|
|
import jittor as jt |
|
|
from jittor import nn |
|
|
import argparse |
|
|
import os |
|
|
import numpy as np |
|
|
from jittor.dataset.mnist import MNIST |
|
|
import jittor.transform as transform |
|
|
import cv2 |
|
|
import time |
|
|
from jittor.dataset.dataset import ImageFolder |
|
|
|
|
|
jt.flags.use_cuda = 1 |
|
|
|
|
|
save_img_path = './images_celebA' |
|
|
save_model_path = './save_model_celebA' |
|
|
os.makedirs(save_img_path, exist_ok=True) |
|
|
os.makedirs(save_model_path, exist_ok=True) |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('--n_epochs', type=int, default=200, help='训练的时期数') |
|
|
parser.add_argument('--batch_size', type=int, default=128, help='批次大小') |
|
|
parser.add_argument('--lr', type=float, default=0.0002, help='学习率') |
|
|
parser.add_argument('--b1', type=float, default=0.5, help='梯度的一阶动量衰减') |
|
|
parser.add_argument('--b2', type=float, default=0.999, help='梯度的一阶动量衰减') |
|
|
parser.add_argument('--n_cpu', type=int, default=8, help='批处理生成期间要使用的 cpu 线程数') |
|
|
parser.add_argument('--latent_dim', type=int, default=100, help='潜在空间的维度') |
|
|
parser.add_argument('--img_size', type=int, default=28, help='每个图像尺寸的大小') |
|
|
parser.add_argument('--celebA_channels', type=int, default=3, help='图像通道数') |
|
|
parser.add_argument('--mnist_channels', type=int, default=1, help='图像通道数') |
|
|
parser.add_argument('--n_critic', type=int, default=5, help='每个迭代器的鉴别器训练步骤数') |
|
|
parser.add_argument('--clip_value', type=float, default=0.01, help='光盘的上下剪辑值。 权重') |
|
|
parser.add_argument('--sample_interval', type=int, default=400, help='图像样本之间的间隔') |
|
|
parser.add_argument('--task', type=str, default='celebA', help='训练数据集类型') |
|
|
parser.add_argument('--train_dir', type=str, default='D:\\Image_Generation_Learn\\Dataset\\CelebA_train', help='训练数据集地址') |
|
|
opt = parser.parse_args() |
|
|
print(opt) |
|
|
img_shape = (opt.celebA_channels, opt.img_size, opt.img_size) |
|
|
|
|
|
|
|
|
def DataLoader(dataclass, img_size, batch_size, train_dir): |
|
|
if dataclass == 'MNIST': |
|
|
Transform = transform.Compose([ |
|
|
transform.Resize(size=img_size), |
|
|
transform.Gray(), |
|
|
transform.ImageNormalize(mean=[0.5], std=[0.5])]) |
|
|
train_loader = MNIST (data_root=train_dir, train=True, transform=Transform).set_attrs(batch_size=batch_size, shuffle=True) |
|
|
elif dataclass == 'celebA': |
|
|
Transform = transform.Compose([ |
|
|
transform.Resize(size=img_size), |
|
|
transform.ImageNormalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5])]) |
|
|
train_loader = ImageFolder(train_dir)\ |
|
|
.set_attrs(transform=Transform, batch_size=batch_size, shuffle=True) |
|
|
else: |
|
|
print("没有加载%s数据集的程序,请选择MNIST或者celebA!" % dataclass) |
|
|
dataclass = input("请输入:MNIST或者celebA:") |
|
|
DataLoader(dataclass, img_size, batch_size,train_dir) |
|
|
|
|
|
return train_loader |
|
|
|
|
|
dataloader = DataLoader(opt.task,opt.img_size,opt.batch_size,opt.train_dir) |
|
|
|
|
|
|
|
|
def save_image(img, path, nrow=10): |
|
|
N,C,W,H = img.shape |
|
|
img2=img.reshape([-1,W*nrow*nrow,H]) |
|
|
img=img2[:,:W*nrow,:] |
|
|
for i in range(1,nrow): |
|
|
img=np.concatenate([img,img2[:,W*nrow*i:W*nrow*(i+1),:]],axis=2) |
|
|
min_=img.min() |
|
|
max_=img.max() |
|
|
img=(img-min_)/(max_-min_)*255 |
|
|
img=img.transpose((1,2,0)) |
|
|
cv2.imwrite(path,img) |
|
|
|
|
|
|
|
|
class Generator(nn.Module): |
|
|
|
|
|
def __init__(self): |
|
|
super(Generator, self).__init__() |
|
|
|
|
|
def block(in_feat, out_feat, normalize=True): |
|
|
layers = [nn.Linear(in_feat, out_feat)] |
|
|
if normalize: |
|
|
layers.append(nn.BatchNorm1d(out_feat, 0.8)) |
|
|
layers.append(nn.LeakyReLU(0.2)) |
|
|
return layers |
|
|
self.model = nn.Sequential(*block(opt.latent_dim, 128, normalize=False), *block(128, 256), *block(256, 512), *block(512, 1024), nn.Linear(1024, int(np.prod(img_shape))), nn.Tanh()) |
|
|
|
|
|
def execute(self, z): |
|
|
img = self.model(z) |
|
|
img = img.view((img.shape[0], *img_shape)) |
|
|
return img |
|
|
|
|
|
|
|
|
class Discriminator(nn.Module): |
|
|
|
|
|
def __init__(self): |
|
|
super(Discriminator, self).__init__() |
|
|
self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)), 512), |
|
|
nn.LeakyReLU(0.2), |
|
|
nn.Linear(512, 256), |
|
|
nn.LeakyReLU(0.2), |
|
|
nn.Linear(256, 1), |
|
|
) |
|
|
|
|
|
def execute(self, img): |
|
|
img_flat = img.reshape((img.shape[0], (- 1))) |
|
|
validity = self.model(img_flat) |
|
|
return validity |
|
|
|
|
|
lambda_gp = 10 |
|
|
|
|
|
|
|
|
generator = Generator() |
|
|
discriminator = Discriminator() |
|
|
|
|
|
|
|
|
optimizer_G = jt.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) |
|
|
optimizer_D = jt.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) |
|
|
|
|
|
|
|
|
def compute_gradient_penalty(D, real_samples, fake_samples): |
|
|
alpha = jt.array(np.random.random((real_samples.shape[0], 1, 1, 1)).astype('float32')) |
|
|
interpolates = ((alpha * real_samples) + ((1 - alpha) * fake_samples)) |
|
|
d_interpolates = D(interpolates) |
|
|
gradients = jt.grad(d_interpolates, interpolates) |
|
|
gradients = gradients.reshape((gradients.shape[0], (- 1))) |
|
|
gp =((jt.sqrt((gradients.sqr()).sum(1))-1).sqr()).mean() |
|
|
return gp |
|
|
|
|
|
batches_done = 0 |
|
|
warmup_times = -1 |
|
|
run_times = 3000 |
|
|
total_time = 0. |
|
|
cnt = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for epoch in range(opt.n_epochs): |
|
|
for i, (imgs, _) in enumerate(dataloader): |
|
|
real_imgs = jt.array(imgs).float32() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
z = jt.array((np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))).astype('float32')) |
|
|
fake_imgs = generator(z) |
|
|
real_validity = discriminator(real_imgs) |
|
|
fake_validity = discriminator(fake_imgs) |
|
|
gradient_penalty = compute_gradient_penalty(discriminator, real_imgs, fake_imgs) |
|
|
d_loss = (- real_validity.mean() + fake_validity.mean() + lambda_gp * gradient_penalty) |
|
|
d_loss.sync() |
|
|
optimizer_D.step(d_loss) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if ((i % opt.n_critic) == 0): |
|
|
fake_img = generator(z) |
|
|
fake_validityg = discriminator(fake_img) |
|
|
g_loss = -fake_validityg.mean() |
|
|
g_loss.sync() |
|
|
optimizer_G.step(g_loss) |
|
|
|
|
|
if warmup_times==-1: |
|
|
print(('[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]' % (epoch, opt.n_epochs, i, len(dataloader), d_loss.data, g_loss.data))) |
|
|
|
|
|
if ( i == 1583 ): |
|
|
save_image(fake_imgs.data[:25], ('%s/%d.png' % (save_img_path, batches_done)), nrow=5) |
|
|
batches_done += opt.n_critic |
|
|
|
|
|
if warmup_times!=-1: |
|
|
jt.sync_all() |
|
|
cnt += 1 |
|
|
print(cnt) |
|
|
if cnt == warmup_times: |
|
|
jt.sync_all(True) |
|
|
sta = time.time() |
|
|
if cnt > warmup_times + run_times: |
|
|
jt.sync_all(True) |
|
|
total_time = time.time() - sta |
|
|
print(f"run {run_times} iters cost {total_time} seconds, and avg {total_time / run_times} one iter.") |
|
|
exit(0) |
|
|
|
|
|
if epoch % 10 == 0: |
|
|
generator.save("%s/generator_%s.pkl"%(save_model_path, opt.task)) |
|
|
discriminator.save("%s/discriminator_%s.pkl"%(save_model_path, opt.task)) |