saemstunes commited on
Commit
aba4d8f
·
verified ·
1 Parent(s): 38ccc97

Update src/ai_system.py

Browse files
Files changed (1) hide show
  1. src/ai_system.py +444 -164
src/ai_system.py CHANGED
@@ -7,6 +7,7 @@ from typing import Dict, List, Optional, Any, Tuple
7
  import json
8
  import requests
9
  import hashlib
 
10
 
11
  try:
12
  from llama_cpp import Llama
@@ -15,9 +16,10 @@ except ImportError:
15
  print("Warning: llama-cpp-python not available. AI functionality will be limited.")
16
 
17
  try:
18
- from huggingface_hub import hf_hub_download
19
  except ImportError:
20
  hf_hub_download = None
 
21
  print("Warning: huggingface_hub not available. Model download will not work.")
22
 
23
  from .supabase_integration import AdvancedSupabaseIntegration
@@ -27,164 +29,271 @@ from .monitoring_system import ComprehensiveMonitor
27
  class SaemsTunesAISystem:
28
  """
29
  Main AI system for Saem's Tunes music education and streaming platform.
30
- Handles user queries with context from the Supabase database.
31
  """
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def __init__(
34
  self,
35
  supabase_integration: AdvancedSupabaseIntegration,
36
  security_system: AdvancedSecuritySystem,
37
  monitor: ComprehensiveMonitor,
38
- model_name: str = "microsoft/Phi-3.5-mini-instruct",
39
- model_repo: str = "bartowski/Phi-3.5-mini-instruct-GGUF",
40
- model_file: str = "Phi-3.5-mini-instruct-Q4_K_M.gguf",
41
- max_response_length: int = 300,
42
- temperature: float = 0.7,
43
- top_p: float = 0.9,
44
- context_window: int = 4096
45
  ):
46
  self.supabase = supabase_integration
47
  self.security = security_system
48
  self.monitor = monitor
49
- self.model_name = model_name
50
- self.model_repo = model_repo
51
- self.model_file = model_file
52
- self.max_response_length = max_response_length
53
- self.temperature = temperature
54
- self.top_p = top_p
55
- self.context_window = context_window
56
-
57
- self.model = None
58
  self.model_loaded = False
59
- self.model_path = None
60
- self.model_hash = None
61
 
 
62
  self.conversation_history = {}
63
  self.response_cache = {}
 
64
 
65
  self.setup_logging()
66
- self.load_model()
67
 
68
  def setup_logging(self):
69
- """Setup logging for the AI system"""
70
- self.logger = logging.getLogger(__name__)
71
  self.logger.setLevel(logging.INFO)
 
 
 
 
 
 
 
 
72
 
73
- def load_model(self):
74
- """Load the Phi-3.5-mini-instruct model"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  try:
76
- self.logger.info(f"🔄 Loading {self.model_name} model...")
77
-
78
  model_dir = "./models"
79
  os.makedirs(model_dir, exist_ok=True)
80
 
81
- local_path = os.path.join(model_dir, self.model_file)
82
 
 
83
  if os.path.exists(local_path):
84
- self.model_path = local_path
85
  self.logger.info(f"✅ Found local model: {local_path}")
86
 
87
- with open(local_path, 'rb') as f:
88
- file_hash = hashlib.md5()
89
- while chunk := f.read(8192):
90
- file_hash.update(chunk)
91
- self.model_hash = file_hash.hexdigest()
 
 
 
92
 
93
- else:
94
- if hf_hub_download is None:
95
- self.logger.error("❌ huggingface_hub not available for model download")
96
- return
97
-
98
- self.logger.info(f"📥 Downloading model from {self.model_repo}")
99
- self.model_path = hf_hub_download(
100
- repo_id=self.model_repo,
101
- filename=self.model_file,
102
- cache_dir=model_dir,
103
- local_dir_use_symlinks=False
104
- )
105
- self.logger.info(f"✅ Model downloaded: {self.model_path}")
106
-
107
- with open(self.model_path, 'rb') as f:
108
- file_hash = hashlib.md5()
109
- while chunk := f.read(8192):
110
- file_hash.update(chunk)
111
- self.model_hash = file_hash.hexdigest()
112
 
113
- if Llama is None:
114
- self.logger.error("❌ llama-cpp-python not available for model loading")
115
- return
 
116
 
117
- self.model = Llama(
118
- model_path=self.model_path,
119
- n_ctx=self.context_window,
120
- n_threads=min(4, os.cpu_count() or 1),
121
- n_batch=512,
122
- verbose=False,
123
- use_mlock=False,
124
- use_mmap=True,
125
- low_vram=False
126
- )
127
 
128
- test_response = self.model.create_completion(
129
- "Test",
130
- max_tokens=10,
131
- temperature=0.1,
132
- stop=["<|end|>", "</s>"]
 
 
 
133
  )
134
 
135
- if test_response and 'choices' in test_response and len(test_response['choices']) > 0:
136
- self.model_loaded = True
137
- self.logger.info("✅ Model loaded and tested successfully!")
138
- self.logger.info(f"📊 Model info: {self.model_path} (Hash: {self.model_hash})")
139
- else:
140
- self.logger.error("❌ Model test failed")
141
- self.model_loaded = False
142
 
143
  except Exception as e:
144
- self.logger.error(f"❌ Error loading model: {e}")
145
- self.model_loaded = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  def process_query(
148
  self,
149
  query: str,
150
  user_id: str,
151
- conversation_id: Optional[str] = None
152
- ) -> str:
 
153
  """
