Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/torchvision/datasets/__init__.py +146 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/_optical_flow.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/_stereo_matching.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/celeba.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/cifar.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/cityscapes.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/clevr.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/dtd.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/eurosat.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/fer2013.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/fgvc_aircraft.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/flowers102.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/folder.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/food101.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/gtsrb.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/imagenette.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/inaturalist.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/kitti.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/lfw.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/lsun.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/moving_mnist.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/omniglot.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/oxford_iiit_pet.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/pcam.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/phototour.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/places365.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/rendered_sst2.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/sbd.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/semeion.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/stanford_cars.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/stl10.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/svhn.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/ucf101.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/usps.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/vision.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/_stereo_matching.py +1224 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/caltech.py +242 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/celeba.py +194 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/cifar.py +168 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/cityscapes.py +222 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/clevr.py +88 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/coco.py +109 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/dtd.py +100 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/eurosat.py +62 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/fakedata.py +67 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/fer2013.py +120 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/fgvc_aircraft.py +115 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/flickr.py +167 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/flowers102.py +114 -0
- .venv/lib/python3.11/site-packages/torchvision/datasets/food101.py +93 -0
.venv/lib/python3.11/site-packages/torchvision/datasets/__init__.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel
|
| 2 |
+
from ._stereo_matching import (
|
| 3 |
+
CarlaStereo,
|
| 4 |
+
CREStereo,
|
| 5 |
+
ETH3DStereo,
|
| 6 |
+
FallingThingsStereo,
|
| 7 |
+
InStereo2k,
|
| 8 |
+
Kitti2012Stereo,
|
| 9 |
+
Kitti2015Stereo,
|
| 10 |
+
Middlebury2014Stereo,
|
| 11 |
+
SceneFlowStereo,
|
| 12 |
+
SintelStereo,
|
| 13 |
+
)
|
| 14 |
+
from .caltech import Caltech101, Caltech256
|
| 15 |
+
from .celeba import CelebA
|
| 16 |
+
from .cifar import CIFAR10, CIFAR100
|
| 17 |
+
from .cityscapes import Cityscapes
|
| 18 |
+
from .clevr import CLEVRClassification
|
| 19 |
+
from .coco import CocoCaptions, CocoDetection
|
| 20 |
+
from .country211 import Country211
|
| 21 |
+
from .dtd import DTD
|
| 22 |
+
from .eurosat import EuroSAT
|
| 23 |
+
from .fakedata import FakeData
|
| 24 |
+
from .fer2013 import FER2013
|
| 25 |
+
from .fgvc_aircraft import FGVCAircraft
|
| 26 |
+
from .flickr import Flickr30k, Flickr8k
|
| 27 |
+
from .flowers102 import Flowers102
|
| 28 |
+
from .folder import DatasetFolder, ImageFolder
|
| 29 |
+
from .food101 import Food101
|
| 30 |
+
from .gtsrb import GTSRB
|
| 31 |
+
from .hmdb51 import HMDB51
|
| 32 |
+
from .imagenet import ImageNet
|
| 33 |
+
from .imagenette import Imagenette
|
| 34 |
+
from .inaturalist import INaturalist
|
| 35 |
+
from .kinetics import Kinetics
|
| 36 |
+
from .kitti import Kitti
|
| 37 |
+
from .lfw import LFWPairs, LFWPeople
|
| 38 |
+
from .lsun import LSUN, LSUNClass
|
| 39 |
+
from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST
|
| 40 |
+
from .moving_mnist import MovingMNIST
|
| 41 |
+
from .omniglot import Omniglot
|
| 42 |
+
from .oxford_iiit_pet import OxfordIIITPet
|
| 43 |
+
from .pcam import PCAM
|
| 44 |
+
from .phototour import PhotoTour
|
| 45 |
+
from .places365 import Places365
|
| 46 |
+
from .rendered_sst2 import RenderedSST2
|
| 47 |
+
from .sbd import SBDataset
|
| 48 |
+
from .sbu import SBU
|
| 49 |
+
from .semeion import SEMEION
|
| 50 |
+
from .stanford_cars import StanfordCars
|
| 51 |
+
from .stl10 import STL10
|
| 52 |
+
from .sun397 import SUN397
|
| 53 |
+
from .svhn import SVHN
|
| 54 |
+
from .ucf101 import UCF101
|
| 55 |
+
from .usps import USPS
|
| 56 |
+
from .vision import VisionDataset
|
| 57 |
+
from .voc import VOCDetection, VOCSegmentation
|
| 58 |
+
from .widerface import WIDERFace
|
| 59 |
+
|
| 60 |
+
__all__ = (
|
| 61 |
+
"LSUN",
|
| 62 |
+
"LSUNClass",
|
| 63 |
+
"ImageFolder",
|
| 64 |
+
"DatasetFolder",
|
| 65 |
+
"FakeData",
|
| 66 |
+
"CocoCaptions",
|
| 67 |
+
"CocoDetection",
|
| 68 |
+
"CIFAR10",
|
| 69 |
+
"CIFAR100",
|
| 70 |
+
"EMNIST",
|
| 71 |
+
"FashionMNIST",
|
| 72 |
+
"QMNIST",
|
| 73 |
+
"MNIST",
|
| 74 |
+
"KMNIST",
|
| 75 |
+
"StanfordCars",
|
| 76 |
+
"STL10",
|
| 77 |
+
"SUN397",
|
| 78 |
+
"SVHN",
|
| 79 |
+
"PhotoTour",
|
| 80 |
+
"SEMEION",
|
| 81 |
+
"Omniglot",
|
| 82 |
+
"SBU",
|
| 83 |
+
"Flickr8k",
|
| 84 |
+
"Flickr30k",
|
| 85 |
+
"Flowers102",
|
| 86 |
+
"VOCSegmentation",
|
| 87 |
+
"VOCDetection",
|
| 88 |
+
"Cityscapes",
|
| 89 |
+
"ImageNet",
|
| 90 |
+
"Caltech101",
|
| 91 |
+
"Caltech256",
|
| 92 |
+
"CelebA",
|
| 93 |
+
"WIDERFace",
|
| 94 |
+
"SBDataset",
|
| 95 |
+
"VisionDataset",
|
| 96 |
+
"USPS",
|
| 97 |
+
"Kinetics",
|
| 98 |
+
"HMDB51",
|
| 99 |
+
"UCF101",
|
| 100 |
+
"Places365",
|
| 101 |
+
"Kitti",
|
| 102 |
+
"INaturalist",
|
| 103 |
+
"LFWPeople",
|
| 104 |
+
"LFWPairs",
|
| 105 |
+
"KittiFlow",
|
| 106 |
+
"Sintel",
|
| 107 |
+
"FlyingChairs",
|
| 108 |
+
"FlyingThings3D",
|
| 109 |
+
"HD1K",
|
| 110 |
+
"Food101",
|
| 111 |
+
"DTD",
|
| 112 |
+
"FER2013",
|
| 113 |
+
"GTSRB",
|
| 114 |
+
"CLEVRClassification",
|
| 115 |
+
"OxfordIIITPet",
|
| 116 |
+
"PCAM",
|
| 117 |
+
"Country211",
|
| 118 |
+
"FGVCAircraft",
|
| 119 |
+
"EuroSAT",
|
| 120 |
+
"RenderedSST2",
|
| 121 |
+
"Kitti2012Stereo",
|
| 122 |
+
"Kitti2015Stereo",
|
| 123 |
+
"CarlaStereo",
|
| 124 |
+
"Middlebury2014Stereo",
|
| 125 |
+
"CREStereo",
|
| 126 |
+
"FallingThingsStereo",
|
| 127 |
+
"SceneFlowStereo",
|
| 128 |
+
"SintelStereo",
|
| 129 |
+
"InStereo2k",
|
| 130 |
+
"ETH3DStereo",
|
| 131 |
+
"wrap_dataset_for_transforms_v2",
|
| 132 |
+
"Imagenette",
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
# We override current module's attributes to handle the import:
|
| 137 |
+
# from torchvision.datasets import wrap_dataset_for_transforms_v2
|
| 138 |
+
# without a cyclic error.
|
| 139 |
+
# Ref: https://peps.python.org/pep-0562/
|
| 140 |
+
def __getattr__(name):
|
| 141 |
+
if name in ("wrap_dataset_for_transforms_v2",):
|
| 142 |
+
from torchvision.tv_tensors._dataset_wrapper import wrap_dataset_for_transforms_v2
|
| 143 |
+
|
| 144 |
+
return wrap_dataset_for_transforms_v2
|
| 145 |
+
|
| 146 |
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/_optical_flow.cpython-311.pyc
ADDED
|
Binary file (27.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/_stereo_matching.cpython-311.pyc
ADDED
|
Binary file (59.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/celeba.cpython-311.pyc
ADDED
|
Binary file (12 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/cifar.cpython-311.pyc
ADDED
|
Binary file (9.05 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/cityscapes.cpython-311.pyc
ADDED
|
Binary file (14.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/clevr.cpython-311.pyc
ADDED
|
Binary file (6.58 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/dtd.cpython-311.pyc
ADDED
|
Binary file (7.21 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/eurosat.cpython-311.pyc
ADDED
|
Binary file (4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/fer2013.cpython-311.pyc
ADDED
|
Binary file (7.66 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/fgvc_aircraft.cpython-311.pyc
ADDED
|
Binary file (7.66 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/flowers102.cpython-311.pyc
ADDED
|
Binary file (7.26 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/folder.cpython-311.pyc
ADDED
|
Binary file (17.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/food101.cpython-311.pyc
ADDED
|
Binary file (6.97 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/gtsrb.cpython-311.pyc
ADDED
|
Binary file (6.04 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/imagenette.cpython-311.pyc
ADDED
|
Binary file (7.02 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/inaturalist.cpython-311.pyc
ADDED
|
Binary file (14.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/kitti.cpython-311.pyc
ADDED
|
Binary file (9.38 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/lfw.cpython-311.pyc
ADDED
|
Binary file (18.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/lsun.cpython-311.pyc
ADDED
|
Binary file (10.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/moving_mnist.cpython-311.pyc
ADDED
|
Binary file (6.03 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/omniglot.cpython-311.pyc
ADDED
|
Binary file (7.32 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/oxford_iiit_pet.cpython-311.pyc
ADDED
|
Binary file (10 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/pcam.cpython-311.pyc
ADDED
|
Binary file (7.86 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/phototour.cpython-311.pyc
ADDED
|
Binary file (13.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/places365.cpython-311.pyc
ADDED
|
Binary file (12.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/rendered_sst2.cpython-311.pyc
ADDED
|
Binary file (5.83 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/sbd.cpython-311.pyc
ADDED
|
Binary file (9.53 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/semeion.cpython-311.pyc
ADDED
|
Binary file (5.17 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/stanford_cars.cpython-311.pyc
ADDED
|
Binary file (6.91 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/stl10.cpython-311.pyc
ADDED
|
Binary file (11.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/svhn.cpython-311.pyc
ADDED
|
Binary file (6.68 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/ucf101.cpython-311.pyc
ADDED
|
Binary file (8.49 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/usps.cpython-311.pyc
ADDED
|
Binary file (6.15 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/vision.cpython-311.pyc
ADDED
|
Binary file (7.58 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/datasets/_stereo_matching.py
ADDED
|
@@ -0,0 +1,1224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
import shutil
|
| 6 |
+
from abc import ABC, abstractmethod
|
| 7 |
+
from glob import glob
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Callable, cast, List, Optional, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
from PIL import Image
|
| 13 |
+
|
| 14 |
+
from .utils import _read_pfm, download_and_extract_archive, verify_str_arg
|
| 15 |
+
from .vision import VisionDataset
|
| 16 |
+
|
| 17 |
+
T1 = Tuple[Image.Image, Image.Image, Optional[np.ndarray], np.ndarray]
|
| 18 |
+
T2 = Tuple[Image.Image, Image.Image, Optional[np.ndarray]]
|
| 19 |
+
|
| 20 |
+
__all__ = ()
|
| 21 |
+
|
| 22 |
+
_read_pfm_file = functools.partial(_read_pfm, slice_channels=1)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class StereoMatchingDataset(ABC, VisionDataset):
|
| 26 |
+
"""Base interface for Stereo matching datasets"""
|
| 27 |
+
|
| 28 |
+
_has_built_in_disparity_mask = False
|
| 29 |
+
|
| 30 |
+
def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None:
|
| 31 |
+
"""
|
| 32 |
+
Args:
|
| 33 |
+
root(str): Root directory of the dataset.
|
| 34 |
+
transforms(callable, optional): A function/transform that takes in Tuples of
|
| 35 |
+
(images, disparities, valid_masks) and returns a transformed version of each of them.
|
| 36 |
+
images is a Tuple of (``PIL.Image``, ``PIL.Image``)
|
| 37 |
+
disparities is a Tuple of (``np.ndarray``, ``np.ndarray``) with shape (1, H, W)
|
| 38 |
+
valid_masks is a Tuple of (``np.ndarray``, ``np.ndarray``) with shape (H, W)
|
| 39 |
+
In some cases, when a dataset does not provide disparities, the ``disparities`` and
|
| 40 |
+
``valid_masks`` can be Tuples containing None values.
|
| 41 |
+
For training splits generally the datasets provide a minimal guarantee of
|
| 42 |
+
images: (``PIL.Image``, ``PIL.Image``)
|
| 43 |
+
disparities: (``np.ndarray``, ``None``) with shape (1, H, W)
|
| 44 |
+
Optionally, based on the dataset, it can return a ``mask`` as well:
|
| 45 |
+
valid_masks: (``np.ndarray | None``, ``None``) with shape (H, W)
|
| 46 |
+
For some test splits, the datasets provides outputs that look like:
|
| 47 |
+
imgaes: (``PIL.Image``, ``PIL.Image``)
|
| 48 |
+
disparities: (``None``, ``None``)
|
| 49 |
+
Optionally, based on the dataset, it can return a ``mask`` as well:
|
| 50 |
+
valid_masks: (``None``, ``None``)
|
| 51 |
+
"""
|
| 52 |
+
super().__init__(root=root)
|
| 53 |
+
self.transforms = transforms
|
| 54 |
+
|
| 55 |
+
self._images = [] # type: ignore
|
| 56 |
+
self._disparities = [] # type: ignore
|
| 57 |
+
|
| 58 |
+
def _read_img(self, file_path: Union[str, Path]) -> Image.Image:
|
| 59 |
+
img = Image.open(file_path)
|
| 60 |
+
if img.mode != "RGB":
|
| 61 |
+
img = img.convert("RGB") # type: ignore [assignment]
|
| 62 |
+
return img
|
| 63 |
+
|
| 64 |
+
def _scan_pairs(
|
| 65 |
+
self,
|
| 66 |
+
paths_left_pattern: str,
|
| 67 |
+
paths_right_pattern: Optional[str] = None,
|
| 68 |
+
) -> List[Tuple[str, Optional[str]]]:
|
| 69 |
+
|
| 70 |
+
left_paths = list(sorted(glob(paths_left_pattern)))
|
| 71 |
+
|
| 72 |
+
right_paths: List[Union[None, str]]
|
| 73 |
+
if paths_right_pattern:
|
| 74 |
+
right_paths = list(sorted(glob(paths_right_pattern)))
|
| 75 |
+
else:
|
| 76 |
+
right_paths = list(None for _ in left_paths)
|
| 77 |
+
|
| 78 |
+
if not left_paths:
|
| 79 |
+
raise FileNotFoundError(f"Could not find any files matching the patterns: {paths_left_pattern}")
|
| 80 |
+
|
| 81 |
+
if not right_paths:
|
| 82 |
+
raise FileNotFoundError(f"Could not find any files matching the patterns: {paths_right_pattern}")
|
| 83 |
+
|
| 84 |
+
if len(left_paths) != len(right_paths):
|
| 85 |
+
raise ValueError(
|
| 86 |
+
f"Found {len(left_paths)} left files but {len(right_paths)} right files using:\n "
|
| 87 |
+
f"left pattern: {paths_left_pattern}\n"
|
| 88 |
+
f"right pattern: {paths_right_pattern}\n"
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
paths = list((left, right) for left, right in zip(left_paths, right_paths))
|
| 92 |
+
return paths
|
| 93 |
+
|
| 94 |
+
@abstractmethod
|
| 95 |
+
def _read_disparity(self, file_path: str) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
|
| 96 |
+
# function that returns a disparity map and an occlusion map
|
| 97 |
+
pass
|
| 98 |
+
|
| 99 |
+
def __getitem__(self, index: int) -> Union[T1, T2]:
|
| 100 |
+
"""Return example at given index.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
index(int): The index of the example to retrieve
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
tuple: A 3 or 4-tuple with ``(img_left, img_right, disparity, Optional[valid_mask])`` where ``valid_mask``
|
| 107 |
+
can be a numpy boolean mask of shape (H, W) if the dataset provides a file
|
| 108 |
+
indicating which disparity pixels are valid. The disparity is a numpy array of
|
| 109 |
+
shape (1, H, W) and the images are PIL images. ``disparity`` is None for
|
| 110 |
+
datasets on which for ``split="test"`` the authors did not provide annotations.
|
| 111 |
+
"""
|
| 112 |
+
img_left = self._read_img(self._images[index][0])
|
| 113 |
+
img_right = self._read_img(self._images[index][1])
|
| 114 |
+
|
| 115 |
+
dsp_map_left, valid_mask_left = self._read_disparity(self._disparities[index][0])
|
| 116 |
+
dsp_map_right, valid_mask_right = self._read_disparity(self._disparities[index][1])
|
| 117 |
+
|
| 118 |
+
imgs = (img_left, img_right)
|
| 119 |
+
dsp_maps = (dsp_map_left, dsp_map_right)
|
| 120 |
+
valid_masks = (valid_mask_left, valid_mask_right)
|
| 121 |
+
|
| 122 |
+
if self.transforms is not None:
|
| 123 |
+
(
|
| 124 |
+
imgs,
|
| 125 |
+
dsp_maps,
|
| 126 |
+
valid_masks,
|
| 127 |
+
) = self.transforms(imgs, dsp_maps, valid_masks)
|
| 128 |
+
|
| 129 |
+
if self._has_built_in_disparity_mask or valid_masks[0] is not None:
|
| 130 |
+
return imgs[0], imgs[1], dsp_maps[0], cast(np.ndarray, valid_masks[0])
|
| 131 |
+
else:
|
| 132 |
+
return imgs[0], imgs[1], dsp_maps[0]
|
| 133 |
+
|
| 134 |
+
def __len__(self) -> int:
|
| 135 |
+
return len(self._images)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class CarlaStereo(StereoMatchingDataset):
|
| 139 |
+
"""
|
| 140 |
+
Carla simulator data linked in the `CREStereo github repo <https://github.com/megvii-research/CREStereo>`_.
|
| 141 |
+
|
| 142 |
+
The dataset is expected to have the following structure: ::
|
| 143 |
+
|
| 144 |
+
root
|
| 145 |
+
carla-highres
|
| 146 |
+
trainingF
|
| 147 |
+
scene1
|
| 148 |
+
img0.png
|
| 149 |
+
img1.png
|
| 150 |
+
disp0GT.pfm
|
| 151 |
+
disp1GT.pfm
|
| 152 |
+
calib.txt
|
| 153 |
+
scene2
|
| 154 |
+
img0.png
|
| 155 |
+
img1.png
|
| 156 |
+
disp0GT.pfm
|
| 157 |
+
disp1GT.pfm
|
| 158 |
+
calib.txt
|
| 159 |
+
...
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
root (str or ``pathlib.Path``): Root directory where `carla-highres` is located.
|
| 163 |
+
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None:
|
| 167 |
+
super().__init__(root, transforms)
|
| 168 |
+
|
| 169 |
+
root = Path(root) / "carla-highres"
|
| 170 |
+
|
| 171 |
+
left_image_pattern = str(root / "trainingF" / "*" / "im0.png")
|
| 172 |
+
right_image_pattern = str(root / "trainingF" / "*" / "im1.png")
|
| 173 |
+
imgs = self._scan_pairs(left_image_pattern, right_image_pattern)
|
| 174 |
+
self._images = imgs
|
| 175 |
+
|
| 176 |
+
left_disparity_pattern = str(root / "trainingF" / "*" / "disp0GT.pfm")
|
| 177 |
+
right_disparity_pattern = str(root / "trainingF" / "*" / "disp1GT.pfm")
|
| 178 |
+
disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
|
| 179 |
+
self._disparities = disparities
|
| 180 |
+
|
| 181 |
+
def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
|
| 182 |
+
disparity_map = _read_pfm_file(file_path)
|
| 183 |
+
disparity_map = np.abs(disparity_map) # ensure that the disparity is positive
|
| 184 |
+
valid_mask = None
|
| 185 |
+
return disparity_map, valid_mask
|
| 186 |
+
|
| 187 |
+
def __getitem__(self, index: int) -> T1:
|
| 188 |
+
"""Return example at given index.
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
index(int): The index of the example to retrieve
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
tuple: A 3-tuple with ``(img_left, img_right, disparity)``.
|
| 195 |
+
The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
|
| 196 |
+
If a ``valid_mask`` is generated within the ``transforms`` parameter,
|
| 197 |
+
a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
|
| 198 |
+
"""
|
| 199 |
+
return cast(T1, super().__getitem__(index))
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class Kitti2012Stereo(StereoMatchingDataset):
|
| 203 |
+
"""
|
| 204 |
+
KITTI dataset from the `2012 stereo evaluation benchmark <http://www.cvlibs.net/datasets/kitti/eval_stereo_flow.php>`_.
|
| 205 |
+
Uses the RGB images for consistency with KITTI 2015.
|
| 206 |
+
|
| 207 |
+
The dataset is expected to have the following structure: ::
|
| 208 |
+
|
| 209 |
+
root
|
| 210 |
+
Kitti2012
|
| 211 |
+
testing
|
| 212 |
+
colored_0
|
| 213 |
+
1_10.png
|
| 214 |
+
2_10.png
|
| 215 |
+
...
|
| 216 |
+
colored_1
|
| 217 |
+
1_10.png
|
| 218 |
+
2_10.png
|
| 219 |
+
...
|
| 220 |
+
training
|
| 221 |
+
colored_0
|
| 222 |
+
1_10.png
|
| 223 |
+
2_10.png
|
| 224 |
+
...
|
| 225 |
+
colored_1
|
| 226 |
+
1_10.png
|
| 227 |
+
2_10.png
|
| 228 |
+
...
|
| 229 |
+
disp_noc
|
| 230 |
+
1.png
|
| 231 |
+
2.png
|
| 232 |
+
...
|
| 233 |
+
calib
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
root (str or ``pathlib.Path``): Root directory where `Kitti2012` is located.
|
| 237 |
+
split (string, optional): The dataset split of scenes, either "train" (default) or "test".
|
| 238 |
+
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
|
| 239 |
+
"""
|
| 240 |
+
|
| 241 |
+
_has_built_in_disparity_mask = True
|
| 242 |
+
|
| 243 |
+
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
|
| 244 |
+
super().__init__(root, transforms)
|
| 245 |
+
|
| 246 |
+
verify_str_arg(split, "split", valid_values=("train", "test"))
|
| 247 |
+
|
| 248 |
+
root = Path(root) / "Kitti2012" / (split + "ing")
|
| 249 |
+
|
| 250 |
+
left_img_pattern = str(root / "colored_0" / "*_10.png")
|
| 251 |
+
right_img_pattern = str(root / "colored_1" / "*_10.png")
|
| 252 |
+
self._images = self._scan_pairs(left_img_pattern, right_img_pattern)
|
| 253 |
+
|
| 254 |
+
if split == "train":
|
| 255 |
+
disparity_pattern = str(root / "disp_noc" / "*.png")
|
| 256 |
+
self._disparities = self._scan_pairs(disparity_pattern, None)
|
| 257 |
+
else:
|
| 258 |
+
self._disparities = list((None, None) for _ in self._images)
|
| 259 |
+
|
| 260 |
+
def _read_disparity(self, file_path: str) -> Tuple[Optional[np.ndarray], None]:
|
| 261 |
+
# test split has no disparity maps
|
| 262 |
+
if file_path is None:
|
| 263 |
+
return None, None
|
| 264 |
+
|
| 265 |
+
disparity_map = np.asarray(Image.open(file_path)) / 256.0
|
| 266 |
+
# unsqueeze the disparity map into (C, H, W) format
|
| 267 |
+
disparity_map = disparity_map[None, :, :]
|
| 268 |
+
valid_mask = None
|
| 269 |
+
return disparity_map, valid_mask
|
| 270 |
+
|
| 271 |
+
def __getitem__(self, index: int) -> T1:
|
| 272 |
+
"""Return example at given index.
|
| 273 |
+
|
| 274 |
+
Args:
|
| 275 |
+
index(int): The index of the example to retrieve
|
| 276 |
+
|
| 277 |
+
Returns:
|
| 278 |
+
tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
|
| 279 |
+
The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
|
| 280 |
+
``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not
|
| 281 |
+
generate a valid mask.
|
| 282 |
+
Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
|
| 283 |
+
"""
|
| 284 |
+
return cast(T1, super().__getitem__(index))
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class Kitti2015Stereo(StereoMatchingDataset):
|
| 288 |
+
"""
|
| 289 |
+
KITTI dataset from the `2015 stereo evaluation benchmark <http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php>`_.
|
| 290 |
+
|
| 291 |
+
The dataset is expected to have the following structure: ::
|
| 292 |
+
|
| 293 |
+
root
|
| 294 |
+
Kitti2015
|
| 295 |
+
testing
|
| 296 |
+
image_2
|
| 297 |
+
img1.png
|
| 298 |
+
img2.png
|
| 299 |
+
...
|
| 300 |
+
image_3
|
| 301 |
+
img1.png
|
| 302 |
+
img2.png
|
| 303 |
+
...
|
| 304 |
+
training
|
| 305 |
+
image_2
|
| 306 |
+
img1.png
|
| 307 |
+
img2.png
|
| 308 |
+
...
|
| 309 |
+
image_3
|
| 310 |
+
img1.png
|
| 311 |
+
img2.png
|
| 312 |
+
...
|
| 313 |
+
disp_occ_0
|
| 314 |
+
img1.png
|
| 315 |
+
img2.png
|
| 316 |
+
...
|
| 317 |
+
disp_occ_1
|
| 318 |
+
img1.png
|
| 319 |
+
img2.png
|
| 320 |
+
...
|
| 321 |
+
calib
|
| 322 |
+
|
| 323 |
+
Args:
|
| 324 |
+
root (str or ``pathlib.Path``): Root directory where `Kitti2015` is located.
|
| 325 |
+
split (string, optional): The dataset split of scenes, either "train" (default) or "test".
|
| 326 |
+
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
|
| 327 |
+
"""
|
| 328 |
+
|
| 329 |
+
_has_built_in_disparity_mask = True
|
| 330 |
+
|
| 331 |
+
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
|
| 332 |
+
super().__init__(root, transforms)
|
| 333 |
+
|
| 334 |
+
verify_str_arg(split, "split", valid_values=("train", "test"))
|
| 335 |
+
|
| 336 |
+
root = Path(root) / "Kitti2015" / (split + "ing")
|
| 337 |
+
left_img_pattern = str(root / "image_2" / "*.png")
|
| 338 |
+
right_img_pattern = str(root / "image_3" / "*.png")
|
| 339 |
+
self._images = self._scan_pairs(left_img_pattern, right_img_pattern)
|
| 340 |
+
|
| 341 |
+
if split == "train":
|
| 342 |
+
left_disparity_pattern = str(root / "disp_occ_0" / "*.png")
|
| 343 |
+
right_disparity_pattern = str(root / "disp_occ_1" / "*.png")
|
| 344 |
+
self._disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
|
| 345 |
+
else:
|
| 346 |
+
self._disparities = list((None, None) for _ in self._images)
|
| 347 |
+
|
| 348 |
+
def _read_disparity(self, file_path: str) -> Tuple[Optional[np.ndarray], None]:
|
| 349 |
+
# test split has no disparity maps
|
| 350 |
+
if file_path is None:
|
| 351 |
+
return None, None
|
| 352 |
+
|
| 353 |
+
disparity_map = np.asarray(Image.open(file_path)) / 256.0
|
| 354 |
+
# unsqueeze the disparity map into (C, H, W) format
|
| 355 |
+
disparity_map = disparity_map[None, :, :]
|
| 356 |
+
valid_mask = None
|
| 357 |
+
return disparity_map, valid_mask
|
| 358 |
+
|
| 359 |
+
def __getitem__(self, index: int) -> T1:
|
| 360 |
+
"""Return example at given index.
|
| 361 |
+
|
| 362 |
+
Args:
|
| 363 |
+
index(int): The index of the example to retrieve
|
| 364 |
+
|
| 365 |
+
Returns:
|
| 366 |
+
tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
|
| 367 |
+
The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
|
| 368 |
+
``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not
|
| 369 |
+
generate a valid mask.
|
| 370 |
+
Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
|
| 371 |
+
"""
|
| 372 |
+
return cast(T1, super().__getitem__(index))
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
class Middlebury2014Stereo(StereoMatchingDataset):
|
| 376 |
+
"""Publicly available scenes from the Middlebury dataset `2014 version <https://vision.middlebury.edu/stereo/data/scenes2014/>`.
|
| 377 |
+
|
| 378 |
+
The dataset mostly follows the original format, without containing the ambient subdirectories. : ::
|
| 379 |
+
|
| 380 |
+
root
|
| 381 |
+
Middlebury2014
|
| 382 |
+
train
|
| 383 |
+
scene1-{perfect,imperfect}
|
| 384 |
+
calib.txt
|
| 385 |
+
im{0,1}.png
|
| 386 |
+
im1E.png
|
| 387 |
+
im1L.png
|
| 388 |
+
disp{0,1}.pfm
|
| 389 |
+
disp{0,1}-n.png
|
| 390 |
+
disp{0,1}-sd.pfm
|
| 391 |
+
disp{0,1}y.pfm
|
| 392 |
+
scene2-{perfect,imperfect}
|
| 393 |
+
calib.txt
|
| 394 |
+
im{0,1}.png
|
| 395 |
+
im1E.png
|
| 396 |
+
im1L.png
|
| 397 |
+
disp{0,1}.pfm
|
| 398 |
+
disp{0,1}-n.png
|
| 399 |
+
disp{0,1}-sd.pfm
|
| 400 |
+
disp{0,1}y.pfm
|
| 401 |
+
...
|
| 402 |
+
additional
|
| 403 |
+
scene1-{perfect,imperfect}
|
| 404 |
+
calib.txt
|
| 405 |
+
im{0,1}.png
|
| 406 |
+
im1E.png
|
| 407 |
+
im1L.png
|
| 408 |
+
disp{0,1}.pfm
|
| 409 |
+
disp{0,1}-n.png
|
| 410 |
+
disp{0,1}-sd.pfm
|
| 411 |
+
disp{0,1}y.pfm
|
| 412 |
+
...
|
| 413 |
+
test
|
| 414 |
+
scene1
|
| 415 |
+
calib.txt
|
| 416 |
+
im{0,1}.png
|
| 417 |
+
scene2
|
| 418 |
+
calib.txt
|
| 419 |
+
im{0,1}.png
|
| 420 |
+
...
|
| 421 |
+
|
| 422 |
+
Args:
|
| 423 |
+
root (str or ``pathlib.Path``): Root directory of the Middleburry 2014 Dataset.
|
| 424 |
+
split (string, optional): The dataset split of scenes, either "train" (default), "test", or "additional"
|
| 425 |
+
use_ambient_views (boolean, optional): Whether to use different expose or lightning views when possible.
|
| 426 |
+
The dataset samples with equal probability between ``[im1.png, im1E.png, im1L.png]``.
|
| 427 |
+
calibration (string, optional): Whether or not to use the calibrated (default) or uncalibrated scenes.
|
| 428 |
+
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
|
| 429 |
+
download (boolean, optional): Whether or not to download the dataset in the ``root`` directory.
|
| 430 |
+
"""
|
| 431 |
+
|
| 432 |
+
splits = {
|
| 433 |
+
"train": [
|
| 434 |
+
"Adirondack",
|
| 435 |
+
"Jadeplant",
|
| 436 |
+
"Motorcycle",
|
| 437 |
+
"Piano",
|
| 438 |
+
"Pipes",
|
| 439 |
+
"Playroom",
|
| 440 |
+
"Playtable",
|
| 441 |
+
"Recycle",
|
| 442 |
+
"Shelves",
|
| 443 |
+
"Vintage",
|
| 444 |
+
],
|
| 445 |
+
"additional": [
|
| 446 |
+
"Backpack",
|
| 447 |
+
"Bicycle1",
|
| 448 |
+
"Cable",
|
| 449 |
+
"Classroom1",
|
| 450 |
+
"Couch",
|
| 451 |
+
"Flowers",
|
| 452 |
+
"Mask",
|
| 453 |
+
"Shopvac",
|
| 454 |
+
"Sticks",
|
| 455 |
+
"Storage",
|
| 456 |
+
"Sword1",
|
| 457 |
+
"Sword2",
|
| 458 |
+
"Umbrella",
|
| 459 |
+
],
|
| 460 |
+
"test": [
|
| 461 |
+
"Plants",
|
| 462 |
+
"Classroom2E",
|
| 463 |
+
"Classroom2",
|
| 464 |
+
"Australia",
|
| 465 |
+
"DjembeL",
|
| 466 |
+
"CrusadeP",
|
| 467 |
+
"Crusade",
|
| 468 |
+
"Hoops",
|
| 469 |
+
"Bicycle2",
|
| 470 |
+
"Staircase",
|
| 471 |
+
"Newkuba",
|
| 472 |
+
"AustraliaP",
|
| 473 |
+
"Djembe",
|
| 474 |
+
"Livingroom",
|
| 475 |
+
"Computer",
|
| 476 |
+
],
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
_has_built_in_disparity_mask = True
|
| 480 |
+
|
| 481 |
+
def __init__(
|
| 482 |
+
self,
|
| 483 |
+
root: Union[str, Path],
|
| 484 |
+
split: str = "train",
|
| 485 |
+
calibration: Optional[str] = "perfect",
|
| 486 |
+
use_ambient_views: bool = False,
|
| 487 |
+
transforms: Optional[Callable] = None,
|
| 488 |
+
download: bool = False,
|
| 489 |
+
) -> None:
|
| 490 |
+
super().__init__(root, transforms)
|
| 491 |
+
|
| 492 |
+
verify_str_arg(split, "split", valid_values=("train", "test", "additional"))
|
| 493 |
+
self.split = split
|
| 494 |
+
|
| 495 |
+
if calibration:
|
| 496 |
+
verify_str_arg(calibration, "calibration", valid_values=("perfect", "imperfect", "both", None)) # type: ignore
|
| 497 |
+
if split == "test":
|
| 498 |
+
raise ValueError("Split 'test' has only no calibration settings, please set `calibration=None`.")
|
| 499 |
+
else:
|
| 500 |
+
if split != "test":
|
| 501 |
+
raise ValueError(
|
| 502 |
+
f"Split '{split}' has calibration settings, however None was provided as an argument."
|
| 503 |
+
f"\nSetting calibration to 'perfect' for split '{split}'. Available calibration settings are: 'perfect', 'imperfect', 'both'.",
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
if download:
|
| 507 |
+
self._download_dataset(root)
|
| 508 |
+
|
| 509 |
+
root = Path(root) / "Middlebury2014"
|
| 510 |
+
|
| 511 |
+
if not os.path.exists(root / split):
|
| 512 |
+
raise FileNotFoundError(f"The {split} directory was not found in the provided root directory")
|
| 513 |
+
|
| 514 |
+
split_scenes = self.splits[split]
|
| 515 |
+
# check that the provided root folder contains the scene splits
|
| 516 |
+
if not any(
|
| 517 |
+
# using startswith to account for perfect / imperfect calibrartion
|
| 518 |
+
scene.startswith(s)
|
| 519 |
+
for scene in os.listdir(root / split)
|
| 520 |
+
for s in split_scenes
|
| 521 |
+
):
|
| 522 |
+
raise FileNotFoundError(f"Provided root folder does not contain any scenes from the {split} split.")
|
| 523 |
+
|
| 524 |
+
calibrartion_suffixes = {
|
| 525 |
+
None: [""],
|
| 526 |
+
"perfect": ["-perfect"],
|
| 527 |
+
"imperfect": ["-imperfect"],
|
| 528 |
+
"both": ["-perfect", "-imperfect"],
|
| 529 |
+
}[calibration]
|
| 530 |
+
|
| 531 |
+
for calibration_suffix in calibrartion_suffixes:
|
| 532 |
+
scene_pattern = "*" + calibration_suffix
|
| 533 |
+
left_img_pattern = str(root / split / scene_pattern / "im0.png")
|
| 534 |
+
right_img_pattern = str(root / split / scene_pattern / "im1.png")
|
| 535 |
+
self._images += self._scan_pairs(left_img_pattern, right_img_pattern)
|
| 536 |
+
|
| 537 |
+
if split == "test":
|
| 538 |
+
self._disparities = list((None, None) for _ in self._images)
|
| 539 |
+
else:
|
| 540 |
+
left_dispartity_pattern = str(root / split / scene_pattern / "disp0.pfm")
|
| 541 |
+
right_dispartity_pattern = str(root / split / scene_pattern / "disp1.pfm")
|
| 542 |
+
self._disparities += self._scan_pairs(left_dispartity_pattern, right_dispartity_pattern)
|
| 543 |
+
|
| 544 |
+
self.use_ambient_views = use_ambient_views
|
| 545 |
+
|
| 546 |
+
def _read_img(self, file_path: Union[str, Path]) -> Image.Image:
|
| 547 |
+
"""
|
| 548 |
+
Function that reads either the original right image or an augmented view when ``use_ambient_views`` is True.
|
| 549 |
+
When ``use_ambient_views`` is True, the dataset will return at random one of ``[im1.png, im1E.png, im1L.png]``
|
| 550 |
+
as the right image.
|
| 551 |
+
"""
|
| 552 |
+
ambient_file_paths: List[Union[str, Path]] # make mypy happy
|
| 553 |
+
|
| 554 |
+
if not isinstance(file_path, Path):
|
| 555 |
+
file_path = Path(file_path)
|
| 556 |
+
|
| 557 |
+
if file_path.name == "im1.png" and self.use_ambient_views:
|
| 558 |
+
base_path = file_path.parent
|
| 559 |
+
# initialize sampleable container
|
| 560 |
+
ambient_file_paths = list(base_path / view_name for view_name in ["im1E.png", "im1L.png"])
|
| 561 |
+
# double check that we're not going to try to read from an invalid file path
|
| 562 |
+
ambient_file_paths = list(filter(lambda p: os.path.exists(p), ambient_file_paths))
|
| 563 |
+
# keep the original image as an option as well for uniform sampling between base views
|
| 564 |
+
ambient_file_paths.append(file_path)
|
| 565 |
+
file_path = random.choice(ambient_file_paths) # type: ignore
|
| 566 |
+
return super()._read_img(file_path)
|
| 567 |
+
|
| 568 |
+
def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]:
|
| 569 |
+
# test split has not disparity maps
|
| 570 |
+
if file_path is None:
|
| 571 |
+
return None, None
|
| 572 |
+
|
| 573 |
+
disparity_map = _read_pfm_file(file_path)
|
| 574 |
+
disparity_map = np.abs(disparity_map) # ensure that the disparity is positive
|
| 575 |
+
disparity_map[disparity_map == np.inf] = 0 # remove infinite disparities
|
| 576 |
+
valid_mask = (disparity_map > 0).squeeze(0) # mask out invalid disparities
|
| 577 |
+
return disparity_map, valid_mask
|
| 578 |
+
|
| 579 |
+
def _download_dataset(self, root: Union[str, Path]) -> None:
|
| 580 |
+
base_url = "https://vision.middlebury.edu/stereo/data/scenes2014/zip"
|
| 581 |
+
# train and additional splits have 2 different calibration settings
|
| 582 |
+
root = Path(root) / "Middlebury2014"
|
| 583 |
+
split_name = self.split
|
| 584 |
+
|
| 585 |
+
if split_name != "test":
|
| 586 |
+
for split_scene in self.splits[split_name]:
|
| 587 |
+
split_root = root / split_name
|
| 588 |
+
for calibration in ["perfect", "imperfect"]:
|
| 589 |
+
scene_name = f"{split_scene}-{calibration}"
|
| 590 |
+
scene_url = f"{base_url}/{scene_name}.zip"
|
| 591 |
+
print(f"Downloading {scene_url}")
|
| 592 |
+
# download the scene only if it doesn't exist
|
| 593 |
+
if not (split_root / scene_name).exists():
|
| 594 |
+
download_and_extract_archive(
|
| 595 |
+
url=scene_url,
|
| 596 |
+
filename=f"{scene_name}.zip",
|
| 597 |
+
download_root=str(split_root),
|
| 598 |
+
remove_finished=True,
|
| 599 |
+
)
|
| 600 |
+
else:
|
| 601 |
+
os.makedirs(root / "test")
|
| 602 |
+
if any(s not in os.listdir(root / "test") for s in self.splits["test"]):
|
| 603 |
+
# test split is downloaded from a different location
|
| 604 |
+
test_set_url = "https://vision.middlebury.edu/stereo/submit3/zip/MiddEval3-data-F.zip"
|
| 605 |
+
# the unzip is going to produce a directory MiddEval3 with two subdirectories trainingF and testF
|
| 606 |
+
# we want to move the contents from testF into the directory
|
| 607 |
+
download_and_extract_archive(url=test_set_url, download_root=str(root), remove_finished=True)
|
| 608 |
+
for scene_dir, scene_names, _ in os.walk(str(root / "MiddEval3/testF")):
|
| 609 |
+
for scene in scene_names:
|
| 610 |
+
scene_dst_dir = root / "test"
|
| 611 |
+
scene_src_dir = Path(scene_dir) / scene
|
| 612 |
+
os.makedirs(scene_dst_dir, exist_ok=True)
|
| 613 |
+
shutil.move(str(scene_src_dir), str(scene_dst_dir))
|
| 614 |
+
|
| 615 |
+
# cleanup MiddEval3 directory
|
| 616 |
+
shutil.rmtree(str(root / "MiddEval3"))
|
| 617 |
+
|
| 618 |
+
def __getitem__(self, index: int) -> T2:
|
| 619 |
+
"""Return example at given index.
|
| 620 |
+
|
| 621 |
+
Args:
|
| 622 |
+
index(int): The index of the example to retrieve
|
| 623 |
+
|
| 624 |
+
Returns:
|
| 625 |
+
tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
|
| 626 |
+
The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
|
| 627 |
+
``valid_mask`` is implicitly ``None`` for `split=test`.
|
| 628 |
+
"""
|
| 629 |
+
return cast(T2, super().__getitem__(index))
|
| 630 |
+
|
| 631 |
+
|
| 632 |
+
class CREStereo(StereoMatchingDataset):
|
| 633 |
+
"""Synthetic dataset used in training the `CREStereo <https://arxiv.org/pdf/2203.11483.pdf>`_ architecture.
|
| 634 |
+
Dataset details on the official paper `repo <https://github.com/megvii-research/CREStereo>`_.
|
| 635 |
+
|
| 636 |
+
The dataset is expected to have the following structure: ::
|
| 637 |
+
|
| 638 |
+
root
|
| 639 |
+
CREStereo
|
| 640 |
+
tree
|
| 641 |
+
img1_left.jpg
|
| 642 |
+
img1_right.jpg
|
| 643 |
+
img1_left.disp.jpg
|
| 644 |
+
img1_right.disp.jpg
|
| 645 |
+
img2_left.jpg
|
| 646 |
+
img2_right.jpg
|
| 647 |
+
img2_left.disp.jpg
|
| 648 |
+
img2_right.disp.jpg
|
| 649 |
+
...
|
| 650 |
+
shapenet
|
| 651 |
+
img1_left.jpg
|
| 652 |
+
img1_right.jpg
|
| 653 |
+
img1_left.disp.jpg
|
| 654 |
+
img1_right.disp.jpg
|
| 655 |
+
...
|
| 656 |
+
reflective
|
| 657 |
+
img1_left.jpg
|
| 658 |
+
img1_right.jpg
|
| 659 |
+
img1_left.disp.jpg
|
| 660 |
+
img1_right.disp.jpg
|
| 661 |
+
...
|
| 662 |
+
hole
|
| 663 |
+
img1_left.jpg
|
| 664 |
+
img1_right.jpg
|
| 665 |
+
img1_left.disp.jpg
|
| 666 |
+
img1_right.disp.jpg
|
| 667 |
+
...
|
| 668 |
+
|
| 669 |
+
Args:
|
| 670 |
+
root (str): Root directory of the dataset.
|
| 671 |
+
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
|
| 672 |
+
"""
|
| 673 |
+
|
| 674 |
+
_has_built_in_disparity_mask = True
|
| 675 |
+
|
| 676 |
+
def __init__(
|
| 677 |
+
self,
|
| 678 |
+
root: Union[str, Path],
|
| 679 |
+
transforms: Optional[Callable] = None,
|
| 680 |
+
) -> None:
|
| 681 |
+
super().__init__(root, transforms)
|
| 682 |
+
|
| 683 |
+
root = Path(root) / "CREStereo"
|
| 684 |
+
|
| 685 |
+
dirs = ["shapenet", "reflective", "tree", "hole"]
|
| 686 |
+
|
| 687 |
+
for s in dirs:
|
| 688 |
+
left_image_pattern = str(root / s / "*_left.jpg")
|
| 689 |
+
right_image_pattern = str(root / s / "*_right.jpg")
|
| 690 |
+
imgs = self._scan_pairs(left_image_pattern, right_image_pattern)
|
| 691 |
+
self._images += imgs
|
| 692 |
+
|
| 693 |
+
left_disparity_pattern = str(root / s / "*_left.disp.png")
|
| 694 |
+
right_disparity_pattern = str(root / s / "*_right.disp.png")
|
| 695 |
+
disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
|
| 696 |
+
self._disparities += disparities
|
| 697 |
+
|
| 698 |
+
def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
|
| 699 |
+
disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
|
| 700 |
+
# unsqueeze the disparity map into (C, H, W) format
|
| 701 |
+
disparity_map = disparity_map[None, :, :] / 32.0
|
| 702 |
+
valid_mask = None
|
| 703 |
+
return disparity_map, valid_mask
|
| 704 |
+
|
| 705 |
+
def __getitem__(self, index: int) -> T1:
|
| 706 |
+
"""Return example at given index.
|
| 707 |
+
|
| 708 |
+
Args:
|
| 709 |
+
index(int): The index of the example to retrieve
|
| 710 |
+
|
| 711 |
+
Returns:
|
| 712 |
+
tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
|
| 713 |
+
The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
|
| 714 |
+
``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not
|
| 715 |
+
generate a valid mask.
|
| 716 |
+
"""
|
| 717 |
+
return cast(T1, super().__getitem__(index))
|
| 718 |
+
|
| 719 |
+
|
| 720 |
+
class FallingThingsStereo(StereoMatchingDataset):
|
| 721 |
+
"""`FallingThings <https://research.nvidia.com/publication/2018-06_falling-things-synthetic-dataset-3d-object-detection-and-pose-estimation>`_ dataset.
|
| 722 |
+
|
| 723 |
+
The dataset is expected to have the following structure: ::
|
| 724 |
+
|
| 725 |
+
root
|
| 726 |
+
FallingThings
|
| 727 |
+
single
|
| 728 |
+
dir1
|
| 729 |
+
scene1
|
| 730 |
+
_object_settings.json
|
| 731 |
+
_camera_settings.json
|
| 732 |
+
image1.left.depth.png
|
| 733 |
+
image1.right.depth.png
|
| 734 |
+
image1.left.jpg
|
| 735 |
+
image1.right.jpg
|
| 736 |
+
image2.left.depth.png
|
| 737 |
+
image2.right.depth.png
|
| 738 |
+
image2.left.jpg
|
| 739 |
+
image2.right
|
| 740 |
+
...
|
| 741 |
+
scene2
|
| 742 |
+
...
|
| 743 |
+
mixed
|
| 744 |
+
scene1
|
| 745 |
+
_object_settings.json
|
| 746 |
+
_camera_settings.json
|
| 747 |
+
image1.left.depth.png
|
| 748 |
+
image1.right.depth.png
|
| 749 |
+
image1.left.jpg
|
| 750 |
+
image1.right.jpg
|
| 751 |
+
image2.left.depth.png
|
| 752 |
+
image2.right.depth.png
|
| 753 |
+
image2.left.jpg
|
| 754 |
+
image2.right
|
| 755 |
+
...
|
| 756 |
+
scene2
|
| 757 |
+
...
|
| 758 |
+
|
| 759 |
+
Args:
|
| 760 |
+
root (str or ``pathlib.Path``): Root directory where FallingThings is located.
|
| 761 |
+
variant (string): Which variant to use. Either "single", "mixed", or "both".
|
| 762 |
+
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
|
| 763 |
+
"""
|
| 764 |
+
|
| 765 |
+
def __init__(self, root: Union[str, Path], variant: str = "single", transforms: Optional[Callable] = None) -> None:
|
| 766 |
+
super().__init__(root, transforms)
|
| 767 |
+
|
| 768 |
+
root = Path(root) / "FallingThings"
|
| 769 |
+
|
| 770 |
+
verify_str_arg(variant, "variant", valid_values=("single", "mixed", "both"))
|
| 771 |
+
|
| 772 |
+
variants = {
|
| 773 |
+
"single": ["single"],
|
| 774 |
+
"mixed": ["mixed"],
|
| 775 |
+
"both": ["single", "mixed"],
|
| 776 |
+
}[variant]
|
| 777 |
+
|
| 778 |
+
split_prefix = {
|
| 779 |
+
"single": Path("*") / "*",
|
| 780 |
+
"mixed": Path("*"),
|
| 781 |
+
}
|
| 782 |
+
|
| 783 |
+
for s in variants:
|
| 784 |
+
left_img_pattern = str(root / s / split_prefix[s] / "*.left.jpg")
|
| 785 |
+
right_img_pattern = str(root / s / split_prefix[s] / "*.right.jpg")
|
| 786 |
+
self._images += self._scan_pairs(left_img_pattern, right_img_pattern)
|
| 787 |
+
|
| 788 |
+
left_disparity_pattern = str(root / s / split_prefix[s] / "*.left.depth.png")
|
| 789 |
+
right_disparity_pattern = str(root / s / split_prefix[s] / "*.right.depth.png")
|
| 790 |
+
self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
|
| 791 |
+
|
| 792 |
+
def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
|
| 793 |
+
# (H, W) image
|
| 794 |
+
depth = np.asarray(Image.open(file_path))
|
| 795 |
+
# as per https://research.nvidia.com/sites/default/files/pubs/2018-06_Falling-Things/readme_0.txt
|
| 796 |
+
# in order to extract disparity from depth maps
|
| 797 |
+
camera_settings_path = Path(file_path).parent / "_camera_settings.json"
|
| 798 |
+
with open(camera_settings_path, "r") as f:
|
| 799 |
+
# inverse of depth-from-disparity equation: depth = (baseline * focal) / (disparity * pixel_constant)
|
| 800 |
+
intrinsics = json.load(f)
|
| 801 |
+
focal = intrinsics["camera_settings"][0]["intrinsic_settings"]["fx"]
|
| 802 |
+
baseline, pixel_constant = 6, 100 # pixel constant is inverted
|
| 803 |
+
disparity_map = (baseline * focal * pixel_constant) / depth.astype(np.float32)
|
| 804 |
+
# unsqueeze disparity to (C, H, W)
|
| 805 |
+
disparity_map = disparity_map[None, :, :]
|
| 806 |
+
valid_mask = None
|
| 807 |
+
return disparity_map, valid_mask
|
| 808 |
+
|
| 809 |
+
def __getitem__(self, index: int) -> T1:
|
| 810 |
+
"""Return example at given index.
|
| 811 |
+
|
| 812 |
+
Args:
|
| 813 |
+
index(int): The index of the example to retrieve
|
| 814 |
+
|
| 815 |
+
Returns:
|
| 816 |
+
tuple: A 3-tuple with ``(img_left, img_right, disparity)``.
|
| 817 |
+
The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
|
| 818 |
+
If a ``valid_mask`` is generated within the ``transforms`` parameter,
|
| 819 |
+
a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
|
| 820 |
+
"""
|
| 821 |
+
return cast(T1, super().__getitem__(index))
|
| 822 |
+
|
| 823 |
+
|
| 824 |
+
class SceneFlowStereo(StereoMatchingDataset):
|
| 825 |
+
"""Dataset interface for `Scene Flow <https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html>`_ datasets.
|
| 826 |
+
This interface provides access to the `FlyingThings3D, `Monkaa` and `Driving` datasets.
|
| 827 |
+
|
| 828 |
+
The dataset is expected to have the following structure: ::
|
| 829 |
+
|
| 830 |
+
root
|
| 831 |
+
SceneFlow
|
| 832 |
+
Monkaa
|
| 833 |
+
frames_cleanpass
|
| 834 |
+
scene1
|
| 835 |
+
left
|
| 836 |
+
img1.png
|
| 837 |
+
img2.png
|
| 838 |
+
right
|
| 839 |
+
img1.png
|
| 840 |
+
img2.png
|
| 841 |
+
scene2
|
| 842 |
+
left
|
| 843 |
+
img1.png
|
| 844 |
+
img2.png
|
| 845 |
+
right
|
| 846 |
+
img1.png
|
| 847 |
+
img2.png
|
| 848 |
+
frames_finalpass
|
| 849 |
+
scene1
|
| 850 |
+
left
|
| 851 |
+
img1.png
|
| 852 |
+
img2.png
|
| 853 |
+
right
|
| 854 |
+
img1.png
|
| 855 |
+
img2.png
|
| 856 |
+
...
|
| 857 |
+
...
|
| 858 |
+
disparity
|
| 859 |
+
scene1
|
| 860 |
+
left
|
| 861 |
+
img1.pfm
|
| 862 |
+
img2.pfm
|
| 863 |
+
right
|
| 864 |
+
img1.pfm
|
| 865 |
+
img2.pfm
|
| 866 |
+
FlyingThings3D
|
| 867 |
+
...
|
| 868 |
+
...
|
| 869 |
+
|
| 870 |
+
Args:
|
| 871 |
+
root (str or ``pathlib.Path``): Root directory where SceneFlow is located.
|
| 872 |
+
variant (string): Which dataset variant to user, "FlyingThings3D" (default), "Monkaa" or "Driving".
|
| 873 |
+
pass_name (string): Which pass to use, "clean" (default), "final" or "both".
|
| 874 |
+
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
|
| 875 |
+
|
| 876 |
+
"""
|
| 877 |
+
|
| 878 |
+
def __init__(
|
| 879 |
+
self,
|
| 880 |
+
root: Union[str, Path],
|
| 881 |
+
variant: str = "FlyingThings3D",
|
| 882 |
+
pass_name: str = "clean",
|
| 883 |
+
transforms: Optional[Callable] = None,
|
| 884 |
+
) -> None:
|
| 885 |
+
super().__init__(root, transforms)
|
| 886 |
+
|
| 887 |
+
root = Path(root) / "SceneFlow"
|
| 888 |
+
|
| 889 |
+
verify_str_arg(variant, "variant", valid_values=("FlyingThings3D", "Driving", "Monkaa"))
|
| 890 |
+
verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final", "both"))
|
| 891 |
+
|
| 892 |
+
passes = {
|
| 893 |
+
"clean": ["frames_cleanpass"],
|
| 894 |
+
"final": ["frames_finalpass"],
|
| 895 |
+
"both": ["frames_cleanpass", "frames_finalpass"],
|
| 896 |
+
}[pass_name]
|
| 897 |
+
|
| 898 |
+
root = root / variant
|
| 899 |
+
|
| 900 |
+
prefix_directories = {
|
| 901 |
+
"Monkaa": Path("*"),
|
| 902 |
+
"FlyingThings3D": Path("*") / "*" / "*",
|
| 903 |
+
"Driving": Path("*") / "*" / "*",
|
| 904 |
+
}
|
| 905 |
+
|
| 906 |
+
for p in passes:
|
| 907 |
+
left_image_pattern = str(root / p / prefix_directories[variant] / "left" / "*.png")
|
| 908 |
+
right_image_pattern = str(root / p / prefix_directories[variant] / "right" / "*.png")
|
| 909 |
+
self._images += self._scan_pairs(left_image_pattern, right_image_pattern)
|
| 910 |
+
|
| 911 |
+
left_disparity_pattern = str(root / "disparity" / prefix_directories[variant] / "left" / "*.pfm")
|
| 912 |
+
right_disparity_pattern = str(root / "disparity" / prefix_directories[variant] / "right" / "*.pfm")
|
| 913 |
+
self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
|
| 914 |
+
|
| 915 |
+
def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
|
| 916 |
+
disparity_map = _read_pfm_file(file_path)
|
| 917 |
+
disparity_map = np.abs(disparity_map) # ensure that the disparity is positive
|
| 918 |
+
valid_mask = None
|
| 919 |
+
return disparity_map, valid_mask
|
| 920 |
+
|
| 921 |
+
def __getitem__(self, index: int) -> T1:
|
| 922 |
+
"""Return example at given index.
|
| 923 |
+
|
| 924 |
+
Args:
|
| 925 |
+
index(int): The index of the example to retrieve
|
| 926 |
+
|
| 927 |
+
Returns:
|
| 928 |
+
tuple: A 3-tuple with ``(img_left, img_right, disparity)``.
|
| 929 |
+
The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
|
| 930 |
+
If a ``valid_mask`` is generated within the ``transforms`` parameter,
|
| 931 |
+
a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
|
| 932 |
+
"""
|
| 933 |
+
return cast(T1, super().__getitem__(index))
|
| 934 |
+
|
| 935 |
+
|
| 936 |
+
class SintelStereo(StereoMatchingDataset):
|
| 937 |
+
"""Sintel `Stereo Dataset <http://sintel.is.tue.mpg.de/stereo>`_.
|
| 938 |
+
|
| 939 |
+
The dataset is expected to have the following structure: ::
|
| 940 |
+
|
| 941 |
+
root
|
| 942 |
+
Sintel
|
| 943 |
+
training
|
| 944 |
+
final_left
|
| 945 |
+
scene1
|
| 946 |
+
img1.png
|
| 947 |
+
img2.png
|
| 948 |
+
...
|
| 949 |
+
...
|
| 950 |
+
final_right
|
| 951 |
+
scene2
|
| 952 |
+
img1.png
|
| 953 |
+
img2.png
|
| 954 |
+
...
|
| 955 |
+
...
|
| 956 |
+
disparities
|
| 957 |
+
scene1
|
| 958 |
+
img1.png
|
| 959 |
+
img2.png
|
| 960 |
+
...
|
| 961 |
+
...
|
| 962 |
+
occlusions
|
| 963 |
+
scene1
|
| 964 |
+
img1.png
|
| 965 |
+
img2.png
|
| 966 |
+
...
|
| 967 |
+
...
|
| 968 |
+
outofframe
|
| 969 |
+
scene1
|
| 970 |
+
img1.png
|
| 971 |
+
img2.png
|
| 972 |
+
...
|
| 973 |
+
...
|
| 974 |
+
|
| 975 |
+
Args:
|
| 976 |
+
root (str or ``pathlib.Path``): Root directory where Sintel Stereo is located.
|
| 977 |
+
pass_name (string): The name of the pass to use, either "final", "clean" or "both".
|
| 978 |
+
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
|
| 979 |
+
"""
|
| 980 |
+
|
| 981 |
+
_has_built_in_disparity_mask = True
|
| 982 |
+
|
| 983 |
+
def __init__(self, root: Union[str, Path], pass_name: str = "final", transforms: Optional[Callable] = None) -> None:
|
| 984 |
+
super().__init__(root, transforms)
|
| 985 |
+
|
| 986 |
+
verify_str_arg(pass_name, "pass_name", valid_values=("final", "clean", "both"))
|
| 987 |
+
|
| 988 |
+
root = Path(root) / "Sintel"
|
| 989 |
+
pass_names = {
|
| 990 |
+
"final": ["final"],
|
| 991 |
+
"clean": ["clean"],
|
| 992 |
+
"both": ["final", "clean"],
|
| 993 |
+
}[pass_name]
|
| 994 |
+
|
| 995 |
+
for p in pass_names:
|
| 996 |
+
left_img_pattern = str(root / "training" / f"{p}_left" / "*" / "*.png")
|
| 997 |
+
right_img_pattern = str(root / "training" / f"{p}_right" / "*" / "*.png")
|
| 998 |
+
self._images += self._scan_pairs(left_img_pattern, right_img_pattern)
|
| 999 |
+
|
| 1000 |
+
disparity_pattern = str(root / "training" / "disparities" / "*" / "*.png")
|
| 1001 |
+
self._disparities += self._scan_pairs(disparity_pattern, None)
|
| 1002 |
+
|
| 1003 |
+
def _get_occlussion_mask_paths(self, file_path: str) -> Tuple[str, str]:
|
| 1004 |
+
# helper function to get the occlusion mask paths
|
| 1005 |
+
# a path will look like .../.../.../training/disparities/scene1/img1.png
|
| 1006 |
+
# we want to get something like .../.../.../training/occlusions/scene1/img1.png
|
| 1007 |
+
fpath = Path(file_path)
|
| 1008 |
+
basename = fpath.name
|
| 1009 |
+
scenedir = fpath.parent
|
| 1010 |
+
# the parent of the scenedir is actually the disparity dir
|
| 1011 |
+
sampledir = scenedir.parent.parent
|
| 1012 |
+
|
| 1013 |
+
occlusion_path = str(sampledir / "occlusions" / scenedir.name / basename)
|
| 1014 |
+
outofframe_path = str(sampledir / "outofframe" / scenedir.name / basename)
|
| 1015 |
+
|
| 1016 |
+
if not os.path.exists(occlusion_path):
|
| 1017 |
+
raise FileNotFoundError(f"Occlusion mask {occlusion_path} does not exist")
|
| 1018 |
+
|
| 1019 |
+
if not os.path.exists(outofframe_path):
|
| 1020 |
+
raise FileNotFoundError(f"Out of frame mask {outofframe_path} does not exist")
|
| 1021 |
+
|
| 1022 |
+
return occlusion_path, outofframe_path
|
| 1023 |
+
|
| 1024 |
+
def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]:
|
| 1025 |
+
if file_path is None:
|
| 1026 |
+
return None, None
|
| 1027 |
+
|
| 1028 |
+
# disparity decoding as per Sintel instructions in the README provided with the dataset
|
| 1029 |
+
disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
|
| 1030 |
+
r, g, b = np.split(disparity_map, 3, axis=-1)
|
| 1031 |
+
disparity_map = r * 4 + g / (2**6) + b / (2**14)
|
| 1032 |
+
# reshape into (C, H, W) format
|
| 1033 |
+
disparity_map = np.transpose(disparity_map, (2, 0, 1))
|
| 1034 |
+
# find the appropriate file paths
|
| 1035 |
+
occlued_mask_path, out_of_frame_mask_path = self._get_occlussion_mask_paths(file_path)
|
| 1036 |
+
# occlusion masks
|
| 1037 |
+
valid_mask = np.asarray(Image.open(occlued_mask_path)) == 0
|
| 1038 |
+
# out of frame masks
|
| 1039 |
+
off_mask = np.asarray(Image.open(out_of_frame_mask_path)) == 0
|
| 1040 |
+
# combine the masks together
|
| 1041 |
+
valid_mask = np.logical_and(off_mask, valid_mask)
|
| 1042 |
+
return disparity_map, valid_mask
|
| 1043 |
+
|
| 1044 |
+
def __getitem__(self, index: int) -> T2:
|
| 1045 |
+
"""Return example at given index.
|
| 1046 |
+
|
| 1047 |
+
Args:
|
| 1048 |
+
index(int): The index of the example to retrieve
|
| 1049 |
+
|
| 1050 |
+
Returns:
|
| 1051 |
+
tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
|
| 1052 |
+
The disparity is a numpy array of shape (1, H, W) and the images are PIL images whilst
|
| 1053 |
+
the valid_mask is a numpy array of shape (H, W).
|
| 1054 |
+
"""
|
| 1055 |
+
return cast(T2, super().__getitem__(index))
|
| 1056 |
+
|
| 1057 |
+
|
| 1058 |
+
class InStereo2k(StereoMatchingDataset):
|
| 1059 |
+
"""`InStereo2k <https://github.com/YuhuaXu/StereoDataset>`_ dataset.
|
| 1060 |
+
|
| 1061 |
+
The dataset is expected to have the following structure: ::
|
| 1062 |
+
|
| 1063 |
+
root
|
| 1064 |
+
InStereo2k
|
| 1065 |
+
train
|
| 1066 |
+
scene1
|
| 1067 |
+
left.png
|
| 1068 |
+
right.png
|
| 1069 |
+
left_disp.png
|
| 1070 |
+
right_disp.png
|
| 1071 |
+
...
|
| 1072 |
+
scene2
|
| 1073 |
+
...
|
| 1074 |
+
test
|
| 1075 |
+
scene1
|
| 1076 |
+
left.png
|
| 1077 |
+
right.png
|
| 1078 |
+
left_disp.png
|
| 1079 |
+
right_disp.png
|
| 1080 |
+
...
|
| 1081 |
+
scene2
|
| 1082 |
+
...
|
| 1083 |
+
|
| 1084 |
+
Args:
|
| 1085 |
+
root (str or ``pathlib.Path``): Root directory where InStereo2k is located.
|
| 1086 |
+
split (string): Either "train" or "test".
|
| 1087 |
+
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
|
| 1088 |
+
"""
|
| 1089 |
+
|
| 1090 |
+
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
|
| 1091 |
+
super().__init__(root, transforms)
|
| 1092 |
+
|
| 1093 |
+
root = Path(root) / "InStereo2k" / split
|
| 1094 |
+
|
| 1095 |
+
verify_str_arg(split, "split", valid_values=("train", "test"))
|
| 1096 |
+
|
| 1097 |
+
left_img_pattern = str(root / "*" / "left.png")
|
| 1098 |
+
right_img_pattern = str(root / "*" / "right.png")
|
| 1099 |
+
self._images = self._scan_pairs(left_img_pattern, right_img_pattern)
|
| 1100 |
+
|
| 1101 |
+
left_disparity_pattern = str(root / "*" / "left_disp.png")
|
| 1102 |
+
right_disparity_pattern = str(root / "*" / "right_disp.png")
|
| 1103 |
+
self._disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
|
| 1104 |
+
|
| 1105 |
+
def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
|
| 1106 |
+
disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
|
| 1107 |
+
# unsqueeze disparity to (C, H, W)
|
| 1108 |
+
disparity_map = disparity_map[None, :, :] / 1024.0
|
| 1109 |
+
valid_mask = None
|
| 1110 |
+
return disparity_map, valid_mask
|
| 1111 |
+
|
| 1112 |
+
def __getitem__(self, index: int) -> T1:
|
| 1113 |
+
"""Return example at given index.
|
| 1114 |
+
|
| 1115 |
+
Args:
|
| 1116 |
+
index(int): The index of the example to retrieve
|
| 1117 |
+
|
| 1118 |
+
Returns:
|
| 1119 |
+
tuple: A 3-tuple with ``(img_left, img_right, disparity)``.
|
| 1120 |
+
The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
|
| 1121 |
+
If a ``valid_mask`` is generated within the ``transforms`` parameter,
|
| 1122 |
+
a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
|
| 1123 |
+
"""
|
| 1124 |
+
return cast(T1, super().__getitem__(index))
|
| 1125 |
+
|
| 1126 |
+
|
| 1127 |
+
class ETH3DStereo(StereoMatchingDataset):
|
| 1128 |
+
"""ETH3D `Low-Res Two-View <https://www.eth3d.net/datasets>`_ dataset.
|
| 1129 |
+
|
| 1130 |
+
The dataset is expected to have the following structure: ::
|
| 1131 |
+
|
| 1132 |
+
root
|
| 1133 |
+
ETH3D
|
| 1134 |
+
two_view_training
|
| 1135 |
+
scene1
|
| 1136 |
+
im1.png
|
| 1137 |
+
im0.png
|
| 1138 |
+
images.txt
|
| 1139 |
+
cameras.txt
|
| 1140 |
+
calib.txt
|
| 1141 |
+
scene2
|
| 1142 |
+
im1.png
|
| 1143 |
+
im0.png
|
| 1144 |
+
images.txt
|
| 1145 |
+
cameras.txt
|
| 1146 |
+
calib.txt
|
| 1147 |
+
...
|
| 1148 |
+
two_view_training_gt
|
| 1149 |
+
scene1
|
| 1150 |
+
disp0GT.pfm
|
| 1151 |
+
mask0nocc.png
|
| 1152 |
+
scene2
|
| 1153 |
+
disp0GT.pfm
|
| 1154 |
+
mask0nocc.png
|
| 1155 |
+
...
|
| 1156 |
+
two_view_testing
|
| 1157 |
+
scene1
|
| 1158 |
+
im1.png
|
| 1159 |
+
im0.png
|
| 1160 |
+
images.txt
|
| 1161 |
+
cameras.txt
|
| 1162 |
+
calib.txt
|
| 1163 |
+
scene2
|
| 1164 |
+
im1.png
|
| 1165 |
+
im0.png
|
| 1166 |
+
images.txt
|
| 1167 |
+
cameras.txt
|
| 1168 |
+
calib.txt
|
| 1169 |
+
...
|
| 1170 |
+
|
| 1171 |
+
Args:
|
| 1172 |
+
root (str or ``pathlib.Path``): Root directory of the ETH3D Dataset.
|
| 1173 |
+
split (string, optional): The dataset split of scenes, either "train" (default) or "test".
|
| 1174 |
+
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
|
| 1175 |
+
"""
|
| 1176 |
+
|
| 1177 |
+
_has_built_in_disparity_mask = True
|
| 1178 |
+
|
| 1179 |
+
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
|
| 1180 |
+
super().__init__(root, transforms)
|
| 1181 |
+
|
| 1182 |
+
verify_str_arg(split, "split", valid_values=("train", "test"))
|
| 1183 |
+
|
| 1184 |
+
root = Path(root) / "ETH3D"
|
| 1185 |
+
|
| 1186 |
+
img_dir = "two_view_training" if split == "train" else "two_view_test"
|
| 1187 |
+
anot_dir = "two_view_training_gt"
|
| 1188 |
+
|
| 1189 |
+
left_img_pattern = str(root / img_dir / "*" / "im0.png")
|
| 1190 |
+
right_img_pattern = str(root / img_dir / "*" / "im1.png")
|
| 1191 |
+
self._images = self._scan_pairs(left_img_pattern, right_img_pattern)
|
| 1192 |
+
|
| 1193 |
+
if split == "test":
|
| 1194 |
+
self._disparities = list((None, None) for _ in self._images)
|
| 1195 |
+
else:
|
| 1196 |
+
disparity_pattern = str(root / anot_dir / "*" / "disp0GT.pfm")
|
| 1197 |
+
self._disparities = self._scan_pairs(disparity_pattern, None)
|
| 1198 |
+
|
| 1199 |
+
def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]:
|
| 1200 |
+
# test split has no disparity maps
|
| 1201 |
+
if file_path is None:
|
| 1202 |
+
return None, None
|
| 1203 |
+
|
| 1204 |
+
disparity_map = _read_pfm_file(file_path)
|
| 1205 |
+
disparity_map = np.abs(disparity_map) # ensure that the disparity is positive
|
| 1206 |
+
mask_path = Path(file_path).parent / "mask0nocc.png"
|
| 1207 |
+
valid_mask = Image.open(mask_path)
|
| 1208 |
+
valid_mask = np.asarray(valid_mask).astype(bool)
|
| 1209 |
+
return disparity_map, valid_mask
|
| 1210 |
+
|
| 1211 |
+
def __getitem__(self, index: int) -> T2:
|
| 1212 |
+
"""Return example at given index.
|
| 1213 |
+
|
| 1214 |
+
Args:
|
| 1215 |
+
index(int): The index of the example to retrieve
|
| 1216 |
+
|
| 1217 |
+
Returns:
|
| 1218 |
+
tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
|
| 1219 |
+
The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
|
| 1220 |
+
``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not
|
| 1221 |
+
generate a valid mask.
|
| 1222 |
+
Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
|
| 1223 |
+
"""
|
| 1224 |
+
return cast(T2, super().__getitem__(index))
|
.venv/lib/python3.11/site-packages/torchvision/datasets/caltech.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import os.path
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Any, Callable, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
from .utils import download_and_extract_archive, verify_str_arg
|
| 9 |
+
from .vision import VisionDataset
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Caltech101(VisionDataset):
|
| 13 |
+
"""`Caltech 101 <https://data.caltech.edu/records/20086>`_ Dataset.
|
| 14 |
+
|
| 15 |
+
.. warning::
|
| 16 |
+
|
| 17 |
+
This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
root (str or ``pathlib.Path``): Root directory of dataset where directory
|
| 21 |
+
``caltech101`` exists or will be saved to if download is set to True.
|
| 22 |
+
target_type (string or list, optional): Type of target to use, ``category`` or
|
| 23 |
+
``annotation``. Can also be a list to output a tuple with all specified
|
| 24 |
+
target types. ``category`` represents the target class, and
|
| 25 |
+
``annotation`` is a list of points from a hand-generated outline.
|
| 26 |
+
Defaults to ``category``.
|
| 27 |
+
transform (callable, optional): A function/transform that takes in a PIL image
|
| 28 |
+
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
| 29 |
+
target_transform (callable, optional): A function/transform that takes in the
|
| 30 |
+
target and transforms it.
|
| 31 |
+
download (bool, optional): If true, downloads the dataset from the internet and
|
| 32 |
+
puts it in root directory. If dataset is already downloaded, it is not
|
| 33 |
+
downloaded again.
|
| 34 |
+
|
| 35 |
+
.. warning::
|
| 36 |
+
|
| 37 |
+
To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
root: Union[str, Path],
|
| 43 |
+
target_type: Union[List[str], str] = "category",
|
| 44 |
+
transform: Optional[Callable] = None,
|
| 45 |
+
target_transform: Optional[Callable] = None,
|
| 46 |
+
download: bool = False,
|
| 47 |
+
) -> None:
|
| 48 |
+
super().__init__(os.path.join(root, "caltech101"), transform=transform, target_transform=target_transform)
|
| 49 |
+
os.makedirs(self.root, exist_ok=True)
|
| 50 |
+
if isinstance(target_type, str):
|
| 51 |
+
target_type = [target_type]
|
| 52 |
+
self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation")) for t in target_type]
|
| 53 |
+
|
| 54 |
+
if download:
|
| 55 |
+
self.download()
|
| 56 |
+
|
| 57 |
+
if not self._check_integrity():
|
| 58 |
+
raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
|
| 59 |
+
|
| 60 |
+
self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories")))
|
| 61 |
+
self.categories.remove("BACKGROUND_Google") # this is not a real class
|
| 62 |
+
|
| 63 |
+
# For some reason, the category names in "101_ObjectCategories" and
|
| 64 |
+
# "Annotations" do not always match. This is a manual map between the
|
| 65 |
+
# two. Defaults to using same name, since most names are fine.
|
| 66 |
+
name_map = {
|
| 67 |
+
"Faces": "Faces_2",
|
| 68 |
+
"Faces_easy": "Faces_3",
|
| 69 |
+
"Motorbikes": "Motorbikes_16",
|
| 70 |
+
"airplanes": "Airplanes_Side_2",
|
| 71 |
+
}
|
| 72 |
+
self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories))
|
| 73 |
+
|
| 74 |
+
self.index: List[int] = []
|
| 75 |
+
self.y = []
|
| 76 |
+
for (i, c) in enumerate(self.categories):
|
| 77 |
+
n = len(os.listdir(os.path.join(self.root, "101_ObjectCategories", c)))
|
| 78 |
+
self.index.extend(range(1, n + 1))
|
| 79 |
+
self.y.extend(n * [i])
|
| 80 |
+
|
| 81 |
+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
| 82 |
+
"""
|
| 83 |
+
Args:
|
| 84 |
+
index (int): Index
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
tuple: (image, target) where the type of target specified by target_type.
|
| 88 |
+
"""
|
| 89 |
+
import scipy.io
|
| 90 |
+
|
| 91 |
+
img = Image.open(
|
| 92 |
+
os.path.join(
|
| 93 |
+
self.root,
|
| 94 |
+
"101_ObjectCategories",
|
| 95 |
+
self.categories[self.y[index]],
|
| 96 |
+
f"image_{self.index[index]:04d}.jpg",
|
| 97 |
+
)
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
target: Any = []
|
| 101 |
+
for t in self.target_type:
|
| 102 |
+
if t == "category":
|
| 103 |
+
target.append(self.y[index])
|
| 104 |
+
elif t == "annotation":
|
| 105 |
+
data = scipy.io.loadmat(
|
| 106 |
+
os.path.join(
|
| 107 |
+
self.root,
|
| 108 |
+
"Annotations",
|
| 109 |
+
self.annotation_categories[self.y[index]],
|
| 110 |
+
f"annotation_{self.index[index]:04d}.mat",
|
| 111 |
+
)
|
| 112 |
+
)
|
| 113 |
+
target.append(data["obj_contour"])
|
| 114 |
+
target = tuple(target) if len(target) > 1 else target[0]
|
| 115 |
+
|
| 116 |
+
if self.transform is not None:
|
| 117 |
+
img = self.transform(img)
|
| 118 |
+
|
| 119 |
+
if self.target_transform is not None:
|
| 120 |
+
target = self.target_transform(target)
|
| 121 |
+
|
| 122 |
+
return img, target
|
| 123 |
+
|
| 124 |
+
def _check_integrity(self) -> bool:
|
| 125 |
+
# can be more robust and check hash of files
|
| 126 |
+
return os.path.exists(os.path.join(self.root, "101_ObjectCategories"))
|
| 127 |
+
|
| 128 |
+
def __len__(self) -> int:
|
| 129 |
+
return len(self.index)
|
| 130 |
+
|
| 131 |
+
def download(self) -> None:
|
| 132 |
+
if self._check_integrity():
|
| 133 |
+
print("Files already downloaded and verified")
|
| 134 |
+
return
|
| 135 |
+
|
| 136 |
+
download_and_extract_archive(
|
| 137 |
+
"https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp",
|
| 138 |
+
self.root,
|
| 139 |
+
filename="101_ObjectCategories.tar.gz",
|
| 140 |
+
md5="b224c7392d521a49829488ab0f1120d9",
|
| 141 |
+
)
|
| 142 |
+
download_and_extract_archive(
|
| 143 |
+
"https://drive.google.com/file/d/175kQy3UsZ0wUEHZjqkUDdNVssr7bgh_m",
|
| 144 |
+
self.root,
|
| 145 |
+
filename="Annotations.tar",
|
| 146 |
+
md5="6f83eeb1f24d99cab4eb377263132c91",
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
def extra_repr(self) -> str:
|
| 150 |
+
return "Target type: {target_type}".format(**self.__dict__)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class Caltech256(VisionDataset):
|
| 154 |
+
"""`Caltech 256 <https://data.caltech.edu/records/20087>`_ Dataset.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
root (str or ``pathlib.Path``): Root directory of dataset where directory
|
| 158 |
+
``caltech256`` exists or will be saved to if download is set to True.
|
| 159 |
+
transform (callable, optional): A function/transform that takes in a PIL image
|
| 160 |
+
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
| 161 |
+
target_transform (callable, optional): A function/transform that takes in the
|
| 162 |
+
target and transforms it.
|
| 163 |
+
download (bool, optional): If true, downloads the dataset from the internet and
|
| 164 |
+
puts it in root directory. If dataset is already downloaded, it is not
|
| 165 |
+
downloaded again.
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
def __init__(
|
| 169 |
+
self,
|
| 170 |
+
root: str,
|
| 171 |
+
transform: Optional[Callable] = None,
|
| 172 |
+
target_transform: Optional[Callable] = None,
|
| 173 |
+
download: bool = False,
|
| 174 |
+
) -> None:
|
| 175 |
+
super().__init__(os.path.join(root, "caltech256"), transform=transform, target_transform=target_transform)
|
| 176 |
+
os.makedirs(self.root, exist_ok=True)
|
| 177 |
+
|
| 178 |
+
if download:
|
| 179 |
+
self.download()
|
| 180 |
+
|
| 181 |
+
if not self._check_integrity():
|
| 182 |
+
raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
|
| 183 |
+
|
| 184 |
+
self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories")))
|
| 185 |
+
self.index: List[int] = []
|
| 186 |
+
self.y = []
|
| 187 |
+
for (i, c) in enumerate(self.categories):
|
| 188 |
+
n = len(
|
| 189 |
+
[
|
| 190 |
+
item
|
| 191 |
+
for item in os.listdir(os.path.join(self.root, "256_ObjectCategories", c))
|
| 192 |
+
if item.endswith(".jpg")
|
| 193 |
+
]
|
| 194 |
+
)
|
| 195 |
+
self.index.extend(range(1, n + 1))
|
| 196 |
+
self.y.extend(n * [i])
|
| 197 |
+
|
| 198 |
+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
| 199 |
+
"""
|
| 200 |
+
Args:
|
| 201 |
+
index (int): Index
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
tuple: (image, target) where target is index of the target class.
|
| 205 |
+
"""
|
| 206 |
+
img = Image.open(
|
| 207 |
+
os.path.join(
|
| 208 |
+
self.root,
|
| 209 |
+
"256_ObjectCategories",
|
| 210 |
+
self.categories[self.y[index]],
|
| 211 |
+
f"{self.y[index] + 1:03d}_{self.index[index]:04d}.jpg",
|
| 212 |
+
)
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
target = self.y[index]
|
| 216 |
+
|
| 217 |
+
if self.transform is not None:
|
| 218 |
+
img = self.transform(img)
|
| 219 |
+
|
| 220 |
+
if self.target_transform is not None:
|
| 221 |
+
target = self.target_transform(target)
|
| 222 |
+
|
| 223 |
+
return img, target
|
| 224 |
+
|
| 225 |
+
def _check_integrity(self) -> bool:
|
| 226 |
+
# can be more robust and check hash of files
|
| 227 |
+
return os.path.exists(os.path.join(self.root, "256_ObjectCategories"))
|
| 228 |
+
|
| 229 |
+
def __len__(self) -> int:
|
| 230 |
+
return len(self.index)
|
| 231 |
+
|
| 232 |
+
def download(self) -> None:
|
| 233 |
+
if self._check_integrity():
|
| 234 |
+
print("Files already downloaded and verified")
|
| 235 |
+
return
|
| 236 |
+
|
| 237 |
+
download_and_extract_archive(
|
| 238 |
+
"https://drive.google.com/file/d/1r6o0pSROcV1_VwT4oSjA2FBUSCWGuxLK",
|
| 239 |
+
self.root,
|
| 240 |
+
filename="256_ObjectCategories.tar",
|
| 241 |
+
md5="67b4f42ca05d46448c6bb8ecd2220f6d",
|
| 242 |
+
)
|
.venv/lib/python3.11/site-packages/torchvision/datasets/celeba.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
import os
|
| 3 |
+
from collections import namedtuple
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any, Callable, List, Optional, Tuple, Union
|
| 6 |
+
|
| 7 |
+
import PIL
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from .utils import check_integrity, download_file_from_google_drive, extract_archive, verify_str_arg
|
| 11 |
+
from .vision import VisionDataset
|
| 12 |
+
|
| 13 |
+
CSV = namedtuple("CSV", ["header", "index", "data"])
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class CelebA(VisionDataset):
|
| 17 |
+
"""`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
root (str or ``pathlib.Path``): Root directory where images are downloaded to.
|
| 21 |
+
split (string): One of {'train', 'valid', 'test', 'all'}.
|
| 22 |
+
Accordingly dataset is selected.
|
| 23 |
+
target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``,
|
| 24 |
+
or ``landmarks``. Can also be a list to output a tuple with all specified target types.
|
| 25 |
+
The targets represent:
|
| 26 |
+
|
| 27 |
+
- ``attr`` (Tensor shape=(40,) dtype=int): binary (0, 1) labels for attributes
|
| 28 |
+
- ``identity`` (int): label for each person (data points with the same identity are the same person)
|
| 29 |
+
- ``bbox`` (Tensor shape=(4,) dtype=int): bounding box (x, y, width, height)
|
| 30 |
+
- ``landmarks`` (Tensor shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x,
|
| 31 |
+
righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y)
|
| 32 |
+
|
| 33 |
+
Defaults to ``attr``. If empty, ``None`` will be returned as target.
|
| 34 |
+
|
| 35 |
+
transform (callable, optional): A function/transform that takes in a PIL image
|
| 36 |
+
and returns a transformed version. E.g, ``transforms.PILToTensor``
|
| 37 |
+
target_transform (callable, optional): A function/transform that takes in the
|
| 38 |
+
target and transforms it.
|
| 39 |
+
download (bool, optional): If true, downloads the dataset from the internet and
|
| 40 |
+
puts it in root directory. If dataset is already downloaded, it is not
|
| 41 |
+
downloaded again.
|
| 42 |
+
|
| 43 |
+
.. warning::
|
| 44 |
+
|
| 45 |
+
To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
base_folder = "celeba"
|
| 49 |
+
# There currently does not appear to be an easy way to extract 7z in python (without introducing additional
|
| 50 |
+
# dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available
|
| 51 |
+
# right now.
|
| 52 |
+
file_list = [
|
| 53 |
+
# File ID MD5 Hash Filename
|
| 54 |
+
("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"),
|
| 55 |
+
# ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc","b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"),
|
| 56 |
+
# ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"),
|
| 57 |
+
("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"),
|
| 58 |
+
("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"),
|
| 59 |
+
("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"),
|
| 60 |
+
("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"),
|
| 61 |
+
# ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"),
|
| 62 |
+
("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"),
|
| 63 |
+
]
|
| 64 |
+
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
root: Union[str, Path],
|
| 68 |
+
split: str = "train",
|
| 69 |
+
target_type: Union[List[str], str] = "attr",
|
| 70 |
+
transform: Optional[Callable] = None,
|
| 71 |
+
target_transform: Optional[Callable] = None,
|
| 72 |
+
download: bool = False,
|
| 73 |
+
) -> None:
|
| 74 |
+
super().__init__(root, transform=transform, target_transform=target_transform)
|
| 75 |
+
self.split = split
|
| 76 |
+
if isinstance(target_type, list):
|
| 77 |
+
self.target_type = target_type
|
| 78 |
+
else:
|
| 79 |
+
self.target_type = [target_type]
|
| 80 |
+
|
| 81 |
+
if not self.target_type and self.target_transform is not None:
|
| 82 |
+
raise RuntimeError("target_transform is specified but target_type is empty")
|
| 83 |
+
|
| 84 |
+
if download:
|
| 85 |
+
self.download()
|
| 86 |
+
|
| 87 |
+
if not self._check_integrity():
|
| 88 |
+
raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
|
| 89 |
+
|
| 90 |
+
split_map = {
|
| 91 |
+
"train": 0,
|
| 92 |
+
"valid": 1,
|
| 93 |
+
"test": 2,
|
| 94 |
+
"all": None,
|
| 95 |
+
}
|
| 96 |
+
split_ = split_map[verify_str_arg(split.lower(), "split", ("train", "valid", "test", "all"))]
|
| 97 |
+
splits = self._load_csv("list_eval_partition.txt")
|
| 98 |
+
identity = self._load_csv("identity_CelebA.txt")
|
| 99 |
+
bbox = self._load_csv("list_bbox_celeba.txt", header=1)
|
| 100 |
+
landmarks_align = self._load_csv("list_landmarks_align_celeba.txt", header=1)
|
| 101 |
+
attr = self._load_csv("list_attr_celeba.txt", header=1)
|
| 102 |
+
|
| 103 |
+
mask = slice(None) if split_ is None else (splits.data == split_).squeeze()
|
| 104 |
+
|
| 105 |
+
if mask == slice(None): # if split == "all"
|
| 106 |
+
self.filename = splits.index
|
| 107 |
+
else:
|
| 108 |
+
self.filename = [splits.index[i] for i in torch.squeeze(torch.nonzero(mask))]
|
| 109 |
+
self.identity = identity.data[mask]
|
| 110 |
+
self.bbox = bbox.data[mask]
|
| 111 |
+
self.landmarks_align = landmarks_align.data[mask]
|
| 112 |
+
self.attr = attr.data[mask]
|
| 113 |
+
# map from {-1, 1} to {0, 1}
|
| 114 |
+
self.attr = torch.div(self.attr + 1, 2, rounding_mode="floor")
|
| 115 |
+
self.attr_names = attr.header
|
| 116 |
+
|
| 117 |
+
def _load_csv(
|
| 118 |
+
self,
|
| 119 |
+
filename: str,
|
| 120 |
+
header: Optional[int] = None,
|
| 121 |
+
) -> CSV:
|
| 122 |
+
with open(os.path.join(self.root, self.base_folder, filename)) as csv_file:
|
| 123 |
+
data = list(csv.reader(csv_file, delimiter=" ", skipinitialspace=True))
|
| 124 |
+
|
| 125 |
+
if header is not None:
|
| 126 |
+
headers = data[header]
|
| 127 |
+
data = data[header + 1 :]
|
| 128 |
+
else:
|
| 129 |
+
headers = []
|
| 130 |
+
|
| 131 |
+
indices = [row[0] for row in data]
|
| 132 |
+
data = [row[1:] for row in data]
|
| 133 |
+
data_int = [list(map(int, i)) for i in data]
|
| 134 |
+
|
| 135 |
+
return CSV(headers, indices, torch.tensor(data_int))
|
| 136 |
+
|
| 137 |
+
def _check_integrity(self) -> bool:
|
| 138 |
+
for (_, md5, filename) in self.file_list:
|
| 139 |
+
fpath = os.path.join(self.root, self.base_folder, filename)
|
| 140 |
+
_, ext = os.path.splitext(filename)
|
| 141 |
+
# Allow original archive to be deleted (zip and 7z)
|
| 142 |
+
# Only need the extracted images
|
| 143 |
+
if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5):
|
| 144 |
+
return False
|
| 145 |
+
|
| 146 |
+
# Should check a hash of the images
|
| 147 |
+
return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba"))
|
| 148 |
+
|
| 149 |
+
def download(self) -> None:
|
| 150 |
+
if self._check_integrity():
|
| 151 |
+
print("Files already downloaded and verified")
|
| 152 |
+
return
|
| 153 |
+
|
| 154 |
+
for (file_id, md5, filename) in self.file_list:
|
| 155 |
+
download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5)
|
| 156 |
+
|
| 157 |
+
extract_archive(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"))
|
| 158 |
+
|
| 159 |
+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
| 160 |
+
X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))
|
| 161 |
+
|
| 162 |
+
target: Any = []
|
| 163 |
+
for t in self.target_type:
|
| 164 |
+
if t == "attr":
|
| 165 |
+
target.append(self.attr[index, :])
|
| 166 |
+
elif t == "identity":
|
| 167 |
+
target.append(self.identity[index, 0])
|
| 168 |
+
elif t == "bbox":
|
| 169 |
+
target.append(self.bbox[index, :])
|
| 170 |
+
elif t == "landmarks":
|
| 171 |
+
target.append(self.landmarks_align[index, :])
|
| 172 |
+
else:
|
| 173 |
+
# TODO: refactor with utils.verify_str_arg
|
| 174 |
+
raise ValueError(f'Target type "{t}" is not recognized.')
|
| 175 |
+
|
| 176 |
+
if self.transform is not None:
|
| 177 |
+
X = self.transform(X)
|
| 178 |
+
|
| 179 |
+
if target:
|
| 180 |
+
target = tuple(target) if len(target) > 1 else target[0]
|
| 181 |
+
|
| 182 |
+
if self.target_transform is not None:
|
| 183 |
+
target = self.target_transform(target)
|
| 184 |
+
else:
|
| 185 |
+
target = None
|
| 186 |
+
|
| 187 |
+
return X, target
|
| 188 |
+
|
| 189 |
+
def __len__(self) -> int:
|
| 190 |
+
return len(self.attr)
|
| 191 |
+
|
| 192 |
+
def extra_repr(self) -> str:
|
| 193 |
+
lines = ["Target type: {target_type}", "Split: {split}"]
|
| 194 |
+
return "\n".join(lines).format(**self.__dict__)
|
.venv/lib/python3.11/site-packages/torchvision/datasets/cifar.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path
|
| 2 |
+
import pickle
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Any, Callable, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
from .utils import check_integrity, download_and_extract_archive
|
| 10 |
+
from .vision import VisionDataset
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class CIFAR10(VisionDataset):
|
| 14 |
+
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
root (str or ``pathlib.Path``): Root directory of dataset where directory
|
| 18 |
+
``cifar-10-batches-py`` exists or will be saved to if download is set to True.
|
| 19 |
+
train (bool, optional): If True, creates dataset from training set, otherwise
|
| 20 |
+
creates from test set.
|
| 21 |
+
transform (callable, optional): A function/transform that takes in a PIL image
|
| 22 |
+
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
| 23 |
+
target_transform (callable, optional): A function/transform that takes in the
|
| 24 |
+
target and transforms it.
|
| 25 |
+
download (bool, optional): If true, downloads the dataset from the internet and
|
| 26 |
+
puts it in root directory. If dataset is already downloaded, it is not
|
| 27 |
+
downloaded again.
|
| 28 |
+
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
base_folder = "cifar-10-batches-py"
|
| 32 |
+
url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
|
| 33 |
+
filename = "cifar-10-python.tar.gz"
|
| 34 |
+
tgz_md5 = "c58f30108f718f92721af3b95e74349a"
|
| 35 |
+
train_list = [
|
| 36 |
+
["data_batch_1", "c99cafc152244af753f735de768cd75f"],
|
| 37 |
+
["data_batch_2", "d4bba439e000b95fd0a9bffe97cbabec"],
|
| 38 |
+
["data_batch_3", "54ebc095f3ab1f0389bbae665268c751"],
|
| 39 |
+
["data_batch_4", "634d18415352ddfa80567beed471001a"],
|
| 40 |
+
["data_batch_5", "482c414d41f54cd18b22e5b47cb7c3cb"],
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
test_list = [
|
| 44 |
+
["test_batch", "40351d587109b95175f43aff81a1287e"],
|
| 45 |
+
]
|
| 46 |
+
meta = {
|
| 47 |
+
"filename": "batches.meta",
|
| 48 |
+
"key": "label_names",
|
| 49 |
+
"md5": "5ff9c542aee3614f3951f8cda6e48888",
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
root: Union[str, Path],
|
| 55 |
+
train: bool = True,
|
| 56 |
+
transform: Optional[Callable] = None,
|
| 57 |
+
target_transform: Optional[Callable] = None,
|
| 58 |
+
download: bool = False,
|
| 59 |
+
) -> None:
|
| 60 |
+
|
| 61 |
+
super().__init__(root, transform=transform, target_transform=target_transform)
|
| 62 |
+
|
| 63 |
+
self.train = train # training set or test set
|
| 64 |
+
|
| 65 |
+
if download:
|
| 66 |
+
self.download()
|
| 67 |
+
|
| 68 |
+
if not self._check_integrity():
|
| 69 |
+
raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
|
| 70 |
+
|
| 71 |
+
if self.train:
|
| 72 |
+
downloaded_list = self.train_list
|
| 73 |
+
else:
|
| 74 |
+
downloaded_list = self.test_list
|
| 75 |
+
|
| 76 |
+
self.data: Any = []
|
| 77 |
+
self.targets = []
|
| 78 |
+
|
| 79 |
+
# now load the picked numpy arrays
|
| 80 |
+
for file_name, checksum in downloaded_list:
|
| 81 |
+
file_path = os.path.join(self.root, self.base_folder, file_name)
|
| 82 |
+
with open(file_path, "rb") as f:
|
| 83 |
+
entry = pickle.load(f, encoding="latin1")
|
| 84 |
+
self.data.append(entry["data"])
|
| 85 |
+
if "labels" in entry:
|
| 86 |
+
self.targets.extend(entry["labels"])
|
| 87 |
+
else:
|
| 88 |
+
self.targets.extend(entry["fine_labels"])
|
| 89 |
+
|
| 90 |
+
self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
|
| 91 |
+
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
|
| 92 |
+
|
| 93 |
+
self._load_meta()
|
| 94 |
+
|
| 95 |
+
def _load_meta(self) -> None:
|
| 96 |
+
path = os.path.join(self.root, self.base_folder, self.meta["filename"])
|
| 97 |
+
if not check_integrity(path, self.meta["md5"]):
|
| 98 |
+
raise RuntimeError("Dataset metadata file not found or corrupted. You can use download=True to download it")
|
| 99 |
+
with open(path, "rb") as infile:
|
| 100 |
+
data = pickle.load(infile, encoding="latin1")
|
| 101 |
+
self.classes = data[self.meta["key"]]
|
| 102 |
+
self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
|
| 103 |
+
|
| 104 |
+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
| 105 |
+
"""
|
| 106 |
+
Args:
|
| 107 |
+
index (int): Index
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
tuple: (image, target) where target is index of the target class.
|
| 111 |
+
"""
|
| 112 |
+
img, target = self.data[index], self.targets[index]
|
| 113 |
+
|
| 114 |
+
# doing this so that it is consistent with all other datasets
|
| 115 |
+
# to return a PIL Image
|
| 116 |
+
img = Image.fromarray(img)
|
| 117 |
+
|
| 118 |
+
if self.transform is not None:
|
| 119 |
+
img = self.transform(img)
|
| 120 |
+
|
| 121 |
+
if self.target_transform is not None:
|
| 122 |
+
target = self.target_transform(target)
|
| 123 |
+
|
| 124 |
+
return img, target
|
| 125 |
+
|
| 126 |
+
def __len__(self) -> int:
|
| 127 |
+
return len(self.data)
|
| 128 |
+
|
| 129 |
+
def _check_integrity(self) -> bool:
|
| 130 |
+
for filename, md5 in self.train_list + self.test_list:
|
| 131 |
+
fpath = os.path.join(self.root, self.base_folder, filename)
|
| 132 |
+
if not check_integrity(fpath, md5):
|
| 133 |
+
return False
|
| 134 |
+
return True
|
| 135 |
+
|
| 136 |
+
def download(self) -> None:
|
| 137 |
+
if self._check_integrity():
|
| 138 |
+
print("Files already downloaded and verified")
|
| 139 |
+
return
|
| 140 |
+
download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
|
| 141 |
+
|
| 142 |
+
def extra_repr(self) -> str:
|
| 143 |
+
split = "Train" if self.train is True else "Test"
|
| 144 |
+
return f"Split: {split}"
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class CIFAR100(CIFAR10):
|
| 148 |
+
"""`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
|
| 149 |
+
|
| 150 |
+
This is a subclass of the `CIFAR10` Dataset.
|
| 151 |
+
"""
|
| 152 |
+
|
| 153 |
+
base_folder = "cifar-100-python"
|
| 154 |
+
url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
|
| 155 |
+
filename = "cifar-100-python.tar.gz"
|
| 156 |
+
tgz_md5 = "eb9058c3a382ffc7106e4002c42a8d85"
|
| 157 |
+
train_list = [
|
| 158 |
+
["train", "16019d7e3df5f24257cddd939b257f8d"],
|
| 159 |
+
]
|
| 160 |
+
|
| 161 |
+
test_list = [
|
| 162 |
+
["test", "f0ef6b0ae62326f3e7ffdfab6717acfc"],
|
| 163 |
+
]
|
| 164 |
+
meta = {
|
| 165 |
+
"filename": "meta",
|
| 166 |
+
"key": "fine_label_names",
|
| 167 |
+
"md5": "7973b15100ade9c7d40fb424638fde48",
|
| 168 |
+
}
|
.venv/lib/python3.11/site-packages/torchvision/datasets/cityscapes.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from collections import namedtuple
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 6 |
+
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
from .utils import extract_archive, iterable_to_str, verify_str_arg
|
| 10 |
+
from .vision import VisionDataset
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Cityscapes(VisionDataset):
|
| 14 |
+
"""`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
root (str or ``pathlib.Path``): Root directory of dataset where directory ``leftImg8bit``
|
| 18 |
+
and ``gtFine`` or ``gtCoarse`` are located.
|
| 19 |
+
split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="fine"
|
| 20 |
+
otherwise ``train``, ``train_extra`` or ``val``
|
| 21 |
+
mode (string, optional): The quality mode to use, ``fine`` or ``coarse``
|
| 22 |
+
target_type (string or list, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
|
| 23 |
+
or ``color``. Can also be a list to output a tuple with all specified target types.
|
| 24 |
+
transform (callable, optional): A function/transform that takes in a PIL image
|
| 25 |
+
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
| 26 |
+
target_transform (callable, optional): A function/transform that takes in the
|
| 27 |
+
target and transforms it.
|
| 28 |
+
transforms (callable, optional): A function/transform that takes input sample and its target as entry
|
| 29 |
+
and returns a transformed version.
|
| 30 |
+
|
| 31 |
+
Examples:
|
| 32 |
+
|
| 33 |
+
Get semantic segmentation target
|
| 34 |
+
|
| 35 |
+
.. code-block:: python
|
| 36 |
+
|
| 37 |
+
dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
|
| 38 |
+
target_type='semantic')
|
| 39 |
+
|
| 40 |
+
img, smnt = dataset[0]
|
| 41 |
+
|
| 42 |
+
Get multiple targets
|
| 43 |
+
|
| 44 |
+
.. code-block:: python
|
| 45 |
+
|
| 46 |
+
dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
|
| 47 |
+
target_type=['instance', 'color', 'polygon'])
|
| 48 |
+
|
| 49 |
+
img, (inst, col, poly) = dataset[0]
|
| 50 |
+
|
| 51 |
+
Validate on the "coarse" set
|
| 52 |
+
|
| 53 |
+
.. code-block:: python
|
| 54 |
+
|
| 55 |
+
dataset = Cityscapes('./data/cityscapes', split='val', mode='coarse',
|
| 56 |
+
target_type='semantic')
|
| 57 |
+
|
| 58 |
+
img, smnt = dataset[0]
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
# Based on https://github.com/mcordts/cityscapesScripts
|
| 62 |
+
CityscapesClass = namedtuple(
|
| 63 |
+
"CityscapesClass",
|
| 64 |
+
["name", "id", "train_id", "category", "category_id", "has_instances", "ignore_in_eval", "color"],
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
classes = [
|
| 68 |
+
CityscapesClass("unlabeled", 0, 255, "void", 0, False, True, (0, 0, 0)),
|
| 69 |
+
CityscapesClass("ego vehicle", 1, 255, "void", 0, False, True, (0, 0, 0)),
|
| 70 |
+
CityscapesClass("rectification border", 2, 255, "void", 0, False, True, (0, 0, 0)),
|
| 71 |
+
CityscapesClass("out of roi", 3, 255, "void", 0, False, True, (0, 0, 0)),
|
| 72 |
+
CityscapesClass("static", 4, 255, "void", 0, False, True, (0, 0, 0)),
|
| 73 |
+
CityscapesClass("dynamic", 5, 255, "void", 0, False, True, (111, 74, 0)),
|
| 74 |
+
CityscapesClass("ground", 6, 255, "void", 0, False, True, (81, 0, 81)),
|
| 75 |
+
CityscapesClass("road", 7, 0, "flat", 1, False, False, (128, 64, 128)),
|
| 76 |
+
CityscapesClass("sidewalk", 8, 1, "flat", 1, False, False, (244, 35, 232)),
|
| 77 |
+
CityscapesClass("parking", 9, 255, "flat", 1, False, True, (250, 170, 160)),
|
| 78 |
+
CityscapesClass("rail track", 10, 255, "flat", 1, False, True, (230, 150, 140)),
|
| 79 |
+
CityscapesClass("building", 11, 2, "construction", 2, False, False, (70, 70, 70)),
|
| 80 |
+
CityscapesClass("wall", 12, 3, "construction", 2, False, False, (102, 102, 156)),
|
| 81 |
+
CityscapesClass("fence", 13, 4, "construction", 2, False, False, (190, 153, 153)),
|
| 82 |
+
CityscapesClass("guard rail", 14, 255, "construction", 2, False, True, (180, 165, 180)),
|
| 83 |
+
CityscapesClass("bridge", 15, 255, "construction", 2, False, True, (150, 100, 100)),
|
| 84 |
+
CityscapesClass("tunnel", 16, 255, "construction", 2, False, True, (150, 120, 90)),
|
| 85 |
+
CityscapesClass("pole", 17, 5, "object", 3, False, False, (153, 153, 153)),
|
| 86 |
+
CityscapesClass("polegroup", 18, 255, "object", 3, False, True, (153, 153, 153)),
|
| 87 |
+
CityscapesClass("traffic light", 19, 6, "object", 3, False, False, (250, 170, 30)),
|
| 88 |
+
CityscapesClass("traffic sign", 20, 7, "object", 3, False, False, (220, 220, 0)),
|
| 89 |
+
CityscapesClass("vegetation", 21, 8, "nature", 4, False, False, (107, 142, 35)),
|
| 90 |
+
CityscapesClass("terrain", 22, 9, "nature", 4, False, False, (152, 251, 152)),
|
| 91 |
+
CityscapesClass("sky", 23, 10, "sky", 5, False, False, (70, 130, 180)),
|
| 92 |
+
CityscapesClass("person", 24, 11, "human", 6, True, False, (220, 20, 60)),
|
| 93 |
+
CityscapesClass("rider", 25, 12, "human", 6, True, False, (255, 0, 0)),
|
| 94 |
+
CityscapesClass("car", 26, 13, "vehicle", 7, True, False, (0, 0, 142)),
|
| 95 |
+
CityscapesClass("truck", 27, 14, "vehicle", 7, True, False, (0, 0, 70)),
|
| 96 |
+
CityscapesClass("bus", 28, 15, "vehicle", 7, True, False, (0, 60, 100)),
|
| 97 |
+
CityscapesClass("caravan", 29, 255, "vehicle", 7, True, True, (0, 0, 90)),
|
| 98 |
+
CityscapesClass("trailer", 30, 255, "vehicle", 7, True, True, (0, 0, 110)),
|
| 99 |
+
CityscapesClass("train", 31, 16, "vehicle", 7, True, False, (0, 80, 100)),
|
| 100 |
+
CityscapesClass("motorcycle", 32, 17, "vehicle", 7, True, False, (0, 0, 230)),
|
| 101 |
+
CityscapesClass("bicycle", 33, 18, "vehicle", 7, True, False, (119, 11, 32)),
|
| 102 |
+
CityscapesClass("license plate", -1, -1, "vehicle", 7, False, True, (0, 0, 142)),
|
| 103 |
+
]
|
| 104 |
+
|
| 105 |
+
def __init__(
|
| 106 |
+
self,
|
| 107 |
+
root: Union[str, Path],
|
| 108 |
+
split: str = "train",
|
| 109 |
+
mode: str = "fine",
|
| 110 |
+
target_type: Union[List[str], str] = "instance",
|
| 111 |
+
transform: Optional[Callable] = None,
|
| 112 |
+
target_transform: Optional[Callable] = None,
|
| 113 |
+
transforms: Optional[Callable] = None,
|
| 114 |
+
) -> None:
|
| 115 |
+
super().__init__(root, transforms, transform, target_transform)
|
| 116 |
+
self.mode = "gtFine" if mode == "fine" else "gtCoarse"
|
| 117 |
+
self.images_dir = os.path.join(self.root, "leftImg8bit", split)
|
| 118 |
+
self.targets_dir = os.path.join(self.root, self.mode, split)
|
| 119 |
+
self.target_type = target_type
|
| 120 |
+
self.split = split
|
| 121 |
+
self.images = []
|
| 122 |
+
self.targets = []
|
| 123 |
+
|
| 124 |
+
verify_str_arg(mode, "mode", ("fine", "coarse"))
|
| 125 |
+
if mode == "fine":
|
| 126 |
+
valid_modes = ("train", "test", "val")
|
| 127 |
+
else:
|
| 128 |
+
valid_modes = ("train", "train_extra", "val")
|
| 129 |
+
msg = "Unknown value '{}' for argument split if mode is '{}'. Valid values are {{{}}}."
|
| 130 |
+
msg = msg.format(split, mode, iterable_to_str(valid_modes))
|
| 131 |
+
verify_str_arg(split, "split", valid_modes, msg)
|
| 132 |
+
|
| 133 |
+
if not isinstance(target_type, list):
|
| 134 |
+
self.target_type = [target_type]
|
| 135 |
+
[
|
| 136 |
+
verify_str_arg(value, "target_type", ("instance", "semantic", "polygon", "color"))
|
| 137 |
+
for value in self.target_type
|
| 138 |
+
]
|
| 139 |
+
|
| 140 |
+
if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):
|
| 141 |
+
|
| 142 |
+
if split == "train_extra":
|
| 143 |
+
image_dir_zip = os.path.join(self.root, "leftImg8bit_trainextra.zip")
|
| 144 |
+
else:
|
| 145 |
+
image_dir_zip = os.path.join(self.root, "leftImg8bit_trainvaltest.zip")
|
| 146 |
+
|
| 147 |
+
if self.mode == "gtFine":
|
| 148 |
+
target_dir_zip = os.path.join(self.root, f"{self.mode}_trainvaltest.zip")
|
| 149 |
+
elif self.mode == "gtCoarse":
|
| 150 |
+
target_dir_zip = os.path.join(self.root, f"{self.mode}.zip")
|
| 151 |
+
|
| 152 |
+
if os.path.isfile(image_dir_zip) and os.path.isfile(target_dir_zip):
|
| 153 |
+
extract_archive(from_path=image_dir_zip, to_path=self.root)
|
| 154 |
+
extract_archive(from_path=target_dir_zip, to_path=self.root)
|
| 155 |
+
else:
|
| 156 |
+
raise RuntimeError(
|
| 157 |
+
"Dataset not found or incomplete. Please make sure all required folders for the"
|
| 158 |
+
' specified "split" and "mode" are inside the "root" directory'
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
for city in os.listdir(self.images_dir):
|
| 162 |
+
img_dir = os.path.join(self.images_dir, city)
|
| 163 |
+
target_dir = os.path.join(self.targets_dir, city)
|
| 164 |
+
for file_name in os.listdir(img_dir):
|
| 165 |
+
target_types = []
|
| 166 |
+
for t in self.target_type:
|
| 167 |
+
target_name = "{}_{}".format(
|
| 168 |
+
file_name.split("_leftImg8bit")[0], self._get_target_suffix(self.mode, t)
|
| 169 |
+
)
|
| 170 |
+
target_types.append(os.path.join(target_dir, target_name))
|
| 171 |
+
|
| 172 |
+
self.images.append(os.path.join(img_dir, file_name))
|
| 173 |
+
self.targets.append(target_types)
|
| 174 |
+
|
| 175 |
+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
| 176 |
+
"""
|
| 177 |
+
Args:
|
| 178 |
+
index (int): Index
|
| 179 |
+
Returns:
|
| 180 |
+
tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
|
| 181 |
+
than one item. Otherwise, target is a json object if target_type="polygon", else the image segmentation.
|
| 182 |
+
"""
|
| 183 |
+
|
| 184 |
+
image = Image.open(self.images[index]).convert("RGB")
|
| 185 |
+
|
| 186 |
+
targets: Any = []
|
| 187 |
+
for i, t in enumerate(self.target_type):
|
| 188 |
+
if t == "polygon":
|
| 189 |
+
target = self._load_json(self.targets[index][i])
|
| 190 |
+
else:
|
| 191 |
+
target = Image.open(self.targets[index][i]) # type: ignore[assignment]
|
| 192 |
+
|
| 193 |
+
targets.append(target)
|
| 194 |
+
|
| 195 |
+
target = tuple(targets) if len(targets) > 1 else targets[0]
|
| 196 |
+
|
| 197 |
+
if self.transforms is not None:
|
| 198 |
+
image, target = self.transforms(image, target)
|
| 199 |
+
|
| 200 |
+
return image, target
|
| 201 |
+
|
| 202 |
+
def __len__(self) -> int:
|
| 203 |
+
return len(self.images)
|
| 204 |
+
|
| 205 |
+
def extra_repr(self) -> str:
|
| 206 |
+
lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"]
|
| 207 |
+
return "\n".join(lines).format(**self.__dict__)
|
| 208 |
+
|
| 209 |
+
def _load_json(self, path: str) -> Dict[str, Any]:
|
| 210 |
+
with open(path) as file:
|
| 211 |
+
data = json.load(file)
|
| 212 |
+
return data
|
| 213 |
+
|
| 214 |
+
def _get_target_suffix(self, mode: str, target_type: str) -> str:
|
| 215 |
+
if target_type == "instance":
|
| 216 |
+
return f"{mode}_instanceIds.png"
|
| 217 |
+
elif target_type == "semantic":
|
| 218 |
+
return f"{mode}_labelIds.png"
|
| 219 |
+
elif target_type == "color":
|
| 220 |
+
return f"{mode}_color.png"
|
| 221 |
+
else:
|
| 222 |
+
return f"{mode}_polygons.json"
|
.venv/lib/python3.11/site-packages/torchvision/datasets/clevr.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import pathlib
|
| 3 |
+
from typing import Any, Callable, List, Optional, Tuple, Union
|
| 4 |
+
from urllib.parse import urlparse
|
| 5 |
+
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
from .utils import download_and_extract_archive, verify_str_arg
|
| 9 |
+
from .vision import VisionDataset
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class CLEVRClassification(VisionDataset):
|
| 13 |
+
"""`CLEVR <https://cs.stanford.edu/people/jcjohns/clevr/>`_ classification dataset.
|
| 14 |
+
|
| 15 |
+
The number of objects in a scene are used as label.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
root (str or ``pathlib.Path``): Root directory of dataset where directory ``root/clevr`` exists or will be saved to if download is
|
| 19 |
+
set to True.
|
| 20 |
+
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
|
| 21 |
+
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
|
| 22 |
+
version. E.g, ``transforms.RandomCrop``
|
| 23 |
+
target_transform (callable, optional): A function/transform that takes in them target and transforms it.
|
| 24 |
+
download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If
|
| 25 |
+
dataset is already downloaded, it is not downloaded again.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
_URL = "https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip"
|
| 29 |
+
_MD5 = "b11922020e72d0cd9154779b2d3d07d2"
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
root: Union[str, pathlib.Path],
|
| 34 |
+
split: str = "train",
|
| 35 |
+
transform: Optional[Callable] = None,
|
| 36 |
+
target_transform: Optional[Callable] = None,
|
| 37 |
+
download: bool = False,
|
| 38 |
+
) -> None:
|
| 39 |
+
self._split = verify_str_arg(split, "split", ("train", "val", "test"))
|
| 40 |
+
super().__init__(root, transform=transform, target_transform=target_transform)
|
| 41 |
+
self._base_folder = pathlib.Path(self.root) / "clevr"
|
| 42 |
+
self._data_folder = self._base_folder / pathlib.Path(urlparse(self._URL).path).stem
|
| 43 |
+
|
| 44 |
+
if download:
|
| 45 |
+
self._download()
|
| 46 |
+
|
| 47 |
+
if not self._check_exists():
|
| 48 |
+
raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
|
| 49 |
+
|
| 50 |
+
self._image_files = sorted(self._data_folder.joinpath("images", self._split).glob("*"))
|
| 51 |
+
|
| 52 |
+
self._labels: List[Optional[int]]
|
| 53 |
+
if self._split != "test":
|
| 54 |
+
with open(self._data_folder / "scenes" / f"CLEVR_{self._split}_scenes.json") as file:
|
| 55 |
+
content = json.load(file)
|
| 56 |
+
num_objects = {scene["image_filename"]: len(scene["objects"]) for scene in content["scenes"]}
|
| 57 |
+
self._labels = [num_objects[image_file.name] for image_file in self._image_files]
|
| 58 |
+
else:
|
| 59 |
+
self._labels = [None] * len(self._image_files)
|
| 60 |
+
|
| 61 |
+
def __len__(self) -> int:
|
| 62 |
+
return len(self._image_files)
|
| 63 |
+
|
| 64 |
+
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
|
| 65 |
+
image_file = self._image_files[idx]
|
| 66 |
+
label = self._labels[idx]
|
| 67 |
+
|
| 68 |
+
image = Image.open(image_file).convert("RGB")
|
| 69 |
+
|
| 70 |
+
if self.transform:
|
| 71 |
+
image = self.transform(image)
|
| 72 |
+
|
| 73 |
+
if self.target_transform:
|
| 74 |
+
label = self.target_transform(label)
|
| 75 |
+
|
| 76 |
+
return image, label
|
| 77 |
+
|
| 78 |
+
def _check_exists(self) -> bool:
|
| 79 |
+
return self._data_folder.exists() and self._data_folder.is_dir()
|
| 80 |
+
|
| 81 |
+
def _download(self) -> None:
|
| 82 |
+
if self._check_exists():
|
| 83 |
+
return
|
| 84 |
+
|
| 85 |
+
download_and_extract_archive(self._URL, str(self._base_folder), md5=self._MD5)
|
| 86 |
+
|
| 87 |
+
def extra_repr(self) -> str:
|
| 88 |
+
return f"split={self._split}"
|
.venv/lib/python3.11/site-packages/torchvision/datasets/coco.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Any, Callable, List, Optional, Tuple, Union
|
| 4 |
+
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
from .vision import VisionDataset
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class CocoDetection(VisionDataset):
|
| 11 |
+
"""`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset.
|
| 12 |
+
|
| 13 |
+
It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
root (str or ``pathlib.Path``): Root directory where images are downloaded to.
|
| 17 |
+
annFile (string): Path to json annotation file.
|
| 18 |
+
transform (callable, optional): A function/transform that takes in a PIL image
|
| 19 |
+
and returns a transformed version. E.g, ``transforms.PILToTensor``
|
| 20 |
+
target_transform (callable, optional): A function/transform that takes in the
|
| 21 |
+
target and transforms it.
|
| 22 |
+
transforms (callable, optional): A function/transform that takes input sample and its target as entry
|
| 23 |
+
and returns a transformed version.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
root: Union[str, Path],
|
| 29 |
+
annFile: str,
|
| 30 |
+
transform: Optional[Callable] = None,
|
| 31 |
+
target_transform: Optional[Callable] = None,
|
| 32 |
+
transforms: Optional[Callable] = None,
|
| 33 |
+
) -> None:
|
| 34 |
+
super().__init__(root, transforms, transform, target_transform)
|
| 35 |
+
from pycocotools.coco import COCO
|
| 36 |
+
|
| 37 |
+
self.coco = COCO(annFile)
|
| 38 |
+
self.ids = list(sorted(self.coco.imgs.keys()))
|
| 39 |
+
|
| 40 |
+
def _load_image(self, id: int) -> Image.Image:
|
| 41 |
+
path = self.coco.loadImgs(id)[0]["file_name"]
|
| 42 |
+
return Image.open(os.path.join(self.root, path)).convert("RGB")
|
| 43 |
+
|
| 44 |
+
def _load_target(self, id: int) -> List[Any]:
|
| 45 |
+
return self.coco.loadAnns(self.coco.getAnnIds(id))
|
| 46 |
+
|
| 47 |
+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
| 48 |
+
|
| 49 |
+
if not isinstance(index, int):
|
| 50 |
+
raise ValueError(f"Index must be of type integer, got {type(index)} instead.")
|
| 51 |
+
|
| 52 |
+
id = self.ids[index]
|
| 53 |
+
image = self._load_image(id)
|
| 54 |
+
target = self._load_target(id)
|
| 55 |
+
|
| 56 |
+
if self.transforms is not None:
|
| 57 |
+
image, target = self.transforms(image, target)
|
| 58 |
+
|
| 59 |
+
return image, target
|
| 60 |
+
|
| 61 |
+
def __len__(self) -> int:
|
| 62 |
+
return len(self.ids)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class CocoCaptions(CocoDetection):
|
| 66 |
+
"""`MS Coco Captions <https://cocodataset.org/#captions-2015>`_ Dataset.
|
| 67 |
+
|
| 68 |
+
It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
root (str or ``pathlib.Path``): Root directory where images are downloaded to.
|
| 72 |
+
annFile (string): Path to json annotation file.
|
| 73 |
+
transform (callable, optional): A function/transform that takes in a PIL image
|
| 74 |
+
and returns a transformed version. E.g, ``transforms.PILToTensor``
|
| 75 |
+
target_transform (callable, optional): A function/transform that takes in the
|
| 76 |
+
target and transforms it.
|
| 77 |
+
transforms (callable, optional): A function/transform that takes input sample and its target as entry
|
| 78 |
+
and returns a transformed version.
|
| 79 |
+
|
| 80 |
+
Example:
|
| 81 |
+
|
| 82 |
+
.. code:: python
|
| 83 |
+
|
| 84 |
+
import torchvision.datasets as dset
|
| 85 |
+
import torchvision.transforms as transforms
|
| 86 |
+
cap = dset.CocoCaptions(root = 'dir where images are',
|
| 87 |
+
annFile = 'json annotation file',
|
| 88 |
+
transform=transforms.PILToTensor())
|
| 89 |
+
|
| 90 |
+
print('Number of samples: ', len(cap))
|
| 91 |
+
img, target = cap[3] # load 4th sample
|
| 92 |
+
|
| 93 |
+
print("Image Size: ", img.size())
|
| 94 |
+
print(target)
|
| 95 |
+
|
| 96 |
+
Output: ::
|
| 97 |
+
|
| 98 |
+
Number of samples: 82783
|
| 99 |
+
Image Size: (3L, 427L, 640L)
|
| 100 |
+
[u'A plane emitting smoke stream flying over a mountain.',
|
| 101 |
+
u'A plane darts across a bright blue sky behind a mountain covered in snow',
|
| 102 |
+
u'A plane leaves a contrail above the snowy mountain top.',
|
| 103 |
+
u'A mountain that has a plane flying overheard in the distance.',
|
| 104 |
+
u'A mountain view with a plume of smoke in the background']
|
| 105 |
+
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
def _load_target(self, id: int) -> List[str]:
|
| 109 |
+
return [ann["caption"] for ann in super()._load_target(id)]
|
.venv/lib/python3.11/site-packages/torchvision/datasets/dtd.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pathlib
|
| 3 |
+
from typing import Any, Callable, Optional, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import PIL.Image
|
| 6 |
+
|
| 7 |
+
from .utils import download_and_extract_archive, verify_str_arg
|
| 8 |
+
from .vision import VisionDataset
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class DTD(VisionDataset):
|
| 12 |
+
"""`Describable Textures Dataset (DTD) <https://www.robots.ox.ac.uk/~vgg/data/dtd/>`_.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
root (str or ``pathlib.Path``): Root directory of the dataset.
|
| 16 |
+
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
|
| 17 |
+
partition (int, optional): The dataset partition. Should be ``1 <= partition <= 10``. Defaults to ``1``.
|
| 18 |
+
|
| 19 |
+
.. note::
|
| 20 |
+
|
| 21 |
+
The partition only changes which split each image belongs to. Thus, regardless of the selected
|
| 22 |
+
partition, combining all splits will result in all images.
|
| 23 |
+
|
| 24 |
+
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
|
| 25 |
+
version. E.g, ``transforms.RandomCrop``.
|
| 26 |
+
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
|
| 27 |
+
download (bool, optional): If True, downloads the dataset from the internet and
|
| 28 |
+
puts it in root directory. If dataset is already downloaded, it is not
|
| 29 |
+
downloaded again. Default is False.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
_URL = "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz"
|
| 33 |
+
_MD5 = "fff73e5086ae6bdbea199a49dfb8a4c1"
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
root: Union[str, pathlib.Path],
|
| 38 |
+
split: str = "train",
|
| 39 |
+
partition: int = 1,
|
| 40 |
+
transform: Optional[Callable] = None,
|
| 41 |
+
target_transform: Optional[Callable] = None,
|
| 42 |
+
download: bool = False,
|
| 43 |
+
) -> None:
|
| 44 |
+
self._split = verify_str_arg(split, "split", ("train", "val", "test"))
|
| 45 |
+
if not isinstance(partition, int) and not (1 <= partition <= 10):
|
| 46 |
+
raise ValueError(
|
| 47 |
+
f"Parameter 'partition' should be an integer with `1 <= partition <= 10`, "
|
| 48 |
+
f"but got {partition} instead"
|
| 49 |
+
)
|
| 50 |
+
self._partition = partition
|
| 51 |
+
|
| 52 |
+
super().__init__(root, transform=transform, target_transform=target_transform)
|
| 53 |
+
self._base_folder = pathlib.Path(self.root) / type(self).__name__.lower()
|
| 54 |
+
self._data_folder = self._base_folder / "dtd"
|
| 55 |
+
self._meta_folder = self._data_folder / "labels"
|
| 56 |
+
self._images_folder = self._data_folder / "images"
|
| 57 |
+
|
| 58 |
+
if download:
|
| 59 |
+
self._download()
|
| 60 |
+
|
| 61 |
+
if not self._check_exists():
|
| 62 |
+
raise RuntimeError("Dataset not found. You can use download=True to download it")
|
| 63 |
+
|
| 64 |
+
self._image_files = []
|
| 65 |
+
classes = []
|
| 66 |
+
with open(self._meta_folder / f"{self._split}{self._partition}.txt") as file:
|
| 67 |
+
for line in file:
|
| 68 |
+
cls, name = line.strip().split("/")
|
| 69 |
+
self._image_files.append(self._images_folder.joinpath(cls, name))
|
| 70 |
+
classes.append(cls)
|
| 71 |
+
|
| 72 |
+
self.classes = sorted(set(classes))
|
| 73 |
+
self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
|
| 74 |
+
self._labels = [self.class_to_idx[cls] for cls in classes]
|
| 75 |
+
|
| 76 |
+
def __len__(self) -> int:
|
| 77 |
+
return len(self._image_files)
|
| 78 |
+
|
| 79 |
+
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
|
| 80 |
+
image_file, label = self._image_files[idx], self._labels[idx]
|
| 81 |
+
image = PIL.Image.open(image_file).convert("RGB")
|
| 82 |
+
|
| 83 |
+
if self.transform:
|
| 84 |
+
image = self.transform(image)
|
| 85 |
+
|
| 86 |
+
if self.target_transform:
|
| 87 |
+
label = self.target_transform(label)
|
| 88 |
+
|
| 89 |
+
return image, label
|
| 90 |
+
|
| 91 |
+
def extra_repr(self) -> str:
|
| 92 |
+
return f"split={self._split}, partition={self._partition}"
|
| 93 |
+
|
| 94 |
+
def _check_exists(self) -> bool:
|
| 95 |
+
return os.path.exists(self._data_folder) and os.path.isdir(self._data_folder)
|
| 96 |
+
|
| 97 |
+
def _download(self) -> None:
|
| 98 |
+
if self._check_exists():
|
| 99 |
+
return
|
| 100 |
+
download_and_extract_archive(self._URL, download_root=str(self._base_folder), md5=self._MD5)
|
.venv/lib/python3.11/site-packages/torchvision/datasets/eurosat.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Callable, Optional, Union
|
| 4 |
+
|
| 5 |
+
from .folder import ImageFolder
|
| 6 |
+
from .utils import download_and_extract_archive
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class EuroSAT(ImageFolder):
|
| 10 |
+
"""RGB version of the `EuroSAT <https://github.com/phelber/eurosat>`_ Dataset.
|
| 11 |
+
|
| 12 |
+
For the MS version of the dataset, see
|
| 13 |
+
`TorchGeo <https://torchgeo.readthedocs.io/en/stable/api/datasets.html#eurosat>`__.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
root (str or ``pathlib.Path``): Root directory of dataset where ``root/eurosat`` exists.
|
| 17 |
+
transform (callable, optional): A function/transform that takes in a PIL image
|
| 18 |
+
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
| 19 |
+
target_transform (callable, optional): A function/transform that takes in the
|
| 20 |
+
target and transforms it.
|
| 21 |
+
download (bool, optional): If True, downloads the dataset from the internet and
|
| 22 |
+
puts it in root directory. If dataset is already downloaded, it is not
|
| 23 |
+
downloaded again. Default is False.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
root: Union[str, Path],
|
| 29 |
+
transform: Optional[Callable] = None,
|
| 30 |
+
target_transform: Optional[Callable] = None,
|
| 31 |
+
download: bool = False,
|
| 32 |
+
) -> None:
|
| 33 |
+
self.root = os.path.expanduser(root)
|
| 34 |
+
self._base_folder = os.path.join(self.root, "eurosat")
|
| 35 |
+
self._data_folder = os.path.join(self._base_folder, "2750")
|
| 36 |
+
|
| 37 |
+
if download:
|
| 38 |
+
self.download()
|
| 39 |
+
|
| 40 |
+
if not self._check_exists():
|
| 41 |
+
raise RuntimeError("Dataset not found. You can use download=True to download it")
|
| 42 |
+
|
| 43 |
+
super().__init__(self._data_folder, transform=transform, target_transform=target_transform)
|
| 44 |
+
self.root = os.path.expanduser(root)
|
| 45 |
+
|
| 46 |
+
def __len__(self) -> int:
|
| 47 |
+
return len(self.samples)
|
| 48 |
+
|
| 49 |
+
def _check_exists(self) -> bool:
|
| 50 |
+
return os.path.exists(self._data_folder)
|
| 51 |
+
|
| 52 |
+
def download(self) -> None:
|
| 53 |
+
|
| 54 |
+
if self._check_exists():
|
| 55 |
+
return
|
| 56 |
+
|
| 57 |
+
os.makedirs(self._base_folder, exist_ok=True)
|
| 58 |
+
download_and_extract_archive(
|
| 59 |
+
"https://huggingface.co/datasets/torchgeo/eurosat/resolve/c877bcd43f099cd0196738f714544e355477f3fd/EuroSAT.zip",
|
| 60 |
+
download_root=self._base_folder,
|
| 61 |
+
md5="c8fa014336c82ac7804f0398fcb19387",
|
| 62 |
+
)
|
.venv/lib/python3.11/site-packages/torchvision/datasets/fakedata.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Callable, Optional, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from .. import transforms
|
| 6 |
+
from .vision import VisionDataset
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class FakeData(VisionDataset):
|
| 10 |
+
"""A fake dataset that returns randomly generated images and returns them as PIL images
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
size (int, optional): Size of the dataset. Default: 1000 images
|
| 14 |
+
image_size(tuple, optional): Size if the returned images. Default: (3, 224, 224)
|
| 15 |
+
num_classes(int, optional): Number of classes in the dataset. Default: 10
|
| 16 |
+
transform (callable, optional): A function/transform that takes in a PIL image
|
| 17 |
+
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
| 18 |
+
target_transform (callable, optional): A function/transform that takes in the
|
| 19 |
+
target and transforms it.
|
| 20 |
+
random_offset (int): Offsets the index-based random seed used to
|
| 21 |
+
generate each image. Default: 0
|
| 22 |
+
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
size: int = 1000,
|
| 28 |
+
image_size: Tuple[int, int, int] = (3, 224, 224),
|
| 29 |
+
num_classes: int = 10,
|
| 30 |
+
transform: Optional[Callable] = None,
|
| 31 |
+
target_transform: Optional[Callable] = None,
|
| 32 |
+
random_offset: int = 0,
|
| 33 |
+
) -> None:
|
| 34 |
+
super().__init__(transform=transform, target_transform=target_transform)
|
| 35 |
+
self.size = size
|
| 36 |
+
self.num_classes = num_classes
|
| 37 |
+
self.image_size = image_size
|
| 38 |
+
self.random_offset = random_offset
|
| 39 |
+
|
| 40 |
+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
| 41 |
+
"""
|
| 42 |
+
Args:
|
| 43 |
+
index (int): Index
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
tuple: (image, target) where target is class_index of the target class.
|
| 47 |
+
"""
|
| 48 |
+
# create random image that is consistent with the index id
|
| 49 |
+
if index >= len(self):
|
| 50 |
+
raise IndexError(f"{self.__class__.__name__} index out of range")
|
| 51 |
+
rng_state = torch.get_rng_state()
|
| 52 |
+
torch.manual_seed(index + self.random_offset)
|
| 53 |
+
img = torch.randn(*self.image_size)
|
| 54 |
+
target = torch.randint(0, self.num_classes, size=(1,), dtype=torch.long)[0]
|
| 55 |
+
torch.set_rng_state(rng_state)
|
| 56 |
+
|
| 57 |
+
# convert to PIL Image
|
| 58 |
+
img = transforms.ToPILImage()(img)
|
| 59 |
+
if self.transform is not None:
|
| 60 |
+
img = self.transform(img)
|
| 61 |
+
if self.target_transform is not None:
|
| 62 |
+
target = self.target_transform(target)
|
| 63 |
+
|
| 64 |
+
return img, target.item()
|
| 65 |
+
|
| 66 |
+
def __len__(self) -> int:
|
| 67 |
+
return self.size
|
.venv/lib/python3.11/site-packages/torchvision/datasets/fer2013.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
import pathlib
|
| 3 |
+
from typing import Any, Callable, Optional, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
from .utils import check_integrity, verify_str_arg
|
| 9 |
+
from .vision import VisionDataset
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class FER2013(VisionDataset):
|
| 13 |
+
"""`FER2013
|
| 14 |
+
<https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge>`_ Dataset.
|
| 15 |
+
|
| 16 |
+
.. note::
|
| 17 |
+
This dataset can return test labels only if ``fer2013.csv`` OR
|
| 18 |
+
``icml_face_data.csv`` are present in ``root/fer2013/``. If only
|
| 19 |
+
``train.csv`` and ``test.csv`` are present, the test labels are set to
|
| 20 |
+
``None``.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
root (str or ``pathlib.Path``): Root directory of dataset where directory
|
| 24 |
+
``root/fer2013`` exists. This directory may contain either
|
| 25 |
+
``fer2013.csv``, ``icml_face_data.csv``, or both ``train.csv`` and
|
| 26 |
+
``test.csv``. Precendence is given in that order, i.e. if
|
| 27 |
+
``fer2013.csv`` is present then the rest of the files will be
|
| 28 |
+
ignored. All these (combinations of) files contain the same data and
|
| 29 |
+
are supported for convenience, but only ``fer2013.csv`` and
|
| 30 |
+
``icml_face_data.csv`` are able to return non-None test labels.
|
| 31 |
+
split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
|
| 32 |
+
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
|
| 33 |
+
version. E.g, ``transforms.RandomCrop``
|
| 34 |
+
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
_RESOURCES = {
|
| 38 |
+
"train": ("train.csv", "3f0dfb3d3fd99c811a1299cb947e3131"),
|
| 39 |
+
"test": ("test.csv", "b02c2298636a634e8c2faabbf3ea9a23"),
|
| 40 |
+
# The fer2013.csv and icml_face_data.csv files contain both train and
|
| 41 |
+
# tests instances, and unlike test.csv they contain the labels for the
|
| 42 |
+
# test instances. We give these 2 files precedence over train.csv and
|
| 43 |
+
# test.csv. And yes, they both contain the same data, but with different
|
| 44 |
+
# column names (note the spaces) and ordering:
|
| 45 |
+
# $ head -n 1 fer2013.csv icml_face_data.csv train.csv test.csv
|
| 46 |
+
# ==> fer2013.csv <==
|
| 47 |
+
# emotion,pixels,Usage
|
| 48 |
+
#
|
| 49 |
+
# ==> icml_face_data.csv <==
|
| 50 |
+
# emotion, Usage, pixels
|
| 51 |
+
#
|
| 52 |
+
# ==> train.csv <==
|
| 53 |
+
# emotion,pixels
|
| 54 |
+
#
|
| 55 |
+
# ==> test.csv <==
|
| 56 |
+
# pixels
|
| 57 |
+
"fer": ("fer2013.csv", "f8428a1edbd21e88f42c73edd2a14f95"),
|
| 58 |
+
"icml": ("icml_face_data.csv", "b114b9e04e6949e5fe8b6a98b3892b1d"),
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
def __init__(
|
| 62 |
+
self,
|
| 63 |
+
root: Union[str, pathlib.Path],
|
| 64 |
+
split: str = "train",
|
| 65 |
+
transform: Optional[Callable] = None,
|
| 66 |
+
target_transform: Optional[Callable] = None,
|
| 67 |
+
) -> None:
|
| 68 |
+
self._split = verify_str_arg(split, "split", ("train", "test"))
|
| 69 |
+
super().__init__(root, transform=transform, target_transform=target_transform)
|
| 70 |
+
|
| 71 |
+
base_folder = pathlib.Path(self.root) / "fer2013"
|
| 72 |
+
use_fer_file = (base_folder / self._RESOURCES["fer"][0]).exists()
|
| 73 |
+
use_icml_file = not use_fer_file and (base_folder / self._RESOURCES["icml"][0]).exists()
|
| 74 |
+
file_name, md5 = self._RESOURCES["fer" if use_fer_file else "icml" if use_icml_file else self._split]
|
| 75 |
+
data_file = base_folder / file_name
|
| 76 |
+
if not check_integrity(str(data_file), md5=md5):
|
| 77 |
+
raise RuntimeError(
|
| 78 |
+
f"{file_name} not found in {base_folder} or corrupted. "
|
| 79 |
+
f"You can download it from "
|
| 80 |
+
f"https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge"
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
pixels_key = " pixels" if use_icml_file else "pixels"
|
| 84 |
+
usage_key = " Usage" if use_icml_file else "Usage"
|
| 85 |
+
|
| 86 |
+
def get_img(row):
|
| 87 |
+
return torch.tensor([int(idx) for idx in row[pixels_key].split()], dtype=torch.uint8).reshape(48, 48)
|
| 88 |
+
|
| 89 |
+
def get_label(row):
|
| 90 |
+
if use_fer_file or use_icml_file or self._split == "train":
|
| 91 |
+
return int(row["emotion"])
|
| 92 |
+
else:
|
| 93 |
+
return None
|
| 94 |
+
|
| 95 |
+
with open(data_file, "r", newline="") as file:
|
| 96 |
+
rows = (row for row in csv.DictReader(file))
|
| 97 |
+
|
| 98 |
+
if use_fer_file or use_icml_file:
|
| 99 |
+
valid_keys = ("Training",) if self._split == "train" else ("PublicTest", "PrivateTest")
|
| 100 |
+
rows = (row for row in rows if row[usage_key] in valid_keys)
|
| 101 |
+
|
| 102 |
+
self._samples = [(get_img(row), get_label(row)) for row in rows]
|
| 103 |
+
|
| 104 |
+
def __len__(self) -> int:
|
| 105 |
+
return len(self._samples)
|
| 106 |
+
|
| 107 |
+
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
|
| 108 |
+
image_tensor, target = self._samples[idx]
|
| 109 |
+
image = Image.fromarray(image_tensor.numpy())
|
| 110 |
+
|
| 111 |
+
if self.transform is not None:
|
| 112 |
+
image = self.transform(image)
|
| 113 |
+
|
| 114 |
+
if self.target_transform is not None:
|
| 115 |
+
target = self.target_transform(target)
|
| 116 |
+
|
| 117 |
+
return image, target
|
| 118 |
+
|
| 119 |
+
def extra_repr(self) -> str:
|
| 120 |
+
return f"split={self._split}"
|
.venv/lib/python3.11/site-packages/torchvision/datasets/fgvc_aircraft.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any, Callable, Optional, Tuple, Union
|
| 6 |
+
|
| 7 |
+
import PIL.Image
|
| 8 |
+
|
| 9 |
+
from .utils import download_and_extract_archive, verify_str_arg
|
| 10 |
+
from .vision import VisionDataset
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class FGVCAircraft(VisionDataset):
|
| 14 |
+
"""`FGVC Aircraft <https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/>`_ Dataset.
|
| 15 |
+
|
| 16 |
+
The dataset contains 10,000 images of aircraft, with 100 images for each of 100
|
| 17 |
+
different aircraft model variants, most of which are airplanes.
|
| 18 |
+
Aircraft models are organized in a three-levels hierarchy. The three levels, from
|
| 19 |
+
finer to coarser, are:
|
| 20 |
+
|
| 21 |
+
- ``variant``, e.g. Boeing 737-700. A variant collapses all the models that are visually
|
| 22 |
+
indistinguishable into one class. The dataset comprises 100 different variants.
|
| 23 |
+
- ``family``, e.g. Boeing 737. The dataset comprises 70 different families.
|
| 24 |
+
- ``manufacturer``, e.g. Boeing. The dataset comprises 30 different manufacturers.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
root (str or ``pathlib.Path``): Root directory of the FGVC Aircraft dataset.
|
| 28 |
+
split (string, optional): The dataset split, supports ``train``, ``val``,
|
| 29 |
+
``trainval`` and ``test``.
|
| 30 |
+
annotation_level (str, optional): The annotation level, supports ``variant``,
|
| 31 |
+
``family`` and ``manufacturer``.
|
| 32 |
+
transform (callable, optional): A function/transform that takes in a PIL image
|
| 33 |
+
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
| 34 |
+
target_transform (callable, optional): A function/transform that takes in the
|
| 35 |
+
target and transforms it.
|
| 36 |
+
download (bool, optional): If True, downloads the dataset from the internet and
|
| 37 |
+
puts it in root directory. If dataset is already downloaded, it is not
|
| 38 |
+
downloaded again.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
_URL = "https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz"
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
root: Union[str, Path],
|
| 46 |
+
split: str = "trainval",
|
| 47 |
+
annotation_level: str = "variant",
|
| 48 |
+
transform: Optional[Callable] = None,
|
| 49 |
+
target_transform: Optional[Callable] = None,
|
| 50 |
+
download: bool = False,
|
| 51 |
+
) -> None:
|
| 52 |
+
super().__init__(root, transform=transform, target_transform=target_transform)
|
| 53 |
+
self._split = verify_str_arg(split, "split", ("train", "val", "trainval", "test"))
|
| 54 |
+
self._annotation_level = verify_str_arg(
|
| 55 |
+
annotation_level, "annotation_level", ("variant", "family", "manufacturer")
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
self._data_path = os.path.join(self.root, "fgvc-aircraft-2013b")
|
| 59 |
+
if download:
|
| 60 |
+
self._download()
|
| 61 |
+
|
| 62 |
+
if not self._check_exists():
|
| 63 |
+
raise RuntimeError("Dataset not found. You can use download=True to download it")
|
| 64 |
+
|
| 65 |
+
annotation_file = os.path.join(
|
| 66 |
+
self._data_path,
|
| 67 |
+
"data",
|
| 68 |
+
{
|
| 69 |
+
"variant": "variants.txt",
|
| 70 |
+
"family": "families.txt",
|
| 71 |
+
"manufacturer": "manufacturers.txt",
|
| 72 |
+
}[self._annotation_level],
|
| 73 |
+
)
|
| 74 |
+
with open(annotation_file, "r") as f:
|
| 75 |
+
self.classes = [line.strip() for line in f]
|
| 76 |
+
|
| 77 |
+
self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
|
| 78 |
+
|
| 79 |
+
image_data_folder = os.path.join(self._data_path, "data", "images")
|
| 80 |
+
labels_file = os.path.join(self._data_path, "data", f"images_{self._annotation_level}_{self._split}.txt")
|
| 81 |
+
|
| 82 |
+
self._image_files = []
|
| 83 |
+
self._labels = []
|
| 84 |
+
|
| 85 |
+
with open(labels_file, "r") as f:
|
| 86 |
+
for line in f:
|
| 87 |
+
image_name, label_name = line.strip().split(" ", 1)
|
| 88 |
+
self._image_files.append(os.path.join(image_data_folder, f"{image_name}.jpg"))
|
| 89 |
+
self._labels.append(self.class_to_idx[label_name])
|
| 90 |
+
|
| 91 |
+
def __len__(self) -> int:
|
| 92 |
+
return len(self._image_files)
|
| 93 |
+
|
| 94 |
+
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
|
| 95 |
+
image_file, label = self._image_files[idx], self._labels[idx]
|
| 96 |
+
image = PIL.Image.open(image_file).convert("RGB")
|
| 97 |
+
|
| 98 |
+
if self.transform:
|
| 99 |
+
image = self.transform(image)
|
| 100 |
+
|
| 101 |
+
if self.target_transform:
|
| 102 |
+
label = self.target_transform(label)
|
| 103 |
+
|
| 104 |
+
return image, label
|
| 105 |
+
|
| 106 |
+
def _download(self) -> None:
|
| 107 |
+
"""
|
| 108 |
+
Download the FGVC Aircraft dataset archive and extract it under root.
|
| 109 |
+
"""
|
| 110 |
+
if self._check_exists():
|
| 111 |
+
return
|
| 112 |
+
download_and_extract_archive(self._URL, self.root)
|
| 113 |
+
|
| 114 |
+
def _check_exists(self) -> bool:
|
| 115 |
+
return os.path.exists(self._data_path) and os.path.isdir(self._data_path)
|
.venv/lib/python3.11/site-packages/torchvision/datasets/flickr.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
from html.parser import HTMLParser
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 7 |
+
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
from .vision import VisionDataset
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Flickr8kParser(HTMLParser):
|
| 14 |
+
"""Parser for extracting captions from the Flickr8k dataset web page."""
|
| 15 |
+
|
| 16 |
+
def __init__(self, root: Union[str, Path]) -> None:
|
| 17 |
+
super().__init__()
|
| 18 |
+
|
| 19 |
+
self.root = root
|
| 20 |
+
|
| 21 |
+
# Data structure to store captions
|
| 22 |
+
self.annotations: Dict[str, List[str]] = {}
|
| 23 |
+
|
| 24 |
+
# State variables
|
| 25 |
+
self.in_table = False
|
| 26 |
+
self.current_tag: Optional[str] = None
|
| 27 |
+
self.current_img: Optional[str] = None
|
| 28 |
+
|
| 29 |
+
def handle_starttag(self, tag: str, attrs: List[Tuple[str, Optional[str]]]) -> None:
|
| 30 |
+
self.current_tag = tag
|
| 31 |
+
|
| 32 |
+
if tag == "table":
|
| 33 |
+
self.in_table = True
|
| 34 |
+
|
| 35 |
+
def handle_endtag(self, tag: str) -> None:
|
| 36 |
+
self.current_tag = None
|
| 37 |
+
|
| 38 |
+
if tag == "table":
|
| 39 |
+
self.in_table = False
|
| 40 |
+
|
| 41 |
+
def handle_data(self, data: str) -> None:
|
| 42 |
+
if self.in_table:
|
| 43 |
+
if data == "Image Not Found":
|
| 44 |
+
self.current_img = None
|
| 45 |
+
elif self.current_tag == "a":
|
| 46 |
+
img_id = data.split("/")[-2]
|
| 47 |
+
img_id = os.path.join(self.root, img_id + "_*.jpg")
|
| 48 |
+
img_id = glob.glob(img_id)[0]
|
| 49 |
+
self.current_img = img_id
|
| 50 |
+
self.annotations[img_id] = []
|
| 51 |
+
elif self.current_tag == "li" and self.current_img:
|
| 52 |
+
img_id = self.current_img
|
| 53 |
+
self.annotations[img_id].append(data.strip())
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class Flickr8k(VisionDataset):
|
| 57 |
+
"""`Flickr8k Entities <http://hockenmaier.cs.illinois.edu/8k-pictures.html>`_ Dataset.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
root (str or ``pathlib.Path``): Root directory where images are downloaded to.
|
| 61 |
+
ann_file (string): Path to annotation file.
|
| 62 |
+
transform (callable, optional): A function/transform that takes in a PIL image
|
| 63 |
+
and returns a transformed version. E.g, ``transforms.PILToTensor``
|
| 64 |
+
target_transform (callable, optional): A function/transform that takes in the
|
| 65 |
+
target and transforms it.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(
|
| 69 |
+
self,
|
| 70 |
+
root: Union[str, Path],
|
| 71 |
+
ann_file: str,
|
| 72 |
+
transform: Optional[Callable] = None,
|
| 73 |
+
target_transform: Optional[Callable] = None,
|
| 74 |
+
) -> None:
|
| 75 |
+
super().__init__(root, transform=transform, target_transform=target_transform)
|
| 76 |
+
self.ann_file = os.path.expanduser(ann_file)
|
| 77 |
+
|
| 78 |
+
# Read annotations and store in a dict
|
| 79 |
+
parser = Flickr8kParser(self.root)
|
| 80 |
+
with open(self.ann_file) as fh:
|
| 81 |
+
parser.feed(fh.read())
|
| 82 |
+
self.annotations = parser.annotations
|
| 83 |
+
|
| 84 |
+
self.ids = list(sorted(self.annotations.keys()))
|
| 85 |
+
|
| 86 |
+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
| 87 |
+
"""
|
| 88 |
+
Args:
|
| 89 |
+
index (int): Index
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
tuple: Tuple (image, target). target is a list of captions for the image.
|
| 93 |
+
"""
|
| 94 |
+
img_id = self.ids[index]
|
| 95 |
+
|
| 96 |
+
# Image
|
| 97 |
+
img = Image.open(img_id).convert("RGB")
|
| 98 |
+
if self.transform is not None:
|
| 99 |
+
img = self.transform(img)
|
| 100 |
+
|
| 101 |
+
# Captions
|
| 102 |
+
target = self.annotations[img_id]
|
| 103 |
+
if self.target_transform is not None:
|
| 104 |
+
target = self.target_transform(target)
|
| 105 |
+
|
| 106 |
+
return img, target
|
| 107 |
+
|
| 108 |
+
def __len__(self) -> int:
|
| 109 |
+
return len(self.ids)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class Flickr30k(VisionDataset):
|
| 113 |
+
"""`Flickr30k Entities <https://bryanplummer.com/Flickr30kEntities/>`_ Dataset.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
root (str or ``pathlib.Path``): Root directory where images are downloaded to.
|
| 117 |
+
ann_file (string): Path to annotation file.
|
| 118 |
+
transform (callable, optional): A function/transform that takes in a PIL image
|
| 119 |
+
and returns a transformed version. E.g, ``transforms.PILToTensor``
|
| 120 |
+
target_transform (callable, optional): A function/transform that takes in the
|
| 121 |
+
target and transforms it.
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
root: str,
|
| 127 |
+
ann_file: str,
|
| 128 |
+
transform: Optional[Callable] = None,
|
| 129 |
+
target_transform: Optional[Callable] = None,
|
| 130 |
+
) -> None:
|
| 131 |
+
super().__init__(root, transform=transform, target_transform=target_transform)
|
| 132 |
+
self.ann_file = os.path.expanduser(ann_file)
|
| 133 |
+
|
| 134 |
+
# Read annotations and store in a dict
|
| 135 |
+
self.annotations = defaultdict(list)
|
| 136 |
+
with open(self.ann_file) as fh:
|
| 137 |
+
for line in fh:
|
| 138 |
+
img_id, caption = line.strip().split("\t")
|
| 139 |
+
self.annotations[img_id[:-2]].append(caption)
|
| 140 |
+
|
| 141 |
+
self.ids = list(sorted(self.annotations.keys()))
|
| 142 |
+
|
| 143 |
+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
| 144 |
+
"""
|
| 145 |
+
Args:
|
| 146 |
+
index (int): Index
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
tuple: Tuple (image, target). target is a list of captions for the image.
|
| 150 |
+
"""
|
| 151 |
+
img_id = self.ids[index]
|
| 152 |
+
|
| 153 |
+
# Image
|
| 154 |
+
filename = os.path.join(self.root, img_id)
|
| 155 |
+
img = Image.open(filename).convert("RGB")
|
| 156 |
+
if self.transform is not None:
|
| 157 |
+
img = self.transform(img)
|
| 158 |
+
|
| 159 |
+
# Captions
|
| 160 |
+
target = self.annotations[img_id]
|
| 161 |
+
if self.target_transform is not None:
|
| 162 |
+
target = self.target_transform(target)
|
| 163 |
+
|
| 164 |
+
return img, target
|
| 165 |
+
|
| 166 |
+
def __len__(self) -> int:
|
| 167 |
+
return len(self.ids)
|
.venv/lib/python3.11/site-packages/torchvision/datasets/flowers102.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from typing import Any, Callable, Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import PIL.Image
|
| 5 |
+
|
| 6 |
+
from .utils import check_integrity, download_and_extract_archive, download_url, verify_str_arg
|
| 7 |
+
from .vision import VisionDataset
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Flowers102(VisionDataset):
|
| 11 |
+
"""`Oxford 102 Flower <https://www.robots.ox.ac.uk/~vgg/data/flowers/102/>`_ Dataset.
|
| 12 |
+
|
| 13 |
+
.. warning::
|
| 14 |
+
|
| 15 |
+
This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
|
| 16 |
+
|
| 17 |
+
Oxford 102 Flower is an image classification dataset consisting of 102 flower categories. The
|
| 18 |
+
flowers were chosen to be flowers commonly occurring in the United Kingdom. Each class consists of
|
| 19 |
+
between 40 and 258 images.
|
| 20 |
+
|
| 21 |
+
The images have large scale, pose and light variations. In addition, there are categories that
|
| 22 |
+
have large variations within the category, and several very similar categories.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
root (str or ``pathlib.Path``): Root directory of the dataset.
|
| 26 |
+
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
|
| 27 |
+
transform (callable, optional): A function/transform that takes in a PIL image and returns a
|
| 28 |
+
transformed version. E.g, ``transforms.RandomCrop``.
|
| 29 |
+
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
|
| 30 |
+
download (bool, optional): If true, downloads the dataset from the internet and
|
| 31 |
+
puts it in root directory. If dataset is already downloaded, it is not
|
| 32 |
+
downloaded again.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
_download_url_prefix = "https://www.robots.ox.ac.uk/~vgg/data/flowers/102/"
|
| 36 |
+
_file_dict = { # filename, md5
|
| 37 |
+
"image": ("102flowers.tgz", "52808999861908f626f3c1f4e79d11fa"),
|
| 38 |
+
"label": ("imagelabels.mat", "e0620be6f572b9609742df49c70aed4d"),
|
| 39 |
+
"setid": ("setid.mat", "a5357ecc9cb78c4bef273ce3793fc85c"),
|
| 40 |
+
}
|
| 41 |
+
_splits_map = {"train": "trnid", "val": "valid", "test": "tstid"}
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
root: Union[str, Path],
|
| 46 |
+
split: str = "train",
|
| 47 |
+
transform: Optional[Callable] = None,
|
| 48 |
+
target_transform: Optional[Callable] = None,
|
| 49 |
+
download: bool = False,
|
| 50 |
+
) -> None:
|
| 51 |
+
super().__init__(root, transform=transform, target_transform=target_transform)
|
| 52 |
+
self._split = verify_str_arg(split, "split", ("train", "val", "test"))
|
| 53 |
+
self._base_folder = Path(self.root) / "flowers-102"
|
| 54 |
+
self._images_folder = self._base_folder / "jpg"
|
| 55 |
+
|
| 56 |
+
if download:
|
| 57 |
+
self.download()
|
| 58 |
+
|
| 59 |
+
if not self._check_integrity():
|
| 60 |
+
raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
|
| 61 |
+
|
| 62 |
+
from scipy.io import loadmat
|
| 63 |
+
|
| 64 |
+
set_ids = loadmat(self._base_folder / self._file_dict["setid"][0], squeeze_me=True)
|
| 65 |
+
image_ids = set_ids[self._splits_map[self._split]].tolist()
|
| 66 |
+
|
| 67 |
+
labels = loadmat(self._base_folder / self._file_dict["label"][0], squeeze_me=True)
|
| 68 |
+
image_id_to_label = dict(enumerate((labels["labels"] - 1).tolist(), 1))
|
| 69 |
+
|
| 70 |
+
self._labels = []
|
| 71 |
+
self._image_files = []
|
| 72 |
+
for image_id in image_ids:
|
| 73 |
+
self._labels.append(image_id_to_label[image_id])
|
| 74 |
+
self._image_files.append(self._images_folder / f"image_{image_id:05d}.jpg")
|
| 75 |
+
|
| 76 |
+
def __len__(self) -> int:
|
| 77 |
+
return len(self._image_files)
|
| 78 |
+
|
| 79 |
+
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
|
| 80 |
+
image_file, label = self._image_files[idx], self._labels[idx]
|
| 81 |
+
image = PIL.Image.open(image_file).convert("RGB")
|
| 82 |
+
|
| 83 |
+
if self.transform:
|
| 84 |
+
image = self.transform(image)
|
| 85 |
+
|
| 86 |
+
if self.target_transform:
|
| 87 |
+
label = self.target_transform(label)
|
| 88 |
+
|
| 89 |
+
return image, label
|
| 90 |
+
|
| 91 |
+
def extra_repr(self) -> str:
|
| 92 |
+
return f"split={self._split}"
|
| 93 |
+
|
| 94 |
+
def _check_integrity(self):
|
| 95 |
+
if not (self._images_folder.exists() and self._images_folder.is_dir()):
|
| 96 |
+
return False
|
| 97 |
+
|
| 98 |
+
for id in ["label", "setid"]:
|
| 99 |
+
filename, md5 = self._file_dict[id]
|
| 100 |
+
if not check_integrity(str(self._base_folder / filename), md5):
|
| 101 |
+
return False
|
| 102 |
+
return True
|
| 103 |
+
|
| 104 |
+
def download(self):
|
| 105 |
+
if self._check_integrity():
|
| 106 |
+
return
|
| 107 |
+
download_and_extract_archive(
|
| 108 |
+
f"{self._download_url_prefix}{self._file_dict['image'][0]}",
|
| 109 |
+
str(self._base_folder),
|
| 110 |
+
md5=self._file_dict["image"][1],
|
| 111 |
+
)
|
| 112 |
+
for id in ["label", "setid"]:
|
| 113 |
+
filename, md5 = self._file_dict[id]
|
| 114 |
+
download_url(self._download_url_prefix + filename, str(self._base_folder), md5=md5)
|
.venv/lib/python3.11/site-packages/torchvision/datasets/food101.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Any, Callable, Optional, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import PIL.Image
|
| 6 |
+
|
| 7 |
+
from .utils import download_and_extract_archive, verify_str_arg
|
| 8 |
+
from .vision import VisionDataset
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Food101(VisionDataset):
|
| 12 |
+
"""`The Food-101 Data Set <https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/>`_.
|
| 13 |
+
|
| 14 |
+
The Food-101 is a challenging data set of 101 food categories with 101,000 images.
|
| 15 |
+
For each class, 250 manually reviewed test images are provided as well as 750 training images.
|
| 16 |
+
On purpose, the training images were not cleaned, and thus still contain some amount of noise.
|
| 17 |
+
This comes mostly in the form of intense colors and sometimes wrong labels. All images were
|
| 18 |
+
rescaled to have a maximum side length of 512 pixels.
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
root (str or ``pathlib.Path``): Root directory of the dataset.
|
| 23 |
+
split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``.
|
| 24 |
+
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
|
| 25 |
+
version. E.g, ``transforms.RandomCrop``.
|
| 26 |
+
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
|
| 27 |
+
download (bool, optional): If True, downloads the dataset from the internet and
|
| 28 |
+
puts it in root directory. If dataset is already downloaded, it is not
|
| 29 |
+
downloaded again. Default is False.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
_URL = "http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz"
|
| 33 |
+
_MD5 = "85eeb15f3717b99a5da872d97d918f87"
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
root: Union[str, Path],
|
| 38 |
+
split: str = "train",
|
| 39 |
+
transform: Optional[Callable] = None,
|
| 40 |
+
target_transform: Optional[Callable] = None,
|
| 41 |
+
download: bool = False,
|
| 42 |
+
) -> None:
|
| 43 |
+
super().__init__(root, transform=transform, target_transform=target_transform)
|
| 44 |
+
self._split = verify_str_arg(split, "split", ("train", "test"))
|
| 45 |
+
self._base_folder = Path(self.root) / "food-101"
|
| 46 |
+
self._meta_folder = self._base_folder / "meta"
|
| 47 |
+
self._images_folder = self._base_folder / "images"
|
| 48 |
+
|
| 49 |
+
if download:
|
| 50 |
+
self._download()
|
| 51 |
+
|
| 52 |
+
if not self._check_exists():
|
| 53 |
+
raise RuntimeError("Dataset not found. You can use download=True to download it")
|
| 54 |
+
|
| 55 |
+
self._labels = []
|
| 56 |
+
self._image_files = []
|
| 57 |
+
with open(self._meta_folder / f"{split}.json") as f:
|
| 58 |
+
metadata = json.loads(f.read())
|
| 59 |
+
|
| 60 |
+
self.classes = sorted(metadata.keys())
|
| 61 |
+
self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
|
| 62 |
+
|
| 63 |
+
for class_label, im_rel_paths in metadata.items():
|
| 64 |
+
self._labels += [self.class_to_idx[class_label]] * len(im_rel_paths)
|
| 65 |
+
self._image_files += [
|
| 66 |
+
self._images_folder.joinpath(*f"{im_rel_path}.jpg".split("/")) for im_rel_path in im_rel_paths
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
def __len__(self) -> int:
|
| 70 |
+
return len(self._image_files)
|
| 71 |
+
|
| 72 |
+
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
|
| 73 |
+
image_file, label = self._image_files[idx], self._labels[idx]
|
| 74 |
+
image = PIL.Image.open(image_file).convert("RGB")
|
| 75 |
+
|
| 76 |
+
if self.transform:
|
| 77 |
+
image = self.transform(image)
|
| 78 |
+
|
| 79 |
+
if self.target_transform:
|
| 80 |
+
label = self.target_transform(label)
|
| 81 |
+
|
| 82 |
+
return image, label
|
| 83 |
+
|
| 84 |
+
def extra_repr(self) -> str:
|
| 85 |
+
return f"split={self._split}"
|
| 86 |
+
|
| 87 |
+
def _check_exists(self) -> bool:
|
| 88 |
+
return all(folder.exists() and folder.is_dir() for folder in (self._meta_folder, self._images_folder))
|
| 89 |
+
|
| 90 |
+
def _download(self) -> None:
|
| 91 |
+
if self._check_exists():
|
| 92 |
+
return
|
| 93 |
+
download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5)
|