Spaces:
Running
Running
| """Gradio demo for LocalVQE — real-time AEC + NS + dereverb. | |
| Loads released model versions side-by-side and exposes a runtime | |
| selector so you can A/B them on the same clip: | |
| v1.2 — newest, default. 1.3 M params. SiLU activation + dmax 64 | |
| (1024 ms echo-search window) + wider clean-pool DNSMOS | |
| filter + phone-bandwidth + codec round-trip aug. Adds | |
| ~+0.3 echo_mos / ~+1 dB ERLE on AEC blind FE-ST vs v1.1. | |
| Path resolves from LOCALVQE_V12_CKPT, else HF. | |
| v1.1 — previous release. 1.3 M params. ReLU6, pre-norm | |
| CausalGroupNorm, STFT-256 codec. Fixes intermittent | |
| crackling that v1 produced under heavy background noise. | |
| Path resolves from LOCALVQE_V11_CKPT, else HF. | |
| v1 — original release. Path resolves from LOCALVQE_V1_CKPT | |
| (or LOCALVQE_LOCAL_CKPT for backward compat), else HF. | |
| If a checkpoint isn't reachable that entry is hidden from the | |
| selector. Each architecture lives in an independent Python | |
| package so they can be loaded simultaneously without import | |
| collisions: | |
| v1 → space/localvqe_model/ | |
| v1.1 → space/localvqe_v11/ | |
| v1.2 → space/localvqe_v12/ | |
| """ | |
| import hashlib | |
| import os | |
| from pathlib import Path | |
| import gradio as gr | |
| import numpy as np | |
| import soundfile as sf | |
| import torch | |
| from scipy.signal import resample_poly | |
| # v1 (original release) — namespace 'localvqe_model' | |
| from localvqe_model import ( | |
| Config as ConfigV1, | |
| LocalVQE as LocalVQEv1, | |
| apply_ckpt_model_config as apply_ckpt_v1, | |
| load_checkpoint as load_ckpt_v1, | |
| ) | |
| # v1.1 / v1.2 — bundled in this directory. Imported on demand to keep | |
| # startup time low when those versions aren't configured. | |
| def _import_v11(): | |
| from localvqe_v11 import ( | |
| Config as ConfigV11, | |
| LocalVQE as LocalVQEv11, | |
| apply_ckpt_model_config as apply_ckpt_v11, | |
| load_checkpoint as load_ckpt_v11, | |
| ) | |
| return ConfigV11, LocalVQEv11, apply_ckpt_v11, load_ckpt_v11 | |
| def _import_v12(): | |
| from localvqe_v12 import ( | |
| Config as ConfigV12, | |
| LocalVQE as LocalVQEv12, | |
| apply_ckpt_model_config as apply_ckpt_v12, | |
| load_checkpoint as load_ckpt_v12, | |
| ) | |
| return ConfigV12, LocalVQEv12, apply_ckpt_v12, load_ckpt_v12 | |
| SR = 16000 | |
| HF_REPO_ID = "LocalAI-io/LocalVQE" | |
| HF_V1_FILE = "localvqe-v1-1.3M.pt" | |
| HF_V11_FILE = "localvqe-v1.1-1.3M.pt" | |
| HF_V12_FILE = "localvqe-v1.2-1.3M.pt" | |
| EXAMPLES_DIR = Path(__file__).resolve().parent / "examples" | |
| def _sha256(path: str) -> str: | |
| h = hashlib.sha256() | |
| with open(path, "rb") as f: | |
| for chunk in iter(lambda: f.read(1 << 20), b""): | |
| h.update(chunk) | |
| return h.hexdigest() | |
| def _resolve_v1_ckpt() -> str | None: | |
| # Backward-compat: LOCALVQE_LOCAL_CKPT used to be the way to override. | |
| for env in ("LOCALVQE_V1_CKPT", "LOCALVQE_LOCAL_CKPT"): | |
| v = os.environ.get(env) | |
| if v: | |
| return v | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| return hf_hub_download(repo_id=HF_REPO_ID, filename=HF_V1_FILE) | |
| except Exception as e: | |
| print(f"v1 unavailable from HF ({e})") | |
| return None | |
| def _resolve_v11_ckpt() -> str | None: | |
| v = os.environ.get("LOCALVQE_V11_CKPT") | |
| if v: | |
| return v | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| return hf_hub_download(repo_id=HF_REPO_ID, filename=HF_V11_FILE) | |
| except Exception: | |
| return None | |
| def _resolve_v12_ckpt() -> str | None: | |
| v = os.environ.get("LOCALVQE_V12_CKPT") | |
| if v: | |
| return v | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| return hf_hub_download(repo_id=HF_REPO_ID, filename=HF_V12_FILE) | |
| except Exception: | |
| return None | |
| def _resolve_v121_ckpt() -> str | None: | |
| # No HF fallback yet — v1.2.1 isn't published. Set LOCALVQE_V121_CKPT | |
| # in docker-compose.yml (defaults to checkpoints/release/...) to load | |
| # the local finetuned copy. | |
| return os.environ.get("LOCALVQE_V121_CKPT") or None | |
| def _resolve_v12a_ckpt() -> str | None: | |
| # v1.2a — v9 (widened DRR + longer RIRs + global gain) from-scratch | |
| # epoch 14. Architecture identical to v1.2/v1.2.1 (uses localvqe_v12 | |
| # package). No HF publish yet. | |
| return os.environ.get("LOCALVQE_V12A_CKPT") or None | |
| def _resolve_v12b_ckpt() -> str | None: | |
| # v1.2b — v10 (v1.2 + audible reverb + 80/20 conference mix + | |
| # pipeline pop fixes, no experimental augs) from-scratch e19. | |
| # Architecture identical to v1.2 (uses localvqe_v12 package). | |
| return os.environ.get("LOCALVQE_V12B_CKPT") or None | |
| def _resolve_v12c_ckpt() -> str | None: | |
| # v1.2c — v11 (v10 + level-invariance mic-gain aug, | |
| # clean_attenuation_factor=1.0) from-scratch e17. Addresses | |
| # low-SNR wobble near noise floor. Architecture identical to | |
| # v1.2 (uses localvqe_v12 package). | |
| return os.environ.get("LOCALVQE_V12C_CKPT") or None | |
| def _resolve_v12d_ckpt() -> str | None: | |
| # v1.2d — v11_refine e22 (10-epoch low-LR cosine continuation | |
| # of v1.2c from v11 e20, peak LR 1e-4). Blind eval beats | |
| # v1.2c on FE-ST echo_mos (+0.31) and NE-ST deg_mos (+0.04) | |
| # while recovering 2.4 dB of FE-ST ERLE. Architecture | |
| # identical to v1.2 (uses localvqe_v12 package). | |
| return os.environ.get("LOCALVQE_V12D_CKPT") or None | |
| def _build_v1(): | |
| ckpt_path = _resolve_v1_ckpt() | |
| if ckpt_path is None: | |
| return None, None | |
| cfg = ConfigV1() | |
| peek = torch.load(ckpt_path, map_location="cpu", weights_only=False) | |
| apply_ckpt_v1(peek, cfg) | |
| del peek | |
| model = LocalVQEv1.from_config(cfg).to("cpu") | |
| load_ckpt_v1(ckpt_path, model) | |
| # Fold the trained AlignBlock softmax temperature (a buffer in the | |
| # checkpoint) into the smoothing conv — without this, eval runs at | |
| # the default 1.0 instead of the trained value, losing ~5 dB ERLE. | |
| model.align.fold_temperature() | |
| model.eval() | |
| info = { | |
| "source": ckpt_path, | |
| "sha256": _sha256(ckpt_path), | |
| "n_params": sum(p.numel() for p in model.parameters()), | |
| "label": "v1 (previous release)", | |
| } | |
| print(f"v1 loaded: {info['n_params']:,} params sha={info['sha256'][:16]}… " | |
| f"src={ckpt_path}") | |
| return model, info | |
| def _build_v11(): | |
| ckpt_path = _resolve_v11_ckpt() | |
| if ckpt_path is None: | |
| return None, None | |
| ConfigV11, LocalVQEv11, apply_ckpt_v11, load_ckpt_v11 = _import_v11() | |
| cfg = ConfigV11() | |
| peek = torch.load(ckpt_path, map_location="cpu", weights_only=False) | |
| apply_ckpt_v11(peek, cfg) | |
| del peek | |
| model = LocalVQEv11.from_config(cfg).to("cpu") | |
| load_ckpt_v11(ckpt_path, model) | |
| model.align.fold_temperature() | |
| model.eval() | |
| info = { | |
| "source": ckpt_path, | |
| "sha256": _sha256(ckpt_path), | |
| "n_params": sum(p.numel() for p in model.parameters()), | |
| "label": "v1.1 (previous release)", | |
| } | |
| print(f"v1.1 loaded: {info['n_params']:,} params sha={info['sha256'][:16]}… " | |
| f"src={ckpt_path}") | |
| return model, info | |
| def _build_v12_like(ckpt_path, label): | |
| """Shared builder for v1.2 and v1.2.1 — same architecture, same package.""" | |
| ConfigV12, LocalVQEv12, apply_ckpt_v12, load_ckpt_v12 = _import_v12() | |
| cfg = ConfigV12() | |
| peek = torch.load(ckpt_path, map_location="cpu", weights_only=False) | |
| apply_ckpt_v12(peek, cfg) | |
| del peek | |
| model = LocalVQEv12.from_config(cfg).to("cpu") | |
| load_ckpt_v12(ckpt_path, model) | |
| model.align.fold_temperature() | |
| model.eval() | |
| info = { | |
| "source": ckpt_path, | |
| "sha256": _sha256(ckpt_path), | |
| "n_params": sum(p.numel() for p in model.parameters()), | |
| "label": label, | |
| } | |
| return model, info | |
| def _build_v12(): | |
| ckpt_path = _resolve_v12_ckpt() | |
| if ckpt_path is None: | |
| return None, None | |
| model, info = _build_v12_like(ckpt_path, "v1.2 (current release)") | |
| print(f"v1.2 loaded: {info['n_params']:,} params sha={info['sha256'][:16]}… " | |
| f"src={ckpt_path}") | |
| return model, info | |
| def _build_v121(): | |
| ckpt_path = _resolve_v121_ckpt() | |
| if ckpt_path is None: | |
| return None, None | |
| model, info = _build_v12_like(ckpt_path, "v1.2.1 (movement-aug finetune)") | |
| print(f"v1.2.1 loaded: {info['n_params']:,} params sha={info['sha256'][:16]}… " | |
| f"src={ckpt_path}") | |
| return model, info | |
| def _build_v12a(): | |
| ckpt_path = _resolve_v12a_ckpt() | |
| if ckpt_path is None: | |
| return None, None | |
| model, info = _build_v12_like( | |
| ckpt_path, "v1.2a (widened DRR + longer RIRs, from-scratch)") | |
| print(f"v1.2a loaded: {info['n_params']:,} params sha={info['sha256'][:16]}… " | |
| f"src={ckpt_path}") | |
| return model, info | |
| def _build_v12b(): | |
| ckpt_path = _resolve_v12b_ckpt() | |
| if ckpt_path is None: | |
| return None, None | |
| model, info = _build_v12_like( | |
| ckpt_path, "v1.2b (v10: audible reverb + conference mix + pop fixes)") | |
| print(f"v1.2b loaded: {info['n_params']:,} params sha={info['sha256'][:16]}… " | |
| f"src={ckpt_path}") | |
| return model, info | |
| def _build_v12c(): | |
| ckpt_path = _resolve_v12c_ckpt() | |
| if ckpt_path is None: | |
| return None, None | |
| model, info = _build_v12_like( | |
| ckpt_path, "v1.2c (v11: level-invariance mic-gain on v1.2b base)") | |
| print(f"v1.2c loaded: {info['n_params']:,} params sha={info['sha256'][:16]}… " | |
| f"src={ckpt_path}") | |
| return model, info | |
| def _build_v12d(): | |
| ckpt_path = _resolve_v12d_ckpt() | |
| if ckpt_path is None: | |
| return None, None | |
| model, info = _build_v12_like( | |
| ckpt_path, "v1.2d (v11_refine e22: low-LR cosine polish of v1.2c)") | |
| print(f"v1.2d loaded: {info['n_params']:,} params sha={info['sha256'][:16]}… " | |
| f"src={ckpt_path}") | |
| return model, info | |
| MODEL_V1, INFO_V1 = _build_v1() | |
| MODEL_V11, INFO_V11 = _build_v11() | |
| MODEL_V12, INFO_V12 = _build_v12() | |
| MODEL_V121, INFO_V121 = _build_v121() | |
| MODEL_V12A, INFO_V12A = _build_v12a() | |
| MODEL_V12B, INFO_V12B = _build_v12b() | |
| MODEL_V12C, INFO_V12C = _build_v12c() | |
| MODEL_V12D, INFO_V12D = _build_v12d() | |
| MODELS: dict[str, object] = {} | |
| INFOS: dict[str, dict] = {} | |
| if MODEL_V1 is not None: | |
| MODELS["v1"] = MODEL_V1 | |
| INFOS["v1"] = INFO_V1 | |
| if MODEL_V11 is not None: | |
| MODELS["v1.1"] = MODEL_V11 | |
| INFOS["v1.1"] = INFO_V11 | |
| if MODEL_V12 is not None: | |
| MODELS["v1.2"] = MODEL_V12 | |
| INFOS["v1.2"] = INFO_V12 | |
| if MODEL_V121 is not None: | |
| MODELS["v1.2.1"] = MODEL_V121 | |
| INFOS["v1.2.1"] = INFO_V121 | |
| if MODEL_V12A is not None: | |
| MODELS["v1.2a"] = MODEL_V12A | |
| INFOS["v1.2a"] = INFO_V12A | |
| if MODEL_V12B is not None: | |
| MODELS["v1.2b"] = MODEL_V12B | |
| INFOS["v1.2b"] = INFO_V12B | |
| if MODEL_V12C is not None: | |
| MODELS["v1.2c"] = MODEL_V12C | |
| INFOS["v1.2c"] = INFO_V12C | |
| if MODEL_V12D is not None: | |
| MODELS["v1.2d"] = MODEL_V12D | |
| INFOS["v1.2d"] = INFO_V12D | |
| if not MODELS: | |
| raise RuntimeError( | |
| "No model could be loaded. Set LOCALVQE_V1_CKPT, " | |
| "LOCALVQE_V11_CKPT, LOCALVQE_V12_CKPT, LOCALVQE_V121_CKPT, " | |
| "LOCALVQE_V12A_CKPT, LOCALVQE_V12B_CKPT, LOCALVQE_V12C_CKPT, " | |
| "or LOCALVQE_V12D_CKPT, or ensure HF access for the " | |
| "published files." | |
| ) | |
| DEFAULT_MODEL_KEY = ( | |
| "v1.2d" if "v1.2d" in MODELS | |
| else "v1.2c" if "v1.2c" in MODELS | |
| else "v1.2b" if "v1.2b" in MODELS | |
| else "v1.2a" if "v1.2a" in MODELS | |
| else "v1.2.1" if "v1.2.1" in MODELS | |
| else "v1.2" if "v1.2" in MODELS | |
| else "v1.1" if "v1.1" in MODELS | |
| else "v1" | |
| ) | |
| # Dev mode: shows the diagnostic-source dropdown and mask-smoother | |
| # accordion in the UI. Auto-on locally, auto-off on HF Spaces (which | |
| # always sets `SPACE_ID`). Either can be overridden by setting | |
| # LOCALVQE_DEV_MODE=1 (force on) or =0 (force off). | |
| def _dev_mode() -> bool: | |
| explicit = os.environ.get("LOCALVQE_DEV_MODE") | |
| if explicit in ("0", "1"): | |
| return explicit == "1" | |
| return "SPACE_ID" not in os.environ | |
| DEV_MODE = _dev_mode() | |
| if DEV_MODE: | |
| print("DEV_MODE=on (debug accordions visible). Set LOCALVQE_DEV_MODE=0 to hide.") | |
| def _load_mono_16k(path: str) -> np.ndarray: | |
| wav, sr = sf.read(path, dtype="float32", always_2d=False) | |
| if wav.ndim == 2: | |
| wav = wav.mean(axis=1) | |
| if sr != SR: | |
| from math import gcd | |
| g = gcd(sr, SR) | |
| wav = resample_poly(wav, SR // g, sr // g).astype(np.float32) | |
| return wav | |
| # Debug / diagnostic helpers live in `_debug.py`, which is excluded | |
| # from the HuggingFace Spaces deploy. When this file is missing the | |
| # app silently degrades: no debug accordions, no diagnostic-source | |
| # branches, just the standard model forward. | |
| try: | |
| import _debug as _dbg | |
| DEBUG_AVAILABLE = True | |
| except ImportError: | |
| _dbg = None | |
| DEBUG_AVAILABLE = False | |
| def _noise_gate(x: np.ndarray, threshold_dbfs: float) -> np.ndarray: | |
| """Hard-gate frames whose RMS is below `threshold_dbfs` to zero. | |
| Operates on 10 ms frames (160 samples at 16 kHz) — short enough | |
| that speech bursts aren't truncated, long enough that a single | |
| out-of-band sample inside an active region doesn't get muted. | |
| The ungated tail (samples that don't fill a full final frame) is | |
| passed through unchanged. | |
| """ | |
| frame = 160 | |
| n = len(x) // frame | |
| if n == 0: | |
| return x | |
| f = x[: n * frame].reshape(n, frame).astype(np.float32) | |
| rms = np.sqrt((f * f).mean(axis=-1) + 1e-12) | |
| rms_db = 20.0 * np.log10(rms + 1e-12) | |
| keep = (rms_db > threshold_dbfs).astype(np.float32) | |
| gated = (f * keep[:, None]).reshape(-1) | |
| return np.concatenate([gated, x[n * frame:]]).astype(x.dtype) | |
| def enhance(mic_path: str, ref_path: str, | |
| model_choice: str = DEFAULT_MODEL_KEY, | |
| gate_enabled: bool = False, | |
| gate_threshold_db: float = -45.0, | |
| smoother_mode: str = "off", | |
| smoother_attack_db: float = 12.0, | |
| smoother_release_db: float = 1.0, | |
| smoother_ema_alpha: float = 0.7, | |
| smoother_floor_db: float = 20.0, | |
| smoother_median_k: int = 3, | |
| debug_source: str = "enhanced", | |
| f_smooth_kernel: int = 31, | |
| f_smooth_mode: str = "median") -> tuple[int, np.ndarray]: | |
| if mic_path is None: | |
| raise gr.Error("Upload or pick a mic recording first.") | |
| if model_choice not in MODELS: | |
| raise gr.Error(f"Model {model_choice!r} not loaded. Available: {list(MODELS)}") | |
| model = MODELS[model_choice] | |
| mic = _load_mono_16k(mic_path) | |
| if ref_path is None: | |
| ref = np.zeros_like(mic) | |
| else: | |
| ref = _load_mono_16k(ref_path) | |
| n = max(len(mic), len(ref)) | |
| if len(mic) < n: | |
| mic = np.pad(mic, (0, n - len(mic))) | |
| if len(ref) < n: | |
| ref = np.pad(ref, (0, n - len(ref))) | |
| mic_t = torch.from_numpy(mic).unsqueeze(0) | |
| ref_t = torch.from_numpy(ref).unsqueeze(0) | |
| with torch.no_grad(): | |
| if DEBUG_AVAILABLE and debug_source != "enhanced": | |
| enc = _dbg.apply_debug_source( | |
| model, mic_t, ref_t, debug_source, | |
| smoother_ema_alpha=smoother_ema_alpha, | |
| f_smooth_kernel=f_smooth_kernel, | |
| f_smooth_mode=f_smooth_mode, | |
| ) | |
| else: | |
| enc = model(mic_t, ref_t) | |
| if (DEBUG_AVAILABLE and smoother_mode != "off" | |
| and debug_source not in ("passthrough", "bypass_ccm")): | |
| enc = _dbg.apply_smoother( | |
| enc, model.encoder(mic_t), smoother_mode, | |
| attack_db=smoother_attack_db, | |
| release_db=smoother_release_db, | |
| ema_alpha=smoother_ema_alpha, | |
| floor_db=smoother_floor_db, | |
| median_k=smoother_median_k, | |
| ) | |
| enh = model.decoder(enc.float(), length=n) | |
| out = enh[0].cpu().numpy() | |
| peak = float(np.abs(out).max()) | |
| if peak > 0.95: | |
| out = out / peak * 0.95 | |
| # Optional residual-echo gate: silence frames whose RMS sits below | |
| # `gate_threshold_db` dBFS. Off by default so listeners can A/B | |
| # against the raw model output via the slider. | |
| if gate_enabled: | |
| out = _noise_gate(out, gate_threshold_db) | |
| # Convert to int16 ourselves: Gradio's gr.Audio output otherwise | |
| # peak-normalises float arrays via convert_to_16_bit_wav (data /= | |
| # np.abs(data).max(); * 32767), which amplifies the cancelled-echo | |
| # residual on AEC-heavy clips by 1000×+ and makes it sound like | |
| # the model isn't suppressing anything. Returning int16 preserves | |
| # the true (quiet) loudness so listeners hear the actual output. | |
| out_i16 = np.clip(out * 32767, -32768, 32767).astype(np.int16) | |
| return SR, out_i16 | |
| EXAMPLES = [ | |
| [ | |
| str(EXAMPLES_DIR / "ne_st_noisy_mic.wav"), | |
| str(EXAMPLES_DIR / "ne_st_noisy_ref.wav"), | |
| ], | |
| [ | |
| str(EXAMPLES_DIR / "ne_st_clean_mic.wav"), | |
| str(EXAMPLES_DIR / "ne_st_clean_ref.wav"), | |
| ], | |
| [ | |
| str(EXAMPLES_DIR / "fe_st_mic.wav"), | |
| str(EXAMPLES_DIR / "fe_st_ref.wav"), | |
| ], | |
| [ | |
| str(EXAMPLES_DIR / "fe_st2_mic.wav"), | |
| str(EXAMPLES_DIR / "fe_st2_ref.wav"), | |
| ], | |
| [ | |
| str(EXAMPLES_DIR / "dt_mic.wav"), | |
| str(EXAMPLES_DIR / "dt_ref.wav"), | |
| ], | |
| ] | |
| DESCRIPTION = """ | |
| **LocalVQE** is a ~1 M-parameter open-source model that cleans up a | |
| microphone signal on a voice call: it cancels the remote participant's | |
| voice being picked up again (echo), suppresses background noise, and | |
| removes reverberation — all in a single causal pass on CPU. | |
| Provide two inputs: | |
| - **Mic**: the raw microphone recording (what the far end would hear | |
| without any processing). | |
| - **Far-end reference**: the audio being played out of your speakers. | |
| For a pure noise-suppression test (no speaker playback), upload | |
| silence or leave empty. | |
| Try the bundled examples first — they cover heavy and light | |
| near-end noise (NE-ST mixed with DNS5 background at 5 dB and 20 dB | |
| SNR), a clean far-end single-talk clip, a far-end clip with some | |
| near-end overlap (mislabelled in the source corpus, but a useful | |
| test of AEC + near-end preservation together), and a double-talk | |
| clip — all from the ICASSP 2022 AEC Challenge blind set. | |
| Weights: [LocalAI-io/LocalVQE](https://huggingface.co/LocalAI-io/LocalVQE) · | |
| Code: [github.com/localai-org/LocalVQE](https://github.com/localai-org/LocalVQE) | |
| """ | |
| with gr.Blocks(title="LocalVQE Demo") as demo: | |
| gr.Markdown("# LocalVQE: real-time AEC + noise suppression + dereverb") | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(): | |
| mic_in = gr.Audio(label="Mic (microphone recording)", type="filepath") | |
| ref_in = gr.Audio(label="Far-end reference (speaker playback)", type="filepath") | |
| model_choice = gr.Radio( | |
| choices=list(MODELS.keys()), | |
| value=DEFAULT_MODEL_KEY, | |
| label="Model", | |
| info=( | |
| "v1.2 is the current release. SiLU activation + 1024 ms " | |
| "echo-search window + wider clean-pool DNSMOS filter + " | |
| "phone-bandwidth + codec round-trip aug. Adds ~+0.3 " | |
| "echo_mos and ~+1 dB ERLE on the AEC blind set vs v1.1. " | |
| "v1.1 / v1 are kept for A/B. Same param count (1.3 M). " | |
| "Switch and re-run on the same clip to compare." | |
| ), | |
| ) if len(MODELS) > 1 else gr.State(DEFAULT_MODEL_KEY) | |
| with gr.Row(): | |
| gate_enabled = gr.Checkbox( | |
| label="Residual-echo gate", | |
| value=False, | |
| info=( | |
| "Post-process the enhanced output: silence any 10 ms frame " | |
| "whose RMS falls below the threshold. Cleans up the quiet " | |
| "residual you'd hear during far-end-only stretches; will " | |
| "also mute genuinely quiet speech below the threshold." | |
| ), | |
| ) | |
| gate_threshold_db = gr.Slider( | |
| label="Gate threshold (dBFS)", | |
| minimum=-70.0, maximum=-20.0, value=-45.0, step=1.0, | |
| ) | |
| if DEBUG_AVAILABLE and DEV_MODE: | |
| _dbg_components = _dbg.build_debug_ui(gr) | |
| debug_source = _dbg_components["debug_source"] | |
| f_smooth_kernel = _dbg_components["f_smooth_kernel"] | |
| f_smooth_mode = _dbg_components["f_smooth_mode"] | |
| smoother_mode = _dbg_components["smoother_mode"] | |
| smoother_attack_db = _dbg_components["smoother_attack_db"] | |
| smoother_release_db = _dbg_components["smoother_release_db"] | |
| smoother_ema_alpha = _dbg_components["smoother_ema_alpha"] | |
| smoother_floor_db = _dbg_components["smoother_floor_db"] | |
| smoother_median_k = _dbg_components["smoother_median_k"] | |
| else: | |
| # Production / no _debug.py — hidden gr.State holders carrying | |
| # neutral defaults, so `enhance()` keeps a stable input list. | |
| debug_source = gr.State("enhanced") | |
| f_smooth_kernel = gr.State(31) | |
| f_smooth_mode = gr.State("median") | |
| smoother_mode = gr.State("off") | |
| smoother_attack_db = gr.State(12.0) | |
| smoother_release_db = gr.State(1.0) | |
| smoother_ema_alpha = gr.State(0.7) | |
| smoother_floor_db = gr.State(20.0) | |
| smoother_median_k = gr.State(3) | |
| btn = gr.Button("Enhance", variant="primary") | |
| out = gr.Audio(label="Enhanced output", type="numpy") | |
| gr.Examples( | |
| examples=EXAMPLES, | |
| inputs=[mic_in, ref_in], | |
| label=( | |
| "Examples — top to bottom: near-end + heavy noise (5 dB SNR, " | |
| "pure NS), near-end + light noise (20 dB SNR, NS preserving " | |
| "clean speech), far-end single-talk (pure AEC), far-end with " | |
| "brief near-end overlap (AEC while preserving NE), and " | |
| "double-talk (AEC while near-end is also talking)." | |
| ), | |
| ) | |
| btn.click( | |
| enhance, | |
| inputs=[mic_in, ref_in, model_choice, | |
| gate_enabled, gate_threshold_db, | |
| smoother_mode, smoother_attack_db, smoother_release_db, | |
| smoother_ema_alpha, smoother_floor_db, smoother_median_k, | |
| debug_source, f_smooth_kernel, f_smooth_mode], | |
| outputs=out, | |
| ) | |
| _info_lines = [] | |
| for key in MODELS: | |
| i = INFOS[key] | |
| _info_lines.append( | |
| f"<b>{i['label']}</b> — <code>{i['source']}</code> · " | |
| f"sha256 <code>{i['sha256'][:16]}…</code> · " | |
| f"{i['n_params']:,} params" | |
| ) | |
| gr.Markdown("<sub>Loaded models:<br>" + "<br>".join(_info_lines) + "</sub>") | |
| if __name__ == "__main__": | |
| demo.launch(server_name=os.environ.get("GRADIO_SERVER_NAME", "127.0.0.1")) | |