license: mit
language: en
library_name: pytorch
tags:
- speech-enhancement
- speech-denoising
- nsnet2
- butterfly
- monarch
- structured-matrices
- onnx
- quantization
- int8
datasets:
- JacobLinCool/VoiceBank-DEMAND-16k
metrics:
- pesq
sparse-nsnet2 — checkpoints
Best-PESQ checkpoints from the sparse-nsnet2 compression sweep: NSNet2 speech enhancement with the FC and GRU layers swappable between dense, Butterfly, and Monarch factorizations.
Trained on VoiceBank-DEMAND-16k for 200 epochs at batch 256, n_fft 512, on a single GTX 1080 Ti.
Each variant ships in three formats:
g_best— PyTorch checkpoint (training-time weights)g_best_fp32.onnx— streaming-shape FP32 ONNX (frame-by-frame, opset 17)g_best.onnx— static int8 ONNX (QDQ format, per-channel weight quant, MinMax calibration on 200 VBD-train utterances)
Results
PESQ measured on the full VBD test split (824 utterances). RTF (real-time factor) is for the int8 ONNX session under onnxruntime CPU; lower is faster.
| run | params | FP32 PESQ | int8 PESQ | Δ (FP32→int8) | int8 RTF |
|---|---|---|---|---|---|
wide_monarch |
2.36 M | 2.864 | 2.842 | +0.021 | 0.166 |
baseline |
2.78 M | 2.845 | 2.833 | +0.012 | 0.452 |
monarch_8 |
0.36 M | 2.832 | 2.826 | +0.006 | 0.025 |
monarch_full |
0.70 M | 2.827 | 2.848 | −0.021 | 0.039 |
monarch_fc |
2.14 M | 2.805 | 2.789 | +0.016 | 0.448 |
butterfly_2blocks |
0.36 M | 2.805 | 2.202 | +0.602 | 0.441 |
butterfly_fc |
1.99 M | 2.799 | 2.494 | +0.306 | 0.522 |
butterfly_ortho |
0.19 M | 2.780 | 2.577 | +0.203 | 0.232 |
butterfly_full |
0.19 M | 2.772 | 2.128 | +0.644 | 0.230 |
Key findings
- Monarch variants are essentially loss-free under int8 (|Δ| ≤ 0.021
across the board). A single
Einsumper FC plus per-channel weight quantization is genuinely friendly to int8 calibration. wide_monarchis the best deployment target for quality: highest FP32 PESQ (2.864) with near-zero int8 loss (2.842). For speed-constrained deployment,monarch_fullandmonarch_8trade ~0.02 PESQ for an RTF of 0.04 / 0.025 — over 10× faster than the dense baseline.- Butterfly with random init degrades catastrophically under int8
(Δ up to 0.64 PESQ on
butterfly_full). Int8 deployment with butterfly factorizations should useinit=ortho:butterfly_ortholoses 0.20 PESQ to int8 versus 0.64 forbutterfly_full(same architecture, same training data, only the init differs). - Longer training makes randn-init butterfly worse on int8. The same
butterfly_fullconfig saw its int8 gap grow from 0.36 PESQ at 50 epochs → 0.64 PESQ at 200 epochs as twiddle factors drifted further from orthogonality. Ortho-init butterfly does not show this regression.
Layout
Each subdirectory is one run, containing the saved generator (g_best),
the streaming FP32 ONNX, the static int8 ONNX, and the exact config.json
the run was trained with.
baseline/{g_best,g_best_fp32.onnx,g_best.onnx,config.json}
monarch_fc/{g_best,g_best_fp32.onnx,g_best.onnx,config.json}
butterfly_fc/{g_best,g_best_fp32.onnx,g_best.onnx,config.json}
monarch_full/{g_best,g_best_fp32.onnx,g_best.onnx,config.json}
butterfly_full/{g_best,g_best_fp32.onnx,g_best.onnx,config.json}
monarch_8/{g_best,g_best_fp32.onnx,g_best.onnx,config.json}
butterfly_ortho/{g_best,g_best_fp32.onnx,g_best.onnx,config.json}
butterfly_2blocks/{g_best,g_best_fp32.onnx,g_best.onnx,config.json}
wide_monarch/{g_best,g_best_fp32.onnx,g_best.onnx,config.json}
Loading
PyTorch checkpoint
Clone the repo first (model classes live there):
git clone https://github.com/LarocheC/sparse-nsnet2
cd sparse-nsnet2
uv sync
Then:
import json, torch
from huggingface_hub import hf_hub_download
from env import AttrDict
from models.model import NSNet2
REPO = "claroche1/sparse-nsnet2-checkpoints"
RUN = "wide_monarch" # or any name from the table
cfg = json.load(open(hf_hub_download(REPO, f"{RUN}/config.json")))
ckpt = torch.load(hf_hub_download(REPO, f"{RUN}/g_best"),
map_location="cuda", weights_only=False)
model = NSNet2(AttrDict(cfg)).cuda().eval()
model.load_state_dict(ckpt["generator"])
ONNX (FP32 or int8)
The ONNX models are streaming-shape: a single frame (B, n_freq) plus the
GRU state (num_layers, B, hidden) per session call, threaded across
frames. The end-to-end pipeline (RMS-norm → STFT → frame loop → iSTFT)
is in inference_onnx.py in the source repo.
import onnxruntime as ort
from huggingface_hub import hf_hub_download
REPO = "claroche1/sparse-nsnet2-checkpoints"
RUN = "wide_monarch"
# FP32:
fp32_path = hf_hub_download(REPO, f"{RUN}/g_best_fp32.onnx")
fp32_sess = ort.InferenceSession(fp32_path, providers=["CPUExecutionProvider"])
# int8 (deployment):
int8_path = hf_hub_download(REPO, f"{RUN}/g_best.onnx")
int8_sess = ort.InferenceSession(int8_path, providers=["CPUExecutionProvider"])
End-to-end inference example with PESQ measurement is in inference_onnx.py
in the source repo.
Citations
@inproceedings{braun2021nsnet2,
title={Towards efficient models for real-time deep noise suppression},
author={Braun, Sebastian and Tashev, Ivan},
booktitle={ICASSP},
year={2021}
}
@inproceedings{dao2019butterfly,
title={Learning fast algorithms for linear transforms using butterfly factorizations},
author={Dao, Tri and Gu, Albert and Eichhorn, Matthew and Rudra, Atri and R{\'e}, Christopher},
booktitle={ICML},
year={2019}
}
@inproceedings{dao2022monarch,
title={Monarch: Expressive structured matrices for efficient and accurate training},
author={Dao, Tri and Chen, Beidi and Sohoni, Nimit S and Desai, Arjun and Poli, Michael and Grogan, Jessica and Liu, Alexander and Rao, Aniruddh and Rudra, Atri and R{\'e}, Christopher},
booktitle={ICML},
year={2022}
}