Streaming-USEF-TP / local /StreamingGridNetV2Block.py
VMoorjani's picture
V3 slight improvement over V2.
eefb734 verified
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from local.AllHeadPReLULayerNormalization4DCF import AllHeadPReLULayerNormalization4DCF
from local.LayerNormalization4DCF import LayerNormalization4DCF
from local.get_layer_from_string import get_layer
class StreamingGridNetV2Block(nn.Module):
"""Streaming-causal GridNetV2 block.
Differences from the offline ``GridNetV2Block``:
* ``inter_rnn`` (the time-axis LSTM) is unidirectional. ``intra_rnn`` stays
bidirectional because it operates across frequency, not time.
* The temporal self-attention applies a causal mask so that frame ``t`` only
attends to frames ``<= t``. The mask is built on-the-fly so the same
module works for any sequence length.
All other tensor shapes match the offline block, so the encoder, decoder and
cross-attention weights from the existing USEF-TP checkpoint warm-start
directly. Only ``inter_rnn`` (bi -> uni) and ``inter_linear`` (input dim
halved) differ in parameter count.
"""
def __getitem__(self, key):
return getattr(self, key)
def __init__(self, emb_dim, emb_ks, emb_hs, n_freqs, hidden_channels,
n_head=4, approx_qk_dim=1024, activation="prelu", eps=1e-5):
super().__init__()
assert activation == "prelu"
in_channels = emb_dim * emb_ks
self.intra_norm = nn.LayerNorm(emb_dim, eps=eps)
self.intra_rnn = nn.LSTM(
in_channels, hidden_channels, 1, batch_first=True, bidirectional=True
)
if emb_ks == emb_hs:
self.intra_linear = nn.Linear(hidden_channels * 2, in_channels)
else:
self.intra_linear = nn.ConvTranspose1d(
hidden_channels * 2, emb_dim, emb_ks, stride=emb_hs
)
self.inter_norm = nn.LayerNorm(emb_dim, eps=eps)
self.inter_rnn = nn.LSTM(
in_channels, hidden_channels, 1, batch_first=True, bidirectional=False
)
if emb_ks == emb_hs:
self.inter_linear = nn.Linear(hidden_channels, in_channels)
else:
self.inter_linear = nn.ConvTranspose1d(
hidden_channels, emb_dim, emb_ks, stride=emb_hs
)
E = math.ceil(approx_qk_dim * 1.0 / n_freqs)
assert emb_dim % n_head == 0
self.add_module("attn_conv_Q", nn.Conv2d(emb_dim, n_head * E, 1))
self.add_module(
"attn_norm_Q",
AllHeadPReLULayerNormalization4DCF((n_head, E, n_freqs), eps=eps),
)
self.add_module("attn_conv_K", nn.Conv2d(emb_dim, n_head * E, 1))
self.add_module(
"attn_norm_K",
AllHeadPReLULayerNormalization4DCF((n_head, E, n_freqs), eps=eps),
)
self.add_module(
"attn_conv_V", nn.Conv2d(emb_dim, n_head * emb_dim // n_head, 1)
)
self.add_module(
"attn_norm_V",
AllHeadPReLULayerNormalization4DCF(
(n_head, emb_dim // n_head, n_freqs), eps=eps
),
)
self.add_module(
"attn_concat_proj",
nn.Sequential(
nn.Conv2d(emb_dim, emb_dim, 1),
get_layer(activation)(),
LayerNormalization4DCF((emb_dim, n_freqs), eps=eps),
),
)
self.emb_dim = emb_dim
self.emb_ks = emb_ks
self.emb_hs = emb_hs
self.n_head = n_head
def forward(self, x):
"""Forward pass for full-sequence (training) mode.
Args:
x: [B, C, T, Q]
Returns:
out: [B, C, T, Q]
"""
B, C, old_T, old_Q = x.shape
olp = self.emb_ks - self.emb_hs
T = (
math.ceil((old_T + 2 * olp - self.emb_ks) / self.emb_hs) * self.emb_hs
+ self.emb_ks
)
Q = (
math.ceil((old_Q + 2 * olp - self.emb_ks) / self.emb_hs) * self.emb_hs
+ self.emb_ks
)
x = x.permute(0, 2, 3, 1) # [B, old_T, old_Q, C]
x = F.pad(x, (0, 0, olp, Q - old_Q - olp, olp, T - old_T - olp)) # [B, T, Q, C]
# intra RNN (over frequency, bidirectional — frequency is not the streaming axis)
input_ = x
intra_rnn = self.intra_norm(input_) # [B, T, Q, C]
if self.emb_ks == self.emb_hs:
intra_rnn = intra_rnn.view([B * T, -1, self.emb_ks * C])
intra_rnn, _ = self.intra_rnn(intra_rnn)
intra_rnn = self.intra_linear(intra_rnn)
intra_rnn = intra_rnn.view([B, T, Q, C])
else:
intra_rnn = intra_rnn.view([B * T, Q, C])
intra_rnn = intra_rnn.transpose(1, 2)
intra_rnn = F.unfold(
intra_rnn[..., None], (self.emb_ks, 1), stride=(self.emb_hs, 1)
)
intra_rnn = intra_rnn.transpose(1, 2)
intra_rnn, _ = self.intra_rnn(intra_rnn)
intra_rnn = intra_rnn.transpose(1, 2)
intra_rnn = self.intra_linear(intra_rnn)
intra_rnn = intra_rnn.view([B, T, C, Q])
intra_rnn = intra_rnn.transpose(-2, -1)
intra_rnn = intra_rnn + input_
intra_rnn = intra_rnn.transpose(1, 2) # [B, Q, T, C]
# inter RNN (over time, unidirectional — causal)
input_ = intra_rnn
inter_rnn = self.inter_norm(input_) # [B, Q, T, C]
if self.emb_ks == self.emb_hs:
inter_rnn = inter_rnn.view([B * Q, -1, self.emb_ks * C])
inter_rnn, _ = self.inter_rnn(inter_rnn)
inter_rnn = self.inter_linear(inter_rnn)
inter_rnn = inter_rnn.view([B, Q, T, C])
else:
inter_rnn = inter_rnn.view(B * Q, T, C)
inter_rnn = inter_rnn.transpose(1, 2)
inter_rnn = F.unfold(
inter_rnn[..., None], (self.emb_ks, 1), stride=(self.emb_hs, 1)
)
inter_rnn = inter_rnn.transpose(1, 2)
inter_rnn, _ = self.inter_rnn(inter_rnn)
inter_rnn = inter_rnn.transpose(1, 2)
inter_rnn = self.inter_linear(inter_rnn)
inter_rnn = inter_rnn.view([B, Q, C, T])
inter_rnn = inter_rnn.transpose(-2, -1)
inter_rnn = inter_rnn + input_
inter_rnn = inter_rnn.permute(0, 3, 2, 1) # [B, C, T, Q]
inter_rnn = inter_rnn[..., olp : olp + old_T, olp : olp + old_Q]
batch = inter_rnn
# Causal self-attention over time. Uses F.scaled_dot_product_attention,
# which dispatches to Flash Attention 2 on H100 / Ampere+ GPUs and
# handles the causal mask internally (much faster than building a
# T x T mask and a manual matmul + softmax).
Q_ = self["attn_norm_Q"](self["attn_conv_Q"](batch))
K_ = self["attn_norm_K"](self["attn_conv_K"](batch))
V_ = self["attn_norm_V"](self["attn_conv_V"](batch))
Q_ = Q_.view(-1, *Q_.shape[2:])
K_ = K_.view(-1, *K_.shape[2:])
V_ = V_.view(-1, *V_.shape[2:])
Q_ = Q_.transpose(1, 2).flatten(start_dim=2) # [B', T, C*Q]
K_ = K_.transpose(1, 2).flatten(start_dim=2) # [B', T, C*Q]
V_ = V_.transpose(1, 2) # [B', T, C, Q]
v_shape = V_.shape
V_ = V_.flatten(start_dim=2) # [B', T, C*Q]
V_ = F.scaled_dot_product_attention(Q_, K_, V_, is_causal=True) # [B', T, C*Q]
V_ = V_.reshape(v_shape) # [B', T, C, Q]
V_ = V_.transpose(1, 2) # [B', C, T, Q]
head_dim = V_.shape[1]
batch = V_.contiguous().view([B, self.n_head * head_dim, old_T, old_Q])
batch = self["attn_concat_proj"](batch)
out = batch + inter_rnn
return out