project2 / src /multimodal_model.py
dnj0's picture
Update src/multimodal_model.py
3d62d9e verified
import torch
import sys
import logging
import time
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Optional
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
stream=sys.stdout
)
logger = logging.getLogger(__name__)
class Gemma3Model:
def __init__(self, model_name: str = "unsloth/gemma-3-1b-pt", device: str = "cpu"):
self.device = device
self.model_name = model_name
logger.info(f"β†’ Loading {model_name}...")
print(f"β†’ Loading {model_name}...", flush=True)
try:
from transformers import BitsAndBytesConfig
# Gemma 3 MUST use float32 for compute (not float16!)
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float32, # ← CRITICAL for Gemma 3
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="nf4"
)
logger.debug("Loading model with 4-bit quantization...")
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quantization_config,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.float32 # ← Explicit float32
)
logger.info("βœ“ 4-bit Gemma 3 model loaded successfully")
print("βœ“ 4-bit Gemma 3 model loaded successfully", flush=True)
except Exception as e:
logger.warning(f"Quantization failed ({e}), falling back to float32...")
print(f"Quantization failed, using float32...", flush=True)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32, # ← Never use float16 with Gemma 3!
device_map="cpu",
trust_remote_code=True,
low_cpu_mem_usage=True
)
logger.info("βœ“ Float32 Gemma 3 model loaded")
logger.debug("Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model.eval()
logger.info(f"βœ“ Model ready with dtype {self.model.dtype}")
print(f"βœ“ Model ready with dtype {self.model.dtype}", flush=True)
def generate_response(self, prompt: str, max_new_tokens: int = 200, temperature: float = 0.8) -> str:
"""Generate with Gemma 3 1B (very slow on CPU - expected!)"""
logger.info(f"Starting generation - Gemma 3 1B on CPU takes 1-3 min for 200 tokens")
print(f"β†’ Generating response...", flush=True)
print(f" ℹ️ Gemma 3 1B CPU inference: ~1-2 tokens/second", flush=True)
print(f" ℹ️ Estimated time: {int(max_new_tokens * 0.75)}-{int(max_new_tokens * 1.5)} seconds", flush=True)
# Clamp temperature for Gemma 3 stability
temperature = max(0.5, min(temperature, 1.5))
start_time = time.time()
try:
logger.debug(f"Tokenizing: {prompt[:50]}...")
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
input_len = inputs["input_ids"].shape[1]
logger.debug(f"Input: {input_len} tokens")
print(f" β†’ Input: {input_len} tokens", flush=True)
logger.debug("Starting model.generate()...")
print(f" ⏳ Generating (this WILL take time on CPU)...", flush=True)
with torch.no_grad():
# ALWAYS set max_new_tokens!
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens, # ← CRITICAL
temperature=temperature,
top_p=0.95,
top_k=50,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
remove_invalid_values=True,
repetition_penalty=1.2
)
elapsed = time.time() - start_time
tokens_generated = outputs.shape[1] - input_len
rate = tokens_generated / elapsed if elapsed > 0 else 0
logger.debug(f"Generation took {elapsed:.2f}s ({rate:.2f} tokens/sec)")
print(f" βœ“ Generated {tokens_generated} tokens in {elapsed:.1f}s ({rate:.2f} tok/s)", flush=True)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
logger.info("βœ“ Generation successful")
return response
except Exception as e:
logger.error(f"Generation failed: {str(e)}", exc_info=True)
raise
def generate_response_greedy(self, prompt: str, max_new_tokens: int = 200) -> str:
"""Faster greedy decoding (deterministic, no sampling)"""
logger.info("Using greedy decoding (faster than sampling)")
print(f"β†’ Generating (greedy mode - faster)...", flush=True)
start_time = time.time()
try:
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False, # Greedy - much faster
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
elapsed = time.time() - start_time
logger.debug(f"Greedy generation in {elapsed:.2f}s")
print(f" βœ“ Generated in {elapsed:.1f}s", flush=True)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
except Exception as e:
logger.error(f"Greedy generation failed: {str(e)}", exc_info=True)
raise
def summarize_text(self, text: str, max_new_tokens: int = 150) -> str:
"""Summarize (use greedy - faster)"""
logger.info(f"Summarizing {len(text)} chars")
prompt = f"Summarize in Russian:\n\n{text[:1500]}\n\nSummary:"
return self.generate_response_greedy(prompt, max_new_tokens=max_new_tokens)
def answer_question(self, question: str, context: str, max_new_tokens: int = 250) -> str:
"""Answer based on context (use greedy - faster)"""
logger.info(f"Answering: {question[:50]}...")
context = context[:2000] # Limit context
prompt = f"""Based on context, answer in Russian.
Context:
{context}
Question: {question}
Answer:"""
return self.generate_response_greedy(prompt, max_new_tokens=max_new_tokens)