dikdimon's picture
Upload extensions using SD-Hub extension
c336648 verified
import os
import yaml
import torch
import argparse
from tqdm import tqdm
from torch.utils.data import DataLoader
from safetensors.torch import save_file, load_file
from interposer import InterposerModel
from dataset import LatentDataset, FileLatentDataset, load_evals
from vae import load_vae
torch.backends.cudnn.benchmark = True
torch.manual_seed(0)
def parse_args():
parser = argparse.ArgumentParser(description="Train latent interposer model")
parser.add_argument("--config", help="Config for training")
args = parser.parse_args()
with open(args.config) as f:
conf = yaml.safe_load(f)
args.dataset = argparse.Namespace(**conf.pop("dataset"))
args.model = argparse.Namespace(**conf.pop("model"))
return argparse.Namespace(**vars(args), **conf)
def eval_images(model, vae, evals):
preds = eval_model(model, evals, loss=False)
out = {}
for name, pred in preds.items():
images = vae.decode(pred).cpu().float()
# for image in images: # eval isn't batched
out[f"eval/{name}"] = images[0]
return out
def eval_model(model, evals, loss=True):
model.eval()
preds = {}
losses = []
for name, data in evals.items():
src = data["src"].to(args.device)
dst = data["dst"].to(args.device)
with torch.no_grad():
pred = model(src)
if loss:
loss = torch.nn.functional.l1_loss(dst, pred)
losses.append(loss)
else:
preds[name] = pred
model.train()
if loss:
return (sum(losses) / len(losses)).data.item()
else:
return preds
# from pytorch GAN tutorial
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0)
if __name__ == "__main__":
args = parse_args()
base_name = f"models/{args.model.src}-to-{args.model.dst}_interposer-{args.model.rev}"
# dataset
if os.path.isfile(args.dataset.src):
dataset = FileLatentDataset(
args.dataset.src,
args.dataset.dst,
)
elif os.path.isdir(args.dataset.src):
dataset = LatentDataset(
args.dataset.src,
args.dataset.dst,
args.dataset.preload
)
else:
raise OSError(f"Missing dataset source {args.dataset.src}")
loader = DataLoader(
dataset,
batch_size = args.batch,
shuffle = True,
drop_last = True,
pin_memory = False,
num_workers = 0,
# num_workers = 6,
# persistent_workers=True,
)
# evals
try:
evals = load_evals(args.dataset.evals)
except:
print(f"No evals, fallback to dataset.")
evals = dataset[0]
# defaults
crit = torch.nn.L1Loss()
optim_args = {
"lr": args.optim["lr"],
"betas": (args.optim["beta1"], args.optim["beta2"])
}
# model
model = InterposerModel(**args.model.args)
model.apply(weights_init)
model.to(args.device)
optim = torch.optim.AdamW(model.parameters(), **optim_args)
# aux model for reverse pass
model_back = InterposerModel(
ch_in = args.model.args["ch_out"],
ch_mid = args.model.args["ch_mid"],
ch_out = args.model.args["ch_in"],
scale = 1.0 / args.model.args["scale"],
blocks = args.model.args["blocks"],
)
model_back.apply(weights_init)
model_back.to(args.device)
optim_back = torch.optim.AdamW(model_back.parameters(), **optim_args)
# scheduler
scheduler = None
if args.cosine:
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optim,
T_max = (args.steps - args.fconst),
eta_min = 1e-8,
)
# vae
vae = None
if args.save_image:
vae = load_vae(args.model.dst, device=args.device, dtype=torch.float16, dec_only=True)
# main loop
import time
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir=f"{base_name}_{int(time.time())}")
pbar = tqdm(total=args.steps)
while pbar.n < args.steps:
for batch in loader:
# get training data
src = batch.get("src").to(args.device)
dst = batch.get("dst").to(args.device)
### Train main model ###
optim.zero_grad()
logs = {}
loss = []
with torch.cuda.amp.autocast():
# pass first model
pred = model(src)
p_loss = crit(pred, dst) * args.p_loss_weight
loss.append(p_loss)
logs["p_loss"] = p_loss.data.item()
# pass second model
if args.r_loss_weight:
pred_back = model_back(pred)
r_loss = crit(pred_back, src) * args.r_loss_weight
loss.append(r_loss)
logs["r_loss"] = r_loss.data.item()
# loss logic
loss = sum(loss)
logs["main"] = loss.data.item()
loss.backward()
optim.step()
# logging
for name, value in logs.items():
writer.add_scalar(f"loss/{name}", value, pbar.n)
### Train backwards model ###
if args.r_loss_weight:
optim_back.zero_grad()
logs = {}
loss = []
with torch.cuda.amp.autocast():
# pass second model
pred = model_back(dst)
p_loss = crit(pred, src) * args.b_loss_weight
loss.append(p_loss)
logs["p_loss"] = p_loss.data.item()
# pass first model
if args.h_loss_weight: # better w/o this?
pred_back = model(pred)
r_loss = crit(pred_back, dst) * args.h_loss_weight
loss.append(r_loss)
logs["r_loss"] = r_loss.data.item()
# loss logic
loss = sum(loss)
logs["main"] = loss.data.item()
loss.backward()
optim_back.step()
# logging
for name, value in logs.items():
writer.add_scalar(f"loss_aux/{name}", value, pbar.n)
# run eval/save eval image
if args.eval_model and pbar.n % args.eval_model == 0:
writer.add_scalar("loss/eval_loss", eval_model(model, evals), pbar.n)
if args.save_image and pbar.n % args.save_image == 0:
for name, image in eval_images(model, vae, evals).items():
writer.add_image(name, image, pbar.n)
# scheduler logic main
if scheduler is not None and pbar.n >= args.fconst:
lr = scheduler.get_last_lr()[0]
scheduler.step()
else:
lr = args.optim["lr"]
writer.add_scalar("lr/model", lr, pbar.n)
# aux model doesn't have a scheduler
writer.add_scalar("lr/model_aux", args.optim["lr"], pbar.n)
# step
pbar.update()
if pbar.n > args.steps:
break
# hacky workaround when the colors are off.
# Save the last n versions and just pick the best one later.
# if pbar.n > (args.steps-2500) and pbar.n%500==0:
# from torchvision.utils import save_image
# save_file(model.state_dict(), f"{base_name}_{pbar.n:07}.safetensors")
# for name, image in eval_images(model, vae, evals).items():
# name = f"models/{name.replace('/', '_')}_{pbar.n:07}.png"
# save_image(image, name)
# final save/cleanup
pbar.close()
writer.close()
save_file(model.state_dict(), f"{base_name}.safetensors")
torch.save(optim.state_dict(), f"{base_name}.optim.pth")