Upload model
Browse files
chatNT.py
CHANGED
|
@@ -588,6 +588,19 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
| 588 |
config_class = ChatNTConfig
|
| 589 |
|
| 590 |
def __init__(self, config: ChatNTConfig) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 591 |
super().__init__(config=config)
|
| 592 |
self.gpt_config = config.gpt_config
|
| 593 |
self.esm_config = config.esm_config
|
|
|
|
| 588 |
config_class = ChatNTConfig
|
| 589 |
|
| 590 |
def __init__(self, config: ChatNTConfig) -> None:
|
| 591 |
+
if isinstance(config, dict):
|
| 592 |
+
# If config is a dictionary instead of ChatNTConfig (which can happen
|
| 593 |
+
# depending how the config was saved), we convert it to the config
|
| 594 |
+
config["gpt_config"]["rope_config"] = RotaryEmbeddingConfig(
|
| 595 |
+
**config["gpt_config"]["rope_config"]
|
| 596 |
+
)
|
| 597 |
+
config["gpt_config"] = GptConfig(**config["gpt_config"])
|
| 598 |
+
config["esm_config"] = ESMTransformerConfig(**config["esm_config"])
|
| 599 |
+
config["perceiver_resampler_config"] = PerceiverResamplerConfig(
|
| 600 |
+
**config["perceiver_resampler_config"]
|
| 601 |
+
)
|
| 602 |
+
config = ChatNTConfig(**config) # type: ignore
|
| 603 |
+
|
| 604 |
super().__init__(config=config)
|
| 605 |
self.gpt_config = config.gpt_config
|
| 606 |
self.esm_config = config.esm_config
|