dnj0 commited on
Commit
3d62d9e
Β·
verified Β·
1 Parent(s): 21f3961

Update src/multimodal_model.py

Browse files
Files changed (1) hide show
  1. src/multimodal_model.py +171 -81
src/multimodal_model.py CHANGED
@@ -1,81 +1,171 @@
1
- import torch
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, AutoImageProcessor
3
- from typing import Optional, Tuple
4
- import numpy as np
5
- from PIL import Image
6
-
7
- class GemmaVisionModel:
8
- def __init__(self, model_name: str = "unsloth/gemma-3-1b-pt", device: str = "cpu"):
9
- self.device = device
10
- self.model_name = model_name
11
-
12
- print(f"β†’ Loading {model_name}...")
13
-
14
- # Load with 4-bit quantization for memory efficiency
15
- try:
16
- from transformers import BitsAndBytesConfig
17
-
18
- quantization_config = BitsAndBytesConfig(
19
- load_in_4bit=True,
20
- bnb_4bit_compute_dtype=torch.float32,
21
- bnb_4bit_use_double_quant=False,
22
- bnb_4bit_quant_type="nf4"
23
- )
24
-
25
- self.model = AutoModelForCausalLM.from_pretrained(
26
- model_name,
27
- quantization_config=quantization_config,
28
- device_map="auto",
29
- trust_remote_code=True
30
- )
31
- except:
32
- # Fallback without quantization
33
- self.model = AutoModelForCausalLM.from_pretrained(
34
- model_name,
35
- torch_dtype=torch.float32,
36
- device_map="cpu",
37
- trust_remote_code=True
38
- )
39
-
40
- self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
41
- self.model.eval()
42
-
43
- print(f"βœ“ Model loaded successfully")
44
-
45
- def generate_response(self, prompt: str, max_length: int = 512, temperature: float = 0.7) -> str:
46
- """Generate text response"""
47
- with torch.no_grad():
48
- inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
49
-
50
- outputs = self.model.generate(
51
- **inputs,
52
- temperature=0.8, # ← Keep in 0.5-1.5 range
53
- do_sample=True, # ← Use sampling for variety
54
- top_p=0.95, # ← Nucleus sampling
55
- top_k=50, # ← Top-K sampling
56
- remove_invalid_values=True, # ← Remove NaN/Inf
57
- repetition_penalty=1.2, # ← Avoid repetition
58
- pad_token_id=self.tokenizer.eos_token_id,
59
- eos_token_id=self.tokenizer.eos_token_id
60
- )
61
-
62
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
63
-
64
- return response
65
-
66
- def summarize_text(self, text: str, max_length: int = 256) -> str:
67
- """Summarize provided text"""
68
- prompt = f"Summarize the following text in Russian:\n\n{text}\n\nSummary:"
69
- return self.generate_response(prompt, max_length=max_length)
70
-
71
- def answer_question(self, question: str, context: str) -> str:
72
- """Answer question based on context"""
73
- prompt = f"""Based on the following context, answer the question in Russian.
74
-
75
- Context:
76
- {context}
77
-
78
- Question: {question}
79
-
80
- Answer:"""
81
- return self.generate_response(prompt, max_length=512)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ import logging
4
+ import time
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+ from typing import Optional
7
+
8
+ logging.basicConfig(
9
+ level=logging.INFO,
10
+ format='%(asctime)s - %(levelname)s - %(message)s',
11
+ stream=sys.stdout
12
+ )
13
+ logger = logging.getLogger(__name__)
14
+
15
+ class Gemma3Model:
16
+ def __init__(self, model_name: str = "unsloth/gemma-3-1b-pt", device: str = "cpu"):
17
+ self.device = device
18
+ self.model_name = model_name
19
+
20
+ logger.info(f"β†’ Loading {model_name}...")
21
+ print(f"β†’ Loading {model_name}...", flush=True)
22
+
23
+ try:
24
+ from transformers import BitsAndBytesConfig
25
+
26
+ # Gemma 3 MUST use float32 for compute (not float16!)
27
+ quantization_config = BitsAndBytesConfig(
28
+ load_in_4bit=True,
29
+ bnb_4bit_compute_dtype=torch.float32, # ← CRITICAL for Gemma 3
30
+ bnb_4bit_use_double_quant=False,
31
+ bnb_4bit_quant_type="nf4"
32
+ )
33
+
34
+ logger.debug("Loading model with 4-bit quantization...")
35
+ self.model = AutoModelForCausalLM.from_pretrained(
36
+ model_name,
37
+ quantization_config=quantization_config,
38
+ device_map="auto",
39
+ trust_remote_code=True,
40
+ torch_dtype=torch.float32 # ← Explicit float32
41
+ )
42
+ logger.info("βœ“ 4-bit Gemma 3 model loaded successfully")
43
+ print("βœ“ 4-bit Gemma 3 model loaded successfully", flush=True)
44
+
45
+ except Exception as e:
46
+ logger.warning(f"Quantization failed ({e}), falling back to float32...")
47
+ print(f"Quantization failed, using float32...", flush=True)
48
+
49
+ self.model = AutoModelForCausalLM.from_pretrained(
50
+ model_name,
51
+ torch_dtype=torch.float32, # ← Never use float16 with Gemma 3!
52
+ device_map="cpu",
53
+ trust_remote_code=True,
54
+ low_cpu_mem_usage=True
55
+ )
56
+ logger.info("βœ“ Float32 Gemma 3 model loaded")
57
+
58
+ logger.debug("Loading tokenizer...")
59
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
60
+
61
+ if self.tokenizer.pad_token is None:
62
+ self.tokenizer.pad_token = self.tokenizer.eos_token
63
+
64
+ self.model.eval()
65
+ logger.info(f"βœ“ Model ready with dtype {self.model.dtype}")
66
+ print(f"βœ“ Model ready with dtype {self.model.dtype}", flush=True)
67
+
68
+ def generate_response(self, prompt: str, max_new_tokens: int = 200, temperature: float = 0.8) -> str:
69
+ """Generate with Gemma 3 1B (very slow on CPU - expected!)"""
70
+
71
+ logger.info(f"Starting generation - Gemma 3 1B on CPU takes 1-3 min for 200 tokens")
72
+ print(f"β†’ Generating response...", flush=True)
73
+ print(f" ℹ️ Gemma 3 1B CPU inference: ~1-2 tokens/second", flush=True)
74
+ print(f" ℹ️ Estimated time: {int(max_new_tokens * 0.75)}-{int(max_new_tokens * 1.5)} seconds", flush=True)
75
+
76
+ # Clamp temperature for Gemma 3 stability
77
+ temperature = max(0.5, min(temperature, 1.5))
78
+
79
+ start_time = time.time()
80
+
81
+ try:
82
+ logger.debug(f"Tokenizing: {prompt[:50]}...")
83
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
84
+ input_len = inputs["input_ids"].shape[1]
85
+ logger.debug(f"Input: {input_len} tokens")
86
+ print(f" β†’ Input: {input_len} tokens", flush=True)
87
+
88
+ logger.debug("Starting model.generate()...")
89
+ print(f" ⏳ Generating (this WILL take time on CPU)...", flush=True)
90
+
91
+ with torch.no_grad():
92
+ # ALWAYS set max_new_tokens!
93
+ outputs = self.model.generate(
94
+ **inputs,
95
+ max_new_tokens=max_new_tokens, # ← CRITICAL
96
+ temperature=temperature,
97
+ top_p=0.95,
98
+ top_k=50,
99
+ do_sample=True,
100
+ pad_token_id=self.tokenizer.eos_token_id,
101
+ eos_token_id=self.tokenizer.eos_token_id,
102
+ remove_invalid_values=True,
103
+ repetition_penalty=1.2
104
+ )
105
+
106
+ elapsed = time.time() - start_time
107
+ tokens_generated = outputs.shape[1] - input_len
108
+ rate = tokens_generated / elapsed if elapsed > 0 else 0
109
+
110
+ logger.debug(f"Generation took {elapsed:.2f}s ({rate:.2f} tokens/sec)")
111
+ print(f" βœ“ Generated {tokens_generated} tokens in {elapsed:.1f}s ({rate:.2f} tok/s)", flush=True)
112
+
113
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
114
+ logger.info("βœ“ Generation successful")
115
+ return response
116
+
117
+ except Exception as e:
118
+ logger.error(f"Generation failed: {str(e)}", exc_info=True)
119
+ raise
120
+
121
+ def generate_response_greedy(self, prompt: str, max_new_tokens: int = 200) -> str:
122
+ """Faster greedy decoding (deterministic, no sampling)"""
123
+
124
+ logger.info("Using greedy decoding (faster than sampling)")
125
+ print(f"β†’ Generating (greedy mode - faster)...", flush=True)
126
+
127
+ start_time = time.time()
128
+
129
+ try:
130
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
131
+
132
+ with torch.no_grad():
133
+ outputs = self.model.generate(
134
+ **inputs,
135
+ max_new_tokens=max_new_tokens,
136
+ do_sample=False, # Greedy - much faster
137
+ pad_token_id=self.tokenizer.eos_token_id,
138
+ eos_token_id=self.tokenizer.eos_token_id
139
+ )
140
+
141
+ elapsed = time.time() - start_time
142
+ logger.debug(f"Greedy generation in {elapsed:.2f}s")
143
+ print(f" βœ“ Generated in {elapsed:.1f}s", flush=True)
144
+
145
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
146
+ return response
147
+
148
+ except Exception as e:
149
+ logger.error(f"Greedy generation failed: {str(e)}", exc_info=True)
150
+ raise
151
+
152
+ def summarize_text(self, text: str, max_new_tokens: int = 150) -> str:
153
+ """Summarize (use greedy - faster)"""
154
+ logger.info(f"Summarizing {len(text)} chars")
155
+ prompt = f"Summarize in Russian:\n\n{text[:1500]}\n\nSummary:"
156
+ return self.generate_response_greedy(prompt, max_new_tokens=max_new_tokens)
157
+
158
+ def answer_question(self, question: str, context: str, max_new_tokens: int = 250) -> str:
159
+ """Answer based on context (use greedy - faster)"""
160
+ logger.info(f"Answering: {question[:50]}...")
161
+
162
+ context = context[:2000] # Limit context
163
+ prompt = f"""Based on context, answer in Russian.
164
+
165
+ Context:
166
+ {context}
167
+
168
+ Question: {question}
169
+
170
+ Answer:"""
171
+ return self.generate_response_greedy(prompt, max_new_tokens=max_new_tokens)