Spaces:
No application file
No application file
| """ | |
| Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) | |
| Copyright(c) 2023 lyuwenyu. All Rights Reserved. | |
| """ | |
| import importlib.metadata | |
| from torch import Tensor | |
| if "0.15.2" in importlib.metadata.version("torchvision"): | |
| import torchvision | |
| torchvision.disable_beta_transforms_warning() | |
| from torchvision.datapoints import BoundingBox as BoundingBoxes | |
| from torchvision.datapoints import BoundingBoxFormat, Image, Mask, Video | |
| from torchvision.transforms.v2 import SanitizeBoundingBox as SanitizeBoundingBoxes | |
| _boxes_keys = ["format", "spatial_size"] | |
| elif "0.17" > importlib.metadata.version("torchvision") >= "0.16": | |
| import torchvision | |
| torchvision.disable_beta_transforms_warning() | |
| from torchvision.transforms.v2 import SanitizeBoundingBoxes | |
| from torchvision.tv_tensors import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video | |
| _boxes_keys = ["format", "canvas_size"] | |
| elif importlib.metadata.version("torchvision") >= "0.17": | |
| import torchvision | |
| from torchvision.transforms.v2 import SanitizeBoundingBoxes | |
| from torchvision.tv_tensors import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video | |
| _boxes_keys = ["format", "canvas_size"] | |
| else: | |
| raise RuntimeError("Please make sure torchvision version >= 0.15.2") | |
| def convert_to_tv_tensor(tensor: Tensor, key: str, box_format="xyxy", spatial_size=None) -> Tensor: | |
| """ | |
| Args: | |
| tensor (Tensor): input tensor | |
| key (str): transform to key | |
| Return: | |
| Dict[str, TV_Tensor] | |
| """ | |
| assert key in ( | |
| "boxes", | |
| "masks", | |
| ), "Only support 'boxes' and 'masks'" | |
| if key == "boxes": | |
| box_format = getattr(BoundingBoxFormat, box_format.upper()) | |
| _kwargs = dict(zip(_boxes_keys, [box_format, spatial_size])) | |
| return BoundingBoxes(tensor, **_kwargs) | |
| if key == "masks": | |
| return Mask(tensor) | |