SHIKARICHACHA commited on
Commit
93a0fac
·
verified ·
1 Parent(s): 4f2c4a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -48
app.py CHANGED
@@ -8,6 +8,7 @@ AND time signature, guaranteeing exact durations in MIDI and in the UI!
8
  Major updates:
9
  - Added duration sum display in Exercise Data tab
10
  - Shows total duration units (16th notes) for verification
 
11
  - Fixed difficulty level implementation
12
  - Maintained all original functionality
13
  """
@@ -28,7 +29,7 @@ def install(packages: List[str]):
28
  subprocess.check_call([sys.executable, "-m", "pip", "install", package])
29
 
30
  install([
31
- "mido", "midi2audio", "pydub", "gradio",
32
  "requests", "numpy", "matplotlib", "librosa", "scipy",
33
  ])
34
 
@@ -54,12 +55,14 @@ import os
54
  import subprocess as sp
55
  import base64
56
  import shutil
 
57
 
58
  # -----------------------------------------------------------------------------
59
  # 3. Configuration & constants
60
  # -----------------------------------------------------------------------------
61
  MISTRAL_API_URL = "https://api.mistral.ai/v1/chat/completions"
62
- MISTRAL_API_KEY = "yQdfM8MLbX9uhInQ7id4iUTwN4h4pDLX" # Replace with your key!
 
63
 
64
  SOUNDFONT_URLS = {
65
  "Trumpet": "https://github.com/FluidSynth/fluidsynth/raw/master/sf2/Trumpet.sf2",
@@ -258,14 +261,10 @@ def get_technique_based_on_level(level: str) -> str:
258
  return random.choice(techniques.get(level, ["with slurs"]))
259
 
260
  # -----------------------------------------------------------------------------
261
- # 9. Mistral API: query, fallback on errors (using integer durations)
262
  # -----------------------------------------------------------------------------
263
- def query_mistral(prompt: str, instrument: str, level: str, key: str,
264
- time_sig: str, measures: int) -> str:
265
- headers = {
266
- "Authorization": f"Bearer {MISTRAL_API_KEY}",
267
- "Content-Type": "application/json",
268
- }
269
  numerator, denominator = map(int, time_sig.split('/'))
270
  units_per_measure = numerator * (16 // denominator)
271
  required_total = measures * units_per_measure
@@ -296,26 +295,62 @@ def query_mistral(prompt: str, instrument: str, level: str, key: str,
296
  "Sum must be exactly as specified. ONLY output the JSON array. No prose."
297
  )
298
 
299
- payload = {
300
- "model": "mistral-medium",
301
- "messages": [
302
- {"role": "system", "content": system_prompt},
303
- {"role": "user", "content": user_prompt},
304
- ],
305
- "temperature": 0.7 if level == "Advanced" else 0.5,
306
- "max_tokens": 1000,
307
- "top_p": 0.95,
308
- "frequency_penalty": 0.2,
309
- "presence_penalty": 0.2,
310
- }
311
-
312
- try:
313
- response = requests.post(MISTRAL_API_URL, headers=headers, json=payload)
314
- response.raise_for_status()
315
- content = response.json()["choices"][0]["message"]["content"]
316
- return content.replace("```json","").replace("```","").strip()
317
- except Exception as e:
318
- print(f"Error querying Mistral API: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  return get_fallback_exercise(instrument, level, key, time_sig, measures)
320
 
321
  # -----------------------------------------------------------------------------
@@ -336,10 +371,10 @@ def safe_parse_json(text: str) -> Optional[list]:
336
  # 11. Main orchestration: talk to API, *scale durations*, build MIDI, UI values
337
  # -----------------------------------------------------------------------------
338
  def generate_exercise(instrument: str, level: str, key: str, tempo: int, time_signature: str,
339
- measures: int, custom_prompt: str, mode: str) -> Tuple[str, Optional[str], str, MidiFile, str, str, int]:
340
  try:
341
  prompt_to_use = custom_prompt if mode == "Exercise Prompt" else ""
342
- output = query_mistral(prompt_to_use, instrument, level, key, time_signature, measures)
343
  parsed = safe_parse_json(output)
344
  if not parsed:
345
  return "Invalid JSON format", None, str(tempo), None, "0", time_signature, 0
@@ -366,7 +401,7 @@ def generate_exercise(instrument: str, level: str, key: str, tempo: int, time_si
366
  # -----------------------------------------------------------------------------
367
  # 12. Simple AI chat assistant (optional, shares LLM)
368
  # -----------------------------------------------------------------------------
369
- def handle_chat(message: str, history: List, instrument: str, level: str):
370
  if not message.strip():
371
  return "", history
372
  messages = [{"role": "system", "content": f"You are a {instrument} teacher for {level} students."}]
@@ -374,16 +409,46 @@ def handle_chat(message: str, history: List, instrument: str, level: str):
374
  messages.append({"role": "user", "content": user_msg})
375
  messages.append({"role": "assistant", "content": assistant_msg})
376
  messages.append({"role": "user", "content": message})
377
- headers = {"Authorization": f"Bearer {MISTRAL_API_KEY}", "Content-Type": "application/json"}
378
- payload = {"model": "mistral-medium", "messages": messages, "temperature": 0.7, "max_tokens": 500}
379
- try:
380
- response = requests.post(MISTRAL_API_URL, headers=headers, json=payload)
381
- response.raise_for_status()
382
- content = response.json()["choices"][0]["message"]["content"]
383
- history.append((message, content))
384
- return "", history
385
- except Exception as e:
386
- history.append((message, f"Error: {str(e)}"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
  return "", history
388
 
389
  # -----------------------------------------------------------------------------
@@ -400,6 +465,11 @@ def create_ui() -> gr.Blocks:
400
  with gr.Column(scale=1):
401
  with gr.Group(visible=True) as params_group:
402
  gr.Markdown("### Exercise Parameters")
 
 
 
 
 
403
  instrument = gr.Dropdown([
404
  "Trumpet", "Piano", "Violin", "Clarinet", "Flute",
405
  ], value="Trumpet", label="Instrument")
@@ -425,7 +495,7 @@ def create_ui() -> gr.Blocks:
425
  duration_display = gr.Textbox(label="Audio Duration", interactive=False)
426
  with gr.TabItem("Exercise Data"):
427
  json_output = gr.Code(label="JSON Representation", language="json")
428
- # NEW: Duration sum display
429
  duration_sum = gr.Number(
430
  label="Total Duration Units (16th notes)",
431
  interactive=False,
@@ -447,16 +517,16 @@ def create_ui() -> gr.Blocks:
447
  inputs=[mode], outputs=[params_group, prompt_group]
448
  )
449
  def generate_caller(mode_val, instrument_val, level_val, key_val,
450
- time_sig_val, measures_val, prompt_val, measures_prompt_val):
451
  real_measures = measures_prompt_val if mode_val == "Exercise Prompt" else measures_val
452
  fixed_tempo = 60
453
  return generate_exercise(
454
  instrument_val, level_val, key_val, fixed_tempo, time_sig_val,
455
- real_measures, prompt_val, mode_val
456
  )
457
  generate_btn.click(
458
  fn=generate_caller,
459
- inputs=[mode, instrument, level, key, time_signature,measures, custom_prompt, measures_prompt],
460
  outputs=[json_output, audio_output, bpm_display, current_midi, duration_display, time_sig_display, duration_sum]
461
  )
462
  def save_midi(json_data, instr, time_sig):
@@ -479,7 +549,7 @@ def create_ui() -> gr.Blocks:
479
  )
480
  send_chat_btn.click(
481
  fn=handle_chat,
482
- inputs=[chat_message, chat_history, instrument, level],
483
  outputs=[chat_message, chat_history],
484
  )
485
  return demo
@@ -489,4 +559,4 @@ def create_ui() -> gr.Blocks:
489
  # -----------------------------------------------------------------------------
490
  if __name__ == "__main__":
491
  demo = create_ui()
492
- demo.launch()
 
8
  Major updates:
9
  - Added duration sum display in Exercise Data tab
10
  - Shows total duration units (16th notes) for verification
11
+ - Added DeepSeek AI model option
12
  - Fixed difficulty level implementation
13
  - Maintained all original functionality
14
  """
 
