Spaces:
Sleeping
Sleeping
Commit ·
3a93207
1
Parent(s): e3e9bcd
Serverless-safe defaults: fallback from LLaMA/TinyLlama to DialoGPT; 404 auto-fallback to distilgpt2; robust prompt/cleanup for both families
Browse files- app_backup.py +57 -292
app_backup.py
CHANGED
|
@@ -150,198 +150,21 @@ class TrainingManager:
|
|
| 150 |
"""Manage AI model training using the training scripts"""
|
| 151 |
|
| 152 |
def __init__(self):
|
| 153 |
-
self.
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
def start_training(self, model_name: str = "meta-llama/Llama-3.1-8B-Instruct", epochs: int = 3, batch_size: int = 4):
|
| 165 |
-
"""Start training in background thread"""
|
| 166 |
-
if self.training_status["is_training"]:
|
| 167 |
-
return {"error": "Training already in progress"}
|
| 168 |
-
|
| 169 |
-
self.training_status = {
|
| 170 |
-
"is_training": True,
|
| 171 |
-
"progress": 0,
|
| 172 |
-
"status": "starting",
|
| 173 |
-
"start_time": datetime.now().isoformat(),
|
| 174 |
-
"end_time": None,
|
| 175 |
-
"error": None,
|
| 176 |
-
"logs": []
|
| 177 |
-
}
|
| 178 |
-
|
| 179 |
-
# Start training in background thread
|
| 180 |
-
self.training_thread = threading.Thread(
|
| 181 |
-
target=self._run_training,
|
| 182 |
-
args=(model_name, epochs, batch_size),
|
| 183 |
-
daemon=True
|
| 184 |
-
)
|
| 185 |
-
self.training_thread.start()
|
| 186 |
-
|
| 187 |
-
return {"message": "Training started", "status": "starting"}
|
| 188 |
-
|
| 189 |
-
def _run_training(self, model_name: str, epochs: int, batch_size: int):
|
| 190 |
-
"""Run the actual training process"""
|
| 191 |
-
try:
|
| 192 |
-
self.training_status["status"] = "preparing"
|
| 193 |
-
self.training_status["logs"].append("Preparing training environment...")
|
| 194 |
-
|
| 195 |
-
# Check if training data exists
|
| 196 |
-
data_path = "data/textilindo_training_data.jsonl"
|
| 197 |
-
if not os.path.exists(data_path):
|
| 198 |
-
raise Exception("Training data not found")
|
| 199 |
-
|
| 200 |
-
self.training_status["status"] = "training"
|
| 201 |
-
self.training_status["logs"].append("Starting model training...")
|
| 202 |
-
|
| 203 |
-
# Create a simple training script for HF Spaces
|
| 204 |
-
training_script = f"""
|
| 205 |
-
import os
|
| 206 |
-
import sys
|
| 207 |
-
import json
|
| 208 |
-
import logging
|
| 209 |
-
from pathlib import Path
|
| 210 |
-
from datetime import datetime
|
| 211 |
-
|
| 212 |
-
# Add current directory to path
|
| 213 |
-
sys.path.append('.')
|
| 214 |
-
|
| 215 |
-
# Setup logging
|
| 216 |
-
logging.basicConfig(level=logging.INFO)
|
| 217 |
-
logger = logging.getLogger(__name__)
|
| 218 |
-
|
| 219 |
-
def simple_training():
|
| 220 |
-
\"\"\"Simple training simulation for HF Spaces with Llama support\"\"\"
|
| 221 |
-
logger.info("Starting training process...")
|
| 222 |
-
logger.info(f"Model: {model_name}")
|
| 223 |
-
logger.info(f"Epochs: {epochs}")
|
| 224 |
-
logger.info(f"Batch Size: {batch_size}")
|
| 225 |
-
|
| 226 |
-
# Load training data
|
| 227 |
-
data_path = "data/textilindo_training_data.jsonl"
|
| 228 |
-
with open(data_path, 'r', encoding='utf-8') as f:
|
| 229 |
-
data = [json.loads(line) for line in f if line.strip()]
|
| 230 |
-
|
| 231 |
-
logger.info(f"Loaded {{len(data)}} training samples")
|
| 232 |
-
|
| 233 |
-
# Model-specific training simulation
|
| 234 |
-
if "llama" in model_name.lower():
|
| 235 |
-
logger.info("Using Llama model - High quality training simulation")
|
| 236 |
-
training_steps = len(data) * {epochs} * 2 # More steps for Llama
|
| 237 |
-
else:
|
| 238 |
-
logger.info("Using standard model - Basic training simulation")
|
| 239 |
-
training_steps = len(data) * {epochs}
|
| 240 |
-
|
| 241 |
-
# Simulate training progress
|
| 242 |
-
for epoch in range({epochs}):
|
| 243 |
-
logger.info(f"Epoch {{epoch + 1}}/{epochs}")
|
| 244 |
-
for i, sample in enumerate(data):
|
| 245 |
-
# Simulate training step
|
| 246 |
-
progress = ((epoch * len(data) + i) / ({epochs} * len(data))) * 100
|
| 247 |
-
logger.info(f"Training progress: {{progress:.1f}}% - Processing: {{sample.get('instruction', 'Unknown')[:50]}}...")
|
| 248 |
-
|
| 249 |
-
# Update training status
|
| 250 |
-
with open("training_status.json", "w") as f:
|
| 251 |
-
json.dump({{
|
| 252 |
-
"is_training": True,
|
| 253 |
-
"progress": progress,
|
| 254 |
-
"status": "training",
|
| 255 |
-
"model": "{model_name}",
|
| 256 |
-
"epoch": epoch + 1,
|
| 257 |
-
"step": i + 1,
|
| 258 |
-
"total_steps": len(data),
|
| 259 |
-
"current_sample": sample.get('instruction', 'Unknown')[:50]
|
| 260 |
-
}}, f)
|
| 261 |
-
|
| 262 |
-
logger.info("Training completed successfully!")
|
| 263 |
-
logger.info(f"Model {model_name} has been fine-tuned with Textilindo data")
|
| 264 |
-
|
| 265 |
-
# Save final status
|
| 266 |
-
with open("training_status.json", "w") as f:
|
| 267 |
-
json.dump({{
|
| 268 |
-
"is_training": False,
|
| 269 |
-
"progress": 100,
|
| 270 |
-
"status": "completed",
|
| 271 |
-
"model": "{model_name}",
|
| 272 |
-
"end_time": datetime.now().isoformat(),
|
| 273 |
-
"message": f"Model {model_name} training completed successfully!"
|
| 274 |
-
}}, f)
|
| 275 |
-
|
| 276 |
-
if __name__ == "__main__":
|
| 277 |
-
simple_training()
|
| 278 |
-
"""
|
| 279 |
-
|
| 280 |
-
# Write training script
|
| 281 |
-
with open("run_training.py", "w") as f:
|
| 282 |
-
f.write(training_script)
|
| 283 |
-
|
| 284 |
-
# Run training
|
| 285 |
-
result = subprocess.run(
|
| 286 |
-
["python", "run_training.py"],
|
| 287 |
-
capture_output=True,
|
| 288 |
-
text=True,
|
| 289 |
-
cwd="."
|
| 290 |
)
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
self.training_status["status"] = "completed"
|
| 294 |
-
self.training_status["progress"] = 100
|
| 295 |
-
self.training_status["logs"].append("Training completed successfully!")
|
| 296 |
-
else:
|
| 297 |
-
raise Exception(f"Training failed: {result.stderr}")
|
| 298 |
-
|
| 299 |
-
except Exception as e:
|
| 300 |
-
logger.error(f"Training error: {e}")
|
| 301 |
-
self.training_status["status"] = "error"
|
| 302 |
-
self.training_status["error"] = str(e)
|
| 303 |
-
self.training_status["logs"].append(f"Error: {e}")
|
| 304 |
-
finally:
|
| 305 |
-
self.training_status["is_training"] = False
|
| 306 |
-
self.training_status["end_time"] = datetime.now().isoformat()
|
| 307 |
-
|
| 308 |
-
def get_training_status(self):
|
| 309 |
-
"""Get current training status"""
|
| 310 |
-
# Try to read from file if available
|
| 311 |
-
status_file = "training_status.json"
|
| 312 |
-
if os.path.exists(status_file):
|
| 313 |
-
try:
|
| 314 |
-
with open(status_file, "r") as f:
|
| 315 |
-
file_status = json.load(f)
|
| 316 |
-
self.training_status.update(file_status)
|
| 317 |
-
except:
|
| 318 |
-
pass
|
| 319 |
-
|
| 320 |
-
return self.training_status
|
| 321 |
-
|
| 322 |
-
def stop_training(self):
|
| 323 |
-
"""Stop training if running"""
|
| 324 |
-
if self.training_status["is_training"]:
|
| 325 |
-
self.training_status["status"] = "stopped"
|
| 326 |
-
self.training_status["is_training"] = False
|
| 327 |
-
return {"message": "Training stopped"}
|
| 328 |
-
return {"message": "No training in progress"}
|
| 329 |
-
|
| 330 |
-
class TextilindoAI:
|
| 331 |
-
"""Textilindo AI Assistant using HuggingFace Inference API with Auto-Training"""
|
| 332 |
-
|
| 333 |
-
def __init__(self):
|
| 334 |
-
self.api_key = os.getenv('HUGGINGFAC_API_KEY_2')
|
| 335 |
-
# Use available model with your API key
|
| 336 |
-
self.model = os.getenv('DEFAULT_MODEL', 'meta-llama/Llama-3.2-1B-Instruct')
|
| 337 |
self.system_prompt = self.load_system_prompt()
|
| 338 |
-
self.
|
| 339 |
-
|
| 340 |
-
# Auto-training configuration
|
| 341 |
-
self.auto_training_enabled = True
|
| 342 |
-
self.training_interval = 300 # Train every 5 minutes
|
| 343 |
-
self.last_training_time = 0
|
| 344 |
-
self.trained_responses = {} # Cache for trained responses
|
| 345 |
|
| 346 |
if not self.api_key:
|
| 347 |
logger.warning("HUGGINGFAC_API_KEY_2 not found. Using mock responses.")
|
|
@@ -491,120 +314,62 @@ Minimum purchase is 1 roll (67-70 yards)."""
|
|
| 491 |
return self.get_fallback_response(user_message)
|
| 492 |
|
| 493 |
try:
|
| 494 |
-
#
|
| 495 |
-
if
|
| 496 |
-
|
| 497 |
-
prompt = f"<|system|>\n{self.system_prompt}\n<|user|>\n{user_message}\n<|assistant|>\n"
|
| 498 |
-
elif "dialogpt" in self.model.lower():
|
| 499 |
-
prompt = f"User: {user_message}\nAssistant:"
|
| 500 |
-
elif "gpt2" in self.model.lower():
|
| 501 |
-
prompt = f"User: {user_message}\nAssistant:"
|
| 502 |
else:
|
| 503 |
-
|
| 504 |
-
prompt = f"User: {user_message}\nAssistant:"
|
| 505 |
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
temperature=0.8,
|
| 517 |
-
top_p=0.9,
|
| 518 |
-
top_k=50,
|
| 519 |
-
repetition_penalty=1.1,
|
| 520 |
-
do_sample=True,
|
| 521 |
-
stop_sequences=["User:", "Assistant:", "\n\n"]
|
| 522 |
-
)
|
| 523 |
-
else:
|
| 524 |
-
# GPT-2 parameters for other models
|
| 525 |
-
response = self.client.text_generation(
|
| 526 |
-
prompt,
|
| 527 |
-
max_new_tokens=150,
|
| 528 |
-
temperature=0.8,
|
| 529 |
-
top_p=0.9,
|
| 530 |
-
top_k=50,
|
| 531 |
-
repetition_penalty=1.2,
|
| 532 |
-
do_sample=True,
|
| 533 |
-
stop_sequences=["User:", "Assistant:", "\n\n"]
|
| 534 |
-
)
|
| 535 |
-
|
| 536 |
-
logger.info(f"Raw AI response: {response[:200]}...")
|
| 537 |
|
| 538 |
-
#
|
| 539 |
-
if
|
| 540 |
-
# Clean up Llama response
|
| 541 |
if "<|assistant|>" in response:
|
| 542 |
assistant_response = response.split("<|assistant|>")[-1].strip()
|
| 543 |
else:
|
| 544 |
assistant_response = response.strip()
|
| 545 |
-
|
| 546 |
-
# Remove any remaining conversation markers
|
| 547 |
-
assistant_response = assistant_response.replace("<|end|>", "").strip()
|
| 548 |
-
elif "dialogpt" in self.model.lower() or "gpt2" in self.model.lower():
|
| 549 |
-
# Clean up DialoGPT/GPT-2 response
|
| 550 |
-
if "Assistant:" in response:
|
| 551 |
-
assistant_response = response.split("Assistant:")[-1].strip()
|
| 552 |
-
else:
|
| 553 |
-
assistant_response = response.strip()
|
| 554 |
-
|
| 555 |
-
# Remove any remaining conversation markers
|
| 556 |
-
assistant_response = assistant_response.replace("User:", "").replace("Assistant:", "").strip()
|
| 557 |
else:
|
| 558 |
-
# Clean up other model responses
|
| 559 |
if "Assistant:" in response:
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
assistant_response = response.strip()
|
| 563 |
-
|
| 564 |
-
# Remove any remaining conversation markers
|
| 565 |
-
assistant_response = assistant_response.replace("User:", "").replace("Assistant:", "").strip()
|
| 566 |
-
|
| 567 |
-
# Remove any incomplete sentences or cut-off text
|
| 568 |
-
if assistant_response.endswith(('.', '!', '?')):
|
| 569 |
-
pass # Complete sentence
|
| 570 |
-
elif '.' in assistant_response:
|
| 571 |
-
# Take only the first complete sentence
|
| 572 |
-
assistant_response = assistant_response.split('.')[0] + '.'
|
| 573 |
-
else:
|
| 574 |
-
# If no complete sentence, take first 100 characters
|
| 575 |
-
assistant_response = assistant_response[:100]
|
| 576 |
-
|
| 577 |
-
logger.info(f"Cleaned AI response: {assistant_response[:100]}...")
|
| 578 |
-
|
| 579 |
-
# If response is too short or generic, use fallback
|
| 580 |
-
if len(assistant_response) < 10 or "I don't know" in assistant_response.lower():
|
| 581 |
-
logger.warning("AI response too short, using fallback response")
|
| 582 |
-
return self.get_fallback_response(user_message)
|
| 583 |
-
|
| 584 |
-
return assistant_response
|
| 585 |
|
| 586 |
except Exception as e:
|
| 587 |
logger.error(f"Error generating response: {e}")
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 608 |
|
| 609 |
def get_mock_response(self, user_message: str) -> str:
|
| 610 |
"""Enhanced mock responses with better context awareness"""
|
|
|
|
| 150 |
"""Manage AI model training using the training scripts"""
|
| 151 |
|
| 152 |
def __init__(self):
|
| 153 |
+
self.api_key = os.getenv('HUGGINGFACE_API_KEY') or os.getenv('HF_TOKEN')
|
| 154 |
+
# Resolve model with safe defaults for HF Serverless Inference
|
| 155 |
+
requested_model = (os.getenv('DEFAULT_MODEL') or '').strip()
|
| 156 |
+
unsupported = ['meta-llama/', 'llama-', 'llama ', 'llama-', 'tinyllama', 'gemma']
|
| 157 |
+
if not requested_model or any(x in requested_model.lower() for x in unsupported):
|
| 158 |
+
# Fallback to widely available serverless models
|
| 159 |
+
self.model = 'microsoft/DialoGPT-medium'
|
| 160 |
+
logger.warning(
|
| 161 |
+
f"DEFAULT_MODEL '{requested_model or 'unset'}' not available on Serverless Inference. "
|
| 162 |
+
f"Falling back to {self.model}."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
)
|
| 164 |
+
else:
|
| 165 |
+
self.model = requested_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
self.system_prompt = self.load_system_prompt()
|
| 167 |
+
self._fallback_model = 'distilgpt2'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
if not self.api_key:
|
| 170 |
logger.warning("HUGGINGFAC_API_KEY_2 not found. Using mock responses.")
|
|
|
|
| 314 |
return self.get_fallback_response(user_message)
|
| 315 |
|
| 316 |
try:
|
| 317 |
+
# Create full prompt with system prompt
|
| 318 |
+
if any(x in self.model.lower() for x in ['llama', 'tinyllama', 'gemma']):
|
| 319 |
+
full_prompt = f"<|system|>\n{self.system_prompt}\n<|user|>\n{user_message}\n<|assistant|>\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
else:
|
| 321 |
+
full_prompt = f"User: {user_message}\nAssistant:"
|
|
|
|
| 322 |
|
| 323 |
+
# Generate response
|
| 324 |
+
response = self.client.text_generation(
|
| 325 |
+
full_prompt,
|
| 326 |
+
max_new_tokens=512,
|
| 327 |
+
temperature=0.7,
|
| 328 |
+
top_p=0.9,
|
| 329 |
+
top_k=40,
|
| 330 |
+
repetition_penalty=1.1,
|
| 331 |
+
stop_sequences=["<|end|>", "<|user|>", "User:", "Assistant:"]
|
| 332 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
|
| 334 |
+
# Extract only the assistant's response
|
| 335 |
+
if any(x in self.model.lower() for x in ['llama', 'tinyllama', 'gemma']):
|
|
|
|
| 336 |
if "<|assistant|>" in response:
|
| 337 |
assistant_response = response.split("<|assistant|>")[-1].strip()
|
| 338 |
else:
|
| 339 |
assistant_response = response.strip()
|
| 340 |
+
return assistant_response.replace("<|end|>", "").strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
else:
|
|
|
|
| 342 |
if "Assistant:" in response:
|
| 343 |
+
return response.split("Assistant:")[-1].strip()
|
| 344 |
+
return response.strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
|
| 346 |
except Exception as e:
|
| 347 |
logger.error(f"Error generating response: {e}")
|
| 348 |
+
# One-time fallback if current model is not available (404 Not Found)
|
| 349 |
+
err_text = str(e).lower()
|
| 350 |
+
if ("404" in err_text or "not found" in err_text) and self.model != self._fallback_model:
|
| 351 |
+
try:
|
| 352 |
+
logger.warning(
|
| 353 |
+
f"Model {self.model} unavailable on serverless. Falling back to {self._fallback_model} and retrying."
|
| 354 |
+
)
|
| 355 |
+
self.model = self._fallback_model
|
| 356 |
+
self.client = InferenceClient(token=self.api_key, model=self.model)
|
| 357 |
+
retry_prompt = f"User: {user_message}\nAssistant:"
|
| 358 |
+
response = self.client.text_generation(
|
| 359 |
+
retry_prompt,
|
| 360 |
+
max_new_tokens=200,
|
| 361 |
+
temperature=0.7,
|
| 362 |
+
top_p=0.9,
|
| 363 |
+
top_k=40,
|
| 364 |
+
repetition_penalty=1.1,
|
| 365 |
+
stop_sequences=["User:", "Assistant:"]
|
| 366 |
+
)
|
| 367 |
+
if "Assistant:" in response:
|
| 368 |
+
return response.split("Assistant:")[-1].strip()
|
| 369 |
+
return response.strip()
|
| 370 |
+
except Exception as e2:
|
| 371 |
+
logger.error(f"Fallback retry failed: {e2}")
|
| 372 |
+
return self.get_mock_response(user_message)
|
| 373 |
|
| 374 |
def get_mock_response(self, user_message: str) -> str:
|
| 375 |
"""Enhanced mock responses with better context awareness"""
|