import math import torch import torch.nn as nn import torch.nn.functional as F from local.AllHeadPReLULayerNormalization4DCF import AllHeadPReLULayerNormalization4DCF from local.LayerNormalization4DCF import LayerNormalization4DCF from local.get_layer_from_string import get_layer class StreamingGridNetV2Block(nn.Module): """Streaming-causal GridNetV2 block. Differences from the offline ``GridNetV2Block``: * ``inter_rnn`` (the time-axis LSTM) is unidirectional. ``intra_rnn`` stays bidirectional because it operates across frequency, not time. * The temporal self-attention applies a causal mask so that frame ``t`` only attends to frames ``<= t``. The mask is built on-the-fly so the same module works for any sequence length. All other tensor shapes match the offline block, so the encoder, decoder and cross-attention weights from the existing USEF-TP checkpoint warm-start directly. Only ``inter_rnn`` (bi -> uni) and ``inter_linear`` (input dim halved) differ in parameter count. """ def __getitem__(self, key): return getattr(self, key) def __init__(self, emb_dim, emb_ks, emb_hs, n_freqs, hidden_channels, n_head=4, approx_qk_dim=1024, activation="prelu", eps=1e-5): super().__init__() assert activation == "prelu" in_channels = emb_dim * emb_ks self.intra_norm = nn.LayerNorm(emb_dim, eps=eps) self.intra_rnn = nn.LSTM( in_channels, hidden_channels, 1, batch_first=True, bidirectional=True ) if emb_ks == emb_hs: self.intra_linear = nn.Linear(hidden_channels * 2, in_channels) else: self.intra_linear = nn.ConvTranspose1d( hidden_channels * 2, emb_dim, emb_ks, stride=emb_hs ) self.inter_norm = nn.LayerNorm(emb_dim, eps=eps) self.inter_rnn = nn.LSTM( in_channels, hidden_channels, 1, batch_first=True, bidirectional=False ) if emb_ks == emb_hs: self.inter_linear = nn.Linear(hidden_channels, in_channels) else: self.inter_linear = nn.ConvTranspose1d( hidden_channels, emb_dim, emb_ks, stride=emb_hs ) E = math.ceil(approx_qk_dim * 1.0 / n_freqs) assert emb_dim % n_head == 0 self.add_module("attn_conv_Q", nn.Conv2d(emb_dim, n_head * E, 1)) self.add_module( "attn_norm_Q", AllHeadPReLULayerNormalization4DCF((n_head, E, n_freqs), eps=eps), ) self.add_module("attn_conv_K", nn.Conv2d(emb_dim, n_head * E, 1)) self.add_module( "attn_norm_K", AllHeadPReLULayerNormalization4DCF((n_head, E, n_freqs), eps=eps), ) self.add_module( "attn_conv_V", nn.Conv2d(emb_dim, n_head * emb_dim // n_head, 1) ) self.add_module( "attn_norm_V", AllHeadPReLULayerNormalization4DCF( (n_head, emb_dim // n_head, n_freqs), eps=eps ), ) self.add_module( "attn_concat_proj", nn.Sequential( nn.Conv2d(emb_dim, emb_dim, 1), get_layer(activation)(), LayerNormalization4DCF((emb_dim, n_freqs), eps=eps), ), ) self.emb_dim = emb_dim self.emb_ks = emb_ks self.emb_hs = emb_hs self.n_head = n_head def forward(self, x): """Forward pass for full-sequence (training) mode. Args: x: [B, C, T, Q] Returns: out: [B, C, T, Q] """ B, C, old_T, old_Q = x.shape olp = self.emb_ks - self.emb_hs T = ( math.ceil((old_T + 2 * olp - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks ) Q = ( math.ceil((old_Q + 2 * olp - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks ) x = x.permute(0, 2, 3, 1) # [B, old_T, old_Q, C] x = F.pad(x, (0, 0, olp, Q - old_Q - olp, olp, T - old_T - olp)) # [B, T, Q, C] # intra RNN (over frequency, bidirectional — frequency is not the streaming axis) input_ = x intra_rnn = self.intra_norm(input_) # [B, T, Q, C] if self.emb_ks == self.emb_hs: intra_rnn = intra_rnn.view([B * T, -1, self.emb_ks * C]) intra_rnn, _ = self.intra_rnn(intra_rnn) intra_rnn = self.intra_linear(intra_rnn) intra_rnn = intra_rnn.view([B, T, Q, C]) else: intra_rnn = intra_rnn.view([B * T, Q, C]) intra_rnn = intra_rnn.transpose(1, 2) intra_rnn = F.unfold( intra_rnn[..., None], (self.emb_ks, 1), stride=(self.emb_hs, 1) ) intra_rnn = intra_rnn.transpose(1, 2) intra_rnn, _ = self.intra_rnn(intra_rnn) intra_rnn = intra_rnn.transpose(1, 2) intra_rnn = self.intra_linear(intra_rnn) intra_rnn = intra_rnn.view([B, T, C, Q]) intra_rnn = intra_rnn.transpose(-2, -1) intra_rnn = intra_rnn + input_ intra_rnn = intra_rnn.transpose(1, 2) # [B, Q, T, C] # inter RNN (over time, unidirectional — causal) input_ = intra_rnn inter_rnn = self.inter_norm(input_) # [B, Q, T, C] if self.emb_ks == self.emb_hs: inter_rnn = inter_rnn.view([B * Q, -1, self.emb_ks * C]) inter_rnn, _ = self.inter_rnn(inter_rnn) inter_rnn = self.inter_linear(inter_rnn) inter_rnn = inter_rnn.view([B, Q, T, C]) else: inter_rnn = inter_rnn.view(B * Q, T, C) inter_rnn = inter_rnn.transpose(1, 2) inter_rnn = F.unfold( inter_rnn[..., None], (self.emb_ks, 1), stride=(self.emb_hs, 1) ) inter_rnn = inter_rnn.transpose(1, 2) inter_rnn, _ = self.inter_rnn(inter_rnn) inter_rnn = inter_rnn.transpose(1, 2) inter_rnn = self.inter_linear(inter_rnn) inter_rnn = inter_rnn.view([B, Q, C, T]) inter_rnn = inter_rnn.transpose(-2, -1) inter_rnn = inter_rnn + input_ inter_rnn = inter_rnn.permute(0, 3, 2, 1) # [B, C, T, Q] inter_rnn = inter_rnn[..., olp : olp + old_T, olp : olp + old_Q] batch = inter_rnn # Causal self-attention over time. Uses F.scaled_dot_product_attention, # which dispatches to Flash Attention 2 on H100 / Ampere+ GPUs and # handles the causal mask internally (much faster than building a # T x T mask and a manual matmul + softmax). Q_ = self["attn_norm_Q"](self["attn_conv_Q"](batch)) K_ = self["attn_norm_K"](self["attn_conv_K"](batch)) V_ = self["attn_norm_V"](self["attn_conv_V"](batch)) Q_ = Q_.view(-1, *Q_.shape[2:]) K_ = K_.view(-1, *K_.shape[2:]) V_ = V_.view(-1, *V_.shape[2:]) Q_ = Q_.transpose(1, 2).flatten(start_dim=2) # [B', T, C*Q] K_ = K_.transpose(1, 2).flatten(start_dim=2) # [B', T, C*Q] V_ = V_.transpose(1, 2) # [B', T, C, Q] v_shape = V_.shape V_ = V_.flatten(start_dim=2) # [B', T, C*Q] V_ = F.scaled_dot_product_attention(Q_, K_, V_, is_causal=True) # [B', T, C*Q] V_ = V_.reshape(v_shape) # [B', T, C, Q] V_ = V_.transpose(1, 2) # [B', C, T, Q] head_dim = V_.shape[1] batch = V_.contiguous().view([B, self.n_head * head_dim, old_T, old_Q]) batch = self["attn_concat_proj"](batch) out = batch + inter_rnn return out