vivek9chavan commited on
Commit
a75f067
·
verified ·
1 Parent(s): 8364cd2

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +816 -0
utils.py ADDED
@@ -0,0 +1,816 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Misc functions.
3
+
4
+ Mostly copy-paste from torchvision references or other public repos like DETR:
5
+ https://github.com/facebookresearch/detr/blob/master/util/misc.py
6
+ """
7
+ import os
8
+ import sys
9
+ import time
10
+ import math
11
+ import random
12
+ import datetime
13
+ import subprocess
14
+ from collections import defaultdict, deque
15
+
16
+ import numpy as np
17
+ import torch
18
+ from torch import nn
19
+ import torch.distributed as dist
20
+ from PIL import ImageFilter, ImageOps
21
+
22
+
23
+ class GaussianBlur(object):
24
+ """
25
+ Apply Gaussian Blur to the PIL image.
26
+ """
27
+ def __init__(self, p=0.5, radius_min=0.1, radius_max=2.):
28
+ self.prob = p
29
+ self.radius_min = radius_min
30
+ self.radius_max = radius_max
31
+
32
+ def __call__(self, img):
33
+ do_it = random.random() <= self.prob
34
+ if not do_it:
35
+ return img
36
+
37
+ return img.filter(
38
+ ImageFilter.GaussianBlur(
39
+ radius=random.uniform(self.radius_min, self.radius_max)
40
+ )
41
+ )
42
+
43
+
44
+ class Solarization(object):
45
+ """
46
+ Apply Solarization to the PIL image.
47
+ """
48
+ def __init__(self, p):
49
+ self.p = p
50
+
51
+ def __call__(self, img):
52
+ if random.random() < self.p:
53
+ return ImageOps.solarize(img)
54
+ else:
55
+ return img
56
+
57
+
58
+ def load_pretrained_weights(model, pretrained_weights, checkpoint_key, model_name, patch_size):
59
+ if os.path.isfile(pretrained_weights):
60
+ state_dict = torch.load(pretrained_weights, map_location="cpu")
61
+ if checkpoint_key is not None and checkpoint_key in state_dict:
62
+ print(f"Take key {checkpoint_key} in provided checkpoint dict")
63
+ state_dict = state_dict[checkpoint_key]
64
+ # remove `module.` prefix
65
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
66
+ # remove `backbone.` prefix induced by multicrop wrapper
67
+ state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
68
+ msg = model.load_state_dict(state_dict, strict=False)
69
+ print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg))
70
+ else:
71
+ print("Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.")
72
+ url = None
73
+ if model_name == "vit_small" and patch_size == 16:
74
+ url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
75
+ elif model_name == "vit_small" and patch_size == 8:
76
+ url = "dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth"
77
+ elif model_name == "vit_base" and patch_size == 16:
78
+ url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
79
+ elif model_name == "vit_base" and patch_size == 8:
80
+ url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
81
+ elif model_name == "xcit_small_12_p16":
82
+ url = "dino_xcit_small_12_p16_pretrain/dino_xcit_small_12_p16_pretrain.pth"
83
+ elif model_name == "xcit_small_12_p8":
84
+ url = "dino_xcit_small_12_p8_pretrain/dino_xcit_small_12_p8_pretrain.pth"
85
+ elif model_name == "xcit_medium_24_p16":
86
+ url = "dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain.pth"
87
+ elif model_name == "xcit_medium_24_p8":
88
+ url = "dino_xcit_medium_24_p8_pretrain/dino_xcit_medium_24_p8_pretrain.pth"
89
+ elif model_name == "resnet50":
90
+ url = "dino_resnet50_pretrain/dino_resnet50_pretrain.pth"
91
+ if url is not None:
92
+ print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.")
93
+ state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
94
+ model.load_state_dict(state_dict, strict=True)
95
+ else:
96
+ print("There is no reference weights available for this model => We use random weights.")
97
+
98
+
99
+ def load_pretrained_linear_weights(linear_classifier, model_name, patch_size):
100
+ url = None
101
+ if model_name == "vit_small" and patch_size == 16:
102
+ url = "dino_deitsmall16_pretrain/dino_deitsmall16_linearweights.pth"
103
+ elif model_name == "vit_small" and patch_size == 8:
104
+ url = "dino_deitsmall8_pretrain/dino_deitsmall8_linearweights.pth"
105
+ elif model_name == "vit_base" and patch_size == 16:
106
+ url = "dino_vitbase16_pretrain/dino_vitbase16_linearweights.pth"
107
+ elif model_name == "vit_base" and patch_size == 8:
108
+ url = "dino_vitbase8_pretrain/dino_vitbase8_linearweights.pth"
109
+ elif model_name == "resnet50":
110
+ url = "dino_resnet50_pretrain/dino_resnet50_linearweights.pth"
111
+ if url is not None:
112
+ print("We load the reference pretrained linear weights.")
113
+ state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)["state_dict"]
114
+ linear_classifier.load_state_dict(state_dict, strict=True)
115
+ else:
116
+ print("We use random linear weights.")
117
+
118
+
119
+ def clip_gradients(model, clip):
120
+ norms = []
121
+ for name, p in model.named_parameters():
122
+ if p.grad is not None:
123
+ param_norm = p.grad.data.norm(2)
124
+ norms.append(param_norm.item())
125
+ clip_coef = clip / (param_norm + 1e-6)
126
+ if clip_coef < 1:
127
+ p.grad.data.mul_(clip_coef)
128
+ return norms
129
+
130
+
131
+ def cancel_gradients_last_layer(epoch, model, freeze_last_layer):
132
+ if epoch >= freeze_last_layer:
133
+ return
134
+ for n, p in model.named_parameters():
135
+ if "last_layer" in n:
136
+ p.grad = None
137
+
138
+
139
+ def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs):
140
+ """
141
+ Re-start from checkpoint
142
+ """
143
+ if not os.path.isfile(ckp_path):
144
+ return
145
+ print("Found checkpoint at {}".format(ckp_path))
146
+
147
+ # open checkpoint file
148
+ checkpoint = torch.load(ckp_path, map_location="cpu")
149
+
150
+ # key is what to look for in the checkpoint file
151
+ # value is the object to load
152
+ # example: {'state_dict': model}
153
+ for key, value in kwargs.items():
154
+ if key in checkpoint and value is not None:
155
+ try:
156
+ msg = value.load_state_dict(checkpoint[key], strict=False)
157
+ print("=> loaded '{}' from checkpoint '{}' with msg {}".format(key, ckp_path, msg))
158
+ except TypeError:
159
+ try:
160
+ msg = value.load_state_dict(checkpoint[key])
161
+ print("=> loaded '{}' from checkpoint: '{}'".format(key, ckp_path))
162
+ except ValueError:
163
+ print("=> failed to load '{}' from checkpoint: '{}'".format(key, ckp_path))
164
+ else:
165
+ print("=> key '{}' not found in checkpoint: '{}'".format(key, ckp_path))
166
+
167
+ # re load variable important for the run
168
+ if run_variables is not None:
169
+ for var_name in run_variables:
170
+ if var_name in checkpoint:
171
+ run_variables[var_name] = checkpoint[var_name]
172
+
173
+
174
+ def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
175
+ warmup_schedule = np.array([])
176
+ warmup_iters = warmup_epochs * niter_per_ep
177
+ if warmup_epochs > 0:
178
+ warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
179
+
180
+ iters = np.arange(epochs * niter_per_ep - warmup_iters)
181
+ schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
182
+
183
+ schedule = np.concatenate((warmup_schedule, schedule))
184
+ assert len(schedule) == epochs * niter_per_ep
185
+ return schedule
186
+
187
+
188
+ def bool_flag(s):
189
+ """
190
+ Parse boolean arguments from the command line.
191
+ """
192
+ FALSY_STRINGS = {"off", "false", "0"}
193
+ TRUTHY_STRINGS = {"on", "true", "1"}
194
+ if s.lower() in FALSY_STRINGS:
195
+ return False
196
+ elif s.lower() in TRUTHY_STRINGS:
197
+ return True
198
+ else:
199
+ raise argparse.ArgumentTypeError("invalid value for a boolean flag")
200
+
201
+
202
+ def fix_random_seeds(seed=31):
203
+ """
204
+ Fix random seeds.
205
+ """
206
+ torch.manual_seed(seed)
207
+ torch.cuda.manual_seed_all(seed)
208
+ np.random.seed(seed)
209
+
210
+
211
+ class SmoothedValue(object):
212
+ """Track a series of values and provide access to smoothed values over a
213
+ window or the global series average.
214
+ """
215
+
216
+ def __init__(self, window_size=20, fmt=None):
217
+ if fmt is None:
218
+ fmt = "{median:.6f} ({global_avg:.6f})"
219
+ self.deque = deque(maxlen=window_size)
220
+ self.total = 0.0
221
+ self.count = 0
222
+ self.fmt = fmt
223
+
224
+ def update(self, value, n=1):
225
+ self.deque.append(value)
226
+ self.count += n
227
+ self.total += value * n
228
+
229
+ def synchronize_between_processes(self):
230
+ """
231
+ Warning: does not synchronize the deque!
232
+ """
233
+ if not is_dist_avail_and_initialized():
234
+ return
235
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
236
+ dist.barrier()
237
+ dist.all_reduce(t)
238
+ t = t.tolist()
239
+ self.count = int(t[0])
240
+ self.total = t[1]
241
+
242
+ @property
243
+ def median(self):
244
+ d = torch.tensor(list(self.deque))
245
+ return d.median().item()
246
+
247
+ @property
248
+ def avg(self):
249
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
250
+ return d.mean().item()
251
+
252
+ @property
253
+ def global_avg(self):
254
+ return self.total / self.count
255
+
256
+ @property
257
+ def max(self):
258
+ return max(self.deque)
259
+
260
+ @property
261
+ def value(self):
262
+ return self.deque[-1]
263
+
264
+ def __str__(self):
265
+ return self.fmt.format(
266
+ median=self.median,
267
+ avg=self.avg,
268
+ global_avg=self.global_avg,
269
+ max=self.max,
270
+ value=self.value)
271
+
272
+
273
+ def reduce_dict(input_dict, average=True):
274
+ """
275
+ Args:
276
+ input_dict (dict): all the values will be reduced
277
+ average (bool): whether to do average or sum
278
+ Reduce the values in the dictionary from all processes so that all processes
279
+ have the averaged results. Returns a dict with the same fields as
280
+ input_dict, after reduction.
281
+ """
282
+ world_size = get_world_size()
283
+ if world_size < 2:
284
+ return input_dict
285
+ with torch.no_grad():
286
+ names = []
287
+ values = []
288
+ # sort the keys so that they are consistent across processes
289
+ for k in sorted(input_dict.keys()):
290
+ names.append(k)
291
+ values.append(input_dict[k])
292
+ values = torch.stack(values, dim=0)
293
+ dist.all_reduce(values)
294
+ if average:
295
+ values /= world_size
296
+ reduced_dict = {k: v for k, v in zip(names, values)}
297
+ return reduced_dict
298
+
299
+
300
+ class MetricLogger(object):
301
+ def __init__(self, delimiter="\t"):
302
+ self.meters = defaultdict(SmoothedValue)
303
+ self.delimiter = delimiter
304
+
305
+ def update(self, **kwargs):
306
+ for k, v in kwargs.items():
307
+ if isinstance(v, torch.Tensor):
308
+ v = v.item()
309
+ assert isinstance(v, (float, int))
310
+ self.meters[k].update(v)
311
+
312
+ def __getattr__(self, attr):
313
+ if attr in self.meters:
314
+ return self.meters[attr]
315
+ if attr in self.__dict__:
316
+ return self.__dict__[attr]
317
+ raise AttributeError("'{}' object has no attribute '{}'".format(
318
+ type(self).__name__, attr))
319
+
320
+ def __str__(self):
321
+ loss_str = []
322
+ for name, meter in self.meters.items():
323
+ loss_str.append(
324
+ "{}: {}".format(name, str(meter))
325
+ )
326
+ return self.delimiter.join(loss_str)
327
+
328
+ def synchronize_between_processes(self):
329
+ for meter in self.meters.values():
330
+ meter.synchronize_between_processes()
331
+
332
+ def add_meter(self, name, meter):
333
+ self.meters[name] = meter
334
+
335
+ def log_every(self, iterable, print_freq, header=None):
336
+ i = 0
337
+ if not header:
338
+ header = ''
339
+ start_time = time.time()
340
+ end = time.time()
341
+ iter_time = SmoothedValue(fmt='{avg:.6f}')
342
+ data_time = SmoothedValue(fmt='{avg:.6f}')
343
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
344
+ if torch.cuda.is_available():
345
+ log_msg = self.delimiter.join([
346
+ header,
347
+ '[{0' + space_fmt + '}/{1}]',
348
+ 'eta: {eta}',
349
+ '{meters}',
350
+ 'time: {time}',
351
+ 'data: {data}',
352
+ 'max mem: {memory:.0f}'
353
+ ])
354
+ else:
355
+ log_msg = self.delimiter.join([
356
+ header,
357
+ '[{0' + space_fmt + '}/{1}]',
358
+ 'eta: {eta}',
359
+ '{meters}',
360
+ 'time: {time}',
361
+ 'data: {data}'
362
+ ])
363
+ MB = 1024.0 * 1024.0
364
+ for obj in iterable:
365
+ data_time.update(time.time() - end)
366
+ yield obj
367
+ iter_time.update(time.time() - end)
368
+ if i % print_freq == 0 or i == len(iterable) - 1:
369
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
370
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
371
+ if torch.cuda.is_available():
372
+ print(log_msg.format(
373
+ i, len(iterable), eta=eta_string,
374
+ meters=str(self),
375
+ time=str(iter_time), data=str(data_time),
376
+ memory=torch.cuda.max_memory_allocated() / MB))
377
+ else:
378
+ print(log_msg.format(
379
+ i, len(iterable), eta=eta_string,
380
+ meters=str(self),
381
+ time=str(iter_time), data=str(data_time)))
382
+ i += 1
383
+ end = time.time()
384
+ total_time = time.time() - start_time
385
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
386
+ print('{} Total time: {} ({:.6f} s / it)'.format(
387
+ header, total_time_str, total_time / len(iterable)))
388
+
389
+
390
+ def get_sha():
391
+ cwd = os.path.dirname(os.path.abspath(__file__))
392
+
393
+ def _run(command):
394
+ return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
395
+ sha = 'N/A'
396
+ diff = "clean"
397
+ branch = 'N/A'
398
+ try:
399
+ sha = _run(['git', 'rev-parse', 'HEAD'])
400
+ subprocess.check_output(['git', 'diff'], cwd=cwd)
401
+ diff = _run(['git', 'diff-index', 'HEAD'])
402
+ diff = "has uncommited changes" if diff else "clean"
403
+ branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
404
+ except Exception:
405
+ pass
406
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
407
+ return message
408
+
409
+
410
+ def is_dist_avail_and_initialized():
411
+ if not dist.is_available():
412
+ return False
413
+ if not dist.is_initialized():
414
+ return False
415
+ return True
416
+
417
+
418
+ def get_world_size():
419
+ if not is_dist_avail_and_initialized():
420
+ return 1
421
+ return dist.get_world_size()
422
+
423
+
424
+ def get_rank():
425
+ if not is_dist_avail_and_initialized():
426
+ return 0
427
+ return dist.get_rank()
428
+
429
+
430
+ def is_main_process():
431
+ return get_rank() == 0
432
+
433
+
434
+ def save_on_master(*args, **kwargs):
435
+ if is_main_process():
436
+ torch.save(*args, **kwargs)
437
+
438
+
439
+ def setup_for_distributed(is_master):
440
+ """
441
+ This function disables printing when not in master process
442
+ """
443
+ import builtins as __builtin__
444
+ builtin_print = __builtin__.print
445
+
446
+ def print(*args, **kwargs):
447
+ force = kwargs.pop('force', False)
448
+ if is_master or force:
449
+ builtin_print(*args, **kwargs)
450
+
451
+ __builtin__.print = print
452
+
453
+
454
+ def init_distributed_mode(args):
455
+ # launched with torch.distributed.launch
456
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
457
+ args.rank = int(os.environ["RANK"])
458
+ args.world_size = int(os.environ['WORLD_SIZE'])
459
+ args.gpu = int(os.environ['LOCAL_RANK'])
460
+ # launched with submitit on a slurm cluster
461
+ elif 'SLURM_PROCID' in os.environ:
462
+ args.rank = int(os.environ['SLURM_PROCID'])
463
+ args.gpu = args.rank % torch.cuda.device_count()
464
+ # launched naively with `python main_dino.py`
465
+ # we manually add MASTER_ADDR and MASTER_PORT to env variables
466
+ elif torch.cuda.is_available():
467
+ print('Will run the code on one GPU.')
468
+ args.rank, args.gpu, args.world_size = 0, 0, 1
469
+ os.environ['MASTER_ADDR'] = '127.0.0.1'
470
+ os.environ['MASTER_PORT'] = '29500'
471
+ else:
472
+ print('Does not support training without GPU.')
473
+ sys.exit(1)
474
+
475
+ dist.init_process_group(
476
+ backend="nccl",
477
+ init_method=args.dist_url,
478
+ world_size=args.world_size,
479
+ rank=args.rank,
480
+ )
481
+
482
+ torch.cuda.set_device(args.gpu)
483
+ print('| distributed init (rank {}): {}'.format(
484
+ args.rank, args.dist_url), flush=True)
485
+ dist.barrier()
486
+ setup_for_distributed(args.rank == 0)
487
+
488
+
489
+ def accuracy(output, target, topk=(1,)):
490
+ """Computes the accuracy over the k top predictions for the specified values of k"""
491
+ maxk = max(topk)
492
+ batch_size = target.size(0)
493
+ _, pred = output.topk(maxk, 1, True, True)
494
+ pred = pred.t()
495
+ correct = pred.eq(target.reshape(1, -1).expand_as(pred))
496
+ return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]
497
+
498
+
499
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
500
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
501
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
502
+ def norm_cdf(x):
503
+ # Computes standard normal cumulative distribution function
504
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
505
+
506
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
507
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
508
+ "The distribution of values may be incorrect.",
509
+ stacklevel=2)
510
+
511
+ with torch.no_grad():
512
+ # Values are generated by using a truncated uniform distribution and
513
+ # then using the inverse CDF for the normal distribution.
514
+ # Get upper and lower cdf values
515
+ l = norm_cdf((a - mean) / std)
516
+ u = norm_cdf((b - mean) / std)
517
+
518
+ # Uniformly fill tensor with values from [l, u], then translate to
519
+ # [2l-1, 2u-1].
520
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
521
+
522
+ # Use inverse cdf transform for normal distribution to get truncated
523
+ # standard normal
524
+ tensor.erfinv_()
525
+
526
+ # Transform to proper mean, std
527
+ tensor.mul_(std * math.sqrt(2.))
528
+ tensor.add_(mean)
529
+
530
+ # Clamp to ensure it's in the proper range
531
+ tensor.clamp_(min=a, max=b)
532
+ return tensor
533
+
534
+
535
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
536
+ # type: (Tensor, float, float, float, float) -> Tensor
537
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
538
+
539
+
540
+ class LARS(torch.optim.Optimizer):
541
+ """
542
+ Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py
543
+ """
544
+ def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, eta=0.001,
545
+ weight_decay_filter=None, lars_adaptation_filter=None):
546
+ defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum,
547
+ eta=eta, weight_decay_filter=weight_decay_filter,
548
+ lars_adaptation_filter=lars_adaptation_filter)
549
+ super().__init__(params, defaults)
550
+
551
+ @torch.no_grad()
552
+ def step(self):
553
+ for g in self.param_groups:
554
+ for p in g['params']:
555
+ dp = p.grad
556
+
557
+ if dp is None:
558
+ continue
559
+
560
+ if p.ndim != 1:
561
+ dp = dp.add(p, alpha=g['weight_decay'])
562
+
563
+ if p.ndim != 1:
564
+ param_norm = torch.norm(p)
565
+ update_norm = torch.norm(dp)
566
+ one = torch.ones_like(param_norm)
567
+ q = torch.where(param_norm > 0.,
568
+ torch.where(update_norm > 0,
569
+ (g['eta'] * param_norm / update_norm), one), one)
570
+ dp = dp.mul(q)
571
+
572
+ param_state = self.state[p]
573
+ if 'mu' not in param_state:
574
+ param_state['mu'] = torch.zeros_like(p)
575
+ mu = param_state['mu']
576
+ mu.mul_(g['momentum']).add_(dp)
577
+
578
+ p.add_(mu, alpha=-g['lr'])
579
+
580
+
581
+ class MultiCropWrapper(nn.Module):
582
+ """
583
+ Perform forward pass separately on each resolution input.
584
+ The inputs corresponding to a single resolution are clubbed and single
585
+ forward is run on the same resolution inputs. Hence we do several
586
+ forward passes = number of different resolutions used. We then
587
+ concatenate all the output features and run the head forward on these
588
+ concatenated features.
589
+ """
590
+ def __init__(self, backbone, head):
591
+ super(MultiCropWrapper, self).__init__()
592
+ # disable layers dedicated to ImageNet labels classification
593
+ backbone.fc, backbone.head = nn.Identity(), nn.Identity()
594
+ self.backbone = backbone
595
+ self.head = head
596
+
597
+ def forward(self, x):
598
+ # convert to list
599
+ if not isinstance(x, list):
600
+ x = [x]
601
+ idx_crops = torch.cumsum(torch.unique_consecutive(
602
+ torch.tensor([inp.shape[-1] for inp in x]),
603
+ return_counts=True,
604
+ )[1], 0)
605
+ start_idx, output = 0, torch.empty(0).to(x[0].device)
606
+ for end_idx in idx_crops:
607
+ _out = self.backbone(torch.cat(x[start_idx: end_idx]))
608
+ # The output is a tuple with XCiT model. See:
609
+ # https://github.com/facebookresearch/xcit/blob/master/xcit.py#L404-L405
610
+ if isinstance(_out, tuple):
611
+ _out = _out[0]
612
+ # accumulate outputs
613
+ output = torch.cat((output, _out))
614
+ start_idx = end_idx
615
+ # Run the head forward on the concatenated features.
616
+ return self.head(output)
617
+
618
+
619
+ def get_params_groups(model):
620
+ regularized = []
621
+ not_regularized = []
622
+ for name, param in model.named_parameters():
623
+ if not param.requires_grad:
624
+ continue
625
+ # we do not regularize biases nor Norm parameters
626
+ if name.endswith(".bias") or len(param.shape) == 1:
627
+ not_regularized.append(param)
628
+ else:
629
+ regularized.append(param)
630
+ return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}]
631
+
632
+
633
+ def has_batchnorms(model):
634
+ bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
635
+ for name, module in model.named_modules():
636
+ if isinstance(module, bn_types):
637
+ return True
638
+ return False
639
+
640
+
641
+ class PCA():
642
+ """
643
+ Class to compute and apply PCA.
644
+ """
645
+ def __init__(self, dim=256, whit=0.5):
646
+ self.dim = dim
647
+ self.whit = whit
648
+ self.mean = None
649
+
650
+ def train_pca(self, cov):
651
+ """
652
+ Takes a covariance matrix (np.ndarray) as input.
653
+ """
654
+ d, v = np.linalg.eigh(cov)
655
+ eps = d.max() * 1e-5
656
+ n_0 = (d < eps).sum()
657
+ if n_0 > 0:
658
+ d[d < eps] = eps
659
+
660
+ # total energy
661
+ totenergy = d.sum()
662
+
663
+ # sort eigenvectors with eigenvalues order
664
+ idx = np.argsort(d)[::-1][:self.dim]
665
+ d = d[idx]
666
+ v = v[:, idx]
667
+
668
+ print("keeping %.2f %% of the energy" % (d.sum() / totenergy * 100.0))
669
+
670
+ # for the whitening
671
+ d = np.diag(1. / d**self.whit)
672
+
673
+ # principal components
674
+ self.dvt = np.dot(d, v.T)
675
+
676
+ def apply(self, x):
677
+ # input is from numpy
678
+ if isinstance(x, np.ndarray):
679
+ if self.mean is not None:
680
+ x -= self.mean
681
+ return np.dot(self.dvt, x.T).T
682
+
683
+ # input is from torch and is on GPU
684
+ if x.is_cuda:
685
+ if self.mean is not None:
686
+ x -= torch.cuda.FloatTensor(self.mean)
687
+ return torch.mm(torch.cuda.FloatTensor(self.dvt), x.transpose(0, 1)).transpose(0, 1)
688
+
689
+ # input if from torch, on CPU
690
+ if self.mean is not None:
691
+ x -= torch.FloatTensor(self.mean)
692
+ return torch.mm(torch.FloatTensor(self.dvt), x.transpose(0, 1)).transpose(0, 1)
693
+
694
+
695
+ def compute_ap(ranks, nres):
696
+ """
697
+ Computes average precision for given ranked indexes.
698
+ Arguments
699
+ ---------
700
+ ranks : zerro-based ranks of positive images
701
+ nres : number of positive images
702
+ Returns
703
+ -------
704
+ ap : average precision
705
+ """
706
+
707
+ # number of images ranked by the system
708
+ nimgranks = len(ranks)
709
+
710
+ # accumulate trapezoids in PR-plot
711
+ ap = 0
712
+
713
+ recall_step = 1. / nres
714
+
715
+ for j in np.arange(nimgranks):
716
+ rank = ranks[j]
717
+
718
+ if rank == 0:
719
+ precision_0 = 1.
720
+ else:
721
+ precision_0 = float(j) / rank
722
+
723
+ precision_1 = float(j + 1) / (rank + 1)
724
+
725
+ ap += (precision_0 + precision_1) * recall_step / 2.
726
+
727
+ return ap
728
+
729
+
730
+ def compute_map(ranks, gnd, kappas=[]):
731
+ """
732
+ Computes the mAP for a given set of returned results.
733
+ Usage:
734
+ map = compute_map (ranks, gnd)
735
+ computes mean average precsion (map) only
736
+ map, aps, pr, prs = compute_map (ranks, gnd, kappas)
737
+ computes mean average precision (map), average precision (aps) for each query
738
+ computes mean precision at kappas (pr), precision at kappas (prs) for each query
739
+ Notes:
740
+ 1) ranks starts from 0, ranks.shape = db_size X #queries
741
+ 2) The junk results (e.g., the query itself) should be declared in the gnd stuct array
742
+ 3) If there are no positive images for some query, that query is excluded from the evaluation
743
+ """
744
+
745
+ map = 0.
746
+ nq = len(gnd) # number of queries
747
+ aps = np.zeros(nq)
748
+ pr = np.zeros(len(kappas))
749
+ prs = np.zeros((nq, len(kappas)))
750
+ nempty = 0
751
+
752
+ for i in np.arange(nq):
753
+ qgnd = np.array(gnd[i]['ok'])
754
+
755
+ # no positive images, skip from the average
756
+ if qgnd.shape[0] == 0:
757
+ aps[i] = float('nan')
758
+ prs[i, :] = float('nan')
759
+ nempty += 1
760
+ continue
761
+
762
+ try:
763
+ qgndj = np.array(gnd[i]['junk'])
764
+ except:
765
+ qgndj = np.empty(0)
766
+
767
+ # sorted positions of positive and junk images (0 based)
768
+ pos = np.arange(ranks.shape[0])[np.in1d(ranks[:,i], qgnd)]
769
+ junk = np.arange(ranks.shape[0])[np.in1d(ranks[:,i], qgndj)]
770
+
771
+ k = 0;
772
+ ij = 0;
773
+ if len(junk):
774
+ # decrease positions of positives based on the number of
775
+ # junk images appearing before them
776
+ ip = 0
777
+ while (ip < len(pos)):
778
+ while (ij < len(junk) and pos[ip] > junk[ij]):
779
+ k += 1
780
+ ij += 1
781
+ pos[ip] = pos[ip] - k
782
+ ip += 1
783
+
784
+ # compute ap
785
+ ap = compute_ap(pos, len(qgnd))
786
+ map = map + ap
787
+ aps[i] = ap
788
+
789
+ # compute precision @ k
790
+ pos += 1 # get it to 1-based
791
+ for j in np.arange(len(kappas)):
792
+ kq = min(max(pos), kappas[j]);
793
+ prs[i, j] = (pos <= kq).sum() / kq
794
+ pr = pr + prs[i, :]
795
+
796
+ map = map / (nq - nempty)
797
+ pr = pr / (nq - nempty)
798
+
799
+ return map, aps, pr, prs
800
+
801
+
802
+ def multi_scale(samples, model):
803
+ v = None
804
+ for s in [1, 1/2**(1/2), 1/2]: # we use 3 different scales
805
+ if s == 1:
806
+ inp = samples.clone()
807
+ else:
808
+ inp = nn.functional.interpolate(samples, scale_factor=s, mode='bilinear', align_corners=False)
809
+ feats = model(inp).clone()
810
+ if v is None:
811
+ v = feats
812
+ else:
813
+ v += feats
814
+ v /= 3
815
+ v /= v.norm()
816
+ return v