claroche1's picture
200-epoch retrain: 9 variants with FP32 + int8 ONNX + int8 PESQ findings
4dfdc5e verified
metadata
license: mit
language: en
library_name: pytorch
tags:
  - speech-enhancement
  - speech-denoising
  - nsnet2
  - butterfly
  - monarch
  - structured-matrices
  - onnx
  - quantization
  - int8
datasets:
  - JacobLinCool/VoiceBank-DEMAND-16k
metrics:
  - pesq

sparse-nsnet2 — checkpoints

Best-PESQ checkpoints from the sparse-nsnet2 compression sweep: NSNet2 speech enhancement with the FC and GRU layers swappable between dense, Butterfly, and Monarch factorizations.

Trained on VoiceBank-DEMAND-16k for 200 epochs at batch 256, n_fft 512, on a single GTX 1080 Ti.

Each variant ships in three formats:

  • g_best — PyTorch checkpoint (training-time weights)
  • g_best_fp32.onnx — streaming-shape FP32 ONNX (frame-by-frame, opset 17)
  • g_best.onnx — static int8 ONNX (QDQ format, per-channel weight quant, MinMax calibration on 200 VBD-train utterances)

Results

PESQ measured on the full VBD test split (824 utterances). RTF (real-time factor) is for the int8 ONNX session under onnxruntime CPU; lower is faster.

run params FP32 PESQ int8 PESQ Δ (FP32→int8) int8 RTF
wide_monarch 2.36 M 2.864 2.842 +0.021 0.166
baseline 2.78 M 2.845 2.833 +0.012 0.452
monarch_8 0.36 M 2.832 2.826 +0.006 0.025
monarch_full 0.70 M 2.827 2.848 −0.021 0.039
monarch_fc 2.14 M 2.805 2.789 +0.016 0.448
butterfly_2blocks 0.36 M 2.805 2.202 +0.602 0.441
butterfly_fc 1.99 M 2.799 2.494 +0.306 0.522
butterfly_ortho 0.19 M 2.780 2.577 +0.203 0.232
butterfly_full 0.19 M 2.772 2.128 +0.644 0.230

Key findings

  • Monarch variants are essentially loss-free under int8 (|Δ| ≤ 0.021 across the board). A single Einsum per FC plus per-channel weight quantization is genuinely friendly to int8 calibration.
  • wide_monarch is the best deployment target for quality: highest FP32 PESQ (2.864) with near-zero int8 loss (2.842). For speed-constrained deployment, monarch_full and monarch_8 trade ~0.02 PESQ for an RTF of 0.04 / 0.025 — over 10× faster than the dense baseline.
  • Butterfly with random init degrades catastrophically under int8 (Δ up to 0.64 PESQ on butterfly_full). Int8 deployment with butterfly factorizations should use init=ortho: butterfly_ortho loses 0.20 PESQ to int8 versus 0.64 for butterfly_full (same architecture, same training data, only the init differs).
  • Longer training makes randn-init butterfly worse on int8. The same butterfly_full config saw its int8 gap grow from 0.36 PESQ at 50 epochs → 0.64 PESQ at 200 epochs as twiddle factors drifted further from orthogonality. Ortho-init butterfly does not show this regression.

Layout

Each subdirectory is one run, containing the saved generator (g_best), the streaming FP32 ONNX, the static int8 ONNX, and the exact config.json the run was trained with.

baseline/{g_best,g_best_fp32.onnx,g_best.onnx,config.json}
monarch_fc/{g_best,g_best_fp32.onnx,g_best.onnx,config.json}
butterfly_fc/{g_best,g_best_fp32.onnx,g_best.onnx,config.json}
monarch_full/{g_best,g_best_fp32.onnx,g_best.onnx,config.json}
butterfly_full/{g_best,g_best_fp32.onnx,g_best.onnx,config.json}
monarch_8/{g_best,g_best_fp32.onnx,g_best.onnx,config.json}
butterfly_ortho/{g_best,g_best_fp32.onnx,g_best.onnx,config.json}
butterfly_2blocks/{g_best,g_best_fp32.onnx,g_best.onnx,config.json}
wide_monarch/{g_best,g_best_fp32.onnx,g_best.onnx,config.json}

Loading

PyTorch checkpoint

Clone the repo first (model classes live there):

git clone https://github.com/LarocheC/sparse-nsnet2
cd sparse-nsnet2
uv sync

Then:

import json, torch
from huggingface_hub import hf_hub_download
from env import AttrDict
from models.model import NSNet2

REPO = "claroche1/sparse-nsnet2-checkpoints"
RUN  = "wide_monarch"  # or any name from the table

cfg  = json.load(open(hf_hub_download(REPO, f"{RUN}/config.json")))
ckpt = torch.load(hf_hub_download(REPO, f"{RUN}/g_best"),
                  map_location="cuda", weights_only=False)

model = NSNet2(AttrDict(cfg)).cuda().eval()
model.load_state_dict(ckpt["generator"])

ONNX (FP32 or int8)

The ONNX models are streaming-shape: a single frame (B, n_freq) plus the GRU state (num_layers, B, hidden) per session call, threaded across frames. The end-to-end pipeline (RMS-norm → STFT → frame loop → iSTFT) is in inference_onnx.py in the source repo.

import onnxruntime as ort
from huggingface_hub import hf_hub_download

REPO = "claroche1/sparse-nsnet2-checkpoints"
RUN  = "wide_monarch"

# FP32:
fp32_path = hf_hub_download(REPO, f"{RUN}/g_best_fp32.onnx")
fp32_sess = ort.InferenceSession(fp32_path, providers=["CPUExecutionProvider"])

# int8 (deployment):
int8_path = hf_hub_download(REPO, f"{RUN}/g_best.onnx")
int8_sess = ort.InferenceSession(int8_path, providers=["CPUExecutionProvider"])

End-to-end inference example with PESQ measurement is in inference_onnx.py in the source repo.

Citations

@inproceedings{braun2021nsnet2,
    title={Towards efficient models for real-time deep noise suppression},
    author={Braun, Sebastian and Tashev, Ivan},
    booktitle={ICASSP},
    year={2021}
}

@inproceedings{dao2019butterfly,
    title={Learning fast algorithms for linear transforms using butterfly factorizations},
    author={Dao, Tri and Gu, Albert and Eichhorn, Matthew and Rudra, Atri and R{\'e}, Christopher},
    booktitle={ICML},
    year={2019}
}

@inproceedings{dao2022monarch,
    title={Monarch: Expressive structured matrices for efficient and accurate training},
    author={Dao, Tri and Chen, Beidi and Sohoni, Nimit S and Desai, Arjun and Poli, Michael and Grogan, Jessica and Liu, Alexander and Rao, Aniruddh and Rudra, Atri and R{\'e}, Christopher},
    booktitle={ICML},
    year={2022}
}