Charlie81 commited on
Commit
b87e5c8
·
1 Parent(s): 112fb19
Files changed (1) hide show
  1. scripts/eval.py +389 -162
scripts/eval.py CHANGED
@@ -1,271 +1,489 @@
1
  #!/usr/bin/env python3
2
  """
3
- eval.py Evaluation script for modified OLMoE model using lm-evaluation-harness
 
 
 
 
 
 
 
 
 
 
 
4
  """
 
5
  import argparse
6
  import json
7
  import os
8
- from typing import Dict, List, Optional
 
 
 
9
  import torch
10
- from transformers import AutoConfig, AutoTokenizer
 
 
11
  from lm_eval import evaluator
12
- # Remove the problematic import - we don't need get_model
13
- import logging
14
 
15
  # Set up logging
16
- logging.basicConfig(level=logging.INFO)
 
 
 
17
  logger = logging.getLogger(__name__)
18
 
19
 
20
  def parse_args():
21
  """Parse command line arguments."""
22
- parser = argparse.ArgumentParser(description="Evaluate myolmoe model")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  # Model arguments
25
- parser.add_argument("--model_path", type=str, default="/root/.cache/huggingface/hub/models--allenai--OLMoE-7B/snapshots/6d84c48581ece794365f2b8e9cfb043c68ade9c5",
26
- help="Path to the pretrained model")
27
- parser.add_argument("--model_type", type=str, default="hf-auto",
28
- help="Model type for lm-eval")
29
-
30
- # Routing configuration
31
- parser.add_argument("--routing_type", type=str, default="non_deterministic",
32
- choices=["dense", "sparse", "non_deterministic"],
33
- help="Type of routing to use")
34
- parser.add_argument("--router_temperature", type=float, default=1.0,
35
- help="Temperature for non-deterministic routing")
36
- parser.add_argument("--num_experts_per_tok", type=int, default=8,
37
- help="Number of experts per token")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  # Evaluation arguments
40
- parser.add_argument("--tasks", type=str, nargs="+",
41
- default=['mmlu'],
42
- # , 'gsm8k'
43
- # default=["hellaswag", "arc_easy", "arc_challenge", "winogrande"],
44
- help="Tasks to evaluate on")
45
- parser.add_argument("--num_fewshot", type=int, default=0,
46
- help="Number of few-shot examples")
47
- parser.add_argument("--batch_size", type=int, default=64,
48
- help="Batch size for evaluation")
49
- parser.add_argument("--max_batch_size", type=int, default=None,
50
- help="Maximum batch size")
51
- parser.add_argument("--device", type=str, default="cuda",
52
- help="Device to use for evaluation")
53
- parser.add_argument("--dtype", type=str, default="float16",
54
- choices=["float16", "bfloat16", "float32"],
55
- help="Data type for model weights")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  # Output arguments
58
- parser.add_argument("--output_dir", type=str, default="./eval_results",
59
- help="Directory to save evaluation results")
60
- parser.add_argument("--output_filename", type=str, default=None,
61
- help="Filename for results (auto-generated if not provided)")
 
 
 
 
 
 
 
 
62
 
63
  # Additional arguments
64
- parser.add_argument("--limit", type=int, default=None,
65
- help="Limit number of examples per task")
66
- parser.add_argument("--write_out", action="store_true",
67
- help="Write out individual predictions")
68
- parser.add_argument("--trust_remote_code", action="store_true",
69
- help="Trust remote code when loading model")
70
- parser.add_argument("--verbosity", type=str, default="INFO",
71
- choices=["DEBUG", "INFO", "WARNING", "ERROR"],
72
- help="Logging verbosity level")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  return parser.parse_args()
75
 
76
 
77
- def setup_model_config(model_path: str, routing_config: Dict) -> None:
78
  """
79
- Update model configuration with routing settings.
80
- """
81
- config_path = os.path.join(model_path, "config.json")
82
 
83
- if os.path.exists(config_path):
84
- with open(config_path, 'r') as f:
85
- config = json.load(f)
86
-
87
- # Update routing configuration
88
- config.update(routing_config)
89
 
90
- # Save updated config
91
- with open(config_path, 'w') as f:
92
- json.dump(config, f, indent=2)
93
-
94
- logger.info(f"Updated model config with routing settings: {routing_config}")
95
- else:
96
- logger.warning(f"Config file not found at {config_path}")
 
 
 
 
 
 
 
 
 
 
97
 
