Spaces:
Runtime error
Runtime error
| """ | |
| MusicSampler HF Space — Modular AudioGenerator for DAW-INVADER. | |
| Exposes a Gradio UI and a FastAPI endpoint for remote Vercel integration. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import uuid | |
| from typing import Any | |
| import gradio as gr | |
| import numpy as np | |
| import soundfile as sf | |
| import torch | |
| from fastapi import BackgroundTasks, FastAPI | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel | |
| from src.text_to_audio import build_pipeline, list_presets | |
| # Defaults to audioldm2-music as a robust alternative to MusicGen | |
| MODEL_PRESET = os.getenv("MODEL_PRESET", "audioldm2-music") | |
| USE_4BIT = os.getenv("USE_4BIT", "False").lower() == "true" | |
| print(f"Loading {MODEL_PRESET} (4-bit={USE_4BIT})...") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| pipe = build_pipeline(preset=MODEL_PRESET, use_4bit=USE_4BIT, device_map=device) | |
| class GenRequest(BaseModel): | |
| prompt: str | |
| duration: float = 5.0 | |
| model: str = MODEL_PRESET | |
| def gradio_gen(prompt, duration, selected_model): | |
| global pipe, MODEL_PRESET | |
| if not prompt or not prompt.strip(): | |
| return None, "Please enter a prompt." | |
| # Reload model if preset changed | |
| if selected_model != MODEL_PRESET: | |
| print(f"Switching to {selected_model}...") | |
| pipe = build_pipeline(preset=selected_model, use_4bit=USE_4BIT, device_map=device) | |
| MODEL_PRESET = selected_model | |
| # Tokens/Steps vary by model; | |
| # For MusicGen: ~50 tokens/sec | |
| # For AudioLDM: uses num_inference_steps (passed via generate_kwargs) | |
| generate_kwargs = {} | |
| if "musicgen" in MODEL_PRESET: | |
| generate_kwargs["max_new_tokens"] = int(duration * 50) | |
| elif "audioldm" in MODEL_PRESET: | |
| generate_kwargs["num_inference_steps"] = 25 # Default good quality | |
| out, profile = pipe.generate_with_profile( | |
| prompt, | |
| generate_kwargs=generate_kwargs | |
| ) | |
| single = out if isinstance(out, dict) else out[0] | |
| audio = single["audio"] | |
| sr = single["sampling_rate"] | |
| if hasattr(audio, "numpy"): | |
| arr = audio.numpy() | |
| else: | |
| arr = np.asarray(audio) | |
| path = f"/tmp/gradio_{uuid.uuid4()}.wav" | |
| sf.write(path, arr.T if arr.ndim == 2 else arr, sr) | |
| return path, f"Generated in {profile.get('time_s', 0):.2f}s (RTF: {profile.get('rtf', 0):.2f})" | |
| with gr.Blocks(title="MusicSampler", theme=gr.themes.Monochrome()) as ui: | |
| gr.Markdown("# 🎹 MusicSampler") | |
| gr.Markdown("Modular AudioGenerator for **DAW-INVADER**. Use the UI or POST to `/generate`.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt = gr.Textbox(label="Musical/Audio Prompt", placeholder="An ambient synth pad with a slow filter sweep...", lines=3) | |
| with gr.Row(): | |
| duration = gr.Slider(minimum=1, maximum=30, value=5, step=1, label="Duration (seconds)") | |
| preset_choice = gr.Dropdown( | |
| choices=list(list_presets().keys()), | |
| value=MODEL_PRESET, | |
| label="Model Preset" | |
| ) | |
| btn = gr.Button("Sample", variant="primary") | |
| with gr.Column(): | |
| audio_out = gr.Audio(label="Output Sample", type="filepath") | |
| stats = gr.Label(label="Performance") | |
| btn.click(gradio_gen, inputs=[prompt, duration, preset_choice], outputs=[audio_out, stats]) | |
| app = ui.app | |
| async def api_generate(req: GenRequest, background_tasks: BackgroundTasks): | |
| """API Endpoint for DAW-INVADER / Vercel integration.""" | |
| filename = f"gen_{uuid.uuid4()}.wav" | |
| output_path = os.path.join("/tmp", filename) | |
| generate_kwargs = {} | |
| if "musicgen" in req.model: | |
| generate_kwargs["max_new_tokens"] = int(req.duration * 50) | |
| elif "audioldm" in req.model: | |
| generate_kwargs["num_inference_steps"] = 25 | |
| out = pipe.generate( | |
| req.prompt, | |
| generate_kwargs=generate_kwargs | |
| ) | |
| single = out if isinstance(out, dict) else out[0] | |
| audio = single["audio"] | |
| sr = single["sampling_rate"] | |
| if hasattr(audio, "numpy"): | |
| arr = audio.numpy() | |
| else: | |
| arr = np.asarray(audio) | |
| sf.write(output_path, arr.T if arr.ndim == 2 else arr, sr) | |
| background_tasks.add_task(os.remove, output_path) | |
| return FileResponse(output_path, media_type="audio/wav", filename=filename) | |
| if __name__ == "__main__": | |
| ui.launch(server_name="0.0.0.0", server_port=7860) | |