Charlie81 commited on
Commit
077e7bc
·
1 Parent(s): 44006e7

attempts to fix more

Browse files
Files changed (2) hide show
  1. myolmoe/modeling_myolmoe.py +49 -14
  2. scripts/eval.py +37 -23
myolmoe/modeling_myolmoe.py CHANGED
@@ -22,12 +22,16 @@ from dataclasses import dataclass, field
22
  from typing import Optional, List, Any
23
  from transformers import PretrainedConfig
24
 
 
 
 
 
25
  @dataclass
26
  class MyOlmoeConfig(PretrainedConfig):
27
  """
28
  Configuration class for MyOlmoe model.
29
  """
30
- model_type: str = "myolmoe"
31
 
32
  # Core model parameters
33
  vocab_size: int = 50304
@@ -72,31 +76,62 @@ class MyOlmoeConfig(PretrainedConfig):
72
  rope_theta: float = 10000.0
73
  rope_scaling: Optional[dict] = None
74
 
75
- # Token IDs
76
  pad_token_id: int = 1
77
  eos_token_id: int = 50279
 
78
 
79
  # Model architecture
80
  architectures: List[str] = field(default_factory=lambda: ["MyOlmoeForCausalLM"])
81
 
82
  def __init__(self, **kwargs):
83
- # Remove torch_dtype and other model loading parameters that shouldn't be in config
84
- model_loading_params = ['torch_dtype', 'device_map', 'low_cpu_mem_usage']
 
85
  for param in model_loading_params:
86
  kwargs.pop(param, None)
87
 
88
- # Initialize dataclass fields
89
- for field in self.__dataclass_fields__:
90
- if field in kwargs:
91
- setattr(self, field, kwargs.pop(field))
92
-
93
- # Call parent init with remaining kwargs
 
 
 
 
 