154
- Process user query and generate response with context from Supabase.
155
 
156
  Args:
157
  query: User's question
158
  user_id: Unique user identifier
159
  conversation_id: Optional conversation ID for context
 
160
 
161
  Returns:
162
- AI-generated response
163
  """
 
 
 
 
 
 
 
164
  if not self.model_loaded:
165
  self.logger.warning("Model not loaded, returning fallback response")
166
- return self.get_fallback_response(query)
 
 
 
 
 
 
 
167
 
168
- cache_key = f"{user_id}:{hash(query)}"
 
169
  if cache_key in self.response_cache:
170
  cached_response, timestamp = self.response_cache[cache_key]
171
- if time.time() - timestamp < 300:
172
- self.logger.info("Returning cached response")
173
- return cached_response
 
 
 
 
 
 
 
174
 
175
  try:
176
- start_time = time.time()
177
-
178
  context = self.supabase.get_music_context(query, user_id)
179
 
180
- prompt = self.build_enhanced_prompt(query, context, user_id, conversation_id)
 
181
 
182
- response = self.model.create_completion(
 
 
183
  prompt,
184
- max_tokens=self.max_response_length,
185
- temperature=self.temperature,
186
- top_p=self.top_p,
187
- stop=["<|end|>", "</s>", "###", "Human:", "Assistant:", "<|endoftext|>"],
188
  echo=False,
189
  stream=False
190
  )
@@ -192,44 +301,64 @@ class SaemsTunesAISystem:
192
  processing_time = time.time() - start_time
193
 
194
  response_text = response['choices'][0]['text'].strip()
 
195
 
196
- response_text = self.clean_response(response_text)
197
-
198
  self.record_metrics(
199
  query=query,
200
  response=response_text,
201
  processing_time=processing_time,
202
  user_id=user_id,
203
  conversation_id=conversation_id,
 
204
  context_used=context,
205
  success=True
206
  )
207
 
 
208
  self.response_cache[cache_key] = (response_text, time.time())
209
 
 
210
  if conversation_id:
211
  self.update_conversation_history(conversation_id, query, response_text)
212
 
213
- self.logger.info(f"✅ Query processed in {processing_time:.2f}s: {query[:50]}...")
214
 
215
- return response_text
 
 
 
 
 
 
 
216
 
217
  except Exception as e:
218
- self.logger.error(f"❌ Error processing query: {e}")
 
219
 
220
  self.record_metrics(
221
  query=query,
222
  response="",
223
- processing_time=0,
224
  user_id=user_id,
225
  conversation_id=conversation_id,
 
226
  error_message=str(e),
227
  success=False
228
  )
229
 
230
- return self.get_error_response(e)
 
 
 
 
 
 
 
 
231
 
232
- def build_enhanced_prompt(
233
  self,
234
  query: str,
235
  context: Dict[str, Any],
@@ -237,7 +366,7 @@ class SaemsTunesAISystem:
237
  conversation_id: Optional[str] = None
238
  ) -> str:
239
  """
240
- Build comprehensive prompt with context from Saem's Tunes platform.
241
  """
242
  conversation_context = ""
243
  if conversation_id and conversation_id in self.conversation_history:
@@ -246,19 +375,101 @@ class SaemsTunesAISystem:
246
  role = "User" if msg["role"] == "user" else "Assistant"
247
  conversation_context += f"{role}: {msg['content']}\n"
248
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  system_prompt = f"""<|system|>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  You are the AI assistant for Saem's Tunes, a comprehensive music education and streaming platform.
251
 
252
  PLATFORM OVERVIEW:
253
- 🎵 **Music Streaming**: {context.get('stats', {}).get('track_count', 0)}+ tracks, {context.get('stats', {}).get('artist_count', 0)}+ artists
254
- 📚 **Education**: Courses, lessons, quizzes, and learning paths
255
- 👥 **Community**: User profiles, favorites, social features
256
- 🎨 **Creator Tools**: Music upload, artist analytics, promotion tools
257
- 💎 **Premium**: Subscription-based premium features
258
 
259
  PLATFORM STATISTICS:
260
  - Total Tracks: {context.get('stats', {}).get('track_count', 0)}
261
- - Total Artists: {context.get('stats', {}).get('artist_count', 0)}
262
  - Total Users: {context.get('stats', {}).get('user_count', 0)}
263
  - Total Courses: {context.get('stats', {}).get('course_count', 0)}
264
  - Active Playlists: {context.get('stats', {}).get('playlist_count', 0)}
@@ -278,14 +489,11 @@ CONVERSATION HISTORY:
278
  RESPONSE GUIDELINES:
279
  1. Be passionate about music and education
280
  2. Provide specific, actionable information about Saem's Tunes
281
- 3. Reference platform features when relevant
282
- 4. Keep responses concise (under {self.max_response_length} words)
283
- 5. Be encouraging and supportive
284
- 6. If unsure, guide users to relevant platform sections
285
- 7. Personalize responses when user context is available
286
- 8. Always maintain a professional, helpful tone
287
- 9. Focus on music education, streaming, and platform features
288
- 10. Avoid discussing unrelated topics
289
 
290
  PLATFORM FEATURES TO MENTION:
291
  - Music streaming and discovery
