Spaces:
Running
Running
File size: 8,748 Bytes
0917e8d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 |
"""
Training script for `sparse-bto` models.
"""
import torch
import yaml
import wandb
import importlib
import argparse
import warnings
from tqdm import tqdm
from pathlib import Path
from typing import Any
from torch.utils.data import DataLoader
from src.util.metrics import PSNR, SSIM
from src.util.logger import Logger
from src.util.config import LOSS_FUNCTIONS, OPTIMIZERS, MODELS
from src.util.moment_based import cal_moment_based_errs
from src.models.our_method.swin_cafm import SwinCAFM
from src.datasets.mos2_sr import BTOSRDataset, BTO_MANY_RES
warnings.simplefilter("ignore")
def _init_module_from_target(mod_config: dict, *, additional_args: dict={}) -> Any:
"""
Init a module from a module config dict,
expect keywords `target` and `args`.
"""
mod_path, cls_name = mod_config["target"].rsplit(".", 1)
module = importlib.import_module(mod_path)
cls = getattr(module, cls_name)
args: dict = mod_config.get("args", {})
args.update(additional_args)
return cls(**args)
def train(config: dict) -> None:
logger = _init_module_from_target(config["logger"])
# some cleaver run initiatization
if bool(config['wandb']['use_wandb']) == True:
_init_module_from_target(config['wandb']['login'])
_init_module_from_target(config['wandb']['init'])
# init datasets/dataloaders
train_dataset = _init_module_from_target(config['train_args']['dataset'])
val_dataset = _init_module_from_target(config['val_args']['dataset'])
train_dataloader = DataLoader(train_dataset, batch_size = int(config['train_args']['batch_size']), shuffle=False)
val_dataloader = DataLoader(val_dataset, batch_size = int(config['val_args']['batch_size']), shuffle=False)
# init loss
train_loss = _init_module_from_target(config['train_args']['loss'])
val_loss = _init_module_from_target(config['val_args']['loss'])
model: torch.nn.Module = _init_module_from_target(config['model'])
model.float().cuda()
# init optim
optimizer: torch.optim.optimizer.Optimizer = _init_module_from_target(config['train_args']['optimizer'], additional_args={"params": model.parameters()})
# for weight saving
best_validation_loss = float("inf")
# main training loop
for epoch in range(int(config['train_args']['num_epochs'])):
# train
model.train()
for step, item in tqdm(enumerate(train_dataloader), desc=f"🚀 Training Epoch: {epoch + 1}/{int(config['train_args']['num_epochs'])}", total=int(config['train_args']['dataset']['args']['steps_per_epoch'])):
X :torch.Tensor = item["X"].float().cuda()
X_sparse:torch.Tensor = item["X_sparse"].float().cuda()
# zero gradients
optimizer.zero_grad()
# ---- forward: p(y | y_sparse) ----
X_hat: torch.Tensor = model(X_sparse)
loss: torch.Tensor = train_loss(X_hat, X)
loss.backward()
optimizer.step()
# ---- log ----
# calculate moment-based errors
mb_errs = cal_moment_based_errs(X_hat, X)
train_mb_errs = {}
for k in mb_errs:
train_mb_errs['train_' + k] = mb_errs[k]
X = torch.clip(X, 0, 1)
X_hat = torch.clip(X_hat, 0, 1)
X_il : torch.Tensor = X.unsqueeze(1).repeat(1, 3, 1, 1)
X_hat_il: torch.Tensor = X_hat.unsqueeze(1).repeat(1, 3, 1, 1)
psnr = PSNR(X_il, X_hat_il, (0, 1))
ssim = SSIM(X_il, X_hat_il, (0, 1))
logger.log(
**{
"global_train_step": len(train_dataloader) * (epoch) + step,
"global_val_step": None,
"epoch": epoch,
"train_loss": loss.item(),
"val_loss": None,
}
)
if bool(config['wandb']['use_wandb']) == True:
log = {
"epoch": epoch,
"train_l1_loss": loss.item(),
"train_psnr": psnr,
"train_ssim": ssim,
}
log.update(train_mb_errs)
wandb.log(log)
# log figures every 100 steps
if step % 100 != 0:
continue
triplet_name = f"train_epoch_{epoch}_step_{step}.png"
if isinstance(logger, Logger) and bool(config['wandb']['use_wandb']) == True:
fig = logger.log_colorized_tensors(
(X, "Target (X)"),
(X_sparse, "Model Input (X_sparse)"),
(X_hat, "Model Prediction"),
file_name=triplet_name,
)
wandb.log({"Train Qualitative Results": wandb.Image(fig)})
# validate
model.eval()
running_val_loss = 0.
with torch.no_grad():
for step, item in tqdm(enumerate(val_dataloader), desc=f"🚀 Validation Epoch: {epoch + 1}/{int(config['train_args']['num_epochs'])}", total=int(config['val_args']['dataset']['args']['steps_per_epoch'])):
X :torch.Tensor = item["X"].float().cuda()
X_sparse:torch.Tensor = item["X_sparse"].float().cuda()
# ---- forward: p(y | y_sparse) ----
X_hat: torch.Tensor = model(X_sparse)
loss = val_loss(X_hat, X)
# calculate moment-based errors
mb_errs = cal_moment_based_errs(X_hat, X)
val_mb_errs = {}
for k in mb_errs:
val_mb_errs['val_' + k] = mb_errs[k]
X = torch.clip(X, 0, 1)
X_hat = torch.clip(X_hat, 0, 1)
# ---- add dummy dims for PSNR/SSIM ----
X_il: torch.Tensor = X.unsqueeze(1).repeat(1, 3, 1, 1)
X_hat_il: torch.Tensor = X_hat.unsqueeze(1).repeat(1, 3, 1, 1)
psnr = PSNR(X_il, X_hat_il, (0, 1))
ssim = SSIM(X_il, X_hat_il, (0, 1))
logger.log(
**{
"global_train_step": None,
"global_val_step": len(val_dataloader) * (epoch) + step,
"epoch": epoch,
"train_loss": None,
"val_loss": loss.item(),
}
)
if bool(config['wandb']['use_wandb']) == True:
log = {
"epoch": epoch,
"val_l1_loss": loss.item(),
"val_psnr": psnr,
"val_ssim": ssim,
}
log.update(val_mb_errs)
wandb.log(log)
# log figures every 100 steps
if step % 100 != 0:
continue
if isinstance(logger, Logger) and bool(config['wandb']['use_wandb']) == True:
triplet_name = f"val_epoch_{epoch}_step_{step}.png"
fig = logger.log_colorized_tensors(
(X, "Target (X)"),
(X_sparse, "Model Input (X_sparse)"),
(X_hat, "Model Prediction (X_hat)"),
file_name=triplet_name,
)
wandb.log({"Val Qualitative Results": wandb.Image(fig)})
# accumulate validation loss
running_val_loss += loss.item()
total_val_steps = int(config['val_args']['dataset']['args']['steps_per_epoch'])
avg_val_loss = running_val_loss / total_val_steps
# if best validation perf, save model weights
if avg_val_loss < best_validation_loss:
best_validation_loss = avg_val_loss
logger.save_weights(model, f"best_epoch_{epoch}")
def main(config: dict) -> None:
train(config)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config", type=str, help="Exeriment run .yaml config.", default="")
args = parser.parse_args()
assert str(args.config).endswith(".yaml"), f"Error: run config must be a `.yaml` file."
assert Path(str(args.config)).is_file(), f"Error: config is not a valid file."
config_path = Path(str(args.config))
try:
with open(str(args.config), "r") as f:
config = yaml.safe_load(f)
except Exception as e:
print(f"Error: exception opening config: {e}")
raise Exception()
main(config)
|