windsornguyen commited on
Commit
dc91c82
·
verified ·
1 Parent(s): 9beefc3

add: eval script

Browse files
Files changed (1) hide show
  1. evaluate.py +620 -0
evaluate.py ADDED
@@ -0,0 +1,620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass
3
+ from typing import Optional
4
+
5
+ from huggingface_hub import hf_hub_download
6
+ import lm_eval as evaluator
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from safetensors.torch import load_file
11
+ from torchtune.modules import RotaryPositionalEmbeddings
12
+ from transformers import (
13
+ AutoConfig,
14
+ AutoModel,
15
+ AutoModelForCausalLM,
16
+ PreTrainedModel,
17
+ PretrainedConfig,
18
+ )
19
+ from transformers.modeling_outputs import CausalLMOutput
20
+ from flash_attn import flash_attn_func
21
+
22
+ os.environ["HF_ALLOW_CODE_EVAL"] = "1"
23
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
24
+
25
+ loss_fn = nn.CrossEntropyLoss()
26
+
27
+
28
+ class Attention(nn.Module):
29
+ def __init__(self, config):
30
+ super().__init__()
31
+ self.wq = nn.Linear(config.dim, config.dim)
32
+ self.wk = nn.Linear(config.dim, config.dim)
33
+ self.wv = nn.Linear(config.dim, config.dim)
34
+ self.wo = nn.Linear(config.dim, config.dim)
35
+ self.wo.SCALE_INIT = 1
36
+
37
+ self.dim = config.dim
38
+ self.head_dim = config.head_dim
39
+ self.num_heads = config.num_heads
40
+ self.num_local_heads = config.num_local_heads
41
+
42
+ self.rotary_emb = RotaryPositionalEmbeddings(
43
+ dim=self.head_dim,
44
+ max_seq_len=config.seq_len,
45
+ base=config.rope_theta,
46
+ )
47
+
48
+ def forward(self, x):
49
+ bsz, seq_len, dim = x.shape
50
+
51
+ q, k, v = self.wq(x), self.wk(x), self.wv(x)
52
+ q = q.view(bsz, seq_len, self.num_heads, self.head_dim)
53
+ k = k.view(bsz, seq_len, self.num_local_heads, self.head_dim)
54
+ v = v.view(bsz, seq_len, self.num_local_heads, self.head_dim)
55
+ q, k = self.rotary_emb(q), self.rotary_emb(k)
56
+
57
+ y = flash_attn_func(
58
+ q=q,
59
+ k=k,
60
+ v=v,
61
+ causal=True,
62
+ )
63
+
64
+ out = y.reshape(bsz, seq_len, -1)
65
+ out = self.wo(out)
66
+
67
+ return out
68
+
69
+
70
+ def find_multiple(n: int, k: int) -> int:
71
+ if n % k == 0:
72
+ return n
73
+ return n + k - (n % k)
74
+
75
+
76
+ class BaseConfigForCausalLM(PretrainedConfig):
77
+ """Base PretrainedConfig class to be decorated with dataclass"""
78
+
79
+ model_type = "base_model"
80
+
81
+
82
+ @dataclass
83
+ class TransformerConfig(BaseConfigForCausalLM):
84
+ model_type = "Transformer"
85
+
86
+ # Define fields with defaults (as before)
87
+ bsz: int = 1
88
+ dim: int = 768
89
+ num_heads: int = 12
90
+ num_local_heads: int = -1
91
+ num_layers: int = 12
92
+ seq_len: int = 4096
93
+ vocab_size: int = 200064
94
+ inter_dim: Optional[int] = None
95
+ mlp_scale: float = 12.0
96
+ weight_tying: bool = True
97
+ bias: bool = False
98
+ rope_theta: float = 10000.0
99
+ torch_dtype: str = "torch.bfloat16"
100
+ device: Optional[str] = None
101
+ head_dim: Optional[int] = None
102
+
103
+ def __init__(
104
+ self,
105
+ bsz: int = 1,
106
+ dim: int = 768,
107
+ num_heads: int = 12,
108
+ num_local_heads: int = -1,
109
+ num_layers: int = 12,
110
+ seq_len: int = 4096,
111
+ vocab_size: int = 200064,
112
+ inter_dim: Optional[int] = None,
113
+ mlp_scale: float = 12.0,
114
+ weight_tying: bool = True,
115
+ bias: bool = False,
116
+ rope_theta: float = 10000.0,
117
+ torch_dtype: str = "torch.bfloat16",
118
+ device: Optional[str] = None,
119
+ head_dim: Optional[int] = None,
120
+ **kwargs,
121
+ ):
122
+ super().__init__(**kwargs)
123
+
124
+ self.bsz = bsz
125
+ self.dim = dim
126
+ self.num_heads = num_heads
127
+ self.num_local_heads = num_local_heads
128
+ self.num_layers = num_layers
129
+ self.seq_len = seq_len
130
+ self.vocab_size = vocab_size
131
+ self.inter_dim = inter_dim
132
+ self.mlp_scale = mlp_scale
133
+ self.weight_tying = weight_tying
134
+ self.bias = bias
135
+ self.rope_theta = rope_theta
136
+ self.torch_dtype = torch_dtype
137
+ self.device = device
138
+ self.head_dim = head_dim
139
+
140
+ self._post_init_logic()
141
+
142
+ def _post_init_logic(self):
143
+ if self.num_local_heads == -1:
144
+ self.num_local_heads = self.num_heads
145
+ if self.inter_dim is None:
146
+ hidden_dim = self.mlp_scale * self.dim
147
+ num_hidden = int(2 * hidden_dim / 3)
148
+ multiple = 256
149
+ self.inter_dim = find_multiple(num_hidden, multiple) if num_hidden > 0 else multiple
150
+
151
+ if self.num_heads > 0:
152
+ self.head_dim = self.dim // self.num_heads
153
+ else:
154
+ raise ValueError("num_heads must be positive")
155
+
156
+ if isinstance(self.torch_dtype, str):
157
+ dtype_str = self.torch_dtype.replace("torch.", "")
158
+ try:
159
+ self.torch_dtype = getattr(torch, dtype_str)
160
+ except AttributeError as err:
161
+ raise ValueError(f"Invalid torch_dtype string: {self.torch_dtype}") from err
162
+ elif not isinstance(self.torch_dtype, torch.dtype):
163
+ raise ValueError(f"torch_dtype must be a string or torch.dtype, got {type(self.torch_dtype)}")
164
+
165
+ if isinstance(self.device, str):
166
+ self.device = torch.device(self.device)
167
+
168
+ @classmethod
169
+ def from_name(cls, name: str):
170
+ print("Not yet implemented")
171
+ pass
172
+
173
+
174
+ class MLP(nn.Module):
175
+ def __init__(self, config: TransformerConfig) -> None:
176
+ super().__init__()
177
+ self.w1 = nn.Linear(config.dim, config.inter_dim)
178
+ self.w2 = nn.Linear(config.inter_dim, config.dim)
179
+ self.w2.SCALE_INIT = 1
180
+
181
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
182
+ return self.w2(F.gelu(self.w1(x), approximate="tanh"))
183
+
184
+
185
+ class TransformerLayer(nn.Module):
186
+ def __init__(self, config):
187
+ super().__init__()
188
+ self.attn_norm = nn.LayerNorm(config.dim, dtype=config.torch_dtype)
189
+ self.attn = Attention(config)
190
+ self.mlp_norm = nn.LayerNorm(config.dim, dtype=config.torch_dtype)
191
+ self.mlp = MLP(config)
192
+
193
+ def forward(self, x):
194
+ x = x + self.attn(self.attn_norm(x))
195
+ x = x + self.mlp(self.mlp_norm(x))
196
+ return x
197
+
198
+
199
+ class Transformer(nn.Module):
200
+ def __init__(self, config):
201
+ super().__init__()
202
+ self.config = config
203
+ self.tok_emb = nn.Embedding(config.vocab_size, config.dim)
204
+ self.layers = nn.ModuleList()
205
+ for _ in range(config.num_layers):
206
+ self.layers.append(TransformerLayer(config))
207
+ self.norm_f = nn.LayerNorm(config.dim, dtype=config.torch_dtype)
208
+ self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False)
209
+
210
+ if self.config.weight_tying:
211
+ self.tok_emb.weight = self.lm_head.weight
212
+
213
+ self.std = self.config.dim**-0.5
214
+
215
+ def init_weights(self, module):
216
+ std = self.std
217
+ if isinstance(module, nn.Linear):
218
+ if hasattr(module, "SCALE_INIT"):
219
+ std *= (2 * self.config.num_layers) ** -0.5
220
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
221
+ if module.bias is not None:
222
+ torch.nn.init.zeros_(module.bias)
223
+ elif isinstance(module, nn.Embedding):
224
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
225
+
226
+ def forward(self, input_ids: torch.Tensor, labels: torch.Tensor = None, **kwargs) -> CausalLMOutput:
227
+ x = self.tok_emb(input_ids)
228
+
229
+ for layer in self.layers:
230
+ x = layer(x)
231
+
232
+ x = self.norm_f(x)
233
+ logits = self.lm_head(x)
234
+
235
+ loss = None
236
+ if labels is not None:
237
+ loss = loss_fn(logits.flatten(0, 1), labels.flatten(0, 1))
238
+
239
+ return CausalLMOutput(
240
+ loss=loss,
241
+ logits=logits,
242
+ )
243
+
244
+ def get_num_params(self):
245
+ """
246
+ Return the number of parameters in the model.
247
+ For non-embedding count (default), the position embeddings get subtracted.
248
+ """
249
+ n_params = sum(p.numel() for p in self.parameters())
250
+ return n_params
251
+
252
+
253
+ def create_base_model_components(model_name_or_path=None, **kwargs):
254
+ """Just load the config."""
255
+ if model_name_or_path is not None:
256
+ config = TransformerConfig.from_pretrained(model_name_or_path, **kwargs)
257
+ else:
258
+ config = TransformerConfig(**kwargs)
259
+ return config
260
+
261
+
262
+ class TransformerForCausalLM(PreTrainedModel):
263
+ """Thin wrapper to comply with HuggingFace's expected interface"""
264
+
265
+ config_class = TransformerConfig
266
+ base_model_prefix = "transformer"
267
+
268
+ def __init__(self, config):
269
+ super().__init__(config)
270
+
271
+ self.transformer = Transformer(config)
272
+ self.transformer.apply(self.transformer.init_weights)
273
+
274
+ def forward(
275
+ self, input_ids: torch.Tensor, labels: torch.Tensor = None, attention_mask: torch.Tensor = None, **kwargs
276
+ ) -> CausalLMOutput:
277
+ outputs = self.transformer(input_ids, labels=labels, **kwargs)
278
+ return outputs
279
+
280
+ def generate(
281
+ self,
282
+ input_ids: torch.Tensor,
283
+ max_length: int = 32,
284
+ num_return_sequences: int = 4,
285
+ temperature: float = 0.8,
286
+ top_k: int = 50,
287
+ top_p: float = 0.95,
288
+ repetition_penalty: float = 1.2,
289
+ seed: int = 42,
290
+ ) -> torch.Tensor:
291
+ """Generate text using top-k and nucleus sampling with temperature and repetition penalty.
292
+
293
+ Args:
294
+ input_ids: Input token ids of shape (batch_size, seq_len)
295
+ max_length: Maximum length of generated sequence
296
+ num_return_sequences: Number of sequences to generate per input
297
+ temperature: Sampling temperature. Higher = more random, lower = more focused
298
+ top_k: Number of highest probability tokens to keep for top-k sampling
299
+ top_p: Cumulative probability cutoff for nucleus sampling
300
+ repetition_penalty: Penalty factor for repeating tokens. 1.0 = no penalty
301
+ seed: Random seed for reproducibility
302
+
303
+ Returns:
304
+ Generated token ids of shape (num_return_sequences, max_length)
305
+ """
306
+ self.eval() # Set to eval mode
307
+ device = input_ids.device
308
+
309
+ # Expand input for multiple sequences
310
+ input_ids = input_ids.repeat(num_return_sequences, 1)
311
+ generated = input_ids
312
+
313
+ # Set up generator for reproducible sampling
314
+ sample_rng = torch.Generator(device=device)
315
+ sample_rng.manual_seed(seed)
316
+
317
+ # Generate tokens until we reach max_length
318
+ with torch.no_grad():
319
+ while generated.size(1) < max_length:
320
+ # Get logits for next token
321
+ outputs = self.transformer(generated)
322
+ next_token_logits = outputs.logits[:, -1, :]
323
+
324
+ # Apply repetition penalty
325
+ if repetition_penalty != 1.0:
326
+ for i in range(generated.shape[0]):
327
+ for token in generated[i]:
328
+ if token in next_token_logits[i]:
329
+ next_token_logits[i, token] /= repetition_penalty
330
+
331
+ # Apply temperature
332
+ if temperature != 1.0:
333
+ next_token_logits = next_token_logits / temperature
334
+
335
+ # Get probabilities
336
+ probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
337
+
338
+ # Top-k sampling
339
+ if top_k > 0:
340
+ indices_to_remove = probs < torch.topk(probs, top_k)[0][..., -1, None]
341
+ probs[indices_to_remove] = 0
342
+
343
+ # Nucleus (top-p) sampling
344
+ if top_p < 1.0:
345
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True)
346
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
347
+
348
+ # Remove tokens with cumulative probability above the threshold
349
+ sorted_indices_to_remove = cumulative_probs > top_p
350
+ # Shift the indices to the right to keep also the first token above the threshold
351
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
352
+ sorted_indices_to_remove[..., 0] = 0
353
+
354
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
355
+ probs[indices_to_remove] = 0
356
+
357
+ # Renormalize probabilities
358
+ probs = probs / probs.sum(dim=-1, keepdim=True).clamp(min=1e-8)
359
+
360
+ # Sample next token
361
+ next_token = torch.multinomial(probs, num_samples=1, generator=sample_rng)
362
+
363
+ # Append to generated sequence
364
+ generated = torch.cat([generated, next_token], dim=1)
365
+
366
+ return generated
367
+
368
+ def get_num_params(self):
369
+ return self.transformer.get_num_params()
370
+
371
+ @classmethod
372
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
373
+ # Get config and create model
374
+ config = create_base_model_components(pretrained_model_name_or_path, **kwargs)
375
+ model = cls(config)
376
+
377
+ # Download safetensors file from hub
378
+ weights_path = hf_hub_download(
379
+ repo_id=pretrained_model_name_or_path,
380
+ filename="model.safetensors",
381
+ cache_dir=kwargs.get("cache_dir"),
382
+ force_download=kwargs.get("force_download", False),
383
+ proxies=kwargs.get("proxies", None),
384
+ local_files_only=kwargs.get("local_files_only", False),
385
+ use_auth_token=kwargs.get("use_auth_token", None),
386
+ revision=kwargs.get("revision", None),
387
+ subfolder=kwargs.get("subfolder", ""),
388
+ )
389
+
390
+ # Load the state dict and metadata from safetensors
391
+ state_dict = load_file(weights_path)
392
+
393
+ # Reconstruct weight tying for tok_emb and lm_head specifically
394
+ tok_emb_key = "tok_emb.weight"
395
+ lm_head_key = "lm_head.weight"
396
+
397
+ tok_emb_present = tok_emb_key in state_dict
398
+ lm_head_present = lm_head_key in state_dict
399
+
400
+ if tok_emb_present and not lm_head_present:
401
+ print(f"Reconstructing weight tying: Linking missing '{lm_head_key}' to existing '{tok_emb_key}'")
402
+ state_dict[lm_head_key] = state_dict[tok_emb_key]
403
+ elif lm_head_present and not tok_emb_present:
404
+ print(f"Reconstructing weight tying: Linking missing '{tok_emb_key}' to existing '{lm_head_key}'")
405
+ state_dict[tok_emb_key] = state_dict[lm_head_key]
406
+ elif not tok_emb_present and not lm_head_present:
407
+ # This case should ideally not happen if the file is valid
408
+ print(
409
+ f"Warning: Neither '{tok_emb_key}' nor '{lm_head_key}' found in state_dict. Weight tying cannot be reconstructed."
410
+ )
411
+ # If both are present, assume they are loaded correctly (or were never tied)
412
+
413
+ # Prepend prefix to all keys to match wrapper's state dict
414
+ final_state_dict = {f"{cls.base_model_prefix}.{k}": v for k, v in state_dict.items()}
415
+ model.load_state_dict(final_state_dict)
416
+
417
+ # Move to GPU if available
418
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
419
+ model = model.to(device=device, dtype=torch.bfloat16)
420
+ model.eval()
421
+
422
+ # Print parameter count as a sanity check
423
+ num_params = model.get_num_params()
424
+ print(f"\nModel loaded: {pretrained_model_name_or_path}")
425
+ print(f"Parameter count: {num_params / 1e6:.2f}M")
426
+
427
+ return model
428
+
429
+
430
+ # Create initial config using the correct class
431
+ config = TransformerConfig()
432
+
433
+ # Register models with correct names
434
+ AutoConfig.register("Transformer", TransformerConfig)
435
+ AutoModel.register(TransformerConfig, Transformer)
436
+ AutoModelForCausalLM.register(TransformerConfig, TransformerForCausalLM)
437
+
438
+ print("Registered Transformer model and configuration.")
439
+
440
+
441
+ def run_model_diagnostics(model, tokenizer, device):
442
+ """Run detailed diagnostics to analyze model behavior."""
443
+ print("\nRunning model diagnostics...")
444
+
445
+ # Test cases of varying difficulty and length
446
+ test_cases = [
447
+ # Simple completion
448
+ "2 + 2 =",
449
+ # Medium difficulty
450
+ "The capital of France is Paris. The capital of Germany is",
451
+ # Complex reasoning
452
+ "If a train travels 120 kilometers in 2 hours, its average speed is",
453
+ # Pattern completion
454
+ "1, 2, 3, 4,",
455
+ # Long context
456
+ "The following is a detailed explanation of photosynthesis: Plants use sunlight to",
457
+ ]
458
+
459
+ with torch.no_grad():
460
+ for prompt in test_cases:
461
+ print(f"\nAnalyzing prompt: {prompt}")
462
+
463
+ # Tokenize
464
+ tokens = tokenizer(prompt, return_tensors="pt")
465
+ input_ids = tokens["input_ids"].to(device)
466
+
467
+ # Get model outputs with attention patterns
468
+ outputs = model.transformer(input_ids, labels=input_ids)
469
+
470
+ # Analyze loss at different positions
471
+ labels = input_ids.clone()
472
+ shift_logits = outputs.logits[..., :-1, :].contiguous()
473
+ shift_labels = labels[..., 1:].contiguous()
474
+
475
+ loss_fct = nn.CrossEntropyLoss(reduction="none")
476
+ token_losses = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(
477
+ shift_labels.size()
478
+ )
479
+
480
+ # Print token-by-token analysis
481
+ input_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
482
+ print("\nToken-by-token loss:")
483
+ for _, (token, loss) in enumerate(zip(input_tokens[1:], token_losses[0])):
484
+ print(f"{token}: {loss.item():.3f}")
485
+
486
+ print(f"Average loss: {token_losses.mean().item():.3f}")
487
+
488
+ # Generate with different temperatures
489
+ temps = [0.5, 0.7, 1.0]
490
+ print("\nGeneration temperature comparison:")
491
+ for temp in temps:
492
+ gen_ids = model.generate(
493
+ input_ids,
494
+ max_length=25,
495
+ num_return_sequences=1,
496
+ temperature=temp,
497
+ top_p=0.9,
498
+ repetition_penalty=1.5,
499
+ seed=42,
500
+ )
501
+ gen_text = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
502
+ print(f"\nTemp {temp}: {gen_text}")
503
+
504
+
505
+ def validate_model_generation():
506
+ print("\nRunning generation validation test...")
507
+
508
+ try:
509
+ from transformers import AutoTokenizer
510
+
511
+ # Load model and tokenizer
512
+ model_id = "Hazan-Lab/Transformer-340M-0428"
513
+ model = TransformerForCausalLM.from_pretrained(model_id)
514
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
515
+
516
+ # Move to GPU if available
517
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
518
+ model = model.to(device=device, dtype=torch.bfloat16)
519
+ model.eval()
520
+
521
+ # Print parameter count as a sanity check
522
+ num_params = model.get_num_params()
523
+ print(f"\nModel loaded: {model_id}")
524
+ print(f"Parameter count: {num_params / 1e6:.2f}M")
525
+
526
+ # Run additional diagnostics
527
+ run_model_diagnostics(model, tokenizer, device)
528
+
529
+ except Exception as e:
530
+ print(f"\nError during validation: {str(e)}")
531
+ raise
532
+
533
+
534
+ # Run evaluation tasks
535
+ tasks = [
536
+ "hellaswag",
537
+ # "mmlu",
538
+ # "piqa",
539
+ # "siqa",
540
+ # "boolq",
541
+ # "winogrande",
542
+ # "commonsense_qa",
543
+ # "openbookqa",
544
+ # "arc",
545
+ # "arc_easy",
546
+ # "arc_challenge",
547
+ # "triviaqa",
548
+ # "nq_open",
549
+ # "humaneval",
550
+ # "mbpp",
551
+ # "gms8k",
552
+ # "hendrycks_math",
553
+ # "mathqa",
554
+ # "minerva_math",
555
+ # "score",
556
+ # "asdiv",
557
+ # "agieval",
558
+ # "bigbench",
559
+ ]
560
+
561
+ tasks_fewshot = {
562
+ "hellaswag": 0,
563
+ # "mmlu": 5,
564
+ # "piqa": 0,
565
+ # "siqa": 0,
566
+ # "boolq": 0,
567
+ # "winogrande": -1,
568
+ # "commonsense_qa": 7,
569
+ # "openbookqa": -1,
570
+ # "arc": -1,
571
+ # "arc_easy": -1,
572
+ # "arc_challenge": -1,
573
+ # "triviaqa": 5,
574
+ # "nq_open": 5,
575
+ # "humaneval": -1,
576
+ # "mbpp": 3,
577
+ # "gms8k": -1,
578
+ # "hendrycks_math": 4,
579
+ # "mathqa": -1,
580
+ # "minerva_math": -1,
581
+ # "score": -1,
582
+ # "asdiv": -1,
583
+ # "agieval": -1,
584
+ # "bigbench": -1,
585
+ }
586
+
587
+ all_results = {}
588
+
589
+ # First validate generation works
590
+ validate_model_generation()
591
+ model_id = "Hazan-Lab/Transformer-340M-0428"
592
+
593
+ print("\nStarting evaluation tasks...")
594
+ for task in tasks:
595
+ print(f"\nEvaluating task: {task}")
596
+ eval_kwargs = dict(
597
+ model="hf",
598
+ model_args=(
599
+ f"pretrained={model_id},"
600
+ "trust_remote_code=True,"
601
+ "dtype=bfloat16,"
602
+ "cache_dir=/scratch/gpfs/mn4560/hazan-lab/tensorized_filters/tensorized_filters/eval/cache"
603
+ ),
604
+ tasks=[task],
605
+ batch_size="auto",
606
+ device="cuda:0",
607
+ )
608
+ few_shot_value = tasks_fewshot.get(task, -1)
609
+ if few_shot_value != -1:
610
+ eval_kwargs["num_fewshot"] = few_shot_value
611
+ results = evaluator.simple_evaluate(**eval_kwargs)
612
+ task_result = results["results"].get(task, {})
613
+ all_results[task] = task_result
614
+ print(f"Results for {task}:")
615
+ print(task_result)
616
+ print("\n" + "=" * 50 + "\n")
617
+
618
+ print("All Evaluation Results:")
619
+ for task, result in all_results.items():
620
+ print(f"{task}: {result}")