Spaces:
Runtime error
Runtime error
| import re | |
| import numpy as np | |
| import torch | |
| import torch.distributed as dist | |
| import collections | |
| import logging | |
| def get_area(pos): | |
| """ | |
| Args | |
| pos: [B, N, 4] | |
| (x1, x2, y1, y2) | |
| Return | |
| area : [B, N] | |
| """ | |
| # [B, N] | |
| height = pos[:, :, 3] - pos[:, :, 2] | |
| width = pos[:, :, 1] - pos[:, :, 0] | |
| area = height * width | |
| return area | |
| def get_relative_distance(pos): | |
| """ | |
| Args | |
| pos: [B, N, 4] | |
| (x1, x2, y1, y2) | |
| Return | |
| out : [B, N, N, 4] | |
| """ | |
| # B, N = pos.size()[:-1] | |
| # [B, N, N, 4] | |
| relative_distance = pos.unsqueeze(1) - pos.unsqueeze(2) | |
| return relative_distance | |
| class LossMeter(object): | |
| def __init__(self, maxlen=100): | |
| """Computes and stores the running average""" | |
| self.vals = collections.deque([], maxlen=maxlen) | |
| def __len__(self): | |
| return len(self.vals) | |
| def update(self, new_val): | |
| self.vals.append(new_val) | |
| def val(self): | |
| return sum(self.vals) / len(self.vals) | |
| def __repr__(self): | |
| return str(self.val) | |
| def count_parameters(model): | |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| def load_state_dict(state_dict_path, loc='cpu'): | |
| state_dict = torch.load(state_dict_path, map_location=loc) | |
| # Change Multi GPU to single GPU | |
| original_keys = list(state_dict.keys()) | |
| for key in original_keys: | |
| if key.startswith("module."): | |
| new_key = key[len("module."):] | |
| state_dict[new_key] = state_dict.pop(key) | |
| return state_dict | |
| def set_global_logging_level(level=logging.ERROR, prefices=[""]): | |
| """ | |
| Override logging levels of different modules based on their name as a prefix. | |
| It needs to be invoked after the modules have been loaded so that their loggers have been initialized. | |
| Args: | |
| - level: desired level. e.g. logging.INFO. Optional. Default is logging.ERROR | |
| - prefices: list of one or more str prefices to match (e.g. ["transformers", "torch"]). Optional. | |
| Default is `[""]` to match all active loggers. | |
| The match is a case-sensitive `module_name.startswith(prefix)` | |
| """ | |
| prefix_re = re.compile(fr'^(?:{ "|".join(prefices) })') | |
| for name in logging.root.manager.loggerDict: | |
| if re.match(prefix_re, name): | |
| logging.getLogger(name).setLevel(level) | |
| def get_iou(anchors, gt_boxes): | |
| """ | |
| anchors: (N, 4) torch floattensor | |
| gt_boxes: (K, 4) torch floattensor | |
| overlaps: (N, K) ndarray of overlap between boxes and query_boxes | |
| """ | |
| N = anchors.size(0) | |
| if gt_boxes.size() == (4,): | |
| gt_boxes = gt_boxes.view(1, 4) | |
| K = gt_boxes.size(0) | |
| gt_boxes_area = ( | |
| (gt_boxes[:, 2] - gt_boxes[:, 0] + 1) * | |
| (gt_boxes[:, 3] - gt_boxes[:, 1] + 1) | |
| ).view(1, K) | |
| anchors_area = ( | |
| (anchors[:, 2] - anchors[:, 0] + 1) * | |
| (anchors[:, 3] - anchors[:, 1] + 1) | |
| ).view(N, 1) | |
| boxes = anchors.view(N, 1, 4).expand(N, K, 4) | |
| query_boxes = gt_boxes.view(1, K, 4).expand(N, K, 4) | |
| iw = ( | |
| torch.min(boxes[:, :, 2], query_boxes[:, :, 2]) | |
| - torch.max(boxes[:, :, 0], query_boxes[:, :, 0]) | |
| + 1 | |
| ) | |
| iw[iw < 0] = 0 | |
| ih = ( | |
| torch.min(boxes[:, :, 3], query_boxes[:, :, 3]) | |
| - torch.max(boxes[:, :, 1], query_boxes[:, :, 1]) | |
| + 1 | |
| ) | |
| ih[ih < 0] = 0 | |
| ua = anchors_area + gt_boxes_area - (iw * ih) | |
| overlaps = iw * ih / ua | |
| return overlaps | |
| def xywh_to_xyxy(boxes): | |
| """Convert [x y w h] box format to [x1 y1 x2 y2] format.""" | |
| return np.hstack((boxes[:, 0:2], boxes[:, 0:2] + boxes[:, 2:4] - 1)) | |