Yash Nagraj commited on
Commit ·
275907d
1
Parent(s): 89e0ef4
Add the training scripts for cloud training
Browse files- Discriminators.py +0 -0
- Generators.py +0 -0
- models.py +121 -0
- train.py +102 -0
- utils.py +109 -0
Discriminators.py
DELETED
|
File without changes
|
Generators.py
DELETED
|
File without changes
|
models.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
class ResidualBlock(nn.Module):
|
| 4 |
+
def __init__(self,input_channels ) -> None:
|
| 5 |
+
super(ResidualBlock,self).__init__()
|
| 6 |
+
self.conv1 = nn.Conv2d(input_channels,input_channels,3,1,padding=1,padding_mode='reflect')
|
| 7 |
+
self.conv2 = nn.Conv2d(input_channels,input_channels,3,1,padding=1,padding_mode='reflect')
|
| 8 |
+
self.instanceNorm = nn.InstanceNorm2d(input_channels)
|
| 9 |
+
self.activation = nn.ReLU()
|
| 10 |
+
|
| 11 |
+
def forward(self,x):
|
| 12 |
+
original = x.copy()
|
| 13 |
+
x = self.conv1(x)
|
| 14 |
+
x = self.instanceNorm(x)
|
| 15 |
+
x = self.activation(x)
|
| 16 |
+
x = self.conv2(x)
|
| 17 |
+
x = self.instanceNorm(x)
|
| 18 |
+
return original + x
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ContractingBlock(nn.Module):
|
| 23 |
+
def __init__(self, input_channels, use_bn=True,kernel_size=3,activation='relu') -> None:
|
| 24 |
+
super(ContractingBlock,self).__init__()
|
| 25 |
+
self.conv1 = nn.Conv2d(input_channels, input_channels*2, kernel_size,padding=1,stride=2,padding_mode='reflect')
|
| 26 |
+
self.activation = nn.ReLU() if activation == 'relu' else nn.LeakyReLU(0.2)
|
| 27 |
+
if use_bn:
|
| 28 |
+
self.normalization = nn.InstanceNorm2d(input_channels)
|
| 29 |
+
self.use_bn = use_bn
|
| 30 |
+
|
| 31 |
+
def forward(self,x):
|
| 32 |
+
x = self.conv1(x)
|
| 33 |
+
if self.use_bn:
|
| 34 |
+
self.normalization(x)
|
| 35 |
+
x = self.activation(x)
|
| 36 |
+
return x
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class ExpandingBlock(nn.Module):
|
| 40 |
+
def __init__(self,input_channels,use_bn=True) -> None:
|
| 41 |
+
super(ExpandingBlock, self).__init__()
|
| 42 |
+
self.conv1 = nn.ConvTranspose2d(input_channels, input_channels // 2, kernel_size=3,stride=2,padding=1,output_padding=1)
|
| 43 |
+
if use_bn:
|
| 44 |
+
self.normalization = nn.InstanceNorm2d(input_channels // 2)
|
| 45 |
+
self.use_bn = use_bn
|
| 46 |
+
self.activation = nn.ReLU()
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
x = self.conv1(x)
|
| 50 |
+
if self.use_bn:
|
| 51 |
+
x = self.normalization(x)
|
| 52 |
+
x = self.activation(x)
|
| 53 |
+
return x
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class FeatureMapBlock(nn.Module):
|
| 58 |
+
def __init__(self, input_channels, output_channels) -> None:
|
| 59 |
+
super(FeatureMapBlock,self).__init__()
|
| 60 |
+
self.conv = nn.Conv2d(input_channels, output_channels,kernel_size=7,padding=3,padding_mode='reflect')
|
| 61 |
+
|
| 62 |
+
def forward(self,x):
|
| 63 |
+
x = self.conv(x)
|
| 64 |
+
return x
|
| 65 |
+
|
| 66 |
+
class Generator(nn.Module):
|
| 67 |
+
def __init__(self, input_channels,output_channels, hidden_dim=64) -> None:
|
| 68 |
+
super(Generator,self).__init__()
|
| 69 |
+
self.upfeature = FeatureMapBlock(input_channels,hidden_dim)
|
| 70 |
+
self.contract1 = ContractingBlock(hidden_dim)
|
| 71 |
+
self.contract2 = ContractingBlock(hidden_dim * 2)
|
| 72 |
+
res_mult = 4
|
| 73 |
+
self.res0 = ResidualBlock(hidden_dim * res_mult)
|
| 74 |
+
self.res1 = ResidualBlock(hidden_dim * res_mult)
|
| 75 |
+
self.res2 = ResidualBlock(hidden_dim * res_mult)
|
| 76 |
+
self.res3 = ResidualBlock(hidden_dim * res_mult)
|
| 77 |
+
self.res4 = ResidualBlock(hidden_dim * res_mult)
|
| 78 |
+
self.res5 = ResidualBlock(hidden_dim * res_mult)
|
| 79 |
+
self.res6 = ResidualBlock(hidden_dim * res_mult)
|
| 80 |
+
self.res7 = ResidualBlock(hidden_dim * res_mult)
|
| 81 |
+
self.res8 = ResidualBlock(hidden_dim * res_mult)
|
| 82 |
+
self.expand1 = ExpandingBlock(hidden_dim * res_mult)
|
| 83 |
+
self.expand2 = ExpandingBlock(hidden_dim * 2)
|
| 84 |
+
self.downfeature = FeatureMapBlock(hidden_dim,output_channels)
|
| 85 |
+
self.tanh = nn.Tanh()
|
| 86 |
+
|
| 87 |
+
def forward(self, x):
|
| 88 |
+
x0 = self.upfeature(x)
|
| 89 |
+
x1 = self.contract1(x0)
|
| 90 |
+
x2 = self.contract2(x1)
|
| 91 |
+
x3 = self.res0(x2)
|
| 92 |
+
x4 = self.res1(x3)
|
| 93 |
+
x5 = self.res2(x4)
|
| 94 |
+
x6 = self.res3(x5)
|
| 95 |
+
x7 = self.res4(x6)
|
| 96 |
+
x8 = self.res5(x7)
|
| 97 |
+
x9 = self.res6(x8)
|
| 98 |
+
x10 = self.res7(x9)
|
| 99 |
+
x11 = self.res8(x10)
|
| 100 |
+
x12 = self.expand1(x11)
|
| 101 |
+
x13 = self.expand2(x12)
|
| 102 |
+
xn = self.downfeature(x13)
|
| 103 |
+
return self.tanh(xn)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class Discriminator(nn.Module):
|
| 107 |
+
def __init__(self, input_channels, hidden_channels=64) -> None:
|
| 108 |
+
super(Discriminator,self).__init__()
|
| 109 |
+
self.upfeature = FeatureMapBlock(input_channels,hidden_channels)
|
| 110 |
+
self.contract1 = ContractingBlock(hidden_channels, False,kernel_size=4,activation='lrelu')
|
| 111 |
+
self.contract2 = ContractingBlock(hidden_channels * 2,kernel_size=4,activation='lrelu')
|
| 112 |
+
self.contract3 = ContractingBlock(hidden_channels * 4,kernel_size=4,activation='lrelu')
|
| 113 |
+
self.conv = nn.Conv2d(hidden_channels*8,1,kernel_size=1)
|
| 114 |
+
|
| 115 |
+
def forward(self,x):
|
| 116 |
+
x0 = self.upfeature(x)
|
| 117 |
+
x1 = self.contract1(x0)
|
| 118 |
+
x2 = self.contract2(x1)
|
| 119 |
+
x3 = self.contract3(x2)
|
| 120 |
+
x4 = self.conv(x3)
|
| 121 |
+
return x4
|
train.py
CHANGED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from torchvision import transforms
|
| 3 |
+
from utils import *
|
| 4 |
+
from models import Generator , Discriminator
|
| 5 |
+
from tqdm.auto import tqdm
|
| 6 |
+
|
| 7 |
+
adv_criterion = nn.MSELoss()
|
| 8 |
+
recon_criterion = nn.L1Loss()
|
| 9 |
+
|
| 10 |
+
n_epochs = 60
|
| 11 |
+
dim_A = 3
|
| 12 |
+
dim_B = 3
|
| 13 |
+
display_step = 200
|
| 14 |
+
batch_size = 1
|
| 15 |
+
lr = 0.0002
|
| 16 |
+
load_shape = 286
|
| 17 |
+
target_shape = 256
|
| 18 |
+
device='cuda'
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
transform = transforms.Compose([
|
| 22 |
+
transforms.Resize(load_shape),
|
| 23 |
+
transforms.RandomCrop(target_shape),
|
| 24 |
+
transforms.RandomHorizontalFlip(),
|
| 25 |
+
transforms.ToTensor(),
|
| 26 |
+
])
|
| 27 |
+
|
| 28 |
+
dataset = ImageDataset("horse2zebra", transform=transform)
|
| 29 |
+
|
| 30 |
+
gen_AB = Generator(dim_A,dim_B).to(device)
|
| 31 |
+
gen_BA = Generator(dim_B,dim_A).to(device)
|
| 32 |
+
gen_opt = torch.optim.Adam(list(gen_AB.parameters()) + list(gen_BA.parameters()),lr = lr,betas=(0.5,0.999))
|
| 33 |
+
disc_A = Discriminator(dim_A).to(device)
|
| 34 |
+
disc_A_opt = torch.optim.Adam(disc_A.parameters(),lr=lr,betas=(0.5,0.999))
|
| 35 |
+
disc_B = Discriminator(dim_B).to(device)
|
| 36 |
+
disc_B_opt = torch.optim.Adam(disc_B.parameters(),lr=lr,betas=(0.5,0.999))
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
gen_AB = gen_AB.apply(weights_init)
|
| 40 |
+
gen_BA = gen_BA.apply(weights_init)
|
| 41 |
+
disc_A = disc_A.apply(weights_init)
|
| 42 |
+
disc_B = disc_B.apply(weights_init)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def train():
|
| 47 |
+
mean_gen_loss = 0
|
| 48 |
+
mean_disc_loss = 0
|
| 49 |
+
dataloader = DataLoader(dataset,batch_size,shuffle=True)
|
| 50 |
+
cur_step = 0
|
| 51 |
+
|
| 52 |
+
for epoch in range(n_epochs):
|
| 53 |
+
for real_A,real_B in tqdm(dataloader):
|
| 54 |
+
real_A = nn.functional.interpolate(real_A,size=target_shape)
|
| 55 |
+
real_B = nn.functional.interpolate(real_B,size=target_shape)
|
| 56 |
+
cur_batch_size = len(real_A)
|
| 57 |
+
real_A = real_A.to(device)
|
| 58 |
+
real_B = real_B.to(device)
|
| 59 |
+
|
| 60 |
+
disc_A_opt.zero_grad()
|
| 61 |
+
with torch.no_grad():
|
| 62 |
+
fake_A = gen_BA(real_A)
|
| 63 |
+
disc_A_loss = get_disc_loss(real_A,fake_A,disc_A,adv_criterion)
|
| 64 |
+
disc_A_loss.backward(retain_graph=True)
|
| 65 |
+
disc_A_opt.step()
|
| 66 |
+
|
| 67 |
+
disc_B_opt.zero_grad()
|
| 68 |
+
with torch.no_grad():
|
| 69 |
+
fake_B = gen_AB(real_B)
|
| 70 |
+
disc_B_loss = get_disc_loss(real_B,fake_B,disc_B,adv_criterion)
|
| 71 |
+
disc_B_loss.backward(retain_graph=True)
|
| 72 |
+
disc_B_opt.step()
|
| 73 |
+
|
| 74 |
+
gen_opt.zero_grad()
|
| 75 |
+
gen_loss ,fake_A,fake_B= get_gen_loss(real_A,real_B,gen_AB,gen_BA,disc_B,disc_A,adv_criterion=,identity_criterion=recon_criterion,cycle_criterion=recon_criterion)
|
| 76 |
+
gen_loss.backward()
|
| 77 |
+
gen_opt.step()
|
| 78 |
+
|
| 79 |
+
mean_gen_loss += gen_loss.item() / display_step
|
| 80 |
+
|
| 81 |
+
mean_disc_loss += disc_A_loss.item() / display_step
|
| 82 |
+
|
| 83 |
+
if cur_step % display_step == 0 and cur_step > 0:
|
| 84 |
+
print(f"Epoch: {epoch} | Step: {cur_step} | Gen_loss: {mean_gen_loss} | Disc_loss: {mean_disc_loss} |")
|
| 85 |
+
show_tensor_images(torch.cat([real_A,real_B]),size=(dim_A,target_shape,target_shape))
|
| 86 |
+
show_tensor_images(torch.cat([fake_A,fake_B]),size=(dim_B,target_shape,target_shape))
|
| 87 |
+
mean_gen_loss = 0
|
| 88 |
+
mean_disc_loss = 0
|
| 89 |
+
torch.save({
|
| 90 |
+
'gen_AB': gen_AB,
|
| 91 |
+
'gen_BA': gen_BA,
|
| 92 |
+
'gen_opt': gen_opt,
|
| 93 |
+
'disc_A': disc_A,
|
| 94 |
+
'disc_A_opt': disc_A_opt,
|
| 95 |
+
'disc_B': disc_B,
|
| 96 |
+
'disc_B_opt': disc_B_opt
|
| 97 |
+
}, f"checkpoints/cycleGAN_{cur_step}.pth")
|
| 98 |
+
|
| 99 |
+
cur_step += 1
|
| 100 |
+
|
| 101 |
+
if __name__ == "__main__":
|
| 102 |
+
train()
|
utils.py
CHANGED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torchvision import transforms
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torchvision.utils import make_grid
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import glob
|
| 8 |
+
import os
|
| 9 |
+
from torch.utils.data import Dataset
|
| 10 |
+
from PIL import Image
|
| 11 |
+
|
| 12 |
+
def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
|
| 13 |
+
'''
|
| 14 |
+
Function for visualizing images: Given a tensor of images, number of images, and
|
| 15 |
+
size per image, plots and prints the images in an uniform grid.
|
| 16 |
+
'''
|
| 17 |
+
image_tensor = (image_tensor + 1) / 2
|
| 18 |
+
image_shifted = image_tensor
|
| 19 |
+
image_unflat = image_shifted.detach().cpu().view(-1, *size)
|
| 20 |
+
image_grid = make_grid(image_unflat[:num_images], nrow=5)
|
| 21 |
+
plt.imshow(image_grid.permute(1, 2, 0).squeeze())
|
| 22 |
+
plt.show()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ImageDataset(Dataset):
|
| 26 |
+
def __init__(self, root, transform=None, mode='train'):
|
| 27 |
+
self.transform = transform
|
| 28 |
+
self.files_A = sorted(glob.glob(os.path.join(root, '%sA' % mode) + '/*.*'))
|
| 29 |
+
self.files_B = sorted(glob.glob(os.path.join(root, '%sB' % mode) + '/*.*'))
|
| 30 |
+
if len(self.files_A) > len(self.files_B):
|
| 31 |
+
self.files_A, self.files_B = self.files_B, self.files_A
|
| 32 |
+
self.new_perm()
|
| 33 |
+
assert len(self.files_A) > 0, "Make sure you downloaded the horse2zebra images!"
|
| 34 |
+
|
| 35 |
+
def new_perm(self):
|
| 36 |
+
self.randperm = torch.randperm(len(self.files_B))[:len(self.files_A)]
|
| 37 |
+
|
| 38 |
+
def __getitem__(self, index):
|
| 39 |
+
item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)]))
|
| 40 |
+
item_B = self.transform(Image.open(self.files_B[self.randperm[index]]))
|
| 41 |
+
if item_A.shape[0] != 3:
|
| 42 |
+
item_A = item_A.repeat(3, 1, 1)
|
| 43 |
+
if item_B.shape[0] != 3:
|
| 44 |
+
item_B = item_B.repeat(3, 1, 1)
|
| 45 |
+
if index == len(self) - 1:
|
| 46 |
+
self.new_perm()
|
| 47 |
+
# Old versions of PyTorch didn't support normalization for different-channeled images
|
| 48 |
+
return (item_A - 0.5) * 2, (item_B - 0.5) * 2
|
| 49 |
+
|
| 50 |
+
def __len__(self):
|
| 51 |
+
return min(len(self.files_A), len(self.files_B))
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def weights_init(m):
|
| 55 |
+
if isinstance(m,nn.Conv2d) or isinstance(m,nn.ConvTranspose2d):
|
| 56 |
+
torch.nn.init.normal_(m.weight,1.0,0.2)
|
| 57 |
+
if isinstance(m, nn.BatchNorm2d):
|
| 58 |
+
torch.nn.init.normal_(m.weight, 0.0, 0.02)
|
| 59 |
+
torch.nn.init.constant_(m.bias, 0)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def get_disc_loss(real_X, fake_X,disc_X, adv_criterion):
|
| 63 |
+
real_pred = disc_X(real_X.detach())
|
| 64 |
+
disc_real_loss = adv_criterion(real_pred,torch.ones_like(real_pred))
|
| 65 |
+
fake_pred = disc_X(fake_X.deatch())
|
| 66 |
+
disc_fake_loss = adv_criterion(fake_pred.detach(),torch.zeros_like(fake_pred))
|
| 67 |
+
disc_loss = (disc_real_loss + disc_fake_loss) / 2
|
| 68 |
+
return disc_loss
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def get_gen_adversarial_loss(real_X, disc_Y, gen_XY, adv_criterion):
|
| 72 |
+
fake_Y = gen_XY(real_X.detach())
|
| 73 |
+
disc_pred = disc_Y(fake_Y)
|
| 74 |
+
adverserial_loss = adv_criterion(disc_pred,torch.ones_like(disc_pred))
|
| 75 |
+
return adverserial_loss,fake_Y
|
| 76 |
+
|
| 77 |
+
def get_identity_loss(real_X, gen_YX,identity_criterion):
|
| 78 |
+
identity_X = gen_YX(real_X)
|
| 79 |
+
identity_loss = identity_criterion(identity_X,real_X)
|
| 80 |
+
return identity_loss,identity_X
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def get_cycle_consistency_loss(real_X, fake_Y, gen_YX, cycle_criterion):
|
| 85 |
+
cycle_X = gen_YX(fake_Y)
|
| 86 |
+
cycle_loss = cycle_criterion(cycle_X,real_X)
|
| 87 |
+
return cycle_loss,cycle_X
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def get_gen_loss(real_A, real_B,gen_AB,gen_BA,disc_B,disc_A,adv_criterion,cycle_criterion,identity_criterion,lambda_identity=0.2,lambda_cycle=10):
|
| 92 |
+
adv_loss_BA, fake_A = get_gen_adversarial_loss(real_B, disc_A, gen_BA, adv_criterion)
|
| 93 |
+
adv_loss_AB, fake_B = get_gen_adversarial_loss(real_A, disc_B, gen_AB, adv_criterion)
|
| 94 |
+
gen_adversarial_loss = adv_loss_BA + adv_loss_AB
|
| 95 |
+
|
| 96 |
+
# Identity Loss -- get_identity_loss(real_X, gen_YX, identity_criterion)
|
| 97 |
+
identity_loss_A, identity_A = get_identity_loss(real_A, gen_BA, identity_criterion)
|
| 98 |
+
identity_loss_B, identity_B = get_identity_loss(real_B, gen_AB, identity_criterion)
|
| 99 |
+
gen_identity_loss = identity_loss_A + identity_loss_B
|
| 100 |
+
|
| 101 |
+
# Cycle-consistency Loss -- get_cycle_consistency_loss(real_X, fake_Y, gen_YX, cycle_criterion)
|
| 102 |
+
cycle_loss_BA, cycle_A = get_cycle_consistency_loss(real_A, fake_B, gen_BA, cycle_criterion)
|
| 103 |
+
cycle_loss_AB, cycle_B = get_cycle_consistency_loss(real_B, fake_A, gen_AB, cycle_criterion)
|
| 104 |
+
gen_cycle_loss = cycle_loss_BA + cycle_loss_AB
|
| 105 |
+
|
| 106 |
+
# Total loss
|
| 107 |
+
gen_loss = lambda_identity * gen_identity_loss + lambda_cycle * gen_cycle_loss + gen_adversarial_loss
|
| 108 |
+
|
| 109 |
+
return gen_loss , fake_A,fake_B
|