File size: 7,165 Bytes
3d62d9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import torch
import sys
import logging
import time
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Optional

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    stream=sys.stdout
)
logger = logging.getLogger(__name__)

class Gemma3Model:
    def __init__(self, model_name: str = "unsloth/gemma-3-1b-pt", device: str = "cpu"):
        self.device = device
        self.model_name = model_name
        
        logger.info(f"β†’ Loading {model_name}...")
        print(f"β†’ Loading {model_name}...", flush=True)
        
        try:
            from transformers import BitsAndBytesConfig
            
            # Gemma 3 MUST use float32 for compute (not float16!)
            quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.float32,  # ← CRITICAL for Gemma 3
                bnb_4bit_use_double_quant=False,
                bnb_4bit_quant_type="nf4"
            )
            
            logger.debug("Loading model with 4-bit quantization...")
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name,
                quantization_config=quantization_config,
                device_map="auto",
                trust_remote_code=True,
                torch_dtype=torch.float32  # ← Explicit float32
            )
            logger.info("βœ“ 4-bit Gemma 3 model loaded successfully")
            print("βœ“ 4-bit Gemma 3 model loaded successfully", flush=True)
        
        except Exception as e:
            logger.warning(f"Quantization failed ({e}), falling back to float32...")
            print(f"Quantization failed, using float32...", flush=True)
            
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=torch.float32,  # ← Never use float16 with Gemma 3!
                device_map="cpu",
                trust_remote_code=True,
                low_cpu_mem_usage=True
            )
            logger.info("βœ“ Float32 Gemma 3 model loaded")
        
        logger.debug("Loading tokenizer...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        self.model.eval()
        logger.info(f"βœ“ Model ready with dtype {self.model.dtype}")
        print(f"βœ“ Model ready with dtype {self.model.dtype}", flush=True)
    
    def generate_response(self, prompt: str, max_new_tokens: int = 200, temperature: float = 0.8) -> str:
        """Generate with Gemma 3 1B (very slow on CPU - expected!)"""
        
        logger.info(f"Starting generation - Gemma 3 1B on CPU takes 1-3 min for 200 tokens")
        print(f"β†’ Generating response...", flush=True)
        print(f"  ℹ️ Gemma 3 1B CPU inference: ~1-2 tokens/second", flush=True)
        print(f"  ℹ️ Estimated time: {int(max_new_tokens * 0.75)}-{int(max_new_tokens * 1.5)} seconds", flush=True)
        
        # Clamp temperature for Gemma 3 stability
        temperature = max(0.5, min(temperature, 1.5))
        
        start_time = time.time()
        
        try:
            logger.debug(f"Tokenizing: {prompt[:50]}...")
            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
            input_len = inputs["input_ids"].shape[1]
            logger.debug(f"Input: {input_len} tokens")
            print(f"  β†’ Input: {input_len} tokens", flush=True)
            
            logger.debug("Starting model.generate()...")
            print(f"  ⏳ Generating (this WILL take time on CPU)...", flush=True)
            
            with torch.no_grad():
                # ALWAYS set max_new_tokens!
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,  # ← CRITICAL
                    temperature=temperature,
                    top_p=0.95,
                    top_k=50,
                    do_sample=True,
                    pad_token_id=self.tokenizer.eos_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                    remove_invalid_values=True,
                    repetition_penalty=1.2
                )
            
            elapsed = time.time() - start_time
            tokens_generated = outputs.shape[1] - input_len
            rate = tokens_generated / elapsed if elapsed > 0 else 0
            
            logger.debug(f"Generation took {elapsed:.2f}s ({rate:.2f} tokens/sec)")
            print(f"  βœ“ Generated {tokens_generated} tokens in {elapsed:.1f}s ({rate:.2f} tok/s)", flush=True)
            
            response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            logger.info("βœ“ Generation successful")
            return response
        
        except Exception as e:
            logger.error(f"Generation failed: {str(e)}", exc_info=True)
            raise
    
    def generate_response_greedy(self, prompt: str, max_new_tokens: int = 200) -> str:
        """Faster greedy decoding (deterministic, no sampling)"""
        
        logger.info("Using greedy decoding (faster than sampling)")
        print(f"β†’ Generating (greedy mode - faster)...", flush=True)
        
        start_time = time.time()
        
        try:
            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
            
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    do_sample=False,  # Greedy - much faster
                    pad_token_id=self.tokenizer.eos_token_id,
                    eos_token_id=self.tokenizer.eos_token_id
                )
            
            elapsed = time.time() - start_time
            logger.debug(f"Greedy generation in {elapsed:.2f}s")
            print(f"  βœ“ Generated in {elapsed:.1f}s", flush=True)
            
            response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            return response
        
        except Exception as e:
            logger.error(f"Greedy generation failed: {str(e)}", exc_info=True)
            raise
    
    def summarize_text(self, text: str, max_new_tokens: int = 150) -> str:
        """Summarize (use greedy - faster)"""
        logger.info(f"Summarizing {len(text)} chars")
        prompt = f"Summarize in Russian:\n\n{text[:1500]}\n\nSummary:"
        return self.generate_response_greedy(prompt, max_new_tokens=max_new_tokens)
    
    def answer_question(self, question: str, context: str, max_new_tokens: int = 250) -> str:
        """Answer based on context (use greedy - faster)"""
        logger.info(f"Answering: {question[:50]}...")
        
        context = context[:2000]  # Limit context
        prompt = f"""Based on context, answer in Russian.

Context:
{context}

Question: {question}

Answer:"""
        return self.generate_response_greedy(prompt, max_new_tokens=max_new_tokens)