Instantiated the nested PretrainedConfig correctly. Added a test to demo
Browse files- .gitignore +1 -0
- .vscode/settings.json +4 -0
- src/config.py +15 -0
- tests/test_config.py +13 -0
.gitignore
CHANGED
|
@@ -3,3 +3,4 @@
|
|
| 3 |
pyrightconfig.json
|
| 4 |
*.jpg
|
| 5 |
*.pyc
|
|
|
|
|
|
| 3 |
pyrightconfig.json
|
| 4 |
*.jpg
|
| 5 |
*.pyc
|
| 6 |
+
.env
|
.vscode/settings.json
CHANGED
|
@@ -20,4 +20,8 @@
|
|
| 20 |
// },
|
| 21 |
// },
|
| 22 |
// "isort.args":["--profile", "black"],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
}
|
|
|
|
| 20 |
// },
|
| 21 |
// },
|
| 22 |
// "isort.args":["--profile", "black"],
|
| 23 |
+
"python.testing.unittestEnabled": false,
|
| 24 |
+
"python.testing.pytestEnabled": true,
|
| 25 |
+
"python.testing.cwd": "${workspaceFolder}/",
|
| 26 |
+
"python.envFile": "${workspaceFolder}/.env",
|
| 27 |
}
|
src/config.py
CHANGED
|
@@ -99,6 +99,16 @@ class TinyCLIPConfig(PretrainedConfig):
|
|
| 99 |
self.loss_type = loss_type
|
| 100 |
super().__init__(**kwargs)
|
| 101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
class TrainerConfig(pydantic.BaseModel):
|
| 104 |
epochs: int = 20
|
|
@@ -119,3 +129,8 @@ class TrainerConfig(pydantic.BaseModel):
|
|
| 119 |
|
| 120 |
_model_config: TinyCLIPConfig = TinyCLIPConfig()
|
| 121 |
_data_config: DataConfig = DataConfig()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
self.loss_type = loss_type
|
| 100 |
super().__init__(**kwargs)
|
| 101 |
|
| 102 |
+
@classmethod
|
| 103 |
+
def from_dict(cls, config_dict, **kwargs):
|
| 104 |
+
text_config_dict = config_dict.pop("text_config", {})
|
| 105 |
+
text_config = TinyCLIPTextConfig.from_dict(text_config_dict)
|
| 106 |
+
|
| 107 |
+
vision_config_dict = config_dict.pop("vision_config", {})
|
| 108 |
+
vision_config = TinyCLIPVisionConfig.from_dict(vision_config_dict)
|
| 109 |
+
|
| 110 |
+
return cls(text_config=text_config, vision_config=vision_config, **config_dict, **kwargs)
|
| 111 |
+
|
| 112 |
|
| 113 |
class TrainerConfig(pydantic.BaseModel):
|
| 114 |
epochs: int = 20
|
|
|
|
| 129 |
|
| 130 |
_model_config: TinyCLIPConfig = TinyCLIPConfig()
|
| 131 |
_data_config: DataConfig = DataConfig()
|
| 132 |
+
|
| 133 |
+
def __init__(self, **data):
|
| 134 |
+
super().__init__(**data)
|
| 135 |
+
if "_model_config" in data:
|
| 136 |
+
self._model_config = TinyCLIPConfig.from_dict(data["_model_config"])
|
tests/test_config.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src import config
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def test_trainer_config():
|
| 6 |
+
trainer_config = config.TrainerConfig.model_validate_json(
|
| 7 |
+
json.dumps({"epochs": 21, "_model_config": {"text_config": {"text_model": "test"}}})
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
assert trainer_config.epochs == 21
|
| 11 |
+
assert trainer_config._model_config.text_config.text_model == "test"
|
| 12 |
+
assert hasattr(trainer_config._model_config.text_config, "max_len")
|
| 13 |
+
assert trainer_config._model_config.vision_config == config.TinyCLIPVisionConfig()
|