@@ -296,14 +504,34 @@ PLATFORM FEATURES TO MENTION:
296
  - Premium subscription benefits
297
  - Mobile app availability
298
  - Music recommendations
299
- - Learning progress tracking
300
-
301
- ANSWER THE USER'S QUESTION BASED ON THE ABOVE CONTEXT:<|end|>
302
  """
 
 
 
303
 
304
- user_prompt = f"<|user|>\n{query}<|end|>\n<|assistant|>\n"
305
-
306
- return system_prompt + user_prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
 
308
  def format_popular_content(self, context: Dict[str, Any]) -> str:
309
  """Format popular content for the prompt"""
@@ -362,26 +590,30 @@ ANSWER THE USER'S QUESTION BASED ON THE ABOVE CONTEXT:<|end|>
362
 
363
  return "\n".join(user_lines) if user_lines else "Basic user account"
364
 
365
- def clean_response(self, response: str) -> str:
366
- """Clean and format the AI response"""
367
  if not response:
368
  return "I apologize, but I couldn't generate a response. Please try again."
369
 
370
  response = response.strip()
371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  if response.startswith("I'm sorry") or response.startswith("I apologize"):
373
  if len(response) < 20:
374
  response = "I'd be happy to help you with that! Our platform offers comprehensive music education and streaming features."
375
 
376
- stop_phrases = [
377
- "<|end|>", "</s>", "###", "Human:", "Assistant:",
378
- "<|endoftext|>", "<|assistant|>", "<|user|>"
379
- ]
380
-
381
- for phrase in stop_phrases:
382
- if phrase in response:
383
- response = response.split(phrase)[0].strip()
384
-
385
  sentences = response.split('. ')
386
  if len(sentences) > 1:
387
  response = '. '.join(sentences[:-1]) + '.' if not sentences[-1].endswith('.') else '. '.join(sentences)
@@ -391,8 +623,9 @@ ANSWER THE USER'S QUESTION BASED ON THE ABOVE CONTEXT:<|end|>
391
 
392
  response = response.replace('**', '').replace('__', '').replace('*', '')
393
 
394
- if len(response) > self.max_response_length:
395
- response = response[:self.max_response_length].rsplit(' ', 1)[0] + '...'
 
396
 
397
  return response
398
 
@@ -416,13 +649,14 @@ ANSWER THE USER'S QUESTION BASED ON THE ABOVE CONTEXT:<|end|>
416
  processing_time: float,
417
  user_id: str,
418
  conversation_id: Optional[str] = None,
 
419
  context_used: Optional[Dict] = None,
420
  error_message: Optional[str] = None,
421
  success: bool = True
422
  ):
423
- """Record metrics for monitoring and analytics"""
424
  metrics = {
425
- 'model_name': 'phi3.5-mini-Q4_K_M',
426
  'processing_time_ms': processing_time * 1000,
427
  'input_tokens': len(query.split()),
428
  'output_tokens': len(response.split()) if response else 0,
@@ -433,7 +667,7 @@ ANSWER THE USER'S QUESTION BASED ON THE ABOVE CONTEXT:<|end|>
433
  'timestamp': datetime.now(),
434
  'query_length': len(query),
435
  'response_length': len(response) if response else 0,
436
- 'model_hash': self.model_hash
437
  }
438
 
439
  if error_message:
@@ -448,6 +682,27 @@ ANSWER THE USER'S QUESTION BASED ON THE ABOVE CONTEXT:<|end|>
448
  'context_summary': context_used.get('summary', '')[:100]
449
  }
450
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451
  self.monitor.record_inference(metrics)
452
 
453
  def get_fallback_response(self, query: str) -> str:
@@ -485,24 +740,33 @@ ANSWER THE USER'S QUESTION BASED ON THE ABOVE CONTEXT:<|end|>
485
 
486
  def is_healthy(self) -> bool:
487
  """Check if AI system is healthy and ready"""
488
- return self.model_loaded and self.model is not None and self.supabase.is_connected()
489
 
490
  def get_system_info(self) -> Dict[str, Any]:
491
- """Get system information for monitoring"""
 
 
 
 
 
 
 
 
 
492
  return {
 
493
  "model_loaded": self.model_loaded,
494
- "model_name": self.model_name,
495
- "model_path": self.model_path,
496
- "model_hash": self.model_hash,
497
- "max_response_length": self.max_response_length,
498
- "temperature": self.temperature,
499
- "top_p": self.top_p,
500
- "context_window": self.context_window,
501
  "supabase_connected": self.supabase.is_connected(),
502
  "conversations_active": len(self.conversation_history),
503
- "cache_size": len(self.response_cache)
 
504
  }
505
 
 
 
 
 
506
  def clear_cache(self, user_id: Optional[str] = None):
507
  """Clear response cache"""
508
  if user_id:
@@ -512,15 +776,31 @@ ANSWER THE USER'S QUESTION BASED ON THE ABOVE CONTEXT:<|end|>
512
  else:
513
  self.response_cache.clear()
514
 
515
- def get_model_stats(self) -> Dict[str, Any]:
516
- """Get model statistics"""
517
- if not self.model_loaded:
518
- return {"error": "Model not loaded"}
519
-
520
- return {
521
- "context_size": self.context_window,
522
- "parameters": "3.8B",
523
- "quantization": "Q4_K_M",
524
- "model_size_gb": round(os.path.getsize(self.model_path) / (1024**3), 2) if self.model_path else 0,
525
- "cache_hit_rate": len(self.response_cache) / (len(self.response_cache) + len(self.conversation_history)) if self.conversation_history else 0
526
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import json
8
  import requests
9
  import hashlib
10
+ import gc
11
 
12
  try:
13
  from llama_cpp import Llama
 
16
  print("Warning: llama-cpp-python not available. AI functionality will be limited.")
17
 
18
  try:
19
+ from huggingface_hub import hf_hub_download, snapshot_download
20
  except ImportError:
21
  hf_hub_download = None
22
+ snapshot_download = None
23
  print("Warning: huggingface_hub not available. Model download will not work.")
24
 
25
  from .supabase_integration import AdvancedSupabaseIntegration
 
29
  class SaemsTunesAISystem:
30
  """
