File size: 22,408 Bytes
4313d1d
 
37969f2
69f75a7
 
37969f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69f75a7
 
37969f2
4313d1d
eee8304
f5e08b6
4313d1d
 
 
 
 
 
 
 
69f75a7
 
 
 
 
 
 
 
37969f2
 
69f75a7
 
 
 
 
 
 
 
4313d1d
37969f2
 
 
 
 
 
 
 
 
4313d1d
69f75a7
 
 
37969f2
4313d1d
 
 
eee8304
 
 
 
 
 
 
 
69f75a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5e08b6
69f75a7
 
 
 
 
37969f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69f75a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4313d1d
69f75a7
4313d1d
69f75a7
 
6e0a6e4
4313d1d
69f75a7
 
 
 
37969f2
69f75a7
 
 
eee8304
4313d1d
 
37969f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69f75a7
 
37969f2
 
 
 
 
 
69f75a7
 
 
 
 
 
 
 
 
37969f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69f75a7
 
37969f2
 
 
 
 
69f75a7
 
37969f2
 
 
 
 
 
 
 
 
 
69f75a7
 
 
 
 
 
 
 
 
 
 
 
 
4313d1d
 
 
 
 
 
 
 
 
 
 
 
 
69f75a7
 
 
 
 
 
 
 
 
 
 
 
9264232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69f75a7
9264232
69f75a7
 
 
 
 
 
 
 
 
 
4313d1d
 
69f75a7
 
 
4313d1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69f75a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4313d1d
 
50954ed
4313d1d
 
9264232
 
 
 
 
50954ed
 
 
 
 
 
 
 
4313d1d
 
 
 
eee8304
 
4313d1d
b180d02
eee8304
 
b180d02
4313d1d
eee8304
 
4313d1d
b180d02
eee8304
 
b180d02
4313d1d
eee8304
 
4313d1d
 
 
 
b7e597c
4313d1d
 
 
 
 
 
 
 
 
 
 
 
b180d02
 
eee8304
 
 
 
4313d1d
 
61c68a1
4313d1d
 
 
 
 
 
 
 
69f75a7
 
 
 
 
37969f2
 
 
 
 
 
69f75a7
 
9264232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69f75a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4313d1d
 
 
 
 
 
 
b180d02
 
eee8304
dac40f1
 
4313d1d
 
 
9264232
 
69f75a7
 
 
 
 
9264232
 
4313d1d
69f75a7
 
 
 
 
 
 
 
 
eee8304
4313d1d
9264232
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
"""Gradio demo for LocalVQE β€” real-time AEC + NS + dereverb.

Loads released model versions side-by-side and exposes a runtime
selector so you can A/B them on the same clip:

  v1.2 β€” newest, default. 1.3 M params. SiLU activation + dmax 64
         (1024 ms echo-search window) + wider clean-pool DNSMOS
         filter + phone-bandwidth + codec round-trip aug. Adds
         ~+0.3 echo_mos / ~+1 dB ERLE on AEC blind FE-ST vs v1.1.
         Path resolves from LOCALVQE_V12_CKPT, else HF.
  v1.1 β€” previous release. 1.3 M params. ReLU6, pre-norm
         CausalGroupNorm, STFT-256 codec. Fixes intermittent
         crackling that v1 produced under heavy background noise.
         Path resolves from LOCALVQE_V11_CKPT, else HF.
  v1   β€” original release. Path resolves from LOCALVQE_V1_CKPT
         (or LOCALVQE_LOCAL_CKPT for backward compat), else HF.

If a checkpoint isn't reachable that entry is hidden from the
selector. Each architecture lives in an independent Python
package so they can be loaded simultaneously without import
collisions:
    v1   β†’ space/localvqe_model/
    v1.1 β†’ space/localvqe_v11/
    v1.2 β†’ space/localvqe_v12/
"""
import hashlib
import os
from pathlib import Path

import gradio as gr
import numpy as np
import soundfile as sf
import torch
from scipy.signal import resample_poly

# v1 (original release) β€” namespace 'localvqe_model'
from localvqe_model import (
    Config as ConfigV1,
    LocalVQE as LocalVQEv1,
    apply_ckpt_model_config as apply_ckpt_v1,
    load_checkpoint as load_ckpt_v1,
)