98
 
99
- def validate_model_setup(model_path: str) -> bool:
100
  """
101
- Validate that the model can be loaded with the current configuration.
 
 
 
 
 
 
102
  """
 
 
 
 
 
 
 
 
 
 
103
  try:
104
- config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
105
- tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
106
-
107
- logger.info(f"Model validation successful:")
108
- logger.info(f" - Model type: {config.model_type}")
109
- logger.info(f" - Routing type: {getattr(config, 'routing_type', 'not specified')}")
110
- logger.info(f" - Vocab size: {config.vocab_size}")
111
- logger.info(f" - Hidden size: {config.hidden_size}")
112
- logger.info(f" - Num layers: {config.num_hidden_layers}")
113
- logger.info(f" - Num experts: {getattr(config, 'num_experts', 'not specified')}")
114
-
115
- return True
116
- except Exception as e:
117
- logger.error(f"Model validation failed: {e}")
118
- return False
119
-
120
- def run_evaluation(args) -> Dict:
121
- """Run evaluation with properly wrapped model."""
122
- from transformers import AutoModelForCausalLM
123
- import sys, os
124
- sys.path.insert(0, os.path.join(os.path.dirname(__file__), "myolmoe_model"))
125
-
126
- # 1. Load config and override routing parameters
127
  config = AutoConfig.from_pretrained(
128
  args.model_path,
129
- trust_remote_code=True
130
  )
 
 
131
  config.routing_type = args.routing_type
132
  config.router_temperature = args.router_temperature
133
  config.num_experts_per_tok = args.num_experts_per_tok
134
-
135
-
136
- # 2. Load model with updated config
137
- torch_dtype = {
138
- "float16": torch.float16,
139
- "bfloat16": torch.bfloat16,
140
- "float32": torch.float32
141
- }[args.dtype]
142
-
143
- from modeling_myolmoe import MyOLMoEForCausalLM
144
-
 
 
 
 
 
 
145
  hf_model = MyOLMoEForCausalLM.from_pretrained(
146
  args.model_path,
147
  config=config,
148
  torch_dtype=torch_dtype,
149
- device_map="auto"
 
150
  ).eval()
151
-
152
-
153
- # 3. Wrap the Hugging Face model in HFLM
154
- eval_model = HFLM(
155
- pretrained=hf_model, # Pass the initialized model
156
  device=args.device,
157
  batch_size=args.batch_size,
158
  max_batch_size=args.max_batch_size,
159
  dtype=args.dtype
160
  )
 
 
 
161
 
162
- # 4. Run evaluation with the wrapped model
163
- results = evaluator.simple_evaluate(
164
- model=eval_model, # Pass the wrapped model
165
- tasks=args.tasks,
166
- num_fewshot=args.num_fewshot,
167
- limit=args.limit,
168
- write_out=args.write_out,
169
- verbosity=args.verbosity,
170
- )
171
 
172
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
- import numpy as np
175
- import torch
176
 
177
- def make_serializable(obj):
 
 
 
 
 
 
 
 
 
178
  if isinstance(obj, dict):
179
  return {k: make_serializable(v) for k, v in obj.items()}
180
  elif isinstance(obj, list):
181
  return [make_serializable(v) for v in obj]
182
  elif isinstance(obj, tuple):
183
  return tuple(make_serializable(v) for v in obj)
184
- # NumPy scalars
185
  elif isinstance(obj, (np.integer, np.floating)):
186
  return obj.item()
187
- # NumPy dtypes
188
  elif isinstance(obj, np.dtype):
189
  return str(obj)
190
- # PyTorch tensor → list
191
  elif isinstance(obj, torch.Tensor):
192
  return obj.tolist()
193
- # PyTorch dtype (e.g. torch.float16)
194
  elif isinstance(obj, torch.dtype):
195
  return str(obj)
196
- # Anything else leave alone
197
  else:
198
  return obj
199
 
200
- def save_results(results: Dict, args) -> str:
201
- """Save evaluation results to file, after converting to JSON-safe types, and print them."""
202
- os.makedirs(args.output_dir, exist_ok=True)
203
 
