Charlie81 commited on
Commit
44006e7
·
1 Parent(s): 353cce5

fix config

Browse files
Files changed (1) hide show
  1. myolmoe/modeling_myolmoe.py +16 -5
myolmoe/modeling_myolmoe.py CHANGED
@@ -79,13 +79,24 @@ class MyOlmoeConfig(PretrainedConfig):
79
  # Model architecture
80
  architectures: List[str] = field(default_factory=lambda: ["MyOlmoeForCausalLM"])
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  def __post_init__(self):
83
  """Post-initialization to ensure compatibility with PretrainedConfig."""
84
- super().__init__(
85
- pad_token_id=self.pad_token_id,
86
- eos_token_id=self.eos_token_id,
87
- **{k: v for k, v in self.__dict__.items() if not k.startswith('_')}
88
- )
89
 
90
  logger = logging.get_logger(__name__)
91
 
 
79
  # Model architecture
80
  architectures: List[str] = field(default_factory=lambda: ["MyOlmoeForCausalLM"])
81
 
82
+ def __init__(self, **kwargs):
83
+ # Remove torch_dtype and other model loading parameters that shouldn't be in config
84
+ model_loading_params = ['torch_dtype', 'device_map', 'low_cpu_mem_usage']
85
+ for param in model_loading_params:
86
+ kwargs.pop(param, None)
87
+
88
+ # Initialize dataclass fields
89
+ for field in self.__dataclass_fields__:
90
+ if field in kwargs:
91
+ setattr(self, field, kwargs.pop(field))
92
+
93
+ # Call parent init with remaining kwargs
94
+ super().__init__(**kwargs)
95
+
96
  def __post_init__(self):
97
  """Post-initialization to ensure compatibility with PretrainedConfig."""
98
+ # This is handled in __init__ now
99
+ pass
 
 
 
100
 
101
  logger = logging.get_logger(__name__)
102