diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..843c1acfedd471dd489ccce2ab7fd1b6f75f32dd Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/caltech.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/caltech.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb712d4799bd5c4f50092c198a35d162391c8d05 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/caltech.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/coco.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/coco.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e53b97120f94fee11ee4d4806d83aa2a6fad9e1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/coco.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/fakedata.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/fakedata.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d17c1dc8ec326caf168f79de3c026ebced2f82c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/fakedata.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/flickr.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/flickr.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41edd4efd9d460fc857437eaf53872ab290dac55 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/flickr.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/hmdb51.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/hmdb51.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5490c23b82c1f221ff7e3586f1a9f77dd50e808c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/hmdb51.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/imagenet.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/imagenet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8424d4ffd129df3547c16df5c6e6012b445cbffe Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/imagenet.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/kinetics.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/kinetics.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e01b081386ad220d0a62e0e119a91dca266859aa Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/kinetics.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/mnist.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/mnist.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cfe3c17064e43befb5b5dcfa5da6b50dc106c073 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/mnist.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef9b13596bbdf017ac323da223efca2589c2ef91 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/video_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/video_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d032e9e08ba218f31c9dca41b5ad710a29ec706c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/video_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/voc.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/voc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..662f58d50b4af251b9aaa9bee6a4f2248ee995db Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/voc.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/widerface.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/widerface.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e11d1de9728014311f26177477d6dfeec3002a8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/widerface.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/samplers/__init__.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/samplers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..58b2d2abd936d885221174d194a633a8e413935f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/samplers/__init__.py @@ -0,0 +1,3 @@ +from .clip_sampler import DistributedSampler, RandomClipSampler, UniformClipSampler + +__all__ = ("DistributedSampler", "UniformClipSampler", "RandomClipSampler") diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/samplers/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/samplers/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2acc8c72fda7f052d382a2ceae6d22d6a3855b09 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/samplers/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/samplers/__pycache__/clip_sampler.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/samplers/__pycache__/clip_sampler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4481f12e6f037f98863cd4a83070ece5b3bca7e3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/samplers/__pycache__/clip_sampler.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/samplers/clip_sampler.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/samplers/clip_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..026c3d75d3b8acd5d0240d8e537e608a92cddfdb --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/samplers/clip_sampler.py @@ -0,0 +1,172 @@ +import math +from typing import cast, Iterator, List, Optional, Sized, Union + +import torch +import torch.distributed as dist +from torch.utils.data import Sampler +from torchvision.datasets.video_utils import VideoClips + + +class DistributedSampler(Sampler): + """ + Extension of DistributedSampler, as discussed in + https://github.com/pytorch/pytorch/issues/23430 + + Example: + dataset: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] + num_replicas: 4 + shuffle: False + + when group_size = 1 + RANK | shard_dataset + ========================= + rank_0 | [0, 4, 8, 12] + rank_1 | [1, 5, 9, 13] + rank_2 | [2, 6, 10, 0] + rank_3 | [3, 7, 11, 1] + + when group_size = 2 + + RANK | shard_dataset + ========================= + rank_0 | [0, 1, 8, 9] + rank_1 | [2, 3, 10, 11] + rank_2 | [4, 5, 12, 13] + rank_3 | [6, 7, 0, 1] + + """ + + def __init__( + self, + dataset: Sized, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = False, + group_size: int = 1, + ) -> None: + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + if len(dataset) % group_size != 0: + raise ValueError( + f"dataset length must be a multiplier of group size dataset length: {len(dataset)}, group size: {group_size}" + ) + self.dataset = dataset + self.group_size = group_size + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + dataset_group_length = len(dataset) // group_size + self.num_group_samples = int(math.ceil(dataset_group_length * 1.0 / self.num_replicas)) + self.num_samples = self.num_group_samples * group_size + self.total_size = self.num_samples * self.num_replicas + self.shuffle = shuffle + + def __iter__(self) -> Iterator[int]: + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + indices: Union[torch.Tensor, List[int]] + if self.shuffle: + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = list(range(len(self.dataset))) + + # add extra samples to make it evenly divisible + indices += indices[: (self.total_size - len(indices))] + assert len(indices) == self.total_size + + total_group_size = self.total_size // self.group_size + indices = torch.reshape(torch.LongTensor(indices), (total_group_size, self.group_size)) + + # subsample + indices = indices[self.rank : total_group_size : self.num_replicas, :] + indices = torch.reshape(indices, (-1,)).tolist() + assert len(indices) == self.num_samples + + if isinstance(self.dataset, Sampler): + orig_indices = list(iter(self.dataset)) + indices = [orig_indices[i] for i in indices] + + return iter(indices) + + def __len__(self) -> int: + return self.num_samples + + def set_epoch(self, epoch: int) -> None: + self.epoch = epoch + + +class UniformClipSampler(Sampler): + """ + Sample `num_video_clips_per_video` clips for each video, equally spaced. + When number of unique clips in the video is fewer than num_video_clips_per_video, + repeat the clips until `num_video_clips_per_video` clips are collected + + Args: + video_clips (VideoClips): video clips to sample from + num_clips_per_video (int): number of clips to be sampled per video + """ + + def __init__(self, video_clips: VideoClips, num_clips_per_video: int) -> None: + if not isinstance(video_clips, VideoClips): + raise TypeError(f"Expected video_clips to be an instance of VideoClips, got {type(video_clips)}") + self.video_clips = video_clips + self.num_clips_per_video = num_clips_per_video + + def __iter__(self) -> Iterator[int]: + idxs = [] + s = 0 + # select num_clips_per_video for each video, uniformly spaced + for c in self.video_clips.clips: + length = len(c) + if length == 0: + # corner case where video decoding fails + continue + + sampled = torch.linspace(s, s + length - 1, steps=self.num_clips_per_video).floor().to(torch.int64) + s += length + idxs.append(sampled) + return iter(cast(List[int], torch.cat(idxs).tolist())) + + def __len__(self) -> int: + return sum(self.num_clips_per_video for c in self.video_clips.clips if len(c) > 0) + + +class RandomClipSampler(Sampler): + """ + Samples at most `max_video_clips_per_video` clips for each video randomly + + Args: + video_clips (VideoClips): video clips to sample from + max_clips_per_video (int): maximum number of clips to be sampled per video + """ + + def __init__(self, video_clips: VideoClips, max_clips_per_video: int) -> None: + if not isinstance(video_clips, VideoClips): + raise TypeError(f"Expected video_clips to be an instance of VideoClips, got {type(video_clips)}") + self.video_clips = video_clips + self.max_clips_per_video = max_clips_per_video + + def __iter__(self) -> Iterator[int]: + idxs = [] + s = 0 + # select at most max_clips_per_video for each video, randomly + for c in self.video_clips.clips: + length = len(c) + size = min(length, self.max_clips_per_video) + sampled = torch.randperm(length)[:size] + s + s += length + idxs.append(sampled) + idxs_ = torch.cat(idxs) + # shuffle all clips randomly + perm = torch.randperm(len(idxs_)) + return iter(idxs_[perm].tolist()) + + def __len__(self) -> int: + return sum(min(len(c), self.max_clips_per_video) for c in self.video_clips.clips) diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/__init__.py b/.venv/lib/python3.11/site-packages/torchvision/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..77680a14f0d0599f4004a2ce5c299c0f5e13a0d5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/transforms/__init__.py @@ -0,0 +1,2 @@ +from .transforms import * +from .autoaugment import * diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d600f1a623b112ab7f71bc7a0d69c25a936a3087 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/_functional_pil.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/_functional_pil.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9a40c78a8873886d618679fec568834a292d016 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/_functional_pil.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/_functional_tensor.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/_functional_tensor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35d1ba24d671aa6307f49378748064cbf067f929 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/_functional_tensor.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/_functional_video.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/_functional_video.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..153534cfbab5ba7e7ae3d9063613b0f95091df0a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/_functional_video.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/_presets.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/_presets.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12fd3beda66b4069ca365b918a59c34eeddcb140 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/_presets.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/_transforms_video.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/_transforms_video.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01678ad5a72e08aa52d13813b89a378c0f38a4b9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/_transforms_video.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/autoaugment.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/autoaugment.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f207ad0f94255c4ff0bde3cc054dc8b01d11ad6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/autoaugment.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/functional.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/functional.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40e771e700d21cf9cc150b3d75fea9cc78747276 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/functional.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/_functional_pil.py b/.venv/lib/python3.11/site-packages/torchvision/transforms/_functional_pil.py new file mode 100644 index 0000000000000000000000000000000000000000..527879bb6f1b249e2c6208032f30c42139a81b99 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/transforms/_functional_pil.py @@ -0,0 +1,393 @@ +import numbers +from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +from PIL import Image, ImageEnhance, ImageOps + +try: + import accimage +except ImportError: + accimage = None + + +@torch.jit.unused +def _is_pil_image(img: Any) -> bool: + if accimage is not None: + return isinstance(img, (Image.Image, accimage.Image)) + else: + return isinstance(img, Image.Image) + + +@torch.jit.unused +def get_dimensions(img: Any) -> List[int]: + if _is_pil_image(img): + if hasattr(img, "getbands"): + channels = len(img.getbands()) + else: + channels = img.channels + width, height = img.size + return [channels, height, width] + raise TypeError(f"Unexpected type {type(img)}") + + +@torch.jit.unused +def get_image_size(img: Any) -> List[int]: + if _is_pil_image(img): + return list(img.size) + raise TypeError(f"Unexpected type {type(img)}") + + +@torch.jit.unused +def get_image_num_channels(img: Any) -> int: + if _is_pil_image(img): + if hasattr(img, "getbands"): + return len(img.getbands()) + else: + return img.channels + raise TypeError(f"Unexpected type {type(img)}") + + +@torch.jit.unused +def hflip(img: Image.Image) -> Image.Image: + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + + return img.transpose(Image.FLIP_LEFT_RIGHT) + + +@torch.jit.unused +def vflip(img: Image.Image) -> Image.Image: + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + + return img.transpose(Image.FLIP_TOP_BOTTOM) + + +@torch.jit.unused +def adjust_brightness(img: Image.Image, brightness_factor: float) -> Image.Image: + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + + enhancer = ImageEnhance.Brightness(img) + img = enhancer.enhance(brightness_factor) + return img + + +@torch.jit.unused +def adjust_contrast(img: Image.Image, contrast_factor: float) -> Image.Image: + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + + enhancer = ImageEnhance.Contrast(img) + img = enhancer.enhance(contrast_factor) + return img + + +@torch.jit.unused +def adjust_saturation(img: Image.Image, saturation_factor: float) -> Image.Image: + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + + enhancer = ImageEnhance.Color(img) + img = enhancer.enhance(saturation_factor) + return img + + +@torch.jit.unused +def adjust_hue(img: Image.Image, hue_factor: float) -> Image.Image: + if not (-0.5 <= hue_factor <= 0.5): + raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].") + + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + + input_mode = img.mode + if input_mode in {"L", "1", "I", "F"}: + return img + + h, s, v = img.convert("HSV").split() + + np_h = np.array(h, dtype=np.uint8) + # This will over/underflow, as desired + np_h += np.array(hue_factor * 255).astype(np.uint8) + + h = Image.fromarray(np_h, "L") + + img = Image.merge("HSV", (h, s, v)).convert(input_mode) + return img + + +@torch.jit.unused +def adjust_gamma( + img: Image.Image, + gamma: float, + gain: float = 1.0, +) -> Image.Image: + + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + + if gamma < 0: + raise ValueError("Gamma should be a non-negative real number") + + input_mode = img.mode + img = img.convert("RGB") + gamma_map = [int((255 + 1 - 1e-3) * gain * pow(ele / 255.0, gamma)) for ele in range(256)] * 3 + img = img.point(gamma_map) # use PIL's point-function to accelerate this part + + img = img.convert(input_mode) + return img + + +@torch.jit.unused +def pad( + img: Image.Image, + padding: Union[int, List[int], Tuple[int, ...]], + fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0, + padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", +) -> Image.Image: + + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + + if not isinstance(padding, (numbers.Number, tuple, list)): + raise TypeError("Got inappropriate padding arg") + if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)): + raise TypeError("Got inappropriate fill arg") + if not isinstance(padding_mode, str): + raise TypeError("Got inappropriate padding_mode arg") + + if isinstance(padding, list): + padding = tuple(padding) + + if isinstance(padding, tuple) and len(padding) not in [1, 2, 4]: + raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple") + + if isinstance(padding, tuple) and len(padding) == 1: + # Compatibility with `functional_tensor.pad` + padding = padding[0] + + if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: + raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") + + if padding_mode == "constant": + opts = _parse_fill(fill, img, name="fill") + if img.mode == "P": + palette = img.getpalette() + image = ImageOps.expand(img, border=padding, **opts) + image.putpalette(palette) + return image + + return ImageOps.expand(img, border=padding, **opts) + else: + if isinstance(padding, int): + pad_left = pad_right = pad_top = pad_bottom = padding + if isinstance(padding, tuple) and len(padding) == 2: + pad_left = pad_right = padding[0] + pad_top = pad_bottom = padding[1] + if isinstance(padding, tuple) and len(padding) == 4: + pad_left = padding[0] + pad_top = padding[1] + pad_right = padding[2] + pad_bottom = padding[3] + + p = [pad_left, pad_top, pad_right, pad_bottom] + cropping = -np.minimum(p, 0) + + if cropping.any(): + crop_left, crop_top, crop_right, crop_bottom = cropping + img = img.crop((crop_left, crop_top, img.width - crop_right, img.height - crop_bottom)) + + pad_left, pad_top, pad_right, pad_bottom = np.maximum(p, 0) + + if img.mode == "P": + palette = img.getpalette() + img = np.asarray(img) + img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), mode=padding_mode) + img = Image.fromarray(img) + img.putpalette(palette) + return img + + img = np.asarray(img) + # RGB image + if len(img.shape) == 3: + img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode) + # Grayscale image + if len(img.shape) == 2: + img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode) + + return Image.fromarray(img) + + +@torch.jit.unused +def crop( + img: Image.Image, + top: int, + left: int, + height: int, + width: int, +) -> Image.Image: + + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + + return img.crop((left, top, left + width, top + height)) + + +@torch.jit.unused +def resize( + img: Image.Image, + size: Union[List[int], int], + interpolation: int = Image.BILINEAR, +) -> Image.Image: + + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + if not (isinstance(size, list) and len(size) == 2): + raise TypeError(f"Got inappropriate size arg: {size}") + + return img.resize(tuple(size[::-1]), interpolation) + + +@torch.jit.unused +def _parse_fill( + fill: Optional[Union[float, List[float], Tuple[float, ...]]], + img: Image.Image, + name: str = "fillcolor", +) -> Dict[str, Optional[Union[float, List[float], Tuple[float, ...]]]]: + + # Process fill color for affine transforms + num_channels = get_image_num_channels(img) + if fill is None: + fill = 0 + if isinstance(fill, (int, float)) and num_channels > 1: + fill = tuple([fill] * num_channels) + if isinstance(fill, (list, tuple)): + if len(fill) == 1: + fill = fill * num_channels + elif len(fill) != num_channels: + msg = "The number of elements in 'fill' does not match the number of channels of the image ({} != {})" + raise ValueError(msg.format(len(fill), num_channels)) + + fill = tuple(fill) # type: ignore[arg-type] + + if img.mode != "F": + if isinstance(fill, (list, tuple)): + fill = tuple(int(x) for x in fill) + else: + fill = int(fill) + + return {name: fill} + + +@torch.jit.unused +def affine( + img: Image.Image, + matrix: List[float], + interpolation: int = Image.NEAREST, + fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, +) -> Image.Image: + + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + + output_size = img.size + opts = _parse_fill(fill, img) + return img.transform(output_size, Image.AFFINE, matrix, interpolation, **opts) + + +@torch.jit.unused +def rotate( + img: Image.Image, + angle: float, + interpolation: int = Image.NEAREST, + expand: bool = False, + center: Optional[Tuple[int, int]] = None, + fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, +) -> Image.Image: + + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + + opts = _parse_fill(fill, img) + return img.rotate(angle, interpolation, expand, center, **opts) + + +@torch.jit.unused +def perspective( + img: Image.Image, + perspective_coeffs: List[float], + interpolation: int = Image.BICUBIC, + fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, +) -> Image.Image: + + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + + opts = _parse_fill(fill, img) + + return img.transform(img.size, Image.PERSPECTIVE, perspective_coeffs, interpolation, **opts) + + +@torch.jit.unused +def to_grayscale(img: Image.Image, num_output_channels: int) -> Image.Image: + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + + if num_output_channels == 1: + img = img.convert("L") + elif num_output_channels == 3: + img = img.convert("L") + np_img = np.array(img, dtype=np.uint8) + np_img = np.dstack([np_img, np_img, np_img]) + img = Image.fromarray(np_img, "RGB") + else: + raise ValueError("num_output_channels should be either 1 or 3") + + return img + + +@torch.jit.unused +def invert(img: Image.Image) -> Image.Image: + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + return ImageOps.invert(img) + + +@torch.jit.unused +def posterize(img: Image.Image, bits: int) -> Image.Image: + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + return ImageOps.posterize(img, bits) + + +@torch.jit.unused +def solarize(img: Image.Image, threshold: int) -> Image.Image: + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + return ImageOps.solarize(img, threshold) + + +@torch.jit.unused +def adjust_sharpness(img: Image.Image, sharpness_factor: float) -> Image.Image: + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + + enhancer = ImageEnhance.Sharpness(img) + img = enhancer.enhance(sharpness_factor) + return img + + +@torch.jit.unused +def autocontrast(img: Image.Image) -> Image.Image: + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + return ImageOps.autocontrast(img) + + +@torch.jit.unused +def equalize(img: Image.Image) -> Image.Image: + if not _is_pil_image(img): + raise TypeError(f"img should be PIL Image. Got {type(img)}") + return ImageOps.equalize(img) diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/_functional_tensor.py b/.venv/lib/python3.11/site-packages/torchvision/transforms/_functional_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..618bbfbab7c8f40b34216c211a4a39f2a87ff72d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/transforms/_functional_tensor.py @@ -0,0 +1,962 @@ +import warnings +from typing import List, Optional, Tuple, Union + +import torch +from torch import Tensor +from torch.nn.functional import conv2d, grid_sample, interpolate, pad as torch_pad + + +def _is_tensor_a_torch_image(x: Tensor) -> bool: + return x.ndim >= 2 + + +def _assert_image_tensor(img: Tensor) -> None: + if not _is_tensor_a_torch_image(img): + raise TypeError("Tensor is not a torch image.") + + +def get_dimensions(img: Tensor) -> List[int]: + _assert_image_tensor(img) + channels = 1 if img.ndim == 2 else img.shape[-3] + height, width = img.shape[-2:] + return [channels, height, width] + + +def get_image_size(img: Tensor) -> List[int]: + # Returns (w, h) of tensor image + _assert_image_tensor(img) + return [img.shape[-1], img.shape[-2]] + + +def get_image_num_channels(img: Tensor) -> int: + _assert_image_tensor(img) + if img.ndim == 2: + return 1 + elif img.ndim > 2: + return img.shape[-3] + + raise TypeError(f"Input ndim should be 2 or more. Got {img.ndim}") + + +def _max_value(dtype: torch.dtype) -> int: + if dtype == torch.uint8: + return 255 + elif dtype == torch.int8: + return 127 + elif dtype == torch.int16: + return 32767 + elif dtype == torch.uint16: + return 65535 + elif dtype == torch.int32: + return 2147483647 + elif dtype == torch.int64: + return 9223372036854775807 + else: + # This is only here for completeness. This value is implicitly assumed in a lot of places so changing it is not + # easy. + return 1 + + +def _assert_channels(img: Tensor, permitted: List[int]) -> None: + c = get_dimensions(img)[0] + if c not in permitted: + raise TypeError(f"Input image tensor permitted channel values are {permitted}, but found {c}") + + +def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor: + if image.dtype == dtype: + return image + + if image.is_floating_point(): + + # TODO: replace with dtype.is_floating_point when torchscript supports it + if torch.tensor(0, dtype=dtype).is_floating_point(): + return image.to(dtype) + + # float to int + if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or ( + image.dtype == torch.float64 and dtype == torch.int64 + ): + msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely." + raise RuntimeError(msg) + + # https://github.com/pytorch/vision/pull/2078#issuecomment-612045321 + # For data in the range 0-1, (float * 255).to(uint) is only 255 + # when float is exactly 1.0. + # `max + 1 - epsilon` provides more evenly distributed mapping of + # ranges of floats to ints. + eps = 1e-3 + max_val = float(_max_value(dtype)) + result = image.mul(max_val + 1.0 - eps) + return result.to(dtype) + else: + input_max = float(_max_value(image.dtype)) + + # int to float + # TODO: replace with dtype.is_floating_point when torchscript supports it + if torch.tensor(0, dtype=dtype).is_floating_point(): + image = image.to(dtype) + return image / input_max + + output_max = float(_max_value(dtype)) + + # int to int + if input_max > output_max: + # factor should be forced to int for torch jit script + # otherwise factor is a float and image // factor can produce different results + factor = int((input_max + 1) // (output_max + 1)) + image = torch.div(image, factor, rounding_mode="floor") + return image.to(dtype) + else: + # factor should be forced to int for torch jit script + # otherwise factor is a float and image * factor can produce different results + factor = int((output_max + 1) // (input_max + 1)) + image = image.to(dtype) + return image * factor + + +def vflip(img: Tensor) -> Tensor: + _assert_image_tensor(img) + + return img.flip(-2) + + +def hflip(img: Tensor) -> Tensor: + _assert_image_tensor(img) + + return img.flip(-1) + + +def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor: + _assert_image_tensor(img) + + _, h, w = get_dimensions(img) + right = left + width + bottom = top + height + + if left < 0 or top < 0 or right > w or bottom > h: + padding_ltrb = [ + max(-left + min(0, right), 0), + max(-top + min(0, bottom), 0), + max(right - max(w, left), 0), + max(bottom - max(h, top), 0), + ] + return pad(img[..., max(top, 0) : bottom, max(left, 0) : right], padding_ltrb, fill=0) + return img[..., top:bottom, left:right] + + +def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: + if img.ndim < 3: + raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}") + _assert_channels(img, [1, 3]) + + if num_output_channels not in (1, 3): + raise ValueError("num_output_channels should be either 1 or 3") + + if img.shape[-3] == 3: + r, g, b = img.unbind(dim=-3) + # This implementation closely follows the TF one: + # https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/ops/image_ops_impl.py#L2105-L2138 + l_img = (0.2989 * r + 0.587 * g + 0.114 * b).to(img.dtype) + l_img = l_img.unsqueeze(dim=-3) + else: + l_img = img.clone() + + if num_output_channels == 3: + return l_img.expand(img.shape) + + return l_img + + +def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: + if brightness_factor < 0: + raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.") + + _assert_image_tensor(img) + + _assert_channels(img, [1, 3]) + + return _blend(img, torch.zeros_like(img), brightness_factor) + + +def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: + if contrast_factor < 0: + raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.") + + _assert_image_tensor(img) + + _assert_channels(img, [3, 1]) + c = get_dimensions(img)[0] + dtype = img.dtype if torch.is_floating_point(img) else torch.float32 + if c == 3: + mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True) + else: + mean = torch.mean(img.to(dtype), dim=(-3, -2, -1), keepdim=True) + + return _blend(img, mean, contrast_factor) + + +def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: + if not (-0.5 <= hue_factor <= 0.5): + raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].") + + if not (isinstance(img, torch.Tensor)): + raise TypeError("Input img should be Tensor image") + + _assert_image_tensor(img) + + _assert_channels(img, [1, 3]) + if get_dimensions(img)[0] == 1: # Match PIL behaviour + return img + + orig_dtype = img.dtype + img = convert_image_dtype(img, torch.float32) + + img = _rgb2hsv(img) + h, s, v = img.unbind(dim=-3) + h = (h + hue_factor) % 1.0 + img = torch.stack((h, s, v), dim=-3) + img_hue_adj = _hsv2rgb(img) + + return convert_image_dtype(img_hue_adj, orig_dtype) + + +def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: + if saturation_factor < 0: + raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.") + + _assert_image_tensor(img) + + _assert_channels(img, [1, 3]) + + if get_dimensions(img)[0] == 1: # Match PIL behaviour + return img + + return _blend(img, rgb_to_grayscale(img), saturation_factor) + + +def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: + if not isinstance(img, torch.Tensor): + raise TypeError("Input img should be a Tensor.") + + _assert_channels(img, [1, 3]) + + if gamma < 0: + raise ValueError("Gamma should be a non-negative real number") + + result = img + dtype = img.dtype + if not torch.is_floating_point(img): + result = convert_image_dtype(result, torch.float32) + + result = (gain * result**gamma).clamp(0, 1) + + result = convert_image_dtype(result, dtype) + return result + + +def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor: + ratio = float(ratio) + bound = _max_value(img1.dtype) + return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype) + + +def _rgb2hsv(img: Tensor) -> Tensor: + r, g, b = img.unbind(dim=-3) + + # Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/ + # src/libImaging/Convert.c#L330 + maxc = torch.max(img, dim=-3).values + minc = torch.min(img, dim=-3).values + + # The algorithm erases S and H channel where `maxc = minc`. This avoids NaN + # from happening in the results, because + # + S channel has division by `maxc`, which is zero only if `maxc = minc` + # + H channel has division by `(maxc - minc)`. + # + # Instead of overwriting NaN afterwards, we just prevent it from occurring, so + # we don't need to deal with it in case we save the NaN in a buffer in + # backprop, if it is ever supported, but it doesn't hurt to do so. + eqc = maxc == minc + + cr = maxc - minc + # Since `eqc => cr = 0`, replacing denominator with 1 when `eqc` is fine. + ones = torch.ones_like(maxc) + s = cr / torch.where(eqc, ones, maxc) + # Note that `eqc => maxc = minc = r = g = b`. So the following calculation + # of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it + # would not matter what values `rc`, `gc`, and `bc` have here, and thus + # replacing denominator with 1 when `eqc` is fine. + cr_divisor = torch.where(eqc, ones, cr) + rc = (maxc - r) / cr_divisor + gc = (maxc - g) / cr_divisor + bc = (maxc - b) / cr_divisor + + hr = (maxc == r) * (bc - gc) + hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc) + hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc) + h = hr + hg + hb + h = torch.fmod((h / 6.0 + 1.0), 1.0) + return torch.stack((h, s, maxc), dim=-3) + + +def _hsv2rgb(img: Tensor) -> Tensor: + h, s, v = img.unbind(dim=-3) + i = torch.floor(h * 6.0) + f = (h * 6.0) - i + i = i.to(dtype=torch.int32) + + p = torch.clamp((v * (1.0 - s)), 0.0, 1.0) + q = torch.clamp((v * (1.0 - s * f)), 0.0, 1.0) + t = torch.clamp((v * (1.0 - s * (1.0 - f))), 0.0, 1.0) + i = i % 6 + + mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1) + + a1 = torch.stack((v, q, p, p, t, v), dim=-3) + a2 = torch.stack((t, v, v, q, p, p), dim=-3) + a3 = torch.stack((p, p, t, v, v, q), dim=-3) + a4 = torch.stack((a1, a2, a3), dim=-4) + + return torch.einsum("...ijk, ...xijk -> ...xjk", mask.to(dtype=img.dtype), a4) + + +def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor: + # padding is left, right, top, bottom + + # crop if needed + if padding[0] < 0 or padding[1] < 0 or padding[2] < 0 or padding[3] < 0: + neg_min_padding = [-min(x, 0) for x in padding] + crop_left, crop_right, crop_top, crop_bottom = neg_min_padding + img = img[..., crop_top : img.shape[-2] - crop_bottom, crop_left : img.shape[-1] - crop_right] + padding = [max(x, 0) for x in padding] + + in_sizes = img.size() + + _x_indices = [i for i in range(in_sizes[-1])] # [0, 1, 2, 3, ...] + left_indices = [i for i in range(padding[0] - 1, -1, -1)] # e.g. [3, 2, 1, 0] + right_indices = [-(i + 1) for i in range(padding[1])] # e.g. [-1, -2, -3] + x_indices = torch.tensor(left_indices + _x_indices + right_indices, device=img.device) + + _y_indices = [i for i in range(in_sizes[-2])] + top_indices = [i for i in range(padding[2] - 1, -1, -1)] + bottom_indices = [-(i + 1) for i in range(padding[3])] + y_indices = torch.tensor(top_indices + _y_indices + bottom_indices, device=img.device) + + ndim = img.ndim + if ndim == 3: + return img[:, y_indices[:, None], x_indices[None, :]] + elif ndim == 4: + return img[:, :, y_indices[:, None], x_indices[None, :]] + else: + raise RuntimeError("Symmetric padding of N-D tensors are not supported yet") + + +def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]: + if isinstance(padding, int): + if torch.jit.is_scripting(): + # This maybe unreachable + raise ValueError("padding can't be an int while torchscripting, set it as a list [value, ]") + pad_left = pad_right = pad_top = pad_bottom = padding + elif len(padding) == 1: + pad_left = pad_right = pad_top = pad_bottom = padding[0] + elif len(padding) == 2: + pad_left = pad_right = padding[0] + pad_top = pad_bottom = padding[1] + else: + pad_left = padding[0] + pad_top = padding[1] + pad_right = padding[2] + pad_bottom = padding[3] + + return [pad_left, pad_right, pad_top, pad_bottom] + + +def pad( + img: Tensor, padding: Union[int, List[int]], fill: Optional[Union[int, float]] = 0, padding_mode: str = "constant" +) -> Tensor: + _assert_image_tensor(img) + + if fill is None: + fill = 0 + + if not isinstance(padding, (int, tuple, list)): + raise TypeError("Got inappropriate padding arg") + if not isinstance(fill, (int, float)): + raise TypeError("Got inappropriate fill arg") + if not isinstance(padding_mode, str): + raise TypeError("Got inappropriate padding_mode arg") + + if isinstance(padding, tuple): + padding = list(padding) + + if isinstance(padding, list): + # TODO: Jit is failing on loading this op when scripted and saved + # https://github.com/pytorch/pytorch/issues/81100 + if len(padding) not in [1, 2, 4]: + raise ValueError( + f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple" + ) + + if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: + raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") + + p = _parse_pad_padding(padding) + + if padding_mode == "edge": + # remap padding_mode str + padding_mode = "replicate" + elif padding_mode == "symmetric": + # route to another implementation + return _pad_symmetric(img, p) + + need_squeeze = False + if img.ndim < 4: + img = img.unsqueeze(dim=0) + need_squeeze = True + + out_dtype = img.dtype + need_cast = False + if (padding_mode != "constant") and img.dtype not in (torch.float32, torch.float64): + # Here we temporarily cast input tensor to float + # until pytorch issue is resolved : + # https://github.com/pytorch/pytorch/issues/40763 + need_cast = True + img = img.to(torch.float32) + + if padding_mode in ("reflect", "replicate"): + img = torch_pad(img, p, mode=padding_mode) + else: + img = torch_pad(img, p, mode=padding_mode, value=float(fill)) + + if need_squeeze: + img = img.squeeze(dim=0) + + if need_cast: + img = img.to(out_dtype) + + return img + + +def resize( + img: Tensor, + size: List[int], + interpolation: str = "bilinear", + antialias: Optional[bool] = True, +) -> Tensor: + _assert_image_tensor(img) + + if isinstance(size, tuple): + size = list(size) + + if antialias is None: + antialias = False + + if antialias and interpolation not in ["bilinear", "bicubic"]: + # We manually set it to False to avoid an error downstream in interpolate() + # This behaviour is documented: the parameter is irrelevant for modes + # that are not bilinear or bicubic. We used to raise an error here, but + # now we don't as True is the default. + antialias = False + + img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64]) + + # Define align_corners to avoid warnings + align_corners = False if interpolation in ["bilinear", "bicubic"] else None + + img = interpolate(img, size=size, mode=interpolation, align_corners=align_corners, antialias=antialias) + + if interpolation == "bicubic" and out_dtype == torch.uint8: + img = img.clamp(min=0, max=255) + + img = _cast_squeeze_out(img, need_cast=need_cast, need_squeeze=need_squeeze, out_dtype=out_dtype) + + return img + + +def _assert_grid_transform_inputs( + img: Tensor, + matrix: Optional[List[float]], + interpolation: str, + fill: Optional[Union[int, float, List[float]]], + supported_interpolation_modes: List[str], + coeffs: Optional[List[float]] = None, +) -> None: + + if not (isinstance(img, torch.Tensor)): + raise TypeError("Input img should be Tensor") + + _assert_image_tensor(img) + + if matrix is not None and not isinstance(matrix, list): + raise TypeError("Argument matrix should be a list") + + if matrix is not None and len(matrix) != 6: + raise ValueError("Argument matrix should have 6 float values") + + if coeffs is not None and len(coeffs) != 8: + raise ValueError("Argument coeffs should have 8 float values") + + if fill is not None and not isinstance(fill, (int, float, tuple, list)): + warnings.warn("Argument fill should be either int, float, tuple or list") + + # Check fill + num_channels = get_dimensions(img)[0] + if fill is not None and isinstance(fill, (tuple, list)) and len(fill) > 1 and len(fill) != num_channels: + msg = ( + "The number of elements in 'fill' cannot broadcast to match the number of " + "channels of the image ({} != {})" + ) + raise ValueError(msg.format(len(fill), num_channels)) + + if interpolation not in supported_interpolation_modes: + raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input") + + +def _cast_squeeze_in(img: Tensor, req_dtypes: List[torch.dtype]) -> Tuple[Tensor, bool, bool, torch.dtype]: + need_squeeze = False + # make image NCHW + if img.ndim < 4: + img = img.unsqueeze(dim=0) + need_squeeze = True + + out_dtype = img.dtype + need_cast = False + if out_dtype not in req_dtypes: + need_cast = True + req_dtype = req_dtypes[0] + img = img.to(req_dtype) + return img, need_cast, need_squeeze, out_dtype + + +def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtype: torch.dtype) -> Tensor: + if need_squeeze: + img = img.squeeze(dim=0) + + if need_cast: + if out_dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): + # it is better to round before cast + img = torch.round(img) + img = img.to(out_dtype) + + return img + + +def _apply_grid_transform( + img: Tensor, grid: Tensor, mode: str, fill: Optional[Union[int, float, List[float]]] +) -> Tensor: + + img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [grid.dtype]) + + if img.shape[0] > 1: + # Apply same grid to a batch of images + grid = grid.expand(img.shape[0], grid.shape[1], grid.shape[2], grid.shape[3]) + + # Append a dummy mask for customized fill colors, should be faster than grid_sample() twice + if fill is not None: + mask = torch.ones((img.shape[0], 1, img.shape[2], img.shape[3]), dtype=img.dtype, device=img.device) + img = torch.cat((img, mask), dim=1) + + img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False) + + # Fill with required color + if fill is not None: + mask = img[:, -1:, :, :] # N * 1 * H * W + img = img[:, :-1, :, :] # N * C * H * W + mask = mask.expand_as(img) + fill_list, len_fill = (fill, len(fill)) if isinstance(fill, (tuple, list)) else ([float(fill)], 1) + fill_img = torch.tensor(fill_list, dtype=img.dtype, device=img.device).view(1, len_fill, 1, 1).expand_as(img) + if mode == "nearest": + mask = mask < 0.5 + img[mask] = fill_img[mask] + else: # 'bilinear' + img = img * mask + (1.0 - mask) * fill_img + + img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype) + return img + + +def _gen_affine_grid( + theta: Tensor, + w: int, + h: int, + ow: int, + oh: int, +) -> Tensor: + # https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/ + # AffineGridGenerator.cpp#L18 + # Difference with AffineGridGenerator is that: + # 1) we normalize grid values after applying theta + # 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate + + d = 0.5 + base_grid = torch.empty(1, oh, ow, 3, dtype=theta.dtype, device=theta.device) + x_grid = torch.linspace(-ow * 0.5 + d, ow * 0.5 + d - 1, steps=ow, device=theta.device) + base_grid[..., 0].copy_(x_grid) + y_grid = torch.linspace(-oh * 0.5 + d, oh * 0.5 + d - 1, steps=oh, device=theta.device).unsqueeze_(-1) + base_grid[..., 1].copy_(y_grid) + base_grid[..., 2].fill_(1) + + rescaled_theta = theta.transpose(1, 2) / torch.tensor([0.5 * w, 0.5 * h], dtype=theta.dtype, device=theta.device) + output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta) + return output_grid.view(1, oh, ow, 2) + + +def affine( + img: Tensor, + matrix: List[float], + interpolation: str = "nearest", + fill: Optional[Union[int, float, List[float]]] = None, +) -> Tensor: + _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"]) + + dtype = img.dtype if torch.is_floating_point(img) else torch.float32 + theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3) + shape = img.shape + # grid will be generated on the same device as theta and img + grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2]) + return _apply_grid_transform(img, grid, interpolation, fill=fill) + + +def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]: + + # Inspired of PIL implementation: + # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054 + + # pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points. + # Points are shifted due to affine matrix torch convention about + # the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5) + pts = torch.tensor( + [ + [-0.5 * w, -0.5 * h, 1.0], + [-0.5 * w, 0.5 * h, 1.0], + [0.5 * w, 0.5 * h, 1.0], + [0.5 * w, -0.5 * h, 1.0], + ] + ) + theta = torch.tensor(matrix, dtype=torch.float).view(2, 3) + new_pts = torch.matmul(pts, theta.T) + min_vals, _ = new_pts.min(dim=0) + max_vals, _ = new_pts.max(dim=0) + + # shift points to [0, w] and [0, h] interval to match PIL results + min_vals += torch.tensor((w * 0.5, h * 0.5)) + max_vals += torch.tensor((w * 0.5, h * 0.5)) + + # Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0 + tol = 1e-4 + cmax = torch.ceil((max_vals / tol).trunc_() * tol) + cmin = torch.floor((min_vals / tol).trunc_() * tol) + size = cmax - cmin + return int(size[0]), int(size[1]) # w, h + + +def rotate( + img: Tensor, + matrix: List[float], + interpolation: str = "nearest", + expand: bool = False, + fill: Optional[Union[int, float, List[float]]] = None, +) -> Tensor: + _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"]) + w, h = img.shape[-1], img.shape[-2] + ow, oh = _compute_affine_output_size(matrix, w, h) if expand else (w, h) + dtype = img.dtype if torch.is_floating_point(img) else torch.float32 + theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3) + # grid will be generated on the same device as theta and img + grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh) + + return _apply_grid_transform(img, grid, interpolation, fill=fill) + + +def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> Tensor: + # https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/ + # src/libImaging/Geometry.c#L394 + + # + # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1) + # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1) + # + theta1 = torch.tensor( + [[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device + ) + theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device) + + d = 0.5 + base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device) + x_grid = torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow, device=device) + base_grid[..., 0].copy_(x_grid) + y_grid = torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh, device=device).unsqueeze_(-1) + base_grid[..., 1].copy_(y_grid) + base_grid[..., 2].fill_(1) + + rescaled_theta1 = theta1.transpose(1, 2) / torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device) + output_grid1 = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta1) + output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2)) + + output_grid = output_grid1 / output_grid2 - 1.0 + return output_grid.view(1, oh, ow, 2) + + +def perspective( + img: Tensor, + perspective_coeffs: List[float], + interpolation: str = "bilinear", + fill: Optional[Union[int, float, List[float]]] = None, +) -> Tensor: + if not (isinstance(img, torch.Tensor)): + raise TypeError("Input img should be Tensor.") + + _assert_image_tensor(img) + + _assert_grid_transform_inputs( + img, + matrix=None, + interpolation=interpolation, + fill=fill, + supported_interpolation_modes=["nearest", "bilinear"], + coeffs=perspective_coeffs, + ) + + ow, oh = img.shape[-1], img.shape[-2] + dtype = img.dtype if torch.is_floating_point(img) else torch.float32 + grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=img.device) + return _apply_grid_transform(img, grid, interpolation, fill=fill) + + +def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> Tensor: + ksize_half = (kernel_size - 1) * 0.5 + + x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size, dtype=dtype, device=device) + pdf = torch.exp(-0.5 * (x / sigma).pow(2)) + kernel1d = pdf / pdf.sum() + + return kernel1d + + +def _get_gaussian_kernel2d( + kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device +) -> Tensor: + kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0], dtype, device) + kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1], dtype, device) + kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :]) + return kernel2d + + +def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor: + if not (isinstance(img, torch.Tensor)): + raise TypeError(f"img should be Tensor. Got {type(img)}") + + _assert_image_tensor(img) + + dtype = img.dtype if torch.is_floating_point(img) else torch.float32 + kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device) + kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1]) + + img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype]) + + # padding = (left, right, top, bottom) + padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2] + img = torch_pad(img, padding, mode="reflect") + img = conv2d(img, kernel, groups=img.shape[-3]) + + img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype) + return img + + +def invert(img: Tensor) -> Tensor: + + _assert_image_tensor(img) + + if img.ndim < 3: + raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}") + + _assert_channels(img, [1, 3]) + + return _max_value(img.dtype) - img + + +def posterize(img: Tensor, bits: int) -> Tensor: + + _assert_image_tensor(img) + + if img.ndim < 3: + raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}") + if img.dtype != torch.uint8: + raise TypeError(f"Only torch.uint8 image tensors are supported, but found {img.dtype}") + + _assert_channels(img, [1, 3]) + mask = -int(2 ** (8 - bits)) # JIT-friendly for: ~(2 ** (8 - bits) - 1) + return img & mask + + +def solarize(img: Tensor, threshold: float) -> Tensor: + + _assert_image_tensor(img) + + if img.ndim < 3: + raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}") + + _assert_channels(img, [1, 3]) + + if threshold > _max_value(img.dtype): + raise TypeError("Threshold should be less than bound of img.") + + inverted_img = invert(img) + return torch.where(img >= threshold, inverted_img, img) + + +def _blurred_degenerate_image(img: Tensor) -> Tensor: + dtype = img.dtype if torch.is_floating_point(img) else torch.float32 + + kernel = torch.ones((3, 3), dtype=dtype, device=img.device) + kernel[1, 1] = 5.0 + kernel /= kernel.sum() + kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1]) + + result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype]) + result_tmp = conv2d(result_tmp, kernel, groups=result_tmp.shape[-3]) + result_tmp = _cast_squeeze_out(result_tmp, need_cast, need_squeeze, out_dtype) + + result = img.clone() + result[..., 1:-1, 1:-1] = result_tmp + + return result + + +def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor: + if sharpness_factor < 0: + raise ValueError(f"sharpness_factor ({sharpness_factor}) is not non-negative.") + + _assert_image_tensor(img) + + _assert_channels(img, [1, 3]) + + if img.size(-1) <= 2 or img.size(-2) <= 2: + return img + + return _blend(img, _blurred_degenerate_image(img), sharpness_factor) + + +def autocontrast(img: Tensor) -> Tensor: + + _assert_image_tensor(img) + + if img.ndim < 3: + raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}") + + _assert_channels(img, [1, 3]) + + bound = _max_value(img.dtype) + dtype = img.dtype if torch.is_floating_point(img) else torch.float32 + + minimum = img.amin(dim=(-2, -1), keepdim=True).to(dtype) + maximum = img.amax(dim=(-2, -1), keepdim=True).to(dtype) + scale = bound / (maximum - minimum) + eq_idxs = torch.isfinite(scale).logical_not() + minimum[eq_idxs] = 0 + scale[eq_idxs] = 1 + + return ((img - minimum) * scale).clamp(0, bound).to(img.dtype) + + +def _scale_channel(img_chan: Tensor) -> Tensor: + # TODO: we should expect bincount to always be faster than histc, but this + # isn't always the case. Once + # https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if + # block and only use bincount. + if img_chan.is_cuda: + hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255) + else: + hist = torch.bincount(img_chan.reshape(-1), minlength=256) + + nonzero_hist = hist[hist != 0] + step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode="floor") + if step == 0: + return img_chan + + lut = torch.div(torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode="floor"), step, rounding_mode="floor") + lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255) + + return lut[img_chan.to(torch.int64)].to(torch.uint8) + + +def _equalize_single_image(img: Tensor) -> Tensor: + return torch.stack([_scale_channel(img[c]) for c in range(img.size(0))]) + + +def equalize(img: Tensor) -> Tensor: + + _assert_image_tensor(img) + + if not (3 <= img.ndim <= 4): + raise TypeError(f"Input image tensor should have 3 or 4 dimensions, but found {img.ndim}") + if img.dtype != torch.uint8: + raise TypeError(f"Only torch.uint8 image tensors are supported, but found {img.dtype}") + + _assert_channels(img, [1, 3]) + + if img.ndim == 3: + return _equalize_single_image(img) + + return torch.stack([_equalize_single_image(x) for x in img]) + + +def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor: + _assert_image_tensor(tensor) + + if not tensor.is_floating_point(): + raise TypeError(f"Input tensor should be a float tensor. Got {tensor.dtype}.") + + if tensor.ndim < 3: + raise ValueError( + f"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = {tensor.size()}" + ) + + if not inplace: + tensor = tensor.clone() + + dtype = tensor.dtype + mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device) + std = torch.as_tensor(std, dtype=dtype, device=tensor.device) + if (std == 0).any(): + raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.") + if mean.ndim == 1: + mean = mean.view(-1, 1, 1) + if std.ndim == 1: + std = std.view(-1, 1, 1) + return tensor.sub_(mean).div_(std) + + +def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor: + _assert_image_tensor(img) + + if not inplace: + img = img.clone() + + img[..., i : i + h, j : j + w] = v + return img + + +def _create_identity_grid(size: List[int]) -> Tensor: + hw_space = [torch.linspace((-s + 1) / s, (s - 1) / s, s) for s in size] + grid_y, grid_x = torch.meshgrid(hw_space, indexing="ij") + return torch.stack([grid_x, grid_y], -1).unsqueeze(0) # 1 x H x W x 2 + + +def elastic_transform( + img: Tensor, + displacement: Tensor, + interpolation: str = "bilinear", + fill: Optional[Union[int, float, List[float]]] = None, +) -> Tensor: + + if not (isinstance(img, torch.Tensor)): + raise TypeError(f"img should be Tensor. Got {type(img)}") + + size = list(img.shape[-2:]) + displacement = displacement.to(img.device) + + identity_grid = _create_identity_grid(size) + grid = identity_grid.to(img.device) + displacement + return _apply_grid_transform(img, grid, interpolation, fill) diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/_functional_video.py b/.venv/lib/python3.11/site-packages/torchvision/transforms/_functional_video.py new file mode 100644 index 0000000000000000000000000000000000000000..91df7d42cd71fc554aba51fcf5e90db30e3c3851 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/transforms/_functional_video.py @@ -0,0 +1,114 @@ +import warnings + +import torch + + +warnings.warn( + "The 'torchvision.transforms._functional_video' module is deprecated since 0.12 and will be removed in the future. " + "Please use the 'torchvision.transforms.functional' module instead." +) + + +def _is_tensor_video_clip(clip): + if not torch.is_tensor(clip): + raise TypeError("clip should be Tensor. Got %s" % type(clip)) + + if not clip.ndimension() == 4: + raise ValueError("clip should be 4D. Got %dD" % clip.dim()) + + return True + + +def crop(clip, i, j, h, w): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) + """ + if len(clip.size()) != 4: + raise ValueError("clip should be a 4D tensor") + return clip[..., i : i + h, j : j + w] + + +def resize(clip, target_size, interpolation_mode): + if len(target_size) != 2: + raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") + return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False) + + +def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"): + """ + Do spatial cropping and resizing to the video clip + Args: + clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) + i (int): i in (i,j) i.e coordinates of the upper left corner. + j (int): j in (i,j) i.e coordinates of the upper left corner. + h (int): Height of the cropped region. + w (int): Width of the cropped region. + size (tuple(int, int)): height and width of resized clip + Returns: + clip (torch.tensor): Resized and cropped clip. Size is (C, T, H, W) + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + clip = crop(clip, i, j, h, w) + clip = resize(clip, size, interpolation_mode) + return clip + + +def center_crop(clip, crop_size): + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + h, w = clip.size(-2), clip.size(-1) + th, tw = crop_size + if h < th or w < tw: + raise ValueError("height and width must be no smaller than crop_size") + + i = int(round((h - th) / 2.0)) + j = int(round((w - tw) / 2.0)) + return crop(clip, i, j, th, tw) + + +def to_tensor(clip): + """ + Convert tensor data type from uint8 to float, divide value by 255.0 and + permute the dimensions of clip tensor + Args: + clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C) + Return: + clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W) + """ + _is_tensor_video_clip(clip) + if not clip.dtype == torch.uint8: + raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype)) + return clip.float().permute(3, 0, 1, 2) / 255.0 + + +def normalize(clip, mean, std, inplace=False): + """ + Args: + clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W) + mean (tuple): pixel RGB mean. Size is (3) + std (tuple): pixel standard deviation. Size is (3) + Returns: + normalized clip (torch.tensor): Size is (C, T, H, W) + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + if not inplace: + clip = clip.clone() + mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device) + std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device) + clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) + return clip + + +def hflip(clip): + """ + Args: + clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W) + Returns: + flipped clip (torch.tensor): Size is (C, T, H, W) + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + return clip.flip(-1) diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/_presets.py b/.venv/lib/python3.11/site-packages/torchvision/transforms/_presets.py new file mode 100644 index 0000000000000000000000000000000000000000..fb6f4ad5ca543ddaf41dbdaa4161cd89193f5506 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/transforms/_presets.py @@ -0,0 +1,216 @@ +""" +This file is part of the private API. Please do not use directly these classes as they will be modified on +future versions without warning. The classes should be accessed only via the transforms argument of Weights. +""" +from typing import Optional, Tuple, Union + +import torch +from torch import nn, Tensor + +from . import functional as F, InterpolationMode + + +__all__ = [ + "ObjectDetection", + "ImageClassification", + "VideoClassification", + "SemanticSegmentation", + "OpticalFlow", +] + + +class ObjectDetection(nn.Module): + def forward(self, img: Tensor) -> Tensor: + if not isinstance(img, Tensor): + img = F.pil_to_tensor(img) + return F.convert_image_dtype(img, torch.float) + + def __repr__(self) -> str: + return self.__class__.__name__ + "()" + + def describe(self) -> str: + return ( + "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. " + "The images are rescaled to ``[0.0, 1.0]``." + ) + + +class ImageClassification(nn.Module): + def __init__( + self, + *, + crop_size: int, + resize_size: int = 256, + mean: Tuple[float, ...] = (0.485, 0.456, 0.406), + std: Tuple[float, ...] = (0.229, 0.224, 0.225), + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + antialias: Optional[bool] = True, + ) -> None: + super().__init__() + self.crop_size = [crop_size] + self.resize_size = [resize_size] + self.mean = list(mean) + self.std = list(std) + self.interpolation = interpolation + self.antialias = antialias + + def forward(self, img: Tensor) -> Tensor: + img = F.resize(img, self.resize_size, interpolation=self.interpolation, antialias=self.antialias) + img = F.center_crop(img, self.crop_size) + if not isinstance(img, Tensor): + img = F.pil_to_tensor(img) + img = F.convert_image_dtype(img, torch.float) + img = F.normalize(img, mean=self.mean, std=self.std) + return img + + def __repr__(self) -> str: + format_string = self.__class__.__name__ + "(" + format_string += f"\n crop_size={self.crop_size}" + format_string += f"\n resize_size={self.resize_size}" + format_string += f"\n mean={self.mean}" + format_string += f"\n std={self.std}" + format_string += f"\n interpolation={self.interpolation}" + format_string += "\n)" + return format_string + + def describe(self) -> str: + return ( + "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. " + f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, " + f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to " + f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``." + ) + + +class VideoClassification(nn.Module): + def __init__( + self, + *, + crop_size: Tuple[int, int], + resize_size: Union[Tuple[int], Tuple[int, int]], + mean: Tuple[float, ...] = (0.43216, 0.394666, 0.37645), + std: Tuple[float, ...] = (0.22803, 0.22145, 0.216989), + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + ) -> None: + super().__init__() + self.crop_size = list(crop_size) + self.resize_size = list(resize_size) + self.mean = list(mean) + self.std = list(std) + self.interpolation = interpolation + + def forward(self, vid: Tensor) -> Tensor: + need_squeeze = False + if vid.ndim < 5: + vid = vid.unsqueeze(dim=0) + need_squeeze = True + + N, T, C, H, W = vid.shape + vid = vid.view(-1, C, H, W) + # We hard-code antialias=False to preserve results after we changed + # its default from None to True (see + # https://github.com/pytorch/vision/pull/7160) + # TODO: we could re-train the video models with antialias=True? + vid = F.resize(vid, self.resize_size, interpolation=self.interpolation, antialias=False) + vid = F.center_crop(vid, self.crop_size) + vid = F.convert_image_dtype(vid, torch.float) + vid = F.normalize(vid, mean=self.mean, std=self.std) + H, W = self.crop_size + vid = vid.view(N, T, C, H, W) + vid = vid.permute(0, 2, 1, 3, 4) # (N, T, C, H, W) => (N, C, T, H, W) + + if need_squeeze: + vid = vid.squeeze(dim=0) + return vid + + def __repr__(self) -> str: + format_string = self.__class__.__name__ + "(" + format_string += f"\n crop_size={self.crop_size}" + format_string += f"\n resize_size={self.resize_size}" + format_string += f"\n mean={self.mean}" + format_string += f"\n std={self.std}" + format_string += f"\n interpolation={self.interpolation}" + format_string += "\n)" + return format_string + + def describe(self) -> str: + return ( + "Accepts batched ``(B, T, C, H, W)`` and single ``(T, C, H, W)`` video frame ``torch.Tensor`` objects. " + f"The frames are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, " + f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to " + f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``. Finally the output " + "dimensions are permuted to ``(..., C, T, H, W)`` tensors." + ) + + +class SemanticSegmentation(nn.Module): + def __init__( + self, + *, + resize_size: Optional[int], + mean: Tuple[float, ...] = (0.485, 0.456, 0.406), + std: Tuple[float, ...] = (0.229, 0.224, 0.225), + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + antialias: Optional[bool] = True, + ) -> None: + super().__init__() + self.resize_size = [resize_size] if resize_size is not None else None + self.mean = list(mean) + self.std = list(std) + self.interpolation = interpolation + self.antialias = antialias + + def forward(self, img: Tensor) -> Tensor: + if isinstance(self.resize_size, list): + img = F.resize(img, self.resize_size, interpolation=self.interpolation, antialias=self.antialias) + if not isinstance(img, Tensor): + img = F.pil_to_tensor(img) + img = F.convert_image_dtype(img, torch.float) + img = F.normalize(img, mean=self.mean, std=self.std) + return img + + def __repr__(self) -> str: + format_string = self.__class__.__name__ + "(" + format_string += f"\n resize_size={self.resize_size}" + format_string += f"\n mean={self.mean}" + format_string += f"\n std={self.std}" + format_string += f"\n interpolation={self.interpolation}" + format_string += "\n)" + return format_string + + def describe(self) -> str: + return ( + "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. " + f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``. " + f"Finally the values are first rescaled to ``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and " + f"``std={self.std}``." + ) + + +class OpticalFlow(nn.Module): + def forward(self, img1: Tensor, img2: Tensor) -> Tuple[Tensor, Tensor]: + if not isinstance(img1, Tensor): + img1 = F.pil_to_tensor(img1) + if not isinstance(img2, Tensor): + img2 = F.pil_to_tensor(img2) + + img1 = F.convert_image_dtype(img1, torch.float) + img2 = F.convert_image_dtype(img2, torch.float) + + # map [0, 1] into [-1, 1] + img1 = F.normalize(img1, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + img2 = F.normalize(img2, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + img1 = img1.contiguous() + img2 = img2.contiguous() + + return img1, img2 + + def __repr__(self) -> str: + return self.__class__.__name__ + "()" + + def describe(self) -> str: + return ( + "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. " + "The images are rescaled to ``[-1.0, 1.0]``." + ) diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/_transforms_video.py b/.venv/lib/python3.11/site-packages/torchvision/transforms/_transforms_video.py new file mode 100644 index 0000000000000000000000000000000000000000..a04da4f74849805641e4c470f6b6b8d5f7000e3a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/transforms/_transforms_video.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python3 + +import numbers +import random +import warnings + +from torchvision.transforms import RandomCrop, RandomResizedCrop + +from . import _functional_video as F + + +__all__ = [ + "RandomCropVideo", + "RandomResizedCropVideo", + "CenterCropVideo", + "NormalizeVideo", + "ToTensorVideo", + "RandomHorizontalFlipVideo", +] + + +warnings.warn( + "The 'torchvision.transforms._transforms_video' module is deprecated since 0.12 and will be removed in the future. " + "Please use the 'torchvision.transforms' module instead." +) + + +class RandomCropVideo(RandomCrop): + def __init__(self, size): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) + Returns: + torch.tensor: randomly cropped/resized video clip. + size is (C, T, OH, OW) + """ + i, j, h, w = self.get_params(clip, self.size) + return F.crop(clip, i, j, h, w) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size})" + + +class RandomResizedCropVideo(RandomResizedCrop): + def __init__( + self, + size, + scale=(0.08, 1.0), + ratio=(3.0 / 4.0, 4.0 / 3.0), + interpolation_mode="bilinear", + ): + if isinstance(size, tuple): + if len(size) != 2: + raise ValueError(f"size should be tuple (height, width), instead got {size}") + self.size = size + else: + self.size = (size, size) + + self.interpolation_mode = interpolation_mode + self.scale = scale + self.ratio = ratio + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) + Returns: + torch.tensor: randomly cropped/resized video clip. + size is (C, T, H, W) + """ + i, j, h, w = self.get_params(clip, self.scale, self.ratio) + return F.resized_crop(clip, i, j, h, w, self.size, self.interpolation_mode) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}, scale={self.scale}, ratio={self.ratio})" + + +class CenterCropVideo: + def __init__(self, crop_size): + if isinstance(crop_size, numbers.Number): + self.crop_size = (int(crop_size), int(crop_size)) + else: + self.crop_size = crop_size + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) + Returns: + torch.tensor: central cropping of video clip. Size is + (C, T, crop_size, crop_size) + """ + return F.center_crop(clip, self.crop_size) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(crop_size={self.crop_size})" + + +class NormalizeVideo: + """ + Normalize the video clip by mean subtraction and division by standard deviation + Args: + mean (3-tuple): pixel RGB mean + std (3-tuple): pixel RGB standard deviation + inplace (boolean): whether do in-place normalization + """ + + def __init__(self, mean, std, inplace=False): + self.mean = mean + self.std = std + self.inplace = inplace + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): video clip to be normalized. Size is (C, T, H, W) + """ + return F.normalize(clip, self.mean, self.std, self.inplace) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})" + + +class ToTensorVideo: + """ + Convert tensor data type from uint8 to float, divide value by 255.0 and + permute the dimensions of clip tensor + """ + + def __init__(self): + pass + + def __call__(self, clip): + """ + Args: + clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C) + Return: + clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W) + """ + return F.to_tensor(clip) + + def __repr__(self) -> str: + return self.__class__.__name__ + + +class RandomHorizontalFlipVideo: + """ + Flip the video clip along the horizontal direction with a given probability + Args: + p (float): probability of the clip being flipped. Default value is 0.5 + """ + + def __init__(self, p=0.5): + self.p = p + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Size is (C, T, H, W) + Return: + clip (torch.tensor): Size is (C, T, H, W) + """ + if random.random() < self.p: + clip = F.hflip(clip) + return clip + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(p={self.p})" diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/autoaugment.py b/.venv/lib/python3.11/site-packages/torchvision/transforms/autoaugment.py new file mode 100644 index 0000000000000000000000000000000000000000..9dbbe91e741093b01ff8491ba8b39d9b6f578103 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/transforms/autoaugment.py @@ -0,0 +1,615 @@ +import math +from enum import Enum +from typing import Dict, List, Optional, Tuple + +import torch +from torch import Tensor + +from . import functional as F, InterpolationMode + +__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide", "AugMix"] + + +def _apply_op( + img: Tensor, op_name: str, magnitude: float, interpolation: InterpolationMode, fill: Optional[List[float]] +): + if op_name == "ShearX": + # magnitude should be arctan(magnitude) + # official autoaug: (1, level, 0, 0, 1, 0) + # https://github.com/tensorflow/models/blob/dd02069717128186b88afa8d857ce57d17957f03/research/autoaugment/augmentation_transforms.py#L290 + # compared to + # torchvision: (1, tan(level), 0, 0, 1, 0) + # https://github.com/pytorch/vision/blob/0c2373d0bba3499e95776e7936e207d8a1676e65/torchvision/transforms/functional.py#L976 + img = F.affine( + img, + angle=0.0, + translate=[0, 0], + scale=1.0, + shear=[math.degrees(math.atan(magnitude)), 0.0], + interpolation=interpolation, + fill=fill, + center=[0, 0], + ) + elif op_name == "ShearY": + # magnitude should be arctan(magnitude) + # See above + img = F.affine( + img, + angle=0.0, + translate=[0, 0], + scale=1.0, + shear=[0.0, math.degrees(math.atan(magnitude))], + interpolation=interpolation, + fill=fill, + center=[0, 0], + ) + elif op_name == "TranslateX": + img = F.affine( + img, + angle=0.0, + translate=[int(magnitude), 0], + scale=1.0, + interpolation=interpolation, + shear=[0.0, 0.0], + fill=fill, + ) + elif op_name == "TranslateY": + img = F.affine( + img, + angle=0.0, + translate=[0, int(magnitude)], + scale=1.0, + interpolation=interpolation, + shear=[0.0, 0.0], + fill=fill, + ) + elif op_name == "Rotate": + img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill) + elif op_name == "Brightness": + img = F.adjust_brightness(img, 1.0 + magnitude) + elif op_name == "Color": + img = F.adjust_saturation(img, 1.0 + magnitude) + elif op_name == "Contrast": + img = F.adjust_contrast(img, 1.0 + magnitude) + elif op_name == "Sharpness": + img = F.adjust_sharpness(img, 1.0 + magnitude) + elif op_name == "Posterize": + img = F.posterize(img, int(magnitude)) + elif op_name == "Solarize": + img = F.solarize(img, magnitude) + elif op_name == "AutoContrast": + img = F.autocontrast(img) + elif op_name == "Equalize": + img = F.equalize(img) + elif op_name == "Invert": + img = F.invert(img) + elif op_name == "Identity": + pass + else: + raise ValueError(f"The provided operator {op_name} is not recognized.") + return img + + +class AutoAugmentPolicy(Enum): + """AutoAugment policies learned on different datasets. + Available policies are IMAGENET, CIFAR10 and SVHN. + """ + + IMAGENET = "imagenet" + CIFAR10 = "cifar10" + SVHN = "svhn" + + +# FIXME: Eliminate copy-pasted code for fill standardization and _augmentation_space() by moving stuff on a base class +class AutoAugment(torch.nn.Module): + r"""AutoAugment data augmentation method based on + `"AutoAugment: Learning Augmentation Strategies from Data" `_. + If the image is torch Tensor, it should be of type torch.uint8, and it is expected + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + policy (AutoAugmentPolicy): Desired policy enum defined by + :class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + fill (sequence or number, optional): Pixel fill value for the area outside the transformed + image. If given a number, the value is used for all bands respectively. + """ + + def __init__( + self, + policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None, + ) -> None: + super().__init__() + self.policy = policy + self.interpolation = interpolation + self.fill = fill + self.policies = self._get_policies(policy) + + def _get_policies( + self, policy: AutoAugmentPolicy + ) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]: + if policy == AutoAugmentPolicy.IMAGENET: + return [ + (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)), + (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), + (("Equalize", 0.8, None), ("Equalize", 0.6, None)), + (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)), + (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), + (("Equalize", 0.4, None), ("Rotate", 0.8, 8)), + (("Solarize", 0.6, 3), ("Equalize", 0.6, None)), + (("Posterize", 0.8, 5), ("Equalize", 1.0, None)), + (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)), + (("Equalize", 0.6, None), ("Posterize", 0.4, 6)), + (("Rotate", 0.8, 8), ("Color", 0.4, 0)), + (("Rotate", 0.4, 9), ("Equalize", 0.6, None)), + (("Equalize", 0.0, None), ("Equalize", 0.8, None)), + (("Invert", 0.6, None), ("Equalize", 1.0, None)), + (("Color", 0.6, 4), ("Contrast", 1.0, 8)), + (("Rotate", 0.8, 8), ("Color", 1.0, 2)), + (("Color", 0.8, 8), ("Solarize", 0.8, 7)), + (("Sharpness", 0.4, 7), ("Invert", 0.6, None)), + (("ShearX", 0.6, 5), ("Equalize", 1.0, None)), + (("Color", 0.4, 0), ("Equalize", 0.6, None)), + (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), + (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), + (("Invert", 0.6, None), ("Equalize", 1.0, None)), + (("Color", 0.6, 4), ("Contrast", 1.0, 8)), + (("Equalize", 0.8, None), ("Equalize", 0.6, None)), + ] + elif policy == AutoAugmentPolicy.CIFAR10: + return [ + (("Invert", 0.1, None), ("Contrast", 0.2, 6)), + (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)), + (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)), + (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)), + (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)), + (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)), + (("Color", 0.4, 3), ("Brightness", 0.6, 7)), + (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)), + (("Equalize", 0.6, None), ("Equalize", 0.5, None)), + (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)), + (("Color", 0.7, 7), ("TranslateX", 0.5, 8)), + (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)), + (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)), + (("Brightness", 0.9, 6), ("Color", 0.2, 8)), + (("Solarize", 0.5, 2), ("Invert", 0.0, None)), + (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)), + (("Equalize", 0.2, None), ("Equalize", 0.6, None)), + (("Color", 0.9, 9), ("Equalize", 0.6, None)), + (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)), + (("Brightness", 0.1, 3), ("Color", 0.7, 0)), + (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)), + (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)), + (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)), + (("Equalize", 0.8, None), ("Invert", 0.1, None)), + (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)), + ] + elif policy == AutoAugmentPolicy.SVHN: + return [ + (("ShearX", 0.9, 4), ("Invert", 0.2, None)), + (("ShearY", 0.9, 8), ("Invert", 0.7, None)), + (("Equalize", 0.6, None), ("Solarize", 0.6, 6)), + (("Invert", 0.9, None), ("Equalize", 0.6, None)), + (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), + (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)), + (("ShearY", 0.9, 8), ("Invert", 0.4, None)), + (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)), + (("Invert", 0.9, None), ("AutoContrast", 0.8, None)), + (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), + (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)), + (("ShearY", 0.8, 8), ("Invert", 0.7, None)), + (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)), + (("Invert", 0.9, None), ("Equalize", 0.6, None)), + (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)), + (("Invert", 0.8, None), ("TranslateY", 0.0, 2)), + (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)), + (("Invert", 0.6, None), ("Rotate", 0.8, 4)), + (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)), + (("ShearX", 0.1, 6), ("Invert", 0.6, None)), + (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)), + (("ShearY", 0.8, 4), ("Invert", 0.8, None)), + (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)), + (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)), + (("ShearX", 0.7, 2), ("Invert", 0.1, None)), + ] + else: + raise ValueError(f"The provided policy {policy} is not recognized.") + + def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[Tensor, bool]]: + return { + # op_name: (magnitudes, signed) + "ShearX": (torch.linspace(0.0, 0.3, num_bins), True), + "ShearY": (torch.linspace(0.0, 0.3, num_bins), True), + "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), + "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), + "Rotate": (torch.linspace(0.0, 30.0, num_bins), True), + "Brightness": (torch.linspace(0.0, 0.9, num_bins), True), + "Color": (torch.linspace(0.0, 0.9, num_bins), True), + "Contrast": (torch.linspace(0.0, 0.9, num_bins), True), + "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True), + "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False), + "Solarize": (torch.linspace(255.0, 0.0, num_bins), False), + "AutoContrast": (torch.tensor(0.0), False), + "Equalize": (torch.tensor(0.0), False), + "Invert": (torch.tensor(0.0), False), + } + + @staticmethod + def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]: + """Get parameters for autoaugment transformation + + Returns: + params required by the autoaugment transformation + """ + policy_id = int(torch.randint(transform_num, (1,)).item()) + probs = torch.rand((2,)) + signs = torch.randint(2, (2,)) + + return policy_id, probs, signs + + def forward(self, img: Tensor) -> Tensor: + """ + img (PIL Image or Tensor): Image to be transformed. + + Returns: + PIL Image or Tensor: AutoAugmented image. + """ + fill = self.fill + channels, height, width = F.get_dimensions(img) + if isinstance(img, Tensor): + if isinstance(fill, (int, float)): + fill = [float(fill)] * channels + elif fill is not None: + fill = [float(f) for f in fill] + + transform_id, probs, signs = self.get_params(len(self.policies)) + + op_meta = self._augmentation_space(10, (height, width)) + for i, (op_name, p, magnitude_id) in enumerate(self.policies[transform_id]): + if probs[i] <= p: + magnitudes, signed = op_meta[op_name] + magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0 + if signed and signs[i] == 0: + magnitude *= -1.0 + img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) + + return img + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(policy={self.policy}, fill={self.fill})" + + +class RandAugment(torch.nn.Module): + r"""RandAugment data augmentation method based on + `"RandAugment: Practical automated data augmentation with a reduced search space" + `_. + If the image is torch Tensor, it should be of type torch.uint8, and it is expected + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + num_ops (int): Number of augmentation transformations to apply sequentially. + magnitude (int): Magnitude for all the transformations. + num_magnitude_bins (int): The number of different magnitude values. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + fill (sequence or number, optional): Pixel fill value for the area outside the transformed + image. If given a number, the value is used for all bands respectively. + """ + + def __init__( + self, + num_ops: int = 2, + magnitude: int = 9, + num_magnitude_bins: int = 31, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None, + ) -> None: + super().__init__() + self.num_ops = num_ops + self.magnitude = magnitude + self.num_magnitude_bins = num_magnitude_bins + self.interpolation = interpolation + self.fill = fill + + def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[Tensor, bool]]: + return { + # op_name: (magnitudes, signed) + "Identity": (torch.tensor(0.0), False), + "ShearX": (torch.linspace(0.0, 0.3, num_bins), True), + "ShearY": (torch.linspace(0.0, 0.3, num_bins), True), + "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), + "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), + "Rotate": (torch.linspace(0.0, 30.0, num_bins), True), + "Brightness": (torch.linspace(0.0, 0.9, num_bins), True), + "Color": (torch.linspace(0.0, 0.9, num_bins), True), + "Contrast": (torch.linspace(0.0, 0.9, num_bins), True), + "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True), + "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False), + "Solarize": (torch.linspace(255.0, 0.0, num_bins), False), + "AutoContrast": (torch.tensor(0.0), False), + "Equalize": (torch.tensor(0.0), False), + } + + def forward(self, img: Tensor) -> Tensor: + """ + img (PIL Image or Tensor): Image to be transformed. + + Returns: + PIL Image or Tensor: Transformed image. + """ + fill = self.fill + channels, height, width = F.get_dimensions(img) + if isinstance(img, Tensor): + if isinstance(fill, (int, float)): + fill = [float(fill)] * channels + elif fill is not None: + fill = [float(f) for f in fill] + + op_meta = self._augmentation_space(self.num_magnitude_bins, (height, width)) + for _ in range(self.num_ops): + op_index = int(torch.randint(len(op_meta), (1,)).item()) + op_name = list(op_meta.keys())[op_index] + magnitudes, signed = op_meta[op_name] + magnitude = float(magnitudes[self.magnitude].item()) if magnitudes.ndim > 0 else 0.0 + if signed and torch.randint(2, (1,)): + magnitude *= -1.0 + img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) + + return img + + def __repr__(self) -> str: + s = ( + f"{self.__class__.__name__}(" + f"num_ops={self.num_ops}" + f", magnitude={self.magnitude}" + f", num_magnitude_bins={self.num_magnitude_bins}" + f", interpolation={self.interpolation}" + f", fill={self.fill}" + f")" + ) + return s + + +class TrivialAugmentWide(torch.nn.Module): + r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in + `"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" `_. + If the image is torch Tensor, it should be of type torch.uint8, and it is expected + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + num_magnitude_bins (int): The number of different magnitude values. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + fill (sequence or number, optional): Pixel fill value for the area outside the transformed + image. If given a number, the value is used for all bands respectively. + """ + + def __init__( + self, + num_magnitude_bins: int = 31, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None, + ) -> None: + super().__init__() + self.num_magnitude_bins = num_magnitude_bins + self.interpolation = interpolation + self.fill = fill + + def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]: + return { + # op_name: (magnitudes, signed) + "Identity": (torch.tensor(0.0), False), + "ShearX": (torch.linspace(0.0, 0.99, num_bins), True), + "ShearY": (torch.linspace(0.0, 0.99, num_bins), True), + "TranslateX": (torch.linspace(0.0, 32.0, num_bins), True), + "TranslateY": (torch.linspace(0.0, 32.0, num_bins), True), + "Rotate": (torch.linspace(0.0, 135.0, num_bins), True), + "Brightness": (torch.linspace(0.0, 0.99, num_bins), True), + "Color": (torch.linspace(0.0, 0.99, num_bins), True), + "Contrast": (torch.linspace(0.0, 0.99, num_bins), True), + "Sharpness": (torch.linspace(0.0, 0.99, num_bins), True), + "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False), + "Solarize": (torch.linspace(255.0, 0.0, num_bins), False), + "AutoContrast": (torch.tensor(0.0), False), + "Equalize": (torch.tensor(0.0), False), + } + + def forward(self, img: Tensor) -> Tensor: + """ + img (PIL Image or Tensor): Image to be transformed. + + Returns: + PIL Image or Tensor: Transformed image. + """ + fill = self.fill + channels, height, width = F.get_dimensions(img) + if isinstance(img, Tensor): + if isinstance(fill, (int, float)): + fill = [float(fill)] * channels + elif fill is not None: + fill = [float(f) for f in fill] + + op_meta = self._augmentation_space(self.num_magnitude_bins) + op_index = int(torch.randint(len(op_meta), (1,)).item()) + op_name = list(op_meta.keys())[op_index] + magnitudes, signed = op_meta[op_name] + magnitude = ( + float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item()) + if magnitudes.ndim > 0 + else 0.0 + ) + if signed and torch.randint(2, (1,)): + magnitude *= -1.0 + + return _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) + + def __repr__(self) -> str: + s = ( + f"{self.__class__.__name__}(" + f"num_magnitude_bins={self.num_magnitude_bins}" + f", interpolation={self.interpolation}" + f", fill={self.fill}" + f")" + ) + return s + + +class AugMix(torch.nn.Module): + r"""AugMix data augmentation method based on + `"AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty" `_. + If the image is torch Tensor, it should be of type torch.uint8, and it is expected + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + severity (int): The severity of base augmentation operators. Default is ``3``. + mixture_width (int): The number of augmentation chains. Default is ``3``. + chain_depth (int): The depth of augmentation chains. A negative value denotes stochastic depth sampled from the interval [1, 3]. + Default is ``-1``. + alpha (float): The hyperparameter for the probability distributions. Default is ``1.0``. + all_ops (bool): Use all operations (including brightness, contrast, color and sharpness). Default is ``True``. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + fill (sequence or number, optional): Pixel fill value for the area outside the transformed + image. If given a number, the value is used for all bands respectively. + """ + + def __init__( + self, + severity: int = 3, + mixture_width: int = 3, + chain_depth: int = -1, + alpha: float = 1.0, + all_ops: bool = True, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: Optional[List[float]] = None, + ) -> None: + super().__init__() + self._PARAMETER_MAX = 10 + if not (1 <= severity <= self._PARAMETER_MAX): + raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.") + self.severity = severity + self.mixture_width = mixture_width + self.chain_depth = chain_depth + self.alpha = alpha + self.all_ops = all_ops + self.interpolation = interpolation + self.fill = fill + + def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[Tensor, bool]]: + s = { + # op_name: (magnitudes, signed) + "ShearX": (torch.linspace(0.0, 0.3, num_bins), True), + "ShearY": (torch.linspace(0.0, 0.3, num_bins), True), + "TranslateX": (torch.linspace(0.0, image_size[1] / 3.0, num_bins), True), + "TranslateY": (torch.linspace(0.0, image_size[0] / 3.0, num_bins), True), + "Rotate": (torch.linspace(0.0, 30.0, num_bins), True), + "Posterize": (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False), + "Solarize": (torch.linspace(255.0, 0.0, num_bins), False), + "AutoContrast": (torch.tensor(0.0), False), + "Equalize": (torch.tensor(0.0), False), + } + if self.all_ops: + s.update( + { + "Brightness": (torch.linspace(0.0, 0.9, num_bins), True), + "Color": (torch.linspace(0.0, 0.9, num_bins), True), + "Contrast": (torch.linspace(0.0, 0.9, num_bins), True), + "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True), + } + ) + return s + + @torch.jit.unused + def _pil_to_tensor(self, img) -> Tensor: + return F.pil_to_tensor(img) + + @torch.jit.unused + def _tensor_to_pil(self, img: Tensor): + return F.to_pil_image(img) + + def _sample_dirichlet(self, params: Tensor) -> Tensor: + # Must be on a separate method so that we can overwrite it in tests. + return torch._sample_dirichlet(params) + + def forward(self, orig_img: Tensor) -> Tensor: + """ + img (PIL Image or Tensor): Image to be transformed. + + Returns: + PIL Image or Tensor: Transformed image. + """ + fill = self.fill + channels, height, width = F.get_dimensions(orig_img) + if isinstance(orig_img, Tensor): + img = orig_img + if isinstance(fill, (int, float)): + fill = [float(fill)] * channels + elif fill is not None: + fill = [float(f) for f in fill] + else: + img = self._pil_to_tensor(orig_img) + + op_meta = self._augmentation_space(self._PARAMETER_MAX, (height, width)) + + orig_dims = list(img.shape) + batch = img.view([1] * max(4 - img.ndim, 0) + orig_dims) + batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1) + + # Sample the beta weights for combining the original and augmented image. To get Beta, we use a Dirichlet + # with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of augmented image. + m = self._sample_dirichlet( + torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1) + ) + + # Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images. + combined_weights = self._sample_dirichlet( + torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1) + ) * m[:, 1].view([batch_dims[0], -1]) + + mix = m[:, 0].view(batch_dims) * batch + for i in range(self.mixture_width): + aug = batch + depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item()) + for _ in range(depth): + op_index = int(torch.randint(len(op_meta), (1,)).item()) + op_name = list(op_meta.keys())[op_index] + magnitudes, signed = op_meta[op_name] + magnitude = ( + float(magnitudes[torch.randint(self.severity, (1,), dtype=torch.long)].item()) + if magnitudes.ndim > 0 + else 0.0 + ) + if signed and torch.randint(2, (1,)): + magnitude *= -1.0 + aug = _apply_op(aug, op_name, magnitude, interpolation=self.interpolation, fill=fill) + mix.add_(combined_weights[:, i].view(batch_dims) * aug) + mix = mix.view(orig_dims).to(dtype=img.dtype) + + if not isinstance(orig_img, Tensor): + return self._tensor_to_pil(mix) + return mix + + def __repr__(self) -> str: + s = ( + f"{self.__class__.__name__}(" + f"severity={self.severity}" + f", mixture_width={self.mixture_width}" + f", chain_depth={self.chain_depth}" + f", alpha={self.alpha}" + f", all_ops={self.all_ops}" + f", interpolation={self.interpolation}" + f", fill={self.fill}" + f")" + ) + return s diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/functional.py b/.venv/lib/python3.11/site-packages/torchvision/transforms/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..8efe2a8878a06b37d7ede9496b213076c1f59c01 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/transforms/functional.py @@ -0,0 +1,1586 @@ +import math +import numbers +import sys +import warnings +from enum import Enum +from typing import Any, List, Optional, Tuple, Union + +import numpy as np +import torch +from PIL import Image +from PIL.Image import Image as PILImage +from torch import Tensor + +try: + import accimage +except ImportError: + accimage = None + +from ..utils import _log_api_usage_once +from . import _functional_pil as F_pil, _functional_tensor as F_t + + +class InterpolationMode(Enum): + """Interpolation modes + Available interpolation methods are ``nearest``, ``nearest-exact``, ``bilinear``, ``bicubic``, ``box``, ``hamming``, + and ``lanczos``. + """ + + NEAREST = "nearest" + NEAREST_EXACT = "nearest-exact" + BILINEAR = "bilinear" + BICUBIC = "bicubic" + # For PIL compatibility + BOX = "box" + HAMMING = "hamming" + LANCZOS = "lanczos" + + +# TODO: Once torchscript supports Enums with staticmethod +# this can be put into InterpolationMode as staticmethod +def _interpolation_modes_from_int(i: int) -> InterpolationMode: + inverse_modes_mapping = { + 0: InterpolationMode.NEAREST, + 2: InterpolationMode.BILINEAR, + 3: InterpolationMode.BICUBIC, + 4: InterpolationMode.BOX, + 5: InterpolationMode.HAMMING, + 1: InterpolationMode.LANCZOS, + } + return inverse_modes_mapping[i] + + +pil_modes_mapping = { + InterpolationMode.NEAREST: 0, + InterpolationMode.BILINEAR: 2, + InterpolationMode.BICUBIC: 3, + InterpolationMode.NEAREST_EXACT: 0, + InterpolationMode.BOX: 4, + InterpolationMode.HAMMING: 5, + InterpolationMode.LANCZOS: 1, +} + +_is_pil_image = F_pil._is_pil_image + + +def get_dimensions(img: Tensor) -> List[int]: + """Returns the dimensions of an image as [channels, height, width]. + + Args: + img (PIL Image or Tensor): The image to be checked. + + Returns: + List[int]: The image dimensions. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(get_dimensions) + if isinstance(img, torch.Tensor): + return F_t.get_dimensions(img) + + return F_pil.get_dimensions(img) + + +def get_image_size(img: Tensor) -> List[int]: + """Returns the size of an image as [width, height]. + + Args: + img (PIL Image or Tensor): The image to be checked. + + Returns: + List[int]: The image size. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(get_image_size) + if isinstance(img, torch.Tensor): + return F_t.get_image_size(img) + + return F_pil.get_image_size(img) + + +def get_image_num_channels(img: Tensor) -> int: + """Returns the number of channels of an image. + + Args: + img (PIL Image or Tensor): The image to be checked. + + Returns: + int: The number of channels. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(get_image_num_channels) + if isinstance(img, torch.Tensor): + return F_t.get_image_num_channels(img) + + return F_pil.get_image_num_channels(img) + + +@torch.jit.unused +def _is_numpy(img: Any) -> bool: + return isinstance(img, np.ndarray) + + +@torch.jit.unused +def _is_numpy_image(img: Any) -> bool: + return img.ndim in {2, 3} + + +def to_tensor(pic: Union[PILImage, np.ndarray]) -> Tensor: + """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. + This function does not support torchscript. + + See :class:`~torchvision.transforms.ToTensor` for more details. + + Args: + pic (PIL Image or numpy.ndarray): Image to be converted to tensor. + + Returns: + Tensor: Converted image. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(to_tensor) + if not (F_pil._is_pil_image(pic) or _is_numpy(pic)): + raise TypeError(f"pic should be PIL Image or ndarray. Got {type(pic)}") + + if _is_numpy(pic) and not _is_numpy_image(pic): + raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndim} dimensions.") + + default_float_dtype = torch.get_default_dtype() + + if isinstance(pic, np.ndarray): + # handle numpy array + if pic.ndim == 2: + pic = pic[:, :, None] + + img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous() + # backward compatibility + if isinstance(img, torch.ByteTensor): + return img.to(dtype=default_float_dtype).div(255) + else: + return img + + if accimage is not None and isinstance(pic, accimage.Image): + nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32) + pic.copyto(nppic) + return torch.from_numpy(nppic).to(dtype=default_float_dtype) + + # handle PIL Image + mode_to_nptype = {"I": np.int32, "I;16" if sys.byteorder == "little" else "I;16B": np.int16, "F": np.float32} + img = torch.from_numpy(np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True)) + + if pic.mode == "1": + img = 255 * img + img = img.view(pic.size[1], pic.size[0], F_pil.get_image_num_channels(pic)) + # put it from HWC to CHW format + img = img.permute((2, 0, 1)).contiguous() + if isinstance(img, torch.ByteTensor): + return img.to(dtype=default_float_dtype).div(255) + else: + return img + + +def pil_to_tensor(pic: Any) -> Tensor: + """Convert a ``PIL Image`` to a tensor of the same type. + This function does not support torchscript. + + See :class:`~torchvision.transforms.PILToTensor` for more details. + + .. note:: + + A deep copy of the underlying array is performed. + + Args: + pic (PIL Image): Image to be converted to tensor. + + Returns: + Tensor: Converted image. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(pil_to_tensor) + if not F_pil._is_pil_image(pic): + raise TypeError(f"pic should be PIL Image. Got {type(pic)}") + + if accimage is not None and isinstance(pic, accimage.Image): + # accimage format is always uint8 internally, so always return uint8 here + nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.uint8) + pic.copyto(nppic) + return torch.as_tensor(nppic) + + # handle PIL Image + img = torch.as_tensor(np.array(pic, copy=True)) + img = img.view(pic.size[1], pic.size[0], F_pil.get_image_num_channels(pic)) + # put it from HWC to CHW format + img = img.permute((2, 0, 1)) + return img + + +def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor: + """Convert a tensor image to the given ``dtype`` and scale the values accordingly + This function does not support PIL Image. + + Args: + image (torch.Tensor): Image to be converted + dtype (torch.dtype): Desired data type of the output + + Returns: + Tensor: Converted image + + .. note:: + + When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly. + If converted back and forth, this mismatch has no effect. + + Raises: + RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as + well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to + overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range + of the integer ``dtype``. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(convert_image_dtype) + if not isinstance(image, torch.Tensor): + raise TypeError("Input img should be Tensor Image") + + return F_t.convert_image_dtype(image, dtype) + + +def to_pil_image(pic, mode=None): + """Convert a tensor or an ndarray to PIL Image. This function does not support torchscript. + + See :class:`~torchvision.transforms.ToPILImage` for more details. + + Args: + pic (Tensor or numpy.ndarray): Image to be converted to PIL Image. + mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). + + .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes + + Returns: + PIL Image: Image converted to PIL Image. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(to_pil_image) + + if isinstance(pic, torch.Tensor): + if pic.ndim == 3: + pic = pic.permute((1, 2, 0)) + pic = pic.numpy(force=True) + elif not isinstance(pic, np.ndarray): + raise TypeError(f"pic should be Tensor or ndarray. Got {type(pic)}.") + + if pic.ndim == 2: + # if 2D image, add channel dimension (HWC) + pic = np.expand_dims(pic, 2) + if pic.ndim != 3: + raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndim} dimensions.") + + if pic.shape[-1] > 4: + raise ValueError(f"pic should not have > 4 channels. Got {pic.shape[-1]} channels.") + + npimg = pic + + if np.issubdtype(npimg.dtype, np.floating) and mode != "F": + npimg = (npimg * 255).astype(np.uint8) + + if npimg.shape[2] == 1: + expected_mode = None + npimg = npimg[:, :, 0] + if npimg.dtype == np.uint8: + expected_mode = "L" + elif npimg.dtype == np.int16: + expected_mode = "I;16" if sys.byteorder == "little" else "I;16B" + elif npimg.dtype == np.int32: + expected_mode = "I" + elif npimg.dtype == np.float32: + expected_mode = "F" + if mode is not None and mode != expected_mode: + raise ValueError(f"Incorrect mode ({mode}) supplied for input type {np.dtype}. Should be {expected_mode}") + mode = expected_mode + + elif npimg.shape[2] == 2: + permitted_2_channel_modes = ["LA"] + if mode is not None and mode not in permitted_2_channel_modes: + raise ValueError(f"Only modes {permitted_2_channel_modes} are supported for 2D inputs") + + if mode is None and npimg.dtype == np.uint8: + mode = "LA" + + elif npimg.shape[2] == 4: + permitted_4_channel_modes = ["RGBA", "CMYK", "RGBX"] + if mode is not None and mode not in permitted_4_channel_modes: + raise ValueError(f"Only modes {permitted_4_channel_modes} are supported for 4D inputs") + + if mode is None and npimg.dtype == np.uint8: + mode = "RGBA" + else: + permitted_3_channel_modes = ["RGB", "YCbCr", "HSV"] + if mode is not None and mode not in permitted_3_channel_modes: + raise ValueError(f"Only modes {permitted_3_channel_modes} are supported for 3D inputs") + if mode is None and npimg.dtype == np.uint8: + mode = "RGB" + + if mode is None: + raise TypeError(f"Input type {npimg.dtype} is not supported") + + return Image.fromarray(npimg, mode=mode) + + +def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor: + """Normalize a float tensor image with mean and standard deviation. + This transform does not support PIL Image. + + .. note:: + This transform acts out of place by default, i.e., it does not mutates the input tensor. + + See :class:`~torchvision.transforms.Normalize` for more details. + + Args: + tensor (Tensor): Float tensor image of size (C, H, W) or (B, C, H, W) to be normalized. + mean (sequence): Sequence of means for each channel. + std (sequence): Sequence of standard deviations for each channel. + inplace(bool,optional): Bool to make this operation inplace. + + Returns: + Tensor: Normalized Tensor image. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(normalize) + if not isinstance(tensor, torch.Tensor): + raise TypeError(f"img should be Tensor Image. Got {type(tensor)}") + + return F_t.normalize(tensor, mean=mean, std=std, inplace=inplace) + + +def _compute_resized_output_size( + image_size: Tuple[int, int], + size: Optional[List[int]], + max_size: Optional[int] = None, + allow_size_none: bool = False, # only True in v2 +) -> List[int]: + h, w = image_size + short, long = (w, h) if w <= h else (h, w) + if size is None: + if not allow_size_none: + raise ValueError("This should never happen!!") + if not isinstance(max_size, int): + raise ValueError(f"max_size must be an integer when size is None, but got {max_size} instead.") + new_short, new_long = int(max_size * short / long), max_size + new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short) + elif len(size) == 1: # specified size only for the smallest edge + requested_new_short = size if isinstance(size, int) else size[0] + new_short, new_long = requested_new_short, int(requested_new_short * long / short) + + if max_size is not None: + if max_size <= requested_new_short: + raise ValueError( + f"max_size = {max_size} must be strictly greater than the requested " + f"size for the smaller edge size = {size}" + ) + if new_long > max_size: + new_short, new_long = int(max_size * new_short / new_long), max_size + + new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short) + else: # specified both h and w + new_w, new_h = size[1], size[0] + return [new_h, new_w] + + +def resize( + img: Tensor, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[bool] = True, +) -> Tensor: + r"""Resize the input image to the given size. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions + + Args: + img (PIL Image or Tensor): Image to be resized. + size (sequence or int): Desired output size. If size is a sequence like + (h, w), the output size will be matched to this. If size is an int, + the smaller edge of the image will be matched to this number maintaining + the aspect ratio. i.e, if height > width, then image will be rescaled to + :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`. + + .. note:: + In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. + Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``, + ``InterpolationMode.NEAREST_EXACT``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are + supported. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. + max_size (int, optional): The maximum allowed for the longer edge of + the resized image. If the longer edge of the image is greater + than ``max_size`` after being resized according to ``size``, + ``size`` will be overruled so that the longer edge is equal to + ``max_size``. + As a result, the smaller edge may be shorter than ``size``. This + is only supported if ``size`` is an int (or a sequence of length + 1 in torchscript mode). + antialias (bool, optional): Whether to apply antialiasing. + It only affects **tensors** with bilinear or bicubic modes and it is + ignored otherwise: on PIL images, antialiasing is always applied on + bilinear or bicubic modes; on other modes (for PIL images and + tensors), antialiasing makes no sense and this parameter is ignored. + Possible values are: + + - ``True`` (default): will apply antialiasing for bilinear or bicubic modes. + Other mode aren't affected. This is probably what you want to use. + - ``False``: will not apply antialiasing for tensors on any mode. PIL + images are still antialiased on bilinear or bicubic modes, because + PIL doesn't support no antialias. + - ``None``: equivalent to ``False`` for tensors and ``True`` for + PIL images. This value exists for legacy reasons and you probably + don't want to use it unless you really know what you are doing. + + The default value changed from ``None`` to ``True`` in + v0.17, for the PIL and Tensor backends to be consistent. + + Returns: + PIL Image or Tensor: Resized image. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(resize) + + if isinstance(interpolation, int): + interpolation = _interpolation_modes_from_int(interpolation) + elif not isinstance(interpolation, InterpolationMode): + raise TypeError( + "Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant" + ) + + if isinstance(size, (list, tuple)): + if len(size) not in [1, 2]: + raise ValueError( + f"Size must be an int or a 1 or 2 element tuple/list, not a {len(size)} element tuple/list" + ) + if max_size is not None and len(size) != 1: + raise ValueError( + "max_size should only be passed if size specifies the length of the smaller edge, " + "i.e. size should be an int or a sequence of length 1 in torchscript mode." + ) + + _, image_height, image_width = get_dimensions(img) + if isinstance(size, int): + size = [size] + output_size = _compute_resized_output_size((image_height, image_width), size, max_size) + + if [image_height, image_width] == output_size: + return img + + if not isinstance(img, torch.Tensor): + if antialias is False: + warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.") + pil_interpolation = pil_modes_mapping[interpolation] + return F_pil.resize(img, size=output_size, interpolation=pil_interpolation) + + return F_t.resize(img, size=output_size, interpolation=interpolation.value, antialias=antialias) + + +def pad(img: Tensor, padding: List[int], fill: Union[int, float] = 0, padding_mode: str = "constant") -> Tensor: + r"""Pad the given image on all sides with the given "pad" value. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric, + at most 3 leading dimensions for mode edge, + and an arbitrary number of leading dimensions for mode constant + + Args: + img (PIL Image or Tensor): Image to be padded. + padding (int or sequence): Padding on each border. If a single int is provided this + is used to pad all borders. If sequence of length 2 is provided this is the padding + on left/right and top/bottom respectively. If a sequence of length 4 is provided + this is the padding for the left, top, right and bottom borders respectively. + + .. note:: + In torchscript mode padding as single int is not supported, use a sequence of + length 1: ``[padding, ]``. + fill (number or tuple): Pixel fill value for constant fill. Default is 0. + If a tuple of length 3, it is used to fill R, G, B channels respectively. + This value is only used when the padding_mode is constant. + Only number is supported for torch Tensor. + Only int or tuple value is supported for PIL Image. + padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric. + Default is constant. + + - constant: pads with a constant value, this value is specified with fill + + - edge: pads with the last value at the edge of the image. + If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2 + + - reflect: pads with reflection of image without repeating the last value on the edge. + For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode + will result in [3, 2, 1, 2, 3, 4, 3, 2] + + - symmetric: pads with reflection of image repeating the last value on the edge. + For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode + will result in [2, 1, 1, 2, 3, 4, 4, 3] + + Returns: + PIL Image or Tensor: Padded image. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(pad) + if not isinstance(img, torch.Tensor): + return F_pil.pad(img, padding=padding, fill=fill, padding_mode=padding_mode) + + return F_t.pad(img, padding=padding, fill=fill, padding_mode=padding_mode) + + +def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor: + """Crop the given image at specified location and output size. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + If image size is smaller than output size along any edge, image is padded with 0 and then cropped. + + Args: + img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image. + top (int): Vertical component of the top left corner of the crop box. + left (int): Horizontal component of the top left corner of the crop box. + height (int): Height of the crop box. + width (int): Width of the crop box. + + Returns: + PIL Image or Tensor: Cropped image. + """ + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(crop) + if not isinstance(img, torch.Tensor): + return F_pil.crop(img, top, left, height, width) + + return F_t.crop(img, top, left, height, width) + + +def center_crop(img: Tensor, output_size: List[int]) -> Tensor: + """Crops the given image at the center. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + If image size is smaller than output size along any edge, image is padded with 0 and then center cropped. + + Args: + img (PIL Image or Tensor): Image to be cropped. + output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int, + it is used for both directions. + + Returns: + PIL Image or Tensor: Cropped image. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(center_crop) + if isinstance(output_size, numbers.Number): + output_size = (int(output_size), int(output_size)) + elif isinstance(output_size, (tuple, list)) and len(output_size) == 1: + output_size = (output_size[0], output_size[0]) + + _, image_height, image_width = get_dimensions(img) + crop_height, crop_width = output_size + + if crop_width > image_width or crop_height > image_height: + padding_ltrb = [ + (crop_width - image_width) // 2 if crop_width > image_width else 0, + (crop_height - image_height) // 2 if crop_height > image_height else 0, + (crop_width - image_width + 1) // 2 if crop_width > image_width else 0, + (crop_height - image_height + 1) // 2 if crop_height > image_height else 0, + ] + img = pad(img, padding_ltrb, fill=0) # PIL uses fill value 0 + _, image_height, image_width = get_dimensions(img) + if crop_width == image_width and crop_height == image_height: + return img + + crop_top = int(round((image_height - crop_height) / 2.0)) + crop_left = int(round((image_width - crop_width) / 2.0)) + return crop(img, crop_top, crop_left, crop_height, crop_width) + + +def resized_crop( + img: Tensor, + top: int, + left: int, + height: int, + width: int, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + antialias: Optional[bool] = True, +) -> Tensor: + """Crop the given image and resize it to desired size. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions + + Notably used in :class:`~torchvision.transforms.RandomResizedCrop`. + + Args: + img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image. + top (int): Vertical component of the top left corner of the crop box. + left (int): Horizontal component of the top left corner of the crop box. + height (int): Height of the crop box. + width (int): Width of the crop box. + size (sequence or int): Desired output size. Same semantics as ``resize``. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. + Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``, + ``InterpolationMode.NEAREST_EXACT``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are + supported. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. + antialias (bool, optional): Whether to apply antialiasing. + It only affects **tensors** with bilinear or bicubic modes and it is + ignored otherwise: on PIL images, antialiasing is always applied on + bilinear or bicubic modes; on other modes (for PIL images and + tensors), antialiasing makes no sense and this parameter is ignored. + Possible values are: + + - ``True`` (default): will apply antialiasing for bilinear or bicubic modes. + Other mode aren't affected. This is probably what you want to use. + - ``False``: will not apply antialiasing for tensors on any mode. PIL + images are still antialiased on bilinear or bicubic modes, because + PIL doesn't support no antialias. + - ``None``: equivalent to ``False`` for tensors and ``True`` for + PIL images. This value exists for legacy reasons and you probably + don't want to use it unless you really know what you are doing. + + The default value changed from ``None`` to ``True`` in + v0.17, for the PIL and Tensor backends to be consistent. + Returns: + PIL Image or Tensor: Cropped image. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(resized_crop) + img = crop(img, top, left, height, width) + img = resize(img, size, interpolation, antialias=antialias) + return img + + +def hflip(img: Tensor) -> Tensor: + """Horizontally flip the given image. + + Args: + img (PIL Image or Tensor): Image to be flipped. If img + is a Tensor, it is expected to be in [..., H, W] format, + where ... means it can have an arbitrary number of leading + dimensions. + + Returns: + PIL Image or Tensor: Horizontally flipped image. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(hflip) + if not isinstance(img, torch.Tensor): + return F_pil.hflip(img) + + return F_t.hflip(img) + + +def _get_perspective_coeffs(startpoints: List[List[int]], endpoints: List[List[int]]) -> List[float]: + """Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms. + + In Perspective Transform each pixel (x, y) in the original image gets transformed as, + (x, y) -> ( (ax + by + c) / (gx + hy + 1), (dx + ey + f) / (gx + hy + 1) ) + + Args: + startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners + ``[top-left, top-right, bottom-right, bottom-left]`` of the original image. + endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners + ``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image. + + Returns: + octuple (a, b, c, d, e, f, g, h) for transforming each pixel. + """ + if len(startpoints) != 4 or len(endpoints) != 4: + raise ValueError( + f"Please provide exactly four corners, got {len(startpoints)} startpoints and {len(endpoints)} endpoints." + ) + a_matrix = torch.zeros(2 * len(startpoints), 8, dtype=torch.float64) + + for i, (p1, p2) in enumerate(zip(endpoints, startpoints)): + a_matrix[2 * i, :] = torch.tensor([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]]) + a_matrix[2 * i + 1, :] = torch.tensor([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]]) + + b_matrix = torch.tensor(startpoints, dtype=torch.float64).view(8) + # do least squares in double precision to prevent numerical issues + res = torch.linalg.lstsq(a_matrix, b_matrix, driver="gels").solution.to(torch.float32) + + output: List[float] = res.tolist() + return output + + +def perspective( + img: Tensor, + startpoints: List[List[int]], + endpoints: List[List[int]], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: Optional[List[float]] = None, +) -> Tensor: + """Perform perspective transform of the given image. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + + Args: + img (PIL Image or Tensor): Image to be transformed. + startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners + ``[top-left, top-right, bottom-right, bottom-left]`` of the original image. + endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners + ``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. + fill (sequence or number, optional): Pixel fill value for the area outside the transformed + image. If given a number, the value is used for all bands respectively. + + .. note:: + In torchscript mode single int/float value is not supported, please use a sequence + of length 1: ``[value, ]``. + + Returns: + PIL Image or Tensor: transformed Image. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(perspective) + + coeffs = _get_perspective_coeffs(startpoints, endpoints) + + if isinstance(interpolation, int): + interpolation = _interpolation_modes_from_int(interpolation) + elif not isinstance(interpolation, InterpolationMode): + raise TypeError( + "Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant" + ) + + if not isinstance(img, torch.Tensor): + pil_interpolation = pil_modes_mapping[interpolation] + return F_pil.perspective(img, coeffs, interpolation=pil_interpolation, fill=fill) + + return F_t.perspective(img, coeffs, interpolation=interpolation.value, fill=fill) + + +def vflip(img: Tensor) -> Tensor: + """Vertically flip the given image. + + Args: + img (PIL Image or Tensor): Image to be flipped. If img + is a Tensor, it is expected to be in [..., H, W] format, + where ... means it can have an arbitrary number of leading + dimensions. + + Returns: + PIL Image or Tensor: Vertically flipped image. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(vflip) + if not isinstance(img, torch.Tensor): + return F_pil.vflip(img) + + return F_t.vflip(img) + + +def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + """Crop the given image into four corners and the central crop. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions + + .. Note:: + This transform returns a tuple of images and there may be a + mismatch in the number of inputs and targets your ``Dataset`` returns. + + Args: + img (PIL Image or Tensor): Image to be cropped. + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). + + Returns: + tuple: tuple (tl, tr, bl, br, center) + Corresponding top left, top right, bottom left, bottom right and center crop. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(five_crop) + if isinstance(size, numbers.Number): + size = (int(size), int(size)) + elif isinstance(size, (tuple, list)) and len(size) == 1: + size = (size[0], size[0]) + + if len(size) != 2: + raise ValueError("Please provide only two dimensions (h, w) for size.") + + _, image_height, image_width = get_dimensions(img) + crop_height, crop_width = size + if crop_width > image_width or crop_height > image_height: + msg = "Requested crop size {} is bigger than input size {}" + raise ValueError(msg.format(size, (image_height, image_width))) + + tl = crop(img, 0, 0, crop_height, crop_width) + tr = crop(img, 0, image_width - crop_width, crop_height, crop_width) + bl = crop(img, image_height - crop_height, 0, crop_height, crop_width) + br = crop(img, image_height - crop_height, image_width - crop_width, crop_height, crop_width) + + center = center_crop(img, [crop_height, crop_width]) + + return tl, tr, bl, br, center + + +def ten_crop( + img: Tensor, size: List[int], vertical_flip: bool = False +) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + """Generate ten cropped images from the given image. + Crop the given image into four corners and the central crop plus the + flipped version of these (horizontal flipping is used by default). + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions + + .. Note:: + This transform returns a tuple of images and there may be a + mismatch in the number of inputs and targets your ``Dataset`` returns. + + Args: + img (PIL Image or Tensor): Image to be cropped. + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). + vertical_flip (bool): Use vertical flipping instead of horizontal + + Returns: + tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip) + Corresponding top left, top right, bottom left, bottom right and + center crop and same for the flipped image. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(ten_crop) + if isinstance(size, numbers.Number): + size = (int(size), int(size)) + elif isinstance(size, (tuple, list)) and len(size) == 1: + size = (size[0], size[0]) + + if len(size) != 2: + raise ValueError("Please provide only two dimensions (h, w) for size.") + + first_five = five_crop(img, size) + + if vertical_flip: + img = vflip(img) + else: + img = hflip(img) + + second_five = five_crop(img, size) + return first_five + second_five + + +def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: + """Adjust brightness of an image. + + Args: + img (PIL Image or Tensor): Image to be adjusted. + If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, + where ... means it can have an arbitrary number of leading dimensions. + brightness_factor (float): How much to adjust the brightness. Can be + any non-negative number. 0 gives a black image, 1 gives the + original image while 2 increases the brightness by a factor of 2. + + Returns: + PIL Image or Tensor: Brightness adjusted image. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(adjust_brightness) + if not isinstance(img, torch.Tensor): + return F_pil.adjust_brightness(img, brightness_factor) + + return F_t.adjust_brightness(img, brightness_factor) + + +def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: + """Adjust contrast of an image. + + Args: + img (PIL Image or Tensor): Image to be adjusted. + If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, + where ... means it can have an arbitrary number of leading dimensions. + contrast_factor (float): How much to adjust the contrast. Can be any + non-negative number. 0 gives a solid gray image, 1 gives the + original image while 2 increases the contrast by a factor of 2. + + Returns: + PIL Image or Tensor: Contrast adjusted image. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(adjust_contrast) + if not isinstance(img, torch.Tensor): + return F_pil.adjust_contrast(img, contrast_factor) + + return F_t.adjust_contrast(img, contrast_factor) + + +def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: + """Adjust color saturation of an image. + + Args: + img (PIL Image or Tensor): Image to be adjusted. + If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, + where ... means it can have an arbitrary number of leading dimensions. + saturation_factor (float): How much to adjust the saturation. 0 will + give a black and white image, 1 will give the original image while + 2 will enhance the saturation by a factor of 2. + + Returns: + PIL Image or Tensor: Saturation adjusted image. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(adjust_saturation) + if not isinstance(img, torch.Tensor): + return F_pil.adjust_saturation(img, saturation_factor) + + return F_t.adjust_saturation(img, saturation_factor) + + +def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: + """Adjust hue of an image. + + The image hue is adjusted by converting the image to HSV and + cyclically shifting the intensities in the hue channel (H). + The image is then converted back to original image mode. + + `hue_factor` is the amount of shift in H channel and must be in the + interval `[-0.5, 0.5]`. + + See `Hue`_ for more details. + + .. _Hue: https://en.wikipedia.org/wiki/Hue + + Args: + img (PIL Image or Tensor): Image to be adjusted. + If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, + where ... means it can have an arbitrary number of leading dimensions. + If img is PIL Image mode "1", "I", "F" and modes with transparency (alpha channel) are not supported. + Note: the pixel values of the input image has to be non-negative for conversion to HSV space; + thus it does not work if you normalize your image to an interval with negative values, + or use an interpolation that generates negative values before using this function. + hue_factor (float): How much to shift the hue channel. Should be in + [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in + HSV space in positive and negative direction respectively. + 0 means no shift. Therefore, both -0.5 and 0.5 will give an image + with complementary colors while 0 gives the original image. + + Returns: + PIL Image or Tensor: Hue adjusted image. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(adjust_hue) + if not isinstance(img, torch.Tensor): + return F_pil.adjust_hue(img, hue_factor) + + return F_t.adjust_hue(img, hue_factor) + + +def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: + r"""Perform gamma correction on an image. + + Also known as Power Law Transform. Intensities in RGB mode are adjusted + based on the following equation: + + .. math:: + I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma} + + See `Gamma Correction`_ for more details. + + .. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction + + Args: + img (PIL Image or Tensor): PIL Image to be adjusted. + If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, + where ... means it can have an arbitrary number of leading dimensions. + If img is PIL Image, modes with transparency (alpha channel) are not supported. + gamma (float): Non negative real number, same as :math:`\gamma` in the equation. + gamma larger than 1 make the shadows darker, + while gamma smaller than 1 make dark regions lighter. + gain (float): The constant multiplier. + Returns: + PIL Image or Tensor: Gamma correction adjusted image. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(adjust_gamma) + if not isinstance(img, torch.Tensor): + return F_pil.adjust_gamma(img, gamma, gain) + + return F_t.adjust_gamma(img, gamma, gain) + + +def _get_inverse_affine_matrix( + center: List[float], angle: float, translate: List[float], scale: float, shear: List[float], inverted: bool = True +) -> List[float]: + # Helper method to compute inverse matrix for affine transformation + + # Pillow requires inverse affine transformation matrix: + # Affine matrix is : M = T * C * RotateScaleShear * C^-1 + # + # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1] + # C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1] + # RotateScaleShear is rotation with scale and shear matrix + # + # RotateScaleShear(a, s, (sx, sy)) = + # = R(a) * S(s) * SHy(sy) * SHx(sx) + # = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ] + # [ s*sin(a - sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ] + # [ 0 , 0 , 1 ] + # where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears: + # SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0] + # [0, 1 ] [-tan(s), 1] + # + # Thus, the inverse is M^-1 = C * RotateScaleShear^-1 * C^-1 * T^-1 + + rot = math.radians(angle) + sx = math.radians(shear[0]) + sy = math.radians(shear[1]) + + cx, cy = center + tx, ty = translate + + # RSS without scaling + a = math.cos(rot - sy) / math.cos(sy) + b = -math.cos(rot - sy) * math.tan(sx) / math.cos(sy) - math.sin(rot) + c = math.sin(rot - sy) / math.cos(sy) + d = -math.sin(rot - sy) * math.tan(sx) / math.cos(sy) + math.cos(rot) + + if inverted: + # Inverted rotation matrix with scale and shear + # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1 + matrix = [d, -b, 0.0, -c, a, 0.0] + matrix = [x / scale for x in matrix] + # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 + matrix[2] += matrix[0] * (-cx - tx) + matrix[1] * (-cy - ty) + matrix[5] += matrix[3] * (-cx - tx) + matrix[4] * (-cy - ty) + # Apply center translation: C * RSS^-1 * C^-1 * T^-1 + matrix[2] += cx + matrix[5] += cy + else: + matrix = [a, b, 0.0, c, d, 0.0] + matrix = [x * scale for x in matrix] + # Apply inverse of center translation: RSS * C^-1 + matrix[2] += matrix[0] * (-cx) + matrix[1] * (-cy) + matrix[5] += matrix[3] * (-cx) + matrix[4] * (-cy) + # Apply translation and center : T * C * RSS * C^-1 + matrix[2] += cx + tx + matrix[5] += cy + ty + + return matrix + + +def rotate( + img: Tensor, + angle: float, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + expand: bool = False, + center: Optional[List[int]] = None, + fill: Optional[List[float]] = None, +) -> Tensor: + """Rotate the image by angle. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + + Args: + img (PIL Image or Tensor): image to be rotated. + angle (number): rotation angle value in degrees, counter-clockwise. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. + expand (bool, optional): Optional expansion flag. + If true, expands the output image to make it large enough to hold the entire rotated image. + If false or omitted, make the output image the same size as the input image. + Note that the expand flag assumes rotation around the center and no translation. + center (sequence, optional): Optional center of rotation. Origin is the upper left corner. + Default is the center of the image. + fill (sequence or number, optional): Pixel fill value for the area outside the transformed + image. If given a number, the value is used for all bands respectively. + + .. note:: + In torchscript mode single int/float value is not supported, please use a sequence + of length 1: ``[value, ]``. + Returns: + PIL Image or Tensor: Rotated image. + + .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters + + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(rotate) + + if isinstance(interpolation, int): + interpolation = _interpolation_modes_from_int(interpolation) + elif not isinstance(interpolation, InterpolationMode): + raise TypeError( + "Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant" + ) + + if not isinstance(angle, (int, float)): + raise TypeError("Argument angle should be int or float") + + if center is not None and not isinstance(center, (list, tuple)): + raise TypeError("Argument center should be a sequence") + + if not isinstance(img, torch.Tensor): + pil_interpolation = pil_modes_mapping[interpolation] + return F_pil.rotate(img, angle=angle, interpolation=pil_interpolation, expand=expand, center=center, fill=fill) + + center_f = [0.0, 0.0] + if center is not None: + _, height, width = get_dimensions(img) + # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. + center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])] + + # due to current incoherence of rotation angle direction between affine and rotate implementations + # we need to set -angle. + matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0]) + return F_t.rotate(img, matrix=matrix, interpolation=interpolation.value, expand=expand, fill=fill) + + +def affine( + img: Tensor, + angle: float, + translate: List[int], + scale: float, + shear: List[float], + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None, + center: Optional[List[int]] = None, +) -> Tensor: + """Apply affine transformation on the image keeping image center invariant. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + + Args: + img (PIL Image or Tensor): image to transform. + angle (number): rotation angle in degrees between -180 and 180, clockwise direction. + translate (sequence of integers): horizontal and vertical translations (post-rotation translation) + scale (float): overall scale + shear (float or sequence): shear angle value in degrees between -180 to 180, clockwise direction. + If a sequence is specified, the first value corresponds to a shear parallel to the x-axis, while + the second value corresponds to a shear parallel to the y-axis. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. + fill (sequence or number, optional): Pixel fill value for the area outside the transformed + image. If given a number, the value is used for all bands respectively. + + .. note:: + In torchscript mode single int/float value is not supported, please use a sequence + of length 1: ``[value, ]``. + center (sequence, optional): Optional center of rotation. Origin is the upper left corner. + Default is the center of the image. + + Returns: + PIL Image or Tensor: Transformed image. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(affine) + + if isinstance(interpolation, int): + interpolation = _interpolation_modes_from_int(interpolation) + elif not isinstance(interpolation, InterpolationMode): + raise TypeError( + "Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant" + ) + + if not isinstance(angle, (int, float)): + raise TypeError("Argument angle should be int or float") + + if not isinstance(translate, (list, tuple)): + raise TypeError("Argument translate should be a sequence") + + if len(translate) != 2: + raise ValueError("Argument translate should be a sequence of length 2") + + if scale <= 0.0: + raise ValueError("Argument scale should be positive") + + if not isinstance(shear, (numbers.Number, (list, tuple))): + raise TypeError("Shear should be either a single value or a sequence of two values") + + if isinstance(angle, int): + angle = float(angle) + + if isinstance(translate, tuple): + translate = list(translate) + + if isinstance(shear, numbers.Number): + shear = [shear, 0.0] + + if isinstance(shear, tuple): + shear = list(shear) + + if len(shear) == 1: + shear = [shear[0], shear[0]] + + if len(shear) != 2: + raise ValueError(f"Shear should be a sequence containing two values. Got {shear}") + + if center is not None and not isinstance(center, (list, tuple)): + raise TypeError("Argument center should be a sequence") + + _, height, width = get_dimensions(img) + if not isinstance(img, torch.Tensor): + # center = (width * 0.5 + 0.5, height * 0.5 + 0.5) + # it is visually better to estimate the center without 0.5 offset + # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine + if center is None: + center = [width * 0.5, height * 0.5] + matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) + pil_interpolation = pil_modes_mapping[interpolation] + return F_pil.affine(img, matrix=matrix, interpolation=pil_interpolation, fill=fill) + + center_f = [0.0, 0.0] + if center is not None: + _, height, width = get_dimensions(img) + # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. + center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])] + + translate_f = [1.0 * t for t in translate] + matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear) + return F_t.affine(img, matrix=matrix, interpolation=interpolation.value, fill=fill) + + +# Looks like to_grayscale() is a stand-alone functional that is never called +# from the transform classes. Perhaps it's still here for BC? I can't be +# bothered to dig. +@torch.jit.unused +def to_grayscale(img, num_output_channels=1): + """Convert PIL image of any mode (RGB, HSV, LAB, etc) to grayscale version of image. + This transform does not support torch Tensor. + + Args: + img (PIL Image): PIL Image to be converted to grayscale. + num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default is 1. + + Returns: + PIL Image: Grayscale version of the image. + + - if num_output_channels = 1 : returned image is single channel + - if num_output_channels = 3 : returned image is 3 channel with r = g = b + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(to_grayscale) + if isinstance(img, Image.Image): + return F_pil.to_grayscale(img, num_output_channels) + + raise TypeError("Input should be PIL Image") + + +def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: + """Convert RGB image to grayscale version of image. + If the image is torch Tensor, it is expected + to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions + + Note: + Please, note that this method supports only RGB images as input. For inputs in other color spaces, + please, consider using :meth:`~torchvision.transforms.functional.to_grayscale` with PIL Image. + + Args: + img (PIL Image or Tensor): RGB Image to be converted to grayscale. + num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1. + + Returns: + PIL Image or Tensor: Grayscale version of the image. + + - if num_output_channels = 1 : returned image is single channel + - if num_output_channels = 3 : returned image is 3 channel with r = g = b + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(rgb_to_grayscale) + if not isinstance(img, torch.Tensor): + return F_pil.to_grayscale(img, num_output_channels) + + return F_t.rgb_to_grayscale(img, num_output_channels) + + +def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor: + """Erase the input Tensor Image with given value. + This transform does not support PIL Image. + + Args: + img (Tensor Image): Tensor image of size (C, H, W) to be erased + i (int): i in (i,j) i.e coordinates of the upper left corner. + j (int): j in (i,j) i.e coordinates of the upper left corner. + h (int): Height of the erased region. + w (int): Width of the erased region. + v: Erasing value. + inplace(bool, optional): For in-place operations. By default, is set False. + + Returns: + Tensor Image: Erased image. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(erase) + if not isinstance(img, torch.Tensor): + raise TypeError(f"img should be Tensor Image. Got {type(img)}") + + return F_t.erase(img, i, j, h, w, v, inplace=inplace) + + +def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Tensor: + """Performs Gaussian blurring on the image by given kernel + + The convolution will be using reflection padding corresponding to the kernel size, to maintain the input shape. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means at most one leading dimension. + + Args: + img (PIL Image or Tensor): Image to be blurred + kernel_size (sequence of ints or int): Gaussian kernel size. Can be a sequence of integers + like ``(kx, ky)`` or a single integer for square kernels. + + .. note:: + In torchscript mode kernel_size as single int is not supported, use a sequence of + length 1: ``[ksize, ]``. + sigma (sequence of floats or float, optional): Gaussian kernel standard deviation. Can be a + sequence of floats like ``(sigma_x, sigma_y)`` or a single float to define the + same sigma in both X/Y directions. If None, then it is computed using + ``kernel_size`` as ``sigma = 0.3 * ((kernel_size - 1) * 0.5 - 1) + 0.8``. + Default, None. + + .. note:: + In torchscript mode sigma as single float is + not supported, use a sequence of length 1: ``[sigma, ]``. + + Returns: + PIL Image or Tensor: Gaussian Blurred version of the image. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(gaussian_blur) + if not isinstance(kernel_size, (int, list, tuple)): + raise TypeError(f"kernel_size should be int or a sequence of integers. Got {type(kernel_size)}") + if isinstance(kernel_size, int): + kernel_size = [kernel_size, kernel_size] + if len(kernel_size) != 2: + raise ValueError(f"If kernel_size is a sequence its length should be 2. Got {len(kernel_size)}") + for ksize in kernel_size: + if ksize % 2 == 0 or ksize < 0: + raise ValueError(f"kernel_size should have odd and positive integers. Got {kernel_size}") + + if sigma is None: + sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size] + + if sigma is not None and not isinstance(sigma, (int, float, list, tuple)): + raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}") + if isinstance(sigma, (int, float)): + sigma = [float(sigma), float(sigma)] + if isinstance(sigma, (list, tuple)) and len(sigma) == 1: + sigma = [sigma[0], sigma[0]] + if len(sigma) != 2: + raise ValueError(f"If sigma is a sequence, its length should be 2. Got {len(sigma)}") + for s in sigma: + if s <= 0.0: + raise ValueError(f"sigma should have positive values. Got {sigma}") + + t_img = img + if not isinstance(img, torch.Tensor): + if not F_pil._is_pil_image(img): + raise TypeError(f"img should be PIL Image or Tensor. Got {type(img)}") + + t_img = pil_to_tensor(img) + + output = F_t.gaussian_blur(t_img, kernel_size, sigma) + + if not isinstance(img, torch.Tensor): + output = to_pil_image(output, mode=img.mode) + return output + + +def invert(img: Tensor) -> Tensor: + """Invert the colors of an RGB/grayscale image. + + Args: + img (PIL Image or Tensor): Image to have its colors inverted. + If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, + where ... means it can have an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Returns: + PIL Image or Tensor: Color inverted image. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(invert) + if not isinstance(img, torch.Tensor): + return F_pil.invert(img) + + return F_t.invert(img) + + +def posterize(img: Tensor, bits: int) -> Tensor: + """Posterize an image by reducing the number of bits for each color channel. + + Args: + img (PIL Image or Tensor): Image to have its colors posterized. + If img is torch Tensor, it should be of type torch.uint8, and + it is expected to be in [..., 1 or 3, H, W] format, where ... means + it can have an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + bits (int): The number of bits to keep for each channel (0-8). + Returns: + PIL Image or Tensor: Posterized image. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(posterize) + if not (0 <= bits <= 8): + raise ValueError(f"The number if bits should be between 0 and 8. Got {bits}") + + if not isinstance(img, torch.Tensor): + return F_pil.posterize(img, bits) + + return F_t.posterize(img, bits) + + +def solarize(img: Tensor, threshold: float) -> Tensor: + """Solarize an RGB/grayscale image by inverting all pixel values above a threshold. + + Args: + img (PIL Image or Tensor): Image to have its colors inverted. + If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, + where ... means it can have an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + threshold (float): All pixels equal or above this value are inverted. + Returns: + PIL Image or Tensor: Solarized image. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(solarize) + if not isinstance(img, torch.Tensor): + return F_pil.solarize(img, threshold) + + return F_t.solarize(img, threshold) + + +def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor: + """Adjust the sharpness of an image. + + Args: + img (PIL Image or Tensor): Image to be adjusted. + If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, + where ... means it can have an arbitrary number of leading dimensions. + sharpness_factor (float): How much to adjust the sharpness. Can be + any non-negative number. 0 gives a blurred image, 1 gives the + original image while 2 increases the sharpness by a factor of 2. + + Returns: + PIL Image or Tensor: Sharpness adjusted image. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(adjust_sharpness) + if not isinstance(img, torch.Tensor): + return F_pil.adjust_sharpness(img, sharpness_factor) + + return F_t.adjust_sharpness(img, sharpness_factor) + + +def autocontrast(img: Tensor) -> Tensor: + """Maximize contrast of an image by remapping its + pixels per channel so that the lowest becomes black and the lightest + becomes white. + + Args: + img (PIL Image or Tensor): Image on which autocontrast is applied. + If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, + where ... means it can have an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Returns: + PIL Image or Tensor: An image that was autocontrasted. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(autocontrast) + if not isinstance(img, torch.Tensor): + return F_pil.autocontrast(img) + + return F_t.autocontrast(img) + + +def equalize(img: Tensor) -> Tensor: + """Equalize the histogram of an image by applying + a non-linear mapping to the input in order to create a uniform + distribution of grayscale values in the output. + + Args: + img (PIL Image or Tensor): Image on which equalize is applied. + If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, + where ... means it can have an arbitrary number of leading dimensions. + The tensor dtype must be ``torch.uint8`` and values are expected to be in ``[0, 255]``. + If img is PIL Image, it is expected to be in mode "P", "L" or "RGB". + + Returns: + PIL Image or Tensor: An image that was equalized. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(equalize) + if not isinstance(img, torch.Tensor): + return F_pil.equalize(img) + + return F_t.equalize(img) + + +def elastic_transform( + img: Tensor, + displacement: Tensor, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: Optional[List[float]] = None, +) -> Tensor: + """Transform a tensor image with elastic transformations. + Given alpha and sigma, it will generate displacement + vectors for all pixels based on random offsets. Alpha controls the strength + and sigma controls the smoothness of the displacements. + The displacements are added to an identity grid and the resulting grid is + used to grid_sample from the image. + + Applications: + Randomly transforms the morphology of objects in images and produces a + see-through-water-like effect. + + Args: + img (PIL Image or Tensor): Image on which elastic_transform is applied. + If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, + where ... means it can have an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "P", "L" or "RGB". + displacement (Tensor): The displacement field. Expected shape is [1, H, W, 2]. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. + Default is ``InterpolationMode.BILINEAR``. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. + fill (number or str or tuple): Pixel fill value for constant fill. Default is 0. + If a tuple of length 3, it is used to fill R, G, B channels respectively. + This value is only used when the padding_mode is constant. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(elastic_transform) + # Backward compatibility with integer value + if isinstance(interpolation, int): + warnings.warn( + "Argument interpolation should be of type InterpolationMode instead of int. " + "Please, use InterpolationMode enum." + ) + interpolation = _interpolation_modes_from_int(interpolation) + + if not isinstance(displacement, torch.Tensor): + raise TypeError("Argument displacement should be a Tensor") + + t_img = img + if not isinstance(img, torch.Tensor): + if not F_pil._is_pil_image(img): + raise TypeError(f"img should be PIL Image or Tensor. Got {type(img)}") + t_img = pil_to_tensor(img) + + shape = t_img.shape + shape = (1,) + shape[-2:] + (2,) + if shape != displacement.shape: + raise ValueError(f"Argument displacement shape should be {shape}, but given {displacement.shape}") + + # TODO: if image shape is [N1, N2, ..., C, H, W] and + # displacement is [1, H, W, 2] we need to reshape input image + # such grid_sampler takes internal code for 4D input + + output = F_t.elastic_transform( + t_img, + displacement, + interpolation=interpolation.value, + fill=fill, + ) + + if not isinstance(img, torch.Tensor): + output = to_pil_image(output, mode=img.mode) + return output diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/transforms.py b/.venv/lib/python3.11/site-packages/torchvision/transforms/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..07932390efeceb8e0ae1d1aada6faddd56b17461 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/transforms/transforms.py @@ -0,0 +1,2153 @@ +import math +import numbers +import random +import warnings +from collections.abc import Sequence +from typing import List, Optional, Tuple, Union + +import torch +from torch import Tensor + +try: + import accimage +except ImportError: + accimage = None + +from ..utils import _log_api_usage_once +from . import functional as F +from .functional import _interpolation_modes_from_int, InterpolationMode + +__all__ = [ + "Compose", + "ToTensor", + "PILToTensor", + "ConvertImageDtype", + "ToPILImage", + "Normalize", + "Resize", + "CenterCrop", + "Pad", + "Lambda", + "RandomApply", + "RandomChoice", + "RandomOrder", + "RandomCrop", + "RandomHorizontalFlip", + "RandomVerticalFlip", + "RandomResizedCrop", + "FiveCrop", + "TenCrop", + "LinearTransformation", + "ColorJitter", + "RandomRotation", + "RandomAffine", + "Grayscale", + "RandomGrayscale", + "RandomPerspective", + "RandomErasing", + "GaussianBlur", + "InterpolationMode", + "RandomInvert", + "RandomPosterize", + "RandomSolarize", + "RandomAdjustSharpness", + "RandomAutocontrast", + "RandomEqualize", + "ElasticTransform", +] + + +class Compose: + """Composes several transforms together. This transform does not support torchscript. + Please, see the note below. + + Args: + transforms (list of ``Transform`` objects): list of transforms to compose. + + Example: + >>> transforms.Compose([ + >>> transforms.CenterCrop(10), + >>> transforms.PILToTensor(), + >>> transforms.ConvertImageDtype(torch.float), + >>> ]) + + .. note:: + In order to script the transformations, please use ``torch.nn.Sequential`` as below. + + >>> transforms = torch.nn.Sequential( + >>> transforms.CenterCrop(10), + >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + >>> ) + >>> scripted_transforms = torch.jit.script(transforms) + + Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require + `lambda` functions or ``PIL.Image``. + + """ + + def __init__(self, transforms): + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(self) + self.transforms = transforms + + def __call__(self, img): + for t in self.transforms: + img = t(img) + return img + + def __repr__(self) -> str: + format_string = self.__class__.__name__ + "(" + for t in self.transforms: + format_string += "\n" + format_string += f" {t}" + format_string += "\n)" + return format_string + + +class ToTensor: + """Convert a PIL Image or ndarray to tensor and scale the values accordingly. + + This transform does not support torchscript. + + Converts a PIL Image or numpy.ndarray (H x W x C) in the range + [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] + if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) + or if the numpy.ndarray has dtype = np.uint8 + + In the other cases, tensors are returned without scaling. + + .. note:: + Because the input image is scaled to [0.0, 1.0], this transformation should not be used when + transforming target image masks. See the `references`_ for implementing the transforms for image masks. + + .. _references: https://github.com/pytorch/vision/tree/main/references/segmentation + """ + + def __init__(self) -> None: + _log_api_usage_once(self) + + def __call__(self, pic): + """ + Args: + pic (PIL Image or numpy.ndarray): Image to be converted to tensor. + + Returns: + Tensor: Converted image. + """ + return F.to_tensor(pic) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" + + +class PILToTensor: + """Convert a PIL Image to a tensor of the same type - this does not scale values. + + This transform does not support torchscript. + + Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W). + """ + + def __init__(self) -> None: + _log_api_usage_once(self) + + def __call__(self, pic): + """ + .. note:: + + A deep copy of the underlying array is performed. + + Args: + pic (PIL Image): Image to be converted to tensor. + + Returns: + Tensor: Converted image. + """ + return F.pil_to_tensor(pic) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" + + +class ConvertImageDtype(torch.nn.Module): + """Convert a tensor image to the given ``dtype`` and scale the values accordingly. + + This function does not support PIL Image. + + Args: + dtype (torch.dtype): Desired data type of the output + + .. note:: + + When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly. + If converted back and forth, this mismatch has no effect. + + Raises: + RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as + well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to + overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range + of the integer ``dtype``. + """ + + def __init__(self, dtype: torch.dtype) -> None: + super().__init__() + _log_api_usage_once(self) + self.dtype = dtype + + def forward(self, image): + return F.convert_image_dtype(image, self.dtype) + + +class ToPILImage: + """Convert a tensor or an ndarray to PIL Image + + This transform does not support torchscript. + + Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape + H x W x C to a PIL Image while adjusting the value range depending on the ``mode``. + + Args: + mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). + If ``mode`` is ``None`` (default) there are some assumptions made about the input data: + + - If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``. + - If the input has 3 channels, the ``mode`` is assumed to be ``RGB``. + - If the input has 2 channels, the ``mode`` is assumed to be ``LA``. + - If the input has 1 channel, the ``mode`` is determined by the data type (i.e ``int``, ``float``, ``short``). + + .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes + """ + + def __init__(self, mode=None): + _log_api_usage_once(self) + self.mode = mode + + def __call__(self, pic): + """ + Args: + pic (Tensor or numpy.ndarray): Image to be converted to PIL Image. + + Returns: + PIL Image: Image converted to PIL Image. + + """ + return F.to_pil_image(pic, self.mode) + + def __repr__(self) -> str: + format_string = self.__class__.__name__ + "(" + if self.mode is not None: + format_string += f"mode={self.mode}" + format_string += ")" + return format_string + + +class Normalize(torch.nn.Module): + """Normalize a tensor image with mean and standard deviation. + This transform does not support PIL Image. + Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n`` + channels, this transform will normalize each channel of the input + ``torch.*Tensor`` i.e., + ``output[channel] = (input[channel] - mean[channel]) / std[channel]`` + + .. note:: + This transform acts out of place, i.e., it does not mutate the input tensor. + + Args: + mean (sequence): Sequence of means for each channel. + std (sequence): Sequence of standard deviations for each channel. + inplace(bool,optional): Bool to make this operation in-place. + + """ + + def __init__(self, mean, std, inplace=False): + super().__init__() + _log_api_usage_once(self) + self.mean = mean + self.std = std + self.inplace = inplace + + def forward(self, tensor: Tensor) -> Tensor: + """ + Args: + tensor (Tensor): Tensor image to be normalized. + + Returns: + Tensor: Normalized Tensor image. + """ + return F.normalize(tensor, self.mean, self.std, self.inplace) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})" + + +class Resize(torch.nn.Module): + """Resize the input image to the given size. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means a maximum of two leading dimensions + + Args: + size (sequence or int): Desired output size. If size is a sequence like + (h, w), output size will be matched to this. If size is an int, + smaller edge of the image will be matched to this number. + i.e, if height > width, then image will be rescaled to + (size * height / width, size). + + .. note:: + In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``, + ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. + max_size (int, optional): The maximum allowed for the longer edge of + the resized image. If the longer edge of the image is greater + than ``max_size`` after being resized according to ``size``, + ``size`` will be overruled so that the longer edge is equal to + ``max_size``. + As a result, the smaller edge may be shorter than ``size``. This + is only supported if ``size`` is an int (or a sequence of length + 1 in torchscript mode). + antialias (bool, optional): Whether to apply antialiasing. + It only affects **tensors** with bilinear or bicubic modes and it is + ignored otherwise: on PIL images, antialiasing is always applied on + bilinear or bicubic modes; on other modes (for PIL images and + tensors), antialiasing makes no sense and this parameter is ignored. + Possible values are: + + - ``True`` (default): will apply antialiasing for bilinear or bicubic modes. + Other mode aren't affected. This is probably what you want to use. + - ``False``: will not apply antialiasing for tensors on any mode. PIL + images are still antialiased on bilinear or bicubic modes, because + PIL doesn't support no antialias. + - ``None``: equivalent to ``False`` for tensors and ``True`` for + PIL images. This value exists for legacy reasons and you probably + don't want to use it unless you really know what you are doing. + + The default value changed from ``None`` to ``True`` in + v0.17, for the PIL and Tensor backends to be consistent. + """ + + def __init__(self, size, interpolation=InterpolationMode.BILINEAR, max_size=None, antialias=True): + super().__init__() + _log_api_usage_once(self) + if not isinstance(size, (int, Sequence)): + raise TypeError(f"Size should be int or sequence. Got {type(size)}") + if isinstance(size, Sequence) and len(size) not in (1, 2): + raise ValueError("If size is a sequence, it should have 1 or 2 values") + self.size = size + self.max_size = max_size + + if isinstance(interpolation, int): + interpolation = _interpolation_modes_from_int(interpolation) + + self.interpolation = interpolation + self.antialias = antialias + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be scaled. + + Returns: + PIL Image or Tensor: Rescaled image. + """ + return F.resize(img, self.size, self.interpolation, self.max_size, self.antialias) + + def __repr__(self) -> str: + detail = f"(size={self.size}, interpolation={self.interpolation.value}, max_size={self.max_size}, antialias={self.antialias})" + return f"{self.__class__.__name__}{detail}" + + +class CenterCrop(torch.nn.Module): + """Crops the given image at the center. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + If image size is smaller than output size along any edge, image is padded with 0 and then center cropped. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). + """ + + def __init__(self, size): + super().__init__() + _log_api_usage_once(self) + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be cropped. + + Returns: + PIL Image or Tensor: Cropped image. + """ + return F.center_crop(img, self.size) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size})" + + +class Pad(torch.nn.Module): + """Pad the given image on all sides with the given "pad" value. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric, + at most 3 leading dimensions for mode edge, + and an arbitrary number of leading dimensions for mode constant + + Args: + padding (int or sequence): Padding on each border. If a single int is provided this + is used to pad all borders. If sequence of length 2 is provided this is the padding + on left/right and top/bottom respectively. If a sequence of length 4 is provided + this is the padding for the left, top, right and bottom borders respectively. + + .. note:: + In torchscript mode padding as single int is not supported, use a sequence of + length 1: ``[padding, ]``. + fill (number or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of + length 3, it is used to fill R, G, B channels respectively. + This value is only used when the padding_mode is constant. + Only number is supported for torch Tensor. + Only int or tuple value is supported for PIL Image. + padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric. + Default is constant. + + - constant: pads with a constant value, this value is specified with fill + + - edge: pads with the last value at the edge of the image. + If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2 + + - reflect: pads with reflection of image without repeating the last value on the edge. + For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode + will result in [3, 2, 1, 2, 3, 4, 3, 2] + + - symmetric: pads with reflection of image repeating the last value on the edge. + For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode + will result in [2, 1, 1, 2, 3, 4, 4, 3] + """ + + def __init__(self, padding, fill=0, padding_mode="constant"): + super().__init__() + _log_api_usage_once(self) + if not isinstance(padding, (numbers.Number, tuple, list)): + raise TypeError("Got inappropriate padding arg") + + if not isinstance(fill, (numbers.Number, tuple, list)): + raise TypeError("Got inappropriate fill arg") + + if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: + raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") + + if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]: + raise ValueError( + f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple" + ) + + self.padding = padding + self.fill = fill + self.padding_mode = padding_mode + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be padded. + + Returns: + PIL Image or Tensor: Padded image. + """ + return F.pad(img, self.padding, self.fill, self.padding_mode) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(padding={self.padding}, fill={self.fill}, padding_mode={self.padding_mode})" + + +class Lambda: + """Apply a user-defined lambda as a transform. This transform does not support torchscript. + + Args: + lambd (function): Lambda/function to be used for transform. + """ + + def __init__(self, lambd): + _log_api_usage_once(self) + if not callable(lambd): + raise TypeError(f"Argument lambd should be callable, got {repr(type(lambd).__name__)}") + self.lambd = lambd + + def __call__(self, img): + return self.lambd(img) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" + + +class RandomTransforms: + """Base class for a list of transformations with randomness + + Args: + transforms (sequence): list of transformations + """ + + def __init__(self, transforms): + _log_api_usage_once(self) + if not isinstance(transforms, Sequence): + raise TypeError("Argument transforms should be a sequence") + self.transforms = transforms + + def __call__(self, *args, **kwargs): + raise NotImplementedError() + + def __repr__(self) -> str: + format_string = self.__class__.__name__ + "(" + for t in self.transforms: + format_string += "\n" + format_string += f" {t}" + format_string += "\n)" + return format_string + + +class RandomApply(torch.nn.Module): + """Apply randomly a list of transformations with a given probability. + + .. note:: + In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of + transforms as shown below: + + >>> transforms = transforms.RandomApply(torch.nn.ModuleList([ + >>> transforms.ColorJitter(), + >>> ]), p=0.3) + >>> scripted_transforms = torch.jit.script(transforms) + + Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require + `lambda` functions or ``PIL.Image``. + + Args: + transforms (sequence or torch.nn.Module): list of transformations + p (float): probability + """ + + def __init__(self, transforms, p=0.5): + super().__init__() + _log_api_usage_once(self) + self.transforms = transforms + self.p = p + + def forward(self, img): + if self.p < torch.rand(1): + return img + for t in self.transforms: + img = t(img) + return img + + def __repr__(self) -> str: + format_string = self.__class__.__name__ + "(" + format_string += f"\n p={self.p}" + for t in self.transforms: + format_string += "\n" + format_string += f" {t}" + format_string += "\n)" + return format_string + + +class RandomOrder(RandomTransforms): + """Apply a list of transformations in a random order. This transform does not support torchscript.""" + + def __call__(self, img): + order = list(range(len(self.transforms))) + random.shuffle(order) + for i in order: + img = self.transforms[i](img) + return img + + +class RandomChoice(RandomTransforms): + """Apply single transformation randomly picked from a list. This transform does not support torchscript.""" + + def __init__(self, transforms, p=None): + super().__init__(transforms) + if p is not None and not isinstance(p, Sequence): + raise TypeError("Argument p should be a sequence") + self.p = p + + def __call__(self, *args): + t = random.choices(self.transforms, weights=self.p)[0] + return t(*args) + + def __repr__(self) -> str: + return f"{super().__repr__()}(p={self.p})" + + +class RandomCrop(torch.nn.Module): + """Crop the given image at a random location. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions, + but if non-constant padding is used, the input is expected to have at most 2 leading dimensions + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). + padding (int or sequence, optional): Optional padding on each border + of the image. Default is None. If a single int is provided this + is used to pad all borders. If sequence of length 2 is provided this is the padding + on left/right and top/bottom respectively. If a sequence of length 4 is provided + this is the padding for the left, top, right and bottom borders respectively. + + .. note:: + In torchscript mode padding as single int is not supported, use a sequence of + length 1: ``[padding, ]``. + pad_if_needed (boolean): It will pad the image if smaller than the + desired size to avoid raising an exception. Since cropping is done + after padding, the padding seems to be done at a random offset. + fill (number or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of + length 3, it is used to fill R, G, B channels respectively. + This value is only used when the padding_mode is constant. + Only number is supported for torch Tensor. + Only int or tuple value is supported for PIL Image. + padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric. + Default is constant. + + - constant: pads with a constant value, this value is specified with fill + + - edge: pads with the last value at the edge of the image. + If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2 + + - reflect: pads with reflection of image without repeating the last value on the edge. + For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode + will result in [3, 2, 1, 2, 3, 4, 3, 2] + + - symmetric: pads with reflection of image repeating the last value on the edge. + For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode + will result in [2, 1, 1, 2, 3, 4, 4, 3] + """ + + @staticmethod + def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]: + """Get parameters for ``crop`` for a random crop. + + Args: + img (PIL Image or Tensor): Image to be cropped. + output_size (tuple): Expected output size of the crop. + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. + """ + _, h, w = F.get_dimensions(img) + th, tw = output_size + + if h < th or w < tw: + raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}") + + if w == tw and h == th: + return 0, 0, h, w + + i = torch.randint(0, h - th + 1, size=(1,)).item() + j = torch.randint(0, w - tw + 1, size=(1,)).item() + return i, j, th, tw + + def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"): + super().__init__() + _log_api_usage_once(self) + + self.size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")) + + self.padding = padding + self.pad_if_needed = pad_if_needed + self.fill = fill + self.padding_mode = padding_mode + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be cropped. + + Returns: + PIL Image or Tensor: Cropped image. + """ + if self.padding is not None: + img = F.pad(img, self.padding, self.fill, self.padding_mode) + + _, height, width = F.get_dimensions(img) + # pad the width if needed + if self.pad_if_needed and width < self.size[1]: + padding = [self.size[1] - width, 0] + img = F.pad(img, padding, self.fill, self.padding_mode) + # pad the height if needed + if self.pad_if_needed and height < self.size[0]: + padding = [0, self.size[0] - height] + img = F.pad(img, padding, self.fill, self.padding_mode) + + i, j, h, w = self.get_params(img, self.size) + + return F.crop(img, i, j, h, w) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size}, padding={self.padding})" + + +class RandomHorizontalFlip(torch.nn.Module): + """Horizontally flip the given image randomly with a given probability. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading + dimensions + + Args: + p (float): probability of the image being flipped. Default value is 0.5 + """ + + def __init__(self, p=0.5): + super().__init__() + _log_api_usage_once(self) + self.p = p + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be flipped. + + Returns: + PIL Image or Tensor: Randomly flipped image. + """ + if torch.rand(1) < self.p: + return F.hflip(img) + return img + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(p={self.p})" + + +class RandomVerticalFlip(torch.nn.Module): + """Vertically flip the given image randomly with a given probability. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading + dimensions + + Args: + p (float): probability of the image being flipped. Default value is 0.5 + """ + + def __init__(self, p=0.5): + super().__init__() + _log_api_usage_once(self) + self.p = p + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be flipped. + + Returns: + PIL Image or Tensor: Randomly flipped image. + """ + if torch.rand(1) < self.p: + return F.vflip(img) + return img + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(p={self.p})" + + +class RandomPerspective(torch.nn.Module): + """Performs a random perspective transformation of the given image with a given probability. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + + Args: + distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1. + Default is 0.5. + p (float): probability of the image being transformed. Default is 0.5. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. + fill (sequence or number): Pixel fill value for the area outside the transformed + image. Default is ``0``. If given a number, the value is used for all bands respectively. + """ + + def __init__(self, distortion_scale=0.5, p=0.5, interpolation=InterpolationMode.BILINEAR, fill=0): + super().__init__() + _log_api_usage_once(self) + self.p = p + + if isinstance(interpolation, int): + interpolation = _interpolation_modes_from_int(interpolation) + + self.interpolation = interpolation + self.distortion_scale = distortion_scale + + if fill is None: + fill = 0 + elif not isinstance(fill, (Sequence, numbers.Number)): + raise TypeError("Fill should be either a sequence or a number.") + + self.fill = fill + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be Perspectively transformed. + + Returns: + PIL Image or Tensor: Randomly transformed image. + """ + + fill = self.fill + channels, height, width = F.get_dimensions(img) + if isinstance(img, Tensor): + if isinstance(fill, (int, float)): + fill = [float(fill)] * channels + else: + fill = [float(f) for f in fill] + + if torch.rand(1) < self.p: + startpoints, endpoints = self.get_params(width, height, self.distortion_scale) + return F.perspective(img, startpoints, endpoints, self.interpolation, fill) + return img + + @staticmethod + def get_params(width: int, height: int, distortion_scale: float) -> Tuple[List[List[int]], List[List[int]]]: + """Get parameters for ``perspective`` for a random perspective transform. + + Args: + width (int): width of the image. + height (int): height of the image. + distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1. + + Returns: + List containing [top-left, top-right, bottom-right, bottom-left] of the original image, + List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image. + """ + half_height = height // 2 + half_width = width // 2 + topleft = [ + int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()), + int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()), + ] + topright = [ + int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()), + int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()), + ] + botright = [ + int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()), + int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()), + ] + botleft = [ + int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()), + int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()), + ] + startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]] + endpoints = [topleft, topright, botright, botleft] + return startpoints, endpoints + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(p={self.p})" + + +class RandomResizedCrop(torch.nn.Module): + """Crop a random portion of image and resize it to a given size. + + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions + + A crop of the original image is made: the crop has a random area (H * W) + and a random aspect ratio. This crop is finally resized to the given + size. This is popularly used to train the Inception networks. + + Args: + size (int or sequence): expected output size of the crop, for each edge. If size is an + int instead of sequence like (h, w), a square output size ``(size, size)`` is + made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). + + .. note:: + In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``. + scale (tuple of float): Specifies the lower and upper bounds for the random area of the crop, + before resizing. The scale is defined with respect to the area of the original image. + ratio (tuple of float): lower and upper bounds for the random aspect ratio of the crop, before + resizing. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``, + ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. + antialias (bool, optional): Whether to apply antialiasing. + It only affects **tensors** with bilinear or bicubic modes and it is + ignored otherwise: on PIL images, antialiasing is always applied on + bilinear or bicubic modes; on other modes (for PIL images and + tensors), antialiasing makes no sense and this parameter is ignored. + Possible values are: + + - ``True`` (default): will apply antialiasing for bilinear or bicubic modes. + Other mode aren't affected. This is probably what you want to use. + - ``False``: will not apply antialiasing for tensors on any mode. PIL + images are still antialiased on bilinear or bicubic modes, because + PIL doesn't support no antialias. + - ``None``: equivalent to ``False`` for tensors and ``True`` for + PIL images. This value exists for legacy reasons and you probably + don't want to use it unless you really know what you are doing. + + The default value changed from ``None`` to ``True`` in + v0.17, for the PIL and Tensor backends to be consistent. + """ + + def __init__( + self, + size, + scale=(0.08, 1.0), + ratio=(3.0 / 4.0, 4.0 / 3.0), + interpolation=InterpolationMode.BILINEAR, + antialias: Optional[bool] = True, + ): + super().__init__() + _log_api_usage_once(self) + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + + if not isinstance(scale, Sequence): + raise TypeError("Scale should be a sequence") + if not isinstance(ratio, Sequence): + raise TypeError("Ratio should be a sequence") + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + warnings.warn("Scale and ratio should be of kind (min, max)") + + if isinstance(interpolation, int): + interpolation = _interpolation_modes_from_int(interpolation) + + self.interpolation = interpolation + self.antialias = antialias + self.scale = scale + self.ratio = ratio + + @staticmethod + def get_params(img: Tensor, scale: List[float], ratio: List[float]) -> Tuple[int, int, int, int]: + """Get parameters for ``crop`` for a random sized crop. + + Args: + img (PIL Image or Tensor): Input image. + scale (list): range of scale of the origin size cropped + ratio (list): range of aspect ratio of the origin aspect ratio cropped + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for a random + sized crop. + """ + _, height, width = F.get_dimensions(img) + area = height * width + + log_ratio = torch.log(torch.tensor(ratio)) + for _ in range(10): + target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() + aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item() + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if 0 < w <= width and 0 < h <= height: + i = torch.randint(0, height - h + 1, size=(1,)).item() + j = torch.randint(0, width - w + 1, size=(1,)).item() + return i, j, h, w + + # Fallback to central crop + in_ratio = float(width) / float(height) + if in_ratio < min(ratio): + w = width + h = int(round(w / min(ratio))) + elif in_ratio > max(ratio): + h = height + w = int(round(h * max(ratio))) + else: # whole image + w = width + h = height + i = (height - h) // 2 + j = (width - w) // 2 + return i, j, h, w + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be cropped and resized. + + Returns: + PIL Image or Tensor: Randomly cropped and resized image. + """ + i, j, h, w = self.get_params(img, self.scale, self.ratio) + return F.resized_crop(img, i, j, h, w, self.size, self.interpolation, antialias=self.antialias) + + def __repr__(self) -> str: + interpolate_str = self.interpolation.value + format_string = self.__class__.__name__ + f"(size={self.size}" + format_string += f", scale={tuple(round(s, 4) for s in self.scale)}" + format_string += f", ratio={tuple(round(r, 4) for r in self.ratio)}" + format_string += f", interpolation={interpolate_str}" + format_string += f", antialias={self.antialias})" + return format_string + + +class FiveCrop(torch.nn.Module): + """Crop the given image into four corners and the central crop. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading + dimensions + + .. Note:: + This transform returns a tuple of images and there may be a mismatch in the number of + inputs and targets your Dataset returns. See below for an example of how to deal with + this. + + Args: + size (sequence or int): Desired output size of the crop. If size is an ``int`` + instead of sequence like (h, w), a square crop of size (size, size) is made. + If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). + + Example: + >>> transform = Compose([ + >>> FiveCrop(size), # this is a list of PIL Images + >>> Lambda(lambda crops: torch.stack([PILToTensor()(crop) for crop in crops])) # returns a 4D tensor + >>> ]) + >>> #In your test loop you can do the following: + >>> input, target = batch # input is a 5d tensor, target is 2d + >>> bs, ncrops, c, h, w = input.size() + >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops + >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops + """ + + def __init__(self, size): + super().__init__() + _log_api_usage_once(self) + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be cropped. + + Returns: + tuple of 5 images. Image can be PIL Image or Tensor + """ + return F.five_crop(img, self.size) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size})" + + +class TenCrop(torch.nn.Module): + """Crop the given image into four corners and the central crop plus the flipped version of + these (horizontal flipping is used by default). + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading + dimensions + + .. Note:: + This transform returns a tuple of images and there may be a mismatch in the number of + inputs and targets your Dataset returns. See below for an example of how to deal with + this. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). + vertical_flip (bool): Use vertical flipping instead of horizontal + + Example: + >>> transform = Compose([ + >>> TenCrop(size), # this is a tuple of PIL Images + >>> Lambda(lambda crops: torch.stack([PILToTensor()(crop) for crop in crops])) # returns a 4D tensor + >>> ]) + >>> #In your test loop you can do the following: + >>> input, target = batch # input is a 5d tensor, target is 2d + >>> bs, ncrops, c, h, w = input.size() + >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops + >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops + """ + + def __init__(self, size, vertical_flip=False): + super().__init__() + _log_api_usage_once(self) + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + self.vertical_flip = vertical_flip + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be cropped. + + Returns: + tuple of 10 images. Image can be PIL Image or Tensor + """ + return F.ten_crop(img, self.size, self.vertical_flip) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size}, vertical_flip={self.vertical_flip})" + + +class LinearTransformation(torch.nn.Module): + """Transform a tensor image with a square transformation matrix and a mean_vector computed + offline. + This transform does not support PIL Image. + Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and + subtract mean_vector from it which is then followed by computing the dot + product with the transformation matrix and then reshaping the tensor to its + original shape. + + Applications: + whitening transformation: Suppose X is a column vector zero-centered data. + Then compute the data covariance matrix [D x D] with torch.mm(X.t(), X), + perform SVD on this matrix and pass it as transformation_matrix. + + Args: + transformation_matrix (Tensor): tensor [D x D], D = C x H x W + mean_vector (Tensor): tensor [D], D = C x H x W + """ + + def __init__(self, transformation_matrix, mean_vector): + super().__init__() + _log_api_usage_once(self) + if transformation_matrix.size(0) != transformation_matrix.size(1): + raise ValueError( + "transformation_matrix should be square. Got " + f"{tuple(transformation_matrix.size())} rectangular matrix." + ) + + if mean_vector.size(0) != transformation_matrix.size(0): + raise ValueError( + f"mean_vector should have the same length {mean_vector.size(0)}" + f" as any one of the dimensions of the transformation_matrix [{tuple(transformation_matrix.size())}]" + ) + + if transformation_matrix.device != mean_vector.device: + raise ValueError( + f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}" + ) + + if transformation_matrix.dtype != mean_vector.dtype: + raise ValueError( + f"Input tensors should have the same dtype. Got {transformation_matrix.dtype} and {mean_vector.dtype}" + ) + + self.transformation_matrix = transformation_matrix + self.mean_vector = mean_vector + + def forward(self, tensor: Tensor) -> Tensor: + """ + Args: + tensor (Tensor): Tensor image to be whitened. + + Returns: + Tensor: Transformed image. + """ + shape = tensor.shape + n = shape[-3] * shape[-2] * shape[-1] + if n != self.transformation_matrix.shape[0]: + raise ValueError( + "Input tensor and transformation matrix have incompatible shape." + + f"[{shape[-3]} x {shape[-2]} x {shape[-1]}] != " + + f"{self.transformation_matrix.shape[0]}" + ) + + if tensor.device.type != self.mean_vector.device.type: + raise ValueError( + "Input tensor should be on the same device as transformation matrix and mean vector. " + f"Got {tensor.device} vs {self.mean_vector.device}" + ) + + flat_tensor = tensor.view(-1, n) - self.mean_vector + transformation_matrix = self.transformation_matrix.to(flat_tensor.dtype) + transformed_tensor = torch.mm(flat_tensor, transformation_matrix) + tensor = transformed_tensor.view(shape) + return tensor + + def __repr__(self) -> str: + s = ( + f"{self.__class__.__name__}(transformation_matrix=" + f"{self.transformation_matrix.tolist()}" + f", mean_vector={self.mean_vector.tolist()})" + ) + return s + + +class ColorJitter(torch.nn.Module): + """Randomly change the brightness, contrast, saturation and hue of an image. + If the image is torch Tensor, it is expected + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, mode "1", "I", "F" and modes with transparency (alpha channel) are not supported. + + Args: + brightness (float or tuple of float (min, max)): How much to jitter brightness. + brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] + or the given [min, max]. Should be non negative numbers. + contrast (float or tuple of float (min, max)): How much to jitter contrast. + contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] + or the given [min, max]. Should be non-negative numbers. + saturation (float or tuple of float (min, max)): How much to jitter saturation. + saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] + or the given [min, max]. Should be non negative numbers. + hue (float or tuple of float (min, max)): How much to jitter hue. + hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. + Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. + To jitter hue, the pixel values of the input image has to be non-negative for conversion to HSV space; + thus it does not work if you normalize your image to an interval with negative values, + or use an interpolation that generates negative values before using this function. + """ + + def __init__( + self, + brightness: Union[float, Tuple[float, float]] = 0, + contrast: Union[float, Tuple[float, float]] = 0, + saturation: Union[float, Tuple[float, float]] = 0, + hue: Union[float, Tuple[float, float]] = 0, + ) -> None: + super().__init__() + _log_api_usage_once(self) + self.brightness = self._check_input(brightness, "brightness") + self.contrast = self._check_input(contrast, "contrast") + self.saturation = self._check_input(saturation, "saturation") + self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False) + + @torch.jit.unused + def _check_input(self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True): + if isinstance(value, numbers.Number): + if value < 0: + raise ValueError(f"If {name} is a single number, it must be non negative.") + value = [center - float(value), center + float(value)] + if clip_first_on_zero: + value[0] = max(value[0], 0.0) + elif isinstance(value, (tuple, list)) and len(value) == 2: + value = [float(value[0]), float(value[1])] + else: + raise TypeError(f"{name} should be a single number or a list/tuple with length 2.") + + if not bound[0] <= value[0] <= value[1] <= bound[1]: + raise ValueError(f"{name} values should be between {bound}, but got {value}.") + + # if value is 0 or (1., 1.) for brightness/contrast/saturation + # or (0., 0.) for hue, do nothing + if value[0] == value[1] == center: + return None + else: + return tuple(value) + + @staticmethod + def get_params( + brightness: Optional[List[float]], + contrast: Optional[List[float]], + saturation: Optional[List[float]], + hue: Optional[List[float]], + ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]: + """Get the parameters for the randomized transform to be applied on image. + + Args: + brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen + uniformly. Pass None to turn off the transformation. + contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen + uniformly. Pass None to turn off the transformation. + saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen + uniformly. Pass None to turn off the transformation. + hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly. + Pass None to turn off the transformation. + + Returns: + tuple: The parameters used to apply the randomized transform + along with their random order. + """ + fn_idx = torch.randperm(4) + + b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1])) + c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1])) + s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1])) + h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1])) + + return fn_idx, b, c, s, h + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Input image. + + Returns: + PIL Image or Tensor: Color jittered image. + """ + fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params( + self.brightness, self.contrast, self.saturation, self.hue + ) + + for fn_id in fn_idx: + if fn_id == 0 and brightness_factor is not None: + img = F.adjust_brightness(img, brightness_factor) + elif fn_id == 1 and contrast_factor is not None: + img = F.adjust_contrast(img, contrast_factor) + elif fn_id == 2 and saturation_factor is not None: + img = F.adjust_saturation(img, saturation_factor) + elif fn_id == 3 and hue_factor is not None: + img = F.adjust_hue(img, hue_factor) + + return img + + def __repr__(self) -> str: + s = ( + f"{self.__class__.__name__}(" + f"brightness={self.brightness}" + f", contrast={self.contrast}" + f", saturation={self.saturation}" + f", hue={self.hue})" + ) + return s + + +class RandomRotation(torch.nn.Module): + """Rotate the image by angle. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + + Args: + degrees (sequence or number): Range of degrees to select from. + If degrees is a number instead of sequence like (min, max), the range of degrees + will be (-degrees, +degrees). + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. + expand (bool, optional): Optional expansion flag. + If true, expands the output to make it large enough to hold the entire rotated image. + If false or omitted, make the output image the same size as the input image. + Note that the expand flag assumes rotation around the center and no translation. + center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner. + Default is the center of the image. + fill (sequence or number): Pixel fill value for the area outside the rotated + image. Default is ``0``. If given a number, the value is used for all bands respectively. + + .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters + + """ + + def __init__(self, degrees, interpolation=InterpolationMode.NEAREST, expand=False, center=None, fill=0): + super().__init__() + _log_api_usage_once(self) + + if isinstance(interpolation, int): + interpolation = _interpolation_modes_from_int(interpolation) + + self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) + + if center is not None: + _check_sequence_input(center, "center", req_sizes=(2,)) + + self.center = center + + self.interpolation = interpolation + self.expand = expand + + if fill is None: + fill = 0 + elif not isinstance(fill, (Sequence, numbers.Number)): + raise TypeError("Fill should be either a sequence or a number.") + + self.fill = fill + + @staticmethod + def get_params(degrees: List[float]) -> float: + """Get parameters for ``rotate`` for a random rotation. + + Returns: + float: angle parameter to be passed to ``rotate`` for random rotation. + """ + angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item()) + return angle + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be rotated. + + Returns: + PIL Image or Tensor: Rotated image. + """ + fill = self.fill + channels, _, _ = F.get_dimensions(img) + if isinstance(img, Tensor): + if isinstance(fill, (int, float)): + fill = [float(fill)] * channels + else: + fill = [float(f) for f in fill] + angle = self.get_params(self.degrees) + + return F.rotate(img, angle, self.interpolation, self.expand, self.center, fill) + + def __repr__(self) -> str: + interpolate_str = self.interpolation.value + format_string = self.__class__.__name__ + f"(degrees={self.degrees}" + format_string += f", interpolation={interpolate_str}" + format_string += f", expand={self.expand}" + if self.center is not None: + format_string += f", center={self.center}" + if self.fill is not None: + format_string += f", fill={self.fill}" + format_string += ")" + return format_string + + +class RandomAffine(torch.nn.Module): + """Random affine transformation of the image keeping center invariant. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + + Args: + degrees (sequence or number): Range of degrees to select from. + If degrees is a number instead of sequence like (min, max), the range of degrees + will be (-degrees, +degrees). Set to 0 to deactivate rotations. + translate (tuple, optional): tuple of maximum absolute fraction for horizontal + and vertical translations. For example translate=(a, b), then horizontal shift + is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is + randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default. + scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is + randomly sampled from the range a <= scale <= b. Will keep original scale by default. + shear (sequence or number, optional): Range of degrees to select from. + If shear is a number, a shear parallel to the x-axis in the range (-shear, +shear) + will be applied. Else if shear is a sequence of 2 values a shear parallel to the x-axis in the + range (shear[0], shear[1]) will be applied. Else if shear is a sequence of 4 values, + an x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied. + Will not apply shear by default. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. + fill (sequence or number): Pixel fill value for the area outside the transformed + image. Default is ``0``. If given a number, the value is used for all bands respectively. + center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner. + Default is the center of the image. + + .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters + + """ + + def __init__( + self, + degrees, + translate=None, + scale=None, + shear=None, + interpolation=InterpolationMode.NEAREST, + fill=0, + center=None, + ): + super().__init__() + _log_api_usage_once(self) + + if isinstance(interpolation, int): + interpolation = _interpolation_modes_from_int(interpolation) + + self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) + + if translate is not None: + _check_sequence_input(translate, "translate", req_sizes=(2,)) + for t in translate: + if not (0.0 <= t <= 1.0): + raise ValueError("translation values should be between 0 and 1") + self.translate = translate + + if scale is not None: + _check_sequence_input(scale, "scale", req_sizes=(2,)) + for s in scale: + if s <= 0: + raise ValueError("scale values should be positive") + self.scale = scale + + if shear is not None: + self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4)) + else: + self.shear = shear + + self.interpolation = interpolation + + if fill is None: + fill = 0 + elif not isinstance(fill, (Sequence, numbers.Number)): + raise TypeError("Fill should be either a sequence or a number.") + + self.fill = fill + + if center is not None: + _check_sequence_input(center, "center", req_sizes=(2,)) + + self.center = center + + @staticmethod + def get_params( + degrees: List[float], + translate: Optional[List[float]], + scale_ranges: Optional[List[float]], + shears: Optional[List[float]], + img_size: List[int], + ) -> Tuple[float, Tuple[int, int], float, Tuple[float, float]]: + """Get parameters for affine transformation + + Returns: + params to be passed to the affine transformation + """ + angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item()) + if translate is not None: + max_dx = float(translate[0] * img_size[0]) + max_dy = float(translate[1] * img_size[1]) + tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item())) + ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item())) + translations = (tx, ty) + else: + translations = (0, 0) + + if scale_ranges is not None: + scale = float(torch.empty(1).uniform_(scale_ranges[0], scale_ranges[1]).item()) + else: + scale = 1.0 + + shear_x = shear_y = 0.0 + if shears is not None: + shear_x = float(torch.empty(1).uniform_(shears[0], shears[1]).item()) + if len(shears) == 4: + shear_y = float(torch.empty(1).uniform_(shears[2], shears[3]).item()) + + shear = (shear_x, shear_y) + + return angle, translations, scale, shear + + def forward(self, img): + """ + img (PIL Image or Tensor): Image to be transformed. + + Returns: + PIL Image or Tensor: Affine transformed image. + """ + fill = self.fill + channels, height, width = F.get_dimensions(img) + if isinstance(img, Tensor): + if isinstance(fill, (int, float)): + fill = [float(fill)] * channels + else: + fill = [float(f) for f in fill] + + img_size = [width, height] # flip for keeping BC on get_params call + + ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size) + + return F.affine(img, *ret, interpolation=self.interpolation, fill=fill, center=self.center) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}(degrees={self.degrees}" + s += f", translate={self.translate}" if self.translate is not None else "" + s += f", scale={self.scale}" if self.scale is not None else "" + s += f", shear={self.shear}" if self.shear is not None else "" + s += f", interpolation={self.interpolation.value}" if self.interpolation != InterpolationMode.NEAREST else "" + s += f", fill={self.fill}" if self.fill != 0 else "" + s += f", center={self.center}" if self.center is not None else "" + s += ")" + + return s + + +class Grayscale(torch.nn.Module): + """Convert image to grayscale. + If the image is torch Tensor, it is expected + to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions + + Args: + num_output_channels (int): (1 or 3) number of channels desired for output image + + Returns: + PIL Image: Grayscale version of the input. + + - If ``num_output_channels == 1`` : returned image is single channel + - If ``num_output_channels == 3`` : returned image is 3 channel with r == g == b + + """ + + def __init__(self, num_output_channels=1): + super().__init__() + _log_api_usage_once(self) + self.num_output_channels = num_output_channels + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be converted to grayscale. + + Returns: + PIL Image or Tensor: Grayscaled image. + """ + return F.rgb_to_grayscale(img, num_output_channels=self.num_output_channels) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(num_output_channels={self.num_output_channels})" + + +class RandomGrayscale(torch.nn.Module): + """Randomly convert image to grayscale with a probability of p (default 0.1). + If the image is torch Tensor, it is expected + to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions + + Args: + p (float): probability that image should be converted to grayscale. + + Returns: + PIL Image or Tensor: Grayscale version of the input image with probability p and unchanged + with probability (1-p). + - If input image is 1 channel: grayscale version is 1 channel + - If input image is 3 channel: grayscale version is 3 channel with r == g == b + + """ + + def __init__(self, p=0.1): + super().__init__() + _log_api_usage_once(self) + self.p = p + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be converted to grayscale. + + Returns: + PIL Image or Tensor: Randomly grayscaled image. + """ + num_output_channels, _, _ = F.get_dimensions(img) + if torch.rand(1) < self.p: + return F.rgb_to_grayscale(img, num_output_channels=num_output_channels) + return img + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(p={self.p})" + + +class RandomErasing(torch.nn.Module): + """Randomly selects a rectangle region in a torch.Tensor image and erases its pixels. + This transform does not support PIL Image. + 'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896 + + Args: + p: probability that the random erasing operation will be performed. + scale: range of proportion of erased area against input image. + ratio: range of aspect ratio of erased area. + value: erasing value. Default is 0. If a single int, it is used to + erase all pixels. If a tuple of length 3, it is used to erase + R, G, B channels respectively. + If a str of 'random', erasing each pixel with random values. + inplace: boolean to make this transform inplace. Default set to False. + + Returns: + Erased Image. + + Example: + >>> transform = transforms.Compose([ + >>> transforms.RandomHorizontalFlip(), + >>> transforms.PILToTensor(), + >>> transforms.ConvertImageDtype(torch.float), + >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + >>> transforms.RandomErasing(), + >>> ]) + """ + + def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False): + super().__init__() + _log_api_usage_once(self) + if not isinstance(value, (numbers.Number, str, tuple, list)): + raise TypeError("Argument value should be either a number or str or a sequence") + if isinstance(value, str) and value != "random": + raise ValueError("If value is str, it should be 'random'") + if not isinstance(scale, Sequence): + raise TypeError("Scale should be a sequence") + if not isinstance(ratio, Sequence): + raise TypeError("Ratio should be a sequence") + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + warnings.warn("Scale and ratio should be of kind (min, max)") + if scale[0] < 0 or scale[1] > 1: + raise ValueError("Scale should be between 0 and 1") + if p < 0 or p > 1: + raise ValueError("Random erasing probability should be between 0 and 1") + + self.p = p + self.scale = scale + self.ratio = ratio + self.value = value + self.inplace = inplace + + @staticmethod + def get_params( + img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float], value: Optional[List[float]] = None + ) -> Tuple[int, int, int, int, Tensor]: + """Get parameters for ``erase`` for a random erasing. + + Args: + img (Tensor): Tensor image to be erased. + scale (sequence): range of proportion of erased area against input image. + ratio (sequence): range of aspect ratio of erased area. + value (list, optional): erasing value. If None, it is interpreted as "random" + (erasing each pixel with random values). If ``len(value)`` is 1, it is interpreted as a number, + i.e. ``value[0]``. + + Returns: + tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing. + """ + img_c, img_h, img_w = img.shape[-3], img.shape[-2], img.shape[-1] + area = img_h * img_w + + log_ratio = torch.log(torch.tensor(ratio)) + for _ in range(10): + erase_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() + aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item() + + h = int(round(math.sqrt(erase_area * aspect_ratio))) + w = int(round(math.sqrt(erase_area / aspect_ratio))) + if not (h < img_h and w < img_w): + continue + + if value is None: + v = torch.empty([img_c, h, w], dtype=torch.float32).normal_() + else: + v = torch.tensor(value)[:, None, None] + + i = torch.randint(0, img_h - h + 1, size=(1,)).item() + j = torch.randint(0, img_w - w + 1, size=(1,)).item() + return i, j, h, w, v + + # Return original image + return 0, 0, img_h, img_w, img + + def forward(self, img): + """ + Args: + img (Tensor): Tensor image to be erased. + + Returns: + img (Tensor): Erased Tensor image. + """ + if torch.rand(1) < self.p: + + # cast self.value to script acceptable type + if isinstance(self.value, (int, float)): + value = [float(self.value)] + elif isinstance(self.value, str): + value = None + elif isinstance(self.value, (list, tuple)): + value = [float(v) for v in self.value] + else: + value = self.value + + if value is not None and not (len(value) in (1, img.shape[-3])): + raise ValueError( + "If value is a sequence, it should have either a single value or " + f"{img.shape[-3]} (number of input channels)" + ) + + x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=value) + return F.erase(img, x, y, h, w, v, self.inplace) + return img + + def __repr__(self) -> str: + s = ( + f"{self.__class__.__name__}" + f"(p={self.p}, " + f"scale={self.scale}, " + f"ratio={self.ratio}, " + f"value={self.value}, " + f"inplace={self.inplace})" + ) + return s + + +class GaussianBlur(torch.nn.Module): + """Blurs image with randomly chosen Gaussian blur. + If the image is torch Tensor, it is expected + to have [..., C, H, W] shape, where ... means at most one leading dimension. + + Args: + kernel_size (int or sequence): Size of the Gaussian kernel. + sigma (float or tuple of float (min, max)): Standard deviation to be used for + creating kernel to perform blurring. If float, sigma is fixed. If it is tuple + of float (min, max), sigma is chosen uniformly at random to lie in the + given range. + + Returns: + PIL Image or Tensor: Gaussian blurred version of the input image. + + """ + + def __init__(self, kernel_size, sigma=(0.1, 2.0)): + super().__init__() + _log_api_usage_once(self) + self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers") + for ks in self.kernel_size: + if ks <= 0 or ks % 2 == 0: + raise ValueError("Kernel size value should be an odd and positive number.") + + if isinstance(sigma, numbers.Number): + if sigma <= 0: + raise ValueError("If sigma is a single number, it must be positive.") + sigma = (sigma, sigma) + elif isinstance(sigma, Sequence) and len(sigma) == 2: + if not 0.0 < sigma[0] <= sigma[1]: + raise ValueError("sigma values should be positive and of the form (min, max).") + else: + raise ValueError("sigma should be a single number or a list/tuple with length 2.") + + self.sigma = sigma + + @staticmethod + def get_params(sigma_min: float, sigma_max: float) -> float: + """Choose sigma for random gaussian blurring. + + Args: + sigma_min (float): Minimum standard deviation that can be chosen for blurring kernel. + sigma_max (float): Maximum standard deviation that can be chosen for blurring kernel. + + Returns: + float: Standard deviation to be passed to calculate kernel for gaussian blurring. + """ + return torch.empty(1).uniform_(sigma_min, sigma_max).item() + + def forward(self, img: Tensor) -> Tensor: + """ + Args: + img (PIL Image or Tensor): image to be blurred. + + Returns: + PIL Image or Tensor: Gaussian blurred image + """ + sigma = self.get_params(self.sigma[0], self.sigma[1]) + return F.gaussian_blur(img, self.kernel_size, [sigma, sigma]) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}(kernel_size={self.kernel_size}, sigma={self.sigma})" + return s + + +def _setup_size(size, error_msg): + if isinstance(size, numbers.Number): + return int(size), int(size) + + if isinstance(size, Sequence) and len(size) == 1: + return size[0], size[0] + + if len(size) != 2: + raise ValueError(error_msg) + + return size + + +def _check_sequence_input(x, name, req_sizes): + msg = req_sizes[0] if len(req_sizes) < 2 else " or ".join([str(s) for s in req_sizes]) + if not isinstance(x, Sequence): + raise TypeError(f"{name} should be a sequence of length {msg}.") + if len(x) not in req_sizes: + raise ValueError(f"{name} should be a sequence of length {msg}.") + + +def _setup_angle(x, name, req_sizes=(2,)): + if isinstance(x, numbers.Number): + if x < 0: + raise ValueError(f"If {name} is a single number, it must be positive.") + x = [-x, x] + else: + _check_sequence_input(x, name, req_sizes) + + return [float(d) for d in x] + + +class RandomInvert(torch.nn.Module): + """Inverts the colors of the given image randomly with a given probability. + If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, + where ... means it can have an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + p (float): probability of the image being color inverted. Default value is 0.5 + """ + + def __init__(self, p=0.5): + super().__init__() + _log_api_usage_once(self) + self.p = p + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be inverted. + + Returns: + PIL Image or Tensor: Randomly color inverted image. + """ + if torch.rand(1).item() < self.p: + return F.invert(img) + return img + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(p={self.p})" + + +class RandomPosterize(torch.nn.Module): + """Posterize the image randomly with a given probability by reducing the + number of bits for each color channel. If the image is torch Tensor, it should be of type torch.uint8, + and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + bits (int): number of bits to keep for each channel (0-8) + p (float): probability of the image being posterized. Default value is 0.5 + """ + + def __init__(self, bits, p=0.5): + super().__init__() + _log_api_usage_once(self) + self.bits = bits + self.p = p + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be posterized. + + Returns: + PIL Image or Tensor: Randomly posterized image. + """ + if torch.rand(1).item() < self.p: + return F.posterize(img, self.bits) + return img + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(bits={self.bits},p={self.p})" + + +class RandomSolarize(torch.nn.Module): + """Solarize the image randomly with a given probability by inverting all pixel + values above a threshold. If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, + where ... means it can have an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + threshold (float): all pixels equal or above this value are inverted. + p (float): probability of the image being solarized. Default value is 0.5 + """ + + def __init__(self, threshold, p=0.5): + super().__init__() + _log_api_usage_once(self) + self.threshold = threshold + self.p = p + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be solarized. + + Returns: + PIL Image or Tensor: Randomly solarized image. + """ + if torch.rand(1).item() < self.p: + return F.solarize(img, self.threshold) + return img + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(threshold={self.threshold},p={self.p})" + + +class RandomAdjustSharpness(torch.nn.Module): + """Adjust the sharpness of the image randomly with a given probability. If the image is torch Tensor, + it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + + Args: + sharpness_factor (float): How much to adjust the sharpness. Can be + any non-negative number. 0 gives a blurred image, 1 gives the + original image while 2 increases the sharpness by a factor of 2. + p (float): probability of the image being sharpened. Default value is 0.5 + """ + + def __init__(self, sharpness_factor, p=0.5): + super().__init__() + _log_api_usage_once(self) + self.sharpness_factor = sharpness_factor + self.p = p + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be sharpened. + + Returns: + PIL Image or Tensor: Randomly sharpened image. + """ + if torch.rand(1).item() < self.p: + return F.adjust_sharpness(img, self.sharpness_factor) + return img + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(sharpness_factor={self.sharpness_factor},p={self.p})" + + +class RandomAutocontrast(torch.nn.Module): + """Autocontrast the pixels of the given image randomly with a given probability. + If the image is torch Tensor, it is expected + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + p (float): probability of the image being autocontrasted. Default value is 0.5 + """ + + def __init__(self, p=0.5): + super().__init__() + _log_api_usage_once(self) + self.p = p + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be autocontrasted. + + Returns: + PIL Image or Tensor: Randomly autocontrasted image. + """ + if torch.rand(1).item() < self.p: + return F.autocontrast(img) + return img + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(p={self.p})" + + +class RandomEqualize(torch.nn.Module): + """Equalize the histogram of the given image randomly with a given probability. + If the image is torch Tensor, it is expected + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "P", "L" or "RGB". + + Args: + p (float): probability of the image being equalized. Default value is 0.5 + """ + + def __init__(self, p=0.5): + super().__init__() + _log_api_usage_once(self) + self.p = p + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be equalized. + + Returns: + PIL Image or Tensor: Randomly equalized image. + """ + if torch.rand(1).item() < self.p: + return F.equalize(img) + return img + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(p={self.p})" + + +class ElasticTransform(torch.nn.Module): + """Transform a tensor image with elastic transformations. + Given alpha and sigma, it will generate displacement + vectors for all pixels based on random offsets. Alpha controls the strength + and sigma controls the smoothness of the displacements. + The displacements are added to an identity grid and the resulting grid is + used to grid_sample from the image. + + Applications: + Randomly transforms the morphology of objects in images and produces a + see-through-water-like effect. + + Args: + alpha (float or sequence of floats): Magnitude of displacements. Default is 50.0. + sigma (float or sequence of floats): Smoothness of displacements. Default is 5.0. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. + fill (sequence or number): Pixel fill value for the area outside the transformed + image. Default is ``0``. If given a number, the value is used for all bands respectively. + + """ + + def __init__(self, alpha=50.0, sigma=5.0, interpolation=InterpolationMode.BILINEAR, fill=0): + super().__init__() + _log_api_usage_once(self) + if not isinstance(alpha, (float, Sequence)): + raise TypeError(f"alpha should be float or a sequence of floats. Got {type(alpha)}") + if isinstance(alpha, Sequence) and len(alpha) != 2: + raise ValueError(f"If alpha is a sequence its length should be 2. Got {len(alpha)}") + if isinstance(alpha, Sequence): + for element in alpha: + if not isinstance(element, float): + raise TypeError(f"alpha should be a sequence of floats. Got {type(element)}") + + if isinstance(alpha, float): + alpha = [float(alpha), float(alpha)] + if isinstance(alpha, (list, tuple)) and len(alpha) == 1: + alpha = [alpha[0], alpha[0]] + + self.alpha = alpha + + if not isinstance(sigma, (float, Sequence)): + raise TypeError(f"sigma should be float or a sequence of floats. Got {type(sigma)}") + if isinstance(sigma, Sequence) and len(sigma) != 2: + raise ValueError(f"If sigma is a sequence its length should be 2. Got {len(sigma)}") + if isinstance(sigma, Sequence): + for element in sigma: + if not isinstance(element, float): + raise TypeError(f"sigma should be a sequence of floats. Got {type(element)}") + + if isinstance(sigma, float): + sigma = [float(sigma), float(sigma)] + if isinstance(sigma, (list, tuple)) and len(sigma) == 1: + sigma = [sigma[0], sigma[0]] + + self.sigma = sigma + + if isinstance(interpolation, int): + interpolation = _interpolation_modes_from_int(interpolation) + self.interpolation = interpolation + + if isinstance(fill, (int, float)): + fill = [float(fill)] + elif isinstance(fill, (list, tuple)): + fill = [float(f) for f in fill] + else: + raise TypeError(f"fill should be int or float or a list or tuple of them. Got {type(fill)}") + self.fill = fill + + @staticmethod + def get_params(alpha: List[float], sigma: List[float], size: List[int]) -> Tensor: + dx = torch.rand([1, 1] + size) * 2 - 1 + if sigma[0] > 0.0: + kx = int(8 * sigma[0] + 1) + # if kernel size is even we have to make it odd + if kx % 2 == 0: + kx += 1 + dx = F.gaussian_blur(dx, [kx, kx], sigma) + dx = dx * alpha[0] / size[0] + + dy = torch.rand([1, 1] + size) * 2 - 1 + if sigma[1] > 0.0: + ky = int(8 * sigma[1] + 1) + # if kernel size is even we have to make it odd + if ky % 2 == 0: + ky += 1 + dy = F.gaussian_blur(dy, [ky, ky], sigma) + dy = dy * alpha[1] / size[1] + return torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2 + + def forward(self, tensor: Tensor) -> Tensor: + """ + Args: + tensor (PIL Image or Tensor): Image to be transformed. + + Returns: + PIL Image or Tensor: Transformed image. + """ + _, height, width = F.get_dimensions(tensor) + displacement = self.get_params(self.alpha, self.sigma, [height, width]) + return F.elastic_transform(tensor, displacement, self.interpolation, self.fill) + + def __repr__(self): + format_string = self.__class__.__name__ + format_string += f"(alpha={self.alpha}" + format_string += f", sigma={self.sigma}" + format_string += f", interpolation={self.interpolation}" + format_string += f", fill={self.fill})" + return format_string diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__init__.py b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2d66917b6eaceb6281b7d809cbb22462a511a91d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__init__.py @@ -0,0 +1,60 @@ +from torchvision.transforms import AutoAugmentPolicy, InterpolationMode # usort: skip + +from . import functional # usort: skip + +from ._transform import Transform # usort: skip + +from ._augment import CutMix, JPEG, MixUp, RandomErasing +from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide +from ._color import ( + ColorJitter, + Grayscale, + RandomAdjustSharpness, + RandomAutocontrast, + RandomChannelPermutation, + RandomEqualize, + RandomGrayscale, + RandomInvert, + RandomPhotometricDistort, + RandomPosterize, + RandomSolarize, + RGB, +) +from ._container import Compose, RandomApply, RandomChoice, RandomOrder +from ._geometry import ( + CenterCrop, + ElasticTransform, + FiveCrop, + Pad, + RandomAffine, + RandomCrop, + RandomHorizontalFlip, + RandomIoUCrop, + RandomPerspective, + RandomResize, + RandomResizedCrop, + RandomRotation, + RandomShortestSize, + RandomVerticalFlip, + RandomZoomOut, + Resize, + ScaleJitter, + TenCrop, +) +from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat +from ._misc import ( + ConvertImageDtype, + GaussianBlur, + GaussianNoise, + Identity, + Lambda, + LinearTransformation, + Normalize, + SanitizeBoundingBoxes, + ToDtype, +) +from ._temporal import UniformTemporalSubsample +from ._type_conversion import PILToTensor, ToImage, ToPILImage, ToPureTensor +from ._utils import check_type, get_bounding_boxes, has_all, has_any, query_chw, query_size + +from ._deprecated import ToTensor # usort: skip diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_augment.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_augment.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b989fb0ada3d748b93157859bad56367ae5d5230 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_augment.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_auto_augment.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_auto_augment.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61689126a12fd50dd04f25f5c3066ae241436207 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_auto_augment.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_color.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_color.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..206bb8517a668784cc5a928da2d2546e938e8e66 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_color.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_deprecated.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_deprecated.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4a2c96c38b3295c8958426eb90c6e92d2e39910 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_deprecated.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_misc.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_misc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..277e8d3b4b64f9157bcb1bc2a62ed270d15ac5cf Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_misc.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_temporal.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_temporal.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..353baedb999e9555d84641ac860575d4bfa29ae6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_temporal.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_transform.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_transform.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0af3646914ffed6d360b5fda9dbaf67b0042f8ac Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_transform.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_type_conversion.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_type_conversion.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0e1d51057aac8f25b3879f10295cb2bda3b8086 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_type_conversion.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca52044a777d9c31997d139cf8b55fe941d3b8d4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_augment.py b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_augment.py new file mode 100644 index 0000000000000000000000000000000000000000..b1dd508340848183d3330ecfadb91269410069ed --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_augment.py @@ -0,0 +1,369 @@ +import math +import numbers +import warnings +from typing import Any, Callable, Dict, List, Optional, Sequence, Union + +import PIL.Image +import torch +from torch.nn.functional import one_hot +from torch.utils._pytree import tree_flatten, tree_unflatten +from torchvision import transforms as _transforms, tv_tensors +from torchvision.transforms.v2 import functional as F + +from ._transform import _RandomApplyTransform, Transform +from ._utils import _check_sequence_input, _parse_labels_getter, has_any, is_pure_tensor, query_chw, query_size + + +class RandomErasing(_RandomApplyTransform): + """Randomly select a rectangle region in the input image or video and erase its pixels. + + This transform does not support PIL Image. + 'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896 + + Args: + p (float, optional): probability that the random erasing operation will be performed. + scale (tuple of float, optional): range of proportion of erased area against input image. + ratio (tuple of float, optional): range of aspect ratio of erased area. + value (number or tuple of numbers): erasing value. Default is 0. If a single int, it is used to + erase all pixels. If a tuple of length 3, it is used to erase + R, G, B channels respectively. + If a str of 'random', erasing each pixel with random values. + inplace (bool, optional): boolean to make this transform inplace. Default set to False. + + Returns: + Erased input. + + Example: + >>> from torchvision.transforms import v2 as transforms + >>> + >>> transform = transforms.Compose([ + >>> transforms.RandomHorizontalFlip(), + >>> transforms.PILToTensor(), + >>> transforms.ConvertImageDtype(torch.float), + >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + >>> transforms.RandomErasing(), + >>> ]) + """ + + _v1_transform_cls = _transforms.RandomErasing + + def _extract_params_for_v1_transform(self) -> Dict[str, Any]: + return dict( + super()._extract_params_for_v1_transform(), + value="random" if self.value is None else self.value, + ) + + def __init__( + self, + p: float = 0.5, + scale: Sequence[float] = (0.02, 0.33), + ratio: Sequence[float] = (0.3, 3.3), + value: float = 0.0, + inplace: bool = False, + ): + super().__init__(p=p) + if not isinstance(value, (numbers.Number, str, tuple, list)): + raise TypeError("Argument value should be either a number or str or a sequence") + if isinstance(value, str) and value != "random": + raise ValueError("If value is str, it should be 'random'") + if not isinstance(scale, Sequence): + raise TypeError("Scale should be a sequence") + if not isinstance(ratio, Sequence): + raise TypeError("Ratio should be a sequence") + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + warnings.warn("Scale and ratio should be of kind (min, max)") + if scale[0] < 0 or scale[1] > 1: + raise ValueError("Scale should be between 0 and 1") + self.scale = scale + self.ratio = ratio + if isinstance(value, (int, float)): + self.value = [float(value)] + elif isinstance(value, str): + self.value = None + elif isinstance(value, (list, tuple)): + self.value = [float(v) for v in value] + else: + self.value = value + self.inplace = inplace + + self._log_ratio = torch.log(torch.tensor(self.ratio)) + + def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: + if isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask)): + warnings.warn( + f"{type(self).__name__}() is currently passing through inputs of type " + f"tv_tensors.{type(inpt).__name__}. This will likely change in the future." + ) + return super()._call_kernel(functional, inpt, *args, **kwargs) + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + img_c, img_h, img_w = query_chw(flat_inputs) + + if self.value is not None and not (len(self.value) in (1, img_c)): + raise ValueError( + f"If value is a sequence, it should have either a single value or {img_c} (number of inpt channels)" + ) + + area = img_h * img_w + + log_ratio = self._log_ratio + for _ in range(10): + erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() + aspect_ratio = torch.exp( + torch.empty(1).uniform_( + log_ratio[0], # type: ignore[arg-type] + log_ratio[1], # type: ignore[arg-type] + ) + ).item() + + h = int(round(math.sqrt(erase_area * aspect_ratio))) + w = int(round(math.sqrt(erase_area / aspect_ratio))) + if not (h < img_h and w < img_w): + continue + + if self.value is None: + v = torch.empty([img_c, h, w], dtype=torch.float32).normal_() + else: + v = torch.tensor(self.value)[:, None, None] + + i = torch.randint(0, img_h - h + 1, size=(1,)).item() + j = torch.randint(0, img_w - w + 1, size=(1,)).item() + break + else: + i, j, h, w, v = 0, 0, img_h, img_w, None + + return dict(i=i, j=j, h=h, w=w, v=v) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if params["v"] is not None: + inpt = self._call_kernel(F.erase, inpt, **params, inplace=self.inplace) + + return inpt + + +class _BaseMixUpCutMix(Transform): + def __init__(self, *, alpha: float = 1.0, num_classes: Optional[int] = None, labels_getter="default") -> None: + super().__init__() + self.alpha = float(alpha) + self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) + + self.num_classes = num_classes + + self._labels_getter = _parse_labels_getter(labels_getter) + + def forward(self, *inputs): + inputs = inputs if len(inputs) > 1 else inputs[0] + flat_inputs, spec = tree_flatten(inputs) + needs_transform_list = self._needs_transform_list(flat_inputs) + + if has_any(flat_inputs, PIL.Image.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask): + raise ValueError(f"{type(self).__name__}() does not support PIL images, bounding boxes and masks.") + + labels = self._labels_getter(inputs) + if not isinstance(labels, torch.Tensor): + raise ValueError(f"The labels must be a tensor, but got {type(labels)} instead.") + if labels.ndim not in (1, 2): + raise ValueError( + f"labels should be index based with shape (batch_size,) " + f"or probability based with shape (batch_size, num_classes), " + f"but got a tensor of shape {labels.shape} instead." + ) + if labels.ndim == 2 and self.num_classes is not None and labels.shape[-1] != self.num_classes: + raise ValueError( + f"When passing 2D labels, " + f"the number of elements in last dimension must match num_classes: " + f"{labels.shape[-1]} != {self.num_classes}. " + f"You can Leave num_classes to None." + ) + if labels.ndim == 1 and self.num_classes is None: + raise ValueError("num_classes must be passed if the labels are index-based (1D)") + + params = { + "labels": labels, + "batch_size": labels.shape[0], + **self._get_params( + [inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform] + ), + } + + # By default, the labels will be False inside needs_transform_list, since they are a torch.Tensor coming + # after an image or video. However, we need to handle them in _transform, so we make sure to set them to True + needs_transform_list[next(idx for idx, inpt in enumerate(flat_inputs) if inpt is labels)] = True + flat_outputs = [ + self._transform(inpt, params) if needs_transform else inpt + for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) + ] + + return tree_unflatten(flat_outputs, spec) + + def _check_image_or_video(self, inpt: torch.Tensor, *, batch_size: int): + expected_num_dims = 5 if isinstance(inpt, tv_tensors.Video) else 4 + if inpt.ndim != expected_num_dims: + raise ValueError( + f"Expected a batched input with {expected_num_dims} dims, but got {inpt.ndim} dimensions instead." + ) + if inpt.shape[0] != batch_size: + raise ValueError( + f"The batch size of the image or video does not match the batch size of the labels: " + f"{inpt.shape[0]} != {batch_size}." + ) + + def _mixup_label(self, label: torch.Tensor, *, lam: float) -> torch.Tensor: + if label.ndim == 1: + label = one_hot(label, num_classes=self.num_classes) # type: ignore[arg-type] + if not label.dtype.is_floating_point: + label = label.float() + return label.roll(1, 0).mul_(1.0 - lam).add_(label.mul(lam)) + + +class MixUp(_BaseMixUpCutMix): + """Apply MixUp to the provided batch of images and labels. + + Paper: `mixup: Beyond Empirical Risk Minimization `_. + + .. note:: + This transform is meant to be used on **batches** of samples, not + individual images. See + :ref:`sphx_glr_auto_examples_transforms_plot_cutmix_mixup.py` for detailed usage + examples. + The sample pairing is deterministic and done by matching consecutive + samples in the batch, so the batch needs to be shuffled (this is an + implementation detail, not a guaranteed convention.) + + In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed + into a tensor of shape ``(batch_size, num_classes)``. + + Args: + alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1. + num_classes (int, optional): number of classes in the batch. Used for one-hot-encoding. + Can be None only if the labels are already one-hot-encoded. + labels_getter (callable or "default", optional): indicates how to identify the labels in the input. + By default, this will pick the second parameter as the labels if it's a tensor. This covers the most + common scenario where this transform is called as ``MixUp()(imgs_batch, labels_batch)``. + It can also be a callable that takes the same input as the transform, and returns the labels. + """ + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + return dict(lam=float(self._dist.sample(()))) # type: ignore[arg-type] + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + lam = params["lam"] + + if inpt is params["labels"]: + return self._mixup_label(inpt, lam=lam) + elif isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) or is_pure_tensor(inpt): + self._check_image_or_video(inpt, batch_size=params["batch_size"]) + + output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam)) + + if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)): + output = tv_tensors.wrap(output, like=inpt) + + return output + else: + return inpt + + +class CutMix(_BaseMixUpCutMix): + """Apply CutMix to the provided batch of images and labels. + + Paper: `CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features + `_. + + .. note:: + This transform is meant to be used on **batches** of samples, not + individual images. See + :ref:`sphx_glr_auto_examples_transforms_plot_cutmix_mixup.py` for detailed usage + examples. + The sample pairing is deterministic and done by matching consecutive + samples in the batch, so the batch needs to be shuffled (this is an + implementation detail, not a guaranteed convention.) + + In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed + into a tensor of shape ``(batch_size, num_classes)``. + + Args: + alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1. + num_classes (int, optional): number of classes in the batch. Used for one-hot-encoding. + Can be None only if the labels are already one-hot-encoded. + labels_getter (callable or "default", optional): indicates how to identify the labels in the input. + By default, this will pick the second parameter as the labels if it's a tensor. This covers the most + common scenario where this transform is called as ``CutMix()(imgs_batch, labels_batch)``. + It can also be a callable that takes the same input as the transform, and returns the labels. + """ + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + lam = float(self._dist.sample(())) # type: ignore[arg-type] + + H, W = query_size(flat_inputs) + + r_x = torch.randint(W, size=(1,)) + r_y = torch.randint(H, size=(1,)) + + r = 0.5 * math.sqrt(1.0 - lam) + r_w_half = int(r * W) + r_h_half = int(r * H) + + x1 = int(torch.clamp(r_x - r_w_half, min=0)) + y1 = int(torch.clamp(r_y - r_h_half, min=0)) + x2 = int(torch.clamp(r_x + r_w_half, max=W)) + y2 = int(torch.clamp(r_y + r_h_half, max=H)) + box = (x1, y1, x2, y2) + + lam_adjusted = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) + + return dict(box=box, lam_adjusted=lam_adjusted) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if inpt is params["labels"]: + return self._mixup_label(inpt, lam=params["lam_adjusted"]) + elif isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) or is_pure_tensor(inpt): + self._check_image_or_video(inpt, batch_size=params["batch_size"]) + + x1, y1, x2, y2 = params["box"] + rolled = inpt.roll(1, 0) + output = inpt.clone() + output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2] + + if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)): + output = tv_tensors.wrap(output, like=inpt) + + return output + else: + return inpt + + +class JPEG(Transform): + """Apply JPEG compression and decompression to the given images. + + If the input is a :class:`torch.Tensor`, it is expected + to be of dtype uint8, on CPU, and have [..., 3 or 1, H, W] shape, + where ... means an arbitrary number of leading dimensions. + + Args: + quality (sequence or number): JPEG quality, from 1 to 100. Lower means more compression. + If quality is a sequence like (min, max), it specifies the range of JPEG quality to + randomly select from (inclusive of both ends). + + Returns: + image with JPEG compression. + """ + + def __init__(self, quality: Union[int, Sequence[int]]): + super().__init__() + if isinstance(quality, int): + quality = [quality, quality] + else: + _check_sequence_input(quality, "quality", req_sizes=(2,)) + + if not (1 <= quality[0] <= quality[1] <= 100 and isinstance(quality[0], int) and isinstance(quality[1], int)): + raise ValueError(f"quality must be an integer from 1 to 100, got {quality =}") + + self.quality = quality + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + quality = torch.randint(self.quality[0], self.quality[1] + 1, ()).item() + return dict(quality=quality) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return self._call_kernel(F.jpeg, inpt, quality=params["quality"]) diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_auto_augment.py b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_auto_augment.py new file mode 100644 index 0000000000000000000000000000000000000000..4dd7ba343aa360c38192d2c8ac88481093ed8c93 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_auto_augment.py @@ -0,0 +1,627 @@ +import math +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union + +import PIL.Image +import torch + +from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec +from torchvision import transforms as _transforms, tv_tensors +from torchvision.transforms import _functional_tensor as _FT +from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, InterpolationMode, Transform +from torchvision.transforms.v2.functional._geometry import _check_interpolation +from torchvision.transforms.v2.functional._meta import get_size +from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT + +from ._utils import _get_fill, _setup_fill_arg, check_type, is_pure_tensor + + +ImageOrVideo = Union[torch.Tensor, PIL.Image.Image, tv_tensors.Image, tv_tensors.Video] + + +class _AutoAugmentBase(Transform): + def __init__( + self, + *, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None, + ) -> None: + super().__init__() + self.interpolation = _check_interpolation(interpolation) + self.fill = fill + self._fill = _setup_fill_arg(fill) + + def _extract_params_for_v1_transform(self) -> Dict[str, Any]: + params = super()._extract_params_for_v1_transform() + + if isinstance(params["fill"], dict): + raise ValueError(f"{type(self).__name__}() can not be scripted for when `fill` is a dictionary.") + + return params + + def _get_random_item(self, dct: Dict[str, Tuple[Callable, bool]]) -> Tuple[str, Tuple[Callable, bool]]: + keys = tuple(dct.keys()) + key = keys[int(torch.randint(len(keys), ()))] + return key, dct[key] + + def _flatten_and_extract_image_or_video( + self, + inputs: Any, + unsupported_types: Tuple[Type, ...] = (tv_tensors.BoundingBoxes, tv_tensors.Mask), + ) -> Tuple[Tuple[List[Any], TreeSpec, int], ImageOrVideo]: + flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0]) + needs_transform_list = self._needs_transform_list(flat_inputs) + + image_or_videos = [] + for idx, (inpt, needs_transform) in enumerate(zip(flat_inputs, needs_transform_list)): + if needs_transform and check_type( + inpt, + ( + tv_tensors.Image, + PIL.Image.Image, + is_pure_tensor, + tv_tensors.Video, + ), + ): + image_or_videos.append((idx, inpt)) + elif isinstance(inpt, unsupported_types): + raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()") + + if not image_or_videos: + raise TypeError("Found no image in the sample.") + if len(image_or_videos) > 1: + raise TypeError( + f"Auto augment transformations are only properly defined for a single image or video, " + f"but found {len(image_or_videos)}." + ) + + idx, image_or_video = image_or_videos[0] + return (flat_inputs, spec, idx), image_or_video + + def _unflatten_and_insert_image_or_video( + self, + flat_inputs_with_spec: Tuple[List[Any], TreeSpec, int], + image_or_video: ImageOrVideo, + ) -> Any: + flat_inputs, spec, idx = flat_inputs_with_spec + flat_inputs[idx] = image_or_video + return tree_unflatten(flat_inputs, spec) + + def _apply_image_or_video_transform( + self, + image: ImageOrVideo, + transform_id: str, + magnitude: float, + interpolation: Union[InterpolationMode, int], + fill: Dict[Union[Type, str], _FillTypeJIT], + ) -> ImageOrVideo: + # Note: this cast is wrong and is only here to make mypy happy (it disagrees with torchscript) + image = cast(torch.Tensor, image) + fill_ = _get_fill(fill, type(image)) + + if transform_id == "Identity": + return image + elif transform_id == "ShearX": + # magnitude should be arctan(magnitude) + # official autoaug: (1, level, 0, 0, 1, 0) + # https://github.com/tensorflow/models/blob/dd02069717128186b88afa8d857ce57d17957f03/research/autoaugment/augmentation_transforms.py#L290 + # compared to + # torchvision: (1, tan(level), 0, 0, 1, 0) + # https://github.com/pytorch/vision/blob/0c2373d0bba3499e95776e7936e207d8a1676e65/torchvision/transforms/functional.py#L976 + return F.affine( + image, + angle=0.0, + translate=[0, 0], + scale=1.0, + shear=[math.degrees(math.atan(magnitude)), 0.0], + interpolation=interpolation, + fill=fill_, + center=[0, 0], + ) + elif transform_id == "ShearY": + # magnitude should be arctan(magnitude) + # See above + return F.affine( + image, + angle=0.0, + translate=[0, 0], + scale=1.0, + shear=[0.0, math.degrees(math.atan(magnitude))], + interpolation=interpolation, + fill=fill_, + center=[0, 0], + ) + elif transform_id == "TranslateX": + return F.affine( + image, + angle=0.0, + translate=[int(magnitude), 0], + scale=1.0, + interpolation=interpolation, + shear=[0.0, 0.0], + fill=fill_, + ) + elif transform_id == "TranslateY": + return F.affine( + image, + angle=0.0, + translate=[0, int(magnitude)], + scale=1.0, + interpolation=interpolation, + shear=[0.0, 0.0], + fill=fill_, + ) + elif transform_id == "Rotate": + return F.rotate(image, angle=magnitude, interpolation=interpolation, fill=fill_) + elif transform_id == "Brightness": + return F.adjust_brightness(image, brightness_factor=1.0 + magnitude) + elif transform_id == "Color": + return F.adjust_saturation(image, saturation_factor=1.0 + magnitude) + elif transform_id == "Contrast": + return F.adjust_contrast(image, contrast_factor=1.0 + magnitude) + elif transform_id == "Sharpness": + return F.adjust_sharpness(image, sharpness_factor=1.0 + magnitude) + elif transform_id == "Posterize": + return F.posterize(image, bits=int(magnitude)) + elif transform_id == "Solarize": + bound = _FT._max_value(image.dtype) if isinstance(image, torch.Tensor) else 255.0 + return F.solarize(image, threshold=bound * magnitude) + elif transform_id == "AutoContrast": + return F.autocontrast(image) + elif transform_id == "Equalize": + return F.equalize(image) + elif transform_id == "Invert": + return F.invert(image) + else: + raise ValueError(f"No transform available for {transform_id}") + + +class AutoAugment(_AutoAugmentBase): + r"""AutoAugment data augmentation method based on + `"AutoAugment: Learning Augmentation Strategies from Data" `_. + + This transformation works on images and videos only. + + If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + policy (AutoAugmentPolicy, optional): Desired policy enum defined by + :class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``. + interpolation (InterpolationMode, optional): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + fill (sequence or number, optional): Pixel fill value for the area outside the transformed + image. If given a number, the value is used for all bands respectively. + """ + _v1_transform_cls = _transforms.AutoAugment + + _AUGMENTATION_SPACE = { + "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), + "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), + "TranslateX": ( + lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins), + True, + ), + "TranslateY": ( + lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins), + True, + ), + "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True), + "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Posterize": ( + lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(), + False, + ), + "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False), + "AutoContrast": (lambda num_bins, height, width: None, False), + "Equalize": (lambda num_bins, height, width: None, False), + "Invert": (lambda num_bins, height, width: None, False), + } + + def __init__( + self, + policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None, + ) -> None: + super().__init__(interpolation=interpolation, fill=fill) + self.policy = policy + self._policies = self._get_policies(policy) + + def _get_policies( + self, policy: AutoAugmentPolicy + ) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]: + if policy == AutoAugmentPolicy.IMAGENET: + return [ + (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)), + (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), + (("Equalize", 0.8, None), ("Equalize", 0.6, None)), + (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)), + (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), + (("Equalize", 0.4, None), ("Rotate", 0.8, 8)), + (("Solarize", 0.6, 3), ("Equalize", 0.6, None)), + (("Posterize", 0.8, 5), ("Equalize", 1.0, None)), + (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)), + (("Equalize", 0.6, None), ("Posterize", 0.4, 6)), + (("Rotate", 0.8, 8), ("Color", 0.4, 0)), + (("Rotate", 0.4, 9), ("Equalize", 0.6, None)), + (("Equalize", 0.0, None), ("Equalize", 0.8, None)), + (("Invert", 0.6, None), ("Equalize", 1.0, None)), + (("Color", 0.6, 4), ("Contrast", 1.0, 8)), + (("Rotate", 0.8, 8), ("Color", 1.0, 2)), + (("Color", 0.8, 8), ("Solarize", 0.8, 7)), + (("Sharpness", 0.4, 7), ("Invert", 0.6, None)), + (("ShearX", 0.6, 5), ("Equalize", 1.0, None)), + (("Color", 0.4, 0), ("Equalize", 0.6, None)), + (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), + (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), + (("Invert", 0.6, None), ("Equalize", 1.0, None)), + (("Color", 0.6, 4), ("Contrast", 1.0, 8)), + (("Equalize", 0.8, None), ("Equalize", 0.6, None)), + ] + elif policy == AutoAugmentPolicy.CIFAR10: + return [ + (("Invert", 0.1, None), ("Contrast", 0.2, 6)), + (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)), + (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)), + (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)), + (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)), + (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)), + (("Color", 0.4, 3), ("Brightness", 0.6, 7)), + (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)), + (("Equalize", 0.6, None), ("Equalize", 0.5, None)), + (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)), + (("Color", 0.7, 7), ("TranslateX", 0.5, 8)), + (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)), + (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)), + (("Brightness", 0.9, 6), ("Color", 0.2, 8)), + (("Solarize", 0.5, 2), ("Invert", 0.0, None)), + (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)), + (("Equalize", 0.2, None), ("Equalize", 0.6, None)), + (("Color", 0.9, 9), ("Equalize", 0.6, None)), + (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)), + (("Brightness", 0.1, 3), ("Color", 0.7, 0)), + (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)), + (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)), + (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)), + (("Equalize", 0.8, None), ("Invert", 0.1, None)), + (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)), + ] + elif policy == AutoAugmentPolicy.SVHN: + return [ + (("ShearX", 0.9, 4), ("Invert", 0.2, None)), + (("ShearY", 0.9, 8), ("Invert", 0.7, None)), + (("Equalize", 0.6, None), ("Solarize", 0.6, 6)), + (("Invert", 0.9, None), ("Equalize", 0.6, None)), + (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), + (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)), + (("ShearY", 0.9, 8), ("Invert", 0.4, None)), + (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)), + (("Invert", 0.9, None), ("AutoContrast", 0.8, None)), + (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), + (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)), + (("ShearY", 0.8, 8), ("Invert", 0.7, None)), + (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)), + (("Invert", 0.9, None), ("Equalize", 0.6, None)), + (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)), + (("Invert", 0.8, None), ("TranslateY", 0.0, 2)), + (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)), + (("Invert", 0.6, None), ("Rotate", 0.8, 4)), + (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)), + (("ShearX", 0.1, 6), ("Invert", 0.6, None)), + (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)), + (("ShearY", 0.8, 4), ("Invert", 0.8, None)), + (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)), + (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)), + (("ShearX", 0.7, 2), ("Invert", 0.1, None)), + ] + else: + raise ValueError(f"The provided policy {policy} is not recognized.") + + def forward(self, *inputs: Any) -> Any: + flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs) + height, width = get_size(image_or_video) # type: ignore[arg-type] + + policy = self._policies[int(torch.randint(len(self._policies), ()))] + + for transform_id, probability, magnitude_idx in policy: + if not torch.rand(()) <= probability: + continue + + magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id] + + magnitudes = magnitudes_fn(10, height, width) + if magnitudes is not None: + magnitude = float(magnitudes[magnitude_idx]) + if signed and torch.rand(()) <= 0.5: + magnitude *= -1 + else: + magnitude = 0.0 + + image_or_video = self._apply_image_or_video_transform( + image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill + ) + + return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video) + + +class RandAugment(_AutoAugmentBase): + r"""RandAugment data augmentation method based on + `"RandAugment: Practical automated data augmentation with a reduced search space" + `_. + + This transformation works on images and videos only. + + If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + num_ops (int, optional): Number of augmentation transformations to apply sequentially. + magnitude (int, optional): Magnitude for all the transformations. + num_magnitude_bins (int, optional): The number of different magnitude values. + interpolation (InterpolationMode, optional): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + fill (sequence or number, optional): Pixel fill value for the area outside the transformed + image. If given a number, the value is used for all bands respectively. + """ + + _v1_transform_cls = _transforms.RandAugment + _AUGMENTATION_SPACE = { + "Identity": (lambda num_bins, height, width: None, False), + "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), + "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), + "TranslateX": ( + lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins), + True, + ), + "TranslateY": ( + lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins), + True, + ), + "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True), + "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Posterize": ( + lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(), + False, + ), + "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False), + "AutoContrast": (lambda num_bins, height, width: None, False), + "Equalize": (lambda num_bins, height, width: None, False), + } + + def __init__( + self, + num_ops: int = 2, + magnitude: int = 9, + num_magnitude_bins: int = 31, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None, + ) -> None: + super().__init__(interpolation=interpolation, fill=fill) + self.num_ops = num_ops + self.magnitude = magnitude + self.num_magnitude_bins = num_magnitude_bins + + def forward(self, *inputs: Any) -> Any: + flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs) + height, width = get_size(image_or_video) # type: ignore[arg-type] + + for _ in range(self.num_ops): + transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) + magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width) + if magnitudes is not None: + magnitude = float(magnitudes[self.magnitude]) + if signed and torch.rand(()) <= 0.5: + magnitude *= -1 + else: + magnitude = 0.0 + image_or_video = self._apply_image_or_video_transform( + image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill + ) + + return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video) + + +class TrivialAugmentWide(_AutoAugmentBase): + r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in + `"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" `_. + + This transformation works on images and videos only. + + If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + num_magnitude_bins (int, optional): The number of different magnitude values. + interpolation (InterpolationMode, optional): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + fill (sequence or number, optional): Pixel fill value for the area outside the transformed + image. If given a number, the value is used for all bands respectively. + """ + + _v1_transform_cls = _transforms.TrivialAugmentWide + _AUGMENTATION_SPACE = { + "Identity": (lambda num_bins, height, width: None, False), + "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), + "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), + "TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True), + "TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True), + "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 135.0, num_bins), True), + "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), + "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), + "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), + "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), + "Posterize": ( + lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6))).round().int(), + False, + ), + "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False), + "AutoContrast": (lambda num_bins, height, width: None, False), + "Equalize": (lambda num_bins, height, width: None, False), + } + + def __init__( + self, + num_magnitude_bins: int = 31, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None, + ): + super().__init__(interpolation=interpolation, fill=fill) + self.num_magnitude_bins = num_magnitude_bins + + def forward(self, *inputs: Any) -> Any: + flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs) + height, width = get_size(image_or_video) # type: ignore[arg-type] + + transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) + + magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width) + if magnitudes is not None: + magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))]) + if signed and torch.rand(()) <= 0.5: + magnitude *= -1 + else: + magnitude = 0.0 + + image_or_video = self._apply_image_or_video_transform( + image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill + ) + return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video) + + +class AugMix(_AutoAugmentBase): + r"""AugMix data augmentation method based on + `"AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty" `_. + + This transformation works on images and videos only. + + If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + severity (int, optional): The severity of base augmentation operators. Default is ``3``. + mixture_width (int, optional): The number of augmentation chains. Default is ``3``. + chain_depth (int, optional): The depth of augmentation chains. A negative value denotes stochastic depth sampled from the interval [1, 3]. + Default is ``-1``. + alpha (float, optional): The hyperparameter for the probability distributions. Default is ``1.0``. + all_ops (bool, optional): Use all operations (including brightness, contrast, color and sharpness). Default is ``True``. + interpolation (InterpolationMode, optional): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + fill (sequence or number, optional): Pixel fill value for the area outside the transformed + image. If given a number, the value is used for all bands respectively. + """ + + _v1_transform_cls = _transforms.AugMix + + _PARTIAL_AUGMENTATION_SPACE = { + "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), + "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True), + "TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, width / 3.0, num_bins), True), + "TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, height / 3.0, num_bins), True), + "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True), + "Posterize": ( + lambda num_bins, height, width: (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(), + False, + ), + "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False), + "AutoContrast": (lambda num_bins, height, width: None, False), + "Equalize": (lambda num_bins, height, width: None, False), + } + _AUGMENTATION_SPACE: Dict[str, Tuple[Callable[[int, int, int], Optional[torch.Tensor]], bool]] = { + **_PARTIAL_AUGMENTATION_SPACE, + "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), + } + + def __init__( + self, + severity: int = 3, + mixture_width: int = 3, + chain_depth: int = -1, + alpha: float = 1.0, + all_ops: bool = True, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None, + ) -> None: + super().__init__(interpolation=interpolation, fill=fill) + self._PARAMETER_MAX = 10 + if not (1 <= severity <= self._PARAMETER_MAX): + raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.") + self.severity = severity + self.mixture_width = mixture_width + self.chain_depth = chain_depth + self.alpha = alpha + self.all_ops = all_ops + + def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor: + # Must be on a separate method so that we can overwrite it in tests. + return torch._sample_dirichlet(params) + + def forward(self, *inputs: Any) -> Any: + flat_inputs_with_spec, orig_image_or_video = self._flatten_and_extract_image_or_video(inputs) + height, width = get_size(orig_image_or_video) # type: ignore[arg-type] + + if isinstance(orig_image_or_video, torch.Tensor): + image_or_video = orig_image_or_video + else: # isinstance(inpt, PIL.Image.Image): + image_or_video = F.pil_to_tensor(orig_image_or_video) + + augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE + + orig_dims = list(image_or_video.shape) + expected_ndim = 5 if isinstance(orig_image_or_video, tv_tensors.Video) else 4 + batch = image_or_video.reshape([1] * max(expected_ndim - image_or_video.ndim, 0) + orig_dims) + batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1) + + # Sample the beta weights for combining the original and augmented image or video. To get Beta, we use a + # Dirichlet with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of + # augmented image or video. + m = self._sample_dirichlet( + torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1) + ) + + # Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images or videos. + combined_weights = self._sample_dirichlet( + torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1) + ) * m[:, 1].reshape([batch_dims[0], -1]) + + mix = m[:, 0].reshape(batch_dims) * batch + for i in range(self.mixture_width): + aug = batch + depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item()) + for _ in range(depth): + transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space) + + magnitudes = magnitudes_fn(self._PARAMETER_MAX, height, width) + if magnitudes is not None: + magnitude = float(magnitudes[int(torch.randint(self.severity, ()))]) + if signed and torch.rand(()) <= 0.5: + magnitude *= -1 + else: + magnitude = 0.0 + + aug = self._apply_image_or_video_transform(aug, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill) # type: ignore[assignment] + mix.add_(combined_weights[:, i].reshape(batch_dims) * aug) + mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype) + + if isinstance(orig_image_or_video, (tv_tensors.Image, tv_tensors.Video)): + mix = tv_tensors.wrap(mix, like=orig_image_or_video) + elif isinstance(orig_image_or_video, PIL.Image.Image): + mix = F.to_pil_image(mix) + + return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, mix) diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_color.py b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_color.py new file mode 100644 index 0000000000000000000000000000000000000000..49b4a8d8b10b3dc43c00586fff4a4db2bcfd0245 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_color.py @@ -0,0 +1,376 @@ +import collections.abc +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import torch +from torchvision import transforms as _transforms +from torchvision.transforms.v2 import functional as F, Transform + +from ._transform import _RandomApplyTransform +from ._utils import query_chw + + +class Grayscale(Transform): + """Convert images or videos to grayscale. + + If the input is a :class:`torch.Tensor`, it is expected + to have [..., 3 or 1, H, W] shape, where ... means an arbitrary number of leading dimensions + + Args: + num_output_channels (int): (1 or 3) number of channels desired for output image + """ + + _v1_transform_cls = _transforms.Grayscale + + def __init__(self, num_output_channels: int = 1): + super().__init__() + self.num_output_channels = num_output_channels + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return self._call_kernel(F.rgb_to_grayscale, inpt, num_output_channels=self.num_output_channels) + + +class RandomGrayscale(_RandomApplyTransform): + """Randomly convert image or videos to grayscale with a probability of p (default 0.1). + + If the input is a :class:`torch.Tensor`, it is expected to have [..., 3 or 1, H, W] shape, + where ... means an arbitrary number of leading dimensions + + The output has the same number of channels as the input. + + Args: + p (float): probability that image should be converted to grayscale. + """ + + _v1_transform_cls = _transforms.RandomGrayscale + + def __init__(self, p: float = 0.1) -> None: + super().__init__(p=p) + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + num_input_channels, *_ = query_chw(flat_inputs) + return dict(num_input_channels=num_input_channels) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return self._call_kernel(F.rgb_to_grayscale, inpt, num_output_channels=params["num_input_channels"]) + + +class RGB(Transform): + """Convert images or videos to RGB (if they are already not RGB). + + If the input is a :class:`torch.Tensor`, it is expected + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions + """ + + def __init__(self): + super().__init__() + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return self._call_kernel(F.grayscale_to_rgb, inpt) + + +class ColorJitter(Transform): + """Randomly change the brightness, contrast, saturation and hue of an image or video. + + If the input is a :class:`torch.Tensor`, it is expected + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, mode "1", "I", "F" and modes with transparency (alpha channel) are not supported. + + Args: + brightness (float or tuple of float (min, max)): How much to jitter brightness. + brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] + or the given [min, max]. Should be non negative numbers. + contrast (float or tuple of float (min, max)): How much to jitter contrast. + contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] + or the given [min, max]. Should be non-negative numbers. + saturation (float or tuple of float (min, max)): How much to jitter saturation. + saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] + or the given [min, max]. Should be non negative numbers. + hue (float or tuple of float (min, max)): How much to jitter hue. + hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. + Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. + To jitter hue, the pixel values of the input image has to be non-negative for conversion to HSV space; + thus it does not work if you normalize your image to an interval with negative values, + or use an interpolation that generates negative values before using this function. + """ + + _v1_transform_cls = _transforms.ColorJitter + + def _extract_params_for_v1_transform(self) -> Dict[str, Any]: + return {attr: value or 0 for attr, value in super()._extract_params_for_v1_transform().items()} + + def __init__( + self, + brightness: Optional[Union[float, Sequence[float]]] = None, + contrast: Optional[Union[float, Sequence[float]]] = None, + saturation: Optional[Union[float, Sequence[float]]] = None, + hue: Optional[Union[float, Sequence[float]]] = None, + ) -> None: + super().__init__() + self.brightness = self._check_input(brightness, "brightness") + self.contrast = self._check_input(contrast, "contrast") + self.saturation = self._check_input(saturation, "saturation") + self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False) + + def _check_input( + self, + value: Optional[Union[float, Sequence[float]]], + name: str, + center: float = 1.0, + bound: Tuple[float, float] = (0, float("inf")), + clip_first_on_zero: bool = True, + ) -> Optional[Tuple[float, float]]: + if value is None: + return None + + if isinstance(value, (int, float)): + if value < 0: + raise ValueError(f"If {name} is a single number, it must be non negative.") + value = [center - value, center + value] + if clip_first_on_zero: + value[0] = max(value[0], 0.0) + elif isinstance(value, collections.abc.Sequence) and len(value) == 2: + value = [float(v) for v in value] + else: + raise TypeError(f"{name}={value} should be a single number or a sequence with length 2.") + + if not bound[0] <= value[0] <= value[1] <= bound[1]: + raise ValueError(f"{name} values should be between {bound}, but got {value}.") + + return None if value[0] == value[1] == center else (float(value[0]), float(value[1])) + + @staticmethod + def _generate_value(left: float, right: float) -> float: + return torch.empty(1).uniform_(left, right).item() + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + fn_idx = torch.randperm(4) + + b = None if self.brightness is None else self._generate_value(self.brightness[0], self.brightness[1]) + c = None if self.contrast is None else self._generate_value(self.contrast[0], self.contrast[1]) + s = None if self.saturation is None else self._generate_value(self.saturation[0], self.saturation[1]) + h = None if self.hue is None else self._generate_value(self.hue[0], self.hue[1]) + + return dict(fn_idx=fn_idx, brightness_factor=b, contrast_factor=c, saturation_factor=s, hue_factor=h) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + output = inpt + brightness_factor = params["brightness_factor"] + contrast_factor = params["contrast_factor"] + saturation_factor = params["saturation_factor"] + hue_factor = params["hue_factor"] + for fn_id in params["fn_idx"]: + if fn_id == 0 and brightness_factor is not None: + output = self._call_kernel(F.adjust_brightness, output, brightness_factor=brightness_factor) + elif fn_id == 1 and contrast_factor is not None: + output = self._call_kernel(F.adjust_contrast, output, contrast_factor=contrast_factor) + elif fn_id == 2 and saturation_factor is not None: + output = self._call_kernel(F.adjust_saturation, output, saturation_factor=saturation_factor) + elif fn_id == 3 and hue_factor is not None: + output = self._call_kernel(F.adjust_hue, output, hue_factor=hue_factor) + return output + + +class RandomChannelPermutation(Transform): + """Randomly permute the channels of an image or video""" + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + num_channels, *_ = query_chw(flat_inputs) + return dict(permutation=torch.randperm(num_channels)) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return self._call_kernel(F.permute_channels, inpt, params["permutation"]) + + +class RandomPhotometricDistort(Transform): + """Randomly distorts the image or video as used in `SSD: Single Shot + MultiBox Detector `_. + + This transform relies on :class:`~torchvision.transforms.v2.ColorJitter` + under the hood to adjust the contrast, saturation, hue, brightness, and also + randomly permutes channels. + + Args: + brightness (tuple of float (min, max), optional): How much to jitter brightness. + brightness_factor is chosen uniformly from [min, max]. Should be non negative numbers. + contrast (tuple of float (min, max), optional): How much to jitter contrast. + contrast_factor is chosen uniformly from [min, max]. Should be non-negative numbers. + saturation (tuple of float (min, max), optional): How much to jitter saturation. + saturation_factor is chosen uniformly from [min, max]. Should be non negative numbers. + hue (tuple of float (min, max), optional): How much to jitter hue. + hue_factor is chosen uniformly from [min, max]. Should have -0.5 <= min <= max <= 0.5. + To jitter hue, the pixel values of the input image has to be non-negative for conversion to HSV space; + thus it does not work if you normalize your image to an interval with negative values, + or use an interpolation that generates negative values before using this function. + p (float, optional) probability each distortion operation (contrast, saturation, ...) to be applied. + Default is 0.5. + """ + + def __init__( + self, + brightness: Tuple[float, float] = (0.875, 1.125), + contrast: Tuple[float, float] = (0.5, 1.5), + saturation: Tuple[float, float] = (0.5, 1.5), + hue: Tuple[float, float] = (-0.05, 0.05), + p: float = 0.5, + ): + super().__init__() + self.brightness = brightness + self.contrast = contrast + self.hue = hue + self.saturation = saturation + self.p = p + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + num_channels, *_ = query_chw(flat_inputs) + params: Dict[str, Any] = { + key: ColorJitter._generate_value(range[0], range[1]) if torch.rand(1) < self.p else None + for key, range in [ + ("brightness_factor", self.brightness), + ("contrast_factor", self.contrast), + ("saturation_factor", self.saturation), + ("hue_factor", self.hue), + ] + } + params["contrast_before"] = bool(torch.rand(()) < 0.5) + params["channel_permutation"] = torch.randperm(num_channels) if torch.rand(1) < self.p else None + return params + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if params["brightness_factor"] is not None: + inpt = self._call_kernel(F.adjust_brightness, inpt, brightness_factor=params["brightness_factor"]) + if params["contrast_factor"] is not None and params["contrast_before"]: + inpt = self._call_kernel(F.adjust_contrast, inpt, contrast_factor=params["contrast_factor"]) + if params["saturation_factor"] is not None: + inpt = self._call_kernel(F.adjust_saturation, inpt, saturation_factor=params["saturation_factor"]) + if params["hue_factor"] is not None: + inpt = self._call_kernel(F.adjust_hue, inpt, hue_factor=params["hue_factor"]) + if params["contrast_factor"] is not None and not params["contrast_before"]: + inpt = self._call_kernel(F.adjust_contrast, inpt, contrast_factor=params["contrast_factor"]) + if params["channel_permutation"] is not None: + inpt = self._call_kernel(F.permute_channels, inpt, permutation=params["channel_permutation"]) + return inpt + + +class RandomEqualize(_RandomApplyTransform): + """Equalize the histogram of the given image or video with a given probability. + + If the input is a :class:`torch.Tensor`, it is expected + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "P", "L" or "RGB". + + Args: + p (float): probability of the image being equalized. Default value is 0.5 + """ + + _v1_transform_cls = _transforms.RandomEqualize + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return self._call_kernel(F.equalize, inpt) + + +class RandomInvert(_RandomApplyTransform): + """Inverts the colors of the given image or video with a given probability. + + If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, + where ... means it can have an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + p (float): probability of the image being color inverted. Default value is 0.5 + """ + + _v1_transform_cls = _transforms.RandomInvert + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return self._call_kernel(F.invert, inpt) + + +class RandomPosterize(_RandomApplyTransform): + """Posterize the image or video with a given probability by reducing the + number of bits for each color channel. + + If the input is a :class:`torch.Tensor`, it should be of type torch.uint8, + and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + bits (int): number of bits to keep for each channel (0-8) + p (float): probability of the image being posterized. Default value is 0.5 + """ + + _v1_transform_cls = _transforms.RandomPosterize + + def __init__(self, bits: int, p: float = 0.5) -> None: + super().__init__(p=p) + self.bits = bits + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return self._call_kernel(F.posterize, inpt, bits=self.bits) + + +class RandomSolarize(_RandomApplyTransform): + """Solarize the image or video with a given probability by inverting all pixel + values above a threshold. + + If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, + where ... means it can have an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + threshold (float): all pixels equal or above this value are inverted. + p (float): probability of the image being solarized. Default value is 0.5 + """ + + _v1_transform_cls = _transforms.RandomSolarize + + def _extract_params_for_v1_transform(self) -> Dict[str, Any]: + params = super()._extract_params_for_v1_transform() + params["threshold"] = float(params["threshold"]) + return params + + def __init__(self, threshold: float, p: float = 0.5) -> None: + super().__init__(p=p) + self.threshold = threshold + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return self._call_kernel(F.solarize, inpt, threshold=self.threshold) + + +class RandomAutocontrast(_RandomApplyTransform): + """Autocontrast the pixels of the given image or video with a given probability. + + If the input is a :class:`torch.Tensor`, it is expected + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + p (float): probability of the image being autocontrasted. Default value is 0.5 + """ + + _v1_transform_cls = _transforms.RandomAutocontrast + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return self._call_kernel(F.autocontrast, inpt) + + +class RandomAdjustSharpness(_RandomApplyTransform): + """Adjust the sharpness of the image or video with a given probability. + + If the input is a :class:`torch.Tensor`, + it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + + Args: + sharpness_factor (float): How much to adjust the sharpness. Can be + any non-negative number. 0 gives a blurred image, 1 gives the + original image while 2 increases the sharpness by a factor of 2. + p (float): probability of the image being sharpened. Default value is 0.5 + """ + + _v1_transform_cls = _transforms.RandomAdjustSharpness + + def __init__(self, sharpness_factor: float, p: float = 0.5) -> None: + super().__init__(p=p) + self.sharpness_factor = sharpness_factor + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=self.sharpness_factor) diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_container.py b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_container.py new file mode 100644 index 0000000000000000000000000000000000000000..54de601c6967001f4f7b5951d84b36b7353dbf23 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_container.py @@ -0,0 +1,174 @@ +from typing import Any, Callable, Dict, List, Optional, Sequence, Union + +import torch + +from torch import nn +from torchvision import transforms as _transforms +from torchvision.transforms.v2 import Transform + + +class Compose(Transform): + """Composes several transforms together. + + This transform does not support torchscript. + Please, see the note below. + + Args: + transforms (list of ``Transform`` objects): list of transforms to compose. + + Example: + >>> transforms.Compose([ + >>> transforms.CenterCrop(10), + >>> transforms.PILToTensor(), + >>> transforms.ConvertImageDtype(torch.float), + >>> ]) + + .. note:: + In order to script the transformations, please use ``torch.nn.Sequential`` as below. + + >>> transforms = torch.nn.Sequential( + >>> transforms.CenterCrop(10), + >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + >>> ) + >>> scripted_transforms = torch.jit.script(transforms) + + Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require + `lambda` functions or ``PIL.Image``. + + """ + + def __init__(self, transforms: Sequence[Callable]) -> None: + super().__init__() + if not isinstance(transforms, Sequence): + raise TypeError("Argument transforms should be a sequence of callables") + elif not transforms: + raise ValueError("Pass at least one transform") + self.transforms = transforms + + def forward(self, *inputs: Any) -> Any: + needs_unpacking = len(inputs) > 1 + for transform in self.transforms: + outputs = transform(*inputs) + inputs = outputs if needs_unpacking else (outputs,) + return outputs + + def extra_repr(self) -> str: + format_string = [] + for t in self.transforms: + format_string.append(f" {t}") + return "\n".join(format_string) + + +class RandomApply(Transform): + """Apply randomly a list of transformations with a given probability. + + .. note:: + In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of + transforms as shown below: + + >>> transforms = transforms.RandomApply(torch.nn.ModuleList([ + >>> transforms.ColorJitter(), + >>> ]), p=0.3) + >>> scripted_transforms = torch.jit.script(transforms) + + Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require + `lambda` functions or ``PIL.Image``. + + Args: + transforms (sequence or torch.nn.Module): list of transformations + p (float): probability of applying the list of transforms + """ + + _v1_transform_cls = _transforms.RandomApply + + def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: float = 0.5) -> None: + super().__init__() + + if not isinstance(transforms, (Sequence, nn.ModuleList)): + raise TypeError("Argument transforms should be a sequence of callables or a `nn.ModuleList`") + self.transforms = transforms + + if not (0.0 <= p <= 1.0): + raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].") + self.p = p + + def _extract_params_for_v1_transform(self) -> Dict[str, Any]: + return {"transforms": self.transforms, "p": self.p} + + def forward(self, *inputs: Any) -> Any: + needs_unpacking = len(inputs) > 1 + + if torch.rand(1) >= self.p: + return inputs if needs_unpacking else inputs[0] + + for transform in self.transforms: + outputs = transform(*inputs) + inputs = outputs if needs_unpacking else (outputs,) + return outputs + + def extra_repr(self) -> str: + format_string = [] + for t in self.transforms: + format_string.append(f" {t}") + return "\n".join(format_string) + + +class RandomChoice(Transform): + """Apply single transformation randomly picked from a list. + + This transform does not support torchscript. + + Args: + transforms (sequence or torch.nn.Module): list of transformations + p (list of floats or None, optional): probability of each transform being picked. + If ``p`` doesn't sum to 1, it is automatically normalized. If ``None`` + (default), all transforms have the same probability. + """ + + def __init__( + self, + transforms: Sequence[Callable], + p: Optional[List[float]] = None, + ) -> None: + if not isinstance(transforms, Sequence): + raise TypeError("Argument transforms should be a sequence of callables") + + if p is None: + p = [1] * len(transforms) + elif len(p) != len(transforms): + raise ValueError(f"Length of p doesn't match the number of transforms: {len(p)} != {len(transforms)}") + + super().__init__() + + self.transforms = transforms + total = sum(p) + self.p = [prob / total for prob in p] + + def forward(self, *inputs: Any) -> Any: + idx = int(torch.multinomial(torch.tensor(self.p), 1)) + transform = self.transforms[idx] + return transform(*inputs) + + +class RandomOrder(Transform): + """Apply a list of transformations in a random order. + + This transform does not support torchscript. + + Args: + transforms (sequence or torch.nn.Module): list of transformations + """ + + def __init__(self, transforms: Sequence[Callable]) -> None: + if not isinstance(transforms, Sequence): + raise TypeError("Argument transforms should be a sequence of callables") + super().__init__() + self.transforms = transforms + + def forward(self, *inputs: Any) -> Any: + needs_unpacking = len(inputs) > 1 + for idx in torch.randperm(len(self.transforms)): + transform = self.transforms[idx] + outputs = transform(*inputs) + inputs = outputs if needs_unpacking else (outputs,) + return outputs diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_deprecated.py b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_deprecated.py new file mode 100644 index 0000000000000000000000000000000000000000..a664cb3fbbdb0d55cea3763458707be5d5671e38 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_deprecated.py @@ -0,0 +1,50 @@ +import warnings +from typing import Any, Dict, Union + +import numpy as np +import PIL.Image +import torch +from torchvision.transforms import functional as _F + +from torchvision.transforms.v2 import Transform + + +class ToTensor(Transform): + """[DEPRECATED] Use ``v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`` instead. + + Convert a PIL Image or ndarray to tensor and scale the values accordingly. + + .. warning:: + :class:`v2.ToTensor` is deprecated and will be removed in a future release. + Please use instead ``v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])``. + Output is equivalent up to float precision. + + This transform does not support torchscript. + + + Converts a PIL Image or numpy.ndarray (H x W x C) in the range + [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] + if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) + or if the numpy.ndarray has dtype = np.uint8 + + In the other cases, tensors are returned without scaling. + + .. note:: + Because the input image is scaled to [0.0, 1.0], this transformation should not be used when + transforming target image masks. See the `references`_ for implementing the transforms for image masks. + + .. _references: https://github.com/pytorch/vision/tree/main/references/segmentation + """ + + _transformed_types = (PIL.Image.Image, np.ndarray) + + def __init__(self) -> None: + warnings.warn( + "The transform `ToTensor()` is deprecated and will be removed in a future release. " + "Instead, please use `v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`." + "Output is equivalent up to float precision." + ) + super().__init__() + + def _transform(self, inpt: Union[PIL.Image.Image, np.ndarray], params: Dict[str, Any]) -> torch.Tensor: + return _F.to_tensor(inpt) diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_geometry.py b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..5d6b1841d7fe91bc8dec221bee9baab7b6371861 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_geometry.py @@ -0,0 +1,1416 @@ +import math +import numbers +import warnings +from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Type, Union + +import PIL.Image +import torch + +from torchvision import transforms as _transforms, tv_tensors +from torchvision.ops.boxes import box_iou +from torchvision.transforms.functional import _get_perspective_coeffs +from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform +from torchvision.transforms.v2.functional._utils import _FillType + +from ._transform import _RandomApplyTransform +from ._utils import ( + _check_padding_arg, + _check_padding_mode_arg, + _check_sequence_input, + _get_fill, + _setup_angle, + _setup_fill_arg, + _setup_number_or_seq, + _setup_size, + get_bounding_boxes, + has_all, + has_any, + is_pure_tensor, + query_size, +) + + +class RandomHorizontalFlip(_RandomApplyTransform): + """Horizontally flip the input with a given probability. + + If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`, + :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.) + it can have arbitrary number of leading batch dimensions. For example, + the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. + + Args: + p (float, optional): probability of the input being flipped. Default value is 0.5 + """ + + _v1_transform_cls = _transforms.RandomHorizontalFlip + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return self._call_kernel(F.horizontal_flip, inpt) + + +class RandomVerticalFlip(_RandomApplyTransform): + """Vertically flip the input with a given probability. + + If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`, + :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.) + it can have arbitrary number of leading batch dimensions. For example, + the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. + + Args: + p (float, optional): probability of the input being flipped. Default value is 0.5 + """ + + _v1_transform_cls = _transforms.RandomVerticalFlip + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return self._call_kernel(F.vertical_flip, inpt) + + +class Resize(Transform): + """Resize the input to the given size. + + If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`, + :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.) + it can have arbitrary number of leading batch dimensions. For example, + the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. + + Args: + size (sequence, int, or None): Desired + output size. + + - If size is a sequence like (h, w), output size will be matched to this. + - If size is an int, smaller edge of the image will be matched to this + number. i.e, if height > width, then image will be rescaled to + (size * height / width, size). + - If size is None, the output shape is determined by the ``max_size`` + parameter. + + .. note:: + In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``. + interpolation (InterpolationMode, optional): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``, + ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. + max_size (int, optional): The maximum allowed for the longer edge of + the resized image. + + - If ``size`` is an int: if the longer edge of the image is greater + than ``max_size`` after being resized according to ``size``, + ``size`` will be overruled so that the longer edge is equal to + ``max_size``. As a result, the smaller edge may be shorter than + ``size``. This is only supported if ``size`` is an int (or a + sequence of length 1 in torchscript mode). + - If ``size`` is None: the longer edge of the image will be matched + to max_size. i.e, if height > width, then image will be rescaled + to (max_size, max_size * width / height). + + This should be left to ``None`` (default) when ``size`` is a + sequence. + + antialias (bool, optional): Whether to apply antialiasing. + It only affects **tensors** with bilinear or bicubic modes and it is + ignored otherwise: on PIL images, antialiasing is always applied on + bilinear or bicubic modes; on other modes (for PIL images and + tensors), antialiasing makes no sense and this parameter is ignored. + Possible values are: + + - ``True`` (default): will apply antialiasing for bilinear or bicubic modes. + Other mode aren't affected. This is probably what you want to use. + - ``False``: will not apply antialiasing for tensors on any mode. PIL + images are still antialiased on bilinear or bicubic modes, because + PIL doesn't support no antialias. + - ``None``: equivalent to ``False`` for tensors and ``True`` for + PIL images. This value exists for legacy reasons and you probably + don't want to use it unless you really know what you are doing. + + The default value changed from ``None`` to ``True`` in + v0.17, for the PIL and Tensor backends to be consistent. + """ + + _v1_transform_cls = _transforms.Resize + + def __init__( + self, + size: Union[int, Sequence[int], None], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[bool] = True, + ) -> None: + super().__init__() + + if isinstance(size, int): + size = [size] + elif isinstance(size, Sequence) and len(size) in {1, 2}: + size = list(size) + elif size is None: + if not isinstance(max_size, int): + raise ValueError(f"max_size must be an integer when size is None, but got {max_size} instead.") + else: + raise ValueError( + f"size can be an integer, a sequence of one or two integers, or None, but got {size} instead." + ) + self.size = size + + self.interpolation = interpolation + self.max_size = max_size + self.antialias = antialias + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return self._call_kernel( + F.resize, + inpt, + self.size, + interpolation=self.interpolation, + max_size=self.max_size, + antialias=self.antialias, + ) + + +class CenterCrop(Transform): + """Crop the input at the center. + + If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`, + :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.) + it can have arbitrary number of leading batch dimensions. For example, + the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. + + If image size is smaller than output size along any edge, image is padded with 0 and then center cropped. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). + """ + + _v1_transform_cls = _transforms.CenterCrop + + def __init__(self, size: Union[int, Sequence[int]]): + super().__init__() + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return self._call_kernel(F.center_crop, inpt, output_size=self.size) + + +class RandomResizedCrop(Transform): + """Crop a random portion of the input and resize it to a given size. + + If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`, + :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.) + it can have arbitrary number of leading batch dimensions. For example, + the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. + + A crop of the original input is made: the crop has a random area (H * W) + and a random aspect ratio. This crop is finally resized to the given + size. This is popularly used to train the Inception networks. + + Args: + size (int or sequence): expected output size of the crop, for each edge. If size is an + int instead of sequence like (h, w), a square output size ``(size, size)`` is + made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). + + .. note:: + In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``. + scale (tuple of float, optional): Specifies the lower and upper bounds for the random area of the crop, + before resizing. The scale is defined with respect to the area of the original image. + ratio (tuple of float, optional): lower and upper bounds for the random aspect ratio of the crop, before + resizing. + interpolation (InterpolationMode, optional): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``, + ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. + antialias (bool, optional): Whether to apply antialiasing. + It only affects **tensors** with bilinear or bicubic modes and it is + ignored otherwise: on PIL images, antialiasing is always applied on + bilinear or bicubic modes; on other modes (for PIL images and + tensors), antialiasing makes no sense and this parameter is ignored. + Possible values are: + + - ``True`` (default): will apply antialiasing for bilinear or bicubic modes. + Other mode aren't affected. This is probably what you want to use. + - ``False``: will not apply antialiasing for tensors on any mode. PIL + images are still antialiased on bilinear or bicubic modes, because + PIL doesn't support no antialias. + - ``None``: equivalent to ``False`` for tensors and ``True`` for + PIL images. This value exists for legacy reasons and you probably + don't want to use it unless you really know what you are doing. + + The default value changed from ``None`` to ``True`` in + v0.17, for the PIL and Tensor backends to be consistent. + """ + + _v1_transform_cls = _transforms.RandomResizedCrop + + def __init__( + self, + size: Union[int, Sequence[int]], + scale: Tuple[float, float] = (0.08, 1.0), + ratio: Tuple[float, float] = (3.0 / 4.0, 4.0 / 3.0), + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + antialias: Optional[bool] = True, + ) -> None: + super().__init__() + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + + if not isinstance(scale, Sequence): + raise TypeError("Scale should be a sequence") + if not isinstance(ratio, Sequence): + raise TypeError("Ratio should be a sequence") + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + warnings.warn("Scale and ratio should be of kind (min, max)") + + self.scale = scale + self.ratio = ratio + self.interpolation = interpolation + self.antialias = antialias + + self._log_ratio = torch.log(torch.tensor(self.ratio)) + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + height, width = query_size(flat_inputs) + area = height * width + + log_ratio = self._log_ratio + for _ in range(10): + target_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() + aspect_ratio = torch.exp( + torch.empty(1).uniform_( + log_ratio[0], # type: ignore[arg-type] + log_ratio[1], # type: ignore[arg-type] + ) + ).item() + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if 0 < w <= width and 0 < h <= height: + i = torch.randint(0, height - h + 1, size=(1,)).item() + j = torch.randint(0, width - w + 1, size=(1,)).item() + break + else: + # Fallback to central crop + in_ratio = float(width) / float(height) + if in_ratio < min(self.ratio): + w = width + h = int(round(w / min(self.ratio))) + elif in_ratio > max(self.ratio): + h = height + w = int(round(h * max(self.ratio))) + else: # whole image + w = width + h = height + i = (height - h) // 2 + j = (width - w) // 2 + + return dict(top=i, left=j, height=h, width=w) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return self._call_kernel( + F.resized_crop, inpt, **params, size=self.size, interpolation=self.interpolation, antialias=self.antialias + ) + + +class FiveCrop(Transform): + """Crop the image or video into four corners and the central crop. + + If the input is a :class:`torch.Tensor` or a :class:`~torchvision.tv_tensors.Image` or a + :class:`~torchvision.tv_tensors.Video` it can have arbitrary number of leading batch dimensions. + For example, the image can have ``[..., C, H, W]`` shape. + + .. Note:: + This transform returns a tuple of images and there may be a mismatch in the number of + inputs and targets your Dataset returns. See below for an example of how to deal with + this. + + Args: + size (sequence or int): Desired output size of the crop. If size is an ``int`` + instead of sequence like (h, w), a square crop of size (size, size) is made. + If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). + + Example: + >>> class BatchMultiCrop(transforms.Transform): + ... def forward(self, sample: Tuple[Tuple[Union[tv_tensors.Image, tv_tensors.Video], ...], int]): + ... images_or_videos, labels = sample + ... batch_size = len(images_or_videos) + ... image_or_video = images_or_videos[0] + ... images_or_videos = tv_tensors.wrap(torch.stack(images_or_videos), like=image_or_video) + ... labels = torch.full((batch_size,), label, device=images_or_videos.device) + ... return images_or_videos, labels + ... + >>> image = tv_tensors.Image(torch.rand(3, 256, 256)) + >>> label = 3 + >>> transform = transforms.Compose([transforms.FiveCrop(224), BatchMultiCrop()]) + >>> images, labels = transform(image, label) + >>> images.shape + torch.Size([5, 3, 224, 224]) + >>> labels + tensor([3, 3, 3, 3, 3]) + """ + + _v1_transform_cls = _transforms.FiveCrop + + def __init__(self, size: Union[int, Sequence[int]]) -> None: + super().__init__() + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + + def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: + if isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask)): + warnings.warn( + f"{type(self).__name__}() is currently passing through inputs of type " + f"tv_tensors.{type(inpt).__name__}. This will likely change in the future." + ) + return super()._call_kernel(functional, inpt, *args, **kwargs) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return self._call_kernel(F.five_crop, inpt, self.size) + + def _check_inputs(self, flat_inputs: List[Any]) -> None: + if has_any(flat_inputs, tv_tensors.BoundingBoxes, tv_tensors.Mask): + raise TypeError(f"BoundingBoxes'es and Mask's are not supported by {type(self).__name__}()") + + +class TenCrop(Transform): + """Crop the image or video into four corners and the central crop plus the flipped version of + these (horizontal flipping is used by default). + + If the input is a :class:`torch.Tensor` or a :class:`~torchvision.tv_tensors.Image` or a + :class:`~torchvision.tv_tensors.Video` it can have arbitrary number of leading batch dimensions. + For example, the image can have ``[..., C, H, W]`` shape. + + See :class:`~torchvision.transforms.v2.FiveCrop` for an example. + + .. Note:: + This transform returns a tuple of images and there may be a mismatch in the number of + inputs and targets your Dataset returns. See below for an example of how to deal with + this. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). + vertical_flip (bool, optional): Use vertical flipping instead of horizontal + """ + + _v1_transform_cls = _transforms.TenCrop + + def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None: + super().__init__() + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + self.vertical_flip = vertical_flip + + def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: + if isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask)): + warnings.warn( + f"{type(self).__name__}() is currently passing through inputs of type " + f"tv_tensors.{type(inpt).__name__}. This will likely change in the future." + ) + return super()._call_kernel(functional, inpt, *args, **kwargs) + + def _check_inputs(self, flat_inputs: List[Any]) -> None: + if has_any(flat_inputs, tv_tensors.BoundingBoxes, tv_tensors.Mask): + raise TypeError(f"BoundingBoxes'es and Mask's are not supported by {type(self).__name__}()") + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return self._call_kernel(F.ten_crop, inpt, self.size, vertical_flip=self.vertical_flip) + + +class Pad(Transform): + """Pad the input on all sides with the given "pad" value. + + If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`, + :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.) + it can have arbitrary number of leading batch dimensions. For example, + the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. + + Args: + padding (int or sequence): Padding on each border. If a single int is provided this + is used to pad all borders. If sequence of length 2 is provided this is the padding + on left/right and top/bottom respectively. If a sequence of length 4 is provided + this is the padding for the left, top, right and bottom borders respectively. + + .. note:: + In torchscript mode padding as single int is not supported, use a sequence of + length 1: ``[padding, ]``. + fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant. + Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively. + Fill value can be also a dictionary mapping data type to the fill value, e.g. + ``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and + ``Mask`` will be filled with 0. + padding_mode (str, optional): Type of padding. Should be: constant, edge, reflect or symmetric. + Default is "constant". + + - constant: pads with a constant value, this value is specified with fill + + - edge: pads with the last value at the edge of the image. + + - reflect: pads with reflection of image without repeating the last value on the edge. + For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode + will result in [3, 2, 1, 2, 3, 4, 3, 2] + + - symmetric: pads with reflection of image repeating the last value on the edge. + For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode + will result in [2, 1, 1, 2, 3, 4, 4, 3] + """ + + _v1_transform_cls = _transforms.Pad + + def _extract_params_for_v1_transform(self) -> Dict[str, Any]: + params = super()._extract_params_for_v1_transform() + + if not (params["fill"] is None or isinstance(params["fill"], (int, float))): + raise ValueError(f"{type(self).__name__}() can only be scripted for a scalar `fill`, but got {self.fill}.") + + return params + + def __init__( + self, + padding: Union[int, Sequence[int]], + fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0, + padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", + ) -> None: + super().__init__() + + _check_padding_arg(padding) + _check_padding_mode_arg(padding_mode) + + # This cast does Sequence[int] -> List[int] and is required to make mypy happy + if not isinstance(padding, int): + padding = list(padding) + self.padding = padding + self.fill = fill + self._fill = _setup_fill_arg(fill) + self.padding_mode = padding_mode + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + fill = _get_fill(self._fill, type(inpt)) + return self._call_kernel(F.pad, inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type] + + +class RandomZoomOut(_RandomApplyTransform): + """ "Zoom out" transformation from + `"SSD: Single Shot MultiBox Detector" `_. + + This transformation randomly pads images, videos, bounding boxes and masks creating a zoom out effect. + Output spatial size is randomly sampled from original size up to a maximum size configured + with ``side_range`` parameter: + + .. code-block:: python + + r = uniform_sample(side_range[0], side_range[1]) + output_width = input_width * r + output_height = input_height * r + + If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`, + :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.) + it can have arbitrary number of leading batch dimensions. For example, + the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. + + Args: + fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant. + Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively. + Fill value can be also a dictionary mapping data type to the fill value, e.g. + ``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and + ``Mask`` will be filled with 0. + side_range (sequence of floats, optional): tuple of two floats defines minimum and maximum factors to + scale the input size. + p (float, optional): probability that the zoom operation will be performed. + """ + + def __init__( + self, + fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0, + side_range: Sequence[float] = (1.0, 4.0), + p: float = 0.5, + ) -> None: + super().__init__(p=p) + + self.fill = fill + self._fill = _setup_fill_arg(fill) + + _check_sequence_input(side_range, "side_range", req_sizes=(2,)) + + self.side_range = side_range + if side_range[0] < 1.0 or side_range[0] > side_range[1]: + raise ValueError(f"Invalid side range provided {side_range}.") + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + orig_h, orig_w = query_size(flat_inputs) + + r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0]) + canvas_width = int(orig_w * r) + canvas_height = int(orig_h * r) + + r = torch.rand(2) + left = int((canvas_width - orig_w) * r[0]) + top = int((canvas_height - orig_h) * r[1]) + right = canvas_width - (left + orig_w) + bottom = canvas_height - (top + orig_h) + padding = [left, top, right, bottom] + + return dict(padding=padding) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + fill = _get_fill(self._fill, type(inpt)) + return self._call_kernel(F.pad, inpt, **params, fill=fill) + + +class RandomRotation(Transform): + """Rotate the input by angle. + + If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`, + :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.) + it can have arbitrary number of leading batch dimensions. For example, + the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. + + Args: + degrees (sequence or number): Range of degrees to select from. + If degrees is a number instead of sequence like (min, max), the range of degrees + will be (-degrees, +degrees). + interpolation (InterpolationMode, optional): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. + expand (bool, optional): Optional expansion flag. + If true, expands the output to make it large enough to hold the entire rotated image. + If false or omitted, make the output image the same size as the input image. + Note that the expand flag assumes rotation around the center (see note below) and no translation. + center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner. + Default is the center of the image. + + .. note:: + + In theory, setting ``center`` has no effect if ``expand=True``, since the image center will become the + center of rotation. In practice however, due to numerical precision, this can lead to off-by-one + differences of the resulting image size compared to using the image center in the first place. Thus, when + setting ``expand=True``, it's best to leave ``center=None`` (default). + fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant. + Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively. + Fill value can be also a dictionary mapping data type to the fill value, e.g. + ``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and + ``Mask`` will be filled with 0. + + .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters + + """ + + _v1_transform_cls = _transforms.RandomRotation + + def __init__( + self, + degrees: Union[numbers.Number, Sequence], + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + expand: bool = False, + center: Optional[List[float]] = None, + fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0, + ) -> None: + super().__init__() + self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) + self.interpolation = interpolation + self.expand = expand + + self.fill = fill + self._fill = _setup_fill_arg(fill) + + if center is not None: + _check_sequence_input(center, "center", req_sizes=(2,)) + + self.center = center + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item() + return dict(angle=angle) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + fill = _get_fill(self._fill, type(inpt)) + return self._call_kernel( + F.rotate, + inpt, + **params, + interpolation=self.interpolation, + expand=self.expand, + center=self.center, + fill=fill, + ) + + +class RandomAffine(Transform): + """Random affine transformation the input keeping center invariant. + + If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`, + :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.) + it can have arbitrary number of leading batch dimensions. For example, + the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. + + Args: + degrees (sequence or number): Range of degrees to select from. + If degrees is a number instead of sequence like (min, max), the range of degrees + will be (-degrees, +degrees). Set to 0 to deactivate rotations. + translate (tuple, optional): tuple of maximum absolute fraction for horizontal + and vertical translations. For example translate=(a, b), then horizontal shift + is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is + randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default. + scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is + randomly sampled from the range a <= scale <= b. Will keep original scale by default. + shear (sequence or number, optional): Range of degrees to select from. + If shear is a number, a shear parallel to the x-axis in the range (-shear, +shear) + will be applied. Else if shear is a sequence of 2 values a shear parallel to the x-axis in the + range (shear[0], shear[1]) will be applied. Else if shear is a sequence of 4 values, + an x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied. + Will not apply shear by default. + interpolation (InterpolationMode, optional): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. + fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant. + Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively. + Fill value can be also a dictionary mapping data type to the fill value, e.g. + ``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and + ``Mask`` will be filled with 0. + center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner. + Default is the center of the image. + + .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters + + """ + + _v1_transform_cls = _transforms.RandomAffine + + def __init__( + self, + degrees: Union[numbers.Number, Sequence], + translate: Optional[Sequence[float]] = None, + scale: Optional[Sequence[float]] = None, + shear: Optional[Union[int, float, Sequence[float]]] = None, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0, + center: Optional[List[float]] = None, + ) -> None: + super().__init__() + self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) + if translate is not None: + _check_sequence_input(translate, "translate", req_sizes=(2,)) + for t in translate: + if not (0.0 <= t <= 1.0): + raise ValueError("translation values should be between 0 and 1") + self.translate = translate + if scale is not None: + _check_sequence_input(scale, "scale", req_sizes=(2,)) + for s in scale: + if s <= 0: + raise ValueError("scale values should be positive") + self.scale = scale + + if shear is not None: + self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4)) + else: + self.shear = shear + + self.interpolation = interpolation + self.fill = fill + self._fill = _setup_fill_arg(fill) + + if center is not None: + _check_sequence_input(center, "center", req_sizes=(2,)) + + self.center = center + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + height, width = query_size(flat_inputs) + + angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item() + if self.translate is not None: + max_dx = float(self.translate[0] * width) + max_dy = float(self.translate[1] * height) + tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item())) + ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item())) + translate = (tx, ty) + else: + translate = (0, 0) + + if self.scale is not None: + scale = torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() + else: + scale = 1.0 + + shear_x = shear_y = 0.0 + if self.shear is not None: + shear_x = torch.empty(1).uniform_(self.shear[0], self.shear[1]).item() + if len(self.shear) == 4: + shear_y = torch.empty(1).uniform_(self.shear[2], self.shear[3]).item() + + shear = (shear_x, shear_y) + return dict(angle=angle, translate=translate, scale=scale, shear=shear) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + fill = _get_fill(self._fill, type(inpt)) + return self._call_kernel( + F.affine, + inpt, + **params, + interpolation=self.interpolation, + fill=fill, + center=self.center, + ) + + +class RandomCrop(Transform): + """Crop the input at a random location. + + If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`, + :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.) + it can have arbitrary number of leading batch dimensions. For example, + the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). + padding (int or sequence, optional): Optional padding on each border + of the image. Default is None. If a single int is provided this + is used to pad all borders. If sequence of length 2 is provided this is the padding + on left/right and top/bottom respectively. If a sequence of length 4 is provided + this is the padding for the left, top, right and bottom borders respectively. + + .. note:: + In torchscript mode padding as single int is not supported, use a sequence of + length 1: ``[padding, ]``. + pad_if_needed (boolean, optional): It will pad the image if smaller than the + desired size to avoid raising an exception. Since cropping is done + after padding, the padding seems to be done at a random offset. + fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant. + Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively. + Fill value can be also a dictionary mapping data type to the fill value, e.g. + ``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and + ``Mask`` will be filled with 0. + padding_mode (str, optional): Type of padding. Should be: constant, edge, reflect or symmetric. + Default is constant. + + - constant: pads with a constant value, this value is specified with fill + + - edge: pads with the last value at the edge of the image. + + - reflect: pads with reflection of image without repeating the last value on the edge. + For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode + will result in [3, 2, 1, 2, 3, 4, 3, 2] + + - symmetric: pads with reflection of image repeating the last value on the edge. + For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode + will result in [2, 1, 1, 2, 3, 4, 4, 3] + """ + + _v1_transform_cls = _transforms.RandomCrop + + def _extract_params_for_v1_transform(self) -> Dict[str, Any]: + params = super()._extract_params_for_v1_transform() + + if not (params["fill"] is None or isinstance(params["fill"], (int, float))): + raise ValueError(f"{type(self).__name__}() can only be scripted for a scalar `fill`, but got {self.fill}.") + + padding = self.padding + if padding is not None: + pad_left, pad_right, pad_top, pad_bottom = padding + padding = [pad_left, pad_top, pad_right, pad_bottom] + params["padding"] = padding + + return params + + def __init__( + self, + size: Union[int, Sequence[int]], + padding: Optional[Union[int, Sequence[int]]] = None, + pad_if_needed: bool = False, + fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0, + padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", + ) -> None: + super().__init__() + + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + + if pad_if_needed or padding is not None: + if padding is not None: + _check_padding_arg(padding) + _check_padding_mode_arg(padding_mode) + + self.padding = F._geometry._parse_pad_padding(padding) if padding else None # type: ignore[arg-type] + self.pad_if_needed = pad_if_needed + self.fill = fill + self._fill = _setup_fill_arg(fill) + self.padding_mode = padding_mode + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + padded_height, padded_width = query_size(flat_inputs) + + if self.padding is not None: + pad_left, pad_right, pad_top, pad_bottom = self.padding + padded_height += pad_top + pad_bottom + padded_width += pad_left + pad_right + else: + pad_left = pad_right = pad_top = pad_bottom = 0 + + cropped_height, cropped_width = self.size + + if self.pad_if_needed: + if padded_height < cropped_height: + diff = cropped_height - padded_height + + pad_top += diff + pad_bottom += diff + padded_height += 2 * diff + + if padded_width < cropped_width: + diff = cropped_width - padded_width + + pad_left += diff + pad_right += diff + padded_width += 2 * diff + + if padded_height < cropped_height or padded_width < cropped_width: + raise ValueError( + f"Required crop size {(cropped_height, cropped_width)} is larger than " + f"{'padded ' if self.padding is not None else ''}input image size {(padded_height, padded_width)}." + ) + + # We need a different order here than we have in self.padding since this padding will be parsed again in `F.pad` + padding = [pad_left, pad_top, pad_right, pad_bottom] + needs_pad = any(padding) + + needs_vert_crop, top = ( + (True, int(torch.randint(0, padded_height - cropped_height + 1, size=()))) + if padded_height > cropped_height + else (False, 0) + ) + needs_horz_crop, left = ( + (True, int(torch.randint(0, padded_width - cropped_width + 1, size=()))) + if padded_width > cropped_width + else (False, 0) + ) + + return dict( + needs_crop=needs_vert_crop or needs_horz_crop, + top=top, + left=left, + height=cropped_height, + width=cropped_width, + needs_pad=needs_pad, + padding=padding, + ) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if params["needs_pad"]: + fill = _get_fill(self._fill, type(inpt)) + inpt = self._call_kernel(F.pad, inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode) + + if params["needs_crop"]: + inpt = self._call_kernel( + F.crop, inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"] + ) + + return inpt + + +class RandomPerspective(_RandomApplyTransform): + """Perform a random perspective transformation of the input with a given probability. + + If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`, + :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.) + it can have arbitrary number of leading batch dimensions. For example, + the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. + + Args: + distortion_scale (float, optional): argument to control the degree of distortion and ranges from 0 to 1. + Default is 0.5. + p (float, optional): probability of the input being transformed. Default is 0.5. + interpolation (InterpolationMode, optional): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. + fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant. + Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively. + Fill value can be also a dictionary mapping data type to the fill value, e.g. + ``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and + ``Mask`` will be filled with 0. + """ + + _v1_transform_cls = _transforms.RandomPerspective + + def __init__( + self, + distortion_scale: float = 0.5, + p: float = 0.5, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0, + ) -> None: + super().__init__(p=p) + + if not (0 <= distortion_scale <= 1): + raise ValueError("Argument distortion_scale value should be between 0 and 1") + + self.distortion_scale = distortion_scale + self.interpolation = interpolation + self.fill = fill + self._fill = _setup_fill_arg(fill) + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + height, width = query_size(flat_inputs) + + distortion_scale = self.distortion_scale + + half_height = height // 2 + half_width = width // 2 + bound_height = int(distortion_scale * half_height) + 1 + bound_width = int(distortion_scale * half_width) + 1 + topleft = [ + int(torch.randint(0, bound_width, size=(1,))), + int(torch.randint(0, bound_height, size=(1,))), + ] + topright = [ + int(torch.randint(width - bound_width, width, size=(1,))), + int(torch.randint(0, bound_height, size=(1,))), + ] + botright = [ + int(torch.randint(width - bound_width, width, size=(1,))), + int(torch.randint(height - bound_height, height, size=(1,))), + ] + botleft = [ + int(torch.randint(0, bound_width, size=(1,))), + int(torch.randint(height - bound_height, height, size=(1,))), + ] + startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]] + endpoints = [topleft, topright, botright, botleft] + perspective_coeffs = _get_perspective_coeffs(startpoints, endpoints) + return dict(coefficients=perspective_coeffs) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + fill = _get_fill(self._fill, type(inpt)) + return self._call_kernel( + F.perspective, + inpt, + startpoints=None, + endpoints=None, + fill=fill, + interpolation=self.interpolation, + **params, + ) + + +class ElasticTransform(Transform): + """Transform the input with elastic transformations. + + If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`, + :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.) + it can have arbitrary number of leading batch dimensions. For example, + the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. + + Given alpha and sigma, it will generate displacement + vectors for all pixels based on random offsets. Alpha controls the strength + and sigma controls the smoothness of the displacements. + The displacements are added to an identity grid and the resulting grid is + used to transform the input. + + .. note:: + Implementation to transform bounding boxes is approximative (not exact). + We construct an approximation of the inverse grid as ``inverse_grid = identity - displacement``. + This is not an exact inverse of the grid used to transform images, i.e. ``grid = identity + displacement``. + Our assumption is that ``displacement * displacement`` is small and can be ignored. + Large displacements would lead to large errors in the approximation. + + Applications: + Randomly transforms the morphology of objects in images and produces a + see-through-water-like effect. + + Args: + alpha (float or sequence of floats, optional): Magnitude of displacements. Default is 50.0. + sigma (float or sequence of floats, optional): Smoothness of displacements. Default is 5.0. + interpolation (InterpolationMode, optional): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. + fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant. + Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively. + Fill value can be also a dictionary mapping data type to the fill value, e.g. + ``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and + ``Mask`` will be filled with 0. + """ + + _v1_transform_cls = _transforms.ElasticTransform + + def __init__( + self, + alpha: Union[float, Sequence[float]] = 50.0, + sigma: Union[float, Sequence[float]] = 5.0, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0, + ) -> None: + super().__init__() + self.alpha = _setup_number_or_seq(alpha, "alpha") + self.sigma = _setup_number_or_seq(sigma, "sigma") + + self.interpolation = interpolation + self.fill = fill + self._fill = _setup_fill_arg(fill) + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + size = list(query_size(flat_inputs)) + + dx = torch.rand([1, 1] + size) * 2 - 1 + if self.sigma[0] > 0.0: + kx = int(8 * self.sigma[0] + 1) + # if kernel size is even we have to make it odd + if kx % 2 == 0: + kx += 1 + dx = self._call_kernel(F.gaussian_blur, dx, [kx, kx], list(self.sigma)) + dx = dx * self.alpha[0] / size[0] + + dy = torch.rand([1, 1] + size) * 2 - 1 + if self.sigma[1] > 0.0: + ky = int(8 * self.sigma[1] + 1) + # if kernel size is even we have to make it odd + if ky % 2 == 0: + ky += 1 + dy = self._call_kernel(F.gaussian_blur, dy, [ky, ky], list(self.sigma)) + dy = dy * self.alpha[1] / size[1] + displacement = torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2 + return dict(displacement=displacement) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + fill = _get_fill(self._fill, type(inpt)) + return self._call_kernel( + F.elastic, + inpt, + **params, + fill=fill, + interpolation=self.interpolation, + ) + + +class RandomIoUCrop(Transform): + """Random IoU crop transformation from + `"SSD: Single Shot MultiBox Detector" `_. + + This transformation requires an image or video data and ``tv_tensors.BoundingBoxes`` in the input. + + .. warning:: + In order to properly remove the bounding boxes below the IoU threshold, `RandomIoUCrop` + must be followed by :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes`, either immediately + after or later in the transforms pipeline. + + If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`, + :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.) + it can have arbitrary number of leading batch dimensions. For example, + the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. + + Args: + min_scale (float, optional): Minimum factors to scale the input size. + max_scale (float, optional): Maximum factors to scale the input size. + min_aspect_ratio (float, optional): Minimum aspect ratio for the cropped image or video. + max_aspect_ratio (float, optional): Maximum aspect ratio for the cropped image or video. + sampler_options (list of float, optional): List of minimal IoU (Jaccard) overlap between all the boxes and + a cropped image or video. Default, ``None`` which corresponds to ``[0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0]`` + trials (int, optional): Number of trials to find a crop for a given value of minimal IoU (Jaccard) overlap. + Default, 40. + """ + + def __init__( + self, + min_scale: float = 0.3, + max_scale: float = 1.0, + min_aspect_ratio: float = 0.5, + max_aspect_ratio: float = 2.0, + sampler_options: Optional[List[float]] = None, + trials: int = 40, + ): + super().__init__() + # Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174 + self.min_scale = min_scale + self.max_scale = max_scale + self.min_aspect_ratio = min_aspect_ratio + self.max_aspect_ratio = max_aspect_ratio + if sampler_options is None: + sampler_options = [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0] + self.options = sampler_options + self.trials = trials + + def _check_inputs(self, flat_inputs: List[Any]) -> None: + if not ( + has_all(flat_inputs, tv_tensors.BoundingBoxes) + and has_any(flat_inputs, PIL.Image.Image, tv_tensors.Image, is_pure_tensor) + ): + raise TypeError( + f"{type(self).__name__}() requires input sample to contain tensor or PIL images " + "and bounding boxes. Sample can also contain masks." + ) + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + orig_h, orig_w = query_size(flat_inputs) + bboxes = get_bounding_boxes(flat_inputs) + + while True: + # sample an option + idx = int(torch.randint(low=0, high=len(self.options), size=(1,))) + min_jaccard_overlap = self.options[idx] + if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option + return dict() + + for _ in range(self.trials): + # check the aspect ratio limitations + r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2) + new_w = int(orig_w * r[0]) + new_h = int(orig_h * r[1]) + aspect_ratio = new_w / new_h + if not (self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio): + continue + + # check for 0 area crops + r = torch.rand(2) + left = int((orig_w - new_w) * r[0]) + top = int((orig_h - new_h) * r[1]) + right = left + new_w + bottom = top + new_h + if left == right or top == bottom: + continue + + # check for any valid boxes with centers within the crop area + xyxy_bboxes = F.convert_bounding_box_format( + bboxes.as_subclass(torch.Tensor), + bboxes.format, + tv_tensors.BoundingBoxFormat.XYXY, + ) + cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2]) + cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3]) + is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom) + if not is_within_crop_area.any(): + continue + + # check at least 1 box with jaccard limitations + xyxy_bboxes = xyxy_bboxes[is_within_crop_area] + ious = box_iou( + xyxy_bboxes, + torch.tensor([[left, top, right, bottom]], dtype=xyxy_bboxes.dtype, device=xyxy_bboxes.device), + ) + if ious.max() < min_jaccard_overlap: + continue + + return dict(top=top, left=left, height=new_h, width=new_w, is_within_crop_area=is_within_crop_area) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + + if len(params) < 1: + return inpt + + output = self._call_kernel( + F.crop, inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"] + ) + + if isinstance(output, tv_tensors.BoundingBoxes): + # We "mark" the invalid boxes as degenreate, and they can be + # removed by a later call to SanitizeBoundingBoxes() + output[~params["is_within_crop_area"]] = 0 + + return output + + +class ScaleJitter(Transform): + """Perform Large Scale Jitter on the input according to + `"Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation" `_. + + If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`, + :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.) + it can have arbitrary number of leading batch dimensions. For example, + the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. + + Args: + target_size (tuple of int): Target size. This parameter defines base scale for jittering, + e.g. ``min(target_size[0] / width, target_size[1] / height)``. + scale_range (tuple of float, optional): Minimum and maximum of the scale range. Default, ``(0.1, 2.0)``. + interpolation (InterpolationMode, optional): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``, + ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. + antialias (bool, optional): Whether to apply antialiasing. + It only affects **tensors** with bilinear or bicubic modes and it is + ignored otherwise: on PIL images, antialiasing is always applied on + bilinear or bicubic modes; on other modes (for PIL images and + tensors), antialiasing makes no sense and this parameter is ignored. + Possible values are: + + - ``True`` (default): will apply antialiasing for bilinear or bicubic modes. + Other mode aren't affected. This is probably what you want to use. + - ``False``: will not apply antialiasing for tensors on any mode. PIL + images are still antialiased on bilinear or bicubic modes, because + PIL doesn't support no antialias. + - ``None``: equivalent to ``False`` for tensors and ``True`` for + PIL images. This value exists for legacy reasons and you probably + don't want to use it unless you really know what you are doing. + + The default value changed from ``None`` to ``True`` in + v0.17, for the PIL and Tensor backends to be consistent. + """ + + def __init__( + self, + target_size: Tuple[int, int], + scale_range: Tuple[float, float] = (0.1, 2.0), + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + antialias: Optional[bool] = True, + ): + super().__init__() + self.target_size = target_size + self.scale_range = scale_range + self.interpolation = interpolation + self.antialias = antialias + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + orig_height, orig_width = query_size(flat_inputs) + + scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0]) + r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale + new_width = int(orig_width * r) + new_height = int(orig_height * r) + + return dict(size=(new_height, new_width)) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return self._call_kernel( + F.resize, inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias + ) + + +class RandomShortestSize(Transform): + """Randomly resize the input. + + If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`, + :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.) + it can have arbitrary number of leading batch dimensions. For example, + the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. + + Args: + min_size (int or sequence of int): Minimum spatial size. Single integer value or a sequence of integer values. + max_size (int, optional): Maximum spatial size. Default, None. + interpolation (InterpolationMode, optional): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``, + ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. + antialias (bool, optional): Whether to apply antialiasing. + It only affects **tensors** with bilinear or bicubic modes and it is + ignored otherwise: on PIL images, antialiasing is always applied on + bilinear or bicubic modes; on other modes (for PIL images and + tensors), antialiasing makes no sense and this parameter is ignored. + Possible values are: + + - ``True`` (default): will apply antialiasing for bilinear or bicubic modes. + Other mode aren't affected. This is probably what you want to use. + - ``False``: will not apply antialiasing for tensors on any mode. PIL + images are still antialiased on bilinear or bicubic modes, because + PIL doesn't support no antialias. + - ``None``: equivalent to ``False`` for tensors and ``True`` for + PIL images. This value exists for legacy reasons and you probably + don't want to use it unless you really know what you are doing. + + The default value changed from ``None`` to ``True`` in + v0.17, for the PIL and Tensor backends to be consistent. + """ + + def __init__( + self, + min_size: Union[List[int], Tuple[int], int], + max_size: Optional[int] = None, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + antialias: Optional[bool] = True, + ): + super().__init__() + self.min_size = [min_size] if isinstance(min_size, int) else list(min_size) + self.max_size = max_size + self.interpolation = interpolation + self.antialias = antialias + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + orig_height, orig_width = query_size(flat_inputs) + + min_size = self.min_size[int(torch.randint(len(self.min_size), ()))] + r = min_size / min(orig_height, orig_width) + if self.max_size is not None: + r = min(r, self.max_size / max(orig_height, orig_width)) + + new_width = int(orig_width * r) + new_height = int(orig_height * r) + + return dict(size=(new_height, new_width)) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return self._call_kernel( + F.resize, inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias + ) + + +class RandomResize(Transform): + """Randomly resize the input. + + This transformation can be used together with ``RandomCrop`` as data augmentations to train + models on image segmentation task. + + Output spatial size is randomly sampled from the interval ``[min_size, max_size]``: + + .. code-block:: python + + size = uniform_sample(min_size, max_size) + output_width = size + output_height = size + + If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`, + :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.) + it can have arbitrary number of leading batch dimensions. For example, + the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. + + Args: + min_size (int): Minimum output size for random sampling + max_size (int): Maximum output size for random sampling + interpolation (InterpolationMode, optional): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``, + ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. + antialias (bool, optional): Whether to apply antialiasing. + It only affects **tensors** with bilinear or bicubic modes and it is + ignored otherwise: on PIL images, antialiasing is always applied on + bilinear or bicubic modes; on other modes (for PIL images and + tensors), antialiasing makes no sense and this parameter is ignored. + Possible values are: + + - ``True`` (default): will apply antialiasing for bilinear or bicubic modes. + Other mode aren't affected. This is probably what you want to use. + - ``False``: will not apply antialiasing for tensors on any mode. PIL + images are still antialiased on bilinear or bicubic modes, because + PIL doesn't support no antialias. + - ``None``: equivalent to ``False`` for tensors and ``True`` for + PIL images. This value exists for legacy reasons and you probably + don't want to use it unless you really know what you are doing. + + The default value changed from ``None`` to ``True`` in + v0.17, for the PIL and Tensor backends to be consistent. + """ + + def __init__( + self, + min_size: int, + max_size: int, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + antialias: Optional[bool] = True, + ) -> None: + super().__init__() + self.min_size = min_size + self.max_size = max_size + self.interpolation = interpolation + self.antialias = antialias + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + size = int(torch.randint(self.min_size, self.max_size, ())) + return dict(size=[size]) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return self._call_kernel( + F.resize, inpt, params["size"], interpolation=self.interpolation, antialias=self.antialias + ) diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_meta.py b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..01a356f46f5301a37cde800a8c0d1b9568a19704 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_meta.py @@ -0,0 +1,36 @@ +from typing import Any, Dict, Union + +from torchvision import tv_tensors +from torchvision.transforms.v2 import functional as F, Transform + + +class ConvertBoundingBoxFormat(Transform): + """Convert bounding box coordinates to the given ``format``, eg from "CXCYWH" to "XYXY". + + Args: + format (str or tv_tensors.BoundingBoxFormat): output bounding box format. + Possible values are defined by :class:`~torchvision.tv_tensors.BoundingBoxFormat` and + string values match the enums, e.g. "XYXY" or "XYWH" etc. + """ + + _transformed_types = (tv_tensors.BoundingBoxes,) + + def __init__(self, format: Union[str, tv_tensors.BoundingBoxFormat]) -> None: + super().__init__() + self.format = format + + def _transform(self, inpt: tv_tensors.BoundingBoxes, params: Dict[str, Any]) -> tv_tensors.BoundingBoxes: + return F.convert_bounding_box_format(inpt, new_format=self.format) # type: ignore[return-value, arg-type] + + +class ClampBoundingBoxes(Transform): + """Clamp bounding boxes to their corresponding image dimensions. + + The clamping is done according to the bounding boxes' ``canvas_size`` meta-data. + + """ + + _transformed_types = (tv_tensors.BoundingBoxes,) + + def _transform(self, inpt: tv_tensors.BoundingBoxes, params: Dict[str, Any]) -> tv_tensors.BoundingBoxes: + return F.clamp_bounding_boxes(inpt) # type: ignore[return-value] diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_misc.py b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_misc.py new file mode 100644 index 0000000000000000000000000000000000000000..93198f0009dccad73fcdac680baf945e341750e0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_misc.py @@ -0,0 +1,451 @@ +import warnings +from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union + +import PIL.Image + +import torch +from torch.utils._pytree import tree_flatten, tree_unflatten + +from torchvision import transforms as _transforms, tv_tensors +from torchvision.transforms.v2 import functional as F, Transform + +from ._utils import _parse_labels_getter, _setup_number_or_seq, _setup_size, get_bounding_boxes, has_any, is_pure_tensor + + +# TODO: do we want/need to expose this? +class Identity(Transform): + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return inpt + + +class Lambda(Transform): + """Apply a user-defined function as a transform. + + This transform does not support torchscript. + + Args: + lambd (function): Lambda/function to be used for transform. + """ + + _transformed_types = (object,) + + def __init__(self, lambd: Callable[[Any], Any], *types: Type): + super().__init__() + self.lambd = lambd + self.types = types or self._transformed_types + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if isinstance(inpt, self.types): + return self.lambd(inpt) + else: + return inpt + + def extra_repr(self) -> str: + extras = [] + name = getattr(self.lambd, "__name__", None) + if name: + extras.append(name) + extras.append(f"types={[type.__name__ for type in self.types]}") + return ", ".join(extras) + + +class LinearTransformation(Transform): + """Transform a tensor image or video with a square transformation matrix and a mean_vector computed offline. + + This transform does not support PIL Image. + Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and + subtract mean_vector from it which is then followed by computing the dot + product with the transformation matrix and then reshaping the tensor to its + original shape. + + Applications: + whitening transformation: Suppose X is a column vector zero-centered data. + Then compute the data covariance matrix [D x D] with torch.mm(X.t(), X), + perform SVD on this matrix and pass it as transformation_matrix. + + Args: + transformation_matrix (Tensor): tensor [D x D], D = C x H x W + mean_vector (Tensor): tensor [D], D = C x H x W + """ + + _v1_transform_cls = _transforms.LinearTransformation + + _transformed_types = (is_pure_tensor, tv_tensors.Image, tv_tensors.Video) + + def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tensor): + super().__init__() + if transformation_matrix.size(0) != transformation_matrix.size(1): + raise ValueError( + "transformation_matrix should be square. Got " + f"{tuple(transformation_matrix.size())} rectangular matrix." + ) + + if mean_vector.size(0) != transformation_matrix.size(0): + raise ValueError( + f"mean_vector should have the same length {mean_vector.size(0)}" + f" as any one of the dimensions of the transformation_matrix [{tuple(transformation_matrix.size())}]" + ) + + if transformation_matrix.device != mean_vector.device: + raise ValueError( + f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}" + ) + + if transformation_matrix.dtype != mean_vector.dtype: + raise ValueError( + f"Input tensors should have the same dtype. Got {transformation_matrix.dtype} and {mean_vector.dtype}" + ) + + self.transformation_matrix = transformation_matrix + self.mean_vector = mean_vector + + def _check_inputs(self, sample: Any) -> Any: + if has_any(sample, PIL.Image.Image): + raise TypeError(f"{type(self).__name__}() does not support PIL images.") + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + shape = inpt.shape + n = shape[-3] * shape[-2] * shape[-1] + if n != self.transformation_matrix.shape[0]: + raise ValueError( + "Input tensor and transformation matrix have incompatible shape." + + f"[{shape[-3]} x {shape[-2]} x {shape[-1]}] != " + + f"{self.transformation_matrix.shape[0]}" + ) + + if inpt.device.type != self.mean_vector.device.type: + raise ValueError( + "Input tensor should be on the same device as transformation matrix and mean vector. " + f"Got {inpt.device} vs {self.mean_vector.device}" + ) + + flat_inpt = inpt.reshape(-1, n) - self.mean_vector + + transformation_matrix = self.transformation_matrix.to(flat_inpt.dtype) + output = torch.mm(flat_inpt, transformation_matrix) + output = output.reshape(shape) + + if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)): + output = tv_tensors.wrap(output, like=inpt) + return output + + +class Normalize(Transform): + """Normalize a tensor image or video with mean and standard deviation. + + This transform does not support PIL Image. + Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n`` + channels, this transform will normalize each channel of the input + ``torch.*Tensor`` i.e., + ``output[channel] = (input[channel] - mean[channel]) / std[channel]`` + + .. note:: + This transform acts out of place, i.e., it does not mutate the input tensor. + + Args: + mean (sequence): Sequence of means for each channel. + std (sequence): Sequence of standard deviations for each channel. + inplace(bool,optional): Bool to make this operation in-place. + + """ + + _v1_transform_cls = _transforms.Normalize + + def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = False): + super().__init__() + self.mean = list(mean) + self.std = list(std) + self.inplace = inplace + + def _check_inputs(self, sample: Any) -> Any: + if has_any(sample, PIL.Image.Image): + raise TypeError(f"{type(self).__name__}() does not support PIL images.") + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return self._call_kernel(F.normalize, inpt, mean=self.mean, std=self.std, inplace=self.inplace) + + +class GaussianBlur(Transform): + """Blurs image with randomly chosen Gaussian blur kernel. + + The convolution will be using reflection padding corresponding to the kernel size, to maintain the input shape. + + If the input is a Tensor, it is expected + to have [..., C, H, W] shape, where ... means an arbitrary number of leading dimensions. + + Args: + kernel_size (int or sequence): Size of the Gaussian kernel. + sigma (float or tuple of float (min, max)): Standard deviation to be used for + creating kernel to perform blurring. If float, sigma is fixed. If it is tuple + of float (min, max), sigma is chosen uniformly at random to lie in the + given range. + """ + + _v1_transform_cls = _transforms.GaussianBlur + + def __init__( + self, kernel_size: Union[int, Sequence[int]], sigma: Union[int, float, Sequence[float]] = (0.1, 2.0) + ) -> None: + super().__init__() + self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers") + for ks in self.kernel_size: + if ks <= 0 or ks % 2 == 0: + raise ValueError("Kernel size value should be an odd and positive number.") + + self.sigma = _setup_number_or_seq(sigma, "sigma") + + if not 0.0 < self.sigma[0] <= self.sigma[1]: + raise ValueError(f"sigma values should be positive and of the form (min, max). Got {self.sigma}") + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + sigma = torch.empty(1).uniform_(self.sigma[0], self.sigma[1]).item() + return dict(sigma=[sigma, sigma]) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return self._call_kernel(F.gaussian_blur, inpt, self.kernel_size, **params) + + +class GaussianNoise(Transform): + """Add gaussian noise to images or videos. + + The input tensor is expected to be in [..., 1 or 3, H, W] format, + where ... means it can have an arbitrary number of leading dimensions. + Each image or frame in a batch will be transformed independently i.e. the + noise added to each image will be different. + + The input tensor is also expected to be of float dtype in ``[0, 1]``. + This transform does not support PIL images. + + Args: + mean (float): Mean of the sampled normal distribution. Default is 0. + sigma (float): Standard deviation of the sampled normal distribution. Default is 0.1. + clip (bool, optional): Whether to clip the values in ``[0, 1]`` after adding noise. Default is True. + """ + + def __init__(self, mean: float = 0.0, sigma: float = 0.1, clip=True) -> None: + super().__init__() + self.mean = mean + self.sigma = sigma + self.clip = clip + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return self._call_kernel(F.gaussian_noise, inpt, mean=self.mean, sigma=self.sigma, clip=self.clip) + + +class ToDtype(Transform): + """Converts the input to a specific dtype, optionally scaling the values for images or videos. + + .. note:: + ``ToDtype(dtype, scale=True)`` is the recommended replacement for ``ConvertImageDtype(dtype)``. + + Args: + dtype (``torch.dtype`` or dict of ``TVTensor`` -> ``torch.dtype``): The dtype to convert to. + If a ``torch.dtype`` is passed, e.g. ``torch.float32``, only images and videos will be converted + to that dtype: this is for compatibility with :class:`~torchvision.transforms.v2.ConvertImageDtype`. + A dict can be passed to specify per-tv_tensor conversions, e.g. + ``dtype={tv_tensors.Image: torch.float32, tv_tensors.Mask: torch.int64, "others":None}``. The "others" + key can be used as a catch-all for any other tv_tensor type, and ``None`` means no conversion. + scale (bool, optional): Whether to scale the values for images or videos. See :ref:`range_and_dtype`. + Default: ``False``. + """ + + _transformed_types = (torch.Tensor,) + + def __init__( + self, dtype: Union[torch.dtype, Dict[Union[Type, str], Optional[torch.dtype]]], scale: bool = False + ) -> None: + super().__init__() + + if not isinstance(dtype, (dict, torch.dtype)): + raise ValueError(f"dtype must be a dict or a torch.dtype, got {type(dtype)} instead") + + if ( + isinstance(dtype, dict) + and torch.Tensor in dtype + and any(cls in dtype for cls in [tv_tensors.Image, tv_tensors.Video]) + ): + warnings.warn( + "Got `dtype` values for `torch.Tensor` and either `tv_tensors.Image` or `tv_tensors.Video`. " + "Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) " + "in case a `tv_tensors.Image` or `tv_tensors.Video` is present in the input." + ) + self.dtype = dtype + self.scale = scale + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if isinstance(self.dtype, torch.dtype): + # For consistency / BC with ConvertImageDtype, we only care about images or videos when dtype + # is a simple torch.dtype + if not is_pure_tensor(inpt) and not isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)): + return inpt + + dtype: Optional[torch.dtype] = self.dtype + elif type(inpt) in self.dtype: + dtype = self.dtype[type(inpt)] + elif "others" in self.dtype: + dtype = self.dtype["others"] + else: + raise ValueError( + f"No dtype was specified for type {type(inpt)}. " + "If you only need to convert the dtype of images or videos, you can just pass e.g. dtype=torch.float32. " + "If you're passing a dict as dtype, " + 'you can use "others" as a catch-all key ' + 'e.g. dtype={tv_tensors.Mask: torch.int64, "others": None} to pass-through the rest of the inputs.' + ) + + supports_scaling = is_pure_tensor(inpt) or isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) + if dtype is None: + if self.scale and supports_scaling: + warnings.warn( + "scale was set to True but no dtype was specified for images or videos: no scaling will be done." + ) + return inpt + + return self._call_kernel(F.to_dtype, inpt, dtype=dtype, scale=self.scale) + + +class ConvertImageDtype(Transform): + """[DEPRECATED] Use ``v2.ToDtype(dtype, scale=True)`` instead. + + Convert input image to the given ``dtype`` and scale the values accordingly. + + .. warning:: + Consider using ``ToDtype(dtype, scale=True)`` instead. See :class:`~torchvision.transforms.v2.ToDtype`. + + This function does not support PIL Image. + + Args: + dtype (torch.dtype): Desired data type of the output + + .. note:: + + When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly. + If converted back and forth, this mismatch has no effect. + + Raises: + RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as + well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to + overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range + of the integer ``dtype``. + """ + + _v1_transform_cls = _transforms.ConvertImageDtype + + def __init__(self, dtype: torch.dtype = torch.float32) -> None: + super().__init__() + self.dtype = dtype + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return self._call_kernel(F.to_dtype, inpt, dtype=self.dtype, scale=True) + + +class SanitizeBoundingBoxes(Transform): + """Remove degenerate/invalid bounding boxes and their corresponding labels and masks. + + This transform removes bounding boxes and their associated labels/masks that: + + - are below a given ``min_size`` or ``min_area``: by default this also removes degenerate boxes that have e.g. X2 <= X1. + - have any coordinate outside of their corresponding image. You may want to + call :class:`~torchvision.transforms.v2.ClampBoundingBoxes` first to avoid undesired removals. + + It can also sanitize other tensors like the "iscrowd" or "area" properties from COCO + (see ``labels_getter`` parameter). + + It is recommended to call it at the end of a pipeline, before passing the + input to the models. It is critical to call this transform if + :class:`~torchvision.transforms.v2.RandomIoUCrop` was called. + If you want to be extra careful, you may call it after all transforms that + may modify bounding boxes but once at the end should be enough in most + cases. + + Args: + min_size (float, optional): The size below which bounding boxes are removed. Default is 1. + min_area (float, optional): The area below which bounding boxes are removed. Default is 1. + labels_getter (callable or str or None, optional): indicates how to identify the labels in the input + (or anything else that needs to be sanitized along with the bounding boxes). + By default, this will try to find a "labels" key in the input (case-insensitive), if + the input is a dict or it is a tuple whose second element is a dict. + This heuristic should work well with a lot of datasets, including the built-in torchvision datasets. + + It can also be a callable that takes the same input as the transform, and returns either: + + - A single tensor (the labels) + - A tuple/list of tensors, each of which will be subject to the same sanitization as the bounding boxes. + This is useful to sanitize multiple tensors like the labels, and the "iscrowd" or "area" properties + from COCO. + + If ``labels_getter`` is None then only bounding boxes are sanitized. + """ + + def __init__( + self, + min_size: float = 1.0, + min_area: float = 1.0, + labels_getter: Union[Callable[[Any], Any], str, None] = "default", + ) -> None: + super().__init__() + + if min_size < 1: + raise ValueError(f"min_size must be >= 1, got {min_size}.") + self.min_size = min_size + + if min_area < 1: + raise ValueError(f"min_area must be >= 1, got {min_area}.") + self.min_area = min_area + + self.labels_getter = labels_getter + self._labels_getter = _parse_labels_getter(labels_getter) + + def forward(self, *inputs: Any) -> Any: + inputs = inputs if len(inputs) > 1 else inputs[0] + + labels = self._labels_getter(inputs) + if labels is not None: + msg = "The labels in the input to forward() must be a tensor or None, got {type} instead." + if isinstance(labels, torch.Tensor): + labels = (labels,) + elif isinstance(labels, (tuple, list)): + for entry in labels: + if not isinstance(entry, torch.Tensor): + # TODO: we don't need to enforce tensors, just that entries are indexable as t[bool_mask] + raise ValueError(msg.format(type=type(entry))) + else: + raise ValueError(msg.format(type=type(labels))) + + flat_inputs, spec = tree_flatten(inputs) + boxes = get_bounding_boxes(flat_inputs) + + if labels is not None: + for label in labels: + if boxes.shape[0] != label.shape[0]: + raise ValueError( + f"Number of boxes (shape={boxes.shape}) and must match the number of labels." + f"Found labels with shape={label.shape})." + ) + + valid = F._misc._get_sanitize_bounding_boxes_mask( + boxes, + format=boxes.format, + canvas_size=boxes.canvas_size, + min_size=self.min_size, + min_area=self.min_area, + ) + + params = dict(valid=valid, labels=labels) + flat_outputs = [self._transform(inpt, params) for inpt in flat_inputs] + + return tree_unflatten(flat_outputs, spec) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + is_label = params["labels"] is not None and any(inpt is label for label in params["labels"]) + is_bounding_boxes_or_mask = isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask)) + + if not (is_label or is_bounding_boxes_or_mask): + return inpt + + output = inpt[params["valid"]] + + if is_label: + return output + else: + return tv_tensors.wrap(output, like=inpt) diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_temporal.py b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_temporal.py new file mode 100644 index 0000000000000000000000000000000000000000..c59d5078d46b4a7061c02aa3c82e0cbb09fb4e89 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_temporal.py @@ -0,0 +1,26 @@ +from typing import Any, Dict + +import torch +from torchvision.transforms.v2 import functional as F, Transform + + +class UniformTemporalSubsample(Transform): + """Uniformly subsample ``num_samples`` indices from the temporal dimension of the video. + + Videos are expected to be of shape ``[..., T, C, H, W]`` where ``T`` denotes the temporal dimension. + + When ``num_samples`` is larger than the size of temporal dimension of the video, it + will sample frames based on nearest neighbor interpolation. + + Args: + num_samples (int): The number of equispaced samples to be selected + """ + + _transformed_types = (torch.Tensor,) + + def __init__(self, num_samples: int): + super().__init__() + self.num_samples = num_samples + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return self._call_kernel(F.uniform_temporal_subsample, inpt, self.num_samples) diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_transform.py b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..b7eced5a2874d5d5a5d24f6580ee4f824b174285 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_transform.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +import enum +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union + +import PIL.Image +import torch +from torch import nn +from torch.utils._pytree import tree_flatten, tree_unflatten +from torchvision import tv_tensors +from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor +from torchvision.utils import _log_api_usage_once + +from .functional._utils import _get_kernel + + +class Transform(nn.Module): + + # Class attribute defining transformed types. Other types are passed-through without any transformation + # We support both Types and callables that are able to do further checks on the type of the input. + _transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (torch.Tensor, PIL.Image.Image) + + def __init__(self) -> None: + super().__init__() + _log_api_usage_once(self) + + def _check_inputs(self, flat_inputs: List[Any]) -> None: + pass + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + return dict() + + def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: + kernel = _get_kernel(functional, type(inpt), allow_passthrough=True) + return kernel(inpt, *args, **kwargs) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + raise NotImplementedError + + def forward(self, *inputs: Any) -> Any: + flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0]) + + self._check_inputs(flat_inputs) + + needs_transform_list = self._needs_transform_list(flat_inputs) + params = self._get_params( + [inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform] + ) + + flat_outputs = [ + self._transform(inpt, params) if needs_transform else inpt + for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) + ] + + return tree_unflatten(flat_outputs, spec) + + def _needs_transform_list(self, flat_inputs: List[Any]) -> List[bool]: + # Below is a heuristic on how to deal with pure tensor inputs: + # 1. Pure tensors, i.e. tensors that are not a tv_tensor, are passed through if there is an explicit image + # (`tv_tensors.Image` or `PIL.Image.Image`) or video (`tv_tensors.Video`) in the sample. + # 2. If there is no explicit image or video in the sample, only the first encountered pure tensor is + # transformed as image, while the rest is passed through. The order is defined by the returned `flat_inputs` + # of `tree_flatten`, which recurses depth-first through the input. + # + # This heuristic stems from two requirements: + # 1. We need to keep BC for single input pure tensors and treat them as images. + # 2. We don't want to treat all pure tensors as images, because some datasets like `CelebA` or `Widerface` + # return supplemental numerical data as tensors that cannot be transformed as images. + # + # The heuristic should work well for most people in practice. The only case where it doesn't is if someone + # tries to transform multiple pure tensors at the same time, expecting them all to be treated as images. + # However, this case wasn't supported by transforms v1 either, so there is no BC concern. + + needs_transform_list = [] + transform_pure_tensor = not has_any(flat_inputs, tv_tensors.Image, tv_tensors.Video, PIL.Image.Image) + for inpt in flat_inputs: + needs_transform = True + + if not check_type(inpt, self._transformed_types): + needs_transform = False + elif is_pure_tensor(inpt): + if transform_pure_tensor: + transform_pure_tensor = False + else: + needs_transform = False + needs_transform_list.append(needs_transform) + return needs_transform_list + + def extra_repr(self) -> str: + extra = [] + for name, value in self.__dict__.items(): + if name.startswith("_") or name == "training": + continue + + if not isinstance(value, (bool, int, float, str, tuple, list, enum.Enum)): + continue + + extra.append(f"{name}={value}") + + return ", ".join(extra) + + # This attribute should be set on all transforms that have a v1 equivalent. Doing so enables two things: + # 1. In case the v1 transform has a static `get_params` method, it will also be available under the same name on + # the v2 transform. See `__init_subclass__` for details. + # 2. The v2 transform will be JIT scriptable. See `_extract_params_for_v1_transform` and `__prepare_scriptable__` + # for details. + _v1_transform_cls: Optional[Type[nn.Module]] = None + + def __init_subclass__(cls) -> None: + # Since `get_params` is a `@staticmethod`, we have to bind it to the class itself rather than to an instance. + # This method is called after subclassing has happened, i.e. `cls` is the subclass, e.g. `Resize`. + if cls._v1_transform_cls is not None and hasattr(cls._v1_transform_cls, "get_params"): + cls.get_params = staticmethod(cls._v1_transform_cls.get_params) # type: ignore[attr-defined] + + def _extract_params_for_v1_transform(self) -> Dict[str, Any]: + # This method is called by `__prepare_scriptable__` to instantiate the equivalent v1 transform from the current + # v2 transform instance. It extracts all available public attributes that are specific to that transform and + # not `nn.Module` in general. + # Overwrite this method on the v2 transform class if the above is not sufficient. For example, this might happen + # if the v2 transform introduced new parameters that are not support by the v1 transform. + common_attrs = nn.Module().__dict__.keys() + return { + attr: value + for attr, value in self.__dict__.items() + if not attr.startswith("_") and attr not in common_attrs + } + + def __prepare_scriptable__(self) -> nn.Module: + # This method is called early on when `torch.jit.script`'ing an `nn.Module` instance. If it succeeds, the return + # value is used for scripting over the original object that should have been scripted. Since the v1 transforms + # are JIT scriptable, and we made sure that for single image inputs v1 and v2 are equivalent, we just return the + # equivalent v1 transform here. This of course only makes transforms v2 JIT scriptable as long as transforms v1 + # is around. + if self._v1_transform_cls is None: + raise RuntimeError( + f"Transform {type(self).__name__} cannot be JIT scripted. " + "torchscript is only supported for backward compatibility with transforms " + "which are already in torchvision.transforms. " + "For torchscript support (on tensors only), you can use the functional API instead." + ) + + return self._v1_transform_cls(**self._extract_params_for_v1_transform()) + + +class _RandomApplyTransform(Transform): + def __init__(self, p: float = 0.5) -> None: + if not (0.0 <= p <= 1.0): + raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].") + + super().__init__() + self.p = p + + def forward(self, *inputs: Any) -> Any: + # We need to almost duplicate `Transform.forward()` here since we always want to check the inputs, but return + # early afterwards in case the random check triggers. The same result could be achieved by calling + # `super().forward()` after the random check, but that would call `self._check_inputs` twice. + + inputs = inputs if len(inputs) > 1 else inputs[0] + flat_inputs, spec = tree_flatten(inputs) + + self._check_inputs(flat_inputs) + + if torch.rand(1) >= self.p: + return inputs + + needs_transform_list = self._needs_transform_list(flat_inputs) + params = self._get_params( + [inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform] + ) + + flat_outputs = [ + self._transform(inpt, params) if needs_transform else inpt + for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) + ] + + return tree_unflatten(flat_outputs, spec) diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_type_conversion.py b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_type_conversion.py new file mode 100644 index 0000000000000000000000000000000000000000..7c7439b1d02a560b3be4261b347c70ff0b3aeb1c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_type_conversion.py @@ -0,0 +1,84 @@ +from typing import Any, Dict, Optional, Union + +import numpy as np +import PIL.Image +import torch + +from torchvision import tv_tensors +from torchvision.transforms.v2 import functional as F, Transform + +from torchvision.transforms.v2._utils import is_pure_tensor + + +class PILToTensor(Transform): + """Convert a PIL Image to a tensor of the same type - this does not scale values. + + This transform does not support torchscript. + + Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W). + """ + + _transformed_types = (PIL.Image.Image,) + + def _transform(self, inpt: PIL.Image.Image, params: Dict[str, Any]) -> torch.Tensor: + return F.pil_to_tensor(inpt) + + +class ToImage(Transform): + """Convert a tensor, ndarray, or PIL Image to :class:`~torchvision.tv_tensors.Image` + ; this does not scale values. + + This transform does not support torchscript. + """ + + _transformed_types = (is_pure_tensor, PIL.Image.Image, np.ndarray) + + def _transform( + self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any] + ) -> tv_tensors.Image: + return F.to_image(inpt) + + +class ToPILImage(Transform): + """Convert a tensor or an ndarray to PIL Image + + This transform does not support torchscript. + + Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape + H x W x C to a PIL Image while adjusting the value range depending on the ``mode``. + + Args: + mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). + If ``mode`` is ``None`` (default) there are some assumptions made about the input data: + + - If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``. + - If the input has 3 channels, the ``mode`` is assumed to be ``RGB``. + - If the input has 2 channels, the ``mode`` is assumed to be ``LA``. + - If the input has 1 channel, the ``mode`` is determined by the data type (i.e ``int``, ``float``, + ``short``). + + .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes + """ + + _transformed_types = (is_pure_tensor, tv_tensors.Image, np.ndarray) + + def __init__(self, mode: Optional[str] = None) -> None: + super().__init__() + self.mode = mode + + def _transform( + self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any] + ) -> PIL.Image.Image: + return F.to_pil_image(inpt, mode=self.mode) + + +class ToPureTensor(Transform): + """Convert all TVTensors to pure tensors, removing associated metadata (if any). + + This doesn't scale or change the values, only the type. + """ + + _transformed_types = (tv_tensors.TVTensor,) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor: + return inpt.as_subclass(torch.Tensor) diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_utils.py b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e7cde4c5c33b2f7229a10efacc47b070e0ac96a1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_utils.py @@ -0,0 +1,220 @@ +from __future__ import annotations + +import collections.abc +import numbers +from contextlib import suppress + +from typing import Any, Callable, Dict, List, Literal, Sequence, Tuple, Type, Union + +import PIL.Image +import torch + +from torchvision import tv_tensors + +from torchvision._utils import sequence_to_str + +from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 +from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor +from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT + + +def _setup_number_or_seq(arg: Union[int, float, Sequence[Union[int, float]]], name: str) -> Sequence[float]: + if not isinstance(arg, (int, float, Sequence)): + raise TypeError(f"{name} should be a number or a sequence of numbers. Got {type(arg)}") + if isinstance(arg, Sequence) and len(arg) not in (1, 2): + raise ValueError(f"If {name} is a sequence its length should be 1 or 2. Got {len(arg)}") + if isinstance(arg, Sequence): + for element in arg: + if not isinstance(element, (int, float)): + raise ValueError(f"{name} should be a sequence of numbers. Got {type(element)}") + + if isinstance(arg, (int, float)): + arg = [float(arg), float(arg)] + elif isinstance(arg, Sequence): + if len(arg) == 1: + arg = [float(arg[0]), float(arg[0])] + else: + arg = [float(arg[0]), float(arg[1])] + return arg + + +def _check_fill_arg(fill: Union[_FillType, Dict[Union[Type, str], _FillType]]) -> None: + if isinstance(fill, dict): + for value in fill.values(): + _check_fill_arg(value) + else: + if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)): + raise TypeError("Got inappropriate fill arg, only Numbers, tuples, lists and dicts are allowed.") + + +def _convert_fill_arg(fill: _FillType) -> _FillTypeJIT: + # Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 + # So, we can't reassign fill to 0 + # if fill is None: + # fill = 0 + if fill is None: + return fill + + if not isinstance(fill, (int, float)): + fill = [float(v) for v in list(fill)] + return fill # type: ignore[return-value] + + +def _setup_fill_arg(fill: Union[_FillType, Dict[Union[Type, str], _FillType]]) -> Dict[Union[Type, str], _FillTypeJIT]: + _check_fill_arg(fill) + + if isinstance(fill, dict): + for k, v in fill.items(): + fill[k] = _convert_fill_arg(v) + return fill # type: ignore[return-value] + else: + return {"others": _convert_fill_arg(fill)} + + +def _get_fill(fill_dict, inpt_type): + if inpt_type in fill_dict: + return fill_dict[inpt_type] + elif "others" in fill_dict: + return fill_dict["others"] + else: + RuntimeError("This should never happen, please open an issue on the torchvision repo if you hit this.") + + +def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None: + if not isinstance(padding, (numbers.Number, tuple, list)): + raise TypeError("Got inappropriate padding arg") + + if isinstance(padding, (tuple, list)) and len(padding) not in [1, 2, 4]: + raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple") + + +# TODO: let's use torchvision._utils.StrEnum to have the best of both worlds (strings and enums) +# https://github.com/pytorch/vision/issues/6250 +def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", "symmetric"]) -> None: + if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: + raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") + + +def _find_labels_default_heuristic(inputs: Any) -> torch.Tensor: + """ + This heuristic covers three cases: + + 1. The input is tuple or list whose second item is a labels tensor. This happens for already batched + classification inputs for MixUp and CutMix (typically after the Dataloder). + 2. The input is a tuple or list whose second item is a dictionary that contains the labels tensor + under a label-like (see below) key. This happens for the inputs of detection models. + 3. The input is a dictionary that is structured as the one from 2. + + What is "label-like" key? We first search for an case-insensitive match of 'labels' inside the keys of the + dictionary. This is the name our detection models expect. If we can't find that, we look for a case-insensitive + match of the term 'label' anywhere inside the key, i.e. 'FooLaBeLBar'. If we can't find that either, the dictionary + contains no "label-like" key. + """ + + if isinstance(inputs, (tuple, list)): + inputs = inputs[1] + + # MixUp, CutMix + if is_pure_tensor(inputs): + return inputs + + if not isinstance(inputs, collections.abc.Mapping): + raise ValueError( + f"When using the default labels_getter, the input passed to forward must be a dictionary or a two-tuple " + f"whose second item is a dictionary or a tensor, but got {inputs} instead." + ) + + candidate_key = None + with suppress(StopIteration): + candidate_key = next(key for key in inputs.keys() if key.lower() == "labels") + if candidate_key is None: + with suppress(StopIteration): + candidate_key = next(key for key in inputs.keys() if "label" in key.lower()) + if candidate_key is None: + raise ValueError( + "Could not infer where the labels are in the sample. Try passing a callable as the labels_getter parameter?" + "If there are no labels in the sample by design, pass labels_getter=None." + ) + + return inputs[candidate_key] + + +def _parse_labels_getter(labels_getter: Union[str, Callable[[Any], Any], None]) -> Callable[[Any], Any]: + if labels_getter == "default": + return _find_labels_default_heuristic + elif callable(labels_getter): + return labels_getter + elif labels_getter is None: + return lambda _: None + else: + raise ValueError(f"labels_getter should either be 'default', a callable, or None, but got {labels_getter}.") + + +def get_bounding_boxes(flat_inputs: List[Any]) -> tv_tensors.BoundingBoxes: + # This assumes there is only one bbox per sample as per the general convention + try: + return next(inpt for inpt in flat_inputs if isinstance(inpt, tv_tensors.BoundingBoxes)) + except StopIteration: + raise ValueError("No bounding boxes were found in the sample") + + +def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]: + chws = { + tuple(get_dimensions(inpt)) + for inpt in flat_inputs + if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video)) + } + if not chws: + raise TypeError("No image or video was found in the sample") + elif len(chws) > 1: + raise ValueError(f"Found multiple CxHxW dimensions in the sample: {sequence_to_str(sorted(chws))}") + c, h, w = chws.pop() + return c, h, w + + +def query_size(flat_inputs: List[Any]) -> Tuple[int, int]: + sizes = { + tuple(get_size(inpt)) + for inpt in flat_inputs + if check_type( + inpt, + ( + is_pure_tensor, + tv_tensors.Image, + PIL.Image.Image, + tv_tensors.Video, + tv_tensors.Mask, + tv_tensors.BoundingBoxes, + ), + ) + } + if not sizes: + raise TypeError("No image, video, mask or bounding box was found in the sample") + elif len(sizes) > 1: + raise ValueError(f"Found multiple HxW dimensions in the sample: {sequence_to_str(sorted(sizes))}") + h, w = sizes.pop() + return h, w + + +def check_type(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...]) -> bool: + for type_or_check in types_or_checks: + if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj): + return True + return False + + +def has_any(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool: + for inpt in flat_inputs: + if check_type(inpt, types_or_checks): + return True + return False + + +def has_all(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool: + for type_or_check in types_or_checks: + for inpt in flat_inputs: + if isinstance(inpt, type_or_check) if isinstance(type_or_check, type) else type_or_check(inpt): + break + else: + return False + return True diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/__init__.py b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d5705d55c4b2d429e450dc0033c99f7fe16a8ddb --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/__init__.py @@ -0,0 +1,153 @@ +from torchvision.transforms import InterpolationMode # usort: skip + +from ._utils import is_pure_tensor, register_kernel # usort: skip + +from ._meta import ( + clamp_bounding_boxes, + convert_bounding_box_format, + get_dimensions_image, + get_dimensions_video, + get_dimensions, + get_num_frames_video, + get_num_frames, + get_image_num_channels, + get_num_channels_image, + get_num_channels_video, + get_num_channels, + get_size_bounding_boxes, + get_size_image, + get_size_mask, + get_size_video, + get_size, +) # usort: skip + +from ._augment import erase, erase_image, erase_video, jpeg, jpeg_image, jpeg_video +from ._color import ( + adjust_brightness, + adjust_brightness_image, + adjust_brightness_video, + adjust_contrast, + adjust_contrast_image, + adjust_contrast_video, + adjust_gamma, + adjust_gamma_image, + adjust_gamma_video, + adjust_hue, + adjust_hue_image, + adjust_hue_video, + adjust_saturation, + adjust_saturation_image, + adjust_saturation_video, + adjust_sharpness, + adjust_sharpness_image, + adjust_sharpness_video, + autocontrast, + autocontrast_image, + autocontrast_video, + equalize, + equalize_image, + equalize_video, + grayscale_to_rgb, + grayscale_to_rgb_image, + invert, + invert_image, + invert_video, + permute_channels, + permute_channels_image, + permute_channels_video, + posterize, + posterize_image, + posterize_video, + rgb_to_grayscale, + rgb_to_grayscale_image, + solarize, + solarize_image, + solarize_video, + to_grayscale, +) +from ._geometry import ( + affine, + affine_bounding_boxes, + affine_image, + affine_mask, + affine_video, + center_crop, + center_crop_bounding_boxes, + center_crop_image, + center_crop_mask, + center_crop_video, + crop, + crop_bounding_boxes, + crop_image, + crop_mask, + crop_video, + elastic, + elastic_bounding_boxes, + elastic_image, + elastic_mask, + elastic_transform, + elastic_video, + five_crop, + five_crop_image, + five_crop_video, + hflip, # TODO: Consider moving all pure alias definitions at the bottom of the file + horizontal_flip, + horizontal_flip_bounding_boxes, + horizontal_flip_image, + horizontal_flip_mask, + horizontal_flip_video, + pad, + pad_bounding_boxes, + pad_image, + pad_mask, + pad_video, + perspective, + perspective_bounding_boxes, + perspective_image, + perspective_mask, + perspective_video, + resize, + resize_bounding_boxes, + resize_image, + resize_mask, + resize_video, + resized_crop, + resized_crop_bounding_boxes, + resized_crop_image, + resized_crop_mask, + resized_crop_video, + rotate, + rotate_bounding_boxes, + rotate_image, + rotate_mask, + rotate_video, + ten_crop, + ten_crop_image, + ten_crop_video, + vertical_flip, + vertical_flip_bounding_boxes, + vertical_flip_image, + vertical_flip_mask, + vertical_flip_video, + vflip, +) +from ._misc import ( + convert_image_dtype, + gaussian_blur, + gaussian_blur_image, + gaussian_blur_video, + gaussian_noise, + gaussian_noise_image, + gaussian_noise_video, + normalize, + normalize_image, + normalize_video, + sanitize_bounding_boxes, + to_dtype, + to_dtype_image, + to_dtype_video, +) +from ._temporal import uniform_temporal_subsample, uniform_temporal_subsample_video +from ._type_conversion import pil_to_tensor, to_image, to_pil_image + +from ._deprecated import get_image_size, to_tensor # usort: skip diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc35cf82931abf3fa10f5bab41ad47a74c2b5851 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/__pycache__/_augment.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/__pycache__/_augment.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21e3fee522d1aa56beee86943f98407df04ff2dd Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/__pycache__/_augment.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/__pycache__/_color.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/__pycache__/_color.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06b12292a803b2c6d6ebf1ca005f65604effd204 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/__pycache__/_color.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/__pycache__/_deprecated.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/__pycache__/_deprecated.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f48b065033593e571a2b22c6dbae2ead18904ce7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/__pycache__/_deprecated.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/__pycache__/_meta.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/__pycache__/_meta.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1c611424acf17e4082f9b23405fa03efd00c823 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/__pycache__/_meta.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/__pycache__/_misc.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/__pycache__/_misc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7aa810246853e0a6bca5adf3b3ea71698b25773c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/__pycache__/_misc.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/__pycache__/_temporal.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/__pycache__/_temporal.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a31926fa8ea6b23b6e2e57e624f3ab73dd03b3c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/__pycache__/_temporal.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/__pycache__/_type_conversion.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/__pycache__/_type_conversion.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1fcf17b906119ddbabfeff868d9f094de928a6bc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/__pycache__/_type_conversion.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/__pycache__/_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/__pycache__/_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4730aa348fb7a71d23f012add01c7d72db3e3f28 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/__pycache__/_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/_augment.py b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/_augment.py new file mode 100644 index 0000000000000000000000000000000000000000..a904d8d7cbdfeb78588abbf43c8bca37b3431735 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/_augment.py @@ -0,0 +1,106 @@ +import io + +import PIL.Image + +import torch +from torchvision import tv_tensors +from torchvision.io import decode_jpeg, encode_jpeg +from torchvision.transforms.functional import pil_to_tensor, to_pil_image +from torchvision.utils import _log_api_usage_once + +from ._utils import _get_kernel, _register_kernel_internal + + +def erase( + inpt: torch.Tensor, + i: int, + j: int, + h: int, + w: int, + v: torch.Tensor, + inplace: bool = False, +) -> torch.Tensor: + """See :class:`~torchvision.transforms.v2.RandomErase` for details.""" + if torch.jit.is_scripting(): + return erase_image(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) + + _log_api_usage_once(erase) + + kernel = _get_kernel(erase, type(inpt)) + return kernel(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) + + +@_register_kernel_internal(erase, torch.Tensor) +@_register_kernel_internal(erase, tv_tensors.Image) +def erase_image( + image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False +) -> torch.Tensor: + if not inplace: + image = image.clone() + + image[..., i : i + h, j : j + w] = v + return image + + +@_register_kernel_internal(erase, PIL.Image.Image) +def _erase_image_pil( + image: PIL.Image.Image, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False +) -> PIL.Image.Image: + t_img = pil_to_tensor(image) + output = erase_image(t_img, i=i, j=j, h=h, w=w, v=v, inplace=inplace) + return to_pil_image(output, mode=image.mode) + + +@_register_kernel_internal(erase, tv_tensors.Video) +def erase_video( + video: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False +) -> torch.Tensor: + return erase_image(video, i=i, j=j, h=h, w=w, v=v, inplace=inplace) + + +def jpeg(image: torch.Tensor, quality: int) -> torch.Tensor: + """See :class:`~torchvision.transforms.v2.JPEG` for details.""" + if torch.jit.is_scripting(): + return jpeg_image(image, quality=quality) + + _log_api_usage_once(jpeg) + + kernel = _get_kernel(jpeg, type(image)) + return kernel(image, quality=quality) + + +@_register_kernel_internal(jpeg, torch.Tensor) +@_register_kernel_internal(jpeg, tv_tensors.Image) +def jpeg_image(image: torch.Tensor, quality: int) -> torch.Tensor: + original_shape = image.shape + image = image.view((-1,) + image.shape[-3:]) + + if image.shape[0] == 0: # degenerate + return image.reshape(original_shape).clone() + + images = [] + for i in range(image.shape[0]): + # isinstance checks are needed for torchscript. + encoded_image = encode_jpeg(image[i], quality=quality) + assert isinstance(encoded_image, torch.Tensor) + decoded_image = decode_jpeg(encoded_image) + assert isinstance(decoded_image, torch.Tensor) + images.append(decoded_image) + + images = torch.stack(images, dim=0).view(original_shape) + return images + + +@_register_kernel_internal(jpeg, tv_tensors.Video) +def jpeg_video(video: torch.Tensor, quality: int) -> torch.Tensor: + return jpeg_image(video, quality=quality) + + +@_register_kernel_internal(jpeg, PIL.Image.Image) +def _jpeg_image_pil(image: PIL.Image.Image, quality: int) -> PIL.Image.Image: + raw_jpeg = io.BytesIO() + image.save(raw_jpeg, format="JPEG", quality=quality) + + # we need to copy since PIL.Image.open() will return PIL.JpegImagePlugin.JpegImageFile + # which is a sub-class of PIL.Image.Image. this will fail check_transform() test. + return PIL.Image.open(raw_jpeg).copy() diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/_color.py b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/_color.py new file mode 100644 index 0000000000000000000000000000000000000000..eb75f58cb7ac85a9478c26c4e86d9ab73b93d331 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/_color.py @@ -0,0 +1,739 @@ +from typing import List + +import PIL.Image +import torch +from torch.nn.functional import conv2d +from torchvision import tv_tensors +from torchvision.transforms import _functional_pil as _FP +from torchvision.transforms._functional_tensor import _max_value + +from torchvision.utils import _log_api_usage_once + +from ._misc import _num_value_bits, to_dtype_image +from ._type_conversion import pil_to_tensor, to_pil_image +from ._utils import _get_kernel, _register_kernel_internal + + +def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: + """See :class:`~torchvision.transforms.v2.Grayscale` for details.""" + if torch.jit.is_scripting(): + return rgb_to_grayscale_image(inpt, num_output_channels=num_output_channels) + + _log_api_usage_once(rgb_to_grayscale) + + kernel = _get_kernel(rgb_to_grayscale, type(inpt)) + return kernel(inpt, num_output_channels=num_output_channels) + + +# `to_grayscale` actually predates `rgb_to_grayscale` in v1, but only handles PIL images. Since `rgb_to_grayscale` is a +# superset in terms of functionality and has the same signature, we alias here to avoid disruption. +to_grayscale = rgb_to_grayscale + + +def _rgb_to_grayscale_image( + image: torch.Tensor, num_output_channels: int = 1, preserve_dtype: bool = True +) -> torch.Tensor: + # TODO: Maybe move the validation that num_output_channels is 1 or 3 to this function instead of callers. + if image.shape[-3] == 1 and num_output_channels == 1: + return image.clone() + if image.shape[-3] == 1 and num_output_channels == 3: + s = [1] * len(image.shape) + s[-3] = 3 + return image.repeat(s) + r, g, b = image.unbind(dim=-3) + l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114) + l_img = l_img.unsqueeze(dim=-3) + if preserve_dtype: + l_img = l_img.to(image.dtype) + if num_output_channels == 3: + l_img = l_img.expand(image.shape) + return l_img + + +@_register_kernel_internal(rgb_to_grayscale, torch.Tensor) +@_register_kernel_internal(rgb_to_grayscale, tv_tensors.Image) +def rgb_to_grayscale_image(image: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: + if num_output_channels not in (1, 3): + raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.") + return _rgb_to_grayscale_image(image, num_output_channels=num_output_channels, preserve_dtype=True) + + +@_register_kernel_internal(rgb_to_grayscale, PIL.Image.Image) +def _rgb_to_grayscale_image_pil(image: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image: + if num_output_channels not in (1, 3): + raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.") + return _FP.to_grayscale(image, num_output_channels=num_output_channels) + + +def grayscale_to_rgb(inpt: torch.Tensor) -> torch.Tensor: + """See :class:`~torchvision.transforms.v2.RGB` for details.""" + if torch.jit.is_scripting(): + return grayscale_to_rgb_image(inpt) + + _log_api_usage_once(grayscale_to_rgb) + + kernel = _get_kernel(grayscale_to_rgb, type(inpt)) + return kernel(inpt) + + +@_register_kernel_internal(grayscale_to_rgb, torch.Tensor) +@_register_kernel_internal(grayscale_to_rgb, tv_tensors.Image) +def grayscale_to_rgb_image(image: torch.Tensor) -> torch.Tensor: + if image.shape[-3] >= 3: + # Image already has RGB channels. We don't need to do anything. + return image + # rgb_to_grayscale can be used to add channels so we reuse that function. + return _rgb_to_grayscale_image(image, num_output_channels=3, preserve_dtype=True) + + +@_register_kernel_internal(grayscale_to_rgb, PIL.Image.Image) +def grayscale_to_rgb_image_pil(image: PIL.Image.Image) -> PIL.Image.Image: + return image.convert(mode="RGB") + + +def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor: + ratio = float(ratio) + fp = image1.is_floating_point() + bound = _max_value(image1.dtype) + output = image1.mul(ratio).add_(image2, alpha=(1.0 - ratio)).clamp_(0, bound) + return output if fp else output.to(image1.dtype) + + +def adjust_brightness(inpt: torch.Tensor, brightness_factor: float) -> torch.Tensor: + """Adjust brightness.""" + + if torch.jit.is_scripting(): + return adjust_brightness_image(inpt, brightness_factor=brightness_factor) + + _log_api_usage_once(adjust_brightness) + + kernel = _get_kernel(adjust_brightness, type(inpt)) + return kernel(inpt, brightness_factor=brightness_factor) + + +@_register_kernel_internal(adjust_brightness, torch.Tensor) +@_register_kernel_internal(adjust_brightness, tv_tensors.Image) +def adjust_brightness_image(image: torch.Tensor, brightness_factor: float) -> torch.Tensor: + if brightness_factor < 0: + raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.") + + c = image.shape[-3] + if c not in [1, 3]: + raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}") + + fp = image.is_floating_point() + bound = _max_value(image.dtype) + output = image.mul(brightness_factor).clamp_(0, bound) + return output if fp else output.to(image.dtype) + + +@_register_kernel_internal(adjust_brightness, PIL.Image.Image) +def _adjust_brightness_image_pil(image: PIL.Image.Image, brightness_factor: float) -> PIL.Image.Image: + return _FP.adjust_brightness(image, brightness_factor=brightness_factor) + + +@_register_kernel_internal(adjust_brightness, tv_tensors.Video) +def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> torch.Tensor: + return adjust_brightness_image(video, brightness_factor=brightness_factor) + + +def adjust_saturation(inpt: torch.Tensor, saturation_factor: float) -> torch.Tensor: + """Adjust saturation.""" + if torch.jit.is_scripting(): + return adjust_saturation_image(inpt, saturation_factor=saturation_factor) + + _log_api_usage_once(adjust_saturation) + + kernel = _get_kernel(adjust_saturation, type(inpt)) + return kernel(inpt, saturation_factor=saturation_factor) + + +@_register_kernel_internal(adjust_saturation, torch.Tensor) +@_register_kernel_internal(adjust_saturation, tv_tensors.Image) +def adjust_saturation_image(image: torch.Tensor, saturation_factor: float) -> torch.Tensor: + if saturation_factor < 0: + raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.") + + c = image.shape[-3] + if c not in [1, 3]: + raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}") + + if c == 1: # Match PIL behaviour + return image + + grayscale_image = _rgb_to_grayscale_image(image, num_output_channels=1, preserve_dtype=False) + if not image.is_floating_point(): + grayscale_image = grayscale_image.floor_() + + return _blend(image, grayscale_image, saturation_factor) + + +_adjust_saturation_image_pil = _register_kernel_internal(adjust_saturation, PIL.Image.Image)(_FP.adjust_saturation) + + +@_register_kernel_internal(adjust_saturation, tv_tensors.Video) +def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> torch.Tensor: + return adjust_saturation_image(video, saturation_factor=saturation_factor) + + +def adjust_contrast(inpt: torch.Tensor, contrast_factor: float) -> torch.Tensor: + """See :class:`~torchvision.transforms.RandomAutocontrast`""" + if torch.jit.is_scripting(): + return adjust_contrast_image(inpt, contrast_factor=contrast_factor) + + _log_api_usage_once(adjust_contrast) + + kernel = _get_kernel(adjust_contrast, type(inpt)) + return kernel(inpt, contrast_factor=contrast_factor) + + +@_register_kernel_internal(adjust_contrast, torch.Tensor) +@_register_kernel_internal(adjust_contrast, tv_tensors.Image) +def adjust_contrast_image(image: torch.Tensor, contrast_factor: float) -> torch.Tensor: + if contrast_factor < 0: + raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.") + + c = image.shape[-3] + if c not in [1, 3]: + raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}") + fp = image.is_floating_point() + if c == 3: + grayscale_image = _rgb_to_grayscale_image(image, num_output_channels=1, preserve_dtype=False) + if not fp: + grayscale_image = grayscale_image.floor_() + else: + grayscale_image = image if fp else image.to(torch.float32) + mean = torch.mean(grayscale_image, dim=(-3, -2, -1), keepdim=True) + return _blend(image, mean, contrast_factor) + + +_adjust_contrast_image_pil = _register_kernel_internal(adjust_contrast, PIL.Image.Image)(_FP.adjust_contrast) + + +@_register_kernel_internal(adjust_contrast, tv_tensors.Video) +def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.Tensor: + return adjust_contrast_image(video, contrast_factor=contrast_factor) + + +def adjust_sharpness(inpt: torch.Tensor, sharpness_factor: float) -> torch.Tensor: + """See :class:`~torchvision.transforms.RandomAdjustSharpness`""" + if torch.jit.is_scripting(): + return adjust_sharpness_image(inpt, sharpness_factor=sharpness_factor) + + _log_api_usage_once(adjust_sharpness) + + kernel = _get_kernel(adjust_sharpness, type(inpt)) + return kernel(inpt, sharpness_factor=sharpness_factor) + + +@_register_kernel_internal(adjust_sharpness, torch.Tensor) +@_register_kernel_internal(adjust_sharpness, tv_tensors.Image) +def adjust_sharpness_image(image: torch.Tensor, sharpness_factor: float) -> torch.Tensor: + num_channels, height, width = image.shape[-3:] + if num_channels not in (1, 3): + raise TypeError(f"Input image tensor can have 1 or 3 channels, but found {num_channels}") + + if sharpness_factor < 0: + raise ValueError(f"sharpness_factor ({sharpness_factor}) is not non-negative.") + + if image.numel() == 0 or height <= 2 or width <= 2: + return image + + bound = _max_value(image.dtype) + fp = image.is_floating_point() + shape = image.shape + + if image.ndim > 4: + image = image.reshape(-1, num_channels, height, width) + needs_unsquash = True + else: + needs_unsquash = False + + # The following is a normalized 3x3 kernel with 1s in the edges and a 5 in the middle. + kernel_dtype = image.dtype if fp else torch.float32 + a, b = 1.0 / 13.0, 5.0 / 13.0 + kernel = torch.tensor([[a, a, a], [a, b, a], [a, a, a]], dtype=kernel_dtype, device=image.device) + kernel = kernel.expand(num_channels, 1, 3, 3) + + # We copy and cast at the same time to avoid modifications on the original data + output = image.to(dtype=kernel_dtype, copy=True) + blurred_degenerate = conv2d(output, kernel, groups=num_channels) + if not fp: + # it is better to round before cast + blurred_degenerate = blurred_degenerate.round_() + + # Create a view on the underlying output while pointing at the same data. We do this to avoid indexing twice. + view = output[..., 1:-1, 1:-1] + + # We speed up blending by minimizing flops and doing in-place. The 2 blend options are mathematically equivalent: + # x+(1-r)*(y-x) = x + (1-r)*y - (1-r)*x = x*r + y*(1-r) + view.add_(blurred_degenerate.sub_(view), alpha=(1.0 - sharpness_factor)) + + # The actual data of output have been modified by the above. We only need to clamp and cast now. + output = output.clamp_(0, bound) + if not fp: + output = output.to(image.dtype) + + if needs_unsquash: + output = output.reshape(shape) + + return output + + +_adjust_sharpness_image_pil = _register_kernel_internal(adjust_sharpness, PIL.Image.Image)(_FP.adjust_sharpness) + + +@_register_kernel_internal(adjust_sharpness, tv_tensors.Video) +def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torch.Tensor: + return adjust_sharpness_image(video, sharpness_factor=sharpness_factor) + + +def adjust_hue(inpt: torch.Tensor, hue_factor: float) -> torch.Tensor: + """Adjust hue""" + if torch.jit.is_scripting(): + return adjust_hue_image(inpt, hue_factor=hue_factor) + + _log_api_usage_once(adjust_hue) + + kernel = _get_kernel(adjust_hue, type(inpt)) + return kernel(inpt, hue_factor=hue_factor) + + +def _rgb_to_hsv(image: torch.Tensor) -> torch.Tensor: + r, g, _ = image.unbind(dim=-3) + + # Implementation is based on + # https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/src/libImaging/Convert.c#L330 + minc, maxc = torch.aminmax(image, dim=-3) + + # The algorithm erases S and H channel where `maxc = minc`. This avoids NaN + # from happening in the results, because + # + S channel has division by `maxc`, which is zero only if `maxc = minc` + # + H channel has division by `(maxc - minc)`. + # + # Instead of overwriting NaN afterwards, we just prevent it from occurring so + # we don't need to deal with it in case we save the NaN in a buffer in + # backprop, if it is ever supported, but it doesn't hurt to do so. + eqc = maxc == minc + + channels_range = maxc - minc + # Since `eqc => channels_range = 0`, replacing denominator with 1 when `eqc` is fine. + ones = torch.ones_like(maxc) + s = channels_range / torch.where(eqc, ones, maxc) + # Note that `eqc => maxc = minc = r = g = b`. So the following calculation + # of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it + # would not matter what values `rc`, `gc`, and `bc` have here, and thus + # replacing denominator with 1 when `eqc` is fine. + channels_range_divisor = torch.where(eqc, ones, channels_range).unsqueeze_(dim=-3) + rc, gc, bc = ((maxc.unsqueeze(dim=-3) - image) / channels_range_divisor).unbind(dim=-3) + + mask_maxc_neq_r = maxc != r + mask_maxc_eq_g = maxc == g + + hg = rc.add(2.0).sub_(bc).mul_(mask_maxc_eq_g & mask_maxc_neq_r) + hr = bc.sub_(gc).mul_(~mask_maxc_neq_r) + hb = gc.add_(4.0).sub_(rc).mul_(mask_maxc_neq_r.logical_and_(mask_maxc_eq_g.logical_not_())) + + h = hr.add_(hg).add_(hb) + h = h.mul_(1.0 / 6.0).add_(1.0).fmod_(1.0) + return torch.stack((h, s, maxc), dim=-3) + + +def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor: + h, s, v = img.unbind(dim=-3) + h6 = h.mul(6) + i = torch.floor(h6) + f = h6.sub_(i) + i = i.to(dtype=torch.int32) + + sxf = s * f + one_minus_s = 1.0 - s + q = (1.0 - sxf).mul_(v).clamp_(0.0, 1.0) + t = sxf.add_(one_minus_s).mul_(v).clamp_(0.0, 1.0) + p = one_minus_s.mul_(v).clamp_(0.0, 1.0) + i.remainder_(6) + + vpqt = torch.stack((v, p, q, t), dim=-3) + + # vpqt -> rgb mapping based on i + select = torch.tensor([[0, 2, 1, 1, 3, 0], [3, 0, 0, 2, 1, 1], [1, 1, 3, 0, 0, 2]], dtype=torch.long) + select = select.to(device=img.device, non_blocking=True) + + select = select[:, i] + if select.ndim > 3: + # if input.shape is (B, ..., C, H, W) then + # select.shape is (C, B, ..., H, W) + # thus we move C axis to get (B, ..., C, H, W) + select = select.moveaxis(0, -3) + + return vpqt.gather(-3, select) + + +@_register_kernel_internal(adjust_hue, torch.Tensor) +@_register_kernel_internal(adjust_hue, tv_tensors.Image) +def adjust_hue_image(image: torch.Tensor, hue_factor: float) -> torch.Tensor: + if not (-0.5 <= hue_factor <= 0.5): + raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].") + + c = image.shape[-3] + if c not in [1, 3]: + raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}") + + if c == 1: # Match PIL behaviour + return image + + if image.numel() == 0: + # exit earlier on empty images + return image + + orig_dtype = image.dtype + image = to_dtype_image(image, torch.float32, scale=True) + + image = _rgb_to_hsv(image) + h, s, v = image.unbind(dim=-3) + h.add_(hue_factor).remainder_(1.0) + image = torch.stack((h, s, v), dim=-3) + image_hue_adj = _hsv_to_rgb(image) + + return to_dtype_image(image_hue_adj, orig_dtype, scale=True) + + +_adjust_hue_image_pil = _register_kernel_internal(adjust_hue, PIL.Image.Image)(_FP.adjust_hue) + + +@_register_kernel_internal(adjust_hue, tv_tensors.Video) +def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor: + return adjust_hue_image(video, hue_factor=hue_factor) + + +def adjust_gamma(inpt: torch.Tensor, gamma: float, gain: float = 1) -> torch.Tensor: + """Adjust gamma.""" + if torch.jit.is_scripting(): + return adjust_gamma_image(inpt, gamma=gamma, gain=gain) + + _log_api_usage_once(adjust_gamma) + + kernel = _get_kernel(adjust_gamma, type(inpt)) + return kernel(inpt, gamma=gamma, gain=gain) + + +@_register_kernel_internal(adjust_gamma, torch.Tensor) +@_register_kernel_internal(adjust_gamma, tv_tensors.Image) +def adjust_gamma_image(image: torch.Tensor, gamma: float, gain: float = 1.0) -> torch.Tensor: + if gamma < 0: + raise ValueError("Gamma should be a non-negative real number") + + # The input image is either assumed to be at [0, 1] scale (if float) or is converted to that scale (if integer). + # Since the gamma is non-negative, the output remains at [0, 1] scale. + if not torch.is_floating_point(image): + output = to_dtype_image(image, torch.float32, scale=True).pow_(gamma) + else: + output = image.pow(gamma) + + if gain != 1.0: + # The clamp operation is needed only if multiplication is performed. It's only when gain != 1, that the scale + # of the output can go beyond [0, 1]. + output = output.mul_(gain).clamp_(0.0, 1.0) + + return to_dtype_image(output, image.dtype, scale=True) + + +_adjust_gamma_image_pil = _register_kernel_internal(adjust_gamma, PIL.Image.Image)(_FP.adjust_gamma) + + +@_register_kernel_internal(adjust_gamma, tv_tensors.Video) +def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> torch.Tensor: + return adjust_gamma_image(video, gamma=gamma, gain=gain) + + +def posterize(inpt: torch.Tensor, bits: int) -> torch.Tensor: + """See :class:`~torchvision.transforms.v2.RandomPosterize` for details.""" + if torch.jit.is_scripting(): + return posterize_image(inpt, bits=bits) + + _log_api_usage_once(posterize) + + kernel = _get_kernel(posterize, type(inpt)) + return kernel(inpt, bits=bits) + + +@_register_kernel_internal(posterize, torch.Tensor) +@_register_kernel_internal(posterize, tv_tensors.Image) +def posterize_image(image: torch.Tensor, bits: int) -> torch.Tensor: + if image.is_floating_point(): + levels = 1 << bits + return image.mul(levels).floor_().clamp_(0, levels - 1).mul_(1.0 / levels) + else: + num_value_bits = _num_value_bits(image.dtype) + if bits >= num_value_bits: + return image + + mask = ((1 << bits) - 1) << (num_value_bits - bits) + return image & mask + + +_posterize_image_pil = _register_kernel_internal(posterize, PIL.Image.Image)(_FP.posterize) + + +@_register_kernel_internal(posterize, tv_tensors.Video) +def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor: + return posterize_image(video, bits=bits) + + +def solarize(inpt: torch.Tensor, threshold: float) -> torch.Tensor: + """See :class:`~torchvision.transforms.v2.RandomSolarize` for details.""" + if torch.jit.is_scripting(): + return solarize_image(inpt, threshold=threshold) + + _log_api_usage_once(solarize) + + kernel = _get_kernel(solarize, type(inpt)) + return kernel(inpt, threshold=threshold) + + +@_register_kernel_internal(solarize, torch.Tensor) +@_register_kernel_internal(solarize, tv_tensors.Image) +def solarize_image(image: torch.Tensor, threshold: float) -> torch.Tensor: + if threshold > _max_value(image.dtype): + raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}") + + return torch.where(image >= threshold, invert_image(image), image) + + +_solarize_image_pil = _register_kernel_internal(solarize, PIL.Image.Image)(_FP.solarize) + + +@_register_kernel_internal(solarize, tv_tensors.Video) +def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor: + return solarize_image(video, threshold=threshold) + + +def autocontrast(inpt: torch.Tensor) -> torch.Tensor: + """See :class:`~torchvision.transforms.v2.RandomAutocontrast` for details.""" + if torch.jit.is_scripting(): + return autocontrast_image(inpt) + + _log_api_usage_once(autocontrast) + + kernel = _get_kernel(autocontrast, type(inpt)) + return kernel(inpt) + + +@_register_kernel_internal(autocontrast, torch.Tensor) +@_register_kernel_internal(autocontrast, tv_tensors.Image) +def autocontrast_image(image: torch.Tensor) -> torch.Tensor: + c = image.shape[-3] + if c not in [1, 3]: + raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}") + + if image.numel() == 0: + # exit earlier on empty images + return image + + bound = _max_value(image.dtype) + fp = image.is_floating_point() + float_image = image if fp else image.to(torch.float32) + + minimum = float_image.amin(dim=(-2, -1), keepdim=True) + maximum = float_image.amax(dim=(-2, -1), keepdim=True) + + eq_idxs = maximum == minimum + inv_scale = maximum.sub_(minimum).mul_(1.0 / bound) + minimum[eq_idxs] = 0.0 + inv_scale[eq_idxs] = 1.0 + + if fp: + diff = float_image.sub(minimum) + else: + diff = float_image.sub_(minimum) + + return diff.div_(inv_scale).clamp_(0, bound).to(image.dtype) + + +_autocontrast_image_pil = _register_kernel_internal(autocontrast, PIL.Image.Image)(_FP.autocontrast) + + +@_register_kernel_internal(autocontrast, tv_tensors.Video) +def autocontrast_video(video: torch.Tensor) -> torch.Tensor: + return autocontrast_image(video) + + +def equalize(inpt: torch.Tensor) -> torch.Tensor: + """See :class:`~torchvision.transforms.v2.RandomEqualize` for details.""" + if torch.jit.is_scripting(): + return equalize_image(inpt) + + _log_api_usage_once(equalize) + + kernel = _get_kernel(equalize, type(inpt)) + return kernel(inpt) + + +@_register_kernel_internal(equalize, torch.Tensor) +@_register_kernel_internal(equalize, tv_tensors.Image) +def equalize_image(image: torch.Tensor) -> torch.Tensor: + if image.numel() == 0: + return image + + # 1. The algorithm below can easily be extended to support arbitrary integer dtypes. However, the histogram that + # would be needed to computed will have at least `torch.iinfo(dtype).max + 1` values. That is perfectly fine for + # `torch.int8`, `torch.uint8`, and `torch.int16`, at least questionable for `torch.int32` and completely + # unfeasible for `torch.int64`. + # 2. Floating point inputs need to be binned for this algorithm. Apart from converting them to an integer dtype, we + # could also use PyTorch's builtin histogram functionality. However, that has its own set of issues: in addition + # to being slow in general, PyTorch's implementation also doesn't support batches. In total, that makes it slower + # and more complicated to implement than a simple conversion and a fast histogram implementation for integers. + # Since we need to convert in most cases anyway and out of the acceptable dtypes mentioned in 1. `torch.uint8` is + # by far the most common, we choose it as base. + output_dtype = image.dtype + image = to_dtype_image(image, torch.uint8, scale=True) + + # The histogram is computed by using the flattened image as index. For example, a pixel value of 127 in the image + # corresponds to adding 1 to index 127 in the histogram. + batch_shape = image.shape[:-2] + flat_image = image.flatten(start_dim=-2).to(torch.long) + hist = flat_image.new_zeros(batch_shape + (256,), dtype=torch.int32) + hist.scatter_add_(dim=-1, index=flat_image, src=hist.new_ones(1).expand_as(flat_image)) + cum_hist = hist.cumsum(dim=-1) + + # The simplest form of lookup-table (LUT) that also achieves histogram equalization is + # `lut = cum_hist / flat_image.shape[-1] * 255` + # However, PIL uses a more elaborate scheme: + # https://github.com/python-pillow/Pillow/blob/eb59cb61d5239ee69cbbf12709a0c6fd7314e6d7/src/PIL/ImageOps.py#L368-L385 + # `lut = ((cum_hist + num_non_max_pixels // (2 * 255)) // num_non_max_pixels) * 255` + + # The last non-zero element in the histogram is the first element in the cumulative histogram with the maximum + # value. Thus, the "max" in `num_non_max_pixels` does not refer to 255 as the maximum value of uint8 images, but + # rather the maximum value in the image, which might be or not be 255. + index = cum_hist.argmax(dim=-1) + num_non_max_pixels = flat_image.shape[-1] - hist.gather(dim=-1, index=index.unsqueeze_(-1)) + + # This is performance optimization that saves us one multiplication later. With this, the LUT computation simplifies + # to `lut = (cum_hist + step // 2) // step` and thus saving the final multiplication by 255 while keeping the + # division count the same. PIL uses the variable name `step` for this, so we keep that for easier comparison. + step = num_non_max_pixels.div_(255, rounding_mode="floor") + + # Although it looks like we could return early if we find `step == 0` like PIL does, that is unfortunately not as + # easy due to our support for batched images. We can only return early if `(step == 0).all()` holds. If it doesn't, + # we have to go through the computation below anyway. Since `step == 0` is an edge case anyway, it makes no sense to + # pay the runtime cost for checking it every time. + valid_equalization = step.ne(0).unsqueeze_(-1) + + # `lut[k]` is computed with `cum_hist[k-1]` with `lut[0] == (step // 2) // step == 0`. Thus, we perform the + # computation only for `lut[1:]` with `cum_hist[:-1]` and add `lut[0] == 0` afterwards. + cum_hist = cum_hist[..., :-1] + ( + cum_hist.add_(step // 2) + # We need the `clamp_`(min=1) call here to avoid zero division since they fail for integer dtypes. This has no + # effect on the returned result of this kernel since images inside the batch with `step == 0` are returned as is + # instead of equalized version. + .div_(step.clamp_(min=1), rounding_mode="floor") + # We need the `clamp_` call here since PILs LUT computation scheme can produce values outside the valid value + # range of uint8 images + .clamp_(0, 255) + ) + lut = cum_hist.to(torch.uint8) + lut = torch.cat([lut.new_zeros(1).expand(batch_shape + (1,)), lut], dim=-1) + equalized_image = lut.gather(dim=-1, index=flat_image).view_as(image) + + output = torch.where(valid_equalization, equalized_image, image) + return to_dtype_image(output, output_dtype, scale=True) + + +_equalize_image_pil = _register_kernel_internal(equalize, PIL.Image.Image)(_FP.equalize) + + +@_register_kernel_internal(equalize, tv_tensors.Video) +def equalize_video(video: torch.Tensor) -> torch.Tensor: + return equalize_image(video) + + +def invert(inpt: torch.Tensor) -> torch.Tensor: + """See :func:`~torchvision.transforms.v2.RandomInvert`.""" + if torch.jit.is_scripting(): + return invert_image(inpt) + + _log_api_usage_once(invert) + + kernel = _get_kernel(invert, type(inpt)) + return kernel(inpt) + + +@_register_kernel_internal(invert, torch.Tensor) +@_register_kernel_internal(invert, tv_tensors.Image) +def invert_image(image: torch.Tensor) -> torch.Tensor: + if image.is_floating_point(): + return 1.0 - image + elif image.dtype == torch.uint8: + return image.bitwise_not() + else: # signed integer dtypes + # We can't use `Tensor.bitwise_not` here, since we want to retain the leading zero bit that encodes the sign + return image.bitwise_xor((1 << _num_value_bits(image.dtype)) - 1) + + +_invert_image_pil = _register_kernel_internal(invert, PIL.Image.Image)(_FP.invert) + + +@_register_kernel_internal(invert, tv_tensors.Video) +def invert_video(video: torch.Tensor) -> torch.Tensor: + return invert_image(video) + + +def permute_channels(inpt: torch.Tensor, permutation: List[int]) -> torch.Tensor: + """Permute the channels of the input according to the given permutation. + + This function supports plain :class:`~torch.Tensor`'s, :class:`PIL.Image.Image`'s, and + :class:`torchvision.tv_tensors.Image` and :class:`torchvision.tv_tensors.Video`. + + Example: + >>> rgb_image = torch.rand(3, 256, 256) + >>> bgr_image = F.permute_channels(rgb_image, permutation=[2, 1, 0]) + + Args: + permutation (List[int]): Valid permutation of the input channel indices. The index of the element determines the + channel index in the input and the value determines the channel index in the output. For example, + ``permutation=[2, 0 , 1]`` + + - takes ``ìnpt[..., 0, :, :]`` and puts it at ``output[..., 2, :, :]``, + - takes ``ìnpt[..., 1, :, :]`` and puts it at ``output[..., 0, :, :]``, and + - takes ``ìnpt[..., 2, :, :]`` and puts it at ``output[..., 1, :, :]``. + + Raises: + ValueError: If ``len(permutation)`` doesn't match the number of channels in the input. + """ + if torch.jit.is_scripting(): + return permute_channels_image(inpt, permutation=permutation) + + _log_api_usage_once(permute_channels) + + kernel = _get_kernel(permute_channels, type(inpt)) + return kernel(inpt, permutation=permutation) + + +@_register_kernel_internal(permute_channels, torch.Tensor) +@_register_kernel_internal(permute_channels, tv_tensors.Image) +def permute_channels_image(image: torch.Tensor, permutation: List[int]) -> torch.Tensor: + shape = image.shape + num_channels, height, width = shape[-3:] + + if len(permutation) != num_channels: + raise ValueError( + f"Length of permutation does not match number of channels: " f"{len(permutation)} != {num_channels}" + ) + + if image.numel() == 0: + return image + + image = image.reshape(-1, num_channels, height, width) + image = image[:, permutation, :, :] + return image.reshape(shape) + + +@_register_kernel_internal(permute_channels, PIL.Image.Image) +def _permute_channels_image_pil(image: PIL.Image.Image, permutation: List[int]) -> PIL.Image.Image: + return to_pil_image(permute_channels_image(pil_to_tensor(image), permutation=permutation)) + + +@_register_kernel_internal(permute_channels, tv_tensors.Video) +def permute_channels_video(video: torch.Tensor, permutation: List[int]) -> torch.Tensor: + return permute_channels_image(video, permutation=permutation) diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/_deprecated.py b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/_deprecated.py new file mode 100644 index 0000000000000000000000000000000000000000..116ea31587a1d7f7172267898152e0167531f303 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/_deprecated.py @@ -0,0 +1,24 @@ +import warnings +from typing import Any, List + +import torch + +from torchvision.transforms import functional as _F + + +@torch.jit.unused +def to_tensor(inpt: Any) -> torch.Tensor: + """[DEPREACTED] Use to_image() and to_dtype() instead.""" + warnings.warn( + "The function `to_tensor(...)` is deprecated and will be removed in a future release. " + "Instead, please use `to_image(...)` followed by `to_dtype(..., dtype=torch.float32, scale=True)`." + ) + return _F.to_tensor(inpt) + + +def get_image_size(inpt: torch.Tensor) -> List[int]: + warnings.warn( + "The function `get_image_size(...)` is deprecated and will be removed in a future release. " + "Instead, please use `get_size(...)` which returns `[h, w]` instead of `[w, h]`." + ) + return _F.get_image_size(inpt) diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/_geometry.py b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/_geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..da080e437c99e0b173c496a9097bbae6ae020d7c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/_geometry.py @@ -0,0 +1,2377 @@ +import math +import numbers +import warnings +from typing import Any, List, Optional, Sequence, Tuple, Union + +import PIL.Image +import torch +from torch.nn.functional import grid_sample, interpolate, pad as torch_pad + +from torchvision import tv_tensors +from torchvision.transforms import _functional_pil as _FP +from torchvision.transforms._functional_tensor import _pad_symmetric +from torchvision.transforms.functional import ( + _compute_resized_output_size as __compute_resized_output_size, + _get_perspective_coeffs, + _interpolation_modes_from_int, + InterpolationMode, + pil_modes_mapping, + pil_to_tensor, + to_pil_image, +) + +from torchvision.utils import _log_api_usage_once + +from ._meta import _get_size_image_pil, clamp_bounding_boxes, convert_bounding_box_format + +from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal + + +def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode: + if isinstance(interpolation, int): + interpolation = _interpolation_modes_from_int(interpolation) + elif not isinstance(interpolation, InterpolationMode): + raise ValueError( + f"Argument interpolation should be an `InterpolationMode` or a corresponding Pillow integer constant, " + f"but got {interpolation}." + ) + return interpolation + + +def horizontal_flip(inpt: torch.Tensor) -> torch.Tensor: + """See :class:`~torchvision.transforms.v2.RandomHorizontalFlip` for details.""" + if torch.jit.is_scripting(): + return horizontal_flip_image(inpt) + + _log_api_usage_once(horizontal_flip) + + kernel = _get_kernel(horizontal_flip, type(inpt)) + return kernel(inpt) + + +@_register_kernel_internal(horizontal_flip, torch.Tensor) +@_register_kernel_internal(horizontal_flip, tv_tensors.Image) +def horizontal_flip_image(image: torch.Tensor) -> torch.Tensor: + return image.flip(-1) + + +@_register_kernel_internal(horizontal_flip, PIL.Image.Image) +def _horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image: + return _FP.hflip(image) + + +@_register_kernel_internal(horizontal_flip, tv_tensors.Mask) +def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor: + return horizontal_flip_image(mask) + + +def horizontal_flip_bounding_boxes( + bounding_boxes: torch.Tensor, format: tv_tensors.BoundingBoxFormat, canvas_size: Tuple[int, int] +) -> torch.Tensor: + shape = bounding_boxes.shape + + bounding_boxes = bounding_boxes.clone().reshape(-1, 4) + + if format == tv_tensors.BoundingBoxFormat.XYXY: + bounding_boxes[:, [2, 0]] = bounding_boxes[:, [0, 2]].sub_(canvas_size[1]).neg_() + elif format == tv_tensors.BoundingBoxFormat.XYWH: + bounding_boxes[:, 0].add_(bounding_boxes[:, 2]).sub_(canvas_size[1]).neg_() + else: # format == tv_tensors.BoundingBoxFormat.CXCYWH: + bounding_boxes[:, 0].sub_(canvas_size[1]).neg_() + + return bounding_boxes.reshape(shape) + + +@_register_kernel_internal(horizontal_flip, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False) +def _horizontal_flip_bounding_boxes_dispatch(inpt: tv_tensors.BoundingBoxes) -> tv_tensors.BoundingBoxes: + output = horizontal_flip_bounding_boxes( + inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size + ) + return tv_tensors.wrap(output, like=inpt) + + +@_register_kernel_internal(horizontal_flip, tv_tensors.Video) +def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor: + return horizontal_flip_image(video) + + +def vertical_flip(inpt: torch.Tensor) -> torch.Tensor: + """See :class:`~torchvision.transforms.v2.RandomVerticalFlip` for details.""" + if torch.jit.is_scripting(): + return vertical_flip_image(inpt) + + _log_api_usage_once(vertical_flip) + + kernel = _get_kernel(vertical_flip, type(inpt)) + return kernel(inpt) + + +@_register_kernel_internal(vertical_flip, torch.Tensor) +@_register_kernel_internal(vertical_flip, tv_tensors.Image) +def vertical_flip_image(image: torch.Tensor) -> torch.Tensor: + return image.flip(-2) + + +@_register_kernel_internal(vertical_flip, PIL.Image.Image) +def _vertical_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image: + return _FP.vflip(image) + + +@_register_kernel_internal(vertical_flip, tv_tensors.Mask) +def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor: + return vertical_flip_image(mask) + + +def vertical_flip_bounding_boxes( + bounding_boxes: torch.Tensor, format: tv_tensors.BoundingBoxFormat, canvas_size: Tuple[int, int] +) -> torch.Tensor: + shape = bounding_boxes.shape + + bounding_boxes = bounding_boxes.clone().reshape(-1, 4) + + if format == tv_tensors.BoundingBoxFormat.XYXY: + bounding_boxes[:, [1, 3]] = bounding_boxes[:, [3, 1]].sub_(canvas_size[0]).neg_() + elif format == tv_tensors.BoundingBoxFormat.XYWH: + bounding_boxes[:, 1].add_(bounding_boxes[:, 3]).sub_(canvas_size[0]).neg_() + else: # format == tv_tensors.BoundingBoxFormat.CXCYWH: + bounding_boxes[:, 1].sub_(canvas_size[0]).neg_() + + return bounding_boxes.reshape(shape) + + +@_register_kernel_internal(vertical_flip, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False) +def _vertical_flip_bounding_boxes_dispatch(inpt: tv_tensors.BoundingBoxes) -> tv_tensors.BoundingBoxes: + output = vertical_flip_bounding_boxes( + inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size + ) + return tv_tensors.wrap(output, like=inpt) + + +@_register_kernel_internal(vertical_flip, tv_tensors.Video) +def vertical_flip_video(video: torch.Tensor) -> torch.Tensor: + return vertical_flip_image(video) + + +# We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are +# prevalent and well understood. Thus, we just alias them without deprecating the old names. +hflip = horizontal_flip +vflip = vertical_flip + + +def _compute_resized_output_size( + canvas_size: Tuple[int, int], size: Optional[List[int]], max_size: Optional[int] = None +) -> List[int]: + if isinstance(size, int): + size = [size] + elif max_size is not None and size is not None and len(size) != 1: + raise ValueError( + "max_size should only be passed if size is None or specifies the length of the smaller edge, " + "i.e. size should be an int or a sequence of length 1 in torchscript mode." + ) + return __compute_resized_output_size(canvas_size, size=size, max_size=max_size, allow_size_none=True) + + +def resize( + inpt: torch.Tensor, + size: Optional[List[int]], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[bool] = True, +) -> torch.Tensor: + """See :class:`~torchvision.transforms.v2.Resize` for details.""" + if torch.jit.is_scripting(): + return resize_image(inpt, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias) + + _log_api_usage_once(resize) + + kernel = _get_kernel(resize, type(inpt)) + return kernel(inpt, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias) + + +# This is an internal helper method for resize_image. We should put it here instead of keeping it +# inside resize_image due to torchscript. +# uint8 dtype support for bilinear and bicubic is limited to cpu and +# according to our benchmarks on eager, non-AVX CPUs should still prefer u8->f32->interpolate->u8 path for bilinear +def _do_native_uint8_resize_on_cpu(interpolation: InterpolationMode) -> bool: + if interpolation == InterpolationMode.BILINEAR: + if torch.compiler.is_compiling(): + return True + else: + return "AVX2" in torch.backends.cpu.get_cpu_capability() + + return interpolation == InterpolationMode.BICUBIC + + +@_register_kernel_internal(resize, torch.Tensor) +@_register_kernel_internal(resize, tv_tensors.Image) +def resize_image( + image: torch.Tensor, + size: Optional[List[int]], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[bool] = True, +) -> torch.Tensor: + interpolation = _check_interpolation(interpolation) + antialias = False if antialias is None else antialias + align_corners: Optional[bool] = None + if interpolation == InterpolationMode.BILINEAR or interpolation == InterpolationMode.BICUBIC: + align_corners = False + else: + # The default of antialias is True from 0.17, so we don't warn or + # error if other interpolation modes are used. This is documented. + antialias = False + + shape = image.shape + numel = image.numel() + num_channels, old_height, old_width = shape[-3:] + new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size) + + if (new_height, new_width) == (old_height, old_width): + return image + elif numel > 0: + dtype = image.dtype + acceptable_dtypes = [torch.float32, torch.float64] + if interpolation == InterpolationMode.NEAREST or interpolation == InterpolationMode.NEAREST_EXACT: + # uint8 dtype can be included for cpu and cuda input if nearest mode + acceptable_dtypes.append(torch.uint8) + elif image.device.type == "cpu": + if _do_native_uint8_resize_on_cpu(interpolation): + acceptable_dtypes.append(torch.uint8) + + image = image.reshape(-1, num_channels, old_height, old_width) + strides = image.stride() + if image.is_contiguous(memory_format=torch.channels_last) and image.shape[0] == 1 and numel != strides[0]: + # There is a weird behaviour in torch core where the output tensor of `interpolate()` can be allocated as + # contiguous even though the input is un-ambiguously channels_last (https://github.com/pytorch/pytorch/issues/68430). + # In particular this happens for the typical torchvision use-case of single CHW images where we fake the batch dim + # to become 1CHW. Below, we restride those tensors to trick torch core into properly allocating the output as + # channels_last, thus preserving the memory format of the input. This is not just for format consistency: + # for uint8 bilinear images, this also avoids an extra copy (re-packing) of the output and saves time. + # TODO: when https://github.com/pytorch/pytorch/issues/68430 is fixed (possibly by https://github.com/pytorch/pytorch/pull/100373), + # we should be able to remove this hack. + new_strides = list(strides) + new_strides[0] = numel + image = image.as_strided((1, num_channels, old_height, old_width), new_strides) + + need_cast = dtype not in acceptable_dtypes + if need_cast: + image = image.to(dtype=torch.float32) + + image = interpolate( + image, + size=[new_height, new_width], + mode=interpolation.value, + align_corners=align_corners, + antialias=antialias, + ) + + if need_cast: + if interpolation == InterpolationMode.BICUBIC and dtype == torch.uint8: + # This path is hit on non-AVX archs, or on GPU. + image = image.clamp_(min=0, max=255) + if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): + image = image.round_() + image = image.to(dtype=dtype) + + return image.reshape(shape[:-3] + (num_channels, new_height, new_width)) + + +def _resize_image_pil( + image: PIL.Image.Image, + size: Union[Sequence[int], int], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, +) -> PIL.Image.Image: + old_height, old_width = image.height, image.width + new_height, new_width = _compute_resized_output_size( + (old_height, old_width), + size=size, # type: ignore[arg-type] + max_size=max_size, + ) + + interpolation = _check_interpolation(interpolation) + + if (new_height, new_width) == (old_height, old_width): + return image + + return image.resize((new_width, new_height), resample=pil_modes_mapping[interpolation]) + + +@_register_kernel_internal(resize, PIL.Image.Image) +def __resize_image_pil_dispatch( + image: PIL.Image.Image, + size: Union[Sequence[int], int], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[bool] = True, +) -> PIL.Image.Image: + if antialias is False: + warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.") + return _resize_image_pil(image, size=size, interpolation=interpolation, max_size=max_size) + + +def resize_mask(mask: torch.Tensor, size: Optional[List[int]], max_size: Optional[int] = None) -> torch.Tensor: + if mask.ndim < 3: + mask = mask.unsqueeze(0) + needs_squeeze = True + else: + needs_squeeze = False + + output = resize_image(mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size) + + if needs_squeeze: + output = output.squeeze(0) + + return output + + +@_register_kernel_internal(resize, tv_tensors.Mask, tv_tensor_wrapper=False) +def _resize_mask_dispatch( + inpt: tv_tensors.Mask, size: List[int], max_size: Optional[int] = None, **kwargs: Any +) -> tv_tensors.Mask: + output = resize_mask(inpt.as_subclass(torch.Tensor), size, max_size=max_size) + return tv_tensors.wrap(output, like=inpt) + + +def resize_bounding_boxes( + bounding_boxes: torch.Tensor, + canvas_size: Tuple[int, int], + size: Optional[List[int]], + max_size: Optional[int] = None, +) -> Tuple[torch.Tensor, Tuple[int, int]]: + old_height, old_width = canvas_size + new_height, new_width = _compute_resized_output_size(canvas_size, size=size, max_size=max_size) + + if (new_height, new_width) == (old_height, old_width): + return bounding_boxes, canvas_size + + w_ratio = new_width / old_width + h_ratio = new_height / old_height + ratios = torch.tensor([w_ratio, h_ratio, w_ratio, h_ratio], device=bounding_boxes.device) + return ( + bounding_boxes.mul(ratios).to(bounding_boxes.dtype), + (new_height, new_width), + ) + + +@_register_kernel_internal(resize, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False) +def _resize_bounding_boxes_dispatch( + inpt: tv_tensors.BoundingBoxes, size: Optional[List[int]], max_size: Optional[int] = None, **kwargs: Any +) -> tv_tensors.BoundingBoxes: + output, canvas_size = resize_bounding_boxes( + inpt.as_subclass(torch.Tensor), inpt.canvas_size, size, max_size=max_size + ) + return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) + + +@_register_kernel_internal(resize, tv_tensors.Video) +def resize_video( + video: torch.Tensor, + size: Optional[List[int]], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[bool] = True, +) -> torch.Tensor: + return resize_image(video, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias) + + +def affine( + inpt: torch.Tensor, + angle: Union[int, float], + translate: List[float], + scale: float, + shear: List[float], + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + fill: _FillTypeJIT = None, + center: Optional[List[float]] = None, +) -> torch.Tensor: + """See :class:`~torchvision.transforms.v2.RandomAffine` for details.""" + if torch.jit.is_scripting(): + return affine_image( + inpt, + angle=angle, + translate=translate, + scale=scale, + shear=shear, + interpolation=interpolation, + fill=fill, + center=center, + ) + + _log_api_usage_once(affine) + + kernel = _get_kernel(affine, type(inpt)) + return kernel( + inpt, + angle=angle, + translate=translate, + scale=scale, + shear=shear, + interpolation=interpolation, + fill=fill, + center=center, + ) + + +def _affine_parse_args( + angle: Union[int, float], + translate: List[float], + scale: float, + shear: List[float], + interpolation: InterpolationMode = InterpolationMode.NEAREST, + center: Optional[List[float]] = None, +) -> Tuple[float, List[float], List[float], Optional[List[float]]]: + if not isinstance(angle, (int, float)): + raise TypeError("Argument angle should be int or float") + + if not isinstance(translate, (list, tuple)): + raise TypeError("Argument translate should be a sequence") + + if len(translate) != 2: + raise ValueError("Argument translate should be a sequence of length 2") + + if scale <= 0.0: + raise ValueError("Argument scale should be positive") + + if not isinstance(shear, (numbers.Number, (list, tuple))): + raise TypeError("Shear should be either a single value or a sequence of two values") + + if not isinstance(interpolation, InterpolationMode): + raise TypeError("Argument interpolation should be a InterpolationMode") + + if isinstance(angle, int): + angle = float(angle) + + if isinstance(translate, tuple): + translate = list(translate) + + if isinstance(shear, numbers.Number): + shear = [shear, 0.0] + + if isinstance(shear, tuple): + shear = list(shear) + + if len(shear) == 1: + shear = [shear[0], shear[0]] + + if len(shear) != 2: + raise ValueError(f"Shear should be a sequence containing two values. Got {shear}") + + if center is not None: + if not isinstance(center, (list, tuple)): + raise TypeError("Argument center should be a sequence") + else: + center = [float(c) for c in center] + + return angle, translate, shear, center + + +def _get_inverse_affine_matrix( + center: List[float], angle: float, translate: List[float], scale: float, shear: List[float], inverted: bool = True +) -> List[float]: + # Helper method to compute inverse matrix for affine transformation + + # Pillow requires inverse affine transformation matrix: + # Affine matrix is : M = T * C * RotateScaleShear * C^-1 + # + # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1] + # C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1] + # RotateScaleShear is rotation with scale and shear matrix + # + # RotateScaleShear(a, s, (sx, sy)) = + # = R(a) * S(s) * SHy(sy) * SHx(sx) + # = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ] + # [ s*sin(a - sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ] + # [ 0 , 0 , 1 ] + # where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears: + # SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0] + # [0, 1 ] [-tan(s), 1] + # + # Thus, the inverse is M^-1 = C * RotateScaleShear^-1 * C^-1 * T^-1 + + rot = math.radians(angle) + sx = math.radians(shear[0]) + sy = math.radians(shear[1]) + + cx, cy = center + tx, ty = translate + + # Cached results + cos_sy = math.cos(sy) + tan_sx = math.tan(sx) + rot_minus_sy = rot - sy + cx_plus_tx = cx + tx + cy_plus_ty = cy + ty + + # Rotate Scale Shear (RSS) without scaling + a = math.cos(rot_minus_sy) / cos_sy + b = -(a * tan_sx + math.sin(rot)) + c = math.sin(rot_minus_sy) / cos_sy + d = math.cos(rot) - c * tan_sx + + if inverted: + # Inverted rotation matrix with scale and shear + # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1 + matrix = [d / scale, -b / scale, 0.0, -c / scale, a / scale, 0.0] + # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 + # and then apply center translation: C * RSS^-1 * C^-1 * T^-1 + matrix[2] += cx - matrix[0] * cx_plus_tx - matrix[1] * cy_plus_ty + matrix[5] += cy - matrix[3] * cx_plus_tx - matrix[4] * cy_plus_ty + else: + matrix = [a * scale, b * scale, 0.0, c * scale, d * scale, 0.0] + # Apply inverse of center translation: RSS * C^-1 + # and then apply translation and center : T * C * RSS * C^-1 + matrix[2] += cx_plus_tx - matrix[0] * cx - matrix[1] * cy + matrix[5] += cy_plus_ty - matrix[3] * cx - matrix[4] * cy + + return matrix + + +def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]: + if torch.compiler.is_compiling() and not torch.jit.is_scripting(): + return _compute_affine_output_size_python(matrix, w, h) + else: + return _compute_affine_output_size_tensor(matrix, w, h) + + +def _compute_affine_output_size_tensor(matrix: List[float], w: int, h: int) -> Tuple[int, int]: + # Inspired of PIL implementation: + # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054 + + # pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points. + # Points are shifted due to affine matrix torch convention about + # the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5) + half_w = 0.5 * w + half_h = 0.5 * h + pts = torch.tensor( + [ + [-half_w, -half_h, 1.0], + [-half_w, half_h, 1.0], + [half_w, half_h, 1.0], + [half_w, -half_h, 1.0], + ] + ) + theta = torch.tensor(matrix, dtype=torch.float).view(2, 3) + new_pts = torch.matmul(pts, theta.T) + min_vals, max_vals = new_pts.aminmax(dim=0) + + # shift points to [0, w] and [0, h] interval to match PIL results + halfs = torch.tensor((half_w, half_h)) + min_vals.add_(halfs) + max_vals.add_(halfs) + + # Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0 + tol = 1e-4 + inv_tol = 1.0 / tol + cmax = max_vals.mul_(inv_tol).trunc_().mul_(tol).ceil_() + cmin = min_vals.mul_(inv_tol).trunc_().mul_(tol).floor_() + size = cmax.sub_(cmin) + return int(size[0]), int(size[1]) # w, h + + +def _compute_affine_output_size_python(matrix: List[float], w: int, h: int) -> Tuple[int, int]: + # Mostly copied from PIL implementation: + # The only difference is with transformed points as input matrix has zero translation part here and + # PIL has a centered translation part. + # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054 + + a, b, c, d, e, f = matrix + xx = [] + yy = [] + + half_w = 0.5 * w + half_h = 0.5 * h + for x, y in ((-half_w, -half_h), (half_w, -half_h), (half_w, half_h), (-half_w, half_h)): + nx = a * x + b * y + c + ny = d * x + e * y + f + xx.append(nx + half_w) + yy.append(ny + half_h) + + nw = math.ceil(max(xx)) - math.floor(min(xx)) + nh = math.ceil(max(yy)) - math.floor(min(yy)) + return int(nw), int(nh) # w, h + + +def _apply_grid_transform(img: torch.Tensor, grid: torch.Tensor, mode: str, fill: _FillTypeJIT) -> torch.Tensor: + input_shape = img.shape + output_height, output_width = grid.shape[1], grid.shape[2] + num_channels, input_height, input_width = input_shape[-3:] + output_shape = input_shape[:-3] + (num_channels, output_height, output_width) + + if img.numel() == 0: + return img.reshape(output_shape) + + img = img.reshape(-1, num_channels, input_height, input_width) + squashed_batch_size = img.shape[0] + + # We are using context knowledge that grid should have float dtype + fp = img.dtype == grid.dtype + float_img = img if fp else img.to(grid.dtype) + + if squashed_batch_size > 1: + # Apply same grid to a batch of images + grid = grid.expand(squashed_batch_size, -1, -1, -1) + + # Append a dummy mask for customized fill colors, should be faster than grid_sample() twice + if fill is not None: + mask = torch.ones( + (squashed_batch_size, 1, input_height, input_width), dtype=float_img.dtype, device=float_img.device + ) + float_img = torch.cat((float_img, mask), dim=1) + + float_img = grid_sample(float_img, grid, mode=mode, padding_mode="zeros", align_corners=False) + + # Fill with required color + if fill is not None: + float_img, mask = torch.tensor_split(float_img, indices=(-1,), dim=-3) + mask = mask.expand_as(float_img) + fill_list = fill if isinstance(fill, (tuple, list)) else [float(fill)] # type: ignore[arg-type] + fill_img = torch.tensor(fill_list, dtype=float_img.dtype, device=float_img.device).view(1, -1, 1, 1) + if mode == "nearest": + float_img = torch.where(mask < 0.5, fill_img.expand_as(float_img), float_img) + else: # 'bilinear' + # The following is mathematically equivalent to: + # img * mask + (1.0 - mask) * fill = img * mask - fill * mask + fill = mask * (img - fill) + fill + float_img = float_img.sub_(fill_img).mul_(mask).add_(fill_img) + + img = float_img.round_().to(img.dtype) if not fp else float_img + + return img.reshape(output_shape) + + +def _assert_grid_transform_inputs( + image: torch.Tensor, + matrix: Optional[List[float]], + interpolation: str, + fill: _FillTypeJIT, + supported_interpolation_modes: List[str], + coeffs: Optional[List[float]] = None, +) -> None: + if matrix is not None: + if not isinstance(matrix, list): + raise TypeError("Argument matrix should be a list") + elif len(matrix) != 6: + raise ValueError("Argument matrix should have 6 float values") + + if coeffs is not None and len(coeffs) != 8: + raise ValueError("Argument coeffs should have 8 float values") + + if fill is not None: + if isinstance(fill, (tuple, list)): + length = len(fill) + num_channels = image.shape[-3] + if length > 1 and length != num_channels: + raise ValueError( + "The number of elements in 'fill' cannot broadcast to match the number of " + f"channels of the image ({length} != {num_channels})" + ) + elif not isinstance(fill, (int, float)): + raise ValueError("Argument fill should be either int, float, tuple or list") + + if interpolation not in supported_interpolation_modes: + raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input") + + +def _affine_grid( + theta: torch.Tensor, + w: int, + h: int, + ow: int, + oh: int, +) -> torch.Tensor: + # https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/ + # AffineGridGenerator.cpp#L18 + # Difference with AffineGridGenerator is that: + # 1) we normalize grid values after applying theta + # 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate + dtype = theta.dtype + device = theta.device + + base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device) + x_grid = torch.linspace((1.0 - ow) * 0.5, (ow - 1.0) * 0.5, steps=ow, device=device) + base_grid[..., 0].copy_(x_grid) + y_grid = torch.linspace((1.0 - oh) * 0.5, (oh - 1.0) * 0.5, steps=oh, device=device).unsqueeze_(-1) + base_grid[..., 1].copy_(y_grid) + base_grid[..., 2].fill_(1) + + rescaled_theta = theta.transpose(1, 2).div_(torch.tensor([0.5 * w, 0.5 * h], dtype=dtype, device=device)) + output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta) + return output_grid.view(1, oh, ow, 2) + + +@_register_kernel_internal(affine, torch.Tensor) +@_register_kernel_internal(affine, tv_tensors.Image) +def affine_image( + image: torch.Tensor, + angle: Union[int, float], + translate: List[float], + scale: float, + shear: List[float], + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + fill: _FillTypeJIT = None, + center: Optional[List[float]] = None, +) -> torch.Tensor: + interpolation = _check_interpolation(interpolation) + + angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center) + + height, width = image.shape[-2:] + + center_f = [0.0, 0.0] + if center is not None: + # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. + center_f = [(c - s * 0.5) for c, s in zip(center, [width, height])] + + translate_f = [float(t) for t in translate] + matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear) + + _assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"]) + + dtype = image.dtype if torch.is_floating_point(image) else torch.float32 + theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3) + grid = _affine_grid(theta, w=width, h=height, ow=width, oh=height) + return _apply_grid_transform(image, grid, interpolation.value, fill=fill) + + +@_register_kernel_internal(affine, PIL.Image.Image) +def _affine_image_pil( + image: PIL.Image.Image, + angle: Union[int, float], + translate: List[float], + scale: float, + shear: List[float], + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + fill: _FillTypeJIT = None, + center: Optional[List[float]] = None, +) -> PIL.Image.Image: + interpolation = _check_interpolation(interpolation) + angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center) + + # center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5) + # it is visually better to estimate the center without 0.5 offset + # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine + if center is None: + height, width = _get_size_image_pil(image) + center = [width * 0.5, height * 0.5] + matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) + + return _FP.affine(image, matrix, interpolation=pil_modes_mapping[interpolation], fill=fill) + + +def _affine_bounding_boxes_with_expand( + bounding_boxes: torch.Tensor, + format: tv_tensors.BoundingBoxFormat, + canvas_size: Tuple[int, int], + angle: Union[int, float], + translate: List[float], + scale: float, + shear: List[float], + center: Optional[List[float]] = None, + expand: bool = False, +) -> Tuple[torch.Tensor, Tuple[int, int]]: + if bounding_boxes.numel() == 0: + return bounding_boxes, canvas_size + + original_shape = bounding_boxes.shape + original_dtype = bounding_boxes.dtype + bounding_boxes = bounding_boxes.clone() if bounding_boxes.is_floating_point() else bounding_boxes.float() + dtype = bounding_boxes.dtype + device = bounding_boxes.device + bounding_boxes = ( + convert_bounding_box_format( + bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY, inplace=True + ) + ).reshape(-1, 4) + + angle, translate, shear, center = _affine_parse_args( + angle, translate, scale, shear, InterpolationMode.NEAREST, center + ) + + if center is None: + height, width = canvas_size + center = [width * 0.5, height * 0.5] + + affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear, inverted=False) + transposed_affine_matrix = ( + torch.tensor( + affine_vector, + dtype=dtype, + device=device, + ) + .reshape(2, 3) + .T + ) + # 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners). + # Tensor of points has shape (N * 4, 3), where N is the number of bboxes + # Single point structure is similar to + # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)] + points = bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2) + points = torch.cat([points, torch.ones(points.shape[0], 1, device=device, dtype=dtype)], dim=-1) + # 2) Now let's transform the points using affine matrix + transformed_points = torch.matmul(points, transposed_affine_matrix) + # 3) Reshape transformed points to [N boxes, 4 points, x/y coords] + # and compute bounding box from 4 transformed points: + transformed_points = transformed_points.reshape(-1, 4, 2) + out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1) + out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1) + + if expand: + # Compute minimum point for transformed image frame: + # Points are Top-Left, Top-Right, Bottom-Left, Bottom-Right points. + height, width = canvas_size + points = torch.tensor( + [ + [0.0, 0.0, 1.0], + [0.0, float(height), 1.0], + [float(width), float(height), 1.0], + [float(width), 0.0, 1.0], + ], + dtype=dtype, + device=device, + ) + new_points = torch.matmul(points, transposed_affine_matrix) + tr = torch.amin(new_points, dim=0, keepdim=True) + # Translate bounding boxes + out_bboxes.sub_(tr.repeat((1, 2))) + # Estimate meta-data for image with inverted=True + affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear) + new_width, new_height = _compute_affine_output_size(affine_vector, width, height) + canvas_size = (new_height, new_width) + + out_bboxes = clamp_bounding_boxes(out_bboxes, format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=canvas_size) + out_bboxes = convert_bounding_box_format( + out_bboxes, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format, inplace=True + ).reshape(original_shape) + + out_bboxes = out_bboxes.to(original_dtype) + return out_bboxes, canvas_size + + +def affine_bounding_boxes( + bounding_boxes: torch.Tensor, + format: tv_tensors.BoundingBoxFormat, + canvas_size: Tuple[int, int], + angle: Union[int, float], + translate: List[float], + scale: float, + shear: List[float], + center: Optional[List[float]] = None, +) -> torch.Tensor: + out_box, _ = _affine_bounding_boxes_with_expand( + bounding_boxes, + format=format, + canvas_size=canvas_size, + angle=angle, + translate=translate, + scale=scale, + shear=shear, + center=center, + expand=False, + ) + return out_box + + +@_register_kernel_internal(affine, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False) +def _affine_bounding_boxes_dispatch( + inpt: tv_tensors.BoundingBoxes, + angle: Union[int, float], + translate: List[float], + scale: float, + shear: List[float], + center: Optional[List[float]] = None, + **kwargs, +) -> tv_tensors.BoundingBoxes: + output = affine_bounding_boxes( + inpt.as_subclass(torch.Tensor), + format=inpt.format, + canvas_size=inpt.canvas_size, + angle=angle, + translate=translate, + scale=scale, + shear=shear, + center=center, + ) + return tv_tensors.wrap(output, like=inpt) + + +def affine_mask( + mask: torch.Tensor, + angle: Union[int, float], + translate: List[float], + scale: float, + shear: List[float], + fill: _FillTypeJIT = None, + center: Optional[List[float]] = None, +) -> torch.Tensor: + if mask.ndim < 3: + mask = mask.unsqueeze(0) + needs_squeeze = True + else: + needs_squeeze = False + + output = affine_image( + mask, + angle=angle, + translate=translate, + scale=scale, + shear=shear, + interpolation=InterpolationMode.NEAREST, + fill=fill, + center=center, + ) + + if needs_squeeze: + output = output.squeeze(0) + + return output + + +@_register_kernel_internal(affine, tv_tensors.Mask, tv_tensor_wrapper=False) +def _affine_mask_dispatch( + inpt: tv_tensors.Mask, + angle: Union[int, float], + translate: List[float], + scale: float, + shear: List[float], + fill: _FillTypeJIT = None, + center: Optional[List[float]] = None, + **kwargs, +) -> tv_tensors.Mask: + output = affine_mask( + inpt.as_subclass(torch.Tensor), + angle=angle, + translate=translate, + scale=scale, + shear=shear, + fill=fill, + center=center, + ) + return tv_tensors.wrap(output, like=inpt) + + +@_register_kernel_internal(affine, tv_tensors.Video) +def affine_video( + video: torch.Tensor, + angle: Union[int, float], + translate: List[float], + scale: float, + shear: List[float], + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + fill: _FillTypeJIT = None, + center: Optional[List[float]] = None, +) -> torch.Tensor: + return affine_image( + video, + angle=angle, + translate=translate, + scale=scale, + shear=shear, + interpolation=interpolation, + fill=fill, + center=center, + ) + + +def rotate( + inpt: torch.Tensor, + angle: float, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + expand: bool = False, + center: Optional[List[float]] = None, + fill: _FillTypeJIT = None, +) -> torch.Tensor: + """See :class:`~torchvision.transforms.v2.RandomRotation` for details.""" + if torch.jit.is_scripting(): + return rotate_image(inpt, angle=angle, interpolation=interpolation, expand=expand, fill=fill, center=center) + + _log_api_usage_once(rotate) + + kernel = _get_kernel(rotate, type(inpt)) + return kernel(inpt, angle=angle, interpolation=interpolation, expand=expand, fill=fill, center=center) + + +@_register_kernel_internal(rotate, torch.Tensor) +@_register_kernel_internal(rotate, tv_tensors.Image) +def rotate_image( + image: torch.Tensor, + angle: float, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + expand: bool = False, + center: Optional[List[float]] = None, + fill: _FillTypeJIT = None, +) -> torch.Tensor: + angle = angle % 360 # shift angle to [0, 360) range + + # fast path: transpose without affine transform + if center is None: + if angle == 0: + return image.clone() + if angle == 180: + return torch.rot90(image, k=2, dims=(-2, -1)) + + if expand or image.shape[-1] == image.shape[-2]: + if angle == 90: + return torch.rot90(image, k=1, dims=(-2, -1)) + if angle == 270: + return torch.rot90(image, k=3, dims=(-2, -1)) + + interpolation = _check_interpolation(interpolation) + + input_height, input_width = image.shape[-2:] + + center_f = [0.0, 0.0] + if center is not None: + # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. + center_f = [(c - s * 0.5) for c, s in zip(center, [input_width, input_height])] + + # due to current incoherence of rotation angle direction between affine and rotate implementations + # we need to set -angle. + matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0]) + + _assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"]) + + output_width, output_height = ( + _compute_affine_output_size(matrix, input_width, input_height) if expand else (input_width, input_height) + ) + dtype = image.dtype if torch.is_floating_point(image) else torch.float32 + theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3) + grid = _affine_grid(theta, w=input_width, h=input_height, ow=output_width, oh=output_height) + return _apply_grid_transform(image, grid, interpolation.value, fill=fill) + + +@_register_kernel_internal(rotate, PIL.Image.Image) +def _rotate_image_pil( + image: PIL.Image.Image, + angle: float, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + expand: bool = False, + center: Optional[List[float]] = None, + fill: _FillTypeJIT = None, +) -> PIL.Image.Image: + interpolation = _check_interpolation(interpolation) + + return _FP.rotate( + image, angle, interpolation=pil_modes_mapping[interpolation], expand=expand, fill=fill, center=center + ) + + +def rotate_bounding_boxes( + bounding_boxes: torch.Tensor, + format: tv_tensors.BoundingBoxFormat, + canvas_size: Tuple[int, int], + angle: float, + expand: bool = False, + center: Optional[List[float]] = None, +) -> Tuple[torch.Tensor, Tuple[int, int]]: + return _affine_bounding_boxes_with_expand( + bounding_boxes, + format=format, + canvas_size=canvas_size, + angle=-angle, + translate=[0.0, 0.0], + scale=1.0, + shear=[0.0, 0.0], + center=center, + expand=expand, + ) + + +@_register_kernel_internal(rotate, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False) +def _rotate_bounding_boxes_dispatch( + inpt: tv_tensors.BoundingBoxes, angle: float, expand: bool = False, center: Optional[List[float]] = None, **kwargs +) -> tv_tensors.BoundingBoxes: + output, canvas_size = rotate_bounding_boxes( + inpt.as_subclass(torch.Tensor), + format=inpt.format, + canvas_size=inpt.canvas_size, + angle=angle, + expand=expand, + center=center, + ) + return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) + + +def rotate_mask( + mask: torch.Tensor, + angle: float, + expand: bool = False, + center: Optional[List[float]] = None, + fill: _FillTypeJIT = None, +) -> torch.Tensor: + if mask.ndim < 3: + mask = mask.unsqueeze(0) + needs_squeeze = True + else: + needs_squeeze = False + + output = rotate_image( + mask, + angle=angle, + expand=expand, + interpolation=InterpolationMode.NEAREST, + fill=fill, + center=center, + ) + + if needs_squeeze: + output = output.squeeze(0) + + return output + + +@_register_kernel_internal(rotate, tv_tensors.Mask, tv_tensor_wrapper=False) +def _rotate_mask_dispatch( + inpt: tv_tensors.Mask, + angle: float, + expand: bool = False, + center: Optional[List[float]] = None, + fill: _FillTypeJIT = None, + **kwargs, +) -> tv_tensors.Mask: + output = rotate_mask(inpt.as_subclass(torch.Tensor), angle=angle, expand=expand, fill=fill, center=center) + return tv_tensors.wrap(output, like=inpt) + + +@_register_kernel_internal(rotate, tv_tensors.Video) +def rotate_video( + video: torch.Tensor, + angle: float, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + expand: bool = False, + center: Optional[List[float]] = None, + fill: _FillTypeJIT = None, +) -> torch.Tensor: + return rotate_image(video, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) + + +def pad( + inpt: torch.Tensor, + padding: List[int], + fill: Optional[Union[int, float, List[float]]] = None, + padding_mode: str = "constant", +) -> torch.Tensor: + """See :class:`~torchvision.transforms.v2.Pad` for details.""" + if torch.jit.is_scripting(): + return pad_image(inpt, padding=padding, fill=fill, padding_mode=padding_mode) + + _log_api_usage_once(pad) + + kernel = _get_kernel(pad, type(inpt)) + return kernel(inpt, padding=padding, fill=fill, padding_mode=padding_mode) + + +def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]: + if isinstance(padding, int): + pad_left = pad_right = pad_top = pad_bottom = padding + elif isinstance(padding, (tuple, list)): + if len(padding) == 1: + pad_left = pad_right = pad_top = pad_bottom = padding[0] + elif len(padding) == 2: + pad_left = pad_right = padding[0] + pad_top = pad_bottom = padding[1] + elif len(padding) == 4: + pad_left = padding[0] + pad_top = padding[1] + pad_right = padding[2] + pad_bottom = padding[3] + else: + raise ValueError( + f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple" + ) + else: + raise TypeError(f"`padding` should be an integer or tuple or list of integers, but got {padding}") + + return [pad_left, pad_right, pad_top, pad_bottom] + + +@_register_kernel_internal(pad, torch.Tensor) +@_register_kernel_internal(pad, tv_tensors.Image) +def pad_image( + image: torch.Tensor, + padding: List[int], + fill: Optional[Union[int, float, List[float]]] = None, + padding_mode: str = "constant", +) -> torch.Tensor: + # Be aware that while `padding` has order `[left, top, right, bottom]`, `torch_padding` uses + # `[left, right, top, bottom]`. This stems from the fact that we align our API with PIL, but need to use `torch_pad` + # internally. + torch_padding = _parse_pad_padding(padding) + + if padding_mode not in ("constant", "edge", "reflect", "symmetric"): + raise ValueError( + f"`padding_mode` should be either `'constant'`, `'edge'`, `'reflect'` or `'symmetric'`, " + f"but got `'{padding_mode}'`." + ) + + if fill is None: + fill = 0 + + if isinstance(fill, (int, float)): + return _pad_with_scalar_fill(image, torch_padding, fill=fill, padding_mode=padding_mode) + elif len(fill) == 1: + return _pad_with_scalar_fill(image, torch_padding, fill=fill[0], padding_mode=padding_mode) + else: + return _pad_with_vector_fill(image, torch_padding, fill=fill, padding_mode=padding_mode) + + +def _pad_with_scalar_fill( + image: torch.Tensor, + torch_padding: List[int], + fill: Union[int, float], + padding_mode: str, +) -> torch.Tensor: + shape = image.shape + num_channels, height, width = shape[-3:] + + batch_size = 1 + for s in shape[:-3]: + batch_size *= s + + image = image.reshape(batch_size, num_channels, height, width) + + if padding_mode == "edge": + # Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map + # the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad` + # name. + padding_mode = "replicate" + + if padding_mode == "constant": + image = torch_pad(image, torch_padding, mode=padding_mode, value=float(fill)) + elif padding_mode in ("reflect", "replicate"): + # `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs. + # TODO: See https://github.com/pytorch/pytorch/issues/40763 + dtype = image.dtype + if not image.is_floating_point(): + needs_cast = True + image = image.to(torch.float32) + else: + needs_cast = False + + image = torch_pad(image, torch_padding, mode=padding_mode) + + if needs_cast: + image = image.to(dtype) + else: # padding_mode == "symmetric" + image = _pad_symmetric(image, torch_padding) + + new_height, new_width = image.shape[-2:] + + return image.reshape(shape[:-3] + (num_channels, new_height, new_width)) + + +# TODO: This should be removed once torch_pad supports non-scalar padding values +def _pad_with_vector_fill( + image: torch.Tensor, + torch_padding: List[int], + fill: List[float], + padding_mode: str, +) -> torch.Tensor: + if padding_mode != "constant": + raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar") + + output = _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant") + left, right, top, bottom = torch_padding + + # We are creating the tensor in the autodetected dtype first and convert to the right one after to avoid an implicit + # float -> int conversion. That happens for example for the valid input of a uint8 image with floating point fill + # value. + fill = torch.tensor(fill, device=image.device).to(dtype=image.dtype).reshape(-1, 1, 1) + + if top > 0: + output[..., :top, :] = fill + if left > 0: + output[..., :, :left] = fill + if bottom > 0: + output[..., -bottom:, :] = fill + if right > 0: + output[..., :, -right:] = fill + return output + + +_pad_image_pil = _register_kernel_internal(pad, PIL.Image.Image)(_FP.pad) + + +@_register_kernel_internal(pad, tv_tensors.Mask) +def pad_mask( + mask: torch.Tensor, + padding: List[int], + fill: Optional[Union[int, float, List[float]]] = None, + padding_mode: str = "constant", +) -> torch.Tensor: + if fill is None: + fill = 0 + + if isinstance(fill, (tuple, list)): + raise ValueError("Non-scalar fill value is not supported") + + if mask.ndim < 3: + mask = mask.unsqueeze(0) + needs_squeeze = True + else: + needs_squeeze = False + + output = pad_image(mask, padding=padding, fill=fill, padding_mode=padding_mode) + + if needs_squeeze: + output = output.squeeze(0) + + return output + + +def pad_bounding_boxes( + bounding_boxes: torch.Tensor, + format: tv_tensors.BoundingBoxFormat, + canvas_size: Tuple[int, int], + padding: List[int], + padding_mode: str = "constant", +) -> Tuple[torch.Tensor, Tuple[int, int]]: + if padding_mode not in ["constant"]: + # TODO: add support of other padding modes + raise ValueError(f"Padding mode '{padding_mode}' is not supported with bounding boxes") + + left, right, top, bottom = _parse_pad_padding(padding) + + if format == tv_tensors.BoundingBoxFormat.XYXY: + pad = [left, top, left, top] + else: + pad = [left, top, 0, 0] + bounding_boxes = bounding_boxes + torch.tensor(pad, dtype=bounding_boxes.dtype, device=bounding_boxes.device) + + height, width = canvas_size + height += top + bottom + width += left + right + canvas_size = (height, width) + + return clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size), canvas_size + + +@_register_kernel_internal(pad, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False) +def _pad_bounding_boxes_dispatch( + inpt: tv_tensors.BoundingBoxes, padding: List[int], padding_mode: str = "constant", **kwargs +) -> tv_tensors.BoundingBoxes: + output, canvas_size = pad_bounding_boxes( + inpt.as_subclass(torch.Tensor), + format=inpt.format, + canvas_size=inpt.canvas_size, + padding=padding, + padding_mode=padding_mode, + ) + return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) + + +@_register_kernel_internal(pad, tv_tensors.Video) +def pad_video( + video: torch.Tensor, + padding: List[int], + fill: Optional[Union[int, float, List[float]]] = None, + padding_mode: str = "constant", +) -> torch.Tensor: + return pad_image(video, padding, fill=fill, padding_mode=padding_mode) + + +def crop(inpt: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: + """See :class:`~torchvision.transforms.v2.RandomCrop` for details.""" + if torch.jit.is_scripting(): + return crop_image(inpt, top=top, left=left, height=height, width=width) + + _log_api_usage_once(crop) + + kernel = _get_kernel(crop, type(inpt)) + return kernel(inpt, top=top, left=left, height=height, width=width) + + +@_register_kernel_internal(crop, torch.Tensor) +@_register_kernel_internal(crop, tv_tensors.Image) +def crop_image(image: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: + h, w = image.shape[-2:] + + right = left + width + bottom = top + height + + if left < 0 or top < 0 or right > w or bottom > h: + image = image[..., max(top, 0) : bottom, max(left, 0) : right] + torch_padding = [ + max(min(right, 0) - left, 0), + max(right - max(w, left), 0), + max(min(bottom, 0) - top, 0), + max(bottom - max(h, top), 0), + ] + return _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant") + return image[..., top:bottom, left:right] + + +_crop_image_pil = _FP.crop +_register_kernel_internal(crop, PIL.Image.Image)(_crop_image_pil) + + +def crop_bounding_boxes( + bounding_boxes: torch.Tensor, + format: tv_tensors.BoundingBoxFormat, + top: int, + left: int, + height: int, + width: int, +) -> Tuple[torch.Tensor, Tuple[int, int]]: + + # Crop or implicit pad if left and/or top have negative values: + if format == tv_tensors.BoundingBoxFormat.XYXY: + sub = [left, top, left, top] + else: + sub = [left, top, 0, 0] + + bounding_boxes = bounding_boxes - torch.tensor(sub, dtype=bounding_boxes.dtype, device=bounding_boxes.device) + canvas_size = (height, width) + + return clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size), canvas_size + + +@_register_kernel_internal(crop, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False) +def _crop_bounding_boxes_dispatch( + inpt: tv_tensors.BoundingBoxes, top: int, left: int, height: int, width: int +) -> tv_tensors.BoundingBoxes: + output, canvas_size = crop_bounding_boxes( + inpt.as_subclass(torch.Tensor), format=inpt.format, top=top, left=left, height=height, width=width + ) + return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) + + +@_register_kernel_internal(crop, tv_tensors.Mask) +def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: + if mask.ndim < 3: + mask = mask.unsqueeze(0) + needs_squeeze = True + else: + needs_squeeze = False + + output = crop_image(mask, top, left, height, width) + + if needs_squeeze: + output = output.squeeze(0) + + return output + + +@_register_kernel_internal(crop, tv_tensors.Video) +def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: + return crop_image(video, top, left, height, width) + + +def perspective( + inpt: torch.Tensor, + startpoints: Optional[List[List[int]]], + endpoints: Optional[List[List[int]]], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + fill: _FillTypeJIT = None, + coefficients: Optional[List[float]] = None, +) -> torch.Tensor: + """See :class:`~torchvision.transforms.v2.RandomPerspective` for details.""" + if torch.jit.is_scripting(): + return perspective_image( + inpt, + startpoints=startpoints, + endpoints=endpoints, + interpolation=interpolation, + fill=fill, + coefficients=coefficients, + ) + + _log_api_usage_once(perspective) + + kernel = _get_kernel(perspective, type(inpt)) + return kernel( + inpt, + startpoints=startpoints, + endpoints=endpoints, + interpolation=interpolation, + fill=fill, + coefficients=coefficients, + ) + + +def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor: + # https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/ + # src/libImaging/Geometry.c#L394 + + # + # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1) + # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1) + # + theta1 = torch.tensor( + [[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device + ) + theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device) + + d = 0.5 + base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device) + x_grid = torch.linspace(d, ow + d - 1.0, steps=ow, device=device, dtype=dtype) + base_grid[..., 0].copy_(x_grid) + y_grid = torch.linspace(d, oh + d - 1.0, steps=oh, device=device, dtype=dtype).unsqueeze_(-1) + base_grid[..., 1].copy_(y_grid) + base_grid[..., 2].fill_(1) + + rescaled_theta1 = theta1.transpose(1, 2).div_(torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device)) + shape = (1, oh * ow, 3) + output_grid1 = base_grid.view(shape).bmm(rescaled_theta1) + output_grid2 = base_grid.view(shape).bmm(theta2.transpose(1, 2)) + + output_grid = output_grid1.div_(output_grid2).sub_(1.0) + return output_grid.view(1, oh, ow, 2) + + +def _perspective_coefficients( + startpoints: Optional[List[List[int]]], + endpoints: Optional[List[List[int]]], + coefficients: Optional[List[float]], +) -> List[float]: + if coefficients is not None: + if startpoints is not None and endpoints is not None: + raise ValueError("The startpoints/endpoints and the coefficients shouldn't be defined concurrently.") + elif len(coefficients) != 8: + raise ValueError("Argument coefficients should have 8 float values") + return coefficients + elif startpoints is not None and endpoints is not None: + return _get_perspective_coeffs(startpoints, endpoints) + else: + raise ValueError("Either the startpoints/endpoints or the coefficients must have non `None` values.") + + +@_register_kernel_internal(perspective, torch.Tensor) +@_register_kernel_internal(perspective, tv_tensors.Image) +def perspective_image( + image: torch.Tensor, + startpoints: Optional[List[List[int]]], + endpoints: Optional[List[List[int]]], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + fill: _FillTypeJIT = None, + coefficients: Optional[List[float]] = None, +) -> torch.Tensor: + perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients) + interpolation = _check_interpolation(interpolation) + + _assert_grid_transform_inputs( + image, + matrix=None, + interpolation=interpolation.value, + fill=fill, + supported_interpolation_modes=["nearest", "bilinear"], + coeffs=perspective_coeffs, + ) + + oh, ow = image.shape[-2:] + dtype = image.dtype if torch.is_floating_point(image) else torch.float32 + grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=image.device) + return _apply_grid_transform(image, grid, interpolation.value, fill=fill) + + +@_register_kernel_internal(perspective, PIL.Image.Image) +def _perspective_image_pil( + image: PIL.Image.Image, + startpoints: Optional[List[List[int]]], + endpoints: Optional[List[List[int]]], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + fill: _FillTypeJIT = None, + coefficients: Optional[List[float]] = None, +) -> PIL.Image.Image: + perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients) + interpolation = _check_interpolation(interpolation) + return _FP.perspective(image, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill) + + +def perspective_bounding_boxes( + bounding_boxes: torch.Tensor, + format: tv_tensors.BoundingBoxFormat, + canvas_size: Tuple[int, int], + startpoints: Optional[List[List[int]]], + endpoints: Optional[List[List[int]]], + coefficients: Optional[List[float]] = None, +) -> torch.Tensor: + if bounding_boxes.numel() == 0: + return bounding_boxes + + perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients) + + original_shape = bounding_boxes.shape + # TODO: first cast to float if bbox is int64 before convert_bounding_box_format + bounding_boxes = ( + convert_bounding_box_format(bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY) + ).reshape(-1, 4) + + dtype = bounding_boxes.dtype if torch.is_floating_point(bounding_boxes) else torch.float32 + device = bounding_boxes.device + + # perspective_coeffs are computed as endpoint -> start point + # We have to invert perspective_coeffs for bboxes: + # (x, y) - end point and (x_out, y_out) - start point + # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1) + # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1) + # and we would like to get: + # x = (inv_coeffs[0] * x_out + inv_coeffs[1] * y_out + inv_coeffs[2]) + # / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1) + # y = (inv_coeffs[3] * x_out + inv_coeffs[4] * y_out + inv_coeffs[5]) + # / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1) + # and compute inv_coeffs in terms of coeffs + + denom = perspective_coeffs[0] * perspective_coeffs[4] - perspective_coeffs[1] * perspective_coeffs[3] + if denom == 0: + raise RuntimeError( + f"Provided perspective_coeffs {perspective_coeffs} can not be inverted to transform bounding boxes. " + f"Denominator is zero, denom={denom}" + ) + + inv_coeffs = [ + (perspective_coeffs[4] - perspective_coeffs[5] * perspective_coeffs[7]) / denom, + (-perspective_coeffs[1] + perspective_coeffs[2] * perspective_coeffs[7]) / denom, + (perspective_coeffs[1] * perspective_coeffs[5] - perspective_coeffs[2] * perspective_coeffs[4]) / denom, + (-perspective_coeffs[3] + perspective_coeffs[5] * perspective_coeffs[6]) / denom, + (perspective_coeffs[0] - perspective_coeffs[2] * perspective_coeffs[6]) / denom, + (-perspective_coeffs[0] * perspective_coeffs[5] + perspective_coeffs[2] * perspective_coeffs[3]) / denom, + (-perspective_coeffs[4] * perspective_coeffs[6] + perspective_coeffs[3] * perspective_coeffs[7]) / denom, + (-perspective_coeffs[0] * perspective_coeffs[7] + perspective_coeffs[1] * perspective_coeffs[6]) / denom, + ] + + theta1 = torch.tensor( + [[inv_coeffs[0], inv_coeffs[1], inv_coeffs[2]], [inv_coeffs[3], inv_coeffs[4], inv_coeffs[5]]], + dtype=dtype, + device=device, + ) + + theta2 = torch.tensor( + [[inv_coeffs[6], inv_coeffs[7], 1.0], [inv_coeffs[6], inv_coeffs[7], 1.0]], dtype=dtype, device=device + ) + + # 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners). + # Tensor of points has shape (N * 4, 3), where N is the number of bboxes + # Single point structure is similar to + # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)] + points = bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2) + points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1) + # 2) Now let's transform the points using perspective matrices + # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1) + # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1) + + numer_points = torch.matmul(points, theta1.T) + denom_points = torch.matmul(points, theta2.T) + transformed_points = numer_points.div_(denom_points) + + # 3) Reshape transformed points to [N boxes, 4 points, x/y coords] + # and compute bounding box from 4 transformed points: + transformed_points = transformed_points.reshape(-1, 4, 2) + out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1) + + out_bboxes = clamp_bounding_boxes( + torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_boxes.dtype), + format=tv_tensors.BoundingBoxFormat.XYXY, + canvas_size=canvas_size, + ) + + # out_bboxes should be of shape [N boxes, 4] + + return convert_bounding_box_format( + out_bboxes, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format, inplace=True + ).reshape(original_shape) + + +@_register_kernel_internal(perspective, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False) +def _perspective_bounding_boxes_dispatch( + inpt: tv_tensors.BoundingBoxes, + startpoints: Optional[List[List[int]]], + endpoints: Optional[List[List[int]]], + coefficients: Optional[List[float]] = None, + **kwargs, +) -> tv_tensors.BoundingBoxes: + output = perspective_bounding_boxes( + inpt.as_subclass(torch.Tensor), + format=inpt.format, + canvas_size=inpt.canvas_size, + startpoints=startpoints, + endpoints=endpoints, + coefficients=coefficients, + ) + return tv_tensors.wrap(output, like=inpt) + + +def perspective_mask( + mask: torch.Tensor, + startpoints: Optional[List[List[int]]], + endpoints: Optional[List[List[int]]], + fill: _FillTypeJIT = None, + coefficients: Optional[List[float]] = None, +) -> torch.Tensor: + if mask.ndim < 3: + mask = mask.unsqueeze(0) + needs_squeeze = True + else: + needs_squeeze = False + + output = perspective_image( + mask, startpoints, endpoints, interpolation=InterpolationMode.NEAREST, fill=fill, coefficients=coefficients + ) + + if needs_squeeze: + output = output.squeeze(0) + + return output + + +@_register_kernel_internal(perspective, tv_tensors.Mask, tv_tensor_wrapper=False) +def _perspective_mask_dispatch( + inpt: tv_tensors.Mask, + startpoints: Optional[List[List[int]]], + endpoints: Optional[List[List[int]]], + fill: _FillTypeJIT = None, + coefficients: Optional[List[float]] = None, + **kwargs, +) -> tv_tensors.Mask: + output = perspective_mask( + inpt.as_subclass(torch.Tensor), + startpoints=startpoints, + endpoints=endpoints, + fill=fill, + coefficients=coefficients, + ) + return tv_tensors.wrap(output, like=inpt) + + +@_register_kernel_internal(perspective, tv_tensors.Video) +def perspective_video( + video: torch.Tensor, + startpoints: Optional[List[List[int]]], + endpoints: Optional[List[List[int]]], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + fill: _FillTypeJIT = None, + coefficients: Optional[List[float]] = None, +) -> torch.Tensor: + return perspective_image( + video, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients + ) + + +def elastic( + inpt: torch.Tensor, + displacement: torch.Tensor, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + fill: _FillTypeJIT = None, +) -> torch.Tensor: + """See :class:`~torchvision.transforms.v2.ElasticTransform` for details.""" + if torch.jit.is_scripting(): + return elastic_image(inpt, displacement=displacement, interpolation=interpolation, fill=fill) + + _log_api_usage_once(elastic) + + kernel = _get_kernel(elastic, type(inpt)) + return kernel(inpt, displacement=displacement, interpolation=interpolation, fill=fill) + + +elastic_transform = elastic + + +@_register_kernel_internal(elastic, torch.Tensor) +@_register_kernel_internal(elastic, tv_tensors.Image) +def elastic_image( + image: torch.Tensor, + displacement: torch.Tensor, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + fill: _FillTypeJIT = None, +) -> torch.Tensor: + if not isinstance(displacement, torch.Tensor): + raise TypeError("Argument displacement should be a Tensor") + + interpolation = _check_interpolation(interpolation) + + height, width = image.shape[-2:] + device = image.device + dtype = image.dtype if torch.is_floating_point(image) else torch.float32 + + # Patch: elastic transform should support (cpu,f16) input + is_cpu_half = device.type == "cpu" and dtype == torch.float16 + if is_cpu_half: + image = image.to(torch.float32) + dtype = torch.float32 + + # We are aware that if input image dtype is uint8 and displacement is float64 then + # displacement will be cast to float32 and all computations will be done with float32 + # We can fix this later if needed + + expected_shape = (1, height, width, 2) + if expected_shape != displacement.shape: + raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}") + + grid = _create_identity_grid((height, width), device=device, dtype=dtype).add_( + displacement.to(dtype=dtype, device=device) + ) + output = _apply_grid_transform(image, grid, interpolation.value, fill=fill) + + if is_cpu_half: + output = output.to(torch.float16) + + return output + + +@_register_kernel_internal(elastic, PIL.Image.Image) +def _elastic_image_pil( + image: PIL.Image.Image, + displacement: torch.Tensor, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + fill: _FillTypeJIT = None, +) -> PIL.Image.Image: + t_img = pil_to_tensor(image) + output = elastic_image(t_img, displacement, interpolation=interpolation, fill=fill) + return to_pil_image(output, mode=image.mode) + + +def _create_identity_grid(size: Tuple[int, int], device: torch.device, dtype: torch.dtype) -> torch.Tensor: + sy, sx = size + base_grid = torch.empty(1, sy, sx, 2, device=device, dtype=dtype) + x_grid = torch.linspace((-sx + 1) / sx, (sx - 1) / sx, sx, device=device, dtype=dtype) + base_grid[..., 0].copy_(x_grid) + + y_grid = torch.linspace((-sy + 1) / sy, (sy - 1) / sy, sy, device=device, dtype=dtype).unsqueeze_(-1) + base_grid[..., 1].copy_(y_grid) + + return base_grid + + +def elastic_bounding_boxes( + bounding_boxes: torch.Tensor, + format: tv_tensors.BoundingBoxFormat, + canvas_size: Tuple[int, int], + displacement: torch.Tensor, +) -> torch.Tensor: + expected_shape = (1, canvas_size[0], canvas_size[1], 2) + if not isinstance(displacement, torch.Tensor): + raise TypeError("Argument displacement should be a Tensor") + elif displacement.shape != expected_shape: + raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}") + + if bounding_boxes.numel() == 0: + return bounding_boxes + + # TODO: add in docstring about approximation we are doing for grid inversion + device = bounding_boxes.device + dtype = bounding_boxes.dtype if torch.is_floating_point(bounding_boxes) else torch.float32 + + if displacement.dtype != dtype or displacement.device != device: + displacement = displacement.to(dtype=dtype, device=device) + + original_shape = bounding_boxes.shape + # TODO: first cast to float if bbox is int64 before convert_bounding_box_format + bounding_boxes = ( + convert_bounding_box_format(bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY) + ).reshape(-1, 4) + + id_grid = _create_identity_grid(canvas_size, device=device, dtype=dtype) + # We construct an approximation of inverse grid as inv_grid = id_grid - displacement + # This is not an exact inverse of the grid + inv_grid = id_grid.sub_(displacement) + + # Get points from bboxes + points = bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2) + if points.is_floating_point(): + points = points.ceil_() + index_xy = points.to(dtype=torch.long) + index_x, index_y = index_xy[:, 0], index_xy[:, 1] + + # Transform points: + t_size = torch.tensor(canvas_size[::-1], device=displacement.device, dtype=displacement.dtype) + transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5) + + transformed_points = transformed_points.reshape(-1, 4, 2) + out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1) + out_bboxes = clamp_bounding_boxes( + torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_boxes.dtype), + format=tv_tensors.BoundingBoxFormat.XYXY, + canvas_size=canvas_size, + ) + + return convert_bounding_box_format( + out_bboxes, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format, inplace=True + ).reshape(original_shape) + + +@_register_kernel_internal(elastic, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False) +def _elastic_bounding_boxes_dispatch( + inpt: tv_tensors.BoundingBoxes, displacement: torch.Tensor, **kwargs +) -> tv_tensors.BoundingBoxes: + output = elastic_bounding_boxes( + inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, displacement=displacement + ) + return tv_tensors.wrap(output, like=inpt) + + +def elastic_mask( + mask: torch.Tensor, + displacement: torch.Tensor, + fill: _FillTypeJIT = None, +) -> torch.Tensor: + if mask.ndim < 3: + mask = mask.unsqueeze(0) + needs_squeeze = True + else: + needs_squeeze = False + + output = elastic_image(mask, displacement=displacement, interpolation=InterpolationMode.NEAREST, fill=fill) + + if needs_squeeze: + output = output.squeeze(0) + + return output + + +@_register_kernel_internal(elastic, tv_tensors.Mask, tv_tensor_wrapper=False) +def _elastic_mask_dispatch( + inpt: tv_tensors.Mask, displacement: torch.Tensor, fill: _FillTypeJIT = None, **kwargs +) -> tv_tensors.Mask: + output = elastic_mask(inpt.as_subclass(torch.Tensor), displacement=displacement, fill=fill) + return tv_tensors.wrap(output, like=inpt) + + +@_register_kernel_internal(elastic, tv_tensors.Video) +def elastic_video( + video: torch.Tensor, + displacement: torch.Tensor, + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + fill: _FillTypeJIT = None, +) -> torch.Tensor: + return elastic_image(video, displacement, interpolation=interpolation, fill=fill) + + +def center_crop(inpt: torch.Tensor, output_size: List[int]) -> torch.Tensor: + """See :class:`~torchvision.transforms.v2.RandomCrop` for details.""" + if torch.jit.is_scripting(): + return center_crop_image(inpt, output_size=output_size) + + _log_api_usage_once(center_crop) + + kernel = _get_kernel(center_crop, type(inpt)) + return kernel(inpt, output_size=output_size) + + +def _center_crop_parse_output_size(output_size: List[int]) -> List[int]: + if isinstance(output_size, numbers.Number): + s = int(output_size) + return [s, s] + elif isinstance(output_size, (tuple, list)) and len(output_size) == 1: + return [output_size[0], output_size[0]] + else: + return list(output_size) + + +def _center_crop_compute_padding(crop_height: int, crop_width: int, image_height: int, image_width: int) -> List[int]: + return [ + (crop_width - image_width) // 2 if crop_width > image_width else 0, + (crop_height - image_height) // 2 if crop_height > image_height else 0, + (crop_width - image_width + 1) // 2 if crop_width > image_width else 0, + (crop_height - image_height + 1) // 2 if crop_height > image_height else 0, + ] + + +def _center_crop_compute_crop_anchor( + crop_height: int, crop_width: int, image_height: int, image_width: int +) -> Tuple[int, int]: + crop_top = int(round((image_height - crop_height) / 2.0)) + crop_left = int(round((image_width - crop_width) / 2.0)) + return crop_top, crop_left + + +@_register_kernel_internal(center_crop, torch.Tensor) +@_register_kernel_internal(center_crop, tv_tensors.Image) +def center_crop_image(image: torch.Tensor, output_size: List[int]) -> torch.Tensor: + crop_height, crop_width = _center_crop_parse_output_size(output_size) + shape = image.shape + if image.numel() == 0: + return image.reshape(shape[:-2] + (crop_height, crop_width)) + image_height, image_width = shape[-2:] + + if crop_height > image_height or crop_width > image_width: + padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) + image = torch_pad(image, _parse_pad_padding(padding_ltrb), value=0.0) + + image_height, image_width = image.shape[-2:] + if crop_width == image_width and crop_height == image_height: + return image + + crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width) + return image[..., crop_top : (crop_top + crop_height), crop_left : (crop_left + crop_width)] + + +@_register_kernel_internal(center_crop, PIL.Image.Image) +def _center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image: + crop_height, crop_width = _center_crop_parse_output_size(output_size) + image_height, image_width = _get_size_image_pil(image) + + if crop_height > image_height or crop_width > image_width: + padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) + image = _pad_image_pil(image, padding_ltrb, fill=0) + + image_height, image_width = _get_size_image_pil(image) + if crop_width == image_width and crop_height == image_height: + return image + + crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width) + return _crop_image_pil(image, crop_top, crop_left, crop_height, crop_width) + + +def center_crop_bounding_boxes( + bounding_boxes: torch.Tensor, + format: tv_tensors.BoundingBoxFormat, + canvas_size: Tuple[int, int], + output_size: List[int], +) -> Tuple[torch.Tensor, Tuple[int, int]]: + crop_height, crop_width = _center_crop_parse_output_size(output_size) + crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *canvas_size) + return crop_bounding_boxes( + bounding_boxes, format, top=crop_top, left=crop_left, height=crop_height, width=crop_width + ) + + +@_register_kernel_internal(center_crop, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False) +def _center_crop_bounding_boxes_dispatch( + inpt: tv_tensors.BoundingBoxes, output_size: List[int] +) -> tv_tensors.BoundingBoxes: + output, canvas_size = center_crop_bounding_boxes( + inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, output_size=output_size + ) + return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) + + +@_register_kernel_internal(center_crop, tv_tensors.Mask) +def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor: + if mask.ndim < 3: + mask = mask.unsqueeze(0) + needs_squeeze = True + else: + needs_squeeze = False + + output = center_crop_image(image=mask, output_size=output_size) + + if needs_squeeze: + output = output.squeeze(0) + + return output + + +@_register_kernel_internal(center_crop, tv_tensors.Video) +def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tensor: + return center_crop_image(video, output_size) + + +def resized_crop( + inpt: torch.Tensor, + top: int, + left: int, + height: int, + width: int, + size: List[int], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + antialias: Optional[bool] = True, +) -> torch.Tensor: + """See :class:`~torchvision.transforms.v2.RandomResizedCrop` for details.""" + if torch.jit.is_scripting(): + return resized_crop_image( + inpt, + top=top, + left=left, + height=height, + width=width, + size=size, + interpolation=interpolation, + antialias=antialias, + ) + + _log_api_usage_once(resized_crop) + + kernel = _get_kernel(resized_crop, type(inpt)) + return kernel( + inpt, + top=top, + left=left, + height=height, + width=width, + size=size, + interpolation=interpolation, + antialias=antialias, + ) + + +@_register_kernel_internal(resized_crop, torch.Tensor) +@_register_kernel_internal(resized_crop, tv_tensors.Image) +def resized_crop_image( + image: torch.Tensor, + top: int, + left: int, + height: int, + width: int, + size: List[int], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + antialias: Optional[bool] = True, +) -> torch.Tensor: + image = crop_image(image, top, left, height, width) + return resize_image(image, size, interpolation=interpolation, antialias=antialias) + + +def _resized_crop_image_pil( + image: PIL.Image.Image, + top: int, + left: int, + height: int, + width: int, + size: List[int], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, +) -> PIL.Image.Image: + image = _crop_image_pil(image, top, left, height, width) + return _resize_image_pil(image, size, interpolation=interpolation) + + +@_register_kernel_internal(resized_crop, PIL.Image.Image) +def _resized_crop_image_pil_dispatch( + image: PIL.Image.Image, + top: int, + left: int, + height: int, + width: int, + size: List[int], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + antialias: Optional[bool] = True, +) -> PIL.Image.Image: + if antialias is False: + warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.") + return _resized_crop_image_pil( + image, + top=top, + left=left, + height=height, + width=width, + size=size, + interpolation=interpolation, + ) + + +def resized_crop_bounding_boxes( + bounding_boxes: torch.Tensor, + format: tv_tensors.BoundingBoxFormat, + top: int, + left: int, + height: int, + width: int, + size: List[int], +) -> Tuple[torch.Tensor, Tuple[int, int]]: + bounding_boxes, canvas_size = crop_bounding_boxes(bounding_boxes, format, top, left, height, width) + return resize_bounding_boxes(bounding_boxes, canvas_size=canvas_size, size=size) + + +@_register_kernel_internal(resized_crop, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False) +def _resized_crop_bounding_boxes_dispatch( + inpt: tv_tensors.BoundingBoxes, top: int, left: int, height: int, width: int, size: List[int], **kwargs +) -> tv_tensors.BoundingBoxes: + output, canvas_size = resized_crop_bounding_boxes( + inpt.as_subclass(torch.Tensor), format=inpt.format, top=top, left=left, height=height, width=width, size=size + ) + return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) + + +def resized_crop_mask( + mask: torch.Tensor, + top: int, + left: int, + height: int, + width: int, + size: List[int], +) -> torch.Tensor: + mask = crop_mask(mask, top, left, height, width) + return resize_mask(mask, size) + + +@_register_kernel_internal(resized_crop, tv_tensors.Mask, tv_tensor_wrapper=False) +def _resized_crop_mask_dispatch( + inpt: tv_tensors.Mask, top: int, left: int, height: int, width: int, size: List[int], **kwargs +) -> tv_tensors.Mask: + output = resized_crop_mask( + inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width, size=size + ) + return tv_tensors.wrap(output, like=inpt) + + +@_register_kernel_internal(resized_crop, tv_tensors.Video) +def resized_crop_video( + video: torch.Tensor, + top: int, + left: int, + height: int, + width: int, + size: List[int], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + antialias: Optional[bool] = True, +) -> torch.Tensor: + return resized_crop_image( + video, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation + ) + + +def five_crop( + inpt: torch.Tensor, size: List[int] +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """See :class:`~torchvision.transforms.v2.FiveCrop` for details.""" + if torch.jit.is_scripting(): + return five_crop_image(inpt, size=size) + + _log_api_usage_once(five_crop) + + kernel = _get_kernel(five_crop, type(inpt)) + return kernel(inpt, size=size) + + +def _parse_five_crop_size(size: List[int]) -> List[int]: + if isinstance(size, numbers.Number): + s = int(size) + size = [s, s] + elif isinstance(size, (tuple, list)) and len(size) == 1: + s = size[0] + size = [s, s] + + if len(size) != 2: + raise ValueError("Please provide only two dimensions (h, w) for size.") + + return size + + +@_register_five_ten_crop_kernel_internal(five_crop, torch.Tensor) +@_register_five_ten_crop_kernel_internal(five_crop, tv_tensors.Image) +def five_crop_image( + image: torch.Tensor, size: List[int] +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + crop_height, crop_width = _parse_five_crop_size(size) + image_height, image_width = image.shape[-2:] + + if crop_width > image_width or crop_height > image_height: + raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}") + + tl = crop_image(image, 0, 0, crop_height, crop_width) + tr = crop_image(image, 0, image_width - crop_width, crop_height, crop_width) + bl = crop_image(image, image_height - crop_height, 0, crop_height, crop_width) + br = crop_image(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width) + center = center_crop_image(image, [crop_height, crop_width]) + + return tl, tr, bl, br, center + + +@_register_five_ten_crop_kernel_internal(five_crop, PIL.Image.Image) +def _five_crop_image_pil( + image: PIL.Image.Image, size: List[int] +) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]: + crop_height, crop_width = _parse_five_crop_size(size) + image_height, image_width = _get_size_image_pil(image) + + if crop_width > image_width or crop_height > image_height: + raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}") + + tl = _crop_image_pil(image, 0, 0, crop_height, crop_width) + tr = _crop_image_pil(image, 0, image_width - crop_width, crop_height, crop_width) + bl = _crop_image_pil(image, image_height - crop_height, 0, crop_height, crop_width) + br = _crop_image_pil(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width) + center = _center_crop_image_pil(image, [crop_height, crop_width]) + + return tl, tr, bl, br, center + + +@_register_five_ten_crop_kernel_internal(five_crop, tv_tensors.Video) +def five_crop_video( + video: torch.Tensor, size: List[int] +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + return five_crop_image(video, size) + + +def ten_crop( + inpt: torch.Tensor, size: List[int], vertical_flip: bool = False +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + """See :class:`~torchvision.transforms.v2.TenCrop` for details.""" + if torch.jit.is_scripting(): + return ten_crop_image(inpt, size=size, vertical_flip=vertical_flip) + + _log_api_usage_once(ten_crop) + + kernel = _get_kernel(ten_crop, type(inpt)) + return kernel(inpt, size=size, vertical_flip=vertical_flip) + + +@_register_five_ten_crop_kernel_internal(ten_crop, torch.Tensor) +@_register_five_ten_crop_kernel_internal(ten_crop, tv_tensors.Image) +def ten_crop_image( + image: torch.Tensor, size: List[int], vertical_flip: bool = False +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + non_flipped = five_crop_image(image, size) + + if vertical_flip: + image = vertical_flip_image(image) + else: + image = horizontal_flip_image(image) + + flipped = five_crop_image(image, size) + + return non_flipped + flipped + + +@_register_five_ten_crop_kernel_internal(ten_crop, PIL.Image.Image) +def _ten_crop_image_pil( + image: PIL.Image.Image, size: List[int], vertical_flip: bool = False +) -> Tuple[ + PIL.Image.Image, + PIL.Image.Image, + PIL.Image.Image, + PIL.Image.Image, + PIL.Image.Image, + PIL.Image.Image, + PIL.Image.Image, + PIL.Image.Image, + PIL.Image.Image, + PIL.Image.Image, +]: + non_flipped = _five_crop_image_pil(image, size) + + if vertical_flip: + image = _vertical_flip_image_pil(image) + else: + image = _horizontal_flip_image_pil(image) + + flipped = _five_crop_image_pil(image, size) + + return non_flipped + flipped + + +@_register_five_ten_crop_kernel_internal(ten_crop, tv_tensors.Video) +def ten_crop_video( + video: torch.Tensor, size: List[int], vertical_flip: bool = False +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + return ten_crop_image(video, size, vertical_flip=vertical_flip) diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/_meta.py b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..b90e5fb7b5be887d24bb42725d59bb056ac126c2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/_meta.py @@ -0,0 +1,279 @@ +from typing import List, Optional, Tuple + +import PIL.Image +import torch +from torchvision import tv_tensors +from torchvision.transforms import _functional_pil as _FP +from torchvision.tv_tensors import BoundingBoxFormat + +from torchvision.utils import _log_api_usage_once + +from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor + + +def get_dimensions(inpt: torch.Tensor) -> List[int]: + if torch.jit.is_scripting(): + return get_dimensions_image(inpt) + + _log_api_usage_once(get_dimensions) + + kernel = _get_kernel(get_dimensions, type(inpt)) + return kernel(inpt) + + +@_register_kernel_internal(get_dimensions, torch.Tensor) +@_register_kernel_internal(get_dimensions, tv_tensors.Image, tv_tensor_wrapper=False) +def get_dimensions_image(image: torch.Tensor) -> List[int]: + chw = list(image.shape[-3:]) + ndims = len(chw) + if ndims == 3: + return chw + elif ndims == 2: + chw.insert(0, 1) + return chw + else: + raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}") + + +_get_dimensions_image_pil = _register_kernel_internal(get_dimensions, PIL.Image.Image)(_FP.get_dimensions) + + +@_register_kernel_internal(get_dimensions, tv_tensors.Video, tv_tensor_wrapper=False) +def get_dimensions_video(video: torch.Tensor) -> List[int]: + return get_dimensions_image(video) + + +def get_num_channels(inpt: torch.Tensor) -> int: + if torch.jit.is_scripting(): + return get_num_channels_image(inpt) + + _log_api_usage_once(get_num_channels) + + kernel = _get_kernel(get_num_channels, type(inpt)) + return kernel(inpt) + + +@_register_kernel_internal(get_num_channels, torch.Tensor) +@_register_kernel_internal(get_num_channels, tv_tensors.Image, tv_tensor_wrapper=False) +def get_num_channels_image(image: torch.Tensor) -> int: + chw = image.shape[-3:] + ndims = len(chw) + if ndims == 3: + return chw[0] + elif ndims == 2: + return 1 + else: + raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}") + + +_get_num_channels_image_pil = _register_kernel_internal(get_num_channels, PIL.Image.Image)(_FP.get_image_num_channels) + + +@_register_kernel_internal(get_num_channels, tv_tensors.Video, tv_tensor_wrapper=False) +def get_num_channels_video(video: torch.Tensor) -> int: + return get_num_channels_image(video) + + +# We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without +# deprecating the old names. +get_image_num_channels = get_num_channels + + +def get_size(inpt: torch.Tensor) -> List[int]: + if torch.jit.is_scripting(): + return get_size_image(inpt) + + _log_api_usage_once(get_size) + + kernel = _get_kernel(get_size, type(inpt)) + return kernel(inpt) + + +@_register_kernel_internal(get_size, torch.Tensor) +@_register_kernel_internal(get_size, tv_tensors.Image, tv_tensor_wrapper=False) +def get_size_image(image: torch.Tensor) -> List[int]: + hw = list(image.shape[-2:]) + ndims = len(hw) + if ndims == 2: + return hw + else: + raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}") + + +@_register_kernel_internal(get_size, PIL.Image.Image) +def _get_size_image_pil(image: PIL.Image.Image) -> List[int]: + width, height = _FP.get_image_size(image) + return [height, width] + + +@_register_kernel_internal(get_size, tv_tensors.Video, tv_tensor_wrapper=False) +def get_size_video(video: torch.Tensor) -> List[int]: + return get_size_image(video) + + +@_register_kernel_internal(get_size, tv_tensors.Mask, tv_tensor_wrapper=False) +def get_size_mask(mask: torch.Tensor) -> List[int]: + return get_size_image(mask) + + +@_register_kernel_internal(get_size, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False) +def get_size_bounding_boxes(bounding_box: tv_tensors.BoundingBoxes) -> List[int]: + return list(bounding_box.canvas_size) + + +def get_num_frames(inpt: torch.Tensor) -> int: + if torch.jit.is_scripting(): + return get_num_frames_video(inpt) + + _log_api_usage_once(get_num_frames) + + kernel = _get_kernel(get_num_frames, type(inpt)) + return kernel(inpt) + + +@_register_kernel_internal(get_num_frames, torch.Tensor) +@_register_kernel_internal(get_num_frames, tv_tensors.Video, tv_tensor_wrapper=False) +def get_num_frames_video(video: torch.Tensor) -> int: + return video.shape[-4] + + +def _xywh_to_xyxy(xywh: torch.Tensor, inplace: bool) -> torch.Tensor: + xyxy = xywh if inplace else xywh.clone() + xyxy[..., 2:] += xyxy[..., :2] + return xyxy + + +def _xyxy_to_xywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor: + xywh = xyxy if inplace else xyxy.clone() + xywh[..., 2:] -= xywh[..., :2] + return xywh + + +def _cxcywh_to_xyxy(cxcywh: torch.Tensor, inplace: bool) -> torch.Tensor: + if not inplace: + cxcywh = cxcywh.clone() + + # Trick to do fast division by 2 and ceil, without casting. It produces the same result as + # `torchvision.ops._box_convert._box_cxcywh_to_xyxy`. + half_wh = cxcywh[..., 2:].div(-2, rounding_mode=None if cxcywh.is_floating_point() else "floor").abs_() + # (cx - width / 2) = x1, same for y1 + cxcywh[..., :2].sub_(half_wh) + # (x1 + width) = x2, same for y2 + cxcywh[..., 2:].add_(cxcywh[..., :2]) + + return cxcywh + + +def _xyxy_to_cxcywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor: + if not inplace: + xyxy = xyxy.clone() + + # (x2 - x1) = width, same for height + xyxy[..., 2:].sub_(xyxy[..., :2]) + # (x1 * 2 + width) / 2 = x1 + width / 2 = x1 + (x2-x1)/2 = (x1 + x2)/2 = cx, same for cy + xyxy[..., :2].mul_(2).add_(xyxy[..., 2:]).div_(2, rounding_mode=None if xyxy.is_floating_point() else "floor") + + return xyxy + + +def _convert_bounding_box_format( + bounding_boxes: torch.Tensor, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat, inplace: bool = False +) -> torch.Tensor: + + if new_format == old_format: + return bounding_boxes + + # TODO: Add _xywh_to_cxcywh and _cxcywh_to_xywh to improve performance + if old_format == BoundingBoxFormat.XYWH: + bounding_boxes = _xywh_to_xyxy(bounding_boxes, inplace) + elif old_format == BoundingBoxFormat.CXCYWH: + bounding_boxes = _cxcywh_to_xyxy(bounding_boxes, inplace) + + if new_format == BoundingBoxFormat.XYWH: + bounding_boxes = _xyxy_to_xywh(bounding_boxes, inplace) + elif new_format == BoundingBoxFormat.CXCYWH: + bounding_boxes = _xyxy_to_cxcywh(bounding_boxes, inplace) + + return bounding_boxes + + +def convert_bounding_box_format( + inpt: torch.Tensor, + old_format: Optional[BoundingBoxFormat] = None, + new_format: Optional[BoundingBoxFormat] = None, + inplace: bool = False, +) -> torch.Tensor: + """See :func:`~torchvision.transforms.v2.ConvertBoundingBoxFormat` for details.""" + # This being a kernel / functional hybrid, we need an option to pass `old_format` explicitly for pure tensor + # inputs as well as extract it from `tv_tensors.BoundingBoxes` inputs. However, putting a default value on + # `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the + # default error that would be thrown if `new_format` had no default value. + if new_format is None: + raise TypeError("convert_bounding_box_format() missing 1 required argument: 'new_format'") + + if not torch.jit.is_scripting(): + _log_api_usage_once(convert_bounding_box_format) + + if isinstance(old_format, str): + old_format = BoundingBoxFormat[old_format.upper()] + if isinstance(new_format, str): + new_format = BoundingBoxFormat[new_format.upper()] + + if torch.jit.is_scripting() or is_pure_tensor(inpt): + if old_format is None: + raise ValueError("For pure tensor inputs, `old_format` has to be passed.") + return _convert_bounding_box_format(inpt, old_format=old_format, new_format=new_format, inplace=inplace) + elif isinstance(inpt, tv_tensors.BoundingBoxes): + if old_format is not None: + raise ValueError("For bounding box tv_tensor inputs, `old_format` must not be passed.") + output = _convert_bounding_box_format( + inpt.as_subclass(torch.Tensor), old_format=inpt.format, new_format=new_format, inplace=inplace + ) + return tv_tensors.wrap(output, like=inpt, format=new_format) + else: + raise TypeError( + f"Input can either be a plain tensor or a bounding box tv_tensor, but got {type(inpt)} instead." + ) + + +def _clamp_bounding_boxes( + bounding_boxes: torch.Tensor, format: BoundingBoxFormat, canvas_size: Tuple[int, int] +) -> torch.Tensor: + # TODO: Investigate if it makes sense from a performance perspective to have an implementation for every + # BoundingBoxFormat instead of converting back and forth + in_dtype = bounding_boxes.dtype + bounding_boxes = bounding_boxes.clone() if bounding_boxes.is_floating_point() else bounding_boxes.float() + xyxy_boxes = convert_bounding_box_format( + bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY, inplace=True + ) + xyxy_boxes[..., 0::2].clamp_(min=0, max=canvas_size[1]) + xyxy_boxes[..., 1::2].clamp_(min=0, max=canvas_size[0]) + out_boxes = convert_bounding_box_format( + xyxy_boxes, old_format=BoundingBoxFormat.XYXY, new_format=format, inplace=True + ) + return out_boxes.to(in_dtype) + + +def clamp_bounding_boxes( + inpt: torch.Tensor, + format: Optional[BoundingBoxFormat] = None, + canvas_size: Optional[Tuple[int, int]] = None, +) -> torch.Tensor: + """See :func:`~torchvision.transforms.v2.ClampBoundingBoxes` for details.""" + if not torch.jit.is_scripting(): + _log_api_usage_once(clamp_bounding_boxes) + + if torch.jit.is_scripting() or is_pure_tensor(inpt): + + if format is None or canvas_size is None: + raise ValueError("For pure tensor inputs, `format` and `canvas_size` have to be passed.") + return _clamp_bounding_boxes(inpt, format=format, canvas_size=canvas_size) + elif isinstance(inpt, tv_tensors.BoundingBoxes): + if format is not None or canvas_size is not None: + raise ValueError("For bounding box tv_tensor inputs, `format` and `canvas_size` must not be passed.") + output = _clamp_bounding_boxes(inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size) + return tv_tensors.wrap(output, like=inpt) + else: + raise TypeError( + f"Input can either be a plain tensor or a bounding box tv_tensor, but got {type(inpt)} instead." + ) diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/_misc.py b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/_misc.py new file mode 100644 index 0000000000000000000000000000000000000000..f40bf117753ef2af57d3e1f0f74df2274a8a05cd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/_misc.py @@ -0,0 +1,420 @@ +import math +from typing import List, Optional, Tuple + +import PIL.Image +import torch +from torch.nn.functional import conv2d, pad as torch_pad + +from torchvision import tv_tensors +from torchvision.transforms._functional_tensor import _max_value +from torchvision.transforms.functional import pil_to_tensor, to_pil_image + +from torchvision.utils import _log_api_usage_once + +from ._meta import _convert_bounding_box_format + +from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor + + +def normalize( + inpt: torch.Tensor, + mean: List[float], + std: List[float], + inplace: bool = False, +) -> torch.Tensor: + """See :class:`~torchvision.transforms.v2.Normalize` for details.""" + if torch.jit.is_scripting(): + return normalize_image(inpt, mean=mean, std=std, inplace=inplace) + + _log_api_usage_once(normalize) + + kernel = _get_kernel(normalize, type(inpt)) + return kernel(inpt, mean=mean, std=std, inplace=inplace) + + +@_register_kernel_internal(normalize, torch.Tensor) +@_register_kernel_internal(normalize, tv_tensors.Image) +def normalize_image(image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor: + if not image.is_floating_point(): + raise TypeError(f"Input tensor should be a float tensor. Got {image.dtype}.") + + if image.ndim < 3: + raise ValueError(f"Expected tensor to be a tensor image of size (..., C, H, W). Got {image.shape}.") + + if isinstance(std, (tuple, list)): + divzero = not all(std) + elif isinstance(std, (int, float)): + divzero = std == 0 + else: + divzero = False + if divzero: + raise ValueError("std evaluated to zero, leading to division by zero.") + + dtype = image.dtype + device = image.device + mean = torch.as_tensor(mean, dtype=dtype, device=device) + std = torch.as_tensor(std, dtype=dtype, device=device) + if mean.ndim == 1: + mean = mean.view(-1, 1, 1) + if std.ndim == 1: + std = std.view(-1, 1, 1) + + if inplace: + image = image.sub_(mean) + else: + image = image.sub(mean) + + return image.div_(std) + + +@_register_kernel_internal(normalize, tv_tensors.Video) +def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor: + return normalize_image(video, mean, std, inplace=inplace) + + +def gaussian_blur(inpt: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> torch.Tensor: + """See :class:`~torchvision.transforms.v2.GaussianBlur` for details.""" + if torch.jit.is_scripting(): + return gaussian_blur_image(inpt, kernel_size=kernel_size, sigma=sigma) + + _log_api_usage_once(gaussian_blur) + + kernel = _get_kernel(gaussian_blur, type(inpt)) + return kernel(inpt, kernel_size=kernel_size, sigma=sigma) + + +def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> torch.Tensor: + lim = (kernel_size - 1) / (2.0 * math.sqrt(2.0)) + x = torch.linspace(-lim, lim, steps=kernel_size, dtype=dtype, device=device) + kernel1d = torch.softmax(x.div(sigma).pow(2).neg(), dim=0) + return kernel1d + + +def _get_gaussian_kernel2d( + kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device +) -> torch.Tensor: + kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0], dtype, device) + kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1], dtype, device) + kernel2d = kernel1d_y.unsqueeze(-1) * kernel1d_x + return kernel2d + + +@_register_kernel_internal(gaussian_blur, torch.Tensor) +@_register_kernel_internal(gaussian_blur, tv_tensors.Image) +def gaussian_blur_image( + image: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None +) -> torch.Tensor: + # TODO: consider deprecating integers from sigma on the future + if isinstance(kernel_size, int): + kernel_size = [kernel_size, kernel_size] + elif len(kernel_size) != 2: + raise ValueError(f"If kernel_size is a sequence its length should be 2. Got {len(kernel_size)}") + for ksize in kernel_size: + if ksize % 2 == 0 or ksize < 0: + raise ValueError(f"kernel_size should have odd and positive integers. Got {kernel_size}") + + if sigma is None: + sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size] + else: + if isinstance(sigma, (list, tuple)): + length = len(sigma) + if length == 1: + s = sigma[0] + sigma = [s, s] + elif length != 2: + raise ValueError(f"If sigma is a sequence, its length should be 2. Got {length}") + elif isinstance(sigma, (int, float)): + s = float(sigma) + sigma = [s, s] + else: + raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}") + for s in sigma: + if s <= 0.0: + raise ValueError(f"sigma should have positive values. Got {sigma}") + + if image.numel() == 0: + return image + + dtype = image.dtype + shape = image.shape + ndim = image.ndim + if ndim == 3: + image = image.unsqueeze(dim=0) + elif ndim > 4: + image = image.reshape((-1,) + shape[-3:]) + + fp = torch.is_floating_point(image) + kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype if fp else torch.float32, device=image.device) + kernel = kernel.expand(shape[-3], 1, kernel.shape[0], kernel.shape[1]) + + output = image if fp else image.to(dtype=torch.float32) + + # padding = (left, right, top, bottom) + padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2] + output = torch_pad(output, padding, mode="reflect") + output = conv2d(output, kernel, groups=shape[-3]) + + if ndim == 3: + output = output.squeeze(dim=0) + elif ndim > 4: + output = output.reshape(shape) + + if not fp: + output = output.round_().to(dtype=dtype) + + return output + + +@_register_kernel_internal(gaussian_blur, PIL.Image.Image) +def _gaussian_blur_image_pil( + image: PIL.Image.Image, kernel_size: List[int], sigma: Optional[List[float]] = None +) -> PIL.Image.Image: + t_img = pil_to_tensor(image) + output = gaussian_blur_image(t_img, kernel_size=kernel_size, sigma=sigma) + return to_pil_image(output, mode=image.mode) + + +@_register_kernel_internal(gaussian_blur, tv_tensors.Video) +def gaussian_blur_video( + video: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None +) -> torch.Tensor: + return gaussian_blur_image(video, kernel_size, sigma) + + +def gaussian_noise(inpt: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, clip: bool = True) -> torch.Tensor: + """See :class:`~torchvision.transforms.v2.GaussianNoise`""" + if torch.jit.is_scripting(): + return gaussian_noise_image(inpt, mean=mean, sigma=sigma) + + _log_api_usage_once(gaussian_noise) + + kernel = _get_kernel(gaussian_noise, type(inpt)) + return kernel(inpt, mean=mean, sigma=sigma, clip=clip) + + +@_register_kernel_internal(gaussian_noise, torch.Tensor) +@_register_kernel_internal(gaussian_noise, tv_tensors.Image) +def gaussian_noise_image(image: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, clip: bool = True) -> torch.Tensor: + if not image.is_floating_point(): + raise ValueError(f"Input tensor is expected to be in float dtype, got dtype={image.dtype}") + if sigma < 0: + raise ValueError(f"sigma shouldn't be negative. Got {sigma}") + + noise = mean + torch.randn_like(image) * sigma + out = image + noise + if clip: + out = torch.clamp(out, 0, 1) + return out + + +@_register_kernel_internal(gaussian_noise, tv_tensors.Video) +def gaussian_noise_video(video: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, clip: bool = True) -> torch.Tensor: + return gaussian_noise_image(video, mean=mean, sigma=sigma, clip=clip) + + +@_register_kernel_internal(gaussian_noise, PIL.Image.Image) +def _gaussian_noise_pil( + video: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, clip: bool = True +) -> PIL.Image.Image: + raise ValueError("Gaussian Noise is not implemented for PIL images.") + + +def to_dtype(inpt: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: + """See :func:`~torchvision.transforms.v2.ToDtype` for details.""" + if torch.jit.is_scripting(): + return to_dtype_image(inpt, dtype=dtype, scale=scale) + + _log_api_usage_once(to_dtype) + + kernel = _get_kernel(to_dtype, type(inpt)) + return kernel(inpt, dtype=dtype, scale=scale) + + +def _num_value_bits(dtype: torch.dtype) -> int: + if dtype == torch.uint8: + return 8 + elif dtype == torch.int8: + return 7 + elif dtype == torch.int16: + return 15 + elif dtype == torch.uint16: + return 16 + elif dtype == torch.int32: + return 31 + elif dtype == torch.int64: + return 63 + else: + raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.") + + +@_register_kernel_internal(to_dtype, torch.Tensor) +@_register_kernel_internal(to_dtype, tv_tensors.Image) +def to_dtype_image(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: + + if image.dtype == dtype: + return image + elif not scale: + return image.to(dtype) + + float_input = image.is_floating_point() + if torch.jit.is_scripting(): + # TODO: remove this branch as soon as `dtype.is_floating_point` is supported by JIT + float_output = torch.tensor(0, dtype=dtype).is_floating_point() + else: + float_output = dtype.is_floating_point + + if float_input: + # float to float + if float_output: + return image.to(dtype) + + # float to int + if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or ( + image.dtype == torch.float64 and dtype == torch.int64 + ): + raise RuntimeError(f"The conversion from {image.dtype} to {dtype} cannot be performed safely.") + + # For data in the range `[0.0, 1.0]`, just multiplying by the maximum value of the integer range and converting + # to the integer dtype is not sufficient. For example, `torch.rand(...).mul(255).to(torch.uint8)` will only + # be `255` if the input is exactly `1.0`. See https://github.com/pytorch/vision/pull/2078#issuecomment-612045321 + # for a detailed analysis. + # To mitigate this, we could round before we convert to the integer dtype, but this is an extra operation. + # Instead, we can also multiply by the maximum value plus something close to `1`. See + # https://github.com/pytorch/vision/pull/2078#issuecomment-613524965 for details. + eps = 1e-3 + max_value = float(_max_value(dtype)) + # We need to scale first since the conversion would otherwise turn the input range `[0.0, 1.0]` into the + # discrete set `{0, 1}`. + return image.mul(max_value + 1.0 - eps).to(dtype) + else: + # int to float + if float_output: + return image.to(dtype).mul_(1.0 / _max_value(image.dtype)) + + # int to int + num_value_bits_input = _num_value_bits(image.dtype) + num_value_bits_output = _num_value_bits(dtype) + + # TODO: Remove if/else inner blocks once uint16 dtype supports bitwise shift operations. + shift_by = abs(num_value_bits_input - num_value_bits_output) + if num_value_bits_input > num_value_bits_output: + if image.dtype == torch.uint16: + return (image / 2 ** (shift_by)).to(dtype) + else: + return image.bitwise_right_shift(shift_by).to(dtype) + else: + if dtype == torch.uint16: + return image.to(dtype) * 2 ** (shift_by) + else: + return image.to(dtype).bitwise_left_shift_(shift_by) + + +# We encourage users to use to_dtype() instead but we keep this for BC +def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """[DEPRECATED] Use to_dtype() instead.""" + return to_dtype_image(image, dtype=dtype, scale=True) + + +@_register_kernel_internal(to_dtype, tv_tensors.Video) +def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: + return to_dtype_image(video, dtype, scale=scale) + + +@_register_kernel_internal(to_dtype, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False) +@_register_kernel_internal(to_dtype, tv_tensors.Mask, tv_tensor_wrapper=False) +def _to_dtype_tensor_dispatch(inpt: torch.Tensor, dtype: torch.dtype, scale: bool = False) -> torch.Tensor: + # We don't need to unwrap and rewrap here, since TVTensor.to() preserves the type + return inpt.to(dtype) + + +def sanitize_bounding_boxes( + bounding_boxes: torch.Tensor, + format: Optional[tv_tensors.BoundingBoxFormat] = None, + canvas_size: Optional[Tuple[int, int]] = None, + min_size: float = 1.0, + min_area: float = 1.0, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Remove degenerate/invalid bounding boxes and return the corresponding indexing mask. + + This removes bounding boxes that: + + - are below a given ``min_size`` or ``min_area``: by default this also removes degenerate boxes that have e.g. X2 <= X1. + - have any coordinate outside of their corresponding image. You may want to + call :func:`~torchvision.transforms.v2.functional.clamp_bounding_boxes` first to avoid undesired removals. + + It is recommended to call it at the end of a pipeline, before passing the + input to the models. It is critical to call this transform if + :class:`~torchvision.transforms.v2.RandomIoUCrop` was called. + If you want to be extra careful, you may call it after all transforms that + may modify bounding boxes but once at the end should be enough in most + cases. + + Args: + bounding_boxes (Tensor or :class:`~torchvision.tv_tensors.BoundingBoxes`): The bounding boxes to be sanitized. + format (str or :class:`~torchvision.tv_tensors.BoundingBoxFormat`, optional): The format of the bounding boxes. + Must be left to none if ``bounding_boxes`` is a :class:`~torchvision.tv_tensors.BoundingBoxes` object. + canvas_size (tuple of int, optional): The canvas_size of the bounding boxes + (size of the corresponding image/video). + Must be left to none if ``bounding_boxes`` is a :class:`~torchvision.tv_tensors.BoundingBoxes` object. + min_size (float, optional) The size below which bounding boxes are removed. Default is 1. + min_area (float, optional) The area below which bounding boxes are removed. Default is 1. + + Returns: + out (tuple of Tensors): The subset of valid bounding boxes, and the corresponding indexing mask. + The mask can then be used to subset other tensors (e.g. labels) that are associated with the bounding boxes. + """ + if torch.jit.is_scripting() or is_pure_tensor(bounding_boxes): + if format is None or canvas_size is None: + raise ValueError( + "format and canvas_size cannot be None if bounding_boxes is a pure tensor. " + f"Got format={format} and canvas_size={canvas_size}." + "Set those to appropriate values or pass bounding_boxes as a tv_tensors.BoundingBoxes object." + ) + if isinstance(format, str): + format = tv_tensors.BoundingBoxFormat[format.upper()] + valid = _get_sanitize_bounding_boxes_mask( + bounding_boxes, format=format, canvas_size=canvas_size, min_size=min_size, min_area=min_area + ) + bounding_boxes = bounding_boxes[valid] + else: + if not isinstance(bounding_boxes, tv_tensors.BoundingBoxes): + raise ValueError("bounding_boxes must be a tv_tensors.BoundingBoxes instance or a pure tensor.") + if format is not None or canvas_size is not None: + raise ValueError( + "format and canvas_size must be None when bounding_boxes is a tv_tensors.BoundingBoxes instance. " + f"Got format={format} and canvas_size={canvas_size}. " + "Leave those to None or pass bounding_boxes as a pure tensor." + ) + valid = _get_sanitize_bounding_boxes_mask( + bounding_boxes, + format=bounding_boxes.format, + canvas_size=bounding_boxes.canvas_size, + min_size=min_size, + min_area=min_area, + ) + bounding_boxes = tv_tensors.wrap(bounding_boxes[valid], like=bounding_boxes) + + return bounding_boxes, valid + + +def _get_sanitize_bounding_boxes_mask( + bounding_boxes: torch.Tensor, + format: tv_tensors.BoundingBoxFormat, + canvas_size: Tuple[int, int], + min_size: float = 1.0, + min_area: float = 1.0, +) -> torch.Tensor: + + bounding_boxes = _convert_bounding_box_format( + bounding_boxes, new_format=tv_tensors.BoundingBoxFormat.XYXY, old_format=format + ) + + image_h, image_w = canvas_size + ws, hs = bounding_boxes[:, 2] - bounding_boxes[:, 0], bounding_boxes[:, 3] - bounding_boxes[:, 1] + valid = (ws >= min_size) & (hs >= min_size) & (bounding_boxes >= 0).all(dim=-1) & (ws * hs >= min_area) + # TODO: Do we really need to check for out of bounds here? All + # transforms should be clamping anyway, so this should never happen? + image_h, image_w = canvas_size + valid &= (bounding_boxes[:, 0] <= image_w) & (bounding_boxes[:, 2] <= image_w) + valid &= (bounding_boxes[:, 1] <= image_h) & (bounding_boxes[:, 3] <= image_h) + return valid diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/_temporal.py b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/_temporal.py new file mode 100644 index 0000000000000000000000000000000000000000..f932b06a295fd10316fba3e796ec4649053e92db --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/_temporal.py @@ -0,0 +1,27 @@ +import torch + +from torchvision import tv_tensors + +from torchvision.utils import _log_api_usage_once + +from ._utils import _get_kernel, _register_kernel_internal + + +def uniform_temporal_subsample(inpt: torch.Tensor, num_samples: int) -> torch.Tensor: + """See :class:`~torchvision.transforms.v2.UniformTemporalSubsample` for details.""" + if torch.jit.is_scripting(): + return uniform_temporal_subsample_video(inpt, num_samples=num_samples) + + _log_api_usage_once(uniform_temporal_subsample) + + kernel = _get_kernel(uniform_temporal_subsample, type(inpt)) + return kernel(inpt, num_samples=num_samples) + + +@_register_kernel_internal(uniform_temporal_subsample, torch.Tensor) +@_register_kernel_internal(uniform_temporal_subsample, tv_tensors.Video) +def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int) -> torch.Tensor: + # Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19 + t_max = video.shape[-4] - 1 + indices = torch.linspace(0, t_max, num_samples, device=video.device).long() + return torch.index_select(video, -4, indices) diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/_type_conversion.py b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/_type_conversion.py new file mode 100644 index 0000000000000000000000000000000000000000..c5a731fe143c365400d5905db8370c538097583a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/_type_conversion.py @@ -0,0 +1,27 @@ +from typing import Union + +import numpy as np +import PIL.Image +import torch +from torchvision import tv_tensors +from torchvision.transforms import functional as _F + + +@torch.jit.unused +def to_image(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> tv_tensors.Image: + """See :class:`~torchvision.transforms.v2.ToImage` for details.""" + if isinstance(inpt, np.ndarray): + output = torch.from_numpy(np.atleast_3d(inpt)).permute((2, 0, 1)).contiguous() + elif isinstance(inpt, PIL.Image.Image): + output = pil_to_tensor(inpt) + elif isinstance(inpt, torch.Tensor): + output = inpt + else: + raise TypeError( + f"Input can either be a pure Tensor, a numpy array, or a PIL image, but got {type(inpt)} instead." + ) + return tv_tensors.Image(output) + + +to_pil_image = _F.to_pil_image +pil_to_tensor = _F.pil_to_tensor diff --git a/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/_utils.py b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fe0faeddc1b9b56596783e4e4872b604edcd1555 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/transforms/v2/functional/_utils.py @@ -0,0 +1,141 @@ +import functools +from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union + +import torch +from torchvision import tv_tensors + +_FillType = Union[int, float, Sequence[int], Sequence[float], None] +_FillTypeJIT = Optional[List[float]] + + +def is_pure_tensor(inpt: Any) -> bool: + return isinstance(inpt, torch.Tensor) and not isinstance(inpt, tv_tensors.TVTensor) + + +# {functional: {input_type: type_specific_kernel}} +_KERNEL_REGISTRY: Dict[Callable, Dict[Type, Callable]] = {} + + +def _kernel_tv_tensor_wrapper(kernel): + @functools.wraps(kernel) + def wrapper(inpt, *args, **kwargs): + # If you're wondering whether we could / should get rid of this wrapper, + # the answer is no: we want to pass pure Tensors to avoid the overhead + # of the __torch_function__ machinery. Note that this is always valid, + # regardless of whether we override __torch_function__ in our base class + # or not. + # Also, even if we didn't call `as_subclass` here, we would still need + # this wrapper to call wrap(), because the TVTensor type would be + # lost after the first operation due to our own __torch_function__ + # logic. + output = kernel(inpt.as_subclass(torch.Tensor), *args, **kwargs) + return tv_tensors.wrap(output, like=inpt) + + return wrapper + + +def _register_kernel_internal(functional, input_type, *, tv_tensor_wrapper=True): + registry = _KERNEL_REGISTRY.setdefault(functional, {}) + if input_type in registry: + raise ValueError(f"Functional {functional} already has a kernel registered for type {input_type}.") + + def decorator(kernel): + registry[input_type] = ( + _kernel_tv_tensor_wrapper(kernel) + if issubclass(input_type, tv_tensors.TVTensor) and tv_tensor_wrapper + else kernel + ) + return kernel + + return decorator + + +def _name_to_functional(name): + import torchvision.transforms.v2.functional # noqa + + try: + return getattr(torchvision.transforms.v2.functional, name) + except AttributeError: + raise ValueError( + f"Could not find functional with name '{name}' in torchvision.transforms.v2.functional." + ) from None + + +_BUILTIN_DATAPOINT_TYPES = { + obj for obj in tv_tensors.__dict__.values() if isinstance(obj, type) and issubclass(obj, tv_tensors.TVTensor) +} + + +def register_kernel(functional, tv_tensor_cls): + """Decorate a kernel to register it for a functional and a (custom) tv_tensor type. + + See :ref:`sphx_glr_auto_examples_transforms_plot_custom_tv_tensors.py` for usage + details. + """ + if isinstance(functional, str): + functional = _name_to_functional(name=functional) + elif not ( + callable(functional) + and getattr(functional, "__module__", "").startswith("torchvision.transforms.v2.functional") + ): + raise ValueError( + f"Kernels can only be registered on functionals from the torchvision.transforms.v2.functional namespace, " + f"but got {functional}." + ) + + if not (isinstance(tv_tensor_cls, type) and issubclass(tv_tensor_cls, tv_tensors.TVTensor)): + raise ValueError( + f"Kernels can only be registered for subclasses of torchvision.tv_tensors.TVTensor, " + f"but got {tv_tensor_cls}." + ) + + if tv_tensor_cls in _BUILTIN_DATAPOINT_TYPES: + raise ValueError(f"Kernels cannot be registered for the builtin tv_tensor classes, but got {tv_tensor_cls}") + + return _register_kernel_internal(functional, tv_tensor_cls, tv_tensor_wrapper=False) + + +def _get_kernel(functional, input_type, *, allow_passthrough=False): + registry = _KERNEL_REGISTRY.get(functional) + if not registry: + raise ValueError(f"No kernel registered for functional {functional.__name__}.") + + for cls in input_type.__mro__: + if cls in registry: + return registry[cls] + elif cls is tv_tensors.TVTensor: + # We don't want user-defined tv_tensors to dispatch to the pure Tensor kernels, so we explicit stop the + # MRO traversal before hitting torch.Tensor. We can even stop at tv_tensors.TVTensor, since we don't + # allow kernels to be registered for tv_tensors.TVTensor anyway. + break + + if allow_passthrough: + return lambda inpt, *args, **kwargs: inpt + + raise TypeError( + f"Functional F.{functional.__name__} supports inputs of type {registry.keys()}, " + f"but got {input_type} instead." + ) + + +# This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop +# We could get rid of this by letting _register_kernel_internal take arbitrary functionals rather than wrap_kernel: bool +def _register_five_ten_crop_kernel_internal(functional, input_type): + registry = _KERNEL_REGISTRY.setdefault(functional, {}) + if input_type in registry: + raise TypeError(f"Functional '{functional}' already has a kernel registered for type '{input_type}'.") + + def wrap(kernel): + @functools.wraps(kernel) + def wrapper(inpt, *args, **kwargs): + output = kernel(inpt, *args, **kwargs) + container_type = type(output) + return container_type(tv_tensors.wrap(o, like=inpt) for o in output) + + return wrapper + + def decorator(kernel): + registry[input_type] = wrap(kernel) if issubclass(input_type, tv_tensors.TVTensor) else kernel + return kernel + + return decorator