Spaces:
Runtime error
Runtime error
File size: 4,464 Bytes
e3f3734 ab80cc2 e3f3734 ab80cc2 e3f3734 e696c96 e3f3734 e696c96 e3f3734 ab80cc2 e3f3734 e696c96 ab80cc2 e696c96 e3f3734 e696c96 e3f3734 e696c96 e3f3734 e696c96 e3f3734 ab80cc2 e696c96 ab80cc2 e696c96 ab80cc2 e3f3734 ab80cc2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | """
MusicSampler HF Space — Modular AudioGenerator for DAW-INVADER.
Exposes a Gradio UI and a FastAPI endpoint for remote Vercel integration.
"""
from __future__ import annotations
import os
import uuid
from typing import Any
import gradio as gr
import numpy as np
import soundfile as sf
import torch
from fastapi import BackgroundTasks, FastAPI
from fastapi.responses import FileResponse
from pydantic import BaseModel
from src.text_to_audio import build_pipeline, list_presets
# Defaults to audioldm2-music as a robust alternative to MusicGen
MODEL_PRESET = os.getenv("MODEL_PRESET", "audioldm2-music")
USE_4BIT = os.getenv("USE_4BIT", "False").lower() == "true"
print(f"Loading {MODEL_PRESET} (4-bit={USE_4BIT})...")
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = build_pipeline(preset=MODEL_PRESET, use_4bit=USE_4BIT, device_map=device)
class GenRequest(BaseModel):
prompt: str
duration: float = 5.0
model: str = MODEL_PRESET
def gradio_gen(prompt, duration, selected_model):
global pipe, MODEL_PRESET
if not prompt or not prompt.strip():
return None, "Please enter a prompt."
# Reload model if preset changed
if selected_model != MODEL_PRESET:
print(f"Switching to {selected_model}...")
pipe = build_pipeline(preset=selected_model, use_4bit=USE_4BIT, device_map=device)
MODEL_PRESET = selected_model
# Tokens/Steps vary by model;
# For MusicGen: ~50 tokens/sec
# For AudioLDM: uses num_inference_steps (passed via generate_kwargs)
generate_kwargs = {}
if "musicgen" in MODEL_PRESET:
generate_kwargs["max_new_tokens"] = int(duration * 50)
elif "audioldm" in MODEL_PRESET:
generate_kwargs["num_inference_steps"] = 25 # Default good quality
out, profile = pipe.generate_with_profile(
prompt,
generate_kwargs=generate_kwargs
)
single = out if isinstance(out, dict) else out[0]
audio = single["audio"]
sr = single["sampling_rate"]
if hasattr(audio, "numpy"):
arr = audio.numpy()
else:
arr = np.asarray(audio)
path = f"/tmp/gradio_{uuid.uuid4()}.wav"
sf.write(path, arr.T if arr.ndim == 2 else arr, sr)
return path, f"Generated in {profile.get('time_s', 0):.2f}s (RTF: {profile.get('rtf', 0):.2f})"
with gr.Blocks(title="MusicSampler", theme=gr.themes.Monochrome()) as ui:
gr.Markdown("# 🎹 MusicSampler")
gr.Markdown("Modular AudioGenerator for **DAW-INVADER**. Use the UI or POST to `/generate`.")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Musical/Audio Prompt", placeholder="An ambient synth pad with a slow filter sweep...", lines=3)
with gr.Row():
duration = gr.Slider(minimum=1, maximum=30, value=5, step=1, label="Duration (seconds)")
preset_choice = gr.Dropdown(
choices=list(list_presets().keys()),
value=MODEL_PRESET,
label="Model Preset"
)
btn = gr.Button("Sample", variant="primary")
with gr.Column():
audio_out = gr.Audio(label="Output Sample", type="filepath")
stats = gr.Label(label="Performance")
btn.click(gradio_gen, inputs=[prompt, duration, preset_choice], outputs=[audio_out, stats])
app = ui.app
@app.post("/generate")
async def api_generate(req: GenRequest, background_tasks: BackgroundTasks):
"""API Endpoint for DAW-INVADER / Vercel integration."""
filename = f"gen_{uuid.uuid4()}.wav"
output_path = os.path.join("/tmp", filename)
generate_kwargs = {}
if "musicgen" in req.model:
generate_kwargs["max_new_tokens"] = int(req.duration * 50)
elif "audioldm" in req.model:
generate_kwargs["num_inference_steps"] = 25
out = pipe.generate(
req.prompt,
generate_kwargs=generate_kwargs
)
single = out if isinstance(out, dict) else out[0]
audio = single["audio"]
sr = single["sampling_rate"]
if hasattr(audio, "numpy"):
arr = audio.numpy()
else:
arr = np.asarray(audio)
sf.write(output_path, arr.T if arr.ndim == 2 else arr, sr)
background_tasks.add_task(os.remove, output_path)
return FileResponse(output_path, media_type="audio/wav", filename=filename)
if __name__ == "__main__":
ui.launch(server_name="0.0.0.0", server_port=7860)
|