File size: 7,807 Bytes
7b2d6c8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 | import jittor as jt
from jittor import nn
import argparse
import os
import numpy as np
from jittor.dataset.mnist import MNIST
import jittor.transform as transform
import cv2
import time
from jittor.dataset.dataset import ImageFolder
jt.flags.use_cuda = 1
save_img_path = './images_celebA'
save_model_path = './save_model_celebA'
os.makedirs(save_img_path, exist_ok=True)
os.makedirs(save_model_path, exist_ok=True)
parser = argparse.ArgumentParser()
parser.add_argument('--n_epochs', type=int, default=200, help='训练的时期数')
parser.add_argument('--batch_size', type=int, default=128, help='批次大小')
parser.add_argument('--lr', type=float, default=0.0002, help='学习率')
parser.add_argument('--b1', type=float, default=0.5, help='梯度的一阶动量衰减')
parser.add_argument('--b2', type=float, default=0.999, help='梯度的一阶动量衰减')
parser.add_argument('--n_cpu', type=int, default=8, help='批处理生成期间要使用的 cpu 线程数')
parser.add_argument('--latent_dim', type=int, default=100, help='潜在空间的维度')
parser.add_argument('--img_size', type=int, default=28, help='每个图像尺寸的大小')
parser.add_argument('--celebA_channels', type=int, default=3, help='图像通道数')
parser.add_argument('--mnist_channels', type=int, default=1, help='图像通道数')
parser.add_argument('--n_critic', type=int, default=5, help='每个迭代器的鉴别器训练步骤数')
parser.add_argument('--clip_value', type=float, default=0.01, help='光盘的上下剪辑值。 权重')
parser.add_argument('--sample_interval', type=int, default=400, help='图像样本之间的间隔')
parser.add_argument('--task', type=str, default='celebA', help='训练数据集类型')
parser.add_argument('--train_dir', type=str, default='D:\\Image_Generation_Learn\\Dataset\\CelebA_train', help='训练数据集地址')
opt = parser.parse_args()
print(opt)
img_shape = (opt.celebA_channels, opt.img_size, opt.img_size)
# 训练集加载程序
def DataLoader(dataclass, img_size, batch_size, train_dir):
if dataclass == 'MNIST':
Transform = transform.Compose([
transform.Resize(size=img_size),
transform.Gray(),
transform.ImageNormalize(mean=[0.5], std=[0.5])])
train_loader = MNIST (data_root=train_dir, train=True, transform=Transform).set_attrs(batch_size=batch_size, shuffle=True)
elif dataclass == 'celebA':
Transform = transform.Compose([
transform.Resize(size=img_size),
transform.ImageNormalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5])])
train_loader = ImageFolder(train_dir)\
.set_attrs(transform=Transform, batch_size=batch_size, shuffle=True)
else:
print("没有加载%s数据集的程序,请选择MNIST或者celebA!" % dataclass)
dataclass = input("请输入:MNIST或者celebA:")
DataLoader(dataclass, img_size, batch_size,train_dir)
return train_loader
dataloader = DataLoader(opt.task,opt.img_size,opt.batch_size,opt.train_dir)
# 保存图片
def save_image(img, path, nrow=10):
N,C,W,H = img.shape
img2=img.reshape([-1,W*nrow*nrow,H])
img=img2[:,:W*nrow,:]
for i in range(1,nrow):
img=np.concatenate([img,img2[:,W*nrow*i:W*nrow*(i+1),:]],axis=2)
min_=img.min()
max_=img.max()
img=(img-min_)/(max_-min_)*255
img=img.transpose((1,2,0))
cv2.imwrite(path,img)
# 生成器
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2))
return layers
self.model = nn.Sequential(*block(opt.latent_dim, 128, normalize=False), *block(128, 256), *block(256, 512), *block(512, 1024), nn.Linear(1024, int(np.prod(img_shape))), nn.Tanh())
def execute(self, z):
img = self.model(z)
img = img.view((img.shape[0], *img_shape))
return img
# 判别器
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
)
def execute(self, img):
img_flat = img.reshape((img.shape[0], (- 1)))
validity = self.model(img_flat)
return validity
lambda_gp = 10
# 初始化生成器和判别器
generator = Generator()
discriminator = Discriminator()
# 优化器
optimizer_G = jt.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = jt.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
# 损失函数(计算 WGAN GP 的梯度惩罚损失)
def compute_gradient_penalty(D, real_samples, fake_samples):
alpha = jt.array(np.random.random((real_samples.shape[0], 1, 1, 1)).astype('float32'))
interpolates = ((alpha * real_samples) + ((1 - alpha) * fake_samples))
d_interpolates = D(interpolates)
gradients = jt.grad(d_interpolates, interpolates)
gradients = gradients.reshape((gradients.shape[0], (- 1)))
gp =((jt.sqrt((gradients.sqr()).sum(1))-1).sqr()).mean()
return gp
batches_done = 0
warmup_times = -1
run_times = 3000
total_time = 0.
cnt = 0
# ----------
# 训练
# ----------
for epoch in range(opt.n_epochs):# 200
for i, (imgs, _) in enumerate(dataloader):
real_imgs = jt.array(imgs).float32()
# -----------------
# 训练生成器
# -----------------
z = jt.array((np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))).astype('float32'))
fake_imgs = generator(z)
real_validity = discriminator(real_imgs)
fake_validity = discriminator(fake_imgs)
gradient_penalty = compute_gradient_penalty(discriminator, real_imgs, fake_imgs)
d_loss = (- real_validity.mean() + fake_validity.mean() + lambda_gp * gradient_penalty)
d_loss.sync()
optimizer_D.step(d_loss)
# ---------------------
# 训练判别器
# ---------------------
if ((i % opt.n_critic) == 0):
fake_img = generator(z)
fake_validityg = discriminator(fake_img)
g_loss = -fake_validityg.mean()
g_loss.sync()
optimizer_G.step(g_loss)
if warmup_times==-1:
print(('[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]' % (epoch, opt.n_epochs, i, len(dataloader), d_loss.data, g_loss.data)))
#if ((batches_done % opt.sample_interval) == 0):
if ( i == 1583 ):#根据opt.batch_size而变化,每批次保存一次
save_image(fake_imgs.data[:25], ('%s/%d.png' % (save_img_path, batches_done)), nrow=5)
batches_done += opt.n_critic
if warmup_times!=-1:
jt.sync_all()
cnt += 1
print(cnt)
if cnt == warmup_times:
jt.sync_all(True)
sta = time.time()
if cnt > warmup_times + run_times:
jt.sync_all(True)
total_time = time.time() - sta
print(f"run {run_times} iters cost {total_time} seconds, and avg {total_time / run_times} one iter.")
exit(0)
if epoch % 10 == 0:# 0-199
generator.save("%s/generator_%s.pkl"%(save_model_path, opt.task))
discriminator.save("%s/discriminator_%s.pkl"%(save_model_path, opt.task)) |