"""Optimized stateful inference path for Streaming USEF-TP. This module keeps the full-sequence ``forward`` behavior compatible with ``model_streaming_usef_tp.py`` while adding an explicit chunk-by-chunk ``stream_step`` API. The optimized path caches reference-side CMHA tensors, rolling STFT/decoder/iSTFT context, temporal LSTM state, and GridNet self-attention K/V history. """ import copy import math import torch import torch.nn as nn import torch.nn.functional as F from local.CMHA import CMHA from local.STFT import STFT, iSTFT from local.StreamingGridNetV2Block import StreamingGridNetV2Block from model_streaming_usef_tp import InteractionModule, PVADDecoder class OptimizedStreamingGridNetV2Block(StreamingGridNetV2Block): """Streaming step extension for ``StreamingGridNetV2Block``. The step path specializes the common real-time configuration ``emb_ks == emb_hs == 1``. The inherited ``forward`` remains available for full-sequence training/evaluation and state-dict compatibility. """ def init_stream_state(self, batch_size, n_freqs, device, dtype=None, max_attention_frames=None): dtype = dtype or next(self.parameters()).dtype hidden = self.inter_rnn.hidden_size state = { "inter_h": torch.zeros(1, batch_size * n_freqs, hidden, device=device, dtype=dtype), "inter_c": torch.zeros(1, batch_size * n_freqs, hidden, device=device, dtype=dtype), "attn_k": None, "attn_v": None, "max_attention_frames": max_attention_frames, } return state def stream_step(self, x, state): """Process one mature time frame. Args: x: ``[B, C, 1, F]`` frame. state: state from ``init_stream_state``. Returns: Tuple ``(out, updated_state)`` with ``out`` shaped ``[B, C, 1, F]``. """ if self.emb_ks != 1 or self.emb_hs != 1: raise NotImplementedError( "Optimized stream_step currently requires emb_ks == emb_hs == 1." ) B, C, T, Q = x.shape if T != 1: raise ValueError(f"stream_step expects one time frame, got T={T}") frame = x.permute(0, 2, 3, 1) # [B, 1, F, C] input_ = frame intra = self.intra_norm(input_) intra = intra.reshape(B, Q, C) intra, _ = self.intra_rnn(intra) intra = self.intra_linear(intra) intra = intra.reshape(B, 1, Q, C) intra = intra + input_ intra = intra.transpose(1, 2) # [B, F, 1, C] input_ = intra inter = self.inter_norm(input_) inter = inter.reshape(B * Q, 1, C) inter, (h, c) = self.inter_rnn(inter, (state["inter_h"], state["inter_c"])) state["inter_h"] = h state["inter_c"] = c inter = self.inter_linear(inter) inter = inter.reshape(B, Q, 1, C) inter = inter + input_ inter = inter.permute(0, 3, 2, 1).contiguous() # [B, C, 1, F] q = self["attn_norm_Q"](self["attn_conv_Q"](inter)) k = self["attn_norm_K"](self["attn_conv_K"](inter)) v = self["attn_norm_V"](self["attn_conv_V"](inter)) q = q.reshape(-1, *q.shape[2:]).transpose(1, 2).flatten(start_dim=2) k = k.reshape(-1, *k.shape[2:]).transpose(1, 2).flatten(start_dim=2) v = v.reshape(-1, *v.shape[2:]).transpose(1, 2) v_shape = v.shape v = v.flatten(start_dim=2) if state["attn_k"] is None: k_cache = k v_cache = v else: k_cache = torch.cat([state["attn_k"], k], dim=1) v_cache = torch.cat([state["attn_v"], v], dim=1) max_frames = state.get("max_attention_frames") if max_frames is not None and k_cache.shape[1] > max_frames: k_cache = k_cache[:, -max_frames:, :].contiguous() v_cache = v_cache[:, -max_frames:, :].contiguous() state["attn_k"] = k_cache state["attn_v"] = v_cache attn = F.scaled_dot_product_attention(q, k_cache, v_cache, is_causal=False) attn = attn.reshape(v_shape).transpose(1, 2) head_dim = attn.shape[1] attn = attn.contiguous().reshape(B, self.n_head * head_dim, 1, Q) attn = self["attn_concat_proj"](attn) return attn + inter, state class Streaming_USEF_TP_Optimized(nn.Module): """Streaming USEF-TP with cached, stateful PyTorch inference.""" def __init__(self, hidden_channels, n_head, emb_dim, emb_ks, emb_hs, num_layers=6, n_fft=128, hop_length=64, win_length=128, cmha_approx_qk_dim=512, eps=1e-5, max_attention_frames=None): super().__init__() self.num_layers = num_layers self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length self.n_freqs = n_fft // 2 + 1 self.emb_dim = emb_dim self.max_attention_frames = max_attention_frames self.stft = STFT(n_fft=n_fft, hop_length=hop_length, win_length=win_length) self.istft = iSTFT(n_fft=n_fft, hop_length=hop_length, win_length=win_length) self.register_buffer("stream_window", torch.hann_window(win_length), persistent=False) t_ksize = 3 ks, padding = (t_ksize, 3), (t_ksize // 2, 1) self.encoder = nn.Conv2d(2, emb_dim, ks, padding=padding) self.cmha = CMHA( emb_dim=emb_dim, n_freqs=self.n_freqs, n_head=n_head, approx_qk_dim=cmha_approx_qk_dim, eps=eps, ) self.separator = nn.ModuleList([ copy.deepcopy( OptimizedStreamingGridNetV2Block( 2 * emb_dim, emb_ks, emb_hs, self.n_freqs, hidden_channels, n_head, approx_qk_dim=512, activation="prelu", ) ) for _ in range(num_layers) ]) self.tse_decoder = nn.ConvTranspose2d( 2 * emb_dim, 2, ks, stride=1, padding=padding ) self.pvad_decoder = PVADDecoder( in_channels=2 * emb_dim, n_freqs=self.n_freqs, t_ksize=t_ksize ) self.interaction = InteractionModule() def forward(self, mix, ref, return_attn=False, return_no_mask=False): """Full-sequence compatibility path.""" mix = mix.unsqueeze(1) ref = ref.unsqueeze(1) mix_c = self.stft(mix)[-1] ref_c = self.stft(ref)[-1] mix_ri = torch.cat([mix_c.real, mix_c.imag], dim=1).permute(0, 1, 3, 2).contiguous() ref_ri = torch.cat([ref_c.real, ref_c.imag], dim=1).permute(0, 1, 3, 2).contiguous() Em = self.encoder(mix_ri) Er = self.encoder(ref_ri) if return_attn: Espk, attn = self.cmha(Em, Er, return_attn=True) else: Espk = self.cmha(Em, Er) Ef = torch.cat([Em, Espk], dim=1) Eo = Ef for block in self.separator: Eo = block(Eo) Dtse = self.tse_decoder(Eo) Ptgt = self.pvad_decoder(Eo) Pi = self.interaction(Ptgt) L_m = Dtse.shape[2] if Pi.shape[-1] < L_m: Pi = F.pad(Pi, (0, L_m - Pi.shape[-1])) elif Pi.shape[-1] > L_m: Pi = Pi[..., :L_m] mask = Pi.unsqueeze(-1).expand(-1, 2, -1, Dtse.shape[-1]) Xf = Dtse * mask out_r = Xf[:, 0, :, :].permute(0, 2, 1).contiguous() out_i = Xf[:, 1, :, :].permute(0, 2, 1).contiguous() Xtgt = self.istft((out_r, out_i), input_type="real_imag").unsqueeze(1) if return_no_mask: out_r_nm = Dtse[:, 0, :, :].permute(0, 2, 1).contiguous() out_i_nm = Dtse[:, 1, :, :].permute(0, 2, 1).contiguous() Xtgt_nomask = self.istft((out_r_nm, out_i_nm), input_type="real_imag").unsqueeze(1) if return_attn and return_no_mask: return Xtgt.squeeze(1), Ptgt, attn, Xtgt_nomask.squeeze(1) if return_attn: return Xtgt.squeeze(1), Ptgt, attn if return_no_mask: return Xtgt.squeeze(1), Ptgt, Xtgt_nomask.squeeze(1) return Xtgt.squeeze(1), Ptgt def reset_stream_state(self, batch_size, device, dtype=None): return self.init_stream(batch_size, device, dtype=dtype) def init_stream(self, batch_size, device, dtype=None): dtype = dtype or next(self.parameters()).dtype zeros_ri = torch.zeros(batch_size, 2, 1, self.n_freqs, device=device, dtype=dtype) zeros_eo = torch.zeros( batch_size, 2 * self.emb_dim, 1, self.n_freqs, device=device, dtype=dtype ) window_len = self.win_length return { "sample_buffer": torch.zeros(batch_size, window_len, device=device, dtype=dtype), "input_buffer": torch.zeros(batch_size, 0, device=device, dtype=dtype), "encoder_frames": zeros_ri, "decoder_eo": zeros_eo, "pvad_conv_prev": torch.zeros(batch_size, self.n_freqs, 1, device=device, dtype=dtype), "interaction_prev": torch.zeros(batch_size, 1, 1, device=device, dtype=dtype), "istft_ola": torch.zeros(batch_size, self.n_fft, device=device, dtype=dtype), "istft_norm": torch.zeros(batch_size, self.n_fft, device=device, dtype=dtype), "separator": [ block.init_stream_state( batch_size, self.n_freqs, device, dtype=dtype, max_attention_frames=self.max_attention_frames, ) for block in self.separator ], "frames_seen": 0, } @torch.no_grad() def prepare_reference(self, ref): """Precompute reference encoding and CMHA K/V tensors. Args: ref: waveform ``[B, T]`` or ``[B, 1, T]``. Returns: A cache dictionary consumed by ``stream_step``. """ if ref.dim() == 2: ref = ref.unsqueeze(1) elif ref.dim() != 3: raise ValueError(f"Expected ref with shape [B,T] or [B,1,T], got {tuple(ref.shape)}") ref_c = self.stft(ref)[-1] ref_ri = torch.cat([ref_c.real, ref_c.imag], dim=1).permute(0, 1, 3, 2).contiguous() Er = self.encoder(ref_ri) K = self.cmha["attn_norm_K"](self.cmha["attn_conv_K"](Er)) V = self.cmha["attn_norm_V"](self.cmha["attn_conv_V"](Er)) B = Er.shape[0] Lr = Er.shape[-2] K = K.reshape(-1, *K.shape[2:]) V = V.reshape(-1, *V.shape[2:]) K = K.transpose(2, 3).contiguous().reshape(B * self.cmha.n_head, -1, Lr) V = V.transpose(1, 2).flatten(start_dim=2).contiguous() return { "K": K, "V": V, "Lr": Lr, "batch_size": B, "qk_dim": K.shape[1], } def _cmha_stream_step(self, Em, ref_cache, return_attn=False): B, _, Lm, _ = Em.shape if Lm != 1: raise ValueError(f"CMHA stream step expects one frame, got Lm={Lm}") if ref_cache["batch_size"] != B: raise ValueError( f"Reference cache batch size {ref_cache['batch_size']} does not match chunk batch {B}" ) Q = self.cmha["attn_norm_Q"](self.cmha["attn_conv_Q"](Em)) Q = Q.reshape(-1, *Q.shape[2:]).transpose(1, 2).flatten(start_dim=2) attn = torch.matmul(Q, ref_cache["K"]) / math.sqrt(ref_cache["qk_dim"]) attn = F.softmax(attn, dim=2) out = torch.matmul(attn, ref_cache["V"]) out = out.reshape(B * self.cmha.n_head, 1, -1, self.n_freqs).transpose(1, 2) head_dim = out.shape[1] out = out.contiguous().reshape(B, self.cmha.n_head * head_dim, 1, self.n_freqs) out = self.cmha["attn_concat_proj"](out) if return_attn: return out, attn.reshape(B, self.cmha.n_head, 1, ref_cache["Lr"]).detach() return out def _stft_stream_frame(self, chunk, state): if chunk.dim() == 3: if chunk.shape[1] != 1: raise ValueError("stream_step expects mono chunks shaped [B,H] or [B,1,H]") chunk = chunk.squeeze(1) if chunk.dim() != 2: raise ValueError(f"stream_step expects chunk [B,H] or [B,1,H], got {tuple(chunk.shape)}") if chunk.shape[-1] != self.hop_length: raise ValueError( f"stream_step expects {self.hop_length} samples per chunk, got {chunk.shape[-1]}" ) state["sample_buffer"] = torch.cat( [state["sample_buffer"][:, chunk.shape[-1]:], chunk], dim=-1 ) window = self.stream_window.to(device=chunk.device, dtype=chunk.dtype) frame = torch.fft.rfft(state["sample_buffer"] * window, n=self.n_fft) frame = torch.stack([frame.real, frame.imag], dim=1).unsqueeze(2) return frame, state def _encoder_stream_step(self, stft_frame, state): frames = torch.cat([state["encoder_frames"], stft_frame], dim=2) state["encoder_frames"] = frames[:, :, -2:, :].contiguous() if frames.shape[2] < 3: return None, state Em = self.encoder(frames[:, :, -3:, :])[:, :, 1:2, :] return Em, state def _decoder_stream_step(self, Eo, state): frames = torch.cat([state["decoder_eo"], Eo], dim=2) state["decoder_eo"] = frames[:, :, -2:, :].contiguous() if frames.shape[2] < 3: return None, None, state window = frames[:, :, -3:, :] Dtse = self.tse_decoder(window)[:, :, 1:2, :] pvad_2d = self.pvad_decoder.tconv2d(window)[:, :, 1:2, :] pvad_feat = pvad_2d.squeeze(1).transpose(1, 2) # [B, F, 1] pvad_in = torch.cat([state["pvad_conv_prev"], pvad_feat], dim=-1) Ptgt = self.pvad_decoder.conv1d(pvad_in) state["pvad_conv_prev"] = pvad_feat p = torch.sigmoid(Ptgt) interaction_in = torch.cat([state["interaction_prev"], p], dim=-1) Pi = F.relu(self.interaction.tconv1d(interaction_in))[..., 1:2] state["interaction_prev"] = p mask = Pi.unsqueeze(-1).expand(-1, 2, -1, Dtse.shape[-1]) return Dtse * mask, Ptgt, state def _istft_stream_step(self, Xf, state): real = Xf[:, 0, 0, :] imag = Xf[:, 1, 0, :] frame = torch.fft.irfft(torch.complex(real, imag), n=self.n_fft) window = self.stream_window.to(device=Xf.device, dtype=Xf.dtype) frame = frame * window state["istft_ola"][:, :self.n_fft] += frame state["istft_norm"][:, :self.n_fft] += window.square().unsqueeze(0) denom = state["istft_norm"][:, :self.hop_length].clamp_min(1e-8) chunk = state["istft_ola"][:, :self.hop_length] / denom zeros = torch.zeros_like(state["istft_ola"][:, :self.hop_length]) state["istft_ola"] = torch.cat([state["istft_ola"][:, self.hop_length:], zeros], dim=-1) state["istft_norm"] = torch.cat([state["istft_norm"][:, self.hop_length:], zeros], dim=-1) return chunk, state def _stream_step_impl(self, chunk, state, ref_cache, return_attn=False): """Run one hop-sized streaming step and report output maturity. Returns: ``(audio_chunk, state, pvad_frame, attn, ready)``. ``ready`` is false during encoder/decoder warm-up, when returned audio is only a placeholder used by the low-level ``stream_step`` compatibility API. """ if chunk.dim() == 3: batch_size = chunk.shape[0] device = chunk.device dtype = chunk.dtype else: batch_size = chunk.shape[0] device = chunk.device dtype = chunk.dtype zero_audio = torch.zeros(batch_size, self.hop_length, device=device, dtype=dtype) zero_pvad = torch.zeros(batch_size, 1, 1, device=device, dtype=dtype) stft_frame, state = self._stft_stream_frame(chunk, state) Em, state = self._encoder_stream_step(stft_frame, state) if Em is None: return zero_audio, state, zero_pvad, None, False if return_attn: Espk, attn = self._cmha_stream_step(Em, ref_cache, return_attn=True) else: Espk = self._cmha_stream_step(Em, ref_cache) attn = None Eo = torch.cat([Em, Espk], dim=1) for idx, block in enumerate(self.separator): Eo, state["separator"][idx] = block.stream_step(Eo, state["separator"][idx]) Xf, Ptgt, state = self._decoder_stream_step(Eo, state) if Xf is None: return zero_audio, state, zero_pvad, attn, False audio, state = self._istft_stream_step(Xf, state) state["frames_seen"] += 1 return audio, state, Ptgt, attn, True @torch.no_grad() def stream_step(self, chunk, state, ref_cache, return_attn=False): """Run one 8 ms streaming step. This low-level API always returns one hop of audio, using zeros during warm-up. Prefer ``stream`` for application code that feeds arbitrary audio lengths and only wants mature output. Returns: ``(audio_chunk, state, pvad_frame)`` or ``(audio_chunk, state, pvad_frame, attn)`` when ``return_attn=True``. """ audio, state, Ptgt, attn, _ = self._stream_step_impl( chunk, state, ref_cache, return_attn=return_attn ) if return_attn: return audio, state, Ptgt, attn return audio, state, Ptgt @torch.no_grad() def stream(self, audio, state, ref_cache, return_attn=False): """Accept any number of samples and return only mature streaming output. ``audio`` may be shaped ``[B, N]`` or ``[B, 1, N]``. Samples that do not complete a hop are buffered in ``state["input_buffer"]`` for the next call. During cold start, this method returns an empty audio tensor until the STFT/encoder/decoder alignment has enough context. Returns: ``(audio_out, state, pvad_frames)`` or ``(audio_out, state, pvad_frames, attn_frames)`` when ``return_attn=True``. ``audio_out`` has shape ``[B, M]`` where ``M`` may be zero. """ if audio.dim() == 3: if audio.shape[1] != 1: raise ValueError("stream expects mono audio shaped [B,N] or [B,1,N]") audio = audio.squeeze(1) if audio.dim() != 2: raise ValueError(f"stream expects audio [B,N] or [B,1,N], got {tuple(audio.shape)}") buffered = state.get("input_buffer") if buffered is None: buffered = torch.zeros(audio.shape[0], 0, device=audio.device, dtype=audio.dtype) if buffered.shape[0] != audio.shape[0]: raise ValueError( f"Buffered batch size {buffered.shape[0]} does not match audio batch {audio.shape[0]}" ) pending = torch.cat([buffered.to(device=audio.device, dtype=audio.dtype), audio], dim=-1) n_hops = pending.shape[-1] // self.hop_length consume = n_hops * self.hop_length state["input_buffer"] = pending[:, consume:].contiguous() chunks = [] pvads = [] attns = [] for idx in range(n_hops): start = idx * self.hop_length chunk = pending[:, start:start + self.hop_length] out, state, pvad, attn, ready = self._stream_step_impl( chunk, state, ref_cache, return_attn=return_attn ) if ready: chunks.append(out) pvads.append(pvad) if return_attn: attns.append(attn) if chunks: audio_out = torch.cat(chunks, dim=-1) pvad_out = torch.cat(pvads, dim=-1) else: audio_out = torch.zeros(audio.shape[0], 0, device=audio.device, dtype=audio.dtype) pvad_out = torch.zeros(audio.shape[0], 1, 0, device=audio.device, dtype=audio.dtype) if return_attn: attn_out = torch.cat(attns, dim=2) if attns else None return audio_out, state, pvad_out, attn_out return audio_out, state, pvad_out OptimizedStreaming_USEF_TP = Streaming_USEF_TP_Optimized