import torch import sys import logging import time from transformers import AutoModelForCausalLM, AutoTokenizer from typing import Optional logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', stream=sys.stdout ) logger = logging.getLogger(__name__) class Gemma3Model: def __init__(self, model_name: str = "unsloth/gemma-3-1b-pt", device: str = "cpu"): self.device = device self.model_name = model_name logger.info(f"→ Loading {model_name}...") print(f"→ Loading {model_name}...", flush=True) try: from transformers import BitsAndBytesConfig # Gemma 3 MUST use float32 for compute (not float16!) quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float32, # ← CRITICAL for Gemma 3 bnb_4bit_use_double_quant=False, bnb_4bit_quant_type="nf4" ) logger.debug("Loading model with 4-bit quantization...") self.model = AutoModelForCausalLM.from_pretrained( model_name, quantization_config=quantization_config, device_map="auto", trust_remote_code=True, torch_dtype=torch.float32 # ← Explicit float32 ) logger.info("✓ 4-bit Gemma 3 model loaded successfully") print("✓ 4-bit Gemma 3 model loaded successfully", flush=True) except Exception as e: logger.warning(f"Quantization failed ({e}), falling back to float32...") print(f"Quantization failed, using float32...", flush=True) self.model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float32, # ← Never use float16 with Gemma 3! device_map="cpu", trust_remote_code=True, low_cpu_mem_usage=True ) logger.info("✓ Float32 Gemma 3 model loaded") logger.debug("Loading tokenizer...") self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.model.eval() logger.info(f"✓ Model ready with dtype {self.model.dtype}") print(f"✓ Model ready with dtype {self.model.dtype}", flush=True) def generate_response(self, prompt: str, max_new_tokens: int = 200, temperature: float = 0.8) -> str: """Generate with Gemma 3 1B (very slow on CPU - expected!)""" logger.info(f"Starting generation - Gemma 3 1B on CPU takes 1-3 min for 200 tokens") print(f"→ Generating response...", flush=True) print(f" ℹ️ Gemma 3 1B CPU inference: ~1-2 tokens/second", flush=True) print(f" ℹ️ Estimated time: {int(max_new_tokens * 0.75)}-{int(max_new_tokens * 1.5)} seconds", flush=True) # Clamp temperature for Gemma 3 stability temperature = max(0.5, min(temperature, 1.5)) start_time = time.time() try: logger.debug(f"Tokenizing: {prompt[:50]}...") inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) input_len = inputs["input_ids"].shape[1] logger.debug(f"Input: {input_len} tokens") print(f" → Input: {input_len} tokens", flush=True) logger.debug("Starting model.generate()...") print(f" ⏳ Generating (this WILL take time on CPU)...", flush=True) with torch.no_grad(): # ALWAYS set max_new_tokens! outputs = self.model.generate( **inputs, max_new_tokens=max_new_tokens, # ← CRITICAL temperature=temperature, top_p=0.95, top_k=50, do_sample=True, pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id, remove_invalid_values=True, repetition_penalty=1.2 ) elapsed = time.time() - start_time tokens_generated = outputs.shape[1] - input_len rate = tokens_generated / elapsed if elapsed > 0 else 0 logger.debug(f"Generation took {elapsed:.2f}s ({rate:.2f} tokens/sec)") print(f" ✓ Generated {tokens_generated} tokens in {elapsed:.1f}s ({rate:.2f} tok/s)", flush=True) response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) logger.info("✓ Generation successful") return response except Exception as e: logger.error(f"Generation failed: {str(e)}", exc_info=True) raise def generate_response_greedy(self, prompt: str, max_new_tokens: int = 200) -> str: """Faster greedy decoding (deterministic, no sampling)""" logger.info("Using greedy decoding (faster than sampling)") print(f"→ Generating (greedy mode - faster)...", flush=True) start_time = time.time() try: inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False, # Greedy - much faster pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id ) elapsed = time.time() - start_time logger.debug(f"Greedy generation in {elapsed:.2f}s") print(f" ✓ Generated in {elapsed:.1f}s", flush=True) response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) return response except Exception as e: logger.error(f"Greedy generation failed: {str(e)}", exc_info=True) raise def summarize_text(self, text: str, max_new_tokens: int = 150) -> str: """Summarize (use greedy - faster)""" logger.info(f"Summarizing {len(text)} chars") prompt = f"Summarize in Russian:\n\n{text[:1500]}\n\nSummary:" return self.generate_response_greedy(prompt, max_new_tokens=max_new_tokens) def answer_question(self, question: str, context: str, max_new_tokens: int = 250) -> str: """Answer based on context (use greedy - faster)""" logger.info(f"Answering: {question[:50]}...") context = context[:2000] # Limit context prompt = f"""Based on context, answer in Russian. Context: {context} Question: {question} Answer:""" return self.generate_response_greedy(prompt, max_new_tokens=max_new_tokens)