import os import yaml import torch import torch.nn as nn from typing import List, Optional from src.util.loss import DiceLoss, FocalLoss, VAELoss, ImageInpaintingL1Loss from torchvision.models import resnet152, swin_b, efficientnet_v2_l, vit_l_16 from torchvision.models import ( ResNet152_Weights, Swin_B_Weights, EfficientNet_V2_L_Weights, ViT_L_16_Weights, ) from src.models.simple_z_predictor import SimpleZRegressionVisionTransformer from src.models.autoencoder import Autoencoder from src.models.vae import VAE from src.models.unet.unet import UNet, ThickUNet from src.models.unetr.unetr import UNETR from src.models.classic_recon import ( LinearInterpolationInpainter, BicubicInterpolationInpainter, AMPInpainter, NearestNeighborsInpainter, ) from src.models.our_method.swin_cafm import SwinCAFM from src.models.prev_methods.sstem import SSTEM from src.models.prev_methods.gpstruct import GPSTRUCT def parse_config(fp: str) -> dict: r""" Args :param fp: path to config file Returns :return: dict """ assert os.path.isfile(fp), f"Error: config file @ {fp} does not exist" with open(fp, "r") as f: config = yaml.safe_load(f) return config LOSS_FUNCTIONS = { "MSE": nn.MSELoss, "L1": nn.L1Loss, "CrossEntropy": nn.CrossEntropyLoss, "SmoothL1": nn.SmoothL1Loss, "Dice": DiceLoss, "Focal": FocalLoss, "Huber": nn.HuberLoss, "VAE": VAELoss, "InpaintingL1": ImageInpaintingL1Loss, } MODELS = { "ours": { "fn": SwinCAFM.get, "weights": None, }, "simple_z_reg_vit": { "fn": SimpleZRegressionVisionTransformer.get, "weights": None, }, "resnet152": { "fn": resnet152, "weights": ResNet152_Weights.IMAGENET1K_V2, }, "swin_b": { "fn": swin_b, "weights": Swin_B_Weights.IMAGENET1K_V1, }, "efficientnet_v2_l": { "fn": efficientnet_v2_l, "weights": EfficientNet_V2_L_Weights.IMAGENET1K_V1, }, "vit_l_16": { "fn": vit_l_16, "weights": ViT_L_16_Weights.IMAGENET1K_V1, }, "ae": {"fn": Autoencoder.get, "weights": None}, "vae": {"fn": VAE.get, "weights": None}, "unet": {"fn": UNet.get, "weights": None}, "thick_unet": {"fn": ThickUNet.get, "weights": None}, "unetr": {"fn": UNETR.get, "weights": None}, "linear_interpolation": {"fn": LinearInterpolationInpainter.get, "weights": None}, "bicubic_interpolation": {"fn": BicubicInterpolationInpainter.get, "weights": None}, "amp_interpolation": {"fn": AMPInpainter.get, "weights": None}, "nn_interpolation": {"fn": NearestNeighborsInpainter.get, "weights": None}, "sstem_interpolation": {"fn": SSTEM.get, "weights": None}, "gpstruct_interpolation": {"fn": GPSTRUCT.get, "weights": None}, } OPTIMIZERS = { "Adam": torch.optim.Adam, "AdamW": torch.optim.AdamW, "SGD": torch.optim.SGD, } class TrainConfig: """ Object representing a config file for a training run of a model. """ def __init__(self, config_fp: str): if not os.path.isfile(config_fp): raise FileNotFoundError(f"Config file not found: {config_fp}") config_dict: dict = parse_config(config_fp) # --- Global settings --- global_cfg: dict = config_dict.get("global", {}) self.device: int = global_cfg.get("device", 0) self.mode: str = global_cfg.get("mode", "train") self.formulation: Optional[str] = global_cfg.get("formulation", None) # --- Model settings --- model_cfg: dict = config_dict.get("model", {}) self.model_name: str = model_cfg.get("name", "") self.pretrained: str = model_cfg.get("pretrained", False) self.weights: str = model_cfg.get("weights", None) self.model_config_file: str = model_cfg.get("config", None) # --- Surrogate settings --- surrogate_cfg: dict = config_dict.get("surrogate", {}) self.surrogate_name: str = surrogate_cfg.get("name", "") self.surgate_weights: str = surrogate_cfg.get("weights", None) # --- Training settings --- training_cfg: str = config_dict.get("training", {}) self.train_batch_size: int = training_cfg.get("batch_size", 1) self.steps_per_epoch: int = training_cfg.get("steps_per_epoch", 1024) self.epochs: int = training_cfg.get("epochs", 2000) self.train_loss: Optional[str] = training_cfg.get("loss", None) self.learning_rate: str = training_cfg.get("lr", 1e-4) self.optimizer: str = training_cfg.get("optimizer", "Adam") # --- Validation settings --- validation_cfg: dict = config_dict.get("validation", {}) self.val_batch_size: int = validation_cfg.get("batch_size", 1) self.val_steps_per_epoch: int = validation_cfg.get("steps_per_epoch", 256) self.val_loss: Optional[str] = validation_cfg.get("loss", None) # --- Dataset settings --- dataset_cfg: dict = config_dict.get("dataset", {}) self.dataset_name: str = dataset_cfg.get("name", "") self.image_size: Optional[str] = dataset_cfg.get("image_size", None) self.crop_size: Optional[str] = dataset_cfg.get("crop_size", None) self.num_workers: int = dataset_cfg.get("num_workers", 0) self.masking_ratio: int = dataset_cfg.get("masking_ratio", 1) # --- Logging settings --- logging_cfg: dict = config_dict.get("logging", {}) self.log_root: str = logging_cfg.get("root", "") self.exp_name: str = logging_cfg.get("exp_name", "") self.result_columns: List = logging_cfg.get("result_columns", []) self.save_weights: bool = logging_cfg.get("save_weights", True) self.save_only_best_weights: bool = logging_cfg.get( "save_only_best_weights", True ) self.enable_tensorboard: bool = logging_cfg.get("enable_tensorboard", False) self.log_figures: bool = logging_cfg.get("log_figures", False) self.log_interval: int = logging_cfg.get("log_interval", 1) def to_dict(self) -> dict: config_dict = { "global": { "device": self.device, "mode": self.mode, "formulation": self.formulation, }, "model": { "name": self.model_name, "pretrained": self.pretrained, "weights": self.weights, "config": self.model_config_file, }, "surrogate": { "name": self.surrogate_name, "weights": self.surgate_weights, }, "training": { "batch_size": self.train_batch_size, "steps_per_epoch": self.steps_per_epoch, "epochs": self.epochs, "loss": self.train_loss, "lr": self.learning_rate, "optimizer": self.optimizer, }, "validation": { "batch_size": self.val_batch_size, "steps_per_epoch": self.val_steps_per_epoch, "loss": self.val_loss, }, "dataset": { "name": self.dataset_name, "image_size": self.image_size, "crop_size": self.crop_size, "num_workers": self.num_workers, "masking_ratio": self.masking_ratio, }, "logging": { "root": self.log_root, "exp_name": self.exp_name, "result_columns": self.result_columns, "save_weights": self.save_weights, "save_only_best_weights": self.save_only_best_weights, "enable_tensorboard": self.enable_tensorboard, "log_figures": self.log_figures, "log_interval": self.log_interval, }, } return config_dict def save_config(self, output_fp: str) -> None: """ Save the current configuration to a YAML file, preserving the original format. """ config_dict = { "global": { "device": self.device, "mode": self.mode, "formulation": self.formulation, }, "model": { "name": self.model_name, "pretrained": self.pretrained, "weights": self.weights, "config": self.model_config_file, }, "surrogate": { "name": self.surrogate_name, "weights": self.surgate_weights, }, "training": { "batch_size": self.train_batch_size, "steps_per_epoch": self.steps_per_epoch, "epochs": self.epochs, "loss": self.train_loss, "lr": self.learning_rate, "optimizer": self.optimizer, }, "validation": { "batch_size": self.val_batch_size, "steps_per_epoch": self.val_steps_per_epoch, "loss": self.val_loss, }, "dataset": { "name": self.dataset_name, "image_size": self.image_size, "crop_size": self.crop_size, "num_workers": self.num_workers, "masking_ratio": self.masking_ratio, }, "logging": { "root": self.log_root, "exp_name": self.exp_name, "result_columns": self.result_columns, "save_weights": self.save_weights, "save_only_best_weights": self.save_only_best_weights, "enable_tensorboard": self.enable_tensorboard, "log_figures": self.log_figures, "log_interval": self.log_interval, }, } # save config with open(output_fp, "w") as f: yaml.safe_dump(config_dict, f, default_flow_style=False, sort_keys=False) class EvalConfig: """ Object representing a config file for a evaluation run of a model. """ def __init__(self, config_fp: str): if not os.path.isfile(config_fp): raise FileNotFoundError(f"Config file not found: {config_fp}") config_dict: dict = parse_config(config_fp) # --- Global settings --- global_cfg: dict = config_dict.get("global", {}) self.device: int = global_cfg.get("device", 0) self.mode: str = global_cfg.get("mode", "train") self.formulation: Optional[str] = global_cfg.get("formulation", None) # --- Model settings --- model_cfg: dict = config_dict.get("model", {}) self.model_name: str = model_cfg.get("name", "") self.pretrained: str = model_cfg.get("pretrained", False) self.weights: str = model_cfg.get("weights", None) self.model_config_file: str = model_cfg.get("config", None) # --- Validation settings --- validation_cfg: dict = config_dict.get("validation", {}) self.val_batch_size: int = validation_cfg.get("batch_size", 1) self.val_steps_per_epoch: int = validation_cfg.get("steps_per_epoch", 256) self.val_loss: Optional[str] = validation_cfg.get("loss", None) # --- Dataset settings --- dataset_cfg: dict = config_dict.get("dataset", {}) self.dataset_name: str = dataset_cfg.get("name", "") self.image_size: Optional[str] = dataset_cfg.get("image_size", None) self.crop_size: Optional[str] = dataset_cfg.get("crop_size", None) self.num_workers: int = dataset_cfg.get("num_workers", 0) self.masking_ratio: int = dataset_cfg.get("masking_ratio", 1) # --- Logging settings --- logging_cfg: dict = config_dict.get("logging", {}) self.log_root: str = logging_cfg.get("root", "") self.exp_name: str = logging_cfg.get("exp_name", "") self.result_columns: List = logging_cfg.get("result_columns", []) self.save_weights: bool = logging_cfg.get("save_weights", True) self.save_only_best_weights: bool = logging_cfg.get( "save_only_best_weights", True ) self.enable_tensorboard: bool = logging_cfg.get("enable_tensorboard", False) self.log_figures: bool = logging_cfg.get("log_figures", False) self.log_interval: int = logging_cfg.get("log_interval", 1) def to_dict(self) -> dict: config_dict = { "global": { "device": self.device, "mode": self.mode, "formulation": self.formulation, }, "model": { "name": self.model_name, "pretrained": self.pretrained, "weights": self.weights, "config": self.model_config_file, }, "validation": { "batch_size": self.val_batch_size, "steps_per_epoch": self.val_steps_per_epoch, "loss": self.val_loss, }, "dataset": { "name": self.dataset_name, "image_size": self.image_size, "crop_size": self.crop_size, "num_workers": self.num_workers, "masking_ratio": self.masking_ratio, }, "logging": { "root": self.log_root, "exp_name": self.exp_name, "result_columns": self.result_columns, "save_weights": self.save_weights, "save_only_best_weights": self.save_only_best_weights, "enable_tensorboard": self.enable_tensorboard, "log_figures": self.log_figures, "log_interval": self.log_interval, }, } return config_dict def save_config(self, output_fp: str) -> None: """ Save the current configuration to a YAML file, preserving the original format. """ config_dict = { "global": { "device": self.device, "mode": self.mode, "formulation": self.formulation, }, "model": { "name": self.model_name, "pretrained": self.pretrained, "weights": self.weights, "config": self.model_config_file, }, "validation": { "batch_size": self.val_batch_size, "steps_per_epoch": self.val_steps_per_epoch, "loss": self.val_loss, }, "dataset": { "name": self.dataset_name, "image_size": self.image_size, "crop_size": self.crop_size, "num_workers": self.num_workers, "masking_ratio": self.masking_ratio, }, "logging": { "root": self.log_root, "exp_name": self.exp_name, "result_columns": self.result_columns, "save_weights": self.save_weights, "save_only_best_weights": self.save_only_best_weights, "enable_tensorboard": self.enable_tensorboard, "log_figures": self.log_figures, "log_interval": self.log_interval, }, } # save config with open(output_fp, "w") as f: yaml.safe_dump(config_dict, f, default_flow_style=False, sort_keys=False) class SurrogateEvalConfig: """ Object representing a config file for an evaluation run of an older-surrogate model pair. """ def __init__(self, config_fp: str): if not os.path.isfile(config_fp): raise FileNotFoundError(f"Config file not found: {config_fp}") config_dict: dict = parse_config(config_fp) # --- Global settings --- global_cfg: dict = config_dict.get("global", {}) self.device: int = global_cfg.get("device", 0) self.mode: str = global_cfg.get("mode", "train") self.formulation: Optional[str] = global_cfg.get("formulation", None) # --- Denoising model settings --- denoising_model_cfg: dict = config_dict.get("denoising_model", {}) self.denoising_model_name: str = denoising_model_cfg.get("name", "") self.denoising_model_pretrained: bool = denoising_model_cfg.get( "pretrained", False ) self.denoising_model_weights: Optional[str] = denoising_model_cfg.get( "weights", None ) self.denoising_model_config_file: Optional[str] = denoising_model_cfg.get( "config", None ) # --- Older surrogate model settings --- older_surrogate_model_cfg: dict = config_dict.get("older_surrogate_model", {}) self.older_surrogate_model_name: str = older_surrogate_model_cfg.get("name", "") self.older_surrogate_model_pretrained: bool = older_surrogate_model_cfg.get( "pretrained", False ) self.older_surrogate_model_weights: Optional[str] = ( older_surrogate_model_cfg.get("weights", None) ) self.older_surrogate_model_config_file: Optional[str] = ( older_surrogate_model_cfg.get("config", None) ) # --- Training settings --- training_cfg: dict = config_dict.get("training", {}) self.train_batch_size: int = training_cfg.get("batch_size", 1) self.train_steps_per_epoch: int = training_cfg.get("steps_per_epoch", 1024) self.epochs: int = training_cfg.get("epochs", 100) self.train_loss: Optional[str] = training_cfg.get("loss", None) self.lr: float = training_cfg.get("lr", 1e-4) self.optimizer: Optional[str] = training_cfg.get("optimizer", "Adam") # --- Validation settings --- validation_cfg: dict = config_dict.get("validation", {}) self.val_batch_size: int = validation_cfg.get("batch_size", 1) self.val_steps_per_epoch: int = validation_cfg.get("steps_per_epoch", 256) self.val_loss: Optional[str] = validation_cfg.get("loss", None) # --- Dataset settings --- dataset_cfg: dict = config_dict.get("dataset", {}) self.dataset_name: str = dataset_cfg.get("name", "") self.image_size: Optional[int] = dataset_cfg.get("image_size", None) self.crop_size: Optional[int] = dataset_cfg.get("crop_size", None) self.num_workers: int = dataset_cfg.get("num_workers", 0) self.masking_ratio: int = dataset_cfg.get("masking_ratio", 1) # --- Logging settings --- logging_cfg: dict = config_dict.get("logging", {}) self.log_root: str = logging_cfg.get("root", "") self.exp_name: str = logging_cfg.get("exp_name", "") self.result_columns: List = logging_cfg.get("result_columns", []) self.save_weights: bool = logging_cfg.get("save_weights", True) self.save_only_best_weights: bool = logging_cfg.get( "save_only_best_weights", True ) self.enable_tensorboard: bool = logging_cfg.get("enable_tensorboard", False) self.log_figures: bool = logging_cfg.get("log_figures", False) self.log_interval: int = logging_cfg.get("log_interval", 1) def to_dict(self) -> dict: """ Return the configuration as a dictionary matching the YAML file structure. """ config_dict = { "global": { "device": self.device, "mode": self.mode, "formulation": self.formulation, }, "denoising_model": { "name": self.denoising_model_name, "pretrained": self.denoising_model_pretrained, "weights": self.denoising_model_weights, "config": self.denoising_model_config_file, }, "older_surrogate_model": { "name": self.older_surrogate_model_name, "pretrained": self.older_surrogate_model_pretrained, "weights": self.older_surrogate_model_weights, "config": self.older_surrogate_model_config_file, }, "training": { "batch_size": self.train_batch_size, "steps_per_epoch": self.train_steps_per_epoch, "epochs": self.epochs, "loss": self.train_loss, "lr": self.lr, "optimizer": self.optimizer, }, "validation": { "batch_size": self.val_batch_size, "steps_per_epoch": self.val_steps_per_epoch, "loss": self.val_loss, }, "dataset": { "name": self.dataset_name, "image_size": self.image_size, "crop_size": self.crop_size, "num_workers": self.num_workers, "masking_ratio": self.masking_ratio, }, "logging": { "root": self.log_root, "exp_name": self.exp_name, "result_columns": self.result_columns, "save_weights": self.save_weights, "save_only_best_weights": self.save_only_best_weights, "enable_tensorboard": self.enable_tensorboard, "log_figures": self.log_figures, "log_interval": self.log_interval, }, } return config_dict def save_config(self, output_fp: str) -> None: """ Save the current configuration to a YAML file, preserving the original structure. """ config_dict = self.to_dict() with open(output_fp, "w") as f: yaml.safe_dump(config_dict, f, default_flow_style=False, sort_keys=False) class ModelConfig: """ Object representing a config file for a SwinIR model. """ def __init__(self, config_fp: str): if not os.path.isfile(config_fp): raise FileNotFoundError(f"Config file not found: {config_fp}") config_dict: dict = parse_config(config_fp) # --- Top-level setting: weights_fp --- self.weights_fp: str = config_dict.get("weights_fp", "") # --- Hyperparameters --- hyperparams: dict = config_dict.get("hyperparams", {}) self.upscale: int = hyperparams.get("upscale", 8) self.img_size: List[int] = hyperparams.get("img_size", [128, 128]) self.window_size: int = hyperparams.get("window_size", 8) self.img_range: float = hyperparams.get("img_range", 1.0) self.depths: List[int] = hyperparams.get("depths", [8, 8, 8, 8, 8, 8]) self.embed_dim: int = hyperparams.get("embed_dim", 180) self.num_heads: List[int] = hyperparams.get("num_heads", [6, 6, 6, 6, 6, 6]) self.mlp_ratio: int = hyperparams.get("mlp_ratio", 2) self.upsampler: str = hyperparams.get("upsampler", "no_upscale") self.resi_connection: str = hyperparams.get("resi_connection", "1conv") self.drop_path_rate = ( config_dict.get("hyperparams", {}).get("drop_path_rate", 0.1), ) # layer norm self.layer_norm_str = config_dict.get("hyperparams", {}).get("norm_layer", None) layer_norm = ( torch.nn.LayerNorm if self.layer_norm_str == "torch.nn.LayerNorm" else None ) self.norm_layer = layer_norm def to_dict(self) -> dict: config_dict = { "weights_fp": self.weights_fp, "hyperparams": { "upscale": self.upscale, "img_size": self.img_size, "window_size": self.window_size, "img_range": self.img_range, "depths": self.depths, "embed_dim": self.embed_dim, "num_heads": self.num_heads, "mlp_ratio": self.mlp_ratio, "drop_path_rate": self.drop_path_rate, "norm_layer": self.layer_norm_str, "upsampler": self.upsampler, "resi_connection": self.resi_connection, }, } return config_dict def save_config(self, output_fp: str) -> None: """ Save the current configuration to a YAML file, preserving the original format. """ config_dict = { "weights_fp": self.weights_fp, "hyperparams": { "upscale": self.upscale, "img_size": self.img_size, "window_size": self.window_size, "img_range": self.img_range, "depths": self.depths, "embed_dim": self.embed_dim, "num_heads": self.num_heads, "mlp_ratio": self.mlp_ratio, "upsampler": self.upsampler, "resi_connection": self.resi_connection, }, } with open(output_fp, "w") as f: yaml.safe_dump(config_dict, f, default_flow_style=False, sort_keys=False)