"""ZeroGPU Gradio demo for Stable Audio 3 — Medium, Small Music, Small SFX. Two tabs: * **Simple** — prompt + duration with a slim Advanced accordion (steps/CFG/seed /sampler). Mirrors the original tiny UI. * **Advanced** — replicates the reference repo's ``stable_audio_3/interface/diffusion_cond.py`` controls: negative prompt, sampler params (sigma_max, APG, duration padding), init audio + noise level, inpainting with mask start/end, spectrogram gallery, send-to-init / send-to-inpaint buttons. """ from __future__ import annotations import spaces # noqa: F401 import os import subprocess import sys import tempfile import time from dataclasses import dataclass from typing import Optional, Tuple def _ensure_stable_audio_tools() -> None: try: import stable_audio_tools # noqa: F401 return except ImportError: pass # stable-audio-tools 0.0.20 strict-pins torch==2.7.1 / torchaudio==2.7.1, # which lack sm_120 (Blackwell) kernels. Install with --no-deps; the # transitive deps are listed in requirements.txt and resolved against the # sm_120-capable torch at build time. print("[startup] installing stable-audio-tools (--no-deps) …", flush=True) subprocess.check_call( [sys.executable, "-m", "pip", "install", "--quiet", "--no-deps", "stable-audio-tools"], ) import stable_audio_tools # noqa: F401 print("[startup] stable-audio-tools installed.", flush=True) _ensure_stable_audio_tools() import gradio as gr import numpy as np import soundfile as sf import torch import torchaudio import torchaudio.transforms as T from einops import rearrange from matplotlib.backends.backend_agg import FigureCanvasAgg from matplotlib.figure import Figure from PIL import Image from stable_audio_tools import get_pretrained_model from stable_audio_tools.inference.generation import generate_diffusion_cond_inpaint # --------------------------------------------------------------------------- # Variants # --------------------------------------------------------------------------- @dataclass class Variant: key: str repo: str label: str default_duration: int placeholder: str VARIANTS: list[Variant] = [ Variant( key="medium", repo="stabilityai/stable-audio-3-medium", label="Medium — general audio (largest)", default_duration=60, placeholder="A dream-like Synthpop instrumental that would accompany a dream-sequence in a surrealist movie 120 BPM", ), Variant( key="small-music", repo="stabilityai/stable-audio-3-small-music", label="Small Music — 0.6B, music-focused", default_duration=60, placeholder="Cinematic neo-soul groove with electric piano, brushed drums, walking upright bass, smoky vibe 92 BPM", ), Variant( key="small-sfx", repo="stabilityai/stable-audio-3-small-sfx", label="Small SFX — 0.6B, sound effects", default_duration=7, placeholder="Chugging train coming into station with horn", ), ] # --------------------------------------------------------------------------- # Preload all variants at module level (ZeroGPU CUDA emulation accepts it) # --------------------------------------------------------------------------- @dataclass class LoadedVariant: variant: Variant model: object sample_rate: int sample_size: int max_seconds: int LOADED: dict[str, LoadedVariant] = {} for v in VARIANTS: print(f"[startup] loading {v.repo} …", flush=True) t0 = time.time() model, config = get_pretrained_model(v.repo) sr = int(config["sample_rate"]) ss = int(config["sample_size"]) model = model.to("cuda").to(torch.float16) LOADED[v.key] = LoadedVariant( variant=v, model=model, sample_rate=sr, sample_size=ss, max_seconds=ss // sr, ) print( f"[startup] {v.key} ready in {time.time() - t0:.1f}s · " f"sr={sr} · sample_size={ss} (~{ss // sr}s max)", flush=True, ) VARIANT_CHOICES = [(v.label, v.key) for v in VARIANTS] # Samplers valid for rf_denoiser diffusion objective (the SA3 family). SAMPLERS = ["pingpong", "euler", "rk4", "dpmpp"] # --------------------------------------------------------------------------- # Spectrogram helper (Mel; adapted from the reference repo's aeiou.py) # --------------------------------------------------------------------------- def _power_to_db(spec: np.ndarray, amin: float = 1e-10) -> np.ndarray: return 10.0 * np.log10(np.maximum(amin, spec)) def audio_spectrogram_image( waveform: torch.Tensor, sample_rate: int, db_range=(35, 120), figsize=(5, 4), ) -> Image.Image: """Render a Mel spectrogram (left channel) as a PIL image.""" if waveform.dim() == 1: waveform = waveform.unsqueeze(0) n_fft = 1024 hop_length = n_fft // 2 mel_op = T.MelSpectrogram( sample_rate=sample_rate, n_fft=n_fft, win_length=None, hop_length=hop_length, center=True, pad_mode="reflect", power=2.0, norm="slaney", onesided=True, n_mels=128, mel_scale="htk", ) melspec = mel_op(waveform.float())[0] # left channel fig = Figure(figsize=figsize, dpi=100) canvas = FigureCanvasAgg(fig) ax = fig.add_subplot() ax.imshow(_power_to_db(melspec.numpy()), origin="lower", aspect="auto", vmin=db_range[0], vmax=db_range[1]) ax.set_ylabel("mel bins (log freq)") ax.set_xlabel("frame") ax.set_title("MelSpectrogram") canvas.draw() return Image.fromarray(np.asarray(canvas.buffer_rgba())) # --------------------------------------------------------------------------- # Inference # --------------------------------------------------------------------------- def _gradio_audio_to_tensor( audio_in: Optional[Tuple[int, np.ndarray]], ) -> Optional[Tuple[int, torch.Tensor]]: """Convert a gr.Audio (numpy) value to the (sr, torch.Tensor[C,N]) tuple that ``generate_diffusion_cond_inpaint`` expects. Accepts mono or stereo.""" if audio_in is None: return None sr, arr = audio_in if arr is None or (hasattr(arr, "size") and arr.size == 0): return None arr = np.asarray(arr) if arr.dtype.kind in ("i", "u"): max_val = float(np.iinfo(arr.dtype).max) arr = arr.astype(np.float32) / max_val else: arr = arr.astype(np.float32) if arr.ndim == 1: arr = arr[None, :] # (1, N) else: # gr.Audio returns (N, C); transpose to (C, N) arr = arr.T if arr.shape[0] > arr.shape[1] else arr return int(sr), torch.from_numpy(arr) def _tensor_to_wav( output: torch.Tensor, sample_rate: int, duration_seconds: Optional[int], out_dir: Optional[str] = None, ) -> Tuple[str, torch.Tensor]: """Pack a (B, C, N) generation tensor to int16, optionally cut to duration, write to disk, and return (path, int16-tensor).""" output = rearrange(output, "b d n -> d (b n)") output = ( output.to(torch.float32) .div(torch.max(torch.abs(output)).clamp(min=1e-9)) .clamp(-1, 1) .mul(32767) .to(torch.int16) .cpu() ) if duration_seconds is not None: output = output[:, : int(duration_seconds) * sample_rate] out_dir = out_dir or tempfile.mkdtemp() out_path = os.path.join(out_dir, "sa3.wav") sf.write(out_path, output.numpy().T, sample_rate, subtype="PCM_16") return out_path, output def _run_inference( variant_key: str, prompt: str, negative_prompt: str = "", duration: int = 60, steps: int = 8, cfg_scale: float = 1.0, sampler_type: str = "pingpong", seed: int = 0, sigma_max: float = 1.0, apg_scale: float = 1.0, duration_padding_sec: float = 6.0, cut_to_seconds_total: bool = True, init_audio: Optional[Tuple[int, np.ndarray]] = None, init_noise_level: float = 0.9, inpaint_audio: Optional[Tuple[int, np.ndarray]] = None, mask_start_sec: float = 0.0, mask_end_sec: float = 0.0, preview_every: int = 0, return_spectrogram: bool = True, progress: gr.Progress = gr.Progress(), ): """Full-featured generation. Returns (audio_path, [spectrogram_img, *previews]) when ``return_spectrogram`` is True, else just ``audio_path``.""" prompt = (prompt or "").strip() if not prompt: raise gr.Error("Please enter a prompt.") if variant_key not in LOADED: raise gr.Error(f"Unknown variant {variant_key!r}.") lv = LOADED[variant_key] duration = max(1, min(int(duration), lv.max_seconds)) progress(0.05, desc=f"[{variant_key}] preparing conditioning") conditioning = [{"prompt": prompt, "seconds_total": int(duration)}] negative_conditioning = None neg = (negative_prompt or "").strip() if neg: negative_conditioning = [{"prompt": neg, "seconds_total": int(duration)}] # The pretransform encoder is fp16 (we cast the whole model at startup), # but prepare_audio's torchaudio Resample uses an fp32 kernel. Pre-resample # in fp32 here so prepare_audio's resample is a no-op, then cast to the # model dtype so the encoder doesn't see a dtype mismatch. model_dtype = next(lv.model.parameters()).dtype def _prep(tup): if tup is None: return None sr, t = tup t = t.float() if sr != lv.sample_rate: t = torchaudio.functional.resample(t, sr, lv.sample_rate) return lv.sample_rate, t.to(model_dtype) init_audio_t = _prep(_gradio_audio_to_tensor(init_audio)) inpaint_audio_t = _prep(_gradio_audio_to_tensor(inpaint_audio)) # Inpaint mask: only enable if mask_end > mask_start AND we have either # inpaint_audio or init_audio (otherwise the mask wraps zero content). mask_start = max(0.0, float(mask_start_sec)) mask_end = min(float(duration), float(mask_end_sec)) use_mask = ( inpaint_audio_t is not None and mask_end > mask_start ) seed_val = int(seed) if seed and int(seed) > 0 else -1 preview_images: list = [] callback = None if preview_every and int(preview_every) > 0: every = int(preview_every) def _cb(info): i = info["i"] if i % every != 0: return denoised = info["denoised"] try: if lv.model.pretransform is not None: denoised = lv.model.pretransform.decode(denoised) d = rearrange(denoised, "b d n -> d (b n)") d = d.clamp(-1, 1).mul(32767).to(torch.int16).cpu() img = audio_spectrogram_image(d, sample_rate=lv.sample_rate) preview_images.append((img, f"Step {i + 1}")) except Exception as e: print(f"[preview] skipped step {i}: {e}", flush=True) callback = _cb gen_kwargs: dict = dict( steps=int(steps), cfg_scale=float(cfg_scale), conditioning=conditioning, negative_conditioning=negative_conditioning, sample_size=lv.sample_size, sampler_type=sampler_type, seed=seed_val, device="cuda", sigma_max=float(sigma_max), apg_scale=float(apg_scale), duration_padding_sec=float(duration_padding_sec), ) if init_audio_t is not None: gen_kwargs["init_audio"] = init_audio_t gen_kwargs["init_noise_level"] = float(init_noise_level) if inpaint_audio_t is not None: gen_kwargs["inpaint_audio"] = inpaint_audio_t if use_mask: gen_kwargs["inpaint_mask_start_seconds"] = mask_start gen_kwargs["inpaint_mask_end_seconds"] = mask_end if callback is not None: gen_kwargs["callback"] = callback progress(0.25, desc=f"[{variant_key}] sampling {steps} steps with {sampler_type}") t0 = time.time() output = generate_diffusion_cond_inpaint(lv.model, **gen_kwargs) print(f"[infer/{variant_key}] sampling done in {time.time() - t0:.1f}s", flush=True) progress(0.92, desc="Normalising & saving") cut_dur = int(duration) if cut_to_seconds_total else None out_path, int16_audio = _tensor_to_wav(output, lv.sample_rate, cut_dur) if not return_spectrogram: return out_path spec_img = audio_spectrogram_image(int16_audio, sample_rate=lv.sample_rate) return out_path, [spec_img, *preview_images] @spaces.GPU def infer( variant_key: str, prompt: str, duration: int = 60, steps: int = 8, cfg_scale: float = 1.0, sampler_type: str = "pingpong", seed: int = 0, progress: gr.Progress = gr.Progress(), ): """Slim handler used by the Simple tab and the Examples cache.""" return _run_inference( variant_key=variant_key, prompt=prompt, duration=duration, steps=steps, cfg_scale=cfg_scale, sampler_type=sampler_type, seed=seed, return_spectrogram=False, progress=progress, ) @spaces.GPU def infer_advanced( variant_key: str, prompt: str, negative_prompt: str, duration: int, steps: int, cfg_scale: float, sampler_type: str, seed: int, sigma_max: float, apg_scale: float, duration_padding_sec: float, cut_to_seconds_total: bool, init_audio: Optional[Tuple[int, np.ndarray]], init_noise_level: float, inpaint_audio: Optional[Tuple[int, np.ndarray]], mask_start_sec: float, mask_end_sec: float, preview_every: int, progress: gr.Progress = gr.Progress(), ): """Full-featured handler used by the Advanced tab.""" return _run_inference( variant_key=variant_key, prompt=prompt, negative_prompt=negative_prompt, duration=duration, steps=steps, cfg_scale=cfg_scale, sampler_type=sampler_type, seed=seed, sigma_max=sigma_max, apg_scale=apg_scale, duration_padding_sec=duration_padding_sec, cut_to_seconds_total=cut_to_seconds_total, init_audio=init_audio, init_noise_level=init_noise_level, inpaint_audio=inpaint_audio, mask_start_sec=mask_start_sec, mask_end_sec=mask_end_sec, preview_every=preview_every, return_spectrogram=True, progress=progress, ) # --------------------------------------------------------------------------- # UI # --------------------------------------------------------------------------- DESCRIPTION = """ # 🎵 Stable Audio 3 Text-to-audio generation with Stable Audio 3. Pick a variant, write a prompt, hit Generate. Switch to **Advanced** for the full sampler / init-audio / inpainting controls. """ EXAMPLES = [ ["medium", "House music that encapsulates the feeling of being at a festival in the sunny weather with all your friends 124 BPM", 60], ["small-music", "Cinematic neo-soul groove with electric piano, brushed drums, walking upright bass, smoky vibe 92 BPM", 45], ["small-music", "Driving techno track with rolling 16th-note hats, deep sub bass, acid arpeggios building tension 132 BPM", 60], ["small-sfx", "Chugging train coming into station with horn", 7], ["small-sfx", "Heavy rain on a tin roof with distant thunder rolls", 10], ["medium", "Rainy night, lo-fi hip-hop beat with vinyl crackle, mellow piano chords, soft kick and snare 80 BPM", 30], ] def _variant_change_simple(variant_key: str): lv = LOADED[variant_key] return ( gr.update(maximum=lv.max_seconds, value=min(lv.variant.default_duration, lv.max_seconds), label=f"Duration (s) · model max {lv.max_seconds}s"), gr.update(placeholder=lv.variant.placeholder), ) def _variant_change_advanced(variant_key: str): lv = LOADED[variant_key] dur = min(lv.variant.default_duration, lv.max_seconds) return ( gr.update(maximum=lv.max_seconds, value=dur, label=f"Seconds total · model max {lv.max_seconds}s"), gr.update(placeholder=lv.variant.placeholder), gr.update(maximum=float(lv.max_seconds), value=0.0), gr.update(maximum=float(lv.max_seconds), value=float(dur)), ) with gr.Blocks(theme=gr.themes.Citrus(), title="Stable Audio 3") as demo: gr.Markdown(DESCRIPTION) with gr.Tabs(): # ----------------------------------------------------------------- # Simple tab # ----------------------------------------------------------------- with gr.Tab("Simple"): variant = gr.Radio( choices=VARIANT_CHOICES, value=VARIANTS[0].key, label="Model", ) with gr.Row(): with gr.Column(scale=2): prompt = gr.Textbox( label="Prompt", placeholder=VARIANTS[0].placeholder, lines=3, ) duration = gr.Slider( 1, LOADED[VARIANTS[0].key].max_seconds, value=VARIANTS[0].default_duration, step=1, label=f"Duration (s) · model max {LOADED[VARIANTS[0].key].max_seconds}s", ) with gr.Accordion("Advanced settings", open=False): steps = gr.Slider(1, 50, value=8, step=1, label="Steps") cfg_scale = gr.Slider(0.5, 8.0, value=1.0, step=0.1, label="CFG scale") sampler_type = gr.Dropdown(SAMPLERS, value="pingpong", label="Sampler") seed = gr.Number(value=0, precision=0, label="Seed (0 = random)") run_btn = gr.Button("🎼 Generate", variant="primary", size="lg") with gr.Column(scale=1): audio_out = gr.Audio(label="Output", type="filepath", autoplay=True) gr.Examples( examples=EXAMPLES, inputs=[variant, prompt, duration], outputs=[audio_out], fn=infer, cache_examples=True, cache_mode="lazy", label="Examples (lazy-cached on first click)", ) variant.change( fn=_variant_change_simple, inputs=[variant], outputs=[duration, prompt], ) run_btn.click( fn=infer, inputs=[variant, prompt, duration, steps, cfg_scale, sampler_type, seed], outputs=[audio_out], ) # ----------------------------------------------------------------- # Advanced tab — mirrors stable_audio_3/interface/diffusion_cond.py # ----------------------------------------------------------------- with gr.Tab("Advanced"): adv_variant = gr.Radio( choices=VARIANT_CHOICES, value=VARIANTS[0].key, label="Model", ) with gr.Row(): with gr.Column(scale=6): adv_prompt = gr.Textbox( show_label=False, placeholder=VARIANTS[0].placeholder, ) adv_negative = gr.Textbox( show_label=False, placeholder="Negative prompt" ) adv_generate = gr.Button("Generate", variant="primary", scale=1) with gr.Row(equal_height=False): with gr.Column(): adv_seconds_total = gr.Slider( minimum=1, maximum=LOADED[VARIANTS[0].key].max_seconds, step=1, value=VARIANTS[0].default_duration, label=f"Seconds total · model max {LOADED[VARIANTS[0].key].max_seconds}s", ) with gr.Row(): adv_steps = gr.Slider( minimum=1, maximum=500, step=1, value=8, label="Steps" ) adv_cfg = gr.Slider( minimum=0.0, maximum=25.0, step=0.1, value=1.0, label="CFG scale", ) with gr.Accordion("Sampler params", open=False): with gr.Row(): adv_seed = gr.Number( label="Seed (set to -1 for random seed)", value=-1, precision=0, ) adv_sampler = gr.Dropdown( SAMPLERS, label="Sampler type", value="pingpong", ) adv_sigma_max = gr.Slider( minimum=0.0, maximum=1.0, step=0.01, value=1.0, label="Sigma max", ) with gr.Row(): adv_apg = gr.Slider( minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="APG scale", info="1.0=full APG, 0.0=vanilla CFG", ) adv_dur_padding = gr.Slider( minimum=0.0, maximum=30.0, step=0.5, value=6.0, label="Duration padding (sec)", ) with gr.Accordion("Output params", open=False): with gr.Row(): adv_preview_every = gr.Slider( minimum=0, maximum=100, step=1, value=0, label="Spec preview every N steps (0 = off)", ) adv_cut_to_total = gr.Checkbox( label="Cut to seconds total", value=True, ) with gr.Accordion("Init audio", open=False): adv_init_audio = gr.Audio( label="Init audio", type="numpy", ) adv_init_noise = gr.Slider( minimum=0.01, maximum=1.0, step=0.01, value=0.9, label="Init noise level", ) with gr.Accordion("Inpainting", open=False): adv_inpaint_audio = gr.Audio( label="Inpaint audio", type="numpy", ) adv_mask_start = gr.Slider( minimum=0.0, maximum=float(LOADED[VARIANTS[0].key].max_seconds), step=0.1, value=0.0, label="Mask start (sec)", ) adv_mask_end = gr.Slider( minimum=0.0, maximum=float(LOADED[VARIANTS[0].key].max_seconds), step=0.1, value=0.0, label="Mask end (sec)", ) with gr.Column(): adv_audio_out = gr.Audio( label="Output audio", type="filepath", autoplay=False, sources=[], ) adv_spec_gallery = gr.Gallery( label="Output spectrogram", show_label=True, columns=2, ) send_to_init_btn = gr.Button("Send to init audio") send_to_inpaint_btn = gr.Button("Send to inpaint audio") send_to_init_btn.click( fn=lambda a: a, inputs=[adv_audio_out], outputs=[adv_init_audio] ) send_to_inpaint_btn.click( fn=lambda a: a, inputs=[adv_audio_out], outputs=[adv_inpaint_audio] ) # Keep the inpaint mask bounded by the current duration. def _update_mask_max(seconds_total): m = max(float(seconds_total), 1.0) return ( gr.update(maximum=m), gr.update(maximum=m, value=m), ) adv_seconds_total.change( _update_mask_max, inputs=[adv_seconds_total], outputs=[adv_mask_start, adv_mask_end], ) adv_variant.change( fn=_variant_change_advanced, inputs=[adv_variant], outputs=[adv_seconds_total, adv_prompt, adv_mask_start, adv_mask_end], ) adv_generate.click( fn=infer_advanced, inputs=[ adv_variant, adv_prompt, adv_negative, adv_seconds_total, adv_steps, adv_cfg, adv_sampler, adv_seed, adv_sigma_max, adv_apg, adv_dur_padding, adv_cut_to_total, adv_init_audio, adv_init_noise, adv_inpaint_audio, adv_mask_start, adv_mask_end, adv_preview_every, ], outputs=[adv_audio_out, adv_spec_gallery], ) if __name__ == "__main__": demo.launch()