94
  super().__init__(**kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- def __post_init__(self):
97
- """Post-initialization to ensure compatibility with PretrainedConfig."""
98
- # This is handled in __init__ now
99
- pass
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  logger = logging.get_logger(__name__)
102
 
 
22
  from typing import Optional, List, Any
23
  from transformers import PretrainedConfig
24
 
25
+ from dataclasses import dataclass, field
26
+ from typing import Optional, List, Dict, Any
27
+ from transformers import PretrainedConfig
28
+
29
  @dataclass
30
  class MyOlmoeConfig(PretrainedConfig):
31
  """
32
  Configuration class for MyOlmoe model.
33
  """
34
+ model_type: str = "olmoe" # Keep as "olmoe" to match your trained model
35
 
36
  # Core model parameters
37
  vocab_size: int = 50304
 
76
  rope_theta: float = 10000.0
77
  rope_scaling: Optional[dict] = None
78
 
79
+ # Token IDs - Set proper defaults
80
  pad_token_id: int = 1
81
  eos_token_id: int = 50279
82
+ bos_token_id: int = 1
83
 
84
  # Model architecture
85
  architectures: List[str] = field(default_factory=lambda: ["MyOlmoeForCausalLM"])
86
 
87
  def __init__(self, **kwargs):
88
+ # Handle model loading parameters that shouldn't go to config
89
+ model_loading_params = ['torch_dtype', 'device_map', 'low_cpu_mem_usage',
90
+ 'load_in_8bit', 'load_in_4bit', 'quantization_config']
91
  for param in model_loading_params:
92
  kwargs.pop(param, None)
93
 
94
+ # Set defaults for any missing required fields
95
+ if 'pad_token_id' not in kwargs:
96
+ kwargs['pad_token_id'] = self.pad_token_id
97
+ if 'eos_token_id' not in kwargs:
98
+ kwargs['eos_token_id'] = self.eos_token_id
99
+ if 'bos_token_id' not in kwargs:
100
+ kwargs['bos_token_id'] = self.bos_token_id
101
+ if 'architectures' not in kwargs:
102
+ kwargs['architectures'] = ["MyOlmoeForCausalLM"]
103
+
104
+ # Initialize the parent class first
105
  super().__init__(**kwargs)
106
+
107
+ # Then set dataclass fields from remaining kwargs or defaults
108
+ for field_name, field_def in self.__dataclass_fields__.items():
109
+ if hasattr(self, field_name):
110
+ continue # Already set by parent
111
+ if field_name in kwargs:
112
+ setattr(self, field_name, kwargs[field_name])
113
+ else:
114
+ # Use default value from dataclass field
115
+ if field_def.default != field_def.default_factory:
116
+ setattr(self, field_name, field_def.default)
117
+ elif field_def.default_factory != field_def.default_factory: # type: ignore
118
+ setattr(self, field_name, field_def.default_factory())
119
 
120
+ @classmethod
121
+ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
122
+ """Override from_pretrained to handle the model type properly."""
123
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
124
+
125
+ # Keep the original model_type from the saved config
126
+ # This allows loading models trained with "olmoe" type
127
+ if 'model_type' in config_dict:
128
+ original_model_type = config_dict['model_type']
129
+ # But register with the class model_type for compatibility
130
+ if original_model_type == "olmoe":
131
+ config_dict['model_type'] = "olmoe" # Keep as olmoe
132
+
133
+ return cls.from_dict(config_dict, **kwargs)
134
+
135
 
136
  logger = logging.get_logger(__name__)
137
 
scripts/eval.py CHANGED
@@ -183,12 +183,6 @@ def load_transformers_model(args) -> HFLM:
183
  def load_custom_model(args) -> HFLM:
184
  """
185
  Load custom MyOLMoE model (uses top-k routing by default).
186
-
187
- Args:
188
- args: Parsed command line arguments
189
-
190
- Returns:
191
- HFLM: Wrapped model ready for evaluation
192
  """
193
  logger.info(f"Loading custom MyOLMoE model: {args.model_path}")
194
  logger.info("Using top-k routing (default)")
@@ -205,15 +199,11 @@ def load_custom_model(args) -> HFLM:
205
  from modeling_myolmoe import MyOlmoeForCausalLM, MyOlmoeConfig
206
  logger.info("Successfully imported MyOlmoeForCausalLM and MyOlmoeConfig")
207
 
208
- # Check if config is a dataclass
209
- if not hasattr(MyOlmoeConfig, '__dataclass_fields__'):
210
- logger.warning("MyOlmoeConfig is not a dataclass, this may cause issues")
211
-
212
- # Register the custom model class with the correct config
213
  from transformers import AutoConfig, AutoModelForCausalLM
214
- AutoConfig.register("myolmoe", MyOlmoeConfig)
215
- AutoModelForCausalLM.register(MyOlmoeConfig, MyOlmoeForCausalLM)
216
- logger.info("Registered MyOlmoeForCausalLM with MyOlmoeConfig")
217
 
218
  except ImportError as e:
219
  logger.error(f"Failed to import custom model: {e}")
@@ -224,31 +214,51 @@ def load_custom_model(args) -> HFLM:
224
  logger.info("Loading model manually to avoid wrapper issues...")
225
 
226
  try:
227
- # Load tokenizer
228
  tokenizer = AutoTokenizer.from_pretrained(
229
  args.model_path,
230
  trust_remote_code=args.trust_remote_code
231
  )
232
 
233
- # Load config using the custom config class
234
- model_config = MyOlmoeConfig.from_pretrained(
235
  args.model_path,
236
  trust_remote_code=args.trust_remote_code
237
  )
238
 
239
- # Debug information
240
  logger.info(f"Loaded config type: {type(model_config)}")
241
  logger.info(f"Config model_type: {model_config.model_type}")
242
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  # Load model instance
244
- model_instance = MyOlmoeForCausalLM.from_pretrained(
245
  args.model_path,
246
- config=model_config,
247
- trust_remote_code=args.trust_remote_code,
248
- torch_dtype=torch.bfloat16 if args.dtype == "bfloat16" else "auto"
249
  )
250
 
251
- # Create HFLM with pre-loaded model
 
 
252
  model = HFLM(
253
  pretrained=model_instance,
254
  tokenizer=tokenizer,
@@ -259,11 +269,15 @@ def load_custom_model(args) -> HFLM:
259
 
260
  except Exception as e:
261
  logger.error(f"Failed to load custom model: {e}")
 
 
 
262
  raise
263
 
264
  logger.info("Custom model loaded successfully")
265
  return model
266
 
 
267
  def validate_model_config(model_path: str, trust_remote_code: bool = False) -> Dict[str, Any]:
268
  """
269
  Validate model configuration and return key information.
 
183
  def load_custom_model(args) -> HFLM:
184
  """
185
  Load custom MyOLMoE model (uses top-k routing by default).
 
 
 
 
 
 
186
  """
187
  logger.info(f"Loading custom MyOLMoE model: {args.model_path}")
188
  logger.info("Using top-k routing (default)")
 
199
  from modeling_myolmoe import MyOlmoeForCausalLM, MyOlmoeConfig
200
  logger.info("Successfully imported MyOlmoeForCausalLM and MyOlmoeConfig")
201
 
202
+ # IMPORTANT: Register with "olmoe" since that's what your model was trained with
 
 
 
 
203
  from transformers import AutoConfig, AutoModelForCausalLM
204
+ AutoConfig.register("olmoe", MyOlmoeConfig, exist_ok=True) # Use exist_ok=True
205
+ AutoModelForCausalLM.register(MyOlmoeConfig, MyOlmoeForCausalLM, exist_ok=True)
206
+ logger.info("Registered MyOlmoeForCausalLM with MyOlmoeConfig for 'olmoe' type")
207
 
208
  except ImportError as e:
209
  logger.error(f"Failed to import custom model: {e}")
 
214
  logger.info("Loading model manually to avoid wrapper issues...")
215
 
216
  try:
217
+ # Load tokenizer first
218
  tokenizer = AutoTokenizer.from_pretrained(
219
  args.model_path,
220
  trust_remote_code=args.trust_remote_code
221
  )
222
 
223
+ # Load config - this should now work with the olmoe type
224
+ model_config = AutoConfig.from_pretrained(
225
  args.model_path,
226
  trust_remote_code=args.trust_remote_code
227
  )
228
 
 
229
  logger.info(f"Loaded config type: {type(model_config)}")
230
  logger.info(f"Config model_type: {model_config.model_type}")
231
 
232
+ # Verify the config is properly initialized
233
+ if not hasattr(model_config, '__dataclass_fields__'):
234
+ logger.warning("Config is not recognized as a dataclass, attempting to recreate...")
235
+ # Recreate config as proper dataclass instance
236
+ config_dict = model_config.to_dict()
237
+ model_config = MyOlmoeConfig(**config_dict)
238
+
239
+ # Prepare model loading kwargs
240
+ model_kwargs = {
241
+ 'config': model_config,
242
+ 'trust_remote_code': args.trust_remote_code,
243
+ }
244
+
245
+ # Add torch_dtype if specified
246
+ if args.dtype == "bfloat16":
247
+ model_kwargs['torch_dtype'] = torch.bfloat16
248
+ elif args.dtype == "float16":
249
+ model_kwargs['torch_dtype'] = torch.float16
250
+ elif args.dtype == "float32":
251
+ model_kwargs['torch_dtype'] = torch.float32
252
+
253
  # Load model instance
254
+ model_instance = AutoModelForCausalLM.from_pretrained(
255
  args.model_path,
256
+ **model_kwargs
 
 
257
  )
258
 
259
+ logger.info(f"Loaded model type: {type(model_instance)}")
260
+
261
+ # Create HFLM wrapper
262
  model = HFLM(
263
  pretrained=model_instance,
264
  tokenizer=tokenizer,
 
269
 
270
  except Exception as e:
271
  logger.error(f"Failed to load custom model: {e}")
272
+ logger.error(f"Error type: {type(e)}")
273
+ import traceback
274
+ logger.error(f"Traceback: {traceback.format_exc()}")
275
  raise
276
 
277
  logger.info("Custom model loaded successfully")
278
  return model
279
 
280
+
281
  def validate_model_config(model_path: str, trust_remote_code: bool = False) -> Dict[str, Any]:
282
  """
283
  Validate model configuration and return key information.