Keith commited on
Commit
e696c96
·
1 Parent(s): ab80cc2

Switch default to audioldm2-music and add model selector

Browse files
Files changed (2) hide show
  1. app.py +37 -21
  2. src/text_to_audio/pipeline.py +4 -0
app.py CHANGED
@@ -17,14 +17,13 @@ from fastapi import BackgroundTasks, FastAPI
17
  from fastapi.responses import FileResponse
18
  from pydantic import BaseModel
19
 
20
- from src.text_to_audio import build_pipeline
21
 
22
- # Initialize Pipeline (defaulting to musicgen-small for MusicSampler)
23
- MODEL_PRESET = os.getenv("MODEL_PRESET", "musicgen-small")
24
  USE_4BIT = os.getenv("USE_4BIT", "False").lower() == "true"
25
 
26
  print(f"Loading {MODEL_PRESET} (4-bit={USE_4BIT})...")
27
- # Force device to cuda if available, otherwise cpu
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
  pipe = build_pipeline(preset=MODEL_PRESET, use_4bit=USE_4BIT, device_map=device)
30
 
@@ -33,16 +32,29 @@ class GenRequest(BaseModel):
33
  duration: float = 5.0
34
  model: str = MODEL_PRESET
35
 
36
- # Gradio Interface functions
37
- def gradio_gen(prompt, duration):
38
  if not prompt or not prompt.strip():
39
  return None, "Please enter a prompt."
40
 
41
- # MusicGen: 5 seconds ~ 250 tokens (50 tokens/sec approx)
42
- tokens = int(duration * 50)
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  out, profile = pipe.generate_with_profile(
44
  prompt,
45
- generate_kwargs={"max_new_tokens": tokens}
46
  )
47
  single = out if isinstance(out, dict) else out[0]
48
  audio = single["audio"]
@@ -54,7 +66,6 @@ def gradio_gen(prompt, duration):
54
  arr = np.asarray(audio)
55
 
56
  path = f"/tmp/gradio_{uuid.uuid4()}.wav"
57
- # Ensure audio is properly formatted for soundfile
58
  sf.write(path, arr.T if arr.ndim == 2 else arr, sr)
59
  return path, f"Generated in {profile.get('time_s', 0):.2f}s (RTF: {profile.get('rtf', 0):.2f})"
60
 
@@ -64,17 +75,21 @@ with gr.Blocks(title="MusicSampler", theme=gr.themes.Monochrome()) as ui:
64
 
65
  with gr.Row():
66
  with gr.Column():
67
- prompt = gr.Textbox(label="Musical Prompt", placeholder="Lo-fi hip hop beat with smooth rhodes piano...", lines=3)
68
- duration = gr.Slider(minimum=1, maximum=30, value=5, step=1, label="Duration (seconds)")
 
 
 
 
 
 
69
  btn = gr.Button("Sample", variant="primary")
70
  with gr.Column():
71
  audio_out = gr.Audio(label="Output Sample", type="filepath")
72
  stats = gr.Label(label="Performance")
73
 
74
- btn.click(gradio_gen, inputs=[prompt, duration], outputs=[audio_out, stats])
75
 
76
- # HF Spaces automatically launches the app defined in app_file if it's sdk: gradio
77
- # To expose a custom API alongside Gradio, we use the internal FastAPI app.
78
  app = ui.app
79
 
80
  @app.post("/generate")
@@ -83,10 +98,15 @@ async def api_generate(req: GenRequest, background_tasks: BackgroundTasks):
83
  filename = f"gen_{uuid.uuid4()}.wav"
84
  output_path = os.path.join("/tmp", filename)
85
 
86
- tokens = int(req.duration * 50)
 
 
 
 
 
87
  out = pipe.generate(
88
  req.prompt,
89
- generate_kwargs={"max_new_tokens": tokens}
90
  )
91
 
92
  single = out if isinstance(out, dict) else out[0]
@@ -99,12 +119,8 @@ async def api_generate(req: GenRequest, background_tasks: BackgroundTasks):
99
  arr = np.asarray(audio)
100
 
101
  sf.write(output_path, arr.T if arr.ndim == 2 else arr, sr)
102
-
103
- # Clean up file after serving
104
  background_tasks.add_task(os.remove, output_path)
105
-
106
  return FileResponse(output_path, media_type="audio/wav", filename=filename)
107
 
108
- # Standard entry point for HF Spaces
109
  if __name__ == "__main__":
110
  ui.launch(server_name="0.0.0.0", server_port=7860)
 
17
  from fastapi.responses import FileResponse
18
  from pydantic import BaseModel
19
 
20
+ from src.text_to_audio import build_pipeline, list_presets
21
 
