Spaces:
Running on Zero
Running on Zero
| """engine.py — Stable Audio 3 Small Music continuation core. | |
| CODA's job is one thing done well: take a short, unfinished clip and continue it | |
| into a finished-sounding track in the same key, tempo and feel. SA3 does that in | |
| a SINGLE call. Its `generate_diffusion_cond_inpaint` is a native audio-inpainting | |
| diffusion sampler: place the user's clip at the front of the buffer, mask the | |
| region after it, and the model fills the masked region conditioned on the kept | |
| audio — true long-form continuation, 44.1 kHz stereo, no multi-pass chaining | |
| and no energy guards. | |
| Candidate selection (the deployed quality fix): the lab verified a take with a | |
| pinned seed on torch 2.7.1, but ZeroGPU runs a different torch (2.8–2.11), so the | |
| same seed no longer reproduces it — it just freezes one arbitrary draw, which on | |
| the deployed build was the bad "sporadic loud synth noise" take. Instead of | |
| trusting a magic seed, CODA draws a few candidates and keeps the cleanest by a | |
| cheap artifact score (it rejects both failure modes: loud random bursts AND | |
| silence collapse). A wall-clock budget + early-accept keep this inside the | |
| ZeroGPU window, so it costs at most a couple of extra fast 8-step draws. | |
| This module is the whole generation core. It returns ONLY the newly generated | |
| tail plus the source length in seconds; `stitch.py` joins that tail onto the | |
| user's *pristine* original so the real recording (and any vocals) plays | |
| untouched up to the seam. | |
| Bounded lead-in (the deployed-bug fix): SA3 only needs a short run-up to know | |
| where the song is going. We therefore condition on at most MAX_LEAD_SECONDS of | |
| the clip's TAIL, not the whole clip. Feeding a long clip (e.g. 100 s) into the | |
| buffer and masking only a few seconds makes the 8-step distilled sampler | |
| collapse to near-silence in that tiny window — the bug that shipped. A bounded | |
| lead keeps the masked (generated) region substantial and healthy, and because | |
| stitch rejoins the tail onto the full pristine original, the listener still | |
| hears their entire clip before the seam. | |
| Mask convention (verified against the installed library source): | |
| inpaint_mask = ones(buffer); inpaint_mask[start:end] = 0 | |
| -> 1 = keep the input audio, 0 = generate. We place `lead` seconds of source | |
| at the front and mask [lead, lead+new], so SA3 keeps the lead and generates a | |
| fresh `new`-second tail that continues from the clip's end. | |
| """ | |
| import contextlib | |
| import time | |
| import numpy as np | |
| import torch | |
| # --- H200 (sm_90) fp16 numerical-stability hardening ------------------------- | |
| # CODA sounds perfect on local Blackwell (sm_120) but produced "sporadic loud | |
| # bursts" on ZeroGPU's H200. Root cause is fp16 numerical instability on H200, | |
| # not a torch version: (1) torch 2.8 changed cuDNN SDPA backend selection for | |
| # H200 and the cuDNN/mem-efficient fp16 SDPA kernels have known NaN/garbage bugs | |
| # (pytorch#139298/#166211/#124877/#112577) — exactly the "loud burst" signature; | |
| # (2) TF32 matmul accumulation adds error that compounds through the 8-step | |
| # distilled sampler. We pin the math (reference) SDPA backend around generation | |
| # and force full-precision matmul accumulation. Both are no-ops on CPU. | |
| torch.backends.cuda.matmul.allow_tf32 = False | |
| torch.backends.cudnn.allow_tf32 = False | |
| try: | |
| from torch.nn.attention import SDPBackend, sdpa_kernel as _sdpa_kernel | |
| except Exception: # very old/unexpected torch | |
| SDPBackend = None | |
| _sdpa_kernel = None | |
| def _sdpa_math(): | |
| """Context manager that forces F.scaled_dot_product_attention onto the MATH | |
| (reference) backend for the wrapped call. SA3's transformer calls SDPA with | |
| no backend pin and casts q/k/v to fp16; on H200 torch may route that to a | |
| cuDNN/mem-efficient kernel with the known fp16 NaN bug. MATH is the stable | |
| reference path. Returns a fresh context each call; a no-op if the API is | |
| missing (older torch) or on CPU.""" | |
| if _sdpa_kernel is None: | |
| return contextlib.nullcontext() | |
| return _sdpa_kernel(SDPBackend.MATH) | |
| MODEL_ID = "stabilityai/stable-audio-3-small-music" | |
| SR = 44100 # SA3 native sample rate (model_config: sample_rate) | |
| STEPS = 8 # SA3 Small is an 8-step adversarially-distilled model. | |
| # It was tuned for 8-step pingpong; pushing it to 16/25 | |
| # steps is OFF its distilled regime and tends to ADD | |
| # artifacts, not remove them. We stay at 8 and fix | |
| # quality by picking the best of a few draws instead. | |
| SAMPLER = "pingpong" # the sampler the distilled model was tuned for | |
| DEFAULT_CFG = 1.0 # distilled-model guidance; the prompt still conditions | |
| # at 1.0 (CFG amplification off, conditional path on) | |
| # Best-of-N: with a random seed each draw differs, so we generate a few and keep | |
| # the cleanest. Bounded so it never blows the ZeroGPU window. | |
| DEFAULT_CANDIDATES = 5 # how many draws to consider when no seed is pinned. | |
| # Raised from 3: the quality bar is the BEST draw, not | |
| # the first clean-ish one, and a fast H200 draw is | |
| # cheap. Early-accept + the GPU budget still short- | |
| # circuit when an early draw is already great. | |
| GPU_BUDGET_SECONDS = 85.0 # stop drawing once this much wall-clock is spent | |
| # (the @spaces.GPU window is 120s; leave slack) | |
| EARLY_ACCEPT_SCORE = 3.5 # a draw this clean is taken immediately, no re-draw. | |
| # Tightened from 4.0 so a merely-okay draw doesn't | |
| # short-circuit the search for a genuinely good one. | |
| DROPOUT_FLOOR = 0.12 # quietest sustained 0.2 s below this fraction of the | |
| # tail median counts as a mid-tail hole (re-draw it) | |
| CREST_FLOOR = 4.0 # peak/RMS below this is a squashed, transient-less | |
| # wash (real music here is ~6-8); penalize it | |
| CREST_SCALE = 1.5 # how hard a collapsed crest is penalized | |
| MAX_TOTAL_SECONDS = 120 # SA3 Small duration cap (sample_size / sample_rate) | |
| MIN_NEW_SECONDS = 5 # below this a "continuation" isn't worth a GPU call | |
| MAX_LEAD_SECONDS = 30 # how much of the clip's TAIL to feed SA3 as run-up. | |
| # SA3 generates a healthy continuation from a bounded | |
| # lead-in; keeping a very long source in the buffer and | |
| # masking only a few seconds makes the distilled sampler | |
| # produce near-silence. 30 s is inside the model's | |
| # healthy range (the verified lab takes used ~30 s leads) | |
| # and the splice restores the full clip anyway. | |
| _model = None | |
| _model_config = None | |
| _sample_size = None | |
| _on_device = None # which device the weights currently live on | |
| def _device(): | |
| return "cuda" if torch.cuda.is_available() else "cpu" | |
| def preload(): | |
| """Load model + autoencoder + T5Gemma conditioner into CPU RAM at process | |
| start. On ZeroGPU the per-call GPU window is the scarce resource, so weights | |
| must come off disk at boot, not inside the window. The CUDA placement + fp16 | |
| cast is deferred to the first `continue_audio` call (i.e. the @spaces.GPU | |
| window), matching how Stability's own Space defers it.""" | |
| global _model, _model_config, _sample_size | |
| if _model is None: | |
| from stable_audio_tools import get_pretrained_model | |
| _model, _model_config = get_pretrained_model(MODEL_ID) | |
| _sample_size = int(_model_config["sample_size"]) | |
| _model.eval() | |
| print(f"[coda] preload: SA3 resident " | |
| f"(sr={_model_config['sample_rate']}, " | |
| f"sample_size={_sample_size} = " | |
| f"{_sample_size / int(_model_config['sample_rate']):.0f}s)", | |
| flush=True) | |
| return _model, _model_config | |
| def _ensure_on_device(): | |
| """Ensure weights are on the GPU with the right per-component dtype. Called | |
| inside the @spaces.GPU window on every generation. `.to()` is a cheap no-op | |
| when already placed, so we re-ensure each call rather than caching device | |
| state — that stays correct even if ZeroGPU detaches the GPU between calls. | |
| Dtype split (the H200 quality fix): the DiT runs in fp16 (fast, and Stability | |
| loads the model fp16), but the autoencoder/pretransform runs in fp32. SA3's | |
| Oobleck/DAC decoder uses Snake1d activations — x + (1/(beta+1e-9))*sin(a*x)^2 | |
| — whose reciprocal*sin^2 term can exceed fp16's ~65504 ceiling, giving | |
| inf/NaN -> loud broadband noise on decode. fp32 there has ample headroom. | |
| sampling.py casts the sampled latents to the pretransform's dtype right | |
| before `decode`, so an fp32 pretransform makes the WHOLE encode+decode path | |
| fp32 with no library change; only the DiT stays fp16. All fp16/fp32 here is | |
| CUDA-only; on CPU the model stays fp32 throughout.""" | |
| global _model, _on_device | |
| dev = _device() | |
| _model = _model.to(dev) | |
| if dev == "cuda": | |
| _model = _model.to(torch.float16) | |
| # keep the autoencoder in fp32 (Snake1d fp16 overflow guard) | |
| if getattr(_model, "pretransform", None) is not None: | |
| _model.pretransform.to(torch.float32) | |
| _on_device = dev | |
| return _model | |
| def _load_source(clip_path): | |
| """Load the clip as stereo float32 @44.1k as a (2, N) tensor. SA3's | |
| autoencoder is stereo; `prepare_audio` inside the sampler will pad/crop to | |
| the buffer length and place this at the FRONT (PadCrop, randomize=False).""" | |
| import librosa | |
| y, _ = librosa.load(clip_path, sr=SR, mono=False) | |
| y = np.asarray(y, dtype=np.float32) | |
| if y.ndim == 1: | |
| y = np.stack([y, y]) # mono -> stereo | |
| elif y.shape[0] > 2: | |
| y = y[:2] | |
| return torch.from_numpy(np.ascontiguousarray(y)) | |
| #: longest clip that still leaves room for MIN_NEW of continuation under the cap | |
| MAX_SOURCE_SECONDS = MAX_TOTAL_SECONDS - MIN_NEW_SECONDS | |
| def plan_continuation(source_seconds, total_seconds): | |
| """Pure helper (unit-testable, no model): turn a (source, requested-total) | |
| pair into the SA3 generation buffer and return (lead, new_seconds, buffer). | |
| - `lead` : seconds of the clip's TAIL used as run-up context, capped at | |
| MAX_LEAD_SECONDS so a long clip can't drown the masked region. | |
| - `new` : seconds of fresh audio to generate. We extend to the requested | |
| finished length (`total - source`), floored at MIN_NEW_SECONDS | |
| so every call earns its GPU time, and bounded so the buffer | |
| (lead + new) never exceeds SA3's MAX_TOTAL_SECONDS cap. | |
| - `buffer` : lead + new, i.e. the full generation buffer. The mask runs | |
| [lead, buffer]; buffer > lead always, so it never inverts. | |
| Raises ValueError only for a clip longer than MAX_SOURCE_SECONDS — at that | |
| point it's a full track, not an unfinished clip to continue. | |
| """ | |
| source_seconds = float(source_seconds) | |
| total_seconds = float(total_seconds) | |
| if source_seconds > MAX_SOURCE_SECONDS: | |
| raise ValueError( | |
| f"clip is {source_seconds:.0f}s — that's a finished-length track, " | |
| f"not an unfinished clip. CODA continues clips up to " | |
| f"{MAX_SOURCE_SECONDS:.0f}s; trim it shorter and re-upload.") | |
| total_seconds = min(total_seconds, MAX_TOTAL_SECONDS) | |
| lead = min(source_seconds, MAX_LEAD_SECONDS) | |
| # extend to the requested finished length; floor at MIN_NEW, and never let | |
| # lead + new exceed the buffer cap. | |
| new_seconds = max(total_seconds - source_seconds, MIN_NEW_SECONDS) | |
| new_seconds = min(new_seconds, MAX_TOTAL_SECONDS - lead) | |
| buffer_seconds = lead + new_seconds | |
| return lead, new_seconds, buffer_seconds | |
| def _tail_artifact_score(tail, sr=SR): | |
| """Lower is better. A blind, ear-free quality score for a generated tail, | |
| used to pick the cleanest of several candidate draws. | |
| It targets the four ways an SA3 draw goes bad: | |
| * "sporadic loud random synth noises" — even a FEW short windows far louder | |
| than the body push the loudest window way above the median. (After the | |
| whole-buffer peak-normalize, a burst that set the peak crushes the body, | |
| making the gap larger still.) Sustained dynamics rarely make any single | |
| 50 ms window many times the median, so musical loudness doesn't trip it. | |
| * silence collapse — a near-silent WHOLE tail is caught by the loudness | |
| floor below (it keys off the tail's overall RMS). | |
| * mid-tail dropout — a brief near-silent HOLE inside an otherwise healthy | |
| tail. This is the gap the first terms miss: overall RMS stays high (so the | |
| silence floor never fires) and max/median stays low (so spikiness never | |
| fires), yet a listener plainly hears the music cut out for a beat. We | |
| detect it as a SUSTAINED quiet stretch — the quietest ~0.2 s envelope | |
| falling well below the median. | |
| * dynamics/transient collapse — a draw can be perfectly tonal and steady yet | |
| sound DULL and lifeless: its transients are smeared, so there's no attack, | |
| just a wall of mush. Flatness and loudness checks all read "clean". It | |
| shows up as a collapsed crest factor (peak/RMS): real music here sits at | |
| crest ~6-8, a squashed draw falls to ~2-3. We penalize a low crest so | |
| best-of-N prefers the punchy draw over the mushy one. | |
| Score = max/median + silence + dropout + dynamics penalties. | |
| Computed on a mono mix over short (~50 ms) windows. A flat, steady signal | |
| scores ~1; loud bursts, a crushed body, a mid-tail hole, or a smeared, | |
| transient-less wash all score high. | |
| """ | |
| mono = tail.mean(axis=0) if tail.ndim == 2 else np.asarray(tail) | |
| mono = np.asarray(mono, dtype=np.float64) | |
| win = max(1, int(0.05 * sr)) | |
| if mono.size < win * 4: | |
| return float("inf") # too short to judge — avoid it | |
| n = mono.size // win | |
| energies = np.sqrt( | |
| np.mean(mono[:n * win].reshape(n, win) ** 2, axis=1) + 1e-12) | |
| median = float(np.median(energies)) + 1e-9 | |
| loudest = float(np.max(energies)) | |
| spikiness = loudest / median | |
| overall = float(np.sqrt(np.mean(mono ** 2)) + 1e-12) | |
| silence_penalty = 0.0 if overall > 0.02 else (0.02 - overall) * 200.0 | |
| # mid-tail dropout: smooth the window energies over ~0.2 s and find how far | |
| # the quietest SUSTAINED stretch falls below the median. Exclude the final | |
| # 0.5 s so a natural ending taper (which stitch fades anyway) isn't punished. | |
| # A clean tail's quietest 0.2 s sits ~0.15-0.4x the median -> no penalty; a | |
| # real hole drops to <0.1x -> a penalty large enough to lose the early-accept | |
| # and force another draw, so best-of-N rolls past the glitch. | |
| dropout_penalty = 0.0 | |
| smooth = np.convolve(energies, np.ones(4) / 4.0, mode="valid") | |
| guard = int(0.5 / 0.05) # last 0.5 s of windows | |
| body = smooth[:-guard] if smooth.size > guard + 4 else smooth | |
| if body.size: | |
| dropout = float(np.min(body)) / median | |
| dropout_penalty = min(8.0, max(0.0, DROPOUT_FLOOR / max(dropout, 1e-3) | |
| - 1.0) * 2.0) | |
| # dynamics/transient collapse: crest = peak / RMS. The tail is peak-normalized | |
| # to ~1.0, so this is essentially 1/RMS — a squashed, attack-less wash reads | |
| # high RMS (low crest); a punchy, dynamic take reads low RMS (high crest). | |
| # Penalize only a clearly collapsed crest, so we never punish a naturally | |
| # dynamic draw. | |
| peak = float(np.abs(mono).max()) | |
| crest = peak / overall | |
| crest_penalty = max(0.0, CREST_FLOOR - crest) * CREST_SCALE | |
| return spikiness + silence_penalty + dropout_penalty + crest_penalty | |
| def continue_audio(clip_path, total_seconds, prompt="", cfg_scale=DEFAULT_CFG, | |
| seed=-1, candidates=None, progress=None): | |
| """Continue `clip_path` up to `total_seconds` with SA3 inpainting. | |
| With the default `seed` (< 0) this draws up to `candidates` SA3 inpaint | |
| takes and returns the cleanest by `_tail_artifact_score` (early-accepting a | |
| clean draw and respecting a GPU wall-clock budget). Pin `seed` >= 0 for a | |
| single deterministic draw (debug/repro). | |
| Returns (new_tail, source_seconds, SR) where: | |
| new_tail : (2, M) float32 @44.1k — ONLY the generated region | |
| [source_end, total]. Peak-normalized to <= 1.0. | |
| source_seconds : the clip's true length (the splice boundary, in seconds) | |
| SR : 44100 | |
| `progress(stage_name)` is called (best-effort) at each stage so the UI can | |
| paint a live status. Progress is stage-based (read / compose / finalize). | |
| """ | |
| from einops import rearrange | |
| from stable_audio_tools.inference.generation import ( | |
| generate_diffusion_cond_inpaint) | |
| def _notify(stage): | |
| if progress is not None: | |
| try: | |
| progress(stage) | |
| except Exception as e: | |
| print(f"[coda] progress callback failed ({e})", flush=True) | |
| preload() | |
| model = _ensure_on_device() | |
| dev = _device() | |
| # Seed policy. The library does `np.random.randint(0, 2**32-1)` when | |
| # seed == -1, which overflows int32 on Windows/numpy<2, so we always draw | |
| # our own safe seeds. A caller that pins a seed (>= 0) gets exactly one | |
| # deterministic draw (reproducibility/debug paths). The default path | |
| # (seed < 0) draws several candidates and keeps the cleanest — that's what | |
| # the app uses, because a single pinned seed doesn't survive a torch change. | |
| pinned = seed is not None and seed >= 0 | |
| n_candidates = 1 if pinned else max(1, int(candidates or DEFAULT_CANDIDATES)) | |
| if pinned: | |
| seeds = [int(seed)] | |
| else: | |
| base = int(np.random.randint(0, 2 ** 31 - 1)) | |
| # spread the seeds far apart so the draws are genuinely different | |
| seeds = [(base + i * 0x9E3779B1) % (2 ** 31 - 1) | |
| for i in range(n_candidates)] | |
| _notify("reading") | |
| source = _load_source(clip_path) | |
| source_seconds = source.shape[-1] / SR | |
| lead, new_seconds, buffer_seconds = plan_continuation( | |
| source_seconds, total_seconds) | |
| # condition on only the TAIL `lead` seconds of the clip. This is the bug fix: | |
| # a long source no longer fills the buffer and starves the masked region. | |
| lead_samples = min(int(round(lead * SR)), source.shape[-1]) | |
| lead_audio = source[:, -lead_samples:] | |
| # the lead is encoded by the autoencoder (pretransform), which we now run in | |
| # fp32; its conv1d rejects a mismatched input dtype, so match the | |
| # pretransform's parameter dtype (fp32 on CUDA), NOT the DiT's fp16. Resample | |
| # is a no-op here (clip is already 44.1k), so prepare_audio keeps this dtype. | |
| ae = getattr(model, "pretransform", None) | |
| ae_dtype = (next(ae.parameters()).dtype if ae is not None | |
| else next(model.model.parameters()).dtype) | |
| lead_audio = lead_audio.to(ae_dtype) | |
| mask_start, mask_end = lead, buffer_seconds | |
| prompt = (prompt or "").strip() | |
| print(f"[coda] continuation: source={source_seconds:.1f}s, " | |
| f"lead={lead:.1f}s -> buffer={buffer_seconds:.1f}s " | |
| f"(+{new_seconds:.1f}s new), mask=[{mask_start:.1f}s, {mask_end:.1f}s], " | |
| f"steps={STEPS}, cfg={cfg_scale}, prompt={prompt!r}", flush=True) | |
| def _draw(draw_seed): | |
| """One full SA3 inpaint draw -> normalized generated tail (2, M).""" | |
| # _sdpa_math() forces the reference attention backend for the whole | |
| # sample loop — the H200 fp16-SDPA-garbage guard (no-op on CPU). | |
| with torch.no_grad(), _sdpa_math(): | |
| output = generate_diffusion_cond_inpaint( | |
| model, | |
| steps=STEPS, | |
| cfg_scale=cfg_scale, | |
| conditioning=[{"prompt": prompt, "seconds_total": buffer_seconds}], | |
| sample_size=_sample_size, | |
| sampler_type=SAMPLER, | |
| inpaint_audio=(SR, lead_audio), | |
| inpaint_mask_start_seconds=mask_start, | |
| inpaint_mask_end_seconds=mask_end, | |
| seed=int(draw_seed), | |
| device=dev, | |
| ) | |
| # (b, d, n) -> (d, b*n); peak-normalize like Stability's reference Space | |
| # (unchanged from the verified-good local path). | |
| output = rearrange(output, "b d n -> d (b n)") | |
| audio = output.to(torch.float32).cpu().numpy() | |
| peak = float(np.abs(audio).max()) | |
| if peak > 1e-9: | |
| audio = audio / peak | |
| if audio.shape[0] == 1: # safety: ensure stereo | |
| audio = np.repeat(audio, 2, axis=0) | |
| # the generated region is [lead, buffer]; lead ends at the clip's true | |
| # end, so this slice is the continuation that follows the source. | |
| start = int(round(lead * SR)) | |
| end = min(int(round(buffer_seconds * SR)), audio.shape[-1]) | |
| return np.ascontiguousarray(audio[:, start:end].astype(np.float32)) | |
| _notify("composing") | |
| # Best-of-N: draw, score, keep the cleanest. Early-accept a clean draw and | |
| # stop if the GPU wall-clock budget runs low, so this never blows the window. | |
| best_tail, best_score, best_seed = None, float("inf"), None | |
| t0 = time.time() | |
| for i, draw_seed in enumerate(seeds): | |
| tail = _draw(draw_seed) | |
| score = _tail_artifact_score(tail, SR) | |
| elapsed = time.time() - t0 | |
| print(f"[coda] candidate {i + 1}/{len(seeds)} seed={draw_seed} " | |
| f"shape={tail.shape} ({tail.shape[-1] / SR:.1f}s) " | |
| f"artifact_score={score:.2f} " | |
| f"rms={float(np.sqrt(np.mean(tail ** 2))):.3f} " | |
| f"elapsed={elapsed:.1f}s", flush=True) | |
| if score < best_score: | |
| best_tail, best_score, best_seed = tail, score, draw_seed | |
| if best_score <= EARLY_ACCEPT_SCORE: | |
| print(f"[coda] candidate {i + 1} clean enough " | |
| f"(score {best_score:.2f} <= {EARLY_ACCEPT_SCORE}); accepting", | |
| flush=True) | |
| break | |
| if elapsed > GPU_BUDGET_SECONDS and i + 1 < len(seeds): | |
| print(f"[coda] GPU budget {GPU_BUDGET_SECONDS:.0f}s reached after " | |
| f"{i + 1} draw(s); keeping best so far", flush=True) | |
| break | |
| _notify("finalizing") | |
| print(f"[coda] selected seed={best_seed} artifact_score={best_score:.2f} " | |
| f"tail={best_tail.shape[-1] / SR:.1f}s peak after norm " | |
| f"{float(np.abs(best_tail).max()):.3f} " | |
| f"rms {float(np.sqrt(np.mean(best_tail ** 2))):.3f}", flush=True) | |
| return best_tail, source_seconds, SR | |