diff --git a/.gitattributes b/.gitattributes
index ad8b6d0f7298a35b092e6e2056874de1293f1033..33871e11f36b23f15d45ea9da69c5331f0adb93f 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -345,3 +345,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/
.venv/lib/python3.11/site-packages/distlib/t64-arm.exe filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/multidict/_multidict.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/torchvision/image.so filter=lfs diff=lfs merge=lfs -text
+.venv/lib/python3.11/site-packages/torchvision/_C.so filter=lfs diff=lfs merge=lfs -text
diff --git a/.venv/lib/python3.11/site-packages/torchvision/_C.so b/.venv/lib/python3.11/site-packages/torchvision/_C.so
new file mode 100644
index 0000000000000000000000000000000000000000..cf66ef52a6888e9e067c97915d7b5cc9c61d3887
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/_C.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fb7e1b7570bd8fc14f9497793f89e188ccf161d7c14ca1f236e00368779ee609
+size 7746688
diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/_optical_flow.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/_optical_flow.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8d6247f03fadb984686f6ec892f0e9dc2e6a571
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/_optical_flow.py
@@ -0,0 +1,490 @@
+import itertools
+import os
+from abc import ABC, abstractmethod
+from glob import glob
+from pathlib import Path
+from typing import Callable, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from PIL import Image
+
+from ..io.image import decode_png, read_file
+from .utils import _read_pfm, verify_str_arg
+from .vision import VisionDataset
+
+T1 = Tuple[Image.Image, Image.Image, Optional[np.ndarray], Optional[np.ndarray]]
+T2 = Tuple[Image.Image, Image.Image, Optional[np.ndarray]]
+
+
+__all__ = (
+ "KittiFlow",
+ "Sintel",
+ "FlyingThings3D",
+ "FlyingChairs",
+ "HD1K",
+)
+
+
+class FlowDataset(ABC, VisionDataset):
+ # Some datasets like Kitti have a built-in valid_flow_mask, indicating which flow values are valid
+ # For those we return (img1, img2, flow, valid_flow_mask), and for the rest we return (img1, img2, flow),
+ # and it's up to whatever consumes the dataset to decide what valid_flow_mask should be.
+ _has_builtin_flow_mask = False
+
+ def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None:
+
+ super().__init__(root=root)
+ self.transforms = transforms
+
+ self._flow_list: List[str] = []
+ self._image_list: List[List[str]] = []
+
+ def _read_img(self, file_name: str) -> Image.Image:
+ img = Image.open(file_name)
+ if img.mode != "RGB":
+ img = img.convert("RGB") # type: ignore[assignment]
+ return img
+
+ @abstractmethod
+ def _read_flow(self, file_name: str):
+ # Return the flow or a tuple with the flow and the valid_flow_mask if _has_builtin_flow_mask is True
+ pass
+
+ def __getitem__(self, index: int) -> Union[T1, T2]:
+
+ img1 = self._read_img(self._image_list[index][0])
+ img2 = self._read_img(self._image_list[index][1])
+
+ if self._flow_list: # it will be empty for some dataset when split="test"
+ flow = self._read_flow(self._flow_list[index])
+ if self._has_builtin_flow_mask:
+ flow, valid_flow_mask = flow
+ else:
+ valid_flow_mask = None
+ else:
+ flow = valid_flow_mask = None
+
+ if self.transforms is not None:
+ img1, img2, flow, valid_flow_mask = self.transforms(img1, img2, flow, valid_flow_mask)
+
+ if self._has_builtin_flow_mask or valid_flow_mask is not None:
+ # The `or valid_flow_mask is not None` part is here because the mask can be generated within a transform
+ return img1, img2, flow, valid_flow_mask
+ else:
+ return img1, img2, flow
+
+ def __len__(self) -> int:
+ return len(self._image_list)
+
+ def __rmul__(self, v: int) -> torch.utils.data.ConcatDataset:
+ return torch.utils.data.ConcatDataset([self] * v)
+
+
+class Sintel(FlowDataset):
+ """`Sintel `_ Dataset for optical flow.
+
+ The dataset is expected to have the following structure: ::
+
+ root
+ Sintel
+ testing
+ clean
+ scene_1
+ scene_2
+ ...
+ final
+ scene_1
+ scene_2
+ ...
+ training
+ clean
+ scene_1
+ scene_2
+ ...
+ final
+ scene_1
+ scene_2
+ ...
+ flow
+ scene_1
+ scene_2
+ ...
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of the Sintel Dataset.
+ split (string, optional): The dataset split, either "train" (default) or "test"
+ pass_name (string, optional): The pass to use, either "clean" (default), "final", or "both". See link above for
+ details on the different passes.
+ transforms (callable, optional): A function/transform that takes in
+ ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
+ ``valid_flow_mask`` is expected for consistency with other datasets which
+ return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
+ """
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ split: str = "train",
+ pass_name: str = "clean",
+ transforms: Optional[Callable] = None,
+ ) -> None:
+ super().__init__(root=root, transforms=transforms)
+
+ verify_str_arg(split, "split", valid_values=("train", "test"))
+ verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final", "both"))
+ passes = ["clean", "final"] if pass_name == "both" else [pass_name]
+
+ root = Path(root) / "Sintel"
+ flow_root = root / "training" / "flow"
+
+ for pass_name in passes:
+ split_dir = "training" if split == "train" else split
+ image_root = root / split_dir / pass_name
+ for scene in os.listdir(image_root):
+ image_list = sorted(glob(str(image_root / scene / "*.png")))
+ for i in range(len(image_list) - 1):
+ self._image_list += [[image_list[i], image_list[i + 1]]]
+
+ if split == "train":
+ self._flow_list += sorted(glob(str(flow_root / scene / "*.flo")))
+
+ def __getitem__(self, index: int) -> Union[T1, T2]:
+ """Return example at given index.
+
+ Args:
+ index(int): The index of the example to retrieve
+
+ Returns:
+ tuple: A 3-tuple with ``(img1, img2, flow)``.
+ The flow is a numpy array of shape (2, H, W) and the images are PIL images.
+ ``flow`` is None if ``split="test"``.
+ If a valid flow mask is generated within the ``transforms`` parameter,
+ a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
+ """
+ return super().__getitem__(index)
+
+ def _read_flow(self, file_name: str) -> np.ndarray:
+ return _read_flo(file_name)
+
+
+class KittiFlow(FlowDataset):
+ """`KITTI `__ dataset for optical flow (2015).
+
+ The dataset is expected to have the following structure: ::
+
+ root
+ KittiFlow
+ testing
+ image_2
+ training
+ image_2
+ flow_occ
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of the KittiFlow Dataset.
+ split (string, optional): The dataset split, either "train" (default) or "test"
+ transforms (callable, optional): A function/transform that takes in
+ ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
+ """
+
+ _has_builtin_flow_mask = True
+
+ def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
+ super().__init__(root=root, transforms=transforms)
+
+ verify_str_arg(split, "split", valid_values=("train", "test"))
+
+ root = Path(root) / "KittiFlow" / (split + "ing")
+ images1 = sorted(glob(str(root / "image_2" / "*_10.png")))
+ images2 = sorted(glob(str(root / "image_2" / "*_11.png")))
+
+ if not images1 or not images2:
+ raise FileNotFoundError(
+ "Could not find the Kitti flow images. Please make sure the directory structure is correct."
+ )
+
+ for img1, img2 in zip(images1, images2):
+ self._image_list += [[img1, img2]]
+
+ if split == "train":
+ self._flow_list = sorted(glob(str(root / "flow_occ" / "*_10.png")))
+
+ def __getitem__(self, index: int) -> Union[T1, T2]:
+ """Return example at given index.
+
+ Args:
+ index(int): The index of the example to retrieve
+
+ Returns:
+ tuple: A 4-tuple with ``(img1, img2, flow, valid_flow_mask)``
+ where ``valid_flow_mask`` is a numpy boolean mask of shape (H, W)
+ indicating which flow values are valid. The flow is a numpy array of
+ shape (2, H, W) and the images are PIL images. ``flow`` and ``valid_flow_mask`` are None if
+ ``split="test"``.
+ """
+ return super().__getitem__(index)
+
+ def _read_flow(self, file_name: str) -> Tuple[np.ndarray, np.ndarray]:
+ return _read_16bits_png_with_flow_and_valid_mask(file_name)
+
+
+class FlyingChairs(FlowDataset):
+ """`FlyingChairs `_ Dataset for optical flow.
+
+ You will also need to download the FlyingChairs_train_val.txt file from the dataset page.
+
+ The dataset is expected to have the following structure: ::
+
+ root
+ FlyingChairs
+ data
+ 00001_flow.flo
+ 00001_img1.ppm
+ 00001_img2.ppm
+ ...
+ FlyingChairs_train_val.txt
+
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of the FlyingChairs Dataset.
+ split (string, optional): The dataset split, either "train" (default) or "val"
+ transforms (callable, optional): A function/transform that takes in
+ ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
+ ``valid_flow_mask`` is expected for consistency with other datasets which
+ return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
+ """
+
+ def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
+ super().__init__(root=root, transforms=transforms)
+
+ verify_str_arg(split, "split", valid_values=("train", "val"))
+
+ root = Path(root) / "FlyingChairs"
+ images = sorted(glob(str(root / "data" / "*.ppm")))
+ flows = sorted(glob(str(root / "data" / "*.flo")))
+
+ split_file_name = "FlyingChairs_train_val.txt"
+
+ if not os.path.exists(root / split_file_name):
+ raise FileNotFoundError(
+ "The FlyingChairs_train_val.txt file was not found - please download it from the dataset page (see docstring)."
+ )
+
+ split_list = np.loadtxt(str(root / split_file_name), dtype=np.int32)
+ for i in range(len(flows)):
+ split_id = split_list[i]
+ if (split == "train" and split_id == 1) or (split == "val" and split_id == 2):
+ self._flow_list += [flows[i]]
+ self._image_list += [[images[2 * i], images[2 * i + 1]]]
+
+ def __getitem__(self, index: int) -> Union[T1, T2]:
+ """Return example at given index.
+
+ Args:
+ index(int): The index of the example to retrieve
+
+ Returns:
+ tuple: A 3-tuple with ``(img1, img2, flow)``.
+ The flow is a numpy array of shape (2, H, W) and the images are PIL images.
+ ``flow`` is None if ``split="val"``.
+ If a valid flow mask is generated within the ``transforms`` parameter,
+ a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
+ """
+ return super().__getitem__(index)
+
+ def _read_flow(self, file_name: str) -> np.ndarray:
+ return _read_flo(file_name)
+
+
+class FlyingThings3D(FlowDataset):
+ """`FlyingThings3D `_ dataset for optical flow.
+
+ The dataset is expected to have the following structure: ::
+
+ root
+ FlyingThings3D
+ frames_cleanpass
+ TEST
+ TRAIN
+ frames_finalpass
+ TEST
+ TRAIN
+ optical_flow
+ TEST
+ TRAIN
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of the intel FlyingThings3D Dataset.
+ split (string, optional): The dataset split, either "train" (default) or "test"
+ pass_name (string, optional): The pass to use, either "clean" (default) or "final" or "both". See link above for
+ details on the different passes.
+ camera (string, optional): Which camera to return images from. Can be either "left" (default) or "right" or "both".
+ transforms (callable, optional): A function/transform that takes in
+ ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
+ ``valid_flow_mask`` is expected for consistency with other datasets which
+ return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
+ """
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ split: str = "train",
+ pass_name: str = "clean",
+ camera: str = "left",
+ transforms: Optional[Callable] = None,
+ ) -> None:
+ super().__init__(root=root, transforms=transforms)
+
+ verify_str_arg(split, "split", valid_values=("train", "test"))
+ split = split.upper()
+
+ verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final", "both"))
+ passes = {
+ "clean": ["frames_cleanpass"],
+ "final": ["frames_finalpass"],
+ "both": ["frames_cleanpass", "frames_finalpass"],
+ }[pass_name]
+
+ verify_str_arg(camera, "camera", valid_values=("left", "right", "both"))
+ cameras = ["left", "right"] if camera == "both" else [camera]
+
+ root = Path(root) / "FlyingThings3D"
+
+ directions = ("into_future", "into_past")
+ for pass_name, camera, direction in itertools.product(passes, cameras, directions):
+ image_dirs = sorted(glob(str(root / pass_name / split / "*/*")))
+ image_dirs = sorted(Path(image_dir) / camera for image_dir in image_dirs)
+
+ flow_dirs = sorted(glob(str(root / "optical_flow" / split / "*/*")))
+ flow_dirs = sorted(Path(flow_dir) / direction / camera for flow_dir in flow_dirs)
+
+ if not image_dirs or not flow_dirs:
+ raise FileNotFoundError(
+ "Could not find the FlyingThings3D flow images. "
+ "Please make sure the directory structure is correct."
+ )
+
+ for image_dir, flow_dir in zip(image_dirs, flow_dirs):
+ images = sorted(glob(str(image_dir / "*.png")))
+ flows = sorted(glob(str(flow_dir / "*.pfm")))
+ for i in range(len(flows) - 1):
+ if direction == "into_future":
+ self._image_list += [[images[i], images[i + 1]]]
+ self._flow_list += [flows[i]]
+ elif direction == "into_past":
+ self._image_list += [[images[i + 1], images[i]]]
+ self._flow_list += [flows[i + 1]]
+
+ def __getitem__(self, index: int) -> Union[T1, T2]:
+ """Return example at given index.
+
+ Args:
+ index(int): The index of the example to retrieve
+
+ Returns:
+ tuple: A 3-tuple with ``(img1, img2, flow)``.
+ The flow is a numpy array of shape (2, H, W) and the images are PIL images.
+ ``flow`` is None if ``split="test"``.
+ If a valid flow mask is generated within the ``transforms`` parameter,
+ a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
+ """
+ return super().__getitem__(index)
+
+ def _read_flow(self, file_name: str) -> np.ndarray:
+ return _read_pfm(file_name)
+
+
+class HD1K(FlowDataset):
+ """`HD1K `__ dataset for optical flow.
+
+ The dataset is expected to have the following structure: ::
+
+ root
+ hd1k
+ hd1k_challenge
+ image_2
+ hd1k_flow_gt
+ flow_occ
+ hd1k_input
+ image_2
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of the HD1K Dataset.
+ split (string, optional): The dataset split, either "train" (default) or "test"
+ transforms (callable, optional): A function/transform that takes in
+ ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
+ """
+
+ _has_builtin_flow_mask = True
+
+ def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
+ super().__init__(root=root, transforms=transforms)
+
+ verify_str_arg(split, "split", valid_values=("train", "test"))
+
+ root = Path(root) / "hd1k"
+ if split == "train":
+ # There are 36 "sequences" and we don't want seq i to overlap with seq i + 1, so we need this for loop
+ for seq_idx in range(36):
+ flows = sorted(glob(str(root / "hd1k_flow_gt" / "flow_occ" / f"{seq_idx:06d}_*.png")))
+ images = sorted(glob(str(root / "hd1k_input" / "image_2" / f"{seq_idx:06d}_*.png")))
+ for i in range(len(flows) - 1):
+ self._flow_list += [flows[i]]
+ self._image_list += [[images[i], images[i + 1]]]
+ else:
+ images1 = sorted(glob(str(root / "hd1k_challenge" / "image_2" / "*10.png")))
+ images2 = sorted(glob(str(root / "hd1k_challenge" / "image_2" / "*11.png")))
+ for image1, image2 in zip(images1, images2):
+ self._image_list += [[image1, image2]]
+
+ if not self._image_list:
+ raise FileNotFoundError(
+ "Could not find the HD1K images. Please make sure the directory structure is correct."
+ )
+
+ def _read_flow(self, file_name: str) -> Tuple[np.ndarray, np.ndarray]:
+ return _read_16bits_png_with_flow_and_valid_mask(file_name)
+
+ def __getitem__(self, index: int) -> Union[T1, T2]:
+ """Return example at given index.
+
+ Args:
+ index(int): The index of the example to retrieve
+
+ Returns:
+ tuple: A 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` where ``valid_flow_mask``
+ is a numpy boolean mask of shape (H, W)
+ indicating which flow values are valid. The flow is a numpy array of
+ shape (2, H, W) and the images are PIL images. ``flow`` and ``valid_flow_mask`` are None if
+ ``split="test"``.
+ """
+ return super().__getitem__(index)
+
+
+def _read_flo(file_name: str) -> np.ndarray:
+ """Read .flo file in Middlebury format"""
+ # Code adapted from:
+ # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
+ # Everything needs to be in little Endian according to
+ # https://vision.middlebury.edu/flow/code/flow-code/README.txt
+ with open(file_name, "rb") as f:
+ magic = np.fromfile(f, "c", count=4).tobytes()
+ if magic != b"PIEH":
+ raise ValueError("Magic number incorrect. Invalid .flo file")
+
+ w = int(np.fromfile(f, " Tuple[np.ndarray, np.ndarray]:
+
+ flow_and_valid = decode_png(read_file(file_name)).to(torch.float32)
+ flow, valid_flow_mask = flow_and_valid[:2, :, :], flow_and_valid[2, :, :]
+ flow = (flow - 2**15) / 64 # This conversion is explained somewhere on the kitti archive
+ valid_flow_mask = valid_flow_mask.bool()
+
+ # For consistency with other datasets, we convert to numpy
+ return flow.numpy(), valid_flow_mask.numpy()
diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/country211.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/country211.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0f82ee1226670ae274cc09dc0697dd4c51fb1c2
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/country211.py
@@ -0,0 +1,58 @@
+from pathlib import Path
+from typing import Callable, Optional, Union
+
+from .folder import ImageFolder
+from .utils import download_and_extract_archive, verify_str_arg
+
+
+class Country211(ImageFolder):
+ """`The Country211 Data Set `_ from OpenAI.
+
+ This dataset was built by filtering the images from the YFCC100m dataset
+ that have GPS coordinate corresponding to a ISO-3166 country code. The
+ dataset is balanced by sampling 150 train images, 50 validation images, and
+ 100 test images for each country.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of the dataset.
+ split (string, optional): The dataset split, supports ``"train"`` (default), ``"valid"`` and ``"test"``.
+ transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
+ version. E.g, ``transforms.RandomCrop``.
+ target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+ download (bool, optional): If True, downloads the dataset from the internet and puts it into
+ ``root/country211/``. If dataset is already downloaded, it is not downloaded again.
+ """
+
+ _URL = "https://openaipublic.azureedge.net/clip/data/country211.tgz"
+ _MD5 = "84988d7644798601126c29e9877aab6a"
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ split: str = "train",
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = False,
+ ) -> None:
+ self._split = verify_str_arg(split, "split", ("train", "valid", "test"))
+
+ root = Path(root).expanduser()
+ self.root = str(root)
+ self._base_folder = root / "country211"
+
+ if download:
+ self._download()
+
+ if not self._check_exists():
+ raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+ super().__init__(str(self._base_folder / self._split), transform=transform, target_transform=target_transform)
+ self.root = str(root)
+
+ def _check_exists(self) -> bool:
+ return self._base_folder.exists() and self._base_folder.is_dir()
+
+ def _download(self) -> None:
+ if self._check_exists():
+ return
+ download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5)
diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/folder.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/folder.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f2f65c7b615aa12ec4cdb2bd472d521a6880ad7
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/folder.py
@@ -0,0 +1,337 @@
+import os
+import os.path
+from pathlib import Path
+from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
+
+from PIL import Image
+
+from .vision import VisionDataset
+
+
+def has_file_allowed_extension(filename: str, extensions: Union[str, Tuple[str, ...]]) -> bool:
+ """Checks if a file is an allowed extension.
+
+ Args:
+ filename (string): path to a file
+ extensions (tuple of strings): extensions to consider (lowercase)
+
+ Returns:
+ bool: True if the filename ends with one of given extensions
+ """
+ return filename.lower().endswith(extensions if isinstance(extensions, str) else tuple(extensions))
+
+
+def is_image_file(filename: str) -> bool:
+ """Checks if a file is an allowed image extension.
+
+ Args:
+ filename (string): path to a file
+
+ Returns:
+ bool: True if the filename ends with a known image extension
+ """
+ return has_file_allowed_extension(filename, IMG_EXTENSIONS)
+
+
+def find_classes(directory: Union[str, Path]) -> Tuple[List[str], Dict[str, int]]:
+ """Finds the class folders in a dataset.
+
+ See :class:`DatasetFolder` for details.
+ """
+ classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
+ if not classes:
+ raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
+
+ class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
+ return classes, class_to_idx
+
+
+def make_dataset(
+ directory: Union[str, Path],
+ class_to_idx: Optional[Dict[str, int]] = None,
+ extensions: Optional[Union[str, Tuple[str, ...]]] = None,
+ is_valid_file: Optional[Callable[[str], bool]] = None,
+ allow_empty: bool = False,
+) -> List[Tuple[str, int]]:
+ """Generates a list of samples of a form (path_to_sample, class).
+
+ See :class:`DatasetFolder` for details.
+
+ Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
+ by default.
+ """
+ directory = os.path.expanduser(directory)
+
+ if class_to_idx is None:
+ _, class_to_idx = find_classes(directory)
+ elif not class_to_idx:
+ raise ValueError("'class_to_index' must have at least one entry to collect any samples.")
+
+ both_none = extensions is None and is_valid_file is None
+ both_something = extensions is not None and is_valid_file is not None
+ if both_none or both_something:
+ raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
+
+ if extensions is not None:
+
+ def is_valid_file(x: str) -> bool:
+ return has_file_allowed_extension(x, extensions) # type: ignore[arg-type]
+
+ is_valid_file = cast(Callable[[str], bool], is_valid_file)
+
+ instances = []
+ available_classes = set()
+ for target_class in sorted(class_to_idx.keys()):
+ class_index = class_to_idx[target_class]
+ target_dir = os.path.join(directory, target_class)
+ if not os.path.isdir(target_dir):
+ continue
+ for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
+ for fname in sorted(fnames):
+ path = os.path.join(root, fname)
+ if is_valid_file(path):
+ item = path, class_index
+ instances.append(item)
+
+ if target_class not in available_classes:
+ available_classes.add(target_class)
+
+ empty_classes = set(class_to_idx.keys()) - available_classes
+ if empty_classes and not allow_empty:
+ msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
+ if extensions is not None:
+ msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"
+ raise FileNotFoundError(msg)
+
+ return instances
+
+
+class DatasetFolder(VisionDataset):
+ """A generic data loader.
+
+ This default directory structure can be customized by overriding the
+ :meth:`find_classes` method.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory path.
+ loader (callable): A function to load a sample given its path.
+ extensions (tuple[string]): A list of allowed extensions.
+ both extensions and is_valid_file should not be passed.
+ transform (callable, optional): A function/transform that takes in
+ a sample and returns a transformed version.
+ E.g, ``transforms.RandomCrop`` for images.
+ target_transform (callable, optional): A function/transform that takes
+ in the target and transforms it.
+ is_valid_file (callable, optional): A function that takes path of a file
+ and check if the file is a valid file (used to check of corrupt files)
+ both extensions and is_valid_file should not be passed.
+ allow_empty(bool, optional): If True, empty folders are considered to be valid classes.
+ An error is raised on empty folders if False (default).
+
+ Attributes:
+ classes (list): List of the class names sorted alphabetically.
+ class_to_idx (dict): Dict with items (class_name, class_index).
+ samples (list): List of (sample path, class_index) tuples
+ targets (list): The class_index value for each image in the dataset
+ """
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ loader: Callable[[str], Any],
+ extensions: Optional[Tuple[str, ...]] = None,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ is_valid_file: Optional[Callable[[str], bool]] = None,
+ allow_empty: bool = False,
+ ) -> None:
+ super().__init__(root, transform=transform, target_transform=target_transform)
+ classes, class_to_idx = self.find_classes(self.root)
+ samples = self.make_dataset(
+ self.root,
+ class_to_idx=class_to_idx,
+ extensions=extensions,
+ is_valid_file=is_valid_file,
+ allow_empty=allow_empty,
+ )
+
+ self.loader = loader
+ self.extensions = extensions
+
+ self.classes = classes
+ self.class_to_idx = class_to_idx
+ self.samples = samples
+ self.targets = [s[1] for s in samples]
+
+ @staticmethod
+ def make_dataset(
+ directory: Union[str, Path],
+ class_to_idx: Dict[str, int],
+ extensions: Optional[Tuple[str, ...]] = None,
+ is_valid_file: Optional[Callable[[str], bool]] = None,
+ allow_empty: bool = False,
+ ) -> List[Tuple[str, int]]:
+ """Generates a list of samples of a form (path_to_sample, class).
+
+ This can be overridden to e.g. read files from a compressed zip file instead of from the disk.
+
+ Args:
+ directory (str): root dataset directory, corresponding to ``self.root``.
+ class_to_idx (Dict[str, int]): Dictionary mapping class name to class index.
+ extensions (optional): A list of allowed extensions.
+ Either extensions or is_valid_file should be passed. Defaults to None.
+ is_valid_file (optional): A function that takes path of a file
+ and checks if the file is a valid file
+ (used to check of corrupt files) both extensions and
+ is_valid_file should not be passed. Defaults to None.
+ allow_empty(bool, optional): If True, empty folders are considered to be valid classes.
+ An error is raised on empty folders if False (default).
+
+ Raises:
+ ValueError: In case ``class_to_idx`` is empty.
+ ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
+ FileNotFoundError: In case no valid file was found for any class.
+
+ Returns:
+ List[Tuple[str, int]]: samples of a form (path_to_sample, class)
+ """
+ if class_to_idx is None:
+ # prevent potential bug since make_dataset() would use the class_to_idx logic of the
+ # find_classes() function, instead of using that of the find_classes() method, which
+ # is potentially overridden and thus could have a different logic.
+ raise ValueError("The class_to_idx parameter cannot be None.")
+ return make_dataset(
+ directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file, allow_empty=allow_empty
+ )
+
+ def find_classes(self, directory: Union[str, Path]) -> Tuple[List[str], Dict[str, int]]:
+ """Find the class folders in a dataset structured as follows::
+
+ directory/
+ ├── class_x
+ │ ├── xxx.ext
+ │ ├── xxy.ext
+ │ └── ...
+ │ └── xxz.ext
+ └── class_y
+ ├── 123.ext
+ ├── nsdf3.ext
+ └── ...
+ └── asd932_.ext
+
+ This method can be overridden to only consider
+ a subset of classes, or to adapt to a different dataset directory structure.
+
+ Args:
+ directory(str): Root directory path, corresponding to ``self.root``
+
+ Raises:
+ FileNotFoundError: If ``dir`` has no class folders.
+
+ Returns:
+ (Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index.
+ """
+ return find_classes(directory)
+
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (sample, target) where target is class_index of the target class.
+ """
+ path, target = self.samples[index]
+ sample = self.loader(path)
+ if self.transform is not None:
+ sample = self.transform(sample)
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return sample, target
+
+ def __len__(self) -> int:
+ return len(self.samples)
+
+
+IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
+
+
+def pil_loader(path: str) -> Image.Image:
+ # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
+ with open(path, "rb") as f:
+ img = Image.open(f)
+ return img.convert("RGB")
+
+
+# TODO: specify the return type
+def accimage_loader(path: str) -> Any:
+ import accimage
+
+ try:
+ return accimage.Image(path)
+ except OSError:
+ # Potentially a decoding problem, fall back to PIL.Image
+ return pil_loader(path)
+
+
+def default_loader(path: str) -> Any:
+ from torchvision import get_image_backend
+
+ if get_image_backend() == "accimage":
+ return accimage_loader(path)
+ else:
+ return pil_loader(path)
+
+
+class ImageFolder(DatasetFolder):
+ """A generic data loader where the images are arranged in this way by default: ::
+
+ root/dog/xxx.png
+ root/dog/xxy.png
+ root/dog/[...]/xxz.png
+
+ root/cat/123.png
+ root/cat/nsdf3.png
+ root/cat/[...]/asd932_.png
+
+ This class inherits from :class:`~torchvision.datasets.DatasetFolder` so
+ the same methods can be overridden to customize the dataset.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory path.
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ loader (callable, optional): A function to load an image given its path.
+ is_valid_file (callable, optional): A function that takes path of an Image file
+ and check if the file is a valid file (used to check of corrupt files)
+ allow_empty(bool, optional): If True, empty folders are considered to be valid classes.
+ An error is raised on empty folders if False (default).
+
+ Attributes:
+ classes (list): List of the class names sorted alphabetically.
+ class_to_idx (dict): Dict with items (class_name, class_index).
+ imgs (list): List of (image path, class_index) tuples
+ """
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ loader: Callable[[str], Any] = default_loader,
+ is_valid_file: Optional[Callable[[str], bool]] = None,
+ allow_empty: bool = False,
+ ):
+ super().__init__(
+ root,
+ loader,
+ IMG_EXTENSIONS if is_valid_file is None else None,
+ transform=transform,
+ target_transform=target_transform,
+ is_valid_file=is_valid_file,
+ allow_empty=allow_empty,
+ )
+ self.imgs = self.samples
diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/gtsrb.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/gtsrb.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3d012c70b22fd8209534a01a51fa9978c705d00
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/gtsrb.py
@@ -0,0 +1,103 @@
+import csv
+import pathlib
+from typing import Any, Callable, Optional, Tuple, Union
+
+import PIL
+
+from .folder import make_dataset
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class GTSRB(VisionDataset):
+ """`German Traffic Sign Recognition Benchmark (GTSRB) `_ Dataset.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of the dataset.
+ split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
+ transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
+ version. E.g, ``transforms.RandomCrop``.
+ target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+ download (bool, optional): If True, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+ """
+
+ def __init__(
+ self,
+ root: Union[str, pathlib.Path],
+ split: str = "train",
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = False,
+ ) -> None:
+
+ super().__init__(root, transform=transform, target_transform=target_transform)
+
+ self._split = verify_str_arg(split, "split", ("train", "test"))
+ self._base_folder = pathlib.Path(root) / "gtsrb"
+ self._target_folder = (
+ self._base_folder / "GTSRB" / ("Training" if self._split == "train" else "Final_Test/Images")
+ )
+
+ if download:
+ self.download()
+
+ if not self._check_exists():
+ raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+ if self._split == "train":
+ samples = make_dataset(str(self._target_folder), extensions=(".ppm",))
+ else:
+ with open(self._base_folder / "GT-final_test.csv") as csv_file:
+ samples = [
+ (str(self._target_folder / row["Filename"]), int(row["ClassId"]))
+ for row in csv.DictReader(csv_file, delimiter=";", skipinitialspace=True)
+ ]
+
+ self._samples = samples
+ self.transform = transform
+ self.target_transform = target_transform
+
+ def __len__(self) -> int:
+ return len(self._samples)
+
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
+
+ path, target = self._samples[index]
+ sample = PIL.Image.open(path).convert("RGB")
+
+ if self.transform is not None:
+ sample = self.transform(sample)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return sample, target
+
+ def _check_exists(self) -> bool:
+ return self._target_folder.is_dir()
+
+ def download(self) -> None:
+ if self._check_exists():
+ return
+
+ base_url = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/"
+
+ if self._split == "train":
+ download_and_extract_archive(
+ f"{base_url}GTSRB-Training_fixed.zip",
+ download_root=str(self._base_folder),
+ md5="513f3c79a4c5141765e10e952eaa2478",
+ )
+ else:
+ download_and_extract_archive(
+ f"{base_url}GTSRB_Final_Test_Images.zip",
+ download_root=str(self._base_folder),
+ md5="c7e4e6327067d32654124b0fe9e82185",
+ )
+ download_and_extract_archive(
+ f"{base_url}GTSRB_Final_Test_GT.zip",
+ download_root=str(self._base_folder),
+ md5="fe31e9c9270bbcd7b84b7f21a9d9d9e5",
+ )
diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/inaturalist.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/inaturalist.py
new file mode 100644
index 0000000000000000000000000000000000000000..68f9a77f56a085cbd1d78540165964ba4f861658
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/inaturalist.py
@@ -0,0 +1,242 @@
+import os
+import os.path
+from pathlib import Path
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+from PIL import Image
+
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+CATEGORIES_2021 = ["kingdom", "phylum", "class", "order", "family", "genus"]
+
+DATASET_URLS = {
+ "2017": "https://ml-inat-competition-datasets.s3.amazonaws.com/2017/train_val_images.tar.gz",
+ "2018": "https://ml-inat-competition-datasets.s3.amazonaws.com/2018/train_val2018.tar.gz",
+ "2019": "https://ml-inat-competition-datasets.s3.amazonaws.com/2019/train_val2019.tar.gz",
+ "2021_train": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train.tar.gz",
+ "2021_train_mini": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train_mini.tar.gz",
+ "2021_valid": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/val.tar.gz",
+}
+
+DATASET_MD5 = {
+ "2017": "7c784ea5e424efaec655bd392f87301f",
+ "2018": "b1c6952ce38f31868cc50ea72d066cc3",
+ "2019": "c60a6e2962c9b8ccbd458d12c8582644",
+ "2021_train": "e0526d53c7f7b2e3167b2b43bb2690ed",
+ "2021_train_mini": "db6ed8330e634445efc8fec83ae81442",
+ "2021_valid": "f6f6e0e242e3d4c9569ba56400938afc",
+}
+
+
+class INaturalist(VisionDataset):
+ """`iNaturalist `_ Dataset.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of dataset where the image files are stored.
+ This class does not require/use annotation files.
+ version (string, optional): Which version of the dataset to download/use. One of
+ '2017', '2018', '2019', '2021_train', '2021_train_mini', '2021_valid'.
+ Default: `2021_train`.
+ target_type (string or list, optional): Type of target to use, for 2021 versions, one of:
+
+ - ``full``: the full category (species)
+ - ``kingdom``: e.g. "Animalia"
+ - ``phylum``: e.g. "Arthropoda"
+ - ``class``: e.g. "Insecta"
+ - ``order``: e.g. "Coleoptera"
+ - ``family``: e.g. "Cleridae"
+ - ``genus``: e.g. "Trichodes"
+
+ for 2017-2019 versions, one of:
+
+ - ``full``: the full (numeric) category
+ - ``super``: the super category, e.g. "Amphibians"
+
+ Can also be a list to output a tuple with all specified target types.
+ Defaults to ``full``.
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+ """
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ version: str = "2021_train",
+ target_type: Union[List[str], str] = "full",
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = False,
+ ) -> None:
+ self.version = verify_str_arg(version, "version", DATASET_URLS.keys())
+
+ super().__init__(os.path.join(root, version), transform=transform, target_transform=target_transform)
+
+ os.makedirs(root, exist_ok=True)
+ if download:
+ self.download()
+
+ if not self._check_integrity():
+ raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+ self.all_categories: List[str] = []
+
+ # map: category type -> name of category -> index
+ self.categories_index: Dict[str, Dict[str, int]] = {}
+
+ # list indexed by category id, containing mapping from category type -> index
+ self.categories_map: List[Dict[str, int]] = []
+
+ if not isinstance(target_type, list):
+ target_type = [target_type]
+ if self.version[:4] == "2021":
+ self.target_type = [verify_str_arg(t, "target_type", ("full", *CATEGORIES_2021)) for t in target_type]
+ self._init_2021()
+ else:
+ self.target_type = [verify_str_arg(t, "target_type", ("full", "super")) for t in target_type]
+ self._init_pre2021()
+
+ # index of all files: (full category id, filename)
+ self.index: List[Tuple[int, str]] = []
+
+ for dir_index, dir_name in enumerate(self.all_categories):
+ files = os.listdir(os.path.join(self.root, dir_name))
+ for fname in files:
+ self.index.append((dir_index, fname))
+
+ def _init_2021(self) -> None:
+ """Initialize based on 2021 layout"""
+
+ self.all_categories = sorted(os.listdir(self.root))
+
+ # map: category type -> name of category -> index
+ self.categories_index = {k: {} for k in CATEGORIES_2021}
+
+ for dir_index, dir_name in enumerate(self.all_categories):
+ pieces = dir_name.split("_")
+ if len(pieces) != 8:
+ raise RuntimeError(f"Unexpected category name {dir_name}, wrong number of pieces")
+ if pieces[0] != f"{dir_index:05d}":
+ raise RuntimeError(f"Unexpected category id {pieces[0]}, expecting {dir_index:05d}")
+ cat_map = {}
+ for cat, name in zip(CATEGORIES_2021, pieces[1:7]):
+ if name in self.categories_index[cat]:
+ cat_id = self.categories_index[cat][name]
+ else:
+ cat_id = len(self.categories_index[cat])
+ self.categories_index[cat][name] = cat_id
+ cat_map[cat] = cat_id
+ self.categories_map.append(cat_map)
+
+ def _init_pre2021(self) -> None:
+ """Initialize based on 2017-2019 layout"""
+
+ # map: category type -> name of category -> index
+ self.categories_index = {"super": {}}
+
+ cat_index = 0
+ super_categories = sorted(os.listdir(self.root))
+ for sindex, scat in enumerate(super_categories):
+ self.categories_index["super"][scat] = sindex
+ subcategories = sorted(os.listdir(os.path.join(self.root, scat)))
+ for subcat in subcategories:
+ if self.version == "2017":
+ # this version does not use ids as directory names
+ subcat_i = cat_index
+ cat_index += 1
+ else:
+ try:
+ subcat_i = int(subcat)
+ except ValueError:
+ raise RuntimeError(f"Unexpected non-numeric dir name: {subcat}")
+ if subcat_i >= len(self.categories_map):
+ old_len = len(self.categories_map)
+ self.categories_map.extend([{}] * (subcat_i - old_len + 1))
+ self.all_categories.extend([""] * (subcat_i - old_len + 1))
+ if self.categories_map[subcat_i]:
+ raise RuntimeError(f"Duplicate category {subcat}")
+ self.categories_map[subcat_i] = {"super": sindex}
+ self.all_categories[subcat_i] = os.path.join(scat, subcat)
+
+ # validate the dictionary
+ for cindex, c in enumerate(self.categories_map):
+ if not c:
+ raise RuntimeError(f"Missing category {cindex}")
+
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where the type of target specified by target_type.
+ """
+
+ cat_id, fname = self.index[index]
+ img = Image.open(os.path.join(self.root, self.all_categories[cat_id], fname))
+
+ target: Any = []
+ for t in self.target_type:
+ if t == "full":
+ target.append(cat_id)
+ else:
+ target.append(self.categories_map[cat_id][t])
+ target = tuple(target) if len(target) > 1 else target[0]
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target
+
+ def __len__(self) -> int:
+ return len(self.index)
+
+ def category_name(self, category_type: str, category_id: int) -> str:
+ """
+ Args:
+ category_type(str): one of "full", "kingdom", "phylum", "class", "order", "family", "genus" or "super"
+ category_id(int): an index (class id) from this category
+
+ Returns:
+ the name of the category
+ """
+ if category_type == "full":
+ return self.all_categories[category_id]
+ else:
+ if category_type not in self.categories_index:
+ raise ValueError(f"Invalid category type '{category_type}'")
+ else:
+ for name, id in self.categories_index[category_type].items():
+ if id == category_id:
+ return name
+ raise ValueError(f"Invalid category id {category_id} for {category_type}")
+
+ def _check_integrity(self) -> bool:
+ return os.path.exists(self.root) and len(os.listdir(self.root)) > 0
+
+ def download(self) -> None:
+ if self._check_integrity():
+ raise RuntimeError(
+ f"The directory {self.root} already exists. "
+ f"If you want to re-download or re-extract the images, delete the directory."
+ )
+
+ base_root = os.path.dirname(self.root)
+
+ download_and_extract_archive(
+ DATASET_URLS[self.version], base_root, filename=f"{self.version}.tgz", md5=DATASET_MD5[self.version]
+ )
+
+ orig_dir_name = os.path.join(base_root, os.path.basename(DATASET_URLS[self.version]).rstrip(".tar.gz"))
+ if not os.path.exists(orig_dir_name):
+ raise RuntimeError(f"Unable to find downloaded files at {orig_dir_name}")
+ os.rename(orig_dir_name, self.root)
+ print(f"Dataset version '{self.version}' has been downloaded and prepared for use")
diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/lsun.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/lsun.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2f5e18b9912f9c9a82e898ba68b68e36ea300c4
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/lsun.py
@@ -0,0 +1,168 @@
+import io
+import os.path
+import pickle
+import string
+from collections.abc import Iterable
+from pathlib import Path
+from typing import Any, Callable, cast, List, Optional, Tuple, Union
+
+from PIL import Image
+
+from .utils import iterable_to_str, verify_str_arg
+from .vision import VisionDataset
+
+
+class LSUNClass(VisionDataset):
+ def __init__(
+ self, root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None
+ ) -> None:
+ import lmdb
+
+ super().__init__(root, transform=transform, target_transform=target_transform)
+
+ self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False)
+ with self.env.begin(write=False) as txn:
+ self.length = txn.stat()["entries"]
+ cache_file = "_cache_" + "".join(c for c in root if c in string.ascii_letters)
+ if os.path.isfile(cache_file):
+ self.keys = pickle.load(open(cache_file, "rb"))
+ else:
+ with self.env.begin(write=False) as txn:
+ self.keys = [key for key in txn.cursor().iternext(keys=True, values=False)]
+ pickle.dump(self.keys, open(cache_file, "wb"))
+
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
+ img, target = None, None
+ env = self.env
+ with env.begin(write=False) as txn:
+ imgbuf = txn.get(self.keys[index])
+
+ buf = io.BytesIO()
+ buf.write(imgbuf)
+ buf.seek(0)
+ img = Image.open(buf).convert("RGB")
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target
+
+ def __len__(self) -> int:
+ return self.length
+
+
+class LSUN(VisionDataset):
+ """`LSUN `_ dataset.
+
+ You will need to install the ``lmdb`` package to use this dataset: run
+ ``pip install lmdb``
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory for the database files.
+ classes (string or list): One of {'train', 'val', 'test'} or a list of
+ categories to load. e,g. ['bedroom_train', 'church_outdoor_train'].
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ """
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ classes: Union[str, List[str]] = "train",
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ ) -> None:
+ super().__init__(root, transform=transform, target_transform=target_transform)
+ self.classes = self._verify_classes(classes)
+
+ # for each class, create an LSUNClassDataset
+ self.dbs = []
+ for c in self.classes:
+ self.dbs.append(LSUNClass(root=os.path.join(root, f"{c}_lmdb"), transform=transform))
+
+ self.indices = []
+ count = 0
+ for db in self.dbs:
+ count += len(db)
+ self.indices.append(count)
+
+ self.length = count
+
+ def _verify_classes(self, classes: Union[str, List[str]]) -> List[str]:
+ categories = [
+ "bedroom",
+ "bridge",
+ "church_outdoor",
+ "classroom",
+ "conference_room",
+ "dining_room",
+ "kitchen",
+ "living_room",
+ "restaurant",
+ "tower",
+ ]
+ dset_opts = ["train", "val", "test"]
+
+ try:
+ classes = cast(str, classes)
+ verify_str_arg(classes, "classes", dset_opts)
+ if classes == "test":
+ classes = [classes]
+ else:
+ classes = [c + "_" + classes for c in categories]
+ except ValueError:
+ if not isinstance(classes, Iterable):
+ msg = "Expected type str or Iterable for argument classes, but got type {}."
+ raise ValueError(msg.format(type(classes)))
+
+ classes = list(classes)
+ msg_fmtstr_type = "Expected type str for elements in argument classes, but got type {}."
+ for c in classes:
+ verify_str_arg(c, custom_msg=msg_fmtstr_type.format(type(c)))
+ c_short = c.split("_")
+ category, dset_opt = "_".join(c_short[:-1]), c_short[-1]
+
+ msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}."
+ msg = msg_fmtstr.format(category, "LSUN class", iterable_to_str(categories))
+ verify_str_arg(category, valid_values=categories, custom_msg=msg)
+
+ msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts))
+ verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg)
+
+ return classes
+
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: Tuple (image, target) where target is the index of the target category.
+ """
+ target = 0
+ sub = 0
+ for ind in self.indices:
+ if index < ind:
+ break
+ target += 1
+ sub = ind
+
+ db = self.dbs[target]
+ index = index - sub
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ img, _ = db[index]
+ return img, target
+
+ def __len__(self) -> int:
+ return self.length
+
+ def extra_repr(self) -> str:
+ return "Classes: {classes}".format(**self.__dict__)
diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/sbu.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/sbu.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c349370a12a4a6c0abea9f6e6ab5bd06a84107e
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/sbu.py
@@ -0,0 +1,110 @@
+import os
+from pathlib import Path
+from typing import Any, Callable, Optional, Tuple, Union
+
+from PIL import Image
+
+from .utils import check_integrity, download_and_extract_archive, download_url
+from .vision import VisionDataset
+
+
+class SBU(VisionDataset):
+ """`SBU Captioned Photo `_ Dataset.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of dataset where tarball
+ ``SBUCaptionedPhotoDataset.tar.gz`` exists.
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ download (bool, optional): If True, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+ """
+
+ url = "https://www.cs.rice.edu/~vo9/sbucaptions/SBUCaptionedPhotoDataset.tar.gz"
+ filename = "SBUCaptionedPhotoDataset.tar.gz"
+ md5_checksum = "9aec147b3488753cf758b4d493422285"
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = True,
+ ) -> None:
+ super().__init__(root, transform=transform, target_transform=target_transform)
+
+ if download:
+ self.download()
+
+ if not self._check_integrity():
+ raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+ # Read the caption for each photo
+ self.photos = []
+ self.captions = []
+
+ file1 = os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_urls.txt")
+ file2 = os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_captions.txt")
+
+ for line1, line2 in zip(open(file1), open(file2)):
+ url = line1.rstrip()
+ photo = os.path.basename(url)
+ filename = os.path.join(self.root, "dataset", photo)
+ if os.path.exists(filename):
+ caption = line2.rstrip()
+ self.photos.append(photo)
+ self.captions.append(caption)
+
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where target is a caption for the photo.
+ """
+ filename = os.path.join(self.root, "dataset", self.photos[index])
+ img = Image.open(filename).convert("RGB")
+ if self.transform is not None:
+ img = self.transform(img)
+
+ target = self.captions[index]
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target
+
+ def __len__(self) -> int:
+ """The number of photos in the dataset."""
+ return len(self.photos)
+
+ def _check_integrity(self) -> bool:
+ """Check the md5 checksum of the downloaded tarball."""
+ root = self.root
+ fpath = os.path.join(root, self.filename)
+ if not check_integrity(fpath, self.md5_checksum):
+ return False
+ return True
+
+ def download(self) -> None:
+ """Download and extract the tarball, and download each individual photo."""
+
+ if self._check_integrity():
+ print("Files already downloaded and verified")
+ return
+
+ download_and_extract_archive(self.url, self.root, self.root, self.filename, self.md5_checksum)
+
+ # Download individual photos
+ with open(os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_urls.txt")) as fh:
+ for line in fh:
+ url = line.rstrip()
+ try:
+ download_url(url, os.path.join(self.root, "dataset"))
+ except OSError:
+ # The images point to public images on Flickr.
+ # Note: Images might be removed by users at anytime.
+ pass
diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/svhn.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/svhn.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d20d7db7e3ccc1d77c5949830cf7b610d88355f
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/svhn.py
@@ -0,0 +1,130 @@
+import os.path
+from pathlib import Path
+from typing import Any, Callable, Optional, Tuple, Union
+
+import numpy as np
+from PIL import Image
+
+from .utils import check_integrity, download_url, verify_str_arg
+from .vision import VisionDataset
+
+
+class SVHN(VisionDataset):
+ """`SVHN `_ Dataset.
+ Note: The SVHN dataset assigns the label `10` to the digit `0`. However, in this Dataset,
+ we assign the label `0` to the digit `0` to be compatible with PyTorch loss functions which
+ expect the class labels to be in the range `[0, C-1]`
+
+ .. warning::
+
+ This class needs `scipy `_ to load data from `.mat` format.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory of the dataset where the data is stored.
+ split (string): One of {'train', 'test', 'extra'}.
+ Accordingly dataset is selected. 'extra' is Extra training set.
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+
+ """
+
+ split_list = {
+ "train": [
+ "http://ufldl.stanford.edu/housenumbers/train_32x32.mat",
+ "train_32x32.mat",
+ "e26dedcc434d2e4c54c9b2d4a06d8373",
+ ],
+ "test": [
+ "http://ufldl.stanford.edu/housenumbers/test_32x32.mat",
+ "test_32x32.mat",
+ "eb5a983be6a315427106f1b164d9cef3",
+ ],
+ "extra": [
+ "http://ufldl.stanford.edu/housenumbers/extra_32x32.mat",
+ "extra_32x32.mat",
+ "a93ce644f1a588dc4d68dda5feec44a7",
+ ],
+ }
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ split: str = "train",
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = False,
+ ) -> None:
+ super().__init__(root, transform=transform, target_transform=target_transform)
+ self.split = verify_str_arg(split, "split", tuple(self.split_list.keys()))
+ self.url = self.split_list[split][0]
+ self.filename = self.split_list[split][1]
+ self.file_md5 = self.split_list[split][2]
+
+ if download:
+ self.download()
+
+ if not self._check_integrity():
+ raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+ # import here rather than at top of file because this is
+ # an optional dependency for torchvision
+ import scipy.io as sio
+
+ # reading(loading) mat file as array
+ loaded_mat = sio.loadmat(os.path.join(self.root, self.filename))
+
+ self.data = loaded_mat["X"]
+ # loading from the .mat file gives an np.ndarray of type np.uint8
+ # converting to np.int64, so that we have a LongTensor after
+ # the conversion from the numpy array
+ # the squeeze is needed to obtain a 1D tensor
+ self.labels = loaded_mat["y"].astype(np.int64).squeeze()
+
+ # the svhn dataset assigns the class label "10" to the digit 0
+ # this makes it inconsistent with several loss functions
+ # which expect the class labels to be in the range [0, C-1]
+ np.place(self.labels, self.labels == 10, 0)
+ self.data = np.transpose(self.data, (3, 2, 0, 1))
+
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where target is index of the target class.
+ """
+ img, target = self.data[index], int(self.labels[index])
+
+ # doing this so that it is consistent with all other datasets
+ # to return a PIL Image
+ img = Image.fromarray(np.transpose(img, (1, 2, 0)))
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target
+
+ def __len__(self) -> int:
+ return len(self.data)
+
+ def _check_integrity(self) -> bool:
+ root = self.root
+ md5 = self.split_list[self.split][2]
+ fpath = os.path.join(root, self.filename)
+ return check_integrity(fpath, md5)
+
+ def download(self) -> None:
+ md5 = self.split_list[self.split][2]
+ download_url(self.url, self.root, self.filename, md5)
+
+ def extra_repr(self) -> str:
+ return "Split: {split}".format(**self.__dict__)
diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/widerface.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/widerface.py
new file mode 100644
index 0000000000000000000000000000000000000000..71f4ce313c3d3ae69678f81a46df8425ad75b1c2
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/widerface.py
@@ -0,0 +1,197 @@
+import os
+from os.path import abspath, expanduser
+from pathlib import Path
+
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from PIL import Image
+
+from .utils import download_and_extract_archive, download_file_from_google_drive, extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class WIDERFace(VisionDataset):
+ """`WIDERFace `_ Dataset.
+
+ Args:
+ root (str or ``pathlib.Path``): Root directory where images and annotations are downloaded to.
+ Expects the following folder structure if download=False:
+
+ .. code::
+
+
+ └── widerface
+ ├── wider_face_split ('wider_face_split.zip' if compressed)
+ ├── WIDER_train ('WIDER_train.zip' if compressed)
+ ├── WIDER_val ('WIDER_val.zip' if compressed)
+ └── WIDER_test ('WIDER_test.zip' if compressed)
+ split (string): The dataset split to use. One of {``train``, ``val``, ``test``}.
+ Defaults to ``train``.
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+
+ .. warning::
+
+ To download the dataset `gdown `_ is required.
+
+ """
+
+ BASE_FOLDER = "widerface"
+ FILE_LIST = [
+ # File ID MD5 Hash Filename
+ ("15hGDLhsx8bLgLcIRD5DhYt5iBxnjNF1M", "3fedf70df600953d25982bcd13d91ba2", "WIDER_train.zip"),
+ ("1GUCogbp16PMGa39thoMMeWxp7Rp5oM8Q", "dfa7d7e790efa35df3788964cf0bbaea", "WIDER_val.zip"),
+ ("1HIfDbVEWKmsYKJZm4lchTBDLW5N7dY5T", "e5d8f4248ed24c334bbd12f49c29dd40", "WIDER_test.zip"),
+ ]
+ ANNOTATIONS_FILE = (
+ "http://shuoyang1213.me/WIDERFACE/support/bbx_annotation/wider_face_split.zip",
+ "0e3767bcf0e326556d407bf5bff5d27c",
+ "wider_face_split.zip",
+ )
+
+ def __init__(
+ self,
+ root: Union[str, Path],
+ split: str = "train",
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = False,
+ ) -> None:
+ super().__init__(
+ root=os.path.join(root, self.BASE_FOLDER), transform=transform, target_transform=target_transform
+ )
+ # check arguments
+ self.split = verify_str_arg(split, "split", ("train", "val", "test"))
+
+ if download:
+ self.download()
+
+ if not self._check_integrity():
+ raise RuntimeError("Dataset not found or corrupted. You can use download=True to download and prepare it")
+
+ self.img_info: List[Dict[str, Union[str, Dict[str, torch.Tensor]]]] = []
+ if self.split in ("train", "val"):
+ self.parse_train_val_annotations_file()
+ else:
+ self.parse_test_annotations_file()
+
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where target is a dict of annotations for all faces in the image.
+ target=None for the test split.
+ """
+
+ # stay consistent with other datasets and return a PIL Image
+ img = Image.open(self.img_info[index]["img_path"]) # type: ignore[arg-type]
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ target = None if self.split == "test" else self.img_info[index]["annotations"]
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target
+
+ def __len__(self) -> int:
+ return len(self.img_info)
+
+ def extra_repr(self) -> str:
+ lines = ["Split: {split}"]
+ return "\n".join(lines).format(**self.__dict__)
+
+ def parse_train_val_annotations_file(self) -> None:
+ filename = "wider_face_train_bbx_gt.txt" if self.split == "train" else "wider_face_val_bbx_gt.txt"
+ filepath = os.path.join(self.root, "wider_face_split", filename)
+
+ with open(filepath) as f:
+ lines = f.readlines()
+ file_name_line, num_boxes_line, box_annotation_line = True, False, False
+ num_boxes, box_counter = 0, 0
+ labels = []
+ for line in lines:
+ line = line.rstrip()
+ if file_name_line:
+ img_path = os.path.join(self.root, "WIDER_" + self.split, "images", line)
+ img_path = abspath(expanduser(img_path))
+ file_name_line = False
+ num_boxes_line = True
+ elif num_boxes_line:
+ num_boxes = int(line)
+ num_boxes_line = False
+ box_annotation_line = True
+ elif box_annotation_line:
+ box_counter += 1
+ line_split = line.split(" ")
+ line_values = [int(x) for x in line_split]
+ labels.append(line_values)
+ if box_counter >= num_boxes:
+ box_annotation_line = False
+ file_name_line = True
+ labels_tensor = torch.tensor(labels)
+ self.img_info.append(
+ {
+ "img_path": img_path,
+ "annotations": {
+ "bbox": labels_tensor[:, 0:4].clone(), # x, y, width, height
+ "blur": labels_tensor[:, 4].clone(),
+ "expression": labels_tensor[:, 5].clone(),
+ "illumination": labels_tensor[:, 6].clone(),
+ "occlusion": labels_tensor[:, 7].clone(),
+ "pose": labels_tensor[:, 8].clone(),
+ "invalid": labels_tensor[:, 9].clone(),
+ },
+ }
+ )
+ box_counter = 0
+ labels.clear()
+ else:
+ raise RuntimeError(f"Error parsing annotation file {filepath}")
+
+ def parse_test_annotations_file(self) -> None:
+ filepath = os.path.join(self.root, "wider_face_split", "wider_face_test_filelist.txt")
+ filepath = abspath(expanduser(filepath))
+ with open(filepath) as f:
+ lines = f.readlines()
+ for line in lines:
+ line = line.rstrip()
+ img_path = os.path.join(self.root, "WIDER_test", "images", line)
+ img_path = abspath(expanduser(img_path))
+ self.img_info.append({"img_path": img_path})
+
+ def _check_integrity(self) -> bool:
+ # Allow original archive to be deleted (zip). Only need the extracted images
+ all_files = self.FILE_LIST.copy()
+ all_files.append(self.ANNOTATIONS_FILE)
+ for (_, md5, filename) in all_files:
+ file, ext = os.path.splitext(filename)
+ extracted_dir = os.path.join(self.root, file)
+ if not os.path.exists(extracted_dir):
+ return False
+ return True
+
+ def download(self) -> None:
+ if self._check_integrity():
+ print("Files already downloaded and verified")
+ return
+
+ # download and extract image data
+ for (file_id, md5, filename) in self.FILE_LIST:
+ download_file_from_google_drive(file_id, self.root, filename, md5)
+ filepath = os.path.join(self.root, filename)
+ extract_archive(filepath)
+
+ # download and extract annotation files
+ download_and_extract_archive(
+ url=self.ANNOTATIONS_FILE[0], download_root=self.root, md5=self.ANNOTATIONS_FILE[1]
+ )
diff --git a/.venv/lib/python3.11/site-packages/torchvision/io/__init__.py b/.venv/lib/python3.11/site-packages/torchvision/io/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a604ea1fdb645c68d0896cff05caa4b97778e7c8
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/io/__init__.py
@@ -0,0 +1,76 @@
+from typing import Any, Dict, Iterator
+
+import torch
+
+from ..utils import _log_api_usage_once
+
+try:
+ from ._load_gpu_decoder import _HAS_GPU_VIDEO_DECODER
+except ModuleNotFoundError:
+ _HAS_GPU_VIDEO_DECODER = False
+
+from ._video_opt import (
+ _HAS_CPU_VIDEO_DECODER,
+ _HAS_VIDEO_OPT,
+ _probe_video_from_file,
+ _probe_video_from_memory,
+ _read_video_from_file,
+ _read_video_from_memory,
+ _read_video_timestamps_from_file,
+ _read_video_timestamps_from_memory,
+ Timebase,
+ VideoMetaData,
+)
+from .image import (
+ decode_gif,
+ decode_image,
+ decode_jpeg,
+ decode_png,
+ decode_webp,
+ encode_jpeg,
+ encode_png,
+ ImageReadMode,
+ read_file,
+ read_image,
+ write_file,
+ write_jpeg,
+ write_png,
+)
+from .video import read_video, read_video_timestamps, write_video
+from .video_reader import VideoReader
+
+
+__all__ = [
+ "write_video",
+ "read_video",
+ "read_video_timestamps",
+ "_read_video_from_file",
+ "_read_video_timestamps_from_file",
+ "_probe_video_from_file",
+ "_read_video_from_memory",
+ "_read_video_timestamps_from_memory",
+ "_probe_video_from_memory",
+ "_HAS_CPU_VIDEO_DECODER",
+ "_HAS_VIDEO_OPT",
+ "_HAS_GPU_VIDEO_DECODER",
+ "_read_video_clip_from_memory",
+ "_read_video_meta_data",
+ "VideoMetaData",
+ "Timebase",
+ "ImageReadMode",
+ "decode_image",
+ "decode_jpeg",
+ "decode_png",
+ "decode_heic",
+ "decode_webp",
+ "decode_gif",
+ "encode_jpeg",
+ "encode_png",
+ "read_file",
+ "read_image",
+ "write_file",
+ "write_jpeg",
+ "write_png",
+ "Video",
+ "VideoReader",
+]
diff --git a/.venv/lib/python3.11/site-packages/torchvision/io/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/io/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..16418ef1b1aa345b38f19e06a285247d4f628c9e
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/io/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchvision/io/__pycache__/_load_gpu_decoder.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/io/__pycache__/_load_gpu_decoder.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..076fd5476ceed0b26762655f48a6118ae4dee8c0
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/io/__pycache__/_load_gpu_decoder.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchvision/io/__pycache__/_video_opt.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/io/__pycache__/_video_opt.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bf0acab4a3182c2ab38c9530e28957f268ddcd20
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/io/__pycache__/_video_opt.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchvision/io/__pycache__/image.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/io/__pycache__/image.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..35ae402bb8d1b947e6626c2b003c86423f2a5661
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/io/__pycache__/image.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchvision/io/__pycache__/video.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/io/__pycache__/video.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..83ba465137aa433478eba1660fff8313bc32a023
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/io/__pycache__/video.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchvision/io/__pycache__/video_reader.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/io/__pycache__/video_reader.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..558096b2431fc084bb0ac18cc3ff3f3371e7c228
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/io/__pycache__/video_reader.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchvision/io/_load_gpu_decoder.py b/.venv/lib/python3.11/site-packages/torchvision/io/_load_gpu_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfd40c545d8201b67290e27bf74ce115774dace1
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/io/_load_gpu_decoder.py
@@ -0,0 +1,8 @@
+from ..extension import _load_library
+
+
+try:
+ _load_library("gpu_decoder")
+ _HAS_GPU_VIDEO_DECODER = True
+except (ImportError, OSError):
+ _HAS_GPU_VIDEO_DECODER = False
diff --git a/.venv/lib/python3.11/site-packages/torchvision/io/_video_opt.py b/.venv/lib/python3.11/site-packages/torchvision/io/_video_opt.py
new file mode 100644
index 0000000000000000000000000000000000000000..69af045e7731882c949f9181f2f7e6d507edc4a1
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/io/_video_opt.py
@@ -0,0 +1,513 @@
+import math
+import warnings
+from fractions import Fraction
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+
+from ..extension import _load_library
+
+
+try:
+ _load_library("video_reader")
+ _HAS_CPU_VIDEO_DECODER = True
+except (ImportError, OSError):
+ _HAS_CPU_VIDEO_DECODER = False
+
+_HAS_VIDEO_OPT = _HAS_CPU_VIDEO_DECODER # For BC
+default_timebase = Fraction(0, 1)
+
+
+# simple class for torch scripting
+# the complex Fraction class from fractions module is not scriptable
+class Timebase:
+ __annotations__ = {"numerator": int, "denominator": int}
+ __slots__ = ["numerator", "denominator"]
+
+ def __init__(
+ self,
+ numerator: int,
+ denominator: int,
+ ) -> None:
+ self.numerator = numerator
+ self.denominator = denominator
+
+
+class VideoMetaData:
+ __annotations__ = {
+ "has_video": bool,
+ "video_timebase": Timebase,
+ "video_duration": float,
+ "video_fps": float,
+ "has_audio": bool,
+ "audio_timebase": Timebase,
+ "audio_duration": float,
+ "audio_sample_rate": float,
+ }
+ __slots__ = [
+ "has_video",
+ "video_timebase",
+ "video_duration",
+ "video_fps",
+ "has_audio",
+ "audio_timebase",
+ "audio_duration",
+ "audio_sample_rate",
+ ]
+
+ def __init__(self) -> None:
+ self.has_video = False
+ self.video_timebase = Timebase(0, 1)
+ self.video_duration = 0.0
+ self.video_fps = 0.0
+ self.has_audio = False
+ self.audio_timebase = Timebase(0, 1)
+ self.audio_duration = 0.0
+ self.audio_sample_rate = 0.0
+
+
+def _validate_pts(pts_range: Tuple[int, int]) -> None:
+
+ if pts_range[0] > pts_range[1] > 0:
+ raise ValueError(
+ f"Start pts should not be smaller than end pts, got start pts: {pts_range[0]} and end pts: {pts_range[1]}"
+ )
+
+
+def _fill_info(
+ vtimebase: torch.Tensor,
+ vfps: torch.Tensor,
+ vduration: torch.Tensor,
+ atimebase: torch.Tensor,
+ asample_rate: torch.Tensor,
+ aduration: torch.Tensor,
+) -> VideoMetaData:
+ """
+ Build update VideoMetaData struct with info about the video
+ """
+ meta = VideoMetaData()
+ if vtimebase.numel() > 0:
+ meta.video_timebase = Timebase(int(vtimebase[0].item()), int(vtimebase[1].item()))
+ timebase = vtimebase[0].item() / float(vtimebase[1].item())
+ if vduration.numel() > 0:
+ meta.has_video = True
+ meta.video_duration = float(vduration.item()) * timebase
+ if vfps.numel() > 0:
+ meta.video_fps = float(vfps.item())
+ if atimebase.numel() > 0:
+ meta.audio_timebase = Timebase(int(atimebase[0].item()), int(atimebase[1].item()))
+ timebase = atimebase[0].item() / float(atimebase[1].item())
+ if aduration.numel() > 0:
+ meta.has_audio = True
+ meta.audio_duration = float(aduration.item()) * timebase
+ if asample_rate.numel() > 0:
+ meta.audio_sample_rate = float(asample_rate.item())
+
+ return meta
+
+
+def _align_audio_frames(
+ aframes: torch.Tensor, aframe_pts: torch.Tensor, audio_pts_range: Tuple[int, int]
+) -> torch.Tensor:
+ start, end = aframe_pts[0], aframe_pts[-1]
+ num_samples = aframes.size(0)
+ step_per_aframe = float(end - start + 1) / float(num_samples)
+ s_idx = 0
+ e_idx = num_samples
+ if start < audio_pts_range[0]:
+ s_idx = int((audio_pts_range[0] - start) / step_per_aframe)
+ if audio_pts_range[1] != -1 and end > audio_pts_range[1]:
+ e_idx = int((audio_pts_range[1] - end) / step_per_aframe)
+ return aframes[s_idx:e_idx, :]
+
+
+def _read_video_from_file(
+ filename: str,
+ seek_frame_margin: float = 0.25,
+ read_video_stream: bool = True,
+ video_width: int = 0,
+ video_height: int = 0,
+ video_min_dimension: int = 0,
+ video_max_dimension: int = 0,
+ video_pts_range: Tuple[int, int] = (0, -1),
+ video_timebase: Fraction = default_timebase,
+ read_audio_stream: bool = True,
+ audio_samples: int = 0,
+ audio_channels: int = 0,
+ audio_pts_range: Tuple[int, int] = (0, -1),
+ audio_timebase: Fraction = default_timebase,
+) -> Tuple[torch.Tensor, torch.Tensor, VideoMetaData]:
+ """
+ Reads a video from a file, returning both the video frames and the audio frames
+
+ Args:
+ filename (str): path to the video file
+ seek_frame_margin (double, optional): seeking frame in the stream is imprecise. Thus,
+ when video_start_pts is specified, we seek the pts earlier by seek_frame_margin seconds
+ read_video_stream (int, optional): whether read video stream. If yes, set to 1. Otherwise, 0
+ video_width/video_height/video_min_dimension/video_max_dimension (int): together decide
+ the size of decoded frames:
+
+ - When video_width = 0, video_height = 0, video_min_dimension = 0,
+ and video_max_dimension = 0, keep the original frame resolution
+ - When video_width = 0, video_height = 0, video_min_dimension != 0,
+ and video_max_dimension = 0, keep the aspect ratio and resize the
+ frame so that shorter edge size is video_min_dimension
+ - When video_width = 0, video_height = 0, video_min_dimension = 0,
+ and video_max_dimension != 0, keep the aspect ratio and resize
+ the frame so that longer edge size is video_max_dimension
+ - When video_width = 0, video_height = 0, video_min_dimension != 0,
+ and video_max_dimension != 0, resize the frame so that shorter
+ edge size is video_min_dimension, and longer edge size is
+ video_max_dimension. The aspect ratio may not be preserved
+ - When video_width = 0, video_height != 0, video_min_dimension = 0,
+ and video_max_dimension = 0, keep the aspect ratio and resize
+ the frame so that frame video_height is $video_height
+ - When video_width != 0, video_height == 0, video_min_dimension = 0,
+ and video_max_dimension = 0, keep the aspect ratio and resize
+ the frame so that frame video_width is $video_width
+ - When video_width != 0, video_height != 0, video_min_dimension = 0,
+ and video_max_dimension = 0, resize the frame so that frame
+ video_width and video_height are set to $video_width and
+ $video_height, respectively
+ video_pts_range (list(int), optional): the start and end presentation timestamp of video stream
+ video_timebase (Fraction, optional): a Fraction rational number which denotes timebase in video stream
+ read_audio_stream (int, optional): whether read audio stream. If yes, set to 1. Otherwise, 0
+ audio_samples (int, optional): audio sampling rate
+ audio_channels (int optional): audio channels
+ audio_pts_range (list(int), optional): the start and end presentation timestamp of audio stream
+ audio_timebase (Fraction, optional): a Fraction rational number which denotes time base in audio stream
+
+ Returns
+ vframes (Tensor[T, H, W, C]): the `T` video frames
+ aframes (Tensor[L, K]): the audio frames, where `L` is the number of points and
+ `K` is the number of audio_channels
+ info (Dict): metadata for the video and audio. Can contain the fields video_fps (float)
+ and audio_fps (int)
+ """
+ _validate_pts(video_pts_range)
+ _validate_pts(audio_pts_range)
+
+ result = torch.ops.video_reader.read_video_from_file(
+ filename,
+ seek_frame_margin,
+ 0, # getPtsOnly
+ read_video_stream,
+ video_width,
+ video_height,
+ video_min_dimension,
+ video_max_dimension,
+ video_pts_range[0],
+ video_pts_range[1],
+ video_timebase.numerator,
+ video_timebase.denominator,
+ read_audio_stream,
+ audio_samples,
+ audio_channels,
+ audio_pts_range[0],
+ audio_pts_range[1],
+ audio_timebase.numerator,
+ audio_timebase.denominator,
+ )
+ vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, atimebase, asample_rate, aduration = result
+ info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
+ if aframes.numel() > 0:
+ # when audio stream is found
+ aframes = _align_audio_frames(aframes, aframe_pts, audio_pts_range)
+ return vframes, aframes, info
+
+
+def _read_video_timestamps_from_file(filename: str) -> Tuple[List[int], List[int], VideoMetaData]:
+ """
+ Decode all video- and audio frames in the video. Only pts
+ (presentation timestamp) is returned. The actual frame pixel data is not
+ copied. Thus, it is much faster than read_video(...)
+ """
+ result = torch.ops.video_reader.read_video_from_file(
+ filename,
+ 0, # seek_frame_margin
+ 1, # getPtsOnly
+ 1, # read_video_stream
+ 0, # video_width
+ 0, # video_height
+ 0, # video_min_dimension
+ 0, # video_max_dimension
+ 0, # video_start_pts
+ -1, # video_end_pts
+ 0, # video_timebase_num
+ 1, # video_timebase_den
+ 1, # read_audio_stream
+ 0, # audio_samples
+ 0, # audio_channels
+ 0, # audio_start_pts
+ -1, # audio_end_pts
+ 0, # audio_timebase_num
+ 1, # audio_timebase_den
+ )
+ _vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, atimebase, asample_rate, aduration = result
+ info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
+
+ vframe_pts = vframe_pts.numpy().tolist()
+ aframe_pts = aframe_pts.numpy().tolist()
+ return vframe_pts, aframe_pts, info
+
+
+def _probe_video_from_file(filename: str) -> VideoMetaData:
+ """
+ Probe a video file and return VideoMetaData with info about the video
+ """
+ result = torch.ops.video_reader.probe_video_from_file(filename)
+ vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
+ info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
+ return info
+
+
+def _read_video_from_memory(
+ video_data: torch.Tensor,
+ seek_frame_margin: float = 0.25,
+ read_video_stream: int = 1,
+ video_width: int = 0,
+ video_height: int = 0,
+ video_min_dimension: int = 0,
+ video_max_dimension: int = 0,
+ video_pts_range: Tuple[int, int] = (0, -1),
+ video_timebase_numerator: int = 0,
+ video_timebase_denominator: int = 1,
+ read_audio_stream: int = 1,
+ audio_samples: int = 0,
+ audio_channels: int = 0,
+ audio_pts_range: Tuple[int, int] = (0, -1),
+ audio_timebase_numerator: int = 0,
+ audio_timebase_denominator: int = 1,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Reads a video from memory, returning both the video frames as the audio frames
+ This function is torchscriptable.
+
+ Args:
+ video_data (data type could be 1) torch.Tensor, dtype=torch.int8 or 2) python bytes):
+ compressed video content stored in either 1) torch.Tensor 2) python bytes
+ seek_frame_margin (double, optional): seeking frame in the stream is imprecise.
+ Thus, when video_start_pts is specified, we seek the pts earlier by seek_frame_margin seconds
+ read_video_stream (int, optional): whether read video stream. If yes, set to 1. Otherwise, 0
+ video_width/video_height/video_min_dimension/video_max_dimension (int): together decide
+ the size of decoded frames:
+
+ - When video_width = 0, video_height = 0, video_min_dimension = 0,
+ and video_max_dimension = 0, keep the original frame resolution
+ - When video_width = 0, video_height = 0, video_min_dimension != 0,
+ and video_max_dimension = 0, keep the aspect ratio and resize the
+ frame so that shorter edge size is video_min_dimension
+ - When video_width = 0, video_height = 0, video_min_dimension = 0,
+ and video_max_dimension != 0, keep the aspect ratio and resize
+ the frame so that longer edge size is video_max_dimension
+ - When video_width = 0, video_height = 0, video_min_dimension != 0,
+ and video_max_dimension != 0, resize the frame so that shorter
+ edge size is video_min_dimension, and longer edge size is
+ video_max_dimension. The aspect ratio may not be preserved
+ - When video_width = 0, video_height != 0, video_min_dimension = 0,
+ and video_max_dimension = 0, keep the aspect ratio and resize
+ the frame so that frame video_height is $video_height
+ - When video_width != 0, video_height == 0, video_min_dimension = 0,
+ and video_max_dimension = 0, keep the aspect ratio and resize
+ the frame so that frame video_width is $video_width
+ - When video_width != 0, video_height != 0, video_min_dimension = 0,
+ and video_max_dimension = 0, resize the frame so that frame
+ video_width and video_height are set to $video_width and
+ $video_height, respectively
+ video_pts_range (list(int), optional): the start and end presentation timestamp of video stream
+ video_timebase_numerator / video_timebase_denominator (float, optional): a rational
+ number which denotes timebase in video stream
+ read_audio_stream (int, optional): whether read audio stream. If yes, set to 1. Otherwise, 0
+ audio_samples (int, optional): audio sampling rate
+ audio_channels (int optional): audio audio_channels
+ audio_pts_range (list(int), optional): the start and end presentation timestamp of audio stream
+ audio_timebase_numerator / audio_timebase_denominator (float, optional):
+ a rational number which denotes time base in audio stream
+
+ Returns:
+ vframes (Tensor[T, H, W, C]): the `T` video frames
+ aframes (Tensor[L, K]): the audio frames, where `L` is the number of points and
+ `K` is the number of channels
+ """
+
+ _validate_pts(video_pts_range)
+ _validate_pts(audio_pts_range)
+
+ if not isinstance(video_data, torch.Tensor):
+ with warnings.catch_warnings():
+ # Ignore the warning because we actually don't modify the buffer in this function
+ warnings.filterwarnings("ignore", message="The given buffer is not writable")
+ video_data = torch.frombuffer(video_data, dtype=torch.uint8)
+
+ result = torch.ops.video_reader.read_video_from_memory(
+ video_data,
+ seek_frame_margin,
+ 0, # getPtsOnly
+ read_video_stream,
+ video_width,
+ video_height,
+ video_min_dimension,
+ video_max_dimension,
+ video_pts_range[0],
+ video_pts_range[1],
+ video_timebase_numerator,
+ video_timebase_denominator,
+ read_audio_stream,
+ audio_samples,
+ audio_channels,
+ audio_pts_range[0],
+ audio_pts_range[1],
+ audio_timebase_numerator,
+ audio_timebase_denominator,
+ )
+
+ vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, atimebase, asample_rate, aduration = result
+
+ if aframes.numel() > 0:
+ # when audio stream is found
+ aframes = _align_audio_frames(aframes, aframe_pts, audio_pts_range)
+
+ return vframes, aframes
+
+
+def _read_video_timestamps_from_memory(
+ video_data: torch.Tensor,
+) -> Tuple[List[int], List[int], VideoMetaData]:
+ """
+ Decode all frames in the video. Only pts (presentation timestamp) is returned.
+ The actual frame pixel data is not copied. Thus, read_video_timestamps(...)
+ is much faster than read_video(...)
+ """
+ if not isinstance(video_data, torch.Tensor):
+ with warnings.catch_warnings():
+ # Ignore the warning because we actually don't modify the buffer in this function
+ warnings.filterwarnings("ignore", message="The given buffer is not writable")
+ video_data = torch.frombuffer(video_data, dtype=torch.uint8)
+ result = torch.ops.video_reader.read_video_from_memory(
+ video_data,
+ 0, # seek_frame_margin
+ 1, # getPtsOnly
+ 1, # read_video_stream
+ 0, # video_width
+ 0, # video_height
+ 0, # video_min_dimension
+ 0, # video_max_dimension
+ 0, # video_start_pts
+ -1, # video_end_pts
+ 0, # video_timebase_num
+ 1, # video_timebase_den
+ 1, # read_audio_stream
+ 0, # audio_samples
+ 0, # audio_channels
+ 0, # audio_start_pts
+ -1, # audio_end_pts
+ 0, # audio_timebase_num
+ 1, # audio_timebase_den
+ )
+ _vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, atimebase, asample_rate, aduration = result
+ info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
+
+ vframe_pts = vframe_pts.numpy().tolist()
+ aframe_pts = aframe_pts.numpy().tolist()
+ return vframe_pts, aframe_pts, info
+
+
+def _probe_video_from_memory(
+ video_data: torch.Tensor,
+) -> VideoMetaData:
+ """
+ Probe a video in memory and return VideoMetaData with info about the video
+ This function is torchscriptable
+ """
+ if not isinstance(video_data, torch.Tensor):
+ with warnings.catch_warnings():
+ # Ignore the warning because we actually don't modify the buffer in this function
+ warnings.filterwarnings("ignore", message="The given buffer is not writable")
+ video_data = torch.frombuffer(video_data, dtype=torch.uint8)
+ result = torch.ops.video_reader.probe_video_from_memory(video_data)
+ vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
+ info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
+ return info
+
+
+def _read_video(
+ filename: str,
+ start_pts: Union[float, Fraction] = 0,
+ end_pts: Optional[Union[float, Fraction]] = None,
+ pts_unit: str = "pts",
+) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, float]]:
+ if end_pts is None:
+ end_pts = float("inf")
+
+ if pts_unit == "pts":
+ warnings.warn(
+ "The pts_unit 'pts' gives wrong results and will be removed in a "
+ + "follow-up version. Please use pts_unit 'sec'."
+ )
+
+ info = _probe_video_from_file(filename)
+
+ has_video = info.has_video
+ has_audio = info.has_audio
+
+ def get_pts(time_base):
+ start_offset = start_pts
+ end_offset = end_pts
+ if pts_unit == "sec":
+ start_offset = int(math.floor(start_pts * (1 / time_base)))
+ if end_offset != float("inf"):
+ end_offset = int(math.ceil(end_pts * (1 / time_base)))
+ if end_offset == float("inf"):
+ end_offset = -1
+ return start_offset, end_offset
+
+ video_pts_range = (0, -1)
+ video_timebase = default_timebase
+ if has_video:
+ video_timebase = Fraction(info.video_timebase.numerator, info.video_timebase.denominator)
+ video_pts_range = get_pts(video_timebase)
+
+ audio_pts_range = (0, -1)
+ audio_timebase = default_timebase
+ if has_audio:
+ audio_timebase = Fraction(info.audio_timebase.numerator, info.audio_timebase.denominator)
+ audio_pts_range = get_pts(audio_timebase)
+
+ vframes, aframes, info = _read_video_from_file(
+ filename,
+ read_video_stream=True,
+ video_pts_range=video_pts_range,
+ video_timebase=video_timebase,
+ read_audio_stream=True,
+ audio_pts_range=audio_pts_range,
+ audio_timebase=audio_timebase,
+ )
+ _info = {}
+ if has_video:
+ _info["video_fps"] = info.video_fps
+ if has_audio:
+ _info["audio_fps"] = info.audio_sample_rate
+
+ return vframes, aframes, _info
+
+
+def _read_video_timestamps(
+ filename: str, pts_unit: str = "pts"
+) -> Tuple[Union[List[int], List[Fraction]], Optional[float]]:
+ if pts_unit == "pts":
+ warnings.warn(
+ "The pts_unit 'pts' gives wrong results and will be removed in a "
+ + "follow-up version. Please use pts_unit 'sec'."
+ )
+
+ pts: Union[List[int], List[Fraction]]
+ pts, _, info = _read_video_timestamps_from_file(filename)
+
+ if pts_unit == "sec":
+ video_time_base = Fraction(info.video_timebase.numerator, info.video_timebase.denominator)
+ pts = [x * video_time_base for x in pts]
+
+ video_fps = info.video_fps if info.has_video else None
+
+ return pts, video_fps
diff --git a/.venv/lib/python3.11/site-packages/torchvision/io/image.py b/.venv/lib/python3.11/site-packages/torchvision/io/image.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb48d0e6816060e7e31e84c6864053b147070f95
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/io/image.py
@@ -0,0 +1,436 @@
+from enum import Enum
+from typing import List, Union
+from warnings import warn
+
+import torch
+
+from ..extension import _load_library
+from ..utils import _log_api_usage_once
+
+
+try:
+ _load_library("image")
+except (ImportError, OSError) as e:
+ warn(
+ f"Failed to load image Python extension: '{e}'"
+ f"If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. "
+ f"Otherwise, there might be something wrong with your environment. "
+ f"Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?"
+ )
+
+
+class ImageReadMode(Enum):
+ """Allow automatic conversion to RGB, RGBA, etc while decoding.
+
+ .. note::
+
+ You don't need to use this struct, you can just pass strings to all
+ ``mode`` parameters, e.g. ``mode="RGB"``.
+
+ The different available modes are the following.
+
+ - UNCHANGED: loads the image as-is
+ - RGB: converts to RGB
+ - RGBA: converts to RGB with transparency (also aliased as RGB_ALPHA)
+ - GRAY: converts to grayscale
+ - GRAY_ALPHA: converts to grayscale with transparency
+
+ .. note::
+
+ Some decoders won't support all possible values, e.g. GRAY and
+ GRAY_ALPHA are only supported for PNG and JPEG images.
+ """
+
+ UNCHANGED = 0
+ GRAY = 1
+ GRAY_ALPHA = 2
+ RGB = 3
+ RGB_ALPHA = 4
+ RGBA = RGB_ALPHA # Alias for convenience
+
+
+def read_file(path: str) -> torch.Tensor:
+ """
+ Return the bytes contents of a file as a uint8 1D Tensor.
+
+ Args:
+ path (str or ``pathlib.Path``): the path to the file to be read
+
+ Returns:
+ data (Tensor)
+ """
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+ _log_api_usage_once(read_file)
+ data = torch.ops.image.read_file(str(path))
+ return data
+
+
+def write_file(filename: str, data: torch.Tensor) -> None:
+ """
+ Write the content of an uint8 1D tensor to a file.
+
+ Args:
+ filename (str or ``pathlib.Path``): the path to the file to be written
+ data (Tensor): the contents to be written to the output file
+ """
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+ _log_api_usage_once(write_file)
+ torch.ops.image.write_file(str(filename), data)
+
+
+def decode_png(
+ input: torch.Tensor,
+ mode: ImageReadMode = ImageReadMode.UNCHANGED,
+ apply_exif_orientation: bool = False,
+) -> torch.Tensor:
+ """
+ Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor.
+
+ The values of the output tensor are in uint8 in [0, 255] for most cases. If
+ the image is a 16-bit png, then the output tensor is uint16 in [0, 65535]
+ (supported from torchvision ``0.21``). Since uint16 support is limited in
+ pytorch, we recommend calling
+ :func:`torchvision.transforms.v2.functional.to_dtype()` with ``scale=True``
+ after this function to convert the decoded image into a uint8 or float
+ tensor.
+
+ Args:
+ input (Tensor[1]): a one dimensional uint8 tensor containing
+ the raw bytes of the PNG image.
+ mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
+ Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode`
+ for available modes.
+ apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
+ Default: False.
+
+ Returns:
+ output (Tensor[image_channels, image_height, image_width])
+ """
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+ _log_api_usage_once(decode_png)
+ if isinstance(mode, str):
+ mode = ImageReadMode[mode.upper()]
+ output = torch.ops.image.decode_png(input, mode.value, apply_exif_orientation)
+ return output
+
+
+def encode_png(input: torch.Tensor, compression_level: int = 6) -> torch.Tensor:
+ """
+ Takes an input tensor in CHW layout and returns a buffer with the contents
+ of its corresponding PNG file.
+
+ Args:
+ input (Tensor[channels, image_height, image_width]): int8 image tensor of
+ ``c`` channels, where ``c`` must 3 or 1.
+ compression_level (int): Compression factor for the resulting file, it must be a number
+ between 0 and 9. Default: 6
+
+ Returns:
+ Tensor[1]: A one dimensional int8 tensor that contains the raw bytes of the
+ PNG file.
+ """
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+ _log_api_usage_once(encode_png)
+ output = torch.ops.image.encode_png(input, compression_level)
+ return output
+
+
+def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
+ """
+ Takes an input tensor in CHW layout (or HW in the case of grayscale images)
+ and saves it in a PNG file.
+
+ Args:
+ input (Tensor[channels, image_height, image_width]): int8 image tensor of
+ ``c`` channels, where ``c`` must be 1 or 3.
+ filename (str or ``pathlib.Path``): Path to save the image.
+ compression_level (int): Compression factor for the resulting file, it must be a number
+ between 0 and 9. Default: 6
+ """
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+ _log_api_usage_once(write_png)
+ output = encode_png(input, compression_level)
+ write_file(filename, output)
+
+
+def decode_jpeg(
+ input: Union[torch.Tensor, List[torch.Tensor]],
+ mode: ImageReadMode = ImageReadMode.UNCHANGED,
+ device: Union[str, torch.device] = "cpu",
+ apply_exif_orientation: bool = False,
+) -> Union[torch.Tensor, List[torch.Tensor]]:
+ """Decode JPEG image(s) into 3D RGB or grayscale Tensor(s), on CPU or CUDA.
+
+ The values of the output tensor are uint8 between 0 and 255.
+
+ .. note::
+ When using a CUDA device, passing a list of tensors is more efficient than repeated individual calls to ``decode_jpeg``.
+ When using CPU the performance is equivalent.
+ The CUDA version of this function has explicitly been designed with thread-safety in mind.
+ This function does not return partial results in case of an error.
+
+ Args:
+ input (Tensor[1] or list[Tensor[1]]): a (list of) one dimensional uint8 tensor(s) containing
+ the raw bytes of the JPEG image. The tensor(s) must be on CPU,
+ regardless of the ``device`` parameter.
+ mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
+ Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode`
+ for available modes.
+ device (str or torch.device): The device on which the decoded image will
+ be stored. If a cuda device is specified, the image will be decoded
+ with `nvjpeg `_. This is only
+ supported for CUDA version >= 10.1
+
+ .. betastatus:: device parameter
+
+ .. warning::
+ There is a memory leak in the nvjpeg library for CUDA versions < 11.6.
+ Make sure to rely on CUDA 11.6 or above before using ``device="cuda"``.
+ apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
+ Default: False. Only implemented for JPEG format on CPU.
+
+ Returns:
+ output (Tensor[image_channels, image_height, image_width] or list[Tensor[image_channels, image_height, image_width]]):
+ The values of the output tensor(s) are uint8 between 0 and 255.
+ ``output.device`` will be set to the specified ``device``
+
+
+ """
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+ _log_api_usage_once(decode_jpeg)
+ if isinstance(device, str):
+ device = torch.device(device)
+ if isinstance(mode, str):
+ mode = ImageReadMode[mode.upper()]
+
+ if isinstance(input, list):
+ if len(input) == 0:
+ raise ValueError("Input list must contain at least one element")
+ if not all(isinstance(t, torch.Tensor) for t in input):
+ raise ValueError("All elements of the input list must be tensors.")
+ if not all(t.device.type == "cpu" for t in input):
+ raise ValueError("Input list must contain tensors on CPU.")
+ if device.type == "cuda":
+ return torch.ops.image.decode_jpegs_cuda(input, mode.value, device)
+ else:
+ return [torch.ops.image.decode_jpeg(img, mode.value, apply_exif_orientation) for img in input]
+
+ else: # input is tensor
+ if input.device.type != "cpu":
+ raise ValueError("Input tensor must be a CPU tensor")
+ if device.type == "cuda":
+ return torch.ops.image.decode_jpegs_cuda([input], mode.value, device)[0]
+ else:
+ return torch.ops.image.decode_jpeg(input, mode.value, apply_exif_orientation)
+
+
+def encode_jpeg(
+ input: Union[torch.Tensor, List[torch.Tensor]], quality: int = 75
+) -> Union[torch.Tensor, List[torch.Tensor]]:
+ """Encode RGB tensor(s) into raw encoded jpeg bytes, on CPU or CUDA.
+
+ .. note::
+ Passing a list of CUDA tensors is more efficient than repeated individual calls to ``encode_jpeg``.
+ For CPU tensors the performance is equivalent.
+
+ Args:
+ input (Tensor[channels, image_height, image_width] or List[Tensor[channels, image_height, image_width]]):
+ (list of) uint8 image tensor(s) of ``c`` channels, where ``c`` must be 1 or 3
+ quality (int): Quality of the resulting JPEG file(s). Must be a number between
+ 1 and 100. Default: 75
+
+ Returns:
+ output (Tensor[1] or list[Tensor[1]]): A (list of) one dimensional uint8 tensor(s) that contain the raw bytes of the JPEG file.
+ """
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+ _log_api_usage_once(encode_jpeg)
+ if quality < 1 or quality > 100:
+ raise ValueError("Image quality should be a positive number between 1 and 100")
+ if isinstance(input, list):
+ if not input:
+ raise ValueError("encode_jpeg requires at least one input tensor when a list is passed")
+ if input[0].device.type == "cuda":
+ return torch.ops.image.encode_jpegs_cuda(input, quality)
+ else:
+ return [torch.ops.image.encode_jpeg(image, quality) for image in input]
+ else: # single input tensor
+ if input.device.type == "cuda":
+ return torch.ops.image.encode_jpegs_cuda([input], quality)[0]
+ else:
+ return torch.ops.image.encode_jpeg(input, quality)
+
+
+def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
+ """
+ Takes an input tensor in CHW layout and saves it in a JPEG file.
+
+ Args:
+ input (Tensor[channels, image_height, image_width]): int8 image tensor of ``c``
+ channels, where ``c`` must be 1 or 3.
+ filename (str or ``pathlib.Path``): Path to save the image.
+ quality (int): Quality of the resulting JPEG file, it must be a number
+ between 1 and 100. Default: 75
+ """
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+ _log_api_usage_once(write_jpeg)
+ output = encode_jpeg(input, quality)
+ assert isinstance(output, torch.Tensor) # Needed for torchscript
+ write_file(filename, output)
+
+
+def decode_image(
+ input: Union[torch.Tensor, str],
+ mode: ImageReadMode = ImageReadMode.UNCHANGED,
+ apply_exif_orientation: bool = False,
+) -> torch.Tensor:
+ """Decode an image into a uint8 tensor, from a path or from raw encoded bytes.
+
+ Currently supported image formats are jpeg, png, gif and webp.
+
+ The values of the output tensor are in uint8 in [0, 255] for most cases.
+
+ If the image is a 16-bit png, then the output tensor is uint16 in [0, 65535]
+ (supported from torchvision ``0.21``). Since uint16 support is limited in
+ pytorch, we recommend calling
+ :func:`torchvision.transforms.v2.functional.to_dtype()` with ``scale=True``
+ after this function to convert the decoded image into a uint8 or float
+ tensor.
+
+ Args:
+ input (Tensor or str or ``pathlib.Path``): The image to decode. If a
+ tensor is passed, it must be one dimensional uint8 tensor containing
+ the raw bytes of the image. Otherwise, this must be a path to the image file.
+ mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
+ Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode`
+ for available modes.
+ apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
+ Only applies to JPEG and PNG images. Default: False.
+
+ Returns:
+ output (Tensor[image_channels, image_height, image_width])
+ """
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+ _log_api_usage_once(decode_image)
+ if not isinstance(input, torch.Tensor):
+ input = read_file(str(input))
+ if isinstance(mode, str):
+ mode = ImageReadMode[mode.upper()]
+ output = torch.ops.image.decode_image(input, mode.value, apply_exif_orientation)
+ return output
+
+
+def read_image(
+ path: str,
+ mode: ImageReadMode = ImageReadMode.UNCHANGED,
+ apply_exif_orientation: bool = False,
+) -> torch.Tensor:
+ """[OBSOLETE] Use :func:`~torchvision.io.decode_image` instead."""
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+ _log_api_usage_once(read_image)
+ data = read_file(path)
+ return decode_image(data, mode, apply_exif_orientation=apply_exif_orientation)
+
+
+def decode_gif(input: torch.Tensor) -> torch.Tensor:
+ """
+ Decode a GIF image into a 3 or 4 dimensional RGB Tensor.
+
+ The values of the output tensor are uint8 between 0 and 255.
+ The output tensor has shape ``(C, H, W)`` if there is only one image in the
+ GIF, and ``(N, C, H, W)`` if there are ``N`` images.
+
+ Args:
+ input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
+ the raw bytes of the GIF image.
+
+ Returns:
+ output (Tensor[image_channels, image_height, image_width] or Tensor[num_images, image_channels, image_height, image_width])
+ """
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+ _log_api_usage_once(decode_gif)
+ return torch.ops.image.decode_gif(input)
+
+
+def decode_webp(
+ input: torch.Tensor,
+ mode: ImageReadMode = ImageReadMode.UNCHANGED,
+) -> torch.Tensor:
+ """
+ Decode a WEBP image into a 3 dimensional RGB[A] Tensor.
+
+ The values of the output tensor are uint8 between 0 and 255.
+
+ Args:
+ input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
+ the raw bytes of the WEBP image.
+ mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
+ Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode`
+ for available modes.
+
+ Returns:
+ Decoded image (Tensor[image_channels, image_height, image_width])
+ """
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+ _log_api_usage_once(decode_webp)
+ if isinstance(mode, str):
+ mode = ImageReadMode[mode.upper()]
+ return torch.ops.image.decode_webp(input, mode.value)
+
+
+def _decode_avif(
+ input: torch.Tensor,
+ mode: ImageReadMode = ImageReadMode.UNCHANGED,
+) -> torch.Tensor:
+ """
+ Decode an AVIF image into a 3 dimensional RGB[A] Tensor.
+
+ The values of the output tensor are in uint8 in [0, 255] for most images. If
+ the image has a bit-depth of more than 8, then the output tensor is uint16
+ in [0, 65535]. Since uint16 support is limited in pytorch, we recommend
+ calling :func:`torchvision.transforms.v2.functional.to_dtype()` with
+ ``scale=True`` after this function to convert the decoded image into a uint8
+ or float tensor.
+
+ Args:
+ input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
+ the raw bytes of the AVIF image.
+ mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
+ Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode`
+ for available modes.
+
+ Returns:
+ Decoded image (Tensor[image_channels, image_height, image_width])
+ """
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+ _log_api_usage_once(_decode_avif)
+ if isinstance(mode, str):
+ mode = ImageReadMode[mode.upper()]
+ return torch.ops.image.decode_avif(input, mode.value)
+
+
+def _decode_heic(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
+ """
+ Decode an HEIC image into a 3 dimensional RGB[A] Tensor.
+
+ The values of the output tensor are in uint8 in [0, 255] for most images. If
+ the image has a bit-depth of more than 8, then the output tensor is uint16
+ in [0, 65535]. Since uint16 support is limited in pytorch, we recommend
+ calling :func:`torchvision.transforms.v2.functional.to_dtype()` with
+ ``scale=True`` after this function to convert the decoded image into a uint8
+ or float tensor.
+
+ Args:
+ input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
+ the raw bytes of the HEIC image.
+ mode (str or ImageReadMode): The mode to convert the image to, e.g. "RGB".
+ Default is "UNCHANGED". See :class:`~torchvision.io.ImageReadMode`
+ for available modes.
+
+ Returns:
+ Decoded image (Tensor[image_channels, image_height, image_width])
+ """
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+ _log_api_usage_once(_decode_heic)
+ if isinstance(mode, str):
+ mode = ImageReadMode[mode.upper()]
+ return torch.ops.image.decode_heic(input, mode.value)
diff --git a/.venv/lib/python3.11/site-packages/torchvision/io/video.py b/.venv/lib/python3.11/site-packages/torchvision/io/video.py
new file mode 100644
index 0000000000000000000000000000000000000000..73c97f37e2952db23f231c7f39e7f8a9a257622f
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/io/video.py
@@ -0,0 +1,438 @@
+import gc
+import math
+import os
+import re
+import warnings
+from fractions import Fraction
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..utils import _log_api_usage_once
+from . import _video_opt
+
+try:
+ import av
+
+ av.logging.set_level(av.logging.ERROR)
+ if not hasattr(av.video.frame.VideoFrame, "pict_type"):
+ av = ImportError(
+ """\
+Your version of PyAV is too old for the necessary video operations in torchvision.
+If you are on Python 3.5, you will have to build from source (the conda-forge
+packages are not up-to-date). See
+https://github.com/mikeboers/PyAV#installation for instructions on how to
+install PyAV on your system.
+"""
+ )
+except ImportError:
+ av = ImportError(
+ """\
+PyAV is not installed, and is necessary for the video operations in torchvision.
+See https://github.com/mikeboers/PyAV#installation for instructions on how to
+install PyAV on your system.
+"""
+ )
+
+
+def _check_av_available() -> None:
+ if isinstance(av, Exception):
+ raise av
+
+
+def _av_available() -> bool:
+ return not isinstance(av, Exception)
+
+
+# PyAV has some reference cycles
+_CALLED_TIMES = 0
+_GC_COLLECTION_INTERVAL = 10
+
+
+def write_video(
+ filename: str,
+ video_array: torch.Tensor,
+ fps: float,
+ video_codec: str = "libx264",
+ options: Optional[Dict[str, Any]] = None,
+ audio_array: Optional[torch.Tensor] = None,
+ audio_fps: Optional[float] = None,
+ audio_codec: Optional[str] = None,
+ audio_options: Optional[Dict[str, Any]] = None,
+) -> None:
+ """
+ Writes a 4d tensor in [T, H, W, C] format in a video file
+
+ .. warning::
+
+ In the near future, we intend to centralize PyTorch's video decoding
+ capabilities within the `torchcodec
+ `_ project. We encourage you to
+ try it out and share your feedback, as the torchvision video decoders
+ will eventually be deprecated.
+
+ Args:
+ filename (str): path where the video will be saved
+ video_array (Tensor[T, H, W, C]): tensor containing the individual frames,
+ as a uint8 tensor in [T, H, W, C] format
+ fps (Number): video frames per second
+ video_codec (str): the name of the video codec, i.e. "libx264", "h264", etc.
+ options (Dict): dictionary containing options to be passed into the PyAV video stream
+ audio_array (Tensor[C, N]): tensor containing the audio, where C is the number of channels
+ and N is the number of samples
+ audio_fps (Number): audio sample rate, typically 44100 or 48000
+ audio_codec (str): the name of the audio codec, i.e. "mp3", "aac", etc.
+ audio_options (Dict): dictionary containing options to be passed into the PyAV audio stream
+ """
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+ _log_api_usage_once(write_video)
+ _check_av_available()
+ video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy(force=True)
+
+ # PyAV does not support floating point numbers with decimal point
+ # and will throw OverflowException in case this is not the case
+ if isinstance(fps, float):
+ fps = np.round(fps)
+
+ with av.open(filename, mode="w") as container:
+ stream = container.add_stream(video_codec, rate=fps)
+ stream.width = video_array.shape[2]
+ stream.height = video_array.shape[1]
+ stream.pix_fmt = "yuv420p" if video_codec != "libx264rgb" else "rgb24"
+ stream.options = options or {}
+
+ if audio_array is not None:
+ audio_format_dtypes = {
+ "dbl": " 1 else "mono"
+ audio_sample_fmt = container.streams.audio[0].format.name
+
+ format_dtype = np.dtype(audio_format_dtypes[audio_sample_fmt])
+ audio_array = torch.as_tensor(audio_array).numpy(force=True).astype(format_dtype)
+
+ frame = av.AudioFrame.from_ndarray(audio_array, format=audio_sample_fmt, layout=audio_layout)
+
+ frame.sample_rate = audio_fps
+
+ for packet in a_stream.encode(frame):
+ container.mux(packet)
+
+ for packet in a_stream.encode():
+ container.mux(packet)
+
+ for img in video_array:
+ frame = av.VideoFrame.from_ndarray(img, format="rgb24")
+ frame.pict_type = "NONE"
+ for packet in stream.encode(frame):
+ container.mux(packet)
+
+ # Flush stream
+ for packet in stream.encode():
+ container.mux(packet)
+
+
+def _read_from_stream(
+ container: "av.container.Container",
+ start_offset: float,
+ end_offset: float,
+ pts_unit: str,
+ stream: "av.stream.Stream",
+ stream_name: Dict[str, Optional[Union[int, Tuple[int, ...], List[int]]]],
+) -> List["av.frame.Frame"]:
+ global _CALLED_TIMES, _GC_COLLECTION_INTERVAL
+ _CALLED_TIMES += 1
+ if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1:
+ gc.collect()
+
+ if pts_unit == "sec":
+ # TODO: we should change all of this from ground up to simply take
+ # sec and convert to MS in C++
+ start_offset = int(math.floor(start_offset * (1 / stream.time_base)))
+ if end_offset != float("inf"):
+ end_offset = int(math.ceil(end_offset * (1 / stream.time_base)))
+ else:
+ warnings.warn("The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.")
+
+ frames = {}
+ should_buffer = True
+ max_buffer_size = 5
+ if stream.type == "video":
+ # DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt)
+ # so need to buffer some extra frames to sort everything
+ # properly
+ extradata = stream.codec_context.extradata
+ # overly complicated way of finding if `divx_packed` is set, following
+ # https://github.com/FFmpeg/FFmpeg/commit/d5a21172283572af587b3d939eba0091484d3263
+ if extradata and b"DivX" in extradata:
+ # can't use regex directly because of some weird characters sometimes...
+ pos = extradata.find(b"DivX")
+ d = extradata[pos:]
+ o = re.search(rb"DivX(\d+)Build(\d+)(\w)", d)
+ if o is None:
+ o = re.search(rb"DivX(\d+)b(\d+)(\w)", d)
+ if o is not None:
+ should_buffer = o.group(3) == b"p"
+ seek_offset = start_offset
+ # some files don't seek to the right location, so better be safe here
+ seek_offset = max(seek_offset - 1, 0)
+ if should_buffer:
+ # FIXME this is kind of a hack, but we will jump to the previous keyframe
+ # so this will be safe
+ seek_offset = max(seek_offset - max_buffer_size, 0)
+ try:
+ # TODO check if stream needs to always be the video stream here or not
+ container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
+ except av.AVError:
+ # TODO add some warnings in this case
+ # print("Corrupted file?", container.name)
+ return []
+ buffer_count = 0
+ try:
+ for _idx, frame in enumerate(container.decode(**stream_name)):
+ frames[frame.pts] = frame
+ if frame.pts >= end_offset:
+ if should_buffer and buffer_count < max_buffer_size:
+ buffer_count += 1
+ continue
+ break
+ except av.AVError:
+ # TODO add a warning
+ pass
+ # ensure that the results are sorted wrt the pts
+ result = [frames[i] for i in sorted(frames) if start_offset <= frames[i].pts <= end_offset]
+ if len(frames) > 0 and start_offset > 0 and start_offset not in frames:
+ # if there is no frame that exactly matches the pts of start_offset
+ # add the last frame smaller than start_offset, to guarantee that
+ # we will have all the necessary data. This is most useful for audio
+ preceding_frames = [i for i in frames if i < start_offset]
+ if len(preceding_frames) > 0:
+ first_frame_pts = max(preceding_frames)
+ result.insert(0, frames[first_frame_pts])
+ return result
+
+
+def _align_audio_frames(
+ aframes: torch.Tensor, audio_frames: List["av.frame.Frame"], ref_start: int, ref_end: float
+) -> torch.Tensor:
+ start, end = audio_frames[0].pts, audio_frames[-1].pts
+ total_aframes = aframes.shape[1]
+ step_per_aframe = (end - start + 1) / total_aframes
+ s_idx = 0
+ e_idx = total_aframes
+ if start < ref_start:
+ s_idx = int((ref_start - start) / step_per_aframe)
+ if end > ref_end:
+ e_idx = int((ref_end - end) / step_per_aframe)
+ return aframes[:, s_idx:e_idx]
+
+
+def read_video(
+ filename: str,
+ start_pts: Union[float, Fraction] = 0,
+ end_pts: Optional[Union[float, Fraction]] = None,
+ pts_unit: str = "pts",
+ output_format: str = "THWC",
+) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
+ """
+ Reads a video from a file, returning both the video frames and the audio frames
+
+ .. warning::
+
+ In the near future, we intend to centralize PyTorch's video decoding
+ capabilities within the `torchcodec
+ `_ project. We encourage you to
+ try it out and share your feedback, as the torchvision video decoders
+ will eventually be deprecated.
+
+ Args:
+ filename (str): path to the video file. If using the pyav backend, this can be whatever ``av.open`` accepts.
+ start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
+ The start presentation time of the video
+ end_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
+ The end presentation time
+ pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted,
+ either 'pts' or 'sec'. Defaults to 'pts'.
+ output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".
+
+ Returns:
+ vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames
+ aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points
+ info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int)
+ """
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+ _log_api_usage_once(read_video)
+
+ output_format = output_format.upper()
+ if output_format not in ("THWC", "TCHW"):
+ raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
+
+ from torchvision import get_video_backend
+
+ if get_video_backend() != "pyav":
+ if not os.path.exists(filename):
+ raise RuntimeError(f"File not found: {filename}")
+ vframes, aframes, info = _video_opt._read_video(filename, start_pts, end_pts, pts_unit)
+ else:
+ _check_av_available()
+
+ if end_pts is None:
+ end_pts = float("inf")
+
+ if end_pts < start_pts:
+ raise ValueError(
+ f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}"
+ )
+
+ info = {}
+ video_frames = []
+ audio_frames = []
+ audio_timebase = _video_opt.default_timebase
+
+ try:
+ with av.open(filename, metadata_errors="ignore") as container:
+ if container.streams.audio:
+ audio_timebase = container.streams.audio[0].time_base
+ if container.streams.video:
+ video_frames = _read_from_stream(
+ container,
+ start_pts,
+ end_pts,
+ pts_unit,
+ container.streams.video[0],
+ {"video": 0},
+ )
+ video_fps = container.streams.video[0].average_rate
+ # guard against potentially corrupted files
+ if video_fps is not None:
+ info["video_fps"] = float(video_fps)
+
+ if container.streams.audio:
+ audio_frames = _read_from_stream(
+ container,
+ start_pts,
+ end_pts,
+ pts_unit,
+ container.streams.audio[0],
+ {"audio": 0},
+ )
+ info["audio_fps"] = container.streams.audio[0].rate
+
+ except av.AVError:
+ # TODO raise a warning?
+ pass
+
+ vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames]
+ aframes_list = [frame.to_ndarray() for frame in audio_frames]
+
+ if vframes_list:
+ vframes = torch.as_tensor(np.stack(vframes_list))
+ else:
+ vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)
+
+ if aframes_list:
+ aframes = np.concatenate(aframes_list, 1)
+ aframes = torch.as_tensor(aframes)
+ if pts_unit == "sec":
+ start_pts = int(math.floor(start_pts * (1 / audio_timebase)))
+ if end_pts != float("inf"):
+ end_pts = int(math.ceil(end_pts * (1 / audio_timebase)))
+ aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
+ else:
+ aframes = torch.empty((1, 0), dtype=torch.float32)
+
+ if output_format == "TCHW":
+ # [T,H,W,C] --> [T,C,H,W]
+ vframes = vframes.permute(0, 3, 1, 2)
+
+ return vframes, aframes, info
+
+
+def _can_read_timestamps_from_packets(container: "av.container.Container") -> bool:
+ extradata = container.streams[0].codec_context.extradata
+ if extradata is None:
+ return False
+ if b"Lavc" in extradata:
+ return True
+ return False
+
+
+def _decode_video_timestamps(container: "av.container.Container") -> List[int]:
+ if _can_read_timestamps_from_packets(container):
+ # fast path
+ return [x.pts for x in container.demux(video=0) if x.pts is not None]
+ else:
+ return [x.pts for x in container.decode(video=0) if x.pts is not None]
+
+
+def read_video_timestamps(filename: str, pts_unit: str = "pts") -> Tuple[List[int], Optional[float]]:
+ """
+ List the video frames timestamps.
+
+ .. warning::
+
+ In the near future, we intend to centralize PyTorch's video decoding
+ capabilities within the `torchcodec
+ `_ project. We encourage you to
+ try it out and share your feedback, as the torchvision video decoders
+ will eventually be deprecated.
+
+ Note that the function decodes the whole video frame-by-frame.
+
+ Args:
+ filename (str): path to the video file
+ pts_unit (str, optional): unit in which timestamp values will be returned
+ either 'pts' or 'sec'. Defaults to 'pts'.
+
+ Returns:
+ pts (List[int] if pts_unit = 'pts', List[Fraction] if pts_unit = 'sec'):
+ presentation timestamps for each one of the frames in the video.
+ video_fps (float, optional): the frame rate for the video
+
+ """
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+ _log_api_usage_once(read_video_timestamps)
+ from torchvision import get_video_backend
+
+ if get_video_backend() != "pyav":
+ return _video_opt._read_video_timestamps(filename, pts_unit)
+
+ _check_av_available()
+
+ video_fps = None
+ pts = []
+
+ try:
+ with av.open(filename, metadata_errors="ignore") as container:
+ if container.streams.video:
+ video_stream = container.streams.video[0]
+ video_time_base = video_stream.time_base
+ try:
+ pts = _decode_video_timestamps(container)
+ except av.AVError:
+ warnings.warn(f"Failed decoding frames for file {filename}")
+ video_fps = float(video_stream.average_rate)
+ except av.AVError as e:
+ msg = f"Failed to open container for {filename}; Caught error: {e}"
+ warnings.warn(msg, RuntimeWarning)
+
+ pts.sort()
+
+ if pts_unit == "sec":
+ pts = [x * video_time_base for x in pts]
+
+ return pts, video_fps
diff --git a/.venv/lib/python3.11/site-packages/torchvision/io/video_reader.py b/.venv/lib/python3.11/site-packages/torchvision/io/video_reader.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf319fe288e1e1a67cc94cb798ab8a9212671101
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/io/video_reader.py
@@ -0,0 +1,294 @@
+import io
+import warnings
+
+from typing import Any, Dict, Iterator
+
+import torch
+
+from ..utils import _log_api_usage_once
+
+from ._video_opt import _HAS_CPU_VIDEO_DECODER
+
+if _HAS_CPU_VIDEO_DECODER:
+
+ def _has_video_opt() -> bool:
+ return True
+
+else:
+
+ def _has_video_opt() -> bool:
+ return False
+
+
+try:
+ import av
+
+ av.logging.set_level(av.logging.ERROR)
+ if not hasattr(av.video.frame.VideoFrame, "pict_type"):
+ av = ImportError(
+ """\
+Your version of PyAV is too old for the necessary video operations in torchvision.
+If you are on Python 3.5, you will have to build from source (the conda-forge
+packages are not up-to-date). See
+https://github.com/mikeboers/PyAV#installation for instructions on how to
+install PyAV on your system.
+"""
+ )
+except ImportError:
+ av = ImportError(
+ """\
+PyAV is not installed, and is necessary for the video operations in torchvision.
+See https://github.com/mikeboers/PyAV#installation for instructions on how to
+install PyAV on your system.
+"""
+ )
+
+
+class VideoReader:
+ """
+ Fine-grained video-reading API.
+ Supports frame-by-frame reading of various streams from a single video
+ container. Much like previous video_reader API it supports the following
+ backends: video_reader, pyav, and cuda.
+ Backends can be set via `torchvision.set_video_backend` function.
+
+ .. warning::
+
+ In the near future, we intend to centralize PyTorch's video decoding
+ capabilities within the `torchcodec
+ `_ project. We encourage you to
+ try it out and share your feedback, as the torchvision video decoders
+ will eventually be deprecated.
+
+ .. betastatus:: VideoReader class
+
+ Example:
+ The following examples creates a :mod:`VideoReader` object, seeks into 2s
+ point, and returns a single frame::
+
+ import torchvision
+ video_path = "path_to_a_test_video"
+ reader = torchvision.io.VideoReader(video_path, "video")
+ reader.seek(2.0)
+ frame = next(reader)
+
+ :mod:`VideoReader` implements the iterable API, which makes it suitable to
+ using it in conjunction with :mod:`itertools` for more advanced reading.
+ As such, we can use a :mod:`VideoReader` instance inside for loops::
+
+ reader.seek(2)
+ for frame in reader:
+ frames.append(frame['data'])
+ # additionally, `seek` implements a fluent API, so we can do
+ for frame in reader.seek(2):
+ frames.append(frame['data'])
+
+ With :mod:`itertools`, we can read all frames between 2 and 5 seconds with the
+ following code::
+
+ for frame in itertools.takewhile(lambda x: x['pts'] <= 5, reader.seek(2)):
+ frames.append(frame['data'])
+
+ and similarly, reading 10 frames after the 2s timestamp can be achieved
+ as follows::
+
+ for frame in itertools.islice(reader.seek(2), 10):
+ frames.append(frame['data'])
+
+ .. note::
+
+ Each stream descriptor consists of two parts: stream type (e.g. 'video') and
+ a unique stream id (which are determined by the video encoding).
+ In this way, if the video container contains multiple
+ streams of the same type, users can access the one they want.
+ If only stream type is passed, the decoder auto-detects first stream of that type.
+
+ Args:
+ src (string, bytes object, or tensor): The media source.
+ If string-type, it must be a file path supported by FFMPEG.
+ If bytes, should be an in-memory representation of a file supported by FFMPEG.
+ If Tensor, it is interpreted internally as byte buffer.
+ It must be one-dimensional, of type ``torch.uint8``.
+
+ stream (string, optional): descriptor of the required stream, followed by the stream id,
+ in the format ``{stream_type}:{stream_id}``. Defaults to ``"video:0"``.
+ Currently available options include ``['video', 'audio']``
+
+ num_threads (int, optional): number of threads used by the codec to decode video.
+ Default value (0) enables multithreading with codec-dependent heuristic. The performance
+ will depend on the version of FFMPEG codecs supported.
+ """
+
+ def __init__(
+ self,
+ src: str,
+ stream: str = "video",
+ num_threads: int = 0,
+ ) -> None:
+ _log_api_usage_once(self)
+ from .. import get_video_backend
+
+ self.backend = get_video_backend()
+ if isinstance(src, str):
+ if not src:
+ raise ValueError("src cannot be empty")
+ elif isinstance(src, bytes):
+ if self.backend in ["cuda"]:
+ raise RuntimeError(
+ "VideoReader cannot be initialized from bytes object when using cuda or pyav backend."
+ )
+ elif self.backend == "pyav":
+ src = io.BytesIO(src)
+ else:
+ with warnings.catch_warnings():
+ # Ignore the warning because we actually don't modify the buffer in this function
+ warnings.filterwarnings("ignore", message="The given buffer is not writable")
+ src = torch.frombuffer(src, dtype=torch.uint8)
+ elif isinstance(src, torch.Tensor):
+ if self.backend in ["cuda", "pyav"]:
+ raise RuntimeError(
+ "VideoReader cannot be initialized from Tensor object when using cuda or pyav backend."
+ )
+ else:
+ raise ValueError(f"src must be either string, Tensor or bytes object. Got {type(src)}")
+
+ if self.backend == "cuda":
+ device = torch.device("cuda")
+ self._c = torch.classes.torchvision.GPUDecoder(src, device)
+
+ elif self.backend == "video_reader":
+ if isinstance(src, str):
+ self._c = torch.classes.torchvision.Video(src, stream, num_threads)
+ elif isinstance(src, torch.Tensor):
+ self._c = torch.classes.torchvision.Video("", "", 0)
+ self._c.init_from_memory(src, stream, num_threads)
+
+ elif self.backend == "pyav":
+ self.container = av.open(src, metadata_errors="ignore")
+ # TODO: load metadata
+ stream_type = stream.split(":")[0]
+ stream_id = 0 if len(stream.split(":")) == 1 else int(stream.split(":")[1])
+ self.pyav_stream = {stream_type: stream_id}
+ self._c = self.container.decode(**self.pyav_stream)
+
+ # TODO: add extradata exception
+
+ else:
+ raise RuntimeError("Unknown video backend: {}".format(self.backend))
+
+ def __next__(self) -> Dict[str, Any]:
+ """Decodes and returns the next frame of the current stream.
+ Frames are encoded as a dict with mandatory
+ data and pts fields, where data is a tensor, and pts is a
+ presentation timestamp of the frame expressed in seconds
+ as a float.
+
+ Returns:
+ (dict): a dictionary and containing decoded frame (``data``)
+ and corresponding timestamp (``pts``) in seconds
+
+ """
+ if self.backend == "cuda":
+ frame = self._c.next()
+ if frame.numel() == 0:
+ raise StopIteration
+ return {"data": frame, "pts": None}
+ elif self.backend == "video_reader":
+ frame, pts = self._c.next()
+ else:
+ try:
+ frame = next(self._c)
+ pts = float(frame.pts * frame.time_base)
+ if "video" in self.pyav_stream:
+ frame = torch.as_tensor(frame.to_rgb().to_ndarray()).permute(2, 0, 1)
+ elif "audio" in self.pyav_stream:
+ frame = torch.as_tensor(frame.to_ndarray()).permute(1, 0)
+ else:
+ frame = None
+ except av.error.EOFError:
+ raise StopIteration
+
+ if frame.numel() == 0:
+ raise StopIteration
+
+ return {"data": frame, "pts": pts}
+
+ def __iter__(self) -> Iterator[Dict[str, Any]]:
+ return self
+
+ def seek(self, time_s: float, keyframes_only: bool = False) -> "VideoReader":
+ """Seek within current stream.
+
+ Args:
+ time_s (float): seek time in seconds
+ keyframes_only (bool): allow to seek only to keyframes
+
+ .. note::
+ Current implementation is the so-called precise seek. This
+ means following seek, call to :mod:`next()` will return the
+ frame with the exact timestamp if it exists or
+ the first frame with timestamp larger than ``time_s``.
+ """
+ if self.backend in ["cuda", "video_reader"]:
+ self._c.seek(time_s, keyframes_only)
+ else:
+ # handle special case as pyav doesn't catch it
+ if time_s < 0:
+ time_s = 0
+ temp_str = self.container.streams.get(**self.pyav_stream)[0]
+ offset = int(round(time_s / temp_str.time_base))
+ if not keyframes_only:
+ warnings.warn("Accurate seek is not implemented for pyav backend")
+ self.container.seek(offset, backward=True, any_frame=False, stream=temp_str)
+ self._c = self.container.decode(**self.pyav_stream)
+ return self
+
+ def get_metadata(self) -> Dict[str, Any]:
+ """Returns video metadata
+
+ Returns:
+ (dict): dictionary containing duration and frame rate for every stream
+ """
+ if self.backend == "pyav":
+ metadata = {} # type: Dict[str, Any]
+ for stream in self.container.streams:
+ if stream.type not in metadata:
+ if stream.type == "video":
+ rate_n = "fps"
+ else:
+ rate_n = "framerate"
+ metadata[stream.type] = {rate_n: [], "duration": []}
+
+ rate = getattr(stream, "average_rate", None) or stream.sample_rate
+
+ metadata[stream.type]["duration"].append(float(stream.duration * stream.time_base))
+ metadata[stream.type][rate_n].append(float(rate))
+ return metadata
+ return self._c.get_metadata()
+
+ def set_current_stream(self, stream: str) -> bool:
+ """Set current stream.
+ Explicitly define the stream we are operating on.
+
+ Args:
+ stream (string): descriptor of the required stream. Defaults to ``"video:0"``
+ Currently available stream types include ``['video', 'audio']``.
+ Each descriptor consists of two parts: stream type (e.g. 'video') and
+ a unique stream id (which are determined by video encoding).
+ In this way, if the video container contains multiple
+ streams of the same type, users can access the one they want.
+ If only stream type is passed, the decoder auto-detects first stream
+ of that type and returns it.
+
+ Returns:
+ (bool): True on success, False otherwise
+ """
+ if self.backend == "cuda":
+ warnings.warn("GPU decoding only works with video stream.")
+ if self.backend == "pyav":
+ stream_type = stream.split(":")[0]
+ stream_id = 0 if len(stream.split(":")) == 1 else int(stream.split(":")[1])
+ self.pyav_stream = {stream_type: stream_id}
+ self._c = self.container.decode(**self.pyav_stream)
+ return True
+ return self._c.set_current_stream(stream)
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/detection/__init__.py b/.venv/lib/python3.11/site-packages/torchvision/models/detection/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4146651c737971cc5a883b6750f2ded3051bc8ea
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/models/detection/__init__.py
@@ -0,0 +1,7 @@
+from .faster_rcnn import *
+from .fcos import *
+from .keypoint_rcnn import *
+from .mask_rcnn import *
+from .retinanet import *
+from .ssd import *
+from .ssdlite import *
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..53c14c903980f452c1b2dbdf0374e2cab6b96615
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/_utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..55a8116c2291f01b23840183dc11306ed842b0a0
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/_utils.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/anchor_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/anchor_utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..484f2e1e6aa74a510cf4722d9e908df7585d7c8f
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/anchor_utils.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/backbone_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/backbone_utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..62e6d17251b09ac643c061512175f67ec1298df6
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/backbone_utils.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/faster_rcnn.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/faster_rcnn.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5dfcf9f1c0e7f80a006ae550791468d9fc5d113e
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/faster_rcnn.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/fcos.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/fcos.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1fcf2808f7879cf5d1e6aabad0566c2b3234be3d
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/fcos.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/generalized_rcnn.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/generalized_rcnn.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f75d0847f50abc76f152d305e0e5e4565b059920
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/generalized_rcnn.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/image_list.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/image_list.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dd15614c40d2c33c2be82d5aa58fb975c1f5a17b
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/image_list.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/keypoint_rcnn.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/keypoint_rcnn.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9a327c7be918ad99861b0b296e79f6a7fd71b216
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/keypoint_rcnn.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/mask_rcnn.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/mask_rcnn.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ab261a5d3c821ff91556fe4def744b4be5d9ef54
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/mask_rcnn.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/retinanet.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/retinanet.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..366d64ea4e8400316e146bdcc28e75e45fec654b
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/retinanet.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/roi_heads.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/roi_heads.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a95b09155442bea8f901927e6e92f1d07a120ef4
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/roi_heads.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/rpn.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/rpn.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..082abcea74273703272541d364a08359be789341
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/rpn.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/ssd.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/ssd.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..762f4f218164e13237d267e431b507cc42680ed2
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/ssd.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/ssdlite.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/ssdlite.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..639278d40f3f1a584a51fc5f670cf74921e4b78b
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/ssdlite.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/transform.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/transform.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dadc3025edc58a8708f92b4bf35c8ecc53084eba
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/detection/__pycache__/transform.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/detection/_utils.py b/.venv/lib/python3.11/site-packages/torchvision/models/detection/_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..559db858ac32f3b9f157aff3c22da83abece2a73
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/models/detection/_utils.py
@@ -0,0 +1,540 @@
+import math
+from collections import OrderedDict
+from typing import Dict, List, Optional, Tuple
+
+import torch
+from torch import nn, Tensor
+from torch.nn import functional as F
+from torchvision.ops import complete_box_iou_loss, distance_box_iou_loss, FrozenBatchNorm2d, generalized_box_iou_loss
+
+
+class BalancedPositiveNegativeSampler:
+ """
+ This class samples batches, ensuring that they contain a fixed proportion of positives
+ """
+
+ def __init__(self, batch_size_per_image: int, positive_fraction: float) -> None:
+ """
+ Args:
+ batch_size_per_image (int): number of elements to be selected per image
+ positive_fraction (float): percentage of positive elements per batch
+ """
+ self.batch_size_per_image = batch_size_per_image
+ self.positive_fraction = positive_fraction
+
+ def __call__(self, matched_idxs: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
+ """
+ Args:
+ matched_idxs: list of tensors containing -1, 0 or positive values.
+ Each tensor corresponds to a specific image.
+ -1 values are ignored, 0 are considered as negatives and > 0 as
+ positives.
+
+ Returns:
+ pos_idx (list[tensor])
+ neg_idx (list[tensor])
+
+ Returns two lists of binary masks for each image.
+ The first list contains the positive elements that were selected,
+ and the second list the negative example.
+ """
+ pos_idx = []
+ neg_idx = []
+ for matched_idxs_per_image in matched_idxs:
+ positive = torch.where(matched_idxs_per_image >= 1)[0]
+ negative = torch.where(matched_idxs_per_image == 0)[0]
+
+ num_pos = int(self.batch_size_per_image * self.positive_fraction)
+ # protect against not enough positive examples
+ num_pos = min(positive.numel(), num_pos)
+ num_neg = self.batch_size_per_image - num_pos
+ # protect against not enough negative examples
+ num_neg = min(negative.numel(), num_neg)
+
+ # randomly select positive and negative examples
+ perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
+ perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]
+
+ pos_idx_per_image = positive[perm1]
+ neg_idx_per_image = negative[perm2]
+
+ # create binary mask from indices
+ pos_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
+ neg_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
+
+ pos_idx_per_image_mask[pos_idx_per_image] = 1
+ neg_idx_per_image_mask[neg_idx_per_image] = 1
+
+ pos_idx.append(pos_idx_per_image_mask)
+ neg_idx.append(neg_idx_per_image_mask)
+
+ return pos_idx, neg_idx
+
+
+@torch.jit._script_if_tracing
+def encode_boxes(reference_boxes: Tensor, proposals: Tensor, weights: Tensor) -> Tensor:
+ """
+ Encode a set of proposals with respect to some
+ reference boxes
+
+ Args:
+ reference_boxes (Tensor): reference boxes
+ proposals (Tensor): boxes to be encoded
+ weights (Tensor[4]): the weights for ``(x, y, w, h)``
+ """
+
+ # perform some unpacking to make it JIT-fusion friendly
+ wx = weights[0]
+ wy = weights[1]
+ ww = weights[2]
+ wh = weights[3]
+
+ proposals_x1 = proposals[:, 0].unsqueeze(1)
+ proposals_y1 = proposals[:, 1].unsqueeze(1)
+ proposals_x2 = proposals[:, 2].unsqueeze(1)
+ proposals_y2 = proposals[:, 3].unsqueeze(1)
+
+ reference_boxes_x1 = reference_boxes[:, 0].unsqueeze(1)
+ reference_boxes_y1 = reference_boxes[:, 1].unsqueeze(1)
+ reference_boxes_x2 = reference_boxes[:, 2].unsqueeze(1)
+ reference_boxes_y2 = reference_boxes[:, 3].unsqueeze(1)
+
+ # implementation starts here
+ ex_widths = proposals_x2 - proposals_x1
+ ex_heights = proposals_y2 - proposals_y1
+ ex_ctr_x = proposals_x1 + 0.5 * ex_widths
+ ex_ctr_y = proposals_y1 + 0.5 * ex_heights
+
+ gt_widths = reference_boxes_x2 - reference_boxes_x1
+ gt_heights = reference_boxes_y2 - reference_boxes_y1
+ gt_ctr_x = reference_boxes_x1 + 0.5 * gt_widths
+ gt_ctr_y = reference_boxes_y1 + 0.5 * gt_heights
+
+ targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths
+ targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights
+ targets_dw = ww * torch.log(gt_widths / ex_widths)
+ targets_dh = wh * torch.log(gt_heights / ex_heights)
+
+ targets = torch.cat((targets_dx, targets_dy, targets_dw, targets_dh), dim=1)
+ return targets
+
+
+class BoxCoder:
+ """
+ This class encodes and decodes a set of bounding boxes into
+ the representation used for training the regressors.
+ """
+
+ def __init__(
+ self, weights: Tuple[float, float, float, float], bbox_xform_clip: float = math.log(1000.0 / 16)
+ ) -> None:
+ """
+ Args:
+ weights (4-element tuple)
+ bbox_xform_clip (float)
+ """
+ self.weights = weights
+ self.bbox_xform_clip = bbox_xform_clip
+
+ def encode(self, reference_boxes: List[Tensor], proposals: List[Tensor]) -> List[Tensor]:
+ boxes_per_image = [len(b) for b in reference_boxes]
+ reference_boxes = torch.cat(reference_boxes, dim=0)
+ proposals = torch.cat(proposals, dim=0)
+ targets = self.encode_single(reference_boxes, proposals)
+ return targets.split(boxes_per_image, 0)
+
+ def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
+ """
+ Encode a set of proposals with respect to some
+ reference boxes
+
+ Args:
+ reference_boxes (Tensor): reference boxes
+ proposals (Tensor): boxes to be encoded
+ """
+ dtype = reference_boxes.dtype
+ device = reference_boxes.device
+ weights = torch.as_tensor(self.weights, dtype=dtype, device=device)
+ targets = encode_boxes(reference_boxes, proposals, weights)
+
+ return targets
+
+ def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor:
+ torch._assert(
+ isinstance(boxes, (list, tuple)),
+ "This function expects boxes of type list or tuple.",
+ )
+ torch._assert(
+ isinstance(rel_codes, torch.Tensor),
+ "This function expects rel_codes of type torch.Tensor.",
+ )
+ boxes_per_image = [b.size(0) for b in boxes]
+ concat_boxes = torch.cat(boxes, dim=0)
+ box_sum = 0
+ for val in boxes_per_image:
+ box_sum += val
+ if box_sum > 0:
+ rel_codes = rel_codes.reshape(box_sum, -1)
+ pred_boxes = self.decode_single(rel_codes, concat_boxes)
+ if box_sum > 0:
+ pred_boxes = pred_boxes.reshape(box_sum, -1, 4)
+ return pred_boxes
+
+ def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
+ """
+ From a set of original boxes and encoded relative box offsets,
+ get the decoded boxes.
+
+ Args:
+ rel_codes (Tensor): encoded boxes
+ boxes (Tensor): reference boxes.
+ """
+
+ boxes = boxes.to(rel_codes.dtype)
+
+ widths = boxes[:, 2] - boxes[:, 0]
+ heights = boxes[:, 3] - boxes[:, 1]
+ ctr_x = boxes[:, 0] + 0.5 * widths
+ ctr_y = boxes[:, 1] + 0.5 * heights
+
+ wx, wy, ww, wh = self.weights
+ dx = rel_codes[:, 0::4] / wx
+ dy = rel_codes[:, 1::4] / wy
+ dw = rel_codes[:, 2::4] / ww
+ dh = rel_codes[:, 3::4] / wh
+
+ # Prevent sending too large values into torch.exp()
+ dw = torch.clamp(dw, max=self.bbox_xform_clip)
+ dh = torch.clamp(dh, max=self.bbox_xform_clip)
+
+ pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
+ pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
+ pred_w = torch.exp(dw) * widths[:, None]
+ pred_h = torch.exp(dh) * heights[:, None]
+
+ # Distance from center to box's corner.
+ c_to_c_h = torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h
+ c_to_c_w = torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w
+
+ pred_boxes1 = pred_ctr_x - c_to_c_w
+ pred_boxes2 = pred_ctr_y - c_to_c_h
+ pred_boxes3 = pred_ctr_x + c_to_c_w
+ pred_boxes4 = pred_ctr_y + c_to_c_h
+ pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=2).flatten(1)
+ return pred_boxes
+
+
+class BoxLinearCoder:
+ """
+ The linear box-to-box transform defined in FCOS. The transformation is parameterized
+ by the distance from the center of (square) src box to 4 edges of the target box.
+ """
+
+ def __init__(self, normalize_by_size: bool = True) -> None:
+ """
+ Args:
+ normalize_by_size (bool): normalize deltas by the size of src (anchor) boxes.
+ """
+ self.normalize_by_size = normalize_by_size
+
+ def encode(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
+ """
+ Encode a set of proposals with respect to some reference boxes
+
+ Args:
+ reference_boxes (Tensor): reference boxes
+ proposals (Tensor): boxes to be encoded
+
+ Returns:
+ Tensor: the encoded relative box offsets that can be used to
+ decode the boxes.
+
+ """
+
+ # get the center of reference_boxes
+ reference_boxes_ctr_x = 0.5 * (reference_boxes[..., 0] + reference_boxes[..., 2])
+ reference_boxes_ctr_y = 0.5 * (reference_boxes[..., 1] + reference_boxes[..., 3])
+
+ # get box regression transformation deltas
+ target_l = reference_boxes_ctr_x - proposals[..., 0]
+ target_t = reference_boxes_ctr_y - proposals[..., 1]
+ target_r = proposals[..., 2] - reference_boxes_ctr_x
+ target_b = proposals[..., 3] - reference_boxes_ctr_y
+
+ targets = torch.stack((target_l, target_t, target_r, target_b), dim=-1)
+
+ if self.normalize_by_size:
+ reference_boxes_w = reference_boxes[..., 2] - reference_boxes[..., 0]
+ reference_boxes_h = reference_boxes[..., 3] - reference_boxes[..., 1]
+ reference_boxes_size = torch.stack(
+ (reference_boxes_w, reference_boxes_h, reference_boxes_w, reference_boxes_h), dim=-1
+ )
+ targets = targets / reference_boxes_size
+ return targets
+
+ def decode(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
+
+ """
+ From a set of original boxes and encoded relative box offsets,
+ get the decoded boxes.
+
+ Args:
+ rel_codes (Tensor): encoded boxes
+ boxes (Tensor): reference boxes.
+
+ Returns:
+ Tensor: the predicted boxes with the encoded relative box offsets.
+
+ .. note::
+ This method assumes that ``rel_codes`` and ``boxes`` have same size for 0th dimension. i.e. ``len(rel_codes) == len(boxes)``.
+
+ """
+
+ boxes = boxes.to(dtype=rel_codes.dtype)
+
+ ctr_x = 0.5 * (boxes[..., 0] + boxes[..., 2])
+ ctr_y = 0.5 * (boxes[..., 1] + boxes[..., 3])
+
+ if self.normalize_by_size:
+ boxes_w = boxes[..., 2] - boxes[..., 0]
+ boxes_h = boxes[..., 3] - boxes[..., 1]
+
+ list_box_size = torch.stack((boxes_w, boxes_h, boxes_w, boxes_h), dim=-1)
+ rel_codes = rel_codes * list_box_size
+
+ pred_boxes1 = ctr_x - rel_codes[..., 0]
+ pred_boxes2 = ctr_y - rel_codes[..., 1]
+ pred_boxes3 = ctr_x + rel_codes[..., 2]
+ pred_boxes4 = ctr_y + rel_codes[..., 3]
+
+ pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=-1)
+ return pred_boxes
+
+
+class Matcher:
+ """
+ This class assigns to each predicted "element" (e.g., a box) a ground-truth
+ element. Each predicted element will have exactly zero or one matches; each
+ ground-truth element may be assigned to zero or more predicted elements.
+
+ Matching is based on the MxN match_quality_matrix, that characterizes how well
+ each (ground-truth, predicted)-pair match. For example, if the elements are
+ boxes, the matrix may contain box IoU overlap values.
+
+ The matcher returns a tensor of size N containing the index of the ground-truth
+ element m that matches to prediction n. If there is no match, a negative value
+ is returned.
+ """
+
+ BELOW_LOW_THRESHOLD = -1
+ BETWEEN_THRESHOLDS = -2
+
+ __annotations__ = {
+ "BELOW_LOW_THRESHOLD": int,
+ "BETWEEN_THRESHOLDS": int,
+ }
+
+ def __init__(self, high_threshold: float, low_threshold: float, allow_low_quality_matches: bool = False) -> None:
+ """
+ Args:
+ high_threshold (float): quality values greater than or equal to
+ this value are candidate matches.
+ low_threshold (float): a lower quality threshold used to stratify
+ matches into three levels:
+ 1) matches >= high_threshold
+ 2) BETWEEN_THRESHOLDS matches in [low_threshold, high_threshold)
+ 3) BELOW_LOW_THRESHOLD matches in [0, low_threshold)
+ allow_low_quality_matches (bool): if True, produce additional matches
+ for predictions that have only low-quality match candidates. See
+ set_low_quality_matches_ for more details.
+ """
+ self.BELOW_LOW_THRESHOLD = -1
+ self.BETWEEN_THRESHOLDS = -2
+ torch._assert(low_threshold <= high_threshold, "low_threshold should be <= high_threshold")
+ self.high_threshold = high_threshold
+ self.low_threshold = low_threshold
+ self.allow_low_quality_matches = allow_low_quality_matches
+
+ def __call__(self, match_quality_matrix: Tensor) -> Tensor:
+ """
+ Args:
+ match_quality_matrix (Tensor[float]): an MxN tensor, containing the
+ pairwise quality between M ground-truth elements and N predicted elements.
+
+ Returns:
+ matches (Tensor[int64]): an N tensor where N[i] is a matched gt in
+ [0, M - 1] or a negative value indicating that prediction i could not
+ be matched.
+ """
+ if match_quality_matrix.numel() == 0:
+ # empty targets or proposals not supported during training
+ if match_quality_matrix.shape[0] == 0:
+ raise ValueError("No ground-truth boxes available for one of the images during training")
+ else:
+ raise ValueError("No proposal boxes available for one of the images during training")
+
+ # match_quality_matrix is M (gt) x N (predicted)
+ # Max over gt elements (dim 0) to find best gt candidate for each prediction
+ matched_vals, matches = match_quality_matrix.max(dim=0)
+ if self.allow_low_quality_matches:
+ all_matches = matches.clone()
+ else:
+ all_matches = None # type: ignore[assignment]
+
+ # Assign candidate matches with low quality to negative (unassigned) values
+ below_low_threshold = matched_vals < self.low_threshold
+ between_thresholds = (matched_vals >= self.low_threshold) & (matched_vals < self.high_threshold)
+ matches[below_low_threshold] = self.BELOW_LOW_THRESHOLD
+ matches[between_thresholds] = self.BETWEEN_THRESHOLDS
+
+ if self.allow_low_quality_matches:
+ if all_matches is None:
+ torch._assert(False, "all_matches should not be None")
+ else:
+ self.set_low_quality_matches_(matches, all_matches, match_quality_matrix)
+
+ return matches
+
+ def set_low_quality_matches_(self, matches: Tensor, all_matches: Tensor, match_quality_matrix: Tensor) -> None:
+ """
+ Produce additional matches for predictions that have only low-quality matches.
+ Specifically, for each ground-truth find the set of predictions that have
+ maximum overlap with it (including ties); for each prediction in that set, if
+ it is unmatched, then match it to the ground-truth with which it has the highest
+ quality value.
+ """
+ # For each gt, find the prediction with which it has the highest quality
+ highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)
+ # Find the highest quality match available, even if it is low, including ties
+ gt_pred_pairs_of_highest_quality = torch.where(match_quality_matrix == highest_quality_foreach_gt[:, None])
+ # Example gt_pred_pairs_of_highest_quality:
+ # (tensor([0, 1, 1, 2, 2, 3, 3, 4, 5, 5]),
+ # tensor([39796, 32055, 32070, 39190, 40255, 40390, 41455, 45470, 45325, 46390]))
+ # Each element in the first tensor is a gt index, and each element in second tensor is a prediction index
+ # Note how gt items 1, 2, 3, and 5 each have two ties
+
+ pred_inds_to_update = gt_pred_pairs_of_highest_quality[1]
+ matches[pred_inds_to_update] = all_matches[pred_inds_to_update]
+
+
+class SSDMatcher(Matcher):
+ def __init__(self, threshold: float) -> None:
+ super().__init__(threshold, threshold, allow_low_quality_matches=False)
+
+ def __call__(self, match_quality_matrix: Tensor) -> Tensor:
+ matches = super().__call__(match_quality_matrix)
+
+ # For each gt, find the prediction with which it has the highest quality
+ _, highest_quality_pred_foreach_gt = match_quality_matrix.max(dim=1)
+ matches[highest_quality_pred_foreach_gt] = torch.arange(
+ highest_quality_pred_foreach_gt.size(0), dtype=torch.int64, device=highest_quality_pred_foreach_gt.device
+ )
+
+ return matches
+
+
+def overwrite_eps(model: nn.Module, eps: float) -> None:
+ """
+ This method overwrites the default eps values of all the
+ FrozenBatchNorm2d layers of the model with the provided value.
+ This is necessary to address the BC-breaking change introduced
+ by the bug-fix at pytorch/vision#2933. The overwrite is applied
+ only when the pretrained weights are loaded to maintain compatibility
+ with previous versions.
+
+ Args:
+ model (nn.Module): The model on which we perform the overwrite.
+ eps (float): The new value of eps.
+ """
+ for module in model.modules():
+ if isinstance(module, FrozenBatchNorm2d):
+ module.eps = eps
+
+
+def retrieve_out_channels(model: nn.Module, size: Tuple[int, int]) -> List[int]:
+ """
+ This method retrieves the number of output channels of a specific model.
+
+ Args:
+ model (nn.Module): The model for which we estimate the out_channels.
+ It should return a single Tensor or an OrderedDict[Tensor].
+ size (Tuple[int, int]): The size (wxh) of the input.
+
+ Returns:
+ out_channels (List[int]): A list of the output channels of the model.
+ """
+ in_training = model.training
+ model.eval()
+
+ with torch.no_grad():
+ # Use dummy data to retrieve the feature map sizes to avoid hard-coding their values
+ device = next(model.parameters()).device
+ tmp_img = torch.zeros((1, 3, size[1], size[0]), device=device)
+ features = model(tmp_img)
+ if isinstance(features, torch.Tensor):
+ features = OrderedDict([("0", features)])
+ out_channels = [x.size(1) for x in features.values()]
+
+ if in_training:
+ model.train()
+
+ return out_channels
+
+
+@torch.jit.unused
+def _fake_cast_onnx(v: Tensor) -> int:
+ return v # type: ignore[return-value]
+
+
+def _topk_min(input: Tensor, orig_kval: int, axis: int) -> int:
+ """
+ ONNX spec requires the k-value to be less than or equal to the number of inputs along
+ provided dim. Certain models use the number of elements along a particular axis instead of K
+ if K exceeds the number of elements along that axis. Previously, python's min() function was
+ used to determine whether to use the provided k-value or the specified dim axis value.
+
+ However, in cases where the model is being exported in tracing mode, python min() is
+ static causing the model to be traced incorrectly and eventually fail at the topk node.
+ In order to avoid this situation, in tracing mode, torch.min() is used instead.
+
+ Args:
+ input (Tensor): The original input tensor.
+ orig_kval (int): The provided k-value.
+ axis(int): Axis along which we retrieve the input size.
+
+ Returns:
+ min_kval (int): Appropriately selected k-value.
+ """
+ if not torch.jit.is_tracing():
+ return min(orig_kval, input.size(axis))
+ axis_dim_val = torch._shape_as_tensor(input)[axis].unsqueeze(0)
+ min_kval = torch.min(torch.cat((torch.tensor([orig_kval], dtype=axis_dim_val.dtype), axis_dim_val), 0))
+ return _fake_cast_onnx(min_kval)
+
+
+def _box_loss(
+ type: str,
+ box_coder: BoxCoder,
+ anchors_per_image: Tensor,
+ matched_gt_boxes_per_image: Tensor,
+ bbox_regression_per_image: Tensor,
+ cnf: Optional[Dict[str, float]] = None,
+) -> Tensor:
+ torch._assert(type in ["l1", "smooth_l1", "ciou", "diou", "giou"], f"Unsupported loss: {type}")
+
+ if type == "l1":
+ target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
+ return F.l1_loss(bbox_regression_per_image, target_regression, reduction="sum")
+ elif type == "smooth_l1":
+ target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
+ beta = cnf["beta"] if cnf is not None and "beta" in cnf else 1.0
+ return F.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction="sum", beta=beta)
+ else:
+ bbox_per_image = box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
+ eps = cnf["eps"] if cnf is not None and "eps" in cnf else 1e-7
+ if type == "ciou":
+ return complete_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
+ if type == "diou":
+ return distance_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
+ # otherwise giou
+ return generalized_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/detection/anchor_utils.py b/.venv/lib/python3.11/site-packages/torchvision/models/detection/anchor_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..253f6502a9b6344f5a3da239f2394179a256424e
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/models/detection/anchor_utils.py
@@ -0,0 +1,268 @@
+import math
+from typing import List, Optional
+
+import torch
+from torch import nn, Tensor
+
+from .image_list import ImageList
+
+
+class AnchorGenerator(nn.Module):
+ """
+ Module that generates anchors for a set of feature maps and
+ image sizes.
+
+ The module support computing anchors at multiple sizes and aspect ratios
+ per feature map. This module assumes aspect ratio = height / width for
+ each anchor.
+
+ sizes and aspect_ratios should have the same number of elements, and it should
+ correspond to the number of feature maps.
+
+ sizes[i] and aspect_ratios[i] can have an arbitrary number of elements,
+ and AnchorGenerator will output a set of sizes[i] * aspect_ratios[i] anchors
+ per spatial location for feature map i.
+
+ Args:
+ sizes (Tuple[Tuple[int]]):
+ aspect_ratios (Tuple[Tuple[float]]):
+ """
+
+ __annotations__ = {
+ "cell_anchors": List[torch.Tensor],
+ }
+
+ def __init__(
+ self,
+ sizes=((128, 256, 512),),
+ aspect_ratios=((0.5, 1.0, 2.0),),
+ ):
+ super().__init__()
+
+ if not isinstance(sizes[0], (list, tuple)):
+ # TODO change this
+ sizes = tuple((s,) for s in sizes)
+ if not isinstance(aspect_ratios[0], (list, tuple)):
+ aspect_ratios = (aspect_ratios,) * len(sizes)
+
+ self.sizes = sizes
+ self.aspect_ratios = aspect_ratios
+ self.cell_anchors = [
+ self.generate_anchors(size, aspect_ratio) for size, aspect_ratio in zip(sizes, aspect_ratios)
+ ]
+
+ # TODO: https://github.com/pytorch/pytorch/issues/26792
+ # For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
+ # (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios)
+ # This method assumes aspect ratio = height / width for an anchor.
+ def generate_anchors(
+ self,
+ scales: List[int],
+ aspect_ratios: List[float],
+ dtype: torch.dtype = torch.float32,
+ device: torch.device = torch.device("cpu"),
+ ) -> Tensor:
+ scales = torch.as_tensor(scales, dtype=dtype, device=device)
+ aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
+ h_ratios = torch.sqrt(aspect_ratios)
+ w_ratios = 1 / h_ratios
+
+ ws = (w_ratios[:, None] * scales[None, :]).view(-1)
+ hs = (h_ratios[:, None] * scales[None, :]).view(-1)
+
+ base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2
+ return base_anchors.round()
+
+ def set_cell_anchors(self, dtype: torch.dtype, device: torch.device):
+ self.cell_anchors = [cell_anchor.to(dtype=dtype, device=device) for cell_anchor in self.cell_anchors]
+
+ def num_anchors_per_location(self) -> List[int]:
+ return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]
+
+ # For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
+ # output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a.
+ def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]]) -> List[Tensor]:
+ anchors = []
+ cell_anchors = self.cell_anchors
+ torch._assert(cell_anchors is not None, "cell_anchors should not be None")
+ torch._assert(
+ len(grid_sizes) == len(strides) == len(cell_anchors),
+ "Anchors should be Tuple[Tuple[int]] because each feature "
+ "map could potentially have different sizes and aspect ratios. "
+ "There needs to be a match between the number of "
+ "feature maps passed and the number of sizes / aspect ratios specified.",
+ )
+
+ for size, stride, base_anchors in zip(grid_sizes, strides, cell_anchors):
+ grid_height, grid_width = size
+ stride_height, stride_width = stride
+ device = base_anchors.device
+
+ # For output anchor, compute [x_center, y_center, x_center, y_center]
+ shifts_x = torch.arange(0, grid_width, dtype=torch.int32, device=device) * stride_width
+ shifts_y = torch.arange(0, grid_height, dtype=torch.int32, device=device) * stride_height
+ shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij")
+ shift_x = shift_x.reshape(-1)
+ shift_y = shift_y.reshape(-1)
+ shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
+
+ # For every (base anchor, output anchor) pair,
+ # offset each zero-centered base anchor by the center of the output anchor.
+ anchors.append((shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4))
+
+ return anchors
+
+ def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]:
+ grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
+ image_size = image_list.tensors.shape[-2:]
+ dtype, device = feature_maps[0].dtype, feature_maps[0].device
+ strides = [
+ [
+ torch.empty((), dtype=torch.int64, device=device).fill_(image_size[0] // g[0]),
+ torch.empty((), dtype=torch.int64, device=device).fill_(image_size[1] // g[1]),
+ ]
+ for g in grid_sizes
+ ]
+ self.set_cell_anchors(dtype, device)
+ anchors_over_all_feature_maps = self.grid_anchors(grid_sizes, strides)
+ anchors: List[List[torch.Tensor]] = []
+ for _ in range(len(image_list.image_sizes)):
+ anchors_in_image = [anchors_per_feature_map for anchors_per_feature_map in anchors_over_all_feature_maps]
+ anchors.append(anchors_in_image)
+ anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors]
+ return anchors
+
+
+class DefaultBoxGenerator(nn.Module):
+ """
+ This module generates the default boxes of SSD for a set of feature maps and image sizes.
+
+ Args:
+ aspect_ratios (List[List[int]]): A list with all the aspect ratios used in each feature map.
+ min_ratio (float): The minimum scale :math:`\text{s}_{\text{min}}` of the default boxes used in the estimation
+ of the scales of each feature map. It is used only if the ``scales`` parameter is not provided.
+ max_ratio (float): The maximum scale :math:`\text{s}_{\text{max}}` of the default boxes used in the estimation
+ of the scales of each feature map. It is used only if the ``scales`` parameter is not provided.
+ scales (List[float]], optional): The scales of the default boxes. If not provided it will be estimated using
+ the ``min_ratio`` and ``max_ratio`` parameters.
+ steps (List[int]], optional): It's a hyper-parameter that affects the tiling of default boxes. If not provided
+ it will be estimated from the data.
+ clip (bool): Whether the standardized values of default boxes should be clipped between 0 and 1. The clipping
+ is applied while the boxes are encoded in format ``(cx, cy, w, h)``.
+ """
+
+ def __init__(
+ self,
+ aspect_ratios: List[List[int]],
+ min_ratio: float = 0.15,
+ max_ratio: float = 0.9,
+ scales: Optional[List[float]] = None,
+ steps: Optional[List[int]] = None,
+ clip: bool = True,
+ ):
+ super().__init__()
+ if steps is not None and len(aspect_ratios) != len(steps):
+ raise ValueError("aspect_ratios and steps should have the same length")
+ self.aspect_ratios = aspect_ratios
+ self.steps = steps
+ self.clip = clip
+ num_outputs = len(aspect_ratios)
+
+ # Estimation of default boxes scales
+ if scales is None:
+ if num_outputs > 1:
+ range_ratio = max_ratio - min_ratio
+ self.scales = [min_ratio + range_ratio * k / (num_outputs - 1.0) for k in range(num_outputs)]
+ self.scales.append(1.0)
+ else:
+ self.scales = [min_ratio, max_ratio]
+ else:
+ self.scales = scales
+
+ self._wh_pairs = self._generate_wh_pairs(num_outputs)
+
+ def _generate_wh_pairs(
+ self, num_outputs: int, dtype: torch.dtype = torch.float32, device: torch.device = torch.device("cpu")
+ ) -> List[Tensor]:
+ _wh_pairs: List[Tensor] = []
+ for k in range(num_outputs):
+ # Adding the 2 default width-height pairs for aspect ratio 1 and scale s'k
+ s_k = self.scales[k]
+ s_prime_k = math.sqrt(self.scales[k] * self.scales[k + 1])
+ wh_pairs = [[s_k, s_k], [s_prime_k, s_prime_k]]
+
+ # Adding 2 pairs for each aspect ratio of the feature map k
+ for ar in self.aspect_ratios[k]:
+ sq_ar = math.sqrt(ar)
+ w = self.scales[k] * sq_ar
+ h = self.scales[k] / sq_ar
+ wh_pairs.extend([[w, h], [h, w]])
+
+ _wh_pairs.append(torch.as_tensor(wh_pairs, dtype=dtype, device=device))
+ return _wh_pairs
+
+ def num_anchors_per_location(self) -> List[int]:
+ # Estimate num of anchors based on aspect ratios: 2 default boxes + 2 * ratios of feaure map.
+ return [2 + 2 * len(r) for r in self.aspect_ratios]
+
+ # Default Boxes calculation based on page 6 of SSD paper
+ def _grid_default_boxes(
+ self, grid_sizes: List[List[int]], image_size: List[int], dtype: torch.dtype = torch.float32
+ ) -> Tensor:
+ default_boxes = []
+ for k, f_k in enumerate(grid_sizes):
+ # Now add the default boxes for each width-height pair
+ if self.steps is not None:
+ x_f_k = image_size[1] / self.steps[k]
+ y_f_k = image_size[0] / self.steps[k]
+ else:
+ y_f_k, x_f_k = f_k
+
+ shifts_x = ((torch.arange(0, f_k[1]) + 0.5) / x_f_k).to(dtype=dtype)
+ shifts_y = ((torch.arange(0, f_k[0]) + 0.5) / y_f_k).to(dtype=dtype)
+ shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij")
+ shift_x = shift_x.reshape(-1)
+ shift_y = shift_y.reshape(-1)
+
+ shifts = torch.stack((shift_x, shift_y) * len(self._wh_pairs[k]), dim=-1).reshape(-1, 2)
+ # Clipping the default boxes while the boxes are encoded in format (cx, cy, w, h)
+ _wh_pair = self._wh_pairs[k].clamp(min=0, max=1) if self.clip else self._wh_pairs[k]
+ wh_pairs = _wh_pair.repeat((f_k[0] * f_k[1]), 1)
+
+ default_box = torch.cat((shifts, wh_pairs), dim=1)
+
+ default_boxes.append(default_box)
+
+ return torch.cat(default_boxes, dim=0)
+
+ def __repr__(self) -> str:
+ s = (
+ f"{self.__class__.__name__}("
+ f"aspect_ratios={self.aspect_ratios}"
+ f", clip={self.clip}"
+ f", scales={self.scales}"
+ f", steps={self.steps}"
+ ")"
+ )
+ return s
+
+ def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]:
+ grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
+ image_size = image_list.tensors.shape[-2:]
+ dtype, device = feature_maps[0].dtype, feature_maps[0].device
+ default_boxes = self._grid_default_boxes(grid_sizes, image_size, dtype=dtype)
+ default_boxes = default_boxes.to(device)
+
+ dboxes = []
+ x_y_size = torch.tensor([image_size[1], image_size[0]], device=default_boxes.device)
+ for _ in image_list.image_sizes:
+ dboxes_in_image = default_boxes
+ dboxes_in_image = torch.cat(
+ [
+ (dboxes_in_image[:, :2] - 0.5 * dboxes_in_image[:, 2:]) * x_y_size,
+ (dboxes_in_image[:, :2] + 0.5 * dboxes_in_image[:, 2:]) * x_y_size,
+ ],
+ -1,
+ )
+ dboxes.append(dboxes_in_image)
+ return dboxes
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/detection/backbone_utils.py b/.venv/lib/python3.11/site-packages/torchvision/models/detection/backbone_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..668e6b31696eb949513d07878eada9d468dc99cd
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/models/detection/backbone_utils.py
@@ -0,0 +1,244 @@
+import warnings
+from typing import Callable, Dict, List, Optional, Union
+
+from torch import nn, Tensor
+from torchvision.ops import misc as misc_nn_ops
+from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool
+
+from .. import mobilenet, resnet
+from .._api import _get_enum_from_fn, WeightsEnum
+from .._utils import handle_legacy_interface, IntermediateLayerGetter
+
+
+class BackboneWithFPN(nn.Module):
+ """
+ Adds a FPN on top of a model.
+ Internally, it uses torchvision.models._utils.IntermediateLayerGetter to
+ extract a submodel that returns the feature maps specified in return_layers.
+ The same limitations of IntermediateLayerGetter apply here.
+ Args:
+ backbone (nn.Module)
+ return_layers (Dict[name, new_name]): a dict containing the names
+ of the modules for which the activations will be returned as
+ the key of the dict, and the value of the dict is the name
+ of the returned activation (which the user can specify).
+ in_channels_list (List[int]): number of channels for each feature map
+ that is returned, in the order they are present in the OrderedDict
+ out_channels (int): number of channels in the FPN.
+ norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
+ Attributes:
+ out_channels (int): the number of channels in the FPN
+ """
+
+ def __init__(
+ self,
+ backbone: nn.Module,
+ return_layers: Dict[str, str],
+ in_channels_list: List[int],
+ out_channels: int,
+ extra_blocks: Optional[ExtraFPNBlock] = None,
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
+ ) -> None:
+ super().__init__()
+
+ if extra_blocks is None:
+ extra_blocks = LastLevelMaxPool()
+
+ self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
+ self.fpn = FeaturePyramidNetwork(
+ in_channels_list=in_channels_list,
+ out_channels=out_channels,
+ extra_blocks=extra_blocks,
+ norm_layer=norm_layer,
+ )
+ self.out_channels = out_channels
+
+ def forward(self, x: Tensor) -> Dict[str, Tensor]:
+ x = self.body(x)
+ x = self.fpn(x)
+ return x
+
+
+@handle_legacy_interface(
+ weights=(
+ "pretrained",
+ lambda kwargs: _get_enum_from_fn(resnet.__dict__[kwargs["backbone_name"]])["IMAGENET1K_V1"],
+ ),
+)
+def resnet_fpn_backbone(
+ *,
+ backbone_name: str,
+ weights: Optional[WeightsEnum],
+ norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
+ trainable_layers: int = 3,
+ returned_layers: Optional[List[int]] = None,
+ extra_blocks: Optional[ExtraFPNBlock] = None,
+) -> BackboneWithFPN:
+ """
+ Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone.
+
+ Examples::
+
+ >>> import torch
+ >>> from torchvision.models import ResNet50_Weights
+ >>> from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
+ >>> backbone = resnet_fpn_backbone(backbone_name='resnet50', weights=ResNet50_Weights.DEFAULT, trainable_layers=3)
+ >>> # get some dummy image
+ >>> x = torch.rand(1,3,64,64)
+ >>> # compute the output
+ >>> output = backbone(x)
+ >>> print([(k, v.shape) for k, v in output.items()])
+ >>> # returns
+ >>> [('0', torch.Size([1, 256, 16, 16])),
+ >>> ('1', torch.Size([1, 256, 8, 8])),
+ >>> ('2', torch.Size([1, 256, 4, 4])),
+ >>> ('3', torch.Size([1, 256, 2, 2])),
+ >>> ('pool', torch.Size([1, 256, 1, 1]))]
+
+ Args:
+ backbone_name (string): resnet architecture. Possible values are 'resnet18', 'resnet34', 'resnet50',
+ 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'
+ weights (WeightsEnum, optional): The pretrained weights for the model
+ norm_layer (callable): it is recommended to use the default value. For details visit:
+ (https://github.com/facebookresearch/maskrcnn-benchmark/issues/267)
+ trainable_layers (int): number of trainable (not frozen) layers starting from final block.
+ Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
+ returned_layers (list of int): The layers of the network to return. Each entry must be in ``[1, 4]``.
+ By default, all layers are returned.
+ extra_blocks (ExtraFPNBlock or None): if provided, extra operations will
+ be performed. It is expected to take the fpn features, the original
+ features and the names of the original features as input, and returns
+ a new list of feature maps and their corresponding names. By
+ default, a ``LastLevelMaxPool`` is used.
+ """
+ backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer)
+ return _resnet_fpn_extractor(backbone, trainable_layers, returned_layers, extra_blocks)
+
+
+def _resnet_fpn_extractor(
+ backbone: resnet.ResNet,
+ trainable_layers: int,
+ returned_layers: Optional[List[int]] = None,
+ extra_blocks: Optional[ExtraFPNBlock] = None,
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
+) -> BackboneWithFPN:
+
+ # select layers that won't be frozen
+ if trainable_layers < 0 or trainable_layers > 5:
+ raise ValueError(f"Trainable layers should be in the range [0,5], got {trainable_layers}")
+ layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers]
+ if trainable_layers == 5:
+ layers_to_train.append("bn1")
+ for name, parameter in backbone.named_parameters():
+ if all([not name.startswith(layer) for layer in layers_to_train]):
+ parameter.requires_grad_(False)
+
+ if extra_blocks is None:
+ extra_blocks = LastLevelMaxPool()
+
+ if returned_layers is None:
+ returned_layers = [1, 2, 3, 4]
+ if min(returned_layers) <= 0 or max(returned_layers) >= 5:
+ raise ValueError(f"Each returned layer should be in the range [1,4]. Got {returned_layers}")
+ return_layers = {f"layer{k}": str(v) for v, k in enumerate(returned_layers)}
+
+ in_channels_stage2 = backbone.inplanes // 8
+ in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers]
+ out_channels = 256
+ return BackboneWithFPN(
+ backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks, norm_layer=norm_layer
+ )
+
+
+def _validate_trainable_layers(
+ is_trained: bool,
+ trainable_backbone_layers: Optional[int],
+ max_value: int,
+ default_value: int,
+) -> int:
+ # don't freeze any layers if pretrained model or backbone is not used
+ if not is_trained:
+ if trainable_backbone_layers is not None:
+ warnings.warn(
+ "Changing trainable_backbone_layers has no effect if "
+ "neither pretrained nor pretrained_backbone have been set to True, "
+ f"falling back to trainable_backbone_layers={max_value} so that all layers are trainable"
+ )
+ trainable_backbone_layers = max_value
+
+ # by default freeze first blocks
+ if trainable_backbone_layers is None:
+ trainable_backbone_layers = default_value
+ if trainable_backbone_layers < 0 or trainable_backbone_layers > max_value:
+ raise ValueError(
+ f"Trainable backbone layers should be in the range [0,{max_value}], got {trainable_backbone_layers} "
+ )
+ return trainable_backbone_layers
+
+
+@handle_legacy_interface(
+ weights=(
+ "pretrained",
+ lambda kwargs: _get_enum_from_fn(mobilenet.__dict__[kwargs["backbone_name"]])["IMAGENET1K_V1"],
+ ),
+)
+def mobilenet_backbone(
+ *,
+ backbone_name: str,
+ weights: Optional[WeightsEnum],
+ fpn: bool,
+ norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
+ trainable_layers: int = 2,
+ returned_layers: Optional[List[int]] = None,
+ extra_blocks: Optional[ExtraFPNBlock] = None,
+) -> nn.Module:
+ backbone = mobilenet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer)
+ return _mobilenet_extractor(backbone, fpn, trainable_layers, returned_layers, extra_blocks)
+
+
+def _mobilenet_extractor(
+ backbone: Union[mobilenet.MobileNetV2, mobilenet.MobileNetV3],
+ fpn: bool,
+ trainable_layers: int,
+ returned_layers: Optional[List[int]] = None,
+ extra_blocks: Optional[ExtraFPNBlock] = None,
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
+) -> nn.Module:
+ backbone = backbone.features
+ # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
+ # The first and last blocks are always included because they are the C0 (conv1) and Cn.
+ stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
+ num_stages = len(stage_indices)
+
+ # find the index of the layer from which we won't freeze
+ if trainable_layers < 0 or trainable_layers > num_stages:
+ raise ValueError(f"Trainable layers should be in the range [0,{num_stages}], got {trainable_layers} ")
+ freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]
+
+ for b in backbone[:freeze_before]:
+ for parameter in b.parameters():
+ parameter.requires_grad_(False)
+
+ out_channels = 256
+ if fpn:
+ if extra_blocks is None:
+ extra_blocks = LastLevelMaxPool()
+
+ if returned_layers is None:
+ returned_layers = [num_stages - 2, num_stages - 1]
+ if min(returned_layers) < 0 or max(returned_layers) >= num_stages:
+ raise ValueError(f"Each returned layer should be in the range [0,{num_stages - 1}], got {returned_layers} ")
+ return_layers = {f"{stage_indices[k]}": str(v) for v, k in enumerate(returned_layers)}
+
+ in_channels_list = [backbone[stage_indices[i]].out_channels for i in returned_layers]
+ return BackboneWithFPN(
+ backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks, norm_layer=norm_layer
+ )
+ else:
+ m = nn.Sequential(
+ backbone,
+ # depthwise linear combination of channels to reduce their size
+ nn.Conv2d(backbone[-1].out_channels, out_channels, 1),
+ )
+ m.out_channels = out_channels # type: ignore[assignment]
+ return m
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/detection/faster_rcnn.py b/.venv/lib/python3.11/site-packages/torchvision/models/detection/faster_rcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..18474ee84f4539cfec99d24534acb1e1e74a14b3
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/models/detection/faster_rcnn.py
@@ -0,0 +1,846 @@
+from typing import Any, Callable, List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torchvision.ops import MultiScaleRoIAlign
+
+from ...ops import misc as misc_nn_ops
+from ...transforms._presets import ObjectDetection
+from .._api import register_model, Weights, WeightsEnum
+from .._meta import _COCO_CATEGORIES
+from .._utils import _ovewrite_value_param, handle_legacy_interface
+from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights
+from ..resnet import resnet50, ResNet50_Weights
+from ._utils import overwrite_eps
+from .anchor_utils import AnchorGenerator
+from .backbone_utils import _mobilenet_extractor, _resnet_fpn_extractor, _validate_trainable_layers
+from .generalized_rcnn import GeneralizedRCNN
+from .roi_heads import RoIHeads
+from .rpn import RegionProposalNetwork, RPNHead
+from .transform import GeneralizedRCNNTransform
+
+
+__all__ = [
+ "FasterRCNN",
+ "FasterRCNN_ResNet50_FPN_Weights",
+ "FasterRCNN_ResNet50_FPN_V2_Weights",
+ "FasterRCNN_MobileNet_V3_Large_FPN_Weights",
+ "FasterRCNN_MobileNet_V3_Large_320_FPN_Weights",
+ "fasterrcnn_resnet50_fpn",
+ "fasterrcnn_resnet50_fpn_v2",
+ "fasterrcnn_mobilenet_v3_large_fpn",
+ "fasterrcnn_mobilenet_v3_large_320_fpn",
+]
+
+
+def _default_anchorgen():
+ anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
+ aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
+ return AnchorGenerator(anchor_sizes, aspect_ratios)
+
+
+class FasterRCNN(GeneralizedRCNN):
+ """
+ Implements Faster R-CNN.
+
+ The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
+ image, and should be in 0-1 range. Different images can have different sizes.
+
+ The behavior of the model changes depending on if it is in training or evaluation mode.
+
+ During training, the model expects both the input tensors and targets (list of dictionary),
+ containing:
+ - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+ - labels (Int64Tensor[N]): the class label for each ground-truth box
+
+ The model returns a Dict[Tensor] during training, containing the classification and regression
+ losses for both the RPN and the R-CNN.
+
+ During inference, the model requires only the input tensors, and returns the post-processed
+ predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
+ follows:
+ - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+ - labels (Int64Tensor[N]): the predicted labels for each image
+ - scores (Tensor[N]): the scores or each prediction
+
+ Args:
+ backbone (nn.Module): the network used to compute the features for the model.
+ It should contain an out_channels attribute, which indicates the number of output
+ channels that each feature map has (and it should be the same for all feature maps).
+ The backbone should return a single Tensor or and OrderedDict[Tensor].
+ num_classes (int): number of output classes of the model (including the background).
+ If box_predictor is specified, num_classes should be None.
+ min_size (int): Images are rescaled before feeding them to the backbone:
+ we attempt to preserve the aspect ratio and scale the shorter edge
+ to ``min_size``. If the resulting longer edge exceeds ``max_size``,
+ then downscale so that the longer edge does not exceed ``max_size``.
+ This may result in the shorter edge beeing lower than ``min_size``.
+ max_size (int): See ``min_size``.
+ image_mean (Tuple[float, float, float]): mean values used for input normalization.
+ They are generally the mean values of the dataset on which the backbone has been trained
+ on
+ image_std (Tuple[float, float, float]): std values used for input normalization.
+ They are generally the std values of the dataset on which the backbone has been trained on
+ rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
+ maps.
+ rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
+ rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
+ rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
+ rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
+ rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
+ rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
+ rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
+ considered as positive during training of the RPN.
+ rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
+ considered as negative during training of the RPN.
+ rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
+ for computing the loss
+ rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
+ of the RPN
+ rpn_score_thresh (float): only return proposals with an objectness score greater than rpn_score_thresh
+ box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
+ the locations indicated by the bounding boxes
+ box_head (nn.Module): module that takes the cropped feature maps as input
+ box_predictor (nn.Module): module that takes the output of box_head and returns the
+ classification logits and box regression deltas.
+ box_score_thresh (float): during inference, only return proposals with a classification score
+ greater than box_score_thresh
+ box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
+ box_detections_per_img (int): maximum number of detections per image, for all classes.
+ box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
+ considered as positive during training of the classification head
+ box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
+ considered as negative during training of the classification head
+ box_batch_size_per_image (int): number of proposals that are sampled during training of the
+ classification head
+ box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
+ of the classification head
+ bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
+ bounding boxes
+
+ Example::
+
+ >>> import torch
+ >>> import torchvision
+ >>> from torchvision.models.detection import FasterRCNN
+ >>> from torchvision.models.detection.rpn import AnchorGenerator
+ >>> # load a pre-trained model for classification and return
+ >>> # only the features
+ >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
+ >>> # FasterRCNN needs to know the number of
+ >>> # output channels in a backbone. For mobilenet_v2, it's 1280,
+ >>> # so we need to add it here
+ >>> backbone.out_channels = 1280
+ >>>
+ >>> # let's make the RPN generate 5 x 3 anchors per spatial
+ >>> # location, with 5 different sizes and 3 different aspect
+ >>> # ratios. We have a Tuple[Tuple[int]] because each feature
+ >>> # map could potentially have different sizes and
+ >>> # aspect ratios
+ >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
+ >>> aspect_ratios=((0.5, 1.0, 2.0),))
+ >>>
+ >>> # let's define what are the feature maps that we will
+ >>> # use to perform the region of interest cropping, as well as
+ >>> # the size of the crop after rescaling.
+ >>> # if your backbone returns a Tensor, featmap_names is expected to
+ >>> # be ['0']. More generally, the backbone should return an
+ >>> # OrderedDict[Tensor], and in featmap_names you can choose which
+ >>> # feature maps to use.
+ >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
+ >>> output_size=7,
+ >>> sampling_ratio=2)
+ >>>
+ >>> # put the pieces together inside a FasterRCNN model
+ >>> model = FasterRCNN(backbone,
+ >>> num_classes=2,
+ >>> rpn_anchor_generator=anchor_generator,
+ >>> box_roi_pool=roi_pooler)
+ >>> model.eval()
+ >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+ >>> predictions = model(x)
+ """
+
+ def __init__(
+ self,
+ backbone,
+ num_classes=None,
+ # transform parameters
+ min_size=800,
+ max_size=1333,
+ image_mean=None,
+ image_std=None,
+ # RPN parameters
+ rpn_anchor_generator=None,
+ rpn_head=None,
+ rpn_pre_nms_top_n_train=2000,
+ rpn_pre_nms_top_n_test=1000,
+ rpn_post_nms_top_n_train=2000,
+ rpn_post_nms_top_n_test=1000,
+ rpn_nms_thresh=0.7,
+ rpn_fg_iou_thresh=0.7,
+ rpn_bg_iou_thresh=0.3,
+ rpn_batch_size_per_image=256,
+ rpn_positive_fraction=0.5,
+ rpn_score_thresh=0.0,
+ # Box parameters
+ box_roi_pool=None,
+ box_head=None,
+ box_predictor=None,
+ box_score_thresh=0.05,
+ box_nms_thresh=0.5,
+ box_detections_per_img=100,
+ box_fg_iou_thresh=0.5,
+ box_bg_iou_thresh=0.5,
+ box_batch_size_per_image=512,
+ box_positive_fraction=0.25,
+ bbox_reg_weights=None,
+ **kwargs,
+ ):
+
+ if not hasattr(backbone, "out_channels"):
+ raise ValueError(
+ "backbone should contain an attribute out_channels "
+ "specifying the number of output channels (assumed to be the "
+ "same for all the levels)"
+ )
+
+ if not isinstance(rpn_anchor_generator, (AnchorGenerator, type(None))):
+ raise TypeError(
+ f"rpn_anchor_generator should be of type AnchorGenerator or None instead of {type(rpn_anchor_generator)}"
+ )
+ if not isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None))):
+ raise TypeError(
+ f"box_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(box_roi_pool)}"
+ )
+
+ if num_classes is not None:
+ if box_predictor is not None:
+ raise ValueError("num_classes should be None when box_predictor is specified")
+ else:
+ if box_predictor is None:
+ raise ValueError("num_classes should not be None when box_predictor is not specified")
+
+ out_channels = backbone.out_channels
+
+ if rpn_anchor_generator is None:
+ rpn_anchor_generator = _default_anchorgen()
+ if rpn_head is None:
+ rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])
+
+ rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test)
+ rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test)
+
+ rpn = RegionProposalNetwork(
+ rpn_anchor_generator,
+ rpn_head,
+ rpn_fg_iou_thresh,
+ rpn_bg_iou_thresh,
+ rpn_batch_size_per_image,
+ rpn_positive_fraction,
+ rpn_pre_nms_top_n,
+ rpn_post_nms_top_n,
+ rpn_nms_thresh,
+ score_thresh=rpn_score_thresh,
+ )
+
+ if box_roi_pool is None:
+ box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2)
+
+ if box_head is None:
+ resolution = box_roi_pool.output_size[0]
+ representation_size = 1024
+ box_head = TwoMLPHead(out_channels * resolution**2, representation_size)
+
+ if box_predictor is None:
+ representation_size = 1024
+ box_predictor = FastRCNNPredictor(representation_size, num_classes)
+
+ roi_heads = RoIHeads(
+ # Box
+ box_roi_pool,
+ box_head,
+ box_predictor,
+ box_fg_iou_thresh,
+ box_bg_iou_thresh,
+ box_batch_size_per_image,
+ box_positive_fraction,
+ bbox_reg_weights,
+ box_score_thresh,
+ box_nms_thresh,
+ box_detections_per_img,
+ )
+
+ if image_mean is None:
+ image_mean = [0.485, 0.456, 0.406]
+ if image_std is None:
+ image_std = [0.229, 0.224, 0.225]
+ transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
+
+ super().__init__(backbone, rpn, roi_heads, transform)
+
+
+class TwoMLPHead(nn.Module):
+ """
+ Standard heads for FPN-based models
+
+ Args:
+ in_channels (int): number of input channels
+ representation_size (int): size of the intermediate representation
+ """
+
+ def __init__(self, in_channels, representation_size):
+ super().__init__()
+
+ self.fc6 = nn.Linear(in_channels, representation_size)
+ self.fc7 = nn.Linear(representation_size, representation_size)
+
+ def forward(self, x):
+ x = x.flatten(start_dim=1)
+
+ x = F.relu(self.fc6(x))
+ x = F.relu(self.fc7(x))
+
+ return x
+
+
+class FastRCNNConvFCHead(nn.Sequential):
+ def __init__(
+ self,
+ input_size: Tuple[int, int, int],
+ conv_layers: List[int],
+ fc_layers: List[int],
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
+ ):
+ """
+ Args:
+ input_size (Tuple[int, int, int]): the input size in CHW format.
+ conv_layers (list): feature dimensions of each Convolution layer
+ fc_layers (list): feature dimensions of each FCN layer
+ norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
+ """
+ in_channels, in_height, in_width = input_size
+
+ blocks = []
+ previous_channels = in_channels
+ for current_channels in conv_layers:
+ blocks.append(misc_nn_ops.Conv2dNormActivation(previous_channels, current_channels, norm_layer=norm_layer))
+ previous_channels = current_channels
+ blocks.append(nn.Flatten())
+ previous_channels = previous_channels * in_height * in_width
+ for current_channels in fc_layers:
+ blocks.append(nn.Linear(previous_channels, current_channels))
+ blocks.append(nn.ReLU(inplace=True))
+ previous_channels = current_channels
+
+ super().__init__(*blocks)
+ for layer in self.modules():
+ if isinstance(layer, nn.Conv2d):
+ nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu")
+ if layer.bias is not None:
+ nn.init.zeros_(layer.bias)
+
+
+class FastRCNNPredictor(nn.Module):
+ """
+ Standard classification + bounding box regression layers
+ for Fast R-CNN.
+
+ Args:
+ in_channels (int): number of input channels
+ num_classes (int): number of output classes (including background)
+ """
+
+ def __init__(self, in_channels, num_classes):
+ super().__init__()
+ self.cls_score = nn.Linear(in_channels, num_classes)
+ self.bbox_pred = nn.Linear(in_channels, num_classes * 4)
+
+ def forward(self, x):
+ if x.dim() == 4:
+ torch._assert(
+ list(x.shape[2:]) == [1, 1],
+ f"x has the wrong shape, expecting the last two dimensions to be [1,1] instead of {list(x.shape[2:])}",
+ )
+ x = x.flatten(start_dim=1)
+ scores = self.cls_score(x)
+ bbox_deltas = self.bbox_pred(x)
+
+ return scores, bbox_deltas
+
+
+_COMMON_META = {
+ "categories": _COCO_CATEGORIES,
+ "min_size": (1, 1),
+}
+
+
+class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
+ COCO_V1 = Weights(
+ url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
+ transforms=ObjectDetection,
+ meta={
+ **_COMMON_META,
+ "num_params": 41755286,
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
+ "_metrics": {
+ "COCO-val2017": {
+ "box_map": 37.0,
+ }
+ },
+ "_ops": 134.38,
+ "_file_size": 159.743,
+ "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
+ },
+ )
+ DEFAULT = COCO_V1
+
+
+class FasterRCNN_ResNet50_FPN_V2_Weights(WeightsEnum):
+ COCO_V1 = Weights(
+ url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_v2_coco-dd69338a.pth",
+ transforms=ObjectDetection,
+ meta={
+ **_COMMON_META,
+ "num_params": 43712278,
+ "recipe": "https://github.com/pytorch/vision/pull/5763",
+ "_metrics": {
+ "COCO-val2017": {
+ "box_map": 46.7,
+ }
+ },
+ "_ops": 280.371,
+ "_file_size": 167.104,
+ "_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""",
+ },
+ )
+ DEFAULT = COCO_V1
+
+
+class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
+ COCO_V1 = Weights(
+ url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
+ transforms=ObjectDetection,
+ meta={
+ **_COMMON_META,
+ "num_params": 19386354,
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
+ "_metrics": {
+ "COCO-val2017": {
+ "box_map": 32.8,
+ }
+ },
+ "_ops": 4.494,
+ "_file_size": 74.239,
+ "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
+ },
+ )
+ DEFAULT = COCO_V1
+
+
+class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
+ COCO_V1 = Weights(
+ url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
+ transforms=ObjectDetection,
+ meta={
+ **_COMMON_META,
+ "num_params": 19386354,
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
+ "_metrics": {
+ "COCO-val2017": {
+ "box_map": 22.8,
+ }
+ },
+ "_ops": 0.719,
+ "_file_size": 74.239,
+ "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
+ },
+ )
+ DEFAULT = COCO_V1
+
+
+@register_model()
+@handle_legacy_interface(
+ weights=("pretrained", FasterRCNN_ResNet50_FPN_Weights.COCO_V1),
+ weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
+)
+def fasterrcnn_resnet50_fpn(
+ *,
+ weights: Optional[FasterRCNN_ResNet50_FPN_Weights] = None,
+ progress: bool = True,
+ num_classes: Optional[int] = None,
+ weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
+ trainable_backbone_layers: Optional[int] = None,
+ **kwargs: Any,
+) -> FasterRCNN:
+ """
+ Faster R-CNN model with a ResNet-50-FPN backbone from the `Faster R-CNN: Towards Real-Time Object
+ Detection with Region Proposal Networks `__
+ paper.
+
+ .. betastatus:: detection module
+
+ The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
+ image, and should be in ``0-1`` range. Different images can have different sizes.
+
+ The behavior of the model changes depending on if it is in training or evaluation mode.
+
+ During training, the model expects both the input tensors and a targets (list of dictionary),
+ containing:
+
+ - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+ - labels (``Int64Tensor[N]``): the class label for each ground-truth box
+
+ The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
+ losses for both the RPN and the R-CNN.
+
+ During inference, the model requires only the input tensors, and returns the post-processed
+ predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
+ follows, where ``N`` is the number of detections:
+
+ - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+ - labels (``Int64Tensor[N]``): the predicted labels for each detection
+ - scores (``Tensor[N]``): the scores of each detection
+
+ For more details on the output, you may refer to :ref:`instance_seg_output`.
+
+ Faster R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
+
+ Example::
+
+ >>> model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
+ >>> # For training
+ >>> images, boxes = torch.rand(4, 3, 600, 1200), torch.rand(4, 11, 4)
+ >>> boxes[:, :, 2:4] = boxes[:, :, 0:2] + boxes[:, :, 2:4]
+ >>> labels = torch.randint(1, 91, (4, 11))
+ >>> images = list(image for image in images)
+ >>> targets = []
+ >>> for i in range(len(images)):
+ >>> d = {}
+ >>> d['boxes'] = boxes[i]
+ >>> d['labels'] = labels[i]
+ >>> targets.append(d)
+ >>> output = model(images, targets)
+ >>> # For inference
+ >>> model.eval()
+ >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+ >>> predictions = model(x)
+ >>>
+ >>> # optionally, if you want to export the model to ONNX:
+ >>> torch.onnx.export(model, x, "faster_rcnn.onnx", opset_version = 11)
+
+ Args:
+ weights (:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ num_classes (int, optional): number of output classes of the model (including the background)
+ weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
+ pretrained weights for the backbone.
+ trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
+ final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
+ trainable. If ``None`` is passed (the default) this value is set to 3.
+ **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+
+ .. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights
+ :members:
+ """
+ weights = FasterRCNN_ResNet50_FPN_Weights.verify(weights)
+ weights_backbone = ResNet50_Weights.verify(weights_backbone)
+
+ if weights is not None:
+ weights_backbone = None
+ num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+ elif num_classes is None:
+ num_classes = 91
+
+ is_trained = weights is not None or weights_backbone is not None
+ trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
+ norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
+
+ backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
+ backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
+ model = FasterRCNN(backbone, num_classes=num_classes, **kwargs)
+
+ if weights is not None:
+ model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+ if weights == FasterRCNN_ResNet50_FPN_Weights.COCO_V1:
+ overwrite_eps(model, 0.0)
+
+ return model
+
+
+@register_model()
+@handle_legacy_interface(
+ weights=("pretrained", FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1),
+ weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
+)
+def fasterrcnn_resnet50_fpn_v2(
+ *,
+ weights: Optional[FasterRCNN_ResNet50_FPN_V2_Weights] = None,
+ progress: bool = True,
+ num_classes: Optional[int] = None,
+ weights_backbone: Optional[ResNet50_Weights] = None,
+ trainable_backbone_layers: Optional[int] = None,
+ **kwargs: Any,
+) -> FasterRCNN:
+ """
+ Constructs an improved Faster R-CNN model with a ResNet-50-FPN backbone from `Benchmarking Detection
+ Transfer Learning with Vision Transformers `__ paper.
+
+ .. betastatus:: detection module
+
+ It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
+ :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
+ details.
+
+ Args:
+ weights (:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ num_classes (int, optional): number of output classes of the model (including the background)
+ weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
+ pretrained weights for the backbone.
+ trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
+ final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
+ trainable. If ``None`` is passed (the default) this value is set to 3.
+ **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+
+ .. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights
+ :members:
+ """
+ weights = FasterRCNN_ResNet50_FPN_V2_Weights.verify(weights)
+ weights_backbone = ResNet50_Weights.verify(weights_backbone)
+
+ if weights is not None:
+ weights_backbone = None
+ num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+ elif num_classes is None:
+ num_classes = 91
+
+ is_trained = weights is not None or weights_backbone is not None
+ trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
+
+ backbone = resnet50(weights=weights_backbone, progress=progress)
+ backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers, norm_layer=nn.BatchNorm2d)
+ rpn_anchor_generator = _default_anchorgen()
+ rpn_head = RPNHead(backbone.out_channels, rpn_anchor_generator.num_anchors_per_location()[0], conv_depth=2)
+ box_head = FastRCNNConvFCHead(
+ (backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d
+ )
+ model = FasterRCNN(
+ backbone,
+ num_classes=num_classes,
+ rpn_anchor_generator=rpn_anchor_generator,
+ rpn_head=rpn_head,
+ box_head=box_head,
+ **kwargs,
+ )
+
+ if weights is not None:
+ model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+
+ return model
+
+
+def _fasterrcnn_mobilenet_v3_large_fpn(
+ *,
+ weights: Optional[Union[FasterRCNN_MobileNet_V3_Large_FPN_Weights, FasterRCNN_MobileNet_V3_Large_320_FPN_Weights]],
+ progress: bool,
+ num_classes: Optional[int],
+ weights_backbone: Optional[MobileNet_V3_Large_Weights],
+ trainable_backbone_layers: Optional[int],
+ **kwargs: Any,
+) -> FasterRCNN:
+ if weights is not None:
+ weights_backbone = None
+ num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+ elif num_classes is None:
+ num_classes = 91
+
+ is_trained = weights is not None or weights_backbone is not None
+ trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 6, 3)
+ norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
+
+ backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
+ backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
+ anchor_sizes = (
+ (
+ 32,
+ 64,
+ 128,
+ 256,
+ 512,
+ ),
+ ) * 3
+ aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
+ model = FasterRCNN(
+ backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs
+ )
+
+ if weights is not None:
+ model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+
+ return model
+
+
+@register_model()
+@handle_legacy_interface(
+ weights=("pretrained", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1),
+ weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
+)
+def fasterrcnn_mobilenet_v3_large_320_fpn(
+ *,
+ weights: Optional[FasterRCNN_MobileNet_V3_Large_320_FPN_Weights] = None,
+ progress: bool = True,
+ num_classes: Optional[int] = None,
+ weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
+ trainable_backbone_layers: Optional[int] = None,
+ **kwargs: Any,
+) -> FasterRCNN:
+ """
+ Low resolution Faster R-CNN model with a MobileNetV3-Large backbone tuned for mobile use cases.
+
+ .. betastatus:: detection module
+
+ It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
+ :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
+ details.
+
+ Example::
+
+ >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(weights=FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.DEFAULT)
+ >>> model.eval()
+ >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+ >>> predictions = model(x)
+
+ Args:
+ weights (:class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ num_classes (int, optional): number of output classes of the model (including the background)
+ weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The
+ pretrained weights for the backbone.
+ trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
+ final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are
+ trainable. If ``None`` is passed (the default) this value is set to 3.
+ **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+
+ .. autoclass:: torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights
+ :members:
+ """
+ weights = FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.verify(weights)
+ weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
+
+ defaults = {
+ "min_size": 320,
+ "max_size": 640,
+ "rpn_pre_nms_top_n_test": 150,
+ "rpn_post_nms_top_n_test": 150,
+ "rpn_score_thresh": 0.05,
+ }
+
+ kwargs = {**defaults, **kwargs}
+ return _fasterrcnn_mobilenet_v3_large_fpn(
+ weights=weights,
+ progress=progress,
+ num_classes=num_classes,
+ weights_backbone=weights_backbone,
+ trainable_backbone_layers=trainable_backbone_layers,
+ **kwargs,
+ )
+
+
+@register_model()
+@handle_legacy_interface(
+ weights=("pretrained", FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1),
+ weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
+)
+def fasterrcnn_mobilenet_v3_large_fpn(
+ *,
+ weights: Optional[FasterRCNN_MobileNet_V3_Large_FPN_Weights] = None,
+ progress: bool = True,
+ num_classes: Optional[int] = None,
+ weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
+ trainable_backbone_layers: Optional[int] = None,
+ **kwargs: Any,
+) -> FasterRCNN:
+ """
+ Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone.
+
+ .. betastatus:: detection module
+
+ It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
+ :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
+ details.
+
+ Example::
+
+ >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT)
+ >>> model.eval()
+ >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+ >>> predictions = model(x)
+
+ Args:
+ weights (:class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ num_classes (int, optional): number of output classes of the model (including the background)
+ weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The
+ pretrained weights for the backbone.
+ trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
+ final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are
+ trainable. If ``None`` is passed (the default) this value is set to 3.
+ **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+
+ .. autoclass:: torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights
+ :members:
+ """
+ weights = FasterRCNN_MobileNet_V3_Large_FPN_Weights.verify(weights)
+ weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
+
+ defaults = {
+ "rpn_score_thresh": 0.05,
+ }
+
+ kwargs = {**defaults, **kwargs}
+ return _fasterrcnn_mobilenet_v3_large_fpn(
+ weights=weights,
+ progress=progress,
+ num_classes=num_classes,
+ weights_backbone=weights_backbone,
+ trainable_backbone_layers=trainable_backbone_layers,
+ **kwargs,
+ )
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/detection/fcos.py b/.venv/lib/python3.11/site-packages/torchvision/models/detection/fcos.py
new file mode 100644
index 0000000000000000000000000000000000000000..a86ad2f424c32bd1cf951d474d3ef14bd1bddbb7
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/models/detection/fcos.py
@@ -0,0 +1,775 @@
+import math
+import warnings
+from collections import OrderedDict
+from functools import partial
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+import torch
+from torch import nn, Tensor
+
+from ...ops import boxes as box_ops, generalized_box_iou_loss, misc as misc_nn_ops, sigmoid_focal_loss
+from ...ops.feature_pyramid_network import LastLevelP6P7
+from ...transforms._presets import ObjectDetection
+from ...utils import _log_api_usage_once
+from .._api import register_model, Weights, WeightsEnum
+from .._meta import _COCO_CATEGORIES
+from .._utils import _ovewrite_value_param, handle_legacy_interface
+from ..resnet import resnet50, ResNet50_Weights
+from . import _utils as det_utils
+from .anchor_utils import AnchorGenerator
+from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
+from .transform import GeneralizedRCNNTransform
+
+
+__all__ = [
+ "FCOS",
+ "FCOS_ResNet50_FPN_Weights",
+ "fcos_resnet50_fpn",
+]
+
+
+class FCOSHead(nn.Module):
+ """
+ A regression and classification head for use in FCOS.
+
+ Args:
+ in_channels (int): number of channels of the input feature
+ num_anchors (int): number of anchors to be predicted
+ num_classes (int): number of classes to be predicted
+ num_convs (Optional[int]): number of conv layer of head. Default: 4.
+ """
+
+ __annotations__ = {
+ "box_coder": det_utils.BoxLinearCoder,
+ }
+
+ def __init__(self, in_channels: int, num_anchors: int, num_classes: int, num_convs: Optional[int] = 4) -> None:
+ super().__init__()
+ self.box_coder = det_utils.BoxLinearCoder(normalize_by_size=True)
+ self.classification_head = FCOSClassificationHead(in_channels, num_anchors, num_classes, num_convs)
+ self.regression_head = FCOSRegressionHead(in_channels, num_anchors, num_convs)
+
+ def compute_loss(
+ self,
+ targets: List[Dict[str, Tensor]],
+ head_outputs: Dict[str, Tensor],
+ anchors: List[Tensor],
+ matched_idxs: List[Tensor],
+ ) -> Dict[str, Tensor]:
+
+ cls_logits = head_outputs["cls_logits"] # [N, HWA, C]
+ bbox_regression = head_outputs["bbox_regression"] # [N, HWA, 4]
+ bbox_ctrness = head_outputs["bbox_ctrness"] # [N, HWA, 1]
+
+ all_gt_classes_targets = []
+ all_gt_boxes_targets = []
+ for targets_per_image, matched_idxs_per_image in zip(targets, matched_idxs):
+ if len(targets_per_image["labels"]) == 0:
+ gt_classes_targets = targets_per_image["labels"].new_zeros((len(matched_idxs_per_image),))
+ gt_boxes_targets = targets_per_image["boxes"].new_zeros((len(matched_idxs_per_image), 4))
+ else:
+ gt_classes_targets = targets_per_image["labels"][matched_idxs_per_image.clip(min=0)]
+ gt_boxes_targets = targets_per_image["boxes"][matched_idxs_per_image.clip(min=0)]
+ gt_classes_targets[matched_idxs_per_image < 0] = -1 # background
+ all_gt_classes_targets.append(gt_classes_targets)
+ all_gt_boxes_targets.append(gt_boxes_targets)
+
+ # List[Tensor] to Tensor conversion of `all_gt_boxes_target`, `all_gt_classes_targets` and `anchors`
+ all_gt_boxes_targets, all_gt_classes_targets, anchors = (
+ torch.stack(all_gt_boxes_targets),
+ torch.stack(all_gt_classes_targets),
+ torch.stack(anchors),
+ )
+
+ # compute foregroud
+ foregroud_mask = all_gt_classes_targets >= 0
+ num_foreground = foregroud_mask.sum().item()
+
+ # classification loss
+ gt_classes_targets = torch.zeros_like(cls_logits)
+ gt_classes_targets[foregroud_mask, all_gt_classes_targets[foregroud_mask]] = 1.0
+ loss_cls = sigmoid_focal_loss(cls_logits, gt_classes_targets, reduction="sum")
+
+ # amp issue: pred_boxes need to convert float
+ pred_boxes = self.box_coder.decode(bbox_regression, anchors)
+
+ # regression loss: GIoU loss
+ loss_bbox_reg = generalized_box_iou_loss(
+ pred_boxes[foregroud_mask],
+ all_gt_boxes_targets[foregroud_mask],
+ reduction="sum",
+ )
+
+ # ctrness loss
+
+ bbox_reg_targets = self.box_coder.encode(anchors, all_gt_boxes_targets)
+
+ if len(bbox_reg_targets) == 0:
+ gt_ctrness_targets = bbox_reg_targets.new_zeros(bbox_reg_targets.size()[:-1])
+ else:
+ left_right = bbox_reg_targets[:, :, [0, 2]]
+ top_bottom = bbox_reg_targets[:, :, [1, 3]]
+ gt_ctrness_targets = torch.sqrt(
+ (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0])
+ * (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
+ )
+ pred_centerness = bbox_ctrness.squeeze(dim=2)
+ loss_bbox_ctrness = nn.functional.binary_cross_entropy_with_logits(
+ pred_centerness[foregroud_mask], gt_ctrness_targets[foregroud_mask], reduction="sum"
+ )
+
+ return {
+ "classification": loss_cls / max(1, num_foreground),
+ "bbox_regression": loss_bbox_reg / max(1, num_foreground),
+ "bbox_ctrness": loss_bbox_ctrness / max(1, num_foreground),
+ }
+
+ def forward(self, x: List[Tensor]) -> Dict[str, Tensor]:
+ cls_logits = self.classification_head(x)
+ bbox_regression, bbox_ctrness = self.regression_head(x)
+ return {
+ "cls_logits": cls_logits,
+ "bbox_regression": bbox_regression,
+ "bbox_ctrness": bbox_ctrness,
+ }
+
+
+class FCOSClassificationHead(nn.Module):
+ """
+ A classification head for use in FCOS.
+
+ Args:
+ in_channels (int): number of channels of the input feature.
+ num_anchors (int): number of anchors to be predicted.
+ num_classes (int): number of classes to be predicted.
+ num_convs (Optional[int]): number of conv layer. Default: 4.
+ prior_probability (Optional[float]): probability of prior. Default: 0.01.
+ norm_layer: Module specifying the normalization layer to use.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ num_anchors: int,
+ num_classes: int,
+ num_convs: int = 4,
+ prior_probability: float = 0.01,
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
+ ) -> None:
+ super().__init__()
+
+ self.num_classes = num_classes
+ self.num_anchors = num_anchors
+
+ if norm_layer is None:
+ norm_layer = partial(nn.GroupNorm, 32)
+
+ conv = []
+ for _ in range(num_convs):
+ conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1))
+ conv.append(norm_layer(in_channels))
+ conv.append(nn.ReLU())
+ self.conv = nn.Sequential(*conv)
+
+ for layer in self.conv.children():
+ if isinstance(layer, nn.Conv2d):
+ torch.nn.init.normal_(layer.weight, std=0.01)
+ torch.nn.init.constant_(layer.bias, 0)
+
+ self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
+ torch.nn.init.normal_(self.cls_logits.weight, std=0.01)
+ torch.nn.init.constant_(self.cls_logits.bias, -math.log((1 - prior_probability) / prior_probability))
+
+ def forward(self, x: List[Tensor]) -> Tensor:
+ all_cls_logits = []
+
+ for features in x:
+ cls_logits = self.conv(features)
+ cls_logits = self.cls_logits(cls_logits)
+
+ # Permute classification output from (N, A * K, H, W) to (N, HWA, K).
+ N, _, H, W = cls_logits.shape
+ cls_logits = cls_logits.view(N, -1, self.num_classes, H, W)
+ cls_logits = cls_logits.permute(0, 3, 4, 1, 2)
+ cls_logits = cls_logits.reshape(N, -1, self.num_classes) # Size=(N, HWA, 4)
+
+ all_cls_logits.append(cls_logits)
+
+ return torch.cat(all_cls_logits, dim=1)
+
+
+class FCOSRegressionHead(nn.Module):
+ """
+ A regression head for use in FCOS, which combines regression branch and center-ness branch.
+ This can obtain better performance.
+
+ Reference: `FCOS: A simple and strong anchor-free object detector `_.
+
+ Args:
+ in_channels (int): number of channels of the input feature
+ num_anchors (int): number of anchors to be predicted
+ num_convs (Optional[int]): number of conv layer. Default: 4.
+ norm_layer: Module specifying the normalization layer to use.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ num_anchors: int,
+ num_convs: int = 4,
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
+ ):
+ super().__init__()
+
+ if norm_layer is None:
+ norm_layer = partial(nn.GroupNorm, 32)
+
+ conv = []
+ for _ in range(num_convs):
+ conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1))
+ conv.append(norm_layer(in_channels))
+ conv.append(nn.ReLU())
+ self.conv = nn.Sequential(*conv)
+
+ self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1)
+ self.bbox_ctrness = nn.Conv2d(in_channels, num_anchors * 1, kernel_size=3, stride=1, padding=1)
+ for layer in [self.bbox_reg, self.bbox_ctrness]:
+ torch.nn.init.normal_(layer.weight, std=0.01)
+ torch.nn.init.zeros_(layer.bias)
+
+ for layer in self.conv.children():
+ if isinstance(layer, nn.Conv2d):
+ torch.nn.init.normal_(layer.weight, std=0.01)
+ torch.nn.init.zeros_(layer.bias)
+
+ def forward(self, x: List[Tensor]) -> Tuple[Tensor, Tensor]:
+ all_bbox_regression = []
+ all_bbox_ctrness = []
+
+ for features in x:
+ bbox_feature = self.conv(features)
+ bbox_regression = nn.functional.relu(self.bbox_reg(bbox_feature))
+ bbox_ctrness = self.bbox_ctrness(bbox_feature)
+
+ # permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4).
+ N, _, H, W = bbox_regression.shape
+ bbox_regression = bbox_regression.view(N, -1, 4, H, W)
+ bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2)
+ bbox_regression = bbox_regression.reshape(N, -1, 4) # Size=(N, HWA, 4)
+ all_bbox_regression.append(bbox_regression)
+
+ # permute bbox ctrness output from (N, 1 * A, H, W) to (N, HWA, 1).
+ bbox_ctrness = bbox_ctrness.view(N, -1, 1, H, W)
+ bbox_ctrness = bbox_ctrness.permute(0, 3, 4, 1, 2)
+ bbox_ctrness = bbox_ctrness.reshape(N, -1, 1)
+ all_bbox_ctrness.append(bbox_ctrness)
+
+ return torch.cat(all_bbox_regression, dim=1), torch.cat(all_bbox_ctrness, dim=1)
+
+
+class FCOS(nn.Module):
+ """
+ Implements FCOS.
+
+ The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
+ image, and should be in 0-1 range. Different images can have different sizes.
+
+ The behavior of the model changes depending on if it is in training or evaluation mode.
+
+ During training, the model expects both the input tensors and targets (list of dictionary),
+ containing:
+ - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+ - labels (Int64Tensor[N]): the class label for each ground-truth box
+
+ The model returns a Dict[Tensor] during training, containing the classification, regression
+ and centerness losses.
+
+ During inference, the model requires only the input tensors, and returns the post-processed
+ predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
+ follows:
+ - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+ - labels (Int64Tensor[N]): the predicted labels for each image
+ - scores (Tensor[N]): the scores for each prediction
+
+ Args:
+ backbone (nn.Module): the network used to compute the features for the model.
+ It should contain an out_channels attribute, which indicates the number of output
+ channels that each feature map has (and it should be the same for all feature maps).
+ The backbone should return a single Tensor or an OrderedDict[Tensor].
+ num_classes (int): number of output classes of the model (including the background).
+ min_size (int): Images are rescaled before feeding them to the backbone:
+ we attempt to preserve the aspect ratio and scale the shorter edge
+ to ``min_size``. If the resulting longer edge exceeds ``max_size``,
+ then downscale so that the longer edge does not exceed ``max_size``.
+ This may result in the shorter edge beeing lower than ``min_size``.
+ max_size (int): See ``min_size``.
+ image_mean (Tuple[float, float, float]): mean values used for input normalization.
+ They are generally the mean values of the dataset on which the backbone has been trained
+ on
+ image_std (Tuple[float, float, float]): std values used for input normalization.
+ They are generally the std values of the dataset on which the backbone has been trained on
+ anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
+ maps. For FCOS, only set one anchor for per position of each level, the width and height equal to
+ the stride of feature map, and set aspect ratio = 1.0, so the center of anchor is equivalent to the point
+ in FCOS paper.
+ head (nn.Module): Module run on top of the feature pyramid.
+ Defaults to a module containing a classification and regression module.
+ center_sampling_radius (int): radius of the "center" of a groundtruth box,
+ within which all anchor points are labeled positive.
+ score_thresh (float): Score threshold used for postprocessing the detections.
+ nms_thresh (float): NMS threshold used for postprocessing the detections.
+ detections_per_img (int): Number of best detections to keep after NMS.
+ topk_candidates (int): Number of best detections to keep before NMS.
+
+ Example:
+
+ >>> import torch
+ >>> import torchvision
+ >>> from torchvision.models.detection import FCOS
+ >>> from torchvision.models.detection.anchor_utils import AnchorGenerator
+ >>> # load a pre-trained model for classification and return
+ >>> # only the features
+ >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
+ >>> # FCOS needs to know the number of
+ >>> # output channels in a backbone. For mobilenet_v2, it's 1280,
+ >>> # so we need to add it here
+ >>> backbone.out_channels = 1280
+ >>>
+ >>> # let's make the network generate 5 x 3 anchors per spatial
+ >>> # location, with 5 different sizes and 3 different aspect
+ >>> # ratios. We have a Tuple[Tuple[int]] because each feature
+ >>> # map could potentially have different sizes and
+ >>> # aspect ratios
+ >>> anchor_generator = AnchorGenerator(
+ >>> sizes=((8,), (16,), (32,), (64,), (128,)),
+ >>> aspect_ratios=((1.0,),)
+ >>> )
+ >>>
+ >>> # put the pieces together inside a FCOS model
+ >>> model = FCOS(
+ >>> backbone,
+ >>> num_classes=80,
+ >>> anchor_generator=anchor_generator,
+ >>> )
+ >>> model.eval()
+ >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+ >>> predictions = model(x)
+ """
+
+ __annotations__ = {
+ "box_coder": det_utils.BoxLinearCoder,
+ }
+
+ def __init__(
+ self,
+ backbone: nn.Module,
+ num_classes: int,
+ # transform parameters
+ min_size: int = 800,
+ max_size: int = 1333,
+ image_mean: Optional[List[float]] = None,
+ image_std: Optional[List[float]] = None,
+ # Anchor parameters
+ anchor_generator: Optional[AnchorGenerator] = None,
+ head: Optional[nn.Module] = None,
+ center_sampling_radius: float = 1.5,
+ score_thresh: float = 0.2,
+ nms_thresh: float = 0.6,
+ detections_per_img: int = 100,
+ topk_candidates: int = 1000,
+ **kwargs,
+ ):
+ super().__init__()
+ _log_api_usage_once(self)
+
+ if not hasattr(backbone, "out_channels"):
+ raise ValueError(
+ "backbone should contain an attribute out_channels "
+ "specifying the number of output channels (assumed to be the "
+ "same for all the levels)"
+ )
+ self.backbone = backbone
+
+ if not isinstance(anchor_generator, (AnchorGenerator, type(None))):
+ raise TypeError(
+ f"anchor_generator should be of type AnchorGenerator or None, instead got {type(anchor_generator)}"
+ )
+
+ if anchor_generator is None:
+ anchor_sizes = ((8,), (16,), (32,), (64,), (128,)) # equal to strides of multi-level feature map
+ aspect_ratios = ((1.0,),) * len(anchor_sizes) # set only one anchor
+ anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
+ self.anchor_generator = anchor_generator
+ if self.anchor_generator.num_anchors_per_location()[0] != 1:
+ raise ValueError(
+ f"anchor_generator.num_anchors_per_location()[0] should be 1 instead of {anchor_generator.num_anchors_per_location()[0]}"
+ )
+
+ if head is None:
+ head = FCOSHead(backbone.out_channels, anchor_generator.num_anchors_per_location()[0], num_classes)
+ self.head = head
+
+ self.box_coder = det_utils.BoxLinearCoder(normalize_by_size=True)
+
+ if image_mean is None:
+ image_mean = [0.485, 0.456, 0.406]
+ if image_std is None:
+ image_std = [0.229, 0.224, 0.225]
+ self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
+
+ self.center_sampling_radius = center_sampling_radius
+ self.score_thresh = score_thresh
+ self.nms_thresh = nms_thresh
+ self.detections_per_img = detections_per_img
+ self.topk_candidates = topk_candidates
+
+ # used only on torchscript mode
+ self._has_warned = False
+
+ @torch.jit.unused
+ def eager_outputs(
+ self, losses: Dict[str, Tensor], detections: List[Dict[str, Tensor]]
+ ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
+ if self.training:
+ return losses
+
+ return detections
+
+ def compute_loss(
+ self,
+ targets: List[Dict[str, Tensor]],
+ head_outputs: Dict[str, Tensor],
+ anchors: List[Tensor],
+ num_anchors_per_level: List[int],
+ ) -> Dict[str, Tensor]:
+ matched_idxs = []
+ for anchors_per_image, targets_per_image in zip(anchors, targets):
+ if targets_per_image["boxes"].numel() == 0:
+ matched_idxs.append(
+ torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device)
+ )
+ continue
+
+ gt_boxes = targets_per_image["boxes"]
+ gt_centers = (gt_boxes[:, :2] + gt_boxes[:, 2:]) / 2 # Nx2
+ anchor_centers = (anchors_per_image[:, :2] + anchors_per_image[:, 2:]) / 2 # N
+ anchor_sizes = anchors_per_image[:, 2] - anchors_per_image[:, 0]
+ # center sampling: anchor point must be close enough to gt center.
+ pairwise_match = (anchor_centers[:, None, :] - gt_centers[None, :, :]).abs_().max(
+ dim=2
+ ).values < self.center_sampling_radius * anchor_sizes[:, None]
+ # compute pairwise distance between N points and M boxes
+ x, y = anchor_centers.unsqueeze(dim=2).unbind(dim=1) # (N, 1)
+ x0, y0, x1, y1 = gt_boxes.unsqueeze(dim=0).unbind(dim=2) # (1, M)
+ pairwise_dist = torch.stack([x - x0, y - y0, x1 - x, y1 - y], dim=2) # (N, M)
+
+ # anchor point must be inside gt
+ pairwise_match &= pairwise_dist.min(dim=2).values > 0
+
+ # each anchor is only responsible for certain scale range.
+ lower_bound = anchor_sizes * 4
+ lower_bound[: num_anchors_per_level[0]] = 0
+ upper_bound = anchor_sizes * 8
+ upper_bound[-num_anchors_per_level[-1] :] = float("inf")
+ pairwise_dist = pairwise_dist.max(dim=2).values
+ pairwise_match &= (pairwise_dist > lower_bound[:, None]) & (pairwise_dist < upper_bound[:, None])
+
+ # match the GT box with minimum area, if there are multiple GT matches
+ gt_areas = (gt_boxes[:, 2] - gt_boxes[:, 0]) * (gt_boxes[:, 3] - gt_boxes[:, 1]) # N
+ pairwise_match = pairwise_match.to(torch.float32) * (1e8 - gt_areas[None, :])
+ min_values, matched_idx = pairwise_match.max(dim=1) # R, per-anchor match
+ matched_idx[min_values < 1e-5] = -1 # unmatched anchors are assigned -1
+
+ matched_idxs.append(matched_idx)
+
+ return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs)
+
+ def postprocess_detections(
+ self, head_outputs: Dict[str, List[Tensor]], anchors: List[List[Tensor]], image_shapes: List[Tuple[int, int]]
+ ) -> List[Dict[str, Tensor]]:
+ class_logits = head_outputs["cls_logits"]
+ box_regression = head_outputs["bbox_regression"]
+ box_ctrness = head_outputs["bbox_ctrness"]
+
+ num_images = len(image_shapes)
+
+ detections: List[Dict[str, Tensor]] = []
+
+ for index in range(num_images):
+ box_regression_per_image = [br[index] for br in box_regression]
+ logits_per_image = [cl[index] for cl in class_logits]
+ box_ctrness_per_image = [bc[index] for bc in box_ctrness]
+ anchors_per_image, image_shape = anchors[index], image_shapes[index]
+
+ image_boxes = []
+ image_scores = []
+ image_labels = []
+
+ for box_regression_per_level, logits_per_level, box_ctrness_per_level, anchors_per_level in zip(
+ box_regression_per_image, logits_per_image, box_ctrness_per_image, anchors_per_image
+ ):
+ num_classes = logits_per_level.shape[-1]
+
+ # remove low scoring boxes
+ scores_per_level = torch.sqrt(
+ torch.sigmoid(logits_per_level) * torch.sigmoid(box_ctrness_per_level)
+ ).flatten()
+ keep_idxs = scores_per_level > self.score_thresh
+ scores_per_level = scores_per_level[keep_idxs]
+ topk_idxs = torch.where(keep_idxs)[0]
+
+ # keep only topk scoring predictions
+ num_topk = det_utils._topk_min(topk_idxs, self.topk_candidates, 0)
+ scores_per_level, idxs = scores_per_level.topk(num_topk)
+ topk_idxs = topk_idxs[idxs]
+
+ anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor")
+ labels_per_level = topk_idxs % num_classes
+
+ boxes_per_level = self.box_coder.decode(
+ box_regression_per_level[anchor_idxs], anchors_per_level[anchor_idxs]
+ )
+ boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape)
+
+ image_boxes.append(boxes_per_level)
+ image_scores.append(scores_per_level)
+ image_labels.append(labels_per_level)
+
+ image_boxes = torch.cat(image_boxes, dim=0)
+ image_scores = torch.cat(image_scores, dim=0)
+ image_labels = torch.cat(image_labels, dim=0)
+
+ # non-maximum suppression
+ keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
+ keep = keep[: self.detections_per_img]
+
+ detections.append(
+ {
+ "boxes": image_boxes[keep],
+ "scores": image_scores[keep],
+ "labels": image_labels[keep],
+ }
+ )
+
+ return detections
+
+ def forward(
+ self,
+ images: List[Tensor],
+ targets: Optional[List[Dict[str, Tensor]]] = None,
+ ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
+ """
+ Args:
+ images (list[Tensor]): images to be processed
+ targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)
+
+ Returns:
+ result (list[BoxList] or dict[Tensor]): the output from the model.
+ During training, it returns a dict[Tensor] which contains the losses.
+ During testing, it returns list[BoxList] contains additional fields
+ like `scores`, `labels` and `mask` (for Mask R-CNN models).
+ """
+ if self.training:
+
+ if targets is None:
+ torch._assert(False, "targets should not be none when in training mode")
+ else:
+ for target in targets:
+ boxes = target["boxes"]
+ torch._assert(isinstance(boxes, torch.Tensor), "Expected target boxes to be of type Tensor.")
+ torch._assert(
+ len(boxes.shape) == 2 and boxes.shape[-1] == 4,
+ f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.",
+ )
+
+ original_image_sizes: List[Tuple[int, int]] = []
+ for img in images:
+ val = img.shape[-2:]
+ torch._assert(
+ len(val) == 2,
+ f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
+ )
+ original_image_sizes.append((val[0], val[1]))
+
+ # transform the input
+ images, targets = self.transform(images, targets)
+
+ # Check for degenerate boxes
+ if targets is not None:
+ for target_idx, target in enumerate(targets):
+ boxes = target["boxes"]
+ degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
+ if degenerate_boxes.any():
+ # print the first degenerate box
+ bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
+ degen_bb: List[float] = boxes[bb_idx].tolist()
+ torch._assert(
+ False,
+ f"All bounding boxes should have positive height and width. Found invalid box {degen_bb} for target at index {target_idx}.",
+ )
+
+ # get the features from the backbone
+ features = self.backbone(images.tensors)
+ if isinstance(features, torch.Tensor):
+ features = OrderedDict([("0", features)])
+
+ features = list(features.values())
+
+ # compute the fcos heads outputs using the features
+ head_outputs = self.head(features)
+
+ # create the set of anchors
+ anchors = self.anchor_generator(images, features)
+ # recover level sizes
+ num_anchors_per_level = [x.size(2) * x.size(3) for x in features]
+
+ losses = {}
+ detections: List[Dict[str, Tensor]] = []
+ if self.training:
+ if targets is None:
+ torch._assert(False, "targets should not be none when in training mode")
+ else:
+ # compute the losses
+ losses = self.compute_loss(targets, head_outputs, anchors, num_anchors_per_level)
+ else:
+ # split outputs per level
+ split_head_outputs: Dict[str, List[Tensor]] = {}
+ for k in head_outputs:
+ split_head_outputs[k] = list(head_outputs[k].split(num_anchors_per_level, dim=1))
+ split_anchors = [list(a.split(num_anchors_per_level)) for a in anchors]
+
+ # compute the detections
+ detections = self.postprocess_detections(split_head_outputs, split_anchors, images.image_sizes)
+ detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
+
+ if torch.jit.is_scripting():
+ if not self._has_warned:
+ warnings.warn("FCOS always returns a (Losses, Detections) tuple in scripting")
+ self._has_warned = True
+ return losses, detections
+ return self.eager_outputs(losses, detections)
+
+
+class FCOS_ResNet50_FPN_Weights(WeightsEnum):
+ COCO_V1 = Weights(
+ url="https://download.pytorch.org/models/fcos_resnet50_fpn_coco-99b0c9b7.pth",
+ transforms=ObjectDetection,
+ meta={
+ "num_params": 32269600,
+ "categories": _COCO_CATEGORIES,
+ "min_size": (1, 1),
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#fcos-resnet-50-fpn",
+ "_metrics": {
+ "COCO-val2017": {
+ "box_map": 39.2,
+ }
+ },
+ "_ops": 128.207,
+ "_file_size": 123.608,
+ "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
+ },
+ )
+ DEFAULT = COCO_V1
+
+
+@register_model()
+@handle_legacy_interface(
+ weights=("pretrained", FCOS_ResNet50_FPN_Weights.COCO_V1),
+ weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
+)
+def fcos_resnet50_fpn(
+ *,
+ weights: Optional[FCOS_ResNet50_FPN_Weights] = None,
+ progress: bool = True,
+ num_classes: Optional[int] = None,
+ weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
+ trainable_backbone_layers: Optional[int] = None,
+ **kwargs: Any,
+) -> FCOS:
+ """
+ Constructs a FCOS model with a ResNet-50-FPN backbone.
+
+ .. betastatus:: detection module
+
+ Reference: `FCOS: Fully Convolutional One-Stage Object Detection `_.
+ `FCOS: A simple and strong anchor-free object detector `_.
+
+ The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
+ image, and should be in ``0-1`` range. Different images can have different sizes.
+
+ The behavior of the model changes depending on if it is in training or evaluation mode.
+
+ During training, the model expects both the input tensors and targets (list of dictionary),
+ containing:
+
+ - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+ - labels (``Int64Tensor[N]``): the class label for each ground-truth box
+
+ The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
+ losses.
+
+ During inference, the model requires only the input tensors, and returns the post-processed
+ predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
+ follows, where ``N`` is the number of detections:
+
+ - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+ - labels (``Int64Tensor[N]``): the predicted labels for each detection
+ - scores (``Tensor[N]``): the scores of each detection
+
+ For more details on the output, you may refer to :ref:`instance_seg_output`.
+
+ Example:
+
+ >>> model = torchvision.models.detection.fcos_resnet50_fpn(weights=FCOS_ResNet50_FPN_Weights.DEFAULT)
+ >>> model.eval()
+ >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+ >>> predictions = model(x)
+
+ Args:
+ weights (:class:`~torchvision.models.detection.FCOS_ResNet50_FPN_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.detection.FCOS_ResNet50_FPN_Weights`
+ below for more details, and possible values. By default, no
+ pre-trained weights are used.
+ progress (bool): If True, displays a progress bar of the download to stderr
+ num_classes (int, optional): number of output classes of the model (including the background)
+ weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for
+ the backbone.
+ trainable_backbone_layers (int, optional): number of trainable (not frozen) resnet layers starting
+ from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
+ trainable. If ``None`` is passed (the default) this value is set to 3. Default: None
+ **kwargs: parameters passed to the ``torchvision.models.detection.FCOS``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+
+ .. autoclass:: torchvision.models.detection.FCOS_ResNet50_FPN_Weights
+ :members:
+ """
+ weights = FCOS_ResNet50_FPN_Weights.verify(weights)
+ weights_backbone = ResNet50_Weights.verify(weights_backbone)
+
+ if weights is not None:
+ weights_backbone = None
+ num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+ elif num_classes is None:
+ num_classes = 91
+
+ is_trained = weights is not None or weights_backbone is not None
+ trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
+ norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
+
+ backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
+ backbone = _resnet_fpn_extractor(
+ backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256)
+ )
+ model = FCOS(backbone, num_classes, **kwargs)
+
+ if weights is not None:
+ model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+
+ return model
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/detection/generalized_rcnn.py b/.venv/lib/python3.11/site-packages/torchvision/models/detection/generalized_rcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..b481265077fb5a582402d81aeb3516ffca063653
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/models/detection/generalized_rcnn.py
@@ -0,0 +1,118 @@
+"""
+Implements the Generalized R-CNN framework
+"""
+
+import warnings
+from collections import OrderedDict
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+from torch import nn, Tensor
+
+from ...utils import _log_api_usage_once
+
+
+class GeneralizedRCNN(nn.Module):
+ """
+ Main class for Generalized R-CNN.
+
+ Args:
+ backbone (nn.Module):
+ rpn (nn.Module):
+ roi_heads (nn.Module): takes the features + the proposals from the RPN and computes
+ detections / masks from it.
+ transform (nn.Module): performs the data transformation from the inputs to feed into
+ the model
+ """
+
+ def __init__(self, backbone: nn.Module, rpn: nn.Module, roi_heads: nn.Module, transform: nn.Module) -> None:
+ super().__init__()
+ _log_api_usage_once(self)
+ self.transform = transform
+ self.backbone = backbone
+ self.rpn = rpn
+ self.roi_heads = roi_heads
+ # used only on torchscript mode
+ self._has_warned = False
+
+ @torch.jit.unused
+ def eager_outputs(self, losses, detections):
+ # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]]
+ if self.training:
+ return losses
+
+ return detections
+
+ def forward(self, images, targets=None):
+ # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
+ """
+ Args:
+ images (list[Tensor]): images to be processed
+ targets (list[Dict[str, Tensor]]): ground-truth boxes present in the image (optional)
+
+ Returns:
+ result (list[BoxList] or dict[Tensor]): the output from the model.
+ During training, it returns a dict[Tensor] which contains the losses.
+ During testing, it returns list[BoxList] contains additional fields
+ like `scores`, `labels` and `mask` (for Mask R-CNN models).
+
+ """
+ if self.training:
+ if targets is None:
+ torch._assert(False, "targets should not be none when in training mode")
+ else:
+ for target in targets:
+ boxes = target["boxes"]
+ if isinstance(boxes, torch.Tensor):
+ torch._assert(
+ len(boxes.shape) == 2 and boxes.shape[-1] == 4,
+ f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.",
+ )
+ else:
+ torch._assert(False, f"Expected target boxes to be of type Tensor, got {type(boxes)}.")
+
+ original_image_sizes: List[Tuple[int, int]] = []
+ for img in images:
+ val = img.shape[-2:]
+ torch._assert(
+ len(val) == 2,
+ f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
+ )
+ original_image_sizes.append((val[0], val[1]))
+
+ images, targets = self.transform(images, targets)
+
+ # Check for degenerate boxes
+ # TODO: Move this to a function
+ if targets is not None:
+ for target_idx, target in enumerate(targets):
+ boxes = target["boxes"]
+ degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
+ if degenerate_boxes.any():
+ # print the first degenerate box
+ bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
+ degen_bb: List[float] = boxes[bb_idx].tolist()
+ torch._assert(
+ False,
+ "All bounding boxes should have positive height and width."
+ f" Found invalid box {degen_bb} for target at index {target_idx}.",
+ )
+
+ features = self.backbone(images.tensors)
+ if isinstance(features, torch.Tensor):
+ features = OrderedDict([("0", features)])
+ proposals, proposal_losses = self.rpn(images, features, targets)
+ detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
+ detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) # type: ignore[operator]
+
+ losses = {}
+ losses.update(detector_losses)
+ losses.update(proposal_losses)
+
+ if torch.jit.is_scripting():
+ if not self._has_warned:
+ warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting")
+ self._has_warned = True
+ return losses, detections
+ else:
+ return self.eager_outputs(losses, detections)
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/detection/image_list.py b/.venv/lib/python3.11/site-packages/torchvision/models/detection/image_list.py
new file mode 100644
index 0000000000000000000000000000000000000000..583866557e4c9ec178e7cc268272db3de1698e41
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/models/detection/image_list.py
@@ -0,0 +1,25 @@
+from typing import List, Tuple
+
+import torch
+from torch import Tensor
+
+
+class ImageList:
+ """
+ Structure that holds a list of images (of possibly
+ varying sizes) as a single tensor.
+ This works by padding the images to the same size,
+ and storing in a field the original sizes of each image
+
+ Args:
+ tensors (tensor): Tensor containing images.
+ image_sizes (list[tuple[int, int]]): List of Tuples each containing size of images.
+ """
+
+ def __init__(self, tensors: Tensor, image_sizes: List[Tuple[int, int]]) -> None:
+ self.tensors = tensors
+ self.image_sizes = image_sizes
+
+ def to(self, device: torch.device) -> "ImageList":
+ cast_tensor = self.tensors.to(device)
+ return ImageList(cast_tensor, self.image_sizes)
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/detection/keypoint_rcnn.py b/.venv/lib/python3.11/site-packages/torchvision/models/detection/keypoint_rcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d7ff0ea433a681064a11a22c3e276e253997772
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/models/detection/keypoint_rcnn.py
@@ -0,0 +1,474 @@
+from typing import Any, Optional
+
+import torch
+from torch import nn
+from torchvision.ops import MultiScaleRoIAlign
+
+from ...ops import misc as misc_nn_ops
+from ...transforms._presets import ObjectDetection
+from .._api import register_model, Weights, WeightsEnum
+from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
+from .._utils import _ovewrite_value_param, handle_legacy_interface
+from ..resnet import resnet50, ResNet50_Weights
+from ._utils import overwrite_eps
+from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
+from .faster_rcnn import FasterRCNN
+
+
+__all__ = [
+ "KeypointRCNN",
+ "KeypointRCNN_ResNet50_FPN_Weights",
+ "keypointrcnn_resnet50_fpn",
+]
+
+
+class KeypointRCNN(FasterRCNN):
+ """
+ Implements Keypoint R-CNN.
+
+ The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
+ image, and should be in 0-1 range. Different images can have different sizes.
+
+ The behavior of the model changes depending on if it is in training or evaluation mode.
+
+ During training, the model expects both the input tensors and targets (list of dictionary),
+ containing:
+
+ - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+ - labels (Int64Tensor[N]): the class label for each ground-truth box
+ - keypoints (FloatTensor[N, K, 3]): the K keypoints location for each of the N instances, in the
+ format [x, y, visibility], where visibility=0 means that the keypoint is not visible.
+
+ The model returns a Dict[Tensor] during training, containing the classification and regression
+ losses for both the RPN and the R-CNN, and the keypoint loss.
+
+ During inference, the model requires only the input tensors, and returns the post-processed
+ predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
+ follows:
+
+ - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+ - labels (Int64Tensor[N]): the predicted labels for each image
+ - scores (Tensor[N]): the scores or each prediction
+ - keypoints (FloatTensor[N, K, 3]): the locations of the predicted keypoints, in [x, y, v] format.
+
+ Args:
+ backbone (nn.Module): the network used to compute the features for the model.
+ It should contain an out_channels attribute, which indicates the number of output
+ channels that each feature map has (and it should be the same for all feature maps).
+ The backbone should return a single Tensor or and OrderedDict[Tensor].
+ num_classes (int): number of output classes of the model (including the background).
+ If box_predictor is specified, num_classes should be None.
+ min_size (int): Images are rescaled before feeding them to the backbone:
+ we attempt to preserve the aspect ratio and scale the shorter edge
+ to ``min_size``. If the resulting longer edge exceeds ``max_size``,
+ then downscale so that the longer edge does not exceed ``max_size``.
+ This may result in the shorter edge beeing lower than ``min_size``.
+ max_size (int): See ``min_size``.
+ image_mean (Tuple[float, float, float]): mean values used for input normalization.
+ They are generally the mean values of the dataset on which the backbone has been trained
+ on
+ image_std (Tuple[float, float, float]): std values used for input normalization.
+ They are generally the std values of the dataset on which the backbone has been trained on
+ rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
+ maps.
+ rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
+ rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
+ rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
+ rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
+ rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
+ rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
+ rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
+ considered as positive during training of the RPN.
+ rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
+ considered as negative during training of the RPN.
+ rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
+ for computing the loss
+ rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
+ of the RPN
+ rpn_score_thresh (float): only return proposals with an objectness score greater than rpn_score_thresh
+ box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
+ the locations indicated by the bounding boxes
+ box_head (nn.Module): module that takes the cropped feature maps as input
+ box_predictor (nn.Module): module that takes the output of box_head and returns the
+ classification logits and box regression deltas.
+ box_score_thresh (float): during inference, only return proposals with a classification score
+ greater than box_score_thresh
+ box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
+ box_detections_per_img (int): maximum number of detections per image, for all classes.
+ box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
+ considered as positive during training of the classification head
+ box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
+ considered as negative during training of the classification head
+ box_batch_size_per_image (int): number of proposals that are sampled during training of the
+ classification head
+ box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
+ of the classification head
+ bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
+ bounding boxes
+ keypoint_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
+ the locations indicated by the bounding boxes, which will be used for the keypoint head.
+ keypoint_head (nn.Module): module that takes the cropped feature maps as input
+ keypoint_predictor (nn.Module): module that takes the output of the keypoint_head and returns the
+ heatmap logits
+
+ Example::
+
+ >>> import torch
+ >>> import torchvision
+ >>> from torchvision.models.detection import KeypointRCNN
+ >>> from torchvision.models.detection.anchor_utils import AnchorGenerator
+ >>>
+ >>> # load a pre-trained model for classification and return
+ >>> # only the features
+ >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
+ >>> # KeypointRCNN needs to know the number of
+ >>> # output channels in a backbone. For mobilenet_v2, it's 1280,
+ >>> # so we need to add it here
+ >>> backbone.out_channels = 1280
+ >>>
+ >>> # let's make the RPN generate 5 x 3 anchors per spatial
+ >>> # location, with 5 different sizes and 3 different aspect
+ >>> # ratios. We have a Tuple[Tuple[int]] because each feature
+ >>> # map could potentially have different sizes and
+ >>> # aspect ratios
+ >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
+ >>> aspect_ratios=((0.5, 1.0, 2.0),))
+ >>>
+ >>> # let's define what are the feature maps that we will
+ >>> # use to perform the region of interest cropping, as well as
+ >>> # the size of the crop after rescaling.
+ >>> # if your backbone returns a Tensor, featmap_names is expected to
+ >>> # be ['0']. More generally, the backbone should return an
+ >>> # OrderedDict[Tensor], and in featmap_names you can choose which
+ >>> # feature maps to use.
+ >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
+ >>> output_size=7,
+ >>> sampling_ratio=2)
+ >>>
+ >>> keypoint_roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
+ >>> output_size=14,
+ >>> sampling_ratio=2)
+ >>> # put the pieces together inside a KeypointRCNN model
+ >>> model = KeypointRCNN(backbone,
+ >>> num_classes=2,
+ >>> rpn_anchor_generator=anchor_generator,
+ >>> box_roi_pool=roi_pooler,
+ >>> keypoint_roi_pool=keypoint_roi_pooler)
+ >>> model.eval()
+ >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+ >>> predictions = model(x)
+ """
+
+ def __init__(
+ self,
+ backbone,
+ num_classes=None,
+ # transform parameters
+ min_size=None,
+ max_size=1333,
+ image_mean=None,
+ image_std=None,
+ # RPN parameters
+ rpn_anchor_generator=None,
+ rpn_head=None,
+ rpn_pre_nms_top_n_train=2000,
+ rpn_pre_nms_top_n_test=1000,
+ rpn_post_nms_top_n_train=2000,
+ rpn_post_nms_top_n_test=1000,
+ rpn_nms_thresh=0.7,
+ rpn_fg_iou_thresh=0.7,
+ rpn_bg_iou_thresh=0.3,
+ rpn_batch_size_per_image=256,
+ rpn_positive_fraction=0.5,
+ rpn_score_thresh=0.0,
+ # Box parameters
+ box_roi_pool=None,
+ box_head=None,
+ box_predictor=None,
+ box_score_thresh=0.05,
+ box_nms_thresh=0.5,
+ box_detections_per_img=100,
+ box_fg_iou_thresh=0.5,
+ box_bg_iou_thresh=0.5,
+ box_batch_size_per_image=512,
+ box_positive_fraction=0.25,
+ bbox_reg_weights=None,
+ # keypoint parameters
+ keypoint_roi_pool=None,
+ keypoint_head=None,
+ keypoint_predictor=None,
+ num_keypoints=None,
+ **kwargs,
+ ):
+
+ if not isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))):
+ raise TypeError(
+ "keypoint_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(keypoint_roi_pool)}"
+ )
+ if min_size is None:
+ min_size = (640, 672, 704, 736, 768, 800)
+
+ if num_keypoints is not None:
+ if keypoint_predictor is not None:
+ raise ValueError("num_keypoints should be None when keypoint_predictor is specified")
+ else:
+ num_keypoints = 17
+
+ out_channels = backbone.out_channels
+
+ if keypoint_roi_pool is None:
+ keypoint_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2)
+
+ if keypoint_head is None:
+ keypoint_layers = tuple(512 for _ in range(8))
+ keypoint_head = KeypointRCNNHeads(out_channels, keypoint_layers)
+
+ if keypoint_predictor is None:
+ keypoint_dim_reduced = 512 # == keypoint_layers[-1]
+ keypoint_predictor = KeypointRCNNPredictor(keypoint_dim_reduced, num_keypoints)
+
+ super().__init__(
+ backbone,
+ num_classes,
+ # transform parameters
+ min_size,
+ max_size,
+ image_mean,
+ image_std,
+ # RPN-specific parameters
+ rpn_anchor_generator,
+ rpn_head,
+ rpn_pre_nms_top_n_train,
+ rpn_pre_nms_top_n_test,
+ rpn_post_nms_top_n_train,
+ rpn_post_nms_top_n_test,
+ rpn_nms_thresh,
+ rpn_fg_iou_thresh,
+ rpn_bg_iou_thresh,
+ rpn_batch_size_per_image,
+ rpn_positive_fraction,
+ rpn_score_thresh,
+ # Box parameters
+ box_roi_pool,
+ box_head,
+ box_predictor,
+ box_score_thresh,
+ box_nms_thresh,
+ box_detections_per_img,
+ box_fg_iou_thresh,
+ box_bg_iou_thresh,
+ box_batch_size_per_image,
+ box_positive_fraction,
+ bbox_reg_weights,
+ **kwargs,
+ )
+
+ self.roi_heads.keypoint_roi_pool = keypoint_roi_pool
+ self.roi_heads.keypoint_head = keypoint_head
+ self.roi_heads.keypoint_predictor = keypoint_predictor
+
+
+class KeypointRCNNHeads(nn.Sequential):
+ def __init__(self, in_channels, layers):
+ d = []
+ next_feature = in_channels
+ for out_channels in layers:
+ d.append(nn.Conv2d(next_feature, out_channels, 3, stride=1, padding=1))
+ d.append(nn.ReLU(inplace=True))
+ next_feature = out_channels
+ super().__init__(*d)
+ for m in self.children():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+ nn.init.constant_(m.bias, 0)
+
+
+class KeypointRCNNPredictor(nn.Module):
+ def __init__(self, in_channels, num_keypoints):
+ super().__init__()
+ input_features = in_channels
+ deconv_kernel = 4
+ self.kps_score_lowres = nn.ConvTranspose2d(
+ input_features,
+ num_keypoints,
+ deconv_kernel,
+ stride=2,
+ padding=deconv_kernel // 2 - 1,
+ )
+ nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu")
+ nn.init.constant_(self.kps_score_lowres.bias, 0)
+ self.up_scale = 2
+ self.out_channels = num_keypoints
+
+ def forward(self, x):
+ x = self.kps_score_lowres(x)
+ return torch.nn.functional.interpolate(
+ x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
+ )
+
+
+_COMMON_META = {
+ "categories": _COCO_PERSON_CATEGORIES,
+ "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES,
+ "min_size": (1, 1),
+}
+
+
+class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
+ COCO_LEGACY = Weights(
+ url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth",
+ transforms=ObjectDetection,
+ meta={
+ **_COMMON_META,
+ "num_params": 59137258,
+ "recipe": "https://github.com/pytorch/vision/issues/1606",
+ "_metrics": {
+ "COCO-val2017": {
+ "box_map": 50.6,
+ "kp_map": 61.1,
+ }
+ },
+ "_ops": 133.924,
+ "_file_size": 226.054,
+ "_docs": """
+ These weights were produced by following a similar training recipe as on the paper but use a checkpoint
+ from an early epoch.
+ """,
+ },
+ )
+ COCO_V1 = Weights(
+ url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
+ transforms=ObjectDetection,
+ meta={
+ **_COMMON_META,
+ "num_params": 59137258,
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#keypoint-r-cnn",
+ "_metrics": {
+ "COCO-val2017": {
+ "box_map": 54.6,
+ "kp_map": 65.0,
+ }
+ },
+ "_ops": 137.42,
+ "_file_size": 226.054,
+ "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
+ },
+ )
+ DEFAULT = COCO_V1
+
+
+@register_model()
+@handle_legacy_interface(
+ weights=(
+ "pretrained",
+ lambda kwargs: KeypointRCNN_ResNet50_FPN_Weights.COCO_LEGACY
+ if kwargs["pretrained"] == "legacy"
+ else KeypointRCNN_ResNet50_FPN_Weights.COCO_V1,
+ ),
+ weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
+)
+def keypointrcnn_resnet50_fpn(
+ *,
+ weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None,
+ progress: bool = True,
+ num_classes: Optional[int] = None,
+ num_keypoints: Optional[int] = None,
+ weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
+ trainable_backbone_layers: Optional[int] = None,
+ **kwargs: Any,
+) -> KeypointRCNN:
+ """
+ Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone.
+
+ .. betastatus:: detection module
+
+ Reference: `Mask R-CNN `__.
+
+ The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
+ image, and should be in ``0-1`` range. Different images can have different sizes.
+
+ The behavior of the model changes depending on if it is in training or evaluation mode.
+
+ During training, the model expects both the input tensors and targets (list of dictionary),
+ containing:
+
+ - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+ - labels (``Int64Tensor[N]``): the class label for each ground-truth box
+ - keypoints (``FloatTensor[N, K, 3]``): the ``K`` keypoints location for each of the ``N`` instances, in the
+ format ``[x, y, visibility]``, where ``visibility=0`` means that the keypoint is not visible.
+
+ The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
+ losses for both the RPN and the R-CNN, and the keypoint loss.
+
+ During inference, the model requires only the input tensors, and returns the post-processed
+ predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
+ follows, where ``N`` is the number of detected instances:
+
+ - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+ - labels (``Int64Tensor[N]``): the predicted labels for each instance
+ - scores (``Tensor[N]``): the scores or each instance
+ - keypoints (``FloatTensor[N, K, 3]``): the locations of the predicted keypoints, in ``[x, y, v]`` format.
+
+ For more details on the output, you may refer to :ref:`instance_seg_output`.
+
+ Keypoint R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
+
+ Example::
+
+ >>> model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=KeypointRCNN_ResNet50_FPN_Weights.DEFAULT)
+ >>> model.eval()
+ >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+ >>> predictions = model(x)
+ >>>
+ >>> # optionally, if you want to export the model to ONNX:
+ >>> torch.onnx.export(model, x, "keypoint_rcnn.onnx", opset_version = 11)
+
+ Args:
+ weights (:class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`
+ below for more details, and possible values. By default, no
+ pre-trained weights are used.
+ progress (bool): If True, displays a progress bar of the download to stderr
+ num_classes (int, optional): number of output classes of the model (including the background)
+ num_keypoints (int, optional): number of keypoints
+ weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
+ pretrained weights for the backbone.
+ trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
+ Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
+ passed (the default) this value is set to 3.
+
+ .. autoclass:: torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights
+ :members:
+ """
+ weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights)
+ weights_backbone = ResNet50_Weights.verify(weights_backbone)
+
+ if weights is not None:
+ weights_backbone = None
+ num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+ num_keypoints = _ovewrite_value_param("num_keypoints", num_keypoints, len(weights.meta["keypoint_names"]))
+ else:
+ if num_classes is None:
+ num_classes = 2
+ if num_keypoints is None:
+ num_keypoints = 17
+
+ is_trained = weights is not None or weights_backbone is not None
+ trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
+ norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
+
+ backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
+ backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
+ model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
+
+ if weights is not None:
+ model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+ if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1:
+ overwrite_eps(model, 0.0)
+
+ return model
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/detection/mask_rcnn.py b/.venv/lib/python3.11/site-packages/torchvision/models/detection/mask_rcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdabbfd26ca8bbefaefdb6fb8b098afac217b595
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/models/detection/mask_rcnn.py
@@ -0,0 +1,590 @@
+from collections import OrderedDict
+from typing import Any, Callable, Optional
+
+from torch import nn
+from torchvision.ops import MultiScaleRoIAlign
+
+from ...ops import misc as misc_nn_ops
+from ...transforms._presets import ObjectDetection
+from .._api import register_model, Weights, WeightsEnum
+from .._meta import _COCO_CATEGORIES
+from .._utils import _ovewrite_value_param, handle_legacy_interface
+from ..resnet import resnet50, ResNet50_Weights
+from ._utils import overwrite_eps
+from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
+from .faster_rcnn import _default_anchorgen, FasterRCNN, FastRCNNConvFCHead, RPNHead
+
+
+__all__ = [
+ "MaskRCNN",
+ "MaskRCNN_ResNet50_FPN_Weights",
+ "MaskRCNN_ResNet50_FPN_V2_Weights",
+ "maskrcnn_resnet50_fpn",
+ "maskrcnn_resnet50_fpn_v2",
+]
+
+
+class MaskRCNN(FasterRCNN):
+ """
+ Implements Mask R-CNN.
+
+ The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
+ image, and should be in 0-1 range. Different images can have different sizes.
+
+ The behavior of the model changes depending on if it is in training or evaluation mode.
+
+ During training, the model expects both the input tensors and targets (list of dictionary),
+ containing:
+ - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+ - labels (Int64Tensor[N]): the class label for each ground-truth box
+ - masks (UInt8Tensor[N, H, W]): the segmentation binary masks for each instance
+
+ The model returns a Dict[Tensor] during training, containing the classification and regression
+ losses for both the RPN and the R-CNN, and the mask loss.
+
+ During inference, the model requires only the input tensors, and returns the post-processed
+ predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
+ follows:
+ - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+ - labels (Int64Tensor[N]): the predicted labels for each image
+ - scores (Tensor[N]): the scores or each prediction
+ - masks (UInt8Tensor[N, 1, H, W]): the predicted masks for each instance, in 0-1 range. In order to
+ obtain the final segmentation masks, the soft masks can be thresholded, generally
+ with a value of 0.5 (mask >= 0.5)
+
+ Args:
+ backbone (nn.Module): the network used to compute the features for the model.
+ It should contain an out_channels attribute, which indicates the number of output
+ channels that each feature map has (and it should be the same for all feature maps).
+ The backbone should return a single Tensor or and OrderedDict[Tensor].
+ num_classes (int): number of output classes of the model (including the background).
+ If box_predictor is specified, num_classes should be None.
+ min_size (int): Images are rescaled before feeding them to the backbone:
+ we attempt to preserve the aspect ratio and scale the shorter edge
+ to ``min_size``. If the resulting longer edge exceeds ``max_size``,
+ then downscale so that the longer edge does not exceed ``max_size``.
+ This may result in the shorter edge beeing lower than ``min_size``.
+ max_size (int): See ``min_size``.
+ image_mean (Tuple[float, float, float]): mean values used for input normalization.
+ They are generally the mean values of the dataset on which the backbone has been trained
+ on
+ image_std (Tuple[float, float, float]): std values used for input normalization.
+ They are generally the std values of the dataset on which the backbone has been trained on
+ rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
+ maps.
+ rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
+ rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
+ rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
+ rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
+ rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
+ rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
+ rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
+ considered as positive during training of the RPN.
+ rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
+ considered as negative during training of the RPN.
+ rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
+ for computing the loss
+ rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
+ of the RPN
+ rpn_score_thresh (float): only return proposals with an objectness score greater than rpn_score_thresh
+ box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
+ the locations indicated by the bounding boxes
+ box_head (nn.Module): module that takes the cropped feature maps as input
+ box_predictor (nn.Module): module that takes the output of box_head and returns the
+ classification logits and box regression deltas.
+ box_score_thresh (float): during inference, only return proposals with a classification score
+ greater than box_score_thresh
+ box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
+ box_detections_per_img (int): maximum number of detections per image, for all classes.
+ box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
+ considered as positive during training of the classification head
+ box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
+ considered as negative during training of the classification head
+ box_batch_size_per_image (int): number of proposals that are sampled during training of the
+ classification head
+ box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
+ of the classification head
+ bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
+ bounding boxes
+ mask_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
+ the locations indicated by the bounding boxes, which will be used for the mask head.
+ mask_head (nn.Module): module that takes the cropped feature maps as input
+ mask_predictor (nn.Module): module that takes the output of the mask_head and returns the
+ segmentation mask logits
+
+ Example::
+
+ >>> import torch
+ >>> import torchvision
+ >>> from torchvision.models.detection import MaskRCNN
+ >>> from torchvision.models.detection.anchor_utils import AnchorGenerator
+ >>>
+ >>> # load a pre-trained model for classification and return
+ >>> # only the features
+ >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
+ >>> # MaskRCNN needs to know the number of
+ >>> # output channels in a backbone. For mobilenet_v2, it's 1280
+ >>> # so we need to add it here,
+ >>> backbone.out_channels = 1280
+ >>>
+ >>> # let's make the RPN generate 5 x 3 anchors per spatial
+ >>> # location, with 5 different sizes and 3 different aspect
+ >>> # ratios. We have a Tuple[Tuple[int]] because each feature
+ >>> # map could potentially have different sizes and
+ >>> # aspect ratios
+ >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
+ >>> aspect_ratios=((0.5, 1.0, 2.0),))
+ >>>
+ >>> # let's define what are the feature maps that we will
+ >>> # use to perform the region of interest cropping, as well as
+ >>> # the size of the crop after rescaling.
+ >>> # if your backbone returns a Tensor, featmap_names is expected to
+ >>> # be ['0']. More generally, the backbone should return an
+ >>> # OrderedDict[Tensor], and in featmap_names you can choose which
+ >>> # feature maps to use.
+ >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
+ >>> output_size=7,
+ >>> sampling_ratio=2)
+ >>>
+ >>> mask_roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
+ >>> output_size=14,
+ >>> sampling_ratio=2)
+ >>> # put the pieces together inside a MaskRCNN model
+ >>> model = MaskRCNN(backbone,
+ >>> num_classes=2,
+ >>> rpn_anchor_generator=anchor_generator,
+ >>> box_roi_pool=roi_pooler,
+ >>> mask_roi_pool=mask_roi_pooler)
+ >>> model.eval()
+ >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+ >>> predictions = model(x)
+ """
+
+ def __init__(
+ self,
+ backbone,
+ num_classes=None,
+ # transform parameters
+ min_size=800,
+ max_size=1333,
+ image_mean=None,
+ image_std=None,
+ # RPN parameters
+ rpn_anchor_generator=None,
+ rpn_head=None,
+ rpn_pre_nms_top_n_train=2000,
+ rpn_pre_nms_top_n_test=1000,
+ rpn_post_nms_top_n_train=2000,
+ rpn_post_nms_top_n_test=1000,
+ rpn_nms_thresh=0.7,
+ rpn_fg_iou_thresh=0.7,
+ rpn_bg_iou_thresh=0.3,
+ rpn_batch_size_per_image=256,
+ rpn_positive_fraction=0.5,
+ rpn_score_thresh=0.0,
+ # Box parameters
+ box_roi_pool=None,
+ box_head=None,
+ box_predictor=None,
+ box_score_thresh=0.05,
+ box_nms_thresh=0.5,
+ box_detections_per_img=100,
+ box_fg_iou_thresh=0.5,
+ box_bg_iou_thresh=0.5,
+ box_batch_size_per_image=512,
+ box_positive_fraction=0.25,
+ bbox_reg_weights=None,
+ # Mask parameters
+ mask_roi_pool=None,
+ mask_head=None,
+ mask_predictor=None,
+ **kwargs,
+ ):
+
+ if not isinstance(mask_roi_pool, (MultiScaleRoIAlign, type(None))):
+ raise TypeError(
+ f"mask_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(mask_roi_pool)}"
+ )
+
+ if num_classes is not None:
+ if mask_predictor is not None:
+ raise ValueError("num_classes should be None when mask_predictor is specified")
+
+ out_channels = backbone.out_channels
+
+ if mask_roi_pool is None:
+ mask_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2)
+
+ if mask_head is None:
+ mask_layers = (256, 256, 256, 256)
+ mask_dilation = 1
+ mask_head = MaskRCNNHeads(out_channels, mask_layers, mask_dilation)
+
+ if mask_predictor is None:
+ mask_predictor_in_channels = 256 # == mask_layers[-1]
+ mask_dim_reduced = 256
+ mask_predictor = MaskRCNNPredictor(mask_predictor_in_channels, mask_dim_reduced, num_classes)
+
+ super().__init__(
+ backbone,
+ num_classes,
+ # transform parameters
+ min_size,
+ max_size,
+ image_mean,
+ image_std,
+ # RPN-specific parameters
+ rpn_anchor_generator,
+ rpn_head,
+ rpn_pre_nms_top_n_train,
+ rpn_pre_nms_top_n_test,
+ rpn_post_nms_top_n_train,
+ rpn_post_nms_top_n_test,
+ rpn_nms_thresh,
+ rpn_fg_iou_thresh,
+ rpn_bg_iou_thresh,
+ rpn_batch_size_per_image,
+ rpn_positive_fraction,
+ rpn_score_thresh,
+ # Box parameters
+ box_roi_pool,
+ box_head,
+ box_predictor,
+ box_score_thresh,
+ box_nms_thresh,
+ box_detections_per_img,
+ box_fg_iou_thresh,
+ box_bg_iou_thresh,
+ box_batch_size_per_image,
+ box_positive_fraction,
+ bbox_reg_weights,
+ **kwargs,
+ )
+
+ self.roi_heads.mask_roi_pool = mask_roi_pool
+ self.roi_heads.mask_head = mask_head
+ self.roi_heads.mask_predictor = mask_predictor
+
+
+class MaskRCNNHeads(nn.Sequential):
+ _version = 2
+
+ def __init__(self, in_channels, layers, dilation, norm_layer: Optional[Callable[..., nn.Module]] = None):
+ """
+ Args:
+ in_channels (int): number of input channels
+ layers (list): feature dimensions of each FCN layer
+ dilation (int): dilation rate of kernel
+ norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
+ """
+ blocks = []
+ next_feature = in_channels
+ for layer_features in layers:
+ blocks.append(
+ misc_nn_ops.Conv2dNormActivation(
+ next_feature,
+ layer_features,
+ kernel_size=3,
+ stride=1,
+ padding=dilation,
+ dilation=dilation,
+ norm_layer=norm_layer,
+ )
+ )
+ next_feature = layer_features
+
+ super().__init__(*blocks)
+ for layer in self.modules():
+ if isinstance(layer, nn.Conv2d):
+ nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu")
+ if layer.bias is not None:
+ nn.init.zeros_(layer.bias)
+
+ def _load_from_state_dict(
+ self,
+ state_dict,
+ prefix,
+ local_metadata,
+ strict,
+ missing_keys,
+ unexpected_keys,
+ error_msgs,
+ ):
+ version = local_metadata.get("version", None)
+
+ if version is None or version < 2:
+ num_blocks = len(self)
+ for i in range(num_blocks):
+ for type in ["weight", "bias"]:
+ old_key = f"{prefix}mask_fcn{i+1}.{type}"
+ new_key = f"{prefix}{i}.0.{type}"
+ if old_key in state_dict:
+ state_dict[new_key] = state_dict.pop(old_key)
+
+ super()._load_from_state_dict(
+ state_dict,
+ prefix,
+ local_metadata,
+ strict,
+ missing_keys,
+ unexpected_keys,
+ error_msgs,
+ )
+
+
+class MaskRCNNPredictor(nn.Sequential):
+ def __init__(self, in_channels, dim_reduced, num_classes):
+ super().__init__(
+ OrderedDict(
+ [
+ ("conv5_mask", nn.ConvTranspose2d(in_channels, dim_reduced, 2, 2, 0)),
+ ("relu", nn.ReLU(inplace=True)),
+ ("mask_fcn_logits", nn.Conv2d(dim_reduced, num_classes, 1, 1, 0)),
+ ]
+ )
+ )
+
+ for name, param in self.named_parameters():
+ if "weight" in name:
+ nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu")
+ # elif "bias" in name:
+ # nn.init.constant_(param, 0)
+
+
+_COMMON_META = {
+ "categories": _COCO_CATEGORIES,
+ "min_size": (1, 1),
+}
+
+
+class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum):
+ COCO_V1 = Weights(
+ url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth",
+ transforms=ObjectDetection,
+ meta={
+ **_COMMON_META,
+ "num_params": 44401393,
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#mask-r-cnn",
+ "_metrics": {
+ "COCO-val2017": {
+ "box_map": 37.9,
+ "mask_map": 34.6,
+ }
+ },
+ "_ops": 134.38,
+ "_file_size": 169.84,
+ "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
+ },
+ )
+ DEFAULT = COCO_V1
+
+
+class MaskRCNN_ResNet50_FPN_V2_Weights(WeightsEnum):
+ COCO_V1 = Weights(
+ url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_v2_coco-73cbd019.pth",
+ transforms=ObjectDetection,
+ meta={
+ **_COMMON_META,
+ "num_params": 46359409,
+ "recipe": "https://github.com/pytorch/vision/pull/5773",
+ "_metrics": {
+ "COCO-val2017": {
+ "box_map": 47.4,
+ "mask_map": 41.8,
+ }
+ },
+ "_ops": 333.577,
+ "_file_size": 177.219,
+ "_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""",
+ },
+ )
+ DEFAULT = COCO_V1
+
+
+@register_model()
+@handle_legacy_interface(
+ weights=("pretrained", MaskRCNN_ResNet50_FPN_Weights.COCO_V1),
+ weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
+)
+def maskrcnn_resnet50_fpn(
+ *,
+ weights: Optional[MaskRCNN_ResNet50_FPN_Weights] = None,
+ progress: bool = True,
+ num_classes: Optional[int] = None,
+ weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
+ trainable_backbone_layers: Optional[int] = None,
+ **kwargs: Any,
+) -> MaskRCNN:
+ """Mask R-CNN model with a ResNet-50-FPN backbone from the `Mask R-CNN
+ `_ paper.
+
+ .. betastatus:: detection module
+
+ The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
+ image, and should be in ``0-1`` range. Different images can have different sizes.
+
+ The behavior of the model changes depending on if it is in training or evaluation mode.
+
+ During training, the model expects both the input tensors and targets (list of dictionary),
+ containing:
+
+ - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+ - labels (``Int64Tensor[N]``): the class label for each ground-truth box
+ - masks (``UInt8Tensor[N, H, W]``): the segmentation binary masks for each instance
+
+ The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
+ losses for both the RPN and the R-CNN, and the mask loss.
+
+ During inference, the model requires only the input tensors, and returns the post-processed
+ predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
+ follows, where ``N`` is the number of detected instances:
+
+ - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+ - labels (``Int64Tensor[N]``): the predicted labels for each instance
+ - scores (``Tensor[N]``): the scores or each instance
+ - masks (``UInt8Tensor[N, 1, H, W]``): the predicted masks for each instance, in ``0-1`` range. In order to
+ obtain the final segmentation masks, the soft masks can be thresholded, generally
+ with a value of 0.5 (``mask >= 0.5``)
+
+ For more details on the output and on how to plot the masks, you may refer to :ref:`instance_seg_output`.
+
+ Mask R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
+
+ Example::
+
+ >>> model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT)
+ >>> model.eval()
+ >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+ >>> predictions = model(x)
+ >>>
+ >>> # optionally, if you want to export the model to ONNX:
+ >>> torch.onnx.export(model, x, "mask_rcnn.onnx", opset_version = 11)
+
+ Args:
+ weights (:class:`~torchvision.models.detection.MaskRCNN_ResNet50_FPN_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.detection.MaskRCNN_ResNet50_FPN_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ num_classes (int, optional): number of output classes of the model (including the background)
+ weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
+ pretrained weights for the backbone.
+ trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
+ final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
+ trainable. If ``None`` is passed (the default) this value is set to 3.
+ **kwargs: parameters passed to the ``torchvision.models.detection.mask_rcnn.MaskRCNN``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+
+ .. autoclass:: torchvision.models.detection.MaskRCNN_ResNet50_FPN_Weights
+ :members:
+ """
+ weights = MaskRCNN_ResNet50_FPN_Weights.verify(weights)
+ weights_backbone = ResNet50_Weights.verify(weights_backbone)
+
+ if weights is not None:
+ weights_backbone = None
+ num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+ elif num_classes is None:
+ num_classes = 91
+
+ is_trained = weights is not None or weights_backbone is not None
+ trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
+ norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
+
+ backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
+ backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
+ model = MaskRCNN(backbone, num_classes=num_classes, **kwargs)
+
+ if weights is not None:
+ model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+ if weights == MaskRCNN_ResNet50_FPN_Weights.COCO_V1:
+ overwrite_eps(model, 0.0)
+
+ return model
+
+
+@register_model()
+@handle_legacy_interface(
+ weights=("pretrained", MaskRCNN_ResNet50_FPN_V2_Weights.COCO_V1),
+ weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
+)
+def maskrcnn_resnet50_fpn_v2(
+ *,
+ weights: Optional[MaskRCNN_ResNet50_FPN_V2_Weights] = None,
+ progress: bool = True,
+ num_classes: Optional[int] = None,
+ weights_backbone: Optional[ResNet50_Weights] = None,
+ trainable_backbone_layers: Optional[int] = None,
+ **kwargs: Any,
+) -> MaskRCNN:
+ """Improved Mask R-CNN model with a ResNet-50-FPN backbone from the `Benchmarking Detection Transfer
+ Learning with Vision Transformers `_ paper.
+
+ .. betastatus:: detection module
+
+ :func:`~torchvision.models.detection.maskrcnn_resnet50_fpn` for more details.
+
+ Args:
+ weights (:class:`~torchvision.models.detection.MaskRCNN_ResNet50_FPN_V2_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.detection.MaskRCNN_ResNet50_FPN_V2_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ num_classes (int, optional): number of output classes of the model (including the background)
+ weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
+ pretrained weights for the backbone.
+ trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
+ final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
+ trainable. If ``None`` is passed (the default) this value is set to 3.
+ **kwargs: parameters passed to the ``torchvision.models.detection.mask_rcnn.MaskRCNN``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+
+ .. autoclass:: torchvision.models.detection.MaskRCNN_ResNet50_FPN_V2_Weights
+ :members:
+ """
+ weights = MaskRCNN_ResNet50_FPN_V2_Weights.verify(weights)
+ weights_backbone = ResNet50_Weights.verify(weights_backbone)
+
+ if weights is not None:
+ weights_backbone = None
+ num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+ elif num_classes is None:
+ num_classes = 91
+
+ is_trained = weights is not None or weights_backbone is not None
+ trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
+
+ backbone = resnet50(weights=weights_backbone, progress=progress)
+ backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers, norm_layer=nn.BatchNorm2d)
+ rpn_anchor_generator = _default_anchorgen()
+ rpn_head = RPNHead(backbone.out_channels, rpn_anchor_generator.num_anchors_per_location()[0], conv_depth=2)
+ box_head = FastRCNNConvFCHead(
+ (backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d
+ )
+ mask_head = MaskRCNNHeads(backbone.out_channels, [256, 256, 256, 256], 1, norm_layer=nn.BatchNorm2d)
+ model = MaskRCNN(
+ backbone,
+ num_classes=num_classes,
+ rpn_anchor_generator=rpn_anchor_generator,
+ rpn_head=rpn_head,
+ box_head=box_head,
+ mask_head=mask_head,
+ **kwargs,
+ )
+
+ if weights is not None:
+ model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+
+ return model
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/detection/retinanet.py b/.venv/lib/python3.11/site-packages/torchvision/models/detection/retinanet.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8cc7755014b6010965108a46c080f71b2d609db
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/models/detection/retinanet.py
@@ -0,0 +1,903 @@
+import math
+import warnings
+from collections import OrderedDict
+from functools import partial
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+import torch
+from torch import nn, Tensor
+
+from ...ops import boxes as box_ops, misc as misc_nn_ops, sigmoid_focal_loss
+from ...ops.feature_pyramid_network import LastLevelP6P7
+from ...transforms._presets import ObjectDetection
+from ...utils import _log_api_usage_once
+from .._api import register_model, Weights, WeightsEnum
+from .._meta import _COCO_CATEGORIES
+from .._utils import _ovewrite_value_param, handle_legacy_interface
+from ..resnet import resnet50, ResNet50_Weights
+from . import _utils as det_utils
+from ._utils import _box_loss, overwrite_eps
+from .anchor_utils import AnchorGenerator
+from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
+from .transform import GeneralizedRCNNTransform
+
+
+__all__ = [
+ "RetinaNet",
+ "RetinaNet_ResNet50_FPN_Weights",
+ "RetinaNet_ResNet50_FPN_V2_Weights",
+ "retinanet_resnet50_fpn",
+ "retinanet_resnet50_fpn_v2",
+]
+
+
+def _sum(x: List[Tensor]) -> Tensor:
+ res = x[0]
+ for i in x[1:]:
+ res = res + i
+ return res
+
+
+def _v1_to_v2_weights(state_dict, prefix):
+ for i in range(4):
+ for type in ["weight", "bias"]:
+ old_key = f"{prefix}conv.{2*i}.{type}"
+ new_key = f"{prefix}conv.{i}.0.{type}"
+ if old_key in state_dict:
+ state_dict[new_key] = state_dict.pop(old_key)
+
+
+def _default_anchorgen():
+ anchor_sizes = tuple((x, int(x * 2 ** (1.0 / 3)), int(x * 2 ** (2.0 / 3))) for x in [32, 64, 128, 256, 512])
+ aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
+ anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
+ return anchor_generator
+
+
+class RetinaNetHead(nn.Module):
+ """
+ A regression and classification head for use in RetinaNet.
+
+ Args:
+ in_channels (int): number of channels of the input feature
+ num_anchors (int): number of anchors to be predicted
+ num_classes (int): number of classes to be predicted
+ norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
+ """
+
+ def __init__(self, in_channels, num_anchors, num_classes, norm_layer: Optional[Callable[..., nn.Module]] = None):
+ super().__init__()
+ self.classification_head = RetinaNetClassificationHead(
+ in_channels, num_anchors, num_classes, norm_layer=norm_layer
+ )
+ self.regression_head = RetinaNetRegressionHead(in_channels, num_anchors, norm_layer=norm_layer)
+
+ def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
+ # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Dict[str, Tensor]
+ return {
+ "classification": self.classification_head.compute_loss(targets, head_outputs, matched_idxs),
+ "bbox_regression": self.regression_head.compute_loss(targets, head_outputs, anchors, matched_idxs),
+ }
+
+ def forward(self, x):
+ # type: (List[Tensor]) -> Dict[str, Tensor]
+ return {"cls_logits": self.classification_head(x), "bbox_regression": self.regression_head(x)}
+
+
+class RetinaNetClassificationHead(nn.Module):
+ """
+ A classification head for use in RetinaNet.
+
+ Args:
+ in_channels (int): number of channels of the input feature
+ num_anchors (int): number of anchors to be predicted
+ num_classes (int): number of classes to be predicted
+ norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
+ """
+
+ _version = 2
+
+ def __init__(
+ self,
+ in_channels,
+ num_anchors,
+ num_classes,
+ prior_probability=0.01,
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
+ ):
+ super().__init__()
+
+ conv = []
+ for _ in range(4):
+ conv.append(misc_nn_ops.Conv2dNormActivation(in_channels, in_channels, norm_layer=norm_layer))
+ self.conv = nn.Sequential(*conv)
+
+ for layer in self.conv.modules():
+ if isinstance(layer, nn.Conv2d):
+ torch.nn.init.normal_(layer.weight, std=0.01)
+ if layer.bias is not None:
+ torch.nn.init.constant_(layer.bias, 0)
+
+ self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
+ torch.nn.init.normal_(self.cls_logits.weight, std=0.01)
+ torch.nn.init.constant_(self.cls_logits.bias, -math.log((1 - prior_probability) / prior_probability))
+
+ self.num_classes = num_classes
+ self.num_anchors = num_anchors
+
+ # This is to fix using det_utils.Matcher.BETWEEN_THRESHOLDS in TorchScript.
+ # TorchScript doesn't support class attributes.
+ # https://github.com/pytorch/vision/pull/1697#issuecomment-630255584
+ self.BETWEEN_THRESHOLDS = det_utils.Matcher.BETWEEN_THRESHOLDS
+
+ def _load_from_state_dict(
+ self,
+ state_dict,
+ prefix,
+ local_metadata,
+ strict,
+ missing_keys,
+ unexpected_keys,
+ error_msgs,
+ ):
+ version = local_metadata.get("version", None)
+
+ if version is None or version < 2:
+ _v1_to_v2_weights(state_dict, prefix)
+
+ super()._load_from_state_dict(
+ state_dict,
+ prefix,
+ local_metadata,
+ strict,
+ missing_keys,
+ unexpected_keys,
+ error_msgs,
+ )
+
+ def compute_loss(self, targets, head_outputs, matched_idxs):
+ # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Tensor
+ losses = []
+
+ cls_logits = head_outputs["cls_logits"]
+
+ for targets_per_image, cls_logits_per_image, matched_idxs_per_image in zip(targets, cls_logits, matched_idxs):
+ # determine only the foreground
+ foreground_idxs_per_image = matched_idxs_per_image >= 0
+ num_foreground = foreground_idxs_per_image.sum()
+
+ # create the target classification
+ gt_classes_target = torch.zeros_like(cls_logits_per_image)
+ gt_classes_target[
+ foreground_idxs_per_image,
+ targets_per_image["labels"][matched_idxs_per_image[foreground_idxs_per_image]],
+ ] = 1.0
+
+ # find indices for which anchors should be ignored
+ valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS
+
+ # compute the classification loss
+ losses.append(
+ sigmoid_focal_loss(
+ cls_logits_per_image[valid_idxs_per_image],
+ gt_classes_target[valid_idxs_per_image],
+ reduction="sum",
+ )
+ / max(1, num_foreground)
+ )
+
+ return _sum(losses) / len(targets)
+
+ def forward(self, x):
+ # type: (List[Tensor]) -> Tensor
+ all_cls_logits = []
+
+ for features in x:
+ cls_logits = self.conv(features)
+ cls_logits = self.cls_logits(cls_logits)
+
+ # Permute classification output from (N, A * K, H, W) to (N, HWA, K).
+ N, _, H, W = cls_logits.shape
+ cls_logits = cls_logits.view(N, -1, self.num_classes, H, W)
+ cls_logits = cls_logits.permute(0, 3, 4, 1, 2)
+ cls_logits = cls_logits.reshape(N, -1, self.num_classes) # Size=(N, HWA, 4)
+
+ all_cls_logits.append(cls_logits)
+
+ return torch.cat(all_cls_logits, dim=1)
+
+
+class RetinaNetRegressionHead(nn.Module):
+ """
+ A regression head for use in RetinaNet.
+
+ Args:
+ in_channels (int): number of channels of the input feature
+ num_anchors (int): number of anchors to be predicted
+ norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
+ """
+
+ _version = 2
+
+ __annotations__ = {
+ "box_coder": det_utils.BoxCoder,
+ }
+
+ def __init__(self, in_channels, num_anchors, norm_layer: Optional[Callable[..., nn.Module]] = None):
+ super().__init__()
+
+ conv = []
+ for _ in range(4):
+ conv.append(misc_nn_ops.Conv2dNormActivation(in_channels, in_channels, norm_layer=norm_layer))
+ self.conv = nn.Sequential(*conv)
+
+ self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1)
+ torch.nn.init.normal_(self.bbox_reg.weight, std=0.01)
+ torch.nn.init.zeros_(self.bbox_reg.bias)
+
+ for layer in self.conv.modules():
+ if isinstance(layer, nn.Conv2d):
+ torch.nn.init.normal_(layer.weight, std=0.01)
+ if layer.bias is not None:
+ torch.nn.init.zeros_(layer.bias)
+
+ self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
+ self._loss_type = "l1"
+
+ def _load_from_state_dict(
+ self,
+ state_dict,
+ prefix,
+ local_metadata,
+ strict,
+ missing_keys,
+ unexpected_keys,
+ error_msgs,
+ ):
+ version = local_metadata.get("version", None)
+
+ if version is None or version < 2:
+ _v1_to_v2_weights(state_dict, prefix)
+
+ super()._load_from_state_dict(
+ state_dict,
+ prefix,
+ local_metadata,
+ strict,
+ missing_keys,
+ unexpected_keys,
+ error_msgs,
+ )
+
+ def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
+ # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Tensor
+ losses = []
+
+ bbox_regression = head_outputs["bbox_regression"]
+
+ for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in zip(
+ targets, bbox_regression, anchors, matched_idxs
+ ):
+ # determine only the foreground indices, ignore the rest
+ foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0]
+ num_foreground = foreground_idxs_per_image.numel()
+
+ # select only the foreground boxes
+ matched_gt_boxes_per_image = targets_per_image["boxes"][matched_idxs_per_image[foreground_idxs_per_image]]
+ bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :]
+ anchors_per_image = anchors_per_image[foreground_idxs_per_image, :]
+
+ # compute the loss
+ losses.append(
+ _box_loss(
+ self._loss_type,
+ self.box_coder,
+ anchors_per_image,
+ matched_gt_boxes_per_image,
+ bbox_regression_per_image,
+ )
+ / max(1, num_foreground)
+ )
+
+ return _sum(losses) / max(1, len(targets))
+
+ def forward(self, x):
+ # type: (List[Tensor]) -> Tensor
+ all_bbox_regression = []
+
+ for features in x:
+ bbox_regression = self.conv(features)
+ bbox_regression = self.bbox_reg(bbox_regression)
+
+ # Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4).
+ N, _, H, W = bbox_regression.shape
+ bbox_regression = bbox_regression.view(N, -1, 4, H, W)
+ bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2)
+ bbox_regression = bbox_regression.reshape(N, -1, 4) # Size=(N, HWA, 4)
+
+ all_bbox_regression.append(bbox_regression)
+
+ return torch.cat(all_bbox_regression, dim=1)
+
+
+class RetinaNet(nn.Module):
+ """
+ Implements RetinaNet.
+
+ The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
+ image, and should be in 0-1 range. Different images can have different sizes.
+
+ The behavior of the model changes depending on if it is in training or evaluation mode.
+
+ During training, the model expects both the input tensors and targets (list of dictionary),
+ containing:
+ - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+ - labels (Int64Tensor[N]): the class label for each ground-truth box
+
+ The model returns a Dict[Tensor] during training, containing the classification and regression
+ losses.
+
+ During inference, the model requires only the input tensors, and returns the post-processed
+ predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
+ follows:
+ - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+ - labels (Int64Tensor[N]): the predicted labels for each image
+ - scores (Tensor[N]): the scores for each prediction
+
+ Args:
+ backbone (nn.Module): the network used to compute the features for the model.
+ It should contain an out_channels attribute, which indicates the number of output
+ channels that each feature map has (and it should be the same for all feature maps).
+ The backbone should return a single Tensor or an OrderedDict[Tensor].
+ num_classes (int): number of output classes of the model (including the background).
+ min_size (int): Images are rescaled before feeding them to the backbone:
+ we attempt to preserve the aspect ratio and scale the shorter edge
+ to ``min_size``. If the resulting longer edge exceeds ``max_size``,
+ then downscale so that the longer edge does not exceed ``max_size``.
+ This may result in the shorter edge beeing lower than ``min_size``.
+ max_size (int): See ``min_size``.
+ image_mean (Tuple[float, float, float]): mean values used for input normalization.
+ They are generally the mean values of the dataset on which the backbone has been trained
+ on
+ image_std (Tuple[float, float, float]): std values used for input normalization.
+ They are generally the std values of the dataset on which the backbone has been trained on
+ anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
+ maps.
+ head (nn.Module): Module run on top of the feature pyramid.
+ Defaults to a module containing a classification and regression module.
+ score_thresh (float): Score threshold used for postprocessing the detections.
+ nms_thresh (float): NMS threshold used for postprocessing the detections.
+ detections_per_img (int): Number of best detections to keep after NMS.
+ fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
+ considered as positive during training.
+ bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
+ considered as negative during training.
+ topk_candidates (int): Number of best detections to keep before NMS.
+
+ Example:
+
+ >>> import torch
+ >>> import torchvision
+ >>> from torchvision.models.detection import RetinaNet
+ >>> from torchvision.models.detection.anchor_utils import AnchorGenerator
+ >>> # load a pre-trained model for classification and return
+ >>> # only the features
+ >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
+ >>> # RetinaNet needs to know the number of
+ >>> # output channels in a backbone. For mobilenet_v2, it's 1280,
+ >>> # so we need to add it here
+ >>> backbone.out_channels = 1280
+ >>>
+ >>> # let's make the network generate 5 x 3 anchors per spatial
+ >>> # location, with 5 different sizes and 3 different aspect
+ >>> # ratios. We have a Tuple[Tuple[int]] because each feature
+ >>> # map could potentially have different sizes and
+ >>> # aspect ratios
+ >>> anchor_generator = AnchorGenerator(
+ >>> sizes=((32, 64, 128, 256, 512),),
+ >>> aspect_ratios=((0.5, 1.0, 2.0),)
+ >>> )
+ >>>
+ >>> # put the pieces together inside a RetinaNet model
+ >>> model = RetinaNet(backbone,
+ >>> num_classes=2,
+ >>> anchor_generator=anchor_generator)
+ >>> model.eval()
+ >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+ >>> predictions = model(x)
+ """
+
+ __annotations__ = {
+ "box_coder": det_utils.BoxCoder,
+ "proposal_matcher": det_utils.Matcher,
+ }
+
+ def __init__(
+ self,
+ backbone,
+ num_classes,
+ # transform parameters
+ min_size=800,
+ max_size=1333,
+ image_mean=None,
+ image_std=None,
+ # Anchor parameters
+ anchor_generator=None,
+ head=None,
+ proposal_matcher=None,
+ score_thresh=0.05,
+ nms_thresh=0.5,
+ detections_per_img=300,
+ fg_iou_thresh=0.5,
+ bg_iou_thresh=0.4,
+ topk_candidates=1000,
+ **kwargs,
+ ):
+ super().__init__()
+ _log_api_usage_once(self)
+
+ if not hasattr(backbone, "out_channels"):
+ raise ValueError(
+ "backbone should contain an attribute out_channels "
+ "specifying the number of output channels (assumed to be the "
+ "same for all the levels)"
+ )
+ self.backbone = backbone
+
+ if not isinstance(anchor_generator, (AnchorGenerator, type(None))):
+ raise TypeError(
+ f"anchor_generator should be of type AnchorGenerator or None instead of {type(anchor_generator)}"
+ )
+
+ if anchor_generator is None:
+ anchor_generator = _default_anchorgen()
+ self.anchor_generator = anchor_generator
+
+ if head is None:
+ head = RetinaNetHead(backbone.out_channels, anchor_generator.num_anchors_per_location()[0], num_classes)
+ self.head = head
+
+ if proposal_matcher is None:
+ proposal_matcher = det_utils.Matcher(
+ fg_iou_thresh,
+ bg_iou_thresh,
+ allow_low_quality_matches=True,
+ )
+ self.proposal_matcher = proposal_matcher
+
+ self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
+
+ if image_mean is None:
+ image_mean = [0.485, 0.456, 0.406]
+ if image_std is None:
+ image_std = [0.229, 0.224, 0.225]
+ self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
+
+ self.score_thresh = score_thresh
+ self.nms_thresh = nms_thresh
+ self.detections_per_img = detections_per_img
+ self.topk_candidates = topk_candidates
+
+ # used only on torchscript mode
+ self._has_warned = False
+
+ @torch.jit.unused
+ def eager_outputs(self, losses, detections):
+ # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
+ if self.training:
+ return losses
+
+ return detections
+
+ def compute_loss(self, targets, head_outputs, anchors):
+ # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Dict[str, Tensor]
+ matched_idxs = []
+ for anchors_per_image, targets_per_image in zip(anchors, targets):
+ if targets_per_image["boxes"].numel() == 0:
+ matched_idxs.append(
+ torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device)
+ )
+ continue
+
+ match_quality_matrix = box_ops.box_iou(targets_per_image["boxes"], anchors_per_image)
+ matched_idxs.append(self.proposal_matcher(match_quality_matrix))
+
+ return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs)
+
+ def postprocess_detections(self, head_outputs, anchors, image_shapes):
+ # type: (Dict[str, List[Tensor]], List[List[Tensor]], List[Tuple[int, int]]) -> List[Dict[str, Tensor]]
+ class_logits = head_outputs["cls_logits"]
+ box_regression = head_outputs["bbox_regression"]
+
+ num_images = len(image_shapes)
+
+ detections: List[Dict[str, Tensor]] = []
+
+ for index in range(num_images):
+ box_regression_per_image = [br[index] for br in box_regression]
+ logits_per_image = [cl[index] for cl in class_logits]
+ anchors_per_image, image_shape = anchors[index], image_shapes[index]
+
+ image_boxes = []
+ image_scores = []
+ image_labels = []
+
+ for box_regression_per_level, logits_per_level, anchors_per_level in zip(
+ box_regression_per_image, logits_per_image, anchors_per_image
+ ):
+ num_classes = logits_per_level.shape[-1]
+
+ # remove low scoring boxes
+ scores_per_level = torch.sigmoid(logits_per_level).flatten()
+ keep_idxs = scores_per_level > self.score_thresh
+ scores_per_level = scores_per_level[keep_idxs]
+ topk_idxs = torch.where(keep_idxs)[0]
+
+ # keep only topk scoring predictions
+ num_topk = det_utils._topk_min(topk_idxs, self.topk_candidates, 0)
+ scores_per_level, idxs = scores_per_level.topk(num_topk)
+ topk_idxs = topk_idxs[idxs]
+
+ anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor")
+ labels_per_level = topk_idxs % num_classes
+
+ boxes_per_level = self.box_coder.decode_single(
+ box_regression_per_level[anchor_idxs], anchors_per_level[anchor_idxs]
+ )
+ boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape)
+
+ image_boxes.append(boxes_per_level)
+ image_scores.append(scores_per_level)
+ image_labels.append(labels_per_level)
+
+ image_boxes = torch.cat(image_boxes, dim=0)
+ image_scores = torch.cat(image_scores, dim=0)
+ image_labels = torch.cat(image_labels, dim=0)
+
+ # non-maximum suppression
+ keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
+ keep = keep[: self.detections_per_img]
+
+ detections.append(
+ {
+ "boxes": image_boxes[keep],
+ "scores": image_scores[keep],
+ "labels": image_labels[keep],
+ }
+ )
+
+ return detections
+
+ def forward(self, images, targets=None):
+ # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
+ """
+ Args:
+ images (list[Tensor]): images to be processed
+ targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)
+
+ Returns:
+ result (list[BoxList] or dict[Tensor]): the output from the model.
+ During training, it returns a dict[Tensor] which contains the losses.
+ During testing, it returns list[BoxList] contains additional fields
+ like `scores`, `labels` and `mask` (for Mask R-CNN models).
+
+ """
+ if self.training:
+ if targets is None:
+ torch._assert(False, "targets should not be none when in training mode")
+ else:
+ for target in targets:
+ boxes = target["boxes"]
+ torch._assert(isinstance(boxes, torch.Tensor), "Expected target boxes to be of type Tensor.")
+ torch._assert(
+ len(boxes.shape) == 2 and boxes.shape[-1] == 4,
+ "Expected target boxes to be a tensor of shape [N, 4].",
+ )
+
+ # get the original image sizes
+ original_image_sizes: List[Tuple[int, int]] = []
+ for img in images:
+ val = img.shape[-2:]
+ torch._assert(
+ len(val) == 2,
+ f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
+ )
+ original_image_sizes.append((val[0], val[1]))
+
+ # transform the input
+ images, targets = self.transform(images, targets)
+
+ # Check for degenerate boxes
+ # TODO: Move this to a function
+ if targets is not None:
+ for target_idx, target in enumerate(targets):
+ boxes = target["boxes"]
+ degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
+ if degenerate_boxes.any():
+ # print the first degenerate box
+ bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
+ degen_bb: List[float] = boxes[bb_idx].tolist()
+ torch._assert(
+ False,
+ "All bounding boxes should have positive height and width."
+ f" Found invalid box {degen_bb} for target at index {target_idx}.",
+ )
+
+ # get the features from the backbone
+ features = self.backbone(images.tensors)
+ if isinstance(features, torch.Tensor):
+ features = OrderedDict([("0", features)])
+
+ # TODO: Do we want a list or a dict?
+ features = list(features.values())
+
+ # compute the retinanet heads outputs using the features
+ head_outputs = self.head(features)
+
+ # create the set of anchors
+ anchors = self.anchor_generator(images, features)
+
+ losses = {}
+ detections: List[Dict[str, Tensor]] = []
+ if self.training:
+ if targets is None:
+ torch._assert(False, "targets should not be none when in training mode")
+ else:
+ # compute the losses
+ losses = self.compute_loss(targets, head_outputs, anchors)
+ else:
+ # recover level sizes
+ num_anchors_per_level = [x.size(2) * x.size(3) for x in features]
+ HW = 0
+ for v in num_anchors_per_level:
+ HW += v
+ HWA = head_outputs["cls_logits"].size(1)
+ A = HWA // HW
+ num_anchors_per_level = [hw * A for hw in num_anchors_per_level]
+
+ # split outputs per level
+ split_head_outputs: Dict[str, List[Tensor]] = {}
+ for k in head_outputs:
+ split_head_outputs[k] = list(head_outputs[k].split(num_anchors_per_level, dim=1))
+ split_anchors = [list(a.split(num_anchors_per_level)) for a in anchors]
+
+ # compute the detections
+ detections = self.postprocess_detections(split_head_outputs, split_anchors, images.image_sizes)
+ detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
+
+ if torch.jit.is_scripting():
+ if not self._has_warned:
+ warnings.warn("RetinaNet always returns a (Losses, Detections) tuple in scripting")
+ self._has_warned = True
+ return losses, detections
+ return self.eager_outputs(losses, detections)
+
+
+_COMMON_META = {
+ "categories": _COCO_CATEGORIES,
+ "min_size": (1, 1),
+}
+
+
+class RetinaNet_ResNet50_FPN_Weights(WeightsEnum):
+ COCO_V1 = Weights(
+ url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth",
+ transforms=ObjectDetection,
+ meta={
+ **_COMMON_META,
+ "num_params": 34014999,
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet",
+ "_metrics": {
+ "COCO-val2017": {
+ "box_map": 36.4,
+ }
+ },
+ "_ops": 151.54,
+ "_file_size": 130.267,
+ "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
+ },
+ )
+ DEFAULT = COCO_V1
+
+
+class RetinaNet_ResNet50_FPN_V2_Weights(WeightsEnum):
+ COCO_V1 = Weights(
+ url="https://download.pytorch.org/models/retinanet_resnet50_fpn_v2_coco-5905b1c5.pth",
+ transforms=ObjectDetection,
+ meta={
+ **_COMMON_META,
+ "num_params": 38198935,
+ "recipe": "https://github.com/pytorch/vision/pull/5756",
+ "_metrics": {
+ "COCO-val2017": {
+ "box_map": 41.5,
+ }
+ },
+ "_ops": 152.238,
+ "_file_size": 146.037,
+ "_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""",
+ },
+ )
+ DEFAULT = COCO_V1
+
+
+@register_model()
+@handle_legacy_interface(
+ weights=("pretrained", RetinaNet_ResNet50_FPN_Weights.COCO_V1),
+ weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
+)
+def retinanet_resnet50_fpn(
+ *,
+ weights: Optional[RetinaNet_ResNet50_FPN_Weights] = None,
+ progress: bool = True,
+ num_classes: Optional[int] = None,
+ weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
+ trainable_backbone_layers: Optional[int] = None,
+ **kwargs: Any,
+) -> RetinaNet:
+ """
+ Constructs a RetinaNet model with a ResNet-50-FPN backbone.
+
+ .. betastatus:: detection module
+
+ Reference: `Focal Loss for Dense Object Detection `_.
+
+ The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
+ image, and should be in ``0-1`` range. Different images can have different sizes.
+
+ The behavior of the model changes depending on if it is in training or evaluation mode.
+
+ During training, the model expects both the input tensors and targets (list of dictionary),
+ containing:
+
+ - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+ - labels (``Int64Tensor[N]``): the class label for each ground-truth box
+
+ The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
+ losses.
+
+ During inference, the model requires only the input tensors, and returns the post-processed
+ predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
+ follows, where ``N`` is the number of detections:
+
+ - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+ - labels (``Int64Tensor[N]``): the predicted labels for each detection
+ - scores (``Tensor[N]``): the scores of each detection
+
+ For more details on the output, you may refer to :ref:`instance_seg_output`.
+
+ Example::
+
+ >>> model = torchvision.models.detection.retinanet_resnet50_fpn(weights=RetinaNet_ResNet50_FPN_Weights.DEFAULT)
+ >>> model.eval()
+ >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+ >>> predictions = model(x)
+
+ Args:
+ weights (:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights`
+ below for more details, and possible values. By default, no
+ pre-trained weights are used.
+ progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
+ num_classes (int, optional): number of output classes of the model (including the background)
+ weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for
+ the backbone.
+ trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
+ Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
+ passed (the default) this value is set to 3.
+ **kwargs: parameters passed to the ``torchvision.models.detection.RetinaNet``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+
+ .. autoclass:: torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights
+ :members:
+ """
+ weights = RetinaNet_ResNet50_FPN_Weights.verify(weights)
+ weights_backbone = ResNet50_Weights.verify(weights_backbone)
+
+ if weights is not None:
+ weights_backbone = None
+ num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+ elif num_classes is None:
+ num_classes = 91
+
+ is_trained = weights is not None or weights_backbone is not None
+ trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
+ norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
+
+ backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
+ # skip P2 because it generates too many anchors (according to their paper)
+ backbone = _resnet_fpn_extractor(
+ backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256)
+ )
+ model = RetinaNet(backbone, num_classes, **kwargs)
+
+ if weights is not None:
+ model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+ if weights == RetinaNet_ResNet50_FPN_Weights.COCO_V1:
+ overwrite_eps(model, 0.0)
+
+ return model
+
+
+@register_model()
+@handle_legacy_interface(
+ weights=("pretrained", RetinaNet_ResNet50_FPN_V2_Weights.COCO_V1),
+ weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
+)
+def retinanet_resnet50_fpn_v2(
+ *,
+ weights: Optional[RetinaNet_ResNet50_FPN_V2_Weights] = None,
+ progress: bool = True,
+ num_classes: Optional[int] = None,
+ weights_backbone: Optional[ResNet50_Weights] = None,
+ trainable_backbone_layers: Optional[int] = None,
+ **kwargs: Any,
+) -> RetinaNet:
+ """
+ Constructs an improved RetinaNet model with a ResNet-50-FPN backbone.
+
+ .. betastatus:: detection module
+
+ Reference: `Bridging the Gap Between Anchor-based and Anchor-free Detection via Adaptive Training Sample Selection
+ `_.
+
+ :func:`~torchvision.models.detection.retinanet_resnet50_fpn` for more details.
+
+ Args:
+ weights (:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights`
+ below for more details, and possible values. By default, no
+ pre-trained weights are used.
+ progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
+ num_classes (int, optional): number of output classes of the model (including the background)
+ weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for
+ the backbone.
+ trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
+ Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
+ passed (the default) this value is set to 3.
+ **kwargs: parameters passed to the ``torchvision.models.detection.RetinaNet``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+
+ .. autoclass:: torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights
+ :members:
+ """
+ weights = RetinaNet_ResNet50_FPN_V2_Weights.verify(weights)
+ weights_backbone = ResNet50_Weights.verify(weights_backbone)
+
+ if weights is not None:
+ weights_backbone = None
+ num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+ elif num_classes is None:
+ num_classes = 91
+
+ is_trained = weights is not None or weights_backbone is not None
+ trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
+
+ backbone = resnet50(weights=weights_backbone, progress=progress)
+ backbone = _resnet_fpn_extractor(
+ backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(2048, 256)
+ )
+ anchor_generator = _default_anchorgen()
+ head = RetinaNetHead(
+ backbone.out_channels,
+ anchor_generator.num_anchors_per_location()[0],
+ num_classes,
+ norm_layer=partial(nn.GroupNorm, 32),
+ )
+ head.regression_head._loss_type = "giou"
+ model = RetinaNet(backbone, num_classes, anchor_generator=anchor_generator, head=head, **kwargs)
+
+ if weights is not None:
+ model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+
+ return model
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/detection/roi_heads.py b/.venv/lib/python3.11/site-packages/torchvision/models/detection/roi_heads.py
new file mode 100644
index 0000000000000000000000000000000000000000..51b210cb6f368c1f4914ffe99287efef6057cba4
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/models/detection/roi_heads.py
@@ -0,0 +1,876 @@
+from typing import Dict, List, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+import torchvision
+from torch import nn, Tensor
+from torchvision.ops import boxes as box_ops, roi_align
+
+from . import _utils as det_utils
+
+
+def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
+ # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
+ """
+ Computes the loss for Faster R-CNN.
+
+ Args:
+ class_logits (Tensor)
+ box_regression (Tensor)
+ labels (list[BoxList])
+ regression_targets (Tensor)
+
+ Returns:
+ classification_loss (Tensor)
+ box_loss (Tensor)
+ """
+
+ labels = torch.cat(labels, dim=0)
+ regression_targets = torch.cat(regression_targets, dim=0)
+
+ classification_loss = F.cross_entropy(class_logits, labels)
+
+ # get indices that correspond to the regression targets for
+ # the corresponding ground truth labels, to be used with
+ # advanced indexing
+ sampled_pos_inds_subset = torch.where(labels > 0)[0]
+ labels_pos = labels[sampled_pos_inds_subset]
+ N, num_classes = class_logits.shape
+ box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
+
+ box_loss = F.smooth_l1_loss(
+ box_regression[sampled_pos_inds_subset, labels_pos],
+ regression_targets[sampled_pos_inds_subset],
+ beta=1 / 9,
+ reduction="sum",
+ )
+ box_loss = box_loss / labels.numel()
+
+ return classification_loss, box_loss
+
+
+def maskrcnn_inference(x, labels):
+ # type: (Tensor, List[Tensor]) -> List[Tensor]
+ """
+ From the results of the CNN, post process the masks
+ by taking the mask corresponding to the class with max
+ probability (which are of fixed size and directly output
+ by the CNN) and return the masks in the mask field of the BoxList.
+
+ Args:
+ x (Tensor): the mask logits
+ labels (list[BoxList]): bounding boxes that are used as
+ reference, one for ech image
+
+ Returns:
+ results (list[BoxList]): one BoxList for each image, containing
+ the extra field mask
+ """
+ mask_prob = x.sigmoid()
+
+ # select masks corresponding to the predicted classes
+ num_masks = x.shape[0]
+ boxes_per_image = [label.shape[0] for label in labels]
+ labels = torch.cat(labels)
+ index = torch.arange(num_masks, device=labels.device)
+ mask_prob = mask_prob[index, labels][:, None]
+ mask_prob = mask_prob.split(boxes_per_image, dim=0)
+
+ return mask_prob
+
+
+def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
+ # type: (Tensor, Tensor, Tensor, int) -> Tensor
+ """
+ Given segmentation masks and the bounding boxes corresponding
+ to the location of the masks in the image, this function
+ crops and resizes the masks in the position defined by the
+ boxes. This prepares the masks for them to be fed to the
+ loss computation as the targets.
+ """
+ matched_idxs = matched_idxs.to(boxes)
+ rois = torch.cat([matched_idxs[:, None], boxes], dim=1)
+ gt_masks = gt_masks[:, None].to(rois)
+ return roi_align(gt_masks, rois, (M, M), 1.0)[:, 0]
+
+
+def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
+ # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor
+ """
+ Args:
+ proposals (list[BoxList])
+ mask_logits (Tensor)
+ targets (list[BoxList])
+
+ Return:
+ mask_loss (Tensor): scalar tensor containing the loss
+ """
+
+ discretization_size = mask_logits.shape[-1]
+ labels = [gt_label[idxs] for gt_label, idxs in zip(gt_labels, mask_matched_idxs)]
+ mask_targets = [
+ project_masks_on_boxes(m, p, i, discretization_size) for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
+ ]
+
+ labels = torch.cat(labels, dim=0)
+ mask_targets = torch.cat(mask_targets, dim=0)
+
+ # torch.mean (in binary_cross_entropy_with_logits) doesn't
+ # accept empty tensors, so handle it separately
+ if mask_targets.numel() == 0:
+ return mask_logits.sum() * 0
+
+ mask_loss = F.binary_cross_entropy_with_logits(
+ mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets
+ )
+ return mask_loss
+
+
+def keypoints_to_heatmap(keypoints, rois, heatmap_size):
+ # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
+ offset_x = rois[:, 0]
+ offset_y = rois[:, 1]
+ scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
+ scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
+
+ offset_x = offset_x[:, None]
+ offset_y = offset_y[:, None]
+ scale_x = scale_x[:, None]
+ scale_y = scale_y[:, None]
+
+ x = keypoints[..., 0]
+ y = keypoints[..., 1]
+
+ x_boundary_inds = x == rois[:, 2][:, None]
+ y_boundary_inds = y == rois[:, 3][:, None]
+
+ x = (x - offset_x) * scale_x
+ x = x.floor().long()
+ y = (y - offset_y) * scale_y
+ y = y.floor().long()
+
+ x[x_boundary_inds] = heatmap_size - 1
+ y[y_boundary_inds] = heatmap_size - 1
+
+ valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
+ vis = keypoints[..., 2] > 0
+ valid = (valid_loc & vis).long()
+
+ lin_ind = y * heatmap_size + x
+ heatmaps = lin_ind * valid
+
+ return heatmaps, valid
+
+
+def _onnx_heatmaps_to_keypoints(
+ maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
+):
+ num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64)
+
+ width_correction = widths_i / roi_map_width
+ height_correction = heights_i / roi_map_height
+
+ roi_map = F.interpolate(
+ maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode="bicubic", align_corners=False
+ )[:, 0]
+
+ w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
+ pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
+
+ x_int = pos % w
+ y_int = (pos - x_int) // w
+
+ x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * width_correction.to(
+ dtype=torch.float32
+ )
+ y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * height_correction.to(
+ dtype=torch.float32
+ )
+
+ xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32)
+ xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32)
+ xy_preds_i_2 = torch.ones(xy_preds_i_1.shape, dtype=torch.float32)
+ xy_preds_i = torch.stack(
+ [
+ xy_preds_i_0.to(dtype=torch.float32),
+ xy_preds_i_1.to(dtype=torch.float32),
+ xy_preds_i_2.to(dtype=torch.float32),
+ ],
+ 0,
+ )
+
+ # TODO: simplify when indexing without rank will be supported by ONNX
+ base = num_keypoints * num_keypoints + num_keypoints + 1
+ ind = torch.arange(num_keypoints)
+ ind = ind.to(dtype=torch.int64) * base
+ end_scores_i = (
+ roi_map.index_select(1, y_int.to(dtype=torch.int64))
+ .index_select(2, x_int.to(dtype=torch.int64))
+ .view(-1)
+ .index_select(0, ind.to(dtype=torch.int64))
+ )
+
+ return xy_preds_i, end_scores_i
+
+
+@torch.jit._script_if_tracing
+def _onnx_heatmaps_to_keypoints_loop(
+ maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
+):
+ xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
+ end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device)
+
+ for i in range(int(rois.size(0))):
+ xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(
+ maps, maps[i], widths_ceil[i], heights_ceil[i], widths[i], heights[i], offset_x[i], offset_y[i]
+ )
+ xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0)
+ end_scores = torch.cat(
+ (end_scores.to(dtype=torch.float32), end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0
+ )
+ return xy_preds, end_scores
+
+
+def heatmaps_to_keypoints(maps, rois):
+ """Extract predicted keypoint locations from heatmaps. Output has shape
+ (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
+ for each keypoint.
+ """
+ # This function converts a discrete image coordinate in a HEATMAP_SIZE x
+ # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
+ # consistency with keypoints_to_heatmap_labels by using the conversion from
+ # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
+ # continuous coordinate.
+ offset_x = rois[:, 0]
+ offset_y = rois[:, 1]
+
+ widths = rois[:, 2] - rois[:, 0]
+ heights = rois[:, 3] - rois[:, 1]
+ widths = widths.clamp(min=1)
+ heights = heights.clamp(min=1)
+ widths_ceil = widths.ceil()
+ heights_ceil = heights.ceil()
+
+ num_keypoints = maps.shape[1]
+
+ if torchvision._is_tracing():
+ xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
+ maps,
+ rois,
+ widths_ceil,
+ heights_ceil,
+ widths,
+ heights,
+ offset_x,
+ offset_y,
+ torch.scalar_tensor(num_keypoints, dtype=torch.int64),
+ )
+ return xy_preds.permute(0, 2, 1), end_scores
+
+ xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
+ end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
+ for i in range(len(rois)):
+ roi_map_width = int(widths_ceil[i].item())
+ roi_map_height = int(heights_ceil[i].item())
+ width_correction = widths[i] / roi_map_width
+ height_correction = heights[i] / roi_map_height
+ roi_map = F.interpolate(
+ maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
+ )[:, 0]
+ # roi_map_probs = scores_to_probs(roi_map.copy())
+ w = roi_map.shape[2]
+ pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
+
+ x_int = pos % w
+ y_int = torch.div(pos - x_int, w, rounding_mode="floor")
+ # assert (roi_map_probs[k, y_int, x_int] ==
+ # roi_map_probs[k, :, :].max())
+ x = (x_int.float() + 0.5) * width_correction
+ y = (y_int.float() + 0.5) * height_correction
+ xy_preds[i, 0, :] = x + offset_x[i]
+ xy_preds[i, 1, :] = y + offset_y[i]
+ xy_preds[i, 2, :] = 1
+ end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
+
+ return xy_preds.permute(0, 2, 1), end_scores
+
+
+def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
+ # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
+ N, K, H, W = keypoint_logits.shape
+ if H != W:
+ raise ValueError(
+ f"keypoint_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
+ )
+ discretization_size = H
+ heatmaps = []
+ valid = []
+ for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
+ kp = gt_kp_in_image[midx]
+ heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size)
+ heatmaps.append(heatmaps_per_image.view(-1))
+ valid.append(valid_per_image.view(-1))
+
+ keypoint_targets = torch.cat(heatmaps, dim=0)
+ valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
+ valid = torch.where(valid)[0]
+
+ # torch.mean (in binary_cross_entropy_with_logits) doesn't
+ # accept empty tensors, so handle it sepaartely
+ if keypoint_targets.numel() == 0 or len(valid) == 0:
+ return keypoint_logits.sum() * 0
+
+ keypoint_logits = keypoint_logits.view(N * K, H * W)
+
+ keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid])
+ return keypoint_loss
+
+
+def keypointrcnn_inference(x, boxes):
+ # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
+ kp_probs = []
+ kp_scores = []
+
+ boxes_per_image = [box.size(0) for box in boxes]
+ x2 = x.split(boxes_per_image, dim=0)
+
+ for xx, bb in zip(x2, boxes):
+ kp_prob, scores = heatmaps_to_keypoints(xx, bb)
+ kp_probs.append(kp_prob)
+ kp_scores.append(scores)
+
+ return kp_probs, kp_scores
+
+
+def _onnx_expand_boxes(boxes, scale):
+ # type: (Tensor, float) -> Tensor
+ w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
+ h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
+ x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
+ y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
+
+ w_half = w_half.to(dtype=torch.float32) * scale
+ h_half = h_half.to(dtype=torch.float32) * scale
+
+ boxes_exp0 = x_c - w_half
+ boxes_exp1 = y_c - h_half
+ boxes_exp2 = x_c + w_half
+ boxes_exp3 = y_c + h_half
+ boxes_exp = torch.stack((boxes_exp0, boxes_exp1, boxes_exp2, boxes_exp3), 1)
+ return boxes_exp
+
+
+# the next two functions should be merged inside Masker
+# but are kept here for the moment while we need them
+# temporarily for paste_mask_in_image
+def expand_boxes(boxes, scale):
+ # type: (Tensor, float) -> Tensor
+ if torchvision._is_tracing():
+ return _onnx_expand_boxes(boxes, scale)
+ w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
+ h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
+ x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
+ y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
+
+ w_half *= scale
+ h_half *= scale
+
+ boxes_exp = torch.zeros_like(boxes)
+ boxes_exp[:, 0] = x_c - w_half
+ boxes_exp[:, 2] = x_c + w_half
+ boxes_exp[:, 1] = y_c - h_half
+ boxes_exp[:, 3] = y_c + h_half
+ return boxes_exp
+
+
+@torch.jit.unused
+def expand_masks_tracing_scale(M, padding):
+ # type: (int, int) -> float
+ return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)
+
+
+def expand_masks(mask, padding):
+ # type: (Tensor, int) -> Tuple[Tensor, float]
+ M = mask.shape[-1]
+ if torch._C._get_tracing_state(): # could not import is_tracing(), not sure why
+ scale = expand_masks_tracing_scale(M, padding)
+ else:
+ scale = float(M + 2 * padding) / M
+ padded_mask = F.pad(mask, (padding,) * 4)
+ return padded_mask, scale
+
+
+def paste_mask_in_image(mask, box, im_h, im_w):
+ # type: (Tensor, Tensor, int, int) -> Tensor
+ TO_REMOVE = 1
+ w = int(box[2] - box[0] + TO_REMOVE)
+ h = int(box[3] - box[1] + TO_REMOVE)
+ w = max(w, 1)
+ h = max(h, 1)
+
+ # Set shape to [batchxCxHxW]
+ mask = mask.expand((1, 1, -1, -1))
+
+ # Resize mask
+ mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False)
+ mask = mask[0][0]
+
+ im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
+ x_0 = max(box[0], 0)
+ x_1 = min(box[2] + 1, im_w)
+ y_0 = max(box[1], 0)
+ y_1 = min(box[3] + 1, im_h)
+
+ im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
+ return im_mask
+
+
+def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
+ one = torch.ones(1, dtype=torch.int64)
+ zero = torch.zeros(1, dtype=torch.int64)
+
+ w = box[2] - box[0] + one
+ h = box[3] - box[1] + one
+ w = torch.max(torch.cat((w, one)))
+ h = torch.max(torch.cat((h, one)))
+
+ # Set shape to [batchxCxHxW]
+ mask = mask.expand((1, 1, mask.size(0), mask.size(1)))
+
+ # Resize mask
+ mask = F.interpolate(mask, size=(int(h), int(w)), mode="bilinear", align_corners=False)
+ mask = mask[0][0]
+
+ x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
+ x_1 = torch.min(torch.cat((box[2].unsqueeze(0) + one, im_w.unsqueeze(0))))
+ y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
+ y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))
+
+ unpaded_im_mask = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
+
+ # TODO : replace below with a dynamic padding when support is added in ONNX
+
+ # pad y
+ zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1))
+ zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1))
+ concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0:im_h, :]
+ # pad x
+ zeros_x0 = torch.zeros(concat_0.size(0), x_0)
+ zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1)
+ im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, :im_w]
+ return im_mask
+
+
+@torch.jit._script_if_tracing
+def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
+ res_append = torch.zeros(0, im_h, im_w)
+ for i in range(masks.size(0)):
+ mask_res = _onnx_paste_mask_in_image(masks[i][0], boxes[i], im_h, im_w)
+ mask_res = mask_res.unsqueeze(0)
+ res_append = torch.cat((res_append, mask_res))
+ return res_append
+
+
+def paste_masks_in_image(masks, boxes, img_shape, padding=1):
+ # type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor
+ masks, scale = expand_masks(masks, padding=padding)
+ boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
+ im_h, im_w = img_shape
+
+ if torchvision._is_tracing():
+ return _onnx_paste_masks_in_image_loop(
+ masks, boxes, torch.scalar_tensor(im_h, dtype=torch.int64), torch.scalar_tensor(im_w, dtype=torch.int64)
+ )[:, None]
+ res = [paste_mask_in_image(m[0], b, im_h, im_w) for m, b in zip(masks, boxes)]
+ if len(res) > 0:
+ ret = torch.stack(res, dim=0)[:, None]
+ else:
+ ret = masks.new_empty((0, 1, im_h, im_w))
+ return ret
+
+
+class RoIHeads(nn.Module):
+ __annotations__ = {
+ "box_coder": det_utils.BoxCoder,
+ "proposal_matcher": det_utils.Matcher,
+ "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
+ }
+
+ def __init__(
+ self,
+ box_roi_pool,
+ box_head,
+ box_predictor,
+ # Faster R-CNN training
+ fg_iou_thresh,
+ bg_iou_thresh,
+ batch_size_per_image,
+ positive_fraction,
+ bbox_reg_weights,
+ # Faster R-CNN inference
+ score_thresh,
+ nms_thresh,
+ detections_per_img,
+ # Mask
+ mask_roi_pool=None,
+ mask_head=None,
+ mask_predictor=None,
+ keypoint_roi_pool=None,
+ keypoint_head=None,
+ keypoint_predictor=None,
+ ):
+ super().__init__()
+
+ self.box_similarity = box_ops.box_iou
+ # assign ground-truth boxes for each proposal
+ self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False)
+
+ self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
+
+ if bbox_reg_weights is None:
+ bbox_reg_weights = (10.0, 10.0, 5.0, 5.0)
+ self.box_coder = det_utils.BoxCoder(bbox_reg_weights)
+
+ self.box_roi_pool = box_roi_pool
+ self.box_head = box_head
+ self.box_predictor = box_predictor
+
+ self.score_thresh = score_thresh
+ self.nms_thresh = nms_thresh
+ self.detections_per_img = detections_per_img
+
+ self.mask_roi_pool = mask_roi_pool
+ self.mask_head = mask_head
+ self.mask_predictor = mask_predictor
+
+ self.keypoint_roi_pool = keypoint_roi_pool
+ self.keypoint_head = keypoint_head
+ self.keypoint_predictor = keypoint_predictor
+
+ def has_mask(self):
+ if self.mask_roi_pool is None:
+ return False
+ if self.mask_head is None:
+ return False
+ if self.mask_predictor is None:
+ return False
+ return True
+
+ def has_keypoint(self):
+ if self.keypoint_roi_pool is None:
+ return False
+ if self.keypoint_head is None:
+ return False
+ if self.keypoint_predictor is None:
+ return False
+ return True
+
+ def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
+ # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
+ matched_idxs = []
+ labels = []
+ for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
+
+ if gt_boxes_in_image.numel() == 0:
+ # Background image
+ device = proposals_in_image.device
+ clamped_matched_idxs_in_image = torch.zeros(
+ (proposals_in_image.shape[0],), dtype=torch.int64, device=device
+ )
+ labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
+ else:
+ # set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
+ match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
+ matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
+
+ clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
+
+ labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
+ labels_in_image = labels_in_image.to(dtype=torch.int64)
+
+ # Label background (below the low threshold)
+ bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
+ labels_in_image[bg_inds] = 0
+
+ # Label ignore proposals (between low and high thresholds)
+ ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
+ labels_in_image[ignore_inds] = -1 # -1 is ignored by sampler
+
+ matched_idxs.append(clamped_matched_idxs_in_image)
+ labels.append(labels_in_image)
+ return matched_idxs, labels
+
+ def subsample(self, labels):
+ # type: (List[Tensor]) -> List[Tensor]
+ sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
+ sampled_inds = []
+ for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
+ img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
+ sampled_inds.append(img_sampled_inds)
+ return sampled_inds
+
+ def add_gt_proposals(self, proposals, gt_boxes):
+ # type: (List[Tensor], List[Tensor]) -> List[Tensor]
+ proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)]
+
+ return proposals
+
+ def check_targets(self, targets):
+ # type: (Optional[List[Dict[str, Tensor]]]) -> None
+ if targets is None:
+ raise ValueError("targets should not be None")
+ if not all(["boxes" in t for t in targets]):
+ raise ValueError("Every element of targets should have a boxes key")
+ if not all(["labels" in t for t in targets]):
+ raise ValueError("Every element of targets should have a labels key")
+ if self.has_mask():
+ if not all(["masks" in t for t in targets]):
+ raise ValueError("Every element of targets should have a masks key")
+
+ def select_training_samples(
+ self,
+ proposals, # type: List[Tensor]
+ targets, # type: Optional[List[Dict[str, Tensor]]]
+ ):
+ # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
+ self.check_targets(targets)
+ if targets is None:
+ raise ValueError("targets should not be None")
+ dtype = proposals[0].dtype
+ device = proposals[0].device
+
+ gt_boxes = [t["boxes"].to(dtype) for t in targets]
+ gt_labels = [t["labels"] for t in targets]
+
+ # append ground-truth bboxes to propos
+ proposals = self.add_gt_proposals(proposals, gt_boxes)
+
+ # get matching gt indices for each proposal
+ matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
+ # sample a fixed proportion of positive-negative proposals
+ sampled_inds = self.subsample(labels)
+ matched_gt_boxes = []
+ num_images = len(proposals)
+ for img_id in range(num_images):
+ img_sampled_inds = sampled_inds[img_id]
+ proposals[img_id] = proposals[img_id][img_sampled_inds]
+ labels[img_id] = labels[img_id][img_sampled_inds]
+ matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
+
+ gt_boxes_in_image = gt_boxes[img_id]
+ if gt_boxes_in_image.numel() == 0:
+ gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device)
+ matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])
+
+ regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
+ return proposals, matched_idxs, labels, regression_targets
+
+ def postprocess_detections(
+ self,
+ class_logits, # type: Tensor
+ box_regression, # type: Tensor
+ proposals, # type: List[Tensor]
+ image_shapes, # type: List[Tuple[int, int]]
+ ):
+ # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
+ device = class_logits.device
+ num_classes = class_logits.shape[-1]
+
+ boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
+ pred_boxes = self.box_coder.decode(box_regression, proposals)
+
+ pred_scores = F.softmax(class_logits, -1)
+
+ pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
+ pred_scores_list = pred_scores.split(boxes_per_image, 0)
+
+ all_boxes = []
+ all_scores = []
+ all_labels = []
+ for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
+ boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
+
+ # create labels for each prediction
+ labels = torch.arange(num_classes, device=device)
+ labels = labels.view(1, -1).expand_as(scores)
+
+ # remove predictions with the background label
+ boxes = boxes[:, 1:]
+ scores = scores[:, 1:]
+ labels = labels[:, 1:]
+
+ # batch everything, by making every class prediction be a separate instance
+ boxes = boxes.reshape(-1, 4)
+ scores = scores.reshape(-1)
+ labels = labels.reshape(-1)
+
+ # remove low scoring boxes
+ inds = torch.where(scores > self.score_thresh)[0]
+ boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
+
+ # remove empty boxes
+ keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
+ boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
+
+ # non-maximum suppression, independently done per class
+ keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
+ # keep only topk scoring predictions
+ keep = keep[: self.detections_per_img]
+ boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
+
+ all_boxes.append(boxes)
+ all_scores.append(scores)
+ all_labels.append(labels)
+
+ return all_boxes, all_scores, all_labels
+
+ def forward(
+ self,
+ features, # type: Dict[str, Tensor]
+ proposals, # type: List[Tensor]
+ image_shapes, # type: List[Tuple[int, int]]
+ targets=None, # type: Optional[List[Dict[str, Tensor]]]
+ ):
+ # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
+ """
+ Args:
+ features (List[Tensor])
+ proposals (List[Tensor[N, 4]])
+ image_shapes (List[Tuple[H, W]])
+ targets (List[Dict])
+ """
+ if targets is not None:
+ for t in targets:
+ # TODO: https://github.com/pytorch/pytorch/issues/26731
+ floating_point_types = (torch.float, torch.double, torch.half)
+ if not t["boxes"].dtype in floating_point_types:
+ raise TypeError(f"target boxes must of float type, instead got {t['boxes'].dtype}")
+ if not t["labels"].dtype == torch.int64:
+ raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}")
+ if self.has_keypoint():
+ if not t["keypoints"].dtype == torch.float32:
+ raise TypeError(f"target keypoints must of float type, instead got {t['keypoints'].dtype}")
+
+ if self.training:
+ proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
+ else:
+ labels = None
+ regression_targets = None
+ matched_idxs = None
+
+ box_features = self.box_roi_pool(features, proposals, image_shapes)
+ box_features = self.box_head(box_features)
+ class_logits, box_regression = self.box_predictor(box_features)
+
+ result: List[Dict[str, torch.Tensor]] = []
+ losses = {}
+ if self.training:
+ if labels is None:
+ raise ValueError("labels cannot be None")
+ if regression_targets is None:
+ raise ValueError("regression_targets cannot be None")
+ loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
+ losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
+ else:
+ boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
+ num_images = len(boxes)
+ for i in range(num_images):
+ result.append(
+ {
+ "boxes": boxes[i],
+ "labels": labels[i],
+ "scores": scores[i],
+ }
+ )
+
+ if self.has_mask():
+ mask_proposals = [p["boxes"] for p in result]
+ if self.training:
+ if matched_idxs is None:
+ raise ValueError("if in training, matched_idxs should not be None")
+
+ # during training, only focus on positive boxes
+ num_images = len(proposals)
+ mask_proposals = []
+ pos_matched_idxs = []
+ for img_id in range(num_images):
+ pos = torch.where(labels[img_id] > 0)[0]
+ mask_proposals.append(proposals[img_id][pos])
+ pos_matched_idxs.append(matched_idxs[img_id][pos])
+ else:
+ pos_matched_idxs = None
+
+ if self.mask_roi_pool is not None:
+ mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
+ mask_features = self.mask_head(mask_features)
+ mask_logits = self.mask_predictor(mask_features)
+ else:
+ raise Exception("Expected mask_roi_pool to be not None")
+
+ loss_mask = {}
+ if self.training:
+ if targets is None or pos_matched_idxs is None or mask_logits is None:
+ raise ValueError("targets, pos_matched_idxs, mask_logits cannot be None when training")
+
+ gt_masks = [t["masks"] for t in targets]
+ gt_labels = [t["labels"] for t in targets]
+ rcnn_loss_mask = maskrcnn_loss(mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs)
+ loss_mask = {"loss_mask": rcnn_loss_mask}
+ else:
+ labels = [r["labels"] for r in result]
+ masks_probs = maskrcnn_inference(mask_logits, labels)
+ for mask_prob, r in zip(masks_probs, result):
+ r["masks"] = mask_prob
+
+ losses.update(loss_mask)
+
+ # keep none checks in if conditional so torchscript will conditionally
+ # compile each branch
+ if (
+ self.keypoint_roi_pool is not None
+ and self.keypoint_head is not None
+ and self.keypoint_predictor is not None
+ ):
+ keypoint_proposals = [p["boxes"] for p in result]
+ if self.training:
+ # during training, only focus on positive boxes
+ num_images = len(proposals)
+ keypoint_proposals = []
+ pos_matched_idxs = []
+ if matched_idxs is None:
+ raise ValueError("if in trainning, matched_idxs should not be None")
+
+ for img_id in range(num_images):
+ pos = torch.where(labels[img_id] > 0)[0]
+ keypoint_proposals.append(proposals[img_id][pos])
+ pos_matched_idxs.append(matched_idxs[img_id][pos])
+ else:
+ pos_matched_idxs = None
+
+ keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes)
+ keypoint_features = self.keypoint_head(keypoint_features)
+ keypoint_logits = self.keypoint_predictor(keypoint_features)
+
+ loss_keypoint = {}
+ if self.training:
+ if targets is None or pos_matched_idxs is None:
+ raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
+
+ gt_keypoints = [t["keypoints"] for t in targets]
+ rcnn_loss_keypoint = keypointrcnn_loss(
+ keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs
+ )
+ loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
+ else:
+ if keypoint_logits is None or keypoint_proposals is None:
+ raise ValueError(
+ "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
+ )
+
+ keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
+ for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
+ r["keypoints"] = keypoint_prob
+ r["keypoints_scores"] = kps
+ losses.update(loss_keypoint)
+
+ return result, losses
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/detection/rpn.py b/.venv/lib/python3.11/site-packages/torchvision/models/detection/rpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..f103181e4c6cba48c1a3b4c97583c5fb6785a8c4
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/models/detection/rpn.py
@@ -0,0 +1,388 @@
+from typing import Dict, List, Optional, Tuple
+
+import torch
+from torch import nn, Tensor
+from torch.nn import functional as F
+from torchvision.ops import boxes as box_ops, Conv2dNormActivation
+
+from . import _utils as det_utils
+
+# Import AnchorGenerator to keep compatibility.
+from .anchor_utils import AnchorGenerator # noqa: 401
+from .image_list import ImageList
+
+
+class RPNHead(nn.Module):
+ """
+ Adds a simple RPN Head with classification and regression heads
+
+ Args:
+ in_channels (int): number of channels of the input feature
+ num_anchors (int): number of anchors to be predicted
+ conv_depth (int, optional): number of convolutions
+ """
+
+ _version = 2
+
+ def __init__(self, in_channels: int, num_anchors: int, conv_depth=1) -> None:
+ super().__init__()
+ convs = []
+ for _ in range(conv_depth):
+ convs.append(Conv2dNormActivation(in_channels, in_channels, kernel_size=3, norm_layer=None))
+ self.conv = nn.Sequential(*convs)
+ self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1)
+ self.bbox_pred = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=1, stride=1)
+
+ for layer in self.modules():
+ if isinstance(layer, nn.Conv2d):
+ torch.nn.init.normal_(layer.weight, std=0.01) # type: ignore[arg-type]
+ if layer.bias is not None:
+ torch.nn.init.constant_(layer.bias, 0) # type: ignore[arg-type]
+
+ def _load_from_state_dict(
+ self,
+ state_dict,
+ prefix,
+ local_metadata,
+ strict,
+ missing_keys,
+ unexpected_keys,
+ error_msgs,
+ ):
+ version = local_metadata.get("version", None)
+
+ if version is None or version < 2:
+ for type in ["weight", "bias"]:
+ old_key = f"{prefix}conv.{type}"
+ new_key = f"{prefix}conv.0.0.{type}"
+ if old_key in state_dict:
+ state_dict[new_key] = state_dict.pop(old_key)
+
+ super()._load_from_state_dict(
+ state_dict,
+ prefix,
+ local_metadata,
+ strict,
+ missing_keys,
+ unexpected_keys,
+ error_msgs,
+ )
+
+ def forward(self, x: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
+ logits = []
+ bbox_reg = []
+ for feature in x:
+ t = self.conv(feature)
+ logits.append(self.cls_logits(t))
+ bbox_reg.append(self.bbox_pred(t))
+ return logits, bbox_reg
+
+
+def permute_and_flatten(layer: Tensor, N: int, A: int, C: int, H: int, W: int) -> Tensor:
+ layer = layer.view(N, -1, C, H, W)
+ layer = layer.permute(0, 3, 4, 1, 2)
+ layer = layer.reshape(N, -1, C)
+ return layer
+
+
+def concat_box_prediction_layers(box_cls: List[Tensor], box_regression: List[Tensor]) -> Tuple[Tensor, Tensor]:
+ box_cls_flattened = []
+ box_regression_flattened = []
+ # for each feature level, permute the outputs to make them be in the
+ # same format as the labels. Note that the labels are computed for
+ # all feature levels concatenated, so we keep the same representation
+ # for the objectness and the box_regression
+ for box_cls_per_level, box_regression_per_level in zip(box_cls, box_regression):
+ N, AxC, H, W = box_cls_per_level.shape
+ Ax4 = box_regression_per_level.shape[1]
+ A = Ax4 // 4
+ C = AxC // A
+ box_cls_per_level = permute_and_flatten(box_cls_per_level, N, A, C, H, W)
+ box_cls_flattened.append(box_cls_per_level)
+
+ box_regression_per_level = permute_and_flatten(box_regression_per_level, N, A, 4, H, W)
+ box_regression_flattened.append(box_regression_per_level)
+ # concatenate on the first dimension (representing the feature levels), to
+ # take into account the way the labels were generated (with all feature maps
+ # being concatenated as well)
+ box_cls = torch.cat(box_cls_flattened, dim=1).flatten(0, -2)
+ box_regression = torch.cat(box_regression_flattened, dim=1).reshape(-1, 4)
+ return box_cls, box_regression
+
+
+class RegionProposalNetwork(torch.nn.Module):
+ """
+ Implements Region Proposal Network (RPN).
+
+ Args:
+ anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
+ maps.
+ head (nn.Module): module that computes the objectness and regression deltas
+ fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
+ considered as positive during training of the RPN.
+ bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
+ considered as negative during training of the RPN.
+ batch_size_per_image (int): number of anchors that are sampled during training of the RPN
+ for computing the loss
+ positive_fraction (float): proportion of positive anchors in a mini-batch during training
+ of the RPN
+ pre_nms_top_n (Dict[str, int]): number of proposals to keep before applying NMS. It should
+ contain two fields: training and testing, to allow for different values depending
+ on training or evaluation
+ post_nms_top_n (Dict[str, int]): number of proposals to keep after applying NMS. It should
+ contain two fields: training and testing, to allow for different values depending
+ on training or evaluation
+ nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
+ score_thresh (float): only return proposals with an objectness score greater than score_thresh
+
+ """
+
+ __annotations__ = {
+ "box_coder": det_utils.BoxCoder,
+ "proposal_matcher": det_utils.Matcher,
+ "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
+ }
+
+ def __init__(
+ self,
+ anchor_generator: AnchorGenerator,
+ head: nn.Module,
+ # Faster-RCNN Training
+ fg_iou_thresh: float,
+ bg_iou_thresh: float,
+ batch_size_per_image: int,
+ positive_fraction: float,
+ # Faster-RCNN Inference
+ pre_nms_top_n: Dict[str, int],
+ post_nms_top_n: Dict[str, int],
+ nms_thresh: float,
+ score_thresh: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.anchor_generator = anchor_generator
+ self.head = head
+ self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
+
+ # used during training
+ self.box_similarity = box_ops.box_iou
+
+ self.proposal_matcher = det_utils.Matcher(
+ fg_iou_thresh,
+ bg_iou_thresh,
+ allow_low_quality_matches=True,
+ )
+
+ self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
+ # used during testing
+ self._pre_nms_top_n = pre_nms_top_n
+ self._post_nms_top_n = post_nms_top_n
+ self.nms_thresh = nms_thresh
+ self.score_thresh = score_thresh
+ self.min_size = 1e-3
+
+ def pre_nms_top_n(self) -> int:
+ if self.training:
+ return self._pre_nms_top_n["training"]
+ return self._pre_nms_top_n["testing"]
+
+ def post_nms_top_n(self) -> int:
+ if self.training:
+ return self._post_nms_top_n["training"]
+ return self._post_nms_top_n["testing"]
+
+ def assign_targets_to_anchors(
+ self, anchors: List[Tensor], targets: List[Dict[str, Tensor]]
+ ) -> Tuple[List[Tensor], List[Tensor]]:
+
+ labels = []
+ matched_gt_boxes = []
+ for anchors_per_image, targets_per_image in zip(anchors, targets):
+ gt_boxes = targets_per_image["boxes"]
+
+ if gt_boxes.numel() == 0:
+ # Background image (negative example)
+ device = anchors_per_image.device
+ matched_gt_boxes_per_image = torch.zeros(anchors_per_image.shape, dtype=torch.float32, device=device)
+ labels_per_image = torch.zeros((anchors_per_image.shape[0],), dtype=torch.float32, device=device)
+ else:
+ match_quality_matrix = self.box_similarity(gt_boxes, anchors_per_image)
+ matched_idxs = self.proposal_matcher(match_quality_matrix)
+ # get the targets corresponding GT for each proposal
+ # NB: need to clamp the indices because we can have a single
+ # GT in the image, and matched_idxs can be -2, which goes
+ # out of bounds
+ matched_gt_boxes_per_image = gt_boxes[matched_idxs.clamp(min=0)]
+
+ labels_per_image = matched_idxs >= 0
+ labels_per_image = labels_per_image.to(dtype=torch.float32)
+
+ # Background (negative examples)
+ bg_indices = matched_idxs == self.proposal_matcher.BELOW_LOW_THRESHOLD
+ labels_per_image[bg_indices] = 0.0
+
+ # discard indices that are between thresholds
+ inds_to_discard = matched_idxs == self.proposal_matcher.BETWEEN_THRESHOLDS
+ labels_per_image[inds_to_discard] = -1.0
+
+ labels.append(labels_per_image)
+ matched_gt_boxes.append(matched_gt_boxes_per_image)
+ return labels, matched_gt_boxes
+
+ def _get_top_n_idx(self, objectness: Tensor, num_anchors_per_level: List[int]) -> Tensor:
+ r = []
+ offset = 0
+ for ob in objectness.split(num_anchors_per_level, 1):
+ num_anchors = ob.shape[1]
+ pre_nms_top_n = det_utils._topk_min(ob, self.pre_nms_top_n(), 1)
+ _, top_n_idx = ob.topk(pre_nms_top_n, dim=1)
+ r.append(top_n_idx + offset)
+ offset += num_anchors
+ return torch.cat(r, dim=1)
+
+ def filter_proposals(
+ self,
+ proposals: Tensor,
+ objectness: Tensor,
+ image_shapes: List[Tuple[int, int]],
+ num_anchors_per_level: List[int],
+ ) -> Tuple[List[Tensor], List[Tensor]]:
+
+ num_images = proposals.shape[0]
+ device = proposals.device
+ # do not backprop through objectness
+ objectness = objectness.detach()
+ objectness = objectness.reshape(num_images, -1)
+
+ levels = [
+ torch.full((n,), idx, dtype=torch.int64, device=device) for idx, n in enumerate(num_anchors_per_level)
+ ]
+ levels = torch.cat(levels, 0)
+ levels = levels.reshape(1, -1).expand_as(objectness)
+
+ # select top_n boxes independently per level before applying nms
+ top_n_idx = self._get_top_n_idx(objectness, num_anchors_per_level)
+
+ image_range = torch.arange(num_images, device=device)
+ batch_idx = image_range[:, None]
+
+ objectness = objectness[batch_idx, top_n_idx]
+ levels = levels[batch_idx, top_n_idx]
+ proposals = proposals[batch_idx, top_n_idx]
+
+ objectness_prob = torch.sigmoid(objectness)
+
+ final_boxes = []
+ final_scores = []
+ for boxes, scores, lvl, img_shape in zip(proposals, objectness_prob, levels, image_shapes):
+ boxes = box_ops.clip_boxes_to_image(boxes, img_shape)
+
+ # remove small boxes
+ keep = box_ops.remove_small_boxes(boxes, self.min_size)
+ boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
+
+ # remove low scoring boxes
+ # use >= for Backwards compatibility
+ keep = torch.where(scores >= self.score_thresh)[0]
+ boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
+
+ # non-maximum suppression, independently done per level
+ keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh)
+
+ # keep only topk scoring predictions
+ keep = keep[: self.post_nms_top_n()]
+ boxes, scores = boxes[keep], scores[keep]
+
+ final_boxes.append(boxes)
+ final_scores.append(scores)
+ return final_boxes, final_scores
+
+ def compute_loss(
+ self, objectness: Tensor, pred_bbox_deltas: Tensor, labels: List[Tensor], regression_targets: List[Tensor]
+ ) -> Tuple[Tensor, Tensor]:
+ """
+ Args:
+ objectness (Tensor)
+ pred_bbox_deltas (Tensor)
+ labels (List[Tensor])
+ regression_targets (List[Tensor])
+
+ Returns:
+ objectness_loss (Tensor)
+ box_loss (Tensor)
+ """
+
+ sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
+ sampled_pos_inds = torch.where(torch.cat(sampled_pos_inds, dim=0))[0]
+ sampled_neg_inds = torch.where(torch.cat(sampled_neg_inds, dim=0))[0]
+
+ sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)
+
+ objectness = objectness.flatten()
+
+ labels = torch.cat(labels, dim=0)
+ regression_targets = torch.cat(regression_targets, dim=0)
+
+ box_loss = F.smooth_l1_loss(
+ pred_bbox_deltas[sampled_pos_inds],
+ regression_targets[sampled_pos_inds],
+ beta=1 / 9,
+ reduction="sum",
+ ) / (sampled_inds.numel())
+
+ objectness_loss = F.binary_cross_entropy_with_logits(objectness[sampled_inds], labels[sampled_inds])
+
+ return objectness_loss, box_loss
+
+ def forward(
+ self,
+ images: ImageList,
+ features: Dict[str, Tensor],
+ targets: Optional[List[Dict[str, Tensor]]] = None,
+ ) -> Tuple[List[Tensor], Dict[str, Tensor]]:
+
+ """
+ Args:
+ images (ImageList): images for which we want to compute the predictions
+ features (Dict[str, Tensor]): features computed from the images that are
+ used for computing the predictions. Each tensor in the list
+ correspond to different feature levels
+ targets (List[Dict[str, Tensor]]): ground-truth boxes present in the image (optional).
+ If provided, each element in the dict should contain a field `boxes`,
+ with the locations of the ground-truth boxes.
+
+ Returns:
+ boxes (List[Tensor]): the predicted boxes from the RPN, one Tensor per
+ image.
+ losses (Dict[str, Tensor]): the losses for the model during training. During
+ testing, it is an empty dict.
+ """
+ # RPN uses all feature maps that are available
+ features = list(features.values())
+ objectness, pred_bbox_deltas = self.head(features)
+ anchors = self.anchor_generator(images, features)
+
+ num_images = len(anchors)
+ num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness]
+ num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors]
+ objectness, pred_bbox_deltas = concat_box_prediction_layers(objectness, pred_bbox_deltas)
+ # apply pred_bbox_deltas to anchors to obtain the decoded proposals
+ # note that we detach the deltas because Faster R-CNN do not backprop through
+ # the proposals
+ proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors)
+ proposals = proposals.view(num_images, -1, 4)
+ boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)
+
+ losses = {}
+ if self.training:
+ if targets is None:
+ raise ValueError("targets should not be None")
+ labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets)
+ regression_targets = self.box_coder.encode(matched_gt_boxes, anchors)
+ loss_objectness, loss_rpn_box_reg = self.compute_loss(
+ objectness, pred_bbox_deltas, labels, regression_targets
+ )
+ losses = {
+ "loss_objectness": loss_objectness,
+ "loss_rpn_box_reg": loss_rpn_box_reg,
+ }
+ return boxes, losses
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/detection/ssd.py b/.venv/lib/python3.11/site-packages/torchvision/models/detection/ssd.py
new file mode 100644
index 0000000000000000000000000000000000000000..87062d2bc88a5bf17625e0530116aba22941c538
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/models/detection/ssd.py
@@ -0,0 +1,682 @@
+import warnings
+from collections import OrderedDict
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+from torch import nn, Tensor
+
+from ...ops import boxes as box_ops
+from ...transforms._presets import ObjectDetection
+from ...utils import _log_api_usage_once
+from .._api import register_model, Weights, WeightsEnum
+from .._meta import _COCO_CATEGORIES
+from .._utils import _ovewrite_value_param, handle_legacy_interface
+from ..vgg import VGG, vgg16, VGG16_Weights
+from . import _utils as det_utils
+from .anchor_utils import DefaultBoxGenerator
+from .backbone_utils import _validate_trainable_layers
+from .transform import GeneralizedRCNNTransform
+
+
+__all__ = [
+ "SSD300_VGG16_Weights",
+ "ssd300_vgg16",
+]
+
+
+class SSD300_VGG16_Weights(WeightsEnum):
+ COCO_V1 = Weights(
+ url="https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth",
+ transforms=ObjectDetection,
+ meta={
+ "num_params": 35641826,
+ "categories": _COCO_CATEGORIES,
+ "min_size": (1, 1),
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssd300-vgg16",
+ "_metrics": {
+ "COCO-val2017": {
+ "box_map": 25.1,
+ }
+ },
+ "_ops": 34.858,
+ "_file_size": 135.988,
+ "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
+ },
+ )
+ DEFAULT = COCO_V1
+
+
+def _xavier_init(conv: nn.Module):
+ for layer in conv.modules():
+ if isinstance(layer, nn.Conv2d):
+ torch.nn.init.xavier_uniform_(layer.weight)
+ if layer.bias is not None:
+ torch.nn.init.constant_(layer.bias, 0.0)
+
+
+class SSDHead(nn.Module):
+ def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int):
+ super().__init__()
+ self.classification_head = SSDClassificationHead(in_channels, num_anchors, num_classes)
+ self.regression_head = SSDRegressionHead(in_channels, num_anchors)
+
+ def forward(self, x: List[Tensor]) -> Dict[str, Tensor]:
+ return {
+ "bbox_regression": self.regression_head(x),
+ "cls_logits": self.classification_head(x),
+ }
+
+
+class SSDScoringHead(nn.Module):
+ def __init__(self, module_list: nn.ModuleList, num_columns: int):
+ super().__init__()
+ self.module_list = module_list
+ self.num_columns = num_columns
+
+ def _get_result_from_module_list(self, x: Tensor, idx: int) -> Tensor:
+ """
+ This is equivalent to self.module_list[idx](x),
+ but torchscript doesn't support this yet
+ """
+ num_blocks = len(self.module_list)
+ if idx < 0:
+ idx += num_blocks
+ out = x
+ for i, module in enumerate(self.module_list):
+ if i == idx:
+ out = module(x)
+ return out
+
+ def forward(self, x: List[Tensor]) -> Tensor:
+ all_results = []
+
+ for i, features in enumerate(x):
+ results = self._get_result_from_module_list(features, i)
+
+ # Permute output from (N, A * K, H, W) to (N, HWA, K).
+ N, _, H, W = results.shape
+ results = results.view(N, -1, self.num_columns, H, W)
+ results = results.permute(0, 3, 4, 1, 2)
+ results = results.reshape(N, -1, self.num_columns) # Size=(N, HWA, K)
+
+ all_results.append(results)
+
+ return torch.cat(all_results, dim=1)
+
+
+class SSDClassificationHead(SSDScoringHead):
+ def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int):
+ cls_logits = nn.ModuleList()
+ for channels, anchors in zip(in_channels, num_anchors):
+ cls_logits.append(nn.Conv2d(channels, num_classes * anchors, kernel_size=3, padding=1))
+ _xavier_init(cls_logits)
+ super().__init__(cls_logits, num_classes)
+
+
+class SSDRegressionHead(SSDScoringHead):
+ def __init__(self, in_channels: List[int], num_anchors: List[int]):
+ bbox_reg = nn.ModuleList()
+ for channels, anchors in zip(in_channels, num_anchors):
+ bbox_reg.append(nn.Conv2d(channels, 4 * anchors, kernel_size=3, padding=1))
+ _xavier_init(bbox_reg)
+ super().__init__(bbox_reg, 4)
+
+
+class SSD(nn.Module):
+ """
+ Implements SSD architecture from `"SSD: Single Shot MultiBox Detector" `_.
+
+ The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
+ image, and should be in 0-1 range. Different images can have different sizes, but they will be resized
+ to a fixed size before passing it to the backbone.
+
+ The behavior of the model changes depending on if it is in training or evaluation mode.
+
+ During training, the model expects both the input tensors and targets (list of dictionary),
+ containing:
+ - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+ - labels (Int64Tensor[N]): the class label for each ground-truth box
+
+ The model returns a Dict[Tensor] during training, containing the classification and regression
+ losses.
+
+ During inference, the model requires only the input tensors, and returns the post-processed
+ predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
+ follows, where ``N`` is the number of detections:
+
+ - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+ - labels (Int64Tensor[N]): the predicted labels for each detection
+ - scores (Tensor[N]): the scores for each detection
+
+ Args:
+ backbone (nn.Module): the network used to compute the features for the model.
+ It should contain an out_channels attribute with the list of the output channels of
+ each feature map. The backbone should return a single Tensor or an OrderedDict[Tensor].
+ anchor_generator (DefaultBoxGenerator): module that generates the default boxes for a
+ set of feature maps.
+ size (Tuple[int, int]): the width and height to which images will be rescaled before feeding them
+ to the backbone.
+ num_classes (int): number of output classes of the model (including the background).
+ image_mean (Tuple[float, float, float]): mean values used for input normalization.
+ They are generally the mean values of the dataset on which the backbone has been trained
+ on
+ image_std (Tuple[float, float, float]): std values used for input normalization.
+ They are generally the std values of the dataset on which the backbone has been trained on
+ head (nn.Module, optional): Module run on top of the backbone features. Defaults to a module containing
+ a classification and regression module.
+ score_thresh (float): Score threshold used for postprocessing the detections.
+ nms_thresh (float): NMS threshold used for postprocessing the detections.
+ detections_per_img (int): Number of best detections to keep after NMS.
+ iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
+ considered as positive during training.
+ topk_candidates (int): Number of best detections to keep before NMS.
+ positive_fraction (float): a number between 0 and 1 which indicates the proportion of positive
+ proposals used during the training of the classification head. It is used to estimate the negative to
+ positive ratio.
+ """
+
+ __annotations__ = {
+ "box_coder": det_utils.BoxCoder,
+ "proposal_matcher": det_utils.Matcher,
+ }
+
+ def __init__(
+ self,
+ backbone: nn.Module,
+ anchor_generator: DefaultBoxGenerator,
+ size: Tuple[int, int],
+ num_classes: int,
+ image_mean: Optional[List[float]] = None,
+ image_std: Optional[List[float]] = None,
+ head: Optional[nn.Module] = None,
+ score_thresh: float = 0.01,
+ nms_thresh: float = 0.45,
+ detections_per_img: int = 200,
+ iou_thresh: float = 0.5,
+ topk_candidates: int = 400,
+ positive_fraction: float = 0.25,
+ **kwargs: Any,
+ ):
+ super().__init__()
+ _log_api_usage_once(self)
+
+ self.backbone = backbone
+
+ self.anchor_generator = anchor_generator
+
+ self.box_coder = det_utils.BoxCoder(weights=(10.0, 10.0, 5.0, 5.0))
+
+ if head is None:
+ if hasattr(backbone, "out_channels"):
+ out_channels = backbone.out_channels
+ else:
+ out_channels = det_utils.retrieve_out_channels(backbone, size)
+
+ if len(out_channels) != len(anchor_generator.aspect_ratios):
+ raise ValueError(
+ f"The length of the output channels from the backbone ({len(out_channels)}) do not match the length of the anchor generator aspect ratios ({len(anchor_generator.aspect_ratios)})"
+ )
+
+ num_anchors = self.anchor_generator.num_anchors_per_location()
+ head = SSDHead(out_channels, num_anchors, num_classes)
+ self.head = head
+
+ self.proposal_matcher = det_utils.SSDMatcher(iou_thresh)
+
+ if image_mean is None:
+ image_mean = [0.485, 0.456, 0.406]
+ if image_std is None:
+ image_std = [0.229, 0.224, 0.225]
+ self.transform = GeneralizedRCNNTransform(
+ min(size), max(size), image_mean, image_std, size_divisible=1, fixed_size=size, **kwargs
+ )
+
+ self.score_thresh = score_thresh
+ self.nms_thresh = nms_thresh
+ self.detections_per_img = detections_per_img
+ self.topk_candidates = topk_candidates
+ self.neg_to_pos_ratio = (1.0 - positive_fraction) / positive_fraction
+
+ # used only on torchscript mode
+ self._has_warned = False
+
+ @torch.jit.unused
+ def eager_outputs(
+ self, losses: Dict[str, Tensor], detections: List[Dict[str, Tensor]]
+ ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
+ if self.training:
+ return losses
+
+ return detections
+
+ def compute_loss(
+ self,
+ targets: List[Dict[str, Tensor]],
+ head_outputs: Dict[str, Tensor],
+ anchors: List[Tensor],
+ matched_idxs: List[Tensor],
+ ) -> Dict[str, Tensor]:
+ bbox_regression = head_outputs["bbox_regression"]
+ cls_logits = head_outputs["cls_logits"]
+
+ # Match original targets with default boxes
+ num_foreground = 0
+ bbox_loss = []
+ cls_targets = []
+ for (
+ targets_per_image,
+ bbox_regression_per_image,
+ cls_logits_per_image,
+ anchors_per_image,
+ matched_idxs_per_image,
+ ) in zip(targets, bbox_regression, cls_logits, anchors, matched_idxs):
+ # produce the matching between boxes and targets
+ foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0]
+ foreground_matched_idxs_per_image = matched_idxs_per_image[foreground_idxs_per_image]
+ num_foreground += foreground_matched_idxs_per_image.numel()
+
+ # Calculate regression loss
+ matched_gt_boxes_per_image = targets_per_image["boxes"][foreground_matched_idxs_per_image]
+ bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :]
+ anchors_per_image = anchors_per_image[foreground_idxs_per_image, :]
+ target_regression = self.box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
+ bbox_loss.append(
+ torch.nn.functional.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction="sum")
+ )
+
+ # Estimate ground truth for class targets
+ gt_classes_target = torch.zeros(
+ (cls_logits_per_image.size(0),),
+ dtype=targets_per_image["labels"].dtype,
+ device=targets_per_image["labels"].device,
+ )
+ gt_classes_target[foreground_idxs_per_image] = targets_per_image["labels"][
+ foreground_matched_idxs_per_image
+ ]
+ cls_targets.append(gt_classes_target)
+
+ bbox_loss = torch.stack(bbox_loss)
+ cls_targets = torch.stack(cls_targets)
+
+ # Calculate classification loss
+ num_classes = cls_logits.size(-1)
+ cls_loss = F.cross_entropy(cls_logits.view(-1, num_classes), cls_targets.view(-1), reduction="none").view(
+ cls_targets.size()
+ )
+
+ # Hard Negative Sampling
+ foreground_idxs = cls_targets > 0
+ num_negative = self.neg_to_pos_ratio * foreground_idxs.sum(1, keepdim=True)
+ # num_negative[num_negative < self.neg_to_pos_ratio] = self.neg_to_pos_ratio
+ negative_loss = cls_loss.clone()
+ negative_loss[foreground_idxs] = -float("inf") # use -inf to detect positive values that creeped in the sample
+ values, idx = negative_loss.sort(1, descending=True)
+ # background_idxs = torch.logical_and(idx.sort(1)[1] < num_negative, torch.isfinite(values))
+ background_idxs = idx.sort(1)[1] < num_negative
+
+ N = max(1, num_foreground)
+ return {
+ "bbox_regression": bbox_loss.sum() / N,
+ "classification": (cls_loss[foreground_idxs].sum() + cls_loss[background_idxs].sum()) / N,
+ }
+
+ def forward(
+ self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None
+ ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
+ if self.training:
+ if targets is None:
+ torch._assert(False, "targets should not be none when in training mode")
+ else:
+ for target in targets:
+ boxes = target["boxes"]
+ if isinstance(boxes, torch.Tensor):
+ torch._assert(
+ len(boxes.shape) == 2 and boxes.shape[-1] == 4,
+ f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.",
+ )
+ else:
+ torch._assert(False, f"Expected target boxes to be of type Tensor, got {type(boxes)}.")
+
+ # get the original image sizes
+ original_image_sizes: List[Tuple[int, int]] = []
+ for img in images:
+ val = img.shape[-2:]
+ torch._assert(
+ len(val) == 2,
+ f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
+ )
+ original_image_sizes.append((val[0], val[1]))
+
+ # transform the input
+ images, targets = self.transform(images, targets)
+
+ # Check for degenerate boxes
+ if targets is not None:
+ for target_idx, target in enumerate(targets):
+ boxes = target["boxes"]
+ degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
+ if degenerate_boxes.any():
+ bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
+ degen_bb: List[float] = boxes[bb_idx].tolist()
+ torch._assert(
+ False,
+ "All bounding boxes should have positive height and width."
+ f" Found invalid box {degen_bb} for target at index {target_idx}.",
+ )
+
+ # get the features from the backbone
+ features = self.backbone(images.tensors)
+ if isinstance(features, torch.Tensor):
+ features = OrderedDict([("0", features)])
+
+ features = list(features.values())
+
+ # compute the ssd heads outputs using the features
+ head_outputs = self.head(features)
+
+ # create the set of anchors
+ anchors = self.anchor_generator(images, features)
+
+ losses = {}
+ detections: List[Dict[str, Tensor]] = []
+ if self.training:
+ matched_idxs = []
+ if targets is None:
+ torch._assert(False, "targets should not be none when in training mode")
+ else:
+ for anchors_per_image, targets_per_image in zip(anchors, targets):
+ if targets_per_image["boxes"].numel() == 0:
+ matched_idxs.append(
+ torch.full(
+ (anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device
+ )
+ )
+ continue
+
+ match_quality_matrix = box_ops.box_iou(targets_per_image["boxes"], anchors_per_image)
+ matched_idxs.append(self.proposal_matcher(match_quality_matrix))
+
+ losses = self.compute_loss(targets, head_outputs, anchors, matched_idxs)
+ else:
+ detections = self.postprocess_detections(head_outputs, anchors, images.image_sizes)
+ detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
+
+ if torch.jit.is_scripting():
+ if not self._has_warned:
+ warnings.warn("SSD always returns a (Losses, Detections) tuple in scripting")
+ self._has_warned = True
+ return losses, detections
+ return self.eager_outputs(losses, detections)
+
+ def postprocess_detections(
+ self, head_outputs: Dict[str, Tensor], image_anchors: List[Tensor], image_shapes: List[Tuple[int, int]]
+ ) -> List[Dict[str, Tensor]]:
+ bbox_regression = head_outputs["bbox_regression"]
+ pred_scores = F.softmax(head_outputs["cls_logits"], dim=-1)
+
+ num_classes = pred_scores.size(-1)
+ device = pred_scores.device
+
+ detections: List[Dict[str, Tensor]] = []
+
+ for boxes, scores, anchors, image_shape in zip(bbox_regression, pred_scores, image_anchors, image_shapes):
+ boxes = self.box_coder.decode_single(boxes, anchors)
+ boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
+
+ image_boxes = []
+ image_scores = []
+ image_labels = []
+ for label in range(1, num_classes):
+ score = scores[:, label]
+
+ keep_idxs = score > self.score_thresh
+ score = score[keep_idxs]
+ box = boxes[keep_idxs]
+
+ # keep only topk scoring predictions
+ num_topk = det_utils._topk_min(score, self.topk_candidates, 0)
+ score, idxs = score.topk(num_topk)
+ box = box[idxs]
+
+ image_boxes.append(box)
+ image_scores.append(score)
+ image_labels.append(torch.full_like(score, fill_value=label, dtype=torch.int64, device=device))
+
+ image_boxes = torch.cat(image_boxes, dim=0)
+ image_scores = torch.cat(image_scores, dim=0)
+ image_labels = torch.cat(image_labels, dim=0)
+
+ # non-maximum suppression
+ keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
+ keep = keep[: self.detections_per_img]
+
+ detections.append(
+ {
+ "boxes": image_boxes[keep],
+ "scores": image_scores[keep],
+ "labels": image_labels[keep],
+ }
+ )
+ return detections
+
+
+class SSDFeatureExtractorVGG(nn.Module):
+ def __init__(self, backbone: nn.Module, highres: bool):
+ super().__init__()
+
+ _, _, maxpool3_pos, maxpool4_pos, _ = (i for i, layer in enumerate(backbone) if isinstance(layer, nn.MaxPool2d))
+
+ # Patch ceil_mode for maxpool3 to get the same WxH output sizes as the paper
+ backbone[maxpool3_pos].ceil_mode = True
+
+ # parameters used for L2 regularization + rescaling
+ self.scale_weight = nn.Parameter(torch.ones(512) * 20)
+
+ # Multiple Feature maps - page 4, Fig 2 of SSD paper
+ self.features = nn.Sequential(*backbone[:maxpool4_pos]) # until conv4_3
+
+ # SSD300 case - page 4, Fig 2 of SSD paper
+ extra = nn.ModuleList(
+ [
+ nn.Sequential(
+ nn.Conv2d(1024, 256, kernel_size=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(256, 512, kernel_size=3, padding=1, stride=2), # conv8_2
+ nn.ReLU(inplace=True),
+ ),
+ nn.Sequential(
+ nn.Conv2d(512, 128, kernel_size=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2), # conv9_2
+ nn.ReLU(inplace=True),
+ ),
+ nn.Sequential(
+ nn.Conv2d(256, 128, kernel_size=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 256, kernel_size=3), # conv10_2
+ nn.ReLU(inplace=True),
+ ),
+ nn.Sequential(
+ nn.Conv2d(256, 128, kernel_size=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 256, kernel_size=3), # conv11_2
+ nn.ReLU(inplace=True),
+ ),
+ ]
+ )
+ if highres:
+ # Additional layers for the SSD512 case. See page 11, footernote 5.
+ extra.append(
+ nn.Sequential(
+ nn.Conv2d(256, 128, kernel_size=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 256, kernel_size=4), # conv12_2
+ nn.ReLU(inplace=True),
+ )
+ )
+ _xavier_init(extra)
+
+ fc = nn.Sequential(
+ nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=False), # add modified maxpool5
+ nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=6, dilation=6), # FC6 with atrous
+ nn.ReLU(inplace=True),
+ nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=1), # FC7
+ nn.ReLU(inplace=True),
+ )
+ _xavier_init(fc)
+ extra.insert(
+ 0,
+ nn.Sequential(
+ *backbone[maxpool4_pos:-1], # until conv5_3, skip maxpool5
+ fc,
+ ),
+ )
+ self.extra = extra
+
+ def forward(self, x: Tensor) -> Dict[str, Tensor]:
+ # L2 regularization + Rescaling of 1st block's feature map
+ x = self.features(x)
+ rescaled = self.scale_weight.view(1, -1, 1, 1) * F.normalize(x)
+ output = [rescaled]
+
+ # Calculating Feature maps for the rest blocks
+ for block in self.extra:
+ x = block(x)
+ output.append(x)
+
+ return OrderedDict([(str(i), v) for i, v in enumerate(output)])
+
+
+def _vgg_extractor(backbone: VGG, highres: bool, trainable_layers: int):
+ backbone = backbone.features
+ # Gather the indices of maxpools. These are the locations of output blocks.
+ stage_indices = [0] + [i for i, b in enumerate(backbone) if isinstance(b, nn.MaxPool2d)][:-1]
+ num_stages = len(stage_indices)
+
+ # find the index of the layer from which we won't freeze
+ torch._assert(
+ 0 <= trainable_layers <= num_stages,
+ f"trainable_layers should be in the range [0, {num_stages}]. Instead got {trainable_layers}",
+ )
+ freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]
+
+ for b in backbone[:freeze_before]:
+ for parameter in b.parameters():
+ parameter.requires_grad_(False)
+
+ return SSDFeatureExtractorVGG(backbone, highres)
+
+
+@register_model()
+@handle_legacy_interface(
+ weights=("pretrained", SSD300_VGG16_Weights.COCO_V1),
+ weights_backbone=("pretrained_backbone", VGG16_Weights.IMAGENET1K_FEATURES),
+)
+def ssd300_vgg16(
+ *,
+ weights: Optional[SSD300_VGG16_Weights] = None,
+ progress: bool = True,
+ num_classes: Optional[int] = None,
+ weights_backbone: Optional[VGG16_Weights] = VGG16_Weights.IMAGENET1K_FEATURES,
+ trainable_backbone_layers: Optional[int] = None,
+ **kwargs: Any,
+) -> SSD:
+ """The SSD300 model is based on the `SSD: Single Shot MultiBox Detector
+ `_ paper.
+
+ .. betastatus:: detection module
+
+ The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
+ image, and should be in 0-1 range. Different images can have different sizes, but they will be resized
+ to a fixed size before passing it to the backbone.
+
+ The behavior of the model changes depending on if it is in training or evaluation mode.
+
+ During training, the model expects both the input tensors and targets (list of dictionary),
+ containing:
+
+ - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+ - labels (Int64Tensor[N]): the class label for each ground-truth box
+
+ The model returns a Dict[Tensor] during training, containing the classification and regression
+ losses.
+
+ During inference, the model requires only the input tensors, and returns the post-processed
+ predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
+ follows, where ``N`` is the number of detections:
+
+ - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+ - labels (Int64Tensor[N]): the predicted labels for each detection
+ - scores (Tensor[N]): the scores for each detection
+
+ Example:
+
+ >>> model = torchvision.models.detection.ssd300_vgg16(weights=SSD300_VGG16_Weights.DEFAULT)
+ >>> model.eval()
+ >>> x = [torch.rand(3, 300, 300), torch.rand(3, 500, 400)]
+ >>> predictions = model(x)
+
+ Args:
+ weights (:class:`~torchvision.models.detection.SSD300_VGG16_Weights`, optional): The pretrained
+ weights to use. See
+ :class:`~torchvision.models.detection.SSD300_VGG16_Weights`
+ below for more details, and possible values. By default, no
+ pre-trained weights are used.
+ progress (bool, optional): If True, displays a progress bar of the download to stderr
+ Default is True.
+ num_classes (int, optional): number of output classes of the model (including the background)
+ weights_backbone (:class:`~torchvision.models.VGG16_Weights`, optional): The pretrained weights for the
+ backbone
+ trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
+ Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
+ passed (the default) this value is set to 4.
+ **kwargs: parameters passed to the ``torchvision.models.detection.SSD``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+
+ .. autoclass:: torchvision.models.detection.SSD300_VGG16_Weights
+ :members:
+ """
+ weights = SSD300_VGG16_Weights.verify(weights)
+ weights_backbone = VGG16_Weights.verify(weights_backbone)
+
+ if "size" in kwargs:
+ warnings.warn("The size of the model is already fixed; ignoring the parameter.")
+
+ if weights is not None:
+ weights_backbone = None
+ num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+ elif num_classes is None:
+ num_classes = 91
+
+ trainable_backbone_layers = _validate_trainable_layers(
+ weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 4
+ )
+
+ # Use custom backbones more appropriate for SSD
+ backbone = vgg16(weights=weights_backbone, progress=progress)
+ backbone = _vgg_extractor(backbone, False, trainable_backbone_layers)
+ anchor_generator = DefaultBoxGenerator(
+ [[2], [2, 3], [2, 3], [2, 3], [2], [2]],
+ scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05],
+ steps=[8, 16, 32, 64, 100, 300],
+ )
+
+ defaults = {
+ # Rescale the input in a way compatible to the backbone
+ "image_mean": [0.48235, 0.45882, 0.40784],
+ "image_std": [1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0], # undo the 0-1 scaling of toTensor
+ }
+ kwargs: Any = {**defaults, **kwargs}
+ model = SSD(backbone, anchor_generator, (300, 300), num_classes, **kwargs)
+
+ if weights is not None:
+ model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+
+ return model
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/detection/ssdlite.py b/.venv/lib/python3.11/site-packages/torchvision/models/detection/ssdlite.py
new file mode 100644
index 0000000000000000000000000000000000000000..eda21bf941ef0d4a9051312ebdba6911c6760e8d
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/models/detection/ssdlite.py
@@ -0,0 +1,331 @@
+import warnings
+from collections import OrderedDict
+from functools import partial
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import torch
+from torch import nn, Tensor
+
+from ...ops.misc import Conv2dNormActivation
+from ...transforms._presets import ObjectDetection
+from ...utils import _log_api_usage_once
+from .. import mobilenet
+from .._api import register_model, Weights, WeightsEnum
+from .._meta import _COCO_CATEGORIES
+from .._utils import _ovewrite_value_param, handle_legacy_interface
+from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights
+from . import _utils as det_utils
+from .anchor_utils import DefaultBoxGenerator
+from .backbone_utils import _validate_trainable_layers
+from .ssd import SSD, SSDScoringHead
+
+
+__all__ = [
+ "SSDLite320_MobileNet_V3_Large_Weights",
+ "ssdlite320_mobilenet_v3_large",
+]
+
+
+# Building blocks of SSDlite as described in section 6.2 of MobileNetV2 paper
+def _prediction_block(
+ in_channels: int, out_channels: int, kernel_size: int, norm_layer: Callable[..., nn.Module]
+) -> nn.Sequential:
+ return nn.Sequential(
+ # 3x3 depthwise with stride 1 and padding 1
+ Conv2dNormActivation(
+ in_channels,
+ in_channels,
+ kernel_size=kernel_size,
+ groups=in_channels,
+ norm_layer=norm_layer,
+ activation_layer=nn.ReLU6,
+ ),
+ # 1x1 projetion to output channels
+ nn.Conv2d(in_channels, out_channels, 1),
+ )
+
+
+def _extra_block(in_channels: int, out_channels: int, norm_layer: Callable[..., nn.Module]) -> nn.Sequential:
+ activation = nn.ReLU6
+ intermediate_channels = out_channels // 2
+ return nn.Sequential(
+ # 1x1 projection to half output channels
+ Conv2dNormActivation(
+ in_channels, intermediate_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=activation
+ ),
+ # 3x3 depthwise with stride 2 and padding 1
+ Conv2dNormActivation(
+ intermediate_channels,
+ intermediate_channels,
+ kernel_size=3,
+ stride=2,
+ groups=intermediate_channels,
+ norm_layer=norm_layer,
+ activation_layer=activation,
+ ),
+ # 1x1 projetion to output channels
+ Conv2dNormActivation(
+ intermediate_channels, out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=activation
+ ),
+ )
+
+
+def _normal_init(conv: nn.Module):
+ for layer in conv.modules():
+ if isinstance(layer, nn.Conv2d):
+ torch.nn.init.normal_(layer.weight, mean=0.0, std=0.03)
+ if layer.bias is not None:
+ torch.nn.init.constant_(layer.bias, 0.0)
+
+
+class SSDLiteHead(nn.Module):
+ def __init__(
+ self, in_channels: List[int], num_anchors: List[int], num_classes: int, norm_layer: Callable[..., nn.Module]
+ ):
+ super().__init__()
+ self.classification_head = SSDLiteClassificationHead(in_channels, num_anchors, num_classes, norm_layer)
+ self.regression_head = SSDLiteRegressionHead(in_channels, num_anchors, norm_layer)
+
+ def forward(self, x: List[Tensor]) -> Dict[str, Tensor]:
+ return {
+ "bbox_regression": self.regression_head(x),
+ "cls_logits": self.classification_head(x),
+ }
+
+
+class SSDLiteClassificationHead(SSDScoringHead):
+ def __init__(
+ self, in_channels: List[int], num_anchors: List[int], num_classes: int, norm_layer: Callable[..., nn.Module]
+ ):
+ cls_logits = nn.ModuleList()
+ for channels, anchors in zip(in_channels, num_anchors):
+ cls_logits.append(_prediction_block(channels, num_classes * anchors, 3, norm_layer))
+ _normal_init(cls_logits)
+ super().__init__(cls_logits, num_classes)
+
+
+class SSDLiteRegressionHead(SSDScoringHead):
+ def __init__(self, in_channels: List[int], num_anchors: List[int], norm_layer: Callable[..., nn.Module]):
+ bbox_reg = nn.ModuleList()
+ for channels, anchors in zip(in_channels, num_anchors):
+ bbox_reg.append(_prediction_block(channels, 4 * anchors, 3, norm_layer))
+ _normal_init(bbox_reg)
+ super().__init__(bbox_reg, 4)
+
+
+class SSDLiteFeatureExtractorMobileNet(nn.Module):
+ def __init__(
+ self,
+ backbone: nn.Module,
+ c4_pos: int,
+ norm_layer: Callable[..., nn.Module],
+ width_mult: float = 1.0,
+ min_depth: int = 16,
+ ):
+ super().__init__()
+ _log_api_usage_once(self)
+
+ if backbone[c4_pos].use_res_connect:
+ raise ValueError("backbone[c4_pos].use_res_connect should be False")
+
+ self.features = nn.Sequential(
+ # As described in section 6.3 of MobileNetV3 paper
+ nn.Sequential(*backbone[:c4_pos], backbone[c4_pos].block[0]), # from start until C4 expansion layer
+ nn.Sequential(backbone[c4_pos].block[1:], *backbone[c4_pos + 1 :]), # from C4 depthwise until end
+ )
+
+ get_depth = lambda d: max(min_depth, int(d * width_mult)) # noqa: E731
+ extra = nn.ModuleList(
+ [
+ _extra_block(backbone[-1].out_channels, get_depth(512), norm_layer),
+ _extra_block(get_depth(512), get_depth(256), norm_layer),
+ _extra_block(get_depth(256), get_depth(256), norm_layer),
+ _extra_block(get_depth(256), get_depth(128), norm_layer),
+ ]
+ )
+ _normal_init(extra)
+
+ self.extra = extra
+
+ def forward(self, x: Tensor) -> Dict[str, Tensor]:
+ # Get feature maps from backbone and extra. Can't be refactored due to JIT limitations.
+ output = []
+ for block in self.features:
+ x = block(x)
+ output.append(x)
+
+ for block in self.extra:
+ x = block(x)
+ output.append(x)
+
+ return OrderedDict([(str(i), v) for i, v in enumerate(output)])
+
+
+def _mobilenet_extractor(
+ backbone: Union[mobilenet.MobileNetV2, mobilenet.MobileNetV3],
+ trainable_layers: int,
+ norm_layer: Callable[..., nn.Module],
+):
+ backbone = backbone.features
+ # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
+ # The first and last blocks are always included because they are the C0 (conv1) and Cn.
+ stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
+ num_stages = len(stage_indices)
+
+ # find the index of the layer from which we won't freeze
+ if not 0 <= trainable_layers <= num_stages:
+ raise ValueError("trainable_layers should be in the range [0, {num_stages}], instead got {trainable_layers}")
+ freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]
+
+ for b in backbone[:freeze_before]:
+ for parameter in b.parameters():
+ parameter.requires_grad_(False)
+
+ return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer)
+
+
+class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum):
+ COCO_V1 = Weights(
+ url="https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth",
+ transforms=ObjectDetection,
+ meta={
+ "num_params": 3440060,
+ "categories": _COCO_CATEGORIES,
+ "min_size": (1, 1),
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssdlite320-mobilenetv3-large",
+ "_metrics": {
+ "COCO-val2017": {
+ "box_map": 21.3,
+ }
+ },
+ "_ops": 0.583,
+ "_file_size": 13.418,
+ "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
+ },
+ )
+ DEFAULT = COCO_V1
+
+
+@register_model()
+@handle_legacy_interface(
+ weights=("pretrained", SSDLite320_MobileNet_V3_Large_Weights.COCO_V1),
+ weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
+)
+def ssdlite320_mobilenet_v3_large(
+ *,
+ weights: Optional[SSDLite320_MobileNet_V3_Large_Weights] = None,
+ progress: bool = True,
+ num_classes: Optional[int] = None,
+ weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
+ trainable_backbone_layers: Optional[int] = None,
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
+ **kwargs: Any,
+) -> SSD:
+ """SSDlite model architecture with input size 320x320 and a MobileNetV3 Large backbone, as
+ described at `Searching for MobileNetV3 `__ and
+ `MobileNetV2: Inverted Residuals and Linear Bottlenecks `__.
+
+ .. betastatus:: detection module
+
+ See :func:`~torchvision.models.detection.ssd300_vgg16` for more details.
+
+ Example:
+
+ >>> model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(weights=SSDLite320_MobileNet_V3_Large_Weights.DEFAULT)
+ >>> model.eval()
+ >>> x = [torch.rand(3, 320, 320), torch.rand(3, 500, 400)]
+ >>> predictions = model(x)
+
+ Args:
+ weights (:class:`~torchvision.models.detection.SSDLite320_MobileNet_V3_Large_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.detection.SSDLite320_MobileNet_V3_Large_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ num_classes (int, optional): number of output classes of the model
+ (including the background).
+ weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The pretrained
+ weights for the backbone.
+ trainable_backbone_layers (int, optional): number of trainable (not frozen) layers
+ starting from final block. Valid values are between 0 and 6, with 6 meaning all
+ backbone layers are trainable. If ``None`` is passed (the default) this value is
+ set to 6.
+ norm_layer (callable, optional): Module specifying the normalization layer to use.
+ **kwargs: parameters passed to the ``torchvision.models.detection.ssd.SSD``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+
+ .. autoclass:: torchvision.models.detection.SSDLite320_MobileNet_V3_Large_Weights
+ :members:
+ """
+
+ weights = SSDLite320_MobileNet_V3_Large_Weights.verify(weights)
+ weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
+
+ if "size" in kwargs:
+ warnings.warn("The size of the model is already fixed; ignoring the parameter.")
+
+ if weights is not None:
+ weights_backbone = None
+ num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+ elif num_classes is None:
+ num_classes = 91
+
+ trainable_backbone_layers = _validate_trainable_layers(
+ weights is not None or weights_backbone is not None, trainable_backbone_layers, 6, 6
+ )
+
+ # Enable reduced tail if no pretrained backbone is selected. See Table 6 of MobileNetV3 paper.
+ reduce_tail = weights_backbone is None
+
+ if norm_layer is None:
+ norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03)
+
+ backbone = mobilenet_v3_large(
+ weights=weights_backbone, progress=progress, norm_layer=norm_layer, reduced_tail=reduce_tail, **kwargs
+ )
+ if weights_backbone is None:
+ # Change the default initialization scheme if not pretrained
+ _normal_init(backbone)
+ backbone = _mobilenet_extractor(
+ backbone,
+ trainable_backbone_layers,
+ norm_layer,
+ )
+
+ size = (320, 320)
+ anchor_generator = DefaultBoxGenerator([[2, 3] for _ in range(6)], min_ratio=0.2, max_ratio=0.95)
+ out_channels = det_utils.retrieve_out_channels(backbone, size)
+ num_anchors = anchor_generator.num_anchors_per_location()
+ if len(out_channels) != len(anchor_generator.aspect_ratios):
+ raise ValueError(
+ f"The length of the output channels from the backbone {len(out_channels)} do not match the length of the anchor generator aspect ratios {len(anchor_generator.aspect_ratios)}"
+ )
+
+ defaults = {
+ "score_thresh": 0.001,
+ "nms_thresh": 0.55,
+ "detections_per_img": 300,
+ "topk_candidates": 300,
+ # Rescale the input in a way compatible to the backbone:
+ # The following mean/std rescale the data from [0, 1] to [-1, 1]
+ "image_mean": [0.5, 0.5, 0.5],
+ "image_std": [0.5, 0.5, 0.5],
+ }
+ kwargs: Any = {**defaults, **kwargs}
+ model = SSD(
+ backbone,
+ anchor_generator,
+ size,
+ num_classes,
+ head=SSDLiteHead(out_channels, num_anchors, num_classes, norm_layer),
+ **kwargs,
+ )
+
+ if weights is not None:
+ model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+
+ return model
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/detection/transform.py b/.venv/lib/python3.11/site-packages/torchvision/models/detection/transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c569b0aafb0c5464815654c0f343d7fb927dc6c
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/models/detection/transform.py
@@ -0,0 +1,319 @@
+import math
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+import torchvision
+from torch import nn, Tensor
+
+from .image_list import ImageList
+from .roi_heads import paste_masks_in_image
+
+
+@torch.jit.unused
+def _get_shape_onnx(image: Tensor) -> Tensor:
+ from torch.onnx import operators
+
+ return operators.shape_as_tensor(image)[-2:]
+
+
+@torch.jit.unused
+def _fake_cast_onnx(v: Tensor) -> float:
+ # ONNX requires a tensor but here we fake its type for JIT.
+ return v
+
+
+def _resize_image_and_masks(
+ image: Tensor,
+ self_min_size: int,
+ self_max_size: int,
+ target: Optional[Dict[str, Tensor]] = None,
+ fixed_size: Optional[Tuple[int, int]] = None,
+) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
+ if torchvision._is_tracing():
+ im_shape = _get_shape_onnx(image)
+ elif torch.jit.is_scripting():
+ im_shape = torch.tensor(image.shape[-2:])
+ else:
+ im_shape = image.shape[-2:]
+
+ size: Optional[List[int]] = None
+ scale_factor: Optional[float] = None
+ recompute_scale_factor: Optional[bool] = None
+ if fixed_size is not None:
+ size = [fixed_size[1], fixed_size[0]]
+ else:
+ if torch.jit.is_scripting() or torchvision._is_tracing():
+ min_size = torch.min(im_shape).to(dtype=torch.float32)
+ max_size = torch.max(im_shape).to(dtype=torch.float32)
+ self_min_size_f = float(self_min_size)
+ self_max_size_f = float(self_max_size)
+ scale = torch.min(self_min_size_f / min_size, self_max_size_f / max_size)
+
+ if torchvision._is_tracing():
+ scale_factor = _fake_cast_onnx(scale)
+ else:
+ scale_factor = scale.item()
+
+ else:
+ # Do it the normal way
+ min_size = min(im_shape)
+ max_size = max(im_shape)
+ scale_factor = min(self_min_size / min_size, self_max_size / max_size)
+
+ recompute_scale_factor = True
+
+ image = torch.nn.functional.interpolate(
+ image[None],
+ size=size,
+ scale_factor=scale_factor,
+ mode="bilinear",
+ recompute_scale_factor=recompute_scale_factor,
+ align_corners=False,
+ )[0]
+
+ if target is None:
+ return image, target
+
+ if "masks" in target:
+ mask = target["masks"]
+ mask = torch.nn.functional.interpolate(
+ mask[:, None].float(), size=size, scale_factor=scale_factor, recompute_scale_factor=recompute_scale_factor
+ )[:, 0].byte()
+ target["masks"] = mask
+ return image, target
+
+
+class GeneralizedRCNNTransform(nn.Module):
+ """
+ Performs input / target transformation before feeding the data to a GeneralizedRCNN
+ model.
+
+ The transformations it performs are:
+ - input normalization (mean subtraction and std division)
+ - input / target resizing to match min_size / max_size
+
+ It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets
+ """
+
+ def __init__(
+ self,
+ min_size: int,
+ max_size: int,
+ image_mean: List[float],
+ image_std: List[float],
+ size_divisible: int = 32,
+ fixed_size: Optional[Tuple[int, int]] = None,
+ **kwargs: Any,
+ ):
+ super().__init__()
+ if not isinstance(min_size, (list, tuple)):
+ min_size = (min_size,)
+ self.min_size = min_size
+ self.max_size = max_size
+ self.image_mean = image_mean
+ self.image_std = image_std
+ self.size_divisible = size_divisible
+ self.fixed_size = fixed_size
+ self._skip_resize = kwargs.pop("_skip_resize", False)
+
+ def forward(
+ self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None
+ ) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]:
+ images = [img for img in images]
+ if targets is not None:
+ # make a copy of targets to avoid modifying it in-place
+ # once torchscript supports dict comprehension
+ # this can be simplified as follows
+ # targets = [{k: v for k,v in t.items()} for t in targets]
+ targets_copy: List[Dict[str, Tensor]] = []
+ for t in targets:
+ data: Dict[str, Tensor] = {}
+ for k, v in t.items():
+ data[k] = v
+ targets_copy.append(data)
+ targets = targets_copy
+ for i in range(len(images)):
+ image = images[i]
+ target_index = targets[i] if targets is not None else None
+
+ if image.dim() != 3:
+ raise ValueError(f"images is expected to be a list of 3d tensors of shape [C, H, W], got {image.shape}")
+ image = self.normalize(image)
+ image, target_index = self.resize(image, target_index)
+ images[i] = image
+ if targets is not None and target_index is not None:
+ targets[i] = target_index
+
+ image_sizes = [img.shape[-2:] for img in images]
+ images = self.batch_images(images, size_divisible=self.size_divisible)
+ image_sizes_list: List[Tuple[int, int]] = []
+ for image_size in image_sizes:
+ torch._assert(
+ len(image_size) == 2,
+ f"Input tensors expected to have in the last two elements H and W, instead got {image_size}",
+ )
+ image_sizes_list.append((image_size[0], image_size[1]))
+
+ image_list = ImageList(images, image_sizes_list)
+ return image_list, targets
+
+ def normalize(self, image: Tensor) -> Tensor:
+ if not image.is_floating_point():
+ raise TypeError(
+ f"Expected input images to be of floating type (in range [0, 1]), "
+ f"but found type {image.dtype} instead"
+ )
+ dtype, device = image.dtype, image.device
+ mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device)
+ std = torch.as_tensor(self.image_std, dtype=dtype, device=device)
+ return (image - mean[:, None, None]) / std[:, None, None]
+
+ def torch_choice(self, k: List[int]) -> int:
+ """
+ Implements `random.choice` via torch ops, so it can be compiled with
+ TorchScript and we use PyTorch's RNG (not native RNG)
+ """
+ index = int(torch.empty(1).uniform_(0.0, float(len(k))).item())
+ return k[index]
+
+ def resize(
+ self,
+ image: Tensor,
+ target: Optional[Dict[str, Tensor]] = None,
+ ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
+ h, w = image.shape[-2:]
+ if self.training:
+ if self._skip_resize:
+ return image, target
+ size = self.torch_choice(self.min_size)
+ else:
+ size = self.min_size[-1]
+ image, target = _resize_image_and_masks(image, size, self.max_size, target, self.fixed_size)
+
+ if target is None:
+ return image, target
+
+ bbox = target["boxes"]
+ bbox = resize_boxes(bbox, (h, w), image.shape[-2:])
+ target["boxes"] = bbox
+
+ if "keypoints" in target:
+ keypoints = target["keypoints"]
+ keypoints = resize_keypoints(keypoints, (h, w), image.shape[-2:])
+ target["keypoints"] = keypoints
+ return image, target
+
+ # _onnx_batch_images() is an implementation of
+ # batch_images() that is supported by ONNX tracing.
+ @torch.jit.unused
+ def _onnx_batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor:
+ max_size = []
+ for i in range(images[0].dim()):
+ max_size_i = torch.max(torch.stack([img.shape[i] for img in images]).to(torch.float32)).to(torch.int64)
+ max_size.append(max_size_i)
+ stride = size_divisible
+ max_size[1] = (torch.ceil((max_size[1].to(torch.float32)) / stride) * stride).to(torch.int64)
+ max_size[2] = (torch.ceil((max_size[2].to(torch.float32)) / stride) * stride).to(torch.int64)
+ max_size = tuple(max_size)
+
+ # work around for
+ # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+ # which is not yet supported in onnx
+ padded_imgs = []
+ for img in images:
+ padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
+ padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
+ padded_imgs.append(padded_img)
+
+ return torch.stack(padded_imgs)
+
+ def max_by_axis(self, the_list: List[List[int]]) -> List[int]:
+ maxes = the_list[0]
+ for sublist in the_list[1:]:
+ for index, item in enumerate(sublist):
+ maxes[index] = max(maxes[index], item)
+ return maxes
+
+ def batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor:
+ if torchvision._is_tracing():
+ # batch_images() does not export well to ONNX
+ # call _onnx_batch_images() instead
+ return self._onnx_batch_images(images, size_divisible)
+
+ max_size = self.max_by_axis([list(img.shape) for img in images])
+ stride = float(size_divisible)
+ max_size = list(max_size)
+ max_size[1] = int(math.ceil(float(max_size[1]) / stride) * stride)
+ max_size[2] = int(math.ceil(float(max_size[2]) / stride) * stride)
+
+ batch_shape = [len(images)] + max_size
+ batched_imgs = images[0].new_full(batch_shape, 0)
+ for i in range(batched_imgs.shape[0]):
+ img = images[i]
+ batched_imgs[i, : img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+
+ return batched_imgs
+
+ def postprocess(
+ self,
+ result: List[Dict[str, Tensor]],
+ image_shapes: List[Tuple[int, int]],
+ original_image_sizes: List[Tuple[int, int]],
+ ) -> List[Dict[str, Tensor]]:
+ if self.training:
+ return result
+ for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)):
+ boxes = pred["boxes"]
+ boxes = resize_boxes(boxes, im_s, o_im_s)
+ result[i]["boxes"] = boxes
+ if "masks" in pred:
+ masks = pred["masks"]
+ masks = paste_masks_in_image(masks, boxes, o_im_s)
+ result[i]["masks"] = masks
+ if "keypoints" in pred:
+ keypoints = pred["keypoints"]
+ keypoints = resize_keypoints(keypoints, im_s, o_im_s)
+ result[i]["keypoints"] = keypoints
+ return result
+
+ def __repr__(self) -> str:
+ format_string = f"{self.__class__.__name__}("
+ _indent = "\n "
+ format_string += f"{_indent}Normalize(mean={self.image_mean}, std={self.image_std})"
+ format_string += f"{_indent}Resize(min_size={self.min_size}, max_size={self.max_size}, mode='bilinear')"
+ format_string += "\n)"
+ return format_string
+
+
+def resize_keypoints(keypoints: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
+ ratios = [
+ torch.tensor(s, dtype=torch.float32, device=keypoints.device)
+ / torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device)
+ for s, s_orig in zip(new_size, original_size)
+ ]
+ ratio_h, ratio_w = ratios
+ resized_data = keypoints.clone()
+ if torch._C._get_tracing_state():
+ resized_data_0 = resized_data[:, :, 0] * ratio_w
+ resized_data_1 = resized_data[:, :, 1] * ratio_h
+ resized_data = torch.stack((resized_data_0, resized_data_1, resized_data[:, :, 2]), dim=2)
+ else:
+ resized_data[..., 0] *= ratio_w
+ resized_data[..., 1] *= ratio_h
+ return resized_data
+
+
+def resize_boxes(boxes: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
+ ratios = [
+ torch.tensor(s, dtype=torch.float32, device=boxes.device)
+ / torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
+ for s, s_orig in zip(new_size, original_size)
+ ]
+ ratio_height, ratio_width = ratios
+ xmin, ymin, xmax, ymax = boxes.unbind(1)
+
+ xmin = xmin * ratio_width
+ xmax = xmax * ratio_width
+ ymin = ymin * ratio_height
+ ymax = ymax * ratio_height
+ return torch.stack((xmin, ymin, xmax, ymax), dim=1)
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/segmentation/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/segmentation/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..61f2ec8f0853ada7794b8423cec17d90837fa698
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/segmentation/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/segmentation/__pycache__/_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/segmentation/__pycache__/_utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c056438e224ca73c2af55a40d84525701725e865
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/segmentation/__pycache__/_utils.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/segmentation/__pycache__/deeplabv3.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/segmentation/__pycache__/deeplabv3.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..af0d3f582f45dd856d7cf47dd4515827d9960175
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/segmentation/__pycache__/deeplabv3.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/segmentation/__pycache__/fcn.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/segmentation/__pycache__/fcn.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7d24670b96200ca7d9ee936af87fdf787024af68
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/segmentation/__pycache__/fcn.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/segmentation/__pycache__/lraspp.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/models/segmentation/__pycache__/lraspp.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dca2479612599f454fa2d7f2ed39e42ee771139e
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/models/segmentation/__pycache__/lraspp.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchvision/models/segmentation/_utils.py b/.venv/lib/python3.11/site-packages/torchvision/models/segmentation/_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..56560e9dab5c143699c918fa28236a902e530daf
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/models/segmentation/_utils.py
@@ -0,0 +1,37 @@
+from collections import OrderedDict
+from typing import Dict, Optional
+
+from torch import nn, Tensor
+from torch.nn import functional as F
+
+from ...utils import _log_api_usage_once
+
+
+class _SimpleSegmentationModel(nn.Module):
+ __constants__ = ["aux_classifier"]
+
+ def __init__(self, backbone: nn.Module, classifier: nn.Module, aux_classifier: Optional[nn.Module] = None) -> None:
+ super().__init__()
+ _log_api_usage_once(self)
+ self.backbone = backbone
+ self.classifier = classifier
+ self.aux_classifier = aux_classifier
+
+ def forward(self, x: Tensor) -> Dict[str, Tensor]:
+ input_shape = x.shape[-2:]
+ # contract: features is a dict of tensors
+ features = self.backbone(x)
+
+ result = OrderedDict()
+ x = features["out"]
+ x = self.classifier(x)
+ x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
+ result["out"] = x
+
+ if self.aux_classifier is not None:
+ x = features["aux"]
+ x = self.aux_classifier(x)
+ x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
+ result["aux"] = x
+
+ return result
diff --git a/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/__init__.py b/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ba47f60a36ad3be8cb2f557adb57f1d2f1ba470
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/__init__.py
@@ -0,0 +1,35 @@
+import torch
+
+from ._bounding_boxes import BoundingBoxes, BoundingBoxFormat
+from ._image import Image
+from ._mask import Mask
+from ._torch_function_helpers import set_return_type
+from ._tv_tensor import TVTensor
+from ._video import Video
+
+
+# TODO: Fix this. We skip this method as it leads to
+# RecursionError: maximum recursion depth exceeded while calling a Python object
+# Until `disable` is removed, there will be graph breaks after all calls to functional transforms
+@torch.compiler.disable
+def wrap(wrappee, *, like, **kwargs):
+ """Convert a :class:`torch.Tensor` (``wrappee``) into the same :class:`~torchvision.tv_tensors.TVTensor` subclass as ``like``.
+
+ If ``like`` is a :class:`~torchvision.tv_tensors.BoundingBoxes`, the ``format`` and ``canvas_size`` of
+ ``like`` are assigned to ``wrappee``, unless they are passed as ``kwargs``.
+
+ Args:
+ wrappee (Tensor): The tensor to convert.
+ like (:class:`~torchvision.tv_tensors.TVTensor`): The reference.
+ ``wrappee`` will be converted into the same subclass as ``like``.
+ kwargs: Can contain "format" and "canvas_size" if ``like`` is a :class:`~torchvision.tv_tensor.BoundingBoxes`.
+ Ignored otherwise.
+ """
+ if isinstance(like, BoundingBoxes):
+ return BoundingBoxes._wrap(
+ wrappee,
+ format=kwargs.get("format", like.format),
+ canvas_size=kwargs.get("canvas_size", like.canvas_size),
+ )
+ else:
+ return wrappee.as_subclass(type(like))
diff --git a/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8950f6eb15f829886be77489cee406591d0ca3da
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/__pycache__/__init__.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/__pycache__/_bounding_boxes.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/__pycache__/_bounding_boxes.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3edd85ce75eb37447bf4bfc5283d4ff7bad22326
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/__pycache__/_bounding_boxes.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/__pycache__/_dataset_wrapper.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/__pycache__/_dataset_wrapper.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0ca4c9ed23c846f580a80f135f1ea566139819b8
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/__pycache__/_dataset_wrapper.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/__pycache__/_image.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/__pycache__/_image.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..66debf6424c18c26c1b8856ae3a3f79c1fa2abff
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/__pycache__/_image.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/__pycache__/_mask.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/__pycache__/_mask.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dee2497536e90a8d717783abce212aa4637b3342
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/__pycache__/_mask.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/__pycache__/_torch_function_helpers.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/__pycache__/_torch_function_helpers.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3f1c21741782f48884a1f03d7c5447372d0abba1
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/__pycache__/_torch_function_helpers.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/__pycache__/_tv_tensor.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/__pycache__/_tv_tensor.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..96fc566bec631fd216116fae857c07aa97491a8f
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/__pycache__/_tv_tensor.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/__pycache__/_video.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/__pycache__/_video.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2435b6ef3e5f302f157ed85e16ef3172a6c20d3a
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/__pycache__/_video.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/_bounding_boxes.py b/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/_bounding_boxes.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea02fa3dc7b3e25ba7f545ca64288fc241d47f2a
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/_bounding_boxes.py
@@ -0,0 +1,103 @@
+from __future__ import annotations
+
+from enum import Enum
+from typing import Any, Mapping, Optional, Sequence, Tuple, Union
+
+import torch
+from torch.utils._pytree import tree_flatten
+
+from ._tv_tensor import TVTensor
+
+
+class BoundingBoxFormat(Enum):
+ """Coordinate format of a bounding box.
+
+ Available formats are
+
+ * ``XYXY``
+ * ``XYWH``
+ * ``CXCYWH``
+ """
+
+ XYXY = "XYXY"
+ XYWH = "XYWH"
+ CXCYWH = "CXCYWH"
+
+
+class BoundingBoxes(TVTensor):
+ """:class:`torch.Tensor` subclass for bounding boxes with shape ``[N, 4]``.
+
+ .. note::
+ There should be only one :class:`~torchvision.tv_tensors.BoundingBoxes`
+ instance per sample e.g. ``{"img": img, "bbox": BoundingBoxes(...)}``,
+ although one :class:`~torchvision.tv_tensors.BoundingBoxes` object can
+ contain multiple bounding boxes.
+
+ Args:
+ data: Any data that can be turned into a tensor with :func:`torch.as_tensor`.
+ format (BoundingBoxFormat, str): Format of the bounding box.
+ canvas_size (two-tuple of ints): Height and width of the corresponding image or video.
+ dtype (torch.dtype, optional): Desired data type of the bounding box. If omitted, will be inferred from
+ ``data``.
+ device (torch.device, optional): Desired device of the bounding box. If omitted and ``data`` is a
+ :class:`torch.Tensor`, the device is taken from it. Otherwise, the bounding box is constructed on the CPU.
+ requires_grad (bool, optional): Whether autograd should record operations on the bounding box. If omitted and
+ ``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``.
+ """
+
+ format: BoundingBoxFormat
+ canvas_size: Tuple[int, int]
+
+ @classmethod
+ def _wrap(cls, tensor: torch.Tensor, *, format: Union[BoundingBoxFormat, str], canvas_size: Tuple[int, int], check_dims: bool = True) -> BoundingBoxes: # type: ignore[override]
+ if check_dims:
+ if tensor.ndim == 1:
+ tensor = tensor.unsqueeze(0)
+ elif tensor.ndim != 2:
+ raise ValueError(f"Expected a 1D or 2D tensor, got {tensor.ndim}D")
+ if isinstance(format, str):
+ format = BoundingBoxFormat[format.upper()]
+ bounding_boxes = tensor.as_subclass(cls)
+ bounding_boxes.format = format
+ bounding_boxes.canvas_size = canvas_size
+ return bounding_boxes
+
+ def __new__(
+ cls,
+ data: Any,
+ *,
+ format: Union[BoundingBoxFormat, str],
+ canvas_size: Tuple[int, int],
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[Union[torch.device, str, int]] = None,
+ requires_grad: Optional[bool] = None,
+ ) -> BoundingBoxes:
+ tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
+ return cls._wrap(tensor, format=format, canvas_size=canvas_size)
+
+ @classmethod
+ def _wrap_output(
+ cls,
+ output: torch.Tensor,
+ args: Sequence[Any] = (),
+ kwargs: Optional[Mapping[str, Any]] = None,
+ ) -> BoundingBoxes:
+ # If there are BoundingBoxes instances in the output, their metadata got lost when we called
+ # super().__torch_function__. We need to restore the metadata somehow, so we choose to take
+ # the metadata from the first bbox in the parameters.
+ # This should be what we want in most cases. When it's not, it's probably a mis-use anyway, e.g.
+ # something like some_xyxy_bbox + some_xywh_bbox; we don't guard against those cases.
+ flat_params, _ = tree_flatten(args + (tuple(kwargs.values()) if kwargs else ())) # type: ignore[operator]
+ first_bbox_from_args = next(x for x in flat_params if isinstance(x, BoundingBoxes))
+ format, canvas_size = first_bbox_from_args.format, first_bbox_from_args.canvas_size
+
+ if isinstance(output, torch.Tensor) and not isinstance(output, BoundingBoxes):
+ output = BoundingBoxes._wrap(output, format=format, canvas_size=canvas_size, check_dims=False)
+ elif isinstance(output, (tuple, list)):
+ output = type(output)(
+ BoundingBoxes._wrap(part, format=format, canvas_size=canvas_size, check_dims=False) for part in output
+ )
+ return output
+
+ def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
+ return self._make_repr(format=self.format, canvas_size=self.canvas_size)
diff --git a/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/_dataset_wrapper.py b/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/_dataset_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..23683221f6005a9ce6a55e785e59409a649d7928
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/_dataset_wrapper.py
@@ -0,0 +1,666 @@
+# type: ignore
+
+from __future__ import annotations
+
+import collections.abc
+
+import contextlib
+from collections import defaultdict
+from copy import copy
+
+import torch
+
+from torchvision import datasets, tv_tensors
+from torchvision.transforms.v2 import functional as F
+
+__all__ = ["wrap_dataset_for_transforms_v2"]
+
+
+def wrap_dataset_for_transforms_v2(dataset, target_keys=None):
+ """Wrap a ``torchvision.dataset`` for usage with :mod:`torchvision.transforms.v2`.
+
+ Example:
+ >>> dataset = torchvision.datasets.CocoDetection(...)
+ >>> dataset = wrap_dataset_for_transforms_v2(dataset)
+
+ .. note::
+
+ For now, only the most popular datasets are supported. Furthermore, the wrapper only supports dataset
+ configurations that are fully supported by ``torchvision.transforms.v2``. If you encounter an error prompting you
+ to raise an issue to ``torchvision`` for a dataset or configuration that you need, please do so.
+
+ The dataset samples are wrapped according to the description below.
+
+ Special cases:
+
+ * :class:`~torchvision.datasets.CocoDetection`: Instead of returning the target as list of dicts, the wrapper
+ returns a dict of lists. In addition, the key-value-pairs ``"boxes"`` (in ``XYXY`` coordinate format),
+ ``"masks"`` and ``"labels"`` are added and wrap the data in the corresponding ``torchvision.tv_tensors``.
+ The original keys are preserved. If ``target_keys`` is omitted, returns only the values for the
+ ``"image_id"``, ``"boxes"``, and ``"labels"``.
+ * :class:`~torchvision.datasets.VOCDetection`: The key-value-pairs ``"boxes"`` and ``"labels"`` are added to
+ the target and wrap the data in the corresponding ``torchvision.tv_tensors``. The original keys are
+ preserved. If ``target_keys`` is omitted, returns only the values for the ``"boxes"`` and ``"labels"``.
+ * :class:`~torchvision.datasets.CelebA`: The target for ``target_type="bbox"`` is converted to the ``XYXY``
+ coordinate format and wrapped into a :class:`~torchvision.tv_tensors.BoundingBoxes` tv_tensor.
+ * :class:`~torchvision.datasets.Kitti`: Instead returning the target as list of dicts, the wrapper returns a
+ dict of lists. In addition, the key-value-pairs ``"boxes"`` and ``"labels"`` are added and wrap the data
+ in the corresponding ``torchvision.tv_tensors``. The original keys are preserved. If ``target_keys`` is
+ omitted, returns only the values for the ``"boxes"`` and ``"labels"``.
+ * :class:`~torchvision.datasets.OxfordIIITPet`: The target for ``target_type="segmentation"`` is wrapped into a
+ :class:`~torchvision.tv_tensors.Mask` tv_tensor.
+ * :class:`~torchvision.datasets.Cityscapes`: The target for ``target_type="semantic"`` is wrapped into a
+ :class:`~torchvision.tv_tensors.Mask` tv_tensor. The target for ``target_type="instance"`` is *replaced* by
+ a dictionary with the key-value-pairs ``"masks"`` (as :class:`~torchvision.tv_tensors.Mask` tv_tensor) and
+ ``"labels"``.
+ * :class:`~torchvision.datasets.WIDERFace`: The value for key ``"bbox"`` in the target is converted to ``XYXY``
+ coordinate format and wrapped into a :class:`~torchvision.tv_tensors.BoundingBoxes` tv_tensor.
+
+ Image classification datasets
+
+ This wrapper is a no-op for image classification datasets, since they were already fully supported by
+ :mod:`torchvision.transforms` and thus no change is needed for :mod:`torchvision.transforms.v2`.
+
+ Segmentation datasets
+
+ Segmentation datasets, e.g. :class:`~torchvision.datasets.VOCSegmentation`, return a two-tuple of
+ :class:`PIL.Image.Image`'s. This wrapper leaves the image as is (first item), while wrapping the
+ segmentation mask into a :class:`~torchvision.tv_tensors.Mask` (second item).
+
+ Video classification datasets
+
+ Video classification datasets, e.g. :class:`~torchvision.datasets.Kinetics`, return a three-tuple containing a
+ :class:`torch.Tensor` for the video and audio and a :class:`int` as label. This wrapper wraps the video into a
+ :class:`~torchvision.tv_tensors.Video` while leaving the other items as is.
+
+ .. note::
+
+ Only datasets constructed with ``output_format="TCHW"`` are supported, since the alternative
+ ``output_format="THWC"`` is not supported by :mod:`torchvision.transforms.v2`.
+
+ Args:
+ dataset: the dataset instance to wrap for compatibility with transforms v2.
+ target_keys: Target keys to return in case the target is a dictionary. If ``None`` (default), selected keys are
+ specific to the dataset. If ``"all"``, returns the full target. Can also be a collection of strings for
+ fine grained access. Currently only supported for :class:`~torchvision.datasets.CocoDetection`,
+ :class:`~torchvision.datasets.VOCDetection`, :class:`~torchvision.datasets.Kitti`, and
+ :class:`~torchvision.datasets.WIDERFace`. See above for details.
+ """
+ if not (
+ target_keys is None
+ or target_keys == "all"
+ or (isinstance(target_keys, collections.abc.Collection) and all(isinstance(key, str) for key in target_keys))
+ ):
+ raise ValueError(
+ f"`target_keys` can be None, 'all', or a collection of strings denoting the keys to be returned, "
+ f"but got {target_keys}"
+ )
+
+ # Imagine we have isinstance(dataset, datasets.ImageNet). This will create a new class with the name
+ # "WrappedImageNet" at runtime that doubly inherits from VisionDatasetTVTensorWrapper (see below) as well as the
+ # original ImageNet class. This allows the user to do regular isinstance(wrapped_dataset, datasets.ImageNet) checks,
+ # while we can still inject everything that we need.
+ wrapped_dataset_cls = type(f"Wrapped{type(dataset).__name__}", (VisionDatasetTVTensorWrapper, type(dataset)), {})
+ # Since VisionDatasetTVTensorWrapper comes before ImageNet in the MRO, calling the class hits
+ # VisionDatasetTVTensorWrapper.__init__ first. Since we are never doing super().__init__(...), the constructor of
+ # ImageNet is never hit. That is by design, since we don't want to create the dataset instance again, but rather
+ # have the existing instance as attribute on the new object.
+ return wrapped_dataset_cls(dataset, target_keys)
+
+
+class WrapperFactories(dict):
+ def register(self, dataset_cls):
+ def decorator(wrapper_factory):
+ self[dataset_cls] = wrapper_factory
+ return wrapper_factory
+
+ return decorator
+
+
+# We need this two-stage design, i.e. a wrapper factory producing the actual wrapper, since some wrappers depend on the
+# dataset instance rather than just the class, since they require the user defined instance attributes. Thus, we can
+# provide a wrapping from the dataset class to the factory here, but can only instantiate the wrapper at runtime when
+# we have access to the dataset instance.
+WRAPPER_FACTORIES = WrapperFactories()
+
+
+class VisionDatasetTVTensorWrapper:
+ def __init__(self, dataset, target_keys):
+ dataset_cls = type(dataset)
+
+ if not isinstance(dataset, datasets.VisionDataset):
+ raise TypeError(
+ f"This wrapper is meant for subclasses of `torchvision.datasets.VisionDataset`, "
+ f"but got a '{dataset_cls.__name__}' instead.\n"
+ f"For an example of how to perform the wrapping for custom datasets, see\n\n"
+ "https://pytorch.org/vision/main/auto_examples/plot_tv_tensors.html#do-i-have-to-wrap-the-output-of-the-datasets-myself"
+ )
+
+ for cls in dataset_cls.mro():
+ if cls in WRAPPER_FACTORIES:
+ wrapper_factory = WRAPPER_FACTORIES[cls]
+ if target_keys is not None and cls not in {
+ datasets.CocoDetection,
+ datasets.VOCDetection,
+ datasets.Kitti,
+ datasets.WIDERFace,
+ }:
+ raise ValueError(
+ f"`target_keys` is currently only supported for `CocoDetection`, `VOCDetection`, `Kitti`, "
+ f"and `WIDERFace`, but got {cls.__name__}."
+ )
+ break
+ elif cls is datasets.VisionDataset:
+ # TODO: If we have documentation on how to do that, put a link in the error message.
+ msg = f"No wrapper exists for dataset class {dataset_cls.__name__}. Please wrap the output yourself."
+ if dataset_cls in datasets.__dict__.values():
+ msg = (
+ f"{msg} If an automated wrapper for this dataset would be useful for you, "
+ f"please open an issue at https://github.com/pytorch/vision/issues."
+ )
+ raise TypeError(msg)
+
+ self._dataset = dataset
+ self._target_keys = target_keys
+ self._wrapper = wrapper_factory(dataset, target_keys)
+
+ # We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply them.
+ # Although internally, `datasets.VisionDataset` merges `transform` and `target_transform` into the joint
+ # `transforms`
+ # https://github.com/pytorch/vision/blob/135a0f9ea9841b6324b4fe8974e2543cbb95709a/torchvision/datasets/vision.py#L52-L54
+ # some (if not most) datasets still use `transform` and `target_transform` individually. Thus, we need to
+ # disable all three here to be able to extract the untransformed sample to wrap.
+ self.transform, dataset.transform = dataset.transform, None
+ self.target_transform, dataset.target_transform = dataset.target_transform, None
+ self.transforms, dataset.transforms = dataset.transforms, None
+
+ def __getattr__(self, item):
+ with contextlib.suppress(AttributeError):
+ return object.__getattribute__(self, item)
+
+ return getattr(self._dataset, item)
+
+ def __getitem__(self, idx):
+ # This gets us the raw sample since we disabled the transforms for the underlying dataset in the constructor
+ # of this class
+ sample = self._dataset[idx]
+
+ sample = self._wrapper(idx, sample)
+
+ # Regardless of whether the user has supplied the transforms individually (`transform` and `target_transform`)
+ # or joint (`transforms`), we can access the full functionality through `transforms`
+ if self.transforms is not None:
+ sample = self.transforms(*sample)
+
+ return sample
+
+ def __len__(self):
+ return len(self._dataset)
+
+ # TODO: maybe we should use __getstate__ and __setstate__ instead of __reduce__, as recommended in the docs.
+ def __reduce__(self):
+ # __reduce__ gets called when we try to pickle the dataset.
+ # In a DataLoader with spawn context, this gets called `num_workers` times from the main process.
+
+ # We have to reset the [target_]transform[s] attributes of the dataset
+ # to their original values, because we previously set them to None in __init__().
+ dataset = copy(self._dataset)
+ dataset.transform = self.transform
+ dataset.transforms = self.transforms
+ dataset.target_transform = self.target_transform
+
+ return wrap_dataset_for_transforms_v2, (dataset, self._target_keys)
+
+
+def raise_not_supported(description):
+ raise RuntimeError(
+ f"{description} is currently not supported by this wrapper. "
+ f"If this would be helpful for you, please open an issue at https://github.com/pytorch/vision/issues."
+ )
+
+
+def identity(item):
+ return item
+
+
+def identity_wrapper_factory(dataset, target_keys):
+ def wrapper(idx, sample):
+ return sample
+
+ return wrapper
+
+
+def pil_image_to_mask(pil_image):
+ return tv_tensors.Mask(pil_image)
+
+
+def parse_target_keys(target_keys, *, available, default):
+ if target_keys is None:
+ target_keys = default
+ if target_keys == "all":
+ target_keys = available
+ else:
+ target_keys = set(target_keys)
+ extra = target_keys - available
+ if extra:
+ raise ValueError(f"Target keys {sorted(extra)} are not available")
+
+ return target_keys
+
+
+def list_of_dicts_to_dict_of_lists(list_of_dicts):
+ dict_of_lists = defaultdict(list)
+ for dct in list_of_dicts:
+ for key, value in dct.items():
+ dict_of_lists[key].append(value)
+ return dict(dict_of_lists)
+
+
+def wrap_target_by_type(target, *, target_types, type_wrappers):
+ if not isinstance(target, (tuple, list)):
+ target = [target]
+
+ wrapped_target = tuple(
+ type_wrappers.get(target_type, identity)(item) for target_type, item in zip(target_types, target)
+ )
+
+ if len(wrapped_target) == 1:
+ wrapped_target = wrapped_target[0]
+
+ return wrapped_target
+
+
+def classification_wrapper_factory(dataset, target_keys):
+ return identity_wrapper_factory(dataset, target_keys)
+
+
+for dataset_cls in [
+ datasets.Caltech256,
+ datasets.CIFAR10,
+ datasets.CIFAR100,
+ datasets.ImageNet,
+ datasets.MNIST,
+ datasets.FashionMNIST,
+ datasets.GTSRB,
+ datasets.DatasetFolder,
+ datasets.ImageFolder,
+ datasets.Imagenette,
+]:
+ WRAPPER_FACTORIES.register(dataset_cls)(classification_wrapper_factory)
+
+
+def segmentation_wrapper_factory(dataset, target_keys):
+ def wrapper(idx, sample):
+ image, mask = sample
+ return image, pil_image_to_mask(mask)
+
+ return wrapper
+
+
+for dataset_cls in [
+ datasets.VOCSegmentation,
+]:
+ WRAPPER_FACTORIES.register(dataset_cls)(segmentation_wrapper_factory)
+
+
+def video_classification_wrapper_factory(dataset, target_keys):
+ if dataset.video_clips.output_format == "THWC":
+ raise RuntimeError(
+ f"{type(dataset).__name__} with `output_format='THWC'` is not supported by this wrapper, "
+ f"since it is not compatible with the transformations. Please use `output_format='TCHW'` instead."
+ )
+
+ def wrapper(idx, sample):
+ video, audio, label = sample
+
+ video = tv_tensors.Video(video)
+
+ return video, audio, label
+
+ return wrapper
+
+
+for dataset_cls in [
+ datasets.HMDB51,
+ datasets.Kinetics,
+ datasets.UCF101,
+]:
+ WRAPPER_FACTORIES.register(dataset_cls)(video_classification_wrapper_factory)
+
+
+@WRAPPER_FACTORIES.register(datasets.Caltech101)
+def caltech101_wrapper_factory(dataset, target_keys):
+ if "annotation" in dataset.target_type:
+ raise_not_supported("Caltech101 dataset with `target_type=['annotation', ...]`")
+
+ return classification_wrapper_factory(dataset, target_keys)
+
+
+@WRAPPER_FACTORIES.register(datasets.CocoDetection)
+def coco_dectection_wrapper_factory(dataset, target_keys):
+ target_keys = parse_target_keys(
+ target_keys,
+ available={
+ # native
+ "segmentation",
+ "area",
+ "iscrowd",
+ "image_id",
+ "bbox",
+ "category_id",
+ # added by the wrapper
+ "boxes",
+ "masks",
+ "labels",
+ },
+ default={"image_id", "boxes", "labels"},
+ )
+
+ def segmentation_to_mask(segmentation, *, canvas_size):
+ from pycocotools import mask
+
+ if isinstance(segmentation, dict):
+ # if counts is a string, it is already an encoded RLE mask
+ if not isinstance(segmentation["counts"], str):
+ segmentation = mask.frPyObjects(segmentation, *canvas_size)
+ elif isinstance(segmentation, list):
+ segmentation = mask.merge(mask.frPyObjects(segmentation, *canvas_size))
+ else:
+ raise ValueError(f"COCO segmentation expected to be a dict or a list, got {type(segmentation)}")
+ return torch.from_numpy(mask.decode(segmentation))
+
+ def wrapper(idx, sample):
+ image_id = dataset.ids[idx]
+
+ image, target = sample
+
+ if not target:
+ return image, dict(image_id=image_id)
+
+ canvas_size = tuple(F.get_size(image))
+
+ batched_target = list_of_dicts_to_dict_of_lists(target)
+ target = {}
+
+ if "image_id" in target_keys:
+ target["image_id"] = image_id
+
+ if "boxes" in target_keys:
+ target["boxes"] = F.convert_bounding_box_format(
+ tv_tensors.BoundingBoxes(
+ batched_target["bbox"],
+ format=tv_tensors.BoundingBoxFormat.XYWH,
+ canvas_size=canvas_size,
+ ),
+ new_format=tv_tensors.BoundingBoxFormat.XYXY,
+ )
+
+ if "masks" in target_keys:
+ target["masks"] = tv_tensors.Mask(
+ torch.stack(
+ [
+ segmentation_to_mask(segmentation, canvas_size=canvas_size)
+ for segmentation in batched_target["segmentation"]
+ ]
+ ),
+ )
+
+ if "labels" in target_keys:
+ target["labels"] = torch.tensor(batched_target["category_id"])
+
+ for target_key in target_keys - {"image_id", "boxes", "masks", "labels"}:
+ target[target_key] = batched_target[target_key]
+
+ return image, target
+
+ return wrapper
+
+
+WRAPPER_FACTORIES.register(datasets.CocoCaptions)(identity_wrapper_factory)
+
+
+VOC_DETECTION_CATEGORIES = [
+ "__background__",
+ "aeroplane",
+ "bicycle",
+ "bird",
+ "boat",
+ "bottle",
+ "bus",
+ "car",
+ "cat",
+ "chair",
+ "cow",
+ "diningtable",
+ "dog",
+ "horse",
+ "motorbike",
+ "person",
+ "pottedplant",
+ "sheep",
+ "sofa",
+ "train",
+ "tvmonitor",
+]
+VOC_DETECTION_CATEGORY_TO_IDX = dict(zip(VOC_DETECTION_CATEGORIES, range(len(VOC_DETECTION_CATEGORIES))))
+
+
+@WRAPPER_FACTORIES.register(datasets.VOCDetection)
+def voc_detection_wrapper_factory(dataset, target_keys):
+ target_keys = parse_target_keys(
+ target_keys,
+ available={
+ # native
+ "annotation",
+ # added by the wrapper
+ "boxes",
+ "labels",
+ },
+ default={"boxes", "labels"},
+ )
+
+ def wrapper(idx, sample):
+ image, target = sample
+
+ batched_instances = list_of_dicts_to_dict_of_lists(target["annotation"]["object"])
+
+ if "annotation" not in target_keys:
+ target = {}
+
+ if "boxes" in target_keys:
+ target["boxes"] = tv_tensors.BoundingBoxes(
+ [
+ [int(bndbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")]
+ for bndbox in batched_instances["bndbox"]
+ ],
+ format=tv_tensors.BoundingBoxFormat.XYXY,
+ canvas_size=(image.height, image.width),
+ )
+
+ if "labels" in target_keys:
+ target["labels"] = torch.tensor(
+ [VOC_DETECTION_CATEGORY_TO_IDX[category] for category in batched_instances["name"]]
+ )
+
+ return image, target
+
+ return wrapper
+
+
+@WRAPPER_FACTORIES.register(datasets.SBDataset)
+def sbd_wrapper(dataset, target_keys):
+ if dataset.mode == "boundaries":
+ raise_not_supported("SBDataset with mode='boundaries'")
+
+ return segmentation_wrapper_factory(dataset, target_keys)
+
+
+@WRAPPER_FACTORIES.register(datasets.CelebA)
+def celeba_wrapper_factory(dataset, target_keys):
+ if any(target_type in dataset.target_type for target_type in ["attr", "landmarks"]):
+ raise_not_supported("`CelebA` dataset with `target_type=['attr', 'landmarks', ...]`")
+
+ def wrapper(idx, sample):
+ image, target = sample
+
+ target = wrap_target_by_type(
+ target,
+ target_types=dataset.target_type,
+ type_wrappers={
+ "bbox": lambda item: F.convert_bounding_box_format(
+ tv_tensors.BoundingBoxes(
+ item,
+ format=tv_tensors.BoundingBoxFormat.XYWH,
+ canvas_size=(image.height, image.width),
+ ),
+ new_format=tv_tensors.BoundingBoxFormat.XYXY,
+ ),
+ },
+ )
+
+ return image, target
+
+ return wrapper
+
+
+KITTI_CATEGORIES = ["Car", "Van", "Truck", "Pedestrian", "Person_sitting", "Cyclist", "Tram", "Misc", "DontCare"]
+KITTI_CATEGORY_TO_IDX = dict(zip(KITTI_CATEGORIES, range(len(KITTI_CATEGORIES))))
+
+
+@WRAPPER_FACTORIES.register(datasets.Kitti)
+def kitti_wrapper_factory(dataset, target_keys):
+ target_keys = parse_target_keys(
+ target_keys,
+ available={
+ # native
+ "type",
+ "truncated",
+ "occluded",
+ "alpha",
+ "bbox",
+ "dimensions",
+ "location",
+ "rotation_y",
+ # added by the wrapper
+ "boxes",
+ "labels",
+ },
+ default={"boxes", "labels"},
+ )
+
+ def wrapper(idx, sample):
+ image, target = sample
+
+ if target is None:
+ return image, target
+
+ batched_target = list_of_dicts_to_dict_of_lists(target)
+ target = {}
+
+ if "boxes" in target_keys:
+ target["boxes"] = tv_tensors.BoundingBoxes(
+ batched_target["bbox"],
+ format=tv_tensors.BoundingBoxFormat.XYXY,
+ canvas_size=(image.height, image.width),
+ )
+
+ if "labels" in target_keys:
+ target["labels"] = torch.tensor([KITTI_CATEGORY_TO_IDX[category] for category in batched_target["type"]])
+
+ for target_key in target_keys - {"boxes", "labels"}:
+ target[target_key] = batched_target[target_key]
+
+ return image, target
+
+ return wrapper
+
+
+@WRAPPER_FACTORIES.register(datasets.OxfordIIITPet)
+def oxford_iiit_pet_wrapper_factor(dataset, target_keys):
+ def wrapper(idx, sample):
+ image, target = sample
+
+ if target is not None:
+ target = wrap_target_by_type(
+ target,
+ target_types=dataset._target_types,
+ type_wrappers={
+ "segmentation": pil_image_to_mask,
+ },
+ )
+
+ return image, target
+
+ return wrapper
+
+
+@WRAPPER_FACTORIES.register(datasets.Cityscapes)
+def cityscapes_wrapper_factory(dataset, target_keys):
+ if any(target_type in dataset.target_type for target_type in ["polygon", "color"]):
+ raise_not_supported("`Cityscapes` dataset with `target_type=['polygon', 'color', ...]`")
+
+ def instance_segmentation_wrapper(mask):
+ # See https://github.com/mcordts/cityscapesScripts/blob/8da5dd00c9069058ccc134654116aac52d4f6fa2/cityscapesscripts/preparation/json2instanceImg.py#L7-L21
+ data = pil_image_to_mask(mask)
+ masks = []
+ labels = []
+ for id in data.unique():
+ masks.append(data == id)
+ label = id
+ if label >= 1_000:
+ label //= 1_000
+ labels.append(label)
+ return dict(masks=tv_tensors.Mask(torch.stack(masks)), labels=torch.stack(labels))
+
+ def wrapper(idx, sample):
+ image, target = sample
+
+ target = wrap_target_by_type(
+ target,
+ target_types=dataset.target_type,
+ type_wrappers={
+ "instance": instance_segmentation_wrapper,
+ "semantic": pil_image_to_mask,
+ },
+ )
+
+ return image, target
+
+ return wrapper
+
+
+@WRAPPER_FACTORIES.register(datasets.WIDERFace)
+def widerface_wrapper(dataset, target_keys):
+ target_keys = parse_target_keys(
+ target_keys,
+ available={
+ "bbox",
+ "blur",
+ "expression",
+ "illumination",
+ "occlusion",
+ "pose",
+ "invalid",
+ },
+ default="all",
+ )
+
+ def wrapper(idx, sample):
+ image, target = sample
+
+ if target is None:
+ return image, target
+
+ target = {key: target[key] for key in target_keys}
+
+ if "bbox" in target_keys:
+ target["bbox"] = F.convert_bounding_box_format(
+ tv_tensors.BoundingBoxes(
+ target["bbox"], format=tv_tensors.BoundingBoxFormat.XYWH, canvas_size=(image.height, image.width)
+ ),
+ new_format=tv_tensors.BoundingBoxFormat.XYXY,
+ )
+
+ return image, target
+
+ return wrapper
diff --git a/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/_image.py b/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a0a2ec720966f849f8d832a1b9f2e640ba7dc2c
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/_image.py
@@ -0,0 +1,53 @@
+from __future__ import annotations
+
+from typing import Any, Optional, Union
+
+import PIL.Image
+import torch
+
+from ._tv_tensor import TVTensor
+
+
+class Image(TVTensor):
+ """:class:`torch.Tensor` subclass for images with shape ``[..., C, H, W]``.
+
+ .. note::
+
+ In the :ref:`transforms `, ``Image`` instances are largely
+ interchangeable with pure :class:`torch.Tensor`. See
+ :ref:`this note ` for more details.
+
+ Args:
+ data (tensor-like, PIL.Image.Image): Any data that can be turned into a tensor with :func:`torch.as_tensor` as
+ well as PIL images.
+ dtype (torch.dtype, optional): Desired data type. If omitted, will be inferred from
+ ``data``.
+ device (torch.device, optional): Desired device. If omitted and ``data`` is a
+ :class:`torch.Tensor`, the device is taken from it. Otherwise, the image is constructed on the CPU.
+ requires_grad (bool, optional): Whether autograd should record operations. If omitted and
+ ``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``.
+ """
+
+ def __new__(
+ cls,
+ data: Any,
+ *,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[Union[torch.device, str, int]] = None,
+ requires_grad: Optional[bool] = None,
+ ) -> Image:
+ if isinstance(data, PIL.Image.Image):
+ from torchvision.transforms.v2 import functional as F
+
+ data = F.pil_to_tensor(data)
+
+ tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
+ if tensor.ndim < 2:
+ raise ValueError
+ elif tensor.ndim == 2:
+ tensor = tensor.unsqueeze(0)
+
+ return tensor.as_subclass(cls)
+
+ def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
+ return self._make_repr()
diff --git a/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/_mask.py b/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/_mask.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef9d96159fb5e0ac64bb99774ae65be9e11325ae
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/_mask.py
@@ -0,0 +1,39 @@
+from __future__ import annotations
+
+from typing import Any, Optional, Union
+
+import PIL.Image
+import torch
+
+from ._tv_tensor import TVTensor
+
+
+class Mask(TVTensor):
+ """:class:`torch.Tensor` subclass for segmentation and detection masks with shape ``[..., H, W]``.
+
+ Args:
+ data (tensor-like, PIL.Image.Image): Any data that can be turned into a tensor with :func:`torch.as_tensor` as
+ well as PIL images.
+ dtype (torch.dtype, optional): Desired data type. If omitted, will be inferred from
+ ``data``.
+ device (torch.device, optional): Desired device. If omitted and ``data`` is a
+ :class:`torch.Tensor`, the device is taken from it. Otherwise, the mask is constructed on the CPU.
+ requires_grad (bool, optional): Whether autograd should record operations. If omitted and
+ ``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``.
+ """
+
+ def __new__(
+ cls,
+ data: Any,
+ *,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[Union[torch.device, str, int]] = None,
+ requires_grad: Optional[bool] = None,
+ ) -> Mask:
+ if isinstance(data, PIL.Image.Image):
+ from torchvision.transforms.v2 import functional as F
+
+ data = F.pil_to_tensor(data)
+
+ tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
+ return tensor.as_subclass(cls)
diff --git a/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/_torch_function_helpers.py b/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/_torch_function_helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6ea5fddf35b4b585cb2527863fc2db1ba6c049b
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/_torch_function_helpers.py
@@ -0,0 +1,72 @@
+import torch
+
+_TORCHFUNCTION_SUBCLASS = False
+
+
+class _ReturnTypeCM:
+ def __init__(self, to_restore):
+ self.to_restore = to_restore
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *args):
+ global _TORCHFUNCTION_SUBCLASS
+ _TORCHFUNCTION_SUBCLASS = self.to_restore
+
+
+def set_return_type(return_type: str):
+ """Set the return type of torch operations on :class:`~torchvision.tv_tensors.TVTensor`.
+
+ This only affects the behaviour of torch operations. It has no effect on
+ ``torchvision`` transforms or functionals, which will always return as
+ output the same type that was passed as input.
+
+ .. warning::
+
+ We recommend using :class:`~torchvision.transforms.v2.ToPureTensor` at
+ the end of your transform pipelines if you use
+ ``set_return_type("TVTensor")``. This will avoid the
+ ``__torch_function__`` overhead in the models ``forward()``.
+
+ Can be used as a global flag for the entire program:
+
+ .. code:: python
+
+ img = tv_tensors.Image(torch.rand(3, 5, 5))
+ img + 2 # This is a pure Tensor (default behaviour)
+
+ set_return_type("TVTensor")
+ img + 2 # This is an Image
+
+ or as a context manager to restrict the scope:
+
+ .. code:: python
+
+ img = tv_tensors.Image(torch.rand(3, 5, 5))
+ img + 2 # This is a pure Tensor
+ with set_return_type("TVTensor"):
+ img + 2 # This is an Image
+ img + 2 # This is a pure Tensor
+
+ Args:
+ return_type (str): Can be "TVTensor" or "Tensor" (case-insensitive).
+ Default is "Tensor" (i.e. pure :class:`torch.Tensor`).
+ """
+ global _TORCHFUNCTION_SUBCLASS
+ to_restore = _TORCHFUNCTION_SUBCLASS
+
+ try:
+ _TORCHFUNCTION_SUBCLASS = {"tensor": False, "tvtensor": True}[return_type.lower()]
+ except KeyError:
+ raise ValueError(f"return_type must be 'TVTensor' or 'Tensor', got {return_type}") from None
+
+ return _ReturnTypeCM(to_restore)
+
+
+def _must_return_subclass():
+ return _TORCHFUNCTION_SUBCLASS
+
+
+# For those ops we always want to preserve the original subclass instead of returning a pure Tensor
+_FORCE_TORCHFUNCTION_SUBCLASS = {torch.Tensor.clone, torch.Tensor.to, torch.Tensor.detach, torch.Tensor.requires_grad_}
diff --git a/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/_tv_tensor.py b/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/_tv_tensor.py
new file mode 100644
index 0000000000000000000000000000000000000000..508e73724beb8e5ce1202bc35e079b5542149feb
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/_tv_tensor.py
@@ -0,0 +1,132 @@
+from __future__ import annotations
+
+from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union
+
+import torch
+from torch._C import DisableTorchFunctionSubclass
+from torch.types import _device, _dtype, _size
+
+from torchvision.tv_tensors._torch_function_helpers import _FORCE_TORCHFUNCTION_SUBCLASS, _must_return_subclass
+
+
+D = TypeVar("D", bound="TVTensor")
+
+
+class TVTensor(torch.Tensor):
+ """Base class for all TVTensors.
+
+ You probably don't want to use this class unless you're defining your own
+ custom TVTensors. See
+ :ref:`sphx_glr_auto_examples_transforms_plot_custom_tv_tensors.py` for details.
+ """
+
+ @staticmethod
+ def _to_tensor(
+ data: Any,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[Union[torch.device, str, int]] = None,
+ requires_grad: Optional[bool] = None,
+ ) -> torch.Tensor:
+ if requires_grad is None:
+ requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False
+ return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad)
+
+ @classmethod
+ def _wrap_output(
+ cls,
+ output: torch.Tensor,
+ args: Sequence[Any] = (),
+ kwargs: Optional[Mapping[str, Any]] = None,
+ ) -> torch.Tensor:
+ # Same as torch._tensor._convert
+ if isinstance(output, torch.Tensor) and not isinstance(output, cls):
+ output = output.as_subclass(cls)
+
+ if isinstance(output, (tuple, list)):
+ # Also handles things like namedtuples
+ output = type(output)(cls._wrap_output(part, args, kwargs) for part in output)
+ return output
+
+ @classmethod
+ def __torch_function__(
+ cls,
+ func: Callable[..., torch.Tensor],
+ types: Tuple[Type[torch.Tensor], ...],
+ args: Sequence[Any] = (),
+ kwargs: Optional[Mapping[str, Any]] = None,
+ ) -> torch.Tensor:
+ """For general information about how the __torch_function__ protocol works,
+ see https://pytorch.org/docs/stable/notes/extending.html#extending-torch
+
+ TL;DR: Every time a PyTorch operator is called, it goes through the inputs and looks for the
+ ``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the
+ ``args`` and ``kwargs`` of the original call.
+
+ Why do we override this? Because the base implementation in torch.Tensor would preserve the TVTensor type
+ of the output. In our case, we want to return pure tensors instead (with a few exceptions). Refer to the
+ "TVTensors FAQ" gallery example for a rationale of this behaviour (TL;DR: perf + no silver bullet).
+
+ Our implementation below is very similar to the base implementation in ``torch.Tensor`` - go check it out.
+ """
+ if not all(issubclass(cls, t) for t in types):
+ return NotImplemented
+
+ # Like in the base Tensor.__torch_function__ implementation, it's easier to always use
+ # DisableTorchFunctionSubclass and then manually re-wrap the output if necessary
+ with DisableTorchFunctionSubclass():
+ output = func(*args, **kwargs or dict())
+
+ must_return_subclass = _must_return_subclass()
+ if must_return_subclass or (func in _FORCE_TORCHFUNCTION_SUBCLASS and isinstance(args[0], cls)):
+ # If you're wondering why we need the `isinstance(args[0], cls)` check, remove it and see what fails
+ # in test_to_tv_tensor_reference().
+ # The __torch_function__ protocol will invoke the __torch_function__ method on *all* types involved in
+ # the computation by walking the MRO upwards. For example,
+ # `out = a_pure_tensor.to(an_image)` will invoke `Image.__torch_function__` with
+ # `args = (a_pure_tensor, an_image)` first. Without this guard, `out` would
+ # be wrapped into an `Image`.
+ return cls._wrap_output(output, args, kwargs)
+
+ if not must_return_subclass and isinstance(output, cls):
+ # DisableTorchFunctionSubclass is ignored by inplace ops like `.add_(...)`,
+ # so for those, the output is still a TVTensor. Thus, we need to manually unwrap.
+ return output.as_subclass(torch.Tensor)
+
+ return output
+
+ def _make_repr(self, **kwargs: Any) -> str:
+ # This is a poor man's implementation of the proposal in https://github.com/pytorch/pytorch/issues/76532.
+ # If that ever gets implemented, remove this in favor of the solution on the `torch.Tensor` class.
+ extra_repr = ", ".join(f"{key}={value}" for key, value in kwargs.items())
+ return f"{super().__repr__()[:-1]}, {extra_repr})"
+
+ # Add properties for common attributes like shape, dtype, device, ndim etc
+ # this way we return the result without passing into __torch_function__
+ @property
+ def shape(self) -> _size: # type: ignore[override]
+ with DisableTorchFunctionSubclass():
+ return super().shape
+
+ @property
+ def ndim(self) -> int: # type: ignore[override]
+ with DisableTorchFunctionSubclass():
+ return super().ndim
+
+ @property
+ def device(self, *args: Any, **kwargs: Any) -> _device: # type: ignore[override]
+ with DisableTorchFunctionSubclass():
+ return super().device
+
+ @property
+ def dtype(self) -> _dtype: # type: ignore[override]
+ with DisableTorchFunctionSubclass():
+ return super().dtype
+
+ def __deepcopy__(self: D, memo: Dict[int, Any]) -> D:
+ # We need to detach first, since a plain `Tensor.clone` will be part of the computation graph, which does
+ # *not* happen for `deepcopy(Tensor)`. A side-effect from detaching is that the `Tensor.requires_grad`
+ # attribute is cleared, so we need to refill it before we return.
+ # Note: We don't explicitly handle deep-copying of the metadata here. The only metadata we currently have is
+ # `BoundingBoxes.format` and `BoundingBoxes.canvas_size`, which are immutable and thus implicitly deep-copied by
+ # `BoundingBoxes.clone()`.
+ return self.detach().clone().requires_grad_(self.requires_grad) # type: ignore[return-value]
diff --git a/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/_video.py b/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa923e781ef0438e88c672f4f46d884689e206af
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/torchvision/tv_tensors/_video.py
@@ -0,0 +1,37 @@
+from __future__ import annotations
+
+from typing import Any, Optional, Union
+
+import torch
+
+from ._tv_tensor import TVTensor
+
+
+class Video(TVTensor):
+ """:class:`torch.Tensor` subclass for videos with shape ``[..., T, C, H, W]``.
+
+ Args:
+ data (tensor-like): Any data that can be turned into a tensor with :func:`torch.as_tensor`.
+ dtype (torch.dtype, optional): Desired data type. If omitted, will be inferred from
+ ``data``.
+ device (torch.device, optional): Desired device. If omitted and ``data`` is a
+ :class:`torch.Tensor`, the device is taken from it. Otherwise, the video is constructed on the CPU.
+ requires_grad (bool, optional): Whether autograd should record operations. If omitted and
+ ``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``.
+ """
+
+ def __new__(
+ cls,
+ data: Any,
+ *,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[Union[torch.device, str, int]] = None,
+ requires_grad: Optional[bool] = None,
+ ) -> Video:
+ tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
+ if data.ndim < 4:
+ raise ValueError
+ return tensor.as_subclass(cls)
+
+ def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
+ return self._make_repr()