diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__init__.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..669d6e86ef482b43e05c2d31220421b4b3d80c15 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/__init__.py @@ -0,0 +1,146 @@ +from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel +from ._stereo_matching import ( + CarlaStereo, + CREStereo, + ETH3DStereo, + FallingThingsStereo, + InStereo2k, + Kitti2012Stereo, + Kitti2015Stereo, + Middlebury2014Stereo, + SceneFlowStereo, + SintelStereo, +) +from .caltech import Caltech101, Caltech256 +from .celeba import CelebA +from .cifar import CIFAR10, CIFAR100 +from .cityscapes import Cityscapes +from .clevr import CLEVRClassification +from .coco import CocoCaptions, CocoDetection +from .country211 import Country211 +from .dtd import DTD +from .eurosat import EuroSAT +from .fakedata import FakeData +from .fer2013 import FER2013 +from .fgvc_aircraft import FGVCAircraft +from .flickr import Flickr30k, Flickr8k +from .flowers102 import Flowers102 +from .folder import DatasetFolder, ImageFolder +from .food101 import Food101 +from .gtsrb import GTSRB +from .hmdb51 import HMDB51 +from .imagenet import ImageNet +from .imagenette import Imagenette +from .inaturalist import INaturalist +from .kinetics import Kinetics +from .kitti import Kitti +from .lfw import LFWPairs, LFWPeople +from .lsun import LSUN, LSUNClass +from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST +from .moving_mnist import MovingMNIST +from .omniglot import Omniglot +from .oxford_iiit_pet import OxfordIIITPet +from .pcam import PCAM +from .phototour import PhotoTour +from .places365 import Places365 +from .rendered_sst2 import RenderedSST2 +from .sbd import SBDataset +from .sbu import SBU +from .semeion import SEMEION +from .stanford_cars import StanfordCars +from .stl10 import STL10 +from .sun397 import SUN397 +from .svhn import SVHN +from .ucf101 import UCF101 +from .usps import USPS +from .vision import VisionDataset +from .voc import VOCDetection, VOCSegmentation +from .widerface import WIDERFace + +__all__ = ( + "LSUN", + "LSUNClass", + "ImageFolder", + "DatasetFolder", + "FakeData", + "CocoCaptions", + "CocoDetection", + "CIFAR10", + "CIFAR100", + "EMNIST", + "FashionMNIST", + "QMNIST", + "MNIST", + "KMNIST", + "StanfordCars", + "STL10", + "SUN397", + "SVHN", + "PhotoTour", + "SEMEION", + "Omniglot", + "SBU", + "Flickr8k", + "Flickr30k", + "Flowers102", + "VOCSegmentation", + "VOCDetection", + "Cityscapes", + "ImageNet", + "Caltech101", + "Caltech256", + "CelebA", + "WIDERFace", + "SBDataset", + "VisionDataset", + "USPS", + "Kinetics", + "HMDB51", + "UCF101", + "Places365", + "Kitti", + "INaturalist", + "LFWPeople", + "LFWPairs", + "KittiFlow", + "Sintel", + "FlyingChairs", + "FlyingThings3D", + "HD1K", + "Food101", + "DTD", + "FER2013", + "GTSRB", + "CLEVRClassification", + "OxfordIIITPet", + "PCAM", + "Country211", + "FGVCAircraft", + "EuroSAT", + "RenderedSST2", + "Kitti2012Stereo", + "Kitti2015Stereo", + "CarlaStereo", + "Middlebury2014Stereo", + "CREStereo", + "FallingThingsStereo", + "SceneFlowStereo", + "SintelStereo", + "InStereo2k", + "ETH3DStereo", + "wrap_dataset_for_transforms_v2", + "Imagenette", +) + + +# We override current module's attributes to handle the import: +# from torchvision.datasets import wrap_dataset_for_transforms_v2 +# without a cyclic error. +# Ref: https://peps.python.org/pep-0562/ +def __getattr__(name): + if name in ("wrap_dataset_for_transforms_v2",): + from torchvision.tv_tensors._dataset_wrapper import wrap_dataset_for_transforms_v2 + + return wrap_dataset_for_transforms_v2 + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/_optical_flow.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/_optical_flow.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5eaa85f2316d68dac70bae6320a8c799c8f1cd8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/_optical_flow.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/_stereo_matching.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/_stereo_matching.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fd537b32f080dda9c798aacfece0cf8341a1696 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/_stereo_matching.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/celeba.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/celeba.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..840196c1602a2d7dcd160f41d19f11cba0abb485 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/celeba.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/cifar.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/cifar.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82f807ab83009ffe9c800f1dc66b860094c7e5d9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/cifar.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/cityscapes.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/cityscapes.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..388f42f1dbc0e915e2bd7cdeaf37f312163d55be Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/cityscapes.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/clevr.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/clevr.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e697ab15e33778aa0844fff3210c1daaf1084c9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/clevr.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/dtd.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/dtd.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61727b6bdc800f19eb6479b15fb361c6803f36e1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/dtd.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/eurosat.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/eurosat.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0936c741662817d24a3a9b854240a321e528d86 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/eurosat.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/fer2013.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/fer2013.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..327d29f641b83821ea2e6dd36a6bb7d63cf6fae4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/fer2013.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/fgvc_aircraft.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/fgvc_aircraft.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2798788debf77c97561cf5552f39074a7a3ff81 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/fgvc_aircraft.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/flowers102.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/flowers102.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec153a7d655f29bad6c55a46252b1981619b6b36 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/flowers102.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/folder.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/folder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9f197d5375ec80ab7d098900b6046ffc0554f0c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/folder.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/food101.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/food101.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e93aafcbe4b0b96c185fe57174283ee5c9af131b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/food101.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/gtsrb.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/gtsrb.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9a2e01d157a7c3f5b6486d3058d39077cd1e8bc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/gtsrb.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/imagenette.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/imagenette.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f697fb733196e01dfaeba8a22574db3476d2cf86 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/imagenette.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/inaturalist.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/inaturalist.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8281a75cbed7ad956909b0f7663b9792ad23c453 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/inaturalist.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/kitti.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/kitti.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94bc83e4e859430056441d4be9ab9eb316c7a82c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/kitti.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/lfw.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/lfw.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe3d787218f526a25db081cf0b796b7f2f912619 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/lfw.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/lsun.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/lsun.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86bc698f9cd086595217d452eb4e3db337500921 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/lsun.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/moving_mnist.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/moving_mnist.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2bcdc46f3c941845fed9f1c8928632cd6f8ebd76 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/moving_mnist.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/omniglot.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/omniglot.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5330f67cd8b12f46488fc280ccee4bb5993a9076 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/omniglot.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/oxford_iiit_pet.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/oxford_iiit_pet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a78e8769351fd78bb8258d287d584f24eca87a6b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/oxford_iiit_pet.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/pcam.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/pcam.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a1bbb31219d4318d36d1305365b545ac58b041a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/pcam.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/phototour.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/phototour.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..461a2f742d28a1393b2122989fccea51d43ee89b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/phototour.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/places365.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/places365.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..776e662d5b4dbbcbbd6a0e5e66e79a244ef009fe Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/places365.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/rendered_sst2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/rendered_sst2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56aa5947e6bd41173b7ed1aa61047c82eddf22e1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/rendered_sst2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/sbd.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/sbd.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d98a74bd15b9ce49723a6b20be1f2aedf75a29b4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/sbd.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/semeion.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/semeion.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2542501e0f6668de4ebbf4fb8e84368b5ab31a1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/semeion.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/stanford_cars.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/stanford_cars.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bebea1cf5894bd127f52b7a5dbfba504710a4377 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/stanford_cars.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/stl10.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/stl10.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46adab4189665ef6b1f717389aa73fb8f13fcb37 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/stl10.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/svhn.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/svhn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8566c02e1fd3bb7b83d9e134e2b80fcd30e7e41e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/svhn.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/ucf101.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/ucf101.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a253537bafbf3bd63fa97a352473da341f3a95dc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/ucf101.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/usps.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/usps.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8bd4074512de593d88e4a76ff00029ffadeb90e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/usps.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/vision.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/vision.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5491c2f7ac2ad17e42f0570976eb08b74ea0e4bd Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/vision.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/_stereo_matching.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/_stereo_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..1deaab7e2f38b072d94b251a5d61321e5430d3f7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/_stereo_matching.py @@ -0,0 +1,1224 @@ +import functools +import json +import os +import random +import shutil +from abc import ABC, abstractmethod +from glob import glob +from pathlib import Path +from typing import Callable, cast, List, Optional, Tuple, Union + +import numpy as np +from PIL import Image + +from .utils import _read_pfm, download_and_extract_archive, verify_str_arg +from .vision import VisionDataset + +T1 = Tuple[Image.Image, Image.Image, Optional[np.ndarray], np.ndarray] +T2 = Tuple[Image.Image, Image.Image, Optional[np.ndarray]] + +__all__ = () + +_read_pfm_file = functools.partial(_read_pfm, slice_channels=1) + + +class StereoMatchingDataset(ABC, VisionDataset): + """Base interface for Stereo matching datasets""" + + _has_built_in_disparity_mask = False + + def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None: + """ + Args: + root(str): Root directory of the dataset. + transforms(callable, optional): A function/transform that takes in Tuples of + (images, disparities, valid_masks) and returns a transformed version of each of them. + images is a Tuple of (``PIL.Image``, ``PIL.Image``) + disparities is a Tuple of (``np.ndarray``, ``np.ndarray``) with shape (1, H, W) + valid_masks is a Tuple of (``np.ndarray``, ``np.ndarray``) with shape (H, W) + In some cases, when a dataset does not provide disparities, the ``disparities`` and + ``valid_masks`` can be Tuples containing None values. + For training splits generally the datasets provide a minimal guarantee of + images: (``PIL.Image``, ``PIL.Image``) + disparities: (``np.ndarray``, ``None``) with shape (1, H, W) + Optionally, based on the dataset, it can return a ``mask`` as well: + valid_masks: (``np.ndarray | None``, ``None``) with shape (H, W) + For some test splits, the datasets provides outputs that look like: + imgaes: (``PIL.Image``, ``PIL.Image``) + disparities: (``None``, ``None``) + Optionally, based on the dataset, it can return a ``mask`` as well: + valid_masks: (``None``, ``None``) + """ + super().__init__(root=root) + self.transforms = transforms + + self._images = [] # type: ignore + self._disparities = [] # type: ignore + + def _read_img(self, file_path: Union[str, Path]) -> Image.Image: + img = Image.open(file_path) + if img.mode != "RGB": + img = img.convert("RGB") # type: ignore [assignment] + return img + + def _scan_pairs( + self, + paths_left_pattern: str, + paths_right_pattern: Optional[str] = None, + ) -> List[Tuple[str, Optional[str]]]: + + left_paths = list(sorted(glob(paths_left_pattern))) + + right_paths: List[Union[None, str]] + if paths_right_pattern: + right_paths = list(sorted(glob(paths_right_pattern))) + else: + right_paths = list(None for _ in left_paths) + + if not left_paths: + raise FileNotFoundError(f"Could not find any files matching the patterns: {paths_left_pattern}") + + if not right_paths: + raise FileNotFoundError(f"Could not find any files matching the patterns: {paths_right_pattern}") + + if len(left_paths) != len(right_paths): + raise ValueError( + f"Found {len(left_paths)} left files but {len(right_paths)} right files using:\n " + f"left pattern: {paths_left_pattern}\n" + f"right pattern: {paths_right_pattern}\n" + ) + + paths = list((left, right) for left, right in zip(left_paths, right_paths)) + return paths + + @abstractmethod + def _read_disparity(self, file_path: str) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]: + # function that returns a disparity map and an occlusion map + pass + + def __getitem__(self, index: int) -> Union[T1, T2]: + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: A 3 or 4-tuple with ``(img_left, img_right, disparity, Optional[valid_mask])`` where ``valid_mask`` + can be a numpy boolean mask of shape (H, W) if the dataset provides a file + indicating which disparity pixels are valid. The disparity is a numpy array of + shape (1, H, W) and the images are PIL images. ``disparity`` is None for + datasets on which for ``split="test"`` the authors did not provide annotations. + """ + img_left = self._read_img(self._images[index][0]) + img_right = self._read_img(self._images[index][1]) + + dsp_map_left, valid_mask_left = self._read_disparity(self._disparities[index][0]) + dsp_map_right, valid_mask_right = self._read_disparity(self._disparities[index][1]) + + imgs = (img_left, img_right) + dsp_maps = (dsp_map_left, dsp_map_right) + valid_masks = (valid_mask_left, valid_mask_right) + + if self.transforms is not None: + ( + imgs, + dsp_maps, + valid_masks, + ) = self.transforms(imgs, dsp_maps, valid_masks) + + if self._has_built_in_disparity_mask or valid_masks[0] is not None: + return imgs[0], imgs[1], dsp_maps[0], cast(np.ndarray, valid_masks[0]) + else: + return imgs[0], imgs[1], dsp_maps[0] + + def __len__(self) -> int: + return len(self._images) + + +class CarlaStereo(StereoMatchingDataset): + """ + Carla simulator data linked in the `CREStereo github repo `_. + + The dataset is expected to have the following structure: :: + + root + carla-highres + trainingF + scene1 + img0.png + img1.png + disp0GT.pfm + disp1GT.pfm + calib.txt + scene2 + img0.png + img1.png + disp0GT.pfm + disp1GT.pfm + calib.txt + ... + + Args: + root (str or ``pathlib.Path``): Root directory where `carla-highres` is located. + transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. + """ + + def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None: + super().__init__(root, transforms) + + root = Path(root) / "carla-highres" + + left_image_pattern = str(root / "trainingF" / "*" / "im0.png") + right_image_pattern = str(root / "trainingF" / "*" / "im1.png") + imgs = self._scan_pairs(left_image_pattern, right_image_pattern) + self._images = imgs + + left_disparity_pattern = str(root / "trainingF" / "*" / "disp0GT.pfm") + right_disparity_pattern = str(root / "trainingF" / "*" / "disp1GT.pfm") + disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern) + self._disparities = disparities + + def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]: + disparity_map = _read_pfm_file(file_path) + disparity_map = np.abs(disparity_map) # ensure that the disparity is positive + valid_mask = None + return disparity_map, valid_mask + + def __getitem__(self, index: int) -> T1: + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: A 3-tuple with ``(img_left, img_right, disparity)``. + The disparity is a numpy array of shape (1, H, W) and the images are PIL images. + If a ``valid_mask`` is generated within the ``transforms`` parameter, + a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned. + """ + return cast(T1, super().__getitem__(index)) + + +class Kitti2012Stereo(StereoMatchingDataset): + """ + KITTI dataset from the `2012 stereo evaluation benchmark `_. + Uses the RGB images for consistency with KITTI 2015. + + The dataset is expected to have the following structure: :: + + root + Kitti2012 + testing + colored_0 + 1_10.png + 2_10.png + ... + colored_1 + 1_10.png + 2_10.png + ... + training + colored_0 + 1_10.png + 2_10.png + ... + colored_1 + 1_10.png + 2_10.png + ... + disp_noc + 1.png + 2.png + ... + calib + + Args: + root (str or ``pathlib.Path``): Root directory where `Kitti2012` is located. + split (string, optional): The dataset split of scenes, either "train" (default) or "test". + transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. + """ + + _has_built_in_disparity_mask = True + + def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None: + super().__init__(root, transforms) + + verify_str_arg(split, "split", valid_values=("train", "test")) + + root = Path(root) / "Kitti2012" / (split + "ing") + + left_img_pattern = str(root / "colored_0" / "*_10.png") + right_img_pattern = str(root / "colored_1" / "*_10.png") + self._images = self._scan_pairs(left_img_pattern, right_img_pattern) + + if split == "train": + disparity_pattern = str(root / "disp_noc" / "*.png") + self._disparities = self._scan_pairs(disparity_pattern, None) + else: + self._disparities = list((None, None) for _ in self._images) + + def _read_disparity(self, file_path: str) -> Tuple[Optional[np.ndarray], None]: + # test split has no disparity maps + if file_path is None: + return None, None + + disparity_map = np.asarray(Image.open(file_path)) / 256.0 + # unsqueeze the disparity map into (C, H, W) format + disparity_map = disparity_map[None, :, :] + valid_mask = None + return disparity_map, valid_mask + + def __getitem__(self, index: int) -> T1: + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``. + The disparity is a numpy array of shape (1, H, W) and the images are PIL images. + ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not + generate a valid mask. + Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test. + """ + return cast(T1, super().__getitem__(index)) + + +class Kitti2015Stereo(StereoMatchingDataset): + """ + KITTI dataset from the `2015 stereo evaluation benchmark `_. + + The dataset is expected to have the following structure: :: + + root + Kitti2015 + testing + image_2 + img1.png + img2.png + ... + image_3 + img1.png + img2.png + ... + training + image_2 + img1.png + img2.png + ... + image_3 + img1.png + img2.png + ... + disp_occ_0 + img1.png + img2.png + ... + disp_occ_1 + img1.png + img2.png + ... + calib + + Args: + root (str or ``pathlib.Path``): Root directory where `Kitti2015` is located. + split (string, optional): The dataset split of scenes, either "train" (default) or "test". + transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. + """ + + _has_built_in_disparity_mask = True + + def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None: + super().__init__(root, transforms) + + verify_str_arg(split, "split", valid_values=("train", "test")) + + root = Path(root) / "Kitti2015" / (split + "ing") + left_img_pattern = str(root / "image_2" / "*.png") + right_img_pattern = str(root / "image_3" / "*.png") + self._images = self._scan_pairs(left_img_pattern, right_img_pattern) + + if split == "train": + left_disparity_pattern = str(root / "disp_occ_0" / "*.png") + right_disparity_pattern = str(root / "disp_occ_1" / "*.png") + self._disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern) + else: + self._disparities = list((None, None) for _ in self._images) + + def _read_disparity(self, file_path: str) -> Tuple[Optional[np.ndarray], None]: + # test split has no disparity maps + if file_path is None: + return None, None + + disparity_map = np.asarray(Image.open(file_path)) / 256.0 + # unsqueeze the disparity map into (C, H, W) format + disparity_map = disparity_map[None, :, :] + valid_mask = None + return disparity_map, valid_mask + + def __getitem__(self, index: int) -> T1: + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``. + The disparity is a numpy array of shape (1, H, W) and the images are PIL images. + ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not + generate a valid mask. + Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test. + """ + return cast(T1, super().__getitem__(index)) + + +class Middlebury2014Stereo(StereoMatchingDataset): + """Publicly available scenes from the Middlebury dataset `2014 version `. + + The dataset mostly follows the original format, without containing the ambient subdirectories. : :: + + root + Middlebury2014 + train + scene1-{perfect,imperfect} + calib.txt + im{0,1}.png + im1E.png + im1L.png + disp{0,1}.pfm + disp{0,1}-n.png + disp{0,1}-sd.pfm + disp{0,1}y.pfm + scene2-{perfect,imperfect} + calib.txt + im{0,1}.png + im1E.png + im1L.png + disp{0,1}.pfm + disp{0,1}-n.png + disp{0,1}-sd.pfm + disp{0,1}y.pfm + ... + additional + scene1-{perfect,imperfect} + calib.txt + im{0,1}.png + im1E.png + im1L.png + disp{0,1}.pfm + disp{0,1}-n.png + disp{0,1}-sd.pfm + disp{0,1}y.pfm + ... + test + scene1 + calib.txt + im{0,1}.png + scene2 + calib.txt + im{0,1}.png + ... + + Args: + root (str or ``pathlib.Path``): Root directory of the Middleburry 2014 Dataset. + split (string, optional): The dataset split of scenes, either "train" (default), "test", or "additional" + use_ambient_views (boolean, optional): Whether to use different expose or lightning views when possible. + The dataset samples with equal probability between ``[im1.png, im1E.png, im1L.png]``. + calibration (string, optional): Whether or not to use the calibrated (default) or uncalibrated scenes. + transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. + download (boolean, optional): Whether or not to download the dataset in the ``root`` directory. + """ + + splits = { + "train": [ + "Adirondack", + "Jadeplant", + "Motorcycle", + "Piano", + "Pipes", + "Playroom", + "Playtable", + "Recycle", + "Shelves", + "Vintage", + ], + "additional": [ + "Backpack", + "Bicycle1", + "Cable", + "Classroom1", + "Couch", + "Flowers", + "Mask", + "Shopvac", + "Sticks", + "Storage", + "Sword1", + "Sword2", + "Umbrella", + ], + "test": [ + "Plants", + "Classroom2E", + "Classroom2", + "Australia", + "DjembeL", + "CrusadeP", + "Crusade", + "Hoops", + "Bicycle2", + "Staircase", + "Newkuba", + "AustraliaP", + "Djembe", + "Livingroom", + "Computer", + ], + } + + _has_built_in_disparity_mask = True + + def __init__( + self, + root: Union[str, Path], + split: str = "train", + calibration: Optional[str] = "perfect", + use_ambient_views: bool = False, + transforms: Optional[Callable] = None, + download: bool = False, + ) -> None: + super().__init__(root, transforms) + + verify_str_arg(split, "split", valid_values=("train", "test", "additional")) + self.split = split + + if calibration: + verify_str_arg(calibration, "calibration", valid_values=("perfect", "imperfect", "both", None)) # type: ignore + if split == "test": + raise ValueError("Split 'test' has only no calibration settings, please set `calibration=None`.") + else: + if split != "test": + raise ValueError( + f"Split '{split}' has calibration settings, however None was provided as an argument." + f"\nSetting calibration to 'perfect' for split '{split}'. Available calibration settings are: 'perfect', 'imperfect', 'both'.", + ) + + if download: + self._download_dataset(root) + + root = Path(root) / "Middlebury2014" + + if not os.path.exists(root / split): + raise FileNotFoundError(f"The {split} directory was not found in the provided root directory") + + split_scenes = self.splits[split] + # check that the provided root folder contains the scene splits + if not any( + # using startswith to account for perfect / imperfect calibrartion + scene.startswith(s) + for scene in os.listdir(root / split) + for s in split_scenes + ): + raise FileNotFoundError(f"Provided root folder does not contain any scenes from the {split} split.") + + calibrartion_suffixes = { + None: [""], + "perfect": ["-perfect"], + "imperfect": ["-imperfect"], + "both": ["-perfect", "-imperfect"], + }[calibration] + + for calibration_suffix in calibrartion_suffixes: + scene_pattern = "*" + calibration_suffix + left_img_pattern = str(root / split / scene_pattern / "im0.png") + right_img_pattern = str(root / split / scene_pattern / "im1.png") + self._images += self._scan_pairs(left_img_pattern, right_img_pattern) + + if split == "test": + self._disparities = list((None, None) for _ in self._images) + else: + left_dispartity_pattern = str(root / split / scene_pattern / "disp0.pfm") + right_dispartity_pattern = str(root / split / scene_pattern / "disp1.pfm") + self._disparities += self._scan_pairs(left_dispartity_pattern, right_dispartity_pattern) + + self.use_ambient_views = use_ambient_views + + def _read_img(self, file_path: Union[str, Path]) -> Image.Image: + """ + Function that reads either the original right image or an augmented view when ``use_ambient_views`` is True. + When ``use_ambient_views`` is True, the dataset will return at random one of ``[im1.png, im1E.png, im1L.png]`` + as the right image. + """ + ambient_file_paths: List[Union[str, Path]] # make mypy happy + + if not isinstance(file_path, Path): + file_path = Path(file_path) + + if file_path.name == "im1.png" and self.use_ambient_views: + base_path = file_path.parent + # initialize sampleable container + ambient_file_paths = list(base_path / view_name for view_name in ["im1E.png", "im1L.png"]) + # double check that we're not going to try to read from an invalid file path + ambient_file_paths = list(filter(lambda p: os.path.exists(p), ambient_file_paths)) + # keep the original image as an option as well for uniform sampling between base views + ambient_file_paths.append(file_path) + file_path = random.choice(ambient_file_paths) # type: ignore + return super()._read_img(file_path) + + def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]: + # test split has not disparity maps + if file_path is None: + return None, None + + disparity_map = _read_pfm_file(file_path) + disparity_map = np.abs(disparity_map) # ensure that the disparity is positive + disparity_map[disparity_map == np.inf] = 0 # remove infinite disparities + valid_mask = (disparity_map > 0).squeeze(0) # mask out invalid disparities + return disparity_map, valid_mask + + def _download_dataset(self, root: Union[str, Path]) -> None: + base_url = "https://vision.middlebury.edu/stereo/data/scenes2014/zip" + # train and additional splits have 2 different calibration settings + root = Path(root) / "Middlebury2014" + split_name = self.split + + if split_name != "test": + for split_scene in self.splits[split_name]: + split_root = root / split_name + for calibration in ["perfect", "imperfect"]: + scene_name = f"{split_scene}-{calibration}" + scene_url = f"{base_url}/{scene_name}.zip" + print(f"Downloading {scene_url}") + # download the scene only if it doesn't exist + if not (split_root / scene_name).exists(): + download_and_extract_archive( + url=scene_url, + filename=f"{scene_name}.zip", + download_root=str(split_root), + remove_finished=True, + ) + else: + os.makedirs(root / "test") + if any(s not in os.listdir(root / "test") for s in self.splits["test"]): + # test split is downloaded from a different location + test_set_url = "https://vision.middlebury.edu/stereo/submit3/zip/MiddEval3-data-F.zip" + # the unzip is going to produce a directory MiddEval3 with two subdirectories trainingF and testF + # we want to move the contents from testF into the directory + download_and_extract_archive(url=test_set_url, download_root=str(root), remove_finished=True) + for scene_dir, scene_names, _ in os.walk(str(root / "MiddEval3/testF")): + for scene in scene_names: + scene_dst_dir = root / "test" + scene_src_dir = Path(scene_dir) / scene + os.makedirs(scene_dst_dir, exist_ok=True) + shutil.move(str(scene_src_dir), str(scene_dst_dir)) + + # cleanup MiddEval3 directory + shutil.rmtree(str(root / "MiddEval3")) + + def __getitem__(self, index: int) -> T2: + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``. + The disparity is a numpy array of shape (1, H, W) and the images are PIL images. + ``valid_mask`` is implicitly ``None`` for `split=test`. + """ + return cast(T2, super().__getitem__(index)) + + +class CREStereo(StereoMatchingDataset): + """Synthetic dataset used in training the `CREStereo `_ architecture. + Dataset details on the official paper `repo `_. + + The dataset is expected to have the following structure: :: + + root + CREStereo + tree + img1_left.jpg + img1_right.jpg + img1_left.disp.jpg + img1_right.disp.jpg + img2_left.jpg + img2_right.jpg + img2_left.disp.jpg + img2_right.disp.jpg + ... + shapenet + img1_left.jpg + img1_right.jpg + img1_left.disp.jpg + img1_right.disp.jpg + ... + reflective + img1_left.jpg + img1_right.jpg + img1_left.disp.jpg + img1_right.disp.jpg + ... + hole + img1_left.jpg + img1_right.jpg + img1_left.disp.jpg + img1_right.disp.jpg + ... + + Args: + root (str): Root directory of the dataset. + transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. + """ + + _has_built_in_disparity_mask = True + + def __init__( + self, + root: Union[str, Path], + transforms: Optional[Callable] = None, + ) -> None: + super().__init__(root, transforms) + + root = Path(root) / "CREStereo" + + dirs = ["shapenet", "reflective", "tree", "hole"] + + for s in dirs: + left_image_pattern = str(root / s / "*_left.jpg") + right_image_pattern = str(root / s / "*_right.jpg") + imgs = self._scan_pairs(left_image_pattern, right_image_pattern) + self._images += imgs + + left_disparity_pattern = str(root / s / "*_left.disp.png") + right_disparity_pattern = str(root / s / "*_right.disp.png") + disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern) + self._disparities += disparities + + def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]: + disparity_map = np.asarray(Image.open(file_path), dtype=np.float32) + # unsqueeze the disparity map into (C, H, W) format + disparity_map = disparity_map[None, :, :] / 32.0 + valid_mask = None + return disparity_map, valid_mask + + def __getitem__(self, index: int) -> T1: + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``. + The disparity is a numpy array of shape (1, H, W) and the images are PIL images. + ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not + generate a valid mask. + """ + return cast(T1, super().__getitem__(index)) + + +class FallingThingsStereo(StereoMatchingDataset): + """`FallingThings `_ dataset. + + The dataset is expected to have the following structure: :: + + root + FallingThings + single + dir1 + scene1 + _object_settings.json + _camera_settings.json + image1.left.depth.png + image1.right.depth.png + image1.left.jpg + image1.right.jpg + image2.left.depth.png + image2.right.depth.png + image2.left.jpg + image2.right + ... + scene2 + ... + mixed + scene1 + _object_settings.json + _camera_settings.json + image1.left.depth.png + image1.right.depth.png + image1.left.jpg + image1.right.jpg + image2.left.depth.png + image2.right.depth.png + image2.left.jpg + image2.right + ... + scene2 + ... + + Args: + root (str or ``pathlib.Path``): Root directory where FallingThings is located. + variant (string): Which variant to use. Either "single", "mixed", or "both". + transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. + """ + + def __init__(self, root: Union[str, Path], variant: str = "single", transforms: Optional[Callable] = None) -> None: + super().__init__(root, transforms) + + root = Path(root) / "FallingThings" + + verify_str_arg(variant, "variant", valid_values=("single", "mixed", "both")) + + variants = { + "single": ["single"], + "mixed": ["mixed"], + "both": ["single", "mixed"], + }[variant] + + split_prefix = { + "single": Path("*") / "*", + "mixed": Path("*"), + } + + for s in variants: + left_img_pattern = str(root / s / split_prefix[s] / "*.left.jpg") + right_img_pattern = str(root / s / split_prefix[s] / "*.right.jpg") + self._images += self._scan_pairs(left_img_pattern, right_img_pattern) + + left_disparity_pattern = str(root / s / split_prefix[s] / "*.left.depth.png") + right_disparity_pattern = str(root / s / split_prefix[s] / "*.right.depth.png") + self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern) + + def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]: + # (H, W) image + depth = np.asarray(Image.open(file_path)) + # as per https://research.nvidia.com/sites/default/files/pubs/2018-06_Falling-Things/readme_0.txt + # in order to extract disparity from depth maps + camera_settings_path = Path(file_path).parent / "_camera_settings.json" + with open(camera_settings_path, "r") as f: + # inverse of depth-from-disparity equation: depth = (baseline * focal) / (disparity * pixel_constant) + intrinsics = json.load(f) + focal = intrinsics["camera_settings"][0]["intrinsic_settings"]["fx"] + baseline, pixel_constant = 6, 100 # pixel constant is inverted + disparity_map = (baseline * focal * pixel_constant) / depth.astype(np.float32) + # unsqueeze disparity to (C, H, W) + disparity_map = disparity_map[None, :, :] + valid_mask = None + return disparity_map, valid_mask + + def __getitem__(self, index: int) -> T1: + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: A 3-tuple with ``(img_left, img_right, disparity)``. + The disparity is a numpy array of shape (1, H, W) and the images are PIL images. + If a ``valid_mask`` is generated within the ``transforms`` parameter, + a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned. + """ + return cast(T1, super().__getitem__(index)) + + +class SceneFlowStereo(StereoMatchingDataset): + """Dataset interface for `Scene Flow `_ datasets. + This interface provides access to the `FlyingThings3D, `Monkaa` and `Driving` datasets. + + The dataset is expected to have the following structure: :: + + root + SceneFlow + Monkaa + frames_cleanpass + scene1 + left + img1.png + img2.png + right + img1.png + img2.png + scene2 + left + img1.png + img2.png + right + img1.png + img2.png + frames_finalpass + scene1 + left + img1.png + img2.png + right + img1.png + img2.png + ... + ... + disparity + scene1 + left + img1.pfm + img2.pfm + right + img1.pfm + img2.pfm + FlyingThings3D + ... + ... + + Args: + root (str or ``pathlib.Path``): Root directory where SceneFlow is located. + variant (string): Which dataset variant to user, "FlyingThings3D" (default), "Monkaa" or "Driving". + pass_name (string): Which pass to use, "clean" (default), "final" or "both". + transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. + + """ + + def __init__( + self, + root: Union[str, Path], + variant: str = "FlyingThings3D", + pass_name: str = "clean", + transforms: Optional[Callable] = None, + ) -> None: + super().__init__(root, transforms) + + root = Path(root) / "SceneFlow" + + verify_str_arg(variant, "variant", valid_values=("FlyingThings3D", "Driving", "Monkaa")) + verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final", "both")) + + passes = { + "clean": ["frames_cleanpass"], + "final": ["frames_finalpass"], + "both": ["frames_cleanpass", "frames_finalpass"], + }[pass_name] + + root = root / variant + + prefix_directories = { + "Monkaa": Path("*"), + "FlyingThings3D": Path("*") / "*" / "*", + "Driving": Path("*") / "*" / "*", + } + + for p in passes: + left_image_pattern = str(root / p / prefix_directories[variant] / "left" / "*.png") + right_image_pattern = str(root / p / prefix_directories[variant] / "right" / "*.png") + self._images += self._scan_pairs(left_image_pattern, right_image_pattern) + + left_disparity_pattern = str(root / "disparity" / prefix_directories[variant] / "left" / "*.pfm") + right_disparity_pattern = str(root / "disparity" / prefix_directories[variant] / "right" / "*.pfm") + self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern) + + def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]: + disparity_map = _read_pfm_file(file_path) + disparity_map = np.abs(disparity_map) # ensure that the disparity is positive + valid_mask = None + return disparity_map, valid_mask + + def __getitem__(self, index: int) -> T1: + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: A 3-tuple with ``(img_left, img_right, disparity)``. + The disparity is a numpy array of shape (1, H, W) and the images are PIL images. + If a ``valid_mask`` is generated within the ``transforms`` parameter, + a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned. + """ + return cast(T1, super().__getitem__(index)) + + +class SintelStereo(StereoMatchingDataset): + """Sintel `Stereo Dataset `_. + + The dataset is expected to have the following structure: :: + + root + Sintel + training + final_left + scene1 + img1.png + img2.png + ... + ... + final_right + scene2 + img1.png + img2.png + ... + ... + disparities + scene1 + img1.png + img2.png + ... + ... + occlusions + scene1 + img1.png + img2.png + ... + ... + outofframe + scene1 + img1.png + img2.png + ... + ... + + Args: + root (str or ``pathlib.Path``): Root directory where Sintel Stereo is located. + pass_name (string): The name of the pass to use, either "final", "clean" or "both". + transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. + """ + + _has_built_in_disparity_mask = True + + def __init__(self, root: Union[str, Path], pass_name: str = "final", transforms: Optional[Callable] = None) -> None: + super().__init__(root, transforms) + + verify_str_arg(pass_name, "pass_name", valid_values=("final", "clean", "both")) + + root = Path(root) / "Sintel" + pass_names = { + "final": ["final"], + "clean": ["clean"], + "both": ["final", "clean"], + }[pass_name] + + for p in pass_names: + left_img_pattern = str(root / "training" / f"{p}_left" / "*" / "*.png") + right_img_pattern = str(root / "training" / f"{p}_right" / "*" / "*.png") + self._images += self._scan_pairs(left_img_pattern, right_img_pattern) + + disparity_pattern = str(root / "training" / "disparities" / "*" / "*.png") + self._disparities += self._scan_pairs(disparity_pattern, None) + + def _get_occlussion_mask_paths(self, file_path: str) -> Tuple[str, str]: + # helper function to get the occlusion mask paths + # a path will look like .../.../.../training/disparities/scene1/img1.png + # we want to get something like .../.../.../training/occlusions/scene1/img1.png + fpath = Path(file_path) + basename = fpath.name + scenedir = fpath.parent + # the parent of the scenedir is actually the disparity dir + sampledir = scenedir.parent.parent + + occlusion_path = str(sampledir / "occlusions" / scenedir.name / basename) + outofframe_path = str(sampledir / "outofframe" / scenedir.name / basename) + + if not os.path.exists(occlusion_path): + raise FileNotFoundError(f"Occlusion mask {occlusion_path} does not exist") + + if not os.path.exists(outofframe_path): + raise FileNotFoundError(f"Out of frame mask {outofframe_path} does not exist") + + return occlusion_path, outofframe_path + + def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]: + if file_path is None: + return None, None + + # disparity decoding as per Sintel instructions in the README provided with the dataset + disparity_map = np.asarray(Image.open(file_path), dtype=np.float32) + r, g, b = np.split(disparity_map, 3, axis=-1) + disparity_map = r * 4 + g / (2**6) + b / (2**14) + # reshape into (C, H, W) format + disparity_map = np.transpose(disparity_map, (2, 0, 1)) + # find the appropriate file paths + occlued_mask_path, out_of_frame_mask_path = self._get_occlussion_mask_paths(file_path) + # occlusion masks + valid_mask = np.asarray(Image.open(occlued_mask_path)) == 0 + # out of frame masks + off_mask = np.asarray(Image.open(out_of_frame_mask_path)) == 0 + # combine the masks together + valid_mask = np.logical_and(off_mask, valid_mask) + return disparity_map, valid_mask + + def __getitem__(self, index: int) -> T2: + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned. + The disparity is a numpy array of shape (1, H, W) and the images are PIL images whilst + the valid_mask is a numpy array of shape (H, W). + """ + return cast(T2, super().__getitem__(index)) + + +class InStereo2k(StereoMatchingDataset): + """`InStereo2k `_ dataset. + + The dataset is expected to have the following structure: :: + + root + InStereo2k + train + scene1 + left.png + right.png + left_disp.png + right_disp.png + ... + scene2 + ... + test + scene1 + left.png + right.png + left_disp.png + right_disp.png + ... + scene2 + ... + + Args: + root (str or ``pathlib.Path``): Root directory where InStereo2k is located. + split (string): Either "train" or "test". + transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. + """ + + def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None: + super().__init__(root, transforms) + + root = Path(root) / "InStereo2k" / split + + verify_str_arg(split, "split", valid_values=("train", "test")) + + left_img_pattern = str(root / "*" / "left.png") + right_img_pattern = str(root / "*" / "right.png") + self._images = self._scan_pairs(left_img_pattern, right_img_pattern) + + left_disparity_pattern = str(root / "*" / "left_disp.png") + right_disparity_pattern = str(root / "*" / "right_disp.png") + self._disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern) + + def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]: + disparity_map = np.asarray(Image.open(file_path), dtype=np.float32) + # unsqueeze disparity to (C, H, W) + disparity_map = disparity_map[None, :, :] / 1024.0 + valid_mask = None + return disparity_map, valid_mask + + def __getitem__(self, index: int) -> T1: + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: A 3-tuple with ``(img_left, img_right, disparity)``. + The disparity is a numpy array of shape (1, H, W) and the images are PIL images. + If a ``valid_mask`` is generated within the ``transforms`` parameter, + a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned. + """ + return cast(T1, super().__getitem__(index)) + + +class ETH3DStereo(StereoMatchingDataset): + """ETH3D `Low-Res Two-View `_ dataset. + + The dataset is expected to have the following structure: :: + + root + ETH3D + two_view_training + scene1 + im1.png + im0.png + images.txt + cameras.txt + calib.txt + scene2 + im1.png + im0.png + images.txt + cameras.txt + calib.txt + ... + two_view_training_gt + scene1 + disp0GT.pfm + mask0nocc.png + scene2 + disp0GT.pfm + mask0nocc.png + ... + two_view_testing + scene1 + im1.png + im0.png + images.txt + cameras.txt + calib.txt + scene2 + im1.png + im0.png + images.txt + cameras.txt + calib.txt + ... + + Args: + root (str or ``pathlib.Path``): Root directory of the ETH3D Dataset. + split (string, optional): The dataset split of scenes, either "train" (default) or "test". + transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. + """ + + _has_built_in_disparity_mask = True + + def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None: + super().__init__(root, transforms) + + verify_str_arg(split, "split", valid_values=("train", "test")) + + root = Path(root) / "ETH3D" + + img_dir = "two_view_training" if split == "train" else "two_view_test" + anot_dir = "two_view_training_gt" + + left_img_pattern = str(root / img_dir / "*" / "im0.png") + right_img_pattern = str(root / img_dir / "*" / "im1.png") + self._images = self._scan_pairs(left_img_pattern, right_img_pattern) + + if split == "test": + self._disparities = list((None, None) for _ in self._images) + else: + disparity_pattern = str(root / anot_dir / "*" / "disp0GT.pfm") + self._disparities = self._scan_pairs(disparity_pattern, None) + + def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]: + # test split has no disparity maps + if file_path is None: + return None, None + + disparity_map = _read_pfm_file(file_path) + disparity_map = np.abs(disparity_map) # ensure that the disparity is positive + mask_path = Path(file_path).parent / "mask0nocc.png" + valid_mask = Image.open(mask_path) + valid_mask = np.asarray(valid_mask).astype(bool) + return disparity_map, valid_mask + + def __getitem__(self, index: int) -> T2: + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``. + The disparity is a numpy array of shape (1, H, W) and the images are PIL images. + ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not + generate a valid mask. + Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test. + """ + return cast(T2, super().__getitem__(index)) diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/caltech.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/caltech.py new file mode 100644 index 0000000000000000000000000000000000000000..fe4f0fad208c7678b1461c2fba71d599ba65e2bb --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/caltech.py @@ -0,0 +1,242 @@ +import os +import os.path +from pathlib import Path +from typing import Any, Callable, List, Optional, Tuple, Union + +from PIL import Image + +from .utils import download_and_extract_archive, verify_str_arg +from .vision import VisionDataset + + +class Caltech101(VisionDataset): + """`Caltech 101 `_ Dataset. + + .. warning:: + + This class needs `scipy `_ to load target files from `.mat` format. + + Args: + root (str or ``pathlib.Path``): Root directory of dataset where directory + ``caltech101`` exists or will be saved to if download is set to True. + target_type (string or list, optional): Type of target to use, ``category`` or + ``annotation``. Can also be a list to output a tuple with all specified + target types. ``category`` represents the target class, and + ``annotation`` is a list of points from a hand-generated outline. + Defaults to ``category``. + transform (callable, optional): A function/transform that takes in a PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + + .. warning:: + + To download the dataset `gdown `_ is required. + """ + + def __init__( + self, + root: Union[str, Path], + target_type: Union[List[str], str] = "category", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: + super().__init__(os.path.join(root, "caltech101"), transform=transform, target_transform=target_transform) + os.makedirs(self.root, exist_ok=True) + if isinstance(target_type, str): + target_type = [target_type] + self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation")) for t in target_type] + + if download: + self.download() + + if not self._check_integrity(): + raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it") + + self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories"))) + self.categories.remove("BACKGROUND_Google") # this is not a real class + + # For some reason, the category names in "101_ObjectCategories" and + # "Annotations" do not always match. This is a manual map between the + # two. Defaults to using same name, since most names are fine. + name_map = { + "Faces": "Faces_2", + "Faces_easy": "Faces_3", + "Motorbikes": "Motorbikes_16", + "airplanes": "Airplanes_Side_2", + } + self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories)) + + self.index: List[int] = [] + self.y = [] + for (i, c) in enumerate(self.categories): + n = len(os.listdir(os.path.join(self.root, "101_ObjectCategories", c))) + self.index.extend(range(1, n + 1)) + self.y.extend(n * [i]) + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """ + Args: + index (int): Index + + Returns: + tuple: (image, target) where the type of target specified by target_type. + """ + import scipy.io + + img = Image.open( + os.path.join( + self.root, + "101_ObjectCategories", + self.categories[self.y[index]], + f"image_{self.index[index]:04d}.jpg", + ) + ) + + target: Any = [] + for t in self.target_type: + if t == "category": + target.append(self.y[index]) + elif t == "annotation": + data = scipy.io.loadmat( + os.path.join( + self.root, + "Annotations", + self.annotation_categories[self.y[index]], + f"annotation_{self.index[index]:04d}.mat", + ) + ) + target.append(data["obj_contour"]) + target = tuple(target) if len(target) > 1 else target[0] + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def _check_integrity(self) -> bool: + # can be more robust and check hash of files + return os.path.exists(os.path.join(self.root, "101_ObjectCategories")) + + def __len__(self) -> int: + return len(self.index) + + def download(self) -> None: + if self._check_integrity(): + print("Files already downloaded and verified") + return + + download_and_extract_archive( + "https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp", + self.root, + filename="101_ObjectCategories.tar.gz", + md5="b224c7392d521a49829488ab0f1120d9", + ) + download_and_extract_archive( + "https://drive.google.com/file/d/175kQy3UsZ0wUEHZjqkUDdNVssr7bgh_m", + self.root, + filename="Annotations.tar", + md5="6f83eeb1f24d99cab4eb377263132c91", + ) + + def extra_repr(self) -> str: + return "Target type: {target_type}".format(**self.__dict__) + + +class Caltech256(VisionDataset): + """`Caltech 256 `_ Dataset. + + Args: + root (str or ``pathlib.Path``): Root directory of dataset where directory + ``caltech256`` exists or will be saved to if download is set to True. + transform (callable, optional): A function/transform that takes in a PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + """ + + def __init__( + self, + root: str, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: + super().__init__(os.path.join(root, "caltech256"), transform=transform, target_transform=target_transform) + os.makedirs(self.root, exist_ok=True) + + if download: + self.download() + + if not self._check_integrity(): + raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it") + + self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories"))) + self.index: List[int] = [] + self.y = [] + for (i, c) in enumerate(self.categories): + n = len( + [ + item + for item in os.listdir(os.path.join(self.root, "256_ObjectCategories", c)) + if item.endswith(".jpg") + ] + ) + self.index.extend(range(1, n + 1)) + self.y.extend(n * [i]) + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """ + Args: + index (int): Index + + Returns: + tuple: (image, target) where target is index of the target class. + """ + img = Image.open( + os.path.join( + self.root, + "256_ObjectCategories", + self.categories[self.y[index]], + f"{self.y[index] + 1:03d}_{self.index[index]:04d}.jpg", + ) + ) + + target = self.y[index] + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def _check_integrity(self) -> bool: + # can be more robust and check hash of files + return os.path.exists(os.path.join(self.root, "256_ObjectCategories")) + + def __len__(self) -> int: + return len(self.index) + + def download(self) -> None: + if self._check_integrity(): + print("Files already downloaded and verified") + return + + download_and_extract_archive( + "https://drive.google.com/file/d/1r6o0pSROcV1_VwT4oSjA2FBUSCWGuxLK", + self.root, + filename="256_ObjectCategories.tar", + md5="67b4f42ca05d46448c6bb8ecd2220f6d", + ) diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/celeba.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/celeba.py new file mode 100644 index 0000000000000000000000000000000000000000..147597d3ab3596106c26b2501d2a6fc4042b2daf --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/celeba.py @@ -0,0 +1,194 @@ +import csv +import os +from collections import namedtuple +from pathlib import Path +from typing import Any, Callable, List, Optional, Tuple, Union + +import PIL +import torch + +from .utils import check_integrity, download_file_from_google_drive, extract_archive, verify_str_arg +from .vision import VisionDataset + +CSV = namedtuple("CSV", ["header", "index", "data"]) + + +class CelebA(VisionDataset): + """`Large-scale CelebFaces Attributes (CelebA) Dataset `_ Dataset. + + Args: + root (str or ``pathlib.Path``): Root directory where images are downloaded to. + split (string): One of {'train', 'valid', 'test', 'all'}. + Accordingly dataset is selected. + target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``, + or ``landmarks``. Can also be a list to output a tuple with all specified target types. + The targets represent: + + - ``attr`` (Tensor shape=(40,) dtype=int): binary (0, 1) labels for attributes + - ``identity`` (int): label for each person (data points with the same identity are the same person) + - ``bbox`` (Tensor shape=(4,) dtype=int): bounding box (x, y, width, height) + - ``landmarks`` (Tensor shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x, + righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y) + + Defaults to ``attr``. If empty, ``None`` will be returned as target. + + transform (callable, optional): A function/transform that takes in a PIL image + and returns a transformed version. E.g, ``transforms.PILToTensor`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + + .. warning:: + + To download the dataset `gdown `_ is required. + """ + + base_folder = "celeba" + # There currently does not appear to be an easy way to extract 7z in python (without introducing additional + # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available + # right now. + file_list = [ + # File ID MD5 Hash Filename + ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"), + # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc","b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"), + # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"), + ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"), + ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"), + ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"), + ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"), + # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"), + ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"), + ] + + def __init__( + self, + root: Union[str, Path], + split: str = "train", + target_type: Union[List[str], str] = "attr", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: + super().__init__(root, transform=transform, target_transform=target_transform) + self.split = split + if isinstance(target_type, list): + self.target_type = target_type + else: + self.target_type = [target_type] + + if not self.target_type and self.target_transform is not None: + raise RuntimeError("target_transform is specified but target_type is empty") + + if download: + self.download() + + if not self._check_integrity(): + raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it") + + split_map = { + "train": 0, + "valid": 1, + "test": 2, + "all": None, + } + split_ = split_map[verify_str_arg(split.lower(), "split", ("train", "valid", "test", "all"))] + splits = self._load_csv("list_eval_partition.txt") + identity = self._load_csv("identity_CelebA.txt") + bbox = self._load_csv("list_bbox_celeba.txt", header=1) + landmarks_align = self._load_csv("list_landmarks_align_celeba.txt", header=1) + attr = self._load_csv("list_attr_celeba.txt", header=1) + + mask = slice(None) if split_ is None else (splits.data == split_).squeeze() + + if mask == slice(None): # if split == "all" + self.filename = splits.index + else: + self.filename = [splits.index[i] for i in torch.squeeze(torch.nonzero(mask))] + self.identity = identity.data[mask] + self.bbox = bbox.data[mask] + self.landmarks_align = landmarks_align.data[mask] + self.attr = attr.data[mask] + # map from {-1, 1} to {0, 1} + self.attr = torch.div(self.attr + 1, 2, rounding_mode="floor") + self.attr_names = attr.header + + def _load_csv( + self, + filename: str, + header: Optional[int] = None, + ) -> CSV: + with open(os.path.join(self.root, self.base_folder, filename)) as csv_file: + data = list(csv.reader(csv_file, delimiter=" ", skipinitialspace=True)) + + if header is not None: + headers = data[header] + data = data[header + 1 :] + else: + headers = [] + + indices = [row[0] for row in data] + data = [row[1:] for row in data] + data_int = [list(map(int, i)) for i in data] + + return CSV(headers, indices, torch.tensor(data_int)) + + def _check_integrity(self) -> bool: + for (_, md5, filename) in self.file_list: + fpath = os.path.join(self.root, self.base_folder, filename) + _, ext = os.path.splitext(filename) + # Allow original archive to be deleted (zip and 7z) + # Only need the extracted images + if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5): + return False + + # Should check a hash of the images + return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba")) + + def download(self) -> None: + if self._check_integrity(): + print("Files already downloaded and verified") + return + + for (file_id, md5, filename) in self.file_list: + download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5) + + extract_archive(os.path.join(self.root, self.base_folder, "img_align_celeba.zip")) + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index])) + + target: Any = [] + for t in self.target_type: + if t == "attr": + target.append(self.attr[index, :]) + elif t == "identity": + target.append(self.identity[index, 0]) + elif t == "bbox": + target.append(self.bbox[index, :]) + elif t == "landmarks": + target.append(self.landmarks_align[index, :]) + else: + # TODO: refactor with utils.verify_str_arg + raise ValueError(f'Target type "{t}" is not recognized.') + + if self.transform is not None: + X = self.transform(X) + + if target: + target = tuple(target) if len(target) > 1 else target[0] + + if self.target_transform is not None: + target = self.target_transform(target) + else: + target = None + + return X, target + + def __len__(self) -> int: + return len(self.attr) + + def extra_repr(self) -> str: + lines = ["Target type: {target_type}", "Split: {split}"] + return "\n".join(lines).format(**self.__dict__) diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/cifar.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/cifar.py new file mode 100644 index 0000000000000000000000000000000000000000..1637670ab91010db4d56a9eafc15a674d0a2eca3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/cifar.py @@ -0,0 +1,168 @@ +import os.path +import pickle +from pathlib import Path +from typing import Any, Callable, Optional, Tuple, Union + +import numpy as np +from PIL import Image + +from .utils import check_integrity, download_and_extract_archive +from .vision import VisionDataset + + +class CIFAR10(VisionDataset): + """`CIFAR10 `_ Dataset. + + Args: + root (str or ``pathlib.Path``): Root directory of dataset where directory + ``cifar-10-batches-py`` exists or will be saved to if download is set to True. + train (bool, optional): If True, creates dataset from training set, otherwise + creates from test set. + transform (callable, optional): A function/transform that takes in a PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + + """ + + base_folder = "cifar-10-batches-py" + url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" + filename = "cifar-10-python.tar.gz" + tgz_md5 = "c58f30108f718f92721af3b95e74349a" + train_list = [ + ["data_batch_1", "c99cafc152244af753f735de768cd75f"], + ["data_batch_2", "d4bba439e000b95fd0a9bffe97cbabec"], + ["data_batch_3", "54ebc095f3ab1f0389bbae665268c751"], + ["data_batch_4", "634d18415352ddfa80567beed471001a"], + ["data_batch_5", "482c414d41f54cd18b22e5b47cb7c3cb"], + ] + + test_list = [ + ["test_batch", "40351d587109b95175f43aff81a1287e"], + ] + meta = { + "filename": "batches.meta", + "key": "label_names", + "md5": "5ff9c542aee3614f3951f8cda6e48888", + } + + def __init__( + self, + root: Union[str, Path], + train: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: + + super().__init__(root, transform=transform, target_transform=target_transform) + + self.train = train # training set or test set + + if download: + self.download() + + if not self._check_integrity(): + raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it") + + if self.train: + downloaded_list = self.train_list + else: + downloaded_list = self.test_list + + self.data: Any = [] + self.targets = [] + + # now load the picked numpy arrays + for file_name, checksum in downloaded_list: + file_path = os.path.join(self.root, self.base_folder, file_name) + with open(file_path, "rb") as f: + entry = pickle.load(f, encoding="latin1") + self.data.append(entry["data"]) + if "labels" in entry: + self.targets.extend(entry["labels"]) + else: + self.targets.extend(entry["fine_labels"]) + + self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) + self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC + + self._load_meta() + + def _load_meta(self) -> None: + path = os.path.join(self.root, self.base_folder, self.meta["filename"]) + if not check_integrity(path, self.meta["md5"]): + raise RuntimeError("Dataset metadata file not found or corrupted. You can use download=True to download it") + with open(path, "rb") as infile: + data = pickle.load(infile, encoding="latin1") + self.classes = data[self.meta["key"]] + self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """ + Args: + index (int): Index + + Returns: + tuple: (image, target) where target is index of the target class. + """ + img, target = self.data[index], self.targets[index] + + # doing this so that it is consistent with all other datasets + # to return a PIL Image + img = Image.fromarray(img) + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def __len__(self) -> int: + return len(self.data) + + def _check_integrity(self) -> bool: + for filename, md5 in self.train_list + self.test_list: + fpath = os.path.join(self.root, self.base_folder, filename) + if not check_integrity(fpath, md5): + return False + return True + + def download(self) -> None: + if self._check_integrity(): + print("Files already downloaded and verified") + return + download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) + + def extra_repr(self) -> str: + split = "Train" if self.train is True else "Test" + return f"Split: {split}" + + +class CIFAR100(CIFAR10): + """`CIFAR100 `_ Dataset. + + This is a subclass of the `CIFAR10` Dataset. + """ + + base_folder = "cifar-100-python" + url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" + filename = "cifar-100-python.tar.gz" + tgz_md5 = "eb9058c3a382ffc7106e4002c42a8d85" + train_list = [ + ["train", "16019d7e3df5f24257cddd939b257f8d"], + ] + + test_list = [ + ["test", "f0ef6b0ae62326f3e7ffdfab6717acfc"], + ] + meta = { + "filename": "meta", + "key": "fine_label_names", + "md5": "7973b15100ade9c7d40fb424638fde48", + } diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/cityscapes.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/cityscapes.py new file mode 100644 index 0000000000000000000000000000000000000000..969642553a1d95324b59769cabb5186b274dac42 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/cityscapes.py @@ -0,0 +1,222 @@ +import json +import os +from collections import namedtuple +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +from PIL import Image + +from .utils import extract_archive, iterable_to_str, verify_str_arg +from .vision import VisionDataset + + +class Cityscapes(VisionDataset): + """`Cityscapes `_ Dataset. + + Args: + root (str or ``pathlib.Path``): Root directory of dataset where directory ``leftImg8bit`` + and ``gtFine`` or ``gtCoarse`` are located. + split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="fine" + otherwise ``train``, ``train_extra`` or ``val`` + mode (string, optional): The quality mode to use, ``fine`` or ``coarse`` + target_type (string or list, optional): Type of target to use, ``instance``, ``semantic``, ``polygon`` + or ``color``. Can also be a list to output a tuple with all specified target types. + transform (callable, optional): A function/transform that takes in a PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + transforms (callable, optional): A function/transform that takes input sample and its target as entry + and returns a transformed version. + + Examples: + + Get semantic segmentation target + + .. code-block:: python + + dataset = Cityscapes('./data/cityscapes', split='train', mode='fine', + target_type='semantic') + + img, smnt = dataset[0] + + Get multiple targets + + .. code-block:: python + + dataset = Cityscapes('./data/cityscapes', split='train', mode='fine', + target_type=['instance', 'color', 'polygon']) + + img, (inst, col, poly) = dataset[0] + + Validate on the "coarse" set + + .. code-block:: python + + dataset = Cityscapes('./data/cityscapes', split='val', mode='coarse', + target_type='semantic') + + img, smnt = dataset[0] + """ + + # Based on https://github.com/mcordts/cityscapesScripts + CityscapesClass = namedtuple( + "CityscapesClass", + ["name", "id", "train_id", "category", "category_id", "has_instances", "ignore_in_eval", "color"], + ) + + classes = [ + CityscapesClass("unlabeled", 0, 255, "void", 0, False, True, (0, 0, 0)), + CityscapesClass("ego vehicle", 1, 255, "void", 0, False, True, (0, 0, 0)), + CityscapesClass("rectification border", 2, 255, "void", 0, False, True, (0, 0, 0)), + CityscapesClass("out of roi", 3, 255, "void", 0, False, True, (0, 0, 0)), + CityscapesClass("static", 4, 255, "void", 0, False, True, (0, 0, 0)), + CityscapesClass("dynamic", 5, 255, "void", 0, False, True, (111, 74, 0)), + CityscapesClass("ground", 6, 255, "void", 0, False, True, (81, 0, 81)), + CityscapesClass("road", 7, 0, "flat", 1, False, False, (128, 64, 128)), + CityscapesClass("sidewalk", 8, 1, "flat", 1, False, False, (244, 35, 232)), + CityscapesClass("parking", 9, 255, "flat", 1, False, True, (250, 170, 160)), + CityscapesClass("rail track", 10, 255, "flat", 1, False, True, (230, 150, 140)), + CityscapesClass("building", 11, 2, "construction", 2, False, False, (70, 70, 70)), + CityscapesClass("wall", 12, 3, "construction", 2, False, False, (102, 102, 156)), + CityscapesClass("fence", 13, 4, "construction", 2, False, False, (190, 153, 153)), + CityscapesClass("guard rail", 14, 255, "construction", 2, False, True, (180, 165, 180)), + CityscapesClass("bridge", 15, 255, "construction", 2, False, True, (150, 100, 100)), + CityscapesClass("tunnel", 16, 255, "construction", 2, False, True, (150, 120, 90)), + CityscapesClass("pole", 17, 5, "object", 3, False, False, (153, 153, 153)), + CityscapesClass("polegroup", 18, 255, "object", 3, False, True, (153, 153, 153)), + CityscapesClass("traffic light", 19, 6, "object", 3, False, False, (250, 170, 30)), + CityscapesClass("traffic sign", 20, 7, "object", 3, False, False, (220, 220, 0)), + CityscapesClass("vegetation", 21, 8, "nature", 4, False, False, (107, 142, 35)), + CityscapesClass("terrain", 22, 9, "nature", 4, False, False, (152, 251, 152)), + CityscapesClass("sky", 23, 10, "sky", 5, False, False, (70, 130, 180)), + CityscapesClass("person", 24, 11, "human", 6, True, False, (220, 20, 60)), + CityscapesClass("rider", 25, 12, "human", 6, True, False, (255, 0, 0)), + CityscapesClass("car", 26, 13, "vehicle", 7, True, False, (0, 0, 142)), + CityscapesClass("truck", 27, 14, "vehicle", 7, True, False, (0, 0, 70)), + CityscapesClass("bus", 28, 15, "vehicle", 7, True, False, (0, 60, 100)), + CityscapesClass("caravan", 29, 255, "vehicle", 7, True, True, (0, 0, 90)), + CityscapesClass("trailer", 30, 255, "vehicle", 7, True, True, (0, 0, 110)), + CityscapesClass("train", 31, 16, "vehicle", 7, True, False, (0, 80, 100)), + CityscapesClass("motorcycle", 32, 17, "vehicle", 7, True, False, (0, 0, 230)), + CityscapesClass("bicycle", 33, 18, "vehicle", 7, True, False, (119, 11, 32)), + CityscapesClass("license plate", -1, -1, "vehicle", 7, False, True, (0, 0, 142)), + ] + + def __init__( + self, + root: Union[str, Path], + split: str = "train", + mode: str = "fine", + target_type: Union[List[str], str] = "instance", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + transforms: Optional[Callable] = None, + ) -> None: + super().__init__(root, transforms, transform, target_transform) + self.mode = "gtFine" if mode == "fine" else "gtCoarse" + self.images_dir = os.path.join(self.root, "leftImg8bit", split) + self.targets_dir = os.path.join(self.root, self.mode, split) + self.target_type = target_type + self.split = split + self.images = [] + self.targets = [] + + verify_str_arg(mode, "mode", ("fine", "coarse")) + if mode == "fine": + valid_modes = ("train", "test", "val") + else: + valid_modes = ("train", "train_extra", "val") + msg = "Unknown value '{}' for argument split if mode is '{}'. Valid values are {{{}}}." + msg = msg.format(split, mode, iterable_to_str(valid_modes)) + verify_str_arg(split, "split", valid_modes, msg) + + if not isinstance(target_type, list): + self.target_type = [target_type] + [ + verify_str_arg(value, "target_type", ("instance", "semantic", "polygon", "color")) + for value in self.target_type + ] + + if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir): + + if split == "train_extra": + image_dir_zip = os.path.join(self.root, "leftImg8bit_trainextra.zip") + else: + image_dir_zip = os.path.join(self.root, "leftImg8bit_trainvaltest.zip") + + if self.mode == "gtFine": + target_dir_zip = os.path.join(self.root, f"{self.mode}_trainvaltest.zip") + elif self.mode == "gtCoarse": + target_dir_zip = os.path.join(self.root, f"{self.mode}.zip") + + if os.path.isfile(image_dir_zip) and os.path.isfile(target_dir_zip): + extract_archive(from_path=image_dir_zip, to_path=self.root) + extract_archive(from_path=target_dir_zip, to_path=self.root) + else: + raise RuntimeError( + "Dataset not found or incomplete. Please make sure all required folders for the" + ' specified "split" and "mode" are inside the "root" directory' + ) + + for city in os.listdir(self.images_dir): + img_dir = os.path.join(self.images_dir, city) + target_dir = os.path.join(self.targets_dir, city) + for file_name in os.listdir(img_dir): + target_types = [] + for t in self.target_type: + target_name = "{}_{}".format( + file_name.split("_leftImg8bit")[0], self._get_target_suffix(self.mode, t) + ) + target_types.append(os.path.join(target_dir, target_name)) + + self.images.append(os.path.join(img_dir, file_name)) + self.targets.append(target_types) + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """ + Args: + index (int): Index + Returns: + tuple: (image, target) where target is a tuple of all target types if target_type is a list with more + than one item. Otherwise, target is a json object if target_type="polygon", else the image segmentation. + """ + + image = Image.open(self.images[index]).convert("RGB") + + targets: Any = [] + for i, t in enumerate(self.target_type): + if t == "polygon": + target = self._load_json(self.targets[index][i]) + else: + target = Image.open(self.targets[index][i]) # type: ignore[assignment] + + targets.append(target) + + target = tuple(targets) if len(targets) > 1 else targets[0] + + if self.transforms is not None: + image, target = self.transforms(image, target) + + return image, target + + def __len__(self) -> int: + return len(self.images) + + def extra_repr(self) -> str: + lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"] + return "\n".join(lines).format(**self.__dict__) + + def _load_json(self, path: str) -> Dict[str, Any]: + with open(path) as file: + data = json.load(file) + return data + + def _get_target_suffix(self, mode: str, target_type: str) -> str: + if target_type == "instance": + return f"{mode}_instanceIds.png" + elif target_type == "semantic": + return f"{mode}_labelIds.png" + elif target_type == "color": + return f"{mode}_color.png" + else: + return f"{mode}_polygons.json" diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/clevr.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/clevr.py new file mode 100644 index 0000000000000000000000000000000000000000..328eb7d79da70c3607b86ded512021f901119d1b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/clevr.py @@ -0,0 +1,88 @@ +import json +import pathlib +from typing import Any, Callable, List, Optional, Tuple, Union +from urllib.parse import urlparse + +from PIL import Image + +from .utils import download_and_extract_archive, verify_str_arg +from .vision import VisionDataset + + +class CLEVRClassification(VisionDataset): + """`CLEVR `_ classification dataset. + + The number of objects in a scene are used as label. + + Args: + root (str or ``pathlib.Path``): Root directory of dataset where directory ``root/clevr`` exists or will be saved to if download is + set to True. + split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``. + transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed + version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in them target and transforms it. + download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If + dataset is already downloaded, it is not downloaded again. + """ + + _URL = "https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip" + _MD5 = "b11922020e72d0cd9154779b2d3d07d2" + + def __init__( + self, + root: Union[str, pathlib.Path], + split: str = "train", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: + self._split = verify_str_arg(split, "split", ("train", "val", "test")) + super().__init__(root, transform=transform, target_transform=target_transform) + self._base_folder = pathlib.Path(self.root) / "clevr" + self._data_folder = self._base_folder / pathlib.Path(urlparse(self._URL).path).stem + + if download: + self._download() + + if not self._check_exists(): + raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it") + + self._image_files = sorted(self._data_folder.joinpath("images", self._split).glob("*")) + + self._labels: List[Optional[int]] + if self._split != "test": + with open(self._data_folder / "scenes" / f"CLEVR_{self._split}_scenes.json") as file: + content = json.load(file) + num_objects = {scene["image_filename"]: len(scene["objects"]) for scene in content["scenes"]} + self._labels = [num_objects[image_file.name] for image_file in self._image_files] + else: + self._labels = [None] * len(self._image_files) + + def __len__(self) -> int: + return len(self._image_files) + + def __getitem__(self, idx: int) -> Tuple[Any, Any]: + image_file = self._image_files[idx] + label = self._labels[idx] + + image = Image.open(image_file).convert("RGB") + + if self.transform: + image = self.transform(image) + + if self.target_transform: + label = self.target_transform(label) + + return image, label + + def _check_exists(self) -> bool: + return self._data_folder.exists() and self._data_folder.is_dir() + + def _download(self) -> None: + if self._check_exists(): + return + + download_and_extract_archive(self._URL, str(self._base_folder), md5=self._MD5) + + def extra_repr(self) -> str: + return f"split={self._split}" diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/coco.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/coco.py new file mode 100644 index 0000000000000000000000000000000000000000..f3b7be798b2457de56416f230f82a49373fbe941 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/coco.py @@ -0,0 +1,109 @@ +import os.path +from pathlib import Path +from typing import Any, Callable, List, Optional, Tuple, Union + +from PIL import Image + +from .vision import VisionDataset + + +class CocoDetection(VisionDataset): + """`MS Coco Detection `_ Dataset. + + It requires the `COCO API to be installed `_. + + Args: + root (str or ``pathlib.Path``): Root directory where images are downloaded to. + annFile (string): Path to json annotation file. + transform (callable, optional): A function/transform that takes in a PIL image + and returns a transformed version. E.g, ``transforms.PILToTensor`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + transforms (callable, optional): A function/transform that takes input sample and its target as entry + and returns a transformed version. + """ + + def __init__( + self, + root: Union[str, Path], + annFile: str, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + transforms: Optional[Callable] = None, + ) -> None: + super().__init__(root, transforms, transform, target_transform) + from pycocotools.coco import COCO + + self.coco = COCO(annFile) + self.ids = list(sorted(self.coco.imgs.keys())) + + def _load_image(self, id: int) -> Image.Image: + path = self.coco.loadImgs(id)[0]["file_name"] + return Image.open(os.path.join(self.root, path)).convert("RGB") + + def _load_target(self, id: int) -> List[Any]: + return self.coco.loadAnns(self.coco.getAnnIds(id)) + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + + if not isinstance(index, int): + raise ValueError(f"Index must be of type integer, got {type(index)} instead.") + + id = self.ids[index] + image = self._load_image(id) + target = self._load_target(id) + + if self.transforms is not None: + image, target = self.transforms(image, target) + + return image, target + + def __len__(self) -> int: + return len(self.ids) + + +class CocoCaptions(CocoDetection): + """`MS Coco Captions `_ Dataset. + + It requires the `COCO API to be installed `_. + + Args: + root (str or ``pathlib.Path``): Root directory where images are downloaded to. + annFile (string): Path to json annotation file. + transform (callable, optional): A function/transform that takes in a PIL image + and returns a transformed version. E.g, ``transforms.PILToTensor`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + transforms (callable, optional): A function/transform that takes input sample and its target as entry + and returns a transformed version. + + Example: + + .. code:: python + + import torchvision.datasets as dset + import torchvision.transforms as transforms + cap = dset.CocoCaptions(root = 'dir where images are', + annFile = 'json annotation file', + transform=transforms.PILToTensor()) + + print('Number of samples: ', len(cap)) + img, target = cap[3] # load 4th sample + + print("Image Size: ", img.size()) + print(target) + + Output: :: + + Number of samples: 82783 + Image Size: (3L, 427L, 640L) + [u'A plane emitting smoke stream flying over a mountain.', + u'A plane darts across a bright blue sky behind a mountain covered in snow', + u'A plane leaves a contrail above the snowy mountain top.', + u'A mountain that has a plane flying overheard in the distance.', + u'A mountain view with a plume of smoke in the background'] + + """ + + def _load_target(self, id: int) -> List[str]: + return [ann["caption"] for ann in super()._load_target(id)] diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/dtd.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/dtd.py new file mode 100644 index 0000000000000000000000000000000000000000..71c556bd201b37b0622df050e8c9dadd5f32f4e0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/dtd.py @@ -0,0 +1,100 @@ +import os +import pathlib +from typing import Any, Callable, Optional, Tuple, Union + +import PIL.Image + +from .utils import download_and_extract_archive, verify_str_arg +from .vision import VisionDataset + + +class DTD(VisionDataset): + """`Describable Textures Dataset (DTD) `_. + + Args: + root (str or ``pathlib.Path``): Root directory of the dataset. + split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``. + partition (int, optional): The dataset partition. Should be ``1 <= partition <= 10``. Defaults to ``1``. + + .. note:: + + The partition only changes which split each image belongs to. Thus, regardless of the selected + partition, combining all splits will result in all images. + + transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed + version. E.g, ``transforms.RandomCrop``. + target_transform (callable, optional): A function/transform that takes in the target and transforms it. + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. Default is False. + """ + + _URL = "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz" + _MD5 = "fff73e5086ae6bdbea199a49dfb8a4c1" + + def __init__( + self, + root: Union[str, pathlib.Path], + split: str = "train", + partition: int = 1, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: + self._split = verify_str_arg(split, "split", ("train", "val", "test")) + if not isinstance(partition, int) and not (1 <= partition <= 10): + raise ValueError( + f"Parameter 'partition' should be an integer with `1 <= partition <= 10`, " + f"but got {partition} instead" + ) + self._partition = partition + + super().__init__(root, transform=transform, target_transform=target_transform) + self._base_folder = pathlib.Path(self.root) / type(self).__name__.lower() + self._data_folder = self._base_folder / "dtd" + self._meta_folder = self._data_folder / "labels" + self._images_folder = self._data_folder / "images" + + if download: + self._download() + + if not self._check_exists(): + raise RuntimeError("Dataset not found. You can use download=True to download it") + + self._image_files = [] + classes = [] + with open(self._meta_folder / f"{self._split}{self._partition}.txt") as file: + for line in file: + cls, name = line.strip().split("/") + self._image_files.append(self._images_folder.joinpath(cls, name)) + classes.append(cls) + + self.classes = sorted(set(classes)) + self.class_to_idx = dict(zip(self.classes, range(len(self.classes)))) + self._labels = [self.class_to_idx[cls] for cls in classes] + + def __len__(self) -> int: + return len(self._image_files) + + def __getitem__(self, idx: int) -> Tuple[Any, Any]: + image_file, label = self._image_files[idx], self._labels[idx] + image = PIL.Image.open(image_file).convert("RGB") + + if self.transform: + image = self.transform(image) + + if self.target_transform: + label = self.target_transform(label) + + return image, label + + def extra_repr(self) -> str: + return f"split={self._split}, partition={self._partition}" + + def _check_exists(self) -> bool: + return os.path.exists(self._data_folder) and os.path.isdir(self._data_folder) + + def _download(self) -> None: + if self._check_exists(): + return + download_and_extract_archive(self._URL, download_root=str(self._base_folder), md5=self._MD5) diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/eurosat.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/eurosat.py new file mode 100644 index 0000000000000000000000000000000000000000..c6571d2abab11d457f0c0b9ff910c7a58efbcf87 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/eurosat.py @@ -0,0 +1,62 @@ +import os +from pathlib import Path +from typing import Callable, Optional, Union + +from .folder import ImageFolder +from .utils import download_and_extract_archive + + +class EuroSAT(ImageFolder): + """RGB version of the `EuroSAT `_ Dataset. + + For the MS version of the dataset, see + `TorchGeo `__. + + Args: + root (str or ``pathlib.Path``): Root directory of dataset where ``root/eurosat`` exists. + transform (callable, optional): A function/transform that takes in a PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. Default is False. + """ + + def __init__( + self, + root: Union[str, Path], + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: + self.root = os.path.expanduser(root) + self._base_folder = os.path.join(self.root, "eurosat") + self._data_folder = os.path.join(self._base_folder, "2750") + + if download: + self.download() + + if not self._check_exists(): + raise RuntimeError("Dataset not found. You can use download=True to download it") + + super().__init__(self._data_folder, transform=transform, target_transform=target_transform) + self.root = os.path.expanduser(root) + + def __len__(self) -> int: + return len(self.samples) + + def _check_exists(self) -> bool: + return os.path.exists(self._data_folder) + + def download(self) -> None: + + if self._check_exists(): + return + + os.makedirs(self._base_folder, exist_ok=True) + download_and_extract_archive( + "https://huggingface.co/datasets/torchgeo/eurosat/resolve/c877bcd43f099cd0196738f714544e355477f3fd/EuroSAT.zip", + download_root=self._base_folder, + md5="c8fa014336c82ac7804f0398fcb19387", + ) diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/fakedata.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/fakedata.py new file mode 100644 index 0000000000000000000000000000000000000000..af26a8579e5d954bde9fc06966d3518b7ca76c1c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/fakedata.py @@ -0,0 +1,67 @@ +from typing import Any, Callable, Optional, Tuple + +import torch + +from .. import transforms +from .vision import VisionDataset + + +class FakeData(VisionDataset): + """A fake dataset that returns randomly generated images and returns them as PIL images + + Args: + size (int, optional): Size of the dataset. Default: 1000 images + image_size(tuple, optional): Size if the returned images. Default: (3, 224, 224) + num_classes(int, optional): Number of classes in the dataset. Default: 10 + transform (callable, optional): A function/transform that takes in a PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + random_offset (int): Offsets the index-based random seed used to + generate each image. Default: 0 + + """ + + def __init__( + self, + size: int = 1000, + image_size: Tuple[int, int, int] = (3, 224, 224), + num_classes: int = 10, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + random_offset: int = 0, + ) -> None: + super().__init__(transform=transform, target_transform=target_transform) + self.size = size + self.num_classes = num_classes + self.image_size = image_size + self.random_offset = random_offset + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """ + Args: + index (int): Index + + Returns: + tuple: (image, target) where target is class_index of the target class. + """ + # create random image that is consistent with the index id + if index >= len(self): + raise IndexError(f"{self.__class__.__name__} index out of range") + rng_state = torch.get_rng_state() + torch.manual_seed(index + self.random_offset) + img = torch.randn(*self.image_size) + target = torch.randint(0, self.num_classes, size=(1,), dtype=torch.long)[0] + torch.set_rng_state(rng_state) + + # convert to PIL Image + img = transforms.ToPILImage()(img) + if self.transform is not None: + img = self.transform(img) + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target.item() + + def __len__(self) -> int: + return self.size diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/fer2013.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/fer2013.py new file mode 100644 index 0000000000000000000000000000000000000000..3afda07846b6d18b1cd43c37f7bb8747b8b172b5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/fer2013.py @@ -0,0 +1,120 @@ +import csv +import pathlib +from typing import Any, Callable, Optional, Tuple, Union + +import torch +from PIL import Image + +from .utils import check_integrity, verify_str_arg +from .vision import VisionDataset + + +class FER2013(VisionDataset): + """`FER2013 + `_ Dataset. + + .. note:: + This dataset can return test labels only if ``fer2013.csv`` OR + ``icml_face_data.csv`` are present in ``root/fer2013/``. If only + ``train.csv`` and ``test.csv`` are present, the test labels are set to + ``None``. + + Args: + root (str or ``pathlib.Path``): Root directory of dataset where directory + ``root/fer2013`` exists. This directory may contain either + ``fer2013.csv``, ``icml_face_data.csv``, or both ``train.csv`` and + ``test.csv``. Precendence is given in that order, i.e. if + ``fer2013.csv`` is present then the rest of the files will be + ignored. All these (combinations of) files contain the same data and + are supported for convenience, but only ``fer2013.csv`` and + ``icml_face_data.csv`` are able to return non-None test labels. + split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``. + transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed + version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the target and transforms it. + """ + + _RESOURCES = { + "train": ("train.csv", "3f0dfb3d3fd99c811a1299cb947e3131"), + "test": ("test.csv", "b02c2298636a634e8c2faabbf3ea9a23"), + # The fer2013.csv and icml_face_data.csv files contain both train and + # tests instances, and unlike test.csv they contain the labels for the + # test instances. We give these 2 files precedence over train.csv and + # test.csv. And yes, they both contain the same data, but with different + # column names (note the spaces) and ordering: + # $ head -n 1 fer2013.csv icml_face_data.csv train.csv test.csv + # ==> fer2013.csv <== + # emotion,pixels,Usage + # + # ==> icml_face_data.csv <== + # emotion, Usage, pixels + # + # ==> train.csv <== + # emotion,pixels + # + # ==> test.csv <== + # pixels + "fer": ("fer2013.csv", "f8428a1edbd21e88f42c73edd2a14f95"), + "icml": ("icml_face_data.csv", "b114b9e04e6949e5fe8b6a98b3892b1d"), + } + + def __init__( + self, + root: Union[str, pathlib.Path], + split: str = "train", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: + self._split = verify_str_arg(split, "split", ("train", "test")) + super().__init__(root, transform=transform, target_transform=target_transform) + + base_folder = pathlib.Path(self.root) / "fer2013" + use_fer_file = (base_folder / self._RESOURCES["fer"][0]).exists() + use_icml_file = not use_fer_file and (base_folder / self._RESOURCES["icml"][0]).exists() + file_name, md5 = self._RESOURCES["fer" if use_fer_file else "icml" if use_icml_file else self._split] + data_file = base_folder / file_name + if not check_integrity(str(data_file), md5=md5): + raise RuntimeError( + f"{file_name} not found in {base_folder} or corrupted. " + f"You can download it from " + f"https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge" + ) + + pixels_key = " pixels" if use_icml_file else "pixels" + usage_key = " Usage" if use_icml_file else "Usage" + + def get_img(row): + return torch.tensor([int(idx) for idx in row[pixels_key].split()], dtype=torch.uint8).reshape(48, 48) + + def get_label(row): + if use_fer_file or use_icml_file or self._split == "train": + return int(row["emotion"]) + else: + return None + + with open(data_file, "r", newline="") as file: + rows = (row for row in csv.DictReader(file)) + + if use_fer_file or use_icml_file: + valid_keys = ("Training",) if self._split == "train" else ("PublicTest", "PrivateTest") + rows = (row for row in rows if row[usage_key] in valid_keys) + + self._samples = [(get_img(row), get_label(row)) for row in rows] + + def __len__(self) -> int: + return len(self._samples) + + def __getitem__(self, idx: int) -> Tuple[Any, Any]: + image_tensor, target = self._samples[idx] + image = Image.fromarray(image_tensor.numpy()) + + if self.transform is not None: + image = self.transform(image) + + if self.target_transform is not None: + target = self.target_transform(target) + + return image, target + + def extra_repr(self) -> str: + return f"split={self._split}" diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/fgvc_aircraft.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/fgvc_aircraft.py new file mode 100644 index 0000000000000000000000000000000000000000..bbf4e970a787556e634f6c2eeb64ed4cd706fa2b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/fgvc_aircraft.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +import os +from pathlib import Path +from typing import Any, Callable, Optional, Tuple, Union + +import PIL.Image + +from .utils import download_and_extract_archive, verify_str_arg +from .vision import VisionDataset + + +class FGVCAircraft(VisionDataset): + """`FGVC Aircraft `_ Dataset. + + The dataset contains 10,000 images of aircraft, with 100 images for each of 100 + different aircraft model variants, most of which are airplanes. + Aircraft models are organized in a three-levels hierarchy. The three levels, from + finer to coarser, are: + + - ``variant``, e.g. Boeing 737-700. A variant collapses all the models that are visually + indistinguishable into one class. The dataset comprises 100 different variants. + - ``family``, e.g. Boeing 737. The dataset comprises 70 different families. + - ``manufacturer``, e.g. Boeing. The dataset comprises 30 different manufacturers. + + Args: + root (str or ``pathlib.Path``): Root directory of the FGVC Aircraft dataset. + split (string, optional): The dataset split, supports ``train``, ``val``, + ``trainval`` and ``test``. + annotation_level (str, optional): The annotation level, supports ``variant``, + ``family`` and ``manufacturer``. + transform (callable, optional): A function/transform that takes in a PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + """ + + _URL = "https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz" + + def __init__( + self, + root: Union[str, Path], + split: str = "trainval", + annotation_level: str = "variant", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: + super().__init__(root, transform=transform, target_transform=target_transform) + self._split = verify_str_arg(split, "split", ("train", "val", "trainval", "test")) + self._annotation_level = verify_str_arg( + annotation_level, "annotation_level", ("variant", "family", "manufacturer") + ) + + self._data_path = os.path.join(self.root, "fgvc-aircraft-2013b") + if download: + self._download() + + if not self._check_exists(): + raise RuntimeError("Dataset not found. You can use download=True to download it") + + annotation_file = os.path.join( + self._data_path, + "data", + { + "variant": "variants.txt", + "family": "families.txt", + "manufacturer": "manufacturers.txt", + }[self._annotation_level], + ) + with open(annotation_file, "r") as f: + self.classes = [line.strip() for line in f] + + self.class_to_idx = dict(zip(self.classes, range(len(self.classes)))) + + image_data_folder = os.path.join(self._data_path, "data", "images") + labels_file = os.path.join(self._data_path, "data", f"images_{self._annotation_level}_{self._split}.txt") + + self._image_files = [] + self._labels = [] + + with open(labels_file, "r") as f: + for line in f: + image_name, label_name = line.strip().split(" ", 1) + self._image_files.append(os.path.join(image_data_folder, f"{image_name}.jpg")) + self._labels.append(self.class_to_idx[label_name]) + + def __len__(self) -> int: + return len(self._image_files) + + def __getitem__(self, idx: int) -> Tuple[Any, Any]: + image_file, label = self._image_files[idx], self._labels[idx] + image = PIL.Image.open(image_file).convert("RGB") + + if self.transform: + image = self.transform(image) + + if self.target_transform: + label = self.target_transform(label) + + return image, label + + def _download(self) -> None: + """ + Download the FGVC Aircraft dataset archive and extract it under root. + """ + if self._check_exists(): + return + download_and_extract_archive(self._URL, self.root) + + def _check_exists(self) -> bool: + return os.path.exists(self._data_path) and os.path.isdir(self._data_path) diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/flickr.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/flickr.py new file mode 100644 index 0000000000000000000000000000000000000000..1021309db0540bcd32c7bf35668b4c43ce586a4d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/flickr.py @@ -0,0 +1,167 @@ +import glob +import os +from collections import defaultdict +from html.parser import HTMLParser +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +from PIL import Image + +from .vision import VisionDataset + + +class Flickr8kParser(HTMLParser): + """Parser for extracting captions from the Flickr8k dataset web page.""" + + def __init__(self, root: Union[str, Path]) -> None: + super().__init__() + + self.root = root + + # Data structure to store captions + self.annotations: Dict[str, List[str]] = {} + + # State variables + self.in_table = False + self.current_tag: Optional[str] = None + self.current_img: Optional[str] = None + + def handle_starttag(self, tag: str, attrs: List[Tuple[str, Optional[str]]]) -> None: + self.current_tag = tag + + if tag == "table": + self.in_table = True + + def handle_endtag(self, tag: str) -> None: + self.current_tag = None + + if tag == "table": + self.in_table = False + + def handle_data(self, data: str) -> None: + if self.in_table: + if data == "Image Not Found": + self.current_img = None + elif self.current_tag == "a": + img_id = data.split("/")[-2] + img_id = os.path.join(self.root, img_id + "_*.jpg") + img_id = glob.glob(img_id)[0] + self.current_img = img_id + self.annotations[img_id] = [] + elif self.current_tag == "li" and self.current_img: + img_id = self.current_img + self.annotations[img_id].append(data.strip()) + + +class Flickr8k(VisionDataset): + """`Flickr8k Entities `_ Dataset. + + Args: + root (str or ``pathlib.Path``): Root directory where images are downloaded to. + ann_file (string): Path to annotation file. + transform (callable, optional): A function/transform that takes in a PIL image + and returns a transformed version. E.g, ``transforms.PILToTensor`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + """ + + def __init__( + self, + root: Union[str, Path], + ann_file: str, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: + super().__init__(root, transform=transform, target_transform=target_transform) + self.ann_file = os.path.expanduser(ann_file) + + # Read annotations and store in a dict + parser = Flickr8kParser(self.root) + with open(self.ann_file) as fh: + parser.feed(fh.read()) + self.annotations = parser.annotations + + self.ids = list(sorted(self.annotations.keys())) + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """ + Args: + index (int): Index + + Returns: + tuple: Tuple (image, target). target is a list of captions for the image. + """ + img_id = self.ids[index] + + # Image + img = Image.open(img_id).convert("RGB") + if self.transform is not None: + img = self.transform(img) + + # Captions + target = self.annotations[img_id] + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def __len__(self) -> int: + return len(self.ids) + + +class Flickr30k(VisionDataset): + """`Flickr30k Entities `_ Dataset. + + Args: + root (str or ``pathlib.Path``): Root directory where images are downloaded to. + ann_file (string): Path to annotation file. + transform (callable, optional): A function/transform that takes in a PIL image + and returns a transformed version. E.g, ``transforms.PILToTensor`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + """ + + def __init__( + self, + root: str, + ann_file: str, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: + super().__init__(root, transform=transform, target_transform=target_transform) + self.ann_file = os.path.expanduser(ann_file) + + # Read annotations and store in a dict + self.annotations = defaultdict(list) + with open(self.ann_file) as fh: + for line in fh: + img_id, caption = line.strip().split("\t") + self.annotations[img_id[:-2]].append(caption) + + self.ids = list(sorted(self.annotations.keys())) + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """ + Args: + index (int): Index + + Returns: + tuple: Tuple (image, target). target is a list of captions for the image. + """ + img_id = self.ids[index] + + # Image + filename = os.path.join(self.root, img_id) + img = Image.open(filename).convert("RGB") + if self.transform is not None: + img = self.transform(img) + + # Captions + target = self.annotations[img_id] + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def __len__(self) -> int: + return len(self.ids) diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/flowers102.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/flowers102.py new file mode 100644 index 0000000000000000000000000000000000000000..07f403702f5f0e06f3d890713f693571190546c6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/flowers102.py @@ -0,0 +1,114 @@ +from pathlib import Path +from typing import Any, Callable, Optional, Tuple, Union + +import PIL.Image + +from .utils import check_integrity, download_and_extract_archive, download_url, verify_str_arg +from .vision import VisionDataset + + +class Flowers102(VisionDataset): + """`Oxford 102 Flower `_ Dataset. + + .. warning:: + + This class needs `scipy `_ to load target files from `.mat` format. + + Oxford 102 Flower is an image classification dataset consisting of 102 flower categories. The + flowers were chosen to be flowers commonly occurring in the United Kingdom. Each class consists of + between 40 and 258 images. + + The images have large scale, pose and light variations. In addition, there are categories that + have large variations within the category, and several very similar categories. + + Args: + root (str or ``pathlib.Path``): Root directory of the dataset. + split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``. + transform (callable, optional): A function/transform that takes in a PIL image and returns a + transformed version. E.g, ``transforms.RandomCrop``. + target_transform (callable, optional): A function/transform that takes in the target and transforms it. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + """ + + _download_url_prefix = "https://www.robots.ox.ac.uk/~vgg/data/flowers/102/" + _file_dict = { # filename, md5 + "image": ("102flowers.tgz", "52808999861908f626f3c1f4e79d11fa"), + "label": ("imagelabels.mat", "e0620be6f572b9609742df49c70aed4d"), + "setid": ("setid.mat", "a5357ecc9cb78c4bef273ce3793fc85c"), + } + _splits_map = {"train": "trnid", "val": "valid", "test": "tstid"} + + def __init__( + self, + root: Union[str, Path], + split: str = "train", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: + super().__init__(root, transform=transform, target_transform=target_transform) + self._split = verify_str_arg(split, "split", ("train", "val", "test")) + self._base_folder = Path(self.root) / "flowers-102" + self._images_folder = self._base_folder / "jpg" + + if download: + self.download() + + if not self._check_integrity(): + raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it") + + from scipy.io import loadmat + + set_ids = loadmat(self._base_folder / self._file_dict["setid"][0], squeeze_me=True) + image_ids = set_ids[self._splits_map[self._split]].tolist() + + labels = loadmat(self._base_folder / self._file_dict["label"][0], squeeze_me=True) + image_id_to_label = dict(enumerate((labels["labels"] - 1).tolist(), 1)) + + self._labels = [] + self._image_files = [] + for image_id in image_ids: + self._labels.append(image_id_to_label[image_id]) + self._image_files.append(self._images_folder / f"image_{image_id:05d}.jpg") + + def __len__(self) -> int: + return len(self._image_files) + + def __getitem__(self, idx: int) -> Tuple[Any, Any]: + image_file, label = self._image_files[idx], self._labels[idx] + image = PIL.Image.open(image_file).convert("RGB") + + if self.transform: + image = self.transform(image) + + if self.target_transform: + label = self.target_transform(label) + + return image, label + + def extra_repr(self) -> str: + return f"split={self._split}" + + def _check_integrity(self): + if not (self._images_folder.exists() and self._images_folder.is_dir()): + return False + + for id in ["label", "setid"]: + filename, md5 = self._file_dict[id] + if not check_integrity(str(self._base_folder / filename), md5): + return False + return True + + def download(self): + if self._check_integrity(): + return + download_and_extract_archive( + f"{self._download_url_prefix}{self._file_dict['image'][0]}", + str(self._base_folder), + md5=self._file_dict["image"][1], + ) + for id in ["label", "setid"]: + filename, md5 = self._file_dict[id] + download_url(self._download_url_prefix + filename, str(self._base_folder), md5=md5) diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/food101.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/food101.py new file mode 100644 index 0000000000000000000000000000000000000000..f734787c1bf638867ad8f05bc52f953e00954060 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/food101.py @@ -0,0 +1,93 @@ +import json +from pathlib import Path +from typing import Any, Callable, Optional, Tuple, Union + +import PIL.Image + +from .utils import download_and_extract_archive, verify_str_arg +from .vision import VisionDataset + + +class Food101(VisionDataset): + """`The Food-101 Data Set `_. + + The Food-101 is a challenging data set of 101 food categories with 101,000 images. + For each class, 250 manually reviewed test images are provided as well as 750 training images. + On purpose, the training images were not cleaned, and thus still contain some amount of noise. + This comes mostly in the form of intense colors and sometimes wrong labels. All images were + rescaled to have a maximum side length of 512 pixels. + + + Args: + root (str or ``pathlib.Path``): Root directory of the dataset. + split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``. + transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed + version. E.g, ``transforms.RandomCrop``. + target_transform (callable, optional): A function/transform that takes in the target and transforms it. + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. Default is False. + """ + + _URL = "http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz" + _MD5 = "85eeb15f3717b99a5da872d97d918f87" + + def __init__( + self, + root: Union[str, Path], + split: str = "train", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: + super().__init__(root, transform=transform, target_transform=target_transform) + self._split = verify_str_arg(split, "split", ("train", "test")) + self._base_folder = Path(self.root) / "food-101" + self._meta_folder = self._base_folder / "meta" + self._images_folder = self._base_folder / "images" + + if download: + self._download() + + if not self._check_exists(): + raise RuntimeError("Dataset not found. You can use download=True to download it") + + self._labels = [] + self._image_files = [] + with open(self._meta_folder / f"{split}.json") as f: + metadata = json.loads(f.read()) + + self.classes = sorted(metadata.keys()) + self.class_to_idx = dict(zip(self.classes, range(len(self.classes)))) + + for class_label, im_rel_paths in metadata.items(): + self._labels += [self.class_to_idx[class_label]] * len(im_rel_paths) + self._image_files += [ + self._images_folder.joinpath(*f"{im_rel_path}.jpg".split("/")) for im_rel_path in im_rel_paths + ] + + def __len__(self) -> int: + return len(self._image_files) + + def __getitem__(self, idx: int) -> Tuple[Any, Any]: + image_file, label = self._image_files[idx], self._labels[idx] + image = PIL.Image.open(image_file).convert("RGB") + + if self.transform: + image = self.transform(image) + + if self.target_transform: + label = self.target_transform(label) + + return image, label + + def extra_repr(self) -> str: + return f"split={self._split}" + + def _check_exists(self) -> bool: + return all(folder.exists() and folder.is_dir() for folder in (self._meta_folder, self._images_folder)) + + def _download(self) -> None: + if self._check_exists(): + return + download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5) diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/hmdb51.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/hmdb51.py new file mode 100644 index 0000000000000000000000000000000000000000..8377e40d57ca7c2b29f98704b67bd03793a40836 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/hmdb51.py @@ -0,0 +1,152 @@ +import glob +import os +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +from torch import Tensor + +from .folder import find_classes, make_dataset +from .video_utils import VideoClips +from .vision import VisionDataset + + +class HMDB51(VisionDataset): + """ + `HMDB51 `_ + dataset. + + HMDB51 is an action recognition video dataset. + This dataset consider every video as a collection of video clips of fixed size, specified + by ``frames_per_clip``, where the step in frames between each clip is given by + ``step_between_clips``. + + To give an example, for 2 videos with 10 and 15 frames respectively, if ``frames_per_clip=5`` + and ``step_between_clips=5``, the dataset size will be (2 + 3) = 5, where the first two + elements will come from video 1, and the next three elements from video 2. + Note that we drop clips which do not have exactly ``frames_per_clip`` elements, so not all + frames in a video might be present. + + Internally, it uses a VideoClips object to handle clip creation. + + Args: + root (str or ``pathlib.Path``): Root directory of the HMDB51 Dataset. + annotation_path (str): Path to the folder containing the split files. + frames_per_clip (int): Number of frames in a clip. + step_between_clips (int): Number of frames between each clip. + fold (int, optional): Which fold to use. Should be between 1 and 3. + train (bool, optional): If ``True``, creates a dataset from the train split, + otherwise from the ``test`` split. + transform (callable, optional): A function/transform that takes in a TxHxWxC video + and returns a transformed version. + output_format (str, optional): The format of the output video tensors (before transforms). + Can be either "THWC" (default) or "TCHW". + + Returns: + tuple: A 3-tuple with the following entries: + + - video (Tensor[T, H, W, C] or Tensor[T, C, H, W]): The `T` video frames + - audio(Tensor[K, L]): the audio frames, where `K` is the number of channels + and `L` is the number of points + - label (int): class of the video clip + """ + + data_url = "https://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/hmdb51_org.rar" + splits = { + "url": "https://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/test_train_splits.rar", + "md5": "15e67781e70dcfbdce2d7dbb9b3344b5", + } + TRAIN_TAG = 1 + TEST_TAG = 2 + + def __init__( + self, + root: Union[str, Path], + annotation_path: str, + frames_per_clip: int, + step_between_clips: int = 1, + frame_rate: Optional[int] = None, + fold: int = 1, + train: bool = True, + transform: Optional[Callable] = None, + _precomputed_metadata: Optional[Dict[str, Any]] = None, + num_workers: int = 1, + _video_width: int = 0, + _video_height: int = 0, + _video_min_dimension: int = 0, + _audio_samples: int = 0, + output_format: str = "THWC", + ) -> None: + super().__init__(root) + if fold not in (1, 2, 3): + raise ValueError(f"fold should be between 1 and 3, got {fold}") + + extensions = ("avi",) + self.classes, class_to_idx = find_classes(self.root) + self.samples = make_dataset( + self.root, + class_to_idx, + extensions, + ) + + video_paths = [path for (path, _) in self.samples] + video_clips = VideoClips( + video_paths, + frames_per_clip, + step_between_clips, + frame_rate, + _precomputed_metadata, + num_workers=num_workers, + _video_width=_video_width, + _video_height=_video_height, + _video_min_dimension=_video_min_dimension, + _audio_samples=_audio_samples, + output_format=output_format, + ) + # we bookkeep the full version of video clips because we want to be able + # to return the metadata of full version rather than the subset version of + # video clips + self.full_video_clips = video_clips + self.fold = fold + self.train = train + self.indices = self._select_fold(video_paths, annotation_path, fold, train) + self.video_clips = video_clips.subset(self.indices) + self.transform = transform + + @property + def metadata(self) -> Dict[str, Any]: + return self.full_video_clips.metadata + + def _select_fold(self, video_list: List[str], annotations_dir: str, fold: int, train: bool) -> List[int]: + target_tag = self.TRAIN_TAG if train else self.TEST_TAG + split_pattern_name = f"*test_split{fold}.txt" + split_pattern_path = os.path.join(annotations_dir, split_pattern_name) + annotation_paths = glob.glob(split_pattern_path) + selected_files = set() + for filepath in annotation_paths: + with open(filepath) as fid: + lines = fid.readlines() + for line in lines: + video_filename, tag_string = line.split() + tag = int(tag_string) + if tag == target_tag: + selected_files.add(video_filename) + + indices = [] + for video_index, video_path in enumerate(video_list): + if os.path.basename(video_path) in selected_files: + indices.append(video_index) + + return indices + + def __len__(self) -> int: + return self.video_clips.num_clips() + + def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, int]: + video, audio, _, video_idx = self.video_clips.get_clip(idx) + sample_index = self.indices[video_idx] + _, class_index = self.samples[sample_index] + + if self.transform is not None: + video = self.transform(video) + + return video, audio, class_index diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/imagenet.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/imagenet.py new file mode 100644 index 0000000000000000000000000000000000000000..d7caf328d2bb1adc149548daf83ee49c2a8459c3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/imagenet.py @@ -0,0 +1,219 @@ +import os +import shutil +import tempfile +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union + +import torch + +from .folder import ImageFolder +from .utils import check_integrity, extract_archive, verify_str_arg + +ARCHIVE_META = { + "train": ("ILSVRC2012_img_train.tar", "1d675b47d978889d74fa0da5fadfb00e"), + "val": ("ILSVRC2012_img_val.tar", "29b22e2961454d5413ddabcf34fc5622"), + "devkit": ("ILSVRC2012_devkit_t12.tar.gz", "fa75699e90414af021442c21a62c3abf"), +} + +META_FILE = "meta.bin" + + +class ImageNet(ImageFolder): + """`ImageNet `_ 2012 Classification Dataset. + + .. note:: + Before using this class, it is required to download ImageNet 2012 dataset from + `here `_ and + place the files ``ILSVRC2012_devkit_t12.tar.gz`` and ``ILSVRC2012_img_train.tar`` + or ``ILSVRC2012_img_val.tar`` based on ``split`` in the root directory. + + Args: + root (str or ``pathlib.Path``): Root directory of the ImageNet Dataset. + split (string, optional): The dataset split, supports ``train``, or ``val``. + transform (callable, optional): A function/transform that takes in a PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + loader (callable, optional): A function to load an image given its path. + + Attributes: + classes (list): List of the class name tuples. + class_to_idx (dict): Dict with items (class_name, class_index). + wnids (list): List of the WordNet IDs. + wnid_to_idx (dict): Dict with items (wordnet_id, class_index). + imgs (list): List of (image path, class_index) tuples + targets (list): The class_index value for each image in the dataset + """ + + def __init__(self, root: Union[str, Path], split: str = "train", **kwargs: Any) -> None: + root = self.root = os.path.expanduser(root) + self.split = verify_str_arg(split, "split", ("train", "val")) + + self.parse_archives() + wnid_to_classes = load_meta_file(self.root)[0] + + super().__init__(self.split_folder, **kwargs) + self.root = root + + self.wnids = self.classes + self.wnid_to_idx = self.class_to_idx + self.classes = [wnid_to_classes[wnid] for wnid in self.wnids] + self.class_to_idx = {cls: idx for idx, clss in enumerate(self.classes) for cls in clss} + + def parse_archives(self) -> None: + if not check_integrity(os.path.join(self.root, META_FILE)): + parse_devkit_archive(self.root) + + if not os.path.isdir(self.split_folder): + if self.split == "train": + parse_train_archive(self.root) + elif self.split == "val": + parse_val_archive(self.root) + + @property + def split_folder(self) -> str: + return os.path.join(self.root, self.split) + + def extra_repr(self) -> str: + return "Split: {split}".format(**self.__dict__) + + +def load_meta_file(root: Union[str, Path], file: Optional[str] = None) -> Tuple[Dict[str, str], List[str]]: + if file is None: + file = META_FILE + file = os.path.join(root, file) + + if check_integrity(file): + return torch.load(file, weights_only=True) + else: + msg = ( + "The meta file {} is not present in the root directory or is corrupted. " + "This file is automatically created by the ImageNet dataset." + ) + raise RuntimeError(msg.format(file, root)) + + +def _verify_archive(root: Union[str, Path], file: str, md5: str) -> None: + if not check_integrity(os.path.join(root, file), md5): + msg = ( + "The archive {} is not present in the root directory or is corrupted. " + "You need to download it externally and place it in {}." + ) + raise RuntimeError(msg.format(file, root)) + + +def parse_devkit_archive(root: Union[str, Path], file: Optional[str] = None) -> None: + """Parse the devkit archive of the ImageNet2012 classification dataset and save + the meta information in a binary file. + + Args: + root (str or ``pathlib.Path``): Root directory containing the devkit archive + file (str, optional): Name of devkit archive. Defaults to + 'ILSVRC2012_devkit_t12.tar.gz' + """ + import scipy.io as sio + + def parse_meta_mat(devkit_root: str) -> Tuple[Dict[int, str], Dict[str, Tuple[str, ...]]]: + metafile = os.path.join(devkit_root, "data", "meta.mat") + meta = sio.loadmat(metafile, squeeze_me=True)["synsets"] + nums_children = list(zip(*meta))[4] + meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0] + idcs, wnids, classes = list(zip(*meta))[:3] + classes = [tuple(clss.split(", ")) for clss in classes] + idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)} + wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)} + return idx_to_wnid, wnid_to_classes + + def parse_val_groundtruth_txt(devkit_root: str) -> List[int]: + file = os.path.join(devkit_root, "data", "ILSVRC2012_validation_ground_truth.txt") + with open(file) as txtfh: + val_idcs = txtfh.readlines() + return [int(val_idx) for val_idx in val_idcs] + + @contextmanager + def get_tmp_dir() -> Iterator[str]: + tmp_dir = tempfile.mkdtemp() + try: + yield tmp_dir + finally: + shutil.rmtree(tmp_dir) + + archive_meta = ARCHIVE_META["devkit"] + if file is None: + file = archive_meta[0] + md5 = archive_meta[1] + + _verify_archive(root, file, md5) + + with get_tmp_dir() as tmp_dir: + extract_archive(os.path.join(root, file), tmp_dir) + + devkit_root = os.path.join(tmp_dir, "ILSVRC2012_devkit_t12") + idx_to_wnid, wnid_to_classes = parse_meta_mat(devkit_root) + val_idcs = parse_val_groundtruth_txt(devkit_root) + val_wnids = [idx_to_wnid[idx] for idx in val_idcs] + + torch.save((wnid_to_classes, val_wnids), os.path.join(root, META_FILE)) + + +def parse_train_archive(root: Union[str, Path], file: Optional[str] = None, folder: str = "train") -> None: + """Parse the train images archive of the ImageNet2012 classification dataset and + prepare it for usage with the ImageNet dataset. + + Args: + root (str or ``pathlib.Path``): Root directory containing the train images archive + file (str, optional): Name of train images archive. Defaults to + 'ILSVRC2012_img_train.tar' + folder (str, optional): Optional name for train images folder. Defaults to + 'train' + """ + archive_meta = ARCHIVE_META["train"] + if file is None: + file = archive_meta[0] + md5 = archive_meta[1] + + _verify_archive(root, file, md5) + + train_root = os.path.join(root, folder) + extract_archive(os.path.join(root, file), train_root) + + archives = [os.path.join(train_root, archive) for archive in os.listdir(train_root)] + for archive in archives: + extract_archive(archive, os.path.splitext(archive)[0], remove_finished=True) + + +def parse_val_archive( + root: Union[str, Path], file: Optional[str] = None, wnids: Optional[List[str]] = None, folder: str = "val" +) -> None: + """Parse the validation images archive of the ImageNet2012 classification dataset + and prepare it for usage with the ImageNet dataset. + + Args: + root (str or ``pathlib.Path``): Root directory containing the validation images archive + file (str, optional): Name of validation images archive. Defaults to + 'ILSVRC2012_img_val.tar' + wnids (list, optional): List of WordNet IDs of the validation images. If None + is given, the IDs are loaded from the meta file in the root directory + folder (str, optional): Optional name for validation images folder. Defaults to + 'val' + """ + archive_meta = ARCHIVE_META["val"] + if file is None: + file = archive_meta[0] + md5 = archive_meta[1] + if wnids is None: + wnids = load_meta_file(root)[1] + + _verify_archive(root, file, md5) + + val_root = os.path.join(root, folder) + extract_archive(os.path.join(root, file), val_root) + + images = sorted(os.path.join(val_root, image) for image in os.listdir(val_root)) + + for wnid in set(wnids): + os.mkdir(os.path.join(val_root, wnid)) + + for wnid, img_file in zip(wnids, images): + shutil.move(img_file, os.path.join(val_root, wnid, os.path.basename(img_file))) diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/imagenette.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/imagenette.py new file mode 100644 index 0000000000000000000000000000000000000000..05da537891ba7c136b5fbc5c89381e6fd5dcf287 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/imagenette.py @@ -0,0 +1,104 @@ +from pathlib import Path +from typing import Any, Callable, Optional, Tuple, Union + +from PIL import Image + +from .folder import find_classes, make_dataset +from .utils import download_and_extract_archive, verify_str_arg +from .vision import VisionDataset + + +class Imagenette(VisionDataset): + """`Imagenette `_ image classification dataset. + + Args: + root (str or ``pathlib.Path``): Root directory of the Imagenette dataset. + split (string, optional): The dataset split. Supports ``"train"`` (default), and ``"val"``. + size (string, optional): The image size. Supports ``"full"`` (default), ``"320px"``, and ``"160px"``. + download (bool, optional): If ``True``, downloads the dataset components and places them in ``root``. Already + downloaded archives are not downloaded again. + transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed + version, e.g. ``transforms.RandomCrop``. + target_transform (callable, optional): A function/transform that takes in the target and transforms it. + + Attributes: + classes (list): List of the class name tuples. + class_to_idx (dict): Dict with items (class name, class index). + wnids (list): List of the WordNet IDs. + wnid_to_idx (dict): Dict with items (WordNet ID, class index). + """ + + _ARCHIVES = { + "full": ("https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz", "fe2fc210e6bb7c5664d602c3cd71e612"), + "320px": ("https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz", "3df6f0d01a2c9592104656642f5e78a3"), + "160px": ("https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz", "e793b78cc4c9e9a4ccc0c1155377a412"), + } + _WNID_TO_CLASS = { + "n01440764": ("tench", "Tinca tinca"), + "n02102040": ("English springer", "English springer spaniel"), + "n02979186": ("cassette player",), + "n03000684": ("chain saw", "chainsaw"), + "n03028079": ("church", "church building"), + "n03394916": ("French horn", "horn"), + "n03417042": ("garbage truck", "dustcart"), + "n03425413": ("gas pump", "gasoline pump", "petrol pump", "island dispenser"), + "n03445777": ("golf ball",), + "n03888257": ("parachute", "chute"), + } + + def __init__( + self, + root: Union[str, Path], + split: str = "train", + size: str = "full", + download=False, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: + super().__init__(root, transform=transform, target_transform=target_transform) + + self._split = verify_str_arg(split, "split", ["train", "val"]) + self._size = verify_str_arg(size, "size", ["full", "320px", "160px"]) + + self._url, self._md5 = self._ARCHIVES[self._size] + self._size_root = Path(self.root) / Path(self._url).stem + self._image_root = str(self._size_root / self._split) + + if download: + self._download() + elif not self._check_exists(): + raise RuntimeError("Dataset not found. You can use download=True to download it.") + + self.wnids, self.wnid_to_idx = find_classes(self._image_root) + self.classes = [self._WNID_TO_CLASS[wnid] for wnid in self.wnids] + self.class_to_idx = { + class_name: idx for wnid, idx in self.wnid_to_idx.items() for class_name in self._WNID_TO_CLASS[wnid] + } + self._samples = make_dataset(self._image_root, self.wnid_to_idx, extensions=".jpeg") + + def _check_exists(self) -> bool: + return self._size_root.exists() + + def _download(self): + if self._check_exists(): + raise RuntimeError( + f"The directory {self._size_root} already exists. " + f"If you want to re-download or re-extract the images, delete the directory." + ) + + download_and_extract_archive(self._url, self.root, md5=self._md5) + + def __getitem__(self, idx: int) -> Tuple[Any, Any]: + path, label = self._samples[idx] + image = Image.open(path).convert("RGB") + + if self.transform is not None: + image = self.transform(image) + + if self.target_transform is not None: + label = self.target_transform(label) + + return image, label + + def __len__(self) -> int: + return len(self._samples) diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/kinetics.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/kinetics.py new file mode 100644 index 0000000000000000000000000000000000000000..868c08e2c3042c6b7f083393fbb35c8772125d6c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/kinetics.py @@ -0,0 +1,248 @@ +import csv +import os +import time +import urllib +from functools import partial +from multiprocessing import Pool +from os import path +from pathlib import Path +from typing import Any, Callable, Dict, Optional, Tuple, Union + +from torch import Tensor + +from .folder import find_classes, make_dataset +from .utils import check_integrity, download_and_extract_archive, download_url, verify_str_arg +from .video_utils import VideoClips +from .vision import VisionDataset + + +def _dl_wrap(tarpath: Union[str, Path], videopath: Union[str, Path], line: str) -> None: + download_and_extract_archive(line, tarpath, videopath) + + +class Kinetics(VisionDataset): + """`Generic Kinetics `_ + dataset. + + Kinetics-400/600/700 are action recognition video datasets. + This dataset consider every video as a collection of video clips of fixed size, specified + by ``frames_per_clip``, where the step in frames between each clip is given by + ``step_between_clips``. + + To give an example, for 2 videos with 10 and 15 frames respectively, if ``frames_per_clip=5`` + and ``step_between_clips=5``, the dataset size will be (2 + 3) = 5, where the first two + elements will come from video 1, and the next three elements from video 2. + Note that we drop clips which do not have exactly ``frames_per_clip`` elements, so not all + frames in a video might be present. + + Args: + root (str or ``pathlib.Path``): Root directory of the Kinetics Dataset. + Directory should be structured as follows: + .. code:: + + root/ + ├── split + │ ├── class1 + │ │ ├── vid1.mp4 + │ │ ├── vid2.mp4 + │ │ ├── vid3.mp4 + │ │ ├── ... + │ ├── class2 + │ │ ├── vidx.mp4 + │ │ └── ... + + Note: split is appended automatically using the split argument. + frames_per_clip (int): number of frames in a clip + num_classes (int): select between Kinetics-400 (default), Kinetics-600, and Kinetics-700 + split (str): split of the dataset to consider; supports ``"train"`` (default) ``"val"`` ``"test"`` + frame_rate (float): If omitted, interpolate different frame rate for each clip. + step_between_clips (int): number of frames between each clip + transform (callable, optional): A function/transform that takes in a TxHxWxC video + and returns a transformed version. + download (bool): Download the official version of the dataset to root folder. + num_workers (int): Use multiple workers for VideoClips creation + num_download_workers (int): Use multiprocessing in order to speed up download. + output_format (str, optional): The format of the output video tensors (before transforms). + Can be either "THWC" or "TCHW" (default). + Note that in most other utils and datasets, the default is actually "THWC". + + Returns: + tuple: A 3-tuple with the following entries: + + - video (Tensor[T, C, H, W] or Tensor[T, H, W, C]): the `T` video frames in torch.uint8 tensor + - audio(Tensor[K, L]): the audio frames, where `K` is the number of channels + and `L` is the number of points in torch.float tensor + - label (int): class of the video clip + + Raises: + RuntimeError: If ``download is True`` and the video archives are already extracted. + """ + + _TAR_URLS = { + "400": "https://s3.amazonaws.com/kinetics/400/{split}/k400_{split}_path.txt", + "600": "https://s3.amazonaws.com/kinetics/600/{split}/k600_{split}_path.txt", + "700": "https://s3.amazonaws.com/kinetics/700_2020/{split}/k700_2020_{split}_path.txt", + } + _ANNOTATION_URLS = { + "400": "https://s3.amazonaws.com/kinetics/400/annotations/{split}.csv", + "600": "https://s3.amazonaws.com/kinetics/600/annotations/{split}.csv", + "700": "https://s3.amazonaws.com/kinetics/700_2020/annotations/{split}.csv", + } + + def __init__( + self, + root: Union[str, Path], + frames_per_clip: int, + num_classes: str = "400", + split: str = "train", + frame_rate: Optional[int] = None, + step_between_clips: int = 1, + transform: Optional[Callable] = None, + extensions: Tuple[str, ...] = ("avi", "mp4"), + download: bool = False, + num_download_workers: int = 1, + num_workers: int = 1, + _precomputed_metadata: Optional[Dict[str, Any]] = None, + _video_width: int = 0, + _video_height: int = 0, + _video_min_dimension: int = 0, + _audio_samples: int = 0, + _audio_channels: int = 0, + _legacy: bool = False, + output_format: str = "TCHW", + ) -> None: + + # TODO: support test + self.num_classes = verify_str_arg(num_classes, arg="num_classes", valid_values=["400", "600", "700"]) + self.extensions = extensions + self.num_download_workers = num_download_workers + + self.root = root + self._legacy = _legacy + + if _legacy: + print("Using legacy structure") + self.split_folder = root + self.split = "unknown" + output_format = "THWC" + if download: + raise ValueError("Cannot download the videos using legacy_structure.") + else: + self.split_folder = path.join(root, split) + self.split = verify_str_arg(split, arg="split", valid_values=["train", "val", "test"]) + + if download: + self.download_and_process_videos() + + super().__init__(self.root) + + self.classes, class_to_idx = find_classes(self.split_folder) + self.samples = make_dataset(self.split_folder, class_to_idx, extensions, is_valid_file=None) + video_list = [x[0] for x in self.samples] + self.video_clips = VideoClips( + video_list, + frames_per_clip, + step_between_clips, + frame_rate, + _precomputed_metadata, + num_workers=num_workers, + _video_width=_video_width, + _video_height=_video_height, + _video_min_dimension=_video_min_dimension, + _audio_samples=_audio_samples, + _audio_channels=_audio_channels, + output_format=output_format, + ) + self.transform = transform + + def download_and_process_videos(self) -> None: + """Downloads all the videos to the _root_ folder in the expected format.""" + tic = time.time() + self._download_videos() + toc = time.time() + print("Elapsed time for downloading in mins ", (toc - tic) / 60) + self._make_ds_structure() + toc2 = time.time() + print("Elapsed time for processing in mins ", (toc2 - toc) / 60) + print("Elapsed time overall in mins ", (toc2 - tic) / 60) + + def _download_videos(self) -> None: + """download tarballs containing the video to "tars" folder and extract them into the _split_ folder where + split is one of the official dataset splits. + + Raises: + RuntimeError: if download folder exists, break to prevent downloading entire dataset again. + """ + if path.exists(self.split_folder): + raise RuntimeError( + f"The directory {self.split_folder} already exists. " + f"If you want to re-download or re-extract the images, delete the directory." + ) + tar_path = path.join(self.root, "tars") + file_list_path = path.join(self.root, "files") + + split_url = self._TAR_URLS[self.num_classes].format(split=self.split) + split_url_filepath = path.join(file_list_path, path.basename(split_url)) + if not check_integrity(split_url_filepath): + download_url(split_url, file_list_path) + with open(split_url_filepath) as file: + list_video_urls = [urllib.parse.quote(line, safe="/,:") for line in file.read().splitlines()] + + if self.num_download_workers == 1: + for line in list_video_urls: + download_and_extract_archive(line, tar_path, self.split_folder) + else: + part = partial(_dl_wrap, tar_path, self.split_folder) + poolproc = Pool(self.num_download_workers) + poolproc.map(part, list_video_urls) + + def _make_ds_structure(self) -> None: + """move videos from + split_folder/ + ├── clip1.avi + ├── clip2.avi + + to the correct format as described below: + split_folder/ + ├── class1 + │ ├── clip1.avi + + """ + annotation_path = path.join(self.root, "annotations") + if not check_integrity(path.join(annotation_path, f"{self.split}.csv")): + download_url(self._ANNOTATION_URLS[self.num_classes].format(split=self.split), annotation_path) + annotations = path.join(annotation_path, f"{self.split}.csv") + + file_fmtstr = "{ytid}_{start:06}_{end:06}.mp4" + with open(annotations) as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + f = file_fmtstr.format( + ytid=row["youtube_id"], + start=int(row["time_start"]), + end=int(row["time_end"]), + ) + label = row["label"].replace(" ", "_").replace("'", "").replace("(", "").replace(")", "") + os.makedirs(path.join(self.split_folder, label), exist_ok=True) + downloaded_file = path.join(self.split_folder, f) + if path.isfile(downloaded_file): + os.replace( + downloaded_file, + path.join(self.split_folder, label, f), + ) + + @property + def metadata(self) -> Dict[str, Any]: + return self.video_clips.metadata + + def __len__(self) -> int: + return self.video_clips.num_clips() + + def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, int]: + video, audio, info, video_idx = self.video_clips.get_clip(idx) + label = self.samples[video_idx][1] + + if self.transform is not None: + video = self.transform(video) + + return video, audio, label diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/kitti.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/kitti.py new file mode 100644 index 0000000000000000000000000000000000000000..69e603c76f22bb6387a4ab6e42bb4d92f49a0bee --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/kitti.py @@ -0,0 +1,158 @@ +import csv +import os +from pathlib import Path +from typing import Any, Callable, List, Optional, Tuple, Union + +from PIL import Image + +from .utils import download_and_extract_archive +from .vision import VisionDataset + + +class Kitti(VisionDataset): + """`KITTI `_ Dataset. + + It corresponds to the "left color images of object" dataset, for object detection. + + Args: + root (str or ``pathlib.Path``): Root directory where images are downloaded to. + Expects the following folder structure if download=False: + + .. code:: + + + └── Kitti + └─ raw + ├── training + | ├── image_2 + | └── label_2 + └── testing + └── image_2 + train (bool, optional): Use ``train`` split if true, else ``test`` split. + Defaults to ``train``. + transform (callable, optional): A function/transform that takes in a PIL image + and returns a transformed version. E.g, ``transforms.PILToTensor`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + transforms (callable, optional): A function/transform that takes input sample + and its target as entry and returns a transformed version. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + + """ + + data_url = "https://s3.eu-central-1.amazonaws.com/avg-kitti/" + resources = [ + "data_object_image_2.zip", + "data_object_label_2.zip", + ] + image_dir_name = "image_2" + labels_dir_name = "label_2" + + def __init__( + self, + root: Union[str, Path], + train: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + transforms: Optional[Callable] = None, + download: bool = False, + ): + super().__init__( + root, + transform=transform, + target_transform=target_transform, + transforms=transforms, + ) + self.images = [] + self.targets = [] + self.train = train + self._location = "training" if self.train else "testing" + + if download: + self.download() + if not self._check_exists(): + raise RuntimeError("Dataset not found. You may use download=True to download it.") + + image_dir = os.path.join(self._raw_folder, self._location, self.image_dir_name) + if self.train: + labels_dir = os.path.join(self._raw_folder, self._location, self.labels_dir_name) + for img_file in os.listdir(image_dir): + self.images.append(os.path.join(image_dir, img_file)) + if self.train: + self.targets.append(os.path.join(labels_dir, f"{img_file.split('.')[0]}.txt")) + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """Get item at a given index. + + Args: + index (int): Index + Returns: + tuple: (image, target), where + target is a list of dictionaries with the following keys: + + - type: str + - truncated: float + - occluded: int + - alpha: float + - bbox: float[4] + - dimensions: float[3] + - locations: float[3] + - rotation_y: float + + """ + image = Image.open(self.images[index]) + target = self._parse_target(index) if self.train else None + if self.transforms: + image, target = self.transforms(image, target) + return image, target + + def _parse_target(self, index: int) -> List: + target = [] + with open(self.targets[index]) as inp: + content = csv.reader(inp, delimiter=" ") + for line in content: + target.append( + { + "type": line[0], + "truncated": float(line[1]), + "occluded": int(line[2]), + "alpha": float(line[3]), + "bbox": [float(x) for x in line[4:8]], + "dimensions": [float(x) for x in line[8:11]], + "location": [float(x) for x in line[11:14]], + "rotation_y": float(line[14]), + } + ) + return target + + def __len__(self) -> int: + return len(self.images) + + @property + def _raw_folder(self) -> str: + return os.path.join(self.root, self.__class__.__name__, "raw") + + def _check_exists(self) -> bool: + """Check if the data directory exists.""" + folders = [self.image_dir_name] + if self.train: + folders.append(self.labels_dir_name) + return all(os.path.isdir(os.path.join(self._raw_folder, self._location, fname)) for fname in folders) + + def download(self) -> None: + """Download the KITTI data if it doesn't exist already.""" + + if self._check_exists(): + return + + os.makedirs(self._raw_folder, exist_ok=True) + + # download files + for fname in self.resources: + download_and_extract_archive( + url=f"{self.data_url}{fname}", + download_root=self._raw_folder, + filename=fname, + ) diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/lfw.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/lfw.py new file mode 100644 index 0000000000000000000000000000000000000000..69f1edaf72fb93f7d5f96a4fe04b7dafc4643c6e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/lfw.py @@ -0,0 +1,256 @@ +import os +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +from PIL import Image + +from .utils import check_integrity, download_and_extract_archive, download_url, verify_str_arg +from .vision import VisionDataset + + +class _LFW(VisionDataset): + + base_folder = "lfw-py" + download_url_prefix = "http://vis-www.cs.umass.edu/lfw/" + + file_dict = { + "original": ("lfw", "lfw.tgz", "a17d05bd522c52d84eca14327a23d494"), + "funneled": ("lfw_funneled", "lfw-funneled.tgz", "1b42dfed7d15c9b2dd63d5e5840c86ad"), + "deepfunneled": ("lfw-deepfunneled", "lfw-deepfunneled.tgz", "68331da3eb755a505a502b5aacb3c201"), + } + checksums = { + "pairs.txt": "9f1ba174e4e1c508ff7cdf10ac338a7d", + "pairsDevTest.txt": "5132f7440eb68cf58910c8a45a2ac10b", + "pairsDevTrain.txt": "4f27cbf15b2da4a85c1907eb4181ad21", + "people.txt": "450f0863dd89e85e73936a6d71a3474b", + "peopleDevTest.txt": "e4bf5be0a43b5dcd9dc5ccfcb8fb19c5", + "peopleDevTrain.txt": "54eaac34beb6d042ed3a7d883e247a21", + "lfw-names.txt": "a6d0a479bd074669f656265a6e693f6d", + } + annot_file = {"10fold": "", "train": "DevTrain", "test": "DevTest"} + names = "lfw-names.txt" + + def __init__( + self, + root: Union[str, Path], + split: str, + image_set: str, + view: str, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: + super().__init__(os.path.join(root, self.base_folder), transform=transform, target_transform=target_transform) + + self.image_set = verify_str_arg(image_set.lower(), "image_set", self.file_dict.keys()) + images_dir, self.filename, self.md5 = self.file_dict[self.image_set] + + self.view = verify_str_arg(view.lower(), "view", ["people", "pairs"]) + self.split = verify_str_arg(split.lower(), "split", ["10fold", "train", "test"]) + self.labels_file = f"{self.view}{self.annot_file[self.split]}.txt" + self.data: List[Any] = [] + + if download: + self.download() + + if not self._check_integrity(): + raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it") + + self.images_dir = os.path.join(self.root, images_dir) + + def _loader(self, path: str) -> Image.Image: + with open(path, "rb") as f: + img = Image.open(f) + return img.convert("RGB") + + def _check_integrity(self) -> bool: + st1 = check_integrity(os.path.join(self.root, self.filename), self.md5) + st2 = check_integrity(os.path.join(self.root, self.labels_file), self.checksums[self.labels_file]) + if not st1 or not st2: + return False + if self.view == "people": + return check_integrity(os.path.join(self.root, self.names), self.checksums[self.names]) + return True + + def download(self) -> None: + if self._check_integrity(): + print("Files already downloaded and verified") + return + url = f"{self.download_url_prefix}{self.filename}" + download_and_extract_archive(url, self.root, filename=self.filename, md5=self.md5) + download_url(f"{self.download_url_prefix}{self.labels_file}", self.root) + if self.view == "people": + download_url(f"{self.download_url_prefix}{self.names}", self.root) + + def _get_path(self, identity: str, no: Union[int, str]) -> str: + return os.path.join(self.images_dir, identity, f"{identity}_{int(no):04d}.jpg") + + def extra_repr(self) -> str: + return f"Alignment: {self.image_set}\nSplit: {self.split}" + + def __len__(self) -> int: + return len(self.data) + + +class LFWPeople(_LFW): + """`LFW `_ Dataset. + + Args: + root (str or ``pathlib.Path``): Root directory of dataset where directory + ``lfw-py`` exists or will be saved to if download is set to True. + split (string, optional): The image split to use. Can be one of ``train``, ``test``, + ``10fold`` (default). + image_set (str, optional): Type of image funneling to use, ``original``, ``funneled`` or + ``deepfunneled``. Defaults to ``funneled``. + transform (callable, optional): A function/transform that takes in a PIL image + and returns a transformed version. E.g, ``transforms.RandomRotation`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + + """ + + def __init__( + self, + root: str, + split: str = "10fold", + image_set: str = "funneled", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: + super().__init__(root, split, image_set, "people", transform, target_transform, download) + + self.class_to_idx = self._get_classes() + self.data, self.targets = self._get_people() + + def _get_people(self) -> Tuple[List[str], List[int]]: + data, targets = [], [] + with open(os.path.join(self.root, self.labels_file)) as f: + lines = f.readlines() + n_folds, s = (int(lines[0]), 1) if self.split == "10fold" else (1, 0) + + for fold in range(n_folds): + n_lines = int(lines[s]) + people = [line.strip().split("\t") for line in lines[s + 1 : s + n_lines + 1]] + s += n_lines + 1 + for i, (identity, num_imgs) in enumerate(people): + for num in range(1, int(num_imgs) + 1): + img = self._get_path(identity, num) + data.append(img) + targets.append(self.class_to_idx[identity]) + + return data, targets + + def _get_classes(self) -> Dict[str, int]: + with open(os.path.join(self.root, self.names)) as f: + lines = f.readlines() + names = [line.strip().split()[0] for line in lines] + class_to_idx = {name: i for i, name in enumerate(names)} + return class_to_idx + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """ + Args: + index (int): Index + + Returns: + tuple: Tuple (image, target) where target is the identity of the person. + """ + img = self._loader(self.data[index]) + target = self.targets[index] + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def extra_repr(self) -> str: + return super().extra_repr() + f"\nClasses (identities): {len(self.class_to_idx)}" + + +class LFWPairs(_LFW): + """`LFW `_ Dataset. + + Args: + root (str or ``pathlib.Path``): Root directory of dataset where directory + ``lfw-py`` exists or will be saved to if download is set to True. + split (string, optional): The image split to use. Can be one of ``train``, ``test``, + ``10fold``. Defaults to ``10fold``. + image_set (str, optional): Type of image funneling to use, ``original``, ``funneled`` or + ``deepfunneled``. Defaults to ``funneled``. + transform (callable, optional): A function/transform that takes in a PIL image + and returns a transformed version. E.g, ``transforms.RandomRotation`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + + """ + + def __init__( + self, + root: str, + split: str = "10fold", + image_set: str = "funneled", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: + super().__init__(root, split, image_set, "pairs", transform, target_transform, download) + + self.pair_names, self.data, self.targets = self._get_pairs(self.images_dir) + + def _get_pairs(self, images_dir: str) -> Tuple[List[Tuple[str, str]], List[Tuple[str, str]], List[int]]: + pair_names, data, targets = [], [], [] + with open(os.path.join(self.root, self.labels_file)) as f: + lines = f.readlines() + if self.split == "10fold": + n_folds, n_pairs = lines[0].split("\t") + n_folds, n_pairs = int(n_folds), int(n_pairs) + else: + n_folds, n_pairs = 1, int(lines[0]) + s = 1 + + for fold in range(n_folds): + matched_pairs = [line.strip().split("\t") for line in lines[s : s + n_pairs]] + unmatched_pairs = [line.strip().split("\t") for line in lines[s + n_pairs : s + (2 * n_pairs)]] + s += 2 * n_pairs + for pair in matched_pairs: + img1, img2, same = self._get_path(pair[0], pair[1]), self._get_path(pair[0], pair[2]), 1 + pair_names.append((pair[0], pair[0])) + data.append((img1, img2)) + targets.append(same) + for pair in unmatched_pairs: + img1, img2, same = self._get_path(pair[0], pair[1]), self._get_path(pair[2], pair[3]), 0 + pair_names.append((pair[0], pair[2])) + data.append((img1, img2)) + targets.append(same) + + return pair_names, data, targets + + def __getitem__(self, index: int) -> Tuple[Any, Any, int]: + """ + Args: + index (int): Index + + Returns: + tuple: (image1, image2, target) where target is `0` for different indentities and `1` for same identities. + """ + img1, img2 = self.data[index] + img1, img2 = self._loader(img1), self._loader(img2) + target = self.targets[index] + + if self.transform is not None: + img1, img2 = self.transform(img1), self.transform(img2) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img1, img2, target diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/mnist.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/mnist.py new file mode 100644 index 0000000000000000000000000000000000000000..b2bbcc6fbaec6c85359778e348949f8347cd33e5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/mnist.py @@ -0,0 +1,559 @@ +import codecs +import os +import os.path +import shutil +import string +import sys +import warnings +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from urllib.error import URLError + +import numpy as np +import torch +from PIL import Image + +from .utils import _flip_byte_order, check_integrity, download_and_extract_archive, extract_archive, verify_str_arg +from .vision import VisionDataset + + +class MNIST(VisionDataset): + """`MNIST `_ Dataset. + + Args: + root (str or ``pathlib.Path``): Root directory of dataset where ``MNIST/raw/train-images-idx3-ubyte`` + and ``MNIST/raw/t10k-images-idx3-ubyte`` exist. + train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``, + otherwise from ``t10k-images-idx3-ubyte``. + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + transform (callable, optional): A function/transform that takes in a PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + """ + + mirrors = [ + "http://yann.lecun.com/exdb/mnist/", + "https://ossci-datasets.s3.amazonaws.com/mnist/", + ] + + resources = [ + ("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"), + ("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"), + ("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"), + ("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c"), + ] + + training_file = "training.pt" + test_file = "test.pt" + classes = [ + "0 - zero", + "1 - one", + "2 - two", + "3 - three", + "4 - four", + "5 - five", + "6 - six", + "7 - seven", + "8 - eight", + "9 - nine", + ] + + @property + def train_labels(self): + warnings.warn("train_labels has been renamed targets") + return self.targets + + @property + def test_labels(self): + warnings.warn("test_labels has been renamed targets") + return self.targets + + @property + def train_data(self): + warnings.warn("train_data has been renamed data") + return self.data + + @property + def test_data(self): + warnings.warn("test_data has been renamed data") + return self.data + + def __init__( + self, + root: Union[str, Path], + train: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: + super().__init__(root, transform=transform, target_transform=target_transform) + self.train = train # training set or test set + + if self._check_legacy_exist(): + self.data, self.targets = self._load_legacy_data() + return + + if download: + self.download() + + if not self._check_exists(): + raise RuntimeError("Dataset not found. You can use download=True to download it") + + self.data, self.targets = self._load_data() + + def _check_legacy_exist(self): + processed_folder_exists = os.path.exists(self.processed_folder) + if not processed_folder_exists: + return False + + return all( + check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file) + ) + + def _load_legacy_data(self): + # This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data + # directly. + data_file = self.training_file if self.train else self.test_file + return torch.load(os.path.join(self.processed_folder, data_file), weights_only=True) + + def _load_data(self): + image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte" + data = read_image_file(os.path.join(self.raw_folder, image_file)) + + label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte" + targets = read_label_file(os.path.join(self.raw_folder, label_file)) + + return data, targets + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """ + Args: + index (int): Index + + Returns: + tuple: (image, target) where target is index of the target class. + """ + img, target = self.data[index], int(self.targets[index]) + + # doing this so that it is consistent with all other datasets + # to return a PIL Image + img = Image.fromarray(img.numpy(), mode="L") + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def __len__(self) -> int: + return len(self.data) + + @property + def raw_folder(self) -> str: + return os.path.join(self.root, self.__class__.__name__, "raw") + + @property + def processed_folder(self) -> str: + return os.path.join(self.root, self.__class__.__name__, "processed") + + @property + def class_to_idx(self) -> Dict[str, int]: + return {_class: i for i, _class in enumerate(self.classes)} + + def _check_exists(self) -> bool: + return all( + check_integrity(os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])) + for url, _ in self.resources + ) + + def download(self) -> None: + """Download the MNIST data if it doesn't exist already.""" + + if self._check_exists(): + return + + os.makedirs(self.raw_folder, exist_ok=True) + + # download files + for filename, md5 in self.resources: + for mirror in self.mirrors: + url = f"{mirror}{filename}" + try: + print(f"Downloading {url}") + download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5) + except URLError as error: + print(f"Failed to download (trying next):\n{error}") + continue + finally: + print() + break + else: + raise RuntimeError(f"Error downloading {filename}") + + def extra_repr(self) -> str: + split = "Train" if self.train is True else "Test" + return f"Split: {split}" + + +class FashionMNIST(MNIST): + """`Fashion-MNIST `_ Dataset. + + Args: + root (str or ``pathlib.Path``): Root directory of dataset where ``FashionMNIST/raw/train-images-idx3-ubyte`` + and ``FashionMNIST/raw/t10k-images-idx3-ubyte`` exist. + train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``, + otherwise from ``t10k-images-idx3-ubyte``. + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + transform (callable, optional): A function/transform that takes in a PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + """ + + mirrors = ["http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"] + + resources = [ + ("train-images-idx3-ubyte.gz", "8d4fb7e6c68d591d4c3dfef9ec88bf0d"), + ("train-labels-idx1-ubyte.gz", "25c81989df183df01b3e8a0aad5dffbe"), + ("t10k-images-idx3-ubyte.gz", "bef4ecab320f06d8554ea6380940ec79"), + ("t10k-labels-idx1-ubyte.gz", "bb300cfdad3c16e7a12a480ee83cd310"), + ] + classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"] + + +class KMNIST(MNIST): + """`Kuzushiji-MNIST `_ Dataset. + + Args: + root (str or ``pathlib.Path``): Root directory of dataset where ``KMNIST/raw/train-images-idx3-ubyte`` + and ``KMNIST/raw/t10k-images-idx3-ubyte`` exist. + train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``, + otherwise from ``t10k-images-idx3-ubyte``. + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + transform (callable, optional): A function/transform that takes in a PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + """ + + mirrors = ["http://codh.rois.ac.jp/kmnist/dataset/kmnist/"] + + resources = [ + ("train-images-idx3-ubyte.gz", "bdb82020997e1d708af4cf47b453dcf7"), + ("train-labels-idx1-ubyte.gz", "e144d726b3acfaa3e44228e80efcd344"), + ("t10k-images-idx3-ubyte.gz", "5c965bf0a639b31b8f53240b1b52f4d7"), + ("t10k-labels-idx1-ubyte.gz", "7320c461ea6c1c855c0b718fb2a4b134"), + ] + classes = ["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"] + + +class EMNIST(MNIST): + """`EMNIST `_ Dataset. + + Args: + root (str or ``pathlib.Path``): Root directory of dataset where ``EMNIST/raw/train-images-idx3-ubyte`` + and ``EMNIST/raw/t10k-images-idx3-ubyte`` exist. + split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``, + ``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies + which one to use. + train (bool, optional): If True, creates dataset from ``training.pt``, + otherwise from ``test.pt``. + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + transform (callable, optional): A function/transform that takes in a PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + """ + + url = "https://biometrics.nist.gov/cs_links/EMNIST/gzip.zip" + md5 = "58c8d27c78d21e728a6bc7b3cc06412e" + splits = ("byclass", "bymerge", "balanced", "letters", "digits", "mnist") + # Merged Classes assumes Same structure for both uppercase and lowercase version + _merged_classes = {"c", "i", "j", "k", "l", "m", "o", "p", "s", "u", "v", "w", "x", "y", "z"} + _all_classes = set(string.digits + string.ascii_letters) + classes_split_dict = { + "byclass": sorted(list(_all_classes)), + "bymerge": sorted(list(_all_classes - _merged_classes)), + "balanced": sorted(list(_all_classes - _merged_classes)), + "letters": ["N/A"] + list(string.ascii_lowercase), + "digits": list(string.digits), + "mnist": list(string.digits), + } + + def __init__(self, root: Union[str, Path], split: str, **kwargs: Any) -> None: + self.split = verify_str_arg(split, "split", self.splits) + self.training_file = self._training_file(split) + self.test_file = self._test_file(split) + super().__init__(root, **kwargs) + self.classes = self.classes_split_dict[self.split] + + @staticmethod + def _training_file(split) -> str: + return f"training_{split}.pt" + + @staticmethod + def _test_file(split) -> str: + return f"test_{split}.pt" + + @property + def _file_prefix(self) -> str: + return f"emnist-{self.split}-{'train' if self.train else 'test'}" + + @property + def images_file(self) -> str: + return os.path.join(self.raw_folder, f"{self._file_prefix}-images-idx3-ubyte") + + @property + def labels_file(self) -> str: + return os.path.join(self.raw_folder, f"{self._file_prefix}-labels-idx1-ubyte") + + def _load_data(self): + return read_image_file(self.images_file), read_label_file(self.labels_file) + + def _check_exists(self) -> bool: + return all(check_integrity(file) for file in (self.images_file, self.labels_file)) + + def download(self) -> None: + """Download the EMNIST data if it doesn't exist already.""" + + if self._check_exists(): + return + + os.makedirs(self.raw_folder, exist_ok=True) + + download_and_extract_archive(self.url, download_root=self.raw_folder, md5=self.md5) + gzip_folder = os.path.join(self.raw_folder, "gzip") + for gzip_file in os.listdir(gzip_folder): + if gzip_file.endswith(".gz"): + extract_archive(os.path.join(gzip_folder, gzip_file), self.raw_folder) + shutil.rmtree(gzip_folder) + + +class QMNIST(MNIST): + """`QMNIST `_ Dataset. + + Args: + root (str or ``pathlib.Path``): Root directory of dataset whose ``raw`` + subdir contains binary files of the datasets. + what (string,optional): Can be 'train', 'test', 'test10k', + 'test50k', or 'nist' for respectively the mnist compatible + training set, the 60k qmnist testing set, the 10k qmnist + examples that match the mnist testing set, the 50k + remaining qmnist testing examples, or all the nist + digits. The default is to select 'train' or 'test' + according to the compatibility argument 'train'. + compat (bool,optional): A boolean that says whether the target + for each example is class number (for compatibility with + the MNIST dataloader) or a torch vector containing the + full qmnist information. Default=True. + download (bool, optional): If True, downloads the dataset from + the internet and puts it in root directory. If dataset is + already downloaded, it is not downloaded again. + transform (callable, optional): A function/transform that + takes in a PIL image and returns a transformed + version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform + that takes in the target and transforms it. + train (bool,optional,compatibility): When argument 'what' is + not specified, this boolean decides whether to load the + training set or the testing set. Default: True. + """ + + subsets = {"train": "train", "test": "test", "test10k": "test", "test50k": "test", "nist": "nist"} + resources: Dict[str, List[Tuple[str, str]]] = { # type: ignore[assignment] + "train": [ + ( + "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz", + "ed72d4157d28c017586c42bc6afe6370", + ), + ( + "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz", + "0058f8dd561b90ffdd0f734c6a30e5e4", + ), + ], + "test": [ + ( + "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-images-idx3-ubyte.gz", + "1394631089c404de565df7b7aeaf9412", + ), + ( + "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-labels-idx2-int.gz", + "5b5b05890a5e13444e108efe57b788aa", + ), + ], + "nist": [ + ( + "https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-images-idx3-ubyte.xz", + "7f124b3b8ab81486c9d8c2749c17f834", + ), + ( + "https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-labels-idx2-int.xz", + "5ed0e788978e45d4a8bd4b7caec3d79d", + ), + ], + } + classes = [ + "0 - zero", + "1 - one", + "2 - two", + "3 - three", + "4 - four", + "5 - five", + "6 - six", + "7 - seven", + "8 - eight", + "9 - nine", + ] + + def __init__( + self, root: Union[str, Path], what: Optional[str] = None, compat: bool = True, train: bool = True, **kwargs: Any + ) -> None: + if what is None: + what = "train" if train else "test" + self.what = verify_str_arg(what, "what", tuple(self.subsets.keys())) + self.compat = compat + self.data_file = what + ".pt" + self.training_file = self.data_file + self.test_file = self.data_file + super().__init__(root, train, **kwargs) + + @property + def images_file(self) -> str: + (url, _), _ = self.resources[self.subsets[self.what]] + return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]) + + @property + def labels_file(self) -> str: + _, (url, _) = self.resources[self.subsets[self.what]] + return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]) + + def _check_exists(self) -> bool: + return all(check_integrity(file) for file in (self.images_file, self.labels_file)) + + def _load_data(self): + data = read_sn3_pascalvincent_tensor(self.images_file) + if data.dtype != torch.uint8: + raise TypeError(f"data should be of dtype torch.uint8 instead of {data.dtype}") + if data.ndimension() != 3: + raise ValueError("data should have 3 dimensions instead of {data.ndimension()}") + + targets = read_sn3_pascalvincent_tensor(self.labels_file).long() + if targets.ndimension() != 2: + raise ValueError(f"targets should have 2 dimensions instead of {targets.ndimension()}") + + if self.what == "test10k": + data = data[0:10000, :, :].clone() + targets = targets[0:10000, :].clone() + elif self.what == "test50k": + data = data[10000:, :, :].clone() + targets = targets[10000:, :].clone() + + return data, targets + + def download(self) -> None: + """Download the QMNIST data if it doesn't exist already. + Note that we only download what has been asked for (argument 'what'). + """ + if self._check_exists(): + return + + os.makedirs(self.raw_folder, exist_ok=True) + split = self.resources[self.subsets[self.what]] + + for url, md5 in split: + download_and_extract_archive(url, self.raw_folder, md5=md5) + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + # redefined to handle the compat flag + img, target = self.data[index], self.targets[index] + img = Image.fromarray(img.numpy(), mode="L") + if self.transform is not None: + img = self.transform(img) + if self.compat: + target = int(target[0]) + if self.target_transform is not None: + target = self.target_transform(target) + return img, target + + def extra_repr(self) -> str: + return f"Split: {self.what}" + + +def get_int(b: bytes) -> int: + return int(codecs.encode(b, "hex"), 16) + + +SN3_PASCALVINCENT_TYPEMAP = { + 8: torch.uint8, + 9: torch.int8, + 11: torch.int16, + 12: torch.int32, + 13: torch.float32, + 14: torch.float64, +} + + +def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tensor: + """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh'). + Argument may be a filename, compressed filename, or file object. + """ + # read + with open(path, "rb") as f: + data = f.read() + + # parse + if sys.byteorder == "little": + magic = get_int(data[0:4]) + nd = magic % 256 + ty = magic // 256 + else: + nd = get_int(data[0:1]) + ty = get_int(data[1:2]) + get_int(data[2:3]) * 256 + get_int(data[3:4]) * 256 * 256 + + assert 1 <= nd <= 3 + assert 8 <= ty <= 14 + torch_type = SN3_PASCALVINCENT_TYPEMAP[ty] + s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)] + + if sys.byteorder == "big": + for i in range(len(s)): + s[i] = int.from_bytes(s[i].to_bytes(4, byteorder="little"), byteorder="big", signed=False) + + parsed = torch.frombuffer(bytearray(data), dtype=torch_type, offset=(4 * (nd + 1))) + + # The MNIST format uses the big endian byte order, while `torch.frombuffer` uses whatever the system uses. In case + # that is little endian and the dtype has more than one byte, we need to flip them. + if sys.byteorder == "little" and parsed.element_size() > 1: + parsed = _flip_byte_order(parsed) + + assert parsed.shape[0] == np.prod(s) or not strict + return parsed.view(*s) + + +def read_label_file(path: str) -> torch.Tensor: + x = read_sn3_pascalvincent_tensor(path, strict=False) + if x.dtype != torch.uint8: + raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}") + if x.ndimension() != 1: + raise ValueError(f"x should have 1 dimension instead of {x.ndimension()}") + return x.long() + + +def read_image_file(path: str) -> torch.Tensor: + x = read_sn3_pascalvincent_tensor(path, strict=False) + if x.dtype != torch.uint8: + raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}") + if x.ndimension() != 3: + raise ValueError(f"x should have 3 dimension instead of {x.ndimension()}") + return x diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/moving_mnist.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/moving_mnist.py new file mode 100644 index 0000000000000000000000000000000000000000..48715de4e8dea86fec71eb09ea98f3ff1b9d6fda --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/moving_mnist.py @@ -0,0 +1,94 @@ +import os.path +from pathlib import Path +from typing import Callable, Optional, Union + +import numpy as np +import torch +from torchvision.datasets.utils import download_url, verify_str_arg +from torchvision.datasets.vision import VisionDataset + + +class MovingMNIST(VisionDataset): + """`MovingMNIST `_ Dataset. + + Args: + root (str or ``pathlib.Path``): Root directory of dataset where ``MovingMNIST/mnist_test_seq.npy`` exists. + split (string, optional): The dataset split, supports ``None`` (default), ``"train"`` and ``"test"``. + If ``split=None``, the full data is returned. + split_ratio (int, optional): The split ratio of number of frames. If ``split="train"``, the first split + frames ``data[:, :split_ratio]`` is returned. If ``split="test"``, the last split frames ``data[:, split_ratio:]`` + is returned. If ``split=None``, this parameter is ignored and the all frames data is returned. + transform (callable, optional): A function/transform that takes in a torch Tensor + and returns a transformed version. E.g, ``transforms.RandomCrop`` + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + """ + + _URL = "http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy" + + def __init__( + self, + root: Union[str, Path], + split: Optional[str] = None, + split_ratio: int = 10, + download: bool = False, + transform: Optional[Callable] = None, + ) -> None: + super().__init__(root, transform=transform) + + self._base_folder = os.path.join(self.root, self.__class__.__name__) + self._filename = self._URL.split("/")[-1] + + if split is not None: + verify_str_arg(split, "split", ("train", "test")) + self.split = split + + if not isinstance(split_ratio, int): + raise TypeError(f"`split_ratio` should be an integer, but got {type(split_ratio)}") + elif not (1 <= split_ratio <= 19): + raise ValueError(f"`split_ratio` should be `1 <= split_ratio <= 19`, but got {split_ratio} instead.") + self.split_ratio = split_ratio + + if download: + self.download() + + if not self._check_exists(): + raise RuntimeError("Dataset not found. You can use download=True to download it.") + + data = torch.from_numpy(np.load(os.path.join(self._base_folder, self._filename))) + if self.split == "train": + data = data[: self.split_ratio] + elif self.split == "test": + data = data[self.split_ratio :] + self.data = data.transpose(0, 1).unsqueeze(2).contiguous() + + def __getitem__(self, idx: int) -> torch.Tensor: + """ + Args: + idx (int): Index + Returns: + torch.Tensor: Video frames (torch Tensor[T, C, H, W]). The `T` is the number of frames. + """ + data = self.data[idx] + if self.transform is not None: + data = self.transform(data) + + return data + + def __len__(self) -> int: + return len(self.data) + + def _check_exists(self) -> bool: + return os.path.exists(os.path.join(self._base_folder, self._filename)) + + def download(self) -> None: + if self._check_exists(): + return + + download_url( + url=self._URL, + root=self._base_folder, + filename=self._filename, + md5="be083ec986bfe91a449d63653c411eb2", + ) diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/omniglot.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/omniglot.py new file mode 100644 index 0000000000000000000000000000000000000000..c02cf91234aaf5e4380c6890a79d8958321d7043 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/omniglot.py @@ -0,0 +1,103 @@ +from os.path import join +from pathlib import Path +from typing import Any, Callable, List, Optional, Tuple, Union + +from PIL import Image + +from .utils import check_integrity, download_and_extract_archive, list_dir, list_files +from .vision import VisionDataset + + +class Omniglot(VisionDataset): + """`Omniglot `_ Dataset. + + Args: + root (str or ``pathlib.Path``): Root directory of dataset where directory + ``omniglot-py`` exists. + background (bool, optional): If True, creates dataset from the "background" set, otherwise + creates from the "evaluation" set. This terminology is defined by the authors. + transform (callable, optional): A function/transform that takes in a PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + download (bool, optional): If true, downloads the dataset zip files from the internet and + puts it in root directory. If the zip files are already downloaded, they are not + downloaded again. + """ + + folder = "omniglot-py" + download_url_prefix = "https://raw.githubusercontent.com/brendenlake/omniglot/master/python" + zips_md5 = { + "images_background": "68d2efa1b9178cc56df9314c21c6e718", + "images_evaluation": "6b91aef0f799c5bb55b94e3f2daec811", + } + + def __init__( + self, + root: Union[str, Path], + background: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: + super().__init__(join(root, self.folder), transform=transform, target_transform=target_transform) + self.background = background + + if download: + self.download() + + if not self._check_integrity(): + raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it") + + self.target_folder = join(self.root, self._get_target_folder()) + self._alphabets = list_dir(self.target_folder) + self._characters: List[str] = sum( + ([join(a, c) for c in list_dir(join(self.target_folder, a))] for a in self._alphabets), [] + ) + self._character_images = [ + [(image, idx) for image in list_files(join(self.target_folder, character), ".png")] + for idx, character in enumerate(self._characters) + ] + self._flat_character_images: List[Tuple[str, int]] = sum(self._character_images, []) + + def __len__(self) -> int: + return len(self._flat_character_images) + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """ + Args: + index (int): Index + + Returns: + tuple: (image, target) where target is index of the target character class. + """ + image_name, character_class = self._flat_character_images[index] + image_path = join(self.target_folder, self._characters[character_class], image_name) + image = Image.open(image_path, mode="r").convert("L") + + if self.transform: + image = self.transform(image) + + if self.target_transform: + character_class = self.target_transform(character_class) + + return image, character_class + + def _check_integrity(self) -> bool: + zip_filename = self._get_target_folder() + if not check_integrity(join(self.root, zip_filename + ".zip"), self.zips_md5[zip_filename]): + return False + return True + + def download(self) -> None: + if self._check_integrity(): + print("Files already downloaded and verified") + return + + filename = self._get_target_folder() + zip_filename = filename + ".zip" + url = self.download_url_prefix + "/" + zip_filename + download_and_extract_archive(url, self.root, filename=zip_filename, md5=self.zips_md5[filename]) + + def _get_target_folder(self) -> str: + return "images_background" if self.background else "images_evaluation" diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/oxford_iiit_pet.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/oxford_iiit_pet.py new file mode 100644 index 0000000000000000000000000000000000000000..1d6d990fdf90106ef765d9ff47d1b760ce26d5d0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/oxford_iiit_pet.py @@ -0,0 +1,132 @@ +import os +import os.path +import pathlib +from typing import Any, Callable, Optional, Sequence, Tuple, Union + +from PIL import Image + +from .utils import download_and_extract_archive, verify_str_arg +from .vision import VisionDataset + + +class OxfordIIITPet(VisionDataset): + """`Oxford-IIIT Pet Dataset `_. + + Args: + root (str or ``pathlib.Path``): Root directory of the dataset. + split (string, optional): The dataset split, supports ``"trainval"`` (default) or ``"test"``. + target_types (string, sequence of strings, optional): Types of target to use. Can be ``category`` (default) or + ``segmentation``. Can also be a list to output a tuple with all specified target types. The types represent: + + - ``category`` (int): Label for one of the 37 pet categories. + - ``binary-category`` (int): Binary label for cat or dog. + - ``segmentation`` (PIL image): Segmentation trimap of the image. + + If empty, ``None`` will be returned as target. + + transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed + version. E.g, ``transforms.RandomCrop``. + target_transform (callable, optional): A function/transform that takes in the target and transforms it. + download (bool, optional): If True, downloads the dataset from the internet and puts it into + ``root/oxford-iiit-pet``. If dataset is already downloaded, it is not downloaded again. + """ + + _RESOURCES = ( + ("https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz", "5c4f3ee8e5d25df40f4fd59a7f44e54c"), + ("https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz", "95a8c909bbe2e81eed6a22bccdf3f68f"), + ) + _VALID_TARGET_TYPES = ("category", "binary-category", "segmentation") + + def __init__( + self, + root: Union[str, pathlib.Path], + split: str = "trainval", + target_types: Union[Sequence[str], str] = "category", + transforms: Optional[Callable] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ): + self._split = verify_str_arg(split, "split", ("trainval", "test")) + if isinstance(target_types, str): + target_types = [target_types] + self._target_types = [ + verify_str_arg(target_type, "target_types", self._VALID_TARGET_TYPES) for target_type in target_types + ] + + super().__init__(root, transforms=transforms, transform=transform, target_transform=target_transform) + self._base_folder = pathlib.Path(self.root) / "oxford-iiit-pet" + self._images_folder = self._base_folder / "images" + self._anns_folder = self._base_folder / "annotations" + self._segs_folder = self._anns_folder / "trimaps" + + if download: + self._download() + + if not self._check_exists(): + raise RuntimeError("Dataset not found. You can use download=True to download it") + + image_ids = [] + self._labels = [] + self._bin_labels = [] + with open(self._anns_folder / f"{self._split}.txt") as file: + for line in file: + image_id, label, bin_label, _ = line.strip().split() + image_ids.append(image_id) + self._labels.append(int(label) - 1) + self._bin_labels.append(int(bin_label) - 1) + + self.bin_classes = ["Cat", "Dog"] + self.classes = [ + " ".join(part.title() for part in raw_cls.split("_")) + for raw_cls, _ in sorted( + {(image_id.rsplit("_", 1)[0], label) for image_id, label in zip(image_ids, self._labels)}, + key=lambda image_id_and_label: image_id_and_label[1], + ) + ] + self.bin_class_to_idx = dict(zip(self.bin_classes, range(len(self.bin_classes)))) + self.class_to_idx = dict(zip(self.classes, range(len(self.classes)))) + + self._images = [self._images_folder / f"{image_id}.jpg" for image_id in image_ids] + self._segs = [self._segs_folder / f"{image_id}.png" for image_id in image_ids] + + def __len__(self) -> int: + return len(self._images) + + def __getitem__(self, idx: int) -> Tuple[Any, Any]: + image = Image.open(self._images[idx]).convert("RGB") + + target: Any = [] + for target_type in self._target_types: + if target_type == "category": + target.append(self._labels[idx]) + elif target_type == "binary-category": + target.append(self._bin_labels[idx]) + else: # target_type == "segmentation" + target.append(Image.open(self._segs[idx])) + + if not target: + target = None + elif len(target) == 1: + target = target[0] + else: + target = tuple(target) + + if self.transforms: + image, target = self.transforms(image, target) + + return image, target + + def _check_exists(self) -> bool: + for folder in (self._images_folder, self._anns_folder): + if not (os.path.exists(folder) and os.path.isdir(folder)): + return False + else: + return True + + def _download(self) -> None: + if self._check_exists(): + return + + for url, md5 in self._RESOURCES: + download_and_extract_archive(url, download_root=str(self._base_folder), md5=md5) diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/pcam.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/pcam.py new file mode 100644 index 0000000000000000000000000000000000000000..8849e0ea39dd1bfb76920bc8afc7f6ced597cda4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/pcam.py @@ -0,0 +1,134 @@ +import pathlib +from typing import Any, Callable, Optional, Tuple, Union + +from PIL import Image + +from .utils import _decompress, download_file_from_google_drive, verify_str_arg +from .vision import VisionDataset + + +class PCAM(VisionDataset): + """`PCAM Dataset `_. + + The PatchCamelyon dataset is a binary classification dataset with 327,680 + color images (96px x 96px), extracted from histopathologic scans of lymph node + sections. Each image is annotated with a binary label indicating presence of + metastatic tissue. + + This dataset requires the ``h5py`` package which you can install with ``pip install h5py``. + + Args: + root (str or ``pathlib.Path``): Root directory of the dataset. + split (string, optional): The dataset split, supports ``"train"`` (default), ``"test"`` or ``"val"``. + transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed + version. E.g, ``transforms.RandomCrop``. + target_transform (callable, optional): A function/transform that takes in the target and transforms it. + download (bool, optional): If True, downloads the dataset from the internet and puts it into ``root/pcam``. If + dataset is already downloaded, it is not downloaded again. + + .. warning:: + + To download the dataset `gdown `_ is required. + """ + + _FILES = { + "train": { + "images": ( + "camelyonpatch_level_2_split_train_x.h5", # Data file name + "1Ka0XfEMiwgCYPdTI-vv6eUElOBnKFKQ2", # Google Drive ID + "1571f514728f59376b705fc836ff4b63", # md5 hash + ), + "targets": ( + "camelyonpatch_level_2_split_train_y.h5", + "1269yhu3pZDP8UYFQs-NYs3FPwuK-nGSG", + "35c2d7259d906cfc8143347bb8e05be7", + ), + }, + "test": { + "images": ( + "camelyonpatch_level_2_split_test_x.h5", + "1qV65ZqZvWzuIVthK8eVDhIwrbnsJdbg_", + "d8c2d60d490dbd479f8199bdfa0cf6ec", + ), + "targets": ( + "camelyonpatch_level_2_split_test_y.h5", + "17BHrSrwWKjYsOgTMmoqrIjDy6Fa2o_gP", + "60a7035772fbdb7f34eb86d4420cf66a", + ), + }, + "val": { + "images": ( + "camelyonpatch_level_2_split_valid_x.h5", + "1hgshYGWK8V-eGRy8LToWJJgDU_rXWVJ3", + "d5b63470df7cfa627aeec8b9dc0c066e", + ), + "targets": ( + "camelyonpatch_level_2_split_valid_y.h5", + "1bH8ZRbhSVAhScTS0p9-ZzGnX91cHT3uO", + "2b85f58b927af9964a4c15b8f7e8f179", + ), + }, + } + + def __init__( + self, + root: Union[str, pathlib.Path], + split: str = "train", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ): + try: + import h5py + + self.h5py = h5py + except ImportError: + raise RuntimeError( + "h5py is not found. This dataset needs to have h5py installed: please run pip install h5py" + ) + + self._split = verify_str_arg(split, "split", ("train", "test", "val")) + + super().__init__(root, transform=transform, target_transform=target_transform) + self._base_folder = pathlib.Path(self.root) / "pcam" + + if download: + self._download() + + if not self._check_exists(): + raise RuntimeError("Dataset not found. You can use download=True to download it") + + def __len__(self) -> int: + images_file = self._FILES[self._split]["images"][0] + with self.h5py.File(self._base_folder / images_file) as images_data: + return images_data["x"].shape[0] + + def __getitem__(self, idx: int) -> Tuple[Any, Any]: + images_file = self._FILES[self._split]["images"][0] + with self.h5py.File(self._base_folder / images_file) as images_data: + image = Image.fromarray(images_data["x"][idx]).convert("RGB") + + targets_file = self._FILES[self._split]["targets"][0] + with self.h5py.File(self._base_folder / targets_file) as targets_data: + target = int(targets_data["y"][idx, 0, 0, 0]) # shape is [num_images, 1, 1, 1] + + if self.transform: + image = self.transform(image) + if self.target_transform: + target = self.target_transform(target) + + return image, target + + def _check_exists(self) -> bool: + images_file = self._FILES[self._split]["images"][0] + targets_file = self._FILES[self._split]["targets"][0] + return all(self._base_folder.joinpath(h5_file).exists() for h5_file in (images_file, targets_file)) + + def _download(self) -> None: + if self._check_exists(): + return + + for file_name, file_id, md5 in self._FILES[self._split].values(): + archive_name = file_name + ".gz" + download_file_from_google_drive(file_id, str(self._base_folder), filename=archive_name, md5=md5) + _decompress(str(self._base_folder / archive_name)) diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/phototour.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/phototour.py new file mode 100644 index 0000000000000000000000000000000000000000..fd2466a3d364e42ba0be0d2fee89d0cd329fb841 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/phototour.py @@ -0,0 +1,234 @@ +import os +from pathlib import Path +from typing import Any, Callable, List, Optional, Tuple, Union + +import numpy as np +import torch +from PIL import Image + +from .utils import download_url +from .vision import VisionDataset + + +class PhotoTour(VisionDataset): + """`Multi-view Stereo Correspondence `_ Dataset. + + .. note:: + + We only provide the newer version of the dataset, since the authors state that it + + is more suitable for training descriptors based on difference of Gaussian, or Harris corners, as the + patches are centred on real interest point detections, rather than being projections of 3D points as is the + case in the old dataset. + + The original dataset is available under http://phototour.cs.washington.edu/patches/default.htm. + + + Args: + root (str or ``pathlib.Path``): Root directory where images are. + name (string): Name of the dataset to load. + transform (callable, optional): A function/transform that takes in a PIL image + and returns a transformed version. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + + """ + + urls = { + "notredame_harris": [ + "http://matthewalunbrown.com/patchdata/notredame_harris.zip", + "notredame_harris.zip", + "69f8c90f78e171349abdf0307afefe4d", + ], + "yosemite_harris": [ + "http://matthewalunbrown.com/patchdata/yosemite_harris.zip", + "yosemite_harris.zip", + "a73253d1c6fbd3ba2613c45065c00d46", + ], + "liberty_harris": [ + "http://matthewalunbrown.com/patchdata/liberty_harris.zip", + "liberty_harris.zip", + "c731fcfb3abb4091110d0ae8c7ba182c", + ], + "notredame": [ + "http://icvl.ee.ic.ac.uk/vbalnt/notredame.zip", + "notredame.zip", + "509eda8535847b8c0a90bbb210c83484", + ], + "yosemite": ["http://icvl.ee.ic.ac.uk/vbalnt/yosemite.zip", "yosemite.zip", "533b2e8eb7ede31be40abc317b2fd4f0"], + "liberty": ["http://icvl.ee.ic.ac.uk/vbalnt/liberty.zip", "liberty.zip", "fdd9152f138ea5ef2091746689176414"], + } + means = { + "notredame": 0.4854, + "yosemite": 0.4844, + "liberty": 0.4437, + "notredame_harris": 0.4854, + "yosemite_harris": 0.4844, + "liberty_harris": 0.4437, + } + stds = { + "notredame": 0.1864, + "yosemite": 0.1818, + "liberty": 0.2019, + "notredame_harris": 0.1864, + "yosemite_harris": 0.1818, + "liberty_harris": 0.2019, + } + lens = { + "notredame": 468159, + "yosemite": 633587, + "liberty": 450092, + "liberty_harris": 379587, + "yosemite_harris": 450912, + "notredame_harris": 325295, + } + image_ext = "bmp" + info_file = "info.txt" + matches_files = "m50_100000_100000_0.txt" + + def __init__( + self, + root: Union[str, Path], + name: str, + train: bool = True, + transform: Optional[Callable] = None, + download: bool = False, + ) -> None: + super().__init__(root, transform=transform) + self.name = name + self.data_dir = os.path.join(self.root, name) + self.data_down = os.path.join(self.root, f"{name}.zip") + self.data_file = os.path.join(self.root, f"{name}.pt") + + self.train = train + self.mean = self.means[name] + self.std = self.stds[name] + + if download: + self.download() + + if not self._check_datafile_exists(): + self.cache() + + # load the serialized data + self.data, self.labels, self.matches = torch.load(self.data_file, weights_only=True) + + def __getitem__(self, index: int) -> Union[torch.Tensor, Tuple[Any, Any, torch.Tensor]]: + """ + Args: + index (int): Index + + Returns: + tuple: (data1, data2, matches) + """ + if self.train: + data = self.data[index] + if self.transform is not None: + data = self.transform(data) + return data + m = self.matches[index] + data1, data2 = self.data[m[0]], self.data[m[1]] + if self.transform is not None: + data1 = self.transform(data1) + data2 = self.transform(data2) + return data1, data2, m[2] + + def __len__(self) -> int: + return len(self.data if self.train else self.matches) + + def _check_datafile_exists(self) -> bool: + return os.path.exists(self.data_file) + + def _check_downloaded(self) -> bool: + return os.path.exists(self.data_dir) + + def download(self) -> None: + if self._check_datafile_exists(): + print(f"# Found cached data {self.data_file}") + return + + if not self._check_downloaded(): + # download files + url = self.urls[self.name][0] + filename = self.urls[self.name][1] + md5 = self.urls[self.name][2] + fpath = os.path.join(self.root, filename) + + download_url(url, self.root, filename, md5) + + print(f"# Extracting data {self.data_down}\n") + + import zipfile + + with zipfile.ZipFile(fpath, "r") as z: + z.extractall(self.data_dir) + + os.unlink(fpath) + + def cache(self) -> None: + # process and save as torch files + print(f"# Caching data {self.data_file}") + + dataset = ( + read_image_file(self.data_dir, self.image_ext, self.lens[self.name]), + read_info_file(self.data_dir, self.info_file), + read_matches_files(self.data_dir, self.matches_files), + ) + + with open(self.data_file, "wb") as f: + torch.save(dataset, f) + + def extra_repr(self) -> str: + split = "Train" if self.train is True else "Test" + return f"Split: {split}" + + +def read_image_file(data_dir: str, image_ext: str, n: int) -> torch.Tensor: + """Return a Tensor containing the patches""" + + def PIL2array(_img: Image.Image) -> np.ndarray: + """Convert PIL image type to numpy 2D array""" + return np.array(_img.getdata(), dtype=np.uint8).reshape(64, 64) + + def find_files(_data_dir: str, _image_ext: str) -> List[str]: + """Return a list with the file names of the images containing the patches""" + files = [] + # find those files with the specified extension + for file_dir in os.listdir(_data_dir): + if file_dir.endswith(_image_ext): + files.append(os.path.join(_data_dir, file_dir)) + return sorted(files) # sort files in ascend order to keep relations + + patches = [] + list_files = find_files(data_dir, image_ext) + + for fpath in list_files: + img = Image.open(fpath) + for y in range(0, img.height, 64): + for x in range(0, img.width, 64): + patch = img.crop((x, y, x + 64, y + 64)) + patches.append(PIL2array(patch)) + return torch.ByteTensor(np.array(patches[:n])) + + +def read_info_file(data_dir: str, info_file: str) -> torch.Tensor: + """Return a Tensor containing the list of labels + Read the file and keep only the ID of the 3D point. + """ + with open(os.path.join(data_dir, info_file)) as f: + labels = [int(line.split()[0]) for line in f] + return torch.LongTensor(labels) + + +def read_matches_files(data_dir: str, matches_file: str) -> torch.Tensor: + """Return a Tensor containing the ground truth matches + Read the file and keep only 3D point ID. + Matches are represented with a 1, non matches with a 0. + """ + matches = [] + with open(os.path.join(data_dir, matches_file)) as f: + for line in f: + line_split = line.split() + matches.append([int(line_split[0]), int(line_split[3]), int(line_split[1] == line_split[4])]) + return torch.LongTensor(matches) diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/places365.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/places365.py new file mode 100644 index 0000000000000000000000000000000000000000..98966e1dc2f6864a1e4eb192d676e426771bca88 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/places365.py @@ -0,0 +1,171 @@ +import os +from os import path +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from urllib.parse import urljoin + +from .folder import default_loader +from .utils import check_integrity, download_and_extract_archive, verify_str_arg +from .vision import VisionDataset + + +class Places365(VisionDataset): + r"""`Places365 `_ classification dataset. + + Args: + root (str or ``pathlib.Path``): Root directory of the Places365 dataset. + split (string, optional): The dataset split. Can be one of ``train-standard`` (default), ``train-challenge``, + ``val``. + small (bool, optional): If ``True``, uses the small images, i.e. resized to 256 x 256 pixels, instead of the + high resolution ones. + download (bool, optional): If ``True``, downloads the dataset components and places them in ``root``. Already + downloaded archives are not downloaded again. + transform (callable, optional): A function/transform that takes in a PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + loader (callable, optional): A function to load an image given its path. + + Attributes: + classes (list): List of the class names. + class_to_idx (dict): Dict with items (class_name, class_index). + imgs (list): List of (image path, class_index) tuples + targets (list): The class_index value for each image in the dataset + + Raises: + RuntimeError: If ``download is False`` and the meta files, i.e. the devkit, are not present or corrupted. + RuntimeError: If ``download is True`` and the image archive is already extracted. + """ + _SPLITS = ("train-standard", "train-challenge", "val") + _BASE_URL = "http://data.csail.mit.edu/places/places365/" + # {variant: (archive, md5)} + _DEVKIT_META = { + "standard": ("filelist_places365-standard.tar", "35a0585fee1fa656440f3ab298f8479c"), + "challenge": ("filelist_places365-challenge.tar", "70a8307e459c3de41690a7c76c931734"), + } + # (file, md5) + _CATEGORIES_META = ("categories_places365.txt", "06c963b85866bd0649f97cb43dd16673") + # {split: (file, md5)} + _FILE_LIST_META = { + "train-standard": ("places365_train_standard.txt", "30f37515461640559006b8329efbed1a"), + "train-challenge": ("places365_train_challenge.txt", "b2931dc997b8c33c27e7329c073a6b57"), + "val": ("places365_val.txt", "e9f2fd57bfd9d07630173f4e8708e4b1"), + } + # {(split, small): (file, md5)} + _IMAGES_META = { + ("train-standard", False): ("train_large_places365standard.tar", "67e186b496a84c929568076ed01a8aa1"), + ("train-challenge", False): ("train_large_places365challenge.tar", "605f18e68e510c82b958664ea134545f"), + ("val", False): ("val_large.tar", "9b71c4993ad89d2d8bcbdc4aef38042f"), + ("train-standard", True): ("train_256_places365standard.tar", "53ca1c756c3d1e7809517cc47c5561c5"), + ("train-challenge", True): ("train_256_places365challenge.tar", "741915038a5e3471ec7332404dfb64ef"), + ("val", True): ("val_256.tar", "e27b17d8d44f4af9a78502beb927f808"), + } + + def __init__( + self, + root: Union[str, Path], + split: str = "train-standard", + small: bool = False, + download: bool = False, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + loader: Callable[[str], Any] = default_loader, + ) -> None: + super().__init__(root, transform=transform, target_transform=target_transform) + + self.split = self._verify_split(split) + self.small = small + self.loader = loader + + self.classes, self.class_to_idx = self.load_categories(download) + self.imgs, self.targets = self.load_file_list(download) + + if download: + self.download_images() + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + file, target = self.imgs[index] + image = self.loader(file) + + if self.transforms is not None: + image, target = self.transforms(image, target) + + return image, target + + def __len__(self) -> int: + return len(self.imgs) + + @property + def variant(self) -> str: + return "challenge" if "challenge" in self.split else "standard" + + @property + def images_dir(self) -> str: + size = "256" if self.small else "large" + if self.split.startswith("train"): + dir = f"data_{size}_{self.variant}" + else: + dir = f"{self.split}_{size}" + return path.join(self.root, dir) + + def load_categories(self, download: bool = True) -> Tuple[List[str], Dict[str, int]]: + def process(line: str) -> Tuple[str, int]: + cls, idx = line.split() + return cls, int(idx) + + file, md5 = self._CATEGORIES_META + file = path.join(self.root, file) + if not self._check_integrity(file, md5, download): + self.download_devkit() + + with open(file) as fh: + class_to_idx = dict(process(line) for line in fh) + + return sorted(class_to_idx.keys()), class_to_idx + + def load_file_list(self, download: bool = True) -> Tuple[List[Tuple[str, int]], List[int]]: + def process(line: str, sep="/") -> Tuple[str, int]: + image, idx = line.split() + return path.join(self.images_dir, image.lstrip(sep).replace(sep, os.sep)), int(idx) + + file, md5 = self._FILE_LIST_META[self.split] + file = path.join(self.root, file) + if not self._check_integrity(file, md5, download): + self.download_devkit() + + with open(file) as fh: + images = [process(line) for line in fh] + + _, targets = zip(*images) + return images, list(targets) + + def download_devkit(self) -> None: + file, md5 = self._DEVKIT_META[self.variant] + download_and_extract_archive(urljoin(self._BASE_URL, file), self.root, md5=md5) + + def download_images(self) -> None: + if path.exists(self.images_dir): + raise RuntimeError( + f"The directory {self.images_dir} already exists. If you want to re-download or re-extract the images, " + f"delete the directory." + ) + + file, md5 = self._IMAGES_META[(self.split, self.small)] + download_and_extract_archive(urljoin(self._BASE_URL, file), self.root, md5=md5) + + if self.split.startswith("train"): + os.rename(self.images_dir.rsplit("_", 1)[0], self.images_dir) + + def extra_repr(self) -> str: + return "\n".join(("Split: {split}", "Small: {small}")).format(**self.__dict__) + + def _verify_split(self, split: str) -> str: + return verify_str_arg(split, "split", self._SPLITS) + + def _check_integrity(self, file: str, md5: str, download: bool) -> bool: + integrity = check_integrity(file, md5=md5) + if not integrity and not download: + raise RuntimeError( + f"The file {file} does not exist or is corrupted. You can set download=True to download it." + ) + return integrity diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/rendered_sst2.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/rendered_sst2.py new file mode 100644 index 0000000000000000000000000000000000000000..48b0ddfc4fb3394505c1125e67cc026b3fd14fc8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/rendered_sst2.py @@ -0,0 +1,86 @@ +from pathlib import Path +from typing import Any, Callable, Optional, Tuple, Union + +import PIL.Image + +from .folder import make_dataset +from .utils import download_and_extract_archive, verify_str_arg +from .vision import VisionDataset + + +class RenderedSST2(VisionDataset): + """`The Rendered SST2 Dataset `_. + + Rendered SST2 is an image classification dataset used to evaluate the models capability on optical + character recognition. This dataset was generated by rendering sentences in the Standford Sentiment + Treebank v2 dataset. + + This dataset contains two classes (positive and negative) and is divided in three splits: a train + split containing 6920 images (3610 positive and 3310 negative), a validation split containing 872 images + (444 positive and 428 negative), and a test split containing 1821 images (909 positive and 912 negative). + + Args: + root (str or ``pathlib.Path``): Root directory of the dataset. + split (string, optional): The dataset split, supports ``"train"`` (default), `"val"` and ``"test"``. + transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed + version. E.g, ``transforms.RandomCrop``. + target_transform (callable, optional): A function/transform that takes in the target and transforms it. + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. Default is False. + """ + + _URL = "https://openaipublic.azureedge.net/clip/data/rendered-sst2.tgz" + _MD5 = "2384d08e9dcfa4bd55b324e610496ee5" + + def __init__( + self, + root: Union[str, Path], + split: str = "train", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: + super().__init__(root, transform=transform, target_transform=target_transform) + self._split = verify_str_arg(split, "split", ("train", "val", "test")) + self._split_to_folder = {"train": "train", "val": "valid", "test": "test"} + self._base_folder = Path(self.root) / "rendered-sst2" + self.classes = ["negative", "positive"] + self.class_to_idx = {"negative": 0, "positive": 1} + + if download: + self._download() + + if not self._check_exists(): + raise RuntimeError("Dataset not found. You can use download=True to download it") + + self._samples = make_dataset(str(self._base_folder / self._split_to_folder[self._split]), extensions=("png",)) + + def __len__(self) -> int: + return len(self._samples) + + def __getitem__(self, idx: int) -> Tuple[Any, Any]: + image_file, label = self._samples[idx] + image = PIL.Image.open(image_file).convert("RGB") + + if self.transform: + image = self.transform(image) + + if self.target_transform: + label = self.target_transform(label) + + return image, label + + def extra_repr(self) -> str: + return f"split={self._split}" + + def _check_exists(self) -> bool: + for class_label in set(self.classes): + if not (self._base_folder / self._split_to_folder[self._split] / class_label).is_dir(): + return False + return True + + def _download(self) -> None: + if self._check_exists(): + return + download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5) diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/sbd.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/sbd.py new file mode 100644 index 0000000000000000000000000000000000000000..4b9ccb75eb9bc6c5c2e9af6fbb4b5868e8488923 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/sbd.py @@ -0,0 +1,126 @@ +import os +import shutil +from pathlib import Path +from typing import Any, Callable, Optional, Tuple, Union + +import numpy as np +from PIL import Image + +from .utils import download_and_extract_archive, download_url, verify_str_arg +from .vision import VisionDataset + + +class SBDataset(VisionDataset): + """`Semantic Boundaries Dataset `_ + + The SBD currently contains annotations from 11355 images taken from the PASCAL VOC 2011 dataset. + + .. note :: + + Please note that the train and val splits included with this dataset are different from + the splits in the PASCAL VOC dataset. In particular some "train" images might be part of + VOC2012 val. + If you are interested in testing on VOC 2012 val, then use `image_set='train_noval'`, + which excludes all val images. + + .. warning:: + + This class needs `scipy `_ to load target files from `.mat` format. + + Args: + root (str or ``pathlib.Path``): Root directory of the Semantic Boundaries Dataset + image_set (string, optional): Select the image_set to use, ``train``, ``val`` or ``train_noval``. + Image set ``train_noval`` excludes VOC 2012 val images. + mode (string, optional): Select target type. Possible values 'boundaries' or 'segmentation'. + In case of 'boundaries', the target is an array of shape `[num_classes, H, W]`, + where `num_classes=20`. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + transforms (callable, optional): A function/transform that takes input sample and its target as entry + and returns a transformed version. Input sample is PIL image and target is a numpy array + if `mode='boundaries'` or PIL image if `mode='segmentation'`. + """ + + url = "https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz" + md5 = "82b4d87ceb2ed10f6038a1cba92111cb" + filename = "benchmark.tgz" + + voc_train_url = "https://www.cs.cornell.edu/~bharathh/train_noval.txt" + voc_split_filename = "train_noval.txt" + voc_split_md5 = "79bff800c5f0b1ec6b21080a3c066722" + + def __init__( + self, + root: Union[str, Path], + image_set: str = "train", + mode: str = "boundaries", + download: bool = False, + transforms: Optional[Callable] = None, + ) -> None: + + try: + from scipy.io import loadmat + + self._loadmat = loadmat + except ImportError: + raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy") + + super().__init__(root, transforms) + self.image_set = verify_str_arg(image_set, "image_set", ("train", "val", "train_noval")) + self.mode = verify_str_arg(mode, "mode", ("segmentation", "boundaries")) + self.num_classes = 20 + + sbd_root = self.root + image_dir = os.path.join(sbd_root, "img") + mask_dir = os.path.join(sbd_root, "cls") + + if download: + download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5) + extracted_ds_root = os.path.join(self.root, "benchmark_RELEASE", "dataset") + for f in ["cls", "img", "inst", "train.txt", "val.txt"]: + old_path = os.path.join(extracted_ds_root, f) + shutil.move(old_path, sbd_root) + if self.image_set == "train_noval": + # Note: this is failing as of June 2024 https://github.com/pytorch/vision/issues/8471 + download_url(self.voc_train_url, sbd_root, self.voc_split_filename, self.voc_split_md5) + + if not os.path.isdir(sbd_root): + raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it") + + split_f = os.path.join(sbd_root, image_set.rstrip("\n") + ".txt") + + with open(os.path.join(split_f)) as fh: + file_names = [x.strip() for x in fh.readlines()] + + self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names] + self.masks = [os.path.join(mask_dir, x + ".mat") for x in file_names] + + self._get_target = self._get_segmentation_target if self.mode == "segmentation" else self._get_boundaries_target + + def _get_segmentation_target(self, filepath: str) -> Image.Image: + mat = self._loadmat(filepath) + return Image.fromarray(mat["GTcls"][0]["Segmentation"][0]) + + def _get_boundaries_target(self, filepath: str) -> np.ndarray: + mat = self._loadmat(filepath) + return np.concatenate( + [np.expand_dims(mat["GTcls"][0]["Boundaries"][0][i][0].toarray(), axis=0) for i in range(self.num_classes)], + axis=0, + ) + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + img = Image.open(self.images[index]).convert("RGB") + target = self._get_target(self.masks[index]) + + if self.transforms is not None: + img, target = self.transforms(img, target) + + return img, target + + def __len__(self) -> int: + return len(self.images) + + def extra_repr(self) -> str: + lines = ["Image set: {image_set}", "Mode: {mode}"] + return "\n".join(lines).format(**self.__dict__) diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/semeion.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/semeion.py new file mode 100644 index 0000000000000000000000000000000000000000..d0344c74241775933f15e0fd93cd41703a497644 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/semeion.py @@ -0,0 +1,92 @@ +import os.path +from pathlib import Path +from typing import Any, Callable, Optional, Tuple, Union + +import numpy as np +from PIL import Image + +from .utils import check_integrity, download_url +from .vision import VisionDataset + + +class SEMEION(VisionDataset): + r"""`SEMEION `_ Dataset. + + Args: + root (str or ``pathlib.Path``): Root directory of dataset where directory + ``semeion.py`` exists. + transform (callable, optional): A function/transform that takes in a PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + + """ + url = "http://archive.ics.uci.edu/ml/machine-learning-databases/semeion/semeion.data" + filename = "semeion.data" + md5_checksum = "cb545d371d2ce14ec121470795a77432" + + def __init__( + self, + root: Union[str, Path], + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = True, + ) -> None: + super().__init__(root, transform=transform, target_transform=target_transform) + + if download: + self.download() + + if not self._check_integrity(): + raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it") + + fp = os.path.join(self.root, self.filename) + data = np.loadtxt(fp) + # convert value to 8 bit unsigned integer + # color (white #255) the pixels + self.data = (data[:, :256] * 255).astype("uint8") + self.data = np.reshape(self.data, (-1, 16, 16)) + self.labels = np.nonzero(data[:, 256:])[1] + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """ + Args: + index (int): Index + + Returns: + tuple: (image, target) where target is index of the target class. + """ + img, target = self.data[index], int(self.labels[index]) + + # doing this so that it is consistent with all other datasets + # to return a PIL Image + img = Image.fromarray(img, mode="L") + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def __len__(self) -> int: + return len(self.data) + + def _check_integrity(self) -> bool: + root = self.root + fpath = os.path.join(root, self.filename) + if not check_integrity(fpath, self.md5_checksum): + return False + return True + + def download(self) -> None: + if self._check_integrity(): + print("Files already downloaded and verified") + return + + root = self.root + download_url(self.url, root, self.filename, self.md5_checksum) diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/stanford_cars.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/stanford_cars.py new file mode 100644 index 0000000000000000000000000000000000000000..c029ed0d3585770330c591a72cd254b423d782ae --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/stanford_cars.py @@ -0,0 +1,109 @@ +import pathlib +from typing import Any, Callable, Optional, Tuple, Union + +from PIL import Image + +from .utils import verify_str_arg +from .vision import VisionDataset + + +class StanfordCars(VisionDataset): + """Stanford Cars Dataset + + The Cars dataset contains 16,185 images of 196 classes of cars. The data is + split into 8,144 training images and 8,041 testing images, where each class + has been split roughly in a 50-50 split + + The original URL is https://ai.stanford.edu/~jkrause/cars/car_dataset.html, but it is broken. + + .. note:: + + This class needs `scipy `_ to load target files from `.mat` format. + + Args: + root (str or ``pathlib.Path``): Root directory of dataset + split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``. + transform (callable, optional): A function/transform that takes in a PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + download (bool, optional): This parameter exists for backward compatibility but it does not + download the dataset, since the original URL is not available anymore. The dataset + seems to be available on Kaggle so you can try to manually download it using + `these instructions `_. + """ + + def __init__( + self, + root: Union[str, pathlib.Path], + split: str = "train", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: + + try: + import scipy.io as sio + except ImportError: + raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy") + + super().__init__(root, transform=transform, target_transform=target_transform) + + self._split = verify_str_arg(split, "split", ("train", "test")) + self._base_folder = pathlib.Path(root) / "stanford_cars" + devkit = self._base_folder / "devkit" + + if self._split == "train": + self._annotations_mat_path = devkit / "cars_train_annos.mat" + self._images_base_path = self._base_folder / "cars_train" + else: + self._annotations_mat_path = self._base_folder / "cars_test_annos_withlabels.mat" + self._images_base_path = self._base_folder / "cars_test" + + if download: + self.download() + + if not self._check_exists(): + raise RuntimeError( + "Dataset not found. Try to manually download following the instructions in " + "https://github.com/pytorch/vision/issues/7545#issuecomment-1631441616." + ) + + self._samples = [ + ( + str(self._images_base_path / annotation["fname"]), + annotation["class"] - 1, # Original target mapping starts from 1, hence -1 + ) + for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"] + ] + + self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist() + self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)} + + def __len__(self) -> int: + return len(self._samples) + + def __getitem__(self, idx: int) -> Tuple[Any, Any]: + """Returns pil_image and class_id for given index""" + image_path, target = self._samples[idx] + pil_image = Image.open(image_path).convert("RGB") + + if self.transform is not None: + pil_image = self.transform(pil_image) + if self.target_transform is not None: + target = self.target_transform(target) + return pil_image, target + + def _check_exists(self) -> bool: + if not (self._base_folder / "devkit").is_dir(): + return False + + return self._annotations_mat_path.exists() and self._images_base_path.is_dir() + + def download(self): + raise ValueError( + "The original URL is broken so the StanfordCars dataset is not available for automatic " + "download anymore. You can try to download it manually following " + "https://github.com/pytorch/vision/issues/7545#issuecomment-1631441616, " + "and set download=False to avoid this error." + ) diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/stl10.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/stl10.py new file mode 100644 index 0000000000000000000000000000000000000000..90ff41738eb32eaf1489736beb861e58cfc8003c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/stl10.py @@ -0,0 +1,175 @@ +import os.path +from pathlib import Path +from typing import Any, Callable, cast, Optional, Tuple, Union + +import numpy as np +from PIL import Image + +from .utils import check_integrity, download_and_extract_archive, verify_str_arg +from .vision import VisionDataset + + +class STL10(VisionDataset): + """`STL10 `_ Dataset. + + Args: + root (str or ``pathlib.Path``): Root directory of dataset where directory + ``stl10_binary`` exists. + split (string): One of {'train', 'test', 'unlabeled', 'train+unlabeled'}. + Accordingly, dataset is selected. + folds (int, optional): One of {0-9} or None. + For training, loads one of the 10 pre-defined folds of 1k samples for the + standard evaluation procedure. If no value is passed, loads the 5k samples. + transform (callable, optional): A function/transform that takes in a PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + """ + + base_folder = "stl10_binary" + url = "http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz" + filename = "stl10_binary.tar.gz" + tgz_md5 = "91f7769df0f17e558f3565bffb0c7dfb" + class_names_file = "class_names.txt" + folds_list_file = "fold_indices.txt" + train_list = [ + ["train_X.bin", "918c2871b30a85fa023e0c44e0bee87f"], + ["train_y.bin", "5a34089d4802c674881badbb80307741"], + ["unlabeled_X.bin", "5242ba1fed5e4be9e1e742405eb56ca4"], + ] + + test_list = [["test_X.bin", "7f263ba9f9e0b06b93213547f721ac82"], ["test_y.bin", "36f9794fa4beb8a2c72628de14fa638e"]] + splits = ("train", "train+unlabeled", "unlabeled", "test") + + def __init__( + self, + root: Union[str, Path], + split: str = "train", + folds: Optional[int] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: + super().__init__(root, transform=transform, target_transform=target_transform) + self.split = verify_str_arg(split, "split", self.splits) + self.folds = self._verify_folds(folds) + + if download: + self.download() + elif not self._check_integrity(): + raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it") + + # now load the picked numpy arrays + self.labels: Optional[np.ndarray] + if self.split == "train": + self.data, self.labels = self.__loadfile(self.train_list[0][0], self.train_list[1][0]) + self.labels = cast(np.ndarray, self.labels) + self.__load_folds(folds) + + elif self.split == "train+unlabeled": + self.data, self.labels = self.__loadfile(self.train_list[0][0], self.train_list[1][0]) + self.labels = cast(np.ndarray, self.labels) + self.__load_folds(folds) + unlabeled_data, _ = self.__loadfile(self.train_list[2][0]) + self.data = np.concatenate((self.data, unlabeled_data)) + self.labels = np.concatenate((self.labels, np.asarray([-1] * unlabeled_data.shape[0]))) + + elif self.split == "unlabeled": + self.data, _ = self.__loadfile(self.train_list[2][0]) + self.labels = np.asarray([-1] * self.data.shape[0]) + else: # self.split == 'test': + self.data, self.labels = self.__loadfile(self.test_list[0][0], self.test_list[1][0]) + + class_file = os.path.join(self.root, self.base_folder, self.class_names_file) + if os.path.isfile(class_file): + with open(class_file) as f: + self.classes = f.read().splitlines() + + def _verify_folds(self, folds: Optional[int]) -> Optional[int]: + if folds is None: + return folds + elif isinstance(folds, int): + if folds in range(10): + return folds + msg = "Value for argument folds should be in the range [0, 10), but got {}." + raise ValueError(msg.format(folds)) + else: + msg = "Expected type None or int for argument folds, but got type {}." + raise ValueError(msg.format(type(folds))) + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """ + Args: + index (int): Index + + Returns: + tuple: (image, target) where target is index of the target class. + """ + target: Optional[int] + if self.labels is not None: + img, target = self.data[index], int(self.labels[index]) + else: + img, target = self.data[index], None + + # doing this so that it is consistent with all other datasets + # to return a PIL Image + img = Image.fromarray(np.transpose(img, (1, 2, 0))) + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def __len__(self) -> int: + return self.data.shape[0] + + def __loadfile(self, data_file: str, labels_file: Optional[str] = None) -> Tuple[np.ndarray, Optional[np.ndarray]]: + labels = None + if labels_file: + path_to_labels = os.path.join(self.root, self.base_folder, labels_file) + with open(path_to_labels, "rb") as f: + labels = np.fromfile(f, dtype=np.uint8) - 1 # 0-based + + path_to_data = os.path.join(self.root, self.base_folder, data_file) + with open(path_to_data, "rb") as f: + # read whole file in uint8 chunks + everything = np.fromfile(f, dtype=np.uint8) + images = np.reshape(everything, (-1, 3, 96, 96)) + images = np.transpose(images, (0, 1, 3, 2)) + + return images, labels + + def _check_integrity(self) -> bool: + for filename, md5 in self.train_list + self.test_list: + fpath = os.path.join(self.root, self.base_folder, filename) + if not check_integrity(fpath, md5): + return False + return True + + def download(self) -> None: + if self._check_integrity(): + print("Files already downloaded and verified") + return + download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) + self._check_integrity() + + def extra_repr(self) -> str: + return "Split: {split}".format(**self.__dict__) + + def __load_folds(self, folds: Optional[int]) -> None: + # loads one of the folds if specified + if folds is None: + return + path_to_folds = os.path.join(self.root, self.base_folder, self.folds_list_file) + with open(path_to_folds) as f: + str_idx = f.read().splitlines()[folds] + list_idx = np.fromstring(str_idx, dtype=np.int64, sep=" ") + self.data = self.data[list_idx, :, :, :] + if self.labels is not None: + self.labels = self.labels[list_idx] diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/sun397.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/sun397.py new file mode 100644 index 0000000000000000000000000000000000000000..4db0a3cf237376f9a59e06f770f90554abfd87e2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/sun397.py @@ -0,0 +1,76 @@ +from pathlib import Path +from typing import Any, Callable, Optional, Tuple, Union + +import PIL.Image + +from .utils import download_and_extract_archive +from .vision import VisionDataset + + +class SUN397(VisionDataset): + """`The SUN397 Data Set `_. + + The SUN397 or Scene UNderstanding (SUN) is a dataset for scene recognition consisting of + 397 categories with 108'754 images. + + Args: + root (str or ``pathlib.Path``): Root directory of the dataset. + transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed + version. E.g, ``transforms.RandomCrop``. + target_transform (callable, optional): A function/transform that takes in the target and transforms it. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + """ + + _DATASET_URL = "http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz" + _DATASET_MD5 = "8ca2778205c41d23104230ba66911c7a" + + def __init__( + self, + root: Union[str, Path], + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: + super().__init__(root, transform=transform, target_transform=target_transform) + self._data_dir = Path(self.root) / "SUN397" + + if download: + self._download() + + if not self._check_exists(): + raise RuntimeError("Dataset not found. You can use download=True to download it") + + with open(self._data_dir / "ClassName.txt") as f: + self.classes = [c[3:].strip() for c in f] + + self.class_to_idx = dict(zip(self.classes, range(len(self.classes)))) + self._image_files = list(self._data_dir.rglob("sun_*.jpg")) + + self._labels = [ + self.class_to_idx["/".join(path.relative_to(self._data_dir).parts[1:-1])] for path in self._image_files + ] + + def __len__(self) -> int: + return len(self._image_files) + + def __getitem__(self, idx: int) -> Tuple[Any, Any]: + image_file, label = self._image_files[idx], self._labels[idx] + image = PIL.Image.open(image_file).convert("RGB") + + if self.transform: + image = self.transform(image) + + if self.target_transform: + label = self.target_transform(label) + + return image, label + + def _check_exists(self) -> bool: + return self._data_dir.is_dir() + + def _download(self) -> None: + if self._check_exists(): + return + download_and_extract_archive(self._DATASET_URL, download_root=self.root, md5=self._DATASET_MD5) diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/ucf101.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/ucf101.py new file mode 100644 index 0000000000000000000000000000000000000000..935f8ad41c78e6086444b72ed1f5ae1f6de34bad --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/ucf101.py @@ -0,0 +1,131 @@ +import os +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +from torch import Tensor + +from .folder import find_classes, make_dataset +from .video_utils import VideoClips +from .vision import VisionDataset + + +class UCF101(VisionDataset): + """ + `UCF101 `_ dataset. + + UCF101 is an action recognition video dataset. + This dataset consider every video as a collection of video clips of fixed size, specified + by ``frames_per_clip``, where the step in frames between each clip is given by + ``step_between_clips``. The dataset itself can be downloaded from the dataset website; + annotations that ``annotation_path`` should be pointing to can be downloaded from `here + `_. + + To give an example, for 2 videos with 10 and 15 frames respectively, if ``frames_per_clip=5`` + and ``step_between_clips=5``, the dataset size will be (2 + 3) = 5, where the first two + elements will come from video 1, and the next three elements from video 2. + Note that we drop clips which do not have exactly ``frames_per_clip`` elements, so not all + frames in a video might be present. + + Internally, it uses a VideoClips object to handle clip creation. + + Args: + root (str or ``pathlib.Path``): Root directory of the UCF101 Dataset. + annotation_path (str): path to the folder containing the split files; + see docstring above for download instructions of these files + frames_per_clip (int): number of frames in a clip. + step_between_clips (int, optional): number of frames between each clip. + fold (int, optional): which fold to use. Should be between 1 and 3. + train (bool, optional): if ``True``, creates a dataset from the train split, + otherwise from the ``test`` split. + transform (callable, optional): A function/transform that takes in a TxHxWxC video + and returns a transformed version. + output_format (str, optional): The format of the output video tensors (before transforms). + Can be either "THWC" (default) or "TCHW". + + Returns: + tuple: A 3-tuple with the following entries: + + - video (Tensor[T, H, W, C] or Tensor[T, C, H, W]): The `T` video frames + - audio(Tensor[K, L]): the audio frames, where `K` is the number of channels + and `L` is the number of points + - label (int): class of the video clip + """ + + def __init__( + self, + root: Union[str, Path], + annotation_path: str, + frames_per_clip: int, + step_between_clips: int = 1, + frame_rate: Optional[int] = None, + fold: int = 1, + train: bool = True, + transform: Optional[Callable] = None, + _precomputed_metadata: Optional[Dict[str, Any]] = None, + num_workers: int = 1, + _video_width: int = 0, + _video_height: int = 0, + _video_min_dimension: int = 0, + _audio_samples: int = 0, + output_format: str = "THWC", + ) -> None: + super().__init__(root) + if not 1 <= fold <= 3: + raise ValueError(f"fold should be between 1 and 3, got {fold}") + + extensions = ("avi",) + self.fold = fold + self.train = train + + self.classes, class_to_idx = find_classes(self.root) + self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None) + video_list = [x[0] for x in self.samples] + video_clips = VideoClips( + video_list, + frames_per_clip, + step_between_clips, + frame_rate, + _precomputed_metadata, + num_workers=num_workers, + _video_width=_video_width, + _video_height=_video_height, + _video_min_dimension=_video_min_dimension, + _audio_samples=_audio_samples, + output_format=output_format, + ) + # we bookkeep the full version of video clips because we want to be able + # to return the metadata of full version rather than the subset version of + # video clips + self.full_video_clips = video_clips + self.indices = self._select_fold(video_list, annotation_path, fold, train) + self.video_clips = video_clips.subset(self.indices) + self.transform = transform + + @property + def metadata(self) -> Dict[str, Any]: + return self.full_video_clips.metadata + + def _select_fold(self, video_list: List[str], annotation_path: str, fold: int, train: bool) -> List[int]: + name = "train" if train else "test" + name = f"{name}list{fold:02d}.txt" + f = os.path.join(annotation_path, name) + selected_files = set() + with open(f) as fid: + data = fid.readlines() + data = [x.strip().split(" ")[0] for x in data] + data = [os.path.join(self.root, *x.split("/")) for x in data] + selected_files.update(data) + indices = [i for i in range(len(video_list)) if video_list[i] in selected_files] + return indices + + def __len__(self) -> int: + return self.video_clips.num_clips() + + def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, int]: + video, audio, info, video_idx = self.video_clips.get_clip(idx) + label = self.samples[self.indices[video_idx]][1] + + if self.transform is not None: + video = self.transform(video) + + return video, audio, label diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/usps.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/usps.py new file mode 100644 index 0000000000000000000000000000000000000000..9c681e79f6c3dc8fb6567f90b741ac7839298931 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/usps.py @@ -0,0 +1,96 @@ +import os +from pathlib import Path +from typing import Any, Callable, Optional, Tuple, Union + +import numpy as np +from PIL import Image + +from .utils import download_url +from .vision import VisionDataset + + +class USPS(VisionDataset): + """`USPS `_ Dataset. + The data-format is : [label [index:value ]*256 \\n] * num_lines, where ``label`` lies in ``[1, 10]``. + The value for each pixel lies in ``[-1, 1]``. Here we transform the ``label`` into ``[0, 9]`` + and make pixel values in ``[0, 255]``. + + Args: + root (str or ``pathlib.Path``): Root directory of dataset to store``USPS`` data files. + train (bool, optional): If True, creates dataset from ``usps.bz2``, + otherwise from ``usps.t.bz2``. + transform (callable, optional): A function/transform that takes in a PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + + """ + + split_list = { + "train": [ + "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2", + "usps.bz2", + "ec16c51db3855ca6c91edd34d0e9b197", + ], + "test": [ + "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2", + "usps.t.bz2", + "8ea070ee2aca1ac39742fdd1ef5ed118", + ], + } + + def __init__( + self, + root: Union[str, Path], + train: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: + super().__init__(root, transform=transform, target_transform=target_transform) + split = "train" if train else "test" + url, filename, checksum = self.split_list[split] + full_path = os.path.join(self.root, filename) + + if download and not os.path.exists(full_path): + download_url(url, self.root, filename, md5=checksum) + + import bz2 + + with bz2.open(full_path) as fp: + raw_data = [line.decode().split() for line in fp.readlines()] + tmp_list = [[x.split(":")[-1] for x in data[1:]] for data in raw_data] + imgs = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16)) + imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8) + targets = [int(d[0]) - 1 for d in raw_data] + + self.data = imgs + self.targets = targets + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """ + Args: + index (int): Index + + Returns: + tuple: (image, target) where target is index of the target class. + """ + img, target = self.data[index], int(self.targets[index]) + + # doing this so that it is consistent with all other datasets + # to return a PIL Image + img = Image.fromarray(img, mode="L") + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def __len__(self) -> int: + return len(self.data) diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/utils.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f65eb53545931cca890f2e9f49b6082b284d3e3d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/utils.py @@ -0,0 +1,476 @@ +import bz2 +import gzip +import hashlib +import lzma +import os +import os.path +import pathlib +import re +import sys +import tarfile +import urllib +import urllib.error +import urllib.request +import zipfile +from typing import Any, Callable, Dict, IO, Iterable, List, Optional, Tuple, TypeVar, Union +from urllib.parse import urlparse + +import numpy as np +import torch +from torch.utils.model_zoo import tqdm + +from .._internally_replaced_utils import _download_file_from_remote_location, _is_remote_location_available + +USER_AGENT = "pytorch/vision" + + +def _urlretrieve(url: str, filename: Union[str, pathlib.Path], chunk_size: int = 1024 * 32) -> None: + with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response: + with open(filename, "wb") as fh, tqdm(total=response.length, unit="B", unit_scale=True) as pbar: + while chunk := response.read(chunk_size): + fh.write(chunk) + pbar.update(len(chunk)) + + +def calculate_md5(fpath: Union[str, pathlib.Path], chunk_size: int = 1024 * 1024) -> str: + # Setting the `usedforsecurity` flag does not change anything about the functionality, but indicates that we are + # not using the MD5 checksum for cryptography. This enables its usage in restricted environments like FIPS. Without + # it torchvision.datasets is unusable in these environments since we perform a MD5 check everywhere. + if sys.version_info >= (3, 9): + md5 = hashlib.md5(usedforsecurity=False) + else: + md5 = hashlib.md5() + with open(fpath, "rb") as f: + while chunk := f.read(chunk_size): + md5.update(chunk) + return md5.hexdigest() + + +def check_md5(fpath: Union[str, pathlib.Path], md5: str, **kwargs: Any) -> bool: + return md5 == calculate_md5(fpath, **kwargs) + + +def check_integrity(fpath: Union[str, pathlib.Path], md5: Optional[str] = None) -> bool: + if not os.path.isfile(fpath): + return False + if md5 is None: + return True + return check_md5(fpath, md5) + + +def _get_redirect_url(url: str, max_hops: int = 3) -> str: + initial_url = url + headers = {"Method": "HEAD", "User-Agent": USER_AGENT} + + for _ in range(max_hops + 1): + with urllib.request.urlopen(urllib.request.Request(url, headers=headers)) as response: + if response.url == url or response.url is None: + return url + + url = response.url + else: + raise RecursionError( + f"Request to {initial_url} exceeded {max_hops} redirects. The last redirect points to {url}." + ) + + +def _get_google_drive_file_id(url: str) -> Optional[str]: + parts = urlparse(url) + + if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None: + return None + + match = re.match(r"/file/d/(?P[^/]*)", parts.path) + if match is None: + return None + + return match.group("id") + + +def download_url( + url: str, + root: Union[str, pathlib.Path], + filename: Optional[Union[str, pathlib.Path]] = None, + md5: Optional[str] = None, + max_redirect_hops: int = 3, +) -> None: + """Download a file from a url and place it in root. + + Args: + url (str): URL to download file from + root (str): Directory to place downloaded file in + filename (str, optional): Name to save the file under. If None, use the basename of the URL + md5 (str, optional): MD5 checksum of the download. If None, do not check + max_redirect_hops (int, optional): Maximum number of redirect hops allowed + """ + root = os.path.expanduser(root) + if not filename: + filename = os.path.basename(url) + fpath = os.fspath(os.path.join(root, filename)) + + os.makedirs(root, exist_ok=True) + + # check if file is already present locally + if check_integrity(fpath, md5): + print("Using downloaded and verified file: " + fpath) + return + + if _is_remote_location_available(): + _download_file_from_remote_location(fpath, url) + else: + # expand redirect chain if needed + url = _get_redirect_url(url, max_hops=max_redirect_hops) + + # check if file is located on Google Drive + file_id = _get_google_drive_file_id(url) + if file_id is not None: + return download_file_from_google_drive(file_id, root, filename, md5) + + # download the file + try: + print("Downloading " + url + " to " + fpath) + _urlretrieve(url, fpath) + except (urllib.error.URLError, OSError) as e: # type: ignore[attr-defined] + if url[:5] == "https": + url = url.replace("https:", "http:") + print("Failed download. Trying https -> http instead. Downloading " + url + " to " + fpath) + _urlretrieve(url, fpath) + else: + raise e + + # check integrity of downloaded file + if not check_integrity(fpath, md5): + raise RuntimeError("File not found or corrupted.") + + +def list_dir(root: Union[str, pathlib.Path], prefix: bool = False) -> List[str]: + """List all directories at a given root + + Args: + root (str): Path to directory whose folders need to be listed + prefix (bool, optional): If true, prepends the path to each result, otherwise + only returns the name of the directories found + """ + root = os.path.expanduser(root) + directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))] + if prefix is True: + directories = [os.path.join(root, d) for d in directories] + return directories + + +def list_files(root: Union[str, pathlib.Path], suffix: str, prefix: bool = False) -> List[str]: + """List all files ending with a suffix at a given root + + Args: + root (str): Path to directory whose folders need to be listed + suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). + It uses the Python "str.endswith" method and is passed directly + prefix (bool, optional): If true, prepends the path to each result, otherwise + only returns the name of the files found + """ + root = os.path.expanduser(root) + files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)] + if prefix is True: + files = [os.path.join(root, d) for d in files] + return files + + +def download_file_from_google_drive( + file_id: str, + root: Union[str, pathlib.Path], + filename: Optional[Union[str, pathlib.Path]] = None, + md5: Optional[str] = None, +): + """Download a Google Drive file from and place it in root. + + Args: + file_id (str): id of file to be downloaded + root (str): Directory to place downloaded file in + filename (str, optional): Name to save the file under. If None, use the id of the file. + md5 (str, optional): MD5 checksum of the download. If None, do not check + """ + try: + import gdown + except ModuleNotFoundError: + raise RuntimeError( + "To download files from GDrive, 'gdown' is required. You can install it with 'pip install gdown'." + ) + + root = os.path.expanduser(root) + if not filename: + filename = file_id + fpath = os.fspath(os.path.join(root, filename)) + + os.makedirs(root, exist_ok=True) + + if check_integrity(fpath, md5): + print(f"Using downloaded {'and verified ' if md5 else ''}file: {fpath}") + return + + gdown.download(id=file_id, output=fpath, quiet=False, user_agent=USER_AGENT) + + if not check_integrity(fpath, md5): + raise RuntimeError("File not found or corrupted.") + + +def _extract_tar( + from_path: Union[str, pathlib.Path], to_path: Union[str, pathlib.Path], compression: Optional[str] +) -> None: + with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar: + tar.extractall(to_path) + + +_ZIP_COMPRESSION_MAP: Dict[str, int] = { + ".bz2": zipfile.ZIP_BZIP2, + ".xz": zipfile.ZIP_LZMA, +} + + +def _extract_zip( + from_path: Union[str, pathlib.Path], to_path: Union[str, pathlib.Path], compression: Optional[str] +) -> None: + with zipfile.ZipFile( + from_path, "r", compression=_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED + ) as zip: + zip.extractall(to_path) + + +_ARCHIVE_EXTRACTORS: Dict[str, Callable[[Union[str, pathlib.Path], Union[str, pathlib.Path], Optional[str]], None]] = { + ".tar": _extract_tar, + ".zip": _extract_zip, +} +_COMPRESSED_FILE_OPENERS: Dict[str, Callable[..., IO]] = { + ".bz2": bz2.open, + ".gz": gzip.open, + ".xz": lzma.open, +} +_FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = { + ".tbz": (".tar", ".bz2"), + ".tbz2": (".tar", ".bz2"), + ".tgz": (".tar", ".gz"), +} + + +def _detect_file_type(file: Union[str, pathlib.Path]) -> Tuple[str, Optional[str], Optional[str]]: + """Detect the archive type and/or compression of a file. + + Args: + file (str): the filename + + Returns: + (tuple): tuple of suffix, archive type, and compression + + Raises: + RuntimeError: if file has no suffix or suffix is not supported + """ + suffixes = pathlib.Path(file).suffixes + if not suffixes: + raise RuntimeError( + f"File '{file}' has no suffixes that could be used to detect the archive type and compression." + ) + suffix = suffixes[-1] + + # check if the suffix is a known alias + if suffix in _FILE_TYPE_ALIASES: + return (suffix, *_FILE_TYPE_ALIASES[suffix]) + + # check if the suffix is an archive type + if suffix in _ARCHIVE_EXTRACTORS: + return suffix, suffix, None + + # check if the suffix is a compression + if suffix in _COMPRESSED_FILE_OPENERS: + # check for suffix hierarchy + if len(suffixes) > 1: + suffix2 = suffixes[-2] + + # check if the suffix2 is an archive type + if suffix2 in _ARCHIVE_EXTRACTORS: + return suffix2 + suffix, suffix2, suffix + + return suffix, None, suffix + + valid_suffixes = sorted(set(_FILE_TYPE_ALIASES) | set(_ARCHIVE_EXTRACTORS) | set(_COMPRESSED_FILE_OPENERS)) + raise RuntimeError(f"Unknown compression or archive type: '{suffix}'.\nKnown suffixes are: '{valid_suffixes}'.") + + +def _decompress( + from_path: Union[str, pathlib.Path], + to_path: Optional[Union[str, pathlib.Path]] = None, + remove_finished: bool = False, +) -> pathlib.Path: + r"""Decompress a file. + + The compression is automatically detected from the file name. + + Args: + from_path (str): Path to the file to be decompressed. + to_path (str): Path to the decompressed file. If omitted, ``from_path`` without compression extension is used. + remove_finished (bool): If ``True``, remove the file after the extraction. + + Returns: + (str): Path to the decompressed file. + """ + suffix, archive_type, compression = _detect_file_type(from_path) + if not compression: + raise RuntimeError(f"Couldn't detect a compression from suffix {suffix}.") + + if to_path is None: + to_path = pathlib.Path(os.fspath(from_path).replace(suffix, archive_type if archive_type is not None else "")) + + # We don't need to check for a missing key here, since this was already done in _detect_file_type() + compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression] + + with compressed_file_opener(from_path, "rb") as rfh, open(to_path, "wb") as wfh: + wfh.write(rfh.read()) + + if remove_finished: + os.remove(from_path) + + return pathlib.Path(to_path) + + +def extract_archive( + from_path: Union[str, pathlib.Path], + to_path: Optional[Union[str, pathlib.Path]] = None, + remove_finished: bool = False, +) -> Union[str, pathlib.Path]: + """Extract an archive. + + The archive type and a possible compression is automatically detected from the file name. If the file is compressed + but not an archive the call is dispatched to :func:`decompress`. + + Args: + from_path (str): Path to the file to be extracted. + to_path (str): Path to the directory the file will be extracted to. If omitted, the directory of the file is + used. + remove_finished (bool): If ``True``, remove the file after the extraction. + + Returns: + (str): Path to the directory the file was extracted to. + """ + + def path_or_str(ret_path: pathlib.Path) -> Union[str, pathlib.Path]: + if isinstance(from_path, str): + return os.fspath(ret_path) + else: + return ret_path + + if to_path is None: + to_path = os.path.dirname(from_path) + + suffix, archive_type, compression = _detect_file_type(from_path) + if not archive_type: + ret_path = _decompress( + from_path, + os.path.join(to_path, os.path.basename(from_path).replace(suffix, "")), + remove_finished=remove_finished, + ) + return path_or_str(ret_path) + + # We don't need to check for a missing key here, since this was already done in _detect_file_type() + extractor = _ARCHIVE_EXTRACTORS[archive_type] + + extractor(from_path, to_path, compression) + if remove_finished: + os.remove(from_path) + + return path_or_str(pathlib.Path(to_path)) + + +def download_and_extract_archive( + url: str, + download_root: Union[str, pathlib.Path], + extract_root: Optional[Union[str, pathlib.Path]] = None, + filename: Optional[Union[str, pathlib.Path]] = None, + md5: Optional[str] = None, + remove_finished: bool = False, +) -> None: + download_root = os.path.expanduser(download_root) + if extract_root is None: + extract_root = download_root + if not filename: + filename = os.path.basename(url) + + download_url(url, download_root, filename, md5) + + archive = os.path.join(download_root, filename) + print(f"Extracting {archive} to {extract_root}") + extract_archive(archive, extract_root, remove_finished) + + +def iterable_to_str(iterable: Iterable) -> str: + return "'" + "', '".join([str(item) for item in iterable]) + "'" + + +T = TypeVar("T", str, bytes) + + +def verify_str_arg( + value: T, + arg: Optional[str] = None, + valid_values: Optional[Iterable[T]] = None, + custom_msg: Optional[str] = None, +) -> T: + if not isinstance(value, str): + if arg is None: + msg = "Expected type str, but got type {type}." + else: + msg = "Expected type str for argument {arg}, but got type {type}." + msg = msg.format(type=type(value), arg=arg) + raise ValueError(msg) + + if valid_values is None: + return value + + if value not in valid_values: + if custom_msg is not None: + msg = custom_msg + else: + msg = "Unknown value '{value}' for argument {arg}. Valid values are {{{valid_values}}}." + msg = msg.format(value=value, arg=arg, valid_values=iterable_to_str(valid_values)) + raise ValueError(msg) + + return value + + +def _read_pfm(file_name: Union[str, pathlib.Path], slice_channels: int = 2) -> np.ndarray: + """Read file in .pfm format. Might contain either 1 or 3 channels of data. + + Args: + file_name (str): Path to the file. + slice_channels (int): Number of channels to slice out of the file. + Useful for reading different data formats stored in .pfm files: Optical Flows, Stereo Disparity Maps, etc. + """ + + with open(file_name, "rb") as f: + header = f.readline().rstrip() + if header not in [b"PF", b"Pf"]: + raise ValueError("Invalid PFM file") + + dim_match = re.match(rb"^(\d+)\s(\d+)\s$", f.readline()) + if not dim_match: + raise Exception("Malformed PFM header.") + w, h = (int(dim) for dim in dim_match.groups()) + + scale = float(f.readline().rstrip()) + if scale < 0: # little-endian + endian = "<" + scale = -scale + else: + endian = ">" # big-endian + + data = np.fromfile(f, dtype=endian + "f") + + pfm_channels = 3 if header == b"PF" else 1 + + data = data.reshape(h, w, pfm_channels).transpose(2, 0, 1) + data = np.flip(data, axis=1) # flip on h dimension + data = data[:slice_channels, :, :] + return data.astype(np.float32) + + +def _flip_byte_order(t: torch.Tensor) -> torch.Tensor: + return ( + t.contiguous().view(torch.uint8).view(*t.shape, t.element_size()).flip(-1).view(*t.shape[:-1], -1).view(t.dtype) + ) diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/video_utils.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/video_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a412bc5841cca891e6808215a6bd7d658c69e0f3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/video_utils.py @@ -0,0 +1,419 @@ +import bisect +import math +import warnings +from fractions import Fraction +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, TypeVar, Union + +import torch +from torchvision.io import _probe_video_from_file, _read_video_from_file, read_video, read_video_timestamps + +from .utils import tqdm + +T = TypeVar("T") + + +def pts_convert(pts: int, timebase_from: Fraction, timebase_to: Fraction, round_func: Callable = math.floor) -> int: + """convert pts between different time bases + Args: + pts: presentation timestamp, float + timebase_from: original timebase. Fraction + timebase_to: new timebase. Fraction + round_func: rounding function. + """ + new_pts = Fraction(pts, 1) * timebase_from / timebase_to + return round_func(new_pts) + + +def unfold(tensor: torch.Tensor, size: int, step: int, dilation: int = 1) -> torch.Tensor: + """ + similar to tensor.unfold, but with the dilation + and specialized for 1d tensors + + Returns all consecutive windows of `size` elements, with + `step` between windows. The distance between each element + in a window is given by `dilation`. + """ + if tensor.dim() != 1: + raise ValueError(f"tensor should have 1 dimension instead of {tensor.dim()}") + o_stride = tensor.stride(0) + numel = tensor.numel() + new_stride = (step * o_stride, dilation * o_stride) + new_size = ((numel - (dilation * (size - 1) + 1)) // step + 1, size) + if new_size[0] < 1: + new_size = (0, size) + return torch.as_strided(tensor, new_size, new_stride) + + +class _VideoTimestampsDataset: + """ + Dataset used to parallelize the reading of the timestamps + of a list of videos, given their paths in the filesystem. + + Used in VideoClips and defined at top level, so it can be + pickled when forking. + """ + + def __init__(self, video_paths: List[str]) -> None: + self.video_paths = video_paths + + def __len__(self) -> int: + return len(self.video_paths) + + def __getitem__(self, idx: int) -> Tuple[List[int], Optional[float]]: + return read_video_timestamps(self.video_paths[idx]) + + +def _collate_fn(x: T) -> T: + """ + Dummy collate function to be used with _VideoTimestampsDataset + """ + return x + + +class VideoClips: + """ + Given a list of video files, computes all consecutive subvideos of size + `clip_length_in_frames`, where the distance between each subvideo in the + same video is defined by `frames_between_clips`. + If `frame_rate` is specified, it will also resample all the videos to have + the same frame rate, and the clips will refer to this frame rate. + + Creating this instance the first time is time-consuming, as it needs to + decode all the videos in `video_paths`. It is recommended that you + cache the results after instantiation of the class. + + Recreating the clips for different clip lengths is fast, and can be done + with the `compute_clips` method. + + Args: + video_paths (List[str]): paths to the video files + clip_length_in_frames (int): size of a clip in number of frames + frames_between_clips (int): step (in frames) between each clip + frame_rate (float, optional): if specified, it will resample the video + so that it has `frame_rate`, and then the clips will be defined + on the resampled video + num_workers (int): how many subprocesses to use for data loading. + 0 means that the data will be loaded in the main process. (default: 0) + output_format (str): The format of the output video tensors. Can be either "THWC" (default) or "TCHW". + """ + + def __init__( + self, + video_paths: List[str], + clip_length_in_frames: int = 16, + frames_between_clips: int = 1, + frame_rate: Optional[float] = None, + _precomputed_metadata: Optional[Dict[str, Any]] = None, + num_workers: int = 0, + _video_width: int = 0, + _video_height: int = 0, + _video_min_dimension: int = 0, + _video_max_dimension: int = 0, + _audio_samples: int = 0, + _audio_channels: int = 0, + output_format: str = "THWC", + ) -> None: + + self.video_paths = video_paths + self.num_workers = num_workers + + # these options are not valid for pyav backend + self._video_width = _video_width + self._video_height = _video_height + self._video_min_dimension = _video_min_dimension + self._video_max_dimension = _video_max_dimension + self._audio_samples = _audio_samples + self._audio_channels = _audio_channels + self.output_format = output_format.upper() + if self.output_format not in ("THWC", "TCHW"): + raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.") + + if _precomputed_metadata is None: + self._compute_frame_pts() + else: + self._init_from_metadata(_precomputed_metadata) + self.compute_clips(clip_length_in_frames, frames_between_clips, frame_rate) + + def _compute_frame_pts(self) -> None: + self.video_pts = [] # len = num_videos. Each entry is a tensor of shape (num_frames_in_video,) + self.video_fps: List[float] = [] # len = num_videos + + # strategy: use a DataLoader to parallelize read_video_timestamps + # so need to create a dummy dataset first + import torch.utils.data + + dl: torch.utils.data.DataLoader = torch.utils.data.DataLoader( + _VideoTimestampsDataset(self.video_paths), # type: ignore[arg-type] + batch_size=16, + num_workers=self.num_workers, + collate_fn=_collate_fn, + ) + + with tqdm(total=len(dl)) as pbar: + for batch in dl: + pbar.update(1) + batch_pts, batch_fps = list(zip(*batch)) + # we need to specify dtype=torch.long because for empty list, + # torch.as_tensor will use torch.float as default dtype. This + # happens when decoding fails and no pts is returned in the list. + batch_pts = [torch.as_tensor(pts, dtype=torch.long) for pts in batch_pts] + self.video_pts.extend(batch_pts) + self.video_fps.extend(batch_fps) + + def _init_from_metadata(self, metadata: Dict[str, Any]) -> None: + self.video_paths = metadata["video_paths"] + assert len(self.video_paths) == len(metadata["video_pts"]) + self.video_pts = metadata["video_pts"] + assert len(self.video_paths) == len(metadata["video_fps"]) + self.video_fps = metadata["video_fps"] + + @property + def metadata(self) -> Dict[str, Any]: + _metadata = { + "video_paths": self.video_paths, + "video_pts": self.video_pts, + "video_fps": self.video_fps, + } + return _metadata + + def subset(self, indices: List[int]) -> "VideoClips": + video_paths = [self.video_paths[i] for i in indices] + video_pts = [self.video_pts[i] for i in indices] + video_fps = [self.video_fps[i] for i in indices] + metadata = { + "video_paths": video_paths, + "video_pts": video_pts, + "video_fps": video_fps, + } + return type(self)( + video_paths, + clip_length_in_frames=self.num_frames, + frames_between_clips=self.step, + frame_rate=self.frame_rate, + _precomputed_metadata=metadata, + num_workers=self.num_workers, + _video_width=self._video_width, + _video_height=self._video_height, + _video_min_dimension=self._video_min_dimension, + _video_max_dimension=self._video_max_dimension, + _audio_samples=self._audio_samples, + _audio_channels=self._audio_channels, + output_format=self.output_format, + ) + + @staticmethod + def compute_clips_for_video( + video_pts: torch.Tensor, num_frames: int, step: int, fps: Optional[float], frame_rate: Optional[float] = None + ) -> Tuple[torch.Tensor, Union[List[slice], torch.Tensor]]: + if fps is None: + # if for some reason the video doesn't have fps (because doesn't have a video stream) + # set the fps to 1. The value doesn't matter, because video_pts is empty anyway + fps = 1 + if frame_rate is None: + frame_rate = fps + total_frames = len(video_pts) * frame_rate / fps + _idxs = VideoClips._resample_video_idx(int(math.floor(total_frames)), fps, frame_rate) + video_pts = video_pts[_idxs] + clips = unfold(video_pts, num_frames, step) + if not clips.numel(): + warnings.warn( + "There aren't enough frames in the current video to get a clip for the given clip length and " + "frames between clips. The video (and potentially others) will be skipped." + ) + idxs: Union[List[slice], torch.Tensor] + if isinstance(_idxs, slice): + idxs = [_idxs] * len(clips) + else: + idxs = unfold(_idxs, num_frames, step) + return clips, idxs + + def compute_clips(self, num_frames: int, step: int, frame_rate: Optional[float] = None) -> None: + """ + Compute all consecutive sequences of clips from video_pts. + Always returns clips of size `num_frames`, meaning that the + last few frames in a video can potentially be dropped. + + Args: + num_frames (int): number of frames for the clip + step (int): distance between two clips + frame_rate (int, optional): The frame rate + """ + self.num_frames = num_frames + self.step = step + self.frame_rate = frame_rate + self.clips = [] + self.resampling_idxs = [] + for video_pts, fps in zip(self.video_pts, self.video_fps): + clips, idxs = self.compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate) + self.clips.append(clips) + self.resampling_idxs.append(idxs) + clip_lengths = torch.as_tensor([len(v) for v in self.clips]) + self.cumulative_sizes = clip_lengths.cumsum(0).tolist() + + def __len__(self) -> int: + return self.num_clips() + + def num_videos(self) -> int: + return len(self.video_paths) + + def num_clips(self) -> int: + """ + Number of subclips that are available in the video list. + """ + return self.cumulative_sizes[-1] + + def get_clip_location(self, idx: int) -> Tuple[int, int]: + """ + Converts a flattened representation of the indices into a video_idx, clip_idx + representation. + """ + video_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if video_idx == 0: + clip_idx = idx + else: + clip_idx = idx - self.cumulative_sizes[video_idx - 1] + return video_idx, clip_idx + + @staticmethod + def _resample_video_idx(num_frames: int, original_fps: float, new_fps: float) -> Union[slice, torch.Tensor]: + step = original_fps / new_fps + if step.is_integer(): + # optimization: if step is integer, don't need to perform + # advanced indexing + step = int(step) + return slice(None, None, step) + idxs = torch.arange(num_frames, dtype=torch.float32) * step + idxs = idxs.floor().to(torch.int64) + return idxs + + def get_clip(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any], int]: + """ + Gets a subclip from a list of videos. + + Args: + idx (int): index of the subclip. Must be between 0 and num_clips(). + + Returns: + video (Tensor) + audio (Tensor) + info (Dict) + video_idx (int): index of the video in `video_paths` + """ + if idx >= self.num_clips(): + raise IndexError(f"Index {idx} out of range ({self.num_clips()} number of clips)") + video_idx, clip_idx = self.get_clip_location(idx) + video_path = self.video_paths[video_idx] + clip_pts = self.clips[video_idx][clip_idx] + + from torchvision import get_video_backend + + backend = get_video_backend() + + if backend == "pyav": + # check for invalid options + if self._video_width != 0: + raise ValueError("pyav backend doesn't support _video_width != 0") + if self._video_height != 0: + raise ValueError("pyav backend doesn't support _video_height != 0") + if self._video_min_dimension != 0: + raise ValueError("pyav backend doesn't support _video_min_dimension != 0") + if self._video_max_dimension != 0: + raise ValueError("pyav backend doesn't support _video_max_dimension != 0") + if self._audio_samples != 0: + raise ValueError("pyav backend doesn't support _audio_samples != 0") + + if backend == "pyav": + start_pts = clip_pts[0].item() + end_pts = clip_pts[-1].item() + video, audio, info = read_video(video_path, start_pts, end_pts) + else: + _info = _probe_video_from_file(video_path) + video_fps = _info.video_fps + audio_fps = None + + video_start_pts = cast(int, clip_pts[0].item()) + video_end_pts = cast(int, clip_pts[-1].item()) + + audio_start_pts, audio_end_pts = 0, -1 + audio_timebase = Fraction(0, 1) + video_timebase = Fraction(_info.video_timebase.numerator, _info.video_timebase.denominator) + if _info.has_audio: + audio_timebase = Fraction(_info.audio_timebase.numerator, _info.audio_timebase.denominator) + audio_start_pts = pts_convert(video_start_pts, video_timebase, audio_timebase, math.floor) + audio_end_pts = pts_convert(video_end_pts, video_timebase, audio_timebase, math.ceil) + audio_fps = _info.audio_sample_rate + video, audio, _ = _read_video_from_file( + video_path, + video_width=self._video_width, + video_height=self._video_height, + video_min_dimension=self._video_min_dimension, + video_max_dimension=self._video_max_dimension, + video_pts_range=(video_start_pts, video_end_pts), + video_timebase=video_timebase, + audio_samples=self._audio_samples, + audio_channels=self._audio_channels, + audio_pts_range=(audio_start_pts, audio_end_pts), + audio_timebase=audio_timebase, + ) + + info = {"video_fps": video_fps} + if audio_fps is not None: + info["audio_fps"] = audio_fps + + if self.frame_rate is not None: + resampling_idx = self.resampling_idxs[video_idx][clip_idx] + if isinstance(resampling_idx, torch.Tensor): + resampling_idx = resampling_idx - resampling_idx[0] + video = video[resampling_idx] + info["video_fps"] = self.frame_rate + assert len(video) == self.num_frames, f"{video.shape} x {self.num_frames}" + + if self.output_format == "TCHW": + # [T,H,W,C] --> [T,C,H,W] + video = video.permute(0, 3, 1, 2) + + return video, audio, info, video_idx + + def __getstate__(self) -> Dict[str, Any]: + video_pts_sizes = [len(v) for v in self.video_pts] + # To be back-compatible, we convert data to dtype torch.long as needed + # because for empty list, in legacy implementation, torch.as_tensor will + # use torch.float as default dtype. This happens when decoding fails and + # no pts is returned in the list. + video_pts = [x.to(torch.int64) for x in self.video_pts] + # video_pts can be an empty list if no frames have been decoded + if video_pts: + video_pts = torch.cat(video_pts) # type: ignore[assignment] + # avoid bug in https://github.com/pytorch/pytorch/issues/32351 + # TODO: Revert it once the bug is fixed. + video_pts = video_pts.numpy() # type: ignore[attr-defined] + + # make a copy of the fields of self + d = self.__dict__.copy() + d["video_pts_sizes"] = video_pts_sizes + d["video_pts"] = video_pts + # delete the following attributes to reduce the size of dictionary. They + # will be re-computed in "__setstate__()" + del d["clips"] + del d["resampling_idxs"] + del d["cumulative_sizes"] + + # for backwards-compatibility + d["_version"] = 2 + return d + + def __setstate__(self, d: Dict[str, Any]) -> None: + # for backwards-compatibility + if "_version" not in d: + self.__dict__ = d + return + + video_pts = torch.as_tensor(d["video_pts"], dtype=torch.int64) + video_pts = torch.split(video_pts, d["video_pts_sizes"], dim=0) + # don't need this info anymore + del d["video_pts_sizes"] + + d["video_pts"] = video_pts + self.__dict__ = d + # recompute attributes "clips", "resampling_idxs" and other derivative ones + self.compute_clips(self.num_frames, self.step, self.frame_rate) diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/vision.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/vision.py new file mode 100644 index 0000000000000000000000000000000000000000..e524c67e263abd9ec9ed5b86aa55203e1f3ba144 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/vision.py @@ -0,0 +1,111 @@ +import os +from pathlib import Path +from typing import Any, Callable, List, Optional, Tuple, Union + +import torch.utils.data as data + +from ..utils import _log_api_usage_once + + +class VisionDataset(data.Dataset): + """ + Base Class For making datasets which are compatible with torchvision. + It is necessary to override the ``__getitem__`` and ``__len__`` method. + + Args: + root (string, optional): Root directory of dataset. Only used for `__repr__`. + transforms (callable, optional): A function/transforms that takes in + an image and a label and returns the transformed versions of both. + transform (callable, optional): A function/transform that takes in a PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + + .. note:: + + :attr:`transforms` and the combination of :attr:`transform` and :attr:`target_transform` are mutually exclusive. + """ + + _repr_indent = 4 + + def __init__( + self, + root: Union[str, Path] = None, # type: ignore[assignment] + transforms: Optional[Callable] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: + _log_api_usage_once(self) + if isinstance(root, str): + root = os.path.expanduser(root) + self.root = root + + has_transforms = transforms is not None + has_separate_transform = transform is not None or target_transform is not None + if has_transforms and has_separate_transform: + raise ValueError("Only transforms or transform/target_transform can be passed as argument") + + # for backwards-compatibility + self.transform = transform + self.target_transform = target_transform + + if has_separate_transform: + transforms = StandardTransform(transform, target_transform) + self.transforms = transforms + + def __getitem__(self, index: int) -> Any: + """ + Args: + index (int): Index + + Returns: + (Any): Sample and meta data, optionally transformed by the respective transforms. + """ + raise NotImplementedError + + def __len__(self) -> int: + raise NotImplementedError + + def __repr__(self) -> str: + head = "Dataset " + self.__class__.__name__ + body = [f"Number of datapoints: {self.__len__()}"] + if self.root is not None: + body.append(f"Root location: {self.root}") + body += self.extra_repr().splitlines() + if hasattr(self, "transforms") and self.transforms is not None: + body += [repr(self.transforms)] + lines = [head] + [" " * self._repr_indent + line for line in body] + return "\n".join(lines) + + def _format_transform_repr(self, transform: Callable, head: str) -> List[str]: + lines = transform.__repr__().splitlines() + return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]] + + def extra_repr(self) -> str: + return "" + + +class StandardTransform: + def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) -> None: + self.transform = transform + self.target_transform = target_transform + + def __call__(self, input: Any, target: Any) -> Tuple[Any, Any]: + if self.transform is not None: + input = self.transform(input) + if self.target_transform is not None: + target = self.target_transform(target) + return input, target + + def _format_transform_repr(self, transform: Callable, head: str) -> List[str]: + lines = transform.__repr__().splitlines() + return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]] + + def __repr__(self) -> str: + body = [self.__class__.__name__] + if self.transform is not None: + body += self._format_transform_repr(self.transform, "Transform: ") + if self.target_transform is not None: + body += self._format_transform_repr(self.target_transform, "Target transform: ") + + return "\n".join(body) diff --git a/.venv/lib/python3.11/site-packages/torchvision/datasets/voc.py b/.venv/lib/python3.11/site-packages/torchvision/datasets/voc.py new file mode 100644 index 0000000000000000000000000000000000000000..0f0e84c84fa159e942b6da20fb9b2f651816ac06 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchvision/datasets/voc.py @@ -0,0 +1,224 @@ +import collections +import os +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from xml.etree.ElementTree import Element as ET_Element + +try: + from defusedxml.ElementTree import parse as ET_parse +except ImportError: + from xml.etree.ElementTree import parse as ET_parse + +from PIL import Image + +from .utils import download_and_extract_archive, verify_str_arg +from .vision import VisionDataset + +DATASET_YEAR_DICT = { + "2012": { + "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar", + "filename": "VOCtrainval_11-May-2012.tar", + "md5": "6cd6e144f989b92b3379bac3b3de84fd", + "base_dir": os.path.join("VOCdevkit", "VOC2012"), + }, + "2011": { + "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar", + "filename": "VOCtrainval_25-May-2011.tar", + "md5": "6c3384ef61512963050cb5d687e5bf1e", + "base_dir": os.path.join("TrainVal", "VOCdevkit", "VOC2011"), + }, + "2010": { + "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar", + "filename": "VOCtrainval_03-May-2010.tar", + "md5": "da459979d0c395079b5c75ee67908abb", + "base_dir": os.path.join("VOCdevkit", "VOC2010"), + }, + "2009": { + "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar", + "filename": "VOCtrainval_11-May-2009.tar", + "md5": "a3e00b113cfcfebf17e343f59da3caa1", + "base_dir": os.path.join("VOCdevkit", "VOC2009"), + }, + "2008": { + "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar", + "filename": "VOCtrainval_11-May-2012.tar", + "md5": "2629fa636546599198acfcfbfcf1904a", + "base_dir": os.path.join("VOCdevkit", "VOC2008"), + }, + "2007": { + "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar", + "filename": "VOCtrainval_06-Nov-2007.tar", + "md5": "c52e279531787c972589f7e41ab4ae64", + "base_dir": os.path.join("VOCdevkit", "VOC2007"), + }, + "2007-test": { + "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar", + "filename": "VOCtest_06-Nov-2007.tar", + "md5": "b6e924de25625d8de591ea690078ad9f", + "base_dir": os.path.join("VOCdevkit", "VOC2007"), + }, +} + + +class _VOCBase(VisionDataset): + _SPLITS_DIR: str + _TARGET_DIR: str + _TARGET_FILE_EXT: str + + def __init__( + self, + root: Union[str, Path], + year: str = "2012", + image_set: str = "train", + download: bool = False, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + transforms: Optional[Callable] = None, + ): + super().__init__(root, transforms, transform, target_transform) + + self.year = verify_str_arg(year, "year", valid_values=[str(yr) for yr in range(2007, 2013)]) + + valid_image_sets = ["train", "trainval", "val"] + if year == "2007": + valid_image_sets.append("test") + self.image_set = verify_str_arg(image_set, "image_set", valid_image_sets) + + key = "2007-test" if year == "2007" and image_set == "test" else year + dataset_year_dict = DATASET_YEAR_DICT[key] + + self.url = dataset_year_dict["url"] + self.filename = dataset_year_dict["filename"] + self.md5 = dataset_year_dict["md5"] + + base_dir = dataset_year_dict["base_dir"] + voc_root = os.path.join(self.root, base_dir) + + if download: + download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5) + + if not os.path.isdir(voc_root): + raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it") + + splits_dir = os.path.join(voc_root, "ImageSets", self._SPLITS_DIR) + split_f = os.path.join(splits_dir, image_set.rstrip("\n") + ".txt") + with open(os.path.join(split_f)) as f: + file_names = [x.strip() for x in f.readlines()] + + image_dir = os.path.join(voc_root, "JPEGImages") + self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names] + + target_dir = os.path.join(voc_root, self._TARGET_DIR) + self.targets = [os.path.join(target_dir, x + self._TARGET_FILE_EXT) for x in file_names] + + assert len(self.images) == len(self.targets) + + def __len__(self) -> int: + return len(self.images) + + +class VOCSegmentation(_VOCBase): + """`Pascal VOC `_ Segmentation Dataset. + + Args: + root (str or ``pathlib.Path``): Root directory of the VOC Dataset. + year (string, optional): The dataset year, supports years ``"2007"`` to ``"2012"``. + image_set (string, optional): Select the image_set to use, ``"train"``, ``"trainval"`` or ``"val"``. If + ``year=="2007"``, can also be ``"test"``. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + transform (callable, optional): A function/transform that takes in a PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + transforms (callable, optional): A function/transform that takes input sample and its target as entry + and returns a transformed version. + """ + + _SPLITS_DIR = "Segmentation" + _TARGET_DIR = "SegmentationClass" + _TARGET_FILE_EXT = ".png" + + @property + def masks(self) -> List[str]: + return self.targets + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """ + Args: + index (int): Index + + Returns: + tuple: (image, target) where target is the image segmentation. + """ + img = Image.open(self.images[index]).convert("RGB") + target = Image.open(self.masks[index]) + + if self.transforms is not None: + img, target = self.transforms(img, target) + + return img, target + + +class VOCDetection(_VOCBase): + """`Pascal VOC `_ Detection Dataset. + + Args: + root (str or ``pathlib.Path``): Root directory of the VOC Dataset. + year (string, optional): The dataset year, supports years ``"2007"`` to ``"2012"``. + image_set (string, optional): Select the image_set to use, ``"train"``, ``"trainval"`` or ``"val"``. If + ``year=="2007"``, can also be ``"test"``. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + (default: alphabetic indexing of VOC's 20 classes). + transform (callable, optional): A function/transform that takes in a PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, required): A function/transform that takes in the + target and transforms it. + transforms (callable, optional): A function/transform that takes input sample and its target as entry + and returns a transformed version. + """ + + _SPLITS_DIR = "Main" + _TARGET_DIR = "Annotations" + _TARGET_FILE_EXT = ".xml" + + @property + def annotations(self) -> List[str]: + return self.targets + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """ + Args: + index (int): Index + + Returns: + tuple: (image, target) where target is a dictionary of the XML tree. + """ + img = Image.open(self.images[index]).convert("RGB") + target = self.parse_voc_xml(ET_parse(self.annotations[index]).getroot()) + + if self.transforms is not None: + img, target = self.transforms(img, target) + + return img, target + + @staticmethod + def parse_voc_xml(node: ET_Element) -> Dict[str, Any]: + voc_dict: Dict[str, Any] = {} + children = list(node) + if children: + def_dic: Dict[str, Any] = collections.defaultdict(list) + for dc in map(VOCDetection.parse_voc_xml, children): + for ind, v in dc.items(): + def_dic[ind].append(v) + if node.tag == "annotation": + def_dic["object"] = [def_dic["object"]] + voc_dict = {node.tag: {ind: v[0] if len(v) == 1 else v for ind, v in def_dic.items()}} + if node.text: + text = node.text.strip() + if not children: + voc_dict[node.tag] = text + return voc_dict