File size: 3,471 Bytes
99ec8a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from omegaconf import OmegaConf
from omegaconf import MISSING

from dataclasses import dataclass, field, asdict
from typing import Any, Dict, Tuple, Optional


@dataclass
class DataLoading:
    batch_size:         int = 64
    shuffle:            bool = True
    num_workers:        int = 0
    pin_memory:         bool = False
    persistent_workers: bool = False


@dataclass
class DataPreprocessing:
    data_parser:    str = MISSING
    label_mapping:  Dict[str, int] = field(default_factory=dict)
    resize_dims:    Tuple[int, int] = (128, 128)
    data_mode:      str = "rgb"                # auto-completed from input_channels in SubnetworkParams
    medmnist_csv_file:  Optional[str] = None


@dataclass
class DataParams:
    dataset_path:       str = MISSING
    images_extension:   str = MISSING
    data_preprocessing: DataPreprocessing = field(default_factory=DataPreprocessing)
    data_loading:       DataLoading = field(default_factory=DataLoading)


@dataclass
class TrainingParams:
    mode:           str = "binary"
    loss:           str = "binary_cross_entropy"
    epochs:         int = 10
    optimizer:      str = "adam"
    learning_rate:  float = 0.001
    momentum:       float = 0.0
    weight_decay:   float = 0.001
    early_stopping_patience: int = 4
    early_stopping_monitor: str = "val_loss"
    early_stopping_mode:    str = "min"


@dataclass
class SubnetworkParams:
    architecture:       str = "base_one"
    input_channels:     int = 3
    # fc_pred_units:    int = 1            # num_classes
    base_channels:      int = 32
    fc_hidden_units:    int = 64
    pred_activation:    str = "tanh"


@dataclass
class EPUCNNParams:
    num_subnetworks:   int = 4
    num_classes:       int = 1
    epu_activation:    str = "sigmoid"
    subnetwork_config: SubnetworkParams = field(default_factory=SubnetworkParams)


@dataclass
class IUSConfig:
    model:          EPUCNNParams = field(default_factory=EPUCNNParams)
    train_params:   TrainingParams = field(default_factory=TrainingParams)
    data_params:    DataParams = field(default_factory=DataParams)
    log_dir:            Optional[str] = "./logs"
    checkpoint_dir:     Optional[str] = "./checkpoints"
    experiment_name:    Optional[str] = "ius_experiment"
    timestamp:          Optional[str] = None
    experiment_saved_folder_name: Optional[str] = None

    @staticmethod
    def from_yaml(filepath: str) -> "IUSConfig":
        yaml_cfg = OmegaConf.load(filepath)
        merged = OmegaConf.merge(OmegaConf.structured(IUSConfig), yaml_cfg)
        merged = OmegaConf.to_object(merged)
        if isinstance(merged, IUSConfig):
            cfg = merged
        else:               # manually
            cfg = IUSConfig(**merged)
        return cfg

    def to_dict(self) -> Dict[str, Any]:
        return {
            "model": vars(self.model),
            "train_params": vars(self.train_params),
            "data_params": vars(self.data_params),
        }

    def __repr__(self):
        return OmegaConf.to_yaml(OmegaConf.structured(self), resolve=True)


if __name__ == "__main__":
     cfg_file = "configs/train_config.yaml"
     cfg = IUSConfig.from_yaml(cfg_file)
     print(cfg.model)
     print(cfg.train_params)
     print(cfg.data_params)

     # yaml_cfg = IUSConfig.from_yaml("checkpoints/test_experiment_unet_0032_20260203_192949/user_config.yaml")
     # print(yaml_cfg)
     # print(type(yaml_cfg))
     # print(yaml_cfg.data_params.test_split)