harismlnaslm commited on
Commit
3a93207
·
1 Parent(s): e3e9bcd

Serverless-safe defaults: fallback from LLaMA/TinyLlama to DialoGPT; 404 auto-fallback to distilgpt2; robust prompt/cleanup for both families

Browse files
Files changed (1) hide show
  1. app_backup.py +57 -292
app_backup.py CHANGED
@@ -150,198 +150,21 @@ class TrainingManager:
150
  """Manage AI model training using the training scripts"""
151
 
152
  def __init__(self):
153
- self.training_status = {
154
- "is_training": False,
155
- "progress": 0,
156
- "status": "idle",
157
- "start_time": None,
158
- "end_time": None,
159
- "error": None,
160
- "logs": []
161
- }
162
- self.training_thread = None
163
-
164
- def start_training(self, model_name: str = "meta-llama/Llama-3.1-8B-Instruct", epochs: int = 3, batch_size: int = 4):
165
- """Start training in background thread"""
166
- if self.training_status["is_training"]:
167
- return {"error": "Training already in progress"}
168
-
169
- self.training_status = {
170
- "is_training": True,
171
- "progress": 0,
172
- "status": "starting",
173
- "start_time": datetime.now().isoformat(),
174
- "end_time": None,
175
- "error": None,
176
- "logs": []
177
- }
178
-
179
- # Start training in background thread
180
- self.training_thread = threading.Thread(
181
- target=self._run_training,
182
- args=(model_name, epochs, batch_size),
183
- daemon=True
184
- )
185
- self.training_thread.start()
186
-
187
- return {"message": "Training started", "status": "starting"}
188
-
189
- def _run_training(self, model_name: str, epochs: int, batch_size: int):
190
- """Run the actual training process"""
191
- try:
192
- self.training_status["status"] = "preparing"
193
- self.training_status["logs"].append("Preparing training environment...")
194
-
195
- # Check if training data exists
196
- data_path = "data/textilindo_training_data.jsonl"
197
- if not os.path.exists(data_path):
198
- raise Exception("Training data not found")
199
-
200
- self.training_status["status"] = "training"
201
- self.training_status["logs"].append("Starting model training...")
202
-
203
- # Create a simple training script for HF Spaces
204
- training_script = f"""
205
- import os
206
- import sys
207
- import json
208
- import logging
209
- from pathlib import Path
210
- from datetime import datetime
211
-
212
- # Add current directory to path
213
- sys.path.append('.')
214
-
215
- # Setup logging
216
- logging.basicConfig(level=logging.INFO)
217
- logger = logging.getLogger(__name__)
218
-
219
- def simple_training():
220
- \"\"\"Simple training simulation for HF Spaces with Llama support\"\"\"
221
- logger.info("Starting training process...")
222
- logger.info(f"Model: {model_name}")
223
- logger.info(f"Epochs: {epochs}")
224
- logger.info(f"Batch Size: {batch_size}")
225
-
226
- # Load training data
227
- data_path = "data/textilindo_training_data.jsonl"
228
- with open(data_path, 'r', encoding='utf-8') as f:
229
- data = [json.loads(line) for line in f if line.strip()]
230
-
231
- logger.info(f"Loaded {{len(data)}} training samples")
232
-
233
- # Model-specific training simulation
234
- if "llama" in model_name.lower():
235
- logger.info("Using Llama model - High quality training simulation")
236
- training_steps = len(data) * {epochs} * 2 # More steps for Llama
237
- else:
238
- logger.info("Using standard model - Basic training simulation")
239
- training_steps = len(data) * {epochs}
240
-
241
- # Simulate training progress
242
- for epoch in range({epochs}):
243
- logger.info(f"Epoch {{epoch + 1}}/{epochs}")
244
- for i, sample in enumerate(data):
245
- # Simulate training step
246
- progress = ((epoch * len(data) + i) / ({epochs} * len(data))) * 100
247
- logger.info(f"Training progress: {{progress:.1f}}% - Processing: {{sample.get('instruction', 'Unknown')[:50]}}...")
248
-
249
- # Update training status
250
- with open("training_status.json", "w") as f:
251
- json.dump({{
252
- "is_training": True,
253
- "progress": progress,
254
- "status": "training",
255
- "model": "{model_name}",
256
- "epoch": epoch + 1,
257
- "step": i + 1,
258
- "total_steps": len(data),
259
- "current_sample": sample.get('instruction', 'Unknown')[:50]
260
- }}, f)
261
-
262
- logger.info("Training completed successfully!")
263
- logger.info(f"Model {model_name} has been fine-tuned with Textilindo data")
264
-
265
- # Save final status
266
- with open("training_status.json", "w") as f:
267
- json.dump({{
268
- "is_training": False,
269
- "progress": 100,
270
- "status": "completed",
271
- "model": "{model_name}",
272
- "end_time": datetime.now().isoformat(),
273
- "message": f"Model {model_name} training completed successfully!"
274
- }}, f)
275
-
276
- if __name__ == "__main__":
277
- simple_training()
278
- """
279
-
280
- # Write training script
281
- with open("run_training.py", "w") as f:
282
- f.write(training_script)
283
-
284
- # Run training
285
- result = subprocess.run(
286
- ["python", "run_training.py"],
287
- capture_output=True,
288
- text=True,
289
- cwd="."
290
  )
