koichi12 commited on
Commit
04b7ba0
·
verified ·
1 Parent(s): f196197

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .venv/lib/python3.11/site-packages/torchvision/_C.so +3 -0
  3. .venv/lib/python3.11/site-packages/torchvision/datasets/_optical_flow.py +490 -0
  4. .venv/lib/python3.11/site-packages/torchvision/datasets/country211.py +58 -0
  5. .venv/lib/python3.11/site-packages/torchvision/datasets/folder.py +337 -0
  6. .venv/lib/python3.11/site-packages/torchvision/datasets/gtsrb.py +103 -0
  7. .venv/lib/python3.11/site-packages/torchvision/datasets/inaturalist.py +242 -0
  8. .venv/lib/python3.11/site-packages/torchvision/datasets/lsun.py +168 -0
  9. .venv/lib/python3.11/site-packages/torchvision/datasets/sbu.py +110 -0
  10. .venv/lib/python3.11/site-packages/torchvision/datasets/svhn.py +130 -0
  11. .venv/lib/python3.11/site-packages/torchvision/datasets/widerface.py +197 -0
  12. .venv/lib/python3.11/site-packages/torchvision/io/__init__.py +76 -0
  13. .venv/lib/python3.11/site-packages/torchvision/io/__pycache__/__init__.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/torchvision/io/__pycache__/_load_gpu_decoder.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/torchvision/io/__pycache__/_video_opt.cpython-311.pyc +0 -0
  16. .venv/lib/python3.11/site-packages/torchvision/io/__pycache__/image.cpython-311.pyc +0 -0
  17. .venv/lib/python3.11/site-packages/torchvision/io/__pycache__/video.cpython-311.pyc +0 -0
  18. .venv/lib/python3.11/site-packages/torchvision/io/__pycache__/video_reader.cpython-311.pyc +0 -0
  19. .venv/lib/python3.11/site-packages/torchvision/io/_load_gpu_decoder.py +8 -0
  20. .venv/lib/python3.11/site-packages/torchvision/io/_video_opt.py +513 -0
  21. .venv/lib/python3.11/site-packages/torchvision/io/image.py +436 -0
  22. .venv/lib/python3.11/site-packages/torchvision/io/video.py +438 -0
  23. .venv/lib/python3.11/site-packages/torchvision/io/video_reader.py +294 -0
  24. .venv/lib/python3.11/site-packages/torchvision/models/detection/__init__.py +7 -0
  25. .venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/__init__.cpython-311.pyc +0 -0
  26. .venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/_utils.cpython-311.pyc +0 -0
  27. .venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/anchor_utils.cpython-311.pyc +0 -0
  28. .venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/backbone_utils.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/faster_rcnn.cpython-311.pyc +0 -0
  30. .venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/fcos.cpython-311.pyc +0 -0
  31. .venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/generalized_rcnn.cpython-311.pyc +0 -0
  32. .venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/image_list.cpython-311.pyc +0 -0
  33. .venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/keypoint_rcnn.cpython-311.pyc +0 -0
  34. .venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/mask_rcnn.cpython-311.pyc +0 -0
  35. .venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/retinanet.cpython-311.pyc +0 -0
  36. .venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/roi_heads.cpython-311.pyc +0 -0
  37. .venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/rpn.cpython-311.pyc +0 -0
  38. .venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/ssd.cpython-311.pyc +0 -0
  39. .venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/ssdlite.cpython-311.pyc +0 -0
  40. .venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/transform.cpython-311.pyc +0 -0
  41. .venv/lib/python3.11/site-packages/torchvision/models/detection/_utils.py +540 -0
  42. .venv/lib/python3.11/site-packages/torchvision/models/detection/anchor_utils.py +268 -0
  43. .venv/lib/python3.11/site-packages/torchvision/models/detection/backbone_utils.py +244 -0
  44. .venv/lib/python3.11/site-packages/torchvision/models/detection/faster_rcnn.py +846 -0
  45. .venv/lib/python3.11/site-packages/torchvision/models/detection/fcos.py +775 -0
  46. .venv/lib/python3.11/site-packages/torchvision/models/detection/generalized_rcnn.py +118 -0
  47. .venv/lib/python3.11/site-packages/torchvision/models/detection/image_list.py +25 -0
  48. .venv/lib/python3.11/site-packages/torchvision/models/detection/keypoint_rcnn.py +474 -0
  49. .venv/lib/python3.11/site-packages/torchvision/models/detection/mask_rcnn.py +590 -0
  50. .venv/lib/python3.11/site-packages/torchvision/models/detection/retinanet.py +903 -0
.gitattributes CHANGED
@@ -345,3 +345,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/
345
  .venv/lib/python3.11/site-packages/distlib/t64-arm.exe filter=lfs diff=lfs merge=lfs -text
346
  .venv/lib/python3.11/site-packages/multidict/_multidict.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
347
  .venv/lib/python3.11/site-packages/torchvision/image.so filter=lfs diff=lfs merge=lfs -text
 
 
345
  .venv/lib/python3.11/site-packages/distlib/t64-arm.exe filter=lfs diff=lfs merge=lfs -text
346
  .venv/lib/python3.11/site-packages/multidict/_multidict.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
347
  .venv/lib/python3.11/site-packages/torchvision/image.so filter=lfs diff=lfs merge=lfs -text
348
+ .venv/lib/python3.11/site-packages/torchvision/_C.so filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/torchvision/_C.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb7e1b7570bd8fc14f9497793f89e188ccf161d7c14ca1f236e00368779ee609
3
+ size 7746688
.venv/lib/python3.11/site-packages/torchvision/datasets/_optical_flow.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import os
3
+ from abc import ABC, abstractmethod
4
+ from glob import glob
5
+ from pathlib import Path
6
+ from typing import Callable, List, Optional, Tuple, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ from PIL import Image
11
+
12
+ from ..io.image import decode_png, read_file
13
+ from .utils import _read_pfm, verify_str_arg
14
+ from .vision import VisionDataset
15
+
16
+ T1 = Tuple[Image.Image, Image.Image, Optional[np.ndarray], Optional[np.ndarray]]
17
+ T2 = Tuple[Image.Image, Image.Image, Optional[np.ndarray]]
18
+
19
+
20
+ __all__ = (
21
+ "KittiFlow",
22
+ "Sintel",
23
+ "FlyingThings3D",
24
+ "FlyingChairs",
25
+ "HD1K",
26
+ )
27
+
28
+
29
+ class FlowDataset(ABC, VisionDataset):
30
+ # Some datasets like Kitti have a built-in valid_flow_mask, indicating which flow values are valid
31
+ # For those we return (img1, img2, flow, valid_flow_mask), and for the rest we return (img1, img2, flow),
32
+ # and it's up to whatever consumes the dataset to decide what valid_flow_mask should be.
33
+ _has_builtin_flow_mask = False
34
+
35
+ def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None:
36
+
37
+ super().__init__(root=root)
38
+ self.transforms = transforms
39
+
40
+ self._flow_list: List[str] = []
41
+ self._image_list: List[List[str]] = []
42
+
43
+ def _read_img(self, file_name: str) -> Image.Image:
44
+ img = Image.open(file_name)
45
+ if img.mode != "RGB":
46
+ img = img.convert("RGB") # type: ignore[assignment]
47
+ return img
48
+
49
+ @abstractmethod
50
+ def _read_flow(self, file_name: str):
51
+ # Return the flow or a tuple with the flow and the valid_flow_mask if _has_builtin_flow_mask is True
52
+ pass
53
+
54
+ def __getitem__(self, index: int) -> Union[T1, T2]:
55
+
56
+ img1 = self._read_img(self._image_list[index][0])
57
+ img2 = self._read_img(self._image_list[index][1])
58
+
59
+ if self._flow_list: # it will be empty for some dataset when split="test"
60
+ flow = self._read_flow(self._flow_list[index])
61
+ if self._has_builtin_flow_mask:
62
+ flow, valid_flow_mask = flow
63
+ else:
64
+ valid_flow_mask = None
65
+ else:
66
+ flow = valid_flow_mask = None
67
+
68
+ if self.transforms is not None:
69
+ img1, img2, flow, valid_flow_mask = self.transforms(img1, img2, flow, valid_flow_mask)
70
+
71
+ if self._has_builtin_flow_mask or valid_flow_mask is not None:
72
+ # The `or valid_flow_mask is not None` part is here because the mask can be generated within a transform
73
+ return img1, img2, flow, valid_flow_mask
74
+ else:
75
+ return img1, img2, flow
76
+
77
+ def __len__(self) -> int:
78
+ return len(self._image_list)
79
+
80
+ def __rmul__(self, v: int) -> torch.utils.data.ConcatDataset:
81
+ return torch.utils.data.ConcatDataset([self] * v)
82
+
83
+
84
+ class Sintel(FlowDataset):
85
+ """`Sintel <http://sintel.is.tue.mpg.de/>`_ Dataset for optical flow.
86
+
87
+ The dataset is expected to have the following structure: ::
88
+
89
+ root
90
+ Sintel
91
+ testing
92
+ clean
93
+ scene_1
94
+ scene_2
95
+ ...
96
+ final
97
+ scene_1
98
+ scene_2
99
+ ...
100
+ training
101
+ clean
102
+ scene_1
103
+ scene_2
104
+ ...
105
+ final
106
+ scene_1
107
+ scene_2
108
+ ...
109
+ flow
110
+ scene_1
111
+ scene_2
112
+ ...
113
+
114
+ Args:
115
+ root (str or ``pathlib.Path``): Root directory of the Sintel Dataset.
116
+ split (string, optional): The dataset split, either "train" (default) or "test"
117
+ pass_name (string, optional): The pass to use, either "clean" (default), "final", or "both". See link above for
118
+ details on the different passes.
119
+ transforms (callable, optional): A function/transform that takes in
120
+ ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
121
+ ``valid_flow_mask`` is expected for consistency with other datasets which
122
+ return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
123
+ """
124
+
125
+ def __init__(
126
+ self,
127
+ root: Union[str, Path],
128
+ split: str = "train",
129
+ pass_name: str = "clean",
130
+ transforms: Optional[Callable] = None,
131
+ ) -> None:
132
+ super().__init__(root=root, transforms=transforms)
133
+
134
+ verify_str_arg(split, "split", valid_values=("train", "test"))
135
+ verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final", "both"))
136
+ passes = ["clean", "final"] if pass_name == "both" else [pass_name]
137
+
138
+ root = Path(root) / "Sintel"
139
+ flow_root = root / "training" / "flow"
140
+
141
+ for pass_name in passes:
142
+ split_dir = "training" if split == "train" else split
143
+ image_root = root / split_dir / pass_name
144
+ for scene in os.listdir(image_root):
145
+ image_list = sorted(glob(str(image_root / scene / "*.png")))
146
+ for i in range(len(image_list) - 1):
147
+ self._image_list += [[image_list[i], image_list[i + 1]]]
148
+
149
+ if split == "train":
150
+ self._flow_list += sorted(glob(str(flow_root / scene / "*.flo")))
151
+
152
+ def __getitem__(self, index: int) -> Union[T1, T2]:
153
+ """Return example at given index.
154
+
155
+ Args:
156
+ index(int): The index of the example to retrieve
157
+
158
+ Returns:
159
+ tuple: A 3-tuple with ``(img1, img2, flow)``.
160
+ The flow is a numpy array of shape (2, H, W) and the images are PIL images.
161
+ ``flow`` is None if ``split="test"``.
162
+ If a valid flow mask is generated within the ``transforms`` parameter,
163
+ a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
164
+ """
165
+ return super().__getitem__(index)
166
+
167
+ def _read_flow(self, file_name: str) -> np.ndarray:
168
+ return _read_flo(file_name)
169
+
170
+
171
+ class KittiFlow(FlowDataset):
172
+ """`KITTI <http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow>`__ dataset for optical flow (2015).
173
+
174
+ The dataset is expected to have the following structure: ::
175
+
176
+ root
177
+ KittiFlow
178
+ testing
179
+ image_2
180
+ training
181
+ image_2
182
+ flow_occ
183
+
184
+ Args:
185
+ root (str or ``pathlib.Path``): Root directory of the KittiFlow Dataset.
186
+ split (string, optional): The dataset split, either "train" (default) or "test"
187
+ transforms (callable, optional): A function/transform that takes in
188
+ ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
189
+ """
190
+
191
+ _has_builtin_flow_mask = True
192
+
193
+ def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
194
+ super().__init__(root=root, transforms=transforms)
195
+
196
+ verify_str_arg(split, "split", valid_values=("train", "test"))
197
+
198
+ root = Path(root) / "KittiFlow" / (split + "ing")
199
+ images1 = sorted(glob(str(root / "image_2" / "*_10.png")))
200
+ images2 = sorted(glob(str(root / "image_2" / "*_11.png")))
201
+
202
+ if not images1 or not images2:
203
+ raise FileNotFoundError(
204
+ "Could not find the Kitti flow images. Please make sure the directory structure is correct."
205
+ )
206
+
207
+ for img1, img2 in zip(images1, images2):
208
+ self._image_list += [[img1, img2]]
209
+
210
+ if split == "train":
211
+ self._flow_list = sorted(glob(str(root / "flow_occ" / "*_10.png")))
212
+
213
+ def __getitem__(self, index: int) -> Union[T1, T2]:
214
+ """Return example at given index.
215
+
216
+ Args:
217
+ index(int): The index of the example to retrieve
218
+
219
+ Returns:
220
+ tuple: A 4-tuple with ``(img1, img2, flow, valid_flow_mask)``
221
+ where ``valid_flow_mask`` is a numpy boolean mask of shape (H, W)
222
+ indicating which flow values are valid. The flow is a numpy array of
223
+ shape (2, H, W) and the images are PIL images. ``flow`` and ``valid_flow_mask`` are None if
224
+ ``split="test"``.
225
+ """
226
+ return super().__getitem__(index)
227
+
228
+ def _read_flow(self, file_name: str) -> Tuple[np.ndarray, np.ndarray]:
229
+ return _read_16bits_png_with_flow_and_valid_mask(file_name)
230
+
231
+
232
+ class FlyingChairs(FlowDataset):
233
+ """`FlyingChairs <https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs>`_ Dataset for optical flow.
234
+
235
+ You will also need to download the FlyingChairs_train_val.txt file from the dataset page.
236
+
237
+ The dataset is expected to have the following structure: ::
238
+
239
+ root
240
+ FlyingChairs
241
+ data
242
+ 00001_flow.flo
243
+ 00001_img1.ppm
244
+ 00001_img2.ppm
245
+ ...
246
+ FlyingChairs_train_val.txt
247
+
248
+
249
+ Args:
250
+ root (str or ``pathlib.Path``): Root directory of the FlyingChairs Dataset.
251
+ split (string, optional): The dataset split, either "train" (default) or "val"
252
+ transforms (callable, optional): A function/transform that takes in
253
+ ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
254
+ ``valid_flow_mask`` is expected for consistency with other datasets which
255
+ return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
256
+ """
257
+
258
+ def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
259
+ super().__init__(root=root, transforms=transforms)
260
+
261
+ verify_str_arg(split, "split", valid_values=("train", "val"))
262
+
263
+ root = Path(root) / "FlyingChairs"
264
+ images = sorted(glob(str(root / "data" / "*.ppm")))
265
+ flows = sorted(glob(str(root / "data" / "*.flo")))
266
+
267
+ split_file_name = "FlyingChairs_train_val.txt"
268
+
269
+ if not os.path.exists(root / split_file_name):
270
+ raise FileNotFoundError(
271
+ "The FlyingChairs_train_val.txt file was not found - please download it from the dataset page (see docstring)."
272
+ )
273
+
274
+ split_list = np.loadtxt(str(root / split_file_name), dtype=np.int32)
275
+ for i in range(len(flows)):
276
+ split_id = split_list[i]
277
+ if (split == "train" and split_id == 1) or (split == "val" and split_id == 2):
278
+ self._flow_list += [flows[i]]
279
+ self._image_list += [[images[2 * i], images[2 * i + 1]]]
280
+
281
+ def __getitem__(self, index: int) -> Union[T1, T2]:
282
+ """Return example at given index.
283
+
284
+ Args:
285
+ index(int): The index of the example to retrieve
286
+
287
+ Returns:
288
+ tuple: A 3-tuple with ``(img1, img2, flow)``.
289
+ The flow is a numpy array of shape (2, H, W) and the images are PIL images.
290
+ ``flow`` is None if ``split="val"``.
291
+ If a valid flow mask is generated within the ``transforms`` parameter,
292
+ a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
293
+ """
294
+ return super().__getitem__(index)
295
+
296
+ def _read_flow(self, file_name: str) -> np.ndarray:
297
+ return _read_flo(file_name)
298
+
299
+
300
+ class FlyingThings3D(FlowDataset):
301
+ """`FlyingThings3D <https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html>`_ dataset for optical flow.
302
+
303
+ The dataset is expected to have the following structure: ::
304
+
305
+ root
306
+ FlyingThings3D
307
+ frames_cleanpass
308
+ TEST
309
+ TRAIN
310
+ frames_finalpass
311
+ TEST
312
+ TRAIN
313
+ optical_flow
314
+ TEST
315
+ TRAIN
316
+
317
+ Args:
318
+ root (str or ``pathlib.Path``): Root directory of the intel FlyingThings3D Dataset.
319
+ split (string, optional): The dataset split, either "train" (default) or "test"
320
+ pass_name (string, optional): The pass to use, either "clean" (default) or "final" or "both". See link above for
321
+ details on the different passes.
322
+ camera (string, optional): Which camera to return images from. Can be either "left" (default) or "right" or "both".
323
+ transforms (callable, optional): A function/transform that takes in
324
+ ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
325
+ ``valid_flow_mask`` is expected for consistency with other datasets which
326
+ return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
327
+ """
328
+
329
+ def __init__(
330
+ self,
331
+ root: Union[str, Path],
332
+ split: str = "train",
333
+ pass_name: str = "clean",
334
+ camera: str = "left",
335
+ transforms: Optional[Callable] = None,
336
+ ) -> None:
337
+ super().__init__(root=root, transforms=transforms)
338
+
339
+ verify_str_arg(split, "split", valid_values=("train", "test"))
340
+ split = split.upper()
341
+
342
+ verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final", "both"))
343
+ passes = {
344
+ "clean": ["frames_cleanpass"],
345
+ "final": ["frames_finalpass"],
346
+ "both": ["frames_cleanpass", "frames_finalpass"],
347
+ }[pass_name]
348
+
349
+ verify_str_arg(camera, "camera", valid_values=("left", "right", "both"))
350
+ cameras = ["left", "right"] if camera == "both" else [camera]
351
+
352
+ root = Path(root) / "FlyingThings3D"
353
+
354
+ directions = ("into_future", "into_past")
355
+ for pass_name, camera, direction in itertools.product(passes, cameras, directions):
356
+ image_dirs = sorted(glob(str(root / pass_name / split / "*/*")))
357
+ image_dirs = sorted(Path(image_dir) / camera for image_dir in image_dirs)
358
+
359
+ flow_dirs = sorted(glob(str(root / "optical_flow" / split / "*/*")))
360
+ flow_dirs = sorted(Path(flow_dir) / direction / camera for flow_dir in flow_dirs)
361
+
362
+ if not image_dirs or not flow_dirs:
363
+ raise FileNotFoundError(
364
+ "Could not find the FlyingThings3D flow images. "
365
+ "Please make sure the directory structure is correct."
366
+ )
367
+
368
+ for image_dir, flow_dir in zip(image_dirs, flow_dirs):
369
+ images = sorted(glob(str(image_dir / "*.png")))
370
+ flows = sorted(glob(str(flow_dir / "*.pfm")))
371
+ for i in range(len(flows) - 1):
372
+ if direction == "into_future":
373
+ self._image_list += [[images[i], images[i + 1]]]
374
+ self._flow_list += [flows[i]]
375
+ elif direction == "into_past":
376
+ self._image_list += [[images[i + 1], images[i]]]
377
+ self._flow_list += [flows[i + 1]]
378
+
379
+ def __getitem__(self, index: int) -> Union[T1, T2]:
380
+ """Return example at given index.
381
+
382
+ Args:
383
+ index(int): The index of the example to retrieve
384
+
385
+ Returns:
386
+ tuple: A 3-tuple with ``(img1, img2, flow)``.
387
+ The flow is a numpy array of shape (2, H, W) and the images are PIL images.
388
+ ``flow`` is None if ``split="test"``.
389
+ If a valid flow mask is generated within the ``transforms`` parameter,
390
+ a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
391
+ """
392
+ return super().__getitem__(index)
393
+
394
+ def _read_flow(self, file_name: str) -> np.ndarray:
395
+ return _read_pfm(file_name)
396
+
397
+
398
+ class HD1K(FlowDataset):
399
+ """`HD1K <http://hci-benchmark.iwr.uni-heidelberg.de/>`__ dataset for optical flow.
400
+
401
+ The dataset is expected to have the following structure: ::
402
+
403
+ root
404
+ hd1k
405
+ hd1k_challenge
406
+ image_2
407
+ hd1k_flow_gt
408
+ flow_occ
409
+ hd1k_input
410
+ image_2
411
+
412
+ Args:
413
+ root (str or ``pathlib.Path``): Root directory of the HD1K Dataset.
414
+ split (string, optional): The dataset split, either "train" (default) or "test"
415
+ transforms (callable, optional): A function/transform that takes in
416
+ ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
417
+ """
418
+
419
+ _has_builtin_flow_mask = True
420
+
421
+ def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
422
+ super().__init__(root=root, transforms=transforms)
423
+
424
+ verify_str_arg(split, "split", valid_values=("train", "test"))
425
+
426
+ root = Path(root) / "hd1k"
427
+ if split == "train":
428
+ # There are 36 "sequences" and we don't want seq i to overlap with seq i + 1, so we need this for loop
429
+ for seq_idx in range(36):
430
+ flows = sorted(glob(str(root / "hd1k_flow_gt" / "flow_occ" / f"{seq_idx:06d}_*.png")))
431
+ images = sorted(glob(str(root / "hd1k_input" / "image_2" / f"{seq_idx:06d}_*.png")))
432
+ for i in range(len(flows) - 1):
433
+ self._flow_list += [flows[i]]
434
+ self._image_list += [[images[i], images[i + 1]]]
435
+ else:
436
+ images1 = sorted(glob(str(root / "hd1k_challenge" / "image_2" / "*10.png")))
437
+ images2 = sorted(glob(str(root / "hd1k_challenge" / "image_2" / "*11.png")))
438
+ for image1, image2 in zip(images1, images2):
439
+ self._image_list += [[image1, image2]]
440
+
441
+ if not self._image_list:
442
+ raise FileNotFoundError(
443
+ "Could not find the HD1K images. Please make sure the directory structure is correct."
444
+ )
445
+
446
+ def _read_flow(self, file_name: str) -> Tuple[np.ndarray, np.ndarray]:
447
+ return _read_16bits_png_with_flow_and_valid_mask(file_name)
448
+
449
+ def __getitem__(self, index: int) -> Union[T1, T2]:
450
+ """Return example at given index.
451
+
452
+ Args:
453
+ index(int): The index of the example to retrieve
454
+
455
+ Returns:
456
+ tuple: A 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` where ``valid_flow_mask``
457
+ is a numpy boolean mask of shape (H, W)
458
+ indicating which flow values are valid. The flow is a numpy array of
459
+ shape (2, H, W) and the images are PIL images. ``flow`` and ``valid_flow_mask`` are None if
460
+ ``split="test"``.
461
+ """
462
+ return super().__getitem__(index)
463
+
464
+
465
+ def _read_flo(file_name: str) -> np.ndarray:
466
+ """Read .flo file in Middlebury format"""
467
+ # Code adapted from:
468
+ # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
469
+ # Everything needs to be in little Endian according to
470
+ # https://vision.middlebury.edu/flow/code/flow-code/README.txt
471
+ with open(file_name, "rb") as f:
472
+ magic = np.fromfile(f, "c", count=4).tobytes()
473
+ if magic != b"PIEH":
474
+ raise ValueError("Magic number incorrect. Invalid .flo file")
475
+
476
+ w = int(np.fromfile(f, "<i4", count=1))
477
+ h = int(np.fromfile(f, "<i4", count=1))
478
+ data = np.fromfile(f, "<f4", count=2 * w * h)
479
+ return data.reshape(h, w, 2).transpose(2, 0, 1)
480
+
481
+
482
+ def _read_16bits_png_with_flow_and_valid_mask(file_name: str) -> Tuple[np.ndarray, np.ndarray]:
483
+
484
+ flow_and_valid = decode_png(read_file(file_name)).to(torch.float32)
485
+ flow, valid_flow_mask = flow_and_valid[:2, :, :], flow_and_valid[2, :, :]
486
+ flow = (flow - 2**15) / 64 # This conversion is explained somewhere on the kitti archive
487
+ valid_flow_mask = valid_flow_mask.bool()
488
+
489
+ # For consistency with other datasets, we convert to numpy
490
+ return flow.numpy(), valid_flow_mask.numpy()
.venv/lib/python3.11/site-packages/torchvision/datasets/country211.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Callable, Optional, Union
3
+
4
+ from .folder import ImageFolder
5
+ from .utils import download_and_extract_archive, verify_str_arg
6
+
7
+
8
+ class Country211(ImageFolder):
9
+ """`The Country211 Data Set <https://github.com/openai/CLIP/blob/main/data/country211.md>`_ from OpenAI.
10
+
11
+ This dataset was built by filtering the images from the YFCC100m dataset
12
+ that have GPS coordinate corresponding to a ISO-3166 country code. The
13
+ dataset is balanced by sampling 150 train images, 50 validation images, and
14
+ 100 test images for each country.
15
+
16
+ Args:
17
+ root (str or ``pathlib.Path``): Root directory of the dataset.
18
+ split (string, optional): The dataset split, supports ``"train"`` (default), ``"valid"`` and ``"test"``.
19
+ transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
20
+ version. E.g, ``transforms.RandomCrop``.
21
+ target_transform (callable, optional): A function/transform that takes in the target and transforms it.
22
+ download (bool, optional): If True, downloads the dataset from the internet and puts it into
23
+ ``root/country211/``. If dataset is already downloaded, it is not downloaded again.
24
+ """
25
+
26
+ _URL = "https://openaipublic.azureedge.net/clip/data/country211.tgz"
27
+ _MD5 = "84988d7644798601126c29e9877aab6a"
28
+
29
+ def __init__(
30
+ self,
31
+ root: Union[str, Path],
32
+ split: str = "train",
33
+ transform: Optional[Callable] = None,
34
+ target_transform: Optional[Callable] = None,
35
+ download: bool = False,
36
+ ) -> None:
37
+ self._split = verify_str_arg(split, "split", ("train", "valid", "test"))
38
+
39
+ root = Path(root).expanduser()
40
+ self.root = str(root)
41
+ self._base_folder = root / "country211"
42
+
43
+ if download:
44
+ self._download()
45
+
46
+ if not self._check_exists():
47
+ raise RuntimeError("Dataset not found. You can use download=True to download it")
48
+
49
+ super().__init__(str(self._base_folder / self._split), transform=transform, target_transform=target_transform)
50
+ self.root = str(root)
51
+
52
+ def _check_exists(self) -> bool:
53
+ return self._base_folder.exists() and self._base_folder.is_dir()
54
+
55
+ def _download(self) -> None:
56
+ if self._check_exists():
57
+ return
58
+ download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5)
.venv/lib/python3.11/site-packages/torchvision/datasets/folder.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path
3
+ from pathlib import Path
4
+ from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
5
+
6
+ from PIL import Image
7
+
8
+ from .vision import VisionDataset
9
+
10
+
11
+ def has_file_allowed_extension(filename: str, extensions: Union[str, Tuple[str, ...]]) -> bool:
12
+ """Checks if a file is an allowed extension.
13
+
14
+ Args:
15
+ filename (string): path to a file
16
+ extensions (tuple of strings): extensions to consider (lowercase)
17
+
18
+ Returns:
19
+ bool: True if the filename ends with one of given extensions
20
+ """
21
+ return filename.lower().endswith(extensions if isinstance(extensions, str) else tuple(extensions))
22
+
23
+
24
+ def is_image_file(filename: str) -> bool:
25
+ """Checks if a file is an allowed image extension.
26
+
27
+ Args:
28
+ filename (string): path to a file
29
+
30
+ Returns:
31
+ bool: True if the filename ends with a known image extension
32
+ """
33
+ return has_file_allowed_extension(filename, IMG_EXTENSIONS)
34
+
35
+
36
+ def find_classes(directory: Union[str, Path]) -> Tuple[List[str], Dict[str, int]]:
37
+ """Finds the class folders in a dataset.
38
+
39
+ See :class:`DatasetFolder` for details.
40
+ """
41
+ classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
42
+ if not classes:
43
+ raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
44
+
45
+ class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
46
+ return classes, class_to_idx
47
+
48
+
49
+ def make_dataset(
50
+ directory: Union[str, Path],
51
+ class_to_idx: Optional[Dict[str, int]] = None,
52
+ extensions: Optional[Union[str, Tuple[str, ...]]] = None,
53
+ is_valid_file: Optional[Callable[[str], bool]] = None,
54
+ allow_empty: bool = False,
55
+ ) -> List[Tuple[str, int]]:
56
+ """Generates a list of samples of a form (path_to_sample, class).
57
+
58
+ See :class:`DatasetFolder` for details.
59
+
60
+ Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
61
+ by default.
62
+ """
63
+ directory = os.path.expanduser(directory)
64
+
65
+ if class_to_idx is None:
66
+ _, class_to_idx = find_classes(directory)
67
+ elif not class_to_idx:
68
+ raise ValueError("'class_to_index' must have at least one entry to collect any samples.")
69
+
70
+ both_none = extensions is None and is_valid_file is None
71
+ both_something = extensions is not None and is_valid_file is not None
72
+ if both_none or both_something:
73
+ raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
74
+
75
+ if extensions is not None:
76
+
77
+ def is_valid_file(x: str) -> bool:
78
+ return has_file_allowed_extension(x, extensions) # type: ignore[arg-type]
79
+
80
+ is_valid_file = cast(Callable[[str], bool], is_valid_file)
81
+
82
+ instances = []
83
+ available_classes = set()
84
+ for target_class in sorted(class_to_idx.keys()):
85
+ class_index = class_to_idx[target_class]
86
+ target_dir = os.path.join(directory, target_class)
87
+ if not os.path.isdir(target_dir):
88
+ continue
89
+ for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
90
+ for fname in sorted(fnames):
91
+ path = os.path.join(root, fname)
92
+ if is_valid_file(path):
93
+ item = path, class_index
94
+ instances.append(item)
95
+
96
+ if target_class not in available_classes:
97
+ available_classes.add(target_class)
98
+
99
+ empty_classes = set(class_to_idx.keys()) - available_classes
100
+ if empty_classes and not allow_empty:
101
+ msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
102
+ if extensions is not None:
103
+ msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"
104
+ raise FileNotFoundError(msg)
105
+
106
+ return instances
107
+
108
+
109
+ class DatasetFolder(VisionDataset):
110
+ """A generic data loader.
111
+
112
+ This default directory structure can be customized by overriding the
113
+ :meth:`find_classes` method.
114
+
115
+ Args:
116
+ root (str or ``pathlib.Path``): Root directory path.
117
+ loader (callable): A function to load a sample given its path.
118
+ extensions (tuple[string]): A list of allowed extensions.
119
+ both extensions and is_valid_file should not be passed.
120
+ transform (callable, optional): A function/transform that takes in
121
+ a sample and returns a transformed version.
122
+ E.g, ``transforms.RandomCrop`` for images.
123
+ target_transform (callable, optional): A function/transform that takes
124
+ in the target and transforms it.
125
+ is_valid_file (callable, optional): A function that takes path of a file
126
+ and check if the file is a valid file (used to check of corrupt files)
127
+ both extensions and is_valid_file should not be passed.
128
+ allow_empty(bool, optional): If True, empty folders are considered to be valid classes.
129
+ An error is raised on empty folders if False (default).
130
+
131
+ Attributes:
132
+ classes (list): List of the class names sorted alphabetically.
133
+ class_to_idx (dict): Dict with items (class_name, class_index).
134
+ samples (list): List of (sample path, class_index) tuples
135
+ targets (list): The class_index value for each image in the dataset
136
+ """
137
+
138
+ def __init__(
139
+ self,
140
+ root: Union[str, Path],
141
+ loader: Callable[[str], Any],
142
+ extensions: Optional[Tuple[str, ...]] = None,
143
+ transform: Optional[Callable] = None,
144
+ target_transform: Optional[Callable] = None,
145
+ is_valid_file: Optional[Callable[[str], bool]] = None,
146
+ allow_empty: bool = False,
147
+ ) -> None:
148
+ super().__init__(root, transform=transform, target_transform=target_transform)
149
+ classes, class_to_idx = self.find_classes(self.root)
150
+ samples = self.make_dataset(
151
+ self.root,
152
+ class_to_idx=class_to_idx,
153
+ extensions=extensions,
154
+ is_valid_file=is_valid_file,
155
+ allow_empty=allow_empty,
156
+ )
157
+
158
+ self.loader = loader
159
+ self.extensions = extensions
160
+
161
+ self.classes = classes
162
+ self.class_to_idx = class_to_idx
163
+ self.samples = samples
164
+ self.targets = [s[1] for s in samples]
165
+
166
+ @staticmethod
167
+ def make_dataset(
168
+ directory: Union[str, Path],
169
+ class_to_idx: Dict[str, int],
170
+ extensions: Optional[Tuple[str, ...]] = None,
171
+ is_valid_file: Optional[Callable[[str], bool]] = None,
172
+ allow_empty: bool = False,
173
+ ) -> List[Tuple[str, int]]:
174
+ """Generates a list of samples of a form (path_to_sample, class).
175
+
176
+ This can be overridden to e.g. read files from a compressed zip file instead of from the disk.
177
+
178
+ Args:
179
+ directory (str): root dataset directory, corresponding to ``self.root``.
180
+ class_to_idx (Dict[str, int]): Dictionary mapping class name to class index.
181
+ extensions (optional): A list of allowed extensions.
182
+ Either extensions or is_valid_file should be passed. Defaults to None.
183
+ is_valid_file (optional): A function that takes path of a file
184
+ and checks if the file is a valid file
185
+ (used to check of corrupt files) both extensions and
186
+ is_valid_file should not be passed. Defaults to None.
187
+ allow_empty(bool, optional): If True, empty folders are considered to be valid classes.
188
+ An error is raised on empty folders if False (default).
189
+
190
+ Raises:
191
+ ValueError: In case ``class_to_idx`` is empty.
192
+ ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
193
+ FileNotFoundError: In case no valid file was found for any class.
194
+
195
+ Returns:
196
+ List[Tuple[str, int]]: samples of a form (path_to_sample, class)
197
+ """
198
+ if class_to_idx is None:
199
+ # prevent potential bug since make_dataset() would use the class_to_idx logic of the
200
+ # find_classes() function, instead of using that of the find_classes() method, which
201
+ # is potentially overridden and thus could have a different logic.
202
+ raise ValueError("The class_to_idx parameter cannot be None.")
203
+ return make_dataset(
204
+ directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file, allow_empty=allow_empty
205
+ )
206
+
207
+ def find_classes(self, directory: Union[str, Path]) -> Tuple[List[str], Dict[str, int]]:
208
+ """Find the class folders in a dataset structured as follows::
209
+
210
+ directory/
211
+ ├── class_x
212
+ │ ├── xxx.ext
213
+ │ ├── xxy.ext
214
+ │ └── ...
215
+ │ └── xxz.ext
216
+ └── class_y
217
+ ├── 123.ext
218
+ ├── nsdf3.ext
219
+ └── ...
220
+ └── asd932_.ext
221
+
222
+ This method can be overridden to only consider
223
+ a subset of classes, or to adapt to a different dataset directory structure.
224
+
225
+ Args:
226
+ directory(str): Root directory path, corresponding to ``self.root``
227
+
228
+ Raises:
229
+ FileNotFoundError: If ``dir`` has no class folders.
230
+
231
+ Returns:
232
+ (Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index.
233
+ """
234
+ return find_classes(directory)
235
+
236
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
237
+ """
238
+ Args:
239
+ index (int): Index
240
+
241
+ Returns:
242
+ tuple: (sample, target) where target is class_index of the target class.
243
+ """
244
+ path, target = self.samples[index]
245
+ sample = self.loader(path)
246
+ if self.transform is not None:
247
+ sample = self.transform(sample)
248
+ if self.target_transform is not None:
249
+ target = self.target_transform(target)
250
+
251
+ return sample, target
252
+
253
+ def __len__(self) -> int:
254
+ return len(self.samples)
255
+
256
+
257
+ IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
258
+
259
+
260
+ def pil_loader(path: str) -> Image.Image:
261
+ # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
262
+ with open(path, "rb") as f:
263
+ img = Image.open(f)
264
+ return img.convert("RGB")
265
+
266
+
267
+ # TODO: specify the return type
268
+ def accimage_loader(path: str) -> Any:
269
+ import accimage
270
+
271
+ try:
272
+ return accimage.Image(path)
273
+ except OSError:
274
+ # Potentially a decoding problem, fall back to PIL.Image
275
+ return pil_loader(path)
276
+
277
+
278
+ def default_loader(path: str) -> Any:
279
+ from torchvision import get_image_backend
280
+
281
+ if get_image_backend() == "accimage":
282
+ return accimage_loader(path)
283
+ else:
284
+ return pil_loader(path)
285
+
286
+
287
+ class ImageFolder(DatasetFolder):
288
+ """A generic data loader where the images are arranged in this way by default: ::
289
+
290
+ root/dog/xxx.png
291
+ root/dog/xxy.png
292
+ root/dog/[...]/xxz.png
293
+
294
+ root/cat/123.png
295
+ root/cat/nsdf3.png
296
+ root/cat/[...]/asd932_.png
297
+
298
+ This class inherits from :class:`~torchvision.datasets.DatasetFolder` so
299
+ the same methods can be overridden to customize the dataset.
300
+
301
+ Args:
302
+ root (str or ``pathlib.Path``): Root directory path.
303
+ transform (callable, optional): A function/transform that takes in a PIL image
304
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
305
+ target_transform (callable, optional): A function/transform that takes in the
306
+ target and transforms it.
307
+ loader (callable, optional): A function to load an image given its path.
308
+ is_valid_file (callable, optional): A function that takes path of an Image file
309
+ and check if the file is a valid file (used to check of corrupt files)
310
+ allow_empty(bool, optional): If True, empty folders are considered to be valid classes.
311
+ An error is raised on empty folders if False (default).
312
+
313
+ Attributes:
314
+ classes (list): List of the class names sorted alphabetically.
315
+ class_to_idx (dict): Dict with items (class_name, class_index).
316
+ imgs (list): List of (image path, class_index) tuples
317
+ """
318
+
319
+ def __init__(
320
+ self,
321
+ root: Union[str, Path],
322
+ transform: Optional[Callable] = None,
323
+ target_transform: Optional[Callable] = None,
324
+ loader: Callable[[str], Any] = default_loader,
325
+ is_valid_file: Optional[Callable[[str], bool]] = None,
326
+ allow_empty: bool = False,
327
+ ):
328
+ super().__init__(
329
+ root,
330
+ loader,
331
+ IMG_EXTENSIONS if is_valid_file is None else None,
332
+ transform=transform,
333
+ target_transform=target_transform,
334
+ is_valid_file=is_valid_file,
335
+ allow_empty=allow_empty,
336
+ )
337
+ self.imgs = self.samples
.venv/lib/python3.11/site-packages/torchvision/datasets/gtsrb.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import pathlib
3
+ from typing import Any, Callable, Optional, Tuple, Union
4
+
5
+ import PIL
6
+
7
+ from .folder import make_dataset
8
+ from .utils import download_and_extract_archive, verify_str_arg
9
+ from .vision import VisionDataset
10
+
11
+
12
+ class GTSRB(VisionDataset):
13
+ """`German Traffic Sign Recognition Benchmark (GTSRB) <https://benchmark.ini.rub.de/>`_ Dataset.
14
+
15
+ Args:
16
+ root (str or ``pathlib.Path``): Root directory of the dataset.
17
+ split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
18
+ transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
19
+ version. E.g, ``transforms.RandomCrop``.
20
+ target_transform (callable, optional): A function/transform that takes in the target and transforms it.
21
+ download (bool, optional): If True, downloads the dataset from the internet and
22
+ puts it in root directory. If dataset is already downloaded, it is not
23
+ downloaded again.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ root: Union[str, pathlib.Path],
29
+ split: str = "train",
30
+ transform: Optional[Callable] = None,
31
+ target_transform: Optional[Callable] = None,
32
+ download: bool = False,
33
+ ) -> None:
34
+
35
+ super().__init__(root, transform=transform, target_transform=target_transform)
36
+
37
+ self._split = verify_str_arg(split, "split", ("train", "test"))
38
+ self._base_folder = pathlib.Path(root) / "gtsrb"
39
+ self._target_folder = (
40
+ self._base_folder / "GTSRB" / ("Training" if self._split == "train" else "Final_Test/Images")
41
+ )
42
+
43
+ if download:
44
+ self.download()
45
+
46
+ if not self._check_exists():
47
+ raise RuntimeError("Dataset not found. You can use download=True to download it")
48
+
49
+ if self._split == "train":
50
+ samples = make_dataset(str(self._target_folder), extensions=(".ppm",))
51
+ else:
52
+ with open(self._base_folder / "GT-final_test.csv") as csv_file:
53
+ samples = [
54
+ (str(self._target_folder / row["Filename"]), int(row["ClassId"]))
55
+ for row in csv.DictReader(csv_file, delimiter=";", skipinitialspace=True)
56
+ ]
57
+
58
+ self._samples = samples
59
+ self.transform = transform
60
+ self.target_transform = target_transform
61
+
62
+ def __len__(self) -> int:
63
+ return len(self._samples)
64
+
65
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
66
+
67
+ path, target = self._samples[index]
68
+ sample = PIL.Image.open(path).convert("RGB")
69
+
70
+ if self.transform is not None:
71
+ sample = self.transform(sample)
72
+
73
+ if self.target_transform is not None:
74
+ target = self.target_transform(target)
75
+
76
+ return sample, target
77
+
78
+ def _check_exists(self) -> bool:
79
+ return self._target_folder.is_dir()
80
+
81
+ def download(self) -> None:
82
+ if self._check_exists():
83
+ return
84
+
85
+ base_url = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/"
86
+
87
+ if self._split == "train":
88
+ download_and_extract_archive(
89
+ f"{base_url}GTSRB-Training_fixed.zip",
90
+ download_root=str(self._base_folder),
91
+ md5="513f3c79a4c5141765e10e952eaa2478",
92
+ )
93
+ else:
94
+ download_and_extract_archive(
95
+ f"{base_url}GTSRB_Final_Test_Images.zip",
96
+ download_root=str(self._base_folder),
97
+ md5="c7e4e6327067d32654124b0fe9e82185",
98
+ )
99
+ download_and_extract_archive(
100
+ f"{base_url}GTSRB_Final_Test_GT.zip",
101
+ download_root=str(self._base_folder),
102
+ md5="fe31e9c9270bbcd7b84b7f21a9d9d9e5",
103
+ )
.venv/lib/python3.11/site-packages/torchvision/datasets/inaturalist.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path
3
+ from pathlib import Path
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ from PIL import Image
7
+
8
+ from .utils import download_and_extract_archive, verify_str_arg
9
+ from .vision import VisionDataset
10
+
11
+ CATEGORIES_2021 = ["kingdom", "phylum", "class", "order", "family", "genus"]
12
+
13
+ DATASET_URLS = {
14
+ "2017": "https://ml-inat-competition-datasets.s3.amazonaws.com/2017/train_val_images.tar.gz",
15
+ "2018": "https://ml-inat-competition-datasets.s3.amazonaws.com/2018/train_val2018.tar.gz",
16
+ "2019": "https://ml-inat-competition-datasets.s3.amazonaws.com/2019/train_val2019.tar.gz",
17
+ "2021_train": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train.tar.gz",
18
+ "2021_train_mini": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train_mini.tar.gz",
19
+ "2021_valid": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/val.tar.gz",
20
+ }
21
+
22
+ DATASET_MD5 = {
23
+ "2017": "7c784ea5e424efaec655bd392f87301f",
24
+ "2018": "b1c6952ce38f31868cc50ea72d066cc3",
25
+ "2019": "c60a6e2962c9b8ccbd458d12c8582644",
26
+ "2021_train": "e0526d53c7f7b2e3167b2b43bb2690ed",
27
+ "2021_train_mini": "db6ed8330e634445efc8fec83ae81442",
28
+ "2021_valid": "f6f6e0e242e3d4c9569ba56400938afc",
29
+ }
30
+
31
+
32
+ class INaturalist(VisionDataset):
33
+ """`iNaturalist <https://github.com/visipedia/inat_comp>`_ Dataset.
34
+
35
+ Args:
36
+ root (str or ``pathlib.Path``): Root directory of dataset where the image files are stored.
37
+ This class does not require/use annotation files.
38
+ version (string, optional): Which version of the dataset to download/use. One of
39
+ '2017', '2018', '2019', '2021_train', '2021_train_mini', '2021_valid'.
40
+ Default: `2021_train`.
41
+ target_type (string or list, optional): Type of target to use, for 2021 versions, one of:
42
+
43
+ - ``full``: the full category (species)
44
+ - ``kingdom``: e.g. "Animalia"
45
+ - ``phylum``: e.g. "Arthropoda"
46
+ - ``class``: e.g. "Insecta"
47
+ - ``order``: e.g. "Coleoptera"
48
+ - ``family``: e.g. "Cleridae"
49
+ - ``genus``: e.g. "Trichodes"
50
+
51
+ for 2017-2019 versions, one of:
52
+
53
+ - ``full``: the full (numeric) category
54
+ - ``super``: the super category, e.g. "Amphibians"
55
+
56
+ Can also be a list to output a tuple with all specified target types.
57
+ Defaults to ``full``.
58
+ transform (callable, optional): A function/transform that takes in a PIL image
59
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
60
+ target_transform (callable, optional): A function/transform that takes in the
61
+ target and transforms it.
62
+ download (bool, optional): If true, downloads the dataset from the internet and
63
+ puts it in root directory. If dataset is already downloaded, it is not
64
+ downloaded again.
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ root: Union[str, Path],
70
+ version: str = "2021_train",
71
+ target_type: Union[List[str], str] = "full",
72
+ transform: Optional[Callable] = None,
73
+ target_transform: Optional[Callable] = None,
74
+ download: bool = False,
75
+ ) -> None:
76
+ self.version = verify_str_arg(version, "version", DATASET_URLS.keys())
77
+
78
+ super().__init__(os.path.join(root, version), transform=transform, target_transform=target_transform)
79
+
80
+ os.makedirs(root, exist_ok=True)
81
+ if download:
82
+ self.download()
83
+
84
+ if not self._check_integrity():
85
+ raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
86
+
87
+ self.all_categories: List[str] = []
88
+
89
+ # map: category type -> name of category -> index
90
+ self.categories_index: Dict[str, Dict[str, int]] = {}
91
+
92
+ # list indexed by category id, containing mapping from category type -> index
93
+ self.categories_map: List[Dict[str, int]] = []
94
+
95
+ if not isinstance(target_type, list):
96
+ target_type = [target_type]
97
+ if self.version[:4] == "2021":
98
+ self.target_type = [verify_str_arg(t, "target_type", ("full", *CATEGORIES_2021)) for t in target_type]
99
+ self._init_2021()
100
+ else:
101
+ self.target_type = [verify_str_arg(t, "target_type", ("full", "super")) for t in target_type]
102
+ self._init_pre2021()
103
+
104
+ # index of all files: (full category id, filename)
105
+ self.index: List[Tuple[int, str]] = []
106
+
107
+ for dir_index, dir_name in enumerate(self.all_categories):
108
+ files = os.listdir(os.path.join(self.root, dir_name))
109
+ for fname in files:
110
+ self.index.append((dir_index, fname))
111
+
112
+ def _init_2021(self) -> None:
113
+ """Initialize based on 2021 layout"""
114
+
115
+ self.all_categories = sorted(os.listdir(self.root))
116
+
117
+ # map: category type -> name of category -> index
118
+ self.categories_index = {k: {} for k in CATEGORIES_2021}
119
+
120
+ for dir_index, dir_name in enumerate(self.all_categories):
121
+ pieces = dir_name.split("_")
122
+ if len(pieces) != 8:
123
+ raise RuntimeError(f"Unexpected category name {dir_name}, wrong number of pieces")
124
+ if pieces[0] != f"{dir_index:05d}":
125
+ raise RuntimeError(f"Unexpected category id {pieces[0]}, expecting {dir_index:05d}")
126
+ cat_map = {}
127
+ for cat, name in zip(CATEGORIES_2021, pieces[1:7]):
128
+ if name in self.categories_index[cat]:
129
+ cat_id = self.categories_index[cat][name]
130
+ else:
131
+ cat_id = len(self.categories_index[cat])
132
+ self.categories_index[cat][name] = cat_id
133
+ cat_map[cat] = cat_id
134
+ self.categories_map.append(cat_map)
135
+
136
+ def _init_pre2021(self) -> None:
137
+ """Initialize based on 2017-2019 layout"""
138
+
139
+ # map: category type -> name of category -> index
140
+ self.categories_index = {"super": {}}
141
+
142
+ cat_index = 0
143
+ super_categories = sorted(os.listdir(self.root))
144
+ for sindex, scat in enumerate(super_categories):
145
+ self.categories_index["super"][scat] = sindex
146
+ subcategories = sorted(os.listdir(os.path.join(self.root, scat)))
147
+ for subcat in subcategories:
148
+ if self.version == "2017":
149
+ # this version does not use ids as directory names
150
+ subcat_i = cat_index
151
+ cat_index += 1
152
+ else:
153
+ try:
154
+ subcat_i = int(subcat)
155
+ except ValueError:
156
+ raise RuntimeError(f"Unexpected non-numeric dir name: {subcat}")
157
+ if subcat_i >= len(self.categories_map):
158
+ old_len = len(self.categories_map)
159
+ self.categories_map.extend([{}] * (subcat_i - old_len + 1))
160
+ self.all_categories.extend([""] * (subcat_i - old_len + 1))
161
+ if self.categories_map[subcat_i]:
162
+ raise RuntimeError(f"Duplicate category {subcat}")
163
+ self.categories_map[subcat_i] = {"super": sindex}
164
+ self.all_categories[subcat_i] = os.path.join(scat, subcat)
165
+
166
+ # validate the dictionary
167
+ for cindex, c in enumerate(self.categories_map):
168
+ if not c:
169
+ raise RuntimeError(f"Missing category {cindex}")
170
+
171
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
172
+ """
173
+ Args:
174
+ index (int): Index
175
+
176
+ Returns:
177
+ tuple: (image, target) where the type of target specified by target_type.
178
+ """
179
+
180
+ cat_id, fname = self.index[index]
181
+ img = Image.open(os.path.join(self.root, self.all_categories[cat_id], fname))
182
+
183
+ target: Any = []
184
+ for t in self.target_type:
185
+ if t == "full":
186
+ target.append(cat_id)
187
+ else:
188
+ target.append(self.categories_map[cat_id][t])
189
+ target = tuple(target) if len(target) > 1 else target[0]
190
+
191
+ if self.transform is not None:
192
+ img = self.transform(img)
193
+
194
+ if self.target_transform is not None:
195
+ target = self.target_transform(target)
196
+
197
+ return img, target
198
+
199
+ def __len__(self) -> int:
200
+ return len(self.index)
201
+
202
+ def category_name(self, category_type: str, category_id: int) -> str:
203
+ """
204
+ Args:
205
+ category_type(str): one of "full", "kingdom", "phylum", "class", "order", "family", "genus" or "super"
206
+ category_id(int): an index (class id) from this category
207
+
208
+ Returns:
209
+ the name of the category
210
+ """
211
+ if category_type == "full":
212
+ return self.all_categories[category_id]
213
+ else:
214
+ if category_type not in self.categories_index:
215
+ raise ValueError(f"Invalid category type '{category_type}'")
216
+ else:
217
+ for name, id in self.categories_index[category_type].items():
218
+ if id == category_id:
219
+ return name
220
+ raise ValueError(f"Invalid category id {category_id} for {category_type}")
221
+
222
+ def _check_integrity(self) -> bool:
223
+ return os.path.exists(self.root) and len(os.listdir(self.root)) > 0
224
+
225
+ def download(self) -> None:
226
+ if self._check_integrity():
227
+ raise RuntimeError(
228
+ f"The directory {self.root} already exists. "
229
+ f"If you want to re-download or re-extract the images, delete the directory."
230
+ )
231
+
232
+ base_root = os.path.dirname(self.root)
233
+
234
+ download_and_extract_archive(
235
+ DATASET_URLS[self.version], base_root, filename=f"{self.version}.tgz", md5=DATASET_MD5[self.version]
236
+ )
237
+
238
+ orig_dir_name = os.path.join(base_root, os.path.basename(DATASET_URLS[self.version]).rstrip(".tar.gz"))
239
+ if not os.path.exists(orig_dir_name):
240
+ raise RuntimeError(f"Unable to find downloaded files at {orig_dir_name}")
241
+ os.rename(orig_dir_name, self.root)
242
+ print(f"Dataset version '{self.version}' has been downloaded and prepared for use")
.venv/lib/python3.11/site-packages/torchvision/datasets/lsun.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os.path
3
+ import pickle
4
+ import string
5
+ from collections.abc import Iterable
6
+ from pathlib import Path
7
+ from typing import Any, Callable, cast, List, Optional, Tuple, Union
8
+
9
+ from PIL import Image
10
+
11
+ from .utils import iterable_to_str, verify_str_arg
12
+ from .vision import VisionDataset
13
+
14
+
15
+ class LSUNClass(VisionDataset):
16
+ def __init__(
17
+ self, root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None
18
+ ) -> None:
19
+ import lmdb
20
+
21
+ super().__init__(root, transform=transform, target_transform=target_transform)
22
+
23
+ self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False)
24
+ with self.env.begin(write=False) as txn:
25
+ self.length = txn.stat()["entries"]
26
+ cache_file = "_cache_" + "".join(c for c in root if c in string.ascii_letters)
27
+ if os.path.isfile(cache_file):
28
+ self.keys = pickle.load(open(cache_file, "rb"))
29
+ else:
30
+ with self.env.begin(write=False) as txn:
31
+ self.keys = [key for key in txn.cursor().iternext(keys=True, values=False)]
32
+ pickle.dump(self.keys, open(cache_file, "wb"))
33
+
34
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
35
+ img, target = None, None
36
+ env = self.env
37
+ with env.begin(write=False) as txn:
38
+ imgbuf = txn.get(self.keys[index])
39
+
40
+ buf = io.BytesIO()
41
+ buf.write(imgbuf)
42
+ buf.seek(0)
43
+ img = Image.open(buf).convert("RGB")
44
+
45
+ if self.transform is not None:
46
+ img = self.transform(img)
47
+
48
+ if self.target_transform is not None:
49
+ target = self.target_transform(target)
50
+
51
+ return img, target
52
+
53
+ def __len__(self) -> int:
54
+ return self.length
55
+
56
+
57
+ class LSUN(VisionDataset):
58
+ """`LSUN <https://www.yf.io/p/lsun>`_ dataset.
59
+
60
+ You will need to install the ``lmdb`` package to use this dataset: run
61
+ ``pip install lmdb``
62
+
63
+ Args:
64
+ root (str or ``pathlib.Path``): Root directory for the database files.
65
+ classes (string or list): One of {'train', 'val', 'test'} or a list of
66
+ categories to load. e,g. ['bedroom_train', 'church_outdoor_train'].
67
+ transform (callable, optional): A function/transform that takes in a PIL image
68
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
69
+ target_transform (callable, optional): A function/transform that takes in the
70
+ target and transforms it.
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ root: Union[str, Path],
76
+ classes: Union[str, List[str]] = "train",
77
+ transform: Optional[Callable] = None,
78
+ target_transform: Optional[Callable] = None,
79
+ ) -> None:
80
+ super().__init__(root, transform=transform, target_transform=target_transform)
81
+ self.classes = self._verify_classes(classes)
82
+
83
+ # for each class, create an LSUNClassDataset
84
+ self.dbs = []
85
+ for c in self.classes:
86
+ self.dbs.append(LSUNClass(root=os.path.join(root, f"{c}_lmdb"), transform=transform))
87
+
88
+ self.indices = []
89
+ count = 0
90
+ for db in self.dbs:
91
+ count += len(db)
92
+ self.indices.append(count)
93
+
94
+ self.length = count
95
+
96
+ def _verify_classes(self, classes: Union[str, List[str]]) -> List[str]:
97
+ categories = [
98
+ "bedroom",
99
+ "bridge",
100
+ "church_outdoor",
101
+ "classroom",
102
+ "conference_room",
103
+ "dining_room",
104
+ "kitchen",
105
+ "living_room",
106
+ "restaurant",
107
+ "tower",
108
+ ]
109
+ dset_opts = ["train", "val", "test"]
110
+
111
+ try:
112
+ classes = cast(str, classes)
113
+ verify_str_arg(classes, "classes", dset_opts)
114
+ if classes == "test":
115
+ classes = [classes]
116
+ else:
117
+ classes = [c + "_" + classes for c in categories]
118
+ except ValueError:
119
+ if not isinstance(classes, Iterable):
120
+ msg = "Expected type str or Iterable for argument classes, but got type {}."
121
+ raise ValueError(msg.format(type(classes)))
122
+
123
+ classes = list(classes)
124
+ msg_fmtstr_type = "Expected type str for elements in argument classes, but got type {}."
125
+ for c in classes:
126
+ verify_str_arg(c, custom_msg=msg_fmtstr_type.format(type(c)))
127
+ c_short = c.split("_")
128
+ category, dset_opt = "_".join(c_short[:-1]), c_short[-1]
129
+
130
+ msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}."
131
+ msg = msg_fmtstr.format(category, "LSUN class", iterable_to_str(categories))
132
+ verify_str_arg(category, valid_values=categories, custom_msg=msg)
133
+
134
+ msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts))
135
+ verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg)
136
+
137
+ return classes
138
+
139
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
140
+ """
141
+ Args:
142
+ index (int): Index
143
+
144
+ Returns:
145
+ tuple: Tuple (image, target) where target is the index of the target category.
146
+ """
147
+ target = 0
148
+ sub = 0
149
+ for ind in self.indices:
150
+ if index < ind:
151
+ break
152
+ target += 1
153
+ sub = ind
154
+
155
+ db = self.dbs[target]
156
+ index = index - sub
157
+
158
+ if self.target_transform is not None:
159
+ target = self.target_transform(target)
160
+
161
+ img, _ = db[index]
162
+ return img, target
163
+
164
+ def __len__(self) -> int:
165
+ return self.length
166
+
167
+ def extra_repr(self) -> str:
168
+ return "Classes: {classes}".format(**self.__dict__)
.venv/lib/python3.11/site-packages/torchvision/datasets/sbu.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Any, Callable, Optional, Tuple, Union
4
+
5
+ from PIL import Image
6
+
7
+ from .utils import check_integrity, download_and_extract_archive, download_url
8
+ from .vision import VisionDataset
9
+
10
+
11
+ class SBU(VisionDataset):
12
+ """`SBU Captioned Photo <http://www.cs.virginia.edu/~vicente/sbucaptions/>`_ Dataset.
13
+
14
+ Args:
15
+ root (str or ``pathlib.Path``): Root directory of dataset where tarball
16
+ ``SBUCaptionedPhotoDataset.tar.gz`` exists.
17
+ transform (callable, optional): A function/transform that takes in a PIL image
18
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
19
+ target_transform (callable, optional): A function/transform that takes in the
20
+ target and transforms it.
21
+ download (bool, optional): If True, downloads the dataset from the internet and
22
+ puts it in root directory. If dataset is already downloaded, it is not
23
+ downloaded again.
24
+ """
25
+
26
+ url = "https://www.cs.rice.edu/~vo9/sbucaptions/SBUCaptionedPhotoDataset.tar.gz"
27
+ filename = "SBUCaptionedPhotoDataset.tar.gz"
28
+ md5_checksum = "9aec147b3488753cf758b4d493422285"
29
+
30
+ def __init__(
31
+ self,
32
+ root: Union[str, Path],
33
+ transform: Optional[Callable] = None,
34
+ target_transform: Optional[Callable] = None,
35
+ download: bool = True,
36
+ ) -> None:
37
+ super().__init__(root, transform=transform, target_transform=target_transform)
38
+
39
+ if download:
40
+ self.download()
41
+
42
+ if not self._check_integrity():
43
+ raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
44
+
45
+ # Read the caption for each photo
46
+ self.photos = []
47
+ self.captions = []
48
+
49
+ file1 = os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_urls.txt")
50
+ file2 = os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_captions.txt")
51
+
52
+ for line1, line2 in zip(open(file1), open(file2)):
53
+ url = line1.rstrip()
54
+ photo = os.path.basename(url)
55
+ filename = os.path.join(self.root, "dataset", photo)
56
+ if os.path.exists(filename):
57
+ caption = line2.rstrip()
58
+ self.photos.append(photo)
59
+ self.captions.append(caption)
60
+
61
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
62
+ """
63
+ Args:
64
+ index (int): Index
65
+
66
+ Returns:
67
+ tuple: (image, target) where target is a caption for the photo.
68
+ """
69
+ filename = os.path.join(self.root, "dataset", self.photos[index])
70
+ img = Image.open(filename).convert("RGB")
71
+ if self.transform is not None:
72
+ img = self.transform(img)
73
+
74
+ target = self.captions[index]
75
+ if self.target_transform is not None:
76
+ target = self.target_transform(target)
77
+
78
+ return img, target
79
+
80
+ def __len__(self) -> int:
81
+ """The number of photos in the dataset."""
82
+ return len(self.photos)
83
+
84
+ def _check_integrity(self) -> bool:
85
+ """Check the md5 checksum of the downloaded tarball."""
86
+ root = self.root
87
+ fpath = os.path.join(root, self.filename)
88
+ if not check_integrity(fpath, self.md5_checksum):
89
+ return False
90
+ return True
91
+
92
+ def download(self) -> None:
93
+ """Download and extract the tarball, and download each individual photo."""
94
+
95
+ if self._check_integrity():
96
+ print("Files already downloaded and verified")
97
+ return
98
+
99
+ download_and_extract_archive(self.url, self.root, self.root, self.filename, self.md5_checksum)
100
+
101
+ # Download individual photos
102
+ with open(os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_urls.txt")) as fh:
103
+ for line in fh:
104
+ url = line.rstrip()
105
+ try:
106
+ download_url(url, os.path.join(self.root, "dataset"))
107
+ except OSError:
108
+ # The images point to public images on Flickr.
109
+ # Note: Images might be removed by users at anytime.
110
+ pass
.venv/lib/python3.11/site-packages/torchvision/datasets/svhn.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ from pathlib import Path
3
+ from typing import Any, Callable, Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ from PIL import Image
7
+
8
+ from .utils import check_integrity, download_url, verify_str_arg
9
+ from .vision import VisionDataset
10
+
11
+
12
+ class SVHN(VisionDataset):
13
+ """`SVHN <http://ufldl.stanford.edu/housenumbers/>`_ Dataset.
14
+ Note: The SVHN dataset assigns the label `10` to the digit `0`. However, in this Dataset,
15
+ we assign the label `0` to the digit `0` to be compatible with PyTorch loss functions which
16
+ expect the class labels to be in the range `[0, C-1]`
17
+
18
+ .. warning::
19
+
20
+ This class needs `scipy <https://docs.scipy.org/doc/>`_ to load data from `.mat` format.
21
+
22
+ Args:
23
+ root (str or ``pathlib.Path``): Root directory of the dataset where the data is stored.
24
+ split (string): One of {'train', 'test', 'extra'}.
25
+ Accordingly dataset is selected. 'extra' is Extra training set.
26
+ transform (callable, optional): A function/transform that takes in a PIL image
27
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
28
+ target_transform (callable, optional): A function/transform that takes in the
29
+ target and transforms it.
30
+ download (bool, optional): If true, downloads the dataset from the internet and
31
+ puts it in root directory. If dataset is already downloaded, it is not
32
+ downloaded again.
33
+
34
+ """
35
+
36
+ split_list = {
37
+ "train": [
38
+ "http://ufldl.stanford.edu/housenumbers/train_32x32.mat",
39
+ "train_32x32.mat",
40
+ "e26dedcc434d2e4c54c9b2d4a06d8373",
41
+ ],
42
+ "test": [
43
+ "http://ufldl.stanford.edu/housenumbers/test_32x32.mat",
44
+ "test_32x32.mat",
45
+ "eb5a983be6a315427106f1b164d9cef3",
46
+ ],
47
+ "extra": [
48
+ "http://ufldl.stanford.edu/housenumbers/extra_32x32.mat",
49
+ "extra_32x32.mat",
50
+ "a93ce644f1a588dc4d68dda5feec44a7",
51
+ ],
52
+ }
53
+
54
+ def __init__(
55
+ self,
56
+ root: Union[str, Path],
57
+ split: str = "train",
58
+ transform: Optional[Callable] = None,
59
+ target_transform: Optional[Callable] = None,
60
+ download: bool = False,
61
+ ) -> None:
62
+ super().__init__(root, transform=transform, target_transform=target_transform)
63
+ self.split = verify_str_arg(split, "split", tuple(self.split_list.keys()))
64
+ self.url = self.split_list[split][0]
65
+ self.filename = self.split_list[split][1]
66
+ self.file_md5 = self.split_list[split][2]
67
+
68
+ if download:
69
+ self.download()
70
+
71
+ if not self._check_integrity():
72
+ raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
73
+
74
+ # import here rather than at top of file because this is
75
+ # an optional dependency for torchvision
76
+ import scipy.io as sio
77
+
78
+ # reading(loading) mat file as array
79
+ loaded_mat = sio.loadmat(os.path.join(self.root, self.filename))
80
+
81
+ self.data = loaded_mat["X"]
82
+ # loading from the .mat file gives an np.ndarray of type np.uint8
83
+ # converting to np.int64, so that we have a LongTensor after
84
+ # the conversion from the numpy array
85
+ # the squeeze is needed to obtain a 1D tensor
86
+ self.labels = loaded_mat["y"].astype(np.int64).squeeze()
87
+
88
+ # the svhn dataset assigns the class label "10" to the digit 0
89
+ # this makes it inconsistent with several loss functions
90
+ # which expect the class labels to be in the range [0, C-1]
91
+ np.place(self.labels, self.labels == 10, 0)
92
+ self.data = np.transpose(self.data, (3, 2, 0, 1))
93
+
94
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
95
+ """
96
+ Args:
97
+ index (int): Index
98
+
99
+ Returns:
100
+ tuple: (image, target) where target is index of the target class.
101
+ """
102
+ img, target = self.data[index], int(self.labels[index])
103
+
104
+ # doing this so that it is consistent with all other datasets
105
+ # to return a PIL Image
106
+ img = Image.fromarray(np.transpose(img, (1, 2, 0)))
107
+
108
+ if self.transform is not None:
109
+ img = self.transform(img)
110
+
111
+ if self.target_transform is not None:
112
+ target = self.target_transform(target)
113
+
114
+ return img, target
115
+
116
+ def __len__(self) -> int:
117
+ return len(self.data)
118
+
119
+ def _check_integrity(self) -> bool:
120
+ root = self.root
121
+ md5 = self.split_list[self.split][2]
122
+ fpath = os.path.join(root, self.filename)
123
+ return check_integrity(fpath, md5)
124
+
125
+ def download(self) -> None:
126
+ md5 = self.split_list[self.split][2]
127
+ download_url(self.url, self.root, self.filename, md5)
128
+
129
+ def extra_repr(self) -> str:
130
+ return "Split: {split}".format(**self.__dict__)
.venv/lib/python3.11/site-packages/torchvision/datasets/widerface.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from os.path import abspath, expanduser
3
+ from pathlib import Path
4
+
5
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
6
+
7
+ import torch
8
+ from PIL import Image
9
+
10
+ from .utils import download_and_extract_archive, download_file_from_google_drive, extract_archive, verify_str_arg
11
+ from .vision import VisionDataset
12
+
13
+
14
+ class WIDERFace(VisionDataset):
15
+ """`WIDERFace <http://shuoyang1213.me/WIDERFACE/>`_ Dataset.
16
+
17
+ Args:
18
+ root (str or ``pathlib.Path``): Root directory where images and annotations are downloaded to.
19
+ Expects the following folder structure if download=False:
20
+
21
+ .. code::
22
+
23
+ <root>
24
+ └── widerface
25
+ ├── wider_face_split ('wider_face_split.zip' if compressed)
26
+ ├── WIDER_train ('WIDER_train.zip' if compressed)
27
+ ├── WIDER_val ('WIDER_val.zip' if compressed)
28
+ └── WIDER_test ('WIDER_test.zip' if compressed)
29
+ split (string): The dataset split to use. One of {``train``, ``val``, ``test``}.
30
+ Defaults to ``train``.
31
+ transform (callable, optional): A function/transform that takes in a PIL image
32
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
33
+ target_transform (callable, optional): A function/transform that takes in the
34
+ target and transforms it.
35
+ download (bool, optional): If true, downloads the dataset from the internet and
36
+ puts it in root directory. If dataset is already downloaded, it is not
37
+ downloaded again.
38
+
39
+ .. warning::
40
+
41
+ To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
42
+
43
+ """
44
+
45
+ BASE_FOLDER = "widerface"
46
+ FILE_LIST = [
47
+ # File ID MD5 Hash Filename
48
+ ("15hGDLhsx8bLgLcIRD5DhYt5iBxnjNF1M", "3fedf70df600953d25982bcd13d91ba2", "WIDER_train.zip"),
49
+ ("1GUCogbp16PMGa39thoMMeWxp7Rp5oM8Q", "dfa7d7e790efa35df3788964cf0bbaea", "WIDER_val.zip"),
50
+ ("1HIfDbVEWKmsYKJZm4lchTBDLW5N7dY5T", "e5d8f4248ed24c334bbd12f49c29dd40", "WIDER_test.zip"),
51
+ ]
52
+ ANNOTATIONS_FILE = (
53
+ "http://shuoyang1213.me/WIDERFACE/support/bbx_annotation/wider_face_split.zip",
54
+ "0e3767bcf0e326556d407bf5bff5d27c",
55
+ "wider_face_split.zip",
56
+ )
57
+
58
+ def __init__(
59
+ self,
60
+ root: Union[str, Path],
61
+ split: str = "train",
62
+ transform: Optional[Callable] = None,
63
+ target_transform: Optional[Callable] = None,
64
+ download: bool = False,
65
+ ) -> None:
66
+ super().__init__(
67
+ root=os.path.join(root, self.BASE_FOLDER), transform=transform, target_transform=target_transform
68
+ )
69
+ # check arguments
70
+ self.split = verify_str_arg(split, "split", ("train", "val", "test"))
71
+
72
+ if download:
73
+ self.download()
74
+
75
+ if not self._check_integrity():
76
+ raise RuntimeError("Dataset not found or corrupted. You can use download=True to download and prepare it")
77
+
78
+ self.img_info: List[Dict[str, Union[str, Dict[str, torch.Tensor]]]] = []
79
+ if self.split in ("train", "val"):
80
+ self.parse_train_val_annotations_file()
81
+ else:
82
+ self.parse_test_annotations_file()
83
+
84
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
85
+ """
86
+ Args:
87
+ index (int): Index
88
+
89
+ Returns:
90
+ tuple: (image, target) where target is a dict of annotations for all faces in the image.
91
+ target=None for the test split.
92
+ """
93
+
94
+ # stay consistent with other datasets and return a PIL Image
95
+ img = Image.open(self.img_info[index]["img_path"]) # type: ignore[arg-type]
96
+
97
+ if self.transform is not None:
98
+ img = self.transform(img)
99
+
100
+ target = None if self.split == "test" else self.img_info[index]["annotations"]
101
+ if self.target_transform is not None:
102
+ target = self.target_transform(target)
103
+
104
+ return img, target
105
+
106
+ def __len__(self) -> int:
107
+ return len(self.img_info)
108
+
109
+ def extra_repr(self) -> str:
110
+ lines = ["Split: {split}"]
111
+ return "\n".join(lines).format(**self.__dict__)
112
+
113
+ def parse_train_val_annotations_file(self) -> None:
114
+ filename = "wider_face_train_bbx_gt.txt" if self.split == "train" else "wider_face_val_bbx_gt.txt"
115
+ filepath = os.path.join(self.root, "wider_face_split", filename)
116
+
117
+ with open(filepath) as f:
118
+ lines = f.readlines()
119
+ file_name_line, num_boxes_line, box_annotation_line = True, False, False
120
+ num_boxes, box_counter = 0, 0
121
+ labels = []
122
+ for line in lines:
123
+ line = line.rstrip()
124
+ if file_name_line:
125
+ img_path = os.path.join(self.root, "WIDER_" + self.split, "images", line)
126
+ img_path = abspath(expanduser(img_path))
127
+ file_name_line = False
128
+ num_boxes_line = True
129
+ elif num_boxes_line:
130
+ num_boxes = int(line)
131
+ num_boxes_line = False
132
+ box_annotation_line = True
133
+ elif box_annotation_line:
134
+ box_counter += 1
135
+ line_split = line.split(" ")
136
+ line_values = [int(x) for x in line_split]
137
+ labels.append(line_values)
138
+ if box_counter >= num_boxes:
139
+ box_annotation_line = False
140
+ file_name_line = True
141
+ labels_tensor = torch.tensor(labels)
142
+ self.img_info.append(
143
+ {
144
+ "img_path": img_path,
145
+ "annotations": {
146
+ "bbox": labels_tensor[:, 0:4].clone(), # x, y, width, height
147
+ "blur": labels_tensor[:, 4].clone(),
148
+ "expression": labels_tensor[:, 5].clone(),
149
+ "illumination": labels_tensor[:, 6].clone(),
150
+ "occlusion": labels_tensor[:, 7].clone(),
151
+ "pose": labels_tensor[:, 8].clone(),
152
+ "invalid": labels_tensor[:, 9].clone(),
153
+ },
154
+ }
155
+ )
156
+ box_counter = 0
157
+ labels.clear()
158
+ else:
159
+ raise RuntimeError(f"Error parsing annotation file {filepath}")
160
+
161
+ def parse_test_annotations_file(self) -> None:
162
+ filepath = os.path.join(self.root, "wider_face_split", "wider_face_test_filelist.txt")
163
+ filepath = abspath(expanduser(filepath))
164
+ with open(filepath) as f:
165
+ lines = f.readlines()
166
+ for line in lines:
167
+ line = line.rstrip()
168
+ img_path = os.path.join(self.root, "WIDER_test", "images", line)
169
+ img_path = abspath(expanduser(img_path))
170
+ self.img_info.append({"img_path": img_path})
171
+
172
+ def _check_integrity(self) -> bool:
173
+ # Allow original archive to be deleted (zip). Only need the extracted images
174
+ all_files = self.FILE_LIST.copy()
175
+ all_files.append(self.ANNOTATIONS_FILE)
176
+ for (_, md5, filename) in all_files:
177
+ file, ext = os.path.splitext(filename)
178
+ extracted_dir = os.path.join(self.root, file)
179
+ if not os.path.exists(extracted_dir):
180
+ return False
181
+ return True
182
+
183
+ def download(self) -> None:
184
+ if self._check_integrity():
185
+ print("Files already downloaded and verified")
186
+ return
187
+
188
+ # download and extract image data
189
+ for (file_id, md5, filename) in self.FILE_LIST:
190
+ download_file_from_google_drive(file_id, self.root, filename, md5)
191
+ filepath = os.path.join(self.root, filename)
192
+ extract_archive(filepath)
193
+
194
+ # download and extract annotation files
195
+ download_and_extract_archive(
196
+ url=self.ANNOTATIONS_FILE[0], download_root=self.root, md5=self.ANNOTATIONS_FILE[1]
197
+ )
.venv/lib/python3.11/site-packages/torchvision/io/__init__.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Iterator
2
+
3
+ import torch
4
+
5
+ from ..utils import _log_api_usage_once
6
+
7
+ try:
8
+ from ._load_gpu_decoder import _HAS_GPU_VIDEO_DECODER
9
+ except ModuleNotFoundError:
10
+ _HAS_GPU_VIDEO_DECODER = False
11
+
12
+ from ._video_opt import (
13
+ _HAS_CPU_VIDEO_DECODER,
14
+ _HAS_VIDEO_OPT,
15
+ _probe_video_from_file,
16
+ _probe_video_from_memory,
17
+ _read_video_from_file,
18
+ _read_video_from_memory,
19
+ _read_video_timestamps_from_file,
20
+ _read_video_timestamps_from_memory,
21
+ Timebase,
22
+ VideoMetaData,
23
+ )
24
+ from .image import (
25
+ decode_gif,
26
+ decode_image,
27
+ decode_jpeg,
28
+ decode_png,
29
+ decode_webp,
30
+ encode_jpeg,
31
+ encode_png,
32
+ ImageReadMode,
33
+ read_file,
34
+ read_image,
35
+ write_file,
36
+ write_jpeg,
37
+ write_png,
38
+ )
39
+ from .video import read_video, read_video_timestamps, write_video
40
+ from .video_reader import VideoReader
41
+
42
+
43
+ __all__ = [
44
+ "write_video",
45
+ "read_video",
46
+ "read_video_timestamps",
47
+ "_read_video_from_file",
48
+ "_read_video_timestamps_from_file",
49
+ "_probe_video_from_file",
50
+ "_read_video_from_memory",
51
+ "_read_video_timestamps_from_memory",
52
+ "_probe_video_from_memory",
53
+ "_HAS_CPU_VIDEO_DECODER",
54
+ "_HAS_VIDEO_OPT",
55
+ "_HAS_GPU_VIDEO_DECODER",
56
+ "_read_video_clip_from_memory",
57
+ "_read_video_meta_data",
58
+ "VideoMetaData",
59
+ "Timebase",
60
+ "ImageReadMode",
61
+ "decode_image",
62
+ "decode_jpeg",
63
+ "decode_png",
64
+ "decode_heic",
65
+ "decode_webp",
66
+ "decode_gif",
67
+ "encode_jpeg",
68
+ "encode_png",
69
+ "read_file",
70
+ "read_image",
71
+ "write_file",
72
+ "write_jpeg",
73
+ "write_png",
74
+ "Video",
75
+ "VideoReader",
76
+ ]
.venv/lib/python3.11/site-packages/torchvision/io/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.92 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/io/__pycache__/_load_gpu_decoder.cpython-311.pyc ADDED
Binary file (464 Bytes). View file
 
.venv/lib/python3.11/site-packages/torchvision/io/__pycache__/_video_opt.cpython-311.pyc ADDED
Binary file (23.9 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/io/__pycache__/image.cpython-311.pyc ADDED
Binary file (24.9 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/io/__pycache__/video.cpython-311.pyc ADDED
Binary file (21.8 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/io/__pycache__/video_reader.cpython-311.pyc ADDED
Binary file (15.4 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/io/_load_gpu_decoder.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from ..extension import _load_library
2
+
3
+
4
+ try:
5
+ _load_library("gpu_decoder")
6
+ _HAS_GPU_VIDEO_DECODER = True
7
+ except (ImportError, OSError):
8
+ _HAS_GPU_VIDEO_DECODER = False
.venv/lib/python3.11/site-packages/torchvision/io/_video_opt.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from fractions import Fraction
4
+ from typing import Dict, List, Optional, Tuple, Union
5
+
6
+ import torch
7
+
8
+ from ..extension import _load_library
9
+
10
+
11
+ try:
12
+ _load_library("video_reader")
13
+ _HAS_CPU_VIDEO_DECODER = True
14
+ except (ImportError, OSError):
15
+ _HAS_CPU_VIDEO_DECODER = False
16
+
17
+ _HAS_VIDEO_OPT = _HAS_CPU_VIDEO_DECODER # For BC
18
+ default_timebase = Fraction(0, 1)
19
+
20
+
21
+ # simple class for torch scripting
22
+ # the complex Fraction class from fractions module is not scriptable
23
+ class Timebase:
24
+ __annotations__ = {"numerator": int, "denominator": int}
25
+ __slots__ = ["numerator", "denominator"]
26
+
27
+ def __init__(
28
+ self,
29
+ numerator: int,
30
+ denominator: int,
31
+ ) -> None:
32
+ self.numerator = numerator
33
+ self.denominator = denominator
34
+
35
+
36
+ class VideoMetaData:
37
+ __annotations__ = {
38
+ "has_video": bool,
39
+ "video_timebase": Timebase,
40
+ "video_duration": float,
41
+ "video_fps": float,
42
+ "has_audio": bool,
43
+ "audio_timebase": Timebase,
44
+ "audio_duration": float,
45
+ "audio_sample_rate": float,
46
+ }
47
+ __slots__ = [
48
+ "has_video",
49
+ "video_timebase",
50
+ "video_duration",
51
+ "video_fps",
52
+ "has_audio",
53
+ "audio_timebase",
54
+ "audio_duration",
55
+ "audio_sample_rate",
56
+ ]
57
+
58
+ def __init__(self) -> None:
59
+ self.has_video = False
60
+ self.video_timebase = Timebase(0, 1)
61
+ self.video_duration = 0.0
62
+ self.video_fps = 0.0
63
+ self.has_audio = False
64
+ self.audio_timebase = Timebase(0, 1)
65
+ self.audio_duration = 0.0
66
+ self.audio_sample_rate = 0.0
67
+
68
+
69
+ def _validate_pts(pts_range: Tuple[int, int]) -> None:
70
+
71
+ if pts_range[0] > pts_range[1] > 0:
72
+ raise ValueError(
73
+ f"Start pts should not be smaller than end pts, got start pts: {pts_range[0]} and end pts: {pts_range[1]}"
74
+ )
75
+
76
+
77
+ def _fill_info(
78
+ vtimebase: torch.Tensor,
79
+ vfps: torch.Tensor,
80
+ vduration: torch.Tensor,
81
+ atimebase: torch.Tensor,
82
+ asample_rate: torch.Tensor,
83
+ aduration: torch.Tensor,
84
+ ) -> VideoMetaData:
85
+ """
86
+ Build update VideoMetaData struct with info about the video
87
+ """
88
+ meta = VideoMetaData()
89
+ if vtimebase.numel() > 0:
90
+ meta.video_timebase = Timebase(int(vtimebase[0].item()), int(vtimebase[1].item()))
91
+ timebase = vtimebase[0].item() / float(vtimebase[1].item())
92
+ if vduration.numel() > 0:
93
+ meta.has_video = True
94
+ meta.video_duration = float(vduration.item()) * timebase
95
+ if vfps.numel() > 0:
96
+ meta.video_fps = float(vfps.item())
97
+ if atimebase.numel() > 0:
98
+ meta.audio_timebase = Timebase(int(atimebase[0].item()), int(atimebase[1].item()))
99
+ timebase = atimebase[0].item() / float(atimebase[1].item())
100
+ if aduration.numel() > 0:
101
+ meta.has_audio = True
102
+ meta.audio_duration = float(aduration.item()) * timebase
103
+ if asample_rate.numel() > 0:
104
+ meta.audio_sample_rate = float(asample_rate.item())
105
+
106
+ return meta
107
+
108
+
109
+ def _align_audio_frames(
110
+ aframes: torch.Tensor, aframe_pts: torch.Tensor, audio_pts_range: Tuple[int, int]
111
+ ) -> torch.Tensor:
112
+ start, end = aframe_pts[0], aframe_pts[-1]
113
+ num_samples = aframes.size(0)
114
+ step_per_aframe = float(end - start + 1) / float(num_samples)
115
+ s_idx = 0
116
+ e_idx = num_samples
117
+ if start < audio_pts_range[0]:
118
+ s_idx = int((audio_pts_range[0] - start) / step_per_aframe)
119
+ if audio_pts_range[1] != -1 and end > audio_pts_range[1]:
120
+ e_idx = int((audio_pts_range[1] - end) / step_per_aframe)
121
+ return aframes[s_idx:e_idx, :]
122
+
123
+
124
+ def _read_video_from_file(
125
+ filename: str,
126
+ seek_frame_margin: float = 0.25,
127
+ read_video_stream: bool = True,
128
+ video_width: int = 0,
129
+ video_height: int = 0,
130
+ video_min_dimension: int = 0,
131
+ video_max_dimension: int = 0,
132
+ video_pts_range: Tuple[int, int] = (0, -1),
133
+ video_timebase: Fraction = default_timebase,
134
+ read_audio_stream: bool = True,
135
+ audio_samples: int = 0,
136
+ audio_channels: int = 0,
137
+ audio_pts_range: Tuple[int, int] = (0, -1),
138
+ audio_timebase: Fraction = default_timebase,
139
+ ) -> Tuple[torch.Tensor, torch.Tensor, VideoMetaData]:
140
+ """
141
+ Reads a video from a file, returning both the video frames and the audio frames
142
+
143
+ Args:
144
+ filename (str): path to the video file
145
+ seek_frame_margin (double, optional): seeking frame in the stream is imprecise. Thus,
146
+ when video_start_pts is specified, we seek the pts earlier by seek_frame_margin seconds
147
+ read_video_stream (int, optional): whether read video stream. If yes, set to 1. Otherwise, 0
148
+ video_width/video_height/video_min_dimension/video_max_dimension (int): together decide
149
+ the size of decoded frames:
150
+
151
+ - When video_width = 0, video_height = 0, video_min_dimension = 0,
152
+ and video_max_dimension = 0, keep the original frame resolution
153
+ - When video_width = 0, video_height = 0, video_min_dimension != 0,
154
+ and video_max_dimension = 0, keep the aspect ratio and resize the
155
+ frame so that shorter edge size is video_min_dimension
156
+ - When video_width = 0, video_height = 0, video_min_dimension = 0,
157
+ and video_max_dimension != 0, keep the aspect ratio and resize
158
+ the frame so that longer edge size is video_max_dimension
159
+ - When video_width = 0, video_height = 0, video_min_dimension != 0,
160
+ and video_max_dimension != 0, resize the frame so that shorter
161
+ edge size is video_min_dimension, and longer edge size is
162
+ video_max_dimension. The aspect ratio may not be preserved
163
+ - When video_width = 0, video_height != 0, video_min_dimension = 0,
164
+ and video_max_dimension = 0, keep the aspect ratio and resize
165
+ the frame so that frame video_height is $video_height
166
+ - When video_width != 0, video_height == 0, video_min_dimension = 0,
167
+ and video_max_dimension = 0, keep the aspect ratio and resize
168
+ the frame so that frame video_width is $video_width
169
+ - When video_width != 0, video_height != 0, video_min_dimension = 0,
170
+ and video_max_dimension = 0, resize the frame so that frame
171
+ video_width and video_height are set to $video_width and
172
+ $video_height, respectively
173
+ video_pts_range (list(int), optional): the start and end presentation timestamp of video stream
174
+ video_timebase (Fraction, optional): a Fraction rational number which denotes timebase in video stream
175
+ read_audio_stream (int, optional): whether read audio stream. If yes, set to 1. Otherwise, 0
176
+ audio_samples (int, optional): audio sampling rate
177
+ audio_channels (int optional): audio channels
178
+ audio_pts_range (list(int), optional): the start and end presentation timestamp of audio stream
179
+ audio_timebase (Fraction, optional): a Fraction rational number which denotes time base in audio stream
180
+
181
+ Returns
182
+ vframes (Tensor[T, H, W, C]): the `T` video frames
183
+ aframes (Tensor[L, K]): the audio frames, where `L` is the number of points and
184
+ `K` is the number of audio_channels
185
+ info (Dict): metadata for the video and audio. Can contain the fields video_fps (float)
186
+ and audio_fps (int)
187
+ """
188
+ _validate_pts(video_pts_range)
189
+ _validate_pts(audio_pts_range)
190
+
191
+ result = torch.ops.video_reader.read_video_from_file(
192
+ filename,
193
+ seek_frame_margin,
194
+ 0, # getPtsOnly
195
+ read_video_stream,
196
+ video_width,
197
+ video_height,
198
+ video_min_dimension,
199
+ video_max_dimension,
200
+ video_pts_range[0],
201
+ video_pts_range[1],
202
+ video_timebase.numerator,
203
+ video_timebase.denominator,
204
+ read_audio_stream,
205
+ audio_samples,
206
+ audio_channels,
207
+ audio_pts_range[0],
208
+ audio_pts_range[1],
209
+ audio_timebase.numerator,
210
+ audio_timebase.denominator,
211
+ )
212
+ vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, atimebase, asample_rate, aduration = result
213
+ info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
214
+ if aframes.numel() > 0:
215
+ # when audio stream is found
216
+ aframes = _align_audio_frames(aframes, aframe_pts, audio_pts_range)
217
+ return vframes, aframes, info
218
+
219
+
220
+ def _read_video_timestamps_from_file(filename: str) -> Tuple[List[int], List[int], VideoMetaData]:
221
+ """
222
+ Decode all video- and audio frames in the video. Only pts
223
+ (presentation timestamp) is returned. The actual frame pixel data is not
224
+ copied. Thus, it is much faster than read_video(...)
225
+ """
226
+ result = torch.ops.video_reader.read_video_from_file(
227
+ filename,
228
+ 0, # seek_frame_margin
229
+ 1, # getPtsOnly
230
+ 1, # read_video_stream
231
+ 0, # video_width
232
+ 0, # video_height
233
+ 0, # video_min_dimension
234
+ 0, # video_max_dimension
235
+ 0, # video_start_pts
236
+ -1, # video_end_pts
237
+ 0, # video_timebase_num
238
+ 1, # video_timebase_den
239
+ 1, # read_audio_stream
240
+ 0, # audio_samples
241
+ 0, # audio_channels
242
+ 0, # audio_start_pts
243
+ -1, # audio_end_pts
244
+ 0, # audio_timebase_num
245
+ 1, # audio_timebase_den
246
+ )
247
+ _vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, atimebase, asample_rate, aduration = result
248
+ info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
249
+
250
+ vframe_pts = vframe_pts.numpy().tolist()
251
+ aframe_pts = aframe_pts.numpy().tolist()
252
+ return vframe_pts, aframe_pts, info
253
+
254
+
255
+ def _probe_video_from_file(filename: str) -> VideoMetaData:
256
+ """
257
+ Probe a video file and return VideoMetaData with info about the video
258
+ """
259
+ result = torch.ops.video_reader.probe_video_from_file(filename)
260
+ vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
261
+ info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
262
+ return info
263
+
264
+
265
+ def _read_video_from_memory(
266
+ video_data: torch.Tensor,
267
+ seek_frame_margin: float = 0.25,
268
+ read_video_stream: int = 1,
269
+ video_width: int = 0,
270
+ video_height: int = 0,
271
+ video_min_dimension: int = 0,
272
+ video_max_dimension: int = 0,
273
+ video_pts_range: Tuple[int, int] = (0, -1),
274
+ video_timebase_numerator: int = 0,
275
+ video_timebase_denominator: int = 1,
276
+ read_audio_stream: int = 1,
277
+ audio_samples: int = 0,
278
+ audio_channels: int = 0,
279
+ audio_pts_range: Tuple[int, int] = (0, -1),
280
+ audio_timebase_numerator: int = 0,
281
+ audio_timebase_denominator: int = 1,
282
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
283
+ """
284
+ Reads a video from memory, returning both the video frames as the audio frames
285
+ This function is torchscriptable.
286
+
287
+ Args:
288
+ video_data (data type could be 1) torch.Tensor, dtype=torch.int8 or 2) python bytes):
289
+ compressed video content stored in either 1) torch.Tensor 2) python bytes
290
+ seek_frame_margin (double, optional): seeking frame in the stream is imprecise.
291
+ Thus, when video_start_pts is specified, we seek the pts earlier by seek_frame_margin seconds
292
+ read_video_stream (int, optional): whether read video stream. If yes, set to 1. Otherwise, 0
293
+ video_width/video_height/video_min_dimension/video_max_dimension (int): together decide
294
+ the size of decoded frames:
295
+
296
+ - When video_width = 0, video_height = 0, video_min_dimension = 0,
297
+ and video_max_dimension = 0, keep the original frame resolution
298
+ - When video_width = 0, video_height = 0, video_min_dimension != 0,
299
+ and video_max_dimension = 0, keep the aspect ratio and resize the
300
+ frame so that shorter edge size is video_min_dimension
301
+ - When video_width = 0, video_height = 0, video_min_dimension = 0,
302
+ and video_max_dimension != 0, keep the aspect ratio and resize
303
+ the frame so that longer edge size is video_max_dimension
304
+ - When video_width = 0, video_height = 0, video_min_dimension != 0,
305
+ and video_max_dimension != 0, resize the frame so that shorter
306
+ edge size is video_min_dimension, and longer edge size is
307
+ video_max_dimension. The aspect ratio may not be preserved
308
+ - When video_width = 0, video_height != 0, video_min_dimension = 0,
309
+ and video_max_dimension = 0, keep the aspect ratio and resize
310
+ the frame so that frame video_height is $video_height
311
+ - When video_width != 0, video_height == 0, video_min_dimension = 0,
312
+ and video_max_dimension = 0, keep the aspect ratio and resize
313
+ the frame so that frame video_width is $video_width
314
+ - When video_width != 0, video_height != 0, video_min_dimension = 0,
315
+ and video_max_dimension = 0, resize the frame so that frame
316
+ video_width and video_height are set to $video_width and
317
+ $video_height, respectively
318
+ video_pts_range (list(int), optional): the start and end presentation timestamp of video stream
319
+ video_timebase_numerator / video_timebase_denominator (float, optional): a rational
320
+ number which denotes timebase in video stream
321
+ read_audio_stream (int, optional): whether read audio stream. If yes, set to 1. Otherwise, 0
322
+ audio_samples (int, optional): audio sampling rate
323
+ audio_channels (int optional): audio audio_channels
324
+ audio_pts_range (list(int), optional): the start and end presentation timestamp of audio stream
325
+ audio_timebase_numerator / audio_timebase_denominator (float, optional):
326
+ a rational number which denotes time base in audio stream
327
+
328
+ Returns:
329
+ vframes (Tensor[T, H, W, C]): the `T` video frames
330
+ aframes (Tensor[L, K]): the audio frames, where `L` is the number of points and
331
+ `K` is the number of channels
332
+ """
333
+
334
+ _validate_pts(video_pts_range)
335
+ _validate_pts(audio_pts_range)
336
+
337
+ if not isinstance(video_data, torch.Tensor):
338
+ with warnings.catch_warnings():
339
+ # Ignore the warning because we actually don't modify the buffer in this function
340
+ warnings.filterwarnings("ignore", message="The given buffer is not writable")
341
+ video_data = torch.frombuffer(video_data, dtype=torch.uint8)
342
+
343
+ result = torch.ops.video_reader.read_video_from_memory(
344
+ video_data,
345
+ seek_frame_margin,
346
+ 0, # getPtsOnly
347
+ read_video_stream,
348
+ video_width,
349
+ video_height,
350
+ video_min_dimension,
351
+ video_max_dimension,
352
+ video_pts_range[0],
353
+ video_pts_range[1],
354
+ video_timebase_numerator,
355
+ video_timebase_denominator,
356
+ read_audio_stream,
357
+ audio_samples,
358
+ audio_channels,
359
+ audio_pts_range[0],
360
+ audio_pts_range[1],
361
+ audio_timebase_numerator,
362
+ audio_timebase_denominator,
363
+ )
364
+
365
+ vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, atimebase, asample_rate, aduration = result
366
+
367
+ if aframes.numel() > 0:
368
+ # when audio stream is found
369
+ aframes = _align_audio_frames(aframes, aframe_pts, audio_pts_range)
370
+
371
+ return vframes, aframes
372
+
373
+
374
+ def _read_video_timestamps_from_memory(
375
+ video_data: torch.Tensor,
376
+ ) -> Tuple[List[int], List[int], VideoMetaData]:
377
+ """
378
+ Decode all frames in the video. Only pts (presentation timestamp) is returned.
379
+ The actual frame pixel data is not copied. Thus, read_video_timestamps(...)
380
+ is much faster than read_video(...)
381
+ """
382
+ if not isinstance(video_data, torch.Tensor):
383
+ with warnings.catch_warnings():
384
+ # Ignore the warning because we actually don't modify the buffer in this function
385
+ warnings.filterwarnings("ignore", message="The given buffer is not writable")
386
+ video_data = torch.frombuffer(video_data, dtype=torch.uint8)
387
+ result = torch.ops.video_reader.read_video_from_memory(
388
+ video_data,
389
+ 0, # seek_frame_margin
390
+ 1, # getPtsOnly
391
+ 1, # read_video_stream
392
+ 0, # video_width
393
+ 0, # video_height
394
+ 0, # video_min_dimension
395
+ 0, # video_max_dimension
396
+ 0, # video_start_pts
397
+ -1, # video_end_pts
398
+ 0, # video_timebase_num
399
+ 1, # video_timebase_den
400
+ 1, # read_audio_stream
401
+ 0, # audio_samples
402
+ 0, # audio_channels
403
+ 0, # audio_start_pts
404
+ -1, # audio_end_pts
405
+ 0, # audio_timebase_num
406
+ 1, # audio_timebase_den
407
+ )
408
+ _vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, atimebase, asample_rate, aduration = result
409
+ info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
410
+
411
+ vframe_pts = vframe_pts.numpy().tolist()
412
+ aframe_pts = aframe_pts.numpy().tolist()
413
+ return vframe_pts, aframe_pts, info
414
+
415
+
416
+ def _probe_video_from_memory(
417
+ video_data: torch.Tensor,
418
+ ) -> VideoMetaData:
419
+ """
420
+ Probe a video in memory and return VideoMetaData with info about the video
421
+ This function is torchscriptable
422
+ """
423
+ if not isinstance(video_data, torch.Tensor):
424
+ with warnings.catch_warnings():
425
+ # Ignore the warning because we actually don't modify the buffer in this function
426
+ warnings.filterwarnings("ignore", message="The given buffer is not writable")
427
+ video_data = torch.frombuffer(video_data, dtype=torch.uint8)
428
+ result = torch.ops.video_reader.probe_video_from_memory(video_data)
429
+ vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
430
+ info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
431
+ return info
432
+
433
+
434
+ def _read_video(
435
+ filename: str,
436
+ start_pts: Union[float, Fraction] = 0,
437
+ end_pts: Optional[Union[float, Fraction]] = None,
438
+ pts_unit: str = "pts",
439
+ ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, float]]:
440
+ if end_pts is None:
441
+ end_pts = float("inf")
442
+
443
+ if pts_unit == "pts":
444
+ warnings.warn(
445
+ "The pts_unit 'pts' gives wrong results and will be removed in a "
446
+ + "follow-up version. Please use pts_unit 'sec'."
447
+ )
448
+
449
+ info = _probe_video_from_file(filename)
450
+
451
+ has_video = info.has_video
452
+ has_audio = info.has_audio
453
+
454
+ def get_pts(time_base):
455
+ start_offset = start_pts
456
+ end_offset = end_pts
457
+ if pts_unit == "sec":
458
+ start_offset = int(math.floor(start_pts * (1 / time_base)))
459
+ if end_offset != float("inf"):
460
+ end_offset = int(math.ceil(end_pts * (1 / time_base)))
461
+ if end_offset == float("inf"):
462
+ end_offset = -1
463
+ return start_offset, end_offset
464
+
465
+ video_pts_range = (0, -1)
466
+ video_timebase = default_timebase
467
+ if has_video:
468
+ video_timebase = Fraction(info.video_timebase.numerator, info.video_timebase.denominator)
469
+ video_pts_range = get_pts(video_timebase)
470
+
471
+ audio_pts_range = (0, -1)
472
+ audio_timebase = default_timebase
473
+ if has_audio:
474
+ audio_timebase = Fraction(info.audio_timebase.numerator, info.audio_timebase.denominator)
475
+ audio_pts_range = get_pts(audio_timebase)
476
+
477
+ vframes, aframes, info = _read_video_from_file(
478
+ filename,
479
+ read_video_stream=True,
480
+ video_pts_range=video_pts_range,
481
+ video_timebase=video_timebase,
482
+ read_audio_stream=True,
483
+ audio_pts_range=audio_pts_range,
484
+ audio_timebase=audio_timebase,
485
+ )
486
+ _info = {}
487
+ if has_video:
488
+ _info["video_fps"] = info.video_fps
489
+ if has_audio:
490
+ _info["audio_fps"] = info.audio_sample_rate
491
+
492
+ return vframes, aframes, _info
493
+
494
+
495
+ def _read_video_timestamps(
496
+ filename: str, pts_unit: str = "pts"
497
+ ) -> Tuple[Union[List[int], List[Fraction]], Optional[float]]:
498
+ if pts_unit == "pts":
499
+ warnings.warn(
500
+ "The pts_unit 'pts' gives wrong results and will be removed in a "
501
+ + "follow-up version. Please use pts_unit 'sec'."
502
+ )
503
+
504
+ pts: Union[List[int], List[Fraction]]
505
+ pts, _, info = _read_video_timestamps_from_file(filename)
506
+
507
+ if pts_unit == "sec":
508
+ video_time_base = Fraction(info.video_timebase.numerator, info.video_timebase.denominator)
509
+ pts = [x * video_time_base for x in pts]
510
+
511
+ video_fps = info.video_fps if info.has_video else None
512
+
513
+ return pts, video_fps
.venv/lib/python3.11/site-packages/torchvision/io/image.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ from typing import List, Union
3
+ from warnings import warn
4
+
5
+ import torch
6
+
7
+ from ..extension import _load_library
8
+ from ..utils import _log_api_usage_once
9
+
10
+
11
+ try:
12
+ _load_library("image")
13
+ except (ImportError, OSError) as e:
14
+ warn(
15
+ f"Failed to load image Python extension: '{e}'"
16
+ f"If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. "
17
+ f"Otherwise, there might be something wrong with your environment. "
18
+ f"Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?"
19
+ )
20
+
21
+
22
+ class ImageReadMode(Enum):
23
+ """Allow automatic conversion to RGB, RGBA, etc while decoding.
24
+
25
+ .. note::
26
+
27
+ You don't need to use this struct, you can just pass strings to all
28
+ ``mode`` parameters, e.g. ``mode="RGB"``.
29
+
30
+ The different available modes are the following.
31
+
32
+ - UNCHANGED: loads the image as-is
33
+ - RGB: converts to RGB
34
+ - RGBA: converts to RGB with transparency (also aliased as RGB_ALPHA)
35
+ - GRAY: converts to grayscale
36
+ - GRAY_ALPHA: converts to grayscale with transparency
37
+
38
+ .. note::
39
+
40
+ Some decoders won't support all possible values, e.g. GRAY and
41
+ GRAY_ALPHA are only supported for PNG and JPEG images.
42
+ """
43
+
44
+ UNCHANGED = 0
45
+ GRAY = 1
46
+ GRAY_ALPHA = 2
47
+ RGB = 3
48
+ RGB_ALPHA = 4
49
+ RGBA = RGB_ALPHA # Alias for convenience
50
+
51
+
52
+ def read_file(path: str) -> torch.Tensor:
53
+ """
54
+ Return the bytes contents of a file as a uint8 1D Tensor.
55
+
56
+ Args:
57
+ path (str or ``pathlib.Path``): the path to the file to be read
58
+
59
+ Returns:
60
+ data (Tensor)
61
+ """
62
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
63
+ _log_api_usage_once(read_file)
64
+ data = torch.ops.image.read_file(str(path))
65
+ return data
66
+
67
+
68
+ def write_file(filename: str, data: torch.Tensor) -> None:
69
+ """
70
+ Write the content of an uint8 1D tensor to a file.
71
+
72
+ Args:
73
+ filename (str or ``pathlib.Path``): the path to the file to be written
74
+ data (Tensor): the contents to be written to the output file
75
+ """
76
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
77
+ _log_api_usage_once(write_file)
78
+ torch.ops.image.write_file(str(filename), data)
79
+
80
+
81
+ def decode_png(
82
+ input: torch.Tensor,
83
+ mode: ImageReadMode = ImageReadMode.UNCHANGED,
84
+ apply_exif_orientation: bool = False,
85
+ ) -> torch.Tensor:
86
+ """
87
+ Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor.
88
+
89
+ The values of the output tensor are in uint8 in [0, 255] for most cases. If
90
+ the image is a 16-bit png, then the output tensor is uint16 in [0, 65535]
91
+ (supported from torchvision ``0.21``). Since uint16 support is limited in
92
+ pytorch, we recommend calling
93
+ :func:`torchvision.transforms.v2.functional.to_dtype()` with ``scale=True``
94
+ after this function to convert the decoded image into a uint8 or float
95
+ tensor.
96
+
97
+ Args:
98
+ input (Tensor[1]): a one dimensional uint8 tensor containing
99
+ the raw bytes of the PNG image.
100
+ mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
101
+ Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode`
102
+ for available modes.
103
+ apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
104
+ Default: False.
105
+
106
+ Returns:
107
+ output (Tensor[image_channels, image_height, image_width])
108
+ """
109
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
110
+ _log_api_usage_once(decode_png)
111
+ if isinstance(mode, str):
112
+ mode = ImageReadMode[mode.upper()]
113
+ output = torch.ops.image.decode_png(input, mode.value, apply_exif_orientation)
114
+ return output
115
+
116
+
117
+ def encode_png(input: torch.Tensor, compression_level: int = 6) -> torch.Tensor:
118
+ """
119
+ Takes an input tensor in CHW layout and returns a buffer with the contents
120
+ of its corresponding PNG file.
121
+
122
+ Args:
123
+ input (Tensor[channels, image_height, image_width]): int8 image tensor of
124
+ ``c`` channels, where ``c`` must 3 or 1.
125
+ compression_level (int): Compression factor for the resulting file, it must be a number
126
+ between 0 and 9. Default: 6
127
+
128
+ Returns:
129
+ Tensor[1]: A one dimensional int8 tensor that contains the raw bytes of the
130
+ PNG file.
131
+ """
132
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
133
+ _log_api_usage_once(encode_png)
134
+ output = torch.ops.image.encode_png(input, compression_level)
135
+ return output
136
+
137
+
138
+ def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
139
+ """
140
+ Takes an input tensor in CHW layout (or HW in the case of grayscale images)
141
+ and saves it in a PNG file.
142
+
143
+ Args:
144
+ input (Tensor[channels, image_height, image_width]): int8 image tensor of
145
+ ``c`` channels, where ``c`` must be 1 or 3.
146
+ filename (str or ``pathlib.Path``): Path to save the image.
147
+ compression_level (int): Compression factor for the resulting file, it must be a number
148
+ between 0 and 9. Default: 6
149
+ """
150
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
151
+ _log_api_usage_once(write_png)
152
+ output = encode_png(input, compression_level)
153
+ write_file(filename, output)
154
+
155
+
156
+ def decode_jpeg(
157
+ input: Union[torch.Tensor, List[torch.Tensor]],
158
+ mode: ImageReadMode = ImageReadMode.UNCHANGED,
159
+ device: Union[str, torch.device] = "cpu",
160
+ apply_exif_orientation: bool = False,
161
+ ) -> Union[torch.Tensor, List[torch.Tensor]]:
162
+ """Decode JPEG image(s) into 3D RGB or grayscale Tensor(s), on CPU or CUDA.
163
+
164
+ The values of the output tensor are uint8 between 0 and 255.
165
+
166
+ .. note::
167
+ When using a CUDA device, passing a list of tensors is more efficient than repeated individual calls to ``decode_jpeg``.
168
+ When using CPU the performance is equivalent.
169
+ The CUDA version of this function has explicitly been designed with thread-safety in mind.
170
+ This function does not return partial results in case of an error.
171
+
172
+ Args:
173
+ input (Tensor[1] or list[Tensor[1]]): a (list of) one dimensional uint8 tensor(s) containing
174
+ the raw bytes of the JPEG image. The tensor(s) must be on CPU,
175
+ regardless of the ``device`` parameter.
176
+ mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
177
+ Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode`
178
+ for available modes.
179
+ device (str or torch.device): The device on which the decoded image will
180
+ be stored. If a cuda device is specified, the image will be decoded
181
+ with `nvjpeg <https://developer.nvidia.com/nvjpeg>`_. This is only
182
+ supported for CUDA version >= 10.1
183
+
184
+ .. betastatus:: device parameter
185
+
186
+ .. warning::
187
+ There is a memory leak in the nvjpeg library for CUDA versions < 11.6.
188
+ Make sure to rely on CUDA 11.6 or above before using ``device="cuda"``.
189
+ apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
190
+ Default: False. Only implemented for JPEG format on CPU.
191
+
192
+ Returns:
193
+ output (Tensor[image_channels, image_height, image_width] or list[Tensor[image_channels, image_height, image_width]]):
194
+ The values of the output tensor(s) are uint8 between 0 and 255.
195
+ ``output.device`` will be set to the specified ``device``
196
+
197
+
198
+ """
199
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
200
+ _log_api_usage_once(decode_jpeg)
201
+ if isinstance(device, str):
202
+ device = torch.device(device)
203
+ if isinstance(mode, str):
204
+ mode = ImageReadMode[mode.upper()]
205
+
206
+ if isinstance(input, list):
207
+ if len(input) == 0:
208
+ raise ValueError("Input list must contain at least one element")
209
+ if not all(isinstance(t, torch.Tensor) for t in input):
210
+ raise ValueError("All elements of the input list must be tensors.")
211
+ if not all(t.device.type == "cpu" for t in input):
212
+ raise ValueError("Input list must contain tensors on CPU.")
213
+ if device.type == "cuda":
214
+ return torch.ops.image.decode_jpegs_cuda(input, mode.value, device)
215
+ else:
216
+ return [torch.ops.image.decode_jpeg(img, mode.value, apply_exif_orientation) for img in input]
217
+
218
+ else: # input is tensor
219
+ if input.device.type != "cpu":
220
+ raise ValueError("Input tensor must be a CPU tensor")
221
+ if device.type == "cuda":
222
+ return torch.ops.image.decode_jpegs_cuda([input], mode.value, device)[0]
223
+ else:
224
+ return torch.ops.image.decode_jpeg(input, mode.value, apply_exif_orientation)
225
+
226
+
227
+ def encode_jpeg(
228
+ input: Union[torch.Tensor, List[torch.Tensor]], quality: int = 75
229
+ ) -> Union[torch.Tensor, List[torch.Tensor]]:
230
+ """Encode RGB tensor(s) into raw encoded jpeg bytes, on CPU or CUDA.
231
+
232
+ .. note::
233
+ Passing a list of CUDA tensors is more efficient than repeated individual calls to ``encode_jpeg``.
234
+ For CPU tensors the performance is equivalent.
235
+
236
+ Args:
237
+ input (Tensor[channels, image_height, image_width] or List[Tensor[channels, image_height, image_width]]):
238
+ (list of) uint8 image tensor(s) of ``c`` channels, where ``c`` must be 1 or 3
239
+ quality (int): Quality of the resulting JPEG file(s). Must be a number between
240
+ 1 and 100. Default: 75
241
+
242
+ Returns:
243
+ output (Tensor[1] or list[Tensor[1]]): A (list of) one dimensional uint8 tensor(s) that contain the raw bytes of the JPEG file.
244
+ """
245
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
246
+ _log_api_usage_once(encode_jpeg)
247
+ if quality < 1 or quality > 100:
248
+ raise ValueError("Image quality should be a positive number between 1 and 100")
249
+ if isinstance(input, list):
250
+ if not input:
251
+ raise ValueError("encode_jpeg requires at least one input tensor when a list is passed")
252
+ if input[0].device.type == "cuda":
253
+ return torch.ops.image.encode_jpegs_cuda(input, quality)
254
+ else:
255
+ return [torch.ops.image.encode_jpeg(image, quality) for image in input]
256
+ else: # single input tensor
257
+ if input.device.type == "cuda":
258
+ return torch.ops.image.encode_jpegs_cuda([input], quality)[0]
259
+ else:
260
+ return torch.ops.image.encode_jpeg(input, quality)
261
+
262
+
263
+ def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
264
+ """
265
+ Takes an input tensor in CHW layout and saves it in a JPEG file.
266
+
267
+ Args:
268
+ input (Tensor[channels, image_height, image_width]): int8 image tensor of ``c``
269
+ channels, where ``c`` must be 1 or 3.
270
+ filename (str or ``pathlib.Path``): Path to save the image.
271
+ quality (int): Quality of the resulting JPEG file, it must be a number
272
+ between 1 and 100. Default: 75
273
+ """
274
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
275
+ _log_api_usage_once(write_jpeg)
276
+ output = encode_jpeg(input, quality)
277
+ assert isinstance(output, torch.Tensor) # Needed for torchscript
278
+ write_file(filename, output)
279
+
280
+
281
+ def decode_image(
282
+ input: Union[torch.Tensor, str],
283
+ mode: ImageReadMode = ImageReadMode.UNCHANGED,
284
+ apply_exif_orientation: bool = False,
285
+ ) -> torch.Tensor:
286
+ """Decode an image into a uint8 tensor, from a path or from raw encoded bytes.
287
+
288
+ Currently supported image formats are jpeg, png, gif and webp.
289
+
290
+ The values of the output tensor are in uint8 in [0, 255] for most cases.
291
+
292
+ If the image is a 16-bit png, then the output tensor is uint16 in [0, 65535]
293
+ (supported from torchvision ``0.21``). Since uint16 support is limited in
294
+ pytorch, we recommend calling
295
+ :func:`torchvision.transforms.v2.functional.to_dtype()` with ``scale=True``
296
+ after this function to convert the decoded image into a uint8 or float
297
+ tensor.
298
+
299
+ Args:
300
+ input (Tensor or str or ``pathlib.Path``): The image to decode. If a
301
+ tensor is passed, it must be one dimensional uint8 tensor containing
302
+ the raw bytes of the image. Otherwise, this must be a path to the image file.
303
+ mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
304
+ Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode`
305
+ for available modes.
306
+ apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
307
+ Only applies to JPEG and PNG images. Default: False.
308
+
309
+ Returns:
310
+ output (Tensor[image_channels, image_height, image_width])
311
+ """
312
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
313
+ _log_api_usage_once(decode_image)
314
+ if not isinstance(input, torch.Tensor):
315
+ input = read_file(str(input))
316
+ if isinstance(mode, str):
317
+ mode = ImageReadMode[mode.upper()]
318
+ output = torch.ops.image.decode_image(input, mode.value, apply_exif_orientation)
319
+ return output
320
+
321
+
322
+ def read_image(
323
+ path: str,
324
+ mode: ImageReadMode = ImageReadMode.UNCHANGED,
325
+ apply_exif_orientation: bool = False,
326
+ ) -> torch.Tensor:
327
+ """[OBSOLETE] Use :func:`~torchvision.io.decode_image` instead."""
328
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
329
+ _log_api_usage_once(read_image)
330
+ data = read_file(path)
331
+ return decode_image(data, mode, apply_exif_orientation=apply_exif_orientation)
332
+
333
+
334
+ def decode_gif(input: torch.Tensor) -> torch.Tensor:
335
+ """
336
+ Decode a GIF image into a 3 or 4 dimensional RGB Tensor.
337
+
338
+ The values of the output tensor are uint8 between 0 and 255.
339
+ The output tensor has shape ``(C, H, W)`` if there is only one image in the
340
+ GIF, and ``(N, C, H, W)`` if there are ``N`` images.
341
+
342
+ Args:
343
+ input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
344
+ the raw bytes of the GIF image.
345
+
346
+ Returns:
347
+ output (Tensor[image_channels, image_height, image_width] or Tensor[num_images, image_channels, image_height, image_width])
348
+ """
349
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
350
+ _log_api_usage_once(decode_gif)
351
+ return torch.ops.image.decode_gif(input)
352
+
353
+
354
+ def decode_webp(
355
+ input: torch.Tensor,
356
+ mode: ImageReadMode = ImageReadMode.UNCHANGED,
357
+ ) -> torch.Tensor:
358
+ """
359
+ Decode a WEBP image into a 3 dimensional RGB[A] Tensor.
360
+
361
+ The values of the output tensor are uint8 between 0 and 255.
362
+
363
+ Args:
364
+ input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
365
+ the raw bytes of the WEBP image.
366
+ mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
367
+ Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode`
368
+ for available modes.
369
+
370
+ Returns:
371
+ Decoded image (Tensor[image_channels, image_height, image_width])
372
+ """
373
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
374
+ _log_api_usage_once(decode_webp)
375
+ if isinstance(mode, str):
376
+ mode = ImageReadMode[mode.upper()]
377
+ return torch.ops.image.decode_webp(input, mode.value)
378
+
379
+
380
+ def _decode_avif(
381
+ input: torch.Tensor,
382
+ mode: ImageReadMode = ImageReadMode.UNCHANGED,
383
+ ) -> torch.Tensor:
384
+ """
385
+ Decode an AVIF image into a 3 dimensional RGB[A] Tensor.
386
+
387
+ The values of the output tensor are in uint8 in [0, 255] for most images. If
388
+ the image has a bit-depth of more than 8, then the output tensor is uint16
389
+ in [0, 65535]. Since uint16 support is limited in pytorch, we recommend
390
+ calling :func:`torchvision.transforms.v2.functional.to_dtype()` with
391
+ ``scale=True`` after this function to convert the decoded image into a uint8
392
+ or float tensor.
393
+
394
+ Args:
395
+ input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
396
+ the raw bytes of the AVIF image.
397
+ mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
398
+ Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode`
399
+ for available modes.
400
+
401
+ Returns:
402
+ Decoded image (Tensor[image_channels, image_height, image_width])
403
+ """
404
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
405
+ _log_api_usage_once(_decode_avif)
406
+ if isinstance(mode, str):
407
+ mode = ImageReadMode[mode.upper()]
408
+ return torch.ops.image.decode_avif(input, mode.value)
409
+
410
+
411
+ def _decode_heic(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
412
+ """
413
+ Decode an HEIC image into a 3 dimensional RGB[A] Tensor.
414
+
415
+ The values of the output tensor are in uint8 in [0, 255] for most images. If
416
+ the image has a bit-depth of more than 8, then the output tensor is uint16
417
+ in [0, 65535]. Since uint16 support is limited in pytorch, we recommend
418
+ calling :func:`torchvision.transforms.v2.functional.to_dtype()` with
419
+ ``scale=True`` after this function to convert the decoded image into a uint8
420
+ or float tensor.
421
+
422
+ Args:
423
+ input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
424
+ the raw bytes of the HEIC image.
425
+ mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
426
+ Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode`
427
+ for available modes.
428
+
429
+ Returns:
430
+ Decoded image (Tensor[image_channels, image_height, image_width])
431
+ """
432
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
433
+ _log_api_usage_once(_decode_heic)
434
+ if isinstance(mode, str):
435
+ mode = ImageReadMode[mode.upper()]
436
+ return torch.ops.image.decode_heic(input, mode.value)
.venv/lib/python3.11/site-packages/torchvision/io/video.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import math
3
+ import os
4
+ import re
5
+ import warnings
6
+ from fractions import Fraction
7
+ from typing import Any, Dict, List, Optional, Tuple, Union
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+ from ..utils import _log_api_usage_once
13
+ from . import _video_opt
14
+
15
+ try:
16
+ import av
17
+
18
+ av.logging.set_level(av.logging.ERROR)
19
+ if not hasattr(av.video.frame.VideoFrame, "pict_type"):
20
+ av = ImportError(
21
+ """\
22
+ Your version of PyAV is too old for the necessary video operations in torchvision.
23
+ If you are on Python 3.5, you will have to build from source (the conda-forge
24
+ packages are not up-to-date). See
25
+ https://github.com/mikeboers/PyAV#installation for instructions on how to
26
+ install PyAV on your system.
27
+ """
28
+ )
29
+ except ImportError:
30
+ av = ImportError(
31
+ """\
32
+ PyAV is not installed, and is necessary for the video operations in torchvision.
33
+ See https://github.com/mikeboers/PyAV#installation for instructions on how to
34
+ install PyAV on your system.
35
+ """
36
+ )
37
+
38
+
39
+ def _check_av_available() -> None:
40
+ if isinstance(av, Exception):
41
+ raise av
42
+
43
+
44
+ def _av_available() -> bool:
45
+ return not isinstance(av, Exception)
46
+
47
+
48
+ # PyAV has some reference cycles
49
+ _CALLED_TIMES = 0
50
+ _GC_COLLECTION_INTERVAL = 10
51
+
52
+
53
+ def write_video(
54
+ filename: str,
55
+ video_array: torch.Tensor,
56
+ fps: float,
57
+ video_codec: str = "libx264",
58
+ options: Optional[Dict[str, Any]] = None,
59
+ audio_array: Optional[torch.Tensor] = None,
60
+ audio_fps: Optional[float] = None,
61
+ audio_codec: Optional[str] = None,
62
+ audio_options: Optional[Dict[str, Any]] = None,
63
+ ) -> None:
64
+ """
65
+ Writes a 4d tensor in [T, H, W, C] format in a video file
66
+
67
+ .. warning::
68
+
69
+ In the near future, we intend to centralize PyTorch's video decoding
70
+ capabilities within the `torchcodec
71
+ <https://github.com/pytorch/torchcodec>`_ project. We encourage you to
72
+ try it out and share your feedback, as the torchvision video decoders
73
+ will eventually be deprecated.
74
+
75
+ Args:
76
+ filename (str): path where the video will be saved
77
+ video_array (Tensor[T, H, W, C]): tensor containing the individual frames,
78
+ as a uint8 tensor in [T, H, W, C] format
79
+ fps (Number): video frames per second
80
+ video_codec (str): the name of the video codec, i.e. "libx264", "h264", etc.
81
+ options (Dict): dictionary containing options to be passed into the PyAV video stream
82
+ audio_array (Tensor[C, N]): tensor containing the audio, where C is the number of channels
83
+ and N is the number of samples
84
+ audio_fps (Number): audio sample rate, typically 44100 or 48000
85
+ audio_codec (str): the name of the audio codec, i.e. "mp3", "aac", etc.
86
+ audio_options (Dict): dictionary containing options to be passed into the PyAV audio stream
87
+ """
88
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
89
+ _log_api_usage_once(write_video)
90
+ _check_av_available()
91
+ video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy(force=True)
92
+
93
+ # PyAV does not support floating point numbers with decimal point
94
+ # and will throw OverflowException in case this is not the case
95
+ if isinstance(fps, float):
96
+ fps = np.round(fps)
97
+
98
+ with av.open(filename, mode="w") as container:
99
+ stream = container.add_stream(video_codec, rate=fps)
100
+ stream.width = video_array.shape[2]
101
+ stream.height = video_array.shape[1]
102
+ stream.pix_fmt = "yuv420p" if video_codec != "libx264rgb" else "rgb24"
103
+ stream.options = options or {}
104
+
105
+ if audio_array is not None:
106
+ audio_format_dtypes = {
107
+ "dbl": "<f8",
108
+ "dblp": "<f8",
109
+ "flt": "<f4",
110
+ "fltp": "<f4",
111
+ "s16": "<i2",
112
+ "s16p": "<i2",
113
+ "s32": "<i4",
114
+ "s32p": "<i4",
115
+ "u8": "u1",
116
+ "u8p": "u1",
117
+ }
118
+ a_stream = container.add_stream(audio_codec, rate=audio_fps)
119
+ a_stream.options = audio_options or {}
120
+
121
+ num_channels = audio_array.shape[0]
122
+ audio_layout = "stereo" if num_channels > 1 else "mono"
123
+ audio_sample_fmt = container.streams.audio[0].format.name
124
+
125
+ format_dtype = np.dtype(audio_format_dtypes[audio_sample_fmt])
126
+ audio_array = torch.as_tensor(audio_array).numpy(force=True).astype(format_dtype)
127
+
128
+ frame = av.AudioFrame.from_ndarray(audio_array, format=audio_sample_fmt, layout=audio_layout)
129
+
130
+ frame.sample_rate = audio_fps
131
+
132
+ for packet in a_stream.encode(frame):
133
+ container.mux(packet)
134
+
135
+ for packet in a_stream.encode():
136
+ container.mux(packet)
137
+
138
+ for img in video_array:
139
+ frame = av.VideoFrame.from_ndarray(img, format="rgb24")
140
+ frame.pict_type = "NONE"
141
+ for packet in stream.encode(frame):
142
+ container.mux(packet)
143
+
144
+ # Flush stream
145
+ for packet in stream.encode():
146
+ container.mux(packet)
147
+
148
+
149
+ def _read_from_stream(
150
+ container: "av.container.Container",
151
+ start_offset: float,
152
+ end_offset: float,
153
+ pts_unit: str,
154
+ stream: "av.stream.Stream",
155
+ stream_name: Dict[str, Optional[Union[int, Tuple[int, ...], List[int]]]],
156
+ ) -> List["av.frame.Frame"]:
157
+ global _CALLED_TIMES, _GC_COLLECTION_INTERVAL
158
+ _CALLED_TIMES += 1
159
+ if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1:
160
+ gc.collect()
161
+
162
+ if pts_unit == "sec":
163
+ # TODO: we should change all of this from ground up to simply take
164
+ # sec and convert to MS in C++
165
+ start_offset = int(math.floor(start_offset * (1 / stream.time_base)))
166
+ if end_offset != float("inf"):
167
+ end_offset = int(math.ceil(end_offset * (1 / stream.time_base)))
168
+ else:
169
+ warnings.warn("The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.")
170
+
171
+ frames = {}
172
+ should_buffer = True
173
+ max_buffer_size = 5
174
+ if stream.type == "video":
175
+ # DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt)
176
+ # so need to buffer some extra frames to sort everything
177
+ # properly
178
+ extradata = stream.codec_context.extradata
179
+ # overly complicated way of finding if `divx_packed` is set, following
180
+ # https://github.com/FFmpeg/FFmpeg/commit/d5a21172283572af587b3d939eba0091484d3263
181
+ if extradata and b"DivX" in extradata:
182
+ # can't use regex directly because of some weird characters sometimes...
183
+ pos = extradata.find(b"DivX")
184
+ d = extradata[pos:]
185
+ o = re.search(rb"DivX(\d+)Build(\d+)(\w)", d)
186
+ if o is None:
187
+ o = re.search(rb"DivX(\d+)b(\d+)(\w)", d)
188
+ if o is not None:
189
+ should_buffer = o.group(3) == b"p"
190
+ seek_offset = start_offset
191
+ # some files don't seek to the right location, so better be safe here
192
+ seek_offset = max(seek_offset - 1, 0)
193
+ if should_buffer:
194
+ # FIXME this is kind of a hack, but we will jump to the previous keyframe
195
+ # so this will be safe
196
+ seek_offset = max(seek_offset - max_buffer_size, 0)
197
+ try:
198
+ # TODO check if stream needs to always be the video stream here or not
199
+ container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
200
+ except av.AVError:
201
+ # TODO add some warnings in this case
202
+ # print("Corrupted file?", container.name)
203
+ return []
204
+ buffer_count = 0
205
+ try:
206
+ for _idx, frame in enumerate(container.decode(**stream_name)):
207
+ frames[frame.pts] = frame
208
+ if frame.pts >= end_offset:
209
+ if should_buffer and buffer_count < max_buffer_size:
210
+ buffer_count += 1
211
+ continue
212
+ break
213
+ except av.AVError:
214
+ # TODO add a warning
215
+ pass
216
+ # ensure that the results are sorted wrt the pts
217
+ result = [frames[i] for i in sorted(frames) if start_offset <= frames[i].pts <= end_offset]
218
+ if len(frames) > 0 and start_offset > 0 and start_offset not in frames:
219
+ # if there is no frame that exactly matches the pts of start_offset
220
+ # add the last frame smaller than start_offset, to guarantee that
221
+ # we will have all the necessary data. This is most useful for audio
222
+ preceding_frames = [i for i in frames if i < start_offset]
223
+ if len(preceding_frames) > 0:
224
+ first_frame_pts = max(preceding_frames)
225
+ result.insert(0, frames[first_frame_pts])
226
+ return result
227
+
228
+
229
+ def _align_audio_frames(
230
+ aframes: torch.Tensor, audio_frames: List["av.frame.Frame"], ref_start: int, ref_end: float
231
+ ) -> torch.Tensor:
232
+ start, end = audio_frames[0].pts, audio_frames[-1].pts
233
+ total_aframes = aframes.shape[1]
234
+ step_per_aframe = (end - start + 1) / total_aframes
235
+ s_idx = 0
236
+ e_idx = total_aframes
237
+ if start < ref_start:
238
+ s_idx = int((ref_start - start) / step_per_aframe)
239
+ if end > ref_end:
240
+ e_idx = int((ref_end - end) / step_per_aframe)
241
+ return aframes[:, s_idx:e_idx]
242
+
243
+
244
+ def read_video(
245
+ filename: str,
246
+ start_pts: Union[float, Fraction] = 0,
247
+ end_pts: Optional[Union[float, Fraction]] = None,
248
+ pts_unit: str = "pts",
249
+ output_format: str = "THWC",
250
+ ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
251
+ """
252
+ Reads a video from a file, returning both the video frames and the audio frames
253
+
254
+ .. warning::
255
+
256
+ In the near future, we intend to centralize PyTorch's video decoding
257
+ capabilities within the `torchcodec
258
+ <https://github.com/pytorch/torchcodec>`_ project. We encourage you to
259
+ try it out and share your feedback, as the torchvision video decoders
260
+ will eventually be deprecated.
261
+
262
+ Args:
263
+ filename (str): path to the video file. If using the pyav backend, this can be whatever ``av.open`` accepts.
264
+ start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
265
+ The start presentation time of the video
266
+ end_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
267
+ The end presentation time
268
+ pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted,
269
+ either 'pts' or 'sec'. Defaults to 'pts'.
270
+ output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".
271
+
272
+ Returns:
273
+ vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames
274
+ aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points
275
+ info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int)
276
+ """
277
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
278
+ _log_api_usage_once(read_video)
279
+
280
+ output_format = output_format.upper()
281
+ if output_format not in ("THWC", "TCHW"):
282
+ raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
283
+
284
+ from torchvision import get_video_backend
285
+
286
+ if get_video_backend() != "pyav":
287
+ if not os.path.exists(filename):
288
+ raise RuntimeError(f"File not found: {filename}")
289
+ vframes, aframes, info = _video_opt._read_video(filename, start_pts, end_pts, pts_unit)
290
+ else:
291
+ _check_av_available()
292
+
293
+ if end_pts is None:
294
+ end_pts = float("inf")
295
+
296
+ if end_pts < start_pts:
297
+ raise ValueError(
298
+ f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}"
299
+ )
300
+
301
+ info = {}
302
+ video_frames = []
303
+ audio_frames = []
304
+ audio_timebase = _video_opt.default_timebase
305
+
306
+ try:
307
+ with av.open(filename, metadata_errors="ignore") as container:
308
+ if container.streams.audio:
309
+ audio_timebase = container.streams.audio[0].time_base
310
+ if container.streams.video:
311
+ video_frames = _read_from_stream(
312
+ container,
313
+ start_pts,
314
+ end_pts,
315
+ pts_unit,
316
+ container.streams.video[0],
317
+ {"video": 0},
318
+ )
319
+ video_fps = container.streams.video[0].average_rate
320
+ # guard against potentially corrupted files
321
+ if video_fps is not None:
322
+ info["video_fps"] = float(video_fps)
323
+
324
+ if container.streams.audio:
325
+ audio_frames = _read_from_stream(
326
+ container,
327
+ start_pts,
328
+ end_pts,
329
+ pts_unit,
330
+ container.streams.audio[0],
331
+ {"audio": 0},
332
+ )
333
+ info["audio_fps"] = container.streams.audio[0].rate
334
+
335
+ except av.AVError:
336
+ # TODO raise a warning?
337
+ pass
338
+
339
+ vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames]
340
+ aframes_list = [frame.to_ndarray() for frame in audio_frames]
341
+
342
+ if vframes_list:
343
+ vframes = torch.as_tensor(np.stack(vframes_list))
344
+ else:
345
+ vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)
346
+
347
+ if aframes_list:
348
+ aframes = np.concatenate(aframes_list, 1)
349
+ aframes = torch.as_tensor(aframes)
350
+ if pts_unit == "sec":
351
+ start_pts = int(math.floor(start_pts * (1 / audio_timebase)))
352
+ if end_pts != float("inf"):
353
+ end_pts = int(math.ceil(end_pts * (1 / audio_timebase)))
354
+ aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
355
+ else:
356
+ aframes = torch.empty((1, 0), dtype=torch.float32)
357
+
358
+ if output_format == "TCHW":
359
+ # [T,H,W,C] --> [T,C,H,W]
360
+ vframes = vframes.permute(0, 3, 1, 2)
361
+
362
+ return vframes, aframes, info
363
+
364
+
365
+ def _can_read_timestamps_from_packets(container: "av.container.Container") -> bool:
366
+ extradata = container.streams[0].codec_context.extradata
367
+ if extradata is None:
368
+ return False
369
+ if b"Lavc" in extradata:
370
+ return True
371
+ return False
372
+
373
+
374
+ def _decode_video_timestamps(container: "av.container.Container") -> List[int]:
375
+ if _can_read_timestamps_from_packets(container):
376
+ # fast path
377
+ return [x.pts for x in container.demux(video=0) if x.pts is not None]
378
+ else:
379
+ return [x.pts for x in container.decode(video=0) if x.pts is not None]
380
+
381
+
382
+ def read_video_timestamps(filename: str, pts_unit: str = "pts") -> Tuple[List[int], Optional[float]]:
383
+ """
384
+ List the video frames timestamps.
385
+
386
+ .. warning::
387
+
388
+ In the near future, we intend to centralize PyTorch's video decoding
389
+ capabilities within the `torchcodec
390
+ <https://github.com/pytorch/torchcodec>`_ project. We encourage you to
391
+ try it out and share your feedback, as the torchvision video decoders
392
+ will eventually be deprecated.
393
+
394
+ Note that the function decodes the whole video frame-by-frame.
395
+
396
+ Args:
397
+ filename (str): path to the video file
398
+ pts_unit (str, optional): unit in which timestamp values will be returned
399
+ either 'pts' or 'sec'. Defaults to 'pts'.
400
+
401
+ Returns:
402
+ pts (List[int] if pts_unit = 'pts', List[Fraction] if pts_unit = 'sec'):
403
+ presentation timestamps for each one of the frames in the video.
404
+ video_fps (float, optional): the frame rate for the video
405
+
406
+ """
407
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
408
+ _log_api_usage_once(read_video_timestamps)
409
+ from torchvision import get_video_backend
410
+
411
+ if get_video_backend() != "pyav":
412
+ return _video_opt._read_video_timestamps(filename, pts_unit)
413
+
414
+ _check_av_available()
415
+
416
+ video_fps = None
417
+ pts = []
418
+
419
+ try:
420
+ with av.open(filename, metadata_errors="ignore") as container:
421
+ if container.streams.video:
422
+ video_stream = container.streams.video[0]
423
+ video_time_base = video_stream.time_base
424
+ try:
425
+ pts = _decode_video_timestamps(container)
426
+ except av.AVError:
427
+ warnings.warn(f"Failed decoding frames for file {filename}")
428
+ video_fps = float(video_stream.average_rate)
429
+ except av.AVError as e:
430
+ msg = f"Failed to open container for {filename}; Caught error: {e}"
431
+ warnings.warn(msg, RuntimeWarning)
432
+
433
+ pts.sort()
434
+
435
+ if pts_unit == "sec":
436
+ pts = [x * video_time_base for x in pts]
437
+
438
+ return pts, video_fps
.venv/lib/python3.11/site-packages/torchvision/io/video_reader.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import warnings
3
+
4
+ from typing import Any, Dict, Iterator
5
+
6
+ import torch
7
+
8
+ from ..utils import _log_api_usage_once
9
+
10
+ from ._video_opt import _HAS_CPU_VIDEO_DECODER
11
+
12
+ if _HAS_CPU_VIDEO_DECODER:
13
+
14
+ def _has_video_opt() -> bool:
15
+ return True
16
+
17
+ else:
18
+
19
+ def _has_video_opt() -> bool:
20
+ return False
21
+
22
+
23
+ try:
24
+ import av
25
+
26
+ av.logging.set_level(av.logging.ERROR)
27
+ if not hasattr(av.video.frame.VideoFrame, "pict_type"):
28
+ av = ImportError(
29
+ """\
30
+ Your version of PyAV is too old for the necessary video operations in torchvision.
31
+ If you are on Python 3.5, you will have to build from source (the conda-forge
32
+ packages are not up-to-date). See
33
+ https://github.com/mikeboers/PyAV#installation for instructions on how to
34
+ install PyAV on your system.
35
+ """
36
+ )
37
+ except ImportError:
38
+ av = ImportError(
39
+ """\
40
+ PyAV is not installed, and is necessary for the video operations in torchvision.
41
+ See https://github.com/mikeboers/PyAV#installation for instructions on how to
42
+ install PyAV on your system.
43
+ """
44
+ )
45
+
46
+
47
+ class VideoReader:
48
+ """
49
+ Fine-grained video-reading API.
50
+ Supports frame-by-frame reading of various streams from a single video
51
+ container. Much like previous video_reader API it supports the following
52
+ backends: video_reader, pyav, and cuda.
53
+ Backends can be set via `torchvision.set_video_backend` function.
54
+
55
+ .. warning::
56
+
57
+ In the near future, we intend to centralize PyTorch's video decoding
58
+ capabilities within the `torchcodec
59
+ <https://github.com/pytorch/torchcodec>`_ project. We encourage you to
60
+ try it out and share your feedback, as the torchvision video decoders
61
+ will eventually be deprecated.
62
+
63
+ .. betastatus:: VideoReader class
64
+
65
+ Example:
66
+ The following examples creates a :mod:`VideoReader` object, seeks into 2s
67
+ point, and returns a single frame::
68
+
69
+ import torchvision
70
+ video_path = "path_to_a_test_video"
71
+ reader = torchvision.io.VideoReader(video_path, "video")
72
+ reader.seek(2.0)
73
+ frame = next(reader)
74
+
75
+ :mod:`VideoReader` implements the iterable API, which makes it suitable to
76
+ using it in conjunction with :mod:`itertools` for more advanced reading.
77
+ As such, we can use a :mod:`VideoReader` instance inside for loops::
78
+
79
+ reader.seek(2)
80
+ for frame in reader:
81
+ frames.append(frame['data'])
82
+ # additionally, `seek` implements a fluent API, so we can do
83
+ for frame in reader.seek(2):
84
+ frames.append(frame['data'])
85
+
86
+ With :mod:`itertools`, we can read all frames between 2 and 5 seconds with the
87
+ following code::
88
+
89
+ for frame in itertools.takewhile(lambda x: x['pts'] <= 5, reader.seek(2)):
90
+ frames.append(frame['data'])
91
+
92
+ and similarly, reading 10 frames after the 2s timestamp can be achieved
93
+ as follows::
94
+
95
+ for frame in itertools.islice(reader.seek(2), 10):
96
+ frames.append(frame['data'])
97
+
98
+ .. note::
99
+
100
+ Each stream descriptor consists of two parts: stream type (e.g. 'video') and
101
+ a unique stream id (which are determined by the video encoding).
102
+ In this way, if the video container contains multiple
103
+ streams of the same type, users can access the one they want.
104
+ If only stream type is passed, the decoder auto-detects first stream of that type.
105
+
106
+ Args:
107
+ src (string, bytes object, or tensor): The media source.
108
+ If string-type, it must be a file path supported by FFMPEG.
109
+ If bytes, should be an in-memory representation of a file supported by FFMPEG.
110
+ If Tensor, it is interpreted internally as byte buffer.
111
+ It must be one-dimensional, of type ``torch.uint8``.
112
+
113
+ stream (string, optional): descriptor of the required stream, followed by the stream id,
114
+ in the format ``{stream_type}:{stream_id}``. Defaults to ``"video:0"``.
115
+ Currently available options include ``['video', 'audio']``
116
+
117
+ num_threads (int, optional): number of threads used by the codec to decode video.
118
+ Default value (0) enables multithreading with codec-dependent heuristic. The performance
119
+ will depend on the version of FFMPEG codecs supported.
120
+ """
121
+
122
+ def __init__(
123
+ self,
124
+ src: str,
125
+ stream: str = "video",
126
+ num_threads: int = 0,
127
+ ) -> None:
128
+ _log_api_usage_once(self)
129
+ from .. import get_video_backend
130
+
131
+ self.backend = get_video_backend()
132
+ if isinstance(src, str):
133
+ if not src:
134
+ raise ValueError("src cannot be empty")
135
+ elif isinstance(src, bytes):
136
+ if self.backend in ["cuda"]:
137
+ raise RuntimeError(
138
+ "VideoReader cannot be initialized from bytes object when using cuda or pyav backend."
139
+ )
140
+ elif self.backend == "pyav":
141
+ src = io.BytesIO(src)
142
+ else:
143
+ with warnings.catch_warnings():
144
+ # Ignore the warning because we actually don't modify the buffer in this function
145
+ warnings.filterwarnings("ignore", message="The given buffer is not writable")
146
+ src = torch.frombuffer(src, dtype=torch.uint8)
147
+ elif isinstance(src, torch.Tensor):
148
+ if self.backend in ["cuda", "pyav"]:
149
+ raise RuntimeError(
150
+ "VideoReader cannot be initialized from Tensor object when using cuda or pyav backend."
151
+ )
152
+ else:
153
+ raise ValueError(f"src must be either string, Tensor or bytes object. Got {type(src)}")
154
+
155
+ if self.backend == "cuda":
156
+ device = torch.device("cuda")
157
+ self._c = torch.classes.torchvision.GPUDecoder(src, device)
158
+
159
+ elif self.backend == "video_reader":
160
+ if isinstance(src, str):
161
+ self._c = torch.classes.torchvision.Video(src, stream, num_threads)
162
+ elif isinstance(src, torch.Tensor):
163
+ self._c = torch.classes.torchvision.Video("", "", 0)
164
+ self._c.init_from_memory(src, stream, num_threads)
165
+
166
+ elif self.backend == "pyav":
167
+ self.container = av.open(src, metadata_errors="ignore")
168
+ # TODO: load metadata
169
+ stream_type = stream.split(":")[0]
170
+ stream_id = 0 if len(stream.split(":")) == 1 else int(stream.split(":")[1])
171
+ self.pyav_stream = {stream_type: stream_id}
172
+ self._c = self.container.decode(**self.pyav_stream)
173
+
174
+ # TODO: add extradata exception
175
+
176
+ else:
177
+ raise RuntimeError("Unknown video backend: {}".format(self.backend))
178
+
179
+ def __next__(self) -> Dict[str, Any]:
180
+ """Decodes and returns the next frame of the current stream.
181
+ Frames are encoded as a dict with mandatory
182
+ data and pts fields, where data is a tensor, and pts is a
183
+ presentation timestamp of the frame expressed in seconds
184
+ as a float.
185
+
186
+ Returns:
187
+ (dict): a dictionary and containing decoded frame (``data``)
188
+ and corresponding timestamp (``pts``) in seconds
189
+
190
+ """
191
+ if self.backend == "cuda":
192
+ frame = self._c.next()
193
+ if frame.numel() == 0:
194
+ raise StopIteration
195
+ return {"data": frame, "pts": None}
196
+ elif self.backend == "video_reader":
197
+ frame, pts = self._c.next()
198
+ else:
199
+ try:
200
+ frame = next(self._c)
201
+ pts = float(frame.pts * frame.time_base)
202
+ if "video" in self.pyav_stream:
203
+ frame = torch.as_tensor(frame.to_rgb().to_ndarray()).permute(2, 0, 1)
204
+ elif "audio" in self.pyav_stream:
205
+ frame = torch.as_tensor(frame.to_ndarray()).permute(1, 0)
206
+ else:
207
+ frame = None
208
+ except av.error.EOFError:
209
+ raise StopIteration
210
+
211
+ if frame.numel() == 0:
212
+ raise StopIteration
213
+
214
+ return {"data": frame, "pts": pts}
215
+
216
+ def __iter__(self) -> Iterator[Dict[str, Any]]:
217
+ return self
218
+
219
+ def seek(self, time_s: float, keyframes_only: bool = False) -> "VideoReader":
220
+ """Seek within current stream.
221
+
222
+ Args:
223
+ time_s (float): seek time in seconds
224
+ keyframes_only (bool): allow to seek only to keyframes
225
+
226
+ .. note::
227
+ Current implementation is the so-called precise seek. This
228
+ means following seek, call to :mod:`next()` will return the
229
+ frame with the exact timestamp if it exists or
230
+ the first frame with timestamp larger than ``time_s``.
231
+ """
232
+ if self.backend in ["cuda", "video_reader"]:
233
+ self._c.seek(time_s, keyframes_only)
234
+ else:
235
+ # handle special case as pyav doesn't catch it
236
+ if time_s < 0:
237
+ time_s = 0
238
+ temp_str = self.container.streams.get(**self.pyav_stream)[0]
239
+ offset = int(round(time_s / temp_str.time_base))
240
+ if not keyframes_only:
241
+ warnings.warn("Accurate seek is not implemented for pyav backend")
242
+ self.container.seek(offset, backward=True, any_frame=False, stream=temp_str)
243
+ self._c = self.container.decode(**self.pyav_stream)
244
+ return self
245
+
246
+ def get_metadata(self) -> Dict[str, Any]:
247
+ """Returns video metadata
248
+
249
+ Returns:
250
+ (dict): dictionary containing duration and frame rate for every stream
251
+ """
252
+ if self.backend == "pyav":
253
+ metadata = {} # type: Dict[str, Any]
254
+ for stream in self.container.streams:
255
+ if stream.type not in metadata:
256
+ if stream.type == "video":
257
+ rate_n = "fps"
258
+ else:
259
+ rate_n = "framerate"
260
+ metadata[stream.type] = {rate_n: [], "duration": []}
261
+
262
+ rate = getattr(stream, "average_rate", None) or stream.sample_rate
263
+
264
+ metadata[stream.type]["duration"].append(float(stream.duration * stream.time_base))
265
+ metadata[stream.type][rate_n].append(float(rate))
266
+ return metadata
267
+ return self._c.get_metadata()
268
+
269
+ def set_current_stream(self, stream: str) -> bool:
270
+ """Set current stream.
271
+ Explicitly define the stream we are operating on.
272
+
273
+ Args:
274
+ stream (string): descriptor of the required stream. Defaults to ``"video:0"``
275
+ Currently available stream types include ``['video', 'audio']``.
276
+ Each descriptor consists of two parts: stream type (e.g. 'video') and
277
+ a unique stream id (which are determined by video encoding).
278
+ In this way, if the video container contains multiple
279
+ streams of the same type, users can access the one they want.
280
+ If only stream type is passed, the decoder auto-detects first stream
281
+ of that type and returns it.
282
+
283
+ Returns:
284
+ (bool): True on success, False otherwise
285
+ """
286
+ if self.backend == "cuda":
287
+ warnings.warn("GPU decoding only works with video stream.")
288
+ if self.backend == "pyav":
289
+ stream_type = stream.split(":")[0]
290
+ stream_id = 0 if len(stream.split(":")) == 1 else int(stream.split(":")[1])
291
+ self.pyav_stream = {stream_type: stream_id}
292
+ self._c = self.container.decode(**self.pyav_stream)
293
+ return True
294
+ return self._c.set_current_stream(stream)
.venv/lib/python3.11/site-packages/torchvision/models/detection/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .faster_rcnn import *
2
+ from .fcos import *
3
+ from .keypoint_rcnn import *
4
+ from .mask_rcnn import *
5
+ from .retinanet import *
6
+ from .ssd import *
7
+ from .ssdlite import *
.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (414 Bytes). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/_utils.cpython-311.pyc ADDED
Binary file (28.3 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/anchor_utils.cpython-311.pyc ADDED
Binary file (18.5 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/backbone_utils.cpython-311.pyc ADDED
Binary file (13.9 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/faster_rcnn.cpython-311.pyc ADDED
Binary file (39.1 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/fcos.cpython-311.pyc ADDED
Binary file (42.6 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/generalized_rcnn.cpython-311.pyc ADDED
Binary file (6.29 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/image_list.cpython-311.pyc ADDED
Binary file (1.65 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/keypoint_rcnn.cpython-311.pyc ADDED
Binary file (23.6 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/mask_rcnn.cpython-311.pyc ADDED
Binary file (28 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/retinanet.cpython-311.pyc ADDED
Binary file (42.7 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/roi_heads.cpython-311.pyc ADDED
Binary file (45.6 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/rpn.cpython-311.pyc ADDED
Binary file (20.1 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/ssd.cpython-311.pyc ADDED
Binary file (37.2 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/ssdlite.cpython-311.pyc ADDED
Binary file (18.2 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/transform.cpython-311.pyc ADDED
Binary file (19.6 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/models/detection/_utils.py ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections import OrderedDict
3
+ from typing import Dict, List, Optional, Tuple
4
+
5
+ import torch
6
+ from torch import nn, Tensor
7
+ from torch.nn import functional as F
8
+ from torchvision.ops import complete_box_iou_loss, distance_box_iou_loss, FrozenBatchNorm2d, generalized_box_iou_loss
9
+
10
+
11
+ class BalancedPositiveNegativeSampler:
12
+ """
13
+ This class samples batches, ensuring that they contain a fixed proportion of positives
14
+ """
15
+
16
+ def __init__(self, batch_size_per_image: int, positive_fraction: float) -> None:
17
+ """
18
+ Args:
19
+ batch_size_per_image (int): number of elements to be selected per image
20
+ positive_fraction (float): percentage of positive elements per batch
21
+ """
22
+ self.batch_size_per_image = batch_size_per_image
23
+ self.positive_fraction = positive_fraction
24
+
25
+ def __call__(self, matched_idxs: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
26
+ """
27
+ Args:
28
+ matched_idxs: list of tensors containing -1, 0 or positive values.
29
+ Each tensor corresponds to a specific image.
30
+ -1 values are ignored, 0 are considered as negatives and > 0 as
31
+ positives.
32
+
33
+ Returns:
34
+ pos_idx (list[tensor])
35
+ neg_idx (list[tensor])
36
+
37
+ Returns two lists of binary masks for each image.
38
+ The first list contains the positive elements that were selected,
39
+ and the second list the negative example.
40
+ """
41
+ pos_idx = []
42
+ neg_idx = []
43
+ for matched_idxs_per_image in matched_idxs:
44
+ positive = torch.where(matched_idxs_per_image >= 1)[0]
45
+ negative = torch.where(matched_idxs_per_image == 0)[0]
46
+
47
+ num_pos = int(self.batch_size_per_image * self.positive_fraction)
48
+ # protect against not enough positive examples
49
+ num_pos = min(positive.numel(), num_pos)
50
+ num_neg = self.batch_size_per_image - num_pos
51
+ # protect against not enough negative examples
52
+ num_neg = min(negative.numel(), num_neg)
53
+
54
+ # randomly select positive and negative examples
55
+ perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
56
+ perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]
57
+
58
+ pos_idx_per_image = positive[perm1]
59
+ neg_idx_per_image = negative[perm2]
60
+
61
+ # create binary mask from indices
62
+ pos_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
63
+ neg_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
64
+
65
+ pos_idx_per_image_mask[pos_idx_per_image] = 1
66
+ neg_idx_per_image_mask[neg_idx_per_image] = 1
67
+
68
+ pos_idx.append(pos_idx_per_image_mask)
69
+ neg_idx.append(neg_idx_per_image_mask)
70
+
71
+ return pos_idx, neg_idx
72
+
73
+
74
+ @torch.jit._script_if_tracing
75
+ def encode_boxes(reference_boxes: Tensor, proposals: Tensor, weights: Tensor) -> Tensor:
76
+ """
77
+ Encode a set of proposals with respect to some
78
+ reference boxes
79
+
80
+ Args:
81
+ reference_boxes (Tensor): reference boxes
82
+ proposals (Tensor): boxes to be encoded
83
+ weights (Tensor[4]): the weights for ``(x, y, w, h)``
84
+ """
85
+
86
+ # perform some unpacking to make it JIT-fusion friendly
87
+ wx = weights[0]
88
+ wy = weights[1]
89
+ ww = weights[2]
90
+ wh = weights[3]
91
+
92
+ proposals_x1 = proposals[:, 0].unsqueeze(1)
93
+ proposals_y1 = proposals[:, 1].unsqueeze(1)
94
+ proposals_x2 = proposals[:, 2].unsqueeze(1)
95
+ proposals_y2 = proposals[:, 3].unsqueeze(1)
96
+
97
+ reference_boxes_x1 = reference_boxes[:, 0].unsqueeze(1)
98
+ reference_boxes_y1 = reference_boxes[:, 1].unsqueeze(1)
99
+ reference_boxes_x2 = reference_boxes[:, 2].unsqueeze(1)
100
+ reference_boxes_y2 = reference_boxes[:, 3].unsqueeze(1)
101
+
102
+ # implementation starts here
103
+ ex_widths = proposals_x2 - proposals_x1
104
+ ex_heights = proposals_y2 - proposals_y1
105
+ ex_ctr_x = proposals_x1 + 0.5 * ex_widths
106
+ ex_ctr_y = proposals_y1 + 0.5 * ex_heights
107
+
108
+ gt_widths = reference_boxes_x2 - reference_boxes_x1
109
+ gt_heights = reference_boxes_y2 - reference_boxes_y1
110
+ gt_ctr_x = reference_boxes_x1 + 0.5 * gt_widths
111
+ gt_ctr_y = reference_boxes_y1 + 0.5 * gt_heights
112
+
113
+ targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths
114
+ targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights
115
+ targets_dw = ww * torch.log(gt_widths / ex_widths)
116
+ targets_dh = wh * torch.log(gt_heights / ex_heights)
117
+
118
+ targets = torch.cat((targets_dx, targets_dy, targets_dw, targets_dh), dim=1)
119
+ return targets
120
+
121
+
122
+ class BoxCoder:
123
+ """
124
+ This class encodes and decodes a set of bounding boxes into
125
+ the representation used for training the regressors.
126
+ """
127
+
128
+ def __init__(
129
+ self, weights: Tuple[float, float, float, float], bbox_xform_clip: float = math.log(1000.0 / 16)
130
+ ) -> None:
131
+ """
132
+ Args:
133
+ weights (4-element tuple)
134
+ bbox_xform_clip (float)
135
+ """
136
+ self.weights = weights
137
+ self.bbox_xform_clip = bbox_xform_clip
138
+
139
+ def encode(self, reference_boxes: List[Tensor], proposals: List[Tensor]) -> List[Tensor]:
140
+ boxes_per_image = [len(b) for b in reference_boxes]
141
+ reference_boxes = torch.cat(reference_boxes, dim=0)
142
+ proposals = torch.cat(proposals, dim=0)
143
+ targets = self.encode_single(reference_boxes, proposals)
144
+ return targets.split(boxes_per_image, 0)
145
+
146
+ def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
147
+ """
148
+ Encode a set of proposals with respect to some
149
+ reference boxes
150
+
151
+ Args:
152
+ reference_boxes (Tensor): reference boxes
153
+ proposals (Tensor): boxes to be encoded
154
+ """
155
+ dtype = reference_boxes.dtype
156
+ device = reference_boxes.device
157
+ weights = torch.as_tensor(self.weights, dtype=dtype, device=device)
158
+ targets = encode_boxes(reference_boxes, proposals, weights)
159
+
160
+ return targets
161
+
162
+ def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor:
163
+ torch._assert(
164
+ isinstance(boxes, (list, tuple)),
165
+ "This function expects boxes of type list or tuple.",
166
+ )
167
+ torch._assert(
168
+ isinstance(rel_codes, torch.Tensor),
169
+ "This function expects rel_codes of type torch.Tensor.",
170
+ )
171
+ boxes_per_image = [b.size(0) for b in boxes]
172
+ concat_boxes = torch.cat(boxes, dim=0)
173
+ box_sum = 0
174
+ for val in boxes_per_image:
175
+ box_sum += val
176
+ if box_sum > 0:
177
+ rel_codes = rel_codes.reshape(box_sum, -1)
178
+ pred_boxes = self.decode_single(rel_codes, concat_boxes)
179
+ if box_sum > 0:
180
+ pred_boxes = pred_boxes.reshape(box_sum, -1, 4)
181
+ return pred_boxes
182
+
183
+ def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
184
+ """
185
+ From a set of original boxes and encoded relative box offsets,
186
+ get the decoded boxes.
187
+
188
+ Args:
189
+ rel_codes (Tensor): encoded boxes
190
+ boxes (Tensor): reference boxes.
191
+ """
192
+
193
+ boxes = boxes.to(rel_codes.dtype)
194
+
195
+ widths = boxes[:, 2] - boxes[:, 0]
196
+ heights = boxes[:, 3] - boxes[:, 1]
197
+ ctr_x = boxes[:, 0] + 0.5 * widths
198
+ ctr_y = boxes[:, 1] + 0.5 * heights
199
+
200
+ wx, wy, ww, wh = self.weights
201
+ dx = rel_codes[:, 0::4] / wx
202
+ dy = rel_codes[:, 1::4] / wy
203
+ dw = rel_codes[:, 2::4] / ww
204
+ dh = rel_codes[:, 3::4] / wh
205
+
206
+ # Prevent sending too large values into torch.exp()
207
+ dw = torch.clamp(dw, max=self.bbox_xform_clip)
208
+ dh = torch.clamp(dh, max=self.bbox_xform_clip)
209
+
210
+ pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
211
+ pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
212
+ pred_w = torch.exp(dw) * widths[:, None]
213
+ pred_h = torch.exp(dh) * heights[:, None]
214
+
215
+ # Distance from center to box's corner.
216
+ c_to_c_h = torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h
217
+ c_to_c_w = torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w
218
+
219
+ pred_boxes1 = pred_ctr_x - c_to_c_w
220
+ pred_boxes2 = pred_ctr_y - c_to_c_h
221
+ pred_boxes3 = pred_ctr_x + c_to_c_w
222
+ pred_boxes4 = pred_ctr_y + c_to_c_h
223
+ pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=2).flatten(1)
224
+ return pred_boxes
225
+
226
+
227
+ class BoxLinearCoder:
228
+ """
229
+ The linear box-to-box transform defined in FCOS. The transformation is parameterized
230
+ by the distance from the center of (square) src box to 4 edges of the target box.
231
+ """
232
+
233
+ def __init__(self, normalize_by_size: bool = True) -> None:
234
+ """
235
+ Args:
236
+ normalize_by_size (bool): normalize deltas by the size of src (anchor) boxes.
237
+ """
238
+ self.normalize_by_size = normalize_by_size
239
+
240
+ def encode(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
241
+ """
242
+ Encode a set of proposals with respect to some reference boxes
243
+
244
+ Args:
245
+ reference_boxes (Tensor): reference boxes
246
+ proposals (Tensor): boxes to be encoded
247
+
248
+ Returns:
249
+ Tensor: the encoded relative box offsets that can be used to
250
+ decode the boxes.
251
+
252
+ """
253
+
254
+ # get the center of reference_boxes
255
+ reference_boxes_ctr_x = 0.5 * (reference_boxes[..., 0] + reference_boxes[..., 2])
256
+ reference_boxes_ctr_y = 0.5 * (reference_boxes[..., 1] + reference_boxes[..., 3])
257
+
258
+ # get box regression transformation deltas
259
+ target_l = reference_boxes_ctr_x - proposals[..., 0]
260
+ target_t = reference_boxes_ctr_y - proposals[..., 1]
261
+ target_r = proposals[..., 2] - reference_boxes_ctr_x
262
+ target_b = proposals[..., 3] - reference_boxes_ctr_y
263
+
264
+ targets = torch.stack((target_l, target_t, target_r, target_b), dim=-1)
265
+
266
+ if self.normalize_by_size:
267
+ reference_boxes_w = reference_boxes[..., 2] - reference_boxes[..., 0]
268
+ reference_boxes_h = reference_boxes[..., 3] - reference_boxes[..., 1]
269
+ reference_boxes_size = torch.stack(
270
+ (reference_boxes_w, reference_boxes_h, reference_boxes_w, reference_boxes_h), dim=-1
271
+ )
272
+ targets = targets / reference_boxes_size
273
+ return targets
274
+
275
+ def decode(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
276
+
277
+ """
278
+ From a set of original boxes and encoded relative box offsets,
279
+ get the decoded boxes.
280
+
281
+ Args:
282
+ rel_codes (Tensor): encoded boxes
283
+ boxes (Tensor): reference boxes.
284
+
285
+ Returns:
286
+ Tensor: the predicted boxes with the encoded relative box offsets.
287
+
288
+ .. note::
289
+ This method assumes that ``rel_codes`` and ``boxes`` have same size for 0th dimension. i.e. ``len(rel_codes) == len(boxes)``.
290
+
291
+ """
292
+
293
+ boxes = boxes.to(dtype=rel_codes.dtype)
294
+
295
+ ctr_x = 0.5 * (boxes[..., 0] + boxes[..., 2])
296
+ ctr_y = 0.5 * (boxes[..., 1] + boxes[..., 3])
297
+
298
+ if self.normalize_by_size:
299
+ boxes_w = boxes[..., 2] - boxes[..., 0]
300
+ boxes_h = boxes[..., 3] - boxes[..., 1]
301
+
302
+ list_box_size = torch.stack((boxes_w, boxes_h, boxes_w, boxes_h), dim=-1)
303
+ rel_codes = rel_codes * list_box_size
304
+
305
+ pred_boxes1 = ctr_x - rel_codes[..., 0]
306
+ pred_boxes2 = ctr_y - rel_codes[..., 1]
307
+ pred_boxes3 = ctr_x + rel_codes[..., 2]
308
+ pred_boxes4 = ctr_y + rel_codes[..., 3]
309
+
310
+ pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=-1)
311
+ return pred_boxes
312
+
313
+
314
+ class Matcher:
315
+ """
316
+ This class assigns to each predicted "element" (e.g., a box) a ground-truth
317
+ element. Each predicted element will have exactly zero or one matches; each
318
+ ground-truth element may be assigned to zero or more predicted elements.
319
+
320
+ Matching is based on the MxN match_quality_matrix, that characterizes how well
321
+ each (ground-truth, predicted)-pair match. For example, if the elements are
322
+ boxes, the matrix may contain box IoU overlap values.
323
+
324
+ The matcher returns a tensor of size N containing the index of the ground-truth
325
+ element m that matches to prediction n. If there is no match, a negative value
326
+ is returned.
327
+ """
328
+
329
+ BELOW_LOW_THRESHOLD = -1
330
+ BETWEEN_THRESHOLDS = -2
331
+
332
+ __annotations__ = {
333
+ "BELOW_LOW_THRESHOLD": int,
334
+ "BETWEEN_THRESHOLDS": int,
335
+ }
336
+
337
+ def __init__(self, high_threshold: float, low_threshold: float, allow_low_quality_matches: bool = False) -> None:
338
+ """
339
+ Args:
340
+ high_threshold (float): quality values greater than or equal to
341
+ this value are candidate matches.
342
+ low_threshold (float): a lower quality threshold used to stratify
343
+ matches into three levels:
344
+ 1) matches >= high_threshold
345
+ 2) BETWEEN_THRESHOLDS matches in [low_threshold, high_threshold)
346
+ 3) BELOW_LOW_THRESHOLD matches in [0, low_threshold)
347
+ allow_low_quality_matches (bool): if True, produce additional matches
348
+ for predictions that have only low-quality match candidates. See
349
+ set_low_quality_matches_ for more details.
350
+ """
351
+ self.BELOW_LOW_THRESHOLD = -1
352
+ self.BETWEEN_THRESHOLDS = -2
353
+ torch._assert(low_threshold <= high_threshold, "low_threshold should be <= high_threshold")
354
+ self.high_threshold = high_threshold
355
+ self.low_threshold = low_threshold
356
+ self.allow_low_quality_matches = allow_low_quality_matches
357
+
358
+ def __call__(self, match_quality_matrix: Tensor) -> Tensor:
359
+ """
360
+ Args:
361
+ match_quality_matrix (Tensor[float]): an MxN tensor, containing the
362
+ pairwise quality between M ground-truth elements and N predicted elements.
363
+
364
+ Returns:
365
+ matches (Tensor[int64]): an N tensor where N[i] is a matched gt in
366
+ [0, M - 1] or a negative value indicating that prediction i could not
367
+ be matched.
368
+ """
369
+ if match_quality_matrix.numel() == 0:
370
+ # empty targets or proposals not supported during training
371
+ if match_quality_matrix.shape[0] == 0:
372
+ raise ValueError("No ground-truth boxes available for one of the images during training")
373
+ else:
374
+ raise ValueError("No proposal boxes available for one of the images during training")
375
+
376
+ # match_quality_matrix is M (gt) x N (predicted)
377
+ # Max over gt elements (dim 0) to find best gt candidate for each prediction
378
+ matched_vals, matches = match_quality_matrix.max(dim=0)
379
+ if self.allow_low_quality_matches:
380
+ all_matches = matches.clone()
381
+ else:
382
+ all_matches = None # type: ignore[assignment]
383
+
384
+ # Assign candidate matches with low quality to negative (unassigned) values
385
+ below_low_threshold = matched_vals < self.low_threshold
386
+ between_thresholds = (matched_vals >= self.low_threshold) & (matched_vals < self.high_threshold)
387
+ matches[below_low_threshold] = self.BELOW_LOW_THRESHOLD
388
+ matches[between_thresholds] = self.BETWEEN_THRESHOLDS
389
+
390
+ if self.allow_low_quality_matches:
391
+ if all_matches is None:
392
+ torch._assert(False, "all_matches should not be None")
393
+ else:
394
+ self.set_low_quality_matches_(matches, all_matches, match_quality_matrix)
395
+
396
+ return matches
397
+
398
+ def set_low_quality_matches_(self, matches: Tensor, all_matches: Tensor, match_quality_matrix: Tensor) -> None:
399
+ """
400
+ Produce additional matches for predictions that have only low-quality matches.
401
+ Specifically, for each ground-truth find the set of predictions that have
402
+ maximum overlap with it (including ties); for each prediction in that set, if
403
+ it is unmatched, then match it to the ground-truth with which it has the highest
404
+ quality value.
405
+ """
406
+ # For each gt, find the prediction with which it has the highest quality
407
+ highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)
408
+ # Find the highest quality match available, even if it is low, including ties
409
+ gt_pred_pairs_of_highest_quality = torch.where(match_quality_matrix == highest_quality_foreach_gt[:, None])
410
+ # Example gt_pred_pairs_of_highest_quality:
411
+ # (tensor([0, 1, 1, 2, 2, 3, 3, 4, 5, 5]),
412
+ # tensor([39796, 32055, 32070, 39190, 40255, 40390, 41455, 45470, 45325, 46390]))
413
+ # Each element in the first tensor is a gt index, and each element in second tensor is a prediction index
414
+ # Note how gt items 1, 2, 3, and 5 each have two ties
415
+
416
+ pred_inds_to_update = gt_pred_pairs_of_highest_quality[1]
417
+ matches[pred_inds_to_update] = all_matches[pred_inds_to_update]
418
+
419
+
420
+ class SSDMatcher(Matcher):
421
+ def __init__(self, threshold: float) -> None:
422
+ super().__init__(threshold, threshold, allow_low_quality_matches=False)
423
+
424
+ def __call__(self, match_quality_matrix: Tensor) -> Tensor:
425
+ matches = super().__call__(match_quality_matrix)
426
+
427
+ # For each gt, find the prediction with which it has the highest quality
428
+ _, highest_quality_pred_foreach_gt = match_quality_matrix.max(dim=1)
429
+ matches[highest_quality_pred_foreach_gt] = torch.arange(
430
+ highest_quality_pred_foreach_gt.size(0), dtype=torch.int64, device=highest_quality_pred_foreach_gt.device
431
+ )
432
+
433
+ return matches
434
+
435
+
436
+ def overwrite_eps(model: nn.Module, eps: float) -> None:
437
+ """
438
+ This method overwrites the default eps values of all the
439
+ FrozenBatchNorm2d layers of the model with the provided value.
440
+ This is necessary to address the BC-breaking change introduced
441
+ by the bug-fix at pytorch/vision#2933. The overwrite is applied
442
+ only when the pretrained weights are loaded to maintain compatibility
443
+ with previous versions.
444
+
445
+ Args:
446
+ model (nn.Module): The model on which we perform the overwrite.
447
+ eps (float): The new value of eps.
448
+ """
449
+ for module in model.modules():
450
+ if isinstance(module, FrozenBatchNorm2d):
451
+ module.eps = eps
452
+
453
+
454
+ def retrieve_out_channels(model: nn.Module, size: Tuple[int, int]) -> List[int]:
455
+ """
456
+ This method retrieves the number of output channels of a specific model.
457
+
458
+ Args:
459
+ model (nn.Module): The model for which we estimate the out_channels.
460
+ It should return a single Tensor or an OrderedDict[Tensor].
461
+ size (Tuple[int, int]): The size (wxh) of the input.
462
+
463
+ Returns:
464
+ out_channels (List[int]): A list of the output channels of the model.
465
+ """
466
+ in_training = model.training
467
+ model.eval()
468
+
469
+ with torch.no_grad():
470
+ # Use dummy data to retrieve the feature map sizes to avoid hard-coding their values
471
+ device = next(model.parameters()).device
472
+ tmp_img = torch.zeros((1, 3, size[1], size[0]), device=device)
473
+ features = model(tmp_img)
474
+ if isinstance(features, torch.Tensor):
475
+ features = OrderedDict([("0", features)])
476
+ out_channels = [x.size(1) for x in features.values()]
477
+
478
+ if in_training:
479
+ model.train()
480
+
481
+ return out_channels
482
+
483
+
484
+ @torch.jit.unused
485
+ def _fake_cast_onnx(v: Tensor) -> int:
486
+ return v # type: ignore[return-value]
487
+
488
+
489
+ def _topk_min(input: Tensor, orig_kval: int, axis: int) -> int:
490
+ """
491
+ ONNX spec requires the k-value to be less than or equal to the number of inputs along
492
+ provided dim. Certain models use the number of elements along a particular axis instead of K
493
+ if K exceeds the number of elements along that axis. Previously, python's min() function was
494
+ used to determine whether to use the provided k-value or the specified dim axis value.
495
+
496
+ However, in cases where the model is being exported in tracing mode, python min() is
497
+ static causing the model to be traced incorrectly and eventually fail at the topk node.
498
+ In order to avoid this situation, in tracing mode, torch.min() is used instead.
499
+
500
+ Args:
501
+ input (Tensor): The original input tensor.
502
+ orig_kval (int): The provided k-value.
503
+ axis(int): Axis along which we retrieve the input size.
504
+
505
+ Returns:
506
+ min_kval (int): Appropriately selected k-value.
507
+ """
508
+ if not torch.jit.is_tracing():
509
+ return min(orig_kval, input.size(axis))
510
+ axis_dim_val = torch._shape_as_tensor(input)[axis].unsqueeze(0)
511
+ min_kval = torch.min(torch.cat((torch.tensor([orig_kval], dtype=axis_dim_val.dtype), axis_dim_val), 0))
512
+ return _fake_cast_onnx(min_kval)
513
+
514
+
515
+ def _box_loss(
516
+ type: str,
517
+ box_coder: BoxCoder,
518
+ anchors_per_image: Tensor,
519
+ matched_gt_boxes_per_image: Tensor,
520
+ bbox_regression_per_image: Tensor,
521
+ cnf: Optional[Dict[str, float]] = None,
522
+ ) -> Tensor:
523
+ torch._assert(type in ["l1", "smooth_l1", "ciou", "diou", "giou"], f"Unsupported loss: {type}")
524
+
525
+ if type == "l1":
526
+ target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
527
+ return F.l1_loss(bbox_regression_per_image, target_regression, reduction="sum")
528
+ elif type == "smooth_l1":
529
+ target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
530
+ beta = cnf["beta"] if cnf is not None and "beta" in cnf else 1.0
531
+ return F.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction="sum", beta=beta)
532
+ else:
533
+ bbox_per_image = box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
534
+ eps = cnf["eps"] if cnf is not None and "eps" in cnf else 1e-7
535
+ if type == "ciou":
536
+ return complete_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
537
+ if type == "diou":
538
+ return distance_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
539
+ # otherwise giou
540
+ return generalized_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
.venv/lib/python3.11/site-packages/torchvision/models/detection/anchor_utils.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional
3
+
4
+ import torch
5
+ from torch import nn, Tensor
6
+
7
+ from .image_list import ImageList
8
+
9
+
10
+ class AnchorGenerator(nn.Module):
11
+ """
12
+ Module that generates anchors for a set of feature maps and
13
+ image sizes.
14
+
15
+ The module support computing anchors at multiple sizes and aspect ratios
16
+ per feature map. This module assumes aspect ratio = height / width for
17
+ each anchor.
18
+
19
+ sizes and aspect_ratios should have the same number of elements, and it should
20
+ correspond to the number of feature maps.
21
+
22
+ sizes[i] and aspect_ratios[i] can have an arbitrary number of elements,
23
+ and AnchorGenerator will output a set of sizes[i] * aspect_ratios[i] anchors
24
+ per spatial location for feature map i.
25
+
26
+ Args:
27
+ sizes (Tuple[Tuple[int]]):
28
+ aspect_ratios (Tuple[Tuple[float]]):
29
+ """
30
+
31
+ __annotations__ = {
32
+ "cell_anchors": List[torch.Tensor],
33
+ }
34
+
35
+ def __init__(
36
+ self,
37
+ sizes=((128, 256, 512),),
38
+ aspect_ratios=((0.5, 1.0, 2.0),),
39
+ ):
40
+ super().__init__()
41
+
42
+ if not isinstance(sizes[0], (list, tuple)):
43
+ # TODO change this
44
+ sizes = tuple((s,) for s in sizes)
45
+ if not isinstance(aspect_ratios[0], (list, tuple)):
46
+ aspect_ratios = (aspect_ratios,) * len(sizes)
47
+
48
+ self.sizes = sizes
49
+ self.aspect_ratios = aspect_ratios
50
+ self.cell_anchors = [
51
+ self.generate_anchors(size, aspect_ratio) for size, aspect_ratio in zip(sizes, aspect_ratios)
52
+ ]
53
+
54
+ # TODO: https://github.com/pytorch/pytorch/issues/26792
55
+ # For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
56
+ # (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios)
57
+ # This method assumes aspect ratio = height / width for an anchor.
58
+ def generate_anchors(
59
+ self,
60
+ scales: List[int],
61
+ aspect_ratios: List[float],
62
+ dtype: torch.dtype = torch.float32,
63
+ device: torch.device = torch.device("cpu"),
64
+ ) -> Tensor:
65
+ scales = torch.as_tensor(scales, dtype=dtype, device=device)
66
+ aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
67
+ h_ratios = torch.sqrt(aspect_ratios)
68
+ w_ratios = 1 / h_ratios
69
+
70
+ ws = (w_ratios[:, None] * scales[None, :]).view(-1)
71
+ hs = (h_ratios[:, None] * scales[None, :]).view(-1)
72
+
73
+ base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2
74
+ return base_anchors.round()
75
+
76
+ def set_cell_anchors(self, dtype: torch.dtype, device: torch.device):
77
+ self.cell_anchors = [cell_anchor.to(dtype=dtype, device=device) for cell_anchor in self.cell_anchors]
78
+
79
+ def num_anchors_per_location(self) -> List[int]:
80
+ return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]
81
+
82
+ # For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
83
+ # output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a.
84
+ def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]]) -> List[Tensor]:
85
+ anchors = []
86
+ cell_anchors = self.cell_anchors
87
+ torch._assert(cell_anchors is not None, "cell_anchors should not be None")
88
+ torch._assert(
89
+ len(grid_sizes) == len(strides) == len(cell_anchors),
90
+ "Anchors should be Tuple[Tuple[int]] because each feature "
91
+ "map could potentially have different sizes and aspect ratios. "
92
+ "There needs to be a match between the number of "
93
+ "feature maps passed and the number of sizes / aspect ratios specified.",
94
+ )
95
+
96
+ for size, stride, base_anchors in zip(grid_sizes, strides, cell_anchors):
97
+ grid_height, grid_width = size
98
+ stride_height, stride_width = stride
99
+ device = base_anchors.device
100
+
101
+ # For output anchor, compute [x_center, y_center, x_center, y_center]
102
+ shifts_x = torch.arange(0, grid_width, dtype=torch.int32, device=device) * stride_width
103
+ shifts_y = torch.arange(0, grid_height, dtype=torch.int32, device=device) * stride_height
104
+ shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij")
105
+ shift_x = shift_x.reshape(-1)
106
+ shift_y = shift_y.reshape(-1)
107
+ shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
108
+
109
+ # For every (base anchor, output anchor) pair,
110
+ # offset each zero-centered base anchor by the center of the output anchor.
111
+ anchors.append((shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4))
112
+
113
+ return anchors
114
+
115
+ def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]:
116
+ grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
117
+ image_size = image_list.tensors.shape[-2:]
118
+ dtype, device = feature_maps[0].dtype, feature_maps[0].device
119
+ strides = [
120
+ [
121
+ torch.empty((), dtype=torch.int64, device=device).fill_(image_size[0] // g[0]),
122
+ torch.empty((), dtype=torch.int64, device=device).fill_(image_size[1] // g[1]),
123
+ ]
124
+ for g in grid_sizes
125
+ ]
126
+ self.set_cell_anchors(dtype, device)
127
+ anchors_over_all_feature_maps = self.grid_anchors(grid_sizes, strides)
128
+ anchors: List[List[torch.Tensor]] = []
129
+ for _ in range(len(image_list.image_sizes)):
130
+ anchors_in_image = [anchors_per_feature_map for anchors_per_feature_map in anchors_over_all_feature_maps]
131
+ anchors.append(anchors_in_image)
132
+ anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors]
133
+ return anchors
134
+
135
+
136
+ class DefaultBoxGenerator(nn.Module):
137
+ """
138
+ This module generates the default boxes of SSD for a set of feature maps and image sizes.
139
+
140
+ Args:
141
+ aspect_ratios (List[List[int]]): A list with all the aspect ratios used in each feature map.
142
+ min_ratio (float): The minimum scale :math:`\text{s}_{\text{min}}` of the default boxes used in the estimation
143
+ of the scales of each feature map. It is used only if the ``scales`` parameter is not provided.
144
+ max_ratio (float): The maximum scale :math:`\text{s}_{\text{max}}` of the default boxes used in the estimation
145
+ of the scales of each feature map. It is used only if the ``scales`` parameter is not provided.
146
+ scales (List[float]], optional): The scales of the default boxes. If not provided it will be estimated using
147
+ the ``min_ratio`` and ``max_ratio`` parameters.
148
+ steps (List[int]], optional): It's a hyper-parameter that affects the tiling of default boxes. If not provided
149
+ it will be estimated from the data.
150
+ clip (bool): Whether the standardized values of default boxes should be clipped between 0 and 1. The clipping
151
+ is applied while the boxes are encoded in format ``(cx, cy, w, h)``.
152
+ """
153
+
154
+ def __init__(
155
+ self,
156
+ aspect_ratios: List[List[int]],
157
+ min_ratio: float = 0.15,
158
+ max_ratio: float = 0.9,
159
+ scales: Optional[List[float]] = None,
160
+ steps: Optional[List[int]] = None,
161
+ clip: bool = True,
162
+ ):
163
+ super().__init__()
164
+ if steps is not None and len(aspect_ratios) != len(steps):
165
+ raise ValueError("aspect_ratios and steps should have the same length")
166
+ self.aspect_ratios = aspect_ratios
167
+ self.steps = steps
168
+ self.clip = clip
169
+ num_outputs = len(aspect_ratios)
170
+
171
+ # Estimation of default boxes scales
172
+ if scales is None:
173
+ if num_outputs > 1:
174
+ range_ratio = max_ratio - min_ratio
175
+ self.scales = [min_ratio + range_ratio * k / (num_outputs - 1.0) for k in range(num_outputs)]
176
+ self.scales.append(1.0)
177
+ else:
178
+ self.scales = [min_ratio, max_ratio]
179
+ else:
180
+ self.scales = scales
181
+
182
+ self._wh_pairs = self._generate_wh_pairs(num_outputs)
183
+
184
+ def _generate_wh_pairs(
185
+ self, num_outputs: int, dtype: torch.dtype = torch.float32, device: torch.device = torch.device("cpu")
186
+ ) -> List[Tensor]:
187
+ _wh_pairs: List[Tensor] = []
188
+ for k in range(num_outputs):
189
+ # Adding the 2 default width-height pairs for aspect ratio 1 and scale s'k
190
+ s_k = self.scales[k]
191
+ s_prime_k = math.sqrt(self.scales[k] * self.scales[k + 1])
192
+ wh_pairs = [[s_k, s_k], [s_prime_k, s_prime_k]]
193
+
194
+ # Adding 2 pairs for each aspect ratio of the feature map k
195
+ for ar in self.aspect_ratios[k]:
196
+ sq_ar = math.sqrt(ar)
197
+ w = self.scales[k] * sq_ar
198
+ h = self.scales[k] / sq_ar
199
+ wh_pairs.extend([[w, h], [h, w]])
200
+
201
+ _wh_pairs.append(torch.as_tensor(wh_pairs, dtype=dtype, device=device))
202
+ return _wh_pairs
203
+
204
+ def num_anchors_per_location(self) -> List[int]:
205
+ # Estimate num of anchors based on aspect ratios: 2 default boxes + 2 * ratios of feaure map.
206
+ return [2 + 2 * len(r) for r in self.aspect_ratios]
207
+
208
+ # Default Boxes calculation based on page 6 of SSD paper
209
+ def _grid_default_boxes(
210
+ self, grid_sizes: List[List[int]], image_size: List[int], dtype: torch.dtype = torch.float32
211
+ ) -> Tensor:
212
+ default_boxes = []
213
+ for k, f_k in enumerate(grid_sizes):
214
+ # Now add the default boxes for each width-height pair
215
+ if self.steps is not None:
216
+ x_f_k = image_size[1] / self.steps[k]
217
+ y_f_k = image_size[0] / self.steps[k]
218
+ else:
219
+ y_f_k, x_f_k = f_k
220
+
221
+ shifts_x = ((torch.arange(0, f_k[1]) + 0.5) / x_f_k).to(dtype=dtype)
222
+ shifts_y = ((torch.arange(0, f_k[0]) + 0.5) / y_f_k).to(dtype=dtype)
223
+ shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij")
224
+ shift_x = shift_x.reshape(-1)
225
+ shift_y = shift_y.reshape(-1)
226
+
227
+ shifts = torch.stack((shift_x, shift_y) * len(self._wh_pairs[k]), dim=-1).reshape(-1, 2)
228
+ # Clipping the default boxes while the boxes are encoded in format (cx, cy, w, h)
229
+ _wh_pair = self._wh_pairs[k].clamp(min=0, max=1) if self.clip else self._wh_pairs[k]
230
+ wh_pairs = _wh_pair.repeat((f_k[0] * f_k[1]), 1)
231
+
232
+ default_box = torch.cat((shifts, wh_pairs), dim=1)
233
+
234
+ default_boxes.append(default_box)
235
+
236
+ return torch.cat(default_boxes, dim=0)
237
+
238
+ def __repr__(self) -> str:
239
+ s = (
240
+ f"{self.__class__.__name__}("
241
+ f"aspect_ratios={self.aspect_ratios}"
242
+ f", clip={self.clip}"
243
+ f", scales={self.scales}"
244
+ f", steps={self.steps}"
245
+ ")"
246
+ )
247
+ return s
248
+
249
+ def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]:
250
+ grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
251
+ image_size = image_list.tensors.shape[-2:]
252
+ dtype, device = feature_maps[0].dtype, feature_maps[0].device
253
+ default_boxes = self._grid_default_boxes(grid_sizes, image_size, dtype=dtype)
254
+ default_boxes = default_boxes.to(device)
255
+
256
+ dboxes = []
257
+ x_y_size = torch.tensor([image_size[1], image_size[0]], device=default_boxes.device)
258
+ for _ in image_list.image_sizes:
259
+ dboxes_in_image = default_boxes
260
+ dboxes_in_image = torch.cat(
261
+ [
262
+ (dboxes_in_image[:, :2] - 0.5 * dboxes_in_image[:, 2:]) * x_y_size,
263
+ (dboxes_in_image[:, :2] + 0.5 * dboxes_in_image[:, 2:]) * x_y_size,
264
+ ],
265
+ -1,
266
+ )
267
+ dboxes.append(dboxes_in_image)
268
+ return dboxes
.venv/lib/python3.11/site-packages/torchvision/models/detection/backbone_utils.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import Callable, Dict, List, Optional, Union
3
+
4
+ from torch import nn, Tensor
5
+ from torchvision.ops import misc as misc_nn_ops
6
+ from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool
7
+
8
+ from .. import mobilenet, resnet
9
+ from .._api import _get_enum_from_fn, WeightsEnum
10
+ from .._utils import handle_legacy_interface, IntermediateLayerGetter
11
+
12
+
13
+ class BackboneWithFPN(nn.Module):
14
+ """
15
+ Adds a FPN on top of a model.
16
+ Internally, it uses torchvision.models._utils.IntermediateLayerGetter to
17
+ extract a submodel that returns the feature maps specified in return_layers.
18
+ The same limitations of IntermediateLayerGetter apply here.
19
+ Args:
20
+ backbone (nn.Module)
21
+ return_layers (Dict[name, new_name]): a dict containing the names
22
+ of the modules for which the activations will be returned as
23
+ the key of the dict, and the value of the dict is the name
24
+ of the returned activation (which the user can specify).
25
+ in_channels_list (List[int]): number of channels for each feature map
26
+ that is returned, in the order they are present in the OrderedDict
27
+ out_channels (int): number of channels in the FPN.
28
+ norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
29
+ Attributes:
30
+ out_channels (int): the number of channels in the FPN
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ backbone: nn.Module,
36
+ return_layers: Dict[str, str],
37
+ in_channels_list: List[int],
38
+ out_channels: int,
39
+ extra_blocks: Optional[ExtraFPNBlock] = None,
40
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
41
+ ) -> None:
42
+ super().__init__()
43
+
44
+ if extra_blocks is None:
45
+ extra_blocks = LastLevelMaxPool()
46
+
47
+ self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
48
+ self.fpn = FeaturePyramidNetwork(
49
+ in_channels_list=in_channels_list,
50
+ out_channels=out_channels,
51
+ extra_blocks=extra_blocks,
52
+ norm_layer=norm_layer,
53
+ )
54
+ self.out_channels = out_channels
55
+
56
+ def forward(self, x: Tensor) -> Dict[str, Tensor]:
57
+ x = self.body(x)
58
+ x = self.fpn(x)
59
+ return x
60
+
61
+
62
+ @handle_legacy_interface(
63
+ weights=(
64
+ "pretrained",
65
+ lambda kwargs: _get_enum_from_fn(resnet.__dict__[kwargs["backbone_name"]])["IMAGENET1K_V1"],
66
+ ),
67
+ )
68
+ def resnet_fpn_backbone(
69
+ *,
70
+ backbone_name: str,
71
+ weights: Optional[WeightsEnum],
72
+ norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
73
+ trainable_layers: int = 3,
74
+ returned_layers: Optional[List[int]] = None,
75
+ extra_blocks: Optional[ExtraFPNBlock] = None,
76
+ ) -> BackboneWithFPN:
77
+ """
78
+ Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone.
79
+
80
+ Examples::
81
+
82
+ >>> import torch
83
+ >>> from torchvision.models import ResNet50_Weights
84
+ >>> from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
85
+ >>> backbone = resnet_fpn_backbone(backbone_name='resnet50', weights=ResNet50_Weights.DEFAULT, trainable_layers=3)
86
+ >>> # get some dummy image
87
+ >>> x = torch.rand(1,3,64,64)
88
+ >>> # compute the output
89
+ >>> output = backbone(x)
90
+ >>> print([(k, v.shape) for k, v in output.items()])
91
+ >>> # returns
92
+ >>> [('0', torch.Size([1, 256, 16, 16])),
93
+ >>> ('1', torch.Size([1, 256, 8, 8])),
94
+ >>> ('2', torch.Size([1, 256, 4, 4])),
95
+ >>> ('3', torch.Size([1, 256, 2, 2])),
96
+ >>> ('pool', torch.Size([1, 256, 1, 1]))]
97
+
98
+ Args:
99
+ backbone_name (string): resnet architecture. Possible values are 'resnet18', 'resnet34', 'resnet50',
100
+ 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'
101
+ weights (WeightsEnum, optional): The pretrained weights for the model
102
+ norm_layer (callable): it is recommended to use the default value. For details visit:
103
+ (https://github.com/facebookresearch/maskrcnn-benchmark/issues/267)
104
+ trainable_layers (int): number of trainable (not frozen) layers starting from final block.
105
+ Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
106
+ returned_layers (list of int): The layers of the network to return. Each entry must be in ``[1, 4]``.
107
+ By default, all layers are returned.
108
+ extra_blocks (ExtraFPNBlock or None): if provided, extra operations will
109
+ be performed. It is expected to take the fpn features, the original
110
+ features and the names of the original features as input, and returns
111
+ a new list of feature maps and their corresponding names. By
112
+ default, a ``LastLevelMaxPool`` is used.
113
+ """
114
+ backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer)
115
+ return _resnet_fpn_extractor(backbone, trainable_layers, returned_layers, extra_blocks)
116
+
117
+
118
+ def _resnet_fpn_extractor(
119
+ backbone: resnet.ResNet,
120
+ trainable_layers: int,
121
+ returned_layers: Optional[List[int]] = None,
122
+ extra_blocks: Optional[ExtraFPNBlock] = None,
123
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
124
+ ) -> BackboneWithFPN:
125
+
126
+ # select layers that won't be frozen
127
+ if trainable_layers < 0 or trainable_layers > 5:
128
+ raise ValueError(f"Trainable layers should be in the range [0,5], got {trainable_layers}")
129
+ layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers]
130
+ if trainable_layers == 5:
131
+ layers_to_train.append("bn1")
132
+ for name, parameter in backbone.named_parameters():
133
+ if all([not name.startswith(layer) for layer in layers_to_train]):
134
+ parameter.requires_grad_(False)
135
+
136
+ if extra_blocks is None:
137
+ extra_blocks = LastLevelMaxPool()
138
+
139
+ if returned_layers is None:
140
+ returned_layers = [1, 2, 3, 4]
141
+ if min(returned_layers) <= 0 or max(returned_layers) >= 5:
142
+ raise ValueError(f"Each returned layer should be in the range [1,4]. Got {returned_layers}")
143
+ return_layers = {f"layer{k}": str(v) for v, k in enumerate(returned_layers)}
144
+
145
+ in_channels_stage2 = backbone.inplanes // 8
146
+ in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers]
147
+ out_channels = 256
148
+ return BackboneWithFPN(
149
+ backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks, norm_layer=norm_layer
150
+ )
151
+
152
+
153
+ def _validate_trainable_layers(
154
+ is_trained: bool,
155
+ trainable_backbone_layers: Optional[int],
156
+ max_value: int,
157
+ default_value: int,
158
+ ) -> int:
159
+ # don't freeze any layers if pretrained model or backbone is not used
160
+ if not is_trained:
161
+ if trainable_backbone_layers is not None:
162
+ warnings.warn(
163
+ "Changing trainable_backbone_layers has no effect if "
164
+ "neither pretrained nor pretrained_backbone have been set to True, "
165
+ f"falling back to trainable_backbone_layers={max_value} so that all layers are trainable"
166
+ )
167
+ trainable_backbone_layers = max_value
168
+
169
+ # by default freeze first blocks
170
+ if trainable_backbone_layers is None:
171
+ trainable_backbone_layers = default_value
172
+ if trainable_backbone_layers < 0 or trainable_backbone_layers > max_value:
173
+ raise ValueError(
174
+ f"Trainable backbone layers should be in the range [0,{max_value}], got {trainable_backbone_layers} "
175
+ )
176
+ return trainable_backbone_layers
177
+
178
+
179
+ @handle_legacy_interface(
180
+ weights=(
181
+ "pretrained",
182
+ lambda kwargs: _get_enum_from_fn(mobilenet.__dict__[kwargs["backbone_name"]])["IMAGENET1K_V1"],
183
+ ),
184
+ )
185
+ def mobilenet_backbone(
186
+ *,
187
+ backbone_name: str,
188
+ weights: Optional[WeightsEnum],
189
+ fpn: bool,
190
+ norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
191
+ trainable_layers: int = 2,
192
+ returned_layers: Optional[List[int]] = None,
193
+ extra_blocks: Optional[ExtraFPNBlock] = None,
194
+ ) -> nn.Module:
195
+ backbone = mobilenet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer)
196
+ return _mobilenet_extractor(backbone, fpn, trainable_layers, returned_layers, extra_blocks)
197
+
198
+
199
+ def _mobilenet_extractor(
200
+ backbone: Union[mobilenet.MobileNetV2, mobilenet.MobileNetV3],
201
+ fpn: bool,
202
+ trainable_layers: int,
203
+ returned_layers: Optional[List[int]] = None,
204
+ extra_blocks: Optional[ExtraFPNBlock] = None,
205
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
206
+ ) -> nn.Module:
207
+ backbone = backbone.features
208
+ # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
209
+ # The first and last blocks are always included because they are the C0 (conv1) and Cn.
210
+ stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
211
+ num_stages = len(stage_indices)
212
+
213
+ # find the index of the layer from which we won't freeze
214
+ if trainable_layers < 0 or trainable_layers > num_stages:
215
+ raise ValueError(f"Trainable layers should be in the range [0,{num_stages}], got {trainable_layers} ")
216
+ freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]
217
+
218
+ for b in backbone[:freeze_before]:
219
+ for parameter in b.parameters():
220
+ parameter.requires_grad_(False)
221
+
222
+ out_channels = 256
223
+ if fpn:
224
+ if extra_blocks is None:
225
+ extra_blocks = LastLevelMaxPool()
226
+
227
+ if returned_layers is None:
228
+ returned_layers = [num_stages - 2, num_stages - 1]
229
+ if min(returned_layers) < 0 or max(returned_layers) >= num_stages:
230
+ raise ValueError(f"Each returned layer should be in the range [0,{num_stages - 1}], got {returned_layers} ")
231
+ return_layers = {f"{stage_indices[k]}": str(v) for v, k in enumerate(returned_layers)}
232
+
233
+ in_channels_list = [backbone[stage_indices[i]].out_channels for i in returned_layers]
234
+ return BackboneWithFPN(
235
+ backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks, norm_layer=norm_layer
236
+ )
237
+ else:
238
+ m = nn.Sequential(
239
+ backbone,
240
+ # depthwise linear combination of channels to reduce their size
241
+ nn.Conv2d(backbone[-1].out_channels, out_channels, 1),
242
+ )
243
+ m.out_channels = out_channels # type: ignore[assignment]
244
+ return m
.venv/lib/python3.11/site-packages/torchvision/models/detection/faster_rcnn.py ADDED
@@ -0,0 +1,846 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ from torchvision.ops import MultiScaleRoIAlign
7
+
8
+ from ...ops import misc as misc_nn_ops
9
+ from ...transforms._presets import ObjectDetection
10
+ from .._api import register_model, Weights, WeightsEnum
11
+ from .._meta import _COCO_CATEGORIES
12
+ from .._utils import _ovewrite_value_param, handle_legacy_interface
13
+ from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights
14
+ from ..resnet import resnet50, ResNet50_Weights
15
+ from ._utils import overwrite_eps
16
+ from .anchor_utils import AnchorGenerator
17
+ from .backbone_utils import _mobilenet_extractor, _resnet_fpn_extractor, _validate_trainable_layers
18
+ from .generalized_rcnn import GeneralizedRCNN
19
+ from .roi_heads import RoIHeads
20
+ from .rpn import RegionProposalNetwork, RPNHead
21
+ from .transform import GeneralizedRCNNTransform
22
+
23
+
24
+ __all__ = [
25
+ "FasterRCNN",
26
+ "FasterRCNN_ResNet50_FPN_Weights",
27
+ "FasterRCNN_ResNet50_FPN_V2_Weights",
28
+ "FasterRCNN_MobileNet_V3_Large_FPN_Weights",
29
+ "FasterRCNN_MobileNet_V3_Large_320_FPN_Weights",
30
+ "fasterrcnn_resnet50_fpn",
31
+ "fasterrcnn_resnet50_fpn_v2",
32
+ "fasterrcnn_mobilenet_v3_large_fpn",
33
+ "fasterrcnn_mobilenet_v3_large_320_fpn",
34
+ ]
35
+
36
+
37
+ def _default_anchorgen():
38
+ anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
39
+ aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
40
+ return AnchorGenerator(anchor_sizes, aspect_ratios)
41
+
42
+
43
+ class FasterRCNN(GeneralizedRCNN):
44
+ """
45
+ Implements Faster R-CNN.
46
+
47
+ The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
48
+ image, and should be in 0-1 range. Different images can have different sizes.
49
+
50
+ The behavior of the model changes depending on if it is in training or evaluation mode.
51
+
52
+ During training, the model expects both the input tensors and targets (list of dictionary),
53
+ containing:
54
+ - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
55
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
56
+ - labels (Int64Tensor[N]): the class label for each ground-truth box
57
+
58
+ The model returns a Dict[Tensor] during training, containing the classification and regression
59
+ losses for both the RPN and the R-CNN.
60
+
61
+ During inference, the model requires only the input tensors, and returns the post-processed
62
+ predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
63
+ follows:
64
+ - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
65
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
66
+ - labels (Int64Tensor[N]): the predicted labels for each image
67
+ - scores (Tensor[N]): the scores or each prediction
68
+
69
+ Args:
70
+ backbone (nn.Module): the network used to compute the features for the model.
71
+ It should contain an out_channels attribute, which indicates the number of output
72
+ channels that each feature map has (and it should be the same for all feature maps).
73
+ The backbone should return a single Tensor or and OrderedDict[Tensor].
74
+ num_classes (int): number of output classes of the model (including the background).
75
+ If box_predictor is specified, num_classes should be None.
76
+ min_size (int): Images are rescaled before feeding them to the backbone:
77
+ we attempt to preserve the aspect ratio and scale the shorter edge
78
+ to ``min_size``. If the resulting longer edge exceeds ``max_size``,
79
+ then downscale so that the longer edge does not exceed ``max_size``.
80
+ This may result in the shorter edge beeing lower than ``min_size``.
81
+ max_size (int): See ``min_size``.
82
+ image_mean (Tuple[float, float, float]): mean values used for input normalization.
83
+ They are generally the mean values of the dataset on which the backbone has been trained
84
+ on
85
+ image_std (Tuple[float, float, float]): std values used for input normalization.
86
+ They are generally the std values of the dataset on which the backbone has been trained on
87
+ rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
88
+ maps.
89
+ rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
90
+ rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
91
+ rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
92
+ rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
93
+ rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
94
+ rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
95
+ rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
96
+ considered as positive during training of the RPN.
97
+ rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
98
+ considered as negative during training of the RPN.
99
+ rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
100
+ for computing the loss
101
+ rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
102
+ of the RPN
103
+ rpn_score_thresh (float): only return proposals with an objectness score greater than rpn_score_thresh
104
+ box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
105
+ the locations indicated by the bounding boxes
106
+ box_head (nn.Module): module that takes the cropped feature maps as input
107
+ box_predictor (nn.Module): module that takes the output of box_head and returns the
108
+ classification logits and box regression deltas.
109
+ box_score_thresh (float): during inference, only return proposals with a classification score
110
+ greater than box_score_thresh
111
+ box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
112
+ box_detections_per_img (int): maximum number of detections per image, for all classes.
113
+ box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
114
+ considered as positive during training of the classification head
115
+ box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
116
+ considered as negative during training of the classification head
117
+ box_batch_size_per_image (int): number of proposals that are sampled during training of the
118
+ classification head
119
+ box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
120
+ of the classification head
121
+ bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
122
+ bounding boxes
123
+
124
+ Example::
125
+
126
+ >>> import torch
127
+ >>> import torchvision
128
+ >>> from torchvision.models.detection import FasterRCNN
129
+ >>> from torchvision.models.detection.rpn import AnchorGenerator
130
+ >>> # load a pre-trained model for classification and return
131
+ >>> # only the features
132
+ >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
133
+ >>> # FasterRCNN needs to know the number of
134
+ >>> # output channels in a backbone. For mobilenet_v2, it's 1280,
135
+ >>> # so we need to add it here
136
+ >>> backbone.out_channels = 1280
137
+ >>>
138
+ >>> # let's make the RPN generate 5 x 3 anchors per spatial
139
+ >>> # location, with 5 different sizes and 3 different aspect
140
+ >>> # ratios. We have a Tuple[Tuple[int]] because each feature
141
+ >>> # map could potentially have different sizes and
142
+ >>> # aspect ratios
143
+ >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
144
+ >>> aspect_ratios=((0.5, 1.0, 2.0),))
145
+ >>>
146
+ >>> # let's define what are the feature maps that we will
147
+ >>> # use to perform the region of interest cropping, as well as
148
+ >>> # the size of the crop after rescaling.
149
+ >>> # if your backbone returns a Tensor, featmap_names is expected to
150
+ >>> # be ['0']. More generally, the backbone should return an
151
+ >>> # OrderedDict[Tensor], and in featmap_names you can choose which
152
+ >>> # feature maps to use.
153
+ >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
154
+ >>> output_size=7,
155
+ >>> sampling_ratio=2)
156
+ >>>
157
+ >>> # put the pieces together inside a FasterRCNN model
158
+ >>> model = FasterRCNN(backbone,
159
+ >>> num_classes=2,
160
+ >>> rpn_anchor_generator=anchor_generator,
161
+ >>> box_roi_pool=roi_pooler)
162
+ >>> model.eval()
163
+ >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
164
+ >>> predictions = model(x)
165
+ """
166
+
167
+ def __init__(
168
+ self,
169
+ backbone,
170
+ num_classes=None,
171
+ # transform parameters
172
+ min_size=800,
173
+ max_size=1333,
174
+ image_mean=None,
175
+ image_std=None,
176
+ # RPN parameters
177
+ rpn_anchor_generator=None,
178
+ rpn_head=None,
179
+ rpn_pre_nms_top_n_train=2000,
180
+ rpn_pre_nms_top_n_test=1000,
181
+ rpn_post_nms_top_n_train=2000,
182
+ rpn_post_nms_top_n_test=1000,
183
+ rpn_nms_thresh=0.7,
184
+ rpn_fg_iou_thresh=0.7,
185
+ rpn_bg_iou_thresh=0.3,
186
+ rpn_batch_size_per_image=256,
187
+ rpn_positive_fraction=0.5,
188
+ rpn_score_thresh=0.0,
189
+ # Box parameters
190
+ box_roi_pool=None,
191
+ box_head=None,
192
+ box_predictor=None,
193
+ box_score_thresh=0.05,
194
+ box_nms_thresh=0.5,
195
+ box_detections_per_img=100,
196
+ box_fg_iou_thresh=0.5,
197
+ box_bg_iou_thresh=0.5,
198
+ box_batch_size_per_image=512,
199
+ box_positive_fraction=0.25,
200
+ bbox_reg_weights=None,
201
+ **kwargs,
202
+ ):
203
+
204
+ if not hasattr(backbone, "out_channels"):
205
+ raise ValueError(
206
+ "backbone should contain an attribute out_channels "
207
+ "specifying the number of output channels (assumed to be the "
208
+ "same for all the levels)"
209
+ )
210
+
211
+ if not isinstance(rpn_anchor_generator, (AnchorGenerator, type(None))):
212
+ raise TypeError(
213
+ f"rpn_anchor_generator should be of type AnchorGenerator or None instead of {type(rpn_anchor_generator)}"
214
+ )
215
+ if not isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None))):
216
+ raise TypeError(
217
+ f"box_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(box_roi_pool)}"
218
+ )
219
+
220
+ if num_classes is not None:
221
+ if box_predictor is not None:
222
+ raise ValueError("num_classes should be None when box_predictor is specified")
223
+ else:
224
+ if box_predictor is None:
225
+ raise ValueError("num_classes should not be None when box_predictor is not specified")
226
+
227
+ out_channels = backbone.out_channels
228
+
229
+ if rpn_anchor_generator is None:
230
+ rpn_anchor_generator = _default_anchorgen()
231
+ if rpn_head is None:
232
+ rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])
233
+
234
+ rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test)
235
+ rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test)
236
+
237
+ rpn = RegionProposalNetwork(
238
+ rpn_anchor_generator,
239
+ rpn_head,
240
+ rpn_fg_iou_thresh,
241
+ rpn_bg_iou_thresh,
242
+ rpn_batch_size_per_image,
243
+ rpn_positive_fraction,
244
+ rpn_pre_nms_top_n,
245
+ rpn_post_nms_top_n,
246
+ rpn_nms_thresh,
247
+ score_thresh=rpn_score_thresh,
248
+ )
249
+
250
+ if box_roi_pool is None:
251
+ box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2)
252
+
253
+ if box_head is None:
254
+ resolution = box_roi_pool.output_size[0]
255
+ representation_size = 1024
256
+ box_head = TwoMLPHead(out_channels * resolution**2, representation_size)
257
+
258
+ if box_predictor is None:
259
+ representation_size = 1024
260
+ box_predictor = FastRCNNPredictor(representation_size, num_classes)
261
+
262
+ roi_heads = RoIHeads(
263
+ # Box
264
+ box_roi_pool,
265
+ box_head,
266
+ box_predictor,
267
+ box_fg_iou_thresh,
268
+ box_bg_iou_thresh,
269
+ box_batch_size_per_image,
270
+ box_positive_fraction,
271
+ bbox_reg_weights,
272
+ box_score_thresh,
273
+ box_nms_thresh,
274
+ box_detections_per_img,
275
+ )
276
+
277
+ if image_mean is None:
278
+ image_mean = [0.485, 0.456, 0.406]
279
+ if image_std is None:
280
+ image_std = [0.229, 0.224, 0.225]
281
+ transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
282
+
283
+ super().__init__(backbone, rpn, roi_heads, transform)
284
+
285
+
286
+ class TwoMLPHead(nn.Module):
287
+ """
288
+ Standard heads for FPN-based models
289
+
290
+ Args:
291
+ in_channels (int): number of input channels
292
+ representation_size (int): size of the intermediate representation
293
+ """
294
+
295
+ def __init__(self, in_channels, representation_size):
296
+ super().__init__()
297
+
298
+ self.fc6 = nn.Linear(in_channels, representation_size)
299
+ self.fc7 = nn.Linear(representation_size, representation_size)
300
+
301
+ def forward(self, x):
302
+ x = x.flatten(start_dim=1)
303
+
304
+ x = F.relu(self.fc6(x))
305
+ x = F.relu(self.fc7(x))
306
+
307
+ return x
308
+
309
+
310
+ class FastRCNNConvFCHead(nn.Sequential):
311
+ def __init__(
312
+ self,
313
+ input_size: Tuple[int, int, int],
314
+ conv_layers: List[int],
315
+ fc_layers: List[int],
316
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
317
+ ):
318
+ """
319
+ Args:
320
+ input_size (Tuple[int, int, int]): the input size in CHW format.
321
+ conv_layers (list): feature dimensions of each Convolution layer
322
+ fc_layers (list): feature dimensions of each FCN layer
323
+ norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
324
+ """
325
+ in_channels, in_height, in_width = input_size
326
+
327
+ blocks = []
328
+ previous_channels = in_channels
329
+ for current_channels in conv_layers:
330
+ blocks.append(misc_nn_ops.Conv2dNormActivation(previous_channels, current_channels, norm_layer=norm_layer))
331
+ previous_channels = current_channels
332
+ blocks.append(nn.Flatten())
333
+ previous_channels = previous_channels * in_height * in_width
334
+ for current_channels in fc_layers:
335
+ blocks.append(nn.Linear(previous_channels, current_channels))
336
+ blocks.append(nn.ReLU(inplace=True))
337
+ previous_channels = current_channels
338
+
339
+ super().__init__(*blocks)
340
+ for layer in self.modules():
341
+ if isinstance(layer, nn.Conv2d):
342
+ nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu")
343
+ if layer.bias is not None:
344
+ nn.init.zeros_(layer.bias)
345
+
346
+
347
+ class FastRCNNPredictor(nn.Module):
348
+ """
349
+ Standard classification + bounding box regression layers
350
+ for Fast R-CNN.
351
+
352
+ Args:
353
+ in_channels (int): number of input channels
354
+ num_classes (int): number of output classes (including background)
355
+ """
356
+
357
+ def __init__(self, in_channels, num_classes):
358
+ super().__init__()
359
+ self.cls_score = nn.Linear(in_channels, num_classes)
360
+ self.bbox_pred = nn.Linear(in_channels, num_classes * 4)
361
+
362
+ def forward(self, x):
363
+ if x.dim() == 4:
364
+ torch._assert(
365
+ list(x.shape[2:]) == [1, 1],
366
+ f"x has the wrong shape, expecting the last two dimensions to be [1,1] instead of {list(x.shape[2:])}",
367
+ )
368
+ x = x.flatten(start_dim=1)
369
+ scores = self.cls_score(x)
370
+ bbox_deltas = self.bbox_pred(x)
371
+
372
+ return scores, bbox_deltas
373
+
374
+
375
+ _COMMON_META = {
376
+ "categories": _COCO_CATEGORIES,
377
+ "min_size": (1, 1),
378
+ }
379
+
380
+
381
+ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
382
+ COCO_V1 = Weights(
383
+ url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
384
+ transforms=ObjectDetection,
385
+ meta={
386
+ **_COMMON_META,
387
+ "num_params": 41755286,
388
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
389
+ "_metrics": {
390
+ "COCO-val2017": {
391
+ "box_map": 37.0,
392
+ }
393
+ },
394
+ "_ops": 134.38,
395
+ "_file_size": 159.743,
396
+ "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
397
+ },
398
+ )
399
+ DEFAULT = COCO_V1
400
+
401
+
402
+ class FasterRCNN_ResNet50_FPN_V2_Weights(WeightsEnum):
403
+ COCO_V1 = Weights(
404
+ url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_v2_coco-dd69338a.pth",
405
+ transforms=ObjectDetection,
406
+ meta={
407
+ **_COMMON_META,
408
+ "num_params": 43712278,
409
+ "recipe": "https://github.com/pytorch/vision/pull/5763",
410
+ "_metrics": {
411
+ "COCO-val2017": {
412
+ "box_map": 46.7,
413
+ }
414
+ },
415
+ "_ops": 280.371,
416
+ "_file_size": 167.104,
417
+ "_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""",
418
+ },
419
+ )
420
+ DEFAULT = COCO_V1
421
+
422
+
423
+ class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
424
+ COCO_V1 = Weights(
425
+ url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
426
+ transforms=ObjectDetection,
427
+ meta={
428
+ **_COMMON_META,
429
+ "num_params": 19386354,
430
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
431
+ "_metrics": {
432
+ "COCO-val2017": {
433
+ "box_map": 32.8,
434
+ }
435
+ },
436
+ "_ops": 4.494,
437
+ "_file_size": 74.239,
438
+ "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
439
+ },
440
+ )
441
+ DEFAULT = COCO_V1
442
+
443
+
444
+ class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
445
+ COCO_V1 = Weights(
446
+ url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
447
+ transforms=ObjectDetection,
448
+ meta={
449
+ **_COMMON_META,
450
+ "num_params": 19386354,
451
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
452
+ "_metrics": {
453
+ "COCO-val2017": {
454
+ "box_map": 22.8,
455
+ }
456
+ },
457
+ "_ops": 0.719,
458
+ "_file_size": 74.239,
459
+ "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
460
+ },
461
+ )
462
+ DEFAULT = COCO_V1
463
+
464
+
465
+ @register_model()
466
+ @handle_legacy_interface(
467
+ weights=("pretrained", FasterRCNN_ResNet50_FPN_Weights.COCO_V1),
468
+ weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
469
+ )
470
+ def fasterrcnn_resnet50_fpn(
471
+ *,
472
+ weights: Optional[FasterRCNN_ResNet50_FPN_Weights] = None,
473
+ progress: bool = True,
474
+ num_classes: Optional[int] = None,
475
+ weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
476
+ trainable_backbone_layers: Optional[int] = None,
477
+ **kwargs: Any,
478
+ ) -> FasterRCNN:
479
+ """
480
+ Faster R-CNN model with a ResNet-50-FPN backbone from the `Faster R-CNN: Towards Real-Time Object
481
+ Detection with Region Proposal Networks <https://arxiv.org/abs/1506.01497>`__
482
+ paper.
483
+
484
+ .. betastatus:: detection module
485
+
486
+ The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
487
+ image, and should be in ``0-1`` range. Different images can have different sizes.
488
+
489
+ The behavior of the model changes depending on if it is in training or evaluation mode.
490
+
491
+ During training, the model expects both the input tensors and a targets (list of dictionary),
492
+ containing:
493
+
494
+ - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
495
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
496
+ - labels (``Int64Tensor[N]``): the class label for each ground-truth box
497
+
498
+ The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
499
+ losses for both the RPN and the R-CNN.
500
+
501
+ During inference, the model requires only the input tensors, and returns the post-processed
502
+ predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
503
+ follows, where ``N`` is the number of detections:
504
+
505
+ - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
506
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
507
+ - labels (``Int64Tensor[N]``): the predicted labels for each detection
508
+ - scores (``Tensor[N]``): the scores of each detection
509
+
510
+ For more details on the output, you may refer to :ref:`instance_seg_output`.
511
+
512
+ Faster R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
513
+
514
+ Example::
515
+
516
+ >>> model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
517
+ >>> # For training
518
+ >>> images, boxes = torch.rand(4, 3, 600, 1200), torch.rand(4, 11, 4)
519
+ >>> boxes[:, :, 2:4] = boxes[:, :, 0:2] + boxes[:, :, 2:4]
520
+ >>> labels = torch.randint(1, 91, (4, 11))
521
+ >>> images = list(image for image in images)
522
+ >>> targets = []
523
+ >>> for i in range(len(images)):
524
+ >>> d = {}
525
+ >>> d['boxes'] = boxes[i]
526
+ >>> d['labels'] = labels[i]
527
+ >>> targets.append(d)
528
+ >>> output = model(images, targets)
529
+ >>> # For inference
530
+ >>> model.eval()
531
+ >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
532
+ >>> predictions = model(x)
533
+ >>>
534
+ >>> # optionally, if you want to export the model to ONNX:
535
+ >>> torch.onnx.export(model, x, "faster_rcnn.onnx", opset_version = 11)
536
+
537
+ Args:
538
+ weights (:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights`, optional): The
539
+ pretrained weights to use. See
540
+ :class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights` below for
541
+ more details, and possible values. By default, no pre-trained
542
+ weights are used.
543
+ progress (bool, optional): If True, displays a progress bar of the
544
+ download to stderr. Default is True.
545
+ num_classes (int, optional): number of output classes of the model (including the background)
546
+ weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
547
+ pretrained weights for the backbone.
548
+ trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
549
+ final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
550
+ trainable. If ``None`` is passed (the default) this value is set to 3.
551
+ **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
552
+ base class. Please refer to the `source code
553
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
554
+ for more details about this class.
555
+
556
+ .. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights
557
+ :members:
558
+ """
559
+ weights = FasterRCNN_ResNet50_FPN_Weights.verify(weights)
560
+ weights_backbone = ResNet50_Weights.verify(weights_backbone)
561
+
562
+ if weights is not None:
563
+ weights_backbone = None
564
+ num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
565
+ elif num_classes is None:
566
+ num_classes = 91
567
+
568
+ is_trained = weights is not None or weights_backbone is not None
569
+ trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
570
+ norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
571
+
572
+ backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
573
+ backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
574
+ model = FasterRCNN(backbone, num_classes=num_classes, **kwargs)
575
+
576
+ if weights is not None:
577
+ model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
578
+ if weights == FasterRCNN_ResNet50_FPN_Weights.COCO_V1:
579
+ overwrite_eps(model, 0.0)
580
+
581
+ return model
582
+
583
+
584
+ @register_model()
585
+ @handle_legacy_interface(
586
+ weights=("pretrained", FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1),
587
+ weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
588
+ )
589
+ def fasterrcnn_resnet50_fpn_v2(
590
+ *,
591
+ weights: Optional[FasterRCNN_ResNet50_FPN_V2_Weights] = None,
592
+ progress: bool = True,
593
+ num_classes: Optional[int] = None,
594
+ weights_backbone: Optional[ResNet50_Weights] = None,
595
+ trainable_backbone_layers: Optional[int] = None,
596
+ **kwargs: Any,
597
+ ) -> FasterRCNN:
598
+ """
599
+ Constructs an improved Faster R-CNN model with a ResNet-50-FPN backbone from `Benchmarking Detection
600
+ Transfer Learning with Vision Transformers <https://arxiv.org/abs/2111.11429>`__ paper.
601
+
602
+ .. betastatus:: detection module
603
+
604
+ It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
605
+ :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
606
+ details.
607
+
608
+ Args:
609
+ weights (:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights`, optional): The
610
+ pretrained weights to use. See
611
+ :class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights` below for
612
+ more details, and possible values. By default, no pre-trained
613
+ weights are used.
614
+ progress (bool, optional): If True, displays a progress bar of the
615
+ download to stderr. Default is True.
616
+ num_classes (int, optional): number of output classes of the model (including the background)
617
+ weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
618
+ pretrained weights for the backbone.
619
+ trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
620
+ final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
621
+ trainable. If ``None`` is passed (the default) this value is set to 3.
622
+ **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
623
+ base class. Please refer to the `source code
624
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
625
+ for more details about this class.
626
+
627
+ .. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights
628
+ :members:
629
+ """
630
+ weights = FasterRCNN_ResNet50_FPN_V2_Weights.verify(weights)
631
+ weights_backbone = ResNet50_Weights.verify(weights_backbone)
632
+
633
+ if weights is not None:
634
+ weights_backbone = None
635
+ num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
636
+ elif num_classes is None:
637
+ num_classes = 91
638
+
639
+ is_trained = weights is not None or weights_backbone is not None
640
+ trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
641
+
642
+ backbone = resnet50(weights=weights_backbone, progress=progress)
643
+ backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers, norm_layer=nn.BatchNorm2d)
644
+ rpn_anchor_generator = _default_anchorgen()
645
+ rpn_head = RPNHead(backbone.out_channels, rpn_anchor_generator.num_anchors_per_location()[0], conv_depth=2)
646
+ box_head = FastRCNNConvFCHead(
647
+ (backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d
648
+ )
649
+ model = FasterRCNN(
650
+ backbone,
651
+ num_classes=num_classes,
652
+ rpn_anchor_generator=rpn_anchor_generator,
653
+ rpn_head=rpn_head,
654
+ box_head=box_head,
655
+ **kwargs,
656
+ )
657
+
658
+ if weights is not None:
659
+ model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
660
+
661
+ return model
662
+
663
+
664
+ def _fasterrcnn_mobilenet_v3_large_fpn(
665
+ *,
666
+ weights: Optional[Union[FasterRCNN_MobileNet_V3_Large_FPN_Weights, FasterRCNN_MobileNet_V3_Large_320_FPN_Weights]],
667
+ progress: bool,
668
+ num_classes: Optional[int],
669
+ weights_backbone: Optional[MobileNet_V3_Large_Weights],
670
+ trainable_backbone_layers: Optional[int],
671
+ **kwargs: Any,
672
+ ) -> FasterRCNN:
673
+ if weights is not None:
674
+ weights_backbone = None
675
+ num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
676
+ elif num_classes is None:
677
+ num_classes = 91
678
+
679
+ is_trained = weights is not None or weights_backbone is not None
680
+ trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 6, 3)
681
+ norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
682
+
683
+ backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
684
+ backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
685
+ anchor_sizes = (
686
+ (
687
+ 32,
688
+ 64,
689
+ 128,
690
+ 256,
691
+ 512,
692
+ ),
693
+ ) * 3
694
+ aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
695
+ model = FasterRCNN(
696
+ backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs
697
+ )
698
+
699
+ if weights is not None:
700
+ model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
701
+
702
+ return model
703
+
704
+
705
+ @register_model()
706
+ @handle_legacy_interface(
707
+ weights=("pretrained", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1),
708
+ weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
709
+ )
710
+ def fasterrcnn_mobilenet_v3_large_320_fpn(
711
+ *,
712
+ weights: Optional[FasterRCNN_MobileNet_V3_Large_320_FPN_Weights] = None,
713
+ progress: bool = True,
714
+ num_classes: Optional[int] = None,
715
+ weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
716
+ trainable_backbone_layers: Optional[int] = None,
717
+ **kwargs: Any,
718
+ ) -> FasterRCNN:
719
+ """
720
+ Low resolution Faster R-CNN model with a MobileNetV3-Large backbone tuned for mobile use cases.
721
+
722
+ .. betastatus:: detection module
723
+
724
+ It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
725
+ :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
726
+ details.
727
+
728
+ Example::
729
+
730
+ >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(weights=FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.DEFAULT)
731
+ >>> model.eval()
732
+ >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
733
+ >>> predictions = model(x)
734
+
735
+ Args:
736
+ weights (:class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights`, optional): The
737
+ pretrained weights to use. See
738
+ :class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights` below for
739
+ more details, and possible values. By default, no pre-trained
740
+ weights are used.
741
+ progress (bool, optional): If True, displays a progress bar of the
742
+ download to stderr. Default is True.
743
+ num_classes (int, optional): number of output classes of the model (including the background)
744
+ weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The
745
+ pretrained weights for the backbone.
746
+ trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
747
+ final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are
748
+ trainable. If ``None`` is passed (the default) this value is set to 3.
749
+ **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
750
+ base class. Please refer to the `source code
751
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
752
+ for more details about this class.
753
+
754
+ .. autoclass:: torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights
755
+ :members:
756
+ """
757
+ weights = FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.verify(weights)
758
+ weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
759
+
760
+ defaults = {
761
+ "min_size": 320,
762
+ "max_size": 640,
763
+ "rpn_pre_nms_top_n_test": 150,
764
+ "rpn_post_nms_top_n_test": 150,
765
+ "rpn_score_thresh": 0.05,
766
+ }
767
+
768
+ kwargs = {**defaults, **kwargs}
769
+ return _fasterrcnn_mobilenet_v3_large_fpn(
770
+ weights=weights,
771
+ progress=progress,
772
+ num_classes=num_classes,
773
+ weights_backbone=weights_backbone,
774
+ trainable_backbone_layers=trainable_backbone_layers,
775
+ **kwargs,
776
+ )
777
+
778
+
779
+ @register_model()
780
+ @handle_legacy_interface(
781
+ weights=("pretrained", FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1),
782
+ weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
783
+ )
784
+ def fasterrcnn_mobilenet_v3_large_fpn(
785
+ *,
786
+ weights: Optional[FasterRCNN_MobileNet_V3_Large_FPN_Weights] = None,
787
+ progress: bool = True,
788
+ num_classes: Optional[int] = None,
789
+ weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
790
+ trainable_backbone_layers: Optional[int] = None,
791
+ **kwargs: Any,
792
+ ) -> FasterRCNN:
793
+ """
794
+ Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone.
795
+
796
+ .. betastatus:: detection module
797
+
798
+ It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
799
+ :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
800
+ details.
801
+
802
+ Example::
803
+
804
+ >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT)
805
+ >>> model.eval()
806
+ >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
807
+ >>> predictions = model(x)
808
+
809
+ Args:
810
+ weights (:class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights`, optional): The
811
+ pretrained weights to use. See
812
+ :class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights` below for
813
+ more details, and possible values. By default, no pre-trained
814
+ weights are used.
815
+ progress (bool, optional): If True, displays a progress bar of the
816
+ download to stderr. Default is True.
817
+ num_classes (int, optional): number of output classes of the model (including the background)
818
+ weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The
819
+ pretrained weights for the backbone.
820
+ trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
821
+ final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are
822
+ trainable. If ``None`` is passed (the default) this value is set to 3.
823
+ **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
824
+ base class. Please refer to the `source code
825
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
826
+ for more details about this class.
827
+
828
+ .. autoclass:: torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights
829
+ :members:
830
+ """
831
+ weights = FasterRCNN_MobileNet_V3_Large_FPN_Weights.verify(weights)
832
+ weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
833
+
834
+ defaults = {
835
+ "rpn_score_thresh": 0.05,
836
+ }
837
+
838
+ kwargs = {**defaults, **kwargs}
839
+ return _fasterrcnn_mobilenet_v3_large_fpn(
840
+ weights=weights,
841
+ progress=progress,
842
+ num_classes=num_classes,
843
+ weights_backbone=weights_backbone,
844
+ trainable_backbone_layers=trainable_backbone_layers,
845
+ **kwargs,
846
+ )
.venv/lib/python3.11/site-packages/torchvision/models/detection/fcos.py ADDED
@@ -0,0 +1,775 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from collections import OrderedDict
4
+ from functools import partial
5
+ from typing import Any, Callable, Dict, List, Optional, Tuple
6
+
7
+ import torch
8
+ from torch import nn, Tensor
9
+
10
+ from ...ops import boxes as box_ops, generalized_box_iou_loss, misc as misc_nn_ops, sigmoid_focal_loss
11
+ from ...ops.feature_pyramid_network import LastLevelP6P7
12
+ from ...transforms._presets import ObjectDetection
13
+ from ...utils import _log_api_usage_once
14
+ from .._api import register_model, Weights, WeightsEnum
15
+ from .._meta import _COCO_CATEGORIES
16
+ from .._utils import _ovewrite_value_param, handle_legacy_interface
17
+ from ..resnet import resnet50, ResNet50_Weights
18
+ from . import _utils as det_utils
19
+ from .anchor_utils import AnchorGenerator
20
+ from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
21
+ from .transform import GeneralizedRCNNTransform
22
+
23
+
24
+ __all__ = [
25
+ "FCOS",
26
+ "FCOS_ResNet50_FPN_Weights",
27
+ "fcos_resnet50_fpn",
28
+ ]
29
+
30
+
31
+ class FCOSHead(nn.Module):
32
+ """
33
+ A regression and classification head for use in FCOS.
34
+
35
+ Args:
36
+ in_channels (int): number of channels of the input feature
37
+ num_anchors (int): number of anchors to be predicted
38
+ num_classes (int): number of classes to be predicted
39
+ num_convs (Optional[int]): number of conv layer of head. Default: 4.
40
+ """
41
+
42
+ __annotations__ = {
43
+ "box_coder": det_utils.BoxLinearCoder,
44
+ }
45
+
46
+ def __init__(self, in_channels: int, num_anchors: int, num_classes: int, num_convs: Optional[int] = 4) -> None:
47
+ super().__init__()
48
+ self.box_coder = det_utils.BoxLinearCoder(normalize_by_size=True)
49
+ self.classification_head = FCOSClassificationHead(in_channels, num_anchors, num_classes, num_convs)
50
+ self.regression_head = FCOSRegressionHead(in_channels, num_anchors, num_convs)
51
+
52
+ def compute_loss(
53
+ self,
54
+ targets: List[Dict[str, Tensor]],
55
+ head_outputs: Dict[str, Tensor],
56
+ anchors: List[Tensor],
57
+ matched_idxs: List[Tensor],
58
+ ) -> Dict[str, Tensor]:
59
+
60
+ cls_logits = head_outputs["cls_logits"] # [N, HWA, C]
61
+ bbox_regression = head_outputs["bbox_regression"] # [N, HWA, 4]
62
+ bbox_ctrness = head_outputs["bbox_ctrness"] # [N, HWA, 1]
63
+
64
+ all_gt_classes_targets = []
65
+ all_gt_boxes_targets = []
66
+ for targets_per_image, matched_idxs_per_image in zip(targets, matched_idxs):
67
+ if len(targets_per_image["labels"]) == 0:
68
+ gt_classes_targets = targets_per_image["labels"].new_zeros((len(matched_idxs_per_image),))
69
+ gt_boxes_targets = targets_per_image["boxes"].new_zeros((len(matched_idxs_per_image), 4))
70
+ else:
71
+ gt_classes_targets = targets_per_image["labels"][matched_idxs_per_image.clip(min=0)]
72
+ gt_boxes_targets = targets_per_image["boxes"][matched_idxs_per_image.clip(min=0)]
73
+ gt_classes_targets[matched_idxs_per_image < 0] = -1 # background
74
+ all_gt_classes_targets.append(gt_classes_targets)
75
+ all_gt_boxes_targets.append(gt_boxes_targets)
76
+
77
+ # List[Tensor] to Tensor conversion of `all_gt_boxes_target`, `all_gt_classes_targets` and `anchors`
78
+ all_gt_boxes_targets, all_gt_classes_targets, anchors = (
79
+ torch.stack(all_gt_boxes_targets),
80
+ torch.stack(all_gt_classes_targets),
81
+ torch.stack(anchors),
82
+ )
83
+
84
+ # compute foregroud
85
+ foregroud_mask = all_gt_classes_targets >= 0
86
+ num_foreground = foregroud_mask.sum().item()
87
+
88
+ # classification loss
89
+ gt_classes_targets = torch.zeros_like(cls_logits)
90
+ gt_classes_targets[foregroud_mask, all_gt_classes_targets[foregroud_mask]] = 1.0
91
+ loss_cls = sigmoid_focal_loss(cls_logits, gt_classes_targets, reduction="sum")
92
+
93
+ # amp issue: pred_boxes need to convert float
94
+ pred_boxes = self.box_coder.decode(bbox_regression, anchors)
95
+
96
+ # regression loss: GIoU loss
97
+ loss_bbox_reg = generalized_box_iou_loss(
98
+ pred_boxes[foregroud_mask],
99
+ all_gt_boxes_targets[foregroud_mask],
100
+ reduction="sum",
101
+ )
102
+
103
+ # ctrness loss
104
+
105
+ bbox_reg_targets = self.box_coder.encode(anchors, all_gt_boxes_targets)
106
+
107
+ if len(bbox_reg_targets) == 0:
108
+ gt_ctrness_targets = bbox_reg_targets.new_zeros(bbox_reg_targets.size()[:-1])
109
+ else:
110
+ left_right = bbox_reg_targets[:, :, [0, 2]]
111
+ top_bottom = bbox_reg_targets[:, :, [1, 3]]
112
+ gt_ctrness_targets = torch.sqrt(
113
+ (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0])
114
+ * (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
115
+ )
116
+ pred_centerness = bbox_ctrness.squeeze(dim=2)
117
+ loss_bbox_ctrness = nn.functional.binary_cross_entropy_with_logits(
118
+ pred_centerness[foregroud_mask], gt_ctrness_targets[foregroud_mask], reduction="sum"
119
+ )
120
+
121
+ return {
122
+ "classification": loss_cls / max(1, num_foreground),
123
+ "bbox_regression": loss_bbox_reg / max(1, num_foreground),
124
+ "bbox_ctrness": loss_bbox_ctrness / max(1, num_foreground),
125
+ }
126
+
127
+ def forward(self, x: List[Tensor]) -> Dict[str, Tensor]:
128
+ cls_logits = self.classification_head(x)
129
+ bbox_regression, bbox_ctrness = self.regression_head(x)
130
+ return {
131
+ "cls_logits": cls_logits,
132
+ "bbox_regression": bbox_regression,
133
+ "bbox_ctrness": bbox_ctrness,
134
+ }
135
+
136
+
137
+ class FCOSClassificationHead(nn.Module):
138
+ """
139
+ A classification head for use in FCOS.
140
+
141
+ Args:
142
+ in_channels (int): number of channels of the input feature.
143
+ num_anchors (int): number of anchors to be predicted.
144
+ num_classes (int): number of classes to be predicted.
145
+ num_convs (Optional[int]): number of conv layer. Default: 4.
146
+ prior_probability (Optional[float]): probability of prior. Default: 0.01.
147
+ norm_layer: Module specifying the normalization layer to use.
148
+ """
149
+
150
+ def __init__(
151
+ self,
152
+ in_channels: int,
153
+ num_anchors: int,
154
+ num_classes: int,
155
+ num_convs: int = 4,
156
+ prior_probability: float = 0.01,
157
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
158
+ ) -> None:
159
+ super().__init__()
160
+
161
+ self.num_classes = num_classes
162
+ self.num_anchors = num_anchors
163
+
164
+ if norm_layer is None:
165
+ norm_layer = partial(nn.GroupNorm, 32)
166
+
167
+ conv = []
168
+ for _ in range(num_convs):
169
+ conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1))
170
+ conv.append(norm_layer(in_channels))
171
+ conv.append(nn.ReLU())
172
+ self.conv = nn.Sequential(*conv)
173
+
174
+ for layer in self.conv.children():
175
+ if isinstance(layer, nn.Conv2d):
176
+ torch.nn.init.normal_(layer.weight, std=0.01)
177
+ torch.nn.init.constant_(layer.bias, 0)
178
+
179
+ self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
180
+ torch.nn.init.normal_(self.cls_logits.weight, std=0.01)
181
+ torch.nn.init.constant_(self.cls_logits.bias, -math.log((1 - prior_probability) / prior_probability))
182
+
183
+ def forward(self, x: List[Tensor]) -> Tensor:
184
+ all_cls_logits = []
185
+
186
+ for features in x:
187
+ cls_logits = self.conv(features)
188
+ cls_logits = self.cls_logits(cls_logits)
189
+
190
+ # Permute classification output from (N, A * K, H, W) to (N, HWA, K).
191
+ N, _, H, W = cls_logits.shape
192
+ cls_logits = cls_logits.view(N, -1, self.num_classes, H, W)
193
+ cls_logits = cls_logits.permute(0, 3, 4, 1, 2)
194
+ cls_logits = cls_logits.reshape(N, -1, self.num_classes) # Size=(N, HWA, 4)
195
+
196
+ all_cls_logits.append(cls_logits)
197
+
198
+ return torch.cat(all_cls_logits, dim=1)
199
+
200
+
201
+ class FCOSRegressionHead(nn.Module):
202
+ """
203
+ A regression head for use in FCOS, which combines regression branch and center-ness branch.
204
+ This can obtain better performance.
205
+
206
+ Reference: `FCOS: A simple and strong anchor-free object detector <https://arxiv.org/abs/2006.09214>`_.
207
+
208
+ Args:
209
+ in_channels (int): number of channels of the input feature
210
+ num_anchors (int): number of anchors to be predicted
211
+ num_convs (Optional[int]): number of conv layer. Default: 4.
212
+ norm_layer: Module specifying the normalization layer to use.
213
+ """
214
+
215
+ def __init__(
216
+ self,
217
+ in_channels: int,
218
+ num_anchors: int,
219
+ num_convs: int = 4,
220
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
221
+ ):
222
+ super().__init__()
223
+
224
+ if norm_layer is None:
225
+ norm_layer = partial(nn.GroupNorm, 32)
226
+
227
+ conv = []
228
+ for _ in range(num_convs):
229
+ conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1))
230
+ conv.append(norm_layer(in_channels))
231
+ conv.append(nn.ReLU())
232
+ self.conv = nn.Sequential(*conv)
233
+
234
+ self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1)
235
+ self.bbox_ctrness = nn.Conv2d(in_channels, num_anchors * 1, kernel_size=3, stride=1, padding=1)
236
+ for layer in [self.bbox_reg, self.bbox_ctrness]:
237
+ torch.nn.init.normal_(layer.weight, std=0.01)
238
+ torch.nn.init.zeros_(layer.bias)
239
+
240
+ for layer in self.conv.children():
241
+ if isinstance(layer, nn.Conv2d):
242
+ torch.nn.init.normal_(layer.weight, std=0.01)
243
+ torch.nn.init.zeros_(layer.bias)
244
+
245
+ def forward(self, x: List[Tensor]) -> Tuple[Tensor, Tensor]:
246
+ all_bbox_regression = []
247
+ all_bbox_ctrness = []
248
+
249
+ for features in x:
250
+ bbox_feature = self.conv(features)
251
+ bbox_regression = nn.functional.relu(self.bbox_reg(bbox_feature))
252
+ bbox_ctrness = self.bbox_ctrness(bbox_feature)
253
+
254
+ # permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4).
255
+ N, _, H, W = bbox_regression.shape
256
+ bbox_regression = bbox_regression.view(N, -1, 4, H, W)
257
+ bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2)
258
+ bbox_regression = bbox_regression.reshape(N, -1, 4) # Size=(N, HWA, 4)
259
+ all_bbox_regression.append(bbox_regression)
260
+
261
+ # permute bbox ctrness output from (N, 1 * A, H, W) to (N, HWA, 1).
262
+ bbox_ctrness = bbox_ctrness.view(N, -1, 1, H, W)
263
+ bbox_ctrness = bbox_ctrness.permute(0, 3, 4, 1, 2)
264
+ bbox_ctrness = bbox_ctrness.reshape(N, -1, 1)
265
+ all_bbox_ctrness.append(bbox_ctrness)
266
+
267
+ return torch.cat(all_bbox_regression, dim=1), torch.cat(all_bbox_ctrness, dim=1)
268
+
269
+
270
+ class FCOS(nn.Module):
271
+ """
272
+ Implements FCOS.
273
+
274
+ The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
275
+ image, and should be in 0-1 range. Different images can have different sizes.
276
+
277
+ The behavior of the model changes depending on if it is in training or evaluation mode.
278
+
279
+ During training, the model expects both the input tensors and targets (list of dictionary),
280
+ containing:
281
+ - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
282
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
283
+ - labels (Int64Tensor[N]): the class label for each ground-truth box
284
+
285
+ The model returns a Dict[Tensor] during training, containing the classification, regression
286
+ and centerness losses.
287
+
288
+ During inference, the model requires only the input tensors, and returns the post-processed
289
+ predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
290
+ follows:
291
+ - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
292
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
293
+ - labels (Int64Tensor[N]): the predicted labels for each image
294
+ - scores (Tensor[N]): the scores for each prediction
295
+
296
+ Args:
297
+ backbone (nn.Module): the network used to compute the features for the model.
298
+ It should contain an out_channels attribute, which indicates the number of output
299
+ channels that each feature map has (and it should be the same for all feature maps).
300
+ The backbone should return a single Tensor or an OrderedDict[Tensor].
301
+ num_classes (int): number of output classes of the model (including the background).
302
+ min_size (int): Images are rescaled before feeding them to the backbone:
303
+ we attempt to preserve the aspect ratio and scale the shorter edge
304
+ to ``min_size``. If the resulting longer edge exceeds ``max_size``,
305
+ then downscale so that the longer edge does not exceed ``max_size``.
306
+ This may result in the shorter edge beeing lower than ``min_size``.
307
+ max_size (int): See ``min_size``.
308
+ image_mean (Tuple[float, float, float]): mean values used for input normalization.
309
+ They are generally the mean values of the dataset on which the backbone has been trained
310
+ on
311
+ image_std (Tuple[float, float, float]): std values used for input normalization.
312
+ They are generally the std values of the dataset on which the backbone has been trained on
313
+ anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
314
+ maps. For FCOS, only set one anchor for per position of each level, the width and height equal to
315
+ the stride of feature map, and set aspect ratio = 1.0, so the center of anchor is equivalent to the point
316
+ in FCOS paper.
317
+ head (nn.Module): Module run on top of the feature pyramid.
318
+ Defaults to a module containing a classification and regression module.
319
+ center_sampling_radius (int): radius of the "center" of a groundtruth box,
320
+ within which all anchor points are labeled positive.
321
+ score_thresh (float): Score threshold used for postprocessing the detections.
322
+ nms_thresh (float): NMS threshold used for postprocessing the detections.
323
+ detections_per_img (int): Number of best detections to keep after NMS.
324
+ topk_candidates (int): Number of best detections to keep before NMS.
325
+
326
+ Example:
327
+
328
+ >>> import torch
329
+ >>> import torchvision
330
+ >>> from torchvision.models.detection import FCOS
331
+ >>> from torchvision.models.detection.anchor_utils import AnchorGenerator
332
+ >>> # load a pre-trained model for classification and return
333
+ >>> # only the features
334
+ >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
335
+ >>> # FCOS needs to know the number of
336
+ >>> # output channels in a backbone. For mobilenet_v2, it's 1280,
337
+ >>> # so we need to add it here
338
+ >>> backbone.out_channels = 1280
339
+ >>>
340
+ >>> # let's make the network generate 5 x 3 anchors per spatial
341
+ >>> # location, with 5 different sizes and 3 different aspect
342
+ >>> # ratios. We have a Tuple[Tuple[int]] because each feature
343
+ >>> # map could potentially have different sizes and
344
+ >>> # aspect ratios
345
+ >>> anchor_generator = AnchorGenerator(
346
+ >>> sizes=((8,), (16,), (32,), (64,), (128,)),
347
+ >>> aspect_ratios=((1.0,),)
348
+ >>> )
349
+ >>>
350
+ >>> # put the pieces together inside a FCOS model
351
+ >>> model = FCOS(
352
+ >>> backbone,
353
+ >>> num_classes=80,
354
+ >>> anchor_generator=anchor_generator,
355
+ >>> )
356
+ >>> model.eval()
357
+ >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
358
+ >>> predictions = model(x)
359
+ """
360
+
361
+ __annotations__ = {
362
+ "box_coder": det_utils.BoxLinearCoder,
363
+ }
364
+
365
+ def __init__(
366
+ self,
367
+ backbone: nn.Module,
368
+ num_classes: int,
369
+ # transform parameters
370
+ min_size: int = 800,
371
+ max_size: int = 1333,
372
+ image_mean: Optional[List[float]] = None,
373
+ image_std: Optional[List[float]] = None,
374
+ # Anchor parameters
375
+ anchor_generator: Optional[AnchorGenerator] = None,
376
+ head: Optional[nn.Module] = None,
377
+ center_sampling_radius: float = 1.5,
378
+ score_thresh: float = 0.2,
379
+ nms_thresh: float = 0.6,
380
+ detections_per_img: int = 100,
381
+ topk_candidates: int = 1000,
382
+ **kwargs,
383
+ ):
384
+ super().__init__()
385
+ _log_api_usage_once(self)
386
+
387
+ if not hasattr(backbone, "out_channels"):
388
+ raise ValueError(
389
+ "backbone should contain an attribute out_channels "
390
+ "specifying the number of output channels (assumed to be the "
391
+ "same for all the levels)"
392
+ )
393
+ self.backbone = backbone
394
+
395
+ if not isinstance(anchor_generator, (AnchorGenerator, type(None))):
396
+ raise TypeError(
397
+ f"anchor_generator should be of type AnchorGenerator or None, instead got {type(anchor_generator)}"
398
+ )
399
+
400
+ if anchor_generator is None:
401
+ anchor_sizes = ((8,), (16,), (32,), (64,), (128,)) # equal to strides of multi-level feature map
402
+ aspect_ratios = ((1.0,),) * len(anchor_sizes) # set only one anchor
403
+ anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
404
+ self.anchor_generator = anchor_generator
405
+ if self.anchor_generator.num_anchors_per_location()[0] != 1:
406
+ raise ValueError(
407
+ f"anchor_generator.num_anchors_per_location()[0] should be 1 instead of {anchor_generator.num_anchors_per_location()[0]}"
408
+ )
409
+
410
+ if head is None:
411
+ head = FCOSHead(backbone.out_channels, anchor_generator.num_anchors_per_location()[0], num_classes)
412
+ self.head = head
413
+
414
+ self.box_coder = det_utils.BoxLinearCoder(normalize_by_size=True)
415
+
416
+ if image_mean is None:
417
+ image_mean = [0.485, 0.456, 0.406]
418
+ if image_std is None:
419
+ image_std = [0.229, 0.224, 0.225]
420
+ self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
421
+
422
+ self.center_sampling_radius = center_sampling_radius
423
+ self.score_thresh = score_thresh
424
+ self.nms_thresh = nms_thresh
425
+ self.detections_per_img = detections_per_img
426
+ self.topk_candidates = topk_candidates
427
+
428
+ # used only on torchscript mode
429
+ self._has_warned = False
430
+
431
+ @torch.jit.unused
432
+ def eager_outputs(
433
+ self, losses: Dict[str, Tensor], detections: List[Dict[str, Tensor]]
434
+ ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
435
+ if self.training:
436
+ return losses
437
+
438
+ return detections
439
+
440
+ def compute_loss(
441
+ self,
442
+ targets: List[Dict[str, Tensor]],
443
+ head_outputs: Dict[str, Tensor],
444
+ anchors: List[Tensor],
445
+ num_anchors_per_level: List[int],
446
+ ) -> Dict[str, Tensor]:
447
+ matched_idxs = []
448
+ for anchors_per_image, targets_per_image in zip(anchors, targets):
449
+ if targets_per_image["boxes"].numel() == 0:
450
+ matched_idxs.append(
451
+ torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device)
452
+ )
453
+ continue
454
+
455
+ gt_boxes = targets_per_image["boxes"]
456
+ gt_centers = (gt_boxes[:, :2] + gt_boxes[:, 2:]) / 2 # Nx2
457
+ anchor_centers = (anchors_per_image[:, :2] + anchors_per_image[:, 2:]) / 2 # N
458
+ anchor_sizes = anchors_per_image[:, 2] - anchors_per_image[:, 0]
459
+ # center sampling: anchor point must be close enough to gt center.
460
+ pairwise_match = (anchor_centers[:, None, :] - gt_centers[None, :, :]).abs_().max(
461
+ dim=2
462
+ ).values < self.center_sampling_radius * anchor_sizes[:, None]
463
+ # compute pairwise distance between N points and M boxes
464
+ x, y = anchor_centers.unsqueeze(dim=2).unbind(dim=1) # (N, 1)
465
+ x0, y0, x1, y1 = gt_boxes.unsqueeze(dim=0).unbind(dim=2) # (1, M)
466
+ pairwise_dist = torch.stack([x - x0, y - y0, x1 - x, y1 - y], dim=2) # (N, M)
467
+
468
+ # anchor point must be inside gt
469
+ pairwise_match &= pairwise_dist.min(dim=2).values > 0
470
+
471
+ # each anchor is only responsible for certain scale range.
472
+ lower_bound = anchor_sizes * 4
473
+ lower_bound[: num_anchors_per_level[0]] = 0
474
+ upper_bound = anchor_sizes * 8
475
+ upper_bound[-num_anchors_per_level[-1] :] = float("inf")
476
+ pairwise_dist = pairwise_dist.max(dim=2).values
477
+ pairwise_match &= (pairwise_dist > lower_bound[:, None]) & (pairwise_dist < upper_bound[:, None])
478
+
479
+ # match the GT box with minimum area, if there are multiple GT matches
480
+ gt_areas = (gt_boxes[:, 2] - gt_boxes[:, 0]) * (gt_boxes[:, 3] - gt_boxes[:, 1]) # N
481
+ pairwise_match = pairwise_match.to(torch.float32) * (1e8 - gt_areas[None, :])
482
+ min_values, matched_idx = pairwise_match.max(dim=1) # R, per-anchor match
483
+ matched_idx[min_values < 1e-5] = -1 # unmatched anchors are assigned -1
484
+
485
+ matched_idxs.append(matched_idx)
486
+
487
+ return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs)
488
+
489
+ def postprocess_detections(
490
+ self, head_outputs: Dict[str, List[Tensor]], anchors: List[List[Tensor]], image_shapes: List[Tuple[int, int]]
491
+ ) -> List[Dict[str, Tensor]]:
492
+ class_logits = head_outputs["cls_logits"]
493
+ box_regression = head_outputs["bbox_regression"]
494
+ box_ctrness = head_outputs["bbox_ctrness"]
495
+
496
+ num_images = len(image_shapes)
497
+
498
+ detections: List[Dict[str, Tensor]] = []
499
+
500
+ for index in range(num_images):
501
+ box_regression_per_image = [br[index] for br in box_regression]
502
+ logits_per_image = [cl[index] for cl in class_logits]
503
+ box_ctrness_per_image = [bc[index] for bc in box_ctrness]
504
+ anchors_per_image, image_shape = anchors[index], image_shapes[index]
505
+
506
+ image_boxes = []
507
+ image_scores = []
508
+ image_labels = []
509
+
510
+ for box_regression_per_level, logits_per_level, box_ctrness_per_level, anchors_per_level in zip(
511
+ box_regression_per_image, logits_per_image, box_ctrness_per_image, anchors_per_image
512
+ ):
513
+ num_classes = logits_per_level.shape[-1]
514
+
515
+ # remove low scoring boxes
516
+ scores_per_level = torch.sqrt(
517
+ torch.sigmoid(logits_per_level) * torch.sigmoid(box_ctrness_per_level)
518
+ ).flatten()
519
+ keep_idxs = scores_per_level > self.score_thresh
520
+ scores_per_level = scores_per_level[keep_idxs]
521
+ topk_idxs = torch.where(keep_idxs)[0]
522
+
523
+ # keep only topk scoring predictions
524
+ num_topk = det_utils._topk_min(topk_idxs, self.topk_candidates, 0)
525
+ scores_per_level, idxs = scores_per_level.topk(num_topk)
526
+ topk_idxs = topk_idxs[idxs]
527
+
528
+ anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor")
529
+ labels_per_level = topk_idxs % num_classes
530
+
531
+ boxes_per_level = self.box_coder.decode(
532
+ box_regression_per_level[anchor_idxs], anchors_per_level[anchor_idxs]
533
+ )
534
+ boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape)
535
+
536
+ image_boxes.append(boxes_per_level)
537
+ image_scores.append(scores_per_level)
538
+ image_labels.append(labels_per_level)
539
+
540
+ image_boxes = torch.cat(image_boxes, dim=0)
541
+ image_scores = torch.cat(image_scores, dim=0)
542
+ image_labels = torch.cat(image_labels, dim=0)
543
+
544
+ # non-maximum suppression
545
+ keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
546
+ keep = keep[: self.detections_per_img]
547
+
548
+ detections.append(
549
+ {
550
+ "boxes": image_boxes[keep],
551
+ "scores": image_scores[keep],
552
+ "labels": image_labels[keep],
553
+ }
554
+ )
555
+
556
+ return detections
557
+
558
+ def forward(
559
+ self,
560
+ images: List[Tensor],
561
+ targets: Optional[List[Dict[str, Tensor]]] = None,
562
+ ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
563
+ """
564
+ Args:
565
+ images (list[Tensor]): images to be processed
566
+ targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)
567
+
568
+ Returns:
569
+ result (list[BoxList] or dict[Tensor]): the output from the model.
570
+ During training, it returns a dict[Tensor] which contains the losses.
571
+ During testing, it returns list[BoxList] contains additional fields
572
+ like `scores`, `labels` and `mask` (for Mask R-CNN models).
573
+ """
574
+ if self.training:
575
+
576
+ if targets is None:
577
+ torch._assert(False, "targets should not be none when in training mode")
578
+ else:
579
+ for target in targets:
580
+ boxes = target["boxes"]
581
+ torch._assert(isinstance(boxes, torch.Tensor), "Expected target boxes to be of type Tensor.")
582
+ torch._assert(
583
+ len(boxes.shape) == 2 and boxes.shape[-1] == 4,
584
+ f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.",
585
+ )
586
+
587
+ original_image_sizes: List[Tuple[int, int]] = []
588
+ for img in images:
589
+ val = img.shape[-2:]
590
+ torch._assert(
591
+ len(val) == 2,
592
+ f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
593
+ )
594
+ original_image_sizes.append((val[0], val[1]))
595
+
596
+ # transform the input
597
+ images, targets = self.transform(images, targets)
598
+
599
+ # Check for degenerate boxes
600
+ if targets is not None:
601
+ for target_idx, target in enumerate(targets):
602
+ boxes = target["boxes"]
603
+ degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
604
+ if degenerate_boxes.any():
605
+ # print the first degenerate box
606
+ bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
607
+ degen_bb: List[float] = boxes[bb_idx].tolist()
608
+ torch._assert(
609
+ False,
610
+ f"All bounding boxes should have positive height and width. Found invalid box {degen_bb} for target at index {target_idx}.",
611
+ )
612
+
613
+ # get the features from the backbone
614
+ features = self.backbone(images.tensors)
615
+ if isinstance(features, torch.Tensor):
616
+ features = OrderedDict([("0", features)])
617
+
618
+ features = list(features.values())
619
+
620
+ # compute the fcos heads outputs using the features
621
+ head_outputs = self.head(features)
622
+
623
+ # create the set of anchors
624
+ anchors = self.anchor_generator(images, features)
625
+ # recover level sizes
626
+ num_anchors_per_level = [x.size(2) * x.size(3) for x in features]
627
+
628
+ losses = {}
629
+ detections: List[Dict[str, Tensor]] = []
630
+ if self.training:
631
+ if targets is None:
632
+ torch._assert(False, "targets should not be none when in training mode")
633
+ else:
634
+ # compute the losses
635
+ losses = self.compute_loss(targets, head_outputs, anchors, num_anchors_per_level)
636
+ else:
637
+ # split outputs per level
638
+ split_head_outputs: Dict[str, List[Tensor]] = {}
639
+ for k in head_outputs:
640
+ split_head_outputs[k] = list(head_outputs[k].split(num_anchors_per_level, dim=1))
641
+ split_anchors = [list(a.split(num_anchors_per_level)) for a in anchors]
642
+
643
+ # compute the detections
644
+ detections = self.postprocess_detections(split_head_outputs, split_anchors, images.image_sizes)
645
+ detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
646
+
647
+ if torch.jit.is_scripting():
648
+ if not self._has_warned:
649
+ warnings.warn("FCOS always returns a (Losses, Detections) tuple in scripting")
650
+ self._has_warned = True
651
+ return losses, detections
652
+ return self.eager_outputs(losses, detections)
653
+
654
+
655
+ class FCOS_ResNet50_FPN_Weights(WeightsEnum):
656
+ COCO_V1 = Weights(
657
+ url="https://download.pytorch.org/models/fcos_resnet50_fpn_coco-99b0c9b7.pth",
658
+ transforms=ObjectDetection,
659
+ meta={
660
+ "num_params": 32269600,
661
+ "categories": _COCO_CATEGORIES,
662
+ "min_size": (1, 1),
663
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#fcos-resnet-50-fpn",
664
+ "_metrics": {
665
+ "COCO-val2017": {
666
+ "box_map": 39.2,
667
+ }
668
+ },
669
+ "_ops": 128.207,
670
+ "_file_size": 123.608,
671
+ "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
672
+ },
673
+ )
674
+ DEFAULT = COCO_V1
675
+
676
+
677
+ @register_model()
678
+ @handle_legacy_interface(
679
+ weights=("pretrained", FCOS_ResNet50_FPN_Weights.COCO_V1),
680
+ weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
681
+ )
682
+ def fcos_resnet50_fpn(
683
+ *,
684
+ weights: Optional[FCOS_ResNet50_FPN_Weights] = None,
685
+ progress: bool = True,
686
+ num_classes: Optional[int] = None,
687
+ weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
688
+ trainable_backbone_layers: Optional[int] = None,
689
+ **kwargs: Any,
690
+ ) -> FCOS:
691
+ """
692
+ Constructs a FCOS model with a ResNet-50-FPN backbone.
693
+
694
+ .. betastatus:: detection module
695
+
696
+ Reference: `FCOS: Fully Convolutional One-Stage Object Detection <https://arxiv.org/abs/1904.01355>`_.
697
+ `FCOS: A simple and strong anchor-free object detector <https://arxiv.org/abs/2006.09214>`_.
698
+
699
+ The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
700
+ image, and should be in ``0-1`` range. Different images can have different sizes.
701
+
702
+ The behavior of the model changes depending on if it is in training or evaluation mode.
703
+
704
+ During training, the model expects both the input tensors and targets (list of dictionary),
705
+ containing:
706
+
707
+ - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
708
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
709
+ - labels (``Int64Tensor[N]``): the class label for each ground-truth box
710
+
711
+ The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
712
+ losses.
713
+
714
+ During inference, the model requires only the input tensors, and returns the post-processed
715
+ predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
716
+ follows, where ``N`` is the number of detections:
717
+
718
+ - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
719
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
720
+ - labels (``Int64Tensor[N]``): the predicted labels for each detection
721
+ - scores (``Tensor[N]``): the scores of each detection
722
+
723
+ For more details on the output, you may refer to :ref:`instance_seg_output`.
724
+
725
+ Example:
726
+
727
+ >>> model = torchvision.models.detection.fcos_resnet50_fpn(weights=FCOS_ResNet50_FPN_Weights.DEFAULT)
728
+ >>> model.eval()
729
+ >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
730
+ >>> predictions = model(x)
731
+
732
+ Args:
733
+ weights (:class:`~torchvision.models.detection.FCOS_ResNet50_FPN_Weights`, optional): The
734
+ pretrained weights to use. See
735
+ :class:`~torchvision.models.detection.FCOS_ResNet50_FPN_Weights`
736
+ below for more details, and possible values. By default, no
737
+ pre-trained weights are used.
738
+ progress (bool): If True, displays a progress bar of the download to stderr
739
+ num_classes (int, optional): number of output classes of the model (including the background)
740
+ weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for
741
+ the backbone.
742
+ trainable_backbone_layers (int, optional): number of trainable (not frozen) resnet layers starting
743
+ from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
744
+ trainable. If ``None`` is passed (the default) this value is set to 3. Default: None
745
+ **kwargs: parameters passed to the ``torchvision.models.detection.FCOS``
746
+ base class. Please refer to the `source code
747
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/fcos.py>`_
748
+ for more details about this class.
749
+
750
+ .. autoclass:: torchvision.models.detection.FCOS_ResNet50_FPN_Weights
751
+ :members:
752
+ """
753
+ weights = FCOS_ResNet50_FPN_Weights.verify(weights)
754
+ weights_backbone = ResNet50_Weights.verify(weights_backbone)
755
+
756
+ if weights is not None:
757
+ weights_backbone = None
758
+ num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
759
+ elif num_classes is None:
760
+ num_classes = 91
761
+
762
+ is_trained = weights is not None or weights_backbone is not None
763
+ trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
764
+ norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
765
+
766
+ backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
767
+ backbone = _resnet_fpn_extractor(
768
+ backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256)
769
+ )
770
+ model = FCOS(backbone, num_classes, **kwargs)
771
+
772
+ if weights is not None:
773
+ model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
774
+
775
+ return model
.venv/lib/python3.11/site-packages/torchvision/models/detection/generalized_rcnn.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implements the Generalized R-CNN framework
3
+ """
4
+
5
+ import warnings
6
+ from collections import OrderedDict
7
+ from typing import Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ from torch import nn, Tensor
11
+
12
+ from ...utils import _log_api_usage_once
13
+
14
+
15
+ class GeneralizedRCNN(nn.Module):
16
+ """
17
+ Main class for Generalized R-CNN.
18
+
19
+ Args:
20
+ backbone (nn.Module):
21
+ rpn (nn.Module):
22
+ roi_heads (nn.Module): takes the features + the proposals from the RPN and computes
23
+ detections / masks from it.
24
+ transform (nn.Module): performs the data transformation from the inputs to feed into
25
+ the model
26
+ """
27
+
28
+ def __init__(self, backbone: nn.Module, rpn: nn.Module, roi_heads: nn.Module, transform: nn.Module) -> None:
29
+ super().__init__()
30
+ _log_api_usage_once(self)
31
+ self.transform = transform
32
+ self.backbone = backbone
33
+ self.rpn = rpn
34
+ self.roi_heads = roi_heads
35
+ # used only on torchscript mode
36
+ self._has_warned = False
37
+
38
+ @torch.jit.unused
39
+ def eager_outputs(self, losses, detections):
40
+ # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]]
41
+ if self.training:
42
+ return losses
43
+
44
+ return detections
45
+
46
+ def forward(self, images, targets=None):
47
+ # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
48
+ """
49
+ Args:
50
+ images (list[Tensor]): images to be processed
51
+ targets (list[Dict[str, Tensor]]): ground-truth boxes present in the image (optional)
52
+
53
+ Returns:
54
+ result (list[BoxList] or dict[Tensor]): the output from the model.
55
+ During training, it returns a dict[Tensor] which contains the losses.
56
+ During testing, it returns list[BoxList] contains additional fields
57
+ like `scores`, `labels` and `mask` (for Mask R-CNN models).
58
+
59
+ """
60
+ if self.training:
61
+ if targets is None:
62
+ torch._assert(False, "targets should not be none when in training mode")
63
+ else:
64
+ for target in targets:
65
+ boxes = target["boxes"]
66
+ if isinstance(boxes, torch.Tensor):
67
+ torch._assert(
68
+ len(boxes.shape) == 2 and boxes.shape[-1] == 4,
69
+ f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.",
70
+ )
71
+ else:
72
+ torch._assert(False, f"Expected target boxes to be of type Tensor, got {type(boxes)}.")
73
+
74
+ original_image_sizes: List[Tuple[int, int]] = []
75
+ for img in images:
76
+ val = img.shape[-2:]
77
+ torch._assert(
78
+ len(val) == 2,
79
+ f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
80
+ )
81
+ original_image_sizes.append((val[0], val[1]))
82
+
83
+ images, targets = self.transform(images, targets)
84
+
85
+ # Check for degenerate boxes
86
+ # TODO: Move this to a function
87
+ if targets is not None:
88
+ for target_idx, target in enumerate(targets):
89
+ boxes = target["boxes"]
90
+ degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
91
+ if degenerate_boxes.any():
92
+ # print the first degenerate box
93
+ bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
94
+ degen_bb: List[float] = boxes[bb_idx].tolist()
95
+ torch._assert(
96
+ False,
97
+ "All bounding boxes should have positive height and width."
98
+ f" Found invalid box {degen_bb} for target at index {target_idx}.",
99
+ )
100
+
101
+ features = self.backbone(images.tensors)
102
+ if isinstance(features, torch.Tensor):
103
+ features = OrderedDict([("0", features)])
104
+ proposals, proposal_losses = self.rpn(images, features, targets)
105
+ detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
106
+ detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) # type: ignore[operator]
107
+
108
+ losses = {}
109
+ losses.update(detector_losses)
110
+ losses.update(proposal_losses)
111
+
112
+ if torch.jit.is_scripting():
113
+ if not self._has_warned:
114
+ warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting")
115
+ self._has_warned = True
116
+ return losses, detections
117
+ else:
118
+ return self.eager_outputs(losses, detections)
.venv/lib/python3.11/site-packages/torchvision/models/detection/image_list.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+
7
+ class ImageList:
8
+ """
9
+ Structure that holds a list of images (of possibly
10
+ varying sizes) as a single tensor.
11
+ This works by padding the images to the same size,
12
+ and storing in a field the original sizes of each image
13
+
14
+ Args:
15
+ tensors (tensor): Tensor containing images.
16
+ image_sizes (list[tuple[int, int]]): List of Tuples each containing size of images.
17
+ """
18
+
19
+ def __init__(self, tensors: Tensor, image_sizes: List[Tuple[int, int]]) -> None:
20
+ self.tensors = tensors
21
+ self.image_sizes = image_sizes
22
+
23
+ def to(self, device: torch.device) -> "ImageList":
24
+ cast_tensor = self.tensors.to(device)
25
+ return ImageList(cast_tensor, self.image_sizes)
.venv/lib/python3.11/site-packages/torchvision/models/detection/keypoint_rcnn.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torchvision.ops import MultiScaleRoIAlign
6
+
7
+ from ...ops import misc as misc_nn_ops
8
+ from ...transforms._presets import ObjectDetection
9
+ from .._api import register_model, Weights, WeightsEnum
10
+ from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
11
+ from .._utils import _ovewrite_value_param, handle_legacy_interface
12
+ from ..resnet import resnet50, ResNet50_Weights
13
+ from ._utils import overwrite_eps
14
+ from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
15
+ from .faster_rcnn import FasterRCNN
16
+
17
+
18
+ __all__ = [
19
+ "KeypointRCNN",
20
+ "KeypointRCNN_ResNet50_FPN_Weights",
21
+ "keypointrcnn_resnet50_fpn",
22
+ ]
23
+
24
+
25
+ class KeypointRCNN(FasterRCNN):
26
+ """
27
+ Implements Keypoint R-CNN.
28
+
29
+ The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
30
+ image, and should be in 0-1 range. Different images can have different sizes.
31
+
32
+ The behavior of the model changes depending on if it is in training or evaluation mode.
33
+
34
+ During training, the model expects both the input tensors and targets (list of dictionary),
35
+ containing:
36
+
37
+ - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
38
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
39
+ - labels (Int64Tensor[N]): the class label for each ground-truth box
40
+ - keypoints (FloatTensor[N, K, 3]): the K keypoints location for each of the N instances, in the
41
+ format [x, y, visibility], where visibility=0 means that the keypoint is not visible.
42
+
43
+ The model returns a Dict[Tensor] during training, containing the classification and regression
44
+ losses for both the RPN and the R-CNN, and the keypoint loss.
45
+
46
+ During inference, the model requires only the input tensors, and returns the post-processed
47
+ predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
48
+ follows:
49
+
50
+ - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
51
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
52
+ - labels (Int64Tensor[N]): the predicted labels for each image
53
+ - scores (Tensor[N]): the scores or each prediction
54
+ - keypoints (FloatTensor[N, K, 3]): the locations of the predicted keypoints, in [x, y, v] format.
55
+
56
+ Args:
57
+ backbone (nn.Module): the network used to compute the features for the model.
58
+ It should contain an out_channels attribute, which indicates the number of output
59
+ channels that each feature map has (and it should be the same for all feature maps).
60
+ The backbone should return a single Tensor or and OrderedDict[Tensor].
61
+ num_classes (int): number of output classes of the model (including the background).
62
+ If box_predictor is specified, num_classes should be None.
63
+ min_size (int): Images are rescaled before feeding them to the backbone:
64
+ we attempt to preserve the aspect ratio and scale the shorter edge
65
+ to ``min_size``. If the resulting longer edge exceeds ``max_size``,
66
+ then downscale so that the longer edge does not exceed ``max_size``.
67
+ This may result in the shorter edge beeing lower than ``min_size``.
68
+ max_size (int): See ``min_size``.
69
+ image_mean (Tuple[float, float, float]): mean values used for input normalization.
70
+ They are generally the mean values of the dataset on which the backbone has been trained
71
+ on
72
+ image_std (Tuple[float, float, float]): std values used for input normalization.
73
+ They are generally the std values of the dataset on which the backbone has been trained on
74
+ rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
75
+ maps.
76
+ rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
77
+ rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
78
+ rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
79
+ rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
80
+ rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
81
+ rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
82
+ rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
83
+ considered as positive during training of the RPN.
84
+ rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
85
+ considered as negative during training of the RPN.
86
+ rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
87
+ for computing the loss
88
+ rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
89
+ of the RPN
90
+ rpn_score_thresh (float): only return proposals with an objectness score greater than rpn_score_thresh
91
+ box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
92
+ the locations indicated by the bounding boxes
93
+ box_head (nn.Module): module that takes the cropped feature maps as input
94
+ box_predictor (nn.Module): module that takes the output of box_head and returns the
95
+ classification logits and box regression deltas.
96
+ box_score_thresh (float): during inference, only return proposals with a classification score
97
+ greater than box_score_thresh
98
+ box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
99
+ box_detections_per_img (int): maximum number of detections per image, for all classes.
100
+ box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
101
+ considered as positive during training of the classification head
102
+ box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
103
+ considered as negative during training of the classification head
104
+ box_batch_size_per_image (int): number of proposals that are sampled during training of the
105
+ classification head
106
+ box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
107
+ of the classification head
108
+ bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
109
+ bounding boxes
110
+ keypoint_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
111
+ the locations indicated by the bounding boxes, which will be used for the keypoint head.
112
+ keypoint_head (nn.Module): module that takes the cropped feature maps as input
113
+ keypoint_predictor (nn.Module): module that takes the output of the keypoint_head and returns the
114
+ heatmap logits
115
+
116
+ Example::
117
+
118
+ >>> import torch
119
+ >>> import torchvision
120
+ >>> from torchvision.models.detection import KeypointRCNN
121
+ >>> from torchvision.models.detection.anchor_utils import AnchorGenerator
122
+ >>>
123
+ >>> # load a pre-trained model for classification and return
124
+ >>> # only the features
125
+ >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
126
+ >>> # KeypointRCNN needs to know the number of
127
+ >>> # output channels in a backbone. For mobilenet_v2, it's 1280,
128
+ >>> # so we need to add it here
129
+ >>> backbone.out_channels = 1280
130
+ >>>
131
+ >>> # let's make the RPN generate 5 x 3 anchors per spatial
132
+ >>> # location, with 5 different sizes and 3 different aspect
133
+ >>> # ratios. We have a Tuple[Tuple[int]] because each feature
134
+ >>> # map could potentially have different sizes and
135
+ >>> # aspect ratios
136
+ >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
137
+ >>> aspect_ratios=((0.5, 1.0, 2.0),))
138
+ >>>
139
+ >>> # let's define what are the feature maps that we will
140
+ >>> # use to perform the region of interest cropping, as well as
141
+ >>> # the size of the crop after rescaling.
142
+ >>> # if your backbone returns a Tensor, featmap_names is expected to
143
+ >>> # be ['0']. More generally, the backbone should return an
144
+ >>> # OrderedDict[Tensor], and in featmap_names you can choose which
145
+ >>> # feature maps to use.
146
+ >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
147
+ >>> output_size=7,
148
+ >>> sampling_ratio=2)
149
+ >>>
150
+ >>> keypoint_roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
151
+ >>> output_size=14,
152
+ >>> sampling_ratio=2)
153
+ >>> # put the pieces together inside a KeypointRCNN model
154
+ >>> model = KeypointRCNN(backbone,
155
+ >>> num_classes=2,
156
+ >>> rpn_anchor_generator=anchor_generator,
157
+ >>> box_roi_pool=roi_pooler,
158
+ >>> keypoint_roi_pool=keypoint_roi_pooler)
159
+ >>> model.eval()
160
+ >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
161
+ >>> predictions = model(x)
162
+ """
163
+
164
+ def __init__(
165
+ self,
166
+ backbone,
167
+ num_classes=None,
168
+ # transform parameters
169
+ min_size=None,
170
+ max_size=1333,
171
+ image_mean=None,
172
+ image_std=None,
173
+ # RPN parameters
174
+ rpn_anchor_generator=None,
175
+ rpn_head=None,
176
+ rpn_pre_nms_top_n_train=2000,
177
+ rpn_pre_nms_top_n_test=1000,
178
+ rpn_post_nms_top_n_train=2000,
179
+ rpn_post_nms_top_n_test=1000,
180
+ rpn_nms_thresh=0.7,
181
+ rpn_fg_iou_thresh=0.7,
182
+ rpn_bg_iou_thresh=0.3,
183
+ rpn_batch_size_per_image=256,
184
+ rpn_positive_fraction=0.5,
185
+ rpn_score_thresh=0.0,
186
+ # Box parameters
187
+ box_roi_pool=None,
188
+ box_head=None,
189
+ box_predictor=None,
190
+ box_score_thresh=0.05,
191
+ box_nms_thresh=0.5,
192
+ box_detections_per_img=100,
193
+ box_fg_iou_thresh=0.5,
194
+ box_bg_iou_thresh=0.5,
195
+ box_batch_size_per_image=512,
196
+ box_positive_fraction=0.25,
197
+ bbox_reg_weights=None,
198
+ # keypoint parameters
199
+ keypoint_roi_pool=None,
200
+ keypoint_head=None,
201
+ keypoint_predictor=None,
202
+ num_keypoints=None,
203
+ **kwargs,
204
+ ):
205
+
206
+ if not isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))):
207
+ raise TypeError(
208
+ "keypoint_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(keypoint_roi_pool)}"
209
+ )
210
+ if min_size is None:
211
+ min_size = (640, 672, 704, 736, 768, 800)
212
+
213
+ if num_keypoints is not None:
214
+ if keypoint_predictor is not None:
215
+ raise ValueError("num_keypoints should be None when keypoint_predictor is specified")
216
+ else:
217
+ num_keypoints = 17
218
+
219
+ out_channels = backbone.out_channels
220
+
221
+ if keypoint_roi_pool is None:
222
+ keypoint_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2)
223
+
224
+ if keypoint_head is None:
225
+ keypoint_layers = tuple(512 for _ in range(8))
226
+ keypoint_head = KeypointRCNNHeads(out_channels, keypoint_layers)
227
+
228
+ if keypoint_predictor is None:
229
+ keypoint_dim_reduced = 512 # == keypoint_layers[-1]
230
+ keypoint_predictor = KeypointRCNNPredictor(keypoint_dim_reduced, num_keypoints)
231
+
232
+ super().__init__(
233
+ backbone,
234
+ num_classes,
235
+ # transform parameters
236
+ min_size,
237
+ max_size,
238
+ image_mean,
239
+ image_std,
240
+ # RPN-specific parameters
241
+ rpn_anchor_generator,
242
+ rpn_head,
243
+ rpn_pre_nms_top_n_train,
244
+ rpn_pre_nms_top_n_test,
245
+ rpn_post_nms_top_n_train,
246
+ rpn_post_nms_top_n_test,
247
+ rpn_nms_thresh,
248
+ rpn_fg_iou_thresh,
249
+ rpn_bg_iou_thresh,
250
+ rpn_batch_size_per_image,
251
+ rpn_positive_fraction,
252
+ rpn_score_thresh,
253
+ # Box parameters
254
+ box_roi_pool,
255
+ box_head,
256
+ box_predictor,
257
+ box_score_thresh,
258
+ box_nms_thresh,
259
+ box_detections_per_img,
260
+ box_fg_iou_thresh,
261
+ box_bg_iou_thresh,
262
+ box_batch_size_per_image,
263
+ box_positive_fraction,
264
+ bbox_reg_weights,
265
+ **kwargs,
266
+ )
267
+
268
+ self.roi_heads.keypoint_roi_pool = keypoint_roi_pool
269
+ self.roi_heads.keypoint_head = keypoint_head
270
+ self.roi_heads.keypoint_predictor = keypoint_predictor
271
+
272
+
273
+ class KeypointRCNNHeads(nn.Sequential):
274
+ def __init__(self, in_channels, layers):
275
+ d = []
276
+ next_feature = in_channels
277
+ for out_channels in layers:
278
+ d.append(nn.Conv2d(next_feature, out_channels, 3, stride=1, padding=1))
279
+ d.append(nn.ReLU(inplace=True))
280
+ next_feature = out_channels
281
+ super().__init__(*d)
282
+ for m in self.children():
283
+ if isinstance(m, nn.Conv2d):
284
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
285
+ nn.init.constant_(m.bias, 0)
286
+
287
+
288
+ class KeypointRCNNPredictor(nn.Module):
289
+ def __init__(self, in_channels, num_keypoints):
290
+ super().__init__()
291
+ input_features = in_channels
292
+ deconv_kernel = 4
293
+ self.kps_score_lowres = nn.ConvTranspose2d(
294
+ input_features,
295
+ num_keypoints,
296
+ deconv_kernel,
297
+ stride=2,
298
+ padding=deconv_kernel // 2 - 1,
299
+ )
300
+ nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu")
301
+ nn.init.constant_(self.kps_score_lowres.bias, 0)
302
+ self.up_scale = 2
303
+ self.out_channels = num_keypoints
304
+
305
+ def forward(self, x):
306
+ x = self.kps_score_lowres(x)
307
+ return torch.nn.functional.interpolate(
308
+ x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
309
+ )
310
+
311
+
312
+ _COMMON_META = {
313
+ "categories": _COCO_PERSON_CATEGORIES,
314
+ "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES,
315
+ "min_size": (1, 1),
316
+ }
317
+
318
+
319
+ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
320
+ COCO_LEGACY = Weights(
321
+ url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth",
322
+ transforms=ObjectDetection,
323
+ meta={
324
+ **_COMMON_META,
325
+ "num_params": 59137258,
326
+ "recipe": "https://github.com/pytorch/vision/issues/1606",
327
+ "_metrics": {
328
+ "COCO-val2017": {
329
+ "box_map": 50.6,
330
+ "kp_map": 61.1,
331
+ }
332
+ },
333
+ "_ops": 133.924,
334
+ "_file_size": 226.054,
335
+ "_docs": """
336
+ These weights were produced by following a similar training recipe as on the paper but use a checkpoint
337
+ from an early epoch.
338
+ """,
339
+ },
340
+ )
341
+ COCO_V1 = Weights(
342
+ url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
343
+ transforms=ObjectDetection,
344
+ meta={
345
+ **_COMMON_META,
346
+ "num_params": 59137258,
347
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#keypoint-r-cnn",
348
+ "_metrics": {
349
+ "COCO-val2017": {
350
+ "box_map": 54.6,
351
+ "kp_map": 65.0,
352
+ }
353
+ },
354
+ "_ops": 137.42,
355
+ "_file_size": 226.054,
356
+ "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
357
+ },
358
+ )
359
+ DEFAULT = COCO_V1
360
+
361
+
362
+ @register_model()
363
+ @handle_legacy_interface(
364
+ weights=(
365
+ "pretrained",
366
+ lambda kwargs: KeypointRCNN_ResNet50_FPN_Weights.COCO_LEGACY
367
+ if kwargs["pretrained"] == "legacy"
368
+ else KeypointRCNN_ResNet50_FPN_Weights.COCO_V1,
369
+ ),
370
+ weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
371
+ )
372
+ def keypointrcnn_resnet50_fpn(
373
+ *,
374
+ weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None,
375
+ progress: bool = True,
376
+ num_classes: Optional[int] = None,
377
+ num_keypoints: Optional[int] = None,
378
+ weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
379
+ trainable_backbone_layers: Optional[int] = None,
380
+ **kwargs: Any,
381
+ ) -> KeypointRCNN:
382
+ """
383
+ Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone.
384
+
385
+ .. betastatus:: detection module
386
+
387
+ Reference: `Mask R-CNN <https://arxiv.org/abs/1703.06870>`__.
388
+
389
+ The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
390
+ image, and should be in ``0-1`` range. Different images can have different sizes.
391
+
392
+ The behavior of the model changes depending on if it is in training or evaluation mode.
393
+
394
+ During training, the model expects both the input tensors and targets (list of dictionary),
395
+ containing:
396
+
397
+ - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
398
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
399
+ - labels (``Int64Tensor[N]``): the class label for each ground-truth box
400
+ - keypoints (``FloatTensor[N, K, 3]``): the ``K`` keypoints location for each of the ``N`` instances, in the
401
+ format ``[x, y, visibility]``, where ``visibility=0`` means that the keypoint is not visible.
402
+
403
+ The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
404
+ losses for both the RPN and the R-CNN, and the keypoint loss.
405
+
406
+ During inference, the model requires only the input tensors, and returns the post-processed
407
+ predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
408
+ follows, where ``N`` is the number of detected instances:
409
+
410
+ - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
411
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
412
+ - labels (``Int64Tensor[N]``): the predicted labels for each instance
413
+ - scores (``Tensor[N]``): the scores or each instance
414
+ - keypoints (``FloatTensor[N, K, 3]``): the locations of the predicted keypoints, in ``[x, y, v]`` format.
415
+
416
+ For more details on the output, you may refer to :ref:`instance_seg_output`.
417
+
418
+ Keypoint R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
419
+
420
+ Example::
421
+
422
+ >>> model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=KeypointRCNN_ResNet50_FPN_Weights.DEFAULT)
423
+ >>> model.eval()
424
+ >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
425
+ >>> predictions = model(x)
426
+ >>>
427
+ >>> # optionally, if you want to export the model to ONNX:
428
+ >>> torch.onnx.export(model, x, "keypoint_rcnn.onnx", opset_version = 11)
429
+
430
+ Args:
431
+ weights (:class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`, optional): The
432
+ pretrained weights to use. See
433
+ :class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`
434
+ below for more details, and possible values. By default, no
435
+ pre-trained weights are used.
436
+ progress (bool): If True, displays a progress bar of the download to stderr
437
+ num_classes (int, optional): number of output classes of the model (including the background)
438
+ num_keypoints (int, optional): number of keypoints
439
+ weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
440
+ pretrained weights for the backbone.
441
+ trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
442
+ Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
443
+ passed (the default) this value is set to 3.
444
+
445
+ .. autoclass:: torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights
446
+ :members:
447
+ """
448
+ weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights)
449
+ weights_backbone = ResNet50_Weights.verify(weights_backbone)
450
+
451
+ if weights is not None:
452
+ weights_backbone = None
453
+ num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
454
+ num_keypoints = _ovewrite_value_param("num_keypoints", num_keypoints, len(weights.meta["keypoint_names"]))
455
+ else:
456
+ if num_classes is None:
457
+ num_classes = 2
458
+ if num_keypoints is None:
459
+ num_keypoints = 17
460
+
461
+ is_trained = weights is not None or weights_backbone is not None
462
+ trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
463
+ norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
464
+
465
+ backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
466
+ backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
467
+ model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
468
+
469
+ if weights is not None:
470
+ model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
471
+ if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1:
472
+ overwrite_eps(model, 0.0)
473
+
474
+ return model
.venv/lib/python3.11/site-packages/torchvision/models/detection/mask_rcnn.py ADDED
@@ -0,0 +1,590 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Any, Callable, Optional
3
+
4
+ from torch import nn
5
+ from torchvision.ops import MultiScaleRoIAlign
6
+
7
+ from ...ops import misc as misc_nn_ops
8
+ from ...transforms._presets import ObjectDetection
9
+ from .._api import register_model, Weights, WeightsEnum
10
+ from .._meta import _COCO_CATEGORIES
11
+ from .._utils import _ovewrite_value_param, handle_legacy_interface
12
+ from ..resnet import resnet50, ResNet50_Weights
13
+ from ._utils import overwrite_eps
14
+ from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
15
+ from .faster_rcnn import _default_anchorgen, FasterRCNN, FastRCNNConvFCHead, RPNHead
16
+
17
+
18
+ __all__ = [
19
+ "MaskRCNN",
20
+ "MaskRCNN_ResNet50_FPN_Weights",
21
+ "MaskRCNN_ResNet50_FPN_V2_Weights",
22
+ "maskrcnn_resnet50_fpn",
23
+ "maskrcnn_resnet50_fpn_v2",
24
+ ]
25
+
26
+
27
+ class MaskRCNN(FasterRCNN):
28
+ """
29
+ Implements Mask R-CNN.
30
+
31
+ The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
32
+ image, and should be in 0-1 range. Different images can have different sizes.
33
+
34
+ The behavior of the model changes depending on if it is in training or evaluation mode.
35
+
36
+ During training, the model expects both the input tensors and targets (list of dictionary),
37
+ containing:
38
+ - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
39
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
40
+ - labels (Int64Tensor[N]): the class label for each ground-truth box
41
+ - masks (UInt8Tensor[N, H, W]): the segmentation binary masks for each instance
42
+
43
+ The model returns a Dict[Tensor] during training, containing the classification and regression
44
+ losses for both the RPN and the R-CNN, and the mask loss.
45
+
46
+ During inference, the model requires only the input tensors, and returns the post-processed
47
+ predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
48
+ follows:
49
+ - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
50
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
51
+ - labels (Int64Tensor[N]): the predicted labels for each image
52
+ - scores (Tensor[N]): the scores or each prediction
53
+ - masks (UInt8Tensor[N, 1, H, W]): the predicted masks for each instance, in 0-1 range. In order to
54
+ obtain the final segmentation masks, the soft masks can be thresholded, generally
55
+ with a value of 0.5 (mask >= 0.5)
56
+
57
+ Args:
58
+ backbone (nn.Module): the network used to compute the features for the model.
59
+ It should contain an out_channels attribute, which indicates the number of output
60
+ channels that each feature map has (and it should be the same for all feature maps).
61
+ The backbone should return a single Tensor or and OrderedDict[Tensor].
62
+ num_classes (int): number of output classes of the model (including the background).
63
+ If box_predictor is specified, num_classes should be None.
64
+ min_size (int): Images are rescaled before feeding them to the backbone:
65
+ we attempt to preserve the aspect ratio and scale the shorter edge
66
+ to ``min_size``. If the resulting longer edge exceeds ``max_size``,
67
+ then downscale so that the longer edge does not exceed ``max_size``.
68
+ This may result in the shorter edge beeing lower than ``min_size``.
69
+ max_size (int): See ``min_size``.
70
+ image_mean (Tuple[float, float, float]): mean values used for input normalization.
71
+ They are generally the mean values of the dataset on which the backbone has been trained
72
+ on
73
+ image_std (Tuple[float, float, float]): std values used for input normalization.
74
+ They are generally the std values of the dataset on which the backbone has been trained on
75
+ rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
76
+ maps.
77
+ rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
78
+ rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
79
+ rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
80
+ rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
81
+ rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
82
+ rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
83
+ rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
84
+ considered as positive during training of the RPN.
85
+ rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
86
+ considered as negative during training of the RPN.
87
+ rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
88
+ for computing the loss
89
+ rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
90
+ of the RPN
91
+ rpn_score_thresh (float): only return proposals with an objectness score greater than rpn_score_thresh
92
+ box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
93
+ the locations indicated by the bounding boxes
94
+ box_head (nn.Module): module that takes the cropped feature maps as input
95
+ box_predictor (nn.Module): module that takes the output of box_head and returns the
96
+ classification logits and box regression deltas.
97
+ box_score_thresh (float): during inference, only return proposals with a classification score
98
+ greater than box_score_thresh
99
+ box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
100
+ box_detections_per_img (int): maximum number of detections per image, for all classes.
101
+ box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
102
+ considered as positive during training of the classification head
103
+ box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
104
+ considered as negative during training of the classification head
105
+ box_batch_size_per_image (int): number of proposals that are sampled during training of the
106
+ classification head
107
+ box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
108
+ of the classification head
109
+ bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
110
+ bounding boxes
111
+ mask_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
112
+ the locations indicated by the bounding boxes, which will be used for the mask head.
113
+ mask_head (nn.Module): module that takes the cropped feature maps as input
114
+ mask_predictor (nn.Module): module that takes the output of the mask_head and returns the
115
+ segmentation mask logits
116
+
117
+ Example::
118
+
119
+ >>> import torch
120
+ >>> import torchvision
121
+ >>> from torchvision.models.detection import MaskRCNN
122
+ >>> from torchvision.models.detection.anchor_utils import AnchorGenerator
123
+ >>>
124
+ >>> # load a pre-trained model for classification and return
125
+ >>> # only the features
126
+ >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
127
+ >>> # MaskRCNN needs to know the number of
128
+ >>> # output channels in a backbone. For mobilenet_v2, it's 1280
129
+ >>> # so we need to add it here,
130
+ >>> backbone.out_channels = 1280
131
+ >>>
132
+ >>> # let's make the RPN generate 5 x 3 anchors per spatial
133
+ >>> # location, with 5 different sizes and 3 different aspect
134
+ >>> # ratios. We have a Tuple[Tuple[int]] because each feature
135
+ >>> # map could potentially have different sizes and
136
+ >>> # aspect ratios
137
+ >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
138
+ >>> aspect_ratios=((0.5, 1.0, 2.0),))
139
+ >>>
140
+ >>> # let's define what are the feature maps that we will
141
+ >>> # use to perform the region of interest cropping, as well as
142
+ >>> # the size of the crop after rescaling.
143
+ >>> # if your backbone returns a Tensor, featmap_names is expected to
144
+ >>> # be ['0']. More generally, the backbone should return an
145
+ >>> # OrderedDict[Tensor], and in featmap_names you can choose which
146
+ >>> # feature maps to use.
147
+ >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
148
+ >>> output_size=7,
149
+ >>> sampling_ratio=2)
150
+ >>>
151
+ >>> mask_roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
152
+ >>> output_size=14,
153
+ >>> sampling_ratio=2)
154
+ >>> # put the pieces together inside a MaskRCNN model
155
+ >>> model = MaskRCNN(backbone,
156
+ >>> num_classes=2,
157
+ >>> rpn_anchor_generator=anchor_generator,
158
+ >>> box_roi_pool=roi_pooler,
159
+ >>> mask_roi_pool=mask_roi_pooler)
160
+ >>> model.eval()
161
+ >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
162
+ >>> predictions = model(x)
163
+ """
164
+
165
+ def __init__(
166
+ self,
167
+ backbone,
168
+ num_classes=None,
169
+ # transform parameters
170
+ min_size=800,
171
+ max_size=1333,
172
+ image_mean=None,
173
+ image_std=None,
174
+ # RPN parameters
175
+ rpn_anchor_generator=None,
176
+ rpn_head=None,
177
+ rpn_pre_nms_top_n_train=2000,
178
+ rpn_pre_nms_top_n_test=1000,
179
+ rpn_post_nms_top_n_train=2000,
180
+ rpn_post_nms_top_n_test=1000,
181
+ rpn_nms_thresh=0.7,
182
+ rpn_fg_iou_thresh=0.7,
183
+ rpn_bg_iou_thresh=0.3,
184
+ rpn_batch_size_per_image=256,
185
+ rpn_positive_fraction=0.5,
186
+ rpn_score_thresh=0.0,
187
+ # Box parameters
188
+ box_roi_pool=None,
189
+ box_head=None,
190
+ box_predictor=None,
191
+ box_score_thresh=0.05,
192
+ box_nms_thresh=0.5,
193
+ box_detections_per_img=100,
194
+ box_fg_iou_thresh=0.5,
195
+ box_bg_iou_thresh=0.5,
196
+ box_batch_size_per_image=512,
197
+ box_positive_fraction=0.25,
198
+ bbox_reg_weights=None,
199
+ # Mask parameters
200
+ mask_roi_pool=None,
201
+ mask_head=None,
202
+ mask_predictor=None,
203
+ **kwargs,
204
+ ):
205
+
206
+ if not isinstance(mask_roi_pool, (MultiScaleRoIAlign, type(None))):
207
+ raise TypeError(
208
+ f"mask_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(mask_roi_pool)}"
209
+ )
210
+
211
+ if num_classes is not None:
212
+ if mask_predictor is not None:
213
+ raise ValueError("num_classes should be None when mask_predictor is specified")
214
+
215
+ out_channels = backbone.out_channels
216
+
217
+ if mask_roi_pool is None:
218
+ mask_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2)
219
+
220
+ if mask_head is None:
221
+ mask_layers = (256, 256, 256, 256)
222
+ mask_dilation = 1
223
+ mask_head = MaskRCNNHeads(out_channels, mask_layers, mask_dilation)
224
+
225
+ if mask_predictor is None:
226
+ mask_predictor_in_channels = 256 # == mask_layers[-1]
227
+ mask_dim_reduced = 256
228
+ mask_predictor = MaskRCNNPredictor(mask_predictor_in_channels, mask_dim_reduced, num_classes)
229
+
230
+ super().__init__(
231
+ backbone,
232
+ num_classes,
233
+ # transform parameters
234
+ min_size,
235
+ max_size,
236
+ image_mean,
237
+ image_std,
238
+ # RPN-specific parameters
239
+ rpn_anchor_generator,
240
+ rpn_head,
241
+ rpn_pre_nms_top_n_train,
242
+ rpn_pre_nms_top_n_test,
243
+ rpn_post_nms_top_n_train,
244
+ rpn_post_nms_top_n_test,
245
+ rpn_nms_thresh,
246
+ rpn_fg_iou_thresh,
247
+ rpn_bg_iou_thresh,
248
+ rpn_batch_size_per_image,
249
+ rpn_positive_fraction,
250
+ rpn_score_thresh,
251
+ # Box parameters
252
+ box_roi_pool,
253
+ box_head,
254
+ box_predictor,
255
+ box_score_thresh,
256
+ box_nms_thresh,
257
+ box_detections_per_img,
258
+ box_fg_iou_thresh,
259
+ box_bg_iou_thresh,
260
+ box_batch_size_per_image,
261
+ box_positive_fraction,
262
+ bbox_reg_weights,
263
+ **kwargs,
264
+ )
265
+
266
+ self.roi_heads.mask_roi_pool = mask_roi_pool
267
+ self.roi_heads.mask_head = mask_head
268
+ self.roi_heads.mask_predictor = mask_predictor
269
+
270
+
271
+ class MaskRCNNHeads(nn.Sequential):
272
+ _version = 2
273
+
274
+ def __init__(self, in_channels, layers, dilation, norm_layer: Optional[Callable[..., nn.Module]] = None):
275
+ """
276
+ Args:
277
+ in_channels (int): number of input channels
278
+ layers (list): feature dimensions of each FCN layer
279
+ dilation (int): dilation rate of kernel
280
+ norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
281
+ """
282
+ blocks = []
283
+ next_feature = in_channels
284
+ for layer_features in layers:
285
+ blocks.append(
286
+ misc_nn_ops.Conv2dNormActivation(
287
+ next_feature,
288
+ layer_features,
289
+ kernel_size=3,
290
+ stride=1,
291
+ padding=dilation,
292
+ dilation=dilation,
293
+ norm_layer=norm_layer,
294
+ )
295
+ )
296
+ next_feature = layer_features
297
+
298
+ super().__init__(*blocks)
299
+ for layer in self.modules():
300
+ if isinstance(layer, nn.Conv2d):
301
+ nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu")
302
+ if layer.bias is not None:
303
+ nn.init.zeros_(layer.bias)
304
+
305
+ def _load_from_state_dict(
306
+ self,
307
+ state_dict,
308
+ prefix,
309
+ local_metadata,
310
+ strict,
311
+ missing_keys,
312
+ unexpected_keys,
313
+ error_msgs,
314
+ ):
315
+ version = local_metadata.get("version", None)
316
+
317
+ if version is None or version < 2:
318
+ num_blocks = len(self)
319
+ for i in range(num_blocks):
320
+ for type in ["weight", "bias"]:
321
+ old_key = f"{prefix}mask_fcn{i+1}.{type}"
322
+ new_key = f"{prefix}{i}.0.{type}"
323
+ if old_key in state_dict:
324
+ state_dict[new_key] = state_dict.pop(old_key)
325
+
326
+ super()._load_from_state_dict(
327
+ state_dict,
328
+ prefix,
329
+ local_metadata,
330
+ strict,
331
+ missing_keys,
332
+ unexpected_keys,
333
+ error_msgs,
334
+ )
335
+
336
+
337
+ class MaskRCNNPredictor(nn.Sequential):
338
+ def __init__(self, in_channels, dim_reduced, num_classes):
339
+ super().__init__(
340
+ OrderedDict(
341
+ [
342
+ ("conv5_mask", nn.ConvTranspose2d(in_channels, dim_reduced, 2, 2, 0)),
343
+ ("relu", nn.ReLU(inplace=True)),
344
+ ("mask_fcn_logits", nn.Conv2d(dim_reduced, num_classes, 1, 1, 0)),
345
+ ]
346
+ )
347
+ )
348
+
349
+ for name, param in self.named_parameters():
350
+ if "weight" in name:
351
+ nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu")
352
+ # elif "bias" in name:
353
+ # nn.init.constant_(param, 0)
354
+
355
+
356
+ _COMMON_META = {
357
+ "categories": _COCO_CATEGORIES,
358
+ "min_size": (1, 1),
359
+ }
360
+
361
+
362
+ class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum):
363
+ COCO_V1 = Weights(
364
+ url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth",
365
+ transforms=ObjectDetection,
366
+ meta={
367
+ **_COMMON_META,
368
+ "num_params": 44401393,
369
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#mask-r-cnn",
370
+ "_metrics": {
371
+ "COCO-val2017": {
372
+ "box_map": 37.9,
373
+ "mask_map": 34.6,
374
+ }
375
+ },
376
+ "_ops": 134.38,
377
+ "_file_size": 169.84,
378
+ "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
379
+ },
380
+ )
381
+ DEFAULT = COCO_V1
382
+
383
+
384
+ class MaskRCNN_ResNet50_FPN_V2_Weights(WeightsEnum):
385
+ COCO_V1 = Weights(
386
+ url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_v2_coco-73cbd019.pth",
387
+ transforms=ObjectDetection,
388
+ meta={
389
+ **_COMMON_META,
390
+ "num_params": 46359409,
391
+ "recipe": "https://github.com/pytorch/vision/pull/5773",
392
+ "_metrics": {
393
+ "COCO-val2017": {
394
+ "box_map": 47.4,
395
+ "mask_map": 41.8,
396
+ }
397
+ },
398
+ "_ops": 333.577,
399
+ "_file_size": 177.219,
400
+ "_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""",
401
+ },
402
+ )
403
+ DEFAULT = COCO_V1
404
+
405
+
406
+ @register_model()
407
+ @handle_legacy_interface(
408
+ weights=("pretrained", MaskRCNN_ResNet50_FPN_Weights.COCO_V1),
409
+ weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
410
+ )
411
+ def maskrcnn_resnet50_fpn(
412
+ *,
413
+ weights: Optional[MaskRCNN_ResNet50_FPN_Weights] = None,
414
+ progress: bool = True,
415
+ num_classes: Optional[int] = None,
416
+ weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
417
+ trainable_backbone_layers: Optional[int] = None,
418
+ **kwargs: Any,
419
+ ) -> MaskRCNN:
420
+ """Mask R-CNN model with a ResNet-50-FPN backbone from the `Mask R-CNN
421
+ <https://arxiv.org/abs/1703.06870>`_ paper.
422
+
423
+ .. betastatus:: detection module
424
+
425
+ The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
426
+ image, and should be in ``0-1`` range. Different images can have different sizes.
427
+
428
+ The behavior of the model changes depending on if it is in training or evaluation mode.
429
+
430
+ During training, the model expects both the input tensors and targets (list of dictionary),
431
+ containing:
432
+
433
+ - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
434
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
435
+ - labels (``Int64Tensor[N]``): the class label for each ground-truth box
436
+ - masks (``UInt8Tensor[N, H, W]``): the segmentation binary masks for each instance
437
+
438
+ The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
439
+ losses for both the RPN and the R-CNN, and the mask loss.
440
+
441
+ During inference, the model requires only the input tensors, and returns the post-processed
442
+ predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
443
+ follows, where ``N`` is the number of detected instances:
444
+
445
+ - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
446
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
447
+ - labels (``Int64Tensor[N]``): the predicted labels for each instance
448
+ - scores (``Tensor[N]``): the scores or each instance
449
+ - masks (``UInt8Tensor[N, 1, H, W]``): the predicted masks for each instance, in ``0-1`` range. In order to
450
+ obtain the final segmentation masks, the soft masks can be thresholded, generally
451
+ with a value of 0.5 (``mask >= 0.5``)
452
+
453
+ For more details on the output and on how to plot the masks, you may refer to :ref:`instance_seg_output`.
454
+
455
+ Mask R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
456
+
457
+ Example::
458
+
459
+ >>> model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT)
460
+ >>> model.eval()
461
+ >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
462
+ >>> predictions = model(x)
463
+ >>>
464
+ >>> # optionally, if you want to export the model to ONNX:
465
+ >>> torch.onnx.export(model, x, "mask_rcnn.onnx", opset_version = 11)
466
+
467
+ Args:
468
+ weights (:class:`~torchvision.models.detection.MaskRCNN_ResNet50_FPN_Weights`, optional): The
469
+ pretrained weights to use. See
470
+ :class:`~torchvision.models.detection.MaskRCNN_ResNet50_FPN_Weights` below for
471
+ more details, and possible values. By default, no pre-trained
472
+ weights are used.
473
+ progress (bool, optional): If True, displays a progress bar of the
474
+ download to stderr. Default is True.
475
+ num_classes (int, optional): number of output classes of the model (including the background)
476
+ weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
477
+ pretrained weights for the backbone.
478
+ trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
479
+ final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
480
+ trainable. If ``None`` is passed (the default) this value is set to 3.
481
+ **kwargs: parameters passed to the ``torchvision.models.detection.mask_rcnn.MaskRCNN``
482
+ base class. Please refer to the `source code
483
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/mask_rcnn.py>`_
484
+ for more details about this class.
485
+
486
+ .. autoclass:: torchvision.models.detection.MaskRCNN_ResNet50_FPN_Weights
487
+ :members:
488
+ """
489
+ weights = MaskRCNN_ResNet50_FPN_Weights.verify(weights)
490
+ weights_backbone = ResNet50_Weights.verify(weights_backbone)
491
+
492
+ if weights is not None:
493
+ weights_backbone = None
494
+ num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
495
+ elif num_classes is None:
496
+ num_classes = 91
497
+
498
+ is_trained = weights is not None or weights_backbone is not None
499
+ trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
500
+ norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
501
+
502
+ backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
503
+ backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
504
+ model = MaskRCNN(backbone, num_classes=num_classes, **kwargs)
505
+
506
+ if weights is not None:
507
+ model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
508
+ if weights == MaskRCNN_ResNet50_FPN_Weights.COCO_V1:
509
+ overwrite_eps(model, 0.0)
510
+
511
+ return model
512
+
513
+
514
+ @register_model()
515
+ @handle_legacy_interface(
516
+ weights=("pretrained", MaskRCNN_ResNet50_FPN_V2_Weights.COCO_V1),
517
+ weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
518
+ )
519
+ def maskrcnn_resnet50_fpn_v2(
520
+ *,
521
+ weights: Optional[MaskRCNN_ResNet50_FPN_V2_Weights] = None,
522
+ progress: bool = True,
523
+ num_classes: Optional[int] = None,
524
+ weights_backbone: Optional[ResNet50_Weights] = None,
525
+ trainable_backbone_layers: Optional[int] = None,
526
+ **kwargs: Any,
527
+ ) -> MaskRCNN:
528
+ """Improved Mask R-CNN model with a ResNet-50-FPN backbone from the `Benchmarking Detection Transfer
529
+ Learning with Vision Transformers <https://arxiv.org/abs/2111.11429>`_ paper.
530
+
531
+ .. betastatus:: detection module
532
+
533
+ :func:`~torchvision.models.detection.maskrcnn_resnet50_fpn` for more details.
534
+
535
+ Args:
536
+ weights (:class:`~torchvision.models.detection.MaskRCNN_ResNet50_FPN_V2_Weights`, optional): The
537
+ pretrained weights to use. See
538
+ :class:`~torchvision.models.detection.MaskRCNN_ResNet50_FPN_V2_Weights` below for
539
+ more details, and possible values. By default, no pre-trained
540
+ weights are used.
541
+ progress (bool, optional): If True, displays a progress bar of the
542
+ download to stderr. Default is True.
543
+ num_classes (int, optional): number of output classes of the model (including the background)
544
+ weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
545
+ pretrained weights for the backbone.
546
+ trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
547
+ final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
548
+ trainable. If ``None`` is passed (the default) this value is set to 3.
549
+ **kwargs: parameters passed to the ``torchvision.models.detection.mask_rcnn.MaskRCNN``
550
+ base class. Please refer to the `source code
551
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/mask_rcnn.py>`_
552
+ for more details about this class.
553
+
554
+ .. autoclass:: torchvision.models.detection.MaskRCNN_ResNet50_FPN_V2_Weights
555
+ :members:
556
+ """
557
+ weights = MaskRCNN_ResNet50_FPN_V2_Weights.verify(weights)
558
+ weights_backbone = ResNet50_Weights.verify(weights_backbone)
559
+
560
+ if weights is not None:
561
+ weights_backbone = None
562
+ num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
563
+ elif num_classes is None:
564
+ num_classes = 91
565
+
566
+ is_trained = weights is not None or weights_backbone is not None
567
+ trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
568
+
569
+ backbone = resnet50(weights=weights_backbone, progress=progress)
570
+ backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers, norm_layer=nn.BatchNorm2d)
571
+ rpn_anchor_generator = _default_anchorgen()
572
+ rpn_head = RPNHead(backbone.out_channels, rpn_anchor_generator.num_anchors_per_location()[0], conv_depth=2)
573
+ box_head = FastRCNNConvFCHead(
574
+ (backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d
575
+ )
576
+ mask_head = MaskRCNNHeads(backbone.out_channels, [256, 256, 256, 256], 1, norm_layer=nn.BatchNorm2d)
577
+ model = MaskRCNN(
578
+ backbone,
579
+ num_classes=num_classes,
580
+ rpn_anchor_generator=rpn_anchor_generator,
581
+ rpn_head=rpn_head,
582
+ box_head=box_head,
583
+ mask_head=mask_head,
584
+ **kwargs,
585
+ )
586
+
587
+ if weights is not None:
588
+ model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
589
+
590
+ return model
.venv/lib/python3.11/site-packages/torchvision/models/detection/retinanet.py ADDED
@@ -0,0 +1,903 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from collections import OrderedDict
4
+ from functools import partial
5
+ from typing import Any, Callable, Dict, List, Optional, Tuple
6
+
7
+ import torch
8
+ from torch import nn, Tensor
9
+
10
+ from ...ops import boxes as box_ops, misc as misc_nn_ops, sigmoid_focal_loss
11
+ from ...ops.feature_pyramid_network import LastLevelP6P7
12
+ from ...transforms._presets import ObjectDetection
13
+ from ...utils import _log_api_usage_once
14
+ from .._api import register_model, Weights, WeightsEnum
15
+ from .._meta import _COCO_CATEGORIES
16
+ from .._utils import _ovewrite_value_param, handle_legacy_interface
17
+ from ..resnet import resnet50, ResNet50_Weights
18
+ from . import _utils as det_utils
19
+ from ._utils import _box_loss, overwrite_eps
20
+ from .anchor_utils import AnchorGenerator
21
+ from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
22
+ from .transform import GeneralizedRCNNTransform
23
+
24
+
25
+ __all__ = [
26
+ "RetinaNet",
27
+ "RetinaNet_ResNet50_FPN_Weights",
28
+ "RetinaNet_ResNet50_FPN_V2_Weights",
29
+ "retinanet_resnet50_fpn",
30
+ "retinanet_resnet50_fpn_v2",
31
+ ]
32
+
33
+
34
+ def _sum(x: List[Tensor]) -> Tensor:
35
+ res = x[0]
36
+ for i in x[1:]:
37
+ res = res + i
38
+ return res
39
+
40
+
41
+ def _v1_to_v2_weights(state_dict, prefix):
42
+ for i in range(4):
43
+ for type in ["weight", "bias"]:
44
+ old_key = f"{prefix}conv.{2*i}.{type}"
45
+ new_key = f"{prefix}conv.{i}.0.{type}"
46
+ if old_key in state_dict:
47
+ state_dict[new_key] = state_dict.pop(old_key)
48
+
49
+
50
+ def _default_anchorgen():
51
+ anchor_sizes = tuple((x, int(x * 2 ** (1.0 / 3)), int(x * 2 ** (2.0 / 3))) for x in [32, 64, 128, 256, 512])
52
+ aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
53
+ anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
54
+ return anchor_generator
55
+
56
+
57
+ class RetinaNetHead(nn.Module):
58
+ """
59
+ A regression and classification head for use in RetinaNet.
60
+
61
+ Args:
62
+ in_channels (int): number of channels of the input feature
63
+ num_anchors (int): number of anchors to be predicted
64
+ num_classes (int): number of classes to be predicted
65
+ norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
66
+ """
67
+
68
+ def __init__(self, in_channels, num_anchors, num_classes, norm_layer: Optional[Callable[..., nn.Module]] = None):
69
+ super().__init__()
70
+ self.classification_head = RetinaNetClassificationHead(
71
+ in_channels, num_anchors, num_classes, norm_layer=norm_layer
72
+ )
73
+ self.regression_head = RetinaNetRegressionHead(in_channels, num_anchors, norm_layer=norm_layer)
74
+
75
+ def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
76
+ # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Dict[str, Tensor]
77
+ return {
78
+ "classification": self.classification_head.compute_loss(targets, head_outputs, matched_idxs),
79
+ "bbox_regression": self.regression_head.compute_loss(targets, head_outputs, anchors, matched_idxs),
80
+ }
81
+
82
+ def forward(self, x):
83
+ # type: (List[Tensor]) -> Dict[str, Tensor]
84
+ return {"cls_logits": self.classification_head(x), "bbox_regression": self.regression_head(x)}
85
+
86
+
87
+ class RetinaNetClassificationHead(nn.Module):
88
+ """
89
+ A classification head for use in RetinaNet.
90
+
91
+ Args:
92
+ in_channels (int): number of channels of the input feature
93
+ num_anchors (int): number of anchors to be predicted
94
+ num_classes (int): number of classes to be predicted
95
+ norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
96
+ """
97
+
98
+ _version = 2
99
+
100
+ def __init__(
101
+ self,
102
+ in_channels,
103
+ num_anchors,
104
+ num_classes,
105
+ prior_probability=0.01,
106
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
107
+ ):
108
+ super().__init__()
109
+
110
+ conv = []
111
+ for _ in range(4):
112
+ conv.append(misc_nn_ops.Conv2dNormActivation(in_channels, in_channels, norm_layer=norm_layer))
113
+ self.conv = nn.Sequential(*conv)
114
+
115
+ for layer in self.conv.modules():
116
+ if isinstance(layer, nn.Conv2d):
117
+ torch.nn.init.normal_(layer.weight, std=0.01)
118
+ if layer.bias is not None:
119
+ torch.nn.init.constant_(layer.bias, 0)
120
+
121
+ self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
122
+ torch.nn.init.normal_(self.cls_logits.weight, std=0.01)
123
+ torch.nn.init.constant_(self.cls_logits.bias, -math.log((1 - prior_probability) / prior_probability))
124
+
125
+ self.num_classes = num_classes
126
+ self.num_anchors = num_anchors
127
+
128
+ # This is to fix using det_utils.Matcher.BETWEEN_THRESHOLDS in TorchScript.
129
+ # TorchScript doesn't support class attributes.
130
+ # https://github.com/pytorch/vision/pull/1697#issuecomment-630255584
131
+ self.BETWEEN_THRESHOLDS = det_utils.Matcher.BETWEEN_THRESHOLDS
132
+
133
+ def _load_from_state_dict(
134
+ self,
135
+ state_dict,
136
+ prefix,
137
+ local_metadata,
138
+ strict,
139
+ missing_keys,
140
+ unexpected_keys,
141
+ error_msgs,
142
+ ):
143
+ version = local_metadata.get("version", None)
144
+
145
+ if version is None or version < 2:
146
+ _v1_to_v2_weights(state_dict, prefix)
147
+
148
+ super()._load_from_state_dict(
149
+ state_dict,
150
+ prefix,
151
+ local_metadata,
152
+ strict,
153
+ missing_keys,
154
+ unexpected_keys,
155
+ error_msgs,
156
+ )
157
+
158
+ def compute_loss(self, targets, head_outputs, matched_idxs):
159
+ # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Tensor
160
+ losses = []
161
+
162
+ cls_logits = head_outputs["cls_logits"]
163
+
164
+ for targets_per_image, cls_logits_per_image, matched_idxs_per_image in zip(targets, cls_logits, matched_idxs):
165
+ # determine only the foreground
166
+ foreground_idxs_per_image = matched_idxs_per_image >= 0
167
+ num_foreground = foreground_idxs_per_image.sum()
168
+
169
+ # create the target classification
170
+ gt_classes_target = torch.zeros_like(cls_logits_per_image)
171
+ gt_classes_target[
172
+ foreground_idxs_per_image,
173
+ targets_per_image["labels"][matched_idxs_per_image[foreground_idxs_per_image]],
174
+ ] = 1.0
175
+
176
+ # find indices for which anchors should be ignored
177
+ valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS
178
+
179
+ # compute the classification loss
180
+ losses.append(
181
+ sigmoid_focal_loss(
182
+ cls_logits_per_image[valid_idxs_per_image],
183
+ gt_classes_target[valid_idxs_per_image],
184
+ reduction="sum",
185
+ )
186
+ / max(1, num_foreground)
187
+ )
188
+
189
+ return _sum(losses) / len(targets)
190
+
191
+ def forward(self, x):
192
+ # type: (List[Tensor]) -> Tensor
193
+ all_cls_logits = []
194
+
195
+ for features in x:
196
+ cls_logits = self.conv(features)
197
+ cls_logits = self.cls_logits(cls_logits)
198
+
199
+ # Permute classification output from (N, A * K, H, W) to (N, HWA, K).
200
+ N, _, H, W = cls_logits.shape
201
+ cls_logits = cls_logits.view(N, -1, self.num_classes, H, W)
202
+ cls_logits = cls_logits.permute(0, 3, 4, 1, 2)
203
+ cls_logits = cls_logits.reshape(N, -1, self.num_classes) # Size=(N, HWA, 4)
204
+
205
+ all_cls_logits.append(cls_logits)
206
+
207
+ return torch.cat(all_cls_logits, dim=1)
208
+
209
+
210
+ class RetinaNetRegressionHead(nn.Module):
211
+ """
212
+ A regression head for use in RetinaNet.
213
+
214
+ Args:
215
+ in_channels (int): number of channels of the input feature
216
+ num_anchors (int): number of anchors to be predicted
217
+ norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
218
+ """
219
+
220
+ _version = 2
221
+
222
+ __annotations__ = {
223
+ "box_coder": det_utils.BoxCoder,
224
+ }
225
+
226
+ def __init__(self, in_channels, num_anchors, norm_layer: Optional[Callable[..., nn.Module]] = None):
227
+ super().__init__()
228
+
229
+ conv = []
230
+ for _ in range(4):
231
+ conv.append(misc_nn_ops.Conv2dNormActivation(in_channels, in_channels, norm_layer=norm_layer))
232
+ self.conv = nn.Sequential(*conv)
233
+
234
+ self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1)
235
+ torch.nn.init.normal_(self.bbox_reg.weight, std=0.01)
236
+ torch.nn.init.zeros_(self.bbox_reg.bias)
237
+
238
+ for layer in self.conv.modules():
239
+ if isinstance(layer, nn.Conv2d):
240
+ torch.nn.init.normal_(layer.weight, std=0.01)
241
+ if layer.bias is not None:
242
+ torch.nn.init.zeros_(layer.bias)
243
+
244
+ self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
245
+ self._loss_type = "l1"
246
+
247
+ def _load_from_state_dict(
248
+ self,
249
+ state_dict,
250
+ prefix,
251
+ local_metadata,
252
+ strict,
253
+ missing_keys,
254
+ unexpected_keys,
255
+ error_msgs,
256
+ ):
257
+ version = local_metadata.get("version", None)
258
+
259
+ if version is None or version < 2:
260
+ _v1_to_v2_weights(state_dict, prefix)
261
+
262
+ super()._load_from_state_dict(
263
+ state_dict,
264
+ prefix,
265
+ local_metadata,
266
+ strict,
267
+ missing_keys,
268
+ unexpected_keys,
269
+ error_msgs,
270
+ )
271
+
272
+ def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
273
+ # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Tensor
274
+ losses = []
275
+
276
+ bbox_regression = head_outputs["bbox_regression"]
277
+
278
+ for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in zip(
279
+ targets, bbox_regression, anchors, matched_idxs
280
+ ):
281
+ # determine only the foreground indices, ignore the rest
282
+ foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0]
283
+ num_foreground = foreground_idxs_per_image.numel()
284
+
285
+ # select only the foreground boxes
286
+ matched_gt_boxes_per_image = targets_per_image["boxes"][matched_idxs_per_image[foreground_idxs_per_image]]
287
+ bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :]
288
+ anchors_per_image = anchors_per_image[foreground_idxs_per_image, :]
289
+
290
+ # compute the loss
291
+ losses.append(
292
+ _box_loss(
293
+ self._loss_type,
294
+ self.box_coder,
295
+ anchors_per_image,
296
+ matched_gt_boxes_per_image,
297
+ bbox_regression_per_image,
298
+ )
299
+ / max(1, num_foreground)
300
+ )
301
+
302
+ return _sum(losses) / max(1, len(targets))
303
+
304
+ def forward(self, x):
305
+ # type: (List[Tensor]) -> Tensor
306
+ all_bbox_regression = []
307
+
308
+ for features in x:
309
+ bbox_regression = self.conv(features)
310
+ bbox_regression = self.bbox_reg(bbox_regression)
311
+
312
+ # Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4).
313
+ N, _, H, W = bbox_regression.shape
314
+ bbox_regression = bbox_regression.view(N, -1, 4, H, W)
315
+ bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2)
316
+ bbox_regression = bbox_regression.reshape(N, -1, 4) # Size=(N, HWA, 4)
317
+
318
+ all_bbox_regression.append(bbox_regression)
319
+
320
+ return torch.cat(all_bbox_regression, dim=1)
321
+
322
+
323
+ class RetinaNet(nn.Module):
324
+ """
325
+ Implements RetinaNet.
326
+
327
+ The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
328
+ image, and should be in 0-1 range. Different images can have different sizes.
329
+
330
+ The behavior of the model changes depending on if it is in training or evaluation mode.
331
+
332
+ During training, the model expects both the input tensors and targets (list of dictionary),
333
+ containing:
334
+ - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
335
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
336
+ - labels (Int64Tensor[N]): the class label for each ground-truth box
337
+
338
+ The model returns a Dict[Tensor] during training, containing the classification and regression
339
+ losses.
340
+
341
+ During inference, the model requires only the input tensors, and returns the post-processed
342
+ predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
343
+ follows:
344
+ - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
345
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
346
+ - labels (Int64Tensor[N]): the predicted labels for each image
347
+ - scores (Tensor[N]): the scores for each prediction
348
+
349
+ Args:
350
+ backbone (nn.Module): the network used to compute the features for the model.
351
+ It should contain an out_channels attribute, which indicates the number of output
352
+ channels that each feature map has (and it should be the same for all feature maps).
353
+ The backbone should return a single Tensor or an OrderedDict[Tensor].
354
+ num_classes (int): number of output classes of the model (including the background).
355
+ min_size (int): Images are rescaled before feeding them to the backbone:
356
+ we attempt to preserve the aspect ratio and scale the shorter edge
357
+ to ``min_size``. If the resulting longer edge exceeds ``max_size``,
358
+ then downscale so that the longer edge does not exceed ``max_size``.
359
+ This may result in the shorter edge beeing lower than ``min_size``.
360
+ max_size (int): See ``min_size``.
361
+ image_mean (Tuple[float, float, float]): mean values used for input normalization.
362
+ They are generally the mean values of the dataset on which the backbone has been trained
363
+ on
364
+ image_std (Tuple[float, float, float]): std values used for input normalization.
365
+ They are generally the std values of the dataset on which the backbone has been trained on
366
+ anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
367
+ maps.
368
+ head (nn.Module): Module run on top of the feature pyramid.
369
+ Defaults to a module containing a classification and regression module.
370
+ score_thresh (float): Score threshold used for postprocessing the detections.
371
+ nms_thresh (float): NMS threshold used for postprocessing the detections.
372
+ detections_per_img (int): Number of best detections to keep after NMS.
373
+ fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
374
+ considered as positive during training.
375
+ bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
376
+ considered as negative during training.
377
+ topk_candidates (int): Number of best detections to keep before NMS.
378
+
379
+ Example:
380
+
381
+ >>> import torch
382
+ >>> import torchvision
383
+ >>> from torchvision.models.detection import RetinaNet
384
+ >>> from torchvision.models.detection.anchor_utils import AnchorGenerator
385
+ >>> # load a pre-trained model for classification and return
386
+ >>> # only the features
387
+ >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
388
+ >>> # RetinaNet needs to know the number of
389
+ >>> # output channels in a backbone. For mobilenet_v2, it's 1280,
390
+ >>> # so we need to add it here
391
+ >>> backbone.out_channels = 1280
392
+ >>>
393
+ >>> # let's make the network generate 5 x 3 anchors per spatial
394
+ >>> # location, with 5 different sizes and 3 different aspect
395
+ >>> # ratios. We have a Tuple[Tuple[int]] because each feature
396
+ >>> # map could potentially have different sizes and
397
+ >>> # aspect ratios
398
+ >>> anchor_generator = AnchorGenerator(
399
+ >>> sizes=((32, 64, 128, 256, 512),),
400
+ >>> aspect_ratios=((0.5, 1.0, 2.0),)
401
+ >>> )
402
+ >>>
403
+ >>> # put the pieces together inside a RetinaNet model
404
+ >>> model = RetinaNet(backbone,
405
+ >>> num_classes=2,
406
+ >>> anchor_generator=anchor_generator)
407
+ >>> model.eval()
408
+ >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
409
+ >>> predictions = model(x)
410
+ """
411
+
412
+ __annotations__ = {
413
+ "box_coder": det_utils.BoxCoder,
414
+ "proposal_matcher": det_utils.Matcher,
415
+ }
416
+
417
+ def __init__(
418
+ self,
419
+ backbone,
420
+ num_classes,
421
+ # transform parameters
422
+ min_size=800,
423
+ max_size=1333,
424
+ image_mean=None,
425
+ image_std=None,
426
+ # Anchor parameters
427
+ anchor_generator=None,
428
+ head=None,
429
+ proposal_matcher=None,
430
+ score_thresh=0.05,
431
+ nms_thresh=0.5,
432
+ detections_per_img=300,
433
+ fg_iou_thresh=0.5,
434
+ bg_iou_thresh=0.4,
435
+ topk_candidates=1000,
436
+ **kwargs,
437
+ ):
438
+ super().__init__()
439
+ _log_api_usage_once(self)
440
+
441
+ if not hasattr(backbone, "out_channels"):
442
+ raise ValueError(
443
+ "backbone should contain an attribute out_channels "
444
+ "specifying the number of output channels (assumed to be the "
445
+ "same for all the levels)"
446
+ )
447
+ self.backbone = backbone
448
+
449
+ if not isinstance(anchor_generator, (AnchorGenerator, type(None))):
450
+ raise TypeError(
451
+ f"anchor_generator should be of type AnchorGenerator or None instead of {type(anchor_generator)}"
452
+ )
453
+
454
+ if anchor_generator is None:
455
+ anchor_generator = _default_anchorgen()
456
+ self.anchor_generator = anchor_generator
457
+
458
+ if head is None:
459
+ head = RetinaNetHead(backbone.out_channels, anchor_generator.num_anchors_per_location()[0], num_classes)
460
+ self.head = head
461
+
462
+ if proposal_matcher is None:
463
+ proposal_matcher = det_utils.Matcher(
464
+ fg_iou_thresh,
465
+ bg_iou_thresh,
466
+ allow_low_quality_matches=True,
467
+ )
468
+ self.proposal_matcher = proposal_matcher
469
+
470
+ self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
471
+
472
+ if image_mean is None:
473
+ image_mean = [0.485, 0.456, 0.406]
474
+ if image_std is None:
475
+ image_std = [0.229, 0.224, 0.225]
476
+ self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
477
+
478
+ self.score_thresh = score_thresh
479
+ self.nms_thresh = nms_thresh
480
+ self.detections_per_img = detections_per_img
481
+ self.topk_candidates = topk_candidates
482
+
483
+ # used only on torchscript mode
484
+ self._has_warned = False
485
+
486
+ @torch.jit.unused
487
+ def eager_outputs(self, losses, detections):
488
+ # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
489
+ if self.training:
490
+ return losses
491
+
492
+ return detections
493
+
494
+ def compute_loss(self, targets, head_outputs, anchors):
495
+ # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Dict[str, Tensor]
496
+ matched_idxs = []
497
+ for anchors_per_image, targets_per_image in zip(anchors, targets):
498
+ if targets_per_image["boxes"].numel() == 0:
499
+ matched_idxs.append(
500
+ torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device)
501
+ )
502
+ continue
503
+
504
+ match_quality_matrix = box_ops.box_iou(targets_per_image["boxes"], anchors_per_image)
505
+ matched_idxs.append(self.proposal_matcher(match_quality_matrix))
506
+
507
+ return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs)
508
+
509
+ def postprocess_detections(self, head_outputs, anchors, image_shapes):
510
+ # type: (Dict[str, List[Tensor]], List[List[Tensor]], List[Tuple[int, int]]) -> List[Dict[str, Tensor]]
511
+ class_logits = head_outputs["cls_logits"]
512
+ box_regression = head_outputs["bbox_regression"]
513
+
514
+ num_images = len(image_shapes)
515
+
516
+ detections: List[Dict[str, Tensor]] = []
517
+
518
+ for index in range(num_images):
519
+ box_regression_per_image = [br[index] for br in box_regression]
520
+ logits_per_image = [cl[index] for cl in class_logits]
521
+ anchors_per_image, image_shape = anchors[index], image_shapes[index]
522
+
523
+ image_boxes = []
524
+ image_scores = []
525
+ image_labels = []
526
+
527
+ for box_regression_per_level, logits_per_level, anchors_per_level in zip(
528
+ box_regression_per_image, logits_per_image, anchors_per_image
529
+ ):
530
+ num_classes = logits_per_level.shape[-1]
531
+
532
+ # remove low scoring boxes
533
+ scores_per_level = torch.sigmoid(logits_per_level).flatten()
534
+ keep_idxs = scores_per_level > self.score_thresh
535
+ scores_per_level = scores_per_level[keep_idxs]
536
+ topk_idxs = torch.where(keep_idxs)[0]
537
+
538
+ # keep only topk scoring predictions
539
+ num_topk = det_utils._topk_min(topk_idxs, self.topk_candidates, 0)
540
+ scores_per_level, idxs = scores_per_level.topk(num_topk)
541
+ topk_idxs = topk_idxs[idxs]
542
+
543
+ anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor")
544
+ labels_per_level = topk_idxs % num_classes
545
+
546
+ boxes_per_level = self.box_coder.decode_single(
547
+ box_regression_per_level[anchor_idxs], anchors_per_level[anchor_idxs]
548
+ )
549
+ boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape)
550
+
551
+ image_boxes.append(boxes_per_level)
552
+ image_scores.append(scores_per_level)
553
+ image_labels.append(labels_per_level)
554
+
555
+ image_boxes = torch.cat(image_boxes, dim=0)
556
+ image_scores = torch.cat(image_scores, dim=0)
557
+ image_labels = torch.cat(image_labels, dim=0)
558
+
559
+ # non-maximum suppression
560
+ keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
561
+ keep = keep[: self.detections_per_img]
562
+
563
+ detections.append(
564
+ {
565
+ "boxes": image_boxes[keep],
566
+ "scores": image_scores[keep],
567
+ "labels": image_labels[keep],
568
+ }
569
+ )
570
+
571
+ return detections
572
+
573
+ def forward(self, images, targets=None):
574
+ # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
575
+ """
576
+ Args:
577
+ images (list[Tensor]): images to be processed
578
+ targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)
579
+
580
+ Returns:
581
+ result (list[BoxList] or dict[Tensor]): the output from the model.
582
+ During training, it returns a dict[Tensor] which contains the losses.
583
+ During testing, it returns list[BoxList] contains additional fields
584
+ like `scores`, `labels` and `mask` (for Mask R-CNN models).
585
+
586
+ """
587
+ if self.training:
588
+ if targets is None:
589
+ torch._assert(False, "targets should not be none when in training mode")
590
+ else:
591
+ for target in targets:
592
+ boxes = target["boxes"]
593
+ torch._assert(isinstance(boxes, torch.Tensor), "Expected target boxes to be of type Tensor.")
594
+ torch._assert(
595
+ len(boxes.shape) == 2 and boxes.shape[-1] == 4,
596
+ "Expected target boxes to be a tensor of shape [N, 4].",
597
+ )
598
+
599
+ # get the original image sizes
600
+ original_image_sizes: List[Tuple[int, int]] = []
601
+ for img in images:
602
+ val = img.shape[-2:]
603
+ torch._assert(
604
+ len(val) == 2,
605
+ f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
606
+ )
607
+ original_image_sizes.append((val[0], val[1]))
608
+
609
+ # transform the input
610
+ images, targets = self.transform(images, targets)
611
+
612
+ # Check for degenerate boxes
613
+ # TODO: Move this to a function
614
+ if targets is not None:
615
+ for target_idx, target in enumerate(targets):
616
+ boxes = target["boxes"]
617
+ degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
618
+ if degenerate_boxes.any():
619
+ # print the first degenerate box
620
+ bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
621
+ degen_bb: List[float] = boxes[bb_idx].tolist()
622
+ torch._assert(
623
+ False,
624
+ "All bounding boxes should have positive height and width."
625
+ f" Found invalid box {degen_bb} for target at index {target_idx}.",
626
+ )
627
+
628
+ # get the features from the backbone
629
+ features = self.backbone(images.tensors)
630
+ if isinstance(features, torch.Tensor):
631
+ features = OrderedDict([("0", features)])
632
+
633
+ # TODO: Do we want a list or a dict?
634
+ features = list(features.values())
635
+
636
+ # compute the retinanet heads outputs using the features
637
+ head_outputs = self.head(features)
638
+
639
+ # create the set of anchors
640
+ anchors = self.anchor_generator(images, features)
641
+
642
+ losses = {}
643
+ detections: List[Dict[str, Tensor]] = []
644
+ if self.training:
645
+ if targets is None:
646
+ torch._assert(False, "targets should not be none when in training mode")
647
+ else:
648
+ # compute the losses
649
+ losses = self.compute_loss(targets, head_outputs, anchors)
650
+ else:
651
+ # recover level sizes
652
+ num_anchors_per_level = [x.size(2) * x.size(3) for x in features]
653
+ HW = 0
654
+ for v in num_anchors_per_level:
655
+ HW += v
656
+ HWA = head_outputs["cls_logits"].size(1)
657
+ A = HWA // HW
658
+ num_anchors_per_level = [hw * A for hw in num_anchors_per_level]
659
+
660
+ # split outputs per level
661
+ split_head_outputs: Dict[str, List[Tensor]] = {}
662
+ for k in head_outputs:
663
+ split_head_outputs[k] = list(head_outputs[k].split(num_anchors_per_level, dim=1))
664
+ split_anchors = [list(a.split(num_anchors_per_level)) for a in anchors]
665
+
666
+ # compute the detections
667
+ detections = self.postprocess_detections(split_head_outputs, split_anchors, images.image_sizes)
668
+ detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
669
+
670
+ if torch.jit.is_scripting():
671
+ if not self._has_warned:
672
+ warnings.warn("RetinaNet always returns a (Losses, Detections) tuple in scripting")
673
+ self._has_warned = True
674
+ return losses, detections
675
+ return self.eager_outputs(losses, detections)
676
+
677
+
678
+ _COMMON_META = {
679
+ "categories": _COCO_CATEGORIES,
680
+ "min_size": (1, 1),
681
+ }
682
+
683
+
684
+ class RetinaNet_ResNet50_FPN_Weights(WeightsEnum):
685
+ COCO_V1 = Weights(
686
+ url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth",
687
+ transforms=ObjectDetection,
688
+ meta={
689
+ **_COMMON_META,
690
+ "num_params": 34014999,
691
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet",
692
+ "_metrics": {
693
+ "COCO-val2017": {
694
+ "box_map": 36.4,
695
+ }
696
+ },
697
+ "_ops": 151.54,
698
+ "_file_size": 130.267,
699
+ "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
700
+ },
701
+ )
702
+ DEFAULT = COCO_V1
703
+
704
+
705
+ class RetinaNet_ResNet50_FPN_V2_Weights(WeightsEnum):
706
+ COCO_V1 = Weights(
707
+ url="https://download.pytorch.org/models/retinanet_resnet50_fpn_v2_coco-5905b1c5.pth",
708
+ transforms=ObjectDetection,
709
+ meta={
710
+ **_COMMON_META,
711
+ "num_params": 38198935,
712
+ "recipe": "https://github.com/pytorch/vision/pull/5756",
713
+ "_metrics": {
714
+ "COCO-val2017": {
715
+ "box_map": 41.5,
716
+ }
717
+ },
718
+ "_ops": 152.238,
719
+ "_file_size": 146.037,
720
+ "_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""",
721
+ },
722
+ )
723
+ DEFAULT = COCO_V1
724
+
725
+
726
+ @register_model()
727
+ @handle_legacy_interface(
728
+ weights=("pretrained", RetinaNet_ResNet50_FPN_Weights.COCO_V1),
729
+ weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
730
+ )
731
+ def retinanet_resnet50_fpn(
732
+ *,
733
+ weights: Optional[RetinaNet_ResNet50_FPN_Weights] = None,
734
+ progress: bool = True,
735
+ num_classes: Optional[int] = None,
736
+ weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
737
+ trainable_backbone_layers: Optional[int] = None,
738
+ **kwargs: Any,
739
+ ) -> RetinaNet:
740
+ """
741
+ Constructs a RetinaNet model with a ResNet-50-FPN backbone.
742
+
743
+ .. betastatus:: detection module
744
+
745
+ Reference: `Focal Loss for Dense Object Detection <https://arxiv.org/abs/1708.02002>`_.
746
+
747
+ The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
748
+ image, and should be in ``0-1`` range. Different images can have different sizes.
749
+
750
+ The behavior of the model changes depending on if it is in training or evaluation mode.
751
+
752
+ During training, the model expects both the input tensors and targets (list of dictionary),
753
+ containing:
754
+
755
+ - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
756
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
757
+ - labels (``Int64Tensor[N]``): the class label for each ground-truth box
758
+
759
+ The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
760
+ losses.
761
+
762
+ During inference, the model requires only the input tensors, and returns the post-processed
763
+ predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
764
+ follows, where ``N`` is the number of detections:
765
+
766
+ - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
767
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
768
+ - labels (``Int64Tensor[N]``): the predicted labels for each detection
769
+ - scores (``Tensor[N]``): the scores of each detection
770
+
771
+ For more details on the output, you may refer to :ref:`instance_seg_output`.
772
+
773
+ Example::
774
+
775
+ >>> model = torchvision.models.detection.retinanet_resnet50_fpn(weights=RetinaNet_ResNet50_FPN_Weights.DEFAULT)
776
+ >>> model.eval()
777
+ >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
778
+ >>> predictions = model(x)
779
+
780
+ Args:
781
+ weights (:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights`, optional): The
782
+ pretrained weights to use. See
783
+ :class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights`
784
+ below for more details, and possible values. By default, no
785
+ pre-trained weights are used.
786
+ progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
787
+ num_classes (int, optional): number of output classes of the model (including the background)
788
+ weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for
789
+ the backbone.
790
+ trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
791
+ Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
792
+ passed (the default) this value is set to 3.
793
+ **kwargs: parameters passed to the ``torchvision.models.detection.RetinaNet``
794
+ base class. Please refer to the `source code
795
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py>`_
796
+ for more details about this class.
797
+
798
+ .. autoclass:: torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights
799
+ :members:
800
+ """
801
+ weights = RetinaNet_ResNet50_FPN_Weights.verify(weights)
802
+ weights_backbone = ResNet50_Weights.verify(weights_backbone)
803
+
804
+ if weights is not None:
805
+ weights_backbone = None
806
+ num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
807
+ elif num_classes is None:
808
+ num_classes = 91
809
+
810
+ is_trained = weights is not None or weights_backbone is not None
811
+ trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
812
+ norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
813
+
814
+ backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
815
+ # skip P2 because it generates too many anchors (according to their paper)
816
+ backbone = _resnet_fpn_extractor(
817
+ backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256)
818
+ )
819
+ model = RetinaNet(backbone, num_classes, **kwargs)
820
+
821
+ if weights is not None:
822
+ model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
823
+ if weights == RetinaNet_ResNet50_FPN_Weights.COCO_V1:
824
+ overwrite_eps(model, 0.0)
825
+
826
+ return model
827
+
828
+
829
+ @register_model()
830
+ @handle_legacy_interface(
831
+ weights=("pretrained", RetinaNet_ResNet50_FPN_V2_Weights.COCO_V1),
832
+ weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
833
+ )
834
+ def retinanet_resnet50_fpn_v2(
835
+ *,
836
+ weights: Optional[RetinaNet_ResNet50_FPN_V2_Weights] = None,
837
+ progress: bool = True,
838
+ num_classes: Optional[int] = None,
839
+ weights_backbone: Optional[ResNet50_Weights] = None,
840
+ trainable_backbone_layers: Optional[int] = None,
841
+ **kwargs: Any,
842
+ ) -> RetinaNet:
843
+ """
844
+ Constructs an improved RetinaNet model with a ResNet-50-FPN backbone.
845
+
846
+ .. betastatus:: detection module
847
+
848
+ Reference: `Bridging the Gap Between Anchor-based and Anchor-free Detection via Adaptive Training Sample Selection
849
+ <https://arxiv.org/abs/1912.02424>`_.
850
+
851
+ :func:`~torchvision.models.detection.retinanet_resnet50_fpn` for more details.
852
+
853
+ Args:
854
+ weights (:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights`, optional): The
855
+ pretrained weights to use. See
856
+ :class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights`
857
+ below for more details, and possible values. By default, no
858
+ pre-trained weights are used.
859
+ progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
860
+ num_classes (int, optional): number of output classes of the model (including the background)
861
+ weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for
862
+ the backbone.
863
+ trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
864
+ Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
865
+ passed (the default) this value is set to 3.
866
+ **kwargs: parameters passed to the ``torchvision.models.detection.RetinaNet``
867
+ base class. Please refer to the `source code
868
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py>`_
869
+ for more details about this class.
870
+
871
+ .. autoclass:: torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights
872
+ :members:
873
+ """
874
+ weights = RetinaNet_ResNet50_FPN_V2_Weights.verify(weights)
875
+ weights_backbone = ResNet50_Weights.verify(weights_backbone)
876
+
877
+ if weights is not None:
878
+ weights_backbone = None
879
+ num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
880
+ elif num_classes is None:
881
+ num_classes = 91
882
+
883
+ is_trained = weights is not None or weights_backbone is not None
884
+ trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
885
+
886
+ backbone = resnet50(weights=weights_backbone, progress=progress)
887
+ backbone = _resnet_fpn_extractor(
888
+ backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(2048, 256)
889
+ )
890
+ anchor_generator = _default_anchorgen()
891
+ head = RetinaNetHead(
892
+ backbone.out_channels,
893
+ anchor_generator.num_anchors_per_location()[0],
894
+ num_classes,
895
+ norm_layer=partial(nn.GroupNorm, 32),
896
+ )
897
+ head.regression_head._loss_type = "giou"
898
+ model = RetinaNet(backbone, num_classes, anchor_generator=anchor_generator, head=head, **kwargs)
899
+
900
+ if weights is not None:
901
+ model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
902
+
903
+ return model