# v1.1 / v1.2 β€” bundled in this directory. Imported on demand to keep
# startup time low when those versions aren't configured.
def _import_v11():
    from localvqe_v11 import (
        Config as ConfigV11,
        LocalVQE as LocalVQEv11,
        apply_ckpt_model_config as apply_ckpt_v11,
        load_checkpoint as load_ckpt_v11,
    )
    return ConfigV11, LocalVQEv11, apply_ckpt_v11, load_ckpt_v11

def _import_v12():
    from localvqe_v12 import (
        Config as ConfigV12,
        LocalVQE as LocalVQEv12,
        apply_ckpt_model_config as apply_ckpt_v12,
        load_checkpoint as load_ckpt_v12,
    )
    return ConfigV12, LocalVQEv12, apply_ckpt_v12, load_ckpt_v12

SR = 16000
HF_REPO_ID = "LocalAI-io/LocalVQE"
HF_V1_FILE = "localvqe-v1-1.3M.pt"
HF_V11_FILE = "localvqe-v1.1-1.3M.pt"
HF_V12_FILE = "localvqe-v1.2-1.3M.pt"
EXAMPLES_DIR = Path(__file__).resolve().parent / "examples"


def _sha256(path: str) -> str:
    h = hashlib.sha256()
    with open(path, "rb") as f:
        for chunk in iter(lambda: f.read(1 << 20), b""):
            h.update(chunk)
    return h.hexdigest()


def _resolve_v1_ckpt() -> str | None:
    # Backward-compat: LOCALVQE_LOCAL_CKPT used to be the way to override.
    for env in ("LOCALVQE_V1_CKPT", "LOCALVQE_LOCAL_CKPT"):
        v = os.environ.get(env)
        if v:
            return v
    try:
        from huggingface_hub import hf_hub_download
        return hf_hub_download(repo_id=HF_REPO_ID, filename=HF_V1_FILE)
    except Exception as e:
        print(f"v1 unavailable from HF ({e})")
        return None


def _resolve_v11_ckpt() -> str | None:
    v = os.environ.get("LOCALVQE_V11_CKPT")
    if v:
        return v
    try:
        from huggingface_hub import hf_hub_download
        return hf_hub_download(repo_id=HF_REPO_ID, filename=HF_V11_FILE)
    except Exception:
        return None


def _resolve_v12_ckpt() -> str | None:
    v = os.environ.get("LOCALVQE_V12_CKPT")
    if v:
        return v
    try:
        from huggingface_hub import hf_hub_download
        return hf_hub_download(repo_id=HF_REPO_ID, filename=HF_V12_FILE)
    except Exception:
        return None


def _resolve_v121_ckpt() -> str | None:
    # No HF fallback yet β€” v1.2.1 isn't published. Set LOCALVQE_V121_CKPT
    # in docker-compose.yml (defaults to checkpoints/release/...) to load
    # the local finetuned copy.
    return os.environ.get("LOCALVQE_V121_CKPT") or None


def _resolve_v12a_ckpt() -> str | None:
    # v1.2a β€” v9 (widened DRR + longer RIRs + global gain) from-scratch
    # epoch 14. Architecture identical to v1.2/v1.2.1 (uses localvqe_v12
    # package). No HF publish yet.
    return os.environ.get("LOCALVQE_V12A_CKPT") or None


def _resolve_v12b_ckpt() -> str | None:
    # v1.2b β€” v10 (v1.2 + audible reverb + 80/20 conference mix +
    # pipeline pop fixes, no experimental augs) from-scratch e19.
    # Architecture identical to v1.2 (uses localvqe_v12 package).
    return os.environ.get("LOCALVQE_V12B_CKPT") or None


def _resolve_v12c_ckpt() -> str | None:
    # v1.2c β€” v11 (v10 + level-invariance mic-gain aug,
    # clean_attenuation_factor=1.0) from-scratch e17. Addresses
    # low-SNR wobble near noise floor. Architecture identical to
    # v1.2 (uses localvqe_v12 package).
    return os.environ.get("LOCALVQE_V12C_CKPT") or None