31
  Main AI system for Saem's Tunes music education and streaming platform.
32
+ Supports multiple GGUF models with intelligent model switching and context-aware responses.
33
  """
34
 
35
+ # Model configurations for all three LLM options
36
+ MODEL_CONFIGS = {
37
+ 'TinyLlama-1.1B-Chat': {
38
+ 'model_repo': "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
39
+ 'model_file': "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf",
40
+ 'max_response_length': 200,
41
+ 'temperature': 0.7,
42
+ 'top_p': 0.9,
43
+ 'context_window': 2048,
44
+ 'description': 'Fastest response, basic conversations'
45
+ },
46
+ 'Phi-2': {
47
+ 'model_repo': "TheBloke/phi-2-GGUF",
48
+ 'model_file': "phi-2.Q4_K_M.gguf",
49
+ 'max_response_length': 250,
50
+ 'temperature': 0.7,
51
+ 'top_p': 0.9,
52
+ 'context_window': 2048,
53
+ 'description': 'Good balance of speed and quality'
54
+ },
55
+ 'Qwen-1.8B-Chat': {
56
+ 'model_repo': "TheBloke/Qwen1.5-1.8B-Chat-GGUF",
57
+ 'model_file': "qwen1.5-1.8b-chat-q4_k_m.gguf",
58
+ 'max_response_length': 300,
59
+ 'temperature': 0.7,
60
+ 'top_p': 0.9,
61
+ 'context_window': 4096,
62
+ 'description': 'Best for complex conversations'
63
+ }
64
+ }
65
+
66
  def __init__(
67
  self,
68
  supabase_integration: AdvancedSupabaseIntegration,
69
  security_system: AdvancedSecuritySystem,
70
  monitor: ComprehensiveMonitor,
71
+ default_model: str = 'TinyLlama-1.1B-Chat'
 
 
 
 
 
 
72
  ):
73
  self.supabase = supabase_integration
74
  self.security = security_system
75
  self.monitor = monitor
76
+
77
+ # Model management
78
+ self.available_models = {}
79
+ self.current_model = None
80
+ self.current_model_name = default_model
 
 
 
 
81
  self.model_loaded = False
 
 
82
 
83
+ # Response management
84
  self.conversation_history = {}
85
  self.response_cache = {}
86
+ self.model_usage_stats = {}
87
 
88
  self.setup_logging()
89
+ self.initialize_models()
90
 
91
  def setup_logging(self):
92
+ """Setup comprehensive logging for the AI system"""
93
+ self.logger = logging.getLogger('SaemsTunesAI')
94
  self.logger.setLevel(logging.INFO)
95
+
96
+ if not self.logger.handlers:
97
+ handler = logging.StreamHandler()
98
+ formatter = logging.Formatter(
99
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
100
+ )
101
+ handler.setFormatter(formatter)
102
+ self.logger.addHandler(handler)
103
 
104
+ def initialize_models(self):
105
+ """Initialize all available models"""
106
+ self.logger.info("🔄 Initializing AI models for Saem's Tunes...")
107
+
108
+ for model_name, config in self.MODEL_CONFIGS.items():
109
+ try:
110
+ self.logger.info(f"📦 Setting up {model_name}...")
111
+ model_path = self.download_or_get_model(
112
+ config['model_repo'],
113
+ config['model_file']
114
+ )
115
+
116
+ if model_path and Llama is not None:
117
+ self.available_models[model_name] = {
118
+ 'model': Llama(
119
+ model_path=model_path,
120
+ n_ctx=config['context_window'],
121
+ n_threads=min(6, os.cpu_count() or 2),
122
+ n_batch=512,
123
+ verbose=False,
124
+ use_mlock=False,
125
+ use_mmap=True,
126
+ low_vram=False
127
+ ),
128
+ 'config': config,
129
+ 'path': model_path,
130
+ 'loaded': True
131
+ }
132
+ self.logger.info(f"✅ {model_name} initialized successfully")
133
+ else:
134
+ self.available_models[model_name] = {
135
+ 'model': None,
136
+ 'config': config,
137
+ 'path': model_path,
138
+ 'loaded': False
139
+ }
140
+ self.logger.warning(f"⚠️ {model_name} setup incomplete")
141
+
142
+ except Exception as e:
143
+ self.logger.error(f"❌ Failed to initialize {model_name}: {e}")
144
+ self.available_models[model_name] = {
145
+ 'model': None,
146
+ 'config': config,
147
+ 'path': None,
148
+ 'loaded': False
149
+ }
150
+
151
+ # Set current model
152
+ if self.current_model_name in self.available_models:
153
+ self.current_model = self.available_models[self.current_model_name]
154
+ self.model_loaded = self.current_model['loaded']
155
+ self.logger.info(f"🎯 Current model set to: {self.current_model_name}")
156
+ else:
157
+ self.logger.error("❌ Default model not available")
158
+
159
+ def download_or_get_model(self, model_repo: str, model_file: str) -> Optional[str]:
160
+ """Download model from Hugging Face or use local if available"""
161
  try:
 
 
162
  model_dir = "./models"
163
  os.makedirs(model_dir, exist_ok=True)
164
 
165
+ local_path = os.path.join(model_dir, model_file)
166
 
167
+ # Check if model already exists locally
168
  if os.path.exists(local_path):
 
169
  self.logger.info(f"✅ Found local model: {local_path}")
170
 
171
+ # Verify file integrity
172
+ file_size = os.path.getsize(local_path)
173
+ if file_size > 1000000: # At least 1MB
174
+ with open(local_path, 'rb') as f:
175
+ file_hash = hashlib.md5()
176
+ while chunk := f.read(8192):
177
+ file_hash.update(chunk)
178
+ model_hash = file_hash.hexdigest()
179
 
180
+ self.logger.info(f"📊 Model verified: {local_path} (Size: {file_size/(1024**3):.2f}GB, Hash: {model_hash[:16]}...)")
181
+ return local_path
182
+ else:
183
+ self.logger.warning("⚠️ Local model file seems corrupted, re-downloading...")
184
+ os.remove(local_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
+ # Download from Hugging Face
187
+ if hf_hub_download is None:
188
+ self.logger.error("❌ huggingface_hub not available for model download")
189
+ return None
190
 
191
+ self.logger.info(f"📥 Downloading {model_file} from {model_repo}...")
 
 
 
 
 
 
 
 
 
192
 
193
+ start_time = time.time()
194
+ model_path = hf_hub_download(
195
+ repo_id=model_repo,
196
+ filename=model_file,
197
+ cache_dir=model_dir,
198
+ local_dir=model_dir,
199
+ local_dir_use_symlinks=False,
200
+ resume_download=True
201
  )
202
 
203
+ download_time = time.time() - start_time
204
+ file_size = os.path.getsize(model_path) / (1024**3) # GB
205
+
206
+ self.logger.info(f" Download completed: {model_path} ({file_size:.2f}GB in {download_time:.1f}s)")
207
+ return model_path
 
 
208
 
209
  except Exception as e:
210
+ self.logger.error(f"❌ Error downloading model {model_file}: {e}")
211
+ return None
212
+
213
+ def switch_model(self, model_name: str) -> bool:
214
+ """Switch to a different model"""
215
+ if model_name not in self.available_models:
216
+ self.logger.error(f"❌ Model {model_name} not available")
217
+ return False
218
+
219
+ if not self.available_models[model_name]['loaded']:
220
+ self.logger.error(f"❌ Model {model_name} not loaded properly")
221
+ return False
222
+
223
+ self.current_model_name = model_name
224
+ self.current_model = self.available_models[model_name]
225
+ self.model_loaded = True
226
+
227
+ self.logger.info(f"🔄 Switched to model: {model_name}")
228
+ return True
229
 
230
  def process_query(
231
  self,
232
  query: str,
233
  user_id: str,
234
+ conversation_id: Optional[str] = None,
235
+ model_name: Optional[str] = None
236
+ ) -> Dict[str, Any]:
237
  """
