Spaces:
Build error
Build error
| # Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved | |
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
| """ | |
| Misc functions, including distributed helpers. | |
| Mostly copy-paste from torchvision references. | |
| """ | |
| import os | |
| import subprocess | |
| from typing import Any, Dict, List, Optional | |
| import torch | |
| import torchvision | |
| from torch import Tensor | |
| def get_sha(): | |
| cwd = os.path.dirname(os.path.abspath(__file__)) | |
| def _run(command): | |
| return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() | |
| sha = "N/A" | |
| diff = "clean" | |
| branch = "N/A" | |
| try: | |
| sha = _run(["git", "rev-parse", "HEAD"]) | |
| subprocess.check_output(["git", "diff"], cwd=cwd) | |
| diff = _run(["git", "diff-index", "HEAD"]) | |
| diff = "has uncommited changes" if diff else "clean" | |
| branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) | |
| except Exception: | |
| pass | |
| message = f"sha: {sha}, status: {diff}, branch: {branch}" | |
| return message | |
| def collate_fn(do_round, batch): | |
| batch = list(zip(*batch)) | |
| final_batch = {} | |
| final_batch["samples"] = NestedTensor.from_tensor_list(batch[0], do_round) | |
| final_batch["targets"] = batch[1] | |
| if "positive_map" in batch[1][0]: | |
| # we batch the positive maps here | |
| # Since in general each batch element will have a different number of boxes, | |
| # we collapse a single batch dimension to avoid padding. This is sufficient for our purposes. | |
| max_len = max([v["positive_map"].shape[1] for v in batch[1]]) | |
| nb_boxes = sum([v["positive_map"].shape[0] for v in batch[1]]) | |
| batched_pos_map = torch.zeros((nb_boxes, max_len), dtype=torch.bool) | |
| cur_count = 0 | |
| for v in batch[1]: | |
| cur_pos = v["positive_map"] | |
| batched_pos_map[cur_count : cur_count + len(cur_pos), : cur_pos.shape[1]] = cur_pos | |
| cur_count += len(cur_pos) | |
| assert cur_count == len(batched_pos_map) | |
| # assert batched_pos_map.sum().item() == sum([v["positive_map"].sum().item() for v in batch[1]]) | |
| final_batch["positive_map"] = batched_pos_map.float() | |
| if "positive_map_eval" in batch[1][0]: | |
| # we batch the positive maps here | |
| # Since in general each batch element will have a different number of boxes, | |
| # we collapse a single batch dimension to avoid padding. This is sufficient for our purposes. | |
| max_len = max([v["positive_map_eval"].shape[1] for v in batch[1]]) | |
| nb_boxes = sum([v["positive_map_eval"].shape[0] for v in batch[1]]) | |
| batched_pos_map = torch.zeros((nb_boxes, max_len), dtype=torch.bool) | |
| cur_count = 0 | |
| for v in batch[1]: | |
| cur_pos = v["positive_map_eval"] | |
| batched_pos_map[cur_count : cur_count + len(cur_pos), : cur_pos.shape[1]] = cur_pos | |
| cur_count += len(cur_pos) | |
| assert cur_count == len(batched_pos_map) | |
| # assert batched_pos_map.sum().item() == sum([v["positive_map"].sum().item() for v in batch[1]]) | |
| final_batch["positive_map_eval"] = batched_pos_map.float() | |
| if "answer" in batch[1][0] or "answer_type" in batch[1][0]: | |
| answers = {} | |
| for f in batch[1][0].keys(): | |
| if "answer" not in f: | |
| continue | |
| answers[f] = torch.stack([b[f] for b in batch[1]]) | |
| final_batch["answers"] = answers | |
| return final_batch | |
| class NestedTensor(object): | |
| def __init__(self, tensors, mask): | |
| self.tensors = tensors | |
| self.mask = mask | |
| def to(self, *args, **kwargs): | |
| cast_tensor = self.tensors.to(*args, **kwargs) | |
| cast_mask = self.mask.to(*args, **kwargs) if self.mask is not None else None | |
| return type(self)(cast_tensor, cast_mask) | |
| def decompose(self): | |
| return self.tensors, self.mask | |
| def from_tensor_list(cls, tensor_list, do_round=False): | |
| # TODO make this more general | |
| if tensor_list[0].ndim == 3: | |
| # TODO make it support different-sized images | |
| max_size = tuple(max(s) for s in zip(*[img.shape for img in tensor_list])) | |
| # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) | |
| batch_shape = (len(tensor_list),) + max_size | |
| b, c, h, w = batch_shape | |
| if do_round: | |
| # Round to an even size to avoid rounding issues in fpn | |
| p = 128 | |
| h = h if h % p == 0 else (h // p + 1) * p | |
| w = w if w % p == 0 else (w // p + 1) * p | |
| batch_shape = b, c, h, w | |
| dtype = tensor_list[0].dtype | |
| device = tensor_list[0].device | |
| tensor = torch.zeros(batch_shape, dtype=dtype, device=device) | |
| mask = torch.ones((b, h, w), dtype=torch.bool, device=device) | |
| for img, pad_img, m in zip(tensor_list, tensor, mask): | |
| pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) | |
| m[: img.shape[1], : img.shape[2]] = False | |
| else: | |
| raise ValueError("not supported") | |
| return cls(tensor, mask) | |
| def __repr__(self): | |
| return repr(self.tensors) | |
| def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): | |
| # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor | |
| """ | |
| Equivalent to nn.functional.interpolate, but with support for empty channel sizes. | |
| """ | |
| if input.numel() > 0: | |
| return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners) | |
| assert input.shape[0] != 0 or input.shape[1] != 0, "At least one of the two first dimensions must be non zero" | |
| if input.shape[1] == 0: | |
| # Pytorch doesn't support null dimension on the channel dimension, so we transpose to fake a null batch dim | |
| return torch.nn.functional.interpolate(input.transpose(0, 1), size, scale_factor, mode, align_corners).transpose(0, 1) | |
| # empty batch dimension is now supported in pytorch | |
| return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners) | |
| def targets_to(targets: List[Dict[str, Any]], device): | |
| """Moves the target dicts to the given device.""" | |
| excluded_keys = [ | |
| "questionId", | |
| "tokens_positive", | |
| "tokens", | |
| "dataset_name", | |
| "sentence_id", | |
| "original_img_id", | |
| "nb_eval", | |
| "task_id", | |
| "original_id", | |
| ] | |
| return [{k: v.to(device) if k not in excluded_keys else v for k, v in t.items() if k != "caption"} for t in targets] | |