204
- # build filename exactly as before…
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  if args.output_filename is None:
206
  model_name = os.path.basename(args.model_path.rstrip('/'))
207
  tasks_str = "_".join(args.tasks[:3])
208
  if len(args.tasks) > 3:
209
  tasks_str += f"_and_{len(args.tasks)-3}_more"
210
- filename = f"{model_name}_{args.routing_type}_{tasks_str}_results.json"
 
 
 
 
211
  else:
212
  filename = args.output_filename
 
213
  if not filename.endswith('.json'):
214
  filename += '.json'
 
215
  output_path = os.path.join(args.output_dir, filename)
216
-
 
217
  metadata = {
218
  "model_path": args.model_path,
219
- "routing_type": args.routing_type,
220
- "router_temperature": args.router_temperature,
221
- "num_experts_per_tok": args.num_experts_per_tok,
222
  "tasks": args.tasks,
223
  "num_fewshot": args.num_fewshot,
224
  "batch_size": args.batch_size,
225
  "device": args.device,
226
  "dtype": args.dtype,
 
227
  }
 
 
 
 
 
 
 
 
 
228
  results_with_metadata = {
229
  "metadata": metadata,
230
  "results": results
231
  }
232
-
233
- # convert everything
234
- serializable = make_serializable(results_with_metadata)
235
-
236
- # write to disk
237
  with open(output_path, 'w') as f:
238
- json.dump(serializable, f, indent=2)
239
-
240
  logger.info(f"Results saved to {output_path}")
241
  return output_path
242
 
243
 
244
-
245
-
246
- def print_summary(results: Dict, routing_type: str) -> None:
247
  """
248
- Print a summary of evaluation results.
 
 
 
 
249
  """
250
- print(f"\n{'='*60}")
251
- print(f"EVALUATION SUMMARY - Routing: {routing_type.upper()}")
252
- print(f"{'='*60}")
 
 
 
 
 
253
 
254
  if "results" in results:
255
  for task, metrics in results["results"].items():
256
  if isinstance(metrics, dict):
257
- print(f"\n{task.upper()}:")
258
  for metric, value in metrics.items():
259
- if isinstance(value, (int, float)):
260
- if metric.endswith('_stderr'):
261
- continue # Skip stderr for summary
262
  stderr_key = f"{metric}_stderr"
263
  stderr = metrics.get(stderr_key, 0)
264
- print(f" {metric}: {value:.4f} (±{stderr:.4f})")
265
- else:
266
- print(f" {metric}: {value}")
267
 
268
- print(f"\n{'='*60}")
269
 
270
 
271
  def main():
@@ -279,6 +497,10 @@ def main():
279
  logger.setLevel(numeric_level)
280
 
281
  try:
 
 
 
 
282
  # Run evaluation
283
  results = run_evaluation(args)
284
 
@@ -286,13 +508,18 @@ def main():
286
  output_path = save_results(results, args)
287
 
288
  # Print summary
289
- print_summary(results, args.routing_type)
290
 
291
- logger.info("Evaluation completed successfully!")
 
292
 
 
 
 
293
  except Exception as e:
294
- logger.error(f"Evaluation failed: {e}")
295
- raise
 
296
 
297
 
298
  if __name__ == "__main__":
 
1
  #!/usr/bin/env python3
2
  """
3
+ eval.py - Evaluation script for OLMoE models using lm-evaluation-harness
4
+
5
+ This script supports evaluation of both:
6
+ 1. Standard Transformers OLMoE models
7
+ 2. Custom MyOLMoE models with modified routing
8
+
9
+ Usage Examples:
10
+ # Evaluate standard OLMoE model
11
+ python eval.py --model_type transformers --tasks mmlu hellaswag
12
+
13
+ # Evaluate custom MyOLMoE model with non-deterministic routing
14
+ python eval.py --model_type custom --routing_type non_deterministic --tasks mmlu
15
  """
16
+
17
  import argparse
18
  import json
19
  import os
20
+ import sys
21
+ import logging
22
+ from typing import Dict, List, Optional, Any
23
+ import numpy as np
24
  import torch
25
+ from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
26
+
27
+ # lm-eval imports
28
  from lm_eval import evaluator
29
+ from lm_eval.models.huggingface import HFLM
 
30
 
