Spaces:
Running on Zero
Running on Zero
| """ZeroGPU Gradio demo for Stable Audio 3 — Medium, Small Music, Small SFX. | |
| Two tabs: | |
| * **Simple** — prompt + duration with a slim Advanced accordion (steps/CFG/seed | |
| /sampler). Mirrors the original tiny UI. | |
| * **Advanced** — replicates the reference repo's | |
| ``stable_audio_3/interface/diffusion_cond.py`` controls: negative prompt, | |
| sampler params (sigma_max, APG, duration padding), init audio + noise level, | |
| inpainting with mask start/end, spectrogram gallery, send-to-init / | |
| send-to-inpaint buttons. | |
| """ | |
| from __future__ import annotations | |
| import spaces # noqa: F401 | |
| import os | |
| import subprocess | |
| import sys | |
| import tempfile | |
| import time | |
| from dataclasses import dataclass | |
| from typing import Optional, Tuple | |
| def _ensure_stable_audio_tools() -> None: | |
| try: | |
| import stable_audio_tools # noqa: F401 | |
| return | |
| except ImportError: | |
| pass | |
| # stable-audio-tools 0.0.20 strict-pins torch==2.7.1 / torchaudio==2.7.1, | |
| # which lack sm_120 (Blackwell) kernels. Install with --no-deps; the | |
| # transitive deps are listed in requirements.txt and resolved against the | |
| # sm_120-capable torch at build time. | |
| print("[startup] installing stable-audio-tools (--no-deps) …", flush=True) | |
| subprocess.check_call( | |
| [sys.executable, "-m", "pip", "install", "--quiet", "--no-deps", | |
| "stable-audio-tools"], | |
| ) | |
| import stable_audio_tools # noqa: F401 | |
| print("[startup] stable-audio-tools installed.", flush=True) | |
| _ensure_stable_audio_tools() | |
| import gradio as gr | |
| import numpy as np | |
| import soundfile as sf | |
| import torch | |
| import torchaudio | |
| import torchaudio.transforms as T | |
| from einops import rearrange | |
| from matplotlib.backends.backend_agg import FigureCanvasAgg | |
| from matplotlib.figure import Figure | |
| from PIL import Image | |
| from stable_audio_tools import get_pretrained_model | |
| from stable_audio_tools.inference.generation import generate_diffusion_cond_inpaint | |
| # --------------------------------------------------------------------------- | |
| # Variants | |
| # --------------------------------------------------------------------------- | |
| class Variant: | |
| key: str | |
| repo: str | |
| label: str | |
| default_duration: int | |
| placeholder: str | |
| VARIANTS: list[Variant] = [ | |
| Variant( | |
| key="medium", | |
| repo="stabilityai/stable-audio-3-medium", | |
| label="Medium — general audio (largest)", | |
| default_duration=60, | |
| placeholder="A dream-like Synthpop instrumental that would accompany a dream-sequence in a surrealist movie 120 BPM", | |
| ), | |
| Variant( | |
| key="small-music", | |
| repo="stabilityai/stable-audio-3-small-music", | |
| label="Small Music — 0.6B, music-focused", | |
| default_duration=60, | |
| placeholder="Cinematic neo-soul groove with electric piano, brushed drums, walking upright bass, smoky vibe 92 BPM", | |
| ), | |
| Variant( | |
| key="small-sfx", | |
| repo="stabilityai/stable-audio-3-small-sfx", | |
| label="Small SFX — 0.6B, sound effects", | |
| default_duration=7, | |
| placeholder="Chugging train coming into station with horn", | |
| ), | |
| ] | |
| # --------------------------------------------------------------------------- | |
| # Preload all variants at module level (ZeroGPU CUDA emulation accepts it) | |
| # --------------------------------------------------------------------------- | |
| class LoadedVariant: | |
| variant: Variant | |
| model: object | |
| sample_rate: int | |
| sample_size: int | |
| max_seconds: int | |
| LOADED: dict[str, LoadedVariant] = {} | |
| for v in VARIANTS: | |
| print(f"[startup] loading {v.repo} …", flush=True) | |
| t0 = time.time() | |
| model, config = get_pretrained_model(v.repo) | |
| sr = int(config["sample_rate"]) | |
| ss = int(config["sample_size"]) | |
| model = model.to("cuda").to(torch.float16) | |
| LOADED[v.key] = LoadedVariant( | |
| variant=v, | |
| model=model, | |
| sample_rate=sr, | |
| sample_size=ss, | |
| max_seconds=ss // sr, | |
| ) | |
| print( | |
| f"[startup] {v.key} ready in {time.time() - t0:.1f}s · " | |
| f"sr={sr} · sample_size={ss} (~{ss // sr}s max)", | |
| flush=True, | |
| ) | |
| VARIANT_CHOICES = [(v.label, v.key) for v in VARIANTS] | |
| # Samplers valid for rf_denoiser diffusion objective (the SA3 family). | |
| SAMPLERS = ["pingpong", "euler", "rk4", "dpmpp"] | |
| # --------------------------------------------------------------------------- | |
| # Spectrogram helper (Mel; adapted from the reference repo's aeiou.py) | |
| # --------------------------------------------------------------------------- | |
| def _power_to_db(spec: np.ndarray, amin: float = 1e-10) -> np.ndarray: | |
| return 10.0 * np.log10(np.maximum(amin, spec)) | |
| def audio_spectrogram_image( | |
| waveform: torch.Tensor, | |
| sample_rate: int, | |
| db_range=(35, 120), | |
| figsize=(5, 4), | |
| ) -> Image.Image: | |
| """Render a Mel spectrogram (left channel) as a PIL image.""" | |
| if waveform.dim() == 1: | |
| waveform = waveform.unsqueeze(0) | |
| n_fft = 1024 | |
| hop_length = n_fft // 2 | |
| mel_op = T.MelSpectrogram( | |
| sample_rate=sample_rate, n_fft=n_fft, win_length=None, | |
| hop_length=hop_length, center=True, pad_mode="reflect", power=2.0, | |
| norm="slaney", onesided=True, n_mels=128, mel_scale="htk", | |
| ) | |
| melspec = mel_op(waveform.float())[0] # left channel | |
| fig = Figure(figsize=figsize, dpi=100) | |
| canvas = FigureCanvasAgg(fig) | |
| ax = fig.add_subplot() | |
| ax.imshow(_power_to_db(melspec.numpy()), origin="lower", aspect="auto", | |
| vmin=db_range[0], vmax=db_range[1]) | |
| ax.set_ylabel("mel bins (log freq)") | |
| ax.set_xlabel("frame") | |
| ax.set_title("MelSpectrogram") | |
| canvas.draw() | |
| return Image.fromarray(np.asarray(canvas.buffer_rgba())) | |
| # --------------------------------------------------------------------------- | |
| # Inference | |
| # --------------------------------------------------------------------------- | |
| def _gradio_audio_to_tensor( | |
| audio_in: Optional[Tuple[int, np.ndarray]], | |
| ) -> Optional[Tuple[int, torch.Tensor]]: | |
| """Convert a gr.Audio (numpy) value to the (sr, torch.Tensor[C,N]) tuple | |
| that ``generate_diffusion_cond_inpaint`` expects. Accepts mono or stereo.""" | |
| if audio_in is None: | |
| return None | |
| sr, arr = audio_in | |
| if arr is None or (hasattr(arr, "size") and arr.size == 0): | |
| return None | |
| arr = np.asarray(arr) | |
| if arr.dtype.kind in ("i", "u"): | |
| max_val = float(np.iinfo(arr.dtype).max) | |
| arr = arr.astype(np.float32) / max_val | |
| else: | |
| arr = arr.astype(np.float32) | |
| if arr.ndim == 1: | |
| arr = arr[None, :] # (1, N) | |
| else: | |
| # gr.Audio returns (N, C); transpose to (C, N) | |
| arr = arr.T if arr.shape[0] > arr.shape[1] else arr | |
| return int(sr), torch.from_numpy(arr) | |
| def _tensor_to_wav( | |
| output: torch.Tensor, | |
| sample_rate: int, | |
| duration_seconds: Optional[int], | |
| out_dir: Optional[str] = None, | |
| ) -> Tuple[str, torch.Tensor]: | |
| """Pack a (B, C, N) generation tensor to int16, optionally cut to duration, | |
| write to disk, and return (path, int16-tensor).""" | |
| output = rearrange(output, "b d n -> d (b n)") | |
| output = ( | |
| output.to(torch.float32) | |
| .div(torch.max(torch.abs(output)).clamp(min=1e-9)) | |
| .clamp(-1, 1) | |
| .mul(32767) | |
| .to(torch.int16) | |
| .cpu() | |
| ) | |
| if duration_seconds is not None: | |
| output = output[:, : int(duration_seconds) * sample_rate] | |
| out_dir = out_dir or tempfile.mkdtemp() | |
| out_path = os.path.join(out_dir, "sa3.wav") | |
| sf.write(out_path, output.numpy().T, sample_rate, subtype="PCM_16") | |
| return out_path, output | |
| def _run_inference( | |
| variant_key: str, | |
| prompt: str, | |
| negative_prompt: str = "", | |
| duration: int = 60, | |
| steps: int = 8, | |
| cfg_scale: float = 1.0, | |
| sampler_type: str = "pingpong", | |
| seed: int = 0, | |
| sigma_max: float = 1.0, | |
| apg_scale: float = 1.0, | |
| duration_padding_sec: float = 6.0, | |
| cut_to_seconds_total: bool = True, | |
| init_audio: Optional[Tuple[int, np.ndarray]] = None, | |
| init_noise_level: float = 0.9, | |
| inpaint_audio: Optional[Tuple[int, np.ndarray]] = None, | |
| mask_start_sec: float = 0.0, | |
| mask_end_sec: float = 0.0, | |
| preview_every: int = 0, | |
| return_spectrogram: bool = True, | |
| progress: gr.Progress = gr.Progress(), | |
| ): | |
| """Full-featured generation. Returns (audio_path, [spectrogram_img, *previews]) | |
| when ``return_spectrogram`` is True, else just ``audio_path``.""" | |
| prompt = (prompt or "").strip() | |
| if not prompt: | |
| raise gr.Error("Please enter a prompt.") | |
| if variant_key not in LOADED: | |
| raise gr.Error(f"Unknown variant {variant_key!r}.") | |
| lv = LOADED[variant_key] | |
| duration = max(1, min(int(duration), lv.max_seconds)) | |
| progress(0.05, desc=f"[{variant_key}] preparing conditioning") | |
| conditioning = [{"prompt": prompt, "seconds_total": int(duration)}] | |
| negative_conditioning = None | |
| neg = (negative_prompt or "").strip() | |
| if neg: | |
| negative_conditioning = [{"prompt": neg, "seconds_total": int(duration)}] | |
| # The pretransform encoder is fp16 (we cast the whole model at startup), | |
| # but prepare_audio's torchaudio Resample uses an fp32 kernel. Pre-resample | |
| # in fp32 here so prepare_audio's resample is a no-op, then cast to the | |
| # model dtype so the encoder doesn't see a dtype mismatch. | |
| model_dtype = next(lv.model.parameters()).dtype | |
| def _prep(tup): | |
| if tup is None: | |
| return None | |
| sr, t = tup | |
| t = t.float() | |
| if sr != lv.sample_rate: | |
| t = torchaudio.functional.resample(t, sr, lv.sample_rate) | |
| return lv.sample_rate, t.to(model_dtype) | |
| init_audio_t = _prep(_gradio_audio_to_tensor(init_audio)) | |
| inpaint_audio_t = _prep(_gradio_audio_to_tensor(inpaint_audio)) | |
| # Inpaint mask: only enable if mask_end > mask_start AND we have either | |
| # inpaint_audio or init_audio (otherwise the mask wraps zero content). | |
| mask_start = max(0.0, float(mask_start_sec)) | |
| mask_end = min(float(duration), float(mask_end_sec)) | |
| use_mask = ( | |
| inpaint_audio_t is not None | |
| and mask_end > mask_start | |
| ) | |
| seed_val = int(seed) if seed and int(seed) > 0 else -1 | |
| preview_images: list = [] | |
| callback = None | |
| if preview_every and int(preview_every) > 0: | |
| every = int(preview_every) | |
| def _cb(info): | |
| i = info["i"] | |
| if i % every != 0: | |
| return | |
| denoised = info["denoised"] | |
| try: | |
| if lv.model.pretransform is not None: | |
| denoised = lv.model.pretransform.decode(denoised) | |
| d = rearrange(denoised, "b d n -> d (b n)") | |
| d = d.clamp(-1, 1).mul(32767).to(torch.int16).cpu() | |
| img = audio_spectrogram_image(d, sample_rate=lv.sample_rate) | |
| preview_images.append((img, f"Step {i + 1}")) | |
| except Exception as e: | |
| print(f"[preview] skipped step {i}: {e}", flush=True) | |
| callback = _cb | |
| gen_kwargs: dict = dict( | |
| steps=int(steps), | |
| cfg_scale=float(cfg_scale), | |
| conditioning=conditioning, | |
| negative_conditioning=negative_conditioning, | |
| sample_size=lv.sample_size, | |
| sampler_type=sampler_type, | |
| seed=seed_val, | |
| device="cuda", | |
| sigma_max=float(sigma_max), | |
| apg_scale=float(apg_scale), | |
| duration_padding_sec=float(duration_padding_sec), | |
| ) | |
| if init_audio_t is not None: | |
| gen_kwargs["init_audio"] = init_audio_t | |
| gen_kwargs["init_noise_level"] = float(init_noise_level) | |
| if inpaint_audio_t is not None: | |
| gen_kwargs["inpaint_audio"] = inpaint_audio_t | |
| if use_mask: | |
| gen_kwargs["inpaint_mask_start_seconds"] = mask_start | |
| gen_kwargs["inpaint_mask_end_seconds"] = mask_end | |
| if callback is not None: | |
| gen_kwargs["callback"] = callback | |
| progress(0.25, desc=f"[{variant_key}] sampling {steps} steps with {sampler_type}") | |
| t0 = time.time() | |
| output = generate_diffusion_cond_inpaint(lv.model, **gen_kwargs) | |
| print(f"[infer/{variant_key}] sampling done in {time.time() - t0:.1f}s", flush=True) | |
| progress(0.92, desc="Normalising & saving") | |
| cut_dur = int(duration) if cut_to_seconds_total else None | |
| out_path, int16_audio = _tensor_to_wav(output, lv.sample_rate, cut_dur) | |
| if not return_spectrogram: | |
| return out_path | |
| spec_img = audio_spectrogram_image(int16_audio, sample_rate=lv.sample_rate) | |
| return out_path, [spec_img, *preview_images] | |
| def infer( | |
| variant_key: str, | |
| prompt: str, | |
| duration: int = 60, | |
| steps: int = 8, | |
| cfg_scale: float = 1.0, | |
| sampler_type: str = "pingpong", | |
| seed: int = 0, | |
| progress: gr.Progress = gr.Progress(), | |
| ): | |
| """Slim handler used by the Simple tab and the Examples cache.""" | |
| return _run_inference( | |
| variant_key=variant_key, | |
| prompt=prompt, | |
| duration=duration, | |
| steps=steps, | |
| cfg_scale=cfg_scale, | |
| sampler_type=sampler_type, | |
| seed=seed, | |
| return_spectrogram=False, | |
| progress=progress, | |
| ) | |
| def infer_advanced( | |
| variant_key: str, | |
| prompt: str, | |
| negative_prompt: str, | |
| duration: int, | |
| steps: int, | |
| cfg_scale: float, | |
| sampler_type: str, | |
| seed: int, | |
| sigma_max: float, | |
| apg_scale: float, | |
| duration_padding_sec: float, | |
| cut_to_seconds_total: bool, | |
| init_audio: Optional[Tuple[int, np.ndarray]], | |
| init_noise_level: float, | |
| inpaint_audio: Optional[Tuple[int, np.ndarray]], | |
| mask_start_sec: float, | |
| mask_end_sec: float, | |
| preview_every: int, | |
| progress: gr.Progress = gr.Progress(), | |
| ): | |
| """Full-featured handler used by the Advanced tab.""" | |
| return _run_inference( | |
| variant_key=variant_key, | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| duration=duration, | |
| steps=steps, | |
| cfg_scale=cfg_scale, | |
| sampler_type=sampler_type, | |
| seed=seed, | |
| sigma_max=sigma_max, | |
| apg_scale=apg_scale, | |
| duration_padding_sec=duration_padding_sec, | |
| cut_to_seconds_total=cut_to_seconds_total, | |
| init_audio=init_audio, | |
| init_noise_level=init_noise_level, | |
| inpaint_audio=inpaint_audio, | |
| mask_start_sec=mask_start_sec, | |
| mask_end_sec=mask_end_sec, | |
| preview_every=preview_every, | |
| return_spectrogram=True, | |
| progress=progress, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # UI | |
| # --------------------------------------------------------------------------- | |
| DESCRIPTION = """ | |
| # 🎵 Stable Audio 3 | |
| Text-to-audio generation with <a href="https://huggingface.co/collections/stabilityai/stable-audio-3" target="_blank" rel="noopener noreferrer">Stable Audio 3</a>. Pick a variant, write a prompt, hit Generate. Switch to **Advanced** for the full sampler / init-audio / inpainting controls. | |
| """ | |
| EXAMPLES = [ | |
| ["medium", "House music that encapsulates the feeling of being at a festival in the sunny weather with all your friends 124 BPM", 60], | |
| ["small-music", "Cinematic neo-soul groove with electric piano, brushed drums, walking upright bass, smoky vibe 92 BPM", 45], | |
| ["small-music", "Driving techno track with rolling 16th-note hats, deep sub bass, acid arpeggios building tension 132 BPM", 60], | |
| ["small-sfx", "Chugging train coming into station with horn", 7], | |
| ["small-sfx", "Heavy rain on a tin roof with distant thunder rolls", 10], | |
| ["medium", "Rainy night, lo-fi hip-hop beat with vinyl crackle, mellow piano chords, soft kick and snare 80 BPM", 30], | |
| ] | |
| def _variant_change_simple(variant_key: str): | |
| lv = LOADED[variant_key] | |
| return ( | |
| gr.update(maximum=lv.max_seconds, value=min(lv.variant.default_duration, lv.max_seconds), | |
| label=f"Duration (s) · model max {lv.max_seconds}s"), | |
| gr.update(placeholder=lv.variant.placeholder), | |
| ) | |
| def _variant_change_advanced(variant_key: str): | |
| lv = LOADED[variant_key] | |
| dur = min(lv.variant.default_duration, lv.max_seconds) | |
| return ( | |
| gr.update(maximum=lv.max_seconds, value=dur, | |
| label=f"Seconds total · model max {lv.max_seconds}s"), | |
| gr.update(placeholder=lv.variant.placeholder), | |
| gr.update(maximum=float(lv.max_seconds), value=0.0), | |
| gr.update(maximum=float(lv.max_seconds), value=float(dur)), | |
| ) | |
| with gr.Blocks(theme=gr.themes.Citrus(), title="Stable Audio 3") as demo: | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Tabs(): | |
| # ----------------------------------------------------------------- | |
| # Simple tab | |
| # ----------------------------------------------------------------- | |
| with gr.Tab("Simple"): | |
| variant = gr.Radio( | |
| choices=VARIANT_CHOICES, | |
| value=VARIANTS[0].key, | |
| label="Model", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder=VARIANTS[0].placeholder, | |
| lines=3, | |
| ) | |
| duration = gr.Slider( | |
| 1, LOADED[VARIANTS[0].key].max_seconds, | |
| value=VARIANTS[0].default_duration, step=1, | |
| label=f"Duration (s) · model max {LOADED[VARIANTS[0].key].max_seconds}s", | |
| ) | |
| with gr.Accordion("Advanced settings", open=False): | |
| steps = gr.Slider(1, 50, value=8, step=1, label="Steps") | |
| cfg_scale = gr.Slider(0.5, 8.0, value=1.0, step=0.1, label="CFG scale") | |
| sampler_type = gr.Dropdown(SAMPLERS, value="pingpong", label="Sampler") | |
| seed = gr.Number(value=0, precision=0, label="Seed (0 = random)") | |
| run_btn = gr.Button("🎼 Generate", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| audio_out = gr.Audio(label="Output", type="filepath", autoplay=True) | |
| gr.Examples( | |
| examples=EXAMPLES, | |
| inputs=[variant, prompt, duration], | |
| outputs=[audio_out], | |
| fn=infer, | |
| cache_examples=True, | |
| cache_mode="lazy", | |
| label="Examples (lazy-cached on first click)", | |
| ) | |
| variant.change( | |
| fn=_variant_change_simple, | |
| inputs=[variant], | |
| outputs=[duration, prompt], | |
| ) | |
| run_btn.click( | |
| fn=infer, | |
| inputs=[variant, prompt, duration, steps, cfg_scale, sampler_type, seed], | |
| outputs=[audio_out], | |
| ) | |
| # ----------------------------------------------------------------- | |
| # Advanced tab — mirrors stable_audio_3/interface/diffusion_cond.py | |
| # ----------------------------------------------------------------- | |
| with gr.Tab("Advanced"): | |
| adv_variant = gr.Radio( | |
| choices=VARIANT_CHOICES, | |
| value=VARIANTS[0].key, | |
| label="Model", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=6): | |
| adv_prompt = gr.Textbox( | |
| show_label=False, | |
| placeholder=VARIANTS[0].placeholder, | |
| ) | |
| adv_negative = gr.Textbox( | |
| show_label=False, placeholder="Negative prompt" | |
| ) | |
| adv_generate = gr.Button("Generate", variant="primary", scale=1) | |
| with gr.Row(equal_height=False): | |
| with gr.Column(): | |
| adv_seconds_total = gr.Slider( | |
| minimum=1, | |
| maximum=LOADED[VARIANTS[0].key].max_seconds, | |
| step=1, | |
| value=VARIANTS[0].default_duration, | |
| label=f"Seconds total · model max {LOADED[VARIANTS[0].key].max_seconds}s", | |
| ) | |
| with gr.Row(): | |
| adv_steps = gr.Slider( | |
| minimum=1, maximum=500, step=1, value=8, label="Steps" | |
| ) | |
| adv_cfg = gr.Slider( | |
| minimum=0.0, maximum=25.0, step=0.1, value=1.0, | |
| label="CFG scale", | |
| ) | |
| with gr.Accordion("Sampler params", open=False): | |
| with gr.Row(): | |
| adv_seed = gr.Number( | |
| label="Seed (set to -1 for random seed)", | |
| value=-1, precision=0, | |
| ) | |
| adv_sampler = gr.Dropdown( | |
| SAMPLERS, label="Sampler type", value="pingpong", | |
| ) | |
| adv_sigma_max = gr.Slider( | |
| minimum=0.0, maximum=1.0, step=0.01, value=1.0, | |
| label="Sigma max", | |
| ) | |
| with gr.Row(): | |
| adv_apg = gr.Slider( | |
| minimum=0.0, maximum=1.0, step=0.1, value=1.0, | |
| label="APG scale", info="1.0=full APG, 0.0=vanilla CFG", | |
| ) | |
| adv_dur_padding = gr.Slider( | |
| minimum=0.0, maximum=30.0, step=0.5, value=6.0, | |
| label="Duration padding (sec)", | |
| ) | |
| with gr.Accordion("Output params", open=False): | |
| with gr.Row(): | |
| adv_preview_every = gr.Slider( | |
| minimum=0, maximum=100, step=1, value=0, | |
| label="Spec preview every N steps (0 = off)", | |
| ) | |
| adv_cut_to_total = gr.Checkbox( | |
| label="Cut to seconds total", value=True, | |
| ) | |
| with gr.Accordion("Init audio", open=False): | |
| adv_init_audio = gr.Audio( | |
| label="Init audio", | |
| type="numpy", | |
| ) | |
| adv_init_noise = gr.Slider( | |
| minimum=0.01, maximum=1.0, step=0.01, value=0.9, | |
| label="Init noise level", | |
| ) | |
| with gr.Accordion("Inpainting", open=False): | |
| adv_inpaint_audio = gr.Audio( | |
| label="Inpaint audio", | |
| type="numpy", | |
| ) | |
| adv_mask_start = gr.Slider( | |
| minimum=0.0, | |
| maximum=float(LOADED[VARIANTS[0].key].max_seconds), | |
| step=0.1, value=0.0, label="Mask start (sec)", | |
| ) | |
| adv_mask_end = gr.Slider( | |
| minimum=0.0, | |
| maximum=float(LOADED[VARIANTS[0].key].max_seconds), | |
| step=0.1, value=0.0, label="Mask end (sec)", | |
| ) | |
| with gr.Column(): | |
| adv_audio_out = gr.Audio( | |
| label="Output audio", type="filepath", autoplay=False, | |
| sources=[], | |
| ) | |
| adv_spec_gallery = gr.Gallery( | |
| label="Output spectrogram", show_label=True, columns=2, | |
| ) | |
| send_to_init_btn = gr.Button("Send to init audio") | |
| send_to_inpaint_btn = gr.Button("Send to inpaint audio") | |
| send_to_init_btn.click( | |
| fn=lambda a: a, inputs=[adv_audio_out], outputs=[adv_init_audio] | |
| ) | |
| send_to_inpaint_btn.click( | |
| fn=lambda a: a, inputs=[adv_audio_out], outputs=[adv_inpaint_audio] | |
| ) | |
| # Keep the inpaint mask bounded by the current duration. | |
| def _update_mask_max(seconds_total): | |
| m = max(float(seconds_total), 1.0) | |
| return ( | |
| gr.update(maximum=m), | |
| gr.update(maximum=m, value=m), | |
| ) | |
| adv_seconds_total.change( | |
| _update_mask_max, | |
| inputs=[adv_seconds_total], | |
| outputs=[adv_mask_start, adv_mask_end], | |
| ) | |
| adv_variant.change( | |
| fn=_variant_change_advanced, | |
| inputs=[adv_variant], | |
| outputs=[adv_seconds_total, adv_prompt, adv_mask_start, adv_mask_end], | |
| ) | |
| adv_generate.click( | |
| fn=infer_advanced, | |
| inputs=[ | |
| adv_variant, | |
| adv_prompt, | |
| adv_negative, | |
| adv_seconds_total, | |
| adv_steps, | |
| adv_cfg, | |
| adv_sampler, | |
| adv_seed, | |
| adv_sigma_max, | |
| adv_apg, | |
| adv_dur_padding, | |
| adv_cut_to_total, | |
| adv_init_audio, | |
| adv_init_noise, | |
| adv_inpaint_audio, | |
| adv_mask_start, | |
| adv_mask_end, | |
| adv_preview_every, | |
| ], | |
| outputs=[adv_audio_out, adv_spec_gallery], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |