File size: 7,797 Bytes
eefb734 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 | import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from local.AllHeadPReLULayerNormalization4DCF import AllHeadPReLULayerNormalization4DCF
from local.LayerNormalization4DCF import LayerNormalization4DCF
from local.get_layer_from_string import get_layer
class StreamingGridNetV2Block(nn.Module):
"""Streaming-causal GridNetV2 block.
Differences from the offline ``GridNetV2Block``:
* ``inter_rnn`` (the time-axis LSTM) is unidirectional. ``intra_rnn`` stays
bidirectional because it operates across frequency, not time.
* The temporal self-attention applies a causal mask so that frame ``t`` only
attends to frames ``<= t``. The mask is built on-the-fly so the same
module works for any sequence length.
All other tensor shapes match the offline block, so the encoder, decoder and
cross-attention weights from the existing USEF-TP checkpoint warm-start
directly. Only ``inter_rnn`` (bi -> uni) and ``inter_linear`` (input dim
halved) differ in parameter count.
"""
def __getitem__(self, key):
return getattr(self, key)
def __init__(self, emb_dim, emb_ks, emb_hs, n_freqs, hidden_channels,
n_head=4, approx_qk_dim=1024, activation="prelu", eps=1e-5):
super().__init__()
assert activation == "prelu"
in_channels = emb_dim * emb_ks
self.intra_norm = nn.LayerNorm(emb_dim, eps=eps)
self.intra_rnn = nn.LSTM(
in_channels, hidden_channels, 1, batch_first=True, bidirectional=True
)
if emb_ks == emb_hs:
self.intra_linear = nn.Linear(hidden_channels * 2, in_channels)
else:
self.intra_linear = nn.ConvTranspose1d(
hidden_channels * 2, emb_dim, emb_ks, stride=emb_hs
)
self.inter_norm = nn.LayerNorm(emb_dim, eps=eps)
self.inter_rnn = nn.LSTM(
in_channels, hidden_channels, 1, batch_first=True, bidirectional=False
)
if emb_ks == emb_hs:
self.inter_linear = nn.Linear(hidden_channels, in_channels)
else:
self.inter_linear = nn.ConvTranspose1d(
hidden_channels, emb_dim, emb_ks, stride=emb_hs
)
E = math.ceil(approx_qk_dim * 1.0 / n_freqs)
assert emb_dim % n_head == 0
self.add_module("attn_conv_Q", nn.Conv2d(emb_dim, n_head * E, 1))
self.add_module(
"attn_norm_Q",
AllHeadPReLULayerNormalization4DCF((n_head, E, n_freqs), eps=eps),
)
self.add_module("attn_conv_K", nn.Conv2d(emb_dim, n_head * E, 1))
self.add_module(
"attn_norm_K",
AllHeadPReLULayerNormalization4DCF((n_head, E, n_freqs), eps=eps),
)
self.add_module(
"attn_conv_V", nn.Conv2d(emb_dim, n_head * emb_dim // n_head, 1)
)
self.add_module(
"attn_norm_V",
AllHeadPReLULayerNormalization4DCF(
(n_head, emb_dim // n_head, n_freqs), eps=eps
),
)
self.add_module(
"attn_concat_proj",
nn.Sequential(
nn.Conv2d(emb_dim, emb_dim, 1),
get_layer(activation)(),
LayerNormalization4DCF((emb_dim, n_freqs), eps=eps),
),
)
self.emb_dim = emb_dim
self.emb_ks = emb_ks
self.emb_hs = emb_hs
self.n_head = n_head
def forward(self, x):
"""Forward pass for full-sequence (training) mode.
Args:
x: [B, C, T, Q]
Returns:
out: [B, C, T, Q]
"""
B, C, old_T, old_Q = x.shape
olp = self.emb_ks - self.emb_hs
T = (
math.ceil((old_T + 2 * olp - self.emb_ks) / self.emb_hs) * self.emb_hs
+ self.emb_ks
)
Q = (
math.ceil((old_Q + 2 * olp - self.emb_ks) / self.emb_hs) * self.emb_hs
+ self.emb_ks
)
x = x.permute(0, 2, 3, 1) # [B, old_T, old_Q, C]
x = F.pad(x, (0, 0, olp, Q - old_Q - olp, olp, T - old_T - olp)) # [B, T, Q, C]
# intra RNN (over frequency, bidirectional — frequency is not the streaming axis)
input_ = x
intra_rnn = self.intra_norm(input_) # [B, T, Q, C]
if self.emb_ks == self.emb_hs:
intra_rnn = intra_rnn.view([B * T, -1, self.emb_ks * C])
intra_rnn, _ = self.intra_rnn(intra_rnn)
intra_rnn = self.intra_linear(intra_rnn)
intra_rnn = intra_rnn.view([B, T, Q, C])
else:
intra_rnn = intra_rnn.view([B * T, Q, C])
intra_rnn = intra_rnn.transpose(1, 2)
intra_rnn = F.unfold(
intra_rnn[..., None], (self.emb_ks, 1), stride=(self.emb_hs, 1)
)
intra_rnn = intra_rnn.transpose(1, 2)
intra_rnn, _ = self.intra_rnn(intra_rnn)
intra_rnn = intra_rnn.transpose(1, 2)
intra_rnn = self.intra_linear(intra_rnn)
intra_rnn = intra_rnn.view([B, T, C, Q])
intra_rnn = intra_rnn.transpose(-2, -1)
intra_rnn = intra_rnn + input_
intra_rnn = intra_rnn.transpose(1, 2) # [B, Q, T, C]
# inter RNN (over time, unidirectional — causal)
input_ = intra_rnn
inter_rnn = self.inter_norm(input_) # [B, Q, T, C]
if self.emb_ks == self.emb_hs:
inter_rnn = inter_rnn.view([B * Q, -1, self.emb_ks * C])
inter_rnn, _ = self.inter_rnn(inter_rnn)
inter_rnn = self.inter_linear(inter_rnn)
inter_rnn = inter_rnn.view([B, Q, T, C])
else:
inter_rnn = inter_rnn.view(B * Q, T, C)
inter_rnn = inter_rnn.transpose(1, 2)
inter_rnn = F.unfold(
inter_rnn[..., None], (self.emb_ks, 1), stride=(self.emb_hs, 1)
)
inter_rnn = inter_rnn.transpose(1, 2)
inter_rnn, _ = self.inter_rnn(inter_rnn)
inter_rnn = inter_rnn.transpose(1, 2)
inter_rnn = self.inter_linear(inter_rnn)
inter_rnn = inter_rnn.view([B, Q, C, T])
inter_rnn = inter_rnn.transpose(-2, -1)
inter_rnn = inter_rnn + input_
inter_rnn = inter_rnn.permute(0, 3, 2, 1) # [B, C, T, Q]
inter_rnn = inter_rnn[..., olp : olp + old_T, olp : olp + old_Q]
batch = inter_rnn
# Causal self-attention over time. Uses F.scaled_dot_product_attention,
# which dispatches to Flash Attention 2 on H100 / Ampere+ GPUs and
# handles the causal mask internally (much faster than building a
# T x T mask and a manual matmul + softmax).
Q_ = self["attn_norm_Q"](self["attn_conv_Q"](batch))
K_ = self["attn_norm_K"](self["attn_conv_K"](batch))
V_ = self["attn_norm_V"](self["attn_conv_V"](batch))
Q_ = Q_.view(-1, *Q_.shape[2:])
K_ = K_.view(-1, *K_.shape[2:])
V_ = V_.view(-1, *V_.shape[2:])
Q_ = Q_.transpose(1, 2).flatten(start_dim=2) # [B', T, C*Q]
K_ = K_.transpose(1, 2).flatten(start_dim=2) # [B', T, C*Q]
V_ = V_.transpose(1, 2) # [B', T, C, Q]
v_shape = V_.shape
V_ = V_.flatten(start_dim=2) # [B', T, C*Q]
V_ = F.scaled_dot_product_attention(Q_, K_, V_, is_causal=True) # [B', T, C*Q]
V_ = V_.reshape(v_shape) # [B', T, C, Q]
V_ = V_.transpose(1, 2) # [B', C, T, Q]
head_dim = V_.shape[1]
batch = V_.contiguous().view([B, self.n_head * head_dim, old_T, old_Q])
batch = self["attn_concat_proj"](batch)
out = batch + inter_rnn
return out
|