Charlie81 commited on
Commit
927465d
·
1 Parent(s): 077e7bc

Revert "18k checkpoint"

Browse files

This reverts commit 555d9b9825121563de4138cbe24ac43bd5bf5f89.

Files changed (2) hide show
  1. myolmoe/modeling_myolmoe.py +105 -155
  2. scripts/eval.py +50 -107
myolmoe/modeling_myolmoe.py CHANGED
@@ -14,124 +14,107 @@ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_u
14
  from transformers.modeling_utils import PreTrainedModel
15
  from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
16
  from transformers.utils import logging
17
- # from transformers.models.olmoe.configuration_olmoe import MyOlmoeConfig
18
  from transformers.configuration_utils import PretrainedConfig
19
  from transformers.modeling_rope_utils import rope_config_validation
20
 
21
- from dataclasses import dataclass, field
22
- from typing import Optional, List, Any
23
- from transformers import PretrainedConfig
24
-
25
- from dataclasses import dataclass, field
26
- from typing import Optional, List, Dict, Any
27
- from transformers import PretrainedConfig
28
-
29
- @dataclass
30
- class MyOlmoeConfig(PretrainedConfig):
31
- """
32
- Configuration class for MyOlmoe model.
33
- """
34
- model_type: str = "olmoe" # Keep as "olmoe" to match your trained model
35
-
36
- # Core model parameters
37
- vocab_size: int = 50304
38
- hidden_size: int = 2048
39
- intermediate_size: int = 1024
40
- num_hidden_layers: int = 16
41
- num_attention_heads: int = 16
42
- num_key_value_heads: int = 16
43
- max_position_embeddings: int = 4096
44
-
45
- # Expert parameters
46
- num_experts: int = 64
47
- num_experts_per_tok: int = 2
48
- num_small_experts: int = 0
49
- small_expert_count: int = 64
50
- small_expert_intermediate_ratio: int = 16
51
- small_expert_intermediate_size: int = 0
52
- small_expert_sparsity_coef: float = 0.1
53
- small_expert_strategy: str = "constant"
54
- max_small_expert_count: int = 64
55
-
56
- # Attention parameters
57
- attention_bias: bool = False
58
- attention_dropout: float = 0.0
59
- clip_qkv: Optional[float] = None
60
-
61
- # Normalization and activation
62
- hidden_act: str = "silu"
63
- rms_norm_eps: float = 1e-05
64
- norm_topk_prob: bool = False
65
-
66
- # Router parameters
67
- router_aux_loss_coef: float = 0.01
68
- output_router_logits: bool = False
69
 
70
- # Training parameters
71
- initializer_range: float = 0.02
72
- tie_word_embeddings: bool = False
73
- use_cache: bool = True
74
-
75
- # RoPE parameters
76
- rope_theta: float = 10000.0
77
- rope_scaling: Optional[dict] = None
78
-
79
- # Token IDs - Set proper defaults
80
- pad_token_id: int = 1
81
- eos_token_id: int = 50279
82
- bos_token_id: int = 1
83
-
84
- # Model architecture
85
- architectures: List[str] = field(default_factory=lambda: ["MyOlmoeForCausalLM"])
86
-
87
- def __init__(self, **kwargs):
88
- # Handle model loading parameters that shouldn't go to config
89
- model_loading_params = ['torch_dtype', 'device_map', 'low_cpu_mem_usage',
90
- 'load_in_8bit', 'load_in_4bit', 'quantization_config']
91
- for param in model_loading_params:
92
- kwargs.pop(param, None)
93
-
94
- # Set defaults for any missing required fields
95
- if 'pad_token_id' not in kwargs:
96
- kwargs['pad_token_id'] = self.pad_token_id
97
- if 'eos_token_id' not in kwargs:
98
- kwargs['eos_token_id'] = self.eos_token_id
99
- if 'bos_token_id' not in kwargs:
100
- kwargs['bos_token_id'] = self.bos_token_id
101
- if 'architectures' not in kwargs:
102
- kwargs['architectures'] = ["MyOlmoeForCausalLM"]
103
-
104
- # Initialize the parent class first
105
- super().__init__(**kwargs)
106
-
107
- # Then set dataclass fields from remaining kwargs or defaults
108
- for field_name, field_def in self.__dataclass_fields__.items():
109
- if hasattr(self, field_name):
110
- continue # Already set by parent
111
- if field_name in kwargs:
112
- setattr(self, field_name, kwargs[field_name])
113
- else:
114
- # Use default value from dataclass field
115
- if field_def.default != field_def.default_factory:
116
- setattr(self, field_name, field_def.default)
117
- elif field_def.default_factory != field_def.default_factory: # type: ignore
118
- setattr(self, field_name, field_def.default_factory())
119
-
120
- @classmethod
121
- def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
122
- """Override from_pretrained to handle the model type properly."""
123
- config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
- # Keep the original model_type from the saved config
126
- # This allows loading models trained with "olmoe" type
127
- if 'model_type' in config_dict:
128
- original_model_type = config_dict['model_type']
129
- # But register with the class model_type for compatibility
130
- if original_model_type == "olmoe":
131
- config_dict['model_type'] = "olmoe" # Keep as olmoe
132
 
