SHIKARICHACHA commited on
Commit
1d84416
·
verified ·
1 Parent(s): 6b1b297

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -28
app.py CHANGED
@@ -6,8 +6,8 @@ Generates custom musical exercises with LLM, perfectly fit to user-specified num
6
  AND time signature, guaranteeing exact durations in MIDI and in the UI!
7
 
8
  Major updates:
9
- - Added Gemma 2b model option
10
- - Added Claude 3.5 Sonnet model option
11
  - Added duration sum display in Exercise Data tab
12
  - Shows total duration units (16th notes) for verification
13
  - Fixed difficulty level implementation
@@ -64,9 +64,10 @@ from openai import OpenAI # For API models
64
  MISTRAL_API_URL = "https://api.mistral.ai/v1/chat/completions"
65
  MISTRAL_API_KEY = "yQdfM8MLbX9uhInQ7id4iUTwN4h4pDLX" # Replace with your key
66
  OPENROUTER_API_KEYS = {
67
- "DeepSeek": "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",
68
  "Gemma": "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",
69
- "Claude": "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"
 
 
70
  }
71
 
72
  SOUNDFONT_URLS = {
@@ -326,7 +327,7 @@ def query_llm(model_name: str, prompt: str, instrument: str, level: str, key: st
326
  print(f"Error querying Mistral API: {e}")
327
  return get_fallback_exercise(instrument, level, key, time_sig, measures)
328
 
329
- elif model_name in ["DeepSeek", "Gemma", "Claude"]:
330
  try:
331
  client = OpenAI(
332
  base_url="https://openrouter.ai/api/v1",
@@ -334,15 +335,12 @@ def query_llm(model_name: str, prompt: str, instrument: str, level: str, key: st
334
  )
335
 
336
  model_map = {
337
- "DeepSeek": "deepseek/deepseek-chat-v3-0324:free",
338
  "Gemma": "google/gemma-3n-e2b-it:free",
339
- "Claude": "anthropic/claude-3.5-sonnet:beta"
 
 
340
  }
341
 
342
- # Special handling for Gemma to ensure JSON output
343
- if model_name == "Gemma":
344
- user_prompt += " Output in valid JSON format only."
345
-
346
  completion = client.chat.completions.create(
347
  extra_headers={
348
  "HTTP-Referer": "https://github.com/AdaptiveMusicExerciseGenerator",
@@ -358,17 +356,9 @@ def query_llm(model_name: str, prompt: str, instrument: str, level: str, key: st
358
  top_p=0.95,
359
  frequency_penalty=0.2,
360
  presence_penalty=0.2,
361
- response_format={"type": "json_object"} if model_name == "Gemma" else None
362
  )
363
  content = completion.choices[0].message.content
364
-
365
- # Clean up Gemma output
366
- if model_name == "Gemma":
367
- content = content.replace("```json", "").replace("```", "").strip()
368
- # Remove any text before the first [
369
- content = re.sub(r'^[^[]*', '', content)
370
-
371
- return content
372
  except Exception as e:
373
  print(f"Error querying {model_name} API: {e}")
374
  return get_fallback_exercise(instrument, level, key, time_sig, measures)
@@ -382,10 +372,6 @@ def query_llm(model_name: str, prompt: str, instrument: str, level: str, key: st
382
  def safe_parse_json(text: str) -> Optional[list]:
383
  try:
384
  text = text.replace("'", '"')
385
- # Remove any text before the first [ and after the last ]
386
- text = re.sub(r'^[^[]*', '', text)
387
- text = re.sub(r'[^\]]*$', '', text)
388
- # Extract JSON array
389
  match = re.search(r"\[(\s*\[.*?\]\s*,?)*\]", text, re.DOTALL)
390
  if match:
391
  return json.loads(match.group(0))
@@ -450,7 +436,7 @@ def handle_chat(message: str, history: List, instrument: str, level: str, ai_mod
450
  history.append((message, f"Error: {str(e)}"))
451
  return "", history
452
 
453
- elif ai_model in ["DeepSeek", "Gemma", "Claude"]:
454
  try:
455
  client = OpenAI(
456
  base_url="https://openrouter.ai/api/v1",
@@ -458,9 +444,10 @@ def handle_chat(message: str, history: List, instrument: str, level: str, ai_mod
458
  )
459
 
460
  model_map = {
461
- "DeepSeek": "deepseek/deepseek-chat-v3-0324:free",
462
  "Gemma": "google/gemma-3n-e2b-it:free",
463
- "Claude": "anthropic/claude-3.5-sonnet:beta"
 
 
464
  }
465
 
466
  completion = client.chat.completions.create(
@@ -499,7 +486,7 @@ def create_ui() -> gr.Blocks:
499
  with gr.Group(visible=True) as params_group:
500
  gr.Markdown("### Exercise Parameters")
501
  ai_model = gr.Radio(
502
- ["Mistral", "DeepSeek", "Gemma", "Claude"],
503
  value="Mistral",
504
  label="AI Model"
505
  )
 
6
  AND time signature, guaranteeing exact durations in MIDI and in the UI!
7
 
8
  Major updates:
9
+ - Added Gemma and Kimi Dev 72b AI model options
10
+ - Added DeepSeek R1 model option
11
  - Added duration sum display in Exercise Data tab
12
  - Shows total duration units (16th notes) for verification
13
  - Fixed difficulty level implementation
 
64
  MISTRAL_API_URL = "https://api.mistral.ai/v1/chat/completions"
65
  MISTRAL_API_KEY = "yQdfM8MLbX9uhInQ7id4iUTwN4h4pDLX" # Replace with your key
66
  OPENROUTER_API_KEYS = {
 
67
  "Gemma": "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",
68
+ "Kimi": "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",
69
+ "Claude": "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",
70
+ "DeepSeek-R1": "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"
71
  }
72
 
73
  SOUNDFONT_URLS = {
 
327
  print(f"Error querying Mistral API: {e}")
328
  return get_fallback_exercise(instrument, level, key, time_sig, measures)
329
 
330
+ elif model_name in ["Gemma", "Kimi", "Claude", "DeepSeek-R1"]:
331
  try:
332
  client = OpenAI(
333
  base_url="https://openrouter.ai/api/v1",
 
335
  )
336
 
337
  model_map = {
 
338
  "Gemma": "google/gemma-3n-e2b-it:free",
339
+ "Kimi": "moonshotai/kimi-dev-72b:free",
340
+ "Claude": "anthropic/claude-3.5-sonnet:beta",
341
+ "DeepSeek-R1": "deepseek/deepseek-r1-0528:free"
342
  }
343
 
 
 
 
 
344
  completion = client.chat.completions.create(
345
  extra_headers={
346
  "HTTP-Referer": "https://github.com/AdaptiveMusicExerciseGenerator",
 
356
  top_p=0.95,
357
  frequency_penalty=0.2,
358
  presence_penalty=0.2,
 
359
  )
360
  content = completion.choices[0].message.content
361
+ return content.replace("```json","").replace("```","").strip()
 
 
 
 
 
 
 
362
  except Exception as e:
363
  print(f"Error querying {model_name} API: {e}")
364
  return get_fallback_exercise(instrument, level, key, time_sig, measures)
 
372
  def safe_parse_json(text: str) -> Optional[list]:
373
  try:
374
  text = text.replace("'", '"')
 
 
 
 
375
  match = re.search(r"\[(\s*\[.*?\]\s*,?)*\]", text, re.DOTALL)
376
  if match:
377
  return json.loads(match.group(0))
 
436
  history.append((message, f"Error: {str(e)}"))
437
  return "", history
438
 
439
+ elif ai_model in ["Gemma", "Kimi", "Claude", "DeepSeek-R1"]:
440
  try:
441
  client = OpenAI(
442
  base_url="https://openrouter.ai/api/v1",
 
444
  )
445
 
446
  model_map = {
 
447
  "Gemma": "google/gemma-3n-e2b-it:free",
448
+ "Kimi": "moonshotai/kimi-dev-72b:free",
449
+ "Claude": "anthropic/claude-3.5-sonnet:beta",
450
+ "DeepSeek-R1": "deepseek/deepseek-r1-0528:free"
451
  }
452
 
453
  completion = client.chat.completions.create(
 
486
  with gr.Group(visible=True) as params_group:
487
  gr.Markdown("### Exercise Parameters")
488
  ai_model = gr.Radio(
489
+ ["Mistral", "Gemma", "Kimi", "Claude", "DeepSeek-R1"],
490
  value="Mistral",
491
  label="AI Model"
492
  )