Spaces:
Configuration error
Configuration error
Upload 39 files
Browse files- .gitattributes +1 -0
- vlm_eval/__init__.py +0 -0
- vlm_eval/__pycache__/__init__.cpython-311.pyc +0 -0
- vlm_eval/__pycache__/__init__.cpython-312.pyc +0 -0
- vlm_eval/__pycache__/coco_cf_loader.cpython-311.pyc +0 -0
- vlm_eval/__pycache__/datasets_classes_templates.cpython-311.pyc +0 -0
- vlm_eval/__pycache__/run_evaluation.cpython-311.pyc +3 -0
- vlm_eval/__pycache__/run_evaluation.cpython-312.pyc +0 -0
- vlm_eval/__pycache__/run_evaluation_qualitative.cpython-311.pyc +0 -0
- vlm_eval/attacks/__init__.py +0 -0
- vlm_eval/attacks/__pycache__/__init__.cpython-311.pyc +0 -0
- vlm_eval/attacks/__pycache__/afw.cpython-311.pyc +0 -0
- vlm_eval/attacks/__pycache__/apgd.cpython-311.pyc +0 -0
- vlm_eval/attacks/__pycache__/attack.cpython-311.pyc +0 -0
- vlm_eval/attacks/__pycache__/ead.cpython-311.pyc +0 -0
- vlm_eval/attacks/__pycache__/fwnucl.cpython-311.pyc +0 -0
- vlm_eval/attacks/__pycache__/gse.cpython-311.pyc +0 -0
- vlm_eval/attacks/__pycache__/iht.cpython-311.pyc +0 -0
- vlm_eval/attacks/__pycache__/pgd0.cpython-311.pyc +0 -0
- vlm_eval/attacks/__pycache__/saif.cpython-311.pyc +0 -0
- vlm_eval/attacks/__pycache__/strattack.cpython-311.pyc +0 -0
- vlm_eval/attacks/apgd.py +384 -0
- vlm_eval/attacks/attack.py +20 -0
- vlm_eval/attacks/ead.py +132 -0
- vlm_eval/attacks/fwnucl.py +170 -0
- vlm_eval/attacks/gse.py +313 -0
- vlm_eval/attacks/iht.py +97 -0
- vlm_eval/attacks/pgd.py +88 -0
- vlm_eval/attacks/pgd0.py +131 -0
- vlm_eval/attacks/saif.py +143 -0
- vlm_eval/attacks/sparsers.py +164 -0
- vlm_eval/attacks/strattack.py +229 -0
- vlm_eval/attacks/utils.py +52 -0
- vlm_eval/clip_classification.py +160 -0
- vlm_eval/clip_train.py +209 -0
- vlm_eval/coco_cf_loader.py +90 -0
- vlm_eval/create_clip_dataset.py +65 -0
- vlm_eval/datasets_classes_templates.py +822 -0
- vlm_eval/ms_coco_gen.py +76 -0
- vlm_eval/run_evaluation.py +0 -0
.gitattributes
CHANGED
|
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
MastersThesis_475703.pdf filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
MastersThesis_475703.pdf filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
vlm_eval/__pycache__/run_evaluation.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
vlm_eval/__init__.py
ADDED
|
File without changes
|
vlm_eval/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (163 Bytes). View file
|
|
|
vlm_eval/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (149 Bytes). View file
|
|
|
vlm_eval/__pycache__/coco_cf_loader.cpython-311.pyc
ADDED
|
Binary file (4.87 kB). View file
|
|
|
vlm_eval/__pycache__/datasets_classes_templates.cpython-311.pyc
ADDED
|
Binary file (21.1 kB). View file
|
|
|
vlm_eval/__pycache__/run_evaluation.cpython-311.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6122c088ad6b90b0847802d1d0eaaefe8b2503bfa8c3c29a370d6c4406b59718
|
| 3 |
+
size 113082
|
vlm_eval/__pycache__/run_evaluation.cpython-312.pyc
ADDED
|
Binary file (73.2 kB). View file
|
|
|
vlm_eval/__pycache__/run_evaluation_qualitative.cpython-311.pyc
ADDED
|
Binary file (11.5 kB). View file
|
|
|
vlm_eval/attacks/__init__.py
ADDED
|
File without changes
|
vlm_eval/attacks/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (171 Bytes). View file
|
|
|
vlm_eval/attacks/__pycache__/afw.cpython-311.pyc
ADDED
|
Binary file (6.67 kB). View file
|
|
|
vlm_eval/attacks/__pycache__/apgd.cpython-311.pyc
ADDED
|
Binary file (23 kB). View file
|
|
|
vlm_eval/attacks/__pycache__/attack.cpython-311.pyc
ADDED
|
Binary file (1.24 kB). View file
|
|
|
vlm_eval/attacks/__pycache__/ead.cpython-311.pyc
ADDED
|
Binary file (7.72 kB). View file
|
|
|
vlm_eval/attacks/__pycache__/fwnucl.cpython-311.pyc
ADDED
|
Binary file (11.8 kB). View file
|
|
|
vlm_eval/attacks/__pycache__/gse.cpython-311.pyc
ADDED
|
Binary file (20.1 kB). View file
|
|
|
vlm_eval/attacks/__pycache__/iht.cpython-311.pyc
ADDED
|
Binary file (5.51 kB). View file
|
|
|
vlm_eval/attacks/__pycache__/pgd0.cpython-311.pyc
ADDED
|
Binary file (9.86 kB). View file
|
|
|
vlm_eval/attacks/__pycache__/saif.cpython-311.pyc
ADDED
|
Binary file (8.29 kB). View file
|
|
|
vlm_eval/attacks/__pycache__/strattack.cpython-311.pyc
ADDED
|
Binary file (14 kB). View file
|
|
|
vlm_eval/attacks/apgd.py
ADDED
|
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code adapted from https://github.com/chs20/RobustVLM/tree/main
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class APGD:
|
| 8 |
+
def __init__(self, model, norm, eps, mask_out='context', initial_stepsize=None, decrease_every=None, decrease_every_max=None, random_init=False):
|
| 9 |
+
# model returns loss sum over batch
|
| 10 |
+
# thus currently only works with batch size 1
|
| 11 |
+
# initial_stepsize: in terms of eps. called alpha in apgd
|
| 12 |
+
# decrease_every: potentially decrease stepsize every x fraction of total iterations. default: 0.22
|
| 13 |
+
self.model = model
|
| 14 |
+
self.norm = norm
|
| 15 |
+
self.eps = eps
|
| 16 |
+
self.initial_stepsize = initial_stepsize
|
| 17 |
+
self.decrease_every = decrease_every
|
| 18 |
+
self.decrease_every_max = decrease_every_max
|
| 19 |
+
self.random_init = random_init
|
| 20 |
+
if mask_out != 'none':
|
| 21 |
+
self.mask_out = mask_out
|
| 22 |
+
else:
|
| 23 |
+
self.mask_out = None
|
| 24 |
+
|
| 25 |
+
def perturb(self, data_clean, iterations, pert_init=None, verbose=False):
|
| 26 |
+
mask = self._set_mask(data_clean)
|
| 27 |
+
data_adv, _, _ = apgd(
|
| 28 |
+
self.model, data_clean, norm=self.norm, eps=self.eps, n_iter=iterations,
|
| 29 |
+
use_rs=self.random_init, mask=mask, alpha=self.initial_stepsize,
|
| 30 |
+
n_iter_2=self.decrease_every, n_iter_min=self.decrease_every_max, pert_init=pert_init,
|
| 31 |
+
verbose=verbose
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
return data_adv
|
| 35 |
+
|
| 36 |
+
def _set_mask(self, data):
|
| 37 |
+
mask = torch.ones_like(data)
|
| 38 |
+
if self.mask_out == 'context':
|
| 39 |
+
mask[:, :-1, ...] = 0
|
| 40 |
+
elif self.mask_out == 'query':
|
| 41 |
+
mask[:, -1, ...] = 0
|
| 42 |
+
elif isinstance(self.mask_out, int):
|
| 43 |
+
mask[:, self.mask_out, ...] = 0
|
| 44 |
+
elif self.mask_out is None:
|
| 45 |
+
pass
|
| 46 |
+
else:
|
| 47 |
+
raise NotImplementedError(f'Unknown mask_out: {self.mask_out}')
|
| 48 |
+
return mask
|
| 49 |
+
|
| 50 |
+
def __str__(self):
|
| 51 |
+
return 'APGD'
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def L1_projection(x2, y2, eps1):
|
| 55 |
+
'''
|
| 56 |
+
x2: center of the L1 ball (bs x input_dim)
|
| 57 |
+
y2: current perturbation (x2 + y2 is the point to be projected)
|
| 58 |
+
eps1: radius of the L1 ball
|
| 59 |
+
|
| 60 |
+
output: delta s.th. ||y2 + delta||_1 = eps1
|
| 61 |
+
and 0 <= x2 + y2 + delta <= 1
|
| 62 |
+
'''
|
| 63 |
+
|
| 64 |
+
x = x2.clone().float().view(x2.shape[0], -1)
|
| 65 |
+
y = y2.clone().float().view(y2.shape[0], -1)
|
| 66 |
+
sigma = y.clone().sign()
|
| 67 |
+
u = torch.min(1 - x - y, x + y)
|
| 68 |
+
# u = torch.min(u, epsinf - torch.clone(y).abs())
|
| 69 |
+
u = torch.min(torch.zeros_like(y), u)
|
| 70 |
+
l = -torch.clone(y).abs()
|
| 71 |
+
d = u.clone()
|
| 72 |
+
|
| 73 |
+
bs, indbs = torch.sort(-torch.cat((u, l), 1), dim=1)
|
| 74 |
+
bs2 = torch.cat((bs[:, 1:], torch.zeros(bs.shape[0], 1).to(bs.device)), 1)
|
| 75 |
+
|
| 76 |
+
inu = 2 * (indbs < u.shape[1]).float() - 1
|
| 77 |
+
size1 = inu.cumsum(dim=1)
|
| 78 |
+
|
| 79 |
+
s1 = -u.sum(dim=1)
|
| 80 |
+
|
| 81 |
+
c = eps1 - y.clone().abs().sum(dim=1)
|
| 82 |
+
c5 = s1 + c < 0
|
| 83 |
+
c2 = c5.nonzero().squeeze(1)
|
| 84 |
+
|
| 85 |
+
s = s1.unsqueeze(-1) + torch.cumsum((bs2 - bs) * size1, dim=1)
|
| 86 |
+
# print(s[0])
|
| 87 |
+
|
| 88 |
+
# print(c5.shape, c2)
|
| 89 |
+
|
| 90 |
+
if c2.nelement != 0:
|
| 91 |
+
|
| 92 |
+
lb = torch.zeros_like(c2).float()
|
| 93 |
+
ub = torch.ones_like(lb) * (bs.shape[1] - 1)
|
| 94 |
+
|
| 95 |
+
# print(c2.shape, lb.shape)
|
| 96 |
+
|
| 97 |
+
nitermax = torch.ceil(torch.log2(torch.tensor(bs.shape[1]).float()))
|
| 98 |
+
counter2 = torch.zeros_like(lb).long()
|
| 99 |
+
counter = 0
|
| 100 |
+
|
| 101 |
+
while counter < nitermax:
|
| 102 |
+
counter4 = torch.floor((lb + ub) / 2.)
|
| 103 |
+
counter2 = counter4.type(torch.LongTensor)
|
| 104 |
+
|
| 105 |
+
c8 = s[c2, counter2] + c[c2] < 0
|
| 106 |
+
ind3 = c8.nonzero().squeeze(1)
|
| 107 |
+
ind32 = (~c8).nonzero().squeeze(1)
|
| 108 |
+
# print(ind3.shape)
|
| 109 |
+
if ind3.nelement != 0:
|
| 110 |
+
lb[ind3] = counter4[ind3]
|
| 111 |
+
if ind32.nelement != 0:
|
| 112 |
+
ub[ind32] = counter4[ind32]
|
| 113 |
+
|
| 114 |
+
# print(lb, ub)
|
| 115 |
+
counter += 1
|
| 116 |
+
|
| 117 |
+
lb2 = lb.long()
|
| 118 |
+
alpha = (-s[c2, lb2] - c[c2]) / size1[c2, lb2 + 1] + bs2[c2, lb2]
|
| 119 |
+
d[c2] = -torch.min(torch.max(-u[c2], alpha.unsqueeze(-1)), -l[c2])
|
| 120 |
+
|
| 121 |
+
return (sigma * d).view(x2.shape)
|
| 122 |
+
|
| 123 |
+
def L0_projection(x_adv, x, eps, step_size, lam=0.01):
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
pert = x_adv - x
|
| 127 |
+
|
| 128 |
+
pert_proj = torch.clamp(pert,-eps,eps)
|
| 129 |
+
x_adv_temp = torch.clamp(x + pert_proj,0.,1.)
|
| 130 |
+
pert_proj = x_adv_temp - x
|
| 131 |
+
pert = torch.where(pert ** 2 - (pert_proj - pert) ** 2 > 2 * step_size * lam, pert_proj, 0)
|
| 132 |
+
#pert = torch.where(pert > (2 * lam * step_size) ** 0.5, pert, 0)
|
| 133 |
+
return torch.clamp(x+pert,0.0,1.0)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def L1_norm(x, keepdim=False):
|
| 138 |
+
z = x.abs().view(x.shape[0], -1).sum(-1)
|
| 139 |
+
if keepdim:
|
| 140 |
+
z = z.view(-1, *[1] * (len(x.shape) - 1))
|
| 141 |
+
return z
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def L2_norm(x, keepdim=False):
|
| 145 |
+
z = (x ** 2).view(x.shape[0], -1).sum(-1).sqrt()
|
| 146 |
+
if keepdim:
|
| 147 |
+
z = z.view(-1, *[1] * (len(x.shape) - 1))
|
| 148 |
+
return z
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def L0_norm(x):
|
| 152 |
+
return (x != 0.).view(x.shape[0], -1).sum(-1)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def dlr_loss(x, y, reduction='none'):
|
| 156 |
+
x_sorted, ind_sorted = x.sort(dim=1)
|
| 157 |
+
ind = (ind_sorted[:, -1] == y).float()
|
| 158 |
+
|
| 159 |
+
return -(x[torch.arange(x.shape[0]), y] - x_sorted[:, -2] * ind - \
|
| 160 |
+
x_sorted[:, -1] * (1. - ind)) / (x_sorted[:, -1] - x_sorted[:, -3] + 1e-12)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def dlr_loss_targeted(x, y, y_target):
|
| 164 |
+
x_sorted, ind_sorted = x.sort(dim=1)
|
| 165 |
+
u = torch.arange(x.shape[0])
|
| 166 |
+
|
| 167 |
+
return -(x[u, y] - x[u, y_target]) / (x_sorted[:, -1] - .5 * (
|
| 168 |
+
x_sorted[:, -3] + x_sorted[:, -4]) + 1e-12)
|
| 169 |
+
|
| 170 |
+
def check_oscillation(x, j, k, y5, k3=0.75):
|
| 171 |
+
t = torch.zeros(x.shape[1]).to(x.device)
|
| 172 |
+
for counter5 in range(k):
|
| 173 |
+
t += (x[j - counter5] > x[j - counter5 - 1]).float()
|
| 174 |
+
|
| 175 |
+
return (t <= k * k3 * torch.ones_like(t)).float()
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def apgd(model, x, norm, eps, n_iter=10, use_rs=False, mask=None, alpha=None, n_iter_2=None,
|
| 179 |
+
n_iter_min=None, pert_init=None, verbose=False, is_train=True):
|
| 180 |
+
# from https://github.com/fra31/robust-finetuning
|
| 181 |
+
assert x.shape[0] == 1 # only support batch size 1 for now
|
| 182 |
+
norm = norm.replace('l', 'L')
|
| 183 |
+
device = x.device
|
| 184 |
+
ndims = len(x.shape) - 1
|
| 185 |
+
|
| 186 |
+
if not use_rs:
|
| 187 |
+
x_adv = x.clone()
|
| 188 |
+
else:
|
| 189 |
+
if norm == 'Linf':
|
| 190 |
+
t = torch.zeros_like(x).uniform_(-eps, eps).detach()
|
| 191 |
+
x_adv = x + t
|
| 192 |
+
elif norm == 'L2':
|
| 193 |
+
t = torch.randn(x.shape).to(device).detach()
|
| 194 |
+
x_adv = x + eps * torch.ones_like(x).detach() * t / (L2_norm(t, keepdim=True) + 1e-12)
|
| 195 |
+
if pert_init is not None:
|
| 196 |
+
assert not use_rs
|
| 197 |
+
assert pert_init.shape == x.shape, f'pert_init.shape: {pert_init.shape}, x.shape: {x.shape}'
|
| 198 |
+
x_adv = x + pert_init
|
| 199 |
+
|
| 200 |
+
x_adv = x_adv.clamp(0., 1.)
|
| 201 |
+
x_best = x_adv.clone()
|
| 202 |
+
x_best_adv = x_adv.clone()
|
| 203 |
+
loss_steps = torch.zeros([n_iter, x.shape[0]], device=device)
|
| 204 |
+
loss_best_steps = torch.zeros([n_iter + 1, x.shape[0]], device=device)
|
| 205 |
+
|
| 206 |
+
# set params
|
| 207 |
+
n_fts = math.prod(x.shape[1:])
|
| 208 |
+
if norm in ['Linf', 'L2']:
|
| 209 |
+
n_iter_2_frac = 0.22 if n_iter_2 is None else n_iter_2
|
| 210 |
+
n_iter_min_frac = 0.06 if n_iter_min is None else n_iter_min
|
| 211 |
+
n_iter_2 = max(int(n_iter_2_frac * n_iter), 1)
|
| 212 |
+
n_iter_min = max(int(n_iter_min_frac * n_iter), 1)
|
| 213 |
+
size_decr = max(int(0.03 * n_iter), 1)
|
| 214 |
+
k = n_iter_2 + 0
|
| 215 |
+
thr_decr = .75
|
| 216 |
+
alpha = 2. if alpha is None else alpha
|
| 217 |
+
elif norm in ['L1','L0']:
|
| 218 |
+
k = max(int(.04 * n_iter), 1)
|
| 219 |
+
init_topk = .05 if is_train else .2
|
| 220 |
+
topk = init_topk * torch.ones([x.shape[0]], device=device)
|
| 221 |
+
sp_old = n_fts * torch.ones_like(topk)
|
| 222 |
+
adasp_redstep = 1.5
|
| 223 |
+
adasp_minstep = 10.
|
| 224 |
+
alpha = 1. if alpha is None else alpha
|
| 225 |
+
|
| 226 |
+
step_size = alpha * eps * torch.ones([x.shape[0], *[1] * ndims],
|
| 227 |
+
device=device)
|
| 228 |
+
counter3 = 0
|
| 229 |
+
|
| 230 |
+
x_adv.requires_grad_()
|
| 231 |
+
# grad = torch.zeros_like(x)
|
| 232 |
+
# for _ in range(self.eot_iter)
|
| 233 |
+
with torch.enable_grad():
|
| 234 |
+
loss_indiv = model(x_adv)#.unsqueeze(0)
|
| 235 |
+
loss = loss_indiv.sum()
|
| 236 |
+
# grad += torch.autograd.grad(loss, [x_adv])[0].detach()
|
| 237 |
+
grad = torch.autograd.grad(loss, [x_adv])[0].detach()
|
| 238 |
+
if mask is not None:
|
| 239 |
+
grad *= mask
|
| 240 |
+
# grad /= float(self.eot_iter)
|
| 241 |
+
grad_best = grad.clone()
|
| 242 |
+
x_adv.detach_()
|
| 243 |
+
loss_indiv = loss_indiv.detach()
|
| 244 |
+
loss = loss.detach()
|
| 245 |
+
|
| 246 |
+
loss_best = loss_indiv.detach().clone()
|
| 247 |
+
loss_best_last_check = loss_best.clone()
|
| 248 |
+
reduced_last_check = torch.ones_like(loss_best)
|
| 249 |
+
n_reduced = 0
|
| 250 |
+
|
| 251 |
+
u = torch.arange(x.shape[0], device=device)
|
| 252 |
+
x_adv_old = x_adv.clone().detach()
|
| 253 |
+
|
| 254 |
+
for i in range(n_iter):
|
| 255 |
+
### gradient step
|
| 256 |
+
if True: # with torch.no_grad()
|
| 257 |
+
x_adv = x_adv.detach()
|
| 258 |
+
grad2 = x_adv - x_adv_old
|
| 259 |
+
x_adv_old = x_adv.clone()
|
| 260 |
+
loss_curr = loss.detach().mean()
|
| 261 |
+
|
| 262 |
+
a = 0.75 if i > 0 else 1.0
|
| 263 |
+
|
| 264 |
+
if norm == 'Linf':
|
| 265 |
+
x_adv_1 = x_adv + step_size * torch.sign(grad)
|
| 266 |
+
x_adv_1 = torch.clamp(torch.min(torch.max(x_adv_1,
|
| 267 |
+
x - eps), x + eps), 0.0, 1.0)
|
| 268 |
+
x_adv_1 = torch.clamp(torch.min(torch.max(
|
| 269 |
+
x_adv + (x_adv_1 - x_adv) * a + grad2 * (1 - a),
|
| 270 |
+
x - eps), x + eps), 0.0, 1.0)
|
| 271 |
+
|
| 272 |
+
elif norm == 'L2':
|
| 273 |
+
x_adv_1 = x_adv + step_size * grad / (L2_norm(grad,
|
| 274 |
+
keepdim=True) + 1e-12)
|
| 275 |
+
x_adv_1 = torch.clamp(x + (x_adv_1 - x) / (L2_norm(x_adv_1 - x,
|
| 276 |
+
keepdim=True) + 1e-12) * torch.min(eps * torch.ones_like(x),
|
| 277 |
+
L2_norm(x_adv_1 - x, keepdim=True)), 0.0, 1.0)
|
| 278 |
+
x_adv_1 = x_adv + (x_adv_1 - x_adv) * a + grad2 * (1 - a)
|
| 279 |
+
x_adv_1 = torch.clamp(x + (x_adv_1 - x) / (L2_norm(x_adv_1 - x,
|
| 280 |
+
keepdim=True) + 1e-12) * torch.min(eps * torch.ones_like(x),
|
| 281 |
+
L2_norm(x_adv_1 - x, keepdim=True)), 0.0, 1.0)
|
| 282 |
+
|
| 283 |
+
elif norm == 'L1':
|
| 284 |
+
grad_topk = grad.abs().view(x.shape[0], -1).sort(-1)[0]
|
| 285 |
+
topk_curr = torch.clamp((1. - topk) * n_fts, min=0, max=n_fts - 1).long()
|
| 286 |
+
grad_topk = grad_topk[u, topk_curr].view(-1, *[1] * (len(x.shape) - 1))
|
| 287 |
+
sparsegrad = grad * (grad.abs() >= grad_topk).float()
|
| 288 |
+
x_adv_1 = x_adv + step_size * sparsegrad.sign() / (
|
| 289 |
+
sparsegrad.sign().abs().view(x.shape[0], -1).sum(dim=-1).view(
|
| 290 |
+
-1, 1, 1, 1) + 1e-10)
|
| 291 |
+
|
| 292 |
+
delta_u = x_adv_1 - x
|
| 293 |
+
delta_p = L1_projection(x, delta_u, eps)
|
| 294 |
+
x_adv_1 = x + delta_u + delta_p
|
| 295 |
+
|
| 296 |
+
elif norm == 'L0':
|
| 297 |
+
L1normgrad = grad / (grad.abs().view(grad.shape[0], -1).sum(
|
| 298 |
+
dim=-1, keepdim=True) + 1e-12).view(grad.shape[0], *[1] * (
|
| 299 |
+
len(grad.shape) - 1))
|
| 300 |
+
x_adv_1 = x_adv + step_size * L1normgrad * n_fts
|
| 301 |
+
# TODO: add momentum
|
| 302 |
+
|
| 303 |
+
x_adv = x_adv_1.to(dtype=x_adv.dtype) + 0.
|
| 304 |
+
|
| 305 |
+
### get gradient
|
| 306 |
+
x_adv.requires_grad_()
|
| 307 |
+
# grad = torch.zeros_like(x)
|
| 308 |
+
# for _ in range(self.eot_iter)
|
| 309 |
+
with torch.enable_grad():
|
| 310 |
+
loss_indiv = model(x_adv)#.unsqueeze(0)
|
| 311 |
+
loss = loss_indiv.sum()
|
| 312 |
+
|
| 313 |
+
# grad += torch.autograd.grad(loss, [x_adv])[0].detach()
|
| 314 |
+
if i < n_iter - 1:
|
| 315 |
+
# save one backward pass
|
| 316 |
+
grad = torch.autograd.grad(loss, [x_adv])[0].detach()
|
| 317 |
+
if mask is not None:
|
| 318 |
+
grad *= mask
|
| 319 |
+
# grad /= float(self.eot_iter)
|
| 320 |
+
x_adv.detach_()
|
| 321 |
+
loss_indiv = loss_indiv.detach()
|
| 322 |
+
loss = loss.detach()
|
| 323 |
+
|
| 324 |
+
x_best_adv = x_adv + 0.
|
| 325 |
+
if verbose and (i % max(n_iter // 10, 1) == 0 or i == n_iter - 1):
|
| 326 |
+
str_stats = ' - step size: {:.5f} - topk: {:.2f}'.format(
|
| 327 |
+
step_size.mean(), topk.mean() * n_fts) if norm in ['L1'] else ' - step size: {:.5f}'.format(
|
| 328 |
+
step_size.mean())
|
| 329 |
+
print('iteration: {} - best loss: {:.6f} curr loss {:.6f} {}'.format(
|
| 330 |
+
i, loss_best.sum(), loss_curr, str_stats))
|
| 331 |
+
# print('pert {}'.format((x - x_best_adv).abs().view(x.shape[0], -1).sum(-1).max()))
|
| 332 |
+
|
| 333 |
+
### check step size
|
| 334 |
+
if True: # with torch.no_grad()
|
| 335 |
+
y1 = loss_indiv.detach().clone()
|
| 336 |
+
loss_steps[i] = y1 + 0
|
| 337 |
+
ind = (y1 > loss_best).nonzero().squeeze()
|
| 338 |
+
x_best[ind] = x_adv[ind].clone()
|
| 339 |
+
grad_best[ind] = grad[ind].clone()
|
| 340 |
+
loss_best[ind] = y1[ind] + 0
|
| 341 |
+
loss_best_steps[i + 1] = loss_best + 0
|
| 342 |
+
|
| 343 |
+
counter3 += 1
|
| 344 |
+
|
| 345 |
+
if counter3 == k:
|
| 346 |
+
if norm in ['Linf', 'L2']:
|
| 347 |
+
fl_oscillation = check_oscillation(loss_steps, i, k,
|
| 348 |
+
loss_best, k3=thr_decr)
|
| 349 |
+
fl_reduce_no_impr = (1. - reduced_last_check) * (
|
| 350 |
+
loss_best_last_check >= loss_best).float()
|
| 351 |
+
fl_oscillation = torch.max(fl_oscillation,
|
| 352 |
+
fl_reduce_no_impr)
|
| 353 |
+
reduced_last_check = fl_oscillation.clone()
|
| 354 |
+
loss_best_last_check = loss_best.clone()
|
| 355 |
+
|
| 356 |
+
if fl_oscillation.sum() > 0:
|
| 357 |
+
ind_fl_osc = (fl_oscillation > 0).nonzero().squeeze()
|
| 358 |
+
step_size[ind_fl_osc] /= 2.0
|
| 359 |
+
n_reduced = fl_oscillation.sum()
|
| 360 |
+
|
| 361 |
+
x_adv[ind_fl_osc] = x_best[ind_fl_osc].clone()
|
| 362 |
+
grad[ind_fl_osc] = grad_best[ind_fl_osc].clone()
|
| 363 |
+
|
| 364 |
+
counter3 = 0
|
| 365 |
+
k = max(k - size_decr, n_iter_min)
|
| 366 |
+
|
| 367 |
+
elif norm in ['L1']:
|
| 368 |
+
# adjust sparsity
|
| 369 |
+
sp_curr = L0_norm(x_best - x)
|
| 370 |
+
fl_redtopk = (sp_curr / sp_old) < .95
|
| 371 |
+
topk = sp_curr / n_fts / 1.5
|
| 372 |
+
step_size[fl_redtopk] = alpha * eps
|
| 373 |
+
step_size[~fl_redtopk] /= adasp_redstep
|
| 374 |
+
step_size.clamp_(alpha * eps / adasp_minstep, alpha * eps)
|
| 375 |
+
sp_old = sp_curr.clone()
|
| 376 |
+
|
| 377 |
+
x_adv[fl_redtopk] = x_best[fl_redtopk].clone()
|
| 378 |
+
grad[fl_redtopk] = grad_best[fl_redtopk].clone()
|
| 379 |
+
|
| 380 |
+
counter3 = 0
|
| 381 |
+
|
| 382 |
+
return x_best, loss_best, x_best_adv
|
| 383 |
+
|
| 384 |
+
|
vlm_eval/attacks/attack.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class Attack(object):
|
| 5 |
+
'''
|
| 6 |
+
Root class for all adversarial attack classes.
|
| 7 |
+
'''
|
| 8 |
+
|
| 9 |
+
def __init__(self, model, targeted=False, img_range=(0, 1)):
|
| 10 |
+
self.model = model
|
| 11 |
+
self.device = 'cuda:0'
|
| 12 |
+
self.targeted = targeted
|
| 13 |
+
self.img_range = img_range
|
| 14 |
+
|
| 15 |
+
def __repr__(self):
|
| 16 |
+
return str(self.__dict__)
|
| 17 |
+
|
| 18 |
+
def to(self, device):
|
| 19 |
+
self.model.to(device)
|
| 20 |
+
self.device = device
|
vlm_eval/attacks/ead.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code taken and adapted from https://github.com/wagnermoritz/GSE
|
| 2 |
+
import torch
|
| 3 |
+
from vlm_eval.attacks.attack import Attack
|
| 4 |
+
|
| 5 |
+
class EAD(Attack):
|
| 6 |
+
|
| 7 |
+
def __init__(self,model, targeted=False, img_range=(0,1), steps=100, beta=5e-5, mask_out='none', ver=False, binary_steps=2, step_size=1e-2, decision_rule='L1'):
|
| 8 |
+
|
| 9 |
+
super().__init__(model=model, targeted=targeted, img_range=img_range)
|
| 10 |
+
self.steps = steps
|
| 11 |
+
self.ver = ver
|
| 12 |
+
self.binary_steps = binary_steps
|
| 13 |
+
self.beta = beta
|
| 14 |
+
if mask_out != 'none':
|
| 15 |
+
self.mask_out = mask_out
|
| 16 |
+
else:
|
| 17 |
+
self.mask_out = None
|
| 18 |
+
self.decision_rule = decision_rule
|
| 19 |
+
self.ver = ver
|
| 20 |
+
self.step_size = step_size
|
| 21 |
+
|
| 22 |
+
def _set_mask(self, data):
|
| 23 |
+
mask = torch.ones_like(data)
|
| 24 |
+
if self.mask_out == 'context':
|
| 25 |
+
mask[:, :-1, ...] = 0
|
| 26 |
+
elif self.mask_out == 'query':
|
| 27 |
+
mask[:, -1, ...] = 0
|
| 28 |
+
elif isinstance(self.mask_out, int):
|
| 29 |
+
mask[:, self.mask_out, ...] = 0
|
| 30 |
+
elif self.mask_out is None:
|
| 31 |
+
pass
|
| 32 |
+
else:
|
| 33 |
+
raise NotImplementedError(f'Unknown mask_out: {self.mask_out}')
|
| 34 |
+
return mask
|
| 35 |
+
|
| 36 |
+
def __call__(self, x_orig):
|
| 37 |
+
|
| 38 |
+
for param in self.model.model.parameters():
|
| 39 |
+
param.requires_grad = False
|
| 40 |
+
|
| 41 |
+
mask_out = self._set_mask(x_orig)
|
| 42 |
+
|
| 43 |
+
c = 1e-1
|
| 44 |
+
c_upper = 10e+10
|
| 45 |
+
c_lower = 0
|
| 46 |
+
|
| 47 |
+
overall_best_attack = x_orig.clone()
|
| 48 |
+
overall_best_dist = torch.inf
|
| 49 |
+
overall_best_loss = 1e10
|
| 50 |
+
|
| 51 |
+
for binary_step in range(self.binary_steps):
|
| 52 |
+
|
| 53 |
+
global_step = 0
|
| 54 |
+
x = x_orig.clone().detach()
|
| 55 |
+
y = x_orig.clone().detach()
|
| 56 |
+
|
| 57 |
+
best_attack = x_orig.clone().detach()
|
| 58 |
+
best_dist = torch.inf
|
| 59 |
+
best_loss = 1e10
|
| 60 |
+
|
| 61 |
+
step_size = 1e-2
|
| 62 |
+
|
| 63 |
+
for step in range(self.steps):
|
| 64 |
+
|
| 65 |
+
y.requires_grad = True
|
| 66 |
+
_, loss = self.loss_fn(x=y, c=c, x_orig=x_orig)
|
| 67 |
+
loss.backward()
|
| 68 |
+
y_grad = y.grad.data * mask_out
|
| 69 |
+
|
| 70 |
+
with torch.no_grad():
|
| 71 |
+
x_new = self.project(x=y-step_size*y_grad, x_orig=x_orig)
|
| 72 |
+
|
| 73 |
+
step_size = (self.step_size - 0) * (1 - global_step / self.steps) ** 0.5 + 0
|
| 74 |
+
global_step += 1
|
| 75 |
+
|
| 76 |
+
y = x_new + (step / (step + 3)) * (x_new - x)
|
| 77 |
+
x = x_new
|
| 78 |
+
|
| 79 |
+
loss_model, loss = self.loss_fn(x=x, c=c, x_orig=x_orig)
|
| 80 |
+
|
| 81 |
+
if self.ver and step % 20 == 0:
|
| 82 |
+
print(f"Binary Step: {binary_step}, Iter: {step}, Loss: {loss.item()}, L0: {(x - x_orig).norm(p=0)}, Linf: {(x - x_orig).norm(p=torch.inf)}")
|
| 83 |
+
|
| 84 |
+
if self.decision_rule == 'L1':
|
| 85 |
+
if (x - x_orig).norm(p=1).item() < best_dist and loss_model < best_loss:
|
| 86 |
+
best_loss = loss_model
|
| 87 |
+
best_attack = x.clone()
|
| 88 |
+
best_dist = (x - x_orig).norm(p=1).item()
|
| 89 |
+
else:
|
| 90 |
+
raise NotImplementedError
|
| 91 |
+
|
| 92 |
+
# Updating c
|
| 93 |
+
if overall_best_dist > best_dist and best_loss < overall_best_loss:
|
| 94 |
+
overall_best_loss = best_loss
|
| 95 |
+
overall_best_dist = best_dist
|
| 96 |
+
overall_best_attack = best_attack.clone()
|
| 97 |
+
|
| 98 |
+
c_upper = min(c_upper, c)
|
| 99 |
+
if c_upper < 1e9:
|
| 100 |
+
c = (c_upper + c_lower) / 2
|
| 101 |
+
|
| 102 |
+
else:
|
| 103 |
+
c_lower = max(c_lower, c)
|
| 104 |
+
if c_upper < 1e9:
|
| 105 |
+
c = (c_lower + c_upper) / 2.0
|
| 106 |
+
else:
|
| 107 |
+
c *= 10
|
| 108 |
+
|
| 109 |
+
print(f"Final L0: {(overall_best_attack - x_orig).norm(p=0)}, Linf: {(overall_best_attack - x_orig).norm(p=torch.inf)}")
|
| 110 |
+
return overall_best_attack.detach()
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def project(self, x, x_orig):
|
| 114 |
+
|
| 115 |
+
mask_1 = (x - x_orig > self.beta).float()
|
| 116 |
+
mask_2 = ((x - x_orig).abs() <= self.beta).float()
|
| 117 |
+
mask_3 = (x - x_orig < -self.beta).float()
|
| 118 |
+
|
| 119 |
+
upper = torch.minimum(x - self.beta, torch.tensor(1.0))
|
| 120 |
+
lower = torch.maximum(x + self.beta, torch.tensor(0.0))
|
| 121 |
+
|
| 122 |
+
proj_x = mask_1 * upper + mask_2 * x_orig + mask_3 * lower
|
| 123 |
+
return proj_x
|
| 124 |
+
|
| 125 |
+
def loss_fn(self, x, c, x_orig):
|
| 126 |
+
|
| 127 |
+
out = -self.model(x).sum() if not self.targeted else self.model(x).sum()
|
| 128 |
+
l2_dist = ((x - x_orig) ** 2).view(x.shape[0], -1).sum(dim=1)
|
| 129 |
+
l1_dist = ((x - x_orig).abs()).view(x.shape[0], -1).sum(dim=1)
|
| 130 |
+
|
| 131 |
+
return out, c * out + l2_dist.sum() + \
|
| 132 |
+
self.beta * l1_dist.sum()
|
vlm_eval/attacks/fwnucl.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code taken and adapted from https://github.com/wagnermoritz/GSE
|
| 2 |
+
import torch
|
| 3 |
+
import math
|
| 4 |
+
from vlm_eval.attacks.attack import Attack
|
| 5 |
+
|
| 6 |
+
class FWnucl(Attack):
|
| 7 |
+
def __init__(self, model, *args, iters=200, img_range=(-1, 1), ver=False,
|
| 8 |
+
targeted=False, eps=5, mask_out='none',**kwargs):
|
| 9 |
+
'''
|
| 10 |
+
Implementation of the nuclear group norm attack.
|
| 11 |
+
|
| 12 |
+
args:
|
| 13 |
+
model: Callable, PyTorch classifier.
|
| 14 |
+
ver: Bool, print progress if True.
|
| 15 |
+
img_range: Tuple of ints/floats, lower and upper bound of image
|
| 16 |
+
entries.
|
| 17 |
+
targeted: Bool, given label is used as a target label if True.
|
| 18 |
+
eps: Float, radius of the nuclear group norm ball.
|
| 19 |
+
'''
|
| 20 |
+
super().__init__(model, img_range=img_range, targeted=targeted)
|
| 21 |
+
self.iters = iters
|
| 22 |
+
self.ver = ver
|
| 23 |
+
self.eps = eps
|
| 24 |
+
self.gr = (math.sqrt(5) + 1) / 2
|
| 25 |
+
if mask_out != 'none':
|
| 26 |
+
self.mask_out = mask_out
|
| 27 |
+
else:
|
| 28 |
+
self.mask_out = None
|
| 29 |
+
|
| 30 |
+
def _set_mask(self, data):
|
| 31 |
+
mask = torch.ones_like(data)
|
| 32 |
+
if self.mask_out == 'context':
|
| 33 |
+
mask[:, :-1, ...] = 0
|
| 34 |
+
elif self.mask_out == 'query':
|
| 35 |
+
mask[:, -1, ...] = 0
|
| 36 |
+
elif isinstance(self.mask_out, int):
|
| 37 |
+
mask[:, self.mask_out, ...] = 0
|
| 38 |
+
elif self.mask_out is None:
|
| 39 |
+
pass
|
| 40 |
+
else:
|
| 41 |
+
raise NotImplementedError(f'Unknown mask_out: {self.mask_out}')
|
| 42 |
+
return mask
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def __loss_fn(self, x):
|
| 46 |
+
'''
|
| 47 |
+
Compute loss depending on self.targeted.
|
| 48 |
+
'''
|
| 49 |
+
if self.targeted:
|
| 50 |
+
return -self.model(x).sum()
|
| 51 |
+
else:
|
| 52 |
+
return self.model(x).sum()
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def __call__(self, x, *args, **kwargs):
|
| 56 |
+
'''
|
| 57 |
+
Perform the nuclear group norm attack on a batch of images x.
|
| 58 |
+
|
| 59 |
+
args:
|
| 60 |
+
x: Tensor of shape [B, C, H, W], batch of images.
|
| 61 |
+
y: Tensor of shape [B], batch of labels.
|
| 62 |
+
|
| 63 |
+
Returns a tensor of the same shape as x containing adversarial examples
|
| 64 |
+
'''
|
| 65 |
+
|
| 66 |
+
for param in self.model.model.parameters():
|
| 67 |
+
param.requires_grad = False
|
| 68 |
+
|
| 69 |
+
mask_out = self._set_mask(x)
|
| 70 |
+
x = x.to(self.device)
|
| 71 |
+
noise = torch.zeros_like(x)
|
| 72 |
+
noise.requires_grad = True
|
| 73 |
+
|
| 74 |
+
for t in range(self.iters):
|
| 75 |
+
if self.ver:
|
| 76 |
+
print(f'\rIteration {t+1}/{self.iters}', end='')
|
| 77 |
+
|
| 78 |
+
loss = self.__loss_fn(x + noise * mask_out)
|
| 79 |
+
loss.backward()
|
| 80 |
+
noise.grad.data = noise.grad.data * mask_out
|
| 81 |
+
s = self.__groupNuclearLMO(noise.grad.data, eps=self.eps)
|
| 82 |
+
with torch.no_grad():
|
| 83 |
+
gamma = self.__lineSearch(x=x, s=s, noise=noise)
|
| 84 |
+
noise = (1 - gamma) * noise + gamma * s
|
| 85 |
+
noise.requires_grad = True
|
| 86 |
+
|
| 87 |
+
if self.ver and t % 20 == 0:
|
| 88 |
+
print(f"Iteration: {t}, Loss: {loss.item()}")
|
| 89 |
+
x = torch.clamp(x + noise, 0, 1)
|
| 90 |
+
if self.ver:
|
| 91 |
+
print("")
|
| 92 |
+
return x.detach()
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def __lineSearch(self, x, s, noise, steps=25):
|
| 96 |
+
'''
|
| 97 |
+
Perform line search for the step size.
|
| 98 |
+
'''
|
| 99 |
+
a = torch.zeros(x.shape[1], device=self.device).view(-1, 1, 1, 1)
|
| 100 |
+
b = torch.ones(x.shape[1], device=self.device).view(-1, 1, 1, 1)
|
| 101 |
+
c = b - (b - a) / self.gr
|
| 102 |
+
d = a + (b - a) / self.gr
|
| 103 |
+
sx = s - noise
|
| 104 |
+
|
| 105 |
+
for i in range(steps):
|
| 106 |
+
loss1 = self.__loss_fn(x + noise + (c * sx).view(*x.shape))
|
| 107 |
+
loss2 = self.__loss_fn(x + noise + (d * sx).view(*x.shape))
|
| 108 |
+
mask = loss1 > loss2
|
| 109 |
+
|
| 110 |
+
b[mask] = d[mask]
|
| 111 |
+
mask = torch.logical_not(mask)
|
| 112 |
+
a[mask] = c[mask]
|
| 113 |
+
|
| 114 |
+
c = b - (b - a) / self.gr
|
| 115 |
+
d = a + (b - a) / self.gr
|
| 116 |
+
|
| 117 |
+
return (b + a) / 2
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def __groupNuclearLMO(self, x, eps=5):
|
| 121 |
+
'''
|
| 122 |
+
LMO for the nuclear group norm ball.
|
| 123 |
+
'''
|
| 124 |
+
|
| 125 |
+
B, C, H, W = x.shape[1], x.shape[3], x.shape[4], x.shape[5]
|
| 126 |
+
size = 32 if H > 64 else 4
|
| 127 |
+
|
| 128 |
+
# turn batch of images into batch of size by size pixel groups per
|
| 129 |
+
# color channel
|
| 130 |
+
xrgb = [x.view(B, C, H, W)[:, c, :, :] for c in range(C)]
|
| 131 |
+
xrgb = [xc.unfold(1, size, size).unfold(2, size, size) for xc in xrgb]
|
| 132 |
+
xrgb = [xc.reshape(-1, size, size) for xc in xrgb]
|
| 133 |
+
|
| 134 |
+
# compute nuclear norm of each patch (sum norms over color channels)
|
| 135 |
+
norms = torch.linalg.svdvals(xrgb[0])
|
| 136 |
+
for xc in xrgb[1:]:
|
| 137 |
+
norms += torch.linalg.svdvals(xc)
|
| 138 |
+
norms = norms.sum(-1).reshape(B, -1)
|
| 139 |
+
|
| 140 |
+
# only keep the patch g* with the largest nuclear norm for each image
|
| 141 |
+
idxs = norms.argmax(dim=1).view(-1, 1)
|
| 142 |
+
xrgb = [xc.reshape(B, -1, size, size) for xc in xrgb]
|
| 143 |
+
xrgb = [xc[torch.arange(B).view(-1, 1), idxs].view(B, size, size)
|
| 144 |
+
for xc in xrgb]
|
| 145 |
+
|
| 146 |
+
# build index tensor corr. to the position of the kept patches in x
|
| 147 |
+
off = (idxs % (W / size)).long() * size
|
| 148 |
+
off += torch.floor(idxs / (W / size)).long() * W * size
|
| 149 |
+
idxs = torch.arange(0, size**2,
|
| 150 |
+
device=self.device).view(1, -1).repeat(B, 1) + off
|
| 151 |
+
off = torch.arange(0, size,
|
| 152 |
+
device=self.device).view(-1, 1).repeat(1, size)
|
| 153 |
+
off = off * W - off * size
|
| 154 |
+
idxs += off.view(1, -1)
|
| 155 |
+
|
| 156 |
+
# compute singular vector pairs corresponding to largest singular value
|
| 157 |
+
# and final perturbation (LMO solution)
|
| 158 |
+
pert = torch.zeros_like(x).view(B, C, H, W)
|
| 159 |
+
for i, xc in enumerate(xrgb):
|
| 160 |
+
U, _, V = torch.linalg.svd(xc)
|
| 161 |
+
U = U[:, :, 0].view(B, size, 1)
|
| 162 |
+
V = V.transpose(-2, -1)[:, :, 0].view(B, size, 1)
|
| 163 |
+
pert_gr = torch.bmm(U, V.transpose(-2, -1)).reshape(B, size * size)
|
| 164 |
+
idx = torch.arange(B).view(-1, 1)
|
| 165 |
+
pert_tmp = pert[:, i, :, :].view(B, -1)
|
| 166 |
+
pert_tmp[idx, idxs] = pert_gr * eps
|
| 167 |
+
pert_clone = pert.clone()
|
| 168 |
+
pert_clone[:, i, :, :] = pert_tmp.view(B, H, W)
|
| 169 |
+
|
| 170 |
+
return pert_clone.view(*x.shape)
|
vlm_eval/attacks/gse.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code taken and adapted from https://github.com/wagnermoritz/GSE
|
| 2 |
+
import torch
|
| 3 |
+
import torchvision
|
| 4 |
+
import math
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from vlm_eval.attacks.attack import Attack
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# required input size : batch_size x num_media x num_frames x channels x height x width
|
| 11 |
+
class GSEAttack(Attack):
|
| 12 |
+
def __init__(self, model, *args, mask_out='none',ver=False, img_range=(-1, 1), search_steps=4,
|
| 13 |
+
targeted=False, sequential=False, search_factor=2,
|
| 14 |
+
gb_size=5, sgm=1.5, mu=1, sigma=0.0025, iters=200, k_hat=10,
|
| 15 |
+
q=0.25, **kwargs):
|
| 16 |
+
'''
|
| 17 |
+
Implementation of the GSE attack.
|
| 18 |
+
|
| 19 |
+
args:
|
| 20 |
+
model: Callable, PyTorch classifier.
|
| 21 |
+
mask_out: Masks out context images if set to context, query images if set to query and none if set to none.
|
| 22 |
+
ver: Bool, print progress if True.
|
| 23 |
+
img_range: Tuple of ints/floats, lower and upper bound of image
|
| 24 |
+
entries.
|
| 25 |
+
search_steps: Int, number of steps for line search on the trade-off
|
| 26 |
+
parameter.
|
| 27 |
+
targeted: Bool, given label is used as a target label if True.
|
| 28 |
+
sequential: Bool, perturbations are computed sequentially for all
|
| 29 |
+
images in the batch if True. For fair comparison to
|
| 30 |
+
Homotopy attack.
|
| 31 |
+
search_factor: Float, factor to increase/decrease the trade-off
|
| 32 |
+
parameter until an upper/lower bound for the line search
|
| 33 |
+
is found.
|
| 34 |
+
gb_size: Odd int, size of the Gaussian blur kernel.
|
| 35 |
+
sgm: Float, sigma of the gaussian blur kernel
|
| 36 |
+
mu: Float, trade-off parameter for 2-norm regularization.
|
| 37 |
+
sigma: Float, step size
|
| 38 |
+
iters: Int, number of iterations.
|
| 39 |
+
k_hat: Int, number of iterations before transitioning to NAG.
|
| 40 |
+
q: Float, inverse of increase factor for adjust_lambda.
|
| 41 |
+
'''
|
| 42 |
+
super().__init__(model, img_range=img_range, targeted=targeted)
|
| 43 |
+
self.ver = ver
|
| 44 |
+
self.search_steps = search_steps
|
| 45 |
+
self.sequential = sequential
|
| 46 |
+
self.search_factor = search_factor
|
| 47 |
+
self.gb_size = gb_size
|
| 48 |
+
self.sgm = sgm
|
| 49 |
+
self.mu = mu
|
| 50 |
+
self.sigma = sigma
|
| 51 |
+
self.iters = iters
|
| 52 |
+
self.k_hat = k_hat
|
| 53 |
+
self.q = q
|
| 54 |
+
if mask_out != 'none':
|
| 55 |
+
self.mask_out = mask_out
|
| 56 |
+
else:
|
| 57 |
+
self.mask_out = None
|
| 58 |
+
|
| 59 |
+
def adjust_lambda(self, lam, noise):
|
| 60 |
+
'''
|
| 61 |
+
Adjust trade-off parameters (lambda) to update search space.
|
| 62 |
+
'''
|
| 63 |
+
x = noise.detach().clone().abs().mean(dim=1, keepdim=True).sign()
|
| 64 |
+
gb = torchvision.transforms.GaussianBlur((self.gb_size, self.gb_size),
|
| 65 |
+
sigma=self.sgm)
|
| 66 |
+
x = gb(x) + 1
|
| 67 |
+
x = torch.where(x == 1, self.q, x)
|
| 68 |
+
lam /= x[:, 0, :, :]
|
| 69 |
+
return lam
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def section_search(self, x, steps=50):
|
| 73 |
+
'''
|
| 74 |
+
Section search for finding the maximal lambda such that the
|
| 75 |
+
perturbation is non-zero after the first iteration.
|
| 76 |
+
'''
|
| 77 |
+
|
| 78 |
+
noise = torch.zeros_like(x, requires_grad=True) # the shape of 'x' is batch_size x num_media x num_frames x Color x height x width
|
| 79 |
+
loss = (-self.model(x + noise).sum() + self.mu
|
| 80 |
+
* torch.norm(noise.view(x.size(1), x.size(3), x.size(4), x.size(5)), p=2, dim=(1,2,3)).sum())
|
| 81 |
+
grad = torch.autograd.grad(loss, [noise])[0].detach()
|
| 82 |
+
noise.detach_()
|
| 83 |
+
ones = torch.ones_like(x.view(x.size(1), x.size(3), x.size(4), x.size(5)))[:, 0, :, :]
|
| 84 |
+
|
| 85 |
+
# define upper and lower bound for line search
|
| 86 |
+
lb = torch.zeros((x.size(1),), dtype=torch.float,
|
| 87 |
+
device=self.device).view(-1, 1, 1)
|
| 88 |
+
ub = lb.clone() + 0.001
|
| 89 |
+
mask = torch.norm(self.prox(grad.clone().view(x.size(1),x.size(3),x.size(4),x.size(5)) * self.sigma,
|
| 90 |
+
ones * ub * self.sigma),
|
| 91 |
+
p=0, dim=(1,2,3)) != 0
|
| 92 |
+
while mask.any():
|
| 93 |
+
ub[mask] *= 2
|
| 94 |
+
mask = torch.norm(self.prox(grad.clone().view(x.size(1),x.size(3),x.size(4),x.size(5)) * self.sigma,
|
| 95 |
+
ones * ub * self.sigma),
|
| 96 |
+
p=0, dim=(1,2,3)) != 0
|
| 97 |
+
|
| 98 |
+
# perform search
|
| 99 |
+
for _ in range(steps):
|
| 100 |
+
cur = (ub + lb) / 2
|
| 101 |
+
mask = torch.norm(self.prox(grad.clone().view(x.size(1),x.size(3),x.size(4),x.size(5)) * self.sigma,
|
| 102 |
+
ones * cur * self.sigma),
|
| 103 |
+
p=0, dim=(1,2,3)) == 0
|
| 104 |
+
ub[mask] = cur[mask]
|
| 105 |
+
mask = torch.logical_not(mask)
|
| 106 |
+
lb[mask] = cur[mask]
|
| 107 |
+
cur = (lb + ub).view(-1) / 2
|
| 108 |
+
return 0.01 * cur
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def __call__(self, x, y, *args, **kwargs):
|
| 112 |
+
'''
|
| 113 |
+
Call the attack for a batch of images x or sequentially for all images
|
| 114 |
+
in x depending on self.sequential.
|
| 115 |
+
|
| 116 |
+
args:
|
| 117 |
+
x: Tensor of shape [B, C, H, W], batch of images.
|
| 118 |
+
y: Tensor of shape [B], batch of labels.
|
| 119 |
+
|
| 120 |
+
Returns a tensor of the same shape as x containing adversarial examples
|
| 121 |
+
'''
|
| 122 |
+
if self.sequential:
|
| 123 |
+
result = x.clone()
|
| 124 |
+
for i, (x_, y_) in enumerate(zip(x, y)):
|
| 125 |
+
result[i] = self.perform_att(x_.unsqueeze(0),
|
| 126 |
+
y_.unsqueeze(0),
|
| 127 |
+
mu=self.mu, sigma=self.sigma,
|
| 128 |
+
k_hat=self.k_hat).detach()
|
| 129 |
+
return result
|
| 130 |
+
else:
|
| 131 |
+
return self.perform_att(x, y, mu=self.mu, sigma=self.sigma,
|
| 132 |
+
k_hat=self.k_hat)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def _set_mask(self, data):
|
| 136 |
+
mask = torch.ones_like(data)
|
| 137 |
+
if self.mask_out == 'context':
|
| 138 |
+
mask[:, :-1, ...] = 0
|
| 139 |
+
elif self.mask_out == 'query':
|
| 140 |
+
mask[:, -1, ...] = 0
|
| 141 |
+
elif isinstance(self.mask_out, int):
|
| 142 |
+
mask[:, self.mask_out, ...] = 0
|
| 143 |
+
elif self.mask_out is None:
|
| 144 |
+
pass
|
| 145 |
+
else:
|
| 146 |
+
raise NotImplementedError(f'Unknown mask_out: {self.mask_out}')
|
| 147 |
+
return mask
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def perform_att(self, x, mu, sigma, k_hat):
|
| 151 |
+
'''
|
| 152 |
+
Perform GSE attack on a batch of images x with corresponding labels y.
|
| 153 |
+
'''
|
| 154 |
+
x = x.to(self.device)
|
| 155 |
+
B, C, H, W = x.shape[1], x.shape[3], x.shape[4], x.shape[5] # Input is of the shape Batch x Num_media x num_frames x colors x height x width
|
| 156 |
+
lams = self.section_search(x)
|
| 157 |
+
mask_out = self._set_mask(x).view(B,C,H,W)
|
| 158 |
+
# save x, y, and lams for resetting them at the beginning of every
|
| 159 |
+
# section search step
|
| 160 |
+
save_x = x.clone()
|
| 161 |
+
save_lams = lams.clone()
|
| 162 |
+
# upper and lower bounds for section learch
|
| 163 |
+
ub_lams = torch.full_like(lams, torch.inf)
|
| 164 |
+
lb_lams = torch.full_like(lams, 0.0)
|
| 165 |
+
# tensor for saving succesful adversarial examples in inner loop
|
| 166 |
+
result = x.clone()
|
| 167 |
+
# tensor for saving best adversarial example so far
|
| 168 |
+
result2 = x.clone()
|
| 169 |
+
best_l0 = torch.full((B,), torch.inf, device=self.device).type(x.type())
|
| 170 |
+
|
| 171 |
+
# section search
|
| 172 |
+
for step in range(self.search_steps):
|
| 173 |
+
x = save_x.clone()
|
| 174 |
+
lams = save_lams.clone()
|
| 175 |
+
lam = torch.ones_like(x.view(B, C, H, W))[:, 0, :, :] * lams.view(-1, 1, 1)
|
| 176 |
+
# tensor for tracking for which images adv. examples have been found
|
| 177 |
+
active = torch.ones(B, dtype=bool, device=self.device)
|
| 178 |
+
# set initial perturbation to zero
|
| 179 |
+
noise = torch.zeros_like(x, requires_grad = True)
|
| 180 |
+
noise_old = noise.clone()
|
| 181 |
+
lr = 1
|
| 182 |
+
|
| 183 |
+
# attack
|
| 184 |
+
for j in range(self.iters):
|
| 185 |
+
if self.ver:
|
| 186 |
+
print(f'\rSearch step {step + 1}/{self.search_steps}, ' +
|
| 187 |
+
f'Prox.Grad. Iteration {j + 1}/{self.iters}, ' +
|
| 188 |
+
f'Images left: {x.shape[1]}', end='')
|
| 189 |
+
if len(x) == 0:
|
| 190 |
+
break
|
| 191 |
+
|
| 192 |
+
self.model.model.zero_grad()
|
| 193 |
+
loss = (-self.model(x + noise).sum() + mu
|
| 194 |
+
* (torch.norm(noise.view(B, C, H, W), p=2, dim=(1,2,3)) ** 2).sum())
|
| 195 |
+
noise_grad_data = torch.autograd.grad(loss, [noise])[0].detach().view(B, C, H, W)
|
| 196 |
+
#print(f"{loss} {(torch.norm(noise.view(B, C, H, W), p=2, dim=(1,2,3)) ** 2).sum()}")
|
| 197 |
+
with torch.no_grad():
|
| 198 |
+
|
| 199 |
+
noise_grad_data = noise_grad_data * mask_out # Mask_out shape B x C x H x W
|
| 200 |
+
lr_ = (1 + math.sqrt(1 + 4 * lr**2)) / 2
|
| 201 |
+
if j == k_hat:
|
| 202 |
+
lammask = (lam > lams.view(-1, 1, 1))[:, None, :, :]
|
| 203 |
+
lammask = lammask.repeat(1, C, 1, 1)
|
| 204 |
+
noise_old = noise.clone()
|
| 205 |
+
if j < k_hat:
|
| 206 |
+
noise = noise - sigma * noise_grad_data.view(1, B, 1, C, H, W)
|
| 207 |
+
noise = self.prox(noise.view(B, C, H, W), lam * sigma).view(1, B, 1, C, H, W)
|
| 208 |
+
noise_tmp = noise.clone()
|
| 209 |
+
noise = lr / lr_ * noise + (1 - (lr/ lr_)) * noise_old
|
| 210 |
+
noise_old = noise_tmp.clone()
|
| 211 |
+
lam = self.adjust_lambda(lam, noise.view(B, C, H, W))
|
| 212 |
+
else:
|
| 213 |
+
noise = noise - sigma * noise_grad_data.view(1, B, 1, C, H, W)
|
| 214 |
+
noise_tmp = noise.clone()
|
| 215 |
+
noise = lr / lr_ * noise + (1 - (lr/ lr_)) * noise_old
|
| 216 |
+
noise_old = noise_tmp.clone()
|
| 217 |
+
noise[lammask.view(1, B, 1, C, H, W)] = 0
|
| 218 |
+
# clamp adv. example to valid range
|
| 219 |
+
x_adv = torch.clamp(x + noise, *self.img_range)
|
| 220 |
+
noise = x_adv - x
|
| 221 |
+
lr = lr_
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
noise.requires_grad = True
|
| 225 |
+
|
| 226 |
+
# section search
|
| 227 |
+
# no adv. example found => decrease upper bound and current lambda
|
| 228 |
+
# adv. example found => save it if the "0-norm" is better than of the
|
| 229 |
+
# previous adv. example, increase lower bound and current lambda
|
| 230 |
+
for i in range(B):
|
| 231 |
+
if active[i]:
|
| 232 |
+
ub_lams[i] = save_lams[i]
|
| 233 |
+
save_lams[i] = 0.95 * lb_lams[i] + 0.05 * save_lams[i]
|
| 234 |
+
else:
|
| 235 |
+
print("here")
|
| 236 |
+
l0 = self.l20((result[i] - save_x[i]).unsqueeze(0)).to(self.device)
|
| 237 |
+
if l0 < best_l0[i]:
|
| 238 |
+
best_l0[i] = l0
|
| 239 |
+
result2[i] = result[i].clone()
|
| 240 |
+
if torch.isinf(ub_lams[i]):
|
| 241 |
+
lb_lams[i] = save_lams[i]
|
| 242 |
+
save_lams[i] *= self.search_factor
|
| 243 |
+
else:
|
| 244 |
+
lb_lams[i] = save_lams[i]
|
| 245 |
+
save_lams[i] = (ub_lams[i] + save_lams[i]) / 2
|
| 246 |
+
|
| 247 |
+
if self.ver:
|
| 248 |
+
print('')
|
| 249 |
+
|
| 250 |
+
return x_adv
|
| 251 |
+
|
| 252 |
+
def extract_patches(self, x):
|
| 253 |
+
'''
|
| 254 |
+
Extracts and returns all overlapping size by size patches from
|
| 255 |
+
the image batch x.
|
| 256 |
+
'''
|
| 257 |
+
B, C, _, _ = x.shape
|
| 258 |
+
size = 8
|
| 259 |
+
kernel = torch.zeros((size ** 2, size ** 2))
|
| 260 |
+
kernel[range(size**2), range(size**2)] = 1.0
|
| 261 |
+
kernel = kernel.view(size**2, 1, size, size)
|
| 262 |
+
kernel = kernel.repeat(C, 1, 1, 1).to(x.device)
|
| 263 |
+
out = F.conv2d(x, kernel, groups=C)
|
| 264 |
+
out = out.view(B, C, size, size, -1)
|
| 265 |
+
out = out.permute(0, 4, 1, 2, 3)
|
| 266 |
+
return out.contiguous()
|
| 267 |
+
|
| 268 |
+
def l20(self, x):
|
| 269 |
+
'''
|
| 270 |
+
Computes d_{2,0}(x[i]) for all perturbations x[i] in the batch x
|
| 271 |
+
as described in section 3.2.
|
| 272 |
+
'''
|
| 273 |
+
B, N, M, C, _, _ = x.shape
|
| 274 |
+
l20s = []
|
| 275 |
+
|
| 276 |
+
for b in range(B):
|
| 277 |
+
for n in range(N):
|
| 278 |
+
for m in range(M):
|
| 279 |
+
x_ = x[b, n, m] # Select the specific perturbation x[b, n, m]
|
| 280 |
+
patches = self.extract_patches(x_.unsqueeze(0)) # Add unsqueeze to match 6D input
|
| 281 |
+
l2s = torch.norm(patches, p=2, dim=(2,3,4))
|
| 282 |
+
l20s.append((l2s != 0).float().sum().item())
|
| 283 |
+
|
| 284 |
+
return torch.tensor(l20s)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def prox(self, grad_loss_noise, lam):
|
| 288 |
+
'''
|
| 289 |
+
Computes the proximal operator of the 1/2-norm of the gradient of the
|
| 290 |
+
adversarial loss wrt current noise.
|
| 291 |
+
'''
|
| 292 |
+
|
| 293 |
+
lam = lam[:, None, :, :]
|
| 294 |
+
sh = list(grad_loss_noise.shape)
|
| 295 |
+
lam = lam.expand(*sh)
|
| 296 |
+
|
| 297 |
+
p_lam = (54 ** (1 / 3) / 4) * lam ** (2 / 3)
|
| 298 |
+
|
| 299 |
+
mask1 = (grad_loss_noise > p_lam)
|
| 300 |
+
mask2 = (torch.abs(grad_loss_noise) <= p_lam)
|
| 301 |
+
mask3 = (grad_loss_noise < -p_lam)
|
| 302 |
+
mask4 = mask1 + mask3
|
| 303 |
+
|
| 304 |
+
phi_lam_x = torch.arccos((lam / 8) * (torch.abs(grad_loss_noise) / 3)
|
| 305 |
+
** (-1.5))
|
| 306 |
+
|
| 307 |
+
grad_loss_noise[mask4] = ((2 / 3) * torch.abs(grad_loss_noise[mask4])
|
| 308 |
+
* (1 + torch.cos((2 * math.pi) / 3
|
| 309 |
+
- (2 * phi_lam_x[mask4]) / 3))).to(torch.float32)
|
| 310 |
+
grad_loss_noise[mask3] = -grad_loss_noise[mask3]
|
| 311 |
+
grad_loss_noise[mask2] = 0
|
| 312 |
+
|
| 313 |
+
return grad_loss_noise
|
vlm_eval/attacks/iht.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code taken and adapted from https://github.com/wagnermoritz/GSE
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from vlm_eval.attacks.attack import Attack
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
class IHT(Attack):
|
| 8 |
+
|
| 9 |
+
def __init__(self, model, targeted=False, img_range=(0, 1), steps=100, prox='hard',ver=False, lam=5e-5, mask_out='none',stepsize=0.015,eps=4./255.):
|
| 10 |
+
super().__init__(model, targeted=targeted, img_range=img_range)
|
| 11 |
+
self.steps = steps
|
| 12 |
+
self.stepsize = stepsize
|
| 13 |
+
self.ver = ver
|
| 14 |
+
self.lam = lam
|
| 15 |
+
self.eps = eps
|
| 16 |
+
if mask_out != 'none':
|
| 17 |
+
self.mask_out = mask_out
|
| 18 |
+
else:
|
| 19 |
+
self.mask_out = None
|
| 20 |
+
if prox == 'hard':
|
| 21 |
+
self.Prox = self.hardprox
|
| 22 |
+
else:
|
| 23 |
+
raise NotImplementedError
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _set_mask(self, data):
|
| 28 |
+
mask = torch.ones_like(data)
|
| 29 |
+
if self.mask_out == 'context':
|
| 30 |
+
mask[:, :-1, ...] = 0
|
| 31 |
+
elif self.mask_out == 'query':
|
| 32 |
+
mask[:, -1, ...] = 0
|
| 33 |
+
elif isinstance(self.mask_out, int):
|
| 34 |
+
mask[:, self.mask_out, ...] = 0
|
| 35 |
+
elif self.mask_out is None:
|
| 36 |
+
pass
|
| 37 |
+
else:
|
| 38 |
+
raise NotImplementedError(f'Unknown mask_out: {self.mask_out}')
|
| 39 |
+
return mask
|
| 40 |
+
|
| 41 |
+
def __call__(self, img):
|
| 42 |
+
|
| 43 |
+
for param in self.model.model.parameters():
|
| 44 |
+
param.requires_grad = False
|
| 45 |
+
|
| 46 |
+
img = img.to(self.device)
|
| 47 |
+
mask_out = self._set_mask(img)
|
| 48 |
+
x = torch.zeros_like(img) # perturbation to optimize
|
| 49 |
+
z = x.clone() # used for FISTA extrapolation
|
| 50 |
+
t = 1
|
| 51 |
+
if self.ver:
|
| 52 |
+
print('')
|
| 53 |
+
|
| 54 |
+
for i in range(self.steps):
|
| 55 |
+
# compue gradient
|
| 56 |
+
x.requires_grad = True
|
| 57 |
+
loss = self.model(img + x).sum() if self.targeted else -self.model(img + x).sum()
|
| 58 |
+
loss.backward()
|
| 59 |
+
x_grad = x.grad.data * mask_out
|
| 60 |
+
x = x.detach()
|
| 61 |
+
|
| 62 |
+
if self.ver and i % 20 == 0:
|
| 63 |
+
print(f'Iteration: {i+1}, Loss: {loss}\n', end='')
|
| 64 |
+
|
| 65 |
+
# FISTA update
|
| 66 |
+
with torch.no_grad():
|
| 67 |
+
t_ = .5 * (1 + math.sqrt(1 + 4 * t ** 2))
|
| 68 |
+
alpha = (t - 1) / t_
|
| 69 |
+
t = t_
|
| 70 |
+
z_ = self.Prox(x=x - self.stepsize * x_grad,
|
| 71 |
+
lam=self.lam * self.stepsize,
|
| 72 |
+
img=img,
|
| 73 |
+
eps=self.eps
|
| 74 |
+
)
|
| 75 |
+
x = z_ + alpha * (z_ - z)
|
| 76 |
+
x = torch.clamp(x,-self.eps,self.eps)
|
| 77 |
+
z = z_.clone()
|
| 78 |
+
x = torch.clamp(img + x, *self.img_range) - img
|
| 79 |
+
|
| 80 |
+
if self.ver:
|
| 81 |
+
print('')
|
| 82 |
+
print(f"L0 pert norm: {x.norm(p=0)}")
|
| 83 |
+
|
| 84 |
+
return (img + x * mask_out).detach(), x.norm(p=0).item()
|
| 85 |
+
|
| 86 |
+
def hardprox(self, x, lam, img, eps):
|
| 87 |
+
'''
|
| 88 |
+
Computes the hard thresholding proximal operator of the the
|
| 89 |
+
perturbation x.
|
| 90 |
+
|
| 91 |
+
:x: Perturbation after gradient descent step.
|
| 92 |
+
:lam: Regularization parameter.
|
| 93 |
+
'''
|
| 94 |
+
x_proj = torch.clamp(x,-eps,eps)
|
| 95 |
+
x_temp = torch.clamp(img + x_proj,*self.img_range)
|
| 96 |
+
x_proj = x_temp - img
|
| 97 |
+
return torch.where(x ** 2 - (x_proj - x) ** 2 > 2 * lam, x_proj, 0)
|
vlm_eval/attacks/pgd.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code taken from https://github.com/chs20/RobustVLM/tree/main
|
| 2 |
+
import torch
|
| 3 |
+
from vlm_eval.attacks.utils import project_perturbation, normalize_grad
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class PGD:
|
| 7 |
+
"""
|
| 8 |
+
Minimize or maximize given loss
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
def __init__(self, forward, norm, eps, mode='min', mask_out='context', image_space=True):
|
| 12 |
+
self.model = forward
|
| 13 |
+
|
| 14 |
+
self.norm = norm
|
| 15 |
+
self.eps = eps
|
| 16 |
+
self.momentum = 0.9
|
| 17 |
+
|
| 18 |
+
self.mode = mode
|
| 19 |
+
self.mask_out = mask_out
|
| 20 |
+
self.image_space = image_space
|
| 21 |
+
|
| 22 |
+
def perturb(self, data_clean, iterations, stepsize, perturbation=None, verbose=False, return_loss=False):
|
| 23 |
+
if self.image_space:
|
| 24 |
+
# make sure data is in image space
|
| 25 |
+
assert torch.max(data_clean) < 1. + 1e-6 and torch.min(data_clean) > -1e-6 # todo
|
| 26 |
+
|
| 27 |
+
if perturbation is None:
|
| 28 |
+
perturbation = torch.zeros_like(data_clean, requires_grad=True)
|
| 29 |
+
mask = self._set_mask(data_clean)
|
| 30 |
+
velocity = torch.zeros_like(data_clean)
|
| 31 |
+
for i in range(iterations):
|
| 32 |
+
perturbation.requires_grad_()
|
| 33 |
+
with torch.enable_grad():
|
| 34 |
+
loss = self.model(data_clean + perturbation)
|
| 35 |
+
# print 10 times in total and last iteration
|
| 36 |
+
if verbose and (i % (iterations // 10 + 1) == 0 or i == iterations - 1):
|
| 37 |
+
print(f'[iteration] {i} [loss] {loss.item()}')
|
| 38 |
+
|
| 39 |
+
with torch.no_grad():
|
| 40 |
+
gradient = torch.autograd.grad(loss, perturbation)[0]
|
| 41 |
+
gradient = mask * gradient
|
| 42 |
+
if gradient.isnan().any(): #
|
| 43 |
+
print(f'attention: nan in gradient ({gradient.isnan().sum()})') #
|
| 44 |
+
gradient[gradient.isnan()] = 0.
|
| 45 |
+
# normalize
|
| 46 |
+
gradient = normalize_grad(gradient, p=self.norm)
|
| 47 |
+
# momentum
|
| 48 |
+
velocity = self.momentum * velocity + gradient
|
| 49 |
+
velocity = normalize_grad(velocity, p=self.norm)
|
| 50 |
+
# update
|
| 51 |
+
if self.mode == 'min':
|
| 52 |
+
perturbation = perturbation - stepsize * velocity
|
| 53 |
+
elif self.mode == 'max':
|
| 54 |
+
perturbation = perturbation + stepsize * velocity
|
| 55 |
+
else:
|
| 56 |
+
raise ValueError(f'Unknown mode: {self.mode}')
|
| 57 |
+
# project
|
| 58 |
+
perturbation = project_perturbation(perturbation, self.eps, self.norm)
|
| 59 |
+
if self.image_space:
|
| 60 |
+
perturbation = torch.clamp(
|
| 61 |
+
data_clean + perturbation, 0, 1
|
| 62 |
+
) - data_clean # clamp to image space
|
| 63 |
+
assert torch.max(data_clean + perturbation) < 1. + 1e-6 and torch.min(
|
| 64 |
+
data_clean + perturbation
|
| 65 |
+
) > -1e-6
|
| 66 |
+
assert not perturbation.isnan().any()
|
| 67 |
+
|
| 68 |
+
# assert (ctorch.compute_norm(perturbation, p=self.norm) <= self.eps + 1e-6).all()
|
| 69 |
+
# todo return best perturbation
|
| 70 |
+
# problem is that model currently does not output expanded loss
|
| 71 |
+
if return_loss:
|
| 72 |
+
return data_clean + perturbation.detach(), loss
|
| 73 |
+
else:
|
| 74 |
+
return data_clean + perturbation.detach()
|
| 75 |
+
|
| 76 |
+
def _set_mask(self, data):
|
| 77 |
+
mask = torch.ones_like(data)
|
| 78 |
+
if self.mask_out == 'context':
|
| 79 |
+
mask[:, :-1, ...] = 0
|
| 80 |
+
elif self.mask_out == 'query':
|
| 81 |
+
mask[:, -1, ...] = 0
|
| 82 |
+
elif isinstance(self.mask_out, int):
|
| 83 |
+
mask[:, self.mask_out, ...] = 0
|
| 84 |
+
elif self.mask_out is None:
|
| 85 |
+
pass
|
| 86 |
+
else:
|
| 87 |
+
raise NotImplementedError(f'Unknown mask_out: {self.mask_out}')
|
| 88 |
+
return mask
|
vlm_eval/attacks/pgd0.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code taken and adapted from https://github.com/wagnermoritz/GSE
|
| 2 |
+
|
| 3 |
+
from vlm_eval.attacks.attack import Attack
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
class PGD0(Attack):
|
| 8 |
+
def __init__(self, model, *args, img_range=(0, 1), k=5000, n_restarts=1,
|
| 9 |
+
targeted=False, iters=200, stepsize=120000/255.0, eps=4./255.,ver=False,mask_out='none',**kwargs):
|
| 10 |
+
'''
|
| 11 |
+
Implementation of the PGD0 attack https://arxiv.org/pdf/1909.05040
|
| 12 |
+
Author's implementation: https://github.com/fra31/sparse-imperceivable-attacks/tree/master
|
| 13 |
+
Addapted from: https://github.com/wagnermoritz/GSE/tree/main
|
| 14 |
+
|
| 15 |
+
args:
|
| 16 |
+
model: Callable, PyTorch classifier.
|
| 17 |
+
img_range: Tuple of ints/floats, lower and upper bound of image
|
| 18 |
+
entries.
|
| 19 |
+
targeted: Bool, given label is used as a target label if True.
|
| 20 |
+
k: Int, sparsity parameter.
|
| 21 |
+
n_restarts: Int, number of restarts from random perturbation.
|
| 22 |
+
iters: Int, number of gradient descent steps per restart.
|
| 23 |
+
stepsize: Float, step size for gradient descent.
|
| 24 |
+
'''
|
| 25 |
+
super().__init__(model, img_range=img_range, targeted=targeted)
|
| 26 |
+
self.k = k
|
| 27 |
+
self.n_restarts = n_restarts
|
| 28 |
+
self.eps = eps
|
| 29 |
+
self.iters = iters
|
| 30 |
+
self.stepsize = stepsize
|
| 31 |
+
if mask_out != 'none':
|
| 32 |
+
self.mask_out = mask_out
|
| 33 |
+
else:
|
| 34 |
+
self.mask_out = None
|
| 35 |
+
self.ver = ver
|
| 36 |
+
|
| 37 |
+
def _set_mask(self, data):
|
| 38 |
+
mask = torch.ones_like(data)
|
| 39 |
+
if self.mask_out == 'context':
|
| 40 |
+
mask[:, :-1, ...] = 0
|
| 41 |
+
elif self.mask_out == 'query':
|
| 42 |
+
mask[:, -1, ...] = 0
|
| 43 |
+
elif isinstance(self.mask_out, int):
|
| 44 |
+
mask[:, self.mask_out, ...] = 0
|
| 45 |
+
elif self.mask_out is None:
|
| 46 |
+
pass
|
| 47 |
+
else:
|
| 48 |
+
raise NotImplementedError(f'Unknown mask_out: {self.mask_out}')
|
| 49 |
+
return mask
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def __call__(self, x, *args, **kwargs):
|
| 53 |
+
'''
|
| 54 |
+
Perform the PGD_0 attack on a batch of images x.
|
| 55 |
+
|
| 56 |
+
args:
|
| 57 |
+
x: Tensor of shape [B, C, H, W], batch of images.
|
| 58 |
+
y: Tensor of shape [B], batch of labels.
|
| 59 |
+
|
| 60 |
+
Returns a tensor of the same shape as x containing adversarial examples
|
| 61 |
+
'''
|
| 62 |
+
|
| 63 |
+
for param in self.model.model.parameters():
|
| 64 |
+
param.requires_grad = False
|
| 65 |
+
|
| 66 |
+
mask_out = self._set_mask(x)
|
| 67 |
+
x = x.to(self.device)
|
| 68 |
+
B, C, H, W = x.shape[1], x.shape[3], x.shape[4], x.shape[5]
|
| 69 |
+
|
| 70 |
+
for _ in range(self.n_restarts):
|
| 71 |
+
if not len(x):
|
| 72 |
+
break
|
| 73 |
+
eps = torch.full_like(x, self.eps)
|
| 74 |
+
lb, ub = torch.maximum(-eps, -x),torch.minimum(eps, 1.0 - x) #self.img_range[0] - x, self.img_range[1] - x
|
| 75 |
+
pert = (torch.clamp(x + (ub - lb) * torch.rand_like(x) + lb, *self.img_range) - x).view(B, C, H, W) * mask_out.view(B, C, H, W)
|
| 76 |
+
pert = self.project_L0(pert, lb, ub) # pert is of the shape (B, C, H, W)
|
| 77 |
+
|
| 78 |
+
for _ in range(self.iters):
|
| 79 |
+
pert.requires_grad = True
|
| 80 |
+
loss = self.lossfn(x=x, pert=pert.view(*x.shape), mask_out=mask_out)
|
| 81 |
+
loss.backward()
|
| 82 |
+
|
| 83 |
+
if self.ver and _ % 20 == 0:
|
| 84 |
+
print(f"Loss: {loss}, Iter: {_}")
|
| 85 |
+
|
| 86 |
+
grad = pert.grad.data.view(B,C,H,W) * mask_out.view(B, C, H, W) # shape (B, C, H, W)
|
| 87 |
+
with torch.no_grad():
|
| 88 |
+
grad /= grad.abs().sum(dim=(1,2,3), keepdim=True) + 1e-10
|
| 89 |
+
pert += (torch.rand_like(x) - .5).view(B, C, H, W) * 1e-12 - self.stepsize * grad
|
| 90 |
+
pert = self.project_L0(pert, lb, ub)
|
| 91 |
+
|
| 92 |
+
return (x + pert.view(*x.shape) * mask_out).detach()
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def project_L0_sigma(self, pert, sigma, kappa, x_orig):
|
| 96 |
+
|
| 97 |
+
B, C, H, W = pert.shape
|
| 98 |
+
x = torch.clone(pert)
|
| 99 |
+
p1 = (1.0 / torch.maximum(1e-12, sigma) * (x_orig > 0).float()) + \
|
| 100 |
+
(1e12 * (x_orig == 0).float())
|
| 101 |
+
p2 = (1.0 / torch.maximum(torch.tensor(1e-12), sigma)) * \
|
| 102 |
+
(1.0 / torch.maximum(torch.tensor(1e-12), x_orig) - 1) * \
|
| 103 |
+
(x_orig > 0).float() + 1e12 * (x_orig == 0).float() + 1e12 * (sigma == 0).float()
|
| 104 |
+
lmbd_l = torch.maximum(-kappa, torch.amax(-p1, dim=1, keepdim=True))
|
| 105 |
+
lmbd_u = torch.minimum(kappa, torch.amin(p2, dim=1, keepdim=True))
|
| 106 |
+
|
| 107 |
+
lmbd_unconstr = torch.sum((pert - x_orig) * sigma * x_orig, dim=1, keepdim=True) / torch.clamp(torch.sum((sigma * x_orig) ** 2, dim=1, keepdim=True), min=1e-12)
|
| 108 |
+
lmbd = torch.maximum(lmbd_l, torch.minimum(lmbd_unconstr, lmbd_u))
|
| 109 |
+
return 0
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def project_L0(self, pert, lb, ub):
|
| 113 |
+
'''
|
| 114 |
+
Project a batch of perturbations such that at most self.k pixels
|
| 115 |
+
are perturbed and componentwise there holds lb <= pert <= ub.
|
| 116 |
+
'''
|
| 117 |
+
|
| 118 |
+
B, C, H, W = pert.shape # Here, pert is of the shape B, C, H, W
|
| 119 |
+
p1 = torch.sum(pert ** 2, dim=1)
|
| 120 |
+
p2 = torch.clamp(torch.minimum(ub.view(B, C, H, W) - pert, pert - lb.view(B, C, H, W)), 0)
|
| 121 |
+
p2 = torch.sum(p2 ** 2, dim=1)
|
| 122 |
+
p3 = torch.topk(-1 * (p1 - p2).view(p1.size(0), -1), k=H*W-self.k, dim=-1)[1]
|
| 123 |
+
pert = torch.maximum(torch.minimum(pert, ub.view(B, C, H, W)), lb.view(B, C, H, W))
|
| 124 |
+
pert[torch.arange(0, B).view(-1, 1), :, p3//W, p3%H] = 0
|
| 125 |
+
return pert
|
| 126 |
+
|
| 127 |
+
def lossfn(self, x, pert, mask_out):
|
| 128 |
+
'''
|
| 129 |
+
Compute the loss at x.
|
| 130 |
+
'''
|
| 131 |
+
return (2 * self.targeted - 1) * self.model(x + pert * mask_out).sum()
|
vlm_eval/attacks/saif.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code adapted from https://github.com/wagnermoritz/GSE
|
| 2 |
+
|
| 3 |
+
from vlm_eval.attacks.attack import Attack
|
| 4 |
+
import torch
|
| 5 |
+
import math
|
| 6 |
+
import time
|
| 7 |
+
|
| 8 |
+
class SAIF(Attack):
|
| 9 |
+
def __init__(self, model, *args, targeted=False, img_range=(-1, 1), steps=200,
|
| 10 |
+
r0=1, ver=False, k=10000, eps=16./255., mask_out='none', **kwargs):
|
| 11 |
+
'''
|
| 12 |
+
Adapted from: https://github.com/wagnermoritz/GSE/tree/main
|
| 13 |
+
Implementation of the sparse Frank-Wolfe attack SAIF
|
| 14 |
+
https://arxiv.org/pdf/2212.07495.pdf
|
| 15 |
+
|
| 16 |
+
args:
|
| 17 |
+
model: Callable, PyTorch classifier.
|
| 18 |
+
img_range: Tuple of ints/floats, lower and upper bound of image
|
| 19 |
+
entries.
|
| 20 |
+
targeted: Bool, given label is used as a target label if True.
|
| 21 |
+
steps: Int, number of FW iterations.
|
| 22 |
+
r0: Int, parameter for step size computation.
|
| 23 |
+
ver: Bool, print progress if True.
|
| 24 |
+
'''
|
| 25 |
+
super().__init__(model, targeted=targeted, img_range=img_range)
|
| 26 |
+
self.steps = steps
|
| 27 |
+
self.r0 = r0
|
| 28 |
+
self.loss_fn = torch.nn.CrossEntropyLoss()
|
| 29 |
+
self.ver = ver
|
| 30 |
+
self.k = k
|
| 31 |
+
self.eps = eps
|
| 32 |
+
if mask_out != 'none':
|
| 33 |
+
self.mask_out = mask_out
|
| 34 |
+
else:
|
| 35 |
+
self.mask_out = None
|
| 36 |
+
|
| 37 |
+
def _set_mask(self, data):
|
| 38 |
+
mask = torch.ones_like(data)
|
| 39 |
+
if self.mask_out == 'context':
|
| 40 |
+
mask[:, :-1, ...] = 0
|
| 41 |
+
elif self.mask_out == 'query':
|
| 42 |
+
mask[:, -1, ...] = 0
|
| 43 |
+
elif isinstance(self.mask_out, int):
|
| 44 |
+
mask[:, self.mask_out, ...] = 0
|
| 45 |
+
elif self.mask_out is None:
|
| 46 |
+
pass
|
| 47 |
+
else:
|
| 48 |
+
raise NotImplementedError(f'Unknown mask_out: {self.mask_out}')
|
| 49 |
+
return mask
|
| 50 |
+
|
| 51 |
+
def __call__(self, x):
|
| 52 |
+
'''
|
| 53 |
+
Perform the attack on a batch of images x.
|
| 54 |
+
|
| 55 |
+
args:
|
| 56 |
+
x: Tensor of shape [B, C, H, W], batch of images.
|
| 57 |
+
k: Int, sparsity parameter,
|
| 58 |
+
eps: Float, perturbation magnitude parameter.
|
| 59 |
+
|
| 60 |
+
Returns a tensor of the same shape as x containing adversarial examples.
|
| 61 |
+
'''
|
| 62 |
+
assert x.shape[0] == 1, "Only support batch size 1 for now"
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
for param in self.model.model.parameters():
|
| 67 |
+
param.requires_grad = False
|
| 68 |
+
|
| 69 |
+
B, C, H, W = x.shape[1], x.shape[3], x.shape[4], x.shape[5]
|
| 70 |
+
x = x.to(self.device)
|
| 71 |
+
batchidx = torch.arange(B).view(-1, 1)
|
| 72 |
+
|
| 73 |
+
mask_out = self._set_mask(x)
|
| 74 |
+
# compute p_0 and s_0
|
| 75 |
+
x_ = x.clone()
|
| 76 |
+
x_.requires_grad = True
|
| 77 |
+
out = self.model(x_)
|
| 78 |
+
loss = -out.sum() if not self.targeted else out.sum()
|
| 79 |
+
x__grad = torch.autograd.grad(loss, [x_])[0].detach() * mask_out
|
| 80 |
+
p = -self.eps * x__grad.sign()
|
| 81 |
+
p = p.detach().half()
|
| 82 |
+
ksmallest = torch.topk(-x__grad.view(B, -1), self.k, dim=1)[1]
|
| 83 |
+
ksmask = torch.zeros((B, C * H * W), device=self.device)
|
| 84 |
+
ksmask[batchidx, ksmallest] = 1
|
| 85 |
+
s = torch.logical_and(ksmask.view(*x.shape), x__grad < 0).float()
|
| 86 |
+
s = s.detach().half()
|
| 87 |
+
|
| 88 |
+
r = self.r0
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
for t in range(self.steps):
|
| 92 |
+
if self.ver:
|
| 93 |
+
print(f'\r Iteration {t+1}/{self.steps}', end='')
|
| 94 |
+
p.requires_grad = True
|
| 95 |
+
s.requires_grad = True
|
| 96 |
+
|
| 97 |
+
D = self.Loss_fn(x, s, p, mask_out)
|
| 98 |
+
D.backward()
|
| 99 |
+
|
| 100 |
+
mp = p.grad * mask_out
|
| 101 |
+
ms = s.grad * mask_out
|
| 102 |
+
with torch.no_grad():
|
| 103 |
+
# inf-norm LMO
|
| 104 |
+
v = (-self.eps * mp.sign()).half()
|
| 105 |
+
|
| 106 |
+
# 1-norm LMO
|
| 107 |
+
ksmallest = torch.topk(-ms.view(B, -1), self.k, dim=1)[1]
|
| 108 |
+
ksmask = torch.zeros((B, C * H * W), device=self.device)
|
| 109 |
+
ksmask[batchidx, ksmallest] = 1
|
| 110 |
+
ksmask = ksmask.view(*x.shape) * mask_out
|
| 111 |
+
z = torch.logical_and(ksmask, ms < 0).float().half()
|
| 112 |
+
# update stepsize until primal progress is made
|
| 113 |
+
mu = 1 / (2 ** r * math.sqrt(t + 1))
|
| 114 |
+
progress_condition = (self.Loss_fn(x, s + mu * (z - s), p + mu * (v - p), mask_out)
|
| 115 |
+
> D)
|
| 116 |
+
|
| 117 |
+
while progress_condition:
|
| 118 |
+
r += 1
|
| 119 |
+
if r >= 50:
|
| 120 |
+
break
|
| 121 |
+
mu = 1 / (2 ** r * math.sqrt(t + 1))
|
| 122 |
+
progress_condition = (self.Loss_fn(x, s + mu * (z - s), p + mu * (v - p), mask_out)
|
| 123 |
+
> D)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
p = p + mu * (v - p)
|
| 127 |
+
s = s + mu * (z - s)
|
| 128 |
+
|
| 129 |
+
x_adv = torch.clamp(x + p, *self.img_range)
|
| 130 |
+
p = x_adv - x
|
| 131 |
+
|
| 132 |
+
if self.ver and t % 10 == 0:
|
| 133 |
+
print(f" Loss: {D}")
|
| 134 |
+
if self.ver:
|
| 135 |
+
print('')
|
| 136 |
+
return (x + s * p * mask_out).detach(), torch.norm(s*p,p=0).item()
|
| 137 |
+
|
| 138 |
+
def Loss_fn(self, x, s, p, mask_out):
|
| 139 |
+
out = self.model(x + s * p * mask_out).sum()
|
| 140 |
+
if self.targeted:
|
| 141 |
+
return out
|
| 142 |
+
else:
|
| 143 |
+
return -out
|
vlm_eval/attacks/sparsers.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code taken and adapted from https://github.com/wagnermoritz/GSE
|
| 2 |
+
from vlm_eval.attacks.attack import Attack
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
class SparseRS(Attack):
|
| 6 |
+
def __init__(self, model, *args, targeted=False, img_range=(-1, 1),
|
| 7 |
+
n_queries=10000, k=100, n_restarts=10, alpha_init=0.8, mask_out='none',**kwargs):
|
| 8 |
+
'''
|
| 9 |
+
Implementation of the L0 variant SparseRS https://arxiv.org/abs/2006.12834
|
| 10 |
+
Authors' implementation: https://github.com/fra31/sparse-rs
|
| 11 |
+
Adapted from: https://github.com/wagnermoritz/GSE/tree/main
|
| 12 |
+
|
| 13 |
+
args:
|
| 14 |
+
model: Callable, PyTorch classifier.
|
| 15 |
+
targeted: Bool, given label is used as a target label if True.
|
| 16 |
+
img_range: Tuple of ints/floats, lower and upper bound of image
|
| 17 |
+
entries.
|
| 18 |
+
n_queries: Int, max number of queries to the model
|
| 19 |
+
k: Int, initial sparsity parameter
|
| 20 |
+
n_restarts: Int, number of restarts with random initialization
|
| 21 |
+
alpha_init: Float, inital value for alpha schedule
|
| 22 |
+
'''
|
| 23 |
+
super().__init__(model, targeted=targeted, img_range=img_range)
|
| 24 |
+
self.n_queries = n_queries
|
| 25 |
+
self.k = k
|
| 26 |
+
self.n_restarts = n_restarts
|
| 27 |
+
self.alpha_init = alpha_init
|
| 28 |
+
if mask_out != 'none':
|
| 29 |
+
self.mask_out = mask_out
|
| 30 |
+
else:
|
| 31 |
+
self.mask_out = None
|
| 32 |
+
|
| 33 |
+
def _set_mask(self, data):
|
| 34 |
+
mask = torch.ones_like(data)
|
| 35 |
+
if self.mask_out == 'context':
|
| 36 |
+
mask[:, :-1, ...] = 0
|
| 37 |
+
elif self.mask_out == 'query':
|
| 38 |
+
mask[:, -1, ...] = 0
|
| 39 |
+
elif isinstance(self.mask_out, int):
|
| 40 |
+
mask[:, self.mask_out, ...] = 0
|
| 41 |
+
elif self.mask_out is None:
|
| 42 |
+
pass
|
| 43 |
+
else:
|
| 44 |
+
raise NotImplementedError(f'Unknown mask_out: {self.mask_out}')
|
| 45 |
+
return mask
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def __call__(self, x, *args, **kwargs):
|
| 49 |
+
'''
|
| 50 |
+
Perform SparseRS L0 on a batch of images x with corresponding labels y.
|
| 51 |
+
|
| 52 |
+
args:
|
| 53 |
+
x: Tensor of shape [B, C, H, W], batch of images.
|
| 54 |
+
y: Tensor of shape [B], batch of labels.
|
| 55 |
+
|
| 56 |
+
Returns a tensor of the same shape as x containing adversarial examples
|
| 57 |
+
'''
|
| 58 |
+
|
| 59 |
+
for param in self.model.model.parameters():
|
| 60 |
+
param.requires_grad = False
|
| 61 |
+
|
| 62 |
+
torch.random.manual_seed(0)
|
| 63 |
+
torch.cuda.random.manual_seed(0)
|
| 64 |
+
x = x.to(self.device)
|
| 65 |
+
|
| 66 |
+
with torch.no_grad():
|
| 67 |
+
for _ in range(self.n_restarts):
|
| 68 |
+
if len(x) == 0:
|
| 69 |
+
break
|
| 70 |
+
|
| 71 |
+
x_adv = self.__perturb(x.clone())
|
| 72 |
+
|
| 73 |
+
return x_adv.detach()
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def __perturb(self, x):
|
| 77 |
+
'''
|
| 78 |
+
Perform the attack from a random starting point.
|
| 79 |
+
'''
|
| 80 |
+
mask_out = self._set_mask(x)
|
| 81 |
+
B, C, H, W = x.shape[1], x.shape[3], x.shape[4], x.shape[5]
|
| 82 |
+
batchidx = torch.arange(B, device=self.device).view(-1, 1)
|
| 83 |
+
result = x.clone().view(B, C, H, W)
|
| 84 |
+
|
| 85 |
+
# M: set of perturbed pixel indices, U_M: set of unperturbed pixel indices
|
| 86 |
+
batch_randperm = torch.rand(B, H * W, device=self.device).argsort(dim=1)
|
| 87 |
+
M = batch_randperm[:, :self.k]
|
| 88 |
+
U_M = batch_randperm[:, self.k:]
|
| 89 |
+
result[batchidx, :, M//W, M%H] = self.__sampleDelta(B, C, self.k)
|
| 90 |
+
|
| 91 |
+
best_loss = self.__lossfn(result.view(*x.shape))
|
| 92 |
+
|
| 93 |
+
for i in range(1, self.n_queries):
|
| 94 |
+
if B == 0:
|
| 95 |
+
break
|
| 96 |
+
# reset k_i currently perturbed pixels and perturb k_i new pixels
|
| 97 |
+
k_i = max(int(self.__alphaSchedule(i) * self.k), 1)
|
| 98 |
+
A_idx = torch.randperm(self.k, device=self.device)[:k_i]
|
| 99 |
+
B_idx = torch.randperm(H * W - self.k, device=self.device)[:k_i]
|
| 100 |
+
A_set, B_set = M[:, A_idx], U_M[:, B_idx]
|
| 101 |
+
|
| 102 |
+
z = result.clone()
|
| 103 |
+
z[batchidx, :, A_set//W, A_set%H] = x.view(B, C, H, W)[batchidx, :, A_set//W, A_set%H]
|
| 104 |
+
if k_i > 1:
|
| 105 |
+
z[batchidx, :, B_set//W, B_set%H] = self.__sampleDelta(B, C, k_i)
|
| 106 |
+
else: # if only one pixel is changed, make sure it actually changes
|
| 107 |
+
new_color = self.__sampleDelta(B, C, k_i)
|
| 108 |
+
while (mask := (z[batchidx, :, B_set//W, B_set%H] == new_color).view(B, -1).all(dim=-1)).any():
|
| 109 |
+
new_color[mask] = self.__sampleDelta(mask.int().sum().item(), C, k_i)
|
| 110 |
+
z[batchidx, :, B_set//W, B_set%H] = new_color
|
| 111 |
+
|
| 112 |
+
# save perturbations that improved the loss/margin
|
| 113 |
+
loss = self.__lossfn(z, y)
|
| 114 |
+
mask = loss < best_loss
|
| 115 |
+
best_loss[mask] = loss[mask]
|
| 116 |
+
mask = torch.logical_or(mask, margin < -1e-6)
|
| 117 |
+
if mask.any():
|
| 118 |
+
#best_margin[mask] = margin[mask]
|
| 119 |
+
tmp = result[active]
|
| 120 |
+
tmp[mask] = z[mask]
|
| 121 |
+
result[active] = tmp
|
| 122 |
+
U_M[mask.nonzero().view(-1, 1), B_idx] = A_set[mask]
|
| 123 |
+
M[mask.nonzero().view(-1, 1), A_idx] = B_set[mask]
|
| 124 |
+
|
| 125 |
+
# stop working on successful adv examples
|
| 126 |
+
mask = best_margin < 0
|
| 127 |
+
if mask.any():
|
| 128 |
+
mask = torch.logical_not(mask)
|
| 129 |
+
active[active.clone()] = mask
|
| 130 |
+
x, y, z, M, U_M = x[mask], y[mask], z[mask], M[mask], U_M[mask]
|
| 131 |
+
best_margin, best_loss = best_margin[mask], best_loss[mask]
|
| 132 |
+
B = len(y)
|
| 133 |
+
batchidx = torch.arange(B, device=self.device).view(-1, 1)
|
| 134 |
+
|
| 135 |
+
return result
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def __sampleDelta(self, B, C, k):
|
| 139 |
+
'''
|
| 140 |
+
Sample k-pixel perturbations for B images. Each pixel is assigned a
|
| 141 |
+
random corner in the C-dimensional cube defined by self.img_range.
|
| 142 |
+
'''
|
| 143 |
+
fac = self.img_range[1] - self.img_range[0]
|
| 144 |
+
return self.img_range[0] + fac * torch.randint(0, 1, [B, k, C],
|
| 145 |
+
dtype=torch.float,
|
| 146 |
+
device=self.device)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def __alphaSchedule(self, iteration):
|
| 150 |
+
'''
|
| 151 |
+
Update number of pixels to perturb based in the current iteration.
|
| 152 |
+
'''
|
| 153 |
+
iteration = int(iteration / self.n_queries * 10000)
|
| 154 |
+
factors = [1, 2, 4, 5, 6, 8, 10, 12, 15, 20]
|
| 155 |
+
alpha_schedule = [10, 50, 200, 500, 1000, 2000, 4000, 6000, 8000]
|
| 156 |
+
idx = bisect.bisect_left(alpha_schedule, iteration)
|
| 157 |
+
return self.alpha_init / factors[idx]
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def __lossfn(self, x):
|
| 161 |
+
'''
|
| 162 |
+
Compute the loss depending on self.targeted.
|
| 163 |
+
'''
|
| 164 |
+
return self.model(x).sum() if self.targeted else -self.model(x).sum()
|
vlm_eval/attacks/strattack.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code taken and adapted from https://github.com/wagnermoritz/GSE
|
| 2 |
+
|
| 3 |
+
from vlm_eval.attacks.attack import Attack
|
| 4 |
+
import torch
|
| 5 |
+
import math
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
class StrAttack(Attack):
|
| 9 |
+
def __init__(self, model, *args, targeted=False, img_range=(0, 1), kappa=0,
|
| 10 |
+
max_iter=100, ver=False, search_steps=2, max_c=1e10, rho=1, mask_out='none',
|
| 11 |
+
c=2.5, retrain=False, **kwargs):
|
| 12 |
+
'''
|
| 13 |
+
Implementation of StrAttack: https://arxiv.org/abs/1808.01664
|
| 14 |
+
Adapted from https://github.com/KaidiXu/StrAttack
|
| 15 |
+
|
| 16 |
+
args:
|
| 17 |
+
model: Callable, PyTorch classifier.
|
| 18 |
+
targeted: Bool, given label is used as a target label if True.
|
| 19 |
+
img_range: Tuple of ints/floats, lower and upper bound of image
|
| 20 |
+
entries.
|
| 21 |
+
max_iter: Int, number of iterations.
|
| 22 |
+
ver: Bool, print progress if True.
|
| 23 |
+
search_steps: Int, number of binary search steps.
|
| 24 |
+
max_c: Float, upper bound for regularizaion parameter.
|
| 25 |
+
rho: Float, ADMM parameter.
|
| 26 |
+
c: Float, initial regularization parameter.
|
| 27 |
+
'''
|
| 28 |
+
super().__init__(model, targeted=targeted, img_range=img_range)
|
| 29 |
+
self.max_iter = max_iter
|
| 30 |
+
self.ver = ver
|
| 31 |
+
self.search_steps = search_steps
|
| 32 |
+
self.max_c = max_c
|
| 33 |
+
self.rho = rho
|
| 34 |
+
self.c = c
|
| 35 |
+
self.retrain = retrain
|
| 36 |
+
if mask_out != 'none':
|
| 37 |
+
self.mask_out = mask_out
|
| 38 |
+
else:
|
| 39 |
+
self.mask_out = None
|
| 40 |
+
|
| 41 |
+
def _set_mask(self, data):
|
| 42 |
+
mask = torch.ones_like(data)
|
| 43 |
+
if self.mask_out == 'context':
|
| 44 |
+
mask[:, :-1, ...] = 0
|
| 45 |
+
elif self.mask_out == 'query':
|
| 46 |
+
mask[:, -1, ...] = 0
|
| 47 |
+
elif isinstance(self.mask_out, int):
|
| 48 |
+
mask[:, self.mask_out, ...] = 0
|
| 49 |
+
elif self.mask_out is None:
|
| 50 |
+
pass
|
| 51 |
+
else:
|
| 52 |
+
raise NotImplementedError(f'Unknown mask_out: {self.mask_out}')
|
| 53 |
+
return mask
|
| 54 |
+
|
| 55 |
+
def __call__(self, imgs, *args, **kwargs):
|
| 56 |
+
'''
|
| 57 |
+
Perform StrAttack on a batch of images x with corresponding labels y.
|
| 58 |
+
|
| 59 |
+
args:
|
| 60 |
+
x: Tensor of shape [B, C, H, W], batch of images.
|
| 61 |
+
|
| 62 |
+
Returns a tensor of the same shape as x containing adversarial examples
|
| 63 |
+
'''
|
| 64 |
+
|
| 65 |
+
for param in self.model.model.parameters():
|
| 66 |
+
param.requires_grad = False
|
| 67 |
+
|
| 68 |
+
c_ = self.c
|
| 69 |
+
imgs = imgs.to(self.device)
|
| 70 |
+
sh = imgs.shape
|
| 71 |
+
batch_size = sh[1]
|
| 72 |
+
mask_out = self._set_mask(imgs)
|
| 73 |
+
|
| 74 |
+
alpha, tau, gamma = 5, 2, 1
|
| 75 |
+
eps = torch.full_like(imgs, 1.0) * mask_out
|
| 76 |
+
# 16 for imagenet, 2 for CIFAR and MNIST
|
| 77 |
+
filterSize = 8 if sh[-1] > 32 else 2
|
| 78 |
+
stride = filterSize
|
| 79 |
+
# convolution kernel used to compute norm of each group
|
| 80 |
+
slidingM = torch.ones((1, sh[3], filterSize, filterSize), device=self.device)
|
| 81 |
+
|
| 82 |
+
cs = torch.ones(batch_size, device=self.device) * c_
|
| 83 |
+
lower_bound = torch.zeros(batch_size)
|
| 84 |
+
upper_bound = torch.ones(batch_size) * self.max_c
|
| 85 |
+
|
| 86 |
+
o_bestl2 = torch.full_like(torch.randn(batch_size), 1e10, dtype=torch.float)
|
| 87 |
+
o_bestscore = torch.full_like(o_bestl2, -1, dtype=torch.float)
|
| 88 |
+
o_bestattack = imgs.clone()
|
| 89 |
+
o_besty = torch.ones_like(imgs)
|
| 90 |
+
|
| 91 |
+
for step in range(self.search_steps):
|
| 92 |
+
|
| 93 |
+
bestl2 = torch.full_like(o_bestl2, 1e10, dtype=torch.float)
|
| 94 |
+
bestscore = torch.full_like(o_bestl2, -1, dtype=torch.float)
|
| 95 |
+
|
| 96 |
+
z, v, u, s = (torch.zeros_like(imgs) for _ in range(4))
|
| 97 |
+
|
| 98 |
+
for iter_ in range(self.max_iter):
|
| 99 |
+
if (not iter_%10 or iter_ == self.max_iter - 1) and self.ver:
|
| 100 |
+
print(f'\rIteration: {iter_+1}/{self.max_iter}, ' +
|
| 101 |
+
f'Search Step: {step+1}/{self.search_steps}', end='')
|
| 102 |
+
|
| 103 |
+
# first update step (7) / Proposition 1
|
| 104 |
+
delta = self.rho / (self.rho + 2 * gamma) * (z - u / self.rho)
|
| 105 |
+
|
| 106 |
+
b = (z - s / self.rho) * mask_out
|
| 107 |
+
tmp = torch.minimum(self.img_range[1] - imgs, eps)
|
| 108 |
+
w = torch.where(b.view(*sh) > tmp.view(*sh), tmp, b) # creating issue (1x5x'5'x3x224x224 instead of 1x5x1x3x224x224)
|
| 109 |
+
tmp = torch.maximum(self.img_range[0] - imgs, -eps)
|
| 110 |
+
w = torch.where(b.view(*sh) < tmp.view(*sh), tmp, w)
|
| 111 |
+
|
| 112 |
+
c = z - v / self.rho
|
| 113 |
+
cNorm = torch.sqrt(F.conv2d(c.view(sh[1], sh[3], sh[4], sh[5]) ** 2, slidingM, stride=stride))
|
| 114 |
+
cNorm = torch.where(cNorm == 0, torch.full_like(cNorm, 1e-12), cNorm)
|
| 115 |
+
cNorm = F.interpolate(cNorm, scale_factor=filterSize)
|
| 116 |
+
y = torch.clamp((1 - tau / (self.rho * cNorm.unsqueeze(0).unsqueeze(3))), 0) * c
|
| 117 |
+
|
| 118 |
+
# second update step (8) / equation (15)
|
| 119 |
+
z_grads = self.__get_z_grad(imgs, z.clone(), cs)
|
| 120 |
+
eta = alpha * math.sqrt(iter_ + 1)
|
| 121 |
+
coeff = (1 / (eta + 3 * self.rho))
|
| 122 |
+
z = coeff * (eta * z + self.rho * (delta + w + y) + u + s + v - z_grads)
|
| 123 |
+
|
| 124 |
+
# third update step (9)
|
| 125 |
+
u = u + self.rho * (delta - z) * mask_out
|
| 126 |
+
v = v + self.rho * (y - z) * mask_out
|
| 127 |
+
s = s + self.rho * (w - z) * mask_out
|
| 128 |
+
# get info for binary search
|
| 129 |
+
x = imgs + y * mask_out
|
| 130 |
+
l2s = torch.sum((z ** 2).reshape(z.size(1), -1), dim=-1)
|
| 131 |
+
|
| 132 |
+
for i, (l2, x_) in enumerate(zip(l2s, x.squeeze(0))):
|
| 133 |
+
if l2 < bestl2[i]:
|
| 134 |
+
bestl2[i] = l2
|
| 135 |
+
if l2 < o_bestl2[i]:
|
| 136 |
+
o_bestl2[i] = l2
|
| 137 |
+
o_bestattack[:,i] = x_.detach().unsqueeze(0).clone()
|
| 138 |
+
o_besty[:,i] = y[:,i]
|
| 139 |
+
for i in range(batch_size):
|
| 140 |
+
|
| 141 |
+
lower_bound[i] = max(lower_bound[i], cs[i])
|
| 142 |
+
if upper_bound[i] < 1e9:
|
| 143 |
+
cs[i] = (lower_bound[i] + upper_bound[i]) / 2
|
| 144 |
+
else:
|
| 145 |
+
cs[i] *= 5
|
| 146 |
+
|
| 147 |
+
del v, u, s, z_grads, w, tmp
|
| 148 |
+
|
| 149 |
+
if self.retrain:
|
| 150 |
+
cs = torch.full_like(o_bestl2, 5.0, dtype=torch.float)
|
| 151 |
+
zeros = torch.zeros_like(imgs)
|
| 152 |
+
|
| 153 |
+
for step in range(8):
|
| 154 |
+
bestl2 = torch.full_like(cs, 1e10, dtype=torch.float, device=self.device)
|
| 155 |
+
bestscore = torch.full_like(cs, -1, dtype=torch.float, device=self.device)
|
| 156 |
+
|
| 157 |
+
Nz = o_besty[o_besty != 0]
|
| 158 |
+
e0 = torch.quantile(Nz.abs(), 0.03)
|
| 159 |
+
A2 = torch.where(o_besty.abs() <= e0, 0, 1)
|
| 160 |
+
z1 = o_besty
|
| 161 |
+
u1 = torch.zeros_like(imgs)
|
| 162 |
+
tmpc = self.rho / (self.rho + gamma / 100)
|
| 163 |
+
|
| 164 |
+
for j in range(100):
|
| 165 |
+
if self.ver and not j % 10:
|
| 166 |
+
print(f'\rRetrain iteration: {step+1}/8, ' +
|
| 167 |
+
f'Search Step: {j+1}/200', end='')
|
| 168 |
+
|
| 169 |
+
tmpA = (z1 - u1) * tmpc
|
| 170 |
+
tmpA1 = torch.where(o_besty.abs() <= e0, zeros, tmpA)
|
| 171 |
+
cond = torch.logical_and(tmpA >
|
| 172 |
+
torch.minimum(self.img_range[1] - imgs, eps),
|
| 173 |
+
o_besty.abs() > e0)
|
| 174 |
+
tmpA2 = torch.where(cond, torch.minimum(self.img_range[1] - imgs, eps),
|
| 175 |
+
tmpA1)
|
| 176 |
+
cond = torch.logical_and(tmpA <
|
| 177 |
+
torch.maximum(self.img_range[0] - imgs, -eps),
|
| 178 |
+
o_besty.abs() > e0)
|
| 179 |
+
deltA = torch.where(cond, torch.maximum(self.img_range[0] - imgs, -eps),
|
| 180 |
+
tmpA2)
|
| 181 |
+
|
| 182 |
+
x = imgs + deltA * mask_out
|
| 183 |
+
grad = self.__get_z_grad(imgs, deltA, cs)
|
| 184 |
+
|
| 185 |
+
stepsize = 1 / (alpha + 2 * self.rho)
|
| 186 |
+
z1 = stepsize * (alpha * z1 * self.rho
|
| 187 |
+
* (deltA + u1) - grad * A2)
|
| 188 |
+
u1 = u1 + deltA - z1
|
| 189 |
+
|
| 190 |
+
for i, (l2, x_) in enumerate(zip(l2s, x.squeeze(0))):
|
| 191 |
+
if l2 < bestl2[i]:
|
| 192 |
+
bestl2[i] = l2
|
| 193 |
+
#bestscore[i] = asc
|
| 194 |
+
if l2 < o_bestl2[i]:
|
| 195 |
+
o_bestl2[i] = l2
|
| 196 |
+
#o_bestscore[i] = asc
|
| 197 |
+
o_bestattack[:,i] = x_.detach().unsqueeze(0).clone()
|
| 198 |
+
o_besty[i] = deltA[i]
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
for i in range(batch_size):
|
| 202 |
+
if (bestscore[i] != -1 and bestl2[i] == o_bestl2[i]):
|
| 203 |
+
upper_bound[i] = min(upper_bound[i], cs[i])
|
| 204 |
+
if upper_bound[i] < 1e9:
|
| 205 |
+
cs[i] = (lower_bound[i] + upper_bound[i]) / 2
|
| 206 |
+
|
| 207 |
+
else:
|
| 208 |
+
lower_bound[i] = max(lower_bound[i], cs[i])
|
| 209 |
+
if upper_bound[i] < 1e9:
|
| 210 |
+
cs[i] = (lower_bound[i] + upper_bound[i]) / 2
|
| 211 |
+
else:
|
| 212 |
+
cs[i] *= 5
|
| 213 |
+
|
| 214 |
+
if self.ver:
|
| 215 |
+
print('')
|
| 216 |
+
|
| 217 |
+
return (o_bestattack * mask_out).detach()
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def __get_z_grad(self, imgs, z, cs):
|
| 221 |
+
'''
|
| 222 |
+
Compute and return gradient of loss wrt. z.
|
| 223 |
+
'''
|
| 224 |
+
z.requires_grad = True
|
| 225 |
+
tmp = self.model(z + imgs).sum() if self.targeted else -self.model(z + imgs).sum()
|
| 226 |
+
loss = torch.mean(cs.to(self.device) * tmp)
|
| 227 |
+
z_grad_data = torch.autograd.grad(loss, [z])[0].detach()
|
| 228 |
+
z.detach_()
|
| 229 |
+
return z_grad_data
|
vlm_eval/attacks/utils.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import collections.abc as container_abcs
|
| 4 |
+
|
| 5 |
+
# Code taken from https://github.com/chs20/RobustVLM/tree/main
|
| 6 |
+
# some parts of this code are adapted from
|
| 7 |
+
# https://github.com/M4xim4l/InNOutRobustness/blob/main/utils/adversarial_attacks/utils.py
|
| 8 |
+
|
| 9 |
+
def project_perturbation(perturbation, eps, norm):
|
| 10 |
+
if norm in ['inf', 'linf', 'Linf']:
|
| 11 |
+
pert_normalized = torch.clamp(perturbation, -eps, eps)
|
| 12 |
+
return pert_normalized
|
| 13 |
+
elif norm in [2, 2.0, 'l2', 'L2', '2']:
|
| 14 |
+
pert_normalized = torch.renorm(perturbation, p=2, dim=0, maxnorm=eps)
|
| 15 |
+
return pert_normalized
|
| 16 |
+
else:
|
| 17 |
+
raise NotImplementedError(f'Norm {norm} not supported')
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def normalize_grad(grad, p):
|
| 21 |
+
if p in ['inf', 'linf', 'Linf']:
|
| 22 |
+
return grad.sign()
|
| 23 |
+
elif p in [2, 2.0, 'l2', 'L2', '2']:
|
| 24 |
+
bs = grad.shape[0]
|
| 25 |
+
grad_flat = grad.view(bs, -1)
|
| 26 |
+
grad_normalized = F.normalize(grad_flat, p=2, dim=1)
|
| 27 |
+
return grad_normalized.view_as(grad)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def L1_norm(x, keepdim=False):
|
| 31 |
+
z = x.abs().view(x.shape[0], -1).sum(-1)
|
| 32 |
+
if keepdim:
|
| 33 |
+
z = z.view(-1, *[1]*(len(x.shape) - 1))
|
| 34 |
+
return z
|
| 35 |
+
|
| 36 |
+
def L2_norm(x, keepdim=False):
|
| 37 |
+
z = (x ** 2).view(x.shape[0], -1).sum(-1).sqrt()
|
| 38 |
+
if keepdim:
|
| 39 |
+
z = z.view(-1, *[1]*(len(x.shape) - 1))
|
| 40 |
+
return z
|
| 41 |
+
|
| 42 |
+
def L0_norm(x):
|
| 43 |
+
return (x != 0.).view(x.shape[0], -1).sum(-1)
|
| 44 |
+
|
| 45 |
+
def zero_gradients(x):
|
| 46 |
+
if isinstance(x, torch.Tensor):
|
| 47 |
+
if x.grad is not None:
|
| 48 |
+
x.grad.detach_()
|
| 49 |
+
x.grad.zero_()
|
| 50 |
+
elif isinstance(x, container_abcs.Iterable):
|
| 51 |
+
for elem in x:
|
| 52 |
+
zero_gradients(elem)
|
vlm_eval/clip_classification.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code adapted from https://github.com/openai/CLIP/blob/main/
|
| 2 |
+
from transformers import CLIPProcessor, CLIPModel
|
| 3 |
+
import argparse
|
| 4 |
+
import torch
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from datasets_classes_templates import data_seeds
|
| 8 |
+
import numpy as np
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
|
| 11 |
+
def zeroshot_classifier(classnames, templates, processor, model):
|
| 12 |
+
with torch.no_grad():
|
| 13 |
+
zeroshot_weights = []
|
| 14 |
+
for classname in tqdm(classnames):
|
| 15 |
+
texts = [template.format(classname) for template in templates] #format with class
|
| 16 |
+
text_inputs = processor(text=texts, return_tensors="pt", padding=True, truncation=True).to('cuda')
|
| 17 |
+
class_embeddings = model.get_text_features(text_inputs['input_ids']) #embed with text encoder
|
| 18 |
+
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
| 19 |
+
class_embedding = class_embeddings.mean(dim=0)
|
| 20 |
+
class_embedding /= class_embedding.norm()
|
| 21 |
+
zeroshot_weights.append(class_embedding)
|
| 22 |
+
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
|
| 23 |
+
return zeroshot_weights
|
| 24 |
+
|
| 25 |
+
def classification_collate_fn(batch):
|
| 26 |
+
images, labels = zip(*batch)
|
| 27 |
+
labels = torch.tensor(labels)
|
| 28 |
+
return images, labels
|
| 29 |
+
|
| 30 |
+
def accuracy(output, target, topk=(1,)):
|
| 31 |
+
pred = output.topk(max(topk), 1, True, True)[1].t()
|
| 32 |
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
| 33 |
+
return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def main():
|
| 37 |
+
|
| 38 |
+
parser = argparse.ArgumentParser()
|
| 39 |
+
parser.add_argument("--data", type=str, default=None, choices=['non_fine_tuned','MS_COCO','medium','base','all'], help='Data on which clip was fine-tuned')
|
| 40 |
+
parser.add_argument("--dataset", type=str, default="CIFAR10", choices=["CIFAR10", "CIFAR100", "ImageNet", "Caltech101", "Caltech256", "Food101"])
|
| 41 |
+
parser.add_argument("--method",type=str, default="COCO_CF", choices=['COCO_CF','APGD_1','APGD_4','NONE'])
|
| 42 |
+
args = parser.parse_args()
|
| 43 |
+
|
| 44 |
+
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
| 45 |
+
results_filename = f'./Results/fine_tuned_clip/zeroshot_image_classification_results_{args.dataset}_{args.data}_{args.method}_{current_time}.txt'
|
| 46 |
+
with open(results_filename, 'w') as f:
|
| 47 |
+
f.write(f'Arguments: {args}\n\n')
|
| 48 |
+
|
| 49 |
+
if args.data == 'MS_COCO':
|
| 50 |
+
assert args.method == 'NONE' and args.data == 'MS_COCO', 'Use NONE for method for MS_COCO data'
|
| 51 |
+
|
| 52 |
+
imagenet_path = '/software/ais2t/pytorch_datasets/imagenet/' # Fill the path for imagenet here
|
| 53 |
+
|
| 54 |
+
if args.dataset == "CIFAR10":
|
| 55 |
+
from datasets_classes_templates import CIFAR10_CLASSES_TEMPLATES as classes_templates
|
| 56 |
+
from torchvision.datasets import CIFAR10
|
| 57 |
+
data = CIFAR10(root='./image_classification_datasets/cifar10/', train=False, download=True)
|
| 58 |
+
elif args.dataset == "CIFAR100":
|
| 59 |
+
from datasets_classes_templates import CIFAR100_CLASSES_TEMPLATES as classes_templates
|
| 60 |
+
from torchvision.datasets import CIFAR100
|
| 61 |
+
data = CIFAR100(root='./image_classification_datasets/cifar100/', train=False, download=True)
|
| 62 |
+
elif args.dataset == "ImageNet":
|
| 63 |
+
from datasets_classes_templates import ImageNet_CLASSES_TEMPLATES as classes_templates
|
| 64 |
+
from torchvision.datasets import ImageNet
|
| 65 |
+
data = ImageNet(root=imagenet_path, split='val')
|
| 66 |
+
elif args.dataset == "Caltech101":
|
| 67 |
+
torch.manual_seed(42)
|
| 68 |
+
from datasets_classes_templates import Caltech101_CLASSES_TEMPLATES as classes_templates
|
| 69 |
+
from torchvision.datasets import Caltech101
|
| 70 |
+
data = Caltech101(root='./image_classification_datasets/', download=False)
|
| 71 |
+
train_size = int(0.8 * len(data)) # 80% for training
|
| 72 |
+
val_size = len(data) - train_size
|
| 73 |
+
_, data = torch.utils.data.random_split(data, [train_size, val_size])
|
| 74 |
+
elif args.dataset == "Caltech256":
|
| 75 |
+
torch.manual_seed(42)
|
| 76 |
+
from datasets_classes_templates import Caltech256_CLASSES_TEMPLATES as classes_templates
|
| 77 |
+
from torchvision.datasets import Caltech256
|
| 78 |
+
data = Caltech256(root='./image_classification_datasets/', download=False)
|
| 79 |
+
train_size = int(0.8 * len(data)) # 80% for training
|
| 80 |
+
val_size = len(data) - train_size
|
| 81 |
+
_, data = torch.utils.data.random_split(data, [train_size, val_size])
|
| 82 |
+
elif args.dataset == "Food101":
|
| 83 |
+
from datasets_classes_templates import Food101_CLASSES_TEMPLATES as classes_templates
|
| 84 |
+
from torchvision.datasets import Food101
|
| 85 |
+
data = Food101(root='./image_classification_datasets/food101/', download=True, split='test')
|
| 86 |
+
else:
|
| 87 |
+
raise NotImplementedError
|
| 88 |
+
|
| 89 |
+
print(f'Conducting zero-shot image classification on {args.dataset}')
|
| 90 |
+
|
| 91 |
+
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
| 92 |
+
model_base_path = './fine_tuned_clip_models'
|
| 93 |
+
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 94 |
+
|
| 95 |
+
top1_list = []
|
| 96 |
+
for data_seed in data_seeds:
|
| 97 |
+
print(f'Conducting zero-shot image classification on {args.data} with seed {data_seed} for the method {args.method}')
|
| 98 |
+
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
|
| 99 |
+
if args.data != 'non_fine_tuned':
|
| 100 |
+
if args.method != 'NONE':
|
| 101 |
+
if args.data not in ['all']:
|
| 102 |
+
model.load_state_dict(torch.load(f'{model_base_path}/{args.method}/clip_model_dataset_{args.data}_method_{args.method}_num_epochs_20_data_seed_{data_seed}.pt'))
|
| 103 |
+
else:
|
| 104 |
+
model.load_state_dict(torch.load(f'{model_base_path}/{args.method}/clip_model_dataset_{args.data}_method_{args.method}_num_epochs_20.pt'))
|
| 105 |
+
elif args.method == 'NONE' and args.data == 'MS_COCO':
|
| 106 |
+
model.load_state_dict(torch.load(f'{model_base_path}/{args.method}/clip_model_dataset_{args.data}_method_{args.method}_num_epochs_20.pt'))
|
| 107 |
+
|
| 108 |
+
model.eval()
|
| 109 |
+
|
| 110 |
+
data_loader = DataLoader(data, batch_size=128, collate_fn=classification_collate_fn, shuffle=False)
|
| 111 |
+
|
| 112 |
+
zeroshot_weights = zeroshot_classifier(classes_templates['classes'],
|
| 113 |
+
classes_templates['templates'],
|
| 114 |
+
processor,
|
| 115 |
+
model
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
with torch.no_grad():
|
| 119 |
+
top1, top5, n = 0., 0., 0.
|
| 120 |
+
for i, (images, target) in enumerate(tqdm(data_loader)):
|
| 121 |
+
target = target.to(device)
|
| 122 |
+
images = list(images)
|
| 123 |
+
|
| 124 |
+
images = processor(images=images, return_tensors="pt").to(device)
|
| 125 |
+
|
| 126 |
+
# predict
|
| 127 |
+
image_features = model.get_image_features(images['pixel_values']).to(device)
|
| 128 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
| 129 |
+
logits = 100. * image_features @ zeroshot_weights
|
| 130 |
+
|
| 131 |
+
# measure accuracy
|
| 132 |
+
acc1, acc5 = accuracy(logits, target, topk=(1, 5))
|
| 133 |
+
top1 += acc1
|
| 134 |
+
top5 += acc5
|
| 135 |
+
n += image_features.size(0)
|
| 136 |
+
|
| 137 |
+
top1 = (top1 / n) * 100
|
| 138 |
+
top5 = (top5 / n) * 100
|
| 139 |
+
|
| 140 |
+
with open(results_filename, 'a') as f:
|
| 141 |
+
f.write(f'Seed {data_seed}: Top-1 Accuracy: {top1:.2f}, Top-5 Accuracy: {top5:.2f}\n')
|
| 142 |
+
|
| 143 |
+
top1_list.append(top1)
|
| 144 |
+
|
| 145 |
+
print(f"Top-1 accuracy: {top1:.2f}")
|
| 146 |
+
print(f"Top-5 accuracy: {top5:.2f}")
|
| 147 |
+
print('-'*40)
|
| 148 |
+
|
| 149 |
+
if args.method == 'NONE' or args.data in ['MS_COCO','all'] or args.data == 'non_fine_tuned':
|
| 150 |
+
break
|
| 151 |
+
top1 = np.asarray(top1_list)
|
| 152 |
+
print(f'Mean of the top 1 accuracy is {np.mean(top1)}')
|
| 153 |
+
print(f'Standard deviation of the top 1 accuracy is {np.std(top1)}')
|
| 154 |
+
|
| 155 |
+
with open(results_filename, 'a') as f:
|
| 156 |
+
f.write(f'\nMean Top-1 Accuracy: {np.mean(top1):.2f}\n')
|
| 157 |
+
f.write(f'Standard Deviation of Top-1 Accuracy: {np.std(top1):.2f}\n')
|
| 158 |
+
|
| 159 |
+
if __name__ == "__main__":
|
| 160 |
+
main()
|
vlm_eval/clip_train.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code adapted from https://github.com/ylaxor/clip-like/blob/main/fine-tune-clip.ipynb
|
| 2 |
+
|
| 3 |
+
from random import seed, shuffle
|
| 4 |
+
from typing import Callable
|
| 5 |
+
import torch
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from transformers import CLIPProcessor, CLIPModel
|
| 8 |
+
from timm.scheduler import CosineLRScheduler
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ModelTrainer:
|
| 13 |
+
|
| 14 |
+
def __init__(self,
|
| 15 |
+
model: Callable,
|
| 16 |
+
processor: Callable,
|
| 17 |
+
data_name: str,
|
| 18 |
+
train_data_loader: torch.utils.data.DataLoader,
|
| 19 |
+
val_data_loader: torch.utils.data.DataLoader,
|
| 20 |
+
num_epochs: int,
|
| 21 |
+
learning_rate: float = 5e-5,
|
| 22 |
+
weight_decay: float = 1e-3,
|
| 23 |
+
device: str = "cuda:0",
|
| 24 |
+
save_model: bool = False,
|
| 25 |
+
save_model_path: str = "./fine_tuned_clip_models",
|
| 26 |
+
data_seed: int = 42,
|
| 27 |
+
method="COCO_CF",
|
| 28 |
+
) -> None:
|
| 29 |
+
|
| 30 |
+
self.model = model
|
| 31 |
+
self.processor = processor
|
| 32 |
+
self.data_name = data_name
|
| 33 |
+
self.train_data_loader = train_data_loader
|
| 34 |
+
self.val_data_loader = val_data_loader
|
| 35 |
+
self.num_epochs = num_epochs
|
| 36 |
+
self.learning_rate = learning_rate
|
| 37 |
+
self.weight_decay = weight_decay
|
| 38 |
+
self.device = device
|
| 39 |
+
self.save_model = save_model
|
| 40 |
+
self.save_model_path = save_model_path
|
| 41 |
+
self.data_seed = data_seed
|
| 42 |
+
self.method = method
|
| 43 |
+
|
| 44 |
+
self.optimizer = torch.optim.AdamW(
|
| 45 |
+
self.model.parameters(),
|
| 46 |
+
lr=self.learning_rate,
|
| 47 |
+
weight_decay=self.weight_decay
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def train(self):
|
| 52 |
+
self.model.train()
|
| 53 |
+
lr_scheduler = CosineLRScheduler(
|
| 54 |
+
self.optimizer,
|
| 55 |
+
t_initial=self.num_epochs * len(self.train_data_loader),
|
| 56 |
+
lr_min=2e-7,
|
| 57 |
+
warmup_lr_init=1e-7,
|
| 58 |
+
warmup_prefix=True,
|
| 59 |
+
warmup_t=3,
|
| 60 |
+
cycle_limit=1,
|
| 61 |
+
t_in_epochs=False,
|
| 62 |
+
)
|
| 63 |
+
progress_bar = tqdm(range(self.num_epochs))
|
| 64 |
+
for epoch in progress_bar:
|
| 65 |
+
running_loss = 0.0
|
| 66 |
+
for batch_idx, batch in enumerate(self.train_data_loader):
|
| 67 |
+
self.optimizer.zero_grad()
|
| 68 |
+
processed_input = self.processor(text=batch["caption"],
|
| 69 |
+
images=batch["image"],
|
| 70 |
+
return_tensors="pt",
|
| 71 |
+
padding=True,
|
| 72 |
+
max_length=128,
|
| 73 |
+
truncation=True
|
| 74 |
+
)
|
| 75 |
+
outputs = self.model(input_ids=processed_input['input_ids'].squeeze().to(self.device),
|
| 76 |
+
pixel_values=processed_input['pixel_values'].squeeze().to(self.device),
|
| 77 |
+
attention_mask=processed_input['attention_mask'].squeeze().to(self.device),
|
| 78 |
+
return_loss=True
|
| 79 |
+
)
|
| 80 |
+
loss = outputs.loss
|
| 81 |
+
loss.backward()
|
| 82 |
+
running_loss += loss.item() * len(batch["caption"])
|
| 83 |
+
self.optimizer.step()
|
| 84 |
+
lr_scheduler.step_update(batch_idx + epoch * len(self.train_data_loader))
|
| 85 |
+
|
| 86 |
+
print(f"Epoch {epoch+1}/{self.num_epochs} Loss: {running_loss/len(self.train_data_loader.dataset):.4f}")
|
| 87 |
+
progress_bar.set_postfix(
|
| 88 |
+
epoch="{}/{}".format(epoch+1,self.num_epochs),
|
| 89 |
+
loss=running_loss/len(self.train_data_loader.dataset),
|
| 90 |
+
lr=self.optimizer.param_groups[0]["lr"]
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
if self.save_model:
|
| 94 |
+
if self.data_name not in ['MS_COCO','all']:
|
| 95 |
+
torch.save(self.model.state_dict(), self.save_model_path + f'clip_model_dataset_{self.data_name}_method_{self.method}_num_epochs_{self.num_epochs}_data_seed_{self.data_seed}.pt')
|
| 96 |
+
print(f"Saving fine-tuned model as clip_model_dataset_{self.data_name}_method_{self.method}_num_epochs_{self.num_epochs}_data_seed_{self.data_seed}.pt")
|
| 97 |
+
else:
|
| 98 |
+
torch.save(self.model.state_dict(), self.save_model_path + f'clip_model_dataset_{self.data_name}_method_{self.method}_num_epochs_{self.num_epochs}.pt')
|
| 99 |
+
print(f"Saving fine-tuned model as clip_model_dataset_{self.data_name}_method_{self.method}_num_epochs_{self.num_epochs}.pt")
|
| 100 |
+
|
| 101 |
+
def eval(self):
|
| 102 |
+
self.model.eval()
|
| 103 |
+
nb_batches = len(self.val_data_loader)
|
| 104 |
+
tqdm_object = tqdm(self.val_data_loader, total=len(self.val_data_loader))
|
| 105 |
+
epoch_loss = 0.0
|
| 106 |
+
for i, batch in enumerate(tqdm_object):
|
| 107 |
+
processed_input = self.processor(text=batch["caption"],
|
| 108 |
+
images=batch["image"],
|
| 109 |
+
return_tensors="pt",
|
| 110 |
+
padding=True,
|
| 111 |
+
max_length=128,
|
| 112 |
+
truncation=True
|
| 113 |
+
)
|
| 114 |
+
outputs = self.model(
|
| 115 |
+
input_ids=processed_input['input_ids'].squeeze().to(self.device),
|
| 116 |
+
attention_mask=processed_input['attention_mask'].squeeze().to(self.device),
|
| 117 |
+
pixel_values=processed_input['pixel_values'].squeeze().to(self.device),
|
| 118 |
+
return_loss=True)
|
| 119 |
+
loss, logits_per_image = outputs.loss, outputs.logits_per_image
|
| 120 |
+
epoch_loss += loss.item()
|
| 121 |
+
tqdm_object.set_postfix(
|
| 122 |
+
batch="{}/{}".format(i+1,nb_batches),
|
| 123 |
+
dev_loss=loss.item(),
|
| 124 |
+
)
|
| 125 |
+
epoch_loss = epoch_loss / nb_batches
|
| 126 |
+
print(f"Eval loss: {epoch_loss}")
|
| 127 |
+
|
| 128 |
+
def main():
|
| 129 |
+
import os
|
| 130 |
+
#os.environ['HF_HOME'] = '' Add path for saved hugging face models
|
| 131 |
+
|
| 132 |
+
import argparse
|
| 133 |
+
parser = argparse.ArgumentParser()
|
| 134 |
+
parser.add_argument('--num_epochs', type=int, default=20)
|
| 135 |
+
parser.add_argument('--data_name', type=str, default="MS_COCO", choices=["MS_COCO","base","medium","all"])
|
| 136 |
+
parser.add_argument('--learning_rate', type=float, default=1e-5)
|
| 137 |
+
parser.add_argument('--batch_size', type=int, default=32)
|
| 138 |
+
parser.add_argument('--save_model', action='store_true', default=False)
|
| 139 |
+
parser.add_argument('--method', type=str, choices=['COCO_CF','APGD_1','APGD_4','NONE'])
|
| 140 |
+
parser.add_argument('--save_model_path', type=str, default="./fine_tuned_clip_models")
|
| 141 |
+
parser.add_argument(
|
| 142 |
+
"--data_seeds",
|
| 143 |
+
nargs="+",
|
| 144 |
+
type=int,
|
| 145 |
+
default=[107],
|
| 146 |
+
help="Seeds to use for each trial for picking demonstrations and eval sets",
|
| 147 |
+
)
|
| 148 |
+
args = parser.parse_args()
|
| 149 |
+
if args.data_name == 'MS_COCO':
|
| 150 |
+
assert args.data_name == 'MS_COCO' and args.method == 'NONE', "Only NONE method is allowed with MS_COCO dataset"
|
| 151 |
+
|
| 152 |
+
from torch.utils.data import DataLoader
|
| 153 |
+
from coco_cf_loader import MS_COCO_dataset, custom_collate_fn
|
| 154 |
+
|
| 155 |
+
torch.manual_seed(42)
|
| 156 |
+
|
| 157 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 158 |
+
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
|
| 159 |
+
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
for data_seed in args.data_seeds:
|
| 163 |
+
|
| 164 |
+
if args.data_name not in ['MS_COCO', 'all']:
|
| 165 |
+
print(f"Data Seed: {data_seed} | Data Name: {args.data_name} | Method: {args.method}")
|
| 166 |
+
dataset = MS_COCO_dataset(base_dir=f'./clip_train_datasets/MS_COCO_{args.method}',
|
| 167 |
+
annotation_file=f'/json_files/data_name_{args.data_name}_data_seed_{data_seed}.json')
|
| 168 |
+
elif args.data_name == 'all':
|
| 169 |
+
print(f"Data Name: {args.data_name} | Method: {args.method}")
|
| 170 |
+
dataset = MS_COCO_dataset(base_dir=f'./clip_train_datasets/MS_COCO_{args.method}',
|
| 171 |
+
annotation_file=f'/json_files/data_name_{args.data_name}.json')
|
| 172 |
+
else:
|
| 173 |
+
print(f"Data Name: {args.data_name} | Method: {args.method}")
|
| 174 |
+
dataset = MS_COCO_dataset(base_dir=f'./clip_train_datasets/MS_COCO',
|
| 175 |
+
annotation_file=f'/ms_coco_captions.json')
|
| 176 |
+
|
| 177 |
+
train_size = int(0.8 * len(dataset)) # 80% for training
|
| 178 |
+
val_size = len(dataset) - train_size # 20% for validation
|
| 179 |
+
|
| 180 |
+
# Randomly split into training and validation datasets
|
| 181 |
+
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
|
| 182 |
+
|
| 183 |
+
# Optional: Create DataLoaders for each subset
|
| 184 |
+
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=custom_collate_fn)
|
| 185 |
+
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, collate_fn=custom_collate_fn,drop_last=True)
|
| 186 |
+
|
| 187 |
+
trainer = ModelTrainer(model=model,
|
| 188 |
+
processor=processor,
|
| 189 |
+
data_name=args.data_name,
|
| 190 |
+
train_data_loader=train_loader,
|
| 191 |
+
val_data_loader=val_loader,
|
| 192 |
+
num_epochs=args.num_epochs,
|
| 193 |
+
learning_rate=args.learning_rate,
|
| 194 |
+
weight_decay=1e-3,
|
| 195 |
+
device=device,
|
| 196 |
+
data_seed=data_seed,
|
| 197 |
+
save_model=args.save_model,
|
| 198 |
+
save_model_path=args.save_model_path,
|
| 199 |
+
method=args.method,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
trainer.train()
|
| 203 |
+
trainer.eval()
|
| 204 |
+
if args.data_name in ['MS_COCO','all']:
|
| 205 |
+
break
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
if __name__ == "__main__":
|
| 209 |
+
main()
|
vlm_eval/coco_cf_loader.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data import Dataset
|
| 2 |
+
from torch.utils.data import DataLoader
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class MS_COCO_dataset(Dataset):
|
| 9 |
+
|
| 10 |
+
def __init__(self, base_dir, annotation_file=None):
|
| 11 |
+
|
| 12 |
+
self.data= []
|
| 13 |
+
self.img_dir = base_dir + '/images'
|
| 14 |
+
self.annotation_file = base_dir + annotation_file
|
| 15 |
+
|
| 16 |
+
with open(self.annotation_file, 'r') as file:
|
| 17 |
+
for line in file:
|
| 18 |
+
self.data.append(json.loads(line))
|
| 19 |
+
|
| 20 |
+
def __len__(self):
|
| 21 |
+
return len(self.data)
|
| 22 |
+
|
| 23 |
+
def __getitem__(self, idx):
|
| 24 |
+
# Extract the relevant info from the JSONL entry
|
| 25 |
+
img_name = os.path.join(self.img_dir, f"{self.data[idx]['image_name']}")
|
| 26 |
+
caption = self.data[idx]['caption']
|
| 27 |
+
sample_id = self.data[idx]['image_id']
|
| 28 |
+
|
| 29 |
+
# Load the image using PIL
|
| 30 |
+
img = Image.open(img_name)
|
| 31 |
+
|
| 32 |
+
return {"id": sample_id,
|
| 33 |
+
"image": img,
|
| 34 |
+
"caption": caption
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
class COCO_CF_dataset(Dataset):
|
| 38 |
+
|
| 39 |
+
def __init__(self, base_dir):
|
| 40 |
+
|
| 41 |
+
self.data= []
|
| 42 |
+
self.img_dir = base_dir + '/images'
|
| 43 |
+
self.annotation_file = base_dir + "/examples.jsonl"
|
| 44 |
+
|
| 45 |
+
with open(self.annotation_file, 'r') as file:
|
| 46 |
+
for line in file:
|
| 47 |
+
self.data.append(json.loads(line))
|
| 48 |
+
|
| 49 |
+
def __len__(self):
|
| 50 |
+
return len(self.data)
|
| 51 |
+
|
| 52 |
+
def __getitem__(self, idx):
|
| 53 |
+
# Extract the relevant info from the JSONL entry
|
| 54 |
+
img_0_name = os.path.join(self.img_dir, f"{self.data[idx]['image_0']}.jpg")
|
| 55 |
+
img_1_name = os.path.join(self.img_dir, f"{self.data[idx]['image_1']}.jpg")
|
| 56 |
+
|
| 57 |
+
caption_0 = self.data[idx]['caption_0']
|
| 58 |
+
caption_1 = self.data[idx]['caption_1']
|
| 59 |
+
sample_id = self.data[idx]['id']
|
| 60 |
+
|
| 61 |
+
# Load the image using PIL
|
| 62 |
+
img_0 = Image.open(img_0_name)
|
| 63 |
+
img_1 = Image.open(img_1_name)
|
| 64 |
+
|
| 65 |
+
return {"id": sample_id,
|
| 66 |
+
"caption_0": caption_0,
|
| 67 |
+
"caption_1": caption_1,
|
| 68 |
+
"image_0": img_0,
|
| 69 |
+
"image_1": img_1}
|
| 70 |
+
|
| 71 |
+
def custom_collate_fn(batch):
|
| 72 |
+
collated_batch = {}
|
| 73 |
+
for key in batch[0].keys():
|
| 74 |
+
collated_batch[key] = [item[key] for item in batch]
|
| 75 |
+
return collated_batch
|
| 76 |
+
|
| 77 |
+
if __name__ == "__main__":
|
| 78 |
+
|
| 79 |
+
base_dir = '/home/htc/kchitranshi/SCRATCH/MS_COCO/'
|
| 80 |
+
data = MS_COCO_dataset(base_dir=base_dir)
|
| 81 |
+
data_loader = DataLoader(data, batch_size=10,collate_fn=custom_collate_fn)
|
| 82 |
+
|
| 83 |
+
for batch in data_loader:
|
| 84 |
+
print(batch)
|
| 85 |
+
break
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
|
vlm_eval/create_clip_dataset.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import random
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def main():
|
| 9 |
+
|
| 10 |
+
# Intialising seeds for data
|
| 11 |
+
data_seeds = [i for i in range(107,122)]
|
| 12 |
+
|
| 13 |
+
ms_coco_base_anno_path = "./clip_train_datasets/MS_COCO/ms_coco_captions.json"
|
| 14 |
+
attack_base_anno_path = "./clip_train_datasets/COCO_CF/examples.jsonl"
|
| 15 |
+
|
| 16 |
+
data_names = ["base","medium","all"]
|
| 17 |
+
|
| 18 |
+
ms_coco_array = []
|
| 19 |
+
attack_array = []
|
| 20 |
+
|
| 21 |
+
with open(ms_coco_base_anno_path, 'r') as file:
|
| 22 |
+
for line in file:
|
| 23 |
+
ms_coco_array.append(json.loads(line))
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
with open(attack_base_anno_path, 'r') as file:
|
| 27 |
+
for line in file:
|
| 28 |
+
attack_array.append(json.loads(line))
|
| 29 |
+
|
| 30 |
+
for data_name in data_names:
|
| 31 |
+
for data_seed in data_seeds:
|
| 32 |
+
if data_name == "base":
|
| 33 |
+
num_ms_coco_samples = 8705
|
| 34 |
+
num_attacks_samples = 4353 # These many pairs of samples with their counterfactuals or adv attack samples. Effectively 8706 in total.
|
| 35 |
+
elif data_name == "medium":
|
| 36 |
+
num_ms_coco_samples = 17410
|
| 37 |
+
num_attacks_samples = int(0.75 * 17410) # These many pairs of samples with their counterfactuals or adv attack samples. Effectively 26115 in total.
|
| 38 |
+
elif data_name == "all":
|
| 39 |
+
num_ms_coco_samples = 17410
|
| 40 |
+
num_attacks_samples = 17410 # These many pairs of samples with their counterfactuals or adv attack samples. Effectively 34820 in total.
|
| 41 |
+
|
| 42 |
+
np.random.seed(data_seed)
|
| 43 |
+
ms_coco_rand_indices = np.random.choice(len(ms_coco_array), num_ms_coco_samples, replace=False)
|
| 44 |
+
attack_rand_indices = np.random.choice(len(attack_array), num_attacks_samples, replace=False)
|
| 45 |
+
|
| 46 |
+
ms_coco_samples = [ms_coco_array[int(i)] for i in ms_coco_rand_indices]
|
| 47 |
+
attack_samples = [attack_array[int(i)] for i in attack_rand_indices]
|
| 48 |
+
attack_samples = [{"image_id": batch["id"], "image_name": batch[f"image_{i}"] + ".jpg", "caption": batch[f"caption_{i}"]} for batch in attack_samples for i in range(2)]
|
| 49 |
+
|
| 50 |
+
random.seed(42)
|
| 51 |
+
combined_dataset = ms_coco_samples + attack_samples
|
| 52 |
+
|
| 53 |
+
random.shuffle(combined_dataset)
|
| 54 |
+
|
| 55 |
+
if data_name != 'all':
|
| 56 |
+
with open(f"./clip_train_datasets/MS_COCO_APGD_4/json_files/data_name_{data_name}_data_seed_{data_seed}.json", 'w') as file:
|
| 57 |
+
for line in combined_dataset:
|
| 58 |
+
file.write(json.dumps(line) + '\n')
|
| 59 |
+
else:
|
| 60 |
+
with open(f"./clip_train_datasets/MS_COCO_APGD_4/json_files/data_name_{data_name}.json", 'w') as file:
|
| 61 |
+
for line in combined_dataset:
|
| 62 |
+
file.write(json.dumps(line) + '\n')
|
| 63 |
+
|
| 64 |
+
if __name__ == "__main__":
|
| 65 |
+
main()
|
vlm_eval/datasets_classes_templates.py
ADDED
|
@@ -0,0 +1,822 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code taken and adapted from https://github.com/openai/CLIP/blob/main/data/prompts.md
|
| 2 |
+
|
| 3 |
+
CIFAR10_CLASSES_TEMPLATES = {
|
| 4 |
+
'classes' : [
|
| 5 |
+
'airplane',
|
| 6 |
+
'automobile',
|
| 7 |
+
'bird',
|
| 8 |
+
'cat',
|
| 9 |
+
'deer',
|
| 10 |
+
'dog',
|
| 11 |
+
'frog',
|
| 12 |
+
'horse',
|
| 13 |
+
'ship',
|
| 14 |
+
'truck',
|
| 15 |
+
],
|
| 16 |
+
|
| 17 |
+
'templates' : [
|
| 18 |
+
'a photo of a {}.',
|
| 19 |
+
'a blurry photo of a {}.',
|
| 20 |
+
'a black and white photo of a {}.',
|
| 21 |
+
'a low contrast photo of a {}.',
|
| 22 |
+
'a high contrast photo of a {}.',
|
| 23 |
+
'a bad photo of a {}.',
|
| 24 |
+
'a good photo of a {}.',
|
| 25 |
+
'a photo of a small {}.',
|
| 26 |
+
'a photo of a big {}.',
|
| 27 |
+
'a photo of the {}.',
|
| 28 |
+
'a blurry photo of the {}.',
|
| 29 |
+
'a black and white photo of the {}.',
|
| 30 |
+
'a low contrast photo of the {}.',
|
| 31 |
+
'a high contrast photo of the {}.',
|
| 32 |
+
'a bad photo of the {}.',
|
| 33 |
+
'a good photo of the {}.',
|
| 34 |
+
'a photo of the small {}.',
|
| 35 |
+
'a photo of the big {}.',
|
| 36 |
+
]
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
CIFAR100_CLASSES_TEMPLATES = {
|
| 40 |
+
'classes' : [
|
| 41 |
+
'apple',
|
| 42 |
+
'aquarium fish',
|
| 43 |
+
'baby',
|
| 44 |
+
'bear',
|
| 45 |
+
'beaver',
|
| 46 |
+
'bed',
|
| 47 |
+
'bee',
|
| 48 |
+
'beetle',
|
| 49 |
+
'bicycle',
|
| 50 |
+
'bottle',
|
| 51 |
+
'bowl',
|
| 52 |
+
'boy',
|
| 53 |
+
'bridge',
|
| 54 |
+
'bus',
|
| 55 |
+
'butterfly',
|
| 56 |
+
'camel',
|
| 57 |
+
'can',
|
| 58 |
+
'castle',
|
| 59 |
+
'caterpillar',
|
| 60 |
+
'cattle',
|
| 61 |
+
'chair',
|
| 62 |
+
'chimpanzee',
|
| 63 |
+
'clock',
|
| 64 |
+
'cloud',
|
| 65 |
+
'cockroach',
|
| 66 |
+
'couch',
|
| 67 |
+
'crab',
|
| 68 |
+
'crocodile',
|
| 69 |
+
'cup',
|
| 70 |
+
'dinosaur',
|
| 71 |
+
'dolphin',
|
| 72 |
+
'elephant',
|
| 73 |
+
'flatfish',
|
| 74 |
+
'forest',
|
| 75 |
+
'fox',
|
| 76 |
+
'girl',
|
| 77 |
+
'hamster',
|
| 78 |
+
'house',
|
| 79 |
+
'kangaroo',
|
| 80 |
+
'keyboard',
|
| 81 |
+
'lamp',
|
| 82 |
+
'lawn mower',
|
| 83 |
+
'leopard',
|
| 84 |
+
'lion',
|
| 85 |
+
'lizard',
|
| 86 |
+
'lobster',
|
| 87 |
+
'man',
|
| 88 |
+
'maple tree',
|
| 89 |
+
'motorcycle',
|
| 90 |
+
'mountain',
|
| 91 |
+
'mouse',
|
| 92 |
+
'mushroom',
|
| 93 |
+
'oak tree',
|
| 94 |
+
'orange',
|
| 95 |
+
'orchid',
|
| 96 |
+
'otter',
|
| 97 |
+
'palm tree',
|
| 98 |
+
'pear',
|
| 99 |
+
'pickup truck',
|
| 100 |
+
'pine tree',
|
| 101 |
+
'plain',
|
| 102 |
+
'plate',
|
| 103 |
+
'poppy',
|
| 104 |
+
'porcupine',
|
| 105 |
+
'possum',
|
| 106 |
+
'rabbit',
|
| 107 |
+
'raccoon',
|
| 108 |
+
'ray',
|
| 109 |
+
'road',
|
| 110 |
+
'rocket',
|
| 111 |
+
'rose',
|
| 112 |
+
'sea',
|
| 113 |
+
'seal',
|
| 114 |
+
'shark',
|
| 115 |
+
'shrew',
|
| 116 |
+
'skunk',
|
| 117 |
+
'skyscraper',
|
| 118 |
+
'snail',
|
| 119 |
+
'snake',
|
| 120 |
+
'spider',
|
| 121 |
+
'squirrel',
|
| 122 |
+
'streetcar',
|
| 123 |
+
'sunflower',
|
| 124 |
+
'sweet pepper',
|
| 125 |
+
'table',
|
| 126 |
+
'tank',
|
| 127 |
+
'telephone',
|
| 128 |
+
'television',
|
| 129 |
+
'tiger',
|
| 130 |
+
'tractor',
|
| 131 |
+
'train',
|
| 132 |
+
'trout',
|
| 133 |
+
'tulip',
|
| 134 |
+
'turtle',
|
| 135 |
+
'wardrobe',
|
| 136 |
+
'whale',
|
| 137 |
+
'willow tree',
|
| 138 |
+
'wolf',
|
| 139 |
+
'woman',
|
| 140 |
+
'worm',
|
| 141 |
+
],
|
| 142 |
+
|
| 143 |
+
'templates' : [
|
| 144 |
+
'a photo of a {}.',
|
| 145 |
+
'a blurry photo of a {}.',
|
| 146 |
+
'a black and white photo of a {}.',
|
| 147 |
+
'a low contrast photo of a {}.',
|
| 148 |
+
'a high contrast photo of a {}.',
|
| 149 |
+
'a bad photo of a {}.',
|
| 150 |
+
'a good photo of a {}.',
|
| 151 |
+
'a photo of a small {}.',
|
| 152 |
+
'a photo of a big {}.',
|
| 153 |
+
'a photo of the {}.',
|
| 154 |
+
'a blurry photo of the {}.',
|
| 155 |
+
'a black and white photo of the {}.',
|
| 156 |
+
'a low contrast photo of the {}.',
|
| 157 |
+
'a high contrast photo of the {}.',
|
| 158 |
+
'a bad photo of the {}.',
|
| 159 |
+
'a good photo of the {}.',
|
| 160 |
+
'a photo of the small {}.',
|
| 161 |
+
'a photo of the big {}.',
|
| 162 |
+
]
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
ImageNet_CLASSES_TEMPLATES = {
|
| 166 |
+
'classes' : ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", "box turtle", "banded gecko", "green iguana", "Carolina anole", "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", "American alligator", "triceratops", "worm snake", "ring-necked snake", "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", "freight car", "French horn", "frying pan", "fur coat", "garbage truck", "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", "mailbox", "tights",
|
| 167 |
+
"one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", "microwave oven",
|
| 168 |
+
"military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", "moped",
|
| 169 |
+
"mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", "neck brace",
|
| 170 |
+
"necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", "oxygen mask",
|
| 171 |
+
"product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", "parking meter",
|
| 172 |
+
"railroad car", "patio", "payphone", "pedestal", "pencil case", "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", "picket fence", "pickup truck", "pier",
|
| 173 |
+
"piggy bank", "pill bottle", "pillow", "ping-pong ball", "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", "plate rack", "farm plow", "plunger",
|
| 174 |
+
"Polaroid camera", "pole", "police van", "poncho", "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", "printer", "prison", "missile", "projector",
|
| 175 |
+
"hockey puck", "punching bag", "purse", "quill", "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", "recreational vehicle",
|
| 176 |
+
"fishing casting reel", "reflex camera", "refrigerator", "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", "rugby ball", "ruler measuring stick",
|
| 177 |
+
"sneaker", "safe", "safety pin", "salt shaker", "sandal", "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", "CRT monitor", "screw", "screwdriver",
|
| 178 |
+
"seat belt", "sewing machine", "shield", "shoe store", "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", "shower curtain", "ski",
|
| 179 |
+
"balaclava ski mask", "sleeping bag", "slide rule", "sliding door", "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock",
|
| 180 |
+
"solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car",
|
| 181 |
+
"spotlight", "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", "mop", "sweatshirt", "swim trunks / shorts", "swing",
|
| 182 |
+
"electrical switch", "syringe", "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"]
|
| 183 |
+
,
|
| 184 |
+
'templates' : [
|
| 185 |
+
'a bad photo of a {}.',
|
| 186 |
+
'a photo of many {}.',
|
| 187 |
+
'a sculpture of a {}.',
|
| 188 |
+
'a photo of the hard to see {}.',
|
| 189 |
+
'a low resolution photo of the {}.',
|
| 190 |
+
'a rendering of a {}.',
|
| 191 |
+
'graffiti of a {}.',
|
| 192 |
+
'a bad photo of the {}.',
|
| 193 |
+
'a cropped photo of the {}.',
|
| 194 |
+
'a tattoo of a {}.',
|
| 195 |
+
'the embroidered {}.',
|
| 196 |
+
'a photo of a hard to see {}.',
|
| 197 |
+
'a bright photo of a {}.',
|
| 198 |
+
'a photo of a clean {}.',
|
| 199 |
+
'a photo of a dirty {}.',
|
| 200 |
+
'a dark photo of the {}.',
|
| 201 |
+
'a drawing of a {}.',
|
| 202 |
+
'a photo of my {}.',
|
| 203 |
+
'the plastic {}.',
|
| 204 |
+
'a photo of the cool {}.',
|
| 205 |
+
'a close-up photo of a {}.',
|
| 206 |
+
'a black and white photo of the {}.',
|
| 207 |
+
'a painting of the {}.',
|
| 208 |
+
'a painting of a {}.',
|
| 209 |
+
'a pixelated photo of the {}.',
|
| 210 |
+
'a sculpture of the {}.',
|
| 211 |
+
'a bright photo of the {}.',
|
| 212 |
+
'a cropped photo of a {}.',
|
| 213 |
+
'a plastic {}.',
|
| 214 |
+
'a photo of the dirty {}.',
|
| 215 |
+
'a jpeg corrupted photo of a {}.',
|
| 216 |
+
'a blurry photo of the {}.',
|
| 217 |
+
'a photo of the {}.',
|
| 218 |
+
'a good photo of the {}.',
|
| 219 |
+
'a rendering of the {}.',
|
| 220 |
+
'a {} in a video game.',
|
| 221 |
+
'a photo of one {}.',
|
| 222 |
+
'a doodle of a {}.',
|
| 223 |
+
'a close-up photo of the {}.',
|
| 224 |
+
'a photo of a {}.',
|
| 225 |
+
'the origami {}.',
|
| 226 |
+
'the {} in a video game.',
|
| 227 |
+
'a sketch of a {}.',
|
| 228 |
+
'a doodle of the {}.',
|
| 229 |
+
'a origami {}.',
|
| 230 |
+
'a low resolution photo of a {}.',
|
| 231 |
+
'the toy {}.',
|
| 232 |
+
'a rendition of the {}.',
|
| 233 |
+
'a photo of the clean {}.',
|
| 234 |
+
'a photo of a large {}.',
|
| 235 |
+
'a rendition of a {}.',
|
| 236 |
+
'a photo of a nice {}.',
|
| 237 |
+
'a photo of a weird {}.',
|
| 238 |
+
'a blurry photo of a {}.',
|
| 239 |
+
'a cartoon {}.',
|
| 240 |
+
'art of a {}.',
|
| 241 |
+
'a sketch of the {}.',
|
| 242 |
+
'a embroidered {}.',
|
| 243 |
+
'a pixelated photo of a {}.',
|
| 244 |
+
'itap of the {}.',
|
| 245 |
+
'a jpeg corrupted photo of the {}.',
|
| 246 |
+
'a good photo of a {}.',
|
| 247 |
+
'a plushie {}.',
|
| 248 |
+
'a photo of the nice {}.',
|
| 249 |
+
'a photo of the small {}.',
|
| 250 |
+
'a photo of the weird {}.',
|
| 251 |
+
'the cartoon {}.',
|
| 252 |
+
'art of the {}.',
|
| 253 |
+
'a drawing of the {}.',
|
| 254 |
+
'a photo of the large {}.',
|
| 255 |
+
'a black and white photo of a {}.',
|
| 256 |
+
'the plushie {}.',
|
| 257 |
+
'a dark photo of a {}.',
|
| 258 |
+
'itap of a {}.',
|
| 259 |
+
'graffiti of the {}.',
|
| 260 |
+
'a toy {}.',
|
| 261 |
+
'itap of my {}.',
|
| 262 |
+
'a photo of a cool {}.',
|
| 263 |
+
'a photo of a small {}.',
|
| 264 |
+
'a tattoo of the {}.',
|
| 265 |
+
]
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
Caltech101_CLASSES_TEMPLATES = {
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
'classes' : ['Faces',
|
| 272 |
+
'Faces_easy',
|
| 273 |
+
'Leopards',
|
| 274 |
+
'Motorbikes',
|
| 275 |
+
'accordion',
|
| 276 |
+
'airplanes',
|
| 277 |
+
'anchor',
|
| 278 |
+
'ant',
|
| 279 |
+
'barrel',
|
| 280 |
+
'bass',
|
| 281 |
+
'beaver',
|
| 282 |
+
'binocular',
|
| 283 |
+
'bonsai',
|
| 284 |
+
'brain',
|
| 285 |
+
'brontosaurus',
|
| 286 |
+
'buddha',
|
| 287 |
+
'butterfly',
|
| 288 |
+
'camera',
|
| 289 |
+
'cannon',
|
| 290 |
+
'car_side',
|
| 291 |
+
'ceiling_fan',
|
| 292 |
+
'cellphone',
|
| 293 |
+
'chair',
|
| 294 |
+
'chandelier',
|
| 295 |
+
'cougar_body',
|
| 296 |
+
'cougar_face',
|
| 297 |
+
'crab',
|
| 298 |
+
'crayfish',
|
| 299 |
+
'crocodile',
|
| 300 |
+
'crocodile_head',
|
| 301 |
+
'cup',
|
| 302 |
+
'dalmatian',
|
| 303 |
+
'dollar_bill',
|
| 304 |
+
'dolphin',
|
| 305 |
+
'dragonfly',
|
| 306 |
+
'electric_guitar',
|
| 307 |
+
'elephant',
|
| 308 |
+
'emu',
|
| 309 |
+
'euphonium',
|
| 310 |
+
'ewer',
|
| 311 |
+
'ferry',
|
| 312 |
+
'flamingo',
|
| 313 |
+
'flamingo_head',
|
| 314 |
+
'garfield',
|
| 315 |
+
'gerenuk',
|
| 316 |
+
'gramophone',
|
| 317 |
+
'grand_piano',
|
| 318 |
+
'hawksbill',
|
| 319 |
+
'headphone',
|
| 320 |
+
'hedgehog',
|
| 321 |
+
'helicopter',
|
| 322 |
+
'ibis',
|
| 323 |
+
'inline_skate',
|
| 324 |
+
'joshua_tree',
|
| 325 |
+
'kangaroo',
|
| 326 |
+
'ketch',
|
| 327 |
+
'lamp',
|
| 328 |
+
'laptop',
|
| 329 |
+
'llama',
|
| 330 |
+
'lobster',
|
| 331 |
+
'lotus',
|
| 332 |
+
'mandolin',
|
| 333 |
+
'mayfly',
|
| 334 |
+
'menorah',
|
| 335 |
+
'metronome',
|
| 336 |
+
'minaret',
|
| 337 |
+
'nautilus',
|
| 338 |
+
'octopus',
|
| 339 |
+
'okapi',
|
| 340 |
+
'pagoda',
|
| 341 |
+
'panda',
|
| 342 |
+
'pigeon',
|
| 343 |
+
'pizza',
|
| 344 |
+
'platypus',
|
| 345 |
+
'pyramid',
|
| 346 |
+
'revolver',
|
| 347 |
+
'rhino',
|
| 348 |
+
'rooster',
|
| 349 |
+
'saxophone',
|
| 350 |
+
'schooner',
|
| 351 |
+
'scissors',
|
| 352 |
+
'scorpion',
|
| 353 |
+
'sea_horse',
|
| 354 |
+
'snoopy',
|
| 355 |
+
'soccer_ball',
|
| 356 |
+
'stapler',
|
| 357 |
+
'starfish',
|
| 358 |
+
'stegosaurus',
|
| 359 |
+
'stop_sign',
|
| 360 |
+
'strawberry',
|
| 361 |
+
'sunflower',
|
| 362 |
+
'tick',
|
| 363 |
+
'trilobite',
|
| 364 |
+
'umbrella',
|
| 365 |
+
'watch',
|
| 366 |
+
'water_lilly',
|
| 367 |
+
'wheelchair',
|
| 368 |
+
'wild_cat',
|
| 369 |
+
'windsor_chair',
|
| 370 |
+
'wrench',
|
| 371 |
+
'yin_yang']
|
| 372 |
+
,
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
'templates' : [
|
| 376 |
+
'a photo of a {}.',
|
| 377 |
+
'a painting of a {}.',
|
| 378 |
+
'a plastic {}.',
|
| 379 |
+
'a sculpture of a {}.',
|
| 380 |
+
'a sketch of a {}.',
|
| 381 |
+
'a tattoo of a {}.',
|
| 382 |
+
'a toy {}.',
|
| 383 |
+
'a rendition of a {}.',
|
| 384 |
+
'a embroidered {}.',
|
| 385 |
+
'a cartoon {}.',
|
| 386 |
+
'a {} in a video game.',
|
| 387 |
+
'a plushie {}.',
|
| 388 |
+
'a origami {}.',
|
| 389 |
+
'art of a {}.',
|
| 390 |
+
'graffiti of a {}.',
|
| 391 |
+
'a drawing of a {}.',
|
| 392 |
+
'a doodle of a {}.',
|
| 393 |
+
'a photo of the {}.',
|
| 394 |
+
'a painting of the {}.',
|
| 395 |
+
'the plastic {}.',
|
| 396 |
+
'a sculpture of the {}.',
|
| 397 |
+
'a sketch of the {}.',
|
| 398 |
+
'a tattoo of the {}.',
|
| 399 |
+
'the toy {}.',
|
| 400 |
+
'a rendition of the {}.',
|
| 401 |
+
'the embroidered {}.',
|
| 402 |
+
'the cartoon {}.',
|
| 403 |
+
'the {} in a video game.',
|
| 404 |
+
'the plushie {}.',
|
| 405 |
+
'the origami {}.',
|
| 406 |
+
'art of the {}.',
|
| 407 |
+
'graffiti of the {}.',
|
| 408 |
+
'a drawing of the {}.',
|
| 409 |
+
'a doodle of the {}.',
|
| 410 |
+
]
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
Caltech256_CLASSES_TEMPLATES = {
|
| 414 |
+
'classes' : [
|
| 415 |
+
'ak47',
|
| 416 |
+
'american flag',
|
| 417 |
+
'backpack',
|
| 418 |
+
'baseball bat',
|
| 419 |
+
'baseball glove',
|
| 420 |
+
'basketball hoop',
|
| 421 |
+
'bat',
|
| 422 |
+
'bathtub',
|
| 423 |
+
'bear',
|
| 424 |
+
'beer mug',
|
| 425 |
+
'billiards',
|
| 426 |
+
'binoculars',
|
| 427 |
+
'birdbath',
|
| 428 |
+
'blimp',
|
| 429 |
+
'bonsai',
|
| 430 |
+
'boom box',
|
| 431 |
+
'bowling ball',
|
| 432 |
+
'bowling pin',
|
| 433 |
+
'boxing glove',
|
| 434 |
+
'brain',
|
| 435 |
+
'breadmaker',
|
| 436 |
+
'buddha',
|
| 437 |
+
'bulldozer',
|
| 438 |
+
'butterfly',
|
| 439 |
+
'cactus',
|
| 440 |
+
'cake',
|
| 441 |
+
'calculator',
|
| 442 |
+
'camel',
|
| 443 |
+
'cannon',
|
| 444 |
+
'canoe',
|
| 445 |
+
'car tire',
|
| 446 |
+
'cartman',
|
| 447 |
+
'cd',
|
| 448 |
+
'centipede',
|
| 449 |
+
'cereal box',
|
| 450 |
+
'chandelier',
|
| 451 |
+
'chess board',
|
| 452 |
+
'chimp',
|
| 453 |
+
'chopsticks',
|
| 454 |
+
'cockroach',
|
| 455 |
+
'coffee mug',
|
| 456 |
+
'coffin',
|
| 457 |
+
'coin',
|
| 458 |
+
'comet',
|
| 459 |
+
'computer keyboard',
|
| 460 |
+
'computer monitor',
|
| 461 |
+
'computer mouse',
|
| 462 |
+
'conch',
|
| 463 |
+
'cormorant',
|
| 464 |
+
'covered wagon',
|
| 465 |
+
'cowboy hat',
|
| 466 |
+
'crab',
|
| 467 |
+
'desk globe',
|
| 468 |
+
'diamond ring',
|
| 469 |
+
'dice',
|
| 470 |
+
'dog',
|
| 471 |
+
'dolphin',
|
| 472 |
+
'doorknob',
|
| 473 |
+
'drinking straw',
|
| 474 |
+
'duck',
|
| 475 |
+
'dumb bell',
|
| 476 |
+
'eiffel tower',
|
| 477 |
+
'electric guitar',
|
| 478 |
+
'elephant',
|
| 479 |
+
'elk',
|
| 480 |
+
'ewer',
|
| 481 |
+
'eyeglasses',
|
| 482 |
+
'fern',
|
| 483 |
+
'fighter jet',
|
| 484 |
+
'fire extinguisher',
|
| 485 |
+
'fire hydrant',
|
| 486 |
+
'fire truck',
|
| 487 |
+
'fireworks',
|
| 488 |
+
'flashlight',
|
| 489 |
+
'floppy disk',
|
| 490 |
+
'football helmet',
|
| 491 |
+
'french horn',
|
| 492 |
+
'fried egg',
|
| 493 |
+
'frisbee',
|
| 494 |
+
'frog',
|
| 495 |
+
'frying pan',
|
| 496 |
+
'galaxy',
|
| 497 |
+
'gas pump',
|
| 498 |
+
'giraffe',
|
| 499 |
+
'goat',
|
| 500 |
+
'golden gate bridge',
|
| 501 |
+
'goldfish',
|
| 502 |
+
'golf ball',
|
| 503 |
+
'goose',
|
| 504 |
+
'gorilla',
|
| 505 |
+
'grand piano',
|
| 506 |
+
'grapes',
|
| 507 |
+
'grasshopper',
|
| 508 |
+
'guitar pick',
|
| 509 |
+
'hamburger',
|
| 510 |
+
'hammock',
|
| 511 |
+
'harmonica',
|
| 512 |
+
'harp',
|
| 513 |
+
'harpsichord',
|
| 514 |
+
'hawksbill',
|
| 515 |
+
'head phones',
|
| 516 |
+
'helicopter',
|
| 517 |
+
'hibiscus',
|
| 518 |
+
'homer simpson',
|
| 519 |
+
'horse',
|
| 520 |
+
'horseshoe crab',
|
| 521 |
+
'hot air balloon',
|
| 522 |
+
'hot dog',
|
| 523 |
+
'hot tub',
|
| 524 |
+
'hourglass',
|
| 525 |
+
'house fly',
|
| 526 |
+
'human skeleton',
|
| 527 |
+
'hummingbird',
|
| 528 |
+
'ibis',
|
| 529 |
+
'ice cream cone',
|
| 530 |
+
'iguana',
|
| 531 |
+
'ipod',
|
| 532 |
+
'iris',
|
| 533 |
+
'jesus christ',
|
| 534 |
+
'joy stick',
|
| 535 |
+
'kangaroo',
|
| 536 |
+
'kayak',
|
| 537 |
+
'ketch',
|
| 538 |
+
'killer whale',
|
| 539 |
+
'knife',
|
| 540 |
+
'ladder',
|
| 541 |
+
'laptop',
|
| 542 |
+
'lathe',
|
| 543 |
+
'leopards',
|
| 544 |
+
'license plate',
|
| 545 |
+
'lightbulb',
|
| 546 |
+
'light house',
|
| 547 |
+
'lightning',
|
| 548 |
+
'llama',
|
| 549 |
+
'mailbox',
|
| 550 |
+
'mandolin',
|
| 551 |
+
'mars',
|
| 552 |
+
'mattress',
|
| 553 |
+
'megaphone',
|
| 554 |
+
'menorah',
|
| 555 |
+
'microscope',
|
| 556 |
+
'microwave',
|
| 557 |
+
'minaret',
|
| 558 |
+
'minotaur',
|
| 559 |
+
'motorbikes',
|
| 560 |
+
'mountain bike',
|
| 561 |
+
'mushroom',
|
| 562 |
+
'mussels',
|
| 563 |
+
'necktie',
|
| 564 |
+
'octopus',
|
| 565 |
+
'ostrich',
|
| 566 |
+
'owl',
|
| 567 |
+
'palm pilot',
|
| 568 |
+
'palm tree',
|
| 569 |
+
'paperclip',
|
| 570 |
+
'paper shredder',
|
| 571 |
+
'pci card',
|
| 572 |
+
'penguin',
|
| 573 |
+
'people',
|
| 574 |
+
'pez dispenser',
|
| 575 |
+
'photocopier',
|
| 576 |
+
'picnic table',
|
| 577 |
+
'playing card',
|
| 578 |
+
'porcupine',
|
| 579 |
+
'pram',
|
| 580 |
+
'praying mantis',
|
| 581 |
+
'pyramid',
|
| 582 |
+
'raccoon',
|
| 583 |
+
'radio telescope',
|
| 584 |
+
'rainbow',
|
| 585 |
+
'refrigerator',
|
| 586 |
+
'revolver',
|
| 587 |
+
'rifle',
|
| 588 |
+
'rotary phone',
|
| 589 |
+
'roulette wheel',
|
| 590 |
+
'saddle',
|
| 591 |
+
'saturn',
|
| 592 |
+
'school bus',
|
| 593 |
+
'scorpion',
|
| 594 |
+
'screwdriver',
|
| 595 |
+
'segway',
|
| 596 |
+
'self propelled lawn mower',
|
| 597 |
+
'sextant',
|
| 598 |
+
'sheet music',
|
| 599 |
+
'skateboard',
|
| 600 |
+
'skunk',
|
| 601 |
+
'skyscraper',
|
| 602 |
+
'smokestack',
|
| 603 |
+
'snail',
|
| 604 |
+
'snake',
|
| 605 |
+
'sneaker',
|
| 606 |
+
'snowmobile',
|
| 607 |
+
'soccer ball',
|
| 608 |
+
'socks',
|
| 609 |
+
'soda can',
|
| 610 |
+
'spaghetti',
|
| 611 |
+
'speed boat',
|
| 612 |
+
'spider',
|
| 613 |
+
'spoon',
|
| 614 |
+
'stained glass',
|
| 615 |
+
'starfish',
|
| 616 |
+
'steering wheel',
|
| 617 |
+
'stirrups',
|
| 618 |
+
'sunflower',
|
| 619 |
+
'superman',
|
| 620 |
+
'sushi',
|
| 621 |
+
'swan',
|
| 622 |
+
'swiss army knife',
|
| 623 |
+
'sword',
|
| 624 |
+
'syringe',
|
| 625 |
+
'tambourine',
|
| 626 |
+
'teapot',
|
| 627 |
+
'teddy bear',
|
| 628 |
+
'teepee',
|
| 629 |
+
'telephone box',
|
| 630 |
+
'tennis ball',
|
| 631 |
+
'tennis court',
|
| 632 |
+
'tennis racket',
|
| 633 |
+
'theodolite',
|
| 634 |
+
'toaster',
|
| 635 |
+
'tomato',
|
| 636 |
+
'tombstone',
|
| 637 |
+
'top hat',
|
| 638 |
+
'touring bike',
|
| 639 |
+
'tower pisa',
|
| 640 |
+
'traffic light',
|
| 641 |
+
'treadmill',
|
| 642 |
+
'triceratops',
|
| 643 |
+
'tricycle',
|
| 644 |
+
'trilobite',
|
| 645 |
+
'tripod',
|
| 646 |
+
't shirt',
|
| 647 |
+
'tuning fork',
|
| 648 |
+
'tweezer',
|
| 649 |
+
'umbrella',
|
| 650 |
+
'unicorn',
|
| 651 |
+
'vcr',
|
| 652 |
+
'video projector',
|
| 653 |
+
'washing machine',
|
| 654 |
+
'watch',
|
| 655 |
+
'waterfall',
|
| 656 |
+
'watermelon',
|
| 657 |
+
'welding mask',
|
| 658 |
+
'wheelbarrow',
|
| 659 |
+
'windmill',
|
| 660 |
+
'wine bottle',
|
| 661 |
+
'xylophone',
|
| 662 |
+
'yarmulke',
|
| 663 |
+
'yo yo',
|
| 664 |
+
'zebra',
|
| 665 |
+
'airplanes',
|
| 666 |
+
'car side',
|
| 667 |
+
'faces easy',
|
| 668 |
+
'greyhound',
|
| 669 |
+
'tennis shoes',
|
| 670 |
+
'toad',
|
| 671 |
+
'clutter'
|
| 672 |
+
],
|
| 673 |
+
|
| 674 |
+
'templates' : [
|
| 675 |
+
'a photo of a {}.',
|
| 676 |
+
'a painting of a {}.',
|
| 677 |
+
'a plastic {}.',
|
| 678 |
+
'a sculpture of a {}.',
|
| 679 |
+
'a sketch of a {}.',
|
| 680 |
+
'a tattoo of a {}.',
|
| 681 |
+
'a toy {}.',
|
| 682 |
+
'a rendition of a {}.',
|
| 683 |
+
'a embroidered {}.',
|
| 684 |
+
'a cartoon {}.',
|
| 685 |
+
'a {} in a video game.',
|
| 686 |
+
'a plushie {}.',
|
| 687 |
+
'a origami {}.',
|
| 688 |
+
'art of a {}.',
|
| 689 |
+
'graffiti of a {}.',
|
| 690 |
+
'a drawing of a {}.',
|
| 691 |
+
'a doodle of a {}.',
|
| 692 |
+
'a photo of the {}.',
|
| 693 |
+
'a painting of the {}.',
|
| 694 |
+
'the plastic {}.',
|
| 695 |
+
'a sculpture of the {}.',
|
| 696 |
+
'a sketch of the {}.',
|
| 697 |
+
'a tattoo of the {}.',
|
| 698 |
+
'the toy {}.',
|
| 699 |
+
'a rendition of the {}.',
|
| 700 |
+
'the embroidered {}.',
|
| 701 |
+
'the cartoon {}.',
|
| 702 |
+
'the {} in a video game.',
|
| 703 |
+
'the plushie {}.',
|
| 704 |
+
'the origami {}.',
|
| 705 |
+
'art of the {}.',
|
| 706 |
+
'graffiti of the {}.',
|
| 707 |
+
'a drawing of the {}.',
|
| 708 |
+
'a doodle of the {}.',
|
| 709 |
+
]
|
| 710 |
+
}
|
| 711 |
+
|
| 712 |
+
Food101_CLASSES_TEMPLATES = {
|
| 713 |
+
'classes' : [
|
| 714 |
+
'apple pie',
|
| 715 |
+
'baby back ribs',
|
| 716 |
+
'baklava',
|
| 717 |
+
'beef carpaccio',
|
| 718 |
+
'beef tartare',
|
| 719 |
+
'beet salad',
|
| 720 |
+
'beignets',
|
| 721 |
+
'bibimbap',
|
| 722 |
+
'bread pudding',
|
| 723 |
+
'breakfast burrito',
|
| 724 |
+
'bruschetta',
|
| 725 |
+
'caesar salad',
|
| 726 |
+
'cannoli',
|
| 727 |
+
'caprese salad',
|
| 728 |
+
'carrot cake',
|
| 729 |
+
'ceviche',
|
| 730 |
+
'cheese plate',
|
| 731 |
+
'cheesecake',
|
| 732 |
+
'chicken curry',
|
| 733 |
+
'chicken quesadilla',
|
| 734 |
+
'chicken wings',
|
| 735 |
+
'chocolate cake',
|
| 736 |
+
'chocolate mousse',
|
| 737 |
+
'churros',
|
| 738 |
+
'clam chowder',
|
| 739 |
+
'club sandwich',
|
| 740 |
+
'crab cakes',
|
| 741 |
+
'creme brulee',
|
| 742 |
+
'croque madame',
|
| 743 |
+
'cup cakes',
|
| 744 |
+
'deviled eggs',
|
| 745 |
+
'donuts',
|
| 746 |
+
'dumplings',
|
| 747 |
+
'edamame',
|
| 748 |
+
'eggs benedict',
|
| 749 |
+
'escargots',
|
| 750 |
+
'falafel',
|
| 751 |
+
'filet mignon',
|
| 752 |
+
'fish and chips',
|
| 753 |
+
'foie gras',
|
| 754 |
+
'french fries',
|
| 755 |
+
'french onion soup',
|
| 756 |
+
'french toast',
|
| 757 |
+
'fried calamari',
|
| 758 |
+
'fried rice',
|
| 759 |
+
'frozen yogurt',
|
| 760 |
+
'garlic bread',
|
| 761 |
+
'gnocchi',
|
| 762 |
+
'greek salad',
|
| 763 |
+
'grilled cheese sandwich',
|
| 764 |
+
'grilled salmon',
|
| 765 |
+
'guacamole',
|
| 766 |
+
'gyoza',
|
| 767 |
+
'hamburger',
|
| 768 |
+
'hot and sour soup',
|
| 769 |
+
'hot dog',
|
| 770 |
+
'huevos rancheros',
|
| 771 |
+
'hummus',
|
| 772 |
+
'ice cream',
|
| 773 |
+
'lasagna',
|
| 774 |
+
'lobster bisque',
|
| 775 |
+
'lobster roll sandwich',
|
| 776 |
+
'macaroni and cheese',
|
| 777 |
+
'macarons',
|
| 778 |
+
'miso soup',
|
| 779 |
+
'mussels',
|
| 780 |
+
'nachos',
|
| 781 |
+
'omelette',
|
| 782 |
+
'onion rings',
|
| 783 |
+
'oysters',
|
| 784 |
+
'pad thai',
|
| 785 |
+
'paella',
|
| 786 |
+
'pancakes',
|
| 787 |
+
'panna cotta',
|
| 788 |
+
'peking duck',
|
| 789 |
+
'pho',
|
| 790 |
+
'pizza',
|
| 791 |
+
'pork chop',
|
| 792 |
+
'poutine',
|
| 793 |
+
'prime rib',
|
| 794 |
+
'pulled pork sandwich',
|
| 795 |
+
'ramen',
|
| 796 |
+
'ravioli',
|
| 797 |
+
'red velvet cake',
|
| 798 |
+
'risotto',
|
| 799 |
+
'samosa',
|
| 800 |
+
'sashimi',
|
| 801 |
+
'scallops',
|
| 802 |
+
'seaweed salad',
|
| 803 |
+
'shrimp and grits',
|
| 804 |
+
'spaghetti bolognese',
|
| 805 |
+
'spaghetti carbonara',
|
| 806 |
+
'spring rolls',
|
| 807 |
+
'steak',
|
| 808 |
+
'strawberry shortcake',
|
| 809 |
+
'sushi',
|
| 810 |
+
'tacos',
|
| 811 |
+
'takoyaki',
|
| 812 |
+
'tiramisu',
|
| 813 |
+
'tuna tartare',
|
| 814 |
+
'waffles',
|
| 815 |
+
],
|
| 816 |
+
|
| 817 |
+
'templates' : [
|
| 818 |
+
'a photo of {}, a type of food.',
|
| 819 |
+
]
|
| 820 |
+
}
|
| 821 |
+
|
| 822 |
+
data_seeds = [107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121]
|
vlm_eval/ms_coco_gen.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import torchvision.datasets as dset
|
| 4 |
+
import torchvision.transforms as transforms
|
| 5 |
+
from coco_cf import COCO_CF_dataset
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
|
| 8 |
+
def custom_collate_fn(batch):
|
| 9 |
+
collated_batch = {}
|
| 10 |
+
for key in batch[0].keys():
|
| 11 |
+
collated_batch[key] = [item[key] for item in batch]
|
| 12 |
+
return collated_batch
|
| 13 |
+
|
| 14 |
+
coco_2017 = dset.CocoCaptions(root='./open_flamingo_datasets/COCO_2017/val2017/',
|
| 15 |
+
annFile='./open_flamingo_datasets/COCO_2017/captions_val2017.json',
|
| 16 |
+
transform=transforms.ToTensor())
|
| 17 |
+
|
| 18 |
+
coco_cf = COCO_CF_dataset(base_dir='./open_flamingo_datasets/COCO_CF/')
|
| 19 |
+
dl_coco_cf = DataLoader(coco_cf, batch_size=100,collate_fn=custom_collate_fn)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Collect both captions from each batch in one step
|
| 23 |
+
coco_cf_captions = []
|
| 24 |
+
|
| 25 |
+
for batch in dl_coco_cf:
|
| 26 |
+
# Extend the list with both captions at once without list comprehension
|
| 27 |
+
coco_cf_captions.extend([caption.replace('.','').replace(",","").replace("-"," ").replace("'s","").lower().strip() for caption in batch['caption_0']])
|
| 28 |
+
|
| 29 |
+
ms_coco_gen_indices = []
|
| 30 |
+
coco_cf_captions_set = set(coco_cf_captions)
|
| 31 |
+
|
| 32 |
+
for index in range(len(coco_2017)):
|
| 33 |
+
image_id = coco_2017.ids[index]
|
| 34 |
+
_,captions = coco_2017[index]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
matches = [s for s in captions if s.replace(".","").replace(",","").replace("'s","").replace("-"," ").lower().strip() in coco_cf_captions_set]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
for match in matches:
|
| 41 |
+
ms_coco_gen_indices.append((image_id,match))
|
| 42 |
+
ms_coco_gen_indices = ms_coco_gen_indices[:17410]
|
| 43 |
+
print(ms_coco_gen_indices)
|
| 44 |
+
ms_coco = [{'image_id': image_index,'caption': caption} for (image_index, caption) in ms_coco_gen_indices]
|
| 45 |
+
|
| 46 |
+
file_path = 'ms_coco_captions.json'
|
| 47 |
+
|
| 48 |
+
# Save the dictionary to a JSON file
|
| 49 |
+
|
| 50 |
+
import os
|
| 51 |
+
|
| 52 |
+
# Base path where the images are located
|
| 53 |
+
base_image_path = '/home/kc/Downloads/val2017/'
|
| 54 |
+
|
| 55 |
+
# Assuming ms_coco_gen_indices is a list of (image_index, caption) tuples
|
| 56 |
+
ms_coco_gen_indices = [(image_index, caption) for (image_index, caption) in ms_coco_gen_indices]
|
| 57 |
+
|
| 58 |
+
# List to store the updated entries with pathtoimage included
|
| 59 |
+
updated_ms_coco_gen_indices = []
|
| 60 |
+
|
| 61 |
+
# Process each (image_index, caption) in ms_coco_gen_indices
|
| 62 |
+
for image_index, caption in ms_coco_gen_indices:
|
| 63 |
+
# Construct the full path to the image file based on the image_index
|
| 64 |
+
pathtoimage = f"{image_index:012d}.jpg" # Ensure image_index is 12 digits with padding
|
| 65 |
+
|
| 66 |
+
# Append the new entry as (image_index, pathtoimage, caption)
|
| 67 |
+
updated_ms_coco_gen_indices.append((image_index, pathtoimage, caption))
|
| 68 |
+
|
| 69 |
+
# Now ms_coco_gen_indices includes (image_index, pathtoimage, caption)
|
| 70 |
+
ms_coco_gen_indices = updated_ms_coco_gen_indices
|
| 71 |
+
ms_coco = [{'image_id': image_index,'image_name': image_name,'caption': caption} for (image_index,image_name ,caption) in ms_coco_gen_indices]
|
| 72 |
+
|
| 73 |
+
with open(file_path, 'w') as json_file:
|
| 74 |
+
for row in ms_coco:
|
| 75 |
+
json.dump(row, json_file)
|
| 76 |
+
json_file.write('\n')
|
vlm_eval/run_evaluation.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|