CMGAN — LiteRT (CompiledModel GPU) speech enhancement / noise suppression

CMGAN (TASLP 2024, MIT, trained on VoiceBank-DEMAND) re-authored to a GPU-native LiteRT .tflite: a 1.83 M-param dual-path conformer (4× time+freq blocks) that denoises 2 s 16 kHz chunks. The STFT (hamming-DFT as one Conv1d) and the mag^0.3 power compression run inside the graph; the host does only reflect-padding, un-compression, inverse STFT and overlap-add. FP16, 4.2 MB.

CMGAN on-device: noise suppression

Noisy input → enhanced output (on-device, +7.2 dB SI-SNR on this clip). Waveforms from the fp16 model.

Verified on a Pixel 8a (Tensor G3): 1 651 / 1 651 nodes LITERT_CL (full residency, 1 partition), ~20 ms per 2 s chunk (RTF ≈ 0.01), SI-SNR +7.2 dB on a 6.6 dB noisy sample (PyTorch reference +9.6 dB), device-vs-PyTorch waveform corr 0.997. fp16 tflite-vs-PyTorch corr 0.999999 (desktop).

I/O

  • Input [1, 32400] float32 — a 32 000-sample (2 s @ 16 kHz) mono chunk, RMS-normalized (x·c with c = sqrt(N / Σx²), un-normalize the output by /c) and reflect-padded by 200 samples on both sides (torch.stft(center=True) equivalent).
  • Outputs [1, 1, 321, 201] × 2 (real, imag) — the enhanced compressed complex spectrogram (mag^0.3 domain, layout [T, F]). Un-compress by scaling (r, i) by (r²+i²)^(7/6), then iSTFT (n_fft 400 / hop 100 / periodic hamming, trim the 200-sample center pad) and overlap-add chunks.

Minimal usage

import numpy as np, soundfile as sf, torch
from ai_edge_litert.interpreter import Interpreter

SR, NFFT, HOP, S = 16000, 400, 100, 32000
wav, _ = sf.read("noisy.wav", dtype="float32")        # mono 16 kHz
x = np.zeros(S, np.float32); n = min(len(wav), S); x[:n] = wav[:n]
c = np.sqrt(S / (x @ x + 1e-12)); x *= c              # RMS normalize
xp = np.concatenate([x[NFFT//2:0:-1], x, x[-2:-NFFT//2-2:-1]])  # reflect pad -> [32400]

it = Interpreter(model_path="cmgan_fp16.tflite"); it.allocate_tensors()
it.set_tensor(it.get_input_details()[0]["index"], xp[None]); it.invoke()
r, i = (torch.tensor(it.get_tensor(o["index"])) for o in
        sorted(it.get_output_details(), key=lambda o: o["index"]))   # [1,1,321,201] compressed

m2 = (r * r + i * i).clamp_min(1e-12) ** (7.0 / 6.0)                 # mag^0.3 -> mag
spec = torch.complex(r * m2, i * m2)[0, 0].T                         # [201, 321]
den = torch.istft(spec, NFFT, HOP, window=torch.hamming_window(NFFT), length=S)
sf.write("denoised.wav", (den.numpy() / c), SR)

Kotlin (Android, LiteRT CompiledModel GPU)

// implementation("com.google.ai.edge.litert:litert:2.1.5")
val model = CompiledModel.create(File(ctx.filesDir, "cmgan_fp16.tflite").absolutePath,
    CompiledModel.Options(Accelerator.GPU), null)
val inBuf = model.createInputBuffers()
val outBuf = model.createOutputBuffers()

// 2 s chunk, RMS-normalized (x*c), reflect-padded by 200 both sides -> [32400]
inBuf[0].writeFloat(paddedChunk)
model.run(inBuf, outBuf)
val real = outBuf[0].readFloat()   // [321 * 201] compressed real, t-major (t*201 + f)
val imag = outBuf[1].readFloat()
// host: scale (r,i) by (r*r+i*i)^(7/6), then iSTFT (n_fft 400 / hop 100 / periodic hamming,
// trim the 200-sample center pad), overlap-add chunks, divide by c —
// see NoiseSuppressor.kt in the speech_enhancement LiteRT sample.

Conversion (numerically-equivalent re-authoring)

The phase path cancels algebraically (mask·mag·cos(∠x) ≡ mask·x_r) — no atan2/cos/sin in the graph. Shaw relative positional embedding (an Embedding lookup = GATHER) is baked to a constant for the fixed chunk and applied as a 2D FULLY_CONNECTED plus a pad/reshape skew realignment. The conformer's folded batches become batch-1 4D tensors (channel-LayerNorm per position, Linears as 1×1 convs, depthwise Conv1d as (1,k) Conv2d, heads folded into the 3D-BMM batch with 1/√d in Q). mag^0.3exp(0.3·ln(·)) (POW is banned); SPConvTranspose2d's 5-D view → an exact 4D reshape chain; InstanceNorm → safe spatial norm; eval-mode BatchNorm → constant scale/shift; all norm eps ≥ 1e-4 (fp16 min-normal on the GPU delegate); no dim-1 broadcast multiplies.

Upstream

  • Code + weights: ruizhecao96/CMGAN (MIT)
  • Please cite Cao et al., CMGAN: Conformer-Based Metric-GAN for Monaural Speech Enhancement (Interspeech 2022 / TASLP 2024) when you use this model.
Downloads last month
2
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support