| import math |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from local.AllHeadPReLULayerNormalization4DCF import AllHeadPReLULayerNormalization4DCF |
| from local.LayerNormalization4DCF import LayerNormalization4DCF |
| from local.get_layer_from_string import get_layer |
|
|
|
|
| class StreamingGridNetV2Block(nn.Module): |
| """Streaming-causal GridNetV2 block. |
| |
| Differences from the offline ``GridNetV2Block``: |
| |
| * ``inter_rnn`` (the time-axis LSTM) is unidirectional. ``intra_rnn`` stays |
| bidirectional because it operates across frequency, not time. |
| * The temporal self-attention applies a causal mask so that frame ``t`` only |
| attends to frames ``<= t``. The mask is built on-the-fly so the same |
| module works for any sequence length. |
| |
| All other tensor shapes match the offline block, so the encoder, decoder and |
| cross-attention weights from the existing USEF-TP checkpoint warm-start |
| directly. Only ``inter_rnn`` (bi -> uni) and ``inter_linear`` (input dim |
| halved) differ in parameter count. |
| """ |
|
|
| def __getitem__(self, key): |
| return getattr(self, key) |
|
|
| def __init__(self, emb_dim, emb_ks, emb_hs, n_freqs, hidden_channels, |
| n_head=4, approx_qk_dim=1024, activation="prelu", eps=1e-5): |
| super().__init__() |
| assert activation == "prelu" |
|
|
| in_channels = emb_dim * emb_ks |
|
|
| self.intra_norm = nn.LayerNorm(emb_dim, eps=eps) |
| self.intra_rnn = nn.LSTM( |
| in_channels, hidden_channels, 1, batch_first=True, bidirectional=True |
| ) |
| if emb_ks == emb_hs: |
| self.intra_linear = nn.Linear(hidden_channels * 2, in_channels) |
| else: |
| self.intra_linear = nn.ConvTranspose1d( |
| hidden_channels * 2, emb_dim, emb_ks, stride=emb_hs |
| ) |
|
|
| self.inter_norm = nn.LayerNorm(emb_dim, eps=eps) |
| self.inter_rnn = nn.LSTM( |
| in_channels, hidden_channels, 1, batch_first=True, bidirectional=False |
| ) |
| if emb_ks == emb_hs: |
| self.inter_linear = nn.Linear(hidden_channels, in_channels) |
| else: |
| self.inter_linear = nn.ConvTranspose1d( |
| hidden_channels, emb_dim, emb_ks, stride=emb_hs |
| ) |
|
|
| E = math.ceil(approx_qk_dim * 1.0 / n_freqs) |
| assert emb_dim % n_head == 0 |
|
|
| self.add_module("attn_conv_Q", nn.Conv2d(emb_dim, n_head * E, 1)) |
| self.add_module( |
| "attn_norm_Q", |
| AllHeadPReLULayerNormalization4DCF((n_head, E, n_freqs), eps=eps), |
| ) |
|
|
| self.add_module("attn_conv_K", nn.Conv2d(emb_dim, n_head * E, 1)) |
| self.add_module( |
| "attn_norm_K", |
| AllHeadPReLULayerNormalization4DCF((n_head, E, n_freqs), eps=eps), |
| ) |
|
|
| self.add_module( |
| "attn_conv_V", nn.Conv2d(emb_dim, n_head * emb_dim // n_head, 1) |
| ) |
| self.add_module( |
| "attn_norm_V", |
| AllHeadPReLULayerNormalization4DCF( |
| (n_head, emb_dim // n_head, n_freqs), eps=eps |
| ), |
| ) |
|
|
| self.add_module( |
| "attn_concat_proj", |
| nn.Sequential( |
| nn.Conv2d(emb_dim, emb_dim, 1), |
| get_layer(activation)(), |
| LayerNormalization4DCF((emb_dim, n_freqs), eps=eps), |
| ), |
| ) |
|
|
| self.emb_dim = emb_dim |
| self.emb_ks = emb_ks |
| self.emb_hs = emb_hs |
| self.n_head = n_head |
|
|
| def forward(self, x): |
| """Forward pass for full-sequence (training) mode. |
| |
| Args: |
| x: [B, C, T, Q] |
| Returns: |
| out: [B, C, T, Q] |
| """ |
| B, C, old_T, old_Q = x.shape |
|
|
| olp = self.emb_ks - self.emb_hs |
| T = ( |
| math.ceil((old_T + 2 * olp - self.emb_ks) / self.emb_hs) * self.emb_hs |
| + self.emb_ks |
| ) |
| Q = ( |
| math.ceil((old_Q + 2 * olp - self.emb_ks) / self.emb_hs) * self.emb_hs |
| + self.emb_ks |
| ) |
|
|
| x = x.permute(0, 2, 3, 1) |
| x = F.pad(x, (0, 0, olp, Q - old_Q - olp, olp, T - old_T - olp)) |
|
|
| |
| input_ = x |
| intra_rnn = self.intra_norm(input_) |
| if self.emb_ks == self.emb_hs: |
| intra_rnn = intra_rnn.view([B * T, -1, self.emb_ks * C]) |
| intra_rnn, _ = self.intra_rnn(intra_rnn) |
| intra_rnn = self.intra_linear(intra_rnn) |
| intra_rnn = intra_rnn.view([B, T, Q, C]) |
| else: |
| intra_rnn = intra_rnn.view([B * T, Q, C]) |
| intra_rnn = intra_rnn.transpose(1, 2) |
| intra_rnn = F.unfold( |
| intra_rnn[..., None], (self.emb_ks, 1), stride=(self.emb_hs, 1) |
| ) |
| intra_rnn = intra_rnn.transpose(1, 2) |
| intra_rnn, _ = self.intra_rnn(intra_rnn) |
| intra_rnn = intra_rnn.transpose(1, 2) |
| intra_rnn = self.intra_linear(intra_rnn) |
| intra_rnn = intra_rnn.view([B, T, C, Q]) |
| intra_rnn = intra_rnn.transpose(-2, -1) |
| intra_rnn = intra_rnn + input_ |
|
|
| intra_rnn = intra_rnn.transpose(1, 2) |
|
|
| |
| input_ = intra_rnn |
| inter_rnn = self.inter_norm(input_) |
| if self.emb_ks == self.emb_hs: |
| inter_rnn = inter_rnn.view([B * Q, -1, self.emb_ks * C]) |
| inter_rnn, _ = self.inter_rnn(inter_rnn) |
| inter_rnn = self.inter_linear(inter_rnn) |
| inter_rnn = inter_rnn.view([B, Q, T, C]) |
| else: |
| inter_rnn = inter_rnn.view(B * Q, T, C) |
| inter_rnn = inter_rnn.transpose(1, 2) |
| inter_rnn = F.unfold( |
| inter_rnn[..., None], (self.emb_ks, 1), stride=(self.emb_hs, 1) |
| ) |
| inter_rnn = inter_rnn.transpose(1, 2) |
| inter_rnn, _ = self.inter_rnn(inter_rnn) |
| inter_rnn = inter_rnn.transpose(1, 2) |
| inter_rnn = self.inter_linear(inter_rnn) |
| inter_rnn = inter_rnn.view([B, Q, C, T]) |
| inter_rnn = inter_rnn.transpose(-2, -1) |
| inter_rnn = inter_rnn + input_ |
|
|
| inter_rnn = inter_rnn.permute(0, 3, 2, 1) |
| inter_rnn = inter_rnn[..., olp : olp + old_T, olp : olp + old_Q] |
| batch = inter_rnn |
|
|
| |
| |
| |
| |
| Q_ = self["attn_norm_Q"](self["attn_conv_Q"](batch)) |
| K_ = self["attn_norm_K"](self["attn_conv_K"](batch)) |
| V_ = self["attn_norm_V"](self["attn_conv_V"](batch)) |
| Q_ = Q_.view(-1, *Q_.shape[2:]) |
| K_ = K_.view(-1, *K_.shape[2:]) |
| V_ = V_.view(-1, *V_.shape[2:]) |
|
|
| Q_ = Q_.transpose(1, 2).flatten(start_dim=2) |
| K_ = K_.transpose(1, 2).flatten(start_dim=2) |
| V_ = V_.transpose(1, 2) |
| v_shape = V_.shape |
| V_ = V_.flatten(start_dim=2) |
|
|
| V_ = F.scaled_dot_product_attention(Q_, K_, V_, is_causal=True) |
|
|
| V_ = V_.reshape(v_shape) |
| V_ = V_.transpose(1, 2) |
| head_dim = V_.shape[1] |
|
|
| batch = V_.contiguous().view([B, self.n_head * head_dim, old_T, old_Q]) |
| batch = self["attn_concat_proj"](batch) |
|
|
| out = batch + inter_rnn |
| return out |
|
|