238
+ Process user query with context from Saem's Tunes platform.
239
 
240
  Args:
241
  query: User's question
242
  user_id: Unique user identifier
243
  conversation_id: Optional conversation ID for context
244
+ model_name: Specific model to use for this query
245
 
246
  Returns:
247
+ Dictionary containing response and metadata
248
  """
249
+ start_time = time.time()
250
+
251
+ # Switch model if requested
252
+ if model_name and model_name != self.current_model_name:
253
+ if not self.switch_model(model_name):
254
+ self.logger.warning(f"⚠️ Failed to switch to {model_name}, using current model")
255
+
256
  if not self.model_loaded:
257
  self.logger.warning("Model not loaded, returning fallback response")
258
+ fallback_response = self.get_fallback_response(query)
259
+ return {
260
+ 'response': fallback_response,
261
+ 'processing_time': time.time() - start_time,
262
+ 'model_used': 'fallback',
263
+ 'success': False,
264
+ 'conversation_id': conversation_id
265
+ }
266
 
267
+ # Check cache first
268
+ cache_key = f"{user_id}:{hash(query)}:{self.current_model_name}"
269
  if cache_key in self.response_cache:
270
  cached_response, timestamp = self.response_cache[cache_key]
271
+ if time.time() - timestamp < 300: # 5 minute cache
272
+ self.logger.info("💾 Returning cached response")
273
+ return {
274
+ 'response': cached_response,
275
+ 'processing_time': 0.01,
276
+ 'model_used': self.current_model_name,
277
+ 'success': True,
278
+ 'cached': True,
279
+ 'conversation_id': conversation_id
280
+ }
281
 
282
  try:
283
+ # Get comprehensive context from Supabase
 
284
  context = self.supabase.get_music_context(query, user_id)
285
 
286
+ # Build model-specific prompt
287
+ prompt = self.build_model_specific_prompt(query, context, user_id, conversation_id)
288
 
289
+ # Generate response
290
+ model_config = self.current_model['config']
291
+ response = self.current_model['model'].create_completion(
292
  prompt,
293
+ max_tokens=model_config['max_response_length'],
294
+ temperature=model_config['temperature'],
295
+ top_p=model_config['top_p'],
296
+ stop=self.get_model_stop_tokens(self.current_model_name),
297
  echo=False,
298
  stream=False
299
  )
 
301
  processing_time = time.time() - start_time
302
 
303
  response_text = response['choices'][0]['text'].strip()
304
+ response_text = self.clean_response(response_text, self.current_model_name)
305
 
306
+ # Update metrics
 
307
  self.record_metrics(
308
  query=query,
309
  response=response_text,
310
  processing_time=processing_time,
311
  user_id=user_id,
312
  conversation_id=conversation_id,
313
+ model_used=self.current_model_name,
314
  context_used=context,
315
  success=True
316
  )
317
 
318
+ # Cache response
319
  self.response_cache[cache_key] = (response_text, time.time())
320
 
321
+ # Update conversation history
322
  if conversation_id:
323
  self.update_conversation_history(conversation_id, query, response_text)
324
 
325
+ self.logger.info(f"✅ Query processed in {processing_time:.2f}s using {self.current_model_name}")
326
 
327
+ return {
328
+ 'response': response_text,
329
+ 'processing_time': processing_time,
330
+ 'model_used': self.current_model_name,
331
+ 'success': True,
332
+ 'conversation_id': conversation_id,
333
+ 'context_used': context.get('summary', 'General platform context')
334
+ }
335
 
336
  except Exception as e:
337
+ processing_time = time.time() - start_time
338
+ self.logger.error(f"❌ Error processing query with {self.current_model_name}: {e}")
339
 
340
  self.record_metrics(
341
  query=query,
342
  response="",
343
+ processing_time=processing_time,
344
  user_id=user_id,
345
  conversation_id=conversation_id,
346
+ model_used=self.current_model_name,
347
  error_message=str(e),
348
  success=False
349
  )
350
 
351
+ error_response = self.get_error_response(e)
352
+ return {
353
+ 'response': error_response,
354
+ 'processing_time': processing_time,
355
+ 'model_used': self.current_model_name,
356
+ 'success': False,
357
+ 'error': str(e),
358
+ 'conversation_id': conversation_id
359
+ }
360
 
361
+ def build_model_specific_prompt(
362
  self,
363
  query: str,
364
  context: Dict[str, Any],
 
366
  conversation_id: Optional[str] = None
367
  ) -> str:
368
  """
