| | from dataclasses import dataclass |
| | from pathlib import Path |
| | from typing import Literal, Optional, Type, TypeVar |
| |
|
| | from dacite import Config, from_dict |
| | from omegaconf import DictConfig, OmegaConf |
| |
|
| | from .dataset import DatasetCfgWrapper |
| | from .dataset.data_module import DataLoaderCfg |
| | from .loss import LossCfgWrapper |
| | from .model.decoder import DecoderCfg |
| | from .model.encoder import EncoderCfg |
| | from .model.model_wrapper import OptimizerCfg, TestCfg, TrainCfg |
| |
|
| |
|
| | @dataclass |
| | class CheckpointingCfg: |
| | load: Optional[str] |
| | every_n_train_steps: int |
| | save_top_k: int |
| | save_weights_only: bool |
| |
|
| |
|
| | @dataclass |
| | class ModelCfg: |
| | decoder: DecoderCfg |
| | encoder: EncoderCfg |
| |
|
| |
|
| | @dataclass |
| | class TrainerCfg: |
| | max_steps: int |
| | val_check_interval: int | float | None |
| | gradient_clip_val: int | float | None |
| | num_nodes: int = 1 |
| | accumulate_grad_batches: int = 1 |
| | precision: Literal["32", "16-mixed", "bf16-mixed"] = "32" |
| |
|
| |
|
| | @dataclass |
| | class RootCfg: |
| | wandb: dict |
| | mode: Literal["train", "test"] |
| | dataset: list[DatasetCfgWrapper] |
| | data_loader: DataLoaderCfg |
| | model: ModelCfg |
| | optimizer: OptimizerCfg |
| | checkpointing: CheckpointingCfg |
| | trainer: TrainerCfg |
| | loss: list[LossCfgWrapper] |
| | test: TestCfg |
| | train: TrainCfg |
| | seed: int |
| |
|
| |
|
| | TYPE_HOOKS = { |
| | Path: Path, |
| | } |
| |
|
| |
|
| | T = TypeVar("T") |
| |
|
| |
|
| | def load_typed_config( |
| | cfg: DictConfig, |
| | data_class: Type[T], |
| | extra_type_hooks: dict = {}, |
| | ) -> T: |
| | return from_dict( |
| | data_class, |
| | OmegaConf.to_container(cfg), |
| | config=Config(type_hooks={**TYPE_HOOKS, **extra_type_hooks}), |
| | ) |
| |
|
| |
|
| | def separate_loss_cfg_wrappers(joined: dict) -> list[LossCfgWrapper]: |
| | |
| | @dataclass |
| | class Dummy: |
| | dummy: LossCfgWrapper |
| |
|
| | return [ |
| | load_typed_config(DictConfig({"dummy": {k: v}}), Dummy).dummy |
| | for k, v in joined.items() |
| | ] |
| |
|
| |
|
| | def separate_dataset_cfg_wrappers(joined: dict) -> list[DatasetCfgWrapper]: |
| | |
| | @dataclass |
| | class Dummy: |
| | dummy: DatasetCfgWrapper |
| |
|
| | return [ |
| | load_typed_config(DictConfig({"dummy": {k: v}}), Dummy).dummy |
| | for k, v in joined.items() |
| | ] |
| |
|
| |
|
| | def load_typed_root_config(cfg: DictConfig) -> RootCfg: |
| | return load_typed_config( |
| | cfg, |
| | RootCfg, |
| | {list[LossCfgWrapper]: separate_loss_cfg_wrappers, |
| | list[DatasetCfgWrapper]: separate_dataset_cfg_wrappers}, |
| | ) |
| |
|