KC123hello commited on
Commit
8cab4cc
·
verified ·
1 Parent(s): 96debc0

Upload 39 files

Browse files
Files changed (40) hide show
  1. .gitattributes +1 -0
  2. vlm_eval/__init__.py +0 -0
  3. vlm_eval/__pycache__/__init__.cpython-311.pyc +0 -0
  4. vlm_eval/__pycache__/__init__.cpython-312.pyc +0 -0
  5. vlm_eval/__pycache__/coco_cf_loader.cpython-311.pyc +0 -0
  6. vlm_eval/__pycache__/datasets_classes_templates.cpython-311.pyc +0 -0
  7. vlm_eval/__pycache__/run_evaluation.cpython-311.pyc +3 -0
  8. vlm_eval/__pycache__/run_evaluation.cpython-312.pyc +0 -0
  9. vlm_eval/__pycache__/run_evaluation_qualitative.cpython-311.pyc +0 -0
  10. vlm_eval/attacks/__init__.py +0 -0
  11. vlm_eval/attacks/__pycache__/__init__.cpython-311.pyc +0 -0
  12. vlm_eval/attacks/__pycache__/afw.cpython-311.pyc +0 -0
  13. vlm_eval/attacks/__pycache__/apgd.cpython-311.pyc +0 -0
  14. vlm_eval/attacks/__pycache__/attack.cpython-311.pyc +0 -0
  15. vlm_eval/attacks/__pycache__/ead.cpython-311.pyc +0 -0
  16. vlm_eval/attacks/__pycache__/fwnucl.cpython-311.pyc +0 -0
  17. vlm_eval/attacks/__pycache__/gse.cpython-311.pyc +0 -0
  18. vlm_eval/attacks/__pycache__/iht.cpython-311.pyc +0 -0
  19. vlm_eval/attacks/__pycache__/pgd0.cpython-311.pyc +0 -0
  20. vlm_eval/attacks/__pycache__/saif.cpython-311.pyc +0 -0
  21. vlm_eval/attacks/__pycache__/strattack.cpython-311.pyc +0 -0
  22. vlm_eval/attacks/apgd.py +384 -0
  23. vlm_eval/attacks/attack.py +20 -0
  24. vlm_eval/attacks/ead.py +132 -0
  25. vlm_eval/attacks/fwnucl.py +170 -0
  26. vlm_eval/attacks/gse.py +313 -0
  27. vlm_eval/attacks/iht.py +97 -0
  28. vlm_eval/attacks/pgd.py +88 -0
  29. vlm_eval/attacks/pgd0.py +131 -0
  30. vlm_eval/attacks/saif.py +143 -0
  31. vlm_eval/attacks/sparsers.py +164 -0
  32. vlm_eval/attacks/strattack.py +229 -0
  33. vlm_eval/attacks/utils.py +52 -0
  34. vlm_eval/clip_classification.py +160 -0
  35. vlm_eval/clip_train.py +209 -0
  36. vlm_eval/coco_cf_loader.py +90 -0
  37. vlm_eval/create_clip_dataset.py +65 -0
  38. vlm_eval/datasets_classes_templates.py +822 -0
  39. vlm_eval/ms_coco_gen.py +76 -0
  40. vlm_eval/run_evaluation.py +0 -0
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  MastersThesis_475703.pdf filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  MastersThesis_475703.pdf filter=lfs diff=lfs merge=lfs -text
37
+ vlm_eval/__pycache__/run_evaluation.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
vlm_eval/__init__.py ADDED
File without changes
vlm_eval/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (163 Bytes). View file
 
vlm_eval/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (149 Bytes). View file
 
vlm_eval/__pycache__/coco_cf_loader.cpython-311.pyc ADDED
Binary file (4.87 kB). View file
 
vlm_eval/__pycache__/datasets_classes_templates.cpython-311.pyc ADDED
Binary file (21.1 kB). View file
 
vlm_eval/__pycache__/run_evaluation.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6122c088ad6b90b0847802d1d0eaaefe8b2503bfa8c3c29a370d6c4406b59718
3
+ size 113082
vlm_eval/__pycache__/run_evaluation.cpython-312.pyc ADDED
Binary file (73.2 kB). View file
 
vlm_eval/__pycache__/run_evaluation_qualitative.cpython-311.pyc ADDED
Binary file (11.5 kB). View file
 
vlm_eval/attacks/__init__.py ADDED
File without changes
vlm_eval/attacks/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (171 Bytes). View file
 
vlm_eval/attacks/__pycache__/afw.cpython-311.pyc ADDED
Binary file (6.67 kB). View file
 
vlm_eval/attacks/__pycache__/apgd.cpython-311.pyc ADDED
Binary file (23 kB). View file
 
vlm_eval/attacks/__pycache__/attack.cpython-311.pyc ADDED
Binary file (1.24 kB). View file
 
vlm_eval/attacks/__pycache__/ead.cpython-311.pyc ADDED
Binary file (7.72 kB). View file
 
vlm_eval/attacks/__pycache__/fwnucl.cpython-311.pyc ADDED
Binary file (11.8 kB). View file
 
vlm_eval/attacks/__pycache__/gse.cpython-311.pyc ADDED
Binary file (20.1 kB). View file
 
vlm_eval/attacks/__pycache__/iht.cpython-311.pyc ADDED
Binary file (5.51 kB). View file
 
vlm_eval/attacks/__pycache__/pgd0.cpython-311.pyc ADDED
Binary file (9.86 kB). View file
 
vlm_eval/attacks/__pycache__/saif.cpython-311.pyc ADDED
Binary file (8.29 kB). View file
 
vlm_eval/attacks/__pycache__/strattack.cpython-311.pyc ADDED
Binary file (14 kB). View file
 
