import os
import json
import gradio as gr
import requests
from dotenv import load_dotenv
from datetime import datetime
from pathlib import Path
from basic_pitch.inference import predict_and_save
from basic_pitch import ICASSP_2022_MODEL_PATH
from music21 import converter
import base64
# === 1. Environment Configuration & OpenAI Client ===
load_dotenv()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
MUSICGEN_API_URL = os.getenv("MUSICGEN_API_URL")
VEROVIO_API_URL = os.getenv("VEROVIO_API_URL")
assert OPENAI_API_KEY, "❌ Please set OPENAI_API_KEY in your .env file"
# Use OpenAI v1 client
from openai import OpenAI
openai_client = OpenAI(api_key=OPENAI_API_KEY)
# Create output directory if it doesn't exist
Path("output").mkdir(exist_ok=True)
# === 2. Tool Functions ===
def generate_music_from_hum(melody_file: str, prompt: str) -> str:
"""
Call an external MusicGen API to generate a music WAV file
based on a user’s humming audio and a style prompt.
"""
if not MUSICGEN_API_URL:
return "❌ MUSICGEN_API_URL is not configured"
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_wav = f"output/generated_{timestamp}.wav"
try:
with open(melody_file, "rb") as f:
files = {"melody": ("hum.wav", f, "audio/wav")}
data = {"text": prompt}
response = requests.post(MUSICGEN_API_URL, files=files, data=data, timeout=180)
if response.status_code != 200:
return f"❌ MusicGen API error {response.status_code}: {response.text}"
with open(output_wav, "wb") as out:
out.write(response.content)
return output_wav
except Exception as e:
return f"❌ Music generation failed: {e}"
def wav_to_musicxml(wav_path: str, timestamp: str=None) -> str:
"""
Convert a WAV audio file to MusicXML using basic-pitch for pitch detection.
"""
ts = timestamp or datetime.now().strftime("%Y%m%d_%H%M%S")
# Remove any old MIDI files
for midi_file in Path("output").glob("*_basic_pitch.mid"):
midi_file.unlink()
# Generate MIDI from the WAV
predict_and_save(
audio_path_list=[wav_path],
output_directory="output",
save_midi=True,
sonify_midi=False,
save_model_outputs=False,
save_notes=False,
model_or_model_path=ICASSP_2022_MODEL_PATH
)
midi_files = list(Path("output").glob("*.mid"))
if not midi_files:
return "❌ Failed to generate MIDI file"
score = converter.parse(str(midi_files[0]))
xml_path = f"output/generated_{ts}.musicxml"
score.write("musicxml", fp=xml_path)
return xml_path
def render_musicxml_via_verovio_api(musicxml_path: str) -> str:
"""
Render a MusicXML file to an SVG preview using the Verovio API.
Returns HTML containing the embedded SVG.
"""
if not VEROVIO_API_URL:
return "❌ VEROVIO_API_URL is not configured"
try:
with open(musicxml_path, "rb") as f:
response = requests.post(VEROVIO_API_URL, files={"file": f}, timeout=120)
if response.status_code != 200:
return f"❌ Verovio API error {response.status_code}: {response.text}"
svg = response.json().get("svg", "")
b64_svg = base64.b64encode(svg.encode("utf-8")).decode("utf-8")
return (
'
'
f'

'
'
'
)
except Exception as e:
return f"❌ SVG rendering failed: {e}"
def generate_score_from_audio(wav_file: str) -> str:
"""
Extract a MusicXML score from a generated music WAV file.
"""
try:
return wav_to_musicxml(wav_file)
except Exception as e:
return f"❌ Score extraction failed: {e}"
# Map of tool names to functions
TOOL_MAP = {
"generate_music_from_hum": generate_music_from_hum,
"wav_to_musicxml": wav_to_musicxml,
"render_musicxml_via_verovio_api": render_musicxml_via_verovio_api,
"generate_score_from_audio": generate_score_from_audio,
}
# === 3. GPT Tool Selection ===
def gpt_decide_tool(message: str, audio_path: str) -> dict:
system_prompt = """
You are an AI music assistant. The user uploads an audio file and provides a request.
Choose the most appropriate tool from the list below and respond with strict JSON:
- generate_music_from_hum(melody_file, prompt)
- wav_to_musicxml(wav_file)
- render_musicxml_via_verovio_api(musicxml_file)
- generate_score_from_audio(wav_file)
JSON format:
{
"tool_name": "...",
"args": { ... },
"explanation": "Reasoning explanation"
}
"""
user_prompt = f"User request: {message}\nAudio file path: {audio_path}"
response = openai_client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
temperature=0.2
)
text = response.choices[0].message.content
try:
return json.loads(text)
except Exception:
return {"error": f"Failed to parse JSON from GPT response:\n{text}"}
# === 4. Main Logic & Dynamic Output Display ===
def handle_request(audio_file, user_prompt):
# Input validation
if not audio_file or not user_prompt:
return (
"❗ Please upload an audio file and enter a request",
"", "",
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
)
plan = gpt_decide_tool(user_prompt, audio_file)
if "error" in plan:
return (plan["error"], "", "") + (gr.update(visible=False),)*3
tool_name = plan["tool_name"]
args = plan.get("args", {})
explanation= plan.get("explanation", "")
log = f"🧠 GPT chose: {tool_name}\n📦 Args: {json.dumps(args, ensure_ascii=False, indent=2)}"
fn = TOOL_MAP.get(tool_name)
if not fn:
return (f"❌ Unknown tool: {tool_name}", explanation, log) + (gr.update(visible=False),)*3
output = fn(**args)
# Determine output type and update components accordingly
if isinstance(output, str) and output.endswith(".wav") and os.path.isfile(output):
return (
"✅ Success", explanation, log,
gr.update(value=output, visible=True), # Audio
gr.update(visible=False), # SVG
gr.update(visible=False) # Text
)
if isinstance(output, str) and output.endswith(".musicxml") and os.path.isfile(output):
# Automatically render MusicXML to SVG
svg_html = render_musicxml_via_verovio_api(output)
return (
"✅ Success", explanation, log,
gr.update(visible=False),
gr.update(value=svg_html, visible=True),
gr.update(visible=False)
)
if isinstance(output, str) and output.strip().startswith("