Jittor_WGAN_GP / wgan_gp.py
isLandLZ's picture
Upload wgan_gp.py
7b2d6c8
import jittor as jt
from jittor import nn
import argparse
import os
import numpy as np
from jittor.dataset.mnist import MNIST
import jittor.transform as transform
import cv2
import time
from jittor.dataset.dataset import ImageFolder
jt.flags.use_cuda = 1
save_img_path = './images_celebA'
save_model_path = './save_model_celebA'
os.makedirs(save_img_path, exist_ok=True)
os.makedirs(save_model_path, exist_ok=True)
parser = argparse.ArgumentParser()
parser.add_argument('--n_epochs', type=int, default=200, help='训练的时期数')
parser.add_argument('--batch_size', type=int, default=128, help='批次大小')
parser.add_argument('--lr', type=float, default=0.0002, help='学习率')
parser.add_argument('--b1', type=float, default=0.5, help='梯度的一阶动量衰减')
parser.add_argument('--b2', type=float, default=0.999, help='梯度的一阶动量衰减')
parser.add_argument('--n_cpu', type=int, default=8, help='批处理生成期间要使用的 cpu 线程数')
parser.add_argument('--latent_dim', type=int, default=100, help='潜在空间的维度')
parser.add_argument('--img_size', type=int, default=28, help='每个图像尺寸的大小')
parser.add_argument('--celebA_channels', type=int, default=3, help='图像通道数')
parser.add_argument('--mnist_channels', type=int, default=1, help='图像通道数')
parser.add_argument('--n_critic', type=int, default=5, help='每个迭代器的鉴别器训练步骤数')
parser.add_argument('--clip_value', type=float, default=0.01, help='光盘的上下剪辑值。 权重')
parser.add_argument('--sample_interval', type=int, default=400, help='图像样本之间的间隔')
parser.add_argument('--task', type=str, default='celebA', help='训练数据集类型')
parser.add_argument('--train_dir', type=str, default='D:\\Image_Generation_Learn\\Dataset\\CelebA_train', help='训练数据集地址')
opt = parser.parse_args()
print(opt)
img_shape = (opt.celebA_channels, opt.img_size, opt.img_size)
# 训练集加载程序
def DataLoader(dataclass, img_size, batch_size, train_dir):
if dataclass == 'MNIST':
Transform = transform.Compose([
transform.Resize(size=img_size),
transform.Gray(),
transform.ImageNormalize(mean=[0.5], std=[0.5])])
train_loader = MNIST (data_root=train_dir, train=True, transform=Transform).set_attrs(batch_size=batch_size, shuffle=True)
elif dataclass == 'celebA':
Transform = transform.Compose([
transform.Resize(size=img_size),
transform.ImageNormalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5])])
train_loader = ImageFolder(train_dir)\
.set_attrs(transform=Transform, batch_size=batch_size, shuffle=True)
else:
print("没有加载%s数据集的程序,请选择MNIST或者celebA!" % dataclass)
dataclass = input("请输入:MNIST或者celebA:")
DataLoader(dataclass, img_size, batch_size,train_dir)
return train_loader
dataloader = DataLoader(opt.task,opt.img_size,opt.batch_size,opt.train_dir)
# 保存图片
def save_image(img, path, nrow=10):
N,C,W,H = img.shape
img2=img.reshape([-1,W*nrow*nrow,H])
img=img2[:,:W*nrow,:]
for i in range(1,nrow):
img=np.concatenate([img,img2[:,W*nrow*i:W*nrow*(i+1),:]],axis=2)
min_=img.min()
max_=img.max()
img=(img-min_)/(max_-min_)*255
img=img.transpose((1,2,0))
cv2.imwrite(path,img)
# 生成器
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2))
return layers
self.model = nn.Sequential(*block(opt.latent_dim, 128, normalize=False), *block(128, 256), *block(256, 512), *block(512, 1024), nn.Linear(1024, int(np.prod(img_shape))), nn.Tanh())
def execute(self, z):
img = self.model(z)
img = img.view((img.shape[0], *img_shape))
return img
# 判别器
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
)
def execute(self, img):
img_flat = img.reshape((img.shape[0], (- 1)))
validity = self.model(img_flat)
return validity
lambda_gp = 10
# 初始化生成器和判别器
generator = Generator()
discriminator = Discriminator()
# 优化器
optimizer_G = jt.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = jt.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
# 损失函数(计算 WGAN GP 的梯度惩罚损失)
def compute_gradient_penalty(D, real_samples, fake_samples):
alpha = jt.array(np.random.random((real_samples.shape[0], 1, 1, 1)).astype('float32'))
interpolates = ((alpha * real_samples) + ((1 - alpha) * fake_samples))
d_interpolates = D(interpolates)
gradients = jt.grad(d_interpolates, interpolates)
gradients = gradients.reshape((gradients.shape[0], (- 1)))
gp =((jt.sqrt((gradients.sqr()).sum(1))-1).sqr()).mean()
return gp
batches_done = 0
warmup_times = -1
run_times = 3000
total_time = 0.
cnt = 0
# ----------
# 训练
# ----------
for epoch in range(opt.n_epochs):# 200
for i, (imgs, _) in enumerate(dataloader):
real_imgs = jt.array(imgs).float32()
# -----------------
# 训练生成器
# -----------------
z = jt.array((np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))).astype('float32'))
fake_imgs = generator(z)
real_validity = discriminator(real_imgs)
fake_validity = discriminator(fake_imgs)
gradient_penalty = compute_gradient_penalty(discriminator, real_imgs, fake_imgs)
d_loss = (- real_validity.mean() + fake_validity.mean() + lambda_gp * gradient_penalty)
d_loss.sync()
optimizer_D.step(d_loss)
# ---------------------
# 训练判别器
# ---------------------
if ((i % opt.n_critic) == 0):
fake_img = generator(z)
fake_validityg = discriminator(fake_img)
g_loss = -fake_validityg.mean()
g_loss.sync()
optimizer_G.step(g_loss)
if warmup_times==-1:
print(('[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]' % (epoch, opt.n_epochs, i, len(dataloader), d_loss.data, g_loss.data)))
#if ((batches_done % opt.sample_interval) == 0):
if ( i == 1583 ):#根据opt.batch_size而变化,每批次保存一次
save_image(fake_imgs.data[:25], ('%s/%d.png' % (save_img_path, batches_done)), nrow=5)
batches_done += opt.n_critic
if warmup_times!=-1:
jt.sync_all()
cnt += 1
print(cnt)
if cnt == warmup_times:
jt.sync_all(True)
sta = time.time()
if cnt > warmup_times + run_times:
jt.sync_all(True)
total_time = time.time() - sta
print(f"run {run_times} iters cost {total_time} seconds, and avg {total_time / run_times} one iter.")
exit(0)
if epoch % 10 == 0:# 0-199
generator.save("%s/generator_%s.pkl"%(save_model_path, opt.task))
discriminator.save("%s/discriminator_%s.pkl"%(save_model_path, opt.task))