| """Optimized stateful inference path for Streaming USEF-TP. |
| |
| This module keeps the full-sequence ``forward`` behavior compatible with |
| ``model_streaming_usef_tp.py`` while adding an explicit chunk-by-chunk |
| ``stream_step`` API. The optimized path caches reference-side CMHA tensors, |
| rolling STFT/decoder/iSTFT context, temporal LSTM state, and GridNet |
| self-attention K/V history. |
| """ |
|
|
| import copy |
| import math |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from local.CMHA import CMHA |
| from local.STFT import STFT, iSTFT |
| from local.StreamingGridNetV2Block import StreamingGridNetV2Block |
| from model_streaming_usef_tp import InteractionModule, PVADDecoder |
|
|
|
|
| class OptimizedStreamingGridNetV2Block(StreamingGridNetV2Block): |
| """Streaming step extension for ``StreamingGridNetV2Block``. |
| |
| The step path specializes the common real-time configuration |
| ``emb_ks == emb_hs == 1``. The inherited ``forward`` remains available for |
| full-sequence training/evaluation and state-dict compatibility. |
| """ |
|
|
| def init_stream_state(self, batch_size, n_freqs, device, dtype=None, |
| max_attention_frames=None): |
| dtype = dtype or next(self.parameters()).dtype |
| hidden = self.inter_rnn.hidden_size |
| state = { |
| "inter_h": torch.zeros(1, batch_size * n_freqs, hidden, device=device, dtype=dtype), |
| "inter_c": torch.zeros(1, batch_size * n_freqs, hidden, device=device, dtype=dtype), |
| "attn_k": None, |
| "attn_v": None, |
| "max_attention_frames": max_attention_frames, |
| } |
| return state |
|
|
| def stream_step(self, x, state): |
| """Process one mature time frame. |
| |
| Args: |
| x: ``[B, C, 1, F]`` frame. |
| state: state from ``init_stream_state``. |
| Returns: |
| Tuple ``(out, updated_state)`` with ``out`` shaped ``[B, C, 1, F]``. |
| """ |
| if self.emb_ks != 1 or self.emb_hs != 1: |
| raise NotImplementedError( |
| "Optimized stream_step currently requires emb_ks == emb_hs == 1." |
| ) |
|
|
| B, C, T, Q = x.shape |
| if T != 1: |
| raise ValueError(f"stream_step expects one time frame, got T={T}") |
|
|
| frame = x.permute(0, 2, 3, 1) |
|
|
| input_ = frame |
| intra = self.intra_norm(input_) |
| intra = intra.reshape(B, Q, C) |
| intra, _ = self.intra_rnn(intra) |
| intra = self.intra_linear(intra) |
| intra = intra.reshape(B, 1, Q, C) |
| intra = intra + input_ |
|
|
| intra = intra.transpose(1, 2) |
|
|
| input_ = intra |
| inter = self.inter_norm(input_) |
| inter = inter.reshape(B * Q, 1, C) |
| inter, (h, c) = self.inter_rnn(inter, (state["inter_h"], state["inter_c"])) |
| state["inter_h"] = h |
| state["inter_c"] = c |
| inter = self.inter_linear(inter) |
| inter = inter.reshape(B, Q, 1, C) |
| inter = inter + input_ |
|
|
| inter = inter.permute(0, 3, 2, 1).contiguous() |
|
|
| q = self["attn_norm_Q"](self["attn_conv_Q"](inter)) |
| k = self["attn_norm_K"](self["attn_conv_K"](inter)) |
| v = self["attn_norm_V"](self["attn_conv_V"](inter)) |
|
|
| q = q.reshape(-1, *q.shape[2:]).transpose(1, 2).flatten(start_dim=2) |
| k = k.reshape(-1, *k.shape[2:]).transpose(1, 2).flatten(start_dim=2) |
| v = v.reshape(-1, *v.shape[2:]).transpose(1, 2) |
| v_shape = v.shape |
| v = v.flatten(start_dim=2) |
|
|
| if state["attn_k"] is None: |
| k_cache = k |
| v_cache = v |
| else: |
| k_cache = torch.cat([state["attn_k"], k], dim=1) |
| v_cache = torch.cat([state["attn_v"], v], dim=1) |
|
|
| max_frames = state.get("max_attention_frames") |
| if max_frames is not None and k_cache.shape[1] > max_frames: |
| k_cache = k_cache[:, -max_frames:, :].contiguous() |
| v_cache = v_cache[:, -max_frames:, :].contiguous() |
|
|
| state["attn_k"] = k_cache |
| state["attn_v"] = v_cache |
|
|
| attn = F.scaled_dot_product_attention(q, k_cache, v_cache, is_causal=False) |
| attn = attn.reshape(v_shape).transpose(1, 2) |
|
|
| head_dim = attn.shape[1] |
| attn = attn.contiguous().reshape(B, self.n_head * head_dim, 1, Q) |
| attn = self["attn_concat_proj"](attn) |
| return attn + inter, state |
|
|
|
|
| class Streaming_USEF_TP_Optimized(nn.Module): |
| """Streaming USEF-TP with cached, stateful PyTorch inference.""" |
|
|
| def __init__(self, hidden_channels, n_head, emb_dim, emb_ks, emb_hs, |
| num_layers=6, n_fft=128, hop_length=64, win_length=128, |
| cmha_approx_qk_dim=512, eps=1e-5, |
| max_attention_frames=None): |
| super().__init__() |
| self.num_layers = num_layers |
| self.n_fft = n_fft |
| self.hop_length = hop_length |
| self.win_length = win_length |
| self.n_freqs = n_fft // 2 + 1 |
| self.emb_dim = emb_dim |
| self.max_attention_frames = max_attention_frames |
|
|
| self.stft = STFT(n_fft=n_fft, hop_length=hop_length, win_length=win_length) |
| self.istft = iSTFT(n_fft=n_fft, hop_length=hop_length, win_length=win_length) |
| self.register_buffer("stream_window", torch.hann_window(win_length), persistent=False) |
|
|
| t_ksize = 3 |
| ks, padding = (t_ksize, 3), (t_ksize // 2, 1) |
|
|
| self.encoder = nn.Conv2d(2, emb_dim, ks, padding=padding) |
| self.cmha = CMHA( |
| emb_dim=emb_dim, n_freqs=self.n_freqs, n_head=n_head, |
| approx_qk_dim=cmha_approx_qk_dim, eps=eps, |
| ) |
|
|
| self.separator = nn.ModuleList([ |
| copy.deepcopy( |
| OptimizedStreamingGridNetV2Block( |
| 2 * emb_dim, emb_ks, emb_hs, self.n_freqs, hidden_channels, |
| n_head, approx_qk_dim=512, activation="prelu", |
| ) |
| ) for _ in range(num_layers) |
| ]) |
|
|
| self.tse_decoder = nn.ConvTranspose2d( |
| 2 * emb_dim, 2, ks, stride=1, padding=padding |
| ) |
| self.pvad_decoder = PVADDecoder( |
| in_channels=2 * emb_dim, n_freqs=self.n_freqs, t_ksize=t_ksize |
| ) |
| self.interaction = InteractionModule() |
|
|
| def forward(self, mix, ref, return_attn=False, return_no_mask=False): |
| """Full-sequence compatibility path.""" |
| mix = mix.unsqueeze(1) |
| ref = ref.unsqueeze(1) |
|
|
| mix_c = self.stft(mix)[-1] |
| ref_c = self.stft(ref)[-1] |
|
|
| mix_ri = torch.cat([mix_c.real, mix_c.imag], dim=1).permute(0, 1, 3, 2).contiguous() |
| ref_ri = torch.cat([ref_c.real, ref_c.imag], dim=1).permute(0, 1, 3, 2).contiguous() |
|
|
| Em = self.encoder(mix_ri) |
| Er = self.encoder(ref_ri) |
|
|
| if return_attn: |
| Espk, attn = self.cmha(Em, Er, return_attn=True) |
| else: |
| Espk = self.cmha(Em, Er) |
|
|
| Ef = torch.cat([Em, Espk], dim=1) |
| Eo = Ef |
| for block in self.separator: |
| Eo = block(Eo) |
|
|
| Dtse = self.tse_decoder(Eo) |
| Ptgt = self.pvad_decoder(Eo) |
| Pi = self.interaction(Ptgt) |
|
|
| L_m = Dtse.shape[2] |
| if Pi.shape[-1] < L_m: |
| Pi = F.pad(Pi, (0, L_m - Pi.shape[-1])) |
| elif Pi.shape[-1] > L_m: |
| Pi = Pi[..., :L_m] |
|
|
| mask = Pi.unsqueeze(-1).expand(-1, 2, -1, Dtse.shape[-1]) |
| Xf = Dtse * mask |
|
|
| out_r = Xf[:, 0, :, :].permute(0, 2, 1).contiguous() |
| out_i = Xf[:, 1, :, :].permute(0, 2, 1).contiguous() |
| Xtgt = self.istft((out_r, out_i), input_type="real_imag").unsqueeze(1) |
|
|
| if return_no_mask: |
| out_r_nm = Dtse[:, 0, :, :].permute(0, 2, 1).contiguous() |
| out_i_nm = Dtse[:, 1, :, :].permute(0, 2, 1).contiguous() |
| Xtgt_nomask = self.istft((out_r_nm, out_i_nm), input_type="real_imag").unsqueeze(1) |
|
|
| if return_attn and return_no_mask: |
| return Xtgt.squeeze(1), Ptgt, attn, Xtgt_nomask.squeeze(1) |
| if return_attn: |
| return Xtgt.squeeze(1), Ptgt, attn |
| if return_no_mask: |
| return Xtgt.squeeze(1), Ptgt, Xtgt_nomask.squeeze(1) |
| return Xtgt.squeeze(1), Ptgt |
|
|
| def reset_stream_state(self, batch_size, device, dtype=None): |
| return self.init_stream(batch_size, device, dtype=dtype) |
|
|
| def init_stream(self, batch_size, device, dtype=None): |
| dtype = dtype or next(self.parameters()).dtype |
| zeros_ri = torch.zeros(batch_size, 2, 1, self.n_freqs, device=device, dtype=dtype) |
| zeros_eo = torch.zeros( |
| batch_size, 2 * self.emb_dim, 1, self.n_freqs, device=device, dtype=dtype |
| ) |
| window_len = self.win_length |
|
|
| return { |
| "sample_buffer": torch.zeros(batch_size, window_len, device=device, dtype=dtype), |
| "input_buffer": torch.zeros(batch_size, 0, device=device, dtype=dtype), |
| "encoder_frames": zeros_ri, |
| "decoder_eo": zeros_eo, |
| "pvad_conv_prev": torch.zeros(batch_size, self.n_freqs, 1, device=device, dtype=dtype), |
| "interaction_prev": torch.zeros(batch_size, 1, 1, device=device, dtype=dtype), |
| "istft_ola": torch.zeros(batch_size, self.n_fft, device=device, dtype=dtype), |
| "istft_norm": torch.zeros(batch_size, self.n_fft, device=device, dtype=dtype), |
| "separator": [ |
| block.init_stream_state( |
| batch_size, self.n_freqs, device, dtype=dtype, |
| max_attention_frames=self.max_attention_frames, |
| ) |
| for block in self.separator |
| ], |
| "frames_seen": 0, |
| } |
|
|
| @torch.no_grad() |
| def prepare_reference(self, ref): |
| """Precompute reference encoding and CMHA K/V tensors. |
| |
| Args: |
| ref: waveform ``[B, T]`` or ``[B, 1, T]``. |
| Returns: |
| A cache dictionary consumed by ``stream_step``. |
| """ |
| if ref.dim() == 2: |
| ref = ref.unsqueeze(1) |
| elif ref.dim() != 3: |
| raise ValueError(f"Expected ref with shape [B,T] or [B,1,T], got {tuple(ref.shape)}") |
|
|
| ref_c = self.stft(ref)[-1] |
| ref_ri = torch.cat([ref_c.real, ref_c.imag], dim=1).permute(0, 1, 3, 2).contiguous() |
| Er = self.encoder(ref_ri) |
|
|
| K = self.cmha["attn_norm_K"](self.cmha["attn_conv_K"](Er)) |
| V = self.cmha["attn_norm_V"](self.cmha["attn_conv_V"](Er)) |
| B = Er.shape[0] |
| Lr = Er.shape[-2] |
|
|
| K = K.reshape(-1, *K.shape[2:]) |
| V = V.reshape(-1, *V.shape[2:]) |
| K = K.transpose(2, 3).contiguous().reshape(B * self.cmha.n_head, -1, Lr) |
| V = V.transpose(1, 2).flatten(start_dim=2).contiguous() |
|
|
| return { |
| "K": K, |
| "V": V, |
| "Lr": Lr, |
| "batch_size": B, |
| "qk_dim": K.shape[1], |
| } |
|
|
| def _cmha_stream_step(self, Em, ref_cache, return_attn=False): |
| B, _, Lm, _ = Em.shape |
| if Lm != 1: |
| raise ValueError(f"CMHA stream step expects one frame, got Lm={Lm}") |
| if ref_cache["batch_size"] != B: |
| raise ValueError( |
| f"Reference cache batch size {ref_cache['batch_size']} does not match chunk batch {B}" |
| ) |
|
|
| Q = self.cmha["attn_norm_Q"](self.cmha["attn_conv_Q"](Em)) |
| Q = Q.reshape(-1, *Q.shape[2:]).transpose(1, 2).flatten(start_dim=2) |
|
|
| attn = torch.matmul(Q, ref_cache["K"]) / math.sqrt(ref_cache["qk_dim"]) |
| attn = F.softmax(attn, dim=2) |
| out = torch.matmul(attn, ref_cache["V"]) |
|
|
| out = out.reshape(B * self.cmha.n_head, 1, -1, self.n_freqs).transpose(1, 2) |
| head_dim = out.shape[1] |
| out = out.contiguous().reshape(B, self.cmha.n_head * head_dim, 1, self.n_freqs) |
| out = self.cmha["attn_concat_proj"](out) |
|
|
| if return_attn: |
| return out, attn.reshape(B, self.cmha.n_head, 1, ref_cache["Lr"]).detach() |
| return out |
|
|
| def _stft_stream_frame(self, chunk, state): |
| if chunk.dim() == 3: |
| if chunk.shape[1] != 1: |
| raise ValueError("stream_step expects mono chunks shaped [B,H] or [B,1,H]") |
| chunk = chunk.squeeze(1) |
| if chunk.dim() != 2: |
| raise ValueError(f"stream_step expects chunk [B,H] or [B,1,H], got {tuple(chunk.shape)}") |
| if chunk.shape[-1] != self.hop_length: |
| raise ValueError( |
| f"stream_step expects {self.hop_length} samples per chunk, got {chunk.shape[-1]}" |
| ) |
|
|
| state["sample_buffer"] = torch.cat( |
| [state["sample_buffer"][:, chunk.shape[-1]:], chunk], dim=-1 |
| ) |
| window = self.stream_window.to(device=chunk.device, dtype=chunk.dtype) |
| frame = torch.fft.rfft(state["sample_buffer"] * window, n=self.n_fft) |
| frame = torch.stack([frame.real, frame.imag], dim=1).unsqueeze(2) |
| return frame, state |
|
|
| def _encoder_stream_step(self, stft_frame, state): |
| frames = torch.cat([state["encoder_frames"], stft_frame], dim=2) |
| state["encoder_frames"] = frames[:, :, -2:, :].contiguous() |
| if frames.shape[2] < 3: |
| return None, state |
| Em = self.encoder(frames[:, :, -3:, :])[:, :, 1:2, :] |
| return Em, state |
|
|
| def _decoder_stream_step(self, Eo, state): |
| frames = torch.cat([state["decoder_eo"], Eo], dim=2) |
| state["decoder_eo"] = frames[:, :, -2:, :].contiguous() |
| if frames.shape[2] < 3: |
| return None, None, state |
|
|
| window = frames[:, :, -3:, :] |
| Dtse = self.tse_decoder(window)[:, :, 1:2, :] |
|
|
| pvad_2d = self.pvad_decoder.tconv2d(window)[:, :, 1:2, :] |
| pvad_feat = pvad_2d.squeeze(1).transpose(1, 2) |
| pvad_in = torch.cat([state["pvad_conv_prev"], pvad_feat], dim=-1) |
| Ptgt = self.pvad_decoder.conv1d(pvad_in) |
| state["pvad_conv_prev"] = pvad_feat |
|
|
| p = torch.sigmoid(Ptgt) |
| interaction_in = torch.cat([state["interaction_prev"], p], dim=-1) |
| Pi = F.relu(self.interaction.tconv1d(interaction_in))[..., 1:2] |
| state["interaction_prev"] = p |
|
|
| mask = Pi.unsqueeze(-1).expand(-1, 2, -1, Dtse.shape[-1]) |
| return Dtse * mask, Ptgt, state |
|
|
| def _istft_stream_step(self, Xf, state): |
| real = Xf[:, 0, 0, :] |
| imag = Xf[:, 1, 0, :] |
| frame = torch.fft.irfft(torch.complex(real, imag), n=self.n_fft) |
| window = self.stream_window.to(device=Xf.device, dtype=Xf.dtype) |
| frame = frame * window |
|
|
| state["istft_ola"][:, :self.n_fft] += frame |
| state["istft_norm"][:, :self.n_fft] += window.square().unsqueeze(0) |
|
|
| denom = state["istft_norm"][:, :self.hop_length].clamp_min(1e-8) |
| chunk = state["istft_ola"][:, :self.hop_length] / denom |
|
|
| zeros = torch.zeros_like(state["istft_ola"][:, :self.hop_length]) |
| state["istft_ola"] = torch.cat([state["istft_ola"][:, self.hop_length:], zeros], dim=-1) |
| state["istft_norm"] = torch.cat([state["istft_norm"][:, self.hop_length:], zeros], dim=-1) |
| return chunk, state |
|
|
| def _stream_step_impl(self, chunk, state, ref_cache, return_attn=False): |
| """Run one hop-sized streaming step and report output maturity. |
| |
| Returns: |
| ``(audio_chunk, state, pvad_frame, attn, ready)``. ``ready`` is |
| false during encoder/decoder warm-up, when returned audio is only a |
| placeholder used by the low-level ``stream_step`` compatibility API. |
| """ |
| if chunk.dim() == 3: |
| batch_size = chunk.shape[0] |
| device = chunk.device |
| dtype = chunk.dtype |
| else: |
| batch_size = chunk.shape[0] |
| device = chunk.device |
| dtype = chunk.dtype |
| zero_audio = torch.zeros(batch_size, self.hop_length, device=device, dtype=dtype) |
| zero_pvad = torch.zeros(batch_size, 1, 1, device=device, dtype=dtype) |
|
|
| stft_frame, state = self._stft_stream_frame(chunk, state) |
| Em, state = self._encoder_stream_step(stft_frame, state) |
| if Em is None: |
| return zero_audio, state, zero_pvad, None, False |
|
|
| if return_attn: |
| Espk, attn = self._cmha_stream_step(Em, ref_cache, return_attn=True) |
| else: |
| Espk = self._cmha_stream_step(Em, ref_cache) |
| attn = None |
|
|
| Eo = torch.cat([Em, Espk], dim=1) |
| for idx, block in enumerate(self.separator): |
| Eo, state["separator"][idx] = block.stream_step(Eo, state["separator"][idx]) |
|
|
| Xf, Ptgt, state = self._decoder_stream_step(Eo, state) |
| if Xf is None: |
| return zero_audio, state, zero_pvad, attn, False |
|
|
| audio, state = self._istft_stream_step(Xf, state) |
| state["frames_seen"] += 1 |
| return audio, state, Ptgt, attn, True |
|
|
| @torch.no_grad() |
| def stream_step(self, chunk, state, ref_cache, return_attn=False): |
| """Run one 8 ms streaming step. |
| |
| This low-level API always returns one hop of audio, using zeros during |
| warm-up. Prefer ``stream`` for application code that feeds arbitrary |
| audio lengths and only wants mature output. |
| |
| Returns: |
| ``(audio_chunk, state, pvad_frame)`` or |
| ``(audio_chunk, state, pvad_frame, attn)`` when ``return_attn=True``. |
| """ |
| audio, state, Ptgt, attn, _ = self._stream_step_impl( |
| chunk, state, ref_cache, return_attn=return_attn |
| ) |
| if return_attn: |
| return audio, state, Ptgt, attn |
| return audio, state, Ptgt |
|
|
| @torch.no_grad() |
| def stream(self, audio, state, ref_cache, return_attn=False): |
| """Accept any number of samples and return only mature streaming output. |
| |
| ``audio`` may be shaped ``[B, N]`` or ``[B, 1, N]``. Samples that do not |
| complete a hop are buffered in ``state["input_buffer"]`` for the next |
| call. During cold start, this method returns an empty audio tensor until |
| the STFT/encoder/decoder alignment has enough context. |
| |
| Returns: |
| ``(audio_out, state, pvad_frames)`` or |
| ``(audio_out, state, pvad_frames, attn_frames)`` when |
| ``return_attn=True``. ``audio_out`` has shape ``[B, M]`` where |
| ``M`` may be zero. |
| """ |
| if audio.dim() == 3: |
| if audio.shape[1] != 1: |
| raise ValueError("stream expects mono audio shaped [B,N] or [B,1,N]") |
| audio = audio.squeeze(1) |
| if audio.dim() != 2: |
| raise ValueError(f"stream expects audio [B,N] or [B,1,N], got {tuple(audio.shape)}") |
|
|
| buffered = state.get("input_buffer") |
| if buffered is None: |
| buffered = torch.zeros(audio.shape[0], 0, device=audio.device, dtype=audio.dtype) |
| if buffered.shape[0] != audio.shape[0]: |
| raise ValueError( |
| f"Buffered batch size {buffered.shape[0]} does not match audio batch {audio.shape[0]}" |
| ) |
|
|
| pending = torch.cat([buffered.to(device=audio.device, dtype=audio.dtype), audio], dim=-1) |
| n_hops = pending.shape[-1] // self.hop_length |
| consume = n_hops * self.hop_length |
| state["input_buffer"] = pending[:, consume:].contiguous() |
|
|
| chunks = [] |
| pvads = [] |
| attns = [] |
| for idx in range(n_hops): |
| start = idx * self.hop_length |
| chunk = pending[:, start:start + self.hop_length] |
| out, state, pvad, attn, ready = self._stream_step_impl( |
| chunk, state, ref_cache, return_attn=return_attn |
| ) |
| if ready: |
| chunks.append(out) |
| pvads.append(pvad) |
| if return_attn: |
| attns.append(attn) |
|
|
| if chunks: |
| audio_out = torch.cat(chunks, dim=-1) |
| pvad_out = torch.cat(pvads, dim=-1) |
| else: |
| audio_out = torch.zeros(audio.shape[0], 0, device=audio.device, dtype=audio.dtype) |
| pvad_out = torch.zeros(audio.shape[0], 1, 0, device=audio.device, dtype=audio.dtype) |
|
|
| if return_attn: |
| attn_out = torch.cat(attns, dim=2) if attns else None |
| return audio_out, state, pvad_out, attn_out |
| return audio_out, state, pvad_out |
|
|
|
|
| OptimizedStreaming_USEF_TP = Streaming_USEF_TP_Optimized |
|
|