File size: 7,165 Bytes
3d62d9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import torch
import sys
import logging
import time
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Optional
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
stream=sys.stdout
)
logger = logging.getLogger(__name__)
class Gemma3Model:
def __init__(self, model_name: str = "unsloth/gemma-3-1b-pt", device: str = "cpu"):
self.device = device
self.model_name = model_name
logger.info(f"β Loading {model_name}...")
print(f"β Loading {model_name}...", flush=True)
try:
from transformers import BitsAndBytesConfig
# Gemma 3 MUST use float32 for compute (not float16!)
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float32, # β CRITICAL for Gemma 3
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="nf4"
)
logger.debug("Loading model with 4-bit quantization...")
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quantization_config,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.float32 # β Explicit float32
)
logger.info("β 4-bit Gemma 3 model loaded successfully")
print("β 4-bit Gemma 3 model loaded successfully", flush=True)
except Exception as e:
logger.warning(f"Quantization failed ({e}), falling back to float32...")
print(f"Quantization failed, using float32...", flush=True)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32, # β Never use float16 with Gemma 3!
device_map="cpu",
trust_remote_code=True,
low_cpu_mem_usage=True
)
logger.info("β Float32 Gemma 3 model loaded")
logger.debug("Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model.eval()
logger.info(f"β Model ready with dtype {self.model.dtype}")
print(f"β Model ready with dtype {self.model.dtype}", flush=True)
def generate_response(self, prompt: str, max_new_tokens: int = 200, temperature: float = 0.8) -> str:
"""Generate with Gemma 3 1B (very slow on CPU - expected!)"""
logger.info(f"Starting generation - Gemma 3 1B on CPU takes 1-3 min for 200 tokens")
print(f"β Generating response...", flush=True)
print(f" βΉοΈ Gemma 3 1B CPU inference: ~1-2 tokens/second", flush=True)
print(f" βΉοΈ Estimated time: {int(max_new_tokens * 0.75)}-{int(max_new_tokens * 1.5)} seconds", flush=True)
# Clamp temperature for Gemma 3 stability
temperature = max(0.5, min(temperature, 1.5))
start_time = time.time()
try:
logger.debug(f"Tokenizing: {prompt[:50]}...")
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
input_len = inputs["input_ids"].shape[1]
logger.debug(f"Input: {input_len} tokens")
print(f" β Input: {input_len} tokens", flush=True)
logger.debug("Starting model.generate()...")
print(f" β³ Generating (this WILL take time on CPU)...", flush=True)
with torch.no_grad():
# ALWAYS set max_new_tokens!
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens, # β CRITICAL
temperature=temperature,
top_p=0.95,
top_k=50,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
remove_invalid_values=True,
repetition_penalty=1.2
)
elapsed = time.time() - start_time
tokens_generated = outputs.shape[1] - input_len
rate = tokens_generated / elapsed if elapsed > 0 else 0
logger.debug(f"Generation took {elapsed:.2f}s ({rate:.2f} tokens/sec)")
print(f" β Generated {tokens_generated} tokens in {elapsed:.1f}s ({rate:.2f} tok/s)", flush=True)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
logger.info("β Generation successful")
return response
except Exception as e:
logger.error(f"Generation failed: {str(e)}", exc_info=True)
raise
def generate_response_greedy(self, prompt: str, max_new_tokens: int = 200) -> str:
"""Faster greedy decoding (deterministic, no sampling)"""
logger.info("Using greedy decoding (faster than sampling)")
print(f"β Generating (greedy mode - faster)...", flush=True)
start_time = time.time()
try:
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False, # Greedy - much faster
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
elapsed = time.time() - start_time
logger.debug(f"Greedy generation in {elapsed:.2f}s")
print(f" β Generated in {elapsed:.1f}s", flush=True)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
except Exception as e:
logger.error(f"Greedy generation failed: {str(e)}", exc_info=True)
raise
def summarize_text(self, text: str, max_new_tokens: int = 150) -> str:
"""Summarize (use greedy - faster)"""
logger.info(f"Summarizing {len(text)} chars")
prompt = f"Summarize in Russian:\n\n{text[:1500]}\n\nSummary:"
return self.generate_response_greedy(prompt, max_new_tokens=max_new_tokens)
def answer_question(self, question: str, context: str, max_new_tokens: int = 250) -> str:
"""Answer based on context (use greedy - faster)"""
logger.info(f"Answering: {question[:50]}...")
context = context[:2000] # Limit context
prompt = f"""Based on context, answer in Russian.
Context:
{context}
Question: {question}
Answer:"""
return self.generate_response_greedy(prompt, max_new_tokens=max_new_tokens)
|