31
  # Set up logging
32
+ logging.basicConfig(
33
+ level=logging.INFO,
34
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
35
+ )
36
  logger = logging.getLogger(__name__)
37
 
38
 
39
  def parse_args():
40
  """Parse command line arguments."""
41
+ parser = argparse.ArgumentParser(
42
+ description="Evaluate OLMoE models using lm-evaluation-harness",
43
+ formatter_class=argparse.RawDescriptionHelpFormatter,
44
+ epilog="""
45
+ Examples:
46
+ # Standard OLMoE evaluation
47
+ python eval.py --model_type transformers --tasks mmlu arc_easy
48
+
49
+ # Custom MyOLMoE with non-deterministic routing
50
+ python eval.py --model_type custom --routing_type non_deterministic \\
51
+ --router_temperature 0.8 --tasks mmlu hellaswag
52
+
53
+ # Dense routing evaluation
54
+ python eval.py --model_type custom --routing_type dense --tasks gsm8k
55
+ """
56
+ )
57
 
58
  # Model arguments
59
+ parser.add_argument(
60
+ "--model_path",
61
+ type=str,
62
+ default="allenai/OLMoE-1B-7B-0924",
63
+ help="Path or name of the pretrained model"
64
+ )
65
+ parser.add_argument(
66
+ "--model_type",
67
+ type=str,
68
+ default="transformers",
69
+ choices=["transformers", "custom"],
70
+ help="Model type: 'transformers' for standard OLMoE, 'custom' for MyOLMoE"
71
+ )
72
+ parser.add_argument(
73
+ "--custom_model_path",
74
+ type=str,
75
+ default="./myolmoe_model",
76
+ help="Path to custom MyOLMoE model code (when using --model_type custom)"
77
+ )
78
+
79
+ # Routing configuration (only for custom models)
80
+ parser.add_argument(
81
+ "--routing_type",
82
+ type=str,
83
+ default="sparse",
84
+ choices=["dense", "sparse", "non_deterministic"],
85
+ help="Routing type (only used with custom models)"
86
+ )
87
+ parser.add_argument(
88
+ "--router_temperature",
89
+ type=float,
90
+ default=1.0,
91
+ help="Temperature for non-deterministic routing"
92
+ )
93
+ parser.add_argument(
94
+ "--num_experts_per_tok",
95
+ type=int,
96
+ default=8,
97
+ help="Number of experts per token"
98
+ )
99
 
100
  # Evaluation arguments
101
+ parser.add_argument(
102
+ "--tasks",
103
+ type=str,
104
+ nargs="+",
105
+ default=["mmlu"],
106
+ help="Tasks to evaluate on (e.g., mmlu, hellaswag, arc_easy, gsm8k)"
107
+ )
108
+ parser.add_argument(
109
+ "--num_fewshot",
110
+ type=int,
111
+ default=0,
112
+ help="Number of few-shot examples"
113
+ )
114
+ parser.add_argument(
115
+ "--batch_size",
116
+ type=int,
117
+ default=8,
118
+ help="Batch size for evaluation"
119
+ )
120
+ parser.add_argument(
121
+ "--max_batch_size",
122
+ type=int,
123
+ default=None,
124
+ help="Maximum batch size (auto if None)"
125
+ )
126
+ parser.add_argument(
127
+ "--device",
128
+ type=str,
129
+ default="auto",
130
+ help="Device to use ('auto', 'cuda', 'cpu')"
131
+ )
132
+ parser.add_argument(
133
+ "--dtype",
134
+ type=str,
135
+ default="auto",
136
+ choices=["auto", "float16", "bfloat16", "float32"],
137
+ help="Data type for model weights"
138
+ )
139
 
140
  # Output arguments
141
+ parser.add_argument(
142
+ "--output_dir",
143
+ type=str,
144
+ default="./eval_results",
145
+ help="Directory to save evaluation results"
146
+ )
147
+ parser.add_argument(
148
+ "--output_filename",
149
+ type=str,
150
+ default=None,
151
+ help="Custom filename for results (auto-generated if not provided)"
152
+ )
153
 
154
  # Additional arguments
