LiSenNet / README.md
claroche1's picture
upload LiSenNet rnn variant -> gru/
7a9640a verified
|
Raw
History Blame Contribute Delete
3.9 kB
metadata
license: mit
library_name: pytorch
tags:
  - speech-enhancement
  - audio
  - denoising
  - onnx
  - causal
  - streaming
  - real-time
  - edge-ai
  - stm32
datasets:
  - JacobLinCool/VoiceBank-DEMAND-16k

LiSenNet

Ultra-compact, causal, real-time speech enhancers trained on VoiceBank-DEMAND-16k — a sub-band U-Net with a magnitude-only mask (phase from a 2-iteration Griffin-Lim offline, or the noisy phase for real-time). Port of Yan, Zhou, Chen & Lu, LiSenNet, arXiv:2409.13285 (hyyan2k/LiSenNet, MIT).

This repo holds two variants, each in its own subfolder:

subfolder bottleneck params NPU FP32 PESQ real-time int8 PESQ
gru/ dual-path GRU (faithful) 36,783 3.006 2.930
conv/ dual-path conv (NPU) 41,063 2.970 2.855

PESQ is wideband, on the full 824-utterance VoiceBank-DEMAND test split.

  • gru/ is the faithful reproduction and the quality reference. Its GRU + 2-axis LayerNorm do not compile to the STM32N6 Neural-ART NPU.
  • conv/ replaces the GRU bottleneck with a dual-path conv one so the whole model maps to the NPU (exported graph has 0 GRU / 0 LayerNormalization). It adds conv/g_best_streaming_fp32.onnx — the frame-by-frame streaming graph with explicit FIFO state I/O (feat + N state_i_in -> est_mag + N state_i_out), the artifact handed to stedgeai. The ~0.04 FP32 gap is the cost of dropping recurrence to fit the NPU.

Code + full write-up: https://github.com/LarocheC/eco8-neaixt — see RESULTS_LISENNET.md.

Files (per subfolder)

config.json, g_best (PyTorch {"generator": state_dict}), g_best_fp32.onnx and g_best_int8_static.onnx (whole-utterance mask sub-network, feat (B,3,T,F) -> est_mag (B,T,F)). conv/ additionally has g_best_streaming_fp32.onnx and g_best_streaming_int8_static.onnx (single frame + explicit state I/O). The ONNX graphs are the mask sub-network only — STFT, feature build and phase recovery stay host-side.

Loading (PyTorch)

import json, torch
from huggingface_hub import hf_hub_download
from common.env import AttrDict
from lisennet.model import build_lisennet

REPO, SUB = "claroche1/LiSenNet", "conv"      # or "gru"
cfg  = json.load(open(hf_hub_download(REPO, f"{SUB}/config.json")))
ckpt = torch.load(hf_hub_download(REPO, f"{SUB}/g_best"), map_location="cpu", weights_only=True)
model = build_lisennet(AttrDict(cfg)).eval()
model.load_state_dict(ckpt["generator"])   # model(noisy_wav)["est"]

Running the NPU streaming graph (frame-by-frame)

import numpy as np, onnxruntime as ort
from huggingface_hub import hf_hub_download

sess = ort.InferenceSession(
    hf_hub_download("claroche1/LiSenNet", "conv/g_best_streaming_fp32.onnx"),
    providers=["CPUExecutionProvider"],
)
state_in = [i for i in sess.get_inputs() if i.name != "feat"]   # FIFO states
out_names = [o.name for o in sess.get_outputs()]                # est_mag + state_*_out
zeros = lambda s: np.zeros([d if isinstance(d, int) else 1 for d in s], np.float32)
states = {i.name: zeros(i.shape) for i in state_in}            # start-of-stream = zeros

def step(feat_t):                                              # feat_t: (1, 3, 1, 257)
    res = sess.run(out_names, {"feat": feat_t, **states})
    for i, v in zip(state_in, res[1:]):
        states[i.name] = v
    return res[0]                                              # est_mag (1, 1, 257)

License

MIT. See the source repository for training code and full attribution.