Streaming-USEF-TP / model_streaming_usef_tp_optimized.py
VMoorjani's picture
V3 slight improvement over V2.
eefb734 verified
"""Optimized stateful inference path for Streaming USEF-TP.
This module keeps the full-sequence ``forward`` behavior compatible with
``model_streaming_usef_tp.py`` while adding an explicit chunk-by-chunk
``stream_step`` API. The optimized path caches reference-side CMHA tensors,
rolling STFT/decoder/iSTFT context, temporal LSTM state, and GridNet
self-attention K/V history.
"""
import copy
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from local.CMHA import CMHA
from local.STFT import STFT, iSTFT
from local.StreamingGridNetV2Block import StreamingGridNetV2Block
from model_streaming_usef_tp import InteractionModule, PVADDecoder
class OptimizedStreamingGridNetV2Block(StreamingGridNetV2Block):
"""Streaming step extension for ``StreamingGridNetV2Block``.
The step path specializes the common real-time configuration
``emb_ks == emb_hs == 1``. The inherited ``forward`` remains available for
full-sequence training/evaluation and state-dict compatibility.
"""
def init_stream_state(self, batch_size, n_freqs, device, dtype=None,
max_attention_frames=None):
dtype = dtype or next(self.parameters()).dtype
hidden = self.inter_rnn.hidden_size
state = {
"inter_h": torch.zeros(1, batch_size * n_freqs, hidden, device=device, dtype=dtype),
"inter_c": torch.zeros(1, batch_size * n_freqs, hidden, device=device, dtype=dtype),
"attn_k": None,
"attn_v": None,
"max_attention_frames": max_attention_frames,
}
return state
def stream_step(self, x, state):
"""Process one mature time frame.
Args:
x: ``[B, C, 1, F]`` frame.
state: state from ``init_stream_state``.
Returns:
Tuple ``(out, updated_state)`` with ``out`` shaped ``[B, C, 1, F]``.
"""
if self.emb_ks != 1 or self.emb_hs != 1:
raise NotImplementedError(
"Optimized stream_step currently requires emb_ks == emb_hs == 1."
)
B, C, T, Q = x.shape
if T != 1:
raise ValueError(f"stream_step expects one time frame, got T={T}")
frame = x.permute(0, 2, 3, 1) # [B, 1, F, C]
input_ = frame
intra = self.intra_norm(input_)
intra = intra.reshape(B, Q, C)
intra, _ = self.intra_rnn(intra)
intra = self.intra_linear(intra)
intra = intra.reshape(B, 1, Q, C)
intra = intra + input_
intra = intra.transpose(1, 2) # [B, F, 1, C]
input_ = intra
inter = self.inter_norm(input_)
inter = inter.reshape(B * Q, 1, C)
inter, (h, c) = self.inter_rnn(inter, (state["inter_h"], state["inter_c"]))
state["inter_h"] = h
state["inter_c"] = c
inter = self.inter_linear(inter)
inter = inter.reshape(B, Q, 1, C)
inter = inter + input_
inter = inter.permute(0, 3, 2, 1).contiguous() # [B, C, 1, F]
q = self["attn_norm_Q"](self["attn_conv_Q"](inter))
k = self["attn_norm_K"](self["attn_conv_K"](inter))
v = self["attn_norm_V"](self["attn_conv_V"](inter))
q = q.reshape(-1, *q.shape[2:]).transpose(1, 2).flatten(start_dim=2)
k = k.reshape(-1, *k.shape[2:]).transpose(1, 2).flatten(start_dim=2)
v = v.reshape(-1, *v.shape[2:]).transpose(1, 2)
v_shape = v.shape
v = v.flatten(start_dim=2)
if state["attn_k"] is None:
k_cache = k
v_cache = v
else:
k_cache = torch.cat([state["attn_k"], k], dim=1)
v_cache = torch.cat([state["attn_v"], v], dim=1)
max_frames = state.get("max_attention_frames")
if max_frames is not None and k_cache.shape[1] > max_frames:
k_cache = k_cache[:, -max_frames:, :].contiguous()
v_cache = v_cache[:, -max_frames:, :].contiguous()
state["attn_k"] = k_cache
state["attn_v"] = v_cache
attn = F.scaled_dot_product_attention(q, k_cache, v_cache, is_causal=False)
attn = attn.reshape(v_shape).transpose(1, 2)
head_dim = attn.shape[1]
attn = attn.contiguous().reshape(B, self.n_head * head_dim, 1, Q)
attn = self["attn_concat_proj"](attn)
return attn + inter, state
class Streaming_USEF_TP_Optimized(nn.Module):
"""Streaming USEF-TP with cached, stateful PyTorch inference."""
def __init__(self, hidden_channels, n_head, emb_dim, emb_ks, emb_hs,
num_layers=6, n_fft=128, hop_length=64, win_length=128,
cmha_approx_qk_dim=512, eps=1e-5,
max_attention_frames=None):
super().__init__()
self.num_layers = num_layers
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.n_freqs = n_fft // 2 + 1
self.emb_dim = emb_dim
self.max_attention_frames = max_attention_frames
self.stft = STFT(n_fft=n_fft, hop_length=hop_length, win_length=win_length)
self.istft = iSTFT(n_fft=n_fft, hop_length=hop_length, win_length=win_length)
self.register_buffer("stream_window", torch.hann_window(win_length), persistent=False)
t_ksize = 3
ks, padding = (t_ksize, 3), (t_ksize // 2, 1)
self.encoder = nn.Conv2d(2, emb_dim, ks, padding=padding)
self.cmha = CMHA(
emb_dim=emb_dim, n_freqs=self.n_freqs, n_head=n_head,
approx_qk_dim=cmha_approx_qk_dim, eps=eps,
)
self.separator = nn.ModuleList([
copy.deepcopy(
OptimizedStreamingGridNetV2Block(
2 * emb_dim, emb_ks, emb_hs, self.n_freqs, hidden_channels,
n_head, approx_qk_dim=512, activation="prelu",
)
) for _ in range(num_layers)
])
self.tse_decoder = nn.ConvTranspose2d(
2 * emb_dim, 2, ks, stride=1, padding=padding
)
self.pvad_decoder = PVADDecoder(
in_channels=2 * emb_dim, n_freqs=self.n_freqs, t_ksize=t_ksize
)
self.interaction = InteractionModule()
def forward(self, mix, ref, return_attn=False, return_no_mask=False):
"""Full-sequence compatibility path."""
mix = mix.unsqueeze(1)
ref = ref.unsqueeze(1)
mix_c = self.stft(mix)[-1]
ref_c = self.stft(ref)[-1]
mix_ri = torch.cat([mix_c.real, mix_c.imag], dim=1).permute(0, 1, 3, 2).contiguous()
ref_ri = torch.cat([ref_c.real, ref_c.imag], dim=1).permute(0, 1, 3, 2).contiguous()
Em = self.encoder(mix_ri)
Er = self.encoder(ref_ri)
if return_attn:
Espk, attn = self.cmha(Em, Er, return_attn=True)
else:
Espk = self.cmha(Em, Er)
Ef = torch.cat([Em, Espk], dim=1)
Eo = Ef
for block in self.separator:
Eo = block(Eo)
Dtse = self.tse_decoder(Eo)
Ptgt = self.pvad_decoder(Eo)
Pi = self.interaction(Ptgt)
L_m = Dtse.shape[2]
if Pi.shape[-1] < L_m:
Pi = F.pad(Pi, (0, L_m - Pi.shape[-1]))
elif Pi.shape[-1] > L_m:
Pi = Pi[..., :L_m]
mask = Pi.unsqueeze(-1).expand(-1, 2, -1, Dtse.shape[-1])
Xf = Dtse * mask
out_r = Xf[:, 0, :, :].permute(0, 2, 1).contiguous()
out_i = Xf[:, 1, :, :].permute(0, 2, 1).contiguous()
Xtgt = self.istft((out_r, out_i), input_type="real_imag").unsqueeze(1)
if return_no_mask:
out_r_nm = Dtse[:, 0, :, :].permute(0, 2, 1).contiguous()
out_i_nm = Dtse[:, 1, :, :].permute(0, 2, 1).contiguous()
Xtgt_nomask = self.istft((out_r_nm, out_i_nm), input_type="real_imag").unsqueeze(1)
if return_attn and return_no_mask:
return Xtgt.squeeze(1), Ptgt, attn, Xtgt_nomask.squeeze(1)
if return_attn:
return Xtgt.squeeze(1), Ptgt, attn
if return_no_mask:
return Xtgt.squeeze(1), Ptgt, Xtgt_nomask.squeeze(1)
return Xtgt.squeeze(1), Ptgt
def reset_stream_state(self, batch_size, device, dtype=None):
return self.init_stream(batch_size, device, dtype=dtype)
def init_stream(self, batch_size, device, dtype=None):
dtype = dtype or next(self.parameters()).dtype
zeros_ri = torch.zeros(batch_size, 2, 1, self.n_freqs, device=device, dtype=dtype)
zeros_eo = torch.zeros(
batch_size, 2 * self.emb_dim, 1, self.n_freqs, device=device, dtype=dtype
)
window_len = self.win_length
return {
"sample_buffer": torch.zeros(batch_size, window_len, device=device, dtype=dtype),
"input_buffer": torch.zeros(batch_size, 0, device=device, dtype=dtype),
"encoder_frames": zeros_ri,
"decoder_eo": zeros_eo,
"pvad_conv_prev": torch.zeros(batch_size, self.n_freqs, 1, device=device, dtype=dtype),
"interaction_prev": torch.zeros(batch_size, 1, 1, device=device, dtype=dtype),
"istft_ola": torch.zeros(batch_size, self.n_fft, device=device, dtype=dtype),
"istft_norm": torch.zeros(batch_size, self.n_fft, device=device, dtype=dtype),
"separator": [
block.init_stream_state(
batch_size, self.n_freqs, device, dtype=dtype,
max_attention_frames=self.max_attention_frames,
)
for block in self.separator
],
"frames_seen": 0,
}
@torch.no_grad()
def prepare_reference(self, ref):
"""Precompute reference encoding and CMHA K/V tensors.
Args:
ref: waveform ``[B, T]`` or ``[B, 1, T]``.
Returns:
A cache dictionary consumed by ``stream_step``.
"""
if ref.dim() == 2:
ref = ref.unsqueeze(1)
elif ref.dim() != 3:
raise ValueError(f"Expected ref with shape [B,T] or [B,1,T], got {tuple(ref.shape)}")
ref_c = self.stft(ref)[-1]
ref_ri = torch.cat([ref_c.real, ref_c.imag], dim=1).permute(0, 1, 3, 2).contiguous()
Er = self.encoder(ref_ri)
K = self.cmha["attn_norm_K"](self.cmha["attn_conv_K"](Er))
V = self.cmha["attn_norm_V"](self.cmha["attn_conv_V"](Er))
B = Er.shape[0]
Lr = Er.shape[-2]
K = K.reshape(-1, *K.shape[2:])
V = V.reshape(-1, *V.shape[2:])
K = K.transpose(2, 3).contiguous().reshape(B * self.cmha.n_head, -1, Lr)
V = V.transpose(1, 2).flatten(start_dim=2).contiguous()
return {
"K": K,
"V": V,
"Lr": Lr,
"batch_size": B,
"qk_dim": K.shape[1],
}
def _cmha_stream_step(self, Em, ref_cache, return_attn=False):
B, _, Lm, _ = Em.shape
if Lm != 1:
raise ValueError(f"CMHA stream step expects one frame, got Lm={Lm}")
if ref_cache["batch_size"] != B:
raise ValueError(
f"Reference cache batch size {ref_cache['batch_size']} does not match chunk batch {B}"
)
Q = self.cmha["attn_norm_Q"](self.cmha["attn_conv_Q"](Em))
Q = Q.reshape(-1, *Q.shape[2:]).transpose(1, 2).flatten(start_dim=2)
attn = torch.matmul(Q, ref_cache["K"]) / math.sqrt(ref_cache["qk_dim"])
attn = F.softmax(attn, dim=2)
out = torch.matmul(attn, ref_cache["V"])
out = out.reshape(B * self.cmha.n_head, 1, -1, self.n_freqs).transpose(1, 2)
head_dim = out.shape[1]
out = out.contiguous().reshape(B, self.cmha.n_head * head_dim, 1, self.n_freqs)
out = self.cmha["attn_concat_proj"](out)
if return_attn:
return out, attn.reshape(B, self.cmha.n_head, 1, ref_cache["Lr"]).detach()
return out
def _stft_stream_frame(self, chunk, state):
if chunk.dim() == 3:
if chunk.shape[1] != 1:
raise ValueError("stream_step expects mono chunks shaped [B,H] or [B,1,H]")
chunk = chunk.squeeze(1)
if chunk.dim() != 2:
raise ValueError(f"stream_step expects chunk [B,H] or [B,1,H], got {tuple(chunk.shape)}")
if chunk.shape[-1] != self.hop_length:
raise ValueError(
f"stream_step expects {self.hop_length} samples per chunk, got {chunk.shape[-1]}"
)
state["sample_buffer"] = torch.cat(
[state["sample_buffer"][:, chunk.shape[-1]:], chunk], dim=-1
)
window = self.stream_window.to(device=chunk.device, dtype=chunk.dtype)
frame = torch.fft.rfft(state["sample_buffer"] * window, n=self.n_fft)
frame = torch.stack([frame.real, frame.imag], dim=1).unsqueeze(2)
return frame, state
def _encoder_stream_step(self, stft_frame, state):
frames = torch.cat([state["encoder_frames"], stft_frame], dim=2)
state["encoder_frames"] = frames[:, :, -2:, :].contiguous()
if frames.shape[2] < 3:
return None, state
Em = self.encoder(frames[:, :, -3:, :])[:, :, 1:2, :]
return Em, state
def _decoder_stream_step(self, Eo, state):
frames = torch.cat([state["decoder_eo"], Eo], dim=2)
state["decoder_eo"] = frames[:, :, -2:, :].contiguous()
if frames.shape[2] < 3:
return None, None, state
window = frames[:, :, -3:, :]
Dtse = self.tse_decoder(window)[:, :, 1:2, :]
pvad_2d = self.pvad_decoder.tconv2d(window)[:, :, 1:2, :]
pvad_feat = pvad_2d.squeeze(1).transpose(1, 2) # [B, F, 1]
pvad_in = torch.cat([state["pvad_conv_prev"], pvad_feat], dim=-1)
Ptgt = self.pvad_decoder.conv1d(pvad_in)
state["pvad_conv_prev"] = pvad_feat
p = torch.sigmoid(Ptgt)
interaction_in = torch.cat([state["interaction_prev"], p], dim=-1)
Pi = F.relu(self.interaction.tconv1d(interaction_in))[..., 1:2]
state["interaction_prev"] = p
mask = Pi.unsqueeze(-1).expand(-1, 2, -1, Dtse.shape[-1])
return Dtse * mask, Ptgt, state
def _istft_stream_step(self, Xf, state):
real = Xf[:, 0, 0, :]
imag = Xf[:, 1, 0, :]
frame = torch.fft.irfft(torch.complex(real, imag), n=self.n_fft)
window = self.stream_window.to(device=Xf.device, dtype=Xf.dtype)
frame = frame * window
state["istft_ola"][:, :self.n_fft] += frame
state["istft_norm"][:, :self.n_fft] += window.square().unsqueeze(0)
denom = state["istft_norm"][:, :self.hop_length].clamp_min(1e-8)
chunk = state["istft_ola"][:, :self.hop_length] / denom
zeros = torch.zeros_like(state["istft_ola"][:, :self.hop_length])
state["istft_ola"] = torch.cat([state["istft_ola"][:, self.hop_length:], zeros], dim=-1)
state["istft_norm"] = torch.cat([state["istft_norm"][:, self.hop_length:], zeros], dim=-1)
return chunk, state
def _stream_step_impl(self, chunk, state, ref_cache, return_attn=False):
"""Run one hop-sized streaming step and report output maturity.
Returns:
``(audio_chunk, state, pvad_frame, attn, ready)``. ``ready`` is
false during encoder/decoder warm-up, when returned audio is only a
placeholder used by the low-level ``stream_step`` compatibility API.
"""
if chunk.dim() == 3:
batch_size = chunk.shape[0]
device = chunk.device
dtype = chunk.dtype
else:
batch_size = chunk.shape[0]
device = chunk.device
dtype = chunk.dtype
zero_audio = torch.zeros(batch_size, self.hop_length, device=device, dtype=dtype)
zero_pvad = torch.zeros(batch_size, 1, 1, device=device, dtype=dtype)
stft_frame, state = self._stft_stream_frame(chunk, state)
Em, state = self._encoder_stream_step(stft_frame, state)
if Em is None:
return zero_audio, state, zero_pvad, None, False
if return_attn:
Espk, attn = self._cmha_stream_step(Em, ref_cache, return_attn=True)
else:
Espk = self._cmha_stream_step(Em, ref_cache)
attn = None
Eo = torch.cat([Em, Espk], dim=1)
for idx, block in enumerate(self.separator):
Eo, state["separator"][idx] = block.stream_step(Eo, state["separator"][idx])
Xf, Ptgt, state = self._decoder_stream_step(Eo, state)
if Xf is None:
return zero_audio, state, zero_pvad, attn, False
audio, state = self._istft_stream_step(Xf, state)
state["frames_seen"] += 1
return audio, state, Ptgt, attn, True
@torch.no_grad()
def stream_step(self, chunk, state, ref_cache, return_attn=False):
"""Run one 8 ms streaming step.
This low-level API always returns one hop of audio, using zeros during
warm-up. Prefer ``stream`` for application code that feeds arbitrary
audio lengths and only wants mature output.
Returns:
``(audio_chunk, state, pvad_frame)`` or
``(audio_chunk, state, pvad_frame, attn)`` when ``return_attn=True``.
"""
audio, state, Ptgt, attn, _ = self._stream_step_impl(
chunk, state, ref_cache, return_attn=return_attn
)
if return_attn:
return audio, state, Ptgt, attn
return audio, state, Ptgt
@torch.no_grad()
def stream(self, audio, state, ref_cache, return_attn=False):
"""Accept any number of samples and return only mature streaming output.
``audio`` may be shaped ``[B, N]`` or ``[B, 1, N]``. Samples that do not
complete a hop are buffered in ``state["input_buffer"]`` for the next
call. During cold start, this method returns an empty audio tensor until
the STFT/encoder/decoder alignment has enough context.
Returns:
``(audio_out, state, pvad_frames)`` or
``(audio_out, state, pvad_frames, attn_frames)`` when
``return_attn=True``. ``audio_out`` has shape ``[B, M]`` where
``M`` may be zero.
"""
if audio.dim() == 3:
if audio.shape[1] != 1:
raise ValueError("stream expects mono audio shaped [B,N] or [B,1,N]")
audio = audio.squeeze(1)
if audio.dim() != 2:
raise ValueError(f"stream expects audio [B,N] or [B,1,N], got {tuple(audio.shape)}")
buffered = state.get("input_buffer")
if buffered is None:
buffered = torch.zeros(audio.shape[0], 0, device=audio.device, dtype=audio.dtype)
if buffered.shape[0] != audio.shape[0]:
raise ValueError(
f"Buffered batch size {buffered.shape[0]} does not match audio batch {audio.shape[0]}"
)
pending = torch.cat([buffered.to(device=audio.device, dtype=audio.dtype), audio], dim=-1)
n_hops = pending.shape[-1] // self.hop_length
consume = n_hops * self.hop_length
state["input_buffer"] = pending[:, consume:].contiguous()
chunks = []
pvads = []
attns = []
for idx in range(n_hops):
start = idx * self.hop_length
chunk = pending[:, start:start + self.hop_length]
out, state, pvad, attn, ready = self._stream_step_impl(
chunk, state, ref_cache, return_attn=return_attn
)
if ready:
chunks.append(out)
pvads.append(pvad)
if return_attn:
attns.append(attn)
if chunks:
audio_out = torch.cat(chunks, dim=-1)
pvad_out = torch.cat(pvads, dim=-1)
else:
audio_out = torch.zeros(audio.shape[0], 0, device=audio.device, dtype=audio.dtype)
pvad_out = torch.zeros(audio.shape[0], 1, 0, device=audio.device, dtype=audio.dtype)
if return_attn:
attn_out = torch.cat(attns, dim=2) if attns else None
return audio_out, state, pvad_out, attn_out
return audio_out, state, pvad_out
OptimizedStreaming_USEF_TP = Streaming_USEF_TP_Optimized