155
+ parser.add_argument(
156
+ "--limit",
157
+ type=int,
158
+ default=None,
159
+ help="Limit number of examples per task (for testing)"
160
+ )
161
+ parser.add_argument(
162
+ "--write_out",
163
+ action="store_true",
164
+ help="Write out individual predictions to files"
165
+ )
166
+ parser.add_argument(
167
+ "--trust_remote_code",
168
+ action="store_true",
169
+ help="Trust remote code when loading model"
170
+ )
171
+ parser.add_argument(
172
+ "--verbosity",
173
+ type=str,
174
+ default="INFO",
175
+ choices=["DEBUG", "INFO", "WARNING", "ERROR"],
176
+ help="Logging verbosity level"
177
+ )
178
 
179
  return parser.parse_args()
180
 
181
 
182
+ def load_transformers_model(args) -> HFLM:
183
  """
184
+ Load standard Transformers OLMoE model.
 
 
185
 
186
+ Args:
187
+ args: Parsed command line arguments
 
 
 
 
188
 
189
+ Returns:
190
+ HFLM: Wrapped model ready for evaluation
191
+ """
192
+ logger.info(f"Loading Transformers OLMoE model: {args.model_path}")
193
+
194
+ # Create HFLM model directly
195
+ model = HFLM(
196
+ pretrained=args.model_path,
197
+ device=args.device,
198
+ batch_size=args.batch_size,
199
+ max_batch_size=args.max_batch_size,
200
+ dtype=args.dtype,
201
+ trust_remote_code=args.trust_remote_code
202
+ )
203
+
204
+ logger.info("Transformers model loaded successfully")
205
+ return model
206
 
207
 
208
+ def load_custom_model(args) -> HFLM:
209
  """
210
+ Load custom MyOLMoE model with routing configuration.
211
+
212
+ Args:
213
+ args: Parsed command line arguments
214
+
215
+ Returns:
216
+ HFLM: Wrapped model ready for evaluation
217
  """
218
+ logger.info(f"Loading custom MyOLMoE model: {args.model_path}")
219
+ logger.info(f"Routing configuration: {args.routing_type}")
220
+
221
+ # Add custom model path to Python path
222
+ if os.path.exists(args.custom_model_path):
223
+ sys.path.insert(0, args.custom_model_path)
224
+ logger.info(f"Added {args.custom_model_path} to Python path")
225
+ else:
226
+ logger.warning(f"Custom model path not found: {args.custom_model_path}")
227
+
228
  try:
229
+ # Import custom model class
230
+ from modeling_myolmoe import MyOLMoEForCausalLM
231
+ logger.info("Successfully imported MyOLMoEForCausalLM")
232
+ except ImportError as e:
233
+ logger.error(f"Failed to import custom model: {e}")
234
+ logger.error("Make sure the custom model code is available in the specified path")
235
+ raise
236
+
237
+ # Load and configure model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  config = AutoConfig.from_pretrained(
239
  args.model_path,
240
+ trust_remote_code=args.trust_remote_code
241
  )
242
+
243
+ # Override routing configuration
244
  config.routing_type = args.routing_type
245
  config.router_temperature = args.router_temperature
246
  config.num_experts_per_tok = args.num_experts_per_tok
247
+
248
+ logger.info(f"Model config updated:")
249
+ logger.info(f" - routing_type: {config.routing_type}")
250
+ logger.info(f" - router_temperature: {config.router_temperature}")
251
+ logger.info(f" - num_experts_per_tok: {config.num_experts_per_tok}")
252
+
253
+ # Determine torch dtype
254
+ if args.dtype == "auto":
255
+ torch_dtype = "auto"
256
+ else:
257
+ torch_dtype = {
258
+ "float16": torch.float16,
259
+ "bfloat16": torch.bfloat16,
260
+ "float32": torch.float32
261
+ }[args.dtype]
262
+
263
+ # Load the custom model
264
  hf_model = MyOLMoEForCausalLM.from_pretrained(
265
  args.model_path,
266
  config=config,
267
  torch_dtype=torch_dtype,
268
+ device_map="auto" if args.device == "auto" else None,
269
+ trust_remote_code=args.trust_remote_code
270
  ).eval()
271
+
272
+ # Wrap in HFLM
273
+ model = HFLM(
274
+ pretrained=hf_model,
 
275
  device=args.device,
276
  batch_size=args.batch_size,
277
  max_batch_size=args.max_batch_size,
278
  dtype=args.dtype
279
  )
