You need to agree to share your contact information to access this model

This repository is publicly accessible, but you have to accept the conditions to access its files and content.

Log in or Sign Up to review the conditions and access this model content.

sparse-nsnet2 — checkpoints

Best-PESQ checkpoints from the sparse-nsnet2 compression sweep: NSNet2 speech enhancement with the FC and GRU layers swappable between dense, Butterfly, and Monarch factorizations.

Trained on VoiceBank-DEMAND-16k for 200 epochs at batch 256, n_fft 512, on a single GTX 1080 Ti.

Each variant ships in three formats:

  • g_best — PyTorch checkpoint (training-time weights)
  • g_best_fp32.onnx — streaming-shape FP32 ONNX (frame-by-frame, opset 17)
  • g_best.onnx — static int8 ONNX (QDQ format, per-channel weight quant, MinMax calibration on 200 VBD-train utterances)

Results

PESQ measured on the full VBD test split (824 utterances). RTF (real-time factor) is for the int8 ONNX session under onnxruntime CPU; lower is faster.

run params FP32 PESQ int8 PESQ Δ (FP32→int8) int8 RTF
wide_monarch 2.36 M 2.864 2.842 +0.021 0.166
baseline 2.78 M 2.845 2.833 +0.012 0.452
monarch_8 0.36 M 2.832 2.826 +0.006 0.025
monarch_full 0.70 M 2.827 2.848 −0.021 0.039
monarch_fc 2.14 M 2.805 2.789 +0.016 0.448
butterfly_2blocks 0.36 M 2.805 2.202 +0.602 0.441
butterfly_fc 1.99 M 2.799 2.494 +0.306 0.522
butterfly_ortho 0.19 M 2.780 2.577 +0.203 0.232
butterfly_full 0.19 M 2.772 2.128 +0.644 0.230

Key findings

  • Monarch variants are essentially loss-free under int8 (|Δ| ≤ 0.021 across the board). A single Einsum per FC plus per-channel weight quantization is genuinely friendly to int8 calibration.
  • wide_monarch is the best deployment target for quality: highest FP32 PESQ (2.864) with near-zero int8 loss (2.842). For speed-constrained deployment, monarch_full and monarch_8 trade ~0.02 PESQ for an RTF of 0.04 / 0.025 — over 10× faster than the dense baseline.
  • Butterfly with random init degrades catastrophically under int8 (Δ up to 0.64 PESQ on butterfly_full). Int8 deployment with butterfly factorizations should use init=ortho: butterfly_ortho loses 0.20 PESQ to int8 versus 0.64 for butterfly_full (same architecture, same training data, only the init differs).
  • Longer training makes randn-init butterfly worse on int8. The same butterfly_full config saw its int8 gap grow from 0.36 PESQ at 50 epochs → 0.64 PESQ at 200 epochs as twiddle factors drifted further from orthogonality. Ortho-init butterfly does not show this regression.

Layout

Each subdirectory is one run, containing the saved generator (g_best), the streaming FP32 ONNX, the static int8 ONNX, and the exact config.json the run was trained with.

baseline/{g_best,g_best_fp32.onnx,g_best.onnx,config.json}
monarch_fc/{g_best,g_best_fp32.onnx,g_best.onnx,config.json}
butterfly_fc/{g_best,g_best_fp32.onnx,g_best.onnx,config.json}
monarch_full/{g_best,g_best_fp32.onnx,g_best.onnx,config.json}
butterfly_full/{g_best,g_best_fp32.onnx,g_best.onnx,config.json}
monarch_8/{g_best,g_best_fp32.onnx,g_best.onnx,config.json}
butterfly_ortho/{g_best,g_best_fp32.onnx,g_best.onnx,config.json}
butterfly_2blocks/{g_best,g_best_fp32.onnx,g_best.onnx,config.json}
wide_monarch/{g_best,g_best_fp32.onnx,g_best.onnx,config.json}

Loading

PyTorch checkpoint

Clone the repo first (model classes live there):

git clone https://github.com/LarocheC/sparse-nsnet2
cd sparse-nsnet2
uv sync

Then:

import json, torch
from huggingface_hub import hf_hub_download
from env import AttrDict
from models.model import NSNet2

REPO = "claroche1/sparse-nsnet2-checkpoints"
RUN  = "wide_monarch"  # or any name from the table

cfg  = json.load(open(hf_hub_download(REPO, f"{RUN}/config.json")))
ckpt = torch.load(hf_hub_download(REPO, f"{RUN}/g_best"),
                  map_location="cuda", weights_only=False)

model = NSNet2(AttrDict(cfg)).cuda().eval()
model.load_state_dict(ckpt["generator"])

ONNX (FP32 or int8)

The ONNX models are streaming-shape: a single frame (B, n_freq) plus the GRU state (num_layers, B, hidden) per session call, threaded across frames. The end-to-end pipeline (RMS-norm → STFT → frame loop → iSTFT) is in inference_onnx.py in the source repo.

import onnxruntime as ort
from huggingface_hub import hf_hub_download

REPO = "claroche1/sparse-nsnet2-checkpoints"
RUN  = "wide_monarch"

# FP32:
fp32_path = hf_hub_download(REPO, f"{RUN}/g_best_fp32.onnx")
fp32_sess = ort.InferenceSession(fp32_path, providers=["CPUExecutionProvider"])

# int8 (deployment):
int8_path = hf_hub_download(REPO, f"{RUN}/g_best.onnx")
int8_sess = ort.InferenceSession(int8_path, providers=["CPUExecutionProvider"])

End-to-end inference example with PESQ measurement is in inference_onnx.py in the source repo.

Citations

@inproceedings{braun2021nsnet2,
    title={Towards efficient models for real-time deep noise suppression},
    author={Braun, Sebastian and Tashev, Ivan},
    booktitle={ICASSP},
    year={2021}
}

@inproceedings{dao2019butterfly,
    title={Learning fast algorithms for linear transforms using butterfly factorizations},
    author={Dao, Tri and Gu, Albert and Eichhorn, Matthew and Rudra, Atri and R{\'e}, Christopher},
    booktitle={ICML},
    year={2019}
}

@inproceedings{dao2022monarch,
    title={Monarch: Expressive structured matrices for efficient and accurate training},
    author={Dao, Tri and Chen, Beidi and Sohoni, Nimit S and Desai, Arjun and Poli, Michael and Grogan, Jessica and Liu, Alexander and Rao, Aniruddh and Rudra, Atri and R{\'e}, Christopher},
    booktitle={ICML},
    year={2022}
}
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train claroche1/sparse-nsnet2-checkpoints

Papers for claroche1/sparse-nsnet2-checkpoints