Spaces:
Running on Zero
Running on Zero
File size: 7,594 Bytes
2a25b9b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 | """"by lyuwenyu
"""
import torch
import torch.nn as nn
import torchvision
torchvision.disable_beta_transforms_warning()
try:
from torchvision import datapoints as _datapoints
_HAS_DATAPOINTS = True
except Exception:
from torchvision import tv_tensors as _datapoints
_HAS_DATAPOINTS = False
import torchvision.transforms.v2 as T
import torchvision.transforms.v2.functional as F
from PIL import Image
from typing import Any, Dict, List, Optional
from src.core import register, GLOBAL_CONFIG
__all__ = ['Compose', ]
RandomPhotometricDistort = register(T.RandomPhotometricDistort)
RandomZoomOut = register(T.RandomZoomOut)
# RandomIoUCrop = register(T.RandomIoUCrop)
RandomHorizontalFlip = register(T.RandomHorizontalFlip)
Resize = register(T.Resize)
if hasattr(T, 'ToImageTensor'):
ToImageTensor = register(T.ToImageTensor)
else:
_BaseToImageTensor = getattr(T, 'ToImage', None) or getattr(T, 'ToTensor', None) or getattr(T, 'PILToTensor', None)
if _BaseToImageTensor is None:
raise AttributeError(
'torchvision.transforms.v2 is missing ToImageTensor/ToImage/ToTensor/PILToTensor; please update torchvision.'
)
@register
class ToImageTensor(_BaseToImageTensor):
pass
if hasattr(T, 'ConvertDtype'):
ConvertDtype = register(T.ConvertDtype)
else:
_BaseConvertDtype = getattr(T, 'ToDtype', None)
if _BaseConvertDtype is None:
raise AttributeError('torchvision.transforms.v2 is missing ConvertDtype/ToDtype; please update torchvision.')
@register
class ConvertDtype(_BaseConvertDtype):
def __init__(self, dtype: torch.dtype = torch.float32, scale: bool = True) -> None:
super().__init__(dtype=dtype, scale=scale)
if hasattr(T, 'SanitizeBoundingBox'):
_BaseSanitizeBoundingBox = T.SanitizeBoundingBox
else:
_BaseSanitizeBoundingBox = getattr(T, 'SanitizeBoundingBoxes', None)
if _BaseSanitizeBoundingBox is None:
raise AttributeError(
'torchvision.transforms.v2 is missing SanitizeBoundingBox/SanitizeBoundingBoxes; please update torchvision.'
)
@register
class SanitizeBoundingBox(_BaseSanitizeBoundingBox):
def forward(self, *inputs):
# Avoid indexing t_gt (full-image mask) with per-box valid mask.
if len(inputs) >= 2 and isinstance(inputs[1], dict) and "t_gt" in inputs[1]:
inputs = list(inputs)
target = dict(inputs[1])
t_gt = target.pop("t_gt")
inputs[1] = target
outputs = super().forward(*inputs)
if isinstance(outputs, tuple) and len(outputs) >= 2 and isinstance(outputs[1], dict):
outputs = list(outputs)
outputs[1]["t_gt"] = t_gt
return tuple(outputs)
return outputs
return super().forward(*inputs)
RandomCrop = register(T.RandomCrop)
Normalize = register(T.Normalize)
_Image = _datapoints.Image
_Video = _datapoints.Video
_Mask = _datapoints.Mask
_BBoxFormat = _datapoints.BoundingBoxFormat
_BBoxType = _datapoints.BoundingBox if _HAS_DATAPOINTS else _datapoints.BoundingBoxes
def _make_bounding_box(data, format, spatial_size):
fmt = format
if _HAS_DATAPOINTS:
return _datapoints.BoundingBox(data, format=fmt, spatial_size=spatial_size)
return _datapoints.BoundingBoxes(data, format=fmt, canvas_size=spatial_size)
def _bbox_spatial_size(bbox):
if hasattr(bbox, "spatial_size"):
return bbox.spatial_size
if hasattr(bbox, "canvas_size"):
return bbox.canvas_size
raise AttributeError("Bounding box object has neither spatial_size nor canvas_size.")
@register
class Compose(T.Compose):
def __init__(self, ops) -> None:
transforms = []
if ops is not None:
for op in ops:
if isinstance(op, dict):
name = op.pop('type')
transfom = getattr(GLOBAL_CONFIG[name]['_pymodule'], name)(**op)
transforms.append(transfom)
# op['type'] = name
elif isinstance(op, nn.Module):
transforms.append(op)
else:
raise ValueError('')
else:
transforms =[EmptyTransform(), ]
super().__init__(transforms=transforms)
@register
class EmptyTransform(T.Transform):
def __init__(self, ) -> None:
super().__init__()
def forward(self, *inputs):
inputs = inputs if len(inputs) > 1 else inputs[0]
return inputs
@register
class PadToSize(T.Pad):
_transformed_types = (
Image.Image,
_Image,
_Video,
_Mask,
_BBoxType,
)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
sz = F.get_spatial_size(flat_inputs[0])
h, w = self.spatial_size[0] - sz[0], self.spatial_size[1] - sz[1]
self.padding = [0, 0, w, h]
return dict(padding=self.padding)
def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
return self._get_params(flat_inputs)
def __init__(self, spatial_size, fill=0, padding_mode='constant') -> None:
if isinstance(spatial_size, int):
spatial_size = (spatial_size, spatial_size)
self.spatial_size = spatial_size
super().__init__(0, fill, padding_mode)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self._fill[type(inpt)]
padding = params['padding']
return F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type]
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._transform(inpt, params)
def __call__(self, *inputs: Any) -> Any:
outputs = super().forward(*inputs)
if len(outputs) > 1 and isinstance(outputs[1], dict):
outputs[1]['padding'] = torch.tensor(self.padding)
return outputs
@register
class RandomIoUCrop(T.RandomIoUCrop):
def __init__(self, min_scale: float = 0.3, max_scale: float = 1, min_aspect_ratio: float = 0.5, max_aspect_ratio: float = 2, sampler_options: Optional[List[float]] = None, trials: int = 40, p: float = 1.0):
super().__init__(min_scale, max_scale, min_aspect_ratio, max_aspect_ratio, sampler_options, trials)
self.p = p
def __call__(self, *inputs: Any) -> Any:
if torch.rand(1) >= self.p:
return inputs if len(inputs) > 1 else inputs[0]
return super().forward(*inputs)
@register
class ConvertBox(T.Transform):
_transformed_types = (
_BBoxType,
)
def __init__(self, out_fmt='', normalize=False) -> None:
super().__init__()
self.out_fmt = out_fmt
self.normalize = normalize
self.data_fmt = {
'xyxy': _BBoxFormat.XYXY,
'cxcywh': _BBoxFormat.CXCYWH
}
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if self.out_fmt:
spatial_size = _bbox_spatial_size(inpt)
in_fmt = inpt.format.value.lower()
inpt = torchvision.ops.box_convert(inpt, in_fmt=in_fmt, out_fmt=self.out_fmt)
inpt = _make_bounding_box(inpt, format=self.data_fmt[self.out_fmt], spatial_size=spatial_size)
if self.normalize:
spatial_size = _bbox_spatial_size(inpt)
inpt = inpt / torch.tensor(spatial_size[::-1]).tile(2)[None]
return inpt
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._transform(inpt, params)
|