sposhiy commited on
Commit
f92b56e
·
verified ·
1 Parent(s): ecb8db5

Upload util/misc.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. util/misc.py +520 -0
util/misc.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Misc functions, including distributed helpers.
4
+
5
+ Mostly copy-paste from torchvision references.
6
+ """
7
+ import os
8
+ import subprocess
9
+ import time
10
+ from collections import defaultdict, deque
11
+ import datetime
12
+ import pickle
13
+ from typing import Optional, List
14
+ from packaging.version import Version
15
+
16
+ import torch
17
+ import torch.distributed as dist
18
+ from torch import Tensor
19
+
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ from torch.autograd import Variable
23
+
24
+ # needed due to empty tensor bug in pytorch and torchvision 0.5
25
+ import torchvision
26
+ if Version(torchvision.__version__) < Version("0.7.0"):
27
+ from torchvision.ops import _new_empty_tensor
28
+ from torchvision.ops.misc import _output_size
29
+
30
+
31
+ class SmoothedValue(object):
32
+ """Track a series of values and provide access to smoothed values over a
33
+ window or the global series average.
34
+ """
35
+
36
+ def __init__(self, window_size=20, fmt=None):
37
+ if fmt is None:
38
+ fmt = "{median:.4f} ({global_avg:.4f})"
39
+ self.deque = deque(maxlen=window_size)
40
+ self.total = 0.0
41
+ self.count = 0
42
+ self.fmt = fmt
43
+
44
+ def update(self, value, n=1):
45
+ self.deque.append(value)
46
+ self.count += n
47
+ self.total += value * n
48
+
49
+ def synchronize_between_processes(self):
50
+ """
51
+ Warning: does not synchronize the deque!
52
+ """
53
+ if not is_dist_avail_and_initialized():
54
+ return
55
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
56
+ dist.barrier()
57
+ dist.all_reduce(t)
58
+ t = t.tolist()
59
+ self.count = int(t[0])
60
+ self.total = t[1]
61
+
62
+ @property
63
+ def median(self):
64
+ d = torch.tensor(list(self.deque))
65
+ return d.median().item()
66
+
67
+ @property
68
+ def avg(self):
69
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
70
+ return d.mean().item()
71
+
72
+ @property
73
+ def global_avg(self):
74
+ return self.total / self.count
75
+
76
+ @property
77
+ def max(self):
78
+ return max(self.deque)
79
+
80
+ @property
81
+ def value(self):
82
+ return self.deque[-1]
83
+
84
+ def __str__(self):
85
+ return self.fmt.format(
86
+ median=self.median,
87
+ avg=self.avg,
88
+ global_avg=self.global_avg,
89
+ max=self.max,
90
+ value=self.value)
91
+
92
+
93
+ def all_gather(data):
94
+ """
95
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
96
+ Args:
97
+ data: any picklable object
98
+ Returns:
99
+ list[data]: list of data gathered from each rank
100
+ """
101
+ world_size = get_world_size()
102
+ if world_size == 1:
103
+ return [data]
104
+
105
+ # serialized to a Tensor
106
+ buffer = pickle.dumps(data)
107
+ storage = torch.ByteStorage.from_buffer(buffer)
108
+ tensor = torch.ByteTensor(storage).to("cuda")
109
+
110
+ # obtain Tensor size of each rank
111
+ local_size = torch.tensor([tensor.numel()], device="cuda")
112
+ size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
113
+ dist.all_gather(size_list, local_size)
114
+ size_list = [int(size.item()) for size in size_list]
115
+ max_size = max(size_list)
116
+
117
+ # receiving Tensor from all ranks
118
+ # we pad the tensor because torch all_gather does not support
119
+ # gathering tensors of different shapes
120
+ tensor_list = []
121
+ for _ in size_list:
122
+ tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
123
+ if local_size != max_size:
124
+ padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
125
+ tensor = torch.cat((tensor, padding), dim=0)
126
+ dist.all_gather(tensor_list, tensor)
127
+
128
+ data_list = []
129
+ for size, tensor in zip(size_list, tensor_list):
130
+ buffer = tensor.cpu().numpy().tobytes()[:size]
131
+ data_list.append(pickle.loads(buffer))
132
+
133
+ return data_list
134
+
135
+
136
+ def reduce_dict(input_dict, average=True):
137
+ """
138
+ Args:
139
+ input_dict (dict): all the values will be reduced
140
+ average (bool): whether to do average or sum
141
+ Reduce the values in the dictionary from all processes so that all processes
142
+ have the averaged results. Returns a dict with the same fields as
143
+ input_dict, after reduction.
144
+ """
145
+ world_size = get_world_size()
146
+ if world_size < 2:
147
+ return input_dict
148
+ with torch.no_grad():
149
+ names = []
150
+ values = []
151
+ # sort the keys so that they are consistent across processes
152
+ for k in sorted(input_dict.keys()):
153
+ names.append(k)
154
+ values.append(input_dict[k])
155
+ values = torch.stack(values, dim=0)
156
+ dist.all_reduce(values)
157
+ if average:
158
+ values /= world_size
159
+ reduced_dict = {k: v for k, v in zip(names, values)}
160
+ return reduced_dict
161
+
162
+
163
+ class MetricLogger(object):
164
+ def __init__(self, delimiter="\t"):
165
+ self.meters = defaultdict(SmoothedValue)
166
+ self.delimiter = delimiter
167
+
168
+ def update(self, **kwargs):
169
+ for k, v in kwargs.items():
170
+ if isinstance(v, torch.Tensor):
171
+ v = v.item()
172
+ assert isinstance(v, (float, int))
173
+ self.meters[k].update(v)
174
+
175
+ def __getattr__(self, attr):
176
+ if attr in self.meters:
177
+ return self.meters[attr]
178
+ if attr in self.__dict__:
179
+ return self.__dict__[attr]
180
+ raise AttributeError("'{}' object has no attribute '{}'".format(
181
+ type(self).__name__, attr))
182
+
183
+ def __str__(self):
184
+ loss_str = []
185
+ for name, meter in self.meters.items():
186
+ loss_str.append(
187
+ "{}: {}".format(name, str(meter))
188
+ )
189
+ return self.delimiter.join(loss_str)
190
+
191
+ def synchronize_between_processes(self):
192
+ for meter in self.meters.values():
193
+ meter.synchronize_between_processes()
194
+
195
+ def add_meter(self, name, meter):
196
+ self.meters[name] = meter
197
+
198
+ def log_every(self, iterable, print_freq, header=None):
199
+ i = 0
200
+ if not header:
201
+ header = ''
202
+ start_time = time.time()
203
+ end = time.time()
204
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
205
+ data_time = SmoothedValue(fmt='{avg:.4f}')
206
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
207
+ if torch.cuda.is_available():
208
+ log_msg = self.delimiter.join([
209
+ header,
210
+ '[{0' + space_fmt + '}/{1}]',
211
+ 'eta: {eta}',
212
+ '{meters}',
213
+ 'time: {time}',
214
+ 'data: {data}',
215
+ 'max mem: {memory:.0f}'
216
+ ])
217
+ else:
218
+ log_msg = self.delimiter.join([
219
+ header,
220
+ '[{0' + space_fmt + '}/{1}]',
221
+ 'eta: {eta}',
222
+ '{meters}',
223
+ 'time: {time}',
224
+ 'data: {data}'
225
+ ])
226
+ MB = 1024.0 * 1024.0
227
+ for obj in iterable:
228
+ data_time.update(time.time() - end)
229
+ yield obj
230
+ iter_time.update(time.time() - end)
231
+ if i % print_freq == 0 or i == len(iterable) - 1:
232
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
233
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
234
+ if torch.cuda.is_available():
235
+ print(log_msg.format(
236
+ i, len(iterable), eta=eta_string,
237
+ meters=str(self),
238
+ time=str(iter_time), data=str(data_time),
239
+ memory=torch.cuda.max_memory_allocated() / MB))
240
+ else:
241
+ print(log_msg.format(
242
+ i, len(iterable), eta=eta_string,
243
+ meters=str(self),
244
+ time=str(iter_time), data=str(data_time)))
245
+ i += 1
246
+ end = time.time()
247
+ total_time = time.time() - start_time
248
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
249
+ print('{} Total time: {} ({:.4f} s / it)'.format(
250
+ header, total_time_str, total_time / len(iterable)))
251
+
252
+
253
+ def get_sha():
254
+ cwd = os.path.dirname(os.path.abspath(__file__))
255
+
256
+ def _run(command):
257
+ return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
258
+ sha = 'N/A'
259
+ diff = "clean"
260
+ branch = 'N/A'
261
+ try:
262
+ sha = _run(['git', 'rev-parse', 'HEAD'])
263
+ subprocess.check_output(['git', 'diff'], cwd=cwd)
264
+ diff = _run(['git', 'diff-index', 'HEAD'])
265
+ diff = "has uncommited changes" if diff else "clean"
266
+ branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
267
+ except Exception:
268
+ pass
269
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
270
+ return message
271
+
272
+
273
+ def collate_fn(batch):
274
+ batch[0] = batch[0].unsqueeze(0)
275
+ batch = list(zip(*batch))
276
+ batch[0] = nested_tensor_from_tensor_list(batch[0])
277
+ return tuple(batch)
278
+
279
+ def collate_fn_crowd(batch):
280
+ # re-organize the batch
281
+ batch_new = []
282
+ for b in batch:
283
+ imgs, points = b
284
+ if imgs.ndim == 3:
285
+ imgs = imgs.unsqueeze(0)
286
+ for i in range(len(imgs)):
287
+ batch_new.append((imgs[i, :, :, :], points[i]))
288
+ batch = batch_new
289
+ batch = list(zip(*batch))
290
+ batch[0] = nested_tensor_from_tensor_list(batch[0])
291
+ return tuple(batch)
292
+
293
+
294
+ def _max_by_axis(the_list):
295
+ # type: (List[List[int]]) -> List[int]
296
+ maxes = the_list[0]
297
+ for sublist in the_list[1:]:
298
+ for index, item in enumerate(sublist):
299
+ maxes[index] = max(maxes[index], item)
300
+ return maxes
301
+
302
+ def _max_by_axis_pad(the_list):
303
+ # type: (List[List[int]]) -> List[int]
304
+ maxes = the_list[0]
305
+ for sublist in the_list[1:]:
306
+ for index, item in enumerate(sublist):
307
+ maxes[index] = max(maxes[index], item)
308
+
309
+ block = 128
310
+
311
+ for i in range(2):
312
+ maxes[i+1] = ((maxes[i+1] - 1) // block + 1) * block
313
+ return maxes
314
+
315
+
316
+ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
317
+ # TODO make this more general
318
+ if tensor_list[0].ndim == 3:
319
+
320
+ # TODO make it support different-sized images
321
+ max_size = _max_by_axis_pad([list(img.shape) for img in tensor_list])
322
+ # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
323
+ batch_shape = [len(tensor_list)] + max_size
324
+ b, c, h, w = batch_shape
325
+ dtype = tensor_list[0].dtype
326
+ device = tensor_list[0].device
327
+ tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
328
+ for img, pad_img in zip(tensor_list, tensor):
329
+ pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
330
+ else:
331
+ raise ValueError('not supported')
332
+ return tensor
333
+
334
+ class NestedTensor(object):
335
+ def __init__(self, tensors, mask: Optional[Tensor]):
336
+ self.tensors = tensors
337
+ self.mask = mask
338
+
339
+ def to(self, device):
340
+ # type: (Device) -> NestedTensor # noqa
341
+ cast_tensor = self.tensors.to(device)
342
+ mask = self.mask
343
+ if mask is not None:
344
+ assert mask is not None
345
+ cast_mask = mask.to(device)
346
+ else:
347
+ cast_mask = None
348
+ return NestedTensor(cast_tensor, cast_mask)
349
+
350
+ def decompose(self):
351
+ return self.tensors, self.mask
352
+
353
+ def __repr__(self):
354
+ return str(self.tensors)
355
+
356
+
357
+ def setup_for_distributed(is_master):
358
+ """
359
+ This function disables printing when not in master process
360
+ """
361
+ import builtins as __builtin__
362
+ builtin_print = __builtin__.print
363
+
364
+ def print(*args, **kwargs):
365
+ force = kwargs.pop('force', False)
366
+ if is_master or force:
367
+ builtin_print(*args, **kwargs)
368
+
369
+ __builtin__.print = print
370
+
371
+
372
+ def is_dist_avail_and_initialized():
373
+ if not dist.is_available():
374
+ return False
375
+ if not dist.is_initialized():
376
+ return False
377
+ return True
378
+
379
+
380
+ def get_world_size():
381
+ if not is_dist_avail_and_initialized():
382
+ return 1
383
+ return dist.get_world_size()
384
+
385
+
386
+ def get_rank():
387
+ if not is_dist_avail_and_initialized():
388
+ return 0
389
+ return dist.get_rank()
390
+
391
+
392
+ def is_main_process():
393
+ return get_rank() == 0
394
+
395
+
396
+ def save_on_master(*args, **kwargs):
397
+ if is_main_process():
398
+ torch.save(*args, **kwargs)
399
+
400
+
401
+ def init_distributed_mode(args):
402
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
403
+ args.rank = int(os.environ["RANK"])
404
+ args.world_size = int(os.environ['WORLD_SIZE'])
405
+ args.gpu = int(os.environ['LOCAL_RANK'])
406
+ elif 'SLURM_PROCID' in os.environ:
407
+ args.rank = int(os.environ['SLURM_PROCID'])
408
+ args.gpu = args.rank % torch.cuda.device_count()
409
+ else:
410
+ print('Not using distributed mode')
411
+ args.distributed = False
412
+ return
413
+
414
+ args.distributed = True
415
+
416
+ torch.cuda.set_device(args.gpu)
417
+ args.dist_backend = 'nccl'
418
+ print('| distributed init (rank {}): {}'.format(
419
+ args.rank, args.dist_url), flush=True)
420
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
421
+ world_size=args.world_size, rank=args.rank)
422
+ torch.distributed.barrier()
423
+ setup_for_distributed(args.rank == 0)
424
+
425
+
426
+ @torch.no_grad()
427
+ def accuracy(output, target, topk=(1,)):
428
+ """Computes the precision@k for the specified values of k"""
429
+ if target.numel() == 0:
430
+ return [torch.zeros([], device=output.device)]
431
+ maxk = max(topk)
432
+ batch_size = target.size(0)
433
+
434
+ _, pred = output.topk(maxk, 1, True, True)
435
+ pred = pred.t()
436
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
437
+
438
+ res = []
439
+ for k in topk:
440
+ correct_k = correct[:k].view(-1).float().sum(0)
441
+ res.append(correct_k.mul_(100.0 / batch_size))
442
+ return res
443
+
444
+
445
+ def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
446
+ # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
447
+ """
448
+ Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
449
+ This will eventually be supported natively by PyTorch, and this
450
+ class can go away.
451
+ """
452
+ if float(torchvision.__version__[:3]) < 0.7:
453
+ if input.numel() > 0:
454
+ return torch.nn.functional.interpolate(
455
+ input, size, scale_factor, mode, align_corners
456
+ )
457
+
458
+ output_shape = _output_size(2, input, size, scale_factor)
459
+ output_shape = list(input.shape[:-2]) + list(output_shape)
460
+ return _new_empty_tensor(input, output_shape)
461
+ else:
462
+ return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
463
+
464
+
465
+ class FocalLoss(nn.Module):
466
+ r"""
467
+ This criterion is a implemenation of Focal Loss, which is proposed in
468
+ Focal Loss for Dense Object Detection.
469
+
470
+ Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])
471
+
472
+ The losses are averaged across observations for each minibatch.
473
+
474
+ Args:
475
+ alpha(1D Tensor, Variable) : the scalar factor for this criterion
476
+ gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5),
477
+ putting more focus on hard, misclassified examples
478
+ size_average(bool): By default, the losses are averaged over observations for each minibatch.
479
+ However, if the field size_average is set to False, the losses are
480
+ instead summed for each minibatch.
481
+
482
+
483
+ """
484
+ def __init__(self, class_num, alpha=None, gamma=2, size_average=True):
485
+ super(FocalLoss, self).__init__()
486
+ if alpha is None:
487
+ self.alpha = Variable(torch.ones(class_num, 1))
488
+ else:
489
+ if isinstance(alpha, Variable):
490
+ self.alpha = alpha
491
+ else:
492
+ self.alpha = Variable(alpha)
493
+ self.gamma = gamma
494
+ self.class_num = class_num
495
+ self.size_average = size_average
496
+
497
+ def forward(self, inputs, targets):
498
+ N = inputs.size(0)
499
+ C = inputs.size(1)
500
+ P = F.softmax(inputs)
501
+
502
+ class_mask = inputs.data.new(N, C).fill_(0)
503
+ class_mask = Variable(class_mask)
504
+ ids = targets.view(-1, 1)
505
+ class_mask.scatter_(1, ids.data, 1.)
506
+
507
+ if inputs.is_cuda and not self.alpha.is_cuda:
508
+ self.alpha = self.alpha.cuda()
509
+ alpha = self.alpha[ids.data.view(-1)]
510
+
511
+ probs = (P*class_mask).sum(1).view(-1,1)
512
+
513
+ log_p = probs.log()
514
+ batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p
515
+
516
+ if self.size_average:
517
+ loss = batch_loss.mean()
518
+ else:
519
+ loss = batch_loss.sum()
520
+ return loss