| | from __future__ import print_function |
| | import argparse |
| | import os |
| | import random |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.parallel |
| | import torch.backends.cudnn as cudnn |
| | import torch.optim as optim |
| | import torch.utils.data |
| | import torchvision.datasets as dset |
| | import torchvision.transforms as transforms |
| | import torchvision.utils as vutils |
| |
|
| | try: |
| | from apex import amp |
| | except ImportError: |
| | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") |
| |
|
| |
|
| | parser = argparse.ArgumentParser() |
| | parser.add_argument('--dataset', default='cifar10', help='cifar10 | lsun | mnist |imagenet | folder | lfw | fake') |
| | parser.add_argument('--dataroot', default='./', help='path to dataset') |
| | parser.add_argument('--workers', type=int, help='number of data loading workers', default=2) |
| | parser.add_argument('--batchSize', type=int, default=64, help='input batch size') |
| | parser.add_argument('--imageSize', type=int, default=64, help='the height / width of the input image to network') |
| | parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector') |
| | parser.add_argument('--ngf', type=int, default=64) |
| | parser.add_argument('--ndf', type=int, default=64) |
| | parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for') |
| | parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002') |
| | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') |
| | parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use') |
| | parser.add_argument('--netG', default='', help="path to netG (to continue training)") |
| | parser.add_argument('--netD', default='', help="path to netD (to continue training)") |
| | parser.add_argument('--outf', default='.', help='folder to output images and model checkpoints') |
| | parser.add_argument('--manualSeed', type=int, help='manual seed') |
| | parser.add_argument('--classes', default='bedroom', help='comma separated list of classes for the lsun data set') |
| | parser.add_argument('--opt_level', default='O1', help='amp opt_level, default="O1"') |
| |
|
| | opt = parser.parse_args() |
| | print(opt) |
| |
|
| |
|
| | try: |
| | os.makedirs(opt.outf) |
| | except OSError: |
| | pass |
| |
|
| | if opt.manualSeed is None: |
| | opt.manualSeed = 2809 |
| | print("Random Seed: ", opt.manualSeed) |
| | random.seed(opt.manualSeed) |
| | torch.manual_seed(opt.manualSeed) |
| |
|
| | cudnn.benchmark = True |
| |
|
| |
|
| | if opt.dataset in ['imagenet', 'folder', 'lfw']: |
| | |
| | dataset = dset.ImageFolder(root=opt.dataroot, |
| | transform=transforms.Compose([ |
| | transforms.Resize(opt.imageSize), |
| | transforms.CenterCrop(opt.imageSize), |
| | transforms.ToTensor(), |
| | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), |
| | ])) |
| | nc=3 |
| | elif opt.dataset == 'lsun': |
| | classes = [ c + '_train' for c in opt.classes.split(',')] |
| | dataset = dset.LSUN(root=opt.dataroot, classes=classes, |
| | transform=transforms.Compose([ |
| | transforms.Resize(opt.imageSize), |
| | transforms.CenterCrop(opt.imageSize), |
| | transforms.ToTensor(), |
| | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), |
| | ])) |
| | nc=3 |
| | elif opt.dataset == 'cifar10': |
| | dataset = dset.CIFAR10(root=opt.dataroot, download=True, |
| | transform=transforms.Compose([ |
| | transforms.Resize(opt.imageSize), |
| | transforms.ToTensor(), |
| | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), |
| | ])) |
| | nc=3 |
| |
|
| | elif opt.dataset == 'mnist': |
| | dataset = dset.MNIST(root=opt.dataroot, download=True, |
| | transform=transforms.Compose([ |
| | transforms.Resize(opt.imageSize), |
| | transforms.ToTensor(), |
| | transforms.Normalize((0.5,), (0.5,)), |
| | ])) |
| | nc=1 |
| |
|
| | elif opt.dataset == 'fake': |
| | dataset = dset.FakeData(image_size=(3, opt.imageSize, opt.imageSize), |
| | transform=transforms.ToTensor()) |
| | nc=3 |
| |
|
| | assert dataset |
| | dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize, |
| | shuffle=True, num_workers=int(opt.workers)) |
| |
|
| | device = torch.device("cuda:0") |
| | ngpu = int(opt.ngpu) |
| | nz = int(opt.nz) |
| | ngf = int(opt.ngf) |
| | ndf = int(opt.ndf) |
| |
|
| |
|
| | |
| | def weights_init(m): |
| | classname = m.__class__.__name__ |
| | if classname.find('Conv') != -1: |
| | m.weight.data.normal_(0.0, 0.02) |
| | elif classname.find('BatchNorm') != -1: |
| | m.weight.data.normal_(1.0, 0.02) |
| | m.bias.data.fill_(0) |
| |
|
| |
|
| | class Generator(nn.Module): |
| | def __init__(self, ngpu): |
| | super(Generator, self).__init__() |
| | self.ngpu = ngpu |
| | self.main = nn.Sequential( |
| | |
| | nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False), |
| | nn.BatchNorm2d(ngf * 8), |
| | nn.ReLU(True), |
| | |
| | nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), |
| | nn.BatchNorm2d(ngf * 4), |
| | nn.ReLU(True), |
| | |
| | nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), |
| | nn.BatchNorm2d(ngf * 2), |
| | nn.ReLU(True), |
| | |
| | nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), |
| | nn.BatchNorm2d(ngf), |
| | nn.ReLU(True), |
| | |
| | nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False), |
| | nn.Tanh() |
| | |
| | ) |
| |
|
| | def forward(self, input): |
| | if input.is_cuda and self.ngpu > 1: |
| | output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) |
| | else: |
| | output = self.main(input) |
| | return output |
| |
|
| |
|
| | netG = Generator(ngpu).to(device) |
| | netG.apply(weights_init) |
| | if opt.netG != '': |
| | netG.load_state_dict(torch.load(opt.netG)) |
| | print(netG) |
| |
|
| |
|
| | class Discriminator(nn.Module): |
| | def __init__(self, ngpu): |
| | super(Discriminator, self).__init__() |
| | self.ngpu = ngpu |
| | self.main = nn.Sequential( |
| | |
| | nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), |
| | nn.LeakyReLU(0.2, inplace=True), |
| | |
| | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), |
| | nn.BatchNorm2d(ndf * 2), |
| | nn.LeakyReLU(0.2, inplace=True), |
| | |
| | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), |
| | nn.BatchNorm2d(ndf * 4), |
| | nn.LeakyReLU(0.2, inplace=True), |
| | |
| | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), |
| | nn.BatchNorm2d(ndf * 8), |
| | nn.LeakyReLU(0.2, inplace=True), |
| | |
| | nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False), |
| | ) |
| |
|
| | def forward(self, input): |
| | if input.is_cuda and self.ngpu > 1: |
| | output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) |
| | else: |
| | output = self.main(input) |
| |
|
| | return output.view(-1, 1).squeeze(1) |
| |
|
| |
|
| | netD = Discriminator(ngpu).to(device) |
| | netD.apply(weights_init) |
| | if opt.netD != '': |
| | netD.load_state_dict(torch.load(opt.netD)) |
| | print(netD) |
| |
|
| | criterion = nn.BCEWithLogitsLoss() |
| |
|
| | fixed_noise = torch.randn(opt.batchSize, nz, 1, 1, device=device) |
| | real_label = 1 |
| | fake_label = 0 |
| |
|
| | |
| | optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) |
| | optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) |
| |
|
| | [netD, netG], [optimizerD, optimizerG] = amp.initialize( |
| | [netD, netG], [optimizerD, optimizerG], opt_level=opt.opt_level, num_losses=3) |
| |
|
| | for epoch in range(opt.niter): |
| | for i, data in enumerate(dataloader, 0): |
| | |
| | |
| | |
| | |
| | netD.zero_grad() |
| | real_cpu = data[0].to(device) |
| | batch_size = real_cpu.size(0) |
| | label = torch.full((batch_size,), real_label, device=device) |
| |
|
| | output = netD(real_cpu) |
| | errD_real = criterion(output, label) |
| | with amp.scale_loss(errD_real, optimizerD, loss_id=0) as errD_real_scaled: |
| | errD_real_scaled.backward() |
| | D_x = output.mean().item() |
| |
|
| | |
| | noise = torch.randn(batch_size, nz, 1, 1, device=device) |
| | fake = netG(noise) |
| | label.fill_(fake_label) |
| | output = netD(fake.detach()) |
| | errD_fake = criterion(output, label) |
| | with amp.scale_loss(errD_fake, optimizerD, loss_id=1) as errD_fake_scaled: |
| | errD_fake_scaled.backward() |
| | D_G_z1 = output.mean().item() |
| | errD = errD_real + errD_fake |
| | optimizerD.step() |
| |
|
| | |
| | |
| | |
| | netG.zero_grad() |
| | label.fill_(real_label) |
| | output = netD(fake) |
| | errG = criterion(output, label) |
| | with amp.scale_loss(errG, optimizerG, loss_id=2) as errG_scaled: |
| | errG_scaled.backward() |
| | D_G_z2 = output.mean().item() |
| | optimizerG.step() |
| |
|
| | print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' |
| | % (epoch, opt.niter, i, len(dataloader), |
| | errD.item(), errG.item(), D_x, D_G_z1, D_G_z2)) |
| | if i % 100 == 0: |
| | vutils.save_image(real_cpu, |
| | '%s/real_samples.png' % opt.outf, |
| | normalize=True) |
| | fake = netG(fixed_noise) |
| | vutils.save_image(fake.detach(), |
| | '%s/amp_fake_samples_epoch_%03d.png' % (opt.outf, epoch), |
| | normalize=True) |
| |
|
| | |
| | torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (opt.outf, epoch)) |
| | torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.outf, epoch)) |
| |
|
| |
|
| |
|