windsornguyen commited on
Commit
41c5b75
·
verified ·
1 Parent(s): c2e30d2

add: eval script

Browse files
Files changed (1) hide show
  1. evaluate.py +979 -0
evaluate.py ADDED
@@ -0,0 +1,979 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+
6
+ from huggingface_hub import hf_hub_download
7
+ import lm_eval as evaluator
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from safetensors.torch import load_file
12
+ from torchtune.modules import RotaryPositionalEmbeddings
13
+ from transformers import (
14
+ AutoConfig,
15
+ AutoModel,
16
+ AutoModelForCausalLM,
17
+ PreTrainedModel,
18
+ PretrainedConfig,
19
+ )
20
+ from transformers.modeling_outputs import CausalLMOutput
21
+
22
+ try:
23
+ from flashfftconv import FlashFFTConv
24
+
25
+ flash_fft_available = True
26
+ except ImportError as e:
27
+ print(f"Unable to import FlashFFTConv: {e}. Falling back to PyTorch implementation.")
28
+ flash_fft_available = False
29
+
30
+ try:
31
+ from flash_attn import flash_attn_func
32
+ except ImportError as e:
33
+ print(f"Unable to import Triton-based flash attention: {e}. No alternative currently available.")
34
+
35
+ os.environ["HF_ALLOW_CODE_EVAL"] = "1"
36
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
37
+
38
+ loss_fn = nn.CrossEntropyLoss()
39
+
40
+
41
+ def nearest_power_of_two(n: int, round_up: bool = False) -> int:
42
+ if n <= 1:
43
+ return 1
44
+ return 1 << ((n - 1).bit_length() if round_up else (n).bit_length() - 1)
45
+
46
+
47
+ def find_multiple(n: int, k: int) -> int:
48
+ if n % k == 0:
49
+ return n
50
+ return n + k - (n % k)
51
+
52
+
53
+ def get_hankel(seq_len: int, use_hankel_L: bool = False) -> torch.Tensor:
54
+ entries = torch.arange(1, seq_len + 1, dtype=torch.float64)
55
+ i_plus_j = entries.reshape(-1, 1) + entries.reshape(1, -1)
56
+
57
+ if use_hankel_L:
58
+ sgn = (-1.0) ** (i_plus_j - 2.0) + 1.0
59
+ denom = (i_plus_j + 3.0) * (i_plus_j - 1.0) * (i_plus_j + 1.0)
60
+ Z = sgn * (8.0 / denom)
61
+ elif not use_hankel_L:
62
+ Z = 2.0 / (i_plus_j**3 - i_plus_j)
63
+ else:
64
+ raise ValueError("use_hankel_L must be a boolean")
65
+
66
+ return Z
67
+
68
+
69
+ def get_spectral_filters(
70
+ seq_len: int,
71
+ K: int,
72
+ use_hankel_L: bool = False,
73
+ device: torch.device = None,
74
+ dtype: torch.dtype = torch.float64,
75
+ ) -> torch.Tensor:
76
+ Z = get_hankel(seq_len, use_hankel_L).to(device)
77
+ sigma, phi = torch.linalg.eigh(Z)
78
+ sigma_k, phi_k = sigma[-K:], phi[:, -K:]
79
+ phi_k *= sigma_k**0.25
80
+ return phi_k.to(device=device, dtype=dtype)
81
+
82
+
83
+ class BaseConfigForCausalLM(PretrainedConfig):
84
+ """Base PretrainedConfig class to be decorated with dataclass"""
85
+
86
+ model_type = "base_model"
87
+
88
+ def __init__(self, **kwargs):
89
+ super().__init__(**kwargs)
90
+
91
+
92
+ @dataclass
93
+ class FlashSTUConfig(BaseConfigForCausalLM):
94
+ model_type = "FlashSTU"
95
+
96
+ # Define fields with defaults (as before)
97
+ bsz: int = 1
98
+ dim: int = 1024
99
+ r: int = 1024
100
+ num_heads: int = 12
101
+ num_local_heads: Optional[int] = -1
102
+ num_layers: int = 12
103
+ seq_len: int = 4096
104
+ n: int = 8191
105
+ window_size: int = 2048
106
+ vocab_size: int = 200064
107
+ inter_dim: Optional[int] = 3072
108
+ mlp_scale: Optional[float] = 12.0
109
+ weight_tying: Optional[bool] = True
110
+ bias: Optional[bool] = False
111
+ rope_theta: Optional[float] = 10000.0
112
+ softcap: Optional[float] = 50.0
113
+ num_eigh: Optional[int] = 24
114
+ use_hankel_L: Optional[bool] = False
115
+ use_flash_fft: Optional[bool] = True
116
+ use_tensordot: Optional[bool] = True
117
+ use_attn: Optional[bool] = True
118
+ use_alibi: Optional[bool] = False
119
+ torch_dtype: torch.dtype = torch.bfloat16
120
+ device: torch.device = None
121
+
122
+ # Explicit __init__ to handle **kwargs for PretrainedConfig compatibility
123
+ def __init__(
124
+ self,
125
+ bsz: int = 1,
126
+ dim: int = 1024,
127
+ r: int = 1024,
128
+ num_heads: int = 12,
129
+ num_local_heads: Optional[int] = -1,
130
+ num_layers: int = 12,
131
+ seq_len: int = 4096,
132
+ n: int = 8191,
133
+ window_size: int = 2048,
134
+ vocab_size: int = 200064,
135
+ inter_dim: Optional[int] = 3072,
136
+ mlp_scale: Optional[float] = 12.0,
137
+ weight_tying: Optional[bool] = True,
138
+ bias: Optional[bool] = False,
139
+ rope_theta: Optional[float] = 10000.0,
140
+ softcap: Optional[float] = 50.0,
141
+ num_eigh: Optional[int] = 24,
142
+ use_hankel_L: Optional[bool] = False,
143
+ use_flash_fft: Optional[bool] = True,
144
+ use_tensordot: Optional[bool] = True,
145
+ use_attn: Optional[bool] = True,
146
+ use_alibi: Optional[bool] = False,
147
+ torch_dtype: torch.dtype = torch.bfloat16,
148
+ device: torch.device = None,
149
+ **kwargs, # Catch extra arguments like model_type
150
+ ):
151
+ super().__init__(**kwargs) # Pass kwargs to parent __init__
152
+
153
+ # Assign fields from arguments
154
+ self.bsz = bsz
155
+ self.dim = dim
156
+ self.r = r
157
+ self.num_heads = num_heads
158
+ self.num_local_heads = num_local_heads
159
+ self.num_layers = num_layers
160
+ self.seq_len = seq_len
161
+ self.n = n
162
+ self.window_size = window_size
163
+ self.vocab_size = vocab_size
164
+ self.inter_dim = inter_dim
165
+ self.mlp_scale = mlp_scale
166
+ self.weight_tying = weight_tying
167
+ self.bias = bias
168
+ self.rope_theta = rope_theta
169
+ self.softcap = softcap
170
+ self.num_eigh = num_eigh
171
+ self.use_hankel_L = use_hankel_L
172
+ self.use_flash_fft = use_flash_fft
173
+ self.use_tensordot = use_tensordot
174
+ self.use_attn = use_attn
175
+ self.use_alibi = use_alibi
176
+ self.torch_dtype = torch_dtype
177
+ self.device = device
178
+
179
+ # Explicitly call __post_init__ if defined and needed
180
+ self.__post_init__()
181
+
182
+ def __post_init__(self):
183
+ # Ensure torch_dtype is a torch.dtype object, not a string
184
+ if isinstance(self.torch_dtype, str):
185
+ try:
186
+ self.torch_dtype = getattr(torch, self.torch_dtype)
187
+ except AttributeError:
188
+ raise ValueError(f"Invalid torch_dtype string: {self.torch_dtype}")
189
+
190
+ if self.num_local_heads == -1:
191
+ self.num_local_heads = self.num_heads
192
+ if self.inter_dim is None:
193
+ hidden_dim = self.mlp_scale * self.dim
194
+ num_hidden = int(2 * hidden_dim / 3)
195
+ self.inter_dim = find_multiple(num_hidden, 256)
196
+ self.head_dim = self.dim // self.num_heads
197
+
198
+ @classmethod
199
+ def from_name(cls, name: str):
200
+ # presets = {
201
+ # "tiny": dict(dim=128, num_heads=4, num_layers=2, vocab_size=10000),
202
+ # "small": dict(dim=256, num_heads=8, num_layers=4, vocab_size=20000),
203
+ # "gpt2-small": dict(dim=768, num_heads=12, num_layers=12, vocab_size=50257),
204
+ # # add more as needed
205
+ # }
206
+ # if name not in presets:
207
+ # raise ValueError(f"Unknown model config name: {name}")
208
+
209
+ # return cls(**presets[name])
210
+ print("Not yet implemented")
211
+ pass
212
+
213
+
214
+ class MLP(nn.Module):
215
+ def __init__(self, config: FlashSTUConfig) -> None:
216
+ super().__init__()
217
+ self.w1 = nn.Linear(config.dim, config.inter_dim)
218
+ self.w2 = nn.Linear(config.inter_dim, config.dim)
219
+ self.w2.SCALE_INIT = 1
220
+
221
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
222
+ return self.w2(F.gelu(self.w1(x), approximate="tanh"))
223
+
224
+
225
+ class SlidingWindowAttention(nn.Module):
226
+ def __init__(self, config):
227
+ super().__init__()
228
+ self.wq = nn.Linear(config.dim, config.dim)
229
+ self.wk = nn.Linear(config.dim, config.dim)
230
+ self.wv = nn.Linear(config.dim, config.dim)
231
+ self.wo = nn.Linear(config.dim, config.dim)
232
+ self.wo.SCALE_INIT = 1
233
+
234
+ self.dim = config.dim
235
+ self.head_dim = config.head_dim
236
+ self.num_heads = config.num_heads
237
+ self.num_local_heads = config.num_local_heads
238
+ self.window_size = config.window_size
239
+ self.softcap = config.softcap
240
+
241
+ self.alibi_slopes = self._get_alibi_slopes(self.num_heads) if config.use_alibi else None
242
+ self.rotary_emb = RotaryPositionalEmbeddings(
243
+ dim=self.head_dim,
244
+ max_seq_len=config.seq_len,
245
+ base=config.rope_theta,
246
+ )
247
+
248
+ def forward(self, x):
249
+ bsz, seq_len, dim = x.shape
250
+
251
+ q, k, v = self.wq(x), self.wk(x), self.wv(x)
252
+ q = q.view(bsz, seq_len, self.num_heads, self.head_dim)
253
+ k = k.view(bsz, seq_len, self.num_local_heads, self.head_dim)
254
+ v = v.view(bsz, seq_len, self.num_local_heads, self.head_dim)
255
+
256
+ if self.alibi_slopes is None:
257
+ q, k = self.rotary_emb(q), self.rotary_emb(k)
258
+
259
+ y = flash_attn_func(
260
+ q=q,
261
+ k=k,
262
+ v=v,
263
+ causal=True,
264
+ window_size=(self.window_size, 0),
265
+ # softcap=self.softcap,
266
+ alibi_slopes=self.alibi_slopes,
267
+ )
268
+
269
+ out = y.reshape(bsz, seq_len, -1)
270
+ out = self.wo(out)
271
+
272
+ return out
273
+
274
+ def _generate_slopes(self, n: int):
275
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
276
+ return [start * (start**i) for i in range(n)]
277
+
278
+ def _get_alibi_slopes(self, num_heads: int, interpolation_factor: float = 0.25):
279
+ # If n_heads is a power of 2, generate slopes directly
280
+ if math.log2(num_heads).is_integer():
281
+ slopes = self._generate_slopes(num_heads)
282
+ else:
283
+ # Get slopes for the nearest power of two
284
+ n = nearest_power_of_two(num_heads, round_up=False)
285
+ slopes_power_of_two = self._generate_slopes(n)
286
+
287
+ # Generate extra slopes
288
+ extra_slopes = self._generate_slopes(2 * n)
289
+ extra_slopes_trunc = extra_slopes[0::2][: num_heads - n]
290
+ slopes = slopes_power_of_two + extra_slopes_trunc
291
+ slopes = torch.tensor(slopes, device=torch.device("cuda")) # FA ALiBi must be on CUDA
292
+ slopes = slopes * interpolation_factor # https://arxiv.org/pdf/2310.13017
293
+ return slopes
294
+
295
+
296
+ class STU(nn.Module):
297
+ def __init__(self, config):
298
+ super().__init__()
299
+
300
+ # Set at top-level post- model init
301
+ self.stu_filters = None
302
+ self.stu_filters_fft = None # TODO: Optimization: Precompute FFT of filters
303
+
304
+ self.n = config.n
305
+ self.num_eigh = config.num_eigh
306
+ self.d_in = config.dim
307
+ self.d_out = config.dim
308
+ self.r = config.r
309
+ self.use_hankel_L = config.use_hankel_L
310
+ self.use_tensordot = config.use_tensordot
311
+ self.flash_fft = (
312
+ FlashFFTConv(self.n, dtype=torch.bfloat16) if config.use_flash_fft and flash_fft_available else None
313
+ )
314
+
315
+ # TODO: Add dimensionality reduction `r` here.
316
+ if self.use_tensordot:
317
+ self.M_inputs = nn.Parameter(torch.zeros(self.d_in, self.d_out))
318
+ self.M_filters = nn.Parameter(torch.zeros(self.num_eigh, self.d_in))
319
+ else:
320
+ self.M_phi_plus = nn.Parameter(torch.zeros(self.num_eigh, self.d_in, self.d_out))
321
+ if not self.use_hankel_L:
322
+ self.M_phi_minus = nn.Parameter(torch.zeros(self.num_eigh, self.d_in, self.d_out))
323
+
324
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
325
+ B, L, D = x.shape
326
+
327
+ if self.use_tensordot:
328
+ # Contract inputs and filters over (K, D) dims first, then convolve
329
+ x_proj = x @ self.M_inputs
330
+ phi_proj = self.stu_filters @ self.M_filters
331
+ if self.flash_fft:
332
+ spectral_plus, spectral_minus = self.flash_conv(x_proj, phi_proj, self.flash_fft, self.use_tensordot)
333
+ else:
334
+ spectral_plus, spectral_minus = self.conv(x_proj, phi_proj, self.n, self.use_tensordot)
335
+
336
+ else:
337
+ # Convolve inputs and filters first, then contract over (K, D) dims
338
+ if self.flash_fft:
339
+ U_plus, U_minus = self.flash_conv(x, self.stu_filters, self.flash_fft, self.use_tensordot)
340
+ else:
341
+ U_plus, U_minus = self.conv(x, self.stu_filters, self.n, self.use_tensordot)
342
+
343
+ B, L, K, D = U_plus.shape
344
+ spectral_plus = U_plus.reshape(B, L, K * D) @ self.M_phi_plus.reshape(K * D, self.d_out)
345
+ if not self.use_hankel_L:
346
+ spectral_minus = U_minus.reshape(B, L, K * D) @ self.M_phi_minus.reshape(K * D, self.d_out)
347
+
348
+ out = spectral_plus if self.use_hankel_L else spectral_plus + spectral_minus
349
+ return out
350
+
351
+ def conv(
352
+ self, u: torch.Tensor, v: torch.Tensor, n: int, use_tensordot: bool = True
353
+ ) -> tuple[torch.Tensor, torch.Tensor]:
354
+ """
355
+ Performs convolution via FFT with causal alignment using a negative featurization.
356
+
357
+ The input tensor u is modulated by an alternating sign tensor (sgn) that multiplies every other
358
+ time step by -1. This "negative featurization" modulates the phase so that in this implementation
359
+ the correct causal output is obtained by simply slicing the first L elements (i.e. [:seq_len]).
360
+ Note: Using a conventional slice [seq_len-1:2*seq_len-1] would yield a flipped alignment, resulting in leakage.
361
+
362
+ Args:
363
+ u: Input tensor of shape (bsz, seq_len, d_in).
364
+ v: Kernel tensor; expected shape is (seq_len, d_out) if use_tensordot is True.
365
+ n: FFT length (typically set to 2*seq_len - 1 for linear convolution with implicit right zero-padding).
366
+ use_tensordot: Boolean flag to control kernel reshaping.
367
+
368
+ Returns:
369
+ A tuple (U_plus, U_minus) where:
370
+ - U_plus is the primary convolution output.
371
+ - U_minus is the secondary output, corrected by the sign tensor.
372
+ """
373
+ bsz, seq_len, d_in = u.shape
374
+
375
+ sgn = torch.full((1, seq_len, 1), 1, device=u.device)
376
+ sgn[:, 1::2] *= -1 # Apply negative featurization: multiply every other element by -1.
377
+
378
+ if use_tensordot:
379
+ _, d_out = v.shape
380
+ v = v.view(1, -1, d_out, 1).to(torch.float32).contiguous()
381
+ else:
382
+ _, K = v.shape
383
+ sgn = sgn.unsqueeze(-1)
384
+ v = v.view(1, -1, K, 1, 1).to(torch.float32).contiguous() # (bsz, seq_len, K, d_in, stack)
385
+ u = u.view(bsz, -1, 1, d_in).expand(bsz, -1, K, d_in)
386
+
387
+ # Cast kernel to float32 for FFT
388
+ v_fft = torch.fft.rfft(v.to(torch.float32), n=n, dim=1)
389
+
390
+ U = torch.stack([u, u * sgn], dim=-1).to(torch.float32).contiguous()
391
+ # Cast input stack to float32 for FFT
392
+ U_fft = torch.fft.rfft(U.to(torch.float32), n=n, dim=1)
393
+
394
+ # Slicing the first seq_len outputs yields the proper causal convolution given the negative modulation.
395
+ # Perform convolution in float32 and cast back
396
+ U_conv = torch.fft.irfft(v_fft * U_fft, n=n, dim=1)[:, :seq_len].to(u.dtype)
397
+ U_plus, U_minus = torch.unbind(U_conv, dim=-1)
398
+ U_minus = U_minus * sgn
399
+
400
+ return U_plus.type_as(u), U_minus.type_as(u)
401
+
402
+ def flash_conv(
403
+ self,
404
+ u: torch.Tensor,
405
+ v: torch.Tensor,
406
+ flash_fft: FlashFFTConv,
407
+ use_tensordot: bool = True,
408
+ ) -> tuple[torch.Tensor, torch.Tensor]:
409
+ """Flash FFT convolution.
410
+
411
+ Args:
412
+ u (torch.Tensor): Input tensor of shape `(B, L, d_in)`, where:
413
+ - `B` is the batch size,
414
+ - `L` is the sequence length,
415
+ - `d_in` is the input dimension.
416
+ v (torch.Tensor): Filter tensor of shape `(K, d_in)`, where:
417
+ - `K` is the number of filters,
418
+ - `d_in` is the input dimension.
419
+ flash_fft (FlashFFTConv): An instance of the FlashFFTConv module, used to perform the convolution.
420
+ use_tensordot (bool, optional): If `True`, performs the tensordot approximation (default is `True`).
421
+
422
+ Returns:
423
+ tuple[torch.Tensor, torch.Tensor]: A tuple `(U_plus, U_minus)`:
424
+ - `U_plus`: Convolved output tensor with positive eigenvalues.
425
+ - Shape depends on `use_tensordot`:
426
+ - If `use_tensordot=True`: `(B, L, d_in)`
427
+ - If `use_tensordot=False`: `(B, L, K, d_in)`
428
+ - `U_minus`: Convolved output tensor with negative eigenvalues.
429
+ - Shape depends on `use_tensordot`:
430
+ - If `use_tensordot=True`: `(B, L, d_in)`
431
+ - If `use_tensordot=False`: `(B, L, K, d_in)`
432
+
433
+ Raises:
434
+ ValueError: If the input tensor shapes do not conform to the expected dimensions.
435
+
436
+ Example:
437
+ >>> u = torch.randn(4, 16, 32) # (B, L, d_in)
438
+ >>> v = torch.randn(8, 32) # (K, d_in)
439
+ >>> flash_fft = FlashFFTConv(n=16, dtype=torch.float32)
440
+ >>> U_plus, U_minus = flash_convolve(u, v, flash_fft, use_tensordot=True)
441
+ >>> print(U_plus.shape, U_minus.shape)
442
+ torch.Size([4, 16, 32]) torch.Size([4, 16, 32])
443
+
444
+ """
445
+ bsz, seq_len, d_in = u.shape
446
+ _, K = v.shape
447
+
448
+ padded_len = nearest_power_of_two(seq_len, round_up=True)
449
+ pad_len = padded_len - seq_len
450
+
451
+ sgn = torch.full((1, 1, padded_len), 1, device=u.device)
452
+ sgn[:, :, 1::2] = -1
453
+
454
+ if use_tensordot:
455
+ u_padded = F.pad(u.transpose(1, 2), (0, pad_len)).to(torch.bfloat16)
456
+ v_padded = F.pad(v.transpose(0, 1), (0, pad_len)).to(torch.float32)
457
+ u_conv = torch.stack([u_padded, u_padded * sgn], dim=0).reshape(2 * bsz, d_in, padded_len)
458
+ else:
459
+ u_k_padded = F.pad(u.transpose(1, 2), (0, pad_len)).repeat_interleave(K, dim=1)
460
+ v_padded = F.pad(v.transpose(0, 1), (0, pad_len)).to(torch.float32).repeat(d_in, 1)
461
+ u_conv = torch.stack([u_k_padded, u_k_padded * sgn], dim=0).reshape(2 * bsz, K * d_in, padded_len)
462
+
463
+ # Ensure inputs to flash_fft are bfloat16 (input) and float32 (kernel)
464
+ U_conv = flash_fft(u_conv.to(torch.bfloat16), v_padded.to(torch.float32))
465
+
466
+ # Trim the output back to the original sequence length
467
+ U_conv = U_conv[..., :seq_len]
468
+ u_plus, u_minus = torch.chunk(U_conv, 2, dim=0)
469
+
470
+ if use_tensordot:
471
+ u_minus = u_minus * sgn[:, :, :seq_len]
472
+ U_plus, U_minus = u_plus.transpose(1, 2), u_minus.transpose(1, 2)
473
+ else:
474
+ sgn = sgn[:, :, :seq_len].unsqueeze(-1).transpose(1, 2)
475
+ U_plus = u_plus.view(bsz, d_in, K, seq_len).permute(0, 3, 2, 1).contiguous()
476
+ U_minus = u_minus.view(bsz, d_in, K, seq_len).permute(0, 3, 2, 1).contiguous() * sgn
477
+
478
+ return U_plus, U_minus
479
+
480
+
481
+ class SlidingWindowAttentionLayer(nn.Module):
482
+ def __init__(self, config):
483
+ super().__init__()
484
+ self.swa_norm = nn.LayerNorm(config.dim)
485
+ self.swa = SlidingWindowAttention(config)
486
+ self.mlp_norm = nn.LayerNorm(config.dim)
487
+ self.mlp = MLP(config)
488
+
489
+ def forward(self, x):
490
+ x = x + self.swa(self.swa_norm(x))
491
+ x = x + self.mlp(self.mlp_norm(x))
492
+ return x
493
+
494
+
495
+ class STULayer(nn.Module):
496
+ def __init__(self, config):
497
+ super().__init__()
498
+ self.stu_norm = nn.LayerNorm(config.dim)
499
+ self.stu = STU(config)
500
+ self.mlp_norm = nn.LayerNorm(config.dim)
501
+ self.mlp = MLP(config)
502
+
503
+ def forward(self, x):
504
+ x = x + self.stu(self.stu_norm(x))
505
+ x = x + self.mlp(self.mlp_norm(x))
506
+ return x
507
+
508
+
509
+ class FlashSTU(nn.Module):
510
+ def __init__(self, config):
511
+ super().__init__()
512
+ self.config = config
513
+ self.tok_emb = nn.Embedding(config.vocab_size, config.dim)
514
+ self.layers = nn.ModuleList()
515
+
516
+ for layer_idx in range(config.num_layers):
517
+ # For more complex %-split arrangements, see https://arxiv.org/pdf/2406.07887
518
+ if layer_idx % 2 == 0:
519
+ self.layers.append(STULayer(config))
520
+ else:
521
+ self.layers.append(SlidingWindowAttentionLayer(config)) if config.use_attn else self.layers.append(
522
+ STULayer(config)
523
+ )
524
+
525
+ self.norm_f = nn.LayerNorm(config.dim)
526
+ self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False)
527
+
528
+ if self.config.weight_tying:
529
+ self.tok_emb.weight = self.lm_head.weight
530
+
531
+ self.std = self.config.dim**-0.5
532
+
533
+ def init_weights(self, module):
534
+ std = self.std
535
+ if isinstance(module, nn.Linear):
536
+ if hasattr(module, "SCALE_INIT"):
537
+ std *= (2 * self.config.num_layers) ** -0.5
538
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
539
+ if module.bias is not None:
540
+ torch.nn.init.zeros_(module.bias)
541
+ elif isinstance(module, nn.Embedding):
542
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
543
+
544
+ def forward(self, input_ids: torch.Tensor, labels: torch.Tensor = None, **kwargs) -> CausalLMOutput:
545
+ x = self.tok_emb(input_ids)
546
+
547
+ for layer in self.layers:
548
+ x = layer(x)
549
+
550
+ x = self.norm_f(x)
551
+ logits = self.lm_head(x)
552
+
553
+ loss = None
554
+ if labels is not None:
555
+ loss = loss_fn(logits.flatten(0, 1), labels.flatten(0, 1))
556
+
557
+ return CausalLMOutput(
558
+ loss=loss,
559
+ logits=logits,
560
+ )
561
+
562
+ def setup_filters(
563
+ self,
564
+ spectral_filters: torch.Tensor,
565
+ spectral_filters_fft: torch.Tensor,
566
+ ):
567
+ for layer in self.layers:
568
+ if isinstance(layer, STULayer):
569
+ layer.stu.stu_filters = spectral_filters
570
+ layer.stu.stu_filters_fft = spectral_filters_fft
571
+
572
+ def get_num_params(self):
573
+ """
574
+ Return the number of parameters in the model.
575
+ For non-embedding count (default), the position embeddings get subtracted.
576
+ """
577
+ n_params = sum(p.numel() for p in self.parameters())
578
+ return n_params
579
+
580
+
581
+ def create_base_model_components(model_name_or_path=None, **kwargs):
582
+ """Create config and filters needed for model initialization"""
583
+ if model_name_or_path is not None:
584
+ config = FlashSTUConfig.from_pretrained(model_name_or_path, **kwargs)
585
+ else:
586
+ config = FlashSTUConfig(**kwargs)
587
+
588
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
589
+
590
+ filters = get_spectral_filters(
591
+ seq_len=config.seq_len,
592
+ K=config.num_eigh,
593
+ use_hankel_L=config.use_hankel_L,
594
+ device=device,
595
+ dtype=config.torch_dtype,
596
+ )
597
+ assert filters.dtype == config.torch_dtype, f"filters dtype is {filters.dtype}, expected {config.torch_dtype}"
598
+ return config, filters
599
+
600
+
601
+ class FlashSTUForCausalLM(PreTrainedModel):
602
+ """Thin wrapper to comply with HuggingFace's expected interface"""
603
+
604
+ config_class = FlashSTUConfig
605
+ base_model_prefix = "FlashSTU"
606
+
607
+ def __init__(self, config):
608
+ super().__init__(config)
609
+
610
+ self.flash_stu = FlashSTU(config)
611
+ self.flash_stu.apply(self.flash_stu.init_weights)
612
+
613
+ device = (
614
+ config.device
615
+ if config.device is not None
616
+ else torch.device("cuda" if torch.cuda.is_available() else "cpu")
617
+ )
618
+ torch_dtype = config.torch_dtype # Assumes __post_init__ already converted it to torch.dtype
619
+
620
+ spectral_filters = get_spectral_filters(
621
+ seq_len=config.seq_len,
622
+ K=config.num_eigh,
623
+ use_hankel_L=config.use_hankel_L,
624
+ device=device,
625
+ # Note: get_spectral_filters returns float64, cast later
626
+ )
627
+ spectral_filters_fft = torch.fft.rfft(spectral_filters, n=config.n, dim=1)
628
+
629
+ # Setup filters in the model, casting to the target dtype
630
+ self.flash_stu.setup_filters(
631
+ spectral_filters.to(dtype=torch_dtype), spectral_filters_fft.to(dtype=torch_dtype)
632
+ )
633
+ # Note: Moving the entire model to device happens later, after loading weights.
634
+
635
+ def forward(
636
+ self, input_ids: torch.Tensor, labels: torch.Tensor = None, attention_mask: torch.Tensor = None, **kwargs
637
+ ) -> CausalLMOutput:
638
+ outputs = self.flash_stu(input_ids, labels=labels, **kwargs)
639
+ return outputs
640
+
641
+ def generate(
642
+ self,
643
+ input_ids: torch.Tensor,
644
+ max_length: int = 32,
645
+ num_return_sequences: int = 4,
646
+ temperature: float = 0.8,
647
+ top_k: int = 50,
648
+ top_p: float = 0.95,
649
+ repetition_penalty: float = 1.2,
650
+ seed: int = 42,
651
+ ) -> torch.Tensor:
652
+ """Generate text using top-k and nucleus sampling with temperature and repetition penalty.
653
+
654
+ Args:
655
+ input_ids: Input token ids of shape (batch_size, seq_len)
656
+ max_length: Maximum length of generated sequence
657
+ num_return_sequences: Number of sequences to generate per input
658
+ temperature: Sampling temperature. Higher = more random, lower = more focused
659
+ top_k: Number of highest probability tokens to keep for top-k sampling
660
+ top_p: Cumulative probability cutoff for nucleus sampling
661
+ repetition_penalty: Penalty factor for repeating tokens. 1.0 = no penalty
662
+ seed: Random seed for reproducibility
663
+
664
+ Returns:
665
+ Generated token ids of shape (num_return_sequences, max_length)
666
+ """
667
+ self.eval() # Set to eval mode
668
+ device = input_ids.device
669
+
670
+ # Expand input for multiple sequences
671
+ input_ids = input_ids.repeat(num_return_sequences, 1)
672
+ generated = input_ids
673
+
674
+ # Set up generator for reproducible sampling
675
+ sample_rng = torch.Generator(device=device)
676
+ sample_rng.manual_seed(seed)
677
+
678
+ # Generate tokens until we reach max_length
679
+ with torch.no_grad():
680
+ while generated.size(1) < max_length:
681
+ # Get logits for next token
682
+ outputs = self.flash_stu(generated)
683
+ next_token_logits = outputs.logits[:, -1, :]
684
+
685
+ # Apply repetition penalty
686
+ if repetition_penalty != 1.0:
687
+ for i in range(generated.shape[0]):
688
+ for token in generated[i]:
689
+ if token in next_token_logits[i]:
690
+ next_token_logits[i, token] /= repetition_penalty
691
+
692
+ # Apply temperature
693
+ if temperature != 1.0:
694
+ next_token_logits = next_token_logits / temperature
695
+
696
+ # Get probabilities
697
+ probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
698
+
699
+ # Top-k sampling
700
+ if top_k > 0:
701
+ indices_to_remove = probs < torch.topk(probs, top_k)[0][..., -1, None]
702
+ probs[indices_to_remove] = 0
703
+
704
+ # Nucleus (top-p) sampling
705
+ if top_p < 1.0:
706
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True)
707
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
708
+
709
+ # Remove tokens with cumulative probability above the threshold
710
+ sorted_indices_to_remove = cumulative_probs > top_p
711
+ # Shift the indices to the right to keep also the first token above the threshold
712
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
713
+ sorted_indices_to_remove[..., 0] = 0
714
+
715
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
716
+ probs[indices_to_remove] = 0
717
+
718
+ # Renormalize probabilities
719
+ probs = probs / probs.sum(dim=-1, keepdim=True).clamp(min=1e-8)
720
+
721
+ # Sample next token
722
+ next_token = torch.multinomial(probs, num_samples=1, generator=sample_rng)
723
+
724
+ # Append to generated sequence
725
+ generated = torch.cat([generated, next_token], dim=1)
726
+
727
+ return generated
728
+
729
+ def get_num_params(self):
730
+ return self.flash_stu.get_num_params()
731
+
732
+ @classmethod
733
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
734
+ # Get config and create model
735
+ config, _ = create_base_model_components(pretrained_model_name_or_path, **kwargs)
736
+ model = cls(config)
737
+
738
+ # Download safetensors file from hub
739
+ weights_path = hf_hub_download(
740
+ repo_id=pretrained_model_name_or_path,
741
+ filename="model.safetensors",
742
+ cache_dir=kwargs.get("cache_dir"),
743
+ force_download=kwargs.get("force_download", False),
744
+ proxies=kwargs.get("proxies", None),
745
+ local_files_only=kwargs.get("local_files_only", False),
746
+ use_auth_token=kwargs.get("use_auth_token", None),
747
+ revision=kwargs.get("revision", None),
748
+ subfolder=kwargs.get("subfolder", ""),
749
+ )
750
+
751
+ state_dict = load_file(weights_path)
752
+
753
+ # Reconstruct weight tying for tok_emb and lm_head
754
+ tok_emb_key = "tok_emb.weight"
755
+ lm_head_key = "lm_head.weight"
756
+
757
+ tok_emb_present = tok_emb_key in state_dict
758
+ lm_head_present = lm_head_key in state_dict
759
+
760
+ if tok_emb_present and not lm_head_present:
761
+ print(f"Reconstructing weight tying: Linking missing '{lm_head_key}' to existing '{tok_emb_key}'")
762
+ state_dict[lm_head_key] = state_dict[tok_emb_key]
763
+ elif lm_head_present and not tok_emb_present:
764
+ print(f"Reconstructing weight tying: Linking missing '{tok_emb_key}' to existing '{lm_head_key}'")
765
+ state_dict[tok_emb_key] = state_dict[lm_head_key]
766
+ elif not tok_emb_present and not lm_head_present:
767
+ # This case should ideally not happen if the file is valid
768
+ print(
769
+ f"Warning: Neither '{tok_emb_key}' nor '{lm_head_key}' found in state_dict. Weight tying cannot be reconstructed."
770
+ )
771
+ # If both are present, assume they are loaded correctly (or were never tied)
772
+
773
+ # Prepend 'flash_stu.' to all keys to match wrapper's state dict
774
+ final_state_dict = {f"flash_stu.{k}": v for k, v in state_dict.items()}
775
+ model.load_state_dict(final_state_dict)
776
+
777
+ # Move to GPU if available
778
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
779
+ model = model.to(device=device, dtype=torch.bfloat16)
780
+ model.eval()
781
+
782
+ # Print parameter count as a sanity check
783
+ num_params = model.get_num_params()
784
+ print(f"\nModel loaded: {pretrained_model_name_or_path}")
785
+ print(f"Parameter count: {num_params / 1e6:.2f}M")
786
+
787
+ return model
788
+
789
+
790
+ # Create initial config and filters for registration
791
+ config, filters = create_base_model_components()
792
+
793
+ # Register models
794
+ AutoConfig.register("FlashSTU", FlashSTUConfig)
795
+ AutoModel.register(FlashSTUConfig, FlashSTU)
796
+ AutoModelForCausalLM.register(FlashSTUConfig, FlashSTUForCausalLM)
797
+
798
+ print("Registered FlashSTU model and configuration.")
799
+
800
+
801
+ def run_model_diagnostics(model, tokenizer, device):
802
+ """Run detailed diagnostics to analyze model behavior."""
803
+ print("\nRunning model diagnostics...")
804
+
805
+ # Test cases of varying difficulty and length
806
+ test_cases = [
807
+ # Simple completion
808
+ "2 + 2 =",
809
+ # Medium difficulty
810
+ "The capital of France is Paris. The capital of Germany is",
811
+ # Complex reasoning
812
+ "If a train travels 120 kilometers in 2 hours, its average speed is",
813
+ # Pattern completion
814
+ "1, 2, 3, 4,",
815
+ # Long context
816
+ "The following is a detailed explanation of photosynthesis: Plants use sunlight to",
817
+ ]
818
+
819
+ with torch.no_grad():
820
+ for prompt in test_cases:
821
+ print(f"\nAnalyzing prompt: {prompt}")
822
+
823
+ # Tokenize
824
+ tokens = tokenizer(prompt, return_tensors="pt")
825
+ input_ids = tokens["input_ids"].to(device)
826
+
827
+ outputs = model.flash_stu(input_ids, labels=input_ids)
828
+
829
+ labels = input_ids.clone()
830
+ shift_logits = outputs.logits[..., :-1, :].contiguous()
831
+ shift_labels = labels[..., 1:].contiguous()
832
+
833
+ loss_fct = nn.CrossEntropyLoss(reduction="none")
834
+ token_losses = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(
835
+ shift_labels.size()
836
+ )
837
+
838
+ # Print token-by-token analysis
839
+ input_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
840
+ print("\nToken-by-token loss:")
841
+ for i, (token, loss) in enumerate(zip(input_tokens[1:], token_losses[0])):
842
+ print(f"{token}: {loss.item():.3f}")
843
+
844
+ print(f"Average loss: {token_losses.mean().item():.3f}")
845
+
846
+ # Generate with different temperatures
847
+ temps = [0.5, 0.7, 1.0]
848
+ print("\nGeneration temperature comparison:")
849
+ for temp in temps:
850
+ gen_ids = model.generate(
851
+ input_ids,
852
+ max_length=25,
853
+ num_return_sequences=1,
854
+ temperature=temp,
855
+ top_p=0.9,
856
+ repetition_penalty=1.5,
857
+ seed=42,
858
+ )
859
+ gen_text = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
860
+ print(f"\nTemp {temp}: {gen_text}")
861
+
862
+
863
+ def validate_model_generation():
864
+ print("\nRunning generation validation test...")
865
+
866
+ try:
867
+ from transformers import AutoTokenizer
868
+
869
+ # Load model and tokenizer
870
+ # model_id = "Hazan-Lab/Flash_STU_550M"
871
+ model_id = "Hazan-Lab/FlashSTU-340M-0428"
872
+ model = FlashSTUForCausalLM.from_pretrained(model_id)
873
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
874
+
875
+ # Move to GPU if available
876
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
877
+ model = model.to(device=device, dtype=torch.bfloat16)
878
+ model.eval()
879
+
880
+ # Print parameter count as a sanity check
881
+ num_params = model.get_num_params()
882
+ print(f"\nModel loaded: {model_id}")
883
+ print(f"Parameter count: {num_params / 1e6:.2f}M")
884
+
885
+ # Run additional diagnostics
886
+ run_model_diagnostics(model, tokenizer, device)
887
+
888
+ except Exception as e:
889
+ print(f"\nError during validation: {str(e)}")
890
+ raise
891
+
892
+
893
+ # Run evaluation tasks
894
+ tasks = [
895
+ # "mmlu",
896
+ "hellaswag",
897
+ # "piqa",
898
+ # "siqa",
899
+ # "boolq",
900
+ # "winogrande",
901
+ # "commonsense_qa",
902
+ # "openbookqa",
903
+ # "arc",
904
+ # "arc_easy",
905
+ # "arc_challenge",
906
+ # "triviaqa",
907
+ # "nq_open",
908
+ # "humaneval",
909
+ # "mbpp",
910
+ # "gms8k",
911
+ # "hendrycks_math",
912
+ # "mathqa",
913
+ # "minerva_math",
914
+ # "score",
915
+ # "asdiv",
916
+ # "agieval",
917
+ # "bigbench",
918
+ ]
919
+
920
+ tasks_fewshot = {
921
+ "hellaswag": 0,
922
+ # "mmlu": 5,
923
+ # "piqa": 0,
924
+ # "siqa": 0,
925
+ # "boolq": 0,
926
+ # "winogrande": -1,
927
+ # "commonsense_qa": 7,
928
+ # "openbookqa": -1,
929
+ # "arc": -1,
930
+ # "arc_easy": -1,
931
+ # "arc_challenge": -1,
932
+ # "triviaqa": 5,
933
+ # "nq_open": 5,
934
+ # "humaneval": -1,
935
+ # "mbpp": 3,
936
+ # "gms8k": -1,
937
+ # "hendrycks_math": 4,
938
+ # "mathqa": -1,
939
+ # "minerva_math": -1,
940
+ # "score": -1,
941
+ # "asdiv": -1,
942
+ # "agieval": -1,
943
+ # "bigbench": -1,
944
+ }
945
+
946
+ all_results = {}
947
+
948
+ # First validate generation works
949
+ validate_model_generation()
950
+
951
+ print("\nStarting evaluation tasks...")
952
+ for task in tasks:
953
+ print(f"\nEvaluating task: {task}")
954
+ eval_kwargs = dict(
955
+ model="hf",
956
+ model_args=(
957
+ # "pretrained=Hazan-Lab/Flash_STU_550M,"
958
+ "pretrained=Hazan-Lab/FlashSTU-340M-0428,"
959
+ "trust_remote_code=True,"
960
+ "dtype=bfloat16,"
961
+ "cache_dir=/scratch/gpfs/mn4560/hazan-lab/tensorized_filters/tensorized_filters/eval/cache"
962
+ ),
963
+ tasks=[task],
964
+ batch_size="auto",
965
+ device="cuda:0",
966
+ )
967
+ few_shot_value = tasks_fewshot.get(task, -1)
968
+ if few_shot_value != -1:
969
+ eval_kwargs["num_fewshot"] = few_shot_value
970
+ results = evaluator.simple_evaluate(**eval_kwargs)
971
+ task_result = results["results"].get(task, {})
972
+ all_results[task] = task_result
973
+ print(f"Results for {task}:")
974
+ print(task_result)
975
+ print("\n" + "=" * 50 + "\n")
976
+
977
+ print("All Evaluation Results:")
978
+ for task, result in all_results.items():
979
+ print(f"{task}: {result}")