File size: 7,807 Bytes
7b2d6c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import jittor as jt
from jittor import nn
import argparse
import os
import numpy as np
from jittor.dataset.mnist import MNIST
import jittor.transform as transform
import cv2
import time
from jittor.dataset.dataset import ImageFolder

jt.flags.use_cuda = 1

save_img_path = './images_celebA'
save_model_path = './save_model_celebA'
os.makedirs(save_img_path, exist_ok=True)
os.makedirs(save_model_path, exist_ok=True)

parser = argparse.ArgumentParser()
parser.add_argument('--n_epochs', type=int, default=200, help='训练的时期数')
parser.add_argument('--batch_size', type=int, default=128, help='批次大小')
parser.add_argument('--lr', type=float, default=0.0002, help='学习率')
parser.add_argument('--b1', type=float, default=0.5, help='梯度的一阶动量衰减')
parser.add_argument('--b2', type=float, default=0.999, help='梯度的一阶动量衰减')
parser.add_argument('--n_cpu', type=int, default=8, help='批处理生成期间要使用的 cpu 线程数')
parser.add_argument('--latent_dim', type=int, default=100, help='潜在空间的维度')
parser.add_argument('--img_size', type=int, default=28, help='每个图像尺寸的大小')
parser.add_argument('--celebA_channels', type=int, default=3, help='图像通道数')
parser.add_argument('--mnist_channels', type=int, default=1, help='图像通道数')
parser.add_argument('--n_critic', type=int, default=5, help='每个迭代器的鉴别器训练步骤数')
parser.add_argument('--clip_value', type=float, default=0.01, help='光盘的上下剪辑值。 权重')
parser.add_argument('--sample_interval', type=int, default=400, help='图像样本之间的间隔')
parser.add_argument('--task', type=str, default='celebA', help='训练数据集类型')
parser.add_argument('--train_dir', type=str, default='D:\\Image_Generation_Learn\\Dataset\\CelebA_train', help='训练数据集地址')
opt = parser.parse_args()
print(opt)
img_shape = (opt.celebA_channels, opt.img_size, opt.img_size)

# 训练集加载程序
def DataLoader(dataclass, img_size, batch_size, train_dir):
    if dataclass == 'MNIST':
        Transform = transform.Compose([
            transform.Resize(size=img_size),
            transform.Gray(),
            transform.ImageNormalize(mean=[0.5], std=[0.5])])
        train_loader = MNIST (data_root=train_dir, train=True, transform=Transform).set_attrs(batch_size=batch_size, shuffle=True)
    elif dataclass == 'celebA':
        Transform = transform.Compose([
            transform.Resize(size=img_size),
            transform.ImageNormalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5])])
        train_loader = ImageFolder(train_dir)\
            .set_attrs(transform=Transform, batch_size=batch_size, shuffle=True)
    else:
        print("没有加载%s数据集的程序,请选择MNIST或者celebA!" % dataclass)
        dataclass = input("请输入:MNIST或者celebA:")
        DataLoader(dataclass, img_size, batch_size,train_dir)
    
    return train_loader

dataloader = DataLoader(opt.task,opt.img_size,opt.batch_size,opt.train_dir)

# 保存图片
def save_image(img, path, nrow=10):
    N,C,W,H = img.shape
    img2=img.reshape([-1,W*nrow*nrow,H])
    img=img2[:,:W*nrow,:]
    for i in range(1,nrow):
        img=np.concatenate([img,img2[:,W*nrow*i:W*nrow*(i+1),:]],axis=2)
    min_=img.min()
    max_=img.max()
    img=(img-min_)/(max_-min_)*255
    img=img.transpose((1,2,0))
    cv2.imwrite(path,img)

# 生成器
class Generator(nn.Module):

    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2))
            return layers
        self.model = nn.Sequential(*block(opt.latent_dim, 128, normalize=False), *block(128, 256), *block(256, 512), *block(512, 1024), nn.Linear(1024, int(np.prod(img_shape))), nn.Tanh())

    def execute(self, z):
        img = self.model(z)
        img = img.view((img.shape[0], *img_shape))
        return img

# 判别器
class Discriminator(nn.Module):

    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)), 512),
                                   nn.LeakyReLU(0.2),
                                   nn.Linear(512, 256),
                                   nn.LeakyReLU(0.2),
                                   nn.Linear(256, 1),
                                   )

    def execute(self, img):
        img_flat = img.reshape((img.shape[0], (- 1)))
        validity = self.model(img_flat)
        return validity

lambda_gp = 10

# 初始化生成器和判别器
generator = Generator()
discriminator = Discriminator()

# 优化器
optimizer_G = jt.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = jt.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

# 损失函数(计算 WGAN GP 的梯度惩罚损失)
def compute_gradient_penalty(D, real_samples, fake_samples):
    alpha = jt.array(np.random.random((real_samples.shape[0], 1, 1, 1)).astype('float32'))
    interpolates = ((alpha * real_samples) + ((1 - alpha) * fake_samples))
    d_interpolates = D(interpolates)
    gradients = jt.grad(d_interpolates, interpolates)
    gradients = gradients.reshape((gradients.shape[0], (- 1)))
    gp =((jt.sqrt((gradients.sqr()).sum(1))-1).sqr()).mean()
    return gp

batches_done = 0
warmup_times = -1
run_times = 3000
total_time = 0.
cnt = 0

# ----------
#  训练
# ----------

for epoch in range(opt.n_epochs):# 200
    for i, (imgs, _) in enumerate(dataloader):
        real_imgs = jt.array(imgs).float32()

        # -----------------
        #  训练生成器
        # -----------------

        z = jt.array((np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))).astype('float32'))
        fake_imgs = generator(z)
        real_validity = discriminator(real_imgs)
        fake_validity = discriminator(fake_imgs)
        gradient_penalty = compute_gradient_penalty(discriminator, real_imgs, fake_imgs)
        d_loss = (- real_validity.mean() + fake_validity.mean() + lambda_gp * gradient_penalty)
        d_loss.sync()
        optimizer_D.step(d_loss)
        
        # ---------------------
        #  训练判别器
        # ---------------------

        if ((i % opt.n_critic) == 0):
            fake_img = generator(z)
            fake_validityg = discriminator(fake_img)
            g_loss = -fake_validityg.mean()
            g_loss.sync()
            optimizer_G.step(g_loss)
            
            if warmup_times==-1:
                print(('[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]' % (epoch, opt.n_epochs, i, len(dataloader), d_loss.data, g_loss.data)))
                #if ((batches_done % opt.sample_interval) == 0):
                if ( i == 1583 ):#根据opt.batch_size而变化,每批次保存一次
                    save_image(fake_imgs.data[:25], ('%s/%d.png' % (save_img_path, batches_done)), nrow=5)
                batches_done += opt.n_critic

        if warmup_times!=-1:
            jt.sync_all()
            cnt += 1
            print(cnt)
            if cnt == warmup_times:
                jt.sync_all(True)
                sta = time.time()
            if cnt > warmup_times + run_times:
                jt.sync_all(True)
                total_time = time.time() - sta
                print(f"run {run_times} iters cost {total_time} seconds, and avg {total_time / run_times} one iter.")
                exit(0)
    
    if epoch % 10 == 0:# 0-199
        generator.save("%s/generator_%s.pkl"%(save_model_path, opt.task))
        discriminator.save("%s/discriminator_%s.pkl"%(save_model_path, opt.task))