ishanjmukherjee commited on
Commit
43539ed
·
1 Parent(s): 8182ebe

Copy Python verbatim from vortex

Browse files
Files changed (10) hide show
  1. attention.py +999 -0
  2. cache.py +62 -0
  3. engine.py +597 -0
  4. generation.py +373 -0
  5. layers.py +272 -0
  6. model.py +937 -0
  7. positional_embeddings.py +114 -0
  8. sample.py +60 -0
  9. special_tokens_map.json +1 -0
  10. utils.py +251 -0
attention.py ADDED
@@ -0,0 +1,999 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from functools import partial
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange, repeat
7
+
8
+ from .utils import get_dim_for_local_rank
9
+
10
+ # Not bothering with ops right now
11
+ # try:
12
+ # from vortex.ops import (
13
+ # local_flash_attn_kvpacked_func,
14
+ # local_flash_attn_qkvpacked_func,
15
+ # local_flash_attn_varlen_kvpacked_func,
16
+ # local_flash_attn_varlen_qkvpacked_func,
17
+ # local_flash_attn_with_kvcache,
18
+ # )
19
+ # except ImportError:
20
+ # local_flash_attn_varlen_qkvpacked_func, local_flash_attn_varlen_kvpacked_func = (
21
+ # None,
22
+ # None,
23
+ # )
24
+ # local_flash_attn_qkvpacked_func, local_flash_attn_kvpacked_func = None, None
25
+ # local_flash_attn_with_kvcache = None
26
+
27
+ local_flash_attn_varlen_qkvpacked_func, local_flash_attn_varlen_kvpacked_func = (
28
+ None,
29
+ None,
30
+ )
31
+ local_flash_attn_qkvpacked_func, local_flash_attn_kvpacked_func = None, None
32
+ local_flash_attn_with_kvcache = None
33
+
34
+ FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
35
+
36
+ from .rotary import RotaryEmbedding
37
+
38
+
39
+ # From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
40
+ def get_alibi_slopes(nheads):
41
+ def get_slopes_power_of_2(nheads):
42
+ start = 2 ** (-(2 ** -(math.log2(nheads) - 3)))
43
+ ratio = start
44
+ return [start * ratio**i for i in range(nheads)]
45
+
46
+ if math.log2(nheads).is_integer():
47
+ return get_slopes_power_of_2(nheads)
48
+ else:
49
+ closest_power_of_2 = 2 ** math.floor(math.log2(nheads))
50
+ return (
51
+ get_slopes_power_of_2(closest_power_of_2)
52
+ + get_alibi_slopes(2 * closest_power_of_2)[0::2][: nheads - closest_power_of_2]
53
+ )
54
+
55
+
56
+ class FlashSelfAttention(nn.Module):
57
+ """Implement the scaled dot product attention with softmax.
58
+ Arguments
59
+ ---------
60
+ softmax_scale: The temperature to use for the softmax attention.
61
+ (default: 1/sqrt(d_keys) where d_keys is computed at
62
+ runtime)
63
+ attention_dropout: The dropout rate to apply to the attention
64
+ (default: 0.0)
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ layer_number,
70
+ causal=False,
71
+ softmax_scale=None,
72
+ attention_dropout=0.0,
73
+ window_size=(-1, -1),
74
+ alibi_slopes=None,
75
+ deterministic=False,
76
+ ):
77
+ super().__init__()
78
+ assert local_flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
79
+ assert local_flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
80
+ self.layer_number = layer_number
81
+ self.causal = causal
82
+ self.softmax_scale = softmax_scale
83
+ self.drop = nn.Dropout(attention_dropout)
84
+ self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
85
+ self.window_size = window_size
86
+ self.deterministic = deterministic
87
+
88
+ def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
89
+ """Implements the multihead softmax attention.
90
+ Arguments
91
+ ---------
92
+ qkv: The tensor containing the query, key, and value.
93
+ If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
94
+ If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
95
+ (total, 3, H, D), where total is the sum of the sequence lengths in the batch.
96
+ causal: if passed, will override self.causal
97
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
98
+ of the sequences in the batch, used to index into qkv.
99
+ max_seqlen: int. Maximum sequence length in the batch.
100
+ Returns:
101
+ --------
102
+ out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
103
+ else (B, S, H, D).
104
+ """
105
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
106
+ assert qkv.is_cuda
107
+
108
+ causal = self.causal if causal is None else causal
109
+ unpadded = cu_seqlens is not None
110
+ if self.alibi_slopes is not None:
111
+ self.alibi_slopes = self.alibi_slopes.to(torch.float32)
112
+ if unpadded:
113
+ assert cu_seqlens.dtype == torch.int32
114
+ assert max_seqlen is not None
115
+ assert isinstance(max_seqlen, int)
116
+ return local_flash_attn_varlen_qkvpacked_func(
117
+ qkv,
118
+ cu_seqlens,
119
+ max_seqlen,
120
+ self.drop.p if self.training else 0.0,
121
+ softmax_scale=self.softmax_scale,
122
+ causal=causal,
123
+ alibi_slopes=self.alibi_slopes,
124
+ window_size=self.window_size,
125
+ deterministic=self.deterministic,
126
+ )
127
+ else:
128
+ y = local_flash_attn_qkvpacked_func(
129
+ qkv,
130
+ self.drop.p if self.training else 0.0,
131
+ softmax_scale=self.softmax_scale,
132
+ causal=causal,
133
+ alibi_slopes=self.alibi_slopes,
134
+ window_size=self.window_size,
135
+ deterministic=self.deterministic,
136
+ )
137
+
138
+ return y
139
+
140
+
141
+ class FlashCrossAttention(nn.Module):
142
+ """Implement the scaled dot product attention with softmax.
143
+ Arguments
144
+ ---------
145
+ softmax_scale: The temperature to use for the softmax attention.
146
+ (default: 1/sqrt(d_keys) where d_keys is computed at
147
+ runtime)
148
+ attention_dropout: The dropout rate to apply to the attention
149
+ (default: 0.0)
150
+ """
151
+
152
+ def __init__(
153
+ self,
154
+ causal=False,
155
+ softmax_scale=None,
156
+ attention_dropout=0.0,
157
+ alibi_slopes=None,
158
+ window_size=(-1, -1),
159
+ deterministic=False,
160
+ ):
161
+ super().__init__()
162
+ assert local_flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
163
+ assert local_flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
164
+ self.causal = causal
165
+ self.softmax_scale = softmax_scale
166
+ self.drop = nn.Dropout(attention_dropout)
167
+ self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
168
+ self.window_size = window_size
169
+ self.deterministic = deterministic
170
+
171
+ def forward(
172
+ self,
173
+ q,
174
+ kv,
175
+ causal=None,
176
+ cu_seqlens=None,
177
+ max_seqlen=None,
178
+ cu_seqlens_k=None,
179
+ max_seqlen_k=None,
180
+ ):
181
+ """Implements the multihead softmax attention.
182
+ Arguments
183
+ ---------
184
+ q: The tensor containing the query. (B, Sq, H, D)
185
+ kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
186
+ causal: if passed, will override self.causal
187
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
188
+ of the sequences in the batch, used to index into q.
189
+ max_seqlen: int. Maximum sequence length in the batch of q.
190
+ cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
191
+ of the sequences in the batch, used to index into kv.
192
+ max_seqlen_k: int. Maximum sequence length in the batch of k and v.
193
+ """
194
+ assert q.dtype in [torch.float16, torch.bfloat16]
195
+ assert q.is_cuda and kv.is_cuda
196
+ causal = self.causal if causal is None else causal
197
+ unpadded = cu_seqlens is not None
198
+ if self.alibi_slopes is not None:
199
+ self.alibi_slopes = self.alibi_slopes.to(torch.float32)
200
+ if unpadded:
201
+ assert cu_seqlens.dtype == torch.int32
202
+ assert max_seqlen is not None
203
+ assert isinstance(max_seqlen, int)
204
+ assert cu_seqlens_k is not None
205
+ assert cu_seqlens_k.dtype == torch.int32
206
+ assert max_seqlen_k is not None
207
+ assert isinstance(max_seqlen_k, int)
208
+ return local_flash_attn_varlen_kvpacked_func(
209
+ q,
210
+ kv,
211
+ cu_seqlens,
212
+ cu_seqlens_k,
213
+ max_seqlen,
214
+ max_seqlen_k,
215
+ self.drop.p if self.training else 0.0,
216
+ softmax_scale=self.softmax_scale,
217
+ causal=causal,
218
+ alibi_slopes=self.alibi_slopes,
219
+ window_size=self.window_size,
220
+ deterministic=self.deterministic,
221
+ )
222
+ else:
223
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
224
+ seqlen_k = kv.shape[1]
225
+ assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
226
+ return local_flash_attn_kvpacked_func(
227
+ q,
228
+ kv,
229
+ self.drop.p if self.training else 0.0,
230
+ causal=causal,
231
+ softmax_scale=self.softmax_scale,
232
+ alibi_slopes=self.alibi_slopes,
233
+ window_size=self.window_size,
234
+ deterministic=self.deterministic,
235
+ )
236
+
237
+
238
+ class SelfAttention(nn.Module):
239
+ """Implement the scaled dot product attention with softmax.
240
+ Arguments
241
+ ---------
242
+ softmax_scale: The temperature to use for the softmax attention.
243
+ (default: 1/sqrt(d_keys) where d_keys is computed at
244
+ runtime)
245
+ attention_dropout: The dropout rate to apply to the attention
246
+ (default: 0.0)
247
+ """
248
+
249
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
250
+ super().__init__()
251
+ self.causal = causal
252
+ self.softmax_scale = softmax_scale
253
+ self.drop = nn.Dropout(attention_dropout)
254
+
255
+ def forward(self, qkv, causal=None, key_padding_mask=None):
256
+ """Implements the multihead softmax attention.
257
+ Arguments
258
+ ---------
259
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
260
+ causal: if passed, will override self.causal
261
+ key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
262
+ False means to mask out. (B, S)
263
+ """
264
+ q, k, v = qkv.unbind(dim=2) # each: (B, T, H, D)
265
+ q = q.permute(0, 2, 1, 3) # (B, H, T, D)
266
+ k = k.permute(0, 2, 1, 3)
267
+ v = v.permute(0, 2, 1, 3)
268
+ batch_size, num_heads, seqlen, d = q.shape
269
+
270
+ scale = self.softmax_scale if self.softmax_scale is not None else 1.0 / math.sqrt(d)
271
+ q = q * (scale * math.sqrt(d))
272
+
273
+ attn_mask = None
274
+ if key_padding_mask is not None:
275
+ attn_mask = torch.where(
276
+ repeat(key_padding_mask, "b s -> b t s", t=seqlen),
277
+ 0.0,
278
+ -10000.0,
279
+ )
280
+
281
+ output = torch.nn.functional.scaled_dot_product_attention(
282
+ q,
283
+ k,
284
+ v,
285
+ attn_mask=attn_mask,
286
+ dropout_p=self.drop.p if self.training else 0.0,
287
+ is_causal=(self.causal if causal is None else causal),
288
+ )
289
+
290
+ output = output.permute(0, 2, 1, 3)
291
+ return output
292
+
293
+
294
+ class CrossAttention(nn.Module):
295
+ """Implement the scaled dot product attention with softmax.
296
+ Arguments
297
+ ---------
298
+ softmax_scale: The temperature to use for the softmax attention.
299
+ (default: 1/sqrt(d_keys) where d_keys is computed at
300
+ runtime)
301
+ attention_dropout: The dropout rate to apply to the attention
302
+ (default: 0.0)
303
+ """
304
+
305
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
306
+ super().__init__()
307
+ self.causal = causal
308
+ self.softmax_scale = softmax_scale
309
+ self.drop = nn.Dropout(attention_dropout)
310
+
311
+ def forward(self, q, kv, causal=None, key_padding_mask=None):
312
+ """Implements the multihead softmax attention.
313
+ Arguments
314
+ ---------
315
+ q: The tensor containing the query. (B, Sq, H, D)
316
+ kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
317
+ causal: if passed, will override self.causal
318
+ key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
319
+ False means to mask out. (B, Sk)
320
+ """
321
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
322
+ causal = self.causal if causal is None else causal
323
+ seqlen_k = kv.shape[1]
324
+ assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
325
+ if kv.shape[3] != q.shape[2]: # MQA/GQA
326
+ kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
327
+ k, v = kv.unbind(dim=2)
328
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
329
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
330
+ if key_padding_mask is not None:
331
+ padding_mask = torch.full(
332
+ (batch_size, seqlen_k),
333
+ -10000.0,
334
+ dtype=scores.dtype,
335
+ device=scores.device,
336
+ )
337
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
338
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
339
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
340
+ if causal:
341
+ # causal mask needs to take into account the difference between seqlen_q and seqlen_k
342
+ row_idx = rearrange(torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1")
343
+ col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long)
344
+ sk = seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
345
+ causal_mask = col_idx > row_idx + sk - seqlen_q
346
+ scores = scores.masked_fill(causal_mask, -10000.0)
347
+ attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
348
+ attention_drop = self.drop(attention)
349
+ output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
350
+ return output
351
+
352
+
353
+ class LinearResidual(nn.Linear):
354
+ """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
355
+
356
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
357
+ return super().forward(input), input
358
+
359
+
360
+ def _update_kv_cache(kv, inference_params, layer_idx):
361
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
362
+ # Pre-allocate memory for key-values for inference.
363
+ num_heads, head_dim = kv.shape[-2:]
364
+ if layer_idx not in inference_params.key_value_memory_dict:
365
+ kv_cache = torch.empty(
366
+ inference_params.max_batch_size,
367
+ inference_params.max_seqlen,
368
+ 2,
369
+ num_heads,
370
+ head_dim,
371
+ dtype=kv.dtype,
372
+ device=kv.device,
373
+ )
374
+ inference_params.key_value_memory_dict[layer_idx] = kv_cache
375
+ else:
376
+ kv_cache = inference_params.key_value_memory_dict[layer_idx]
377
+ # Adjust key and value for inference
378
+ batch_start = inference_params.batch_size_offset
379
+ batch_end = batch_start + kv.shape[0]
380
+ sequence_start = inference_params.seqlen_offset
381
+ sequence_end = sequence_start + kv.shape[1]
382
+ assert batch_end <= kv_cache.shape[0]
383
+ assert sequence_end <= kv_cache.shape[1]
384
+ assert kv_cache is not None
385
+ kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
386
+ return kv_cache[batch_start:batch_end, :sequence_end, ...]
387
+
388
+
389
+ class MHA(nn.Module):
390
+ """Multi-head self-attention and cross-attention"""
391
+
392
+ def __init__(
393
+ self,
394
+ embed_dim,
395
+ num_heads,
396
+ num_heads_kv=None,
397
+ cross_attn=False,
398
+ qkv_proj_bias=True,
399
+ out_proj_bias=True,
400
+ dropout=0.0,
401
+ softmax_scale=None,
402
+ causal=False,
403
+ layer_idx=None,
404
+ dwconv=False,
405
+ rotary_emb_dim=0,
406
+ rotary_emb_base=10000.0,
407
+ rotary_emb_scale_base=None,
408
+ rotary_emb_interleaved=False,
409
+ use_alibi=False,
410
+ window_size=(-1, -1),
411
+ fused_bias_fc=False,
412
+ use_flash_attn=False,
413
+ return_residual=False,
414
+ checkpointing=False,
415
+ device=None,
416
+ dtype=None,
417
+ ) -> None:
418
+ """
419
+ num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
420
+ return_residual: whether to return the input x along with the output. This is for
421
+ performance reason: for post-norm architecture, returning the input allows us
422
+ to fuse the backward of nn.Linear with the residual connection.
423
+ """
424
+ factory_kwargs = {"device": device, "dtype": dtype}
425
+ super().__init__()
426
+ self.embed_dim = embed_dim
427
+ self.cross_attn = cross_attn
428
+ self.causal = causal
429
+ self.layer_idx = layer_idx
430
+ self.dwconv = dwconv
431
+ self.rotary_emb_dim = rotary_emb_dim
432
+ self.use_flash_attn = use_flash_attn
433
+ self.return_residual = return_residual
434
+ self.checkpointing = checkpointing
435
+ if use_alibi:
436
+ assert use_flash_attn, "ALiBi code path requires flash_attn"
437
+ alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
438
+ else:
439
+ alibi_slopes = None
440
+ if window_size != (-1, -1):
441
+ assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
442
+
443
+ self.num_heads = num_heads
444
+ self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
445
+ assert self.num_heads % self.num_heads_kv == 0, "num_heads must be divisible by num_heads_kv"
446
+ assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
447
+ self.head_dim = self.embed_dim // num_heads
448
+ qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
449
+ kv_dim = 2 * self.head_dim * self.num_heads_kv
450
+
451
+ if self.rotary_emb_dim > 0:
452
+ assert not cross_attn, "MHA with rotary embedding does not support cross-attention yet"
453
+ assert RotaryEmbedding is not None, "rotary_emb is not installed"
454
+ self.rotary_emb = RotaryEmbedding(
455
+ self.rotary_emb_dim,
456
+ base=rotary_emb_base,
457
+ scale_base=rotary_emb_scale_base,
458
+ interleaved=rotary_emb_interleaved,
459
+ device=device,
460
+ )
461
+
462
+ if fused_bias_fc and FusedDense is None:
463
+ raise ImportError("fused_dense is not installed")
464
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
465
+ linear_resid_cls = LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
466
+ wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
467
+ inner_attn_cls = (
468
+ partial(
469
+ FlashSelfAttention,
470
+ layer_number=self.layer_idx,
471
+ alibi_slopes=alibi_slopes,
472
+ window_size=window_size,
473
+ )
474
+ if use_flash_attn
475
+ else SelfAttention
476
+ )
477
+ inner_cross_attn_cls = (
478
+ partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
479
+ if use_flash_attn
480
+ else CrossAttention
481
+ )
482
+ if not self.cross_attn:
483
+ self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
484
+ else:
485
+ self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
486
+ self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
487
+ if self.dwconv:
488
+ if self.num_heads_kv == self.num_heads:
489
+ self.dwconv_qkv = nn.Conv1d(qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim)
490
+ else:
491
+ self.dwconv_q = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim)
492
+ self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim)
493
+ self.inner_attn = inner_attn_cls(
494
+ causal=causal,
495
+ softmax_scale=softmax_scale,
496
+ attention_dropout=dropout,
497
+ )
498
+ self.inner_cross_attn = inner_cross_attn_cls(
499
+ causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
500
+ )
501
+ self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
502
+
503
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
504
+ dtype = self.out_proj.weight.dtype if dtype is None else dtype
505
+ device = self.out_proj.weight.device
506
+ return torch.empty(
507
+ batch_size,
508
+ max_seqlen,
509
+ 2,
510
+ self.num_heads_kv,
511
+ self.head_dim,
512
+ dtype=dtype,
513
+ device=device,
514
+ )
515
+
516
+ def _update_kv_cache(self, kv, inference_params):
517
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
518
+ assert not self.dwconv, "Generation does not support dwconv yet"
519
+ assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
520
+ return _update_kv_cache(kv, inference_params, self.layer_idx)
521
+
522
+ def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
523
+ """
524
+ Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
525
+ q: (batch_size, seqlen_q, nheads, head_dim)
526
+ kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
527
+ """
528
+ assert inference_params is not None and inference_params.seqlen_offset > 0
529
+ assert self.use_flash_attn
530
+ if self.rotary_emb_dim > 0:
531
+ assert self.rotary_emb.scale is None, "This code path does not support xPos"
532
+ self.rotary_emb._update_cos_sin_cache(inference_params.max_seqlen, device=q.device, dtype=q.dtype)
533
+ rotary_cos, rotary_sin = (
534
+ self.rotary_emb._cos_cached,
535
+ self.rotary_emb._sin_cached,
536
+ )
537
+ else:
538
+ rotary_cos, rotary_sin = None, None
539
+ batch = q.shape[0]
540
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
541
+ cache_seqlens = (
542
+ inference_params.lengths_per_sample[:batch]
543
+ if inference_params.lengths_per_sample is not None
544
+ else inference_params.seqlen_offset
545
+ )
546
+ alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
547
+ context = local_flash_attn_with_kvcache(
548
+ q,
549
+ kv_cache[:, :, 0],
550
+ kv_cache[:, :, 1],
551
+ kv[:, :, 0],
552
+ kv[:, :, 1],
553
+ rotary_cos=rotary_cos,
554
+ rotary_sin=rotary_sin,
555
+ cache_seqlens=cache_seqlens,
556
+ softmax_scale=self.inner_cross_attn.softmax_scale,
557
+ causal=self.inner_cross_attn.causal,
558
+ rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
559
+ alibi_slopes=alibi_slopes,
560
+ )
561
+ return context
562
+
563
+ def _update_kvcache_attention(self, q, kv, inference_params):
564
+ """Write kv to inference_params, then do attention"""
565
+ if inference_params.seqlen_offset == 0 or local_flash_attn_with_kvcache is None or not self.use_flash_attn:
566
+ # TODO: this only uses seqlen_offset and not lengths_per_sample.
567
+ kv = self._update_kv_cache(kv, inference_params)
568
+ return self.inner_cross_attn(q, kv)
569
+ else:
570
+ batch = q.shape[0]
571
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
572
+ cache_seqlens = (
573
+ inference_params.lengths_per_sample[:batch]
574
+ if inference_params.lengths_per_sample is not None
575
+ else inference_params.seqlen_offset
576
+ )
577
+ alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
578
+ return local_flash_attn_with_kvcache(
579
+ q,
580
+ kv_cache[:, :, 0],
581
+ kv_cache[:, :, 1],
582
+ kv[:, :, 0],
583
+ kv[:, :, 1],
584
+ cache_seqlens=cache_seqlens,
585
+ softmax_scale=self.inner_cross_attn.softmax_scale,
586
+ causal=self.inner_cross_attn.causal,
587
+ alibi_slopes=alibi_slopes,
588
+ )
589
+
590
+ def forward(
591
+ self,
592
+ x,
593
+ x_kv=None,
594
+ key_padding_mask=None,
595
+ cu_seqlens=None,
596
+ max_seqlen=None,
597
+ mixer_subset=None,
598
+ inference_params=None,
599
+ **kwargs,
600
+ ):
601
+ """
602
+ Arguments:
603
+ x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
604
+ cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
605
+ is the is the sum of the sequence lengths in the batch.
606
+ x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
607
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
608
+ of the sequences in the batch, used to index into x. Only applicable when using
609
+ FlashAttention.
610
+ max_seqlen: int. Maximum sequence length in the batch.
611
+ key_padding_mask: boolean mask, True means to keep, False means to mask out.
612
+ (batch, seqlen). Only applicable when not using FlashAttention.
613
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
614
+ before applying the query projection. Useful for e.g., ViT where we only care
615
+ about the CLS token in the last layer.
616
+ inference_params: for generation. Adapted from Megatron-LM (and Apex)
617
+ https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
618
+ """
619
+ if cu_seqlens is not None:
620
+ assert max_seqlen is not None
621
+ assert key_padding_mask is None
622
+ assert self.use_flash_attn
623
+ assert not self.dwconv
624
+ assert self.rotary_emb_dim == 0
625
+ if key_padding_mask is not None:
626
+ assert cu_seqlens is None
627
+ assert max_seqlen is None
628
+ assert not self.use_flash_attn
629
+ if inference_params is not None:
630
+ assert key_padding_mask is None
631
+ assert cu_seqlens is None and max_seqlen is None
632
+ assert not self.dwconv
633
+
634
+ kwargs = (
635
+ {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
636
+ if self.use_flash_attn
637
+ else {"key_padding_mask": key_padding_mask, **kwargs}
638
+ )
639
+ seqlen_offset = (
640
+ 0
641
+ if inference_params is None
642
+ else (
643
+ inference_params.lengths_per_sample
644
+ if inference_params.lengths_per_sample is not None
645
+ else inference_params.seqlen_offset
646
+ )
647
+ )
648
+ rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
649
+ batch, seqlen = x.shape[:2]
650
+ if not self.cross_attn and self.num_heads_kv == self.num_heads:
651
+ assert x_kv is None and mixer_subset is None
652
+ if not self.return_residual:
653
+ qkv = self.Wqkv(x)
654
+ else:
655
+ qkv, x = self.Wqkv(x)
656
+ if self.dwconv:
657
+ qkv = rearrange(
658
+ self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2],
659
+ "b d s -> b s d",
660
+ ).contiguous()
661
+ qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
662
+ if (
663
+ inference_params is None
664
+ or inference_params.seqlen_offset == 0
665
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
666
+ or not self.use_flash_attn
667
+ ):
668
+ if self.rotary_emb_dim > 0:
669
+ qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen)
670
+ if inference_params is None:
671
+ if not self.checkpointing:
672
+ context = self.inner_attn(qkv, **kwargs)
673
+ else:
674
+ context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
675
+ else:
676
+ context = self._update_kvcache_attention(qkv[:, :, 0], qkv[:, :, 1:], inference_params)
677
+ else:
678
+ context = self._apply_rotary_update_kvcache_attention(qkv[:, :, 0], qkv[:, :, 1:], inference_params)
679
+ else:
680
+ if self.cross_attn:
681
+ if not self.return_residual:
682
+ q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
683
+ kv = self.Wkv(x_kv if x_kv is not None else x)
684
+ else:
685
+ if x_kv is not None:
686
+ kv, x_kv = self.Wkv(x_kv)
687
+ else:
688
+ kv, x = self.Wkv(x)
689
+ q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
690
+ else:
691
+ assert self.num_heads_kv != self.num_heads
692
+ if not self.return_residual:
693
+ qkv = self.Wqkv(x)
694
+ else:
695
+ qkv, x = self.Wqkv(x)
696
+ q = qkv[..., : self.num_heads * self.head_dim]
697
+ kv = qkv[..., self.num_heads * self.head_dim :]
698
+ q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
699
+ kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
700
+ if self.dwconv:
701
+ q = rearrange(
702
+ self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2],
703
+ "b d s -> b s d",
704
+ ).contiguous()
705
+ kv = rearrange(
706
+ self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2],
707
+ "b d s -> b s d",
708
+ ).contiguous()
709
+ if (
710
+ inference_params is None
711
+ or inference_params.seqlen_offset == 0
712
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
713
+ or not self.use_flash_attn
714
+ ):
715
+ if self.rotary_emb_dim > 0:
716
+ q, kv = self.rotary_emb(q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen)
717
+ if inference_params is None:
718
+ if not self.checkpointing:
719
+ context = self.inner_cross_attn(q, kv, **kwargs)
720
+ else:
721
+ context = torch.utils.checkpoint.checkpoint(self.inner_cross_attn, q, kv, **kwargs)
722
+ else:
723
+ context = self._update_kvcache_attention(q, kv, inference_params)
724
+ else:
725
+ context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
726
+ out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
727
+ return out if not self.return_residual else (out, x)
728
+
729
+
730
+ class ParallelMHA(nn.Module):
731
+ """Multi-head self-attention and cross-attention"""
732
+
733
+ def __init__(
734
+ self,
735
+ embed_dim,
736
+ num_heads,
737
+ process_group,
738
+ num_heads_kv=None,
739
+ qkv_proj_bias=True,
740
+ out_proj_bias=True,
741
+ dropout=0.0,
742
+ softmax_scale=None,
743
+ causal=False,
744
+ layer_idx=None,
745
+ rotary_emb_dim=0,
746
+ rotary_emb_base=10000.0,
747
+ rotary_emb_scale_base=None,
748
+ rotary_emb_interleaved=False,
749
+ use_alibi=False,
750
+ window_size=(-1, -1),
751
+ use_flash_attn=False,
752
+ checkpointing=False,
753
+ sequence_parallel=True,
754
+ device=None,
755
+ dtype=None,
756
+ ) -> None:
757
+ factory_kwargs = {"device": device, "dtype": dtype}
758
+ super().__init__()
759
+ self.embed_dim = embed_dim
760
+ self.causal = causal
761
+ self.layer_idx = layer_idx
762
+ self.rotary_emb_dim = rotary_emb_dim
763
+ self.use_flash_attn = use_flash_attn
764
+ self.checkpointing = checkpointing
765
+ self.process_group = process_group
766
+ self.world_size = process_group.size()
767
+ self.local_rank = torch.distributed.get_rank(process_group)
768
+
769
+ self.num_heads = num_heads
770
+ assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
771
+
772
+ self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
773
+ assert self.num_heads % self.num_heads_kv == 0, "num_heads must be divisible by num_heads_kv"
774
+
775
+ self.num_heads_per_rank = get_dim_for_local_rank(self.num_heads, self.world_size, self.local_rank)
776
+ self.num_heads_kv_per_rank = get_dim_for_local_rank(self.num_heads_kv, self.world_size, self.local_rank)
777
+ self.head_dim = self.embed_dim // num_heads
778
+ qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
779
+
780
+ if use_alibi:
781
+ assert use_flash_attn, "ALiBi code path requires flash_attn"
782
+ num_heads_local = math.ceil(self.num_heads / self.world_size)
783
+ alibi_slopes = torch.tensor(
784
+ get_alibi_slopes(num_heads)[
785
+ self.local_rank * num_heads_local : (self.local_rank + 1) * num_heads_local
786
+ ],
787
+ device=device,
788
+ )
789
+ else:
790
+ alibi_slopes = None
791
+ if window_size != (-1, -1):
792
+ assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
793
+
794
+ if self.rotary_emb_dim > 0:
795
+ assert RotaryEmbedding is not None, "rotary_emb is not installed"
796
+ self.rotary_emb = RotaryEmbedding(
797
+ self.rotary_emb_dim,
798
+ base=rotary_emb_base,
799
+ scale_base=rotary_emb_scale_base,
800
+ interleaved=rotary_emb_interleaved,
801
+ device=device,
802
+ )
803
+
804
+ if ColumnParallelLinear is None or RowParallelLinear is None:
805
+ raise ImportError("fused_dense is not installed")
806
+ self.Wqkv = ColumnParallelLinear(
807
+ embed_dim,
808
+ qkv_dim,
809
+ process_group,
810
+ bias=qkv_proj_bias,
811
+ sequence_parallel=sequence_parallel,
812
+ multiple_of=self.head_dim * (self.num_heads // self.num_heads_kv + 2),
813
+ **factory_kwargs,
814
+ )
815
+ inner_attn_cls = (
816
+ partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
817
+ if use_flash_attn
818
+ else SelfAttention
819
+ )
820
+ inner_cross_attn_cls = (
821
+ partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
822
+ if use_flash_attn
823
+ else CrossAttention
824
+ )
825
+ self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
826
+ self.inner_cross_attn = inner_cross_attn_cls(
827
+ causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
828
+ )
829
+ self.out_proj = RowParallelLinear(
830
+ embed_dim,
831
+ embed_dim,
832
+ process_group,
833
+ bias=out_proj_bias,
834
+ sequence_parallel=sequence_parallel,
835
+ multiple_of=self.head_dim,
836
+ **factory_kwargs,
837
+ )
838
+
839
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
840
+ dtype = self.out_proj.weight.dtype if dtype is None else dtype
841
+ device = self.out_proj.weight.device
842
+ return torch.empty(
843
+ batch_size,
844
+ max_seqlen,
845
+ 2,
846
+ self.num_heads_kv_per_rank,
847
+ self.head_dim,
848
+ dtype=dtype,
849
+ device=device,
850
+ )
851
+
852
+ def _update_kv_cache(self, kv, inference_params):
853
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
854
+ assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
855
+ return _update_kv_cache(kv, inference_params, self.layer_idx)
856
+
857
+ def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
858
+ """
859
+ Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
860
+ q: (batch_size, seqlen_q, nheads, head_dim)
861
+ kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
862
+ """
863
+ assert inference_params is not None and inference_params.seqlen_offset > 0
864
+ assert self.use_flash_attn
865
+ if self.rotary_emb_dim > 0:
866
+ assert self.rotary_emb.scale is None, "This code path does not support xPos"
867
+ self.rotary_emb._update_cos_sin_cache(inference_params.max_seqlen, device=q.device, dtype=q.dtype)
868
+ rotary_cos, rotary_sin = (
869
+ self.rotary_emb._cos_cached,
870
+ self.rotary_emb._sin_cached,
871
+ )
872
+ else:
873
+ rotary_cos, rotary_sin = None, None
874
+ batch = q.shape[0]
875
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
876
+ cache_seqlens = (
877
+ inference_params.lengths_per_sample[:batch]
878
+ if inference_params.lengths_per_sample is not None
879
+ else inference_params.seqlen_offset
880
+ )
881
+ alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
882
+ context = local_flash_attn_with_kvcache(
883
+ q,
884
+ kv_cache[:, :, 0],
885
+ kv_cache[:, :, 1],
886
+ kv[:, :, 0],
887
+ kv[:, :, 1],
888
+ rotary_cos=rotary_cos,
889
+ rotary_sin=rotary_sin,
890
+ cache_seqlens=cache_seqlens,
891
+ softmax_scale=self.inner_cross_attn.softmax_scale,
892
+ causal=self.inner_cross_attn.causal,
893
+ rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
894
+ alibi_slopes=alibi_slopes,
895
+ )
896
+ return context
897
+
898
+ def _update_kvcache_attention(self, q, kv, inference_params):
899
+ """Write kv to inference_params, then do attention"""
900
+ if inference_params.seqlen_offset == 0 or not self.use_flash_attn:
901
+ # TODO: this only uses seqlen_offset and not lengths_per_sample.
902
+ kv = self._update_kv_cache(kv, inference_params)
903
+ return self.inner_cross_attn(q, kv)
904
+ else:
905
+ batch = q.shape[0]
906
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
907
+ cache_seqlens = (
908
+ inference_params.lengths_per_sample[:batch]
909
+ if inference_params.lengths_per_sample is not None
910
+ else inference_params.seqlen_offset
911
+ )
912
+ alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
913
+ context = local_flash_attn_with_kvcache(
914
+ q,
915
+ kv_cache[:, :, 0],
916
+ kv_cache[:, :, 1],
917
+ kv[:, :, 0],
918
+ kv[:, :, 1],
919
+ cache_seqlens=cache_seqlens,
920
+ softmax_scale=self.inner_cross_attn.softmax_scale,
921
+ causal=self.inner_cross_attn.causal,
922
+ alibi_slopes=alibi_slopes,
923
+ )
924
+ return context
925
+
926
+ def forward(self, x, seqlen=None, inference_params=None, **kwargs):
927
+ """
928
+ Arguments:
929
+ x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
930
+ If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
931
+ split x during sequence parallel, we split the batch * seqlen dimension
932
+ (in case batch is small).
933
+ """
934
+ qkv = self.Wqkv(x)
935
+ if seqlen is not None:
936
+ qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
937
+ seqlen_offset = (
938
+ 0
939
+ if inference_params is None
940
+ else (
941
+ inference_params.lengths_per_sample
942
+ if inference_params.lengths_per_sample is not None
943
+ else inference_params.seqlen_offset
944
+ )
945
+ )
946
+ rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
947
+ if self.num_heads_kv == self.num_heads:
948
+ qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
949
+ if (
950
+ inference_params is None
951
+ or inference_params.seqlen_offset == 0
952
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
953
+ or not self.use_flash_attn
954
+ ):
955
+ if self.rotary_emb_dim > 0:
956
+ qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen)
957
+ if inference_params is None:
958
+ if not self.checkpointing:
959
+ context = self.inner_attn(qkv, **kwargs)
960
+ else:
961
+ context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
962
+ else:
963
+ context = self._update_kvcache_attention(qkv[:, :, 0], qkv[:, :, 1:], inference_params)
964
+ else:
965
+ context = self._apply_rotary_update_kvcache_attention(qkv[:, :, 0], qkv[:, :, 1:], inference_params)
966
+ else:
967
+ q = rearrange(
968
+ qkv[..., : self.num_heads_per_rank * self.head_dim],
969
+ "... (h d) -> ... h d",
970
+ d=self.head_dim,
971
+ )
972
+ kv = rearrange(
973
+ qkv[..., self.num_heads_per_rank * self.head_dim :],
974
+ "... (two hkv d) -> ... two hkv d",
975
+ two=2,
976
+ d=self.head_dim,
977
+ )
978
+ if (
979
+ inference_params is None
980
+ or inference_params.seqlen_offset == 0
981
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
982
+ or not self.use_flash_attn
983
+ ):
984
+ if self.rotary_emb_dim > 0:
985
+ q, kv = self.rotary_emb(q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen)
986
+ if inference_params is None:
987
+ if not self.checkpointing:
988
+ context = self.inner_cross_attn(q, kv, **kwargs)
989
+ else:
990
+ context = torch.utils.checkpoint.checkpoint(self.inner_cross_attn, q, kv, **kwargs)
991
+ else:
992
+ context = self._update_kvcache_attention(q, kv, inference_params)
993
+ else:
994
+ context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
995
+ context = rearrange(context, "b s h d -> b s (h d)")
996
+ if seqlen is not None:
997
+ context = rearrange(context, "b s d -> (b s) d")
998
+ out = self.out_proj(context)
999
+ return out
cache.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied verbatim from vortex
2
+ # Copyright (c) 2024, Michael Poli.
3
+
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Optional
7
+
8
+ from torch import Tensor
9
+
10
+
11
+ # https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py
12
+ @dataclass
13
+ class InferenceParams:
14
+ """Inference parameters that are passed to the main model in order
15
+ to efficienly calculate and store the context during inference."""
16
+
17
+ max_seqlen: int
18
+ max_batch_size: int
19
+ seqlen_offset: int = 0
20
+ batch_size_offset: int = 0
21
+ key_value_memory_dict: dict = field(default_factory=dict)
22
+ lengths_per_sample: Optional[Tensor] = None
23
+
24
+ def reset(self, max_seqlen, max_batch_size):
25
+ self.max_seqlen = max_seqlen
26
+ self.max_batch_size = max_batch_size
27
+ self.seqlen_offset = 0
28
+ if self.lengths_per_sample is not None:
29
+ self.lengths_per_sample.zero_()
30
+
31
+
32
+ @dataclass
33
+ class HyenaCascadeIIRInferenceParams:
34
+ """Inference parameters passed to long Hyena blocks with recurrent mode."""
35
+
36
+ fir_filter_length: int = 3
37
+ state_dim: int = 16
38
+ seqlen_offset: int = 0
39
+ fir_state_dict: dict = field(default_factory=dict)
40
+ state_dict: dict = field(default_factory=dict)
41
+
42
+ def reset(self):
43
+ self.fir_filter_length = 3
44
+ self.state_dim = 16
45
+ self.seqlen_offset = 0
46
+
47
+
48
+ @dataclass
49
+ class HyenaCascadeFIRInferenceParams:
50
+ """Inference parameters passed to short and medium Hyena blocks."""
51
+
52
+ fir_filter_length: int = 3
53
+ fir_inner_filter_length: int = 4
54
+ seqlen_offset: int = 0
55
+ fir_inner_state_dict: dict = field(default_factory=dict)
56
+ fir_state_dict: dict = field(default_factory=dict)
57
+ state_dict: dict = field(default_factory=dict)
58
+
59
+ def reset(self):
60
+ self.fir_filter_length = 3
61
+ self.fir_inner_filter_length = 4
62
+ self.seqlen_offset = 0
engine.py ADDED
@@ -0,0 +1,597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied verbatim from vortex
2
+ # Copyright (c) 2024, Michael Poli.
3
+
4
+ import gc
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ try:
10
+ pass
11
+ except:
12
+ pass
13
+ from .utils import column_split
14
+ from .rich_logging import activations_logger
15
+
16
+ IIR_PREFILL_MODES = [
17
+ "recurrence",
18
+ "modal-fft",
19
+ "hybrid-modal-recurrence",
20
+ "modal-scan",
21
+ "canonical-fft",
22
+ "iir-fir-caching",
23
+ ]
24
+
25
+
26
+ def adjust_filter_shape_for_broadcast(u, h):
27
+ h = h.squeeze() # Standardize to [D, L] from [1, D, L] and [D, 1, L]
28
+
29
+ # Case: u: [B, D, L], k_f: [D, L]
30
+ if len(u.shape) > len(h.shape):
31
+ h = h.unsqueeze(0)
32
+
33
+ # Case: u: [B, D1, D2, L], k_f: [B, D, L]
34
+ if len(u.shape) > 3:
35
+ h = h.unsqueeze(1)
36
+ return h
37
+
38
+
39
+ def fftconv_func(
40
+ u,
41
+ k,
42
+ D,
43
+ dropout_mask,
44
+ gelu=True,
45
+ k_rev=None,
46
+ bidirectional=False,
47
+ print_activations=False,
48
+ layer_idx=None,
49
+ **kwargs,
50
+ ):
51
+ seqlen = u.shape[-1]
52
+ fft_size = 2 * seqlen
53
+
54
+ k_f = torch.fft.rfft(k, n=fft_size) / fft_size
55
+ k_f = adjust_filter_shape_for_broadcast(u, k_f)
56
+ k = k.squeeze()
57
+
58
+ if bidirectional:
59
+ u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)
60
+ k, k2 = k.split(k.shape[1] // 2, dim=1)
61
+ k2_f = torch.fft.rfft(k2, n=fft_size) / fft_size
62
+ y1 = u_f * k_f
63
+ y2 = u_f.conj() * k2_f.conj()
64
+
65
+ y = torch.fft.irfft(y1 + y2, n=fft_size, norm="forward")[..., :seqlen]
66
+
67
+ else:
68
+ if k_rev is not None:
69
+ k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size
70
+ k_f = k_f + k_rev_f.conj()
71
+
72
+ u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)
73
+
74
+ y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen]
75
+
76
+ if print_activations:
77
+ activations_logger.info(f"post fftconv pre bias {y} {y.min()} {y.max()}")
78
+
79
+ out = y + u * D.unsqueeze(-1)
80
+
81
+ if print_activations:
82
+ activations_logger.info(f"post fftconv post bias {out} {out.min()} {out.max()}")
83
+
84
+ return out.to(dtype=u.dtype)
85
+
86
+
87
+ def canonicalize_modal_system(poles, residues):
88
+ """Canonicalize a modal system.
89
+
90
+ Args:
91
+ poles (Tensor): The poles of the system.
92
+ residues (Tensor): The residues of the system.
93
+
94
+ Returns:
95
+ Tuple[Tensor, Tensor]: The canonicalized poles and residues.
96
+ """
97
+ raise NotImplementedError
98
+
99
+
100
+ def list_tensors(idx):
101
+ for obj in gc.get_objects():
102
+ try:
103
+ if torch.is_tensor(obj) and isinstance(obj, torch.Tensor):
104
+ # dump to log
105
+ print(type(obj), obj.size())
106
+ el = obj[0]
107
+ with open(f"tensors_{idx}.txt", "a") as f:
108
+ f.write(f"{type(obj)} {obj.size()} {el}\n")
109
+ except Exception:
110
+ pass
111
+
112
+
113
+ class HyenaInferenceEngine:
114
+ def __init__(
115
+ self,
116
+ fir_fn=None,
117
+ iir_prefill_style="modal-fft",
118
+ layer_idx=None,
119
+ ground_truth_activations_path=None,
120
+ print_activations=False,
121
+ hyena_flip_x1x2=False,
122
+ ) -> None:
123
+ self.fir_fn = fir_fn
124
+ assert iir_prefill_style in IIR_PREFILL_MODES, f"iir_prefill_style must be one of {IIR_PREFILL_MODES}"
125
+ self.iir_prefill_style = iir_prefill_style
126
+ self.layer_idx = layer_idx
127
+ self.low_mem_mode = False
128
+ self.ground_truth_activations_path = ground_truth_activations_path
129
+ self.print_activations = print_activations
130
+ self.hyena_flip_x1x2 = hyena_flip_x1x2
131
+
132
+ def parallel_fir(
133
+ self,
134
+ fir_fn,
135
+ u,
136
+ weight,
137
+ bias,
138
+ L,
139
+ dims,
140
+ groups=None,
141
+ gated_bias=False,
142
+ column_split_hyena=False,
143
+ dim_last=True,
144
+ fir_length=3,
145
+ gate=False,
146
+ inference_params=None,
147
+ prefill_mode=None,
148
+ padding_mask=None,
149
+ ):
150
+ L = u.shape[1] if dim_last else u.shape[2]
151
+ if gate:
152
+ hidden_size, num_attention_heads, hidden_size_per_attention_head, _, _ = dims
153
+ # Compatibility with training infra that column splits the projections
154
+ if column_split_hyena:
155
+ x2, x1, v = column_split(u, num_attention_heads, hidden_size_per_attention_head)
156
+ else:
157
+ x2, x1, v = u.split([hidden_size, hidden_size, hidden_size], dim=1)
158
+ if self.hyena_flip_x1x2:
159
+ x1, x2 = x2, x1
160
+ u = x1 * v
161
+
162
+ if self.print_activations:
163
+ activations_logger.info(f"q: {x2}, {x2.min()}, {x2.max()}")
164
+ activations_logger.info(f"k: {x1}, {x1.min()}, {x1.max()}")
165
+ activations_logger.info(f"v: {v}, {v.min()}, {v.max()}")
166
+ activations_logger.info(f"post pregate: {u}, {u.min()}, {u.max()}")
167
+
168
+ # prepare input layout, dimensions and dispatch to fir kernel
169
+ # Deprecated
170
+ if fir_fn != torch.nn.functional.conv1d:
171
+ if dim_last:
172
+ u = u.permute(0, 2, 1) # B, D, L
173
+ z = fir_fn(u)[:, :L] # B, L, D
174
+
175
+ elif fir_length >= 128:
176
+ with torch.autocast("cuda"):
177
+ z = fftconv_func(
178
+ u.to(torch.float32),
179
+ weight[:, :, :L].to(torch.float32),
180
+ bias,
181
+ None,
182
+ gelu=False,
183
+ bidirectional=False,
184
+ print_activations=self.print_activations,
185
+ groups=groups,
186
+ layer_idx=self.layer_idx,
187
+ )
188
+ z = z.to(u.dtype)
189
+ else:
190
+ if dim_last:
191
+ u = u.permute(0, 2, 1) # B, D, L
192
+
193
+ if groups is None:
194
+ g = u.shape[1]
195
+ else:
196
+ g = groups
197
+
198
+ z = fir_fn(
199
+ u.to(torch.float32),
200
+ weight.to(torch.float32),
201
+ bias=None,
202
+ stride=1,
203
+ padding=fir_length - 1,
204
+ groups=u.shape[1], # always set to D, regardless of filter grouping
205
+ )[..., :L]
206
+ if self.print_activations:
207
+ activations_logger.info(f"post filter: {z}, {z.min()}, {z.max()}")
208
+
209
+ z = z.to(u.dtype)
210
+
211
+ if gated_bias is False:
212
+ if self.print_activations:
213
+ activations_logger.info(f"post dw conv {z} {z.min()} {z.max()}")
214
+ # if self.ground_truth_activations_path:
215
+ # z_savanna = torch.load(f"{self.ground_truth_activations_path}/post_dw_conv_{self.layer_idx}.pt")
216
+ # z_savanna = z_savanna.permute(1, 2, 0)
217
+ # z_diff = (z.squeeze() - z_savanna.squeeze()).abs().max()
218
+ # activations_logger.info(f"dw_conv_diff: {z_diff}")
219
+
220
+ if bias is not None:
221
+ if gated_bias:
222
+ z = z + bias[None, :, None] * u
223
+ else:
224
+ z = z + bias[None, :, None]
225
+
226
+ # handle padding post fir, the only place with biases
227
+ if type(padding_mask) == torch.Tensor:
228
+ z = z * padding_mask[:, None]
229
+
230
+ if gate:
231
+ # if self.layer_idx == 1:
232
+ # breakpoint()
233
+ z = x2 * z
234
+
235
+ if self.print_activations:
236
+ activations_logger.info(f"hyena filter: {weight}, {weight.min()}, {weight.max()}")
237
+ activations_logger.info(f"post postgate: {z}, {z.min()}, {z.max()}")
238
+ # if self.ground_truth_activations_path:
239
+ # q_savanna = torch.load(f"{self.ground_truth_activations_path}/q_{self.layer_idx}.pt")
240
+ # k_savanna = torch.load(f"{self.ground_truth_activations_path}/k_{self.layer_idx}.pt")
241
+ # v_savanna = torch.load(f"{self.ground_truth_activations_path}/v_{self.layer_idx}.pt")
242
+
243
+ # q_diff = (x2 - q_savanna).abs()
244
+ # k_diff = (x1 - k_savanna).abs()
245
+ # v_diff = (v - v_savanna).abs()
246
+
247
+ # activations_logger.info(f"q_diff: {q_diff.max()}, {q_diff.mean()}")
248
+ # activations_logger.info(f"k_diff: {k_diff.max()}, {k_diff.mean()}")
249
+ # activations_logger.info(f"v_diff: {v_diff.max()}, {v_diff.mean()}")
250
+
251
+ # h_savanna = torch.load(f"/home/zymrael/checkpoints/evo2/activations/savanna/hyena_filter_{self.layer_idx}.pt")
252
+ # h_diff = (weight[..., :h_savanna.shape[-1]].squeeze() - h_savanna.squeeze()).abs()
253
+
254
+ # activations_logger.info(f"h_diff: {h_diff.max()}, {h_diff.mean()}")
255
+
256
+ if inference_params is not None:
257
+ fir_state = u[..., -fir_length + 1 :]
258
+ else:
259
+ fir_state = None
260
+
261
+ return z, fir_state
262
+
263
+ def parallel_iir(
264
+ self,
265
+ z_pre,
266
+ h,
267
+ D,
268
+ L,
269
+ poles,
270
+ residues,
271
+ t,
272
+ dims,
273
+ layer_idx,
274
+ inference_params=None,
275
+ prefill_style="fft",
276
+ fftconv_fn=None,
277
+ padding_mask=None,
278
+ use_flashfft=False,
279
+ column_split_hyena=False,
280
+ long_fir_threshold=None,
281
+ ):
282
+ """Compute the output state of the short convolutional filter."""
283
+ fft_size = 2 * L
284
+ hidden_size, num_attention_heads, hidden_size_per_attention_head, _, _ = dims
285
+ # Compatibility with training infra that column splits the projections
286
+ if column_split_hyena:
287
+ z = z_pre.reshape(
288
+ z_pre.shape[0],
289
+ num_attention_heads,
290
+ 3 * hidden_size_per_attention_head,
291
+ z_pre.shape[2],
292
+ )
293
+ x2, x1, v = (
294
+ z[:, :, :hidden_size_per_attention_head],
295
+ z[
296
+ :,
297
+ :,
298
+ hidden_size_per_attention_head : 2 * hidden_size_per_attention_head,
299
+ ],
300
+ z[:, :, 2 * hidden_size_per_attention_head :],
301
+ )
302
+ x2, x1, v = (
303
+ x2.reshape(x2.shape[0], -1, x2.shape[-1]),
304
+ x1.reshape(x1.shape[0], -1, x1.shape[-1]),
305
+ v.reshape(v.shape[0], -1, v.shape[-1]),
306
+ )
307
+ else:
308
+ x2, x1, v = z_pre.split([hidden_size, hidden_size, hidden_size], dim=1)
309
+
310
+ if self.hyena_flip_x1x2:
311
+ x1, x2 = x2, x1
312
+
313
+ x1v = x1 * v
314
+
315
+ if inference_params is not None and prefill_style == "recurrence":
316
+ y = self.prefill_via_direct_recurrence(
317
+ inference_params=inference_params,
318
+ x1v=x1v,
319
+ L=L,
320
+ poles=poles,
321
+ residues=residues,
322
+ )
323
+
324
+ else:
325
+ if use_flashfft and (L % 2) == 0: # only works with even L
326
+ y = fftconv_fn(
327
+ x1v.to(dtype=torch.bfloat16).contiguous(),
328
+ h.to(dtype=torch.float32),
329
+ )
330
+ X_s = None
331
+
332
+ elif long_fir_threshold is None:
333
+ H = torch.fft.rfft(h.to(dtype=torch.float32), n=fft_size) / fft_size
334
+ X_s = torch.fft.fft(x1v.to(dtype=torch.float32), n=fft_size)
335
+ X = X_s[..., : H.shape[-1]]
336
+ if len(z_pre.shape) > 3:
337
+ H = H.unsqueeze(1)
338
+ y = torch.fft.irfft(X * H, n=fft_size, norm="forward")[..., :L]
339
+
340
+ else:
341
+ assert h.shape[0] == 1, "batch size must be 1 for long_fir_threshold"
342
+ h = h[0][:, None] # rearrange to d, 1, l for depthwise conv1d
343
+ h = h[..., :long_fir_threshold]
344
+ y = F.conv1d(
345
+ x1v,
346
+ h.to(dtype=x1v.dtype),
347
+ stride=1,
348
+ groups=x1v.shape[1],
349
+ padding=h.shape[-1] - 1,
350
+ )[..., :L]
351
+ # if self.layer_idx == 2:
352
+ # breakpoint()
353
+ y = y.to(dtype=x1v.dtype)
354
+ y = (y + x1v * D.unsqueeze(-1)) * x2
355
+
356
+ if self.print_activations:
357
+ activations_logger.info(f"hyena filter: {h}, {h.min()}, {h.max()}")
358
+ activations_logger.info(f"post hyena iir gate: {y}, {y.min()}, {y.max()}")
359
+ activations_logger.info(f"q: {x2}, {x2.min()}, {x2.max()}")
360
+ activations_logger.info(f"k: {x1}, {x1.min()}, {x1.max()}")
361
+ activations_logger.info(f"v: {v}, {v.min()}, {v.max()}")
362
+ # if self.ground_truth_activations_path:
363
+ # q_savanna = torch.load(f"{self.ground_truth_activations_path}/q_{self.layer_idx}.pt")
364
+ # k_savanna = torch.load(f"{self.ground_truth_activations_path}/k_{self.layer_idx}.pt")
365
+ # v_savanna = torch.load(f"{self.ground_truth_activations_path}/v_{self.layer_idx}.pt")
366
+
367
+ # q_diff = (x2 - q_savanna).abs()
368
+ # k_diff = (x1 - k_savanna).abs()
369
+ # v_diff = (v - v_savanna).abs()
370
+
371
+ # activations_logger.info(f"q_diff: {q_diff.max()}, {q_diff.mean()}")
372
+ # activations_logger.info(f"k_diff: {k_diff.max()}, {k_diff.mean()}")
373
+ # activations_logger.info(f"v_diff: {v_diff.max()}, {v_diff.mean()}")
374
+
375
+ # h_savanna = torch.load(f"/home/zymrael/checkpoints/evo2/activations/savanna/hyena_filter_{self.layer_idx}.pt")
376
+
377
+ # h_diff = (h[..., :h_savanna.shape[-1]].squeeze() - h_savanna.squeeze()).abs()
378
+ # activations_logger.info(f"h_diff: {h_diff.max()}, {h_diff.mean()}")
379
+
380
+ if inference_params is not None:
381
+ if prefill_style == "fft":
382
+ self.prefill_via_modal_fft(
383
+ inference_params=inference_params,
384
+ x1v=x1v,
385
+ X_s=X_s,
386
+ L=L,
387
+ t=t,
388
+ poles=poles,
389
+ dims=dims,
390
+ layer_idx=layer_idx,
391
+ use_flashfft=use_flashfft,
392
+ fftconv_fn=fftconv_fn,
393
+ )
394
+
395
+ elif prefill_style == "recurrence":
396
+ # recurrent prefill is done before
397
+ pass
398
+ else:
399
+ raise NotImplementedError
400
+ if self.low_mem_mode:
401
+ # TODO: smarter gc
402
+ del z_pre, x2, x1, v, x1v, h, poles, residues
403
+ torch.cuda.empty_cache()
404
+
405
+ return y.permute(0, 2, 1)
406
+
407
+ def step_fir(self, u, fir_state, weight, bias=None, gated_bias=False, flip_filter=False):
408
+ """Steps forward FIR filters in the architecture.
409
+
410
+ FIR filters generally include truncated convolutions in Hyena with an explicit or hybrid time-domain parametrization:
411
+ * Short FIR filters in Hyena featurizers
412
+ * Short and medium FIR filters in Hyena operators
413
+
414
+ Note:
415
+ `fir_state` contains the last FIR filter length - 1 elements of `u`: `u_(L-2), u_{L-1), ...`
416
+ We assume dimensions of `short_filter_weight` to be `[d, 1, short_filter_len]`.
417
+ """
418
+ weight = weight.squeeze()
419
+
420
+ cache_size = fir_state.shape[-1]
421
+ filter_length = weight.shape[-1]
422
+ if flip_filter:
423
+ weight = weight.flip(-1)
424
+ weight = weight[..., -cache_size - 1 :].unsqueeze(0)
425
+ else:
426
+ weight = weight[..., : cache_size + 1].unsqueeze(0)
427
+
428
+ input_dtype = u.dtype
429
+ weight = weight.to(torch.float32)
430
+ u = u.to(torch.float32)
431
+ fir_state = fir_state.to(torch.float32)
432
+ bias = bias.to(torch.float32) if bias is not None else None
433
+
434
+ h0, h = weight[..., -1], weight[..., :-1]
435
+ y = h0 * u + torch.sum(fir_state * h, dim=-1)
436
+
437
+ if bias is not None:
438
+ if gated_bias:
439
+ y = y + bias * u
440
+ else:
441
+ y = y + bias
442
+
443
+ # Update the state
444
+ if cache_size < filter_length - 1:
445
+ fir_state = torch.cat([fir_state, u[..., None]], dim=-1)
446
+ else:
447
+ fir_state = torch.roll(fir_state, -1, dims=2)
448
+ fir_state[..., -1] = u
449
+
450
+ return y.to(input_dtype), fir_state
451
+
452
+ def step_iir(self, x2, x1, v, D, residues, poles, iir_state, iir_groups=1):
453
+ # TODO: kernelize
454
+ x1v = x1 * v
455
+ poles = torch.exp(poles) # poles arg contains log_poles
456
+ poles = poles[..., 0][None] # squeeze dummy seqlen dim and add dummy batch dim
457
+ residues = residues[None] # add dummy batch dim
458
+ iir_state = poles * iir_state + x1v[..., None]
459
+
460
+ res_state = torch.sum(residues * iir_state, dim=-1)
461
+
462
+ if iir_groups > 1:
463
+ raise NotImplementedError
464
+ # if self.layer_idx == 2:
465
+ # breakpoint()
466
+ y = x2 * (res_state + D * x1v)
467
+
468
+ return y, iir_state
469
+
470
+ def prefill_via_fir_caching(self, u, inference_params, L, *args, **kwargs):
471
+ """Turns the IIR filter into a FIR and uses a cache for decoding."""
472
+ raise NotImplementedError(":)")
473
+
474
+ def prefill_via_direct_recurrence(self, inference_params, x1v, L, residues, poles, *args, **kwargs) -> torch.Tensor:
475
+ """
476
+ Compute the IIR state via explicit recurrence (modal form)
477
+
478
+ This is the most memory efficient prefilling method for Hyena filters.
479
+
480
+ Note:
481
+ dtypes: [state: float32, poles: float32, x1v: bfloat16, output: bfloat16]
482
+ """
483
+ state_dim = poles.shape[1]
484
+ x1v_ = x1v[..., None, None] # b, d, l, sdim, reim
485
+ x1v_ = x1v_.repeat(1, 1, 1, state_dim, 2) # b, d, l, sdim, reim
486
+ x1v_[..., 1] = 0
487
+
488
+ state = 0 * x1v_[:, :, 0]
489
+ output = 0 * x1v_[:, :, :, 0, 0] # b, d, l
490
+
491
+ # suppress dummy seqlen dimension
492
+ poles = poles[:, :, 0][None]
493
+ residues = residues[:, :, 0][None].repeat(x1v_.shape[0], 1, 1, 1) # b, d, sdim, reim
494
+
495
+ # state: b, d, sdim, reim
496
+ # poles: 1, d, sdim, reim
497
+ # x1v_: b, d, l, sdim, reim
498
+ for i in range(L):
499
+ state[..., 0] = poles[..., 0] * state[..., 0] - poles[..., 1] * state[..., 1] + x1v_[:, :, i, :, 0]
500
+ state[..., 1] = poles[..., 0] * state[..., 1] + poles[..., 1] * state[..., 0] + x1v_[:, :, i, :, 1]
501
+ output[:, :, i] = torch.sum(residues * state, dim=-2)[..., 0] # .real
502
+
503
+ inference_params.state_dict[self.layer_idx] = state.to(dtype=torch.float32)
504
+
505
+ return output
506
+
507
+ def prefill_via_hybrid_recurrence(self, inference_params, u, log_poles, x1v_f_a, L, *args, **kwargs):
508
+ """
509
+ Compute the IIR state via hybrid recurrence-convolution over blocks
510
+ """
511
+ raise NotImplementedError(":)")
512
+
513
+ def prefill_via_scan(self, u, inference_params=None, *args, **kwargs):
514
+ raise NotImplementedError
515
+
516
+ def prefill_via_canonical_fft(self, u, inference_params=None, *args, **kwargs):
517
+ """
518
+ Compute the IIR state via a single FFT
519
+
520
+ This is the most memory efficient "parallelized" prefilling method for Hyena.
521
+
522
+ From: https://arxiv.org/abs/2310.18780
523
+ """
524
+ raise NotImplementedError(":)")
525
+
526
+ def prefill_via_modal_fft(
527
+ self,
528
+ inference_params,
529
+ x1v,
530
+ L,
531
+ poles,
532
+ t,
533
+ dims,
534
+ layer_idx,
535
+ X_s=None,
536
+ use_flashfft=False,
537
+ fftconv_fn=None,
538
+ state_dtype=torch.float32,
539
+ *args,
540
+ **kwargs,
541
+ ):
542
+ """
543
+ Compute the IIR state via a single FFT
544
+ """
545
+ # When the model has a long convolution derived from a recurrence in modal form and prefill_style is "fft",
546
+ # we split the filter into poles and residues and reuse FFT computation on the input.
547
+ hidden_size, _, _, state_size, hyena_filter_groups = dims
548
+
549
+ assert X_s is not None
550
+ bs = x1v.shape[0]
551
+ fft_size = 2 * L
552
+ # poles = torch.view_as_complex(poles.to(torch.float32))
553
+ state_s = (poles.to(torch.float32) * t).exp()
554
+
555
+ # state_s = poles**t
556
+ state_S = torch.fft.fft(state_s, n=fft_size).repeat(bs, 1, 1, 1) # B, D, state_dim, 2 * L
557
+ if hyena_filter_groups > 1:
558
+ state_S = state_S.repeat_interleave(hidden_size // hyena_filter_groups, 1)
559
+ state = torch.fft.ifft(X_s[..., None, :] * state_S, n=fft_size)
560
+ inference_params.state_dict[layer_idx] = state[..., L - 1].to(dtype=state_dtype)
561
+
562
+ def _compute_state(self, log_poles, u, t, L, *args, **kwargs):
563
+ """
564
+ Compute the IIR state given an input `u` and log_poles of the modal system.
565
+ """
566
+ bs = u.shape[0]
567
+ fft_size = 2 * L
568
+ U = torch.fft.rfft(u.to(torch.float32), n=fft_size)
569
+ fft_size = 2 * L
570
+ x = (log_poles * t).exp()
571
+ # [batch, hidden_size, state_dim, 2 * seqlen]
572
+ X = torch.fft.fft(x, n=fft_size).repeat(bs, 1, 1, 1)
573
+ state = torch.fft.ifft(U[..., None, :] * X, n=fft_size)[..., :L]
574
+ return state
575
+
576
+
577
+ # I don't think this class is used anywhere? Comment out
578
+ class HyenaFilter:
579
+ """Handles Hyena filter computations including FFT and direct convolution."""
580
+
581
+ def __init__(self, use_flash_fft=False):
582
+ self.use_flash_fft = use_flash_fft
583
+
584
+ def fft_conv(self, u, k, D, **kwargs):
585
+ """FFT-based convolution implementation."""
586
+ seqlen = u.shape[-1]
587
+ fft_size = 2 * seqlen
588
+
589
+ k_f = self._prepare_filter(k, u, fft_size)
590
+ y = self._compute_fft_conv(u, k_f, fft_size, seqlen, **kwargs)
591
+
592
+ return y + u * D.unsqueeze(-1)
593
+
594
+ def _prepare_filter(self, k, u, fft_size):
595
+ """Prepare filter for FFT convolution."""
596
+ k_f = torch.fft.rfft(k, n=fft_size) / fft_size
597
+ return adjust_filter_shape_for_broadcast(u, k_f)
generation.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied verbatim from vortex
2
+ # Copyright (c) 2024, Michael Poli.
3
+
4
+ from dataclasses import dataclass
5
+
6
+ import torch
7
+ import sys
8
+ import numpy as np
9
+
10
+ from .sample import sample
11
+ from .tokenizer import CharLevelTokenizer
12
+ from .utils import print_rank_0
13
+
14
+
15
+ class Generator:
16
+ def __init__(self, model, tokenizer, top_k=50, top_p=0.7, temperature=1):
17
+ self.model = model
18
+ self.tokenizer = tokenizer
19
+ self.top_k = top_k
20
+ self.top_p = top_p
21
+ self.temperature = temperature
22
+ self.untils = ["\n\n"]
23
+
24
+ def generate(
25
+ self,
26
+ device: str,
27
+ input_string: str = None,
28
+ input_ids: torch.Tensor = None,
29
+ num_tokens: int = 32,
30
+ cached_generation: bool = True,
31
+ force_prompt_threshold: int = None,
32
+ max_seqlen: int = None,
33
+ print_generation: bool = True,
34
+ verbose: bool = False,
35
+ skip_special_tokens: bool = False,
36
+ stop_at_eos: bool = True,
37
+ inference_params_dict: dict = None,
38
+ token_callback=lambda i: None,
39
+ ) -> tuple[torch.Tensor, torch.Tensor]:
40
+ """
41
+ Generates using the model with optional cached sampling replay.
42
+
43
+ This method enables passing in and returning the `inference_params_dict` for
44
+ replaying cached sampling from a given state, for example for beam search.
45
+
46
+ Args:
47
+ device: The device to run the model on.
48
+ input_string: The input prompt to generate from.
49
+ input_ids: The input prompt token ids to generate from.
50
+ num_tokens: The number of tokens to generate.
51
+ cached_generation: Whether to use cached generation. Defaults to False.
52
+ force_prompt_threshold: Number of tokens to prefill in parallel before
53
+ switching to prompt forcing. Used to reduce peak memory usage and
54
+ support longer prompts. Defaults to None.
55
+ max_seqlen: Maximum sequence length to generate. Determines the max size
56
+ of the cache if larger. Otherwise automatically determined using
57
+ prompt length + max_tokens. Defaults to None.
58
+ print_generation: Whether to print generated tokens. Defaults to False.
59
+ verbose: Whether to print verbose output. Defaults to False.
60
+ skip_special_tokens: Whether to skip special tokens. Defaults to True.
61
+ stop_at_eos: Whether to stop generation at EOS token. Defaults to True.
62
+ inference_params_dict: Dictionary of inference parameters to use for
63
+ replaying cached sampling. Defaults to None.
64
+ token_callback: Optional callback function called after each token is
65
+ generated. Defaults to None.
66
+
67
+ Returns:
68
+ dict: The inference parameters dictionary used for generation, which can
69
+ be used to replay the exact same sampling sequence.
70
+ """
71
+ if isinstance(self.tokenizer.eos, int):
72
+ eos_token_ids = torch.LongTensor([self.tokenizer.eos]).to(device)
73
+ else:
74
+ eos_token_ids = self.tokenizer.tokenize(self.tokenizer.eos).to(device)
75
+
76
+ if input_ids is None:
77
+ input = self.tokenizer.tokenize(input_string)
78
+ if isinstance(input, list):
79
+ input = torch.LongTensor(input).unsqueeze(0).to(device)
80
+ else:
81
+ input = input.unsqueeze(0).to(device)
82
+ else:
83
+ input = input_ids
84
+ x = input
85
+
86
+ if max_seqlen is not None:
87
+ x = x[:, -max_seqlen:]
88
+
89
+ num_tokens = int(num_tokens)
90
+ batch_size = x.shape[0]
91
+
92
+ prompt_length = x.shape[1]
93
+ prompt_forcing = inference_params_dict is None and force_prompt_threshold is not None and prompt_length > force_prompt_threshold
94
+ if prompt_forcing:
95
+ forced_prompt_length = prompt_length - force_prompt_threshold
96
+ x_force = x[:, force_prompt_threshold:]
97
+ x = x[:, :force_prompt_threshold]
98
+ else:
99
+ forced_prompt_length = 0
100
+ tot_length = prompt_length + num_tokens
101
+ if max_seqlen is not None:
102
+ if max_seqlen > tot_length:
103
+ tot_length = max_seqlen
104
+
105
+ generation = torch.empty(
106
+ x.shape[0],
107
+ num_tokens,
108
+ dtype=torch.long,
109
+ device=x.device,
110
+ )
111
+
112
+ scores = torch.empty(
113
+ x.shape[0],
114
+ num_tokens,
115
+ self.tokenizer.vocab_size,
116
+ dtype=torch.float,
117
+ device=x.device,
118
+ )
119
+
120
+ if inference_params_dict is not None:
121
+ cached_generation = True
122
+ prefilled = True
123
+ # Ensure that the cached data is loaded on the correct device.
124
+ if any(data.device != x.device for data in inference_params_dict["hcl"].fir_state_dict.values()):
125
+ for key, data in inference_params_dict["mha"].key_value_memory_dict.items():
126
+ inference_params_dict["mha"].key_value_memory_dict[key] = data.to(x.device)
127
+ for key, data in inference_params_dict["hcl"].fir_state_dict.items():
128
+ inference_params_dict["hcl"].fir_state_dict[key] = data.to(x.device)
129
+ for key, data in inference_params_dict["hcl"].state_dict.items():
130
+ inference_params_dict["hcl"].state_dict[key] = data.to(x.device)
131
+ for key, data in inference_params_dict["hcm"].fir_inner_state_dict.items():
132
+ inference_params_dict["hcm"].fir_inner_state_dict[key] = data.to(x.device)
133
+ for key, data in inference_params_dict["hcm"].fir_state_dict.items():
134
+ inference_params_dict["hcm"].fir_state_dict[key] = data.to(x.device)
135
+ for key, data in inference_params_dict["hcm"].state_dict.items():
136
+ inference_params_dict["hcm"].state_dict[key] = data.to(x.device)
137
+ for key, data in inference_params_dict["hcs"].fir_state_dict.items():
138
+ inference_params_dict["hcs"].fir_state_dict[key] = data.to(x.device)
139
+ for key, data in inference_params_dict["hcs"].fir_inner_state_dict.items():
140
+ inference_params_dict["hcs"].fir_inner_state_dict[key] = data.to(x.device)
141
+ for key, data in inference_params_dict["hcs"].state_dict.items():
142
+ inference_params_dict["hcs"].state_dict[key] = data.to(x.device)
143
+ inference_params_dict["mha"].max_batch_size = batch_size
144
+ elif cached_generation:
145
+ inference_params_dict = self.model.initialize_inference_params(max_seqlen=tot_length)
146
+ inference_params_dict["mha"].max_batch_size = batch_size
147
+ prefilled = False
148
+ else:
149
+ inference_params_dict = None
150
+ prefilled = False
151
+
152
+ if verbose:
153
+ mem_after_tok = torch.cuda.memory_allocated(device=x.device) / 1e9
154
+ print_rank_0(f"Memory after tokenization: {mem_after_tok} GB")
155
+ print_rank_0("Starting generation...")
156
+ if input_string is not None:
157
+ print_rank_0("Prompt: " + input_string)
158
+ else:
159
+ print_rank_0(f"Prompt ids: {input_ids} {input_ids.shape}")
160
+
161
+ i = 0
162
+ for i in range(forced_prompt_length + num_tokens):
163
+ post_prefill = prefilled or (cached_generation and i > 0)
164
+
165
+ # prefill then process only the last token
166
+ if post_prefill:
167
+ x = x[:, -1:]
168
+ seqlen_offset = inference_params_dict["mha"].seqlen_offset
169
+
170
+ if seqlen_offset == 0:
171
+ if prompt_forcing:
172
+ seqlen_offset = force_prompt_threshold
173
+ else:
174
+ seqlen_offset = input.shape[-1]
175
+ inference_params_dict["mha"].seqlen_offset = seqlen_offset
176
+ inference_params_dict["hcl"].seqlen_offset = seqlen_offset
177
+ inference_params_dict["hcm"].seqlen_offset = seqlen_offset
178
+ inference_params_dict["hcs"].seqlen_offset = seqlen_offset
179
+ else:
180
+ inference_params_dict["mha"].seqlen_offset += 1
181
+ inference_params_dict["hcl"].seqlen_offset += 1
182
+ inference_params_dict["hcm"].seqlen_offset += 1
183
+ inference_params_dict["hcs"].seqlen_offset += 1
184
+
185
+ # do forward pass with no gradient
186
+ with torch.inference_mode():
187
+ logits, inference_params_dict = self.model(
188
+ x,
189
+ inference_params_dict=inference_params_dict,
190
+ )
191
+
192
+ token_callback(i)
193
+
194
+ last_logits = logits[:, -1]
195
+
196
+ if prompt_forcing and i < forced_prompt_length:
197
+ new_idx = x_force[:, i]
198
+ else:
199
+ new_idx = sample(
200
+ last_logits,
201
+ top_k=self.top_k,
202
+ top_p=self.top_p,
203
+ temperature=self.temperature,
204
+ )
205
+
206
+ if stop_at_eos and (generation[0, -1:] == eos_token_ids).all():
207
+ print("Stopping generation at EOS")
208
+
209
+ if print_generation and verbose and batch_size == 1:
210
+ print(
211
+ f"{self.tokenizer.detokenize([new_idx.item()])}",
212
+ end=" ",
213
+ flush=True,
214
+ )
215
+
216
+ if prompt_forcing:
217
+ if i >= forced_prompt_length:
218
+ scores[:, i - forced_prompt_length] = last_logits
219
+ generation[:, i - forced_prompt_length] = new_idx
220
+ else:
221
+ scores[:, i] = last_logits
222
+ generation[:, i] = new_idx
223
+
224
+ if post_prefill:
225
+ x = new_idx[:, None]
226
+ else:
227
+ x = torch.cat([x, new_idx[:, None]], dim=-1)
228
+
229
+ if verbose:
230
+ y = self.tokenizer.detokenize_batch(generation[:, : i + 1])
231
+
232
+ for until in self.untils:
233
+ if until in y:
234
+ y = y.split(until)[0]
235
+ break
236
+
237
+ print(f"\nInput: {input_string}, Output: {y}")
238
+
239
+ mem_end = torch.cuda.memory_allocated(device=x.device) / 1e9
240
+ print(f"Memory after generation: {mem_end} GB")
241
+
242
+ return generation[:, : i + 1], scores[:, : i + 1], inference_params_dict
243
+
244
+
245
+ def logits_to_logprobs(logits: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
246
+ """Convert logits to log probabilities."""
247
+ probs = torch.log_softmax(logits, dim=-1)
248
+ return torch.gather(probs, -1, tokens.unsqueeze(-1)).squeeze(-1)
249
+
250
+
251
+ def prepare_batch(
252
+ seqs: list[str], tokenizer: CharLevelTokenizer, prepend_bos: bool = False, device: str = "cuda:0"
253
+ ) -> tuple[torch.Tensor, list[int]]:
254
+ """Prepare a batch of sequences for the model."""
255
+ if prepend_bos:
256
+ seqs = [tokenizer.bos + seq for seq in seqs]
257
+
258
+ tokens = [tokenizer.tokenize(seq) for seq in seqs]
259
+ if isinstance(tokens[0], list):
260
+ tokens = [torch.tensor(t, dtype=torch.long) for t in tokens]
261
+
262
+ max_len = max(len(t) for t in tokens)
263
+ batch = torch.zeros((len(tokens), max_len), dtype=torch.long)
264
+
265
+ for i, t in enumerate(tokens):
266
+ batch[i, : len(t)] = t
267
+
268
+ return batch.to(device), [len(t) for t in tokens]
269
+
270
+
271
+ @dataclass(kw_only=True)
272
+ class GenerationOutput:
273
+ sequences: list[str]
274
+ logits: list[torch.Tensor]
275
+ logprobs_mean: list[float]
276
+
277
+
278
+ def generate(
279
+ *,
280
+ prompt_seqs: list[str],
281
+ model,
282
+ tokenizer: CharLevelTokenizer,
283
+ n_tokens: int = 100,
284
+ temperature: float = 0.0,
285
+ top_k: int = 1,
286
+ top_p: float = 1.0,
287
+ batched: bool = True,
288
+ prepend_bos: bool = False,
289
+ force_prompt_threshold: int = 1000,
290
+ cached_generation: bool = True,
291
+ verbose: int = 1,
292
+ device: str = "cuda:0",
293
+ **kwargs,
294
+ ) -> GenerationOutput:
295
+ """
296
+ Performs generation from a list of prompts.
297
+ If all prompts are the same length, this can do batched generation.
298
+ Also supports cached generation for efficient sampling.
299
+ """
300
+ model.eval()
301
+
302
+ g = Generator(
303
+ model,
304
+ tokenizer,
305
+ top_k=top_k,
306
+ top_p=top_p,
307
+ temperature=temperature,
308
+ )
309
+
310
+ uniform_lengths = all(len(s) == len(prompt_seqs[0]) for s in prompt_seqs)
311
+
312
+ if batched and uniform_lengths:
313
+ input_ids_list = [
314
+ prepare_batch(
315
+ prompt_seqs,
316
+ tokenizer,
317
+ prepend_bos=prepend_bos,
318
+ device=device,
319
+ )[0]
320
+ ]
321
+ else:
322
+ sys.stderr.write("WARNING: Batched generation is turned off.\n")
323
+ input_ids_list = [
324
+ prepare_batch(
325
+ [prompt_seq],
326
+ tokenizer,
327
+ prepend_bos=prepend_bos,
328
+ device=device,
329
+ )[0]
330
+ for prompt_seq in prompt_seqs
331
+ ]
332
+
333
+ generated_seqs, generated_scores, logitss = [], [], []
334
+ for input_ids in input_ids_list:
335
+ batch_size = input_ids.shape[0]
336
+
337
+ output_ids, logits, _ = g.generate(
338
+ input_ids=input_ids,
339
+ num_tokens=n_tokens,
340
+ device=device,
341
+ print_generation=(verbose > 1),
342
+ verbose=(verbose > 1),
343
+ stop_at_eos=False,
344
+ force_prompt_threshold=force_prompt_threshold,
345
+ cached_generation=cached_generation,
346
+ **kwargs,
347
+ )
348
+
349
+ if verbose > 1:
350
+ print("input_ids.shape", input_ids.shape)
351
+ print("output_ids.shape", output_ids.shape)
352
+ print("logits.shape", logits.shape)
353
+
354
+ generated_seqs_batch = list(tokenizer.detokenize_batch(output_ids))
355
+ assert len(generated_seqs_batch) == batch_size
356
+ generated_seqs += generated_seqs_batch
357
+ logitss.append(logits)
358
+
359
+ logprobs = logits_to_logprobs(logits, output_ids)
360
+ logprobs = logprobs.float().cpu().numpy()
361
+
362
+ generated_scores += [np.mean(logprobs[idx]) for idx in range(batch_size)]
363
+
364
+ assert len(generated_seqs) == len(generated_scores) == len(prompt_seqs)
365
+ if verbose:
366
+ for seq, score, prompt in zip(generated_seqs, generated_scores, prompt_seqs):
367
+ print(f'Prompt: "{prompt}",\tOutput: "{seq}",\tScore: {score}')
368
+
369
+ return GenerationOutput(
370
+ sequences=generated_seqs,
371
+ logits=logitss,
372
+ logprobs_mean=generated_scores,
373
+ )
layers.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied verbatim from vortex (minus the commented out code)
2
+ # Copyright (c) 2024, Michael Poli.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch import Tensor
8
+ from typing import Callable
9
+ from .utils import grab_first_if_tuple
10
+
11
+ from transformer_engine.pytorch import Linear
12
+ from transformer_engine.common.recipe import Format, DelayedScaling
13
+ import transformer_engine.pytorch as te
14
+
15
+ # Not bothering with ops right now (which is an interface with custom Triton
16
+ # kernels)
17
+ # try:
18
+ # from hyena_ops import hyena_se_fwd, hyena_mr_fwd, hyena_li_fwd
19
+ # except ImportError:
20
+ # hyena_se_fwd, hyena_mr_fwd, hyena_li_fwd = None, None, None
21
+
22
+ hyena_se_fwd, hyena_mr_fwd, hyena_li_fwd = None, None, None
23
+
24
+
25
+ def set_format_recipe():
26
+ fp8_format = Format.HYBRID # E4M3 during forward pass, E5M2 during backward pass
27
+ fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
28
+ return fp8_format, fp8_recipe
29
+
30
+
31
+ class TELinear(Linear):
32
+ """
33
+ Wrapper for Transformer-Engine's `Linear` layer.
34
+
35
+ Note that if Megatron's parallel_state has not been initialized
36
+ yet, the tp_group passed to TE will be None and must be set later
37
+ via set_tensor_parallel_group().
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ input_size: int,
43
+ output_size: int,
44
+ init_method: Callable,
45
+ bias: bool = True,
46
+ skip_bias_add: bool = False,
47
+ use_fp8: bool = False,
48
+ **kwargs,
49
+ ):
50
+ # Parameters are initialized at higher precision even if fp8
51
+ # is used
52
+ params_dtype = torch.bfloat16
53
+
54
+ # TE returns a zero length Tensor when bias=False and
55
+ # return_bias=True, but we prefer None. So in that case we
56
+ # tell TE to not return the bias, and return None
57
+ # ourselves. This way our forward always returns two values
58
+ # and we don't have to deal with the zero length Tensor.
59
+ self.te_return_bias = skip_bias_add and bias
60
+
61
+ self.use_fp8_input_projections = use_fp8
62
+ if use_fp8:
63
+ self.fp8_format, self.fp8_recipe = set_format_recipe()
64
+
65
+ super().__init__(
66
+ in_features=input_size,
67
+ out_features=output_size,
68
+ sequence_parallel=False,
69
+ fuse_wgrad_accumulation=False,
70
+ tp_group=None,
71
+ tp_size=1,
72
+ init_method=init_method,
73
+ params_dtype=params_dtype,
74
+ parallel_mode=None,
75
+ bias=bias,
76
+ return_bias=self.te_return_bias,
77
+ **kwargs,
78
+ )
79
+
80
+ def forward(self, x):
81
+ if self.use_fp8_input_projections:
82
+ with te.fp8_autocast(enabled=True, fp8_recipe=self.fp8_recipe):
83
+ out = super().forward(x)
84
+ else:
85
+ out = super().forward(x)
86
+
87
+ # TE only returns a tuple when return_bias is True, otherwise
88
+ # it returns a single Tensor, we always want to return two
89
+ # values regardless of the arguments.
90
+ if self.te_return_bias:
91
+ return out
92
+ return out, None
93
+
94
+
95
+ class FlexLinear:
96
+ """
97
+ Megatron and Transformer Engine linear layer compatible with fp8, bf16, fp16 and fp32
98
+ """
99
+
100
+ def __new__(
101
+ self,
102
+ input_size,
103
+ output_size,
104
+ config,
105
+ parallel_mode: str,
106
+ bias: bool = False,
107
+ skip_bias_add: bool = True,
108
+ use_fp8: bool = False,
109
+ input_is_parallel=False, # for row parallel
110
+ gather_output: bool = True, # for column parallel
111
+ parallel_output: bool = False, # for row parallel
112
+ **kwargs,
113
+ ):
114
+ # use_fp8 = config.use_fp8_linears
115
+ self.config = config
116
+ instance = None
117
+
118
+ if use_fp8:
119
+ instance = TELinear(
120
+ input_size=input_size,
121
+ output_size=output_size,
122
+ config=self.config,
123
+ parallel_mode=parallel_mode,
124
+ bias=bias,
125
+ skip_bias_add=skip_bias_add,
126
+ **kwargs,
127
+ )
128
+
129
+ return instance
130
+
131
+
132
+ class RMSNorm(torch.nn.Module):
133
+ def __init__(self, config):
134
+ super(RMSNorm, self).__init__()
135
+ self.eps, self.hidden_size = config.eps, config.hidden_size
136
+ self.scale = torch.nn.Parameter(torch.ones(self.hidden_size, dtype=config.params_dtype))
137
+ self.register_parameter("scale", self.scale)
138
+ self.use_flash_rmsnorm = config.get("use_flash_rmsnorm", False)
139
+
140
+ if self.use_flash_rmsnorm:
141
+ from flash_attn.ops.rms_norm import rms_norm as rmsnorm_func
142
+
143
+ self.rmsnorm_func = rmsnorm_func
144
+
145
+ def forward(self, x):
146
+ if self.use_flash_rmsnorm:
147
+ return self.rmsnorm_func(x, self.scale, self.eps)
148
+ else:
149
+ y = x / (x.norm(2, dim=-1, keepdim=True) * self.hidden_size ** (-1.0 / 2) + self.eps)
150
+ return self.scale * y
151
+
152
+
153
+ class ParallelGatedMLP(nn.Module):
154
+ def __init__(
155
+ self,
156
+ config,
157
+ layer_idx,
158
+ ):
159
+ super().__init__()
160
+
161
+ self.layer_idx = layer_idx
162
+ multiple_of = config.get("inner_size_multiple_of", 64)
163
+ self.act_type = config.get("mlp_activation", "gelu")
164
+ if self.act_type == "gelu":
165
+ self.act = F.gelu
166
+ elif self.act_type == "silu":
167
+ self.act = F.silu
168
+ else:
169
+ raise NotImplementedError
170
+
171
+ if self.layer_idx > 0 and config.get("evo2_style_activations", False):
172
+ self.act = nn.Identity()
173
+
174
+ self.multiple_of = multiple_of * config.model_parallel_size
175
+
176
+ inner_size = int(2 * config.hidden_size * 4 / 3)
177
+ inner_size = self.multiple_of * ((inner_size + self.multiple_of - 1) // self.multiple_of)
178
+ inner_size = config.get("inner_mlp_size", inner_size)
179
+
180
+ self.l1 = nn.Linear(
181
+ in_features=config.hidden_size,
182
+ out_features=inner_size,
183
+ bias=False,
184
+ )
185
+ self.l2 = nn.Linear(
186
+ in_features=config.hidden_size,
187
+ out_features=inner_size,
188
+ bias=False,
189
+ )
190
+ self.l3 = nn.Linear(
191
+ in_features=inner_size,
192
+ out_features=config.hidden_size,
193
+ bias=False,
194
+ )
195
+
196
+ def forward(self, z):
197
+ z1, z2 = self.l1(z), self.l2(z)
198
+ z1, z2 = grab_first_if_tuple(z1), grab_first_if_tuple(z2)
199
+ y = self.l3(self.act(z1) * z2)
200
+ return grab_first_if_tuple(y)
201
+
202
+
203
+ class Embedding(nn.Module):
204
+ _train_dtype = "bf16"
205
+
206
+ def __init__(self, config):
207
+ super().__init__()
208
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
209
+
210
+ def embed(self, input_ids, position_ids=None, tokentype_ids=None):
211
+ embeddings = self.word_embeddings(input_ids)
212
+ return embeddings
213
+
214
+ def unembed(self, u):
215
+ weight = self.word_embeddings.weight
216
+ return torch.matmul(u, weight)
217
+
218
+
219
+ class VocabParallelEmbedding(nn.Embedding):
220
+ "Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/embedding.py"
221
+
222
+ def __init__(self, config):
223
+ vocab_size, process_group, padding_idx = (
224
+ config.vocab_size,
225
+ config.get("process_group", None),
226
+ config.get("padding_idx", None),
227
+ )
228
+ self.process_group = process_group
229
+ if process_group is not None:
230
+ world_size = torch.distributed.get_world_size(process_group)
231
+ if vocab_size % world_size != 0:
232
+ raise ValueError(f"vocab_size ({vocab_size}) must be divisible by " f"world_size ({world_size})")
233
+ if world_size > 1 and padding_idx is not None:
234
+ raise RuntimeError("ParallelEmbedding does not support padding_idx")
235
+ else:
236
+ world_size = 1
237
+ super().__init__(
238
+ vocab_size // world_size,
239
+ embedding_dim=config.hidden_size,
240
+ padding_idx=padding_idx,
241
+ )
242
+
243
+ def forward(self, input: Tensor) -> Tensor:
244
+ if self.process_group is None:
245
+ return super().forward(input)
246
+ else:
247
+ rank = torch.distributed.get_rank(self.process_group)
248
+ vocab_size = self.num_embeddings
249
+ vocab_start_index, vocab_end_index = (
250
+ rank * vocab_size,
251
+ (rank + 1) * vocab_size,
252
+ )
253
+ # Create a mask of valid vocab ids (1 means it needs to be masked).
254
+ input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
255
+ input = input - vocab_start_index
256
+ input[input_ids_mask] = 0
257
+ embeddings = self.forward(input)
258
+ embeddings[input_ids_mask] = 0.0
259
+ # Reduce to the global process group
260
+ torch.distributed.all_reduce(embeddings, group=self.process_group)
261
+ return embeddings
262
+
263
+ def unembed(self, u: Tensor) -> Tensor:
264
+ if self.process_group is None:
265
+ return u @ self.weight.T
266
+ else:
267
+ raise NotImplementedError
268
+
269
+
270
+ class VocabParallelUnembedding(VocabParallelEmbedding):
271
+ def forward(self, input: Tensor) -> Tensor:
272
+ return self.unembed(input)
model.py ADDED
@@ -0,0 +1,937 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied verbatim from vortex
2
+
3
+ # Copyright (c) 2024, Michael Poli.
4
+
5
+ import math
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from .cache import (
11
+ InferenceParams,
12
+ HyenaCascadeFIRInferenceParams,
13
+ HyenaCascadeIIRInferenceParams,
14
+ )
15
+ from .engine import HyenaInferenceEngine
16
+ from .layers import (
17
+ ParallelGatedMLP,
18
+ RMSNorm,
19
+ VocabParallelEmbedding,
20
+ VocabParallelUnembedding,
21
+ TELinear,
22
+ )
23
+ from .utils import (
24
+ Lambda,
25
+ column_split,
26
+ interleave,
27
+ print_rank_0,
28
+ move_to_device,
29
+ fixup_fp8_extra_states,
30
+ fixup_te_workspace,
31
+ )
32
+ from .rich_logging import activations_logger, enable_activations_logging
33
+
34
+ import logging
35
+ from tqdm import tqdm
36
+
37
+ from attention import MHA
38
+
39
+ try:
40
+ from vortex.model.positional_embeddings import swap_mha_rope
41
+ except ImportError:
42
+ "could not import swap_mha_rope from src.positional_embeddings"
43
+
44
+
45
+ class AttentionBlock(nn.Module):
46
+ def __init__(self, config, layer_idx) -> None:
47
+ super().__init__()
48
+ self.config = config
49
+ self.pre_norm, self.post_norm = RMSNorm(config), RMSNorm(config)
50
+ self.layer_idx = layer_idx
51
+ self.print_activations = config.get("print_activations", False)
52
+ self.proj_groups = config.get("proj_groups", 1)
53
+ dtype = config.get("attn_block_dtype", torch.bfloat16)
54
+ mlp_dtype = config.get("mlp_dtype", torch.bfloat16)
55
+ self.num_attention_heads = config.num_attention_heads
56
+ self.hidden_size = config.hidden_size
57
+ self.hidden_size_per_attention_head = config.hidden_size // config.num_attention_heads
58
+
59
+ self.counter = 0
60
+ self.inner_mha_cls = MHA(
61
+ embed_dim=config.hidden_size,
62
+ num_heads=config.num_attention_heads,
63
+ num_heads_kv=config.num_attention_heads // self.proj_groups,
64
+ rotary_emb_dim=config.hidden_size // config.num_attention_heads,
65
+ qkv_proj_bias=config.get("qkv_proj_bias", True),
66
+ rotary_emb_base=config.get("rotary_emb_base", 1000000),
67
+ causal=True,
68
+ layer_idx=layer_idx,
69
+ out_proj_bias=config.get("mha_out_proj_bias", True),
70
+ use_flash_attn=self.config.use_flash_attn,
71
+ ).to(dtype=dtype)
72
+
73
+ # check if using interpolated rotary pos emb from config, and swap the rope emb
74
+ if config.get("use_interpolated_rotary_pos_emb", False):
75
+ swap_mha_rope(
76
+ mha=self.inner_mha_cls,
77
+ kwargs_new_rope={"scaling_factor": config.get("rotary_emb_scaling_factor", 1.0)},
78
+ )
79
+
80
+ if self.config.get("smeared_gqa", False):
81
+ self.inner_mha_cls.num_heads_kv = self.inner_mha_cls.num_heads
82
+ self.inner_mha_cls.rotary_emb.register_buffer("inv_freq", self.inner_mha_cls.rotary_emb.inv_freq)
83
+
84
+ self.mlp = ParallelGatedMLP(config, layer_idx).to(dtype=mlp_dtype)
85
+
86
+ def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs):
87
+ if (
88
+ type(padding_mask) == torch.Tensor
89
+ ): # workaround for masking bug in FA. This works because Wqkv does not have bias
90
+ # and attention scores will be also automatically zeroed.
91
+ u = u * padding_mask[..., None]
92
+
93
+ if self.print_activations:
94
+ activations_logger.info(f"pre mha: {u}")
95
+
96
+ u = (
97
+ self.inner_mha_cls(
98
+ self.pre_norm(u),
99
+ inference_params=inference_params,
100
+ )
101
+ + u
102
+ )
103
+ if self.print_activations:
104
+ activations_logger.info(f"post mha: {u}")
105
+
106
+ if type(padding_mask) == torch.Tensor: # guard against bias
107
+ u = u * padding_mask[..., None]
108
+
109
+ if self.print_activations:
110
+ activations_logger.info(f"pre mlp: {u} {u.min()} {u.max()} {self.mlp.__class__}")
111
+ activations_logger.info(
112
+ f"post mlp norm: {self.post_norm(u)} {self.post_norm(u).min()} {self.post_norm(u).max()}"
113
+ )
114
+ activations_logger.info(
115
+ f"post mlp: {self.mlp(self.post_norm(u))} {self.mlp(self.post_norm(u)).min()} {self.mlp(self.post_norm(u)).max()}"
116
+ )
117
+
118
+ u = self.mlp(self.post_norm(u)) + u
119
+ return u, None
120
+
121
+
122
+ class HyenaCascade(nn.Module):
123
+ def __init__(self, config, layer_idx, hyena_filter_groups=None, fir_inner_filter_length=None) -> None:
124
+ super().__init__()
125
+ self.config = config
126
+ self.layer_idx = layer_idx
127
+ self.hyena_filter_groups = hyena_filter_groups
128
+ self.print_activations = config.get("print_activations", False)
129
+ self.ground_truth_activations_path = config.get("ground_truth_activations_path", None)
130
+
131
+ self.use_flashfft = config.get("use_flashfft", False)
132
+ self.state_size = config.state_size
133
+ self.hidden_size = config.hidden_size
134
+ self.num_filters = config.num_filters
135
+ self.inference_mode = config.get("inference_mode", True)
136
+ self.counter = 0
137
+ self.column_split_hyena = config.get("column_split_hyena", True)
138
+ self.hyena_flip_x1x2 = config.get("hyena_flip_x1x2", False)
139
+
140
+ assert self.hidden_size % self.num_filters == 0 and self.num_filters <= self.hidden_size
141
+
142
+ # attention heads are not used except to split post short_filter
143
+ # projections in the same way as the checkpoint
144
+ self.num_attention_heads = config.num_attention_heads
145
+ self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads
146
+
147
+ self.fir_inner_filter_length = fir_inner_filter_length
148
+ self.short_filter_length = config.short_filter_length
149
+ self.short_filter_weight = nn.Parameter(torch.randn(3 * config.hidden_size, 1, config.short_filter_length))
150
+ self.short_filter_bias = nn.Parameter(torch.randn(3 * config.hidden_size)) if config.short_filter_bias else None
151
+
152
+ self.engine = HyenaInferenceEngine(
153
+ layer_idx=layer_idx,
154
+ ground_truth_activations_path=self.ground_truth_activations_path,
155
+ print_activations=self.print_activations,
156
+ hyena_flip_x1x2=config.get("hyena_flip_x1x2", False),
157
+ )
158
+ self.use_flash_depthwise = config.get("use_flash_depthwise", False)
159
+ self.data_dtype = None
160
+
161
+ if self.use_flash_depthwise:
162
+ try:
163
+ from flashfftconv import FlashDepthwiseConv1d
164
+
165
+ self.fir_fn = FlashDepthwiseConv1d(
166
+ channels=3 * self.hidden_size,
167
+ kernel_size=self.short_filter_length,
168
+ padding=self.short_filter_length - 1,
169
+ weights=self.short_filter_weight,
170
+ bias=self.short_filter_bias,
171
+ device=None,
172
+ dtype=self.config.get("depthwise_dtype", torch.bfloat16),
173
+ )
174
+ except ImportError:
175
+ "flashfftconv not installed"
176
+ else:
177
+ self.fir_fn = F.conv1d
178
+
179
+ self.fir_inner_fn = F.conv1d
180
+
181
+ self.fftconv_fn = None
182
+ self.long_fir_threshold = config.get("long_fir_threshold", None)
183
+ if self.long_fir_threshold is not None:
184
+ assert self.use_flashfft is False, "long_fir_threshold not compatible with fused flashfft"
185
+
186
+ self.num_systems = self.hyena_filter_groups
187
+ self.channels_per_group = self.hidden_size // self.hyena_filter_groups
188
+
189
+ if self.fir_inner_filter_length:
190
+ self.h = nn.Parameter(torch.randn(self.hyena_filter_groups, 1, fir_inner_filter_length))
191
+
192
+ if fir_inner_filter_length >= 128:
193
+ self.D = nn.Parameter(torch.zeros(self.hidden_size))
194
+
195
+ if fir_inner_filter_length < 128:
196
+ self.D = None
197
+
198
+ else:
199
+ log_poles = torch.randn(self.num_systems, self.state_size, 1, dtype=torch.float32)
200
+
201
+ # TODO: bring over init from internals
202
+ # poles[..., 0] = 1e-2 * torch.randn(self.num_systems, self.state_size, 1)
203
+ # poles[..., 1] = 1e-3 * torch.randn(self.num_systems, self.state_size, 1)
204
+
205
+ self.log_poles = nn.Parameter(log_poles)
206
+ self.residues = nn.Parameter(torch.randn(self.num_systems, self.state_size, dtype=torch.float32))
207
+ self.D = nn.Parameter(torch.zeros(self.hidden_size))
208
+ self.h = None
209
+ self.t = None
210
+
211
+ def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs):
212
+ if inference_params is not None and self.layer_idx in inference_params.fir_state_dict.keys():
213
+ return self.sequential_forward(u, inference_params)
214
+
215
+ else:
216
+ return self.parallel_forward(u, inference_params, padding_mask)
217
+
218
+ def parallel_forward(self, u, inference_params=None, padding_mask=None):
219
+ L = u.shape[1]
220
+ dims = (
221
+ self.hidden_size,
222
+ self.num_attention_heads,
223
+ self.hidden_size_per_attention_head,
224
+ self.state_size,
225
+ self.hyena_filter_groups,
226
+ )
227
+ if self.print_activations:
228
+ activations_logger.info(f"pre 1 parallel fir: {u}, {u.min()}, {u.max()}")
229
+
230
+ z_pre, fir_state = self.engine.parallel_fir(
231
+ self.fir_fn,
232
+ u,
233
+ self.short_filter_weight,
234
+ self.short_filter_bias,
235
+ L,
236
+ dims=dims,
237
+ gate=False,
238
+ column_split_hyena=self.column_split_hyena,
239
+ fir_length=self.short_filter_length,
240
+ inference_params=inference_params,
241
+ padding_mask=padding_mask,
242
+ dim_last=True,
243
+ )
244
+
245
+ if inference_params:
246
+ inference_params.fir_state_dict[self.layer_idx] = fir_state
247
+
248
+ if self.config.interleave:
249
+ z_pre = interleave(z_pre)
250
+
251
+ if self.h is None:
252
+ h, _, _, _ = self.compute_filter(L, u.device)
253
+ else:
254
+ h = self.h
255
+
256
+ D = self.D
257
+
258
+ if self.hyena_filter_groups > 1:
259
+ h = h.repeat_interleave(self.hidden_size // self.hyena_filter_groups, 0)
260
+
261
+ # if inference_params is not None, we plan to perform generation:
262
+ # prefilling is handled by the engine.
263
+ if self.fir_inner_filter_length is not None:
264
+ if self.print_activations:
265
+ activations_logger.info(
266
+ f"pre 2 parallel fir: {z_pre}, {z_pre.min()}, {z_pre.max()}, {self.fir_inner_filter_length}"
267
+ )
268
+ y, fir_inner_state = self.engine.parallel_fir(
269
+ self.fir_inner_fn,
270
+ z_pre,
271
+ h,
272
+ D,
273
+ L,
274
+ dims=dims,
275
+ gate=True,
276
+ gated_bias=self.fir_inner_filter_length >= 128,
277
+ dim_last=False,
278
+ column_split_hyena=self.column_split_hyena,
279
+ fir_length=self.fir_inner_filter_length,
280
+ inference_params=inference_params,
281
+ padding_mask=padding_mask,
282
+ groups=self.hyena_filter_groups,
283
+ )
284
+ if self.print_activations:
285
+ activations_logger.info(f"post 2 parallel fir: {y}, {y.min()}, {y.max()}")
286
+ y = y.permute(0, 2, 1)
287
+ if inference_params:
288
+ inference_params.fir_inner_state_dict[self.layer_idx] = fir_inner_state
289
+ else:
290
+ if self.print_activations:
291
+ activations_logger.info(f"pre 2 parallel iir: {z_pre}, {z_pre.min()}, {z_pre.max()}")
292
+ y = self.engine.parallel_iir(
293
+ z_pre,
294
+ h,
295
+ D,
296
+ L,
297
+ t=self.t,
298
+ poles=self.log_poles,
299
+ residues=self.residues,
300
+ dims=dims,
301
+ inference_params=inference_params,
302
+ layer_idx=self.layer_idx,
303
+ prefill_style=self.config.get("prefill_style", "fft"),
304
+ use_flashfft=self.use_flashfft,
305
+ fftconv_fn=self.fftconv_fn,
306
+ column_split_hyena=self.column_split_hyena,
307
+ long_fir_threshold=self.long_fir_threshold,
308
+ padding_mask=padding_mask,
309
+ )
310
+ if self.print_activations:
311
+ activations_logger.info(f"post 2 parallel iir: {y}, {y.min()}, {y.max()}")
312
+
313
+ return y, inference_params
314
+
315
+ def sequential_forward(self, u, inference_params):
316
+ if self.data_dtype is None:
317
+ self.data_dtype = u.dtype
318
+
319
+ if len(u.shape) > 2:
320
+ u = u[:, -1]
321
+
322
+ z_pre, fir_state = self.engine.step_fir(
323
+ u,
324
+ inference_params.fir_state_dict[self.layer_idx],
325
+ weight=self.short_filter_weight,
326
+ bias=self.short_filter_bias,
327
+ )
328
+ inference_params.fir_state_dict[self.layer_idx] = fir_state
329
+
330
+ if self.config.interleave:
331
+ z_pre = interleave(z_pre)
332
+
333
+ x2, x1, v = (
334
+ column_split(z_pre, self.num_attention_heads, self.hidden_size_per_attention_head)
335
+ if self.column_split_hyena
336
+ else z_pre.split([self.hidden_size, self.hidden_size, self.hidden_size], dim=1)
337
+ )
338
+
339
+ if self.hyena_flip_x1x2:
340
+ x1, x2 = x2, x1
341
+
342
+ if self.fir_inner_filter_length is not None:
343
+ if self.hyena_filter_groups > 1:
344
+ h = self.h.repeat_interleave(self.hidden_size // self.hyena_filter_groups, 0)
345
+ else:
346
+ h = self.h
347
+
348
+ y, fir_inner_state = self.engine.step_fir(
349
+ x1 * v,
350
+ inference_params.fir_inner_state_dict[self.layer_idx],
351
+ weight=h,
352
+ bias=self.D,
353
+ flip_filter=self.fir_inner_filter_length >= 128,
354
+ gated_bias=self.fir_inner_filter_length >= 128,
355
+ )
356
+ y = y * x2
357
+ inference_params.fir_inner_state_dict[self.layer_idx] = fir_inner_state
358
+ else:
359
+ y, iir_state = self.engine.step_iir(
360
+ x2,
361
+ x1,
362
+ v,
363
+ self.D,
364
+ self.residues,
365
+ self.log_poles,
366
+ inference_params.state_dict[self.layer_idx],
367
+ iir_groups=1,
368
+ )
369
+ inference_params.state_dict[self.layer_idx] = iir_state
370
+
371
+ y = y.to(dtype=self.data_dtype)
372
+ return y[:, None], inference_params
373
+
374
+ def update_time(self, L, device):
375
+ """
376
+ Set [0, 1, ..., L-1] where L is the length of the current batch of inputs.
377
+ If L is greater than the length of the previous batch, then the time vector is
378
+ reinitialized. Otherwise, the time vector is truncated from cache.
379
+ """
380
+ if self.t is None:
381
+ self.t = torch.arange(L, device=device)[None, None]
382
+ elif self.t.shape[-1] < L:
383
+ self.t = torch.arange(L, device=device)[None, None]
384
+ else:
385
+ self.t = self.t[..., :L]
386
+
387
+ def compute_filter(self, L, device):
388
+ self.update_time(L, device)
389
+ filter_dtype = torch.float32
390
+ residues, log_poles = (
391
+ self.residues.to(filter_dtype),
392
+ self.log_poles.to(filter_dtype),
393
+ )
394
+ h = (residues[..., None] * (log_poles * self.t).exp()).sum(1)[None] # B, D, L
395
+ return h, filter_dtype, log_poles, residues
396
+
397
+
398
+ class ParallelGatedConvBlock(nn.Module):
399
+ def __init__(self, config, layer_idx, hyena_filter_groups=None, fir_inner_filter_length=None) -> None:
400
+ super().__init__()
401
+ self.config = config
402
+ self.layer_idx = layer_idx
403
+ self.print_activations = config.get("print_activations", False)
404
+ self.ground_truth_activations_path = config.get("ground_truth_activations_path", None)
405
+ self.low_mem_mode = config.get("low_mem_mode", False)
406
+ self.fir_inner_filter_length = fir_inner_filter_length
407
+ self.hyena_filter_groups = hyena_filter_groups if hyena_filter_groups is not None else config.hidden_size
408
+ dtype = config.get("hyena_block_dtype", torch.bfloat16)
409
+ mlp_dtype = config.get("mlp_dtype", torch.bfloat16)
410
+ self.pre_norm, self.post_norm = (
411
+ RMSNorm(config).to(dtype=dtype),
412
+ RMSNorm(config).to(dtype=dtype),
413
+ )
414
+ self.filter = HyenaCascade(
415
+ config,
416
+ layer_idx,
417
+ hyena_filter_groups=self.hyena_filter_groups,
418
+ fir_inner_filter_length=fir_inner_filter_length,
419
+ ).to(dtype=dtype)
420
+
421
+ # For posterity/debugging: TELinear can be easily replaced by
422
+ # nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=config.qkv_proj_bias).to(dtype=dtype)
423
+ # which sometimes is very useful when debugging FP8.
424
+ self.projections = TELinear(
425
+ config.hidden_size,
426
+ 3 * config.hidden_size,
427
+ bias=config.qkv_proj_bias,
428
+ init_method=torch.nn.init.xavier_uniform_,
429
+ use_fp8=config.get("use_fp8_input_projections", False),
430
+ )
431
+
432
+ self.out_filter_dense = nn.Linear(config.hidden_size, config.hidden_size, bias=config.hyena_out_proj_bias).to(
433
+ dtype
434
+ )
435
+ self.mlp = ParallelGatedMLP(config, layer_idx).to(dtype=mlp_dtype)
436
+
437
+ # self.proj_norm_fn = self.proj_norm
438
+ # self.res_mlp_norm_fn = self.res_mlp_norm
439
+
440
+ if self.config.get("compile", False):
441
+ self.proj_norm_fn = torch.compile(self.proj_norm, fullgraph=True, dynamic=False, mode="reduce-overhead")
442
+ self.res_mlp_norm_fn = torch.compile(
443
+ self.res_mlp_norm, fullgraph=True, dynamic=False, mode="reduce-overhead"
444
+ )
445
+
446
+ def pad_to_multiple(self, x, multiple=16):
447
+ """Pad input tensor to multiple of 16 only when FP8 is enabled"""
448
+ if not self.config.get("use_fp8_input_projections", False):
449
+ return x
450
+
451
+ batch_size, seq_len, hidden_dim = x.size()
452
+ pad_len = (multiple - (seq_len % multiple)) % multiple
453
+ if pad_len == 0:
454
+ return x
455
+ return F.pad(x, (0, 0, 0, pad_len))
456
+
457
+ def proj_norm(self, x):
458
+ if self.print_activations:
459
+ activations_logger.info(f"pre mixer norm: {x} {x.min()} {x.max()} {self.projections.__class__}")
460
+ activations_logger.info(
461
+ f"post mixer norm: {self.pre_norm(x)} {self.pre_norm(x).min()} {self.pre_norm(x).max()}"
462
+ )
463
+
464
+ if self.ground_truth_activations_path:
465
+ pre_norm_savanna = torch.load(
466
+ f"{self.ground_truth_activations_path}/pre_mixer_norm_{self.layer_idx}.pt"
467
+ )
468
+ post_norm_savanna = torch.load(
469
+ f"{self.ground_truth_activations_path}/post_mixer_norm_{self.layer_idx}.pt"
470
+ )
471
+
472
+ activation_diff = (x.squeeze() - pre_norm_savanna.squeeze()).abs()
473
+ activations_logger.info(
474
+ f"pre mixer norm activation_diff: {activation_diff.max()}, {activation_diff.mean()}"
475
+ )
476
+ activation_diff = (self.pre_norm(x).squeeze() - post_norm_savanna.squeeze()).abs()
477
+ activations_logger.info(
478
+ f"post mixer norm activation_diff: {activation_diff.max()}, {activation_diff.mean()}"
479
+ )
480
+ activations_logger.info(
481
+ f"pre norm scale: {self.pre_norm.scale}, {self.pre_norm.scale.min()}, {self.pre_norm.scale.max()}"
482
+ )
483
+
484
+ normalized = self.pre_norm(x)
485
+ normalized = self.pad_to_multiple(normalized)
486
+ with torch.cuda.device(x.device):
487
+ projected = self.projections(normalized)
488
+
489
+ if isinstance(projected, tuple):
490
+ projected = projected[0]
491
+
492
+ original_seq_len = x.size(1)
493
+ # Slice back to original sequence length if padding was added
494
+ if projected.size(1) > original_seq_len:
495
+ projected = projected[:, :original_seq_len, :]
496
+
497
+ return projected
498
+
499
+ def res_mlp_norm(self, x):
500
+ if self.print_activations:
501
+ activations_logger.info(f"pre mlp: {x} {x.min()} {x.max()} {self.mlp.__class__}")
502
+ activations_logger.info(
503
+ f"post mlp norm: {self.post_norm(x)} {self.post_norm(x).min()} {self.post_norm(x).max()}"
504
+ )
505
+ activations_logger.info(
506
+ f"post mlp: {self.mlp(self.post_norm(x))} {self.mlp(self.post_norm(x)).min()} {self.mlp(self.post_norm(x)).max()}"
507
+ )
508
+ if self.ground_truth_activations_path:
509
+ pre_mlp_savanna = torch.load(f"{self.ground_truth_activations_path}/pre_mlp_{self.layer_idx}.pt")
510
+ post_mlp_savanna = torch.load(f"{self.ground_truth_activations_path}/post_mlp_norm_{self.layer_idx}.pt")
511
+
512
+ activation_diff = (x.squeeze() - pre_mlp_savanna.squeeze()).abs()
513
+ activations_logger.info(f"pre mlp activation_diff: {activation_diff.max()}, {activation_diff.mean()}")
514
+ activation_diff = (self.post_norm(x).squeeze() - post_mlp_savanna.squeeze()).abs()
515
+ activations_logger.info(
516
+ f"post mlp norm activation_diff: {activation_diff.max()}, {activation_diff.mean()}"
517
+ )
518
+ return self.mlp(self.post_norm(x)) + x
519
+
520
+ def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs):
521
+ z = self.proj_norm(u)
522
+
523
+ if type(padding_mask) == torch.Tensor: # guard against bias
524
+ z = z * padding_mask[..., None]
525
+
526
+ if self.print_activations:
527
+ activations_logger.info(f"pre filter: {z} {z.min()} {z.max()} {self.filter.__class__}")
528
+ if self.ground_truth_activations_path:
529
+ z_savanna = torch.load(f"{self.ground_truth_activations_path}/pre_filter_{self.layer_idx}.pt")
530
+ activation_diff = (z - z_savanna.squeeze()).abs()
531
+ activations_logger.info(
532
+ f"pre filter activation_diff: {activation_diff.max()}, {activation_diff.mean()}"
533
+ )
534
+ z, inference_params = self.filter(z, inference_params=inference_params, padding_mask=padding_mask)
535
+
536
+ if self.print_activations:
537
+ activations_logger.info(f"post postgate: {z} {z.min()} {z.max()} {self.filter.__class__}")
538
+ activations_logger.info(
539
+ f"post out proj: {self.out_filter_dense(z)} {self.out_filter_dense(z).min()} {self.out_filter_dense(z).max()} {self.out_filter_dense.__class__}"
540
+ )
541
+ activations_logger.info(
542
+ f"post mixer dense and residual: {self.out_filter_dense(z) + u} {(self.out_filter_dense(z) + u).min()} {(self.out_filter_dense(z) + u).max()}"
543
+ )
544
+ activations_logger.info(
545
+ f"post mixer dense: {self.out_filter_dense(z)} {self.out_filter_dense(z).min()} {self.out_filter_dense(z).max()}"
546
+ )
547
+ activations_logger.info(f"post mixer: {z} {z.min()} {z.max()}")
548
+ if self.ground_truth_activations_path:
549
+ z_savanna = torch.load(f"{self.ground_truth_activations_path}/post_filter_{self.layer_idx}.pt")
550
+ activation_diff = (z - z_savanna.squeeze()).abs()
551
+ activations_logger.info(
552
+ f"post filter activation_diff: {activation_diff.max()}, {activation_diff.mean()}"
553
+ )
554
+
555
+ z_savanna = torch.load(f"{self.ground_truth_activations_path}/post_out_proj_{self.layer_idx}.pt")
556
+ z_ = F.linear(z, self.out_filter_dense.weight)
557
+ activation_diff = (z_ - z_savanna.squeeze()).abs()
558
+ activations_logger.info(
559
+ f"post out proj activation_diff: {activation_diff.max()}, {activation_diff.mean()}"
560
+ )
561
+
562
+ z_in = self.out_filter_dense(z) + u
563
+
564
+ # if self.layer_idx == 0:
565
+ # z_in = z_savanna.squeeze() + u + self.out_filter_dense.bias
566
+
567
+ if type(padding_mask) == torch.Tensor: # guard against bias
568
+ z_in = z_in * padding_mask[..., None]
569
+
570
+ y = self.res_mlp_norm(z_in)
571
+
572
+ return y, inference_params
573
+
574
+
575
+ def get_block(config, layer_idx, flash_fft=None):
576
+ if layer_idx in config.attn_layer_idxs:
577
+ return AttentionBlock(config, layer_idx)
578
+ elif layer_idx in config.hcl_layer_idxs:
579
+ block = ParallelGatedConvBlock(config, layer_idx)
580
+ if config.get("use_flashfft", "False"):
581
+ block.filter.fftconv_fn = flash_fft
582
+ return block
583
+ elif layer_idx in config.hcm_layer_idxs:
584
+ block = ParallelGatedConvBlock(
585
+ config,
586
+ layer_idx,
587
+ hyena_filter_groups=config.hcm_filter_groups,
588
+ fir_inner_filter_length=config.hcm_filter_length,
589
+ )
590
+ return block
591
+ elif layer_idx in config.hcs_layer_idxs:
592
+ block = ParallelGatedConvBlock(
593
+ config,
594
+ layer_idx,
595
+ hyena_filter_groups=config.hcs_filter_groups,
596
+ fir_inner_filter_length=config.hcs_filter_length,
597
+ )
598
+ return block
599
+ else:
600
+ raise NotImplementedError
601
+
602
+
603
+ class StripedHyena(nn.Module):
604
+ def __init__(self, config):
605
+ super().__init__()
606
+ fixup_te_workspace() # Workaround global cublas workspaces in TE
607
+
608
+ self.config = config
609
+ self.print_activations = config.get("print_activations", False)
610
+
611
+ if self.print_activations:
612
+ enable_activations_logging()
613
+ self.logger = logging.getLogger(self.__class__.__name__)
614
+
615
+ self.ground_truth_activations_path = config.get("ground_truth_activations_path", None)
616
+ self.logger.info(f"Initializing StripedHyena with config: {config}")
617
+
618
+ with torch.device("cuda:0" if torch.cuda.is_available() else "cpu"):
619
+ self.embedding_layer = VocabParallelEmbedding(config)
620
+
621
+ if config.get("use_flashfft", "True"):
622
+ try:
623
+ from flashfftconv import FlashFFTConv
624
+
625
+ self.flash_fft = FlashFFTConv(config.seqlen, dtype=torch.bfloat16)
626
+ except ImportError:
627
+ "flashfftconv not installed"
628
+ else:
629
+ self.flash_fft = None
630
+ if not self.config.get('evo2_style_activations', False):
631
+ self.logger.warning(
632
+ "⚠️ Not using Evo2 style activations ⚠️\n"
633
+ "⚠️ Set 'evo2_style_activations: True' in config if you are using Evo 2 checkpoints ⚠️"
634
+ )
635
+ self.logger.info(f"Initializing {config.num_layers} blocks...")
636
+ self.blocks = nn.ModuleList()
637
+ self.block_idx_to_device = {}
638
+
639
+ # Calculate layers per GPU
640
+ num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
641
+ layers_per_gpu = math.ceil(config.num_layers / num_gpus)
642
+ self.logger.info(f"Distributing across {num_gpus} GPUs, approximately {layers_per_gpu} layers per GPU")
643
+
644
+ for layer_idx in tqdm(range(config.num_layers)):
645
+ # Determine which GPU should handle this layer
646
+ device_idx = min(layer_idx // layers_per_gpu, num_gpus - 1)
647
+ device = f"cuda:{device_idx}" if torch.cuda.is_available() else "cpu"
648
+
649
+ with torch.device(device):
650
+ # TELinear uses `device="cuda"` device to allocate empty bias
651
+ # tensor. This makes sure that the empty tensor is allocated on the
652
+ # correct device. (torch.device(), unlike torch.cuda.device(),
653
+ # doesn't override current CUDA device.)
654
+ with torch.cuda.device(device):
655
+ block = get_block(config, layer_idx, flash_fft=self.flash_fft)
656
+ move_to_device(block, device)
657
+
658
+ self.blocks.append(block)
659
+ self.block_idx_to_device[layer_idx] = device
660
+ self.logger.info(f"Assigned {layer_idx=} to {device=}")
661
+ self.logger.info(
662
+ f"Parameter count for block {layer_idx}: {sum(p.numel() for p in self.blocks[-1].parameters())}"
663
+ )
664
+
665
+ with torch.device(self.block_idx_to_device[0]):
666
+ with torch.cuda.device(self.block_idx_to_device[0]):
667
+ self.norm = RMSNorm(config) if config.get("final_norm", True) else None
668
+ if config.tie_embeddings:
669
+ # Lambda usage is to be able to use forward() on caller side, which in
670
+ # turn is needed for PyTorch hooks to work properly.
671
+ self.unembed = Lambda(self.embedding_layer.unembed)
672
+ else:
673
+ if config.tie_embeddings:
674
+ # Technically we can support this mode, just need to
675
+ # copy tensors across GPUs then. But let's implement it
676
+ # once/if needed.
677
+ self.logger.info("Ignoring tie_embeddings for now.")
678
+ self.unembed = VocabParallelUnembedding(config)
679
+
680
+ self.logger.info("Initialized model")
681
+
682
+ def forward(self, x, inference_params_dict=None, padding_mask=None):
683
+ L = x.shape[1]
684
+ if self.print_activations:
685
+ activations_logger.info(f"pre embedding: {x}, {x.min()}, {x.max()}")
686
+
687
+ x = self.embedding_layer(x)
688
+
689
+ if self.print_activations:
690
+ activations_logger.info(f"post embedding: {x}, {x.min()}, {x.max()}")
691
+
692
+ if inference_params_dict is not None:
693
+ x, inference_params_dict_out = self.stateful_forward(
694
+ x,
695
+ inference_params_dict=inference_params_dict,
696
+ )
697
+ else:
698
+ x, inference_params_dict_out = self.stateless_forward(x, padding_mask=padding_mask)
699
+
700
+ if self.print_activations:
701
+ activations_logger.info(f"pre norm: {x}, {x.min()}, {x.max()}")
702
+
703
+ # By convention, we return results on the first device
704
+ x = x.to(self.block_idx_to_device[0])
705
+ x = self.norm(x)
706
+
707
+ if self.print_activations:
708
+ activations_logger.info(f"post norm: {x}, {x.min()}, {x.max(), {self.norm.scale}}")
709
+
710
+ x = self.unembed(x)
711
+ return x, inference_params_dict_out
712
+
713
+ def block_idx_to_name(self, block_idx):
714
+ if block_idx in self.config.attn_layer_idxs:
715
+ return "mha"
716
+ elif block_idx in self.config.hcl_layer_idxs:
717
+ return "hcl"
718
+ elif block_idx in self.config.hcm_layer_idxs:
719
+ return "hcm"
720
+ elif block_idx in self.config.hcs_layer_idxs:
721
+ return "hcs"
722
+ else:
723
+ raise ValueError(f"Block index {block_idx} not found")
724
+
725
+ def cross_device_transfer(self, x, block_idx):
726
+ if self.block_idx_to_device[max(block_idx - 1, 0)] != self.block_idx_to_device[block_idx]:
727
+ x = x.to(self.block_idx_to_device[block_idx])
728
+ return x
729
+
730
+ def stateful_forward(self, x, inference_params_dict=None):
731
+ for block_idx, block in enumerate(self.blocks):
732
+ inference_params = inference_params_dict[self.block_idx_to_name(block_idx)]
733
+
734
+ if self.print_activations:
735
+ activations_logger.info(f"pre block {block_idx}: {x}, {x.min()}, {x.max()} {block.__class__}")
736
+ if self.ground_truth_activations_path:
737
+ x_savanna = torch.load(f"{self.ground_truth_activations_path}/pre_block_{block_idx}.pt")
738
+ activation_diff = (x - x_savanna.squeeze()).abs()
739
+ activations_logger.info(
740
+ f"pre block {block_idx} activation_diff: {activation_diff.max()}, {activation_diff.mean()}"
741
+ )
742
+
743
+ x = self.cross_device_transfer(x, block_idx)
744
+ x, _ = block(x, inference_params=inference_params)
745
+
746
+ if self.print_activations:
747
+ activations_logger.info(f"post block {block_idx}: {x}, {x.min()}, {x.max()}")
748
+ if self.ground_truth_activations_path:
749
+ x_savanna = torch.load(f"{self.ground_truth_activations_path}/post_block_{block_idx}.pt")
750
+ activation_diff = (x - x_savanna.squeeze()).abs()
751
+ activations_logger.info(
752
+ f"post block {block_idx} activation_diff: {activation_diff.max()}, {activation_diff.mean()}"
753
+ )
754
+
755
+ return x, inference_params_dict
756
+
757
+ def stateless_forward(self, x, padding_mask=None):
758
+ if type(padding_mask) == torch.Tensor:
759
+ x = x * padding_mask[..., None]
760
+
761
+ for block_idx, block in enumerate(self.blocks):
762
+ if self.print_activations:
763
+ activations_logger.info(f"pre block {block_idx}: {x}, {x.min()}, {x.max()} {block.__class__}")
764
+ if self.ground_truth_activations_path:
765
+ x_savanna = torch.load(f"{self.ground_truth_activations_path}/pre_block_{block_idx}.pt")
766
+ activation_diff = (x - x_savanna.squeeze()).abs()
767
+ activations_logger.info(
768
+ f"pre block {block_idx} activation_diff: {activation_diff.max()}, {activation_diff.mean()}"
769
+ )
770
+
771
+ x = self.cross_device_transfer(x, block_idx)
772
+ x, _ = block(x, inference_params=None, padding_mask=padding_mask)
773
+
774
+ if self.print_activations:
775
+ activations_logger.info(f"post block {block_idx}: {x}, {x.min()}, {x.max()}")
776
+ if self.ground_truth_activations_path:
777
+ x_savanna = torch.load(f"{self.ground_truth_activations_path}/post_block_{block_idx}.pt")
778
+ activation_diff = (x - x_savanna.squeeze()).abs()
779
+ activations_logger.info(
780
+ f"post block {block_idx} activation_diff: {activation_diff.max()}, {activation_diff.mean()}"
781
+ )
782
+
783
+ return x, None
784
+
785
+ def initialize_inference_params(self, max_seqlen=None):
786
+ ## Input seqlen takes priority over config!
787
+ ## WARNING: This avoids potential errors but means the model can be used beyond length it was trained at
788
+ config_seqlen = self.config.get("max_seqlen", None)
789
+ if config_seqlen is None:
790
+ print("No max_seqlen found in config!!! using default value of 8192")
791
+ config_seqlen = 8192
792
+ new_max_seqlen = max_seqlen if max_seqlen != None else config_seqlen
793
+ # self.config["max_seqlen"] = new_max_seqlen
794
+ ## Note: changing the stored config max_seqlen will change the max_seqlen used in flash attention, leading to minor logit differences
795
+ print(f"Initializing inference params with max_seqlen={new_max_seqlen}")
796
+
797
+ inference_params_dict = {
798
+ "mha": InferenceParams(
799
+ max_seqlen=new_max_seqlen,
800
+ max_batch_size=self.config.get("max_batch_size", 1),
801
+ seqlen_offset=0,
802
+ ),
803
+ "hcl": HyenaCascadeIIRInferenceParams(
804
+ fir_filter_length=self.config.short_filter_length,
805
+ state_dim=self.config.state_size,
806
+ seqlen_offset=0,
807
+ ),
808
+ "hcm": HyenaCascadeFIRInferenceParams(
809
+ fir_filter_length=self.config.short_filter_length,
810
+ fir_inner_filter_length=self.config.hcm_filter_length,
811
+ seqlen_offset=0,
812
+ ),
813
+ "hcs": HyenaCascadeFIRInferenceParams(
814
+ fir_filter_length=self.config.short_filter_length,
815
+ fir_inner_filter_length=self.config.hcs_filter_length,
816
+ seqlen_offset=0,
817
+ ),
818
+ }
819
+ return inference_params_dict
820
+
821
+ def precompute_filters(self, L, device):
822
+ for block_idx, block in enumerate(self.blocks):
823
+ if type(block) == ParallelGatedConvBlock:
824
+ if type(block.filter) == HyenaCascade:
825
+ L = block.filter.long_fir_threshold or L
826
+ print_rank_0(f"Precomputing filters, L={L}...")
827
+
828
+ filter_dtype = torch.float16 if L >= 2048 else torch.float32
829
+
830
+ block.filter._set_time(L, device)
831
+ residues, poles = (
832
+ block.filter.residues.to(torch.float16),
833
+ block.filter.poles.to(torch.float16),
834
+ )
835
+
836
+ block.filter.h = (residues * poles**block.filter.t).real.sum(1)[None]
837
+ block.filter.h = block.filter.h.to(dtype=filter_dtype)
838
+
839
+ def load_poles_residues(self, path):
840
+ "Load different poles and residues for each layer."
841
+ for block_idx, block in enumerate(self.blocks):
842
+ if type(block) == ParallelGatedConvBlock:
843
+ if type(block.filter) == HyenaCascade:
844
+ self.logger.info(f"Loading approximatepoles and residues for block {block_idx}")
845
+ poles = torch.load(path + f"/approx_poles_{block_idx+1}.pt", map_location="cpu")
846
+ poles = torch.view_as_real(poles)
847
+ residues = torch.load(path + f"/approx_residues_{block_idx+1}.pt", map_location="cpu")
848
+ residues = torch.view_as_real(residues)
849
+ poles = poles.permute(1, 0, 2).unsqueeze(-2)
850
+ residues = residues.permute(1, 0, 2).unsqueeze(-2)
851
+
852
+ block.filter.poles = nn.Parameter(poles)
853
+ block.filter.residues = nn.Parameter(residues)
854
+
855
+ def custom_load_state_dict(self, state_dict, strict=True):
856
+ """
857
+ Post-processes the state_dict to convert savanna checkpoints to vortex checkpoints.
858
+ """
859
+ self.logger.debug(f"Loading state dict: {state_dict}, (ignoring extra keys) with strict: {strict}")
860
+ model_dict = self.state_dict()
861
+
862
+ # Find keys that are in model_dict but not in state_dict
863
+ missing_in_state_dict = model_dict.keys() - state_dict.keys()
864
+ # Find keys that are in state_dict but not in model_dict
865
+ extra_in_state_dict = state_dict.keys() - model_dict.keys()
866
+
867
+ if missing_in_state_dict:
868
+ print(f"Keys missing in state_dict: {missing_in_state_dict}")
869
+ if extra_in_state_dict:
870
+ print(f"Extra keys in state_dict: {extra_in_state_dict}")
871
+
872
+ filtered_dict = {k: v for k, v in state_dict.items() if k in model_dict}
873
+
874
+ if all("._extra_state" in k for k in missing_in_state_dict):
875
+ self.logger.info("Checkpoint has no FP8 extra state, will be using initial state.")
876
+ for k in missing_in_state_dict:
877
+ filtered_dict[k] = None
878
+
879
+ self.load_state_dict(filtered_dict, strict=strict)
880
+ fixup_fp8_extra_states(self)
881
+
882
+ if self.config.get("column_split", True):
883
+ self.logger.info("Adjusting Wqkv for column split (permuting rows)")
884
+ for layer_idx, block in enumerate(self.blocks):
885
+ if type(block) == AttentionBlock:
886
+ target_device = block.inner_mha_cls.Wqkv.weight.device
887
+
888
+ Wqkv = state_dict[f"blocks.{layer_idx}.inner_mha_cls.Wqkv.weight"]
889
+ try:
890
+ bias = state_dict[f"blocks.{layer_idx}.inner_mha_cls.Wqkv.bias"]
891
+ except:
892
+ bias = None
893
+
894
+ size_att_head = block.hidden_size_per_attention_head
895
+
896
+ Wqkv = Wqkv.permute(1, 0)
897
+ Wqkv = Wqkv.reshape(block.hidden_size, block.num_attention_heads, 3, size_att_head)
898
+ Wq, Wk, Wv = Wqkv.unbind(dim=-2)
899
+ Wq = Wq.reshape(block.hidden_size, -1)
900
+ Wk = Wk.reshape(block.hidden_size, -1)
901
+ Wv = Wv.reshape(block.hidden_size, -1)
902
+ Wqkv = torch.cat([Wq, Wk, Wv], dim=-1)
903
+ Wqkv = Wqkv.permute(1, 0)
904
+
905
+ # Single device transfer at the end
906
+ block.inner_mha_cls.Wqkv.weight.data = Wqkv.to(target_device)
907
+
908
+ if bias is not None:
909
+ bias = bias.cpu() # Process on CPU
910
+ bias = bias.reshape(block.num_attention_heads, 3, size_att_head)
911
+ bias_q, bias_k, bias_v = bias.unbind(dim=-2)
912
+ bias_q = bias_q.reshape(block.hidden_size)
913
+ bias_k = bias_k.reshape(block.hidden_size)
914
+ bias_v = bias_v.reshape(block.hidden_size)
915
+ bias = torch.cat([bias_q, bias_k, bias_v], dim=0)
916
+ try:
917
+ block.inner_mha_cls.Wqkv.bias.data = bias.to(target_device)
918
+ except:
919
+ pass
920
+
921
+ def to_bfloat16_except_pr_lc(self, to_float32=False):
922
+ """Convert all parameters to bfloat16 except for the poles and residues.
923
+
924
+ Particularly important for longer prompts.
925
+ """
926
+ excluded_shapes = [(4096, 1, 128)]
927
+ for k, p in self.named_parameters():
928
+ if "projections" not in k: # avoid TE linears
929
+ if "log_poles" not in k and "residues" not in k and p.shape not in excluded_shapes:
930
+ p.data = p.data.to(torch.bfloat16)
931
+ else:
932
+ if to_float32:
933
+ p.data = p.data.to(torch.float32)
934
+ for k, b in self.named_buffers():
935
+ if "inv_freq" in k:
936
+ if to_float32:
937
+ b.data = b.data.to(torch.float32)
positional_embeddings.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied verbatim from vortex
2
+ """
3
+ Armin Thomas, Jan 2023. Modified by Eric Nguyen.
4
+
5
+ Wrappers for linearly interpolated rope embeddings to use inside of MHA layers of Flash Attn.
6
+
7
+ """
8
+
9
+ import torch
10
+ from einops import rearrange
11
+ from .rotary import RotaryEmbedding
12
+
13
+
14
+ # simple wrapper for flash-attn RoPE with linear scaling:
15
+ class LinearlyScaledRotaryEmbedding(RotaryEmbedding):
16
+ def __init__(
17
+ self,
18
+ dim: int,
19
+ scaling_factor: float = 1.0,
20
+ base=10000.0,
21
+ interleaved=False,
22
+ scale_base=None,
23
+ pos_idx_in_fp32=True,
24
+ device=None,
25
+ ):
26
+ super().__init__(
27
+ dim=dim,
28
+ base=base,
29
+ interleaved=interleaved,
30
+ scale_base=scale_base,
31
+ pos_idx_in_fp32=pos_idx_in_fp32,
32
+ device=device,
33
+ )
34
+ self._linear_scaling_factor = scaling_factor
35
+
36
+ # adpated from: https://github.com/Dao-AILab/flash-attention/blob/43ceab630bc6c27712428da5a33fc9cb5c369d91/flash_attn/layers/rotary.py#L368
37
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
38
+ # Reset the tables if the sequence length has changed,
39
+ # if we're on a new device (possibly due to tracing for instance),
40
+ # or if we're switching from inference mode to training
41
+ if (
42
+ seqlen > self._seq_len_cached
43
+ or self._cos_cached is None
44
+ or self._cos_cached.device != device
45
+ or self._cos_cached.dtype != dtype
46
+ or (self.training and self._cos_cached.is_inference())
47
+ ):
48
+ self._seq_len_cached = seqlen
49
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
50
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
51
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
52
+ if self.pos_idx_in_fp32:
53
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
54
+ # linear scaling:
55
+ t = t / self._linear_scaling_factor
56
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
57
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
58
+ # cos & sin output to change significantly.
59
+ # We want to recompute self.inv_freq if it was not loaded in fp32
60
+ if self.inv_freq.dtype != torch.float32:
61
+ inv_freq = self._compute_inv_freq(device=device)
62
+ else:
63
+ inv_freq = self.inv_freq
64
+ else:
65
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
66
+ # linear scaling:
67
+ t = t / self._linear_scaling_factor
68
+ inv_freq = self.inv_freq
69
+ # Don't do einsum, it converts fp32 to fp16 under AMP
70
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
71
+ freqs = torch.outer(t, inv_freq)
72
+ if self.scale is None:
73
+ self._cos_cached = torch.cos(freqs).to(dtype)
74
+ self._sin_cached = torch.sin(freqs).to(dtype)
75
+ else:
76
+ power = (
77
+ torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
78
+ ) / self.scale_base
79
+ scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
80
+ # We want the multiplication by scale to happen in fp32
81
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
82
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
83
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
84
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
85
+
86
+
87
+ # swap out RoPE of existing mha:
88
+ def swap_mha_rope(
89
+ mha,
90
+ new_rope: torch.nn.Module = LinearlyScaledRotaryEmbedding,
91
+ kwargs_new_rope: dict = None,
92
+ ):
93
+ # determine mha dtype and device:
94
+ dtype = mha.Wq.weight.dtype if mha.cross_attn else mha.Wqkv.weight.dtype
95
+ device = mha.Wq.weight.device if mha.cross_attn else mha.Wqkv.weight.device
96
+ # determine RoPE settings:
97
+ kwargs_old_rope = dict(
98
+ dim=mha.rotary_emb.dim,
99
+ base=mha.rotary_emb.base,
100
+ interleaved=mha.rotary_emb.interleaved,
101
+ scale_base=mha.rotary_emb.scale_base,
102
+ pos_idx_in_fp32=mha.rotary_emb.pos_idx_in_fp32,
103
+ device=mha.rotary_emb.inv_freq.device,
104
+ )
105
+ # delete old RoPE:
106
+ del mha.rotary_emb
107
+ # create new RoPE:
108
+ kwargs_new_rope = kwargs_new_rope or {"scaling_factor": 1.0}
109
+ scaled_rope = new_rope(**kwargs_new_rope, **kwargs_old_rope).to(dtype)
110
+ # attach new RoPE to mha:
111
+ mha.rotary_emb = scaled_rope
112
+ # make new sure RoPE is correctly registered:
113
+ assert isinstance(mha.rotary_emb, new_rope)
114
+ return mha
sample.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied verbatim from vortex
2
+ import torch
3
+
4
+
5
+ # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
6
+ # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
7
+ def modify_logits_for_top_k_filtering(logits, top_k):
8
+ """Set the logits for none top-k values to -inf. Done in-place."""
9
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
10
+ logits.masked_fill_(indices_to_remove, float("-Inf"))
11
+
12
+
13
+ # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
14
+ # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
15
+ def modify_logits_for_top_p_filtering(logits, top_p):
16
+ """Set the logits for none top-p values to -inf. Done in-place."""
17
+ if top_p <= 0.0 or top_p >= 1.0:
18
+ return
19
+
20
+ # First sort and calculate cumulative sum of probabilities.
21
+ sorted_logits, sorted_indices = torch.sort(logits, descending=False)
22
+ cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
23
+ # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
24
+ sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
25
+ # scatter sorted tensors to original indexing
26
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
27
+ logits.masked_fill_(indices_to_remove, float("-inf"))
28
+
29
+
30
+ # https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py
31
+ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
32
+ """Sample from top-k logits.
33
+ Arguments:
34
+ logits: Tensor of shape (batch_size, vocab_size)
35
+ """
36
+ logits = torch.nan_to_num(logits)
37
+ logits = torch.where(logits == float("-inf"), 0, logits)
38
+ logits = torch.where(logits == float("inf"), 0, logits)
39
+
40
+ if top_k == 1: # Short-circuit for greedy decoding
41
+ return logits.argmax(dim=-1)
42
+ else:
43
+ if top_p > 0.0:
44
+ assert top_p <= 1.0, "top-p should be in (0, 1]."
45
+ if top_k > 0:
46
+ top_k = min(top_k, logits.size(-1)) # Safety check
47
+ logits_top, indices = torch.topk(logits, top_k, dim=-1)
48
+ if temperature != 1.0:
49
+ logits_top /= temperature
50
+ modify_logits_for_top_p_filtering(logits_top, top_p)
51
+
52
+ return indices[
53
+ torch.arange(indices.shape[0], device=indices.device),
54
+ torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
55
+ ]
56
+ else:
57
+ # Clone so that when we modify for top_p we don't change the original logits
58
+ logits_top = logits / temperature if temperature != 1.0 else logits.clone()
59
+ modify_logits_for_top_p_filtering(logits_top, top_p)
60
+ return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
utils.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied veratim from vortex
2
+ import torch
3
+ import logging
4
+
5
+ log = logging.getLogger(__name__)
6
+
7
+
8
+ def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int:
9
+ """Get the dim for the local rank derived from splitting dim on world_size processes.
10
+
11
+ The split may not be even across the world_size processes.
12
+ """
13
+ multiple = dim // multiple_of
14
+ div = multiple // world_size
15
+ mod = multiple % world_size
16
+ local_multiple = div + int(local_rank < mod)
17
+ return local_multiple * multiple_of
18
+
19
+
20
+ def grab_first_if_tuple(x):
21
+ if x.__class__.__name__ == "tuple":
22
+ return x[0]
23
+ else:
24
+ return x
25
+
26
+
27
+ def interleave(z_pre):
28
+ if len(z_pre.shape) == 3: # non-cached
29
+ x1 = z_pre[:, 0::3, :]
30
+ x2 = z_pre[:, 1::3, :]
31
+ v = z_pre[:, 2::3, :]
32
+ z_pre = torch.cat([x1, x2, v], dim=1)
33
+ return z_pre
34
+ else:
35
+ x1 = z_pre[..., 0::3]
36
+ x2 = z_pre[..., 1::3]
37
+ v = z_pre[..., 2::3]
38
+ z_pre = torch.concat([x1, x2, v], dim=-1)
39
+ return z_pre
40
+
41
+
42
+ def column_split(x, num_heads, head_size):
43
+ """Split a tensor with `num_heads` alongside the head dimension, instead of
44
+ across heads. Fixed to three projections
45
+ """
46
+ # FIXME: merge cases
47
+ if len(x.shape) == 2:
48
+ x_reshaped = x.reshape(
49
+ x.shape[0],
50
+ num_heads,
51
+ 3 * head_size,
52
+ )
53
+
54
+ x2, x1, v = (
55
+ x_reshaped[..., :head_size],
56
+ x_reshaped[..., head_size : 2 * head_size],
57
+ x_reshaped[..., 2 * head_size :],
58
+ )
59
+ x2, x1, v = (
60
+ x2.reshape(x2.shape[0], -1),
61
+ x1.reshape(x1.shape[0], -1),
62
+ v.reshape(v.shape[0], -1),
63
+ )
64
+ return x2, x1, v
65
+ else:
66
+ x = x.reshape(
67
+ x.shape[0],
68
+ num_heads,
69
+ 3 * head_size,
70
+ x.shape[2],
71
+ )
72
+ x2, x1, v = (
73
+ x[:, :, :head_size],
74
+ x[
75
+ :,
76
+ :,
77
+ head_size : 2 * head_size,
78
+ ],
79
+ x[:, :, 2 * head_size :],
80
+ )
81
+ x2, x1, v = (
82
+ x2.reshape(x2.shape[0], -1, x2.shape[-1]),
83
+ x1.reshape(x1.shape[0], -1, x1.shape[-1]),
84
+ v.reshape(v.shape[0], -1, v.shape[-1]),
85
+ )
86
+ return x2, x1, v
87
+
88
+
89
+ def load_checkpoint(model, checkpoint_path):
90
+ if checkpoint_path is None:
91
+ log.warning("Using random weights (dry-run)")
92
+ return
93
+ log.info(f"Loading {checkpoint_path}")
94
+
95
+ # We must allowlist BytesIO, as fp8-enabled checkpoints store this type
96
+ # in Transformer Engine layers' _extra keys. If not, weights_only=True
97
+ # will not be happy.
98
+ import io
99
+
100
+ torch.serialization.add_safe_globals([io.BytesIO])
101
+
102
+ with torch.inference_mode():
103
+ state = torch.load(
104
+ checkpoint_path,
105
+ # Make sure we override device location that is specified in the
106
+ # checkpoint dictionary (e.g. checkpoints may have "cuda:0"
107
+ # as a location for all layers, which then wouldn't work for
108
+ # multi-GPU case.)
109
+ map_location="cpu",
110
+ # This is an optimization: with that, we don't actually read
111
+ # whole checkpoints dictionary from disk to CPU memory in one
112
+ # go; instead, pytorch would only load relevant layers to CPU
113
+ # memory when we are about to copy them to GPU.
114
+ mmap=True,
115
+ # Make sure PyTorch is not issuing a warning regarding potential
116
+ # security issues.
117
+ weights_only=True,
118
+ )
119
+ model.to_bfloat16_except_pr_lc(to_float32=True)
120
+
121
+ model.custom_load_state_dict(state)
122
+
123
+ model.to_bfloat16_except_pr_lc()
124
+
125
+
126
+ def move_to_device(module, device):
127
+ """Recursively moves all parameters and buffers to the specified device."""
128
+ for child in module.children():
129
+ move_to_device(child, device)
130
+
131
+ for param in module.parameters(recurse=False):
132
+ if param.device != device:
133
+ param.data = param.data.to(device)
134
+
135
+ for buf in module.buffers(recurse=False):
136
+ if buf.device != device:
137
+ buf.data = buf.data.to(device)
138
+
139
+ module.to(device)
140
+
141
+
142
+ def fixup_fp8_extra_states(module):
143
+ """Recursively fixes device location of TE's Linear fp8 extra states."""
144
+ for child in module.children():
145
+ fixup_fp8_extra_states(child)
146
+
147
+ # TE Linear uses default "cuda" device to load extra state, which causes
148
+ # trouble when the layer is moved to another GPU. Instead, this is how
149
+ # TE Linear should load extra_state: using parameters' device.
150
+ torch_load = torch.load
151
+
152
+ def overriden_load(state, map_location):
153
+ device = next(module.parameters()).device
154
+ return torch_load(state, map_location=device)
155
+
156
+ if hasattr(module, "fp8_meta"):
157
+ log.debug(f"Reloading fp8 extra state to a proper device for {module}")
158
+ from unittest.mock import patch
159
+
160
+ with patch("torch.load", new=overriden_load):
161
+ module.set_extra_state(module.get_extra_state())
162
+
163
+
164
+ def fixup_te_workspace():
165
+ """TE uses single workspace tensor for all calls, disregarding that inputs
166
+ may be on separate GPUs. This patches TE's Linear module to use per-device
167
+ workspaces."""
168
+ from functools import lru_cache
169
+
170
+ @lru_cache
171
+ def te_cublas_get_workspace_per_device(device):
172
+ log.info(f"Fixup applied: Allocating cublas workspace for {device=}")
173
+ import transformer_engine.pytorch.module.base as tebase
174
+
175
+ with torch.cuda.device(device):
176
+ tebase._cublas_workspace = None # Force get_workspace() to reallocate tensor
177
+ return tebase.get_workspace()
178
+
179
+ def get_workspace():
180
+ return te_cublas_get_workspace_per_device(torch.cuda.current_device())
181
+
182
+ import transformer_engine.pytorch.module.linear as telinear
183
+
184
+ telinear.get_workspace = get_workspace
185
+
186
+
187
+ def get_init_from_string(init_str):
188
+ if type(init_str) == str:
189
+ if init_str == "torch.nn.init.zeros_":
190
+ return torch.nn.init.zeros_
191
+ elif init_str == "torch.nn.init.xavier_uniform_":
192
+ return torch.nn.init.xavier_uniform_
193
+ elif init_str == "torch.nn.init.xavier_normal_":
194
+ return torch.nn.init.xavier_normal_
195
+ else:
196
+ raise ValueError(f"Unrecognized init {init_str}")
197
+
198
+
199
+ def print_rank_0(message, debug=False, end="\n"):
200
+ """Print from rank 0 only."""
201
+ if torch.distributed.is_initialized():
202
+ if torch.distributed.get_rank() == 0:
203
+ print(message, flush=True, end=end)
204
+ else:
205
+ print(message, flush=True, end=end)
206
+
207
+
208
+ class dotdict(dict):
209
+ """dot.notation access to dictionary attributes"""
210
+
211
+ __getattr__ = dict.get
212
+ __setattr__ = dict.__setitem__
213
+ __delattr__ = dict.__delitem__
214
+
215
+
216
+ def ensure_divisibility(numerator, denominator):
217
+ """Ensure that numerator is divisible by the denominator."""
218
+ assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
219
+
220
+
221
+ def divide(numerator, denominator):
222
+ """Ensure that numerator is divisible by the denominator and return
223
+ the division value."""
224
+ ensure_divisibility(numerator, denominator)
225
+ return numerator // denominator
226
+
227
+
228
+ class Lambda(torch.nn.Module):
229
+ def __init__(self, func):
230
+ super().__init__()
231
+ self.func = func
232
+
233
+ def forward(self, x):
234
+ return self.func(x)
235
+
236
+
237
+ class VocabUtility:
238
+ """Split the vocabulary into `world_size` chunks amd return the
239
+ first and last index of the vocabulary belonging to the `rank`
240
+ partition: Note that indices in [first, last]"""
241
+
242
+ @staticmethod
243
+ def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size):
244
+ index_f = rank * per_partition_vocab_size
245
+ index_l = index_f + per_partition_vocab_size
246
+ return index_f, index_l
247
+
248
+ @staticmethod
249
+ def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
250
+ per_partition_vocab_size = divide(global_vocab_size, world_size)
251
+ return VocabUtility.vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size)