280
+
281
+ logger.info("Custom model loaded successfully")
282
+ return model
283
 
 
 
 
 
 
 
 
 
 
284
 
285
+ def validate_model_config(model_path: str, trust_remote_code: bool = False) -> Dict[str, Any]:
286
+ """
287
+ Validate model configuration and return key information.
288
+
289
+ Args:
290
+ model_path: Path to the model
291
+ trust_remote_code: Whether to trust remote code
292
+
293
+ Returns:
294
+ Dict containing model configuration information
295
+ """
296
+ try:
297
+ config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
298
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=trust_remote_code)
299
+
300
+ model_info = {
301
+ "model_type": getattr(config, "model_type", "unknown"),
302
+ "vocab_size": getattr(config, "vocab_size", "unknown"),
303
+ "hidden_size": getattr(config, "hidden_size", "unknown"),
304
+ "num_layers": getattr(config, "num_hidden_layers", "unknown"),
305
+ "num_experts": getattr(config, "num_experts", "not specified"),
306
+ "routing_type": getattr(config, "routing_type", "default"),
307
+ }
308
+
309
+ logger.info("Model validation successful:")
310
+ for key, value in model_info.items():
311
+ logger.info(f" {key}: {value}")
312
+
313
+ return model_info
314
+
315
+ except Exception as e:
316
+ logger.error(f"Model validation failed: {e}")
317
+ raise
318
 
 
 
319
 
320
+ def make_serializable(obj: Any) -> Any:
321
+ """
322
+ Convert objects to JSON-serializable format.
323
+
324
+ Args:
325
+ obj: Object to convert
326
+
327
+ Returns:
328
+ JSON-serializable version of the object
329
+ """
330
  if isinstance(obj, dict):
331
  return {k: make_serializable(v) for k, v in obj.items()}
332
  elif isinstance(obj, list):
333
  return [make_serializable(v) for v in obj]
334
  elif isinstance(obj, tuple):
335
  return tuple(make_serializable(v) for v in obj)
 
336
  elif isinstance(obj, (np.integer, np.floating)):
337
  return obj.item()
 
338
  elif isinstance(obj, np.dtype):
339
  return str(obj)
 
340
  elif isinstance(obj, torch.Tensor):
341
  return obj.tolist()
 
342
  elif isinstance(obj, torch.dtype):
343
  return str(obj)
 
344
  else:
345
  return obj
346
 
 
 
 
347
 
348
+ def run_evaluation(args) -> Dict[str, Any]:
349
+ """
350
+ Run evaluation on the specified model.
351
+
352
+ Args:
353
+ args: Parsed command line arguments
354
+
355
+ Returns:
356
+ Dict containing evaluation results
357
+ """
358
+ logger.info("Starting evaluation...")
359
+
360
+ # Validate model first
361
+ validate_model_config(args.model_path, args.trust_remote_code)
362
+
363
+ # Load appropriate model
364
+ if args.model_type == "transformers":
365
+ model = load_transformers_model(args)
366
+ elif args.model_type == "custom":
367
+ model = load_custom_model(args)
368
+ else:
369
+ raise ValueError(f"Unknown model type: {args.model_type}")
370
+
371
+ # Run evaluation
372
+ logger.info(f"Running evaluation on tasks: {args.tasks}")
373
+ logger.info(f"Few-shot examples: {args.num_fewshot}")
374
+ logger.info(f"Batch size: {args.batch_size}")
375
+
376
+ results = evaluator.simple_evaluate(
377
+ model=model,
378
+ tasks=args.tasks,
379
+ num_fewshot=args.num_fewshot,
380
+ limit=args.limit,
381
+ write_out=args.write_out,
382
+ verbosity=args.verbosity,
383
+ )
384
+
385
+ logger.info("Evaluation completed successfully")
386
+ return results
387
+
388
+
389
+ def save_results(results: Dict[str, Any], args) -> str:
390
+ """
391
+ Save evaluation results to file.
392
+
393
+ Args:
394
+ results: Evaluation results
395
+ args: Parsed command line arguments
396
+
397
+ Returns:
398
+ str: Path to saved results file
399
+ """
400
+ os.makedirs(args.output_dir, exist_ok=True)
401
+
402
+ # Generate filename if not provided
403
  if args.output_filename is None:
