Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/caltech.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/coco.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/fakedata.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/flickr.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/hmdb51.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/imagenet.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/kinetics.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/mnist.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/video_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/voc.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/widerface.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/samplers/__init__.py +3 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/samplers/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/samplers/__pycache__/clip_sampler.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/samplers/clip_sampler.py +172 -0
- .venv/lib/python3.11/site-packages/torchvision/transforms/__init__.py +2 -0
- .venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/_functional_pil.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/_functional_tensor.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/_functional_video.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/_presets.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/_transforms_video.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/autoaugment.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/functional.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/transforms/_functional_pil.py +393 -0
- .venv/lib/python3.11/site-packages/torchvision/transforms/_functional_tensor.py +962 -0
- .venv/lib/python3.11/site-packages/torchvision/transforms/_functional_video.py +114 -0
- .venv/lib/python3.11/site-packages/torchvision/transforms/_presets.py +216 -0
- .venv/lib/python3.11/site-packages/torchvision/transforms/_transforms_video.py +174 -0
- .venv/lib/python3.11/site-packages/torchvision/transforms/autoaugment.py +615 -0
- .venv/lib/python3.11/site-packages/torchvision/transforms/functional.py +1586 -0
- .venv/lib/python3.11/site-packages/torchvision/transforms/transforms.py +2153 -0
- .venv/lib/python3.11/site-packages/torchvision/transforms/v2/__init__.py +60 -0
- .venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_augment.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_auto_augment.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_color.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_deprecated.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_misc.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_temporal.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_transform.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_type_conversion.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/transforms/v2/_augment.py +369 -0
- .venv/lib/python3.11/site-packages/torchvision/transforms/v2/_auto_augment.py +627 -0
- .venv/lib/python3.11/site-packages/torchvision/transforms/v2/_color.py +376 -0
- .venv/lib/python3.11/site-packages/torchvision/transforms/v2/_container.py +174 -0
- .venv/lib/python3.11/site-packages/torchvision/transforms/v2/_deprecated.py +50 -0
- .venv/lib/python3.11/site-packages/torchvision/transforms/v2/_geometry.py +1416 -0
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (4.41 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/caltech.cpython-311.pyc
ADDED
|
Binary file (13.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/coco.cpython-311.pyc
ADDED
|
Binary file (7.09 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/fakedata.cpython-311.pyc
ADDED
|
Binary file (3.86 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/flickr.cpython-311.pyc
ADDED
|
Binary file (9.37 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/hmdb51.cpython-311.pyc
ADDED
|
Binary file (8.16 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/imagenet.cpython-311.pyc
ADDED
|
Binary file (16.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/kinetics.cpython-311.pyc
ADDED
|
Binary file (14.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/mnist.cpython-311.pyc
ADDED
|
Binary file (33.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (27.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/video_utils.cpython-311.pyc
ADDED
|
Binary file (22.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/voc.cpython-311.pyc
ADDED
|
Binary file (13.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/widerface.cpython-311.pyc
ADDED
|
Binary file (11.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/samplers/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .clip_sampler import DistributedSampler, RandomClipSampler, UniformClipSampler
|
| 2 |
+
|
| 3 |
+
__all__ = ("DistributedSampler", "UniformClipSampler", "RandomClipSampler")
|
.venv/lib/python3.11/site-packages/torchvision/datasets/samplers/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (373 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/samplers/__pycache__/clip_sampler.cpython-311.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/samplers/clip_sampler.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import cast, Iterator, List, Optional, Sized, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.distributed as dist
|
| 6 |
+
from torch.utils.data import Sampler
|
| 7 |
+
from torchvision.datasets.video_utils import VideoClips
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class DistributedSampler(Sampler):
|
| 11 |
+
"""
|
| 12 |
+
Extension of DistributedSampler, as discussed in
|
| 13 |
+
https://github.com/pytorch/pytorch/issues/23430
|
| 14 |
+
|
| 15 |
+
Example:
|
| 16 |
+
dataset: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
|
| 17 |
+
num_replicas: 4
|
| 18 |
+
shuffle: False
|
| 19 |
+
|
| 20 |
+
when group_size = 1
|
| 21 |
+
RANK | shard_dataset
|
| 22 |
+
=========================
|
| 23 |
+
rank_0 | [0, 4, 8, 12]
|
| 24 |
+
rank_1 | [1, 5, 9, 13]
|
| 25 |
+
rank_2 | [2, 6, 10, 0]
|
| 26 |
+
rank_3 | [3, 7, 11, 1]
|
| 27 |
+
|
| 28 |
+
when group_size = 2
|
| 29 |
+
|
| 30 |
+
RANK | shard_dataset
|
| 31 |
+
=========================
|
| 32 |
+
rank_0 | [0, 1, 8, 9]
|
| 33 |
+
rank_1 | [2, 3, 10, 11]
|
| 34 |
+
rank_2 | [4, 5, 12, 13]
|
| 35 |
+
rank_3 | [6, 7, 0, 1]
|
| 36 |
+
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
dataset: Sized,
|
| 42 |
+
num_replicas: Optional[int] = None,
|
| 43 |
+
rank: Optional[int] = None,
|
| 44 |
+
shuffle: bool = False,
|
| 45 |
+
group_size: int = 1,
|
| 46 |
+
) -> None:
|
| 47 |
+
if num_replicas is None:
|
| 48 |
+
if not dist.is_available():
|
| 49 |
+
raise RuntimeError("Requires distributed package to be available")
|
| 50 |
+
num_replicas = dist.get_world_size()
|
| 51 |
+
if rank is None:
|
| 52 |
+
if not dist.is_available():
|
| 53 |
+
raise RuntimeError("Requires distributed package to be available")
|
| 54 |
+
rank = dist.get_rank()
|
| 55 |
+
if len(dataset) % group_size != 0:
|
| 56 |
+
raise ValueError(
|
| 57 |
+
f"dataset length must be a multiplier of group size dataset length: {len(dataset)}, group size: {group_size}"
|
| 58 |
+
)
|
| 59 |
+
self.dataset = dataset
|
| 60 |
+
self.group_size = group_size
|
| 61 |
+
self.num_replicas = num_replicas
|
| 62 |
+
self.rank = rank
|
| 63 |
+
self.epoch = 0
|
| 64 |
+
dataset_group_length = len(dataset) // group_size
|
| 65 |
+
self.num_group_samples = int(math.ceil(dataset_group_length * 1.0 / self.num_replicas))
|
| 66 |
+
self.num_samples = self.num_group_samples * group_size
|
| 67 |
+
self.total_size = self.num_samples * self.num_replicas
|
| 68 |
+
self.shuffle = shuffle
|
| 69 |
+
|
| 70 |
+
def __iter__(self) -> Iterator[int]:
|
| 71 |
+
# deterministically shuffle based on epoch
|
| 72 |
+
g = torch.Generator()
|
| 73 |
+
g.manual_seed(self.epoch)
|
| 74 |
+
indices: Union[torch.Tensor, List[int]]
|
| 75 |
+
if self.shuffle:
|
| 76 |
+
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
| 77 |
+
else:
|
| 78 |
+
indices = list(range(len(self.dataset)))
|
| 79 |
+
|
| 80 |
+
# add extra samples to make it evenly divisible
|
| 81 |
+
indices += indices[: (self.total_size - len(indices))]
|
| 82 |
+
assert len(indices) == self.total_size
|
| 83 |
+
|
| 84 |
+
total_group_size = self.total_size // self.group_size
|
| 85 |
+
indices = torch.reshape(torch.LongTensor(indices), (total_group_size, self.group_size))
|
| 86 |
+
|
| 87 |
+
# subsample
|
| 88 |
+
indices = indices[self.rank : total_group_size : self.num_replicas, :]
|
| 89 |
+
indices = torch.reshape(indices, (-1,)).tolist()
|
| 90 |
+
assert len(indices) == self.num_samples
|
| 91 |
+
|
| 92 |
+
if isinstance(self.dataset, Sampler):
|
| 93 |
+
orig_indices = list(iter(self.dataset))
|
| 94 |
+
indices = [orig_indices[i] for i in indices]
|
| 95 |
+
|
| 96 |
+
return iter(indices)
|
| 97 |
+
|
| 98 |
+
def __len__(self) -> int:
|
| 99 |
+
return self.num_samples
|
| 100 |
+
|
| 101 |
+
def set_epoch(self, epoch: int) -> None:
|
| 102 |
+
self.epoch = epoch
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class UniformClipSampler(Sampler):
|
| 106 |
+
"""
|
| 107 |
+
Sample `num_video_clips_per_video` clips for each video, equally spaced.
|
| 108 |
+
When number of unique clips in the video is fewer than num_video_clips_per_video,
|
| 109 |
+
repeat the clips until `num_video_clips_per_video` clips are collected
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
video_clips (VideoClips): video clips to sample from
|
| 113 |
+
num_clips_per_video (int): number of clips to be sampled per video
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
def __init__(self, video_clips: VideoClips, num_clips_per_video: int) -> None:
|
| 117 |
+
if not isinstance(video_clips, VideoClips):
|
| 118 |
+
raise TypeError(f"Expected video_clips to be an instance of VideoClips, got {type(video_clips)}")
|
| 119 |
+
self.video_clips = video_clips
|
| 120 |
+
self.num_clips_per_video = num_clips_per_video
|
| 121 |
+
|
| 122 |
+
def __iter__(self) -> Iterator[int]:
|
| 123 |
+
idxs = []
|
| 124 |
+
s = 0
|
| 125 |
+
# select num_clips_per_video for each video, uniformly spaced
|
| 126 |
+
for c in self.video_clips.clips:
|
| 127 |
+
length = len(c)
|
| 128 |
+
if length == 0:
|
| 129 |
+
# corner case where video decoding fails
|
| 130 |
+
continue
|
| 131 |
+
|
| 132 |
+
sampled = torch.linspace(s, s + length - 1, steps=self.num_clips_per_video).floor().to(torch.int64)
|
| 133 |
+
s += length
|
| 134 |
+
idxs.append(sampled)
|
| 135 |
+
return iter(cast(List[int], torch.cat(idxs).tolist()))
|
| 136 |
+
|
| 137 |
+
def __len__(self) -> int:
|
| 138 |
+
return sum(self.num_clips_per_video for c in self.video_clips.clips if len(c) > 0)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class RandomClipSampler(Sampler):
|
| 142 |
+
"""
|
| 143 |
+
Samples at most `max_video_clips_per_video` clips for each video randomly
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
video_clips (VideoClips): video clips to sample from
|
| 147 |
+
max_clips_per_video (int): maximum number of clips to be sampled per video
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
def __init__(self, video_clips: VideoClips, max_clips_per_video: int) -> None:
|
| 151 |
+
if not isinstance(video_clips, VideoClips):
|
| 152 |
+
raise TypeError(f"Expected video_clips to be an instance of VideoClips, got {type(video_clips)}")
|
| 153 |
+
self.video_clips = video_clips
|
| 154 |
+
self.max_clips_per_video = max_clips_per_video
|
| 155 |
+
|
| 156 |
+
def __iter__(self) -> Iterator[int]:
|
| 157 |
+
idxs = []
|
| 158 |
+
s = 0
|
| 159 |
+
# select at most max_clips_per_video for each video, randomly
|
| 160 |
+
for c in self.video_clips.clips:
|
| 161 |
+
length = len(c)
|
| 162 |
+
size = min(length, self.max_clips_per_video)
|
| 163 |
+
sampled = torch.randperm(length)[:size] + s
|
| 164 |
+
s += length
|
| 165 |
+
idxs.append(sampled)
|
| 166 |
+
idxs_ = torch.cat(idxs)
|
| 167 |
+
# shuffle all clips randomly
|
| 168 |
+
perm = torch.randperm(len(idxs_))
|
| 169 |
+
return iter(idxs_[perm].tolist())
|
| 170 |
+
|
| 171 |
+
def __len__(self) -> int:
|
| 172 |
+
return sum(min(len(c), self.max_clips_per_video) for c in self.video_clips.clips)
|
.venv/lib/python3.11/site-packages/torchvision/transforms/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .transforms import *
|
| 2 |
+
from .autoaugment import *
|
.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (263 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/_functional_pil.cpython-311.pyc
ADDED
|
Binary file (21.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/_functional_tensor.cpython-311.pyc
ADDED
|
Binary file (52 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/_functional_video.cpython-311.pyc
ADDED
|
Binary file (6.23 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/_presets.cpython-311.pyc
ADDED
|
Binary file (13 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/_transforms_video.cpython-311.pyc
ADDED
|
Binary file (9.19 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/autoaugment.cpython-311.pyc
ADDED
|
Binary file (34 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/functional.cpython-311.pyc
ADDED
|
Binary file (85.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/transforms/_functional_pil.py
ADDED
|
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numbers
|
| 2 |
+
from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from PIL import Image, ImageEnhance, ImageOps
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
import accimage
|
| 10 |
+
except ImportError:
|
| 11 |
+
accimage = None
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@torch.jit.unused
|
| 15 |
+
def _is_pil_image(img: Any) -> bool:
|
| 16 |
+
if accimage is not None:
|
| 17 |
+
return isinstance(img, (Image.Image, accimage.Image))
|
| 18 |
+
else:
|
| 19 |
+
return isinstance(img, Image.Image)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@torch.jit.unused
|
| 23 |
+
def get_dimensions(img: Any) -> List[int]:
|
| 24 |
+
if _is_pil_image(img):
|
| 25 |
+
if hasattr(img, "getbands"):
|
| 26 |
+
channels = len(img.getbands())
|
| 27 |
+
else:
|
| 28 |
+
channels = img.channels
|
| 29 |
+
width, height = img.size
|
| 30 |
+
return [channels, height, width]
|
| 31 |
+
raise TypeError(f"Unexpected type {type(img)}")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@torch.jit.unused
|
| 35 |
+
def get_image_size(img: Any) -> List[int]:
|
| 36 |
+
if _is_pil_image(img):
|
| 37 |
+
return list(img.size)
|
| 38 |
+
raise TypeError(f"Unexpected type {type(img)}")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@torch.jit.unused
|
| 42 |
+
def get_image_num_channels(img: Any) -> int:
|
| 43 |
+
if _is_pil_image(img):
|
| 44 |
+
if hasattr(img, "getbands"):
|
| 45 |
+
return len(img.getbands())
|
| 46 |
+
else:
|
| 47 |
+
return img.channels
|
| 48 |
+
raise TypeError(f"Unexpected type {type(img)}")
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@torch.jit.unused
|
| 52 |
+
def hflip(img: Image.Image) -> Image.Image:
|
| 53 |
+
if not _is_pil_image(img):
|
| 54 |
+
raise TypeError(f"img should be PIL Image. Got {type(img)}")
|
| 55 |
+
|
| 56 |
+
return img.transpose(Image.FLIP_LEFT_RIGHT)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@torch.jit.unused
|
| 60 |
+
def vflip(img: Image.Image) -> Image.Image:
|
| 61 |
+
if not _is_pil_image(img):
|
| 62 |
+
raise TypeError(f"img should be PIL Image. Got {type(img)}")
|
| 63 |
+
|
| 64 |
+
return img.transpose(Image.FLIP_TOP_BOTTOM)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@torch.jit.unused
|
| 68 |
+
def adjust_brightness(img: Image.Image, brightness_factor: float) -> Image.Image:
|
| 69 |
+
if not _is_pil_image(img):
|
| 70 |
+
raise TypeError(f"img should be PIL Image. Got {type(img)}")
|
| 71 |
+
|
| 72 |
+
enhancer = ImageEnhance.Brightness(img)
|
| 73 |
+
img = enhancer.enhance(brightness_factor)
|
| 74 |
+
return img
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@torch.jit.unused
|
| 78 |
+
def adjust_contrast(img: Image.Image, contrast_factor: float) -> Image.Image:
|
| 79 |
+
if not _is_pil_image(img):
|
| 80 |
+
raise TypeError(f"img should be PIL Image. Got {type(img)}")
|
| 81 |
+
|
| 82 |
+
enhancer = ImageEnhance.Contrast(img)
|
| 83 |
+
img = enhancer.enhance(contrast_factor)
|
| 84 |
+
return img
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@torch.jit.unused
|
| 88 |
+
def adjust_saturation(img: Image.Image, saturation_factor: float) -> Image.Image:
|
| 89 |
+
if not _is_pil_image(img):
|
| 90 |
+
raise TypeError(f"img should be PIL Image. Got {type(img)}")
|
| 91 |
+
|
| 92 |
+
enhancer = ImageEnhance.Color(img)
|
| 93 |
+
img = enhancer.enhance(saturation_factor)
|
| 94 |
+
return img
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
@torch.jit.unused
|
| 98 |
+
def adjust_hue(img: Image.Image, hue_factor: float) -> Image.Image:
|
| 99 |
+
if not (-0.5 <= hue_factor <= 0.5):
|
| 100 |
+
raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
|
| 101 |
+
|
| 102 |
+
if not _is_pil_image(img):
|
| 103 |
+
raise TypeError(f"img should be PIL Image. Got {type(img)}")
|
| 104 |
+
|
| 105 |
+
input_mode = img.mode
|
| 106 |
+
if input_mode in {"L", "1", "I", "F"}:
|
| 107 |
+
return img
|
| 108 |
+
|
| 109 |
+
h, s, v = img.convert("HSV").split()
|
| 110 |
+
|
| 111 |
+
np_h = np.array(h, dtype=np.uint8)
|
| 112 |
+
# This will over/underflow, as desired
|
| 113 |
+
np_h += np.array(hue_factor * 255).astype(np.uint8)
|
| 114 |
+
|
| 115 |
+
h = Image.fromarray(np_h, "L")
|
| 116 |
+
|
| 117 |
+
img = Image.merge("HSV", (h, s, v)).convert(input_mode)
|
| 118 |
+
return img
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@torch.jit.unused
|
| 122 |
+
def adjust_gamma(
|
| 123 |
+
img: Image.Image,
|
| 124 |
+
gamma: float,
|
| 125 |
+
gain: float = 1.0,
|
| 126 |
+
) -> Image.Image:
|
| 127 |
+
|
| 128 |
+
if not _is_pil_image(img):
|
| 129 |
+
raise TypeError(f"img should be PIL Image. Got {type(img)}")
|
| 130 |
+
|
| 131 |
+
if gamma < 0:
|
| 132 |
+
raise ValueError("Gamma should be a non-negative real number")
|
| 133 |
+
|
| 134 |
+
input_mode = img.mode
|
| 135 |
+
img = img.convert("RGB")
|
| 136 |
+
gamma_map = [int((255 + 1 - 1e-3) * gain * pow(ele / 255.0, gamma)) for ele in range(256)] * 3
|
| 137 |
+
img = img.point(gamma_map) # use PIL's point-function to accelerate this part
|
| 138 |
+
|
| 139 |
+
img = img.convert(input_mode)
|
| 140 |
+
return img
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
@torch.jit.unused
|
| 144 |
+
def pad(
|
| 145 |
+
img: Image.Image,
|
| 146 |
+
padding: Union[int, List[int], Tuple[int, ...]],
|
| 147 |
+
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
|
| 148 |
+
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
|
| 149 |
+
) -> Image.Image:
|
| 150 |
+
|
| 151 |
+
if not _is_pil_image(img):
|
| 152 |
+
raise TypeError(f"img should be PIL Image. Got {type(img)}")
|
| 153 |
+
|
| 154 |
+
if not isinstance(padding, (numbers.Number, tuple, list)):
|
| 155 |
+
raise TypeError("Got inappropriate padding arg")
|
| 156 |
+
if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)):
|
| 157 |
+
raise TypeError("Got inappropriate fill arg")
|
| 158 |
+
if not isinstance(padding_mode, str):
|
| 159 |
+
raise TypeError("Got inappropriate padding_mode arg")
|
| 160 |
+
|
| 161 |
+
if isinstance(padding, list):
|
| 162 |
+
padding = tuple(padding)
|
| 163 |
+
|
| 164 |
+
if isinstance(padding, tuple) and len(padding) not in [1, 2, 4]:
|
| 165 |
+
raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple")
|
| 166 |
+
|
| 167 |
+
if isinstance(padding, tuple) and len(padding) == 1:
|
| 168 |
+
# Compatibility with `functional_tensor.pad`
|
| 169 |
+
padding = padding[0]
|
| 170 |
+
|
| 171 |
+
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
|
| 172 |
+
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
|
| 173 |
+
|
| 174 |
+
if padding_mode == "constant":
|
| 175 |
+
opts = _parse_fill(fill, img, name="fill")
|
| 176 |
+
if img.mode == "P":
|
| 177 |
+
palette = img.getpalette()
|
| 178 |
+
image = ImageOps.expand(img, border=padding, **opts)
|
| 179 |
+
image.putpalette(palette)
|
| 180 |
+
return image
|
| 181 |
+
|
| 182 |
+
return ImageOps.expand(img, border=padding, **opts)
|
| 183 |
+
else:
|
| 184 |
+
if isinstance(padding, int):
|
| 185 |
+
pad_left = pad_right = pad_top = pad_bottom = padding
|
| 186 |
+
if isinstance(padding, tuple) and len(padding) == 2:
|
| 187 |
+
pad_left = pad_right = padding[0]
|
| 188 |
+
pad_top = pad_bottom = padding[1]
|
| 189 |
+
if isinstance(padding, tuple) and len(padding) == 4:
|
| 190 |
+
pad_left = padding[0]
|
| 191 |
+
pad_top = padding[1]
|
| 192 |
+
pad_right = padding[2]
|
| 193 |
+
pad_bottom = padding[3]
|
| 194 |
+
|
| 195 |
+
p = [pad_left, pad_top, pad_right, pad_bottom]
|
| 196 |
+
cropping = -np.minimum(p, 0)
|
| 197 |
+
|
| 198 |
+
if cropping.any():
|
| 199 |
+
crop_left, crop_top, crop_right, crop_bottom = cropping
|
| 200 |
+
img = img.crop((crop_left, crop_top, img.width - crop_right, img.height - crop_bottom))
|
| 201 |
+
|
| 202 |
+
pad_left, pad_top, pad_right, pad_bottom = np.maximum(p, 0)
|
| 203 |
+
|
| 204 |
+
if img.mode == "P":
|
| 205 |
+
palette = img.getpalette()
|
| 206 |
+
img = np.asarray(img)
|
| 207 |
+
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), mode=padding_mode)
|
| 208 |
+
img = Image.fromarray(img)
|
| 209 |
+
img.putpalette(palette)
|
| 210 |
+
return img
|
| 211 |
+
|
| 212 |
+
img = np.asarray(img)
|
| 213 |
+
# RGB image
|
| 214 |
+
if len(img.shape) == 3:
|
| 215 |
+
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode)
|
| 216 |
+
# Grayscale image
|
| 217 |
+
if len(img.shape) == 2:
|
| 218 |
+
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)
|
| 219 |
+
|
| 220 |
+
return Image.fromarray(img)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
@torch.jit.unused
|
| 224 |
+
def crop(
|
| 225 |
+
img: Image.Image,
|
| 226 |
+
top: int,
|
| 227 |
+
left: int,
|
| 228 |
+
height: int,
|
| 229 |
+
width: int,
|
| 230 |
+
) -> Image.Image:
|
| 231 |
+
|
| 232 |
+
if not _is_pil_image(img):
|
| 233 |
+
raise TypeError(f"img should be PIL Image. Got {type(img)}")
|
| 234 |
+
|
| 235 |
+
return img.crop((left, top, left + width, top + height))
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
@torch.jit.unused
|
| 239 |
+
def resize(
|
| 240 |
+
img: Image.Image,
|
| 241 |
+
size: Union[List[int], int],
|
| 242 |
+
interpolation: int = Image.BILINEAR,
|
| 243 |
+
) -> Image.Image:
|
| 244 |
+
|
| 245 |
+
if not _is_pil_image(img):
|
| 246 |
+
raise TypeError(f"img should be PIL Image. Got {type(img)}")
|
| 247 |
+
if not (isinstance(size, list) and len(size) == 2):
|
| 248 |
+
raise TypeError(f"Got inappropriate size arg: {size}")
|
| 249 |
+
|
| 250 |
+
return img.resize(tuple(size[::-1]), interpolation)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
@torch.jit.unused
|
| 254 |
+
def _parse_fill(
|
| 255 |
+
fill: Optional[Union[float, List[float], Tuple[float, ...]]],
|
| 256 |
+
img: Image.Image,
|
| 257 |
+
name: str = "fillcolor",
|
| 258 |
+
) -> Dict[str, Optional[Union[float, List[float], Tuple[float, ...]]]]:
|
| 259 |
+
|
| 260 |
+
# Process fill color for affine transforms
|
| 261 |
+
num_channels = get_image_num_channels(img)
|
| 262 |
+
if fill is None:
|
| 263 |
+
fill = 0
|
| 264 |
+
if isinstance(fill, (int, float)) and num_channels > 1:
|
| 265 |
+
fill = tuple([fill] * num_channels)
|
| 266 |
+
if isinstance(fill, (list, tuple)):
|
| 267 |
+
if len(fill) == 1:
|
| 268 |
+
fill = fill * num_channels
|
| 269 |
+
elif len(fill) != num_channels:
|
| 270 |
+
msg = "The number of elements in 'fill' does not match the number of channels of the image ({} != {})"
|
| 271 |
+
raise ValueError(msg.format(len(fill), num_channels))
|
| 272 |
+
|
| 273 |
+
fill = tuple(fill) # type: ignore[arg-type]
|
| 274 |
+
|
| 275 |
+
if img.mode != "F":
|
| 276 |
+
if isinstance(fill, (list, tuple)):
|
| 277 |
+
fill = tuple(int(x) for x in fill)
|
| 278 |
+
else:
|
| 279 |
+
fill = int(fill)
|
| 280 |
+
|
| 281 |
+
return {name: fill}
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
@torch.jit.unused
|
| 285 |
+
def affine(
|
| 286 |
+
img: Image.Image,
|
| 287 |
+
matrix: List[float],
|
| 288 |
+
interpolation: int = Image.NEAREST,
|
| 289 |
+
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
|
| 290 |
+
) -> Image.Image:
|
| 291 |
+
|
| 292 |
+
if not _is_pil_image(img):
|
| 293 |
+
raise TypeError(f"img should be PIL Image. Got {type(img)}")
|
| 294 |
+
|
| 295 |
+
output_size = img.size
|
| 296 |
+
opts = _parse_fill(fill, img)
|
| 297 |
+
return img.transform(output_size, Image.AFFINE, matrix, interpolation, **opts)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
@torch.jit.unused
|
| 301 |
+
def rotate(
|
| 302 |
+
img: Image.Image,
|
| 303 |
+
angle: float,
|
| 304 |
+
interpolation: int = Image.NEAREST,
|
| 305 |
+
expand: bool = False,
|
| 306 |
+
center: Optional[Tuple[int, int]] = None,
|
| 307 |
+
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
|
| 308 |
+
) -> Image.Image:
|
| 309 |
+
|
| 310 |
+
if not _is_pil_image(img):
|
| 311 |
+
raise TypeError(f"img should be PIL Image. Got {type(img)}")
|
| 312 |
+
|
| 313 |
+
opts = _parse_fill(fill, img)
|
| 314 |
+
return img.rotate(angle, interpolation, expand, center, **opts)
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
@torch.jit.unused
|
| 318 |
+
def perspective(
|
| 319 |
+
img: Image.Image,
|
| 320 |
+
perspective_coeffs: List[float],
|
| 321 |
+
interpolation: int = Image.BICUBIC,
|
| 322 |
+
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
|
| 323 |
+
) -> Image.Image:
|
| 324 |
+
|
| 325 |
+
if not _is_pil_image(img):
|
| 326 |
+
raise TypeError(f"img should be PIL Image. Got {type(img)}")
|
| 327 |
+
|
| 328 |
+
opts = _parse_fill(fill, img)
|
| 329 |
+
|
| 330 |
+
return img.transform(img.size, Image.PERSPECTIVE, perspective_coeffs, interpolation, **opts)
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
@torch.jit.unused
|
| 334 |
+
def to_grayscale(img: Image.Image, num_output_channels: int) -> Image.Image:
|
| 335 |
+
if not _is_pil_image(img):
|
| 336 |
+
raise TypeError(f"img should be PIL Image. Got {type(img)}")
|
| 337 |
+
|
| 338 |
+
if num_output_channels == 1:
|
| 339 |
+
img = img.convert("L")
|
| 340 |
+
elif num_output_channels == 3:
|
| 341 |
+
img = img.convert("L")
|
| 342 |
+
np_img = np.array(img, dtype=np.uint8)
|
| 343 |
+
np_img = np.dstack([np_img, np_img, np_img])
|
| 344 |
+
img = Image.fromarray(np_img, "RGB")
|
| 345 |
+
else:
|
| 346 |
+
raise ValueError("num_output_channels should be either 1 or 3")
|
| 347 |
+
|
| 348 |
+
return img
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
@torch.jit.unused
|
| 352 |
+
def invert(img: Image.Image) -> Image.Image:
|
| 353 |
+
if not _is_pil_image(img):
|
| 354 |
+
raise TypeError(f"img should be PIL Image. Got {type(img)}")
|
| 355 |
+
return ImageOps.invert(img)
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
@torch.jit.unused
|
| 359 |
+
def posterize(img: Image.Image, bits: int) -> Image.Image:
|
| 360 |
+
if not _is_pil_image(img):
|
| 361 |
+
raise TypeError(f"img should be PIL Image. Got {type(img)}")
|
| 362 |
+
return ImageOps.posterize(img, bits)
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
@torch.jit.unused
|
| 366 |
+
def solarize(img: Image.Image, threshold: int) -> Image.Image:
|
| 367 |
+
if not _is_pil_image(img):
|
| 368 |
+
raise TypeError(f"img should be PIL Image. Got {type(img)}")
|
| 369 |
+
return ImageOps.solarize(img, threshold)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
@torch.jit.unused
|
| 373 |
+
def adjust_sharpness(img: Image.Image, sharpness_factor: float) -> Image.Image:
|
| 374 |
+
if not _is_pil_image(img):
|
| 375 |
+
raise TypeError(f"img should be PIL Image. Got {type(img)}")
|
| 376 |
+
|
| 377 |
+
enhancer = ImageEnhance.Sharpness(img)
|
| 378 |
+
img = enhancer.enhance(sharpness_factor)
|
| 379 |
+
return img
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
@torch.jit.unused
|
| 383 |
+
def autocontrast(img: Image.Image) -> Image.Image:
|
| 384 |
+
if not _is_pil_image(img):
|
| 385 |
+
raise TypeError(f"img should be PIL Image. Got {type(img)}")
|
| 386 |
+
return ImageOps.autocontrast(img)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
@torch.jit.unused
|
| 390 |
+
def equalize(img: Image.Image) -> Image.Image:
|
| 391 |
+
if not _is_pil_image(img):
|
| 392 |
+
raise TypeError(f"img should be PIL Image. Got {type(img)}")
|
| 393 |
+
return ImageOps.equalize(img)
|
.venv/lib/python3.11/site-packages/torchvision/transforms/_functional_tensor.py
ADDED
|
@@ -0,0 +1,962 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
from typing import List, Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
from torch.nn.functional import conv2d, grid_sample, interpolate, pad as torch_pad
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _is_tensor_a_torch_image(x: Tensor) -> bool:
|
| 10 |
+
return x.ndim >= 2
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _assert_image_tensor(img: Tensor) -> None:
|
| 14 |
+
if not _is_tensor_a_torch_image(img):
|
| 15 |
+
raise TypeError("Tensor is not a torch image.")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_dimensions(img: Tensor) -> List[int]:
|
| 19 |
+
_assert_image_tensor(img)
|
| 20 |
+
channels = 1 if img.ndim == 2 else img.shape[-3]
|
| 21 |
+
height, width = img.shape[-2:]
|
| 22 |
+
return [channels, height, width]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_image_size(img: Tensor) -> List[int]:
|
| 26 |
+
# Returns (w, h) of tensor image
|
| 27 |
+
_assert_image_tensor(img)
|
| 28 |
+
return [img.shape[-1], img.shape[-2]]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_image_num_channels(img: Tensor) -> int:
|
| 32 |
+
_assert_image_tensor(img)
|
| 33 |
+
if img.ndim == 2:
|
| 34 |
+
return 1
|
| 35 |
+
elif img.ndim > 2:
|
| 36 |
+
return img.shape[-3]
|
| 37 |
+
|
| 38 |
+
raise TypeError(f"Input ndim should be 2 or more. Got {img.ndim}")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _max_value(dtype: torch.dtype) -> int:
|
| 42 |
+
if dtype == torch.uint8:
|
| 43 |
+
return 255
|
| 44 |
+
elif dtype == torch.int8:
|
| 45 |
+
return 127
|
| 46 |
+
elif dtype == torch.int16:
|
| 47 |
+
return 32767
|
| 48 |
+
elif dtype == torch.uint16:
|
| 49 |
+
return 65535
|
| 50 |
+
elif dtype == torch.int32:
|
| 51 |
+
return 2147483647
|
| 52 |
+
elif dtype == torch.int64:
|
| 53 |
+
return 9223372036854775807
|
| 54 |
+
else:
|
| 55 |
+
# This is only here for completeness. This value is implicitly assumed in a lot of places so changing it is not
|
| 56 |
+
# easy.
|
| 57 |
+
return 1
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _assert_channels(img: Tensor, permitted: List[int]) -> None:
|
| 61 |
+
c = get_dimensions(img)[0]
|
| 62 |
+
if c not in permitted:
|
| 63 |
+
raise TypeError(f"Input image tensor permitted channel values are {permitted}, but found {c}")
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
|
| 67 |
+
if image.dtype == dtype:
|
| 68 |
+
return image
|
| 69 |
+
|
| 70 |
+
if image.is_floating_point():
|
| 71 |
+
|
| 72 |
+
# TODO: replace with dtype.is_floating_point when torchscript supports it
|
| 73 |
+
if torch.tensor(0, dtype=dtype).is_floating_point():
|
| 74 |
+
return image.to(dtype)
|
| 75 |
+
|
| 76 |
+
# float to int
|
| 77 |
+
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
|
| 78 |
+
image.dtype == torch.float64 and dtype == torch.int64
|
| 79 |
+
):
|
| 80 |
+
msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
|
| 81 |
+
raise RuntimeError(msg)
|
| 82 |
+
|
| 83 |
+
# https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
|
| 84 |
+
# For data in the range 0-1, (float * 255).to(uint) is only 255
|
| 85 |
+
# when float is exactly 1.0.
|
| 86 |
+
# `max + 1 - epsilon` provides more evenly distributed mapping of
|
| 87 |
+
# ranges of floats to ints.
|
| 88 |
+
eps = 1e-3
|
| 89 |
+
max_val = float(_max_value(dtype))
|
| 90 |
+
result = image.mul(max_val + 1.0 - eps)
|
| 91 |
+
return result.to(dtype)
|
| 92 |
+
else:
|
| 93 |
+
input_max = float(_max_value(image.dtype))
|
| 94 |
+
|
| 95 |
+
# int to float
|
| 96 |
+
# TODO: replace with dtype.is_floating_point when torchscript supports it
|
| 97 |
+
if torch.tensor(0, dtype=dtype).is_floating_point():
|
| 98 |
+
image = image.to(dtype)
|
| 99 |
+
return image / input_max
|
| 100 |
+
|
| 101 |
+
output_max = float(_max_value(dtype))
|
| 102 |
+
|
| 103 |
+
# int to int
|
| 104 |
+
if input_max > output_max:
|
| 105 |
+
# factor should be forced to int for torch jit script
|
| 106 |
+
# otherwise factor is a float and image // factor can produce different results
|
| 107 |
+
factor = int((input_max + 1) // (output_max + 1))
|
| 108 |
+
image = torch.div(image, factor, rounding_mode="floor")
|
| 109 |
+
return image.to(dtype)
|
| 110 |
+
else:
|
| 111 |
+
# factor should be forced to int for torch jit script
|
| 112 |
+
# otherwise factor is a float and image * factor can produce different results
|
| 113 |
+
factor = int((output_max + 1) // (input_max + 1))
|
| 114 |
+
image = image.to(dtype)
|
| 115 |
+
return image * factor
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def vflip(img: Tensor) -> Tensor:
|
| 119 |
+
_assert_image_tensor(img)
|
| 120 |
+
|
| 121 |
+
return img.flip(-2)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def hflip(img: Tensor) -> Tensor:
|
| 125 |
+
_assert_image_tensor(img)
|
| 126 |
+
|
| 127 |
+
return img.flip(-1)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
|
| 131 |
+
_assert_image_tensor(img)
|
| 132 |
+
|
| 133 |
+
_, h, w = get_dimensions(img)
|
| 134 |
+
right = left + width
|
| 135 |
+
bottom = top + height
|
| 136 |
+
|
| 137 |
+
if left < 0 or top < 0 or right > w or bottom > h:
|
| 138 |
+
padding_ltrb = [
|
| 139 |
+
max(-left + min(0, right), 0),
|
| 140 |
+
max(-top + min(0, bottom), 0),
|
| 141 |
+
max(right - max(w, left), 0),
|
| 142 |
+
max(bottom - max(h, top), 0),
|
| 143 |
+
]
|
| 144 |
+
return pad(img[..., max(top, 0) : bottom, max(left, 0) : right], padding_ltrb, fill=0)
|
| 145 |
+
return img[..., top:bottom, left:right]
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
|
| 149 |
+
if img.ndim < 3:
|
| 150 |
+
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
|
| 151 |
+
_assert_channels(img, [1, 3])
|
| 152 |
+
|
| 153 |
+
if num_output_channels not in (1, 3):
|
| 154 |
+
raise ValueError("num_output_channels should be either 1 or 3")
|
| 155 |
+
|
| 156 |
+
if img.shape[-3] == 3:
|
| 157 |
+
r, g, b = img.unbind(dim=-3)
|
| 158 |
+
# This implementation closely follows the TF one:
|
| 159 |
+
# https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/ops/image_ops_impl.py#L2105-L2138
|
| 160 |
+
l_img = (0.2989 * r + 0.587 * g + 0.114 * b).to(img.dtype)
|
| 161 |
+
l_img = l_img.unsqueeze(dim=-3)
|
| 162 |
+
else:
|
| 163 |
+
l_img = img.clone()
|
| 164 |
+
|
| 165 |
+
if num_output_channels == 3:
|
| 166 |
+
return l_img.expand(img.shape)
|
| 167 |
+
|
| 168 |
+
return l_img
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
|
| 172 |
+
if brightness_factor < 0:
|
| 173 |
+
raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.")
|
| 174 |
+
|
| 175 |
+
_assert_image_tensor(img)
|
| 176 |
+
|
| 177 |
+
_assert_channels(img, [1, 3])
|
| 178 |
+
|
| 179 |
+
return _blend(img, torch.zeros_like(img), brightness_factor)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
|
| 183 |
+
if contrast_factor < 0:
|
| 184 |
+
raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.")
|
| 185 |
+
|
| 186 |
+
_assert_image_tensor(img)
|
| 187 |
+
|
| 188 |
+
_assert_channels(img, [3, 1])
|
| 189 |
+
c = get_dimensions(img)[0]
|
| 190 |
+
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
|
| 191 |
+
if c == 3:
|
| 192 |
+
mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True)
|
| 193 |
+
else:
|
| 194 |
+
mean = torch.mean(img.to(dtype), dim=(-3, -2, -1), keepdim=True)
|
| 195 |
+
|
| 196 |
+
return _blend(img, mean, contrast_factor)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
|
| 200 |
+
if not (-0.5 <= hue_factor <= 0.5):
|
| 201 |
+
raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
|
| 202 |
+
|
| 203 |
+
if not (isinstance(img, torch.Tensor)):
|
| 204 |
+
raise TypeError("Input img should be Tensor image")
|
| 205 |
+
|
| 206 |
+
_assert_image_tensor(img)
|
| 207 |
+
|
| 208 |
+
_assert_channels(img, [1, 3])
|
| 209 |
+
if get_dimensions(img)[0] == 1: # Match PIL behaviour
|
| 210 |
+
return img
|
| 211 |
+
|
| 212 |
+
orig_dtype = img.dtype
|
| 213 |
+
img = convert_image_dtype(img, torch.float32)
|
| 214 |
+
|
| 215 |
+
img = _rgb2hsv(img)
|
| 216 |
+
h, s, v = img.unbind(dim=-3)
|
| 217 |
+
h = (h + hue_factor) % 1.0
|
| 218 |
+
img = torch.stack((h, s, v), dim=-3)
|
| 219 |
+
img_hue_adj = _hsv2rgb(img)
|
| 220 |
+
|
| 221 |
+
return convert_image_dtype(img_hue_adj, orig_dtype)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
|
| 225 |
+
if saturation_factor < 0:
|
| 226 |
+
raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.")
|
| 227 |
+
|
| 228 |
+
_assert_image_tensor(img)
|
| 229 |
+
|
| 230 |
+
_assert_channels(img, [1, 3])
|
| 231 |
+
|
| 232 |
+
if get_dimensions(img)[0] == 1: # Match PIL behaviour
|
| 233 |
+
return img
|
| 234 |
+
|
| 235 |
+
return _blend(img, rgb_to_grayscale(img), saturation_factor)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
|
| 239 |
+
if not isinstance(img, torch.Tensor):
|
| 240 |
+
raise TypeError("Input img should be a Tensor.")
|
| 241 |
+
|
| 242 |
+
_assert_channels(img, [1, 3])
|
| 243 |
+
|
| 244 |
+
if gamma < 0:
|
| 245 |
+
raise ValueError("Gamma should be a non-negative real number")
|
| 246 |
+
|
| 247 |
+
result = img
|
| 248 |
+
dtype = img.dtype
|
| 249 |
+
if not torch.is_floating_point(img):
|
| 250 |
+
result = convert_image_dtype(result, torch.float32)
|
| 251 |
+
|
| 252 |
+
result = (gain * result**gamma).clamp(0, 1)
|
| 253 |
+
|
| 254 |
+
result = convert_image_dtype(result, dtype)
|
| 255 |
+
return result
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
|
| 259 |
+
ratio = float(ratio)
|
| 260 |
+
bound = _max_value(img1.dtype)
|
| 261 |
+
return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def _rgb2hsv(img: Tensor) -> Tensor:
|
| 265 |
+
r, g, b = img.unbind(dim=-3)
|
| 266 |
+
|
| 267 |
+
# Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/
|
| 268 |
+
# src/libImaging/Convert.c#L330
|
| 269 |
+
maxc = torch.max(img, dim=-3).values
|
| 270 |
+
minc = torch.min(img, dim=-3).values
|
| 271 |
+
|
| 272 |
+
# The algorithm erases S and H channel where `maxc = minc`. This avoids NaN
|
| 273 |
+
# from happening in the results, because
|
| 274 |
+
# + S channel has division by `maxc`, which is zero only if `maxc = minc`
|
| 275 |
+
# + H channel has division by `(maxc - minc)`.
|
| 276 |
+
#
|
| 277 |
+
# Instead of overwriting NaN afterwards, we just prevent it from occurring, so
|
| 278 |
+
# we don't need to deal with it in case we save the NaN in a buffer in
|
| 279 |
+
# backprop, if it is ever supported, but it doesn't hurt to do so.
|
| 280 |
+
eqc = maxc == minc
|
| 281 |
+
|
| 282 |
+
cr = maxc - minc
|
| 283 |
+
# Since `eqc => cr = 0`, replacing denominator with 1 when `eqc` is fine.
|
| 284 |
+
ones = torch.ones_like(maxc)
|
| 285 |
+
s = cr / torch.where(eqc, ones, maxc)
|
| 286 |
+
# Note that `eqc => maxc = minc = r = g = b`. So the following calculation
|
| 287 |
+
# of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it
|
| 288 |
+
# would not matter what values `rc`, `gc`, and `bc` have here, and thus
|
| 289 |
+
# replacing denominator with 1 when `eqc` is fine.
|
| 290 |
+
cr_divisor = torch.where(eqc, ones, cr)
|
| 291 |
+
rc = (maxc - r) / cr_divisor
|
| 292 |
+
gc = (maxc - g) / cr_divisor
|
| 293 |
+
bc = (maxc - b) / cr_divisor
|
| 294 |
+
|
| 295 |
+
hr = (maxc == r) * (bc - gc)
|
| 296 |
+
hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc)
|
| 297 |
+
hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc)
|
| 298 |
+
h = hr + hg + hb
|
| 299 |
+
h = torch.fmod((h / 6.0 + 1.0), 1.0)
|
| 300 |
+
return torch.stack((h, s, maxc), dim=-3)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def _hsv2rgb(img: Tensor) -> Tensor:
|
| 304 |
+
h, s, v = img.unbind(dim=-3)
|
| 305 |
+
i = torch.floor(h * 6.0)
|
| 306 |
+
f = (h * 6.0) - i
|
| 307 |
+
i = i.to(dtype=torch.int32)
|
| 308 |
+
|
| 309 |
+
p = torch.clamp((v * (1.0 - s)), 0.0, 1.0)
|
| 310 |
+
q = torch.clamp((v * (1.0 - s * f)), 0.0, 1.0)
|
| 311 |
+
t = torch.clamp((v * (1.0 - s * (1.0 - f))), 0.0, 1.0)
|
| 312 |
+
i = i % 6
|
| 313 |
+
|
| 314 |
+
mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1)
|
| 315 |
+
|
| 316 |
+
a1 = torch.stack((v, q, p, p, t, v), dim=-3)
|
| 317 |
+
a2 = torch.stack((t, v, v, q, p, p), dim=-3)
|
| 318 |
+
a3 = torch.stack((p, p, t, v, v, q), dim=-3)
|
| 319 |
+
a4 = torch.stack((a1, a2, a3), dim=-4)
|
| 320 |
+
|
| 321 |
+
return torch.einsum("...ijk, ...xijk -> ...xjk", mask.to(dtype=img.dtype), a4)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor:
|
| 325 |
+
# padding is left, right, top, bottom
|
| 326 |
+
|
| 327 |
+
# crop if needed
|
| 328 |
+
if padding[0] < 0 or padding[1] < 0 or padding[2] < 0 or padding[3] < 0:
|
| 329 |
+
neg_min_padding = [-min(x, 0) for x in padding]
|
| 330 |
+
crop_left, crop_right, crop_top, crop_bottom = neg_min_padding
|
| 331 |
+
img = img[..., crop_top : img.shape[-2] - crop_bottom, crop_left : img.shape[-1] - crop_right]
|
| 332 |
+
padding = [max(x, 0) for x in padding]
|
| 333 |
+
|
| 334 |
+
in_sizes = img.size()
|
| 335 |
+
|
| 336 |
+
_x_indices = [i for i in range(in_sizes[-1])] # [0, 1, 2, 3, ...]
|
| 337 |
+
left_indices = [i for i in range(padding[0] - 1, -1, -1)] # e.g. [3, 2, 1, 0]
|
| 338 |
+
right_indices = [-(i + 1) for i in range(padding[1])] # e.g. [-1, -2, -3]
|
| 339 |
+
x_indices = torch.tensor(left_indices + _x_indices + right_indices, device=img.device)
|
| 340 |
+
|
| 341 |
+
_y_indices = [i for i in range(in_sizes[-2])]
|
| 342 |
+
top_indices = [i for i in range(padding[2] - 1, -1, -1)]
|
| 343 |
+
bottom_indices = [-(i + 1) for i in range(padding[3])]
|
| 344 |
+
y_indices = torch.tensor(top_indices + _y_indices + bottom_indices, device=img.device)
|
| 345 |
+
|
| 346 |
+
ndim = img.ndim
|
| 347 |
+
if ndim == 3:
|
| 348 |
+
return img[:, y_indices[:, None], x_indices[None, :]]
|
| 349 |
+
elif ndim == 4:
|
| 350 |
+
return img[:, :, y_indices[:, None], x_indices[None, :]]
|
| 351 |
+
else:
|
| 352 |
+
raise RuntimeError("Symmetric padding of N-D tensors are not supported yet")
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
|
| 356 |
+
if isinstance(padding, int):
|
| 357 |
+
if torch.jit.is_scripting():
|
| 358 |
+
# This maybe unreachable
|
| 359 |
+
raise ValueError("padding can't be an int while torchscripting, set it as a list [value, ]")
|
| 360 |
+
pad_left = pad_right = pad_top = pad_bottom = padding
|
| 361 |
+
elif len(padding) == 1:
|
| 362 |
+
pad_left = pad_right = pad_top = pad_bottom = padding[0]
|
| 363 |
+
elif len(padding) == 2:
|
| 364 |
+
pad_left = pad_right = padding[0]
|
| 365 |
+
pad_top = pad_bottom = padding[1]
|
| 366 |
+
else:
|
| 367 |
+
pad_left = padding[0]
|
| 368 |
+
pad_top = padding[1]
|
| 369 |
+
pad_right = padding[2]
|
| 370 |
+
pad_bottom = padding[3]
|
| 371 |
+
|
| 372 |
+
return [pad_left, pad_right, pad_top, pad_bottom]
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def pad(
|
| 376 |
+
img: Tensor, padding: Union[int, List[int]], fill: Optional[Union[int, float]] = 0, padding_mode: str = "constant"
|
| 377 |
+
) -> Tensor:
|
| 378 |
+
_assert_image_tensor(img)
|
| 379 |
+
|
| 380 |
+
if fill is None:
|
| 381 |
+
fill = 0
|
| 382 |
+
|
| 383 |
+
if not isinstance(padding, (int, tuple, list)):
|
| 384 |
+
raise TypeError("Got inappropriate padding arg")
|
| 385 |
+
if not isinstance(fill, (int, float)):
|
| 386 |
+
raise TypeError("Got inappropriate fill arg")
|
| 387 |
+
if not isinstance(padding_mode, str):
|
| 388 |
+
raise TypeError("Got inappropriate padding_mode arg")
|
| 389 |
+
|
| 390 |
+
if isinstance(padding, tuple):
|
| 391 |
+
padding = list(padding)
|
| 392 |
+
|
| 393 |
+
if isinstance(padding, list):
|
| 394 |
+
# TODO: Jit is failing on loading this op when scripted and saved
|
| 395 |
+
# https://github.com/pytorch/pytorch/issues/81100
|
| 396 |
+
if len(padding) not in [1, 2, 4]:
|
| 397 |
+
raise ValueError(
|
| 398 |
+
f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
|
| 402 |
+
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
|
| 403 |
+
|
| 404 |
+
p = _parse_pad_padding(padding)
|
| 405 |
+
|
| 406 |
+
if padding_mode == "edge":
|
| 407 |
+
# remap padding_mode str
|
| 408 |
+
padding_mode = "replicate"
|
| 409 |
+
elif padding_mode == "symmetric":
|
| 410 |
+
# route to another implementation
|
| 411 |
+
return _pad_symmetric(img, p)
|
| 412 |
+
|
| 413 |
+
need_squeeze = False
|
| 414 |
+
if img.ndim < 4:
|
| 415 |
+
img = img.unsqueeze(dim=0)
|
| 416 |
+
need_squeeze = True
|
| 417 |
+
|
| 418 |
+
out_dtype = img.dtype
|
| 419 |
+
need_cast = False
|
| 420 |
+
if (padding_mode != "constant") and img.dtype not in (torch.float32, torch.float64):
|
| 421 |
+
# Here we temporarily cast input tensor to float
|
| 422 |
+
# until pytorch issue is resolved :
|
| 423 |
+
# https://github.com/pytorch/pytorch/issues/40763
|
| 424 |
+
need_cast = True
|
| 425 |
+
img = img.to(torch.float32)
|
| 426 |
+
|
| 427 |
+
if padding_mode in ("reflect", "replicate"):
|
| 428 |
+
img = torch_pad(img, p, mode=padding_mode)
|
| 429 |
+
else:
|
| 430 |
+
img = torch_pad(img, p, mode=padding_mode, value=float(fill))
|
| 431 |
+
|
| 432 |
+
if need_squeeze:
|
| 433 |
+
img = img.squeeze(dim=0)
|
| 434 |
+
|
| 435 |
+
if need_cast:
|
| 436 |
+
img = img.to(out_dtype)
|
| 437 |
+
|
| 438 |
+
return img
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
def resize(
|
| 442 |
+
img: Tensor,
|
| 443 |
+
size: List[int],
|
| 444 |
+
interpolation: str = "bilinear",
|
| 445 |
+
antialias: Optional[bool] = True,
|
| 446 |
+
) -> Tensor:
|
| 447 |
+
_assert_image_tensor(img)
|
| 448 |
+
|
| 449 |
+
if isinstance(size, tuple):
|
| 450 |
+
size = list(size)
|
| 451 |
+
|
| 452 |
+
if antialias is None:
|
| 453 |
+
antialias = False
|
| 454 |
+
|
| 455 |
+
if antialias and interpolation not in ["bilinear", "bicubic"]:
|
| 456 |
+
# We manually set it to False to avoid an error downstream in interpolate()
|
| 457 |
+
# This behaviour is documented: the parameter is irrelevant for modes
|
| 458 |
+
# that are not bilinear or bicubic. We used to raise an error here, but
|
| 459 |
+
# now we don't as True is the default.
|
| 460 |
+
antialias = False
|
| 461 |
+
|
| 462 |
+
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64])
|
| 463 |
+
|
| 464 |
+
# Define align_corners to avoid warnings
|
| 465 |
+
align_corners = False if interpolation in ["bilinear", "bicubic"] else None
|
| 466 |
+
|
| 467 |
+
img = interpolate(img, size=size, mode=interpolation, align_corners=align_corners, antialias=antialias)
|
| 468 |
+
|
| 469 |
+
if interpolation == "bicubic" and out_dtype == torch.uint8:
|
| 470 |
+
img = img.clamp(min=0, max=255)
|
| 471 |
+
|
| 472 |
+
img = _cast_squeeze_out(img, need_cast=need_cast, need_squeeze=need_squeeze, out_dtype=out_dtype)
|
| 473 |
+
|
| 474 |
+
return img
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
def _assert_grid_transform_inputs(
|
| 478 |
+
img: Tensor,
|
| 479 |
+
matrix: Optional[List[float]],
|
| 480 |
+
interpolation: str,
|
| 481 |
+
fill: Optional[Union[int, float, List[float]]],
|
| 482 |
+
supported_interpolation_modes: List[str],
|
| 483 |
+
coeffs: Optional[List[float]] = None,
|
| 484 |
+
) -> None:
|
| 485 |
+
|
| 486 |
+
if not (isinstance(img, torch.Tensor)):
|
| 487 |
+
raise TypeError("Input img should be Tensor")
|
| 488 |
+
|
| 489 |
+
_assert_image_tensor(img)
|
| 490 |
+
|
| 491 |
+
if matrix is not None and not isinstance(matrix, list):
|
| 492 |
+
raise TypeError("Argument matrix should be a list")
|
| 493 |
+
|
| 494 |
+
if matrix is not None and len(matrix) != 6:
|
| 495 |
+
raise ValueError("Argument matrix should have 6 float values")
|
| 496 |
+
|
| 497 |
+
if coeffs is not None and len(coeffs) != 8:
|
| 498 |
+
raise ValueError("Argument coeffs should have 8 float values")
|
| 499 |
+
|
| 500 |
+
if fill is not None and not isinstance(fill, (int, float, tuple, list)):
|
| 501 |
+
warnings.warn("Argument fill should be either int, float, tuple or list")
|
| 502 |
+
|
| 503 |
+
# Check fill
|
| 504 |
+
num_channels = get_dimensions(img)[0]
|
| 505 |
+
if fill is not None and isinstance(fill, (tuple, list)) and len(fill) > 1 and len(fill) != num_channels:
|
| 506 |
+
msg = (
|
| 507 |
+
"The number of elements in 'fill' cannot broadcast to match the number of "
|
| 508 |
+
"channels of the image ({} != {})"
|
| 509 |
+
)
|
| 510 |
+
raise ValueError(msg.format(len(fill), num_channels))
|
| 511 |
+
|
| 512 |
+
if interpolation not in supported_interpolation_modes:
|
| 513 |
+
raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input")
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
def _cast_squeeze_in(img: Tensor, req_dtypes: List[torch.dtype]) -> Tuple[Tensor, bool, bool, torch.dtype]:
|
| 517 |
+
need_squeeze = False
|
| 518 |
+
# make image NCHW
|
| 519 |
+
if img.ndim < 4:
|
| 520 |
+
img = img.unsqueeze(dim=0)
|
| 521 |
+
need_squeeze = True
|
| 522 |
+
|
| 523 |
+
out_dtype = img.dtype
|
| 524 |
+
need_cast = False
|
| 525 |
+
if out_dtype not in req_dtypes:
|
| 526 |
+
need_cast = True
|
| 527 |
+
req_dtype = req_dtypes[0]
|
| 528 |
+
img = img.to(req_dtype)
|
| 529 |
+
return img, need_cast, need_squeeze, out_dtype
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtype: torch.dtype) -> Tensor:
|
| 533 |
+
if need_squeeze:
|
| 534 |
+
img = img.squeeze(dim=0)
|
| 535 |
+
|
| 536 |
+
if need_cast:
|
| 537 |
+
if out_dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
|
| 538 |
+
# it is better to round before cast
|
| 539 |
+
img = torch.round(img)
|
| 540 |
+
img = img.to(out_dtype)
|
| 541 |
+
|
| 542 |
+
return img
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
def _apply_grid_transform(
|
| 546 |
+
img: Tensor, grid: Tensor, mode: str, fill: Optional[Union[int, float, List[float]]]
|
| 547 |
+
) -> Tensor:
|
| 548 |
+
|
| 549 |
+
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [grid.dtype])
|
| 550 |
+
|
| 551 |
+
if img.shape[0] > 1:
|
| 552 |
+
# Apply same grid to a batch of images
|
| 553 |
+
grid = grid.expand(img.shape[0], grid.shape[1], grid.shape[2], grid.shape[3])
|
| 554 |
+
|
| 555 |
+
# Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
|
| 556 |
+
if fill is not None:
|
| 557 |
+
mask = torch.ones((img.shape[0], 1, img.shape[2], img.shape[3]), dtype=img.dtype, device=img.device)
|
| 558 |
+
img = torch.cat((img, mask), dim=1)
|
| 559 |
+
|
| 560 |
+
img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False)
|
| 561 |
+
|
| 562 |
+
# Fill with required color
|
| 563 |
+
if fill is not None:
|
| 564 |
+
mask = img[:, -1:, :, :] # N * 1 * H * W
|
| 565 |
+
img = img[:, :-1, :, :] # N * C * H * W
|
| 566 |
+
mask = mask.expand_as(img)
|
| 567 |
+
fill_list, len_fill = (fill, len(fill)) if isinstance(fill, (tuple, list)) else ([float(fill)], 1)
|
| 568 |
+
fill_img = torch.tensor(fill_list, dtype=img.dtype, device=img.device).view(1, len_fill, 1, 1).expand_as(img)
|
| 569 |
+
if mode == "nearest":
|
| 570 |
+
mask = mask < 0.5
|
| 571 |
+
img[mask] = fill_img[mask]
|
| 572 |
+
else: # 'bilinear'
|
| 573 |
+
img = img * mask + (1.0 - mask) * fill_img
|
| 574 |
+
|
| 575 |
+
img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
|
| 576 |
+
return img
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
def _gen_affine_grid(
|
| 580 |
+
theta: Tensor,
|
| 581 |
+
w: int,
|
| 582 |
+
h: int,
|
| 583 |
+
ow: int,
|
| 584 |
+
oh: int,
|
| 585 |
+
) -> Tensor:
|
| 586 |
+
# https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/
|
| 587 |
+
# AffineGridGenerator.cpp#L18
|
| 588 |
+
# Difference with AffineGridGenerator is that:
|
| 589 |
+
# 1) we normalize grid values after applying theta
|
| 590 |
+
# 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate
|
| 591 |
+
|
| 592 |
+
d = 0.5
|
| 593 |
+
base_grid = torch.empty(1, oh, ow, 3, dtype=theta.dtype, device=theta.device)
|
| 594 |
+
x_grid = torch.linspace(-ow * 0.5 + d, ow * 0.5 + d - 1, steps=ow, device=theta.device)
|
| 595 |
+
base_grid[..., 0].copy_(x_grid)
|
| 596 |
+
y_grid = torch.linspace(-oh * 0.5 + d, oh * 0.5 + d - 1, steps=oh, device=theta.device).unsqueeze_(-1)
|
| 597 |
+
base_grid[..., 1].copy_(y_grid)
|
| 598 |
+
base_grid[..., 2].fill_(1)
|
| 599 |
+
|
| 600 |
+
rescaled_theta = theta.transpose(1, 2) / torch.tensor([0.5 * w, 0.5 * h], dtype=theta.dtype, device=theta.device)
|
| 601 |
+
output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta)
|
| 602 |
+
return output_grid.view(1, oh, ow, 2)
|
| 603 |
+
|
| 604 |
+
|
| 605 |
+
def affine(
|
| 606 |
+
img: Tensor,
|
| 607 |
+
matrix: List[float],
|
| 608 |
+
interpolation: str = "nearest",
|
| 609 |
+
fill: Optional[Union[int, float, List[float]]] = None,
|
| 610 |
+
) -> Tensor:
|
| 611 |
+
_assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
|
| 612 |
+
|
| 613 |
+
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
|
| 614 |
+
theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
|
| 615 |
+
shape = img.shape
|
| 616 |
+
# grid will be generated on the same device as theta and img
|
| 617 |
+
grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2])
|
| 618 |
+
return _apply_grid_transform(img, grid, interpolation, fill=fill)
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]:
|
| 622 |
+
|
| 623 |
+
# Inspired of PIL implementation:
|
| 624 |
+
# https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054
|
| 625 |
+
|
| 626 |
+
# pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
|
| 627 |
+
# Points are shifted due to affine matrix torch convention about
|
| 628 |
+
# the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5)
|
| 629 |
+
pts = torch.tensor(
|
| 630 |
+
[
|
| 631 |
+
[-0.5 * w, -0.5 * h, 1.0],
|
| 632 |
+
[-0.5 * w, 0.5 * h, 1.0],
|
| 633 |
+
[0.5 * w, 0.5 * h, 1.0],
|
| 634 |
+
[0.5 * w, -0.5 * h, 1.0],
|
| 635 |
+
]
|
| 636 |
+
)
|
| 637 |
+
theta = torch.tensor(matrix, dtype=torch.float).view(2, 3)
|
| 638 |
+
new_pts = torch.matmul(pts, theta.T)
|
| 639 |
+
min_vals, _ = new_pts.min(dim=0)
|
| 640 |
+
max_vals, _ = new_pts.max(dim=0)
|
| 641 |
+
|
| 642 |
+
# shift points to [0, w] and [0, h] interval to match PIL results
|
| 643 |
+
min_vals += torch.tensor((w * 0.5, h * 0.5))
|
| 644 |
+
max_vals += torch.tensor((w * 0.5, h * 0.5))
|
| 645 |
+
|
| 646 |
+
# Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0
|
| 647 |
+
tol = 1e-4
|
| 648 |
+
cmax = torch.ceil((max_vals / tol).trunc_() * tol)
|
| 649 |
+
cmin = torch.floor((min_vals / tol).trunc_() * tol)
|
| 650 |
+
size = cmax - cmin
|
| 651 |
+
return int(size[0]), int(size[1]) # w, h
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
def rotate(
|
| 655 |
+
img: Tensor,
|
| 656 |
+
matrix: List[float],
|
| 657 |
+
interpolation: str = "nearest",
|
| 658 |
+
expand: bool = False,
|
| 659 |
+
fill: Optional[Union[int, float, List[float]]] = None,
|
| 660 |
+
) -> Tensor:
|
| 661 |
+
_assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
|
| 662 |
+
w, h = img.shape[-1], img.shape[-2]
|
| 663 |
+
ow, oh = _compute_affine_output_size(matrix, w, h) if expand else (w, h)
|
| 664 |
+
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
|
| 665 |
+
theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
|
| 666 |
+
# grid will be generated on the same device as theta and img
|
| 667 |
+
grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh)
|
| 668 |
+
|
| 669 |
+
return _apply_grid_transform(img, grid, interpolation, fill=fill)
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> Tensor:
|
| 673 |
+
# https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
|
| 674 |
+
# src/libImaging/Geometry.c#L394
|
| 675 |
+
|
| 676 |
+
#
|
| 677 |
+
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
|
| 678 |
+
# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
|
| 679 |
+
#
|
| 680 |
+
theta1 = torch.tensor(
|
| 681 |
+
[[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device
|
| 682 |
+
)
|
| 683 |
+
theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device)
|
| 684 |
+
|
| 685 |
+
d = 0.5
|
| 686 |
+
base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
|
| 687 |
+
x_grid = torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow, device=device)
|
| 688 |
+
base_grid[..., 0].copy_(x_grid)
|
| 689 |
+
y_grid = torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh, device=device).unsqueeze_(-1)
|
| 690 |
+
base_grid[..., 1].copy_(y_grid)
|
| 691 |
+
base_grid[..., 2].fill_(1)
|
| 692 |
+
|
| 693 |
+
rescaled_theta1 = theta1.transpose(1, 2) / torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device)
|
| 694 |
+
output_grid1 = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta1)
|
| 695 |
+
output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2))
|
| 696 |
+
|
| 697 |
+
output_grid = output_grid1 / output_grid2 - 1.0
|
| 698 |
+
return output_grid.view(1, oh, ow, 2)
|
| 699 |
+
|
| 700 |
+
|
| 701 |
+
def perspective(
|
| 702 |
+
img: Tensor,
|
| 703 |
+
perspective_coeffs: List[float],
|
| 704 |
+
interpolation: str = "bilinear",
|
| 705 |
+
fill: Optional[Union[int, float, List[float]]] = None,
|
| 706 |
+
) -> Tensor:
|
| 707 |
+
if not (isinstance(img, torch.Tensor)):
|
| 708 |
+
raise TypeError("Input img should be Tensor.")
|
| 709 |
+
|
| 710 |
+
_assert_image_tensor(img)
|
| 711 |
+
|
| 712 |
+
_assert_grid_transform_inputs(
|
| 713 |
+
img,
|
| 714 |
+
matrix=None,
|
| 715 |
+
interpolation=interpolation,
|
| 716 |
+
fill=fill,
|
| 717 |
+
supported_interpolation_modes=["nearest", "bilinear"],
|
| 718 |
+
coeffs=perspective_coeffs,
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
ow, oh = img.shape[-1], img.shape[-2]
|
| 722 |
+
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
|
| 723 |
+
grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=img.device)
|
| 724 |
+
return _apply_grid_transform(img, grid, interpolation, fill=fill)
|
| 725 |
+
|
| 726 |
+
|
| 727 |
+
def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> Tensor:
|
| 728 |
+
ksize_half = (kernel_size - 1) * 0.5
|
| 729 |
+
|
| 730 |
+
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size, dtype=dtype, device=device)
|
| 731 |
+
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
|
| 732 |
+
kernel1d = pdf / pdf.sum()
|
| 733 |
+
|
| 734 |
+
return kernel1d
|
| 735 |
+
|
| 736 |
+
|
| 737 |
+
def _get_gaussian_kernel2d(
|
| 738 |
+
kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
|
| 739 |
+
) -> Tensor:
|
| 740 |
+
kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0], dtype, device)
|
| 741 |
+
kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1], dtype, device)
|
| 742 |
+
kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :])
|
| 743 |
+
return kernel2d
|
| 744 |
+
|
| 745 |
+
|
| 746 |
+
def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor:
|
| 747 |
+
if not (isinstance(img, torch.Tensor)):
|
| 748 |
+
raise TypeError(f"img should be Tensor. Got {type(img)}")
|
| 749 |
+
|
| 750 |
+
_assert_image_tensor(img)
|
| 751 |
+
|
| 752 |
+
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
|
| 753 |
+
kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device)
|
| 754 |
+
kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])
|
| 755 |
+
|
| 756 |
+
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype])
|
| 757 |
+
|
| 758 |
+
# padding = (left, right, top, bottom)
|
| 759 |
+
padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
|
| 760 |
+
img = torch_pad(img, padding, mode="reflect")
|
| 761 |
+
img = conv2d(img, kernel, groups=img.shape[-3])
|
| 762 |
+
|
| 763 |
+
img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
|
| 764 |
+
return img
|
| 765 |
+
|
| 766 |
+
|
| 767 |
+
def invert(img: Tensor) -> Tensor:
|
| 768 |
+
|
| 769 |
+
_assert_image_tensor(img)
|
| 770 |
+
|
| 771 |
+
if img.ndim < 3:
|
| 772 |
+
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
|
| 773 |
+
|
| 774 |
+
_assert_channels(img, [1, 3])
|
| 775 |
+
|
| 776 |
+
return _max_value(img.dtype) - img
|
| 777 |
+
|
| 778 |
+
|
| 779 |
+
def posterize(img: Tensor, bits: int) -> Tensor:
|
| 780 |
+
|
| 781 |
+
_assert_image_tensor(img)
|
| 782 |
+
|
| 783 |
+
if img.ndim < 3:
|
| 784 |
+
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
|
| 785 |
+
if img.dtype != torch.uint8:
|
| 786 |
+
raise TypeError(f"Only torch.uint8 image tensors are supported, but found {img.dtype}")
|
| 787 |
+
|
| 788 |
+
_assert_channels(img, [1, 3])
|
| 789 |
+
mask = -int(2 ** (8 - bits)) # JIT-friendly for: ~(2 ** (8 - bits) - 1)
|
| 790 |
+
return img & mask
|
| 791 |
+
|
| 792 |
+
|
| 793 |
+
def solarize(img: Tensor, threshold: float) -> Tensor:
|
| 794 |
+
|
| 795 |
+
_assert_image_tensor(img)
|
| 796 |
+
|
| 797 |
+
if img.ndim < 3:
|
| 798 |
+
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
|
| 799 |
+
|
| 800 |
+
_assert_channels(img, [1, 3])
|
| 801 |
+
|
| 802 |
+
if threshold > _max_value(img.dtype):
|
| 803 |
+
raise TypeError("Threshold should be less than bound of img.")
|
| 804 |
+
|
| 805 |
+
inverted_img = invert(img)
|
| 806 |
+
return torch.where(img >= threshold, inverted_img, img)
|
| 807 |
+
|
| 808 |
+
|
| 809 |
+
def _blurred_degenerate_image(img: Tensor) -> Tensor:
|
| 810 |
+
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
|
| 811 |
+
|
| 812 |
+
kernel = torch.ones((3, 3), dtype=dtype, device=img.device)
|
| 813 |
+
kernel[1, 1] = 5.0
|
| 814 |
+
kernel /= kernel.sum()
|
| 815 |
+
kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])
|
| 816 |
+
|
| 817 |
+
result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype])
|
| 818 |
+
result_tmp = conv2d(result_tmp, kernel, groups=result_tmp.shape[-3])
|
| 819 |
+
result_tmp = _cast_squeeze_out(result_tmp, need_cast, need_squeeze, out_dtype)
|
| 820 |
+
|
| 821 |
+
result = img.clone()
|
| 822 |
+
result[..., 1:-1, 1:-1] = result_tmp
|
| 823 |
+
|
| 824 |
+
return result
|
| 825 |
+
|
| 826 |
+
|
| 827 |
+
def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
|
| 828 |
+
if sharpness_factor < 0:
|
| 829 |
+
raise ValueError(f"sharpness_factor ({sharpness_factor}) is not non-negative.")
|
| 830 |
+
|
| 831 |
+
_assert_image_tensor(img)
|
| 832 |
+
|
| 833 |
+
_assert_channels(img, [1, 3])
|
| 834 |
+
|
| 835 |
+
if img.size(-1) <= 2 or img.size(-2) <= 2:
|
| 836 |
+
return img
|
| 837 |
+
|
| 838 |
+
return _blend(img, _blurred_degenerate_image(img), sharpness_factor)
|
| 839 |
+
|
| 840 |
+
|
| 841 |
+
def autocontrast(img: Tensor) -> Tensor:
|
| 842 |
+
|
| 843 |
+
_assert_image_tensor(img)
|
| 844 |
+
|
| 845 |
+
if img.ndim < 3:
|
| 846 |
+
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
|
| 847 |
+
|
| 848 |
+
_assert_channels(img, [1, 3])
|
| 849 |
+
|
| 850 |
+
bound = _max_value(img.dtype)
|
| 851 |
+
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
|
| 852 |
+
|
| 853 |
+
minimum = img.amin(dim=(-2, -1), keepdim=True).to(dtype)
|
| 854 |
+
maximum = img.amax(dim=(-2, -1), keepdim=True).to(dtype)
|
| 855 |
+
scale = bound / (maximum - minimum)
|
| 856 |
+
eq_idxs = torch.isfinite(scale).logical_not()
|
| 857 |
+
minimum[eq_idxs] = 0
|
| 858 |
+
scale[eq_idxs] = 1
|
| 859 |
+
|
| 860 |
+
return ((img - minimum) * scale).clamp(0, bound).to(img.dtype)
|
| 861 |
+
|
| 862 |
+
|
| 863 |
+
def _scale_channel(img_chan: Tensor) -> Tensor:
|
| 864 |
+
# TODO: we should expect bincount to always be faster than histc, but this
|
| 865 |
+
# isn't always the case. Once
|
| 866 |
+
# https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if
|
| 867 |
+
# block and only use bincount.
|
| 868 |
+
if img_chan.is_cuda:
|
| 869 |
+
hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255)
|
| 870 |
+
else:
|
| 871 |
+
hist = torch.bincount(img_chan.reshape(-1), minlength=256)
|
| 872 |
+
|
| 873 |
+
nonzero_hist = hist[hist != 0]
|
| 874 |
+
step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode="floor")
|
| 875 |
+
if step == 0:
|
| 876 |
+
return img_chan
|
| 877 |
+
|
| 878 |
+
lut = torch.div(torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode="floor"), step, rounding_mode="floor")
|
| 879 |
+
lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255)
|
| 880 |
+
|
| 881 |
+
return lut[img_chan.to(torch.int64)].to(torch.uint8)
|
| 882 |
+
|
| 883 |
+
|
| 884 |
+
def _equalize_single_image(img: Tensor) -> Tensor:
|
| 885 |
+
return torch.stack([_scale_channel(img[c]) for c in range(img.size(0))])
|
| 886 |
+
|
| 887 |
+
|
| 888 |
+
def equalize(img: Tensor) -> Tensor:
|
| 889 |
+
|
| 890 |
+
_assert_image_tensor(img)
|
| 891 |
+
|
| 892 |
+
if not (3 <= img.ndim <= 4):
|
| 893 |
+
raise TypeError(f"Input image tensor should have 3 or 4 dimensions, but found {img.ndim}")
|
| 894 |
+
if img.dtype != torch.uint8:
|
| 895 |
+
raise TypeError(f"Only torch.uint8 image tensors are supported, but found {img.dtype}")
|
| 896 |
+
|
| 897 |
+
_assert_channels(img, [1, 3])
|
| 898 |
+
|
| 899 |
+
if img.ndim == 3:
|
| 900 |
+
return _equalize_single_image(img)
|
| 901 |
+
|
| 902 |
+
return torch.stack([_equalize_single_image(x) for x in img])
|
| 903 |
+
|
| 904 |
+
|
| 905 |
+
def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor:
|
| 906 |
+
_assert_image_tensor(tensor)
|
| 907 |
+
|
| 908 |
+
if not tensor.is_floating_point():
|
| 909 |
+
raise TypeError(f"Input tensor should be a float tensor. Got {tensor.dtype}.")
|
| 910 |
+
|
| 911 |
+
if tensor.ndim < 3:
|
| 912 |
+
raise ValueError(
|
| 913 |
+
f"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = {tensor.size()}"
|
| 914 |
+
)
|
| 915 |
+
|
| 916 |
+
if not inplace:
|
| 917 |
+
tensor = tensor.clone()
|
| 918 |
+
|
| 919 |
+
dtype = tensor.dtype
|
| 920 |
+
mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
|
| 921 |
+
std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
|
| 922 |
+
if (std == 0).any():
|
| 923 |
+
raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.")
|
| 924 |
+
if mean.ndim == 1:
|
| 925 |
+
mean = mean.view(-1, 1, 1)
|
| 926 |
+
if std.ndim == 1:
|
| 927 |
+
std = std.view(-1, 1, 1)
|
| 928 |
+
return tensor.sub_(mean).div_(std)
|
| 929 |
+
|
| 930 |
+
|
| 931 |
+
def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
|
| 932 |
+
_assert_image_tensor(img)
|
| 933 |
+
|
| 934 |
+
if not inplace:
|
| 935 |
+
img = img.clone()
|
| 936 |
+
|
| 937 |
+
img[..., i : i + h, j : j + w] = v
|
| 938 |
+
return img
|
| 939 |
+
|
| 940 |
+
|
| 941 |
+
def _create_identity_grid(size: List[int]) -> Tensor:
|
| 942 |
+
hw_space = [torch.linspace((-s + 1) / s, (s - 1) / s, s) for s in size]
|
| 943 |
+
grid_y, grid_x = torch.meshgrid(hw_space, indexing="ij")
|
| 944 |
+
return torch.stack([grid_x, grid_y], -1).unsqueeze(0) # 1 x H x W x 2
|
| 945 |
+
|
| 946 |
+
|
| 947 |
+
def elastic_transform(
|
| 948 |
+
img: Tensor,
|
| 949 |
+
displacement: Tensor,
|
| 950 |
+
interpolation: str = "bilinear",
|
| 951 |
+
fill: Optional[Union[int, float, List[float]]] = None,
|
| 952 |
+
) -> Tensor:
|
| 953 |
+
|
| 954 |
+
if not (isinstance(img, torch.Tensor)):
|
| 955 |
+
raise TypeError(f"img should be Tensor. Got {type(img)}")
|
| 956 |
+
|
| 957 |
+
size = list(img.shape[-2:])
|
| 958 |
+
displacement = displacement.to(img.device)
|
| 959 |
+
|
| 960 |
+
identity_grid = _create_identity_grid(size)
|
| 961 |
+
grid = identity_grid.to(img.device) + displacement
|
| 962 |
+
return _apply_grid_transform(img, grid, interpolation, fill)
|
.venv/lib/python3.11/site-packages/torchvision/transforms/_functional_video.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
warnings.warn(
|
| 7 |
+
"The 'torchvision.transforms._functional_video' module is deprecated since 0.12 and will be removed in the future. "
|
| 8 |
+
"Please use the 'torchvision.transforms.functional' module instead."
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _is_tensor_video_clip(clip):
|
| 13 |
+
if not torch.is_tensor(clip):
|
| 14 |
+
raise TypeError("clip should be Tensor. Got %s" % type(clip))
|
| 15 |
+
|
| 16 |
+
if not clip.ndimension() == 4:
|
| 17 |
+
raise ValueError("clip should be 4D. Got %dD" % clip.dim())
|
| 18 |
+
|
| 19 |
+
return True
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def crop(clip, i, j, h, w):
|
| 23 |
+
"""
|
| 24 |
+
Args:
|
| 25 |
+
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
|
| 26 |
+
"""
|
| 27 |
+
if len(clip.size()) != 4:
|
| 28 |
+
raise ValueError("clip should be a 4D tensor")
|
| 29 |
+
return clip[..., i : i + h, j : j + w]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def resize(clip, target_size, interpolation_mode):
|
| 33 |
+
if len(target_size) != 2:
|
| 34 |
+
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
|
| 35 |
+
return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
|
| 39 |
+
"""
|
| 40 |
+
Do spatial cropping and resizing to the video clip
|
| 41 |
+
Args:
|
| 42 |
+
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
|
| 43 |
+
i (int): i in (i,j) i.e coordinates of the upper left corner.
|
| 44 |
+
j (int): j in (i,j) i.e coordinates of the upper left corner.
|
| 45 |
+
h (int): Height of the cropped region.
|
| 46 |
+
w (int): Width of the cropped region.
|
| 47 |
+
size (tuple(int, int)): height and width of resized clip
|
| 48 |
+
Returns:
|
| 49 |
+
clip (torch.tensor): Resized and cropped clip. Size is (C, T, H, W)
|
| 50 |
+
"""
|
| 51 |
+
if not _is_tensor_video_clip(clip):
|
| 52 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
| 53 |
+
clip = crop(clip, i, j, h, w)
|
| 54 |
+
clip = resize(clip, size, interpolation_mode)
|
| 55 |
+
return clip
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def center_crop(clip, crop_size):
|
| 59 |
+
if not _is_tensor_video_clip(clip):
|
| 60 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
| 61 |
+
h, w = clip.size(-2), clip.size(-1)
|
| 62 |
+
th, tw = crop_size
|
| 63 |
+
if h < th or w < tw:
|
| 64 |
+
raise ValueError("height and width must be no smaller than crop_size")
|
| 65 |
+
|
| 66 |
+
i = int(round((h - th) / 2.0))
|
| 67 |
+
j = int(round((w - tw) / 2.0))
|
| 68 |
+
return crop(clip, i, j, th, tw)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def to_tensor(clip):
|
| 72 |
+
"""
|
| 73 |
+
Convert tensor data type from uint8 to float, divide value by 255.0 and
|
| 74 |
+
permute the dimensions of clip tensor
|
| 75 |
+
Args:
|
| 76 |
+
clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C)
|
| 77 |
+
Return:
|
| 78 |
+
clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W)
|
| 79 |
+
"""
|
| 80 |
+
_is_tensor_video_clip(clip)
|
| 81 |
+
if not clip.dtype == torch.uint8:
|
| 82 |
+
raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
|
| 83 |
+
return clip.float().permute(3, 0, 1, 2) / 255.0
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def normalize(clip, mean, std, inplace=False):
|
| 87 |
+
"""
|
| 88 |
+
Args:
|
| 89 |
+
clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
|
| 90 |
+
mean (tuple): pixel RGB mean. Size is (3)
|
| 91 |
+
std (tuple): pixel standard deviation. Size is (3)
|
| 92 |
+
Returns:
|
| 93 |
+
normalized clip (torch.tensor): Size is (C, T, H, W)
|
| 94 |
+
"""
|
| 95 |
+
if not _is_tensor_video_clip(clip):
|
| 96 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
| 97 |
+
if not inplace:
|
| 98 |
+
clip = clip.clone()
|
| 99 |
+
mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
|
| 100 |
+
std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
|
| 101 |
+
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
|
| 102 |
+
return clip
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def hflip(clip):
|
| 106 |
+
"""
|
| 107 |
+
Args:
|
| 108 |
+
clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
|
| 109 |
+
Returns:
|
| 110 |
+
flipped clip (torch.tensor): Size is (C, T, H, W)
|
| 111 |
+
"""
|
| 112 |
+
if not _is_tensor_video_clip(clip):
|
| 113 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
| 114 |
+
return clip.flip(-1)
|
.venv/lib/python3.11/site-packages/torchvision/transforms/_presets.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file is part of the private API. Please do not use directly these classes as they will be modified on
|
| 3 |
+
future versions without warning. The classes should be accessed only via the transforms argument of Weights.
|
| 4 |
+
"""
|
| 5 |
+
from typing import Optional, Tuple, Union
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn, Tensor
|
| 9 |
+
|
| 10 |
+
from . import functional as F, InterpolationMode
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
__all__ = [
|
| 14 |
+
"ObjectDetection",
|
| 15 |
+
"ImageClassification",
|
| 16 |
+
"VideoClassification",
|
| 17 |
+
"SemanticSegmentation",
|
| 18 |
+
"OpticalFlow",
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ObjectDetection(nn.Module):
|
| 23 |
+
def forward(self, img: Tensor) -> Tensor:
|
| 24 |
+
if not isinstance(img, Tensor):
|
| 25 |
+
img = F.pil_to_tensor(img)
|
| 26 |
+
return F.convert_image_dtype(img, torch.float)
|
| 27 |
+
|
| 28 |
+
def __repr__(self) -> str:
|
| 29 |
+
return self.__class__.__name__ + "()"
|
| 30 |
+
|
| 31 |
+
def describe(self) -> str:
|
| 32 |
+
return (
|
| 33 |
+
"Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
|
| 34 |
+
"The images are rescaled to ``[0.0, 1.0]``."
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class ImageClassification(nn.Module):
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
*,
|
| 42 |
+
crop_size: int,
|
| 43 |
+
resize_size: int = 256,
|
| 44 |
+
mean: Tuple[float, ...] = (0.485, 0.456, 0.406),
|
| 45 |
+
std: Tuple[float, ...] = (0.229, 0.224, 0.225),
|
| 46 |
+
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
|
| 47 |
+
antialias: Optional[bool] = True,
|
| 48 |
+
) -> None:
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.crop_size = [crop_size]
|
| 51 |
+
self.resize_size = [resize_size]
|
| 52 |
+
self.mean = list(mean)
|
| 53 |
+
self.std = list(std)
|
| 54 |
+
self.interpolation = interpolation
|
| 55 |
+
self.antialias = antialias
|
| 56 |
+
|
| 57 |
+
def forward(self, img: Tensor) -> Tensor:
|
| 58 |
+
img = F.resize(img, self.resize_size, interpolation=self.interpolation, antialias=self.antialias)
|
| 59 |
+
img = F.center_crop(img, self.crop_size)
|
| 60 |
+
if not isinstance(img, Tensor):
|
| 61 |
+
img = F.pil_to_tensor(img)
|
| 62 |
+
img = F.convert_image_dtype(img, torch.float)
|
| 63 |
+
img = F.normalize(img, mean=self.mean, std=self.std)
|
| 64 |
+
return img
|
| 65 |
+
|
| 66 |
+
def __repr__(self) -> str:
|
| 67 |
+
format_string = self.__class__.__name__ + "("
|
| 68 |
+
format_string += f"\n crop_size={self.crop_size}"
|
| 69 |
+
format_string += f"\n resize_size={self.resize_size}"
|
| 70 |
+
format_string += f"\n mean={self.mean}"
|
| 71 |
+
format_string += f"\n std={self.std}"
|
| 72 |
+
format_string += f"\n interpolation={self.interpolation}"
|
| 73 |
+
format_string += "\n)"
|
| 74 |
+
return format_string
|
| 75 |
+
|
| 76 |
+
def describe(self) -> str:
|
| 77 |
+
return (
|
| 78 |
+
"Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
|
| 79 |
+
f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, "
|
| 80 |
+
f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to "
|
| 81 |
+
f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``."
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class VideoClassification(nn.Module):
|
| 86 |
+
def __init__(
|
| 87 |
+
self,
|
| 88 |
+
*,
|
| 89 |
+
crop_size: Tuple[int, int],
|
| 90 |
+
resize_size: Union[Tuple[int], Tuple[int, int]],
|
| 91 |
+
mean: Tuple[float, ...] = (0.43216, 0.394666, 0.37645),
|
| 92 |
+
std: Tuple[float, ...] = (0.22803, 0.22145, 0.216989),
|
| 93 |
+
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
|
| 94 |
+
) -> None:
|
| 95 |
+
super().__init__()
|
| 96 |
+
self.crop_size = list(crop_size)
|
| 97 |
+
self.resize_size = list(resize_size)
|
| 98 |
+
self.mean = list(mean)
|
| 99 |
+
self.std = list(std)
|
| 100 |
+
self.interpolation = interpolation
|
| 101 |
+
|
| 102 |
+
def forward(self, vid: Tensor) -> Tensor:
|
| 103 |
+
need_squeeze = False
|
| 104 |
+
if vid.ndim < 5:
|
| 105 |
+
vid = vid.unsqueeze(dim=0)
|
| 106 |
+
need_squeeze = True
|
| 107 |
+
|
| 108 |
+
N, T, C, H, W = vid.shape
|
| 109 |
+
vid = vid.view(-1, C, H, W)
|
| 110 |
+
# We hard-code antialias=False to preserve results after we changed
|
| 111 |
+
# its default from None to True (see
|
| 112 |
+
# https://github.com/pytorch/vision/pull/7160)
|
| 113 |
+
# TODO: we could re-train the video models with antialias=True?
|
| 114 |
+
vid = F.resize(vid, self.resize_size, interpolation=self.interpolation, antialias=False)
|
| 115 |
+
vid = F.center_crop(vid, self.crop_size)
|
| 116 |
+
vid = F.convert_image_dtype(vid, torch.float)
|
| 117 |
+
vid = F.normalize(vid, mean=self.mean, std=self.std)
|
| 118 |
+
H, W = self.crop_size
|
| 119 |
+
vid = vid.view(N, T, C, H, W)
|
| 120 |
+
vid = vid.permute(0, 2, 1, 3, 4) # (N, T, C, H, W) => (N, C, T, H, W)
|
| 121 |
+
|
| 122 |
+
if need_squeeze:
|
| 123 |
+
vid = vid.squeeze(dim=0)
|
| 124 |
+
return vid
|
| 125 |
+
|
| 126 |
+
def __repr__(self) -> str:
|
| 127 |
+
format_string = self.__class__.__name__ + "("
|
| 128 |
+
format_string += f"\n crop_size={self.crop_size}"
|
| 129 |
+
format_string += f"\n resize_size={self.resize_size}"
|
| 130 |
+
format_string += f"\n mean={self.mean}"
|
| 131 |
+
format_string += f"\n std={self.std}"
|
| 132 |
+
format_string += f"\n interpolation={self.interpolation}"
|
| 133 |
+
format_string += "\n)"
|
| 134 |
+
return format_string
|
| 135 |
+
|
| 136 |
+
def describe(self) -> str:
|
| 137 |
+
return (
|
| 138 |
+
"Accepts batched ``(B, T, C, H, W)`` and single ``(T, C, H, W)`` video frame ``torch.Tensor`` objects. "
|
| 139 |
+
f"The frames are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, "
|
| 140 |
+
f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to "
|
| 141 |
+
f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``. Finally the output "
|
| 142 |
+
"dimensions are permuted to ``(..., C, T, H, W)`` tensors."
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class SemanticSegmentation(nn.Module):
|
| 147 |
+
def __init__(
|
| 148 |
+
self,
|
| 149 |
+
*,
|
| 150 |
+
resize_size: Optional[int],
|
| 151 |
+
mean: Tuple[float, ...] = (0.485, 0.456, 0.406),
|
| 152 |
+
std: Tuple[float, ...] = (0.229, 0.224, 0.225),
|
| 153 |
+
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
|
| 154 |
+
antialias: Optional[bool] = True,
|
| 155 |
+
) -> None:
|
| 156 |
+
super().__init__()
|
| 157 |
+
self.resize_size = [resize_size] if resize_size is not None else None
|
| 158 |
+
self.mean = list(mean)
|
| 159 |
+
self.std = list(std)
|
| 160 |
+
self.interpolation = interpolation
|
| 161 |
+
self.antialias = antialias
|
| 162 |
+
|
| 163 |
+
def forward(self, img: Tensor) -> Tensor:
|
| 164 |
+
if isinstance(self.resize_size, list):
|
| 165 |
+
img = F.resize(img, self.resize_size, interpolation=self.interpolation, antialias=self.antialias)
|
| 166 |
+
if not isinstance(img, Tensor):
|
| 167 |
+
img = F.pil_to_tensor(img)
|
| 168 |
+
img = F.convert_image_dtype(img, torch.float)
|
| 169 |
+
img = F.normalize(img, mean=self.mean, std=self.std)
|
| 170 |
+
return img
|
| 171 |
+
|
| 172 |
+
def __repr__(self) -> str:
|
| 173 |
+
format_string = self.__class__.__name__ + "("
|
| 174 |
+
format_string += f"\n resize_size={self.resize_size}"
|
| 175 |
+
format_string += f"\n mean={self.mean}"
|
| 176 |
+
format_string += f"\n std={self.std}"
|
| 177 |
+
format_string += f"\n interpolation={self.interpolation}"
|
| 178 |
+
format_string += "\n)"
|
| 179 |
+
return format_string
|
| 180 |
+
|
| 181 |
+
def describe(self) -> str:
|
| 182 |
+
return (
|
| 183 |
+
"Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
|
| 184 |
+
f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``. "
|
| 185 |
+
f"Finally the values are first rescaled to ``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and "
|
| 186 |
+
f"``std={self.std}``."
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class OpticalFlow(nn.Module):
|
| 191 |
+
def forward(self, img1: Tensor, img2: Tensor) -> Tuple[Tensor, Tensor]:
|
| 192 |
+
if not isinstance(img1, Tensor):
|
| 193 |
+
img1 = F.pil_to_tensor(img1)
|
| 194 |
+
if not isinstance(img2, Tensor):
|
| 195 |
+
img2 = F.pil_to_tensor(img2)
|
| 196 |
+
|
| 197 |
+
img1 = F.convert_image_dtype(img1, torch.float)
|
| 198 |
+
img2 = F.convert_image_dtype(img2, torch.float)
|
| 199 |
+
|
| 200 |
+
# map [0, 1] into [-1, 1]
|
| 201 |
+
img1 = F.normalize(img1, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
| 202 |
+
img2 = F.normalize(img2, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
| 203 |
+
|
| 204 |
+
img1 = img1.contiguous()
|
| 205 |
+
img2 = img2.contiguous()
|
| 206 |
+
|
| 207 |
+
return img1, img2
|
| 208 |
+
|
| 209 |
+
def __repr__(self) -> str:
|
| 210 |
+
return self.__class__.__name__ + "()"
|
| 211 |
+
|
| 212 |
+
def describe(self) -> str:
|
| 213 |
+
return (
|
| 214 |
+
"Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
|
| 215 |
+
"The images are rescaled to ``[-1.0, 1.0]``."
|
| 216 |
+
)
|
.venv/lib/python3.11/site-packages/torchvision/transforms/_transforms_video.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
import numbers
|
| 4 |
+
import random
|
| 5 |
+
import warnings
|
| 6 |
+
|
| 7 |
+
from torchvision.transforms import RandomCrop, RandomResizedCrop
|
| 8 |
+
|
| 9 |
+
from . import _functional_video as F
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"RandomCropVideo",
|
| 14 |
+
"RandomResizedCropVideo",
|
| 15 |
+
"CenterCropVideo",
|
| 16 |
+
"NormalizeVideo",
|
| 17 |
+
"ToTensorVideo",
|
| 18 |
+
"RandomHorizontalFlipVideo",
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
warnings.warn(
|
| 23 |
+
"The 'torchvision.transforms._transforms_video' module is deprecated since 0.12 and will be removed in the future. "
|
| 24 |
+
"Please use the 'torchvision.transforms' module instead."
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class RandomCropVideo(RandomCrop):
|
| 29 |
+
def __init__(self, size):
|
| 30 |
+
if isinstance(size, numbers.Number):
|
| 31 |
+
self.size = (int(size), int(size))
|
| 32 |
+
else:
|
| 33 |
+
self.size = size
|
| 34 |
+
|
| 35 |
+
def __call__(self, clip):
|
| 36 |
+
"""
|
| 37 |
+
Args:
|
| 38 |
+
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
|
| 39 |
+
Returns:
|
| 40 |
+
torch.tensor: randomly cropped/resized video clip.
|
| 41 |
+
size is (C, T, OH, OW)
|
| 42 |
+
"""
|
| 43 |
+
i, j, h, w = self.get_params(clip, self.size)
|
| 44 |
+
return F.crop(clip, i, j, h, w)
|
| 45 |
+
|
| 46 |
+
def __repr__(self) -> str:
|
| 47 |
+
return f"{self.__class__.__name__}(size={self.size})"
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class RandomResizedCropVideo(RandomResizedCrop):
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
size,
|
| 54 |
+
scale=(0.08, 1.0),
|
| 55 |
+
ratio=(3.0 / 4.0, 4.0 / 3.0),
|
| 56 |
+
interpolation_mode="bilinear",
|
| 57 |
+
):
|
| 58 |
+
if isinstance(size, tuple):
|
| 59 |
+
if len(size) != 2:
|
| 60 |
+
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
| 61 |
+
self.size = size
|
| 62 |
+
else:
|
| 63 |
+
self.size = (size, size)
|
| 64 |
+
|
| 65 |
+
self.interpolation_mode = interpolation_mode
|
| 66 |
+
self.scale = scale
|
| 67 |
+
self.ratio = ratio
|
| 68 |
+
|
| 69 |
+
def __call__(self, clip):
|
| 70 |
+
"""
|
| 71 |
+
Args:
|
| 72 |
+
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
|
| 73 |
+
Returns:
|
| 74 |
+
torch.tensor: randomly cropped/resized video clip.
|
| 75 |
+
size is (C, T, H, W)
|
| 76 |
+
"""
|
| 77 |
+
i, j, h, w = self.get_params(clip, self.scale, self.ratio)
|
| 78 |
+
return F.resized_crop(clip, i, j, h, w, self.size, self.interpolation_mode)
|
| 79 |
+
|
| 80 |
+
def __repr__(self) -> str:
|
| 81 |
+
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}, scale={self.scale}, ratio={self.ratio})"
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class CenterCropVideo:
|
| 85 |
+
def __init__(self, crop_size):
|
| 86 |
+
if isinstance(crop_size, numbers.Number):
|
| 87 |
+
self.crop_size = (int(crop_size), int(crop_size))
|
| 88 |
+
else:
|
| 89 |
+
self.crop_size = crop_size
|
| 90 |
+
|
| 91 |
+
def __call__(self, clip):
|
| 92 |
+
"""
|
| 93 |
+
Args:
|
| 94 |
+
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
|
| 95 |
+
Returns:
|
| 96 |
+
torch.tensor: central cropping of video clip. Size is
|
| 97 |
+
(C, T, crop_size, crop_size)
|
| 98 |
+
"""
|
| 99 |
+
return F.center_crop(clip, self.crop_size)
|
| 100 |
+
|
| 101 |
+
def __repr__(self) -> str:
|
| 102 |
+
return f"{self.__class__.__name__}(crop_size={self.crop_size})"
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class NormalizeVideo:
|
| 106 |
+
"""
|
| 107 |
+
Normalize the video clip by mean subtraction and division by standard deviation
|
| 108 |
+
Args:
|
| 109 |
+
mean (3-tuple): pixel RGB mean
|
| 110 |
+
std (3-tuple): pixel RGB standard deviation
|
| 111 |
+
inplace (boolean): whether do in-place normalization
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
def __init__(self, mean, std, inplace=False):
|
| 115 |
+
self.mean = mean
|
| 116 |
+
self.std = std
|
| 117 |
+
self.inplace = inplace
|
| 118 |
+
|
| 119 |
+
def __call__(self, clip):
|
| 120 |
+
"""
|
| 121 |
+
Args:
|
| 122 |
+
clip (torch.tensor): video clip to be normalized. Size is (C, T, H, W)
|
| 123 |
+
"""
|
| 124 |
+
return F.normalize(clip, self.mean, self.std, self.inplace)
|
| 125 |
+
|
| 126 |
+
def __repr__(self) -> str:
|
| 127 |
+
return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class ToTensorVideo:
|
| 131 |
+
"""
|
| 132 |
+
Convert tensor data type from uint8 to float, divide value by 255.0 and
|
| 133 |
+
permute the dimensions of clip tensor
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
def __init__(self):
|
| 137 |
+
pass
|
| 138 |
+
|
| 139 |
+
def __call__(self, clip):
|
| 140 |
+
"""
|
| 141 |
+
Args:
|
| 142 |
+
clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C)
|
| 143 |
+
Return:
|
| 144 |
+
clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W)
|
| 145 |
+
"""
|
| 146 |
+
return F.to_tensor(clip)
|
| 147 |
+
|
| 148 |
+
def __repr__(self) -> str:
|
| 149 |
+
return self.__class__.__name__
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class RandomHorizontalFlipVideo:
|
| 153 |
+
"""
|
| 154 |
+
Flip the video clip along the horizontal direction with a given probability
|
| 155 |
+
Args:
|
| 156 |
+
p (float): probability of the clip being flipped. Default value is 0.5
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
def __init__(self, p=0.5):
|
| 160 |
+
self.p = p
|
| 161 |
+
|
| 162 |
+
def __call__(self, clip):
|
| 163 |
+
"""
|
| 164 |
+
Args:
|
| 165 |
+
clip (torch.tensor): Size is (C, T, H, W)
|
| 166 |
+
Return:
|
| 167 |
+
clip (torch.tensor): Size is (C, T, H, W)
|
| 168 |
+
"""
|
| 169 |
+
if random.random() < self.p:
|
| 170 |
+
clip = F.hflip(clip)
|
| 171 |
+
return clip
|
| 172 |
+
|
| 173 |
+
def __repr__(self) -> str:
|
| 174 |
+
return f"{self.__class__.__name__}(p={self.p})"
|
.venv/lib/python3.11/site-packages/torchvision/transforms/autoaugment.py
ADDED
|
@@ -0,0 +1,615 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from enum import Enum
|
| 3 |
+
from typing import Dict, List, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
|
| 8 |
+
from . import functional as F, InterpolationMode
|
| 9 |
+
|
| 10 |
+
__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide", "AugMix"]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _apply_op(
|
| 14 |
+
img: Tensor, op_name: str, magnitude: float, interpolation: InterpolationMode, fill: Optional[List[float]]
|
| 15 |
+
):
|
| 16 |
+
if op_name == "ShearX":
|
| 17 |
+
# magnitude should be arctan(magnitude)
|
| 18 |
+
# official autoaug: (1, level, 0, 0, 1, 0)
|
| 19 |
+
# https://github.com/tensorflow/models/blob/dd02069717128186b88afa8d857ce57d17957f03/research/autoaugment/augmentation_transforms.py#L290
|
| 20 |
+
# compared to
|
| 21 |
+
# torchvision: (1, tan(level), 0, 0, 1, 0)
|
| 22 |
+
# https://github.com/pytorch/vision/blob/0c2373d0bba3499e95776e7936e207d8a1676e65/torchvision/transforms/functional.py#L976
|
| 23 |
+
img = F.affine(
|
| 24 |
+
img,
|
| 25 |
+
angle=0.0,
|
| 26 |
+
translate=[0, 0],
|
| 27 |
+
scale=1.0,
|
| 28 |
+
shear=[math.degrees(math.atan(magnitude)), 0.0],
|
| 29 |
+
interpolation=interpolation,
|
| 30 |
+
fill=fill,
|
| 31 |
+
center=[0, 0],
|
| 32 |
+
)
|
| 33 |
+
elif op_name == "ShearY":
|
| 34 |
+
# magnitude should be arctan(magnitude)
|
| 35 |
+
# See above
|
| 36 |
+
img = F.affine(
|
| 37 |
+
img,
|
| 38 |
+
angle=0.0,
|
| 39 |
+
translate=[0, 0],
|
| 40 |
+
scale=1.0,
|
| 41 |
+
shear=[0.0, math.degrees(math.atan(magnitude))],
|
| 42 |
+
interpolation=interpolation,
|
| 43 |
+
fill=fill,
|
| 44 |
+
center=[0, 0],
|
| 45 |
+
)
|
| 46 |
+
elif op_name == "TranslateX":
|
| 47 |
+
img = F.affine(
|
| 48 |
+
img,
|
| 49 |
+
angle=0.0,
|
| 50 |
+
translate=[int(magnitude), 0],
|
| 51 |
+
scale=1.0,
|
| 52 |
+
interpolation=interpolation,
|
| 53 |
+
shear=[0.0, 0.0],
|
| 54 |
+
fill=fill,
|
| 55 |
+
)
|
| 56 |
+
elif op_name == "TranslateY":
|
| 57 |
+
img = F.affine(
|
| 58 |
+
img,
|
| 59 |
+
angle=0.0,
|
| 60 |
+
translate=[0, int(magnitude)],
|
| 61 |
+
scale=1.0,
|
| 62 |
+
interpolation=interpolation,
|
| 63 |
+
shear=[0.0, 0.0],
|
| 64 |
+
fill=fill,
|
| 65 |
+
)
|
| 66 |
+
elif op_name == "Rotate":
|
| 67 |
+
img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill)
|
| 68 |
+
elif op_name == "Brightness":
|
| 69 |
+
img = F.adjust_brightness(img, 1.0 + magnitude)
|
| 70 |
+
elif op_name == "Color":
|
| 71 |
+
img = F.adjust_saturation(img, 1.0 + magnitude)
|
| 72 |
+
elif op_name == "Contrast":
|
| 73 |
+
img = F.adjust_contrast(img, 1.0 + magnitude)
|
| 74 |
+
elif op_name == "Sharpness":
|
| 75 |
+
img = F.adjust_sharpness(img, 1.0 + magnitude)
|
| 76 |
+
elif op_name == "Posterize":
|
| 77 |
+
img = F.posterize(img, int(magnitude))
|
| 78 |
+
elif op_name == "Solarize":
|
| 79 |
+
img = F.solarize(img, magnitude)
|
| 80 |
+
elif op_name == "AutoContrast":
|
| 81 |
+
img = F.autocontrast(img)
|
| 82 |
+
elif op_name == "Equalize":
|
| 83 |
+
img = F.equalize(img)
|
| 84 |
+
elif op_name == "Invert":
|
| 85 |
+
img = F.invert(img)
|
| 86 |
+
elif op_name == "Identity":
|
| 87 |
+
pass
|
| 88 |
+
else:
|
| 89 |
+
raise ValueError(f"The provided operator {op_name} is not recognized.")
|
| 90 |
+
return img
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class AutoAugmentPolicy(Enum):
|
| 94 |
+
"""AutoAugment policies learned on different datasets.
|
| 95 |
+
Available policies are IMAGENET, CIFAR10 and SVHN.
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
IMAGENET = "imagenet"
|
| 99 |
+
CIFAR10 = "cifar10"
|
| 100 |
+
SVHN = "svhn"
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# FIXME: Eliminate copy-pasted code for fill standardization and _augmentation_space() by moving stuff on a base class
|
| 104 |
+
class AutoAugment(torch.nn.Module):
|
| 105 |
+
r"""AutoAugment data augmentation method based on
|
| 106 |
+
`"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
|
| 107 |
+
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
|
| 108 |
+
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
| 109 |
+
If img is PIL Image, it is expected to be in mode "L" or "RGB".
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
policy (AutoAugmentPolicy): Desired policy enum defined by
|
| 113 |
+
:class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``.
|
| 114 |
+
interpolation (InterpolationMode): Desired interpolation enum defined by
|
| 115 |
+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
|
| 116 |
+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
|
| 117 |
+
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
|
| 118 |
+
image. If given a number, the value is used for all bands respectively.
|
| 119 |
+
"""
|
| 120 |
+
|
| 121 |
+
def __init__(
|
| 122 |
+
self,
|
| 123 |
+
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
|
| 124 |
+
interpolation: InterpolationMode = InterpolationMode.NEAREST,
|
| 125 |
+
fill: Optional[List[float]] = None,
|
| 126 |
+
) -> None:
|
| 127 |
+
super().__init__()
|
| 128 |
+
self.policy = policy
|
| 129 |
+
self.interpolation = interpolation
|
| 130 |
+
self.fill = fill
|
| 131 |
+
self.policies = self._get_policies(policy)
|
| 132 |
+
|
| 133 |
+
def _get_policies(
|
| 134 |
+
self, policy: AutoAugmentPolicy
|
| 135 |
+
) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
|
| 136 |
+
if policy == AutoAugmentPolicy.IMAGENET:
|
| 137 |
+
return [
|
| 138 |
+
(("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
|
| 139 |
+
(("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
|
| 140 |
+
(("Equalize", 0.8, None), ("Equalize", 0.6, None)),
|
| 141 |
+
(("Posterize", 0.6, 7), ("Posterize", 0.6, 6)),
|
| 142 |
+
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
|
| 143 |
+
(("Equalize", 0.4, None), ("Rotate", 0.8, 8)),
|
| 144 |
+
(("Solarize", 0.6, 3), ("Equalize", 0.6, None)),
|
| 145 |
+
(("Posterize", 0.8, 5), ("Equalize", 1.0, None)),
|
| 146 |
+
(("Rotate", 0.2, 3), ("Solarize", 0.6, 8)),
|
| 147 |
+
(("Equalize", 0.6, None), ("Posterize", 0.4, 6)),
|
| 148 |
+
(("Rotate", 0.8, 8), ("Color", 0.4, 0)),
|
| 149 |
+
(("Rotate", 0.4, 9), ("Equalize", 0.6, None)),
|
| 150 |
+
(("Equalize", 0.0, None), ("Equalize", 0.8, None)),
|
| 151 |
+
(("Invert", 0.6, None), ("Equalize", 1.0, None)),
|
| 152 |
+
(("Color", 0.6, 4), ("Contrast", 1.0, 8)),
|
| 153 |
+
(("Rotate", 0.8, 8), ("Color", 1.0, 2)),
|
| 154 |
+
(("Color", 0.8, 8), ("Solarize", 0.8, 7)),
|
| 155 |
+
(("Sharpness", 0.4, 7), ("Invert", 0.6, None)),
|
| 156 |
+
(("ShearX", 0.6, 5), ("Equalize", 1.0, None)),
|
| 157 |
+
(("Color", 0.4, 0), ("Equalize", 0.6, None)),
|
| 158 |
+
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
|
| 159 |
+
(("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
|
| 160 |
+
(("Invert", 0.6, None), ("Equalize", 1.0, None)),
|
| 161 |
+
(("Color", 0.6, 4), ("Contrast", 1.0, 8)),
|
| 162 |
+
(("Equalize", 0.8, None), ("Equalize", 0.6, None)),
|
| 163 |
+
]
|
| 164 |
+
elif policy == AutoAugmentPolicy.CIFAR10:
|
| 165 |
+
return [
|
| 166 |
+
(("Invert", 0.1, None), ("Contrast", 0.2, 6)),
|
| 167 |
+
(("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)),
|
| 168 |
+
(("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
|
| 169 |
+
(("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)),
|
| 170 |
+
(("AutoContrast", 0.5, None), ("Equalize", 0.9, None)),
|
| 171 |
+
(("ShearY", 0.2, 7), ("Posterize", 0.3, 7)),
|
| 172 |
+
(("Color", 0.4, 3), ("Brightness", 0.6, 7)),
|
| 173 |
+
(("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)),
|
| 174 |
+
(("Equalize", 0.6, None), ("Equalize", 0.5, None)),
|
| 175 |
+
(("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)),
|
| 176 |
+
(("Color", 0.7, 7), ("TranslateX", 0.5, 8)),
|
| 177 |
+
(("Equalize", 0.3, None), ("AutoContrast", 0.4, None)),
|
| 178 |
+
(("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)),
|
| 179 |
+
(("Brightness", 0.9, 6), ("Color", 0.2, 8)),
|
| 180 |
+
(("Solarize", 0.5, 2), ("Invert", 0.0, None)),
|
| 181 |
+
(("Equalize", 0.2, None), ("AutoContrast", 0.6, None)),
|
| 182 |
+
(("Equalize", 0.2, None), ("Equalize", 0.6, None)),
|
| 183 |
+
(("Color", 0.9, 9), ("Equalize", 0.6, None)),
|
| 184 |
+
(("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)),
|
| 185 |
+
(("Brightness", 0.1, 3), ("Color", 0.7, 0)),
|
| 186 |
+
(("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)),
|
| 187 |
+
(("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)),
|
| 188 |
+
(("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)),
|
| 189 |
+
(("Equalize", 0.8, None), ("Invert", 0.1, None)),
|
| 190 |
+
(("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)),
|
| 191 |
+
]
|
| 192 |
+
elif policy == AutoAugmentPolicy.SVHN:
|
| 193 |
+
return [
|
| 194 |
+
(("ShearX", 0.9, 4), ("Invert", 0.2, None)),
|
| 195 |
+
(("ShearY", 0.9, 8), ("Invert", 0.7, None)),
|
| 196 |
+
(("Equalize", 0.6, None), ("Solarize", 0.6, 6)),
|
| 197 |
+
(("Invert", 0.9, None), ("Equalize", 0.6, None)),
|
| 198 |
+
(("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
|
| 199 |
+
(("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)),
|
| 200 |
+
(("ShearY", 0.9, 8), ("Invert", 0.4, None)),
|
| 201 |
+
(("ShearY", 0.9, 5), ("Solarize", 0.2, 6)),
|
| 202 |
+
(("Invert", 0.9, None), ("AutoContrast", 0.8, None)),
|
| 203 |
+
(("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
|
| 204 |
+
(("ShearX", 0.9, 4), ("Solarize", 0.3, 3)),
|
| 205 |
+
(("ShearY", 0.8, 8), ("Invert", 0.7, None)),
|
| 206 |
+
(("Equalize", 0.9, None), ("TranslateY", 0.6, 6)),
|
| 207 |
+
(("Invert", 0.9, None), ("Equalize", 0.6, None)),
|
| 208 |
+
(("Contrast", 0.3, 3), ("Rotate", 0.8, 4)),
|
| 209 |
+
(("Invert", 0.8, None), ("TranslateY", 0.0, 2)),
|
| 210 |
+
(("ShearY", 0.7, 6), ("Solarize", 0.4, 8)),
|
| 211 |
+
(("Invert", 0.6, None), ("Rotate", 0.8, 4)),
|
| 212 |
+
(("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)),
|
| 213 |
+
(("ShearX", 0.1, 6), ("Invert", 0.6, None)),
|
| 214 |
+
(("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)),
|
| 215 |
+
(("ShearY", 0.8, 4), ("Invert", 0.8, None)),
|
| 216 |
+
(("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)),
|
| 217 |
+
(("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)),
|
| 218 |
+
(("ShearX", 0.7, 2), ("Invert", 0.1, None)),
|
| 219 |
+
]
|
| 220 |
+
else:
|
| 221 |
+
raise ValueError(f"The provided policy {policy} is not recognized.")
|
| 222 |
+
|
| 223 |
+
def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[Tensor, bool]]:
|
| 224 |
+
return {
|
| 225 |
+
# op_name: (magnitudes, signed)
|
| 226 |
+
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
|
| 227 |
+
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
|
| 228 |
+
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
|
| 229 |
+
"TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
|
| 230 |
+
"Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
|
| 231 |
+
"Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
|
| 232 |
+
"Color": (torch.linspace(0.0, 0.9, num_bins), True),
|
| 233 |
+
"Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
|
| 234 |
+
"Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
|
| 235 |
+
"Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
|
| 236 |
+
"Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
|
| 237 |
+
"AutoContrast": (torch.tensor(0.0), False),
|
| 238 |
+
"Equalize": (torch.tensor(0.0), False),
|
| 239 |
+
"Invert": (torch.tensor(0.0), False),
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
@staticmethod
|
| 243 |
+
def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]:
|
| 244 |
+
"""Get parameters for autoaugment transformation
|
| 245 |
+
|
| 246 |
+
Returns:
|
| 247 |
+
params required by the autoaugment transformation
|
| 248 |
+
"""
|
| 249 |
+
policy_id = int(torch.randint(transform_num, (1,)).item())
|
| 250 |
+
probs = torch.rand((2,))
|
| 251 |
+
signs = torch.randint(2, (2,))
|
| 252 |
+
|
| 253 |
+
return policy_id, probs, signs
|
| 254 |
+
|
| 255 |
+
def forward(self, img: Tensor) -> Tensor:
|
| 256 |
+
"""
|
| 257 |
+
img (PIL Image or Tensor): Image to be transformed.
|
| 258 |
+
|
| 259 |
+
Returns:
|
| 260 |
+
PIL Image or Tensor: AutoAugmented image.
|
| 261 |
+
"""
|
| 262 |
+
fill = self.fill
|
| 263 |
+
channels, height, width = F.get_dimensions(img)
|
| 264 |
+
if isinstance(img, Tensor):
|
| 265 |
+
if isinstance(fill, (int, float)):
|
| 266 |
+
fill = [float(fill)] * channels
|
| 267 |
+
elif fill is not None:
|
| 268 |
+
fill = [float(f) for f in fill]
|
| 269 |
+
|
| 270 |
+
transform_id, probs, signs = self.get_params(len(self.policies))
|
| 271 |
+
|
| 272 |
+
op_meta = self._augmentation_space(10, (height, width))
|
| 273 |
+
for i, (op_name, p, magnitude_id) in enumerate(self.policies[transform_id]):
|
| 274 |
+
if probs[i] <= p:
|
| 275 |
+
magnitudes, signed = op_meta[op_name]
|
| 276 |
+
magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0
|
| 277 |
+
if signed and signs[i] == 0:
|
| 278 |
+
magnitude *= -1.0
|
| 279 |
+
img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
|
| 280 |
+
|
| 281 |
+
return img
|
| 282 |
+
|
| 283 |
+
def __repr__(self) -> str:
|
| 284 |
+
return f"{self.__class__.__name__}(policy={self.policy}, fill={self.fill})"
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class RandAugment(torch.nn.Module):
|
| 288 |
+
r"""RandAugment data augmentation method based on
|
| 289 |
+
`"RandAugment: Practical automated data augmentation with a reduced search space"
|
| 290 |
+
<https://arxiv.org/abs/1909.13719>`_.
|
| 291 |
+
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
|
| 292 |
+
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
| 293 |
+
If img is PIL Image, it is expected to be in mode "L" or "RGB".
|
| 294 |
+
|
| 295 |
+
Args:
|
| 296 |
+
num_ops (int): Number of augmentation transformations to apply sequentially.
|
| 297 |
+
magnitude (int): Magnitude for all the transformations.
|
| 298 |
+
num_magnitude_bins (int): The number of different magnitude values.
|
| 299 |
+
interpolation (InterpolationMode): Desired interpolation enum defined by
|
| 300 |
+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
|
| 301 |
+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
|
| 302 |
+
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
|
| 303 |
+
image. If given a number, the value is used for all bands respectively.
|
| 304 |
+
"""
|
| 305 |
+
|
| 306 |
+
def __init__(
|
| 307 |
+
self,
|
| 308 |
+
num_ops: int = 2,
|
| 309 |
+
magnitude: int = 9,
|
| 310 |
+
num_magnitude_bins: int = 31,
|
| 311 |
+
interpolation: InterpolationMode = InterpolationMode.NEAREST,
|
| 312 |
+
fill: Optional[List[float]] = None,
|
| 313 |
+
) -> None:
|
| 314 |
+
super().__init__()
|
| 315 |
+
self.num_ops = num_ops
|
| 316 |
+
self.magnitude = magnitude
|
| 317 |
+
self.num_magnitude_bins = num_magnitude_bins
|
| 318 |
+
self.interpolation = interpolation
|
| 319 |
+
self.fill = fill
|
| 320 |
+
|
| 321 |
+
def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[Tensor, bool]]:
|
| 322 |
+
return {
|
| 323 |
+
# op_name: (magnitudes, signed)
|
| 324 |
+
"Identity": (torch.tensor(0.0), False),
|
| 325 |
+
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
|
| 326 |
+
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
|
| 327 |
+
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
|
| 328 |
+
"TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
|
| 329 |
+
"Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
|
| 330 |
+
"Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
|
| 331 |
+
"Color": (torch.linspace(0.0, 0.9, num_bins), True),
|
| 332 |
+
"Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
|
| 333 |
+
"Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
|
| 334 |
+
"Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
|
| 335 |
+
"Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
|
| 336 |
+
"AutoContrast": (torch.tensor(0.0), False),
|
| 337 |
+
"Equalize": (torch.tensor(0.0), False),
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
def forward(self, img: Tensor) -> Tensor:
|
| 341 |
+
"""
|
| 342 |
+
img (PIL Image or Tensor): Image to be transformed.
|
| 343 |
+
|
| 344 |
+
Returns:
|
| 345 |
+
PIL Image or Tensor: Transformed image.
|
| 346 |
+
"""
|
| 347 |
+
fill = self.fill
|
| 348 |
+
channels, height, width = F.get_dimensions(img)
|
| 349 |
+
if isinstance(img, Tensor):
|
| 350 |
+
if isinstance(fill, (int, float)):
|
| 351 |
+
fill = [float(fill)] * channels
|
| 352 |
+
elif fill is not None:
|
| 353 |
+
fill = [float(f) for f in fill]
|
| 354 |
+
|
| 355 |
+
op_meta = self._augmentation_space(self.num_magnitude_bins, (height, width))
|
| 356 |
+
for _ in range(self.num_ops):
|
| 357 |
+
op_index = int(torch.randint(len(op_meta), (1,)).item())
|
| 358 |
+
op_name = list(op_meta.keys())[op_index]
|
| 359 |
+
magnitudes, signed = op_meta[op_name]
|
| 360 |
+
magnitude = float(magnitudes[self.magnitude].item()) if magnitudes.ndim > 0 else 0.0
|
| 361 |
+
if signed and torch.randint(2, (1,)):
|
| 362 |
+
magnitude *= -1.0
|
| 363 |
+
img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
|
| 364 |
+
|
| 365 |
+
return img
|
| 366 |
+
|
| 367 |
+
def __repr__(self) -> str:
|
| 368 |
+
s = (
|
| 369 |
+
f"{self.__class__.__name__}("
|
| 370 |
+
f"num_ops={self.num_ops}"
|
| 371 |
+
f", magnitude={self.magnitude}"
|
| 372 |
+
f", num_magnitude_bins={self.num_magnitude_bins}"
|
| 373 |
+
f", interpolation={self.interpolation}"
|
| 374 |
+
f", fill={self.fill}"
|
| 375 |
+
f")"
|
| 376 |
+
)
|
| 377 |
+
return s
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
class TrivialAugmentWide(torch.nn.Module):
|
| 381 |
+
r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in
|
| 382 |
+
`"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`_.
|
| 383 |
+
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
|
| 384 |
+
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
| 385 |
+
If img is PIL Image, it is expected to be in mode "L" or "RGB".
|
| 386 |
+
|
| 387 |
+
Args:
|
| 388 |
+
num_magnitude_bins (int): The number of different magnitude values.
|
| 389 |
+
interpolation (InterpolationMode): Desired interpolation enum defined by
|
| 390 |
+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
|
| 391 |
+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
|
| 392 |
+
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
|
| 393 |
+
image. If given a number, the value is used for all bands respectively.
|
| 394 |
+
"""
|
| 395 |
+
|
| 396 |
+
def __init__(
|
| 397 |
+
self,
|
| 398 |
+
num_magnitude_bins: int = 31,
|
| 399 |
+
interpolation: InterpolationMode = InterpolationMode.NEAREST,
|
| 400 |
+
fill: Optional[List[float]] = None,
|
| 401 |
+
) -> None:
|
| 402 |
+
super().__init__()
|
| 403 |
+
self.num_magnitude_bins = num_magnitude_bins
|
| 404 |
+
self.interpolation = interpolation
|
| 405 |
+
self.fill = fill
|
| 406 |
+
|
| 407 |
+
def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]:
|
| 408 |
+
return {
|
| 409 |
+
# op_name: (magnitudes, signed)
|
| 410 |
+
"Identity": (torch.tensor(0.0), False),
|
| 411 |
+
"ShearX": (torch.linspace(0.0, 0.99, num_bins), True),
|
| 412 |
+
"ShearY": (torch.linspace(0.0, 0.99, num_bins), True),
|
| 413 |
+
"TranslateX": (torch.linspace(0.0, 32.0, num_bins), True),
|
| 414 |
+
"TranslateY": (torch.linspace(0.0, 32.0, num_bins), True),
|
| 415 |
+
"Rotate": (torch.linspace(0.0, 135.0, num_bins), True),
|
| 416 |
+
"Brightness": (torch.linspace(0.0, 0.99, num_bins), True),
|
| 417 |
+
"Color": (torch.linspace(0.0, 0.99, num_bins), True),
|
| 418 |
+
"Contrast": (torch.linspace(0.0, 0.99, num_bins), True),
|
| 419 |
+
"Sharpness": (torch.linspace(0.0, 0.99, num_bins), True),
|
| 420 |
+
"Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False),
|
| 421 |
+
"Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
|
| 422 |
+
"AutoContrast": (torch.tensor(0.0), False),
|
| 423 |
+
"Equalize": (torch.tensor(0.0), False),
|
| 424 |
+
}
|
| 425 |
+
|
| 426 |
+
def forward(self, img: Tensor) -> Tensor:
|
| 427 |
+
"""
|
| 428 |
+
img (PIL Image or Tensor): Image to be transformed.
|
| 429 |
+
|
| 430 |
+
Returns:
|
| 431 |
+
PIL Image or Tensor: Transformed image.
|
| 432 |
+
"""
|
| 433 |
+
fill = self.fill
|
| 434 |
+
channels, height, width = F.get_dimensions(img)
|
| 435 |
+
if isinstance(img, Tensor):
|
| 436 |
+
if isinstance(fill, (int, float)):
|
| 437 |
+
fill = [float(fill)] * channels
|
| 438 |
+
elif fill is not None:
|
| 439 |
+
fill = [float(f) for f in fill]
|
| 440 |
+
|
| 441 |
+
op_meta = self._augmentation_space(self.num_magnitude_bins)
|
| 442 |
+
op_index = int(torch.randint(len(op_meta), (1,)).item())
|
| 443 |
+
op_name = list(op_meta.keys())[op_index]
|
| 444 |
+
magnitudes, signed = op_meta[op_name]
|
| 445 |
+
magnitude = (
|
| 446 |
+
float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item())
|
| 447 |
+
if magnitudes.ndim > 0
|
| 448 |
+
else 0.0
|
| 449 |
+
)
|
| 450 |
+
if signed and torch.randint(2, (1,)):
|
| 451 |
+
magnitude *= -1.0
|
| 452 |
+
|
| 453 |
+
return _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
|
| 454 |
+
|
| 455 |
+
def __repr__(self) -> str:
|
| 456 |
+
s = (
|
| 457 |
+
f"{self.__class__.__name__}("
|
| 458 |
+
f"num_magnitude_bins={self.num_magnitude_bins}"
|
| 459 |
+
f", interpolation={self.interpolation}"
|
| 460 |
+
f", fill={self.fill}"
|
| 461 |
+
f")"
|
| 462 |
+
)
|
| 463 |
+
return s
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
class AugMix(torch.nn.Module):
|
| 467 |
+
r"""AugMix data augmentation method based on
|
| 468 |
+
`"AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty" <https://arxiv.org/abs/1912.02781>`_.
|
| 469 |
+
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
|
| 470 |
+
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
| 471 |
+
If img is PIL Image, it is expected to be in mode "L" or "RGB".
|
| 472 |
+
|
| 473 |
+
Args:
|
| 474 |
+
severity (int): The severity of base augmentation operators. Default is ``3``.
|
| 475 |
+
mixture_width (int): The number of augmentation chains. Default is ``3``.
|
| 476 |
+
chain_depth (int): The depth of augmentation chains. A negative value denotes stochastic depth sampled from the interval [1, 3].
|
| 477 |
+
Default is ``-1``.
|
| 478 |
+
alpha (float): The hyperparameter for the probability distributions. Default is ``1.0``.
|
| 479 |
+
all_ops (bool): Use all operations (including brightness, contrast, color and sharpness). Default is ``True``.
|
| 480 |
+
interpolation (InterpolationMode): Desired interpolation enum defined by
|
| 481 |
+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
|
| 482 |
+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
|
| 483 |
+
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
|
| 484 |
+
image. If given a number, the value is used for all bands respectively.
|
| 485 |
+
"""
|
| 486 |
+
|
| 487 |
+
def __init__(
|
| 488 |
+
self,
|
| 489 |
+
severity: int = 3,
|
| 490 |
+
mixture_width: int = 3,
|
| 491 |
+
chain_depth: int = -1,
|
| 492 |
+
alpha: float = 1.0,
|
| 493 |
+
all_ops: bool = True,
|
| 494 |
+
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
|
| 495 |
+
fill: Optional[List[float]] = None,
|
| 496 |
+
) -> None:
|
| 497 |
+
super().__init__()
|
| 498 |
+
self._PARAMETER_MAX = 10
|
| 499 |
+
if not (1 <= severity <= self._PARAMETER_MAX):
|
| 500 |
+
raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.")
|
| 501 |
+
self.severity = severity
|
| 502 |
+
self.mixture_width = mixture_width
|
| 503 |
+
self.chain_depth = chain_depth
|
| 504 |
+
self.alpha = alpha
|
| 505 |
+
self.all_ops = all_ops
|
| 506 |
+
self.interpolation = interpolation
|
| 507 |
+
self.fill = fill
|
| 508 |
+
|
| 509 |
+
def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[Tensor, bool]]:
|
| 510 |
+
s = {
|
| 511 |
+
# op_name: (magnitudes, signed)
|
| 512 |
+
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
|
| 513 |
+
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
|
| 514 |
+
"TranslateX": (torch.linspace(0.0, image_size[1] / 3.0, num_bins), True),
|
| 515 |
+
"TranslateY": (torch.linspace(0.0, image_size[0] / 3.0, num_bins), True),
|
| 516 |
+
"Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
|
| 517 |
+
"Posterize": (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
|
| 518 |
+
"Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
|
| 519 |
+
"AutoContrast": (torch.tensor(0.0), False),
|
| 520 |
+
"Equalize": (torch.tensor(0.0), False),
|
| 521 |
+
}
|
| 522 |
+
if self.all_ops:
|
| 523 |
+
s.update(
|
| 524 |
+
{
|
| 525 |
+
"Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
|
| 526 |
+
"Color": (torch.linspace(0.0, 0.9, num_bins), True),
|
| 527 |
+
"Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
|
| 528 |
+
"Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
|
| 529 |
+
}
|
| 530 |
+
)
|
| 531 |
+
return s
|
| 532 |
+
|
| 533 |
+
@torch.jit.unused
|
| 534 |
+
def _pil_to_tensor(self, img) -> Tensor:
|
| 535 |
+
return F.pil_to_tensor(img)
|
| 536 |
+
|
| 537 |
+
@torch.jit.unused
|
| 538 |
+
def _tensor_to_pil(self, img: Tensor):
|
| 539 |
+
return F.to_pil_image(img)
|
| 540 |
+
|
| 541 |
+
def _sample_dirichlet(self, params: Tensor) -> Tensor:
|
| 542 |
+
# Must be on a separate method so that we can overwrite it in tests.
|
| 543 |
+
return torch._sample_dirichlet(params)
|
| 544 |
+
|
| 545 |
+
def forward(self, orig_img: Tensor) -> Tensor:
|
| 546 |
+
"""
|
| 547 |
+
img (PIL Image or Tensor): Image to be transformed.
|
| 548 |
+
|
| 549 |
+
Returns:
|
| 550 |
+
PIL Image or Tensor: Transformed image.
|
| 551 |
+
"""
|
| 552 |
+
fill = self.fill
|
| 553 |
+
channels, height, width = F.get_dimensions(orig_img)
|
| 554 |
+
if isinstance(orig_img, Tensor):
|
| 555 |
+
img = orig_img
|
| 556 |
+
if isinstance(fill, (int, float)):
|
| 557 |
+
fill = [float(fill)] * channels
|
| 558 |
+
elif fill is not None:
|
| 559 |
+
fill = [float(f) for f in fill]
|
| 560 |
+
else:
|
| 561 |
+
img = self._pil_to_tensor(orig_img)
|
| 562 |
+
|
| 563 |
+
op_meta = self._augmentation_space(self._PARAMETER_MAX, (height, width))
|
| 564 |
+
|
| 565 |
+
orig_dims = list(img.shape)
|
| 566 |
+
batch = img.view([1] * max(4 - img.ndim, 0) + orig_dims)
|
| 567 |
+
batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)
|
| 568 |
+
|
| 569 |
+
# Sample the beta weights for combining the original and augmented image. To get Beta, we use a Dirichlet
|
| 570 |
+
# with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of augmented image.
|
| 571 |
+
m = self._sample_dirichlet(
|
| 572 |
+
torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1)
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
# Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images.
|
| 576 |
+
combined_weights = self._sample_dirichlet(
|
| 577 |
+
torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1)
|
| 578 |
+
) * m[:, 1].view([batch_dims[0], -1])
|
| 579 |
+
|
| 580 |
+
mix = m[:, 0].view(batch_dims) * batch
|
| 581 |
+
for i in range(self.mixture_width):
|
| 582 |
+
aug = batch
|
| 583 |
+
depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item())
|
| 584 |
+
for _ in range(depth):
|
| 585 |
+
op_index = int(torch.randint(len(op_meta), (1,)).item())
|
| 586 |
+
op_name = list(op_meta.keys())[op_index]
|
| 587 |
+
magnitudes, signed = op_meta[op_name]
|
| 588 |
+
magnitude = (
|
| 589 |
+
float(magnitudes[torch.randint(self.severity, (1,), dtype=torch.long)].item())
|
| 590 |
+
if magnitudes.ndim > 0
|
| 591 |
+
else 0.0
|
| 592 |
+
)
|
| 593 |
+
if signed and torch.randint(2, (1,)):
|
| 594 |
+
magnitude *= -1.0
|
| 595 |
+
aug = _apply_op(aug, op_name, magnitude, interpolation=self.interpolation, fill=fill)
|
| 596 |
+
mix.add_(combined_weights[:, i].view(batch_dims) * aug)
|
| 597 |
+
mix = mix.view(orig_dims).to(dtype=img.dtype)
|
| 598 |
+
|
| 599 |
+
if not isinstance(orig_img, Tensor):
|
| 600 |
+
return self._tensor_to_pil(mix)
|
| 601 |
+
return mix
|
| 602 |
+
|
| 603 |
+
def __repr__(self) -> str:
|
| 604 |
+
s = (
|
| 605 |
+
f"{self.__class__.__name__}("
|
| 606 |
+
f"severity={self.severity}"
|
| 607 |
+
f", mixture_width={self.mixture_width}"
|
| 608 |
+
f", chain_depth={self.chain_depth}"
|
| 609 |
+
f", alpha={self.alpha}"
|
| 610 |
+
f", all_ops={self.all_ops}"
|
| 611 |
+
f", interpolation={self.interpolation}"
|
| 612 |
+
f", fill={self.fill}"
|
| 613 |
+
f")"
|
| 614 |
+
)
|
| 615 |
+
return s
|
.venv/lib/python3.11/site-packages/torchvision/transforms/functional.py
ADDED
|
@@ -0,0 +1,1586 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import numbers
|
| 3 |
+
import sys
|
| 4 |
+
import warnings
|
| 5 |
+
from enum import Enum
|
| 6 |
+
from typing import Any, List, Optional, Tuple, Union
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from PIL.Image import Image as PILImage
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
import accimage
|
| 16 |
+
except ImportError:
|
| 17 |
+
accimage = None
|
| 18 |
+
|
| 19 |
+
from ..utils import _log_api_usage_once
|
| 20 |
+
from . import _functional_pil as F_pil, _functional_tensor as F_t
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class InterpolationMode(Enum):
|
| 24 |
+
"""Interpolation modes
|
| 25 |
+
Available interpolation methods are ``nearest``, ``nearest-exact``, ``bilinear``, ``bicubic``, ``box``, ``hamming``,
|
| 26 |
+
and ``lanczos``.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
NEAREST = "nearest"
|
| 30 |
+
NEAREST_EXACT = "nearest-exact"
|
| 31 |
+
BILINEAR = "bilinear"
|
| 32 |
+
BICUBIC = "bicubic"
|
| 33 |
+
# For PIL compatibility
|
| 34 |
+
BOX = "box"
|
| 35 |
+
HAMMING = "hamming"
|
| 36 |
+
LANCZOS = "lanczos"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# TODO: Once torchscript supports Enums with staticmethod
|
| 40 |
+
# this can be put into InterpolationMode as staticmethod
|
| 41 |
+
def _interpolation_modes_from_int(i: int) -> InterpolationMode:
|
| 42 |
+
inverse_modes_mapping = {
|
| 43 |
+
0: InterpolationMode.NEAREST,
|
| 44 |
+
2: InterpolationMode.BILINEAR,
|
| 45 |
+
3: InterpolationMode.BICUBIC,
|
| 46 |
+
4: InterpolationMode.BOX,
|
| 47 |
+
5: InterpolationMode.HAMMING,
|
| 48 |
+
1: InterpolationMode.LANCZOS,
|
| 49 |
+
}
|
| 50 |
+
return inverse_modes_mapping[i]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
pil_modes_mapping = {
|
| 54 |
+
InterpolationMode.NEAREST: 0,
|
| 55 |
+
InterpolationMode.BILINEAR: 2,
|
| 56 |
+
InterpolationMode.BICUBIC: 3,
|
| 57 |
+
InterpolationMode.NEAREST_EXACT: 0,
|
| 58 |
+
InterpolationMode.BOX: 4,
|
| 59 |
+
InterpolationMode.HAMMING: 5,
|
| 60 |
+
InterpolationMode.LANCZOS: 1,
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
_is_pil_image = F_pil._is_pil_image
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_dimensions(img: Tensor) -> List[int]:
|
| 67 |
+
"""Returns the dimensions of an image as [channels, height, width].
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
img (PIL Image or Tensor): The image to be checked.
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
List[int]: The image dimensions.
|
| 74 |
+
"""
|
| 75 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 76 |
+
_log_api_usage_once(get_dimensions)
|
| 77 |
+
if isinstance(img, torch.Tensor):
|
| 78 |
+
return F_t.get_dimensions(img)
|
| 79 |
+
|
| 80 |
+
return F_pil.get_dimensions(img)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def get_image_size(img: Tensor) -> List[int]:
|
| 84 |
+
"""Returns the size of an image as [width, height].
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
img (PIL Image or Tensor): The image to be checked.
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
List[int]: The image size.
|
| 91 |
+
"""
|
| 92 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 93 |
+
_log_api_usage_once(get_image_size)
|
| 94 |
+
if isinstance(img, torch.Tensor):
|
| 95 |
+
return F_t.get_image_size(img)
|
| 96 |
+
|
| 97 |
+
return F_pil.get_image_size(img)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def get_image_num_channels(img: Tensor) -> int:
|
| 101 |
+
"""Returns the number of channels of an image.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
img (PIL Image or Tensor): The image to be checked.
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
int: The number of channels.
|
| 108 |
+
"""
|
| 109 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 110 |
+
_log_api_usage_once(get_image_num_channels)
|
| 111 |
+
if isinstance(img, torch.Tensor):
|
| 112 |
+
return F_t.get_image_num_channels(img)
|
| 113 |
+
|
| 114 |
+
return F_pil.get_image_num_channels(img)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@torch.jit.unused
|
| 118 |
+
def _is_numpy(img: Any) -> bool:
|
| 119 |
+
return isinstance(img, np.ndarray)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@torch.jit.unused
|
| 123 |
+
def _is_numpy_image(img: Any) -> bool:
|
| 124 |
+
return img.ndim in {2, 3}
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def to_tensor(pic: Union[PILImage, np.ndarray]) -> Tensor:
|
| 128 |
+
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
|
| 129 |
+
This function does not support torchscript.
|
| 130 |
+
|
| 131 |
+
See :class:`~torchvision.transforms.ToTensor` for more details.
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
Tensor: Converted image.
|
| 138 |
+
"""
|
| 139 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 140 |
+
_log_api_usage_once(to_tensor)
|
| 141 |
+
if not (F_pil._is_pil_image(pic) or _is_numpy(pic)):
|
| 142 |
+
raise TypeError(f"pic should be PIL Image or ndarray. Got {type(pic)}")
|
| 143 |
+
|
| 144 |
+
if _is_numpy(pic) and not _is_numpy_image(pic):
|
| 145 |
+
raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndim} dimensions.")
|
| 146 |
+
|
| 147 |
+
default_float_dtype = torch.get_default_dtype()
|
| 148 |
+
|
| 149 |
+
if isinstance(pic, np.ndarray):
|
| 150 |
+
# handle numpy array
|
| 151 |
+
if pic.ndim == 2:
|
| 152 |
+
pic = pic[:, :, None]
|
| 153 |
+
|
| 154 |
+
img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
|
| 155 |
+
# backward compatibility
|
| 156 |
+
if isinstance(img, torch.ByteTensor):
|
| 157 |
+
return img.to(dtype=default_float_dtype).div(255)
|
| 158 |
+
else:
|
| 159 |
+
return img
|
| 160 |
+
|
| 161 |
+
if accimage is not None and isinstance(pic, accimage.Image):
|
| 162 |
+
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
|
| 163 |
+
pic.copyto(nppic)
|
| 164 |
+
return torch.from_numpy(nppic).to(dtype=default_float_dtype)
|
| 165 |
+
|
| 166 |
+
# handle PIL Image
|
| 167 |
+
mode_to_nptype = {"I": np.int32, "I;16" if sys.byteorder == "little" else "I;16B": np.int16, "F": np.float32}
|
| 168 |
+
img = torch.from_numpy(np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True))
|
| 169 |
+
|
| 170 |
+
if pic.mode == "1":
|
| 171 |
+
img = 255 * img
|
| 172 |
+
img = img.view(pic.size[1], pic.size[0], F_pil.get_image_num_channels(pic))
|
| 173 |
+
# put it from HWC to CHW format
|
| 174 |
+
img = img.permute((2, 0, 1)).contiguous()
|
| 175 |
+
if isinstance(img, torch.ByteTensor):
|
| 176 |
+
return img.to(dtype=default_float_dtype).div(255)
|
| 177 |
+
else:
|
| 178 |
+
return img
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def pil_to_tensor(pic: Any) -> Tensor:
|
| 182 |
+
"""Convert a ``PIL Image`` to a tensor of the same type.
|
| 183 |
+
This function does not support torchscript.
|
| 184 |
+
|
| 185 |
+
See :class:`~torchvision.transforms.PILToTensor` for more details.
|
| 186 |
+
|
| 187 |
+
.. note::
|
| 188 |
+
|
| 189 |
+
A deep copy of the underlying array is performed.
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
pic (PIL Image): Image to be converted to tensor.
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
Tensor: Converted image.
|
| 196 |
+
"""
|
| 197 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 198 |
+
_log_api_usage_once(pil_to_tensor)
|
| 199 |
+
if not F_pil._is_pil_image(pic):
|
| 200 |
+
raise TypeError(f"pic should be PIL Image. Got {type(pic)}")
|
| 201 |
+
|
| 202 |
+
if accimage is not None and isinstance(pic, accimage.Image):
|
| 203 |
+
# accimage format is always uint8 internally, so always return uint8 here
|
| 204 |
+
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.uint8)
|
| 205 |
+
pic.copyto(nppic)
|
| 206 |
+
return torch.as_tensor(nppic)
|
| 207 |
+
|
| 208 |
+
# handle PIL Image
|
| 209 |
+
img = torch.as_tensor(np.array(pic, copy=True))
|
| 210 |
+
img = img.view(pic.size[1], pic.size[0], F_pil.get_image_num_channels(pic))
|
| 211 |
+
# put it from HWC to CHW format
|
| 212 |
+
img = img.permute((2, 0, 1))
|
| 213 |
+
return img
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
|
| 217 |
+
"""Convert a tensor image to the given ``dtype`` and scale the values accordingly
|
| 218 |
+
This function does not support PIL Image.
|
| 219 |
+
|
| 220 |
+
Args:
|
| 221 |
+
image (torch.Tensor): Image to be converted
|
| 222 |
+
dtype (torch.dtype): Desired data type of the output
|
| 223 |
+
|
| 224 |
+
Returns:
|
| 225 |
+
Tensor: Converted image
|
| 226 |
+
|
| 227 |
+
.. note::
|
| 228 |
+
|
| 229 |
+
When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
|
| 230 |
+
If converted back and forth, this mismatch has no effect.
|
| 231 |
+
|
| 232 |
+
Raises:
|
| 233 |
+
RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
|
| 234 |
+
well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
|
| 235 |
+
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
|
| 236 |
+
of the integer ``dtype``.
|
| 237 |
+
"""
|
| 238 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 239 |
+
_log_api_usage_once(convert_image_dtype)
|
| 240 |
+
if not isinstance(image, torch.Tensor):
|
| 241 |
+
raise TypeError("Input img should be Tensor Image")
|
| 242 |
+
|
| 243 |
+
return F_t.convert_image_dtype(image, dtype)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def to_pil_image(pic, mode=None):
|
| 247 |
+
"""Convert a tensor or an ndarray to PIL Image. This function does not support torchscript.
|
| 248 |
+
|
| 249 |
+
See :class:`~torchvision.transforms.ToPILImage` for more details.
|
| 250 |
+
|
| 251 |
+
Args:
|
| 252 |
+
pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
|
| 253 |
+
mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
|
| 254 |
+
|
| 255 |
+
.. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
|
| 256 |
+
|
| 257 |
+
Returns:
|
| 258 |
+
PIL Image: Image converted to PIL Image.
|
| 259 |
+
"""
|
| 260 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 261 |
+
_log_api_usage_once(to_pil_image)
|
| 262 |
+
|
| 263 |
+
if isinstance(pic, torch.Tensor):
|
| 264 |
+
if pic.ndim == 3:
|
| 265 |
+
pic = pic.permute((1, 2, 0))
|
| 266 |
+
pic = pic.numpy(force=True)
|
| 267 |
+
elif not isinstance(pic, np.ndarray):
|
| 268 |
+
raise TypeError(f"pic should be Tensor or ndarray. Got {type(pic)}.")
|
| 269 |
+
|
| 270 |
+
if pic.ndim == 2:
|
| 271 |
+
# if 2D image, add channel dimension (HWC)
|
| 272 |
+
pic = np.expand_dims(pic, 2)
|
| 273 |
+
if pic.ndim != 3:
|
| 274 |
+
raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndim} dimensions.")
|
| 275 |
+
|
| 276 |
+
if pic.shape[-1] > 4:
|
| 277 |
+
raise ValueError(f"pic should not have > 4 channels. Got {pic.shape[-1]} channels.")
|
| 278 |
+
|
| 279 |
+
npimg = pic
|
| 280 |
+
|
| 281 |
+
if np.issubdtype(npimg.dtype, np.floating) and mode != "F":
|
| 282 |
+
npimg = (npimg * 255).astype(np.uint8)
|
| 283 |
+
|
| 284 |
+
if npimg.shape[2] == 1:
|
| 285 |
+
expected_mode = None
|
| 286 |
+
npimg = npimg[:, :, 0]
|
| 287 |
+
if npimg.dtype == np.uint8:
|
| 288 |
+
expected_mode = "L"
|
| 289 |
+
elif npimg.dtype == np.int16:
|
| 290 |
+
expected_mode = "I;16" if sys.byteorder == "little" else "I;16B"
|
| 291 |
+
elif npimg.dtype == np.int32:
|
| 292 |
+
expected_mode = "I"
|
| 293 |
+
elif npimg.dtype == np.float32:
|
| 294 |
+
expected_mode = "F"
|
| 295 |
+
if mode is not None and mode != expected_mode:
|
| 296 |
+
raise ValueError(f"Incorrect mode ({mode}) supplied for input type {np.dtype}. Should be {expected_mode}")
|
| 297 |
+
mode = expected_mode
|
| 298 |
+
|
| 299 |
+
elif npimg.shape[2] == 2:
|
| 300 |
+
permitted_2_channel_modes = ["LA"]
|
| 301 |
+
if mode is not None and mode not in permitted_2_channel_modes:
|
| 302 |
+
raise ValueError(f"Only modes {permitted_2_channel_modes} are supported for 2D inputs")
|
| 303 |
+
|
| 304 |
+
if mode is None and npimg.dtype == np.uint8:
|
| 305 |
+
mode = "LA"
|
| 306 |
+
|
| 307 |
+
elif npimg.shape[2] == 4:
|
| 308 |
+
permitted_4_channel_modes = ["RGBA", "CMYK", "RGBX"]
|
| 309 |
+
if mode is not None and mode not in permitted_4_channel_modes:
|
| 310 |
+
raise ValueError(f"Only modes {permitted_4_channel_modes} are supported for 4D inputs")
|
| 311 |
+
|
| 312 |
+
if mode is None and npimg.dtype == np.uint8:
|
| 313 |
+
mode = "RGBA"
|
| 314 |
+
else:
|
| 315 |
+
permitted_3_channel_modes = ["RGB", "YCbCr", "HSV"]
|
| 316 |
+
if mode is not None and mode not in permitted_3_channel_modes:
|
| 317 |
+
raise ValueError(f"Only modes {permitted_3_channel_modes} are supported for 3D inputs")
|
| 318 |
+
if mode is None and npimg.dtype == np.uint8:
|
| 319 |
+
mode = "RGB"
|
| 320 |
+
|
| 321 |
+
if mode is None:
|
| 322 |
+
raise TypeError(f"Input type {npimg.dtype} is not supported")
|
| 323 |
+
|
| 324 |
+
return Image.fromarray(npimg, mode=mode)
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor:
|
| 328 |
+
"""Normalize a float tensor image with mean and standard deviation.
|
| 329 |
+
This transform does not support PIL Image.
|
| 330 |
+
|
| 331 |
+
.. note::
|
| 332 |
+
This transform acts out of place by default, i.e., it does not mutates the input tensor.
|
| 333 |
+
|
| 334 |
+
See :class:`~torchvision.transforms.Normalize` for more details.
|
| 335 |
+
|
| 336 |
+
Args:
|
| 337 |
+
tensor (Tensor): Float tensor image of size (C, H, W) or (B, C, H, W) to be normalized.
|
| 338 |
+
mean (sequence): Sequence of means for each channel.
|
| 339 |
+
std (sequence): Sequence of standard deviations for each channel.
|
| 340 |
+
inplace(bool,optional): Bool to make this operation inplace.
|
| 341 |
+
|
| 342 |
+
Returns:
|
| 343 |
+
Tensor: Normalized Tensor image.
|
| 344 |
+
"""
|
| 345 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 346 |
+
_log_api_usage_once(normalize)
|
| 347 |
+
if not isinstance(tensor, torch.Tensor):
|
| 348 |
+
raise TypeError(f"img should be Tensor Image. Got {type(tensor)}")
|
| 349 |
+
|
| 350 |
+
return F_t.normalize(tensor, mean=mean, std=std, inplace=inplace)
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
def _compute_resized_output_size(
|
| 354 |
+
image_size: Tuple[int, int],
|
| 355 |
+
size: Optional[List[int]],
|
| 356 |
+
max_size: Optional[int] = None,
|
| 357 |
+
allow_size_none: bool = False, # only True in v2
|
| 358 |
+
) -> List[int]:
|
| 359 |
+
h, w = image_size
|
| 360 |
+
short, long = (w, h) if w <= h else (h, w)
|
| 361 |
+
if size is None:
|
| 362 |
+
if not allow_size_none:
|
| 363 |
+
raise ValueError("This should never happen!!")
|
| 364 |
+
if not isinstance(max_size, int):
|
| 365 |
+
raise ValueError(f"max_size must be an integer when size is None, but got {max_size} instead.")
|
| 366 |
+
new_short, new_long = int(max_size * short / long), max_size
|
| 367 |
+
new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
|
| 368 |
+
elif len(size) == 1: # specified size only for the smallest edge
|
| 369 |
+
requested_new_short = size if isinstance(size, int) else size[0]
|
| 370 |
+
new_short, new_long = requested_new_short, int(requested_new_short * long / short)
|
| 371 |
+
|
| 372 |
+
if max_size is not None:
|
| 373 |
+
if max_size <= requested_new_short:
|
| 374 |
+
raise ValueError(
|
| 375 |
+
f"max_size = {max_size} must be strictly greater than the requested "
|
| 376 |
+
f"size for the smaller edge size = {size}"
|
| 377 |
+
)
|
| 378 |
+
if new_long > max_size:
|
| 379 |
+
new_short, new_long = int(max_size * new_short / new_long), max_size
|
| 380 |
+
|
| 381 |
+
new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
|
| 382 |
+
else: # specified both h and w
|
| 383 |
+
new_w, new_h = size[1], size[0]
|
| 384 |
+
return [new_h, new_w]
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def resize(
|
| 388 |
+
img: Tensor,
|
| 389 |
+
size: List[int],
|
| 390 |
+
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
|
| 391 |
+
max_size: Optional[int] = None,
|
| 392 |
+
antialias: Optional[bool] = True,
|
| 393 |
+
) -> Tensor:
|
| 394 |
+
r"""Resize the input image to the given size.
|
| 395 |
+
If the image is torch Tensor, it is expected
|
| 396 |
+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
|
| 397 |
+
|
| 398 |
+
Args:
|
| 399 |
+
img (PIL Image or Tensor): Image to be resized.
|
| 400 |
+
size (sequence or int): Desired output size. If size is a sequence like
|
| 401 |
+
(h, w), the output size will be matched to this. If size is an int,
|
| 402 |
+
the smaller edge of the image will be matched to this number maintaining
|
| 403 |
+
the aspect ratio. i.e, if height > width, then image will be rescaled to
|
| 404 |
+
:math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`.
|
| 405 |
+
|
| 406 |
+
.. note::
|
| 407 |
+
In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
|
| 408 |
+
interpolation (InterpolationMode): Desired interpolation enum defined by
|
| 409 |
+
:class:`torchvision.transforms.InterpolationMode`.
|
| 410 |
+
Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
|
| 411 |
+
``InterpolationMode.NEAREST_EXACT``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are
|
| 412 |
+
supported.
|
| 413 |
+
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
|
| 414 |
+
max_size (int, optional): The maximum allowed for the longer edge of
|
| 415 |
+
the resized image. If the longer edge of the image is greater
|
| 416 |
+
than ``max_size`` after being resized according to ``size``,
|
| 417 |
+
``size`` will be overruled so that the longer edge is equal to
|
| 418 |
+
``max_size``.
|
| 419 |
+
As a result, the smaller edge may be shorter than ``size``. This
|
| 420 |
+
is only supported if ``size`` is an int (or a sequence of length
|
| 421 |
+
1 in torchscript mode).
|
| 422 |
+
antialias (bool, optional): Whether to apply antialiasing.
|
| 423 |
+
It only affects **tensors** with bilinear or bicubic modes and it is
|
| 424 |
+
ignored otherwise: on PIL images, antialiasing is always applied on
|
| 425 |
+
bilinear or bicubic modes; on other modes (for PIL images and
|
| 426 |
+
tensors), antialiasing makes no sense and this parameter is ignored.
|
| 427 |
+
Possible values are:
|
| 428 |
+
|
| 429 |
+
- ``True`` (default): will apply antialiasing for bilinear or bicubic modes.
|
| 430 |
+
Other mode aren't affected. This is probably what you want to use.
|
| 431 |
+
- ``False``: will not apply antialiasing for tensors on any mode. PIL
|
| 432 |
+
images are still antialiased on bilinear or bicubic modes, because
|
| 433 |
+
PIL doesn't support no antialias.
|
| 434 |
+
- ``None``: equivalent to ``False`` for tensors and ``True`` for
|
| 435 |
+
PIL images. This value exists for legacy reasons and you probably
|
| 436 |
+
don't want to use it unless you really know what you are doing.
|
| 437 |
+
|
| 438 |
+
The default value changed from ``None`` to ``True`` in
|
| 439 |
+
v0.17, for the PIL and Tensor backends to be consistent.
|
| 440 |
+
|
| 441 |
+
Returns:
|
| 442 |
+
PIL Image or Tensor: Resized image.
|
| 443 |
+
"""
|
| 444 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 445 |
+
_log_api_usage_once(resize)
|
| 446 |
+
|
| 447 |
+
if isinstance(interpolation, int):
|
| 448 |
+
interpolation = _interpolation_modes_from_int(interpolation)
|
| 449 |
+
elif not isinstance(interpolation, InterpolationMode):
|
| 450 |
+
raise TypeError(
|
| 451 |
+
"Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
if isinstance(size, (list, tuple)):
|
| 455 |
+
if len(size) not in [1, 2]:
|
| 456 |
+
raise ValueError(
|
| 457 |
+
f"Size must be an int or a 1 or 2 element tuple/list, not a {len(size)} element tuple/list"
|
| 458 |
+
)
|
| 459 |
+
if max_size is not None and len(size) != 1:
|
| 460 |
+
raise ValueError(
|
| 461 |
+
"max_size should only be passed if size specifies the length of the smaller edge, "
|
| 462 |
+
"i.e. size should be an int or a sequence of length 1 in torchscript mode."
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
_, image_height, image_width = get_dimensions(img)
|
| 466 |
+
if isinstance(size, int):
|
| 467 |
+
size = [size]
|
| 468 |
+
output_size = _compute_resized_output_size((image_height, image_width), size, max_size)
|
| 469 |
+
|
| 470 |
+
if [image_height, image_width] == output_size:
|
| 471 |
+
return img
|
| 472 |
+
|
| 473 |
+
if not isinstance(img, torch.Tensor):
|
| 474 |
+
if antialias is False:
|
| 475 |
+
warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
|
| 476 |
+
pil_interpolation = pil_modes_mapping[interpolation]
|
| 477 |
+
return F_pil.resize(img, size=output_size, interpolation=pil_interpolation)
|
| 478 |
+
|
| 479 |
+
return F_t.resize(img, size=output_size, interpolation=interpolation.value, antialias=antialias)
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def pad(img: Tensor, padding: List[int], fill: Union[int, float] = 0, padding_mode: str = "constant") -> Tensor:
|
| 483 |
+
r"""Pad the given image on all sides with the given "pad" value.
|
| 484 |
+
If the image is torch Tensor, it is expected
|
| 485 |
+
to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric,
|
| 486 |
+
at most 3 leading dimensions for mode edge,
|
| 487 |
+
and an arbitrary number of leading dimensions for mode constant
|
| 488 |
+
|
| 489 |
+
Args:
|
| 490 |
+
img (PIL Image or Tensor): Image to be padded.
|
| 491 |
+
padding (int or sequence): Padding on each border. If a single int is provided this
|
| 492 |
+
is used to pad all borders. If sequence of length 2 is provided this is the padding
|
| 493 |
+
on left/right and top/bottom respectively. If a sequence of length 4 is provided
|
| 494 |
+
this is the padding for the left, top, right and bottom borders respectively.
|
| 495 |
+
|
| 496 |
+
.. note::
|
| 497 |
+
In torchscript mode padding as single int is not supported, use a sequence of
|
| 498 |
+
length 1: ``[padding, ]``.
|
| 499 |
+
fill (number or tuple): Pixel fill value for constant fill. Default is 0.
|
| 500 |
+
If a tuple of length 3, it is used to fill R, G, B channels respectively.
|
| 501 |
+
This value is only used when the padding_mode is constant.
|
| 502 |
+
Only number is supported for torch Tensor.
|
| 503 |
+
Only int or tuple value is supported for PIL Image.
|
| 504 |
+
padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
|
| 505 |
+
Default is constant.
|
| 506 |
+
|
| 507 |
+
- constant: pads with a constant value, this value is specified with fill
|
| 508 |
+
|
| 509 |
+
- edge: pads with the last value at the edge of the image.
|
| 510 |
+
If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2
|
| 511 |
+
|
| 512 |
+
- reflect: pads with reflection of image without repeating the last value on the edge.
|
| 513 |
+
For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
|
| 514 |
+
will result in [3, 2, 1, 2, 3, 4, 3, 2]
|
| 515 |
+
|
| 516 |
+
- symmetric: pads with reflection of image repeating the last value on the edge.
|
| 517 |
+
For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
|
| 518 |
+
will result in [2, 1, 1, 2, 3, 4, 4, 3]
|
| 519 |
+
|
| 520 |
+
Returns:
|
| 521 |
+
PIL Image or Tensor: Padded image.
|
| 522 |
+
"""
|
| 523 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 524 |
+
_log_api_usage_once(pad)
|
| 525 |
+
if not isinstance(img, torch.Tensor):
|
| 526 |
+
return F_pil.pad(img, padding=padding, fill=fill, padding_mode=padding_mode)
|
| 527 |
+
|
| 528 |
+
return F_t.pad(img, padding=padding, fill=fill, padding_mode=padding_mode)
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
|
| 532 |
+
"""Crop the given image at specified location and output size.
|
| 533 |
+
If the image is torch Tensor, it is expected
|
| 534 |
+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
|
| 535 |
+
If image size is smaller than output size along any edge, image is padded with 0 and then cropped.
|
| 536 |
+
|
| 537 |
+
Args:
|
| 538 |
+
img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image.
|
| 539 |
+
top (int): Vertical component of the top left corner of the crop box.
|
| 540 |
+
left (int): Horizontal component of the top left corner of the crop box.
|
| 541 |
+
height (int): Height of the crop box.
|
| 542 |
+
width (int): Width of the crop box.
|
| 543 |
+
|
| 544 |
+
Returns:
|
| 545 |
+
PIL Image or Tensor: Cropped image.
|
| 546 |
+
"""
|
| 547 |
+
|
| 548 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 549 |
+
_log_api_usage_once(crop)
|
| 550 |
+
if not isinstance(img, torch.Tensor):
|
| 551 |
+
return F_pil.crop(img, top, left, height, width)
|
| 552 |
+
|
| 553 |
+
return F_t.crop(img, top, left, height, width)
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
|
| 557 |
+
"""Crops the given image at the center.
|
| 558 |
+
If the image is torch Tensor, it is expected
|
| 559 |
+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
|
| 560 |
+
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
|
| 561 |
+
|
| 562 |
+
Args:
|
| 563 |
+
img (PIL Image or Tensor): Image to be cropped.
|
| 564 |
+
output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int,
|
| 565 |
+
it is used for both directions.
|
| 566 |
+
|
| 567 |
+
Returns:
|
| 568 |
+
PIL Image or Tensor: Cropped image.
|
| 569 |
+
"""
|
| 570 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 571 |
+
_log_api_usage_once(center_crop)
|
| 572 |
+
if isinstance(output_size, numbers.Number):
|
| 573 |
+
output_size = (int(output_size), int(output_size))
|
| 574 |
+
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
|
| 575 |
+
output_size = (output_size[0], output_size[0])
|
| 576 |
+
|
| 577 |
+
_, image_height, image_width = get_dimensions(img)
|
| 578 |
+
crop_height, crop_width = output_size
|
| 579 |
+
|
| 580 |
+
if crop_width > image_width or crop_height > image_height:
|
| 581 |
+
padding_ltrb = [
|
| 582 |
+
(crop_width - image_width) // 2 if crop_width > image_width else 0,
|
| 583 |
+
(crop_height - image_height) // 2 if crop_height > image_height else 0,
|
| 584 |
+
(crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
|
| 585 |
+
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
|
| 586 |
+
]
|
| 587 |
+
img = pad(img, padding_ltrb, fill=0) # PIL uses fill value 0
|
| 588 |
+
_, image_height, image_width = get_dimensions(img)
|
| 589 |
+
if crop_width == image_width and crop_height == image_height:
|
| 590 |
+
return img
|
| 591 |
+
|
| 592 |
+
crop_top = int(round((image_height - crop_height) / 2.0))
|
| 593 |
+
crop_left = int(round((image_width - crop_width) / 2.0))
|
| 594 |
+
return crop(img, crop_top, crop_left, crop_height, crop_width)
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
def resized_crop(
|
| 598 |
+
img: Tensor,
|
| 599 |
+
top: int,
|
| 600 |
+
left: int,
|
| 601 |
+
height: int,
|
| 602 |
+
width: int,
|
| 603 |
+
size: List[int],
|
| 604 |
+
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
|
| 605 |
+
antialias: Optional[bool] = True,
|
| 606 |
+
) -> Tensor:
|
| 607 |
+
"""Crop the given image and resize it to desired size.
|
| 608 |
+
If the image is torch Tensor, it is expected
|
| 609 |
+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
|
| 610 |
+
|
| 611 |
+
Notably used in :class:`~torchvision.transforms.RandomResizedCrop`.
|
| 612 |
+
|
| 613 |
+
Args:
|
| 614 |
+
img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image.
|
| 615 |
+
top (int): Vertical component of the top left corner of the crop box.
|
| 616 |
+
left (int): Horizontal component of the top left corner of the crop box.
|
| 617 |
+
height (int): Height of the crop box.
|
| 618 |
+
width (int): Width of the crop box.
|
| 619 |
+
size (sequence or int): Desired output size. Same semantics as ``resize``.
|
| 620 |
+
interpolation (InterpolationMode): Desired interpolation enum defined by
|
| 621 |
+
:class:`torchvision.transforms.InterpolationMode`.
|
| 622 |
+
Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
|
| 623 |
+
``InterpolationMode.NEAREST_EXACT``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are
|
| 624 |
+
supported.
|
| 625 |
+
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
|
| 626 |
+
antialias (bool, optional): Whether to apply antialiasing.
|
| 627 |
+
It only affects **tensors** with bilinear or bicubic modes and it is
|
| 628 |
+
ignored otherwise: on PIL images, antialiasing is always applied on
|
| 629 |
+
bilinear or bicubic modes; on other modes (for PIL images and
|
| 630 |
+
tensors), antialiasing makes no sense and this parameter is ignored.
|
| 631 |
+
Possible values are:
|
| 632 |
+
|
| 633 |
+
- ``True`` (default): will apply antialiasing for bilinear or bicubic modes.
|
| 634 |
+
Other mode aren't affected. This is probably what you want to use.
|
| 635 |
+
- ``False``: will not apply antialiasing for tensors on any mode. PIL
|
| 636 |
+
images are still antialiased on bilinear or bicubic modes, because
|
| 637 |
+
PIL doesn't support no antialias.
|
| 638 |
+
- ``None``: equivalent to ``False`` for tensors and ``True`` for
|
| 639 |
+
PIL images. This value exists for legacy reasons and you probably
|
| 640 |
+
don't want to use it unless you really know what you are doing.
|
| 641 |
+
|
| 642 |
+
The default value changed from ``None`` to ``True`` in
|
| 643 |
+
v0.17, for the PIL and Tensor backends to be consistent.
|
| 644 |
+
Returns:
|
| 645 |
+
PIL Image or Tensor: Cropped image.
|
| 646 |
+
"""
|
| 647 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 648 |
+
_log_api_usage_once(resized_crop)
|
| 649 |
+
img = crop(img, top, left, height, width)
|
| 650 |
+
img = resize(img, size, interpolation, antialias=antialias)
|
| 651 |
+
return img
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
def hflip(img: Tensor) -> Tensor:
|
| 655 |
+
"""Horizontally flip the given image.
|
| 656 |
+
|
| 657 |
+
Args:
|
| 658 |
+
img (PIL Image or Tensor): Image to be flipped. If img
|
| 659 |
+
is a Tensor, it is expected to be in [..., H, W] format,
|
| 660 |
+
where ... means it can have an arbitrary number of leading
|
| 661 |
+
dimensions.
|
| 662 |
+
|
| 663 |
+
Returns:
|
| 664 |
+
PIL Image or Tensor: Horizontally flipped image.
|
| 665 |
+
"""
|
| 666 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 667 |
+
_log_api_usage_once(hflip)
|
| 668 |
+
if not isinstance(img, torch.Tensor):
|
| 669 |
+
return F_pil.hflip(img)
|
| 670 |
+
|
| 671 |
+
return F_t.hflip(img)
|
| 672 |
+
|
| 673 |
+
|
| 674 |
+
def _get_perspective_coeffs(startpoints: List[List[int]], endpoints: List[List[int]]) -> List[float]:
|
| 675 |
+
"""Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms.
|
| 676 |
+
|
| 677 |
+
In Perspective Transform each pixel (x, y) in the original image gets transformed as,
|
| 678 |
+
(x, y) -> ( (ax + by + c) / (gx + hy + 1), (dx + ey + f) / (gx + hy + 1) )
|
| 679 |
+
|
| 680 |
+
Args:
|
| 681 |
+
startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
|
| 682 |
+
``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
|
| 683 |
+
endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
|
| 684 |
+
``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.
|
| 685 |
+
|
| 686 |
+
Returns:
|
| 687 |
+
octuple (a, b, c, d, e, f, g, h) for transforming each pixel.
|
| 688 |
+
"""
|
| 689 |
+
if len(startpoints) != 4 or len(endpoints) != 4:
|
| 690 |
+
raise ValueError(
|
| 691 |
+
f"Please provide exactly four corners, got {len(startpoints)} startpoints and {len(endpoints)} endpoints."
|
| 692 |
+
)
|
| 693 |
+
a_matrix = torch.zeros(2 * len(startpoints), 8, dtype=torch.float64)
|
| 694 |
+
|
| 695 |
+
for i, (p1, p2) in enumerate(zip(endpoints, startpoints)):
|
| 696 |
+
a_matrix[2 * i, :] = torch.tensor([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]])
|
| 697 |
+
a_matrix[2 * i + 1, :] = torch.tensor([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]])
|
| 698 |
+
|
| 699 |
+
b_matrix = torch.tensor(startpoints, dtype=torch.float64).view(8)
|
| 700 |
+
# do least squares in double precision to prevent numerical issues
|
| 701 |
+
res = torch.linalg.lstsq(a_matrix, b_matrix, driver="gels").solution.to(torch.float32)
|
| 702 |
+
|
| 703 |
+
output: List[float] = res.tolist()
|
| 704 |
+
return output
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
def perspective(
|
| 708 |
+
img: Tensor,
|
| 709 |
+
startpoints: List[List[int]],
|
| 710 |
+
endpoints: List[List[int]],
|
| 711 |
+
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
|
| 712 |
+
fill: Optional[List[float]] = None,
|
| 713 |
+
) -> Tensor:
|
| 714 |
+
"""Perform perspective transform of the given image.
|
| 715 |
+
If the image is torch Tensor, it is expected
|
| 716 |
+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
|
| 717 |
+
|
| 718 |
+
Args:
|
| 719 |
+
img (PIL Image or Tensor): Image to be transformed.
|
| 720 |
+
startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
|
| 721 |
+
``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
|
| 722 |
+
endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
|
| 723 |
+
``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.
|
| 724 |
+
interpolation (InterpolationMode): Desired interpolation enum defined by
|
| 725 |
+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
|
| 726 |
+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
|
| 727 |
+
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
|
| 728 |
+
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
|
| 729 |
+
image. If given a number, the value is used for all bands respectively.
|
| 730 |
+
|
| 731 |
+
.. note::
|
| 732 |
+
In torchscript mode single int/float value is not supported, please use a sequence
|
| 733 |
+
of length 1: ``[value, ]``.
|
| 734 |
+
|
| 735 |
+
Returns:
|
| 736 |
+
PIL Image or Tensor: transformed Image.
|
| 737 |
+
"""
|
| 738 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 739 |
+
_log_api_usage_once(perspective)
|
| 740 |
+
|
| 741 |
+
coeffs = _get_perspective_coeffs(startpoints, endpoints)
|
| 742 |
+
|
| 743 |
+
if isinstance(interpolation, int):
|
| 744 |
+
interpolation = _interpolation_modes_from_int(interpolation)
|
| 745 |
+
elif not isinstance(interpolation, InterpolationMode):
|
| 746 |
+
raise TypeError(
|
| 747 |
+
"Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
|
| 748 |
+
)
|
| 749 |
+
|
| 750 |
+
if not isinstance(img, torch.Tensor):
|
| 751 |
+
pil_interpolation = pil_modes_mapping[interpolation]
|
| 752 |
+
return F_pil.perspective(img, coeffs, interpolation=pil_interpolation, fill=fill)
|
| 753 |
+
|
| 754 |
+
return F_t.perspective(img, coeffs, interpolation=interpolation.value, fill=fill)
|
| 755 |
+
|
| 756 |
+
|
| 757 |
+
def vflip(img: Tensor) -> Tensor:
|
| 758 |
+
"""Vertically flip the given image.
|
| 759 |
+
|
| 760 |
+
Args:
|
| 761 |
+
img (PIL Image or Tensor): Image to be flipped. If img
|
| 762 |
+
is a Tensor, it is expected to be in [..., H, W] format,
|
| 763 |
+
where ... means it can have an arbitrary number of leading
|
| 764 |
+
dimensions.
|
| 765 |
+
|
| 766 |
+
Returns:
|
| 767 |
+
PIL Image or Tensor: Vertically flipped image.
|
| 768 |
+
"""
|
| 769 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 770 |
+
_log_api_usage_once(vflip)
|
| 771 |
+
if not isinstance(img, torch.Tensor):
|
| 772 |
+
return F_pil.vflip(img)
|
| 773 |
+
|
| 774 |
+
return F_t.vflip(img)
|
| 775 |
+
|
| 776 |
+
|
| 777 |
+
def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
|
| 778 |
+
"""Crop the given image into four corners and the central crop.
|
| 779 |
+
If the image is torch Tensor, it is expected
|
| 780 |
+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
|
| 781 |
+
|
| 782 |
+
.. Note::
|
| 783 |
+
This transform returns a tuple of images and there may be a
|
| 784 |
+
mismatch in the number of inputs and targets your ``Dataset`` returns.
|
| 785 |
+
|
| 786 |
+
Args:
|
| 787 |
+
img (PIL Image or Tensor): Image to be cropped.
|
| 788 |
+
size (sequence or int): Desired output size of the crop. If size is an
|
| 789 |
+
int instead of sequence like (h, w), a square crop (size, size) is
|
| 790 |
+
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
|
| 791 |
+
|
| 792 |
+
Returns:
|
| 793 |
+
tuple: tuple (tl, tr, bl, br, center)
|
| 794 |
+
Corresponding top left, top right, bottom left, bottom right and center crop.
|
| 795 |
+
"""
|
| 796 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 797 |
+
_log_api_usage_once(five_crop)
|
| 798 |
+
if isinstance(size, numbers.Number):
|
| 799 |
+
size = (int(size), int(size))
|
| 800 |
+
elif isinstance(size, (tuple, list)) and len(size) == 1:
|
| 801 |
+
size = (size[0], size[0])
|
| 802 |
+
|
| 803 |
+
if len(size) != 2:
|
| 804 |
+
raise ValueError("Please provide only two dimensions (h, w) for size.")
|
| 805 |
+
|
| 806 |
+
_, image_height, image_width = get_dimensions(img)
|
| 807 |
+
crop_height, crop_width = size
|
| 808 |
+
if crop_width > image_width or crop_height > image_height:
|
| 809 |
+
msg = "Requested crop size {} is bigger than input size {}"
|
| 810 |
+
raise ValueError(msg.format(size, (image_height, image_width)))
|
| 811 |
+
|
| 812 |
+
tl = crop(img, 0, 0, crop_height, crop_width)
|
| 813 |
+
tr = crop(img, 0, image_width - crop_width, crop_height, crop_width)
|
| 814 |
+
bl = crop(img, image_height - crop_height, 0, crop_height, crop_width)
|
| 815 |
+
br = crop(img, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
|
| 816 |
+
|
| 817 |
+
center = center_crop(img, [crop_height, crop_width])
|
| 818 |
+
|
| 819 |
+
return tl, tr, bl, br, center
|
| 820 |
+
|
| 821 |
+
|
| 822 |
+
def ten_crop(
|
| 823 |
+
img: Tensor, size: List[int], vertical_flip: bool = False
|
| 824 |
+
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
|
| 825 |
+
"""Generate ten cropped images from the given image.
|
| 826 |
+
Crop the given image into four corners and the central crop plus the
|
| 827 |
+
flipped version of these (horizontal flipping is used by default).
|
| 828 |
+
If the image is torch Tensor, it is expected
|
| 829 |
+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
|
| 830 |
+
|
| 831 |
+
.. Note::
|
| 832 |
+
This transform returns a tuple of images and there may be a
|
| 833 |
+
mismatch in the number of inputs and targets your ``Dataset`` returns.
|
| 834 |
+
|
| 835 |
+
Args:
|
| 836 |
+
img (PIL Image or Tensor): Image to be cropped.
|
| 837 |
+
size (sequence or int): Desired output size of the crop. If size is an
|
| 838 |
+
int instead of sequence like (h, w), a square crop (size, size) is
|
| 839 |
+
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
|
| 840 |
+
vertical_flip (bool): Use vertical flipping instead of horizontal
|
| 841 |
+
|
| 842 |
+
Returns:
|
| 843 |
+
tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip)
|
| 844 |
+
Corresponding top left, top right, bottom left, bottom right and
|
| 845 |
+
center crop and same for the flipped image.
|
| 846 |
+
"""
|
| 847 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 848 |
+
_log_api_usage_once(ten_crop)
|
| 849 |
+
if isinstance(size, numbers.Number):
|
| 850 |
+
size = (int(size), int(size))
|
| 851 |
+
elif isinstance(size, (tuple, list)) and len(size) == 1:
|
| 852 |
+
size = (size[0], size[0])
|
| 853 |
+
|
| 854 |
+
if len(size) != 2:
|
| 855 |
+
raise ValueError("Please provide only two dimensions (h, w) for size.")
|
| 856 |
+
|
| 857 |
+
first_five = five_crop(img, size)
|
| 858 |
+
|
| 859 |
+
if vertical_flip:
|
| 860 |
+
img = vflip(img)
|
| 861 |
+
else:
|
| 862 |
+
img = hflip(img)
|
| 863 |
+
|
| 864 |
+
second_five = five_crop(img, size)
|
| 865 |
+
return first_five + second_five
|
| 866 |
+
|
| 867 |
+
|
| 868 |
+
def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
|
| 869 |
+
"""Adjust brightness of an image.
|
| 870 |
+
|
| 871 |
+
Args:
|
| 872 |
+
img (PIL Image or Tensor): Image to be adjusted.
|
| 873 |
+
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
|
| 874 |
+
where ... means it can have an arbitrary number of leading dimensions.
|
| 875 |
+
brightness_factor (float): How much to adjust the brightness. Can be
|
| 876 |
+
any non-negative number. 0 gives a black image, 1 gives the
|
| 877 |
+
original image while 2 increases the brightness by a factor of 2.
|
| 878 |
+
|
| 879 |
+
Returns:
|
| 880 |
+
PIL Image or Tensor: Brightness adjusted image.
|
| 881 |
+
"""
|
| 882 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 883 |
+
_log_api_usage_once(adjust_brightness)
|
| 884 |
+
if not isinstance(img, torch.Tensor):
|
| 885 |
+
return F_pil.adjust_brightness(img, brightness_factor)
|
| 886 |
+
|
| 887 |
+
return F_t.adjust_brightness(img, brightness_factor)
|
| 888 |
+
|
| 889 |
+
|
| 890 |
+
def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
|
| 891 |
+
"""Adjust contrast of an image.
|
| 892 |
+
|
| 893 |
+
Args:
|
| 894 |
+
img (PIL Image or Tensor): Image to be adjusted.
|
| 895 |
+
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
|
| 896 |
+
where ... means it can have an arbitrary number of leading dimensions.
|
| 897 |
+
contrast_factor (float): How much to adjust the contrast. Can be any
|
| 898 |
+
non-negative number. 0 gives a solid gray image, 1 gives the
|
| 899 |
+
original image while 2 increases the contrast by a factor of 2.
|
| 900 |
+
|
| 901 |
+
Returns:
|
| 902 |
+
PIL Image or Tensor: Contrast adjusted image.
|
| 903 |
+
"""
|
| 904 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 905 |
+
_log_api_usage_once(adjust_contrast)
|
| 906 |
+
if not isinstance(img, torch.Tensor):
|
| 907 |
+
return F_pil.adjust_contrast(img, contrast_factor)
|
| 908 |
+
|
| 909 |
+
return F_t.adjust_contrast(img, contrast_factor)
|
| 910 |
+
|
| 911 |
+
|
| 912 |
+
def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
|
| 913 |
+
"""Adjust color saturation of an image.
|
| 914 |
+
|
| 915 |
+
Args:
|
| 916 |
+
img (PIL Image or Tensor): Image to be adjusted.
|
| 917 |
+
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
|
| 918 |
+
where ... means it can have an arbitrary number of leading dimensions.
|
| 919 |
+
saturation_factor (float): How much to adjust the saturation. 0 will
|
| 920 |
+
give a black and white image, 1 will give the original image while
|
| 921 |
+
2 will enhance the saturation by a factor of 2.
|
| 922 |
+
|
| 923 |
+
Returns:
|
| 924 |
+
PIL Image or Tensor: Saturation adjusted image.
|
| 925 |
+
"""
|
| 926 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 927 |
+
_log_api_usage_once(adjust_saturation)
|
| 928 |
+
if not isinstance(img, torch.Tensor):
|
| 929 |
+
return F_pil.adjust_saturation(img, saturation_factor)
|
| 930 |
+
|
| 931 |
+
return F_t.adjust_saturation(img, saturation_factor)
|
| 932 |
+
|
| 933 |
+
|
| 934 |
+
def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
|
| 935 |
+
"""Adjust hue of an image.
|
| 936 |
+
|
| 937 |
+
The image hue is adjusted by converting the image to HSV and
|
| 938 |
+
cyclically shifting the intensities in the hue channel (H).
|
| 939 |
+
The image is then converted back to original image mode.
|
| 940 |
+
|
| 941 |
+
`hue_factor` is the amount of shift in H channel and must be in the
|
| 942 |
+
interval `[-0.5, 0.5]`.
|
| 943 |
+
|
| 944 |
+
See `Hue`_ for more details.
|
| 945 |
+
|
| 946 |
+
.. _Hue: https://en.wikipedia.org/wiki/Hue
|
| 947 |
+
|
| 948 |
+
Args:
|
| 949 |
+
img (PIL Image or Tensor): Image to be adjusted.
|
| 950 |
+
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
|
| 951 |
+
where ... means it can have an arbitrary number of leading dimensions.
|
| 952 |
+
If img is PIL Image mode "1", "I", "F" and modes with transparency (alpha channel) are not supported.
|
| 953 |
+
Note: the pixel values of the input image has to be non-negative for conversion to HSV space;
|
| 954 |
+
thus it does not work if you normalize your image to an interval with negative values,
|
| 955 |
+
or use an interpolation that generates negative values before using this function.
|
| 956 |
+
hue_factor (float): How much to shift the hue channel. Should be in
|
| 957 |
+
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
|
| 958 |
+
HSV space in positive and negative direction respectively.
|
| 959 |
+
0 means no shift. Therefore, both -0.5 and 0.5 will give an image
|
| 960 |
+
with complementary colors while 0 gives the original image.
|
| 961 |
+
|
| 962 |
+
Returns:
|
| 963 |
+
PIL Image or Tensor: Hue adjusted image.
|
| 964 |
+
"""
|
| 965 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 966 |
+
_log_api_usage_once(adjust_hue)
|
| 967 |
+
if not isinstance(img, torch.Tensor):
|
| 968 |
+
return F_pil.adjust_hue(img, hue_factor)
|
| 969 |
+
|
| 970 |
+
return F_t.adjust_hue(img, hue_factor)
|
| 971 |
+
|
| 972 |
+
|
| 973 |
+
def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
|
| 974 |
+
r"""Perform gamma correction on an image.
|
| 975 |
+
|
| 976 |
+
Also known as Power Law Transform. Intensities in RGB mode are adjusted
|
| 977 |
+
based on the following equation:
|
| 978 |
+
|
| 979 |
+
.. math::
|
| 980 |
+
I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma}
|
| 981 |
+
|
| 982 |
+
See `Gamma Correction`_ for more details.
|
| 983 |
+
|
| 984 |
+
.. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction
|
| 985 |
+
|
| 986 |
+
Args:
|
| 987 |
+
img (PIL Image or Tensor): PIL Image to be adjusted.
|
| 988 |
+
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
|
| 989 |
+
where ... means it can have an arbitrary number of leading dimensions.
|
| 990 |
+
If img is PIL Image, modes with transparency (alpha channel) are not supported.
|
| 991 |
+
gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
|
| 992 |
+
gamma larger than 1 make the shadows darker,
|
| 993 |
+
while gamma smaller than 1 make dark regions lighter.
|
| 994 |
+
gain (float): The constant multiplier.
|
| 995 |
+
Returns:
|
| 996 |
+
PIL Image or Tensor: Gamma correction adjusted image.
|
| 997 |
+
"""
|
| 998 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 999 |
+
_log_api_usage_once(adjust_gamma)
|
| 1000 |
+
if not isinstance(img, torch.Tensor):
|
| 1001 |
+
return F_pil.adjust_gamma(img, gamma, gain)
|
| 1002 |
+
|
| 1003 |
+
return F_t.adjust_gamma(img, gamma, gain)
|
| 1004 |
+
|
| 1005 |
+
|
| 1006 |
+
def _get_inverse_affine_matrix(
|
| 1007 |
+
center: List[float], angle: float, translate: List[float], scale: float, shear: List[float], inverted: bool = True
|
| 1008 |
+
) -> List[float]:
|
| 1009 |
+
# Helper method to compute inverse matrix for affine transformation
|
| 1010 |
+
|
| 1011 |
+
# Pillow requires inverse affine transformation matrix:
|
| 1012 |
+
# Affine matrix is : M = T * C * RotateScaleShear * C^-1
|
| 1013 |
+
#
|
| 1014 |
+
# where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
|
| 1015 |
+
# C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
|
| 1016 |
+
# RotateScaleShear is rotation with scale and shear matrix
|
| 1017 |
+
#
|
| 1018 |
+
# RotateScaleShear(a, s, (sx, sy)) =
|
| 1019 |
+
# = R(a) * S(s) * SHy(sy) * SHx(sx)
|
| 1020 |
+
# = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ]
|
| 1021 |
+
# [ s*sin(a - sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ]
|
| 1022 |
+
# [ 0 , 0 , 1 ]
|
| 1023 |
+
# where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears:
|
| 1024 |
+
# SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0]
|
| 1025 |
+
# [0, 1 ] [-tan(s), 1]
|
| 1026 |
+
#
|
| 1027 |
+
# Thus, the inverse is M^-1 = C * RotateScaleShear^-1 * C^-1 * T^-1
|
| 1028 |
+
|
| 1029 |
+
rot = math.radians(angle)
|
| 1030 |
+
sx = math.radians(shear[0])
|
| 1031 |
+
sy = math.radians(shear[1])
|
| 1032 |
+
|
| 1033 |
+
cx, cy = center
|
| 1034 |
+
tx, ty = translate
|
| 1035 |
+
|
| 1036 |
+
# RSS without scaling
|
| 1037 |
+
a = math.cos(rot - sy) / math.cos(sy)
|
| 1038 |
+
b = -math.cos(rot - sy) * math.tan(sx) / math.cos(sy) - math.sin(rot)
|
| 1039 |
+
c = math.sin(rot - sy) / math.cos(sy)
|
| 1040 |
+
d = -math.sin(rot - sy) * math.tan(sx) / math.cos(sy) + math.cos(rot)
|
| 1041 |
+
|
| 1042 |
+
if inverted:
|
| 1043 |
+
# Inverted rotation matrix with scale and shear
|
| 1044 |
+
# det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
|
| 1045 |
+
matrix = [d, -b, 0.0, -c, a, 0.0]
|
| 1046 |
+
matrix = [x / scale for x in matrix]
|
| 1047 |
+
# Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
|
| 1048 |
+
matrix[2] += matrix[0] * (-cx - tx) + matrix[1] * (-cy - ty)
|
| 1049 |
+
matrix[5] += matrix[3] * (-cx - tx) + matrix[4] * (-cy - ty)
|
| 1050 |
+
# Apply center translation: C * RSS^-1 * C^-1 * T^-1
|
| 1051 |
+
matrix[2] += cx
|
| 1052 |
+
matrix[5] += cy
|
| 1053 |
+
else:
|
| 1054 |
+
matrix = [a, b, 0.0, c, d, 0.0]
|
| 1055 |
+
matrix = [x * scale for x in matrix]
|
| 1056 |
+
# Apply inverse of center translation: RSS * C^-1
|
| 1057 |
+
matrix[2] += matrix[0] * (-cx) + matrix[1] * (-cy)
|
| 1058 |
+
matrix[5] += matrix[3] * (-cx) + matrix[4] * (-cy)
|
| 1059 |
+
# Apply translation and center : T * C * RSS * C^-1
|
| 1060 |
+
matrix[2] += cx + tx
|
| 1061 |
+
matrix[5] += cy + ty
|
| 1062 |
+
|
| 1063 |
+
return matrix
|
| 1064 |
+
|
| 1065 |
+
|
| 1066 |
+
def rotate(
|
| 1067 |
+
img: Tensor,
|
| 1068 |
+
angle: float,
|
| 1069 |
+
interpolation: InterpolationMode = InterpolationMode.NEAREST,
|
| 1070 |
+
expand: bool = False,
|
| 1071 |
+
center: Optional[List[int]] = None,
|
| 1072 |
+
fill: Optional[List[float]] = None,
|
| 1073 |
+
) -> Tensor:
|
| 1074 |
+
"""Rotate the image by angle.
|
| 1075 |
+
If the image is torch Tensor, it is expected
|
| 1076 |
+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
|
| 1077 |
+
|
| 1078 |
+
Args:
|
| 1079 |
+
img (PIL Image or Tensor): image to be rotated.
|
| 1080 |
+
angle (number): rotation angle value in degrees, counter-clockwise.
|
| 1081 |
+
interpolation (InterpolationMode): Desired interpolation enum defined by
|
| 1082 |
+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
|
| 1083 |
+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
|
| 1084 |
+
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
|
| 1085 |
+
expand (bool, optional): Optional expansion flag.
|
| 1086 |
+
If true, expands the output image to make it large enough to hold the entire rotated image.
|
| 1087 |
+
If false or omitted, make the output image the same size as the input image.
|
| 1088 |
+
Note that the expand flag assumes rotation around the center and no translation.
|
| 1089 |
+
center (sequence, optional): Optional center of rotation. Origin is the upper left corner.
|
| 1090 |
+
Default is the center of the image.
|
| 1091 |
+
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
|
| 1092 |
+
image. If given a number, the value is used for all bands respectively.
|
| 1093 |
+
|
| 1094 |
+
.. note::
|
| 1095 |
+
In torchscript mode single int/float value is not supported, please use a sequence
|
| 1096 |
+
of length 1: ``[value, ]``.
|
| 1097 |
+
Returns:
|
| 1098 |
+
PIL Image or Tensor: Rotated image.
|
| 1099 |
+
|
| 1100 |
+
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
|
| 1101 |
+
|
| 1102 |
+
"""
|
| 1103 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 1104 |
+
_log_api_usage_once(rotate)
|
| 1105 |
+
|
| 1106 |
+
if isinstance(interpolation, int):
|
| 1107 |
+
interpolation = _interpolation_modes_from_int(interpolation)
|
| 1108 |
+
elif not isinstance(interpolation, InterpolationMode):
|
| 1109 |
+
raise TypeError(
|
| 1110 |
+
"Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
|
| 1111 |
+
)
|
| 1112 |
+
|
| 1113 |
+
if not isinstance(angle, (int, float)):
|
| 1114 |
+
raise TypeError("Argument angle should be int or float")
|
| 1115 |
+
|
| 1116 |
+
if center is not None and not isinstance(center, (list, tuple)):
|
| 1117 |
+
raise TypeError("Argument center should be a sequence")
|
| 1118 |
+
|
| 1119 |
+
if not isinstance(img, torch.Tensor):
|
| 1120 |
+
pil_interpolation = pil_modes_mapping[interpolation]
|
| 1121 |
+
return F_pil.rotate(img, angle=angle, interpolation=pil_interpolation, expand=expand, center=center, fill=fill)
|
| 1122 |
+
|
| 1123 |
+
center_f = [0.0, 0.0]
|
| 1124 |
+
if center is not None:
|
| 1125 |
+
_, height, width = get_dimensions(img)
|
| 1126 |
+
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
|
| 1127 |
+
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
|
| 1128 |
+
|
| 1129 |
+
# due to current incoherence of rotation angle direction between affine and rotate implementations
|
| 1130 |
+
# we need to set -angle.
|
| 1131 |
+
matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
|
| 1132 |
+
return F_t.rotate(img, matrix=matrix, interpolation=interpolation.value, expand=expand, fill=fill)
|
| 1133 |
+
|
| 1134 |
+
|
| 1135 |
+
def affine(
|
| 1136 |
+
img: Tensor,
|
| 1137 |
+
angle: float,
|
| 1138 |
+
translate: List[int],
|
| 1139 |
+
scale: float,
|
| 1140 |
+
shear: List[float],
|
| 1141 |
+
interpolation: InterpolationMode = InterpolationMode.NEAREST,
|
| 1142 |
+
fill: Optional[List[float]] = None,
|
| 1143 |
+
center: Optional[List[int]] = None,
|
| 1144 |
+
) -> Tensor:
|
| 1145 |
+
"""Apply affine transformation on the image keeping image center invariant.
|
| 1146 |
+
If the image is torch Tensor, it is expected
|
| 1147 |
+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
|
| 1148 |
+
|
| 1149 |
+
Args:
|
| 1150 |
+
img (PIL Image or Tensor): image to transform.
|
| 1151 |
+
angle (number): rotation angle in degrees between -180 and 180, clockwise direction.
|
| 1152 |
+
translate (sequence of integers): horizontal and vertical translations (post-rotation translation)
|
| 1153 |
+
scale (float): overall scale
|
| 1154 |
+
shear (float or sequence): shear angle value in degrees between -180 to 180, clockwise direction.
|
| 1155 |
+
If a sequence is specified, the first value corresponds to a shear parallel to the x-axis, while
|
| 1156 |
+
the second value corresponds to a shear parallel to the y-axis.
|
| 1157 |
+
interpolation (InterpolationMode): Desired interpolation enum defined by
|
| 1158 |
+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
|
| 1159 |
+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
|
| 1160 |
+
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
|
| 1161 |
+
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
|
| 1162 |
+
image. If given a number, the value is used for all bands respectively.
|
| 1163 |
+
|
| 1164 |
+
.. note::
|
| 1165 |
+
In torchscript mode single int/float value is not supported, please use a sequence
|
| 1166 |
+
of length 1: ``[value, ]``.
|
| 1167 |
+
center (sequence, optional): Optional center of rotation. Origin is the upper left corner.
|
| 1168 |
+
Default is the center of the image.
|
| 1169 |
+
|
| 1170 |
+
Returns:
|
| 1171 |
+
PIL Image or Tensor: Transformed image.
|
| 1172 |
+
"""
|
| 1173 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 1174 |
+
_log_api_usage_once(affine)
|
| 1175 |
+
|
| 1176 |
+
if isinstance(interpolation, int):
|
| 1177 |
+
interpolation = _interpolation_modes_from_int(interpolation)
|
| 1178 |
+
elif not isinstance(interpolation, InterpolationMode):
|
| 1179 |
+
raise TypeError(
|
| 1180 |
+
"Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
|
| 1181 |
+
)
|
| 1182 |
+
|
| 1183 |
+
if not isinstance(angle, (int, float)):
|
| 1184 |
+
raise TypeError("Argument angle should be int or float")
|
| 1185 |
+
|
| 1186 |
+
if not isinstance(translate, (list, tuple)):
|
| 1187 |
+
raise TypeError("Argument translate should be a sequence")
|
| 1188 |
+
|
| 1189 |
+
if len(translate) != 2:
|
| 1190 |
+
raise ValueError("Argument translate should be a sequence of length 2")
|
| 1191 |
+
|
| 1192 |
+
if scale <= 0.0:
|
| 1193 |
+
raise ValueError("Argument scale should be positive")
|
| 1194 |
+
|
| 1195 |
+
if not isinstance(shear, (numbers.Number, (list, tuple))):
|
| 1196 |
+
raise TypeError("Shear should be either a single value or a sequence of two values")
|
| 1197 |
+
|
| 1198 |
+
if isinstance(angle, int):
|
| 1199 |
+
angle = float(angle)
|
| 1200 |
+
|
| 1201 |
+
if isinstance(translate, tuple):
|
| 1202 |
+
translate = list(translate)
|
| 1203 |
+
|
| 1204 |
+
if isinstance(shear, numbers.Number):
|
| 1205 |
+
shear = [shear, 0.0]
|
| 1206 |
+
|
| 1207 |
+
if isinstance(shear, tuple):
|
| 1208 |
+
shear = list(shear)
|
| 1209 |
+
|
| 1210 |
+
if len(shear) == 1:
|
| 1211 |
+
shear = [shear[0], shear[0]]
|
| 1212 |
+
|
| 1213 |
+
if len(shear) != 2:
|
| 1214 |
+
raise ValueError(f"Shear should be a sequence containing two values. Got {shear}")
|
| 1215 |
+
|
| 1216 |
+
if center is not None and not isinstance(center, (list, tuple)):
|
| 1217 |
+
raise TypeError("Argument center should be a sequence")
|
| 1218 |
+
|
| 1219 |
+
_, height, width = get_dimensions(img)
|
| 1220 |
+
if not isinstance(img, torch.Tensor):
|
| 1221 |
+
# center = (width * 0.5 + 0.5, height * 0.5 + 0.5)
|
| 1222 |
+
# it is visually better to estimate the center without 0.5 offset
|
| 1223 |
+
# otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
|
| 1224 |
+
if center is None:
|
| 1225 |
+
center = [width * 0.5, height * 0.5]
|
| 1226 |
+
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
|
| 1227 |
+
pil_interpolation = pil_modes_mapping[interpolation]
|
| 1228 |
+
return F_pil.affine(img, matrix=matrix, interpolation=pil_interpolation, fill=fill)
|
| 1229 |
+
|
| 1230 |
+
center_f = [0.0, 0.0]
|
| 1231 |
+
if center is not None:
|
| 1232 |
+
_, height, width = get_dimensions(img)
|
| 1233 |
+
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
|
| 1234 |
+
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
|
| 1235 |
+
|
| 1236 |
+
translate_f = [1.0 * t for t in translate]
|
| 1237 |
+
matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)
|
| 1238 |
+
return F_t.affine(img, matrix=matrix, interpolation=interpolation.value, fill=fill)
|
| 1239 |
+
|
| 1240 |
+
|
| 1241 |
+
# Looks like to_grayscale() is a stand-alone functional that is never called
|
| 1242 |
+
# from the transform classes. Perhaps it's still here for BC? I can't be
|
| 1243 |
+
# bothered to dig.
|
| 1244 |
+
@torch.jit.unused
|
| 1245 |
+
def to_grayscale(img, num_output_channels=1):
|
| 1246 |
+
"""Convert PIL image of any mode (RGB, HSV, LAB, etc) to grayscale version of image.
|
| 1247 |
+
This transform does not support torch Tensor.
|
| 1248 |
+
|
| 1249 |
+
Args:
|
| 1250 |
+
img (PIL Image): PIL Image to be converted to grayscale.
|
| 1251 |
+
num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default is 1.
|
| 1252 |
+
|
| 1253 |
+
Returns:
|
| 1254 |
+
PIL Image: Grayscale version of the image.
|
| 1255 |
+
|
| 1256 |
+
- if num_output_channels = 1 : returned image is single channel
|
| 1257 |
+
- if num_output_channels = 3 : returned image is 3 channel with r = g = b
|
| 1258 |
+
"""
|
| 1259 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 1260 |
+
_log_api_usage_once(to_grayscale)
|
| 1261 |
+
if isinstance(img, Image.Image):
|
| 1262 |
+
return F_pil.to_grayscale(img, num_output_channels)
|
| 1263 |
+
|
| 1264 |
+
raise TypeError("Input should be PIL Image")
|
| 1265 |
+
|
| 1266 |
+
|
| 1267 |
+
def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
|
| 1268 |
+
"""Convert RGB image to grayscale version of image.
|
| 1269 |
+
If the image is torch Tensor, it is expected
|
| 1270 |
+
to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
|
| 1271 |
+
|
| 1272 |
+
Note:
|
| 1273 |
+
Please, note that this method supports only RGB images as input. For inputs in other color spaces,
|
| 1274 |
+
please, consider using :meth:`~torchvision.transforms.functional.to_grayscale` with PIL Image.
|
| 1275 |
+
|
| 1276 |
+
Args:
|
| 1277 |
+
img (PIL Image or Tensor): RGB Image to be converted to grayscale.
|
| 1278 |
+
num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1.
|
| 1279 |
+
|
| 1280 |
+
Returns:
|
| 1281 |
+
PIL Image or Tensor: Grayscale version of the image.
|
| 1282 |
+
|
| 1283 |
+
- if num_output_channels = 1 : returned image is single channel
|
| 1284 |
+
- if num_output_channels = 3 : returned image is 3 channel with r = g = b
|
| 1285 |
+
"""
|
| 1286 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 1287 |
+
_log_api_usage_once(rgb_to_grayscale)
|
| 1288 |
+
if not isinstance(img, torch.Tensor):
|
| 1289 |
+
return F_pil.to_grayscale(img, num_output_channels)
|
| 1290 |
+
|
| 1291 |
+
return F_t.rgb_to_grayscale(img, num_output_channels)
|
| 1292 |
+
|
| 1293 |
+
|
| 1294 |
+
def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
|
| 1295 |
+
"""Erase the input Tensor Image with given value.
|
| 1296 |
+
This transform does not support PIL Image.
|
| 1297 |
+
|
| 1298 |
+
Args:
|
| 1299 |
+
img (Tensor Image): Tensor image of size (C, H, W) to be erased
|
| 1300 |
+
i (int): i in (i,j) i.e coordinates of the upper left corner.
|
| 1301 |
+
j (int): j in (i,j) i.e coordinates of the upper left corner.
|
| 1302 |
+
h (int): Height of the erased region.
|
| 1303 |
+
w (int): Width of the erased region.
|
| 1304 |
+
v: Erasing value.
|
| 1305 |
+
inplace(bool, optional): For in-place operations. By default, is set False.
|
| 1306 |
+
|
| 1307 |
+
Returns:
|
| 1308 |
+
Tensor Image: Erased image.
|
| 1309 |
+
"""
|
| 1310 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 1311 |
+
_log_api_usage_once(erase)
|
| 1312 |
+
if not isinstance(img, torch.Tensor):
|
| 1313 |
+
raise TypeError(f"img should be Tensor Image. Got {type(img)}")
|
| 1314 |
+
|
| 1315 |
+
return F_t.erase(img, i, j, h, w, v, inplace=inplace)
|
| 1316 |
+
|
| 1317 |
+
|
| 1318 |
+
def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Tensor:
|
| 1319 |
+
"""Performs Gaussian blurring on the image by given kernel
|
| 1320 |
+
|
| 1321 |
+
The convolution will be using reflection padding corresponding to the kernel size, to maintain the input shape.
|
| 1322 |
+
If the image is torch Tensor, it is expected
|
| 1323 |
+
to have [..., H, W] shape, where ... means at most one leading dimension.
|
| 1324 |
+
|
| 1325 |
+
Args:
|
| 1326 |
+
img (PIL Image or Tensor): Image to be blurred
|
| 1327 |
+
kernel_size (sequence of ints or int): Gaussian kernel size. Can be a sequence of integers
|
| 1328 |
+
like ``(kx, ky)`` or a single integer for square kernels.
|
| 1329 |
+
|
| 1330 |
+
.. note::
|
| 1331 |
+
In torchscript mode kernel_size as single int is not supported, use a sequence of
|
| 1332 |
+
length 1: ``[ksize, ]``.
|
| 1333 |
+
sigma (sequence of floats or float, optional): Gaussian kernel standard deviation. Can be a
|
| 1334 |
+
sequence of floats like ``(sigma_x, sigma_y)`` or a single float to define the
|
| 1335 |
+
same sigma in both X/Y directions. If None, then it is computed using
|
| 1336 |
+
``kernel_size`` as ``sigma = 0.3 * ((kernel_size - 1) * 0.5 - 1) + 0.8``.
|
| 1337 |
+
Default, None.
|
| 1338 |
+
|
| 1339 |
+
.. note::
|
| 1340 |
+
In torchscript mode sigma as single float is
|
| 1341 |
+
not supported, use a sequence of length 1: ``[sigma, ]``.
|
| 1342 |
+
|
| 1343 |
+
Returns:
|
| 1344 |
+
PIL Image or Tensor: Gaussian Blurred version of the image.
|
| 1345 |
+
"""
|
| 1346 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 1347 |
+
_log_api_usage_once(gaussian_blur)
|
| 1348 |
+
if not isinstance(kernel_size, (int, list, tuple)):
|
| 1349 |
+
raise TypeError(f"kernel_size should be int or a sequence of integers. Got {type(kernel_size)}")
|
| 1350 |
+
if isinstance(kernel_size, int):
|
| 1351 |
+
kernel_size = [kernel_size, kernel_size]
|
| 1352 |
+
if len(kernel_size) != 2:
|
| 1353 |
+
raise ValueError(f"If kernel_size is a sequence its length should be 2. Got {len(kernel_size)}")
|
| 1354 |
+
for ksize in kernel_size:
|
| 1355 |
+
if ksize % 2 == 0 or ksize < 0:
|
| 1356 |
+
raise ValueError(f"kernel_size should have odd and positive integers. Got {kernel_size}")
|
| 1357 |
+
|
| 1358 |
+
if sigma is None:
|
| 1359 |
+
sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size]
|
| 1360 |
+
|
| 1361 |
+
if sigma is not None and not isinstance(sigma, (int, float, list, tuple)):
|
| 1362 |
+
raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}")
|
| 1363 |
+
if isinstance(sigma, (int, float)):
|
| 1364 |
+
sigma = [float(sigma), float(sigma)]
|
| 1365 |
+
if isinstance(sigma, (list, tuple)) and len(sigma) == 1:
|
| 1366 |
+
sigma = [sigma[0], sigma[0]]
|
| 1367 |
+
if len(sigma) != 2:
|
| 1368 |
+
raise ValueError(f"If sigma is a sequence, its length should be 2. Got {len(sigma)}")
|
| 1369 |
+
for s in sigma:
|
| 1370 |
+
if s <= 0.0:
|
| 1371 |
+
raise ValueError(f"sigma should have positive values. Got {sigma}")
|
| 1372 |
+
|
| 1373 |
+
t_img = img
|
| 1374 |
+
if not isinstance(img, torch.Tensor):
|
| 1375 |
+
if not F_pil._is_pil_image(img):
|
| 1376 |
+
raise TypeError(f"img should be PIL Image or Tensor. Got {type(img)}")
|
| 1377 |
+
|
| 1378 |
+
t_img = pil_to_tensor(img)
|
| 1379 |
+
|
| 1380 |
+
output = F_t.gaussian_blur(t_img, kernel_size, sigma)
|
| 1381 |
+
|
| 1382 |
+
if not isinstance(img, torch.Tensor):
|
| 1383 |
+
output = to_pil_image(output, mode=img.mode)
|
| 1384 |
+
return output
|
| 1385 |
+
|
| 1386 |
+
|
| 1387 |
+
def invert(img: Tensor) -> Tensor:
|
| 1388 |
+
"""Invert the colors of an RGB/grayscale image.
|
| 1389 |
+
|
| 1390 |
+
Args:
|
| 1391 |
+
img (PIL Image or Tensor): Image to have its colors inverted.
|
| 1392 |
+
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
|
| 1393 |
+
where ... means it can have an arbitrary number of leading dimensions.
|
| 1394 |
+
If img is PIL Image, it is expected to be in mode "L" or "RGB".
|
| 1395 |
+
|
| 1396 |
+
Returns:
|
| 1397 |
+
PIL Image or Tensor: Color inverted image.
|
| 1398 |
+
"""
|
| 1399 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 1400 |
+
_log_api_usage_once(invert)
|
| 1401 |
+
if not isinstance(img, torch.Tensor):
|
| 1402 |
+
return F_pil.invert(img)
|
| 1403 |
+
|
| 1404 |
+
return F_t.invert(img)
|
| 1405 |
+
|
| 1406 |
+
|
| 1407 |
+
def posterize(img: Tensor, bits: int) -> Tensor:
|
| 1408 |
+
"""Posterize an image by reducing the number of bits for each color channel.
|
| 1409 |
+
|
| 1410 |
+
Args:
|
| 1411 |
+
img (PIL Image or Tensor): Image to have its colors posterized.
|
| 1412 |
+
If img is torch Tensor, it should be of type torch.uint8, and
|
| 1413 |
+
it is expected to be in [..., 1 or 3, H, W] format, where ... means
|
| 1414 |
+
it can have an arbitrary number of leading dimensions.
|
| 1415 |
+
If img is PIL Image, it is expected to be in mode "L" or "RGB".
|
| 1416 |
+
bits (int): The number of bits to keep for each channel (0-8).
|
| 1417 |
+
Returns:
|
| 1418 |
+
PIL Image or Tensor: Posterized image.
|
| 1419 |
+
"""
|
| 1420 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 1421 |
+
_log_api_usage_once(posterize)
|
| 1422 |
+
if not (0 <= bits <= 8):
|
| 1423 |
+
raise ValueError(f"The number if bits should be between 0 and 8. Got {bits}")
|
| 1424 |
+
|
| 1425 |
+
if not isinstance(img, torch.Tensor):
|
| 1426 |
+
return F_pil.posterize(img, bits)
|
| 1427 |
+
|
| 1428 |
+
return F_t.posterize(img, bits)
|
| 1429 |
+
|
| 1430 |
+
|
| 1431 |
+
def solarize(img: Tensor, threshold: float) -> Tensor:
|
| 1432 |
+
"""Solarize an RGB/grayscale image by inverting all pixel values above a threshold.
|
| 1433 |
+
|
| 1434 |
+
Args:
|
| 1435 |
+
img (PIL Image or Tensor): Image to have its colors inverted.
|
| 1436 |
+
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
|
| 1437 |
+
where ... means it can have an arbitrary number of leading dimensions.
|
| 1438 |
+
If img is PIL Image, it is expected to be in mode "L" or "RGB".
|
| 1439 |
+
threshold (float): All pixels equal or above this value are inverted.
|
| 1440 |
+
Returns:
|
| 1441 |
+
PIL Image or Tensor: Solarized image.
|
| 1442 |
+
"""
|
| 1443 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 1444 |
+
_log_api_usage_once(solarize)
|
| 1445 |
+
if not isinstance(img, torch.Tensor):
|
| 1446 |
+
return F_pil.solarize(img, threshold)
|
| 1447 |
+
|
| 1448 |
+
return F_t.solarize(img, threshold)
|
| 1449 |
+
|
| 1450 |
+
|
| 1451 |
+
def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
|
| 1452 |
+
"""Adjust the sharpness of an image.
|
| 1453 |
+
|
| 1454 |
+
Args:
|
| 1455 |
+
img (PIL Image or Tensor): Image to be adjusted.
|
| 1456 |
+
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
|
| 1457 |
+
where ... means it can have an arbitrary number of leading dimensions.
|
| 1458 |
+
sharpness_factor (float): How much to adjust the sharpness. Can be
|
| 1459 |
+
any non-negative number. 0 gives a blurred image, 1 gives the
|
| 1460 |
+
original image while 2 increases the sharpness by a factor of 2.
|
| 1461 |
+
|
| 1462 |
+
Returns:
|
| 1463 |
+
PIL Image or Tensor: Sharpness adjusted image.
|
| 1464 |
+
"""
|
| 1465 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 1466 |
+
_log_api_usage_once(adjust_sharpness)
|
| 1467 |
+
if not isinstance(img, torch.Tensor):
|
| 1468 |
+
return F_pil.adjust_sharpness(img, sharpness_factor)
|
| 1469 |
+
|
| 1470 |
+
return F_t.adjust_sharpness(img, sharpness_factor)
|
| 1471 |
+
|
| 1472 |
+
|
| 1473 |
+
def autocontrast(img: Tensor) -> Tensor:
|
| 1474 |
+
"""Maximize contrast of an image by remapping its
|
| 1475 |
+
pixels per channel so that the lowest becomes black and the lightest
|
| 1476 |
+
becomes white.
|
| 1477 |
+
|
| 1478 |
+
Args:
|
| 1479 |
+
img (PIL Image or Tensor): Image on which autocontrast is applied.
|
| 1480 |
+
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
|
| 1481 |
+
where ... means it can have an arbitrary number of leading dimensions.
|
| 1482 |
+
If img is PIL Image, it is expected to be in mode "L" or "RGB".
|
| 1483 |
+
|
| 1484 |
+
Returns:
|
| 1485 |
+
PIL Image or Tensor: An image that was autocontrasted.
|
| 1486 |
+
"""
|
| 1487 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 1488 |
+
_log_api_usage_once(autocontrast)
|
| 1489 |
+
if not isinstance(img, torch.Tensor):
|
| 1490 |
+
return F_pil.autocontrast(img)
|
| 1491 |
+
|
| 1492 |
+
return F_t.autocontrast(img)
|
| 1493 |
+
|
| 1494 |
+
|
| 1495 |
+
def equalize(img: Tensor) -> Tensor:
|
| 1496 |
+
"""Equalize the histogram of an image by applying
|
| 1497 |
+
a non-linear mapping to the input in order to create a uniform
|
| 1498 |
+
distribution of grayscale values in the output.
|
| 1499 |
+
|
| 1500 |
+
Args:
|
| 1501 |
+
img (PIL Image or Tensor): Image on which equalize is applied.
|
| 1502 |
+
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
|
| 1503 |
+
where ... means it can have an arbitrary number of leading dimensions.
|
| 1504 |
+
The tensor dtype must be ``torch.uint8`` and values are expected to be in ``[0, 255]``.
|
| 1505 |
+
If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
|
| 1506 |
+
|
| 1507 |
+
Returns:
|
| 1508 |
+
PIL Image or Tensor: An image that was equalized.
|
| 1509 |
+
"""
|
| 1510 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 1511 |
+
_log_api_usage_once(equalize)
|
| 1512 |
+
if not isinstance(img, torch.Tensor):
|
| 1513 |
+
return F_pil.equalize(img)
|
| 1514 |
+
|
| 1515 |
+
return F_t.equalize(img)
|
| 1516 |
+
|
| 1517 |
+
|
| 1518 |
+
def elastic_transform(
|
| 1519 |
+
img: Tensor,
|
| 1520 |
+
displacement: Tensor,
|
| 1521 |
+
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
|
| 1522 |
+
fill: Optional[List[float]] = None,
|
| 1523 |
+
) -> Tensor:
|
| 1524 |
+
"""Transform a tensor image with elastic transformations.
|
| 1525 |
+
Given alpha and sigma, it will generate displacement
|
| 1526 |
+
vectors for all pixels based on random offsets. Alpha controls the strength
|
| 1527 |
+
and sigma controls the smoothness of the displacements.
|
| 1528 |
+
The displacements are added to an identity grid and the resulting grid is
|
| 1529 |
+
used to grid_sample from the image.
|
| 1530 |
+
|
| 1531 |
+
Applications:
|
| 1532 |
+
Randomly transforms the morphology of objects in images and produces a
|
| 1533 |
+
see-through-water-like effect.
|
| 1534 |
+
|
| 1535 |
+
Args:
|
| 1536 |
+
img (PIL Image or Tensor): Image on which elastic_transform is applied.
|
| 1537 |
+
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
|
| 1538 |
+
where ... means it can have an arbitrary number of leading dimensions.
|
| 1539 |
+
If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
|
| 1540 |
+
displacement (Tensor): The displacement field. Expected shape is [1, H, W, 2].
|
| 1541 |
+
interpolation (InterpolationMode): Desired interpolation enum defined by
|
| 1542 |
+
:class:`torchvision.transforms.InterpolationMode`.
|
| 1543 |
+
Default is ``InterpolationMode.BILINEAR``.
|
| 1544 |
+
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
|
| 1545 |
+
fill (number or str or tuple): Pixel fill value for constant fill. Default is 0.
|
| 1546 |
+
If a tuple of length 3, it is used to fill R, G, B channels respectively.
|
| 1547 |
+
This value is only used when the padding_mode is constant.
|
| 1548 |
+
"""
|
| 1549 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 1550 |
+
_log_api_usage_once(elastic_transform)
|
| 1551 |
+
# Backward compatibility with integer value
|
| 1552 |
+
if isinstance(interpolation, int):
|
| 1553 |
+
warnings.warn(
|
| 1554 |
+
"Argument interpolation should be of type InterpolationMode instead of int. "
|
| 1555 |
+
"Please, use InterpolationMode enum."
|
| 1556 |
+
)
|
| 1557 |
+
interpolation = _interpolation_modes_from_int(interpolation)
|
| 1558 |
+
|
| 1559 |
+
if not isinstance(displacement, torch.Tensor):
|
| 1560 |
+
raise TypeError("Argument displacement should be a Tensor")
|
| 1561 |
+
|
| 1562 |
+
t_img = img
|
| 1563 |
+
if not isinstance(img, torch.Tensor):
|
| 1564 |
+
if not F_pil._is_pil_image(img):
|
| 1565 |
+
raise TypeError(f"img should be PIL Image or Tensor. Got {type(img)}")
|
| 1566 |
+
t_img = pil_to_tensor(img)
|
| 1567 |
+
|
| 1568 |
+
shape = t_img.shape
|
| 1569 |
+
shape = (1,) + shape[-2:] + (2,)
|
| 1570 |
+
if shape != displacement.shape:
|
| 1571 |
+
raise ValueError(f"Argument displacement shape should be {shape}, but given {displacement.shape}")
|
| 1572 |
+
|
| 1573 |
+
# TODO: if image shape is [N1, N2, ..., C, H, W] and
|
| 1574 |
+
# displacement is [1, H, W, 2] we need to reshape input image
|
| 1575 |
+
# such grid_sampler takes internal code for 4D input
|
| 1576 |
+
|
| 1577 |
+
output = F_t.elastic_transform(
|
| 1578 |
+
t_img,
|
| 1579 |
+
displacement,
|
| 1580 |
+
interpolation=interpolation.value,
|
| 1581 |
+
fill=fill,
|
| 1582 |
+
)
|
| 1583 |
+
|
| 1584 |
+
if not isinstance(img, torch.Tensor):
|
| 1585 |
+
output = to_pil_image(output, mode=img.mode)
|
| 1586 |
+
return output
|
.venv/lib/python3.11/site-packages/torchvision/transforms/transforms.py
ADDED
|
@@ -0,0 +1,2153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import numbers
|
| 3 |
+
import random
|
| 4 |
+
import warnings
|
| 5 |
+
from collections.abc import Sequence
|
| 6 |
+
from typing import List, Optional, Tuple, Union
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
import accimage
|
| 13 |
+
except ImportError:
|
| 14 |
+
accimage = None
|
| 15 |
+
|
| 16 |
+
from ..utils import _log_api_usage_once
|
| 17 |
+
from . import functional as F
|
| 18 |
+
from .functional import _interpolation_modes_from_int, InterpolationMode
|
| 19 |
+
|
| 20 |
+
__all__ = [
|
| 21 |
+
"Compose",
|
| 22 |
+
"ToTensor",
|
| 23 |
+
"PILToTensor",
|
| 24 |
+
"ConvertImageDtype",
|
| 25 |
+
"ToPILImage",
|
| 26 |
+
"Normalize",
|
| 27 |
+
"Resize",
|
| 28 |
+
"CenterCrop",
|
| 29 |
+
"Pad",
|
| 30 |
+
"Lambda",
|
| 31 |
+
"RandomApply",
|
| 32 |
+
"RandomChoice",
|
| 33 |
+
"RandomOrder",
|
| 34 |
+
"RandomCrop",
|
| 35 |
+
"RandomHorizontalFlip",
|
| 36 |
+
"RandomVerticalFlip",
|
| 37 |
+
"RandomResizedCrop",
|
| 38 |
+
"FiveCrop",
|
| 39 |
+
"TenCrop",
|
| 40 |
+
"LinearTransformation",
|
| 41 |
+
"ColorJitter",
|
| 42 |
+
"RandomRotation",
|
| 43 |
+
"RandomAffine",
|
| 44 |
+
"Grayscale",
|
| 45 |
+
"RandomGrayscale",
|
| 46 |
+
"RandomPerspective",
|
| 47 |
+
"RandomErasing",
|
| 48 |
+
"GaussianBlur",
|
| 49 |
+
"InterpolationMode",
|
| 50 |
+
"RandomInvert",
|
| 51 |
+
"RandomPosterize",
|
| 52 |
+
"RandomSolarize",
|
| 53 |
+
"RandomAdjustSharpness",
|
| 54 |
+
"RandomAutocontrast",
|
| 55 |
+
"RandomEqualize",
|
| 56 |
+
"ElasticTransform",
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class Compose:
|
| 61 |
+
"""Composes several transforms together. This transform does not support torchscript.
|
| 62 |
+
Please, see the note below.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
transforms (list of ``Transform`` objects): list of transforms to compose.
|
| 66 |
+
|
| 67 |
+
Example:
|
| 68 |
+
>>> transforms.Compose([
|
| 69 |
+
>>> transforms.CenterCrop(10),
|
| 70 |
+
>>> transforms.PILToTensor(),
|
| 71 |
+
>>> transforms.ConvertImageDtype(torch.float),
|
| 72 |
+
>>> ])
|
| 73 |
+
|
| 74 |
+
.. note::
|
| 75 |
+
In order to script the transformations, please use ``torch.nn.Sequential`` as below.
|
| 76 |
+
|
| 77 |
+
>>> transforms = torch.nn.Sequential(
|
| 78 |
+
>>> transforms.CenterCrop(10),
|
| 79 |
+
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
| 80 |
+
>>> )
|
| 81 |
+
>>> scripted_transforms = torch.jit.script(transforms)
|
| 82 |
+
|
| 83 |
+
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
|
| 84 |
+
`lambda` functions or ``PIL.Image``.
|
| 85 |
+
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
def __init__(self, transforms):
|
| 89 |
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
| 90 |
+
_log_api_usage_once(self)
|
| 91 |
+
self.transforms = transforms
|
| 92 |
+
|
| 93 |
+
def __call__(self, img):
|
| 94 |
+
for t in self.transforms:
|
| 95 |
+
img = t(img)
|
| 96 |
+
return img
|
| 97 |
+
|
| 98 |
+
def __repr__(self) -> str:
|
| 99 |
+
format_string = self.__class__.__name__ + "("
|
| 100 |
+
for t in self.transforms:
|
| 101 |
+
format_string += "\n"
|
| 102 |
+
format_string += f" {t}"
|
| 103 |
+
format_string += "\n)"
|
| 104 |
+
return format_string
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class ToTensor:
|
| 108 |
+
"""Convert a PIL Image or ndarray to tensor and scale the values accordingly.
|
| 109 |
+
|
| 110 |
+
This transform does not support torchscript.
|
| 111 |
+
|
| 112 |
+
Converts a PIL Image or numpy.ndarray (H x W x C) in the range
|
| 113 |
+
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
|
| 114 |
+
if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
|
| 115 |
+
or if the numpy.ndarray has dtype = np.uint8
|
| 116 |
+
|
| 117 |
+
In the other cases, tensors are returned without scaling.
|
| 118 |
+
|
| 119 |
+
.. note::
|
| 120 |
+
Because the input image is scaled to [0.0, 1.0], this transformation should not be used when
|
| 121 |
+
transforming target image masks. See the `references`_ for implementing the transforms for image masks.
|
| 122 |
+
|
| 123 |
+
.. _references: https://github.com/pytorch/vision/tree/main/references/segmentation
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
def __init__(self) -> None:
|
| 127 |
+
_log_api_usage_once(self)
|
| 128 |
+
|
| 129 |
+
def __call__(self, pic):
|
| 130 |
+
"""
|
| 131 |
+
Args:
|
| 132 |
+
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
Tensor: Converted image.
|
| 136 |
+
"""
|
| 137 |
+
return F.to_tensor(pic)
|
| 138 |
+
|
| 139 |
+
def __repr__(self) -> str:
|
| 140 |
+
return f"{self.__class__.__name__}()"
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class PILToTensor:
|
| 144 |
+
"""Convert a PIL Image to a tensor of the same type - this does not scale values.
|
| 145 |
+
|
| 146 |
+
This transform does not support torchscript.
|
| 147 |
+
|
| 148 |
+
Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W).
|
| 149 |
+
"""
|
| 150 |
+
|
| 151 |
+
def __init__(self) -> None:
|
| 152 |
+
_log_api_usage_once(self)
|
| 153 |
+
|
| 154 |
+
def __call__(self, pic):
|
| 155 |
+
"""
|
| 156 |
+
.. note::
|
| 157 |
+
|
| 158 |
+
A deep copy of the underlying array is performed.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
pic (PIL Image): Image to be converted to tensor.
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
Tensor: Converted image.
|
| 165 |
+
"""
|
| 166 |
+
return F.pil_to_tensor(pic)
|
| 167 |
+
|
| 168 |
+
def __repr__(self) -> str:
|
| 169 |
+
return f"{self.__class__.__name__}()"
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class ConvertImageDtype(torch.nn.Module):
|
| 173 |
+
"""Convert a tensor image to the given ``dtype`` and scale the values accordingly.
|
| 174 |
+
|
| 175 |
+
This function does not support PIL Image.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
dtype (torch.dtype): Desired data type of the output
|
| 179 |
+
|
| 180 |
+
.. note::
|
| 181 |
+
|
| 182 |
+
When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
|
| 183 |
+
If converted back and forth, this mismatch has no effect.
|
| 184 |
+
|
| 185 |
+
Raises:
|
| 186 |
+
RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
|
| 187 |
+
well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
|
| 188 |
+
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
|
| 189 |
+
of the integer ``dtype``.
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
def __init__(self, dtype: torch.dtype) -> None:
|
| 193 |
+
super().__init__()
|
| 194 |
+
_log_api_usage_once(self)
|
| 195 |
+
self.dtype = dtype
|
| 196 |
+
|
| 197 |
+
def forward(self, image):
|
| 198 |
+
return F.convert_image_dtype(image, self.dtype)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class ToPILImage:
|
| 202 |
+
"""Convert a tensor or an ndarray to PIL Image
|
| 203 |
+
|
| 204 |
+
This transform does not support torchscript.
|
| 205 |
+
|
| 206 |
+
Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
|
| 207 |
+
H x W x C to a PIL Image while adjusting the value range depending on the ``mode``.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
|
| 211 |
+
If ``mode`` is ``None`` (default) there are some assumptions made about the input data:
|
| 212 |
+
|
| 213 |
+
- If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``.
|
| 214 |
+
- If the input has 3 channels, the ``mode`` is assumed to be ``RGB``.
|
| 215 |
+
- If the input has 2 channels, the ``mode`` is assumed to be ``LA``.
|
| 216 |
+
- If the input has 1 channel, the ``mode`` is determined by the data type (i.e ``int``, ``float``, ``short``).
|
| 217 |
+
|
| 218 |
+
.. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
|
| 219 |
+
"""
|
| 220 |
+
|
| 221 |
+
def __init__(self, mode=None):
|
| 222 |
+
_log_api_usage_once(self)
|
| 223 |
+
self.mode = mode
|
| 224 |
+
|
| 225 |
+
def __call__(self, pic):
|
| 226 |
+
"""
|
| 227 |
+
Args:
|
| 228 |
+
pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
PIL Image: Image converted to PIL Image.
|
| 232 |
+
|
| 233 |
+
"""
|
| 234 |
+
return F.to_pil_image(pic, self.mode)
|
| 235 |
+
|
| 236 |
+
def __repr__(self) -> str:
|
| 237 |
+
format_string = self.__class__.__name__ + "("
|
| 238 |
+
if self.mode is not None:
|
| 239 |
+
format_string += f"mode={self.mode}"
|
| 240 |
+
format_string += ")"
|
| 241 |
+
return format_string
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class Normalize(torch.nn.Module):
|
| 245 |
+
"""Normalize a tensor image with mean and standard deviation.
|
| 246 |
+
This transform does not support PIL Image.
|
| 247 |
+
Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
|
| 248 |
+
channels, this transform will normalize each channel of the input
|
| 249 |
+
``torch.*Tensor`` i.e.,
|
| 250 |
+
``output[channel] = (input[channel] - mean[channel]) / std[channel]``
|
| 251 |
+
|
| 252 |
+
.. note::
|
| 253 |
+
This transform acts out of place, i.e., it does not mutate the input tensor.
|
| 254 |
+
|
| 255 |
+
Args:
|
| 256 |
+
mean (sequence): Sequence of means for each channel.
|
| 257 |
+
std (sequence): Sequence of standard deviations for each channel.
|
| 258 |
+
inplace(bool,optional): Bool to make this operation in-place.
|
| 259 |
+
|
| 260 |
+
"""
|
| 261 |
+
|
| 262 |
+
def __init__(self, mean, std, inplace=False):
|
| 263 |
+
super().__init__()
|
| 264 |
+
_log_api_usage_once(self)
|
| 265 |
+
self.mean = mean
|
| 266 |
+
self.std = std
|
| 267 |
+
self.inplace = inplace
|
| 268 |
+
|
| 269 |
+
def forward(self, tensor: Tensor) -> Tensor:
|
| 270 |
+
"""
|
| 271 |
+
Args:
|
| 272 |
+
tensor (Tensor): Tensor image to be normalized.
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
Tensor: Normalized Tensor image.
|
| 276 |
+
"""
|
| 277 |
+
return F.normalize(tensor, self.mean, self.std, self.inplace)
|
| 278 |
+
|
| 279 |
+
def __repr__(self) -> str:
|
| 280 |
+
return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})"
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
class Resize(torch.nn.Module):
|
| 284 |
+
"""Resize the input image to the given size.
|
| 285 |
+
If the image is torch Tensor, it is expected
|
| 286 |
+
to have [..., H, W] shape, where ... means a maximum of two leading dimensions
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
size (sequence or int): Desired output size. If size is a sequence like
|
| 290 |
+
(h, w), output size will be matched to this. If size is an int,
|
| 291 |
+
smaller edge of the image will be matched to this number.
|
| 292 |
+
i.e, if height > width, then image will be rescaled to
|
| 293 |
+
(size * height / width, size).
|
| 294 |
+
|
| 295 |
+
.. note::
|
| 296 |
+
In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
|
| 297 |
+
interpolation (InterpolationMode): Desired interpolation enum defined by
|
| 298 |
+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
|
| 299 |
+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
|
| 300 |
+
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
|
| 301 |
+
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
|
| 302 |
+
max_size (int, optional): The maximum allowed for the longer edge of
|
| 303 |
+
the resized image. If the longer edge of the image is greater
|
| 304 |
+
than ``max_size`` after being resized according to ``size``,
|
| 305 |
+
``size`` will be overruled so that the longer edge is equal to
|
| 306 |
+
``max_size``.
|
| 307 |
+
As a result, the smaller edge may be shorter than ``size``. This
|
| 308 |
+
is only supported if ``size`` is an int (or a sequence of length
|
| 309 |
+
1 in torchscript mode).
|
| 310 |
+
antialias (bool, optional): Whether to apply antialiasing.
|
| 311 |
+
It only affects **tensors** with bilinear or bicubic modes and it is
|
| 312 |
+
ignored otherwise: on PIL images, antialiasing is always applied on
|
| 313 |
+
bilinear or bicubic modes; on other modes (for PIL images and
|
| 314 |
+
tensors), antialiasing makes no sense and this parameter is ignored.
|
| 315 |
+
Possible values are:
|
| 316 |
+
|
| 317 |
+
- ``True`` (default): will apply antialiasing for bilinear or bicubic modes.
|
| 318 |
+
Other mode aren't affected. This is probably what you want to use.
|
| 319 |
+
- ``False``: will not apply antialiasing for tensors on any mode. PIL
|
| 320 |
+
images are still antialiased on bilinear or bicubic modes, because
|
| 321 |
+
PIL doesn't support no antialias.
|
| 322 |
+
- ``None``: equivalent to ``False`` for tensors and ``True`` for
|
| 323 |
+
PIL images. This value exists for legacy reasons and you probably
|
| 324 |
+
don't want to use it unless you really know what you are doing.
|
| 325 |
+
|
| 326 |
+
The default value changed from ``None`` to ``True`` in
|
| 327 |
+
v0.17, for the PIL and Tensor backends to be consistent.
|
| 328 |
+
"""
|
| 329 |
+
|
| 330 |
+
def __init__(self, size, interpolation=InterpolationMode.BILINEAR, max_size=None, antialias=True):
|
| 331 |
+
super().__init__()
|
| 332 |
+
_log_api_usage_once(self)
|
| 333 |
+
if not isinstance(size, (int, Sequence)):
|
| 334 |
+
raise TypeError(f"Size should be int or sequence. Got {type(size)}")
|
| 335 |
+
if isinstance(size, Sequence) and len(size) not in (1, 2):
|
| 336 |
+
raise ValueError("If size is a sequence, it should have 1 or 2 values")
|
| 337 |
+
self.size = size
|
| 338 |
+
self.max_size = max_size
|
| 339 |
+
|
| 340 |
+
if isinstance(interpolation, int):
|
| 341 |
+
interpolation = _interpolation_modes_from_int(interpolation)
|
| 342 |
+
|
| 343 |
+
self.interpolation = interpolation
|
| 344 |
+
self.antialias = antialias
|
| 345 |
+
|
| 346 |
+
def forward(self, img):
|
| 347 |
+
"""
|
| 348 |
+
Args:
|
| 349 |
+
img (PIL Image or Tensor): Image to be scaled.
|
| 350 |
+
|
| 351 |
+
Returns:
|
| 352 |
+
PIL Image or Tensor: Rescaled image.
|
| 353 |
+
"""
|
| 354 |
+
return F.resize(img, self.size, self.interpolation, self.max_size, self.antialias)
|
| 355 |
+
|
| 356 |
+
def __repr__(self) -> str:
|
| 357 |
+
detail = f"(size={self.size}, interpolation={self.interpolation.value}, max_size={self.max_size}, antialias={self.antialias})"
|
| 358 |
+
return f"{self.__class__.__name__}{detail}"
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
class CenterCrop(torch.nn.Module):
|
| 362 |
+
"""Crops the given image at the center.
|
| 363 |
+
If the image is torch Tensor, it is expected
|
| 364 |
+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
|
| 365 |
+
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
|
| 366 |
+
|
| 367 |
+
Args:
|
| 368 |
+
size (sequence or int): Desired output size of the crop. If size is an
|
| 369 |
+
int instead of sequence like (h, w), a square crop (size, size) is
|
| 370 |
+
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
|
| 371 |
+
"""
|
| 372 |
+
|
| 373 |
+
def __init__(self, size):
|
| 374 |
+
super().__init__()
|
| 375 |
+
_log_api_usage_once(self)
|
| 376 |
+
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
|
| 377 |
+
|
| 378 |
+
def forward(self, img):
|
| 379 |
+
"""
|
| 380 |
+
Args:
|
| 381 |
+
img (PIL Image or Tensor): Image to be cropped.
|
| 382 |
+
|
| 383 |
+
Returns:
|
| 384 |
+
PIL Image or Tensor: Cropped image.
|
| 385 |
+
"""
|
| 386 |
+
return F.center_crop(img, self.size)
|
| 387 |
+
|
| 388 |
+
def __repr__(self) -> str:
|
| 389 |
+
return f"{self.__class__.__name__}(size={self.size})"
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
class Pad(torch.nn.Module):
|
| 393 |
+
"""Pad the given image on all sides with the given "pad" value.
|
| 394 |
+
If the image is torch Tensor, it is expected
|
| 395 |
+
to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric,
|
| 396 |
+
at most 3 leading dimensions for mode edge,
|
| 397 |
+
and an arbitrary number of leading dimensions for mode constant
|
| 398 |
+
|
| 399 |
+
Args:
|
| 400 |
+
padding (int or sequence): Padding on each border. If a single int is provided this
|
| 401 |
+
is used to pad all borders. If sequence of length 2 is provided this is the padding
|
| 402 |
+
on left/right and top/bottom respectively. If a sequence of length 4 is provided
|
| 403 |
+
this is the padding for the left, top, right and bottom borders respectively.
|
| 404 |
+
|
| 405 |
+
.. note::
|
| 406 |
+
In torchscript mode padding as single int is not supported, use a sequence of
|
| 407 |
+
length 1: ``[padding, ]``.
|
| 408 |
+
fill (number or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
|
| 409 |
+
length 3, it is used to fill R, G, B channels respectively.
|
| 410 |
+
This value is only used when the padding_mode is constant.
|
| 411 |
+
Only number is supported for torch Tensor.
|
| 412 |
+
Only int or tuple value is supported for PIL Image.
|
| 413 |
+
padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
|
| 414 |
+
Default is constant.
|
| 415 |
+
|
| 416 |
+
- constant: pads with a constant value, this value is specified with fill
|
| 417 |
+
|
| 418 |
+
- edge: pads with the last value at the edge of the image.
|
| 419 |
+
If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2
|
| 420 |
+
|
| 421 |
+
- reflect: pads with reflection of image without repeating the last value on the edge.
|
| 422 |
+
For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
|
| 423 |
+
will result in [3, 2, 1, 2, 3, 4, 3, 2]
|
| 424 |
+
|
| 425 |
+
- symmetric: pads with reflection of image repeating the last value on the edge.
|
| 426 |
+
For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
|
| 427 |
+
will result in [2, 1, 1, 2, 3, 4, 4, 3]
|
| 428 |
+
"""
|
| 429 |
+
|
| 430 |
+
def __init__(self, padding, fill=0, padding_mode="constant"):
|
| 431 |
+
super().__init__()
|
| 432 |
+
_log_api_usage_once(self)
|
| 433 |
+
if not isinstance(padding, (numbers.Number, tuple, list)):
|
| 434 |
+
raise TypeError("Got inappropriate padding arg")
|
| 435 |
+
|
| 436 |
+
if not isinstance(fill, (numbers.Number, tuple, list)):
|
| 437 |
+
raise TypeError("Got inappropriate fill arg")
|
| 438 |
+
|
| 439 |
+
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
|
| 440 |
+
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
|
| 441 |
+
|
| 442 |
+
if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]:
|
| 443 |
+
raise ValueError(
|
| 444 |
+
f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
self.padding = padding
|
| 448 |
+
self.fill = fill
|
| 449 |
+
self.padding_mode = padding_mode
|
| 450 |
+
|
| 451 |
+
def forward(self, img):
|
| 452 |
+
"""
|
| 453 |
+
Args:
|
| 454 |
+
img (PIL Image or Tensor): Image to be padded.
|
| 455 |
+
|
| 456 |
+
Returns:
|
| 457 |
+
PIL Image or Tensor: Padded image.
|
| 458 |
+
"""
|
| 459 |
+
return F.pad(img, self.padding, self.fill, self.padding_mode)
|
| 460 |
+
|
| 461 |
+
def __repr__(self) -> str:
|
| 462 |
+
return f"{self.__class__.__name__}(padding={self.padding}, fill={self.fill}, padding_mode={self.padding_mode})"
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
class Lambda:
|
| 466 |
+
"""Apply a user-defined lambda as a transform. This transform does not support torchscript.
|
| 467 |
+
|
| 468 |
+
Args:
|
| 469 |
+
lambd (function): Lambda/function to be used for transform.
|
| 470 |
+
"""
|
| 471 |
+
|
| 472 |
+
def __init__(self, lambd):
|
| 473 |
+
_log_api_usage_once(self)
|
| 474 |
+
if not callable(lambd):
|
| 475 |
+
raise TypeError(f"Argument lambd should be callable, got {repr(type(lambd).__name__)}")
|
| 476 |
+
self.lambd = lambd
|
| 477 |
+
|
| 478 |
+
def __call__(self, img):
|
| 479 |
+
return self.lambd(img)
|
| 480 |
+
|
| 481 |
+
def __repr__(self) -> str:
|
| 482 |
+
return f"{self.__class__.__name__}()"
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
class RandomTransforms:
|
| 486 |
+
"""Base class for a list of transformations with randomness
|
| 487 |
+
|
| 488 |
+
Args:
|
| 489 |
+
transforms (sequence): list of transformations
|
| 490 |
+
"""
|
| 491 |
+
|
| 492 |
+
def __init__(self, transforms):
|
| 493 |
+
_log_api_usage_once(self)
|
| 494 |
+
if not isinstance(transforms, Sequence):
|
| 495 |
+
raise TypeError("Argument transforms should be a sequence")
|
| 496 |
+
self.transforms = transforms
|
| 497 |
+
|
| 498 |
+
def __call__(self, *args, **kwargs):
|
| 499 |
+
raise NotImplementedError()
|
| 500 |
+
|
| 501 |
+
def __repr__(self) -> str:
|
| 502 |
+
format_string = self.__class__.__name__ + "("
|
| 503 |
+
for t in self.transforms:
|
| 504 |
+
format_string += "\n"
|
| 505 |
+
format_string += f" {t}"
|
| 506 |
+
format_string += "\n)"
|
| 507 |
+
return format_string
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
class RandomApply(torch.nn.Module):
|
| 511 |
+
"""Apply randomly a list of transformations with a given probability.
|
| 512 |
+
|
| 513 |
+
.. note::
|
| 514 |
+
In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of
|
| 515 |
+
transforms as shown below:
|
| 516 |
+
|
| 517 |
+
>>> transforms = transforms.RandomApply(torch.nn.ModuleList([
|
| 518 |
+
>>> transforms.ColorJitter(),
|
| 519 |
+
>>> ]), p=0.3)
|
| 520 |
+
>>> scripted_transforms = torch.jit.script(transforms)
|
| 521 |
+
|
| 522 |
+
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
|
| 523 |
+
`lambda` functions or ``PIL.Image``.
|
| 524 |
+
|
| 525 |
+
Args:
|
| 526 |
+
transforms (sequence or torch.nn.Module): list of transformations
|
| 527 |
+
p (float): probability
|
| 528 |
+
"""
|
| 529 |
+
|
| 530 |
+
def __init__(self, transforms, p=0.5):
|
| 531 |
+
super().__init__()
|
| 532 |
+
_log_api_usage_once(self)
|
| 533 |
+
self.transforms = transforms
|
| 534 |
+
self.p = p
|
| 535 |
+
|
| 536 |
+
def forward(self, img):
|
| 537 |
+
if self.p < torch.rand(1):
|
| 538 |
+
return img
|
| 539 |
+
for t in self.transforms:
|
| 540 |
+
img = t(img)
|
| 541 |
+
return img
|
| 542 |
+
|
| 543 |
+
def __repr__(self) -> str:
|
| 544 |
+
format_string = self.__class__.__name__ + "("
|
| 545 |
+
format_string += f"\n p={self.p}"
|
| 546 |
+
for t in self.transforms:
|
| 547 |
+
format_string += "\n"
|
| 548 |
+
format_string += f" {t}"
|
| 549 |
+
format_string += "\n)"
|
| 550 |
+
return format_string
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
class RandomOrder(RandomTransforms):
|
| 554 |
+
"""Apply a list of transformations in a random order. This transform does not support torchscript."""
|
| 555 |
+
|
| 556 |
+
def __call__(self, img):
|
| 557 |
+
order = list(range(len(self.transforms)))
|
| 558 |
+
random.shuffle(order)
|
| 559 |
+
for i in order:
|
| 560 |
+
img = self.transforms[i](img)
|
| 561 |
+
return img
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
class RandomChoice(RandomTransforms):
|
| 565 |
+
"""Apply single transformation randomly picked from a list. This transform does not support torchscript."""
|
| 566 |
+
|
| 567 |
+
def __init__(self, transforms, p=None):
|
| 568 |
+
super().__init__(transforms)
|
| 569 |
+
if p is not None and not isinstance(p, Sequence):
|
| 570 |
+
raise TypeError("Argument p should be a sequence")
|
| 571 |
+
self.p = p
|
| 572 |
+
|
| 573 |
+
def __call__(self, *args):
|
| 574 |
+
t = random.choices(self.transforms, weights=self.p)[0]
|
| 575 |
+
return t(*args)
|
| 576 |
+
|
| 577 |
+
def __repr__(self) -> str:
|
| 578 |
+
return f"{super().__repr__()}(p={self.p})"
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
class RandomCrop(torch.nn.Module):
|
| 582 |
+
"""Crop the given image at a random location.
|
| 583 |
+
If the image is torch Tensor, it is expected
|
| 584 |
+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions,
|
| 585 |
+
but if non-constant padding is used, the input is expected to have at most 2 leading dimensions
|
| 586 |
+
|
| 587 |
+
Args:
|
| 588 |
+
size (sequence or int): Desired output size of the crop. If size is an
|
| 589 |
+
int instead of sequence like (h, w), a square crop (size, size) is
|
| 590 |
+
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
|
| 591 |
+
padding (int or sequence, optional): Optional padding on each border
|
| 592 |
+
of the image. Default is None. If a single int is provided this
|
| 593 |
+
is used to pad all borders. If sequence of length 2 is provided this is the padding
|
| 594 |
+
on left/right and top/bottom respectively. If a sequence of length 4 is provided
|
| 595 |
+
this is the padding for the left, top, right and bottom borders respectively.
|
| 596 |
+
|
| 597 |
+
.. note::
|
| 598 |
+
In torchscript mode padding as single int is not supported, use a sequence of
|
| 599 |
+
length 1: ``[padding, ]``.
|
| 600 |
+
pad_if_needed (boolean): It will pad the image if smaller than the
|
| 601 |
+
desired size to avoid raising an exception. Since cropping is done
|
| 602 |
+
after padding, the padding seems to be done at a random offset.
|
| 603 |
+
fill (number or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
|
| 604 |
+
length 3, it is used to fill R, G, B channels respectively.
|
| 605 |
+
This value is only used when the padding_mode is constant.
|
| 606 |
+
Only number is supported for torch Tensor.
|
| 607 |
+
Only int or tuple value is supported for PIL Image.
|
| 608 |
+
padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
|
| 609 |
+
Default is constant.
|
| 610 |
+
|
| 611 |
+
- constant: pads with a constant value, this value is specified with fill
|
| 612 |
+
|
| 613 |
+
- edge: pads with the last value at the edge of the image.
|
| 614 |
+
If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2
|
| 615 |
+
|
| 616 |
+
- reflect: pads with reflection of image without repeating the last value on the edge.
|
| 617 |
+
For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
|
| 618 |
+
will result in [3, 2, 1, 2, 3, 4, 3, 2]
|
| 619 |
+
|
| 620 |
+
- symmetric: pads with reflection of image repeating the last value on the edge.
|
| 621 |
+
For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
|
| 622 |
+
will result in [2, 1, 1, 2, 3, 4, 4, 3]
|
| 623 |
+
"""
|
| 624 |
+
|
| 625 |
+
@staticmethod
|
| 626 |
+
def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]:
|
| 627 |
+
"""Get parameters for ``crop`` for a random crop.
|
| 628 |
+
|
| 629 |
+
Args:
|
| 630 |
+
img (PIL Image or Tensor): Image to be cropped.
|
| 631 |
+
output_size (tuple): Expected output size of the crop.
|
| 632 |
+
|
| 633 |
+
Returns:
|
| 634 |
+
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
|
| 635 |
+
"""
|
| 636 |
+
_, h, w = F.get_dimensions(img)
|
| 637 |
+
th, tw = output_size
|
| 638 |
+
|
| 639 |
+
if h < th or w < tw:
|
| 640 |
+
raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
|
| 641 |
+
|
| 642 |
+
if w == tw and h == th:
|
| 643 |
+
return 0, 0, h, w
|
| 644 |
+
|
| 645 |
+
i = torch.randint(0, h - th + 1, size=(1,)).item()
|
| 646 |
+
j = torch.randint(0, w - tw + 1, size=(1,)).item()
|
| 647 |
+
return i, j, th, tw
|
| 648 |
+
|
| 649 |
+
def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"):
|
| 650 |
+
super().__init__()
|
| 651 |
+
_log_api_usage_once(self)
|
| 652 |
+
|
| 653 |
+
self.size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
|
| 654 |
+
|
| 655 |
+
self.padding = padding
|
| 656 |
+
self.pad_if_needed = pad_if_needed
|
| 657 |
+
self.fill = fill
|
| 658 |
+
self.padding_mode = padding_mode
|
| 659 |
+
|
| 660 |
+
def forward(self, img):
|
| 661 |
+
"""
|
| 662 |
+
Args:
|
| 663 |
+
img (PIL Image or Tensor): Image to be cropped.
|
| 664 |
+
|
| 665 |
+
Returns:
|
| 666 |
+
PIL Image or Tensor: Cropped image.
|
| 667 |
+
"""
|
| 668 |
+
if self.padding is not None:
|
| 669 |
+
img = F.pad(img, self.padding, self.fill, self.padding_mode)
|
| 670 |
+
|
| 671 |
+
_, height, width = F.get_dimensions(img)
|
| 672 |
+
# pad the width if needed
|
| 673 |
+
if self.pad_if_needed and width < self.size[1]:
|
| 674 |
+
padding = [self.size[1] - width, 0]
|
| 675 |
+
img = F.pad(img, padding, self.fill, self.padding_mode)
|
| 676 |
+
# pad the height if needed
|
| 677 |
+
if self.pad_if_needed and height < self.size[0]:
|
| 678 |
+
padding = [0, self.size[0] - height]
|
| 679 |
+
img = F.pad(img, padding, self.fill, self.padding_mode)
|
| 680 |
+
|
| 681 |
+
i, j, h, w = self.get_params(img, self.size)
|
| 682 |
+
|
| 683 |
+
return F.crop(img, i, j, h, w)
|
| 684 |
+
|
| 685 |
+
def __repr__(self) -> str:
|
| 686 |
+
return f"{self.__class__.__name__}(size={self.size}, padding={self.padding})"
|
| 687 |
+
|
| 688 |
+
|
| 689 |
+
class RandomHorizontalFlip(torch.nn.Module):
|
| 690 |
+
"""Horizontally flip the given image randomly with a given probability.
|
| 691 |
+
If the image is torch Tensor, it is expected
|
| 692 |
+
to have [..., H, W] shape, where ... means an arbitrary number of leading
|
| 693 |
+
dimensions
|
| 694 |
+
|
| 695 |
+
Args:
|
| 696 |
+
p (float): probability of the image being flipped. Default value is 0.5
|
| 697 |
+
"""
|
| 698 |
+
|
| 699 |
+
def __init__(self, p=0.5):
|
| 700 |
+
super().__init__()
|
| 701 |
+
_log_api_usage_once(self)
|
| 702 |
+
self.p = p
|
| 703 |
+
|
| 704 |
+
def forward(self, img):
|
| 705 |
+
"""
|
| 706 |
+
Args:
|
| 707 |
+
img (PIL Image or Tensor): Image to be flipped.
|
| 708 |
+
|
| 709 |
+
Returns:
|
| 710 |
+
PIL Image or Tensor: Randomly flipped image.
|
| 711 |
+
"""
|
| 712 |
+
if torch.rand(1) < self.p:
|
| 713 |
+
return F.hflip(img)
|
| 714 |
+
return img
|
| 715 |
+
|
| 716 |
+
def __repr__(self) -> str:
|
| 717 |
+
return f"{self.__class__.__name__}(p={self.p})"
|
| 718 |
+
|
| 719 |
+
|
| 720 |
+
class RandomVerticalFlip(torch.nn.Module):
|
| 721 |
+
"""Vertically flip the given image randomly with a given probability.
|
| 722 |
+
If the image is torch Tensor, it is expected
|
| 723 |
+
to have [..., H, W] shape, where ... means an arbitrary number of leading
|
| 724 |
+
dimensions
|
| 725 |
+
|
| 726 |
+
Args:
|
| 727 |
+
p (float): probability of the image being flipped. Default value is 0.5
|
| 728 |
+
"""
|
| 729 |
+
|
| 730 |
+
def __init__(self, p=0.5):
|
| 731 |
+
super().__init__()
|
| 732 |
+
_log_api_usage_once(self)
|
| 733 |
+
self.p = p
|
| 734 |
+
|
| 735 |
+
def forward(self, img):
|
| 736 |
+
"""
|
| 737 |
+
Args:
|
| 738 |
+
img (PIL Image or Tensor): Image to be flipped.
|
| 739 |
+
|
| 740 |
+
Returns:
|
| 741 |
+
PIL Image or Tensor: Randomly flipped image.
|
| 742 |
+
"""
|
| 743 |
+
if torch.rand(1) < self.p:
|
| 744 |
+
return F.vflip(img)
|
| 745 |
+
return img
|
| 746 |
+
|
| 747 |
+
def __repr__(self) -> str:
|
| 748 |
+
return f"{self.__class__.__name__}(p={self.p})"
|
| 749 |
+
|
| 750 |
+
|
| 751 |
+
class RandomPerspective(torch.nn.Module):
|
| 752 |
+
"""Performs a random perspective transformation of the given image with a given probability.
|
| 753 |
+
If the image is torch Tensor, it is expected
|
| 754 |
+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
|
| 755 |
+
|
| 756 |
+
Args:
|
| 757 |
+
distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
|
| 758 |
+
Default is 0.5.
|
| 759 |
+
p (float): probability of the image being transformed. Default is 0.5.
|
| 760 |
+
interpolation (InterpolationMode): Desired interpolation enum defined by
|
| 761 |
+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
|
| 762 |
+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
|
| 763 |
+
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
|
| 764 |
+
fill (sequence or number): Pixel fill value for the area outside the transformed
|
| 765 |
+
image. Default is ``0``. If given a number, the value is used for all bands respectively.
|
| 766 |
+
"""
|
| 767 |
+
|
| 768 |
+
def __init__(self, distortion_scale=0.5, p=0.5, interpolation=InterpolationMode.BILINEAR, fill=0):
|
| 769 |
+
super().__init__()
|
| 770 |
+
_log_api_usage_once(self)
|
| 771 |
+
self.p = p
|
| 772 |
+
|
| 773 |
+
if isinstance(interpolation, int):
|
| 774 |
+
interpolation = _interpolation_modes_from_int(interpolation)
|
| 775 |
+
|
| 776 |
+
self.interpolation = interpolation
|
| 777 |
+
self.distortion_scale = distortion_scale
|
| 778 |
+
|
| 779 |
+
if fill is None:
|
| 780 |
+
fill = 0
|
| 781 |
+
elif not isinstance(fill, (Sequence, numbers.Number)):
|
| 782 |
+
raise TypeError("Fill should be either a sequence or a number.")
|
| 783 |
+
|
| 784 |
+
self.fill = fill
|
| 785 |
+
|
| 786 |
+
def forward(self, img):
|
| 787 |
+
"""
|
| 788 |
+
Args:
|
| 789 |
+
img (PIL Image or Tensor): Image to be Perspectively transformed.
|
| 790 |
+
|
| 791 |
+
Returns:
|
| 792 |
+
PIL Image or Tensor: Randomly transformed image.
|
| 793 |
+
"""
|
| 794 |
+
|
| 795 |
+
fill = self.fill
|
| 796 |
+
channels, height, width = F.get_dimensions(img)
|
| 797 |
+
if isinstance(img, Tensor):
|
| 798 |
+
if isinstance(fill, (int, float)):
|
| 799 |
+
fill = [float(fill)] * channels
|
| 800 |
+
else:
|
| 801 |
+
fill = [float(f) for f in fill]
|
| 802 |
+
|
| 803 |
+
if torch.rand(1) < self.p:
|
| 804 |
+
startpoints, endpoints = self.get_params(width, height, self.distortion_scale)
|
| 805 |
+
return F.perspective(img, startpoints, endpoints, self.interpolation, fill)
|
| 806 |
+
return img
|
| 807 |
+
|
| 808 |
+
@staticmethod
|
| 809 |
+
def get_params(width: int, height: int, distortion_scale: float) -> Tuple[List[List[int]], List[List[int]]]:
|
| 810 |
+
"""Get parameters for ``perspective`` for a random perspective transform.
|
| 811 |
+
|
| 812 |
+
Args:
|
| 813 |
+
width (int): width of the image.
|
| 814 |
+
height (int): height of the image.
|
| 815 |
+
distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
|
| 816 |
+
|
| 817 |
+
Returns:
|
| 818 |
+
List containing [top-left, top-right, bottom-right, bottom-left] of the original image,
|
| 819 |
+
List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image.
|
| 820 |
+
"""
|
| 821 |
+
half_height = height // 2
|
| 822 |
+
half_width = width // 2
|
| 823 |
+
topleft = [
|
| 824 |
+
int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
|
| 825 |
+
int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
|
| 826 |
+
]
|
| 827 |
+
topright = [
|
| 828 |
+
int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
|
| 829 |
+
int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
|
| 830 |
+
]
|
| 831 |
+
botright = [
|
| 832 |
+
int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
|
| 833 |
+
int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
|
| 834 |
+
]
|
| 835 |
+
botleft = [
|
| 836 |
+
int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
|
| 837 |
+
int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
|
| 838 |
+
]
|
| 839 |
+
startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]
|
| 840 |
+
endpoints = [topleft, topright, botright, botleft]
|
| 841 |
+
return startpoints, endpoints
|
| 842 |
+
|
| 843 |
+
def __repr__(self) -> str:
|
| 844 |
+
return f"{self.__class__.__name__}(p={self.p})"
|
| 845 |
+
|
| 846 |
+
|
| 847 |
+
class RandomResizedCrop(torch.nn.Module):
|
| 848 |
+
"""Crop a random portion of image and resize it to a given size.
|
| 849 |
+
|
| 850 |
+
If the image is torch Tensor, it is expected
|
| 851 |
+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
|
| 852 |
+
|
| 853 |
+
A crop of the original image is made: the crop has a random area (H * W)
|
| 854 |
+
and a random aspect ratio. This crop is finally resized to the given
|
| 855 |
+
size. This is popularly used to train the Inception networks.
|
| 856 |
+
|
| 857 |
+
Args:
|
| 858 |
+
size (int or sequence): expected output size of the crop, for each edge. If size is an
|
| 859 |
+
int instead of sequence like (h, w), a square output size ``(size, size)`` is
|
| 860 |
+
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
|
| 861 |
+
|
| 862 |
+
.. note::
|
| 863 |
+
In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
|
| 864 |
+
scale (tuple of float): Specifies the lower and upper bounds for the random area of the crop,
|
| 865 |
+
before resizing. The scale is defined with respect to the area of the original image.
|
| 866 |
+
ratio (tuple of float): lower and upper bounds for the random aspect ratio of the crop, before
|
| 867 |
+
resizing.
|
| 868 |
+
interpolation (InterpolationMode): Desired interpolation enum defined by
|
| 869 |
+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
|
| 870 |
+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
|
| 871 |
+
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
|
| 872 |
+
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
|
| 873 |
+
antialias (bool, optional): Whether to apply antialiasing.
|
| 874 |
+
It only affects **tensors** with bilinear or bicubic modes and it is
|
| 875 |
+
ignored otherwise: on PIL images, antialiasing is always applied on
|
| 876 |
+
bilinear or bicubic modes; on other modes (for PIL images and
|
| 877 |
+
tensors), antialiasing makes no sense and this parameter is ignored.
|
| 878 |
+
Possible values are:
|
| 879 |
+
|
| 880 |
+
- ``True`` (default): will apply antialiasing for bilinear or bicubic modes.
|
| 881 |
+
Other mode aren't affected. This is probably what you want to use.
|
| 882 |
+
- ``False``: will not apply antialiasing for tensors on any mode. PIL
|
| 883 |
+
images are still antialiased on bilinear or bicubic modes, because
|
| 884 |
+
PIL doesn't support no antialias.
|
| 885 |
+
- ``None``: equivalent to ``False`` for tensors and ``True`` for
|
| 886 |
+
PIL images. This value exists for legacy reasons and you probably
|
| 887 |
+
don't want to use it unless you really know what you are doing.
|
| 888 |
+
|
| 889 |
+
The default value changed from ``None`` to ``True`` in
|
| 890 |
+
v0.17, for the PIL and Tensor backends to be consistent.
|
| 891 |
+
"""
|
| 892 |
+
|
| 893 |
+
def __init__(
|
| 894 |
+
self,
|
| 895 |
+
size,
|
| 896 |
+
scale=(0.08, 1.0),
|
| 897 |
+
ratio=(3.0 / 4.0, 4.0 / 3.0),
|
| 898 |
+
interpolation=InterpolationMode.BILINEAR,
|
| 899 |
+
antialias: Optional[bool] = True,
|
| 900 |
+
):
|
| 901 |
+
super().__init__()
|
| 902 |
+
_log_api_usage_once(self)
|
| 903 |
+
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
|
| 904 |
+
|
| 905 |
+
if not isinstance(scale, Sequence):
|
| 906 |
+
raise TypeError("Scale should be a sequence")
|
| 907 |
+
if not isinstance(ratio, Sequence):
|
| 908 |
+
raise TypeError("Ratio should be a sequence")
|
| 909 |
+
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
|
| 910 |
+
warnings.warn("Scale and ratio should be of kind (min, max)")
|
| 911 |
+
|
| 912 |
+
if isinstance(interpolation, int):
|
| 913 |
+
interpolation = _interpolation_modes_from_int(interpolation)
|
| 914 |
+
|
| 915 |
+
self.interpolation = interpolation
|
| 916 |
+
self.antialias = antialias
|
| 917 |
+
self.scale = scale
|
| 918 |
+
self.ratio = ratio
|
| 919 |
+
|
| 920 |
+
@staticmethod
|
| 921 |
+
def get_params(img: Tensor, scale: List[float], ratio: List[float]) -> Tuple[int, int, int, int]:
|
| 922 |
+
"""Get parameters for ``crop`` for a random sized crop.
|
| 923 |
+
|
| 924 |
+
Args:
|
| 925 |
+
img (PIL Image or Tensor): Input image.
|
| 926 |
+
scale (list): range of scale of the origin size cropped
|
| 927 |
+
ratio (list): range of aspect ratio of the origin aspect ratio cropped
|
| 928 |
+
|
| 929 |
+
Returns:
|
| 930 |
+
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
|
| 931 |
+
sized crop.
|
| 932 |
+
"""
|
| 933 |
+
_, height, width = F.get_dimensions(img)
|
| 934 |
+
area = height * width
|
| 935 |
+
|
| 936 |
+
log_ratio = torch.log(torch.tensor(ratio))
|
| 937 |
+
for _ in range(10):
|
| 938 |
+
target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
|
| 939 |
+
aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
|
| 940 |
+
|
| 941 |
+
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
| 942 |
+
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
| 943 |
+
|
| 944 |
+
if 0 < w <= width and 0 < h <= height:
|
| 945 |
+
i = torch.randint(0, height - h + 1, size=(1,)).item()
|
| 946 |
+
j = torch.randint(0, width - w + 1, size=(1,)).item()
|
| 947 |
+
return i, j, h, w
|
| 948 |
+
|
| 949 |
+
# Fallback to central crop
|
| 950 |
+
in_ratio = float(width) / float(height)
|
| 951 |
+
if in_ratio < min(ratio):
|
| 952 |
+
w = width
|
| 953 |
+
h = int(round(w / min(ratio)))
|
| 954 |
+
elif in_ratio > max(ratio):
|
| 955 |
+
h = height
|
| 956 |
+
w = int(round(h * max(ratio)))
|
| 957 |
+
else: # whole image
|
| 958 |
+
w = width
|
| 959 |
+
h = height
|
| 960 |
+
i = (height - h) // 2
|
| 961 |
+
j = (width - w) // 2
|
| 962 |
+
return i, j, h, w
|
| 963 |
+
|
| 964 |
+
def forward(self, img):
|
| 965 |
+
"""
|
| 966 |
+
Args:
|
| 967 |
+
img (PIL Image or Tensor): Image to be cropped and resized.
|
| 968 |
+
|
| 969 |
+
Returns:
|
| 970 |
+
PIL Image or Tensor: Randomly cropped and resized image.
|
| 971 |
+
"""
|
| 972 |
+
i, j, h, w = self.get_params(img, self.scale, self.ratio)
|
| 973 |
+
return F.resized_crop(img, i, j, h, w, self.size, self.interpolation, antialias=self.antialias)
|
| 974 |
+
|
| 975 |
+
def __repr__(self) -> str:
|
| 976 |
+
interpolate_str = self.interpolation.value
|
| 977 |
+
format_string = self.__class__.__name__ + f"(size={self.size}"
|
| 978 |
+
format_string += f", scale={tuple(round(s, 4) for s in self.scale)}"
|
| 979 |
+
format_string += f", ratio={tuple(round(r, 4) for r in self.ratio)}"
|
| 980 |
+
format_string += f", interpolation={interpolate_str}"
|
| 981 |
+
format_string += f", antialias={self.antialias})"
|
| 982 |
+
return format_string
|
| 983 |
+
|
| 984 |
+
|
| 985 |
+
class FiveCrop(torch.nn.Module):
|
| 986 |
+
"""Crop the given image into four corners and the central crop.
|
| 987 |
+
If the image is torch Tensor, it is expected
|
| 988 |
+
to have [..., H, W] shape, where ... means an arbitrary number of leading
|
| 989 |
+
dimensions
|
| 990 |
+
|
| 991 |
+
.. Note::
|
| 992 |
+
This transform returns a tuple of images and there may be a mismatch in the number of
|
| 993 |
+
inputs and targets your Dataset returns. See below for an example of how to deal with
|
| 994 |
+
this.
|
| 995 |
+
|
| 996 |
+
Args:
|
| 997 |
+
size (sequence or int): Desired output size of the crop. If size is an ``int``
|
| 998 |
+
instead of sequence like (h, w), a square crop of size (size, size) is made.
|
| 999 |
+
If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
|
| 1000 |
+
|
| 1001 |
+
Example:
|
| 1002 |
+
>>> transform = Compose([
|
| 1003 |
+
>>> FiveCrop(size), # this is a list of PIL Images
|
| 1004 |
+
>>> Lambda(lambda crops: torch.stack([PILToTensor()(crop) for crop in crops])) # returns a 4D tensor
|
| 1005 |
+
>>> ])
|
| 1006 |
+
>>> #In your test loop you can do the following:
|
| 1007 |
+
>>> input, target = batch # input is a 5d tensor, target is 2d
|
| 1008 |
+
>>> bs, ncrops, c, h, w = input.size()
|
| 1009 |
+
>>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
|
| 1010 |
+
>>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
|
| 1011 |
+
"""
|
| 1012 |
+
|
| 1013 |
+
def __init__(self, size):
|
| 1014 |
+
super().__init__()
|
| 1015 |
+
_log_api_usage_once(self)
|
| 1016 |
+
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
|
| 1017 |
+
|
| 1018 |
+
def forward(self, img):
|
| 1019 |
+
"""
|
| 1020 |
+
Args:
|
| 1021 |
+
img (PIL Image or Tensor): Image to be cropped.
|
| 1022 |
+
|
| 1023 |
+
Returns:
|
| 1024 |
+
tuple of 5 images. Image can be PIL Image or Tensor
|
| 1025 |
+
"""
|
| 1026 |
+
return F.five_crop(img, self.size)
|
| 1027 |
+
|
| 1028 |
+
def __repr__(self) -> str:
|
| 1029 |
+
return f"{self.__class__.__name__}(size={self.size})"
|
| 1030 |
+
|
| 1031 |
+
|
| 1032 |
+
class TenCrop(torch.nn.Module):
|
| 1033 |
+
"""Crop the given image into four corners and the central crop plus the flipped version of
|
| 1034 |
+
these (horizontal flipping is used by default).
|
| 1035 |
+
If the image is torch Tensor, it is expected
|
| 1036 |
+
to have [..., H, W] shape, where ... means an arbitrary number of leading
|
| 1037 |
+
dimensions
|
| 1038 |
+
|
| 1039 |
+
.. Note::
|
| 1040 |
+
This transform returns a tuple of images and there may be a mismatch in the number of
|
| 1041 |
+
inputs and targets your Dataset returns. See below for an example of how to deal with
|
| 1042 |
+
this.
|
| 1043 |
+
|
| 1044 |
+
Args:
|
| 1045 |
+
size (sequence or int): Desired output size of the crop. If size is an
|
| 1046 |
+
int instead of sequence like (h, w), a square crop (size, size) is
|
| 1047 |
+
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
|
| 1048 |
+
vertical_flip (bool): Use vertical flipping instead of horizontal
|
| 1049 |
+
|
| 1050 |
+
Example:
|
| 1051 |
+
>>> transform = Compose([
|
| 1052 |
+
>>> TenCrop(size), # this is a tuple of PIL Images
|
| 1053 |
+
>>> Lambda(lambda crops: torch.stack([PILToTensor()(crop) for crop in crops])) # returns a 4D tensor
|
| 1054 |
+
>>> ])
|
| 1055 |
+
>>> #In your test loop you can do the following:
|
| 1056 |
+
>>> input, target = batch # input is a 5d tensor, target is 2d
|
| 1057 |
+
>>> bs, ncrops, c, h, w = input.size()
|
| 1058 |
+
>>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
|
| 1059 |
+
>>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
|
| 1060 |
+
"""
|
| 1061 |
+
|
| 1062 |
+
def __init__(self, size, vertical_flip=False):
|
| 1063 |
+
super().__init__()
|
| 1064 |
+
_log_api_usage_once(self)
|
| 1065 |
+
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
|
| 1066 |
+
self.vertical_flip = vertical_flip
|
| 1067 |
+
|
| 1068 |
+
def forward(self, img):
|
| 1069 |
+
"""
|
| 1070 |
+
Args:
|
| 1071 |
+
img (PIL Image or Tensor): Image to be cropped.
|
| 1072 |
+
|
| 1073 |
+
Returns:
|
| 1074 |
+
tuple of 10 images. Image can be PIL Image or Tensor
|
| 1075 |
+
"""
|
| 1076 |
+
return F.ten_crop(img, self.size, self.vertical_flip)
|
| 1077 |
+
|
| 1078 |
+
def __repr__(self) -> str:
|
| 1079 |
+
return f"{self.__class__.__name__}(size={self.size}, vertical_flip={self.vertical_flip})"
|
| 1080 |
+
|
| 1081 |
+
|
| 1082 |
+
class LinearTransformation(torch.nn.Module):
|
| 1083 |
+
"""Transform a tensor image with a square transformation matrix and a mean_vector computed
|
| 1084 |
+
offline.
|
| 1085 |
+
This transform does not support PIL Image.
|
| 1086 |
+
Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and
|
| 1087 |
+
subtract mean_vector from it which is then followed by computing the dot
|
| 1088 |
+
product with the transformation matrix and then reshaping the tensor to its
|
| 1089 |
+
original shape.
|
| 1090 |
+
|
| 1091 |
+
Applications:
|
| 1092 |
+
whitening transformation: Suppose X is a column vector zero-centered data.
|
| 1093 |
+
Then compute the data covariance matrix [D x D] with torch.mm(X.t(), X),
|
| 1094 |
+
perform SVD on this matrix and pass it as transformation_matrix.
|
| 1095 |
+
|
| 1096 |
+
Args:
|
| 1097 |
+
transformation_matrix (Tensor): tensor [D x D], D = C x H x W
|
| 1098 |
+
mean_vector (Tensor): tensor [D], D = C x H x W
|
| 1099 |
+
"""
|
| 1100 |
+
|
| 1101 |
+
def __init__(self, transformation_matrix, mean_vector):
|
| 1102 |
+
super().__init__()
|
| 1103 |
+
_log_api_usage_once(self)
|
| 1104 |
+
if transformation_matrix.size(0) != transformation_matrix.size(1):
|
| 1105 |
+
raise ValueError(
|
| 1106 |
+
"transformation_matrix should be square. Got "
|
| 1107 |
+
f"{tuple(transformation_matrix.size())} rectangular matrix."
|
| 1108 |
+
)
|
| 1109 |
+
|
| 1110 |
+
if mean_vector.size(0) != transformation_matrix.size(0):
|
| 1111 |
+
raise ValueError(
|
| 1112 |
+
f"mean_vector should have the same length {mean_vector.size(0)}"
|
| 1113 |
+
f" as any one of the dimensions of the transformation_matrix [{tuple(transformation_matrix.size())}]"
|
| 1114 |
+
)
|
| 1115 |
+
|
| 1116 |
+
if transformation_matrix.device != mean_vector.device:
|
| 1117 |
+
raise ValueError(
|
| 1118 |
+
f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}"
|
| 1119 |
+
)
|
| 1120 |
+
|
| 1121 |
+
if transformation_matrix.dtype != mean_vector.dtype:
|
| 1122 |
+
raise ValueError(
|
| 1123 |
+
f"Input tensors should have the same dtype. Got {transformation_matrix.dtype} and {mean_vector.dtype}"
|
| 1124 |
+
)
|
| 1125 |
+
|
| 1126 |
+
self.transformation_matrix = transformation_matrix
|
| 1127 |
+
self.mean_vector = mean_vector
|
| 1128 |
+
|
| 1129 |
+
def forward(self, tensor: Tensor) -> Tensor:
|
| 1130 |
+
"""
|
| 1131 |
+
Args:
|
| 1132 |
+
tensor (Tensor): Tensor image to be whitened.
|
| 1133 |
+
|
| 1134 |
+
Returns:
|
| 1135 |
+
Tensor: Transformed image.
|
| 1136 |
+
"""
|
| 1137 |
+
shape = tensor.shape
|
| 1138 |
+
n = shape[-3] * shape[-2] * shape[-1]
|
| 1139 |
+
if n != self.transformation_matrix.shape[0]:
|
| 1140 |
+
raise ValueError(
|
| 1141 |
+
"Input tensor and transformation matrix have incompatible shape."
|
| 1142 |
+
+ f"[{shape[-3]} x {shape[-2]} x {shape[-1]}] != "
|
| 1143 |
+
+ f"{self.transformation_matrix.shape[0]}"
|
| 1144 |
+
)
|
| 1145 |
+
|
| 1146 |
+
if tensor.device.type != self.mean_vector.device.type:
|
| 1147 |
+
raise ValueError(
|
| 1148 |
+
"Input tensor should be on the same device as transformation matrix and mean vector. "
|
| 1149 |
+
f"Got {tensor.device} vs {self.mean_vector.device}"
|
| 1150 |
+
)
|
| 1151 |
+
|
| 1152 |
+
flat_tensor = tensor.view(-1, n) - self.mean_vector
|
| 1153 |
+
transformation_matrix = self.transformation_matrix.to(flat_tensor.dtype)
|
| 1154 |
+
transformed_tensor = torch.mm(flat_tensor, transformation_matrix)
|
| 1155 |
+
tensor = transformed_tensor.view(shape)
|
| 1156 |
+
return tensor
|
| 1157 |
+
|
| 1158 |
+
def __repr__(self) -> str:
|
| 1159 |
+
s = (
|
| 1160 |
+
f"{self.__class__.__name__}(transformation_matrix="
|
| 1161 |
+
f"{self.transformation_matrix.tolist()}"
|
| 1162 |
+
f", mean_vector={self.mean_vector.tolist()})"
|
| 1163 |
+
)
|
| 1164 |
+
return s
|
| 1165 |
+
|
| 1166 |
+
|
| 1167 |
+
class ColorJitter(torch.nn.Module):
|
| 1168 |
+
"""Randomly change the brightness, contrast, saturation and hue of an image.
|
| 1169 |
+
If the image is torch Tensor, it is expected
|
| 1170 |
+
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
| 1171 |
+
If img is PIL Image, mode "1", "I", "F" and modes with transparency (alpha channel) are not supported.
|
| 1172 |
+
|
| 1173 |
+
Args:
|
| 1174 |
+
brightness (float or tuple of float (min, max)): How much to jitter brightness.
|
| 1175 |
+
brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
|
| 1176 |
+
or the given [min, max]. Should be non negative numbers.
|
| 1177 |
+
contrast (float or tuple of float (min, max)): How much to jitter contrast.
|
| 1178 |
+
contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
|
| 1179 |
+
or the given [min, max]. Should be non-negative numbers.
|
| 1180 |
+
saturation (float or tuple of float (min, max)): How much to jitter saturation.
|
| 1181 |
+
saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
|
| 1182 |
+
or the given [min, max]. Should be non negative numbers.
|
| 1183 |
+
hue (float or tuple of float (min, max)): How much to jitter hue.
|
| 1184 |
+
hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
|
| 1185 |
+
Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
|
| 1186 |
+
To jitter hue, the pixel values of the input image has to be non-negative for conversion to HSV space;
|
| 1187 |
+
thus it does not work if you normalize your image to an interval with negative values,
|
| 1188 |
+
or use an interpolation that generates negative values before using this function.
|
| 1189 |
+
"""
|
| 1190 |
+
|
| 1191 |
+
def __init__(
|
| 1192 |
+
self,
|
| 1193 |
+
brightness: Union[float, Tuple[float, float]] = 0,
|
| 1194 |
+
contrast: Union[float, Tuple[float, float]] = 0,
|
| 1195 |
+
saturation: Union[float, Tuple[float, float]] = 0,
|
| 1196 |
+
hue: Union[float, Tuple[float, float]] = 0,
|
| 1197 |
+
) -> None:
|
| 1198 |
+
super().__init__()
|
| 1199 |
+
_log_api_usage_once(self)
|
| 1200 |
+
self.brightness = self._check_input(brightness, "brightness")
|
| 1201 |
+
self.contrast = self._check_input(contrast, "contrast")
|
| 1202 |
+
self.saturation = self._check_input(saturation, "saturation")
|
| 1203 |
+
self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
|
| 1204 |
+
|
| 1205 |
+
@torch.jit.unused
|
| 1206 |
+
def _check_input(self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True):
|
| 1207 |
+
if isinstance(value, numbers.Number):
|
| 1208 |
+
if value < 0:
|
| 1209 |
+
raise ValueError(f"If {name} is a single number, it must be non negative.")
|
| 1210 |
+
value = [center - float(value), center + float(value)]
|
| 1211 |
+
if clip_first_on_zero:
|
| 1212 |
+
value[0] = max(value[0], 0.0)
|
| 1213 |
+
elif isinstance(value, (tuple, list)) and len(value) == 2:
|
| 1214 |
+
value = [float(value[0]), float(value[1])]
|
| 1215 |
+
else:
|
| 1216 |
+
raise TypeError(f"{name} should be a single number or a list/tuple with length 2.")
|
| 1217 |
+
|
| 1218 |
+
if not bound[0] <= value[0] <= value[1] <= bound[1]:
|
| 1219 |
+
raise ValueError(f"{name} values should be between {bound}, but got {value}.")
|
| 1220 |
+
|
| 1221 |
+
# if value is 0 or (1., 1.) for brightness/contrast/saturation
|
| 1222 |
+
# or (0., 0.) for hue, do nothing
|
| 1223 |
+
if value[0] == value[1] == center:
|
| 1224 |
+
return None
|
| 1225 |
+
else:
|
| 1226 |
+
return tuple(value)
|
| 1227 |
+
|
| 1228 |
+
@staticmethod
|
| 1229 |
+
def get_params(
|
| 1230 |
+
brightness: Optional[List[float]],
|
| 1231 |
+
contrast: Optional[List[float]],
|
| 1232 |
+
saturation: Optional[List[float]],
|
| 1233 |
+
hue: Optional[List[float]],
|
| 1234 |
+
) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]:
|
| 1235 |
+
"""Get the parameters for the randomized transform to be applied on image.
|
| 1236 |
+
|
| 1237 |
+
Args:
|
| 1238 |
+
brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen
|
| 1239 |
+
uniformly. Pass None to turn off the transformation.
|
| 1240 |
+
contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen
|
| 1241 |
+
uniformly. Pass None to turn off the transformation.
|
| 1242 |
+
saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen
|
| 1243 |
+
uniformly. Pass None to turn off the transformation.
|
| 1244 |
+
hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly.
|
| 1245 |
+
Pass None to turn off the transformation.
|
| 1246 |
+
|
| 1247 |
+
Returns:
|
| 1248 |
+
tuple: The parameters used to apply the randomized transform
|
| 1249 |
+
along with their random order.
|
| 1250 |
+
"""
|
| 1251 |
+
fn_idx = torch.randperm(4)
|
| 1252 |
+
|
| 1253 |
+
b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
|
| 1254 |
+
c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
|
| 1255 |
+
s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
|
| 1256 |
+
h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
|
| 1257 |
+
|
| 1258 |
+
return fn_idx, b, c, s, h
|
| 1259 |
+
|
| 1260 |
+
def forward(self, img):
|
| 1261 |
+
"""
|
| 1262 |
+
Args:
|
| 1263 |
+
img (PIL Image or Tensor): Input image.
|
| 1264 |
+
|
| 1265 |
+
Returns:
|
| 1266 |
+
PIL Image or Tensor: Color jittered image.
|
| 1267 |
+
"""
|
| 1268 |
+
fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
|
| 1269 |
+
self.brightness, self.contrast, self.saturation, self.hue
|
| 1270 |
+
)
|
| 1271 |
+
|
| 1272 |
+
for fn_id in fn_idx:
|
| 1273 |
+
if fn_id == 0 and brightness_factor is not None:
|
| 1274 |
+
img = F.adjust_brightness(img, brightness_factor)
|
| 1275 |
+
elif fn_id == 1 and contrast_factor is not None:
|
| 1276 |
+
img = F.adjust_contrast(img, contrast_factor)
|
| 1277 |
+
elif fn_id == 2 and saturation_factor is not None:
|
| 1278 |
+
img = F.adjust_saturation(img, saturation_factor)
|
| 1279 |
+
elif fn_id == 3 and hue_factor is not None:
|
| 1280 |
+
img = F.adjust_hue(img, hue_factor)
|
| 1281 |
+
|
| 1282 |
+
return img
|
| 1283 |
+
|
| 1284 |
+
def __repr__(self) -> str:
|
| 1285 |
+
s = (
|
| 1286 |
+
f"{self.__class__.__name__}("
|
| 1287 |
+
f"brightness={self.brightness}"
|
| 1288 |
+
f", contrast={self.contrast}"
|
| 1289 |
+
f", saturation={self.saturation}"
|
| 1290 |
+
f", hue={self.hue})"
|
| 1291 |
+
)
|
| 1292 |
+
return s
|
| 1293 |
+
|
| 1294 |
+
|
| 1295 |
+
class RandomRotation(torch.nn.Module):
|
| 1296 |
+
"""Rotate the image by angle.
|
| 1297 |
+
If the image is torch Tensor, it is expected
|
| 1298 |
+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
|
| 1299 |
+
|
| 1300 |
+
Args:
|
| 1301 |
+
degrees (sequence or number): Range of degrees to select from.
|
| 1302 |
+
If degrees is a number instead of sequence like (min, max), the range of degrees
|
| 1303 |
+
will be (-degrees, +degrees).
|
| 1304 |
+
interpolation (InterpolationMode): Desired interpolation enum defined by
|
| 1305 |
+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
|
| 1306 |
+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
|
| 1307 |
+
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
|
| 1308 |
+
expand (bool, optional): Optional expansion flag.
|
| 1309 |
+
If true, expands the output to make it large enough to hold the entire rotated image.
|
| 1310 |
+
If false or omitted, make the output image the same size as the input image.
|
| 1311 |
+
Note that the expand flag assumes rotation around the center and no translation.
|
| 1312 |
+
center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
|
| 1313 |
+
Default is the center of the image.
|
| 1314 |
+
fill (sequence or number): Pixel fill value for the area outside the rotated
|
| 1315 |
+
image. Default is ``0``. If given a number, the value is used for all bands respectively.
|
| 1316 |
+
|
| 1317 |
+
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
|
| 1318 |
+
|
| 1319 |
+
"""
|
| 1320 |
+
|
| 1321 |
+
def __init__(self, degrees, interpolation=InterpolationMode.NEAREST, expand=False, center=None, fill=0):
|
| 1322 |
+
super().__init__()
|
| 1323 |
+
_log_api_usage_once(self)
|
| 1324 |
+
|
| 1325 |
+
if isinstance(interpolation, int):
|
| 1326 |
+
interpolation = _interpolation_modes_from_int(interpolation)
|
| 1327 |
+
|
| 1328 |
+
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
|
| 1329 |
+
|
| 1330 |
+
if center is not None:
|
| 1331 |
+
_check_sequence_input(center, "center", req_sizes=(2,))
|
| 1332 |
+
|
| 1333 |
+
self.center = center
|
| 1334 |
+
|
| 1335 |
+
self.interpolation = interpolation
|
| 1336 |
+
self.expand = expand
|
| 1337 |
+
|
| 1338 |
+
if fill is None:
|
| 1339 |
+
fill = 0
|
| 1340 |
+
elif not isinstance(fill, (Sequence, numbers.Number)):
|
| 1341 |
+
raise TypeError("Fill should be either a sequence or a number.")
|
| 1342 |
+
|
| 1343 |
+
self.fill = fill
|
| 1344 |
+
|
| 1345 |
+
@staticmethod
|
| 1346 |
+
def get_params(degrees: List[float]) -> float:
|
| 1347 |
+
"""Get parameters for ``rotate`` for a random rotation.
|
| 1348 |
+
|
| 1349 |
+
Returns:
|
| 1350 |
+
float: angle parameter to be passed to ``rotate`` for random rotation.
|
| 1351 |
+
"""
|
| 1352 |
+
angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
|
| 1353 |
+
return angle
|
| 1354 |
+
|
| 1355 |
+
def forward(self, img):
|
| 1356 |
+
"""
|
| 1357 |
+
Args:
|
| 1358 |
+
img (PIL Image or Tensor): Image to be rotated.
|
| 1359 |
+
|
| 1360 |
+
Returns:
|
| 1361 |
+
PIL Image or Tensor: Rotated image.
|
| 1362 |
+
"""
|
| 1363 |
+
fill = self.fill
|
| 1364 |
+
channels, _, _ = F.get_dimensions(img)
|
| 1365 |
+
if isinstance(img, Tensor):
|
| 1366 |
+
if isinstance(fill, (int, float)):
|
| 1367 |
+
fill = [float(fill)] * channels
|
| 1368 |
+
else:
|
| 1369 |
+
fill = [float(f) for f in fill]
|
| 1370 |
+
angle = self.get_params(self.degrees)
|
| 1371 |
+
|
| 1372 |
+
return F.rotate(img, angle, self.interpolation, self.expand, self.center, fill)
|
| 1373 |
+
|
| 1374 |
+
def __repr__(self) -> str:
|
| 1375 |
+
interpolate_str = self.interpolation.value
|
| 1376 |
+
format_string = self.__class__.__name__ + f"(degrees={self.degrees}"
|
| 1377 |
+
format_string += f", interpolation={interpolate_str}"
|
| 1378 |
+
format_string += f", expand={self.expand}"
|
| 1379 |
+
if self.center is not None:
|
| 1380 |
+
format_string += f", center={self.center}"
|
| 1381 |
+
if self.fill is not None:
|
| 1382 |
+
format_string += f", fill={self.fill}"
|
| 1383 |
+
format_string += ")"
|
| 1384 |
+
return format_string
|
| 1385 |
+
|
| 1386 |
+
|
| 1387 |
+
class RandomAffine(torch.nn.Module):
|
| 1388 |
+
"""Random affine transformation of the image keeping center invariant.
|
| 1389 |
+
If the image is torch Tensor, it is expected
|
| 1390 |
+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
|
| 1391 |
+
|
| 1392 |
+
Args:
|
| 1393 |
+
degrees (sequence or number): Range of degrees to select from.
|
| 1394 |
+
If degrees is a number instead of sequence like (min, max), the range of degrees
|
| 1395 |
+
will be (-degrees, +degrees). Set to 0 to deactivate rotations.
|
| 1396 |
+
translate (tuple, optional): tuple of maximum absolute fraction for horizontal
|
| 1397 |
+
and vertical translations. For example translate=(a, b), then horizontal shift
|
| 1398 |
+
is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
|
| 1399 |
+
randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
|
| 1400 |
+
scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is
|
| 1401 |
+
randomly sampled from the range a <= scale <= b. Will keep original scale by default.
|
| 1402 |
+
shear (sequence or number, optional): Range of degrees to select from.
|
| 1403 |
+
If shear is a number, a shear parallel to the x-axis in the range (-shear, +shear)
|
| 1404 |
+
will be applied. Else if shear is a sequence of 2 values a shear parallel to the x-axis in the
|
| 1405 |
+
range (shear[0], shear[1]) will be applied. Else if shear is a sequence of 4 values,
|
| 1406 |
+
an x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied.
|
| 1407 |
+
Will not apply shear by default.
|
| 1408 |
+
interpolation (InterpolationMode): Desired interpolation enum defined by
|
| 1409 |
+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
|
| 1410 |
+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
|
| 1411 |
+
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
|
| 1412 |
+
fill (sequence or number): Pixel fill value for the area outside the transformed
|
| 1413 |
+
image. Default is ``0``. If given a number, the value is used for all bands respectively.
|
| 1414 |
+
center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
|
| 1415 |
+
Default is the center of the image.
|
| 1416 |
+
|
| 1417 |
+
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
|
| 1418 |
+
|
| 1419 |
+
"""
|
| 1420 |
+
|
| 1421 |
+
def __init__(
|
| 1422 |
+
self,
|
| 1423 |
+
degrees,
|
| 1424 |
+
translate=None,
|
| 1425 |
+
scale=None,
|
| 1426 |
+
shear=None,
|
| 1427 |
+
interpolation=InterpolationMode.NEAREST,
|
| 1428 |
+
fill=0,
|
| 1429 |
+
center=None,
|
| 1430 |
+
):
|
| 1431 |
+
super().__init__()
|
| 1432 |
+
_log_api_usage_once(self)
|
| 1433 |
+
|
| 1434 |
+
if isinstance(interpolation, int):
|
| 1435 |
+
interpolation = _interpolation_modes_from_int(interpolation)
|
| 1436 |
+
|
| 1437 |
+
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
|
| 1438 |
+
|
| 1439 |
+
if translate is not None:
|
| 1440 |
+
_check_sequence_input(translate, "translate", req_sizes=(2,))
|
| 1441 |
+
for t in translate:
|
| 1442 |
+
if not (0.0 <= t <= 1.0):
|
| 1443 |
+
raise ValueError("translation values should be between 0 and 1")
|
| 1444 |
+
self.translate = translate
|
| 1445 |
+
|
| 1446 |
+
if scale is not None:
|
| 1447 |
+
_check_sequence_input(scale, "scale", req_sizes=(2,))
|
| 1448 |
+
for s in scale:
|
| 1449 |
+
if s <= 0:
|
| 1450 |
+
raise ValueError("scale values should be positive")
|
| 1451 |
+
self.scale = scale
|
| 1452 |
+
|
| 1453 |
+
if shear is not None:
|
| 1454 |
+
self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4))
|
| 1455 |
+
else:
|
| 1456 |
+
self.shear = shear
|
| 1457 |
+
|
| 1458 |
+
self.interpolation = interpolation
|
| 1459 |
+
|
| 1460 |
+
if fill is None:
|
| 1461 |
+
fill = 0
|
| 1462 |
+
elif not isinstance(fill, (Sequence, numbers.Number)):
|
| 1463 |
+
raise TypeError("Fill should be either a sequence or a number.")
|
| 1464 |
+
|
| 1465 |
+
self.fill = fill
|
| 1466 |
+
|
| 1467 |
+
if center is not None:
|
| 1468 |
+
_check_sequence_input(center, "center", req_sizes=(2,))
|
| 1469 |
+
|
| 1470 |
+
self.center = center
|
| 1471 |
+
|
| 1472 |
+
@staticmethod
|
| 1473 |
+
def get_params(
|
| 1474 |
+
degrees: List[float],
|
| 1475 |
+
translate: Optional[List[float]],
|
| 1476 |
+
scale_ranges: Optional[List[float]],
|
| 1477 |
+
shears: Optional[List[float]],
|
| 1478 |
+
img_size: List[int],
|
| 1479 |
+
) -> Tuple[float, Tuple[int, int], float, Tuple[float, float]]:
|
| 1480 |
+
"""Get parameters for affine transformation
|
| 1481 |
+
|
| 1482 |
+
Returns:
|
| 1483 |
+
params to be passed to the affine transformation
|
| 1484 |
+
"""
|
| 1485 |
+
angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
|
| 1486 |
+
if translate is not None:
|
| 1487 |
+
max_dx = float(translate[0] * img_size[0])
|
| 1488 |
+
max_dy = float(translate[1] * img_size[1])
|
| 1489 |
+
tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item()))
|
| 1490 |
+
ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item()))
|
| 1491 |
+
translations = (tx, ty)
|
| 1492 |
+
else:
|
| 1493 |
+
translations = (0, 0)
|
| 1494 |
+
|
| 1495 |
+
if scale_ranges is not None:
|
| 1496 |
+
scale = float(torch.empty(1).uniform_(scale_ranges[0], scale_ranges[1]).item())
|
| 1497 |
+
else:
|
| 1498 |
+
scale = 1.0
|
| 1499 |
+
|
| 1500 |
+
shear_x = shear_y = 0.0
|
| 1501 |
+
if shears is not None:
|
| 1502 |
+
shear_x = float(torch.empty(1).uniform_(shears[0], shears[1]).item())
|
| 1503 |
+
if len(shears) == 4:
|
| 1504 |
+
shear_y = float(torch.empty(1).uniform_(shears[2], shears[3]).item())
|
| 1505 |
+
|
| 1506 |
+
shear = (shear_x, shear_y)
|
| 1507 |
+
|
| 1508 |
+
return angle, translations, scale, shear
|
| 1509 |
+
|
| 1510 |
+
def forward(self, img):
|
| 1511 |
+
"""
|
| 1512 |
+
img (PIL Image or Tensor): Image to be transformed.
|
| 1513 |
+
|
| 1514 |
+
Returns:
|
| 1515 |
+
PIL Image or Tensor: Affine transformed image.
|
| 1516 |
+
"""
|
| 1517 |
+
fill = self.fill
|
| 1518 |
+
channels, height, width = F.get_dimensions(img)
|
| 1519 |
+
if isinstance(img, Tensor):
|
| 1520 |
+
if isinstance(fill, (int, float)):
|
| 1521 |
+
fill = [float(fill)] * channels
|
| 1522 |
+
else:
|
| 1523 |
+
fill = [float(f) for f in fill]
|
| 1524 |
+
|
| 1525 |
+
img_size = [width, height] # flip for keeping BC on get_params call
|
| 1526 |
+
|
| 1527 |
+
ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size)
|
| 1528 |
+
|
| 1529 |
+
return F.affine(img, *ret, interpolation=self.interpolation, fill=fill, center=self.center)
|
| 1530 |
+
|
| 1531 |
+
def __repr__(self) -> str:
|
| 1532 |
+
s = f"{self.__class__.__name__}(degrees={self.degrees}"
|
| 1533 |
+
s += f", translate={self.translate}" if self.translate is not None else ""
|
| 1534 |
+
s += f", scale={self.scale}" if self.scale is not None else ""
|
| 1535 |
+
s += f", shear={self.shear}" if self.shear is not None else ""
|
| 1536 |
+
s += f", interpolation={self.interpolation.value}" if self.interpolation != InterpolationMode.NEAREST else ""
|
| 1537 |
+
s += f", fill={self.fill}" if self.fill != 0 else ""
|
| 1538 |
+
s += f", center={self.center}" if self.center is not None else ""
|
| 1539 |
+
s += ")"
|
| 1540 |
+
|
| 1541 |
+
return s
|
| 1542 |
+
|
| 1543 |
+
|
| 1544 |
+
class Grayscale(torch.nn.Module):
|
| 1545 |
+
"""Convert image to grayscale.
|
| 1546 |
+
If the image is torch Tensor, it is expected
|
| 1547 |
+
to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
|
| 1548 |
+
|
| 1549 |
+
Args:
|
| 1550 |
+
num_output_channels (int): (1 or 3) number of channels desired for output image
|
| 1551 |
+
|
| 1552 |
+
Returns:
|
| 1553 |
+
PIL Image: Grayscale version of the input.
|
| 1554 |
+
|
| 1555 |
+
- If ``num_output_channels == 1`` : returned image is single channel
|
| 1556 |
+
- If ``num_output_channels == 3`` : returned image is 3 channel with r == g == b
|
| 1557 |
+
|
| 1558 |
+
"""
|
| 1559 |
+
|
| 1560 |
+
def __init__(self, num_output_channels=1):
|
| 1561 |
+
super().__init__()
|
| 1562 |
+
_log_api_usage_once(self)
|
| 1563 |
+
self.num_output_channels = num_output_channels
|
| 1564 |
+
|
| 1565 |
+
def forward(self, img):
|
| 1566 |
+
"""
|
| 1567 |
+
Args:
|
| 1568 |
+
img (PIL Image or Tensor): Image to be converted to grayscale.
|
| 1569 |
+
|
| 1570 |
+
Returns:
|
| 1571 |
+
PIL Image or Tensor: Grayscaled image.
|
| 1572 |
+
"""
|
| 1573 |
+
return F.rgb_to_grayscale(img, num_output_channels=self.num_output_channels)
|
| 1574 |
+
|
| 1575 |
+
def __repr__(self) -> str:
|
| 1576 |
+
return f"{self.__class__.__name__}(num_output_channels={self.num_output_channels})"
|
| 1577 |
+
|
| 1578 |
+
|
| 1579 |
+
class RandomGrayscale(torch.nn.Module):
|
| 1580 |
+
"""Randomly convert image to grayscale with a probability of p (default 0.1).
|
| 1581 |
+
If the image is torch Tensor, it is expected
|
| 1582 |
+
to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
|
| 1583 |
+
|
| 1584 |
+
Args:
|
| 1585 |
+
p (float): probability that image should be converted to grayscale.
|
| 1586 |
+
|
| 1587 |
+
Returns:
|
| 1588 |
+
PIL Image or Tensor: Grayscale version of the input image with probability p and unchanged
|
| 1589 |
+
with probability (1-p).
|
| 1590 |
+
- If input image is 1 channel: grayscale version is 1 channel
|
| 1591 |
+
- If input image is 3 channel: grayscale version is 3 channel with r == g == b
|
| 1592 |
+
|
| 1593 |
+
"""
|
| 1594 |
+
|
| 1595 |
+
def __init__(self, p=0.1):
|
| 1596 |
+
super().__init__()
|
| 1597 |
+
_log_api_usage_once(self)
|
| 1598 |
+
self.p = p
|
| 1599 |
+
|
| 1600 |
+
def forward(self, img):
|
| 1601 |
+
"""
|
| 1602 |
+
Args:
|
| 1603 |
+
img (PIL Image or Tensor): Image to be converted to grayscale.
|
| 1604 |
+
|
| 1605 |
+
Returns:
|
| 1606 |
+
PIL Image or Tensor: Randomly grayscaled image.
|
| 1607 |
+
"""
|
| 1608 |
+
num_output_channels, _, _ = F.get_dimensions(img)
|
| 1609 |
+
if torch.rand(1) < self.p:
|
| 1610 |
+
return F.rgb_to_grayscale(img, num_output_channels=num_output_channels)
|
| 1611 |
+
return img
|
| 1612 |
+
|
| 1613 |
+
def __repr__(self) -> str:
|
| 1614 |
+
return f"{self.__class__.__name__}(p={self.p})"
|
| 1615 |
+
|
| 1616 |
+
|
| 1617 |
+
class RandomErasing(torch.nn.Module):
|
| 1618 |
+
"""Randomly selects a rectangle region in a torch.Tensor image and erases its pixels.
|
| 1619 |
+
This transform does not support PIL Image.
|
| 1620 |
+
'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896
|
| 1621 |
+
|
| 1622 |
+
Args:
|
| 1623 |
+
p: probability that the random erasing operation will be performed.
|
| 1624 |
+
scale: range of proportion of erased area against input image.
|
| 1625 |
+
ratio: range of aspect ratio of erased area.
|
| 1626 |
+
value: erasing value. Default is 0. If a single int, it is used to
|
| 1627 |
+
erase all pixels. If a tuple of length 3, it is used to erase
|
| 1628 |
+
R, G, B channels respectively.
|
| 1629 |
+
If a str of 'random', erasing each pixel with random values.
|
| 1630 |
+
inplace: boolean to make this transform inplace. Default set to False.
|
| 1631 |
+
|
| 1632 |
+
Returns:
|
| 1633 |
+
Erased Image.
|
| 1634 |
+
|
| 1635 |
+
Example:
|
| 1636 |
+
>>> transform = transforms.Compose([
|
| 1637 |
+
>>> transforms.RandomHorizontalFlip(),
|
| 1638 |
+
>>> transforms.PILToTensor(),
|
| 1639 |
+
>>> transforms.ConvertImageDtype(torch.float),
|
| 1640 |
+
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
| 1641 |
+
>>> transforms.RandomErasing(),
|
| 1642 |
+
>>> ])
|
| 1643 |
+
"""
|
| 1644 |
+
|
| 1645 |
+
def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False):
|
| 1646 |
+
super().__init__()
|
| 1647 |
+
_log_api_usage_once(self)
|
| 1648 |
+
if not isinstance(value, (numbers.Number, str, tuple, list)):
|
| 1649 |
+
raise TypeError("Argument value should be either a number or str or a sequence")
|
| 1650 |
+
if isinstance(value, str) and value != "random":
|
| 1651 |
+
raise ValueError("If value is str, it should be 'random'")
|
| 1652 |
+
if not isinstance(scale, Sequence):
|
| 1653 |
+
raise TypeError("Scale should be a sequence")
|
| 1654 |
+
if not isinstance(ratio, Sequence):
|
| 1655 |
+
raise TypeError("Ratio should be a sequence")
|
| 1656 |
+
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
|
| 1657 |
+
warnings.warn("Scale and ratio should be of kind (min, max)")
|
| 1658 |
+
if scale[0] < 0 or scale[1] > 1:
|
| 1659 |
+
raise ValueError("Scale should be between 0 and 1")
|
| 1660 |
+
if p < 0 or p > 1:
|
| 1661 |
+
raise ValueError("Random erasing probability should be between 0 and 1")
|
| 1662 |
+
|
| 1663 |
+
self.p = p
|
| 1664 |
+
self.scale = scale
|
| 1665 |
+
self.ratio = ratio
|
| 1666 |
+
self.value = value
|
| 1667 |
+
self.inplace = inplace
|
| 1668 |
+
|
| 1669 |
+
@staticmethod
|
| 1670 |
+
def get_params(
|
| 1671 |
+
img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float], value: Optional[List[float]] = None
|
| 1672 |
+
) -> Tuple[int, int, int, int, Tensor]:
|
| 1673 |
+
"""Get parameters for ``erase`` for a random erasing.
|
| 1674 |
+
|
| 1675 |
+
Args:
|
| 1676 |
+
img (Tensor): Tensor image to be erased.
|
| 1677 |
+
scale (sequence): range of proportion of erased area against input image.
|
| 1678 |
+
ratio (sequence): range of aspect ratio of erased area.
|
| 1679 |
+
value (list, optional): erasing value. If None, it is interpreted as "random"
|
| 1680 |
+
(erasing each pixel with random values). If ``len(value)`` is 1, it is interpreted as a number,
|
| 1681 |
+
i.e. ``value[0]``.
|
| 1682 |
+
|
| 1683 |
+
Returns:
|
| 1684 |
+
tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing.
|
| 1685 |
+
"""
|
| 1686 |
+
img_c, img_h, img_w = img.shape[-3], img.shape[-2], img.shape[-1]
|
| 1687 |
+
area = img_h * img_w
|
| 1688 |
+
|
| 1689 |
+
log_ratio = torch.log(torch.tensor(ratio))
|
| 1690 |
+
for _ in range(10):
|
| 1691 |
+
erase_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
|
| 1692 |
+
aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
|
| 1693 |
+
|
| 1694 |
+
h = int(round(math.sqrt(erase_area * aspect_ratio)))
|
| 1695 |
+
w = int(round(math.sqrt(erase_area / aspect_ratio)))
|
| 1696 |
+
if not (h < img_h and w < img_w):
|
| 1697 |
+
continue
|
| 1698 |
+
|
| 1699 |
+
if value is None:
|
| 1700 |
+
v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
|
| 1701 |
+
else:
|
| 1702 |
+
v = torch.tensor(value)[:, None, None]
|
| 1703 |
+
|
| 1704 |
+
i = torch.randint(0, img_h - h + 1, size=(1,)).item()
|
| 1705 |
+
j = torch.randint(0, img_w - w + 1, size=(1,)).item()
|
| 1706 |
+
return i, j, h, w, v
|
| 1707 |
+
|
| 1708 |
+
# Return original image
|
| 1709 |
+
return 0, 0, img_h, img_w, img
|
| 1710 |
+
|
| 1711 |
+
def forward(self, img):
|
| 1712 |
+
"""
|
| 1713 |
+
Args:
|
| 1714 |
+
img (Tensor): Tensor image to be erased.
|
| 1715 |
+
|
| 1716 |
+
Returns:
|
| 1717 |
+
img (Tensor): Erased Tensor image.
|
| 1718 |
+
"""
|
| 1719 |
+
if torch.rand(1) < self.p:
|
| 1720 |
+
|
| 1721 |
+
# cast self.value to script acceptable type
|
| 1722 |
+
if isinstance(self.value, (int, float)):
|
| 1723 |
+
value = [float(self.value)]
|
| 1724 |
+
elif isinstance(self.value, str):
|
| 1725 |
+
value = None
|
| 1726 |
+
elif isinstance(self.value, (list, tuple)):
|
| 1727 |
+
value = [float(v) for v in self.value]
|
| 1728 |
+
else:
|
| 1729 |
+
value = self.value
|
| 1730 |
+
|
| 1731 |
+
if value is not None and not (len(value) in (1, img.shape[-3])):
|
| 1732 |
+
raise ValueError(
|
| 1733 |
+
"If value is a sequence, it should have either a single value or "
|
| 1734 |
+
f"{img.shape[-3]} (number of input channels)"
|
| 1735 |
+
)
|
| 1736 |
+
|
| 1737 |
+
x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=value)
|
| 1738 |
+
return F.erase(img, x, y, h, w, v, self.inplace)
|
| 1739 |
+
return img
|
| 1740 |
+
|
| 1741 |
+
def __repr__(self) -> str:
|
| 1742 |
+
s = (
|
| 1743 |
+
f"{self.__class__.__name__}"
|
| 1744 |
+
f"(p={self.p}, "
|
| 1745 |
+
f"scale={self.scale}, "
|
| 1746 |
+
f"ratio={self.ratio}, "
|
| 1747 |
+
f"value={self.value}, "
|
| 1748 |
+
f"inplace={self.inplace})"
|
| 1749 |
+
)
|
| 1750 |
+
return s
|
| 1751 |
+
|
| 1752 |
+
|
| 1753 |
+
class GaussianBlur(torch.nn.Module):
|
| 1754 |
+
"""Blurs image with randomly chosen Gaussian blur.
|
| 1755 |
+
If the image is torch Tensor, it is expected
|
| 1756 |
+
to have [..., C, H, W] shape, where ... means at most one leading dimension.
|
| 1757 |
+
|
| 1758 |
+
Args:
|
| 1759 |
+
kernel_size (int or sequence): Size of the Gaussian kernel.
|
| 1760 |
+
sigma (float or tuple of float (min, max)): Standard deviation to be used for
|
| 1761 |
+
creating kernel to perform blurring. If float, sigma is fixed. If it is tuple
|
| 1762 |
+
of float (min, max), sigma is chosen uniformly at random to lie in the
|
| 1763 |
+
given range.
|
| 1764 |
+
|
| 1765 |
+
Returns:
|
| 1766 |
+
PIL Image or Tensor: Gaussian blurred version of the input image.
|
| 1767 |
+
|
| 1768 |
+
"""
|
| 1769 |
+
|
| 1770 |
+
def __init__(self, kernel_size, sigma=(0.1, 2.0)):
|
| 1771 |
+
super().__init__()
|
| 1772 |
+
_log_api_usage_once(self)
|
| 1773 |
+
self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers")
|
| 1774 |
+
for ks in self.kernel_size:
|
| 1775 |
+
if ks <= 0 or ks % 2 == 0:
|
| 1776 |
+
raise ValueError("Kernel size value should be an odd and positive number.")
|
| 1777 |
+
|
| 1778 |
+
if isinstance(sigma, numbers.Number):
|
| 1779 |
+
if sigma <= 0:
|
| 1780 |
+
raise ValueError("If sigma is a single number, it must be positive.")
|
| 1781 |
+
sigma = (sigma, sigma)
|
| 1782 |
+
elif isinstance(sigma, Sequence) and len(sigma) == 2:
|
| 1783 |
+
if not 0.0 < sigma[0] <= sigma[1]:
|
| 1784 |
+
raise ValueError("sigma values should be positive and of the form (min, max).")
|
| 1785 |
+
else:
|
| 1786 |
+
raise ValueError("sigma should be a single number or a list/tuple with length 2.")
|
| 1787 |
+
|
| 1788 |
+
self.sigma = sigma
|
| 1789 |
+
|
| 1790 |
+
@staticmethod
|
| 1791 |
+
def get_params(sigma_min: float, sigma_max: float) -> float:
|
| 1792 |
+
"""Choose sigma for random gaussian blurring.
|
| 1793 |
+
|
| 1794 |
+
Args:
|
| 1795 |
+
sigma_min (float): Minimum standard deviation that can be chosen for blurring kernel.
|
| 1796 |
+
sigma_max (float): Maximum standard deviation that can be chosen for blurring kernel.
|
| 1797 |
+
|
| 1798 |
+
Returns:
|
| 1799 |
+
float: Standard deviation to be passed to calculate kernel for gaussian blurring.
|
| 1800 |
+
"""
|
| 1801 |
+
return torch.empty(1).uniform_(sigma_min, sigma_max).item()
|
| 1802 |
+
|
| 1803 |
+
def forward(self, img: Tensor) -> Tensor:
|
| 1804 |
+
"""
|
| 1805 |
+
Args:
|
| 1806 |
+
img (PIL Image or Tensor): image to be blurred.
|
| 1807 |
+
|
| 1808 |
+
Returns:
|
| 1809 |
+
PIL Image or Tensor: Gaussian blurred image
|
| 1810 |
+
"""
|
| 1811 |
+
sigma = self.get_params(self.sigma[0], self.sigma[1])
|
| 1812 |
+
return F.gaussian_blur(img, self.kernel_size, [sigma, sigma])
|
| 1813 |
+
|
| 1814 |
+
def __repr__(self) -> str:
|
| 1815 |
+
s = f"{self.__class__.__name__}(kernel_size={self.kernel_size}, sigma={self.sigma})"
|
| 1816 |
+
return s
|
| 1817 |
+
|
| 1818 |
+
|
| 1819 |
+
def _setup_size(size, error_msg):
|
| 1820 |
+
if isinstance(size, numbers.Number):
|
| 1821 |
+
return int(size), int(size)
|
| 1822 |
+
|
| 1823 |
+
if isinstance(size, Sequence) and len(size) == 1:
|
| 1824 |
+
return size[0], size[0]
|
| 1825 |
+
|
| 1826 |
+
if len(size) != 2:
|
| 1827 |
+
raise ValueError(error_msg)
|
| 1828 |
+
|
| 1829 |
+
return size
|
| 1830 |
+
|
| 1831 |
+
|
| 1832 |
+
def _check_sequence_input(x, name, req_sizes):
|
| 1833 |
+
msg = req_sizes[0] if len(req_sizes) < 2 else " or ".join([str(s) for s in req_sizes])
|
| 1834 |
+
if not isinstance(x, Sequence):
|
| 1835 |
+
raise TypeError(f"{name} should be a sequence of length {msg}.")
|
| 1836 |
+
if len(x) not in req_sizes:
|
| 1837 |
+
raise ValueError(f"{name} should be a sequence of length {msg}.")
|
| 1838 |
+
|
| 1839 |
+
|
| 1840 |
+
def _setup_angle(x, name, req_sizes=(2,)):
|
| 1841 |
+
if isinstance(x, numbers.Number):
|
| 1842 |
+
if x < 0:
|
| 1843 |
+
raise ValueError(f"If {name} is a single number, it must be positive.")
|
| 1844 |
+
x = [-x, x]
|
| 1845 |
+
else:
|
| 1846 |
+
_check_sequence_input(x, name, req_sizes)
|
| 1847 |
+
|
| 1848 |
+
return [float(d) for d in x]
|
| 1849 |
+
|
| 1850 |
+
|
| 1851 |
+
class RandomInvert(torch.nn.Module):
|
| 1852 |
+
"""Inverts the colors of the given image randomly with a given probability.
|
| 1853 |
+
If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
|
| 1854 |
+
where ... means it can have an arbitrary number of leading dimensions.
|
| 1855 |
+
If img is PIL Image, it is expected to be in mode "L" or "RGB".
|
| 1856 |
+
|
| 1857 |
+
Args:
|
| 1858 |
+
p (float): probability of the image being color inverted. Default value is 0.5
|
| 1859 |
+
"""
|
| 1860 |
+
|
| 1861 |
+
def __init__(self, p=0.5):
|
| 1862 |
+
super().__init__()
|
| 1863 |
+
_log_api_usage_once(self)
|
| 1864 |
+
self.p = p
|
| 1865 |
+
|
| 1866 |
+
def forward(self, img):
|
| 1867 |
+
"""
|
| 1868 |
+
Args:
|
| 1869 |
+
img (PIL Image or Tensor): Image to be inverted.
|
| 1870 |
+
|
| 1871 |
+
Returns:
|
| 1872 |
+
PIL Image or Tensor: Randomly color inverted image.
|
| 1873 |
+
"""
|
| 1874 |
+
if torch.rand(1).item() < self.p:
|
| 1875 |
+
return F.invert(img)
|
| 1876 |
+
return img
|
| 1877 |
+
|
| 1878 |
+
def __repr__(self) -> str:
|
| 1879 |
+
return f"{self.__class__.__name__}(p={self.p})"
|
| 1880 |
+
|
| 1881 |
+
|
| 1882 |
+
class RandomPosterize(torch.nn.Module):
|
| 1883 |
+
"""Posterize the image randomly with a given probability by reducing the
|
| 1884 |
+
number of bits for each color channel. If the image is torch Tensor, it should be of type torch.uint8,
|
| 1885 |
+
and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
| 1886 |
+
If img is PIL Image, it is expected to be in mode "L" or "RGB".
|
| 1887 |
+
|
| 1888 |
+
Args:
|
| 1889 |
+
bits (int): number of bits to keep for each channel (0-8)
|
| 1890 |
+
p (float): probability of the image being posterized. Default value is 0.5
|
| 1891 |
+
"""
|
| 1892 |
+
|
| 1893 |
+
def __init__(self, bits, p=0.5):
|
| 1894 |
+
super().__init__()
|
| 1895 |
+
_log_api_usage_once(self)
|
| 1896 |
+
self.bits = bits
|
| 1897 |
+
self.p = p
|
| 1898 |
+
|
| 1899 |
+
def forward(self, img):
|
| 1900 |
+
"""
|
| 1901 |
+
Args:
|
| 1902 |
+
img (PIL Image or Tensor): Image to be posterized.
|
| 1903 |
+
|
| 1904 |
+
Returns:
|
| 1905 |
+
PIL Image or Tensor: Randomly posterized image.
|
| 1906 |
+
"""
|
| 1907 |
+
if torch.rand(1).item() < self.p:
|
| 1908 |
+
return F.posterize(img, self.bits)
|
| 1909 |
+
return img
|
| 1910 |
+
|
| 1911 |
+
def __repr__(self) -> str:
|
| 1912 |
+
return f"{self.__class__.__name__}(bits={self.bits},p={self.p})"
|
| 1913 |
+
|
| 1914 |
+
|
| 1915 |
+
class RandomSolarize(torch.nn.Module):
|
| 1916 |
+
"""Solarize the image randomly with a given probability by inverting all pixel
|
| 1917 |
+
values above a threshold. If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
|
| 1918 |
+
where ... means it can have an arbitrary number of leading dimensions.
|
| 1919 |
+
If img is PIL Image, it is expected to be in mode "L" or "RGB".
|
| 1920 |
+
|
| 1921 |
+
Args:
|
| 1922 |
+
threshold (float): all pixels equal or above this value are inverted.
|
| 1923 |
+
p (float): probability of the image being solarized. Default value is 0.5
|
| 1924 |
+
"""
|
| 1925 |
+
|
| 1926 |
+
def __init__(self, threshold, p=0.5):
|
| 1927 |
+
super().__init__()
|
| 1928 |
+
_log_api_usage_once(self)
|
| 1929 |
+
self.threshold = threshold
|
| 1930 |
+
self.p = p
|
| 1931 |
+
|
| 1932 |
+
def forward(self, img):
|
| 1933 |
+
"""
|
| 1934 |
+
Args:
|
| 1935 |
+
img (PIL Image or Tensor): Image to be solarized.
|
| 1936 |
+
|
| 1937 |
+
Returns:
|
| 1938 |
+
PIL Image or Tensor: Randomly solarized image.
|
| 1939 |
+
"""
|
| 1940 |
+
if torch.rand(1).item() < self.p:
|
| 1941 |
+
return F.solarize(img, self.threshold)
|
| 1942 |
+
return img
|
| 1943 |
+
|
| 1944 |
+
def __repr__(self) -> str:
|
| 1945 |
+
return f"{self.__class__.__name__}(threshold={self.threshold},p={self.p})"
|
| 1946 |
+
|
| 1947 |
+
|
| 1948 |
+
class RandomAdjustSharpness(torch.nn.Module):
|
| 1949 |
+
"""Adjust the sharpness of the image randomly with a given probability. If the image is torch Tensor,
|
| 1950 |
+
it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
| 1951 |
+
|
| 1952 |
+
Args:
|
| 1953 |
+
sharpness_factor (float): How much to adjust the sharpness. Can be
|
| 1954 |
+
any non-negative number. 0 gives a blurred image, 1 gives the
|
| 1955 |
+
original image while 2 increases the sharpness by a factor of 2.
|
| 1956 |
+
p (float): probability of the image being sharpened. Default value is 0.5
|
| 1957 |
+
"""
|
| 1958 |
+
|
| 1959 |
+
def __init__(self, sharpness_factor, p=0.5):
|
| 1960 |
+
super().__init__()
|
| 1961 |
+
_log_api_usage_once(self)
|
| 1962 |
+
self.sharpness_factor = sharpness_factor
|
| 1963 |
+
self.p = p
|
| 1964 |
+
|
| 1965 |
+
def forward(self, img):
|
| 1966 |
+
"""
|
| 1967 |
+
Args:
|
| 1968 |
+
img (PIL Image or Tensor): Image to be sharpened.
|
| 1969 |
+
|
| 1970 |
+
Returns:
|
| 1971 |
+
PIL Image or Tensor: Randomly sharpened image.
|
| 1972 |
+
"""
|
| 1973 |
+
if torch.rand(1).item() < self.p:
|
| 1974 |
+
return F.adjust_sharpness(img, self.sharpness_factor)
|
| 1975 |
+
return img
|
| 1976 |
+
|
| 1977 |
+
def __repr__(self) -> str:
|
| 1978 |
+
return f"{self.__class__.__name__}(sharpness_factor={self.sharpness_factor},p={self.p})"
|
| 1979 |
+
|
| 1980 |
+
|
| 1981 |
+
class RandomAutocontrast(torch.nn.Module):
|
| 1982 |
+
"""Autocontrast the pixels of the given image randomly with a given probability.
|
| 1983 |
+
If the image is torch Tensor, it is expected
|
| 1984 |
+
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
| 1985 |
+
If img is PIL Image, it is expected to be in mode "L" or "RGB".
|
| 1986 |
+
|
| 1987 |
+
Args:
|
| 1988 |
+
p (float): probability of the image being autocontrasted. Default value is 0.5
|
| 1989 |
+
"""
|
| 1990 |
+
|
| 1991 |
+
def __init__(self, p=0.5):
|
| 1992 |
+
super().__init__()
|
| 1993 |
+
_log_api_usage_once(self)
|
| 1994 |
+
self.p = p
|
| 1995 |
+
|
| 1996 |
+
def forward(self, img):
|
| 1997 |
+
"""
|
| 1998 |
+
Args:
|
| 1999 |
+
img (PIL Image or Tensor): Image to be autocontrasted.
|
| 2000 |
+
|
| 2001 |
+
Returns:
|
| 2002 |
+
PIL Image or Tensor: Randomly autocontrasted image.
|
| 2003 |
+
"""
|
| 2004 |
+
if torch.rand(1).item() < self.p:
|
| 2005 |
+
return F.autocontrast(img)
|
| 2006 |
+
return img
|
| 2007 |
+
|
| 2008 |
+
def __repr__(self) -> str:
|
| 2009 |
+
return f"{self.__class__.__name__}(p={self.p})"
|
| 2010 |
+
|
| 2011 |
+
|
| 2012 |
+
class RandomEqualize(torch.nn.Module):
|
| 2013 |
+
"""Equalize the histogram of the given image randomly with a given probability.
|
| 2014 |
+
If the image is torch Tensor, it is expected
|
| 2015 |
+
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
| 2016 |
+
If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
|
| 2017 |
+
|
| 2018 |
+
Args:
|
| 2019 |
+
p (float): probability of the image being equalized. Default value is 0.5
|
| 2020 |
+
"""
|
| 2021 |
+
|
| 2022 |
+
def __init__(self, p=0.5):
|
| 2023 |
+
super().__init__()
|
| 2024 |
+
_log_api_usage_once(self)
|
| 2025 |
+
self.p = p
|
| 2026 |
+
|
| 2027 |
+
def forward(self, img):
|
| 2028 |
+
"""
|
| 2029 |
+
Args:
|
| 2030 |
+
img (PIL Image or Tensor): Image to be equalized.
|
| 2031 |
+
|
| 2032 |
+
Returns:
|
| 2033 |
+
PIL Image or Tensor: Randomly equalized image.
|
| 2034 |
+
"""
|
| 2035 |
+
if torch.rand(1).item() < self.p:
|
| 2036 |
+
return F.equalize(img)
|
| 2037 |
+
return img
|
| 2038 |
+
|
| 2039 |
+
def __repr__(self) -> str:
|
| 2040 |
+
return f"{self.__class__.__name__}(p={self.p})"
|
| 2041 |
+
|
| 2042 |
+
|
| 2043 |
+
class ElasticTransform(torch.nn.Module):
|
| 2044 |
+
"""Transform a tensor image with elastic transformations.
|
| 2045 |
+
Given alpha and sigma, it will generate displacement
|
| 2046 |
+
vectors for all pixels based on random offsets. Alpha controls the strength
|
| 2047 |
+
and sigma controls the smoothness of the displacements.
|
| 2048 |
+
The displacements are added to an identity grid and the resulting grid is
|
| 2049 |
+
used to grid_sample from the image.
|
| 2050 |
+
|
| 2051 |
+
Applications:
|
| 2052 |
+
Randomly transforms the morphology of objects in images and produces a
|
| 2053 |
+
see-through-water-like effect.
|
| 2054 |
+
|
| 2055 |
+
Args:
|
| 2056 |
+
alpha (float or sequence of floats): Magnitude of displacements. Default is 50.0.
|
| 2057 |
+
sigma (float or sequence of floats): Smoothness of displacements. Default is 5.0.
|
| 2058 |
+
interpolation (InterpolationMode): Desired interpolation enum defined by
|
| 2059 |
+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
|
| 2060 |
+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
|
| 2061 |
+
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
|
| 2062 |
+
fill (sequence or number): Pixel fill value for the area outside the transformed
|
| 2063 |
+
image. Default is ``0``. If given a number, the value is used for all bands respectively.
|
| 2064 |
+
|
| 2065 |
+
"""
|
| 2066 |
+
|
| 2067 |
+
def __init__(self, alpha=50.0, sigma=5.0, interpolation=InterpolationMode.BILINEAR, fill=0):
|
| 2068 |
+
super().__init__()
|
| 2069 |
+
_log_api_usage_once(self)
|
| 2070 |
+
if not isinstance(alpha, (float, Sequence)):
|
| 2071 |
+
raise TypeError(f"alpha should be float or a sequence of floats. Got {type(alpha)}")
|
| 2072 |
+
if isinstance(alpha, Sequence) and len(alpha) != 2:
|
| 2073 |
+
raise ValueError(f"If alpha is a sequence its length should be 2. Got {len(alpha)}")
|
| 2074 |
+
if isinstance(alpha, Sequence):
|
| 2075 |
+
for element in alpha:
|
| 2076 |
+
if not isinstance(element, float):
|
| 2077 |
+
raise TypeError(f"alpha should be a sequence of floats. Got {type(element)}")
|
| 2078 |
+
|
| 2079 |
+
if isinstance(alpha, float):
|
| 2080 |
+
alpha = [float(alpha), float(alpha)]
|
| 2081 |
+
if isinstance(alpha, (list, tuple)) and len(alpha) == 1:
|
| 2082 |
+
alpha = [alpha[0], alpha[0]]
|
| 2083 |
+
|
| 2084 |
+
self.alpha = alpha
|
| 2085 |
+
|
| 2086 |
+
if not isinstance(sigma, (float, Sequence)):
|
| 2087 |
+
raise TypeError(f"sigma should be float or a sequence of floats. Got {type(sigma)}")
|
| 2088 |
+
if isinstance(sigma, Sequence) and len(sigma) != 2:
|
| 2089 |
+
raise ValueError(f"If sigma is a sequence its length should be 2. Got {len(sigma)}")
|
| 2090 |
+
if isinstance(sigma, Sequence):
|
| 2091 |
+
for element in sigma:
|
| 2092 |
+
if not isinstance(element, float):
|
| 2093 |
+
raise TypeError(f"sigma should be a sequence of floats. Got {type(element)}")
|
| 2094 |
+
|
| 2095 |
+
if isinstance(sigma, float):
|
| 2096 |
+
sigma = [float(sigma), float(sigma)]
|
| 2097 |
+
if isinstance(sigma, (list, tuple)) and len(sigma) == 1:
|
| 2098 |
+
sigma = [sigma[0], sigma[0]]
|
| 2099 |
+
|
| 2100 |
+
self.sigma = sigma
|
| 2101 |
+
|
| 2102 |
+
if isinstance(interpolation, int):
|
| 2103 |
+
interpolation = _interpolation_modes_from_int(interpolation)
|
| 2104 |
+
self.interpolation = interpolation
|
| 2105 |
+
|
| 2106 |
+
if isinstance(fill, (int, float)):
|
| 2107 |
+
fill = [float(fill)]
|
| 2108 |
+
elif isinstance(fill, (list, tuple)):
|
| 2109 |
+
fill = [float(f) for f in fill]
|
| 2110 |
+
else:
|
| 2111 |
+
raise TypeError(f"fill should be int or float or a list or tuple of them. Got {type(fill)}")
|
| 2112 |
+
self.fill = fill
|
| 2113 |
+
|
| 2114 |
+
@staticmethod
|
| 2115 |
+
def get_params(alpha: List[float], sigma: List[float], size: List[int]) -> Tensor:
|
| 2116 |
+
dx = torch.rand([1, 1] + size) * 2 - 1
|
| 2117 |
+
if sigma[0] > 0.0:
|
| 2118 |
+
kx = int(8 * sigma[0] + 1)
|
| 2119 |
+
# if kernel size is even we have to make it odd
|
| 2120 |
+
if kx % 2 == 0:
|
| 2121 |
+
kx += 1
|
| 2122 |
+
dx = F.gaussian_blur(dx, [kx, kx], sigma)
|
| 2123 |
+
dx = dx * alpha[0] / size[0]
|
| 2124 |
+
|
| 2125 |
+
dy = torch.rand([1, 1] + size) * 2 - 1
|
| 2126 |
+
if sigma[1] > 0.0:
|
| 2127 |
+
ky = int(8 * sigma[1] + 1)
|
| 2128 |
+
# if kernel size is even we have to make it odd
|
| 2129 |
+
if ky % 2 == 0:
|
| 2130 |
+
ky += 1
|
| 2131 |
+
dy = F.gaussian_blur(dy, [ky, ky], sigma)
|
| 2132 |
+
dy = dy * alpha[1] / size[1]
|
| 2133 |
+
return torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2
|
| 2134 |
+
|
| 2135 |
+
def forward(self, tensor: Tensor) -> Tensor:
|
| 2136 |
+
"""
|
| 2137 |
+
Args:
|
| 2138 |
+
tensor (PIL Image or Tensor): Image to be transformed.
|
| 2139 |
+
|
| 2140 |
+
Returns:
|
| 2141 |
+
PIL Image or Tensor: Transformed image.
|
| 2142 |
+
"""
|
| 2143 |
+
_, height, width = F.get_dimensions(tensor)
|
| 2144 |
+
displacement = self.get_params(self.alpha, self.sigma, [height, width])
|
| 2145 |
+
return F.elastic_transform(tensor, displacement, self.interpolation, self.fill)
|
| 2146 |
+
|
| 2147 |
+
def __repr__(self):
|
| 2148 |
+
format_string = self.__class__.__name__
|
| 2149 |
+
format_string += f"(alpha={self.alpha}"
|
| 2150 |
+
format_string += f", sigma={self.sigma}"
|
| 2151 |
+
format_string += f", interpolation={self.interpolation}"
|
| 2152 |
+
format_string += f", fill={self.fill})"
|
| 2153 |
+
return format_string
|
.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__init__.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torchvision.transforms import AutoAugmentPolicy, InterpolationMode # usort: skip
|
| 2 |
+
|
| 3 |
+
from . import functional # usort: skip
|
| 4 |
+
|
| 5 |
+
from ._transform import Transform # usort: skip
|
| 6 |
+
|
| 7 |
+
from ._augment import CutMix, JPEG, MixUp, RandomErasing
|
| 8 |
+
from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide
|
| 9 |
+
from ._color import (
|
| 10 |
+
ColorJitter,
|
| 11 |
+
Grayscale,
|
| 12 |
+
RandomAdjustSharpness,
|
| 13 |
+
RandomAutocontrast,
|
| 14 |
+
RandomChannelPermutation,
|
| 15 |
+
RandomEqualize,
|
| 16 |
+
RandomGrayscale,
|
| 17 |
+
RandomInvert,
|
| 18 |
+
RandomPhotometricDistort,
|
| 19 |
+
RandomPosterize,
|
| 20 |
+
RandomSolarize,
|
| 21 |
+
RGB,
|
| 22 |
+
)
|
| 23 |
+
from ._container import Compose, RandomApply, RandomChoice, RandomOrder
|
| 24 |
+
from ._geometry import (
|
| 25 |
+
CenterCrop,
|
| 26 |
+
ElasticTransform,
|
| 27 |
+
FiveCrop,
|
| 28 |
+
Pad,
|
| 29 |
+
RandomAffine,
|
| 30 |
+
RandomCrop,
|
| 31 |
+
RandomHorizontalFlip,
|
| 32 |
+
RandomIoUCrop,
|
| 33 |
+
RandomPerspective,
|
| 34 |
+
RandomResize,
|
| 35 |
+
RandomResizedCrop,
|
| 36 |
+
RandomRotation,
|
| 37 |
+
RandomShortestSize,
|
| 38 |
+
RandomVerticalFlip,
|
| 39 |
+
RandomZoomOut,
|
| 40 |
+
Resize,
|
| 41 |
+
ScaleJitter,
|
| 42 |
+
TenCrop,
|
| 43 |
+
)
|
| 44 |
+
from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat
|
| 45 |
+
from ._misc import (
|
| 46 |
+
ConvertImageDtype,
|
| 47 |
+
GaussianBlur,
|
| 48 |
+
GaussianNoise,
|
| 49 |
+
Identity,
|
| 50 |
+
Lambda,
|
| 51 |
+
LinearTransformation,
|
| 52 |
+
Normalize,
|
| 53 |
+
SanitizeBoundingBoxes,
|
| 54 |
+
ToDtype,
|
| 55 |
+
)
|
| 56 |
+
from ._temporal import UniformTemporalSubsample
|
| 57 |
+
from ._type_conversion import PILToTensor, ToImage, ToPILImage, ToPureTensor
|
| 58 |
+
from ._utils import check_type, get_bounding_boxes, has_all, has_any, query_chw, query_size
|
| 59 |
+
|
| 60 |
+
from ._deprecated import ToTensor # usort: skip
|
.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_augment.cpython-311.pyc
ADDED
|
Binary file (25.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_auto_augment.cpython-311.pyc
ADDED
|
Binary file (39.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_color.cpython-311.pyc
ADDED
|
Binary file (27.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_deprecated.cpython-311.pyc
ADDED
|
Binary file (3.09 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_misc.cpython-311.pyc
ADDED
|
Binary file (29.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_temporal.cpython-311.pyc
ADDED
|
Binary file (1.99 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_transform.cpython-311.pyc
ADDED
|
Binary file (9.84 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_type_conversion.cpython-311.pyc
ADDED
|
Binary file (5.36 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_utils.cpython-311.pyc
ADDED
|
Binary file (14.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_augment.py
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import numbers
|
| 3 |
+
import warnings
|
| 4 |
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
|
| 5 |
+
|
| 6 |
+
import PIL.Image
|
| 7 |
+
import torch
|
| 8 |
+
from torch.nn.functional import one_hot
|
| 9 |
+
from torch.utils._pytree import tree_flatten, tree_unflatten
|
| 10 |
+
from torchvision import transforms as _transforms, tv_tensors
|
| 11 |
+
from torchvision.transforms.v2 import functional as F
|
| 12 |
+
|
| 13 |
+
from ._transform import _RandomApplyTransform, Transform
|
| 14 |
+
from ._utils import _check_sequence_input, _parse_labels_getter, has_any, is_pure_tensor, query_chw, query_size
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class RandomErasing(_RandomApplyTransform):
|
| 18 |
+
"""Randomly select a rectangle region in the input image or video and erase its pixels.
|
| 19 |
+
|
| 20 |
+
This transform does not support PIL Image.
|
| 21 |
+
'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
p (float, optional): probability that the random erasing operation will be performed.
|
| 25 |
+
scale (tuple of float, optional): range of proportion of erased area against input image.
|
| 26 |
+
ratio (tuple of float, optional): range of aspect ratio of erased area.
|
| 27 |
+
value (number or tuple of numbers): erasing value. Default is 0. If a single int, it is used to
|
| 28 |
+
erase all pixels. If a tuple of length 3, it is used to erase
|
| 29 |
+
R, G, B channels respectively.
|
| 30 |
+
If a str of 'random', erasing each pixel with random values.
|
| 31 |
+
inplace (bool, optional): boolean to make this transform inplace. Default set to False.
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
Erased input.
|
| 35 |
+
|
| 36 |
+
Example:
|
| 37 |
+
>>> from torchvision.transforms import v2 as transforms
|
| 38 |
+
>>>
|
| 39 |
+
>>> transform = transforms.Compose([
|
| 40 |
+
>>> transforms.RandomHorizontalFlip(),
|
| 41 |
+
>>> transforms.PILToTensor(),
|
| 42 |
+
>>> transforms.ConvertImageDtype(torch.float),
|
| 43 |
+
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
| 44 |
+
>>> transforms.RandomErasing(),
|
| 45 |
+
>>> ])
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
_v1_transform_cls = _transforms.RandomErasing
|
| 49 |
+
|
| 50 |
+
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
|
| 51 |
+
return dict(
|
| 52 |
+
super()._extract_params_for_v1_transform(),
|
| 53 |
+
value="random" if self.value is None else self.value,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
p: float = 0.5,
|
| 59 |
+
scale: Sequence[float] = (0.02, 0.33),
|
| 60 |
+
ratio: Sequence[float] = (0.3, 3.3),
|
| 61 |
+
value: float = 0.0,
|
| 62 |
+
inplace: bool = False,
|
| 63 |
+
):
|
| 64 |
+
super().__init__(p=p)
|
| 65 |
+
if not isinstance(value, (numbers.Number, str, tuple, list)):
|
| 66 |
+
raise TypeError("Argument value should be either a number or str or a sequence")
|
| 67 |
+
if isinstance(value, str) and value != "random":
|
| 68 |
+
raise ValueError("If value is str, it should be 'random'")
|
| 69 |
+
if not isinstance(scale, Sequence):
|
| 70 |
+
raise TypeError("Scale should be a sequence")
|
| 71 |
+
if not isinstance(ratio, Sequence):
|
| 72 |
+
raise TypeError("Ratio should be a sequence")
|
| 73 |
+
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
|
| 74 |
+
warnings.warn("Scale and ratio should be of kind (min, max)")
|
| 75 |
+
if scale[0] < 0 or scale[1] > 1:
|
| 76 |
+
raise ValueError("Scale should be between 0 and 1")
|
| 77 |
+
self.scale = scale
|
| 78 |
+
self.ratio = ratio
|
| 79 |
+
if isinstance(value, (int, float)):
|
| 80 |
+
self.value = [float(value)]
|
| 81 |
+
elif isinstance(value, str):
|
| 82 |
+
self.value = None
|
| 83 |
+
elif isinstance(value, (list, tuple)):
|
| 84 |
+
self.value = [float(v) for v in value]
|
| 85 |
+
else:
|
| 86 |
+
self.value = value
|
| 87 |
+
self.inplace = inplace
|
| 88 |
+
|
| 89 |
+
self._log_ratio = torch.log(torch.tensor(self.ratio))
|
| 90 |
+
|
| 91 |
+
def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
|
| 92 |
+
if isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask)):
|
| 93 |
+
warnings.warn(
|
| 94 |
+
f"{type(self).__name__}() is currently passing through inputs of type "
|
| 95 |
+
f"tv_tensors.{type(inpt).__name__}. This will likely change in the future."
|
| 96 |
+
)
|
| 97 |
+
return super()._call_kernel(functional, inpt, *args, **kwargs)
|
| 98 |
+
|
| 99 |
+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
| 100 |
+
img_c, img_h, img_w = query_chw(flat_inputs)
|
| 101 |
+
|
| 102 |
+
if self.value is not None and not (len(self.value) in (1, img_c)):
|
| 103 |
+
raise ValueError(
|
| 104 |
+
f"If value is a sequence, it should have either a single value or {img_c} (number of inpt channels)"
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
area = img_h * img_w
|
| 108 |
+
|
| 109 |
+
log_ratio = self._log_ratio
|
| 110 |
+
for _ in range(10):
|
| 111 |
+
erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
|
| 112 |
+
aspect_ratio = torch.exp(
|
| 113 |
+
torch.empty(1).uniform_(
|
| 114 |
+
log_ratio[0], # type: ignore[arg-type]
|
| 115 |
+
log_ratio[1], # type: ignore[arg-type]
|
| 116 |
+
)
|
| 117 |
+
).item()
|
| 118 |
+
|
| 119 |
+
h = int(round(math.sqrt(erase_area * aspect_ratio)))
|
| 120 |
+
w = int(round(math.sqrt(erase_area / aspect_ratio)))
|
| 121 |
+
if not (h < img_h and w < img_w):
|
| 122 |
+
continue
|
| 123 |
+
|
| 124 |
+
if self.value is None:
|
| 125 |
+
v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
|
| 126 |
+
else:
|
| 127 |
+
v = torch.tensor(self.value)[:, None, None]
|
| 128 |
+
|
| 129 |
+
i = torch.randint(0, img_h - h + 1, size=(1,)).item()
|
| 130 |
+
j = torch.randint(0, img_w - w + 1, size=(1,)).item()
|
| 131 |
+
break
|
| 132 |
+
else:
|
| 133 |
+
i, j, h, w, v = 0, 0, img_h, img_w, None
|
| 134 |
+
|
| 135 |
+
return dict(i=i, j=j, h=h, w=w, v=v)
|
| 136 |
+
|
| 137 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
| 138 |
+
if params["v"] is not None:
|
| 139 |
+
inpt = self._call_kernel(F.erase, inpt, **params, inplace=self.inplace)
|
| 140 |
+
|
| 141 |
+
return inpt
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class _BaseMixUpCutMix(Transform):
|
| 145 |
+
def __init__(self, *, alpha: float = 1.0, num_classes: Optional[int] = None, labels_getter="default") -> None:
|
| 146 |
+
super().__init__()
|
| 147 |
+
self.alpha = float(alpha)
|
| 148 |
+
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
|
| 149 |
+
|
| 150 |
+
self.num_classes = num_classes
|
| 151 |
+
|
| 152 |
+
self._labels_getter = _parse_labels_getter(labels_getter)
|
| 153 |
+
|
| 154 |
+
def forward(self, *inputs):
|
| 155 |
+
inputs = inputs if len(inputs) > 1 else inputs[0]
|
| 156 |
+
flat_inputs, spec = tree_flatten(inputs)
|
| 157 |
+
needs_transform_list = self._needs_transform_list(flat_inputs)
|
| 158 |
+
|
| 159 |
+
if has_any(flat_inputs, PIL.Image.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask):
|
| 160 |
+
raise ValueError(f"{type(self).__name__}() does not support PIL images, bounding boxes and masks.")
|
| 161 |
+
|
| 162 |
+
labels = self._labels_getter(inputs)
|
| 163 |
+
if not isinstance(labels, torch.Tensor):
|
| 164 |
+
raise ValueError(f"The labels must be a tensor, but got {type(labels)} instead.")
|
| 165 |
+
if labels.ndim not in (1, 2):
|
| 166 |
+
raise ValueError(
|
| 167 |
+
f"labels should be index based with shape (batch_size,) "
|
| 168 |
+
f"or probability based with shape (batch_size, num_classes), "
|
| 169 |
+
f"but got a tensor of shape {labels.shape} instead."
|
| 170 |
+
)
|
| 171 |
+
if labels.ndim == 2 and self.num_classes is not None and labels.shape[-1] != self.num_classes:
|
| 172 |
+
raise ValueError(
|
| 173 |
+
f"When passing 2D labels, "
|
| 174 |
+
f"the number of elements in last dimension must match num_classes: "
|
| 175 |
+
f"{labels.shape[-1]} != {self.num_classes}. "
|
| 176 |
+
f"You can Leave num_classes to None."
|
| 177 |
+
)
|
| 178 |
+
if labels.ndim == 1 and self.num_classes is None:
|
| 179 |
+
raise ValueError("num_classes must be passed if the labels are index-based (1D)")
|
| 180 |
+
|
| 181 |
+
params = {
|
| 182 |
+
"labels": labels,
|
| 183 |
+
"batch_size": labels.shape[0],
|
| 184 |
+
**self._get_params(
|
| 185 |
+
[inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform]
|
| 186 |
+
),
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
# By default, the labels will be False inside needs_transform_list, since they are a torch.Tensor coming
|
| 190 |
+
# after an image or video. However, we need to handle them in _transform, so we make sure to set them to True
|
| 191 |
+
needs_transform_list[next(idx for idx, inpt in enumerate(flat_inputs) if inpt is labels)] = True
|
| 192 |
+
flat_outputs = [
|
| 193 |
+
self._transform(inpt, params) if needs_transform else inpt
|
| 194 |
+
for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list)
|
| 195 |
+
]
|
| 196 |
+
|
| 197 |
+
return tree_unflatten(flat_outputs, spec)
|
| 198 |
+
|
| 199 |
+
def _check_image_or_video(self, inpt: torch.Tensor, *, batch_size: int):
|
| 200 |
+
expected_num_dims = 5 if isinstance(inpt, tv_tensors.Video) else 4
|
| 201 |
+
if inpt.ndim != expected_num_dims:
|
| 202 |
+
raise ValueError(
|
| 203 |
+
f"Expected a batched input with {expected_num_dims} dims, but got {inpt.ndim} dimensions instead."
|
| 204 |
+
)
|
| 205 |
+
if inpt.shape[0] != batch_size:
|
| 206 |
+
raise ValueError(
|
| 207 |
+
f"The batch size of the image or video does not match the batch size of the labels: "
|
| 208 |
+
f"{inpt.shape[0]} != {batch_size}."
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
def _mixup_label(self, label: torch.Tensor, *, lam: float) -> torch.Tensor:
|
| 212 |
+
if label.ndim == 1:
|
| 213 |
+
label = one_hot(label, num_classes=self.num_classes) # type: ignore[arg-type]
|
| 214 |
+
if not label.dtype.is_floating_point:
|
| 215 |
+
label = label.float()
|
| 216 |
+
return label.roll(1, 0).mul_(1.0 - lam).add_(label.mul(lam))
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class MixUp(_BaseMixUpCutMix):
|
| 220 |
+
"""Apply MixUp to the provided batch of images and labels.
|
| 221 |
+
|
| 222 |
+
Paper: `mixup: Beyond Empirical Risk Minimization <https://arxiv.org/abs/1710.09412>`_.
|
| 223 |
+
|
| 224 |
+
.. note::
|
| 225 |
+
This transform is meant to be used on **batches** of samples, not
|
| 226 |
+
individual images. See
|
| 227 |
+
:ref:`sphx_glr_auto_examples_transforms_plot_cutmix_mixup.py` for detailed usage
|
| 228 |
+
examples.
|
| 229 |
+
The sample pairing is deterministic and done by matching consecutive
|
| 230 |
+
samples in the batch, so the batch needs to be shuffled (this is an
|
| 231 |
+
implementation detail, not a guaranteed convention.)
|
| 232 |
+
|
| 233 |
+
In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed
|
| 234 |
+
into a tensor of shape ``(batch_size, num_classes)``.
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1.
|
| 238 |
+
num_classes (int, optional): number of classes in the batch. Used for one-hot-encoding.
|
| 239 |
+
Can be None only if the labels are already one-hot-encoded.
|
| 240 |
+
labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
|
| 241 |
+
By default, this will pick the second parameter as the labels if it's a tensor. This covers the most
|
| 242 |
+
common scenario where this transform is called as ``MixUp()(imgs_batch, labels_batch)``.
|
| 243 |
+
It can also be a callable that takes the same input as the transform, and returns the labels.
|
| 244 |
+
"""
|
| 245 |
+
|
| 246 |
+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
| 247 |
+
return dict(lam=float(self._dist.sample(()))) # type: ignore[arg-type]
|
| 248 |
+
|
| 249 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
| 250 |
+
lam = params["lam"]
|
| 251 |
+
|
| 252 |
+
if inpt is params["labels"]:
|
| 253 |
+
return self._mixup_label(inpt, lam=lam)
|
| 254 |
+
elif isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) or is_pure_tensor(inpt):
|
| 255 |
+
self._check_image_or_video(inpt, batch_size=params["batch_size"])
|
| 256 |
+
|
| 257 |
+
output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))
|
| 258 |
+
|
| 259 |
+
if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
|
| 260 |
+
output = tv_tensors.wrap(output, like=inpt)
|
| 261 |
+
|
| 262 |
+
return output
|
| 263 |
+
else:
|
| 264 |
+
return inpt
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
class CutMix(_BaseMixUpCutMix):
|
| 268 |
+
"""Apply CutMix to the provided batch of images and labels.
|
| 269 |
+
|
| 270 |
+
Paper: `CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features
|
| 271 |
+
<https://arxiv.org/abs/1905.04899>`_.
|
| 272 |
+
|
| 273 |
+
.. note::
|
| 274 |
+
This transform is meant to be used on **batches** of samples, not
|
| 275 |
+
individual images. See
|
| 276 |
+
:ref:`sphx_glr_auto_examples_transforms_plot_cutmix_mixup.py` for detailed usage
|
| 277 |
+
examples.
|
| 278 |
+
The sample pairing is deterministic and done by matching consecutive
|
| 279 |
+
samples in the batch, so the batch needs to be shuffled (this is an
|
| 280 |
+
implementation detail, not a guaranteed convention.)
|
| 281 |
+
|
| 282 |
+
In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed
|
| 283 |
+
into a tensor of shape ``(batch_size, num_classes)``.
|
| 284 |
+
|
| 285 |
+
Args:
|
| 286 |
+
alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1.
|
| 287 |
+
num_classes (int, optional): number of classes in the batch. Used for one-hot-encoding.
|
| 288 |
+
Can be None only if the labels are already one-hot-encoded.
|
| 289 |
+
labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
|
| 290 |
+
By default, this will pick the second parameter as the labels if it's a tensor. This covers the most
|
| 291 |
+
common scenario where this transform is called as ``CutMix()(imgs_batch, labels_batch)``.
|
| 292 |
+
It can also be a callable that takes the same input as the transform, and returns the labels.
|
| 293 |
+
"""
|
| 294 |
+
|
| 295 |
+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
| 296 |
+
lam = float(self._dist.sample(())) # type: ignore[arg-type]
|
| 297 |
+
|
| 298 |
+
H, W = query_size(flat_inputs)
|
| 299 |
+
|
| 300 |
+
r_x = torch.randint(W, size=(1,))
|
| 301 |
+
r_y = torch.randint(H, size=(1,))
|
| 302 |
+
|
| 303 |
+
r = 0.5 * math.sqrt(1.0 - lam)
|
| 304 |
+
r_w_half = int(r * W)
|
| 305 |
+
r_h_half = int(r * H)
|
| 306 |
+
|
| 307 |
+
x1 = int(torch.clamp(r_x - r_w_half, min=0))
|
| 308 |
+
y1 = int(torch.clamp(r_y - r_h_half, min=0))
|
| 309 |
+
x2 = int(torch.clamp(r_x + r_w_half, max=W))
|
| 310 |
+
y2 = int(torch.clamp(r_y + r_h_half, max=H))
|
| 311 |
+
box = (x1, y1, x2, y2)
|
| 312 |
+
|
| 313 |
+
lam_adjusted = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
|
| 314 |
+
|
| 315 |
+
return dict(box=box, lam_adjusted=lam_adjusted)
|
| 316 |
+
|
| 317 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
| 318 |
+
if inpt is params["labels"]:
|
| 319 |
+
return self._mixup_label(inpt, lam=params["lam_adjusted"])
|
| 320 |
+
elif isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) or is_pure_tensor(inpt):
|
| 321 |
+
self._check_image_or_video(inpt, batch_size=params["batch_size"])
|
| 322 |
+
|
| 323 |
+
x1, y1, x2, y2 = params["box"]
|
| 324 |
+
rolled = inpt.roll(1, 0)
|
| 325 |
+
output = inpt.clone()
|
| 326 |
+
output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2]
|
| 327 |
+
|
| 328 |
+
if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
|
| 329 |
+
output = tv_tensors.wrap(output, like=inpt)
|
| 330 |
+
|
| 331 |
+
return output
|
| 332 |
+
else:
|
| 333 |
+
return inpt
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
class JPEG(Transform):
|
| 337 |
+
"""Apply JPEG compression and decompression to the given images.
|
| 338 |
+
|
| 339 |
+
If the input is a :class:`torch.Tensor`, it is expected
|
| 340 |
+
to be of dtype uint8, on CPU, and have [..., 3 or 1, H, W] shape,
|
| 341 |
+
where ... means an arbitrary number of leading dimensions.
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
quality (sequence or number): JPEG quality, from 1 to 100. Lower means more compression.
|
| 345 |
+
If quality is a sequence like (min, max), it specifies the range of JPEG quality to
|
| 346 |
+
randomly select from (inclusive of both ends).
|
| 347 |
+
|
| 348 |
+
Returns:
|
| 349 |
+
image with JPEG compression.
|
| 350 |
+
"""
|
| 351 |
+
|
| 352 |
+
def __init__(self, quality: Union[int, Sequence[int]]):
|
| 353 |
+
super().__init__()
|
| 354 |
+
if isinstance(quality, int):
|
| 355 |
+
quality = [quality, quality]
|
| 356 |
+
else:
|
| 357 |
+
_check_sequence_input(quality, "quality", req_sizes=(2,))
|
| 358 |
+
|
| 359 |
+
if not (1 <= quality[0] <= quality[1] <= 100 and isinstance(quality[0], int) and isinstance(quality[1], int)):
|
| 360 |
+
raise ValueError(f"quality must be an integer from 1 to 100, got {quality =}")
|
| 361 |
+
|
| 362 |
+
self.quality = quality
|
| 363 |
+
|
| 364 |
+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
| 365 |
+
quality = torch.randint(self.quality[0], self.quality[1] + 1, ()).item()
|
| 366 |
+
return dict(quality=quality)
|
| 367 |
+
|
| 368 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
| 369 |
+
return self._call_kernel(F.jpeg, inpt, quality=params["quality"])
|
.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_auto_augment.py
ADDED
|
@@ -0,0 +1,627 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union
|
| 3 |
+
|
| 4 |
+
import PIL.Image
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
|
| 8 |
+
from torchvision import transforms as _transforms, tv_tensors
|
| 9 |
+
from torchvision.transforms import _functional_tensor as _FT
|
| 10 |
+
from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
|
| 11 |
+
from torchvision.transforms.v2.functional._geometry import _check_interpolation
|
| 12 |
+
from torchvision.transforms.v2.functional._meta import get_size
|
| 13 |
+
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
|
| 14 |
+
|
| 15 |
+
from ._utils import _get_fill, _setup_fill_arg, check_type, is_pure_tensor
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
ImageOrVideo = Union[torch.Tensor, PIL.Image.Image, tv_tensors.Image, tv_tensors.Video]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class _AutoAugmentBase(Transform):
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
*,
|
| 25 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
|
| 26 |
+
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
|
| 27 |
+
) -> None:
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.interpolation = _check_interpolation(interpolation)
|
| 30 |
+
self.fill = fill
|
| 31 |
+
self._fill = _setup_fill_arg(fill)
|
| 32 |
+
|
| 33 |
+
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
|
| 34 |
+
params = super()._extract_params_for_v1_transform()
|
| 35 |
+
|
| 36 |
+
if isinstance(params["fill"], dict):
|
| 37 |
+
raise ValueError(f"{type(self).__name__}() can not be scripted for when `fill` is a dictionary.")
|
| 38 |
+
|
| 39 |
+
return params
|
| 40 |
+
|
| 41 |
+
def _get_random_item(self, dct: Dict[str, Tuple[Callable, bool]]) -> Tuple[str, Tuple[Callable, bool]]:
|
| 42 |
+
keys = tuple(dct.keys())
|
| 43 |
+
key = keys[int(torch.randint(len(keys), ()))]
|
| 44 |
+
return key, dct[key]
|
| 45 |
+
|
| 46 |
+
def _flatten_and_extract_image_or_video(
|
| 47 |
+
self,
|
| 48 |
+
inputs: Any,
|
| 49 |
+
unsupported_types: Tuple[Type, ...] = (tv_tensors.BoundingBoxes, tv_tensors.Mask),
|
| 50 |
+
) -> Tuple[Tuple[List[Any], TreeSpec, int], ImageOrVideo]:
|
| 51 |
+
flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
|
| 52 |
+
needs_transform_list = self._needs_transform_list(flat_inputs)
|
| 53 |
+
|
| 54 |
+
image_or_videos = []
|
| 55 |
+
for idx, (inpt, needs_transform) in enumerate(zip(flat_inputs, needs_transform_list)):
|
| 56 |
+
if needs_transform and check_type(
|
| 57 |
+
inpt,
|
| 58 |
+
(
|
| 59 |
+
tv_tensors.Image,
|
| 60 |
+
PIL.Image.Image,
|
| 61 |
+
is_pure_tensor,
|
| 62 |
+
tv_tensors.Video,
|
| 63 |
+
),
|
| 64 |
+
):
|
| 65 |
+
image_or_videos.append((idx, inpt))
|
| 66 |
+
elif isinstance(inpt, unsupported_types):
|
| 67 |
+
raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()")
|
| 68 |
+
|
| 69 |
+
if not image_or_videos:
|
| 70 |
+
raise TypeError("Found no image in the sample.")
|
| 71 |
+
if len(image_or_videos) > 1:
|
| 72 |
+
raise TypeError(
|
| 73 |
+
f"Auto augment transformations are only properly defined for a single image or video, "
|
| 74 |
+
f"but found {len(image_or_videos)}."
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
idx, image_or_video = image_or_videos[0]
|
| 78 |
+
return (flat_inputs, spec, idx), image_or_video
|
| 79 |
+
|
| 80 |
+
def _unflatten_and_insert_image_or_video(
|
| 81 |
+
self,
|
| 82 |
+
flat_inputs_with_spec: Tuple[List[Any], TreeSpec, int],
|
| 83 |
+
image_or_video: ImageOrVideo,
|
| 84 |
+
) -> Any:
|
| 85 |
+
flat_inputs, spec, idx = flat_inputs_with_spec
|
| 86 |
+
flat_inputs[idx] = image_or_video
|
| 87 |
+
return tree_unflatten(flat_inputs, spec)
|
| 88 |
+
|
| 89 |
+
def _apply_image_or_video_transform(
|
| 90 |
+
self,
|
| 91 |
+
image: ImageOrVideo,
|
| 92 |
+
transform_id: str,
|
| 93 |
+
magnitude: float,
|
| 94 |
+
interpolation: Union[InterpolationMode, int],
|
| 95 |
+
fill: Dict[Union[Type, str], _FillTypeJIT],
|
| 96 |
+
) -> ImageOrVideo:
|
| 97 |
+
# Note: this cast is wrong and is only here to make mypy happy (it disagrees with torchscript)
|
| 98 |
+
image = cast(torch.Tensor, image)
|
| 99 |
+
fill_ = _get_fill(fill, type(image))
|
| 100 |
+
|
| 101 |
+
if transform_id == "Identity":
|
| 102 |
+
return image
|
| 103 |
+
elif transform_id == "ShearX":
|
| 104 |
+
# magnitude should be arctan(magnitude)
|
| 105 |
+
# official autoaug: (1, level, 0, 0, 1, 0)
|
| 106 |
+
# https://github.com/tensorflow/models/blob/dd02069717128186b88afa8d857ce57d17957f03/research/autoaugment/augmentation_transforms.py#L290
|
| 107 |
+
# compared to
|
| 108 |
+
# torchvision: (1, tan(level), 0, 0, 1, 0)
|
| 109 |
+
# https://github.com/pytorch/vision/blob/0c2373d0bba3499e95776e7936e207d8a1676e65/torchvision/transforms/functional.py#L976
|
| 110 |
+
return F.affine(
|
| 111 |
+
image,
|
| 112 |
+
angle=0.0,
|
| 113 |
+
translate=[0, 0],
|
| 114 |
+
scale=1.0,
|
| 115 |
+
shear=[math.degrees(math.atan(magnitude)), 0.0],
|
| 116 |
+
interpolation=interpolation,
|
| 117 |
+
fill=fill_,
|
| 118 |
+
center=[0, 0],
|
| 119 |
+
)
|
| 120 |
+
elif transform_id == "ShearY":
|
| 121 |
+
# magnitude should be arctan(magnitude)
|
| 122 |
+
# See above
|
| 123 |
+
return F.affine(
|
| 124 |
+
image,
|
| 125 |
+
angle=0.0,
|
| 126 |
+
translate=[0, 0],
|
| 127 |
+
scale=1.0,
|
| 128 |
+
shear=[0.0, math.degrees(math.atan(magnitude))],
|
| 129 |
+
interpolation=interpolation,
|
| 130 |
+
fill=fill_,
|
| 131 |
+
center=[0, 0],
|
| 132 |
+
)
|
| 133 |
+
elif transform_id == "TranslateX":
|
| 134 |
+
return F.affine(
|
| 135 |
+
image,
|
| 136 |
+
angle=0.0,
|
| 137 |
+
translate=[int(magnitude), 0],
|
| 138 |
+
scale=1.0,
|
| 139 |
+
interpolation=interpolation,
|
| 140 |
+
shear=[0.0, 0.0],
|
| 141 |
+
fill=fill_,
|
| 142 |
+
)
|
| 143 |
+
elif transform_id == "TranslateY":
|
| 144 |
+
return F.affine(
|
| 145 |
+
image,
|
| 146 |
+
angle=0.0,
|
| 147 |
+
translate=[0, int(magnitude)],
|
| 148 |
+
scale=1.0,
|
| 149 |
+
interpolation=interpolation,
|
| 150 |
+
shear=[0.0, 0.0],
|
| 151 |
+
fill=fill_,
|
| 152 |
+
)
|
| 153 |
+
elif transform_id == "Rotate":
|
| 154 |
+
return F.rotate(image, angle=magnitude, interpolation=interpolation, fill=fill_)
|
| 155 |
+
elif transform_id == "Brightness":
|
| 156 |
+
return F.adjust_brightness(image, brightness_factor=1.0 + magnitude)
|
| 157 |
+
elif transform_id == "Color":
|
| 158 |
+
return F.adjust_saturation(image, saturation_factor=1.0 + magnitude)
|
| 159 |
+
elif transform_id == "Contrast":
|
| 160 |
+
return F.adjust_contrast(image, contrast_factor=1.0 + magnitude)
|
| 161 |
+
elif transform_id == "Sharpness":
|
| 162 |
+
return F.adjust_sharpness(image, sharpness_factor=1.0 + magnitude)
|
| 163 |
+
elif transform_id == "Posterize":
|
| 164 |
+
return F.posterize(image, bits=int(magnitude))
|
| 165 |
+
elif transform_id == "Solarize":
|
| 166 |
+
bound = _FT._max_value(image.dtype) if isinstance(image, torch.Tensor) else 255.0
|
| 167 |
+
return F.solarize(image, threshold=bound * magnitude)
|
| 168 |
+
elif transform_id == "AutoContrast":
|
| 169 |
+
return F.autocontrast(image)
|
| 170 |
+
elif transform_id == "Equalize":
|
| 171 |
+
return F.equalize(image)
|
| 172 |
+
elif transform_id == "Invert":
|
| 173 |
+
return F.invert(image)
|
| 174 |
+
else:
|
| 175 |
+
raise ValueError(f"No transform available for {transform_id}")
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class AutoAugment(_AutoAugmentBase):
|
| 179 |
+
r"""AutoAugment data augmentation method based on
|
| 180 |
+
`"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
|
| 181 |
+
|
| 182 |
+
This transformation works on images and videos only.
|
| 183 |
+
|
| 184 |
+
If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
|
| 185 |
+
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
| 186 |
+
If img is PIL Image, it is expected to be in mode "L" or "RGB".
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
policy (AutoAugmentPolicy, optional): Desired policy enum defined by
|
| 190 |
+
:class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``.
|
| 191 |
+
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
|
| 192 |
+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
|
| 193 |
+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
|
| 194 |
+
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
|
| 195 |
+
image. If given a number, the value is used for all bands respectively.
|
| 196 |
+
"""
|
| 197 |
+
_v1_transform_cls = _transforms.AutoAugment
|
| 198 |
+
|
| 199 |
+
_AUGMENTATION_SPACE = {
|
| 200 |
+
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
|
| 201 |
+
"ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
|
| 202 |
+
"TranslateX": (
|
| 203 |
+
lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins),
|
| 204 |
+
True,
|
| 205 |
+
),
|
| 206 |
+
"TranslateY": (
|
| 207 |
+
lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins),
|
| 208 |
+
True,
|
| 209 |
+
),
|
| 210 |
+
"Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
|
| 211 |
+
"Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
|
| 212 |
+
"Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
|
| 213 |
+
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
|
| 214 |
+
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
|
| 215 |
+
"Posterize": (
|
| 216 |
+
lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
|
| 217 |
+
False,
|
| 218 |
+
),
|
| 219 |
+
"Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
|
| 220 |
+
"AutoContrast": (lambda num_bins, height, width: None, False),
|
| 221 |
+
"Equalize": (lambda num_bins, height, width: None, False),
|
| 222 |
+
"Invert": (lambda num_bins, height, width: None, False),
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
def __init__(
|
| 226 |
+
self,
|
| 227 |
+
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
|
| 228 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
|
| 229 |
+
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
|
| 230 |
+
) -> None:
|
| 231 |
+
super().__init__(interpolation=interpolation, fill=fill)
|
| 232 |
+
self.policy = policy
|
| 233 |
+
self._policies = self._get_policies(policy)
|
| 234 |
+
|
| 235 |
+
def _get_policies(
|
| 236 |
+
self, policy: AutoAugmentPolicy
|
| 237 |
+
) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
|
| 238 |
+
if policy == AutoAugmentPolicy.IMAGENET:
|
| 239 |
+
return [
|
| 240 |
+
(("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
|
| 241 |
+
(("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
|
| 242 |
+
(("Equalize", 0.8, None), ("Equalize", 0.6, None)),
|
| 243 |
+
(("Posterize", 0.6, 7), ("Posterize", 0.6, 6)),
|
| 244 |
+
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
|
| 245 |
+
(("Equalize", 0.4, None), ("Rotate", 0.8, 8)),
|
| 246 |
+
(("Solarize", 0.6, 3), ("Equalize", 0.6, None)),
|
| 247 |
+
(("Posterize", 0.8, 5), ("Equalize", 1.0, None)),
|
| 248 |
+
(("Rotate", 0.2, 3), ("Solarize", 0.6, 8)),
|
| 249 |
+
(("Equalize", 0.6, None), ("Posterize", 0.4, 6)),
|
| 250 |
+
(("Rotate", 0.8, 8), ("Color", 0.4, 0)),
|
| 251 |
+
(("Rotate", 0.4, 9), ("Equalize", 0.6, None)),
|
| 252 |
+
(("Equalize", 0.0, None), ("Equalize", 0.8, None)),
|
| 253 |
+
(("Invert", 0.6, None), ("Equalize", 1.0, None)),
|
| 254 |
+
(("Color", 0.6, 4), ("Contrast", 1.0, 8)),
|
| 255 |
+
(("Rotate", 0.8, 8), ("Color", 1.0, 2)),
|
| 256 |
+
(("Color", 0.8, 8), ("Solarize", 0.8, 7)),
|
| 257 |
+
(("Sharpness", 0.4, 7), ("Invert", 0.6, None)),
|
| 258 |
+
(("ShearX", 0.6, 5), ("Equalize", 1.0, None)),
|
| 259 |
+
(("Color", 0.4, 0), ("Equalize", 0.6, None)),
|
| 260 |
+
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
|
| 261 |
+
(("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
|
| 262 |
+
(("Invert", 0.6, None), ("Equalize", 1.0, None)),
|
| 263 |
+
(("Color", 0.6, 4), ("Contrast", 1.0, 8)),
|
| 264 |
+
(("Equalize", 0.8, None), ("Equalize", 0.6, None)),
|
| 265 |
+
]
|
| 266 |
+
elif policy == AutoAugmentPolicy.CIFAR10:
|
| 267 |
+
return [
|
| 268 |
+
(("Invert", 0.1, None), ("Contrast", 0.2, 6)),
|
| 269 |
+
(("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)),
|
| 270 |
+
(("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
|
| 271 |
+
(("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)),
|
| 272 |
+
(("AutoContrast", 0.5, None), ("Equalize", 0.9, None)),
|
| 273 |
+
(("ShearY", 0.2, 7), ("Posterize", 0.3, 7)),
|
| 274 |
+
(("Color", 0.4, 3), ("Brightness", 0.6, 7)),
|
| 275 |
+
(("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)),
|
| 276 |
+
(("Equalize", 0.6, None), ("Equalize", 0.5, None)),
|
| 277 |
+
(("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)),
|
| 278 |
+
(("Color", 0.7, 7), ("TranslateX", 0.5, 8)),
|
| 279 |
+
(("Equalize", 0.3, None), ("AutoContrast", 0.4, None)),
|
| 280 |
+
(("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)),
|
| 281 |
+
(("Brightness", 0.9, 6), ("Color", 0.2, 8)),
|
| 282 |
+
(("Solarize", 0.5, 2), ("Invert", 0.0, None)),
|
| 283 |
+
(("Equalize", 0.2, None), ("AutoContrast", 0.6, None)),
|
| 284 |
+
(("Equalize", 0.2, None), ("Equalize", 0.6, None)),
|
| 285 |
+
(("Color", 0.9, 9), ("Equalize", 0.6, None)),
|
| 286 |
+
(("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)),
|
| 287 |
+
(("Brightness", 0.1, 3), ("Color", 0.7, 0)),
|
| 288 |
+
(("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)),
|
| 289 |
+
(("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)),
|
| 290 |
+
(("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)),
|
| 291 |
+
(("Equalize", 0.8, None), ("Invert", 0.1, None)),
|
| 292 |
+
(("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)),
|
| 293 |
+
]
|
| 294 |
+
elif policy == AutoAugmentPolicy.SVHN:
|
| 295 |
+
return [
|
| 296 |
+
(("ShearX", 0.9, 4), ("Invert", 0.2, None)),
|
| 297 |
+
(("ShearY", 0.9, 8), ("Invert", 0.7, None)),
|
| 298 |
+
(("Equalize", 0.6, None), ("Solarize", 0.6, 6)),
|
| 299 |
+
(("Invert", 0.9, None), ("Equalize", 0.6, None)),
|
| 300 |
+
(("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
|
| 301 |
+
(("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)),
|
| 302 |
+
(("ShearY", 0.9, 8), ("Invert", 0.4, None)),
|
| 303 |
+
(("ShearY", 0.9, 5), ("Solarize", 0.2, 6)),
|
| 304 |
+
(("Invert", 0.9, None), ("AutoContrast", 0.8, None)),
|
| 305 |
+
(("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
|
| 306 |
+
(("ShearX", 0.9, 4), ("Solarize", 0.3, 3)),
|
| 307 |
+
(("ShearY", 0.8, 8), ("Invert", 0.7, None)),
|
| 308 |
+
(("Equalize", 0.9, None), ("TranslateY", 0.6, 6)),
|
| 309 |
+
(("Invert", 0.9, None), ("Equalize", 0.6, None)),
|
| 310 |
+
(("Contrast", 0.3, 3), ("Rotate", 0.8, 4)),
|
| 311 |
+
(("Invert", 0.8, None), ("TranslateY", 0.0, 2)),
|
| 312 |
+
(("ShearY", 0.7, 6), ("Solarize", 0.4, 8)),
|
| 313 |
+
(("Invert", 0.6, None), ("Rotate", 0.8, 4)),
|
| 314 |
+
(("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)),
|
| 315 |
+
(("ShearX", 0.1, 6), ("Invert", 0.6, None)),
|
| 316 |
+
(("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)),
|
| 317 |
+
(("ShearY", 0.8, 4), ("Invert", 0.8, None)),
|
| 318 |
+
(("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)),
|
| 319 |
+
(("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)),
|
| 320 |
+
(("ShearX", 0.7, 2), ("Invert", 0.1, None)),
|
| 321 |
+
]
|
| 322 |
+
else:
|
| 323 |
+
raise ValueError(f"The provided policy {policy} is not recognized.")
|
| 324 |
+
|
| 325 |
+
def forward(self, *inputs: Any) -> Any:
|
| 326 |
+
flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
|
| 327 |
+
height, width = get_size(image_or_video) # type: ignore[arg-type]
|
| 328 |
+
|
| 329 |
+
policy = self._policies[int(torch.randint(len(self._policies), ()))]
|
| 330 |
+
|
| 331 |
+
for transform_id, probability, magnitude_idx in policy:
|
| 332 |
+
if not torch.rand(()) <= probability:
|
| 333 |
+
continue
|
| 334 |
+
|
| 335 |
+
magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id]
|
| 336 |
+
|
| 337 |
+
magnitudes = magnitudes_fn(10, height, width)
|
| 338 |
+
if magnitudes is not None:
|
| 339 |
+
magnitude = float(magnitudes[magnitude_idx])
|
| 340 |
+
if signed and torch.rand(()) <= 0.5:
|
| 341 |
+
magnitude *= -1
|
| 342 |
+
else:
|
| 343 |
+
magnitude = 0.0
|
| 344 |
+
|
| 345 |
+
image_or_video = self._apply_image_or_video_transform(
|
| 346 |
+
image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
class RandAugment(_AutoAugmentBase):
|
| 353 |
+
r"""RandAugment data augmentation method based on
|
| 354 |
+
`"RandAugment: Practical automated data augmentation with a reduced search space"
|
| 355 |
+
<https://arxiv.org/abs/1909.13719>`_.
|
| 356 |
+
|
| 357 |
+
This transformation works on images and videos only.
|
| 358 |
+
|
| 359 |
+
If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
|
| 360 |
+
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
| 361 |
+
If img is PIL Image, it is expected to be in mode "L" or "RGB".
|
| 362 |
+
|
| 363 |
+
Args:
|
| 364 |
+
num_ops (int, optional): Number of augmentation transformations to apply sequentially.
|
| 365 |
+
magnitude (int, optional): Magnitude for all the transformations.
|
| 366 |
+
num_magnitude_bins (int, optional): The number of different magnitude values.
|
| 367 |
+
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
|
| 368 |
+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
|
| 369 |
+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
|
| 370 |
+
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
|
| 371 |
+
image. If given a number, the value is used for all bands respectively.
|
| 372 |
+
"""
|
| 373 |
+
|
| 374 |
+
_v1_transform_cls = _transforms.RandAugment
|
| 375 |
+
_AUGMENTATION_SPACE = {
|
| 376 |
+
"Identity": (lambda num_bins, height, width: None, False),
|
| 377 |
+
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
|
| 378 |
+
"ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
|
| 379 |
+
"TranslateX": (
|
| 380 |
+
lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins),
|
| 381 |
+
True,
|
| 382 |
+
),
|
| 383 |
+
"TranslateY": (
|
| 384 |
+
lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins),
|
| 385 |
+
True,
|
| 386 |
+
),
|
| 387 |
+
"Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
|
| 388 |
+
"Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
|
| 389 |
+
"Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
|
| 390 |
+
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
|
| 391 |
+
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
|
| 392 |
+
"Posterize": (
|
| 393 |
+
lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
|
| 394 |
+
False,
|
| 395 |
+
),
|
| 396 |
+
"Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
|
| 397 |
+
"AutoContrast": (lambda num_bins, height, width: None, False),
|
| 398 |
+
"Equalize": (lambda num_bins, height, width: None, False),
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
def __init__(
|
| 402 |
+
self,
|
| 403 |
+
num_ops: int = 2,
|
| 404 |
+
magnitude: int = 9,
|
| 405 |
+
num_magnitude_bins: int = 31,
|
| 406 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
|
| 407 |
+
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
|
| 408 |
+
) -> None:
|
| 409 |
+
super().__init__(interpolation=interpolation, fill=fill)
|
| 410 |
+
self.num_ops = num_ops
|
| 411 |
+
self.magnitude = magnitude
|
| 412 |
+
self.num_magnitude_bins = num_magnitude_bins
|
| 413 |
+
|
| 414 |
+
def forward(self, *inputs: Any) -> Any:
|
| 415 |
+
flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
|
| 416 |
+
height, width = get_size(image_or_video) # type: ignore[arg-type]
|
| 417 |
+
|
| 418 |
+
for _ in range(self.num_ops):
|
| 419 |
+
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
|
| 420 |
+
magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
|
| 421 |
+
if magnitudes is not None:
|
| 422 |
+
magnitude = float(magnitudes[self.magnitude])
|
| 423 |
+
if signed and torch.rand(()) <= 0.5:
|
| 424 |
+
magnitude *= -1
|
| 425 |
+
else:
|
| 426 |
+
magnitude = 0.0
|
| 427 |
+
image_or_video = self._apply_image_or_video_transform(
|
| 428 |
+
image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
class TrivialAugmentWide(_AutoAugmentBase):
|
| 435 |
+
r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in
|
| 436 |
+
`"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`_.
|
| 437 |
+
|
| 438 |
+
This transformation works on images and videos only.
|
| 439 |
+
|
| 440 |
+
If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
|
| 441 |
+
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
| 442 |
+
If img is PIL Image, it is expected to be in mode "L" or "RGB".
|
| 443 |
+
|
| 444 |
+
Args:
|
| 445 |
+
num_magnitude_bins (int, optional): The number of different magnitude values.
|
| 446 |
+
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
|
| 447 |
+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
|
| 448 |
+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
|
| 449 |
+
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
|
| 450 |
+
image. If given a number, the value is used for all bands respectively.
|
| 451 |
+
"""
|
| 452 |
+
|
| 453 |
+
_v1_transform_cls = _transforms.TrivialAugmentWide
|
| 454 |
+
_AUGMENTATION_SPACE = {
|
| 455 |
+
"Identity": (lambda num_bins, height, width: None, False),
|
| 456 |
+
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
|
| 457 |
+
"ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
|
| 458 |
+
"TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True),
|
| 459 |
+
"TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True),
|
| 460 |
+
"Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 135.0, num_bins), True),
|
| 461 |
+
"Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
|
| 462 |
+
"Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
|
| 463 |
+
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
|
| 464 |
+
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
|
| 465 |
+
"Posterize": (
|
| 466 |
+
lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6))).round().int(),
|
| 467 |
+
False,
|
| 468 |
+
),
|
| 469 |
+
"Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
|
| 470 |
+
"AutoContrast": (lambda num_bins, height, width: None, False),
|
| 471 |
+
"Equalize": (lambda num_bins, height, width: None, False),
|
| 472 |
+
}
|
| 473 |
+
|
| 474 |
+
def __init__(
|
| 475 |
+
self,
|
| 476 |
+
num_magnitude_bins: int = 31,
|
| 477 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
|
| 478 |
+
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
|
| 479 |
+
):
|
| 480 |
+
super().__init__(interpolation=interpolation, fill=fill)
|
| 481 |
+
self.num_magnitude_bins = num_magnitude_bins
|
| 482 |
+
|
| 483 |
+
def forward(self, *inputs: Any) -> Any:
|
| 484 |
+
flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
|
| 485 |
+
height, width = get_size(image_or_video) # type: ignore[arg-type]
|
| 486 |
+
|
| 487 |
+
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
|
| 488 |
+
|
| 489 |
+
magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
|
| 490 |
+
if magnitudes is not None:
|
| 491 |
+
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
|
| 492 |
+
if signed and torch.rand(()) <= 0.5:
|
| 493 |
+
magnitude *= -1
|
| 494 |
+
else:
|
| 495 |
+
magnitude = 0.0
|
| 496 |
+
|
| 497 |
+
image_or_video = self._apply_image_or_video_transform(
|
| 498 |
+
image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
|
| 499 |
+
)
|
| 500 |
+
return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
class AugMix(_AutoAugmentBase):
|
| 504 |
+
r"""AugMix data augmentation method based on
|
| 505 |
+
`"AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty" <https://arxiv.org/abs/1912.02781>`_.
|
| 506 |
+
|
| 507 |
+
This transformation works on images and videos only.
|
| 508 |
+
|
| 509 |
+
If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
|
| 510 |
+
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
| 511 |
+
If img is PIL Image, it is expected to be in mode "L" or "RGB".
|
| 512 |
+
|
| 513 |
+
Args:
|
| 514 |
+
severity (int, optional): The severity of base augmentation operators. Default is ``3``.
|
| 515 |
+
mixture_width (int, optional): The number of augmentation chains. Default is ``3``.
|
| 516 |
+
chain_depth (int, optional): The depth of augmentation chains. A negative value denotes stochastic depth sampled from the interval [1, 3].
|
| 517 |
+
Default is ``-1``.
|
| 518 |
+
alpha (float, optional): The hyperparameter for the probability distributions. Default is ``1.0``.
|
| 519 |
+
all_ops (bool, optional): Use all operations (including brightness, contrast, color and sharpness). Default is ``True``.
|
| 520 |
+
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
|
| 521 |
+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
|
| 522 |
+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
|
| 523 |
+
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
|
| 524 |
+
image. If given a number, the value is used for all bands respectively.
|
| 525 |
+
"""
|
| 526 |
+
|
| 527 |
+
_v1_transform_cls = _transforms.AugMix
|
| 528 |
+
|
| 529 |
+
_PARTIAL_AUGMENTATION_SPACE = {
|
| 530 |
+
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
|
| 531 |
+
"ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
|
| 532 |
+
"TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, width / 3.0, num_bins), True),
|
| 533 |
+
"TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, height / 3.0, num_bins), True),
|
| 534 |
+
"Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
|
| 535 |
+
"Posterize": (
|
| 536 |
+
lambda num_bins, height, width: (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
|
| 537 |
+
False,
|
| 538 |
+
),
|
| 539 |
+
"Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
|
| 540 |
+
"AutoContrast": (lambda num_bins, height, width: None, False),
|
| 541 |
+
"Equalize": (lambda num_bins, height, width: None, False),
|
| 542 |
+
}
|
| 543 |
+
_AUGMENTATION_SPACE: Dict[str, Tuple[Callable[[int, int, int], Optional[torch.Tensor]], bool]] = {
|
| 544 |
+
**_PARTIAL_AUGMENTATION_SPACE,
|
| 545 |
+
"Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
|
| 546 |
+
"Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
|
| 547 |
+
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
|
| 548 |
+
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
|
| 549 |
+
}
|
| 550 |
+
|
| 551 |
+
def __init__(
|
| 552 |
+
self,
|
| 553 |
+
severity: int = 3,
|
| 554 |
+
mixture_width: int = 3,
|
| 555 |
+
chain_depth: int = -1,
|
| 556 |
+
alpha: float = 1.0,
|
| 557 |
+
all_ops: bool = True,
|
| 558 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
|
| 559 |
+
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
|
| 560 |
+
) -> None:
|
| 561 |
+
super().__init__(interpolation=interpolation, fill=fill)
|
| 562 |
+
self._PARAMETER_MAX = 10
|
| 563 |
+
if not (1 <= severity <= self._PARAMETER_MAX):
|
| 564 |
+
raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.")
|
| 565 |
+
self.severity = severity
|
| 566 |
+
self.mixture_width = mixture_width
|
| 567 |
+
self.chain_depth = chain_depth
|
| 568 |
+
self.alpha = alpha
|
| 569 |
+
self.all_ops = all_ops
|
| 570 |
+
|
| 571 |
+
def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor:
|
| 572 |
+
# Must be on a separate method so that we can overwrite it in tests.
|
| 573 |
+
return torch._sample_dirichlet(params)
|
| 574 |
+
|
| 575 |
+
def forward(self, *inputs: Any) -> Any:
|
| 576 |
+
flat_inputs_with_spec, orig_image_or_video = self._flatten_and_extract_image_or_video(inputs)
|
| 577 |
+
height, width = get_size(orig_image_or_video) # type: ignore[arg-type]
|
| 578 |
+
|
| 579 |
+
if isinstance(orig_image_or_video, torch.Tensor):
|
| 580 |
+
image_or_video = orig_image_or_video
|
| 581 |
+
else: # isinstance(inpt, PIL.Image.Image):
|
| 582 |
+
image_or_video = F.pil_to_tensor(orig_image_or_video)
|
| 583 |
+
|
| 584 |
+
augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE
|
| 585 |
+
|
| 586 |
+
orig_dims = list(image_or_video.shape)
|
| 587 |
+
expected_ndim = 5 if isinstance(orig_image_or_video, tv_tensors.Video) else 4
|
| 588 |
+
batch = image_or_video.reshape([1] * max(expected_ndim - image_or_video.ndim, 0) + orig_dims)
|
| 589 |
+
batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)
|
| 590 |
+
|
| 591 |
+
# Sample the beta weights for combining the original and augmented image or video. To get Beta, we use a
|
| 592 |
+
# Dirichlet with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of
|
| 593 |
+
# augmented image or video.
|
| 594 |
+
m = self._sample_dirichlet(
|
| 595 |
+
torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1)
|
| 596 |
+
)
|
| 597 |
+
|
| 598 |
+
# Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images or videos.
|
| 599 |
+
combined_weights = self._sample_dirichlet(
|
| 600 |
+
torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1)
|
| 601 |
+
) * m[:, 1].reshape([batch_dims[0], -1])
|
| 602 |
+
|
| 603 |
+
mix = m[:, 0].reshape(batch_dims) * batch
|
| 604 |
+
for i in range(self.mixture_width):
|
| 605 |
+
aug = batch
|
| 606 |
+
depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item())
|
| 607 |
+
for _ in range(depth):
|
| 608 |
+
transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space)
|
| 609 |
+
|
| 610 |
+
magnitudes = magnitudes_fn(self._PARAMETER_MAX, height, width)
|
| 611 |
+
if magnitudes is not None:
|
| 612 |
+
magnitude = float(magnitudes[int(torch.randint(self.severity, ()))])
|
| 613 |
+
if signed and torch.rand(()) <= 0.5:
|
| 614 |
+
magnitude *= -1
|
| 615 |
+
else:
|
| 616 |
+
magnitude = 0.0
|
| 617 |
+
|
| 618 |
+
aug = self._apply_image_or_video_transform(aug, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill) # type: ignore[assignment]
|
| 619 |
+
mix.add_(combined_weights[:, i].reshape(batch_dims) * aug)
|
| 620 |
+
mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype)
|
| 621 |
+
|
| 622 |
+
if isinstance(orig_image_or_video, (tv_tensors.Image, tv_tensors.Video)):
|
| 623 |
+
mix = tv_tensors.wrap(mix, like=orig_image_or_video)
|
| 624 |
+
elif isinstance(orig_image_or_video, PIL.Image.Image):
|
| 625 |
+
mix = F.to_pil_image(mix)
|
| 626 |
+
|
| 627 |
+
return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, mix)
|
.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_color.py
ADDED
|
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections.abc
|
| 2 |
+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torchvision import transforms as _transforms
|
| 6 |
+
from torchvision.transforms.v2 import functional as F, Transform
|
| 7 |
+
|
| 8 |
+
from ._transform import _RandomApplyTransform
|
| 9 |
+
from ._utils import query_chw
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Grayscale(Transform):
|
| 13 |
+
"""Convert images or videos to grayscale.
|
| 14 |
+
|
| 15 |
+
If the input is a :class:`torch.Tensor`, it is expected
|
| 16 |
+
to have [..., 3 or 1, H, W] shape, where ... means an arbitrary number of leading dimensions
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
num_output_channels (int): (1 or 3) number of channels desired for output image
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
_v1_transform_cls = _transforms.Grayscale
|
| 23 |
+
|
| 24 |
+
def __init__(self, num_output_channels: int = 1):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.num_output_channels = num_output_channels
|
| 27 |
+
|
| 28 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
| 29 |
+
return self._call_kernel(F.rgb_to_grayscale, inpt, num_output_channels=self.num_output_channels)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class RandomGrayscale(_RandomApplyTransform):
|
| 33 |
+
"""Randomly convert image or videos to grayscale with a probability of p (default 0.1).
|
| 34 |
+
|
| 35 |
+
If the input is a :class:`torch.Tensor`, it is expected to have [..., 3 or 1, H, W] shape,
|
| 36 |
+
where ... means an arbitrary number of leading dimensions
|
| 37 |
+
|
| 38 |
+
The output has the same number of channels as the input.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
p (float): probability that image should be converted to grayscale.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
_v1_transform_cls = _transforms.RandomGrayscale
|
| 45 |
+
|
| 46 |
+
def __init__(self, p: float = 0.1) -> None:
|
| 47 |
+
super().__init__(p=p)
|
| 48 |
+
|
| 49 |
+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
| 50 |
+
num_input_channels, *_ = query_chw(flat_inputs)
|
| 51 |
+
return dict(num_input_channels=num_input_channels)
|
| 52 |
+
|
| 53 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
| 54 |
+
return self._call_kernel(F.rgb_to_grayscale, inpt, num_output_channels=params["num_input_channels"])
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class RGB(Transform):
|
| 58 |
+
"""Convert images or videos to RGB (if they are already not RGB).
|
| 59 |
+
|
| 60 |
+
If the input is a :class:`torch.Tensor`, it is expected
|
| 61 |
+
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
def __init__(self):
|
| 65 |
+
super().__init__()
|
| 66 |
+
|
| 67 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
| 68 |
+
return self._call_kernel(F.grayscale_to_rgb, inpt)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class ColorJitter(Transform):
|
| 72 |
+
"""Randomly change the brightness, contrast, saturation and hue of an image or video.
|
| 73 |
+
|
| 74 |
+
If the input is a :class:`torch.Tensor`, it is expected
|
| 75 |
+
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
| 76 |
+
If img is PIL Image, mode "1", "I", "F" and modes with transparency (alpha channel) are not supported.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
brightness (float or tuple of float (min, max)): How much to jitter brightness.
|
| 80 |
+
brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
|
| 81 |
+
or the given [min, max]. Should be non negative numbers.
|
| 82 |
+
contrast (float or tuple of float (min, max)): How much to jitter contrast.
|
| 83 |
+
contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
|
| 84 |
+
or the given [min, max]. Should be non-negative numbers.
|
| 85 |
+
saturation (float or tuple of float (min, max)): How much to jitter saturation.
|
| 86 |
+
saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
|
| 87 |
+
or the given [min, max]. Should be non negative numbers.
|
| 88 |
+
hue (float or tuple of float (min, max)): How much to jitter hue.
|
| 89 |
+
hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
|
| 90 |
+
Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
|
| 91 |
+
To jitter hue, the pixel values of the input image has to be non-negative for conversion to HSV space;
|
| 92 |
+
thus it does not work if you normalize your image to an interval with negative values,
|
| 93 |
+
or use an interpolation that generates negative values before using this function.
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
_v1_transform_cls = _transforms.ColorJitter
|
| 97 |
+
|
| 98 |
+
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
|
| 99 |
+
return {attr: value or 0 for attr, value in super()._extract_params_for_v1_transform().items()}
|
| 100 |
+
|
| 101 |
+
def __init__(
|
| 102 |
+
self,
|
| 103 |
+
brightness: Optional[Union[float, Sequence[float]]] = None,
|
| 104 |
+
contrast: Optional[Union[float, Sequence[float]]] = None,
|
| 105 |
+
saturation: Optional[Union[float, Sequence[float]]] = None,
|
| 106 |
+
hue: Optional[Union[float, Sequence[float]]] = None,
|
| 107 |
+
) -> None:
|
| 108 |
+
super().__init__()
|
| 109 |
+
self.brightness = self._check_input(brightness, "brightness")
|
| 110 |
+
self.contrast = self._check_input(contrast, "contrast")
|
| 111 |
+
self.saturation = self._check_input(saturation, "saturation")
|
| 112 |
+
self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
|
| 113 |
+
|
| 114 |
+
def _check_input(
|
| 115 |
+
self,
|
| 116 |
+
value: Optional[Union[float, Sequence[float]]],
|
| 117 |
+
name: str,
|
| 118 |
+
center: float = 1.0,
|
| 119 |
+
bound: Tuple[float, float] = (0, float("inf")),
|
| 120 |
+
clip_first_on_zero: bool = True,
|
| 121 |
+
) -> Optional[Tuple[float, float]]:
|
| 122 |
+
if value is None:
|
| 123 |
+
return None
|
| 124 |
+
|
| 125 |
+
if isinstance(value, (int, float)):
|
| 126 |
+
if value < 0:
|
| 127 |
+
raise ValueError(f"If {name} is a single number, it must be non negative.")
|
| 128 |
+
value = [center - value, center + value]
|
| 129 |
+
if clip_first_on_zero:
|
| 130 |
+
value[0] = max(value[0], 0.0)
|
| 131 |
+
elif isinstance(value, collections.abc.Sequence) and len(value) == 2:
|
| 132 |
+
value = [float(v) for v in value]
|
| 133 |
+
else:
|
| 134 |
+
raise TypeError(f"{name}={value} should be a single number or a sequence with length 2.")
|
| 135 |
+
|
| 136 |
+
if not bound[0] <= value[0] <= value[1] <= bound[1]:
|
| 137 |
+
raise ValueError(f"{name} values should be between {bound}, but got {value}.")
|
| 138 |
+
|
| 139 |
+
return None if value[0] == value[1] == center else (float(value[0]), float(value[1]))
|
| 140 |
+
|
| 141 |
+
@staticmethod
|
| 142 |
+
def _generate_value(left: float, right: float) -> float:
|
| 143 |
+
return torch.empty(1).uniform_(left, right).item()
|
| 144 |
+
|
| 145 |
+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
| 146 |
+
fn_idx = torch.randperm(4)
|
| 147 |
+
|
| 148 |
+
b = None if self.brightness is None else self._generate_value(self.brightness[0], self.brightness[1])
|
| 149 |
+
c = None if self.contrast is None else self._generate_value(self.contrast[0], self.contrast[1])
|
| 150 |
+
s = None if self.saturation is None else self._generate_value(self.saturation[0], self.saturation[1])
|
| 151 |
+
h = None if self.hue is None else self._generate_value(self.hue[0], self.hue[1])
|
| 152 |
+
|
| 153 |
+
return dict(fn_idx=fn_idx, brightness_factor=b, contrast_factor=c, saturation_factor=s, hue_factor=h)
|
| 154 |
+
|
| 155 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
| 156 |
+
output = inpt
|
| 157 |
+
brightness_factor = params["brightness_factor"]
|
| 158 |
+
contrast_factor = params["contrast_factor"]
|
| 159 |
+
saturation_factor = params["saturation_factor"]
|
| 160 |
+
hue_factor = params["hue_factor"]
|
| 161 |
+
for fn_id in params["fn_idx"]:
|
| 162 |
+
if fn_id == 0 and brightness_factor is not None:
|
| 163 |
+
output = self._call_kernel(F.adjust_brightness, output, brightness_factor=brightness_factor)
|
| 164 |
+
elif fn_id == 1 and contrast_factor is not None:
|
| 165 |
+
output = self._call_kernel(F.adjust_contrast, output, contrast_factor=contrast_factor)
|
| 166 |
+
elif fn_id == 2 and saturation_factor is not None:
|
| 167 |
+
output = self._call_kernel(F.adjust_saturation, output, saturation_factor=saturation_factor)
|
| 168 |
+
elif fn_id == 3 and hue_factor is not None:
|
| 169 |
+
output = self._call_kernel(F.adjust_hue, output, hue_factor=hue_factor)
|
| 170 |
+
return output
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class RandomChannelPermutation(Transform):
|
| 174 |
+
"""Randomly permute the channels of an image or video"""
|
| 175 |
+
|
| 176 |
+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
| 177 |
+
num_channels, *_ = query_chw(flat_inputs)
|
| 178 |
+
return dict(permutation=torch.randperm(num_channels))
|
| 179 |
+
|
| 180 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
| 181 |
+
return self._call_kernel(F.permute_channels, inpt, params["permutation"])
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class RandomPhotometricDistort(Transform):
|
| 185 |
+
"""Randomly distorts the image or video as used in `SSD: Single Shot
|
| 186 |
+
MultiBox Detector <https://arxiv.org/abs/1512.02325>`_.
|
| 187 |
+
|
| 188 |
+
This transform relies on :class:`~torchvision.transforms.v2.ColorJitter`
|
| 189 |
+
under the hood to adjust the contrast, saturation, hue, brightness, and also
|
| 190 |
+
randomly permutes channels.
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
brightness (tuple of float (min, max), optional): How much to jitter brightness.
|
| 194 |
+
brightness_factor is chosen uniformly from [min, max]. Should be non negative numbers.
|
| 195 |
+
contrast (tuple of float (min, max), optional): How much to jitter contrast.
|
| 196 |
+
contrast_factor is chosen uniformly from [min, max]. Should be non-negative numbers.
|
| 197 |
+
saturation (tuple of float (min, max), optional): How much to jitter saturation.
|
| 198 |
+
saturation_factor is chosen uniformly from [min, max]. Should be non negative numbers.
|
| 199 |
+
hue (tuple of float (min, max), optional): How much to jitter hue.
|
| 200 |
+
hue_factor is chosen uniformly from [min, max]. Should have -0.5 <= min <= max <= 0.5.
|
| 201 |
+
To jitter hue, the pixel values of the input image has to be non-negative for conversion to HSV space;
|
| 202 |
+
thus it does not work if you normalize your image to an interval with negative values,
|
| 203 |
+
or use an interpolation that generates negative values before using this function.
|
| 204 |
+
p (float, optional) probability each distortion operation (contrast, saturation, ...) to be applied.
|
| 205 |
+
Default is 0.5.
|
| 206 |
+
"""
|
| 207 |
+
|
| 208 |
+
def __init__(
|
| 209 |
+
self,
|
| 210 |
+
brightness: Tuple[float, float] = (0.875, 1.125),
|
| 211 |
+
contrast: Tuple[float, float] = (0.5, 1.5),
|
| 212 |
+
saturation: Tuple[float, float] = (0.5, 1.5),
|
| 213 |
+
hue: Tuple[float, float] = (-0.05, 0.05),
|
| 214 |
+
p: float = 0.5,
|
| 215 |
+
):
|
| 216 |
+
super().__init__()
|
| 217 |
+
self.brightness = brightness
|
| 218 |
+
self.contrast = contrast
|
| 219 |
+
self.hue = hue
|
| 220 |
+
self.saturation = saturation
|
| 221 |
+
self.p = p
|
| 222 |
+
|
| 223 |
+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
| 224 |
+
num_channels, *_ = query_chw(flat_inputs)
|
| 225 |
+
params: Dict[str, Any] = {
|
| 226 |
+
key: ColorJitter._generate_value(range[0], range[1]) if torch.rand(1) < self.p else None
|
| 227 |
+
for key, range in [
|
| 228 |
+
("brightness_factor", self.brightness),
|
| 229 |
+
("contrast_factor", self.contrast),
|
| 230 |
+
("saturation_factor", self.saturation),
|
| 231 |
+
("hue_factor", self.hue),
|
| 232 |
+
]
|
| 233 |
+
}
|
| 234 |
+
params["contrast_before"] = bool(torch.rand(()) < 0.5)
|
| 235 |
+
params["channel_permutation"] = torch.randperm(num_channels) if torch.rand(1) < self.p else None
|
| 236 |
+
return params
|
| 237 |
+
|
| 238 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
| 239 |
+
if params["brightness_factor"] is not None:
|
| 240 |
+
inpt = self._call_kernel(F.adjust_brightness, inpt, brightness_factor=params["brightness_factor"])
|
| 241 |
+
if params["contrast_factor"] is not None and params["contrast_before"]:
|
| 242 |
+
inpt = self._call_kernel(F.adjust_contrast, inpt, contrast_factor=params["contrast_factor"])
|
| 243 |
+
if params["saturation_factor"] is not None:
|
| 244 |
+
inpt = self._call_kernel(F.adjust_saturation, inpt, saturation_factor=params["saturation_factor"])
|
| 245 |
+
if params["hue_factor"] is not None:
|
| 246 |
+
inpt = self._call_kernel(F.adjust_hue, inpt, hue_factor=params["hue_factor"])
|
| 247 |
+
if params["contrast_factor"] is not None and not params["contrast_before"]:
|
| 248 |
+
inpt = self._call_kernel(F.adjust_contrast, inpt, contrast_factor=params["contrast_factor"])
|
| 249 |
+
if params["channel_permutation"] is not None:
|
| 250 |
+
inpt = self._call_kernel(F.permute_channels, inpt, permutation=params["channel_permutation"])
|
| 251 |
+
return inpt
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class RandomEqualize(_RandomApplyTransform):
|
| 255 |
+
"""Equalize the histogram of the given image or video with a given probability.
|
| 256 |
+
|
| 257 |
+
If the input is a :class:`torch.Tensor`, it is expected
|
| 258 |
+
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
| 259 |
+
If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
p (float): probability of the image being equalized. Default value is 0.5
|
| 263 |
+
"""
|
| 264 |
+
|
| 265 |
+
_v1_transform_cls = _transforms.RandomEqualize
|
| 266 |
+
|
| 267 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
| 268 |
+
return self._call_kernel(F.equalize, inpt)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class RandomInvert(_RandomApplyTransform):
|
| 272 |
+
"""Inverts the colors of the given image or video with a given probability.
|
| 273 |
+
|
| 274 |
+
If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
|
| 275 |
+
where ... means it can have an arbitrary number of leading dimensions.
|
| 276 |
+
If img is PIL Image, it is expected to be in mode "L" or "RGB".
|
| 277 |
+
|
| 278 |
+
Args:
|
| 279 |
+
p (float): probability of the image being color inverted. Default value is 0.5
|
| 280 |
+
"""
|
| 281 |
+
|
| 282 |
+
_v1_transform_cls = _transforms.RandomInvert
|
| 283 |
+
|
| 284 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
| 285 |
+
return self._call_kernel(F.invert, inpt)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
class RandomPosterize(_RandomApplyTransform):
|
| 289 |
+
"""Posterize the image or video with a given probability by reducing the
|
| 290 |
+
number of bits for each color channel.
|
| 291 |
+
|
| 292 |
+
If the input is a :class:`torch.Tensor`, it should be of type torch.uint8,
|
| 293 |
+
and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
| 294 |
+
If img is PIL Image, it is expected to be in mode "L" or "RGB".
|
| 295 |
+
|
| 296 |
+
Args:
|
| 297 |
+
bits (int): number of bits to keep for each channel (0-8)
|
| 298 |
+
p (float): probability of the image being posterized. Default value is 0.5
|
| 299 |
+
"""
|
| 300 |
+
|
| 301 |
+
_v1_transform_cls = _transforms.RandomPosterize
|
| 302 |
+
|
| 303 |
+
def __init__(self, bits: int, p: float = 0.5) -> None:
|
| 304 |
+
super().__init__(p=p)
|
| 305 |
+
self.bits = bits
|
| 306 |
+
|
| 307 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
| 308 |
+
return self._call_kernel(F.posterize, inpt, bits=self.bits)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
class RandomSolarize(_RandomApplyTransform):
|
| 312 |
+
"""Solarize the image or video with a given probability by inverting all pixel
|
| 313 |
+
values above a threshold.
|
| 314 |
+
|
| 315 |
+
If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
|
| 316 |
+
where ... means it can have an arbitrary number of leading dimensions.
|
| 317 |
+
If img is PIL Image, it is expected to be in mode "L" or "RGB".
|
| 318 |
+
|
| 319 |
+
Args:
|
| 320 |
+
threshold (float): all pixels equal or above this value are inverted.
|
| 321 |
+
p (float): probability of the image being solarized. Default value is 0.5
|
| 322 |
+
"""
|
| 323 |
+
|
| 324 |
+
_v1_transform_cls = _transforms.RandomSolarize
|
| 325 |
+
|
| 326 |
+
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
|
| 327 |
+
params = super()._extract_params_for_v1_transform()
|
| 328 |
+
params["threshold"] = float(params["threshold"])
|
| 329 |
+
return params
|
| 330 |
+
|
| 331 |
+
def __init__(self, threshold: float, p: float = 0.5) -> None:
|
| 332 |
+
super().__init__(p=p)
|
| 333 |
+
self.threshold = threshold
|
| 334 |
+
|
| 335 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
| 336 |
+
return self._call_kernel(F.solarize, inpt, threshold=self.threshold)
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
class RandomAutocontrast(_RandomApplyTransform):
|
| 340 |
+
"""Autocontrast the pixels of the given image or video with a given probability.
|
| 341 |
+
|
| 342 |
+
If the input is a :class:`torch.Tensor`, it is expected
|
| 343 |
+
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
| 344 |
+
If img is PIL Image, it is expected to be in mode "L" or "RGB".
|
| 345 |
+
|
| 346 |
+
Args:
|
| 347 |
+
p (float): probability of the image being autocontrasted. Default value is 0.5
|
| 348 |
+
"""
|
| 349 |
+
|
| 350 |
+
_v1_transform_cls = _transforms.RandomAutocontrast
|
| 351 |
+
|
| 352 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
| 353 |
+
return self._call_kernel(F.autocontrast, inpt)
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
class RandomAdjustSharpness(_RandomApplyTransform):
|
| 357 |
+
"""Adjust the sharpness of the image or video with a given probability.
|
| 358 |
+
|
| 359 |
+
If the input is a :class:`torch.Tensor`,
|
| 360 |
+
it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
| 361 |
+
|
| 362 |
+
Args:
|
| 363 |
+
sharpness_factor (float): How much to adjust the sharpness. Can be
|
| 364 |
+
any non-negative number. 0 gives a blurred image, 1 gives the
|
| 365 |
+
original image while 2 increases the sharpness by a factor of 2.
|
| 366 |
+
p (float): probability of the image being sharpened. Default value is 0.5
|
| 367 |
+
"""
|
| 368 |
+
|
| 369 |
+
_v1_transform_cls = _transforms.RandomAdjustSharpness
|
| 370 |
+
|
| 371 |
+
def __init__(self, sharpness_factor: float, p: float = 0.5) -> None:
|
| 372 |
+
super().__init__(p=p)
|
| 373 |
+
self.sharpness_factor = sharpness_factor
|
| 374 |
+
|
| 375 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
| 376 |
+
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=self.sharpness_factor)
|
.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_container.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torchvision import transforms as _transforms
|
| 7 |
+
from torchvision.transforms.v2 import Transform
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Compose(Transform):
|
| 11 |
+
"""Composes several transforms together.
|
| 12 |
+
|
| 13 |
+
This transform does not support torchscript.
|
| 14 |
+
Please, see the note below.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
transforms (list of ``Transform`` objects): list of transforms to compose.
|
| 18 |
+
|
| 19 |
+
Example:
|
| 20 |
+
>>> transforms.Compose([
|
| 21 |
+
>>> transforms.CenterCrop(10),
|
| 22 |
+
>>> transforms.PILToTensor(),
|
| 23 |
+
>>> transforms.ConvertImageDtype(torch.float),
|
| 24 |
+
>>> ])
|
| 25 |
+
|
| 26 |
+
.. note::
|
| 27 |
+
In order to script the transformations, please use ``torch.nn.Sequential`` as below.
|
| 28 |
+
|
| 29 |
+
>>> transforms = torch.nn.Sequential(
|
| 30 |
+
>>> transforms.CenterCrop(10),
|
| 31 |
+
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
| 32 |
+
>>> )
|
| 33 |
+
>>> scripted_transforms = torch.jit.script(transforms)
|
| 34 |
+
|
| 35 |
+
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
|
| 36 |
+
`lambda` functions or ``PIL.Image``.
|
| 37 |
+
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(self, transforms: Sequence[Callable]) -> None:
|
| 41 |
+
super().__init__()
|
| 42 |
+
if not isinstance(transforms, Sequence):
|
| 43 |
+
raise TypeError("Argument transforms should be a sequence of callables")
|
| 44 |
+
elif not transforms:
|
| 45 |
+
raise ValueError("Pass at least one transform")
|
| 46 |
+
self.transforms = transforms
|
| 47 |
+
|
| 48 |
+
def forward(self, *inputs: Any) -> Any:
|
| 49 |
+
needs_unpacking = len(inputs) > 1
|
| 50 |
+
for transform in self.transforms:
|
| 51 |
+
outputs = transform(*inputs)
|
| 52 |
+
inputs = outputs if needs_unpacking else (outputs,)
|
| 53 |
+
return outputs
|
| 54 |
+
|
| 55 |
+
def extra_repr(self) -> str:
|
| 56 |
+
format_string = []
|
| 57 |
+
for t in self.transforms:
|
| 58 |
+
format_string.append(f" {t}")
|
| 59 |
+
return "\n".join(format_string)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class RandomApply(Transform):
|
| 63 |
+
"""Apply randomly a list of transformations with a given probability.
|
| 64 |
+
|
| 65 |
+
.. note::
|
| 66 |
+
In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of
|
| 67 |
+
transforms as shown below:
|
| 68 |
+
|
| 69 |
+
>>> transforms = transforms.RandomApply(torch.nn.ModuleList([
|
| 70 |
+
>>> transforms.ColorJitter(),
|
| 71 |
+
>>> ]), p=0.3)
|
| 72 |
+
>>> scripted_transforms = torch.jit.script(transforms)
|
| 73 |
+
|
| 74 |
+
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
|
| 75 |
+
`lambda` functions or ``PIL.Image``.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
transforms (sequence or torch.nn.Module): list of transformations
|
| 79 |
+
p (float): probability of applying the list of transforms
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
_v1_transform_cls = _transforms.RandomApply
|
| 83 |
+
|
| 84 |
+
def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: float = 0.5) -> None:
|
| 85 |
+
super().__init__()
|
| 86 |
+
|
| 87 |
+
if not isinstance(transforms, (Sequence, nn.ModuleList)):
|
| 88 |
+
raise TypeError("Argument transforms should be a sequence of callables or a `nn.ModuleList`")
|
| 89 |
+
self.transforms = transforms
|
| 90 |
+
|
| 91 |
+
if not (0.0 <= p <= 1.0):
|
| 92 |
+
raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")
|
| 93 |
+
self.p = p
|
| 94 |
+
|
| 95 |
+
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
|
| 96 |
+
return {"transforms": self.transforms, "p": self.p}
|
| 97 |
+
|
| 98 |
+
def forward(self, *inputs: Any) -> Any:
|
| 99 |
+
needs_unpacking = len(inputs) > 1
|
| 100 |
+
|
| 101 |
+
if torch.rand(1) >= self.p:
|
| 102 |
+
return inputs if needs_unpacking else inputs[0]
|
| 103 |
+
|
| 104 |
+
for transform in self.transforms:
|
| 105 |
+
outputs = transform(*inputs)
|
| 106 |
+
inputs = outputs if needs_unpacking else (outputs,)
|
| 107 |
+
return outputs
|
| 108 |
+
|
| 109 |
+
def extra_repr(self) -> str:
|
| 110 |
+
format_string = []
|
| 111 |
+
for t in self.transforms:
|
| 112 |
+
format_string.append(f" {t}")
|
| 113 |
+
return "\n".join(format_string)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class RandomChoice(Transform):
|
| 117 |
+
"""Apply single transformation randomly picked from a list.
|
| 118 |
+
|
| 119 |
+
This transform does not support torchscript.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
transforms (sequence or torch.nn.Module): list of transformations
|
| 123 |
+
p (list of floats or None, optional): probability of each transform being picked.
|
| 124 |
+
If ``p`` doesn't sum to 1, it is automatically normalized. If ``None``
|
| 125 |
+
(default), all transforms have the same probability.
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
def __init__(
|
| 129 |
+
self,
|
| 130 |
+
transforms: Sequence[Callable],
|
| 131 |
+
p: Optional[List[float]] = None,
|
| 132 |
+
) -> None:
|
| 133 |
+
if not isinstance(transforms, Sequence):
|
| 134 |
+
raise TypeError("Argument transforms should be a sequence of callables")
|
| 135 |
+
|
| 136 |
+
if p is None:
|
| 137 |
+
p = [1] * len(transforms)
|
| 138 |
+
elif len(p) != len(transforms):
|
| 139 |
+
raise ValueError(f"Length of p doesn't match the number of transforms: {len(p)} != {len(transforms)}")
|
| 140 |
+
|
| 141 |
+
super().__init__()
|
| 142 |
+
|
| 143 |
+
self.transforms = transforms
|
| 144 |
+
total = sum(p)
|
| 145 |
+
self.p = [prob / total for prob in p]
|
| 146 |
+
|
| 147 |
+
def forward(self, *inputs: Any) -> Any:
|
| 148 |
+
idx = int(torch.multinomial(torch.tensor(self.p), 1))
|
| 149 |
+
transform = self.transforms[idx]
|
| 150 |
+
return transform(*inputs)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class RandomOrder(Transform):
|
| 154 |
+
"""Apply a list of transformations in a random order.
|
| 155 |
+
|
| 156 |
+
This transform does not support torchscript.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
transforms (sequence or torch.nn.Module): list of transformations
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
def __init__(self, transforms: Sequence[Callable]) -> None:
|
| 163 |
+
if not isinstance(transforms, Sequence):
|
| 164 |
+
raise TypeError("Argument transforms should be a sequence of callables")
|
| 165 |
+
super().__init__()
|
| 166 |
+
self.transforms = transforms
|
| 167 |
+
|
| 168 |
+
def forward(self, *inputs: Any) -> Any:
|
| 169 |
+
needs_unpacking = len(inputs) > 1
|
| 170 |
+
for idx in torch.randperm(len(self.transforms)):
|
| 171 |
+
transform = self.transforms[idx]
|
| 172 |
+
outputs = transform(*inputs)
|
| 173 |
+
inputs = outputs if needs_unpacking else (outputs,)
|
| 174 |
+
return outputs
|
.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_deprecated.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
from typing import Any, Dict, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import PIL.Image
|
| 6 |
+
import torch
|
| 7 |
+
from torchvision.transforms import functional as _F
|
| 8 |
+
|
| 9 |
+
from torchvision.transforms.v2 import Transform
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ToTensor(Transform):
|
| 13 |
+
"""[DEPRECATED] Use ``v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`` instead.
|
| 14 |
+
|
| 15 |
+
Convert a PIL Image or ndarray to tensor and scale the values accordingly.
|
| 16 |
+
|
| 17 |
+
.. warning::
|
| 18 |
+
:class:`v2.ToTensor` is deprecated and will be removed in a future release.
|
| 19 |
+
Please use instead ``v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])``.
|
| 20 |
+
Output is equivalent up to float precision.
|
| 21 |
+
|
| 22 |
+
This transform does not support torchscript.
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
Converts a PIL Image or numpy.ndarray (H x W x C) in the range
|
| 26 |
+
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
|
| 27 |
+
if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
|
| 28 |
+
or if the numpy.ndarray has dtype = np.uint8
|
| 29 |
+
|
| 30 |
+
In the other cases, tensors are returned without scaling.
|
| 31 |
+
|
| 32 |
+
.. note::
|
| 33 |
+
Because the input image is scaled to [0.0, 1.0], this transformation should not be used when
|
| 34 |
+
transforming target image masks. See the `references`_ for implementing the transforms for image masks.
|
| 35 |
+
|
| 36 |
+
.. _references: https://github.com/pytorch/vision/tree/main/references/segmentation
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
_transformed_types = (PIL.Image.Image, np.ndarray)
|
| 40 |
+
|
| 41 |
+
def __init__(self) -> None:
|
| 42 |
+
warnings.warn(
|
| 43 |
+
"The transform `ToTensor()` is deprecated and will be removed in a future release. "
|
| 44 |
+
"Instead, please use `v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`."
|
| 45 |
+
"Output is equivalent up to float precision."
|
| 46 |
+
)
|
| 47 |
+
super().__init__()
|
| 48 |
+
|
| 49 |
+
def _transform(self, inpt: Union[PIL.Image.Image, np.ndarray], params: Dict[str, Any]) -> torch.Tensor:
|
| 50 |
+
return _F.to_tensor(inpt)
|
.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_geometry.py
ADDED
|
@@ -0,0 +1,1416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import numbers
|
| 3 |
+
import warnings
|
| 4 |
+
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Type, Union
|
| 5 |
+
|
| 6 |
+
import PIL.Image
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from torchvision import transforms as _transforms, tv_tensors
|
| 10 |
+
from torchvision.ops.boxes import box_iou
|
| 11 |
+
from torchvision.transforms.functional import _get_perspective_coeffs
|
| 12 |
+
from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform
|
| 13 |
+
from torchvision.transforms.v2.functional._utils import _FillType
|
| 14 |
+
|
| 15 |
+
from ._transform import _RandomApplyTransform
|
| 16 |
+
from ._utils import (
|
| 17 |
+
_check_padding_arg,
|
| 18 |
+
_check_padding_mode_arg,
|
| 19 |
+
_check_sequence_input,
|
| 20 |
+
_get_fill,
|
| 21 |
+
_setup_angle,
|
| 22 |
+
_setup_fill_arg,
|
| 23 |
+
_setup_number_or_seq,
|
| 24 |
+
_setup_size,
|
| 25 |
+
get_bounding_boxes,
|
| 26 |
+
has_all,
|
| 27 |
+
has_any,
|
| 28 |
+
is_pure_tensor,
|
| 29 |
+
query_size,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class RandomHorizontalFlip(_RandomApplyTransform):
|
| 34 |
+
"""Horizontally flip the input with a given probability.
|
| 35 |
+
|
| 36 |
+
If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
|
| 37 |
+
:class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
|
| 38 |
+
it can have arbitrary number of leading batch dimensions. For example,
|
| 39 |
+
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
p (float, optional): probability of the input being flipped. Default value is 0.5
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
_v1_transform_cls = _transforms.RandomHorizontalFlip
|
| 46 |
+
|
| 47 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
| 48 |
+
return self._call_kernel(F.horizontal_flip, inpt)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class RandomVerticalFlip(_RandomApplyTransform):
|
| 52 |
+
"""Vertically flip the input with a given probability.
|
| 53 |
+
|
| 54 |
+
If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
|
| 55 |
+
:class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
|
| 56 |
+
it can have arbitrary number of leading batch dimensions. For example,
|
| 57 |
+
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
p (float, optional): probability of the input being flipped. Default value is 0.5
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
_v1_transform_cls = _transforms.RandomVerticalFlip
|
| 64 |
+
|
| 65 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
| 66 |
+
return self._call_kernel(F.vertical_flip, inpt)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class Resize(Transform):
|
| 70 |
+
"""Resize the input to the given size.
|
| 71 |
+
|
| 72 |
+
If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
|
| 73 |
+
:class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
|
| 74 |
+
it can have arbitrary number of leading batch dimensions. For example,
|
| 75 |
+
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
size (sequence, int, or None): Desired
|
| 79 |
+
output size.
|
| 80 |
+
|
| 81 |
+
- If size is a sequence like (h, w), output size will be matched to this.
|
| 82 |
+
- If size is an int, smaller edge of the image will be matched to this
|
| 83 |
+
number. i.e, if height > width, then image will be rescaled to
|
| 84 |
+
(size * height / width, size).
|
| 85 |
+
- If size is None, the output shape is determined by the ``max_size``
|
| 86 |
+
parameter.
|
| 87 |
+
|
| 88 |
+
.. note::
|
| 89 |
+
In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
|
| 90 |
+
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
|
| 91 |
+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
|
| 92 |
+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
|
| 93 |
+
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
|
| 94 |
+
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
|
| 95 |
+
max_size (int, optional): The maximum allowed for the longer edge of
|
| 96 |
+
the resized image.
|
| 97 |
+
|
| 98 |
+
- If ``size`` is an int: if the longer edge of the image is greater
|
| 99 |
+
than ``max_size`` after being resized according to ``size``,
|
| 100 |
+
``size`` will be overruled so that the longer edge is equal to
|
| 101 |
+
``max_size``. As a result, the smaller edge may be shorter than
|
| 102 |
+
``size``. This is only supported if ``size`` is an int (or a
|
| 103 |
+
sequence of length 1 in torchscript mode).
|
| 104 |
+
- If ``size`` is None: the longer edge of the image will be matched
|
| 105 |
+
to max_size. i.e, if height > width, then image will be rescaled
|
| 106 |
+
to (max_size, max_size * width / height).
|
| 107 |
+
|
| 108 |
+
This should be left to ``None`` (default) when ``size`` is a
|
| 109 |
+
sequence.
|
| 110 |
+
|
| 111 |
+
antialias (bool, optional): Whether to apply antialiasing.
|
| 112 |
+
It only affects **tensors** with bilinear or bicubic modes and it is
|
| 113 |
+
ignored otherwise: on PIL images, antialiasing is always applied on
|
| 114 |
+
bilinear or bicubic modes; on other modes (for PIL images and
|
| 115 |
+
tensors), antialiasing makes no sense and this parameter is ignored.
|
| 116 |
+
Possible values are:
|
| 117 |
+
|
| 118 |
+
- ``True`` (default): will apply antialiasing for bilinear or bicubic modes.
|
| 119 |
+
Other mode aren't affected. This is probably what you want to use.
|
| 120 |
+
- ``False``: will not apply antialiasing for tensors on any mode. PIL
|
| 121 |
+
images are still antialiased on bilinear or bicubic modes, because
|
| 122 |
+
PIL doesn't support no antialias.
|
| 123 |
+
- ``None``: equivalent to ``False`` for tensors and ``True`` for
|
| 124 |
+
PIL images. This value exists for legacy reasons and you probably
|
| 125 |
+
don't want to use it unless you really know what you are doing.
|
| 126 |
+
|
| 127 |
+
The default value changed from ``None`` to ``True`` in
|
| 128 |
+
v0.17, for the PIL and Tensor backends to be consistent.
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
_v1_transform_cls = _transforms.Resize
|
| 132 |
+
|
| 133 |
+
def __init__(
|
| 134 |
+
self,
|
| 135 |
+
size: Union[int, Sequence[int], None],
|
| 136 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
|
| 137 |
+
max_size: Optional[int] = None,
|
| 138 |
+
antialias: Optional[bool] = True,
|
| 139 |
+
) -> None:
|
| 140 |
+
super().__init__()
|
| 141 |
+
|
| 142 |
+
if isinstance(size, int):
|
| 143 |
+
size = [size]
|
| 144 |
+
elif isinstance(size, Sequence) and len(size) in {1, 2}:
|
| 145 |
+
size = list(size)
|
| 146 |
+
elif size is None:
|
| 147 |
+
if not isinstance(max_size, int):
|
| 148 |
+
raise ValueError(f"max_size must be an integer when size is None, but got {max_size} instead.")
|
| 149 |
+
else:
|
| 150 |
+
raise ValueError(
|
| 151 |
+
f"size can be an integer, a sequence of one or two integers, or None, but got {size} instead."
|
| 152 |
+
)
|
| 153 |
+
self.size = size
|
| 154 |
+
|
| 155 |
+
self.interpolation = interpolation
|
| 156 |
+
self.max_size = max_size
|
| 157 |
+
self.antialias = antialias
|
| 158 |
+
|
| 159 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
| 160 |
+
return self._call_kernel(
|
| 161 |
+
F.resize,
|
| 162 |
+
inpt,
|
| 163 |
+
self.size,
|
| 164 |
+
interpolation=self.interpolation,
|
| 165 |
+
max_size=self.max_size,
|
| 166 |
+
antialias=self.antialias,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class CenterCrop(Transform):
|
| 171 |
+
"""Crop the input at the center.
|
| 172 |
+
|
| 173 |
+
If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
|
| 174 |
+
:class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
|
| 175 |
+
it can have arbitrary number of leading batch dimensions. For example,
|
| 176 |
+
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
|
| 177 |
+
|
| 178 |
+
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
size (sequence or int): Desired output size of the crop. If size is an
|
| 182 |
+
int instead of sequence like (h, w), a square crop (size, size) is
|
| 183 |
+
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
|
| 184 |
+
"""
|
| 185 |
+
|
| 186 |
+
_v1_transform_cls = _transforms.CenterCrop
|
| 187 |
+
|
| 188 |
+
def __init__(self, size: Union[int, Sequence[int]]):
|
| 189 |
+
super().__init__()
|
| 190 |
+
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
|
| 191 |
+
|
| 192 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
| 193 |
+
return self._call_kernel(F.center_crop, inpt, output_size=self.size)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class RandomResizedCrop(Transform):
|
| 197 |
+
"""Crop a random portion of the input and resize it to a given size.
|
| 198 |
+
|
| 199 |
+
If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
|
| 200 |
+
:class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
|
| 201 |
+
it can have arbitrary number of leading batch dimensions. For example,
|
| 202 |
+
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
|
| 203 |
+
|
| 204 |
+
A crop of the original input is made: the crop has a random area (H * W)
|
| 205 |
+
and a random aspect ratio. This crop is finally resized to the given
|
| 206 |
+
size. This is popularly used to train the Inception networks.
|
| 207 |
+
|
| 208 |
+
Args:
|
| 209 |
+
size (int or sequence): expected output size of the crop, for each edge. If size is an
|
| 210 |
+
int instead of sequence like (h, w), a square output size ``(size, size)`` is
|
| 211 |
+
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
|
| 212 |
+
|
| 213 |
+
.. note::
|
| 214 |
+
In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
|
| 215 |
+
scale (tuple of float, optional): Specifies the lower and upper bounds for the random area of the crop,
|
| 216 |
+
before resizing. The scale is defined with respect to the area of the original image.
|
| 217 |
+
ratio (tuple of float, optional): lower and upper bounds for the random aspect ratio of the crop, before
|
| 218 |
+
resizing.
|
| 219 |
+
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
|
| 220 |
+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
|
| 221 |
+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
|
| 222 |
+
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
|
| 223 |
+
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
|
| 224 |
+
antialias (bool, optional): Whether to apply antialiasing.
|
| 225 |
+
It only affects **tensors** with bilinear or bicubic modes and it is
|
| 226 |
+
ignored otherwise: on PIL images, antialiasing is always applied on
|
| 227 |
+
bilinear or bicubic modes; on other modes (for PIL images and
|
| 228 |
+
tensors), antialiasing makes no sense and this parameter is ignored.
|
| 229 |
+
Possible values are:
|
| 230 |
+
|
| 231 |
+
- ``True`` (default): will apply antialiasing for bilinear or bicubic modes.
|
| 232 |
+
Other mode aren't affected. This is probably what you want to use.
|
| 233 |
+
- ``False``: will not apply antialiasing for tensors on any mode. PIL
|
| 234 |
+
images are still antialiased on bilinear or bicubic modes, because
|
| 235 |
+
PIL doesn't support no antialias.
|
| 236 |
+
- ``None``: equivalent to ``False`` for tensors and ``True`` for
|
| 237 |
+
PIL images. This value exists for legacy reasons and you probably
|
| 238 |
+
don't want to use it unless you really know what you are doing.
|
| 239 |
+
|
| 240 |
+
The default value changed from ``None`` to ``True`` in
|
| 241 |
+
v0.17, for the PIL and Tensor backends to be consistent.
|
| 242 |
+
"""
|
| 243 |
+
|
| 244 |
+
_v1_transform_cls = _transforms.RandomResizedCrop
|
| 245 |
+
|
| 246 |
+
def __init__(
|
| 247 |
+
self,
|
| 248 |
+
size: Union[int, Sequence[int]],
|
| 249 |
+
scale: Tuple[float, float] = (0.08, 1.0),
|
| 250 |
+
ratio: Tuple[float, float] = (3.0 / 4.0, 4.0 / 3.0),
|
| 251 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
|
| 252 |
+
antialias: Optional[bool] = True,
|
| 253 |
+
) -> None:
|
| 254 |
+
super().__init__()
|
| 255 |
+
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
|
| 256 |
+
|
| 257 |
+
if not isinstance(scale, Sequence):
|
| 258 |
+
raise TypeError("Scale should be a sequence")
|
| 259 |
+
if not isinstance(ratio, Sequence):
|
| 260 |
+
raise TypeError("Ratio should be a sequence")
|
| 261 |
+
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
|
| 262 |
+
warnings.warn("Scale and ratio should be of kind (min, max)")
|
| 263 |
+
|
| 264 |
+
self.scale = scale
|
| 265 |
+
self.ratio = ratio
|
| 266 |
+
self.interpolation = interpolation
|
| 267 |
+
self.antialias = antialias
|
| 268 |
+
|
| 269 |
+
self._log_ratio = torch.log(torch.tensor(self.ratio))
|
| 270 |
+
|
| 271 |
+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
| 272 |
+
height, width = query_size(flat_inputs)
|
| 273 |
+
area = height * width
|
| 274 |
+
|
| 275 |
+
log_ratio = self._log_ratio
|
| 276 |
+
for _ in range(10):
|
| 277 |
+
target_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
|
| 278 |
+
aspect_ratio = torch.exp(
|
| 279 |
+
torch.empty(1).uniform_(
|
| 280 |
+
log_ratio[0], # type: ignore[arg-type]
|
| 281 |
+
log_ratio[1], # type: ignore[arg-type]
|
| 282 |
+
)
|
| 283 |
+
).item()
|
| 284 |
+
|
| 285 |
+
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
| 286 |
+
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
| 287 |
+
|
| 288 |
+
if 0 < w <= width and 0 < h <= height:
|
| 289 |
+
i = torch.randint(0, height - h + 1, size=(1,)).item()
|
| 290 |
+
j = torch.randint(0, width - w + 1, size=(1,)).item()
|
| 291 |
+
break
|
| 292 |
+
else:
|
| 293 |
+
# Fallback to central crop
|
| 294 |
+
in_ratio = float(width) / float(height)
|
| 295 |
+
if in_ratio < min(self.ratio):
|
| 296 |
+
w = width
|
| 297 |
+
h = int(round(w / min(self.ratio)))
|
| 298 |
+
elif in_ratio > max(self.ratio):
|
| 299 |
+
h = height
|
| 300 |
+
w = int(round(h * max(self.ratio)))
|
| 301 |
+
else: # whole image
|
| 302 |
+
w = width
|
| 303 |
+
h = height
|
| 304 |
+
i = (height - h) // 2
|
| 305 |
+
j = (width - w) // 2
|
| 306 |
+
|
| 307 |
+
return dict(top=i, left=j, height=h, width=w)
|
| 308 |
+
|
| 309 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
| 310 |
+
return self._call_kernel(
|
| 311 |
+
F.resized_crop, inpt, **params, size=self.size, interpolation=self.interpolation, antialias=self.antialias
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
class FiveCrop(Transform):
|
| 316 |
+
"""Crop the image or video into four corners and the central crop.
|
| 317 |
+
|
| 318 |
+
If the input is a :class:`torch.Tensor` or a :class:`~torchvision.tv_tensors.Image` or a
|
| 319 |
+
:class:`~torchvision.tv_tensors.Video` it can have arbitrary number of leading batch dimensions.
|
| 320 |
+
For example, the image can have ``[..., C, H, W]`` shape.
|
| 321 |
+
|
| 322 |
+
.. Note::
|
| 323 |
+
This transform returns a tuple of images and there may be a mismatch in the number of
|
| 324 |
+
inputs and targets your Dataset returns. See below for an example of how to deal with
|
| 325 |
+
this.
|
| 326 |
+
|
| 327 |
+
Args:
|
| 328 |
+
size (sequence or int): Desired output size of the crop. If size is an ``int``
|
| 329 |
+
instead of sequence like (h, w), a square crop of size (size, size) is made.
|
| 330 |
+
If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
|
| 331 |
+
|
| 332 |
+
Example:
|
| 333 |
+
>>> class BatchMultiCrop(transforms.Transform):
|
| 334 |
+
... def forward(self, sample: Tuple[Tuple[Union[tv_tensors.Image, tv_tensors.Video], ...], int]):
|
| 335 |
+
... images_or_videos, labels = sample
|
| 336 |
+
... batch_size = len(images_or_videos)
|
| 337 |
+
... image_or_video = images_or_videos[0]
|
| 338 |
+
... images_or_videos = tv_tensors.wrap(torch.stack(images_or_videos), like=image_or_video)
|
| 339 |
+
... labels = torch.full((batch_size,), label, device=images_or_videos.device)
|
| 340 |
+
... return images_or_videos, labels
|
| 341 |
+
...
|
| 342 |
+
>>> image = tv_tensors.Image(torch.rand(3, 256, 256))
|
| 343 |
+
>>> label = 3
|
| 344 |
+
>>> transform = transforms.Compose([transforms.FiveCrop(224), BatchMultiCrop()])
|
| 345 |
+
>>> images, labels = transform(image, label)
|
| 346 |
+
>>> images.shape
|
| 347 |
+
torch.Size([5, 3, 224, 224])
|
| 348 |
+
>>> labels
|
| 349 |
+
tensor([3, 3, 3, 3, 3])
|
| 350 |
+
"""
|
| 351 |
+
|
| 352 |
+
_v1_transform_cls = _transforms.FiveCrop
|
| 353 |
+
|
| 354 |
+
def __init__(self, size: Union[int, Sequence[int]]) -> None:
|
| 355 |
+
super().__init__()
|
| 356 |
+
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
|
| 357 |
+
|
| 358 |
+
def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
|
| 359 |
+
if isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask)):
|
| 360 |
+
warnings.warn(
|
| 361 |
+
f"{type(self).__name__}() is currently passing through inputs of type "
|
| 362 |
+
f"tv_tensors.{type(inpt).__name__}. This will likely change in the future."
|
| 363 |
+
)
|
| 364 |
+
return super()._call_kernel(functional, inpt, *args, **kwargs)
|
| 365 |
+
|
| 366 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
| 367 |
+
return self._call_kernel(F.five_crop, inpt, self.size)
|
| 368 |
+
|
| 369 |
+
def _check_inputs(self, flat_inputs: List[Any]) -> None:
|
| 370 |
+
if has_any(flat_inputs, tv_tensors.BoundingBoxes, tv_tensors.Mask):
|
| 371 |
+
raise TypeError(f"BoundingBoxes'es and Mask's are not supported by {type(self).__name__}()")
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
class TenCrop(Transform):
|
| 375 |
+
"""Crop the image or video into four corners and the central crop plus the flipped version of
|
| 376 |
+
these (horizontal flipping is used by default).
|
| 377 |
+
|
| 378 |
+
If the input is a :class:`torch.Tensor` or a :class:`~torchvision.tv_tensors.Image` or a
|
| 379 |
+
:class:`~torchvision.tv_tensors.Video` it can have arbitrary number of leading batch dimensions.
|
| 380 |
+
For example, the image can have ``[..., C, H, W]`` shape.
|
| 381 |
+
|
| 382 |
+
See :class:`~torchvision.transforms.v2.FiveCrop` for an example.
|
| 383 |
+
|
| 384 |
+
.. Note::
|
| 385 |
+
This transform returns a tuple of images and there may be a mismatch in the number of
|
| 386 |
+
inputs and targets your Dataset returns. See below for an example of how to deal with
|
| 387 |
+
this.
|
| 388 |
+
|
| 389 |
+
Args:
|
| 390 |
+
size (sequence or int): Desired output size of the crop. If size is an
|
| 391 |
+
int instead of sequence like (h, w), a square crop (size, size) is
|
| 392 |
+
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
|
| 393 |
+
vertical_flip (bool, optional): Use vertical flipping instead of horizontal
|
| 394 |
+
"""
|
| 395 |
+
|
| 396 |
+
_v1_transform_cls = _transforms.TenCrop
|
| 397 |
+
|
| 398 |
+
def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None:
|
| 399 |
+
super().__init__()
|
| 400 |
+
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
|
| 401 |
+
self.vertical_flip = vertical_flip
|
| 402 |
+
|
| 403 |
+
def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
|
| 404 |
+
if isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask)):
|
| 405 |
+
warnings.warn(
|
| 406 |
+
f"{type(self).__name__}() is currently passing through inputs of type "
|
| 407 |
+
f"tv_tensors.{type(inpt).__name__}. This will likely change in the future."
|
| 408 |
+
)
|
| 409 |
+
return super()._call_kernel(functional, inpt, *args, **kwargs)
|
| 410 |
+
|
| 411 |
+
def _check_inputs(self, flat_inputs: List[Any]) -> None:
|
| 412 |
+
if has_any(flat_inputs, tv_tensors.BoundingBoxes, tv_tensors.Mask):
|
| 413 |
+
raise TypeError(f"BoundingBoxes'es and Mask's are not supported by {type(self).__name__}()")
|
| 414 |
+
|
| 415 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
| 416 |
+
return self._call_kernel(F.ten_crop, inpt, self.size, vertical_flip=self.vertical_flip)
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
class Pad(Transform):
|
| 420 |
+
"""Pad the input on all sides with the given "pad" value.
|
| 421 |
+
|
| 422 |
+
If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
|
| 423 |
+
:class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
|
| 424 |
+
it can have arbitrary number of leading batch dimensions. For example,
|
| 425 |
+
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
|
| 426 |
+
|
| 427 |
+
Args:
|
| 428 |
+
padding (int or sequence): Padding on each border. If a single int is provided this
|
| 429 |
+
is used to pad all borders. If sequence of length 2 is provided this is the padding
|
| 430 |
+
on left/right and top/bottom respectively. If a sequence of length 4 is provided
|
| 431 |
+
this is the padding for the left, top, right and bottom borders respectively.
|
| 432 |
+
|
| 433 |
+
.. note::
|
| 434 |
+
In torchscript mode padding as single int is not supported, use a sequence of
|
| 435 |
+
length 1: ``[padding, ]``.
|
| 436 |
+
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
|
| 437 |
+
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
|
| 438 |
+
Fill value can be also a dictionary mapping data type to the fill value, e.g.
|
| 439 |
+
``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and
|
| 440 |
+
``Mask`` will be filled with 0.
|
| 441 |
+
padding_mode (str, optional): Type of padding. Should be: constant, edge, reflect or symmetric.
|
| 442 |
+
Default is "constant".
|
| 443 |
+
|
| 444 |
+
- constant: pads with a constant value, this value is specified with fill
|
| 445 |
+
|
| 446 |
+
- edge: pads with the last value at the edge of the image.
|
| 447 |
+
|
| 448 |
+
- reflect: pads with reflection of image without repeating the last value on the edge.
|
| 449 |
+
For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
|
| 450 |
+
will result in [3, 2, 1, 2, 3, 4, 3, 2]
|
| 451 |
+
|
| 452 |
+
- symmetric: pads with reflection of image repeating the last value on the edge.
|
| 453 |
+
For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
|
| 454 |
+
will result in [2, 1, 1, 2, 3, 4, 4, 3]
|
| 455 |
+
"""
|
| 456 |
+
|
| 457 |
+
_v1_transform_cls = _transforms.Pad
|
| 458 |
+
|
| 459 |
+
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
|
| 460 |
+
params = super()._extract_params_for_v1_transform()
|
| 461 |
+
|
| 462 |
+
if not (params["fill"] is None or isinstance(params["fill"], (int, float))):
|
| 463 |
+
raise ValueError(f"{type(self).__name__}() can only be scripted for a scalar `fill`, but got {self.fill}.")
|
| 464 |
+
|
| 465 |
+
return params
|
| 466 |
+
|
| 467 |
+
def __init__(
|
| 468 |
+
self,
|
| 469 |
+
padding: Union[int, Sequence[int]],
|
| 470 |
+
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
|
| 471 |
+
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
|
| 472 |
+
) -> None:
|
| 473 |
+
super().__init__()
|
| 474 |
+
|
| 475 |
+
_check_padding_arg(padding)
|
| 476 |
+
_check_padding_mode_arg(padding_mode)
|
| 477 |
+
|
| 478 |
+
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
|
| 479 |
+
if not isinstance(padding, int):
|
| 480 |
+
padding = list(padding)
|
| 481 |
+
self.padding = padding
|
| 482 |
+
self.fill = fill
|
| 483 |
+
self._fill = _setup_fill_arg(fill)
|
| 484 |
+
self.padding_mode = padding_mode
|
| 485 |
+
|
| 486 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
| 487 |
+
fill = _get_fill(self._fill, type(inpt))
|
| 488 |
+
return self._call_kernel(F.pad, inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type]
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
class RandomZoomOut(_RandomApplyTransform):
|
| 492 |
+
""" "Zoom out" transformation from
|
| 493 |
+
`"SSD: Single Shot MultiBox Detector" <https://arxiv.org/abs/1512.02325>`_.
|
| 494 |
+
|
| 495 |
+
This transformation randomly pads images, videos, bounding boxes and masks creating a zoom out effect.
|
| 496 |
+
Output spatial size is randomly sampled from original size up to a maximum size configured
|
| 497 |
+
with ``side_range`` parameter:
|
| 498 |
+
|
| 499 |
+
.. code-block:: python
|
| 500 |
+
|
| 501 |
+
r = uniform_sample(side_range[0], side_range[1])
|
| 502 |
+
output_width = input_width * r
|
| 503 |
+
output_height = input_height * r
|
| 504 |
+
|
| 505 |
+
If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
|
| 506 |
+
:class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
|
| 507 |
+
it can have arbitrary number of leading batch dimensions. For example,
|
| 508 |
+
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
|
| 509 |
+
|
| 510 |
+
Args:
|
| 511 |
+
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
|
| 512 |
+
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
|
| 513 |
+
Fill value can be also a dictionary mapping data type to the fill value, e.g.
|
| 514 |
+
``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and
|
| 515 |
+
``Mask`` will be filled with 0.
|
| 516 |
+
side_range (sequence of floats, optional): tuple of two floats defines minimum and maximum factors to
|
| 517 |
+
scale the input size.
|
| 518 |
+
p (float, optional): probability that the zoom operation will be performed.
|
| 519 |
+
"""
|
| 520 |
+
|
| 521 |
+
def __init__(
|
| 522 |
+
self,
|
| 523 |
+
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
|
| 524 |
+
side_range: Sequence[float] = (1.0, 4.0),
|
| 525 |
+
p: float = 0.5,
|
| 526 |
+
) -> None:
|
| 527 |
+
super().__init__(p=p)
|
| 528 |
+
|
| 529 |
+
self.fill = fill
|
| 530 |
+
self._fill = _setup_fill_arg(fill)
|
| 531 |
+
|
| 532 |
+
_check_sequence_input(side_range, "side_range", req_sizes=(2,))
|
| 533 |
+
|
| 534 |
+
self.side_range = side_range
|
| 535 |
+
if side_range[0] < 1.0 or side_range[0] > side_range[1]:
|
| 536 |
+
raise ValueError(f"Invalid side range provided {side_range}.")
|
| 537 |
+
|
| 538 |
+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
| 539 |
+
orig_h, orig_w = query_size(flat_inputs)
|
| 540 |
+
|
| 541 |
+
r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
|
| 542 |
+
canvas_width = int(orig_w * r)
|
| 543 |
+
canvas_height = int(orig_h * r)
|
| 544 |
+
|
| 545 |
+
r = torch.rand(2)
|
| 546 |
+
left = int((canvas_width - orig_w) * r[0])
|
| 547 |
+
top = int((canvas_height - orig_h) * r[1])
|
| 548 |
+
right = canvas_width - (left + orig_w)
|
| 549 |
+
bottom = canvas_height - (top + orig_h)
|
| 550 |
+
padding = [left, top, right, bottom]
|
| 551 |
+
|
| 552 |
+
return dict(padding=padding)
|
| 553 |
+
|
| 554 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
| 555 |
+
fill = _get_fill(self._fill, type(inpt))
|
| 556 |
+
return self._call_kernel(F.pad, inpt, **params, fill=fill)
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
class RandomRotation(Transform):
|
| 560 |
+
"""Rotate the input by angle.
|
| 561 |
+
|
| 562 |
+
If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
|
| 563 |
+
:class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
|
| 564 |
+
it can have arbitrary number of leading batch dimensions. For example,
|
| 565 |
+
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
|
| 566 |
+
|
| 567 |
+
Args:
|
| 568 |
+
degrees (sequence or number): Range of degrees to select from.
|
| 569 |
+
If degrees is a number instead of sequence like (min, max), the range of degrees
|
| 570 |
+
will be (-degrees, +degrees).
|
| 571 |
+
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
|
| 572 |
+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
|
| 573 |
+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
|
| 574 |
+
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
|
| 575 |
+
expand (bool, optional): Optional expansion flag.
|
| 576 |
+
If true, expands the output to make it large enough to hold the entire rotated image.
|
| 577 |
+
If false or omitted, make the output image the same size as the input image.
|
| 578 |
+
Note that the expand flag assumes rotation around the center (see note below) and no translation.
|
| 579 |
+
center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
|
| 580 |
+
Default is the center of the image.
|
| 581 |
+
|
| 582 |
+
.. note::
|
| 583 |
+
|
| 584 |
+
In theory, setting ``center`` has no effect if ``expand=True``, since the image center will become the
|
| 585 |
+
center of rotation. In practice however, due to numerical precision, this can lead to off-by-one
|
| 586 |
+
differences of the resulting image size compared to using the image center in the first place. Thus, when
|
| 587 |
+
setting ``expand=True``, it's best to leave ``center=None`` (default).
|
| 588 |
+
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
|
| 589 |
+
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
|
| 590 |
+
Fill value can be also a dictionary mapping data type to the fill value, e.g.
|
| 591 |
+
``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and
|
| 592 |
+
``Mask`` will be filled with 0.
|
| 593 |
+
|
| 594 |
+
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
|
| 595 |
+
|
| 596 |
+
"""
|
| 597 |
+
|
| 598 |
+
_v1_transform_cls = _transforms.RandomRotation
|
| 599 |
+
|
| 600 |
+
def __init__(
|
| 601 |
+
self,
|
| 602 |
+
degrees: Union[numbers.Number, Sequence],
|
| 603 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
|
| 604 |
+
expand: bool = False,
|
| 605 |
+
center: Optional[List[float]] = None,
|
| 606 |
+
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
|
| 607 |
+
) -> None:
|
| 608 |
+
super().__init__()
|
| 609 |
+
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
|
| 610 |
+
self.interpolation = interpolation
|
| 611 |
+
self.expand = expand
|
| 612 |
+
|
| 613 |
+
self.fill = fill
|
| 614 |
+
self._fill = _setup_fill_arg(fill)
|
| 615 |
+
|
| 616 |
+
if center is not None:
|
| 617 |
+
_check_sequence_input(center, "center", req_sizes=(2,))
|
| 618 |
+
|
| 619 |
+
self.center = center
|
| 620 |
+
|
| 621 |
+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
| 622 |
+
angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item()
|
| 623 |
+
return dict(angle=angle)
|
| 624 |
+
|
| 625 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
| 626 |
+
fill = _get_fill(self._fill, type(inpt))
|
| 627 |
+
return self._call_kernel(
|
| 628 |
+
F.rotate,
|
| 629 |
+
inpt,
|
| 630 |
+
**params,
|
| 631 |
+
interpolation=self.interpolation,
|
| 632 |
+
expand=self.expand,
|
| 633 |
+
center=self.center,
|
| 634 |
+
fill=fill,
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
class RandomAffine(Transform):
|
| 639 |
+
"""Random affine transformation the input keeping center invariant.
|
| 640 |
+
|
| 641 |
+
If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
|
| 642 |
+
:class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
|
| 643 |
+
it can have arbitrary number of leading batch dimensions. For example,
|
| 644 |
+
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
|
| 645 |
+
|
| 646 |
+
Args:
|
| 647 |
+
degrees (sequence or number): Range of degrees to select from.
|
| 648 |
+
If degrees is a number instead of sequence like (min, max), the range of degrees
|
| 649 |
+
will be (-degrees, +degrees). Set to 0 to deactivate rotations.
|
| 650 |
+
translate (tuple, optional): tuple of maximum absolute fraction for horizontal
|
| 651 |
+
and vertical translations. For example translate=(a, b), then horizontal shift
|
| 652 |
+
is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
|
| 653 |
+
randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
|
| 654 |
+
scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is
|
| 655 |
+
randomly sampled from the range a <= scale <= b. Will keep original scale by default.
|
| 656 |
+
shear (sequence or number, optional): Range of degrees to select from.
|
| 657 |
+
If shear is a number, a shear parallel to the x-axis in the range (-shear, +shear)
|
| 658 |
+
will be applied. Else if shear is a sequence of 2 values a shear parallel to the x-axis in the
|
| 659 |
+
range (shear[0], shear[1]) will be applied. Else if shear is a sequence of 4 values,
|
| 660 |
+
an x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied.
|
| 661 |
+
Will not apply shear by default.
|
| 662 |
+
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
|
| 663 |
+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
|
| 664 |
+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
|
| 665 |
+
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
|
| 666 |
+
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
|
| 667 |
+
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
|
| 668 |
+
Fill value can be also a dictionary mapping data type to the fill value, e.g.
|
| 669 |
+
``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and
|
| 670 |
+
``Mask`` will be filled with 0.
|
| 671 |
+
center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
|
| 672 |
+
Default is the center of the image.
|
| 673 |
+
|
| 674 |
+
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
|
| 675 |
+
|
| 676 |
+
"""
|
| 677 |
+
|
| 678 |
+
_v1_transform_cls = _transforms.RandomAffine
|
| 679 |
+
|
| 680 |
+
def __init__(
|
| 681 |
+
self,
|
| 682 |
+
degrees: Union[numbers.Number, Sequence],
|
| 683 |
+
translate: Optional[Sequence[float]] = None,
|
| 684 |
+
scale: Optional[Sequence[float]] = None,
|
| 685 |
+
shear: Optional[Union[int, float, Sequence[float]]] = None,
|
| 686 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
|
| 687 |
+
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
|
| 688 |
+
center: Optional[List[float]] = None,
|
| 689 |
+
) -> None:
|
| 690 |
+
super().__init__()
|
| 691 |
+
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
|
| 692 |
+
if translate is not None:
|
| 693 |
+
_check_sequence_input(translate, "translate", req_sizes=(2,))
|
| 694 |
+
for t in translate:
|
| 695 |
+
if not (0.0 <= t <= 1.0):
|
| 696 |
+
raise ValueError("translation values should be between 0 and 1")
|
| 697 |
+
self.translate = translate
|
| 698 |
+
if scale is not None:
|
| 699 |
+
_check_sequence_input(scale, "scale", req_sizes=(2,))
|
| 700 |
+
for s in scale:
|
| 701 |
+
if s <= 0:
|
| 702 |
+
raise ValueError("scale values should be positive")
|
| 703 |
+
self.scale = scale
|
| 704 |
+
|
| 705 |
+
if shear is not None:
|
| 706 |
+
self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4))
|
| 707 |
+
else:
|
| 708 |
+
self.shear = shear
|
| 709 |
+
|
| 710 |
+
self.interpolation = interpolation
|
| 711 |
+
self.fill = fill
|
| 712 |
+
self._fill = _setup_fill_arg(fill)
|
| 713 |
+
|
| 714 |
+
if center is not None:
|
| 715 |
+
_check_sequence_input(center, "center", req_sizes=(2,))
|
| 716 |
+
|
| 717 |
+
self.center = center
|
| 718 |
+
|
| 719 |
+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
| 720 |
+
height, width = query_size(flat_inputs)
|
| 721 |
+
|
| 722 |
+
angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item()
|
| 723 |
+
if self.translate is not None:
|
| 724 |
+
max_dx = float(self.translate[0] * width)
|
| 725 |
+
max_dy = float(self.translate[1] * height)
|
| 726 |
+
tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item()))
|
| 727 |
+
ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item()))
|
| 728 |
+
translate = (tx, ty)
|
| 729 |
+
else:
|
| 730 |
+
translate = (0, 0)
|
| 731 |
+
|
| 732 |
+
if self.scale is not None:
|
| 733 |
+
scale = torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
|
| 734 |
+
else:
|
| 735 |
+
scale = 1.0
|
| 736 |
+
|
| 737 |
+
shear_x = shear_y = 0.0
|
| 738 |
+
if self.shear is not None:
|
| 739 |
+
shear_x = torch.empty(1).uniform_(self.shear[0], self.shear[1]).item()
|
| 740 |
+
if len(self.shear) == 4:
|
| 741 |
+
shear_y = torch.empty(1).uniform_(self.shear[2], self.shear[3]).item()
|
| 742 |
+
|
| 743 |
+
shear = (shear_x, shear_y)
|
| 744 |
+
return dict(angle=angle, translate=translate, scale=scale, shear=shear)
|
| 745 |
+
|
| 746 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
| 747 |
+
fill = _get_fill(self._fill, type(inpt))
|
| 748 |
+
return self._call_kernel(
|
| 749 |
+
F.affine,
|
| 750 |
+
inpt,
|
| 751 |
+
**params,
|
| 752 |
+
interpolation=self.interpolation,
|
| 753 |
+
fill=fill,
|
| 754 |
+
center=self.center,
|
| 755 |
+
)
|
| 756 |
+
|
| 757 |
+
|
| 758 |
+
class RandomCrop(Transform):
|
| 759 |
+
"""Crop the input at a random location.
|
| 760 |
+
|
| 761 |
+
If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
|
| 762 |
+
:class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
|
| 763 |
+
it can have arbitrary number of leading batch dimensions. For example,
|
| 764 |
+
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
|
| 765 |
+
|
| 766 |
+
Args:
|
| 767 |
+
size (sequence or int): Desired output size of the crop. If size is an
|
| 768 |
+
int instead of sequence like (h, w), a square crop (size, size) is
|
| 769 |
+
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
|
| 770 |
+
padding (int or sequence, optional): Optional padding on each border
|
| 771 |
+
of the image. Default is None. If a single int is provided this
|
| 772 |
+
is used to pad all borders. If sequence of length 2 is provided this is the padding
|
| 773 |
+
on left/right and top/bottom respectively. If a sequence of length 4 is provided
|
| 774 |
+
this is the padding for the left, top, right and bottom borders respectively.
|
| 775 |
+
|
| 776 |
+
.. note::
|
| 777 |
+
In torchscript mode padding as single int is not supported, use a sequence of
|
| 778 |
+
length 1: ``[padding, ]``.
|
| 779 |
+
pad_if_needed (boolean, optional): It will pad the image if smaller than the
|
| 780 |
+
desired size to avoid raising an exception. Since cropping is done
|
| 781 |
+
after padding, the padding seems to be done at a random offset.
|
| 782 |
+
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
|
| 783 |
+
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
|
| 784 |
+
Fill value can be also a dictionary mapping data type to the fill value, e.g.
|
| 785 |
+
``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and
|
| 786 |
+
``Mask`` will be filled with 0.
|
| 787 |
+
padding_mode (str, optional): Type of padding. Should be: constant, edge, reflect or symmetric.
|
| 788 |
+
Default is constant.
|
| 789 |
+
|
| 790 |
+
- constant: pads with a constant value, this value is specified with fill
|
| 791 |
+
|
| 792 |
+
- edge: pads with the last value at the edge of the image.
|
| 793 |
+
|
| 794 |
+
- reflect: pads with reflection of image without repeating the last value on the edge.
|
| 795 |
+
For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
|
| 796 |
+
will result in [3, 2, 1, 2, 3, 4, 3, 2]
|
| 797 |
+
|
| 798 |
+
- symmetric: pads with reflection of image repeating the last value on the edge.
|
| 799 |
+
For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
|
| 800 |
+
will result in [2, 1, 1, 2, 3, 4, 4, 3]
|
| 801 |
+
"""
|
| 802 |
+
|
| 803 |
+
_v1_transform_cls = _transforms.RandomCrop
|
| 804 |
+
|
| 805 |
+
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
|
| 806 |
+
params = super()._extract_params_for_v1_transform()
|
| 807 |
+
|
| 808 |
+
if not (params["fill"] is None or isinstance(params["fill"], (int, float))):
|
| 809 |
+
raise ValueError(f"{type(self).__name__}() can only be scripted for a scalar `fill`, but got {self.fill}.")
|
| 810 |
+
|
| 811 |
+
padding = self.padding
|
| 812 |
+
if padding is not None:
|
| 813 |
+
pad_left, pad_right, pad_top, pad_bottom = padding
|
| 814 |
+
padding = [pad_left, pad_top, pad_right, pad_bottom]
|
| 815 |
+
params["padding"] = padding
|
| 816 |
+
|
| 817 |
+
return params
|
| 818 |
+
|
| 819 |
+
def __init__(
|
| 820 |
+
self,
|
| 821 |
+
size: Union[int, Sequence[int]],
|
| 822 |
+
padding: Optional[Union[int, Sequence[int]]] = None,
|
| 823 |
+
pad_if_needed: bool = False,
|
| 824 |
+
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
|
| 825 |
+
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
|
| 826 |
+
) -> None:
|
| 827 |
+
super().__init__()
|
| 828 |
+
|
| 829 |
+
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
|
| 830 |
+
|
| 831 |
+
if pad_if_needed or padding is not None:
|
| 832 |
+
if padding is not None:
|
| 833 |
+
_check_padding_arg(padding)
|
| 834 |
+
_check_padding_mode_arg(padding_mode)
|
| 835 |
+
|
| 836 |
+
self.padding = F._geometry._parse_pad_padding(padding) if padding else None # type: ignore[arg-type]
|
| 837 |
+
self.pad_if_needed = pad_if_needed
|
| 838 |
+
self.fill = fill
|
| 839 |
+
self._fill = _setup_fill_arg(fill)
|
| 840 |
+
self.padding_mode = padding_mode
|
| 841 |
+
|
| 842 |
+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
| 843 |
+
padded_height, padded_width = query_size(flat_inputs)
|
| 844 |
+
|
| 845 |
+
if self.padding is not None:
|
| 846 |
+
pad_left, pad_right, pad_top, pad_bottom = self.padding
|
| 847 |
+
padded_height += pad_top + pad_bottom
|
| 848 |
+
padded_width += pad_left + pad_right
|
| 849 |
+
else:
|
| 850 |
+
pad_left = pad_right = pad_top = pad_bottom = 0
|
| 851 |
+
|
| 852 |
+
cropped_height, cropped_width = self.size
|
| 853 |
+
|
| 854 |
+
if self.pad_if_needed:
|
| 855 |
+
if padded_height < cropped_height:
|
| 856 |
+
diff = cropped_height - padded_height
|
| 857 |
+
|
| 858 |
+
pad_top += diff
|
| 859 |
+
pad_bottom += diff
|
| 860 |
+
padded_height += 2 * diff
|
| 861 |
+
|
| 862 |
+
if padded_width < cropped_width:
|
| 863 |
+
diff = cropped_width - padded_width
|
| 864 |
+
|
| 865 |
+
pad_left += diff
|
| 866 |
+
pad_right += diff
|
| 867 |
+
padded_width += 2 * diff
|
| 868 |
+
|
| 869 |
+
if padded_height < cropped_height or padded_width < cropped_width:
|
| 870 |
+
raise ValueError(
|
| 871 |
+
f"Required crop size {(cropped_height, cropped_width)} is larger than "
|
| 872 |
+
f"{'padded ' if self.padding is not None else ''}input image size {(padded_height, padded_width)}."
|
| 873 |
+
)
|
| 874 |
+
|
| 875 |
+
# We need a different order here than we have in self.padding since this padding will be parsed again in `F.pad`
|
| 876 |
+
padding = [pad_left, pad_top, pad_right, pad_bottom]
|
| 877 |
+
needs_pad = any(padding)
|
| 878 |
+
|
| 879 |
+
needs_vert_crop, top = (
|
| 880 |
+
(True, int(torch.randint(0, padded_height - cropped_height + 1, size=())))
|
| 881 |
+
if padded_height > cropped_height
|
| 882 |
+
else (False, 0)
|
| 883 |
+
)
|
| 884 |
+
needs_horz_crop, left = (
|
| 885 |
+
(True, int(torch.randint(0, padded_width - cropped_width + 1, size=())))
|
| 886 |
+
if padded_width > cropped_width
|
| 887 |
+
else (False, 0)
|
| 888 |
+
)
|
| 889 |
+
|
| 890 |
+
return dict(
|
| 891 |
+
needs_crop=needs_vert_crop or needs_horz_crop,
|
| 892 |
+
top=top,
|
| 893 |
+
left=left,
|
| 894 |
+
height=cropped_height,
|
| 895 |
+
width=cropped_width,
|
| 896 |
+
needs_pad=needs_pad,
|
| 897 |
+
padding=padding,
|
| 898 |
+
)
|
| 899 |
+
|
| 900 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
| 901 |
+
if params["needs_pad"]:
|
| 902 |
+
fill = _get_fill(self._fill, type(inpt))
|
| 903 |
+
inpt = self._call_kernel(F.pad, inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode)
|
| 904 |
+
|
| 905 |
+
if params["needs_crop"]:
|
| 906 |
+
inpt = self._call_kernel(
|
| 907 |
+
F.crop, inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]
|
| 908 |
+
)
|
| 909 |
+
|
| 910 |
+
return inpt
|
| 911 |
+
|
| 912 |
+
|
| 913 |
+
class RandomPerspective(_RandomApplyTransform):
|
| 914 |
+
"""Perform a random perspective transformation of the input with a given probability.
|
| 915 |
+
|
| 916 |
+
If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
|
| 917 |
+
:class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
|
| 918 |
+
it can have arbitrary number of leading batch dimensions. For example,
|
| 919 |
+
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
|
| 920 |
+
|
| 921 |
+
Args:
|
| 922 |
+
distortion_scale (float, optional): argument to control the degree of distortion and ranges from 0 to 1.
|
| 923 |
+
Default is 0.5.
|
| 924 |
+
p (float, optional): probability of the input being transformed. Default is 0.5.
|
| 925 |
+
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
|
| 926 |
+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
|
| 927 |
+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
|
| 928 |
+
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
|
| 929 |
+
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
|
| 930 |
+
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
|
| 931 |
+
Fill value can be also a dictionary mapping data type to the fill value, e.g.
|
| 932 |
+
``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and
|
| 933 |
+
``Mask`` will be filled with 0.
|
| 934 |
+
"""
|
| 935 |
+
|
| 936 |
+
_v1_transform_cls = _transforms.RandomPerspective
|
| 937 |
+
|
| 938 |
+
def __init__(
|
| 939 |
+
self,
|
| 940 |
+
distortion_scale: float = 0.5,
|
| 941 |
+
p: float = 0.5,
|
| 942 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
|
| 943 |
+
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
|
| 944 |
+
) -> None:
|
| 945 |
+
super().__init__(p=p)
|
| 946 |
+
|
| 947 |
+
if not (0 <= distortion_scale <= 1):
|
| 948 |
+
raise ValueError("Argument distortion_scale value should be between 0 and 1")
|
| 949 |
+
|
| 950 |
+
self.distortion_scale = distortion_scale
|
| 951 |
+
self.interpolation = interpolation
|
| 952 |
+
self.fill = fill
|
| 953 |
+
self._fill = _setup_fill_arg(fill)
|
| 954 |
+
|
| 955 |
+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
| 956 |
+
height, width = query_size(flat_inputs)
|
| 957 |
+
|
| 958 |
+
distortion_scale = self.distortion_scale
|
| 959 |
+
|
| 960 |
+
half_height = height // 2
|
| 961 |
+
half_width = width // 2
|
| 962 |
+
bound_height = int(distortion_scale * half_height) + 1
|
| 963 |
+
bound_width = int(distortion_scale * half_width) + 1
|
| 964 |
+
topleft = [
|
| 965 |
+
int(torch.randint(0, bound_width, size=(1,))),
|
| 966 |
+
int(torch.randint(0, bound_height, size=(1,))),
|
| 967 |
+
]
|
| 968 |
+
topright = [
|
| 969 |
+
int(torch.randint(width - bound_width, width, size=(1,))),
|
| 970 |
+
int(torch.randint(0, bound_height, size=(1,))),
|
| 971 |
+
]
|
| 972 |
+
botright = [
|
| 973 |
+
int(torch.randint(width - bound_width, width, size=(1,))),
|
| 974 |
+
int(torch.randint(height - bound_height, height, size=(1,))),
|
| 975 |
+
]
|
| 976 |
+
botleft = [
|
| 977 |
+
int(torch.randint(0, bound_width, size=(1,))),
|
| 978 |
+
int(torch.randint(height - bound_height, height, size=(1,))),
|
| 979 |
+
]
|
| 980 |
+
startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]
|
| 981 |
+
endpoints = [topleft, topright, botright, botleft]
|
| 982 |
+
perspective_coeffs = _get_perspective_coeffs(startpoints, endpoints)
|
| 983 |
+
return dict(coefficients=perspective_coeffs)
|
| 984 |
+
|
| 985 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
| 986 |
+
fill = _get_fill(self._fill, type(inpt))
|
| 987 |
+
return self._call_kernel(
|
| 988 |
+
F.perspective,
|
| 989 |
+
inpt,
|
| 990 |
+
startpoints=None,
|
| 991 |
+
endpoints=None,
|
| 992 |
+
fill=fill,
|
| 993 |
+
interpolation=self.interpolation,
|
| 994 |
+
**params,
|
| 995 |
+
)
|
| 996 |
+
|
| 997 |
+
|
| 998 |
+
class ElasticTransform(Transform):
|
| 999 |
+
"""Transform the input with elastic transformations.
|
| 1000 |
+
|
| 1001 |
+
If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
|
| 1002 |
+
:class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
|
| 1003 |
+
it can have arbitrary number of leading batch dimensions. For example,
|
| 1004 |
+
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
|
| 1005 |
+
|
| 1006 |
+
Given alpha and sigma, it will generate displacement
|
| 1007 |
+
vectors for all pixels based on random offsets. Alpha controls the strength
|
| 1008 |
+
and sigma controls the smoothness of the displacements.
|
| 1009 |
+
The displacements are added to an identity grid and the resulting grid is
|
| 1010 |
+
used to transform the input.
|
| 1011 |
+
|
| 1012 |
+
.. note::
|
| 1013 |
+
Implementation to transform bounding boxes is approximative (not exact).
|
| 1014 |
+
We construct an approximation of the inverse grid as ``inverse_grid = identity - displacement``.
|
| 1015 |
+
This is not an exact inverse of the grid used to transform images, i.e. ``grid = identity + displacement``.
|
| 1016 |
+
Our assumption is that ``displacement * displacement`` is small and can be ignored.
|
| 1017 |
+
Large displacements would lead to large errors in the approximation.
|
| 1018 |
+
|
| 1019 |
+
Applications:
|
| 1020 |
+
Randomly transforms the morphology of objects in images and produces a
|
| 1021 |
+
see-through-water-like effect.
|
| 1022 |
+
|
| 1023 |
+
Args:
|
| 1024 |
+
alpha (float or sequence of floats, optional): Magnitude of displacements. Default is 50.0.
|
| 1025 |
+
sigma (float or sequence of floats, optional): Smoothness of displacements. Default is 5.0.
|
| 1026 |
+
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
|
| 1027 |
+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
|
| 1028 |
+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
|
| 1029 |
+
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
|
| 1030 |
+
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
|
| 1031 |
+
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
|
| 1032 |
+
Fill value can be also a dictionary mapping data type to the fill value, e.g.
|
| 1033 |
+
``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and
|
| 1034 |
+
``Mask`` will be filled with 0.
|
| 1035 |
+
"""
|
| 1036 |
+
|
| 1037 |
+
_v1_transform_cls = _transforms.ElasticTransform
|
| 1038 |
+
|
| 1039 |
+
def __init__(
|
| 1040 |
+
self,
|
| 1041 |
+
alpha: Union[float, Sequence[float]] = 50.0,
|
| 1042 |
+
sigma: Union[float, Sequence[float]] = 5.0,
|
| 1043 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
|
| 1044 |
+
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
|
| 1045 |
+
) -> None:
|
| 1046 |
+
super().__init__()
|
| 1047 |
+
self.alpha = _setup_number_or_seq(alpha, "alpha")
|
| 1048 |
+
self.sigma = _setup_number_or_seq(sigma, "sigma")
|
| 1049 |
+
|
| 1050 |
+
self.interpolation = interpolation
|
| 1051 |
+
self.fill = fill
|
| 1052 |
+
self._fill = _setup_fill_arg(fill)
|
| 1053 |
+
|
| 1054 |
+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
| 1055 |
+
size = list(query_size(flat_inputs))
|
| 1056 |
+
|
| 1057 |
+
dx = torch.rand([1, 1] + size) * 2 - 1
|
| 1058 |
+
if self.sigma[0] > 0.0:
|
| 1059 |
+
kx = int(8 * self.sigma[0] + 1)
|
| 1060 |
+
# if kernel size is even we have to make it odd
|
| 1061 |
+
if kx % 2 == 0:
|
| 1062 |
+
kx += 1
|
| 1063 |
+
dx = self._call_kernel(F.gaussian_blur, dx, [kx, kx], list(self.sigma))
|
| 1064 |
+
dx = dx * self.alpha[0] / size[0]
|
| 1065 |
+
|
| 1066 |
+
dy = torch.rand([1, 1] + size) * 2 - 1
|
| 1067 |
+
if self.sigma[1] > 0.0:
|
| 1068 |
+
ky = int(8 * self.sigma[1] + 1)
|
| 1069 |
+
# if kernel size is even we have to make it odd
|
| 1070 |
+
if ky % 2 == 0:
|
| 1071 |
+
ky += 1
|
| 1072 |
+
dy = self._call_kernel(F.gaussian_blur, dy, [ky, ky], list(self.sigma))
|
| 1073 |
+
dy = dy * self.alpha[1] / size[1]
|
| 1074 |
+
displacement = torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2
|
| 1075 |
+
return dict(displacement=displacement)
|
| 1076 |
+
|
| 1077 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
| 1078 |
+
fill = _get_fill(self._fill, type(inpt))
|
| 1079 |
+
return self._call_kernel(
|
| 1080 |
+
F.elastic,
|
| 1081 |
+
inpt,
|
| 1082 |
+
**params,
|
| 1083 |
+
fill=fill,
|
| 1084 |
+
interpolation=self.interpolation,
|
| 1085 |
+
)
|
| 1086 |
+
|
| 1087 |
+
|
| 1088 |
+
class RandomIoUCrop(Transform):
|
| 1089 |
+
"""Random IoU crop transformation from
|
| 1090 |
+
`"SSD: Single Shot MultiBox Detector" <https://arxiv.org/abs/1512.02325>`_.
|
| 1091 |
+
|
| 1092 |
+
This transformation requires an image or video data and ``tv_tensors.BoundingBoxes`` in the input.
|
| 1093 |
+
|
| 1094 |
+
.. warning::
|
| 1095 |
+
In order to properly remove the bounding boxes below the IoU threshold, `RandomIoUCrop`
|
| 1096 |
+
must be followed by :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes`, either immediately
|
| 1097 |
+
after or later in the transforms pipeline.
|
| 1098 |
+
|
| 1099 |
+
If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
|
| 1100 |
+
:class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
|
| 1101 |
+
it can have arbitrary number of leading batch dimensions. For example,
|
| 1102 |
+
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
|
| 1103 |
+
|
| 1104 |
+
Args:
|
| 1105 |
+
min_scale (float, optional): Minimum factors to scale the input size.
|
| 1106 |
+
max_scale (float, optional): Maximum factors to scale the input size.
|
| 1107 |
+
min_aspect_ratio (float, optional): Minimum aspect ratio for the cropped image or video.
|
| 1108 |
+
max_aspect_ratio (float, optional): Maximum aspect ratio for the cropped image or video.
|
| 1109 |
+
sampler_options (list of float, optional): List of minimal IoU (Jaccard) overlap between all the boxes and
|
| 1110 |
+
a cropped image or video. Default, ``None`` which corresponds to ``[0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0]``
|
| 1111 |
+
trials (int, optional): Number of trials to find a crop for a given value of minimal IoU (Jaccard) overlap.
|
| 1112 |
+
Default, 40.
|
| 1113 |
+
"""
|
| 1114 |
+
|
| 1115 |
+
def __init__(
|
| 1116 |
+
self,
|
| 1117 |
+
min_scale: float = 0.3,
|
| 1118 |
+
max_scale: float = 1.0,
|
| 1119 |
+
min_aspect_ratio: float = 0.5,
|
| 1120 |
+
max_aspect_ratio: float = 2.0,
|
| 1121 |
+
sampler_options: Optional[List[float]] = None,
|
| 1122 |
+
trials: int = 40,
|
| 1123 |
+
):
|
| 1124 |
+
super().__init__()
|
| 1125 |
+
# Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174
|
| 1126 |
+
self.min_scale = min_scale
|
| 1127 |
+
self.max_scale = max_scale
|
| 1128 |
+
self.min_aspect_ratio = min_aspect_ratio
|
| 1129 |
+
self.max_aspect_ratio = max_aspect_ratio
|
| 1130 |
+
if sampler_options is None:
|
| 1131 |
+
sampler_options = [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0]
|
| 1132 |
+
self.options = sampler_options
|
| 1133 |
+
self.trials = trials
|
| 1134 |
+
|
| 1135 |
+
def _check_inputs(self, flat_inputs: List[Any]) -> None:
|
| 1136 |
+
if not (
|
| 1137 |
+
has_all(flat_inputs, tv_tensors.BoundingBoxes)
|
| 1138 |
+
and has_any(flat_inputs, PIL.Image.Image, tv_tensors.Image, is_pure_tensor)
|
| 1139 |
+
):
|
| 1140 |
+
raise TypeError(
|
| 1141 |
+
f"{type(self).__name__}() requires input sample to contain tensor or PIL images "
|
| 1142 |
+
"and bounding boxes. Sample can also contain masks."
|
| 1143 |
+
)
|
| 1144 |
+
|
| 1145 |
+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
| 1146 |
+
orig_h, orig_w = query_size(flat_inputs)
|
| 1147 |
+
bboxes = get_bounding_boxes(flat_inputs)
|
| 1148 |
+
|
| 1149 |
+
while True:
|
| 1150 |
+
# sample an option
|
| 1151 |
+
idx = int(torch.randint(low=0, high=len(self.options), size=(1,)))
|
| 1152 |
+
min_jaccard_overlap = self.options[idx]
|
| 1153 |
+
if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option
|
| 1154 |
+
return dict()
|
| 1155 |
+
|
| 1156 |
+
for _ in range(self.trials):
|
| 1157 |
+
# check the aspect ratio limitations
|
| 1158 |
+
r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2)
|
| 1159 |
+
new_w = int(orig_w * r[0])
|
| 1160 |
+
new_h = int(orig_h * r[1])
|
| 1161 |
+
aspect_ratio = new_w / new_h
|
| 1162 |
+
if not (self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio):
|
| 1163 |
+
continue
|
| 1164 |
+
|
| 1165 |
+
# check for 0 area crops
|
| 1166 |
+
r = torch.rand(2)
|
| 1167 |
+
left = int((orig_w - new_w) * r[0])
|
| 1168 |
+
top = int((orig_h - new_h) * r[1])
|
| 1169 |
+
right = left + new_w
|
| 1170 |
+
bottom = top + new_h
|
| 1171 |
+
if left == right or top == bottom:
|
| 1172 |
+
continue
|
| 1173 |
+
|
| 1174 |
+
# check for any valid boxes with centers within the crop area
|
| 1175 |
+
xyxy_bboxes = F.convert_bounding_box_format(
|
| 1176 |
+
bboxes.as_subclass(torch.Tensor),
|
| 1177 |
+
bboxes.format,
|
| 1178 |
+
tv_tensors.BoundingBoxFormat.XYXY,
|
| 1179 |
+
)
|
| 1180 |
+
cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2])
|
| 1181 |
+
cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3])
|
| 1182 |
+
is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom)
|
| 1183 |
+
if not is_within_crop_area.any():
|
| 1184 |
+
continue
|
| 1185 |
+
|
| 1186 |
+
# check at least 1 box with jaccard limitations
|
| 1187 |
+
xyxy_bboxes = xyxy_bboxes[is_within_crop_area]
|
| 1188 |
+
ious = box_iou(
|
| 1189 |
+
xyxy_bboxes,
|
| 1190 |
+
torch.tensor([[left, top, right, bottom]], dtype=xyxy_bboxes.dtype, device=xyxy_bboxes.device),
|
| 1191 |
+
)
|
| 1192 |
+
if ious.max() < min_jaccard_overlap:
|
| 1193 |
+
continue
|
| 1194 |
+
|
| 1195 |
+
return dict(top=top, left=left, height=new_h, width=new_w, is_within_crop_area=is_within_crop_area)
|
| 1196 |
+
|
| 1197 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
| 1198 |
+
|
| 1199 |
+
if len(params) < 1:
|
| 1200 |
+
return inpt
|
| 1201 |
+
|
| 1202 |
+
output = self._call_kernel(
|
| 1203 |
+
F.crop, inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]
|
| 1204 |
+
)
|
| 1205 |
+
|
| 1206 |
+
if isinstance(output, tv_tensors.BoundingBoxes):
|
| 1207 |
+
# We "mark" the invalid boxes as degenreate, and they can be
|
| 1208 |
+
# removed by a later call to SanitizeBoundingBoxes()
|
| 1209 |
+
output[~params["is_within_crop_area"]] = 0
|
| 1210 |
+
|
| 1211 |
+
return output
|
| 1212 |
+
|
| 1213 |
+
|
| 1214 |
+
class ScaleJitter(Transform):
|
| 1215 |
+
"""Perform Large Scale Jitter on the input according to
|
| 1216 |
+
`"Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation" <https://arxiv.org/abs/2012.07177>`_.
|
| 1217 |
+
|
| 1218 |
+
If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
|
| 1219 |
+
:class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
|
| 1220 |
+
it can have arbitrary number of leading batch dimensions. For example,
|
| 1221 |
+
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
|
| 1222 |
+
|
| 1223 |
+
Args:
|
| 1224 |
+
target_size (tuple of int): Target size. This parameter defines base scale for jittering,
|
| 1225 |
+
e.g. ``min(target_size[0] / width, target_size[1] / height)``.
|
| 1226 |
+
scale_range (tuple of float, optional): Minimum and maximum of the scale range. Default, ``(0.1, 2.0)``.
|
| 1227 |
+
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
|
| 1228 |
+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
|
| 1229 |
+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
|
| 1230 |
+
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
|
| 1231 |
+
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
|
| 1232 |
+
antialias (bool, optional): Whether to apply antialiasing.
|
| 1233 |
+
It only affects **tensors** with bilinear or bicubic modes and it is
|
| 1234 |
+
ignored otherwise: on PIL images, antialiasing is always applied on
|
| 1235 |
+
bilinear or bicubic modes; on other modes (for PIL images and
|
| 1236 |
+
tensors), antialiasing makes no sense and this parameter is ignored.
|
| 1237 |
+
Possible values are:
|
| 1238 |
+
|
| 1239 |
+
- ``True`` (default): will apply antialiasing for bilinear or bicubic modes.
|
| 1240 |
+
Other mode aren't affected. This is probably what you want to use.
|
| 1241 |
+
- ``False``: will not apply antialiasing for tensors on any mode. PIL
|
| 1242 |
+
images are still antialiased on bilinear or bicubic modes, because
|
| 1243 |
+
PIL doesn't support no antialias.
|
| 1244 |
+
- ``None``: equivalent to ``False`` for tensors and ``True`` for
|
| 1245 |
+
PIL images. This value exists for legacy reasons and you probably
|
| 1246 |
+
don't want to use it unless you really know what you are doing.
|
| 1247 |
+
|
| 1248 |
+
The default value changed from ``None`` to ``True`` in
|
| 1249 |
+
v0.17, for the PIL and Tensor backends to be consistent.
|
| 1250 |
+
"""
|
| 1251 |
+
|
| 1252 |
+
def __init__(
|
| 1253 |
+
self,
|
| 1254 |
+
target_size: Tuple[int, int],
|
| 1255 |
+
scale_range: Tuple[float, float] = (0.1, 2.0),
|
| 1256 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
|
| 1257 |
+
antialias: Optional[bool] = True,
|
| 1258 |
+
):
|
| 1259 |
+
super().__init__()
|
| 1260 |
+
self.target_size = target_size
|
| 1261 |
+
self.scale_range = scale_range
|
| 1262 |
+
self.interpolation = interpolation
|
| 1263 |
+
self.antialias = antialias
|
| 1264 |
+
|
| 1265 |
+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
| 1266 |
+
orig_height, orig_width = query_size(flat_inputs)
|
| 1267 |
+
|
| 1268 |
+
scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0])
|
| 1269 |
+
r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale
|
| 1270 |
+
new_width = int(orig_width * r)
|
| 1271 |
+
new_height = int(orig_height * r)
|
| 1272 |
+
|
| 1273 |
+
return dict(size=(new_height, new_width))
|
| 1274 |
+
|
| 1275 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
| 1276 |
+
return self._call_kernel(
|
| 1277 |
+
F.resize, inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias
|
| 1278 |
+
)
|
| 1279 |
+
|
| 1280 |
+
|
| 1281 |
+
class RandomShortestSize(Transform):
|
| 1282 |
+
"""Randomly resize the input.
|
| 1283 |
+
|
| 1284 |
+
If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
|
| 1285 |
+
:class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
|
| 1286 |
+
it can have arbitrary number of leading batch dimensions. For example,
|
| 1287 |
+
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
|
| 1288 |
+
|
| 1289 |
+
Args:
|
| 1290 |
+
min_size (int or sequence of int): Minimum spatial size. Single integer value or a sequence of integer values.
|
| 1291 |
+
max_size (int, optional): Maximum spatial size. Default, None.
|
| 1292 |
+
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
|
| 1293 |
+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
|
| 1294 |
+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
|
| 1295 |
+
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
|
| 1296 |
+
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
|
| 1297 |
+
antialias (bool, optional): Whether to apply antialiasing.
|
| 1298 |
+
It only affects **tensors** with bilinear or bicubic modes and it is
|
| 1299 |
+
ignored otherwise: on PIL images, antialiasing is always applied on
|
| 1300 |
+
bilinear or bicubic modes; on other modes (for PIL images and
|
| 1301 |
+
tensors), antialiasing makes no sense and this parameter is ignored.
|
| 1302 |
+
Possible values are:
|
| 1303 |
+
|
| 1304 |
+
- ``True`` (default): will apply antialiasing for bilinear or bicubic modes.
|
| 1305 |
+
Other mode aren't affected. This is probably what you want to use.
|
| 1306 |
+
- ``False``: will not apply antialiasing for tensors on any mode. PIL
|
| 1307 |
+
images are still antialiased on bilinear or bicubic modes, because
|
| 1308 |
+
PIL doesn't support no antialias.
|
| 1309 |
+
- ``None``: equivalent to ``False`` for tensors and ``True`` for
|
| 1310 |
+
PIL images. This value exists for legacy reasons and you probably
|
| 1311 |
+
don't want to use it unless you really know what you are doing.
|
| 1312 |
+
|
| 1313 |
+
The default value changed from ``None`` to ``True`` in
|
| 1314 |
+
v0.17, for the PIL and Tensor backends to be consistent.
|
| 1315 |
+
"""
|
| 1316 |
+
|
| 1317 |
+
def __init__(
|
| 1318 |
+
self,
|
| 1319 |
+
min_size: Union[List[int], Tuple[int], int],
|
| 1320 |
+
max_size: Optional[int] = None,
|
| 1321 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
|
| 1322 |
+
antialias: Optional[bool] = True,
|
| 1323 |
+
):
|
| 1324 |
+
super().__init__()
|
| 1325 |
+
self.min_size = [min_size] if isinstance(min_size, int) else list(min_size)
|
| 1326 |
+
self.max_size = max_size
|
| 1327 |
+
self.interpolation = interpolation
|
| 1328 |
+
self.antialias = antialias
|
| 1329 |
+
|
| 1330 |
+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
| 1331 |
+
orig_height, orig_width = query_size(flat_inputs)
|
| 1332 |
+
|
| 1333 |
+
min_size = self.min_size[int(torch.randint(len(self.min_size), ()))]
|
| 1334 |
+
r = min_size / min(orig_height, orig_width)
|
| 1335 |
+
if self.max_size is not None:
|
| 1336 |
+
r = min(r, self.max_size / max(orig_height, orig_width))
|
| 1337 |
+
|
| 1338 |
+
new_width = int(orig_width * r)
|
| 1339 |
+
new_height = int(orig_height * r)
|
| 1340 |
+
|
| 1341 |
+
return dict(size=(new_height, new_width))
|
| 1342 |
+
|
| 1343 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
| 1344 |
+
return self._call_kernel(
|
| 1345 |
+
F.resize, inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias
|
| 1346 |
+
)
|
| 1347 |
+
|
| 1348 |
+
|
| 1349 |
+
class RandomResize(Transform):
|
| 1350 |
+
"""Randomly resize the input.
|
| 1351 |
+
|
| 1352 |
+
This transformation can be used together with ``RandomCrop`` as data augmentations to train
|
| 1353 |
+
models on image segmentation task.
|
| 1354 |
+
|
| 1355 |
+
Output spatial size is randomly sampled from the interval ``[min_size, max_size]``:
|
| 1356 |
+
|
| 1357 |
+
.. code-block:: python
|
| 1358 |
+
|
| 1359 |
+
size = uniform_sample(min_size, max_size)
|
| 1360 |
+
output_width = size
|
| 1361 |
+
output_height = size
|
| 1362 |
+
|
| 1363 |
+
If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
|
| 1364 |
+
:class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
|
| 1365 |
+
it can have arbitrary number of leading batch dimensions. For example,
|
| 1366 |
+
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
|
| 1367 |
+
|
| 1368 |
+
Args:
|
| 1369 |
+
min_size (int): Minimum output size for random sampling
|
| 1370 |
+
max_size (int): Maximum output size for random sampling
|
| 1371 |
+
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
|
| 1372 |
+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
|
| 1373 |
+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
|
| 1374 |
+
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
|
| 1375 |
+
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
|
| 1376 |
+
antialias (bool, optional): Whether to apply antialiasing.
|
| 1377 |
+
It only affects **tensors** with bilinear or bicubic modes and it is
|
| 1378 |
+
ignored otherwise: on PIL images, antialiasing is always applied on
|
| 1379 |
+
bilinear or bicubic modes; on other modes (for PIL images and
|
| 1380 |
+
tensors), antialiasing makes no sense and this parameter is ignored.
|
| 1381 |
+
Possible values are:
|
| 1382 |
+
|
| 1383 |
+
- ``True`` (default): will apply antialiasing for bilinear or bicubic modes.
|
| 1384 |
+
Other mode aren't affected. This is probably what you want to use.
|
| 1385 |
+
- ``False``: will not apply antialiasing for tensors on any mode. PIL
|
| 1386 |
+
images are still antialiased on bilinear or bicubic modes, because
|
| 1387 |
+
PIL doesn't support no antialias.
|
| 1388 |
+
- ``None``: equivalent to ``False`` for tensors and ``True`` for
|
| 1389 |
+
PIL images. This value exists for legacy reasons and you probably
|
| 1390 |
+
don't want to use it unless you really know what you are doing.
|
| 1391 |
+
|
| 1392 |
+
The default value changed from ``None`` to ``True`` in
|
| 1393 |
+
v0.17, for the PIL and Tensor backends to be consistent.
|
| 1394 |
+
"""
|
| 1395 |
+
|
| 1396 |
+
def __init__(
|
| 1397 |
+
self,
|
| 1398 |
+
min_size: int,
|
| 1399 |
+
max_size: int,
|
| 1400 |
+
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
|
| 1401 |
+
antialias: Optional[bool] = True,
|
| 1402 |
+
) -> None:
|
| 1403 |
+
super().__init__()
|
| 1404 |
+
self.min_size = min_size
|
| 1405 |
+
self.max_size = max_size
|
| 1406 |
+
self.interpolation = interpolation
|
| 1407 |
+
self.antialias = antialias
|
| 1408 |
+
|
| 1409 |
+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
| 1410 |
+
size = int(torch.randint(self.min_size, self.max_size, ()))
|
| 1411 |
+
return dict(size=[size])
|
| 1412 |
+
|
| 1413 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
| 1414 |
+
return self._call_kernel(
|
| 1415 |
+
F.resize, inpt, params["size"], interpolation=self.interpolation, antialias=self.antialias
|
| 1416 |
+
)
|