File size: 7,797 Bytes
eefb734
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from local.AllHeadPReLULayerNormalization4DCF import AllHeadPReLULayerNormalization4DCF
from local.LayerNormalization4DCF import LayerNormalization4DCF
from local.get_layer_from_string import get_layer


class StreamingGridNetV2Block(nn.Module):
    """Streaming-causal GridNetV2 block.

    Differences from the offline ``GridNetV2Block``:

    * ``inter_rnn`` (the time-axis LSTM) is unidirectional. ``intra_rnn`` stays
      bidirectional because it operates across frequency, not time.
    * The temporal self-attention applies a causal mask so that frame ``t`` only
      attends to frames ``<= t``. The mask is built on-the-fly so the same
      module works for any sequence length.

    All other tensor shapes match the offline block, so the encoder, decoder and
    cross-attention weights from the existing USEF-TP checkpoint warm-start
    directly. Only ``inter_rnn`` (bi -> uni) and ``inter_linear`` (input dim
    halved) differ in parameter count.
    """

    def __getitem__(self, key):
        return getattr(self, key)

    def __init__(self, emb_dim, emb_ks, emb_hs, n_freqs, hidden_channels,
                 n_head=4, approx_qk_dim=1024, activation="prelu", eps=1e-5):
        super().__init__()
        assert activation == "prelu"

        in_channels = emb_dim * emb_ks

        self.intra_norm = nn.LayerNorm(emb_dim, eps=eps)
        self.intra_rnn = nn.LSTM(
            in_channels, hidden_channels, 1, batch_first=True, bidirectional=True
        )
        if emb_ks == emb_hs:
            self.intra_linear = nn.Linear(hidden_channels * 2, in_channels)
        else:
            self.intra_linear = nn.ConvTranspose1d(
                hidden_channels * 2, emb_dim, emb_ks, stride=emb_hs
            )

        self.inter_norm = nn.LayerNorm(emb_dim, eps=eps)
        self.inter_rnn = nn.LSTM(
            in_channels, hidden_channels, 1, batch_first=True, bidirectional=False
        )
        if emb_ks == emb_hs:
            self.inter_linear = nn.Linear(hidden_channels, in_channels)
        else:
            self.inter_linear = nn.ConvTranspose1d(
                hidden_channels, emb_dim, emb_ks, stride=emb_hs
            )

        E = math.ceil(approx_qk_dim * 1.0 / n_freqs)
        assert emb_dim % n_head == 0

        self.add_module("attn_conv_Q", nn.Conv2d(emb_dim, n_head * E, 1))
        self.add_module(
            "attn_norm_Q",
            AllHeadPReLULayerNormalization4DCF((n_head, E, n_freqs), eps=eps),
        )

        self.add_module("attn_conv_K", nn.Conv2d(emb_dim, n_head * E, 1))
        self.add_module(
            "attn_norm_K",
            AllHeadPReLULayerNormalization4DCF((n_head, E, n_freqs), eps=eps),
        )

        self.add_module(
            "attn_conv_V", nn.Conv2d(emb_dim, n_head * emb_dim // n_head, 1)
        )
        self.add_module(
            "attn_norm_V",
            AllHeadPReLULayerNormalization4DCF(
                (n_head, emb_dim // n_head, n_freqs), eps=eps
            ),
        )

        self.add_module(
            "attn_concat_proj",
            nn.Sequential(
                nn.Conv2d(emb_dim, emb_dim, 1),
                get_layer(activation)(),
                LayerNormalization4DCF((emb_dim, n_freqs), eps=eps),
            ),
        )

        self.emb_dim = emb_dim
        self.emb_ks = emb_ks
        self.emb_hs = emb_hs
        self.n_head = n_head

    def forward(self, x):
        """Forward pass for full-sequence (training) mode.

        Args:
            x: [B, C, T, Q]
        Returns:
            out: [B, C, T, Q]
        """
        B, C, old_T, old_Q = x.shape

        olp = self.emb_ks - self.emb_hs
        T = (
            math.ceil((old_T + 2 * olp - self.emb_ks) / self.emb_hs) * self.emb_hs
            + self.emb_ks
        )
        Q = (
            math.ceil((old_Q + 2 * olp - self.emb_ks) / self.emb_hs) * self.emb_hs
            + self.emb_ks
        )

        x = x.permute(0, 2, 3, 1)  # [B, old_T, old_Q, C]
        x = F.pad(x, (0, 0, olp, Q - old_Q - olp, olp, T - old_T - olp))  # [B, T, Q, C]

        # intra RNN (over frequency, bidirectional — frequency is not the streaming axis)
        input_ = x
        intra_rnn = self.intra_norm(input_)  # [B, T, Q, C]
        if self.emb_ks == self.emb_hs:
            intra_rnn = intra_rnn.view([B * T, -1, self.emb_ks * C])
            intra_rnn, _ = self.intra_rnn(intra_rnn)
            intra_rnn = self.intra_linear(intra_rnn)
            intra_rnn = intra_rnn.view([B, T, Q, C])
        else:
            intra_rnn = intra_rnn.view([B * T, Q, C])
            intra_rnn = intra_rnn.transpose(1, 2)
            intra_rnn = F.unfold(
                intra_rnn[..., None], (self.emb_ks, 1), stride=(self.emb_hs, 1)
            )
            intra_rnn = intra_rnn.transpose(1, 2)
            intra_rnn, _ = self.intra_rnn(intra_rnn)
            intra_rnn = intra_rnn.transpose(1, 2)
            intra_rnn = self.intra_linear(intra_rnn)
            intra_rnn = intra_rnn.view([B, T, C, Q])
            intra_rnn = intra_rnn.transpose(-2, -1)
        intra_rnn = intra_rnn + input_

        intra_rnn = intra_rnn.transpose(1, 2)  # [B, Q, T, C]

        # inter RNN (over time, unidirectional — causal)
        input_ = intra_rnn
        inter_rnn = self.inter_norm(input_)  # [B, Q, T, C]
        if self.emb_ks == self.emb_hs:
            inter_rnn = inter_rnn.view([B * Q, -1, self.emb_ks * C])
            inter_rnn, _ = self.inter_rnn(inter_rnn)
            inter_rnn = self.inter_linear(inter_rnn)
            inter_rnn = inter_rnn.view([B, Q, T, C])
        else:
            inter_rnn = inter_rnn.view(B * Q, T, C)
            inter_rnn = inter_rnn.transpose(1, 2)
            inter_rnn = F.unfold(
                inter_rnn[..., None], (self.emb_ks, 1), stride=(self.emb_hs, 1)
            )
            inter_rnn = inter_rnn.transpose(1, 2)
            inter_rnn, _ = self.inter_rnn(inter_rnn)
            inter_rnn = inter_rnn.transpose(1, 2)
            inter_rnn = self.inter_linear(inter_rnn)
            inter_rnn = inter_rnn.view([B, Q, C, T])
            inter_rnn = inter_rnn.transpose(-2, -1)
        inter_rnn = inter_rnn + input_

        inter_rnn = inter_rnn.permute(0, 3, 2, 1)  # [B, C, T, Q]
        inter_rnn = inter_rnn[..., olp : olp + old_T, olp : olp + old_Q]
        batch = inter_rnn

        # Causal self-attention over time. Uses F.scaled_dot_product_attention,
        # which dispatches to Flash Attention 2 on H100 / Ampere+ GPUs and
        # handles the causal mask internally (much faster than building a
        # T x T mask and a manual matmul + softmax).
        Q_ = self["attn_norm_Q"](self["attn_conv_Q"](batch))
        K_ = self["attn_norm_K"](self["attn_conv_K"](batch))
        V_ = self["attn_norm_V"](self["attn_conv_V"](batch))
        Q_ = Q_.view(-1, *Q_.shape[2:])
        K_ = K_.view(-1, *K_.shape[2:])
        V_ = V_.view(-1, *V_.shape[2:])

        Q_ = Q_.transpose(1, 2).flatten(start_dim=2)    # [B', T, C*Q]
        K_ = K_.transpose(1, 2).flatten(start_dim=2)    # [B', T, C*Q]
        V_ = V_.transpose(1, 2)                         # [B', T, C, Q]
        v_shape = V_.shape
        V_ = V_.flatten(start_dim=2)                    # [B', T, C*Q]

        V_ = F.scaled_dot_product_attention(Q_, K_, V_, is_causal=True)  # [B', T, C*Q]

        V_ = V_.reshape(v_shape)                                               # [B', T, C, Q]
        V_ = V_.transpose(1, 2)                                                # [B', C, T, Q]
        head_dim = V_.shape[1]

        batch = V_.contiguous().view([B, self.n_head * head_dim, old_T, old_Q])
        batch = self["attn_concat_proj"](batch)

        out = batch + inter_rnn
        return out