alverciito commited on
Commit ·
8068e2f
1
Parent(s): b8b34d9
fix model config bug
Browse files
model.py
CHANGED
|
@@ -137,7 +137,7 @@ class SentenceCoseNet(PreTrainedModel):
|
|
| 137 |
super().__init__(config)
|
| 138 |
|
| 139 |
# Core PyTorch model
|
| 140 |
-
self.model = SegmentationNetwork(self.to_model_config())
|
| 141 |
|
| 142 |
# Initialize weights following HF conventions
|
| 143 |
self.post_init()
|
|
@@ -303,24 +303,25 @@ class SentenceCoseNet(PreTrainedModel):
|
|
| 303 |
**kwargs,
|
| 304 |
)
|
| 305 |
|
| 306 |
-
|
|
|
|
| 307 |
"""
|
| 308 |
Convert Hugging Face config to internal ModelConfig.
|
| 309 |
"""
|
| 310 |
mc = ModelConfig()
|
| 311 |
|
| 312 |
# Core dimensions
|
| 313 |
-
mc.vocab_size =
|
| 314 |
-
mc.model_dim =
|
| 315 |
-
mc.valid_padding =
|
| 316 |
|
| 317 |
# CoSeNet config
|
| 318 |
-
mc.cosenet = CoSeNetConfig(**
|
| 319 |
|
| 320 |
# Transformer stack
|
| 321 |
mc.transformers = [
|
| 322 |
TransformerConfig(**cfg)
|
| 323 |
-
for cfg in
|
| 324 |
]
|
| 325 |
|
| 326 |
return mc
|
|
|
|
| 137 |
super().__init__(config)
|
| 138 |
|
| 139 |
# Core PyTorch model
|
| 140 |
+
self.model = SegmentationNetwork(self.to_model_config(config))
|
| 141 |
|
| 142 |
# Initialize weights following HF conventions
|
| 143 |
self.post_init()
|
|
|
|
| 303 |
**kwargs,
|
| 304 |
)
|
| 305 |
|
| 306 |
+
@staticmethod
|
| 307 |
+
def to_model_config(config: SentenceCoseNetConfig) -> ModelConfig:
|
| 308 |
"""
|
| 309 |
Convert Hugging Face config to internal ModelConfig.
|
| 310 |
"""
|
| 311 |
mc = ModelConfig()
|
| 312 |
|
| 313 |
# Core dimensions
|
| 314 |
+
mc.vocab_size = config.vocab_size
|
| 315 |
+
mc.model_dim = config.emb_dim
|
| 316 |
+
mc.valid_padding = config.valid_padding
|
| 317 |
|
| 318 |
# CoSeNet config
|
| 319 |
+
mc.cosenet = CoSeNetConfig(**config.cosenet)
|
| 320 |
|
| 321 |
# Transformer stack
|
| 322 |
mc.transformers = [
|
| 323 |
TransformerConfig(**cfg)
|
| 324 |
+
for cfg in config.transformers
|
| 325 |
]
|
| 326 |
|
| 327 |
return mc
|