LNA β€” Learning Noise Adapters for Incremental Speech Enhancement

Reimplementation of the paper "Learning Noise Adapters for Incremental Speech Enhancement".
Code: annkisluk/speech


Files

File Description
session0_pretrain/lna_pretrained.pt Pretrained LNA backbone (40 epochs, 10 NOISEX-92 noise types)
session1_incremental/lna_session1.pt After incremental session 1 (alarm noise)
session2_incremental/lna_session2.pt After incremental session 2 (cough noise)
session3_incremental/lna_session3.pt After incremental session 3 (destroyerops noise)
session4_incremental/lna_session4.pt After incremental session 4 (machinegun noise)
config.json Model architecture and training hyperparameters

Architecture

  • Backbone: SepFormer (N=256, L=16, 2 DPT blocks Γ— 8 layers, 8 heads, d_ffn=1024)
  • Adapters: Noise adapters with bottleneck dim Ĉ=1 (FFL + MHA per transformer layer)
  • Parameters: 25.6M total (98k per new noise adapter)
  • Input: 8 kHz mono waveform

Usage

import torch
from huggingface_hub import hf_hub_download
import sys, os

# Clone the code repo
# git clone https://github.com/annkisluk/speech && cd speech

from src.models.lna_model import LNAModel

# Download checkpoint
ckpt_path = hf_hub_download(repo_id="Annkisluk/lna-speech",
                            filename="session0_pretrain/lna_pretrained.pt")

# Build model
model = LNAModel(
    n_basis=256, kernel_size=16, num_layers=8, num_blocks=2,
    nhead=8, dim_feedforward=1024, dropout=0.1,
    adapter_bottleneck_dim=1, max_sessions=6
)
model.load_checkpoint(ckpt_path)
model.eval()

# Enhance speech  (noisy: [1, T] tensor at 8 kHz)
noisy = torch.randn(1, 32000)  # 4-second example
with torch.no_grad():
    enhanced = model(noisy, session_id=None)  # session_id=None β†’ base model only

For incremental sessions, load lna_session{N}.pt and add adapters first:

# Load session 4 model (knows all 4 noise domains)
ckpt_path = hf_hub_download(repo_id="Annkisluk/lna-speech",
                            filename="session4_incremental/lna_session4.pt")

model = LNAModel(n_basis=256, kernel_size=16, num_layers=8, num_blocks=2,
                 nhead=8, dim_feedforward=1024, dropout=0.1,
                 adapter_bottleneck_dim=1, max_sessions=6)
for sid in range(1, 5):
    model.add_new_session(session_id=sid, bottleneck_dim=1)
model.load_checkpoint(ckpt_path)
model.eval()

with torch.no_grad():
    enhanced = model(noisy, session_id=2)  # route to session 2 adapter (cough)

Training Data

Session Noise types Train samples Test samples
0 (pretrain) 10 NOISEX-92 40,400 6,510
1–4 (incremental) 1 new noise each 1,212/session 651/session

SNR range: {βˆ’5, 0, 5, 10} dB. Sample rate: 8 kHz. Speech: LibriSpeech train-clean-100.

Downloads last month
97
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support