Manmay Nakhashi commited on
Commit
5cc51a5
·
1 Parent(s): e694869

Add explicit Target Duration slider (0–60s) + gen_duration kwarg

Browse files

- TTSServer.generate gains gen_duration: float = 0.0 (override estimator)
- Gradio app exposes a 'Target duration (s)' slider (0 = auto)
- inference.py --gen-duration help text expanded
- Auto-rescale and end-of-clip patch already active for long outputs

For music / multi-section scenes set the slider to 20–60 s and the
auto-rescale schedule keeps the output safe at any cfg.

Files changed (3) hide show
  1. app.py +14 -6
  2. src/inference.py +10 -3
  3. src/inference_server.py +47 -5
app.py CHANGED
@@ -107,7 +107,8 @@ EXAMPLES: list[tuple[str, str, str]] = [
107
 
108
 
109
  @spaces.GPU(duration=120)
110
- def on_generate(prompt: str, audio_ref, cfg: float, stg: float, dur_mult: float, seed: int):
 
111
  if not prompt or not prompt.strip():
112
  raise gr.Error("Prompt is empty.")
113
  t0 = time.time()
@@ -119,6 +120,7 @@ def on_generate(prompt: str, audio_ref, cfg: float, stg: float, dur_mult: float,
119
  voice_ref=ref_path,
120
  cfg_scale=cfg, stg_scale=stg,
121
  duration_multiplier=dur_mult, seed=int(seed),
 
122
  )
123
  elapsed = time.time() - t0
124
  logging.info(f"Generated in {elapsed:.2f}s -> {output}")
@@ -159,7 +161,11 @@ with gr.Blocks(
159
  with gr.Accordion("Inference settings", open=False):
160
  cfg_slider = gr.Slider(1.0, 10.0, value=2.5, step=0.5, label="CFG scale")
161
  stg_slider = gr.Slider(0.0, 5.0, value=1.5, step=0.5, label="STG scale")
162
- dur_slider = gr.Slider(0.8, 2.0, value=1.1, step=0.05, label="Duration ×")
 
 
 
 
163
  seed_input = gr.Number(value=42, label="Seed", precision=0)
164
  audio_out = gr.Audio(label="Generated audio", type="filepath")
165
  with gr.Accordion("Prompt writing guide", open=False):
@@ -176,7 +182,8 @@ with gr.Blocks(
176
 
177
  gen_btn.click(
178
  on_generate,
179
- inputs=[prompt_box, audio_ref, cfg_slider, stg_slider, dur_slider, seed_input],
 
180
  outputs=[audio_out],
181
  )
182
 
@@ -185,15 +192,16 @@ with gr.Blocks(
185
  gr.Examples(
186
  label="🎬 Click any row to generate a sample",
187
  examples=[
188
- [name, prompt, voice_path, 2.5, 1.5, 1.1, 42]
189
  for name, voice_path, prompt in EXAMPLES
190
  ],
191
  example_labels=[name for name, _, _ in EXAMPLES],
192
  inputs=[gr.Textbox(visible=False, label="Scene"),
193
  prompt_box, audio_ref,
194
- cfg_slider, stg_slider, dur_slider, seed_input],
195
  outputs=[audio_out],
196
- fn=lambda _name, prompt, ref, cfg, stg, dur, seed: on_generate(prompt, ref, cfg, stg, dur, seed),
 
197
  cache_examples=False,
198
  run_on_click=True,
199
  examples_per_page=20,
 
107
 
108
 
109
  @spaces.GPU(duration=120)
110
+ def on_generate(prompt: str, audio_ref, cfg: float, stg: float, dur_mult: float,
111
+ gen_dur: float, seed: int):
112
  if not prompt or not prompt.strip():
113
  raise gr.Error("Prompt is empty.")
114
  t0 = time.time()
 
120
  voice_ref=ref_path,
121
  cfg_scale=cfg, stg_scale=stg,
122
  duration_multiplier=dur_mult, seed=int(seed),
123
+ gen_duration=float(gen_dur),
124
  )
125
  elapsed = time.time() - t0
126
  logging.info(f"Generated in {elapsed:.2f}s -> {output}")
 
161
  with gr.Accordion("Inference settings", open=False):
162
  cfg_slider = gr.Slider(1.0, 10.0, value=2.5, step=0.5, label="CFG scale")
163
  stg_slider = gr.Slider(0.0, 5.0, value=1.5, step=0.5, label="STG scale")
164
+ dur_slider = gr.Slider(0.8, 2.0, value=1.1, step=0.05,
165
+ label="Duration × (only used when target duration = 0)")
166
+ gen_dur_slider = gr.Slider(0.0, 60.0, value=0.0, step=1.0,
167
+ label="Target duration (s) — 0 = auto from prompt; "
168
+ "set higher (≥20s) for long-form music or scenes")
169
  seed_input = gr.Number(value=42, label="Seed", precision=0)
170
  audio_out = gr.Audio(label="Generated audio", type="filepath")
171
  with gr.Accordion("Prompt writing guide", open=False):
 
182
 
183
  gen_btn.click(
184
  on_generate,
185
+ inputs=[prompt_box, audio_ref, cfg_slider, stg_slider,
186
+ dur_slider, gen_dur_slider, seed_input],
187
  outputs=[audio_out],
188
  )
189
 
 
192
  gr.Examples(
193
  label="🎬 Click any row to generate a sample",
194
  examples=[
195
+ [name, prompt, voice_path, 2.5, 1.5, 1.1, 0.0, 42]
196
  for name, voice_path, prompt in EXAMPLES
197
  ],
198
  example_labels=[name for name, _, _ in EXAMPLES],
199
  inputs=[gr.Textbox(visible=False, label="Scene"),
200
  prompt_box, audio_ref,
201
+ cfg_slider, stg_slider, dur_slider, gen_dur_slider, seed_input],
202
  outputs=[audio_out],
203
+ fn=lambda _name, prompt, ref, cfg, stg, dur, gen_dur, seed: on_generate(
204
+ prompt, ref, cfg, stg, dur, gen_dur, seed),
205
  cache_examples=False,
206
  run_on_click=True,
207
  examples_per_page=20,
src/inference.py CHANGED
@@ -230,7 +230,10 @@ def parse_args():
230
  p.add_argument("--output", default="tts_output.wav")
231
 
232
  p.add_argument("--ref-duration", type=float, default=10.0, help="Seconds of voice reference to use")
233
- p.add_argument("--gen-duration", type=float, default=0.0, help="Target duration (0=auto)")
 
 
 
234
  p.add_argument("--pad-start", type=float, default=0.0,
235
  help="Prepend N seconds of silent padding, trimmed after decode (use 0 for clean starts)")
236
  p.add_argument("--speed", type=float, default=1.0)
@@ -260,7 +263,9 @@ def parse_args():
260
  p.add_argument("--cfg-scale", type=float, default=None, help="CFG scale (auto: 1.0 distilled, 7.0 dev)")
261
  p.add_argument("--stg-scale", type=float, default=None, help="STG scale (auto: 0.0 distilled, 1.0 dev)")
262
  p.add_argument("--stg-block", type=int, default=29, help="Block index for STG perturbation")
263
- p.add_argument("--rescale-scale", type=float, default=None, help="Rescale (auto: 0.0 distilled, 0.7 dev)")
 
 
264
  p.add_argument("--modality-scale", type=float, default=None, help="Modality (auto: 1.0 distilled, 3.0 dev)")
265
  p.add_argument("--cfg-clamp", type=float, default=0.0, help="Clamp guided pred std to N * cond std (0=disabled)")
266
  p.add_argument("--steps", type=int, default=None, help="Override steps (auto: distilled sigmas / 30 dev)")
@@ -324,7 +329,9 @@ def main():
324
  if args.stg_scale is None:
325
  args.stg_scale = 0.0 if is_distilled else 1.0
326
  if args.rescale_scale is None:
327
- args.rescale_scale = 0.0 if is_distilled else 0.7
 
 
328
  if args.modality_scale is None:
329
  args.modality_scale = 1.0 if is_distilled else 3.0
330
  if args.fps is None:
 
230
  p.add_argument("--output", default="tts_output.wav")
231
 
232
  p.add_argument("--ref-duration", type=float, default=10.0, help="Seconds of voice reference to use")
233
+ p.add_argument("--gen-duration", type=float, default=0.0,
234
+ help="Target output duration in seconds (0 = auto from prompt + multiplier). "
235
+ "Set explicitly for long-form prompts (e.g. --gen-duration 30 for music). "
236
+ "Outputs >20.5s automatically engage the end-of-clip silence-prior patch.")
237
  p.add_argument("--pad-start", type=float, default=0.0,
238
  help="Prepend N seconds of silent padding, trimmed after decode (use 0 for clean starts)")
239
  p.add_argument("--speed", type=float, default=1.0)
 
263
  p.add_argument("--cfg-scale", type=float, default=None, help="CFG scale (auto: 1.0 distilled, 7.0 dev)")
264
  p.add_argument("--stg-scale", type=float, default=None, help="STG scale (auto: 0.0 distilled, 1.0 dev)")
265
  p.add_argument("--stg-block", type=int, default=29, help="Block index for STG perturbation")
266
+ p.add_argument("--rescale-scale", type=float, default=None,
267
+ help="Latent CFG std-rescale (default auto: cfg-aware schedule that prevents "
268
+ "output clipping at high cfg; pass any float in [0,1] to override).")
269
  p.add_argument("--modality-scale", type=float, default=None, help="Modality (auto: 1.0 distilled, 3.0 dev)")
270
  p.add_argument("--cfg-clamp", type=float, default=0.0, help="Clamp guided pred std to N * cond std (0=disabled)")
271
  p.add_argument("--steps", type=int, default=None, help="Override steps (auto: distilled sigmas / 30 dev)")
 
329
  if args.stg_scale is None:
330
  args.stg_scale = 0.0 if is_distilled else 1.0
331
  if args.rescale_scale is None:
332
+ # Auto cfg-aware rescale: imported from inference_server to keep one source of truth.
333
+ from inference_server import auto_rescale_for_cfg
334
+ args.rescale_scale = 0.0 if is_distilled else auto_rescale_for_cfg(args.cfg_scale)
335
  if args.modality_scale is None:
336
  args.modality_scale = 1.0 if is_distilled else 3.0
337
  if args.fps is None:
src/inference_server.py CHANGED
@@ -60,6 +60,34 @@ def estimate_duration(prompt, multiplier=1.1):
60
  return max(3.0, round(base * multiplier, 1))
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  class TTSServer:
64
  def __init__(self, checkpoint=None, full_checkpoint=None, gemma_root=None,
65
  device="cuda", dtype="bf16", compile_model=True, bnb_4bit=True):
@@ -177,12 +205,23 @@ class TTSServer:
177
 
178
  @torch.inference_mode()
179
  def generate(self, prompt, voice_ref=None, cfg_scale=2.5, stg_scale=1.5,
180
- duration_multiplier=1.1, seed=42, ref_duration=10.0):
181
- """Generate audio. Returns (waveform_path, duration_seconds)."""
 
 
 
 
 
 
 
 
182
  t_total = time.time()
183
 
184
- # Duration + target shape
185
- gen_dur = estimate_duration(prompt, duration_multiplier)
 
 
 
186
  fps = 25.0
187
  n_frames = int(round(gen_dur * fps)) + 1
188
  n_frames = ((n_frames - 1 + 4) // 8) * 8 + 1
@@ -231,10 +270,13 @@ class TTSServer:
231
  logging.info(f"Prompt: {time.time()-t0:.2f}s")
232
 
233
  # Denoiser
 
 
 
234
  guider = MultiModalGuider(
235
  params=MultiModalGuiderParams(
236
  cfg_scale=cfg_scale, stg_scale=stg_scale,
237
- stg_blocks=[29], rescale_scale=0.0, modality_scale=1.0,
238
  ),
239
  negative_context=a_ctx_neg,
240
  )
 
60
  return max(3.0, round(base * multiplier, 1))
61
 
62
 
63
+ def auto_rescale_for_cfg(cfg: float) -> float:
64
+ """CFG-aware std-rescale schedule that prevents output clipping at high cfg.
65
+
66
+ The CFG formula `pred = cond + (cfg-1)*(cond - uncond)` makes pred.std()
67
+ grow roughly linearly with cfg, which the audio VAE+vocoder render as
68
+ progressively louder waveforms. By cfg≈3 the output starts hard-clipping
69
+ at 0 dBFS — and clipped information is unrecoverable in post.
70
+
71
+ Empirical sweep on the blues prompt with the back-porch-boogie ref
72
+ (rescale_scale needed for ≥1 dB peak headroom):
73
+ cfg=2.5 → 0.2 ; cfg=3 → 0.6 ; cfg=4 → 0.8 ; cfg=5–8 → 0.8 ; cfg=10 → 1.0
74
+
75
+ Piecewise-linear fit through those points; returns 0 below cfg=2 (no CFG
76
+ even applied at cfg=1), plateaus at 0.8 between cfg=4 and cfg=8 to
77
+ preserve the "extra punch" of high-CFG generations, and ramps to 1.0 by
78
+ cfg=10.
79
+ """
80
+ if cfg <= 2.0:
81
+ return 0.0
82
+ if cfg <= 3.0:
83
+ return 0.6 * (cfg - 2.0) # 0 → 0.6
84
+ if cfg <= 4.0:
85
+ return 0.6 + 0.2 * (cfg - 3.0) # 0.6 → 0.8
86
+ if cfg <= 8.0:
87
+ return 0.8 # plateau
88
+ return min(1.0, 0.8 + 0.1 * (cfg - 8.0)) # 0.8 → 1.0 at cfg=10
89
+
90
+
91
  class TTSServer:
92
  def __init__(self, checkpoint=None, full_checkpoint=None, gemma_root=None,
93
  device="cuda", dtype="bf16", compile_model=True, bnb_4bit=True):
 
205
 
206
  @torch.inference_mode()
207
  def generate(self, prompt, voice_ref=None, cfg_scale=2.5, stg_scale=1.5,
208
+ duration_multiplier=1.1, seed=42, ref_duration=10.0,
209
+ rescale_scale="auto", gen_duration: float = 0.0):
210
+ """Generate audio. Returns (waveform_path, duration_seconds).
211
+
212
+ rescale_scale: latent-side CFG std-rescale that prevents clipping at
213
+ high cfg. Set to "auto" (default) for the cfg-aware schedule, a
214
+ float in [0, 1] for a fixed override, or 0 to disable.
215
+ gen_duration: explicit target duration in seconds. 0 (default) → auto
216
+ from prompt + duration_multiplier; >0 overrides everything else.
217
+ """
218
  t_total = time.time()
219
 
220
+ # Duration + target shape — explicit gen_duration wins over the estimator.
221
+ if gen_duration and gen_duration > 0:
222
+ gen_dur = float(gen_duration)
223
+ else:
224
+ gen_dur = estimate_duration(prompt, duration_multiplier)
225
  fps = 25.0
226
  n_frames = int(round(gen_dur * fps)) + 1
227
  n_frames = ((n_frames - 1 + 4) // 8) * 8 + 1
 
270
  logging.info(f"Prompt: {time.time()-t0:.2f}s")
271
 
272
  # Denoiser
273
+ resc = auto_rescale_for_cfg(cfg_scale) if rescale_scale == "auto" else float(rescale_scale)
274
+ if rescale_scale == "auto":
275
+ logging.info(f"Auto rescale_scale = {resc:.2f} for cfg={cfg_scale}")
276
  guider = MultiModalGuider(
277
  params=MultiModalGuiderParams(
278
  cfg_scale=cfg_scale, stg_scale=stg_scale,
279
+ stg_blocks=[29], rescale_scale=resc, modality_scale=1.0,
280
  ),
281
  negative_context=a_ctx_neg,
282
  )