File size: 4,170 Bytes
92b9080
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
import argparse

import torch
import torch.nn 
from torch.utils.data import TensorDataset
import torch.backends.cudnn as cudnn

class Generator(nn.Module):

    def __init__(self, in_ch):
        super(Generator, self).__init__()
        self.conv1 = nn.Conv2d(in_ch, 64, 4, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, 4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.deconv3 = nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.deconv4 = nn.ConvTranspose2d(64, in_ch, 4, stride=2, padding=1)

    def forward(self, x):
        h = F.leaky_relu(self.bn1(self.conv1(x)))
        h = F.leaky_relu(self.bn2(self.conv2(h)))
        h = F.leaky_relu(self.bn3(self.deconv3(h)))
        h = F.tanh(self.deconv4(h))
        return h

class Discriminator(nn.Module):

    def __init__(self, in_ch):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(in_ch, 64, 3, stride=2)
        self.conv2 = nn.Conv2d(64, 128, 3, stride=2)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, 3, stride=2)
        self.bn3 = nn.BatchNorm2d(256)
        if in_ch == 1:
            self.fc4 = nn.Linear(1024, 1)
        else:
            self.fc4 = nn.Linear(2304, 1)

    def forward(self, x):
        h = F.leaky_relu(self.conv1(x))
        h = F.leaky_relu(self.bn2(self.conv2(h)))
        h = F.leaky_relu(self.bn3(self.conv3(h)))
        h = F.sigmoid(self.fc4(h.view(h.size(0), -1)))
        return h


def main(args):

    #Initialize GAN model
    G = Generator(in_ch = C).cuda()
    D = Discriminator(in_ch = C).cuda()

    #Initialize Generator
    opt_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
    opt_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
    loss_bce = nn.BCELoss()
    loss_mse = nn.MSELoss()
    cudnn.benchmark = True

    #Initialize DataLoader
    train_data = torch.load("./adv_data.tar")
    train_data = TensorDataset(train_data["normal"], train_data["adv"])
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True)

    #Start Training
    for i in range(args.epochs):
        G.eval()
        x_fake = G(x_adv_temp).data
        G.train()
        gen_loss, dis_loss, n = 0, 0, 0
        for x, x_adv in train_loader:
            current_size = x.size(0)
            x, x_adv = x.cuda(), x_adv.cuda()

            #Train Discriminator
            t_real = torch.ones(current_size).cuda()
            t_fake = torch.zeros(current_size).cuda()
            y_real = D(x).squeeze()
            x_fake = G(x_adv)
            y_fake = D(x_fake).squeeze()

            loss_D = loss_bce(y_real, t_real) + loss_bce(y_fake, t_fake)
            opt_D.zero_grad()
            loss_D.backward()
            opt_D.step()

            # Train G
            for _ in range(2):
                x_fake = G(x_adv)
                y_fake = D(x_fake).squeeze()

                loss_G = args.alpha * loss_mse(x_fake, x) + args.beta * loss_bce(y_fake, t_real)
                opt_G.zero_grad()
                loss_G.backward()
                opt_G.step()

            gen_loss += loss_D.data[0] * x.size(0)
            dis_loss += loss_G.data[0] * x.size(0)
            n += x.size(0)

        print("epoch:{}, LossG:{:.3f}, LossD:{:.3f}".format(i, gen_loss / n, dis_loss / n))
        torch.save({"generator": G.state_dict(), "discriminator": D.state_dict()},
                   os.path.join(args.checkpoint, "{}.tar".format(i + 1)))

    G.eval()

def get_args():
    
    parser = argparse.ArgumentParser()
    
    parser.add_argument("--data", type=str, default="mnist")
    parser.add_argument("--lr", type=float, default=0.0002)
    parser.add_argument("--epochs", type=int, default=2)
    parser.add_argument("--alpha", type=float, default=0.7)
    parser.add_argument("--beta", type=float, default=0.3)
    parser.add_argument("--checkpoint", type=str, default="./checkpoint/test")
    args = parser.parse_args()
    
    return args


if __name__ == "__main__":
    get_args()
    main(args)