29
  subprocess.check_call([sys.executable, "-m", "pip", "install", package])
30
 
31
  install([
32
+ "mido", "midi2audio", "pydub", "gradio", "openai",
33
  "requests", "numpy", "matplotlib", "librosa", "scipy",
34
  ])
35
 
 
55
  import subprocess as sp
56
  import base64
57
  import shutil
58
+ from openai import OpenAI # For DeepSeek API
59
 
60
  # -----------------------------------------------------------------------------
61
  # 3. Configuration & constants
62
  # -----------------------------------------------------------------------------
63
  MISTRAL_API_URL = "https://api.mistral.ai/v1/chat/completions"
64
+ MISTRAL_API_KEY = "yQdfM8MLbX9uhInQ7id4iUTwN4h4pDLX" # Replace with your key
65
+ DEEPSEEK_API_KEY = "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"
66
 
67
  SOUNDFONT_URLS = {
68
  "Trumpet": "https://github.com/FluidSynth/fluidsynth/raw/master/sf2/Trumpet.sf2",
 
261
  return random.choice(techniques.get(level, ["with slurs"]))
262
 
263
  # -----------------------------------------------------------------------------
264
+ # 9. LLM Query Function (supports Mistral and DeepSeek)
265
  # -----------------------------------------------------------------------------
266
+ def query_llm(model_name: str, prompt: str, instrument: str, level: str, key: str,
267
+ time_sig: str, measures: int) -> str:
 
 
 
 
268
  numerator, denominator = map(int, time_sig.split('/'))
269
  units_per_measure = numerator * (16 // denominator)
270
  required_total = measures * units_per_measure
 
295
  "Sum must be exactly as specified. ONLY output the JSON array. No prose."
296
  )
297
 
298
+ if model_name == "Mistral":
299
+ headers = {
300
+ "Authorization": f"Bearer {MISTRAL_API_KEY}",
301
+ "Content-Type": "application/json",
302
+ }
303
+ payload = {
304
+ "model": "mistral-medium",
305
+ "messages": [
306
+ {"role": "system", "content": system_prompt},
307
+ {"role": "user", "content": user_prompt},
308
+ ],
309
+ "temperature": 0.7 if level == "Advanced" else 0.5,
310
+ "max_tokens": 1000,
311
+ "top_p": 0.95,
312
+ "frequency_penalty": 0.2,
313
+ "presence_penalty": 0.2,
314
+ }
315
+ try:
316
+ response = requests.post(MISTRAL_API_URL, headers=headers, json=payload)
317
+ response.raise_for_status()
318
+ content = response.json()["choices"][0]["message"]["content"]
319
+ return content.replace("```json","").replace("```","").strip()
320
+ except Exception as e:
321
+ print(f"Error querying Mistral API: {e}")
322
+ return get_fallback_exercise(instrument, level, key, time_sig, measures)
323
+
324
+ elif model_name == "DeepSeek":
325
+ try:
326
+ client = OpenAI(
327
+ base_url="https://openrouter.ai/api/v1",
328
+ api_key=DEEPSEEK_API_KEY,
329
+ )
330
+
331
+ completion = client.chat.completions.create(
332
+ extra_headers={
333
+ "HTTP-Referer": "https://github.com/AdaptiveMusicExerciseGenerator",
334
+ "X-Title": "Music Exercise Generator",
335
+ },
336
+ model="deepseek/deepseek-chat-v3-0324:free",
337
+ messages=[
338
+ {"role": "system", "content": system_prompt},
339
+ {"role": "user", "content": user_prompt},
340
+ ],
341
+ temperature=0.7 if level == "Advanced" else 0.5,
342
+ max_tokens=1000,
343
+ top_p=0.95,
344
+ frequency_penalty=0.2,
345
+ presence_penalty=0.2,
346
+ )
347
+ content = completion.choices[0].message.content
348
+ return content.replace("```json","").replace("```","").strip()
349
+ except Exception as e:
350
+ print(f"Error querying DeepSeek API: {e}")
351
+ return get_fallback_exercise(instrument, level, key, time_sig, measures)
352
+
353
+ else:
354
  return get_fallback_exercise(instrument, level, key, time_sig, measures)
355
 
356
  # -----------------------------------------------------------------------------
 
371
  # 11. Main orchestration: talk to API, *scale durations*, build MIDI, UI values
372
  # -----------------------------------------------------------------------------
373
  def generate_exercise(instrument: str, level: str, key: str, tempo: int, time_signature: str,
374
+ measures: int, custom_prompt: str, mode: str, ai_model: str) -> Tuple[str, Optional[str], str, MidiFile, str, str, int]:
375
  try:
376
  prompt_to_use = custom_prompt if mode == "Exercise Prompt" else ""
377
+ output = query_llm(ai_model, prompt_to_use, instrument, level, key, time_signature, measures)
378
  parsed = safe_parse_json(output)
379
  if not parsed:
380
  return "Invalid JSON format", None, str(tempo), None, "0", time_signature, 0
 
401
  # -----------------------------------------------------------------------------
402
  # 12. Simple AI chat assistant (optional, shares LLM)
403
  # -----------------------------------------------------------------------------
404
+ def handle_chat(message: str, history: List, instrument: str, level: str, ai_model: str):
405
  if not message.strip():
406
  return "", history
407
  messages = [{"role": "system", "content": f"You are a {instrument} teacher for {level} students."}]
 
409
  messages.append({"role": "user", "content": user_msg})
410
  messages.append({"role": "assistant", "content": assistant_msg})
411
  messages.append({"role": "user", "content": message})
412
+
413
+ if ai_model == "Mistral":
414
+ headers = {"Authorization": f"Bearer {MISTRAL_API_KEY}", "Content-Type": "application/json"}
415
+ payload = {"model": "mistral-medium", "messages": messages, "temperature": 0.7, "max_tokens": 500}
416
+ try:
417
+ response = requests.post(MISTRAL_API_URL, headers=headers, json=payload)
418
+ response.raise_for_status()
419
+ content = response.json()["choices"][0]["message"]["content"]
420
+ history.append((message, content))
421
+ return "", history
422
+ except Exception as e:
423
+ history.append((message, f"Error: {str(e)}"))
424
+ return "", history
425
+
426
+ elif ai_model == "DeepSeek":
427
+ try:
428
+ client = OpenAI(
429
+ base_url="https://openrouter.ai/api/v1",
430
+ api_key=DEEPSEEK_API_KEY,
431
+ )
432
+
433
+ completion = client.chat.completions.create(
434
+ extra_headers={
435
+ "HTTP-Referer": "https://github.com/AdaptiveMusicExerciseGenerator",
436
+ "X-Title": "Music Exercise Generator",
437
+ },
438
+ model="deepseek/deepseek-chat-v3-0324:free",
439
+ messages=messages,
440
+ temperature=0.7,
441
+ max_tokens=500,
442
+ )
443
+ content = completion.choices[0].message.content
444
+ history.append((message, content))
445
+ return "", history
446
+ except Exception as e:
447
+ history.append((message, f"Error: {str(e)}"))
448
+ return "", history
449
+
450
+ else:
451
+ history.append((message, "Error: Invalid AI model selected"))
452
  return "", history
453
 
454
  # -----------------------------------------------------------------------------
 
465
  with gr.Column(scale=1):
466
  with gr.Group(visible=True) as params_group:
467
  gr.Markdown("### Exercise Parameters")
468
+ ai_model = gr.Radio(
469
+ ["Mistral", "DeepSeek"],
470
+ value="Mistral",
471
+ label="AI Model"
472
+ )
473
  instrument = gr.Dropdown([
474
  "Trumpet", "Piano", "Violin", "Clarinet", "Flute",
475
  ], value="Trumpet", label="Instrument")
 
495
  duration_display = gr.Textbox(label="Audio Duration", interactive=False)
496
  with gr.TabItem("Exercise Data"):
497
  json_output = gr.Code(label="JSON Representation", language="json")
498
+ # Duration sum display
499
  duration_sum = gr.Number(
500
  label="Total Duration Units (16th notes)",
501
  interactive=False,
 
517
  inputs=[mode], outputs=[params_group, prompt_group]
518
  )
519
  def generate_caller(mode_val, instrument_val, level_val, key_val,
520
+ time_sig_val, measures_val, prompt_val, measures_prompt_val, ai_model_val):
521
  real_measures = measures_prompt_val if mode_val == "Exercise Prompt" else measures_val
522
  fixed_tempo = 60
523
  return generate_exercise(
524
  instrument_val, level_val, key_val, fixed_tempo, time_sig_val,
525
+ real_measures, prompt_val, mode_val, ai_model_val
526
  )
527
  generate_btn.click(
528
  fn=generate_caller,
529
+ inputs=[mode, instrument, level, key, time_signature, measures, custom_prompt, measures_prompt, ai_model],
530
  outputs=[json_output, audio_output, bpm_display, current_midi, duration_display, time_sig_display, duration_sum]
531
  )
532
  def save_midi(json_data, instr, time_sig):
 
549
  )
550
  send_chat_btn.click(
551
  fn=handle_chat,
552
+ inputs=[chat_message, chat_history, instrument, level, ai_model],
553
  outputs=[chat_message, chat_history],
554
  )
555
  return demo
 
559
  # -----------------------------------------------------------------------------
560
  if __name__ == "__main__":
561
  demo = create_ui()
562
+ demo.launch()