Spaces:
Running
Running
| import os | |
| import sys | |
| import wandb | |
| import random | |
| import argparse | |
| import warnings | |
| import torch | |
| import torch.nn as nn | |
| from tqdm import tqdm | |
| from pathlib import Path | |
| from rich.console import Console | |
| from typing import List, Optional | |
| from torch.utils.data import DataLoader | |
| from src.util.metrics import ( | |
| PSNR, | |
| SSIM, | |
| RMSE_surface_roughness_l1, | |
| ) | |
| from src.models.unet.unet import UNetSR | |
| from src.models.our_method.swin_cafm import SwinCAFM | |
| from src.datasets.mos2_sr import ( | |
| BTOSRDataset, | |
| UnifiedMOS2SRDataset, | |
| MOS2SRDataset, | |
| MOS2_SEF_FULL_RES_SRC_DIR, | |
| MOS2_SILICON_DIR, | |
| MOS2_SAPPHIRE_DIR, | |
| MOS2_SYNTHETIC, | |
| BTO_MANY_RES, | |
| ) | |
| from src.util.logger import ExperimentLogger | |
| from src.util.config import ( | |
| TrainConfig, | |
| ModelConfig, | |
| LOSS_FUNCTIONS, | |
| OPTIMIZERS, | |
| MODELS, | |
| ) | |
| from src.util.loss import roughness_loss, rotation_invariant_l1_loss | |
| warnings.simplefilter("always") | |
| torch.multiprocessing.set_sharing_strategy("file_system") | |
| TRAIN_CONFIG_FP = os.path.abspath("configs/train.yaml") | |
| CONSOLE = Console() | |
| def setup_logger( | |
| train_config: TrainConfig, model_config: Optional[ModelConfig] | |
| ) -> ExperimentLogger: | |
| logger = ExperimentLogger( | |
| train_config_dict=train_config.to_dict(), | |
| model_config_dict=model_config.to_dict() if model_config != None else None, | |
| root=train_config.log_root, | |
| exp_name=train_config.exp_name, | |
| log_interval=train_config.log_interval, | |
| ) | |
| logger.add_result_columns(train_config.result_columns) | |
| return logger | |
| def create_model(config: TrainConfig) -> nn.Module: | |
| model_fn = MODELS[config.model_name]["fn"] | |
| model_weights = MODELS[config.model_name]["weights"] | |
| if model_weights: | |
| model = model_fn(weights=model_weights) | |
| elif config.model_name == "hiera": | |
| model = model_fn | |
| model.freeze() | |
| else: | |
| model = model_fn() | |
| assert isinstance(model, nn.Module) | |
| return model | |
| def create_dataloader(args, config: TrainConfig, split: str) -> DataLoader: | |
| assert str(args.dataset) in [ | |
| "all", | |
| "synthetic", | |
| "bto", | |
| "mos2-sef", | |
| "sapphire", | |
| "silicon", | |
| ] | |
| src_dir = { | |
| "all": None, | |
| "synthetic": MOS2_SYNTHETIC, | |
| "mos2-sef": MOS2_SEF_FULL_RES_SRC_DIR, | |
| "sapphire": MOS2_SAPPHIRE_DIR, | |
| "silicon": MOS2_SILICON_DIR, | |
| "bto": BTO_MANY_RES, | |
| }[args.dataset] | |
| dataset = None | |
| if str(args.dataset) == "all": | |
| dataset = UnifiedMOS2SRDataset( | |
| split=split, | |
| steps_per_epoch=( | |
| int(config.steps_per_epoch * config.train_batch_size) | |
| if split == "train" | |
| else config.val_steps_per_epoch | |
| ), | |
| upsample_factor=int(args.upsample_factor), | |
| ) | |
| elif str(args.dataset) == "bto": | |
| dataset = BTOSRDataset( | |
| steps_per_epoch=( | |
| int(config.steps_per_epoch * config.train_batch_size) | |
| if split == "train" | |
| else config.val_steps_per_epoch | |
| ), | |
| upsample_factor=int(args.upsample_factor), | |
| ) | |
| else: | |
| dataset = MOS2SRDataset( | |
| src_dir=src_dir, | |
| split=split, | |
| steps_per_epoch=( | |
| int(config.steps_per_epoch * config.train_batch_size) | |
| if split == "train" | |
| else config.val_steps_per_epoch | |
| ), | |
| upsample_factor=int(args.upsample_factor), | |
| ) | |
| return DataLoader( | |
| dataset, | |
| batch_size=( | |
| config.train_batch_size if split == "train" else config.val_batch_size | |
| ), | |
| shuffle=False, | |
| num_workers=config.num_workers, | |
| ) | |
| def train( | |
| args, | |
| config: TrainConfig, | |
| model_config: Optional[ModelConfig] = None, | |
| ) -> None: | |
| logger = setup_logger(config, model_config) | |
| # wandb login | |
| wandb.login(key="3d8c09b359c1abc995fd03c27398c41afce857c1") | |
| wandb.init( | |
| entity="team-levi", | |
| project="sparse-cafm", | |
| config=config.to_dict(), | |
| name=str(args.exp_name), | |
| ) | |
| # HACK: just loading a torch .pth file | |
| # model = create_model(config) | |
| # model = SwinCAFM.init_from_config(model_config.to_dict()) | |
| model = torch.load(str(args.weights)) | |
| train_dataloader = create_dataloader(args, config, "train") | |
| val_dataloader = create_dataloader(args, config, "val") | |
| # define loss function and optimizer | |
| train_loss: torch.nn.Module = LOSS_FUNCTIONS[config.train_loss]() | |
| val_loss: torch.nn.Module = LOSS_FUNCTIONS[config.val_loss]() | |
| # use to save model checkpoints | |
| best_val_loss = float("inf") | |
| num_epochs = config.epochs | |
| device = config.device | |
| # as per: https://arxiv.org/pdf/2404.00722 | |
| optimizer = torch.optim.Adam(model.parameters(), lr=float(config.learning_rate)) | |
| # assert isinstance(model, SwinCAFM) | |
| # HACK: randomly init weights | |
| # model.apply(model._init_weights) | |
| model.cuda(device) | |
| model.float() | |
| # ---------- training loop ---------- | |
| for epoch in range(num_epochs): | |
| model.train() | |
| for step, batch in enumerate( | |
| tqdm(train_dataloader, desc=f"Training: Epoch {epoch+1}/{num_epochs}") | |
| ): | |
| # [0, 1] | |
| # NOTE: manually specifing X vs y | |
| X = batch["X"].float().cuda() | |
| X_sparse = batch["X_sparse"].float().cuda() | |
| # zero gradients | |
| optimizer.zero_grad() | |
| # ---- forward: p(y | y_sparse) ---- | |
| X_hat: torch.Tensor = model(X_sparse) | |
| assert isinstance(train_dataloader.dataset, BTOSRDataset) | |
| rmse_sr_loss = RMSE_surface_roughness_l1( | |
| X, | |
| X_hat, | |
| train_dataloader.dataset.topo_maps_min, | |
| train_dataloader.dataset.topo_maps_max, | |
| ) | |
| # --- L1 ---- | |
| # loss = torch.nn.functional.l1_loss(X, X_hat) | |
| # --- L1 + surface_roughness ---- | |
| # EPS = 1.5 | |
| # loss = torch.nn.functional.l1_loss(X, X_hat) + (EPS * rmse_sr_loss) | |
| # --- surface_roughness --- | |
| loss = rotation_invariant_l1_loss( | |
| model, | |
| X, | |
| X_sparse, | |
| train_dataloader.dataset.topo_maps_min, | |
| train_dataloader.dataset.topo_maps_max, | |
| ) | |
| # backprop and step | |
| loss.backward() | |
| optimizer.step() | |
| # HACK: clip to [0, 1] | |
| X = torch.clip(X, 0, 1) | |
| X_hat = torch.clip(X_hat, 0, 1) | |
| # ---- add dummy dims for PSNR/SSIM ---- | |
| X_il : torch.Tensor = X.unsqueeze(1).repeat(1, 3, 1, 1) | |
| X_hat_il: torch.Tensor = X_hat.unsqueeze(1).repeat(1, 3, 1, 1) | |
| psnr = PSNR(X_il, X_hat_il, (0, 1)) | |
| ssim = SSIM(X_il, X_hat_il, (0, 1)) | |
| logger.log( | |
| **{ | |
| "global_train_step": len(train_dataloader) * (epoch) + step, | |
| "global_val_step": None, | |
| "epoch": epoch, | |
| "train_loss": loss.item(), | |
| "val_loss": None, | |
| } | |
| ) | |
| wandb.log( | |
| { | |
| "epoch": epoch, | |
| "train_l1_loss": loss.item(), | |
| "train_psnr": psnr, | |
| "train_ssim": ssim, | |
| "train_RMSE_surface_roughness_l1": rmse_sr_loss, | |
| } | |
| ) | |
| # log figures every 100 steps | |
| if step % 100 != 0: | |
| continue | |
| triplet_name = f"train_epoch_{epoch}_step_{step}.png" | |
| fig = logger.log_colorized_tensors( | |
| (X, "Target (X)"), | |
| (X_sparse, "Model Input (X_sparse)"), | |
| (X_hat, "Model Prediction"), | |
| file_name=triplet_name, | |
| ) | |
| wandb.log({"Train Qualitative Results": wandb.Image(fig)}) | |
| # validation | |
| model.eval() | |
| val_running_loss = 0.0 | |
| num_val_steps = 1 | |
| with torch.no_grad(): | |
| for i, batch in enumerate( | |
| tqdm(val_dataloader, desc=f"Validation: Epoch {epoch+1}/{num_epochs}") | |
| ): | |
| # NOTE: manually specifing X vs y | |
| X = batch["X"].float().cuda() | |
| X_sparse = batch["X_sparse"].float().cuda() | |
| # ---- forward: p(y | y_sparse) ---- | |
| X_hat: torch.Tensor = model(X_sparse) | |
| assert isinstance(train_dataloader.dataset, BTOSRDataset) | |
| rmse_sr_loss = RMSE_surface_roughness_l1( | |
| X, | |
| X_hat, | |
| train_dataloader.dataset.topo_maps_min, | |
| train_dataloader.dataset.topo_maps_max, | |
| ) | |
| # --- L1 ---- | |
| # loss = val_loss(X_hat, X) | |
| # --- Surface Roughness --- | |
| loss = roughness_loss( | |
| X_hat, | |
| X, | |
| train_dataloader.dataset.topo_maps_min, | |
| train_dataloader.dataset.topo_maps_max, | |
| ) | |
| val_running_loss += loss.item() * X.size(0) | |
| X = torch.clip(X, 0, 1) | |
| X_hat = torch.clip(X_hat, 0, 1) | |
| # ---- add dummy dims for PSNR/SSIM ---- | |
| X_il: torch.Tensor = X.unsqueeze(1).repeat(1, 3, 1, 1) | |
| X_hat_il: torch.Tensor = X_hat.unsqueeze(1).repeat(1, 3, 1, 1) | |
| psnr = PSNR(X_il, X_hat_il, (0, 1)) | |
| ssim = SSIM(X_il, X_hat_il, (0, 1)) | |
| logger.log( | |
| **{ | |
| "global_train_step": None, | |
| "global_val_step": len(val_dataloader) * (epoch) + i, | |
| "epoch": epoch, | |
| "train_loss": None, | |
| "val_loss": loss.item(), | |
| } | |
| ) | |
| wandb.log( | |
| { | |
| "epoch": epoch, | |
| "val_l1_loss": loss.item(), | |
| "val_psnr": psnr, | |
| "val_ssim": ssim, | |
| "val_RMSE_surface_roughness_l1": rmse_sr_loss, | |
| } | |
| ) | |
| # log figures every 100 steps | |
| if i % 100 != 0: | |
| continue | |
| triplet_name = f"val_epoch_{epoch}_step_{i}.png" | |
| fig = logger.log_colorized_tensors( | |
| (X, "Target (X)"), | |
| (X_sparse, "Model Input (X_sparse)"), | |
| (X_hat, "Model Prediction (X_hat)"), | |
| file_name=triplet_name, | |
| ) | |
| wandb.log({"Val Qualitative Results": wandb.Image(fig)}) | |
| # ++ | |
| num_val_steps += 1 | |
| # optional: log best/recent model weights | |
| avg_val_loss = val_running_loss / num_val_steps | |
| if not bool(config.save_weights): | |
| continue | |
| if bool(config.save_only_best_weights): | |
| if avg_val_loss < best_val_loss: | |
| best_val_loss = avg_val_loss | |
| logger.save_weights(model, "best") | |
| else: | |
| # NOTE: we overwrite previous "latest" weights | |
| logger.save_weights(model, f"latest") | |
| else: | |
| logger.save_weights(model, f"epoch_{epoch}") | |
| def main(args: argparse.Namespace) -> None: | |
| # load training config | |
| config = TrainConfig(TRAIN_CONFIG_FP) | |
| config.weights = args.weights | |
| model_config: Optional[ModelConfig] = None | |
| # optional: parse model config | |
| if config.model_config_file != None: | |
| model_config_abs_path = os.path.join( | |
| Path(TRAIN_CONFIG_FP).parent.__str__(), config.model_config_file | |
| ) | |
| assert os.path.isfile( | |
| model_config_abs_path | |
| ), f"Bad path to model config: {model_config_abs_path}" | |
| model_config = ModelConfig(model_config_abs_path) | |
| # -------------------- training config args -------------------- | |
| config.exp_name = args.exp_name | |
| config.log_root = args.root | |
| # config.learning_rate = str(args.learning_rate) | |
| # config.train_batch_size = int(args.batch_size) | |
| # -------------------- model config args -------------------- | |
| if model_config != None: | |
| # transformer block depths; e.g., [6, 6, 6, 6, 6, 6] | |
| model_config.depths = [args.depths] * args.num_blocks | |
| # num heads per block; e.g., [6, 6, 6, 6, 6, 6] | |
| model_config.num_heads = [args.num_heads] * args.num_blocks | |
| # size of sifted-attention window | |
| model_config.window_size = args.window_size | |
| model_config.drop_path_rate = args.drop_path_rate | |
| model_config.norm_layer = args.norm_layer | |
| args.upsample_factor = int(args.upsample_factor) | |
| # train | |
| train(args, config, model_config) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| # -------------------- training config args -------------------- | |
| parser.add_argument( | |
| "-e", | |
| "--exp_name", | |
| type=str, | |
| help="Experiment directory name.", | |
| default="my-experiment", | |
| ) | |
| parser.add_argument( | |
| "-r", | |
| "--root", | |
| type=str, | |
| help="Root directory to save experiment in.", | |
| default="__exps__/", | |
| ) | |
| parser.add_argument( | |
| "-ds", | |
| "--dataset", | |
| type=str, | |
| help="'synthetic', 'mos2-sef', 'sapphire', 'silicon', 'all']", | |
| default="mos2-sef", | |
| ) | |
| parser.add_argument( | |
| "-ws", "--weights", type=str, help="Path to model checkpoints", default="" | |
| ) | |
| parser.add_argument( | |
| "-fm", "--formulation", type=str, help="['X', 'y', 'both']", default="y" | |
| ) | |
| # -------------------- model config args -------------------- | |
| parser.add_argument( | |
| "-dps", "--depths", type=int, help="Depths of RSTB blocks", default=6 | |
| ) | |
| parser.add_argument( | |
| "-nbs", "--num_blocks", type=int, help="Number of RSTB blocks", default=6 | |
| ) | |
| parser.add_argument( | |
| "-nhs", | |
| "--num_heads", | |
| type=int, | |
| help="Number of heads per RSTB block", | |
| default=6, | |
| ) | |
| parser.add_argument( | |
| "-wsz", | |
| "--window_size", | |
| type=int, | |
| help="Size of shifted attention window", | |
| default=8, | |
| ) | |
| parser.add_argument("-dpr", "--drop_path_rate", type=float, help="", default=0.1) | |
| parser.add_argument( | |
| "-nlr", "--norm_layer", type=str, help="", default="torch.nn.LayerNorm" | |
| ) | |
| # -------------------- ablation args -------------------- | |
| parser.add_argument("-sw", "--surrogate_weights", type=str, help="", default="") | |
| parser.add_argument("-lr", "--learning_rate", type=float, help="", default=1e-5) | |
| parser.add_argument("-bs", "--batch_size", type=int, help="", default=1) | |
| parser.add_argument("-sr", "--upsample_factor", type=int, help="", default=2) | |
| args = parser.parse_args() | |
| main(args) | |