File size: 4,464 Bytes
e3f3734
 
 
 
 
 
 
 
ab80cc2
 
 
 
 
 
e3f3734
ab80cc2
e3f3734
 
 
e696c96
e3f3734
e696c96
 
e3f3734
 
 
ab80cc2
 
e3f3734
 
 
 
 
 
e696c96
 
ab80cc2
 
 
e696c96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3f3734
 
e696c96
e3f3734
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e696c96
 
 
 
 
 
 
 
e3f3734
 
 
 
 
e696c96
e3f3734
ab80cc2
 
 
 
 
 
 
 
e696c96
 
 
 
 
 
ab80cc2
 
e696c96
ab80cc2
 
 
 
 
 
 
 
 
 
 
 
 
 
e3f3734
 
ab80cc2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""
MusicSampler HF Space — Modular AudioGenerator for DAW-INVADER.
Exposes a Gradio UI and a FastAPI endpoint for remote Vercel integration.
"""

from __future__ import annotations

import os
import uuid
from typing import Any

import gradio as gr
import numpy as np
import soundfile as sf
import torch
from fastapi import BackgroundTasks, FastAPI
from fastapi.responses import FileResponse
from pydantic import BaseModel

from src.text_to_audio import build_pipeline, list_presets

# Defaults to audioldm2-music as a robust alternative to MusicGen
MODEL_PRESET = os.getenv("MODEL_PRESET", "audioldm2-music")
USE_4BIT = os.getenv("USE_4BIT", "False").lower() == "true"

print(f"Loading {MODEL_PRESET} (4-bit={USE_4BIT})...")
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = build_pipeline(preset=MODEL_PRESET, use_4bit=USE_4BIT, device_map=device)

class GenRequest(BaseModel):
    prompt: str
    duration: float = 5.0
    model: str = MODEL_PRESET

def gradio_gen(prompt, duration, selected_model):
    global pipe, MODEL_PRESET
    if not prompt or not prompt.strip():
        return None, "Please enter a prompt."
    
    # Reload model if preset changed
    if selected_model != MODEL_PRESET:
        print(f"Switching to {selected_model}...")
        pipe = build_pipeline(preset=selected_model, use_4bit=USE_4BIT, device_map=device)
        MODEL_PRESET = selected_model

    # Tokens/Steps vary by model; 
    # For MusicGen: ~50 tokens/sec
    # For AudioLDM: uses num_inference_steps (passed via generate_kwargs)
    generate_kwargs = {}
    if "musicgen" in MODEL_PRESET:
        generate_kwargs["max_new_tokens"] = int(duration * 50)
    elif "audioldm" in MODEL_PRESET:
        generate_kwargs["num_inference_steps"] = 25 # Default good quality
    
    out, profile = pipe.generate_with_profile(
        prompt, 
        generate_kwargs=generate_kwargs
    )
    single = out if isinstance(out, dict) else out[0]
    audio = single["audio"]
    sr = single["sampling_rate"]
    
    if hasattr(audio, "numpy"):
        arr = audio.numpy()
    else:
        arr = np.asarray(audio)
        
    path = f"/tmp/gradio_{uuid.uuid4()}.wav"
    sf.write(path, arr.T if arr.ndim == 2 else arr, sr)
    return path, f"Generated in {profile.get('time_s', 0):.2f}s (RTF: {profile.get('rtf', 0):.2f})"

with gr.Blocks(title="MusicSampler", theme=gr.themes.Monochrome()) as ui:
    gr.Markdown("# 🎹 MusicSampler")
    gr.Markdown("Modular AudioGenerator for **DAW-INVADER**. Use the UI or POST to `/generate`.")
    
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(label="Musical/Audio Prompt", placeholder="An ambient synth pad with a slow filter sweep...", lines=3)
            with gr.Row():
                duration = gr.Slider(minimum=1, maximum=30, value=5, step=1, label="Duration (seconds)")
                preset_choice = gr.Dropdown(
                    choices=list(list_presets().keys()), 
                    value=MODEL_PRESET, 
                    label="Model Preset"
                )
            btn = gr.Button("Sample", variant="primary")
        with gr.Column():
            audio_out = gr.Audio(label="Output Sample", type="filepath")
            stats = gr.Label(label="Performance")
            
    btn.click(gradio_gen, inputs=[prompt, duration, preset_choice], outputs=[audio_out, stats])

app = ui.app 

@app.post("/generate")
async def api_generate(req: GenRequest, background_tasks: BackgroundTasks):
    """API Endpoint for DAW-INVADER / Vercel integration."""
    filename = f"gen_{uuid.uuid4()}.wav"
    output_path = os.path.join("/tmp", filename)
    
    generate_kwargs = {}
    if "musicgen" in req.model:
        generate_kwargs["max_new_tokens"] = int(req.duration * 50)
    elif "audioldm" in req.model:
        generate_kwargs["num_inference_steps"] = 25
    
    out = pipe.generate(
        req.prompt, 
        generate_kwargs=generate_kwargs
    )
    
    single = out if isinstance(out, dict) else out[0]
    audio = single["audio"]
    sr = single["sampling_rate"]
    
    if hasattr(audio, "numpy"):
        arr = audio.numpy()
    else:
        arr = np.asarray(audio)
        
    sf.write(output_path, arr.T if arr.ndim == 2 else arr, sr)
    background_tasks.add_task(os.remove, output_path)
    return FileResponse(output_path, media_type="audio/wav", filename=filename)

if __name__ == "__main__":
    ui.launch(server_name="0.0.0.0", server_port=7860)