Streaming-USEF-TP / model_streaming_usef_tp.py
VMoorjani's picture
V3 slight improvement over V2.
eefb734 verified
"""Streaming USEF-TP model.
Causal version of USEF-TP designed for chunk-by-chunk inference at 8 kHz.
Streaming config (matches the offline training STFT settings so that the
checkpoint warm-starts cleanly):
* Sample rate : 8000 Hz
* Chunk hop H : 64 samples = 8 ms
* STFT window W : 128 samples = 16 ms (one new STFT frame per chunk)
* Frame look-ahead L : 2 STFT frames = 16 ms
Total algorithmic latency = W + L*H = 32 ms (excluding compute).
The L=2 look-ahead is built into the architecture by keeping the encoder and
TSE/PVAD ConvTranspose2d layers with their original symmetric (kernel=3, pad=1)
time padding -- each contributes one frame of right-side context. Every other
time-axis layer (the GridNet ``inter_rnn``, the GridNet self-attention, the
PVAD ``conv1d``) is strictly causal, so the total look-ahead is exactly the
2 frames spent in the encoder + decoder convs.
Compared to the offline ``model_USEF_TP.py``:
* The encoder ``GroupNorm`` is dropped (it aggregates over time).
* The ``GridNetV2Block`` is replaced by ``StreamingGridNetV2Block``
(unidirectional ``inter_rnn``, causal-masked self-attention via SDPA).
* The PVAD ``conv1d`` is given causal left-padding so it adds zero look-ahead.
* The ``InteractionModule`` ConvTranspose1d is unchanged -- it is already
causal (output[n] depends only on input[n-1] and input[n]).
* The global ``std = mix.std(...)`` input/output gain is dropped so that
there is nothing utterance-global at streaming time.
Encoder, decoder, CMHA and ``intra_rnn`` weights warm-start directly from an
existing offline USEF-TP checkpoint; ``inter_rnn`` (bi -> uni) and the
self-attention need re-init.
"""
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from local.CMHA import CMHA
from local.STFT import STFT, iSTFT
from local.StreamingGridNetV2Block import StreamingGridNetV2Block
class PVADDecoder(nn.Module):
"""Causal PVAD decoder.
``tconv2d``: same shape as the offline version (kernel (3, 3), padding
(1, 1)) -- contributes 1 STFT frame of look-ahead which is part of the
model's L=2 budget.
``conv1d``: kernel size 2, but applied with left-only padding so output[t]
depends only on input[t-1] and input[t]. Output length matches input length
(the offline version produced ``L - 1``).
"""
def __init__(self, in_channels, n_freqs, t_ksize=3):
super().__init__()
ks, padding = (t_ksize, 3), (t_ksize // 2, 1)
self.tconv2d = nn.ConvTranspose2d(in_channels, 1, ks, stride=1, padding=padding)
self.conv1d = nn.Conv1d(n_freqs, 1, kernel_size=2, stride=1)
def forward(self, Eo):
x = self.tconv2d(Eo) # [B, 1, L, F]
x = x.squeeze(1).transpose(1, 2) # [B, F, L]
x = F.pad(x, (1, 0)) # causal left-pad on time
return self.conv1d(x) # [B, 1, L]
class InteractionModule(nn.Module):
"""Sigmoid -> ConvTranspose1d (kernel 2) -> ReLU.
ConvTranspose1d kernel 2 with stride 1 and no padding produces output[t]
that depends on input[t-1] and input[t] only -- i.e. it is causal already.
Output length is L_in + 1; the model crops to the TSE length.
"""
def __init__(self):
super().__init__()
self.tconv1d = nn.ConvTranspose1d(1, 1, kernel_size=2, stride=1)
def forward(self, Ptgt):
p = torch.sigmoid(Ptgt)
p = F.relu(self.tconv1d(p))
return p # [B, 1, L + 1]
class Streaming_USEF_TP(nn.Module):
"""Streaming-causal USEF-TP.
Args:
hidden_channels: GridNet LSTM hidden size.
n_head: number of attention heads in CMHA and GridNet self-attention.
emb_dim: encoder/decoder channel width.
emb_ks, emb_hs: GridNet unfold kernel/stride along intra/inter axes.
num_layers: number of stacked GridNet blocks.
n_fft, hop_length, win_length: STFT parameters. Defaults match the
offline USEF-TP training (8 kHz, 16 ms window, 8 ms hop).
cmha_approx_qk_dim: approx. Q/K dim for the cross-attention.
The number of STFT frequency bins ``n_freqs = n_fft // 2 + 1`` is derived
automatically. The STFT/iSTFT modules are constructed internally and stored
on the model, so callers don't need to set them up separately.
"""
def __init__(self, hidden_channels, n_head, emb_dim, emb_ks, emb_hs,
num_layers=6, n_fft=128, hop_length=64, win_length=128,
cmha_approx_qk_dim=512, eps=1e-5):
super().__init__()
self.num_layers = num_layers
self.stft = STFT(n_fft=n_fft, hop_length=hop_length, win_length=win_length)
self.istft = iSTFT(n_fft=n_fft, hop_length=hop_length, win_length=win_length)
n_freqs = n_fft // 2 + 1
t_ksize = 3
ks, padding = (t_ksize, 3), (t_ksize // 2, 1)
# Encoder: shared between mixture and reference (post-STFT).
# GroupNorm from the offline model is dropped (it aggregates over time).
# Kernel 3 with pad 1 in time gives 1 frame of look-ahead -- the first
# half of the model's L=2 budget.
self.encoder = nn.Conv2d(2, emb_dim, ks, padding=padding)
# Cross multi-head attention: Q from E_m, K/V from E_r.
# Already streaming-compatible: each query frame attends across the
# reference time axis only, so the reference is encoded once at
# enrollment and K/V are reused for every chunk.
self.cmha = CMHA(
emb_dim=emb_dim, n_freqs=n_freqs, n_head=n_head,
approx_qk_dim=cmha_approx_qk_dim, eps=eps,
)
# Separator: stack of streaming-causal TF-GridNet blocks on 2C channels.
self.separator = nn.Sequential(*[
copy.deepcopy(
StreamingGridNetV2Block(
2 * emb_dim, emb_ks, emb_hs, n_freqs, hidden_channels,
n_head, approx_qk_dim=512, activation="prelu",
)
) for _ in range(num_layers)
])
# TSE decoder: TConv2d(2C -> 2) -- one frame of time look-ahead, the
# second half of the L=2 budget.
self.tse_decoder = nn.ConvTranspose2d(
2 * emb_dim, 2, ks, stride=1, padding=padding
)
# PVAD decoder and interaction module.
self.pvad_decoder = PVADDecoder(
in_channels=2 * emb_dim, n_freqs=n_freqs, t_ksize=t_ksize
)
self.interaction = InteractionModule()
def forward(self, mix, ref, return_attn=False, return_no_mask=False):
"""Full-sequence (training) forward.
Inputs are 1D waveforms ``[B, T]`` at 8 kHz. The model is causal up to
the L=2 STFT-frame look-ahead built into the encoder/decoder kernels;
no additional masking is required during training.
"""
mix = mix.unsqueeze(1)
ref = ref.unsqueeze(1)
mix_c = self.stft(mix)[-1]
ref_c = self.stft(ref)[-1]
mix_ri = torch.cat([mix_c.real, mix_c.imag], dim=1).permute(0, 1, 3, 2).contiguous()
ref_ri = torch.cat([ref_c.real, ref_c.imag], dim=1).permute(0, 1, 3, 2).contiguous()
Em = self.encoder(mix_ri)
Er = self.encoder(ref_ri)
if return_attn:
Espk, attn = self.cmha(Em, Er, return_attn=True)
else:
Espk = self.cmha(Em, Er)
Ef = torch.cat([Em, Espk], dim=1)
Eo = self.separator(Ef)
Dtse = self.tse_decoder(Eo)
Ptgt = self.pvad_decoder(Eo)
Pi = self.interaction(Ptgt)
L_m = Dtse.shape[2]
if Pi.shape[-1] < L_m:
Pi = F.pad(Pi, (0, L_m - Pi.shape[-1]))
elif Pi.shape[-1] > L_m:
Pi = Pi[..., :L_m]
mask = Pi.unsqueeze(-1).expand(-1, 2, -1, Dtse.shape[-1])
Xf = Dtse * mask
out_r = Xf[:, 0, :, :].permute(0, 2, 1).contiguous()
out_i = Xf[:, 1, :, :].permute(0, 2, 1).contiguous()
Xtgt = self.istft((out_r, out_i), input_type="real_imag").unsqueeze(1)
if return_no_mask:
out_r_nm = Dtse[:, 0, :, :].permute(0, 2, 1).contiguous()
out_i_nm = Dtse[:, 1, :, :].permute(0, 2, 1).contiguous()
Xtgt_nomask = self.istft(
(out_r_nm, out_i_nm), input_type="real_imag"
).unsqueeze(1)
if return_attn and return_no_mask:
return Xtgt.squeeze(1), Ptgt, attn, Xtgt_nomask.squeeze(1)
if return_attn:
return Xtgt.squeeze(1), Ptgt, attn
if return_no_mask:
return Xtgt.squeeze(1), Ptgt, Xtgt_nomask.squeeze(1)
return Xtgt.squeeze(1), Ptgt