"""Streaming USEF-TP model. Causal version of USEF-TP designed for chunk-by-chunk inference at 8 kHz. Streaming config (matches the offline training STFT settings so that the checkpoint warm-starts cleanly): * Sample rate : 8000 Hz * Chunk hop H : 64 samples = 8 ms * STFT window W : 128 samples = 16 ms (one new STFT frame per chunk) * Frame look-ahead L : 2 STFT frames = 16 ms Total algorithmic latency = W + L*H = 32 ms (excluding compute). The L=2 look-ahead is built into the architecture by keeping the encoder and TSE/PVAD ConvTranspose2d layers with their original symmetric (kernel=3, pad=1) time padding -- each contributes one frame of right-side context. Every other time-axis layer (the GridNet ``inter_rnn``, the GridNet self-attention, the PVAD ``conv1d``) is strictly causal, so the total look-ahead is exactly the 2 frames spent in the encoder + decoder convs. Compared to the offline ``model_USEF_TP.py``: * The encoder ``GroupNorm`` is dropped (it aggregates over time). * The ``GridNetV2Block`` is replaced by ``StreamingGridNetV2Block`` (unidirectional ``inter_rnn``, causal-masked self-attention via SDPA). * The PVAD ``conv1d`` is given causal left-padding so it adds zero look-ahead. * The ``InteractionModule`` ConvTranspose1d is unchanged -- it is already causal (output[n] depends only on input[n-1] and input[n]). * The global ``std = mix.std(...)`` input/output gain is dropped so that there is nothing utterance-global at streaming time. Encoder, decoder, CMHA and ``intra_rnn`` weights warm-start directly from an existing offline USEF-TP checkpoint; ``inter_rnn`` (bi -> uni) and the self-attention need re-init. """ import copy import torch import torch.nn as nn import torch.nn.functional as F from local.CMHA import CMHA from local.STFT import STFT, iSTFT from local.StreamingGridNetV2Block import StreamingGridNetV2Block class PVADDecoder(nn.Module): """Causal PVAD decoder. ``tconv2d``: same shape as the offline version (kernel (3, 3), padding (1, 1)) -- contributes 1 STFT frame of look-ahead which is part of the model's L=2 budget. ``conv1d``: kernel size 2, but applied with left-only padding so output[t] depends only on input[t-1] and input[t]. Output length matches input length (the offline version produced ``L - 1``). """ def __init__(self, in_channels, n_freqs, t_ksize=3): super().__init__() ks, padding = (t_ksize, 3), (t_ksize // 2, 1) self.tconv2d = nn.ConvTranspose2d(in_channels, 1, ks, stride=1, padding=padding) self.conv1d = nn.Conv1d(n_freqs, 1, kernel_size=2, stride=1) def forward(self, Eo): x = self.tconv2d(Eo) # [B, 1, L, F] x = x.squeeze(1).transpose(1, 2) # [B, F, L] x = F.pad(x, (1, 0)) # causal left-pad on time return self.conv1d(x) # [B, 1, L] class InteractionModule(nn.Module): """Sigmoid -> ConvTranspose1d (kernel 2) -> ReLU. ConvTranspose1d kernel 2 with stride 1 and no padding produces output[t] that depends on input[t-1] and input[t] only -- i.e. it is causal already. Output length is L_in + 1; the model crops to the TSE length. """ def __init__(self): super().__init__() self.tconv1d = nn.ConvTranspose1d(1, 1, kernel_size=2, stride=1) def forward(self, Ptgt): p = torch.sigmoid(Ptgt) p = F.relu(self.tconv1d(p)) return p # [B, 1, L + 1] class Streaming_USEF_TP(nn.Module): """Streaming-causal USEF-TP. Args: hidden_channels: GridNet LSTM hidden size. n_head: number of attention heads in CMHA and GridNet self-attention. emb_dim: encoder/decoder channel width. emb_ks, emb_hs: GridNet unfold kernel/stride along intra/inter axes. num_layers: number of stacked GridNet blocks. n_fft, hop_length, win_length: STFT parameters. Defaults match the offline USEF-TP training (8 kHz, 16 ms window, 8 ms hop). cmha_approx_qk_dim: approx. Q/K dim for the cross-attention. The number of STFT frequency bins ``n_freqs = n_fft // 2 + 1`` is derived automatically. The STFT/iSTFT modules are constructed internally and stored on the model, so callers don't need to set them up separately. """ def __init__(self, hidden_channels, n_head, emb_dim, emb_ks, emb_hs, num_layers=6, n_fft=128, hop_length=64, win_length=128, cmha_approx_qk_dim=512, eps=1e-5): super().__init__() self.num_layers = num_layers self.stft = STFT(n_fft=n_fft, hop_length=hop_length, win_length=win_length) self.istft = iSTFT(n_fft=n_fft, hop_length=hop_length, win_length=win_length) n_freqs = n_fft // 2 + 1 t_ksize = 3 ks, padding = (t_ksize, 3), (t_ksize // 2, 1) # Encoder: shared between mixture and reference (post-STFT). # GroupNorm from the offline model is dropped (it aggregates over time). # Kernel 3 with pad 1 in time gives 1 frame of look-ahead -- the first # half of the model's L=2 budget. self.encoder = nn.Conv2d(2, emb_dim, ks, padding=padding) # Cross multi-head attention: Q from E_m, K/V from E_r. # Already streaming-compatible: each query frame attends across the # reference time axis only, so the reference is encoded once at # enrollment and K/V are reused for every chunk. self.cmha = CMHA( emb_dim=emb_dim, n_freqs=n_freqs, n_head=n_head, approx_qk_dim=cmha_approx_qk_dim, eps=eps, ) # Separator: stack of streaming-causal TF-GridNet blocks on 2C channels. self.separator = nn.Sequential(*[ copy.deepcopy( StreamingGridNetV2Block( 2 * emb_dim, emb_ks, emb_hs, n_freqs, hidden_channels, n_head, approx_qk_dim=512, activation="prelu", ) ) for _ in range(num_layers) ]) # TSE decoder: TConv2d(2C -> 2) -- one frame of time look-ahead, the # second half of the L=2 budget. self.tse_decoder = nn.ConvTranspose2d( 2 * emb_dim, 2, ks, stride=1, padding=padding ) # PVAD decoder and interaction module. self.pvad_decoder = PVADDecoder( in_channels=2 * emb_dim, n_freqs=n_freqs, t_ksize=t_ksize ) self.interaction = InteractionModule() def forward(self, mix, ref, return_attn=False, return_no_mask=False): """Full-sequence (training) forward. Inputs are 1D waveforms ``[B, T]`` at 8 kHz. The model is causal up to the L=2 STFT-frame look-ahead built into the encoder/decoder kernels; no additional masking is required during training. """ mix = mix.unsqueeze(1) ref = ref.unsqueeze(1) mix_c = self.stft(mix)[-1] ref_c = self.stft(ref)[-1] mix_ri = torch.cat([mix_c.real, mix_c.imag], dim=1).permute(0, 1, 3, 2).contiguous() ref_ri = torch.cat([ref_c.real, ref_c.imag], dim=1).permute(0, 1, 3, 2).contiguous() Em = self.encoder(mix_ri) Er = self.encoder(ref_ri) if return_attn: Espk, attn = self.cmha(Em, Er, return_attn=True) else: Espk = self.cmha(Em, Er) Ef = torch.cat([Em, Espk], dim=1) Eo = self.separator(Ef) Dtse = self.tse_decoder(Eo) Ptgt = self.pvad_decoder(Eo) Pi = self.interaction(Ptgt) L_m = Dtse.shape[2] if Pi.shape[-1] < L_m: Pi = F.pad(Pi, (0, L_m - Pi.shape[-1])) elif Pi.shape[-1] > L_m: Pi = Pi[..., :L_m] mask = Pi.unsqueeze(-1).expand(-1, 2, -1, Dtse.shape[-1]) Xf = Dtse * mask out_r = Xf[:, 0, :, :].permute(0, 2, 1).contiguous() out_i = Xf[:, 1, :, :].permute(0, 2, 1).contiguous() Xtgt = self.istft((out_r, out_i), input_type="real_imag").unsqueeze(1) if return_no_mask: out_r_nm = Dtse[:, 0, :, :].permute(0, 2, 1).contiguous() out_i_nm = Dtse[:, 1, :, :].permute(0, 2, 1).contiguous() Xtgt_nomask = self.istft( (out_r_nm, out_i_nm), input_type="real_imag" ).unsqueeze(1) if return_attn and return_no_mask: return Xtgt.squeeze(1), Ptgt, attn, Xtgt_nomask.squeeze(1) if return_attn: return Xtgt.squeeze(1), Ptgt, attn if return_no_mask: return Xtgt.squeeze(1), Ptgt, Xtgt_nomask.squeeze(1) return Xtgt.squeeze(1), Ptgt