133
- return cls.from_dict(config_dict, **kwargs)
134
-
 
 
 
 
 
 
 
 
 
 
135
 
136
  logger = logging.get_logger(__name__)
137
 
@@ -203,7 +186,7 @@ ALL_LAYERNORM_LAYERS.append(OlmoeRMSNorm)
203
 
204
 
205
  class OlmoeRotaryEmbedding(nn.Module):
206
- def __init__(self, config: MyOlmoeConfig, device=None):
207
  super().__init__()
208
  if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
209
  self.rope_type = config.rope_scaling.get(
@@ -289,7 +272,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
289
 
290
 
291
  class OlmoeAttention(nn.Module):
292
- def __init__(self, config: MyOlmoeConfig, layer_idx: Optional[int] = None):
293
  super().__init__()
294
  self.config = config
295
  self.layer_idx = layer_idx
@@ -574,14 +557,11 @@ class OlmoeSparseMoeBlock(nn.Module):
574
  self.num_experts = config.num_experts
575
  self.top_k = config.num_experts_per_tok
576
  self.norm_topk_prob = config.norm_topk_prob
577
-
578
- #########
579
- self.register_buffer('expert_usage_counts', torch.zeros(config.num_experts + config.max_small_expert_count, dtype=torch.long))
580
- self.expert_usage_counts: torch.Tensor # For type hinting
581
- #########
582
 
 
583
  in_second_half = layer_idx >= self.total_layers // 2
584
 
 
585
  if in_second_half:
586
  second_half_idx = layer_idx - (self.total_layers // 2)
587
  num_second_half_blocks = self.total_layers - (self.total_layers // 2)
@@ -589,6 +569,7 @@ class OlmoeSparseMoeBlock(nn.Module):
589
  if config.small_expert_strategy == "constant":
590
  self.num_small_experts = config.max_small_expert_count // num_second_half_blocks
591
  elif config.small_expert_strategy == "increment":
 
592
  self.num_small_experts = (
593
  (second_half_idx + 1) * config.max_small_expert_count // ((num_second_half_blocks * (num_second_half_blocks + 1)) // 2)
594
  )
@@ -629,12 +610,6 @@ class OlmoeSparseMoeBlock(nn.Module):
629
  if self.norm_topk_prob:
630
  routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
631
 
632
- #########
633
- expert_indices = selected_experts.flatten()
634
- unique_experts, counts = torch.unique(expert_indices, return_counts=True)
635
- self.expert_usage_counts[unique_experts] += counts.to(self.expert_usage_counts.device)
636
- #########
637
-
638
  final_hidden_states = torch.zeros_like(hidden_states)
639
  expert_mask = torch.nn.functional.one_hot(
640
  selected_experts,
@@ -656,35 +631,10 @@ class OlmoeSparseMoeBlock(nn.Module):
656
  final_hidden_states.index_add_(0, top_x, current_output.to(hidden_states.dtype))
657
 
658
  return final_hidden_states.view(batch_size, sequence_length, hidden_dim), combined_logits
659
-
660
- #########
661
- def __del__(self):
662
- # Print expert usage statistics when the block is deconstructed
663
- if hasattr(self, 'expert_usage_counts'):
664
- total_usage = self.expert_usage_counts.sum().item()
665
- if total_usage > 0:
666
- print(f"\nExpert Usage Statistics for Layer {self.layer_idx}:")
667
- print(f"Total tokens processed: {total_usage}")
668
-
669
- # Regular experts
670
- if self.num_experts > 0:
671
- regular_usage = self.expert_usage_counts[:self.num_experts]
672
- print("\nRegular Experts:")
673
- for i, count in enumerate(regular_usage):
674
- print(f"Expert {i}: {count.item()} uses ({count.item()/total_usage:.2%})")
675
-
676
- # Small experts
677
- if self.num_small_experts > 0:
678
- small_usage = self.expert_usage_counts[self.num_experts:self.num_experts+self.num_small_experts]
679
- print("\nSmall Experts:")
680
- for i, count in enumerate(small_usage):
681
- print(f"Small Expert {i}: {count.item()} uses ({count.item()/total_usage:.2%})")
682
-
683
- print("\n")
684
- #########
685
 
686
  class OlmoeDecoderLayer(nn.Module):
687
- def __init__(self, config: MyOlmoeConfig, layer_idx: int):
688
  super().__init__()
689
  self.hidden_size = config.hidden_size
690
  self.self_attn = OLMOE_ATTENTION_CLASSES[config._attn_implementation](
@@ -740,7 +690,7 @@ class OlmoeDecoderLayer(nn.Module):
740
 
741
 
742
  class OlmoePreTrainedModel(PreTrainedModel):
743
- config_class = MyOlmoeConfig
744
  base_model_prefix = "model"
745
  supports_gradient_checkpointing = True
746
  _no_split_modules = ["OlmoeDecoderLayer"]
@@ -766,7 +716,7 @@ class OlmoePreTrainedModel(PreTrainedModel):
766
 
767
 
768
  class OlmoeModel(OlmoePreTrainedModel):
769
- def __init__(self, config: MyOlmoeConfig):
770
  super().__init__(config)
771
  self.padding_idx = config.pad_token_id
772
  self.vocab_size = config.vocab_size
@@ -1171,4 +1121,4 @@ class MyOlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin):
1171
  router_logits=outputs.router_logits,
1172
  )
1173
 
1174
- __all__ = ["MyOlmoeForCausalLM", "OlmoeModel", "OlmoePreTrainedModel", "MyOlmoeConfig"]
 
14
  from transformers.modeling_utils import PreTrainedModel
15
  from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
16
  from transformers.utils import logging
17
+ # from transformers.models.olmoe.configuration_olmoe import OlmoeConfig
18
  from transformers.configuration_utils import PretrainedConfig
19
  from transformers.modeling_rope_utils import rope_config_validation
20
 
21
+ class OlmoeConfig(PretrainedConfig):
22
+ r"""
23
+ This is the configuration class to store the configuration of a [`OlmoeModel`].
24
+ [Previous docstring remains the same...]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ Args:
27
+ [Previous args remain the same...]
28
+ small_expert_intermediate_ratio (`float`, *optional*, defaults to 0.5):
29
+ Ratio of intermediate size for small experts compared to regular experts.
30
+ small_expert_count (`int`, *optional*, defaults to 64):
31
+ Frequency of small experts - every Nth expert will be small.
32
+ small_expert_sparsity_coef (`float`, *optional*, defaults to 0.1):
33
+ Coefficient for small expert load balancing loss.
34
+ """
35
+ model_type = "olmoe"
36
+ keys_to_ignore_at_inference = ["past_key_values"]
37
+
38
+ def __init__(
39
+ self,
40
+ vocab_size=50304,
41
+ hidden_size=2048,
42
+ intermediate_size=2048,
43
+ num_hidden_layers=16,
44
+ num_attention_heads=16,
45
+ num_key_value_heads=None,
46
+ hidden_act="silu",
47
+ max_position_embeddings=4096,
48
+ initializer_range=0.02,
49
+ rms_norm_eps=1e-05,
50
+ use_cache=True,
51
+ pad_token_id=1,
52
+ bos_token_id=None,
53
+ eos_token_id=50279,
54
+ tie_word_embeddings=False,
55
+ rope_theta=10000.0,
56
+ rope_scaling=None,
57
+ attention_bias=False,
58
+ attention_dropout=0.0,
59
+ clip_qkv=None,
60
+ num_experts_per_tok=8,
61
+ num_experts=64,
62
+ output_router_logits=False,
63
+ router_aux_loss_coef=0.01,
64
+ norm_topk_prob=False,
65
+ small_expert_intermediate_ratio=64,
66
+ small_expert_count=64,
67
+ small_expert_sparsity_coef=0.1,
68
+ small_expert_strategy="constant", # increment
69
+ max_small_expert_count=64,
70
+ **kwargs,
71
+ ):
72
+ self.vocab_size = vocab_size
73
+ self.max_position_embeddings = max_position_embeddings
74
+ self.hidden_size = hidden_size
75
+ self.intermediate_size = intermediate_size
76
+ self.num_hidden_layers = num_hidden_layers
77
+ self.num_attention_heads = num_attention_heads
78
+
79
+ # for backward compatibility
80
+ if num_key_value_heads is None:
81
+ num_key_value_heads = num_attention_heads
82
+
83
+ self.num_key_value_heads = num_key_value_heads
84
+ self.hidden_act = hidden_act
85
+ self.initializer_range = initializer_range
86
+ self.rms_norm_eps = rms_norm_eps
87
+ self.use_cache = use_cache
88
+ self.rope_theta = rope_theta
89
+ self.rope_scaling = rope_scaling
90
+ self.attention_bias = attention_bias
91
+ self.attention_dropout = attention_dropout
92
+ self.clip_qkv = clip_qkv
93
+ self.num_experts_per_tok = num_experts_per_tok
94
+ self.num_experts = num_experts
95
+ self.output_router_logits = output_router_logits
96
+ self.router_aux_loss_coef = router_aux_loss_coef
97
+ self.norm_topk_prob = norm_topk_prob
98
 
99
+ # Small expert parameters
100
+ self.small_expert_intermediate_ratio = small_expert_intermediate_ratio
101
+ self.small_expert_count = small_expert_count
102
+ self.small_expert_sparsity_coef = small_expert_sparsity_coef
103
+ self.small_expert_strategy = small_expert_strategy
104
+ self.max_small_expert_count = max_small_expert_count
 
105
 
106
+ # Validate the correctness of rotary position embeddings parameters
107
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
108
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
109
+ rope_config_validation(self)
110
+
111
+ super().__init__(
112
+ pad_token_id=pad_token_id,
113
+ bos_token_id=bos_token_id,
114
+ eos_token_id=eos_token_id,
115
+ tie_word_embeddings=tie_word_embeddings,
116
+ **kwargs,
117
+ )
118
 
119
  logger = logging.get_logger(__name__)
120
 
 
186
 
187
 
188
  class OlmoeRotaryEmbedding(nn.Module):
189
+ def __init__(self, config: OlmoeConfig, device=None):
190
  super().__init__()
191
  if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
192
  self.rope_type = config.rope_scaling.get(
 
272
 
273
 
274
  class OlmoeAttention(nn.Module):
275
+ def __init__(self, config: OlmoeConfig, layer_idx: Optional[int] = None):
276
  super().__init__()
277
  self.config = config
278
  self.layer_idx = layer_idx
 
557
  self.num_experts = config.num_experts
558
  self.top_k = config.num_experts_per_tok
559
  self.norm_topk_prob = config.norm_topk_prob
 
 
 
 
 
560
 
561
+ # Determine if this block is in the second half
562
  in_second_half = layer_idx >= self.total_layers // 2
563
 
564
+ # Determine small expert count for this layer
565
  if in_second_half:
566
  second_half_idx = layer_idx - (self.total_layers // 2)
567
  num_second_half_blocks = self.total_layers - (self.total_layers // 2)
 
569
  if config.small_expert_strategy == "constant":
570
  self.num_small_experts = config.max_small_expert_count // num_second_half_blocks
571
  elif config.small_expert_strategy == "increment":
572
+ # Linearly scale small experts from 1 to max_small_expert_count
573
  self.num_small_experts = (
574
  (second_half_idx + 1) * config.max_small_expert_count // ((num_second_half_blocks * (num_second_half_blocks + 1)) // 2)
575
  )
 
610
  if self.norm_topk_prob:
611
  routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
612
 
 
 
 
 
 
 
613
  final_hidden_states = torch.zeros_like(hidden_states)
614
  expert_mask = torch.nn.functional.one_hot(
615
  selected_experts,
 
631
  final_hidden_states.index_add_(0, top_x, current_output.to(hidden_states.dtype))
632
 
633
  return final_hidden_states.view(batch_size, sequence_length, hidden_dim), combined_logits
634
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
635
 
636
  class OlmoeDecoderLayer(nn.Module):
637
+ def __init__(self, config: OlmoeConfig, layer_idx: int):
638
  super().__init__()
639
  self.hidden_size = config.hidden_size
640
  self.self_attn = OLMOE_ATTENTION_CLASSES[config._attn_implementation](
 
690
 
691
 
692
  class OlmoePreTrainedModel(PreTrainedModel):
693
+ config_class = OlmoeConfig
694
  base_model_prefix = "model"
695
  supports_gradient_checkpointing = True
696
  _no_split_modules = ["OlmoeDecoderLayer"]
 
716
 
717
 
718
  class OlmoeModel(OlmoePreTrainedModel):
719
+ def __init__(self, config: OlmoeConfig):
720
  super().__init__(config)
721
  self.padding_idx = config.pad_token_id
722
  self.vocab_size = config.vocab_size
 
1121
  router_logits=outputs.router_logits,
1122
  )
1123
 
1124
+ __all__ = ["MyOlmoeForCausalLM", "OlmoeModel", "OlmoePreTrainedModel", "OlmoeConfig"]
scripts/eval.py CHANGED
@@ -183,6 +183,12 @@ def load_transformers_model(args) -> HFLM:
183
  def load_custom_model(args) -> HFLM:
184
  """
185
  Load custom MyOLMoE model (uses top-k routing by default).
 
 
 
 
 
 
186
  """
187
  logger.info(f"Loading custom MyOLMoE model: {args.model_path}")
188
  logger.info("Using top-k routing (default)")
@@ -195,84 +201,49 @@ def load_custom_model(args) -> HFLM:
195
  logger.warning(f"Custom model path not found: {args.custom_model_path}")
196
 
197
  try:
198
- # Import custom model class and config
199
- from modeling_myolmoe import MyOlmoeForCausalLM, MyOlmoeConfig
200
- logger.info("Successfully imported MyOlmoeForCausalLM and MyOlmoeConfig")
201
-
202
- # IMPORTANT: Register with "olmoe" since that's what your model was trained with
203
- from transformers import AutoConfig, AutoModelForCausalLM
204
- AutoConfig.register("olmoe", MyOlmoeConfig, exist_ok=True) # Use exist_ok=True
205
- AutoModelForCausalLM.register(MyOlmoeConfig, MyOlmoeForCausalLM, exist_ok=True)
206
- logger.info("Registered MyOlmoeForCausalLM with MyOlmoeConfig for 'olmoe' type")
207
-
208
  except ImportError as e:
209
  logger.error(f"Failed to import custom model: {e}")
210
  logger.error("Make sure the custom model code is available in the specified path")
211
  raise
212
 
213
- # Load model manually to avoid wrapper issues
214
- logger.info("Loading model manually to avoid wrapper issues...")
 
 
 
215
 
216
- try:
217
- # Load tokenizer first
218
- tokenizer = AutoTokenizer.from_pretrained(
219
- args.model_path,
220
- trust_remote_code=args.trust_remote_code
221
- )
222
-
223
- # Load config - this should now work with the olmoe type
224
- model_config = AutoConfig.from_pretrained(
225
- args.model_path,
226
- trust_remote_code=args.trust_remote_code
227
- )
228
-
229
- logger.info(f"Loaded config type: {type(model_config)}")
230
- logger.info(f"Config model_type: {model_config.model_type}")
231
-
232
- # Verify the config is properly initialized
233
- if not hasattr(model_config, '__dataclass_fields__'):
234
- logger.warning("Config is not recognized as a dataclass, attempting to recreate...")
235
- # Recreate config as proper dataclass instance
236
- config_dict = model_config.to_dict()
237
- model_config = MyOlmoeConfig(**config_dict)
238
-
239
- # Prepare model loading kwargs
240
- model_kwargs = {
241
- 'config': model_config,
242
- 'trust_remote_code': args.trust_remote_code,
243
- }
244
-
245
- # Add torch_dtype if specified
246
- if args.dtype == "bfloat16":
247
- model_kwargs['torch_dtype'] = torch.bfloat16
248
- elif args.dtype == "float16":
249
- model_kwargs['torch_dtype'] = torch.float16
250
- elif args.dtype == "float32":
251
- model_kwargs['torch_dtype'] = torch.float32
252
-
253
- # Load model instance
254
- model_instance = AutoModelForCausalLM.from_pretrained(
255
- args.model_path,
256
- **model_kwargs
257
- )
258
-
259
- logger.info(f"Loaded model type: {type(model_instance)}")
260
-
261
- # Create HFLM wrapper
262
- model = HFLM(
263
- pretrained=model_instance,
264
- tokenizer=tokenizer,
265
- device=args.device,
266
- batch_size=args.batch_size,
267
- max_batch_size=args.max_batch_size
268
- )
269
-
270
- except Exception as e:
271
- logger.error(f"Failed to load custom model: {e}")
272
- logger.error(f"Error type: {type(e)}")
273
- import traceback
274
- logger.error(f"Traceback: {traceback.format_exc()}")
275
- raise
276
 
277
  logger.info("Custom model loaded successfully")
278
  return model
@@ -368,41 +339,13 @@ def run_evaluation(args) -> Dict[str, Any]:
368
  logger.info(f"Few-shot examples: {args.num_fewshot}")
369
  logger.info(f"Batch size: {args.batch_size}")
370
 
371
- # Debug information - FIXED
372
- print("Type of model being passed:", type(model))
373
- if hasattr(model, '_model') and hasattr(model._model, 'config'):
374
- print("Model config:", model._model.config)
375
- elif hasattr(model, 'config'):
376
- print("Model config:", model.config)
377
- else:
378
- print("Model config: Not accessible")
379
-
380
- # Ensure model is properly initialized
381
- if hasattr(model, '_model') and model._model is not None:
382
- logger.info("Model is properly loaded and wrapped")
383
- else:
384
- logger.warning("Model wrapper may not be properly initialized")
385
-
386
- try:
387
- results = evaluator.simple_evaluate(
388
- model=model,
389
- tasks=args.tasks,
390
- num_fewshot=args.num_fewshot,
391
- limit=args.limit,
392
- write_out=args.write_out,
393
- )
394
- except Exception as e:
395
- logger.error(f"Evaluation failed with error: {e}")
396
- logger.error("This might be due to model registration or configuration issues")
397
-
398
- # Additional debugging
399
- logger.error(f"Model type: {type(model)}")
400
- if hasattr(model, '_model'):
401
- logger.error(f"Internal model type: {type(model._model)}")
402
- if hasattr(model._model, 'config'):
403
- logger.error(f"Internal model config type: {type(model._model.config)}")
404
-
405
- raise
406
 
407
  logger.info("Evaluation completed successfully")
408
  return results
 
183
  def load_custom_model(args) -> HFLM:
184
  """
185
  Load custom MyOLMoE model (uses top-k routing by default).
186
+
187
+ Args:
188
+ args: Parsed command line arguments
189
+
190
+ Returns:
191
+ HFLM: Wrapped model ready for evaluation
192
  """
193
  logger.info(f"Loading custom MyOLMoE model: {args.model_path}")
194
  logger.info("Using top-k routing (default)")
 
201
  logger.warning(f"Custom model path not found: {args.custom_model_path}")
202
 
203
  try:
204
+ # Import custom model class
205
+ from modeling_myolmoe import MyOlmoeForCausalLM
206
+ logger.info("Successfully imported MyOlmoeForCausalLM")
 
 
 
 
 
 
 
207
  except ImportError as e:
208
  logger.error(f"Failed to import custom model: {e}")
209
  logger.error("Make sure the custom model code is available in the specified path")
210
  raise
211
 
212
+ # Load model configuration
213
+ config = AutoConfig.from_pretrained(
214
+ args.model_path,
215
+ trust_remote_code=args.trust_remote_code
216
+ )
217
 
218
+ logger.info("Model will use default top-k routing configuration")
219
+
220
+ # Determine torch dtype
221
+ if args.dtype == "auto":
222
+ torch_dtype = "auto"
223
+ else:
224
+ torch_dtype = {
225
+ "float16": torch.float16,
226
+ "bfloat16": torch.bfloat16,
227
+ "float32": torch.float32
228
+ }[args.dtype]
229
+
230
+ # Load the custom model
231
+ hf_model = MyOlmoeForCausalLM.from_pretrained(
232
+ args.model_path,
233
+ config=config,
234
+ torch_dtype=torch_dtype,
235
+ device_map="auto" if args.device == "auto" else None,
236
+ trust_remote_code=args.trust_remote_code
237
+ ).eval()
238
+
239
+ # Wrap in HFLM
240
+ model = HFLM(
241
+ pretrained=hf_model,
242
+ device=args.device,
243
+ batch_size=args.batch_size,
244
+ max_batch_size=args.max_batch_size,
245
+ dtype=args.dtype
246
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
  logger.info("Custom model loaded successfully")
249
  return model
 
339
  logger.info(f"Few-shot examples: {args.num_fewshot}")
340
  logger.info(f"Batch size: {args.batch_size}")
341
 
342
+ results = evaluator.simple_evaluate(
343
+ model=model,
344
+ tasks=args.tasks,
345
+ num_fewshot=args.num_fewshot,
346
+ limit=args.limit,
347
+ write_out=args.write_out,
348
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
 
350
  logger.info("Evaluation completed successfully")
351
  return results