Spaces:
Paused
Paused
| print("started imports") | |
| import sys | |
| import argparse | |
| import time | |
| import cv2 | |
| import wandb | |
| from PIL import Image | |
| import os | |
| from torch.utils.data import DataLoader | |
| import torch.optim as optim | |
| import torch.nn.functional as F | |
| import torch | |
| import torchvision.transforms as transforms | |
| import torch.optim.lr_scheduler as scheduler | |
| # custom imports | |
| sys.path.append('./apex/') | |
| from apex import amp | |
| from network.AEI_Net import * | |
| from network.MultiscaleDiscriminator import * | |
| from utils.training.Dataset import FaceEmbedVGG2, FaceEmbed | |
| from utils.training.image_processing import make_image_list, get_faceswap | |
| from utils.training.losses import hinge_loss, compute_discriminator_loss, compute_generator_losses | |
| from utils.training.detector import detect_landmarks, paint_eyes | |
| from AdaptiveWingLoss.core import models | |
| from arcface_model.iresnet import iresnet100 | |
| print("finished imports") | |
| def train_one_epoch(G: 'generator model', | |
| D: 'discriminator model', | |
| opt_G: "generator opt", | |
| opt_D: "discriminator opt", | |
| scheduler_G: "scheduler G opt", | |
| scheduler_D: "scheduler D opt", | |
| netArc: 'ArcFace model', | |
| model_ft: 'Landmark Detector', | |
| args: 'Args Namespace', | |
| dataloader: torch.utils.data.DataLoader, | |
| device: 'torch device', | |
| epoch:int, | |
| loss_adv_accumulated:int): | |
| for iteration, data in enumerate(dataloader): | |
| start_time = time.time() | |
| Xs_orig, Xs, Xt, same_person = data | |
| Xs_orig = Xs_orig.to(device) | |
| Xs = Xs.to(device) | |
| Xt = Xt.to(device) | |
| same_person = same_person.to(device) | |
| # get the identity embeddings of Xs | |
| with torch.no_grad(): | |
| embed = netArc(F.interpolate(Xs_orig, [112, 112], mode='bilinear', align_corners=False)) | |
| diff_person = torch.ones_like(same_person) | |
| if args.diff_eq_same: | |
| same_person = diff_person | |
| # generator training | |
| opt_G.zero_grad() | |
| Y, Xt_attr = G(Xt, embed) | |
| Di = D(Y) | |
| ZY = netArc(F.interpolate(Y, [112, 112], mode='bilinear', align_corners=False)) | |
| if args.eye_detector_loss: | |
| Xt_eyes, Xt_heatmap_left, Xt_heatmap_right = detect_landmarks(Xt, model_ft) | |
| Y_eyes, Y_heatmap_left, Y_heatmap_right = detect_landmarks(Y, model_ft) | |
| eye_heatmaps = [Xt_heatmap_left, Xt_heatmap_right, Y_heatmap_left, Y_heatmap_right] | |
| else: | |
| eye_heatmaps = None | |
| lossG, loss_adv_accumulated, L_adv, L_attr, L_id, L_rec, L_l2_eyes = compute_generator_losses(G, Y, Xt, Xt_attr, Di, | |
| embed, ZY, eye_heatmaps,loss_adv_accumulated, | |
| diff_person, same_person, args) | |
| with amp.scale_loss(lossG, opt_G) as scaled_loss: | |
| scaled_loss.backward() | |
| opt_G.step() | |
| if args.scheduler: | |
| scheduler_G.step() | |
| # discriminator training | |
| opt_D.zero_grad() | |
| lossD = compute_discriminator_loss(D, Y, Xs, diff_person) | |
| with amp.scale_loss(lossD, opt_D) as scaled_loss: | |
| scaled_loss.backward() | |
| if (not args.discr_force) or (loss_adv_accumulated < 4.): | |
| opt_D.step() | |
| if args.scheduler: | |
| scheduler_D.step() | |
| batch_time = time.time() - start_time | |
| if iteration % args.show_step == 0: | |
| images = [Xs, Xt, Y] | |
| if args.eye_detector_loss: | |
| Xt_eyes_img = paint_eyes(Xt, Xt_eyes) | |
| Yt_eyes_img = paint_eyes(Y, Y_eyes) | |
| images.extend([Xt_eyes_img, Yt_eyes_img]) | |
| image = make_image_list(images) | |
| if args.use_wandb: | |
| wandb.log({"gen_images":wandb.Image(image, caption=f"{epoch:03}" + '_' + f"{iteration:06}")}) | |
| else: | |
| cv2.imwrite('./images/generated_image.jpg', image[:,:,::-1]) | |
| if iteration % 10 == 0: | |
| print(f'epoch: {epoch} {iteration} / {len(dataloader)}') | |
| print(f'lossD: {lossD.item()} lossG: {lossG.item()} batch_time: {batch_time}s') | |
| print(f'L_adv: {L_adv.item()} L_id: {L_id.item()} L_attr: {L_attr.item()} L_rec: {L_rec.item()}') | |
| if args.eye_detector_loss: | |
| print(f'L_l2_eyes: {L_l2_eyes.item()}') | |
| print(f'loss_adv_accumulated: {loss_adv_accumulated}') | |
| if args.scheduler: | |
| print(f'scheduler_G lr: {scheduler_G.get_last_lr()} scheduler_D lr: {scheduler_D.get_last_lr()}') | |
| if args.use_wandb: | |
| if args.eye_detector_loss: | |
| wandb.log({"loss_eyes": L_l2_eyes.item()}, commit=False) | |
| wandb.log({"loss_id": L_id.item(), | |
| "lossD": lossD.item(), | |
| "lossG": lossG.item(), | |
| "loss_adv": L_adv.item(), | |
| "loss_attr": L_attr.item(), | |
| "loss_rec": L_rec.item()}) | |
| if iteration % 5000 == 0: | |
| torch.save(G.state_dict(), f'./saved_models_{args.run_name}/G_latest.pth') | |
| torch.save(D.state_dict(), f'./saved_models_{args.run_name}/D_latest.pth') | |
| torch.save(G.state_dict(), f'./current_models_{args.run_name}/G_' + str(epoch)+ '_' + f"{iteration:06}" + '.pth') | |
| torch.save(D.state_dict(), f'./current_models_{args.run_name}/D_' + str(epoch)+ '_' + f"{iteration:06}" + '.pth') | |
| if (iteration % 250 == 0) and (args.use_wandb): | |
| ### Посмотрим как выглядит свап на трех конкретных фотках, чтобы проследить динамику | |
| G.eval() | |
| res1 = get_faceswap('examples/images/training//source1.png', 'examples/images/training//target1.png', G, netArc, device) | |
| res2 = get_faceswap('examples/images/training//source2.png', 'examples/images/training//target2.png', G, netArc, device) | |
| res3 = get_faceswap('examples/images/training//source3.png', 'examples/images/training//target3.png', G, netArc, device) | |
| res4 = get_faceswap('examples/images/training//source4.png', 'examples/images/training//target4.png', G, netArc, device) | |
| res5 = get_faceswap('examples/images/training//source5.png', 'examples/images/training//target5.png', G, netArc, device) | |
| res6 = get_faceswap('examples/images/training//source6.png', 'examples/images/training//target6.png', G, netArc, device) | |
| output1 = np.concatenate((res1, res2, res3), axis=0) | |
| output2 = np.concatenate((res4, res5, res6), axis=0) | |
| output = np.concatenate((output1, output2), axis=1) | |
| wandb.log({"our_images":wandb.Image(output, caption=f"{epoch:03}" + '_' + f"{iteration:06}")}) | |
| G.train() | |
| def train(args, device): | |
| # training params | |
| batch_size = args.batch_size | |
| max_epoch = args.max_epoch | |
| # initializing main models | |
| G = AEI_Net(args.backbone, num_blocks=args.num_blocks, c_id=512).to(device) | |
| D = MultiscaleDiscriminator(input_nc=3, n_layers=5, norm_layer=torch.nn.InstanceNorm2d).to(device) | |
| G.train() | |
| D.train() | |
| # initializing model for identity extraction | |
| netArc = iresnet100(fp16=False) | |
| netArc.load_state_dict(torch.load('arcface_model/backbone.pth')) | |
| netArc=netArc.cuda() | |
| netArc.eval() | |
| if args.eye_detector_loss: | |
| model_ft = models.FAN(4, "False", "False", 98) | |
| checkpoint = torch.load('./AdaptiveWingLoss/AWL_detector/WFLW_4HG.pth') | |
| if 'state_dict' not in checkpoint: | |
| model_ft.load_state_dict(checkpoint) | |
| else: | |
| pretrained_weights = checkpoint['state_dict'] | |
| model_weights = model_ft.state_dict() | |
| pretrained_weights = {k: v for k, v in pretrained_weights.items() \ | |
| if k in model_weights} | |
| model_weights.update(pretrained_weights) | |
| model_ft.load_state_dict(model_weights) | |
| model_ft = model_ft.to(device) | |
| model_ft.eval() | |
| else: | |
| model_ft=None | |
| opt_G = optim.Adam(G.parameters(), lr=args.lr_G, betas=(0, 0.999), weight_decay=1e-4) | |
| opt_D = optim.Adam(D.parameters(), lr=args.lr_D, betas=(0, 0.999), weight_decay=1e-4) | |
| G, opt_G = amp.initialize(G, opt_G, opt_level=args.optim_level) | |
| D, opt_D = amp.initialize(D, opt_D, opt_level=args.optim_level) | |
| if args.scheduler: | |
| scheduler_G = scheduler.StepLR(opt_G, step_size=args.scheduler_step, gamma=args.scheduler_gamma) | |
| scheduler_D = scheduler.StepLR(opt_D, step_size=args.scheduler_step, gamma=args.scheduler_gamma) | |
| else: | |
| scheduler_G = None | |
| scheduler_D = None | |
| if args.pretrained: | |
| try: | |
| G.load_state_dict(torch.load(args.G_path, map_location=torch.device('cpu')), strict=False) | |
| D.load_state_dict(torch.load(args.D_path, map_location=torch.device('cpu')), strict=False) | |
| print("Loaded pretrained weights for G and D") | |
| except FileNotFoundError as e: | |
| print("Not found pretrained weights. Continue without any pretrained weights.") | |
| if args.vgg: | |
| dataset = FaceEmbedVGG2(args.dataset_path, same_prob=args.same_person, same_identity=args.same_identity) | |
| else: | |
| dataset = FaceEmbed([args.dataset_path], same_prob=args.same_person) | |
| dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8, drop_last=True) | |
| # Будем считать аккумулированный adv loss, чтобы обучать дискриминатор только когда он ниже порога, если discr_force=True | |
| loss_adv_accumulated = 20. | |
| for epoch in range(0, max_epoch): | |
| train_one_epoch(G, | |
| D, | |
| opt_G, | |
| opt_D, | |
| scheduler_G, | |
| scheduler_D, | |
| netArc, | |
| model_ft, | |
| args, | |
| dataloader, | |
| device, | |
| epoch, | |
| loss_adv_accumulated) | |
| def main(args): | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| if not torch.cuda.is_available(): | |
| print('cuda is not available. using cpu. check if it\'s ok') | |
| print("Starting traing") | |
| train(args, device=device) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| # dataset params | |
| parser.add_argument('--dataset_path', default='/VggFace2-crop/', help='Path to the dataset. If not VGG2 dataset is used, param --vgg should be set False') | |
| parser.add_argument('--G_path', default='./saved_models/G.pth', help='Path to pretrained weights for G. Only used if pretrained=True') | |
| parser.add_argument('--D_path', default='./saved_models/D.pth', help='Path to pretrained weights for D. Only used if pretrained=True') | |
| parser.add_argument('--vgg', default=True, type=bool, help='When using VGG2 dataset (or any other dataset with several photos for one identity)') | |
| # weights for loss | |
| parser.add_argument('--weight_adv', default=1, type=float, help='Adversarial Loss weight') | |
| parser.add_argument('--weight_attr', default=10, type=float, help='Attributes weight') | |
| parser.add_argument('--weight_id', default=20, type=float, help='Identity Loss weight') | |
| parser.add_argument('--weight_rec', default=10, type=float, help='Reconstruction Loss weight') | |
| parser.add_argument('--weight_eyes', default=0., type=float, help='Eyes Loss weight') | |
| # training params you may want to change | |
| parser.add_argument('--backbone', default='unet', const='unet', nargs='?', choices=['unet', 'linknet', 'resnet'], help='Backbone for attribute encoder') | |
| parser.add_argument('--num_blocks', default=2, type=int, help='Numbers of AddBlocks at AddResblock') | |
| parser.add_argument('--same_person', default=0.2, type=float, help='Probability of using same person identity during training') | |
| parser.add_argument('--same_identity', default=True, type=bool, help='Using simswap approach, when source_id = target_id. Only possible with vgg=True') | |
| parser.add_argument('--diff_eq_same', default=False, type=bool, help='Don\'t use info about where is defferent identities') | |
| parser.add_argument('--pretrained', default=True, type=bool, help='If using the pretrained weights for training or not') | |
| parser.add_argument('--discr_force', default=False, type=bool, help='If True Discriminator would not train when adversarial loss is high') | |
| parser.add_argument('--scheduler', default=False, type=bool, help='If True decreasing LR is used for learning of generator and discriminator') | |
| parser.add_argument('--scheduler_step', default=5000, type=int) | |
| parser.add_argument('--scheduler_gamma', default=0.2, type=float, help='It is value, which shows how many times to decrease LR') | |
| parser.add_argument('--eye_detector_loss', default=False, type=bool, help='If True eye loss with using AdaptiveWingLoss detector is applied to generator') | |
| # info about this run | |
| parser.add_argument('--use_wandb', default=False, type=bool, help='Use wandb to track your experiments or not') | |
| parser.add_argument('--run_name', required=True, type=str, help='Name of this run. Used to create folders where to save the weights.') | |
| parser.add_argument('--wandb_project', default='your-project-name', type=str) | |
| parser.add_argument('--wandb_entity', default='your-login', type=str) | |
| # training params you probably don't want to change | |
| parser.add_argument('--batch_size', default=16, type=int) | |
| parser.add_argument('--lr_G', default=4e-4, type=float) | |
| parser.add_argument('--lr_D', default=4e-4, type=float) | |
| parser.add_argument('--max_epoch', default=2000, type=int) | |
| parser.add_argument('--show_step', default=500, type=int) | |
| parser.add_argument('--save_epoch', default=1, type=int) | |
| parser.add_argument('--optim_level', default='O2', type=str) | |
| args = parser.parse_args() | |
| if args.vgg==False and args.same_identity==True: | |
| raise ValueError("Sorry, you can't use some other dataset than VGG2 Faces with param same_identity=True") | |
| if args.use_wandb==True: | |
| wandb.init(project=args.wandb_project, entity=args.wandb_entity, settings=wandb.Settings(start_method='fork')) | |
| config = wandb.config | |
| config.dataset_path = args.dataset_path | |
| config.weight_adv = args.weight_adv | |
| config.weight_attr = args.weight_attr | |
| config.weight_id = args.weight_id | |
| config.weight_rec = args.weight_rec | |
| config.weight_eyes = args.weight_eyes | |
| config.same_person = args.same_person | |
| config.Vgg2Face = args.vgg | |
| config.same_identity = args.same_identity | |
| config.diff_eq_same = args.diff_eq_same | |
| config.discr_force = args.discr_force | |
| config.scheduler = args.scheduler | |
| config.scheduler_step = args.scheduler_step | |
| config.scheduler_gamma = args.scheduler_gamma | |
| config.eye_detector_loss = args.eye_detector_loss | |
| config.pretrained = args.pretrained | |
| config.run_name = args.run_name | |
| config.G_path = args.G_path | |
| config.D_path = args.D_path | |
| config.batch_size = args.batch_size | |
| config.lr_G = args.lr_G | |
| config.lr_D = args.lr_D | |
| elif not os.path.exists('./images'): | |
| os.mkdir('./images') | |
| # Создаем папки, чтобы было куда сохранять последние веса моделей, а также веса с каждой эпохи | |
| if not os.path.exists(f'./saved_models_{args.run_name}'): | |
| os.mkdir(f'./saved_models_{args.run_name}') | |
| os.mkdir(f'./current_models_{args.run_name}') | |
| main(args) | |