""" MusicSampler HF Space — Modular AudioGenerator for DAW-INVADER. Exposes a Gradio UI and a FastAPI endpoint for remote Vercel integration. """ from __future__ import annotations import os import uuid from typing import Any import gradio as gr import numpy as np import soundfile as sf import torch from fastapi import BackgroundTasks, FastAPI from fastapi.responses import FileResponse from pydantic import BaseModel from src.text_to_audio import build_pipeline, list_presets # Defaults to audioldm2-music as a robust alternative to MusicGen MODEL_PRESET = os.getenv("MODEL_PRESET", "audioldm2-music") USE_4BIT = os.getenv("USE_4BIT", "False").lower() == "true" print(f"Loading {MODEL_PRESET} (4-bit={USE_4BIT})...") device = "cuda" if torch.cuda.is_available() else "cpu" pipe = build_pipeline(preset=MODEL_PRESET, use_4bit=USE_4BIT, device_map=device) class GenRequest(BaseModel): prompt: str duration: float = 5.0 model: str = MODEL_PRESET def gradio_gen(prompt, duration, selected_model): global pipe, MODEL_PRESET if not prompt or not prompt.strip(): return None, "Please enter a prompt." # Reload model if preset changed if selected_model != MODEL_PRESET: print(f"Switching to {selected_model}...") pipe = build_pipeline(preset=selected_model, use_4bit=USE_4BIT, device_map=device) MODEL_PRESET = selected_model # Tokens/Steps vary by model; # For MusicGen: ~50 tokens/sec # For AudioLDM: uses num_inference_steps (passed via generate_kwargs) generate_kwargs = {} if "musicgen" in MODEL_PRESET: generate_kwargs["max_new_tokens"] = int(duration * 50) elif "audioldm" in MODEL_PRESET: generate_kwargs["num_inference_steps"] = 25 # Default good quality out, profile = pipe.generate_with_profile( prompt, generate_kwargs=generate_kwargs ) single = out if isinstance(out, dict) else out[0] audio = single["audio"] sr = single["sampling_rate"] if hasattr(audio, "numpy"): arr = audio.numpy() else: arr = np.asarray(audio) path = f"/tmp/gradio_{uuid.uuid4()}.wav" sf.write(path, arr.T if arr.ndim == 2 else arr, sr) return path, f"Generated in {profile.get('time_s', 0):.2f}s (RTF: {profile.get('rtf', 0):.2f})" with gr.Blocks(title="MusicSampler", theme=gr.themes.Monochrome()) as ui: gr.Markdown("# 🎹 MusicSampler") gr.Markdown("Modular AudioGenerator for **DAW-INVADER**. Use the UI or POST to `/generate`.") with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="Musical/Audio Prompt", placeholder="An ambient synth pad with a slow filter sweep...", lines=3) with gr.Row(): duration = gr.Slider(minimum=1, maximum=30, value=5, step=1, label="Duration (seconds)") preset_choice = gr.Dropdown( choices=list(list_presets().keys()), value=MODEL_PRESET, label="Model Preset" ) btn = gr.Button("Sample", variant="primary") with gr.Column(): audio_out = gr.Audio(label="Output Sample", type="filepath") stats = gr.Label(label="Performance") btn.click(gradio_gen, inputs=[prompt, duration, preset_choice], outputs=[audio_out, stats]) app = ui.app @app.post("/generate") async def api_generate(req: GenRequest, background_tasks: BackgroundTasks): """API Endpoint for DAW-INVADER / Vercel integration.""" filename = f"gen_{uuid.uuid4()}.wav" output_path = os.path.join("/tmp", filename) generate_kwargs = {} if "musicgen" in req.model: generate_kwargs["max_new_tokens"] = int(req.duration * 50) elif "audioldm" in req.model: generate_kwargs["num_inference_steps"] = 25 out = pipe.generate( req.prompt, generate_kwargs=generate_kwargs ) single = out if isinstance(out, dict) else out[0] audio = single["audio"] sr = single["sampling_rate"] if hasattr(audio, "numpy"): arr = audio.numpy() else: arr = np.asarray(audio) sf.write(output_path, arr.T if arr.ndim == 2 else arr, sr) background_tasks.add_task(os.remove, output_path) return FileResponse(output_path, media_type="audio/wav", filename=filename) if __name__ == "__main__": ui.launch(server_name="0.0.0.0", server_port=7860)