404
  model_name = os.path.basename(args.model_path.rstrip('/'))
405
  tasks_str = "_".join(args.tasks[:3])
406
  if len(args.tasks) > 3:
407
  tasks_str += f"_and_{len(args.tasks)-3}_more"
408
+
409
+ if args.model_type == "custom":
410
+ filename = f"{model_name}_{args.routing_type}_{tasks_str}_results.json"
411
+ else:
412
+ filename = f"{model_name}_transformers_{tasks_str}_results.json"
413
  else:
414
  filename = args.output_filename
415
+
416
  if not filename.endswith('.json'):
417
  filename += '.json'
418
+
419
  output_path = os.path.join(args.output_dir, filename)
420
+
421
+ # Prepare metadata
422
  metadata = {
423
  "model_path": args.model_path,
424
+ "model_type": args.model_type,
 
 
425
  "tasks": args.tasks,
426
  "num_fewshot": args.num_fewshot,
427
  "batch_size": args.batch_size,
428
  "device": args.device,
429
  "dtype": args.dtype,
430
+ "limit": args.limit,
431
  }
432
+
433
+ # Add routing-specific metadata for custom models
434
+ if args.model_type == "custom":
435
+ metadata.update({
436
+ "routing_type": args.routing_type,
437
+ "router_temperature": args.router_temperature,
438
+ "num_experts_per_tok": args.num_experts_per_tok,
439
+ })
440
+
441
  results_with_metadata = {
442
  "metadata": metadata,
443
  "results": results
444
  }
445
+
446
+ # Convert to JSON-serializable format
447
+ serializable_results = make_serializable(results_with_metadata)
448
+
449
+ # Save to file
450
  with open(output_path, 'w') as f:
451
+ json.dump(serializable_results, f, indent=2)
452
+
453
  logger.info(f"Results saved to {output_path}")
454
  return output_path
455
 
456
 
457
+ def print_summary(results: Dict[str, Any], args) -> None:
 
 
458
  """
459
+ Print a formatted summary of evaluation results.
460
+
461
+ Args:
462
+ results: Evaluation results
463
+ args: Parsed command line arguments
464
  """
465
+ print(f"\n{'='*80}")
466
+ print(f"EVALUATION SUMMARY")
467
+ print(f"Model: {args.model_path}")
468
+ print(f"Type: {args.model_type.upper()}")
469
+ if args.model_type == "custom":
470
+ print(f"Routing: {args.routing_type.upper()}")
471
+ print(f"Tasks: {', '.join(args.tasks)}")
472
+ print(f"{'='*80}")
473
 
474
  if "results" in results:
475
  for task, metrics in results["results"].items():
476
  if isinstance(metrics, dict):
477
+ print(f"\n📊 {task.upper()}:")
478
  for metric, value in metrics.items():
479
+ if isinstance(value, (int, float)) and not metric.endswith('_stderr'):
 
 
480
  stderr_key = f"{metric}_stderr"
481
  stderr = metrics.get(stderr_key, 0)
482
+ print(f" {metric:.<20} {value:.4f} (±{stderr:.4f})")
483
+ else:
484
+ print("\n⚠️ No results found in evaluation output")
485
 
486
+ print(f"\n{'='*80}")
487
 
488
 
489
  def main():
 
497
  logger.setLevel(numeric_level)
498
 
499
  try:
500
+ logger.info("="*80)
501
+ logger.info("Starting OLMoE Model Evaluation")
502
+ logger.info("="*80)
503
+
504
  # Run evaluation
505
  results = run_evaluation(args)
506
 
 
508
  output_path = save_results(results, args)
509
 
510
  # Print summary
511
+ print_summary(results, args)
512
 
513
+ logger.info(f"Evaluation completed successfully!")
514
+ logger.info(f"📁 Results saved to: {output_path}")
515
 
516
+ except KeyboardInterrupt:
517
+ logger.info("Evaluation interrupted by user")
518
+ sys.exit(1)
519
  except Exception as e:
520
+ logger.error(f"Evaluation failed: {e}")
521
+ logger.debug("Full traceback:", exc_info=True)
522
+ sys.exit(1)
523
 
524
 
525
  if __name__ == "__main__":