Ali Mohsin
feat: Add virtual try-on system components including DensePose, SMPL, and pix2pixHD models, rendering, and utilities.
5db43ff
import numpy as np
import torch
import torch.nn as nn
import os
from torch.autograd import Variable
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
from .ASAPNet_model import ASAPNetModel
import random
class ASAPNet_RGBA(ASAPNetModel):
def name(self):
return 'ASAPNet_RGBA'
def forward(self, heatmaps, image, infer=False):
# Encode Inputs
#input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat)
real_input = image
real_image = image[:, [0, 1, 2], :, :]
real_mask = image[:, [3], :, :]
# Fake Generation
fake_output = self.netG.forward(heatmaps)
fake_image = fake_output[:, [0, 1, 2], :, :]
fake_mask = fake_output[:, [3], :, :]
# Fake Detection and Loss
pred_fake_pool = self.discriminate(heatmaps, fake_output, use_pool=True)
loss_D_fake = self.criterionGAN(pred_fake_pool, False)
# Real Detection and Loss
pred_real = self.discriminate(heatmaps, real_input)
loss_D_real = self.criterionGAN(pred_real, True)
# GAN loss (Fake Passability Loss)
pred_fake = self.netD.forward(torch.cat((heatmaps, fake_output), dim=1))
loss_G_GAN = self.criterionGAN(pred_fake, True)
# GAN feature matching loss
loss_G_GAN_Feat = 0
if not self.opt.no_ganFeat_loss:
feat_weights = 4.0 / (self.opt.n_layers_D + 1)
D_weights = 1.0 / self.opt.num_D
for i in range(self.opt.num_D):
for j in range(len(pred_fake[i]) - 1):
loss_G_GAN_Feat += D_weights * feat_weights * \
self.criterionFeat(pred_fake[i][j],
pred_real[i][j].detach()) * self.opt.lambda_feat
# VGG feature matching loss
loss_G_VGG = 0
normalized = True
if not self.opt.no_vgg_loss:
if normalized:
loss_G_VGG = self.criterionVGG((fake_image * 0.5 + 1.0), #* fake_mask,
(real_image * 0.5 + 1.0) * real_mask) * self.opt.lambda_feat
else:
loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat
loss_G_VGG += self.criterionVGG(real_mask, fake_mask) * self.opt.lambda_feat
local_loss=False
c=1.5
if local_loss:
local_loss_G_GAN, local_loss_G_GAN_Feat, local_loss_G_VGG, local_loss_D_real, local_loss_D_fake=self.localLoss(heatmaps,image,fake_output)
loss_G_GAN+=local_loss_G_GAN*c
loss_G_GAN_Feat+=local_loss_G_GAN_Feat*c
loss_G_VGG+=local_loss_G_VGG*c
loss_D_real+=local_loss_D_real*c
loss_D_fake+=local_loss_D_fake*c
# Only return the fake_B image if necessary to save BW
return [self.loss_filter(loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake),
None if not infer else fake_output]
def localLoss(self,heatmaps, gt_output, fake_output):
heatmaps, gt_output, fake_output = self.RandomCrop([heatmaps, gt_output, fake_output],size=360)
heatmaps=nn.functional.interpolate(heatmaps, scale_factor=None,size=(1024,1024), mode='bilinear',align_corners=True)
gt_output=nn.functional.interpolate(gt_output, scale_factor=None,size=(1024,1024), mode='bilinear',align_corners=True)
fake_output=nn.functional.interpolate(fake_output, scale_factor=None,size=(1024,1024), mode='bilinear',align_corners=True)
real_input = gt_output
real_image = gt_output[:, [0, 1, 2], :, :]
real_mask = gt_output[:, [3], :, :]
fake_image = fake_output[:, [0, 1, 2], :, :]
fake_mask = fake_output[:, [3], :, :]
# Fake Detection and Loss
pred_fake_pool = self.discriminate(heatmaps, fake_output, use_pool=True)
loss_D_fake = self.criterionGAN(pred_fake_pool, False)
# Real Detection and Loss
pred_real = self.discriminate(heatmaps, real_input)
loss_D_real = self.criterionGAN(pred_real, True)
# GAN loss (Fake Passability Loss)
pred_fake = self.netD.forward(torch.cat((heatmaps, fake_output), dim=1))
loss_G_GAN = self.criterionGAN(pred_fake, True)
# GAN feature matching loss
loss_G_GAN_Feat = 0
if not self.opt.no_ganFeat_loss:
feat_weights = 4.0 / (self.opt.n_layers_D + 1)
D_weights = 1.0 / self.opt.num_D
for i in range(self.opt.num_D):
for j in range(len(pred_fake[i]) - 1):
loss_G_GAN_Feat += D_weights * feat_weights * \
self.criterionFeat(pred_fake[i][j],
pred_real[i][j].detach()) * self.opt.lambda_feat
# VGG feature matching loss
loss_G_VGG = 0
normalized = True
if not self.opt.no_vgg_loss:
if normalized:
loss_G_VGG = self.criterionVGG((fake_image * 0.5 + 1.0) * fake_mask,
(real_image * 0.5 + 1.0) * real_mask) * self.opt.lambda_feat
else:
loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat
return [loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake]
def localVGG(self, img1, img2):
cropped1, cropped2 = self.RandomCrop([img1, img2],360)
return self.criterionVGG(cropped1, cropped2)
def RandomCrop(self, img_list, size):
_, _, h, w = img_list[0].shape
assert size<h and size<w
h_start = random.randint(0, h - size)
w_start = random.randint(0, w - size)
cropped_list=[]
for img in img_list:
cropped_img = img[:,:,h_start:h_start + size, w_start:w_start + size]
cropped_list.append(cropped_img)
return cropped_list
def inference(self, image=None):
# Encode Inputs
image = Variable(image) if image is not None else None
input_concat = image
if torch.__version__.startswith('0.4'):
with torch.no_grad():
fake_image = self.netG.forward(input_concat)
else:
fake_image = self.netG.forward(input_concat)
return fake_image
class InferenceModel(ASAPNet_RGBA):
def forward(self, inp):
#label, inst = inp
#return self.inference(label, inst)
return self.inference(inp)