Keith commited on
Commit
ab80cc2
·
1 Parent(s): e3f3734

Update SDK version and app.py for HF stability

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +49 -46
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🎹
4
  colorFrom: indigo
5
  colorTo: blue
6
  sdk: gradio
7
- sdk_version: 4.0.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
4
  colorFrom: indigo
5
  colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 4.44.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
app.py CHANGED
@@ -6,14 +6,16 @@ Exposes a Gradio UI and a FastAPI endpoint for remote Vercel integration.
6
  from __future__ import annotations
7
 
8
  import os
 
 
 
 
 
 
9
  import torch
10
- from fastapi import FastAPI, BackgroundTasks
11
  from fastapi.responses import FileResponse
12
  from pydantic import BaseModel
13
- import gradio as gr
14
- import soundfile as sf
15
- import numpy as np
16
- import uuid
17
 
18
  from src.text_to_audio import build_pipeline
19
 
@@ -22,50 +24,21 @@ MODEL_PRESET = os.getenv("MODEL_PRESET", "musicgen-small")
22
  USE_4BIT = os.getenv("USE_4BIT", "False").lower() == "true"
23
 
24
  print(f"Loading {MODEL_PRESET} (4-bit={USE_4BIT})...")
25
- pipe = build_pipeline(preset=MODEL_PRESET, use_4bit=USE_4BIT)
26
-
27
- # FastAPI Setup
28
- app = FastAPI(title="MusicSampler API")
29
 
30
  class GenRequest(BaseModel):
31
  prompt: str
32
  duration: float = 5.0
33
  model: str = MODEL_PRESET
34
 
35
- @app.post("/generate")
36
- async def api_generate(req: GenRequest, background_tasks: BackgroundTasks):
37
- """API Endpoint for DAW-INVADER / Vercel integration."""
38
- filename = f"gen_{uuid.uuid4()}.wav"
39
- output_path = os.path.join("/tmp", filename)
40
-
41
- # Generate audio
42
- # MusicGen supports 'max_new_tokens' via generate_kwargs
43
- # 5 seconds ~ 250 tokens for MusicGen small (50 tokens/sec)
44
- tokens = int(req.duration * 50)
45
-
46
- out = pipe.generate(
47
- req.prompt,
48
- generate_kwargs={"max_new_tokens": tokens}
49
- )
50
-
51
- single = out if isinstance(out, dict) else out[0]
52
- audio = single["audio"]
53
- sr = single["sampling_rate"]
54
-
55
- if hasattr(audio, "numpy"):
56
- arr = audio.numpy()
57
- else:
58
- arr = np.asarray(audio)
59
-
60
- sf.write(output_path, arr.T if arr.ndim == 2 else arr, sr)
61
-
62
- # Clean up file after serving
63
- background_tasks.add_task(os.remove, output_path)
64
-
65
- return FileResponse(output_path, media_type="audio/wav", filename=filename)
66
-
67
- # Gradio Interface
68
  def gradio_gen(prompt, duration):
 
 
 
 
69
  tokens = int(duration * 50)
70
  out, profile = pipe.generate_with_profile(
71
  prompt,
@@ -81,6 +54,7 @@ def gradio_gen(prompt, duration):
81
  arr = np.asarray(audio)
82
 
83
  path = f"/tmp/gradio_{uuid.uuid4()}.wav"
 
84
  sf.write(path, arr.T if arr.ndim == 2 else arr, sr)
85
  return path, f"Generated in {profile.get('time_s', 0):.2f}s (RTF: {profile.get('rtf', 0):.2f})"
86
 
@@ -99,9 +73,38 @@ with gr.Blocks(title="MusicSampler", theme=gr.themes.Monochrome()) as ui:
99
 
100
  btn.click(gradio_gen, inputs=[prompt, duration], outputs=[audio_out, stats])
101
 
102
- # Mount Gradio into FastAPI
103
- app = gr.mount_gradio_app(app, ui, path="/")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
 
105
  if __name__ == "__main__":
106
- import uvicorn
107
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
6
  from __future__ import annotations
7
 
8
  import os
9
+ import uuid
10
+ from typing import Any
11
+
12
+ import gradio as gr
13
+ import numpy as np
14
+ import soundfile as sf
15
  import torch
16
+ from fastapi import BackgroundTasks, FastAPI
17
  from fastapi.responses import FileResponse
18
  from pydantic import BaseModel
 
 
 
 
19
 
20
  from src.text_to_audio import build_pipeline
21
 
 
24
  USE_4BIT = os.getenv("USE_4BIT", "False").lower() == "true"
25
 
26
  print(f"Loading {MODEL_PRESET} (4-bit={USE_4BIT})...")
27
+ # Force device to cuda if available, otherwise cpu
28
+ device = "cuda" if torch.cuda.is_available() else "cpu"
29
+ pipe = build_pipeline(preset=MODEL_PRESET, use_4bit=USE_4BIT, device_map=device)
 
30
 
31
  class GenRequest(BaseModel):
32
  prompt: str
33
  duration: float = 5.0
34
  model: str = MODEL_PRESET
35
 
36
+ # Gradio Interface functions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  def gradio_gen(prompt, duration):
38
+ if not prompt or not prompt.strip():
39
+ return None, "Please enter a prompt."
40
+
41
+ # MusicGen: 5 seconds ~ 250 tokens (50 tokens/sec approx)
42
  tokens = int(duration * 50)
43
  out, profile = pipe.generate_with_profile(
44
  prompt,
 
54
  arr = np.asarray(audio)
55
 
56
  path = f"/tmp/gradio_{uuid.uuid4()}.wav"
57
+ # Ensure audio is properly formatted for soundfile
58
  sf.write(path, arr.T if arr.ndim == 2 else arr, sr)
59
  return path, f"Generated in {profile.get('time_s', 0):.2f}s (RTF: {profile.get('rtf', 0):.2f})"
60
 
 
73
 
74
  btn.click(gradio_gen, inputs=[prompt, duration], outputs=[audio_out, stats])
75
 
76
+ # HF Spaces automatically launches the app defined in app_file if it's sdk: gradio
77
+ # To expose a custom API alongside Gradio, we use the internal FastAPI app.
78
+ app = ui.app
79
+
80
+ @app.post("/generate")
81
+ async def api_generate(req: GenRequest, background_tasks: BackgroundTasks):
82
+ """API Endpoint for DAW-INVADER / Vercel integration."""
83
+ filename = f"gen_{uuid.uuid4()}.wav"
84
+ output_path = os.path.join("/tmp", filename)
85
+
86
+ tokens = int(req.duration * 50)
87
+ out = pipe.generate(
88
+ req.prompt,
89
+ generate_kwargs={"max_new_tokens": tokens}
90
+ )
91
+
92
+ single = out if isinstance(out, dict) else out[0]
93
+ audio = single["audio"]
94
+ sr = single["sampling_rate"]
95
+
96
+ if hasattr(audio, "numpy"):
97
+ arr = audio.numpy()
98
+ else:
99
+ arr = np.asarray(audio)
100
+
101
+ sf.write(output_path, arr.T if arr.ndim == 2 else arr, sr)
102
+
103
+ # Clean up file after serving
104
+ background_tasks.add_task(os.remove, output_path)
105
+
106
+ return FileResponse(output_path, media_type="audio/wav", filename=filename)
107
 
108
+ # Standard entry point for HF Spaces
109
  if __name__ == "__main__":
110
+ ui.launch(server_name="0.0.0.0", server_port=7860)