Spaces:
Running on Zero
Running on Zero
| """"by lyuwenyu | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torchvision | |
| torchvision.disable_beta_transforms_warning() | |
| try: | |
| from torchvision import datapoints as _datapoints | |
| _HAS_DATAPOINTS = True | |
| except Exception: | |
| from torchvision import tv_tensors as _datapoints | |
| _HAS_DATAPOINTS = False | |
| import torchvision.transforms.v2 as T | |
| import torchvision.transforms.v2.functional as F | |
| from PIL import Image | |
| from typing import Any, Dict, List, Optional | |
| from src.core import register, GLOBAL_CONFIG | |
| __all__ = ['Compose', ] | |
| RandomPhotometricDistort = register(T.RandomPhotometricDistort) | |
| RandomZoomOut = register(T.RandomZoomOut) | |
| # RandomIoUCrop = register(T.RandomIoUCrop) | |
| RandomHorizontalFlip = register(T.RandomHorizontalFlip) | |
| Resize = register(T.Resize) | |
| if hasattr(T, 'ToImageTensor'): | |
| ToImageTensor = register(T.ToImageTensor) | |
| else: | |
| _BaseToImageTensor = getattr(T, 'ToImage', None) or getattr(T, 'ToTensor', None) or getattr(T, 'PILToTensor', None) | |
| if _BaseToImageTensor is None: | |
| raise AttributeError( | |
| 'torchvision.transforms.v2 is missing ToImageTensor/ToImage/ToTensor/PILToTensor; please update torchvision.' | |
| ) | |
| class ToImageTensor(_BaseToImageTensor): | |
| pass | |
| if hasattr(T, 'ConvertDtype'): | |
| ConvertDtype = register(T.ConvertDtype) | |
| else: | |
| _BaseConvertDtype = getattr(T, 'ToDtype', None) | |
| if _BaseConvertDtype is None: | |
| raise AttributeError('torchvision.transforms.v2 is missing ConvertDtype/ToDtype; please update torchvision.') | |
| class ConvertDtype(_BaseConvertDtype): | |
| def __init__(self, dtype: torch.dtype = torch.float32, scale: bool = True) -> None: | |
| super().__init__(dtype=dtype, scale=scale) | |
| if hasattr(T, 'SanitizeBoundingBox'): | |
| _BaseSanitizeBoundingBox = T.SanitizeBoundingBox | |
| else: | |
| _BaseSanitizeBoundingBox = getattr(T, 'SanitizeBoundingBoxes', None) | |
| if _BaseSanitizeBoundingBox is None: | |
| raise AttributeError( | |
| 'torchvision.transforms.v2 is missing SanitizeBoundingBox/SanitizeBoundingBoxes; please update torchvision.' | |
| ) | |
| class SanitizeBoundingBox(_BaseSanitizeBoundingBox): | |
| def forward(self, *inputs): | |
| # Avoid indexing t_gt (full-image mask) with per-box valid mask. | |
| if len(inputs) >= 2 and isinstance(inputs[1], dict) and "t_gt" in inputs[1]: | |
| inputs = list(inputs) | |
| target = dict(inputs[1]) | |
| t_gt = target.pop("t_gt") | |
| inputs[1] = target | |
| outputs = super().forward(*inputs) | |
| if isinstance(outputs, tuple) and len(outputs) >= 2 and isinstance(outputs[1], dict): | |
| outputs = list(outputs) | |
| outputs[1]["t_gt"] = t_gt | |
| return tuple(outputs) | |
| return outputs | |
| return super().forward(*inputs) | |
| RandomCrop = register(T.RandomCrop) | |
| Normalize = register(T.Normalize) | |
| _Image = _datapoints.Image | |
| _Video = _datapoints.Video | |
| _Mask = _datapoints.Mask | |
| _BBoxFormat = _datapoints.BoundingBoxFormat | |
| _BBoxType = _datapoints.BoundingBox if _HAS_DATAPOINTS else _datapoints.BoundingBoxes | |
| def _make_bounding_box(data, format, spatial_size): | |
| fmt = format | |
| if _HAS_DATAPOINTS: | |
| return _datapoints.BoundingBox(data, format=fmt, spatial_size=spatial_size) | |
| return _datapoints.BoundingBoxes(data, format=fmt, canvas_size=spatial_size) | |
| def _bbox_spatial_size(bbox): | |
| if hasattr(bbox, "spatial_size"): | |
| return bbox.spatial_size | |
| if hasattr(bbox, "canvas_size"): | |
| return bbox.canvas_size | |
| raise AttributeError("Bounding box object has neither spatial_size nor canvas_size.") | |
| class Compose(T.Compose): | |
| def __init__(self, ops) -> None: | |
| transforms = [] | |
| if ops is not None: | |
| for op in ops: | |
| if isinstance(op, dict): | |
| name = op.pop('type') | |
| transfom = getattr(GLOBAL_CONFIG[name]['_pymodule'], name)(**op) | |
| transforms.append(transfom) | |
| # op['type'] = name | |
| elif isinstance(op, nn.Module): | |
| transforms.append(op) | |
| else: | |
| raise ValueError('') | |
| else: | |
| transforms =[EmptyTransform(), ] | |
| super().__init__(transforms=transforms) | |
| class EmptyTransform(T.Transform): | |
| def __init__(self, ) -> None: | |
| super().__init__() | |
| def forward(self, *inputs): | |
| inputs = inputs if len(inputs) > 1 else inputs[0] | |
| return inputs | |
| class PadToSize(T.Pad): | |
| _transformed_types = ( | |
| Image.Image, | |
| _Image, | |
| _Video, | |
| _Mask, | |
| _BBoxType, | |
| ) | |
| def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: | |
| sz = F.get_spatial_size(flat_inputs[0]) | |
| h, w = self.spatial_size[0] - sz[0], self.spatial_size[1] - sz[1] | |
| self.padding = [0, 0, w, h] | |
| return dict(padding=self.padding) | |
| def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: | |
| return self._get_params(flat_inputs) | |
| def __init__(self, spatial_size, fill=0, padding_mode='constant') -> None: | |
| if isinstance(spatial_size, int): | |
| spatial_size = (spatial_size, spatial_size) | |
| self.spatial_size = spatial_size | |
| super().__init__(0, fill, padding_mode) | |
| def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: | |
| fill = self._fill[type(inpt)] | |
| padding = params['padding'] | |
| return F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type] | |
| def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: | |
| return self._transform(inpt, params) | |
| def __call__(self, *inputs: Any) -> Any: | |
| outputs = super().forward(*inputs) | |
| if len(outputs) > 1 and isinstance(outputs[1], dict): | |
| outputs[1]['padding'] = torch.tensor(self.padding) | |
| return outputs | |
| class RandomIoUCrop(T.RandomIoUCrop): | |
| def __init__(self, min_scale: float = 0.3, max_scale: float = 1, min_aspect_ratio: float = 0.5, max_aspect_ratio: float = 2, sampler_options: Optional[List[float]] = None, trials: int = 40, p: float = 1.0): | |
| super().__init__(min_scale, max_scale, min_aspect_ratio, max_aspect_ratio, sampler_options, trials) | |
| self.p = p | |
| def __call__(self, *inputs: Any) -> Any: | |
| if torch.rand(1) >= self.p: | |
| return inputs if len(inputs) > 1 else inputs[0] | |
| return super().forward(*inputs) | |
| class ConvertBox(T.Transform): | |
| _transformed_types = ( | |
| _BBoxType, | |
| ) | |
| def __init__(self, out_fmt='', normalize=False) -> None: | |
| super().__init__() | |
| self.out_fmt = out_fmt | |
| self.normalize = normalize | |
| self.data_fmt = { | |
| 'xyxy': _BBoxFormat.XYXY, | |
| 'cxcywh': _BBoxFormat.CXCYWH | |
| } | |
| def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: | |
| if self.out_fmt: | |
| spatial_size = _bbox_spatial_size(inpt) | |
| in_fmt = inpt.format.value.lower() | |
| inpt = torchvision.ops.box_convert(inpt, in_fmt=in_fmt, out_fmt=self.out_fmt) | |
| inpt = _make_bounding_box(inpt, format=self.data_fmt[self.out_fmt], spatial_size=spatial_size) | |
| if self.normalize: | |
| spatial_size = _bbox_spatial_size(inpt) | |
| inpt = inpt / torch.tensor(spatial_size[::-1]).tile(2)[None] | |
| return inpt | |
| def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: | |
| return self._transform(inpt, params) | |