AlgoX commited on
Commit
f6911e7
·
1 Parent(s): 995292c

feat : add mlstm and slstm blocks

Browse files
Files changed (1) hide show
  1. model/xlstm.py +290 -0
model/xlstm.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ import math
6
+
7
+ def get_model_device(model):
8
+ return next(iter(model.parameters())).device
9
+
10
+
11
+ class MLSTMCell(nn.Module):
12
+ def __init__(self, hidden_size: int, num_heads: int = 8):
13
+ super().__init__()
14
+ self.hidden_size = hidden_size
15
+ self.num_heads = num_heads
16
+ self.head_size = hidden_size // num_heads
17
+ self.eps = 1e-6
18
+
19
+ self.igate_proj = nn.Linear(3 * hidden_size, num_heads, bias=True)
20
+ self.fgate_proj = nn.Linear(3 * hidden_size, num_heads, bias=True)
21
+ self.outnorm = nn.GroupNorm(num_groups=num_heads, num_channels=hidden_size)
22
+
23
+ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, state):
24
+ batch_size, hidden_size = q.shape
25
+
26
+ cell_state, norm_state, max_state = state
27
+
28
+ qkv_cat = torch.cat([q, k, v], dim=-1)
29
+ igate_preact = self.igate_proj(qkv_cat)
30
+ fgate_preact = self.fgate_proj(qkv_cat)
31
+
32
+ q = q.view(batch_size, self.num_heads, self.head_size)
33
+ k = k.view(batch_size, self.num_heads, self.head_size)
34
+ v = v.view(batch_size, self.num_heads, self.head_size)
35
+
36
+ # Stabilization for gates
37
+ log_f = torch.nn.functional.logsigmoid(fgate_preact)
38
+
39
+ max_new = torch.maximum(igate_preact, max_state + log_f)
40
+
41
+ i_gate = torch.exp(igate_preact - max_new)
42
+ f_gate = torch.exp(log_f + max_state - max_new)
43
+
44
+ # Scale keys
45
+ k = k / math.sqrt(self.head_size)
46
+
47
+ # Update memory and normalizer
48
+ # C_new = f * C + i * k^T * v
49
+ cell_new = (
50
+ f_gate[:, :, None, None] * cell_state
51
+ + i_gate[:, :, None, None] * k[:, :, :, None] * v[:, :, None]
52
+ )
53
+ # n_new = f * n + i * k
54
+ norm_new = f_gate[:, :, None] * norm_state + i_gate[:, :, None] * k
55
+
56
+ # Compute output: h = (q @ C) / max(q @ n, 1)
57
+ numerator = torch.einsum("bnh,bnhk->bnk", q, cell_new)
58
+ qn_dotproduct = torch.einsum("bnh,bnh->bn", q, norm_new)
59
+ max_val = torch.exp(-max_new)
60
+ denominator = torch.maximum(qn_dotproduct.abs(), max_val) + self.eps
61
+ out = numerator / denominator[:, :, None]
62
+
63
+ out = self.outnorm(out.view(batch_size, self.hidden_size))
64
+
65
+ out = out.reshape(batch_size, self.hidden_size)
66
+
67
+ assert cell_new.shape == cell_state.shape
68
+ assert norm_new.shape == norm_state.shape
69
+ assert max_new.shape == max_state.shape
70
+
71
+ return out, (cell_new, norm_new, max_new)
72
+
73
+ def init_state(self, batch_size: int, device: torch.device):
74
+ return (
75
+ torch.zeros(
76
+ batch_size,
77
+ self.num_heads,
78
+ self.head_size,
79
+ self.head_size,
80
+ device=device,
81
+ ),
82
+ torch.zeros(batch_size, self.num_heads, self.head_size, device=device),
83
+ torch.zeros(batch_size, self.num_heads, device=device),
84
+ )
85
+
86
+
87
+ class CausalConv1d(nn.Module):
88
+
89
+ def __init__(self, hidden_size, kernel_size):
90
+ super().__init__()
91
+ self.hidden_size = hidden_size
92
+ self.kernel_size = kernel_size
93
+ self.conv = nn.Conv1d(
94
+ hidden_size, hidden_size, kernel_size, groups=hidden_size, bias=True
95
+ )
96
+
97
+ def init_state(self, batch_size: int, device: torch.device | None = None):
98
+ if device is None:
99
+ device = get_model_device(self)
100
+ return torch.zeros(
101
+ batch_size, self.hidden_size, self.kernel_size - 1, device=device
102
+ )
103
+
104
+ def forward(self, x: torch.Tensor, state: torch.Tensor):
105
+ x_with_state = torch.concat([state, x[:, :, None]], dim=-1)
106
+ out = self.conv(x_with_state)
107
+ new_state = x_with_state[:, :, 1:]
108
+ return out.squeeze(-1), new_state
109
+
110
+
111
+ class BlockLinear(nn.Module):
112
+
113
+ def __init__(self, num_blocks: int, hidden_size: int, bias: bool = True):
114
+ super().__init__()
115
+ self.num_blocks = num_blocks
116
+ self.block_size = hidden_size // num_blocks
117
+ self.hidden_size = hidden_size
118
+ self.weight = nn.Parameter(
119
+ torch.empty(num_blocks, self.block_size, self.block_size)
120
+ )
121
+ if bias:
122
+ self.bias = nn.Parameter(torch.empty(self.hidden_size))
123
+ else:
124
+ self.bias = None
125
+
126
+ def forward(self, x):
127
+ batch_size = x.shape[0]
128
+ assert x.shape[1] == self.hidden_size
129
+ x = x.view(batch_size, self.num_blocks, self.block_size)
130
+ out = torch.einsum("bnh,nkh->bnk", x, self.weight)
131
+ out = out.reshape(batch_size, self.hidden_size)
132
+ if self.bias is not None:
133
+ out += self.bias
134
+ return out
135
+
136
+
137
+ class MLSTMBlock(nn.Module):
138
+ def __init__(
139
+ self,
140
+ hidden_size: int,
141
+ num_heads: int = 8,
142
+ conv_kernel_size: int = 4,
143
+ qkv_proj_block_size: int = 4,
144
+ expand_factor: int = 2,
145
+ ):
146
+ super().__init__()
147
+ self.hidden_size = hidden_size
148
+ self.num_heads = num_heads
149
+
150
+ self.inner_size = expand_factor * hidden_size
151
+
152
+ self.norm = nn.LayerNorm(hidden_size, bias=False)
153
+
154
+ self.x_proj = nn.Linear(hidden_size, self.inner_size, bias=False)
155
+ self.gate_proj = nn.Linear(hidden_size, self.inner_size, bias=False)
156
+
157
+ num_blocks = self.inner_size // qkv_proj_block_size
158
+ self.q_proj = BlockLinear(num_blocks, self.inner_size, bias=False)
159
+ self.k_proj = BlockLinear(num_blocks, self.inner_size, bias=False)
160
+ self.v_proj = BlockLinear(num_blocks, self.inner_size, bias=False)
161
+
162
+ self.conv1d = CausalConv1d(self.inner_size, kernel_size=conv_kernel_size)
163
+
164
+ self.mlstm_cell = MLSTMCell(self.inner_size, num_heads)
165
+ self.proj_down = nn.Linear(self.inner_size, hidden_size, bias=False)
166
+ self.learnable_skip = nn.Parameter(torch.ones(self.inner_size))
167
+
168
+ self.head_size = self.inner_size // num_heads
169
+
170
+ def forward(self, x: torch.Tensor, state):
171
+ conv_state, recurrent_state = state
172
+
173
+ skip = x
174
+
175
+ x = self.norm(x)
176
+ x_mlstm = self.x_proj(x)
177
+ x_gate = self.gate_proj(x)
178
+
179
+ x_conv, new_conv_state = self.conv1d(x_mlstm, conv_state)
180
+ x_mlstm_conv = F.silu(x_conv)
181
+
182
+ q = self.q_proj(x_mlstm_conv)
183
+ k = self.k_proj(x_mlstm_conv)
184
+ v = self.v_proj(x_mlstm)
185
+
186
+ mlstm_out, new_recurrent_state = self.mlstm_cell(q, k, v, recurrent_state)
187
+
188
+ mlstm_out_skip = mlstm_out + (self.learnable_skip * x_mlstm_conv)
189
+ h_state = mlstm_out_skip * F.silu(x_gate)
190
+ y = self.proj_down(h_state)
191
+
192
+ return y + skip, (new_conv_state, new_recurrent_state)
193
+
194
+ def init_state(self, batch_size: int, device: torch.device):
195
+ return (
196
+ self.conv1d.init_state(batch_size, device),
197
+ self.mlstm_cell.init_state(batch_size, device),
198
+ )
199
+
200
+
201
+ class SLSTMCell(nn.Module):
202
+ def __init__(self, hidden_size: int, num_heads: int = 4):
203
+ super().__init__()
204
+ self.hidden_size = hidden_size
205
+ self.num_heads = num_heads
206
+ self.head_size = hidden_size // num_heads
207
+ self.eps = 1e-6
208
+
209
+ def forward(
210
+ self,
211
+ i: torch.Tensor,
212
+ f: torch.Tensor,
213
+ z: torch.Tensor,
214
+ o: torch.Tensor,
215
+ state,
216
+ ):
217
+ cell_state, norm_state, max_state = state
218
+
219
+ log_f_plus_m = max_state + torch.nn.functional.logsigmoid(f)
220
+
221
+ # Use torch.where to avoid data-dependent branching
222
+ max_new = torch.maximum(i, log_f_plus_m)
223
+
224
+ # Compute stabilized exponential gates
225
+ o_gate = torch.sigmoid(o)
226
+ i_gate = torch.exp(i - max_new)
227
+ f_gate = torch.exp(log_f_plus_m - max_new)
228
+
229
+ cell_new = f_gate * cell_state + i_gate * torch.tanh(z)
230
+ norm_new = f_gate * norm_state + i_gate
231
+ y_new = o_gate * cell_new / (norm_new + self.eps)
232
+
233
+ return y_new, (cell_new, norm_new, max_new)
234
+
235
+ def init_state(self, batch_size: int, device: torch.device):
236
+ return (
237
+ torch.zeros(batch_size, self.hidden_size, device=device),
238
+ torch.zeros(batch_size, self.hidden_size, device=device),
239
+ torch.zeros(batch_size, self.hidden_size, device=device) - float("inf"),
240
+ )
241
+
242
+
243
+ class SLSTMBlock(nn.Module):
244
+ def __init__(self, hidden_size: int, num_heads: int = 4, conv_kernel_size: int = 4):
245
+ super().__init__()
246
+ self.hidden_size = hidden_size
247
+ self.num_heads = num_heads
248
+
249
+ self.norm = nn.LayerNorm(hidden_size, bias=False)
250
+ self.conv1d = CausalConv1d(hidden_size, kernel_size=conv_kernel_size)
251
+ self.igate_input = BlockLinear(num_heads, hidden_size, bias=False)
252
+ self.fgate_input = BlockLinear(num_heads, hidden_size, bias=False)
253
+ self.zgate_input = BlockLinear(num_heads, hidden_size, bias=False)
254
+ self.ogate_input = BlockLinear(num_heads, hidden_size, bias=False)
255
+
256
+ self.igate_state = BlockLinear(num_heads, hidden_size)
257
+ self.fgate_state = BlockLinear(num_heads, hidden_size)
258
+ self.zgate_state = BlockLinear(num_heads, hidden_size)
259
+ self.ogate_state = BlockLinear(num_heads, hidden_size)
260
+
261
+ self.slstm_cell = SLSTMCell(hidden_size, num_heads)
262
+ self.group_norm = nn.GroupNorm(num_groups=num_heads, num_channels=hidden_size)
263
+
264
+ def forward(self, x: torch.Tensor, state):
265
+ conv_state, recurrent_state, slstm_state = state
266
+
267
+ skip = x
268
+ x = self.norm(x)
269
+
270
+ x_conv, new_conv_state = self.conv1d(x, conv_state)
271
+ x_conv_act = F.silu(x_conv)
272
+
273
+ i = self.igate_input(x_conv_act) + self.igate_state(recurrent_state)
274
+ f = self.fgate_input(x_conv_act) + self.fgate_state(recurrent_state)
275
+ z = self.zgate_input(x) + self.zgate_state(recurrent_state)
276
+ o = self.ogate_input(x) + self.ogate_state(recurrent_state)
277
+
278
+ new_recurrent_state, new_slstm_state = self.slstm_cell(i, f, z, o, slstm_state)
279
+ slstm_out = self.group_norm(new_recurrent_state)
280
+
281
+ return slstm_out + skip, (new_conv_state, new_recurrent_state, new_slstm_state)
282
+
283
+ def init_state(self, batch_size: int, device: torch.device):
284
+ return (
285
+ self.conv1d.init_state(batch_size, device),
286
+ torch.zeros(batch_size, self.hidden_size, device=device),
287
+ self.slstm_cell.init_state(batch_size, device),
288
+ )
289
+
290
+