Spaces:
Sleeping
Sleeping
| import math | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import torch | |
| import torchvision | |
| from torch import nn, Tensor | |
| from .image_list import ImageList | |
| def _get_shape_onnx(image: Tensor) -> Tensor: | |
| from torch.onnx import operators | |
| return operators.shape_as_tensor(image)[-2:] | |
| def _fake_cast_onnx(v: Tensor) -> float: | |
| # ONNX requires a tensor but here we fake its type for JIT. | |
| return v | |
| def _resize_image_and_masks( | |
| image: Tensor, | |
| self_min_size: int, | |
| self_max_size: int, | |
| target: Optional[Dict[str, Tensor]] = None, | |
| fixed_size: Optional[Tuple[int, int]] = None, | |
| ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: | |
| if torchvision._is_tracing(): | |
| im_shape = _get_shape_onnx(image) | |
| elif torch.jit.is_scripting(): | |
| im_shape = torch.tensor(image.shape[-2:]) | |
| else: | |
| im_shape = image.shape[-2:] | |
| size: Optional[List[int]] = None | |
| scale_factor: Optional[float] = None | |
| recompute_scale_factor: Optional[bool] = None | |
| if fixed_size is not None: | |
| size = [fixed_size[1], fixed_size[0]] | |
| else: | |
| if torch.jit.is_scripting() or torchvision._is_tracing(): | |
| min_size = torch.min(im_shape).to(dtype=torch.float32) | |
| max_size = torch.max(im_shape).to(dtype=torch.float32) | |
| self_min_size_f = float(self_min_size) | |
| self_max_size_f = float(self_max_size) | |
| scale = torch.min(self_min_size_f / min_size, self_max_size_f / max_size) | |
| if torchvision._is_tracing(): | |
| scale_factor = _fake_cast_onnx(scale) | |
| else: | |
| scale_factor = scale.item() | |
| else: | |
| # Do it the normal way | |
| min_size = min(im_shape) | |
| max_size = max(im_shape) | |
| scale_factor = min(self_min_size / min_size, self_max_size / max_size) | |
| recompute_scale_factor = True | |
| image = torch.nn.functional.interpolate( | |
| image[None], | |
| size=size, | |
| scale_factor=scale_factor, | |
| mode="bilinear", | |
| recompute_scale_factor=recompute_scale_factor, | |
| align_corners=False, | |
| )[0] | |
| if target is None: | |
| return image, target | |
| if "masks" in target: | |
| mask = target["masks"] | |
| mask = torch.nn.functional.interpolate( | |
| mask[:, None].float(), size=size, scale_factor=scale_factor, recompute_scale_factor=recompute_scale_factor | |
| )[:, 0].byte() | |
| target["masks"] = mask | |
| return image, target | |
| class GeneralizedRCNNTransform(nn.Module): | |
| """ | |
| Performs input / target transformation before feeding the data to a GeneralizedRCNN | |
| model. | |
| The transformations it performs are: | |
| - input normalization (mean subtraction and std division) | |
| - input / target resizing to match min_size / max_size | |
| It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets | |
| """ | |
| def __init__( | |
| self, | |
| min_size: int, | |
| max_size: int, | |
| image_mean: List[float], | |
| image_std: List[float], | |
| size_divisible: int = 32, | |
| fixed_size: Optional[Tuple[int, int]] = None, | |
| **kwargs: Any, | |
| ): | |
| super().__init__() | |
| if not isinstance(min_size, (list, tuple)): | |
| min_size = (min_size,) | |
| self.min_size = min_size | |
| self.max_size = max_size | |
| self.image_mean = image_mean | |
| self.image_std = image_std | |
| self.size_divisible = size_divisible | |
| self.fixed_size = fixed_size | |
| self._skip_resize = kwargs.pop("_skip_resize", False) | |
| def forward( | |
| self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None | |
| ) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]: | |
| images = [img for img in images] | |
| if targets is not None: | |
| # make a copy of targets to avoid modifying it in-place | |
| # once torchscript supports dict comprehension | |
| # this can be simplified as follows | |
| # targets = [{k: v for k,v in t.items()} for t in targets] | |
| targets_copy: List[Dict[str, Tensor]] = [] | |
| for t in targets: | |
| data: Dict[str, Tensor] = {} | |
| for k, v in t.items(): | |
| data[k] = v | |
| targets_copy.append(data) | |
| targets = targets_copy | |
| for i in range(len(images)): | |
| image = images[i] | |
| target_index = targets[i] if targets is not None else None | |
| if image.dim() != 3: | |
| raise ValueError(f"images is expected to be a list of 3d tensors of shape [C, H, W], got {image.shape}") | |
| image = self.normalize(image) | |
| image, target_index = self.resize(image, target_index) | |
| images[i] = image | |
| if targets is not None and target_index is not None: | |
| targets[i] = target_index | |
| image_sizes = [img.shape[-2:] for img in images] | |
| images = self.batch_images(images, size_divisible=self.size_divisible) | |
| image_sizes_list: List[Tuple[int, int]] = [] | |
| for image_size in image_sizes: | |
| torch._assert( | |
| len(image_size) == 2, | |
| f"Input tensors expected to have in the last two elements H and W, instead got {image_size}", | |
| ) | |
| image_sizes_list.append((image_size[0], image_size[1])) | |
| image_list = ImageList(images, image_sizes_list) | |
| return image_list, targets | |
| def normalize(self, image: Tensor) -> Tensor: | |
| if not image.is_floating_point(): | |
| raise TypeError( | |
| f"Expected input images to be of floating type (in range [0, 1]), " | |
| f"but found type {image.dtype} instead" | |
| ) | |
| dtype, device = image.dtype, image.device | |
| mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device) | |
| std = torch.as_tensor(self.image_std, dtype=dtype, device=device) | |
| return (image - mean[:, None, None]) / std[:, None, None] | |
| def torch_choice(self, k: List[int]) -> int: | |
| """ | |
| Implements `random.choice` via torch ops, so it can be compiled with | |
| TorchScript and we use PyTorch's RNG (not native RNG) | |
| """ | |
| index = int(torch.empty(1).uniform_(0.0, float(len(k))).item()) | |
| return k[index] | |
| def resize( | |
| self, | |
| image: Tensor, | |
| target: Optional[Dict[str, Tensor]] = None, | |
| ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: | |
| h, w = image.shape[-2:] | |
| if self.training: | |
| if self._skip_resize: | |
| return image, target | |
| size = self.torch_choice(self.min_size) | |
| else: | |
| size = self.min_size[-1] | |
| image, target = _resize_image_and_masks(image, size, self.max_size, target, self.fixed_size) | |
| if target is None: | |
| return image, target | |
| bbox = target["boxes"] | |
| bbox = resize_boxes(bbox, (h, w), image.shape[-2:]) | |
| target["boxes"] = bbox | |
| if "keypoints" in target: | |
| keypoints = target["keypoints"] | |
| keypoints = resize_keypoints(keypoints, (h, w), image.shape[-2:]) | |
| target["keypoints"] = keypoints | |
| return image, target | |
| # _onnx_batch_images() is an implementation of | |
| # batch_images() that is supported by ONNX tracing. | |
| def _onnx_batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor: | |
| max_size = [] | |
| for i in range(images[0].dim()): | |
| max_size_i = torch.max(torch.stack([img.shape[i] for img in images]).to(torch.float32)).to(torch.int64) | |
| max_size.append(max_size_i) | |
| stride = size_divisible | |
| max_size[1] = (torch.ceil((max_size[1].to(torch.float32)) / stride) * stride).to(torch.int64) | |
| max_size[2] = (torch.ceil((max_size[2].to(torch.float32)) / stride) * stride).to(torch.int64) | |
| max_size = tuple(max_size) | |
| # work around for | |
| # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) | |
| # which is not yet supported in onnx | |
| padded_imgs = [] | |
| for img in images: | |
| padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] | |
| padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) | |
| padded_imgs.append(padded_img) | |
| return torch.stack(padded_imgs) | |
| def max_by_axis(self, the_list: List[List[int]]) -> List[int]: | |
| maxes = the_list[0] | |
| for sublist in the_list[1:]: | |
| for index, item in enumerate(sublist): | |
| maxes[index] = max(maxes[index], item) | |
| return maxes | |
| def batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor: | |
| if torchvision._is_tracing(): | |
| # batch_images() does not export well to ONNX | |
| # call _onnx_batch_images() instead | |
| return self._onnx_batch_images(images, size_divisible) | |
| max_size = self.max_by_axis([list(img.shape) for img in images]) | |
| stride = float(size_divisible) | |
| max_size = list(max_size) | |
| max_size[1] = int(math.ceil(float(max_size[1]) / stride) * stride) | |
| max_size[2] = int(math.ceil(float(max_size[2]) / stride) * stride) | |
| batch_shape = [len(images)] + max_size | |
| batched_imgs = images[0].new_full(batch_shape, 0) | |
| for i in range(batched_imgs.shape[0]): | |
| img = images[i] | |
| batched_imgs[i, : img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) | |
| return batched_imgs | |
| def postprocess( | |
| self, | |
| result: List[Dict[str, Tensor]], | |
| image_shapes: List[Tuple[int, int]], | |
| original_image_sizes: List[Tuple[int, int]], | |
| ) -> List[Dict[str, Tensor]]: | |
| if self.training: | |
| return result | |
| for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)): | |
| boxes = pred["boxes"] | |
| boxes = resize_boxes(boxes, im_s, o_im_s) | |
| result[i]["boxes"] = boxes | |
| if "masks" in pred: | |
| masks = pred["masks"] | |
| masks = paste_masks_in_image(masks, boxes, o_im_s) | |
| result[i]["masks"] = masks | |
| if "keypoints" in pred: | |
| keypoints = pred["keypoints"] | |
| keypoints = resize_keypoints(keypoints, im_s, o_im_s) | |
| result[i]["keypoints"] = keypoints | |
| return result | |
| def __repr__(self) -> str: | |
| format_string = f"{self.__class__.__name__}(" | |
| _indent = "\n " | |
| format_string += f"{_indent}Normalize(mean={self.image_mean}, std={self.image_std})" | |
| format_string += f"{_indent}Resize(min_size={self.min_size}, max_size={self.max_size}, mode='bilinear')" | |
| format_string += "\n)" | |
| return format_string | |
| def resize_keypoints(keypoints: Tensor, original_size: List[int], new_size: List[int]) -> Tensor: | |
| ratios = [ | |
| torch.tensor(s, dtype=torch.float32, device=keypoints.device) | |
| / torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device) | |
| for s, s_orig in zip(new_size, original_size) | |
| ] | |
| ratio_h, ratio_w = ratios | |
| resized_data = keypoints.clone() | |
| if torch._C._get_tracing_state(): | |
| resized_data_0 = resized_data[:, :, 0] * ratio_w | |
| resized_data_1 = resized_data[:, :, 1] * ratio_h | |
| resized_data = torch.stack((resized_data_0, resized_data_1, resized_data[:, :, 2]), dim=2) | |
| else: | |
| resized_data[..., 0] *= ratio_w | |
| resized_data[..., 1] *= ratio_h | |
| return resized_data | |
| def resize_boxes(boxes: Tensor, original_size: List[int], new_size: List[int]) -> Tensor: | |
| ratios = [ | |
| torch.tensor(s, dtype=torch.float32, device=boxes.device) | |
| / torch.tensor(s_orig, dtype=torch.float32, device=boxes.device) | |
| for s, s_orig in zip(new_size, original_size) | |
| ] | |
| ratio_height, ratio_width = ratios | |
| xmin, ymin, xmax, ymax = boxes.unbind(1) | |
| xmin = xmin * ratio_width | |
| xmax = xmax * ratio_width | |
| ymin = ymin * ratio_height | |
| ymax = ymax * ratio_height | |
| return torch.stack((xmin, ymin, xmax, ymax), dim=1) |