def _resolve_v12d_ckpt() -> str | None:
    # v1.2d β€” v11_refine e22 (10-epoch low-LR cosine continuation
    # of v1.2c from v11 e20, peak LR 1e-4). Blind eval beats
    # v1.2c on FE-ST echo_mos (+0.31) and NE-ST deg_mos (+0.04)
    # while recovering 2.4 dB of FE-ST ERLE. Architecture
    # identical to v1.2 (uses localvqe_v12 package).
    return os.environ.get("LOCALVQE_V12D_CKPT") or None


def _build_v1():
    ckpt_path = _resolve_v1_ckpt()
    if ckpt_path is None:
        return None, None
    cfg = ConfigV1()
    peek = torch.load(ckpt_path, map_location="cpu", weights_only=False)
    apply_ckpt_v1(peek, cfg)
    del peek
    model = LocalVQEv1.from_config(cfg).to("cpu")
    load_ckpt_v1(ckpt_path, model)
    # Fold the trained AlignBlock softmax temperature (a buffer in the
    # checkpoint) into the smoothing conv β€” without this, eval runs at
    # the default 1.0 instead of the trained value, losing ~5 dB ERLE.
    model.align.fold_temperature()
    model.eval()
    info = {
        "source": ckpt_path,
        "sha256": _sha256(ckpt_path),
        "n_params": sum(p.numel() for p in model.parameters()),
        "label": "v1 (previous release)",
    }
    print(f"v1 loaded: {info['n_params']:,} params  sha={info['sha256'][:16]}…  "
          f"src={ckpt_path}")
    return model, info


def _build_v11():
    ckpt_path = _resolve_v11_ckpt()
    if ckpt_path is None:
        return None, None
    ConfigV11, LocalVQEv11, apply_ckpt_v11, load_ckpt_v11 = _import_v11()
    cfg = ConfigV11()
    peek = torch.load(ckpt_path, map_location="cpu", weights_only=False)
    apply_ckpt_v11(peek, cfg)
    del peek
    model = LocalVQEv11.from_config(cfg).to("cpu")
    load_ckpt_v11(ckpt_path, model)
    model.align.fold_temperature()
    model.eval()
    info = {
        "source": ckpt_path,
        "sha256": _sha256(ckpt_path),
        "n_params": sum(p.numel() for p in model.parameters()),
        "label": "v1.1 (previous release)",
    }
    print(f"v1.1 loaded: {info['n_params']:,} params  sha={info['sha256'][:16]}…  "
          f"src={ckpt_path}")
    return model, info


def _build_v12_like(ckpt_path, label):
    """Shared builder for v1.2 and v1.2.1 β€” same architecture, same package."""
    ConfigV12, LocalVQEv12, apply_ckpt_v12, load_ckpt_v12 = _import_v12()
    cfg = ConfigV12()
    peek = torch.load(ckpt_path, map_location="cpu", weights_only=False)
    apply_ckpt_v12(peek, cfg)
    del peek
    model = LocalVQEv12.from_config(cfg).to("cpu")
    load_ckpt_v12(ckpt_path, model)
    model.align.fold_temperature()
    model.eval()
    info = {
        "source": ckpt_path,
        "sha256": _sha256(ckpt_path),
        "n_params": sum(p.numel() for p in model.parameters()),
        "label": label,
    }
    return model, info


def _build_v12():
    ckpt_path = _resolve_v12_ckpt()
    if ckpt_path is None:
        return None, None
    model, info = _build_v12_like(ckpt_path, "v1.2 (current release)")
    print(f"v1.2 loaded: {info['n_params']:,} params  sha={info['sha256'][:16]}…  "
          f"src={ckpt_path}")
    return model, info


def _build_v121():
    ckpt_path = _resolve_v121_ckpt()
    if ckpt_path is None:
        return None, None
    model, info = _build_v12_like(ckpt_path, "v1.2.1 (movement-aug finetune)")
    print(f"v1.2.1 loaded: {info['n_params']:,} params  sha={info['sha256'][:16]}…  "
          f"src={ckpt_path}")
    return model, info


