| """
|
| Modified from DETR https://github.com/facebookresearch/detr
|
| Misc functions.
|
| Mostly copy-paste from torchvision references.
|
| """
|
| import pickle
|
| from typing import Optional, List
|
| from collections import OrderedDict, defaultdict, deque
|
| import time
|
| import datetime
|
|
|
| import torch
|
| import torch.distributed as dist
|
| from torch import Tensor
|
| from torchvision.ops.boxes import box_area
|
| import numpy as np
|
|
|
| from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn
|
| import matplotlib.pyplot as plt
|
|
|
|
|
| import torchvision
|
| if float(torchvision.__version__.split(".")[1]) < 7.0:
|
| from torchvision.ops import _new_empty_tensor
|
| from torchvision.ops.misc import _output_size
|
|
|
|
|
| def all_gather(data):
|
| """
|
| Run all_gather on arbitrary picklable data (not necessarily tensors)
|
| Args:
|
| data: any picklable object
|
| Returns:
|
| list[data]: list of data gathered from each rank
|
| """
|
| world_size = get_world_size()
|
| if world_size == 1:
|
| return [data]
|
|
|
|
|
| buffer = pickle.dumps(data)
|
| storage = torch.ByteStorage.from_buffer(buffer)
|
| tensor = torch.ByteTensor(storage).to("cuda")
|
|
|
|
|
| local_size = torch.tensor([tensor.numel()], device="cuda")
|
| size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
|
| dist.all_gather(size_list, local_size)
|
| size_list = [int(size.item()) for size in size_list]
|
| max_size = max(size_list)
|
|
|
|
|
|
|
|
|
| tensor_list = []
|
| for _ in size_list:
|
| tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
|
| if local_size != max_size:
|
| padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
|
| tensor = torch.cat((tensor, padding), dim=0)
|
| dist.all_gather(tensor_list, tensor)
|
|
|
| data_list = []
|
| for size, tensor in zip(size_list, tensor_list):
|
| buffer = tensor.cpu().numpy().tobytes()[:size]
|
| data_list.append(pickle.loads(buffer))
|
|
|
| return data_list
|
|
|
|
|
| def reduce_dict(input_dict, average=True):
|
| """
|
| Args:
|
| input_dict (dict): all the values will be reduced
|
| average (bool): whether to do average or sum
|
| Reduce the values in the dictionary from all processes so that all processes
|
| have the averaged results. Returns a dict with the same fields as
|
| input_dict, after reduction.
|
| """
|
| world_size = get_world_size()
|
| if world_size < 2:
|
| return input_dict
|
| with torch.no_grad():
|
| names = []
|
| values = []
|
|
|
| for k in sorted(input_dict.keys()):
|
| names.append(k)
|
| values.append(input_dict[k])
|
| values = torch.stack(values, dim=0)
|
| dist.all_reduce(values)
|
| if average:
|
| values /= world_size
|
| reduced_dict = {k: v for k, v in zip(names, values)}
|
| return reduced_dict
|
|
|
|
|
| def _max_by_axis(the_list):
|
|
|
| maxes = the_list[0]
|
| for sublist in the_list[1:]:
|
| for index, item in enumerate(sublist):
|
| maxes[index] = max(maxes[index], item)
|
| return maxes
|
|
|
|
|
| class NestedTensor(object):
|
| def __init__(self, tensors, mask: Optional[Tensor]):
|
| self.tensors = tensors
|
| self.mask = mask
|
|
|
| def to(self, device):
|
| cast_tensor = self.tensors.to(device)
|
| mask = self.mask
|
| if mask is not None:
|
| assert mask is not None
|
| cast_mask = mask.to(device)
|
| else:
|
| cast_mask = None
|
| return NestedTensor(cast_tensor, cast_mask)
|
|
|
| @property
|
| def device(self):
|
| return self.tensors.device
|
|
|
| def decompose(self):
|
| return self.tensors, self.mask
|
|
|
| def __repr__(self):
|
| return str(self.tensors)
|
|
|
|
|
| def nested_tensor_from_tensor_list(tensor_list: List[Tensor], size_divisibility=1, split=True):
|
| """
|
| This function receives a list of image tensors and returns a NestedTensor of the padded images, along with their
|
| padding masks (true for padding areas, false otherwise).
|
| """
|
|
|
|
|
| if split:
|
| tensor_list = [tensor.split(3,dim=0) for tensor in tensor_list]
|
| tensor_list = [item for sublist in tensor_list for item in sublist]
|
|
|
|
|
| if tensor_list[0].ndim == 3:
|
|
|
| max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
|
|
| if size_divisibility > 1:
|
| stride = size_divisibility
|
|
|
| max_size[-2] = (max_size[-2] + (stride - 1)) // stride * stride
|
| max_size[-1] = (max_size[-1] + (stride - 1)) // stride * stride
|
|
|
|
|
| batch_shape = [len(tensor_list)] + max_size
|
| b, c, h, w = batch_shape
|
| dtype = tensor_list[0].dtype
|
| device = tensor_list[0].device
|
| tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
| mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
| for img, pad_img, m in zip(tensor_list, tensor, mask):
|
| pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
| m[: img.shape[1], :img.shape[2]] = False
|
| else:
|
| raise ValueError('not supported')
|
| return NestedTensor(tensor, mask)
|
|
|
|
|
| def nested_tensor_from_videos_list(videos_list: List[Tensor], size_divisibility=1):
|
| """
|
| This function receives a list of videos (each of shape [T, C, H, W]) and returns a NestedTensor of the padded
|
| videos (shape [B, T, C, PH, PW], along with their padding masks (true for padding areas, false otherwise, of shape
|
| [B, T, PH, PW].
|
| """
|
| max_size = _max_by_axis([list(img.shape) for img in videos_list])
|
|
|
| if size_divisibility > 1:
|
| stride = size_divisibility
|
|
|
| max_size[-2] = (max_size[-2] + (stride - 1)) // stride * stride
|
| max_size[-1] = (max_size[-1] + (stride - 1)) // stride * stride
|
|
|
| padded_batch_shape = [len(videos_list)] + max_size
|
| b, t, c, h, w = padded_batch_shape
|
| dtype = videos_list[0].dtype
|
| device = videos_list[0].device
|
| padded_videos = torch.zeros(padded_batch_shape, dtype=dtype, device=device)
|
| videos_pad_masks = torch.ones((b, t, h, w), dtype=torch.bool, device=device)
|
|
|
|
|
| for vid_frames, pad_vid_frames, vid_pad_m in zip(videos_list, padded_videos, videos_pad_masks):
|
| pad_vid_frames[:vid_frames.shape[0], :, :vid_frames.shape[2], :vid_frames.shape[3]].copy_(vid_frames)
|
| vid_pad_m[:vid_frames.shape[0], :vid_frames.shape[2], :vid_frames.shape[3]] = False
|
| return NestedTensor(padded_videos, videos_pad_masks)
|
|
|
|
|
| def setup_for_distributed(is_master):
|
| """
|
| This function disables printing when not in master process
|
| """
|
| import builtins as __builtin__
|
| builtin_print = __builtin__.print
|
|
|
| def print(*args, **kwargs):
|
| force = kwargs.pop('force', False)
|
| if is_master or force:
|
| builtin_print(*args, **kwargs)
|
|
|
| __builtin__.print = print
|
|
|
|
|
| def is_dist_avail_and_initialized():
|
| if not dist.is_available():
|
| return False
|
| if not dist.is_initialized():
|
| return False
|
| return True
|
|
|
|
|
| def get_world_size():
|
| if not is_dist_avail_and_initialized():
|
| return 1
|
| return dist.get_world_size()
|
|
|
|
|
| def get_rank():
|
| if not is_dist_avail_and_initialized():
|
| return 0
|
| return dist.get_rank()
|
|
|
|
|
| def is_main_process():
|
| return get_rank() == 0
|
|
|
|
|
| def save_on_master(*args, **kwargs):
|
| if is_main_process():
|
| torch.save(*args, **kwargs)
|
|
|
| def box_xyxy_to_cxcywh(x):
|
| x0, y0, x1, y1 = x.unbind(-1)
|
| b = [(x0 + x1) / 2, (y0 + y1) / 2,
|
| (x1 - x0), (y1 - y0)]
|
| return torch.stack(b, dim=-1)
|
|
|
| def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
|
|
| """
|
| Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
|
| This will eventually be supported natively by PyTorch, and this
|
| class can go away.
|
| """
|
| if float(torchvision.__version__.split(".")[1]) < 7.0:
|
| if input.numel() > 0:
|
| return torch.nn.functional.interpolate(
|
| input, size, scale_factor, mode, align_corners
|
| )
|
|
|
| output_shape = _output_size(2, input, size, scale_factor)
|
| output_shape = list(input.shape[:-2]) + list(output_shape)
|
| return _new_empty_tensor(input, output_shape)
|
| else:
|
| return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
|
|
|
| class SmoothedValue(object):
|
| """Track a series of values and provide access to smoothed values over a
|
| window or the global series average.
|
| """
|
|
|
| def __init__(self, window_size=20, fmt=None):
|
| if fmt is None:
|
| fmt = "{median:.4f} ({global_avg:.4f})"
|
| self.deque = deque(maxlen=window_size)
|
| self.total = 0.0
|
| self.count = 0
|
| self.fmt = fmt
|
|
|
| def update(self, value, n=1):
|
| self.deque.append(value)
|
| self.count += n
|
| self.total += value * n
|
|
|
| def synchronize_between_processes(self):
|
| """
|
| Warning: does not synchronize the deque!
|
| """
|
| if not is_dist_avail_and_initialized():
|
| return
|
| t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
| dist.barrier()
|
| dist.all_reduce(t)
|
| t = t.tolist()
|
| self.count = int(t[0])
|
| self.total = t[1]
|
|
|
| @property
|
| def median(self):
|
| d = torch.tensor(list(self.deque))
|
| return d.median().item()
|
|
|
| @property
|
| def avg(self):
|
| d = torch.tensor(list(self.deque), dtype=torch.float32)
|
| return d.mean().item()
|
|
|
| @property
|
| def global_avg(self):
|
| return self.total / self.count
|
|
|
| @property
|
| def max(self):
|
| return max(self.deque)
|
|
|
| @property
|
| def value(self):
|
| return self.deque[-1]
|
|
|
| def __str__(self):
|
| return self.fmt.format(
|
| median=self.median,
|
| avg=self.avg,
|
| global_avg=self.global_avg,
|
| max=self.max,
|
| value=self.value)
|
|
|
| class MetricLogger(object):
|
| def __init__(self, delimiter="\t"):
|
| self.meters = defaultdict(SmoothedValue)
|
| self.delimiter = delimiter
|
|
|
| def update(self, **kwargs):
|
| for k, v in kwargs.items():
|
| if isinstance(v, torch.Tensor):
|
| v = v.item()
|
| assert isinstance(v, (float, int))
|
| self.meters[k].update(v)
|
|
|
| def __getattr__(self, attr):
|
| if attr in self.meters:
|
| return self.meters[attr]
|
| if attr in self.__dict__:
|
| return self.__dict__[attr]
|
| raise AttributeError("'{}' object has no attribute '{}'".format(
|
| type(self).__name__, attr))
|
|
|
| def __str__(self):
|
| loss_str = []
|
| for name, meter in self.meters.items():
|
| loss_str.append(
|
| "{}: {}".format(name, str(meter))
|
| )
|
| return self.delimiter.join(loss_str)
|
|
|
| def synchronize_between_processes(self):
|
| for meter in self.meters.values():
|
| meter.synchronize_between_processes()
|
|
|
| def add_meter(self, name, meter):
|
| self.meters[name] = meter
|
|
|
| def log_every(self, iterable, print_freq, header=None):
|
| i = 0
|
| if not header:
|
| header = ''
|
| start_time = time.time()
|
| end = time.time()
|
| iter_time = SmoothedValue(fmt='{avg:.4f}')
|
| data_time = SmoothedValue(fmt='{avg:.4f}')
|
| space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
| if torch.cuda.is_available():
|
| log_msg = self.delimiter.join([
|
| header,
|
| '[{0' + space_fmt + '}/{1}]',
|
| 'eta: {eta}',
|
| '{meters}',
|
| 'time: {time}',
|
| 'data: {data}',
|
| 'max mem: {memory:.0f}'
|
| ])
|
| else:
|
| log_msg = self.delimiter.join([
|
| header,
|
| '[{0' + space_fmt + '}/{1}]',
|
| 'eta: {eta}',
|
| '{meters}',
|
| 'time: {time}',
|
| 'data: {data}'
|
| ])
|
| MB = 1024.0 * 1024.0
|
| for obj in iterable:
|
| data_time.update(time.time() - end)
|
| yield obj
|
| iter_time.update(time.time() - end)
|
| if i % print_freq == 0 or i == len(iterable) - 1:
|
| eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
| eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
| if torch.cuda.is_available():
|
| print(log_msg.format(
|
| i, len(iterable), eta=eta_string,
|
| meters=str(self),
|
| time=str(iter_time), data=str(data_time),
|
| memory=torch.cuda.max_memory_allocated() / MB))
|
| else:
|
| print(log_msg.format(
|
| i, len(iterable), eta=eta_string,
|
| meters=str(self),
|
| time=str(iter_time), data=str(data_time)))
|
| i += 1
|
| end = time.time()
|
| total_time = time.time() - start_time
|
| total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| print('{} Total time: {} ({:.4f} s / it)'.format(
|
| header, total_time_str, total_time / len(iterable)))
|
|
|
| def clip_iou(boxes1,boxes2):
|
| area1 = box_area(boxes1)
|
| area2 = box_area(boxes2)
|
| lt = torch.max(boxes1[:, :2], boxes2[:, :2])
|
| rb = torch.min(boxes1[:, 2:], boxes2[:, 2:])
|
| wh = (rb - lt).clamp(min=0)
|
| inter = wh[:,0] * wh[:,1]
|
| union = area1 + area2 - inter
|
| iou = (inter + 1e-6) / (union+1e-6)
|
| return iou
|
|
|
| def multi_iou(boxes1, boxes2):
|
| lt = torch.max(boxes1[...,:2], boxes2[...,:2])
|
| rb = torch.min(boxes1[...,2:], boxes2[...,2:])
|
| wh = (rb - lt).clamp(min=0)
|
| wh_1 = boxes1[...,2:] - boxes1[...,:2]
|
| wh_2 = boxes2[...,2:] - boxes2[...,:2]
|
| inter = wh[...,0] * wh[...,1]
|
| union = wh_1[...,0] * wh_1[...,1] + wh_2[...,0] * wh_2[...,1] - inter
|
| iou = (inter + 1e-6) / (union + 1e-6)
|
| return iou
|
|
|
| def box_cxcywh_to_xyxy(x):
|
| x_c, y_c, w, h = x.unbind(-1)
|
| b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
|
| (x_c + 0.5 * w), (y_c + 0.5 * h)]
|
| return torch.stack(b, dim=-1)
|
|
|
|
|
| def box_iou(boxes1, boxes2):
|
| area1 = box_area(boxes1)
|
| area2 = box_area(boxes2)
|
|
|
| lt = torch.max(boxes1[:, None, :2], boxes2[:, :2])
|
| rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])
|
|
|
| wh = (rb - lt).clamp(min=0)
|
| inter = wh[:, :, 0] * wh[:, :, 1]
|
|
|
| union = area1[:, None] + area2 - inter
|
|
|
| iou = (inter+1e-6) / (union+1e-6)
|
| return iou, union
|
|
|
| def generalized_box_iou(boxes1, boxes2):
|
| """
|
| Generalized IoU from https://giou.stanford.edu/
|
|
|
| The boxes should be in [x0, y0, x1, y1] format
|
|
|
| Returns a [N, M] pairwise matrix, where N = len(boxes1)
|
| and M = len(boxes2)
|
| """
|
|
|
|
|
| assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
|
| assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
|
| iou, union = box_iou(boxes1, boxes2)
|
|
|
| lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
| rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
|
|
| wh = (rb - lt).clamp(min=0)
|
| area = wh[:, :, 0] * wh[:, :, 1]
|
|
|
| return iou - ((area - union) + 1e-6) / (area + 1e-6)
|
|
|
| def inverse_sigmoid(x, eps=1e-5):
|
| x = x.clamp(min=0, max=1)
|
| x1 = x.clamp(min=eps)
|
| x2 = (1 - x).clamp(min=eps)
|
| return torch.log(x1/x2)
|
|
|
| def masks_to_boxes(masks):
|
| """Compute the bounding boxes around the provided masks
|
|
|
| The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
|
|
|
| Returns a [N, 4] tensors, with the boxes in xyxy format
|
| """
|
| if masks.numel() == 0:
|
| return torch.zeros((0, 4), device=masks.device)
|
|
|
| h, w = masks.shape[-2:]
|
|
|
| y = torch.arange(0, h, dtype=torch.float)
|
| x = torch.arange(0, w, dtype=torch.float)
|
| y, x = torch.meshgrid(y, x)
|
|
|
| x_mask = (masks * x.unsqueeze(0))
|
| x_max = x_mask.flatten(1).max(-1)[0]
|
| x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
|
|
| y_mask = (masks * y.unsqueeze(0))
|
| y_max = y_mask.flatten(1).max(-1)[0]
|
| y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
|
|
| return torch.stack([x_min, y_min, x_max, y_max], 1)
|
|
|
|
|
| def clean_state_dict(state_dict):
|
| new_state_dict = OrderedDict()
|
| for k, v in state_dict.items():
|
| if k[:7] == 'module.':
|
| k = k[7:]
|
| new_state_dict[k] = v
|
| return new_state_dict
|
|
|
| def get_batch_observer(cur_epoch, total_epochs, batch_num, disable=False):
|
| batch_ob = Progress(
|
| TextColumn("[bold cyan]Epoch:{task.fields[epoch]}/{task.fields[total_epoch]}"),
|
| BarColumn(bar_width=40),
|
| "{task.completed}/{task.total}",
|
| "•", TimeRemainingColumn(),
|
| "•", TextColumn("[bold red]loss={task.fields[loss]:.4f}"),
|
| "•", TextColumn("[bold deep_sky_blue1]cls={task.fields[cls]:.4f}"),
|
| "•", TextColumn("[bold magenta]bbox={task.fields[bbox]:.4f}"),
|
| "•", TextColumn("[bold magenta]giou={task.fields[giou]:.4f}"),
|
| "•", TextColumn("[bold gold1]mask={task.fields[mask]:.4f}"),
|
| "•", TextColumn("[bold gold1]dice={task.fields[dice]:.4f}"),
|
| "•", TextColumn("[bold gold1]proj={task.fields[proj]:.4f}"),
|
| disable=disable,
|
| )
|
| pg = batch_ob.add_task(description="Training Observer", total=batch_num, epoch=cur_epoch, total_epoch=total_epochs,
|
| loss=0, cls=0, bbox=0, giou=0, mask=0, dice=0, proj=0)
|
| return batch_ob, pg
|
|
|
| def colormap(rgb=False):
|
|
|
|
|
| cmap = plt.cm.Set1
|
|
|
|
|
| num_colors = 10
|
|
|
|
|
| color_list = cmap(np.linspace(0, 1, num_colors))[:, :3] * 255
|
| if not rgb:
|
| color_list = color_list[:, ::-1]
|
| return color_list |