atkiya110 commited on
Commit
14e5ccc
Β·
verified Β·
1 Parent(s): ace5292

Update llm_generator.py

Browse files
Files changed (1) hide show
  1. llm_generator.py +297 -31
llm_generator.py CHANGED
@@ -1,58 +1,324 @@
1
- # llm_generator.py - FIXED
2
-
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
- import logging
6
-
7
- logger = logging.getLogger(__name__)
8
 
9
 
10
  class LLMGenerator:
11
- def __init__(self, model_name: str = "microsoft/Phi-3-mini-4k-instruct"):
12
- logger.info(f"πŸ€– Loading {model_name}...")
13
- logger.info(f"πŸ“Ÿ Device: cpu")
14
- logger.info(f"⏳ Loading (this takes ~30 seconds)...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  # Load tokenizer
17
- logger.info("πŸ“¦ [1/2] Loading tokenizer...")
18
  self.tokenizer = AutoTokenizer.from_pretrained(
19
  model_name,
20
- trust_remote_code=True
 
21
  )
22
 
23
- # Load model with FIXED parameters
24
- logger.info("πŸ“¦ [2/2] Loading model weights...")
 
 
 
 
25
  self.model = AutoModelForCausalLM.from_pretrained(
26
  model_name,
27
- torch_dtype=torch.float32, # Use float32 for CPU
28
  trust_remote_code=True,
29
- low_cpu_mem_usage=False, # βœ… CHANGED: Set to False for CPU
30
- # device_map="cpu", # βœ… REMOVED: Don't use device_map on CPU
31
- attn_implementation="eager" # βœ… ADDED: Fix flash-attention warning
32
  )
33
 
34
- self.model.eval() # Set to evaluation mode
 
 
35
 
36
- logger.info("βœ… Model loaded successfully!")
 
 
37
 
38
- def generate(self, prompt: str, max_length: int = 512) -> str:
39
- """Generate text from prompt"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  try:
41
- inputs = self.tokenizer(prompt, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- with torch.no_grad():
 
 
44
  outputs = self.model.generate(
45
  **inputs,
46
- max_length=max_length,
47
- num_return_sequences=1,
48
  temperature=0.7,
49
  do_sample=True,
50
- pad_token_id=self.tokenizer.eos_token_id
 
 
 
 
 
 
51
  )
 
52
 
53
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
54
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  except Exception as e:
57
- logger.error(f"Generation error: {e}")
58
- return "Error generating response"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
2
  import torch
3
+ import time
 
 
4
 
5
 
6
  class LLMGenerator:
7
+ def __init__(
8
+ self,
9
+ model_name: str = "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
10
+ device: str = "cpu",
11
+ max_new_tokens: int = 150, # βœ… Reduced from 250 (faster)
12
+ use_cache: bool = True # βœ… Enable KV cache
13
+ ):
14
+ """
15
+ Initialize TinyLlama model (optimized for speed)
16
+
17
+ Args:
18
+ model_name: HuggingFace model name
19
+ device: 'cpu' or 'cuda'
20
+ max_new_tokens: Max tokens to generate (lower = faster)
21
+ use_cache: Use key-value caching (faster generation)
22
+ """
23
+ self.device = device
24
+ self.max_new_tokens = max_new_tokens
25
+ self.use_cache = use_cache
26
+
27
+ print(f" πŸ€– Loading {model_name}...")
28
+ print(f" πŸ“Ÿ Device: {device}")
29
+ print(f" ⏳ Loading (this takes ~30 seconds)...")
30
+
31
+ start_time = time.time()
32
 
33
  # Load tokenizer
34
+ print(f" πŸ“¦ [1/2] Loading tokenizer...")
35
  self.tokenizer = AutoTokenizer.from_pretrained(
36
  model_name,
37
+ trust_remote_code=True,
38
+ use_fast=True # βœ… Use fast tokenizer
39
  )
40
 
41
+ # Set padding token
42
+ if self.tokenizer.pad_token is None:
43
+ self.tokenizer.pad_token = self.tokenizer.eos_token
44
+
45
+ # Load model with optimizations
46
+ print(f" πŸ“¦ [2/2] Loading model weights...")
47
  self.model = AutoModelForCausalLM.from_pretrained(
48
  model_name,
49
+ torch_dtype=torch.float32, # CPU requires float32
50
  trust_remote_code=True,
51
+ low_cpu_mem_usage=True,
52
+ use_cache=use_cache # βœ… Enable KV cache
 
53
  )
54
 
55
+ # Move to device
56
+ self.model = self.model.to(device)
57
+ self.model.eval() # Evaluation mode (no gradients)
58
 
59
+ load_time = time.time() - start_time
60
+ print(f" βœ… TinyLlama loaded in {load_time:.1f}s!")
61
+ print(f" ⚑ Max tokens: {max_new_tokens} (lower = faster)")
62
 
63
+ def generate_answer(
64
+ self,
65
+ query: str,
66
+ context: str,
67
+ conversation_history: str = ""
68
+ ) -> str:
69
+ """
70
+ Generate answer (optimized for speed)
71
+
72
+ Args:
73
+ query: User's question
74
+ context: Retrieved context (will be truncated if too long)
75
+ conversation_history: Previous turns (optional)
76
+
77
+ Returns:
78
+ Generated answer
79
+ """
80
+ start_time = time.time()
81
+
82
  try:
83
+ # βœ… Truncate context aggressively (faster tokenization)
84
+ context = self._truncate_context(context, max_chars=1500)
85
+
86
+ # Build prompt
87
+ prompt = self._build_prompt(query, context, conversation_history)
88
+
89
+ # Tokenize (faster with truncation)
90
+ t1 = time.time()
91
+ inputs = self.tokenizer(
92
+ prompt,
93
+ return_tensors="pt",
94
+ truncation=True,
95
+ max_length=1500, # βœ… Reduced from 2000
96
+ padding=False,
97
+ return_attention_mask=True
98
+ )
99
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
100
+ tokenize_time = time.time() - t1
101
 
102
+ # Generate (optimized settings)
103
+ t2 = time.time()
104
+ with torch.no_grad(): # No gradients = faster
105
  outputs = self.model.generate(
106
  **inputs,
107
+ max_new_tokens=self.max_new_tokens, # βœ… Configurable
108
+ min_new_tokens=20, # βœ… Ensure minimum response
109
  temperature=0.7,
110
  do_sample=True,
111
+ top_p=0.9,
112
+ top_k=50, # βœ… Add top-k sampling (faster)
113
+ repetition_penalty=1.15, # βœ… Slightly higher
114
+ pad_token_id=self.tokenizer.pad_token_id,
115
+ eos_token_id=self.tokenizer.eos_token_id,
116
+ use_cache=self.use_cache, # βœ… Use KV cache
117
+ num_beams=1 # βœ… Greedy decoding (faster than beam search)
118
  )
119
+ generate_time = time.time() - t2
120
 
121
+ # Decode
122
+ t3 = time.time()
123
+ full_response = self.tokenizer.decode(
124
+ outputs[0],
125
+ skip_special_tokens=True,
126
+ clean_up_tokenization_spaces=True
127
+ )
128
+ decode_time = time.time() - t3
129
+
130
+ # Extract answer
131
+ answer = self._extract_answer(full_response, prompt)
132
+
133
+ # Performance stats
134
+ total_time = time.time() - start_time
135
+ print(f" ⏱️ Generation timing:")
136
+ print(f" β€’ Tokenize: {tokenize_time:.3f}s")
137
+ print(f" β€’ Generate: {generate_time:.3f}s")
138
+ print(f" β€’ Decode: {decode_time:.3f}s")
139
+ print(f" β€’ Total: {total_time:.3f}s")
140
+
141
+ return answer
142
 
143
  except Exception as e:
144
+ print(f" ❌ Generation error: {str(e)}")
145
+ return self._fallback_answer(context)
146
+
147
+ def _truncate_context(self, context: str, max_chars: int = 1500) -> str:
148
+ """
149
+ Intelligently truncate context to speed up processing
150
+ """
151
+ if len(context) <= max_chars:
152
+ return context
153
+
154
+ # Try to truncate at sentence boundary
155
+ truncated = context[:max_chars]
156
+ last_period = truncated.rfind('.')
157
+
158
+ if last_period > max_chars * 0.7: # At least 70% of content
159
+ return truncated[:last_period + 1]
160
+ else:
161
+ return truncated + "..."
162
+
163
+ def _build_prompt(self, query: str, context: str, history: str) -> str:
164
+ """Build optimized prompt (shorter = faster)"""
165
+
166
+ # βœ… Shorter system message
167
+ system_msg = "You are an EWU admissions assistant. Answer based only on the context provided. Be concise."
168
+
169
+ # βœ… Simpler format (less tokens)
170
+ prompt = f"""<|system|>
171
+ {system_msg}</s>
172
+ <|user|>
173
+ Context: {context}
174
+
175
+ Question: {query}</s>
176
+ <|assistant|>
177
+ """
178
+ return prompt
179
+
180
+ def _extract_answer(self, full_response: str, prompt: str) -> str:
181
+ """Extract clean answer from response"""
182
+
183
+ # Find assistant response
184
+ if "<|assistant|>" in full_response:
185
+ parts = full_response.split("<|assistant|>")
186
+ answer = parts[-1] if len(parts) > 1 else full_response
187
+ else:
188
+ # Remove prompt
189
+ answer = full_response.replace(prompt, "").strip()
190
+
191
+ # Clean special tokens
192
+ for token in ["</s>", "<|system|>", "<|user|>", "<|assistant|>", "<s>"]:
193
+ answer = answer.replace(token, "")
194
+
195
+ # Clean extra whitespace
196
+ answer = " ".join(answer.split())
197
+
198
+ # βœ… Limit length (avoid rambling)
199
+ if len(answer) > 500:
200
+ answer = answer[:500].rsplit('.', 1)[0] + "."
201
+
202
+ return answer.strip() if answer.strip() else self._fallback_answer("")
203
+
204
+ def _fallback_answer(self, context: str) -> str:
205
+ """
206
+ Fallback when generation fails
207
+ Return formatted context instead
208
+ """
209
+ if not context:
210
+ return "I apologize, but I couldn't find relevant information to answer your question."
211
+
212
+ # Return first few lines of context
213
+ lines = [line.strip() for line in context.split('\n') if line.strip()]
214
+ return "\n".join(lines[:5]) + "\n\nπŸ“ž For more details: +880-2-9882308"
215
+
216
+
217
+ # ============================================================================
218
+ # EVEN FASTER: Ultra-lightweight alternative
219
+ # ============================================================================
220
+
221
+ class FastLLMGenerator:
222
+ """
223
+ Ultra-fast generator with DistilGPT-2 (10x smaller model)
224
+ Use this if TinyLlama is still too slow
225
+ """
226
+
227
+ def __init__(self, model_name: str = "distilgpt2", device: str = "cpu"):
228
+ print(f" ⚑ Loading {model_name} (ultra-fast)...")
229
+
230
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
231
+ self.tokenizer.pad_token = self.tokenizer.eos_token
232
+
233
+ self.model = AutoModelForCausalLM.from_pretrained(
234
+ model_name,
235
+ torch_dtype=torch.float32
236
+ ).to(device)
237
+
238
+ self.model.eval()
239
+ self.device = device
240
+
241
+ print(f" βœ… Loaded! (82M params, 10x faster than TinyLlama)")
242
+
243
+ def generate_answer(self, query: str, context: str, **kwargs) -> str:
244
+ """Generate with ultra-fast model"""
245
+
246
+ # Very simple prompt
247
+ prompt = f"Context: {context[:800]}\n\nQ: {query}\nA:"
248
+
249
+ inputs = self.tokenizer(
250
+ prompt,
251
+ return_tensors="pt",
252
+ truncation=True,
253
+ max_length=1000
254
+ ).to(self.device)
255
+
256
+ with torch.no_grad():
257
+ outputs = self.model.generate(
258
+ **inputs,
259
+ max_new_tokens=80, # Very short
260
+ temperature=0.8,
261
+ do_sample=True,
262
+ top_p=0.9,
263
+ pad_token_id=self.tokenizer.eos_token_id
264
+ )
265
+
266
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
267
+ answer = response.replace(prompt, "").strip()
268
+
269
+ return answer if answer else "Based on the information provided."
270
+
271
+
272
+ # ============================================================================
273
+ # TEST
274
+ # ============================================================================
275
+
276
+ if __name__ == "__main__":
277
+ print("="*70)
278
+ print("Testing Optimized LLM Generators")
279
+ print("="*70)
280
+
281
+ test_context = """Program: B.Sc. in Computer Science Engineering (CSE)
282
+ Total Tuition Fee: 634,500 BDT
283
+ Total Credits: 141
284
+ Fee Per Credit: 4,500 BDT
285
+ Application Deadline: August 25, 2025
286
+ Admission Test Date: August 30, 2025"""
287
+
288
+ test_query = "How much does the CSE program cost?"
289
+
290
+ # Test 1: Optimized TinyLlama
291
+ print("\n" + "="*70)
292
+ print("TEST 1: Optimized TinyLlama")
293
+ print("="*70)
294
+
295
+ try:
296
+ generator = LLMGenerator(
297
+ model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
298
+ device="cpu",
299
+ max_new_tokens=100 # Short responses
300
+ )
301
+
302
+ answer = generator.generate_answer(test_query, test_context)
303
+ print(f"\nβœ… Answer: {answer}\n")
304
+
305
+ except Exception as e:
306
+ print(f"❌ Error: {e}")
307
+
308
+ # Test 2: Ultra-fast DistilGPT-2
309
+ print("\n" + "="*70)
310
+ print("TEST 2: Ultra-Fast DistilGPT-2")
311
+ print("="*70)
312
+
313
+ try:
314
+ fast_gen = FastLLMGenerator(model_name="distilgpt2", device="cpu")
315
+
316
+ answer = fast_gen.generate_answer(test_query, test_context)
317
+ print(f"\nβœ… Answer: {answer}\n")
318
+
319
+ except Exception as e:
320
+ print(f"❌ Error: {e}")
321
+
322
+ print("="*70)
323
+ print("βœ… All tests completed!")
324
+ print("="*70)