291
-
292
- if result.returncode == 0:
293
- self.training_status["status"] = "completed"
294
- self.training_status["progress"] = 100
295
- self.training_status["logs"].append("Training completed successfully!")
296
- else:
297
- raise Exception(f"Training failed: {result.stderr}")
298
-
299
- except Exception as e:
300
- logger.error(f"Training error: {e}")
301
- self.training_status["status"] = "error"
302
- self.training_status["error"] = str(e)
303
- self.training_status["logs"].append(f"Error: {e}")
304
- finally:
305
- self.training_status["is_training"] = False
306
- self.training_status["end_time"] = datetime.now().isoformat()
307
-
308
- def get_training_status(self):
309
- """Get current training status"""
310
- # Try to read from file if available
311
- status_file = "training_status.json"
312
- if os.path.exists(status_file):
313
- try:
314
- with open(status_file, "r") as f:
315
- file_status = json.load(f)
316
- self.training_status.update(file_status)
317
- except:
318
- pass
319
-
320
- return self.training_status
321
-
322
- def stop_training(self):
323
- """Stop training if running"""
324
- if self.training_status["is_training"]:
325
- self.training_status["status"] = "stopped"
326
- self.training_status["is_training"] = False
327
- return {"message": "Training stopped"}
328
- return {"message": "No training in progress"}
329
-
330
- class TextilindoAI:
331
- """Textilindo AI Assistant using HuggingFace Inference API with Auto-Training"""
332
-
333
- def __init__(self):
334
- self.api_key = os.getenv('HUGGINGFAC_API_KEY_2')
335
- # Use available model with your API key
336
- self.model = os.getenv('DEFAULT_MODEL', 'meta-llama/Llama-3.2-1B-Instruct')
337
  self.system_prompt = self.load_system_prompt()
338
- self.data_loader = TrainingDataLoader()
339
-
340
- # Auto-training configuration
341
- self.auto_training_enabled = True
342
- self.training_interval = 300 # Train every 5 minutes
343
- self.last_training_time = 0
344
- self.trained_responses = {} # Cache for trained responses
345
 
346
  if not self.api_key:
347
  logger.warning("HUGGINGFAC_API_KEY_2 not found. Using mock responses.")
