Spaces:
Sleeping
Sleeping
| import os | |
| import yaml | |
| import torch | |
| import torch.nn as nn | |
| from typing import List, Optional | |
| from src.util.loss import DiceLoss, FocalLoss, VAELoss, ImageInpaintingL1Loss | |
| from torchvision.models import resnet152, swin_b, efficientnet_v2_l, vit_l_16 | |
| from torchvision.models import ( | |
| ResNet152_Weights, | |
| Swin_B_Weights, | |
| EfficientNet_V2_L_Weights, | |
| ViT_L_16_Weights, | |
| ) | |
| from src.models.simple_z_predictor import SimpleZRegressionVisionTransformer | |
| from src.models.autoencoder import Autoencoder | |
| from src.models.vae import VAE | |
| from src.models.unet.unet import UNet, ThickUNet | |
| from src.models.unetr.unetr import UNETR | |
| from src.models.classic_recon import ( | |
| LinearInterpolationInpainter, | |
| BicubicInterpolationInpainter, | |
| AMPInpainter, | |
| NearestNeighborsInpainter, | |
| ) | |
| from src.models.our_method.swin_cafm import SwinCAFM | |
| from src.models.prev_methods.sstem import SSTEM | |
| from src.models.prev_methods.gpstruct import GPSTRUCT | |
| def parse_config(fp: str) -> dict: | |
| r""" | |
| Args | |
| :param fp: path to config file | |
| Returns | |
| :return: dict | |
| """ | |
| assert os.path.isfile(fp), f"Error: config file @ {fp} does not exist" | |
| with open(fp, "r") as f: | |
| config = yaml.safe_load(f) | |
| return config | |
| LOSS_FUNCTIONS = { | |
| "MSE": nn.MSELoss, | |
| "L1": nn.L1Loss, | |
| "CrossEntropy": nn.CrossEntropyLoss, | |
| "SmoothL1": nn.SmoothL1Loss, | |
| "Dice": DiceLoss, | |
| "Focal": FocalLoss, | |
| "Huber": nn.HuberLoss, | |
| "VAE": VAELoss, | |
| "InpaintingL1": ImageInpaintingL1Loss, | |
| } | |
| MODELS = { | |
| "ours": { | |
| "fn": SwinCAFM.get, | |
| "weights": None, | |
| }, | |
| "simple_z_reg_vit": { | |
| "fn": SimpleZRegressionVisionTransformer.get, | |
| "weights": None, | |
| }, | |
| "resnet152": { | |
| "fn": resnet152, | |
| "weights": ResNet152_Weights.IMAGENET1K_V2, | |
| }, | |
| "swin_b": { | |
| "fn": swin_b, | |
| "weights": Swin_B_Weights.IMAGENET1K_V1, | |
| }, | |
| "efficientnet_v2_l": { | |
| "fn": efficientnet_v2_l, | |
| "weights": EfficientNet_V2_L_Weights.IMAGENET1K_V1, | |
| }, | |
| "vit_l_16": { | |
| "fn": vit_l_16, | |
| "weights": ViT_L_16_Weights.IMAGENET1K_V1, | |
| }, | |
| "ae": {"fn": Autoencoder.get, "weights": None}, | |
| "vae": {"fn": VAE.get, "weights": None}, | |
| "unet": {"fn": UNet.get, "weights": None}, | |
| "thick_unet": {"fn": ThickUNet.get, "weights": None}, | |
| "unetr": {"fn": UNETR.get, "weights": None}, | |
| "linear_interpolation": {"fn": LinearInterpolationInpainter.get, "weights": None}, | |
| "bicubic_interpolation": {"fn": BicubicInterpolationInpainter.get, "weights": None}, | |
| "amp_interpolation": {"fn": AMPInpainter.get, "weights": None}, | |
| "nn_interpolation": {"fn": NearestNeighborsInpainter.get, "weights": None}, | |
| "sstem_interpolation": {"fn": SSTEM.get, "weights": None}, | |
| "gpstruct_interpolation": {"fn": GPSTRUCT.get, "weights": None}, | |
| } | |
| OPTIMIZERS = { | |
| "Adam": torch.optim.Adam, | |
| "AdamW": torch.optim.AdamW, | |
| "SGD": torch.optim.SGD, | |
| } | |
| class TrainConfig: | |
| """ | |
| Object representing a config file for a training run of a model. | |
| """ | |
| def __init__(self, config_fp: str): | |
| if not os.path.isfile(config_fp): | |
| raise FileNotFoundError(f"Config file not found: {config_fp}") | |
| config_dict: dict = parse_config(config_fp) | |
| # --- Global settings --- | |
| global_cfg: dict = config_dict.get("global", {}) | |
| self.device: int = global_cfg.get("device", 0) | |
| self.mode: str = global_cfg.get("mode", "train") | |
| self.formulation: Optional[str] = global_cfg.get("formulation", None) | |
| # --- Model settings --- | |
| model_cfg: dict = config_dict.get("model", {}) | |
| self.model_name: str = model_cfg.get("name", "") | |
| self.pretrained: str = model_cfg.get("pretrained", False) | |
| self.weights: str = model_cfg.get("weights", None) | |
| self.model_config_file: str = model_cfg.get("config", None) | |
| # --- Surrogate settings --- | |
| surrogate_cfg: dict = config_dict.get("surrogate", {}) | |
| self.surrogate_name: str = surrogate_cfg.get("name", "") | |
| self.surgate_weights: str = surrogate_cfg.get("weights", None) | |
| # --- Training settings --- | |
| training_cfg: str = config_dict.get("training", {}) | |
| self.train_batch_size: int = training_cfg.get("batch_size", 1) | |
| self.steps_per_epoch: int = training_cfg.get("steps_per_epoch", 1024) | |
| self.epochs: int = training_cfg.get("epochs", 2000) | |
| self.train_loss: Optional[str] = training_cfg.get("loss", None) | |
| self.learning_rate: str = training_cfg.get("lr", 1e-4) | |
| self.optimizer: str = training_cfg.get("optimizer", "Adam") | |
| # --- Validation settings --- | |
| validation_cfg: dict = config_dict.get("validation", {}) | |
| self.val_batch_size: int = validation_cfg.get("batch_size", 1) | |
| self.val_steps_per_epoch: int = validation_cfg.get("steps_per_epoch", 256) | |
| self.val_loss: Optional[str] = validation_cfg.get("loss", None) | |
| # --- Dataset settings --- | |
| dataset_cfg: dict = config_dict.get("dataset", {}) | |
| self.dataset_name: str = dataset_cfg.get("name", "") | |
| self.image_size: Optional[str] = dataset_cfg.get("image_size", None) | |
| self.crop_size: Optional[str] = dataset_cfg.get("crop_size", None) | |
| self.num_workers: int = dataset_cfg.get("num_workers", 0) | |
| self.masking_ratio: int = dataset_cfg.get("masking_ratio", 1) | |
| # --- Logging settings --- | |
| logging_cfg: dict = config_dict.get("logging", {}) | |
| self.log_root: str = logging_cfg.get("root", "") | |
| self.exp_name: str = logging_cfg.get("exp_name", "") | |
| self.result_columns: List = logging_cfg.get("result_columns", []) | |
| self.save_weights: bool = logging_cfg.get("save_weights", True) | |
| self.save_only_best_weights: bool = logging_cfg.get( | |
| "save_only_best_weights", True | |
| ) | |
| self.enable_tensorboard: bool = logging_cfg.get("enable_tensorboard", False) | |
| self.log_figures: bool = logging_cfg.get("log_figures", False) | |
| self.log_interval: int = logging_cfg.get("log_interval", 1) | |
| def to_dict(self) -> dict: | |
| config_dict = { | |
| "global": { | |
| "device": self.device, | |
| "mode": self.mode, | |
| "formulation": self.formulation, | |
| }, | |
| "model": { | |
| "name": self.model_name, | |
| "pretrained": self.pretrained, | |
| "weights": self.weights, | |
| "config": self.model_config_file, | |
| }, | |
| "surrogate": { | |
| "name": self.surrogate_name, | |
| "weights": self.surgate_weights, | |
| }, | |
| "training": { | |
| "batch_size": self.train_batch_size, | |
| "steps_per_epoch": self.steps_per_epoch, | |
| "epochs": self.epochs, | |
| "loss": self.train_loss, | |
| "lr": self.learning_rate, | |
| "optimizer": self.optimizer, | |
| }, | |
| "validation": { | |
| "batch_size": self.val_batch_size, | |
| "steps_per_epoch": self.val_steps_per_epoch, | |
| "loss": self.val_loss, | |
| }, | |
| "dataset": { | |
| "name": self.dataset_name, | |
| "image_size": self.image_size, | |
| "crop_size": self.crop_size, | |
| "num_workers": self.num_workers, | |
| "masking_ratio": self.masking_ratio, | |
| }, | |
| "logging": { | |
| "root": self.log_root, | |
| "exp_name": self.exp_name, | |
| "result_columns": self.result_columns, | |
| "save_weights": self.save_weights, | |
| "save_only_best_weights": self.save_only_best_weights, | |
| "enable_tensorboard": self.enable_tensorboard, | |
| "log_figures": self.log_figures, | |
| "log_interval": self.log_interval, | |
| }, | |
| } | |
| return config_dict | |
| def save_config(self, output_fp: str) -> None: | |
| """ | |
| Save the current configuration to a YAML file, preserving the original format. | |
| """ | |
| config_dict = { | |
| "global": { | |
| "device": self.device, | |
| "mode": self.mode, | |
| "formulation": self.formulation, | |
| }, | |
| "model": { | |
| "name": self.model_name, | |
| "pretrained": self.pretrained, | |
| "weights": self.weights, | |
| "config": self.model_config_file, | |
| }, | |
| "surrogate": { | |
| "name": self.surrogate_name, | |
| "weights": self.surgate_weights, | |
| }, | |
| "training": { | |
| "batch_size": self.train_batch_size, | |
| "steps_per_epoch": self.steps_per_epoch, | |
| "epochs": self.epochs, | |
| "loss": self.train_loss, | |
| "lr": self.learning_rate, | |
| "optimizer": self.optimizer, | |
| }, | |
| "validation": { | |
| "batch_size": self.val_batch_size, | |
| "steps_per_epoch": self.val_steps_per_epoch, | |
| "loss": self.val_loss, | |
| }, | |
| "dataset": { | |
| "name": self.dataset_name, | |
| "image_size": self.image_size, | |
| "crop_size": self.crop_size, | |
| "num_workers": self.num_workers, | |
| "masking_ratio": self.masking_ratio, | |
| }, | |
| "logging": { | |
| "root": self.log_root, | |
| "exp_name": self.exp_name, | |
| "result_columns": self.result_columns, | |
| "save_weights": self.save_weights, | |
| "save_only_best_weights": self.save_only_best_weights, | |
| "enable_tensorboard": self.enable_tensorboard, | |
| "log_figures": self.log_figures, | |
| "log_interval": self.log_interval, | |
| }, | |
| } | |
| # save config | |
| with open(output_fp, "w") as f: | |
| yaml.safe_dump(config_dict, f, default_flow_style=False, sort_keys=False) | |
| class EvalConfig: | |
| """ | |
| Object representing a config file for a evaluation run of a model. | |
| """ | |
| def __init__(self, config_fp: str): | |
| if not os.path.isfile(config_fp): | |
| raise FileNotFoundError(f"Config file not found: {config_fp}") | |
| config_dict: dict = parse_config(config_fp) | |
| # --- Global settings --- | |
| global_cfg: dict = config_dict.get("global", {}) | |
| self.device: int = global_cfg.get("device", 0) | |
| self.mode: str = global_cfg.get("mode", "train") | |
| self.formulation: Optional[str] = global_cfg.get("formulation", None) | |
| # --- Model settings --- | |
| model_cfg: dict = config_dict.get("model", {}) | |
| self.model_name: str = model_cfg.get("name", "") | |
| self.pretrained: str = model_cfg.get("pretrained", False) | |
| self.weights: str = model_cfg.get("weights", None) | |
| self.model_config_file: str = model_cfg.get("config", None) | |
| # --- Validation settings --- | |
| validation_cfg: dict = config_dict.get("validation", {}) | |
| self.val_batch_size: int = validation_cfg.get("batch_size", 1) | |
| self.val_steps_per_epoch: int = validation_cfg.get("steps_per_epoch", 256) | |
| self.val_loss: Optional[str] = validation_cfg.get("loss", None) | |
| # --- Dataset settings --- | |
| dataset_cfg: dict = config_dict.get("dataset", {}) | |
| self.dataset_name: str = dataset_cfg.get("name", "") | |
| self.image_size: Optional[str] = dataset_cfg.get("image_size", None) | |
| self.crop_size: Optional[str] = dataset_cfg.get("crop_size", None) | |
| self.num_workers: int = dataset_cfg.get("num_workers", 0) | |
| self.masking_ratio: int = dataset_cfg.get("masking_ratio", 1) | |
| # --- Logging settings --- | |
| logging_cfg: dict = config_dict.get("logging", {}) | |
| self.log_root: str = logging_cfg.get("root", "") | |
| self.exp_name: str = logging_cfg.get("exp_name", "") | |
| self.result_columns: List = logging_cfg.get("result_columns", []) | |
| self.save_weights: bool = logging_cfg.get("save_weights", True) | |
| self.save_only_best_weights: bool = logging_cfg.get( | |
| "save_only_best_weights", True | |
| ) | |
| self.enable_tensorboard: bool = logging_cfg.get("enable_tensorboard", False) | |
| self.log_figures: bool = logging_cfg.get("log_figures", False) | |
| self.log_interval: int = logging_cfg.get("log_interval", 1) | |
| def to_dict(self) -> dict: | |
| config_dict = { | |
| "global": { | |
| "device": self.device, | |
| "mode": self.mode, | |
| "formulation": self.formulation, | |
| }, | |
| "model": { | |
| "name": self.model_name, | |
| "pretrained": self.pretrained, | |
| "weights": self.weights, | |
| "config": self.model_config_file, | |
| }, | |
| "validation": { | |
| "batch_size": self.val_batch_size, | |
| "steps_per_epoch": self.val_steps_per_epoch, | |
| "loss": self.val_loss, | |
| }, | |
| "dataset": { | |
| "name": self.dataset_name, | |
| "image_size": self.image_size, | |
| "crop_size": self.crop_size, | |
| "num_workers": self.num_workers, | |
| "masking_ratio": self.masking_ratio, | |
| }, | |
| "logging": { | |
| "root": self.log_root, | |
| "exp_name": self.exp_name, | |
| "result_columns": self.result_columns, | |
| "save_weights": self.save_weights, | |
| "save_only_best_weights": self.save_only_best_weights, | |
| "enable_tensorboard": self.enable_tensorboard, | |
| "log_figures": self.log_figures, | |
| "log_interval": self.log_interval, | |
| }, | |
| } | |
| return config_dict | |
| def save_config(self, output_fp: str) -> None: | |
| """ | |
| Save the current configuration to a YAML file, preserving the original format. | |
| """ | |
| config_dict = { | |
| "global": { | |
| "device": self.device, | |
| "mode": self.mode, | |
| "formulation": self.formulation, | |
| }, | |
| "model": { | |
| "name": self.model_name, | |
| "pretrained": self.pretrained, | |
| "weights": self.weights, | |
| "config": self.model_config_file, | |
| }, | |
| "validation": { | |
| "batch_size": self.val_batch_size, | |
| "steps_per_epoch": self.val_steps_per_epoch, | |
| "loss": self.val_loss, | |
| }, | |
| "dataset": { | |
| "name": self.dataset_name, | |
| "image_size": self.image_size, | |
| "crop_size": self.crop_size, | |
| "num_workers": self.num_workers, | |
| "masking_ratio": self.masking_ratio, | |
| }, | |
| "logging": { | |
| "root": self.log_root, | |
| "exp_name": self.exp_name, | |
| "result_columns": self.result_columns, | |
| "save_weights": self.save_weights, | |
| "save_only_best_weights": self.save_only_best_weights, | |
| "enable_tensorboard": self.enable_tensorboard, | |
| "log_figures": self.log_figures, | |
| "log_interval": self.log_interval, | |
| }, | |
| } | |
| # save config | |
| with open(output_fp, "w") as f: | |
| yaml.safe_dump(config_dict, f, default_flow_style=False, sort_keys=False) | |
| class SurrogateEvalConfig: | |
| """ | |
| Object representing a config file for an evaluation run of an older-surrogate model pair. | |
| """ | |
| def __init__(self, config_fp: str): | |
| if not os.path.isfile(config_fp): | |
| raise FileNotFoundError(f"Config file not found: {config_fp}") | |
| config_dict: dict = parse_config(config_fp) | |
| # --- Global settings --- | |
| global_cfg: dict = config_dict.get("global", {}) | |
| self.device: int = global_cfg.get("device", 0) | |
| self.mode: str = global_cfg.get("mode", "train") | |
| self.formulation: Optional[str] = global_cfg.get("formulation", None) | |
| # --- Denoising model settings --- | |
| denoising_model_cfg: dict = config_dict.get("denoising_model", {}) | |
| self.denoising_model_name: str = denoising_model_cfg.get("name", "") | |
| self.denoising_model_pretrained: bool = denoising_model_cfg.get( | |
| "pretrained", False | |
| ) | |
| self.denoising_model_weights: Optional[str] = denoising_model_cfg.get( | |
| "weights", None | |
| ) | |
| self.denoising_model_config_file: Optional[str] = denoising_model_cfg.get( | |
| "config", None | |
| ) | |
| # --- Older surrogate model settings --- | |
| older_surrogate_model_cfg: dict = config_dict.get("older_surrogate_model", {}) | |
| self.older_surrogate_model_name: str = older_surrogate_model_cfg.get("name", "") | |
| self.older_surrogate_model_pretrained: bool = older_surrogate_model_cfg.get( | |
| "pretrained", False | |
| ) | |
| self.older_surrogate_model_weights: Optional[str] = ( | |
| older_surrogate_model_cfg.get("weights", None) | |
| ) | |
| self.older_surrogate_model_config_file: Optional[str] = ( | |
| older_surrogate_model_cfg.get("config", None) | |
| ) | |
| # --- Training settings --- | |
| training_cfg: dict = config_dict.get("training", {}) | |
| self.train_batch_size: int = training_cfg.get("batch_size", 1) | |
| self.train_steps_per_epoch: int = training_cfg.get("steps_per_epoch", 1024) | |
| self.epochs: int = training_cfg.get("epochs", 100) | |
| self.train_loss: Optional[str] = training_cfg.get("loss", None) | |
| self.lr: float = training_cfg.get("lr", 1e-4) | |
| self.optimizer: Optional[str] = training_cfg.get("optimizer", "Adam") | |
| # --- Validation settings --- | |
| validation_cfg: dict = config_dict.get("validation", {}) | |
| self.val_batch_size: int = validation_cfg.get("batch_size", 1) | |
| self.val_steps_per_epoch: int = validation_cfg.get("steps_per_epoch", 256) | |
| self.val_loss: Optional[str] = validation_cfg.get("loss", None) | |
| # --- Dataset settings --- | |
| dataset_cfg: dict = config_dict.get("dataset", {}) | |
| self.dataset_name: str = dataset_cfg.get("name", "") | |
| self.image_size: Optional[int] = dataset_cfg.get("image_size", None) | |
| self.crop_size: Optional[int] = dataset_cfg.get("crop_size", None) | |
| self.num_workers: int = dataset_cfg.get("num_workers", 0) | |
| self.masking_ratio: int = dataset_cfg.get("masking_ratio", 1) | |
| # --- Logging settings --- | |
| logging_cfg: dict = config_dict.get("logging", {}) | |
| self.log_root: str = logging_cfg.get("root", "") | |
| self.exp_name: str = logging_cfg.get("exp_name", "") | |
| self.result_columns: List = logging_cfg.get("result_columns", []) | |
| self.save_weights: bool = logging_cfg.get("save_weights", True) | |
| self.save_only_best_weights: bool = logging_cfg.get( | |
| "save_only_best_weights", True | |
| ) | |
| self.enable_tensorboard: bool = logging_cfg.get("enable_tensorboard", False) | |
| self.log_figures: bool = logging_cfg.get("log_figures", False) | |
| self.log_interval: int = logging_cfg.get("log_interval", 1) | |
| def to_dict(self) -> dict: | |
| """ | |
| Return the configuration as a dictionary matching the YAML file structure. | |
| """ | |
| config_dict = { | |
| "global": { | |
| "device": self.device, | |
| "mode": self.mode, | |
| "formulation": self.formulation, | |
| }, | |
| "denoising_model": { | |
| "name": self.denoising_model_name, | |
| "pretrained": self.denoising_model_pretrained, | |
| "weights": self.denoising_model_weights, | |
| "config": self.denoising_model_config_file, | |
| }, | |
| "older_surrogate_model": { | |
| "name": self.older_surrogate_model_name, | |
| "pretrained": self.older_surrogate_model_pretrained, | |
| "weights": self.older_surrogate_model_weights, | |
| "config": self.older_surrogate_model_config_file, | |
| }, | |
| "training": { | |
| "batch_size": self.train_batch_size, | |
| "steps_per_epoch": self.train_steps_per_epoch, | |
| "epochs": self.epochs, | |
| "loss": self.train_loss, | |
| "lr": self.lr, | |
| "optimizer": self.optimizer, | |
| }, | |
| "validation": { | |
| "batch_size": self.val_batch_size, | |
| "steps_per_epoch": self.val_steps_per_epoch, | |
| "loss": self.val_loss, | |
| }, | |
| "dataset": { | |
| "name": self.dataset_name, | |
| "image_size": self.image_size, | |
| "crop_size": self.crop_size, | |
| "num_workers": self.num_workers, | |
| "masking_ratio": self.masking_ratio, | |
| }, | |
| "logging": { | |
| "root": self.log_root, | |
| "exp_name": self.exp_name, | |
| "result_columns": self.result_columns, | |
| "save_weights": self.save_weights, | |
| "save_only_best_weights": self.save_only_best_weights, | |
| "enable_tensorboard": self.enable_tensorboard, | |
| "log_figures": self.log_figures, | |
| "log_interval": self.log_interval, | |
| }, | |
| } | |
| return config_dict | |
| def save_config(self, output_fp: str) -> None: | |
| """ | |
| Save the current configuration to a YAML file, preserving the original structure. | |
| """ | |
| config_dict = self.to_dict() | |
| with open(output_fp, "w") as f: | |
| yaml.safe_dump(config_dict, f, default_flow_style=False, sort_keys=False) | |
| class ModelConfig: | |
| """ | |
| Object representing a config file for a SwinIR model. | |
| """ | |
| def __init__(self, config_fp: str): | |
| if not os.path.isfile(config_fp): | |
| raise FileNotFoundError(f"Config file not found: {config_fp}") | |
| config_dict: dict = parse_config(config_fp) | |
| # --- Top-level setting: weights_fp --- | |
| self.weights_fp: str = config_dict.get("weights_fp", "") | |
| # --- Hyperparameters --- | |
| hyperparams: dict = config_dict.get("hyperparams", {}) | |
| self.upscale: int = hyperparams.get("upscale", 8) | |
| self.img_size: List[int] = hyperparams.get("img_size", [128, 128]) | |
| self.window_size: int = hyperparams.get("window_size", 8) | |
| self.img_range: float = hyperparams.get("img_range", 1.0) | |
| self.depths: List[int] = hyperparams.get("depths", [8, 8, 8, 8, 8, 8]) | |
| self.embed_dim: int = hyperparams.get("embed_dim", 180) | |
| self.num_heads: List[int] = hyperparams.get("num_heads", [6, 6, 6, 6, 6, 6]) | |
| self.mlp_ratio: int = hyperparams.get("mlp_ratio", 2) | |
| self.upsampler: str = hyperparams.get("upsampler", "no_upscale") | |
| self.resi_connection: str = hyperparams.get("resi_connection", "1conv") | |
| self.drop_path_rate = ( | |
| config_dict.get("hyperparams", {}).get("drop_path_rate", 0.1), | |
| ) | |
| # layer norm | |
| self.layer_norm_str = config_dict.get("hyperparams", {}).get("norm_layer", None) | |
| layer_norm = ( | |
| torch.nn.LayerNorm if self.layer_norm_str == "torch.nn.LayerNorm" else None | |
| ) | |
| self.norm_layer = layer_norm | |
| def to_dict(self) -> dict: | |
| config_dict = { | |
| "weights_fp": self.weights_fp, | |
| "hyperparams": { | |
| "upscale": self.upscale, | |
| "img_size": self.img_size, | |
| "window_size": self.window_size, | |
| "img_range": self.img_range, | |
| "depths": self.depths, | |
| "embed_dim": self.embed_dim, | |
| "num_heads": self.num_heads, | |
| "mlp_ratio": self.mlp_ratio, | |
| "drop_path_rate": self.drop_path_rate, | |
| "norm_layer": self.layer_norm_str, | |
| "upsampler": self.upsampler, | |
| "resi_connection": self.resi_connection, | |
| }, | |
| } | |
| return config_dict | |
| def save_config(self, output_fp: str) -> None: | |
| """ | |
| Save the current configuration to a YAML file, preserving the original format. | |
| """ | |
| config_dict = { | |
| "weights_fp": self.weights_fp, | |
| "hyperparams": { | |
| "upscale": self.upscale, | |
| "img_size": self.img_size, | |
| "window_size": self.window_size, | |
| "img_range": self.img_range, | |
| "depths": self.depths, | |
| "embed_dim": self.embed_dim, | |
| "num_heads": self.num_heads, | |
| "mlp_ratio": self.mlp_ratio, | |
| "upsampler": self.upsampler, | |
| "resi_connection": self.resi_connection, | |
| }, | |
| } | |
| with open(output_fp, "w") as f: | |
| yaml.safe_dump(config_dict, f, default_flow_style=False, sort_keys=False) | |