22
+ # Defaults to audioldm2-music as a robust alternative to MusicGen
23
+ MODEL_PRESET = os.getenv("MODEL_PRESET", "audioldm2-music")
24
  USE_4BIT = os.getenv("USE_4BIT", "False").lower() == "true"
25
 
26
  print(f"Loading {MODEL_PRESET} (4-bit={USE_4BIT})...")
 
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
28
  pipe = build_pipeline(preset=MODEL_PRESET, use_4bit=USE_4BIT, device_map=device)
29
 
 
32
  duration: float = 5.0
33
  model: str = MODEL_PRESET
34
 
35
+ def gradio_gen(prompt, duration, selected_model):
36
+ global pipe, MODEL_PRESET
37
  if not prompt or not prompt.strip():
38
  return None, "Please enter a prompt."
39
 
40
+ # Reload model if preset changed
41
+ if selected_model != MODEL_PRESET:
42
+ print(f"Switching to {selected_model}...")
43
+ pipe = build_pipeline(preset=selected_model, use_4bit=USE_4BIT, device_map=device)
44
+ MODEL_PRESET = selected_model
45
+
46
+ # Tokens/Steps vary by model;
47
+ # For MusicGen: ~50 tokens/sec
48
+ # For AudioLDM: uses num_inference_steps (passed via generate_kwargs)
49
+ generate_kwargs = {}
50
+ if "musicgen" in MODEL_PRESET:
51
+ generate_kwargs["max_new_tokens"] = int(duration * 50)
52
+ elif "audioldm" in MODEL_PRESET:
53
+ generate_kwargs["num_inference_steps"] = 25 # Default good quality
54
+
55
  out, profile = pipe.generate_with_profile(
56
  prompt,
57
+ generate_kwargs=generate_kwargs
58
  )
59
  single = out if isinstance(out, dict) else out[0]
60
  audio = single["audio"]
 
66
  arr = np.asarray(audio)
67
 
68
  path = f"/tmp/gradio_{uuid.uuid4()}.wav"
 
69
  sf.write(path, arr.T if arr.ndim == 2 else arr, sr)
70
  return path, f"Generated in {profile.get('time_s', 0):.2f}s (RTF: {profile.get('rtf', 0):.2f})"
71
 
 
75
 
76
  with gr.Row():
77
  with gr.Column():
78
+ prompt = gr.Textbox(label="Musical/Audio Prompt", placeholder="An ambient synth pad with a slow filter sweep...", lines=3)
79
+ with gr.Row():
80
+ duration = gr.Slider(minimum=1, maximum=30, value=5, step=1, label="Duration (seconds)")
81
+ preset_choice = gr.Dropdown(
82
+ choices=list(list_presets().keys()),
83
+ value=MODEL_PRESET,
84
+ label="Model Preset"
85
+ )
86
  btn = gr.Button("Sample", variant="primary")
87
  with gr.Column():
88
  audio_out = gr.Audio(label="Output Sample", type="filepath")
89
  stats = gr.Label(label="Performance")
90
 
91
+ btn.click(gradio_gen, inputs=[prompt, duration, preset_choice], outputs=[audio_out, stats])
92
 
 
 
93
  app = ui.app
94
 
95
  @app.post("/generate")
 
98
  filename = f"gen_{uuid.uuid4()}.wav"
99
  output_path = os.path.join("/tmp", filename)
100
 
101
+ generate_kwargs = {}
102
+ if "musicgen" in req.model:
103
+ generate_kwargs["max_new_tokens"] = int(req.duration * 50)
104
+ elif "audioldm" in req.model:
105
+ generate_kwargs["num_inference_steps"] = 25
106
+
107
  out = pipe.generate(
108
  req.prompt,
109
+ generate_kwargs=generate_kwargs
110
  )
111
 
112
  single = out if isinstance(out, dict) else out[0]
 
119
  arr = np.asarray(audio)
120
 
121
  sf.write(output_path, arr.T if arr.ndim == 2 else arr, sr)
 
 
122
  background_tasks.add_task(os.remove, output_path)
 
123
  return FileResponse(output_path, media_type="audio/wav", filename=filename)
124
 
 
125
  if __name__ == "__main__":
126
  ui.launch(server_name="0.0.0.0", server_port=7860)
src/text_to_audio/pipeline.py CHANGED
@@ -41,6 +41,10 @@ PRESETS = {
41
  "model_id": "facebook/musicgen-small",
42
  "description": "Music/sfx; 32k Hz, generation-style.",
43
  },
 
 
 
 
44
  }
45
 
46
 
 
41
  "model_id": "facebook/musicgen-small",
42
  "description": "Music/sfx; 32k Hz, generation-style.",
43
  },
44
+ "audioldm2-music": {
45
+ "model_id": "cvssp/audioldm2-music",
46
+ "description": "High-quality music generation via AudioLDM2; robust alternative to MusicGen.",
47
+ },
48
  }
49
 
50