369
+ Build comprehensive prompt tailored to specific model requirements.
370
  """
371
  conversation_context = ""
372
  if conversation_id and conversation_id in self.conversation_history:
 
375
  role = "User" if msg["role"] == "user" else "Assistant"
376
  conversation_context += f"{role}: {msg['content']}\n"
377
 
378
+ # Model-specific prompt templates
379
+ if self.current_model_name == 'TinyLlama-1.1B-Chat':
380
+ return self.build_tinyllama_prompt(query, context, conversation_context, user_id)
381
+ elif self.current_model_name == 'Phi-2':
382
+ return self.build_phi2_prompt(query, context, conversation_context, user_id)
383
+ elif self.current_model_name == 'Qwen-1.8B-Chat':
384
+ return self.build_qwen_prompt(query, context, conversation_context, user_id)
385
+ else:
386
+ return self.build_default_prompt(query, context, conversation_context, user_id)
387
+
388
+ def build_tinyllama_prompt(self, query: str, context: Dict, conversation_context: str, user_id: str) -> str:
389
+ """Build prompt for TinyLlama model"""
390
  system_prompt = f"""<|system|>
391
+ You are the AI assistant for Saem's Tunes - a music education and streaming platform.
392
+
393
+ PLATFORM CONTEXT:
394
+ - Music Streaming: {context.get('stats', {}).get('track_count', 0)}+ tracks
395
+ - Education: Courses, lessons, and learning paths
396
+ - Community: User profiles and social features
397
+ - Creator Tools: Music upload and analytics
398
+
399
+ CURRENT STATS:
400
+ - Tracks: {context.get('stats', {}).get('track_count', 0)}
401
+ - Artists: {context.get('stats', {}).get('artist_count', 0)}
402
+ - Users: {context.get('stats', {}).get('user_count', 0)}
403
+ - Courses: {context.get('stats', {}).get('course_count', 0)}
404
+
405
+ POPULAR CONTENT:
406
+ {self.format_popular_content(context)}
407
+
408
+ CONVERSATION HISTORY:
409
+ {conversation_context if conversation_context else 'No recent history'}
410
+
411
+ GUIDELINES:
412
+ - Be concise and helpful about Saem's Tunes
413
+ - Focus on music streaming and education
414
+ - Keep responses under 150 words
415
+ - Be enthusiastic about music
416
+
417
+ Answer the user's question based on the platform context.</s>
418
+ """
419
+ return system_prompt + f"<|user|>\n{query}</s>\n<|assistant|>\n"
420
+
421
+ def build_phi2_prompt(self, query: str, context: Dict, conversation_context: str, user_id: str) -> str:
422
+ """Build prompt for Phi-2 model"""
423
+ system_prompt = f"""You are an AI assistant for Saem's Tunes music platform.
424
+
425
+ Platform Overview:
426
+ - Music streaming with {context.get('stats', {}).get('track_count', 0)}+ tracks
427
+ - Educational courses and learning paths
428
+ - Community features and creator tools
429
+ - Premium subscription benefits
430
+
431
+ Current Statistics:
432
+ - Total Tracks: {context.get('stats', {}).get('track_count', 0)}
433
+ - Total Artists: {context.get('stats', {}).get('artist_count', 0)}
434
+ - Total Users: {context.get('stats', {}).get('user_count', 0)}
435
+ - Total Courses: {context.get('stats', {}).get('course_count', 0)}
436
+
437
+ Popular Content:
438
+ {self.format_popular_content(context)}
439
+
440
+ Conversation History:
441
+ {conversation_context if conversation_context else 'No recent conversation'}
442
+
443
+ User Context:
444
+ {self.format_user_context(context.get('user_context', {}))}
445
+
446
+ Instructions:
447
+ 1. Provide specific information about Saem's Tunes features
448
+ 2. Focus on music education and streaming capabilities
449
+ 3. Keep responses clear and informative
450
+ 4. Reference platform statistics when relevant
451
+ 5. Be passionate about music education
452
+
453
+ Question: {query}
454
+
455
+ Answer:"""
456
+ return system_prompt
457
+
458
+ def build_qwen_prompt(self, query: str, context: Dict, conversation_context: str, user_id: str) -> str:
459
+ """Build prompt for Qwen model"""
460
+ system_prompt = f"""<|im_start|>system
461
  You are the AI assistant for Saem's Tunes, a comprehensive music education and streaming platform.
