Spaces:
Runtime error
Runtime error
Keith commited on
Commit ·
ab80cc2
1
Parent(s): e3f3734
Update SDK version and app.py for HF stability
Browse files
README.md
CHANGED
|
@@ -4,7 +4,7 @@ emoji: 🎹
|
|
| 4 |
colorFrom: indigo
|
| 5 |
colorTo: blue
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 4.
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
|
|
|
| 4 |
colorFrom: indigo
|
| 5 |
colorTo: blue
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.44.1
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
app.py
CHANGED
|
@@ -6,14 +6,16 @@ Exposes a Gradio UI and a FastAPI endpoint for remote Vercel integration.
|
|
| 6 |
from __future__ import annotations
|
| 7 |
|
| 8 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
import torch
|
| 10 |
-
from fastapi import
|
| 11 |
from fastapi.responses import FileResponse
|
| 12 |
from pydantic import BaseModel
|
| 13 |
-
import gradio as gr
|
| 14 |
-
import soundfile as sf
|
| 15 |
-
import numpy as np
|
| 16 |
-
import uuid
|
| 17 |
|
| 18 |
from src.text_to_audio import build_pipeline
|
| 19 |
|
|
@@ -22,50 +24,21 @@ MODEL_PRESET = os.getenv("MODEL_PRESET", "musicgen-small")
|
|
| 22 |
USE_4BIT = os.getenv("USE_4BIT", "False").lower() == "true"
|
| 23 |
|
| 24 |
print(f"Loading {MODEL_PRESET} (4-bit={USE_4BIT})...")
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
app = FastAPI(title="MusicSampler API")
|
| 29 |
|
| 30 |
class GenRequest(BaseModel):
|
| 31 |
prompt: str
|
| 32 |
duration: float = 5.0
|
| 33 |
model: str = MODEL_PRESET
|
| 34 |
|
| 35 |
-
|
| 36 |
-
async def api_generate(req: GenRequest, background_tasks: BackgroundTasks):
|
| 37 |
-
"""API Endpoint for DAW-INVADER / Vercel integration."""
|
| 38 |
-
filename = f"gen_{uuid.uuid4()}.wav"
|
| 39 |
-
output_path = os.path.join("/tmp", filename)
|
| 40 |
-
|
| 41 |
-
# Generate audio
|
| 42 |
-
# MusicGen supports 'max_new_tokens' via generate_kwargs
|
| 43 |
-
# 5 seconds ~ 250 tokens for MusicGen small (50 tokens/sec)
|
| 44 |
-
tokens = int(req.duration * 50)
|
| 45 |
-
|
| 46 |
-
out = pipe.generate(
|
| 47 |
-
req.prompt,
|
| 48 |
-
generate_kwargs={"max_new_tokens": tokens}
|
| 49 |
-
)
|
| 50 |
-
|
| 51 |
-
single = out if isinstance(out, dict) else out[0]
|
| 52 |
-
audio = single["audio"]
|
| 53 |
-
sr = single["sampling_rate"]
|
| 54 |
-
|
| 55 |
-
if hasattr(audio, "numpy"):
|
| 56 |
-
arr = audio.numpy()
|
| 57 |
-
else:
|
| 58 |
-
arr = np.asarray(audio)
|
| 59 |
-
|
| 60 |
-
sf.write(output_path, arr.T if arr.ndim == 2 else arr, sr)
|
| 61 |
-
|
| 62 |
-
# Clean up file after serving
|
| 63 |
-
background_tasks.add_task(os.remove, output_path)
|
| 64 |
-
|
| 65 |
-
return FileResponse(output_path, media_type="audio/wav", filename=filename)
|
| 66 |
-
|
| 67 |
-
# Gradio Interface
|
| 68 |
def gradio_gen(prompt, duration):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
tokens = int(duration * 50)
|
| 70 |
out, profile = pipe.generate_with_profile(
|
| 71 |
prompt,
|
|
@@ -81,6 +54,7 @@ def gradio_gen(prompt, duration):
|
|
| 81 |
arr = np.asarray(audio)
|
| 82 |
|
| 83 |
path = f"/tmp/gradio_{uuid.uuid4()}.wav"
|
|
|
|
| 84 |
sf.write(path, arr.T if arr.ndim == 2 else arr, sr)
|
| 85 |
return path, f"Generated in {profile.get('time_s', 0):.2f}s (RTF: {profile.get('rtf', 0):.2f})"
|
| 86 |
|
|
@@ -99,9 +73,38 @@ with gr.Blocks(title="MusicSampler", theme=gr.themes.Monochrome()) as ui:
|
|
| 99 |
|
| 100 |
btn.click(gradio_gen, inputs=[prompt, duration], outputs=[audio_out, stats])
|
| 101 |
|
| 102 |
-
#
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
|
|
|
| 105 |
if __name__ == "__main__":
|
| 106 |
-
|
| 107 |
-
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
|
| 6 |
from __future__ import annotations
|
| 7 |
|
| 8 |
import os
|
| 9 |
+
import uuid
|
| 10 |
+
from typing import Any
|
| 11 |
+
|
| 12 |
+
import gradio as gr
|
| 13 |
+
import numpy as np
|
| 14 |
+
import soundfile as sf
|
| 15 |
import torch
|
| 16 |
+
from fastapi import BackgroundTasks, FastAPI
|
| 17 |
from fastapi.responses import FileResponse
|
| 18 |
from pydantic import BaseModel
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
from src.text_to_audio import build_pipeline
|
| 21 |
|
|
|
|
| 24 |
USE_4BIT = os.getenv("USE_4BIT", "False").lower() == "true"
|
| 25 |
|
| 26 |
print(f"Loading {MODEL_PRESET} (4-bit={USE_4BIT})...")
|
| 27 |
+
# Force device to cuda if available, otherwise cpu
|
| 28 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 29 |
+
pipe = build_pipeline(preset=MODEL_PRESET, use_4bit=USE_4BIT, device_map=device)
|
|
|
|
| 30 |
|
| 31 |
class GenRequest(BaseModel):
|
| 32 |
prompt: str
|
| 33 |
duration: float = 5.0
|
| 34 |
model: str = MODEL_PRESET
|
| 35 |
|
| 36 |
+
# Gradio Interface functions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
def gradio_gen(prompt, duration):
|
| 38 |
+
if not prompt or not prompt.strip():
|
| 39 |
+
return None, "Please enter a prompt."
|
| 40 |
+
|
| 41 |
+
# MusicGen: 5 seconds ~ 250 tokens (50 tokens/sec approx)
|
| 42 |
tokens = int(duration * 50)
|
| 43 |
out, profile = pipe.generate_with_profile(
|
| 44 |
prompt,
|
|
|
|
| 54 |
arr = np.asarray(audio)
|
| 55 |
|
| 56 |
path = f"/tmp/gradio_{uuid.uuid4()}.wav"
|
| 57 |
+
# Ensure audio is properly formatted for soundfile
|
| 58 |
sf.write(path, arr.T if arr.ndim == 2 else arr, sr)
|
| 59 |
return path, f"Generated in {profile.get('time_s', 0):.2f}s (RTF: {profile.get('rtf', 0):.2f})"
|
| 60 |
|
|
|
|
| 73 |
|
| 74 |
btn.click(gradio_gen, inputs=[prompt, duration], outputs=[audio_out, stats])
|
| 75 |
|
| 76 |
+
# HF Spaces automatically launches the app defined in app_file if it's sdk: gradio
|
| 77 |
+
# To expose a custom API alongside Gradio, we use the internal FastAPI app.
|
| 78 |
+
app = ui.app
|
| 79 |
+
|
| 80 |
+
@app.post("/generate")
|
| 81 |
+
async def api_generate(req: GenRequest, background_tasks: BackgroundTasks):
|
| 82 |
+
"""API Endpoint for DAW-INVADER / Vercel integration."""
|
| 83 |
+
filename = f"gen_{uuid.uuid4()}.wav"
|
| 84 |
+
output_path = os.path.join("/tmp", filename)
|
| 85 |
+
|
| 86 |
+
tokens = int(req.duration * 50)
|
| 87 |
+
out = pipe.generate(
|
| 88 |
+
req.prompt,
|
| 89 |
+
generate_kwargs={"max_new_tokens": tokens}
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
single = out if isinstance(out, dict) else out[0]
|
| 93 |
+
audio = single["audio"]
|
| 94 |
+
sr = single["sampling_rate"]
|
| 95 |
+
|
| 96 |
+
if hasattr(audio, "numpy"):
|
| 97 |
+
arr = audio.numpy()
|
| 98 |
+
else:
|
| 99 |
+
arr = np.asarray(audio)
|
| 100 |
+
|
| 101 |
+
sf.write(output_path, arr.T if arr.ndim == 2 else arr, sr)
|
| 102 |
+
|
| 103 |
+
# Clean up file after serving
|
| 104 |
+
background_tasks.add_task(os.remove, output_path)
|
| 105 |
+
|
| 106 |
+
return FileResponse(output_path, media_type="audio/wav", filename=filename)
|
| 107 |
|
| 108 |
+
# Standard entry point for HF Spaces
|
| 109 |
if __name__ == "__main__":
|
| 110 |
+
ui.launch(server_name="0.0.0.0", server_port=7860)
|
|
|