Instructions to use litert-community/CMGAN-LiteRT with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- LiteRT
How to use litert-community/CMGAN-LiteRT with LiteRT:
# No code snippets available yet for this library. # To use this model, check the repository files and the library's documentation. # Want to help? PRs adding snippets are welcome at: # https://github.com/huggingface/huggingface.js
- Notebooks
- Google Colab
- Kaggle
CMGAN — LiteRT (CompiledModel GPU) speech enhancement / noise suppression
CMGAN (TASLP 2024, MIT, trained on VoiceBank-DEMAND)
re-authored to a GPU-native LiteRT .tflite: a 1.83 M-param dual-path conformer
(4× time+freq blocks) that denoises 2 s 16 kHz chunks. The STFT (hamming-DFT as one Conv1d) and
the mag^0.3 power compression run inside the graph; the host does only reflect-padding,
un-compression, inverse STFT and overlap-add. FP16, 4.2 MB.
Noisy input → enhanced output (on-device, +7.2 dB SI-SNR on this clip). Waveforms from the fp16 model.
Verified on a Pixel 8a (Tensor G3): 1 651 / 1 651 nodes LITERT_CL (full residency, 1 partition), ~20 ms per 2 s chunk (RTF ≈ 0.01), SI-SNR +7.2 dB on a 6.6 dB noisy sample (PyTorch reference +9.6 dB), device-vs-PyTorch waveform corr 0.997. fp16 tflite-vs-PyTorch corr 0.999999 (desktop).
I/O
- Input
[1, 32400]float32 — a 32 000-sample (2 s @ 16 kHz) mono chunk, RMS-normalized (x·cwithc = sqrt(N / Σx²), un-normalize the output by/c) and reflect-padded by 200 samples on both sides (torch.stft(center=True)equivalent). - Outputs
[1, 1, 321, 201]× 2 (real, imag) — the enhanced compressed complex spectrogram (mag^0.3domain, layout[T, F]). Un-compress by scaling(r, i)by(r²+i²)^(7/6), then iSTFT (n_fft 400 / hop 100 / periodic hamming, trim the 200-sample center pad) and overlap-add chunks.
Minimal usage
import numpy as np, soundfile as sf, torch
from ai_edge_litert.interpreter import Interpreter
SR, NFFT, HOP, S = 16000, 400, 100, 32000
wav, _ = sf.read("noisy.wav", dtype="float32") # mono 16 kHz
x = np.zeros(S, np.float32); n = min(len(wav), S); x[:n] = wav[:n]
c = np.sqrt(S / (x @ x + 1e-12)); x *= c # RMS normalize
xp = np.concatenate([x[NFFT//2:0:-1], x, x[-2:-NFFT//2-2:-1]]) # reflect pad -> [32400]
it = Interpreter(model_path="cmgan_fp16.tflite"); it.allocate_tensors()
it.set_tensor(it.get_input_details()[0]["index"], xp[None]); it.invoke()
r, i = (torch.tensor(it.get_tensor(o["index"])) for o in
sorted(it.get_output_details(), key=lambda o: o["index"])) # [1,1,321,201] compressed
m2 = (r * r + i * i).clamp_min(1e-12) ** (7.0 / 6.0) # mag^0.3 -> mag
spec = torch.complex(r * m2, i * m2)[0, 0].T # [201, 321]
den = torch.istft(spec, NFFT, HOP, window=torch.hamming_window(NFFT), length=S)
sf.write("denoised.wav", (den.numpy() / c), SR)
Kotlin (Android, LiteRT CompiledModel GPU)
// implementation("com.google.ai.edge.litert:litert:2.1.5")
val model = CompiledModel.create(File(ctx.filesDir, "cmgan_fp16.tflite").absolutePath,
CompiledModel.Options(Accelerator.GPU), null)
val inBuf = model.createInputBuffers()
val outBuf = model.createOutputBuffers()
// 2 s chunk, RMS-normalized (x*c), reflect-padded by 200 both sides -> [32400]
inBuf[0].writeFloat(paddedChunk)
model.run(inBuf, outBuf)
val real = outBuf[0].readFloat() // [321 * 201] compressed real, t-major (t*201 + f)
val imag = outBuf[1].readFloat()
// host: scale (r,i) by (r*r+i*i)^(7/6), then iSTFT (n_fft 400 / hop 100 / periodic hamming,
// trim the 200-sample center pad), overlap-add chunks, divide by c —
// see NoiseSuppressor.kt in the speech_enhancement LiteRT sample.
Conversion (numerically-equivalent re-authoring)
The phase path cancels algebraically (mask·mag·cos(∠x) ≡ mask·x_r) — no atan2/cos/sin in the
graph. Shaw relative positional embedding (an Embedding lookup = GATHER) is baked to a constant for
the fixed chunk and applied as a 2D FULLY_CONNECTED plus a pad/reshape skew realignment. The
conformer's folded batches become batch-1 4D tensors (channel-LayerNorm per position, Linears as
1×1 convs, depthwise Conv1d as (1,k) Conv2d, heads folded into the 3D-BMM batch with 1/√d in Q).
mag^0.3 → exp(0.3·ln(·)) (POW is banned); SPConvTranspose2d's 5-D view → an exact 4D reshape
chain; InstanceNorm → safe spatial norm; eval-mode BatchNorm → constant scale/shift; all norm eps
≥ 1e-4 (fp16 min-normal on the GPU delegate); no dim-1 broadcast multiplies.
Upstream
- Code + weights: ruizhecao96/CMGAN (MIT)
- Please cite Cao et al., CMGAN: Conformer-Based Metric-GAN for Monaural Speech Enhancement (Interspeech 2022 / TASLP 2024) when you use this model.
- Downloads last month
- 2