def _build_v12a():
    ckpt_path = _resolve_v12a_ckpt()
    if ckpt_path is None:
        return None, None
    model, info = _build_v12_like(
        ckpt_path, "v1.2a (widened DRR + longer RIRs, from-scratch)")
    print(f"v1.2a loaded: {info['n_params']:,} params  sha={info['sha256'][:16]}…  "
          f"src={ckpt_path}")
    return model, info


def _build_v12b():
    ckpt_path = _resolve_v12b_ckpt()
    if ckpt_path is None:
        return None, None
    model, info = _build_v12_like(
        ckpt_path, "v1.2b (v10: audible reverb + conference mix + pop fixes)")
    print(f"v1.2b loaded: {info['n_params']:,} params  sha={info['sha256'][:16]}…  "
          f"src={ckpt_path}")
    return model, info


def _build_v12c():
    ckpt_path = _resolve_v12c_ckpt()
    if ckpt_path is None:
        return None, None
    model, info = _build_v12_like(
        ckpt_path, "v1.2c (v11: level-invariance mic-gain on v1.2b base)")
    print(f"v1.2c loaded: {info['n_params']:,} params  sha={info['sha256'][:16]}…  "
          f"src={ckpt_path}")
    return model, info


def _build_v12d():
    ckpt_path = _resolve_v12d_ckpt()
    if ckpt_path is None:
        return None, None
    model, info = _build_v12_like(
        ckpt_path, "v1.2d (v11_refine e22: low-LR cosine polish of v1.2c)")
    print(f"v1.2d loaded: {info['n_params']:,} params  sha={info['sha256'][:16]}…  "
          f"src={ckpt_path}")
    return model, info


MODEL_V1, INFO_V1 = _build_v1()
MODEL_V11, INFO_V11 = _build_v11()
MODEL_V12, INFO_V12 = _build_v12()
MODEL_V121, INFO_V121 = _build_v121()
MODEL_V12A, INFO_V12A = _build_v12a()
MODEL_V12B, INFO_V12B = _build_v12b()
MODEL_V12C, INFO_V12C = _build_v12c()
MODEL_V12D, INFO_V12D = _build_v12d()

MODELS: dict[str, object] = {}
INFOS: dict[str, dict] = {}
if MODEL_V1 is not None:
    MODELS["v1"] = MODEL_V1
    INFOS["v1"] = INFO_V1
if MODEL_V11 is not None:
    MODELS["v1.1"] = MODEL_V11
    INFOS["v1.1"] = INFO_V11
if MODEL_V12 is not None:
    MODELS["v1.2"] = MODEL_V12
    INFOS["v1.2"] = INFO_V12
if MODEL_V121 is not None:
    MODELS["v1.2.1"] = MODEL_V121
    INFOS["v1.2.1"] = INFO_V121
if MODEL_V12A is not None:
    MODELS["v1.2a"] = MODEL_V12A
    INFOS["v1.2a"] = INFO_V12A
if MODEL_V12B is not None:
    MODELS["v1.2b"] = MODEL_V12B
    INFOS["v1.2b"] = INFO_V12B
if MODEL_V12C is not None:
    MODELS["v1.2c"] = MODEL_V12C
    INFOS["v1.2c"] = INFO_V12C
if MODEL_V12D is not None:
    MODELS["v1.2d"] = MODEL_V12D
    INFOS["v1.2d"] = INFO_V12D
if not MODELS:
    raise RuntimeError(
        "No model could be loaded. Set LOCALVQE_V1_CKPT, "
        "LOCALVQE_V11_CKPT, LOCALVQE_V12_CKPT, LOCALVQE_V121_CKPT, "
        "LOCALVQE_V12A_CKPT, LOCALVQE_V12B_CKPT, LOCALVQE_V12C_CKPT, "
        "or LOCALVQE_V12D_CKPT, or ensure HF access for the "
        "published files."
    )

DEFAULT_MODEL_KEY = (
    "v1.2d" if "v1.2d" in MODELS
    else "v1.2c" if "v1.2c" in MODELS
    else "v1.2b" if "v1.2b" in MODELS
    else "v1.2a" if "v1.2a" in MODELS
    else "v1.2.1" if "v1.2.1" in MODELS
    else "v1.2" if "v1.2" in MODELS
    else "v1.1" if "v1.1" in MODELS
    else "v1"
)