462
 
463
  PLATFORM OVERVIEW:
464
+ 🎵 Music Streaming: {context.get('stats', {}).get('track_count', 0)}+ tracks, {context.get('stats', {}).get('artist_count', 0)}+ artists
465
+ 📚 Education: Courses, lessons, quizzes, learning paths
466
+ 👥 Community: User profiles, favorites, social features
467
+ 🎨 Creator Tools: Music upload, artist analytics, promotion
468
+ 💎 Premium: Subscription-based premium features
469
 
470
  PLATFORM STATISTICS:
471
  - Total Tracks: {context.get('stats', {}).get('track_count', 0)}
472
+ - Total Artists: {context.get('stats', {}).get('artist_count', 0)}
473
  - Total Users: {context.get('stats', {}).get('user_count', 0)}
474
  - Total Courses: {context.get('stats', {}).get('course_count', 0)}
475
  - Active Playlists: {context.get('stats', {}).get('playlist_count', 0)}
 
489
  RESPONSE GUIDELINES:
490
  1. Be passionate about music and education
491
  2. Provide specific, actionable information about Saem's Tunes
492
+ 3. Reference platform features when relevant
493
+ 4. Keep responses concise and helpful
494
+ 5. Personalize responses when user context is available
495
+ 6. Focus on music education, streaming, and platform features
496
+ 7. Always maintain a professional, helpful tone
 
 
 
497
 
498
  PLATFORM FEATURES TO MENTION:
499
  - Music streaming and discovery
 
504
  - Premium subscription benefits
505
  - Mobile app availability
506
  - Music recommendations
