ius / utils /omega_parser.py
pgatoula's picture
Sync from GitHub via hub-sync
99ec8a2 verified
from omegaconf import OmegaConf
from omegaconf import MISSING
from dataclasses import dataclass, field, asdict
from typing import Any, Dict, Tuple, Optional
@dataclass
class DataLoading:
batch_size: int = 64
shuffle: bool = True
num_workers: int = 0
pin_memory: bool = False
persistent_workers: bool = False
@dataclass
class DataPreprocessing:
data_parser: str = MISSING
label_mapping: Dict[str, int] = field(default_factory=dict)
resize_dims: Tuple[int, int] = (128, 128)
data_mode: str = "rgb" # auto-completed from input_channels in SubnetworkParams
medmnist_csv_file: Optional[str] = None
@dataclass
class DataParams:
dataset_path: str = MISSING
images_extension: str = MISSING
data_preprocessing: DataPreprocessing = field(default_factory=DataPreprocessing)
data_loading: DataLoading = field(default_factory=DataLoading)
@dataclass
class TrainingParams:
mode: str = "binary"
loss: str = "binary_cross_entropy"
epochs: int = 10
optimizer: str = "adam"
learning_rate: float = 0.001
momentum: float = 0.0
weight_decay: float = 0.001
early_stopping_patience: int = 4
early_stopping_monitor: str = "val_loss"
early_stopping_mode: str = "min"
@dataclass
class SubnetworkParams:
architecture: str = "base_one"
input_channels: int = 3
# fc_pred_units: int = 1 # num_classes
base_channels: int = 32
fc_hidden_units: int = 64
pred_activation: str = "tanh"
@dataclass
class EPUCNNParams:
num_subnetworks: int = 4
num_classes: int = 1
epu_activation: str = "sigmoid"
subnetwork_config: SubnetworkParams = field(default_factory=SubnetworkParams)
@dataclass
class IUSConfig:
model: EPUCNNParams = field(default_factory=EPUCNNParams)
train_params: TrainingParams = field(default_factory=TrainingParams)
data_params: DataParams = field(default_factory=DataParams)
log_dir: Optional[str] = "./logs"
checkpoint_dir: Optional[str] = "./checkpoints"
experiment_name: Optional[str] = "ius_experiment"
timestamp: Optional[str] = None
experiment_saved_folder_name: Optional[str] = None
@staticmethod
def from_yaml(filepath: str) -> "IUSConfig":
yaml_cfg = OmegaConf.load(filepath)
merged = OmegaConf.merge(OmegaConf.structured(IUSConfig), yaml_cfg)
merged = OmegaConf.to_object(merged)
if isinstance(merged, IUSConfig):
cfg = merged
else: # manually
cfg = IUSConfig(**merged)
return cfg
def to_dict(self) -> Dict[str, Any]:
return {
"model": vars(self.model),
"train_params": vars(self.train_params),
"data_params": vars(self.data_params),
}
def __repr__(self):
return OmegaConf.to_yaml(OmegaConf.structured(self), resolve=True)
if __name__ == "__main__":
cfg_file = "configs/train_config.yaml"
cfg = IUSConfig.from_yaml(cfg_file)
print(cfg.model)
print(cfg.train_params)
print(cfg.data_params)
# yaml_cfg = IUSConfig.from_yaml("checkpoints/test_experiment_unet_0032_20260203_192949/user_config.yaml")
# print(yaml_cfg)
# print(type(yaml_cfg))
# print(yaml_cfg.data_params.test_split)