Spaces:
Running
Running
File size: 22,408 Bytes
4313d1d 37969f2 69f75a7 37969f2 69f75a7 37969f2 4313d1d eee8304 f5e08b6 4313d1d 69f75a7 37969f2 69f75a7 4313d1d 37969f2 4313d1d 69f75a7 37969f2 4313d1d eee8304 69f75a7 f5e08b6 69f75a7 37969f2 69f75a7 4313d1d 69f75a7 4313d1d 69f75a7 6e0a6e4 4313d1d 69f75a7 37969f2 69f75a7 eee8304 4313d1d 37969f2 69f75a7 37969f2 69f75a7 37969f2 69f75a7 37969f2 69f75a7 37969f2 69f75a7 4313d1d 69f75a7 9264232 69f75a7 9264232 69f75a7 4313d1d 69f75a7 4313d1d 69f75a7 4313d1d 50954ed 4313d1d 9264232 50954ed 4313d1d eee8304 4313d1d b180d02 eee8304 b180d02 4313d1d eee8304 4313d1d b180d02 eee8304 b180d02 4313d1d eee8304 4313d1d b7e597c 4313d1d b180d02 eee8304 4313d1d 61c68a1 4313d1d 69f75a7 37969f2 69f75a7 9264232 69f75a7 4313d1d b180d02 eee8304 dac40f1 4313d1d 9264232 69f75a7 9264232 4313d1d 69f75a7 eee8304 4313d1d 9264232 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 | """Gradio demo for LocalVQE β real-time AEC + NS + dereverb.
Loads released model versions side-by-side and exposes a runtime
selector so you can A/B them on the same clip:
v1.2 β newest, default. 1.3 M params. SiLU activation + dmax 64
(1024 ms echo-search window) + wider clean-pool DNSMOS
filter + phone-bandwidth + codec round-trip aug. Adds
~+0.3 echo_mos / ~+1 dB ERLE on AEC blind FE-ST vs v1.1.
Path resolves from LOCALVQE_V12_CKPT, else HF.
v1.1 β previous release. 1.3 M params. ReLU6, pre-norm
CausalGroupNorm, STFT-256 codec. Fixes intermittent
crackling that v1 produced under heavy background noise.
Path resolves from LOCALVQE_V11_CKPT, else HF.
v1 β original release. Path resolves from LOCALVQE_V1_CKPT
(or LOCALVQE_LOCAL_CKPT for backward compat), else HF.
If a checkpoint isn't reachable that entry is hidden from the
selector. Each architecture lives in an independent Python
package so they can be loaded simultaneously without import
collisions:
v1 β space/localvqe_model/
v1.1 β space/localvqe_v11/
v1.2 β space/localvqe_v12/
"""
import hashlib
import os
from pathlib import Path
import gradio as gr
import numpy as np
import soundfile as sf
import torch
from scipy.signal import resample_poly
# v1 (original release) β namespace 'localvqe_model'
from localvqe_model import (
Config as ConfigV1,
LocalVQE as LocalVQEv1,
apply_ckpt_model_config as apply_ckpt_v1,
load_checkpoint as load_ckpt_v1,
)
# v1.1 / v1.2 β bundled in this directory. Imported on demand to keep
# startup time low when those versions aren't configured.
def _import_v11():
from localvqe_v11 import (
Config as ConfigV11,
LocalVQE as LocalVQEv11,
apply_ckpt_model_config as apply_ckpt_v11,
load_checkpoint as load_ckpt_v11,
)
return ConfigV11, LocalVQEv11, apply_ckpt_v11, load_ckpt_v11
def _import_v12():
from localvqe_v12 import (
Config as ConfigV12,
LocalVQE as LocalVQEv12,
apply_ckpt_model_config as apply_ckpt_v12,
load_checkpoint as load_ckpt_v12,
)
return ConfigV12, LocalVQEv12, apply_ckpt_v12, load_ckpt_v12
SR = 16000
HF_REPO_ID = "LocalAI-io/LocalVQE"
HF_V1_FILE = "localvqe-v1-1.3M.pt"
HF_V11_FILE = "localvqe-v1.1-1.3M.pt"
HF_V12_FILE = "localvqe-v1.2-1.3M.pt"
EXAMPLES_DIR = Path(__file__).resolve().parent / "examples"
def _sha256(path: str) -> str:
h = hashlib.sha256()
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(1 << 20), b""):
h.update(chunk)
return h.hexdigest()
def _resolve_v1_ckpt() -> str | None:
# Backward-compat: LOCALVQE_LOCAL_CKPT used to be the way to override.
for env in ("LOCALVQE_V1_CKPT", "LOCALVQE_LOCAL_CKPT"):
v = os.environ.get(env)
if v:
return v
try:
from huggingface_hub import hf_hub_download
return hf_hub_download(repo_id=HF_REPO_ID, filename=HF_V1_FILE)
except Exception as e:
print(f"v1 unavailable from HF ({e})")
return None
def _resolve_v11_ckpt() -> str | None:
v = os.environ.get("LOCALVQE_V11_CKPT")
if v:
return v
try:
from huggingface_hub import hf_hub_download
return hf_hub_download(repo_id=HF_REPO_ID, filename=HF_V11_FILE)
except Exception:
return None
def _resolve_v12_ckpt() -> str | None:
v = os.environ.get("LOCALVQE_V12_CKPT")
if v:
return v
try:
from huggingface_hub import hf_hub_download
return hf_hub_download(repo_id=HF_REPO_ID, filename=HF_V12_FILE)
except Exception:
return None
def _resolve_v121_ckpt() -> str | None:
# No HF fallback yet β v1.2.1 isn't published. Set LOCALVQE_V121_CKPT
# in docker-compose.yml (defaults to checkpoints/release/...) to load
# the local finetuned copy.
return os.environ.get("LOCALVQE_V121_CKPT") or None
def _resolve_v12a_ckpt() -> str | None:
# v1.2a β v9 (widened DRR + longer RIRs + global gain) from-scratch
# epoch 14. Architecture identical to v1.2/v1.2.1 (uses localvqe_v12
# package). No HF publish yet.
return os.environ.get("LOCALVQE_V12A_CKPT") or None
def _resolve_v12b_ckpt() -> str | None:
# v1.2b β v10 (v1.2 + audible reverb + 80/20 conference mix +
# pipeline pop fixes, no experimental augs) from-scratch e19.
# Architecture identical to v1.2 (uses localvqe_v12 package).
return os.environ.get("LOCALVQE_V12B_CKPT") or None
def _resolve_v12c_ckpt() -> str | None:
# v1.2c β v11 (v10 + level-invariance mic-gain aug,
# clean_attenuation_factor=1.0) from-scratch e17. Addresses
# low-SNR wobble near noise floor. Architecture identical to
# v1.2 (uses localvqe_v12 package).
return os.environ.get("LOCALVQE_V12C_CKPT") or None
def _resolve_v12d_ckpt() -> str | None:
# v1.2d β v11_refine e22 (10-epoch low-LR cosine continuation
# of v1.2c from v11 e20, peak LR 1e-4). Blind eval beats
# v1.2c on FE-ST echo_mos (+0.31) and NE-ST deg_mos (+0.04)
# while recovering 2.4 dB of FE-ST ERLE. Architecture
# identical to v1.2 (uses localvqe_v12 package).
return os.environ.get("LOCALVQE_V12D_CKPT") or None
def _build_v1():
ckpt_path = _resolve_v1_ckpt()
if ckpt_path is None:
return None, None
cfg = ConfigV1()
peek = torch.load(ckpt_path, map_location="cpu", weights_only=False)
apply_ckpt_v1(peek, cfg)
del peek
model = LocalVQEv1.from_config(cfg).to("cpu")
load_ckpt_v1(ckpt_path, model)
# Fold the trained AlignBlock softmax temperature (a buffer in the
# checkpoint) into the smoothing conv β without this, eval runs at
# the default 1.0 instead of the trained value, losing ~5 dB ERLE.
model.align.fold_temperature()
model.eval()
info = {
"source": ckpt_path,
"sha256": _sha256(ckpt_path),
"n_params": sum(p.numel() for p in model.parameters()),
"label": "v1 (previous release)",
}
print(f"v1 loaded: {info['n_params']:,} params sha={info['sha256'][:16]}β¦ "
f"src={ckpt_path}")
return model, info
def _build_v11():
ckpt_path = _resolve_v11_ckpt()
if ckpt_path is None:
return None, None
ConfigV11, LocalVQEv11, apply_ckpt_v11, load_ckpt_v11 = _import_v11()
cfg = ConfigV11()
peek = torch.load(ckpt_path, map_location="cpu", weights_only=False)
apply_ckpt_v11(peek, cfg)
del peek
model = LocalVQEv11.from_config(cfg).to("cpu")
load_ckpt_v11(ckpt_path, model)
model.align.fold_temperature()
model.eval()
info = {
"source": ckpt_path,
"sha256": _sha256(ckpt_path),
"n_params": sum(p.numel() for p in model.parameters()),
"label": "v1.1 (previous release)",
}
print(f"v1.1 loaded: {info['n_params']:,} params sha={info['sha256'][:16]}β¦ "
f"src={ckpt_path}")
return model, info
def _build_v12_like(ckpt_path, label):
"""Shared builder for v1.2 and v1.2.1 β same architecture, same package."""
ConfigV12, LocalVQEv12, apply_ckpt_v12, load_ckpt_v12 = _import_v12()
cfg = ConfigV12()
peek = torch.load(ckpt_path, map_location="cpu", weights_only=False)
apply_ckpt_v12(peek, cfg)
del peek
model = LocalVQEv12.from_config(cfg).to("cpu")
load_ckpt_v12(ckpt_path, model)
model.align.fold_temperature()
model.eval()
info = {
"source": ckpt_path,
"sha256": _sha256(ckpt_path),
"n_params": sum(p.numel() for p in model.parameters()),
"label": label,
}
return model, info
def _build_v12():
ckpt_path = _resolve_v12_ckpt()
if ckpt_path is None:
return None, None
model, info = _build_v12_like(ckpt_path, "v1.2 (current release)")
print(f"v1.2 loaded: {info['n_params']:,} params sha={info['sha256'][:16]}β¦ "
f"src={ckpt_path}")
return model, info
def _build_v121():
ckpt_path = _resolve_v121_ckpt()
if ckpt_path is None:
return None, None
model, info = _build_v12_like(ckpt_path, "v1.2.1 (movement-aug finetune)")
print(f"v1.2.1 loaded: {info['n_params']:,} params sha={info['sha256'][:16]}β¦ "
f"src={ckpt_path}")
return model, info
def _build_v12a():
ckpt_path = _resolve_v12a_ckpt()
if ckpt_path is None:
return None, None
model, info = _build_v12_like(
ckpt_path, "v1.2a (widened DRR + longer RIRs, from-scratch)")
print(f"v1.2a loaded: {info['n_params']:,} params sha={info['sha256'][:16]}β¦ "
f"src={ckpt_path}")
return model, info
def _build_v12b():
ckpt_path = _resolve_v12b_ckpt()
if ckpt_path is None:
return None, None
model, info = _build_v12_like(
ckpt_path, "v1.2b (v10: audible reverb + conference mix + pop fixes)")
print(f"v1.2b loaded: {info['n_params']:,} params sha={info['sha256'][:16]}β¦ "
f"src={ckpt_path}")
return model, info
def _build_v12c():
ckpt_path = _resolve_v12c_ckpt()
if ckpt_path is None:
return None, None
model, info = _build_v12_like(
ckpt_path, "v1.2c (v11: level-invariance mic-gain on v1.2b base)")
print(f"v1.2c loaded: {info['n_params']:,} params sha={info['sha256'][:16]}β¦ "
f"src={ckpt_path}")
return model, info
def _build_v12d():
ckpt_path = _resolve_v12d_ckpt()
if ckpt_path is None:
return None, None
model, info = _build_v12_like(
ckpt_path, "v1.2d (v11_refine e22: low-LR cosine polish of v1.2c)")
print(f"v1.2d loaded: {info['n_params']:,} params sha={info['sha256'][:16]}β¦ "
f"src={ckpt_path}")
return model, info
MODEL_V1, INFO_V1 = _build_v1()
MODEL_V11, INFO_V11 = _build_v11()
MODEL_V12, INFO_V12 = _build_v12()
MODEL_V121, INFO_V121 = _build_v121()
MODEL_V12A, INFO_V12A = _build_v12a()
MODEL_V12B, INFO_V12B = _build_v12b()
MODEL_V12C, INFO_V12C = _build_v12c()
MODEL_V12D, INFO_V12D = _build_v12d()
MODELS: dict[str, object] = {}
INFOS: dict[str, dict] = {}
if MODEL_V1 is not None:
MODELS["v1"] = MODEL_V1
INFOS["v1"] = INFO_V1
if MODEL_V11 is not None:
MODELS["v1.1"] = MODEL_V11
INFOS["v1.1"] = INFO_V11
if MODEL_V12 is not None:
MODELS["v1.2"] = MODEL_V12
INFOS["v1.2"] = INFO_V12
if MODEL_V121 is not None:
MODELS["v1.2.1"] = MODEL_V121
INFOS["v1.2.1"] = INFO_V121
if MODEL_V12A is not None:
MODELS["v1.2a"] = MODEL_V12A
INFOS["v1.2a"] = INFO_V12A
if MODEL_V12B is not None:
MODELS["v1.2b"] = MODEL_V12B
INFOS["v1.2b"] = INFO_V12B
if MODEL_V12C is not None:
MODELS["v1.2c"] = MODEL_V12C
INFOS["v1.2c"] = INFO_V12C
if MODEL_V12D is not None:
MODELS["v1.2d"] = MODEL_V12D
INFOS["v1.2d"] = INFO_V12D
if not MODELS:
raise RuntimeError(
"No model could be loaded. Set LOCALVQE_V1_CKPT, "
"LOCALVQE_V11_CKPT, LOCALVQE_V12_CKPT, LOCALVQE_V121_CKPT, "
"LOCALVQE_V12A_CKPT, LOCALVQE_V12B_CKPT, LOCALVQE_V12C_CKPT, "
"or LOCALVQE_V12D_CKPT, or ensure HF access for the "
"published files."
)
DEFAULT_MODEL_KEY = (
"v1.2d" if "v1.2d" in MODELS
else "v1.2c" if "v1.2c" in MODELS
else "v1.2b" if "v1.2b" in MODELS
else "v1.2a" if "v1.2a" in MODELS
else "v1.2.1" if "v1.2.1" in MODELS
else "v1.2" if "v1.2" in MODELS
else "v1.1" if "v1.1" in MODELS
else "v1"
)
# Dev mode: shows the diagnostic-source dropdown and mask-smoother
# accordion in the UI. Auto-on locally, auto-off on HF Spaces (which
# always sets `SPACE_ID`). Either can be overridden by setting
# LOCALVQE_DEV_MODE=1 (force on) or =0 (force off).
def _dev_mode() -> bool:
explicit = os.environ.get("LOCALVQE_DEV_MODE")
if explicit in ("0", "1"):
return explicit == "1"
return "SPACE_ID" not in os.environ
DEV_MODE = _dev_mode()
if DEV_MODE:
print("DEV_MODE=on (debug accordions visible). Set LOCALVQE_DEV_MODE=0 to hide.")
def _load_mono_16k(path: str) -> np.ndarray:
wav, sr = sf.read(path, dtype="float32", always_2d=False)
if wav.ndim == 2:
wav = wav.mean(axis=1)
if sr != SR:
from math import gcd
g = gcd(sr, SR)
wav = resample_poly(wav, SR // g, sr // g).astype(np.float32)
return wav
# Debug / diagnostic helpers live in `_debug.py`, which is excluded
# from the HuggingFace Spaces deploy. When this file is missing the
# app silently degrades: no debug accordions, no diagnostic-source
# branches, just the standard model forward.
try:
import _debug as _dbg
DEBUG_AVAILABLE = True
except ImportError:
_dbg = None
DEBUG_AVAILABLE = False
def _noise_gate(x: np.ndarray, threshold_dbfs: float) -> np.ndarray:
"""Hard-gate frames whose RMS is below `threshold_dbfs` to zero.
Operates on 10 ms frames (160 samples at 16 kHz) β short enough
that speech bursts aren't truncated, long enough that a single
out-of-band sample inside an active region doesn't get muted.
The ungated tail (samples that don't fill a full final frame) is
passed through unchanged.
"""
frame = 160
n = len(x) // frame
if n == 0:
return x
f = x[: n * frame].reshape(n, frame).astype(np.float32)
rms = np.sqrt((f * f).mean(axis=-1) + 1e-12)
rms_db = 20.0 * np.log10(rms + 1e-12)
keep = (rms_db > threshold_dbfs).astype(np.float32)
gated = (f * keep[:, None]).reshape(-1)
return np.concatenate([gated, x[n * frame:]]).astype(x.dtype)
def enhance(mic_path: str, ref_path: str,
model_choice: str = DEFAULT_MODEL_KEY,
gate_enabled: bool = False,
gate_threshold_db: float = -45.0,
smoother_mode: str = "off",
smoother_attack_db: float = 12.0,
smoother_release_db: float = 1.0,
smoother_ema_alpha: float = 0.7,
smoother_floor_db: float = 20.0,
smoother_median_k: int = 3,
debug_source: str = "enhanced",
f_smooth_kernel: int = 31,
f_smooth_mode: str = "median") -> tuple[int, np.ndarray]:
if mic_path is None:
raise gr.Error("Upload or pick a mic recording first.")
if model_choice not in MODELS:
raise gr.Error(f"Model {model_choice!r} not loaded. Available: {list(MODELS)}")
model = MODELS[model_choice]
mic = _load_mono_16k(mic_path)
if ref_path is None:
ref = np.zeros_like(mic)
else:
ref = _load_mono_16k(ref_path)
n = max(len(mic), len(ref))
if len(mic) < n:
mic = np.pad(mic, (0, n - len(mic)))
if len(ref) < n:
ref = np.pad(ref, (0, n - len(ref)))
mic_t = torch.from_numpy(mic).unsqueeze(0)
ref_t = torch.from_numpy(ref).unsqueeze(0)
with torch.no_grad():
if DEBUG_AVAILABLE and debug_source != "enhanced":
enc = _dbg.apply_debug_source(
model, mic_t, ref_t, debug_source,
smoother_ema_alpha=smoother_ema_alpha,
f_smooth_kernel=f_smooth_kernel,
f_smooth_mode=f_smooth_mode,
)
else:
enc = model(mic_t, ref_t)
if (DEBUG_AVAILABLE and smoother_mode != "off"
and debug_source not in ("passthrough", "bypass_ccm")):
enc = _dbg.apply_smoother(
enc, model.encoder(mic_t), smoother_mode,
attack_db=smoother_attack_db,
release_db=smoother_release_db,
ema_alpha=smoother_ema_alpha,
floor_db=smoother_floor_db,
median_k=smoother_median_k,
)
enh = model.decoder(enc.float(), length=n)
out = enh[0].cpu().numpy()
peak = float(np.abs(out).max())
if peak > 0.95:
out = out / peak * 0.95
# Optional residual-echo gate: silence frames whose RMS sits below
# `gate_threshold_db` dBFS. Off by default so listeners can A/B
# against the raw model output via the slider.
if gate_enabled:
out = _noise_gate(out, gate_threshold_db)
# Convert to int16 ourselves: Gradio's gr.Audio output otherwise
# peak-normalises float arrays via convert_to_16_bit_wav (data /=
# np.abs(data).max(); * 32767), which amplifies the cancelled-echo
# residual on AEC-heavy clips by 1000Γ+ and makes it sound like
# the model isn't suppressing anything. Returning int16 preserves
# the true (quiet) loudness so listeners hear the actual output.
out_i16 = np.clip(out * 32767, -32768, 32767).astype(np.int16)
return SR, out_i16
EXAMPLES = [
[
str(EXAMPLES_DIR / "ne_st_noisy_mic.wav"),
str(EXAMPLES_DIR / "ne_st_noisy_ref.wav"),
],
[
str(EXAMPLES_DIR / "ne_st_clean_mic.wav"),
str(EXAMPLES_DIR / "ne_st_clean_ref.wav"),
],
[
str(EXAMPLES_DIR / "fe_st_mic.wav"),
str(EXAMPLES_DIR / "fe_st_ref.wav"),
],
[
str(EXAMPLES_DIR / "fe_st2_mic.wav"),
str(EXAMPLES_DIR / "fe_st2_ref.wav"),
],
[
str(EXAMPLES_DIR / "dt_mic.wav"),
str(EXAMPLES_DIR / "dt_ref.wav"),
],
]
DESCRIPTION = """
**LocalVQE** is a ~1 M-parameter open-source model that cleans up a
microphone signal on a voice call: it cancels the remote participant's
voice being picked up again (echo), suppresses background noise, and
removes reverberation β all in a single causal pass on CPU.
Provide two inputs:
- **Mic**: the raw microphone recording (what the far end would hear
without any processing).
- **Far-end reference**: the audio being played out of your speakers.
For a pure noise-suppression test (no speaker playback), upload
silence or leave empty.
Try the bundled examples first β they cover heavy and light
near-end noise (NE-ST mixed with DNS5 background at 5 dB and 20 dB
SNR), a clean far-end single-talk clip, a far-end clip with some
near-end overlap (mislabelled in the source corpus, but a useful
test of AEC + near-end preservation together), and a double-talk
clip β all from the ICASSP 2022 AEC Challenge blind set.
Weights: [LocalAI-io/LocalVQE](https://huggingface.co/LocalAI-io/LocalVQE) Β·
Code: [github.com/localai-org/LocalVQE](https://github.com/localai-org/LocalVQE)
"""
with gr.Blocks(title="LocalVQE Demo") as demo:
gr.Markdown("# LocalVQE: real-time AEC + noise suppression + dereverb")
gr.Markdown(DESCRIPTION)
with gr.Row():
mic_in = gr.Audio(label="Mic (microphone recording)", type="filepath")
ref_in = gr.Audio(label="Far-end reference (speaker playback)", type="filepath")
model_choice = gr.Radio(
choices=list(MODELS.keys()),
value=DEFAULT_MODEL_KEY,
label="Model",
info=(
"v1.2 is the current release. SiLU activation + 1024 ms "
"echo-search window + wider clean-pool DNSMOS filter + "
"phone-bandwidth + codec round-trip aug. Adds ~+0.3 "
"echo_mos and ~+1 dB ERLE on the AEC blind set vs v1.1. "
"v1.1 / v1 are kept for A/B. Same param count (1.3 M). "
"Switch and re-run on the same clip to compare."
),
) if len(MODELS) > 1 else gr.State(DEFAULT_MODEL_KEY)
with gr.Row():
gate_enabled = gr.Checkbox(
label="Residual-echo gate",
value=False,
info=(
"Post-process the enhanced output: silence any 10 ms frame "
"whose RMS falls below the threshold. Cleans up the quiet "
"residual you'd hear during far-end-only stretches; will "
"also mute genuinely quiet speech below the threshold."
),
)
gate_threshold_db = gr.Slider(
label="Gate threshold (dBFS)",
minimum=-70.0, maximum=-20.0, value=-45.0, step=1.0,
)
if DEBUG_AVAILABLE and DEV_MODE:
_dbg_components = _dbg.build_debug_ui(gr)
debug_source = _dbg_components["debug_source"]
f_smooth_kernel = _dbg_components["f_smooth_kernel"]
f_smooth_mode = _dbg_components["f_smooth_mode"]
smoother_mode = _dbg_components["smoother_mode"]
smoother_attack_db = _dbg_components["smoother_attack_db"]
smoother_release_db = _dbg_components["smoother_release_db"]
smoother_ema_alpha = _dbg_components["smoother_ema_alpha"]
smoother_floor_db = _dbg_components["smoother_floor_db"]
smoother_median_k = _dbg_components["smoother_median_k"]
else:
# Production / no _debug.py β hidden gr.State holders carrying
# neutral defaults, so `enhance()` keeps a stable input list.
debug_source = gr.State("enhanced")
f_smooth_kernel = gr.State(31)
f_smooth_mode = gr.State("median")
smoother_mode = gr.State("off")
smoother_attack_db = gr.State(12.0)
smoother_release_db = gr.State(1.0)
smoother_ema_alpha = gr.State(0.7)
smoother_floor_db = gr.State(20.0)
smoother_median_k = gr.State(3)
btn = gr.Button("Enhance", variant="primary")
out = gr.Audio(label="Enhanced output", type="numpy")
gr.Examples(
examples=EXAMPLES,
inputs=[mic_in, ref_in],
label=(
"Examples β top to bottom: near-end + heavy noise (5 dB SNR, "
"pure NS), near-end + light noise (20 dB SNR, NS preserving "
"clean speech), far-end single-talk (pure AEC), far-end with "
"brief near-end overlap (AEC while preserving NE), and "
"double-talk (AEC while near-end is also talking)."
),
)
btn.click(
enhance,
inputs=[mic_in, ref_in, model_choice,
gate_enabled, gate_threshold_db,
smoother_mode, smoother_attack_db, smoother_release_db,
smoother_ema_alpha, smoother_floor_db, smoother_median_k,
debug_source, f_smooth_kernel, f_smooth_mode],
outputs=out,
)
_info_lines = []
for key in MODELS:
i = INFOS[key]
_info_lines.append(
f"<b>{i['label']}</b> β <code>{i['source']}</code> Β· "
f"sha256 <code>{i['sha256'][:16]}β¦</code> Β· "
f"{i['n_params']:,} params"
)
gr.Markdown("<sub>Loaded models:<br>" + "<br>".join(_info_lines) + "</sub>")
if __name__ == "__main__":
demo.launch(server_name=os.environ.get("GRADIO_SERVER_NAME", "127.0.0.1"))
|