koichi12 commited on
Commit
769e5f0
·
verified ·
1 Parent(s): 208efc9

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/__init__.cpython-311.pyc +0 -0
  2. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/caltech.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/coco.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/fakedata.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/flickr.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/hmdb51.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/imagenet.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/kinetics.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/mnist.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/utils.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/video_utils.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/voc.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/widerface.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/torchvision/datasets/samplers/__init__.py +3 -0
  15. .venv/lib/python3.11/site-packages/torchvision/datasets/samplers/__pycache__/__init__.cpython-311.pyc +0 -0
  16. .venv/lib/python3.11/site-packages/torchvision/datasets/samplers/__pycache__/clip_sampler.cpython-311.pyc +0 -0
  17. .venv/lib/python3.11/site-packages/torchvision/datasets/samplers/clip_sampler.py +172 -0
  18. .venv/lib/python3.11/site-packages/torchvision/transforms/__init__.py +2 -0
  19. .venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/__init__.cpython-311.pyc +0 -0
  20. .venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/_functional_pil.cpython-311.pyc +0 -0
  21. .venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/_functional_tensor.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/_functional_video.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/_presets.cpython-311.pyc +0 -0
  24. .venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/_transforms_video.cpython-311.pyc +0 -0
  25. .venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/autoaugment.cpython-311.pyc +0 -0
  26. .venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/functional.cpython-311.pyc +0 -0
  27. .venv/lib/python3.11/site-packages/torchvision/transforms/_functional_pil.py +393 -0
  28. .venv/lib/python3.11/site-packages/torchvision/transforms/_functional_tensor.py +962 -0
  29. .venv/lib/python3.11/site-packages/torchvision/transforms/_functional_video.py +114 -0
  30. .venv/lib/python3.11/site-packages/torchvision/transforms/_presets.py +216 -0
  31. .venv/lib/python3.11/site-packages/torchvision/transforms/_transforms_video.py +174 -0
  32. .venv/lib/python3.11/site-packages/torchvision/transforms/autoaugment.py +615 -0
  33. .venv/lib/python3.11/site-packages/torchvision/transforms/functional.py +1586 -0
  34. .venv/lib/python3.11/site-packages/torchvision/transforms/transforms.py +2153 -0
  35. .venv/lib/python3.11/site-packages/torchvision/transforms/v2/__init__.py +60 -0
  36. .venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_augment.cpython-311.pyc +0 -0
  37. .venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_auto_augment.cpython-311.pyc +0 -0
  38. .venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_color.cpython-311.pyc +0 -0
  39. .venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_deprecated.cpython-311.pyc +0 -0
  40. .venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_misc.cpython-311.pyc +0 -0
  41. .venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_temporal.cpython-311.pyc +0 -0
  42. .venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_transform.cpython-311.pyc +0 -0
  43. .venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_type_conversion.cpython-311.pyc +0 -0
  44. .venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_utils.cpython-311.pyc +0 -0
  45. .venv/lib/python3.11/site-packages/torchvision/transforms/v2/_augment.py +369 -0
  46. .venv/lib/python3.11/site-packages/torchvision/transforms/v2/_auto_augment.py +627 -0
  47. .venv/lib/python3.11/site-packages/torchvision/transforms/v2/_color.py +376 -0
  48. .venv/lib/python3.11/site-packages/torchvision/transforms/v2/_container.py +174 -0
  49. .venv/lib/python3.11/site-packages/torchvision/transforms/v2/_deprecated.py +50 -0
  50. .venv/lib/python3.11/site-packages/torchvision/transforms/v2/_geometry.py +1416 -0
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (4.41 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/caltech.cpython-311.pyc ADDED
Binary file (13.5 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/coco.cpython-311.pyc ADDED
Binary file (7.09 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/fakedata.cpython-311.pyc ADDED
Binary file (3.86 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/flickr.cpython-311.pyc ADDED
Binary file (9.37 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/hmdb51.cpython-311.pyc ADDED
Binary file (8.16 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/imagenet.cpython-311.pyc ADDED
Binary file (16.9 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/kinetics.cpython-311.pyc ADDED
Binary file (14.7 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/mnist.cpython-311.pyc ADDED
Binary file (33.4 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/utils.cpython-311.pyc ADDED
Binary file (27.2 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/video_utils.cpython-311.pyc ADDED
Binary file (22.5 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/voc.cpython-311.pyc ADDED
Binary file (13.8 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/widerface.cpython-311.pyc ADDED
Binary file (11.5 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/samplers/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .clip_sampler import DistributedSampler, RandomClipSampler, UniformClipSampler
2
+
3
+ __all__ = ("DistributedSampler", "UniformClipSampler", "RandomClipSampler")
.venv/lib/python3.11/site-packages/torchvision/datasets/samplers/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (373 Bytes). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/samplers/__pycache__/clip_sampler.cpython-311.pyc ADDED
Binary file (10.6 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/samplers/clip_sampler.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import cast, Iterator, List, Optional, Sized, Union
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+ from torch.utils.data import Sampler
7
+ from torchvision.datasets.video_utils import VideoClips
8
+
9
+
10
+ class DistributedSampler(Sampler):
11
+ """
12
+ Extension of DistributedSampler, as discussed in
13
+ https://github.com/pytorch/pytorch/issues/23430
14
+
15
+ Example:
16
+ dataset: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
17
+ num_replicas: 4
18
+ shuffle: False
19
+
20
+ when group_size = 1
21
+ RANK | shard_dataset
22
+ =========================
23
+ rank_0 | [0, 4, 8, 12]
24
+ rank_1 | [1, 5, 9, 13]
25
+ rank_2 | [2, 6, 10, 0]
26
+ rank_3 | [3, 7, 11, 1]
27
+
28
+ when group_size = 2
29
+
30
+ RANK | shard_dataset
31
+ =========================
32
+ rank_0 | [0, 1, 8, 9]
33
+ rank_1 | [2, 3, 10, 11]
34
+ rank_2 | [4, 5, 12, 13]
35
+ rank_3 | [6, 7, 0, 1]
36
+
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ dataset: Sized,
42
+ num_replicas: Optional[int] = None,
43
+ rank: Optional[int] = None,
44
+ shuffle: bool = False,
45
+ group_size: int = 1,
46
+ ) -> None:
47
+ if num_replicas is None:
48
+ if not dist.is_available():
49
+ raise RuntimeError("Requires distributed package to be available")
50
+ num_replicas = dist.get_world_size()
51
+ if rank is None:
52
+ if not dist.is_available():
53
+ raise RuntimeError("Requires distributed package to be available")
54
+ rank = dist.get_rank()
55
+ if len(dataset) % group_size != 0:
56
+ raise ValueError(
57
+ f"dataset length must be a multiplier of group size dataset length: {len(dataset)}, group size: {group_size}"
58
+ )
59
+ self.dataset = dataset
60
+ self.group_size = group_size
61
+ self.num_replicas = num_replicas
62
+ self.rank = rank
63
+ self.epoch = 0
64
+ dataset_group_length = len(dataset) // group_size
65
+ self.num_group_samples = int(math.ceil(dataset_group_length * 1.0 / self.num_replicas))
66
+ self.num_samples = self.num_group_samples * group_size
67
+ self.total_size = self.num_samples * self.num_replicas
68
+ self.shuffle = shuffle
69
+
70
+ def __iter__(self) -> Iterator[int]:
71
+ # deterministically shuffle based on epoch
72
+ g = torch.Generator()
73
+ g.manual_seed(self.epoch)
74
+ indices: Union[torch.Tensor, List[int]]
75
+ if self.shuffle:
76
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
77
+ else:
78
+ indices = list(range(len(self.dataset)))
79
+
80
+ # add extra samples to make it evenly divisible
81
+ indices += indices[: (self.total_size - len(indices))]
82
+ assert len(indices) == self.total_size
83
+
84
+ total_group_size = self.total_size // self.group_size
85
+ indices = torch.reshape(torch.LongTensor(indices), (total_group_size, self.group_size))
86
+
87
+ # subsample
88
+ indices = indices[self.rank : total_group_size : self.num_replicas, :]
89
+ indices = torch.reshape(indices, (-1,)).tolist()
90
+ assert len(indices) == self.num_samples
91
+
92
+ if isinstance(self.dataset, Sampler):
93
+ orig_indices = list(iter(self.dataset))
94
+ indices = [orig_indices[i] for i in indices]
95
+
96
+ return iter(indices)
97
+
98
+ def __len__(self) -> int:
99
+ return self.num_samples
100
+
101
+ def set_epoch(self, epoch: int) -> None:
102
+ self.epoch = epoch
103
+
104
+
105
+ class UniformClipSampler(Sampler):
106
+ """
107
+ Sample `num_video_clips_per_video` clips for each video, equally spaced.
108
+ When number of unique clips in the video is fewer than num_video_clips_per_video,
109
+ repeat the clips until `num_video_clips_per_video` clips are collected
110
+
111
+ Args:
112
+ video_clips (VideoClips): video clips to sample from
113
+ num_clips_per_video (int): number of clips to be sampled per video
114
+ """
115
+
116
+ def __init__(self, video_clips: VideoClips, num_clips_per_video: int) -> None:
117
+ if not isinstance(video_clips, VideoClips):
118
+ raise TypeError(f"Expected video_clips to be an instance of VideoClips, got {type(video_clips)}")
119
+ self.video_clips = video_clips
120
+ self.num_clips_per_video = num_clips_per_video
121
+
122
+ def __iter__(self) -> Iterator[int]:
123
+ idxs = []
124
+ s = 0
125
+ # select num_clips_per_video for each video, uniformly spaced
126
+ for c in self.video_clips.clips:
127
+ length = len(c)
128
+ if length == 0:
129
+ # corner case where video decoding fails
130
+ continue
131
+
132
+ sampled = torch.linspace(s, s + length - 1, steps=self.num_clips_per_video).floor().to(torch.int64)
133
+ s += length
134
+ idxs.append(sampled)
135
+ return iter(cast(List[int], torch.cat(idxs).tolist()))
136
+
137
+ def __len__(self) -> int:
138
+ return sum(self.num_clips_per_video for c in self.video_clips.clips if len(c) > 0)
139
+
140
+
141
+ class RandomClipSampler(Sampler):
142
+ """
143
+ Samples at most `max_video_clips_per_video` clips for each video randomly
144
+
145
+ Args:
146
+ video_clips (VideoClips): video clips to sample from
147
+ max_clips_per_video (int): maximum number of clips to be sampled per video
148
+ """
149
+
150
+ def __init__(self, video_clips: VideoClips, max_clips_per_video: int) -> None:
151
+ if not isinstance(video_clips, VideoClips):
152
+ raise TypeError(f"Expected video_clips to be an instance of VideoClips, got {type(video_clips)}")
153
+ self.video_clips = video_clips
154
+ self.max_clips_per_video = max_clips_per_video
155
+
156
+ def __iter__(self) -> Iterator[int]:
157
+ idxs = []
158
+ s = 0
159
+ # select at most max_clips_per_video for each video, randomly
160
+ for c in self.video_clips.clips:
161
+ length = len(c)
162
+ size = min(length, self.max_clips_per_video)
163
+ sampled = torch.randperm(length)[:size] + s
164
+ s += length
165
+ idxs.append(sampled)
166
+ idxs_ = torch.cat(idxs)
167
+ # shuffle all clips randomly
168
+ perm = torch.randperm(len(idxs_))
169
+ return iter(idxs_[perm].tolist())
170
+
171
+ def __len__(self) -> int:
172
+ return sum(min(len(c), self.max_clips_per_video) for c in self.video_clips.clips)
.venv/lib/python3.11/site-packages/torchvision/transforms/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .transforms import *
2
+ from .autoaugment import *
.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (263 Bytes). View file
 
.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/_functional_pil.cpython-311.pyc ADDED
Binary file (21.7 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/_functional_tensor.cpython-311.pyc ADDED
Binary file (52 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/_functional_video.cpython-311.pyc ADDED
Binary file (6.23 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/_presets.cpython-311.pyc ADDED
Binary file (13 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/_transforms_video.cpython-311.pyc ADDED
Binary file (9.19 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/autoaugment.cpython-311.pyc ADDED
Binary file (34 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/transforms/__pycache__/functional.cpython-311.pyc ADDED
Binary file (85.7 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/transforms/_functional_pil.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numbers
2
+ from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from PIL import Image, ImageEnhance, ImageOps
7
+
8
+ try:
9
+ import accimage
10
+ except ImportError:
11
+ accimage = None
12
+
13
+
14
+ @torch.jit.unused
15
+ def _is_pil_image(img: Any) -> bool:
16
+ if accimage is not None:
17
+ return isinstance(img, (Image.Image, accimage.Image))
18
+ else:
19
+ return isinstance(img, Image.Image)
20
+
21
+
22
+ @torch.jit.unused
23
+ def get_dimensions(img: Any) -> List[int]:
24
+ if _is_pil_image(img):
25
+ if hasattr(img, "getbands"):
26
+ channels = len(img.getbands())
27
+ else:
28
+ channels = img.channels
29
+ width, height = img.size
30
+ return [channels, height, width]
31
+ raise TypeError(f"Unexpected type {type(img)}")
32
+
33
+
34
+ @torch.jit.unused
35
+ def get_image_size(img: Any) -> List[int]:
36
+ if _is_pil_image(img):
37
+ return list(img.size)
38
+ raise TypeError(f"Unexpected type {type(img)}")
39
+
40
+
41
+ @torch.jit.unused
42
+ def get_image_num_channels(img: Any) -> int:
43
+ if _is_pil_image(img):
44
+ if hasattr(img, "getbands"):
45
+ return len(img.getbands())
46
+ else:
47
+ return img.channels
48
+ raise TypeError(f"Unexpected type {type(img)}")
49
+
50
+
51
+ @torch.jit.unused
52
+ def hflip(img: Image.Image) -> Image.Image:
53
+ if not _is_pil_image(img):
54
+ raise TypeError(f"img should be PIL Image. Got {type(img)}")
55
+
56
+ return img.transpose(Image.FLIP_LEFT_RIGHT)
57
+
58
+
59
+ @torch.jit.unused
60
+ def vflip(img: Image.Image) -> Image.Image:
61
+ if not _is_pil_image(img):
62
+ raise TypeError(f"img should be PIL Image. Got {type(img)}")
63
+
64
+ return img.transpose(Image.FLIP_TOP_BOTTOM)
65
+
66
+
67
+ @torch.jit.unused
68
+ def adjust_brightness(img: Image.Image, brightness_factor: float) -> Image.Image:
69
+ if not _is_pil_image(img):
70
+ raise TypeError(f"img should be PIL Image. Got {type(img)}")
71
+
72
+ enhancer = ImageEnhance.Brightness(img)
73
+ img = enhancer.enhance(brightness_factor)
74
+ return img
75
+
76
+
77
+ @torch.jit.unused
78
+ def adjust_contrast(img: Image.Image, contrast_factor: float) -> Image.Image:
79
+ if not _is_pil_image(img):
80
+ raise TypeError(f"img should be PIL Image. Got {type(img)}")
81
+
82
+ enhancer = ImageEnhance.Contrast(img)
83
+ img = enhancer.enhance(contrast_factor)
84
+ return img
85
+
86
+
87
+ @torch.jit.unused
88
+ def adjust_saturation(img: Image.Image, saturation_factor: float) -> Image.Image:
89
+ if not _is_pil_image(img):
90
+ raise TypeError(f"img should be PIL Image. Got {type(img)}")
91
+
92
+ enhancer = ImageEnhance.Color(img)
93
+ img = enhancer.enhance(saturation_factor)
94
+ return img
95
+
96
+
97
+ @torch.jit.unused
98
+ def adjust_hue(img: Image.Image, hue_factor: float) -> Image.Image:
99
+ if not (-0.5 <= hue_factor <= 0.5):
100
+ raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
101
+
102
+ if not _is_pil_image(img):
103
+ raise TypeError(f"img should be PIL Image. Got {type(img)}")
104
+
105
+ input_mode = img.mode
106
+ if input_mode in {"L", "1", "I", "F"}:
107
+ return img
108
+
109
+ h, s, v = img.convert("HSV").split()
110
+
111
+ np_h = np.array(h, dtype=np.uint8)
112
+ # This will over/underflow, as desired
113
+ np_h += np.array(hue_factor * 255).astype(np.uint8)
114
+
115
+ h = Image.fromarray(np_h, "L")
116
+
117
+ img = Image.merge("HSV", (h, s, v)).convert(input_mode)
118
+ return img
119
+
120
+
121
+ @torch.jit.unused
122
+ def adjust_gamma(
123
+ img: Image.Image,
124
+ gamma: float,
125
+ gain: float = 1.0,
126
+ ) -> Image.Image:
127
+
128
+ if not _is_pil_image(img):
129
+ raise TypeError(f"img should be PIL Image. Got {type(img)}")
130
+
131
+ if gamma < 0:
132
+ raise ValueError("Gamma should be a non-negative real number")
133
+
134
+ input_mode = img.mode
135
+ img = img.convert("RGB")
136
+ gamma_map = [int((255 + 1 - 1e-3) * gain * pow(ele / 255.0, gamma)) for ele in range(256)] * 3
137
+ img = img.point(gamma_map) # use PIL's point-function to accelerate this part
138
+
139
+ img = img.convert(input_mode)
140
+ return img
141
+
142
+
143
+ @torch.jit.unused
144
+ def pad(
145
+ img: Image.Image,
146
+ padding: Union[int, List[int], Tuple[int, ...]],
147
+ fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
148
+ padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
149
+ ) -> Image.Image:
150
+
151
+ if not _is_pil_image(img):
152
+ raise TypeError(f"img should be PIL Image. Got {type(img)}")
153
+
154
+ if not isinstance(padding, (numbers.Number, tuple, list)):
155
+ raise TypeError("Got inappropriate padding arg")
156
+ if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)):
157
+ raise TypeError("Got inappropriate fill arg")
158
+ if not isinstance(padding_mode, str):
159
+ raise TypeError("Got inappropriate padding_mode arg")
160
+
161
+ if isinstance(padding, list):
162
+ padding = tuple(padding)
163
+
164
+ if isinstance(padding, tuple) and len(padding) not in [1, 2, 4]:
165
+ raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple")
166
+
167
+ if isinstance(padding, tuple) and len(padding) == 1:
168
+ # Compatibility with `functional_tensor.pad`
169
+ padding = padding[0]
170
+
171
+ if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
172
+ raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
173
+
174
+ if padding_mode == "constant":
175
+ opts = _parse_fill(fill, img, name="fill")
176
+ if img.mode == "P":
177
+ palette = img.getpalette()
178
+ image = ImageOps.expand(img, border=padding, **opts)
179
+ image.putpalette(palette)
180
+ return image
181
+
182
+ return ImageOps.expand(img, border=padding, **opts)
183
+ else:
184
+ if isinstance(padding, int):
185
+ pad_left = pad_right = pad_top = pad_bottom = padding
186
+ if isinstance(padding, tuple) and len(padding) == 2:
187
+ pad_left = pad_right = padding[0]
188
+ pad_top = pad_bottom = padding[1]
189
+ if isinstance(padding, tuple) and len(padding) == 4:
190
+ pad_left = padding[0]
191
+ pad_top = padding[1]
192
+ pad_right = padding[2]
193
+ pad_bottom = padding[3]
194
+
195
+ p = [pad_left, pad_top, pad_right, pad_bottom]
196
+ cropping = -np.minimum(p, 0)
197
+
198
+ if cropping.any():
199
+ crop_left, crop_top, crop_right, crop_bottom = cropping
200
+ img = img.crop((crop_left, crop_top, img.width - crop_right, img.height - crop_bottom))
201
+
202
+ pad_left, pad_top, pad_right, pad_bottom = np.maximum(p, 0)
203
+
204
+ if img.mode == "P":
205
+ palette = img.getpalette()
206
+ img = np.asarray(img)
207
+ img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), mode=padding_mode)
208
+ img = Image.fromarray(img)
209
+ img.putpalette(palette)
210
+ return img
211
+
212
+ img = np.asarray(img)
213
+ # RGB image
214
+ if len(img.shape) == 3:
215
+ img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode)
216
+ # Grayscale image
217
+ if len(img.shape) == 2:
218
+ img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)
219
+
220
+ return Image.fromarray(img)
221
+
222
+
223
+ @torch.jit.unused
224
+ def crop(
225
+ img: Image.Image,
226
+ top: int,
227
+ left: int,
228
+ height: int,
229
+ width: int,
230
+ ) -> Image.Image:
231
+
232
+ if not _is_pil_image(img):
233
+ raise TypeError(f"img should be PIL Image. Got {type(img)}")
234
+
235
+ return img.crop((left, top, left + width, top + height))
236
+
237
+
238
+ @torch.jit.unused
239
+ def resize(
240
+ img: Image.Image,
241
+ size: Union[List[int], int],
242
+ interpolation: int = Image.BILINEAR,
243
+ ) -> Image.Image:
244
+
245
+ if not _is_pil_image(img):
246
+ raise TypeError(f"img should be PIL Image. Got {type(img)}")
247
+ if not (isinstance(size, list) and len(size) == 2):
248
+ raise TypeError(f"Got inappropriate size arg: {size}")
249
+
250
+ return img.resize(tuple(size[::-1]), interpolation)
251
+
252
+
253
+ @torch.jit.unused
254
+ def _parse_fill(
255
+ fill: Optional[Union[float, List[float], Tuple[float, ...]]],
256
+ img: Image.Image,
257
+ name: str = "fillcolor",
258
+ ) -> Dict[str, Optional[Union[float, List[float], Tuple[float, ...]]]]:
259
+
260
+ # Process fill color for affine transforms
261
+ num_channels = get_image_num_channels(img)
262
+ if fill is None:
263
+ fill = 0
264
+ if isinstance(fill, (int, float)) and num_channels > 1:
265
+ fill = tuple([fill] * num_channels)
266
+ if isinstance(fill, (list, tuple)):
267
+ if len(fill) == 1:
268
+ fill = fill * num_channels
269
+ elif len(fill) != num_channels:
270
+ msg = "The number of elements in 'fill' does not match the number of channels of the image ({} != {})"
271
+ raise ValueError(msg.format(len(fill), num_channels))
272
+
273
+ fill = tuple(fill) # type: ignore[arg-type]
274
+
275
+ if img.mode != "F":
276
+ if isinstance(fill, (list, tuple)):
277
+ fill = tuple(int(x) for x in fill)
278
+ else:
279
+ fill = int(fill)
280
+
281
+ return {name: fill}
282
+
283
+
284
+ @torch.jit.unused
285
+ def affine(
286
+ img: Image.Image,
287
+ matrix: List[float],
288
+ interpolation: int = Image.NEAREST,
289
+ fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
290
+ ) -> Image.Image:
291
+
292
+ if not _is_pil_image(img):
293
+ raise TypeError(f"img should be PIL Image. Got {type(img)}")
294
+
295
+ output_size = img.size
296
+ opts = _parse_fill(fill, img)
297
+ return img.transform(output_size, Image.AFFINE, matrix, interpolation, **opts)
298
+
299
+
300
+ @torch.jit.unused
301
+ def rotate(
302
+ img: Image.Image,
303
+ angle: float,
304
+ interpolation: int = Image.NEAREST,
305
+ expand: bool = False,
306
+ center: Optional[Tuple[int, int]] = None,
307
+ fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
308
+ ) -> Image.Image:
309
+
310
+ if not _is_pil_image(img):
311
+ raise TypeError(f"img should be PIL Image. Got {type(img)}")
312
+
313
+ opts = _parse_fill(fill, img)
314
+ return img.rotate(angle, interpolation, expand, center, **opts)
315
+
316
+
317
+ @torch.jit.unused
318
+ def perspective(
319
+ img: Image.Image,
320
+ perspective_coeffs: List[float],
321
+ interpolation: int = Image.BICUBIC,
322
+ fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
323
+ ) -> Image.Image:
324
+
325
+ if not _is_pil_image(img):
326
+ raise TypeError(f"img should be PIL Image. Got {type(img)}")
327
+
328
+ opts = _parse_fill(fill, img)
329
+
330
+ return img.transform(img.size, Image.PERSPECTIVE, perspective_coeffs, interpolation, **opts)
331
+
332
+
333
+ @torch.jit.unused
334
+ def to_grayscale(img: Image.Image, num_output_channels: int) -> Image.Image:
335
+ if not _is_pil_image(img):
336
+ raise TypeError(f"img should be PIL Image. Got {type(img)}")
337
+
338
+ if num_output_channels == 1:
339
+ img = img.convert("L")
340
+ elif num_output_channels == 3:
341
+ img = img.convert("L")
342
+ np_img = np.array(img, dtype=np.uint8)
343
+ np_img = np.dstack([np_img, np_img, np_img])
344
+ img = Image.fromarray(np_img, "RGB")
345
+ else:
346
+ raise ValueError("num_output_channels should be either 1 or 3")
347
+
348
+ return img
349
+
350
+
351
+ @torch.jit.unused
352
+ def invert(img: Image.Image) -> Image.Image:
353
+ if not _is_pil_image(img):
354
+ raise TypeError(f"img should be PIL Image. Got {type(img)}")
355
+ return ImageOps.invert(img)
356
+
357
+
358
+ @torch.jit.unused
359
+ def posterize(img: Image.Image, bits: int) -> Image.Image:
360
+ if not _is_pil_image(img):
361
+ raise TypeError(f"img should be PIL Image. Got {type(img)}")
362
+ return ImageOps.posterize(img, bits)
363
+
364
+
365
+ @torch.jit.unused
366
+ def solarize(img: Image.Image, threshold: int) -> Image.Image:
367
+ if not _is_pil_image(img):
368
+ raise TypeError(f"img should be PIL Image. Got {type(img)}")
369
+ return ImageOps.solarize(img, threshold)
370
+
371
+
372
+ @torch.jit.unused
373
+ def adjust_sharpness(img: Image.Image, sharpness_factor: float) -> Image.Image:
374
+ if not _is_pil_image(img):
375
+ raise TypeError(f"img should be PIL Image. Got {type(img)}")
376
+
377
+ enhancer = ImageEnhance.Sharpness(img)
378
+ img = enhancer.enhance(sharpness_factor)
379
+ return img
380
+
381
+
382
+ @torch.jit.unused
383
+ def autocontrast(img: Image.Image) -> Image.Image:
384
+ if not _is_pil_image(img):
385
+ raise TypeError(f"img should be PIL Image. Got {type(img)}")
386
+ return ImageOps.autocontrast(img)
387
+
388
+
389
+ @torch.jit.unused
390
+ def equalize(img: Image.Image) -> Image.Image:
391
+ if not _is_pil_image(img):
392
+ raise TypeError(f"img should be PIL Image. Got {type(img)}")
393
+ return ImageOps.equalize(img)
.venv/lib/python3.11/site-packages/torchvision/transforms/_functional_tensor.py ADDED
@@ -0,0 +1,962 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from torch import Tensor
6
+ from torch.nn.functional import conv2d, grid_sample, interpolate, pad as torch_pad
7
+
8
+
9
+ def _is_tensor_a_torch_image(x: Tensor) -> bool:
10
+ return x.ndim >= 2
11
+
12
+
13
+ def _assert_image_tensor(img: Tensor) -> None:
14
+ if not _is_tensor_a_torch_image(img):
15
+ raise TypeError("Tensor is not a torch image.")
16
+
17
+
18
+ def get_dimensions(img: Tensor) -> List[int]:
19
+ _assert_image_tensor(img)
20
+ channels = 1 if img.ndim == 2 else img.shape[-3]
21
+ height, width = img.shape[-2:]
22
+ return [channels, height, width]
23
+
24
+
25
+ def get_image_size(img: Tensor) -> List[int]:
26
+ # Returns (w, h) of tensor image
27
+ _assert_image_tensor(img)
28
+ return [img.shape[-1], img.shape[-2]]
29
+
30
+
31
+ def get_image_num_channels(img: Tensor) -> int:
32
+ _assert_image_tensor(img)
33
+ if img.ndim == 2:
34
+ return 1
35
+ elif img.ndim > 2:
36
+ return img.shape[-3]
37
+
38
+ raise TypeError(f"Input ndim should be 2 or more. Got {img.ndim}")
39
+
40
+
41
+ def _max_value(dtype: torch.dtype) -> int:
42
+ if dtype == torch.uint8:
43
+ return 255
44
+ elif dtype == torch.int8:
45
+ return 127
46
+ elif dtype == torch.int16:
47
+ return 32767
48
+ elif dtype == torch.uint16:
49
+ return 65535
50
+ elif dtype == torch.int32:
51
+ return 2147483647
52
+ elif dtype == torch.int64:
53
+ return 9223372036854775807
54
+ else:
55
+ # This is only here for completeness. This value is implicitly assumed in a lot of places so changing it is not
56
+ # easy.
57
+ return 1
58
+
59
+
60
+ def _assert_channels(img: Tensor, permitted: List[int]) -> None:
61
+ c = get_dimensions(img)[0]
62
+ if c not in permitted:
63
+ raise TypeError(f"Input image tensor permitted channel values are {permitted}, but found {c}")
64
+
65
+
66
+ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
67
+ if image.dtype == dtype:
68
+ return image
69
+
70
+ if image.is_floating_point():
71
+
72
+ # TODO: replace with dtype.is_floating_point when torchscript supports it
73
+ if torch.tensor(0, dtype=dtype).is_floating_point():
74
+ return image.to(dtype)
75
+
76
+ # float to int
77
+ if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
78
+ image.dtype == torch.float64 and dtype == torch.int64
79
+ ):
80
+ msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
81
+ raise RuntimeError(msg)
82
+
83
+ # https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
84
+ # For data in the range 0-1, (float * 255).to(uint) is only 255
85
+ # when float is exactly 1.0.
86
+ # `max + 1 - epsilon` provides more evenly distributed mapping of
87
+ # ranges of floats to ints.
88
+ eps = 1e-3
89
+ max_val = float(_max_value(dtype))
90
+ result = image.mul(max_val + 1.0 - eps)
91
+ return result.to(dtype)
92
+ else:
93
+ input_max = float(_max_value(image.dtype))
94
+
95
+ # int to float
96
+ # TODO: replace with dtype.is_floating_point when torchscript supports it
97
+ if torch.tensor(0, dtype=dtype).is_floating_point():
98
+ image = image.to(dtype)
99
+ return image / input_max
100
+
101
+ output_max = float(_max_value(dtype))
102
+
103
+ # int to int
104
+ if input_max > output_max:
105
+ # factor should be forced to int for torch jit script
106
+ # otherwise factor is a float and image // factor can produce different results
107
+ factor = int((input_max + 1) // (output_max + 1))
108
+ image = torch.div(image, factor, rounding_mode="floor")
109
+ return image.to(dtype)
110
+ else:
111
+ # factor should be forced to int for torch jit script
112
+ # otherwise factor is a float and image * factor can produce different results
113
+ factor = int((output_max + 1) // (input_max + 1))
114
+ image = image.to(dtype)
115
+ return image * factor
116
+
117
+
118
+ def vflip(img: Tensor) -> Tensor:
119
+ _assert_image_tensor(img)
120
+
121
+ return img.flip(-2)
122
+
123
+
124
+ def hflip(img: Tensor) -> Tensor:
125
+ _assert_image_tensor(img)
126
+
127
+ return img.flip(-1)
128
+
129
+
130
+ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
131
+ _assert_image_tensor(img)
132
+
133
+ _, h, w = get_dimensions(img)
134
+ right = left + width
135
+ bottom = top + height
136
+
137
+ if left < 0 or top < 0 or right > w or bottom > h:
138
+ padding_ltrb = [
139
+ max(-left + min(0, right), 0),
140
+ max(-top + min(0, bottom), 0),
141
+ max(right - max(w, left), 0),
142
+ max(bottom - max(h, top), 0),
143
+ ]
144
+ return pad(img[..., max(top, 0) : bottom, max(left, 0) : right], padding_ltrb, fill=0)
145
+ return img[..., top:bottom, left:right]
146
+
147
+
148
+ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
149
+ if img.ndim < 3:
150
+ raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
151
+ _assert_channels(img, [1, 3])
152
+
153
+ if num_output_channels not in (1, 3):
154
+ raise ValueError("num_output_channels should be either 1 or 3")
155
+
156
+ if img.shape[-3] == 3:
157
+ r, g, b = img.unbind(dim=-3)
158
+ # This implementation closely follows the TF one:
159
+ # https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/ops/image_ops_impl.py#L2105-L2138
160
+ l_img = (0.2989 * r + 0.587 * g + 0.114 * b).to(img.dtype)
161
+ l_img = l_img.unsqueeze(dim=-3)
162
+ else:
163
+ l_img = img.clone()
164
+
165
+ if num_output_channels == 3:
166
+ return l_img.expand(img.shape)
167
+
168
+ return l_img
169
+
170
+
171
+ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
172
+ if brightness_factor < 0:
173
+ raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.")
174
+
175
+ _assert_image_tensor(img)
176
+
177
+ _assert_channels(img, [1, 3])
178
+
179
+ return _blend(img, torch.zeros_like(img), brightness_factor)
180
+
181
+
182
+ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
183
+ if contrast_factor < 0:
184
+ raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.")
185
+
186
+ _assert_image_tensor(img)
187
+
188
+ _assert_channels(img, [3, 1])
189
+ c = get_dimensions(img)[0]
190
+ dtype = img.dtype if torch.is_floating_point(img) else torch.float32
191
+ if c == 3:
192
+ mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True)
193
+ else:
194
+ mean = torch.mean(img.to(dtype), dim=(-3, -2, -1), keepdim=True)
195
+
196
+ return _blend(img, mean, contrast_factor)
197
+
198
+
199
+ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
200
+ if not (-0.5 <= hue_factor <= 0.5):
201
+ raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
202
+
203
+ if not (isinstance(img, torch.Tensor)):
204
+ raise TypeError("Input img should be Tensor image")
205
+
206
+ _assert_image_tensor(img)
207
+
208
+ _assert_channels(img, [1, 3])
209
+ if get_dimensions(img)[0] == 1: # Match PIL behaviour
210
+ return img
211
+
212
+ orig_dtype = img.dtype
213
+ img = convert_image_dtype(img, torch.float32)
214
+
215
+ img = _rgb2hsv(img)
216
+ h, s, v = img.unbind(dim=-3)
217
+ h = (h + hue_factor) % 1.0
218
+ img = torch.stack((h, s, v), dim=-3)
219
+ img_hue_adj = _hsv2rgb(img)
220
+
221
+ return convert_image_dtype(img_hue_adj, orig_dtype)
222
+
223
+
224
+ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
225
+ if saturation_factor < 0:
226
+ raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.")
227
+
228
+ _assert_image_tensor(img)
229
+
230
+ _assert_channels(img, [1, 3])
231
+
232
+ if get_dimensions(img)[0] == 1: # Match PIL behaviour
233
+ return img
234
+
235
+ return _blend(img, rgb_to_grayscale(img), saturation_factor)
236
+
237
+
238
+ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
239
+ if not isinstance(img, torch.Tensor):
240
+ raise TypeError("Input img should be a Tensor.")
241
+
242
+ _assert_channels(img, [1, 3])
243
+
244
+ if gamma < 0:
245
+ raise ValueError("Gamma should be a non-negative real number")
246
+
247
+ result = img
248
+ dtype = img.dtype
249
+ if not torch.is_floating_point(img):
250
+ result = convert_image_dtype(result, torch.float32)
251
+
252
+ result = (gain * result**gamma).clamp(0, 1)
253
+
254
+ result = convert_image_dtype(result, dtype)
255
+ return result
256
+
257
+
258
+ def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
259
+ ratio = float(ratio)
260
+ bound = _max_value(img1.dtype)
261
+ return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype)
262
+
263
+
264
+ def _rgb2hsv(img: Tensor) -> Tensor:
265
+ r, g, b = img.unbind(dim=-3)
266
+
267
+ # Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/
268
+ # src/libImaging/Convert.c#L330
269
+ maxc = torch.max(img, dim=-3).values
270
+ minc = torch.min(img, dim=-3).values
271
+
272
+ # The algorithm erases S and H channel where `maxc = minc`. This avoids NaN
273
+ # from happening in the results, because
274
+ # + S channel has division by `maxc`, which is zero only if `maxc = minc`
275
+ # + H channel has division by `(maxc - minc)`.
276
+ #
277
+ # Instead of overwriting NaN afterwards, we just prevent it from occurring, so
278
+ # we don't need to deal with it in case we save the NaN in a buffer in
279
+ # backprop, if it is ever supported, but it doesn't hurt to do so.
280
+ eqc = maxc == minc
281
+
282
+ cr = maxc - minc
283
+ # Since `eqc => cr = 0`, replacing denominator with 1 when `eqc` is fine.
284
+ ones = torch.ones_like(maxc)
285
+ s = cr / torch.where(eqc, ones, maxc)
286
+ # Note that `eqc => maxc = minc = r = g = b`. So the following calculation
287
+ # of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it
288
+ # would not matter what values `rc`, `gc`, and `bc` have here, and thus
289
+ # replacing denominator with 1 when `eqc` is fine.
290
+ cr_divisor = torch.where(eqc, ones, cr)
291
+ rc = (maxc - r) / cr_divisor
292
+ gc = (maxc - g) / cr_divisor
293
+ bc = (maxc - b) / cr_divisor
294
+
295
+ hr = (maxc == r) * (bc - gc)
296
+ hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc)
297
+ hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc)
298
+ h = hr + hg + hb
299
+ h = torch.fmod((h / 6.0 + 1.0), 1.0)
300
+ return torch.stack((h, s, maxc), dim=-3)
301
+
302
+
303
+ def _hsv2rgb(img: Tensor) -> Tensor:
304
+ h, s, v = img.unbind(dim=-3)
305
+ i = torch.floor(h * 6.0)
306
+ f = (h * 6.0) - i
307
+ i = i.to(dtype=torch.int32)
308
+
309
+ p = torch.clamp((v * (1.0 - s)), 0.0, 1.0)
310
+ q = torch.clamp((v * (1.0 - s * f)), 0.0, 1.0)
311
+ t = torch.clamp((v * (1.0 - s * (1.0 - f))), 0.0, 1.0)
312
+ i = i % 6
313
+
314
+ mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1)
315
+
316
+ a1 = torch.stack((v, q, p, p, t, v), dim=-3)
317
+ a2 = torch.stack((t, v, v, q, p, p), dim=-3)
318
+ a3 = torch.stack((p, p, t, v, v, q), dim=-3)
319
+ a4 = torch.stack((a1, a2, a3), dim=-4)
320
+
321
+ return torch.einsum("...ijk, ...xijk -> ...xjk", mask.to(dtype=img.dtype), a4)
322
+
323
+
324
+ def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor:
325
+ # padding is left, right, top, bottom
326
+
327
+ # crop if needed
328
+ if padding[0] < 0 or padding[1] < 0 or padding[2] < 0 or padding[3] < 0:
329
+ neg_min_padding = [-min(x, 0) for x in padding]
330
+ crop_left, crop_right, crop_top, crop_bottom = neg_min_padding
331
+ img = img[..., crop_top : img.shape[-2] - crop_bottom, crop_left : img.shape[-1] - crop_right]
332
+ padding = [max(x, 0) for x in padding]
333
+
334
+ in_sizes = img.size()
335
+
336
+ _x_indices = [i for i in range(in_sizes[-1])] # [0, 1, 2, 3, ...]
337
+ left_indices = [i for i in range(padding[0] - 1, -1, -1)] # e.g. [3, 2, 1, 0]
338
+ right_indices = [-(i + 1) for i in range(padding[1])] # e.g. [-1, -2, -3]
339
+ x_indices = torch.tensor(left_indices + _x_indices + right_indices, device=img.device)
340
+
341
+ _y_indices = [i for i in range(in_sizes[-2])]
342
+ top_indices = [i for i in range(padding[2] - 1, -1, -1)]
343
+ bottom_indices = [-(i + 1) for i in range(padding[3])]
344
+ y_indices = torch.tensor(top_indices + _y_indices + bottom_indices, device=img.device)
345
+
346
+ ndim = img.ndim
347
+ if ndim == 3:
348
+ return img[:, y_indices[:, None], x_indices[None, :]]
349
+ elif ndim == 4:
350
+ return img[:, :, y_indices[:, None], x_indices[None, :]]
351
+ else:
352
+ raise RuntimeError("Symmetric padding of N-D tensors are not supported yet")
353
+
354
+
355
+ def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
356
+ if isinstance(padding, int):
357
+ if torch.jit.is_scripting():
358
+ # This maybe unreachable
359
+ raise ValueError("padding can't be an int while torchscripting, set it as a list [value, ]")
360
+ pad_left = pad_right = pad_top = pad_bottom = padding
361
+ elif len(padding) == 1:
362
+ pad_left = pad_right = pad_top = pad_bottom = padding[0]
363
+ elif len(padding) == 2:
364
+ pad_left = pad_right = padding[0]
365
+ pad_top = pad_bottom = padding[1]
366
+ else:
367
+ pad_left = padding[0]
368
+ pad_top = padding[1]
369
+ pad_right = padding[2]
370
+ pad_bottom = padding[3]
371
+
372
+ return [pad_left, pad_right, pad_top, pad_bottom]
373
+
374
+
375
+ def pad(
376
+ img: Tensor, padding: Union[int, List[int]], fill: Optional[Union[int, float]] = 0, padding_mode: str = "constant"
377
+ ) -> Tensor:
378
+ _assert_image_tensor(img)
379
+
380
+ if fill is None:
381
+ fill = 0
382
+
383
+ if not isinstance(padding, (int, tuple, list)):
384
+ raise TypeError("Got inappropriate padding arg")
385
+ if not isinstance(fill, (int, float)):
386
+ raise TypeError("Got inappropriate fill arg")
387
+ if not isinstance(padding_mode, str):
388
+ raise TypeError("Got inappropriate padding_mode arg")
389
+
390
+ if isinstance(padding, tuple):
391
+ padding = list(padding)
392
+
393
+ if isinstance(padding, list):
394
+ # TODO: Jit is failing on loading this op when scripted and saved
395
+ # https://github.com/pytorch/pytorch/issues/81100
396
+ if len(padding) not in [1, 2, 4]:
397
+ raise ValueError(
398
+ f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
399
+ )
400
+
401
+ if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
402
+ raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
403
+
404
+ p = _parse_pad_padding(padding)
405
+
406
+ if padding_mode == "edge":
407
+ # remap padding_mode str
408
+ padding_mode = "replicate"
409
+ elif padding_mode == "symmetric":
410
+ # route to another implementation
411
+ return _pad_symmetric(img, p)
412
+
413
+ need_squeeze = False
414
+ if img.ndim < 4:
415
+ img = img.unsqueeze(dim=0)
416
+ need_squeeze = True
417
+
418
+ out_dtype = img.dtype
419
+ need_cast = False
420
+ if (padding_mode != "constant") and img.dtype not in (torch.float32, torch.float64):
421
+ # Here we temporarily cast input tensor to float
422
+ # until pytorch issue is resolved :
423
+ # https://github.com/pytorch/pytorch/issues/40763
424
+ need_cast = True
425
+ img = img.to(torch.float32)
426
+
427
+ if padding_mode in ("reflect", "replicate"):
428
+ img = torch_pad(img, p, mode=padding_mode)
429
+ else:
430
+ img = torch_pad(img, p, mode=padding_mode, value=float(fill))
431
+
432
+ if need_squeeze:
433
+ img = img.squeeze(dim=0)
434
+
435
+ if need_cast:
436
+ img = img.to(out_dtype)
437
+
438
+ return img
439
+
440
+
441
+ def resize(
442
+ img: Tensor,
443
+ size: List[int],
444
+ interpolation: str = "bilinear",
445
+ antialias: Optional[bool] = True,
446
+ ) -> Tensor:
447
+ _assert_image_tensor(img)
448
+
449
+ if isinstance(size, tuple):
450
+ size = list(size)
451
+
452
+ if antialias is None:
453
+ antialias = False
454
+
455
+ if antialias and interpolation not in ["bilinear", "bicubic"]:
456
+ # We manually set it to False to avoid an error downstream in interpolate()
457
+ # This behaviour is documented: the parameter is irrelevant for modes
458
+ # that are not bilinear or bicubic. We used to raise an error here, but
459
+ # now we don't as True is the default.
460
+ antialias = False
461
+
462
+ img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64])
463
+
464
+ # Define align_corners to avoid warnings
465
+ align_corners = False if interpolation in ["bilinear", "bicubic"] else None
466
+
467
+ img = interpolate(img, size=size, mode=interpolation, align_corners=align_corners, antialias=antialias)
468
+
469
+ if interpolation == "bicubic" and out_dtype == torch.uint8:
470
+ img = img.clamp(min=0, max=255)
471
+
472
+ img = _cast_squeeze_out(img, need_cast=need_cast, need_squeeze=need_squeeze, out_dtype=out_dtype)
473
+
474
+ return img
475
+
476
+
477
+ def _assert_grid_transform_inputs(
478
+ img: Tensor,
479
+ matrix: Optional[List[float]],
480
+ interpolation: str,
481
+ fill: Optional[Union[int, float, List[float]]],
482
+ supported_interpolation_modes: List[str],
483
+ coeffs: Optional[List[float]] = None,
484
+ ) -> None:
485
+
486
+ if not (isinstance(img, torch.Tensor)):
487
+ raise TypeError("Input img should be Tensor")
488
+
489
+ _assert_image_tensor(img)
490
+
491
+ if matrix is not None and not isinstance(matrix, list):
492
+ raise TypeError("Argument matrix should be a list")
493
+
494
+ if matrix is not None and len(matrix) != 6:
495
+ raise ValueError("Argument matrix should have 6 float values")
496
+
497
+ if coeffs is not None and len(coeffs) != 8:
498
+ raise ValueError("Argument coeffs should have 8 float values")
499
+
500
+ if fill is not None and not isinstance(fill, (int, float, tuple, list)):
501
+ warnings.warn("Argument fill should be either int, float, tuple or list")
502
+
503
+ # Check fill
504
+ num_channels = get_dimensions(img)[0]
505
+ if fill is not None and isinstance(fill, (tuple, list)) and len(fill) > 1 and len(fill) != num_channels:
506
+ msg = (
507
+ "The number of elements in 'fill' cannot broadcast to match the number of "
508
+ "channels of the image ({} != {})"
509
+ )
510
+ raise ValueError(msg.format(len(fill), num_channels))
511
+
512
+ if interpolation not in supported_interpolation_modes:
513
+ raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input")
514
+
515
+
516
+ def _cast_squeeze_in(img: Tensor, req_dtypes: List[torch.dtype]) -> Tuple[Tensor, bool, bool, torch.dtype]:
517
+ need_squeeze = False
518
+ # make image NCHW
519
+ if img.ndim < 4:
520
+ img = img.unsqueeze(dim=0)
521
+ need_squeeze = True
522
+
523
+ out_dtype = img.dtype
524
+ need_cast = False
525
+ if out_dtype not in req_dtypes:
526
+ need_cast = True
527
+ req_dtype = req_dtypes[0]
528
+ img = img.to(req_dtype)
529
+ return img, need_cast, need_squeeze, out_dtype
530
+
531
+
532
+ def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtype: torch.dtype) -> Tensor:
533
+ if need_squeeze:
534
+ img = img.squeeze(dim=0)
535
+
536
+ if need_cast:
537
+ if out_dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
538
+ # it is better to round before cast
539
+ img = torch.round(img)
540
+ img = img.to(out_dtype)
541
+
542
+ return img
543
+
544
+
545
+ def _apply_grid_transform(
546
+ img: Tensor, grid: Tensor, mode: str, fill: Optional[Union[int, float, List[float]]]
547
+ ) -> Tensor:
548
+
549
+ img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [grid.dtype])
550
+
551
+ if img.shape[0] > 1:
552
+ # Apply same grid to a batch of images
553
+ grid = grid.expand(img.shape[0], grid.shape[1], grid.shape[2], grid.shape[3])
554
+
555
+ # Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
556
+ if fill is not None:
557
+ mask = torch.ones((img.shape[0], 1, img.shape[2], img.shape[3]), dtype=img.dtype, device=img.device)
558
+ img = torch.cat((img, mask), dim=1)
559
+
560
+ img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False)
561
+
562
+ # Fill with required color
563
+ if fill is not None:
564
+ mask = img[:, -1:, :, :] # N * 1 * H * W
565
+ img = img[:, :-1, :, :] # N * C * H * W
566
+ mask = mask.expand_as(img)
567
+ fill_list, len_fill = (fill, len(fill)) if isinstance(fill, (tuple, list)) else ([float(fill)], 1)
568
+ fill_img = torch.tensor(fill_list, dtype=img.dtype, device=img.device).view(1, len_fill, 1, 1).expand_as(img)
569
+ if mode == "nearest":
570
+ mask = mask < 0.5
571
+ img[mask] = fill_img[mask]
572
+ else: # 'bilinear'
573
+ img = img * mask + (1.0 - mask) * fill_img
574
+
575
+ img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
576
+ return img
577
+
578
+
579
+ def _gen_affine_grid(
580
+ theta: Tensor,
581
+ w: int,
582
+ h: int,
583
+ ow: int,
584
+ oh: int,
585
+ ) -> Tensor:
586
+ # https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/
587
+ # AffineGridGenerator.cpp#L18
588
+ # Difference with AffineGridGenerator is that:
589
+ # 1) we normalize grid values after applying theta
590
+ # 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate
591
+
592
+ d = 0.5
593
+ base_grid = torch.empty(1, oh, ow, 3, dtype=theta.dtype, device=theta.device)
594
+ x_grid = torch.linspace(-ow * 0.5 + d, ow * 0.5 + d - 1, steps=ow, device=theta.device)
595
+ base_grid[..., 0].copy_(x_grid)
596
+ y_grid = torch.linspace(-oh * 0.5 + d, oh * 0.5 + d - 1, steps=oh, device=theta.device).unsqueeze_(-1)
597
+ base_grid[..., 1].copy_(y_grid)
598
+ base_grid[..., 2].fill_(1)
599
+
600
+ rescaled_theta = theta.transpose(1, 2) / torch.tensor([0.5 * w, 0.5 * h], dtype=theta.dtype, device=theta.device)
601
+ output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta)
602
+ return output_grid.view(1, oh, ow, 2)
603
+
604
+
605
+ def affine(
606
+ img: Tensor,
607
+ matrix: List[float],
608
+ interpolation: str = "nearest",
609
+ fill: Optional[Union[int, float, List[float]]] = None,
610
+ ) -> Tensor:
611
+ _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
612
+
613
+ dtype = img.dtype if torch.is_floating_point(img) else torch.float32
614
+ theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
615
+ shape = img.shape
616
+ # grid will be generated on the same device as theta and img
617
+ grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2])
618
+ return _apply_grid_transform(img, grid, interpolation, fill=fill)
619
+
620
+
621
+ def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]:
622
+
623
+ # Inspired of PIL implementation:
624
+ # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054
625
+
626
+ # pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
627
+ # Points are shifted due to affine matrix torch convention about
628
+ # the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5)
629
+ pts = torch.tensor(
630
+ [
631
+ [-0.5 * w, -0.5 * h, 1.0],
632
+ [-0.5 * w, 0.5 * h, 1.0],
633
+ [0.5 * w, 0.5 * h, 1.0],
634
+ [0.5 * w, -0.5 * h, 1.0],
635
+ ]
636
+ )
637
+ theta = torch.tensor(matrix, dtype=torch.float).view(2, 3)
638
+ new_pts = torch.matmul(pts, theta.T)
639
+ min_vals, _ = new_pts.min(dim=0)
640
+ max_vals, _ = new_pts.max(dim=0)
641
+
642
+ # shift points to [0, w] and [0, h] interval to match PIL results
643
+ min_vals += torch.tensor((w * 0.5, h * 0.5))
644
+ max_vals += torch.tensor((w * 0.5, h * 0.5))
645
+
646
+ # Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0
647
+ tol = 1e-4
648
+ cmax = torch.ceil((max_vals / tol).trunc_() * tol)
649
+ cmin = torch.floor((min_vals / tol).trunc_() * tol)
650
+ size = cmax - cmin
651
+ return int(size[0]), int(size[1]) # w, h
652
+
653
+
654
+ def rotate(
655
+ img: Tensor,
656
+ matrix: List[float],
657
+ interpolation: str = "nearest",
658
+ expand: bool = False,
659
+ fill: Optional[Union[int, float, List[float]]] = None,
660
+ ) -> Tensor:
661
+ _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
662
+ w, h = img.shape[-1], img.shape[-2]
663
+ ow, oh = _compute_affine_output_size(matrix, w, h) if expand else (w, h)
664
+ dtype = img.dtype if torch.is_floating_point(img) else torch.float32
665
+ theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
666
+ # grid will be generated on the same device as theta and img
667
+ grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh)
668
+
669
+ return _apply_grid_transform(img, grid, interpolation, fill=fill)
670
+
671
+
672
+ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> Tensor:
673
+ # https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
674
+ # src/libImaging/Geometry.c#L394
675
+
676
+ #
677
+ # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
678
+ # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
679
+ #
680
+ theta1 = torch.tensor(
681
+ [[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device
682
+ )
683
+ theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device)
684
+
685
+ d = 0.5
686
+ base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
687
+ x_grid = torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow, device=device)
688
+ base_grid[..., 0].copy_(x_grid)
689
+ y_grid = torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh, device=device).unsqueeze_(-1)
690
+ base_grid[..., 1].copy_(y_grid)
691
+ base_grid[..., 2].fill_(1)
692
+
693
+ rescaled_theta1 = theta1.transpose(1, 2) / torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device)
694
+ output_grid1 = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta1)
695
+ output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2))
696
+
697
+ output_grid = output_grid1 / output_grid2 - 1.0
698
+ return output_grid.view(1, oh, ow, 2)
699
+
700
+
701
+ def perspective(
702
+ img: Tensor,
703
+ perspective_coeffs: List[float],
704
+ interpolation: str = "bilinear",
705
+ fill: Optional[Union[int, float, List[float]]] = None,
706
+ ) -> Tensor:
707
+ if not (isinstance(img, torch.Tensor)):
708
+ raise TypeError("Input img should be Tensor.")
709
+
710
+ _assert_image_tensor(img)
711
+
712
+ _assert_grid_transform_inputs(
713
+ img,
714
+ matrix=None,
715
+ interpolation=interpolation,
716
+ fill=fill,
717
+ supported_interpolation_modes=["nearest", "bilinear"],
718
+ coeffs=perspective_coeffs,
719
+ )
720
+
721
+ ow, oh = img.shape[-1], img.shape[-2]
722
+ dtype = img.dtype if torch.is_floating_point(img) else torch.float32
723
+ grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=img.device)
724
+ return _apply_grid_transform(img, grid, interpolation, fill=fill)
725
+
726
+
727
+ def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> Tensor:
728
+ ksize_half = (kernel_size - 1) * 0.5
729
+
730
+ x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size, dtype=dtype, device=device)
731
+ pdf = torch.exp(-0.5 * (x / sigma).pow(2))
732
+ kernel1d = pdf / pdf.sum()
733
+
734
+ return kernel1d
735
+
736
+
737
+ def _get_gaussian_kernel2d(
738
+ kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
739
+ ) -> Tensor:
740
+ kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0], dtype, device)
741
+ kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1], dtype, device)
742
+ kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :])
743
+ return kernel2d
744
+
745
+
746
+ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor:
747
+ if not (isinstance(img, torch.Tensor)):
748
+ raise TypeError(f"img should be Tensor. Got {type(img)}")
749
+
750
+ _assert_image_tensor(img)
751
+
752
+ dtype = img.dtype if torch.is_floating_point(img) else torch.float32
753
+ kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device)
754
+ kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])
755
+
756
+ img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype])
757
+
758
+ # padding = (left, right, top, bottom)
759
+ padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
760
+ img = torch_pad(img, padding, mode="reflect")
761
+ img = conv2d(img, kernel, groups=img.shape[-3])
762
+
763
+ img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
764
+ return img
765
+
766
+
767
+ def invert(img: Tensor) -> Tensor:
768
+
769
+ _assert_image_tensor(img)
770
+
771
+ if img.ndim < 3:
772
+ raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
773
+
774
+ _assert_channels(img, [1, 3])
775
+
776
+ return _max_value(img.dtype) - img
777
+
778
+
779
+ def posterize(img: Tensor, bits: int) -> Tensor:
780
+
781
+ _assert_image_tensor(img)
782
+
783
+ if img.ndim < 3:
784
+ raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
785
+ if img.dtype != torch.uint8:
786
+ raise TypeError(f"Only torch.uint8 image tensors are supported, but found {img.dtype}")
787
+
788
+ _assert_channels(img, [1, 3])
789
+ mask = -int(2 ** (8 - bits)) # JIT-friendly for: ~(2 ** (8 - bits) - 1)
790
+ return img & mask
791
+
792
+
793
+ def solarize(img: Tensor, threshold: float) -> Tensor:
794
+
795
+ _assert_image_tensor(img)
796
+
797
+ if img.ndim < 3:
798
+ raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
799
+
800
+ _assert_channels(img, [1, 3])
801
+
802
+ if threshold > _max_value(img.dtype):
803
+ raise TypeError("Threshold should be less than bound of img.")
804
+
805
+ inverted_img = invert(img)
806
+ return torch.where(img >= threshold, inverted_img, img)
807
+
808
+
809
+ def _blurred_degenerate_image(img: Tensor) -> Tensor:
810
+ dtype = img.dtype if torch.is_floating_point(img) else torch.float32
811
+
812
+ kernel = torch.ones((3, 3), dtype=dtype, device=img.device)
813
+ kernel[1, 1] = 5.0
814
+ kernel /= kernel.sum()
815
+ kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])
816
+
817
+ result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype])
818
+ result_tmp = conv2d(result_tmp, kernel, groups=result_tmp.shape[-3])
819
+ result_tmp = _cast_squeeze_out(result_tmp, need_cast, need_squeeze, out_dtype)
820
+
821
+ result = img.clone()
822
+ result[..., 1:-1, 1:-1] = result_tmp
823
+
824
+ return result
825
+
826
+
827
+ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
828
+ if sharpness_factor < 0:
829
+ raise ValueError(f"sharpness_factor ({sharpness_factor}) is not non-negative.")
830
+
831
+ _assert_image_tensor(img)
832
+
833
+ _assert_channels(img, [1, 3])
834
+
835
+ if img.size(-1) <= 2 or img.size(-2) <= 2:
836
+ return img
837
+
838
+ return _blend(img, _blurred_degenerate_image(img), sharpness_factor)
839
+
840
+
841
+ def autocontrast(img: Tensor) -> Tensor:
842
+
843
+ _assert_image_tensor(img)
844
+
845
+ if img.ndim < 3:
846
+ raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {img.ndim}")
847
+
848
+ _assert_channels(img, [1, 3])
849
+
850
+ bound = _max_value(img.dtype)
851
+ dtype = img.dtype if torch.is_floating_point(img) else torch.float32
852
+
853
+ minimum = img.amin(dim=(-2, -1), keepdim=True).to(dtype)
854
+ maximum = img.amax(dim=(-2, -1), keepdim=True).to(dtype)
855
+ scale = bound / (maximum - minimum)
856
+ eq_idxs = torch.isfinite(scale).logical_not()
857
+ minimum[eq_idxs] = 0
858
+ scale[eq_idxs] = 1
859
+
860
+ return ((img - minimum) * scale).clamp(0, bound).to(img.dtype)
861
+
862
+
863
+ def _scale_channel(img_chan: Tensor) -> Tensor:
864
+ # TODO: we should expect bincount to always be faster than histc, but this
865
+ # isn't always the case. Once
866
+ # https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if
867
+ # block and only use bincount.
868
+ if img_chan.is_cuda:
869
+ hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255)
870
+ else:
871
+ hist = torch.bincount(img_chan.reshape(-1), minlength=256)
872
+
873
+ nonzero_hist = hist[hist != 0]
874
+ step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode="floor")
875
+ if step == 0:
876
+ return img_chan
877
+
878
+ lut = torch.div(torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode="floor"), step, rounding_mode="floor")
879
+ lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255)
880
+
881
+ return lut[img_chan.to(torch.int64)].to(torch.uint8)
882
+
883
+
884
+ def _equalize_single_image(img: Tensor) -> Tensor:
885
+ return torch.stack([_scale_channel(img[c]) for c in range(img.size(0))])
886
+
887
+
888
+ def equalize(img: Tensor) -> Tensor:
889
+
890
+ _assert_image_tensor(img)
891
+
892
+ if not (3 <= img.ndim <= 4):
893
+ raise TypeError(f"Input image tensor should have 3 or 4 dimensions, but found {img.ndim}")
894
+ if img.dtype != torch.uint8:
895
+ raise TypeError(f"Only torch.uint8 image tensors are supported, but found {img.dtype}")
896
+
897
+ _assert_channels(img, [1, 3])
898
+
899
+ if img.ndim == 3:
900
+ return _equalize_single_image(img)
901
+
902
+ return torch.stack([_equalize_single_image(x) for x in img])
903
+
904
+
905
+ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor:
906
+ _assert_image_tensor(tensor)
907
+
908
+ if not tensor.is_floating_point():
909
+ raise TypeError(f"Input tensor should be a float tensor. Got {tensor.dtype}.")
910
+
911
+ if tensor.ndim < 3:
912
+ raise ValueError(
913
+ f"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = {tensor.size()}"
914
+ )
915
+
916
+ if not inplace:
917
+ tensor = tensor.clone()
918
+
919
+ dtype = tensor.dtype
920
+ mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
921
+ std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
922
+ if (std == 0).any():
923
+ raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.")
924
+ if mean.ndim == 1:
925
+ mean = mean.view(-1, 1, 1)
926
+ if std.ndim == 1:
927
+ std = std.view(-1, 1, 1)
928
+ return tensor.sub_(mean).div_(std)
929
+
930
+
931
+ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
932
+ _assert_image_tensor(img)
933
+
934
+ if not inplace:
935
+ img = img.clone()
936
+
937
+ img[..., i : i + h, j : j + w] = v
938
+ return img
939
+
940
+
941
+ def _create_identity_grid(size: List[int]) -> Tensor:
942
+ hw_space = [torch.linspace((-s + 1) / s, (s - 1) / s, s) for s in size]
943
+ grid_y, grid_x = torch.meshgrid(hw_space, indexing="ij")
944
+ return torch.stack([grid_x, grid_y], -1).unsqueeze(0) # 1 x H x W x 2
945
+
946
+
947
+ def elastic_transform(
948
+ img: Tensor,
949
+ displacement: Tensor,
950
+ interpolation: str = "bilinear",
951
+ fill: Optional[Union[int, float, List[float]]] = None,
952
+ ) -> Tensor:
953
+
954
+ if not (isinstance(img, torch.Tensor)):
955
+ raise TypeError(f"img should be Tensor. Got {type(img)}")
956
+
957
+ size = list(img.shape[-2:])
958
+ displacement = displacement.to(img.device)
959
+
960
+ identity_grid = _create_identity_grid(size)
961
+ grid = identity_grid.to(img.device) + displacement
962
+ return _apply_grid_transform(img, grid, interpolation, fill)
.venv/lib/python3.11/site-packages/torchvision/transforms/_functional_video.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import torch
4
+
5
+
6
+ warnings.warn(
7
+ "The 'torchvision.transforms._functional_video' module is deprecated since 0.12 and will be removed in the future. "
8
+ "Please use the 'torchvision.transforms.functional' module instead."
9
+ )
10
+
11
+
12
+ def _is_tensor_video_clip(clip):
13
+ if not torch.is_tensor(clip):
14
+ raise TypeError("clip should be Tensor. Got %s" % type(clip))
15
+
16
+ if not clip.ndimension() == 4:
17
+ raise ValueError("clip should be 4D. Got %dD" % clip.dim())
18
+
19
+ return True
20
+
21
+
22
+ def crop(clip, i, j, h, w):
23
+ """
24
+ Args:
25
+ clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
26
+ """
27
+ if len(clip.size()) != 4:
28
+ raise ValueError("clip should be a 4D tensor")
29
+ return clip[..., i : i + h, j : j + w]
30
+
31
+
32
+ def resize(clip, target_size, interpolation_mode):
33
+ if len(target_size) != 2:
34
+ raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
35
+ return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
36
+
37
+
38
+ def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
39
+ """
40
+ Do spatial cropping and resizing to the video clip
41
+ Args:
42
+ clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
43
+ i (int): i in (i,j) i.e coordinates of the upper left corner.
44
+ j (int): j in (i,j) i.e coordinates of the upper left corner.
45
+ h (int): Height of the cropped region.
46
+ w (int): Width of the cropped region.
47
+ size (tuple(int, int)): height and width of resized clip
48
+ Returns:
49
+ clip (torch.tensor): Resized and cropped clip. Size is (C, T, H, W)
50
+ """
51
+ if not _is_tensor_video_clip(clip):
52
+ raise ValueError("clip should be a 4D torch.tensor")
53
+ clip = crop(clip, i, j, h, w)
54
+ clip = resize(clip, size, interpolation_mode)
55
+ return clip
56
+
57
+
58
+ def center_crop(clip, crop_size):
59
+ if not _is_tensor_video_clip(clip):
60
+ raise ValueError("clip should be a 4D torch.tensor")
61
+ h, w = clip.size(-2), clip.size(-1)
62
+ th, tw = crop_size
63
+ if h < th or w < tw:
64
+ raise ValueError("height and width must be no smaller than crop_size")
65
+
66
+ i = int(round((h - th) / 2.0))
67
+ j = int(round((w - tw) / 2.0))
68
+ return crop(clip, i, j, th, tw)
69
+
70
+
71
+ def to_tensor(clip):
72
+ """
73
+ Convert tensor data type from uint8 to float, divide value by 255.0 and
74
+ permute the dimensions of clip tensor
75
+ Args:
76
+ clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C)
77
+ Return:
78
+ clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W)
79
+ """
80
+ _is_tensor_video_clip(clip)
81
+ if not clip.dtype == torch.uint8:
82
+ raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
83
+ return clip.float().permute(3, 0, 1, 2) / 255.0
84
+
85
+
86
+ def normalize(clip, mean, std, inplace=False):
87
+ """
88
+ Args:
89
+ clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
90
+ mean (tuple): pixel RGB mean. Size is (3)
91
+ std (tuple): pixel standard deviation. Size is (3)
92
+ Returns:
93
+ normalized clip (torch.tensor): Size is (C, T, H, W)
94
+ """
95
+ if not _is_tensor_video_clip(clip):
96
+ raise ValueError("clip should be a 4D torch.tensor")
97
+ if not inplace:
98
+ clip = clip.clone()
99
+ mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
100
+ std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
101
+ clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
102
+ return clip
103
+
104
+
105
+ def hflip(clip):
106
+ """
107
+ Args:
108
+ clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
109
+ Returns:
110
+ flipped clip (torch.tensor): Size is (C, T, H, W)
111
+ """
112
+ if not _is_tensor_video_clip(clip):
113
+ raise ValueError("clip should be a 4D torch.tensor")
114
+ return clip.flip(-1)
.venv/lib/python3.11/site-packages/torchvision/transforms/_presets.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is part of the private API. Please do not use directly these classes as they will be modified on
3
+ future versions without warning. The classes should be accessed only via the transforms argument of Weights.
4
+ """
5
+ from typing import Optional, Tuple, Union
6
+
7
+ import torch
8
+ from torch import nn, Tensor
9
+
10
+ from . import functional as F, InterpolationMode
11
+
12
+
13
+ __all__ = [
14
+ "ObjectDetection",
15
+ "ImageClassification",
16
+ "VideoClassification",
17
+ "SemanticSegmentation",
18
+ "OpticalFlow",
19
+ ]
20
+
21
+
22
+ class ObjectDetection(nn.Module):
23
+ def forward(self, img: Tensor) -> Tensor:
24
+ if not isinstance(img, Tensor):
25
+ img = F.pil_to_tensor(img)
26
+ return F.convert_image_dtype(img, torch.float)
27
+
28
+ def __repr__(self) -> str:
29
+ return self.__class__.__name__ + "()"
30
+
31
+ def describe(self) -> str:
32
+ return (
33
+ "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
34
+ "The images are rescaled to ``[0.0, 1.0]``."
35
+ )
36
+
37
+
38
+ class ImageClassification(nn.Module):
39
+ def __init__(
40
+ self,
41
+ *,
42
+ crop_size: int,
43
+ resize_size: int = 256,
44
+ mean: Tuple[float, ...] = (0.485, 0.456, 0.406),
45
+ std: Tuple[float, ...] = (0.229, 0.224, 0.225),
46
+ interpolation: InterpolationMode = InterpolationMode.BILINEAR,
47
+ antialias: Optional[bool] = True,
48
+ ) -> None:
49
+ super().__init__()
50
+ self.crop_size = [crop_size]
51
+ self.resize_size = [resize_size]
52
+ self.mean = list(mean)
53
+ self.std = list(std)
54
+ self.interpolation = interpolation
55
+ self.antialias = antialias
56
+
57
+ def forward(self, img: Tensor) -> Tensor:
58
+ img = F.resize(img, self.resize_size, interpolation=self.interpolation, antialias=self.antialias)
59
+ img = F.center_crop(img, self.crop_size)
60
+ if not isinstance(img, Tensor):
61
+ img = F.pil_to_tensor(img)
62
+ img = F.convert_image_dtype(img, torch.float)
63
+ img = F.normalize(img, mean=self.mean, std=self.std)
64
+ return img
65
+
66
+ def __repr__(self) -> str:
67
+ format_string = self.__class__.__name__ + "("
68
+ format_string += f"\n crop_size={self.crop_size}"
69
+ format_string += f"\n resize_size={self.resize_size}"
70
+ format_string += f"\n mean={self.mean}"
71
+ format_string += f"\n std={self.std}"
72
+ format_string += f"\n interpolation={self.interpolation}"
73
+ format_string += "\n)"
74
+ return format_string
75
+
76
+ def describe(self) -> str:
77
+ return (
78
+ "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
79
+ f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, "
80
+ f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to "
81
+ f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``."
82
+ )
83
+
84
+
85
+ class VideoClassification(nn.Module):
86
+ def __init__(
87
+ self,
88
+ *,
89
+ crop_size: Tuple[int, int],
90
+ resize_size: Union[Tuple[int], Tuple[int, int]],
91
+ mean: Tuple[float, ...] = (0.43216, 0.394666, 0.37645),
92
+ std: Tuple[float, ...] = (0.22803, 0.22145, 0.216989),
93
+ interpolation: InterpolationMode = InterpolationMode.BILINEAR,
94
+ ) -> None:
95
+ super().__init__()
96
+ self.crop_size = list(crop_size)
97
+ self.resize_size = list(resize_size)
98
+ self.mean = list(mean)
99
+ self.std = list(std)
100
+ self.interpolation = interpolation
101
+
102
+ def forward(self, vid: Tensor) -> Tensor:
103
+ need_squeeze = False
104
+ if vid.ndim < 5:
105
+ vid = vid.unsqueeze(dim=0)
106
+ need_squeeze = True
107
+
108
+ N, T, C, H, W = vid.shape
109
+ vid = vid.view(-1, C, H, W)
110
+ # We hard-code antialias=False to preserve results after we changed
111
+ # its default from None to True (see
112
+ # https://github.com/pytorch/vision/pull/7160)
113
+ # TODO: we could re-train the video models with antialias=True?
114
+ vid = F.resize(vid, self.resize_size, interpolation=self.interpolation, antialias=False)
115
+ vid = F.center_crop(vid, self.crop_size)
116
+ vid = F.convert_image_dtype(vid, torch.float)
117
+ vid = F.normalize(vid, mean=self.mean, std=self.std)
118
+ H, W = self.crop_size
119
+ vid = vid.view(N, T, C, H, W)
120
+ vid = vid.permute(0, 2, 1, 3, 4) # (N, T, C, H, W) => (N, C, T, H, W)
121
+
122
+ if need_squeeze:
123
+ vid = vid.squeeze(dim=0)
124
+ return vid
125
+
126
+ def __repr__(self) -> str:
127
+ format_string = self.__class__.__name__ + "("
128
+ format_string += f"\n crop_size={self.crop_size}"
129
+ format_string += f"\n resize_size={self.resize_size}"
130
+ format_string += f"\n mean={self.mean}"
131
+ format_string += f"\n std={self.std}"
132
+ format_string += f"\n interpolation={self.interpolation}"
133
+ format_string += "\n)"
134
+ return format_string
135
+
136
+ def describe(self) -> str:
137
+ return (
138
+ "Accepts batched ``(B, T, C, H, W)`` and single ``(T, C, H, W)`` video frame ``torch.Tensor`` objects. "
139
+ f"The frames are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, "
140
+ f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to "
141
+ f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``. Finally the output "
142
+ "dimensions are permuted to ``(..., C, T, H, W)`` tensors."
143
+ )
144
+
145
+
146
+ class SemanticSegmentation(nn.Module):
147
+ def __init__(
148
+ self,
149
+ *,
150
+ resize_size: Optional[int],
151
+ mean: Tuple[float, ...] = (0.485, 0.456, 0.406),
152
+ std: Tuple[float, ...] = (0.229, 0.224, 0.225),
153
+ interpolation: InterpolationMode = InterpolationMode.BILINEAR,
154
+ antialias: Optional[bool] = True,
155
+ ) -> None:
156
+ super().__init__()
157
+ self.resize_size = [resize_size] if resize_size is not None else None
158
+ self.mean = list(mean)
159
+ self.std = list(std)
160
+ self.interpolation = interpolation
161
+ self.antialias = antialias
162
+
163
+ def forward(self, img: Tensor) -> Tensor:
164
+ if isinstance(self.resize_size, list):
165
+ img = F.resize(img, self.resize_size, interpolation=self.interpolation, antialias=self.antialias)
166
+ if not isinstance(img, Tensor):
167
+ img = F.pil_to_tensor(img)
168
+ img = F.convert_image_dtype(img, torch.float)
169
+ img = F.normalize(img, mean=self.mean, std=self.std)
170
+ return img
171
+
172
+ def __repr__(self) -> str:
173
+ format_string = self.__class__.__name__ + "("
174
+ format_string += f"\n resize_size={self.resize_size}"
175
+ format_string += f"\n mean={self.mean}"
176
+ format_string += f"\n std={self.std}"
177
+ format_string += f"\n interpolation={self.interpolation}"
178
+ format_string += "\n)"
179
+ return format_string
180
+
181
+ def describe(self) -> str:
182
+ return (
183
+ "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
184
+ f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``. "
185
+ f"Finally the values are first rescaled to ``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and "
186
+ f"``std={self.std}``."
187
+ )
188
+
189
+
190
+ class OpticalFlow(nn.Module):
191
+ def forward(self, img1: Tensor, img2: Tensor) -> Tuple[Tensor, Tensor]:
192
+ if not isinstance(img1, Tensor):
193
+ img1 = F.pil_to_tensor(img1)
194
+ if not isinstance(img2, Tensor):
195
+ img2 = F.pil_to_tensor(img2)
196
+
197
+ img1 = F.convert_image_dtype(img1, torch.float)
198
+ img2 = F.convert_image_dtype(img2, torch.float)
199
+
200
+ # map [0, 1] into [-1, 1]
201
+ img1 = F.normalize(img1, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
202
+ img2 = F.normalize(img2, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
203
+
204
+ img1 = img1.contiguous()
205
+ img2 = img2.contiguous()
206
+
207
+ return img1, img2
208
+
209
+ def __repr__(self) -> str:
210
+ return self.__class__.__name__ + "()"
211
+
212
+ def describe(self) -> str:
213
+ return (
214
+ "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
215
+ "The images are rescaled to ``[-1.0, 1.0]``."
216
+ )
.venv/lib/python3.11/site-packages/torchvision/transforms/_transforms_video.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import numbers
4
+ import random
5
+ import warnings
6
+
7
+ from torchvision.transforms import RandomCrop, RandomResizedCrop
8
+
9
+ from . import _functional_video as F
10
+
11
+
12
+ __all__ = [
13
+ "RandomCropVideo",
14
+ "RandomResizedCropVideo",
15
+ "CenterCropVideo",
16
+ "NormalizeVideo",
17
+ "ToTensorVideo",
18
+ "RandomHorizontalFlipVideo",
19
+ ]
20
+
21
+
22
+ warnings.warn(
23
+ "The 'torchvision.transforms._transforms_video' module is deprecated since 0.12 and will be removed in the future. "
24
+ "Please use the 'torchvision.transforms' module instead."
25
+ )
26
+
27
+
28
+ class RandomCropVideo(RandomCrop):
29
+ def __init__(self, size):
30
+ if isinstance(size, numbers.Number):
31
+ self.size = (int(size), int(size))
32
+ else:
33
+ self.size = size
34
+
35
+ def __call__(self, clip):
36
+ """
37
+ Args:
38
+ clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
39
+ Returns:
40
+ torch.tensor: randomly cropped/resized video clip.
41
+ size is (C, T, OH, OW)
42
+ """
43
+ i, j, h, w = self.get_params(clip, self.size)
44
+ return F.crop(clip, i, j, h, w)
45
+
46
+ def __repr__(self) -> str:
47
+ return f"{self.__class__.__name__}(size={self.size})"
48
+
49
+
50
+ class RandomResizedCropVideo(RandomResizedCrop):
51
+ def __init__(
52
+ self,
53
+ size,
54
+ scale=(0.08, 1.0),
55
+ ratio=(3.0 / 4.0, 4.0 / 3.0),
56
+ interpolation_mode="bilinear",
57
+ ):
58
+ if isinstance(size, tuple):
59
+ if len(size) != 2:
60
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
61
+ self.size = size
62
+ else:
63
+ self.size = (size, size)
64
+
65
+ self.interpolation_mode = interpolation_mode
66
+ self.scale = scale
67
+ self.ratio = ratio
68
+
69
+ def __call__(self, clip):
70
+ """
71
+ Args:
72
+ clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
73
+ Returns:
74
+ torch.tensor: randomly cropped/resized video clip.
75
+ size is (C, T, H, W)
76
+ """
77
+ i, j, h, w = self.get_params(clip, self.scale, self.ratio)
78
+ return F.resized_crop(clip, i, j, h, w, self.size, self.interpolation_mode)
79
+
80
+ def __repr__(self) -> str:
81
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}, scale={self.scale}, ratio={self.ratio})"
82
+
83
+
84
+ class CenterCropVideo:
85
+ def __init__(self, crop_size):
86
+ if isinstance(crop_size, numbers.Number):
87
+ self.crop_size = (int(crop_size), int(crop_size))
88
+ else:
89
+ self.crop_size = crop_size
90
+
91
+ def __call__(self, clip):
92
+ """
93
+ Args:
94
+ clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
95
+ Returns:
96
+ torch.tensor: central cropping of video clip. Size is
97
+ (C, T, crop_size, crop_size)
98
+ """
99
+ return F.center_crop(clip, self.crop_size)
100
+
101
+ def __repr__(self) -> str:
102
+ return f"{self.__class__.__name__}(crop_size={self.crop_size})"
103
+
104
+
105
+ class NormalizeVideo:
106
+ """
107
+ Normalize the video clip by mean subtraction and division by standard deviation
108
+ Args:
109
+ mean (3-tuple): pixel RGB mean
110
+ std (3-tuple): pixel RGB standard deviation
111
+ inplace (boolean): whether do in-place normalization
112
+ """
113
+
114
+ def __init__(self, mean, std, inplace=False):
115
+ self.mean = mean
116
+ self.std = std
117
+ self.inplace = inplace
118
+
119
+ def __call__(self, clip):
120
+ """
121
+ Args:
122
+ clip (torch.tensor): video clip to be normalized. Size is (C, T, H, W)
123
+ """
124
+ return F.normalize(clip, self.mean, self.std, self.inplace)
125
+
126
+ def __repr__(self) -> str:
127
+ return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
128
+
129
+
130
+ class ToTensorVideo:
131
+ """
132
+ Convert tensor data type from uint8 to float, divide value by 255.0 and
133
+ permute the dimensions of clip tensor
134
+ """
135
+
136
+ def __init__(self):
137
+ pass
138
+
139
+ def __call__(self, clip):
140
+ """
141
+ Args:
142
+ clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C)
143
+ Return:
144
+ clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W)
145
+ """
146
+ return F.to_tensor(clip)
147
+
148
+ def __repr__(self) -> str:
149
+ return self.__class__.__name__
150
+
151
+
152
+ class RandomHorizontalFlipVideo:
153
+ """
154
+ Flip the video clip along the horizontal direction with a given probability
155
+ Args:
156
+ p (float): probability of the clip being flipped. Default value is 0.5
157
+ """
158
+
159
+ def __init__(self, p=0.5):
160
+ self.p = p
161
+
162
+ def __call__(self, clip):
163
+ """
164
+ Args:
165
+ clip (torch.tensor): Size is (C, T, H, W)
166
+ Return:
167
+ clip (torch.tensor): Size is (C, T, H, W)
168
+ """
169
+ if random.random() < self.p:
170
+ clip = F.hflip(clip)
171
+ return clip
172
+
173
+ def __repr__(self) -> str:
174
+ return f"{self.__class__.__name__}(p={self.p})"
.venv/lib/python3.11/site-packages/torchvision/transforms/autoaugment.py ADDED
@@ -0,0 +1,615 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from enum import Enum
3
+ from typing import Dict, List, Optional, Tuple
4
+
5
+ import torch
6
+ from torch import Tensor
7
+
8
+ from . import functional as F, InterpolationMode
9
+
10
+ __all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide", "AugMix"]
11
+
12
+
13
+ def _apply_op(
14
+ img: Tensor, op_name: str, magnitude: float, interpolation: InterpolationMode, fill: Optional[List[float]]
15
+ ):
16
+ if op_name == "ShearX":
17
+ # magnitude should be arctan(magnitude)
18
+ # official autoaug: (1, level, 0, 0, 1, 0)
19
+ # https://github.com/tensorflow/models/blob/dd02069717128186b88afa8d857ce57d17957f03/research/autoaugment/augmentation_transforms.py#L290
20
+ # compared to
21
+ # torchvision: (1, tan(level), 0, 0, 1, 0)
22
+ # https://github.com/pytorch/vision/blob/0c2373d0bba3499e95776e7936e207d8a1676e65/torchvision/transforms/functional.py#L976
23
+ img = F.affine(
24
+ img,
25
+ angle=0.0,
26
+ translate=[0, 0],
27
+ scale=1.0,
28
+ shear=[math.degrees(math.atan(magnitude)), 0.0],
29
+ interpolation=interpolation,
30
+ fill=fill,
31
+ center=[0, 0],
32
+ )
33
+ elif op_name == "ShearY":
34
+ # magnitude should be arctan(magnitude)
35
+ # See above
36
+ img = F.affine(
37
+ img,
38
+ angle=0.0,
39
+ translate=[0, 0],
40
+ scale=1.0,
41
+ shear=[0.0, math.degrees(math.atan(magnitude))],
42
+ interpolation=interpolation,
43
+ fill=fill,
44
+ center=[0, 0],
45
+ )
46
+ elif op_name == "TranslateX":
47
+ img = F.affine(
48
+ img,
49
+ angle=0.0,
50
+ translate=[int(magnitude), 0],
51
+ scale=1.0,
52
+ interpolation=interpolation,
53
+ shear=[0.0, 0.0],
54
+ fill=fill,
55
+ )
56
+ elif op_name == "TranslateY":
57
+ img = F.affine(
58
+ img,
59
+ angle=0.0,
60
+ translate=[0, int(magnitude)],
61
+ scale=1.0,
62
+ interpolation=interpolation,
63
+ shear=[0.0, 0.0],
64
+ fill=fill,
65
+ )
66
+ elif op_name == "Rotate":
67
+ img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill)
68
+ elif op_name == "Brightness":
69
+ img = F.adjust_brightness(img, 1.0 + magnitude)
70
+ elif op_name == "Color":
71
+ img = F.adjust_saturation(img, 1.0 + magnitude)
72
+ elif op_name == "Contrast":
73
+ img = F.adjust_contrast(img, 1.0 + magnitude)
74
+ elif op_name == "Sharpness":
75
+ img = F.adjust_sharpness(img, 1.0 + magnitude)
76
+ elif op_name == "Posterize":
77
+ img = F.posterize(img, int(magnitude))
78
+ elif op_name == "Solarize":
79
+ img = F.solarize(img, magnitude)
80
+ elif op_name == "AutoContrast":
81
+ img = F.autocontrast(img)
82
+ elif op_name == "Equalize":
83
+ img = F.equalize(img)
84
+ elif op_name == "Invert":
85
+ img = F.invert(img)
86
+ elif op_name == "Identity":
87
+ pass
88
+ else:
89
+ raise ValueError(f"The provided operator {op_name} is not recognized.")
90
+ return img
91
+
92
+
93
+ class AutoAugmentPolicy(Enum):
94
+ """AutoAugment policies learned on different datasets.
95
+ Available policies are IMAGENET, CIFAR10 and SVHN.
96
+ """
97
+
98
+ IMAGENET = "imagenet"
99
+ CIFAR10 = "cifar10"
100
+ SVHN = "svhn"
101
+
102
+
103
+ # FIXME: Eliminate copy-pasted code for fill standardization and _augmentation_space() by moving stuff on a base class
104
+ class AutoAugment(torch.nn.Module):
105
+ r"""AutoAugment data augmentation method based on
106
+ `"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
107
+ If the image is torch Tensor, it should be of type torch.uint8, and it is expected
108
+ to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
109
+ If img is PIL Image, it is expected to be in mode "L" or "RGB".
110
+
111
+ Args:
112
+ policy (AutoAugmentPolicy): Desired policy enum defined by
113
+ :class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``.
114
+ interpolation (InterpolationMode): Desired interpolation enum defined by
115
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
116
+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
117
+ fill (sequence or number, optional): Pixel fill value for the area outside the transformed
118
+ image. If given a number, the value is used for all bands respectively.
119
+ """
120
+
121
+ def __init__(
122
+ self,
123
+ policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
124
+ interpolation: InterpolationMode = InterpolationMode.NEAREST,
125
+ fill: Optional[List[float]] = None,
126
+ ) -> None:
127
+ super().__init__()
128
+ self.policy = policy
129
+ self.interpolation = interpolation
130
+ self.fill = fill
131
+ self.policies = self._get_policies(policy)
132
+
133
+ def _get_policies(
134
+ self, policy: AutoAugmentPolicy
135
+ ) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
136
+ if policy == AutoAugmentPolicy.IMAGENET:
137
+ return [
138
+ (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
139
+ (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
140
+ (("Equalize", 0.8, None), ("Equalize", 0.6, None)),
141
+ (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)),
142
+ (("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
143
+ (("Equalize", 0.4, None), ("Rotate", 0.8, 8)),
144
+ (("Solarize", 0.6, 3), ("Equalize", 0.6, None)),
145
+ (("Posterize", 0.8, 5), ("Equalize", 1.0, None)),
146
+ (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)),
147
+ (("Equalize", 0.6, None), ("Posterize", 0.4, 6)),
148
+ (("Rotate", 0.8, 8), ("Color", 0.4, 0)),
149
+ (("Rotate", 0.4, 9), ("Equalize", 0.6, None)),
150
+ (("Equalize", 0.0, None), ("Equalize", 0.8, None)),
151
+ (("Invert", 0.6, None), ("Equalize", 1.0, None)),
152
+ (("Color", 0.6, 4), ("Contrast", 1.0, 8)),
153
+ (("Rotate", 0.8, 8), ("Color", 1.0, 2)),
154
+ (("Color", 0.8, 8), ("Solarize", 0.8, 7)),
155
+ (("Sharpness", 0.4, 7), ("Invert", 0.6, None)),
156
+ (("ShearX", 0.6, 5), ("Equalize", 1.0, None)),
157
+ (("Color", 0.4, 0), ("Equalize", 0.6, None)),
158
+ (("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
159
+ (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
160
+ (("Invert", 0.6, None), ("Equalize", 1.0, None)),
161
+ (("Color", 0.6, 4), ("Contrast", 1.0, 8)),
162
+ (("Equalize", 0.8, None), ("Equalize", 0.6, None)),
163
+ ]
164
+ elif policy == AutoAugmentPolicy.CIFAR10:
165
+ return [
166
+ (("Invert", 0.1, None), ("Contrast", 0.2, 6)),
167
+ (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)),
168
+ (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
169
+ (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)),
170
+ (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)),
171
+ (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)),
172
+ (("Color", 0.4, 3), ("Brightness", 0.6, 7)),
173
+ (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)),
174
+ (("Equalize", 0.6, None), ("Equalize", 0.5, None)),
175
+ (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)),
176
+ (("Color", 0.7, 7), ("TranslateX", 0.5, 8)),
177
+ (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)),
178
+ (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)),
179
+ (("Brightness", 0.9, 6), ("Color", 0.2, 8)),
180
+ (("Solarize", 0.5, 2), ("Invert", 0.0, None)),
181
+ (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)),
182
+ (("Equalize", 0.2, None), ("Equalize", 0.6, None)),
183
+ (("Color", 0.9, 9), ("Equalize", 0.6, None)),
184
+ (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)),
185
+ (("Brightness", 0.1, 3), ("Color", 0.7, 0)),
186
+ (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)),
187
+ (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)),
188
+ (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)),
189
+ (("Equalize", 0.8, None), ("Invert", 0.1, None)),
190
+ (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)),
191
+ ]
192
+ elif policy == AutoAugmentPolicy.SVHN:
193
+ return [
194
+ (("ShearX", 0.9, 4), ("Invert", 0.2, None)),
195
+ (("ShearY", 0.9, 8), ("Invert", 0.7, None)),
196
+ (("Equalize", 0.6, None), ("Solarize", 0.6, 6)),
197
+ (("Invert", 0.9, None), ("Equalize", 0.6, None)),
198
+ (("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
199
+ (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)),
200
+ (("ShearY", 0.9, 8), ("Invert", 0.4, None)),
201
+ (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)),
202
+ (("Invert", 0.9, None), ("AutoContrast", 0.8, None)),
203
+ (("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
204
+ (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)),
205
+ (("ShearY", 0.8, 8), ("Invert", 0.7, None)),
206
+ (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)),
207
+ (("Invert", 0.9, None), ("Equalize", 0.6, None)),
208
+ (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)),
209
+ (("Invert", 0.8, None), ("TranslateY", 0.0, 2)),
210
+ (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)),
211
+ (("Invert", 0.6, None), ("Rotate", 0.8, 4)),
212
+ (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)),
213
+ (("ShearX", 0.1, 6), ("Invert", 0.6, None)),
214
+ (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)),
215
+ (("ShearY", 0.8, 4), ("Invert", 0.8, None)),
216
+ (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)),
217
+ (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)),
218
+ (("ShearX", 0.7, 2), ("Invert", 0.1, None)),
219
+ ]
220
+ else:
221
+ raise ValueError(f"The provided policy {policy} is not recognized.")
222
+
223
+ def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[Tensor, bool]]:
224
+ return {
225
+ # op_name: (magnitudes, signed)
226
+ "ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
227
+ "ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
228
+ "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
229
+ "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
230
+ "Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
231
+ "Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
232
+ "Color": (torch.linspace(0.0, 0.9, num_bins), True),
233
+ "Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
234
+ "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
235
+ "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
236
+ "Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
237
+ "AutoContrast": (torch.tensor(0.0), False),
238
+ "Equalize": (torch.tensor(0.0), False),
239
+ "Invert": (torch.tensor(0.0), False),
240
+ }
241
+
242
+ @staticmethod
243
+ def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]:
244
+ """Get parameters for autoaugment transformation
245
+
246
+ Returns:
247
+ params required by the autoaugment transformation
248
+ """
249
+ policy_id = int(torch.randint(transform_num, (1,)).item())
250
+ probs = torch.rand((2,))
251
+ signs = torch.randint(2, (2,))
252
+
253
+ return policy_id, probs, signs
254
+
255
+ def forward(self, img: Tensor) -> Tensor:
256
+ """
257
+ img (PIL Image or Tensor): Image to be transformed.
258
+
259
+ Returns:
260
+ PIL Image or Tensor: AutoAugmented image.
261
+ """
262
+ fill = self.fill
263
+ channels, height, width = F.get_dimensions(img)
264
+ if isinstance(img, Tensor):
265
+ if isinstance(fill, (int, float)):
266
+ fill = [float(fill)] * channels
267
+ elif fill is not None:
268
+ fill = [float(f) for f in fill]
269
+
270
+ transform_id, probs, signs = self.get_params(len(self.policies))
271
+
272
+ op_meta = self._augmentation_space(10, (height, width))
273
+ for i, (op_name, p, magnitude_id) in enumerate(self.policies[transform_id]):
274
+ if probs[i] <= p:
275
+ magnitudes, signed = op_meta[op_name]
276
+ magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0
277
+ if signed and signs[i] == 0:
278
+ magnitude *= -1.0
279
+ img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
280
+
281
+ return img
282
+
283
+ def __repr__(self) -> str:
284
+ return f"{self.__class__.__name__}(policy={self.policy}, fill={self.fill})"
285
+
286
+
287
+ class RandAugment(torch.nn.Module):
288
+ r"""RandAugment data augmentation method based on
289
+ `"RandAugment: Practical automated data augmentation with a reduced search space"
290
+ <https://arxiv.org/abs/1909.13719>`_.
291
+ If the image is torch Tensor, it should be of type torch.uint8, and it is expected
292
+ to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
293
+ If img is PIL Image, it is expected to be in mode "L" or "RGB".
294
+
295
+ Args:
296
+ num_ops (int): Number of augmentation transformations to apply sequentially.
297
+ magnitude (int): Magnitude for all the transformations.
298
+ num_magnitude_bins (int): The number of different magnitude values.
299
+ interpolation (InterpolationMode): Desired interpolation enum defined by
300
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
301
+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
302
+ fill (sequence or number, optional): Pixel fill value for the area outside the transformed
303
+ image. If given a number, the value is used for all bands respectively.
304
+ """
305
+
306
+ def __init__(
307
+ self,
308
+ num_ops: int = 2,
309
+ magnitude: int = 9,
310
+ num_magnitude_bins: int = 31,
311
+ interpolation: InterpolationMode = InterpolationMode.NEAREST,
312
+ fill: Optional[List[float]] = None,
313
+ ) -> None:
314
+ super().__init__()
315
+ self.num_ops = num_ops
316
+ self.magnitude = magnitude
317
+ self.num_magnitude_bins = num_magnitude_bins
318
+ self.interpolation = interpolation
319
+ self.fill = fill
320
+
321
+ def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[Tensor, bool]]:
322
+ return {
323
+ # op_name: (magnitudes, signed)
324
+ "Identity": (torch.tensor(0.0), False),
325
+ "ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
326
+ "ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
327
+ "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
328
+ "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
329
+ "Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
330
+ "Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
331
+ "Color": (torch.linspace(0.0, 0.9, num_bins), True),
332
+ "Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
333
+ "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
334
+ "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
335
+ "Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
336
+ "AutoContrast": (torch.tensor(0.0), False),
337
+ "Equalize": (torch.tensor(0.0), False),
338
+ }
339
+
340
+ def forward(self, img: Tensor) -> Tensor:
341
+ """
342
+ img (PIL Image or Tensor): Image to be transformed.
343
+
344
+ Returns:
345
+ PIL Image or Tensor: Transformed image.
346
+ """
347
+ fill = self.fill
348
+ channels, height, width = F.get_dimensions(img)
349
+ if isinstance(img, Tensor):
350
+ if isinstance(fill, (int, float)):
351
+ fill = [float(fill)] * channels
352
+ elif fill is not None:
353
+ fill = [float(f) for f in fill]
354
+
355
+ op_meta = self._augmentation_space(self.num_magnitude_bins, (height, width))
356
+ for _ in range(self.num_ops):
357
+ op_index = int(torch.randint(len(op_meta), (1,)).item())
358
+ op_name = list(op_meta.keys())[op_index]
359
+ magnitudes, signed = op_meta[op_name]
360
+ magnitude = float(magnitudes[self.magnitude].item()) if magnitudes.ndim > 0 else 0.0
361
+ if signed and torch.randint(2, (1,)):
362
+ magnitude *= -1.0
363
+ img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
364
+
365
+ return img
366
+
367
+ def __repr__(self) -> str:
368
+ s = (
369
+ f"{self.__class__.__name__}("
370
+ f"num_ops={self.num_ops}"
371
+ f", magnitude={self.magnitude}"
372
+ f", num_magnitude_bins={self.num_magnitude_bins}"
373
+ f", interpolation={self.interpolation}"
374
+ f", fill={self.fill}"
375
+ f")"
376
+ )
377
+ return s
378
+
379
+
380
+ class TrivialAugmentWide(torch.nn.Module):
381
+ r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in
382
+ `"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`_.
383
+ If the image is torch Tensor, it should be of type torch.uint8, and it is expected
384
+ to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
385
+ If img is PIL Image, it is expected to be in mode "L" or "RGB".
386
+
387
+ Args:
388
+ num_magnitude_bins (int): The number of different magnitude values.
389
+ interpolation (InterpolationMode): Desired interpolation enum defined by
390
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
391
+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
392
+ fill (sequence or number, optional): Pixel fill value for the area outside the transformed
393
+ image. If given a number, the value is used for all bands respectively.
394
+ """
395
+
396
+ def __init__(
397
+ self,
398
+ num_magnitude_bins: int = 31,
399
+ interpolation: InterpolationMode = InterpolationMode.NEAREST,
400
+ fill: Optional[List[float]] = None,
401
+ ) -> None:
402
+ super().__init__()
403
+ self.num_magnitude_bins = num_magnitude_bins
404
+ self.interpolation = interpolation
405
+ self.fill = fill
406
+
407
+ def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]:
408
+ return {
409
+ # op_name: (magnitudes, signed)
410
+ "Identity": (torch.tensor(0.0), False),
411
+ "ShearX": (torch.linspace(0.0, 0.99, num_bins), True),
412
+ "ShearY": (torch.linspace(0.0, 0.99, num_bins), True),
413
+ "TranslateX": (torch.linspace(0.0, 32.0, num_bins), True),
414
+ "TranslateY": (torch.linspace(0.0, 32.0, num_bins), True),
415
+ "Rotate": (torch.linspace(0.0, 135.0, num_bins), True),
416
+ "Brightness": (torch.linspace(0.0, 0.99, num_bins), True),
417
+ "Color": (torch.linspace(0.0, 0.99, num_bins), True),
418
+ "Contrast": (torch.linspace(0.0, 0.99, num_bins), True),
419
+ "Sharpness": (torch.linspace(0.0, 0.99, num_bins), True),
420
+ "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False),
421
+ "Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
422
+ "AutoContrast": (torch.tensor(0.0), False),
423
+ "Equalize": (torch.tensor(0.0), False),
424
+ }
425
+
426
+ def forward(self, img: Tensor) -> Tensor:
427
+ """
428
+ img (PIL Image or Tensor): Image to be transformed.
429
+
430
+ Returns:
431
+ PIL Image or Tensor: Transformed image.
432
+ """
433
+ fill = self.fill
434
+ channels, height, width = F.get_dimensions(img)
435
+ if isinstance(img, Tensor):
436
+ if isinstance(fill, (int, float)):
437
+ fill = [float(fill)] * channels
438
+ elif fill is not None:
439
+ fill = [float(f) for f in fill]
440
+
441
+ op_meta = self._augmentation_space(self.num_magnitude_bins)
442
+ op_index = int(torch.randint(len(op_meta), (1,)).item())
443
+ op_name = list(op_meta.keys())[op_index]
444
+ magnitudes, signed = op_meta[op_name]
445
+ magnitude = (
446
+ float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item())
447
+ if magnitudes.ndim > 0
448
+ else 0.0
449
+ )
450
+ if signed and torch.randint(2, (1,)):
451
+ magnitude *= -1.0
452
+
453
+ return _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
454
+
455
+ def __repr__(self) -> str:
456
+ s = (
457
+ f"{self.__class__.__name__}("
458
+ f"num_magnitude_bins={self.num_magnitude_bins}"
459
+ f", interpolation={self.interpolation}"
460
+ f", fill={self.fill}"
461
+ f")"
462
+ )
463
+ return s
464
+
465
+
466
+ class AugMix(torch.nn.Module):
467
+ r"""AugMix data augmentation method based on
468
+ `"AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty" <https://arxiv.org/abs/1912.02781>`_.
469
+ If the image is torch Tensor, it should be of type torch.uint8, and it is expected
470
+ to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
471
+ If img is PIL Image, it is expected to be in mode "L" or "RGB".
472
+
473
+ Args:
474
+ severity (int): The severity of base augmentation operators. Default is ``3``.
475
+ mixture_width (int): The number of augmentation chains. Default is ``3``.
476
+ chain_depth (int): The depth of augmentation chains. A negative value denotes stochastic depth sampled from the interval [1, 3].
477
+ Default is ``-1``.
478
+ alpha (float): The hyperparameter for the probability distributions. Default is ``1.0``.
479
+ all_ops (bool): Use all operations (including brightness, contrast, color and sharpness). Default is ``True``.
480
+ interpolation (InterpolationMode): Desired interpolation enum defined by
481
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
482
+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
483
+ fill (sequence or number, optional): Pixel fill value for the area outside the transformed
484
+ image. If given a number, the value is used for all bands respectively.
485
+ """
486
+
487
+ def __init__(
488
+ self,
489
+ severity: int = 3,
490
+ mixture_width: int = 3,
491
+ chain_depth: int = -1,
492
+ alpha: float = 1.0,
493
+ all_ops: bool = True,
494
+ interpolation: InterpolationMode = InterpolationMode.BILINEAR,
495
+ fill: Optional[List[float]] = None,
496
+ ) -> None:
497
+ super().__init__()
498
+ self._PARAMETER_MAX = 10
499
+ if not (1 <= severity <= self._PARAMETER_MAX):
500
+ raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.")
501
+ self.severity = severity
502
+ self.mixture_width = mixture_width
503
+ self.chain_depth = chain_depth
504
+ self.alpha = alpha
505
+ self.all_ops = all_ops
506
+ self.interpolation = interpolation
507
+ self.fill = fill
508
+
509
+ def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[Tensor, bool]]:
510
+ s = {
511
+ # op_name: (magnitudes, signed)
512
+ "ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
513
+ "ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
514
+ "TranslateX": (torch.linspace(0.0, image_size[1] / 3.0, num_bins), True),
515
+ "TranslateY": (torch.linspace(0.0, image_size[0] / 3.0, num_bins), True),
516
+ "Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
517
+ "Posterize": (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
518
+ "Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
519
+ "AutoContrast": (torch.tensor(0.0), False),
520
+ "Equalize": (torch.tensor(0.0), False),
521
+ }
522
+ if self.all_ops:
523
+ s.update(
524
+ {
525
+ "Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
526
+ "Color": (torch.linspace(0.0, 0.9, num_bins), True),
527
+ "Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
528
+ "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
529
+ }
530
+ )
531
+ return s
532
+
533
+ @torch.jit.unused
534
+ def _pil_to_tensor(self, img) -> Tensor:
535
+ return F.pil_to_tensor(img)
536
+
537
+ @torch.jit.unused
538
+ def _tensor_to_pil(self, img: Tensor):
539
+ return F.to_pil_image(img)
540
+
541
+ def _sample_dirichlet(self, params: Tensor) -> Tensor:
542
+ # Must be on a separate method so that we can overwrite it in tests.
543
+ return torch._sample_dirichlet(params)
544
+
545
+ def forward(self, orig_img: Tensor) -> Tensor:
546
+ """
547
+ img (PIL Image or Tensor): Image to be transformed.
548
+
549
+ Returns:
550
+ PIL Image or Tensor: Transformed image.
551
+ """
552
+ fill = self.fill
553
+ channels, height, width = F.get_dimensions(orig_img)
554
+ if isinstance(orig_img, Tensor):
555
+ img = orig_img
556
+ if isinstance(fill, (int, float)):
557
+ fill = [float(fill)] * channels
558
+ elif fill is not None:
559
+ fill = [float(f) for f in fill]
560
+ else:
561
+ img = self._pil_to_tensor(orig_img)
562
+
563
+ op_meta = self._augmentation_space(self._PARAMETER_MAX, (height, width))
564
+
565
+ orig_dims = list(img.shape)
566
+ batch = img.view([1] * max(4 - img.ndim, 0) + orig_dims)
567
+ batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)
568
+
569
+ # Sample the beta weights for combining the original and augmented image. To get Beta, we use a Dirichlet
570
+ # with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of augmented image.
571
+ m = self._sample_dirichlet(
572
+ torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1)
573
+ )
574
+
575
+ # Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images.
576
+ combined_weights = self._sample_dirichlet(
577
+ torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1)
578
+ ) * m[:, 1].view([batch_dims[0], -1])
579
+
580
+ mix = m[:, 0].view(batch_dims) * batch
581
+ for i in range(self.mixture_width):
582
+ aug = batch
583
+ depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item())
584
+ for _ in range(depth):
585
+ op_index = int(torch.randint(len(op_meta), (1,)).item())
586
+ op_name = list(op_meta.keys())[op_index]
587
+ magnitudes, signed = op_meta[op_name]
588
+ magnitude = (
589
+ float(magnitudes[torch.randint(self.severity, (1,), dtype=torch.long)].item())
590
+ if magnitudes.ndim > 0
591
+ else 0.0
592
+ )
593
+ if signed and torch.randint(2, (1,)):
594
+ magnitude *= -1.0
595
+ aug = _apply_op(aug, op_name, magnitude, interpolation=self.interpolation, fill=fill)
596
+ mix.add_(combined_weights[:, i].view(batch_dims) * aug)
597
+ mix = mix.view(orig_dims).to(dtype=img.dtype)
598
+
599
+ if not isinstance(orig_img, Tensor):
600
+ return self._tensor_to_pil(mix)
601
+ return mix
602
+
603
+ def __repr__(self) -> str:
604
+ s = (
605
+ f"{self.__class__.__name__}("
606
+ f"severity={self.severity}"
607
+ f", mixture_width={self.mixture_width}"
608
+ f", chain_depth={self.chain_depth}"
609
+ f", alpha={self.alpha}"
610
+ f", all_ops={self.all_ops}"
611
+ f", interpolation={self.interpolation}"
612
+ f", fill={self.fill}"
613
+ f")"
614
+ )
615
+ return s
.venv/lib/python3.11/site-packages/torchvision/transforms/functional.py ADDED
@@ -0,0 +1,1586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numbers
3
+ import sys
4
+ import warnings
5
+ from enum import Enum
6
+ from typing import Any, List, Optional, Tuple, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ from PIL import Image
11
+ from PIL.Image import Image as PILImage
12
+ from torch import Tensor
13
+
14
+ try:
15
+ import accimage
16
+ except ImportError:
17
+ accimage = None
18
+
19
+ from ..utils import _log_api_usage_once
20
+ from . import _functional_pil as F_pil, _functional_tensor as F_t
21
+
22
+
23
+ class InterpolationMode(Enum):
24
+ """Interpolation modes
25
+ Available interpolation methods are ``nearest``, ``nearest-exact``, ``bilinear``, ``bicubic``, ``box``, ``hamming``,
26
+ and ``lanczos``.
27
+ """
28
+
29
+ NEAREST = "nearest"
30
+ NEAREST_EXACT = "nearest-exact"
31
+ BILINEAR = "bilinear"
32
+ BICUBIC = "bicubic"
33
+ # For PIL compatibility
34
+ BOX = "box"
35
+ HAMMING = "hamming"
36
+ LANCZOS = "lanczos"
37
+
38
+
39
+ # TODO: Once torchscript supports Enums with staticmethod
40
+ # this can be put into InterpolationMode as staticmethod
41
+ def _interpolation_modes_from_int(i: int) -> InterpolationMode:
42
+ inverse_modes_mapping = {
43
+ 0: InterpolationMode.NEAREST,
44
+ 2: InterpolationMode.BILINEAR,
45
+ 3: InterpolationMode.BICUBIC,
46
+ 4: InterpolationMode.BOX,
47
+ 5: InterpolationMode.HAMMING,
48
+ 1: InterpolationMode.LANCZOS,
49
+ }
50
+ return inverse_modes_mapping[i]
51
+
52
+
53
+ pil_modes_mapping = {
54
+ InterpolationMode.NEAREST: 0,
55
+ InterpolationMode.BILINEAR: 2,
56
+ InterpolationMode.BICUBIC: 3,
57
+ InterpolationMode.NEAREST_EXACT: 0,
58
+ InterpolationMode.BOX: 4,
59
+ InterpolationMode.HAMMING: 5,
60
+ InterpolationMode.LANCZOS: 1,
61
+ }
62
+
63
+ _is_pil_image = F_pil._is_pil_image
64
+
65
+
66
+ def get_dimensions(img: Tensor) -> List[int]:
67
+ """Returns the dimensions of an image as [channels, height, width].
68
+
69
+ Args:
70
+ img (PIL Image or Tensor): The image to be checked.
71
+
72
+ Returns:
73
+ List[int]: The image dimensions.
74
+ """
75
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
76
+ _log_api_usage_once(get_dimensions)
77
+ if isinstance(img, torch.Tensor):
78
+ return F_t.get_dimensions(img)
79
+
80
+ return F_pil.get_dimensions(img)
81
+
82
+
83
+ def get_image_size(img: Tensor) -> List[int]:
84
+ """Returns the size of an image as [width, height].
85
+
86
+ Args:
87
+ img (PIL Image or Tensor): The image to be checked.
88
+
89
+ Returns:
90
+ List[int]: The image size.
91
+ """
92
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
93
+ _log_api_usage_once(get_image_size)
94
+ if isinstance(img, torch.Tensor):
95
+ return F_t.get_image_size(img)
96
+
97
+ return F_pil.get_image_size(img)
98
+
99
+
100
+ def get_image_num_channels(img: Tensor) -> int:
101
+ """Returns the number of channels of an image.
102
+
103
+ Args:
104
+ img (PIL Image or Tensor): The image to be checked.
105
+
106
+ Returns:
107
+ int: The number of channels.
108
+ """
109
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
110
+ _log_api_usage_once(get_image_num_channels)
111
+ if isinstance(img, torch.Tensor):
112
+ return F_t.get_image_num_channels(img)
113
+
114
+ return F_pil.get_image_num_channels(img)
115
+
116
+
117
+ @torch.jit.unused
118
+ def _is_numpy(img: Any) -> bool:
119
+ return isinstance(img, np.ndarray)
120
+
121
+
122
+ @torch.jit.unused
123
+ def _is_numpy_image(img: Any) -> bool:
124
+ return img.ndim in {2, 3}
125
+
126
+
127
+ def to_tensor(pic: Union[PILImage, np.ndarray]) -> Tensor:
128
+ """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
129
+ This function does not support torchscript.
130
+
131
+ See :class:`~torchvision.transforms.ToTensor` for more details.
132
+
133
+ Args:
134
+ pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
135
+
136
+ Returns:
137
+ Tensor: Converted image.
138
+ """
139
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
140
+ _log_api_usage_once(to_tensor)
141
+ if not (F_pil._is_pil_image(pic) or _is_numpy(pic)):
142
+ raise TypeError(f"pic should be PIL Image or ndarray. Got {type(pic)}")
143
+
144
+ if _is_numpy(pic) and not _is_numpy_image(pic):
145
+ raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndim} dimensions.")
146
+
147
+ default_float_dtype = torch.get_default_dtype()
148
+
149
+ if isinstance(pic, np.ndarray):
150
+ # handle numpy array
151
+ if pic.ndim == 2:
152
+ pic = pic[:, :, None]
153
+
154
+ img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
155
+ # backward compatibility
156
+ if isinstance(img, torch.ByteTensor):
157
+ return img.to(dtype=default_float_dtype).div(255)
158
+ else:
159
+ return img
160
+
161
+ if accimage is not None and isinstance(pic, accimage.Image):
162
+ nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
163
+ pic.copyto(nppic)
164
+ return torch.from_numpy(nppic).to(dtype=default_float_dtype)
165
+
166
+ # handle PIL Image
167
+ mode_to_nptype = {"I": np.int32, "I;16" if sys.byteorder == "little" else "I;16B": np.int16, "F": np.float32}
168
+ img = torch.from_numpy(np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True))
169
+
170
+ if pic.mode == "1":
171
+ img = 255 * img
172
+ img = img.view(pic.size[1], pic.size[0], F_pil.get_image_num_channels(pic))
173
+ # put it from HWC to CHW format
174
+ img = img.permute((2, 0, 1)).contiguous()
175
+ if isinstance(img, torch.ByteTensor):
176
+ return img.to(dtype=default_float_dtype).div(255)
177
+ else:
178
+ return img
179
+
180
+
181
+ def pil_to_tensor(pic: Any) -> Tensor:
182
+ """Convert a ``PIL Image`` to a tensor of the same type.
183
+ This function does not support torchscript.
184
+
185
+ See :class:`~torchvision.transforms.PILToTensor` for more details.
186
+
187
+ .. note::
188
+
189
+ A deep copy of the underlying array is performed.
190
+
191
+ Args:
192
+ pic (PIL Image): Image to be converted to tensor.
193
+
194
+ Returns:
195
+ Tensor: Converted image.
196
+ """
197
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
198
+ _log_api_usage_once(pil_to_tensor)
199
+ if not F_pil._is_pil_image(pic):
200
+ raise TypeError(f"pic should be PIL Image. Got {type(pic)}")
201
+
202
+ if accimage is not None and isinstance(pic, accimage.Image):
203
+ # accimage format is always uint8 internally, so always return uint8 here
204
+ nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.uint8)
205
+ pic.copyto(nppic)
206
+ return torch.as_tensor(nppic)
207
+
208
+ # handle PIL Image
209
+ img = torch.as_tensor(np.array(pic, copy=True))
210
+ img = img.view(pic.size[1], pic.size[0], F_pil.get_image_num_channels(pic))
211
+ # put it from HWC to CHW format
212
+ img = img.permute((2, 0, 1))
213
+ return img
214
+
215
+
216
+ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
217
+ """Convert a tensor image to the given ``dtype`` and scale the values accordingly
218
+ This function does not support PIL Image.
219
+
220
+ Args:
221
+ image (torch.Tensor): Image to be converted
222
+ dtype (torch.dtype): Desired data type of the output
223
+
224
+ Returns:
225
+ Tensor: Converted image
226
+
227
+ .. note::
228
+
229
+ When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
230
+ If converted back and forth, this mismatch has no effect.
231
+
232
+ Raises:
233
+ RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
234
+ well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
235
+ overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
236
+ of the integer ``dtype``.
237
+ """
238
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
239
+ _log_api_usage_once(convert_image_dtype)
240
+ if not isinstance(image, torch.Tensor):
241
+ raise TypeError("Input img should be Tensor Image")
242
+
243
+ return F_t.convert_image_dtype(image, dtype)
244
+
245
+
246
+ def to_pil_image(pic, mode=None):
247
+ """Convert a tensor or an ndarray to PIL Image. This function does not support torchscript.
248
+
249
+ See :class:`~torchvision.transforms.ToPILImage` for more details.
250
+
251
+ Args:
252
+ pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
253
+ mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
254
+
255
+ .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
256
+
257
+ Returns:
258
+ PIL Image: Image converted to PIL Image.
259
+ """
260
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
261
+ _log_api_usage_once(to_pil_image)
262
+
263
+ if isinstance(pic, torch.Tensor):
264
+ if pic.ndim == 3:
265
+ pic = pic.permute((1, 2, 0))
266
+ pic = pic.numpy(force=True)
267
+ elif not isinstance(pic, np.ndarray):
268
+ raise TypeError(f"pic should be Tensor or ndarray. Got {type(pic)}.")
269
+
270
+ if pic.ndim == 2:
271
+ # if 2D image, add channel dimension (HWC)
272
+ pic = np.expand_dims(pic, 2)
273
+ if pic.ndim != 3:
274
+ raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndim} dimensions.")
275
+
276
+ if pic.shape[-1] > 4:
277
+ raise ValueError(f"pic should not have > 4 channels. Got {pic.shape[-1]} channels.")
278
+
279
+ npimg = pic
280
+
281
+ if np.issubdtype(npimg.dtype, np.floating) and mode != "F":
282
+ npimg = (npimg * 255).astype(np.uint8)
283
+
284
+ if npimg.shape[2] == 1:
285
+ expected_mode = None
286
+ npimg = npimg[:, :, 0]
287
+ if npimg.dtype == np.uint8:
288
+ expected_mode = "L"
289
+ elif npimg.dtype == np.int16:
290
+ expected_mode = "I;16" if sys.byteorder == "little" else "I;16B"
291
+ elif npimg.dtype == np.int32:
292
+ expected_mode = "I"
293
+ elif npimg.dtype == np.float32:
294
+ expected_mode = "F"
295
+ if mode is not None and mode != expected_mode:
296
+ raise ValueError(f"Incorrect mode ({mode}) supplied for input type {np.dtype}. Should be {expected_mode}")
297
+ mode = expected_mode
298
+
299
+ elif npimg.shape[2] == 2:
300
+ permitted_2_channel_modes = ["LA"]
301
+ if mode is not None and mode not in permitted_2_channel_modes:
302
+ raise ValueError(f"Only modes {permitted_2_channel_modes} are supported for 2D inputs")
303
+
304
+ if mode is None and npimg.dtype == np.uint8:
305
+ mode = "LA"
306
+
307
+ elif npimg.shape[2] == 4:
308
+ permitted_4_channel_modes = ["RGBA", "CMYK", "RGBX"]
309
+ if mode is not None and mode not in permitted_4_channel_modes:
310
+ raise ValueError(f"Only modes {permitted_4_channel_modes} are supported for 4D inputs")
311
+
312
+ if mode is None and npimg.dtype == np.uint8:
313
+ mode = "RGBA"
314
+ else:
315
+ permitted_3_channel_modes = ["RGB", "YCbCr", "HSV"]
316
+ if mode is not None and mode not in permitted_3_channel_modes:
317
+ raise ValueError(f"Only modes {permitted_3_channel_modes} are supported for 3D inputs")
318
+ if mode is None and npimg.dtype == np.uint8:
319
+ mode = "RGB"
320
+
321
+ if mode is None:
322
+ raise TypeError(f"Input type {npimg.dtype} is not supported")
323
+
324
+ return Image.fromarray(npimg, mode=mode)
325
+
326
+
327
+ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor:
328
+ """Normalize a float tensor image with mean and standard deviation.
329
+ This transform does not support PIL Image.
330
+
331
+ .. note::
332
+ This transform acts out of place by default, i.e., it does not mutates the input tensor.
333
+
334
+ See :class:`~torchvision.transforms.Normalize` for more details.
335
+
336
+ Args:
337
+ tensor (Tensor): Float tensor image of size (C, H, W) or (B, C, H, W) to be normalized.
338
+ mean (sequence): Sequence of means for each channel.
339
+ std (sequence): Sequence of standard deviations for each channel.
340
+ inplace(bool,optional): Bool to make this operation inplace.
341
+
342
+ Returns:
343
+ Tensor: Normalized Tensor image.
344
+ """
345
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
346
+ _log_api_usage_once(normalize)
347
+ if not isinstance(tensor, torch.Tensor):
348
+ raise TypeError(f"img should be Tensor Image. Got {type(tensor)}")
349
+
350
+ return F_t.normalize(tensor, mean=mean, std=std, inplace=inplace)
351
+
352
+
353
+ def _compute_resized_output_size(
354
+ image_size: Tuple[int, int],
355
+ size: Optional[List[int]],
356
+ max_size: Optional[int] = None,
357
+ allow_size_none: bool = False, # only True in v2
358
+ ) -> List[int]:
359
+ h, w = image_size
360
+ short, long = (w, h) if w <= h else (h, w)
361
+ if size is None:
362
+ if not allow_size_none:
363
+ raise ValueError("This should never happen!!")
364
+ if not isinstance(max_size, int):
365
+ raise ValueError(f"max_size must be an integer when size is None, but got {max_size} instead.")
366
+ new_short, new_long = int(max_size * short / long), max_size
367
+ new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
368
+ elif len(size) == 1: # specified size only for the smallest edge
369
+ requested_new_short = size if isinstance(size, int) else size[0]
370
+ new_short, new_long = requested_new_short, int(requested_new_short * long / short)
371
+
372
+ if max_size is not None:
373
+ if max_size <= requested_new_short:
374
+ raise ValueError(
375
+ f"max_size = {max_size} must be strictly greater than the requested "
376
+ f"size for the smaller edge size = {size}"
377
+ )
378
+ if new_long > max_size:
379
+ new_short, new_long = int(max_size * new_short / new_long), max_size
380
+
381
+ new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
382
+ else: # specified both h and w
383
+ new_w, new_h = size[1], size[0]
384
+ return [new_h, new_w]
385
+
386
+
387
+ def resize(
388
+ img: Tensor,
389
+ size: List[int],
390
+ interpolation: InterpolationMode = InterpolationMode.BILINEAR,
391
+ max_size: Optional[int] = None,
392
+ antialias: Optional[bool] = True,
393
+ ) -> Tensor:
394
+ r"""Resize the input image to the given size.
395
+ If the image is torch Tensor, it is expected
396
+ to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
397
+
398
+ Args:
399
+ img (PIL Image or Tensor): Image to be resized.
400
+ size (sequence or int): Desired output size. If size is a sequence like
401
+ (h, w), the output size will be matched to this. If size is an int,
402
+ the smaller edge of the image will be matched to this number maintaining
403
+ the aspect ratio. i.e, if height > width, then image will be rescaled to
404
+ :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`.
405
+
406
+ .. note::
407
+ In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
408
+ interpolation (InterpolationMode): Desired interpolation enum defined by
409
+ :class:`torchvision.transforms.InterpolationMode`.
410
+ Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
411
+ ``InterpolationMode.NEAREST_EXACT``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are
412
+ supported.
413
+ The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
414
+ max_size (int, optional): The maximum allowed for the longer edge of
415
+ the resized image. If the longer edge of the image is greater
416
+ than ``max_size`` after being resized according to ``size``,
417
+ ``size`` will be overruled so that the longer edge is equal to
418
+ ``max_size``.
419
+ As a result, the smaller edge may be shorter than ``size``. This
420
+ is only supported if ``size`` is an int (or a sequence of length
421
+ 1 in torchscript mode).
422
+ antialias (bool, optional): Whether to apply antialiasing.
423
+ It only affects **tensors** with bilinear or bicubic modes and it is
424
+ ignored otherwise: on PIL images, antialiasing is always applied on
425
+ bilinear or bicubic modes; on other modes (for PIL images and
426
+ tensors), antialiasing makes no sense and this parameter is ignored.
427
+ Possible values are:
428
+
429
+ - ``True`` (default): will apply antialiasing for bilinear or bicubic modes.
430
+ Other mode aren't affected. This is probably what you want to use.
431
+ - ``False``: will not apply antialiasing for tensors on any mode. PIL
432
+ images are still antialiased on bilinear or bicubic modes, because
433
+ PIL doesn't support no antialias.
434
+ - ``None``: equivalent to ``False`` for tensors and ``True`` for
435
+ PIL images. This value exists for legacy reasons and you probably
436
+ don't want to use it unless you really know what you are doing.
437
+
438
+ The default value changed from ``None`` to ``True`` in
439
+ v0.17, for the PIL and Tensor backends to be consistent.
440
+
441
+ Returns:
442
+ PIL Image or Tensor: Resized image.
443
+ """
444
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
445
+ _log_api_usage_once(resize)
446
+
447
+ if isinstance(interpolation, int):
448
+ interpolation = _interpolation_modes_from_int(interpolation)
449
+ elif not isinstance(interpolation, InterpolationMode):
450
+ raise TypeError(
451
+ "Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
452
+ )
453
+
454
+ if isinstance(size, (list, tuple)):
455
+ if len(size) not in [1, 2]:
456
+ raise ValueError(
457
+ f"Size must be an int or a 1 or 2 element tuple/list, not a {len(size)} element tuple/list"
458
+ )
459
+ if max_size is not None and len(size) != 1:
460
+ raise ValueError(
461
+ "max_size should only be passed if size specifies the length of the smaller edge, "
462
+ "i.e. size should be an int or a sequence of length 1 in torchscript mode."
463
+ )
464
+
465
+ _, image_height, image_width = get_dimensions(img)
466
+ if isinstance(size, int):
467
+ size = [size]
468
+ output_size = _compute_resized_output_size((image_height, image_width), size, max_size)
469
+
470
+ if [image_height, image_width] == output_size:
471
+ return img
472
+
473
+ if not isinstance(img, torch.Tensor):
474
+ if antialias is False:
475
+ warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
476
+ pil_interpolation = pil_modes_mapping[interpolation]
477
+ return F_pil.resize(img, size=output_size, interpolation=pil_interpolation)
478
+
479
+ return F_t.resize(img, size=output_size, interpolation=interpolation.value, antialias=antialias)
480
+
481
+
482
+ def pad(img: Tensor, padding: List[int], fill: Union[int, float] = 0, padding_mode: str = "constant") -> Tensor:
483
+ r"""Pad the given image on all sides with the given "pad" value.
484
+ If the image is torch Tensor, it is expected
485
+ to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric,
486
+ at most 3 leading dimensions for mode edge,
487
+ and an arbitrary number of leading dimensions for mode constant
488
+
489
+ Args:
490
+ img (PIL Image or Tensor): Image to be padded.
491
+ padding (int or sequence): Padding on each border. If a single int is provided this
492
+ is used to pad all borders. If sequence of length 2 is provided this is the padding
493
+ on left/right and top/bottom respectively. If a sequence of length 4 is provided
494
+ this is the padding for the left, top, right and bottom borders respectively.
495
+
496
+ .. note::
497
+ In torchscript mode padding as single int is not supported, use a sequence of
498
+ length 1: ``[padding, ]``.
499
+ fill (number or tuple): Pixel fill value for constant fill. Default is 0.
500
+ If a tuple of length 3, it is used to fill R, G, B channels respectively.
501
+ This value is only used when the padding_mode is constant.
502
+ Only number is supported for torch Tensor.
503
+ Only int or tuple value is supported for PIL Image.
504
+ padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
505
+ Default is constant.
506
+
507
+ - constant: pads with a constant value, this value is specified with fill
508
+
509
+ - edge: pads with the last value at the edge of the image.
510
+ If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2
511
+
512
+ - reflect: pads with reflection of image without repeating the last value on the edge.
513
+ For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
514
+ will result in [3, 2, 1, 2, 3, 4, 3, 2]
515
+
516
+ - symmetric: pads with reflection of image repeating the last value on the edge.
517
+ For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
518
+ will result in [2, 1, 1, 2, 3, 4, 4, 3]
519
+
520
+ Returns:
521
+ PIL Image or Tensor: Padded image.
522
+ """
523
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
524
+ _log_api_usage_once(pad)
525
+ if not isinstance(img, torch.Tensor):
526
+ return F_pil.pad(img, padding=padding, fill=fill, padding_mode=padding_mode)
527
+
528
+ return F_t.pad(img, padding=padding, fill=fill, padding_mode=padding_mode)
529
+
530
+
531
+ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
532
+ """Crop the given image at specified location and output size.
533
+ If the image is torch Tensor, it is expected
534
+ to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
535
+ If image size is smaller than output size along any edge, image is padded with 0 and then cropped.
536
+
537
+ Args:
538
+ img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image.
539
+ top (int): Vertical component of the top left corner of the crop box.
540
+ left (int): Horizontal component of the top left corner of the crop box.
541
+ height (int): Height of the crop box.
542
+ width (int): Width of the crop box.
543
+
544
+ Returns:
545
+ PIL Image or Tensor: Cropped image.
546
+ """
547
+
548
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
549
+ _log_api_usage_once(crop)
550
+ if not isinstance(img, torch.Tensor):
551
+ return F_pil.crop(img, top, left, height, width)
552
+
553
+ return F_t.crop(img, top, left, height, width)
554
+
555
+
556
+ def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
557
+ """Crops the given image at the center.
558
+ If the image is torch Tensor, it is expected
559
+ to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
560
+ If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
561
+
562
+ Args:
563
+ img (PIL Image or Tensor): Image to be cropped.
564
+ output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int,
565
+ it is used for both directions.
566
+
567
+ Returns:
568
+ PIL Image or Tensor: Cropped image.
569
+ """
570
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
571
+ _log_api_usage_once(center_crop)
572
+ if isinstance(output_size, numbers.Number):
573
+ output_size = (int(output_size), int(output_size))
574
+ elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
575
+ output_size = (output_size[0], output_size[0])
576
+
577
+ _, image_height, image_width = get_dimensions(img)
578
+ crop_height, crop_width = output_size
579
+
580
+ if crop_width > image_width or crop_height > image_height:
581
+ padding_ltrb = [
582
+ (crop_width - image_width) // 2 if crop_width > image_width else 0,
583
+ (crop_height - image_height) // 2 if crop_height > image_height else 0,
584
+ (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
585
+ (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
586
+ ]
587
+ img = pad(img, padding_ltrb, fill=0) # PIL uses fill value 0
588
+ _, image_height, image_width = get_dimensions(img)
589
+ if crop_width == image_width and crop_height == image_height:
590
+ return img
591
+
592
+ crop_top = int(round((image_height - crop_height) / 2.0))
593
+ crop_left = int(round((image_width - crop_width) / 2.0))
594
+ return crop(img, crop_top, crop_left, crop_height, crop_width)
595
+
596
+
597
+ def resized_crop(
598
+ img: Tensor,
599
+ top: int,
600
+ left: int,
601
+ height: int,
602
+ width: int,
603
+ size: List[int],
604
+ interpolation: InterpolationMode = InterpolationMode.BILINEAR,
605
+ antialias: Optional[bool] = True,
606
+ ) -> Tensor:
607
+ """Crop the given image and resize it to desired size.
608
+ If the image is torch Tensor, it is expected
609
+ to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
610
+
611
+ Notably used in :class:`~torchvision.transforms.RandomResizedCrop`.
612
+
613
+ Args:
614
+ img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image.
615
+ top (int): Vertical component of the top left corner of the crop box.
616
+ left (int): Horizontal component of the top left corner of the crop box.
617
+ height (int): Height of the crop box.
618
+ width (int): Width of the crop box.
619
+ size (sequence or int): Desired output size. Same semantics as ``resize``.
620
+ interpolation (InterpolationMode): Desired interpolation enum defined by
621
+ :class:`torchvision.transforms.InterpolationMode`.
622
+ Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
623
+ ``InterpolationMode.NEAREST_EXACT``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are
624
+ supported.
625
+ The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
626
+ antialias (bool, optional): Whether to apply antialiasing.
627
+ It only affects **tensors** with bilinear or bicubic modes and it is
628
+ ignored otherwise: on PIL images, antialiasing is always applied on
629
+ bilinear or bicubic modes; on other modes (for PIL images and
630
+ tensors), antialiasing makes no sense and this parameter is ignored.
631
+ Possible values are:
632
+
633
+ - ``True`` (default): will apply antialiasing for bilinear or bicubic modes.
634
+ Other mode aren't affected. This is probably what you want to use.
635
+ - ``False``: will not apply antialiasing for tensors on any mode. PIL
636
+ images are still antialiased on bilinear or bicubic modes, because
637
+ PIL doesn't support no antialias.
638
+ - ``None``: equivalent to ``False`` for tensors and ``True`` for
639
+ PIL images. This value exists for legacy reasons and you probably
640
+ don't want to use it unless you really know what you are doing.
641
+
642
+ The default value changed from ``None`` to ``True`` in
643
+ v0.17, for the PIL and Tensor backends to be consistent.
644
+ Returns:
645
+ PIL Image or Tensor: Cropped image.
646
+ """
647
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
648
+ _log_api_usage_once(resized_crop)
649
+ img = crop(img, top, left, height, width)
650
+ img = resize(img, size, interpolation, antialias=antialias)
651
+ return img
652
+
653
+
654
+ def hflip(img: Tensor) -> Tensor:
655
+ """Horizontally flip the given image.
656
+
657
+ Args:
658
+ img (PIL Image or Tensor): Image to be flipped. If img
659
+ is a Tensor, it is expected to be in [..., H, W] format,
660
+ where ... means it can have an arbitrary number of leading
661
+ dimensions.
662
+
663
+ Returns:
664
+ PIL Image or Tensor: Horizontally flipped image.
665
+ """
666
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
667
+ _log_api_usage_once(hflip)
668
+ if not isinstance(img, torch.Tensor):
669
+ return F_pil.hflip(img)
670
+
671
+ return F_t.hflip(img)
672
+
673
+
674
+ def _get_perspective_coeffs(startpoints: List[List[int]], endpoints: List[List[int]]) -> List[float]:
675
+ """Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms.
676
+
677
+ In Perspective Transform each pixel (x, y) in the original image gets transformed as,
678
+ (x, y) -> ( (ax + by + c) / (gx + hy + 1), (dx + ey + f) / (gx + hy + 1) )
679
+
680
+ Args:
681
+ startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
682
+ ``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
683
+ endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
684
+ ``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.
685
+
686
+ Returns:
687
+ octuple (a, b, c, d, e, f, g, h) for transforming each pixel.
688
+ """
689
+ if len(startpoints) != 4 or len(endpoints) != 4:
690
+ raise ValueError(
691
+ f"Please provide exactly four corners, got {len(startpoints)} startpoints and {len(endpoints)} endpoints."
692
+ )
693
+ a_matrix = torch.zeros(2 * len(startpoints), 8, dtype=torch.float64)
694
+
695
+ for i, (p1, p2) in enumerate(zip(endpoints, startpoints)):
696
+ a_matrix[2 * i, :] = torch.tensor([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]])
697
+ a_matrix[2 * i + 1, :] = torch.tensor([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]])
698
+
699
+ b_matrix = torch.tensor(startpoints, dtype=torch.float64).view(8)
700
+ # do least squares in double precision to prevent numerical issues
701
+ res = torch.linalg.lstsq(a_matrix, b_matrix, driver="gels").solution.to(torch.float32)
702
+
703
+ output: List[float] = res.tolist()
704
+ return output
705
+
706
+
707
+ def perspective(
708
+ img: Tensor,
709
+ startpoints: List[List[int]],
710
+ endpoints: List[List[int]],
711
+ interpolation: InterpolationMode = InterpolationMode.BILINEAR,
712
+ fill: Optional[List[float]] = None,
713
+ ) -> Tensor:
714
+ """Perform perspective transform of the given image.
715
+ If the image is torch Tensor, it is expected
716
+ to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
717
+
718
+ Args:
719
+ img (PIL Image or Tensor): Image to be transformed.
720
+ startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
721
+ ``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
722
+ endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
723
+ ``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.
724
+ interpolation (InterpolationMode): Desired interpolation enum defined by
725
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
726
+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
727
+ The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
728
+ fill (sequence or number, optional): Pixel fill value for the area outside the transformed
729
+ image. If given a number, the value is used for all bands respectively.
730
+
731
+ .. note::
732
+ In torchscript mode single int/float value is not supported, please use a sequence
733
+ of length 1: ``[value, ]``.
734
+
735
+ Returns:
736
+ PIL Image or Tensor: transformed Image.
737
+ """
738
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
739
+ _log_api_usage_once(perspective)
740
+
741
+ coeffs = _get_perspective_coeffs(startpoints, endpoints)
742
+
743
+ if isinstance(interpolation, int):
744
+ interpolation = _interpolation_modes_from_int(interpolation)
745
+ elif not isinstance(interpolation, InterpolationMode):
746
+ raise TypeError(
747
+ "Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
748
+ )
749
+
750
+ if not isinstance(img, torch.Tensor):
751
+ pil_interpolation = pil_modes_mapping[interpolation]
752
+ return F_pil.perspective(img, coeffs, interpolation=pil_interpolation, fill=fill)
753
+
754
+ return F_t.perspective(img, coeffs, interpolation=interpolation.value, fill=fill)
755
+
756
+
757
+ def vflip(img: Tensor) -> Tensor:
758
+ """Vertically flip the given image.
759
+
760
+ Args:
761
+ img (PIL Image or Tensor): Image to be flipped. If img
762
+ is a Tensor, it is expected to be in [..., H, W] format,
763
+ where ... means it can have an arbitrary number of leading
764
+ dimensions.
765
+
766
+ Returns:
767
+ PIL Image or Tensor: Vertically flipped image.
768
+ """
769
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
770
+ _log_api_usage_once(vflip)
771
+ if not isinstance(img, torch.Tensor):
772
+ return F_pil.vflip(img)
773
+
774
+ return F_t.vflip(img)
775
+
776
+
777
+ def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
778
+ """Crop the given image into four corners and the central crop.
779
+ If the image is torch Tensor, it is expected
780
+ to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
781
+
782
+ .. Note::
783
+ This transform returns a tuple of images and there may be a
784
+ mismatch in the number of inputs and targets your ``Dataset`` returns.
785
+
786
+ Args:
787
+ img (PIL Image or Tensor): Image to be cropped.
788
+ size (sequence or int): Desired output size of the crop. If size is an
789
+ int instead of sequence like (h, w), a square crop (size, size) is
790
+ made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
791
+
792
+ Returns:
793
+ tuple: tuple (tl, tr, bl, br, center)
794
+ Corresponding top left, top right, bottom left, bottom right and center crop.
795
+ """
796
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
797
+ _log_api_usage_once(five_crop)
798
+ if isinstance(size, numbers.Number):
799
+ size = (int(size), int(size))
800
+ elif isinstance(size, (tuple, list)) and len(size) == 1:
801
+ size = (size[0], size[0])
802
+
803
+ if len(size) != 2:
804
+ raise ValueError("Please provide only two dimensions (h, w) for size.")
805
+
806
+ _, image_height, image_width = get_dimensions(img)
807
+ crop_height, crop_width = size
808
+ if crop_width > image_width or crop_height > image_height:
809
+ msg = "Requested crop size {} is bigger than input size {}"
810
+ raise ValueError(msg.format(size, (image_height, image_width)))
811
+
812
+ tl = crop(img, 0, 0, crop_height, crop_width)
813
+ tr = crop(img, 0, image_width - crop_width, crop_height, crop_width)
814
+ bl = crop(img, image_height - crop_height, 0, crop_height, crop_width)
815
+ br = crop(img, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
816
+
817
+ center = center_crop(img, [crop_height, crop_width])
818
+
819
+ return tl, tr, bl, br, center
820
+
821
+
822
+ def ten_crop(
823
+ img: Tensor, size: List[int], vertical_flip: bool = False
824
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
825
+ """Generate ten cropped images from the given image.
826
+ Crop the given image into four corners and the central crop plus the
827
+ flipped version of these (horizontal flipping is used by default).
828
+ If the image is torch Tensor, it is expected
829
+ to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
830
+
831
+ .. Note::
832
+ This transform returns a tuple of images and there may be a
833
+ mismatch in the number of inputs and targets your ``Dataset`` returns.
834
+
835
+ Args:
836
+ img (PIL Image or Tensor): Image to be cropped.
837
+ size (sequence or int): Desired output size of the crop. If size is an
838
+ int instead of sequence like (h, w), a square crop (size, size) is
839
+ made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
840
+ vertical_flip (bool): Use vertical flipping instead of horizontal
841
+
842
+ Returns:
843
+ tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip)
844
+ Corresponding top left, top right, bottom left, bottom right and
845
+ center crop and same for the flipped image.
846
+ """
847
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
848
+ _log_api_usage_once(ten_crop)
849
+ if isinstance(size, numbers.Number):
850
+ size = (int(size), int(size))
851
+ elif isinstance(size, (tuple, list)) and len(size) == 1:
852
+ size = (size[0], size[0])
853
+
854
+ if len(size) != 2:
855
+ raise ValueError("Please provide only two dimensions (h, w) for size.")
856
+
857
+ first_five = five_crop(img, size)
858
+
859
+ if vertical_flip:
860
+ img = vflip(img)
861
+ else:
862
+ img = hflip(img)
863
+
864
+ second_five = five_crop(img, size)
865
+ return first_five + second_five
866
+
867
+
868
+ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
869
+ """Adjust brightness of an image.
870
+
871
+ Args:
872
+ img (PIL Image or Tensor): Image to be adjusted.
873
+ If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
874
+ where ... means it can have an arbitrary number of leading dimensions.
875
+ brightness_factor (float): How much to adjust the brightness. Can be
876
+ any non-negative number. 0 gives a black image, 1 gives the
877
+ original image while 2 increases the brightness by a factor of 2.
878
+
879
+ Returns:
880
+ PIL Image or Tensor: Brightness adjusted image.
881
+ """
882
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
883
+ _log_api_usage_once(adjust_brightness)
884
+ if not isinstance(img, torch.Tensor):
885
+ return F_pil.adjust_brightness(img, brightness_factor)
886
+
887
+ return F_t.adjust_brightness(img, brightness_factor)
888
+
889
+
890
+ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
891
+ """Adjust contrast of an image.
892
+
893
+ Args:
894
+ img (PIL Image or Tensor): Image to be adjusted.
895
+ If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
896
+ where ... means it can have an arbitrary number of leading dimensions.
897
+ contrast_factor (float): How much to adjust the contrast. Can be any
898
+ non-negative number. 0 gives a solid gray image, 1 gives the
899
+ original image while 2 increases the contrast by a factor of 2.
900
+
901
+ Returns:
902
+ PIL Image or Tensor: Contrast adjusted image.
903
+ """
904
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
905
+ _log_api_usage_once(adjust_contrast)
906
+ if not isinstance(img, torch.Tensor):
907
+ return F_pil.adjust_contrast(img, contrast_factor)
908
+
909
+ return F_t.adjust_contrast(img, contrast_factor)
910
+
911
+
912
+ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
913
+ """Adjust color saturation of an image.
914
+
915
+ Args:
916
+ img (PIL Image or Tensor): Image to be adjusted.
917
+ If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
918
+ where ... means it can have an arbitrary number of leading dimensions.
919
+ saturation_factor (float): How much to adjust the saturation. 0 will
920
+ give a black and white image, 1 will give the original image while
921
+ 2 will enhance the saturation by a factor of 2.
922
+
923
+ Returns:
924
+ PIL Image or Tensor: Saturation adjusted image.
925
+ """
926
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
927
+ _log_api_usage_once(adjust_saturation)
928
+ if not isinstance(img, torch.Tensor):
929
+ return F_pil.adjust_saturation(img, saturation_factor)
930
+
931
+ return F_t.adjust_saturation(img, saturation_factor)
932
+
933
+
934
+ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
935
+ """Adjust hue of an image.
936
+
937
+ The image hue is adjusted by converting the image to HSV and
938
+ cyclically shifting the intensities in the hue channel (H).
939
+ The image is then converted back to original image mode.
940
+
941
+ `hue_factor` is the amount of shift in H channel and must be in the
942
+ interval `[-0.5, 0.5]`.
943
+
944
+ See `Hue`_ for more details.
945
+
946
+ .. _Hue: https://en.wikipedia.org/wiki/Hue
947
+
948
+ Args:
949
+ img (PIL Image or Tensor): Image to be adjusted.
950
+ If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
951
+ where ... means it can have an arbitrary number of leading dimensions.
952
+ If img is PIL Image mode "1", "I", "F" and modes with transparency (alpha channel) are not supported.
953
+ Note: the pixel values of the input image has to be non-negative for conversion to HSV space;
954
+ thus it does not work if you normalize your image to an interval with negative values,
955
+ or use an interpolation that generates negative values before using this function.
956
+ hue_factor (float): How much to shift the hue channel. Should be in
957
+ [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
958
+ HSV space in positive and negative direction respectively.
959
+ 0 means no shift. Therefore, both -0.5 and 0.5 will give an image
960
+ with complementary colors while 0 gives the original image.
961
+
962
+ Returns:
963
+ PIL Image or Tensor: Hue adjusted image.
964
+ """
965
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
966
+ _log_api_usage_once(adjust_hue)
967
+ if not isinstance(img, torch.Tensor):
968
+ return F_pil.adjust_hue(img, hue_factor)
969
+
970
+ return F_t.adjust_hue(img, hue_factor)
971
+
972
+
973
+ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
974
+ r"""Perform gamma correction on an image.
975
+
976
+ Also known as Power Law Transform. Intensities in RGB mode are adjusted
977
+ based on the following equation:
978
+
979
+ .. math::
980
+ I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma}
981
+
982
+ See `Gamma Correction`_ for more details.
983
+
984
+ .. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction
985
+
986
+ Args:
987
+ img (PIL Image or Tensor): PIL Image to be adjusted.
988
+ If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
989
+ where ... means it can have an arbitrary number of leading dimensions.
990
+ If img is PIL Image, modes with transparency (alpha channel) are not supported.
991
+ gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
992
+ gamma larger than 1 make the shadows darker,
993
+ while gamma smaller than 1 make dark regions lighter.
994
+ gain (float): The constant multiplier.
995
+ Returns:
996
+ PIL Image or Tensor: Gamma correction adjusted image.
997
+ """
998
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
999
+ _log_api_usage_once(adjust_gamma)
1000
+ if not isinstance(img, torch.Tensor):
1001
+ return F_pil.adjust_gamma(img, gamma, gain)
1002
+
1003
+ return F_t.adjust_gamma(img, gamma, gain)
1004
+
1005
+
1006
+ def _get_inverse_affine_matrix(
1007
+ center: List[float], angle: float, translate: List[float], scale: float, shear: List[float], inverted: bool = True
1008
+ ) -> List[float]:
1009
+ # Helper method to compute inverse matrix for affine transformation
1010
+
1011
+ # Pillow requires inverse affine transformation matrix:
1012
+ # Affine matrix is : M = T * C * RotateScaleShear * C^-1
1013
+ #
1014
+ # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
1015
+ # C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
1016
+ # RotateScaleShear is rotation with scale and shear matrix
1017
+ #
1018
+ # RotateScaleShear(a, s, (sx, sy)) =
1019
+ # = R(a) * S(s) * SHy(sy) * SHx(sx)
1020
+ # = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ]
1021
+ # [ s*sin(a - sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ]
1022
+ # [ 0 , 0 , 1 ]
1023
+ # where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears:
1024
+ # SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0]
1025
+ # [0, 1 ] [-tan(s), 1]
1026
+ #
1027
+ # Thus, the inverse is M^-1 = C * RotateScaleShear^-1 * C^-1 * T^-1
1028
+
1029
+ rot = math.radians(angle)
1030
+ sx = math.radians(shear[0])
1031
+ sy = math.radians(shear[1])
1032
+
1033
+ cx, cy = center
1034
+ tx, ty = translate
1035
+
1036
+ # RSS without scaling
1037
+ a = math.cos(rot - sy) / math.cos(sy)
1038
+ b = -math.cos(rot - sy) * math.tan(sx) / math.cos(sy) - math.sin(rot)
1039
+ c = math.sin(rot - sy) / math.cos(sy)
1040
+ d = -math.sin(rot - sy) * math.tan(sx) / math.cos(sy) + math.cos(rot)
1041
+
1042
+ if inverted:
1043
+ # Inverted rotation matrix with scale and shear
1044
+ # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
1045
+ matrix = [d, -b, 0.0, -c, a, 0.0]
1046
+ matrix = [x / scale for x in matrix]
1047
+ # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
1048
+ matrix[2] += matrix[0] * (-cx - tx) + matrix[1] * (-cy - ty)
1049
+ matrix[5] += matrix[3] * (-cx - tx) + matrix[4] * (-cy - ty)
1050
+ # Apply center translation: C * RSS^-1 * C^-1 * T^-1
1051
+ matrix[2] += cx
1052
+ matrix[5] += cy
1053
+ else:
1054
+ matrix = [a, b, 0.0, c, d, 0.0]
1055
+ matrix = [x * scale for x in matrix]
1056
+ # Apply inverse of center translation: RSS * C^-1
1057
+ matrix[2] += matrix[0] * (-cx) + matrix[1] * (-cy)
1058
+ matrix[5] += matrix[3] * (-cx) + matrix[4] * (-cy)
1059
+ # Apply translation and center : T * C * RSS * C^-1
1060
+ matrix[2] += cx + tx
1061
+ matrix[5] += cy + ty
1062
+
1063
+ return matrix
1064
+
1065
+
1066
+ def rotate(
1067
+ img: Tensor,
1068
+ angle: float,
1069
+ interpolation: InterpolationMode = InterpolationMode.NEAREST,
1070
+ expand: bool = False,
1071
+ center: Optional[List[int]] = None,
1072
+ fill: Optional[List[float]] = None,
1073
+ ) -> Tensor:
1074
+ """Rotate the image by angle.
1075
+ If the image is torch Tensor, it is expected
1076
+ to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
1077
+
1078
+ Args:
1079
+ img (PIL Image or Tensor): image to be rotated.
1080
+ angle (number): rotation angle value in degrees, counter-clockwise.
1081
+ interpolation (InterpolationMode): Desired interpolation enum defined by
1082
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
1083
+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
1084
+ The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
1085
+ expand (bool, optional): Optional expansion flag.
1086
+ If true, expands the output image to make it large enough to hold the entire rotated image.
1087
+ If false or omitted, make the output image the same size as the input image.
1088
+ Note that the expand flag assumes rotation around the center and no translation.
1089
+ center (sequence, optional): Optional center of rotation. Origin is the upper left corner.
1090
+ Default is the center of the image.
1091
+ fill (sequence or number, optional): Pixel fill value for the area outside the transformed
1092
+ image. If given a number, the value is used for all bands respectively.
1093
+
1094
+ .. note::
1095
+ In torchscript mode single int/float value is not supported, please use a sequence
1096
+ of length 1: ``[value, ]``.
1097
+ Returns:
1098
+ PIL Image or Tensor: Rotated image.
1099
+
1100
+ .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
1101
+
1102
+ """
1103
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
1104
+ _log_api_usage_once(rotate)
1105
+
1106
+ if isinstance(interpolation, int):
1107
+ interpolation = _interpolation_modes_from_int(interpolation)
1108
+ elif not isinstance(interpolation, InterpolationMode):
1109
+ raise TypeError(
1110
+ "Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
1111
+ )
1112
+
1113
+ if not isinstance(angle, (int, float)):
1114
+ raise TypeError("Argument angle should be int or float")
1115
+
1116
+ if center is not None and not isinstance(center, (list, tuple)):
1117
+ raise TypeError("Argument center should be a sequence")
1118
+
1119
+ if not isinstance(img, torch.Tensor):
1120
+ pil_interpolation = pil_modes_mapping[interpolation]
1121
+ return F_pil.rotate(img, angle=angle, interpolation=pil_interpolation, expand=expand, center=center, fill=fill)
1122
+
1123
+ center_f = [0.0, 0.0]
1124
+ if center is not None:
1125
+ _, height, width = get_dimensions(img)
1126
+ # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
1127
+ center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
1128
+
1129
+ # due to current incoherence of rotation angle direction between affine and rotate implementations
1130
+ # we need to set -angle.
1131
+ matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
1132
+ return F_t.rotate(img, matrix=matrix, interpolation=interpolation.value, expand=expand, fill=fill)
1133
+
1134
+
1135
+ def affine(
1136
+ img: Tensor,
1137
+ angle: float,
1138
+ translate: List[int],
1139
+ scale: float,
1140
+ shear: List[float],
1141
+ interpolation: InterpolationMode = InterpolationMode.NEAREST,
1142
+ fill: Optional[List[float]] = None,
1143
+ center: Optional[List[int]] = None,
1144
+ ) -> Tensor:
1145
+ """Apply affine transformation on the image keeping image center invariant.
1146
+ If the image is torch Tensor, it is expected
1147
+ to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
1148
+
1149
+ Args:
1150
+ img (PIL Image or Tensor): image to transform.
1151
+ angle (number): rotation angle in degrees between -180 and 180, clockwise direction.
1152
+ translate (sequence of integers): horizontal and vertical translations (post-rotation translation)
1153
+ scale (float): overall scale
1154
+ shear (float or sequence): shear angle value in degrees between -180 to 180, clockwise direction.
1155
+ If a sequence is specified, the first value corresponds to a shear parallel to the x-axis, while
1156
+ the second value corresponds to a shear parallel to the y-axis.
1157
+ interpolation (InterpolationMode): Desired interpolation enum defined by
1158
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
1159
+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
1160
+ The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
1161
+ fill (sequence or number, optional): Pixel fill value for the area outside the transformed
1162
+ image. If given a number, the value is used for all bands respectively.
1163
+
1164
+ .. note::
1165
+ In torchscript mode single int/float value is not supported, please use a sequence
1166
+ of length 1: ``[value, ]``.
1167
+ center (sequence, optional): Optional center of rotation. Origin is the upper left corner.
1168
+ Default is the center of the image.
1169
+
1170
+ Returns:
1171
+ PIL Image or Tensor: Transformed image.
1172
+ """
1173
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
1174
+ _log_api_usage_once(affine)
1175
+
1176
+ if isinstance(interpolation, int):
1177
+ interpolation = _interpolation_modes_from_int(interpolation)
1178
+ elif not isinstance(interpolation, InterpolationMode):
1179
+ raise TypeError(
1180
+ "Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
1181
+ )
1182
+
1183
+ if not isinstance(angle, (int, float)):
1184
+ raise TypeError("Argument angle should be int or float")
1185
+
1186
+ if not isinstance(translate, (list, tuple)):
1187
+ raise TypeError("Argument translate should be a sequence")
1188
+
1189
+ if len(translate) != 2:
1190
+ raise ValueError("Argument translate should be a sequence of length 2")
1191
+
1192
+ if scale <= 0.0:
1193
+ raise ValueError("Argument scale should be positive")
1194
+
1195
+ if not isinstance(shear, (numbers.Number, (list, tuple))):
1196
+ raise TypeError("Shear should be either a single value or a sequence of two values")
1197
+
1198
+ if isinstance(angle, int):
1199
+ angle = float(angle)
1200
+
1201
+ if isinstance(translate, tuple):
1202
+ translate = list(translate)
1203
+
1204
+ if isinstance(shear, numbers.Number):
1205
+ shear = [shear, 0.0]
1206
+
1207
+ if isinstance(shear, tuple):
1208
+ shear = list(shear)
1209
+
1210
+ if len(shear) == 1:
1211
+ shear = [shear[0], shear[0]]
1212
+
1213
+ if len(shear) != 2:
1214
+ raise ValueError(f"Shear should be a sequence containing two values. Got {shear}")
1215
+
1216
+ if center is not None and not isinstance(center, (list, tuple)):
1217
+ raise TypeError("Argument center should be a sequence")
1218
+
1219
+ _, height, width = get_dimensions(img)
1220
+ if not isinstance(img, torch.Tensor):
1221
+ # center = (width * 0.5 + 0.5, height * 0.5 + 0.5)
1222
+ # it is visually better to estimate the center without 0.5 offset
1223
+ # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
1224
+ if center is None:
1225
+ center = [width * 0.5, height * 0.5]
1226
+ matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
1227
+ pil_interpolation = pil_modes_mapping[interpolation]
1228
+ return F_pil.affine(img, matrix=matrix, interpolation=pil_interpolation, fill=fill)
1229
+
1230
+ center_f = [0.0, 0.0]
1231
+ if center is not None:
1232
+ _, height, width = get_dimensions(img)
1233
+ # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
1234
+ center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
1235
+
1236
+ translate_f = [1.0 * t for t in translate]
1237
+ matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)
1238
+ return F_t.affine(img, matrix=matrix, interpolation=interpolation.value, fill=fill)
1239
+
1240
+
1241
+ # Looks like to_grayscale() is a stand-alone functional that is never called
1242
+ # from the transform classes. Perhaps it's still here for BC? I can't be
1243
+ # bothered to dig.
1244
+ @torch.jit.unused
1245
+ def to_grayscale(img, num_output_channels=1):
1246
+ """Convert PIL image of any mode (RGB, HSV, LAB, etc) to grayscale version of image.
1247
+ This transform does not support torch Tensor.
1248
+
1249
+ Args:
1250
+ img (PIL Image): PIL Image to be converted to grayscale.
1251
+ num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default is 1.
1252
+
1253
+ Returns:
1254
+ PIL Image: Grayscale version of the image.
1255
+
1256
+ - if num_output_channels = 1 : returned image is single channel
1257
+ - if num_output_channels = 3 : returned image is 3 channel with r = g = b
1258
+ """
1259
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
1260
+ _log_api_usage_once(to_grayscale)
1261
+ if isinstance(img, Image.Image):
1262
+ return F_pil.to_grayscale(img, num_output_channels)
1263
+
1264
+ raise TypeError("Input should be PIL Image")
1265
+
1266
+
1267
+ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
1268
+ """Convert RGB image to grayscale version of image.
1269
+ If the image is torch Tensor, it is expected
1270
+ to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
1271
+
1272
+ Note:
1273
+ Please, note that this method supports only RGB images as input. For inputs in other color spaces,
1274
+ please, consider using :meth:`~torchvision.transforms.functional.to_grayscale` with PIL Image.
1275
+
1276
+ Args:
1277
+ img (PIL Image or Tensor): RGB Image to be converted to grayscale.
1278
+ num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1.
1279
+
1280
+ Returns:
1281
+ PIL Image or Tensor: Grayscale version of the image.
1282
+
1283
+ - if num_output_channels = 1 : returned image is single channel
1284
+ - if num_output_channels = 3 : returned image is 3 channel with r = g = b
1285
+ """
1286
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
1287
+ _log_api_usage_once(rgb_to_grayscale)
1288
+ if not isinstance(img, torch.Tensor):
1289
+ return F_pil.to_grayscale(img, num_output_channels)
1290
+
1291
+ return F_t.rgb_to_grayscale(img, num_output_channels)
1292
+
1293
+
1294
+ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
1295
+ """Erase the input Tensor Image with given value.
1296
+ This transform does not support PIL Image.
1297
+
1298
+ Args:
1299
+ img (Tensor Image): Tensor image of size (C, H, W) to be erased
1300
+ i (int): i in (i,j) i.e coordinates of the upper left corner.
1301
+ j (int): j in (i,j) i.e coordinates of the upper left corner.
1302
+ h (int): Height of the erased region.
1303
+ w (int): Width of the erased region.
1304
+ v: Erasing value.
1305
+ inplace(bool, optional): For in-place operations. By default, is set False.
1306
+
1307
+ Returns:
1308
+ Tensor Image: Erased image.
1309
+ """
1310
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
1311
+ _log_api_usage_once(erase)
1312
+ if not isinstance(img, torch.Tensor):
1313
+ raise TypeError(f"img should be Tensor Image. Got {type(img)}")
1314
+
1315
+ return F_t.erase(img, i, j, h, w, v, inplace=inplace)
1316
+
1317
+
1318
+ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Tensor:
1319
+ """Performs Gaussian blurring on the image by given kernel
1320
+
1321
+ The convolution will be using reflection padding corresponding to the kernel size, to maintain the input shape.
1322
+ If the image is torch Tensor, it is expected
1323
+ to have [..., H, W] shape, where ... means at most one leading dimension.
1324
+
1325
+ Args:
1326
+ img (PIL Image or Tensor): Image to be blurred
1327
+ kernel_size (sequence of ints or int): Gaussian kernel size. Can be a sequence of integers
1328
+ like ``(kx, ky)`` or a single integer for square kernels.
1329
+
1330
+ .. note::
1331
+ In torchscript mode kernel_size as single int is not supported, use a sequence of
1332
+ length 1: ``[ksize, ]``.
1333
+ sigma (sequence of floats or float, optional): Gaussian kernel standard deviation. Can be a
1334
+ sequence of floats like ``(sigma_x, sigma_y)`` or a single float to define the
1335
+ same sigma in both X/Y directions. If None, then it is computed using
1336
+ ``kernel_size`` as ``sigma = 0.3 * ((kernel_size - 1) * 0.5 - 1) + 0.8``.
1337
+ Default, None.
1338
+
1339
+ .. note::
1340
+ In torchscript mode sigma as single float is
1341
+ not supported, use a sequence of length 1: ``[sigma, ]``.
1342
+
1343
+ Returns:
1344
+ PIL Image or Tensor: Gaussian Blurred version of the image.
1345
+ """
1346
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
1347
+ _log_api_usage_once(gaussian_blur)
1348
+ if not isinstance(kernel_size, (int, list, tuple)):
1349
+ raise TypeError(f"kernel_size should be int or a sequence of integers. Got {type(kernel_size)}")
1350
+ if isinstance(kernel_size, int):
1351
+ kernel_size = [kernel_size, kernel_size]
1352
+ if len(kernel_size) != 2:
1353
+ raise ValueError(f"If kernel_size is a sequence its length should be 2. Got {len(kernel_size)}")
1354
+ for ksize in kernel_size:
1355
+ if ksize % 2 == 0 or ksize < 0:
1356
+ raise ValueError(f"kernel_size should have odd and positive integers. Got {kernel_size}")
1357
+
1358
+ if sigma is None:
1359
+ sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size]
1360
+
1361
+ if sigma is not None and not isinstance(sigma, (int, float, list, tuple)):
1362
+ raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}")
1363
+ if isinstance(sigma, (int, float)):
1364
+ sigma = [float(sigma), float(sigma)]
1365
+ if isinstance(sigma, (list, tuple)) and len(sigma) == 1:
1366
+ sigma = [sigma[0], sigma[0]]
1367
+ if len(sigma) != 2:
1368
+ raise ValueError(f"If sigma is a sequence, its length should be 2. Got {len(sigma)}")
1369
+ for s in sigma:
1370
+ if s <= 0.0:
1371
+ raise ValueError(f"sigma should have positive values. Got {sigma}")
1372
+
1373
+ t_img = img
1374
+ if not isinstance(img, torch.Tensor):
1375
+ if not F_pil._is_pil_image(img):
1376
+ raise TypeError(f"img should be PIL Image or Tensor. Got {type(img)}")
1377
+
1378
+ t_img = pil_to_tensor(img)
1379
+
1380
+ output = F_t.gaussian_blur(t_img, kernel_size, sigma)
1381
+
1382
+ if not isinstance(img, torch.Tensor):
1383
+ output = to_pil_image(output, mode=img.mode)
1384
+ return output
1385
+
1386
+
1387
+ def invert(img: Tensor) -> Tensor:
1388
+ """Invert the colors of an RGB/grayscale image.
1389
+
1390
+ Args:
1391
+ img (PIL Image or Tensor): Image to have its colors inverted.
1392
+ If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
1393
+ where ... means it can have an arbitrary number of leading dimensions.
1394
+ If img is PIL Image, it is expected to be in mode "L" or "RGB".
1395
+
1396
+ Returns:
1397
+ PIL Image or Tensor: Color inverted image.
1398
+ """
1399
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
1400
+ _log_api_usage_once(invert)
1401
+ if not isinstance(img, torch.Tensor):
1402
+ return F_pil.invert(img)
1403
+
1404
+ return F_t.invert(img)
1405
+
1406
+
1407
+ def posterize(img: Tensor, bits: int) -> Tensor:
1408
+ """Posterize an image by reducing the number of bits for each color channel.
1409
+
1410
+ Args:
1411
+ img (PIL Image or Tensor): Image to have its colors posterized.
1412
+ If img is torch Tensor, it should be of type torch.uint8, and
1413
+ it is expected to be in [..., 1 or 3, H, W] format, where ... means
1414
+ it can have an arbitrary number of leading dimensions.
1415
+ If img is PIL Image, it is expected to be in mode "L" or "RGB".
1416
+ bits (int): The number of bits to keep for each channel (0-8).
1417
+ Returns:
1418
+ PIL Image or Tensor: Posterized image.
1419
+ """
1420
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
1421
+ _log_api_usage_once(posterize)
1422
+ if not (0 <= bits <= 8):
1423
+ raise ValueError(f"The number if bits should be between 0 and 8. Got {bits}")
1424
+
1425
+ if not isinstance(img, torch.Tensor):
1426
+ return F_pil.posterize(img, bits)
1427
+
1428
+ return F_t.posterize(img, bits)
1429
+
1430
+
1431
+ def solarize(img: Tensor, threshold: float) -> Tensor:
1432
+ """Solarize an RGB/grayscale image by inverting all pixel values above a threshold.
1433
+
1434
+ Args:
1435
+ img (PIL Image or Tensor): Image to have its colors inverted.
1436
+ If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
1437
+ where ... means it can have an arbitrary number of leading dimensions.
1438
+ If img is PIL Image, it is expected to be in mode "L" or "RGB".
1439
+ threshold (float): All pixels equal or above this value are inverted.
1440
+ Returns:
1441
+ PIL Image or Tensor: Solarized image.
1442
+ """
1443
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
1444
+ _log_api_usage_once(solarize)
1445
+ if not isinstance(img, torch.Tensor):
1446
+ return F_pil.solarize(img, threshold)
1447
+
1448
+ return F_t.solarize(img, threshold)
1449
+
1450
+
1451
+ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
1452
+ """Adjust the sharpness of an image.
1453
+
1454
+ Args:
1455
+ img (PIL Image or Tensor): Image to be adjusted.
1456
+ If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
1457
+ where ... means it can have an arbitrary number of leading dimensions.
1458
+ sharpness_factor (float): How much to adjust the sharpness. Can be
1459
+ any non-negative number. 0 gives a blurred image, 1 gives the
1460
+ original image while 2 increases the sharpness by a factor of 2.
1461
+
1462
+ Returns:
1463
+ PIL Image or Tensor: Sharpness adjusted image.
1464
+ """
1465
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
1466
+ _log_api_usage_once(adjust_sharpness)
1467
+ if not isinstance(img, torch.Tensor):
1468
+ return F_pil.adjust_sharpness(img, sharpness_factor)
1469
+
1470
+ return F_t.adjust_sharpness(img, sharpness_factor)
1471
+
1472
+
1473
+ def autocontrast(img: Tensor) -> Tensor:
1474
+ """Maximize contrast of an image by remapping its
1475
+ pixels per channel so that the lowest becomes black and the lightest
1476
+ becomes white.
1477
+
1478
+ Args:
1479
+ img (PIL Image or Tensor): Image on which autocontrast is applied.
1480
+ If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
1481
+ where ... means it can have an arbitrary number of leading dimensions.
1482
+ If img is PIL Image, it is expected to be in mode "L" or "RGB".
1483
+
1484
+ Returns:
1485
+ PIL Image or Tensor: An image that was autocontrasted.
1486
+ """
1487
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
1488
+ _log_api_usage_once(autocontrast)
1489
+ if not isinstance(img, torch.Tensor):
1490
+ return F_pil.autocontrast(img)
1491
+
1492
+ return F_t.autocontrast(img)
1493
+
1494
+
1495
+ def equalize(img: Tensor) -> Tensor:
1496
+ """Equalize the histogram of an image by applying
1497
+ a non-linear mapping to the input in order to create a uniform
1498
+ distribution of grayscale values in the output.
1499
+
1500
+ Args:
1501
+ img (PIL Image or Tensor): Image on which equalize is applied.
1502
+ If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
1503
+ where ... means it can have an arbitrary number of leading dimensions.
1504
+ The tensor dtype must be ``torch.uint8`` and values are expected to be in ``[0, 255]``.
1505
+ If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
1506
+
1507
+ Returns:
1508
+ PIL Image or Tensor: An image that was equalized.
1509
+ """
1510
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
1511
+ _log_api_usage_once(equalize)
1512
+ if not isinstance(img, torch.Tensor):
1513
+ return F_pil.equalize(img)
1514
+
1515
+ return F_t.equalize(img)
1516
+
1517
+
1518
+ def elastic_transform(
1519
+ img: Tensor,
1520
+ displacement: Tensor,
1521
+ interpolation: InterpolationMode = InterpolationMode.BILINEAR,
1522
+ fill: Optional[List[float]] = None,
1523
+ ) -> Tensor:
1524
+ """Transform a tensor image with elastic transformations.
1525
+ Given alpha and sigma, it will generate displacement
1526
+ vectors for all pixels based on random offsets. Alpha controls the strength
1527
+ and sigma controls the smoothness of the displacements.
1528
+ The displacements are added to an identity grid and the resulting grid is
1529
+ used to grid_sample from the image.
1530
+
1531
+ Applications:
1532
+ Randomly transforms the morphology of objects in images and produces a
1533
+ see-through-water-like effect.
1534
+
1535
+ Args:
1536
+ img (PIL Image or Tensor): Image on which elastic_transform is applied.
1537
+ If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
1538
+ where ... means it can have an arbitrary number of leading dimensions.
1539
+ If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
1540
+ displacement (Tensor): The displacement field. Expected shape is [1, H, W, 2].
1541
+ interpolation (InterpolationMode): Desired interpolation enum defined by
1542
+ :class:`torchvision.transforms.InterpolationMode`.
1543
+ Default is ``InterpolationMode.BILINEAR``.
1544
+ The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
1545
+ fill (number or str or tuple): Pixel fill value for constant fill. Default is 0.
1546
+ If a tuple of length 3, it is used to fill R, G, B channels respectively.
1547
+ This value is only used when the padding_mode is constant.
1548
+ """
1549
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
1550
+ _log_api_usage_once(elastic_transform)
1551
+ # Backward compatibility with integer value
1552
+ if isinstance(interpolation, int):
1553
+ warnings.warn(
1554
+ "Argument interpolation should be of type InterpolationMode instead of int. "
1555
+ "Please, use InterpolationMode enum."
1556
+ )
1557
+ interpolation = _interpolation_modes_from_int(interpolation)
1558
+
1559
+ if not isinstance(displacement, torch.Tensor):
1560
+ raise TypeError("Argument displacement should be a Tensor")
1561
+
1562
+ t_img = img
1563
+ if not isinstance(img, torch.Tensor):
1564
+ if not F_pil._is_pil_image(img):
1565
+ raise TypeError(f"img should be PIL Image or Tensor. Got {type(img)}")
1566
+ t_img = pil_to_tensor(img)
1567
+
1568
+ shape = t_img.shape
1569
+ shape = (1,) + shape[-2:] + (2,)
1570
+ if shape != displacement.shape:
1571
+ raise ValueError(f"Argument displacement shape should be {shape}, but given {displacement.shape}")
1572
+
1573
+ # TODO: if image shape is [N1, N2, ..., C, H, W] and
1574
+ # displacement is [1, H, W, 2] we need to reshape input image
1575
+ # such grid_sampler takes internal code for 4D input
1576
+
1577
+ output = F_t.elastic_transform(
1578
+ t_img,
1579
+ displacement,
1580
+ interpolation=interpolation.value,
1581
+ fill=fill,
1582
+ )
1583
+
1584
+ if not isinstance(img, torch.Tensor):
1585
+ output = to_pil_image(output, mode=img.mode)
1586
+ return output
.venv/lib/python3.11/site-packages/torchvision/transforms/transforms.py ADDED
@@ -0,0 +1,2153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numbers
3
+ import random
4
+ import warnings
5
+ from collections.abc import Sequence
6
+ from typing import List, Optional, Tuple, Union
7
+
8
+ import torch
9
+ from torch import Tensor
10
+
11
+ try:
12
+ import accimage
13
+ except ImportError:
14
+ accimage = None
15
+
16
+ from ..utils import _log_api_usage_once
17
+ from . import functional as F
18
+ from .functional import _interpolation_modes_from_int, InterpolationMode
19
+
20
+ __all__ = [
21
+ "Compose",
22
+ "ToTensor",
23
+ "PILToTensor",
24
+ "ConvertImageDtype",
25
+ "ToPILImage",
26
+ "Normalize",
27
+ "Resize",
28
+ "CenterCrop",
29
+ "Pad",
30
+ "Lambda",
31
+ "RandomApply",
32
+ "RandomChoice",
33
+ "RandomOrder",
34
+ "RandomCrop",
35
+ "RandomHorizontalFlip",
36
+ "RandomVerticalFlip",
37
+ "RandomResizedCrop",
38
+ "FiveCrop",
39
+ "TenCrop",
40
+ "LinearTransformation",
41
+ "ColorJitter",
42
+ "RandomRotation",
43
+ "RandomAffine",
44
+ "Grayscale",
45
+ "RandomGrayscale",
46
+ "RandomPerspective",
47
+ "RandomErasing",
48
+ "GaussianBlur",
49
+ "InterpolationMode",
50
+ "RandomInvert",
51
+ "RandomPosterize",
52
+ "RandomSolarize",
53
+ "RandomAdjustSharpness",
54
+ "RandomAutocontrast",
55
+ "RandomEqualize",
56
+ "ElasticTransform",
57
+ ]
58
+
59
+
60
+ class Compose:
61
+ """Composes several transforms together. This transform does not support torchscript.
62
+ Please, see the note below.
63
+
64
+ Args:
65
+ transforms (list of ``Transform`` objects): list of transforms to compose.
66
+
67
+ Example:
68
+ >>> transforms.Compose([
69
+ >>> transforms.CenterCrop(10),
70
+ >>> transforms.PILToTensor(),
71
+ >>> transforms.ConvertImageDtype(torch.float),
72
+ >>> ])
73
+
74
+ .. note::
75
+ In order to script the transformations, please use ``torch.nn.Sequential`` as below.
76
+
77
+ >>> transforms = torch.nn.Sequential(
78
+ >>> transforms.CenterCrop(10),
79
+ >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
80
+ >>> )
81
+ >>> scripted_transforms = torch.jit.script(transforms)
82
+
83
+ Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
84
+ `lambda` functions or ``PIL.Image``.
85
+
86
+ """
87
+
88
+ def __init__(self, transforms):
89
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
90
+ _log_api_usage_once(self)
91
+ self.transforms = transforms
92
+
93
+ def __call__(self, img):
94
+ for t in self.transforms:
95
+ img = t(img)
96
+ return img
97
+
98
+ def __repr__(self) -> str:
99
+ format_string = self.__class__.__name__ + "("
100
+ for t in self.transforms:
101
+ format_string += "\n"
102
+ format_string += f" {t}"
103
+ format_string += "\n)"
104
+ return format_string
105
+
106
+
107
+ class ToTensor:
108
+ """Convert a PIL Image or ndarray to tensor and scale the values accordingly.
109
+
110
+ This transform does not support torchscript.
111
+
112
+ Converts a PIL Image or numpy.ndarray (H x W x C) in the range
113
+ [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
114
+ if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
115
+ or if the numpy.ndarray has dtype = np.uint8
116
+
117
+ In the other cases, tensors are returned without scaling.
118
+
119
+ .. note::
120
+ Because the input image is scaled to [0.0, 1.0], this transformation should not be used when
121
+ transforming target image masks. See the `references`_ for implementing the transforms for image masks.
122
+
123
+ .. _references: https://github.com/pytorch/vision/tree/main/references/segmentation
124
+ """
125
+
126
+ def __init__(self) -> None:
127
+ _log_api_usage_once(self)
128
+
129
+ def __call__(self, pic):
130
+ """
131
+ Args:
132
+ pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
133
+
134
+ Returns:
135
+ Tensor: Converted image.
136
+ """
137
+ return F.to_tensor(pic)
138
+
139
+ def __repr__(self) -> str:
140
+ return f"{self.__class__.__name__}()"
141
+
142
+
143
+ class PILToTensor:
144
+ """Convert a PIL Image to a tensor of the same type - this does not scale values.
145
+
146
+ This transform does not support torchscript.
147
+
148
+ Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W).
149
+ """
150
+
151
+ def __init__(self) -> None:
152
+ _log_api_usage_once(self)
153
+
154
+ def __call__(self, pic):
155
+ """
156
+ .. note::
157
+
158
+ A deep copy of the underlying array is performed.
159
+
160
+ Args:
161
+ pic (PIL Image): Image to be converted to tensor.
162
+
163
+ Returns:
164
+ Tensor: Converted image.
165
+ """
166
+ return F.pil_to_tensor(pic)
167
+
168
+ def __repr__(self) -> str:
169
+ return f"{self.__class__.__name__}()"
170
+
171
+
172
+ class ConvertImageDtype(torch.nn.Module):
173
+ """Convert a tensor image to the given ``dtype`` and scale the values accordingly.
174
+
175
+ This function does not support PIL Image.
176
+
177
+ Args:
178
+ dtype (torch.dtype): Desired data type of the output
179
+
180
+ .. note::
181
+
182
+ When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
183
+ If converted back and forth, this mismatch has no effect.
184
+
185
+ Raises:
186
+ RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
187
+ well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
188
+ overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
189
+ of the integer ``dtype``.
190
+ """
191
+
192
+ def __init__(self, dtype: torch.dtype) -> None:
193
+ super().__init__()
194
+ _log_api_usage_once(self)
195
+ self.dtype = dtype
196
+
197
+ def forward(self, image):
198
+ return F.convert_image_dtype(image, self.dtype)
199
+
200
+
201
+ class ToPILImage:
202
+ """Convert a tensor or an ndarray to PIL Image
203
+
204
+ This transform does not support torchscript.
205
+
206
+ Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
207
+ H x W x C to a PIL Image while adjusting the value range depending on the ``mode``.
208
+
209
+ Args:
210
+ mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
211
+ If ``mode`` is ``None`` (default) there are some assumptions made about the input data:
212
+
213
+ - If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``.
214
+ - If the input has 3 channels, the ``mode`` is assumed to be ``RGB``.
215
+ - If the input has 2 channels, the ``mode`` is assumed to be ``LA``.
216
+ - If the input has 1 channel, the ``mode`` is determined by the data type (i.e ``int``, ``float``, ``short``).
217
+
218
+ .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
219
+ """
220
+
221
+ def __init__(self, mode=None):
222
+ _log_api_usage_once(self)
223
+ self.mode = mode
224
+
225
+ def __call__(self, pic):
226
+ """
227
+ Args:
228
+ pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
229
+
230
+ Returns:
231
+ PIL Image: Image converted to PIL Image.
232
+
233
+ """
234
+ return F.to_pil_image(pic, self.mode)
235
+
236
+ def __repr__(self) -> str:
237
+ format_string = self.__class__.__name__ + "("
238
+ if self.mode is not None:
239
+ format_string += f"mode={self.mode}"
240
+ format_string += ")"
241
+ return format_string
242
+
243
+
244
+ class Normalize(torch.nn.Module):
245
+ """Normalize a tensor image with mean and standard deviation.
246
+ This transform does not support PIL Image.
247
+ Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
248
+ channels, this transform will normalize each channel of the input
249
+ ``torch.*Tensor`` i.e.,
250
+ ``output[channel] = (input[channel] - mean[channel]) / std[channel]``
251
+
252
+ .. note::
253
+ This transform acts out of place, i.e., it does not mutate the input tensor.
254
+
255
+ Args:
256
+ mean (sequence): Sequence of means for each channel.
257
+ std (sequence): Sequence of standard deviations for each channel.
258
+ inplace(bool,optional): Bool to make this operation in-place.
259
+
260
+ """
261
+
262
+ def __init__(self, mean, std, inplace=False):
263
+ super().__init__()
264
+ _log_api_usage_once(self)
265
+ self.mean = mean
266
+ self.std = std
267
+ self.inplace = inplace
268
+
269
+ def forward(self, tensor: Tensor) -> Tensor:
270
+ """
271
+ Args:
272
+ tensor (Tensor): Tensor image to be normalized.
273
+
274
+ Returns:
275
+ Tensor: Normalized Tensor image.
276
+ """
277
+ return F.normalize(tensor, self.mean, self.std, self.inplace)
278
+
279
+ def __repr__(self) -> str:
280
+ return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})"
281
+
282
+
283
+ class Resize(torch.nn.Module):
284
+ """Resize the input image to the given size.
285
+ If the image is torch Tensor, it is expected
286
+ to have [..., H, W] shape, where ... means a maximum of two leading dimensions
287
+
288
+ Args:
289
+ size (sequence or int): Desired output size. If size is a sequence like
290
+ (h, w), output size will be matched to this. If size is an int,
291
+ smaller edge of the image will be matched to this number.
292
+ i.e, if height > width, then image will be rescaled to
293
+ (size * height / width, size).
294
+
295
+ .. note::
296
+ In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
297
+ interpolation (InterpolationMode): Desired interpolation enum defined by
298
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
299
+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
300
+ ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
301
+ The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
302
+ max_size (int, optional): The maximum allowed for the longer edge of
303
+ the resized image. If the longer edge of the image is greater
304
+ than ``max_size`` after being resized according to ``size``,
305
+ ``size`` will be overruled so that the longer edge is equal to
306
+ ``max_size``.
307
+ As a result, the smaller edge may be shorter than ``size``. This
308
+ is only supported if ``size`` is an int (or a sequence of length
309
+ 1 in torchscript mode).
310
+ antialias (bool, optional): Whether to apply antialiasing.
311
+ It only affects **tensors** with bilinear or bicubic modes and it is
312
+ ignored otherwise: on PIL images, antialiasing is always applied on
313
+ bilinear or bicubic modes; on other modes (for PIL images and
314
+ tensors), antialiasing makes no sense and this parameter is ignored.
315
+ Possible values are:
316
+
317
+ - ``True`` (default): will apply antialiasing for bilinear or bicubic modes.
318
+ Other mode aren't affected. This is probably what you want to use.
319
+ - ``False``: will not apply antialiasing for tensors on any mode. PIL
320
+ images are still antialiased on bilinear or bicubic modes, because
321
+ PIL doesn't support no antialias.
322
+ - ``None``: equivalent to ``False`` for tensors and ``True`` for
323
+ PIL images. This value exists for legacy reasons and you probably
324
+ don't want to use it unless you really know what you are doing.
325
+
326
+ The default value changed from ``None`` to ``True`` in
327
+ v0.17, for the PIL and Tensor backends to be consistent.
328
+ """
329
+
330
+ def __init__(self, size, interpolation=InterpolationMode.BILINEAR, max_size=None, antialias=True):
331
+ super().__init__()
332
+ _log_api_usage_once(self)
333
+ if not isinstance(size, (int, Sequence)):
334
+ raise TypeError(f"Size should be int or sequence. Got {type(size)}")
335
+ if isinstance(size, Sequence) and len(size) not in (1, 2):
336
+ raise ValueError("If size is a sequence, it should have 1 or 2 values")
337
+ self.size = size
338
+ self.max_size = max_size
339
+
340
+ if isinstance(interpolation, int):
341
+ interpolation = _interpolation_modes_from_int(interpolation)
342
+
343
+ self.interpolation = interpolation
344
+ self.antialias = antialias
345
+
346
+ def forward(self, img):
347
+ """
348
+ Args:
349
+ img (PIL Image or Tensor): Image to be scaled.
350
+
351
+ Returns:
352
+ PIL Image or Tensor: Rescaled image.
353
+ """
354
+ return F.resize(img, self.size, self.interpolation, self.max_size, self.antialias)
355
+
356
+ def __repr__(self) -> str:
357
+ detail = f"(size={self.size}, interpolation={self.interpolation.value}, max_size={self.max_size}, antialias={self.antialias})"
358
+ return f"{self.__class__.__name__}{detail}"
359
+
360
+
361
+ class CenterCrop(torch.nn.Module):
362
+ """Crops the given image at the center.
363
+ If the image is torch Tensor, it is expected
364
+ to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
365
+ If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
366
+
367
+ Args:
368
+ size (sequence or int): Desired output size of the crop. If size is an
369
+ int instead of sequence like (h, w), a square crop (size, size) is
370
+ made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
371
+ """
372
+
373
+ def __init__(self, size):
374
+ super().__init__()
375
+ _log_api_usage_once(self)
376
+ self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
377
+
378
+ def forward(self, img):
379
+ """
380
+ Args:
381
+ img (PIL Image or Tensor): Image to be cropped.
382
+
383
+ Returns:
384
+ PIL Image or Tensor: Cropped image.
385
+ """
386
+ return F.center_crop(img, self.size)
387
+
388
+ def __repr__(self) -> str:
389
+ return f"{self.__class__.__name__}(size={self.size})"
390
+
391
+
392
+ class Pad(torch.nn.Module):
393
+ """Pad the given image on all sides with the given "pad" value.
394
+ If the image is torch Tensor, it is expected
395
+ to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric,
396
+ at most 3 leading dimensions for mode edge,
397
+ and an arbitrary number of leading dimensions for mode constant
398
+
399
+ Args:
400
+ padding (int or sequence): Padding on each border. If a single int is provided this
401
+ is used to pad all borders. If sequence of length 2 is provided this is the padding
402
+ on left/right and top/bottom respectively. If a sequence of length 4 is provided
403
+ this is the padding for the left, top, right and bottom borders respectively.
404
+
405
+ .. note::
406
+ In torchscript mode padding as single int is not supported, use a sequence of
407
+ length 1: ``[padding, ]``.
408
+ fill (number or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
409
+ length 3, it is used to fill R, G, B channels respectively.
410
+ This value is only used when the padding_mode is constant.
411
+ Only number is supported for torch Tensor.
412
+ Only int or tuple value is supported for PIL Image.
413
+ padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
414
+ Default is constant.
415
+
416
+ - constant: pads with a constant value, this value is specified with fill
417
+
418
+ - edge: pads with the last value at the edge of the image.
419
+ If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2
420
+
421
+ - reflect: pads with reflection of image without repeating the last value on the edge.
422
+ For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
423
+ will result in [3, 2, 1, 2, 3, 4, 3, 2]
424
+
425
+ - symmetric: pads with reflection of image repeating the last value on the edge.
426
+ For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
427
+ will result in [2, 1, 1, 2, 3, 4, 4, 3]
428
+ """
429
+
430
+ def __init__(self, padding, fill=0, padding_mode="constant"):
431
+ super().__init__()
432
+ _log_api_usage_once(self)
433
+ if not isinstance(padding, (numbers.Number, tuple, list)):
434
+ raise TypeError("Got inappropriate padding arg")
435
+
436
+ if not isinstance(fill, (numbers.Number, tuple, list)):
437
+ raise TypeError("Got inappropriate fill arg")
438
+
439
+ if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
440
+ raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
441
+
442
+ if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]:
443
+ raise ValueError(
444
+ f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
445
+ )
446
+
447
+ self.padding = padding
448
+ self.fill = fill
449
+ self.padding_mode = padding_mode
450
+
451
+ def forward(self, img):
452
+ """
453
+ Args:
454
+ img (PIL Image or Tensor): Image to be padded.
455
+
456
+ Returns:
457
+ PIL Image or Tensor: Padded image.
458
+ """
459
+ return F.pad(img, self.padding, self.fill, self.padding_mode)
460
+
461
+ def __repr__(self) -> str:
462
+ return f"{self.__class__.__name__}(padding={self.padding}, fill={self.fill}, padding_mode={self.padding_mode})"
463
+
464
+
465
+ class Lambda:
466
+ """Apply a user-defined lambda as a transform. This transform does not support torchscript.
467
+
468
+ Args:
469
+ lambd (function): Lambda/function to be used for transform.
470
+ """
471
+
472
+ def __init__(self, lambd):
473
+ _log_api_usage_once(self)
474
+ if not callable(lambd):
475
+ raise TypeError(f"Argument lambd should be callable, got {repr(type(lambd).__name__)}")
476
+ self.lambd = lambd
477
+
478
+ def __call__(self, img):
479
+ return self.lambd(img)
480
+
481
+ def __repr__(self) -> str:
482
+ return f"{self.__class__.__name__}()"
483
+
484
+
485
+ class RandomTransforms:
486
+ """Base class for a list of transformations with randomness
487
+
488
+ Args:
489
+ transforms (sequence): list of transformations
490
+ """
491
+
492
+ def __init__(self, transforms):
493
+ _log_api_usage_once(self)
494
+ if not isinstance(transforms, Sequence):
495
+ raise TypeError("Argument transforms should be a sequence")
496
+ self.transforms = transforms
497
+
498
+ def __call__(self, *args, **kwargs):
499
+ raise NotImplementedError()
500
+
501
+ def __repr__(self) -> str:
502
+ format_string = self.__class__.__name__ + "("
503
+ for t in self.transforms:
504
+ format_string += "\n"
505
+ format_string += f" {t}"
506
+ format_string += "\n)"
507
+ return format_string
508
+
509
+
510
+ class RandomApply(torch.nn.Module):
511
+ """Apply randomly a list of transformations with a given probability.
512
+
513
+ .. note::
514
+ In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of
515
+ transforms as shown below:
516
+
517
+ >>> transforms = transforms.RandomApply(torch.nn.ModuleList([
518
+ >>> transforms.ColorJitter(),
519
+ >>> ]), p=0.3)
520
+ >>> scripted_transforms = torch.jit.script(transforms)
521
+
522
+ Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
523
+ `lambda` functions or ``PIL.Image``.
524
+
525
+ Args:
526
+ transforms (sequence or torch.nn.Module): list of transformations
527
+ p (float): probability
528
+ """
529
+
530
+ def __init__(self, transforms, p=0.5):
531
+ super().__init__()
532
+ _log_api_usage_once(self)
533
+ self.transforms = transforms
534
+ self.p = p
535
+
536
+ def forward(self, img):
537
+ if self.p < torch.rand(1):
538
+ return img
539
+ for t in self.transforms:
540
+ img = t(img)
541
+ return img
542
+
543
+ def __repr__(self) -> str:
544
+ format_string = self.__class__.__name__ + "("
545
+ format_string += f"\n p={self.p}"
546
+ for t in self.transforms:
547
+ format_string += "\n"
548
+ format_string += f" {t}"
549
+ format_string += "\n)"
550
+ return format_string
551
+
552
+
553
+ class RandomOrder(RandomTransforms):
554
+ """Apply a list of transformations in a random order. This transform does not support torchscript."""
555
+
556
+ def __call__(self, img):
557
+ order = list(range(len(self.transforms)))
558
+ random.shuffle(order)
559
+ for i in order:
560
+ img = self.transforms[i](img)
561
+ return img
562
+
563
+
564
+ class RandomChoice(RandomTransforms):
565
+ """Apply single transformation randomly picked from a list. This transform does not support torchscript."""
566
+
567
+ def __init__(self, transforms, p=None):
568
+ super().__init__(transforms)
569
+ if p is not None and not isinstance(p, Sequence):
570
+ raise TypeError("Argument p should be a sequence")
571
+ self.p = p
572
+
573
+ def __call__(self, *args):
574
+ t = random.choices(self.transforms, weights=self.p)[0]
575
+ return t(*args)
576
+
577
+ def __repr__(self) -> str:
578
+ return f"{super().__repr__()}(p={self.p})"
579
+
580
+
581
+ class RandomCrop(torch.nn.Module):
582
+ """Crop the given image at a random location.
583
+ If the image is torch Tensor, it is expected
584
+ to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions,
585
+ but if non-constant padding is used, the input is expected to have at most 2 leading dimensions
586
+
587
+ Args:
588
+ size (sequence or int): Desired output size of the crop. If size is an
589
+ int instead of sequence like (h, w), a square crop (size, size) is
590
+ made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
591
+ padding (int or sequence, optional): Optional padding on each border
592
+ of the image. Default is None. If a single int is provided this
593
+ is used to pad all borders. If sequence of length 2 is provided this is the padding
594
+ on left/right and top/bottom respectively. If a sequence of length 4 is provided
595
+ this is the padding for the left, top, right and bottom borders respectively.
596
+
597
+ .. note::
598
+ In torchscript mode padding as single int is not supported, use a sequence of
599
+ length 1: ``[padding, ]``.
600
+ pad_if_needed (boolean): It will pad the image if smaller than the
601
+ desired size to avoid raising an exception. Since cropping is done
602
+ after padding, the padding seems to be done at a random offset.
603
+ fill (number or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
604
+ length 3, it is used to fill R, G, B channels respectively.
605
+ This value is only used when the padding_mode is constant.
606
+ Only number is supported for torch Tensor.
607
+ Only int or tuple value is supported for PIL Image.
608
+ padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
609
+ Default is constant.
610
+
611
+ - constant: pads with a constant value, this value is specified with fill
612
+
613
+ - edge: pads with the last value at the edge of the image.
614
+ If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2
615
+
616
+ - reflect: pads with reflection of image without repeating the last value on the edge.
617
+ For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
618
+ will result in [3, 2, 1, 2, 3, 4, 3, 2]
619
+
620
+ - symmetric: pads with reflection of image repeating the last value on the edge.
621
+ For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
622
+ will result in [2, 1, 1, 2, 3, 4, 4, 3]
623
+ """
624
+
625
+ @staticmethod
626
+ def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]:
627
+ """Get parameters for ``crop`` for a random crop.
628
+
629
+ Args:
630
+ img (PIL Image or Tensor): Image to be cropped.
631
+ output_size (tuple): Expected output size of the crop.
632
+
633
+ Returns:
634
+ tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
635
+ """
636
+ _, h, w = F.get_dimensions(img)
637
+ th, tw = output_size
638
+
639
+ if h < th or w < tw:
640
+ raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
641
+
642
+ if w == tw and h == th:
643
+ return 0, 0, h, w
644
+
645
+ i = torch.randint(0, h - th + 1, size=(1,)).item()
646
+ j = torch.randint(0, w - tw + 1, size=(1,)).item()
647
+ return i, j, th, tw
648
+
649
+ def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"):
650
+ super().__init__()
651
+ _log_api_usage_once(self)
652
+
653
+ self.size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
654
+
655
+ self.padding = padding
656
+ self.pad_if_needed = pad_if_needed
657
+ self.fill = fill
658
+ self.padding_mode = padding_mode
659
+
660
+ def forward(self, img):
661
+ """
662
+ Args:
663
+ img (PIL Image or Tensor): Image to be cropped.
664
+
665
+ Returns:
666
+ PIL Image or Tensor: Cropped image.
667
+ """
668
+ if self.padding is not None:
669
+ img = F.pad(img, self.padding, self.fill, self.padding_mode)
670
+
671
+ _, height, width = F.get_dimensions(img)
672
+ # pad the width if needed
673
+ if self.pad_if_needed and width < self.size[1]:
674
+ padding = [self.size[1] - width, 0]
675
+ img = F.pad(img, padding, self.fill, self.padding_mode)
676
+ # pad the height if needed
677
+ if self.pad_if_needed and height < self.size[0]:
678
+ padding = [0, self.size[0] - height]
679
+ img = F.pad(img, padding, self.fill, self.padding_mode)
680
+
681
+ i, j, h, w = self.get_params(img, self.size)
682
+
683
+ return F.crop(img, i, j, h, w)
684
+
685
+ def __repr__(self) -> str:
686
+ return f"{self.__class__.__name__}(size={self.size}, padding={self.padding})"
687
+
688
+
689
+ class RandomHorizontalFlip(torch.nn.Module):
690
+ """Horizontally flip the given image randomly with a given probability.
691
+ If the image is torch Tensor, it is expected
692
+ to have [..., H, W] shape, where ... means an arbitrary number of leading
693
+ dimensions
694
+
695
+ Args:
696
+ p (float): probability of the image being flipped. Default value is 0.5
697
+ """
698
+
699
+ def __init__(self, p=0.5):
700
+ super().__init__()
701
+ _log_api_usage_once(self)
702
+ self.p = p
703
+
704
+ def forward(self, img):
705
+ """
706
+ Args:
707
+ img (PIL Image or Tensor): Image to be flipped.
708
+
709
+ Returns:
710
+ PIL Image or Tensor: Randomly flipped image.
711
+ """
712
+ if torch.rand(1) < self.p:
713
+ return F.hflip(img)
714
+ return img
715
+
716
+ def __repr__(self) -> str:
717
+ return f"{self.__class__.__name__}(p={self.p})"
718
+
719
+
720
+ class RandomVerticalFlip(torch.nn.Module):
721
+ """Vertically flip the given image randomly with a given probability.
722
+ If the image is torch Tensor, it is expected
723
+ to have [..., H, W] shape, where ... means an arbitrary number of leading
724
+ dimensions
725
+
726
+ Args:
727
+ p (float): probability of the image being flipped. Default value is 0.5
728
+ """
729
+
730
+ def __init__(self, p=0.5):
731
+ super().__init__()
732
+ _log_api_usage_once(self)
733
+ self.p = p
734
+
735
+ def forward(self, img):
736
+ """
737
+ Args:
738
+ img (PIL Image or Tensor): Image to be flipped.
739
+
740
+ Returns:
741
+ PIL Image or Tensor: Randomly flipped image.
742
+ """
743
+ if torch.rand(1) < self.p:
744
+ return F.vflip(img)
745
+ return img
746
+
747
+ def __repr__(self) -> str:
748
+ return f"{self.__class__.__name__}(p={self.p})"
749
+
750
+
751
+ class RandomPerspective(torch.nn.Module):
752
+ """Performs a random perspective transformation of the given image with a given probability.
753
+ If the image is torch Tensor, it is expected
754
+ to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
755
+
756
+ Args:
757
+ distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
758
+ Default is 0.5.
759
+ p (float): probability of the image being transformed. Default is 0.5.
760
+ interpolation (InterpolationMode): Desired interpolation enum defined by
761
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
762
+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
763
+ The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
764
+ fill (sequence or number): Pixel fill value for the area outside the transformed
765
+ image. Default is ``0``. If given a number, the value is used for all bands respectively.
766
+ """
767
+
768
+ def __init__(self, distortion_scale=0.5, p=0.5, interpolation=InterpolationMode.BILINEAR, fill=0):
769
+ super().__init__()
770
+ _log_api_usage_once(self)
771
+ self.p = p
772
+
773
+ if isinstance(interpolation, int):
774
+ interpolation = _interpolation_modes_from_int(interpolation)
775
+
776
+ self.interpolation = interpolation
777
+ self.distortion_scale = distortion_scale
778
+
779
+ if fill is None:
780
+ fill = 0
781
+ elif not isinstance(fill, (Sequence, numbers.Number)):
782
+ raise TypeError("Fill should be either a sequence or a number.")
783
+
784
+ self.fill = fill
785
+
786
+ def forward(self, img):
787
+ """
788
+ Args:
789
+ img (PIL Image or Tensor): Image to be Perspectively transformed.
790
+
791
+ Returns:
792
+ PIL Image or Tensor: Randomly transformed image.
793
+ """
794
+
795
+ fill = self.fill
796
+ channels, height, width = F.get_dimensions(img)
797
+ if isinstance(img, Tensor):
798
+ if isinstance(fill, (int, float)):
799
+ fill = [float(fill)] * channels
800
+ else:
801
+ fill = [float(f) for f in fill]
802
+
803
+ if torch.rand(1) < self.p:
804
+ startpoints, endpoints = self.get_params(width, height, self.distortion_scale)
805
+ return F.perspective(img, startpoints, endpoints, self.interpolation, fill)
806
+ return img
807
+
808
+ @staticmethod
809
+ def get_params(width: int, height: int, distortion_scale: float) -> Tuple[List[List[int]], List[List[int]]]:
810
+ """Get parameters for ``perspective`` for a random perspective transform.
811
+
812
+ Args:
813
+ width (int): width of the image.
814
+ height (int): height of the image.
815
+ distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
816
+
817
+ Returns:
818
+ List containing [top-left, top-right, bottom-right, bottom-left] of the original image,
819
+ List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image.
820
+ """
821
+ half_height = height // 2
822
+ half_width = width // 2
823
+ topleft = [
824
+ int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
825
+ int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
826
+ ]
827
+ topright = [
828
+ int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
829
+ int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
830
+ ]
831
+ botright = [
832
+ int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
833
+ int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
834
+ ]
835
+ botleft = [
836
+ int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
837
+ int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
838
+ ]
839
+ startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]
840
+ endpoints = [topleft, topright, botright, botleft]
841
+ return startpoints, endpoints
842
+
843
+ def __repr__(self) -> str:
844
+ return f"{self.__class__.__name__}(p={self.p})"
845
+
846
+
847
+ class RandomResizedCrop(torch.nn.Module):
848
+ """Crop a random portion of image and resize it to a given size.
849
+
850
+ If the image is torch Tensor, it is expected
851
+ to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
852
+
853
+ A crop of the original image is made: the crop has a random area (H * W)
854
+ and a random aspect ratio. This crop is finally resized to the given
855
+ size. This is popularly used to train the Inception networks.
856
+
857
+ Args:
858
+ size (int or sequence): expected output size of the crop, for each edge. If size is an
859
+ int instead of sequence like (h, w), a square output size ``(size, size)`` is
860
+ made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
861
+
862
+ .. note::
863
+ In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
864
+ scale (tuple of float): Specifies the lower and upper bounds for the random area of the crop,
865
+ before resizing. The scale is defined with respect to the area of the original image.
866
+ ratio (tuple of float): lower and upper bounds for the random aspect ratio of the crop, before
867
+ resizing.
868
+ interpolation (InterpolationMode): Desired interpolation enum defined by
869
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
870
+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
871
+ ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
872
+ The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
873
+ antialias (bool, optional): Whether to apply antialiasing.
874
+ It only affects **tensors** with bilinear or bicubic modes and it is
875
+ ignored otherwise: on PIL images, antialiasing is always applied on
876
+ bilinear or bicubic modes; on other modes (for PIL images and
877
+ tensors), antialiasing makes no sense and this parameter is ignored.
878
+ Possible values are:
879
+
880
+ - ``True`` (default): will apply antialiasing for bilinear or bicubic modes.
881
+ Other mode aren't affected. This is probably what you want to use.
882
+ - ``False``: will not apply antialiasing for tensors on any mode. PIL
883
+ images are still antialiased on bilinear or bicubic modes, because
884
+ PIL doesn't support no antialias.
885
+ - ``None``: equivalent to ``False`` for tensors and ``True`` for
886
+ PIL images. This value exists for legacy reasons and you probably
887
+ don't want to use it unless you really know what you are doing.
888
+
889
+ The default value changed from ``None`` to ``True`` in
890
+ v0.17, for the PIL and Tensor backends to be consistent.
891
+ """
892
+
893
+ def __init__(
894
+ self,
895
+ size,
896
+ scale=(0.08, 1.0),
897
+ ratio=(3.0 / 4.0, 4.0 / 3.0),
898
+ interpolation=InterpolationMode.BILINEAR,
899
+ antialias: Optional[bool] = True,
900
+ ):
901
+ super().__init__()
902
+ _log_api_usage_once(self)
903
+ self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
904
+
905
+ if not isinstance(scale, Sequence):
906
+ raise TypeError("Scale should be a sequence")
907
+ if not isinstance(ratio, Sequence):
908
+ raise TypeError("Ratio should be a sequence")
909
+ if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
910
+ warnings.warn("Scale and ratio should be of kind (min, max)")
911
+
912
+ if isinstance(interpolation, int):
913
+ interpolation = _interpolation_modes_from_int(interpolation)
914
+
915
+ self.interpolation = interpolation
916
+ self.antialias = antialias
917
+ self.scale = scale
918
+ self.ratio = ratio
919
+
920
+ @staticmethod
921
+ def get_params(img: Tensor, scale: List[float], ratio: List[float]) -> Tuple[int, int, int, int]:
922
+ """Get parameters for ``crop`` for a random sized crop.
923
+
924
+ Args:
925
+ img (PIL Image or Tensor): Input image.
926
+ scale (list): range of scale of the origin size cropped
927
+ ratio (list): range of aspect ratio of the origin aspect ratio cropped
928
+
929
+ Returns:
930
+ tuple: params (i, j, h, w) to be passed to ``crop`` for a random
931
+ sized crop.
932
+ """
933
+ _, height, width = F.get_dimensions(img)
934
+ area = height * width
935
+
936
+ log_ratio = torch.log(torch.tensor(ratio))
937
+ for _ in range(10):
938
+ target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
939
+ aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
940
+
941
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
942
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
943
+
944
+ if 0 < w <= width and 0 < h <= height:
945
+ i = torch.randint(0, height - h + 1, size=(1,)).item()
946
+ j = torch.randint(0, width - w + 1, size=(1,)).item()
947
+ return i, j, h, w
948
+
949
+ # Fallback to central crop
950
+ in_ratio = float(width) / float(height)
951
+ if in_ratio < min(ratio):
952
+ w = width
953
+ h = int(round(w / min(ratio)))
954
+ elif in_ratio > max(ratio):
955
+ h = height
956
+ w = int(round(h * max(ratio)))
957
+ else: # whole image
958
+ w = width
959
+ h = height
960
+ i = (height - h) // 2
961
+ j = (width - w) // 2
962
+ return i, j, h, w
963
+
964
+ def forward(self, img):
965
+ """
966
+ Args:
967
+ img (PIL Image or Tensor): Image to be cropped and resized.
968
+
969
+ Returns:
970
+ PIL Image or Tensor: Randomly cropped and resized image.
971
+ """
972
+ i, j, h, w = self.get_params(img, self.scale, self.ratio)
973
+ return F.resized_crop(img, i, j, h, w, self.size, self.interpolation, antialias=self.antialias)
974
+
975
+ def __repr__(self) -> str:
976
+ interpolate_str = self.interpolation.value
977
+ format_string = self.__class__.__name__ + f"(size={self.size}"
978
+ format_string += f", scale={tuple(round(s, 4) for s in self.scale)}"
979
+ format_string += f", ratio={tuple(round(r, 4) for r in self.ratio)}"
980
+ format_string += f", interpolation={interpolate_str}"
981
+ format_string += f", antialias={self.antialias})"
982
+ return format_string
983
+
984
+
985
+ class FiveCrop(torch.nn.Module):
986
+ """Crop the given image into four corners and the central crop.
987
+ If the image is torch Tensor, it is expected
988
+ to have [..., H, W] shape, where ... means an arbitrary number of leading
989
+ dimensions
990
+
991
+ .. Note::
992
+ This transform returns a tuple of images and there may be a mismatch in the number of
993
+ inputs and targets your Dataset returns. See below for an example of how to deal with
994
+ this.
995
+
996
+ Args:
997
+ size (sequence or int): Desired output size of the crop. If size is an ``int``
998
+ instead of sequence like (h, w), a square crop of size (size, size) is made.
999
+ If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
1000
+
1001
+ Example:
1002
+ >>> transform = Compose([
1003
+ >>> FiveCrop(size), # this is a list of PIL Images
1004
+ >>> Lambda(lambda crops: torch.stack([PILToTensor()(crop) for crop in crops])) # returns a 4D tensor
1005
+ >>> ])
1006
+ >>> #In your test loop you can do the following:
1007
+ >>> input, target = batch # input is a 5d tensor, target is 2d
1008
+ >>> bs, ncrops, c, h, w = input.size()
1009
+ >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
1010
+ >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
1011
+ """
1012
+
1013
+ def __init__(self, size):
1014
+ super().__init__()
1015
+ _log_api_usage_once(self)
1016
+ self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
1017
+
1018
+ def forward(self, img):
1019
+ """
1020
+ Args:
1021
+ img (PIL Image or Tensor): Image to be cropped.
1022
+
1023
+ Returns:
1024
+ tuple of 5 images. Image can be PIL Image or Tensor
1025
+ """
1026
+ return F.five_crop(img, self.size)
1027
+
1028
+ def __repr__(self) -> str:
1029
+ return f"{self.__class__.__name__}(size={self.size})"
1030
+
1031
+
1032
+ class TenCrop(torch.nn.Module):
1033
+ """Crop the given image into four corners and the central crop plus the flipped version of
1034
+ these (horizontal flipping is used by default).
1035
+ If the image is torch Tensor, it is expected
1036
+ to have [..., H, W] shape, where ... means an arbitrary number of leading
1037
+ dimensions
1038
+
1039
+ .. Note::
1040
+ This transform returns a tuple of images and there may be a mismatch in the number of
1041
+ inputs and targets your Dataset returns. See below for an example of how to deal with
1042
+ this.
1043
+
1044
+ Args:
1045
+ size (sequence or int): Desired output size of the crop. If size is an
1046
+ int instead of sequence like (h, w), a square crop (size, size) is
1047
+ made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
1048
+ vertical_flip (bool): Use vertical flipping instead of horizontal
1049
+
1050
+ Example:
1051
+ >>> transform = Compose([
1052
+ >>> TenCrop(size), # this is a tuple of PIL Images
1053
+ >>> Lambda(lambda crops: torch.stack([PILToTensor()(crop) for crop in crops])) # returns a 4D tensor
1054
+ >>> ])
1055
+ >>> #In your test loop you can do the following:
1056
+ >>> input, target = batch # input is a 5d tensor, target is 2d
1057
+ >>> bs, ncrops, c, h, w = input.size()
1058
+ >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
1059
+ >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
1060
+ """
1061
+
1062
+ def __init__(self, size, vertical_flip=False):
1063
+ super().__init__()
1064
+ _log_api_usage_once(self)
1065
+ self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
1066
+ self.vertical_flip = vertical_flip
1067
+
1068
+ def forward(self, img):
1069
+ """
1070
+ Args:
1071
+ img (PIL Image or Tensor): Image to be cropped.
1072
+
1073
+ Returns:
1074
+ tuple of 10 images. Image can be PIL Image or Tensor
1075
+ """
1076
+ return F.ten_crop(img, self.size, self.vertical_flip)
1077
+
1078
+ def __repr__(self) -> str:
1079
+ return f"{self.__class__.__name__}(size={self.size}, vertical_flip={self.vertical_flip})"
1080
+
1081
+
1082
+ class LinearTransformation(torch.nn.Module):
1083
+ """Transform a tensor image with a square transformation matrix and a mean_vector computed
1084
+ offline.
1085
+ This transform does not support PIL Image.
1086
+ Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and
1087
+ subtract mean_vector from it which is then followed by computing the dot
1088
+ product with the transformation matrix and then reshaping the tensor to its
1089
+ original shape.
1090
+
1091
+ Applications:
1092
+ whitening transformation: Suppose X is a column vector zero-centered data.
1093
+ Then compute the data covariance matrix [D x D] with torch.mm(X.t(), X),
1094
+ perform SVD on this matrix and pass it as transformation_matrix.
1095
+
1096
+ Args:
1097
+ transformation_matrix (Tensor): tensor [D x D], D = C x H x W
1098
+ mean_vector (Tensor): tensor [D], D = C x H x W
1099
+ """
1100
+
1101
+ def __init__(self, transformation_matrix, mean_vector):
1102
+ super().__init__()
1103
+ _log_api_usage_once(self)
1104
+ if transformation_matrix.size(0) != transformation_matrix.size(1):
1105
+ raise ValueError(
1106
+ "transformation_matrix should be square. Got "
1107
+ f"{tuple(transformation_matrix.size())} rectangular matrix."
1108
+ )
1109
+
1110
+ if mean_vector.size(0) != transformation_matrix.size(0):
1111
+ raise ValueError(
1112
+ f"mean_vector should have the same length {mean_vector.size(0)}"
1113
+ f" as any one of the dimensions of the transformation_matrix [{tuple(transformation_matrix.size())}]"
1114
+ )
1115
+
1116
+ if transformation_matrix.device != mean_vector.device:
1117
+ raise ValueError(
1118
+ f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}"
1119
+ )
1120
+
1121
+ if transformation_matrix.dtype != mean_vector.dtype:
1122
+ raise ValueError(
1123
+ f"Input tensors should have the same dtype. Got {transformation_matrix.dtype} and {mean_vector.dtype}"
1124
+ )
1125
+
1126
+ self.transformation_matrix = transformation_matrix
1127
+ self.mean_vector = mean_vector
1128
+
1129
+ def forward(self, tensor: Tensor) -> Tensor:
1130
+ """
1131
+ Args:
1132
+ tensor (Tensor): Tensor image to be whitened.
1133
+
1134
+ Returns:
1135
+ Tensor: Transformed image.
1136
+ """
1137
+ shape = tensor.shape
1138
+ n = shape[-3] * shape[-2] * shape[-1]
1139
+ if n != self.transformation_matrix.shape[0]:
1140
+ raise ValueError(
1141
+ "Input tensor and transformation matrix have incompatible shape."
1142
+ + f"[{shape[-3]} x {shape[-2]} x {shape[-1]}] != "
1143
+ + f"{self.transformation_matrix.shape[0]}"
1144
+ )
1145
+
1146
+ if tensor.device.type != self.mean_vector.device.type:
1147
+ raise ValueError(
1148
+ "Input tensor should be on the same device as transformation matrix and mean vector. "
1149
+ f"Got {tensor.device} vs {self.mean_vector.device}"
1150
+ )
1151
+
1152
+ flat_tensor = tensor.view(-1, n) - self.mean_vector
1153
+ transformation_matrix = self.transformation_matrix.to(flat_tensor.dtype)
1154
+ transformed_tensor = torch.mm(flat_tensor, transformation_matrix)
1155
+ tensor = transformed_tensor.view(shape)
1156
+ return tensor
1157
+
1158
+ def __repr__(self) -> str:
1159
+ s = (
1160
+ f"{self.__class__.__name__}(transformation_matrix="
1161
+ f"{self.transformation_matrix.tolist()}"
1162
+ f", mean_vector={self.mean_vector.tolist()})"
1163
+ )
1164
+ return s
1165
+
1166
+
1167
+ class ColorJitter(torch.nn.Module):
1168
+ """Randomly change the brightness, contrast, saturation and hue of an image.
1169
+ If the image is torch Tensor, it is expected
1170
+ to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
1171
+ If img is PIL Image, mode "1", "I", "F" and modes with transparency (alpha channel) are not supported.
1172
+
1173
+ Args:
1174
+ brightness (float or tuple of float (min, max)): How much to jitter brightness.
1175
+ brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
1176
+ or the given [min, max]. Should be non negative numbers.
1177
+ contrast (float or tuple of float (min, max)): How much to jitter contrast.
1178
+ contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
1179
+ or the given [min, max]. Should be non-negative numbers.
1180
+ saturation (float or tuple of float (min, max)): How much to jitter saturation.
1181
+ saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
1182
+ or the given [min, max]. Should be non negative numbers.
1183
+ hue (float or tuple of float (min, max)): How much to jitter hue.
1184
+ hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
1185
+ Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
1186
+ To jitter hue, the pixel values of the input image has to be non-negative for conversion to HSV space;
1187
+ thus it does not work if you normalize your image to an interval with negative values,
1188
+ or use an interpolation that generates negative values before using this function.
1189
+ """
1190
+
1191
+ def __init__(
1192
+ self,
1193
+ brightness: Union[float, Tuple[float, float]] = 0,
1194
+ contrast: Union[float, Tuple[float, float]] = 0,
1195
+ saturation: Union[float, Tuple[float, float]] = 0,
1196
+ hue: Union[float, Tuple[float, float]] = 0,
1197
+ ) -> None:
1198
+ super().__init__()
1199
+ _log_api_usage_once(self)
1200
+ self.brightness = self._check_input(brightness, "brightness")
1201
+ self.contrast = self._check_input(contrast, "contrast")
1202
+ self.saturation = self._check_input(saturation, "saturation")
1203
+ self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
1204
+
1205
+ @torch.jit.unused
1206
+ def _check_input(self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True):
1207
+ if isinstance(value, numbers.Number):
1208
+ if value < 0:
1209
+ raise ValueError(f"If {name} is a single number, it must be non negative.")
1210
+ value = [center - float(value), center + float(value)]
1211
+ if clip_first_on_zero:
1212
+ value[0] = max(value[0], 0.0)
1213
+ elif isinstance(value, (tuple, list)) and len(value) == 2:
1214
+ value = [float(value[0]), float(value[1])]
1215
+ else:
1216
+ raise TypeError(f"{name} should be a single number or a list/tuple with length 2.")
1217
+
1218
+ if not bound[0] <= value[0] <= value[1] <= bound[1]:
1219
+ raise ValueError(f"{name} values should be between {bound}, but got {value}.")
1220
+
1221
+ # if value is 0 or (1., 1.) for brightness/contrast/saturation
1222
+ # or (0., 0.) for hue, do nothing
1223
+ if value[0] == value[1] == center:
1224
+ return None
1225
+ else:
1226
+ return tuple(value)
1227
+
1228
+ @staticmethod
1229
+ def get_params(
1230
+ brightness: Optional[List[float]],
1231
+ contrast: Optional[List[float]],
1232
+ saturation: Optional[List[float]],
1233
+ hue: Optional[List[float]],
1234
+ ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]:
1235
+ """Get the parameters for the randomized transform to be applied on image.
1236
+
1237
+ Args:
1238
+ brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen
1239
+ uniformly. Pass None to turn off the transformation.
1240
+ contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen
1241
+ uniformly. Pass None to turn off the transformation.
1242
+ saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen
1243
+ uniformly. Pass None to turn off the transformation.
1244
+ hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly.
1245
+ Pass None to turn off the transformation.
1246
+
1247
+ Returns:
1248
+ tuple: The parameters used to apply the randomized transform
1249
+ along with their random order.
1250
+ """
1251
+ fn_idx = torch.randperm(4)
1252
+
1253
+ b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
1254
+ c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
1255
+ s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
1256
+ h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
1257
+
1258
+ return fn_idx, b, c, s, h
1259
+
1260
+ def forward(self, img):
1261
+ """
1262
+ Args:
1263
+ img (PIL Image or Tensor): Input image.
1264
+
1265
+ Returns:
1266
+ PIL Image or Tensor: Color jittered image.
1267
+ """
1268
+ fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
1269
+ self.brightness, self.contrast, self.saturation, self.hue
1270
+ )
1271
+
1272
+ for fn_id in fn_idx:
1273
+ if fn_id == 0 and brightness_factor is not None:
1274
+ img = F.adjust_brightness(img, brightness_factor)
1275
+ elif fn_id == 1 and contrast_factor is not None:
1276
+ img = F.adjust_contrast(img, contrast_factor)
1277
+ elif fn_id == 2 and saturation_factor is not None:
1278
+ img = F.adjust_saturation(img, saturation_factor)
1279
+ elif fn_id == 3 and hue_factor is not None:
1280
+ img = F.adjust_hue(img, hue_factor)
1281
+
1282
+ return img
1283
+
1284
+ def __repr__(self) -> str:
1285
+ s = (
1286
+ f"{self.__class__.__name__}("
1287
+ f"brightness={self.brightness}"
1288
+ f", contrast={self.contrast}"
1289
+ f", saturation={self.saturation}"
1290
+ f", hue={self.hue})"
1291
+ )
1292
+ return s
1293
+
1294
+
1295
+ class RandomRotation(torch.nn.Module):
1296
+ """Rotate the image by angle.
1297
+ If the image is torch Tensor, it is expected
1298
+ to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
1299
+
1300
+ Args:
1301
+ degrees (sequence or number): Range of degrees to select from.
1302
+ If degrees is a number instead of sequence like (min, max), the range of degrees
1303
+ will be (-degrees, +degrees).
1304
+ interpolation (InterpolationMode): Desired interpolation enum defined by
1305
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
1306
+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
1307
+ The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
1308
+ expand (bool, optional): Optional expansion flag.
1309
+ If true, expands the output to make it large enough to hold the entire rotated image.
1310
+ If false or omitted, make the output image the same size as the input image.
1311
+ Note that the expand flag assumes rotation around the center and no translation.
1312
+ center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
1313
+ Default is the center of the image.
1314
+ fill (sequence or number): Pixel fill value for the area outside the rotated
1315
+ image. Default is ``0``. If given a number, the value is used for all bands respectively.
1316
+
1317
+ .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
1318
+
1319
+ """
1320
+
1321
+ def __init__(self, degrees, interpolation=InterpolationMode.NEAREST, expand=False, center=None, fill=0):
1322
+ super().__init__()
1323
+ _log_api_usage_once(self)
1324
+
1325
+ if isinstance(interpolation, int):
1326
+ interpolation = _interpolation_modes_from_int(interpolation)
1327
+
1328
+ self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
1329
+
1330
+ if center is not None:
1331
+ _check_sequence_input(center, "center", req_sizes=(2,))
1332
+
1333
+ self.center = center
1334
+
1335
+ self.interpolation = interpolation
1336
+ self.expand = expand
1337
+
1338
+ if fill is None:
1339
+ fill = 0
1340
+ elif not isinstance(fill, (Sequence, numbers.Number)):
1341
+ raise TypeError("Fill should be either a sequence or a number.")
1342
+
1343
+ self.fill = fill
1344
+
1345
+ @staticmethod
1346
+ def get_params(degrees: List[float]) -> float:
1347
+ """Get parameters for ``rotate`` for a random rotation.
1348
+
1349
+ Returns:
1350
+ float: angle parameter to be passed to ``rotate`` for random rotation.
1351
+ """
1352
+ angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
1353
+ return angle
1354
+
1355
+ def forward(self, img):
1356
+ """
1357
+ Args:
1358
+ img (PIL Image or Tensor): Image to be rotated.
1359
+
1360
+ Returns:
1361
+ PIL Image or Tensor: Rotated image.
1362
+ """
1363
+ fill = self.fill
1364
+ channels, _, _ = F.get_dimensions(img)
1365
+ if isinstance(img, Tensor):
1366
+ if isinstance(fill, (int, float)):
1367
+ fill = [float(fill)] * channels
1368
+ else:
1369
+ fill = [float(f) for f in fill]
1370
+ angle = self.get_params(self.degrees)
1371
+
1372
+ return F.rotate(img, angle, self.interpolation, self.expand, self.center, fill)
1373
+
1374
+ def __repr__(self) -> str:
1375
+ interpolate_str = self.interpolation.value
1376
+ format_string = self.__class__.__name__ + f"(degrees={self.degrees}"
1377
+ format_string += f", interpolation={interpolate_str}"
1378
+ format_string += f", expand={self.expand}"
1379
+ if self.center is not None:
1380
+ format_string += f", center={self.center}"
1381
+ if self.fill is not None:
1382
+ format_string += f", fill={self.fill}"
1383
+ format_string += ")"
1384
+ return format_string
1385
+
1386
+
1387
+ class RandomAffine(torch.nn.Module):
1388
+ """Random affine transformation of the image keeping center invariant.
1389
+ If the image is torch Tensor, it is expected
1390
+ to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
1391
+
1392
+ Args:
1393
+ degrees (sequence or number): Range of degrees to select from.
1394
+ If degrees is a number instead of sequence like (min, max), the range of degrees
1395
+ will be (-degrees, +degrees). Set to 0 to deactivate rotations.
1396
+ translate (tuple, optional): tuple of maximum absolute fraction for horizontal
1397
+ and vertical translations. For example translate=(a, b), then horizontal shift
1398
+ is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
1399
+ randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
1400
+ scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is
1401
+ randomly sampled from the range a <= scale <= b. Will keep original scale by default.
1402
+ shear (sequence or number, optional): Range of degrees to select from.
1403
+ If shear is a number, a shear parallel to the x-axis in the range (-shear, +shear)
1404
+ will be applied. Else if shear is a sequence of 2 values a shear parallel to the x-axis in the
1405
+ range (shear[0], shear[1]) will be applied. Else if shear is a sequence of 4 values,
1406
+ an x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied.
1407
+ Will not apply shear by default.
1408
+ interpolation (InterpolationMode): Desired interpolation enum defined by
1409
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
1410
+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
1411
+ The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
1412
+ fill (sequence or number): Pixel fill value for the area outside the transformed
1413
+ image. Default is ``0``. If given a number, the value is used for all bands respectively.
1414
+ center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
1415
+ Default is the center of the image.
1416
+
1417
+ .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
1418
+
1419
+ """
1420
+
1421
+ def __init__(
1422
+ self,
1423
+ degrees,
1424
+ translate=None,
1425
+ scale=None,
1426
+ shear=None,
1427
+ interpolation=InterpolationMode.NEAREST,
1428
+ fill=0,
1429
+ center=None,
1430
+ ):
1431
+ super().__init__()
1432
+ _log_api_usage_once(self)
1433
+
1434
+ if isinstance(interpolation, int):
1435
+ interpolation = _interpolation_modes_from_int(interpolation)
1436
+
1437
+ self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
1438
+
1439
+ if translate is not None:
1440
+ _check_sequence_input(translate, "translate", req_sizes=(2,))
1441
+ for t in translate:
1442
+ if not (0.0 <= t <= 1.0):
1443
+ raise ValueError("translation values should be between 0 and 1")
1444
+ self.translate = translate
1445
+
1446
+ if scale is not None:
1447
+ _check_sequence_input(scale, "scale", req_sizes=(2,))
1448
+ for s in scale:
1449
+ if s <= 0:
1450
+ raise ValueError("scale values should be positive")
1451
+ self.scale = scale
1452
+
1453
+ if shear is not None:
1454
+ self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4))
1455
+ else:
1456
+ self.shear = shear
1457
+
1458
+ self.interpolation = interpolation
1459
+
1460
+ if fill is None:
1461
+ fill = 0
1462
+ elif not isinstance(fill, (Sequence, numbers.Number)):
1463
+ raise TypeError("Fill should be either a sequence or a number.")
1464
+
1465
+ self.fill = fill
1466
+
1467
+ if center is not None:
1468
+ _check_sequence_input(center, "center", req_sizes=(2,))
1469
+
1470
+ self.center = center
1471
+
1472
+ @staticmethod
1473
+ def get_params(
1474
+ degrees: List[float],
1475
+ translate: Optional[List[float]],
1476
+ scale_ranges: Optional[List[float]],
1477
+ shears: Optional[List[float]],
1478
+ img_size: List[int],
1479
+ ) -> Tuple[float, Tuple[int, int], float, Tuple[float, float]]:
1480
+ """Get parameters for affine transformation
1481
+
1482
+ Returns:
1483
+ params to be passed to the affine transformation
1484
+ """
1485
+ angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
1486
+ if translate is not None:
1487
+ max_dx = float(translate[0] * img_size[0])
1488
+ max_dy = float(translate[1] * img_size[1])
1489
+ tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item()))
1490
+ ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item()))
1491
+ translations = (tx, ty)
1492
+ else:
1493
+ translations = (0, 0)
1494
+
1495
+ if scale_ranges is not None:
1496
+ scale = float(torch.empty(1).uniform_(scale_ranges[0], scale_ranges[1]).item())
1497
+ else:
1498
+ scale = 1.0
1499
+
1500
+ shear_x = shear_y = 0.0
1501
+ if shears is not None:
1502
+ shear_x = float(torch.empty(1).uniform_(shears[0], shears[1]).item())
1503
+ if len(shears) == 4:
1504
+ shear_y = float(torch.empty(1).uniform_(shears[2], shears[3]).item())
1505
+
1506
+ shear = (shear_x, shear_y)
1507
+
1508
+ return angle, translations, scale, shear
1509
+
1510
+ def forward(self, img):
1511
+ """
1512
+ img (PIL Image or Tensor): Image to be transformed.
1513
+
1514
+ Returns:
1515
+ PIL Image or Tensor: Affine transformed image.
1516
+ """
1517
+ fill = self.fill
1518
+ channels, height, width = F.get_dimensions(img)
1519
+ if isinstance(img, Tensor):
1520
+ if isinstance(fill, (int, float)):
1521
+ fill = [float(fill)] * channels
1522
+ else:
1523
+ fill = [float(f) for f in fill]
1524
+
1525
+ img_size = [width, height] # flip for keeping BC on get_params call
1526
+
1527
+ ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size)
1528
+
1529
+ return F.affine(img, *ret, interpolation=self.interpolation, fill=fill, center=self.center)
1530
+
1531
+ def __repr__(self) -> str:
1532
+ s = f"{self.__class__.__name__}(degrees={self.degrees}"
1533
+ s += f", translate={self.translate}" if self.translate is not None else ""
1534
+ s += f", scale={self.scale}" if self.scale is not None else ""
1535
+ s += f", shear={self.shear}" if self.shear is not None else ""
1536
+ s += f", interpolation={self.interpolation.value}" if self.interpolation != InterpolationMode.NEAREST else ""
1537
+ s += f", fill={self.fill}" if self.fill != 0 else ""
1538
+ s += f", center={self.center}" if self.center is not None else ""
1539
+ s += ")"
1540
+
1541
+ return s
1542
+
1543
+
1544
+ class Grayscale(torch.nn.Module):
1545
+ """Convert image to grayscale.
1546
+ If the image is torch Tensor, it is expected
1547
+ to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
1548
+
1549
+ Args:
1550
+ num_output_channels (int): (1 or 3) number of channels desired for output image
1551
+
1552
+ Returns:
1553
+ PIL Image: Grayscale version of the input.
1554
+
1555
+ - If ``num_output_channels == 1`` : returned image is single channel
1556
+ - If ``num_output_channels == 3`` : returned image is 3 channel with r == g == b
1557
+
1558
+ """
1559
+
1560
+ def __init__(self, num_output_channels=1):
1561
+ super().__init__()
1562
+ _log_api_usage_once(self)
1563
+ self.num_output_channels = num_output_channels
1564
+
1565
+ def forward(self, img):
1566
+ """
1567
+ Args:
1568
+ img (PIL Image or Tensor): Image to be converted to grayscale.
1569
+
1570
+ Returns:
1571
+ PIL Image or Tensor: Grayscaled image.
1572
+ """
1573
+ return F.rgb_to_grayscale(img, num_output_channels=self.num_output_channels)
1574
+
1575
+ def __repr__(self) -> str:
1576
+ return f"{self.__class__.__name__}(num_output_channels={self.num_output_channels})"
1577
+
1578
+
1579
+ class RandomGrayscale(torch.nn.Module):
1580
+ """Randomly convert image to grayscale with a probability of p (default 0.1).
1581
+ If the image is torch Tensor, it is expected
1582
+ to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
1583
+
1584
+ Args:
1585
+ p (float): probability that image should be converted to grayscale.
1586
+
1587
+ Returns:
1588
+ PIL Image or Tensor: Grayscale version of the input image with probability p and unchanged
1589
+ with probability (1-p).
1590
+ - If input image is 1 channel: grayscale version is 1 channel
1591
+ - If input image is 3 channel: grayscale version is 3 channel with r == g == b
1592
+
1593
+ """
1594
+
1595
+ def __init__(self, p=0.1):
1596
+ super().__init__()
1597
+ _log_api_usage_once(self)
1598
+ self.p = p
1599
+
1600
+ def forward(self, img):
1601
+ """
1602
+ Args:
1603
+ img (PIL Image or Tensor): Image to be converted to grayscale.
1604
+
1605
+ Returns:
1606
+ PIL Image or Tensor: Randomly grayscaled image.
1607
+ """
1608
+ num_output_channels, _, _ = F.get_dimensions(img)
1609
+ if torch.rand(1) < self.p:
1610
+ return F.rgb_to_grayscale(img, num_output_channels=num_output_channels)
1611
+ return img
1612
+
1613
+ def __repr__(self) -> str:
1614
+ return f"{self.__class__.__name__}(p={self.p})"
1615
+
1616
+
1617
+ class RandomErasing(torch.nn.Module):
1618
+ """Randomly selects a rectangle region in a torch.Tensor image and erases its pixels.
1619
+ This transform does not support PIL Image.
1620
+ 'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896
1621
+
1622
+ Args:
1623
+ p: probability that the random erasing operation will be performed.
1624
+ scale: range of proportion of erased area against input image.
1625
+ ratio: range of aspect ratio of erased area.
1626
+ value: erasing value. Default is 0. If a single int, it is used to
1627
+ erase all pixels. If a tuple of length 3, it is used to erase
1628
+ R, G, B channels respectively.
1629
+ If a str of 'random', erasing each pixel with random values.
1630
+ inplace: boolean to make this transform inplace. Default set to False.
1631
+
1632
+ Returns:
1633
+ Erased Image.
1634
+
1635
+ Example:
1636
+ >>> transform = transforms.Compose([
1637
+ >>> transforms.RandomHorizontalFlip(),
1638
+ >>> transforms.PILToTensor(),
1639
+ >>> transforms.ConvertImageDtype(torch.float),
1640
+ >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
1641
+ >>> transforms.RandomErasing(),
1642
+ >>> ])
1643
+ """
1644
+
1645
+ def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False):
1646
+ super().__init__()
1647
+ _log_api_usage_once(self)
1648
+ if not isinstance(value, (numbers.Number, str, tuple, list)):
1649
+ raise TypeError("Argument value should be either a number or str or a sequence")
1650
+ if isinstance(value, str) and value != "random":
1651
+ raise ValueError("If value is str, it should be 'random'")
1652
+ if not isinstance(scale, Sequence):
1653
+ raise TypeError("Scale should be a sequence")
1654
+ if not isinstance(ratio, Sequence):
1655
+ raise TypeError("Ratio should be a sequence")
1656
+ if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
1657
+ warnings.warn("Scale and ratio should be of kind (min, max)")
1658
+ if scale[0] < 0 or scale[1] > 1:
1659
+ raise ValueError("Scale should be between 0 and 1")
1660
+ if p < 0 or p > 1:
1661
+ raise ValueError("Random erasing probability should be between 0 and 1")
1662
+
1663
+ self.p = p
1664
+ self.scale = scale
1665
+ self.ratio = ratio
1666
+ self.value = value
1667
+ self.inplace = inplace
1668
+
1669
+ @staticmethod
1670
+ def get_params(
1671
+ img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float], value: Optional[List[float]] = None
1672
+ ) -> Tuple[int, int, int, int, Tensor]:
1673
+ """Get parameters for ``erase`` for a random erasing.
1674
+
1675
+ Args:
1676
+ img (Tensor): Tensor image to be erased.
1677
+ scale (sequence): range of proportion of erased area against input image.
1678
+ ratio (sequence): range of aspect ratio of erased area.
1679
+ value (list, optional): erasing value. If None, it is interpreted as "random"
1680
+ (erasing each pixel with random values). If ``len(value)`` is 1, it is interpreted as a number,
1681
+ i.e. ``value[0]``.
1682
+
1683
+ Returns:
1684
+ tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing.
1685
+ """
1686
+ img_c, img_h, img_w = img.shape[-3], img.shape[-2], img.shape[-1]
1687
+ area = img_h * img_w
1688
+
1689
+ log_ratio = torch.log(torch.tensor(ratio))
1690
+ for _ in range(10):
1691
+ erase_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
1692
+ aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
1693
+
1694
+ h = int(round(math.sqrt(erase_area * aspect_ratio)))
1695
+ w = int(round(math.sqrt(erase_area / aspect_ratio)))
1696
+ if not (h < img_h and w < img_w):
1697
+ continue
1698
+
1699
+ if value is None:
1700
+ v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
1701
+ else:
1702
+ v = torch.tensor(value)[:, None, None]
1703
+
1704
+ i = torch.randint(0, img_h - h + 1, size=(1,)).item()
1705
+ j = torch.randint(0, img_w - w + 1, size=(1,)).item()
1706
+ return i, j, h, w, v
1707
+
1708
+ # Return original image
1709
+ return 0, 0, img_h, img_w, img
1710
+
1711
+ def forward(self, img):
1712
+ """
1713
+ Args:
1714
+ img (Tensor): Tensor image to be erased.
1715
+
1716
+ Returns:
1717
+ img (Tensor): Erased Tensor image.
1718
+ """
1719
+ if torch.rand(1) < self.p:
1720
+
1721
+ # cast self.value to script acceptable type
1722
+ if isinstance(self.value, (int, float)):
1723
+ value = [float(self.value)]
1724
+ elif isinstance(self.value, str):
1725
+ value = None
1726
+ elif isinstance(self.value, (list, tuple)):
1727
+ value = [float(v) for v in self.value]
1728
+ else:
1729
+ value = self.value
1730
+
1731
+ if value is not None and not (len(value) in (1, img.shape[-3])):
1732
+ raise ValueError(
1733
+ "If value is a sequence, it should have either a single value or "
1734
+ f"{img.shape[-3]} (number of input channels)"
1735
+ )
1736
+
1737
+ x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=value)
1738
+ return F.erase(img, x, y, h, w, v, self.inplace)
1739
+ return img
1740
+
1741
+ def __repr__(self) -> str:
1742
+ s = (
1743
+ f"{self.__class__.__name__}"
1744
+ f"(p={self.p}, "
1745
+ f"scale={self.scale}, "
1746
+ f"ratio={self.ratio}, "
1747
+ f"value={self.value}, "
1748
+ f"inplace={self.inplace})"
1749
+ )
1750
+ return s
1751
+
1752
+
1753
+ class GaussianBlur(torch.nn.Module):
1754
+ """Blurs image with randomly chosen Gaussian blur.
1755
+ If the image is torch Tensor, it is expected
1756
+ to have [..., C, H, W] shape, where ... means at most one leading dimension.
1757
+
1758
+ Args:
1759
+ kernel_size (int or sequence): Size of the Gaussian kernel.
1760
+ sigma (float or tuple of float (min, max)): Standard deviation to be used for
1761
+ creating kernel to perform blurring. If float, sigma is fixed. If it is tuple
1762
+ of float (min, max), sigma is chosen uniformly at random to lie in the
1763
+ given range.
1764
+
1765
+ Returns:
1766
+ PIL Image or Tensor: Gaussian blurred version of the input image.
1767
+
1768
+ """
1769
+
1770
+ def __init__(self, kernel_size, sigma=(0.1, 2.0)):
1771
+ super().__init__()
1772
+ _log_api_usage_once(self)
1773
+ self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers")
1774
+ for ks in self.kernel_size:
1775
+ if ks <= 0 or ks % 2 == 0:
1776
+ raise ValueError("Kernel size value should be an odd and positive number.")
1777
+
1778
+ if isinstance(sigma, numbers.Number):
1779
+ if sigma <= 0:
1780
+ raise ValueError("If sigma is a single number, it must be positive.")
1781
+ sigma = (sigma, sigma)
1782
+ elif isinstance(sigma, Sequence) and len(sigma) == 2:
1783
+ if not 0.0 < sigma[0] <= sigma[1]:
1784
+ raise ValueError("sigma values should be positive and of the form (min, max).")
1785
+ else:
1786
+ raise ValueError("sigma should be a single number or a list/tuple with length 2.")
1787
+
1788
+ self.sigma = sigma
1789
+
1790
+ @staticmethod
1791
+ def get_params(sigma_min: float, sigma_max: float) -> float:
1792
+ """Choose sigma for random gaussian blurring.
1793
+
1794
+ Args:
1795
+ sigma_min (float): Minimum standard deviation that can be chosen for blurring kernel.
1796
+ sigma_max (float): Maximum standard deviation that can be chosen for blurring kernel.
1797
+
1798
+ Returns:
1799
+ float: Standard deviation to be passed to calculate kernel for gaussian blurring.
1800
+ """
1801
+ return torch.empty(1).uniform_(sigma_min, sigma_max).item()
1802
+
1803
+ def forward(self, img: Tensor) -> Tensor:
1804
+ """
1805
+ Args:
1806
+ img (PIL Image or Tensor): image to be blurred.
1807
+
1808
+ Returns:
1809
+ PIL Image or Tensor: Gaussian blurred image
1810
+ """
1811
+ sigma = self.get_params(self.sigma[0], self.sigma[1])
1812
+ return F.gaussian_blur(img, self.kernel_size, [sigma, sigma])
1813
+
1814
+ def __repr__(self) -> str:
1815
+ s = f"{self.__class__.__name__}(kernel_size={self.kernel_size}, sigma={self.sigma})"
1816
+ return s
1817
+
1818
+
1819
+ def _setup_size(size, error_msg):
1820
+ if isinstance(size, numbers.Number):
1821
+ return int(size), int(size)
1822
+
1823
+ if isinstance(size, Sequence) and len(size) == 1:
1824
+ return size[0], size[0]
1825
+
1826
+ if len(size) != 2:
1827
+ raise ValueError(error_msg)
1828
+
1829
+ return size
1830
+
1831
+
1832
+ def _check_sequence_input(x, name, req_sizes):
1833
+ msg = req_sizes[0] if len(req_sizes) < 2 else " or ".join([str(s) for s in req_sizes])
1834
+ if not isinstance(x, Sequence):
1835
+ raise TypeError(f"{name} should be a sequence of length {msg}.")
1836
+ if len(x) not in req_sizes:
1837
+ raise ValueError(f"{name} should be a sequence of length {msg}.")
1838
+
1839
+
1840
+ def _setup_angle(x, name, req_sizes=(2,)):
1841
+ if isinstance(x, numbers.Number):
1842
+ if x < 0:
1843
+ raise ValueError(f"If {name} is a single number, it must be positive.")
1844
+ x = [-x, x]
1845
+ else:
1846
+ _check_sequence_input(x, name, req_sizes)
1847
+
1848
+ return [float(d) for d in x]
1849
+
1850
+
1851
+ class RandomInvert(torch.nn.Module):
1852
+ """Inverts the colors of the given image randomly with a given probability.
1853
+ If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
1854
+ where ... means it can have an arbitrary number of leading dimensions.
1855
+ If img is PIL Image, it is expected to be in mode "L" or "RGB".
1856
+
1857
+ Args:
1858
+ p (float): probability of the image being color inverted. Default value is 0.5
1859
+ """
1860
+
1861
+ def __init__(self, p=0.5):
1862
+ super().__init__()
1863
+ _log_api_usage_once(self)
1864
+ self.p = p
1865
+
1866
+ def forward(self, img):
1867
+ """
1868
+ Args:
1869
+ img (PIL Image or Tensor): Image to be inverted.
1870
+
1871
+ Returns:
1872
+ PIL Image or Tensor: Randomly color inverted image.
1873
+ """
1874
+ if torch.rand(1).item() < self.p:
1875
+ return F.invert(img)
1876
+ return img
1877
+
1878
+ def __repr__(self) -> str:
1879
+ return f"{self.__class__.__name__}(p={self.p})"
1880
+
1881
+
1882
+ class RandomPosterize(torch.nn.Module):
1883
+ """Posterize the image randomly with a given probability by reducing the
1884
+ number of bits for each color channel. If the image is torch Tensor, it should be of type torch.uint8,
1885
+ and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
1886
+ If img is PIL Image, it is expected to be in mode "L" or "RGB".
1887
+
1888
+ Args:
1889
+ bits (int): number of bits to keep for each channel (0-8)
1890
+ p (float): probability of the image being posterized. Default value is 0.5
1891
+ """
1892
+
1893
+ def __init__(self, bits, p=0.5):
1894
+ super().__init__()
1895
+ _log_api_usage_once(self)
1896
+ self.bits = bits
1897
+ self.p = p
1898
+
1899
+ def forward(self, img):
1900
+ """
1901
+ Args:
1902
+ img (PIL Image or Tensor): Image to be posterized.
1903
+
1904
+ Returns:
1905
+ PIL Image or Tensor: Randomly posterized image.
1906
+ """
1907
+ if torch.rand(1).item() < self.p:
1908
+ return F.posterize(img, self.bits)
1909
+ return img
1910
+
1911
+ def __repr__(self) -> str:
1912
+ return f"{self.__class__.__name__}(bits={self.bits},p={self.p})"
1913
+
1914
+
1915
+ class RandomSolarize(torch.nn.Module):
1916
+ """Solarize the image randomly with a given probability by inverting all pixel
1917
+ values above a threshold. If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
1918
+ where ... means it can have an arbitrary number of leading dimensions.
1919
+ If img is PIL Image, it is expected to be in mode "L" or "RGB".
1920
+
1921
+ Args:
1922
+ threshold (float): all pixels equal or above this value are inverted.
1923
+ p (float): probability of the image being solarized. Default value is 0.5
1924
+ """
1925
+
1926
+ def __init__(self, threshold, p=0.5):
1927
+ super().__init__()
1928
+ _log_api_usage_once(self)
1929
+ self.threshold = threshold
1930
+ self.p = p
1931
+
1932
+ def forward(self, img):
1933
+ """
1934
+ Args:
1935
+ img (PIL Image or Tensor): Image to be solarized.
1936
+
1937
+ Returns:
1938
+ PIL Image or Tensor: Randomly solarized image.
1939
+ """
1940
+ if torch.rand(1).item() < self.p:
1941
+ return F.solarize(img, self.threshold)
1942
+ return img
1943
+
1944
+ def __repr__(self) -> str:
1945
+ return f"{self.__class__.__name__}(threshold={self.threshold},p={self.p})"
1946
+
1947
+
1948
+ class RandomAdjustSharpness(torch.nn.Module):
1949
+ """Adjust the sharpness of the image randomly with a given probability. If the image is torch Tensor,
1950
+ it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
1951
+
1952
+ Args:
1953
+ sharpness_factor (float): How much to adjust the sharpness. Can be
1954
+ any non-negative number. 0 gives a blurred image, 1 gives the
1955
+ original image while 2 increases the sharpness by a factor of 2.
1956
+ p (float): probability of the image being sharpened. Default value is 0.5
1957
+ """
1958
+
1959
+ def __init__(self, sharpness_factor, p=0.5):
1960
+ super().__init__()
1961
+ _log_api_usage_once(self)
1962
+ self.sharpness_factor = sharpness_factor
1963
+ self.p = p
1964
+
1965
+ def forward(self, img):
1966
+ """
1967
+ Args:
1968
+ img (PIL Image or Tensor): Image to be sharpened.
1969
+
1970
+ Returns:
1971
+ PIL Image or Tensor: Randomly sharpened image.
1972
+ """
1973
+ if torch.rand(1).item() < self.p:
1974
+ return F.adjust_sharpness(img, self.sharpness_factor)
1975
+ return img
1976
+
1977
+ def __repr__(self) -> str:
1978
+ return f"{self.__class__.__name__}(sharpness_factor={self.sharpness_factor},p={self.p})"
1979
+
1980
+
1981
+ class RandomAutocontrast(torch.nn.Module):
1982
+ """Autocontrast the pixels of the given image randomly with a given probability.
1983
+ If the image is torch Tensor, it is expected
1984
+ to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
1985
+ If img is PIL Image, it is expected to be in mode "L" or "RGB".
1986
+
1987
+ Args:
1988
+ p (float): probability of the image being autocontrasted. Default value is 0.5
1989
+ """
1990
+
1991
+ def __init__(self, p=0.5):
1992
+ super().__init__()
1993
+ _log_api_usage_once(self)
1994
+ self.p = p
1995
+
1996
+ def forward(self, img):
1997
+ """
1998
+ Args:
1999
+ img (PIL Image or Tensor): Image to be autocontrasted.
2000
+
2001
+ Returns:
2002
+ PIL Image or Tensor: Randomly autocontrasted image.
2003
+ """
2004
+ if torch.rand(1).item() < self.p:
2005
+ return F.autocontrast(img)
2006
+ return img
2007
+
2008
+ def __repr__(self) -> str:
2009
+ return f"{self.__class__.__name__}(p={self.p})"
2010
+
2011
+
2012
+ class RandomEqualize(torch.nn.Module):
2013
+ """Equalize the histogram of the given image randomly with a given probability.
2014
+ If the image is torch Tensor, it is expected
2015
+ to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
2016
+ If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
2017
+
2018
+ Args:
2019
+ p (float): probability of the image being equalized. Default value is 0.5
2020
+ """
2021
+
2022
+ def __init__(self, p=0.5):
2023
+ super().__init__()
2024
+ _log_api_usage_once(self)
2025
+ self.p = p
2026
+
2027
+ def forward(self, img):
2028
+ """
2029
+ Args:
2030
+ img (PIL Image or Tensor): Image to be equalized.
2031
+
2032
+ Returns:
2033
+ PIL Image or Tensor: Randomly equalized image.
2034
+ """
2035
+ if torch.rand(1).item() < self.p:
2036
+ return F.equalize(img)
2037
+ return img
2038
+
2039
+ def __repr__(self) -> str:
2040
+ return f"{self.__class__.__name__}(p={self.p})"
2041
+
2042
+
2043
+ class ElasticTransform(torch.nn.Module):
2044
+ """Transform a tensor image with elastic transformations.
2045
+ Given alpha and sigma, it will generate displacement
2046
+ vectors for all pixels based on random offsets. Alpha controls the strength
2047
+ and sigma controls the smoothness of the displacements.
2048
+ The displacements are added to an identity grid and the resulting grid is
2049
+ used to grid_sample from the image.
2050
+
2051
+ Applications:
2052
+ Randomly transforms the morphology of objects in images and produces a
2053
+ see-through-water-like effect.
2054
+
2055
+ Args:
2056
+ alpha (float or sequence of floats): Magnitude of displacements. Default is 50.0.
2057
+ sigma (float or sequence of floats): Smoothness of displacements. Default is 5.0.
2058
+ interpolation (InterpolationMode): Desired interpolation enum defined by
2059
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
2060
+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
2061
+ The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
2062
+ fill (sequence or number): Pixel fill value for the area outside the transformed
2063
+ image. Default is ``0``. If given a number, the value is used for all bands respectively.
2064
+
2065
+ """
2066
+
2067
+ def __init__(self, alpha=50.0, sigma=5.0, interpolation=InterpolationMode.BILINEAR, fill=0):
2068
+ super().__init__()
2069
+ _log_api_usage_once(self)
2070
+ if not isinstance(alpha, (float, Sequence)):
2071
+ raise TypeError(f"alpha should be float or a sequence of floats. Got {type(alpha)}")
2072
+ if isinstance(alpha, Sequence) and len(alpha) != 2:
2073
+ raise ValueError(f"If alpha is a sequence its length should be 2. Got {len(alpha)}")
2074
+ if isinstance(alpha, Sequence):
2075
+ for element in alpha:
2076
+ if not isinstance(element, float):
2077
+ raise TypeError(f"alpha should be a sequence of floats. Got {type(element)}")
2078
+
2079
+ if isinstance(alpha, float):
2080
+ alpha = [float(alpha), float(alpha)]
2081
+ if isinstance(alpha, (list, tuple)) and len(alpha) == 1:
2082
+ alpha = [alpha[0], alpha[0]]
2083
+
2084
+ self.alpha = alpha
2085
+
2086
+ if not isinstance(sigma, (float, Sequence)):
2087
+ raise TypeError(f"sigma should be float or a sequence of floats. Got {type(sigma)}")
2088
+ if isinstance(sigma, Sequence) and len(sigma) != 2:
2089
+ raise ValueError(f"If sigma is a sequence its length should be 2. Got {len(sigma)}")
2090
+ if isinstance(sigma, Sequence):
2091
+ for element in sigma:
2092
+ if not isinstance(element, float):
2093
+ raise TypeError(f"sigma should be a sequence of floats. Got {type(element)}")
2094
+
2095
+ if isinstance(sigma, float):
2096
+ sigma = [float(sigma), float(sigma)]
2097
+ if isinstance(sigma, (list, tuple)) and len(sigma) == 1:
2098
+ sigma = [sigma[0], sigma[0]]
2099
+
2100
+ self.sigma = sigma
2101
+
2102
+ if isinstance(interpolation, int):
2103
+ interpolation = _interpolation_modes_from_int(interpolation)
2104
+ self.interpolation = interpolation
2105
+
2106
+ if isinstance(fill, (int, float)):
2107
+ fill = [float(fill)]
2108
+ elif isinstance(fill, (list, tuple)):
2109
+ fill = [float(f) for f in fill]
2110
+ else:
2111
+ raise TypeError(f"fill should be int or float or a list or tuple of them. Got {type(fill)}")
2112
+ self.fill = fill
2113
+
2114
+ @staticmethod
2115
+ def get_params(alpha: List[float], sigma: List[float], size: List[int]) -> Tensor:
2116
+ dx = torch.rand([1, 1] + size) * 2 - 1
2117
+ if sigma[0] > 0.0:
2118
+ kx = int(8 * sigma[0] + 1)
2119
+ # if kernel size is even we have to make it odd
2120
+ if kx % 2 == 0:
2121
+ kx += 1
2122
+ dx = F.gaussian_blur(dx, [kx, kx], sigma)
2123
+ dx = dx * alpha[0] / size[0]
2124
+
2125
+ dy = torch.rand([1, 1] + size) * 2 - 1
2126
+ if sigma[1] > 0.0:
2127
+ ky = int(8 * sigma[1] + 1)
2128
+ # if kernel size is even we have to make it odd
2129
+ if ky % 2 == 0:
2130
+ ky += 1
2131
+ dy = F.gaussian_blur(dy, [ky, ky], sigma)
2132
+ dy = dy * alpha[1] / size[1]
2133
+ return torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2
2134
+
2135
+ def forward(self, tensor: Tensor) -> Tensor:
2136
+ """
2137
+ Args:
2138
+ tensor (PIL Image or Tensor): Image to be transformed.
2139
+
2140
+ Returns:
2141
+ PIL Image or Tensor: Transformed image.
2142
+ """
2143
+ _, height, width = F.get_dimensions(tensor)
2144
+ displacement = self.get_params(self.alpha, self.sigma, [height, width])
2145
+ return F.elastic_transform(tensor, displacement, self.interpolation, self.fill)
2146
+
2147
+ def __repr__(self):
2148
+ format_string = self.__class__.__name__
2149
+ format_string += f"(alpha={self.alpha}"
2150
+ format_string += f", sigma={self.sigma}"
2151
+ format_string += f", interpolation={self.interpolation}"
2152
+ format_string += f", fill={self.fill})"
2153
+ return format_string
.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__init__.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision.transforms import AutoAugmentPolicy, InterpolationMode # usort: skip
2
+
3
+ from . import functional # usort: skip
4
+
5
+ from ._transform import Transform # usort: skip
6
+
7
+ from ._augment import CutMix, JPEG, MixUp, RandomErasing
8
+ from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide
9
+ from ._color import (
10
+ ColorJitter,
11
+ Grayscale,
12
+ RandomAdjustSharpness,
13
+ RandomAutocontrast,
14
+ RandomChannelPermutation,
15
+ RandomEqualize,
16
+ RandomGrayscale,
17
+ RandomInvert,
18
+ RandomPhotometricDistort,
19
+ RandomPosterize,
20
+ RandomSolarize,
21
+ RGB,
22
+ )
23
+ from ._container import Compose, RandomApply, RandomChoice, RandomOrder
24
+ from ._geometry import (
25
+ CenterCrop,
26
+ ElasticTransform,
27
+ FiveCrop,
28
+ Pad,
29
+ RandomAffine,
30
+ RandomCrop,
31
+ RandomHorizontalFlip,
32
+ RandomIoUCrop,
33
+ RandomPerspective,
34
+ RandomResize,
35
+ RandomResizedCrop,
36
+ RandomRotation,
37
+ RandomShortestSize,
38
+ RandomVerticalFlip,
39
+ RandomZoomOut,
40
+ Resize,
41
+ ScaleJitter,
42
+ TenCrop,
43
+ )
44
+ from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat
45
+ from ._misc import (
46
+ ConvertImageDtype,
47
+ GaussianBlur,
48
+ GaussianNoise,
49
+ Identity,
50
+ Lambda,
51
+ LinearTransformation,
52
+ Normalize,
53
+ SanitizeBoundingBoxes,
54
+ ToDtype,
55
+ )
56
+ from ._temporal import UniformTemporalSubsample
57
+ from ._type_conversion import PILToTensor, ToImage, ToPILImage, ToPureTensor
58
+ from ._utils import check_type, get_bounding_boxes, has_all, has_any, query_chw, query_size
59
+
60
+ from ._deprecated import ToTensor # usort: skip
.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_augment.cpython-311.pyc ADDED
Binary file (25.8 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_auto_augment.cpython-311.pyc ADDED
Binary file (39.3 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_color.cpython-311.pyc ADDED
Binary file (27.9 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_deprecated.cpython-311.pyc ADDED
Binary file (3.09 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_misc.cpython-311.pyc ADDED
Binary file (29.1 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_temporal.cpython-311.pyc ADDED
Binary file (1.99 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_transform.cpython-311.pyc ADDED
Binary file (9.84 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_type_conversion.cpython-311.pyc ADDED
Binary file (5.36 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/transforms/v2/__pycache__/_utils.cpython-311.pyc ADDED
Binary file (14.3 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_augment.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numbers
3
+ import warnings
4
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Union
5
+
6
+ import PIL.Image
7
+ import torch
8
+ from torch.nn.functional import one_hot
9
+ from torch.utils._pytree import tree_flatten, tree_unflatten
10
+ from torchvision import transforms as _transforms, tv_tensors
11
+ from torchvision.transforms.v2 import functional as F
12
+
13
+ from ._transform import _RandomApplyTransform, Transform
14
+ from ._utils import _check_sequence_input, _parse_labels_getter, has_any, is_pure_tensor, query_chw, query_size
15
+
16
+
17
+ class RandomErasing(_RandomApplyTransform):
18
+ """Randomly select a rectangle region in the input image or video and erase its pixels.
19
+
20
+ This transform does not support PIL Image.
21
+ 'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896
22
+
23
+ Args:
24
+ p (float, optional): probability that the random erasing operation will be performed.
25
+ scale (tuple of float, optional): range of proportion of erased area against input image.
26
+ ratio (tuple of float, optional): range of aspect ratio of erased area.
27
+ value (number or tuple of numbers): erasing value. Default is 0. If a single int, it is used to
28
+ erase all pixels. If a tuple of length 3, it is used to erase
29
+ R, G, B channels respectively.
30
+ If a str of 'random', erasing each pixel with random values.
31
+ inplace (bool, optional): boolean to make this transform inplace. Default set to False.
32
+
33
+ Returns:
34
+ Erased input.
35
+
36
+ Example:
37
+ >>> from torchvision.transforms import v2 as transforms
38
+ >>>
39
+ >>> transform = transforms.Compose([
40
+ >>> transforms.RandomHorizontalFlip(),
41
+ >>> transforms.PILToTensor(),
42
+ >>> transforms.ConvertImageDtype(torch.float),
43
+ >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
44
+ >>> transforms.RandomErasing(),
45
+ >>> ])
46
+ """
47
+
48
+ _v1_transform_cls = _transforms.RandomErasing
49
+
50
+ def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
51
+ return dict(
52
+ super()._extract_params_for_v1_transform(),
53
+ value="random" if self.value is None else self.value,
54
+ )
55
+
56
+ def __init__(
57
+ self,
58
+ p: float = 0.5,
59
+ scale: Sequence[float] = (0.02, 0.33),
60
+ ratio: Sequence[float] = (0.3, 3.3),
61
+ value: float = 0.0,
62
+ inplace: bool = False,
63
+ ):
64
+ super().__init__(p=p)
65
+ if not isinstance(value, (numbers.Number, str, tuple, list)):
66
+ raise TypeError("Argument value should be either a number or str or a sequence")
67
+ if isinstance(value, str) and value != "random":
68
+ raise ValueError("If value is str, it should be 'random'")
69
+ if not isinstance(scale, Sequence):
70
+ raise TypeError("Scale should be a sequence")
71
+ if not isinstance(ratio, Sequence):
72
+ raise TypeError("Ratio should be a sequence")
73
+ if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
74
+ warnings.warn("Scale and ratio should be of kind (min, max)")
75
+ if scale[0] < 0 or scale[1] > 1:
76
+ raise ValueError("Scale should be between 0 and 1")
77
+ self.scale = scale
78
+ self.ratio = ratio
79
+ if isinstance(value, (int, float)):
80
+ self.value = [float(value)]
81
+ elif isinstance(value, str):
82
+ self.value = None
83
+ elif isinstance(value, (list, tuple)):
84
+ self.value = [float(v) for v in value]
85
+ else:
86
+ self.value = value
87
+ self.inplace = inplace
88
+
89
+ self._log_ratio = torch.log(torch.tensor(self.ratio))
90
+
91
+ def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
92
+ if isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask)):
93
+ warnings.warn(
94
+ f"{type(self).__name__}() is currently passing through inputs of type "
95
+ f"tv_tensors.{type(inpt).__name__}. This will likely change in the future."
96
+ )
97
+ return super()._call_kernel(functional, inpt, *args, **kwargs)
98
+
99
+ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
100
+ img_c, img_h, img_w = query_chw(flat_inputs)
101
+
102
+ if self.value is not None and not (len(self.value) in (1, img_c)):
103
+ raise ValueError(
104
+ f"If value is a sequence, it should have either a single value or {img_c} (number of inpt channels)"
105
+ )
106
+
107
+ area = img_h * img_w
108
+
109
+ log_ratio = self._log_ratio
110
+ for _ in range(10):
111
+ erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
112
+ aspect_ratio = torch.exp(
113
+ torch.empty(1).uniform_(
114
+ log_ratio[0], # type: ignore[arg-type]
115
+ log_ratio[1], # type: ignore[arg-type]
116
+ )
117
+ ).item()
118
+
119
+ h = int(round(math.sqrt(erase_area * aspect_ratio)))
120
+ w = int(round(math.sqrt(erase_area / aspect_ratio)))
121
+ if not (h < img_h and w < img_w):
122
+ continue
123
+
124
+ if self.value is None:
125
+ v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
126
+ else:
127
+ v = torch.tensor(self.value)[:, None, None]
128
+
129
+ i = torch.randint(0, img_h - h + 1, size=(1,)).item()
130
+ j = torch.randint(0, img_w - w + 1, size=(1,)).item()
131
+ break
132
+ else:
133
+ i, j, h, w, v = 0, 0, img_h, img_w, None
134
+
135
+ return dict(i=i, j=j, h=h, w=w, v=v)
136
+
137
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
138
+ if params["v"] is not None:
139
+ inpt = self._call_kernel(F.erase, inpt, **params, inplace=self.inplace)
140
+
141
+ return inpt
142
+
143
+
144
+ class _BaseMixUpCutMix(Transform):
145
+ def __init__(self, *, alpha: float = 1.0, num_classes: Optional[int] = None, labels_getter="default") -> None:
146
+ super().__init__()
147
+ self.alpha = float(alpha)
148
+ self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
149
+
150
+ self.num_classes = num_classes
151
+
152
+ self._labels_getter = _parse_labels_getter(labels_getter)
153
+
154
+ def forward(self, *inputs):
155
+ inputs = inputs if len(inputs) > 1 else inputs[0]
156
+ flat_inputs, spec = tree_flatten(inputs)
157
+ needs_transform_list = self._needs_transform_list(flat_inputs)
158
+
159
+ if has_any(flat_inputs, PIL.Image.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask):
160
+ raise ValueError(f"{type(self).__name__}() does not support PIL images, bounding boxes and masks.")
161
+
162
+ labels = self._labels_getter(inputs)
163
+ if not isinstance(labels, torch.Tensor):
164
+ raise ValueError(f"The labels must be a tensor, but got {type(labels)} instead.")
165
+ if labels.ndim not in (1, 2):
166
+ raise ValueError(
167
+ f"labels should be index based with shape (batch_size,) "
168
+ f"or probability based with shape (batch_size, num_classes), "
169
+ f"but got a tensor of shape {labels.shape} instead."
170
+ )
171
+ if labels.ndim == 2 and self.num_classes is not None and labels.shape[-1] != self.num_classes:
172
+ raise ValueError(
173
+ f"When passing 2D labels, "
174
+ f"the number of elements in last dimension must match num_classes: "
175
+ f"{labels.shape[-1]} != {self.num_classes}. "
176
+ f"You can Leave num_classes to None."
177
+ )
178
+ if labels.ndim == 1 and self.num_classes is None:
179
+ raise ValueError("num_classes must be passed if the labels are index-based (1D)")
180
+
181
+ params = {
182
+ "labels": labels,
183
+ "batch_size": labels.shape[0],
184
+ **self._get_params(
185
+ [inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform]
186
+ ),
187
+ }
188
+
189
+ # By default, the labels will be False inside needs_transform_list, since they are a torch.Tensor coming
190
+ # after an image or video. However, we need to handle them in _transform, so we make sure to set them to True
191
+ needs_transform_list[next(idx for idx, inpt in enumerate(flat_inputs) if inpt is labels)] = True
192
+ flat_outputs = [
193
+ self._transform(inpt, params) if needs_transform else inpt
194
+ for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list)
195
+ ]
196
+
197
+ return tree_unflatten(flat_outputs, spec)
198
+
199
+ def _check_image_or_video(self, inpt: torch.Tensor, *, batch_size: int):
200
+ expected_num_dims = 5 if isinstance(inpt, tv_tensors.Video) else 4
201
+ if inpt.ndim != expected_num_dims:
202
+ raise ValueError(
203
+ f"Expected a batched input with {expected_num_dims} dims, but got {inpt.ndim} dimensions instead."
204
+ )
205
+ if inpt.shape[0] != batch_size:
206
+ raise ValueError(
207
+ f"The batch size of the image or video does not match the batch size of the labels: "
208
+ f"{inpt.shape[0]} != {batch_size}."
209
+ )
210
+
211
+ def _mixup_label(self, label: torch.Tensor, *, lam: float) -> torch.Tensor:
212
+ if label.ndim == 1:
213
+ label = one_hot(label, num_classes=self.num_classes) # type: ignore[arg-type]
214
+ if not label.dtype.is_floating_point:
215
+ label = label.float()
216
+ return label.roll(1, 0).mul_(1.0 - lam).add_(label.mul(lam))
217
+
218
+
219
+ class MixUp(_BaseMixUpCutMix):
220
+ """Apply MixUp to the provided batch of images and labels.
221
+
222
+ Paper: `mixup: Beyond Empirical Risk Minimization <https://arxiv.org/abs/1710.09412>`_.
223
+
224
+ .. note::
225
+ This transform is meant to be used on **batches** of samples, not
226
+ individual images. See
227
+ :ref:`sphx_glr_auto_examples_transforms_plot_cutmix_mixup.py` for detailed usage
228
+ examples.
229
+ The sample pairing is deterministic and done by matching consecutive
230
+ samples in the batch, so the batch needs to be shuffled (this is an
231
+ implementation detail, not a guaranteed convention.)
232
+
233
+ In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed
234
+ into a tensor of shape ``(batch_size, num_classes)``.
235
+
236
+ Args:
237
+ alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1.
238
+ num_classes (int, optional): number of classes in the batch. Used for one-hot-encoding.
239
+ Can be None only if the labels are already one-hot-encoded.
240
+ labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
241
+ By default, this will pick the second parameter as the labels if it's a tensor. This covers the most
242
+ common scenario where this transform is called as ``MixUp()(imgs_batch, labels_batch)``.
243
+ It can also be a callable that takes the same input as the transform, and returns the labels.
244
+ """
245
+
246
+ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
247
+ return dict(lam=float(self._dist.sample(()))) # type: ignore[arg-type]
248
+
249
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
250
+ lam = params["lam"]
251
+
252
+ if inpt is params["labels"]:
253
+ return self._mixup_label(inpt, lam=lam)
254
+ elif isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) or is_pure_tensor(inpt):
255
+ self._check_image_or_video(inpt, batch_size=params["batch_size"])
256
+
257
+ output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))
258
+
259
+ if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
260
+ output = tv_tensors.wrap(output, like=inpt)
261
+
262
+ return output
263
+ else:
264
+ return inpt
265
+
266
+
267
+ class CutMix(_BaseMixUpCutMix):
268
+ """Apply CutMix to the provided batch of images and labels.
269
+
270
+ Paper: `CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features
271
+ <https://arxiv.org/abs/1905.04899>`_.
272
+
273
+ .. note::
274
+ This transform is meant to be used on **batches** of samples, not
275
+ individual images. See
276
+ :ref:`sphx_glr_auto_examples_transforms_plot_cutmix_mixup.py` for detailed usage
277
+ examples.
278
+ The sample pairing is deterministic and done by matching consecutive
279
+ samples in the batch, so the batch needs to be shuffled (this is an
280
+ implementation detail, not a guaranteed convention.)
281
+
282
+ In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed
283
+ into a tensor of shape ``(batch_size, num_classes)``.
284
+
285
+ Args:
286
+ alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1.
287
+ num_classes (int, optional): number of classes in the batch. Used for one-hot-encoding.
288
+ Can be None only if the labels are already one-hot-encoded.
289
+ labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
290
+ By default, this will pick the second parameter as the labels if it's a tensor. This covers the most
291
+ common scenario where this transform is called as ``CutMix()(imgs_batch, labels_batch)``.
292
+ It can also be a callable that takes the same input as the transform, and returns the labels.
293
+ """
294
+
295
+ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
296
+ lam = float(self._dist.sample(())) # type: ignore[arg-type]
297
+
298
+ H, W = query_size(flat_inputs)
299
+
300
+ r_x = torch.randint(W, size=(1,))
301
+ r_y = torch.randint(H, size=(1,))
302
+
303
+ r = 0.5 * math.sqrt(1.0 - lam)
304
+ r_w_half = int(r * W)
305
+ r_h_half = int(r * H)
306
+
307
+ x1 = int(torch.clamp(r_x - r_w_half, min=0))
308
+ y1 = int(torch.clamp(r_y - r_h_half, min=0))
309
+ x2 = int(torch.clamp(r_x + r_w_half, max=W))
310
+ y2 = int(torch.clamp(r_y + r_h_half, max=H))
311
+ box = (x1, y1, x2, y2)
312
+
313
+ lam_adjusted = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
314
+
315
+ return dict(box=box, lam_adjusted=lam_adjusted)
316
+
317
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
318
+ if inpt is params["labels"]:
319
+ return self._mixup_label(inpt, lam=params["lam_adjusted"])
320
+ elif isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) or is_pure_tensor(inpt):
321
+ self._check_image_or_video(inpt, batch_size=params["batch_size"])
322
+
323
+ x1, y1, x2, y2 = params["box"]
324
+ rolled = inpt.roll(1, 0)
325
+ output = inpt.clone()
326
+ output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2]
327
+
328
+ if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
329
+ output = tv_tensors.wrap(output, like=inpt)
330
+
331
+ return output
332
+ else:
333
+ return inpt
334
+
335
+
336
+ class JPEG(Transform):
337
+ """Apply JPEG compression and decompression to the given images.
338
+
339
+ If the input is a :class:`torch.Tensor`, it is expected
340
+ to be of dtype uint8, on CPU, and have [..., 3 or 1, H, W] shape,
341
+ where ... means an arbitrary number of leading dimensions.
342
+
343
+ Args:
344
+ quality (sequence or number): JPEG quality, from 1 to 100. Lower means more compression.
345
+ If quality is a sequence like (min, max), it specifies the range of JPEG quality to
346
+ randomly select from (inclusive of both ends).
347
+
348
+ Returns:
349
+ image with JPEG compression.
350
+ """
351
+
352
+ def __init__(self, quality: Union[int, Sequence[int]]):
353
+ super().__init__()
354
+ if isinstance(quality, int):
355
+ quality = [quality, quality]
356
+ else:
357
+ _check_sequence_input(quality, "quality", req_sizes=(2,))
358
+
359
+ if not (1 <= quality[0] <= quality[1] <= 100 and isinstance(quality[0], int) and isinstance(quality[1], int)):
360
+ raise ValueError(f"quality must be an integer from 1 to 100, got {quality =}")
361
+
362
+ self.quality = quality
363
+
364
+ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
365
+ quality = torch.randint(self.quality[0], self.quality[1] + 1, ()).item()
366
+ return dict(quality=quality)
367
+
368
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
369
+ return self._call_kernel(F.jpeg, inpt, quality=params["quality"])
.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_auto_augment.py ADDED
@@ -0,0 +1,627 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union
3
+
4
+ import PIL.Image
5
+ import torch
6
+
7
+ from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
8
+ from torchvision import transforms as _transforms, tv_tensors
9
+ from torchvision.transforms import _functional_tensor as _FT
10
+ from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
11
+ from torchvision.transforms.v2.functional._geometry import _check_interpolation
12
+ from torchvision.transforms.v2.functional._meta import get_size
13
+ from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
14
+
15
+ from ._utils import _get_fill, _setup_fill_arg, check_type, is_pure_tensor
16
+
17
+
18
+ ImageOrVideo = Union[torch.Tensor, PIL.Image.Image, tv_tensors.Image, tv_tensors.Video]
19
+
20
+
21
+ class _AutoAugmentBase(Transform):
22
+ def __init__(
23
+ self,
24
+ *,
25
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
26
+ fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
27
+ ) -> None:
28
+ super().__init__()
29
+ self.interpolation = _check_interpolation(interpolation)
30
+ self.fill = fill
31
+ self._fill = _setup_fill_arg(fill)
32
+
33
+ def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
34
+ params = super()._extract_params_for_v1_transform()
35
+
36
+ if isinstance(params["fill"], dict):
37
+ raise ValueError(f"{type(self).__name__}() can not be scripted for when `fill` is a dictionary.")
38
+
39
+ return params
40
+
41
+ def _get_random_item(self, dct: Dict[str, Tuple[Callable, bool]]) -> Tuple[str, Tuple[Callable, bool]]:
42
+ keys = tuple(dct.keys())
43
+ key = keys[int(torch.randint(len(keys), ()))]
44
+ return key, dct[key]
45
+
46
+ def _flatten_and_extract_image_or_video(
47
+ self,
48
+ inputs: Any,
49
+ unsupported_types: Tuple[Type, ...] = (tv_tensors.BoundingBoxes, tv_tensors.Mask),
50
+ ) -> Tuple[Tuple[List[Any], TreeSpec, int], ImageOrVideo]:
51
+ flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
52
+ needs_transform_list = self._needs_transform_list(flat_inputs)
53
+
54
+ image_or_videos = []
55
+ for idx, (inpt, needs_transform) in enumerate(zip(flat_inputs, needs_transform_list)):
56
+ if needs_transform and check_type(
57
+ inpt,
58
+ (
59
+ tv_tensors.Image,
60
+ PIL.Image.Image,
61
+ is_pure_tensor,
62
+ tv_tensors.Video,
63
+ ),
64
+ ):
65
+ image_or_videos.append((idx, inpt))
66
+ elif isinstance(inpt, unsupported_types):
67
+ raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()")
68
+
69
+ if not image_or_videos:
70
+ raise TypeError("Found no image in the sample.")
71
+ if len(image_or_videos) > 1:
72
+ raise TypeError(
73
+ f"Auto augment transformations are only properly defined for a single image or video, "
74
+ f"but found {len(image_or_videos)}."
75
+ )
76
+
77
+ idx, image_or_video = image_or_videos[0]
78
+ return (flat_inputs, spec, idx), image_or_video
79
+
80
+ def _unflatten_and_insert_image_or_video(
81
+ self,
82
+ flat_inputs_with_spec: Tuple[List[Any], TreeSpec, int],
83
+ image_or_video: ImageOrVideo,
84
+ ) -> Any:
85
+ flat_inputs, spec, idx = flat_inputs_with_spec
86
+ flat_inputs[idx] = image_or_video
87
+ return tree_unflatten(flat_inputs, spec)
88
+
89
+ def _apply_image_or_video_transform(
90
+ self,
91
+ image: ImageOrVideo,
92
+ transform_id: str,
93
+ magnitude: float,
94
+ interpolation: Union[InterpolationMode, int],
95
+ fill: Dict[Union[Type, str], _FillTypeJIT],
96
+ ) -> ImageOrVideo:
97
+ # Note: this cast is wrong and is only here to make mypy happy (it disagrees with torchscript)
98
+ image = cast(torch.Tensor, image)
99
+ fill_ = _get_fill(fill, type(image))
100
+
101
+ if transform_id == "Identity":
102
+ return image
103
+ elif transform_id == "ShearX":
104
+ # magnitude should be arctan(magnitude)
105
+ # official autoaug: (1, level, 0, 0, 1, 0)
106
+ # https://github.com/tensorflow/models/blob/dd02069717128186b88afa8d857ce57d17957f03/research/autoaugment/augmentation_transforms.py#L290
107
+ # compared to
108
+ # torchvision: (1, tan(level), 0, 0, 1, 0)
109
+ # https://github.com/pytorch/vision/blob/0c2373d0bba3499e95776e7936e207d8a1676e65/torchvision/transforms/functional.py#L976
110
+ return F.affine(
111
+ image,
112
+ angle=0.0,
113
+ translate=[0, 0],
114
+ scale=1.0,
115
+ shear=[math.degrees(math.atan(magnitude)), 0.0],
116
+ interpolation=interpolation,
117
+ fill=fill_,
118
+ center=[0, 0],
119
+ )
120
+ elif transform_id == "ShearY":
121
+ # magnitude should be arctan(magnitude)
122
+ # See above
123
+ return F.affine(
124
+ image,
125
+ angle=0.0,
126
+ translate=[0, 0],
127
+ scale=1.0,
128
+ shear=[0.0, math.degrees(math.atan(magnitude))],
129
+ interpolation=interpolation,
130
+ fill=fill_,
131
+ center=[0, 0],
132
+ )
133
+ elif transform_id == "TranslateX":
134
+ return F.affine(
135
+ image,
136
+ angle=0.0,
137
+ translate=[int(magnitude), 0],
138
+ scale=1.0,
139
+ interpolation=interpolation,
140
+ shear=[0.0, 0.0],
141
+ fill=fill_,
142
+ )
143
+ elif transform_id == "TranslateY":
144
+ return F.affine(
145
+ image,
146
+ angle=0.0,
147
+ translate=[0, int(magnitude)],
148
+ scale=1.0,
149
+ interpolation=interpolation,
150
+ shear=[0.0, 0.0],
151
+ fill=fill_,
152
+ )
153
+ elif transform_id == "Rotate":
154
+ return F.rotate(image, angle=magnitude, interpolation=interpolation, fill=fill_)
155
+ elif transform_id == "Brightness":
156
+ return F.adjust_brightness(image, brightness_factor=1.0 + magnitude)
157
+ elif transform_id == "Color":
158
+ return F.adjust_saturation(image, saturation_factor=1.0 + magnitude)
159
+ elif transform_id == "Contrast":
160
+ return F.adjust_contrast(image, contrast_factor=1.0 + magnitude)
161
+ elif transform_id == "Sharpness":
162
+ return F.adjust_sharpness(image, sharpness_factor=1.0 + magnitude)
163
+ elif transform_id == "Posterize":
164
+ return F.posterize(image, bits=int(magnitude))
165
+ elif transform_id == "Solarize":
166
+ bound = _FT._max_value(image.dtype) if isinstance(image, torch.Tensor) else 255.0
167
+ return F.solarize(image, threshold=bound * magnitude)
168
+ elif transform_id == "AutoContrast":
169
+ return F.autocontrast(image)
170
+ elif transform_id == "Equalize":
171
+ return F.equalize(image)
172
+ elif transform_id == "Invert":
173
+ return F.invert(image)
174
+ else:
175
+ raise ValueError(f"No transform available for {transform_id}")
176
+
177
+
178
+ class AutoAugment(_AutoAugmentBase):
179
+ r"""AutoAugment data augmentation method based on
180
+ `"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
181
+
182
+ This transformation works on images and videos only.
183
+
184
+ If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
185
+ to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
186
+ If img is PIL Image, it is expected to be in mode "L" or "RGB".
187
+
188
+ Args:
189
+ policy (AutoAugmentPolicy, optional): Desired policy enum defined by
190
+ :class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``.
191
+ interpolation (InterpolationMode, optional): Desired interpolation enum defined by
192
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
193
+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
194
+ fill (sequence or number, optional): Pixel fill value for the area outside the transformed
195
+ image. If given a number, the value is used for all bands respectively.
196
+ """
197
+ _v1_transform_cls = _transforms.AutoAugment
198
+
199
+ _AUGMENTATION_SPACE = {
200
+ "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
201
+ "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
202
+ "TranslateX": (
203
+ lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins),
204
+ True,
205
+ ),
206
+ "TranslateY": (
207
+ lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins),
208
+ True,
209
+ ),
210
+ "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
211
+ "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
212
+ "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
213
+ "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
214
+ "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
215
+ "Posterize": (
216
+ lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
217
+ False,
218
+ ),
219
+ "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
220
+ "AutoContrast": (lambda num_bins, height, width: None, False),
221
+ "Equalize": (lambda num_bins, height, width: None, False),
222
+ "Invert": (lambda num_bins, height, width: None, False),
223
+ }
224
+
225
+ def __init__(
226
+ self,
227
+ policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
228
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
229
+ fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
230
+ ) -> None:
231
+ super().__init__(interpolation=interpolation, fill=fill)
232
+ self.policy = policy
233
+ self._policies = self._get_policies(policy)
234
+
235
+ def _get_policies(
236
+ self, policy: AutoAugmentPolicy
237
+ ) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
238
+ if policy == AutoAugmentPolicy.IMAGENET:
239
+ return [
240
+ (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
241
+ (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
242
+ (("Equalize", 0.8, None), ("Equalize", 0.6, None)),
243
+ (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)),
244
+ (("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
245
+ (("Equalize", 0.4, None), ("Rotate", 0.8, 8)),
246
+ (("Solarize", 0.6, 3), ("Equalize", 0.6, None)),
247
+ (("Posterize", 0.8, 5), ("Equalize", 1.0, None)),
248
+ (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)),
249
+ (("Equalize", 0.6, None), ("Posterize", 0.4, 6)),
250
+ (("Rotate", 0.8, 8), ("Color", 0.4, 0)),
251
+ (("Rotate", 0.4, 9), ("Equalize", 0.6, None)),
252
+ (("Equalize", 0.0, None), ("Equalize", 0.8, None)),
253
+ (("Invert", 0.6, None), ("Equalize", 1.0, None)),
254
+ (("Color", 0.6, 4), ("Contrast", 1.0, 8)),
255
+ (("Rotate", 0.8, 8), ("Color", 1.0, 2)),
256
+ (("Color", 0.8, 8), ("Solarize", 0.8, 7)),
257
+ (("Sharpness", 0.4, 7), ("Invert", 0.6, None)),
258
+ (("ShearX", 0.6, 5), ("Equalize", 1.0, None)),
259
+ (("Color", 0.4, 0), ("Equalize", 0.6, None)),
260
+ (("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
261
+ (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
262
+ (("Invert", 0.6, None), ("Equalize", 1.0, None)),
263
+ (("Color", 0.6, 4), ("Contrast", 1.0, 8)),
264
+ (("Equalize", 0.8, None), ("Equalize", 0.6, None)),
265
+ ]
266
+ elif policy == AutoAugmentPolicy.CIFAR10:
267
+ return [
268
+ (("Invert", 0.1, None), ("Contrast", 0.2, 6)),
269
+ (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)),
270
+ (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
271
+ (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)),
272
+ (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)),
273
+ (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)),
274
+ (("Color", 0.4, 3), ("Brightness", 0.6, 7)),
275
+ (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)),
276
+ (("Equalize", 0.6, None), ("Equalize", 0.5, None)),
277
+ (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)),
278
+ (("Color", 0.7, 7), ("TranslateX", 0.5, 8)),
279
+ (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)),
280
+ (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)),
281
+ (("Brightness", 0.9, 6), ("Color", 0.2, 8)),
282
+ (("Solarize", 0.5, 2), ("Invert", 0.0, None)),
283
+ (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)),
284
+ (("Equalize", 0.2, None), ("Equalize", 0.6, None)),
285
+ (("Color", 0.9, 9), ("Equalize", 0.6, None)),
286
+ (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)),
287
+ (("Brightness", 0.1, 3), ("Color", 0.7, 0)),
288
+ (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)),
289
+ (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)),
290
+ (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)),
291
+ (("Equalize", 0.8, None), ("Invert", 0.1, None)),
292
+ (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)),
293
+ ]
294
+ elif policy == AutoAugmentPolicy.SVHN:
295
+ return [
296
+ (("ShearX", 0.9, 4), ("Invert", 0.2, None)),
297
+ (("ShearY", 0.9, 8), ("Invert", 0.7, None)),
298
+ (("Equalize", 0.6, None), ("Solarize", 0.6, 6)),
299
+ (("Invert", 0.9, None), ("Equalize", 0.6, None)),
300
+ (("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
301
+ (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)),
302
+ (("ShearY", 0.9, 8), ("Invert", 0.4, None)),
303
+ (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)),
304
+ (("Invert", 0.9, None), ("AutoContrast", 0.8, None)),
305
+ (("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
306
+ (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)),
307
+ (("ShearY", 0.8, 8), ("Invert", 0.7, None)),
308
+ (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)),
309
+ (("Invert", 0.9, None), ("Equalize", 0.6, None)),
310
+ (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)),
311
+ (("Invert", 0.8, None), ("TranslateY", 0.0, 2)),
312
+ (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)),
313
+ (("Invert", 0.6, None), ("Rotate", 0.8, 4)),
314
+ (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)),
315
+ (("ShearX", 0.1, 6), ("Invert", 0.6, None)),
316
+ (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)),
317
+ (("ShearY", 0.8, 4), ("Invert", 0.8, None)),
318
+ (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)),
319
+ (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)),
320
+ (("ShearX", 0.7, 2), ("Invert", 0.1, None)),
321
+ ]
322
+ else:
323
+ raise ValueError(f"The provided policy {policy} is not recognized.")
324
+
325
+ def forward(self, *inputs: Any) -> Any:
326
+ flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
327
+ height, width = get_size(image_or_video) # type: ignore[arg-type]
328
+
329
+ policy = self._policies[int(torch.randint(len(self._policies), ()))]
330
+
331
+ for transform_id, probability, magnitude_idx in policy:
332
+ if not torch.rand(()) <= probability:
333
+ continue
334
+
335
+ magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id]
336
+
337
+ magnitudes = magnitudes_fn(10, height, width)
338
+ if magnitudes is not None:
339
+ magnitude = float(magnitudes[magnitude_idx])
340
+ if signed and torch.rand(()) <= 0.5:
341
+ magnitude *= -1
342
+ else:
343
+ magnitude = 0.0
344
+
345
+ image_or_video = self._apply_image_or_video_transform(
346
+ image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
347
+ )
348
+
349
+ return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
350
+
351
+
352
+ class RandAugment(_AutoAugmentBase):
353
+ r"""RandAugment data augmentation method based on
354
+ `"RandAugment: Practical automated data augmentation with a reduced search space"
355
+ <https://arxiv.org/abs/1909.13719>`_.
356
+
357
+ This transformation works on images and videos only.
358
+
359
+ If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
360
+ to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
361
+ If img is PIL Image, it is expected to be in mode "L" or "RGB".
362
+
363
+ Args:
364
+ num_ops (int, optional): Number of augmentation transformations to apply sequentially.
365
+ magnitude (int, optional): Magnitude for all the transformations.
366
+ num_magnitude_bins (int, optional): The number of different magnitude values.
367
+ interpolation (InterpolationMode, optional): Desired interpolation enum defined by
368
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
369
+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
370
+ fill (sequence or number, optional): Pixel fill value for the area outside the transformed
371
+ image. If given a number, the value is used for all bands respectively.
372
+ """
373
+
374
+ _v1_transform_cls = _transforms.RandAugment
375
+ _AUGMENTATION_SPACE = {
376
+ "Identity": (lambda num_bins, height, width: None, False),
377
+ "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
378
+ "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
379
+ "TranslateX": (
380
+ lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins),
381
+ True,
382
+ ),
383
+ "TranslateY": (
384
+ lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins),
385
+ True,
386
+ ),
387
+ "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
388
+ "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
389
+ "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
390
+ "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
391
+ "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
392
+ "Posterize": (
393
+ lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
394
+ False,
395
+ ),
396
+ "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
397
+ "AutoContrast": (lambda num_bins, height, width: None, False),
398
+ "Equalize": (lambda num_bins, height, width: None, False),
399
+ }
400
+
401
+ def __init__(
402
+ self,
403
+ num_ops: int = 2,
404
+ magnitude: int = 9,
405
+ num_magnitude_bins: int = 31,
406
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
407
+ fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
408
+ ) -> None:
409
+ super().__init__(interpolation=interpolation, fill=fill)
410
+ self.num_ops = num_ops
411
+ self.magnitude = magnitude
412
+ self.num_magnitude_bins = num_magnitude_bins
413
+
414
+ def forward(self, *inputs: Any) -> Any:
415
+ flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
416
+ height, width = get_size(image_or_video) # type: ignore[arg-type]
417
+
418
+ for _ in range(self.num_ops):
419
+ transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
420
+ magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
421
+ if magnitudes is not None:
422
+ magnitude = float(magnitudes[self.magnitude])
423
+ if signed and torch.rand(()) <= 0.5:
424
+ magnitude *= -1
425
+ else:
426
+ magnitude = 0.0
427
+ image_or_video = self._apply_image_or_video_transform(
428
+ image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
429
+ )
430
+
431
+ return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
432
+
433
+
434
+ class TrivialAugmentWide(_AutoAugmentBase):
435
+ r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in
436
+ `"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`_.
437
+
438
+ This transformation works on images and videos only.
439
+
440
+ If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
441
+ to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
442
+ If img is PIL Image, it is expected to be in mode "L" or "RGB".
443
+
444
+ Args:
445
+ num_magnitude_bins (int, optional): The number of different magnitude values.
446
+ interpolation (InterpolationMode, optional): Desired interpolation enum defined by
447
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
448
+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
449
+ fill (sequence or number, optional): Pixel fill value for the area outside the transformed
450
+ image. If given a number, the value is used for all bands respectively.
451
+ """
452
+
453
+ _v1_transform_cls = _transforms.TrivialAugmentWide
454
+ _AUGMENTATION_SPACE = {
455
+ "Identity": (lambda num_bins, height, width: None, False),
456
+ "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
457
+ "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
458
+ "TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True),
459
+ "TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True),
460
+ "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 135.0, num_bins), True),
461
+ "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
462
+ "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
463
+ "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
464
+ "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
465
+ "Posterize": (
466
+ lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6))).round().int(),
467
+ False,
468
+ ),
469
+ "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
470
+ "AutoContrast": (lambda num_bins, height, width: None, False),
471
+ "Equalize": (lambda num_bins, height, width: None, False),
472
+ }
473
+
474
+ def __init__(
475
+ self,
476
+ num_magnitude_bins: int = 31,
477
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
478
+ fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
479
+ ):
480
+ super().__init__(interpolation=interpolation, fill=fill)
481
+ self.num_magnitude_bins = num_magnitude_bins
482
+
483
+ def forward(self, *inputs: Any) -> Any:
484
+ flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
485
+ height, width = get_size(image_or_video) # type: ignore[arg-type]
486
+
487
+ transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
488
+
489
+ magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
490
+ if magnitudes is not None:
491
+ magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
492
+ if signed and torch.rand(()) <= 0.5:
493
+ magnitude *= -1
494
+ else:
495
+ magnitude = 0.0
496
+
497
+ image_or_video = self._apply_image_or_video_transform(
498
+ image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
499
+ )
500
+ return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
501
+
502
+
503
+ class AugMix(_AutoAugmentBase):
504
+ r"""AugMix data augmentation method based on
505
+ `"AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty" <https://arxiv.org/abs/1912.02781>`_.
506
+
507
+ This transformation works on images and videos only.
508
+
509
+ If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
510
+ to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
511
+ If img is PIL Image, it is expected to be in mode "L" or "RGB".
512
+
513
+ Args:
514
+ severity (int, optional): The severity of base augmentation operators. Default is ``3``.
515
+ mixture_width (int, optional): The number of augmentation chains. Default is ``3``.
516
+ chain_depth (int, optional): The depth of augmentation chains. A negative value denotes stochastic depth sampled from the interval [1, 3].
517
+ Default is ``-1``.
518
+ alpha (float, optional): The hyperparameter for the probability distributions. Default is ``1.0``.
519
+ all_ops (bool, optional): Use all operations (including brightness, contrast, color and sharpness). Default is ``True``.
520
+ interpolation (InterpolationMode, optional): Desired interpolation enum defined by
521
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
522
+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
523
+ fill (sequence or number, optional): Pixel fill value for the area outside the transformed
524
+ image. If given a number, the value is used for all bands respectively.
525
+ """
526
+
527
+ _v1_transform_cls = _transforms.AugMix
528
+
529
+ _PARTIAL_AUGMENTATION_SPACE = {
530
+ "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
531
+ "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
532
+ "TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, width / 3.0, num_bins), True),
533
+ "TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, height / 3.0, num_bins), True),
534
+ "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
535
+ "Posterize": (
536
+ lambda num_bins, height, width: (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
537
+ False,
538
+ ),
539
+ "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
540
+ "AutoContrast": (lambda num_bins, height, width: None, False),
541
+ "Equalize": (lambda num_bins, height, width: None, False),
542
+ }
543
+ _AUGMENTATION_SPACE: Dict[str, Tuple[Callable[[int, int, int], Optional[torch.Tensor]], bool]] = {
544
+ **_PARTIAL_AUGMENTATION_SPACE,
545
+ "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
546
+ "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
547
+ "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
548
+ "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
549
+ }
550
+
551
+ def __init__(
552
+ self,
553
+ severity: int = 3,
554
+ mixture_width: int = 3,
555
+ chain_depth: int = -1,
556
+ alpha: float = 1.0,
557
+ all_ops: bool = True,
558
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
559
+ fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
560
+ ) -> None:
561
+ super().__init__(interpolation=interpolation, fill=fill)
562
+ self._PARAMETER_MAX = 10
563
+ if not (1 <= severity <= self._PARAMETER_MAX):
564
+ raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.")
565
+ self.severity = severity
566
+ self.mixture_width = mixture_width
567
+ self.chain_depth = chain_depth
568
+ self.alpha = alpha
569
+ self.all_ops = all_ops
570
+
571
+ def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor:
572
+ # Must be on a separate method so that we can overwrite it in tests.
573
+ return torch._sample_dirichlet(params)
574
+
575
+ def forward(self, *inputs: Any) -> Any:
576
+ flat_inputs_with_spec, orig_image_or_video = self._flatten_and_extract_image_or_video(inputs)
577
+ height, width = get_size(orig_image_or_video) # type: ignore[arg-type]
578
+
579
+ if isinstance(orig_image_or_video, torch.Tensor):
580
+ image_or_video = orig_image_or_video
581
+ else: # isinstance(inpt, PIL.Image.Image):
582
+ image_or_video = F.pil_to_tensor(orig_image_or_video)
583
+
584
+ augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE
585
+
586
+ orig_dims = list(image_or_video.shape)
587
+ expected_ndim = 5 if isinstance(orig_image_or_video, tv_tensors.Video) else 4
588
+ batch = image_or_video.reshape([1] * max(expected_ndim - image_or_video.ndim, 0) + orig_dims)
589
+ batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)
590
+
591
+ # Sample the beta weights for combining the original and augmented image or video. To get Beta, we use a
592
+ # Dirichlet with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of
593
+ # augmented image or video.
594
+ m = self._sample_dirichlet(
595
+ torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1)
596
+ )
597
+
598
+ # Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images or videos.
599
+ combined_weights = self._sample_dirichlet(
600
+ torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1)
601
+ ) * m[:, 1].reshape([batch_dims[0], -1])
602
+
603
+ mix = m[:, 0].reshape(batch_dims) * batch
604
+ for i in range(self.mixture_width):
605
+ aug = batch
606
+ depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item())
607
+ for _ in range(depth):
608
+ transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space)
609
+
610
+ magnitudes = magnitudes_fn(self._PARAMETER_MAX, height, width)
611
+ if magnitudes is not None:
612
+ magnitude = float(magnitudes[int(torch.randint(self.severity, ()))])
613
+ if signed and torch.rand(()) <= 0.5:
614
+ magnitude *= -1
615
+ else:
616
+ magnitude = 0.0
617
+
618
+ aug = self._apply_image_or_video_transform(aug, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill) # type: ignore[assignment]
619
+ mix.add_(combined_weights[:, i].reshape(batch_dims) * aug)
620
+ mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype)
621
+
622
+ if isinstance(orig_image_or_video, (tv_tensors.Image, tv_tensors.Video)):
623
+ mix = tv_tensors.wrap(mix, like=orig_image_or_video)
624
+ elif isinstance(orig_image_or_video, PIL.Image.Image):
625
+ mix = F.to_pil_image(mix)
626
+
627
+ return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, mix)
.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_color.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections.abc
2
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
3
+
4
+ import torch
5
+ from torchvision import transforms as _transforms
6
+ from torchvision.transforms.v2 import functional as F, Transform
7
+
8
+ from ._transform import _RandomApplyTransform
9
+ from ._utils import query_chw
10
+
11
+
12
+ class Grayscale(Transform):
13
+ """Convert images or videos to grayscale.
14
+
15
+ If the input is a :class:`torch.Tensor`, it is expected
16
+ to have [..., 3 or 1, H, W] shape, where ... means an arbitrary number of leading dimensions
17
+
18
+ Args:
19
+ num_output_channels (int): (1 or 3) number of channels desired for output image
20
+ """
21
+
22
+ _v1_transform_cls = _transforms.Grayscale
23
+
24
+ def __init__(self, num_output_channels: int = 1):
25
+ super().__init__()
26
+ self.num_output_channels = num_output_channels
27
+
28
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
29
+ return self._call_kernel(F.rgb_to_grayscale, inpt, num_output_channels=self.num_output_channels)
30
+
31
+
32
+ class RandomGrayscale(_RandomApplyTransform):
33
+ """Randomly convert image or videos to grayscale with a probability of p (default 0.1).
34
+
35
+ If the input is a :class:`torch.Tensor`, it is expected to have [..., 3 or 1, H, W] shape,
36
+ where ... means an arbitrary number of leading dimensions
37
+
38
+ The output has the same number of channels as the input.
39
+
40
+ Args:
41
+ p (float): probability that image should be converted to grayscale.
42
+ """
43
+
44
+ _v1_transform_cls = _transforms.RandomGrayscale
45
+
46
+ def __init__(self, p: float = 0.1) -> None:
47
+ super().__init__(p=p)
48
+
49
+ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
50
+ num_input_channels, *_ = query_chw(flat_inputs)
51
+ return dict(num_input_channels=num_input_channels)
52
+
53
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
54
+ return self._call_kernel(F.rgb_to_grayscale, inpt, num_output_channels=params["num_input_channels"])
55
+
56
+
57
+ class RGB(Transform):
58
+ """Convert images or videos to RGB (if they are already not RGB).
59
+
60
+ If the input is a :class:`torch.Tensor`, it is expected
61
+ to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions
62
+ """
63
+
64
+ def __init__(self):
65
+ super().__init__()
66
+
67
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
68
+ return self._call_kernel(F.grayscale_to_rgb, inpt)
69
+
70
+
71
+ class ColorJitter(Transform):
72
+ """Randomly change the brightness, contrast, saturation and hue of an image or video.
73
+
74
+ If the input is a :class:`torch.Tensor`, it is expected
75
+ to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
76
+ If img is PIL Image, mode "1", "I", "F" and modes with transparency (alpha channel) are not supported.
77
+
78
+ Args:
79
+ brightness (float or tuple of float (min, max)): How much to jitter brightness.
80
+ brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
81
+ or the given [min, max]. Should be non negative numbers.
82
+ contrast (float or tuple of float (min, max)): How much to jitter contrast.
83
+ contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
84
+ or the given [min, max]. Should be non-negative numbers.
85
+ saturation (float or tuple of float (min, max)): How much to jitter saturation.
86
+ saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
87
+ or the given [min, max]. Should be non negative numbers.
88
+ hue (float or tuple of float (min, max)): How much to jitter hue.
89
+ hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
90
+ Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
91
+ To jitter hue, the pixel values of the input image has to be non-negative for conversion to HSV space;
92
+ thus it does not work if you normalize your image to an interval with negative values,
93
+ or use an interpolation that generates negative values before using this function.
94
+ """
95
+
96
+ _v1_transform_cls = _transforms.ColorJitter
97
+
98
+ def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
99
+ return {attr: value or 0 for attr, value in super()._extract_params_for_v1_transform().items()}
100
+
101
+ def __init__(
102
+ self,
103
+ brightness: Optional[Union[float, Sequence[float]]] = None,
104
+ contrast: Optional[Union[float, Sequence[float]]] = None,
105
+ saturation: Optional[Union[float, Sequence[float]]] = None,
106
+ hue: Optional[Union[float, Sequence[float]]] = None,
107
+ ) -> None:
108
+ super().__init__()
109
+ self.brightness = self._check_input(brightness, "brightness")
110
+ self.contrast = self._check_input(contrast, "contrast")
111
+ self.saturation = self._check_input(saturation, "saturation")
112
+ self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
113
+
114
+ def _check_input(
115
+ self,
116
+ value: Optional[Union[float, Sequence[float]]],
117
+ name: str,
118
+ center: float = 1.0,
119
+ bound: Tuple[float, float] = (0, float("inf")),
120
+ clip_first_on_zero: bool = True,
121
+ ) -> Optional[Tuple[float, float]]:
122
+ if value is None:
123
+ return None
124
+
125
+ if isinstance(value, (int, float)):
126
+ if value < 0:
127
+ raise ValueError(f"If {name} is a single number, it must be non negative.")
128
+ value = [center - value, center + value]
129
+ if clip_first_on_zero:
130
+ value[0] = max(value[0], 0.0)
131
+ elif isinstance(value, collections.abc.Sequence) and len(value) == 2:
132
+ value = [float(v) for v in value]
133
+ else:
134
+ raise TypeError(f"{name}={value} should be a single number or a sequence with length 2.")
135
+
136
+ if not bound[0] <= value[0] <= value[1] <= bound[1]:
137
+ raise ValueError(f"{name} values should be between {bound}, but got {value}.")
138
+
139
+ return None if value[0] == value[1] == center else (float(value[0]), float(value[1]))
140
+
141
+ @staticmethod
142
+ def _generate_value(left: float, right: float) -> float:
143
+ return torch.empty(1).uniform_(left, right).item()
144
+
145
+ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
146
+ fn_idx = torch.randperm(4)
147
+
148
+ b = None if self.brightness is None else self._generate_value(self.brightness[0], self.brightness[1])
149
+ c = None if self.contrast is None else self._generate_value(self.contrast[0], self.contrast[1])
150
+ s = None if self.saturation is None else self._generate_value(self.saturation[0], self.saturation[1])
151
+ h = None if self.hue is None else self._generate_value(self.hue[0], self.hue[1])
152
+
153
+ return dict(fn_idx=fn_idx, brightness_factor=b, contrast_factor=c, saturation_factor=s, hue_factor=h)
154
+
155
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
156
+ output = inpt
157
+ brightness_factor = params["brightness_factor"]
158
+ contrast_factor = params["contrast_factor"]
159
+ saturation_factor = params["saturation_factor"]
160
+ hue_factor = params["hue_factor"]
161
+ for fn_id in params["fn_idx"]:
162
+ if fn_id == 0 and brightness_factor is not None:
163
+ output = self._call_kernel(F.adjust_brightness, output, brightness_factor=brightness_factor)
164
+ elif fn_id == 1 and contrast_factor is not None:
165
+ output = self._call_kernel(F.adjust_contrast, output, contrast_factor=contrast_factor)
166
+ elif fn_id == 2 and saturation_factor is not None:
167
+ output = self._call_kernel(F.adjust_saturation, output, saturation_factor=saturation_factor)
168
+ elif fn_id == 3 and hue_factor is not None:
169
+ output = self._call_kernel(F.adjust_hue, output, hue_factor=hue_factor)
170
+ return output
171
+
172
+
173
+ class RandomChannelPermutation(Transform):
174
+ """Randomly permute the channels of an image or video"""
175
+
176
+ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
177
+ num_channels, *_ = query_chw(flat_inputs)
178
+ return dict(permutation=torch.randperm(num_channels))
179
+
180
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
181
+ return self._call_kernel(F.permute_channels, inpt, params["permutation"])
182
+
183
+
184
+ class RandomPhotometricDistort(Transform):
185
+ """Randomly distorts the image or video as used in `SSD: Single Shot
186
+ MultiBox Detector <https://arxiv.org/abs/1512.02325>`_.
187
+
188
+ This transform relies on :class:`~torchvision.transforms.v2.ColorJitter`
189
+ under the hood to adjust the contrast, saturation, hue, brightness, and also
190
+ randomly permutes channels.
191
+
192
+ Args:
193
+ brightness (tuple of float (min, max), optional): How much to jitter brightness.
194
+ brightness_factor is chosen uniformly from [min, max]. Should be non negative numbers.
195
+ contrast (tuple of float (min, max), optional): How much to jitter contrast.
196
+ contrast_factor is chosen uniformly from [min, max]. Should be non-negative numbers.
197
+ saturation (tuple of float (min, max), optional): How much to jitter saturation.
198
+ saturation_factor is chosen uniformly from [min, max]. Should be non negative numbers.
199
+ hue (tuple of float (min, max), optional): How much to jitter hue.
200
+ hue_factor is chosen uniformly from [min, max]. Should have -0.5 <= min <= max <= 0.5.
201
+ To jitter hue, the pixel values of the input image has to be non-negative for conversion to HSV space;
202
+ thus it does not work if you normalize your image to an interval with negative values,
203
+ or use an interpolation that generates negative values before using this function.
204
+ p (float, optional) probability each distortion operation (contrast, saturation, ...) to be applied.
205
+ Default is 0.5.
206
+ """
207
+
208
+ def __init__(
209
+ self,
210
+ brightness: Tuple[float, float] = (0.875, 1.125),
211
+ contrast: Tuple[float, float] = (0.5, 1.5),
212
+ saturation: Tuple[float, float] = (0.5, 1.5),
213
+ hue: Tuple[float, float] = (-0.05, 0.05),
214
+ p: float = 0.5,
215
+ ):
216
+ super().__init__()
217
+ self.brightness = brightness
218
+ self.contrast = contrast
219
+ self.hue = hue
220
+ self.saturation = saturation
221
+ self.p = p
222
+
223
+ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
224
+ num_channels, *_ = query_chw(flat_inputs)
225
+ params: Dict[str, Any] = {
226
+ key: ColorJitter._generate_value(range[0], range[1]) if torch.rand(1) < self.p else None
227
+ for key, range in [
228
+ ("brightness_factor", self.brightness),
229
+ ("contrast_factor", self.contrast),
230
+ ("saturation_factor", self.saturation),
231
+ ("hue_factor", self.hue),
232
+ ]
233
+ }
234
+ params["contrast_before"] = bool(torch.rand(()) < 0.5)
235
+ params["channel_permutation"] = torch.randperm(num_channels) if torch.rand(1) < self.p else None
236
+ return params
237
+
238
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
239
+ if params["brightness_factor"] is not None:
240
+ inpt = self._call_kernel(F.adjust_brightness, inpt, brightness_factor=params["brightness_factor"])
241
+ if params["contrast_factor"] is not None and params["contrast_before"]:
242
+ inpt = self._call_kernel(F.adjust_contrast, inpt, contrast_factor=params["contrast_factor"])
243
+ if params["saturation_factor"] is not None:
244
+ inpt = self._call_kernel(F.adjust_saturation, inpt, saturation_factor=params["saturation_factor"])
245
+ if params["hue_factor"] is not None:
246
+ inpt = self._call_kernel(F.adjust_hue, inpt, hue_factor=params["hue_factor"])
247
+ if params["contrast_factor"] is not None and not params["contrast_before"]:
248
+ inpt = self._call_kernel(F.adjust_contrast, inpt, contrast_factor=params["contrast_factor"])
249
+ if params["channel_permutation"] is not None:
250
+ inpt = self._call_kernel(F.permute_channels, inpt, permutation=params["channel_permutation"])
251
+ return inpt
252
+
253
+
254
+ class RandomEqualize(_RandomApplyTransform):
255
+ """Equalize the histogram of the given image or video with a given probability.
256
+
257
+ If the input is a :class:`torch.Tensor`, it is expected
258
+ to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
259
+ If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
260
+
261
+ Args:
262
+ p (float): probability of the image being equalized. Default value is 0.5
263
+ """
264
+
265
+ _v1_transform_cls = _transforms.RandomEqualize
266
+
267
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
268
+ return self._call_kernel(F.equalize, inpt)
269
+
270
+
271
+ class RandomInvert(_RandomApplyTransform):
272
+ """Inverts the colors of the given image or video with a given probability.
273
+
274
+ If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
275
+ where ... means it can have an arbitrary number of leading dimensions.
276
+ If img is PIL Image, it is expected to be in mode "L" or "RGB".
277
+
278
+ Args:
279
+ p (float): probability of the image being color inverted. Default value is 0.5
280
+ """
281
+
282
+ _v1_transform_cls = _transforms.RandomInvert
283
+
284
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
285
+ return self._call_kernel(F.invert, inpt)
286
+
287
+
288
+ class RandomPosterize(_RandomApplyTransform):
289
+ """Posterize the image or video with a given probability by reducing the
290
+ number of bits for each color channel.
291
+
292
+ If the input is a :class:`torch.Tensor`, it should be of type torch.uint8,
293
+ and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
294
+ If img is PIL Image, it is expected to be in mode "L" or "RGB".
295
+
296
+ Args:
297
+ bits (int): number of bits to keep for each channel (0-8)
298
+ p (float): probability of the image being posterized. Default value is 0.5
299
+ """
300
+
301
+ _v1_transform_cls = _transforms.RandomPosterize
302
+
303
+ def __init__(self, bits: int, p: float = 0.5) -> None:
304
+ super().__init__(p=p)
305
+ self.bits = bits
306
+
307
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
308
+ return self._call_kernel(F.posterize, inpt, bits=self.bits)
309
+
310
+
311
+ class RandomSolarize(_RandomApplyTransform):
312
+ """Solarize the image or video with a given probability by inverting all pixel
313
+ values above a threshold.
314
+
315
+ If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
316
+ where ... means it can have an arbitrary number of leading dimensions.
317
+ If img is PIL Image, it is expected to be in mode "L" or "RGB".
318
+
319
+ Args:
320
+ threshold (float): all pixels equal or above this value are inverted.
321
+ p (float): probability of the image being solarized. Default value is 0.5
322
+ """
323
+
324
+ _v1_transform_cls = _transforms.RandomSolarize
325
+
326
+ def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
327
+ params = super()._extract_params_for_v1_transform()
328
+ params["threshold"] = float(params["threshold"])
329
+ return params
330
+
331
+ def __init__(self, threshold: float, p: float = 0.5) -> None:
332
+ super().__init__(p=p)
333
+ self.threshold = threshold
334
+
335
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
336
+ return self._call_kernel(F.solarize, inpt, threshold=self.threshold)
337
+
338
+
339
+ class RandomAutocontrast(_RandomApplyTransform):
340
+ """Autocontrast the pixels of the given image or video with a given probability.
341
+
342
+ If the input is a :class:`torch.Tensor`, it is expected
343
+ to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
344
+ If img is PIL Image, it is expected to be in mode "L" or "RGB".
345
+
346
+ Args:
347
+ p (float): probability of the image being autocontrasted. Default value is 0.5
348
+ """
349
+
350
+ _v1_transform_cls = _transforms.RandomAutocontrast
351
+
352
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
353
+ return self._call_kernel(F.autocontrast, inpt)
354
+
355
+
356
+ class RandomAdjustSharpness(_RandomApplyTransform):
357
+ """Adjust the sharpness of the image or video with a given probability.
358
+
359
+ If the input is a :class:`torch.Tensor`,
360
+ it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
361
+
362
+ Args:
363
+ sharpness_factor (float): How much to adjust the sharpness. Can be
364
+ any non-negative number. 0 gives a blurred image, 1 gives the
365
+ original image while 2 increases the sharpness by a factor of 2.
366
+ p (float): probability of the image being sharpened. Default value is 0.5
367
+ """
368
+
369
+ _v1_transform_cls = _transforms.RandomAdjustSharpness
370
+
371
+ def __init__(self, sharpness_factor: float, p: float = 0.5) -> None:
372
+ super().__init__(p=p)
373
+ self.sharpness_factor = sharpness_factor
374
+
375
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
376
+ return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=self.sharpness_factor)
.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_container.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Union
2
+
3
+ import torch
4
+
5
+ from torch import nn
6
+ from torchvision import transforms as _transforms
7
+ from torchvision.transforms.v2 import Transform
8
+
9
+
10
+ class Compose(Transform):
11
+ """Composes several transforms together.
12
+
13
+ This transform does not support torchscript.
14
+ Please, see the note below.
15
+
16
+ Args:
17
+ transforms (list of ``Transform`` objects): list of transforms to compose.
18
+
19
+ Example:
20
+ >>> transforms.Compose([
21
+ >>> transforms.CenterCrop(10),
22
+ >>> transforms.PILToTensor(),
23
+ >>> transforms.ConvertImageDtype(torch.float),
24
+ >>> ])
25
+
26
+ .. note::
27
+ In order to script the transformations, please use ``torch.nn.Sequential`` as below.
28
+
29
+ >>> transforms = torch.nn.Sequential(
30
+ >>> transforms.CenterCrop(10),
31
+ >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
32
+ >>> )
33
+ >>> scripted_transforms = torch.jit.script(transforms)
34
+
35
+ Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
36
+ `lambda` functions or ``PIL.Image``.
37
+
38
+ """
39
+
40
+ def __init__(self, transforms: Sequence[Callable]) -> None:
41
+ super().__init__()
42
+ if not isinstance(transforms, Sequence):
43
+ raise TypeError("Argument transforms should be a sequence of callables")
44
+ elif not transforms:
45
+ raise ValueError("Pass at least one transform")
46
+ self.transforms = transforms
47
+
48
+ def forward(self, *inputs: Any) -> Any:
49
+ needs_unpacking = len(inputs) > 1
50
+ for transform in self.transforms:
51
+ outputs = transform(*inputs)
52
+ inputs = outputs if needs_unpacking else (outputs,)
53
+ return outputs
54
+
55
+ def extra_repr(self) -> str:
56
+ format_string = []
57
+ for t in self.transforms:
58
+ format_string.append(f" {t}")
59
+ return "\n".join(format_string)
60
+
61
+
62
+ class RandomApply(Transform):
63
+ """Apply randomly a list of transformations with a given probability.
64
+
65
+ .. note::
66
+ In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of
67
+ transforms as shown below:
68
+
69
+ >>> transforms = transforms.RandomApply(torch.nn.ModuleList([
70
+ >>> transforms.ColorJitter(),
71
+ >>> ]), p=0.3)
72
+ >>> scripted_transforms = torch.jit.script(transforms)
73
+
74
+ Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
75
+ `lambda` functions or ``PIL.Image``.
76
+
77
+ Args:
78
+ transforms (sequence or torch.nn.Module): list of transformations
79
+ p (float): probability of applying the list of transforms
80
+ """
81
+
82
+ _v1_transform_cls = _transforms.RandomApply
83
+
84
+ def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: float = 0.5) -> None:
85
+ super().__init__()
86
+
87
+ if not isinstance(transforms, (Sequence, nn.ModuleList)):
88
+ raise TypeError("Argument transforms should be a sequence of callables or a `nn.ModuleList`")
89
+ self.transforms = transforms
90
+
91
+ if not (0.0 <= p <= 1.0):
92
+ raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")
93
+ self.p = p
94
+
95
+ def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
96
+ return {"transforms": self.transforms, "p": self.p}
97
+
98
+ def forward(self, *inputs: Any) -> Any:
99
+ needs_unpacking = len(inputs) > 1
100
+
101
+ if torch.rand(1) >= self.p:
102
+ return inputs if needs_unpacking else inputs[0]
103
+
104
+ for transform in self.transforms:
105
+ outputs = transform(*inputs)
106
+ inputs = outputs if needs_unpacking else (outputs,)
107
+ return outputs
108
+
109
+ def extra_repr(self) -> str:
110
+ format_string = []
111
+ for t in self.transforms:
112
+ format_string.append(f" {t}")
113
+ return "\n".join(format_string)
114
+
115
+
116
+ class RandomChoice(Transform):
117
+ """Apply single transformation randomly picked from a list.
118
+
119
+ This transform does not support torchscript.
120
+
121
+ Args:
122
+ transforms (sequence or torch.nn.Module): list of transformations
123
+ p (list of floats or None, optional): probability of each transform being picked.
124
+ If ``p`` doesn't sum to 1, it is automatically normalized. If ``None``
125
+ (default), all transforms have the same probability.
126
+ """
127
+
128
+ def __init__(
129
+ self,
130
+ transforms: Sequence[Callable],
131
+ p: Optional[List[float]] = None,
132
+ ) -> None:
133
+ if not isinstance(transforms, Sequence):
134
+ raise TypeError("Argument transforms should be a sequence of callables")
135
+
136
+ if p is None:
137
+ p = [1] * len(transforms)
138
+ elif len(p) != len(transforms):
139
+ raise ValueError(f"Length of p doesn't match the number of transforms: {len(p)} != {len(transforms)}")
140
+
141
+ super().__init__()
142
+
143
+ self.transforms = transforms
144
+ total = sum(p)
145
+ self.p = [prob / total for prob in p]
146
+
147
+ def forward(self, *inputs: Any) -> Any:
148
+ idx = int(torch.multinomial(torch.tensor(self.p), 1))
149
+ transform = self.transforms[idx]
150
+ return transform(*inputs)
151
+
152
+
153
+ class RandomOrder(Transform):
154
+ """Apply a list of transformations in a random order.
155
+
156
+ This transform does not support torchscript.
157
+
158
+ Args:
159
+ transforms (sequence or torch.nn.Module): list of transformations
160
+ """
161
+
162
+ def __init__(self, transforms: Sequence[Callable]) -> None:
163
+ if not isinstance(transforms, Sequence):
164
+ raise TypeError("Argument transforms should be a sequence of callables")
165
+ super().__init__()
166
+ self.transforms = transforms
167
+
168
+ def forward(self, *inputs: Any) -> Any:
169
+ needs_unpacking = len(inputs) > 1
170
+ for idx in torch.randperm(len(self.transforms)):
171
+ transform = self.transforms[idx]
172
+ outputs = transform(*inputs)
173
+ inputs = outputs if needs_unpacking else (outputs,)
174
+ return outputs
.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_deprecated.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import Any, Dict, Union
3
+
4
+ import numpy as np
5
+ import PIL.Image
6
+ import torch
7
+ from torchvision.transforms import functional as _F
8
+
9
+ from torchvision.transforms.v2 import Transform
10
+
11
+
12
+ class ToTensor(Transform):
13
+ """[DEPRECATED] Use ``v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`` instead.
14
+
15
+ Convert a PIL Image or ndarray to tensor and scale the values accordingly.
16
+
17
+ .. warning::
18
+ :class:`v2.ToTensor` is deprecated and will be removed in a future release.
19
+ Please use instead ``v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])``.
20
+ Output is equivalent up to float precision.
21
+
22
+ This transform does not support torchscript.
23
+
24
+
25
+ Converts a PIL Image or numpy.ndarray (H x W x C) in the range
26
+ [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
27
+ if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
28
+ or if the numpy.ndarray has dtype = np.uint8
29
+
30
+ In the other cases, tensors are returned without scaling.
31
+
32
+ .. note::
33
+ Because the input image is scaled to [0.0, 1.0], this transformation should not be used when
34
+ transforming target image masks. See the `references`_ for implementing the transforms for image masks.
35
+
36
+ .. _references: https://github.com/pytorch/vision/tree/main/references/segmentation
37
+ """
38
+
39
+ _transformed_types = (PIL.Image.Image, np.ndarray)
40
+
41
+ def __init__(self) -> None:
42
+ warnings.warn(
43
+ "The transform `ToTensor()` is deprecated and will be removed in a future release. "
44
+ "Instead, please use `v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`."
45
+ "Output is equivalent up to float precision."
46
+ )
47
+ super().__init__()
48
+
49
+ def _transform(self, inpt: Union[PIL.Image.Image, np.ndarray], params: Dict[str, Any]) -> torch.Tensor:
50
+ return _F.to_tensor(inpt)
.venv/lib/python3.11/site-packages/torchvision/transforms/v2/_geometry.py ADDED
@@ -0,0 +1,1416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numbers
3
+ import warnings
4
+ from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Type, Union
5
+
6
+ import PIL.Image
7
+ import torch
8
+
9
+ from torchvision import transforms as _transforms, tv_tensors
10
+ from torchvision.ops.boxes import box_iou
11
+ from torchvision.transforms.functional import _get_perspective_coeffs
12
+ from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform
13
+ from torchvision.transforms.v2.functional._utils import _FillType
14
+
15
+ from ._transform import _RandomApplyTransform
16
+ from ._utils import (
17
+ _check_padding_arg,
18
+ _check_padding_mode_arg,
19
+ _check_sequence_input,
20
+ _get_fill,
21
+ _setup_angle,
22
+ _setup_fill_arg,
23
+ _setup_number_or_seq,
24
+ _setup_size,
25
+ get_bounding_boxes,
26
+ has_all,
27
+ has_any,
28
+ is_pure_tensor,
29
+ query_size,
30
+ )
31
+
32
+
33
+ class RandomHorizontalFlip(_RandomApplyTransform):
34
+ """Horizontally flip the input with a given probability.
35
+
36
+ If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
37
+ :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
38
+ it can have arbitrary number of leading batch dimensions. For example,
39
+ the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
40
+
41
+ Args:
42
+ p (float, optional): probability of the input being flipped. Default value is 0.5
43
+ """
44
+
45
+ _v1_transform_cls = _transforms.RandomHorizontalFlip
46
+
47
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
48
+ return self._call_kernel(F.horizontal_flip, inpt)
49
+
50
+
51
+ class RandomVerticalFlip(_RandomApplyTransform):
52
+ """Vertically flip the input with a given probability.
53
+
54
+ If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
55
+ :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
56
+ it can have arbitrary number of leading batch dimensions. For example,
57
+ the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
58
+
59
+ Args:
60
+ p (float, optional): probability of the input being flipped. Default value is 0.5
61
+ """
62
+
63
+ _v1_transform_cls = _transforms.RandomVerticalFlip
64
+
65
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
66
+ return self._call_kernel(F.vertical_flip, inpt)
67
+
68
+
69
+ class Resize(Transform):
70
+ """Resize the input to the given size.
71
+
72
+ If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
73
+ :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
74
+ it can have arbitrary number of leading batch dimensions. For example,
75
+ the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
76
+
77
+ Args:
78
+ size (sequence, int, or None): Desired
79
+ output size.
80
+
81
+ - If size is a sequence like (h, w), output size will be matched to this.
82
+ - If size is an int, smaller edge of the image will be matched to this
83
+ number. i.e, if height > width, then image will be rescaled to
84
+ (size * height / width, size).
85
+ - If size is None, the output shape is determined by the ``max_size``
86
+ parameter.
87
+
88
+ .. note::
89
+ In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
90
+ interpolation (InterpolationMode, optional): Desired interpolation enum defined by
91
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
92
+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
93
+ ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
94
+ The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
95
+ max_size (int, optional): The maximum allowed for the longer edge of
96
+ the resized image.
97
+
98
+ - If ``size`` is an int: if the longer edge of the image is greater
99
+ than ``max_size`` after being resized according to ``size``,
100
+ ``size`` will be overruled so that the longer edge is equal to
101
+ ``max_size``. As a result, the smaller edge may be shorter than
102
+ ``size``. This is only supported if ``size`` is an int (or a
103
+ sequence of length 1 in torchscript mode).
104
+ - If ``size`` is None: the longer edge of the image will be matched
105
+ to max_size. i.e, if height > width, then image will be rescaled
106
+ to (max_size, max_size * width / height).
107
+
108
+ This should be left to ``None`` (default) when ``size`` is a
109
+ sequence.
110
+
111
+ antialias (bool, optional): Whether to apply antialiasing.
112
+ It only affects **tensors** with bilinear or bicubic modes and it is
113
+ ignored otherwise: on PIL images, antialiasing is always applied on
114
+ bilinear or bicubic modes; on other modes (for PIL images and
115
+ tensors), antialiasing makes no sense and this parameter is ignored.
116
+ Possible values are:
117
+
118
+ - ``True`` (default): will apply antialiasing for bilinear or bicubic modes.
119
+ Other mode aren't affected. This is probably what you want to use.
120
+ - ``False``: will not apply antialiasing for tensors on any mode. PIL
121
+ images are still antialiased on bilinear or bicubic modes, because
122
+ PIL doesn't support no antialias.
123
+ - ``None``: equivalent to ``False`` for tensors and ``True`` for
124
+ PIL images. This value exists for legacy reasons and you probably
125
+ don't want to use it unless you really know what you are doing.
126
+
127
+ The default value changed from ``None`` to ``True`` in
128
+ v0.17, for the PIL and Tensor backends to be consistent.
129
+ """
130
+
131
+ _v1_transform_cls = _transforms.Resize
132
+
133
+ def __init__(
134
+ self,
135
+ size: Union[int, Sequence[int], None],
136
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
137
+ max_size: Optional[int] = None,
138
+ antialias: Optional[bool] = True,
139
+ ) -> None:
140
+ super().__init__()
141
+
142
+ if isinstance(size, int):
143
+ size = [size]
144
+ elif isinstance(size, Sequence) and len(size) in {1, 2}:
145
+ size = list(size)
146
+ elif size is None:
147
+ if not isinstance(max_size, int):
148
+ raise ValueError(f"max_size must be an integer when size is None, but got {max_size} instead.")
149
+ else:
150
+ raise ValueError(
151
+ f"size can be an integer, a sequence of one or two integers, or None, but got {size} instead."
152
+ )
153
+ self.size = size
154
+
155
+ self.interpolation = interpolation
156
+ self.max_size = max_size
157
+ self.antialias = antialias
158
+
159
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
160
+ return self._call_kernel(
161
+ F.resize,
162
+ inpt,
163
+ self.size,
164
+ interpolation=self.interpolation,
165
+ max_size=self.max_size,
166
+ antialias=self.antialias,
167
+ )
168
+
169
+
170
+ class CenterCrop(Transform):
171
+ """Crop the input at the center.
172
+
173
+ If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
174
+ :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
175
+ it can have arbitrary number of leading batch dimensions. For example,
176
+ the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
177
+
178
+ If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
179
+
180
+ Args:
181
+ size (sequence or int): Desired output size of the crop. If size is an
182
+ int instead of sequence like (h, w), a square crop (size, size) is
183
+ made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
184
+ """
185
+
186
+ _v1_transform_cls = _transforms.CenterCrop
187
+
188
+ def __init__(self, size: Union[int, Sequence[int]]):
189
+ super().__init__()
190
+ self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
191
+
192
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
193
+ return self._call_kernel(F.center_crop, inpt, output_size=self.size)
194
+
195
+
196
+ class RandomResizedCrop(Transform):
197
+ """Crop a random portion of the input and resize it to a given size.
198
+
199
+ If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
200
+ :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
201
+ it can have arbitrary number of leading batch dimensions. For example,
202
+ the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
203
+
204
+ A crop of the original input is made: the crop has a random area (H * W)
205
+ and a random aspect ratio. This crop is finally resized to the given
206
+ size. This is popularly used to train the Inception networks.
207
+
208
+ Args:
209
+ size (int or sequence): expected output size of the crop, for each edge. If size is an
210
+ int instead of sequence like (h, w), a square output size ``(size, size)`` is
211
+ made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
212
+
213
+ .. note::
214
+ In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
215
+ scale (tuple of float, optional): Specifies the lower and upper bounds for the random area of the crop,
216
+ before resizing. The scale is defined with respect to the area of the original image.
217
+ ratio (tuple of float, optional): lower and upper bounds for the random aspect ratio of the crop, before
218
+ resizing.
219
+ interpolation (InterpolationMode, optional): Desired interpolation enum defined by
220
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
221
+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
222
+ ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
223
+ The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
224
+ antialias (bool, optional): Whether to apply antialiasing.
225
+ It only affects **tensors** with bilinear or bicubic modes and it is
226
+ ignored otherwise: on PIL images, antialiasing is always applied on
227
+ bilinear or bicubic modes; on other modes (for PIL images and
228
+ tensors), antialiasing makes no sense and this parameter is ignored.
229
+ Possible values are:
230
+
231
+ - ``True`` (default): will apply antialiasing for bilinear or bicubic modes.
232
+ Other mode aren't affected. This is probably what you want to use.
233
+ - ``False``: will not apply antialiasing for tensors on any mode. PIL
234
+ images are still antialiased on bilinear or bicubic modes, because
235
+ PIL doesn't support no antialias.
236
+ - ``None``: equivalent to ``False`` for tensors and ``True`` for
237
+ PIL images. This value exists for legacy reasons and you probably
238
+ don't want to use it unless you really know what you are doing.
239
+
240
+ The default value changed from ``None`` to ``True`` in
241
+ v0.17, for the PIL and Tensor backends to be consistent.
242
+ """
243
+
244
+ _v1_transform_cls = _transforms.RandomResizedCrop
245
+
246
+ def __init__(
247
+ self,
248
+ size: Union[int, Sequence[int]],
249
+ scale: Tuple[float, float] = (0.08, 1.0),
250
+ ratio: Tuple[float, float] = (3.0 / 4.0, 4.0 / 3.0),
251
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
252
+ antialias: Optional[bool] = True,
253
+ ) -> None:
254
+ super().__init__()
255
+ self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
256
+
257
+ if not isinstance(scale, Sequence):
258
+ raise TypeError("Scale should be a sequence")
259
+ if not isinstance(ratio, Sequence):
260
+ raise TypeError("Ratio should be a sequence")
261
+ if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
262
+ warnings.warn("Scale and ratio should be of kind (min, max)")
263
+
264
+ self.scale = scale
265
+ self.ratio = ratio
266
+ self.interpolation = interpolation
267
+ self.antialias = antialias
268
+
269
+ self._log_ratio = torch.log(torch.tensor(self.ratio))
270
+
271
+ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
272
+ height, width = query_size(flat_inputs)
273
+ area = height * width
274
+
275
+ log_ratio = self._log_ratio
276
+ for _ in range(10):
277
+ target_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
278
+ aspect_ratio = torch.exp(
279
+ torch.empty(1).uniform_(
280
+ log_ratio[0], # type: ignore[arg-type]
281
+ log_ratio[1], # type: ignore[arg-type]
282
+ )
283
+ ).item()
284
+
285
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
286
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
287
+
288
+ if 0 < w <= width and 0 < h <= height:
289
+ i = torch.randint(0, height - h + 1, size=(1,)).item()
290
+ j = torch.randint(0, width - w + 1, size=(1,)).item()
291
+ break
292
+ else:
293
+ # Fallback to central crop
294
+ in_ratio = float(width) / float(height)
295
+ if in_ratio < min(self.ratio):
296
+ w = width
297
+ h = int(round(w / min(self.ratio)))
298
+ elif in_ratio > max(self.ratio):
299
+ h = height
300
+ w = int(round(h * max(self.ratio)))
301
+ else: # whole image
302
+ w = width
303
+ h = height
304
+ i = (height - h) // 2
305
+ j = (width - w) // 2
306
+
307
+ return dict(top=i, left=j, height=h, width=w)
308
+
309
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
310
+ return self._call_kernel(
311
+ F.resized_crop, inpt, **params, size=self.size, interpolation=self.interpolation, antialias=self.antialias
312
+ )
313
+
314
+
315
+ class FiveCrop(Transform):
316
+ """Crop the image or video into four corners and the central crop.
317
+
318
+ If the input is a :class:`torch.Tensor` or a :class:`~torchvision.tv_tensors.Image` or a
319
+ :class:`~torchvision.tv_tensors.Video` it can have arbitrary number of leading batch dimensions.
320
+ For example, the image can have ``[..., C, H, W]`` shape.
321
+
322
+ .. Note::
323
+ This transform returns a tuple of images and there may be a mismatch in the number of
324
+ inputs and targets your Dataset returns. See below for an example of how to deal with
325
+ this.
326
+
327
+ Args:
328
+ size (sequence or int): Desired output size of the crop. If size is an ``int``
329
+ instead of sequence like (h, w), a square crop of size (size, size) is made.
330
+ If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
331
+
332
+ Example:
333
+ >>> class BatchMultiCrop(transforms.Transform):
334
+ ... def forward(self, sample: Tuple[Tuple[Union[tv_tensors.Image, tv_tensors.Video], ...], int]):
335
+ ... images_or_videos, labels = sample
336
+ ... batch_size = len(images_or_videos)
337
+ ... image_or_video = images_or_videos[0]
338
+ ... images_or_videos = tv_tensors.wrap(torch.stack(images_or_videos), like=image_or_video)
339
+ ... labels = torch.full((batch_size,), label, device=images_or_videos.device)
340
+ ... return images_or_videos, labels
341
+ ...
342
+ >>> image = tv_tensors.Image(torch.rand(3, 256, 256))
343
+ >>> label = 3
344
+ >>> transform = transforms.Compose([transforms.FiveCrop(224), BatchMultiCrop()])
345
+ >>> images, labels = transform(image, label)
346
+ >>> images.shape
347
+ torch.Size([5, 3, 224, 224])
348
+ >>> labels
349
+ tensor([3, 3, 3, 3, 3])
350
+ """
351
+
352
+ _v1_transform_cls = _transforms.FiveCrop
353
+
354
+ def __init__(self, size: Union[int, Sequence[int]]) -> None:
355
+ super().__init__()
356
+ self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
357
+
358
+ def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
359
+ if isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask)):
360
+ warnings.warn(
361
+ f"{type(self).__name__}() is currently passing through inputs of type "
362
+ f"tv_tensors.{type(inpt).__name__}. This will likely change in the future."
363
+ )
364
+ return super()._call_kernel(functional, inpt, *args, **kwargs)
365
+
366
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
367
+ return self._call_kernel(F.five_crop, inpt, self.size)
368
+
369
+ def _check_inputs(self, flat_inputs: List[Any]) -> None:
370
+ if has_any(flat_inputs, tv_tensors.BoundingBoxes, tv_tensors.Mask):
371
+ raise TypeError(f"BoundingBoxes'es and Mask's are not supported by {type(self).__name__}()")
372
+
373
+
374
+ class TenCrop(Transform):
375
+ """Crop the image or video into four corners and the central crop plus the flipped version of
376
+ these (horizontal flipping is used by default).
377
+
378
+ If the input is a :class:`torch.Tensor` or a :class:`~torchvision.tv_tensors.Image` or a
379
+ :class:`~torchvision.tv_tensors.Video` it can have arbitrary number of leading batch dimensions.
380
+ For example, the image can have ``[..., C, H, W]`` shape.
381
+
382
+ See :class:`~torchvision.transforms.v2.FiveCrop` for an example.
383
+
384
+ .. Note::
385
+ This transform returns a tuple of images and there may be a mismatch in the number of
386
+ inputs and targets your Dataset returns. See below for an example of how to deal with
387
+ this.
388
+
389
+ Args:
390
+ size (sequence or int): Desired output size of the crop. If size is an
391
+ int instead of sequence like (h, w), a square crop (size, size) is
392
+ made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
393
+ vertical_flip (bool, optional): Use vertical flipping instead of horizontal
394
+ """
395
+
396
+ _v1_transform_cls = _transforms.TenCrop
397
+
398
+ def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None:
399
+ super().__init__()
400
+ self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
401
+ self.vertical_flip = vertical_flip
402
+
403
+ def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
404
+ if isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask)):
405
+ warnings.warn(
406
+ f"{type(self).__name__}() is currently passing through inputs of type "
407
+ f"tv_tensors.{type(inpt).__name__}. This will likely change in the future."
408
+ )
409
+ return super()._call_kernel(functional, inpt, *args, **kwargs)
410
+
411
+ def _check_inputs(self, flat_inputs: List[Any]) -> None:
412
+ if has_any(flat_inputs, tv_tensors.BoundingBoxes, tv_tensors.Mask):
413
+ raise TypeError(f"BoundingBoxes'es and Mask's are not supported by {type(self).__name__}()")
414
+
415
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
416
+ return self._call_kernel(F.ten_crop, inpt, self.size, vertical_flip=self.vertical_flip)
417
+
418
+
419
+ class Pad(Transform):
420
+ """Pad the input on all sides with the given "pad" value.
421
+
422
+ If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
423
+ :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
424
+ it can have arbitrary number of leading batch dimensions. For example,
425
+ the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
426
+
427
+ Args:
428
+ padding (int or sequence): Padding on each border. If a single int is provided this
429
+ is used to pad all borders. If sequence of length 2 is provided this is the padding
430
+ on left/right and top/bottom respectively. If a sequence of length 4 is provided
431
+ this is the padding for the left, top, right and bottom borders respectively.
432
+
433
+ .. note::
434
+ In torchscript mode padding as single int is not supported, use a sequence of
435
+ length 1: ``[padding, ]``.
436
+ fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
437
+ Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
438
+ Fill value can be also a dictionary mapping data type to the fill value, e.g.
439
+ ``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and
440
+ ``Mask`` will be filled with 0.
441
+ padding_mode (str, optional): Type of padding. Should be: constant, edge, reflect or symmetric.
442
+ Default is "constant".
443
+
444
+ - constant: pads with a constant value, this value is specified with fill
445
+
446
+ - edge: pads with the last value at the edge of the image.
447
+
448
+ - reflect: pads with reflection of image without repeating the last value on the edge.
449
+ For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
450
+ will result in [3, 2, 1, 2, 3, 4, 3, 2]
451
+
452
+ - symmetric: pads with reflection of image repeating the last value on the edge.
453
+ For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
454
+ will result in [2, 1, 1, 2, 3, 4, 4, 3]
455
+ """
456
+
457
+ _v1_transform_cls = _transforms.Pad
458
+
459
+ def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
460
+ params = super()._extract_params_for_v1_transform()
461
+
462
+ if not (params["fill"] is None or isinstance(params["fill"], (int, float))):
463
+ raise ValueError(f"{type(self).__name__}() can only be scripted for a scalar `fill`, but got {self.fill}.")
464
+
465
+ return params
466
+
467
+ def __init__(
468
+ self,
469
+ padding: Union[int, Sequence[int]],
470
+ fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
471
+ padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
472
+ ) -> None:
473
+ super().__init__()
474
+
475
+ _check_padding_arg(padding)
476
+ _check_padding_mode_arg(padding_mode)
477
+
478
+ # This cast does Sequence[int] -> List[int] and is required to make mypy happy
479
+ if not isinstance(padding, int):
480
+ padding = list(padding)
481
+ self.padding = padding
482
+ self.fill = fill
483
+ self._fill = _setup_fill_arg(fill)
484
+ self.padding_mode = padding_mode
485
+
486
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
487
+ fill = _get_fill(self._fill, type(inpt))
488
+ return self._call_kernel(F.pad, inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type]
489
+
490
+
491
+ class RandomZoomOut(_RandomApplyTransform):
492
+ """ "Zoom out" transformation from
493
+ `"SSD: Single Shot MultiBox Detector" <https://arxiv.org/abs/1512.02325>`_.
494
+
495
+ This transformation randomly pads images, videos, bounding boxes and masks creating a zoom out effect.
496
+ Output spatial size is randomly sampled from original size up to a maximum size configured
497
+ with ``side_range`` parameter:
498
+
499
+ .. code-block:: python
500
+
501
+ r = uniform_sample(side_range[0], side_range[1])
502
+ output_width = input_width * r
503
+ output_height = input_height * r
504
+
505
+ If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
506
+ :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
507
+ it can have arbitrary number of leading batch dimensions. For example,
508
+ the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
509
+
510
+ Args:
511
+ fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
512
+ Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
513
+ Fill value can be also a dictionary mapping data type to the fill value, e.g.
514
+ ``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and
515
+ ``Mask`` will be filled with 0.
516
+ side_range (sequence of floats, optional): tuple of two floats defines minimum and maximum factors to
517
+ scale the input size.
518
+ p (float, optional): probability that the zoom operation will be performed.
519
+ """
520
+
521
+ def __init__(
522
+ self,
523
+ fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
524
+ side_range: Sequence[float] = (1.0, 4.0),
525
+ p: float = 0.5,
526
+ ) -> None:
527
+ super().__init__(p=p)
528
+
529
+ self.fill = fill
530
+ self._fill = _setup_fill_arg(fill)
531
+
532
+ _check_sequence_input(side_range, "side_range", req_sizes=(2,))
533
+
534
+ self.side_range = side_range
535
+ if side_range[0] < 1.0 or side_range[0] > side_range[1]:
536
+ raise ValueError(f"Invalid side range provided {side_range}.")
537
+
538
+ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
539
+ orig_h, orig_w = query_size(flat_inputs)
540
+
541
+ r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
542
+ canvas_width = int(orig_w * r)
543
+ canvas_height = int(orig_h * r)
544
+
545
+ r = torch.rand(2)
546
+ left = int((canvas_width - orig_w) * r[0])
547
+ top = int((canvas_height - orig_h) * r[1])
548
+ right = canvas_width - (left + orig_w)
549
+ bottom = canvas_height - (top + orig_h)
550
+ padding = [left, top, right, bottom]
551
+
552
+ return dict(padding=padding)
553
+
554
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
555
+ fill = _get_fill(self._fill, type(inpt))
556
+ return self._call_kernel(F.pad, inpt, **params, fill=fill)
557
+
558
+
559
+ class RandomRotation(Transform):
560
+ """Rotate the input by angle.
561
+
562
+ If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
563
+ :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
564
+ it can have arbitrary number of leading batch dimensions. For example,
565
+ the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
566
+
567
+ Args:
568
+ degrees (sequence or number): Range of degrees to select from.
569
+ If degrees is a number instead of sequence like (min, max), the range of degrees
570
+ will be (-degrees, +degrees).
571
+ interpolation (InterpolationMode, optional): Desired interpolation enum defined by
572
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
573
+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
574
+ The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
575
+ expand (bool, optional): Optional expansion flag.
576
+ If true, expands the output to make it large enough to hold the entire rotated image.
577
+ If false or omitted, make the output image the same size as the input image.
578
+ Note that the expand flag assumes rotation around the center (see note below) and no translation.
579
+ center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
580
+ Default is the center of the image.
581
+
582
+ .. note::
583
+
584
+ In theory, setting ``center`` has no effect if ``expand=True``, since the image center will become the
585
+ center of rotation. In practice however, due to numerical precision, this can lead to off-by-one
586
+ differences of the resulting image size compared to using the image center in the first place. Thus, when
587
+ setting ``expand=True``, it's best to leave ``center=None`` (default).
588
+ fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
589
+ Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
590
+ Fill value can be also a dictionary mapping data type to the fill value, e.g.
591
+ ``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and
592
+ ``Mask`` will be filled with 0.
593
+
594
+ .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
595
+
596
+ """
597
+
598
+ _v1_transform_cls = _transforms.RandomRotation
599
+
600
+ def __init__(
601
+ self,
602
+ degrees: Union[numbers.Number, Sequence],
603
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
604
+ expand: bool = False,
605
+ center: Optional[List[float]] = None,
606
+ fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
607
+ ) -> None:
608
+ super().__init__()
609
+ self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
610
+ self.interpolation = interpolation
611
+ self.expand = expand
612
+
613
+ self.fill = fill
614
+ self._fill = _setup_fill_arg(fill)
615
+
616
+ if center is not None:
617
+ _check_sequence_input(center, "center", req_sizes=(2,))
618
+
619
+ self.center = center
620
+
621
+ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
622
+ angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item()
623
+ return dict(angle=angle)
624
+
625
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
626
+ fill = _get_fill(self._fill, type(inpt))
627
+ return self._call_kernel(
628
+ F.rotate,
629
+ inpt,
630
+ **params,
631
+ interpolation=self.interpolation,
632
+ expand=self.expand,
633
+ center=self.center,
634
+ fill=fill,
635
+ )
636
+
637
+
638
+ class RandomAffine(Transform):
639
+ """Random affine transformation the input keeping center invariant.
640
+
641
+ If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
642
+ :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
643
+ it can have arbitrary number of leading batch dimensions. For example,
644
+ the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
645
+
646
+ Args:
647
+ degrees (sequence or number): Range of degrees to select from.
648
+ If degrees is a number instead of sequence like (min, max), the range of degrees
649
+ will be (-degrees, +degrees). Set to 0 to deactivate rotations.
650
+ translate (tuple, optional): tuple of maximum absolute fraction for horizontal
651
+ and vertical translations. For example translate=(a, b), then horizontal shift
652
+ is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
653
+ randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
654
+ scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is
655
+ randomly sampled from the range a <= scale <= b. Will keep original scale by default.
656
+ shear (sequence or number, optional): Range of degrees to select from.
657
+ If shear is a number, a shear parallel to the x-axis in the range (-shear, +shear)
658
+ will be applied. Else if shear is a sequence of 2 values a shear parallel to the x-axis in the
659
+ range (shear[0], shear[1]) will be applied. Else if shear is a sequence of 4 values,
660
+ an x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied.
661
+ Will not apply shear by default.
662
+ interpolation (InterpolationMode, optional): Desired interpolation enum defined by
663
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
664
+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
665
+ The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
666
+ fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
667
+ Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
668
+ Fill value can be also a dictionary mapping data type to the fill value, e.g.
669
+ ``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and
670
+ ``Mask`` will be filled with 0.
671
+ center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
672
+ Default is the center of the image.
673
+
674
+ .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
675
+
676
+ """
677
+
678
+ _v1_transform_cls = _transforms.RandomAffine
679
+
680
+ def __init__(
681
+ self,
682
+ degrees: Union[numbers.Number, Sequence],
683
+ translate: Optional[Sequence[float]] = None,
684
+ scale: Optional[Sequence[float]] = None,
685
+ shear: Optional[Union[int, float, Sequence[float]]] = None,
686
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
687
+ fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
688
+ center: Optional[List[float]] = None,
689
+ ) -> None:
690
+ super().__init__()
691
+ self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
692
+ if translate is not None:
693
+ _check_sequence_input(translate, "translate", req_sizes=(2,))
694
+ for t in translate:
695
+ if not (0.0 <= t <= 1.0):
696
+ raise ValueError("translation values should be between 0 and 1")
697
+ self.translate = translate
698
+ if scale is not None:
699
+ _check_sequence_input(scale, "scale", req_sizes=(2,))
700
+ for s in scale:
701
+ if s <= 0:
702
+ raise ValueError("scale values should be positive")
703
+ self.scale = scale
704
+
705
+ if shear is not None:
706
+ self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4))
707
+ else:
708
+ self.shear = shear
709
+
710
+ self.interpolation = interpolation
711
+ self.fill = fill
712
+ self._fill = _setup_fill_arg(fill)
713
+
714
+ if center is not None:
715
+ _check_sequence_input(center, "center", req_sizes=(2,))
716
+
717
+ self.center = center
718
+
719
+ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
720
+ height, width = query_size(flat_inputs)
721
+
722
+ angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item()
723
+ if self.translate is not None:
724
+ max_dx = float(self.translate[0] * width)
725
+ max_dy = float(self.translate[1] * height)
726
+ tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item()))
727
+ ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item()))
728
+ translate = (tx, ty)
729
+ else:
730
+ translate = (0, 0)
731
+
732
+ if self.scale is not None:
733
+ scale = torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
734
+ else:
735
+ scale = 1.0
736
+
737
+ shear_x = shear_y = 0.0
738
+ if self.shear is not None:
739
+ shear_x = torch.empty(1).uniform_(self.shear[0], self.shear[1]).item()
740
+ if len(self.shear) == 4:
741
+ shear_y = torch.empty(1).uniform_(self.shear[2], self.shear[3]).item()
742
+
743
+ shear = (shear_x, shear_y)
744
+ return dict(angle=angle, translate=translate, scale=scale, shear=shear)
745
+
746
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
747
+ fill = _get_fill(self._fill, type(inpt))
748
+ return self._call_kernel(
749
+ F.affine,
750
+ inpt,
751
+ **params,
752
+ interpolation=self.interpolation,
753
+ fill=fill,
754
+ center=self.center,
755
+ )
756
+
757
+
758
+ class RandomCrop(Transform):
759
+ """Crop the input at a random location.
760
+
761
+ If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
762
+ :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
763
+ it can have arbitrary number of leading batch dimensions. For example,
764
+ the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
765
+
766
+ Args:
767
+ size (sequence or int): Desired output size of the crop. If size is an
768
+ int instead of sequence like (h, w), a square crop (size, size) is
769
+ made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
770
+ padding (int or sequence, optional): Optional padding on each border
771
+ of the image. Default is None. If a single int is provided this
772
+ is used to pad all borders. If sequence of length 2 is provided this is the padding
773
+ on left/right and top/bottom respectively. If a sequence of length 4 is provided
774
+ this is the padding for the left, top, right and bottom borders respectively.
775
+
776
+ .. note::
777
+ In torchscript mode padding as single int is not supported, use a sequence of
778
+ length 1: ``[padding, ]``.
779
+ pad_if_needed (boolean, optional): It will pad the image if smaller than the
780
+ desired size to avoid raising an exception. Since cropping is done
781
+ after padding, the padding seems to be done at a random offset.
782
+ fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
783
+ Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
784
+ Fill value can be also a dictionary mapping data type to the fill value, e.g.
785
+ ``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and
786
+ ``Mask`` will be filled with 0.
787
+ padding_mode (str, optional): Type of padding. Should be: constant, edge, reflect or symmetric.
788
+ Default is constant.
789
+
790
+ - constant: pads with a constant value, this value is specified with fill
791
+
792
+ - edge: pads with the last value at the edge of the image.
793
+
794
+ - reflect: pads with reflection of image without repeating the last value on the edge.
795
+ For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
796
+ will result in [3, 2, 1, 2, 3, 4, 3, 2]
797
+
798
+ - symmetric: pads with reflection of image repeating the last value on the edge.
799
+ For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
800
+ will result in [2, 1, 1, 2, 3, 4, 4, 3]
801
+ """
802
+
803
+ _v1_transform_cls = _transforms.RandomCrop
804
+
805
+ def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
806
+ params = super()._extract_params_for_v1_transform()
807
+
808
+ if not (params["fill"] is None or isinstance(params["fill"], (int, float))):
809
+ raise ValueError(f"{type(self).__name__}() can only be scripted for a scalar `fill`, but got {self.fill}.")
810
+
811
+ padding = self.padding
812
+ if padding is not None:
813
+ pad_left, pad_right, pad_top, pad_bottom = padding
814
+ padding = [pad_left, pad_top, pad_right, pad_bottom]
815
+ params["padding"] = padding
816
+
817
+ return params
818
+
819
+ def __init__(
820
+ self,
821
+ size: Union[int, Sequence[int]],
822
+ padding: Optional[Union[int, Sequence[int]]] = None,
823
+ pad_if_needed: bool = False,
824
+ fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
825
+ padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
826
+ ) -> None:
827
+ super().__init__()
828
+
829
+ self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
830
+
831
+ if pad_if_needed or padding is not None:
832
+ if padding is not None:
833
+ _check_padding_arg(padding)
834
+ _check_padding_mode_arg(padding_mode)
835
+
836
+ self.padding = F._geometry._parse_pad_padding(padding) if padding else None # type: ignore[arg-type]
837
+ self.pad_if_needed = pad_if_needed
838
+ self.fill = fill
839
+ self._fill = _setup_fill_arg(fill)
840
+ self.padding_mode = padding_mode
841
+
842
+ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
843
+ padded_height, padded_width = query_size(flat_inputs)
844
+
845
+ if self.padding is not None:
846
+ pad_left, pad_right, pad_top, pad_bottom = self.padding
847
+ padded_height += pad_top + pad_bottom
848
+ padded_width += pad_left + pad_right
849
+ else:
850
+ pad_left = pad_right = pad_top = pad_bottom = 0
851
+
852
+ cropped_height, cropped_width = self.size
853
+
854
+ if self.pad_if_needed:
855
+ if padded_height < cropped_height:
856
+ diff = cropped_height - padded_height
857
+
858
+ pad_top += diff
859
+ pad_bottom += diff
860
+ padded_height += 2 * diff
861
+
862
+ if padded_width < cropped_width:
863
+ diff = cropped_width - padded_width
864
+
865
+ pad_left += diff
866
+ pad_right += diff
867
+ padded_width += 2 * diff
868
+
869
+ if padded_height < cropped_height or padded_width < cropped_width:
870
+ raise ValueError(
871
+ f"Required crop size {(cropped_height, cropped_width)} is larger than "
872
+ f"{'padded ' if self.padding is not None else ''}input image size {(padded_height, padded_width)}."
873
+ )
874
+
875
+ # We need a different order here than we have in self.padding since this padding will be parsed again in `F.pad`
876
+ padding = [pad_left, pad_top, pad_right, pad_bottom]
877
+ needs_pad = any(padding)
878
+
879
+ needs_vert_crop, top = (
880
+ (True, int(torch.randint(0, padded_height - cropped_height + 1, size=())))
881
+ if padded_height > cropped_height
882
+ else (False, 0)
883
+ )
884
+ needs_horz_crop, left = (
885
+ (True, int(torch.randint(0, padded_width - cropped_width + 1, size=())))
886
+ if padded_width > cropped_width
887
+ else (False, 0)
888
+ )
889
+
890
+ return dict(
891
+ needs_crop=needs_vert_crop or needs_horz_crop,
892
+ top=top,
893
+ left=left,
894
+ height=cropped_height,
895
+ width=cropped_width,
896
+ needs_pad=needs_pad,
897
+ padding=padding,
898
+ )
899
+
900
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
901
+ if params["needs_pad"]:
902
+ fill = _get_fill(self._fill, type(inpt))
903
+ inpt = self._call_kernel(F.pad, inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode)
904
+
905
+ if params["needs_crop"]:
906
+ inpt = self._call_kernel(
907
+ F.crop, inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]
908
+ )
909
+
910
+ return inpt
911
+
912
+
913
+ class RandomPerspective(_RandomApplyTransform):
914
+ """Perform a random perspective transformation of the input with a given probability.
915
+
916
+ If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
917
+ :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
918
+ it can have arbitrary number of leading batch dimensions. For example,
919
+ the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
920
+
921
+ Args:
922
+ distortion_scale (float, optional): argument to control the degree of distortion and ranges from 0 to 1.
923
+ Default is 0.5.
924
+ p (float, optional): probability of the input being transformed. Default is 0.5.
925
+ interpolation (InterpolationMode, optional): Desired interpolation enum defined by
926
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
927
+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
928
+ The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
929
+ fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
930
+ Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
931
+ Fill value can be also a dictionary mapping data type to the fill value, e.g.
932
+ ``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and
933
+ ``Mask`` will be filled with 0.
934
+ """
935
+
936
+ _v1_transform_cls = _transforms.RandomPerspective
937
+
938
+ def __init__(
939
+ self,
940
+ distortion_scale: float = 0.5,
941
+ p: float = 0.5,
942
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
943
+ fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
944
+ ) -> None:
945
+ super().__init__(p=p)
946
+
947
+ if not (0 <= distortion_scale <= 1):
948
+ raise ValueError("Argument distortion_scale value should be between 0 and 1")
949
+
950
+ self.distortion_scale = distortion_scale
951
+ self.interpolation = interpolation
952
+ self.fill = fill
953
+ self._fill = _setup_fill_arg(fill)
954
+
955
+ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
956
+ height, width = query_size(flat_inputs)
957
+
958
+ distortion_scale = self.distortion_scale
959
+
960
+ half_height = height // 2
961
+ half_width = width // 2
962
+ bound_height = int(distortion_scale * half_height) + 1
963
+ bound_width = int(distortion_scale * half_width) + 1
964
+ topleft = [
965
+ int(torch.randint(0, bound_width, size=(1,))),
966
+ int(torch.randint(0, bound_height, size=(1,))),
967
+ ]
968
+ topright = [
969
+ int(torch.randint(width - bound_width, width, size=(1,))),
970
+ int(torch.randint(0, bound_height, size=(1,))),
971
+ ]
972
+ botright = [
973
+ int(torch.randint(width - bound_width, width, size=(1,))),
974
+ int(torch.randint(height - bound_height, height, size=(1,))),
975
+ ]
976
+ botleft = [
977
+ int(torch.randint(0, bound_width, size=(1,))),
978
+ int(torch.randint(height - bound_height, height, size=(1,))),
979
+ ]
980
+ startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]
981
+ endpoints = [topleft, topright, botright, botleft]
982
+ perspective_coeffs = _get_perspective_coeffs(startpoints, endpoints)
983
+ return dict(coefficients=perspective_coeffs)
984
+
985
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
986
+ fill = _get_fill(self._fill, type(inpt))
987
+ return self._call_kernel(
988
+ F.perspective,
989
+ inpt,
990
+ startpoints=None,
991
+ endpoints=None,
992
+ fill=fill,
993
+ interpolation=self.interpolation,
994
+ **params,
995
+ )
996
+
997
+
998
+ class ElasticTransform(Transform):
999
+ """Transform the input with elastic transformations.
1000
+
1001
+ If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
1002
+ :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
1003
+ it can have arbitrary number of leading batch dimensions. For example,
1004
+ the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
1005
+
1006
+ Given alpha and sigma, it will generate displacement
1007
+ vectors for all pixels based on random offsets. Alpha controls the strength
1008
+ and sigma controls the smoothness of the displacements.
1009
+ The displacements are added to an identity grid and the resulting grid is
1010
+ used to transform the input.
1011
+
1012
+ .. note::
1013
+ Implementation to transform bounding boxes is approximative (not exact).
1014
+ We construct an approximation of the inverse grid as ``inverse_grid = identity - displacement``.
1015
+ This is not an exact inverse of the grid used to transform images, i.e. ``grid = identity + displacement``.
1016
+ Our assumption is that ``displacement * displacement`` is small and can be ignored.
1017
+ Large displacements would lead to large errors in the approximation.
1018
+
1019
+ Applications:
1020
+ Randomly transforms the morphology of objects in images and produces a
1021
+ see-through-water-like effect.
1022
+
1023
+ Args:
1024
+ alpha (float or sequence of floats, optional): Magnitude of displacements. Default is 50.0.
1025
+ sigma (float or sequence of floats, optional): Smoothness of displacements. Default is 5.0.
1026
+ interpolation (InterpolationMode, optional): Desired interpolation enum defined by
1027
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
1028
+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
1029
+ The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
1030
+ fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
1031
+ Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
1032
+ Fill value can be also a dictionary mapping data type to the fill value, e.g.
1033
+ ``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and
1034
+ ``Mask`` will be filled with 0.
1035
+ """
1036
+
1037
+ _v1_transform_cls = _transforms.ElasticTransform
1038
+
1039
+ def __init__(
1040
+ self,
1041
+ alpha: Union[float, Sequence[float]] = 50.0,
1042
+ sigma: Union[float, Sequence[float]] = 5.0,
1043
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1044
+ fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
1045
+ ) -> None:
1046
+ super().__init__()
1047
+ self.alpha = _setup_number_or_seq(alpha, "alpha")
1048
+ self.sigma = _setup_number_or_seq(sigma, "sigma")
1049
+
1050
+ self.interpolation = interpolation
1051
+ self.fill = fill
1052
+ self._fill = _setup_fill_arg(fill)
1053
+
1054
+ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
1055
+ size = list(query_size(flat_inputs))
1056
+
1057
+ dx = torch.rand([1, 1] + size) * 2 - 1
1058
+ if self.sigma[0] > 0.0:
1059
+ kx = int(8 * self.sigma[0] + 1)
1060
+ # if kernel size is even we have to make it odd
1061
+ if kx % 2 == 0:
1062
+ kx += 1
1063
+ dx = self._call_kernel(F.gaussian_blur, dx, [kx, kx], list(self.sigma))
1064
+ dx = dx * self.alpha[0] / size[0]
1065
+
1066
+ dy = torch.rand([1, 1] + size) * 2 - 1
1067
+ if self.sigma[1] > 0.0:
1068
+ ky = int(8 * self.sigma[1] + 1)
1069
+ # if kernel size is even we have to make it odd
1070
+ if ky % 2 == 0:
1071
+ ky += 1
1072
+ dy = self._call_kernel(F.gaussian_blur, dy, [ky, ky], list(self.sigma))
1073
+ dy = dy * self.alpha[1] / size[1]
1074
+ displacement = torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2
1075
+ return dict(displacement=displacement)
1076
+
1077
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
1078
+ fill = _get_fill(self._fill, type(inpt))
1079
+ return self._call_kernel(
1080
+ F.elastic,
1081
+ inpt,
1082
+ **params,
1083
+ fill=fill,
1084
+ interpolation=self.interpolation,
1085
+ )
1086
+
1087
+
1088
+ class RandomIoUCrop(Transform):
1089
+ """Random IoU crop transformation from
1090
+ `"SSD: Single Shot MultiBox Detector" <https://arxiv.org/abs/1512.02325>`_.
1091
+
1092
+ This transformation requires an image or video data and ``tv_tensors.BoundingBoxes`` in the input.
1093
+
1094
+ .. warning::
1095
+ In order to properly remove the bounding boxes below the IoU threshold, `RandomIoUCrop`
1096
+ must be followed by :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes`, either immediately
1097
+ after or later in the transforms pipeline.
1098
+
1099
+ If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
1100
+ :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
1101
+ it can have arbitrary number of leading batch dimensions. For example,
1102
+ the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
1103
+
1104
+ Args:
1105
+ min_scale (float, optional): Minimum factors to scale the input size.
1106
+ max_scale (float, optional): Maximum factors to scale the input size.
1107
+ min_aspect_ratio (float, optional): Minimum aspect ratio for the cropped image or video.
1108
+ max_aspect_ratio (float, optional): Maximum aspect ratio for the cropped image or video.
1109
+ sampler_options (list of float, optional): List of minimal IoU (Jaccard) overlap between all the boxes and
1110
+ a cropped image or video. Default, ``None`` which corresponds to ``[0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0]``
1111
+ trials (int, optional): Number of trials to find a crop for a given value of minimal IoU (Jaccard) overlap.
1112
+ Default, 40.
1113
+ """
1114
+
1115
+ def __init__(
1116
+ self,
1117
+ min_scale: float = 0.3,
1118
+ max_scale: float = 1.0,
1119
+ min_aspect_ratio: float = 0.5,
1120
+ max_aspect_ratio: float = 2.0,
1121
+ sampler_options: Optional[List[float]] = None,
1122
+ trials: int = 40,
1123
+ ):
1124
+ super().__init__()
1125
+ # Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174
1126
+ self.min_scale = min_scale
1127
+ self.max_scale = max_scale
1128
+ self.min_aspect_ratio = min_aspect_ratio
1129
+ self.max_aspect_ratio = max_aspect_ratio
1130
+ if sampler_options is None:
1131
+ sampler_options = [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0]
1132
+ self.options = sampler_options
1133
+ self.trials = trials
1134
+
1135
+ def _check_inputs(self, flat_inputs: List[Any]) -> None:
1136
+ if not (
1137
+ has_all(flat_inputs, tv_tensors.BoundingBoxes)
1138
+ and has_any(flat_inputs, PIL.Image.Image, tv_tensors.Image, is_pure_tensor)
1139
+ ):
1140
+ raise TypeError(
1141
+ f"{type(self).__name__}() requires input sample to contain tensor or PIL images "
1142
+ "and bounding boxes. Sample can also contain masks."
1143
+ )
1144
+
1145
+ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
1146
+ orig_h, orig_w = query_size(flat_inputs)
1147
+ bboxes = get_bounding_boxes(flat_inputs)
1148
+
1149
+ while True:
1150
+ # sample an option
1151
+ idx = int(torch.randint(low=0, high=len(self.options), size=(1,)))
1152
+ min_jaccard_overlap = self.options[idx]
1153
+ if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option
1154
+ return dict()
1155
+
1156
+ for _ in range(self.trials):
1157
+ # check the aspect ratio limitations
1158
+ r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2)
1159
+ new_w = int(orig_w * r[0])
1160
+ new_h = int(orig_h * r[1])
1161
+ aspect_ratio = new_w / new_h
1162
+ if not (self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio):
1163
+ continue
1164
+
1165
+ # check for 0 area crops
1166
+ r = torch.rand(2)
1167
+ left = int((orig_w - new_w) * r[0])
1168
+ top = int((orig_h - new_h) * r[1])
1169
+ right = left + new_w
1170
+ bottom = top + new_h
1171
+ if left == right or top == bottom:
1172
+ continue
1173
+
1174
+ # check for any valid boxes with centers within the crop area
1175
+ xyxy_bboxes = F.convert_bounding_box_format(
1176
+ bboxes.as_subclass(torch.Tensor),
1177
+ bboxes.format,
1178
+ tv_tensors.BoundingBoxFormat.XYXY,
1179
+ )
1180
+ cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2])
1181
+ cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3])
1182
+ is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom)
1183
+ if not is_within_crop_area.any():
1184
+ continue
1185
+
1186
+ # check at least 1 box with jaccard limitations
1187
+ xyxy_bboxes = xyxy_bboxes[is_within_crop_area]
1188
+ ious = box_iou(
1189
+ xyxy_bboxes,
1190
+ torch.tensor([[left, top, right, bottom]], dtype=xyxy_bboxes.dtype, device=xyxy_bboxes.device),
1191
+ )
1192
+ if ious.max() < min_jaccard_overlap:
1193
+ continue
1194
+
1195
+ return dict(top=top, left=left, height=new_h, width=new_w, is_within_crop_area=is_within_crop_area)
1196
+
1197
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
1198
+
1199
+ if len(params) < 1:
1200
+ return inpt
1201
+
1202
+ output = self._call_kernel(
1203
+ F.crop, inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]
1204
+ )
1205
+
1206
+ if isinstance(output, tv_tensors.BoundingBoxes):
1207
+ # We "mark" the invalid boxes as degenreate, and they can be
1208
+ # removed by a later call to SanitizeBoundingBoxes()
1209
+ output[~params["is_within_crop_area"]] = 0
1210
+
1211
+ return output
1212
+
1213
+
1214
+ class ScaleJitter(Transform):
1215
+ """Perform Large Scale Jitter on the input according to
1216
+ `"Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation" <https://arxiv.org/abs/2012.07177>`_.
1217
+
1218
+ If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
1219
+ :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
1220
+ it can have arbitrary number of leading batch dimensions. For example,
1221
+ the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
1222
+
1223
+ Args:
1224
+ target_size (tuple of int): Target size. This parameter defines base scale for jittering,
1225
+ e.g. ``min(target_size[0] / width, target_size[1] / height)``.
1226
+ scale_range (tuple of float, optional): Minimum and maximum of the scale range. Default, ``(0.1, 2.0)``.
1227
+ interpolation (InterpolationMode, optional): Desired interpolation enum defined by
1228
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
1229
+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
1230
+ ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
1231
+ The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
1232
+ antialias (bool, optional): Whether to apply antialiasing.
1233
+ It only affects **tensors** with bilinear or bicubic modes and it is
1234
+ ignored otherwise: on PIL images, antialiasing is always applied on
1235
+ bilinear or bicubic modes; on other modes (for PIL images and
1236
+ tensors), antialiasing makes no sense and this parameter is ignored.
1237
+ Possible values are:
1238
+
1239
+ - ``True`` (default): will apply antialiasing for bilinear or bicubic modes.
1240
+ Other mode aren't affected. This is probably what you want to use.
1241
+ - ``False``: will not apply antialiasing for tensors on any mode. PIL
1242
+ images are still antialiased on bilinear or bicubic modes, because
1243
+ PIL doesn't support no antialias.
1244
+ - ``None``: equivalent to ``False`` for tensors and ``True`` for
1245
+ PIL images. This value exists for legacy reasons and you probably
1246
+ don't want to use it unless you really know what you are doing.
1247
+
1248
+ The default value changed from ``None`` to ``True`` in
1249
+ v0.17, for the PIL and Tensor backends to be consistent.
1250
+ """
1251
+
1252
+ def __init__(
1253
+ self,
1254
+ target_size: Tuple[int, int],
1255
+ scale_range: Tuple[float, float] = (0.1, 2.0),
1256
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1257
+ antialias: Optional[bool] = True,
1258
+ ):
1259
+ super().__init__()
1260
+ self.target_size = target_size
1261
+ self.scale_range = scale_range
1262
+ self.interpolation = interpolation
1263
+ self.antialias = antialias
1264
+
1265
+ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
1266
+ orig_height, orig_width = query_size(flat_inputs)
1267
+
1268
+ scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0])
1269
+ r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale
1270
+ new_width = int(orig_width * r)
1271
+ new_height = int(orig_height * r)
1272
+
1273
+ return dict(size=(new_height, new_width))
1274
+
1275
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
1276
+ return self._call_kernel(
1277
+ F.resize, inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias
1278
+ )
1279
+
1280
+
1281
+ class RandomShortestSize(Transform):
1282
+ """Randomly resize the input.
1283
+
1284
+ If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
1285
+ :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
1286
+ it can have arbitrary number of leading batch dimensions. For example,
1287
+ the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
1288
+
1289
+ Args:
1290
+ min_size (int or sequence of int): Minimum spatial size. Single integer value or a sequence of integer values.
1291
+ max_size (int, optional): Maximum spatial size. Default, None.
1292
+ interpolation (InterpolationMode, optional): Desired interpolation enum defined by
1293
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
1294
+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
1295
+ ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
1296
+ The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
1297
+ antialias (bool, optional): Whether to apply antialiasing.
1298
+ It only affects **tensors** with bilinear or bicubic modes and it is
1299
+ ignored otherwise: on PIL images, antialiasing is always applied on
1300
+ bilinear or bicubic modes; on other modes (for PIL images and
1301
+ tensors), antialiasing makes no sense and this parameter is ignored.
1302
+ Possible values are:
1303
+
1304
+ - ``True`` (default): will apply antialiasing for bilinear or bicubic modes.
1305
+ Other mode aren't affected. This is probably what you want to use.
1306
+ - ``False``: will not apply antialiasing for tensors on any mode. PIL
1307
+ images are still antialiased on bilinear or bicubic modes, because
1308
+ PIL doesn't support no antialias.
1309
+ - ``None``: equivalent to ``False`` for tensors and ``True`` for
1310
+ PIL images. This value exists for legacy reasons and you probably
1311
+ don't want to use it unless you really know what you are doing.
1312
+
1313
+ The default value changed from ``None`` to ``True`` in
1314
+ v0.17, for the PIL and Tensor backends to be consistent.
1315
+ """
1316
+
1317
+ def __init__(
1318
+ self,
1319
+ min_size: Union[List[int], Tuple[int], int],
1320
+ max_size: Optional[int] = None,
1321
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1322
+ antialias: Optional[bool] = True,
1323
+ ):
1324
+ super().__init__()
1325
+ self.min_size = [min_size] if isinstance(min_size, int) else list(min_size)
1326
+ self.max_size = max_size
1327
+ self.interpolation = interpolation
1328
+ self.antialias = antialias
1329
+
1330
+ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
1331
+ orig_height, orig_width = query_size(flat_inputs)
1332
+
1333
+ min_size = self.min_size[int(torch.randint(len(self.min_size), ()))]
1334
+ r = min_size / min(orig_height, orig_width)
1335
+ if self.max_size is not None:
1336
+ r = min(r, self.max_size / max(orig_height, orig_width))
1337
+
1338
+ new_width = int(orig_width * r)
1339
+ new_height = int(orig_height * r)
1340
+
1341
+ return dict(size=(new_height, new_width))
1342
+
1343
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
1344
+ return self._call_kernel(
1345
+ F.resize, inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias
1346
+ )
1347
+
1348
+
1349
+ class RandomResize(Transform):
1350
+ """Randomly resize the input.
1351
+
1352
+ This transformation can be used together with ``RandomCrop`` as data augmentations to train
1353
+ models on image segmentation task.
1354
+
1355
+ Output spatial size is randomly sampled from the interval ``[min_size, max_size]``:
1356
+
1357
+ .. code-block:: python
1358
+
1359
+ size = uniform_sample(min_size, max_size)
1360
+ output_width = size
1361
+ output_height = size
1362
+
1363
+ If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`,
1364
+ :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.)
1365
+ it can have arbitrary number of leading batch dimensions. For example,
1366
+ the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
1367
+
1368
+ Args:
1369
+ min_size (int): Minimum output size for random sampling
1370
+ max_size (int): Maximum output size for random sampling
1371
+ interpolation (InterpolationMode, optional): Desired interpolation enum defined by
1372
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
1373
+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
1374
+ ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
1375
+ The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
1376
+ antialias (bool, optional): Whether to apply antialiasing.
1377
+ It only affects **tensors** with bilinear or bicubic modes and it is
1378
+ ignored otherwise: on PIL images, antialiasing is always applied on
1379
+ bilinear or bicubic modes; on other modes (for PIL images and
1380
+ tensors), antialiasing makes no sense and this parameter is ignored.
1381
+ Possible values are:
1382
+
1383
+ - ``True`` (default): will apply antialiasing for bilinear or bicubic modes.
1384
+ Other mode aren't affected. This is probably what you want to use.
1385
+ - ``False``: will not apply antialiasing for tensors on any mode. PIL
1386
+ images are still antialiased on bilinear or bicubic modes, because
1387
+ PIL doesn't support no antialias.
1388
+ - ``None``: equivalent to ``False`` for tensors and ``True`` for
1389
+ PIL images. This value exists for legacy reasons and you probably
1390
+ don't want to use it unless you really know what you are doing.
1391
+
1392
+ The default value changed from ``None`` to ``True`` in
1393
+ v0.17, for the PIL and Tensor backends to be consistent.
1394
+ """
1395
+
1396
+ def __init__(
1397
+ self,
1398
+ min_size: int,
1399
+ max_size: int,
1400
+ interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
1401
+ antialias: Optional[bool] = True,
1402
+ ) -> None:
1403
+ super().__init__()
1404
+ self.min_size = min_size
1405
+ self.max_size = max_size
1406
+ self.interpolation = interpolation
1407
+ self.antialias = antialias
1408
+
1409
+ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
1410
+ size = int(torch.randint(self.min_size, self.max_size, ()))
1411
+ return dict(size=[size])
1412
+
1413
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
1414
+ return self._call_kernel(
1415
+ F.resize, inpt, params["size"], interpolation=self.interpolation, antialias=self.antialias
1416
+ )