coda / engine.py
blackboxanalytics's picture
Harden best-of-N selection: reject dropouts and squashed takes
1cca94c
Raw
History Blame Contribute Delete
22.8 kB
"""engine.py — Stable Audio 3 Small Music continuation core.
CODA's job is one thing done well: take a short, unfinished clip and continue it
into a finished-sounding track in the same key, tempo and feel. SA3 does that in
a SINGLE call. Its `generate_diffusion_cond_inpaint` is a native audio-inpainting
diffusion sampler: place the user's clip at the front of the buffer, mask the
region after it, and the model fills the masked region conditioned on the kept
audio — true long-form continuation, 44.1 kHz stereo, no multi-pass chaining
and no energy guards.
Candidate selection (the deployed quality fix): the lab verified a take with a
pinned seed on torch 2.7.1, but ZeroGPU runs a different torch (2.8–2.11), so the
same seed no longer reproduces it — it just freezes one arbitrary draw, which on
the deployed build was the bad "sporadic loud synth noise" take. Instead of
trusting a magic seed, CODA draws a few candidates and keeps the cleanest by a
cheap artifact score (it rejects both failure modes: loud random bursts AND
silence collapse). A wall-clock budget + early-accept keep this inside the
ZeroGPU window, so it costs at most a couple of extra fast 8-step draws.
This module is the whole generation core. It returns ONLY the newly generated
tail plus the source length in seconds; `stitch.py` joins that tail onto the
user's *pristine* original so the real recording (and any vocals) plays
untouched up to the seam.
Bounded lead-in (the deployed-bug fix): SA3 only needs a short run-up to know
where the song is going. We therefore condition on at most MAX_LEAD_SECONDS of
the clip's TAIL, not the whole clip. Feeding a long clip (e.g. 100 s) into the
buffer and masking only a few seconds makes the 8-step distilled sampler
collapse to near-silence in that tiny window — the bug that shipped. A bounded
lead keeps the masked (generated) region substantial and healthy, and because
stitch rejoins the tail onto the full pristine original, the listener still
hears their entire clip before the seam.
Mask convention (verified against the installed library source):
inpaint_mask = ones(buffer); inpaint_mask[start:end] = 0
-> 1 = keep the input audio, 0 = generate. We place `lead` seconds of source
at the front and mask [lead, lead+new], so SA3 keeps the lead and generates a
fresh `new`-second tail that continues from the clip's end.
"""
import contextlib
import time
import numpy as np
import torch
# --- H200 (sm_90) fp16 numerical-stability hardening -------------------------
# CODA sounds perfect on local Blackwell (sm_120) but produced "sporadic loud
# bursts" on ZeroGPU's H200. Root cause is fp16 numerical instability on H200,
# not a torch version: (1) torch 2.8 changed cuDNN SDPA backend selection for
# H200 and the cuDNN/mem-efficient fp16 SDPA kernels have known NaN/garbage bugs
# (pytorch#139298/#166211/#124877/#112577) — exactly the "loud burst" signature;
# (2) TF32 matmul accumulation adds error that compounds through the 8-step
# distilled sampler. We pin the math (reference) SDPA backend around generation
# and force full-precision matmul accumulation. Both are no-ops on CPU.
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
try:
from torch.nn.attention import SDPBackend, sdpa_kernel as _sdpa_kernel
except Exception: # very old/unexpected torch
SDPBackend = None
_sdpa_kernel = None
def _sdpa_math():
"""Context manager that forces F.scaled_dot_product_attention onto the MATH
(reference) backend for the wrapped call. SA3's transformer calls SDPA with
no backend pin and casts q/k/v to fp16; on H200 torch may route that to a
cuDNN/mem-efficient kernel with the known fp16 NaN bug. MATH is the stable
reference path. Returns a fresh context each call; a no-op if the API is
missing (older torch) or on CPU."""
if _sdpa_kernel is None:
return contextlib.nullcontext()
return _sdpa_kernel(SDPBackend.MATH)
MODEL_ID = "stabilityai/stable-audio-3-small-music"
SR = 44100 # SA3 native sample rate (model_config: sample_rate)
STEPS = 8 # SA3 Small is an 8-step adversarially-distilled model.
# It was tuned for 8-step pingpong; pushing it to 16/25
# steps is OFF its distilled regime and tends to ADD
# artifacts, not remove them. We stay at 8 and fix
# quality by picking the best of a few draws instead.
SAMPLER = "pingpong" # the sampler the distilled model was tuned for
DEFAULT_CFG = 1.0 # distilled-model guidance; the prompt still conditions
# at 1.0 (CFG amplification off, conditional path on)
# Best-of-N: with a random seed each draw differs, so we generate a few and keep
# the cleanest. Bounded so it never blows the ZeroGPU window.
DEFAULT_CANDIDATES = 5 # how many draws to consider when no seed is pinned.
# Raised from 3: the quality bar is the BEST draw, not
# the first clean-ish one, and a fast H200 draw is
# cheap. Early-accept + the GPU budget still short-
# circuit when an early draw is already great.
GPU_BUDGET_SECONDS = 85.0 # stop drawing once this much wall-clock is spent
# (the @spaces.GPU window is 120s; leave slack)
EARLY_ACCEPT_SCORE = 3.5 # a draw this clean is taken immediately, no re-draw.
# Tightened from 4.0 so a merely-okay draw doesn't
# short-circuit the search for a genuinely good one.
DROPOUT_FLOOR = 0.12 # quietest sustained 0.2 s below this fraction of the
# tail median counts as a mid-tail hole (re-draw it)
CREST_FLOOR = 4.0 # peak/RMS below this is a squashed, transient-less
# wash (real music here is ~6-8); penalize it
CREST_SCALE = 1.5 # how hard a collapsed crest is penalized
MAX_TOTAL_SECONDS = 120 # SA3 Small duration cap (sample_size / sample_rate)
MIN_NEW_SECONDS = 5 # below this a "continuation" isn't worth a GPU call
MAX_LEAD_SECONDS = 30 # how much of the clip's TAIL to feed SA3 as run-up.
# SA3 generates a healthy continuation from a bounded
# lead-in; keeping a very long source in the buffer and
# masking only a few seconds makes the distilled sampler
# produce near-silence. 30 s is inside the model's
# healthy range (the verified lab takes used ~30 s leads)
# and the splice restores the full clip anyway.
_model = None
_model_config = None
_sample_size = None
_on_device = None # which device the weights currently live on
def _device():
return "cuda" if torch.cuda.is_available() else "cpu"
def preload():
"""Load model + autoencoder + T5Gemma conditioner into CPU RAM at process
start. On ZeroGPU the per-call GPU window is the scarce resource, so weights
must come off disk at boot, not inside the window. The CUDA placement + fp16
cast is deferred to the first `continue_audio` call (i.e. the @spaces.GPU
window), matching how Stability's own Space defers it."""
global _model, _model_config, _sample_size
if _model is None:
from stable_audio_tools import get_pretrained_model
_model, _model_config = get_pretrained_model(MODEL_ID)
_sample_size = int(_model_config["sample_size"])
_model.eval()
print(f"[coda] preload: SA3 resident "
f"(sr={_model_config['sample_rate']}, "
f"sample_size={_sample_size} = "
f"{_sample_size / int(_model_config['sample_rate']):.0f}s)",
flush=True)
return _model, _model_config
def _ensure_on_device():
"""Ensure weights are on the GPU with the right per-component dtype. Called
inside the @spaces.GPU window on every generation. `.to()` is a cheap no-op
when already placed, so we re-ensure each call rather than caching device
state — that stays correct even if ZeroGPU detaches the GPU between calls.
Dtype split (the H200 quality fix): the DiT runs in fp16 (fast, and Stability
loads the model fp16), but the autoencoder/pretransform runs in fp32. SA3's
Oobleck/DAC decoder uses Snake1d activations — x + (1/(beta+1e-9))*sin(a*x)^2
— whose reciprocal*sin^2 term can exceed fp16's ~65504 ceiling, giving
inf/NaN -> loud broadband noise on decode. fp32 there has ample headroom.
sampling.py casts the sampled latents to the pretransform's dtype right
before `decode`, so an fp32 pretransform makes the WHOLE encode+decode path
fp32 with no library change; only the DiT stays fp16. All fp16/fp32 here is
CUDA-only; on CPU the model stays fp32 throughout."""
global _model, _on_device
dev = _device()
_model = _model.to(dev)
if dev == "cuda":
_model = _model.to(torch.float16)
# keep the autoencoder in fp32 (Snake1d fp16 overflow guard)
if getattr(_model, "pretransform", None) is not None:
_model.pretransform.to(torch.float32)
_on_device = dev
return _model
def _load_source(clip_path):
"""Load the clip as stereo float32 @44.1k as a (2, N) tensor. SA3's
autoencoder is stereo; `prepare_audio` inside the sampler will pad/crop to
the buffer length and place this at the FRONT (PadCrop, randomize=False)."""
import librosa
y, _ = librosa.load(clip_path, sr=SR, mono=False)
y = np.asarray(y, dtype=np.float32)
if y.ndim == 1:
y = np.stack([y, y]) # mono -> stereo
elif y.shape[0] > 2:
y = y[:2]
return torch.from_numpy(np.ascontiguousarray(y))
#: longest clip that still leaves room for MIN_NEW of continuation under the cap
MAX_SOURCE_SECONDS = MAX_TOTAL_SECONDS - MIN_NEW_SECONDS
def plan_continuation(source_seconds, total_seconds):
"""Pure helper (unit-testable, no model): turn a (source, requested-total)
pair into the SA3 generation buffer and return (lead, new_seconds, buffer).
- `lead` : seconds of the clip's TAIL used as run-up context, capped at
MAX_LEAD_SECONDS so a long clip can't drown the masked region.
- `new` : seconds of fresh audio to generate. We extend to the requested
finished length (`total - source`), floored at MIN_NEW_SECONDS
so every call earns its GPU time, and bounded so the buffer
(lead + new) never exceeds SA3's MAX_TOTAL_SECONDS cap.
- `buffer` : lead + new, i.e. the full generation buffer. The mask runs
[lead, buffer]; buffer > lead always, so it never inverts.
Raises ValueError only for a clip longer than MAX_SOURCE_SECONDS — at that
point it's a full track, not an unfinished clip to continue.
"""
source_seconds = float(source_seconds)
total_seconds = float(total_seconds)
if source_seconds > MAX_SOURCE_SECONDS:
raise ValueError(
f"clip is {source_seconds:.0f}s — that's a finished-length track, "
f"not an unfinished clip. CODA continues clips up to "
f"{MAX_SOURCE_SECONDS:.0f}s; trim it shorter and re-upload.")
total_seconds = min(total_seconds, MAX_TOTAL_SECONDS)
lead = min(source_seconds, MAX_LEAD_SECONDS)
# extend to the requested finished length; floor at MIN_NEW, and never let
# lead + new exceed the buffer cap.
new_seconds = max(total_seconds - source_seconds, MIN_NEW_SECONDS)
new_seconds = min(new_seconds, MAX_TOTAL_SECONDS - lead)
buffer_seconds = lead + new_seconds
return lead, new_seconds, buffer_seconds
def _tail_artifact_score(tail, sr=SR):
"""Lower is better. A blind, ear-free quality score for a generated tail,
used to pick the cleanest of several candidate draws.
It targets the four ways an SA3 draw goes bad:
* "sporadic loud random synth noises" — even a FEW short windows far louder
than the body push the loudest window way above the median. (After the
whole-buffer peak-normalize, a burst that set the peak crushes the body,
making the gap larger still.) Sustained dynamics rarely make any single
50 ms window many times the median, so musical loudness doesn't trip it.
* silence collapse — a near-silent WHOLE tail is caught by the loudness
floor below (it keys off the tail's overall RMS).
* mid-tail dropout — a brief near-silent HOLE inside an otherwise healthy
tail. This is the gap the first terms miss: overall RMS stays high (so the
silence floor never fires) and max/median stays low (so spikiness never
fires), yet a listener plainly hears the music cut out for a beat. We
detect it as a SUSTAINED quiet stretch — the quietest ~0.2 s envelope
falling well below the median.
* dynamics/transient collapse — a draw can be perfectly tonal and steady yet
sound DULL and lifeless: its transients are smeared, so there's no attack,
just a wall of mush. Flatness and loudness checks all read "clean". It
shows up as a collapsed crest factor (peak/RMS): real music here sits at
crest ~6-8, a squashed draw falls to ~2-3. We penalize a low crest so
best-of-N prefers the punchy draw over the mushy one.
Score = max/median + silence + dropout + dynamics penalties.
Computed on a mono mix over short (~50 ms) windows. A flat, steady signal
scores ~1; loud bursts, a crushed body, a mid-tail hole, or a smeared,
transient-less wash all score high.
"""
mono = tail.mean(axis=0) if tail.ndim == 2 else np.asarray(tail)
mono = np.asarray(mono, dtype=np.float64)
win = max(1, int(0.05 * sr))
if mono.size < win * 4:
return float("inf") # too short to judge — avoid it
n = mono.size // win
energies = np.sqrt(
np.mean(mono[:n * win].reshape(n, win) ** 2, axis=1) + 1e-12)
median = float(np.median(energies)) + 1e-9
loudest = float(np.max(energies))
spikiness = loudest / median
overall = float(np.sqrt(np.mean(mono ** 2)) + 1e-12)
silence_penalty = 0.0 if overall > 0.02 else (0.02 - overall) * 200.0
# mid-tail dropout: smooth the window energies over ~0.2 s and find how far
# the quietest SUSTAINED stretch falls below the median. Exclude the final
# 0.5 s so a natural ending taper (which stitch fades anyway) isn't punished.
# A clean tail's quietest 0.2 s sits ~0.15-0.4x the median -> no penalty; a
# real hole drops to <0.1x -> a penalty large enough to lose the early-accept
# and force another draw, so best-of-N rolls past the glitch.
dropout_penalty = 0.0
smooth = np.convolve(energies, np.ones(4) / 4.0, mode="valid")
guard = int(0.5 / 0.05) # last 0.5 s of windows
body = smooth[:-guard] if smooth.size > guard + 4 else smooth
if body.size:
dropout = float(np.min(body)) / median
dropout_penalty = min(8.0, max(0.0, DROPOUT_FLOOR / max(dropout, 1e-3)
- 1.0) * 2.0)
# dynamics/transient collapse: crest = peak / RMS. The tail is peak-normalized
# to ~1.0, so this is essentially 1/RMS — a squashed, attack-less wash reads
# high RMS (low crest); a punchy, dynamic take reads low RMS (high crest).
# Penalize only a clearly collapsed crest, so we never punish a naturally
# dynamic draw.
peak = float(np.abs(mono).max())
crest = peak / overall
crest_penalty = max(0.0, CREST_FLOOR - crest) * CREST_SCALE
return spikiness + silence_penalty + dropout_penalty + crest_penalty
def continue_audio(clip_path, total_seconds, prompt="", cfg_scale=DEFAULT_CFG,
seed=-1, candidates=None, progress=None):
"""Continue `clip_path` up to `total_seconds` with SA3 inpainting.
With the default `seed` (< 0) this draws up to `candidates` SA3 inpaint
takes and returns the cleanest by `_tail_artifact_score` (early-accepting a
clean draw and respecting a GPU wall-clock budget). Pin `seed` >= 0 for a
single deterministic draw (debug/repro).
Returns (new_tail, source_seconds, SR) where:
new_tail : (2, M) float32 @44.1k — ONLY the generated region
[source_end, total]. Peak-normalized to <= 1.0.
source_seconds : the clip's true length (the splice boundary, in seconds)
SR : 44100
`progress(stage_name)` is called (best-effort) at each stage so the UI can
paint a live status. Progress is stage-based (read / compose / finalize).
"""
from einops import rearrange
from stable_audio_tools.inference.generation import (
generate_diffusion_cond_inpaint)
def _notify(stage):
if progress is not None:
try:
progress(stage)
except Exception as e:
print(f"[coda] progress callback failed ({e})", flush=True)
preload()
model = _ensure_on_device()
dev = _device()
# Seed policy. The library does `np.random.randint(0, 2**32-1)` when
# seed == -1, which overflows int32 on Windows/numpy<2, so we always draw
# our own safe seeds. A caller that pins a seed (>= 0) gets exactly one
# deterministic draw (reproducibility/debug paths). The default path
# (seed < 0) draws several candidates and keeps the cleanest — that's what
# the app uses, because a single pinned seed doesn't survive a torch change.
pinned = seed is not None and seed >= 0
n_candidates = 1 if pinned else max(1, int(candidates or DEFAULT_CANDIDATES))
if pinned:
seeds = [int(seed)]
else:
base = int(np.random.randint(0, 2 ** 31 - 1))
# spread the seeds far apart so the draws are genuinely different
seeds = [(base + i * 0x9E3779B1) % (2 ** 31 - 1)
for i in range(n_candidates)]
_notify("reading")
source = _load_source(clip_path)
source_seconds = source.shape[-1] / SR
lead, new_seconds, buffer_seconds = plan_continuation(
source_seconds, total_seconds)
# condition on only the TAIL `lead` seconds of the clip. This is the bug fix:
# a long source no longer fills the buffer and starves the masked region.
lead_samples = min(int(round(lead * SR)), source.shape[-1])
lead_audio = source[:, -lead_samples:]
# the lead is encoded by the autoencoder (pretransform), which we now run in
# fp32; its conv1d rejects a mismatched input dtype, so match the
# pretransform's parameter dtype (fp32 on CUDA), NOT the DiT's fp16. Resample
# is a no-op here (clip is already 44.1k), so prepare_audio keeps this dtype.
ae = getattr(model, "pretransform", None)
ae_dtype = (next(ae.parameters()).dtype if ae is not None
else next(model.model.parameters()).dtype)
lead_audio = lead_audio.to(ae_dtype)
mask_start, mask_end = lead, buffer_seconds
prompt = (prompt or "").strip()
print(f"[coda] continuation: source={source_seconds:.1f}s, "
f"lead={lead:.1f}s -> buffer={buffer_seconds:.1f}s "
f"(+{new_seconds:.1f}s new), mask=[{mask_start:.1f}s, {mask_end:.1f}s], "
f"steps={STEPS}, cfg={cfg_scale}, prompt={prompt!r}", flush=True)
def _draw(draw_seed):
"""One full SA3 inpaint draw -> normalized generated tail (2, M)."""
# _sdpa_math() forces the reference attention backend for the whole
# sample loop — the H200 fp16-SDPA-garbage guard (no-op on CPU).
with torch.no_grad(), _sdpa_math():
output = generate_diffusion_cond_inpaint(
model,
steps=STEPS,
cfg_scale=cfg_scale,
conditioning=[{"prompt": prompt, "seconds_total": buffer_seconds}],
sample_size=_sample_size,
sampler_type=SAMPLER,
inpaint_audio=(SR, lead_audio),
inpaint_mask_start_seconds=mask_start,
inpaint_mask_end_seconds=mask_end,
seed=int(draw_seed),
device=dev,
)
# (b, d, n) -> (d, b*n); peak-normalize like Stability's reference Space
# (unchanged from the verified-good local path).
output = rearrange(output, "b d n -> d (b n)")
audio = output.to(torch.float32).cpu().numpy()
peak = float(np.abs(audio).max())
if peak > 1e-9:
audio = audio / peak
if audio.shape[0] == 1: # safety: ensure stereo
audio = np.repeat(audio, 2, axis=0)
# the generated region is [lead, buffer]; lead ends at the clip's true
# end, so this slice is the continuation that follows the source.
start = int(round(lead * SR))
end = min(int(round(buffer_seconds * SR)), audio.shape[-1])
return np.ascontiguousarray(audio[:, start:end].astype(np.float32))
_notify("composing")
# Best-of-N: draw, score, keep the cleanest. Early-accept a clean draw and
# stop if the GPU wall-clock budget runs low, so this never blows the window.
best_tail, best_score, best_seed = None, float("inf"), None
t0 = time.time()
for i, draw_seed in enumerate(seeds):
tail = _draw(draw_seed)
score = _tail_artifact_score(tail, SR)
elapsed = time.time() - t0
print(f"[coda] candidate {i + 1}/{len(seeds)} seed={draw_seed} "
f"shape={tail.shape} ({tail.shape[-1] / SR:.1f}s) "
f"artifact_score={score:.2f} "
f"rms={float(np.sqrt(np.mean(tail ** 2))):.3f} "
f"elapsed={elapsed:.1f}s", flush=True)
if score < best_score:
best_tail, best_score, best_seed = tail, score, draw_seed
if best_score <= EARLY_ACCEPT_SCORE:
print(f"[coda] candidate {i + 1} clean enough "
f"(score {best_score:.2f} <= {EARLY_ACCEPT_SCORE}); accepting",
flush=True)
break
if elapsed > GPU_BUDGET_SECONDS and i + 1 < len(seeds):
print(f"[coda] GPU budget {GPU_BUDGET_SECONDS:.0f}s reached after "
f"{i + 1} draw(s); keeping best so far", flush=True)
break
_notify("finalizing")
print(f"[coda] selected seed={best_seed} artifact_score={best_score:.2f} "
f"tail={best_tail.shape[-1] / SR:.1f}s peak after norm "
f"{float(np.abs(best_tail).max()):.3f} "
f"rms {float(np.sqrt(np.mean(best_tail ** 2))):.3f}", flush=True)
return best_tail, source_seconds, SR