# Dev mode: shows the diagnostic-source dropdown and mask-smoother
# accordion in the UI. Auto-on locally, auto-off on HF Spaces (which
# always sets `SPACE_ID`). Either can be overridden by setting
# LOCALVQE_DEV_MODE=1 (force on) or =0 (force off).
def _dev_mode() -> bool:
    explicit = os.environ.get("LOCALVQE_DEV_MODE")
    if explicit in ("0", "1"):
        return explicit == "1"
    return "SPACE_ID" not in os.environ
DEV_MODE = _dev_mode()
if DEV_MODE:
    print("DEV_MODE=on (debug accordions visible). Set LOCALVQE_DEV_MODE=0 to hide.")


def _load_mono_16k(path: str) -> np.ndarray:
    wav, sr = sf.read(path, dtype="float32", always_2d=False)
    if wav.ndim == 2:
        wav = wav.mean(axis=1)
    if sr != SR:
        from math import gcd
        g = gcd(sr, SR)
        wav = resample_poly(wav, SR // g, sr // g).astype(np.float32)
    return wav


# Debug / diagnostic helpers live in `_debug.py`, which is excluded
# from the HuggingFace Spaces deploy. When this file is missing the
# app silently degrades: no debug accordions, no diagnostic-source
# branches, just the standard model forward.
try:
    import _debug as _dbg
    DEBUG_AVAILABLE = True
except ImportError:
    _dbg = None
    DEBUG_AVAILABLE = False


def _noise_gate(x: np.ndarray, threshold_dbfs: float) -> np.ndarray:
    """Hard-gate frames whose RMS is below `threshold_dbfs` to zero.

    Operates on 10 ms frames (160 samples at 16 kHz) β€” short enough
    that speech bursts aren't truncated, long enough that a single
    out-of-band sample inside an active region doesn't get muted.
    The ungated tail (samples that don't fill a full final frame) is
    passed through unchanged.
    """
    frame = 160
    n = len(x) // frame
    if n == 0:
        return x
    f = x[: n * frame].reshape(n, frame).astype(np.float32)
    rms = np.sqrt((f * f).mean(axis=-1) + 1e-12)
    rms_db = 20.0 * np.log10(rms + 1e-12)
    keep = (rms_db > threshold_dbfs).astype(np.float32)
    gated = (f * keep[:, None]).reshape(-1)
    return np.concatenate([gated, x[n * frame:]]).astype(x.dtype)


def enhance(mic_path: str, ref_path: str,
            model_choice: str = DEFAULT_MODEL_KEY,
            gate_enabled: bool = False,
            gate_threshold_db: float = -45.0,
            smoother_mode: str = "off",
            smoother_attack_db: float = 12.0,
            smoother_release_db: float = 1.0,
            smoother_ema_alpha: float = 0.7,
            smoother_floor_db: float = 20.0,
            smoother_median_k: int = 3,
            debug_source: str = "enhanced",
            f_smooth_kernel: int = 31,
            f_smooth_mode: str = "median") -> tuple[int, np.ndarray]:
    if mic_path is None:
        raise gr.Error("Upload or pick a mic recording first.")
    if model_choice not in MODELS:
        raise gr.Error(f"Model {model_choice!r} not loaded. Available: {list(MODELS)}")
    model = MODELS[model_choice]

    mic = _load_mono_16k(mic_path)
    if ref_path is None:
        ref = np.zeros_like(mic)
    else:
        ref = _load_mono_16k(ref_path)

    n = max(len(mic), len(ref))
    if len(mic) < n:
        mic = np.pad(mic, (0, n - len(mic)))
    if len(ref) < n:
        ref = np.pad(ref, (0, n - len(ref)))

    mic_t = torch.from_numpy(mic).unsqueeze(0)
    ref_t = torch.from_numpy(ref).unsqueeze(0)

    with torch.no_grad():
        if DEBUG_AVAILABLE and debug_source != "enhanced":
            enc = _dbg.apply_debug_source(
                model, mic_t, ref_t, debug_source,
                smoother_ema_alpha=smoother_ema_alpha,
                f_smooth_kernel=f_smooth_kernel,
                f_smooth_mode=f_smooth_mode,
            )
        else:
            enc = model(mic_t, ref_t)

        if (DEBUG_AVAILABLE and smoother_mode != "off"
                and debug_source not in ("passthrough", "bypass_ccm")):
            enc = _dbg.apply_smoother(
                enc, model.encoder(mic_t), smoother_mode,
                attack_db=smoother_attack_db,
                release_db=smoother_release_db,
                ema_alpha=smoother_ema_alpha,
                floor_db=smoother_floor_db,
                median_k=smoother_median_k,
            )
        enh = model.decoder(enc.float(), length=n)

    out = enh[0].cpu().numpy()
    peak = float(np.abs(out).max())
    if peak > 0.95:
        out = out / peak * 0.95
    # Optional residual-echo gate: silence frames whose RMS sits below
    # `gate_threshold_db` dBFS. Off by default so listeners can A/B
    # against the raw model output via the slider.
    if gate_enabled:
        out = _noise_gate(out, gate_threshold_db)
    # Convert to int16 ourselves: Gradio's gr.Audio output otherwise
    # peak-normalises float arrays via convert_to_16_bit_wav (data /=
    # np.abs(data).max(); * 32767), which amplifies the cancelled-echo
    # residual on AEC-heavy clips by 1000Γ—+ and makes it sound like
    # the model isn't suppressing anything. Returning int16 preserves
    # the true (quiet) loudness so listeners hear the actual output.
    out_i16 = np.clip(out * 32767, -32768, 32767).astype(np.int16)
    return SR, out_i16


EXAMPLES = [
    [
        str(EXAMPLES_DIR / "ne_st_noisy_mic.wav"),
        str(EXAMPLES_DIR / "ne_st_noisy_ref.wav"),
    ],
    [
        str(EXAMPLES_DIR / "ne_st_clean_mic.wav"),
        str(EXAMPLES_DIR / "ne_st_clean_ref.wav"),
    ],
    [
        str(EXAMPLES_DIR / "fe_st_mic.wav"),
        str(EXAMPLES_DIR / "fe_st_ref.wav"),
    ],
    [
        str(EXAMPLES_DIR / "fe_st2_mic.wav"),
        str(EXAMPLES_DIR / "fe_st2_ref.wav"),
    ],
    [
        str(EXAMPLES_DIR / "dt_mic.wav"),
        str(EXAMPLES_DIR / "dt_ref.wav"),
    ],
]

DESCRIPTION = """
**LocalVQE** is a ~1 M-parameter open-source model that cleans up a
microphone signal on a voice call: it cancels the remote participant's
voice being picked up again (echo), suppresses background noise, and
removes reverberation β€” all in a single causal pass on CPU.

Provide two inputs:

- **Mic**: the raw microphone recording (what the far end would hear
  without any processing).
- **Far-end reference**: the audio being played out of your speakers.
  For a pure noise-suppression test (no speaker playback), upload
  silence or leave empty.

Try the bundled examples first β€” they cover heavy and light
near-end noise (NE-ST mixed with DNS5 background at 5 dB and 20 dB
SNR), a clean far-end single-talk clip, a far-end clip with some
near-end overlap (mislabelled in the source corpus, but a useful
test of AEC + near-end preservation together), and a double-talk
clip β€” all from the ICASSP 2022 AEC Challenge blind set.

Weights: [LocalAI-io/LocalVQE](https://huggingface.co/LocalAI-io/LocalVQE) Β·
Code: [github.com/localai-org/LocalVQE](https://github.com/localai-org/LocalVQE)
"""

with gr.Blocks(title="LocalVQE Demo") as demo:
    gr.Markdown("# LocalVQE: real-time AEC + noise suppression + dereverb")
    gr.Markdown(DESCRIPTION)
    with gr.Row():
        mic_in = gr.Audio(label="Mic (microphone recording)", type="filepath")
        ref_in = gr.Audio(label="Far-end reference (speaker playback)", type="filepath")
    model_choice = gr.Radio(
        choices=list(MODELS.keys()),
        value=DEFAULT_MODEL_KEY,
        label="Model",
        info=(
            "v1.2 is the current release. SiLU activation + 1024 ms "
            "echo-search window + wider clean-pool DNSMOS filter + "
            "phone-bandwidth + codec round-trip aug. Adds ~+0.3 "
            "echo_mos and ~+1 dB ERLE on the AEC blind set vs v1.1. "
            "v1.1 / v1 are kept for A/B. Same param count (1.3 M). "
            "Switch and re-run on the same clip to compare."
        ),
    ) if len(MODELS) > 1 else gr.State(DEFAULT_MODEL_KEY)
    with gr.Row():
        gate_enabled = gr.Checkbox(
            label="Residual-echo gate",
            value=False,
            info=(
                "Post-process the enhanced output: silence any 10 ms frame "
                "whose RMS falls below the threshold. Cleans up the quiet "
                "residual you'd hear during far-end-only stretches; will "
                "also mute genuinely quiet speech below the threshold."
            ),
        )
        gate_threshold_db = gr.Slider(
            label="Gate threshold (dBFS)",
            minimum=-70.0, maximum=-20.0, value=-45.0, step=1.0,
        )
    if DEBUG_AVAILABLE and DEV_MODE:
        _dbg_components = _dbg.build_debug_ui(gr)
        debug_source = _dbg_components["debug_source"]
        f_smooth_kernel = _dbg_components["f_smooth_kernel"]
        f_smooth_mode = _dbg_components["f_smooth_mode"]
        smoother_mode = _dbg_components["smoother_mode"]
        smoother_attack_db = _dbg_components["smoother_attack_db"]
        smoother_release_db = _dbg_components["smoother_release_db"]
        smoother_ema_alpha = _dbg_components["smoother_ema_alpha"]
        smoother_floor_db = _dbg_components["smoother_floor_db"]
        smoother_median_k = _dbg_components["smoother_median_k"]
    else:
        # Production / no _debug.py β€” hidden gr.State holders carrying
        # neutral defaults, so `enhance()` keeps a stable input list.
        debug_source = gr.State("enhanced")
        f_smooth_kernel = gr.State(31)
        f_smooth_mode = gr.State("median")
        smoother_mode = gr.State("off")
        smoother_attack_db = gr.State(12.0)
        smoother_release_db = gr.State(1.0)
        smoother_ema_alpha = gr.State(0.7)
        smoother_floor_db = gr.State(20.0)
        smoother_median_k = gr.State(3)
    btn = gr.Button("Enhance", variant="primary")
    out = gr.Audio(label="Enhanced output", type="numpy")

    gr.Examples(
        examples=EXAMPLES,
        inputs=[mic_in, ref_in],
        label=(
            "Examples β€” top to bottom: near-end + heavy noise (5 dB SNR, "
            "pure NS), near-end + light noise (20 dB SNR, NS preserving "
            "clean speech), far-end single-talk (pure AEC), far-end with "
            "brief near-end overlap (AEC while preserving NE), and "
            "double-talk (AEC while near-end is also talking)."
        ),
    )

    btn.click(
        enhance,
        inputs=[mic_in, ref_in, model_choice,
                gate_enabled, gate_threshold_db,
                smoother_mode, smoother_attack_db, smoother_release_db,
                smoother_ema_alpha, smoother_floor_db, smoother_median_k,
                debug_source, f_smooth_kernel, f_smooth_mode],
        outputs=out,
    )

    _info_lines = []
    for key in MODELS:
        i = INFOS[key]
        _info_lines.append(
            f"<b>{i['label']}</b> β€” <code>{i['source']}</code> Β· "
            f"sha256 <code>{i['sha256'][:16]}…</code> Β· "
            f"{i['n_params']:,} params"
        )
    gr.Markdown("<sub>Loaded models:<br>" + "<br>".join(_info_lines) + "</sub>")

if __name__ == "__main__":
    demo.launch(server_name=os.environ.get("GRADIO_SERVER_NAME", "127.0.0.1"))