Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import argparse | |
| import os | |
| import time | |
| from cp_dataset import CPDataset, CPDataLoader | |
| from networks import GicLoss, GMM, UnetGenerator, VGGLoss, load_checkpoint, save_checkpoint | |
| from tensorboardX import SummaryWriter | |
| from visualization import board_add_image, board_add_images | |
| def get_opt(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--name", default="GMM") | |
| # parser.add_argument("--name", default="TOM") | |
| parser.add_argument("--gpu_ids", default="") | |
| parser.add_argument('-j', '--workers', type=int, default=1) | |
| parser.add_argument('-b', '--batch-size', type=int, default=4) | |
| parser.add_argument("--dataroot", default="data") | |
| parser.add_argument("--datamode", default="train") | |
| parser.add_argument("--stage", default="GMM") | |
| # parser.add_argument("--stage", default="TOM") | |
| parser.add_argument("--data_list", default="train_pairs.txt") | |
| parser.add_argument("--fine_width", type=int, default=192) | |
| parser.add_argument("--fine_height", type=int, default=256) | |
| parser.add_argument("--radius", type=int, default=5) | |
| parser.add_argument("--grid_size", type=int, default=5) | |
| parser.add_argument('--lr', type=float, default=0.0001, | |
| help='initial learning rate for adam') | |
| parser.add_argument('--tensorboard_dir', type=str, | |
| default='tensorboard', help='save tensorboard infos') | |
| parser.add_argument('--checkpoint_dir', type=str, | |
| default='checkpoints', help='save checkpoint infos') | |
| parser.add_argument('--checkpoint', type=str, default='', | |
| help='model checkpoint for initialization') | |
| parser.add_argument("--display_count", type=int, default=20) | |
| parser.add_argument("--save_count", type=int, default=5000) | |
| parser.add_argument("--keep_step", type=int, default=100000) | |
| parser.add_argument("--decay_step", type=int, default=100000) | |
| parser.add_argument("--shuffle", action='store_true', | |
| help='shuffle input data') | |
| opt = parser.parse_args() | |
| return opt | |
| def train_gmm(opt, train_loader, model, board): | |
| model.cuda() | |
| model.train() | |
| # criterion | |
| criterionL1 = nn.L1Loss() | |
| gicloss = GicLoss(opt) | |
| # optimizer | |
| optimizer = torch.optim.Adam( | |
| model.parameters(), lr=opt.lr, betas=(0.5, 0.999)) | |
| scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: 1.0 - | |
| max(0, step - opt.keep_step) / float(opt.decay_step + 1)) | |
| for step in range(opt.keep_step + opt.decay_step): | |
| iter_start_time = time.time() | |
| inputs = train_loader.next_batch() | |
| im = inputs['image'].cuda() | |
| im_pose = inputs['pose_image'].cuda() | |
| im_h = inputs['head'].cuda() | |
| shape = inputs['shape'].cuda() | |
| agnostic = inputs['agnostic'].cuda() | |
| c = inputs['cloth'].cuda() | |
| cm = inputs['cloth_mask'].cuda() | |
| im_c = inputs['parse_cloth'].cuda() | |
| im_g = inputs['grid_image'].cuda() | |
| grid, theta = model(agnostic, cm) # can be added c too for new training | |
| warped_cloth = F.grid_sample(c, grid, padding_mode='border') | |
| warped_mask = F.grid_sample(cm, grid, padding_mode='zeros') | |
| warped_grid = F.grid_sample(im_g, grid, padding_mode='zeros') | |
| visuals = [[im_h, shape, im_pose], | |
| [c, warped_cloth, im_c], | |
| [warped_grid, (warped_cloth+im)*0.5, im]] | |
| # loss for warped cloth | |
| Lwarp = criterionL1(warped_cloth, im_c) # changing to previous code as it corresponds to the working code | |
| # Actual loss function as in the paper given below (comment out previous line and uncomment below to train as per the paper) | |
| # Lwarp = criterionL1(warped_mask, cm) # loss for warped mask thanks @xuxiaochun025 for fixing the git code. | |
| # grid regularization loss | |
| Lgic = gicloss(grid) | |
| # 200x200 = 40.000 * 0.001 | |
| Lgic = Lgic / (grid.shape[0] * grid.shape[1] * grid.shape[2]) | |
| loss = Lwarp + 40 * Lgic # total GMM loss | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| if (step+1) % opt.display_count == 0: | |
| board_add_images(board, 'combine', visuals, step+1) | |
| board.add_scalar('loss', loss.item(), step+1) | |
| board.add_scalar('40*Lgic', (40*Lgic).item(), step+1) | |
| board.add_scalar('Lwarp', Lwarp.item(), step+1) | |
| t = time.time() - iter_start_time | |
| print('step: %8d, time: %.3f, loss: %4f, (40*Lgic): %.8f, Lwarp: %.6f' % | |
| (step+1, t, loss.item(), (40*Lgic).item(), Lwarp.item()), flush=True) | |
| if (step+1) % opt.save_count == 0: | |
| save_checkpoint(model, os.path.join( | |
| opt.checkpoint_dir, opt.name, 'step_%06d.pth' % (step+1))) | |
| def train_tom(opt, train_loader, model, board): | |
| model.cuda() | |
| model.train() | |
| # criterion | |
| criterionL1 = nn.L1Loss() | |
| criterionVGG = VGGLoss() | |
| criterionMask = nn.L1Loss() | |
| # optimizer | |
| optimizer = torch.optim.Adam( | |
| model.parameters(), lr=opt.lr, betas=(0.5, 0.999)) | |
| scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: 1.0 - | |
| max(0, step - opt.keep_step) / float(opt.decay_step + 1)) | |
| for step in range(opt.keep_step + opt.decay_step): | |
| iter_start_time = time.time() | |
| inputs = train_loader.next_batch() | |
| im = inputs['image'].cuda() | |
| im_pose = inputs['pose_image'] | |
| im_h = inputs['head'] | |
| shape = inputs['shape'] | |
| agnostic = inputs['agnostic'].cuda() | |
| c = inputs['cloth'].cuda() | |
| cm = inputs['cloth_mask'].cuda() | |
| pcm = inputs['parse_cloth_mask'].cuda() | |
| # outputs = model(torch.cat([agnostic, c], 1)) # CP-VTON | |
| outputs = model(torch.cat([agnostic, c, cm], 1)) # CP-VTON+ | |
| p_rendered, m_composite = torch.split(outputs, 3, 1) | |
| p_rendered = F.tanh(p_rendered) | |
| m_composite = F.sigmoid(m_composite) | |
| p_tryon = c * m_composite + p_rendered * (1 - m_composite) | |
| """visuals = [[im_h, shape, im_pose], | |
| [c, cm*2-1, m_composite*2-1], | |
| [p_rendered, p_tryon, im]]""" # CP-VTON | |
| visuals = [[im_h, shape, im_pose], | |
| [c, pcm*2-1, m_composite*2-1], | |
| [p_rendered, p_tryon, im]] # CP-VTON+ | |
| loss_l1 = criterionL1(p_tryon, im) | |
| loss_vgg = criterionVGG(p_tryon, im) | |
| # loss_mask = criterionMask(m_composite, cm) # CP-VTON | |
| loss_mask = criterionMask(m_composite, pcm) # CP-VTON+ | |
| loss = loss_l1 + loss_vgg + loss_mask | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| if (step+1) % opt.display_count == 0: | |
| board_add_images(board, 'combine', visuals, step+1) | |
| board.add_scalar('metric', loss.item(), step+1) | |
| board.add_scalar('L1', loss_l1.item(), step+1) | |
| board.add_scalar('VGG', loss_vgg.item(), step+1) | |
| board.add_scalar('MaskL1', loss_mask.item(), step+1) | |
| t = time.time() - iter_start_time | |
| print('step: %8d, time: %.3f, loss: %.4f, l1: %.4f, vgg: %.4f, mask: %.4f' | |
| % (step+1, t, loss.item(), loss_l1.item(), | |
| loss_vgg.item(), loss_mask.item()), flush=True) | |
| if (step+1) % opt.save_count == 0: | |
| save_checkpoint(model, os.path.join( | |
| opt.checkpoint_dir, opt.name, 'step_%06d.pth' % (step+1))) | |
| def main(): | |
| opt = get_opt() | |
| print(opt) | |
| print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name)) | |
| # create dataset | |
| train_dataset = CPDataset(opt) | |
| # create dataloader | |
| train_loader = CPDataLoader(opt, train_dataset) | |
| # visualization | |
| if not os.path.exists(opt.tensorboard_dir): | |
| os.makedirs(opt.tensorboard_dir) | |
| board = SummaryWriter(logdir=os.path.join(opt.tensorboard_dir, opt.name)) | |
| # create model & train & save the final checkpoint | |
| if opt.stage == 'GMM': | |
| model = GMM(opt) | |
| if not opt.checkpoint == '' and os.path.exists(opt.checkpoint): | |
| load_checkpoint(model, opt.checkpoint) | |
| train_gmm(opt, train_loader, model, board) | |
| save_checkpoint(model, os.path.join( | |
| opt.checkpoint_dir, opt.name, 'gmm_final.pth')) | |
| elif opt.stage == 'TOM': | |
| # model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d) # CP-VTON | |
| model = UnetGenerator( | |
| 26, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d) # CP-VTON+ | |
| if not opt.checkpoint == '' and os.path.exists(opt.checkpoint): | |
| load_checkpoint(model, opt.checkpoint) | |
| train_tom(opt, train_loader, model, board) | |
| save_checkpoint(model, os.path.join( | |
| opt.checkpoint_dir, opt.name, 'tom_final.pth')) | |
| else: | |
| raise NotImplementedError('Model [%s] is not implemented' % opt.stage) | |
| print('Finished training %s, named: %s!' % (opt.stage, opt.name)) | |
| if __name__ == "__main__": | |
| main() | |