vlm_eval/attacks/apgd.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code adapted from https://github.com/chs20/RobustVLM/tree/main
2
+
3
+ import torch
4
+ import math
5
+
6
+
7
+ class APGD:
8
+ def __init__(self, model, norm, eps, mask_out='context', initial_stepsize=None, decrease_every=None, decrease_every_max=None, random_init=False):
9
+ # model returns loss sum over batch
10
+ # thus currently only works with batch size 1
11
+ # initial_stepsize: in terms of eps. called alpha in apgd
12
+ # decrease_every: potentially decrease stepsize every x fraction of total iterations. default: 0.22
13
+ self.model = model
14
+ self.norm = norm
15
+ self.eps = eps
16
+ self.initial_stepsize = initial_stepsize
17
+ self.decrease_every = decrease_every
18
+ self.decrease_every_max = decrease_every_max
19
+ self.random_init = random_init
20
+ if mask_out != 'none':
21
+ self.mask_out = mask_out
22
+ else:
23
+ self.mask_out = None
24
+
25
+ def perturb(self, data_clean, iterations, pert_init=None, verbose=False):
26
+ mask = self._set_mask(data_clean)
27
+ data_adv, _, _ = apgd(
28
+ self.model, data_clean, norm=self.norm, eps=self.eps, n_iter=iterations,
29
+ use_rs=self.random_init, mask=mask, alpha=self.initial_stepsize,
30
+ n_iter_2=self.decrease_every, n_iter_min=self.decrease_every_max, pert_init=pert_init,
31
+ verbose=verbose
32
+ )
33
+
34
+ return data_adv
35
+
36
+ def _set_mask(self, data):
37
+ mask = torch.ones_like(data)
38
+ if self.mask_out == 'context':
39
+ mask[:, :-1, ...] = 0
40
+ elif self.mask_out == 'query':
41
+ mask[:, -1, ...] = 0
42
+ elif isinstance(self.mask_out, int):
43
+ mask[:, self.mask_out, ...] = 0
44
+ elif self.mask_out is None:
45
+ pass
46
+ else:
47
+ raise NotImplementedError(f'Unknown mask_out: {self.mask_out}')
48
+ return mask
49
+
50
+ def __str__(self):
51
+ return 'APGD'
52
+
53
+
54
+ def L1_projection(x2, y2, eps1):
55
+ '''
56
+ x2: center of the L1 ball (bs x input_dim)
57
+ y2: current perturbation (x2 + y2 is the point to be projected)
58
+ eps1: radius of the L1 ball
59
+
60
+ output: delta s.th. ||y2 + delta||_1 = eps1
61
+ and 0 <= x2 + y2 + delta <= 1
62
+ '''
63
+
64
+ x = x2.clone().float().view(x2.shape[0], -1)
65
+ y = y2.clone().float().view(y2.shape[0], -1)
66
+ sigma = y.clone().sign()
67
+ u = torch.min(1 - x - y, x + y)
68
+ # u = torch.min(u, epsinf - torch.clone(y).abs())
69
+ u = torch.min(torch.zeros_like(y), u)
70
+ l = -torch.clone(y).abs()
71
+ d = u.clone()
72
+
73
+ bs, indbs = torch.sort(-torch.cat((u, l), 1), dim=1)
74
+ bs2 = torch.cat((bs[:, 1:], torch.zeros(bs.shape[0], 1).to(bs.device)), 1)
75
+
76
+ inu = 2 * (indbs < u.shape[1]).float() - 1
77
+ size1 = inu.cumsum(dim=1)
78
+
79
+ s1 = -u.sum(dim=1)
80
+
81
+ c = eps1 - y.clone().abs().sum(dim=1)
82
+ c5 = s1 + c < 0
83
+ c2 = c5.nonzero().squeeze(1)
84
+
85
+ s = s1.unsqueeze(-1) + torch.cumsum((bs2 - bs) * size1, dim=1)
86
+ # print(s[0])
87
+
88
+ # print(c5.shape, c2)
89
+
90
+ if c2.nelement != 0:
91
+
92
+ lb = torch.zeros_like(c2).float()
93
+ ub = torch.ones_like(lb) * (bs.shape[1] - 1)
94
+
95
+ # print(c2.shape, lb.shape)
96
+
97
+ nitermax = torch.ceil(torch.log2(torch.tensor(bs.shape[1]).float()))
98
+ counter2 = torch.zeros_like(lb).long()
99
+ counter = 0
100
+
101
+ while counter < nitermax:
102
+ counter4 = torch.floor((lb + ub) / 2.)
103
+ counter2 = counter4.type(torch.LongTensor)
104
+
105
+ c8 = s[c2, counter2] + c[c2] < 0
106
+ ind3 = c8.nonzero().squeeze(1)
107
+ ind32 = (~c8).nonzero().squeeze(1)
108
+ # print(ind3.shape)
109
+ if ind3.nelement != 0:
110
+ lb[ind3] = counter4[ind3]
111
+ if ind32.nelement != 0:
112
+ ub[ind32] = counter4[ind32]
113
+
114
+ # print(lb, ub)
115
+ counter += 1
116
+
117
+ lb2 = lb.long()
118
+ alpha = (-s[c2, lb2] - c[c2]) / size1[c2, lb2 + 1] + bs2[c2, lb2]
119
+ d[c2] = -torch.min(torch.max(-u[c2], alpha.unsqueeze(-1)), -l[c2])
120
+
121
+ return (sigma * d).view(x2.shape)
122
+
123
+ def L0_projection(x_adv, x, eps, step_size, lam=0.01):
124
+
125
+
126
+ pert = x_adv - x
127
+
128
+ pert_proj = torch.clamp(pert,-eps,eps)
129
+ x_adv_temp = torch.clamp(x + pert_proj,0.,1.)
130
+ pert_proj = x_adv_temp - x
131
+ pert = torch.where(pert ** 2 - (pert_proj - pert) ** 2 > 2 * step_size * lam, pert_proj, 0)
132
+ #pert = torch.where(pert > (2 * lam * step_size) ** 0.5, pert, 0)
133
+ return torch.clamp(x+pert,0.0,1.0)
134
+
135
+
136
+
137
+ def L1_norm(x, keepdim=False):
138
+ z = x.abs().view(x.shape[0], -1).sum(-1)
139
+ if keepdim:
140
+ z = z.view(-1, *[1] * (len(x.shape) - 1))
141
+ return z
142
+
143
+
144
+ def L2_norm(x, keepdim=False):
145
+ z = (x ** 2).view(x.shape[0], -1).sum(-1).sqrt()
146
+ if keepdim:
147
+ z = z.view(-1, *[1] * (len(x.shape) - 1))
148
+ return z
149
+
150
+
151
+ def L0_norm(x):
152
+ return (x != 0.).view(x.shape[0], -1).sum(-1)
153
+
154
+
155
+ def dlr_loss(x, y, reduction='none'):
156
+ x_sorted, ind_sorted = x.sort(dim=1)
157
+ ind = (ind_sorted[:, -1] == y).float()
158
+
159
+ return -(x[torch.arange(x.shape[0]), y] - x_sorted[:, -2] * ind - \
160
+ x_sorted[:, -1] * (1. - ind)) / (x_sorted[:, -1] - x_sorted[:, -3] + 1e-12)
161
+
162
+
163
+ def dlr_loss_targeted(x, y, y_target):
164
+ x_sorted, ind_sorted = x.sort(dim=1)
165
+ u = torch.arange(x.shape[0])
166
+
167
+ return -(x[u, y] - x[u, y_target]) / (x_sorted[:, -1] - .5 * (
168
+ x_sorted[:, -3] + x_sorted[:, -4]) + 1e-12)
169
+
170
+ def check_oscillation(x, j, k, y5, k3=0.75):
171
+ t = torch.zeros(x.shape[1]).to(x.device)
172
+ for counter5 in range(k):
173
+ t += (x[j - counter5] > x[j - counter5 - 1]).float()
174
+
175
+ return (t <= k * k3 * torch.ones_like(t)).float()
176
+
177
+
178
+ def apgd(model, x, norm, eps, n_iter=10, use_rs=False, mask=None, alpha=None, n_iter_2=None,
179
+ n_iter_min=None, pert_init=None, verbose=False, is_train=True):
180
+ # from https://github.com/fra31/robust-finetuning
181
+ assert x.shape[0] == 1 # only support batch size 1 for now
182
+ norm = norm.replace('l', 'L')
183
+ device = x.device
184
+ ndims = len(x.shape) - 1
185
+
186
+ if not use_rs:
187
+ x_adv = x.clone()
188
+ else:
189
+ if norm == 'Linf':
190
+ t = torch.zeros_like(x).uniform_(-eps, eps).detach()
191
+ x_adv = x + t
192
+ elif norm == 'L2':
193
+ t = torch.randn(x.shape).to(device).detach()
194
+ x_adv = x + eps * torch.ones_like(x).detach() * t / (L2_norm(t, keepdim=True) + 1e-12)
195
+ if pert_init is not None:
196
+ assert not use_rs
197
+ assert pert_init.shape == x.shape, f'pert_init.shape: {pert_init.shape}, x.shape: {x.shape}'
198
+ x_adv = x + pert_init
199
+
200
+ x_adv = x_adv.clamp(0., 1.)
201
+ x_best = x_adv.clone()
202
+ x_best_adv = x_adv.clone()
203
+ loss_steps = torch.zeros([n_iter, x.shape[0]], device=device)
204
+ loss_best_steps = torch.zeros([n_iter + 1, x.shape[0]], device=device)
205
+
206
+ # set params
207
+ n_fts = math.prod(x.shape[1:])
208
+ if norm in ['Linf', 'L2']:
209
+ n_iter_2_frac = 0.22 if n_iter_2 is None else n_iter_2
210
+ n_iter_min_frac = 0.06 if n_iter_min is None else n_iter_min
211
+ n_iter_2 = max(int(n_iter_2_frac * n_iter), 1)
212
+ n_iter_min = max(int(n_iter_min_frac * n_iter), 1)
213
+ size_decr = max(int(0.03 * n_iter), 1)
214
+ k = n_iter_2 + 0
215
+ thr_decr = .75
216
+ alpha = 2. if alpha is None else alpha
217
+ elif norm in ['L1','L0']:
218
+ k = max(int(.04 * n_iter), 1)
219
+ init_topk = .05 if is_train else .2
220
+ topk = init_topk * torch.ones([x.shape[0]], device=device)
221
+ sp_old = n_fts * torch.ones_like(topk)
222
+ adasp_redstep = 1.5
223
+ adasp_minstep = 10.
224
+ alpha = 1. if alpha is None else alpha
225
+
226
+ step_size = alpha * eps * torch.ones([x.shape[0], *[1] * ndims],
227
+ device=device)
228
+ counter3 = 0
229
+
230
+ x_adv.requires_grad_()
231
+ # grad = torch.zeros_like(x)
232
+ # for _ in range(self.eot_iter)
233
+ with torch.enable_grad():
234
+ loss_indiv = model(x_adv)#.unsqueeze(0)
235
+ loss = loss_indiv.sum()
236
+ # grad += torch.autograd.grad(loss, [x_adv])[0].detach()
237
+ grad = torch.autograd.grad(loss, [x_adv])[0].detach()
238
+ if mask is not None:
239
+ grad *= mask
240
+ # grad /= float(self.eot_iter)
241
+ grad_best = grad.clone()
242
+ x_adv.detach_()
243
+ loss_indiv = loss_indiv.detach()
244
+ loss = loss.detach()
245
+
246
+ loss_best = loss_indiv.detach().clone()
247
+ loss_best_last_check = loss_best.clone()
248
+ reduced_last_check = torch.ones_like(loss_best)
249
+ n_reduced = 0
250
+
251
+ u = torch.arange(x.shape[0], device=device)
252
+ x_adv_old = x_adv.clone().detach()
253
+
254
+ for i in range(n_iter):
255
+ ### gradient step
256
+ if True: # with torch.no_grad()
257
+ x_adv = x_adv.detach()
258
+ grad2 = x_adv - x_adv_old
259
+ x_adv_old = x_adv.clone()
260
+ loss_curr = loss.detach().mean()
261
+
262
+ a = 0.75 if i > 0 else 1.0
263
+
264
+ if norm == 'Linf':
265
+ x_adv_1 = x_adv + step_size * torch.sign(grad)
266
+ x_adv_1 = torch.clamp(torch.min(torch.max(x_adv_1,
267
+ x - eps), x + eps), 0.0, 1.0)
268
+ x_adv_1 = torch.clamp(torch.min(torch.max(
269
+ x_adv + (x_adv_1 - x_adv) * a + grad2 * (1 - a),
270
+ x - eps), x + eps), 0.0, 1.0)
271
+
272
+ elif norm == 'L2':
273
+ x_adv_1 = x_adv + step_size * grad / (L2_norm(grad,
274
+ keepdim=True) + 1e-12)
275
+ x_adv_1 = torch.clamp(x + (x_adv_1 - x) / (L2_norm(x_adv_1 - x,
276
+ keepdim=True) + 1e-12) * torch.min(eps * torch.ones_like(x),
277
+ L2_norm(x_adv_1 - x, keepdim=True)), 0.0, 1.0)
278
+ x_adv_1 = x_adv + (x_adv_1 - x_adv) * a + grad2 * (1 - a)
279
+ x_adv_1 = torch.clamp(x + (x_adv_1 - x) / (L2_norm(x_adv_1 - x,
280
+ keepdim=True) + 1e-12) * torch.min(eps * torch.ones_like(x),
281
+ L2_norm(x_adv_1 - x, keepdim=True)), 0.0, 1.0)
282
+
283
+ elif norm == 'L1':
284
+ grad_topk = grad.abs().view(x.shape[0], -1).sort(-1)[0]
285
+ topk_curr = torch.clamp((1. - topk) * n_fts, min=0, max=n_fts - 1).long()
286
+ grad_topk = grad_topk[u, topk_curr].view(-1, *[1] * (len(x.shape) - 1))
287
+ sparsegrad = grad * (grad.abs() >= grad_topk).float()
288
+ x_adv_1 = x_adv + step_size * sparsegrad.sign() / (
289
+ sparsegrad.sign().abs().view(x.shape[0], -1).sum(dim=-1).view(
290
+ -1, 1, 1, 1) + 1e-10)
291
+
292
+ delta_u = x_adv_1 - x
293
+ delta_p = L1_projection(x, delta_u, eps)
294
+ x_adv_1 = x + delta_u + delta_p
295
+
296
+ elif norm == 'L0':
297
+ L1normgrad = grad / (grad.abs().view(grad.shape[0], -1).sum(
298
+ dim=-1, keepdim=True) + 1e-12).view(grad.shape[0], *[1] * (
299
+ len(grad.shape) - 1))
300
+ x_adv_1 = x_adv + step_size * L1normgrad * n_fts
301
+ # TODO: add momentum
302
+
303
+ x_adv = x_adv_1.to(dtype=x_adv.dtype) + 0.
304
+
305
+ ### get gradient
306
+ x_adv.requires_grad_()
307
+ # grad = torch.zeros_like(x)
308
+ # for _ in range(self.eot_iter)
309
+ with torch.enable_grad():
310
+ loss_indiv = model(x_adv)#.unsqueeze(0)
311
+ loss = loss_indiv.sum()
312
+
313
+ # grad += torch.autograd.grad(loss, [x_adv])[0].detach()
314
+ if i < n_iter - 1:
315
+ # save one backward pass
316
+ grad = torch.autograd.grad(loss, [x_adv])[0].detach()
317
+ if mask is not None:
318
+ grad *= mask
319
+ # grad /= float(self.eot_iter)
320
+ x_adv.detach_()
321
+ loss_indiv = loss_indiv.detach()
322
+ loss = loss.detach()
323
+
324
+ x_best_adv = x_adv + 0.
325
+ if verbose and (i % max(n_iter // 10, 1) == 0 or i == n_iter - 1):
326
+ str_stats = ' - step size: {:.5f} - topk: {:.2f}'.format(
327
+ step_size.mean(), topk.mean() * n_fts) if norm in ['L1'] else ' - step size: {:.5f}'.format(
328
+ step_size.mean())
329
+ print('iteration: {} - best loss: {:.6f} curr loss {:.6f} {}'.format(
330
+ i, loss_best.sum(), loss_curr, str_stats))
331
+ # print('pert {}'.format((x - x_best_adv).abs().view(x.shape[0], -1).sum(-1).max()))
332
+
333
+ ### check step size
334
+ if True: # with torch.no_grad()
335
+ y1 = loss_indiv.detach().clone()
336
+ loss_steps[i] = y1 + 0
337
+ ind = (y1 > loss_best).nonzero().squeeze()
338
+ x_best[ind] = x_adv[ind].clone()
339
+ grad_best[ind] = grad[ind].clone()
340
+ loss_best[ind] = y1[ind] + 0
341
+ loss_best_steps[i + 1] = loss_best + 0
342
+
343
+ counter3 += 1
344
+
345
+ if counter3 == k:
346
+ if norm in ['Linf', 'L2']:
347
+ fl_oscillation = check_oscillation(loss_steps, i, k,
348
+ loss_best, k3=thr_decr)
349
+ fl_reduce_no_impr = (1. - reduced_last_check) * (
350
+ loss_best_last_check >= loss_best).float()
351
+ fl_oscillation = torch.max(fl_oscillation,
352
+ fl_reduce_no_impr)
353
+ reduced_last_check = fl_oscillation.clone()
354
+ loss_best_last_check = loss_best.clone()
355
+
356
+ if fl_oscillation.sum() > 0:
357
+ ind_fl_osc = (fl_oscillation > 0).nonzero().squeeze()
358
+ step_size[ind_fl_osc] /= 2.0
359
+ n_reduced = fl_oscillation.sum()
360
+
361
+ x_adv[ind_fl_osc] = x_best[ind_fl_osc].clone()
362
+ grad[ind_fl_osc] = grad_best[ind_fl_osc].clone()
363
+
364
+ counter3 = 0
365
+ k = max(k - size_decr, n_iter_min)
366
+
367
+ elif norm in ['L1']:
368
+ # adjust sparsity
369
+ sp_curr = L0_norm(x_best - x)
370
+ fl_redtopk = (sp_curr / sp_old) < .95
371
+ topk = sp_curr / n_fts / 1.5
372
+ step_size[fl_redtopk] = alpha * eps
373
+ step_size[~fl_redtopk] /= adasp_redstep
374
+ step_size.clamp_(alpha * eps / adasp_minstep, alpha * eps)
375
+ sp_old = sp_curr.clone()
376
+
377
+ x_adv[fl_redtopk] = x_best[fl_redtopk].clone()
378
+ grad[fl_redtopk] = grad_best[fl_redtopk].clone()
379
+
380
+ counter3 = 0
381
+
382
+ return x_best, loss_best, x_best_adv
383
+
384
+
vlm_eval/attacks/attack.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class Attack(object):
5
+ '''
6
+ Root class for all adversarial attack classes.
7
+ '''
8
+
9
+ def __init__(self, model, targeted=False, img_range=(0, 1)):
10
+ self.model = model
11
+ self.device = 'cuda:0'
12
+ self.targeted = targeted
13
+ self.img_range = img_range
14
+
15
+ def __repr__(self):
16
+ return str(self.__dict__)
17
+
18
+ def to(self, device):
19
+ self.model.to(device)
20
+ self.device = device
vlm_eval/attacks/ead.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code taken and adapted from https://github.com/wagnermoritz/GSE
2
+ import torch
3
+ from vlm_eval.attacks.attack import Attack
4
+
5
+ class EAD(Attack):
6
+
7
+ def __init__(self,model, targeted=False, img_range=(0,1), steps=100, beta=5e-5, mask_out='none', ver=False, binary_steps=2, step_size=1e-2, decision_rule='L1'):
8
+
9
+ super().__init__(model=model, targeted=targeted, img_range=img_range)
10
+ self.steps = steps
11
+ self.ver = ver
12
+ self.binary_steps = binary_steps
13
+ self.beta = beta
14
+ if mask_out != 'none':
15
+ self.mask_out = mask_out
16
+ else:
17
+ self.mask_out = None
18
+ self.decision_rule = decision_rule
19
+ self.ver = ver
20
+ self.step_size = step_size
21
+
22
+ def _set_mask(self, data):
23
+ mask = torch.ones_like(data)
24
+ if self.mask_out == 'context':
25
+ mask[:, :-1, ...] = 0
26
+ elif self.mask_out == 'query':
27
+ mask[:, -1, ...] = 0
28
+ elif isinstance(self.mask_out, int):
29
+ mask[:, self.mask_out, ...] = 0
30
+ elif self.mask_out is None:
31
+ pass
32
+ else:
33
+ raise NotImplementedError(f'Unknown mask_out: {self.mask_out}')
34
+ return mask
35
+
36
+ def __call__(self, x_orig):
37
+
38
+ for param in self.model.model.parameters():
39
+ param.requires_grad = False
40
+
41
+ mask_out = self._set_mask(x_orig)
42
+
43
+ c = 1e-1
44
+ c_upper = 10e+10
45
+ c_lower = 0
46
+
47
+ overall_best_attack = x_orig.clone()
48
+ overall_best_dist = torch.inf
49
+ overall_best_loss = 1e10
50
+
51
+ for binary_step in range(self.binary_steps):
52
+
53
+ global_step = 0
54
+ x = x_orig.clone().detach()
55
+ y = x_orig.clone().detach()
56
+
57
+ best_attack = x_orig.clone().detach()
58
+ best_dist = torch.inf
59
+ best_loss = 1e10
60
+
61
+ step_size = 1e-2
62
+
63
+ for step in range(self.steps):
64
+
65
+ y.requires_grad = True
66
+ _, loss = self.loss_fn(x=y, c=c, x_orig=x_orig)
67
+ loss.backward()
68
+ y_grad = y.grad.data * mask_out
69
+
70
+ with torch.no_grad():
71
+ x_new = self.project(x=y-step_size*y_grad, x_orig=x_orig)
72
+
73
+ step_size = (self.step_size - 0) * (1 - global_step / self.steps) ** 0.5 + 0
74
+ global_step += 1
75
+
76
+ y = x_new + (step / (step + 3)) * (x_new - x)
77
+ x = x_new
78
+
79
+ loss_model, loss = self.loss_fn(x=x, c=c, x_orig=x_orig)
80
+
81
+ if self.ver and step % 20 == 0:
82
+ print(f"Binary Step: {binary_step}, Iter: {step}, Loss: {loss.item()}, L0: {(x - x_orig).norm(p=0)}, Linf: {(x - x_orig).norm(p=torch.inf)}")
83
+
84
+ if self.decision_rule == 'L1':
85
+ if (x - x_orig).norm(p=1).item() < best_dist and loss_model < best_loss:
86
+ best_loss = loss_model
87
+ best_attack = x.clone()
88
+ best_dist = (x - x_orig).norm(p=1).item()
89
+ else:
90
+ raise NotImplementedError
91
+
92
+ # Updating c
93
+ if overall_best_dist > best_dist and best_loss < overall_best_loss:
94
+ overall_best_loss = best_loss
95
+ overall_best_dist = best_dist
96
+ overall_best_attack = best_attack.clone()
97
+
98
+ c_upper = min(c_upper, c)
99
+ if c_upper < 1e9:
100
+ c = (c_upper + c_lower) / 2
101
+
102
+ else:
103
+ c_lower = max(c_lower, c)
104
+ if c_upper < 1e9:
105
+ c = (c_lower + c_upper) / 2.0
106
+ else:
107
+ c *= 10
108
+
109
+ print(f"Final L0: {(overall_best_attack - x_orig).norm(p=0)}, Linf: {(overall_best_attack - x_orig).norm(p=torch.inf)}")
110
+ return overall_best_attack.detach()
111
+
112
+
113
+ def project(self, x, x_orig):
114
+
115
+ mask_1 = (x - x_orig > self.beta).float()
116
+ mask_2 = ((x - x_orig).abs() <= self.beta).float()
117
+ mask_3 = (x - x_orig < -self.beta).float()
118
+
119
+ upper = torch.minimum(x - self.beta, torch.tensor(1.0))
120
+ lower = torch.maximum(x + self.beta, torch.tensor(0.0))
121
+
122
+ proj_x = mask_1 * upper + mask_2 * x_orig + mask_3 * lower
123
+ return proj_x
124
+
125
+ def loss_fn(self, x, c, x_orig):
126
+
127
+ out = -self.model(x).sum() if not self.targeted else self.model(x).sum()
128
+ l2_dist = ((x - x_orig) ** 2).view(x.shape[0], -1).sum(dim=1)
129
+ l1_dist = ((x - x_orig).abs()).view(x.shape[0], -1).sum(dim=1)
130
+
131
+ return out, c * out + l2_dist.sum() + \
132
+ self.beta * l1_dist.sum()
vlm_eval/attacks/fwnucl.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code taken and adapted from https://github.com/wagnermoritz/GSE
2
+ import torch
3
+ import math
4
+ from vlm_eval.attacks.attack import Attack
5
+
6
+ class FWnucl(Attack):
7
+ def __init__(self, model, *args, iters=200, img_range=(-1, 1), ver=False,
8
+ targeted=False, eps=5, mask_out='none',**kwargs):
9
+ '''
10
+ Implementation of the nuclear group norm attack.
11
+
12
+ args:
13
+ model: Callable, PyTorch classifier.
14
+ ver: Bool, print progress if True.
15
+ img_range: Tuple of ints/floats, lower and upper bound of image
16
+ entries.
17
+ targeted: Bool, given label is used as a target label if True.
18
+ eps: Float, radius of the nuclear group norm ball.
19
+ '''
20
+ super().__init__(model, img_range=img_range, targeted=targeted)
21
+ self.iters = iters
22
+ self.ver = ver
23
+ self.eps = eps
24
+ self.gr = (math.sqrt(5) + 1) / 2
25
+ if mask_out != 'none':
26
+ self.mask_out = mask_out
27
+ else:
28
+ self.mask_out = None
29
+
30
+ def _set_mask(self, data):
31
+ mask = torch.ones_like(data)
32
+ if self.mask_out == 'context':
33
+ mask[:, :-1, ...] = 0
34
+ elif self.mask_out == 'query':
35
+ mask[:, -1, ...] = 0
36
+ elif isinstance(self.mask_out, int):
37
+ mask[:, self.mask_out, ...] = 0
38
+ elif self.mask_out is None:
39
+ pass
40
+ else:
41
+ raise NotImplementedError(f'Unknown mask_out: {self.mask_out}')
42
+ return mask
43
+
44
+
45
+ def __loss_fn(self, x):
46
+ '''
47
+ Compute loss depending on self.targeted.
48
+ '''
49
+ if self.targeted:
50
+ return -self.model(x).sum()
51
+ else:
52
+ return self.model(x).sum()
53
+
54
+
55
+ def __call__(self, x, *args, **kwargs):
56
+ '''
57
+ Perform the nuclear group norm attack on a batch of images x.
58
+
59
+ args:
60
+ x: Tensor of shape [B, C, H, W], batch of images.
61
+ y: Tensor of shape [B], batch of labels.
62
+
63
+ Returns a tensor of the same shape as x containing adversarial examples
64
+ '''
65
+
66
+ for param in self.model.model.parameters():
67
+ param.requires_grad = False
68
+
69
+ mask_out = self._set_mask(x)
70
+ x = x.to(self.device)
71
+ noise = torch.zeros_like(x)
72
+ noise.requires_grad = True
73
+
74
+ for t in range(self.iters):
75
+ if self.ver:
76
+ print(f'\rIteration {t+1}/{self.iters}', end='')
77
+
78
+ loss = self.__loss_fn(x + noise * mask_out)
79
+ loss.backward()
80
+ noise.grad.data = noise.grad.data * mask_out
81
+ s = self.__groupNuclearLMO(noise.grad.data, eps=self.eps)
82
+ with torch.no_grad():
83
+ gamma = self.__lineSearch(x=x, s=s, noise=noise)
84
+ noise = (1 - gamma) * noise + gamma * s
85
+ noise.requires_grad = True
86
+
87
+ if self.ver and t % 20 == 0:
88
+ print(f"Iteration: {t}, Loss: {loss.item()}")
89
+ x = torch.clamp(x + noise, 0, 1)
90
+ if self.ver:
91
+ print("")
92
+ return x.detach()
93
+
94
+
95
+ def __lineSearch(self, x, s, noise, steps=25):
96
+ '''
97
+ Perform line search for the step size.
98
+ '''
99
+ a = torch.zeros(x.shape[1], device=self.device).view(-1, 1, 1, 1)
100
+ b = torch.ones(x.shape[1], device=self.device).view(-1, 1, 1, 1)
101
+ c = b - (b - a) / self.gr
102
+ d = a + (b - a) / self.gr
103
+ sx = s - noise
104
+
105
+ for i in range(steps):
106
+ loss1 = self.__loss_fn(x + noise + (c * sx).view(*x.shape))
107
+ loss2 = self.__loss_fn(x + noise + (d * sx).view(*x.shape))
108
+ mask = loss1 > loss2
109
+
110
+ b[mask] = d[mask]
111
+ mask = torch.logical_not(mask)
112
+ a[mask] = c[mask]
113
+
114
+ c = b - (b - a) / self.gr
115
+ d = a + (b - a) / self.gr
116
+
117
+ return (b + a) / 2
118
+
119
+
120
+ def __groupNuclearLMO(self, x, eps=5):
121
+ '''
122
+ LMO for the nuclear group norm ball.
123
+ '''
124
+
125
+ B, C, H, W = x.shape[1], x.shape[3], x.shape[4], x.shape[5]
126
+ size = 32 if H > 64 else 4
127
+
128
+ # turn batch of images into batch of size by size pixel groups per
129
+ # color channel
130
+ xrgb = [x.view(B, C, H, W)[:, c, :, :] for c in range(C)]
131
+ xrgb = [xc.unfold(1, size, size).unfold(2, size, size) for xc in xrgb]
132
+ xrgb = [xc.reshape(-1, size, size) for xc in xrgb]
133
+
134
+ # compute nuclear norm of each patch (sum norms over color channels)
135
+ norms = torch.linalg.svdvals(xrgb[0])
136
+ for xc in xrgb[1:]:
137
+ norms += torch.linalg.svdvals(xc)
138
+ norms = norms.sum(-1).reshape(B, -1)
139
+
140
+ # only keep the patch g* with the largest nuclear norm for each image
141
+ idxs = norms.argmax(dim=1).view(-1, 1)
142
+ xrgb = [xc.reshape(B, -1, size, size) for xc in xrgb]
143
+ xrgb = [xc[torch.arange(B).view(-1, 1), idxs].view(B, size, size)
144
+ for xc in xrgb]
145
+
146
+ # build index tensor corr. to the position of the kept patches in x
147
+ off = (idxs % (W / size)).long() * size
148
+ off += torch.floor(idxs / (W / size)).long() * W * size
149
+ idxs = torch.arange(0, size**2,
150
+ device=self.device).view(1, -1).repeat(B, 1) + off
151
+ off = torch.arange(0, size,
152
+ device=self.device).view(-1, 1).repeat(1, size)
153
+ off = off * W - off * size
154
+ idxs += off.view(1, -1)
155
+
156
+ # compute singular vector pairs corresponding to largest singular value
157
+ # and final perturbation (LMO solution)
158
+ pert = torch.zeros_like(x).view(B, C, H, W)
159
+ for i, xc in enumerate(xrgb):
160
+ U, _, V = torch.linalg.svd(xc)
161
+ U = U[:, :, 0].view(B, size, 1)
162
+ V = V.transpose(-2, -1)[:, :, 0].view(B, size, 1)
163
+ pert_gr = torch.bmm(U, V.transpose(-2, -1)).reshape(B, size * size)
164
+ idx = torch.arange(B).view(-1, 1)
165
+ pert_tmp = pert[:, i, :, :].view(B, -1)
166
+ pert_tmp[idx, idxs] = pert_gr * eps
167
+ pert_clone = pert.clone()
168
+ pert_clone[:, i, :, :] = pert_tmp.view(B, H, W)
169
+
170
+ return pert_clone.view(*x.shape)
vlm_eval/attacks/gse.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code taken and adapted from https://github.com/wagnermoritz/GSE
2
+ import torch
3
+ import torchvision
4
+ import math
5
+ import torch.nn.functional as F
6
+
7
+ from vlm_eval.attacks.attack import Attack
8
+
9
+
10
+ # required input size : batch_size x num_media x num_frames x channels x height x width
11
+ class GSEAttack(Attack):
12
+ def __init__(self, model, *args, mask_out='none',ver=False, img_range=(-1, 1), search_steps=4,
13
+ targeted=False, sequential=False, search_factor=2,
14
+ gb_size=5, sgm=1.5, mu=1, sigma=0.0025, iters=200, k_hat=10,
15
+ q=0.25, **kwargs):
16
+ '''
17
+ Implementation of the GSE attack.
18
+
19
+ args:
20
+ model: Callable, PyTorch classifier.
21
+ mask_out: Masks out context images if set to context, query images if set to query and none if set to none.
22
+ ver: Bool, print progress if True.
23
+ img_range: Tuple of ints/floats, lower and upper bound of image
24
+ entries.
25
+ search_steps: Int, number of steps for line search on the trade-off
26
+ parameter.
27
+ targeted: Bool, given label is used as a target label if True.
28
+ sequential: Bool, perturbations are computed sequentially for all
29
+ images in the batch if True. For fair comparison to
30
+ Homotopy attack.
31
+ search_factor: Float, factor to increase/decrease the trade-off
32
+ parameter until an upper/lower bound for the line search
33
+ is found.
34
+ gb_size: Odd int, size of the Gaussian blur kernel.
35
+ sgm: Float, sigma of the gaussian blur kernel
36
+ mu: Float, trade-off parameter for 2-norm regularization.
37
+ sigma: Float, step size
38
+ iters: Int, number of iterations.
39
+ k_hat: Int, number of iterations before transitioning to NAG.
40
+ q: Float, inverse of increase factor for adjust_lambda.
41
+ '''
42
+ super().__init__(model, img_range=img_range, targeted=targeted)
43
+ self.ver = ver
44
+ self.search_steps = search_steps
45
+ self.sequential = sequential
46
+ self.search_factor = search_factor
47
+ self.gb_size = gb_size
48
+ self.sgm = sgm
49
+ self.mu = mu
50
+ self.sigma = sigma
51
+ self.iters = iters
52
+ self.k_hat = k_hat
53
+ self.q = q
54
+ if mask_out != 'none':
55
+ self.mask_out = mask_out
56
+ else:
57
+ self.mask_out = None
58
+
59
+ def adjust_lambda(self, lam, noise):
60
+ '''
61
+ Adjust trade-off parameters (lambda) to update search space.
62
+ '''
63
+ x = noise.detach().clone().abs().mean(dim=1, keepdim=True).sign()
64
+ gb = torchvision.transforms.GaussianBlur((self.gb_size, self.gb_size),
65
+ sigma=self.sgm)
66
+ x = gb(x) + 1
67
+ x = torch.where(x == 1, self.q, x)
68
+ lam /= x[:, 0, :, :]
69
+ return lam
70
+
71
+
72
+ def section_search(self, x, steps=50):
73
+ '''
74
+ Section search for finding the maximal lambda such that the
75
+ perturbation is non-zero after the first iteration.
76
+ '''
77
+
78
+ noise = torch.zeros_like(x, requires_grad=True) # the shape of 'x' is batch_size x num_media x num_frames x Color x height x width
79
+ loss = (-self.model(x + noise).sum() + self.mu
80
+ * torch.norm(noise.view(x.size(1), x.size(3), x.size(4), x.size(5)), p=2, dim=(1,2,3)).sum())
81
+ grad = torch.autograd.grad(loss, [noise])[0].detach()
82
+ noise.detach_()
83
+ ones = torch.ones_like(x.view(x.size(1), x.size(3), x.size(4), x.size(5)))[:, 0, :, :]
84
+
85
+ # define upper and lower bound for line search
86
+ lb = torch.zeros((x.size(1),), dtype=torch.float,
87
+ device=self.device).view(-1, 1, 1)
88
+ ub = lb.clone() + 0.001
89
+ mask = torch.norm(self.prox(grad.clone().view(x.size(1),x.size(3),x.size(4),x.size(5)) * self.sigma,
90
+ ones * ub * self.sigma),
91
+ p=0, dim=(1,2,3)) != 0
92
+ while mask.any():
93
+ ub[mask] *= 2
94
+ mask = torch.norm(self.prox(grad.clone().view(x.size(1),x.size(3),x.size(4),x.size(5)) * self.sigma,
95
+ ones * ub * self.sigma),
96
+ p=0, dim=(1,2,3)) != 0
97
+
98
+ # perform search
99
+ for _ in range(steps):
100
+ cur = (ub + lb) / 2
101
+ mask = torch.norm(self.prox(grad.clone().view(x.size(1),x.size(3),x.size(4),x.size(5)) * self.sigma,
102
+ ones * cur * self.sigma),
103
+ p=0, dim=(1,2,3)) == 0
104
+ ub[mask] = cur[mask]
105
+ mask = torch.logical_not(mask)
106
+ lb[mask] = cur[mask]
107
+ cur = (lb + ub).view(-1) / 2
108
+ return 0.01 * cur
109
+
110
+
111
+ def __call__(self, x, y, *args, **kwargs):
112
+ '''
113
+ Call the attack for a batch of images x or sequentially for all images
114
+ in x depending on self.sequential.
115
+
116
+ args:
117
+ x: Tensor of shape [B, C, H, W], batch of images.
118
+ y: Tensor of shape [B], batch of labels.
119
+
120
+ Returns a tensor of the same shape as x containing adversarial examples
121
+ '''
122
+ if self.sequential:
123
+ result = x.clone()
124
+ for i, (x_, y_) in enumerate(zip(x, y)):
125
+ result[i] = self.perform_att(x_.unsqueeze(0),
126
+ y_.unsqueeze(0),
127
+ mu=self.mu, sigma=self.sigma,
128
+ k_hat=self.k_hat).detach()
129
+ return result
130
+ else:
131
+ return self.perform_att(x, y, mu=self.mu, sigma=self.sigma,
132
+ k_hat=self.k_hat)
133
+
134
+
135
+ def _set_mask(self, data):
136
+ mask = torch.ones_like(data)
137
+ if self.mask_out == 'context':
138
+ mask[:, :-1, ...] = 0
139
+ elif self.mask_out == 'query':
140
+ mask[:, -1, ...] = 0
141
+ elif isinstance(self.mask_out, int):
142
+ mask[:, self.mask_out, ...] = 0
143
+ elif self.mask_out is None:
144
+ pass
145
+ else:
146
+ raise NotImplementedError(f'Unknown mask_out: {self.mask_out}')
147
+ return mask
148
+
149
+
150
+ def perform_att(self, x, mu, sigma, k_hat):
151
+ '''
152
+ Perform GSE attack on a batch of images x with corresponding labels y.
153
+ '''
154
+ x = x.to(self.device)
155
+ B, C, H, W = x.shape[1], x.shape[3], x.shape[4], x.shape[5] # Input is of the shape Batch x Num_media x num_frames x colors x height x width
156
+ lams = self.section_search(x)
157
+ mask_out = self._set_mask(x).view(B,C,H,W)
158
+ # save x, y, and lams for resetting them at the beginning of every
159
+ # section search step
160
+ save_x = x.clone()
161
+ save_lams = lams.clone()
162
+ # upper and lower bounds for section learch
163
+ ub_lams = torch.full_like(lams, torch.inf)
164
+ lb_lams = torch.full_like(lams, 0.0)
165
+ # tensor for saving succesful adversarial examples in inner loop
166
+ result = x.clone()
167
+ # tensor for saving best adversarial example so far
168
+ result2 = x.clone()
169
+ best_l0 = torch.full((B,), torch.inf, device=self.device).type(x.type())
170
+
171
+ # section search
172
+ for step in range(self.search_steps):
173
+ x = save_x.clone()
174
+ lams = save_lams.clone()
175
+ lam = torch.ones_like(x.view(B, C, H, W))[:, 0, :, :] * lams.view(-1, 1, 1)
176
+ # tensor for tracking for which images adv. examples have been found
177
+ active = torch.ones(B, dtype=bool, device=self.device)
178
+ # set initial perturbation to zero
179
+ noise = torch.zeros_like(x, requires_grad = True)
180
+ noise_old = noise.clone()
181
+ lr = 1
182
+
183
+ # attack
184
+ for j in range(self.iters):
185
+ if self.ver:
186
+ print(f'\rSearch step {step + 1}/{self.search_steps}, ' +
187
+ f'Prox.Grad. Iteration {j + 1}/{self.iters}, ' +
188
+ f'Images left: {x.shape[1]}', end='')
189
+ if len(x) == 0:
190
+ break
191
+
192
+ self.model.model.zero_grad()
193
+ loss = (-self.model(x + noise).sum() + mu
194
+ * (torch.norm(noise.view(B, C, H, W), p=2, dim=(1,2,3)) ** 2).sum())
195
+ noise_grad_data = torch.autograd.grad(loss, [noise])[0].detach().view(B, C, H, W)
196
+ #print(f"{loss} {(torch.norm(noise.view(B, C, H, W), p=2, dim=(1,2,3)) ** 2).sum()}")
197
+ with torch.no_grad():
198
+
199
+ noise_grad_data = noise_grad_data * mask_out # Mask_out shape B x C x H x W
200
+ lr_ = (1 + math.sqrt(1 + 4 * lr**2)) / 2
201
+ if j == k_hat:
202
+ lammask = (lam > lams.view(-1, 1, 1))[:, None, :, :]
203
+ lammask = lammask.repeat(1, C, 1, 1)
204
+ noise_old = noise.clone()
205
+ if j < k_hat:
206
+ noise = noise - sigma * noise_grad_data.view(1, B, 1, C, H, W)
207
+ noise = self.prox(noise.view(B, C, H, W), lam * sigma).view(1, B, 1, C, H, W)
208
+ noise_tmp = noise.clone()
209
+ noise = lr / lr_ * noise + (1 - (lr/ lr_)) * noise_old
210
+ noise_old = noise_tmp.clone()
211
+ lam = self.adjust_lambda(lam, noise.view(B, C, H, W))
212
+ else:
213
+ noise = noise - sigma * noise_grad_data.view(1, B, 1, C, H, W)
214
+ noise_tmp = noise.clone()
215
+ noise = lr / lr_ * noise + (1 - (lr/ lr_)) * noise_old
216
+ noise_old = noise_tmp.clone()
217
+ noise[lammask.view(1, B, 1, C, H, W)] = 0
218
+ # clamp adv. example to valid range
219
+ x_adv = torch.clamp(x + noise, *self.img_range)
220
+ noise = x_adv - x
221
+ lr = lr_
222
+
223
+
224
+ noise.requires_grad = True
225
+
226
+ # section search
227
+ # no adv. example found => decrease upper bound and current lambda
228
+ # adv. example found => save it if the "0-norm" is better than of the
229
+ # previous adv. example, increase lower bound and current lambda
230
+ for i in range(B):
231
+ if active[i]:
232
+ ub_lams[i] = save_lams[i]
233
+ save_lams[i] = 0.95 * lb_lams[i] + 0.05 * save_lams[i]
234
+ else:
235
+ print("here")
236
+ l0 = self.l20((result[i] - save_x[i]).unsqueeze(0)).to(self.device)
237
+ if l0 < best_l0[i]:
238
+ best_l0[i] = l0
239
+ result2[i] = result[i].clone()
240
+ if torch.isinf(ub_lams[i]):
241
+ lb_lams[i] = save_lams[i]
242
+ save_lams[i] *= self.search_factor
243
+ else:
244
+ lb_lams[i] = save_lams[i]
245
+ save_lams[i] = (ub_lams[i] + save_lams[i]) / 2
246
+
247
+ if self.ver:
248
+ print('')
249
+
250
+ return x_adv
251
+
252
+ def extract_patches(self, x):
253
+ '''
254
+ Extracts and returns all overlapping size by size patches from
255
+ the image batch x.
256
+ '''
257
+ B, C, _, _ = x.shape
258
+ size = 8
259
+ kernel = torch.zeros((size ** 2, size ** 2))
260
+ kernel[range(size**2), range(size**2)] = 1.0
261
+ kernel = kernel.view(size**2, 1, size, size)
262
+ kernel = kernel.repeat(C, 1, 1, 1).to(x.device)
263
+ out = F.conv2d(x, kernel, groups=C)
264
+ out = out.view(B, C, size, size, -1)
265
+ out = out.permute(0, 4, 1, 2, 3)
266
+ return out.contiguous()
267
+
268
+ def l20(self, x):
269
+ '''
270
+ Computes d_{2,0}(x[i]) for all perturbations x[i] in the batch x
271
+ as described in section 3.2.
272
+ '''
273
+ B, N, M, C, _, _ = x.shape
274
+ l20s = []
275
+
276
+ for b in range(B):
277
+ for n in range(N):
278
+ for m in range(M):
279
+ x_ = x[b, n, m] # Select the specific perturbation x[b, n, m]
280
+ patches = self.extract_patches(x_.unsqueeze(0)) # Add unsqueeze to match 6D input
281
+ l2s = torch.norm(patches, p=2, dim=(2,3,4))
282
+ l20s.append((l2s != 0).float().sum().item())
283
+
284
+ return torch.tensor(l20s)
285
+
286
+
287
+ def prox(self, grad_loss_noise, lam):
288
+ '''
289
+ Computes the proximal operator of the 1/2-norm of the gradient of the
290
+ adversarial loss wrt current noise.
291
+ '''
292
+
293
+ lam = lam[:, None, :, :]
294
+ sh = list(grad_loss_noise.shape)
295
+ lam = lam.expand(*sh)
296
+
297
+ p_lam = (54 ** (1 / 3) / 4) * lam ** (2 / 3)
298
+
299
+ mask1 = (grad_loss_noise > p_lam)
300
+ mask2 = (torch.abs(grad_loss_noise) <= p_lam)
301
+ mask3 = (grad_loss_noise < -p_lam)
302
+ mask4 = mask1 + mask3
303
+
304
+ phi_lam_x = torch.arccos((lam / 8) * (torch.abs(grad_loss_noise) / 3)
305
+ ** (-1.5))
306
+
307
+ grad_loss_noise[mask4] = ((2 / 3) * torch.abs(grad_loss_noise[mask4])
308
+ * (1 + torch.cos((2 * math.pi) / 3
309
+ - (2 * phi_lam_x[mask4]) / 3))).to(torch.float32)
310
+ grad_loss_noise[mask3] = -grad_loss_noise[mask3]
311
+ grad_loss_noise[mask2] = 0
312
+
313
+ return grad_loss_noise
vlm_eval/attacks/iht.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code taken and adapted from https://github.com/wagnermoritz/GSE
2
+
3
+ import torch
4
+ from vlm_eval.attacks.attack import Attack
5
+ import math
6
+
7
+ class IHT(Attack):
8
+
9
+ def __init__(self, model, targeted=False, img_range=(0, 1), steps=100, prox='hard',ver=False, lam=5e-5, mask_out='none',stepsize=0.015,eps=4./255.):
10
+ super().__init__(model, targeted=targeted, img_range=img_range)
11
+ self.steps = steps
12
+ self.stepsize = stepsize
13
+ self.ver = ver
14
+ self.lam = lam
15
+ self.eps = eps
16
+ if mask_out != 'none':
17
+ self.mask_out = mask_out
18
+ else:
19
+ self.mask_out = None
20
+ if prox == 'hard':
21
+ self.Prox = self.hardprox
22
+ else:
23
+ raise NotImplementedError
24
+
25
+
26
+
27
+ def _set_mask(self, data):
28
+ mask = torch.ones_like(data)
29
+ if self.mask_out == 'context':
30
+ mask[:, :-1, ...] = 0
31
+ elif self.mask_out == 'query':
32
+ mask[:, -1, ...] = 0
33
+ elif isinstance(self.mask_out, int):
34
+ mask[:, self.mask_out, ...] = 0
35
+ elif self.mask_out is None:
36
+ pass
37
+ else:
38
+ raise NotImplementedError(f'Unknown mask_out: {self.mask_out}')
39
+ return mask
40
+
41
+ def __call__(self, img):
42
+
43
+ for param in self.model.model.parameters():
44
+ param.requires_grad = False
45
+
46
+ img = img.to(self.device)
47
+ mask_out = self._set_mask(img)
48
+ x = torch.zeros_like(img) # perturbation to optimize
49
+ z = x.clone() # used for FISTA extrapolation
50
+ t = 1
51
+ if self.ver:
52
+ print('')
53
+
54
+ for i in range(self.steps):
55
+ # compue gradient
56
+ x.requires_grad = True
57
+ loss = self.model(img + x).sum() if self.targeted else -self.model(img + x).sum()
58
+ loss.backward()
59
+ x_grad = x.grad.data * mask_out
60
+ x = x.detach()
61
+
62
+ if self.ver and i % 20 == 0:
63
+ print(f'Iteration: {i+1}, Loss: {loss}\n', end='')
64
+
65
+ # FISTA update
66
+ with torch.no_grad():
67
+ t_ = .5 * (1 + math.sqrt(1 + 4 * t ** 2))
68
+ alpha = (t - 1) / t_
69
+ t = t_
70
+ z_ = self.Prox(x=x - self.stepsize * x_grad,
71
+ lam=self.lam * self.stepsize,
72
+ img=img,
73
+ eps=self.eps
74
+ )
75
+ x = z_ + alpha * (z_ - z)
76
+ x = torch.clamp(x,-self.eps,self.eps)
77
+ z = z_.clone()
78
+ x = torch.clamp(img + x, *self.img_range) - img
79
+
80
+ if self.ver:
81
+ print('')
82
+ print(f"L0 pert norm: {x.norm(p=0)}")
83
+
84
+ return (img + x * mask_out).detach(), x.norm(p=0).item()
85
+
86
+ def hardprox(self, x, lam, img, eps):
87
+ '''
88
+ Computes the hard thresholding proximal operator of the the
89
+ perturbation x.
90
+
91
+ :x: Perturbation after gradient descent step.
92
+ :lam: Regularization parameter.
93
+ '''
94
+ x_proj = torch.clamp(x,-eps,eps)
95
+ x_temp = torch.clamp(img + x_proj,*self.img_range)
96
+ x_proj = x_temp - img
97
+ return torch.where(x ** 2 - (x_proj - x) ** 2 > 2 * lam, x_proj, 0)
vlm_eval/attacks/pgd.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code taken from https://github.com/chs20/RobustVLM/tree/main
2
+ import torch
3
+ from vlm_eval.attacks.utils import project_perturbation, normalize_grad
4
+
5
+
6
+ class PGD:
7
+ """
8
+ Minimize or maximize given loss
9
+ """
10
+
11
+ def __init__(self, forward, norm, eps, mode='min', mask_out='context', image_space=True):
12
+ self.model = forward
13
+
14
+ self.norm = norm
15
+ self.eps = eps
16
+ self.momentum = 0.9
17
+
18
+ self.mode = mode
19
+ self.mask_out = mask_out
20
+ self.image_space = image_space
21
+
22
+ def perturb(self, data_clean, iterations, stepsize, perturbation=None, verbose=False, return_loss=False):
23
+ if self.image_space:
24
+ # make sure data is in image space
25
+ assert torch.max(data_clean) < 1. + 1e-6 and torch.min(data_clean) > -1e-6 # todo
26
+
27
+ if perturbation is None:
28
+ perturbation = torch.zeros_like(data_clean, requires_grad=True)
29
+ mask = self._set_mask(data_clean)
30
+ velocity = torch.zeros_like(data_clean)
31
+ for i in range(iterations):
32
+ perturbation.requires_grad_()
33
+ with torch.enable_grad():
34
+ loss = self.model(data_clean + perturbation)
35
+ # print 10 times in total and last iteration
36
+ if verbose and (i % (iterations // 10 + 1) == 0 or i == iterations - 1):
37
+ print(f'[iteration] {i} [loss] {loss.item()}')
38
+
39
+ with torch.no_grad():
40
+ gradient = torch.autograd.grad(loss, perturbation)[0]
41
+ gradient = mask * gradient
42
+ if gradient.isnan().any(): #
43
+ print(f'attention: nan in gradient ({gradient.isnan().sum()})') #
44
+ gradient[gradient.isnan()] = 0.
45
+ # normalize
46
+ gradient = normalize_grad(gradient, p=self.norm)
47
+ # momentum
48
+ velocity = self.momentum * velocity + gradient
49
+ velocity = normalize_grad(velocity, p=self.norm)
50
+ # update
51
+ if self.mode == 'min':
52
+ perturbation = perturbation - stepsize * velocity
53
+ elif self.mode == 'max':
54
+ perturbation = perturbation + stepsize * velocity
55
+ else:
56
+ raise ValueError(f'Unknown mode: {self.mode}')
57
+ # project
58
+ perturbation = project_perturbation(perturbation, self.eps, self.norm)
59
+ if self.image_space:
60
+ perturbation = torch.clamp(
61
+ data_clean + perturbation, 0, 1
62
+ ) - data_clean # clamp to image space
63
+ assert torch.max(data_clean + perturbation) < 1. + 1e-6 and torch.min(
64
+ data_clean + perturbation
65
+ ) > -1e-6
66
+ assert not perturbation.isnan().any()
67
+
68
+ # assert (ctorch.compute_norm(perturbation, p=self.norm) <= self.eps + 1e-6).all()
69
+ # todo return best perturbation
70
+ # problem is that model currently does not output expanded loss
71
+ if return_loss:
72
+ return data_clean + perturbation.detach(), loss
73
+ else:
74
+ return data_clean + perturbation.detach()
75
+
76
+ def _set_mask(self, data):
77
+ mask = torch.ones_like(data)
78
+ if self.mask_out == 'context':
79
+ mask[:, :-1, ...] = 0
80
+ elif self.mask_out == 'query':
81
+ mask[:, -1, ...] = 0
82
+ elif isinstance(self.mask_out, int):
83
+ mask[:, self.mask_out, ...] = 0
84
+ elif self.mask_out is None:
85
+ pass
86
+ else:
87
+ raise NotImplementedError(f'Unknown mask_out: {self.mask_out}')
88
+ return mask
vlm_eval/attacks/pgd0.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code taken and adapted from https://github.com/wagnermoritz/GSE
2
+
3
+ from vlm_eval.attacks.attack import Attack
4
+ import torch
5
+ import numpy as np
6
+
7
+ class PGD0(Attack):
8
+ def __init__(self, model, *args, img_range=(0, 1), k=5000, n_restarts=1,
9
+ targeted=False, iters=200, stepsize=120000/255.0, eps=4./255.,ver=False,mask_out='none',**kwargs):
10
+ '''
11
+ Implementation of the PGD0 attack https://arxiv.org/pdf/1909.05040
12
+ Author's implementation: https://github.com/fra31/sparse-imperceivable-attacks/tree/master
13
+ Addapted from: https://github.com/wagnermoritz/GSE/tree/main
14
+
15
+ args:
16
+ model: Callable, PyTorch classifier.
17
+ img_range: Tuple of ints/floats, lower and upper bound of image
18
+ entries.
19
+ targeted: Bool, given label is used as a target label if True.
20
+ k: Int, sparsity parameter.
21
+ n_restarts: Int, number of restarts from random perturbation.
22
+ iters: Int, number of gradient descent steps per restart.
23
+ stepsize: Float, step size for gradient descent.
24
+ '''
25
+ super().__init__(model, img_range=img_range, targeted=targeted)
26
+ self.k = k
27
+ self.n_restarts = n_restarts
28
+ self.eps = eps
29
+ self.iters = iters
30
+ self.stepsize = stepsize
31
+ if mask_out != 'none':
32
+ self.mask_out = mask_out
33
+ else:
34
+ self.mask_out = None
35
+ self.ver = ver
36
+
37
+ def _set_mask(self, data):
38
+ mask = torch.ones_like(data)
39
+ if self.mask_out == 'context':
40
+ mask[:, :-1, ...] = 0
41
+ elif self.mask_out == 'query':
42
+ mask[:, -1, ...] = 0
43
+ elif isinstance(self.mask_out, int):
44
+ mask[:, self.mask_out, ...] = 0
45
+ elif self.mask_out is None:
46
+ pass
47
+ else:
48
+ raise NotImplementedError(f'Unknown mask_out: {self.mask_out}')
49
+ return mask
50
+
51
+
52
+ def __call__(self, x, *args, **kwargs):
53
+ '''
54
+ Perform the PGD_0 attack on a batch of images x.
55
+
56
+ args:
57
+ x: Tensor of shape [B, C, H, W], batch of images.
58
+ y: Tensor of shape [B], batch of labels.
59
+
60
+ Returns a tensor of the same shape as x containing adversarial examples
61
+ '''
62
+
63
+ for param in self.model.model.parameters():
64
+ param.requires_grad = False
65
+
66
+ mask_out = self._set_mask(x)
67
+ x = x.to(self.device)
68
+ B, C, H, W = x.shape[1], x.shape[3], x.shape[4], x.shape[5]
69
+
70
+ for _ in range(self.n_restarts):
71
+ if not len(x):
72
+ break
73
+ eps = torch.full_like(x, self.eps)
74
+ lb, ub = torch.maximum(-eps, -x),torch.minimum(eps, 1.0 - x) #self.img_range[0] - x, self.img_range[1] - x
75
+ pert = (torch.clamp(x + (ub - lb) * torch.rand_like(x) + lb, *self.img_range) - x).view(B, C, H, W) * mask_out.view(B, C, H, W)
76
+ pert = self.project_L0(pert, lb, ub) # pert is of the shape (B, C, H, W)
77
+
78
+ for _ in range(self.iters):
79
+ pert.requires_grad = True
80
+ loss = self.lossfn(x=x, pert=pert.view(*x.shape), mask_out=mask_out)
81
+ loss.backward()
82
+
83
+ if self.ver and _ % 20 == 0:
84
+ print(f"Loss: {loss}, Iter: {_}")
85
+
86
+ grad = pert.grad.data.view(B,C,H,W) * mask_out.view(B, C, H, W) # shape (B, C, H, W)
87
+ with torch.no_grad():
88
+ grad /= grad.abs().sum(dim=(1,2,3), keepdim=True) + 1e-10
89
+ pert += (torch.rand_like(x) - .5).view(B, C, H, W) * 1e-12 - self.stepsize * grad
90
+ pert = self.project_L0(pert, lb, ub)
91
+
92
+ return (x + pert.view(*x.shape) * mask_out).detach()
93
+
94
+
95
+ def project_L0_sigma(self, pert, sigma, kappa, x_orig):
96
+
97
+ B, C, H, W = pert.shape
98
+ x = torch.clone(pert)
99
+ p1 = (1.0 / torch.maximum(1e-12, sigma) * (x_orig > 0).float()) + \
100
+ (1e12 * (x_orig == 0).float())
101
+ p2 = (1.0 / torch.maximum(torch.tensor(1e-12), sigma)) * \
102
+ (1.0 / torch.maximum(torch.tensor(1e-12), x_orig) - 1) * \
103
+ (x_orig > 0).float() + 1e12 * (x_orig == 0).float() + 1e12 * (sigma == 0).float()
104
+ lmbd_l = torch.maximum(-kappa, torch.amax(-p1, dim=1, keepdim=True))
105
+ lmbd_u = torch.minimum(kappa, torch.amin(p2, dim=1, keepdim=True))
106
+
107
+ lmbd_unconstr = torch.sum((pert - x_orig) * sigma * x_orig, dim=1, keepdim=True) / torch.clamp(torch.sum((sigma * x_orig) ** 2, dim=1, keepdim=True), min=1e-12)
108
+ lmbd = torch.maximum(lmbd_l, torch.minimum(lmbd_unconstr, lmbd_u))
109
+ return 0
110
+
111
+
112
+ def project_L0(self, pert, lb, ub):
113
+ '''
114
+ Project a batch of perturbations such that at most self.k pixels
115
+ are perturbed and componentwise there holds lb <= pert <= ub.
116
+ '''
117
+
118
+ B, C, H, W = pert.shape # Here, pert is of the shape B, C, H, W
119
+ p1 = torch.sum(pert ** 2, dim=1)
120
+ p2 = torch.clamp(torch.minimum(ub.view(B, C, H, W) - pert, pert - lb.view(B, C, H, W)), 0)
121
+ p2 = torch.sum(p2 ** 2, dim=1)
122
+ p3 = torch.topk(-1 * (p1 - p2).view(p1.size(0), -1), k=H*W-self.k, dim=-1)[1]
123
+ pert = torch.maximum(torch.minimum(pert, ub.view(B, C, H, W)), lb.view(B, C, H, W))
124
+ pert[torch.arange(0, B).view(-1, 1), :, p3//W, p3%H] = 0
125
+ return pert
126
+
127
+ def lossfn(self, x, pert, mask_out):
128
+ '''
129
+ Compute the loss at x.
130
+ '''
131
+ return (2 * self.targeted - 1) * self.model(x + pert * mask_out).sum()
vlm_eval/attacks/saif.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code adapted from https://github.com/wagnermoritz/GSE
2
+
3
+ from vlm_eval.attacks.attack import Attack
4
+ import torch
5
+ import math
6
+ import time
7
+
8
+ class SAIF(Attack):
9
+ def __init__(self, model, *args, targeted=False, img_range=(-1, 1), steps=200,
10
+ r0=1, ver=False, k=10000, eps=16./255., mask_out='none', **kwargs):
11
+ '''
12
+ Adapted from: https://github.com/wagnermoritz/GSE/tree/main
13
+ Implementation of the sparse Frank-Wolfe attack SAIF
14
+ https://arxiv.org/pdf/2212.07495.pdf
15
+
16
+ args:
17
+ model: Callable, PyTorch classifier.
18
+ img_range: Tuple of ints/floats, lower and upper bound of image
19
+ entries.
20
+ targeted: Bool, given label is used as a target label if True.
21
+ steps: Int, number of FW iterations.
22
+ r0: Int, parameter for step size computation.
23
+ ver: Bool, print progress if True.
24
+ '''
25
+ super().__init__(model, targeted=targeted, img_range=img_range)
26
+ self.steps = steps
27
+ self.r0 = r0
28
+ self.loss_fn = torch.nn.CrossEntropyLoss()
29
+ self.ver = ver
30
+ self.k = k
31
+ self.eps = eps
32
+ if mask_out != 'none':
33
+ self.mask_out = mask_out
34
+ else:
35
+ self.mask_out = None
36
+
37
+ def _set_mask(self, data):
38
+ mask = torch.ones_like(data)
39
+ if self.mask_out == 'context':
40
+ mask[:, :-1, ...] = 0
41
+ elif self.mask_out == 'query':
42
+ mask[:, -1, ...] = 0
43
+ elif isinstance(self.mask_out, int):
44
+ mask[:, self.mask_out, ...] = 0
45
+ elif self.mask_out is None:
46
+ pass
47
+ else:
48
+ raise NotImplementedError(f'Unknown mask_out: {self.mask_out}')
49
+ return mask
50
+
51
+ def __call__(self, x):
52
+ '''
53
+ Perform the attack on a batch of images x.
54
+
55
+ args:
56
+ x: Tensor of shape [B, C, H, W], batch of images.
57
+ k: Int, sparsity parameter,
58
+ eps: Float, perturbation magnitude parameter.
59
+
60
+ Returns a tensor of the same shape as x containing adversarial examples.
61
+ '''
62
+ assert x.shape[0] == 1, "Only support batch size 1 for now"
63
+
64
+
65
+
66
+ for param in self.model.model.parameters():
67
+ param.requires_grad = False
68
+
69
+ B, C, H, W = x.shape[1], x.shape[3], x.shape[4], x.shape[5]
70
+ x = x.to(self.device)
71
+ batchidx = torch.arange(B).view(-1, 1)
72
+
73
+ mask_out = self._set_mask(x)
74
+ # compute p_0 and s_0
75
+ x_ = x.clone()
76
+ x_.requires_grad = True
77
+ out = self.model(x_)
78
+ loss = -out.sum() if not self.targeted else out.sum()
79
+ x__grad = torch.autograd.grad(loss, [x_])[0].detach() * mask_out
80
+ p = -self.eps * x__grad.sign()
81
+ p = p.detach().half()
82
+ ksmallest = torch.topk(-x__grad.view(B, -1), self.k, dim=1)[1]
83
+ ksmask = torch.zeros((B, C * H * W), device=self.device)
84
+ ksmask[batchidx, ksmallest] = 1
85
+ s = torch.logical_and(ksmask.view(*x.shape), x__grad < 0).float()
86
+ s = s.detach().half()
87
+
88
+ r = self.r0
89
+
90
+
91
+ for t in range(self.steps):
92
+ if self.ver:
93
+ print(f'\r Iteration {t+1}/{self.steps}', end='')
94
+ p.requires_grad = True
95
+ s.requires_grad = True
96
+
97
+ D = self.Loss_fn(x, s, p, mask_out)
98
+ D.backward()
99
+
100
+ mp = p.grad * mask_out
101
+ ms = s.grad * mask_out
102
+ with torch.no_grad():
103
+ # inf-norm LMO
104
+ v = (-self.eps * mp.sign()).half()
105
+
106
+ # 1-norm LMO
107
+ ksmallest = torch.topk(-ms.view(B, -1), self.k, dim=1)[1]
108
+ ksmask = torch.zeros((B, C * H * W), device=self.device)
109
+ ksmask[batchidx, ksmallest] = 1
110
+ ksmask = ksmask.view(*x.shape) * mask_out
111
+ z = torch.logical_and(ksmask, ms < 0).float().half()
112
+ # update stepsize until primal progress is made
113
+ mu = 1 / (2 ** r * math.sqrt(t + 1))
114
+ progress_condition = (self.Loss_fn(x, s + mu * (z - s), p + mu * (v - p), mask_out)
115
+ > D)
116
+
117
+ while progress_condition:
118
+ r += 1
119
+ if r >= 50:
120
+ break
121
+ mu = 1 / (2 ** r * math.sqrt(t + 1))
122
+ progress_condition = (self.Loss_fn(x, s + mu * (z - s), p + mu * (v - p), mask_out)
123
+ > D)
124
+
125
+
126
+ p = p + mu * (v - p)
127
+ s = s + mu * (z - s)
128
+
129
+ x_adv = torch.clamp(x + p, *self.img_range)
130
+ p = x_adv - x
131
+
132
+ if self.ver and t % 10 == 0:
133
+ print(f" Loss: {D}")
134
+ if self.ver:
135
+ print('')
136
+ return (x + s * p * mask_out).detach(), torch.norm(s*p,p=0).item()
137
+
138
+ def Loss_fn(self, x, s, p, mask_out):
139
+ out = self.model(x + s * p * mask_out).sum()
140
+ if self.targeted:
141
+ return out
142
+ else:
143
+ return -out
vlm_eval/attacks/sparsers.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code taken and adapted from https://github.com/wagnermoritz/GSE
2
+ from vlm_eval.attacks.attack import Attack
3
+ import torch
4
+
5
+ class SparseRS(Attack):
6
+ def __init__(self, model, *args, targeted=False, img_range=(-1, 1),
7
+ n_queries=10000, k=100, n_restarts=10, alpha_init=0.8, mask_out='none',**kwargs):
8
+ '''
9
+ Implementation of the L0 variant SparseRS https://arxiv.org/abs/2006.12834
10
+ Authors' implementation: https://github.com/fra31/sparse-rs
11
+ Adapted from: https://github.com/wagnermoritz/GSE/tree/main
12
+
13
+ args:
14
+ model: Callable, PyTorch classifier.
15
+ targeted: Bool, given label is used as a target label if True.
16
+ img_range: Tuple of ints/floats, lower and upper bound of image
17
+ entries.
18
+ n_queries: Int, max number of queries to the model
19
+ k: Int, initial sparsity parameter
20
+ n_restarts: Int, number of restarts with random initialization
21
+ alpha_init: Float, inital value for alpha schedule
22
+ '''
23
+ super().__init__(model, targeted=targeted, img_range=img_range)
24
+ self.n_queries = n_queries
25
+ self.k = k
26
+ self.n_restarts = n_restarts
27
+ self.alpha_init = alpha_init
28
+ if mask_out != 'none':
29
+ self.mask_out = mask_out
30
+ else:
31
+ self.mask_out = None
32
+
33
+ def _set_mask(self, data):
34
+ mask = torch.ones_like(data)
35
+ if self.mask_out == 'context':
36
+ mask[:, :-1, ...] = 0
37
+ elif self.mask_out == 'query':
38
+ mask[:, -1, ...] = 0
39
+ elif isinstance(self.mask_out, int):
40
+ mask[:, self.mask_out, ...] = 0
41
+ elif self.mask_out is None:
42
+ pass
43
+ else:
44
+ raise NotImplementedError(f'Unknown mask_out: {self.mask_out}')
45
+ return mask
46
+
47
+
48
+ def __call__(self, x, *args, **kwargs):
49
+ '''
50
+ Perform SparseRS L0 on a batch of images x with corresponding labels y.
51
+
52
+ args:
53
+ x: Tensor of shape [B, C, H, W], batch of images.
54
+ y: Tensor of shape [B], batch of labels.
55
+
56
+ Returns a tensor of the same shape as x containing adversarial examples
57
+ '''
58
+
59
+ for param in self.model.model.parameters():
60
+ param.requires_grad = False
61
+
62
+ torch.random.manual_seed(0)
63
+ torch.cuda.random.manual_seed(0)
64
+ x = x.to(self.device)
65
+
66
+ with torch.no_grad():
67
+ for _ in range(self.n_restarts):
68
+ if len(x) == 0:
69
+ break
70
+
71
+ x_adv = self.__perturb(x.clone())
72
+
73
+ return x_adv.detach()
74
+
75
+
76
+ def __perturb(self, x):
77
+ '''
78
+ Perform the attack from a random starting point.
79
+ '''
80
+ mask_out = self._set_mask(x)
81
+ B, C, H, W = x.shape[1], x.shape[3], x.shape[4], x.shape[5]
82
+ batchidx = torch.arange(B, device=self.device).view(-1, 1)
83
+ result = x.clone().view(B, C, H, W)
84
+
85
+ # M: set of perturbed pixel indices, U_M: set of unperturbed pixel indices
86
+ batch_randperm = torch.rand(B, H * W, device=self.device).argsort(dim=1)
87
+ M = batch_randperm[:, :self.k]
88
+ U_M = batch_randperm[:, self.k:]
89
+ result[batchidx, :, M//W, M%H] = self.__sampleDelta(B, C, self.k)
90
+
91
+ best_loss = self.__lossfn(result.view(*x.shape))
92
+
93
+ for i in range(1, self.n_queries):
94
+ if B == 0:
95
+ break
96
+ # reset k_i currently perturbed pixels and perturb k_i new pixels
97
+ k_i = max(int(self.__alphaSchedule(i) * self.k), 1)
98
+ A_idx = torch.randperm(self.k, device=self.device)[:k_i]
99
+ B_idx = torch.randperm(H * W - self.k, device=self.device)[:k_i]
100
+ A_set, B_set = M[:, A_idx], U_M[:, B_idx]
101
+
102
+ z = result.clone()
103
+ z[batchidx, :, A_set//W, A_set%H] = x.view(B, C, H, W)[batchidx, :, A_set//W, A_set%H]
104
+ if k_i > 1:
105
+ z[batchidx, :, B_set//W, B_set%H] = self.__sampleDelta(B, C, k_i)
106
+ else: # if only one pixel is changed, make sure it actually changes
107
+ new_color = self.__sampleDelta(B, C, k_i)
108
+ while (mask := (z[batchidx, :, B_set//W, B_set%H] == new_color).view(B, -1).all(dim=-1)).any():
109
+ new_color[mask] = self.__sampleDelta(mask.int().sum().item(), C, k_i)
110
+ z[batchidx, :, B_set//W, B_set%H] = new_color
111
+
112
+ # save perturbations that improved the loss/margin
113
+ loss = self.__lossfn(z, y)
114
+ mask = loss < best_loss
115
+ best_loss[mask] = loss[mask]
116
+ mask = torch.logical_or(mask, margin < -1e-6)
117
+ if mask.any():
118
+ #best_margin[mask] = margin[mask]
119
+ tmp = result[active]
120
+ tmp[mask] = z[mask]
121
+ result[active] = tmp
122
+ U_M[mask.nonzero().view(-1, 1), B_idx] = A_set[mask]
123
+ M[mask.nonzero().view(-1, 1), A_idx] = B_set[mask]
124
+
125
+ # stop working on successful adv examples
126
+ mask = best_margin < 0
127
+ if mask.any():
128
+ mask = torch.logical_not(mask)
129
+ active[active.clone()] = mask
130
+ x, y, z, M, U_M = x[mask], y[mask], z[mask], M[mask], U_M[mask]
131
+ best_margin, best_loss = best_margin[mask], best_loss[mask]
132
+ B = len(y)
133
+ batchidx = torch.arange(B, device=self.device).view(-1, 1)
134
+
135
+ return result
136
+
137
+
138
+ def __sampleDelta(self, B, C, k):
139
+ '''
140
+ Sample k-pixel perturbations for B images. Each pixel is assigned a
141
+ random corner in the C-dimensional cube defined by self.img_range.
142
+ '''
143
+ fac = self.img_range[1] - self.img_range[0]
144
+ return self.img_range[0] + fac * torch.randint(0, 1, [B, k, C],
145
+ dtype=torch.float,
146
+ device=self.device)
147
+
148
+
149
+ def __alphaSchedule(self, iteration):
150
+ '''
151
+ Update number of pixels to perturb based in the current iteration.
152
+ '''
153
+ iteration = int(iteration / self.n_queries * 10000)
154
+ factors = [1, 2, 4, 5, 6, 8, 10, 12, 15, 20]
155
+ alpha_schedule = [10, 50, 200, 500, 1000, 2000, 4000, 6000, 8000]
156
+ idx = bisect.bisect_left(alpha_schedule, iteration)
157
+ return self.alpha_init / factors[idx]
158
+
159
+
160
+ def __lossfn(self, x):
161
+ '''
162
+ Compute the loss depending on self.targeted.
163
+ '''
164
+ return self.model(x).sum() if self.targeted else -self.model(x).sum()
vlm_eval/attacks/strattack.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code taken and adapted from https://github.com/wagnermoritz/GSE
2
+
3
+ from vlm_eval.attacks.attack import Attack
4
+ import torch
5
+ import math
6
+ import torch.nn.functional as F
7
+
8
+ class StrAttack(Attack):
9
+ def __init__(self, model, *args, targeted=False, img_range=(0, 1), kappa=0,
10
+ max_iter=100, ver=False, search_steps=2, max_c=1e10, rho=1, mask_out='none',
11
+ c=2.5, retrain=False, **kwargs):
12
+ '''
13
+ Implementation of StrAttack: https://arxiv.org/abs/1808.01664
14
+ Adapted from https://github.com/KaidiXu/StrAttack
15
+
16
+ args:
17
+ model: Callable, PyTorch classifier.
18
+ targeted: Bool, given label is used as a target label if True.
19
+ img_range: Tuple of ints/floats, lower and upper bound of image
20
+ entries.
21
+ max_iter: Int, number of iterations.
22
+ ver: Bool, print progress if True.
23
+ search_steps: Int, number of binary search steps.
24
+ max_c: Float, upper bound for regularizaion parameter.
25
+ rho: Float, ADMM parameter.
26
+ c: Float, initial regularization parameter.
27
+ '''
28
+ super().__init__(model, targeted=targeted, img_range=img_range)
29
+ self.max_iter = max_iter
30
+ self.ver = ver
31
+ self.search_steps = search_steps
32
+ self.max_c = max_c
33
+ self.rho = rho
34
+ self.c = c
35
+ self.retrain = retrain
36
+ if mask_out != 'none':
37
+ self.mask_out = mask_out
38
+ else:
39
+ self.mask_out = None
40
+
41
+ def _set_mask(self, data):
42
+ mask = torch.ones_like(data)
43
+ if self.mask_out == 'context':
44
+ mask[:, :-1, ...] = 0
45
+ elif self.mask_out == 'query':
46
+ mask[:, -1, ...] = 0
47
+ elif isinstance(self.mask_out, int):
48
+ mask[:, self.mask_out, ...] = 0
49
+ elif self.mask_out is None:
50
+ pass
51
+ else:
52
+ raise NotImplementedError(f'Unknown mask_out: {self.mask_out}')
53
+ return mask
54
+
55
+ def __call__(self, imgs, *args, **kwargs):
56
+ '''
57
+ Perform StrAttack on a batch of images x with corresponding labels y.
58
+
59
+ args:
60
+ x: Tensor of shape [B, C, H, W], batch of images.
61
+
62
+ Returns a tensor of the same shape as x containing adversarial examples
63
+ '''
64
+
65
+ for param in self.model.model.parameters():
66
+ param.requires_grad = False
67
+
68
+ c_ = self.c
69
+ imgs = imgs.to(self.device)
70
+ sh = imgs.shape
71
+ batch_size = sh[1]
72
+ mask_out = self._set_mask(imgs)
73
+
74
+ alpha, tau, gamma = 5, 2, 1
75
+ eps = torch.full_like(imgs, 1.0) * mask_out
76
+ # 16 for imagenet, 2 for CIFAR and MNIST
77
+ filterSize = 8 if sh[-1] > 32 else 2
78
+ stride = filterSize
79
+ # convolution kernel used to compute norm of each group
80
+ slidingM = torch.ones((1, sh[3], filterSize, filterSize), device=self.device)
81
+
82
+ cs = torch.ones(batch_size, device=self.device) * c_
83
+ lower_bound = torch.zeros(batch_size)
84
+ upper_bound = torch.ones(batch_size) * self.max_c
85
+
86
+ o_bestl2 = torch.full_like(torch.randn(batch_size), 1e10, dtype=torch.float)
87
+ o_bestscore = torch.full_like(o_bestl2, -1, dtype=torch.float)
88
+ o_bestattack = imgs.clone()
89
+ o_besty = torch.ones_like(imgs)
90
+
91
+ for step in range(self.search_steps):
92
+
93
+ bestl2 = torch.full_like(o_bestl2, 1e10, dtype=torch.float)
94
+ bestscore = torch.full_like(o_bestl2, -1, dtype=torch.float)
95
+
96
+ z, v, u, s = (torch.zeros_like(imgs) for _ in range(4))
97
+
98
+ for iter_ in range(self.max_iter):
99
+ if (not iter_%10 or iter_ == self.max_iter - 1) and self.ver:
100
+ print(f'\rIteration: {iter_+1}/{self.max_iter}, ' +
101
+ f'Search Step: {step+1}/{self.search_steps}', end='')
102
+
103
+ # first update step (7) / Proposition 1
104
+ delta = self.rho / (self.rho + 2 * gamma) * (z - u / self.rho)
105
+
106
+ b = (z - s / self.rho) * mask_out
107
+ tmp = torch.minimum(self.img_range[1] - imgs, eps)
108
+ w = torch.where(b.view(*sh) > tmp.view(*sh), tmp, b) # creating issue (1x5x'5'x3x224x224 instead of 1x5x1x3x224x224)
109
+ tmp = torch.maximum(self.img_range[0] - imgs, -eps)
110
+ w = torch.where(b.view(*sh) < tmp.view(*sh), tmp, w)
111
+
112
+ c = z - v / self.rho
113
+ cNorm = torch.sqrt(F.conv2d(c.view(sh[1], sh[3], sh[4], sh[5]) ** 2, slidingM, stride=stride))
114
+ cNorm = torch.where(cNorm == 0, torch.full_like(cNorm, 1e-12), cNorm)
115
+ cNorm = F.interpolate(cNorm, scale_factor=filterSize)
116
+ y = torch.clamp((1 - tau / (self.rho * cNorm.unsqueeze(0).unsqueeze(3))), 0) * c
117
+
118
+ # second update step (8) / equation (15)
119
+ z_grads = self.__get_z_grad(imgs, z.clone(), cs)
120
+ eta = alpha * math.sqrt(iter_ + 1)
121
+ coeff = (1 / (eta + 3 * self.rho))
122
+ z = coeff * (eta * z + self.rho * (delta + w + y) + u + s + v - z_grads)
123
+
124
+ # third update step (9)
125
+ u = u + self.rho * (delta - z) * mask_out
126
+ v = v + self.rho * (y - z) * mask_out
127
+ s = s + self.rho * (w - z) * mask_out
128
+ # get info for binary search
129
+ x = imgs + y * mask_out
130
+ l2s = torch.sum((z ** 2).reshape(z.size(1), -1), dim=-1)
131
+
132
+ for i, (l2, x_) in enumerate(zip(l2s, x.squeeze(0))):
133
+ if l2 < bestl2[i]:
134
+ bestl2[i] = l2
135
+ if l2 < o_bestl2[i]:
136
+ o_bestl2[i] = l2
137
+ o_bestattack[:,i] = x_.detach().unsqueeze(0).clone()
138
+ o_besty[:,i] = y[:,i]
139
+ for i in range(batch_size):
140
+
141
+ lower_bound[i] = max(lower_bound[i], cs[i])
142
+ if upper_bound[i] < 1e9:
143
+ cs[i] = (lower_bound[i] + upper_bound[i]) / 2
144
+ else:
145
+ cs[i] *= 5
146
+
147
+ del v, u, s, z_grads, w, tmp
148
+
149
+ if self.retrain:
150
+ cs = torch.full_like(o_bestl2, 5.0, dtype=torch.float)
151
+ zeros = torch.zeros_like(imgs)
152
+
153
+ for step in range(8):
154
+ bestl2 = torch.full_like(cs, 1e10, dtype=torch.float, device=self.device)
155
+ bestscore = torch.full_like(cs, -1, dtype=torch.float, device=self.device)
156
+
157
+ Nz = o_besty[o_besty != 0]
158
+ e0 = torch.quantile(Nz.abs(), 0.03)
159
+ A2 = torch.where(o_besty.abs() <= e0, 0, 1)
160
+ z1 = o_besty
161
+ u1 = torch.zeros_like(imgs)
162
+ tmpc = self.rho / (self.rho + gamma / 100)
163
+
164
+ for j in range(100):
165
+ if self.ver and not j % 10:
166
+ print(f'\rRetrain iteration: {step+1}/8, ' +
167
+ f'Search Step: {j+1}/200', end='')
168
+
169
+ tmpA = (z1 - u1) * tmpc
170
+ tmpA1 = torch.where(o_besty.abs() <= e0, zeros, tmpA)
171
+ cond = torch.logical_and(tmpA >
172
+ torch.minimum(self.img_range[1] - imgs, eps),
173
+ o_besty.abs() > e0)
174
+ tmpA2 = torch.where(cond, torch.minimum(self.img_range[1] - imgs, eps),
175
+ tmpA1)
176
+ cond = torch.logical_and(tmpA <
177
+ torch.maximum(self.img_range[0] - imgs, -eps),
178
+ o_besty.abs() > e0)
179
+ deltA = torch.where(cond, torch.maximum(self.img_range[0] - imgs, -eps),
180
+ tmpA2)
181
+
182
+ x = imgs + deltA * mask_out
183
+ grad = self.__get_z_grad(imgs, deltA, cs)
184
+
185
+ stepsize = 1 / (alpha + 2 * self.rho)
186
+ z1 = stepsize * (alpha * z1 * self.rho
187
+ * (deltA + u1) - grad * A2)
188
+ u1 = u1 + deltA - z1
189
+
190
+ for i, (l2, x_) in enumerate(zip(l2s, x.squeeze(0))):
191
+ if l2 < bestl2[i]:
192
+ bestl2[i] = l2
193
+ #bestscore[i] = asc
194
+ if l2 < o_bestl2[i]:
195
+ o_bestl2[i] = l2
196
+ #o_bestscore[i] = asc
197
+ o_bestattack[:,i] = x_.detach().unsqueeze(0).clone()
198
+ o_besty[i] = deltA[i]
199
+
200
+
201
+ for i in range(batch_size):
202
+ if (bestscore[i] != -1 and bestl2[i] == o_bestl2[i]):
203
+ upper_bound[i] = min(upper_bound[i], cs[i])
204
+ if upper_bound[i] < 1e9:
205
+ cs[i] = (lower_bound[i] + upper_bound[i]) / 2
206
+
207
+ else:
208
+ lower_bound[i] = max(lower_bound[i], cs[i])
209
+ if upper_bound[i] < 1e9:
210
+ cs[i] = (lower_bound[i] + upper_bound[i]) / 2
211
+ else:
212
+ cs[i] *= 5
213
+
214
+ if self.ver:
215
+ print('')
216
+
217
+ return (o_bestattack * mask_out).detach()
218
+
219
+
220
+ def __get_z_grad(self, imgs, z, cs):
221
+ '''
222
+ Compute and return gradient of loss wrt. z.
223
+ '''
224
+ z.requires_grad = True
225
+ tmp = self.model(z + imgs).sum() if self.targeted else -self.model(z + imgs).sum()
226
+ loss = torch.mean(cs.to(self.device) * tmp)
227
+ z_grad_data = torch.autograd.grad(loss, [z])[0].detach()
228
+ z.detach_()
229
+ return z_grad_data
vlm_eval/attacks/utils.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import collections.abc as container_abcs
4
+
5
+ # Code taken from https://github.com/chs20/RobustVLM/tree/main
6
+ # some parts of this code are adapted from
7
+ # https://github.com/M4xim4l/InNOutRobustness/blob/main/utils/adversarial_attacks/utils.py
8
+
9
+ def project_perturbation(perturbation, eps, norm):
10
+ if norm in ['inf', 'linf', 'Linf']:
11
+ pert_normalized = torch.clamp(perturbation, -eps, eps)
12
+ return pert_normalized
13
+ elif norm in [2, 2.0, 'l2', 'L2', '2']:
14
+ pert_normalized = torch.renorm(perturbation, p=2, dim=0, maxnorm=eps)
15
+ return pert_normalized
16
+ else:
17
+ raise NotImplementedError(f'Norm {norm} not supported')
18
+
19
+
20
+ def normalize_grad(grad, p):
21
+ if p in ['inf', 'linf', 'Linf']:
22
+ return grad.sign()
23
+ elif p in [2, 2.0, 'l2', 'L2', '2']:
24
+ bs = grad.shape[0]
25
+ grad_flat = grad.view(bs, -1)
26
+ grad_normalized = F.normalize(grad_flat, p=2, dim=1)
27
+ return grad_normalized.view_as(grad)
28
+
29
+
30
+ def L1_norm(x, keepdim=False):
31
+ z = x.abs().view(x.shape[0], -1).sum(-1)
32
+ if keepdim:
33
+ z = z.view(-1, *[1]*(len(x.shape) - 1))
34
+ return z
35
+
36
+ def L2_norm(x, keepdim=False):
37
+ z = (x ** 2).view(x.shape[0], -1).sum(-1).sqrt()
38
+ if keepdim:
39
+ z = z.view(-1, *[1]*(len(x.shape) - 1))
40
+ return z
41
+
42
+ def L0_norm(x):
43
+ return (x != 0.).view(x.shape[0], -1).sum(-1)
44
+
45
+ def zero_gradients(x):
46
+ if isinstance(x, torch.Tensor):
47
+ if x.grad is not None:
48
+ x.grad.detach_()
49
+ x.grad.zero_()
50
+ elif isinstance(x, container_abcs.Iterable):
51
+ for elem in x:
52
+ zero_gradients(elem)
vlm_eval/clip_classification.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code adapted from https://github.com/openai/CLIP/blob/main/
2
+ from transformers import CLIPProcessor, CLIPModel
3
+ import argparse
4
+ import torch
5
+ from torch.utils.data import DataLoader
6
+ from tqdm import tqdm
7
+ from datasets_classes_templates import data_seeds
8
+ import numpy as np
9
+ from datetime import datetime
10
+
11
+ def zeroshot_classifier(classnames, templates, processor, model):
12
+ with torch.no_grad():
13
+ zeroshot_weights = []
14
+ for classname in tqdm(classnames):
15
+ texts = [template.format(classname) for template in templates] #format with class
16
+ text_inputs = processor(text=texts, return_tensors="pt", padding=True, truncation=True).to('cuda')
17
+ class_embeddings = model.get_text_features(text_inputs['input_ids']) #embed with text encoder
18
+ class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
19
+ class_embedding = class_embeddings.mean(dim=0)
20
+ class_embedding /= class_embedding.norm()
21
+ zeroshot_weights.append(class_embedding)
22
+ zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
23
+ return zeroshot_weights
24
+
25
+ def classification_collate_fn(batch):
26
+ images, labels = zip(*batch)
27
+ labels = torch.tensor(labels)
28
+ return images, labels
29
+
30
+ def accuracy(output, target, topk=(1,)):
31
+ pred = output.topk(max(topk), 1, True, True)[1].t()
32
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
33
+ return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]
34
+
35
+
36
+ def main():
37
+
38
+ parser = argparse.ArgumentParser()
39
+ parser.add_argument("--data", type=str, default=None, choices=['non_fine_tuned','MS_COCO','medium','base','all'], help='Data on which clip was fine-tuned')
40
+ parser.add_argument("--dataset", type=str, default="CIFAR10", choices=["CIFAR10", "CIFAR100", "ImageNet", "Caltech101", "Caltech256", "Food101"])
41
+ parser.add_argument("--method",type=str, default="COCO_CF", choices=['COCO_CF','APGD_1','APGD_4','NONE'])
42
+ args = parser.parse_args()
43
+
44
+ current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
45
+ results_filename = f'./Results/fine_tuned_clip/zeroshot_image_classification_results_{args.dataset}_{args.data}_{args.method}_{current_time}.txt'
46
+ with open(results_filename, 'w') as f:
47
+ f.write(f'Arguments: {args}\n\n')
48
+
49
+ if args.data == 'MS_COCO':
50
+ assert args.method == 'NONE' and args.data == 'MS_COCO', 'Use NONE for method for MS_COCO data'
51
+
52
+ imagenet_path = '/software/ais2t/pytorch_datasets/imagenet/' # Fill the path for imagenet here
53
+
54
+ if args.dataset == "CIFAR10":
55
+ from datasets_classes_templates import CIFAR10_CLASSES_TEMPLATES as classes_templates
56
+ from torchvision.datasets import CIFAR10
57
+ data = CIFAR10(root='./image_classification_datasets/cifar10/', train=False, download=True)
58
+ elif args.dataset == "CIFAR100":
59
+ from datasets_classes_templates import CIFAR100_CLASSES_TEMPLATES as classes_templates
60
+ from torchvision.datasets import CIFAR100
61
+ data = CIFAR100(root='./image_classification_datasets/cifar100/', train=False, download=True)
62
+ elif args.dataset == "ImageNet":
63
+ from datasets_classes_templates import ImageNet_CLASSES_TEMPLATES as classes_templates
64
+ from torchvision.datasets import ImageNet
65
+ data = ImageNet(root=imagenet_path, split='val')
66
+ elif args.dataset == "Caltech101":
67
+ torch.manual_seed(42)
68
+ from datasets_classes_templates import Caltech101_CLASSES_TEMPLATES as classes_templates
69
+ from torchvision.datasets import Caltech101
70
+ data = Caltech101(root='./image_classification_datasets/', download=False)
71
+ train_size = int(0.8 * len(data)) # 80% for training
72
+ val_size = len(data) - train_size
73
+ _, data = torch.utils.data.random_split(data, [train_size, val_size])
74
+ elif args.dataset == "Caltech256":
75
+ torch.manual_seed(42)
76
+ from datasets_classes_templates import Caltech256_CLASSES_TEMPLATES as classes_templates
77
+ from torchvision.datasets import Caltech256
78
+ data = Caltech256(root='./image_classification_datasets/', download=False)
79
+ train_size = int(0.8 * len(data)) # 80% for training
80
+ val_size = len(data) - train_size
81
+ _, data = torch.utils.data.random_split(data, [train_size, val_size])
82
+ elif args.dataset == "Food101":
83
+ from datasets_classes_templates import Food101_CLASSES_TEMPLATES as classes_templates
84
+ from torchvision.datasets import Food101
85
+ data = Food101(root='./image_classification_datasets/food101/', download=True, split='test')
86
+ else:
87
+ raise NotImplementedError
88
+
89
+ print(f'Conducting zero-shot image classification on {args.dataset}')
90
+
91
+ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
92
+ model_base_path = './fine_tuned_clip_models'
93
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
94
+
95
+ top1_list = []
96
+ for data_seed in data_seeds:
97
+ print(f'Conducting zero-shot image classification on {args.data} with seed {data_seed} for the method {args.method}')
98
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
99
+ if args.data != 'non_fine_tuned':
100
+ if args.method != 'NONE':
101
+ if args.data not in ['all']:
102
+ model.load_state_dict(torch.load(f'{model_base_path}/{args.method}/clip_model_dataset_{args.data}_method_{args.method}_num_epochs_20_data_seed_{data_seed}.pt'))
103
+ else:
104
+ model.load_state_dict(torch.load(f'{model_base_path}/{args.method}/clip_model_dataset_{args.data}_method_{args.method}_num_epochs_20.pt'))
105
+ elif args.method == 'NONE' and args.data == 'MS_COCO':
106
+ model.load_state_dict(torch.load(f'{model_base_path}/{args.method}/clip_model_dataset_{args.data}_method_{args.method}_num_epochs_20.pt'))
107
+
108
+ model.eval()
109
+
110
+ data_loader = DataLoader(data, batch_size=128, collate_fn=classification_collate_fn, shuffle=False)
111
+
112
+ zeroshot_weights = zeroshot_classifier(classes_templates['classes'],
113
+ classes_templates['templates'],
114
+ processor,
115
+ model
116
+ )
117
+
118
+ with torch.no_grad():
119
+ top1, top5, n = 0., 0., 0.
120
+ for i, (images, target) in enumerate(tqdm(data_loader)):
121
+ target = target.to(device)
122
+ images = list(images)
123
+
124
+ images = processor(images=images, return_tensors="pt").to(device)
125
+
126
+ # predict
127
+ image_features = model.get_image_features(images['pixel_values']).to(device)
128
+ image_features /= image_features.norm(dim=-1, keepdim=True)
129
+ logits = 100. * image_features @ zeroshot_weights
130
+
131
+ # measure accuracy
132
+ acc1, acc5 = accuracy(logits, target, topk=(1, 5))
133
+ top1 += acc1
134
+ top5 += acc5
135
+ n += image_features.size(0)
136
+
137
+ top1 = (top1 / n) * 100
138
+ top5 = (top5 / n) * 100
139
+
140
+ with open(results_filename, 'a') as f:
141
+ f.write(f'Seed {data_seed}: Top-1 Accuracy: {top1:.2f}, Top-5 Accuracy: {top5:.2f}\n')
142
+
143
+ top1_list.append(top1)
144
+
145
+ print(f"Top-1 accuracy: {top1:.2f}")
146
+ print(f"Top-5 accuracy: {top5:.2f}")
147
+ print('-'*40)
148
+
149
+ if args.method == 'NONE' or args.data in ['MS_COCO','all'] or args.data == 'non_fine_tuned':
150
+ break
151
+ top1 = np.asarray(top1_list)
152
+ print(f'Mean of the top 1 accuracy is {np.mean(top1)}')
153
+ print(f'Standard deviation of the top 1 accuracy is {np.std(top1)}')
154
+
155
+ with open(results_filename, 'a') as f:
156
+ f.write(f'\nMean Top-1 Accuracy: {np.mean(top1):.2f}\n')
157
+ f.write(f'Standard Deviation of Top-1 Accuracy: {np.std(top1):.2f}\n')
158
+
159
+ if __name__ == "__main__":
160
+ main()
vlm_eval/clip_train.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code adapted from https://github.com/ylaxor/clip-like/blob/main/fine-tune-clip.ipynb
2
+
3
+ from random import seed, shuffle
4
+ from typing import Callable
5
+ import torch
6
+ from tqdm import tqdm
7
+ from transformers import CLIPProcessor, CLIPModel
8
+ from timm.scheduler import CosineLRScheduler
9
+
10
+
11
+
12
+ class ModelTrainer:
13
+
14
+ def __init__(self,
15
+ model: Callable,
16
+ processor: Callable,
17
+ data_name: str,
18
+ train_data_loader: torch.utils.data.DataLoader,
19
+ val_data_loader: torch.utils.data.DataLoader,
20
+ num_epochs: int,
21
+ learning_rate: float = 5e-5,
22
+ weight_decay: float = 1e-3,
23
+ device: str = "cuda:0",
24
+ save_model: bool = False,
25
+ save_model_path: str = "./fine_tuned_clip_models",
26
+ data_seed: int = 42,
27
+ method="COCO_CF",
28
+ ) -> None:
29
+
30
+ self.model = model
31
+ self.processor = processor
32
+ self.data_name = data_name
33
+ self.train_data_loader = train_data_loader
34
+ self.val_data_loader = val_data_loader
35
+ self.num_epochs = num_epochs
36
+ self.learning_rate = learning_rate
37
+ self.weight_decay = weight_decay
38
+ self.device = device
39
+ self.save_model = save_model
40
+ self.save_model_path = save_model_path
41
+ self.data_seed = data_seed
42
+ self.method = method
43
+
44
+ self.optimizer = torch.optim.AdamW(
45
+ self.model.parameters(),
46
+ lr=self.learning_rate,
47
+ weight_decay=self.weight_decay
48
+ )
49
+
50
+
51
+ def train(self):
52
+ self.model.train()
53
+ lr_scheduler = CosineLRScheduler(
54
+ self.optimizer,
55
+ t_initial=self.num_epochs * len(self.train_data_loader),
56
+ lr_min=2e-7,
57
+ warmup_lr_init=1e-7,
58
+ warmup_prefix=True,
59
+ warmup_t=3,
60
+ cycle_limit=1,
61
+ t_in_epochs=False,
62
+ )
63
+ progress_bar = tqdm(range(self.num_epochs))
64
+ for epoch in progress_bar:
65
+ running_loss = 0.0
66
+ for batch_idx, batch in enumerate(self.train_data_loader):
67
+ self.optimizer.zero_grad()
68
+ processed_input = self.processor(text=batch["caption"],
69
+ images=batch["image"],
70
+ return_tensors="pt",
71
+ padding=True,
72
+ max_length=128,
73
+ truncation=True
74
+ )
75
+ outputs = self.model(input_ids=processed_input['input_ids'].squeeze().to(self.device),
76
+ pixel_values=processed_input['pixel_values'].squeeze().to(self.device),
77
+ attention_mask=processed_input['attention_mask'].squeeze().to(self.device),
78
+ return_loss=True
79
+ )
80
+ loss = outputs.loss
81
+ loss.backward()
82
+ running_loss += loss.item() * len(batch["caption"])
83
+ self.optimizer.step()
84
+ lr_scheduler.step_update(batch_idx + epoch * len(self.train_data_loader))
85
+
86
+ print(f"Epoch {epoch+1}/{self.num_epochs} Loss: {running_loss/len(self.train_data_loader.dataset):.4f}")
87
+ progress_bar.set_postfix(
88
+ epoch="{}/{}".format(epoch+1,self.num_epochs),
89
+ loss=running_loss/len(self.train_data_loader.dataset),
90
+ lr=self.optimizer.param_groups[0]["lr"]
91
+ )
92
+
93
+ if self.save_model:
94
+ if self.data_name not in ['MS_COCO','all']:
95
+ torch.save(self.model.state_dict(), self.save_model_path + f'clip_model_dataset_{self.data_name}_method_{self.method}_num_epochs_{self.num_epochs}_data_seed_{self.data_seed}.pt')
96
+ print(f"Saving fine-tuned model as clip_model_dataset_{self.data_name}_method_{self.method}_num_epochs_{self.num_epochs}_data_seed_{self.data_seed}.pt")
97
+ else:
98
+ torch.save(self.model.state_dict(), self.save_model_path + f'clip_model_dataset_{self.data_name}_method_{self.method}_num_epochs_{self.num_epochs}.pt')
99
+ print(f"Saving fine-tuned model as clip_model_dataset_{self.data_name}_method_{self.method}_num_epochs_{self.num_epochs}.pt")
100
+
101
+ def eval(self):
102
+ self.model.eval()
103
+ nb_batches = len(self.val_data_loader)
104
+ tqdm_object = tqdm(self.val_data_loader, total=len(self.val_data_loader))
105
+ epoch_loss = 0.0
106
+ for i, batch in enumerate(tqdm_object):
107
+ processed_input = self.processor(text=batch["caption"],
108
+ images=batch["image"],
109
+ return_tensors="pt",
110
+ padding=True,
111
+ max_length=128,
112
+ truncation=True
113
+ )
114
+ outputs = self.model(
115
+ input_ids=processed_input['input_ids'].squeeze().to(self.device),
116
+ attention_mask=processed_input['attention_mask'].squeeze().to(self.device),
117
+ pixel_values=processed_input['pixel_values'].squeeze().to(self.device),
118
+ return_loss=True)
119
+ loss, logits_per_image = outputs.loss, outputs.logits_per_image
120
+ epoch_loss += loss.item()
121
+ tqdm_object.set_postfix(
122
+ batch="{}/{}".format(i+1,nb_batches),
123
+ dev_loss=loss.item(),
124
+ )
125
+ epoch_loss = epoch_loss / nb_batches
126
+ print(f"Eval loss: {epoch_loss}")
127
+
128
+ def main():
129
+ import os
130
+ #os.environ['HF_HOME'] = '' Add path for saved hugging face models
131
+
132
+ import argparse
133
+ parser = argparse.ArgumentParser()
134
+ parser.add_argument('--num_epochs', type=int, default=20)
135
+ parser.add_argument('--data_name', type=str, default="MS_COCO", choices=["MS_COCO","base","medium","all"])
136
+ parser.add_argument('--learning_rate', type=float, default=1e-5)
137
+ parser.add_argument('--batch_size', type=int, default=32)
138
+ parser.add_argument('--save_model', action='store_true', default=False)
139
+ parser.add_argument('--method', type=str, choices=['COCO_CF','APGD_1','APGD_4','NONE'])
140
+ parser.add_argument('--save_model_path', type=str, default="./fine_tuned_clip_models")
141
+ parser.add_argument(
142
+ "--data_seeds",
143
+ nargs="+",
144
+ type=int,
145
+ default=[107],
146
+ help="Seeds to use for each trial for picking demonstrations and eval sets",
147
+ )
148
+ args = parser.parse_args()
149
+ if args.data_name == 'MS_COCO':
150
+ assert args.data_name == 'MS_COCO' and args.method == 'NONE', "Only NONE method is allowed with MS_COCO dataset"
151
+
152
+ from torch.utils.data import DataLoader
153
+ from coco_cf_loader import MS_COCO_dataset, custom_collate_fn
154
+
155
+ torch.manual_seed(42)
156
+
157
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
158
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
159
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
160
+
161
+
162
+ for data_seed in args.data_seeds:
163
+
164
+ if args.data_name not in ['MS_COCO', 'all']:
165
+ print(f"Data Seed: {data_seed} | Data Name: {args.data_name} | Method: {args.method}")
166
+ dataset = MS_COCO_dataset(base_dir=f'./clip_train_datasets/MS_COCO_{args.method}',
167
+ annotation_file=f'/json_files/data_name_{args.data_name}_data_seed_{data_seed}.json')
168
+ elif args.data_name == 'all':
169
+ print(f"Data Name: {args.data_name} | Method: {args.method}")
170
+ dataset = MS_COCO_dataset(base_dir=f'./clip_train_datasets/MS_COCO_{args.method}',
171
+ annotation_file=f'/json_files/data_name_{args.data_name}.json')
172
+ else:
173
+ print(f"Data Name: {args.data_name} | Method: {args.method}")
174
+ dataset = MS_COCO_dataset(base_dir=f'./clip_train_datasets/MS_COCO',
175
+ annotation_file=f'/ms_coco_captions.json')
176
+
177
+ train_size = int(0.8 * len(dataset)) # 80% for training
178
+ val_size = len(dataset) - train_size # 20% for validation
179
+
180
+ # Randomly split into training and validation datasets
181
+ train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
182
+
183
+ # Optional: Create DataLoaders for each subset
184
+ train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=custom_collate_fn)
185
+ val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, collate_fn=custom_collate_fn,drop_last=True)
186
+
187
+ trainer = ModelTrainer(model=model,
188
+ processor=processor,
189
+ data_name=args.data_name,
190
+ train_data_loader=train_loader,
191
+ val_data_loader=val_loader,
192
+ num_epochs=args.num_epochs,
193
+ learning_rate=args.learning_rate,
194
+ weight_decay=1e-3,
195
+ device=device,
196
+ data_seed=data_seed,
197
+ save_model=args.save_model,
198
+ save_model_path=args.save_model_path,
199
+ method=args.method,
200
+ )
201
+
202
+ trainer.train()
203
+ trainer.eval()
204
+ if args.data_name in ['MS_COCO','all']:
205
+ break
206
+
207
+
208
+ if __name__ == "__main__":
209
+ main()
vlm_eval/coco_cf_loader.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from torch.utils.data import DataLoader
3
+ import os
4
+ import json
5
+ from PIL import Image
6
+
7
+
8
+ class MS_COCO_dataset(Dataset):
9
+
10
+ def __init__(self, base_dir, annotation_file=None):
11
+
12
+ self.data= []
13
+ self.img_dir = base_dir + '/images'
14
+ self.annotation_file = base_dir + annotation_file
15
+
16
+ with open(self.annotation_file, 'r') as file:
17
+ for line in file:
18
+ self.data.append(json.loads(line))
19
+
20
+ def __len__(self):
21
+ return len(self.data)
22
+
23
+ def __getitem__(self, idx):
24
+ # Extract the relevant info from the JSONL entry
25
+ img_name = os.path.join(self.img_dir, f"{self.data[idx]['image_name']}")
26
+ caption = self.data[idx]['caption']
27
+ sample_id = self.data[idx]['image_id']
28
+
29
+ # Load the image using PIL
30
+ img = Image.open(img_name)
31
+
32
+ return {"id": sample_id,
33
+ "image": img,
34
+ "caption": caption
35
+ }
36
+
37
+ class COCO_CF_dataset(Dataset):
38
+
39
+ def __init__(self, base_dir):
40
+
41
+ self.data= []
42
+ self.img_dir = base_dir + '/images'
43
+ self.annotation_file = base_dir + "/examples.jsonl"
44
+
45
+ with open(self.annotation_file, 'r') as file:
46
+ for line in file:
47
+ self.data.append(json.loads(line))
48
+
49
+ def __len__(self):
50
+ return len(self.data)
51
+
52
+ def __getitem__(self, idx):
53
+ # Extract the relevant info from the JSONL entry
54
+ img_0_name = os.path.join(self.img_dir, f"{self.data[idx]['image_0']}.jpg")
55
+ img_1_name = os.path.join(self.img_dir, f"{self.data[idx]['image_1']}.jpg")
56
+
57
+ caption_0 = self.data[idx]['caption_0']
58
+ caption_1 = self.data[idx]['caption_1']
59
+ sample_id = self.data[idx]['id']
60
+
61
+ # Load the image using PIL
62
+ img_0 = Image.open(img_0_name)
63
+ img_1 = Image.open(img_1_name)
64
+
65
+ return {"id": sample_id,
66
+ "caption_0": caption_0,
67
+ "caption_1": caption_1,
68
+ "image_0": img_0,
69
+ "image_1": img_1}
70
+
71
+ def custom_collate_fn(batch):
72
+ collated_batch = {}
73
+ for key in batch[0].keys():
74
+ collated_batch[key] = [item[key] for item in batch]
75
+ return collated_batch
76
+
77
+ if __name__ == "__main__":
78
+
79
+ base_dir = '/home/htc/kchitranshi/SCRATCH/MS_COCO/'
80
+ data = MS_COCO_dataset(base_dir=base_dir)
81
+ data_loader = DataLoader(data, batch_size=10,collate_fn=custom_collate_fn)
82
+
83
+ for batch in data_loader:
84
+ print(batch)
85
+ break
86
+
87
+
88
+
89
+
90
+
vlm_eval/create_clip_dataset.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ import numpy as np
4
+ import random
5
+
6
+
7
+
8
+ def main():
9
+
10
+ # Intialising seeds for data
11
+ data_seeds = [i for i in range(107,122)]
12
+
13
+ ms_coco_base_anno_path = "./clip_train_datasets/MS_COCO/ms_coco_captions.json"
14
+ attack_base_anno_path = "./clip_train_datasets/COCO_CF/examples.jsonl"
15
+
16
+ data_names = ["base","medium","all"]
17
+
18
+ ms_coco_array = []
19
+ attack_array = []
20
+
21
+ with open(ms_coco_base_anno_path, 'r') as file:
22
+ for line in file:
23
+ ms_coco_array.append(json.loads(line))
24
+
25
+
26
+ with open(attack_base_anno_path, 'r') as file:
27
+ for line in file:
28
+ attack_array.append(json.loads(line))
29
+
30
+ for data_name in data_names:
31
+ for data_seed in data_seeds:
32
+ if data_name == "base":
33
+ num_ms_coco_samples = 8705
34
+ num_attacks_samples = 4353 # These many pairs of samples with their counterfactuals or adv attack samples. Effectively 8706 in total.
35
+ elif data_name == "medium":
36
+ num_ms_coco_samples = 17410
37
+ num_attacks_samples = int(0.75 * 17410) # These many pairs of samples with their counterfactuals or adv attack samples. Effectively 26115 in total.
38
+ elif data_name == "all":
39
+ num_ms_coco_samples = 17410
40
+ num_attacks_samples = 17410 # These many pairs of samples with their counterfactuals or adv attack samples. Effectively 34820 in total.
41
+
42
+ np.random.seed(data_seed)
43
+ ms_coco_rand_indices = np.random.choice(len(ms_coco_array), num_ms_coco_samples, replace=False)
44
+ attack_rand_indices = np.random.choice(len(attack_array), num_attacks_samples, replace=False)
45
+
46
+ ms_coco_samples = [ms_coco_array[int(i)] for i in ms_coco_rand_indices]
47
+ attack_samples = [attack_array[int(i)] for i in attack_rand_indices]
48
+ attack_samples = [{"image_id": batch["id"], "image_name": batch[f"image_{i}"] + ".jpg", "caption": batch[f"caption_{i}"]} for batch in attack_samples for i in range(2)]
49
+
50
+ random.seed(42)
51
+ combined_dataset = ms_coco_samples + attack_samples
52
+
53
+ random.shuffle(combined_dataset)
54
+
55
+ if data_name != 'all':
56
+ with open(f"./clip_train_datasets/MS_COCO_APGD_4/json_files/data_name_{data_name}_data_seed_{data_seed}.json", 'w') as file:
57
+ for line in combined_dataset:
58
+ file.write(json.dumps(line) + '\n')
59
+ else:
60
+ with open(f"./clip_train_datasets/MS_COCO_APGD_4/json_files/data_name_{data_name}.json", 'w') as file:
61
+ for line in combined_dataset:
62
+ file.write(json.dumps(line) + '\n')
63
+
64
+ if __name__ == "__main__":
65
+ main()
vlm_eval/datasets_classes_templates.py ADDED
@@ -0,0 +1,822 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code taken and adapted from https://github.com/openai/CLIP/blob/main/data/prompts.md
2
+
3
+ CIFAR10_CLASSES_TEMPLATES = {
4
+ 'classes' : [
5
+ 'airplane',
6
+ 'automobile',
7
+ 'bird',
8
+ 'cat',
9
+ 'deer',
10
+ 'dog',
11
+ 'frog',
12
+ 'horse',
13
+ 'ship',
14
+ 'truck',
15
+ ],
16
+
17
+ 'templates' : [
18
+ 'a photo of a {}.',
19
+ 'a blurry photo of a {}.',
20
+ 'a black and white photo of a {}.',
21
+ 'a low contrast photo of a {}.',
22
+ 'a high contrast photo of a {}.',
23
+ 'a bad photo of a {}.',
24
+ 'a good photo of a {}.',
25
+ 'a photo of a small {}.',
26
+ 'a photo of a big {}.',
27
+ 'a photo of the {}.',
28
+ 'a blurry photo of the {}.',
29
+ 'a black and white photo of the {}.',
30
+ 'a low contrast photo of the {}.',
31
+ 'a high contrast photo of the {}.',
32
+ 'a bad photo of the {}.',
33
+ 'a good photo of the {}.',
34
+ 'a photo of the small {}.',
35
+ 'a photo of the big {}.',
36
+ ]
37
+ }
38
+
39
+ CIFAR100_CLASSES_TEMPLATES = {
40
+ 'classes' : [
41
+ 'apple',
42
+ 'aquarium fish',
43
+ 'baby',
44
+ 'bear',
45
+ 'beaver',
46
+ 'bed',
47
+ 'bee',
48
+ 'beetle',
49
+ 'bicycle',
50
+ 'bottle',
51
+ 'bowl',
52
+ 'boy',
53
+ 'bridge',
54
+ 'bus',
55
+ 'butterfly',
56
+ 'camel',
57
+ 'can',
58
+ 'castle',
59
+ 'caterpillar',
60
+ 'cattle',
61
+ 'chair',
62
+ 'chimpanzee',
63
+ 'clock',
64
+ 'cloud',
65
+ 'cockroach',
66
+ 'couch',
67
+ 'crab',
68
+ 'crocodile',
69
+ 'cup',
70
+ 'dinosaur',
71
+ 'dolphin',
72
+ 'elephant',
73
+ 'flatfish',
74
+ 'forest',
75
+ 'fox',
76
+ 'girl',
77
+ 'hamster',
78
+ 'house',
79
+ 'kangaroo',
80
+ 'keyboard',
81
+ 'lamp',
82
+ 'lawn mower',
83
+ 'leopard',
84
+ 'lion',
85
+ 'lizard',
86
+ 'lobster',
87
+ 'man',
88
+ 'maple tree',
89
+ 'motorcycle',
90
+ 'mountain',
91
+ 'mouse',
92
+ 'mushroom',
93
+ 'oak tree',
94
+ 'orange',
95
+ 'orchid',
96
+ 'otter',
97
+ 'palm tree',
98
+ 'pear',
99
+ 'pickup truck',
100
+ 'pine tree',
101
+ 'plain',
102
+ 'plate',
103
+ 'poppy',
104
+ 'porcupine',
105
+ 'possum',
106
+ 'rabbit',
107
+ 'raccoon',
108
+ 'ray',
109
+ 'road',
110
+ 'rocket',
111
+ 'rose',
112
+ 'sea',
113
+ 'seal',
114
+ 'shark',
115
+ 'shrew',
116
+ 'skunk',
117
+ 'skyscraper',
118
+ 'snail',
119
+ 'snake',
120
+ 'spider',
121
+ 'squirrel',
122
+ 'streetcar',
123
+ 'sunflower',
124
+ 'sweet pepper',
125
+ 'table',
126
+ 'tank',
127
+ 'telephone',
128
+ 'television',
129
+ 'tiger',
130
+ 'tractor',
131
+ 'train',
132
+ 'trout',
133
+ 'tulip',
134
+ 'turtle',
135
+ 'wardrobe',
136
+ 'whale',
137
+ 'willow tree',
138
+ 'wolf',
139
+ 'woman',
140
+ 'worm',
141
+ ],
142
+
143
+ 'templates' : [
144
+ 'a photo of a {}.',
145
+ 'a blurry photo of a {}.',
146
+ 'a black and white photo of a {}.',
147
+ 'a low contrast photo of a {}.',
148
+ 'a high contrast photo of a {}.',
149
+ 'a bad photo of a {}.',
150
+ 'a good photo of a {}.',
151
+ 'a photo of a small {}.',
152
+ 'a photo of a big {}.',
153
+ 'a photo of the {}.',
154
+ 'a blurry photo of the {}.',
155
+ 'a black and white photo of the {}.',
156
+ 'a low contrast photo of the {}.',
157
+ 'a high contrast photo of the {}.',
158
+ 'a bad photo of the {}.',
159
+ 'a good photo of the {}.',
160
+ 'a photo of the small {}.',
161
+ 'a photo of the big {}.',
162
+ ]
163
+ }
164
+
165
+ ImageNet_CLASSES_TEMPLATES = {
166
+ 'classes' : ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", "box turtle", "banded gecko", "green iguana", "Carolina anole", "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", "American alligator", "triceratops", "worm snake", "ring-necked snake", "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", "freight car", "French horn", "frying pan", "fur coat", "garbage truck", "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", "mailbox", "tights",
167
+ "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", "microwave oven",
168
+ "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", "moped",
169
+ "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", "neck brace",
170
+ "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", "oxygen mask",
171
+ "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", "parking meter",
172
+ "railroad car", "patio", "payphone", "pedestal", "pencil case", "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", "picket fence", "pickup truck", "pier",
173
+ "piggy bank", "pill bottle", "pillow", "ping-pong ball", "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", "plate rack", "farm plow", "plunger",
174
+ "Polaroid camera", "pole", "police van", "poncho", "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", "printer", "prison", "missile", "projector",
175
+ "hockey puck", "punching bag", "purse", "quill", "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", "recreational vehicle",
176
+ "fishing casting reel", "reflex camera", "refrigerator", "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", "rugby ball", "ruler measuring stick",
177
+ "sneaker", "safe", "safety pin", "salt shaker", "sandal", "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", "CRT monitor", "screw", "screwdriver",
178
+ "seat belt", "sewing machine", "shield", "shoe store", "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", "shower curtain", "ski",
179
+ "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock",
180
+ "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car",
181
+ "spotlight", "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", "mop", "sweatshirt", "swim trunks / shorts", "swing",
182
+ "electrical switch", "syringe", "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"]
183
+ ,
184
+ 'templates' : [
185
+ 'a bad photo of a {}.',
186
+ 'a photo of many {}.',
187
+ 'a sculpture of a {}.',
188
+ 'a photo of the hard to see {}.',
189
+ 'a low resolution photo of the {}.',
190
+ 'a rendering of a {}.',
191
+ 'graffiti of a {}.',
192
+ 'a bad photo of the {}.',
193
+ 'a cropped photo of the {}.',
194
+ 'a tattoo of a {}.',
195
+ 'the embroidered {}.',
196
+ 'a photo of a hard to see {}.',
197
+ 'a bright photo of a {}.',
198
+ 'a photo of a clean {}.',
199
+ 'a photo of a dirty {}.',
200
+ 'a dark photo of the {}.',
201
+ 'a drawing of a {}.',
202
+ 'a photo of my {}.',
203
+ 'the plastic {}.',
204
+ 'a photo of the cool {}.',
205
+ 'a close-up photo of a {}.',
206
+ 'a black and white photo of the {}.',
207
+ 'a painting of the {}.',
208
+ 'a painting of a {}.',
209
+ 'a pixelated photo of the {}.',
210
+ 'a sculpture of the {}.',
211
+ 'a bright photo of the {}.',
212
+ 'a cropped photo of a {}.',
213
+ 'a plastic {}.',
214
+ 'a photo of the dirty {}.',
215
+ 'a jpeg corrupted photo of a {}.',
216
+ 'a blurry photo of the {}.',
217
+ 'a photo of the {}.',
218
+ 'a good photo of the {}.',
219
+ 'a rendering of the {}.',
220
+ 'a {} in a video game.',
221
+ 'a photo of one {}.',
222
+ 'a doodle of a {}.',
223
+ 'a close-up photo of the {}.',
224
+ 'a photo of a {}.',
225
+ 'the origami {}.',
226
+ 'the {} in a video game.',
227
+ 'a sketch of a {}.',
228
+ 'a doodle of the {}.',
229
+ 'a origami {}.',
230
+ 'a low resolution photo of a {}.',
231
+ 'the toy {}.',
232
+ 'a rendition of the {}.',
233
+ 'a photo of the clean {}.',
234
+ 'a photo of a large {}.',
235
+ 'a rendition of a {}.',
236
+ 'a photo of a nice {}.',
237
+ 'a photo of a weird {}.',
238
+ 'a blurry photo of a {}.',
239
+ 'a cartoon {}.',
240
+ 'art of a {}.',
241
+ 'a sketch of the {}.',
242
+ 'a embroidered {}.',
243
+ 'a pixelated photo of a {}.',
244
+ 'itap of the {}.',
245
+ 'a jpeg corrupted photo of the {}.',
246
+ 'a good photo of a {}.',
247
+ 'a plushie {}.',
248
+ 'a photo of the nice {}.',
249
+ 'a photo of the small {}.',
250
+ 'a photo of the weird {}.',
251
+ 'the cartoon {}.',
252
+ 'art of the {}.',
253
+ 'a drawing of the {}.',
254
+ 'a photo of the large {}.',
255
+ 'a black and white photo of a {}.',
256
+ 'the plushie {}.',
257
+ 'a dark photo of a {}.',
258
+ 'itap of a {}.',
259
+ 'graffiti of the {}.',
260
+ 'a toy {}.',
261
+ 'itap of my {}.',
262
+ 'a photo of a cool {}.',
263
+ 'a photo of a small {}.',
264
+ 'a tattoo of the {}.',
265
+ ]
266
+ }
267
+
268
+ Caltech101_CLASSES_TEMPLATES = {
269
+
270
+
271
+ 'classes' : ['Faces',
272
+ 'Faces_easy',
273
+ 'Leopards',
274
+ 'Motorbikes',
275
+ 'accordion',
276
+ 'airplanes',
277
+ 'anchor',
278
+ 'ant',
279
+ 'barrel',
280
+ 'bass',
281
+ 'beaver',
282
+ 'binocular',
283
+ 'bonsai',
284
+ 'brain',
285
+ 'brontosaurus',
286
+ 'buddha',
287
+ 'butterfly',
288
+ 'camera',
289
+ 'cannon',
290
+ 'car_side',
291
+ 'ceiling_fan',
292
+ 'cellphone',
293
+ 'chair',
294
+ 'chandelier',
295
+ 'cougar_body',
296
+ 'cougar_face',
297
+ 'crab',
298
+ 'crayfish',
299
+ 'crocodile',
300
+ 'crocodile_head',
301
+ 'cup',
302
+ 'dalmatian',
303
+ 'dollar_bill',
304
+ 'dolphin',
305
+ 'dragonfly',
306
+ 'electric_guitar',
307
+ 'elephant',
308
+ 'emu',
309
+ 'euphonium',
310
+ 'ewer',
311
+ 'ferry',
312
+ 'flamingo',
313
+ 'flamingo_head',
314
+ 'garfield',
315
+ 'gerenuk',
316
+ 'gramophone',
317
+ 'grand_piano',
318
+ 'hawksbill',
319
+ 'headphone',
320
+ 'hedgehog',
321
+ 'helicopter',
322
+ 'ibis',
323
+ 'inline_skate',
324
+ 'joshua_tree',
325
+ 'kangaroo',
326
+ 'ketch',
327
+ 'lamp',
328
+ 'laptop',
329
+ 'llama',
330
+ 'lobster',
331
+ 'lotus',
332
+ 'mandolin',
333
+ 'mayfly',
334
+ 'menorah',
335
+ 'metronome',
336
+ 'minaret',
337
+ 'nautilus',
338
+ 'octopus',
339
+ 'okapi',
340
+ 'pagoda',
341
+ 'panda',
342
+ 'pigeon',
343
+ 'pizza',
344
+ 'platypus',
345
+ 'pyramid',
346
+ 'revolver',
347
+ 'rhino',
348
+ 'rooster',
349
+ 'saxophone',
350
+ 'schooner',
351
+ 'scissors',
352
+ 'scorpion',
353
+ 'sea_horse',
354
+ 'snoopy',
355
+ 'soccer_ball',
356
+ 'stapler',
357
+ 'starfish',
358
+ 'stegosaurus',
359
+ 'stop_sign',
360
+ 'strawberry',
361
+ 'sunflower',
362
+ 'tick',
363
+ 'trilobite',
364
+ 'umbrella',
365
+ 'watch',
366
+ 'water_lilly',
367
+ 'wheelchair',
368
+ 'wild_cat',
369
+ 'windsor_chair',
370
+ 'wrench',
371
+ 'yin_yang']
372
+ ,
373
+
374
+
375
+ 'templates' : [
376
+ 'a photo of a {}.',
377
+ 'a painting of a {}.',
378
+ 'a plastic {}.',
379
+ 'a sculpture of a {}.',
380
+ 'a sketch of a {}.',
381
+ 'a tattoo of a {}.',
382
+ 'a toy {}.',
383
+ 'a rendition of a {}.',
384
+ 'a embroidered {}.',
385
+ 'a cartoon {}.',
386
+ 'a {} in a video game.',
387
+ 'a plushie {}.',
388
+ 'a origami {}.',
389
+ 'art of a {}.',
390
+ 'graffiti of a {}.',
391
+ 'a drawing of a {}.',
392
+ 'a doodle of a {}.',
393
+ 'a photo of the {}.',
394
+ 'a painting of the {}.',
395
+ 'the plastic {}.',
396
+ 'a sculpture of the {}.',
397
+ 'a sketch of the {}.',
398
+ 'a tattoo of the {}.',
399
+ 'the toy {}.',
400
+ 'a rendition of the {}.',
401
+ 'the embroidered {}.',
402
+ 'the cartoon {}.',
403
+ 'the {} in a video game.',
404
+ 'the plushie {}.',
405
+ 'the origami {}.',
406
+ 'art of the {}.',
407
+ 'graffiti of the {}.',
408
+ 'a drawing of the {}.',
409
+ 'a doodle of the {}.',
410
+ ]
411
+ }
412
+
413
+ Caltech256_CLASSES_TEMPLATES = {
414
+ 'classes' : [
415
+ 'ak47',
416
+ 'american flag',
417
+ 'backpack',
418
+ 'baseball bat',
419
+ 'baseball glove',
420
+ 'basketball hoop',
421
+ 'bat',
422
+ 'bathtub',
423
+ 'bear',
424
+ 'beer mug',
425
+ 'billiards',
426
+ 'binoculars',
427
+ 'birdbath',
428
+ 'blimp',
429
+ 'bonsai',
430
+ 'boom box',
431
+ 'bowling ball',
432
+ 'bowling pin',
433
+ 'boxing glove',
434
+ 'brain',
435
+ 'breadmaker',
436
+ 'buddha',
437
+ 'bulldozer',
438
+ 'butterfly',
439
+ 'cactus',
440
+ 'cake',
441
+ 'calculator',
442
+ 'camel',
443
+ 'cannon',
444
+ 'canoe',
445
+ 'car tire',
446
+ 'cartman',
447
+ 'cd',
448
+ 'centipede',
449
+ 'cereal box',
450
+ 'chandelier',
451
+ 'chess board',
452
+ 'chimp',
453
+ 'chopsticks',
454
+ 'cockroach',
455
+ 'coffee mug',
456
+ 'coffin',
457
+ 'coin',
458
+ 'comet',
459
+ 'computer keyboard',
460
+ 'computer monitor',
461
+ 'computer mouse',
462
+ 'conch',
463
+ 'cormorant',
464
+ 'covered wagon',
465
+ 'cowboy hat',
466
+ 'crab',
467
+ 'desk globe',
468
+ 'diamond ring',
469
+ 'dice',
470
+ 'dog',
471
+ 'dolphin',
472
+ 'doorknob',
473
+ 'drinking straw',
474
+ 'duck',
475
+ 'dumb bell',
476
+ 'eiffel tower',
477
+ 'electric guitar',
478
+ 'elephant',
479
+ 'elk',
480
+ 'ewer',
481
+ 'eyeglasses',
482
+ 'fern',
483
+ 'fighter jet',
484
+ 'fire extinguisher',
485
+ 'fire hydrant',
486
+ 'fire truck',
487
+ 'fireworks',
488
+ 'flashlight',
489
+ 'floppy disk',
490
+ 'football helmet',
491
+ 'french horn',
492
+ 'fried egg',
493
+ 'frisbee',
494
+ 'frog',
495
+ 'frying pan',
496
+ 'galaxy',
497
+ 'gas pump',
498
+ 'giraffe',
499
+ 'goat',
500
+ 'golden gate bridge',
501
+ 'goldfish',
502
+ 'golf ball',
503
+ 'goose',
504
+ 'gorilla',
505
+ 'grand piano',
506
+ 'grapes',
507
+ 'grasshopper',
508
+ 'guitar pick',
509
+ 'hamburger',
510
+ 'hammock',
511
+ 'harmonica',
512
+ 'harp',
513
+ 'harpsichord',
514
+ 'hawksbill',
515
+ 'head phones',
516
+ 'helicopter',
517
+ 'hibiscus',
518
+ 'homer simpson',
519
+ 'horse',
520
+ 'horseshoe crab',
521
+ 'hot air balloon',
522
+ 'hot dog',
523
+ 'hot tub',
524
+ 'hourglass',
525
+ 'house fly',
526
+ 'human skeleton',
527
+ 'hummingbird',
528
+ 'ibis',
529
+ 'ice cream cone',
530
+ 'iguana',
531
+ 'ipod',
532
+ 'iris',
533
+ 'jesus christ',
534
+ 'joy stick',
535
+ 'kangaroo',
536
+ 'kayak',
537
+ 'ketch',
538
+ 'killer whale',
539
+ 'knife',
540
+ 'ladder',
541
+ 'laptop',
542
+ 'lathe',
543
+ 'leopards',
544
+ 'license plate',
545
+ 'lightbulb',
546
+ 'light house',
547
+ 'lightning',
548
+ 'llama',
549
+ 'mailbox',
550
+ 'mandolin',
551
+ 'mars',
552
+ 'mattress',
553
+ 'megaphone',
554
+ 'menorah',
555
+ 'microscope',
556
+ 'microwave',
557
+ 'minaret',
558
+ 'minotaur',
559
+ 'motorbikes',
560
+ 'mountain bike',
561
+ 'mushroom',
562
+ 'mussels',
563
+ 'necktie',
564
+ 'octopus',
565
+ 'ostrich',
566
+ 'owl',
567
+ 'palm pilot',
568
+ 'palm tree',
569
+ 'paperclip',
570
+ 'paper shredder',
571
+ 'pci card',
572
+ 'penguin',
573
+ 'people',
574
+ 'pez dispenser',
575
+ 'photocopier',
576
+ 'picnic table',
577
+ 'playing card',
578
+ 'porcupine',
579
+ 'pram',
580
+ 'praying mantis',
581
+ 'pyramid',
582
+ 'raccoon',
583
+ 'radio telescope',
584
+ 'rainbow',
585
+ 'refrigerator',
586
+ 'revolver',
587
+ 'rifle',
588
+ 'rotary phone',
589
+ 'roulette wheel',
590
+ 'saddle',
591
+ 'saturn',
592
+ 'school bus',
593
+ 'scorpion',
594
+ 'screwdriver',
595
+ 'segway',
596
+ 'self propelled lawn mower',
597
+ 'sextant',
598
+ 'sheet music',
599
+ 'skateboard',
600
+ 'skunk',
601
+ 'skyscraper',
602
+ 'smokestack',
603
+ 'snail',
604
+ 'snake',
605
+ 'sneaker',
606
+ 'snowmobile',
607
+ 'soccer ball',
608
+ 'socks',
609
+ 'soda can',
610
+ 'spaghetti',
611
+ 'speed boat',
612
+ 'spider',
613
+ 'spoon',
614
+ 'stained glass',
615
+ 'starfish',
616
+ 'steering wheel',
617
+ 'stirrups',
618
+ 'sunflower',
619
+ 'superman',
620
+ 'sushi',
621
+ 'swan',
622
+ 'swiss army knife',
623
+ 'sword',
624
+ 'syringe',
625
+ 'tambourine',
626
+ 'teapot',
627
+ 'teddy bear',
628
+ 'teepee',
629
+ 'telephone box',
630
+ 'tennis ball',
631
+ 'tennis court',
632
+ 'tennis racket',
633
+ 'theodolite',
634
+ 'toaster',
635
+ 'tomato',
636
+ 'tombstone',
637
+ 'top hat',
638
+ 'touring bike',
639
+ 'tower pisa',
640
+ 'traffic light',
641
+ 'treadmill',
642
+ 'triceratops',
643
+ 'tricycle',
644
+ 'trilobite',
645
+ 'tripod',
646
+ 't shirt',
647
+ 'tuning fork',
648
+ 'tweezer',
649
+ 'umbrella',
650
+ 'unicorn',
651
+ 'vcr',
652
+ 'video projector',
653
+ 'washing machine',
654
+ 'watch',
655
+ 'waterfall',
656
+ 'watermelon',
657
+ 'welding mask',
658
+ 'wheelbarrow',
659
+ 'windmill',
660
+ 'wine bottle',
661
+ 'xylophone',
662
+ 'yarmulke',
663
+ 'yo yo',
664
+ 'zebra',
665
+ 'airplanes',
666
+ 'car side',
667
+ 'faces easy',
668
+ 'greyhound',
669
+ 'tennis shoes',
670
+ 'toad',
671
+ 'clutter'
672
+ ],
673
+
674
+ 'templates' : [
675
+ 'a photo of a {}.',
676
+ 'a painting of a {}.',
677
+ 'a plastic {}.',
678
+ 'a sculpture of a {}.',
679
+ 'a sketch of a {}.',
680
+ 'a tattoo of a {}.',
681
+ 'a toy {}.',
682
+ 'a rendition of a {}.',
683
+ 'a embroidered {}.',
684
+ 'a cartoon {}.',
685
+ 'a {} in a video game.',
686
+ 'a plushie {}.',
687
+ 'a origami {}.',
688
+ 'art of a {}.',
689
+ 'graffiti of a {}.',
690
+ 'a drawing of a {}.',
691
+ 'a doodle of a {}.',
692
+ 'a photo of the {}.',
693
+ 'a painting of the {}.',
694
+ 'the plastic {}.',
695
+ 'a sculpture of the {}.',
696
+ 'a sketch of the {}.',
697
+ 'a tattoo of the {}.',
698
+ 'the toy {}.',
699
+ 'a rendition of the {}.',
700
+ 'the embroidered {}.',
701
+ 'the cartoon {}.',
702
+ 'the {} in a video game.',
703
+ 'the plushie {}.',
704
+ 'the origami {}.',
705
+ 'art of the {}.',
706
+ 'graffiti of the {}.',
707
+ 'a drawing of the {}.',
708
+ 'a doodle of the {}.',
709
+ ]
710
+ }
711
+
712
+ Food101_CLASSES_TEMPLATES = {
713
+ 'classes' : [
714
+ 'apple pie',
715
+ 'baby back ribs',
716
+ 'baklava',
717
+ 'beef carpaccio',
718
+ 'beef tartare',
719
+ 'beet salad',
720
+ 'beignets',
721
+ 'bibimbap',
722
+ 'bread pudding',
723
+ 'breakfast burrito',
724
+ 'bruschetta',
725
+ 'caesar salad',
726
+ 'cannoli',
727
+ 'caprese salad',
728
+ 'carrot cake',
729
+ 'ceviche',
730
+ 'cheese plate',
731
+ 'cheesecake',
732
+ 'chicken curry',
733
+ 'chicken quesadilla',
734
+ 'chicken wings',
735
+ 'chocolate cake',
736
+ 'chocolate mousse',
737
+ 'churros',
738
+ 'clam chowder',
739
+ 'club sandwich',
740
+ 'crab cakes',
741
+ 'creme brulee',
742
+ 'croque madame',
743
+ 'cup cakes',
744
+ 'deviled eggs',
745
+ 'donuts',
746
+ 'dumplings',
747
+ 'edamame',
748
+ 'eggs benedict',
749
+ 'escargots',
750
+ 'falafel',
751
+ 'filet mignon',
752
+ 'fish and chips',
753
+ 'foie gras',
754
+ 'french fries',
755
+ 'french onion soup',
756
+ 'french toast',
757
+ 'fried calamari',
758
+ 'fried rice',
759
+ 'frozen yogurt',
760
+ 'garlic bread',
761
+ 'gnocchi',
762
+ 'greek salad',
763
+ 'grilled cheese sandwich',
764
+ 'grilled salmon',
765
+ 'guacamole',
766
+ 'gyoza',
767
+ 'hamburger',
768
+ 'hot and sour soup',
769
+ 'hot dog',
770
+ 'huevos rancheros',
771
+ 'hummus',
772
+ 'ice cream',
773
+ 'lasagna',
774
+ 'lobster bisque',
775
+ 'lobster roll sandwich',
776
+ 'macaroni and cheese',
777
+ 'macarons',
778
+ 'miso soup',
779
+ 'mussels',
780
+ 'nachos',
781
+ 'omelette',
782
+ 'onion rings',
783
+ 'oysters',
784
+ 'pad thai',
785
+ 'paella',
786
+ 'pancakes',
787
+ 'panna cotta',
788
+ 'peking duck',
789
+ 'pho',
790
+ 'pizza',
791
+ 'pork chop',
792
+ 'poutine',
793
+ 'prime rib',
794
+ 'pulled pork sandwich',
795
+ 'ramen',
796
+ 'ravioli',
797
+ 'red velvet cake',
798
+ 'risotto',
799
+ 'samosa',
800
+ 'sashimi',
801
+ 'scallops',
802
+ 'seaweed salad',
803
+ 'shrimp and grits',
804
+ 'spaghetti bolognese',
805
+ 'spaghetti carbonara',
806
+ 'spring rolls',
807
+ 'steak',
808
+ 'strawberry shortcake',
809
+ 'sushi',
810
+ 'tacos',
811
+ 'takoyaki',
812
+ 'tiramisu',
813
+ 'tuna tartare',
814
+ 'waffles',
815
+ ],
816
+
817
+ 'templates' : [
818
+ 'a photo of {}, a type of food.',
819
+ ]
820
+ }
821
+
822
+ data_seeds = [107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121]
vlm_eval/ms_coco_gen.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torchvision.datasets as dset
4
+ import torchvision.transforms as transforms
5
+ from coco_cf import COCO_CF_dataset
6
+ from torch.utils.data import DataLoader
7
+
8
+ def custom_collate_fn(batch):
9
+ collated_batch = {}
10
+ for key in batch[0].keys():
11
+ collated_batch[key] = [item[key] for item in batch]
12
+ return collated_batch
13
+
14
+ coco_2017 = dset.CocoCaptions(root='./open_flamingo_datasets/COCO_2017/val2017/',
15
+ annFile='./open_flamingo_datasets/COCO_2017/captions_val2017.json',
16
+ transform=transforms.ToTensor())
17
+
18
+ coco_cf = COCO_CF_dataset(base_dir='./open_flamingo_datasets/COCO_CF/')
19
+ dl_coco_cf = DataLoader(coco_cf, batch_size=100,collate_fn=custom_collate_fn)
20
+
21
+
22
+ # Collect both captions from each batch in one step
23
+ coco_cf_captions = []
24
+
25
+ for batch in dl_coco_cf:
26
+ # Extend the list with both captions at once without list comprehension
27
+ coco_cf_captions.extend([caption.replace('.','').replace(",","").replace("-"," ").replace("'s","").lower().strip() for caption in batch['caption_0']])
28
+
29
+ ms_coco_gen_indices = []
30
+ coco_cf_captions_set = set(coco_cf_captions)
31
+
32
+ for index in range(len(coco_2017)):
33
+ image_id = coco_2017.ids[index]
34
+ _,captions = coco_2017[index]
35
+
36
+
37
+ matches = [s for s in captions if s.replace(".","").replace(",","").replace("'s","").replace("-"," ").lower().strip() in coco_cf_captions_set]
38
+
39
+
40
+ for match in matches:
41
+ ms_coco_gen_indices.append((image_id,match))
42
+ ms_coco_gen_indices = ms_coco_gen_indices[:17410]
43
+ print(ms_coco_gen_indices)
44
+ ms_coco = [{'image_id': image_index,'caption': caption} for (image_index, caption) in ms_coco_gen_indices]
45
+
46
+ file_path = 'ms_coco_captions.json'
47
+
48
+ # Save the dictionary to a JSON file
49
+
50
+ import os
51
+
52
+ # Base path where the images are located
53
+ base_image_path = '/home/kc/Downloads/val2017/'
54
+
55
+ # Assuming ms_coco_gen_indices is a list of (image_index, caption) tuples
56
+ ms_coco_gen_indices = [(image_index, caption) for (image_index, caption) in ms_coco_gen_indices]
57
+
58
+ # List to store the updated entries with pathtoimage included
59
+ updated_ms_coco_gen_indices = []
60
+
61
+ # Process each (image_index, caption) in ms_coco_gen_indices
62
+ for image_index, caption in ms_coco_gen_indices:
63
+ # Construct the full path to the image file based on the image_index
64
+ pathtoimage = f"{image_index:012d}.jpg" # Ensure image_index is 12 digits with padding
65
+
66
+ # Append the new entry as (image_index, pathtoimage, caption)
67
+ updated_ms_coco_gen_indices.append((image_index, pathtoimage, caption))
68
+
69
+ # Now ms_coco_gen_indices includes (image_index, pathtoimage, caption)
70
+ ms_coco_gen_indices = updated_ms_coco_gen_indices
71
+ ms_coco = [{'image_id': image_index,'image_name': image_name,'caption': caption} for (image_index,image_name ,caption) in ms_coco_gen_indices]
72
+
73
+ with open(file_path, 'w') as json_file:
74
+ for row in ms_coco:
75
+ json.dump(row, json_file)
76
+ json_file.write('\n')
vlm_eval/run_evaluation.py ADDED
The diff for this file is too large to render. See raw diff