Spaces:
Running
Running
| """ | |
| Training script for `sparse-bto` models. | |
| """ | |
| import torch | |
| import yaml | |
| import wandb | |
| import importlib | |
| import argparse | |
| import warnings | |
| from tqdm import tqdm | |
| from pathlib import Path | |
| from typing import Any | |
| from torch.utils.data import DataLoader | |
| from src.util.metrics import PSNR, SSIM | |
| from src.util.logger import Logger | |
| from src.util.config import LOSS_FUNCTIONS, OPTIMIZERS, MODELS | |
| from src.util.moment_based import cal_moment_based_errs | |
| from src.models.our_method.swin_cafm import SwinCAFM | |
| from src.datasets.mos2_sr import BTOSRDataset, BTO_MANY_RES | |
| warnings.simplefilter("ignore") | |
| def _init_module_from_target(mod_config: dict, *, additional_args: dict={}) -> Any: | |
| """ | |
| Init a module from a module config dict, | |
| expect keywords `target` and `args`. | |
| """ | |
| mod_path, cls_name = mod_config["target"].rsplit(".", 1) | |
| module = importlib.import_module(mod_path) | |
| cls = getattr(module, cls_name) | |
| args: dict = mod_config.get("args", {}) | |
| args.update(additional_args) | |
| return cls(**args) | |
| def train(config: dict) -> None: | |
| logger = _init_module_from_target(config["logger"]) | |
| # some cleaver run initiatization | |
| if bool(config['wandb']['use_wandb']) == True: | |
| _init_module_from_target(config['wandb']['login']) | |
| _init_module_from_target(config['wandb']['init']) | |
| # init datasets/dataloaders | |
| train_dataset = _init_module_from_target(config['train_args']['dataset']) | |
| val_dataset = _init_module_from_target(config['val_args']['dataset']) | |
| train_dataloader = DataLoader(train_dataset, batch_size = int(config['train_args']['batch_size']), shuffle=False) | |
| val_dataloader = DataLoader(val_dataset, batch_size = int(config['val_args']['batch_size']), shuffle=False) | |
| # init loss | |
| train_loss = _init_module_from_target(config['train_args']['loss']) | |
| val_loss = _init_module_from_target(config['val_args']['loss']) | |
| model: torch.nn.Module = _init_module_from_target(config['model']) | |
| model.float().cuda() | |
| # init optim | |
| optimizer: torch.optim.optimizer.Optimizer = _init_module_from_target(config['train_args']['optimizer'], additional_args={"params": model.parameters()}) | |
| # for weight saving | |
| best_validation_loss = float("inf") | |
| # main training loop | |
| for epoch in range(int(config['train_args']['num_epochs'])): | |
| # train | |
| model.train() | |
| for step, item in tqdm(enumerate(train_dataloader), desc=f"🚀 Training Epoch: {epoch + 1}/{int(config['train_args']['num_epochs'])}", total=int(config['train_args']['dataset']['args']['steps_per_epoch'])): | |
| X :torch.Tensor = item["X"].float().cuda() | |
| X_sparse:torch.Tensor = item["X_sparse"].float().cuda() | |
| # zero gradients | |
| optimizer.zero_grad() | |
| # ---- forward: p(y | y_sparse) ---- | |
| X_hat: torch.Tensor = model(X_sparse) | |
| loss: torch.Tensor = train_loss(X_hat, X) | |
| loss.backward() | |
| optimizer.step() | |
| # ---- log ---- | |
| # calculate moment-based errors | |
| mb_errs = cal_moment_based_errs(X_hat, X) | |
| train_mb_errs = {} | |
| for k in mb_errs: | |
| train_mb_errs['train_' + k] = mb_errs[k] | |
| X = torch.clip(X, 0, 1) | |
| X_hat = torch.clip(X_hat, 0, 1) | |
| X_il : torch.Tensor = X.unsqueeze(1).repeat(1, 3, 1, 1) | |
| X_hat_il: torch.Tensor = X_hat.unsqueeze(1).repeat(1, 3, 1, 1) | |
| psnr = PSNR(X_il, X_hat_il, (0, 1)) | |
| ssim = SSIM(X_il, X_hat_il, (0, 1)) | |
| logger.log( | |
| **{ | |
| "global_train_step": len(train_dataloader) * (epoch) + step, | |
| "global_val_step": None, | |
| "epoch": epoch, | |
| "train_loss": loss.item(), | |
| "val_loss": None, | |
| } | |
| ) | |
| if bool(config['wandb']['use_wandb']) == True: | |
| log = { | |
| "epoch": epoch, | |
| "train_l1_loss": loss.item(), | |
| "train_psnr": psnr, | |
| "train_ssim": ssim, | |
| } | |
| log.update(train_mb_errs) | |
| wandb.log(log) | |
| # log figures every 100 steps | |
| if step % 100 != 0: | |
| continue | |
| triplet_name = f"train_epoch_{epoch}_step_{step}.png" | |
| if isinstance(logger, Logger) and bool(config['wandb']['use_wandb']) == True: | |
| fig = logger.log_colorized_tensors( | |
| (X, "Target (X)"), | |
| (X_sparse, "Model Input (X_sparse)"), | |
| (X_hat, "Model Prediction"), | |
| file_name=triplet_name, | |
| ) | |
| wandb.log({"Train Qualitative Results": wandb.Image(fig)}) | |
| # validate | |
| model.eval() | |
| running_val_loss = 0. | |
| with torch.no_grad(): | |
| for step, item in tqdm(enumerate(val_dataloader), desc=f"🚀 Validation Epoch: {epoch + 1}/{int(config['train_args']['num_epochs'])}", total=int(config['val_args']['dataset']['args']['steps_per_epoch'])): | |
| X :torch.Tensor = item["X"].float().cuda() | |
| X_sparse:torch.Tensor = item["X_sparse"].float().cuda() | |
| # ---- forward: p(y | y_sparse) ---- | |
| X_hat: torch.Tensor = model(X_sparse) | |
| loss = val_loss(X_hat, X) | |
| # calculate moment-based errors | |
| mb_errs = cal_moment_based_errs(X_hat, X) | |
| val_mb_errs = {} | |
| for k in mb_errs: | |
| val_mb_errs['val_' + k] = mb_errs[k] | |
| X = torch.clip(X, 0, 1) | |
| X_hat = torch.clip(X_hat, 0, 1) | |
| # ---- add dummy dims for PSNR/SSIM ---- | |
| X_il: torch.Tensor = X.unsqueeze(1).repeat(1, 3, 1, 1) | |
| X_hat_il: torch.Tensor = X_hat.unsqueeze(1).repeat(1, 3, 1, 1) | |
| psnr = PSNR(X_il, X_hat_il, (0, 1)) | |
| ssim = SSIM(X_il, X_hat_il, (0, 1)) | |
| logger.log( | |
| **{ | |
| "global_train_step": None, | |
| "global_val_step": len(val_dataloader) * (epoch) + step, | |
| "epoch": epoch, | |
| "train_loss": None, | |
| "val_loss": loss.item(), | |
| } | |
| ) | |
| if bool(config['wandb']['use_wandb']) == True: | |
| log = { | |
| "epoch": epoch, | |
| "val_l1_loss": loss.item(), | |
| "val_psnr": psnr, | |
| "val_ssim": ssim, | |
| } | |
| log.update(val_mb_errs) | |
| wandb.log(log) | |
| # log figures every 100 steps | |
| if step % 100 != 0: | |
| continue | |
| if isinstance(logger, Logger) and bool(config['wandb']['use_wandb']) == True: | |
| triplet_name = f"val_epoch_{epoch}_step_{step}.png" | |
| fig = logger.log_colorized_tensors( | |
| (X, "Target (X)"), | |
| (X_sparse, "Model Input (X_sparse)"), | |
| (X_hat, "Model Prediction (X_hat)"), | |
| file_name=triplet_name, | |
| ) | |
| wandb.log({"Val Qualitative Results": wandb.Image(fig)}) | |
| # accumulate validation loss | |
| running_val_loss += loss.item() | |
| total_val_steps = int(config['val_args']['dataset']['args']['steps_per_epoch']) | |
| avg_val_loss = running_val_loss / total_val_steps | |
| # if best validation perf, save model weights | |
| if avg_val_loss < best_validation_loss: | |
| best_validation_loss = avg_val_loss | |
| logger.save_weights(model, f"best_epoch_{epoch}") | |
| def main(config: dict) -> None: | |
| train(config) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("-c", "--config", type=str, help="Exeriment run .yaml config.", default="") | |
| args = parser.parse_args() | |
| assert str(args.config).endswith(".yaml"), f"Error: run config must be a `.yaml` file." | |
| assert Path(str(args.config)).is_file(), f"Error: config is not a valid file." | |
| config_path = Path(str(args.config)) | |
| try: | |
| with open(str(args.config), "r") as f: | |
| config = yaml.safe_load(f) | |
| except Exception as e: | |
| print(f"Error: exception opening config: {e}") | |
| raise Exception() | |
| main(config) | |