| from omegaconf import OmegaConf |
| from omegaconf import MISSING |
|
|
| from dataclasses import dataclass, field, asdict |
| from typing import Any, Dict, Tuple, Optional |
|
|
|
|
| @dataclass |
| class DataLoading: |
| batch_size: int = 64 |
| shuffle: bool = True |
| num_workers: int = 0 |
| pin_memory: bool = False |
| persistent_workers: bool = False |
|
|
|
|
| @dataclass |
| class DataPreprocessing: |
| data_parser: str = MISSING |
| label_mapping: Dict[str, int] = field(default_factory=dict) |
| resize_dims: Tuple[int, int] = (128, 128) |
| data_mode: str = "rgb" |
| medmnist_csv_file: Optional[str] = None |
|
|
|
|
| @dataclass |
| class DataParams: |
| dataset_path: str = MISSING |
| images_extension: str = MISSING |
| data_preprocessing: DataPreprocessing = field(default_factory=DataPreprocessing) |
| data_loading: DataLoading = field(default_factory=DataLoading) |
|
|
|
|
| @dataclass |
| class TrainingParams: |
| mode: str = "binary" |
| loss: str = "binary_cross_entropy" |
| epochs: int = 10 |
| optimizer: str = "adam" |
| learning_rate: float = 0.001 |
| momentum: float = 0.0 |
| weight_decay: float = 0.001 |
| early_stopping_patience: int = 4 |
| early_stopping_monitor: str = "val_loss" |
| early_stopping_mode: str = "min" |
|
|
|
|
| @dataclass |
| class SubnetworkParams: |
| architecture: str = "base_one" |
| input_channels: int = 3 |
| |
| base_channels: int = 32 |
| fc_hidden_units: int = 64 |
| pred_activation: str = "tanh" |
|
|
|
|
| @dataclass |
| class EPUCNNParams: |
| num_subnetworks: int = 4 |
| num_classes: int = 1 |
| epu_activation: str = "sigmoid" |
| subnetwork_config: SubnetworkParams = field(default_factory=SubnetworkParams) |
|
|
|
|
| @dataclass |
| class IUSConfig: |
| model: EPUCNNParams = field(default_factory=EPUCNNParams) |
| train_params: TrainingParams = field(default_factory=TrainingParams) |
| data_params: DataParams = field(default_factory=DataParams) |
| log_dir: Optional[str] = "./logs" |
| checkpoint_dir: Optional[str] = "./checkpoints" |
| experiment_name: Optional[str] = "ius_experiment" |
| timestamp: Optional[str] = None |
| experiment_saved_folder_name: Optional[str] = None |
|
|
| @staticmethod |
| def from_yaml(filepath: str) -> "IUSConfig": |
| yaml_cfg = OmegaConf.load(filepath) |
| merged = OmegaConf.merge(OmegaConf.structured(IUSConfig), yaml_cfg) |
| merged = OmegaConf.to_object(merged) |
| if isinstance(merged, IUSConfig): |
| cfg = merged |
| else: |
| cfg = IUSConfig(**merged) |
| return cfg |
|
|
| def to_dict(self) -> Dict[str, Any]: |
| return { |
| "model": vars(self.model), |
| "train_params": vars(self.train_params), |
| "data_params": vars(self.data_params), |
| } |
|
|
| def __repr__(self): |
| return OmegaConf.to_yaml(OmegaConf.structured(self), resolve=True) |
|
|
|
|
| if __name__ == "__main__": |
| cfg_file = "configs/train_config.yaml" |
| cfg = IUSConfig.from_yaml(cfg_file) |
| print(cfg.model) |
| print(cfg.train_params) |
| print(cfg.data_params) |
|
|
| |
| |
| |
| |
|
|