"""Gradio demo for LocalVQE — real-time AEC + NS + dereverb.
Loads released model versions side-by-side and exposes a runtime
selector so you can A/B them on the same clip:
v1.2 — newest, default. 1.3 M params. SiLU activation + dmax 64
(1024 ms echo-search window) + wider clean-pool DNSMOS
filter + phone-bandwidth + codec round-trip aug. Adds
~+0.3 echo_mos / ~+1 dB ERLE on AEC blind FE-ST vs v1.1.
Path resolves from LOCALVQE_V12_CKPT, else HF.
v1.1 — previous release. 1.3 M params. ReLU6, pre-norm
CausalGroupNorm, STFT-256 codec. Fixes intermittent
crackling that v1 produced under heavy background noise.
Path resolves from LOCALVQE_V11_CKPT, else HF.
v1 — original release. Path resolves from LOCALVQE_V1_CKPT
(or LOCALVQE_LOCAL_CKPT for backward compat), else HF.
If a checkpoint isn't reachable that entry is hidden from the
selector. Each architecture lives in an independent Python
package so they can be loaded simultaneously without import
collisions:
v1 → space/localvqe_model/
v1.1 → space/localvqe_v11/
v1.2 → space/localvqe_v12/
"""
import hashlib
import os
from pathlib import Path
import gradio as gr
import numpy as np
import soundfile as sf
import torch
from scipy.signal import resample_poly
# v1 (original release) — namespace 'localvqe_model'
from localvqe_model import (
Config as ConfigV1,
LocalVQE as LocalVQEv1,
apply_ckpt_model_config as apply_ckpt_v1,
load_checkpoint as load_ckpt_v1,
)
# v1.1 / v1.2 — bundled in this directory. Imported on demand to keep
# startup time low when those versions aren't configured.
def _import_v11():
from localvqe_v11 import (
Config as ConfigV11,
LocalVQE as LocalVQEv11,
apply_ckpt_model_config as apply_ckpt_v11,
load_checkpoint as load_ckpt_v11,
)
return ConfigV11, LocalVQEv11, apply_ckpt_v11, load_ckpt_v11
def _import_v12():
from localvqe_v12 import (
Config as ConfigV12,
LocalVQE as LocalVQEv12,
apply_ckpt_model_config as apply_ckpt_v12,
load_checkpoint as load_ckpt_v12,
)
return ConfigV12, LocalVQEv12, apply_ckpt_v12, load_ckpt_v12
SR = 16000
HF_REPO_ID = "LocalAI-io/LocalVQE"
HF_V1_FILE = "localvqe-v1-1.3M.pt"
HF_V11_FILE = "localvqe-v1.1-1.3M.pt"
HF_V12_FILE = "localvqe-v1.2-1.3M.pt"
EXAMPLES_DIR = Path(__file__).resolve().parent / "examples"
def _sha256(path: str) -> str:
h = hashlib.sha256()
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(1 << 20), b""):
h.update(chunk)
return h.hexdigest()
def _resolve_v1_ckpt() -> str | None:
# Backward-compat: LOCALVQE_LOCAL_CKPT used to be the way to override.
for env in ("LOCALVQE_V1_CKPT", "LOCALVQE_LOCAL_CKPT"):
v = os.environ.get(env)
if v:
return v
try:
from huggingface_hub import hf_hub_download
return hf_hub_download(repo_id=HF_REPO_ID, filename=HF_V1_FILE)
except Exception as e:
print(f"v1 unavailable from HF ({e})")
return None
def _resolve_v11_ckpt() -> str | None:
v = os.environ.get("LOCALVQE_V11_CKPT")
if v:
return v
try:
from huggingface_hub import hf_hub_download
return hf_hub_download(repo_id=HF_REPO_ID, filename=HF_V11_FILE)
except Exception:
return None
def _resolve_v12_ckpt() -> str | None:
v = os.environ.get("LOCALVQE_V12_CKPT")
if v:
return v
try:
from huggingface_hub import hf_hub_download
return hf_hub_download(repo_id=HF_REPO_ID, filename=HF_V12_FILE)
except Exception:
return None
def _resolve_v121_ckpt() -> str | None:
# No HF fallback yet — v1.2.1 isn't published. Set LOCALVQE_V121_CKPT
# in docker-compose.yml (defaults to checkpoints/release/...) to load
# the local finetuned copy.
return os.environ.get("LOCALVQE_V121_CKPT") or None
def _resolve_v12a_ckpt() -> str | None:
# v1.2a — v9 (widened DRR + longer RIRs + global gain) from-scratch
# epoch 14. Architecture identical to v1.2/v1.2.1 (uses localvqe_v12
# package). No HF publish yet.
return os.environ.get("LOCALVQE_V12A_CKPT") or None
def _resolve_v12b_ckpt() -> str | None:
# v1.2b — v10 (v1.2 + audible reverb + 80/20 conference mix +
# pipeline pop fixes, no experimental augs) from-scratch e19.
# Architecture identical to v1.2 (uses localvqe_v12 package).
return os.environ.get("LOCALVQE_V12B_CKPT") or None
def _resolve_v12c_ckpt() -> str | None:
# v1.2c — v11 (v10 + level-invariance mic-gain aug,
# clean_attenuation_factor=1.0) from-scratch e17. Addresses
# low-SNR wobble near noise floor. Architecture identical to
# v1.2 (uses localvqe_v12 package).
return os.environ.get("LOCALVQE_V12C_CKPT") or None
def _resolve_v12d_ckpt() -> str | None:
# v1.2d — v11_refine e22 (10-epoch low-LR cosine continuation
# of v1.2c from v11 e20, peak LR 1e-4). Blind eval beats
# v1.2c on FE-ST echo_mos (+0.31) and NE-ST deg_mos (+0.04)
# while recovering 2.4 dB of FE-ST ERLE. Architecture
# identical to v1.2 (uses localvqe_v12 package).
return os.environ.get("LOCALVQE_V12D_CKPT") or None
def _build_v1():
ckpt_path = _resolve_v1_ckpt()
if ckpt_path is None:
return None, None
cfg = ConfigV1()
peek = torch.load(ckpt_path, map_location="cpu", weights_only=False)
apply_ckpt_v1(peek, cfg)
del peek
model = LocalVQEv1.from_config(cfg).to("cpu")
load_ckpt_v1(ckpt_path, model)
# Fold the trained AlignBlock softmax temperature (a buffer in the
# checkpoint) into the smoothing conv — without this, eval runs at
# the default 1.0 instead of the trained value, losing ~5 dB ERLE.
model.align.fold_temperature()
model.eval()
info = {
"source": ckpt_path,
"sha256": _sha256(ckpt_path),
"n_params": sum(p.numel() for p in model.parameters()),
"label": "v1 (previous release)",
}
print(f"v1 loaded: {info['n_params']:,} params sha={info['sha256'][:16]}… "
f"src={ckpt_path}")
return model, info
def _build_v11():
ckpt_path = _resolve_v11_ckpt()
if ckpt_path is None:
return None, None
ConfigV11, LocalVQEv11, apply_ckpt_v11, load_ckpt_v11 = _import_v11()
cfg = ConfigV11()
peek = torch.load(ckpt_path, map_location="cpu", weights_only=False)
apply_ckpt_v11(peek, cfg)
del peek
model = LocalVQEv11.from_config(cfg).to("cpu")
load_ckpt_v11(ckpt_path, model)
model.align.fold_temperature()
model.eval()
info = {
"source": ckpt_path,
"sha256": _sha256(ckpt_path),
"n_params": sum(p.numel() for p in model.parameters()),
"label": "v1.1 (previous release)",
}
print(f"v1.1 loaded: {info['n_params']:,} params sha={info['sha256'][:16]}… "
f"src={ckpt_path}")
return model, info
def _build_v12_like(ckpt_path, label):
"""Shared builder for v1.2 and v1.2.1 — same architecture, same package."""
ConfigV12, LocalVQEv12, apply_ckpt_v12, load_ckpt_v12 = _import_v12()
cfg = ConfigV12()
peek = torch.load(ckpt_path, map_location="cpu", weights_only=False)
apply_ckpt_v12(peek, cfg)
del peek
model = LocalVQEv12.from_config(cfg).to("cpu")
load_ckpt_v12(ckpt_path, model)
model.align.fold_temperature()
model.eval()
info = {
"source": ckpt_path,
"sha256": _sha256(ckpt_path),
"n_params": sum(p.numel() for p in model.parameters()),
"label": label,
}
return model, info
def _build_v12():
ckpt_path = _resolve_v12_ckpt()
if ckpt_path is None:
return None, None
model, info = _build_v12_like(ckpt_path, "v1.2 (current release)")
print(f"v1.2 loaded: {info['n_params']:,} params sha={info['sha256'][:16]}… "
f"src={ckpt_path}")
return model, info
def _build_v121():
ckpt_path = _resolve_v121_ckpt()
if ckpt_path is None:
return None, None
model, info = _build_v12_like(ckpt_path, "v1.2.1 (movement-aug finetune)")
print(f"v1.2.1 loaded: {info['n_params']:,} params sha={info['sha256'][:16]}… "
f"src={ckpt_path}")
return model, info
def _build_v12a():
ckpt_path = _resolve_v12a_ckpt()
if ckpt_path is None:
return None, None
model, info = _build_v12_like(
ckpt_path, "v1.2a (widened DRR + longer RIRs, from-scratch)")
print(f"v1.2a loaded: {info['n_params']:,} params sha={info['sha256'][:16]}… "
f"src={ckpt_path}")
return model, info
def _build_v12b():
ckpt_path = _resolve_v12b_ckpt()
if ckpt_path is None:
return None, None
model, info = _build_v12_like(
ckpt_path, "v1.2b (v10: audible reverb + conference mix + pop fixes)")
print(f"v1.2b loaded: {info['n_params']:,} params sha={info['sha256'][:16]}… "
f"src={ckpt_path}")
return model, info
def _build_v12c():
ckpt_path = _resolve_v12c_ckpt()
if ckpt_path is None:
return None, None
model, info = _build_v12_like(
ckpt_path, "v1.2c (v11: level-invariance mic-gain on v1.2b base)")
print(f"v1.2c loaded: {info['n_params']:,} params sha={info['sha256'][:16]}… "
f"src={ckpt_path}")
return model, info
def _build_v12d():
ckpt_path = _resolve_v12d_ckpt()
if ckpt_path is None:
return None, None
model, info = _build_v12_like(
ckpt_path, "v1.2d (v11_refine e22: low-LR cosine polish of v1.2c)")
print(f"v1.2d loaded: {info['n_params']:,} params sha={info['sha256'][:16]}… "
f"src={ckpt_path}")
return model, info
MODEL_V1, INFO_V1 = _build_v1()
MODEL_V11, INFO_V11 = _build_v11()
MODEL_V12, INFO_V12 = _build_v12()
MODEL_V121, INFO_V121 = _build_v121()
MODEL_V12A, INFO_V12A = _build_v12a()
MODEL_V12B, INFO_V12B = _build_v12b()
MODEL_V12C, INFO_V12C = _build_v12c()
MODEL_V12D, INFO_V12D = _build_v12d()
MODELS: dict[str, object] = {}
INFOS: dict[str, dict] = {}
if MODEL_V1 is not None:
MODELS["v1"] = MODEL_V1
INFOS["v1"] = INFO_V1
if MODEL_V11 is not None:
MODELS["v1.1"] = MODEL_V11
INFOS["v1.1"] = INFO_V11
if MODEL_V12 is not None:
MODELS["v1.2"] = MODEL_V12
INFOS["v1.2"] = INFO_V12
if MODEL_V121 is not None:
MODELS["v1.2.1"] = MODEL_V121
INFOS["v1.2.1"] = INFO_V121
if MODEL_V12A is not None:
MODELS["v1.2a"] = MODEL_V12A
INFOS["v1.2a"] = INFO_V12A
if MODEL_V12B is not None:
MODELS["v1.2b"] = MODEL_V12B
INFOS["v1.2b"] = INFO_V12B
if MODEL_V12C is not None:
MODELS["v1.2c"] = MODEL_V12C
INFOS["v1.2c"] = INFO_V12C
if MODEL_V12D is not None:
MODELS["v1.2d"] = MODEL_V12D
INFOS["v1.2d"] = INFO_V12D
if not MODELS:
raise RuntimeError(
"No model could be loaded. Set LOCALVQE_V1_CKPT, "
"LOCALVQE_V11_CKPT, LOCALVQE_V12_CKPT, LOCALVQE_V121_CKPT, "
"LOCALVQE_V12A_CKPT, LOCALVQE_V12B_CKPT, LOCALVQE_V12C_CKPT, "
"or LOCALVQE_V12D_CKPT, or ensure HF access for the "
"published files."
)
DEFAULT_MODEL_KEY = (
"v1.2d" if "v1.2d" in MODELS
else "v1.2c" if "v1.2c" in MODELS
else "v1.2b" if "v1.2b" in MODELS
else "v1.2a" if "v1.2a" in MODELS
else "v1.2.1" if "v1.2.1" in MODELS
else "v1.2" if "v1.2" in MODELS
else "v1.1" if "v1.1" in MODELS
else "v1"
)
# Dev mode: shows the diagnostic-source dropdown and mask-smoother
# accordion in the UI. Auto-on locally, auto-off on HF Spaces (which
# always sets `SPACE_ID`). Either can be overridden by setting
# LOCALVQE_DEV_MODE=1 (force on) or =0 (force off).
def _dev_mode() -> bool:
explicit = os.environ.get("LOCALVQE_DEV_MODE")
if explicit in ("0", "1"):
return explicit == "1"
return "SPACE_ID" not in os.environ
DEV_MODE = _dev_mode()
if DEV_MODE:
print("DEV_MODE=on (debug accordions visible). Set LOCALVQE_DEV_MODE=0 to hide.")
def _load_mono_16k(path: str) -> np.ndarray:
wav, sr = sf.read(path, dtype="float32", always_2d=False)
if wav.ndim == 2:
wav = wav.mean(axis=1)
if sr != SR:
from math import gcd
g = gcd(sr, SR)
wav = resample_poly(wav, SR // g, sr // g).astype(np.float32)
return wav
# Debug / diagnostic helpers live in `_debug.py`, which is excluded
# from the HuggingFace Spaces deploy. When this file is missing the
# app silently degrades: no debug accordions, no diagnostic-source
# branches, just the standard model forward.
try:
import _debug as _dbg
DEBUG_AVAILABLE = True
except ImportError:
_dbg = None
DEBUG_AVAILABLE = False
def _noise_gate(x: np.ndarray, threshold_dbfs: float) -> np.ndarray:
"""Hard-gate frames whose RMS is below `threshold_dbfs` to zero.
Operates on 10 ms frames (160 samples at 16 kHz) — short enough
that speech bursts aren't truncated, long enough that a single
out-of-band sample inside an active region doesn't get muted.
The ungated tail (samples that don't fill a full final frame) is
passed through unchanged.
"""
frame = 160
n = len(x) // frame
if n == 0:
return x
f = x[: n * frame].reshape(n, frame).astype(np.float32)
rms = np.sqrt((f * f).mean(axis=-1) + 1e-12)
rms_db = 20.0 * np.log10(rms + 1e-12)
keep = (rms_db > threshold_dbfs).astype(np.float32)
gated = (f * keep[:, None]).reshape(-1)
return np.concatenate([gated, x[n * frame:]]).astype(x.dtype)
def enhance(mic_path: str, ref_path: str,
model_choice: str = DEFAULT_MODEL_KEY,
gate_enabled: bool = False,
gate_threshold_db: float = -45.0,
smoother_mode: str = "off",
smoother_attack_db: float = 12.0,
smoother_release_db: float = 1.0,
smoother_ema_alpha: float = 0.7,
smoother_floor_db: float = 20.0,
smoother_median_k: int = 3,
debug_source: str = "enhanced",
f_smooth_kernel: int = 31,
f_smooth_mode: str = "median") -> tuple[int, np.ndarray]:
if mic_path is None:
raise gr.Error("Upload or pick a mic recording first.")
if model_choice not in MODELS:
raise gr.Error(f"Model {model_choice!r} not loaded. Available: {list(MODELS)}")
model = MODELS[model_choice]
mic = _load_mono_16k(mic_path)
if ref_path is None:
ref = np.zeros_like(mic)
else:
ref = _load_mono_16k(ref_path)
n = max(len(mic), len(ref))
if len(mic) < n:
mic = np.pad(mic, (0, n - len(mic)))
if len(ref) < n:
ref = np.pad(ref, (0, n - len(ref)))
mic_t = torch.from_numpy(mic).unsqueeze(0)
ref_t = torch.from_numpy(ref).unsqueeze(0)
with torch.no_grad():
if DEBUG_AVAILABLE and debug_source != "enhanced":
enc = _dbg.apply_debug_source(
model, mic_t, ref_t, debug_source,
smoother_ema_alpha=smoother_ema_alpha,
f_smooth_kernel=f_smooth_kernel,
f_smooth_mode=f_smooth_mode,
)
else:
enc = model(mic_t, ref_t)
if (DEBUG_AVAILABLE and smoother_mode != "off"
and debug_source not in ("passthrough", "bypass_ccm")):
enc = _dbg.apply_smoother(
enc, model.encoder(mic_t), smoother_mode,
attack_db=smoother_attack_db,
release_db=smoother_release_db,
ema_alpha=smoother_ema_alpha,
floor_db=smoother_floor_db,
median_k=smoother_median_k,
)
enh = model.decoder(enc.float(), length=n)
out = enh[0].cpu().numpy()
peak = float(np.abs(out).max())
if peak > 0.95:
out = out / peak * 0.95
# Optional residual-echo gate: silence frames whose RMS sits below
# `gate_threshold_db` dBFS. Off by default so listeners can A/B
# against the raw model output via the slider.
if gate_enabled:
out = _noise_gate(out, gate_threshold_db)
# Convert to int16 ourselves: Gradio's gr.Audio output otherwise
# peak-normalises float arrays via convert_to_16_bit_wav (data /=
# np.abs(data).max(); * 32767), which amplifies the cancelled-echo
# residual on AEC-heavy clips by 1000×+ and makes it sound like
# the model isn't suppressing anything. Returning int16 preserves
# the true (quiet) loudness so listeners hear the actual output.
out_i16 = np.clip(out * 32767, -32768, 32767).astype(np.int16)
return SR, out_i16
EXAMPLES = [
[
str(EXAMPLES_DIR / "ne_st_noisy_mic.wav"),
str(EXAMPLES_DIR / "ne_st_noisy_ref.wav"),
],
[
str(EXAMPLES_DIR / "ne_st_clean_mic.wav"),
str(EXAMPLES_DIR / "ne_st_clean_ref.wav"),
],
[
str(EXAMPLES_DIR / "fe_st_mic.wav"),
str(EXAMPLES_DIR / "fe_st_ref.wav"),
],
[
str(EXAMPLES_DIR / "fe_st2_mic.wav"),
str(EXAMPLES_DIR / "fe_st2_ref.wav"),
],
[
str(EXAMPLES_DIR / "dt_mic.wav"),
str(EXAMPLES_DIR / "dt_ref.wav"),
],
]
DESCRIPTION = """
**LocalVQE** is a ~1 M-parameter open-source model that cleans up a
microphone signal on a voice call: it cancels the remote participant's
voice being picked up again (echo), suppresses background noise, and
removes reverberation — all in a single causal pass on CPU.
Provide two inputs:
- **Mic**: the raw microphone recording (what the far end would hear
without any processing).
- **Far-end reference**: the audio being played out of your speakers.
For a pure noise-suppression test (no speaker playback), upload
silence or leave empty.
Try the bundled examples first — they cover heavy and light
near-end noise (NE-ST mixed with DNS5 background at 5 dB and 20 dB
SNR), a clean far-end single-talk clip, a far-end clip with some
near-end overlap (mislabelled in the source corpus, but a useful
test of AEC + near-end preservation together), and a double-talk
clip — all from the ICASSP 2022 AEC Challenge blind set.
Weights: [LocalAI-io/LocalVQE](https://huggingface.co/LocalAI-io/LocalVQE) ·
Code: [github.com/localai-org/LocalVQE](https://github.com/localai-org/LocalVQE)
"""
with gr.Blocks(title="LocalVQE Demo") as demo:
gr.Markdown("# LocalVQE: real-time AEC + noise suppression + dereverb")
gr.Markdown(DESCRIPTION)
with gr.Row():
mic_in = gr.Audio(label="Mic (microphone recording)", type="filepath")
ref_in = gr.Audio(label="Far-end reference (speaker playback)", type="filepath")
model_choice = gr.Radio(
choices=list(MODELS.keys()),
value=DEFAULT_MODEL_KEY,
label="Model",
info=(
"v1.2 is the current release. SiLU activation + 1024 ms "
"echo-search window + wider clean-pool DNSMOS filter + "
"phone-bandwidth + codec round-trip aug. Adds ~+0.3 "
"echo_mos and ~+1 dB ERLE on the AEC blind set vs v1.1. "
"v1.1 / v1 are kept for A/B. Same param count (1.3 M). "
"Switch and re-run on the same clip to compare."
),
) if len(MODELS) > 1 else gr.State(DEFAULT_MODEL_KEY)
with gr.Row():
gate_enabled = gr.Checkbox(
label="Residual-echo gate",
value=False,
info=(
"Post-process the enhanced output: silence any 10 ms frame "
"whose RMS falls below the threshold. Cleans up the quiet "
"residual you'd hear during far-end-only stretches; will "
"also mute genuinely quiet speech below the threshold."
),
)
gate_threshold_db = gr.Slider(
label="Gate threshold (dBFS)",
minimum=-70.0, maximum=-20.0, value=-45.0, step=1.0,
)
if DEBUG_AVAILABLE and DEV_MODE:
_dbg_components = _dbg.build_debug_ui(gr)
debug_source = _dbg_components["debug_source"]
f_smooth_kernel = _dbg_components["f_smooth_kernel"]
f_smooth_mode = _dbg_components["f_smooth_mode"]
smoother_mode = _dbg_components["smoother_mode"]
smoother_attack_db = _dbg_components["smoother_attack_db"]
smoother_release_db = _dbg_components["smoother_release_db"]
smoother_ema_alpha = _dbg_components["smoother_ema_alpha"]
smoother_floor_db = _dbg_components["smoother_floor_db"]
smoother_median_k = _dbg_components["smoother_median_k"]
else:
# Production / no _debug.py — hidden gr.State holders carrying
# neutral defaults, so `enhance()` keeps a stable input list.
debug_source = gr.State("enhanced")
f_smooth_kernel = gr.State(31)
f_smooth_mode = gr.State("median")
smoother_mode = gr.State("off")
smoother_attack_db = gr.State(12.0)
smoother_release_db = gr.State(1.0)
smoother_ema_alpha = gr.State(0.7)
smoother_floor_db = gr.State(20.0)
smoother_median_k = gr.State(3)
btn = gr.Button("Enhance", variant="primary")
out = gr.Audio(label="Enhanced output", type="numpy")
gr.Examples(
examples=EXAMPLES,
inputs=[mic_in, ref_in],
label=(
"Examples — top to bottom: near-end + heavy noise (5 dB SNR, "
"pure NS), near-end + light noise (20 dB SNR, NS preserving "
"clean speech), far-end single-talk (pure AEC), far-end with "
"brief near-end overlap (AEC while preserving NE), and "
"double-talk (AEC while near-end is also talking)."
),
)
btn.click(
enhance,
inputs=[mic_in, ref_in, model_choice,
gate_enabled, gate_threshold_db,
smoother_mode, smoother_attack_db, smoother_release_db,
smoother_ema_alpha, smoother_floor_db, smoother_median_k,
debug_source, f_smooth_kernel, f_smooth_mode],
outputs=out,
)
_info_lines = []
for key in MODELS:
i = INFOS[key]
_info_lines.append(
f"{i['label']} — {i['source']} · "
f"sha256 {i['sha256'][:16]}… · "
f"{i['n_params']:,} params"
)
gr.Markdown("Loaded models:
" + "
".join(_info_lines) + "")
if __name__ == "__main__":
demo.launch(server_name=os.environ.get("GRADIO_SERVER_NAME", "127.0.0.1"))