| import os
|
| import yaml
|
| import torch
|
| import argparse
|
| from tqdm import tqdm
|
| from torch.utils.data import DataLoader
|
| from safetensors.torch import save_file, load_file
|
|
|
| from interposer import InterposerModel
|
| from dataset import LatentDataset, FileLatentDataset, load_evals
|
| from vae import load_vae
|
|
|
| torch.backends.cudnn.benchmark = True
|
| torch.manual_seed(0)
|
|
|
| def parse_args():
|
| parser = argparse.ArgumentParser(description="Train latent interposer model")
|
| parser.add_argument("--config", help="Config for training")
|
| args = parser.parse_args()
|
| with open(args.config) as f:
|
| conf = yaml.safe_load(f)
|
| args.dataset = argparse.Namespace(**conf.pop("dataset"))
|
| args.model = argparse.Namespace(**conf.pop("model"))
|
| return argparse.Namespace(**vars(args), **conf)
|
|
|
| def eval_images(model, vae, evals):
|
| preds = eval_model(model, evals, loss=False)
|
| out = {}
|
| for name, pred in preds.items():
|
| images = vae.decode(pred).cpu().float()
|
|
|
| out[f"eval/{name}"] = images[0]
|
| return out
|
|
|
| def eval_model(model, evals, loss=True):
|
| model.eval()
|
| preds = {}
|
| losses = []
|
| for name, data in evals.items():
|
| src = data["src"].to(args.device)
|
| dst = data["dst"].to(args.device)
|
| with torch.no_grad():
|
| pred = model(src)
|
| if loss:
|
| loss = torch.nn.functional.l1_loss(dst, pred)
|
| losses.append(loss)
|
| else:
|
| preds[name] = pred
|
| model.train()
|
| if loss:
|
| return (sum(losses) / len(losses)).data.item()
|
| else:
|
| return preds
|
|
|
|
|
| def weights_init(m):
|
| classname = m.__class__.__name__
|
| if classname.find('Conv') != -1:
|
| torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
|
| elif classname.find('BatchNorm') != -1:
|
| torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
|
| torch.nn.init.constant_(m.bias.data, 0)
|
|
|
| if __name__ == "__main__":
|
| args = parse_args()
|
| base_name = f"models/{args.model.src}-to-{args.model.dst}_interposer-{args.model.rev}"
|
|
|
|
|
| if os.path.isfile(args.dataset.src):
|
| dataset = FileLatentDataset(
|
| args.dataset.src,
|
| args.dataset.dst,
|
| )
|
| elif os.path.isdir(args.dataset.src):
|
| dataset = LatentDataset(
|
| args.dataset.src,
|
| args.dataset.dst,
|
| args.dataset.preload
|
| )
|
| else:
|
| raise OSError(f"Missing dataset source {args.dataset.src}")
|
| loader = DataLoader(
|
| dataset,
|
| batch_size = args.batch,
|
| shuffle = True,
|
| drop_last = True,
|
| pin_memory = False,
|
| num_workers = 0,
|
|
|
|
|
| )
|
|
|
|
|
| try:
|
| evals = load_evals(args.dataset.evals)
|
| except:
|
| print(f"No evals, fallback to dataset.")
|
| evals = dataset[0]
|
|
|
|
|
| crit = torch.nn.L1Loss()
|
| optim_args = {
|
| "lr": args.optim["lr"],
|
| "betas": (args.optim["beta1"], args.optim["beta2"])
|
| }
|
|
|
|
|
| model = InterposerModel(**args.model.args)
|
| model.apply(weights_init)
|
| model.to(args.device)
|
| optim = torch.optim.AdamW(model.parameters(), **optim_args)
|
|
|
|
|
| model_back = InterposerModel(
|
| ch_in = args.model.args["ch_out"],
|
| ch_mid = args.model.args["ch_mid"],
|
| ch_out = args.model.args["ch_in"],
|
| scale = 1.0 / args.model.args["scale"],
|
| blocks = args.model.args["blocks"],
|
| )
|
| model_back.apply(weights_init)
|
| model_back.to(args.device)
|
| optim_back = torch.optim.AdamW(model_back.parameters(), **optim_args)
|
|
|
|
|
| scheduler = None
|
| if args.cosine:
|
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| optim,
|
| T_max = (args.steps - args.fconst),
|
| eta_min = 1e-8,
|
| )
|
|
|
|
|
| vae = None
|
| if args.save_image:
|
| vae = load_vae(args.model.dst, device=args.device, dtype=torch.float16, dec_only=True)
|
|
|
|
|
| import time
|
| from torch.utils.tensorboard import SummaryWriter
|
| writer = SummaryWriter(log_dir=f"{base_name}_{int(time.time())}")
|
|
|
| pbar = tqdm(total=args.steps)
|
| while pbar.n < args.steps:
|
| for batch in loader:
|
|
|
| src = batch.get("src").to(args.device)
|
| dst = batch.get("dst").to(args.device)
|
|
|
|
|
| optim.zero_grad()
|
| logs = {}
|
| loss = []
|
| with torch.cuda.amp.autocast():
|
|
|
| pred = model(src)
|
|
|
| p_loss = crit(pred, dst) * args.p_loss_weight
|
| loss.append(p_loss)
|
| logs["p_loss"] = p_loss.data.item()
|
|
|
|
|
| if args.r_loss_weight:
|
| pred_back = model_back(pred)
|
|
|
| r_loss = crit(pred_back, src) * args.r_loss_weight
|
| loss.append(r_loss)
|
| logs["r_loss"] = r_loss.data.item()
|
|
|
|
|
| loss = sum(loss)
|
| logs["main"] = loss.data.item()
|
| loss.backward()
|
| optim.step()
|
|
|
|
|
| for name, value in logs.items():
|
| writer.add_scalar(f"loss/{name}", value, pbar.n)
|
|
|
|
|
| if args.r_loss_weight:
|
| optim_back.zero_grad()
|
| logs = {}
|
| loss = []
|
| with torch.cuda.amp.autocast():
|
|
|
| pred = model_back(dst)
|
|
|
| p_loss = crit(pred, src) * args.b_loss_weight
|
| loss.append(p_loss)
|
| logs["p_loss"] = p_loss.data.item()
|
|
|
|
|
| if args.h_loss_weight:
|
| pred_back = model(pred)
|
|
|
| r_loss = crit(pred_back, dst) * args.h_loss_weight
|
| loss.append(r_loss)
|
| logs["r_loss"] = r_loss.data.item()
|
|
|
|
|
| loss = sum(loss)
|
| logs["main"] = loss.data.item()
|
| loss.backward()
|
| optim_back.step()
|
|
|
|
|
| for name, value in logs.items():
|
| writer.add_scalar(f"loss_aux/{name}", value, pbar.n)
|
|
|
|
|
| if args.eval_model and pbar.n % args.eval_model == 0:
|
| writer.add_scalar("loss/eval_loss", eval_model(model, evals), pbar.n)
|
| if args.save_image and pbar.n % args.save_image == 0:
|
| for name, image in eval_images(model, vae, evals).items():
|
| writer.add_image(name, image, pbar.n)
|
|
|
|
|
| if scheduler is not None and pbar.n >= args.fconst:
|
| lr = scheduler.get_last_lr()[0]
|
| scheduler.step()
|
| else:
|
| lr = args.optim["lr"]
|
| writer.add_scalar("lr/model", lr, pbar.n)
|
|
|
|
|
| writer.add_scalar("lr/model_aux", args.optim["lr"], pbar.n)
|
|
|
|
|
| pbar.update()
|
| if pbar.n > args.steps:
|
| break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| pbar.close()
|
| writer.close()
|
|
|
| save_file(model.state_dict(), f"{base_name}.safetensors")
|
| torch.save(optim.state_dict(), f"{base_name}.optim.pth")
|
|
|