@@ -491,120 +314,62 @@ Minimum purchase is 1 roll (67-70 yards)."""
491
  return self.get_fallback_response(user_message)
492
 
493
  try:
494
- # Use appropriate conversation format
495
- if "llama" in self.model.lower():
496
- # Use proper chat format for Llama models
497
- prompt = f"<|system|>\n{self.system_prompt}\n<|user|>\n{user_message}\n<|assistant|>\n"
498
- elif "dialogpt" in self.model.lower():
499
- prompt = f"User: {user_message}\nAssistant:"
500
- elif "gpt2" in self.model.lower():
501
- prompt = f"User: {user_message}\nAssistant:"
502
  else:
503
- # Fallback format for other models
504
- prompt = f"User: {user_message}\nAssistant:"
505
 
506
- logger.info(f"Using model: {self.model}")
507
- logger.info(f"API Key present: {bool(self.api_key)}")
508
-
509
- logger.info(f"Generating response for prompt: {prompt[:100]}...")
510
-
511
- # Generate response with DialoGPT-optimized parameters
512
- if "dialogpt" in self.model.lower():
513
- response = self.client.text_generation(
514
- prompt,
515
- max_new_tokens=150,
516
- temperature=0.8,
517
- top_p=0.9,
518
- top_k=50,
519
- repetition_penalty=1.1,
520
- do_sample=True,
521
- stop_sequences=["User:", "Assistant:", "\n\n"]
522
- )
523
- else:
524
- # GPT-2 parameters for other models
525
- response = self.client.text_generation(
526
- prompt,
527
- max_new_tokens=150,
528
- temperature=0.8,
529
- top_p=0.9,
530
- top_k=50,
531
- repetition_penalty=1.2,
532
- do_sample=True,
533
- stop_sequences=["User:", "Assistant:", "\n\n"]
534
- )
535
-
536
- logger.info(f"Raw AI response: {response[:200]}...")
537
 
538
- # Clean up the response based on model type
539
- if "llama" in self.model.lower():
540
- # Clean up Llama response
541
  if "<|assistant|>" in response:
542
  assistant_response = response.split("<|assistant|>")[-1].strip()
543
  else:
544
  assistant_response = response.strip()
545
-
546
- # Remove any remaining conversation markers
547
- assistant_response = assistant_response.replace("<|end|>", "").strip()
548
- elif "dialogpt" in self.model.lower() or "gpt2" in self.model.lower():
549
- # Clean up DialoGPT/GPT-2 response
550
- if "Assistant:" in response:
551
- assistant_response = response.split("Assistant:")[-1].strip()
552
- else:
553
- assistant_response = response.strip()
554
-
555
- # Remove any remaining conversation markers
556
- assistant_response = assistant_response.replace("User:", "").replace("Assistant:", "").strip()
557
  else:
558
- # Clean up other model responses
559
  if "Assistant:" in response:
560
- assistant_response = response.split("Assistant:")[-1].strip()
561
- else:
562
- assistant_response = response.strip()
563
-
564
- # Remove any remaining conversation markers
565
- assistant_response = assistant_response.replace("User:", "").replace("Assistant:", "").strip()
566
-
567
- # Remove any incomplete sentences or cut-off text
568
- if assistant_response.endswith(('.', '!', '?')):
569
- pass # Complete sentence
570
- elif '.' in assistant_response:
571
- # Take only the first complete sentence
572
- assistant_response = assistant_response.split('.')[0] + '.'
573
- else:
574
- # If no complete sentence, take first 100 characters
575
- assistant_response = assistant_response[:100]
576
-
577
- logger.info(f"Cleaned AI response: {assistant_response[:100]}...")
578
-
579
- # If response is too short or generic, use fallback
580
- if len(assistant_response) < 10 or "I don't know" in assistant_response.lower():
581
- logger.warning("AI response too short, using fallback response")
582
- return self.get_fallback_response(user_message)
583
-
584
- return assistant_response
585
 
586
  except Exception as e:
587
  logger.error(f"Error generating response: {e}")
588
- logger.error(f"Error type: {type(e).__name__}")
589
- logger.error(f"Error details: {str(e)}")
590
- # Try training data as fallback
591
- training_match = self.data_loader.find_best_match(user_message)
592
- if training_match:
593
- logger.info("Using training data as fallback after API error")
594
- return training_match.get('output', '')
595
- return self.get_fallback_response(user_message)
596
-
597
- def get_fallback_response(self, user_message: str) -> str:
598
- """Fallback response when no training data match and no API available"""
599
- # Try to give a more contextual response based on the question
600
- if "hello" in user_message.lower() or "hi" in user_message.lower():
601
- return "Halo! Saya adalah asisten AI Textilindo. Bagaimana saya bisa membantu Anda hari ini? 😊"
602
- elif "weather" in user_message.lower() or "cuaca" in user_message.lower():
603
- return "Maaf, saya tidak bisa memberikan informasi cuaca. Tapi saya bisa membantu Anda dengan pertanyaan tentang produk dan layanan Textilindo!"
604
- elif "how are you" in user_message.lower() or "apa kabar" in user_message.lower():
605
- return "Saya baik-baik saja, terima kasih! Saya siap membantu Anda dengan pertanyaan tentang Textilindo. Ada yang bisa saya bantu?"
606
- else:
607
- return f"Halo! Saya adalah asisten AI Textilindo. Saya bisa membantu Anda dengan pertanyaan tentang produk dan layanan kami, atau sekadar mengobrol! Bagaimana saya bisa membantu Anda hari ini? 😊"
 
 
 
 
 
608
 
609
  def get_mock_response(self, user_message: str) -> str:
610
  """Enhanced mock responses with better context awareness"""
 
150
  """Manage AI model training using the training scripts"""
151
 
152
  def __init__(self):
153
+ self.api_key = os.getenv('HUGGINGFACE_API_KEY') or os.getenv('HF_TOKEN')
154
+ # Resolve model with safe defaults for HF Serverless Inference
155
+ requested_model = (os.getenv('DEFAULT_MODEL') or '').strip()
156
+ unsupported = ['meta-llama/', 'llama-', 'llama ', 'llama-', 'tinyllama', 'gemma']
157
+ if not requested_model or any(x in requested_model.lower() for x in unsupported):
158
+ # Fallback to widely available serverless models
159
+ self.model = 'microsoft/DialoGPT-medium'
160
+ logger.warning(
161
+ f"DEFAULT_MODEL '{requested_model or 'unset'}' not available on Serverless Inference. "
162
+ f"Falling back to {self.model}."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  )
164
+ else:
165
+ self.model = requested_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  self.system_prompt = self.load_system_prompt()
167
+ self._fallback_model = 'distilgpt2'
 
 
 
 
 
 
168
 
169
  if not self.api_key:
170
  logger.warning("HUGGINGFAC_API_KEY_2 not found. Using mock responses.")
 
314
  return self.get_fallback_response(user_message)
315
 
316
  try:
317
+ # Create full prompt with system prompt
318
+ if any(x in self.model.lower() for x in ['llama', 'tinyllama', 'gemma']):
319
+ full_prompt = f"<|system|>\n{self.system_prompt}\n<|user|>\n{user_message}\n<|assistant|>\n"
 
 
 
 
 
320
  else:
321
+ full_prompt = f"User: {user_message}\nAssistant:"
 
322
 
323
+ # Generate response
324
+ response = self.client.text_generation(
325
+ full_prompt,
326
+ max_new_tokens=512,
327
+ temperature=0.7,
328
+ top_p=0.9,
329
+ top_k=40,
330
+ repetition_penalty=1.1,
331
+ stop_sequences=["<|end|>", "<|user|>", "User:", "Assistant:"]
332
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
 
334
+ # Extract only the assistant's response
335
+ if any(x in self.model.lower() for x in ['llama', 'tinyllama', 'gemma']):
 
336
  if "<|assistant|>" in response:
337
  assistant_response = response.split("<|assistant|>")[-1].strip()
338
  else:
339
  assistant_response = response.strip()
340
+ return assistant_response.replace("<|end|>", "").strip()
 
 
 
 
 
 
 
 
 
 
 
341
  else:
 
342
  if "Assistant:" in response:
343
+ return response.split("Assistant:")[-1].strip()
344
+ return response.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
 
346
  except Exception as e:
347
  logger.error(f"Error generating response: {e}")
348
+ # One-time fallback if current model is not available (404 Not Found)
349
+ err_text = str(e).lower()
350
+ if ("404" in err_text or "not found" in err_text) and self.model != self._fallback_model:
351
+ try:
352
+ logger.warning(
353
+ f"Model {self.model} unavailable on serverless. Falling back to {self._fallback_model} and retrying."
354
+ )
355
+ self.model = self._fallback_model
356
+ self.client = InferenceClient(token=self.api_key, model=self.model)
357
+ retry_prompt = f"User: {user_message}\nAssistant:"
358
+ response = self.client.text_generation(
359
+ retry_prompt,
360
+ max_new_tokens=200,
361
+ temperature=0.7,
362
+ top_p=0.9,
363
+ top_k=40,
364
+ repetition_penalty=1.1,
365
+ stop_sequences=["User:", "Assistant:"]
366
+ )
367
+ if "Assistant:" in response:
368
+ return response.split("Assistant:")[-1].strip()
369
+ return response.strip()
370
+ except Exception as e2:
371
+ logger.error(f"Fallback retry failed: {e2}")
372
+ return self.get_mock_response(user_message)
373
 
374
  def get_mock_response(self, user_message: str) -> str:
375
  """Enhanced mock responses with better context awareness"""