Spaces:
Sleeping
Sleeping
| """ | |
| Adaptive Music Exercise Generator (Strict Duration Enforcement) | |
| ============================================================== | |
| Generates custom musical exercises with LLM, perfectly fit to user-specified number of measures | |
| AND time signature, guaranteeing exact durations in MIDI and in the UI! | |
| Major updates: | |
| - Added Gemma, Kimi Dev 72b, and Llama 3.1 AI model options | |
| - Added duration sum display in Exercise Data tab | |
| - Shows total duration units (16th notes) for verification | |
| - Added DeepSeek AI model option | |
| - Fixed difficulty level implementation | |
| - Maintained all original functionality | |
| """ | |
| # ----------------------------------------------------------------------------- | |
| # 1. Runtime-time package installation (for fresh containers/Colab/etc) | |
| # ----------------------------------------------------------------------------- | |
| import sys | |
| import subprocess | |
| from typing import Dict, Optional, Tuple, List | |
| import time | |
| import random | |
| def install(packages: List[str]): | |
| for package in packages: | |
| try: | |
| __import__(package) | |
| except ImportError: | |
| print(f"Installing missing package: {package}") | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", package]) | |
| install([ | |
| "mido", "midi2audio", "pydub", "gradio", "openai", | |
| "requests", "numpy", "matplotlib", "librosa", "scipy", | |
| ]) | |
| # ----------------------------------------------------------------------------- | |
| # 2. Static imports | |
| # ----------------------------------------------------------------------------- | |
| import requests | |
| import json | |
| import tempfile | |
| import mido | |
| from mido import Message, MidiFile, MidiTrack, MetaMessage | |
| import re | |
| from io import BytesIO | |
| from midi2audio import FluidSynth | |
| from pydub import AudioSegment | |
| import gradio as gr | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import librosa | |
| from scipy.io import wavfile | |
| import os | |
| import subprocess as sp | |
| import base64 | |
| import shutil | |
| from openai import OpenAI # For API models | |
| # ----------------------------------------------------------------------------- | |
| # 3. Configuration & constants | |
| # ----------------------------------------------------------------------------- | |
| MISTRAL_API_URL = "https://api.mistral.ai/v1/chat/completions" | |
| MISTRAL_API_KEY = "yQdfM8MLbX9uhInQ7id4iUTwN4h4pDLX" # Replace with your key | |
| OPENROUTER_API_KEYS = { | |
| "DeepSeek": "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", | |
| "Claude": "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", | |
| "Gemma": "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", | |
| "Kimi": "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", | |
| "Llama 3.1": "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" | |
| } | |
| SOUNDFONT_URLS = { | |
| "Trumpet": "https://github.com/FluidSynth/fluidsynth/raw/master/sf2/Trumpet.sf2", | |
| "Piano": "https://musical-artifacts.com/artifacts/2719/GeneralUser_GS_1.471.sf2", | |
| "Violin": "https://musical-artifacts.com/artifacts/2744/SalC5Light.sf2", | |
| "Clarinet": "https://musical-artifacts.com/artifacts/2744/SalC5Light.sf2", | |
| "Flute": "https://musical-artifacts.com/artifacts/2744/SalC5Light.sf2", | |
| } | |
| SAMPLE_RATE = 44100 # Hz | |
| TICKS_PER_BEAT = 480 # Standard MIDI resolution | |
| TICKS_PER_16TH = TICKS_PER_BEAT // 4 # 120 ticks per 16th note | |
| if not os.path.exists('/usr/bin/fluidsynth'): | |
| try: | |
| os.system('apt-get update && apt-get install -y fluidsynth') | |
| except Exception: | |
| print("Could not install FluidSynth automatically. Please install it manually.") | |
| os.makedirs("static", exist_ok=True) | |
| # ----------------------------------------------------------------------------- | |
| # 4. Music theory helpers (note names ↔︎ MIDI numbers) | |
| # ----------------------------------------------------------------------------- | |
| NOTE_MAP: Dict[str, int] = { | |
| "C": 0, "C#": 1, "DB": 1, | |
| "D": 2, "D#": 3, "EB": 3, | |
| "E": 4, "F": 5, "F#": 6, "GB": 6, | |
| "G": 7, "G#": 8, "AB": 8, | |
| "A": 9, "A#": 10, "BB": 10, | |
| "B": 11, | |
| } | |
| INSTRUMENT_PROGRAMS: Dict[str, int] = { | |
| "Piano": 0, "Trumpet": 56, "Violin": 40, | |
| "Clarinet": 71, "Flute": 73, | |
| } | |
| def note_name_to_midi(note: str) -> int: | |
| match = re.match(r"([A-Ga-g][#b]?)(\d)", note) | |
| if not match: | |
| raise ValueError(f"Invalid note: {note}") | |
| pitch, octave = match.groups() | |
| pitch = pitch.upper().replace('b', 'B') | |
| if pitch not in NOTE_MAP: | |
| raise ValueError(f"Invalid pitch: {pitch}") | |
| return NOTE_MAP[pitch] + (int(octave) + 1) * 12 | |
| def midi_to_note_name(midi_num: int) -> str: | |
| notes = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"] | |
| octave = (midi_num // 12) - 1 | |
| return f"{notes[midi_num % 12]}{octave}" | |
| # ----------------------------------------------------------------------------- | |
| # 5. Duration scaling: guarantee the output sums to requested total (using integers) | |
| # ----------------------------------------------------------------------------- | |
| def scale_json_durations(json_data, target_units: int) -> list: | |
| """Scales durations so that their sum is exactly target_units (16th notes).""" | |
| durations = [int(d) for _, d in json_data] | |
| total = sum(durations) | |
| if total == 0: | |
| return json_data | |
| # Calculate proportional scaling with integer arithmetic | |
| scaled = [] | |
| remainder = target_units | |
| for i, (note, d) in enumerate(json_data): | |
| if i < len(json_data) - 1: | |
| # Proportional allocation | |
| portion = max(1, round(d * target_units / total)) | |
| scaled.append([note, portion]) | |
| remainder -= portion | |
| else: | |
| # Last note gets all remaining units | |
| scaled.append([note, max(1, remainder)]) | |
| return scaled | |
| # ----------------------------------------------------------------------------- | |
| # 6. MIDI from scaled JSON (using integer durations) | |
| # ----------------------------------------------------------------------------- | |
| def json_to_midi(json_data: list, instrument: str, tempo: int, time_signature: str, measures: int) -> MidiFile: | |
| mid = MidiFile(ticks_per_beat=TICKS_PER_BEAT) | |
| track = MidiTrack(); mid.tracks.append(track) | |
| program = INSTRUMENT_PROGRAMS.get(instrument, 56) | |
| numerator, denominator = map(int, time_signature.split('/')) | |
| track.append(MetaMessage('time_signature', numerator=numerator, | |
| denominator=denominator, time=0)) | |
| track.append(MetaMessage('set_tempo', tempo=mido.bpm2tempo(tempo), time=0)) | |
| track.append(Message('program_change', program=program, time=0)) | |
| for note_name, duration_units in json_data: | |
| try: | |
| note_num = note_name_to_midi(note_name) | |
| ticks = int(duration_units * TICKS_PER_16TH) | |
| ticks = max(ticks, 1) | |
| velocity = random.randint(60, 100) | |
| track.append(Message('note_on', note=note_num, velocity=velocity, time=0)) | |
| track.append(Message('note_off', note=note_num, velocity=velocity, time=ticks)) | |
| except Exception as e: | |
| print(f"Error parsing note {note_name}: {e}") | |
| return mid | |
| # ----------------------------------------------------------------------------- | |
| # 7. MIDI → Audio (MP3) helpers | |
| # ----------------------------------------------------------------------------- | |
| def get_soundfont(instrument: str) -> str: | |
| os.makedirs("soundfonts", exist_ok=True) | |
| sf2_path = f"soundfonts/{instrument}.sf2" | |
| if not os.path.exists(sf2_path): | |
| url = SOUNDFONT_URLS.get(instrument, SOUNDFONT_URLS["Trumpet"]) | |
| print(f"Downloading SoundFont for {instrument}…") | |
| response = requests.get(url) | |
| with open(sf2_path, "wb") as f: | |
| f.write(response.content) | |
| return sf2_path | |
| def midi_to_mp3(midi_obj: MidiFile, instrument: str = "Trumpet") -> Tuple[str, float]: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".mid") as mid_file: | |
| midi_obj.save(mid_file.name) | |
| wav_path = mid_file.name.replace(".mid", ".wav") | |
| mp3_path = mid_file.name.replace(".mid", ".mp3") | |
| sf2_path = get_soundfont(instrument) | |
| try: | |
| sp.run([ | |
| 'fluidsynth', '-ni', sf2_path, mid_file.name, | |
| '-F', wav_path, '-r', '44100', '-g', '1.0' | |
| ], check=True, capture_output=True) | |
| except Exception: | |
| fs = FluidSynth(sf2_path, sample_rate=44100, gain=1.0) | |
| fs.midi_to_audio(mid_file.name, wav_path) | |
| try: | |
| sound = AudioSegment.from_wav(wav_path) | |
| if instrument == "Trumpet": | |
| sound = sound.high_pass_filter(200) | |
| elif instrument == "Violin": | |
| sound = sound.low_pass_filter(5000) | |
| sound.export(mp3_path, format="mp3") | |
| static_mp3_path = os.path.join('static', os.path.basename(mp3_path)) | |
| shutil.move(mp3_path, static_mp3_path) | |
| return static_mp3_path, sound.duration_seconds | |
| finally: | |
| for f in [mid_file.name, wav_path]: | |
| try: | |
| os.remove(f) | |
| except FileNotFoundError: | |
| pass | |
| # ----------------------------------------------------------------------------- | |
| # 8. Prompt engineering for variety (using integer durations) | |
| # ----------------------------------------------------------------------------- | |
| def get_fallback_exercise(instrument: str, level: str, key: str, | |
| time_sig: str, measures: int) -> str: | |
| instrument_patterns = { | |
| "Trumpet": ["C4", "D4", "E4", "G4", "E4", "C4"], | |
| "Piano": ["C4", "E4", "G4", "C5", "G4", "E4"], | |
| "Violin": ["G4", "A4", "B4", "D5", "B4", "G4"], | |
| "Clarinet": ["E4", "F4", "G4", "Bb4", "G4", "E4"], | |
| "Flute": ["A4", "B4", "C5", "E5", "C5", "A4"], | |
| } | |
| pattern = instrument_patterns.get(instrument, instrument_patterns["Trumpet"]) | |
| numerator, denominator = map(int, time_sig.split('/')) | |
| units_per_measure = numerator * (16 // denominator) | |
| target_units = measures * units_per_measure | |
| notes, durs = [], [] | |
| i = 0 | |
| # Use quarter notes (4 units) as base duration | |
| while len(notes) * 4 < target_units: | |
| notes.append(pattern[i % len(pattern)]) | |
| durs.append(4) | |
| i += 1 | |
| # Adjust last duration to match total exactly | |
| total_units = len(durs) * 4 | |
| if total_units > target_units: | |
| durs[-1] = 4 - (total_units - target_units) | |
| return json.dumps([[n, d] for n, d in zip(notes, durs)]) | |
| def get_style_based_on_level(level: str) -> str: | |
| styles = { | |
| "Beginner": ["simple", "legato", "stepwise"], | |
| "Intermediate": ["jazzy", "bluesy", "march-like", "syncopated"], | |
| "Advanced": ["technical", "chromatic", "fast arpeggios", "wide intervals"], | |
| } | |
| return random.choice(styles.get(level, ["technical"])) | |
| def get_technique_based_on_level(level: str) -> str: | |
| techniques = { | |
| "Beginner": ["with long tones", "with simple rhythms", "focusing on tone"], | |
| "Intermediate": ["with slurs", "with accents", "using triplets"], | |
| "Advanced": ["with double tonguing", "with extreme registers", "complex rhythms"], | |
| } | |
| return random.choice(techniques.get(level, ["with slurs"])) | |
| # ----------------------------------------------------------------------------- | |
| # 9. LLM Query Function (with enhanced error handling) | |
| # ----------------------------------------------------------------------------- | |
| def query_llm(model_name: str, prompt: str, instrument: str, level: str, key: str, | |
| time_sig: str, measures: int) -> str: | |
| numerator, denominator = map(int, time_sig.split('/')) | |
| units_per_measure = numerator * (16 // denominator) | |
| required_total = measures * units_per_measure | |
| duration_constraint = ( | |
| f"Sum of all durations MUST BE EXACTLY {required_total} units (16th notes). " | |
| f"Each integer duration represents a 16th note (1=16th, 2=8th, 4=quarter, 8=half, 16=whole). " | |
| f"If it doesn't match, the exercise is invalid." | |
| ) | |
| system_prompt = ( | |
| f"You are an expert music teacher specializing in {instrument.lower()}. " | |
| "Create customized exercises using INTEGER durations representing 16th notes." | |
| ) | |
| if prompt.strip(): | |
| user_prompt = ( | |
| f"{prompt} {duration_constraint} Output ONLY a JSON array of [note, duration] pairs." | |
| ) | |
| else: | |
| style = get_style_based_on_level(level) | |
| technique = get_technique_based_on_level(level) | |
| user_prompt = ( | |
| f"Create a {style} {instrument.lower()} exercise in {key} with {time_sig} time signature " | |
| f"{technique} for a {level.lower()} player. {duration_constraint} " | |
| "Output ONLY a JSON array of [note, duration] pairs following these rules: " | |
| "Use standard note names (e.g., \"Bb4\", \"F#5\"). Monophonic only. " | |
| "Durations: 1=16th, 2=8th, 4=quarter, 8=half, 16=whole. " | |
| "Sum must be exactly as specified. ONLY output the JSON array. No prose." | |
| ) | |
| # Retry up to 3 times for rate limited models | |
| max_retries = 3 | |
| retry_delay = 5 # seconds | |
| for attempt in range(max_retries): | |
| try: | |
| if model_name == "Mistral": | |
| headers = { | |
| "Authorization": f"Bearer {MISTRAL_API_KEY}", | |
| "Content-Type": "application/json", | |
| } | |
| payload = { | |
| "model": "mistral-medium", | |
| "messages": [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| "temperature": 0.7 if level == "Advanced" else 0.5, | |
| "max_tokens": 1000, | |
| "top_p": 0.95, | |
| "frequency_penalty": 0.2, | |
| "presence_penalty": 0.2, | |
| } | |
| response = requests.post(MISTRAL_API_URL, headers=headers, json=payload) | |
| response.raise_for_status() | |
| content = response.json()["choices"][0]["message"]["content"] | |
| return content.replace("```json","").replace("```","").strip() | |
| elif model_name in ["DeepSeek", "Claude", "Gemma", "Kimi", "Llama 3.1"]: | |
| client = OpenAI( | |
| base_url="https://openrouter.ai/api/v1", | |
| api_key=OPENROUTER_API_KEYS[model_name], | |
| ) | |
| model_map = { | |
| "DeepSeek": "deepseek/deepseek-chat-v3-0324:free", | |
| "Claude": "anthropic/claude-3.5-sonnet:beta", | |
| "Gemma": "google/gemma-3n-e2b-it:free", | |
| "Kimi": "moonshotai/kimi-dev-72b:free", | |
| "Llama 3.1": "meta-llama/llama-3.1-405b-instruct:free" | |
| } | |
| # Special handling for Gemma API structure | |
| if model_name == "Gemma": | |
| messages = [ | |
| {"role": "user", "content": user_prompt} | |
| ] | |
| else: | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ] | |
| completion = client.chat.completions.create( | |
| extra_headers={ | |
| "HTTP-Referer": "https://github.com/AdaptiveMusicExerciseGenerator", | |
| "X-Title": "Music Exercise Generator", | |
| }, | |
| model=model_map[model_name], | |
| messages=messages, | |
| temperature=0.7 if level == "Advanced" else 0.5, | |
| max_tokens=1000, | |
| top_p=0.95, | |
| frequency_penalty=0.2, | |
| presence_penalty=0.2, | |
| ) | |
| content = completion.choices[0].message.content | |
| return content.replace("```json","").replace("```","").strip() | |
| else: | |
| return get_fallback_exercise(instrument, level, key, time_sig, measures) | |
| except Exception as e: | |
| print(f"Error querying {model_name} API (attempt {attempt+1}): {e}") | |
| if "429" in str(e) or "Rate limit" in str(e): | |
| print(f"Rate limited, retrying in {retry_delay} seconds...") | |
| time.sleep(retry_delay) | |
| retry_delay *= 2 # Exponential backoff | |
| else: | |
| break | |
| # Fallback to Mistral if other APIs fail | |
| print(f"All attempts failed for {model_name}, using Mistral fallback") | |
| try: | |
| headers = { | |
| "Authorization": f"Bearer {MISTRAL_API_KEY}", | |
| "Content-Type": "application/json", | |
| } | |
| payload = { | |
| "model": "mistral-medium", | |
| "messages": [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| "temperature": 0.7 if level == "Advanced" else 0.5, | |
| "max_tokens": 1000, | |
| "top_p": 0.95, | |
| "frequency_penalty": 0.2, | |
| "presence_penalty": 0.2, | |
| } | |
| response = requests.post(MISTRAL_API_URL, headers=headers, json=payload) | |
| response.raise_for_status() | |
| content = response.json()["choices"][0]["message"]["content"] | |
| return content.replace("```json","").replace("```","").strip() | |
| except Exception as e: | |
| print(f"Error querying Mistral fallback: {e}") | |
| return get_fallback_exercise(instrument, level, key, time_sig, measures) | |
| # ----------------------------------------------------------------------------- | |
| # 10. Robust JSON parsing for LLM outputs | |
| # ----------------------------------------------------------------------------- | |
| def safe_parse_json(text: str) -> Optional[list]: | |
| try: | |
| text = text.replace("'", '"') | |
| match = re.search(r"\[(\s*\[.*?\]\s*,?)*\]", text, re.DOTALL) | |
| if match: | |
| return json.loads(match.group(0)) | |
| return json.loads(text) | |
| except Exception as e: | |
| print(f"JSON parsing error: {e}\nRaw text: {text}") | |
| return None | |
| # ----------------------------------------------------------------------------- | |
| # 11. Main orchestration: talk to API, *scale durations*, build MIDI, UI values | |
| # ----------------------------------------------------------------------------- | |
| def generate_exercise(instrument: str, level: str, key: str, tempo: int, time_signature: str, | |
| measures: int, custom_prompt: str, mode: str, ai_model: str) -> Tuple[str, Optional[str], str, MidiFile, str, str, int]: | |
| try: | |
| prompt_to_use = custom_prompt if mode == "Exercise Prompt" else "" | |
| output = query_llm(ai_model, prompt_to_use, instrument, level, key, time_signature, measures) | |
| parsed = safe_parse_json(output) | |
| if not parsed: | |
| return "Invalid JSON format", None, str(tempo), None, "0", time_signature, 0 | |
| # Calculate total required 16th notes | |
| numerator, denominator = map(int, time_signature.split('/')) | |
| units_per_measure = numerator * (16 // denominator) | |
| total_units = measures * units_per_measure | |
| # Strict scaling | |
| parsed_scaled = scale_json_durations(parsed, total_units) | |
| # Calculate total duration units | |
| total_duration = sum(d for _, d in parsed_scaled) | |
| # Generate MIDI and audio | |
| midi = json_to_midi(parsed_scaled, instrument, tempo, time_signature, measures) | |
| mp3_path, real_duration = midi_to_mp3(midi, instrument) | |
| output_json_str = json.dumps(parsed_scaled, indent=2) | |
| return output_json_str, mp3_path, str(tempo), midi, f"{real_duration:.2f} seconds", time_signature, total_duration | |
| except Exception as e: | |
| return f"Error: {str(e)}", None, str(tempo), None, "0", time_signature, 0 | |
| # ----------------------------------------------------------------------------- | |
| # 12. AI chat assistant with enhanced error handling | |
| # ----------------------------------------------------------------------------- | |
| def handle_chat(message: str, history: List, instrument: str, level: str, ai_model: str): | |
| if not message.strip(): | |
| return "", history | |
| messages = [{"role": "system", "content": f"You are a {instrument} teacher for {level} students."}] | |
| for user_msg, assistant_msg in history: | |
| messages.append({"role": "user", "content": user_msg}) | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| messages.append({"role": "user", "content": message}) | |
| max_retries = 3 | |
| retry_delay = 3 # seconds | |
| for attempt in range(max_retries): | |
| try: | |
| if ai_model == "Mistral": | |
| headers = {"Authorization": f"Bearer {MISTRAL_API_KEY}", "Content-Type": "application/json"} | |
| payload = {"model": "mistral-medium", "messages": messages, "temperature": 0.7, "max_tokens": 500} | |
| response = requests.post(MISTRAL_API_URL, headers=headers, json=payload) | |
| response.raise_for_status() | |
| content = response.json()["choices"][0]["message"]["content"] | |
| history.append((message, content)) | |
| return "", history | |
| elif ai_model in ["DeepSeek", "Claude", "Gemma", "Kimi", "Llama 3.1"]: | |
| client = OpenAI( | |
| base_url="https://openrouter.ai/api/v1", | |
| api_key=OPENROUTER_API_KEYS[ai_model], | |
| ) | |
| model_map = { | |
| "DeepSeek": "deepseek/deepseek-chat-v3-0324:free", | |
| "Claude": "anthropic/claude-3.5-sonnet:beta", | |
| "Gemma": "google/gemma-3n-e2b-it:free", | |
| "Kimi": "moonshotai/kimi-dev-72b:free", | |
| "Llama 3.1": "meta-llama/llama-3.1-405b-instruct:free" | |
| } | |
| # Special handling for Gemma API structure | |
| if ai_model == "Gemma": | |
| adjusted_messages = [{"role": "user", "content": msg["content"]} for msg in messages] | |
| else: | |
| adjusted_messages = messages | |
| completion = client.chat.completions.create( | |
| extra_headers={ | |
| "HTTP-Referer": "https://github.com/AdaptiveMusicExerciseGenerator", | |
| "X-Title": "Music Exercise Generator", | |
| }, | |
| model=model_map[ai_model], | |
| messages=adjusted_messages, | |
| temperature=0.7, | |
| max_tokens=500, | |
| ) | |
| content = completion.choices[0].message.content | |
| history.append((message, content)) | |
| return "", history | |
| else: | |
| history.append((message, "Error: Invalid AI model selected")) | |
| return "", history | |
| except Exception as e: | |
| print(f"Chat error with {ai_model} (attempt {attempt+1}): {e}") | |
| if "429" in str(e) or "Rate limit" in str(e): | |
| print(f"Rate limited, retrying in {retry_delay} seconds...") | |
| time.sleep(retry_delay) | |
| retry_delay *= 2 # Exponential backoff | |
| else: | |
| # Fallback to Mistral | |
| print(f"Using Mistral fallback for chat") | |
| try: | |
| headers = {"Authorization": f"Bearer {MISTRAL_API_KEY}", "Content-Type": "application/json"} | |
| payload = {"model": "mistral-medium", "messages": messages, "temperature": 0.7, "max_tokens": 500} | |
| response = requests.post(MISTRAL_API_URL, headers=headers, json=payload) | |
| response.raise_for_status() | |
| content = response.json()["choices"][0]["message"]["content"] | |
| history.append((message, content)) | |
| return "", history | |
| except Exception as e: | |
| history.append((message, f"Error: {str(e)}")) | |
| return "", history | |
| history.append((message, "Error: All API attempts failed")) | |
| return "", history | |
| # ----------------------------------------------------------------------------- | |
| # 13. Gradio user interface definition | |
| # ----------------------------------------------------------------------------- | |
| def create_ui() -> gr.Blocks: | |
| with gr.Blocks(title="Adaptive Music Exercise Generator", theme="soft") as demo: | |
| gr.Markdown("# 🎼 Adaptive Music Exercise Generator") | |
| current_midi = gr.State(None) | |
| current_exercise = gr.State("") | |
| mode = gr.Radio(["Exercise Parameters","Exercise Prompt"], value="Exercise Parameters", label="Generation Mode") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Group(visible=True) as params_group: | |
| gr.Markdown("### Exercise Parameters") | |
| ai_model = gr.Radio( | |
| ["Mistral", "DeepSeek", "Claude", "Gemma", "Kimi", "Llama 3.1"], | |
| value="Mistral", | |
| label="AI Model" | |
| ) | |
| instrument = gr.Dropdown([ | |
| "Trumpet", "Piano", "Violin", "Clarinet", "Flute", | |
| ], value="Trumpet", label="Instrument") | |
| level = gr.Radio([ | |
| "Beginner", "Intermediate", "Advanced", | |
| ], value="Intermediate", label="Difficulty Level") | |
| key = gr.Dropdown([ | |
| "C Major", "G Major", "D Major", "F Major", "Bb Major", "A Minor", "E Minor", | |
| ], value="C Major", label="Key Signature") | |
| time_signature = gr.Dropdown(["3/4", "4/4"], value="4/4", label="Time Signature") | |
| measures = gr.Radio([4, 8], value=4, label="Length (measures)") | |
| with gr.Group(visible=False) as prompt_group: | |
| gr.Markdown("### Exercise Prompt") | |
| custom_prompt = gr.Textbox("", label="Enter your custom prompt", lines=3) | |
| measures_prompt = gr.Radio([4, 8], value=4, label="Length (measures)") | |
| generate_btn = gr.Button("Generate Exercise", variant="primary") | |
| with gr.Column(scale=2): | |
| with gr.Tabs(): | |
| with gr.TabItem("Exercise Player"): | |
| audio_output = gr.Audio(label="Generated Exercise", autoplay=True, type="filepath") | |
| bpm_display = gr.Textbox(label="Tempo (BPM)") | |
| time_sig_display = gr.Textbox(label="Time Signature") | |
| duration_display = gr.Textbox(label="Audio Duration", interactive=False) | |
| with gr.TabItem("Exercise Data"): | |
| json_output = gr.Code(label="JSON Representation", language="json") | |
| # Duration sum display | |
| duration_sum = gr.Number( | |
| label="Total Duration Units (16th notes)", | |
| interactive=False, | |
| precision=0 | |
| ) | |
| with gr.TabItem("MIDI Export"): | |
| midi_output = gr.File(label="MIDI File") | |
| download_midi = gr.Button("Generate MIDI File") | |
| with gr.TabItem("AI Chat"): | |
| chat_history = gr.Chatbot(label="Practice Assistant", height=400) | |
| chat_message = gr.Textbox(label="Ask the AI anything about your practice") | |
| send_chat_btn = gr.Button("Send") | |
| # Toggle UI groups | |
| mode.change( | |
| fn=lambda m: { | |
| params_group: gr.update(visible=(m == "Exercise Parameters")), | |
| prompt_group: gr.update(visible=(m == "Exercise Prompt")), | |
| }, | |
| inputs=[mode], outputs=[params_group, prompt_group] | |
| ) | |
| def generate_caller(mode_val, instrument_val, level_val, key_val, | |
| time_sig_val, measures_val, prompt_val, measures_prompt_val, ai_model_val): | |
| real_measures = measures_prompt_val if mode_val == "Exercise Prompt" else measures_val | |
| fixed_tempo = 60 | |
| return generate_exercise( | |
| instrument_val, level_val, key_val, fixed_tempo, time_sig_val, | |
| real_measures, prompt_val, mode_val, ai_model_val | |
| ) | |
| generate_btn.click( | |
| fn=generate_caller, | |
| inputs=[mode, instrument, level, key, time_signature, measures, custom_prompt, measures_prompt, ai_model], | |
| outputs=[json_output, audio_output, bpm_display, current_midi, duration_display, time_sig_display, duration_sum] | |
| ) | |
| def save_midi(json_data, instr, time_sig): | |
| parsed = safe_parse_json(json_data) | |
| if not parsed: | |
| return None | |
| numerator, denominator = map(int, time_sig.split('/')) | |
| units_per_measure = numerator * (16 // denominator) | |
| total_units = sum(int(d[1]) for d in parsed) | |
| measures_est = max(1, round(total_units / units_per_measure)) | |
| scaled = scale_json_durations(parsed, measures_est * units_per_measure) | |
| midi_obj = json_to_midi(scaled, instr, 60, time_sig, measures_est) | |
| midi_path = os.path.join("static", "exercise.mid") | |
| midi_obj.save(midi_path) | |
| return midi_path | |
| download_midi.click( | |
| fn=save_midi, | |
| inputs=[json_output, instrument, time_signature], | |
| outputs=[midi_output], | |
| ) | |
| send_chat_btn.click( | |
| fn=handle_chat, | |
| inputs=[chat_message, chat_history, instrument, level, ai_model], | |
| outputs=[chat_message, chat_history], | |
| ) | |
| return demo | |
| # ----------------------------------------------------------------------------- | |
| # 14. Entry point | |
| # ----------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| demo = create_ui() | |
| demo.launch() |