Spaces:
Running on Zero
Running on Zero
Manmay Nakhashi commited on
Commit ·
5cc51a5
1
Parent(s): e694869
Add explicit Target Duration slider (0–60s) + gen_duration kwarg
Browse files- TTSServer.generate gains gen_duration: float = 0.0 (override estimator)
- Gradio app exposes a 'Target duration (s)' slider (0 = auto)
- inference.py --gen-duration help text expanded
- Auto-rescale and end-of-clip patch already active for long outputs
For music / multi-section scenes set the slider to 20–60 s and the
auto-rescale schedule keeps the output safe at any cfg.
- app.py +14 -6
- src/inference.py +10 -3
- src/inference_server.py +47 -5
app.py
CHANGED
|
@@ -107,7 +107,8 @@ EXAMPLES: list[tuple[str, str, str]] = [
|
|
| 107 |
|
| 108 |
|
| 109 |
@spaces.GPU(duration=120)
|
| 110 |
-
def on_generate(prompt: str, audio_ref, cfg: float, stg: float, dur_mult: float,
|
|
|
|
| 111 |
if not prompt or not prompt.strip():
|
| 112 |
raise gr.Error("Prompt is empty.")
|
| 113 |
t0 = time.time()
|
|
@@ -119,6 +120,7 @@ def on_generate(prompt: str, audio_ref, cfg: float, stg: float, dur_mult: float,
|
|
| 119 |
voice_ref=ref_path,
|
| 120 |
cfg_scale=cfg, stg_scale=stg,
|
| 121 |
duration_multiplier=dur_mult, seed=int(seed),
|
|
|
|
| 122 |
)
|
| 123 |
elapsed = time.time() - t0
|
| 124 |
logging.info(f"Generated in {elapsed:.2f}s -> {output}")
|
|
@@ -159,7 +161,11 @@ with gr.Blocks(
|
|
| 159 |
with gr.Accordion("Inference settings", open=False):
|
| 160 |
cfg_slider = gr.Slider(1.0, 10.0, value=2.5, step=0.5, label="CFG scale")
|
| 161 |
stg_slider = gr.Slider(0.0, 5.0, value=1.5, step=0.5, label="STG scale")
|
| 162 |
-
dur_slider = gr.Slider(0.8, 2.0, value=1.1, step=0.05,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
seed_input = gr.Number(value=42, label="Seed", precision=0)
|
| 164 |
audio_out = gr.Audio(label="Generated audio", type="filepath")
|
| 165 |
with gr.Accordion("Prompt writing guide", open=False):
|
|
@@ -176,7 +182,8 @@ with gr.Blocks(
|
|
| 176 |
|
| 177 |
gen_btn.click(
|
| 178 |
on_generate,
|
| 179 |
-
inputs=[prompt_box, audio_ref, cfg_slider, stg_slider,
|
|
|
|
| 180 |
outputs=[audio_out],
|
| 181 |
)
|
| 182 |
|
|
@@ -185,15 +192,16 @@ with gr.Blocks(
|
|
| 185 |
gr.Examples(
|
| 186 |
label="🎬 Click any row to generate a sample",
|
| 187 |
examples=[
|
| 188 |
-
[name, prompt, voice_path, 2.5, 1.5, 1.1, 42]
|
| 189 |
for name, voice_path, prompt in EXAMPLES
|
| 190 |
],
|
| 191 |
example_labels=[name for name, _, _ in EXAMPLES],
|
| 192 |
inputs=[gr.Textbox(visible=False, label="Scene"),
|
| 193 |
prompt_box, audio_ref,
|
| 194 |
-
cfg_slider, stg_slider, dur_slider, seed_input],
|
| 195 |
outputs=[audio_out],
|
| 196 |
-
fn=lambda _name, prompt, ref, cfg, stg, dur, seed: on_generate(
|
|
|
|
| 197 |
cache_examples=False,
|
| 198 |
run_on_click=True,
|
| 199 |
examples_per_page=20,
|
|
|
|
| 107 |
|
| 108 |
|
| 109 |
@spaces.GPU(duration=120)
|
| 110 |
+
def on_generate(prompt: str, audio_ref, cfg: float, stg: float, dur_mult: float,
|
| 111 |
+
gen_dur: float, seed: int):
|
| 112 |
if not prompt or not prompt.strip():
|
| 113 |
raise gr.Error("Prompt is empty.")
|
| 114 |
t0 = time.time()
|
|
|
|
| 120 |
voice_ref=ref_path,
|
| 121 |
cfg_scale=cfg, stg_scale=stg,
|
| 122 |
duration_multiplier=dur_mult, seed=int(seed),
|
| 123 |
+
gen_duration=float(gen_dur),
|
| 124 |
)
|
| 125 |
elapsed = time.time() - t0
|
| 126 |
logging.info(f"Generated in {elapsed:.2f}s -> {output}")
|
|
|
|
| 161 |
with gr.Accordion("Inference settings", open=False):
|
| 162 |
cfg_slider = gr.Slider(1.0, 10.0, value=2.5, step=0.5, label="CFG scale")
|
| 163 |
stg_slider = gr.Slider(0.0, 5.0, value=1.5, step=0.5, label="STG scale")
|
| 164 |
+
dur_slider = gr.Slider(0.8, 2.0, value=1.1, step=0.05,
|
| 165 |
+
label="Duration × (only used when target duration = 0)")
|
| 166 |
+
gen_dur_slider = gr.Slider(0.0, 60.0, value=0.0, step=1.0,
|
| 167 |
+
label="Target duration (s) — 0 = auto from prompt; "
|
| 168 |
+
"set higher (≥20s) for long-form music or scenes")
|
| 169 |
seed_input = gr.Number(value=42, label="Seed", precision=0)
|
| 170 |
audio_out = gr.Audio(label="Generated audio", type="filepath")
|
| 171 |
with gr.Accordion("Prompt writing guide", open=False):
|
|
|
|
| 182 |
|
| 183 |
gen_btn.click(
|
| 184 |
on_generate,
|
| 185 |
+
inputs=[prompt_box, audio_ref, cfg_slider, stg_slider,
|
| 186 |
+
dur_slider, gen_dur_slider, seed_input],
|
| 187 |
outputs=[audio_out],
|
| 188 |
)
|
| 189 |
|
|
|
|
| 192 |
gr.Examples(
|
| 193 |
label="🎬 Click any row to generate a sample",
|
| 194 |
examples=[
|
| 195 |
+
[name, prompt, voice_path, 2.5, 1.5, 1.1, 0.0, 42]
|
| 196 |
for name, voice_path, prompt in EXAMPLES
|
| 197 |
],
|
| 198 |
example_labels=[name for name, _, _ in EXAMPLES],
|
| 199 |
inputs=[gr.Textbox(visible=False, label="Scene"),
|
| 200 |
prompt_box, audio_ref,
|
| 201 |
+
cfg_slider, stg_slider, dur_slider, gen_dur_slider, seed_input],
|
| 202 |
outputs=[audio_out],
|
| 203 |
+
fn=lambda _name, prompt, ref, cfg, stg, dur, gen_dur, seed: on_generate(
|
| 204 |
+
prompt, ref, cfg, stg, dur, gen_dur, seed),
|
| 205 |
cache_examples=False,
|
| 206 |
run_on_click=True,
|
| 207 |
examples_per_page=20,
|
src/inference.py
CHANGED
|
@@ -230,7 +230,10 @@ def parse_args():
|
|
| 230 |
p.add_argument("--output", default="tts_output.wav")
|
| 231 |
|
| 232 |
p.add_argument("--ref-duration", type=float, default=10.0, help="Seconds of voice reference to use")
|
| 233 |
-
p.add_argument("--gen-duration", type=float, default=0.0,
|
|
|
|
|
|
|
|
|
|
| 234 |
p.add_argument("--pad-start", type=float, default=0.0,
|
| 235 |
help="Prepend N seconds of silent padding, trimmed after decode (use 0 for clean starts)")
|
| 236 |
p.add_argument("--speed", type=float, default=1.0)
|
|
@@ -260,7 +263,9 @@ def parse_args():
|
|
| 260 |
p.add_argument("--cfg-scale", type=float, default=None, help="CFG scale (auto: 1.0 distilled, 7.0 dev)")
|
| 261 |
p.add_argument("--stg-scale", type=float, default=None, help="STG scale (auto: 0.0 distilled, 1.0 dev)")
|
| 262 |
p.add_argument("--stg-block", type=int, default=29, help="Block index for STG perturbation")
|
| 263 |
-
p.add_argument("--rescale-scale", type=float, default=None,
|
|
|
|
|
|
|
| 264 |
p.add_argument("--modality-scale", type=float, default=None, help="Modality (auto: 1.0 distilled, 3.0 dev)")
|
| 265 |
p.add_argument("--cfg-clamp", type=float, default=0.0, help="Clamp guided pred std to N * cond std (0=disabled)")
|
| 266 |
p.add_argument("--steps", type=int, default=None, help="Override steps (auto: distilled sigmas / 30 dev)")
|
|
@@ -324,7 +329,9 @@ def main():
|
|
| 324 |
if args.stg_scale is None:
|
| 325 |
args.stg_scale = 0.0 if is_distilled else 1.0
|
| 326 |
if args.rescale_scale is None:
|
| 327 |
-
|
|
|
|
|
|
|
| 328 |
if args.modality_scale is None:
|
| 329 |
args.modality_scale = 1.0 if is_distilled else 3.0
|
| 330 |
if args.fps is None:
|
|
|
|
| 230 |
p.add_argument("--output", default="tts_output.wav")
|
| 231 |
|
| 232 |
p.add_argument("--ref-duration", type=float, default=10.0, help="Seconds of voice reference to use")
|
| 233 |
+
p.add_argument("--gen-duration", type=float, default=0.0,
|
| 234 |
+
help="Target output duration in seconds (0 = auto from prompt + multiplier). "
|
| 235 |
+
"Set explicitly for long-form prompts (e.g. --gen-duration 30 for music). "
|
| 236 |
+
"Outputs >20.5s automatically engage the end-of-clip silence-prior patch.")
|
| 237 |
p.add_argument("--pad-start", type=float, default=0.0,
|
| 238 |
help="Prepend N seconds of silent padding, trimmed after decode (use 0 for clean starts)")
|
| 239 |
p.add_argument("--speed", type=float, default=1.0)
|
|
|
|
| 263 |
p.add_argument("--cfg-scale", type=float, default=None, help="CFG scale (auto: 1.0 distilled, 7.0 dev)")
|
| 264 |
p.add_argument("--stg-scale", type=float, default=None, help="STG scale (auto: 0.0 distilled, 1.0 dev)")
|
| 265 |
p.add_argument("--stg-block", type=int, default=29, help="Block index for STG perturbation")
|
| 266 |
+
p.add_argument("--rescale-scale", type=float, default=None,
|
| 267 |
+
help="Latent CFG std-rescale (default auto: cfg-aware schedule that prevents "
|
| 268 |
+
"output clipping at high cfg; pass any float in [0,1] to override).")
|
| 269 |
p.add_argument("--modality-scale", type=float, default=None, help="Modality (auto: 1.0 distilled, 3.0 dev)")
|
| 270 |
p.add_argument("--cfg-clamp", type=float, default=0.0, help="Clamp guided pred std to N * cond std (0=disabled)")
|
| 271 |
p.add_argument("--steps", type=int, default=None, help="Override steps (auto: distilled sigmas / 30 dev)")
|
|
|
|
| 329 |
if args.stg_scale is None:
|
| 330 |
args.stg_scale = 0.0 if is_distilled else 1.0
|
| 331 |
if args.rescale_scale is None:
|
| 332 |
+
# Auto cfg-aware rescale: imported from inference_server to keep one source of truth.
|
| 333 |
+
from inference_server import auto_rescale_for_cfg
|
| 334 |
+
args.rescale_scale = 0.0 if is_distilled else auto_rescale_for_cfg(args.cfg_scale)
|
| 335 |
if args.modality_scale is None:
|
| 336 |
args.modality_scale = 1.0 if is_distilled else 3.0
|
| 337 |
if args.fps is None:
|
src/inference_server.py
CHANGED
|
@@ -60,6 +60,34 @@ def estimate_duration(prompt, multiplier=1.1):
|
|
| 60 |
return max(3.0, round(base * multiplier, 1))
|
| 61 |
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
class TTSServer:
|
| 64 |
def __init__(self, checkpoint=None, full_checkpoint=None, gemma_root=None,
|
| 65 |
device="cuda", dtype="bf16", compile_model=True, bnb_4bit=True):
|
|
@@ -177,12 +205,23 @@ class TTSServer:
|
|
| 177 |
|
| 178 |
@torch.inference_mode()
|
| 179 |
def generate(self, prompt, voice_ref=None, cfg_scale=2.5, stg_scale=1.5,
|
| 180 |
-
duration_multiplier=1.1, seed=42, ref_duration=10.0
|
| 181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
t_total = time.time()
|
| 183 |
|
| 184 |
-
# Duration + target shape
|
| 185 |
-
|
|
|
|
|
|
|
|
|
|
| 186 |
fps = 25.0
|
| 187 |
n_frames = int(round(gen_dur * fps)) + 1
|
| 188 |
n_frames = ((n_frames - 1 + 4) // 8) * 8 + 1
|
|
@@ -231,10 +270,13 @@ class TTSServer:
|
|
| 231 |
logging.info(f"Prompt: {time.time()-t0:.2f}s")
|
| 232 |
|
| 233 |
# Denoiser
|
|
|
|
|
|
|
|
|
|
| 234 |
guider = MultiModalGuider(
|
| 235 |
params=MultiModalGuiderParams(
|
| 236 |
cfg_scale=cfg_scale, stg_scale=stg_scale,
|
| 237 |
-
stg_blocks=[29], rescale_scale=
|
| 238 |
),
|
| 239 |
negative_context=a_ctx_neg,
|
| 240 |
)
|
|
|
|
| 60 |
return max(3.0, round(base * multiplier, 1))
|
| 61 |
|
| 62 |
|
| 63 |
+
def auto_rescale_for_cfg(cfg: float) -> float:
|
| 64 |
+
"""CFG-aware std-rescale schedule that prevents output clipping at high cfg.
|
| 65 |
+
|
| 66 |
+
The CFG formula `pred = cond + (cfg-1)*(cond - uncond)` makes pred.std()
|
| 67 |
+
grow roughly linearly with cfg, which the audio VAE+vocoder render as
|
| 68 |
+
progressively louder waveforms. By cfg≈3 the output starts hard-clipping
|
| 69 |
+
at 0 dBFS — and clipped information is unrecoverable in post.
|
| 70 |
+
|
| 71 |
+
Empirical sweep on the blues prompt with the back-porch-boogie ref
|
| 72 |
+
(rescale_scale needed for ≥1 dB peak headroom):
|
| 73 |
+
cfg=2.5 → 0.2 ; cfg=3 → 0.6 ; cfg=4 → 0.8 ; cfg=5–8 → 0.8 ; cfg=10 → 1.0
|
| 74 |
+
|
| 75 |
+
Piecewise-linear fit through those points; returns 0 below cfg=2 (no CFG
|
| 76 |
+
even applied at cfg=1), plateaus at 0.8 between cfg=4 and cfg=8 to
|
| 77 |
+
preserve the "extra punch" of high-CFG generations, and ramps to 1.0 by
|
| 78 |
+
cfg=10.
|
| 79 |
+
"""
|
| 80 |
+
if cfg <= 2.0:
|
| 81 |
+
return 0.0
|
| 82 |
+
if cfg <= 3.0:
|
| 83 |
+
return 0.6 * (cfg - 2.0) # 0 → 0.6
|
| 84 |
+
if cfg <= 4.0:
|
| 85 |
+
return 0.6 + 0.2 * (cfg - 3.0) # 0.6 → 0.8
|
| 86 |
+
if cfg <= 8.0:
|
| 87 |
+
return 0.8 # plateau
|
| 88 |
+
return min(1.0, 0.8 + 0.1 * (cfg - 8.0)) # 0.8 → 1.0 at cfg=10
|
| 89 |
+
|
| 90 |
+
|
| 91 |
class TTSServer:
|
| 92 |
def __init__(self, checkpoint=None, full_checkpoint=None, gemma_root=None,
|
| 93 |
device="cuda", dtype="bf16", compile_model=True, bnb_4bit=True):
|
|
|
|
| 205 |
|
| 206 |
@torch.inference_mode()
|
| 207 |
def generate(self, prompt, voice_ref=None, cfg_scale=2.5, stg_scale=1.5,
|
| 208 |
+
duration_multiplier=1.1, seed=42, ref_duration=10.0,
|
| 209 |
+
rescale_scale="auto", gen_duration: float = 0.0):
|
| 210 |
+
"""Generate audio. Returns (waveform_path, duration_seconds).
|
| 211 |
+
|
| 212 |
+
rescale_scale: latent-side CFG std-rescale that prevents clipping at
|
| 213 |
+
high cfg. Set to "auto" (default) for the cfg-aware schedule, a
|
| 214 |
+
float in [0, 1] for a fixed override, or 0 to disable.
|
| 215 |
+
gen_duration: explicit target duration in seconds. 0 (default) → auto
|
| 216 |
+
from prompt + duration_multiplier; >0 overrides everything else.
|
| 217 |
+
"""
|
| 218 |
t_total = time.time()
|
| 219 |
|
| 220 |
+
# Duration + target shape — explicit gen_duration wins over the estimator.
|
| 221 |
+
if gen_duration and gen_duration > 0:
|
| 222 |
+
gen_dur = float(gen_duration)
|
| 223 |
+
else:
|
| 224 |
+
gen_dur = estimate_duration(prompt, duration_multiplier)
|
| 225 |
fps = 25.0
|
| 226 |
n_frames = int(round(gen_dur * fps)) + 1
|
| 227 |
n_frames = ((n_frames - 1 + 4) // 8) * 8 + 1
|
|
|
|
| 270 |
logging.info(f"Prompt: {time.time()-t0:.2f}s")
|
| 271 |
|
| 272 |
# Denoiser
|
| 273 |
+
resc = auto_rescale_for_cfg(cfg_scale) if rescale_scale == "auto" else float(rescale_scale)
|
| 274 |
+
if rescale_scale == "auto":
|
| 275 |
+
logging.info(f"Auto rescale_scale = {resc:.2f} for cfg={cfg_scale}")
|
| 276 |
guider = MultiModalGuider(
|
| 277 |
params=MultiModalGuiderParams(
|
| 278 |
cfg_scale=cfg_scale, stg_scale=stg_scale,
|
| 279 |
+
stg_blocks=[29], rescale_scale=resc, modality_scale=1.0,
|
| 280 |
),
|
| 281 |
negative_context=a_ctx_neg,
|
| 282 |
)
|