| """Streaming USEF-TP model. |
| |
| Causal version of USEF-TP designed for chunk-by-chunk inference at 8 kHz. |
| |
| Streaming config (matches the offline training STFT settings so that the |
| checkpoint warm-starts cleanly): |
| |
| * Sample rate : 8000 Hz |
| * Chunk hop H : 64 samples = 8 ms |
| * STFT window W : 128 samples = 16 ms (one new STFT frame per chunk) |
| * Frame look-ahead L : 2 STFT frames = 16 ms |
| |
| Total algorithmic latency = W + L*H = 32 ms (excluding compute). |
| |
| The L=2 look-ahead is built into the architecture by keeping the encoder and |
| TSE/PVAD ConvTranspose2d layers with their original symmetric (kernel=3, pad=1) |
| time padding -- each contributes one frame of right-side context. Every other |
| time-axis layer (the GridNet ``inter_rnn``, the GridNet self-attention, the |
| PVAD ``conv1d``) is strictly causal, so the total look-ahead is exactly the |
| 2 frames spent in the encoder + decoder convs. |
| |
| Compared to the offline ``model_USEF_TP.py``: |
| |
| * The encoder ``GroupNorm`` is dropped (it aggregates over time). |
| * The ``GridNetV2Block`` is replaced by ``StreamingGridNetV2Block`` |
| (unidirectional ``inter_rnn``, causal-masked self-attention via SDPA). |
| * The PVAD ``conv1d`` is given causal left-padding so it adds zero look-ahead. |
| * The ``InteractionModule`` ConvTranspose1d is unchanged -- it is already |
| causal (output[n] depends only on input[n-1] and input[n]). |
| * The global ``std = mix.std(...)`` input/output gain is dropped so that |
| there is nothing utterance-global at streaming time. |
| |
| Encoder, decoder, CMHA and ``intra_rnn`` weights warm-start directly from an |
| existing offline USEF-TP checkpoint; ``inter_rnn`` (bi -> uni) and the |
| self-attention need re-init. |
| """ |
|
|
| import copy |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from local.CMHA import CMHA |
| from local.STFT import STFT, iSTFT |
| from local.StreamingGridNetV2Block import StreamingGridNetV2Block |
|
|
|
|
| class PVADDecoder(nn.Module): |
| """Causal PVAD decoder. |
| |
| ``tconv2d``: same shape as the offline version (kernel (3, 3), padding |
| (1, 1)) -- contributes 1 STFT frame of look-ahead which is part of the |
| model's L=2 budget. |
| |
| ``conv1d``: kernel size 2, but applied with left-only padding so output[t] |
| depends only on input[t-1] and input[t]. Output length matches input length |
| (the offline version produced ``L - 1``). |
| """ |
|
|
| def __init__(self, in_channels, n_freqs, t_ksize=3): |
| super().__init__() |
| ks, padding = (t_ksize, 3), (t_ksize // 2, 1) |
| self.tconv2d = nn.ConvTranspose2d(in_channels, 1, ks, stride=1, padding=padding) |
| self.conv1d = nn.Conv1d(n_freqs, 1, kernel_size=2, stride=1) |
|
|
| def forward(self, Eo): |
| x = self.tconv2d(Eo) |
| x = x.squeeze(1).transpose(1, 2) |
| x = F.pad(x, (1, 0)) |
| return self.conv1d(x) |
|
|
|
|
| class InteractionModule(nn.Module): |
| """Sigmoid -> ConvTranspose1d (kernel 2) -> ReLU. |
| |
| ConvTranspose1d kernel 2 with stride 1 and no padding produces output[t] |
| that depends on input[t-1] and input[t] only -- i.e. it is causal already. |
| Output length is L_in + 1; the model crops to the TSE length. |
| """ |
|
|
| def __init__(self): |
| super().__init__() |
| self.tconv1d = nn.ConvTranspose1d(1, 1, kernel_size=2, stride=1) |
|
|
| def forward(self, Ptgt): |
| p = torch.sigmoid(Ptgt) |
| p = F.relu(self.tconv1d(p)) |
| return p |
|
|
|
|
| class Streaming_USEF_TP(nn.Module): |
| """Streaming-causal USEF-TP. |
| |
| Args: |
| hidden_channels: GridNet LSTM hidden size. |
| n_head: number of attention heads in CMHA and GridNet self-attention. |
| emb_dim: encoder/decoder channel width. |
| emb_ks, emb_hs: GridNet unfold kernel/stride along intra/inter axes. |
| num_layers: number of stacked GridNet blocks. |
| n_fft, hop_length, win_length: STFT parameters. Defaults match the |
| offline USEF-TP training (8 kHz, 16 ms window, 8 ms hop). |
| cmha_approx_qk_dim: approx. Q/K dim for the cross-attention. |
| |
| The number of STFT frequency bins ``n_freqs = n_fft // 2 + 1`` is derived |
| automatically. The STFT/iSTFT modules are constructed internally and stored |
| on the model, so callers don't need to set them up separately. |
| """ |
|
|
| def __init__(self, hidden_channels, n_head, emb_dim, emb_ks, emb_hs, |
| num_layers=6, n_fft=128, hop_length=64, win_length=128, |
| cmha_approx_qk_dim=512, eps=1e-5): |
| super().__init__() |
| self.num_layers = num_layers |
| self.stft = STFT(n_fft=n_fft, hop_length=hop_length, win_length=win_length) |
| self.istft = iSTFT(n_fft=n_fft, hop_length=hop_length, win_length=win_length) |
| n_freqs = n_fft // 2 + 1 |
|
|
| t_ksize = 3 |
| ks, padding = (t_ksize, 3), (t_ksize // 2, 1) |
|
|
| |
| |
| |
| |
| self.encoder = nn.Conv2d(2, emb_dim, ks, padding=padding) |
|
|
| |
| |
| |
| |
| self.cmha = CMHA( |
| emb_dim=emb_dim, n_freqs=n_freqs, n_head=n_head, |
| approx_qk_dim=cmha_approx_qk_dim, eps=eps, |
| ) |
|
|
| |
| self.separator = nn.Sequential(*[ |
| copy.deepcopy( |
| StreamingGridNetV2Block( |
| 2 * emb_dim, emb_ks, emb_hs, n_freqs, hidden_channels, |
| n_head, approx_qk_dim=512, activation="prelu", |
| ) |
| ) for _ in range(num_layers) |
| ]) |
|
|
| |
| |
| self.tse_decoder = nn.ConvTranspose2d( |
| 2 * emb_dim, 2, ks, stride=1, padding=padding |
| ) |
|
|
| |
| self.pvad_decoder = PVADDecoder( |
| in_channels=2 * emb_dim, n_freqs=n_freqs, t_ksize=t_ksize |
| ) |
| self.interaction = InteractionModule() |
|
|
| def forward(self, mix, ref, return_attn=False, return_no_mask=False): |
| """Full-sequence (training) forward. |
| |
| Inputs are 1D waveforms ``[B, T]`` at 8 kHz. The model is causal up to |
| the L=2 STFT-frame look-ahead built into the encoder/decoder kernels; |
| no additional masking is required during training. |
| """ |
| mix = mix.unsqueeze(1) |
| ref = ref.unsqueeze(1) |
|
|
| mix_c = self.stft(mix)[-1] |
| ref_c = self.stft(ref)[-1] |
|
|
| mix_ri = torch.cat([mix_c.real, mix_c.imag], dim=1).permute(0, 1, 3, 2).contiguous() |
| ref_ri = torch.cat([ref_c.real, ref_c.imag], dim=1).permute(0, 1, 3, 2).contiguous() |
|
|
| Em = self.encoder(mix_ri) |
| Er = self.encoder(ref_ri) |
|
|
| if return_attn: |
| Espk, attn = self.cmha(Em, Er, return_attn=True) |
| else: |
| Espk = self.cmha(Em, Er) |
|
|
| Ef = torch.cat([Em, Espk], dim=1) |
| Eo = self.separator(Ef) |
|
|
| Dtse = self.tse_decoder(Eo) |
| Ptgt = self.pvad_decoder(Eo) |
| Pi = self.interaction(Ptgt) |
|
|
| L_m = Dtse.shape[2] |
| if Pi.shape[-1] < L_m: |
| Pi = F.pad(Pi, (0, L_m - Pi.shape[-1])) |
| elif Pi.shape[-1] > L_m: |
| Pi = Pi[..., :L_m] |
|
|
| mask = Pi.unsqueeze(-1).expand(-1, 2, -1, Dtse.shape[-1]) |
| Xf = Dtse * mask |
|
|
| out_r = Xf[:, 0, :, :].permute(0, 2, 1).contiguous() |
| out_i = Xf[:, 1, :, :].permute(0, 2, 1).contiguous() |
|
|
| Xtgt = self.istft((out_r, out_i), input_type="real_imag").unsqueeze(1) |
|
|
| if return_no_mask: |
| out_r_nm = Dtse[:, 0, :, :].permute(0, 2, 1).contiguous() |
| out_i_nm = Dtse[:, 1, :, :].permute(0, 2, 1).contiguous() |
| Xtgt_nomask = self.istft( |
| (out_r_nm, out_i_nm), input_type="real_imag" |
| ).unsqueeze(1) |
|
|
| if return_attn and return_no_mask: |
| return Xtgt.squeeze(1), Ptgt, attn, Xtgt_nomask.squeeze(1) |
| if return_attn: |
| return Xtgt.squeeze(1), Ptgt, attn |
| if return_no_mask: |
| return Xtgt.squeeze(1), Ptgt, Xtgt_nomask.squeeze(1) |
| return Xtgt.squeeze(1), Ptgt |
|
|