Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import gradio as gr | |
| import requests | |
| from dotenv import load_dotenv | |
| from datetime import datetime | |
| from pathlib import Path | |
| from basic_pitch.inference import predict_and_save | |
| from basic_pitch import ICASSP_2022_MODEL_PATH | |
| from music21 import converter | |
| import base64 | |
| # === 1. Environment Configuration & OpenAI Client === | |
| load_dotenv() | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| MUSICGEN_API_URL = os.getenv("MUSICGEN_API_URL") | |
| VEROVIO_API_URL = os.getenv("VEROVIO_API_URL") | |
| assert OPENAI_API_KEY, "β Please set OPENAI_API_KEY in your .env file" | |
| # Use OpenAI v1 client | |
| from openai import OpenAI | |
| openai_client = OpenAI(api_key=OPENAI_API_KEY) | |
| # Create output directory if it doesn't exist | |
| Path("output").mkdir(exist_ok=True) | |
| # === 2. Tool Functions === | |
| def generate_music_from_hum(melody_file: str, prompt: str) -> str: | |
| """ | |
| Call an external MusicGen API to generate a music WAV file | |
| based on a userβs humming audio and a style prompt. | |
| """ | |
| if not MUSICGEN_API_URL: | |
| return "β MUSICGEN_API_URL is not configured" | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| output_wav = f"output/generated_{timestamp}.wav" | |
| try: | |
| with open(melody_file, "rb") as f: | |
| files = {"melody": ("hum.wav", f, "audio/wav")} | |
| data = {"text": prompt} | |
| response = requests.post(MUSICGEN_API_URL, files=files, data=data, timeout=180) | |
| if response.status_code != 200: | |
| return f"β MusicGen API error {response.status_code}: {response.text}" | |
| with open(output_wav, "wb") as out: | |
| out.write(response.content) | |
| return output_wav | |
| except Exception as e: | |
| return f"β Music generation failed: {e}" | |
| def wav_to_musicxml(wav_path: str, timestamp: str=None) -> str: | |
| """ | |
| Convert a WAV audio file to MusicXML using basic-pitch for pitch detection. | |
| """ | |
| ts = timestamp or datetime.now().strftime("%Y%m%d_%H%M%S") | |
| # Remove any old MIDI files | |
| for midi_file in Path("output").glob("*_basic_pitch.mid"): | |
| midi_file.unlink() | |
| # Generate MIDI from the WAV | |
| predict_and_save( | |
| audio_path_list=[wav_path], | |
| output_directory="output", | |
| save_midi=True, | |
| sonify_midi=False, | |
| save_model_outputs=False, | |
| save_notes=False, | |
| model_or_model_path=ICASSP_2022_MODEL_PATH | |
| ) | |
| midi_files = list(Path("output").glob("*.mid")) | |
| if not midi_files: | |
| return "β Failed to generate MIDI file" | |
| score = converter.parse(str(midi_files[0])) | |
| xml_path = f"output/generated_{ts}.musicxml" | |
| score.write("musicxml", fp=xml_path) | |
| return xml_path | |
| def render_musicxml_via_verovio_api(musicxml_path: str) -> str: | |
| """ | |
| Render a MusicXML file to an SVG preview using the Verovio API. | |
| Returns HTML containing the embedded SVG. | |
| """ | |
| if not VEROVIO_API_URL: | |
| return "β VEROVIO_API_URL is not configured" | |
| try: | |
| with open(musicxml_path, "rb") as f: | |
| response = requests.post(VEROVIO_API_URL, files={"file": f}, timeout=120) | |
| if response.status_code != 200: | |
| return f"β Verovio API error {response.status_code}: {response.text}" | |
| svg = response.json().get("svg", "") | |
| b64_svg = base64.b64encode(svg.encode("utf-8")).decode("utf-8") | |
| return ( | |
| '<div style="background:white;padding:10px;border-radius:8px;">' | |
| f'<img src="data:image/svg+xml;base64,{b64_svg}" style="width:100%;" />' | |
| '</div>' | |
| ) | |
| except Exception as e: | |
| return f"β SVG rendering failed: {e}" | |
| def generate_score_from_audio(wav_file: str) -> str: | |
| """ | |
| Extract a MusicXML score from a generated music WAV file. | |
| """ | |
| try: | |
| return wav_to_musicxml(wav_file) | |
| except Exception as e: | |
| return f"β Score extraction failed: {e}" | |
| # Map of tool names to functions | |
| TOOL_MAP = { | |
| "generate_music_from_hum": generate_music_from_hum, | |
| "wav_to_musicxml": wav_to_musicxml, | |
| "render_musicxml_via_verovio_api": render_musicxml_via_verovio_api, | |
| "generate_score_from_audio": generate_score_from_audio, | |
| } | |
| # === 3. GPT Tool Selection === | |
| def gpt_decide_tool(message: str, audio_path: str) -> dict: | |
| system_prompt = """ | |
| You are an AI music assistant. The user uploads an audio file and provides a request. | |
| Choose the most appropriate tool from the list below and respond with strict JSON: | |
| - generate_music_from_hum(melody_file, prompt) | |
| - wav_to_musicxml(wav_file) | |
| - render_musicxml_via_verovio_api(musicxml_file) | |
| - generate_score_from_audio(wav_file) | |
| JSON format: | |
| { | |
| "tool_name": "...", | |
| "args": { ... }, | |
| "explanation": "Reasoning explanation" | |
| } | |
| """ | |
| user_prompt = f"User request: {message}\nAudio file path: {audio_path}" | |
| response = openai_client.chat.completions.create( | |
| model="gpt-3.5-turbo", | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt} | |
| ], | |
| temperature=0.2 | |
| ) | |
| text = response.choices[0].message.content | |
| try: | |
| return json.loads(text) | |
| except Exception: | |
| return {"error": f"Failed to parse JSON from GPT response:\n{text}"} | |
| # === 4. Main Logic & Dynamic Output Display === | |
| def handle_request(audio_file, user_prompt): | |
| # Input validation | |
| if not audio_file or not user_prompt: | |
| return ( | |
| "β Please upload an audio file and enter a request", | |
| "", "", | |
| gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) | |
| ) | |
| plan = gpt_decide_tool(user_prompt, audio_file) | |
| if "error" in plan: | |
| return (plan["error"], "", "") + (gr.update(visible=False),)*3 | |
| tool_name = plan["tool_name"] | |
| args = plan.get("args", {}) | |
| explanation= plan.get("explanation", "") | |
| log = f"π§ GPT chose: {tool_name}\nπ¦ Args: {json.dumps(args, ensure_ascii=False, indent=2)}" | |
| fn = TOOL_MAP.get(tool_name) | |
| if not fn: | |
| return (f"β Unknown tool: {tool_name}", explanation, log) + (gr.update(visible=False),)*3 | |
| output = fn(**args) | |
| # Determine output type and update components accordingly | |
| if isinstance(output, str) and output.endswith(".wav") and os.path.isfile(output): | |
| return ( | |
| "β Success", explanation, log, | |
| gr.update(value=output, visible=True), # Audio | |
| gr.update(visible=False), # SVG | |
| gr.update(visible=False) # Text | |
| ) | |
| if isinstance(output, str) and output.endswith(".musicxml") and os.path.isfile(output): | |
| # Automatically render MusicXML to SVG | |
| svg_html = render_musicxml_via_verovio_api(output) | |
| return ( | |
| "β Success", explanation, log, | |
| gr.update(visible=False), | |
| gr.update(value=svg_html, visible=True), | |
| gr.update(visible=False) | |
| ) | |
| if isinstance(output, str) and output.strip().startswith("<div"): | |
| # Already HTML SVG | |
| return ( | |
| "β Success", explanation, log, | |
| gr.update(visible=False), | |
| gr.update(value=output, visible=True), | |
| gr.update(visible=False) | |
| ) | |
| # Otherwise treat as plain text | |
| return ( | |
| "β Success", explanation, log, | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(value=str(output), visible=True) | |
| ) | |
| # === 5. Gradio Interface === | |
| with gr.Blocks(title="πΆ Vibe Jamming β Your Music Assistant") as demo: | |
| gr.Markdown("## π΅ Vibe Jamming β Your Music Assistant") | |
| with gr.Row(): | |
| audio_input = gr.Audio(label="Upload Audio (.wav)", type="filepath") | |
| text_input = gr.Textbox(label="Your Request", placeholder="e.g., Generate jazz music from my humming") | |
| run_button = gr.Button("π Run") | |
| status_box = gr.Textbox(label="Status") | |
| explanation_box = gr.Textbox(label="Explanation") | |
| log_box = gr.Textbox(label="Tool Log", lines=6) | |
| audio_output = gr.Audio(label="π§ Audio Output", visible=False, type="filepath") | |
| svg_output = gr.HTML(label="πΌοΈ Score Preview (SVG)", visible=False) | |
| text_output = gr.Textbox(label="π Text Output", visible=False, lines=4) | |
| run_button.click( | |
| fn=handle_request, | |
| inputs=[audio_input, text_input], | |
| outputs=[status_box, explanation_box, log_box, audio_output, svg_output, text_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |