SHIKARICHACHA commited on
Commit
a386765
·
verified ·
1 Parent(s): 78c07af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -94
app.py CHANGED
@@ -20,6 +20,8 @@ Major updates:
20
  import sys
21
  import subprocess
22
  from typing import Dict, Optional, Tuple, List
 
 
23
 
24
  def install(packages: List[str]):
25
  for package in packages:
@@ -37,7 +39,6 @@ install([
37
  # -----------------------------------------------------------------------------
38
  # 2. Static imports
39
  # -----------------------------------------------------------------------------
40
- import random
41
  import requests
42
  import json
43
  import tempfile
@@ -268,7 +269,7 @@ def get_technique_based_on_level(level: str) -> str:
268
  return random.choice(techniques.get(level, ["with slurs"]))
269
 
270
  # -----------------------------------------------------------------------------
271
- # 9. LLM Query Function (supports multiple models)
272
  # -----------------------------------------------------------------------------
273
  def query_llm(model_name: str, prompt: str, instrument: str, level: str, key: str,
274
  time_sig: str, measures: int) -> str:
@@ -302,7 +303,90 @@ def query_llm(model_name: str, prompt: str, instrument: str, level: str, key: st
302
  "Sum must be exactly as specified. ONLY output the JSON array. No prose."
303
  )
304
 
305
- if model_name == "Mistral":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  headers = {
307
  "Authorization": f"Bearer {MISTRAL_API_KEY}",
308
  "Content-Type": "application/json",
@@ -319,53 +403,12 @@ def query_llm(model_name: str, prompt: str, instrument: str, level: str, key: st
319
  "frequency_penalty": 0.2,
320
  "presence_penalty": 0.2,
321
  }
322
- try:
323
- response = requests.post(MISTRAL_API_URL, headers=headers, json=payload)
324
- response.raise_for_status()
325
- content = response.json()["choices"][0]["message"]["content"]
326
- return content.replace("```json","").replace("```","").strip()
327
- except Exception as e:
328
- print(f"Error querying Mistral API: {e}")
329
- return get_fallback_exercise(instrument, level, key, time_sig, measures)
330
-
331
- elif model_name in ["DeepSeek", "Claude", "Gemma", "Kimi", "Llama 3.1"]:
332
- try:
333
- client = OpenAI(
334
- base_url="https://openrouter.ai/api/v1",
335
- api_key=OPENROUTER_API_KEYS[model_name],
336
- )
337
-
338
- model_map = {
339
- "DeepSeek": "deepseek/deepseek-chat-v3-0324:free",
340
- "Claude": "anthropic/claude-3.5-sonnet:beta",
341
- "Gemma": "google/gemma-3n-e2b-it:free",
342
- "Kimi": "moonshotai/kimi-dev-72b:free",
343
- "Llama 3.1": "meta-llama/llama-3.1-405b-instruct:free"
344
- }
345
-
346
- completion = client.chat.completions.create(
347
- extra_headers={
348
- "HTTP-Referer": "https://github.com/AdaptiveMusicExerciseGenerator",
349
- "X-Title": "Music Exercise Generator",
350
- },
351
- model=model_map[model_name],
352
- messages=[
353
- {"role": "system", "content": system_prompt},
354
- {"role": "user", "content": user_prompt},
355
- ],
356
- temperature=0.7 if level == "Advanced" else 0.5,
357
- max_tokens=1000,
358
- top_p=0.95,
359
- frequency_penalty=0.2,
360
- presence_penalty=0.2,
361
- )
362
- content = completion.choices[0].message.content
363
- return content.replace("```json","").replace("```","").strip()
364
- except Exception as e:
365
- print(f"Error querying {model_name} API: {e}")
366
- return get_fallback_exercise(instrument, level, key, time_sig, measures)
367
-
368
- else:
369
  return get_fallback_exercise(instrument, level, key, time_sig, measures)
370
 
371
  # -----------------------------------------------------------------------------
@@ -414,7 +457,7 @@ def generate_exercise(instrument: str, level: str, key: str, tempo: int, time_si
414
  return f"Error: {str(e)}", None, str(tempo), None, "0", time_signature, 0
415
 
416
  # -----------------------------------------------------------------------------
417
- # 12. AI chat assistant
418
  # -----------------------------------------------------------------------------
419
  def handle_chat(message: str, history: List, instrument: str, level: str, ai_model: str):
420
  if not message.strip():
@@ -425,54 +468,81 @@ def handle_chat(message: str, history: List, instrument: str, level: str, ai_mod
425
  messages.append({"role": "assistant", "content": assistant_msg})
426
  messages.append({"role": "user", "content": message})
427
 
428
- if ai_model == "Mistral":
429
- headers = {"Authorization": f"Bearer {MISTRAL_API_KEY}", "Content-Type": "application/json"}
430
- payload = {"model": "mistral-medium", "messages": messages, "temperature": 0.7, "max_tokens": 500}
431
- try:
432
- response = requests.post(MISTRAL_API_URL, headers=headers, json=payload)
433
- response.raise_for_status()
434
- content = response.json()["choices"][0]["message"]["content"]
435
- history.append((message, content))
436
- return "", history
437
- except Exception as e:
438
- history.append((message, f"Error: {str(e)}"))
439
- return "", history
440
 
441
- elif ai_model in ["DeepSeek", "Claude", "Gemma", "Kimi", "Llama 3.1"]:
442
  try:
443
- client = OpenAI(
444
- base_url="https://openrouter.ai/api/v1",
445
- api_key=OPENROUTER_API_KEYS[ai_model],
446
- )
 
 
 
 
447
 
448
- model_map = {
449
- "DeepSeek": "deepseek/deepseek-chat-v3-0324:free",
450
- "Claude": "anthropic/claude-3.5-sonnet:beta",
451
- "Gemma": "google/gemma-3n-e2b-it:free",
452
- "Kimi": "moonshotai/kimi-dev-72b:free",
453
- "Llama 3.1": "meta-llama/llama-3.1-405b-instruct:free"
454
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
455
 
456
- completion = client.chat.completions.create(
457
- extra_headers={
458
- "HTTP-Referer": "https://github.com/AdaptiveMusicExerciseGenerator",
459
- "X-Title": "Music Exercise Generator",
460
- },
461
- model=model_map[ai_model],
462
- messages=messages,
463
- temperature=0.7,
464
- max_tokens=500,
465
- )
466
- content = completion.choices[0].message.content
467
- history.append((message, content))
468
- return "", history
469
  except Exception as e:
470
- history.append((message, f"Error: {str(e)}"))
471
- return "", history
472
-
473
- else:
474
- history.append((message, "Error: Invalid AI model selected"))
475
- return "", history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
476
 
477
  # -----------------------------------------------------------------------------
478
  # 13. Gradio user interface definition
 
20
  import sys
21
  import subprocess
22
  from typing import Dict, Optional, Tuple, List
23
+ import time
24
+ import random
25
 
26
  def install(packages: List[str]):
27
  for package in packages:
 
39
  # -----------------------------------------------------------------------------
40
  # 2. Static imports
41
  # -----------------------------------------------------------------------------
 
42
  import requests
43
  import json
44
  import tempfile
 
269
  return random.choice(techniques.get(level, ["with slurs"]))
270
 
271
  # -----------------------------------------------------------------------------
272
+ # 9. LLM Query Function (with enhanced error handling)
273
  # -----------------------------------------------------------------------------
274
  def query_llm(model_name: str, prompt: str, instrument: str, level: str, key: str,
275
  time_sig: str, measures: int) -> str:
 
303
  "Sum must be exactly as specified. ONLY output the JSON array. No prose."
304
  )
305
 
306
+ # Retry up to 3 times for rate limited models
307
+ max_retries = 3
308
+ retry_delay = 5 # seconds
309
+
310
+ for attempt in range(max_retries):
311
+ try:
312
+ if model_name == "Mistral":
313
+ headers = {
314
+ "Authorization": f"Bearer {MISTRAL_API_KEY}",
315
+ "Content-Type": "application/json",
316
+ }
317
+ payload = {
318
+ "model": "mistral-medium",
319
+ "messages": [
320
+ {"role": "system", "content": system_prompt},
321
+ {"role": "user", "content": user_prompt},
322
+ ],
323
+ "temperature": 0.7 if level == "Advanced" else 0.5,
324
+ "max_tokens": 1000,
325
+ "top_p": 0.95,
326
+ "frequency_penalty": 0.2,
327
+ "presence_penalty": 0.2,
328
+ }
329
+ response = requests.post(MISTRAL_API_URL, headers=headers, json=payload)
330
+ response.raise_for_status()
331
+ content = response.json()["choices"][0]["message"]["content"]
332
+ return content.replace("```json","").replace("```","").strip()
333
+
334
+ elif model_name in ["DeepSeek", "Claude", "Gemma", "Kimi", "Llama 3.1"]:
335
+ client = OpenAI(
336
+ base_url="https://openrouter.ai/api/v1",
337
+ api_key=OPENROUTER_API_KEYS[model_name],
338
+ )
339
+
340
+ model_map = {
341
+ "DeepSeek": "deepseek/deepseek-chat-v3-0324:free",
342
+ "Claude": "anthropic/claude-3.5-sonnet:beta",
343
+ "Gemma": "google/gemma-3n-e2b-it:free",
344
+ "Kimi": "moonshotai/kimi-dev-72b:free",
345
+ "Llama 3.1": "meta-llama/llama-3.1-405b-instruct:free"
346
+ }
347
+
348
+ # Special handling for Gemma API structure
349
+ if model_name == "Gemma":
350
+ messages = [
351
+ {"role": "user", "content": user_prompt}
352
+ ]
353
+ else:
354
+ messages = [
355
+ {"role": "system", "content": system_prompt},
356
+ {"role": "user", "content": user_prompt},
357
+ ]
358
+
359
+ completion = client.chat.completions.create(
360
+ extra_headers={
361
+ "HTTP-Referer": "https://github.com/AdaptiveMusicExerciseGenerator",
362
+ "X-Title": "Music Exercise Generator",
363
+ },
364
+ model=model_map[model_name],
365
+ messages=messages,
366
+ temperature=0.7 if level == "Advanced" else 0.5,
367
+ max_tokens=1000,
368
+ top_p=0.95,
369
+ frequency_penalty=0.2,
370
+ presence_penalty=0.2,
371
+ )
372
+ content = completion.choices[0].message.content
373
+ return content.replace("```json","").replace("```","").strip()
374
+
375
+ else:
376
+ return get_fallback_exercise(instrument, level, key, time_sig, measures)
377
+
378
+ except Exception as e:
379
+ print(f"Error querying {model_name} API (attempt {attempt+1}): {e}")
380
+ if "429" in str(e) or "Rate limit" in str(e):
381
+ print(f"Rate limited, retrying in {retry_delay} seconds...")
382
+ time.sleep(retry_delay)
383
+ retry_delay *= 2 # Exponential backoff
384
+ else:
385
+ break
386
+
387
+ # Fallback to Mistral if other APIs fail
388
+ print(f"All attempts failed for {model_name}, using Mistral fallback")
389
+ try:
390
  headers = {
391
  "Authorization": f"Bearer {MISTRAL_API_KEY}",
392
  "Content-Type": "application/json",
 
403
  "frequency_penalty": 0.2,
404
  "presence_penalty": 0.2,
405
  }
406
+ response = requests.post(MISTRAL_API_URL, headers=headers, json=payload)
407
+ response.raise_for_status()
408
+ content = response.json()["choices"][0]["message"]["content"]
409
+ return content.replace("```json","").replace("```","").strip()
410
+ except Exception as e:
411
+ print(f"Error querying Mistral fallback: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  return get_fallback_exercise(instrument, level, key, time_sig, measures)
413
 
414
  # -----------------------------------------------------------------------------
 
457
  return f"Error: {str(e)}", None, str(tempo), None, "0", time_signature, 0
458
 
459
  # -----------------------------------------------------------------------------
460
+ # 12. AI chat assistant with enhanced error handling
461
  # -----------------------------------------------------------------------------
462
  def handle_chat(message: str, history: List, instrument: str, level: str, ai_model: str):
463
  if not message.strip():
 
468
  messages.append({"role": "assistant", "content": assistant_msg})
469
  messages.append({"role": "user", "content": message})
470
 
471
+ max_retries = 3
472
+ retry_delay = 3 # seconds
 
 
 
 
 
 
 
 
 
 
473
 
474
+ for attempt in range(max_retries):
475
  try:
476
+ if ai_model == "Mistral":
477
+ headers = {"Authorization": f"Bearer {MISTRAL_API_KEY}", "Content-Type": "application/json"}
478
+ payload = {"model": "mistral-medium", "messages": messages, "temperature": 0.7, "max_tokens": 500}
479
+ response = requests.post(MISTRAL_API_URL, headers=headers, json=payload)
480
+ response.raise_for_status()
481
+ content = response.json()["choices"][0]["message"]["content"]
482
+ history.append((message, content))
483
+ return "", history
484
 
485
+ elif ai_model in ["DeepSeek", "Claude", "Gemma", "Kimi", "Llama 3.1"]:
486
+ client = OpenAI(
487
+ base_url="https://openrouter.ai/api/v1",
488
+ api_key=OPENROUTER_API_KEYS[ai_model],
489
+ )
490
+
491
+ model_map = {
492
+ "DeepSeek": "deepseek/deepseek-chat-v3-0324:free",
493
+ "Claude": "anthropic/claude-3.5-sonnet:beta",
494
+ "Gemma": "google/gemma-3n-e2b-it:free",
495
+ "Kimi": "moonshotai/kimi-dev-72b:free",
496
+ "Llama 3.1": "meta-llama/llama-3.1-405b-instruct:free"
497
+ }
498
+
499
+ # Special handling for Gemma API structure
500
+ if ai_model == "Gemma":
501
+ adjusted_messages = [{"role": "user", "content": msg["content"]} for msg in messages]
502
+ else:
503
+ adjusted_messages = messages
504
+
505
+ completion = client.chat.completions.create(
506
+ extra_headers={
507
+ "HTTP-Referer": "https://github.com/AdaptiveMusicExerciseGenerator",
508
+ "X-Title": "Music Exercise Generator",
509
+ },
510
+ model=model_map[ai_model],
511
+ messages=adjusted_messages,
512
+ temperature=0.7,
513
+ max_tokens=500,
514
+ )
515
+ content = completion.choices[0].message.content
516
+ history.append((message, content))
517
+ return "", history
518
 
519
+ else:
520
+ history.append((message, "Error: Invalid AI model selected"))
521
+ return "", history
522
+
 
 
 
 
 
 
 
 
 
523
  except Exception as e:
524
+ print(f"Chat error with {ai_model} (attempt {attempt+1}): {e}")
525
+ if "429" in str(e) or "Rate limit" in str(e):
526
+ print(f"Rate limited, retrying in {retry_delay} seconds...")
527
+ time.sleep(retry_delay)
528
+ retry_delay *= 2 # Exponential backoff
529
+ else:
530
+ # Fallback to Mistral
531
+ print(f"Using Mistral fallback for chat")
532
+ try:
533
+ headers = {"Authorization": f"Bearer {MISTRAL_API_KEY}", "Content-Type": "application/json"}
534
+ payload = {"model": "mistral-medium", "messages": messages, "temperature": 0.7, "max_tokens": 500}
535
+ response = requests.post(MISTRAL_API_URL, headers=headers, json=payload)
536
+ response.raise_for_status()
537
+ content = response.json()["choices"][0]["message"]["content"]
538
+ history.append((message, content))
539
+ return "", history
540
+ except Exception as e:
541
+ history.append((message, f"Error: {str(e)}"))
542
+ return "", history
543
+
544
+ history.append((message, "Error: All API attempts failed"))
545
+ return "", history
546
 
547
  # -----------------------------------------------------------------------------
548
  # 13. Gradio user interface definition