alverciito commited on
Commit
8068e2f
·
1 Parent(s): b8b34d9

fix model config bug

Browse files
Files changed (1) hide show
  1. model.py +8 -7
model.py CHANGED
@@ -137,7 +137,7 @@ class SentenceCoseNet(PreTrainedModel):
137
  super().__init__(config)
138
 
139
  # Core PyTorch model
140
- self.model = SegmentationNetwork(self.to_model_config())
141
 
142
  # Initialize weights following HF conventions
143
  self.post_init()
@@ -303,24 +303,25 @@ class SentenceCoseNet(PreTrainedModel):
303
  **kwargs,
304
  )
305
 
306
- def to_model_config(self) -> ModelConfig:
 
307
  """
308
  Convert Hugging Face config to internal ModelConfig.
309
  """
310
  mc = ModelConfig()
311
 
312
  # Core dimensions
313
- mc.vocab_size = self.vocab_size
314
- mc.model_dim = self.emb_dim
315
- mc.valid_padding = self.valid_padding
316
 
317
  # CoSeNet config
318
- mc.cosenet = CoSeNetConfig(**self.cosenet)
319
 
320
  # Transformer stack
321
  mc.transformers = [
322
  TransformerConfig(**cfg)
323
- for cfg in self.transformers
324
  ]
325
 
326
  return mc
 
137
  super().__init__(config)
138
 
139
  # Core PyTorch model
140
+ self.model = SegmentationNetwork(self.to_model_config(config))
141
 
142
  # Initialize weights following HF conventions
143
  self.post_init()
 
303
  **kwargs,
304
  )
305
 
306
+ @staticmethod
307
+ def to_model_config(config: SentenceCoseNetConfig) -> ModelConfig:
308
  """
309
  Convert Hugging Face config to internal ModelConfig.
310
  """
311
  mc = ModelConfig()
312
 
313
  # Core dimensions
314
+ mc.vocab_size = config.vocab_size
315
+ mc.model_dim = config.emb_dim
316
+ mc.valid_padding = config.valid_padding
317
 
318
  # CoSeNet config
319
+ mc.cosenet = CoSeNetConfig(**config.cosenet)
320
 
321
  # Transformer stack
322
  mc.transformers = [
323
  TransformerConfig(**cfg)
324
+ for cfg in config.transformers
325
  ]
326
 
327
  return mc