File size: 15,011 Bytes
197d4ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
import torch
import torch.nn as nn
import numpy as np
import random

from methods.gnn import GNN_nl
from methods import backbone_multiblock
from methods.tool_func import *
from methods.meta_template_StyleAdv_RN_GNN import MetaTemplate


class StyleAdvGNN(MetaTemplate):
  maml=False
  def __init__(self, model_func,  n_way, n_support, tf_path=None):
    super(StyleAdvGNN, self).__init__(model_func, n_way, n_support, tf_path=tf_path)

    # loss function
    self.loss_fn = nn.CrossEntropyLoss()

    # metric function
    self.fc = nn.Sequential(nn.Linear(self.feat_dim, 128), nn.BatchNorm1d(128, track_running_stats=False)) if not self.maml else nn.Sequential(backbone.Linear_fw(self.feat_dim, 128), backbone.BatchNorm1d_fw(128, track_running_stats=False))
    self.gnn = GNN_nl(128 + self.n_way, 96, self.n_way)

    # for global classifier
    self.method = 'GnnNet'
    self.classifier = nn.Linear(self.feature.final_feat_dim, 64)

    # fix label for training the metric function   1*nw(1 + ns)*nw
    support_label = torch.from_numpy(np.repeat(range(self.n_way), self.n_support)).unsqueeze(1)
    support_label = torch.zeros(self.n_way*self.n_support, self.n_way).scatter(1, support_label, 1).view(self.n_way, self.n_support, self.n_way)
    support_label = torch.cat([support_label, torch.zeros(self.n_way, 1, n_way)], dim=1)
    self.support_label = support_label.view(1, -1, self.n_way)

  def cuda(self):
    self.feature.cuda()
    self.fc.cuda()
    self.gnn.cuda()
    self.classifier.cuda()
    self.support_label = self.support_label.cuda()
    return self

  def set_forward(self,x,is_feature=False):
    x = x.cuda()

    if is_feature:
      # reshape the feature tensor: n_way * n_s + 15 * f
      assert(x.size(1) == self.n_support + 15)
      z = self.fc(x.view(-1, *x.size()[2:]))
      z = z.view(self.n_way, -1, z.size(1))
    else:
      # get feature using encoder
      x = x.view(-1, *x.size()[2:])
      z = self.fc(self.feature(x))
      z = z.view(self.n_way, -1, z.size(1))

    # stack the feature for metric function: n_way * n_s + n_q * f -> n_q * [1 * n_way(n_s + 1) * f]
    z_stack = [torch.cat([z[:, :self.n_support], z[:, self.n_support + i:self.n_support + i + 1]], dim=1).view(1, -1, z.size(2)) for i in range(self.n_query)]
    assert(z_stack[0].size(1) == self.n_way*(self.n_support + 1))
    scores = self.forward_gnn(z_stack)
    return scores



  def forward_gnn(self, zs):
    # gnn inp: n_q * n_way(n_s + 1) * f
    nodes = torch.cat([torch.cat([z, self.support_label], dim=2) for z in zs], dim=0)
    scores = self.gnn(nodes)

    # n_q * n_way(n_s + 1) * n_way -> (n_way * n_q) * n_way
    scores = scores.view(self.n_query, self.n_way, self.n_support + 1, self.n_way)[:, :, -1].permute(1, 0, 2).contiguous().view(-1, self.n_way)
    return scores


  def set_forward_loss(self, x):
    y_query = torch.from_numpy(np.repeat(range( self.n_way ), self.n_query))
    y_query = y_query.cuda()
    scores = self.set_forward(x)
    loss = self.loss_fn(scores, y_query)
    return scores, loss


  def adversarial_attack_Incre(self, x_ori, y_ori, epsilon_list):
    x_ori = x_ori.cuda()
    y_ori = y_ori.cuda()
    x_size = x_ori.size()
    x_ori = x_ori.view(x_size[0]*x_size[1], x_size[2], x_size[3], x_size[4])
    y_ori = y_ori.view(x_size[0]*x_size[1])

    # if not adv, set defalut = 'None'
    adv_style_mean_block1, adv_style_std_block1 = 'None', 'None'
    adv_style_mean_block2, adv_style_std_block2 = 'None', 'None'
    adv_style_mean_block3, adv_style_std_block3 = 'None', 'None'

    # forward and set the grad = True
    blocklist = 'block123'
    
    if('1' in blocklist and epsilon_list[0] != 0 ):
      # forward block1
      x_ori_block1 = self.feature.forward_block1(x_ori)
      feat_size_block1 = x_ori_block1.size()
      ori_style_mean_block1, ori_style_std_block1 = calc_mean_std(x_ori_block1)
      # set them as learnable parameters
      ori_style_mean_block1  = torch.nn.Parameter(ori_style_mean_block1)
      ori_style_std_block1 = torch.nn.Parameter(ori_style_std_block1)
      ori_style_mean_block1.requires_grad_()
      ori_style_std_block1.requires_grad_()
      # contain ori_style_mean_block1 in the graph 
      x_normalized_block1 = (x_ori_block1 - ori_style_mean_block1.detach().expand(feat_size_block1)) / ori_style_std_block1.detach().expand(feat_size_block1)
      x_ori_block1 = x_normalized_block1 * ori_style_std_block1.expand(feat_size_block1) + ori_style_mean_block1.expand(feat_size_block1)
      
      # pass the rest model
      x_ori_block2 = self.feature.forward_block2(x_ori_block1)
      x_ori_block3 = self.feature.forward_block3(x_ori_block2)
      x_ori_block4 = self.feature.forward_block4(x_ori_block3)
      x_ori_fea = self.feature.forward_rest(x_ori_block4)
      x_ori_output = self.classifier.forward(x_ori_fea)
    
      # calculate initial pred, loss and acc
      ori_pred = x_ori_output.max(1, keepdim=True)[1]
      ori_loss = self.loss_fn(x_ori_output, y_ori)
      ori_acc = (ori_pred == y_ori).type(torch.float).sum().item() / y_ori.size()[0]

      # zero all the existing gradients
      self.feature.zero_grad()
      self.classifier.zero_grad()
   
      # backward loss
      ori_loss.backward()

      # collect datagrad
      grad_ori_style_mean_block1 = ori_style_mean_block1.grad.detach()
      grad_ori_style_std_block1 = ori_style_std_block1.grad.detach()
    
      # fgsm style attack
      index = torch.randint(0, len(epsilon_list), (1, ))[0]
      epsilon = epsilon_list[index]

      adv_style_mean_block1 = fgsm_attack(ori_style_mean_block1, epsilon, grad_ori_style_mean_block1)
      adv_style_std_block1 = fgsm_attack(ori_style_std_block1, epsilon, grad_ori_style_std_block1)

    # add zero_grad
    self.feature.zero_grad()
    self.classifier.zero_grad()

    if('2' in blocklist and epsilon_list[1] != 0):
      # forward block1
      x_ori_block1 = self.feature.forward_block1(x_ori)
      # update adv_block1
      x_adv_block1 = changeNewAdvStyle(x_ori_block1, adv_style_mean_block1, adv_style_std_block1, p_thred=0)
      # forward block2
      x_ori_block2 = self.feature.forward_block2(x_adv_block1) 
      # calculate mean and std
      feat_size_block2 = x_ori_block2.size()
      ori_style_mean_block2, ori_style_std_block2 = calc_mean_std(x_ori_block2)
      # set them as learnable parameters
      ori_style_mean_block2  = torch.nn.Parameter(ori_style_mean_block2)
      ori_style_std_block2 = torch.nn.Parameter(ori_style_std_block2)
      ori_style_mean_block2.requires_grad_()
      ori_style_std_block2.requires_grad_()
      # contain ori_style_mean_block1 in the graph 
      x_normalized_block2 = (x_ori_block2 - ori_style_mean_block2.detach().expand(feat_size_block2)) / ori_style_std_block2.detach().expand(feat_size_block2)
      x_ori_block2 = x_normalized_block2 * ori_style_std_block2.expand(feat_size_block2) + ori_style_mean_block2.expand(feat_size_block2)
      # pass the rest model
      x_ori_block3 = self.feature.forward_block3(x_ori_block2)
      x_ori_block4 = self.feature.forward_block4(x_ori_block3)
      x_ori_fea = self.feature.forward_rest(x_ori_block4)
      x_ori_output = self.classifier.forward(x_ori_fea)
      # calculate initial pred, loss and acc
      ori_pred = x_ori_output.max(1, keepdim=True)[1]
      ori_loss = self.loss_fn(x_ori_output, y_ori)
      ori_acc = (ori_pred == y_ori).type(torch.float).sum().item() / y_ori.size()[0]
      # zero all the existing gradients
      self.feature.zero_grad()
      self.classifier.zero_grad()
      # backward loss
      ori_loss.backward()
      # collect datagrad
      grad_ori_style_mean_block2 = ori_style_mean_block2.grad.detach()
      grad_ori_style_std_block2 = ori_style_std_block2.grad.detach()
      # fgsm style attack
      index = torch.randint(0, len(epsilon_list), (1, ))[0]
      epsilon = epsilon_list[index]
      adv_style_mean_block2 = fgsm_attack(ori_style_mean_block2, epsilon, grad_ori_style_mean_block2)
      adv_style_std_block2 = fgsm_attack(ori_style_std_block2, epsilon, grad_ori_style_std_block2)

    # add zero_grad
    self.feature.zero_grad()
    self.classifier.zero_grad()

    if('3' in blocklist and epsilon_list[2] != 0):
      # forward block1, block2, block3
      x_ori_block1 = self.feature.forward_block1(x_ori)
      x_adv_block1 = changeNewAdvStyle(x_ori_block1, adv_style_mean_block1, adv_style_std_block1, p_thred=0)
      x_ori_block2 = self.feature.forward_block2(x_adv_block1)
      x_adv_block2 = changeNewAdvStyle(x_ori_block2, adv_style_mean_block2, adv_style_std_block2, p_thred=0)
      x_ori_block3 = self.feature.forward_block3(x_adv_block2)
      # calculate mean and std
      feat_size_block3 = x_ori_block3.size()
      ori_style_mean_block3, ori_style_std_block3 = calc_mean_std(x_ori_block3)
      # set them as learnable parameters
      ori_style_mean_block3  = torch.nn.Parameter(ori_style_mean_block3)
      ori_style_std_block3 = torch.nn.Parameter(ori_style_std_block3)
      ori_style_mean_block3.requires_grad_()
      ori_style_std_block3.requires_grad_()
      # contain ori_style_mean_block3 in the graph 
      x_normalized_block3 = (x_ori_block3 - ori_style_mean_block3.detach().expand(feat_size_block3)) / ori_style_std_block3.detach().expand(feat_size_block3)
      x_ori_block3 = x_normalized_block3 * ori_style_std_block3.expand(feat_size_block3) + ori_style_mean_block3.expand(feat_size_block3)
      # pass the rest model
      x_ori_block4 = self.feature.forward_block4(x_ori_block3)
      x_ori_fea = self.feature.forward_rest(x_ori_block4)
      x_ori_output = self.classifier.forward(x_ori_fea)
      # calculate initial pred, loss and acc
      ori_pred = x_ori_output.max(1, keepdim=True)[1]
      ori_loss = self.loss_fn(x_ori_output, y_ori)
      ori_acc = (ori_pred == y_ori).type(torch.float).sum().item() / y_ori.size()[0]
      # zero all the existing gradients
      self.feature.zero_grad()
      self.classifier.zero_grad()
      # backward loss
      ori_loss.backward()
      # collect datagrad
      grad_ori_style_mean_block3 = ori_style_mean_block3.grad.detach()
      grad_ori_style_std_block3 = ori_style_std_block3.grad.detach()
      # fgsm style attack
      index = torch.randint(0, len(epsilon_list), (1, ))[0]
      epsilon = epsilon_list[index]
      adv_style_mean_block3 = fgsm_attack(ori_style_mean_block3, epsilon, grad_ori_style_mean_block3)
      adv_style_std_block3 = fgsm_attack(ori_style_std_block3, epsilon, grad_ori_style_std_block3)

    return adv_style_mean_block1, adv_style_std_block1, adv_style_mean_block2, adv_style_std_block2, adv_style_mean_block3, adv_style_std_block3 
    
  
  def set_statues_of_modules(self, flag):
    if(flag=='eval'):
      self.feature.eval()
      self.fc.eval()
      self.gnn.eval()
      self.classifier.eval()
    elif(flag=='train'):
      self.feature.train()
      self.fc.train()
      self.gnn.train()
      self.classifier.train()
    return 
   

  def set_forward_loss_StyAdv(self, x_ori, global_y, epsilon_list):
    ##################################################################
    # 0. first cp x_adv from x_ori
    x_adv = x_ori

    ##################################################################
    # 1. styleAdv
    self.set_statues_of_modules('eval') 

    adv_style_mean_block1, adv_style_std_block1, adv_style_mean_block2, adv_style_std_block2, adv_style_mean_block3, adv_style_std_block3 = self.adversarial_attack_Incre(x_ori, global_y, epsilon_list)
 
    self.feature.zero_grad()
    self.fc.zero_grad()
    self.classifier.zero_grad()
    self.gnn.zero_grad()

    #################################################################
    # 2. forward and get loss
    self.set_statues_of_modules('train')

    # define y_query for FSL
    y_query = torch.from_numpy(np.repeat(range( self.n_way ), self.n_query))
    y_query = y_query.cuda()

    # forward x_ori 
    x_ori = x_ori.cuda()
    x_size = x_ori.size()
    x_ori = x_ori.view(x_size[0]*x_size[1], x_size[2], x_size[3], x_size[4])
    global_y = global_y.view(x_size[0]*x_size[1]).cuda()
    x_ori_block1 = self.feature.forward_block1(x_ori)
    x_ori_block2 = self.feature.forward_block2(x_ori_block1)
    x_ori_block3 = self.feature.forward_block3(x_ori_block2)
    x_ori_block4 = self.feature.forward_block4(x_ori_block3)
    x_ori_fea = self.feature.forward_rest(x_ori_block4)

    # ori cls global loss    
    scores_cls_ori = self.classifier.forward(x_ori_fea)
    loss_cls_ori = self.loss_fn(scores_cls_ori, global_y)
    acc_cls_ori = ( scores_cls_ori.max(1, keepdim=True)[1]  == global_y ).type(torch.float).sum().item() / global_y.size()[0]

    # ori FSL scores and losses
    x_ori_z = self.fc(x_ori_fea)
    x_ori_z = x_ori_z.view(self.n_way, -1, x_ori_z.size(1))
    x_ori_z_stack = [torch.cat([x_ori_z[:, :self.n_support], x_ori_z[:, self.n_support + i:self.n_support + i + 1]], dim=1).view(1, -1, x_ori_z.size(2)) for i in range(self.n_query)]
    assert(x_ori_z_stack[0].size(1) == self.n_way*(self.n_support + 1))
    scores_fsl_ori = self.forward_gnn(x_ori_z_stack)
    loss_fsl_ori = self.loss_fn(scores_fsl_ori, y_query)

    # forward x_adv
    x_adv = x_adv.cuda()
    x_adv = x_adv.view(x_size[0]*x_size[1], x_size[2], x_size[3], x_size[4])
    x_adv_block1 = self.feature.forward_block1(x_adv)

    x_adv_block1_newStyle = changeNewAdvStyle(x_adv_block1, adv_style_mean_block1, adv_style_std_block1, p_thred = P_THRED) 
    x_adv_block2 = self.feature.forward_block2(x_adv_block1_newStyle)
    x_adv_block2_newStyle = changeNewAdvStyle(x_adv_block2, adv_style_mean_block2, adv_style_std_block2, p_thred = P_THRED)
    x_adv_block3 = self.feature.forward_block3(x_adv_block2_newStyle)
    x_adv_block3_newStyle = changeNewAdvStyle(x_adv_block3, adv_style_mean_block3, adv_style_std_block3, p_thred = P_THRED)
    x_adv_block4 = self.feature.forward_block4(x_adv_block3_newStyle)
    x_adv_fea = self.feature.forward_rest(x_adv_block4)
   
    # adv cls gloabl loss
    scores_cls_adv = self.classifier.forward(x_adv_fea)
    loss_cls_adv = self.loss_fn(scores_cls_adv, global_y)
    acc_cls_adv = ( scores_cls_adv.max(1, keepdim=True)[1]  == global_y ).type(torch.float).sum().item() / global_y.size()[0]

    # adv FSL scores and losses
    x_adv_z = self.fc(x_adv_fea)
    x_adv_z = x_adv_z.view(self.n_way, -1, x_adv_z.size(1))
    x_adv_z_stack = [torch.cat([x_adv_z[:, :self.n_support], x_adv_z[:, self.n_support + i:self.n_support + i + 1]], dim=1).view(1, -1, x_adv_z.size(2)) for i in range(self.n_query)]
    assert(x_adv_z_stack[0].size(1) == self.n_way*(self.n_support + 1))
    scores_fsl_adv = self.forward_gnn(x_adv_z_stack)
    loss_fsl_adv = self.loss_fn(scores_fsl_adv, y_query)

    #print('scores_fsl_adv:', scores_fsl_adv.mean(), 'loss_fsl_adv:', loss_fsl_adv, 'scores_cls_adv:', scores_cls_adv.mean(), 'loss_cls_adv:', loss_cls_adv)
    return scores_fsl_ori, loss_fsl_ori, scores_cls_ori, loss_cls_ori, scores_fsl_adv, loss_fsl_adv, scores_cls_adv, loss_cls_adv