507
+ - Learning progress tracking<|im_end|>
 
 
508
  """
509
+ conversation_history = ""
510
+ if conversation_context:
511
+ conversation_history = conversation_context.replace("User:", "<|im_start|>user\n").replace("Assistant:", "<|im_start|>assistant\n")
512
 
513
+ return system_prompt + conversation_history + f"<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n"
514
+
515
+ def build_default_prompt(self, query: str, context: Dict, conversation_context: str, user_id: str) -> str:
516
+ """Default prompt for unknown models"""
517
+ return f"""Saem's Tunes Platform Assistant
518
+ Platform Stats: {context.get('stats', {}).get('track_count', 0)} tracks, {context.get('stats', {}).get('user_count', 0)} users
519
+ Context: {context.get('summary', 'Music education and streaming platform')}
520
+
521
+ Question: {query}
522
+
523
+ Answer:"""
524
+
525
+ def get_model_stop_tokens(self, model_name: str) -> List[str]:
526
+ """Get appropriate stop tokens for each model"""
527
+ if model_name == 'TinyLlama-1.1B-Chat':
528
+ return ["</s>", "<|user|>", "<|assistant|>", "<|system|>"]
529
+ elif model_name == 'Phi-2':
530
+ return ["\n\n", "###", "Human:", "Assistant:"]
531
+ elif model_name == 'Qwen-1.8B-Chat':
532
+ return ["<|im_end|>", "<|im_start|>", "\n\n"]
533
+ else:
534
+ return ["\n\n", "###", "Human:", "Assistant:"]
535
 
536
  def format_popular_content(self, context: Dict[str, Any]) -> str:
537
  """Format popular content for the prompt"""
 
590
 
591
  return "\n".join(user_lines) if user_lines else "Basic user account"
592
 
593
+ def clean_response(self, response: str, model_name: str) -> str:
594
+ """Clean and format the AI response based on model"""
595
  if not response:
596
  return "I apologize, but I couldn't generate a response. Please try again."
597
 
598
  response = response.strip()
599
 
600
+ # Model-specific cleaning
601
+ if model_name == 'TinyLlama-1.1B-Chat':
602
+ stop_tokens = ["</s>", "<|user|>", "<|assistant|>", "<|system|>"]
603
+ elif model_name == 'Qwen-1.8B-Chat':
604
+ stop_tokens = ["<|im_end|>", "<|im_start|>"]
605
+ else:
606
+ stop_tokens = ["\n\n", "###", "Human:", "Assistant:"]
607
+
608
+ for token in stop_tokens:
609
+ if token in response:
610
+ response = response.split(token)[0].strip()
611
+
612
+ # General cleaning
613
  if response.startswith("I'm sorry") or response.startswith("I apologize"):
614
  if len(response) < 20:
615
  response = "I'd be happy to help you with that! Our platform offers comprehensive music education and streaming features."
616
 
 
 
 
 
 
 
 
 
 
617
  sentences = response.split('. ')
618
  if len(sentences) > 1:
619
  response = '. '.join(sentences[:-1]) + '.' if not sentences[-1].endswith('.') else '. '.join(sentences)
 
623
 
624
  response = response.replace('**', '').replace('__', '').replace('*', '')
625
 
626
+ max_length = self.current_model['config']['max_response_length'] if self.current_model else 200
627
+ if len(response) > max_length:
628
+ response = response[:max_length].rsplit(' ', 1)[0] + '...'
629
 
630
  return response
631
 
 
649
  processing_time: float,
650
  user_id: str,
651
  conversation_id: Optional[str] = None,
652
+ model_used: Optional[str] = None,
653
  context_used: Optional[Dict] = None,
654
  error_message: Optional[str] = None,
655
  success: bool = True
656
  ):
657
+ """Record comprehensive metrics for monitoring"""
658
  metrics = {
659
+ 'model_name': model_used or self.current_model_name,
660
  'processing_time_ms': processing_time * 1000,
661
  'input_tokens': len(query.split()),
662
  'output_tokens': len(response.split()) if response else 0,
 
667
  'timestamp': datetime.now(),
668
  'query_length': len(query),
669
  'response_length': len(response) if response else 0,
670
+ 'current_model': self.current_model_name
671
  }
672
 
673
  if error_message:
 
682
  'context_summary': context_used.get('summary', '')[:100]
683
  }
684
 
685
+ # Update model usage stats
686
+ if model_used:
687
+ if model_used not in self.model_usage_stats:
688
+ self.model_usage_stats[model_used] = {
689
+ 'total_requests': 0,
690
+ 'successful_requests': 0,
691
+ 'total_processing_time': 0,
692
+ 'average_response_time': 0
693
+ }
694
+
695
+ self.model_usage_stats[model_used]['total_requests'] += 1
696
+ self.model_usage_stats[model_used]['total_processing_time'] += processing_time
697
+
698
+ if success:
699
+ self.model_usage_stats[model_used]['successful_requests'] += 1
700
+
701
+ self.model_usage_stats[model_used]['average_response_time'] = (
702
+ self.model_usage_stats[model_used]['total_processing_time'] /
703
+ self.model_usage_stats[model_used]['total_requests']
704
+ )
705
+
706
  self.monitor.record_inference(metrics)
707
 
708
  def get_fallback_response(self, query: str) -> str:
 
740
 
741
  def is_healthy(self) -> bool:
742
  """Check if AI system is healthy and ready"""
743
+ return self.model_loaded and self.current_model is not None and self.supabase.is_connected()
744
 
745
  def get_system_info(self) -> Dict[str, Any]:
746
+ """Get comprehensive system information"""
747
+ loaded_models = []
748
+ for name, model_info in self.available_models.items():
749
+ loaded_models.append({
750
+ 'name': name,
751
+ 'loaded': model_info['loaded'],
752
+ 'path': model_info['path'],
753
+ 'config': model_info['config']
754
+ })
755
+
756
  return {
757
+ "current_model": self.current_model_name,
758
  "model_loaded": self.model_loaded,
759
+ "available_models": loaded_models,
 
 
 
 
 
 
760
  "supabase_connected": self.supabase.is_connected(),
761
  "conversations_active": len(self.conversation_history),
762
+ "cache_size": len(self.response_cache),
763
+ "model_usage_stats": self.model_usage_stats
764
  }
765
 
766
+ def get_model_performance(self) -> Dict[str, Any]:
767
+ """Get performance statistics for all models"""
768
+ return self.model_usage_stats
769
+
770
  def clear_cache(self, user_id: Optional[str] = None):
771
  """Clear response cache"""
772
  if user_id:
 
776
  else:
777
  self.response_cache.clear()
778
 
779
+ def get_available_models(self) -> List[Dict[str, Any]]:
780
+ """Get list of available models with their status"""
781
+ models = []
782
+ for name, info in self.available_models.items():
783
+ models.append({
784
+ 'name': name,
785
+ 'loaded': info['loaded'],
786
+ 'description': info['config']['description'],
787
+ 'max_response_length': info['config']['max_response_length'],
788
+ 'context_window': info['config']['context_window']
789
+ })
790
+ return models
791
+
792
+ def unload_model(self, model_name: str):
793
+ """Unload a specific model to free memory"""
794
+ if model_name in self.available_models:
795
+ if self.available_models[model_name]['model'] is not None:
796
+ # Force garbage collection
797
+ del self.available_models[model_name]['model']
798
+ self.available_models[model_name]['model'] = None
799
+ self.available_models[model_name]['loaded'] = False
800
+
801
+ if self.current_model_name == model_name:
802
+ self.current_model = None
803
+ self.model_loaded = False
804
+
805
+ gc.collect()
806
+ self.logger.info(f"🗑️ Unloaded model: {model_name}")