koichi12 commited on
Commit
208efc9
·
verified ·
1 Parent(s): 04b7ba0

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/torchvision/datasets/__init__.py +146 -0
  2. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/_optical_flow.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/_stereo_matching.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/celeba.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/cifar.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/cityscapes.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/clevr.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/dtd.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/eurosat.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/fer2013.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/fgvc_aircraft.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/flowers102.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/folder.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/food101.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/gtsrb.cpython-311.pyc +0 -0
  16. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/imagenette.cpython-311.pyc +0 -0
  17. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/inaturalist.cpython-311.pyc +0 -0
  18. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/kitti.cpython-311.pyc +0 -0
  19. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/lfw.cpython-311.pyc +0 -0
  20. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/lsun.cpython-311.pyc +0 -0
  21. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/moving_mnist.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/omniglot.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/oxford_iiit_pet.cpython-311.pyc +0 -0
  24. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/pcam.cpython-311.pyc +0 -0
  25. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/phototour.cpython-311.pyc +0 -0
  26. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/places365.cpython-311.pyc +0 -0
  27. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/rendered_sst2.cpython-311.pyc +0 -0
  28. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/sbd.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/semeion.cpython-311.pyc +0 -0
  30. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/stanford_cars.cpython-311.pyc +0 -0
  31. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/stl10.cpython-311.pyc +0 -0
  32. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/svhn.cpython-311.pyc +0 -0
  33. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/ucf101.cpython-311.pyc +0 -0
  34. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/usps.cpython-311.pyc +0 -0
  35. .venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/vision.cpython-311.pyc +0 -0
  36. .venv/lib/python3.11/site-packages/torchvision/datasets/_stereo_matching.py +1224 -0
  37. .venv/lib/python3.11/site-packages/torchvision/datasets/caltech.py +242 -0
  38. .venv/lib/python3.11/site-packages/torchvision/datasets/celeba.py +194 -0
  39. .venv/lib/python3.11/site-packages/torchvision/datasets/cifar.py +168 -0
  40. .venv/lib/python3.11/site-packages/torchvision/datasets/cityscapes.py +222 -0
  41. .venv/lib/python3.11/site-packages/torchvision/datasets/clevr.py +88 -0
  42. .venv/lib/python3.11/site-packages/torchvision/datasets/coco.py +109 -0
  43. .venv/lib/python3.11/site-packages/torchvision/datasets/dtd.py +100 -0
  44. .venv/lib/python3.11/site-packages/torchvision/datasets/eurosat.py +62 -0
  45. .venv/lib/python3.11/site-packages/torchvision/datasets/fakedata.py +67 -0
  46. .venv/lib/python3.11/site-packages/torchvision/datasets/fer2013.py +120 -0
  47. .venv/lib/python3.11/site-packages/torchvision/datasets/fgvc_aircraft.py +115 -0
  48. .venv/lib/python3.11/site-packages/torchvision/datasets/flickr.py +167 -0
  49. .venv/lib/python3.11/site-packages/torchvision/datasets/flowers102.py +114 -0
  50. .venv/lib/python3.11/site-packages/torchvision/datasets/food101.py +93 -0
.venv/lib/python3.11/site-packages/torchvision/datasets/__init__.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel
2
+ from ._stereo_matching import (
3
+ CarlaStereo,
4
+ CREStereo,
5
+ ETH3DStereo,
6
+ FallingThingsStereo,
7
+ InStereo2k,
8
+ Kitti2012Stereo,
9
+ Kitti2015Stereo,
10
+ Middlebury2014Stereo,
11
+ SceneFlowStereo,
12
+ SintelStereo,
13
+ )
14
+ from .caltech import Caltech101, Caltech256
15
+ from .celeba import CelebA
16
+ from .cifar import CIFAR10, CIFAR100
17
+ from .cityscapes import Cityscapes
18
+ from .clevr import CLEVRClassification
19
+ from .coco import CocoCaptions, CocoDetection
20
+ from .country211 import Country211
21
+ from .dtd import DTD
22
+ from .eurosat import EuroSAT
23
+ from .fakedata import FakeData
24
+ from .fer2013 import FER2013
25
+ from .fgvc_aircraft import FGVCAircraft
26
+ from .flickr import Flickr30k, Flickr8k
27
+ from .flowers102 import Flowers102
28
+ from .folder import DatasetFolder, ImageFolder
29
+ from .food101 import Food101
30
+ from .gtsrb import GTSRB
31
+ from .hmdb51 import HMDB51
32
+ from .imagenet import ImageNet
33
+ from .imagenette import Imagenette
34
+ from .inaturalist import INaturalist
35
+ from .kinetics import Kinetics
36
+ from .kitti import Kitti
37
+ from .lfw import LFWPairs, LFWPeople
38
+ from .lsun import LSUN, LSUNClass
39
+ from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST
40
+ from .moving_mnist import MovingMNIST
41
+ from .omniglot import Omniglot
42
+ from .oxford_iiit_pet import OxfordIIITPet
43
+ from .pcam import PCAM
44
+ from .phototour import PhotoTour
45
+ from .places365 import Places365
46
+ from .rendered_sst2 import RenderedSST2
47
+ from .sbd import SBDataset
48
+ from .sbu import SBU
49
+ from .semeion import SEMEION
50
+ from .stanford_cars import StanfordCars
51
+ from .stl10 import STL10
52
+ from .sun397 import SUN397
53
+ from .svhn import SVHN
54
+ from .ucf101 import UCF101
55
+ from .usps import USPS
56
+ from .vision import VisionDataset
57
+ from .voc import VOCDetection, VOCSegmentation
58
+ from .widerface import WIDERFace
59
+
60
+ __all__ = (
61
+ "LSUN",
62
+ "LSUNClass",
63
+ "ImageFolder",
64
+ "DatasetFolder",
65
+ "FakeData",
66
+ "CocoCaptions",
67
+ "CocoDetection",
68
+ "CIFAR10",
69
+ "CIFAR100",
70
+ "EMNIST",
71
+ "FashionMNIST",
72
+ "QMNIST",
73
+ "MNIST",
74
+ "KMNIST",
75
+ "StanfordCars",
76
+ "STL10",
77
+ "SUN397",
78
+ "SVHN",
79
+ "PhotoTour",
80
+ "SEMEION",
81
+ "Omniglot",
82
+ "SBU",
83
+ "Flickr8k",
84
+ "Flickr30k",
85
+ "Flowers102",
86
+ "VOCSegmentation",
87
+ "VOCDetection",
88
+ "Cityscapes",
89
+ "ImageNet",
90
+ "Caltech101",
91
+ "Caltech256",
92
+ "CelebA",
93
+ "WIDERFace",
94
+ "SBDataset",
95
+ "VisionDataset",
96
+ "USPS",
97
+ "Kinetics",
98
+ "HMDB51",
99
+ "UCF101",
100
+ "Places365",
101
+ "Kitti",
102
+ "INaturalist",
103
+ "LFWPeople",
104
+ "LFWPairs",
105
+ "KittiFlow",
106
+ "Sintel",
107
+ "FlyingChairs",
108
+ "FlyingThings3D",
109
+ "HD1K",
110
+ "Food101",
111
+ "DTD",
112
+ "FER2013",
113
+ "GTSRB",
114
+ "CLEVRClassification",
115
+ "OxfordIIITPet",
116
+ "PCAM",
117
+ "Country211",
118
+ "FGVCAircraft",
119
+ "EuroSAT",
120
+ "RenderedSST2",
121
+ "Kitti2012Stereo",
122
+ "Kitti2015Stereo",
123
+ "CarlaStereo",
124
+ "Middlebury2014Stereo",
125
+ "CREStereo",
126
+ "FallingThingsStereo",
127
+ "SceneFlowStereo",
128
+ "SintelStereo",
129
+ "InStereo2k",
130
+ "ETH3DStereo",
131
+ "wrap_dataset_for_transforms_v2",
132
+ "Imagenette",
133
+ )
134
+
135
+
136
+ # We override current module's attributes to handle the import:
137
+ # from torchvision.datasets import wrap_dataset_for_transforms_v2
138
+ # without a cyclic error.
139
+ # Ref: https://peps.python.org/pep-0562/
140
+ def __getattr__(name):
141
+ if name in ("wrap_dataset_for_transforms_v2",):
142
+ from torchvision.tv_tensors._dataset_wrapper import wrap_dataset_for_transforms_v2
143
+
144
+ return wrap_dataset_for_transforms_v2
145
+
146
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/_optical_flow.cpython-311.pyc ADDED
Binary file (27.8 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/_stereo_matching.cpython-311.pyc ADDED
Binary file (59.3 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/celeba.cpython-311.pyc ADDED
Binary file (12 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/cifar.cpython-311.pyc ADDED
Binary file (9.05 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/cityscapes.cpython-311.pyc ADDED
Binary file (14.3 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/clevr.cpython-311.pyc ADDED
Binary file (6.58 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/dtd.cpython-311.pyc ADDED
Binary file (7.21 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/eurosat.cpython-311.pyc ADDED
Binary file (4 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/fer2013.cpython-311.pyc ADDED
Binary file (7.66 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/fgvc_aircraft.cpython-311.pyc ADDED
Binary file (7.66 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/flowers102.cpython-311.pyc ADDED
Binary file (7.26 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/folder.cpython-311.pyc ADDED
Binary file (17.4 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/food101.cpython-311.pyc ADDED
Binary file (6.97 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/gtsrb.cpython-311.pyc ADDED
Binary file (6.04 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/imagenette.cpython-311.pyc ADDED
Binary file (7.02 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/inaturalist.cpython-311.pyc ADDED
Binary file (14.3 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/kitti.cpython-311.pyc ADDED
Binary file (9.38 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/lfw.cpython-311.pyc ADDED
Binary file (18.8 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/lsun.cpython-311.pyc ADDED
Binary file (10.3 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/moving_mnist.cpython-311.pyc ADDED
Binary file (6.03 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/omniglot.cpython-311.pyc ADDED
Binary file (7.32 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/oxford_iiit_pet.cpython-311.pyc ADDED
Binary file (10 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/pcam.cpython-311.pyc ADDED
Binary file (7.86 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/phototour.cpython-311.pyc ADDED
Binary file (13.4 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/places365.cpython-311.pyc ADDED
Binary file (12.3 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/rendered_sst2.cpython-311.pyc ADDED
Binary file (5.83 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/sbd.cpython-311.pyc ADDED
Binary file (9.53 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/semeion.cpython-311.pyc ADDED
Binary file (5.17 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/stanford_cars.cpython-311.pyc ADDED
Binary file (6.91 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/stl10.cpython-311.pyc ADDED
Binary file (11.9 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/svhn.cpython-311.pyc ADDED
Binary file (6.68 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/ucf101.cpython-311.pyc ADDED
Binary file (8.49 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/usps.cpython-311.pyc ADDED
Binary file (6.15 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/__pycache__/vision.cpython-311.pyc ADDED
Binary file (7.58 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/datasets/_stereo_matching.py ADDED
@@ -0,0 +1,1224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import json
3
+ import os
4
+ import random
5
+ import shutil
6
+ from abc import ABC, abstractmethod
7
+ from glob import glob
8
+ from pathlib import Path
9
+ from typing import Callable, cast, List, Optional, Tuple, Union
10
+
11
+ import numpy as np
12
+ from PIL import Image
13
+
14
+ from .utils import _read_pfm, download_and_extract_archive, verify_str_arg
15
+ from .vision import VisionDataset
16
+
17
+ T1 = Tuple[Image.Image, Image.Image, Optional[np.ndarray], np.ndarray]
18
+ T2 = Tuple[Image.Image, Image.Image, Optional[np.ndarray]]
19
+
20
+ __all__ = ()
21
+
22
+ _read_pfm_file = functools.partial(_read_pfm, slice_channels=1)
23
+
24
+
25
+ class StereoMatchingDataset(ABC, VisionDataset):
26
+ """Base interface for Stereo matching datasets"""
27
+
28
+ _has_built_in_disparity_mask = False
29
+
30
+ def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None:
31
+ """
32
+ Args:
33
+ root(str): Root directory of the dataset.
34
+ transforms(callable, optional): A function/transform that takes in Tuples of
35
+ (images, disparities, valid_masks) and returns a transformed version of each of them.
36
+ images is a Tuple of (``PIL.Image``, ``PIL.Image``)
37
+ disparities is a Tuple of (``np.ndarray``, ``np.ndarray``) with shape (1, H, W)
38
+ valid_masks is a Tuple of (``np.ndarray``, ``np.ndarray``) with shape (H, W)
39
+ In some cases, when a dataset does not provide disparities, the ``disparities`` and
40
+ ``valid_masks`` can be Tuples containing None values.
41
+ For training splits generally the datasets provide a minimal guarantee of
42
+ images: (``PIL.Image``, ``PIL.Image``)
43
+ disparities: (``np.ndarray``, ``None``) with shape (1, H, W)
44
+ Optionally, based on the dataset, it can return a ``mask`` as well:
45
+ valid_masks: (``np.ndarray | None``, ``None``) with shape (H, W)
46
+ For some test splits, the datasets provides outputs that look like:
47
+ imgaes: (``PIL.Image``, ``PIL.Image``)
48
+ disparities: (``None``, ``None``)
49
+ Optionally, based on the dataset, it can return a ``mask`` as well:
50
+ valid_masks: (``None``, ``None``)
51
+ """
52
+ super().__init__(root=root)
53
+ self.transforms = transforms
54
+
55
+ self._images = [] # type: ignore
56
+ self._disparities = [] # type: ignore
57
+
58
+ def _read_img(self, file_path: Union[str, Path]) -> Image.Image:
59
+ img = Image.open(file_path)
60
+ if img.mode != "RGB":
61
+ img = img.convert("RGB") # type: ignore [assignment]
62
+ return img
63
+
64
+ def _scan_pairs(
65
+ self,
66
+ paths_left_pattern: str,
67
+ paths_right_pattern: Optional[str] = None,
68
+ ) -> List[Tuple[str, Optional[str]]]:
69
+
70
+ left_paths = list(sorted(glob(paths_left_pattern)))
71
+
72
+ right_paths: List[Union[None, str]]
73
+ if paths_right_pattern:
74
+ right_paths = list(sorted(glob(paths_right_pattern)))
75
+ else:
76
+ right_paths = list(None for _ in left_paths)
77
+
78
+ if not left_paths:
79
+ raise FileNotFoundError(f"Could not find any files matching the patterns: {paths_left_pattern}")
80
+
81
+ if not right_paths:
82
+ raise FileNotFoundError(f"Could not find any files matching the patterns: {paths_right_pattern}")
83
+
84
+ if len(left_paths) != len(right_paths):
85
+ raise ValueError(
86
+ f"Found {len(left_paths)} left files but {len(right_paths)} right files using:\n "
87
+ f"left pattern: {paths_left_pattern}\n"
88
+ f"right pattern: {paths_right_pattern}\n"
89
+ )
90
+
91
+ paths = list((left, right) for left, right in zip(left_paths, right_paths))
92
+ return paths
93
+
94
+ @abstractmethod
95
+ def _read_disparity(self, file_path: str) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
96
+ # function that returns a disparity map and an occlusion map
97
+ pass
98
+
99
+ def __getitem__(self, index: int) -> Union[T1, T2]:
100
+ """Return example at given index.
101
+
102
+ Args:
103
+ index(int): The index of the example to retrieve
104
+
105
+ Returns:
106
+ tuple: A 3 or 4-tuple with ``(img_left, img_right, disparity, Optional[valid_mask])`` where ``valid_mask``
107
+ can be a numpy boolean mask of shape (H, W) if the dataset provides a file
108
+ indicating which disparity pixels are valid. The disparity is a numpy array of
109
+ shape (1, H, W) and the images are PIL images. ``disparity`` is None for
110
+ datasets on which for ``split="test"`` the authors did not provide annotations.
111
+ """
112
+ img_left = self._read_img(self._images[index][0])
113
+ img_right = self._read_img(self._images[index][1])
114
+
115
+ dsp_map_left, valid_mask_left = self._read_disparity(self._disparities[index][0])
116
+ dsp_map_right, valid_mask_right = self._read_disparity(self._disparities[index][1])
117
+
118
+ imgs = (img_left, img_right)
119
+ dsp_maps = (dsp_map_left, dsp_map_right)
120
+ valid_masks = (valid_mask_left, valid_mask_right)
121
+
122
+ if self.transforms is not None:
123
+ (
124
+ imgs,
125
+ dsp_maps,
126
+ valid_masks,
127
+ ) = self.transforms(imgs, dsp_maps, valid_masks)
128
+
129
+ if self._has_built_in_disparity_mask or valid_masks[0] is not None:
130
+ return imgs[0], imgs[1], dsp_maps[0], cast(np.ndarray, valid_masks[0])
131
+ else:
132
+ return imgs[0], imgs[1], dsp_maps[0]
133
+
134
+ def __len__(self) -> int:
135
+ return len(self._images)
136
+
137
+
138
+ class CarlaStereo(StereoMatchingDataset):
139
+ """
140
+ Carla simulator data linked in the `CREStereo github repo <https://github.com/megvii-research/CREStereo>`_.
141
+
142
+ The dataset is expected to have the following structure: ::
143
+
144
+ root
145
+ carla-highres
146
+ trainingF
147
+ scene1
148
+ img0.png
149
+ img1.png
150
+ disp0GT.pfm
151
+ disp1GT.pfm
152
+ calib.txt
153
+ scene2
154
+ img0.png
155
+ img1.png
156
+ disp0GT.pfm
157
+ disp1GT.pfm
158
+ calib.txt
159
+ ...
160
+
161
+ Args:
162
+ root (str or ``pathlib.Path``): Root directory where `carla-highres` is located.
163
+ transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
164
+ """
165
+
166
+ def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None:
167
+ super().__init__(root, transforms)
168
+
169
+ root = Path(root) / "carla-highres"
170
+
171
+ left_image_pattern = str(root / "trainingF" / "*" / "im0.png")
172
+ right_image_pattern = str(root / "trainingF" / "*" / "im1.png")
173
+ imgs = self._scan_pairs(left_image_pattern, right_image_pattern)
174
+ self._images = imgs
175
+
176
+ left_disparity_pattern = str(root / "trainingF" / "*" / "disp0GT.pfm")
177
+ right_disparity_pattern = str(root / "trainingF" / "*" / "disp1GT.pfm")
178
+ disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
179
+ self._disparities = disparities
180
+
181
+ def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
182
+ disparity_map = _read_pfm_file(file_path)
183
+ disparity_map = np.abs(disparity_map) # ensure that the disparity is positive
184
+ valid_mask = None
185
+ return disparity_map, valid_mask
186
+
187
+ def __getitem__(self, index: int) -> T1:
188
+ """Return example at given index.
189
+
190
+ Args:
191
+ index(int): The index of the example to retrieve
192
+
193
+ Returns:
194
+ tuple: A 3-tuple with ``(img_left, img_right, disparity)``.
195
+ The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
196
+ If a ``valid_mask`` is generated within the ``transforms`` parameter,
197
+ a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
198
+ """
199
+ return cast(T1, super().__getitem__(index))
200
+
201
+
202
+ class Kitti2012Stereo(StereoMatchingDataset):
203
+ """
204
+ KITTI dataset from the `2012 stereo evaluation benchmark <http://www.cvlibs.net/datasets/kitti/eval_stereo_flow.php>`_.
205
+ Uses the RGB images for consistency with KITTI 2015.
206
+
207
+ The dataset is expected to have the following structure: ::
208
+
209
+ root
210
+ Kitti2012
211
+ testing
212
+ colored_0
213
+ 1_10.png
214
+ 2_10.png
215
+ ...
216
+ colored_1
217
+ 1_10.png
218
+ 2_10.png
219
+ ...
220
+ training
221
+ colored_0
222
+ 1_10.png
223
+ 2_10.png
224
+ ...
225
+ colored_1
226
+ 1_10.png
227
+ 2_10.png
228
+ ...
229
+ disp_noc
230
+ 1.png
231
+ 2.png
232
+ ...
233
+ calib
234
+
235
+ Args:
236
+ root (str or ``pathlib.Path``): Root directory where `Kitti2012` is located.
237
+ split (string, optional): The dataset split of scenes, either "train" (default) or "test".
238
+ transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
239
+ """
240
+
241
+ _has_built_in_disparity_mask = True
242
+
243
+ def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
244
+ super().__init__(root, transforms)
245
+
246
+ verify_str_arg(split, "split", valid_values=("train", "test"))
247
+
248
+ root = Path(root) / "Kitti2012" / (split + "ing")
249
+
250
+ left_img_pattern = str(root / "colored_0" / "*_10.png")
251
+ right_img_pattern = str(root / "colored_1" / "*_10.png")
252
+ self._images = self._scan_pairs(left_img_pattern, right_img_pattern)
253
+
254
+ if split == "train":
255
+ disparity_pattern = str(root / "disp_noc" / "*.png")
256
+ self._disparities = self._scan_pairs(disparity_pattern, None)
257
+ else:
258
+ self._disparities = list((None, None) for _ in self._images)
259
+
260
+ def _read_disparity(self, file_path: str) -> Tuple[Optional[np.ndarray], None]:
261
+ # test split has no disparity maps
262
+ if file_path is None:
263
+ return None, None
264
+
265
+ disparity_map = np.asarray(Image.open(file_path)) / 256.0
266
+ # unsqueeze the disparity map into (C, H, W) format
267
+ disparity_map = disparity_map[None, :, :]
268
+ valid_mask = None
269
+ return disparity_map, valid_mask
270
+
271
+ def __getitem__(self, index: int) -> T1:
272
+ """Return example at given index.
273
+
274
+ Args:
275
+ index(int): The index of the example to retrieve
276
+
277
+ Returns:
278
+ tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
279
+ The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
280
+ ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not
281
+ generate a valid mask.
282
+ Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
283
+ """
284
+ return cast(T1, super().__getitem__(index))
285
+
286
+
287
+ class Kitti2015Stereo(StereoMatchingDataset):
288
+ """
289
+ KITTI dataset from the `2015 stereo evaluation benchmark <http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php>`_.
290
+
291
+ The dataset is expected to have the following structure: ::
292
+
293
+ root
294
+ Kitti2015
295
+ testing
296
+ image_2
297
+ img1.png
298
+ img2.png
299
+ ...
300
+ image_3
301
+ img1.png
302
+ img2.png
303
+ ...
304
+ training
305
+ image_2
306
+ img1.png
307
+ img2.png
308
+ ...
309
+ image_3
310
+ img1.png
311
+ img2.png
312
+ ...
313
+ disp_occ_0
314
+ img1.png
315
+ img2.png
316
+ ...
317
+ disp_occ_1
318
+ img1.png
319
+ img2.png
320
+ ...
321
+ calib
322
+
323
+ Args:
324
+ root (str or ``pathlib.Path``): Root directory where `Kitti2015` is located.
325
+ split (string, optional): The dataset split of scenes, either "train" (default) or "test".
326
+ transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
327
+ """
328
+
329
+ _has_built_in_disparity_mask = True
330
+
331
+ def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
332
+ super().__init__(root, transforms)
333
+
334
+ verify_str_arg(split, "split", valid_values=("train", "test"))
335
+
336
+ root = Path(root) / "Kitti2015" / (split + "ing")
337
+ left_img_pattern = str(root / "image_2" / "*.png")
338
+ right_img_pattern = str(root / "image_3" / "*.png")
339
+ self._images = self._scan_pairs(left_img_pattern, right_img_pattern)
340
+
341
+ if split == "train":
342
+ left_disparity_pattern = str(root / "disp_occ_0" / "*.png")
343
+ right_disparity_pattern = str(root / "disp_occ_1" / "*.png")
344
+ self._disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
345
+ else:
346
+ self._disparities = list((None, None) for _ in self._images)
347
+
348
+ def _read_disparity(self, file_path: str) -> Tuple[Optional[np.ndarray], None]:
349
+ # test split has no disparity maps
350
+ if file_path is None:
351
+ return None, None
352
+
353
+ disparity_map = np.asarray(Image.open(file_path)) / 256.0
354
+ # unsqueeze the disparity map into (C, H, W) format
355
+ disparity_map = disparity_map[None, :, :]
356
+ valid_mask = None
357
+ return disparity_map, valid_mask
358
+
359
+ def __getitem__(self, index: int) -> T1:
360
+ """Return example at given index.
361
+
362
+ Args:
363
+ index(int): The index of the example to retrieve
364
+
365
+ Returns:
366
+ tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
367
+ The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
368
+ ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not
369
+ generate a valid mask.
370
+ Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
371
+ """
372
+ return cast(T1, super().__getitem__(index))
373
+
374
+
375
+ class Middlebury2014Stereo(StereoMatchingDataset):
376
+ """Publicly available scenes from the Middlebury dataset `2014 version <https://vision.middlebury.edu/stereo/data/scenes2014/>`.
377
+
378
+ The dataset mostly follows the original format, without containing the ambient subdirectories. : ::
379
+
380
+ root
381
+ Middlebury2014
382
+ train
383
+ scene1-{perfect,imperfect}
384
+ calib.txt
385
+ im{0,1}.png
386
+ im1E.png
387
+ im1L.png
388
+ disp{0,1}.pfm
389
+ disp{0,1}-n.png
390
+ disp{0,1}-sd.pfm
391
+ disp{0,1}y.pfm
392
+ scene2-{perfect,imperfect}
393
+ calib.txt
394
+ im{0,1}.png
395
+ im1E.png
396
+ im1L.png
397
+ disp{0,1}.pfm
398
+ disp{0,1}-n.png
399
+ disp{0,1}-sd.pfm
400
+ disp{0,1}y.pfm
401
+ ...
402
+ additional
403
+ scene1-{perfect,imperfect}
404
+ calib.txt
405
+ im{0,1}.png
406
+ im1E.png
407
+ im1L.png
408
+ disp{0,1}.pfm
409
+ disp{0,1}-n.png
410
+ disp{0,1}-sd.pfm
411
+ disp{0,1}y.pfm
412
+ ...
413
+ test
414
+ scene1
415
+ calib.txt
416
+ im{0,1}.png
417
+ scene2
418
+ calib.txt
419
+ im{0,1}.png
420
+ ...
421
+
422
+ Args:
423
+ root (str or ``pathlib.Path``): Root directory of the Middleburry 2014 Dataset.
424
+ split (string, optional): The dataset split of scenes, either "train" (default), "test", or "additional"
425
+ use_ambient_views (boolean, optional): Whether to use different expose or lightning views when possible.
426
+ The dataset samples with equal probability between ``[im1.png, im1E.png, im1L.png]``.
427
+ calibration (string, optional): Whether or not to use the calibrated (default) or uncalibrated scenes.
428
+ transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
429
+ download (boolean, optional): Whether or not to download the dataset in the ``root`` directory.
430
+ """
431
+
432
+ splits = {
433
+ "train": [
434
+ "Adirondack",
435
+ "Jadeplant",
436
+ "Motorcycle",
437
+ "Piano",
438
+ "Pipes",
439
+ "Playroom",
440
+ "Playtable",
441
+ "Recycle",
442
+ "Shelves",
443
+ "Vintage",
444
+ ],
445
+ "additional": [
446
+ "Backpack",
447
+ "Bicycle1",
448
+ "Cable",
449
+ "Classroom1",
450
+ "Couch",
451
+ "Flowers",
452
+ "Mask",
453
+ "Shopvac",
454
+ "Sticks",
455
+ "Storage",
456
+ "Sword1",
457
+ "Sword2",
458
+ "Umbrella",
459
+ ],
460
+ "test": [
461
+ "Plants",
462
+ "Classroom2E",
463
+ "Classroom2",
464
+ "Australia",
465
+ "DjembeL",
466
+ "CrusadeP",
467
+ "Crusade",
468
+ "Hoops",
469
+ "Bicycle2",
470
+ "Staircase",
471
+ "Newkuba",
472
+ "AustraliaP",
473
+ "Djembe",
474
+ "Livingroom",
475
+ "Computer",
476
+ ],
477
+ }
478
+
479
+ _has_built_in_disparity_mask = True
480
+
481
+ def __init__(
482
+ self,
483
+ root: Union[str, Path],
484
+ split: str = "train",
485
+ calibration: Optional[str] = "perfect",
486
+ use_ambient_views: bool = False,
487
+ transforms: Optional[Callable] = None,
488
+ download: bool = False,
489
+ ) -> None:
490
+ super().__init__(root, transforms)
491
+
492
+ verify_str_arg(split, "split", valid_values=("train", "test", "additional"))
493
+ self.split = split
494
+
495
+ if calibration:
496
+ verify_str_arg(calibration, "calibration", valid_values=("perfect", "imperfect", "both", None)) # type: ignore
497
+ if split == "test":
498
+ raise ValueError("Split 'test' has only no calibration settings, please set `calibration=None`.")
499
+ else:
500
+ if split != "test":
501
+ raise ValueError(
502
+ f"Split '{split}' has calibration settings, however None was provided as an argument."
503
+ f"\nSetting calibration to 'perfect' for split '{split}'. Available calibration settings are: 'perfect', 'imperfect', 'both'.",
504
+ )
505
+
506
+ if download:
507
+ self._download_dataset(root)
508
+
509
+ root = Path(root) / "Middlebury2014"
510
+
511
+ if not os.path.exists(root / split):
512
+ raise FileNotFoundError(f"The {split} directory was not found in the provided root directory")
513
+
514
+ split_scenes = self.splits[split]
515
+ # check that the provided root folder contains the scene splits
516
+ if not any(
517
+ # using startswith to account for perfect / imperfect calibrartion
518
+ scene.startswith(s)
519
+ for scene in os.listdir(root / split)
520
+ for s in split_scenes
521
+ ):
522
+ raise FileNotFoundError(f"Provided root folder does not contain any scenes from the {split} split.")
523
+
524
+ calibrartion_suffixes = {
525
+ None: [""],
526
+ "perfect": ["-perfect"],
527
+ "imperfect": ["-imperfect"],
528
+ "both": ["-perfect", "-imperfect"],
529
+ }[calibration]
530
+
531
+ for calibration_suffix in calibrartion_suffixes:
532
+ scene_pattern = "*" + calibration_suffix
533
+ left_img_pattern = str(root / split / scene_pattern / "im0.png")
534
+ right_img_pattern = str(root / split / scene_pattern / "im1.png")
535
+ self._images += self._scan_pairs(left_img_pattern, right_img_pattern)
536
+
537
+ if split == "test":
538
+ self._disparities = list((None, None) for _ in self._images)
539
+ else:
540
+ left_dispartity_pattern = str(root / split / scene_pattern / "disp0.pfm")
541
+ right_dispartity_pattern = str(root / split / scene_pattern / "disp1.pfm")
542
+ self._disparities += self._scan_pairs(left_dispartity_pattern, right_dispartity_pattern)
543
+
544
+ self.use_ambient_views = use_ambient_views
545
+
546
+ def _read_img(self, file_path: Union[str, Path]) -> Image.Image:
547
+ """
548
+ Function that reads either the original right image or an augmented view when ``use_ambient_views`` is True.
549
+ When ``use_ambient_views`` is True, the dataset will return at random one of ``[im1.png, im1E.png, im1L.png]``
550
+ as the right image.
551
+ """
552
+ ambient_file_paths: List[Union[str, Path]] # make mypy happy
553
+
554
+ if not isinstance(file_path, Path):
555
+ file_path = Path(file_path)
556
+
557
+ if file_path.name == "im1.png" and self.use_ambient_views:
558
+ base_path = file_path.parent
559
+ # initialize sampleable container
560
+ ambient_file_paths = list(base_path / view_name for view_name in ["im1E.png", "im1L.png"])
561
+ # double check that we're not going to try to read from an invalid file path
562
+ ambient_file_paths = list(filter(lambda p: os.path.exists(p), ambient_file_paths))
563
+ # keep the original image as an option as well for uniform sampling between base views
564
+ ambient_file_paths.append(file_path)
565
+ file_path = random.choice(ambient_file_paths) # type: ignore
566
+ return super()._read_img(file_path)
567
+
568
+ def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]:
569
+ # test split has not disparity maps
570
+ if file_path is None:
571
+ return None, None
572
+
573
+ disparity_map = _read_pfm_file(file_path)
574
+ disparity_map = np.abs(disparity_map) # ensure that the disparity is positive
575
+ disparity_map[disparity_map == np.inf] = 0 # remove infinite disparities
576
+ valid_mask = (disparity_map > 0).squeeze(0) # mask out invalid disparities
577
+ return disparity_map, valid_mask
578
+
579
+ def _download_dataset(self, root: Union[str, Path]) -> None:
580
+ base_url = "https://vision.middlebury.edu/stereo/data/scenes2014/zip"
581
+ # train and additional splits have 2 different calibration settings
582
+ root = Path(root) / "Middlebury2014"
583
+ split_name = self.split
584
+
585
+ if split_name != "test":
586
+ for split_scene in self.splits[split_name]:
587
+ split_root = root / split_name
588
+ for calibration in ["perfect", "imperfect"]:
589
+ scene_name = f"{split_scene}-{calibration}"
590
+ scene_url = f"{base_url}/{scene_name}.zip"
591
+ print(f"Downloading {scene_url}")
592
+ # download the scene only if it doesn't exist
593
+ if not (split_root / scene_name).exists():
594
+ download_and_extract_archive(
595
+ url=scene_url,
596
+ filename=f"{scene_name}.zip",
597
+ download_root=str(split_root),
598
+ remove_finished=True,
599
+ )
600
+ else:
601
+ os.makedirs(root / "test")
602
+ if any(s not in os.listdir(root / "test") for s in self.splits["test"]):
603
+ # test split is downloaded from a different location
604
+ test_set_url = "https://vision.middlebury.edu/stereo/submit3/zip/MiddEval3-data-F.zip"
605
+ # the unzip is going to produce a directory MiddEval3 with two subdirectories trainingF and testF
606
+ # we want to move the contents from testF into the directory
607
+ download_and_extract_archive(url=test_set_url, download_root=str(root), remove_finished=True)
608
+ for scene_dir, scene_names, _ in os.walk(str(root / "MiddEval3/testF")):
609
+ for scene in scene_names:
610
+ scene_dst_dir = root / "test"
611
+ scene_src_dir = Path(scene_dir) / scene
612
+ os.makedirs(scene_dst_dir, exist_ok=True)
613
+ shutil.move(str(scene_src_dir), str(scene_dst_dir))
614
+
615
+ # cleanup MiddEval3 directory
616
+ shutil.rmtree(str(root / "MiddEval3"))
617
+
618
+ def __getitem__(self, index: int) -> T2:
619
+ """Return example at given index.
620
+
621
+ Args:
622
+ index(int): The index of the example to retrieve
623
+
624
+ Returns:
625
+ tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
626
+ The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
627
+ ``valid_mask`` is implicitly ``None`` for `split=test`.
628
+ """
629
+ return cast(T2, super().__getitem__(index))
630
+
631
+
632
+ class CREStereo(StereoMatchingDataset):
633
+ """Synthetic dataset used in training the `CREStereo <https://arxiv.org/pdf/2203.11483.pdf>`_ architecture.
634
+ Dataset details on the official paper `repo <https://github.com/megvii-research/CREStereo>`_.
635
+
636
+ The dataset is expected to have the following structure: ::
637
+
638
+ root
639
+ CREStereo
640
+ tree
641
+ img1_left.jpg
642
+ img1_right.jpg
643
+ img1_left.disp.jpg
644
+ img1_right.disp.jpg
645
+ img2_left.jpg
646
+ img2_right.jpg
647
+ img2_left.disp.jpg
648
+ img2_right.disp.jpg
649
+ ...
650
+ shapenet
651
+ img1_left.jpg
652
+ img1_right.jpg
653
+ img1_left.disp.jpg
654
+ img1_right.disp.jpg
655
+ ...
656
+ reflective
657
+ img1_left.jpg
658
+ img1_right.jpg
659
+ img1_left.disp.jpg
660
+ img1_right.disp.jpg
661
+ ...
662
+ hole
663
+ img1_left.jpg
664
+ img1_right.jpg
665
+ img1_left.disp.jpg
666
+ img1_right.disp.jpg
667
+ ...
668
+
669
+ Args:
670
+ root (str): Root directory of the dataset.
671
+ transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
672
+ """
673
+
674
+ _has_built_in_disparity_mask = True
675
+
676
+ def __init__(
677
+ self,
678
+ root: Union[str, Path],
679
+ transforms: Optional[Callable] = None,
680
+ ) -> None:
681
+ super().__init__(root, transforms)
682
+
683
+ root = Path(root) / "CREStereo"
684
+
685
+ dirs = ["shapenet", "reflective", "tree", "hole"]
686
+
687
+ for s in dirs:
688
+ left_image_pattern = str(root / s / "*_left.jpg")
689
+ right_image_pattern = str(root / s / "*_right.jpg")
690
+ imgs = self._scan_pairs(left_image_pattern, right_image_pattern)
691
+ self._images += imgs
692
+
693
+ left_disparity_pattern = str(root / s / "*_left.disp.png")
694
+ right_disparity_pattern = str(root / s / "*_right.disp.png")
695
+ disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
696
+ self._disparities += disparities
697
+
698
+ def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
699
+ disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
700
+ # unsqueeze the disparity map into (C, H, W) format
701
+ disparity_map = disparity_map[None, :, :] / 32.0
702
+ valid_mask = None
703
+ return disparity_map, valid_mask
704
+
705
+ def __getitem__(self, index: int) -> T1:
706
+ """Return example at given index.
707
+
708
+ Args:
709
+ index(int): The index of the example to retrieve
710
+
711
+ Returns:
712
+ tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
713
+ The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
714
+ ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not
715
+ generate a valid mask.
716
+ """
717
+ return cast(T1, super().__getitem__(index))
718
+
719
+
720
+ class FallingThingsStereo(StereoMatchingDataset):
721
+ """`FallingThings <https://research.nvidia.com/publication/2018-06_falling-things-synthetic-dataset-3d-object-detection-and-pose-estimation>`_ dataset.
722
+
723
+ The dataset is expected to have the following structure: ::
724
+
725
+ root
726
+ FallingThings
727
+ single
728
+ dir1
729
+ scene1
730
+ _object_settings.json
731
+ _camera_settings.json
732
+ image1.left.depth.png
733
+ image1.right.depth.png
734
+ image1.left.jpg
735
+ image1.right.jpg
736
+ image2.left.depth.png
737
+ image2.right.depth.png
738
+ image2.left.jpg
739
+ image2.right
740
+ ...
741
+ scene2
742
+ ...
743
+ mixed
744
+ scene1
745
+ _object_settings.json
746
+ _camera_settings.json
747
+ image1.left.depth.png
748
+ image1.right.depth.png
749
+ image1.left.jpg
750
+ image1.right.jpg
751
+ image2.left.depth.png
752
+ image2.right.depth.png
753
+ image2.left.jpg
754
+ image2.right
755
+ ...
756
+ scene2
757
+ ...
758
+
759
+ Args:
760
+ root (str or ``pathlib.Path``): Root directory where FallingThings is located.
761
+ variant (string): Which variant to use. Either "single", "mixed", or "both".
762
+ transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
763
+ """
764
+
765
+ def __init__(self, root: Union[str, Path], variant: str = "single", transforms: Optional[Callable] = None) -> None:
766
+ super().__init__(root, transforms)
767
+
768
+ root = Path(root) / "FallingThings"
769
+
770
+ verify_str_arg(variant, "variant", valid_values=("single", "mixed", "both"))
771
+
772
+ variants = {
773
+ "single": ["single"],
774
+ "mixed": ["mixed"],
775
+ "both": ["single", "mixed"],
776
+ }[variant]
777
+
778
+ split_prefix = {
779
+ "single": Path("*") / "*",
780
+ "mixed": Path("*"),
781
+ }
782
+
783
+ for s in variants:
784
+ left_img_pattern = str(root / s / split_prefix[s] / "*.left.jpg")
785
+ right_img_pattern = str(root / s / split_prefix[s] / "*.right.jpg")
786
+ self._images += self._scan_pairs(left_img_pattern, right_img_pattern)
787
+
788
+ left_disparity_pattern = str(root / s / split_prefix[s] / "*.left.depth.png")
789
+ right_disparity_pattern = str(root / s / split_prefix[s] / "*.right.depth.png")
790
+ self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
791
+
792
+ def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
793
+ # (H, W) image
794
+ depth = np.asarray(Image.open(file_path))
795
+ # as per https://research.nvidia.com/sites/default/files/pubs/2018-06_Falling-Things/readme_0.txt
796
+ # in order to extract disparity from depth maps
797
+ camera_settings_path = Path(file_path).parent / "_camera_settings.json"
798
+ with open(camera_settings_path, "r") as f:
799
+ # inverse of depth-from-disparity equation: depth = (baseline * focal) / (disparity * pixel_constant)
800
+ intrinsics = json.load(f)
801
+ focal = intrinsics["camera_settings"][0]["intrinsic_settings"]["fx"]
802
+ baseline, pixel_constant = 6, 100 # pixel constant is inverted
803
+ disparity_map = (baseline * focal * pixel_constant) / depth.astype(np.float32)
804
+ # unsqueeze disparity to (C, H, W)
805
+ disparity_map = disparity_map[None, :, :]
806
+ valid_mask = None
807
+ return disparity_map, valid_mask
808
+
809
+ def __getitem__(self, index: int) -> T1:
810
+ """Return example at given index.
811
+
812
+ Args:
813
+ index(int): The index of the example to retrieve
814
+
815
+ Returns:
816
+ tuple: A 3-tuple with ``(img_left, img_right, disparity)``.
817
+ The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
818
+ If a ``valid_mask`` is generated within the ``transforms`` parameter,
819
+ a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
820
+ """
821
+ return cast(T1, super().__getitem__(index))
822
+
823
+
824
+ class SceneFlowStereo(StereoMatchingDataset):
825
+ """Dataset interface for `Scene Flow <https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html>`_ datasets.
826
+ This interface provides access to the `FlyingThings3D, `Monkaa` and `Driving` datasets.
827
+
828
+ The dataset is expected to have the following structure: ::
829
+
830
+ root
831
+ SceneFlow
832
+ Monkaa
833
+ frames_cleanpass
834
+ scene1
835
+ left
836
+ img1.png
837
+ img2.png
838
+ right
839
+ img1.png
840
+ img2.png
841
+ scene2
842
+ left
843
+ img1.png
844
+ img2.png
845
+ right
846
+ img1.png
847
+ img2.png
848
+ frames_finalpass
849
+ scene1
850
+ left
851
+ img1.png
852
+ img2.png
853
+ right
854
+ img1.png
855
+ img2.png
856
+ ...
857
+ ...
858
+ disparity
859
+ scene1
860
+ left
861
+ img1.pfm
862
+ img2.pfm
863
+ right
864
+ img1.pfm
865
+ img2.pfm
866
+ FlyingThings3D
867
+ ...
868
+ ...
869
+
870
+ Args:
871
+ root (str or ``pathlib.Path``): Root directory where SceneFlow is located.
872
+ variant (string): Which dataset variant to user, "FlyingThings3D" (default), "Monkaa" or "Driving".
873
+ pass_name (string): Which pass to use, "clean" (default), "final" or "both".
874
+ transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
875
+
876
+ """
877
+
878
+ def __init__(
879
+ self,
880
+ root: Union[str, Path],
881
+ variant: str = "FlyingThings3D",
882
+ pass_name: str = "clean",
883
+ transforms: Optional[Callable] = None,
884
+ ) -> None:
885
+ super().__init__(root, transforms)
886
+
887
+ root = Path(root) / "SceneFlow"
888
+
889
+ verify_str_arg(variant, "variant", valid_values=("FlyingThings3D", "Driving", "Monkaa"))
890
+ verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final", "both"))
891
+
892
+ passes = {
893
+ "clean": ["frames_cleanpass"],
894
+ "final": ["frames_finalpass"],
895
+ "both": ["frames_cleanpass", "frames_finalpass"],
896
+ }[pass_name]
897
+
898
+ root = root / variant
899
+
900
+ prefix_directories = {
901
+ "Monkaa": Path("*"),
902
+ "FlyingThings3D": Path("*") / "*" / "*",
903
+ "Driving": Path("*") / "*" / "*",
904
+ }
905
+
906
+ for p in passes:
907
+ left_image_pattern = str(root / p / prefix_directories[variant] / "left" / "*.png")
908
+ right_image_pattern = str(root / p / prefix_directories[variant] / "right" / "*.png")
909
+ self._images += self._scan_pairs(left_image_pattern, right_image_pattern)
910
+
911
+ left_disparity_pattern = str(root / "disparity" / prefix_directories[variant] / "left" / "*.pfm")
912
+ right_disparity_pattern = str(root / "disparity" / prefix_directories[variant] / "right" / "*.pfm")
913
+ self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
914
+
915
+ def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
916
+ disparity_map = _read_pfm_file(file_path)
917
+ disparity_map = np.abs(disparity_map) # ensure that the disparity is positive
918
+ valid_mask = None
919
+ return disparity_map, valid_mask
920
+
921
+ def __getitem__(self, index: int) -> T1:
922
+ """Return example at given index.
923
+
924
+ Args:
925
+ index(int): The index of the example to retrieve
926
+
927
+ Returns:
928
+ tuple: A 3-tuple with ``(img_left, img_right, disparity)``.
929
+ The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
930
+ If a ``valid_mask`` is generated within the ``transforms`` parameter,
931
+ a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
932
+ """
933
+ return cast(T1, super().__getitem__(index))
934
+
935
+
936
+ class SintelStereo(StereoMatchingDataset):
937
+ """Sintel `Stereo Dataset <http://sintel.is.tue.mpg.de/stereo>`_.
938
+
939
+ The dataset is expected to have the following structure: ::
940
+
941
+ root
942
+ Sintel
943
+ training
944
+ final_left
945
+ scene1
946
+ img1.png
947
+ img2.png
948
+ ...
949
+ ...
950
+ final_right
951
+ scene2
952
+ img1.png
953
+ img2.png
954
+ ...
955
+ ...
956
+ disparities
957
+ scene1
958
+ img1.png
959
+ img2.png
960
+ ...
961
+ ...
962
+ occlusions
963
+ scene1
964
+ img1.png
965
+ img2.png
966
+ ...
967
+ ...
968
+ outofframe
969
+ scene1
970
+ img1.png
971
+ img2.png
972
+ ...
973
+ ...
974
+
975
+ Args:
976
+ root (str or ``pathlib.Path``): Root directory where Sintel Stereo is located.
977
+ pass_name (string): The name of the pass to use, either "final", "clean" or "both".
978
+ transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
979
+ """
980
+
981
+ _has_built_in_disparity_mask = True
982
+
983
+ def __init__(self, root: Union[str, Path], pass_name: str = "final", transforms: Optional[Callable] = None) -> None:
984
+ super().__init__(root, transforms)
985
+
986
+ verify_str_arg(pass_name, "pass_name", valid_values=("final", "clean", "both"))
987
+
988
+ root = Path(root) / "Sintel"
989
+ pass_names = {
990
+ "final": ["final"],
991
+ "clean": ["clean"],
992
+ "both": ["final", "clean"],
993
+ }[pass_name]
994
+
995
+ for p in pass_names:
996
+ left_img_pattern = str(root / "training" / f"{p}_left" / "*" / "*.png")
997
+ right_img_pattern = str(root / "training" / f"{p}_right" / "*" / "*.png")
998
+ self._images += self._scan_pairs(left_img_pattern, right_img_pattern)
999
+
1000
+ disparity_pattern = str(root / "training" / "disparities" / "*" / "*.png")
1001
+ self._disparities += self._scan_pairs(disparity_pattern, None)
1002
+
1003
+ def _get_occlussion_mask_paths(self, file_path: str) -> Tuple[str, str]:
1004
+ # helper function to get the occlusion mask paths
1005
+ # a path will look like .../.../.../training/disparities/scene1/img1.png
1006
+ # we want to get something like .../.../.../training/occlusions/scene1/img1.png
1007
+ fpath = Path(file_path)
1008
+ basename = fpath.name
1009
+ scenedir = fpath.parent
1010
+ # the parent of the scenedir is actually the disparity dir
1011
+ sampledir = scenedir.parent.parent
1012
+
1013
+ occlusion_path = str(sampledir / "occlusions" / scenedir.name / basename)
1014
+ outofframe_path = str(sampledir / "outofframe" / scenedir.name / basename)
1015
+
1016
+ if not os.path.exists(occlusion_path):
1017
+ raise FileNotFoundError(f"Occlusion mask {occlusion_path} does not exist")
1018
+
1019
+ if not os.path.exists(outofframe_path):
1020
+ raise FileNotFoundError(f"Out of frame mask {outofframe_path} does not exist")
1021
+
1022
+ return occlusion_path, outofframe_path
1023
+
1024
+ def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]:
1025
+ if file_path is None:
1026
+ return None, None
1027
+
1028
+ # disparity decoding as per Sintel instructions in the README provided with the dataset
1029
+ disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
1030
+ r, g, b = np.split(disparity_map, 3, axis=-1)
1031
+ disparity_map = r * 4 + g / (2**6) + b / (2**14)
1032
+ # reshape into (C, H, W) format
1033
+ disparity_map = np.transpose(disparity_map, (2, 0, 1))
1034
+ # find the appropriate file paths
1035
+ occlued_mask_path, out_of_frame_mask_path = self._get_occlussion_mask_paths(file_path)
1036
+ # occlusion masks
1037
+ valid_mask = np.asarray(Image.open(occlued_mask_path)) == 0
1038
+ # out of frame masks
1039
+ off_mask = np.asarray(Image.open(out_of_frame_mask_path)) == 0
1040
+ # combine the masks together
1041
+ valid_mask = np.logical_and(off_mask, valid_mask)
1042
+ return disparity_map, valid_mask
1043
+
1044
+ def __getitem__(self, index: int) -> T2:
1045
+ """Return example at given index.
1046
+
1047
+ Args:
1048
+ index(int): The index of the example to retrieve
1049
+
1050
+ Returns:
1051
+ tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
1052
+ The disparity is a numpy array of shape (1, H, W) and the images are PIL images whilst
1053
+ the valid_mask is a numpy array of shape (H, W).
1054
+ """
1055
+ return cast(T2, super().__getitem__(index))
1056
+
1057
+
1058
+ class InStereo2k(StereoMatchingDataset):
1059
+ """`InStereo2k <https://github.com/YuhuaXu/StereoDataset>`_ dataset.
1060
+
1061
+ The dataset is expected to have the following structure: ::
1062
+
1063
+ root
1064
+ InStereo2k
1065
+ train
1066
+ scene1
1067
+ left.png
1068
+ right.png
1069
+ left_disp.png
1070
+ right_disp.png
1071
+ ...
1072
+ scene2
1073
+ ...
1074
+ test
1075
+ scene1
1076
+ left.png
1077
+ right.png
1078
+ left_disp.png
1079
+ right_disp.png
1080
+ ...
1081
+ scene2
1082
+ ...
1083
+
1084
+ Args:
1085
+ root (str or ``pathlib.Path``): Root directory where InStereo2k is located.
1086
+ split (string): Either "train" or "test".
1087
+ transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
1088
+ """
1089
+
1090
+ def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
1091
+ super().__init__(root, transforms)
1092
+
1093
+ root = Path(root) / "InStereo2k" / split
1094
+
1095
+ verify_str_arg(split, "split", valid_values=("train", "test"))
1096
+
1097
+ left_img_pattern = str(root / "*" / "left.png")
1098
+ right_img_pattern = str(root / "*" / "right.png")
1099
+ self._images = self._scan_pairs(left_img_pattern, right_img_pattern)
1100
+
1101
+ left_disparity_pattern = str(root / "*" / "left_disp.png")
1102
+ right_disparity_pattern = str(root / "*" / "right_disp.png")
1103
+ self._disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
1104
+
1105
+ def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
1106
+ disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
1107
+ # unsqueeze disparity to (C, H, W)
1108
+ disparity_map = disparity_map[None, :, :] / 1024.0
1109
+ valid_mask = None
1110
+ return disparity_map, valid_mask
1111
+
1112
+ def __getitem__(self, index: int) -> T1:
1113
+ """Return example at given index.
1114
+
1115
+ Args:
1116
+ index(int): The index of the example to retrieve
1117
+
1118
+ Returns:
1119
+ tuple: A 3-tuple with ``(img_left, img_right, disparity)``.
1120
+ The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
1121
+ If a ``valid_mask`` is generated within the ``transforms`` parameter,
1122
+ a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
1123
+ """
1124
+ return cast(T1, super().__getitem__(index))
1125
+
1126
+
1127
+ class ETH3DStereo(StereoMatchingDataset):
1128
+ """ETH3D `Low-Res Two-View <https://www.eth3d.net/datasets>`_ dataset.
1129
+
1130
+ The dataset is expected to have the following structure: ::
1131
+
1132
+ root
1133
+ ETH3D
1134
+ two_view_training
1135
+ scene1
1136
+ im1.png
1137
+ im0.png
1138
+ images.txt
1139
+ cameras.txt
1140
+ calib.txt
1141
+ scene2
1142
+ im1.png
1143
+ im0.png
1144
+ images.txt
1145
+ cameras.txt
1146
+ calib.txt
1147
+ ...
1148
+ two_view_training_gt
1149
+ scene1
1150
+ disp0GT.pfm
1151
+ mask0nocc.png
1152
+ scene2
1153
+ disp0GT.pfm
1154
+ mask0nocc.png
1155
+ ...
1156
+ two_view_testing
1157
+ scene1
1158
+ im1.png
1159
+ im0.png
1160
+ images.txt
1161
+ cameras.txt
1162
+ calib.txt
1163
+ scene2
1164
+ im1.png
1165
+ im0.png
1166
+ images.txt
1167
+ cameras.txt
1168
+ calib.txt
1169
+ ...
1170
+
1171
+ Args:
1172
+ root (str or ``pathlib.Path``): Root directory of the ETH3D Dataset.
1173
+ split (string, optional): The dataset split of scenes, either "train" (default) or "test".
1174
+ transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
1175
+ """
1176
+
1177
+ _has_built_in_disparity_mask = True
1178
+
1179
+ def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
1180
+ super().__init__(root, transforms)
1181
+
1182
+ verify_str_arg(split, "split", valid_values=("train", "test"))
1183
+
1184
+ root = Path(root) / "ETH3D"
1185
+
1186
+ img_dir = "two_view_training" if split == "train" else "two_view_test"
1187
+ anot_dir = "two_view_training_gt"
1188
+
1189
+ left_img_pattern = str(root / img_dir / "*" / "im0.png")
1190
+ right_img_pattern = str(root / img_dir / "*" / "im1.png")
1191
+ self._images = self._scan_pairs(left_img_pattern, right_img_pattern)
1192
+
1193
+ if split == "test":
1194
+ self._disparities = list((None, None) for _ in self._images)
1195
+ else:
1196
+ disparity_pattern = str(root / anot_dir / "*" / "disp0GT.pfm")
1197
+ self._disparities = self._scan_pairs(disparity_pattern, None)
1198
+
1199
+ def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]:
1200
+ # test split has no disparity maps
1201
+ if file_path is None:
1202
+ return None, None
1203
+
1204
+ disparity_map = _read_pfm_file(file_path)
1205
+ disparity_map = np.abs(disparity_map) # ensure that the disparity is positive
1206
+ mask_path = Path(file_path).parent / "mask0nocc.png"
1207
+ valid_mask = Image.open(mask_path)
1208
+ valid_mask = np.asarray(valid_mask).astype(bool)
1209
+ return disparity_map, valid_mask
1210
+
1211
+ def __getitem__(self, index: int) -> T2:
1212
+ """Return example at given index.
1213
+
1214
+ Args:
1215
+ index(int): The index of the example to retrieve
1216
+
1217
+ Returns:
1218
+ tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
1219
+ The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
1220
+ ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not
1221
+ generate a valid mask.
1222
+ Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
1223
+ """
1224
+ return cast(T2, super().__getitem__(index))
.venv/lib/python3.11/site-packages/torchvision/datasets/caltech.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path
3
+ from pathlib import Path
4
+ from typing import Any, Callable, List, Optional, Tuple, Union
5
+
6
+ from PIL import Image
7
+
8
+ from .utils import download_and_extract_archive, verify_str_arg
9
+ from .vision import VisionDataset
10
+
11
+
12
+ class Caltech101(VisionDataset):
13
+ """`Caltech 101 <https://data.caltech.edu/records/20086>`_ Dataset.
14
+
15
+ .. warning::
16
+
17
+ This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
18
+
19
+ Args:
20
+ root (str or ``pathlib.Path``): Root directory of dataset where directory
21
+ ``caltech101`` exists or will be saved to if download is set to True.
22
+ target_type (string or list, optional): Type of target to use, ``category`` or
23
+ ``annotation``. Can also be a list to output a tuple with all specified
24
+ target types. ``category`` represents the target class, and
25
+ ``annotation`` is a list of points from a hand-generated outline.
26
+ Defaults to ``category``.
27
+ transform (callable, optional): A function/transform that takes in a PIL image
28
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
29
+ target_transform (callable, optional): A function/transform that takes in the
30
+ target and transforms it.
31
+ download (bool, optional): If true, downloads the dataset from the internet and
32
+ puts it in root directory. If dataset is already downloaded, it is not
33
+ downloaded again.
34
+
35
+ .. warning::
36
+
37
+ To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ root: Union[str, Path],
43
+ target_type: Union[List[str], str] = "category",
44
+ transform: Optional[Callable] = None,
45
+ target_transform: Optional[Callable] = None,
46
+ download: bool = False,
47
+ ) -> None:
48
+ super().__init__(os.path.join(root, "caltech101"), transform=transform, target_transform=target_transform)
49
+ os.makedirs(self.root, exist_ok=True)
50
+ if isinstance(target_type, str):
51
+ target_type = [target_type]
52
+ self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation")) for t in target_type]
53
+
54
+ if download:
55
+ self.download()
56
+
57
+ if not self._check_integrity():
58
+ raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
59
+
60
+ self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories")))
61
+ self.categories.remove("BACKGROUND_Google") # this is not a real class
62
+
63
+ # For some reason, the category names in "101_ObjectCategories" and
64
+ # "Annotations" do not always match. This is a manual map between the
65
+ # two. Defaults to using same name, since most names are fine.
66
+ name_map = {
67
+ "Faces": "Faces_2",
68
+ "Faces_easy": "Faces_3",
69
+ "Motorbikes": "Motorbikes_16",
70
+ "airplanes": "Airplanes_Side_2",
71
+ }
72
+ self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories))
73
+
74
+ self.index: List[int] = []
75
+ self.y = []
76
+ for (i, c) in enumerate(self.categories):
77
+ n = len(os.listdir(os.path.join(self.root, "101_ObjectCategories", c)))
78
+ self.index.extend(range(1, n + 1))
79
+ self.y.extend(n * [i])
80
+
81
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
82
+ """
83
+ Args:
84
+ index (int): Index
85
+
86
+ Returns:
87
+ tuple: (image, target) where the type of target specified by target_type.
88
+ """
89
+ import scipy.io
90
+
91
+ img = Image.open(
92
+ os.path.join(
93
+ self.root,
94
+ "101_ObjectCategories",
95
+ self.categories[self.y[index]],
96
+ f"image_{self.index[index]:04d}.jpg",
97
+ )
98
+ )
99
+
100
+ target: Any = []
101
+ for t in self.target_type:
102
+ if t == "category":
103
+ target.append(self.y[index])
104
+ elif t == "annotation":
105
+ data = scipy.io.loadmat(
106
+ os.path.join(
107
+ self.root,
108
+ "Annotations",
109
+ self.annotation_categories[self.y[index]],
110
+ f"annotation_{self.index[index]:04d}.mat",
111
+ )
112
+ )
113
+ target.append(data["obj_contour"])
114
+ target = tuple(target) if len(target) > 1 else target[0]
115
+
116
+ if self.transform is not None:
117
+ img = self.transform(img)
118
+
119
+ if self.target_transform is not None:
120
+ target = self.target_transform(target)
121
+
122
+ return img, target
123
+
124
+ def _check_integrity(self) -> bool:
125
+ # can be more robust and check hash of files
126
+ return os.path.exists(os.path.join(self.root, "101_ObjectCategories"))
127
+
128
+ def __len__(self) -> int:
129
+ return len(self.index)
130
+
131
+ def download(self) -> None:
132
+ if self._check_integrity():
133
+ print("Files already downloaded and verified")
134
+ return
135
+
136
+ download_and_extract_archive(
137
+ "https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp",
138
+ self.root,
139
+ filename="101_ObjectCategories.tar.gz",
140
+ md5="b224c7392d521a49829488ab0f1120d9",
141
+ )
142
+ download_and_extract_archive(
143
+ "https://drive.google.com/file/d/175kQy3UsZ0wUEHZjqkUDdNVssr7bgh_m",
144
+ self.root,
145
+ filename="Annotations.tar",
146
+ md5="6f83eeb1f24d99cab4eb377263132c91",
147
+ )
148
+
149
+ def extra_repr(self) -> str:
150
+ return "Target type: {target_type}".format(**self.__dict__)
151
+
152
+
153
+ class Caltech256(VisionDataset):
154
+ """`Caltech 256 <https://data.caltech.edu/records/20087>`_ Dataset.
155
+
156
+ Args:
157
+ root (str or ``pathlib.Path``): Root directory of dataset where directory
158
+ ``caltech256`` exists or will be saved to if download is set to True.
159
+ transform (callable, optional): A function/transform that takes in a PIL image
160
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
161
+ target_transform (callable, optional): A function/transform that takes in the
162
+ target and transforms it.
163
+ download (bool, optional): If true, downloads the dataset from the internet and
164
+ puts it in root directory. If dataset is already downloaded, it is not
165
+ downloaded again.
166
+ """
167
+
168
+ def __init__(
169
+ self,
170
+ root: str,
171
+ transform: Optional[Callable] = None,
172
+ target_transform: Optional[Callable] = None,
173
+ download: bool = False,
174
+ ) -> None:
175
+ super().__init__(os.path.join(root, "caltech256"), transform=transform, target_transform=target_transform)
176
+ os.makedirs(self.root, exist_ok=True)
177
+
178
+ if download:
179
+ self.download()
180
+
181
+ if not self._check_integrity():
182
+ raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
183
+
184
+ self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories")))
185
+ self.index: List[int] = []
186
+ self.y = []
187
+ for (i, c) in enumerate(self.categories):
188
+ n = len(
189
+ [
190
+ item
191
+ for item in os.listdir(os.path.join(self.root, "256_ObjectCategories", c))
192
+ if item.endswith(".jpg")
193
+ ]
194
+ )
195
+ self.index.extend(range(1, n + 1))
196
+ self.y.extend(n * [i])
197
+
198
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
199
+ """
200
+ Args:
201
+ index (int): Index
202
+
203
+ Returns:
204
+ tuple: (image, target) where target is index of the target class.
205
+ """
206
+ img = Image.open(
207
+ os.path.join(
208
+ self.root,
209
+ "256_ObjectCategories",
210
+ self.categories[self.y[index]],
211
+ f"{self.y[index] + 1:03d}_{self.index[index]:04d}.jpg",
212
+ )
213
+ )
214
+
215
+ target = self.y[index]
216
+
217
+ if self.transform is not None:
218
+ img = self.transform(img)
219
+
220
+ if self.target_transform is not None:
221
+ target = self.target_transform(target)
222
+
223
+ return img, target
224
+
225
+ def _check_integrity(self) -> bool:
226
+ # can be more robust and check hash of files
227
+ return os.path.exists(os.path.join(self.root, "256_ObjectCategories"))
228
+
229
+ def __len__(self) -> int:
230
+ return len(self.index)
231
+
232
+ def download(self) -> None:
233
+ if self._check_integrity():
234
+ print("Files already downloaded and verified")
235
+ return
236
+
237
+ download_and_extract_archive(
238
+ "https://drive.google.com/file/d/1r6o0pSROcV1_VwT4oSjA2FBUSCWGuxLK",
239
+ self.root,
240
+ filename="256_ObjectCategories.tar",
241
+ md5="67b4f42ca05d46448c6bb8ecd2220f6d",
242
+ )
.venv/lib/python3.11/site-packages/torchvision/datasets/celeba.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import os
3
+ from collections import namedtuple
4
+ from pathlib import Path
5
+ from typing import Any, Callable, List, Optional, Tuple, Union
6
+
7
+ import PIL
8
+ import torch
9
+
10
+ from .utils import check_integrity, download_file_from_google_drive, extract_archive, verify_str_arg
11
+ from .vision import VisionDataset
12
+
13
+ CSV = namedtuple("CSV", ["header", "index", "data"])
14
+
15
+
16
+ class CelebA(VisionDataset):
17
+ """`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset.
18
+
19
+ Args:
20
+ root (str or ``pathlib.Path``): Root directory where images are downloaded to.
21
+ split (string): One of {'train', 'valid', 'test', 'all'}.
22
+ Accordingly dataset is selected.
23
+ target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``,
24
+ or ``landmarks``. Can also be a list to output a tuple with all specified target types.
25
+ The targets represent:
26
+
27
+ - ``attr`` (Tensor shape=(40,) dtype=int): binary (0, 1) labels for attributes
28
+ - ``identity`` (int): label for each person (data points with the same identity are the same person)
29
+ - ``bbox`` (Tensor shape=(4,) dtype=int): bounding box (x, y, width, height)
30
+ - ``landmarks`` (Tensor shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x,
31
+ righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y)
32
+
33
+ Defaults to ``attr``. If empty, ``None`` will be returned as target.
34
+
35
+ transform (callable, optional): A function/transform that takes in a PIL image
36
+ and returns a transformed version. E.g, ``transforms.PILToTensor``
37
+ target_transform (callable, optional): A function/transform that takes in the
38
+ target and transforms it.
39
+ download (bool, optional): If true, downloads the dataset from the internet and
40
+ puts it in root directory. If dataset is already downloaded, it is not
41
+ downloaded again.
42
+
43
+ .. warning::
44
+
45
+ To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
46
+ """
47
+
48
+ base_folder = "celeba"
49
+ # There currently does not appear to be an easy way to extract 7z in python (without introducing additional
50
+ # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available
51
+ # right now.
52
+ file_list = [
53
+ # File ID MD5 Hash Filename
54
+ ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"),
55
+ # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc","b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"),
56
+ # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"),
57
+ ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"),
58
+ ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"),
59
+ ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"),
60
+ ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"),
61
+ # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"),
62
+ ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"),
63
+ ]
64
+
65
+ def __init__(
66
+ self,
67
+ root: Union[str, Path],
68
+ split: str = "train",
69
+ target_type: Union[List[str], str] = "attr",
70
+ transform: Optional[Callable] = None,
71
+ target_transform: Optional[Callable] = None,
72
+ download: bool = False,
73
+ ) -> None:
74
+ super().__init__(root, transform=transform, target_transform=target_transform)
75
+ self.split = split
76
+ if isinstance(target_type, list):
77
+ self.target_type = target_type
78
+ else:
79
+ self.target_type = [target_type]
80
+
81
+ if not self.target_type and self.target_transform is not None:
82
+ raise RuntimeError("target_transform is specified but target_type is empty")
83
+
84
+ if download:
85
+ self.download()
86
+
87
+ if not self._check_integrity():
88
+ raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
89
+
90
+ split_map = {
91
+ "train": 0,
92
+ "valid": 1,
93
+ "test": 2,
94
+ "all": None,
95
+ }
96
+ split_ = split_map[verify_str_arg(split.lower(), "split", ("train", "valid", "test", "all"))]
97
+ splits = self._load_csv("list_eval_partition.txt")
98
+ identity = self._load_csv("identity_CelebA.txt")
99
+ bbox = self._load_csv("list_bbox_celeba.txt", header=1)
100
+ landmarks_align = self._load_csv("list_landmarks_align_celeba.txt", header=1)
101
+ attr = self._load_csv("list_attr_celeba.txt", header=1)
102
+
103
+ mask = slice(None) if split_ is None else (splits.data == split_).squeeze()
104
+
105
+ if mask == slice(None): # if split == "all"
106
+ self.filename = splits.index
107
+ else:
108
+ self.filename = [splits.index[i] for i in torch.squeeze(torch.nonzero(mask))]
109
+ self.identity = identity.data[mask]
110
+ self.bbox = bbox.data[mask]
111
+ self.landmarks_align = landmarks_align.data[mask]
112
+ self.attr = attr.data[mask]
113
+ # map from {-1, 1} to {0, 1}
114
+ self.attr = torch.div(self.attr + 1, 2, rounding_mode="floor")
115
+ self.attr_names = attr.header
116
+
117
+ def _load_csv(
118
+ self,
119
+ filename: str,
120
+ header: Optional[int] = None,
121
+ ) -> CSV:
122
+ with open(os.path.join(self.root, self.base_folder, filename)) as csv_file:
123
+ data = list(csv.reader(csv_file, delimiter=" ", skipinitialspace=True))
124
+
125
+ if header is not None:
126
+ headers = data[header]
127
+ data = data[header + 1 :]
128
+ else:
129
+ headers = []
130
+
131
+ indices = [row[0] for row in data]
132
+ data = [row[1:] for row in data]
133
+ data_int = [list(map(int, i)) for i in data]
134
+
135
+ return CSV(headers, indices, torch.tensor(data_int))
136
+
137
+ def _check_integrity(self) -> bool:
138
+ for (_, md5, filename) in self.file_list:
139
+ fpath = os.path.join(self.root, self.base_folder, filename)
140
+ _, ext = os.path.splitext(filename)
141
+ # Allow original archive to be deleted (zip and 7z)
142
+ # Only need the extracted images
143
+ if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5):
144
+ return False
145
+
146
+ # Should check a hash of the images
147
+ return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba"))
148
+
149
+ def download(self) -> None:
150
+ if self._check_integrity():
151
+ print("Files already downloaded and verified")
152
+ return
153
+
154
+ for (file_id, md5, filename) in self.file_list:
155
+ download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5)
156
+
157
+ extract_archive(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"))
158
+
159
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
160
+ X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))
161
+
162
+ target: Any = []
163
+ for t in self.target_type:
164
+ if t == "attr":
165
+ target.append(self.attr[index, :])
166
+ elif t == "identity":
167
+ target.append(self.identity[index, 0])
168
+ elif t == "bbox":
169
+ target.append(self.bbox[index, :])
170
+ elif t == "landmarks":
171
+ target.append(self.landmarks_align[index, :])
172
+ else:
173
+ # TODO: refactor with utils.verify_str_arg
174
+ raise ValueError(f'Target type "{t}" is not recognized.')
175
+
176
+ if self.transform is not None:
177
+ X = self.transform(X)
178
+
179
+ if target:
180
+ target = tuple(target) if len(target) > 1 else target[0]
181
+
182
+ if self.target_transform is not None:
183
+ target = self.target_transform(target)
184
+ else:
185
+ target = None
186
+
187
+ return X, target
188
+
189
+ def __len__(self) -> int:
190
+ return len(self.attr)
191
+
192
+ def extra_repr(self) -> str:
193
+ lines = ["Target type: {target_type}", "Split: {split}"]
194
+ return "\n".join(lines).format(**self.__dict__)
.venv/lib/python3.11/site-packages/torchvision/datasets/cifar.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import pickle
3
+ from pathlib import Path
4
+ from typing import Any, Callable, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+ from .utils import check_integrity, download_and_extract_archive
10
+ from .vision import VisionDataset
11
+
12
+
13
+ class CIFAR10(VisionDataset):
14
+ """`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
15
+
16
+ Args:
17
+ root (str or ``pathlib.Path``): Root directory of dataset where directory
18
+ ``cifar-10-batches-py`` exists or will be saved to if download is set to True.
19
+ train (bool, optional): If True, creates dataset from training set, otherwise
20
+ creates from test set.
21
+ transform (callable, optional): A function/transform that takes in a PIL image
22
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
23
+ target_transform (callable, optional): A function/transform that takes in the
24
+ target and transforms it.
25
+ download (bool, optional): If true, downloads the dataset from the internet and
26
+ puts it in root directory. If dataset is already downloaded, it is not
27
+ downloaded again.
28
+
29
+ """
30
+
31
+ base_folder = "cifar-10-batches-py"
32
+ url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
33
+ filename = "cifar-10-python.tar.gz"
34
+ tgz_md5 = "c58f30108f718f92721af3b95e74349a"
35
+ train_list = [
36
+ ["data_batch_1", "c99cafc152244af753f735de768cd75f"],
37
+ ["data_batch_2", "d4bba439e000b95fd0a9bffe97cbabec"],
38
+ ["data_batch_3", "54ebc095f3ab1f0389bbae665268c751"],
39
+ ["data_batch_4", "634d18415352ddfa80567beed471001a"],
40
+ ["data_batch_5", "482c414d41f54cd18b22e5b47cb7c3cb"],
41
+ ]
42
+
43
+ test_list = [
44
+ ["test_batch", "40351d587109b95175f43aff81a1287e"],
45
+ ]
46
+ meta = {
47
+ "filename": "batches.meta",
48
+ "key": "label_names",
49
+ "md5": "5ff9c542aee3614f3951f8cda6e48888",
50
+ }
51
+
52
+ def __init__(
53
+ self,
54
+ root: Union[str, Path],
55
+ train: bool = True,
56
+ transform: Optional[Callable] = None,
57
+ target_transform: Optional[Callable] = None,
58
+ download: bool = False,
59
+ ) -> None:
60
+
61
+ super().__init__(root, transform=transform, target_transform=target_transform)
62
+
63
+ self.train = train # training set or test set
64
+
65
+ if download:
66
+ self.download()
67
+
68
+ if not self._check_integrity():
69
+ raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
70
+
71
+ if self.train:
72
+ downloaded_list = self.train_list
73
+ else:
74
+ downloaded_list = self.test_list
75
+
76
+ self.data: Any = []
77
+ self.targets = []
78
+
79
+ # now load the picked numpy arrays
80
+ for file_name, checksum in downloaded_list:
81
+ file_path = os.path.join(self.root, self.base_folder, file_name)
82
+ with open(file_path, "rb") as f:
83
+ entry = pickle.load(f, encoding="latin1")
84
+ self.data.append(entry["data"])
85
+ if "labels" in entry:
86
+ self.targets.extend(entry["labels"])
87
+ else:
88
+ self.targets.extend(entry["fine_labels"])
89
+
90
+ self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
91
+ self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
92
+
93
+ self._load_meta()
94
+
95
+ def _load_meta(self) -> None:
96
+ path = os.path.join(self.root, self.base_folder, self.meta["filename"])
97
+ if not check_integrity(path, self.meta["md5"]):
98
+ raise RuntimeError("Dataset metadata file not found or corrupted. You can use download=True to download it")
99
+ with open(path, "rb") as infile:
100
+ data = pickle.load(infile, encoding="latin1")
101
+ self.classes = data[self.meta["key"]]
102
+ self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
103
+
104
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
105
+ """
106
+ Args:
107
+ index (int): Index
108
+
109
+ Returns:
110
+ tuple: (image, target) where target is index of the target class.
111
+ """
112
+ img, target = self.data[index], self.targets[index]
113
+
114
+ # doing this so that it is consistent with all other datasets
115
+ # to return a PIL Image
116
+ img = Image.fromarray(img)
117
+
118
+ if self.transform is not None:
119
+ img = self.transform(img)
120
+
121
+ if self.target_transform is not None:
122
+ target = self.target_transform(target)
123
+
124
+ return img, target
125
+
126
+ def __len__(self) -> int:
127
+ return len(self.data)
128
+
129
+ def _check_integrity(self) -> bool:
130
+ for filename, md5 in self.train_list + self.test_list:
131
+ fpath = os.path.join(self.root, self.base_folder, filename)
132
+ if not check_integrity(fpath, md5):
133
+ return False
134
+ return True
135
+
136
+ def download(self) -> None:
137
+ if self._check_integrity():
138
+ print("Files already downloaded and verified")
139
+ return
140
+ download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
141
+
142
+ def extra_repr(self) -> str:
143
+ split = "Train" if self.train is True else "Test"
144
+ return f"Split: {split}"
145
+
146
+
147
+ class CIFAR100(CIFAR10):
148
+ """`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
149
+
150
+ This is a subclass of the `CIFAR10` Dataset.
151
+ """
152
+
153
+ base_folder = "cifar-100-python"
154
+ url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
155
+ filename = "cifar-100-python.tar.gz"
156
+ tgz_md5 = "eb9058c3a382ffc7106e4002c42a8d85"
157
+ train_list = [
158
+ ["train", "16019d7e3df5f24257cddd939b257f8d"],
159
+ ]
160
+
161
+ test_list = [
162
+ ["test", "f0ef6b0ae62326f3e7ffdfab6717acfc"],
163
+ ]
164
+ meta = {
165
+ "filename": "meta",
166
+ "key": "fine_label_names",
167
+ "md5": "7973b15100ade9c7d40fb424638fde48",
168
+ }
.venv/lib/python3.11/site-packages/torchvision/datasets/cityscapes.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from collections import namedtuple
4
+ from pathlib import Path
5
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
6
+
7
+ from PIL import Image
8
+
9
+ from .utils import extract_archive, iterable_to_str, verify_str_arg
10
+ from .vision import VisionDataset
11
+
12
+
13
+ class Cityscapes(VisionDataset):
14
+ """`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
15
+
16
+ Args:
17
+ root (str or ``pathlib.Path``): Root directory of dataset where directory ``leftImg8bit``
18
+ and ``gtFine`` or ``gtCoarse`` are located.
19
+ split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="fine"
20
+ otherwise ``train``, ``train_extra`` or ``val``
21
+ mode (string, optional): The quality mode to use, ``fine`` or ``coarse``
22
+ target_type (string or list, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
23
+ or ``color``. Can also be a list to output a tuple with all specified target types.
24
+ transform (callable, optional): A function/transform that takes in a PIL image
25
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
26
+ target_transform (callable, optional): A function/transform that takes in the
27
+ target and transforms it.
28
+ transforms (callable, optional): A function/transform that takes input sample and its target as entry
29
+ and returns a transformed version.
30
+
31
+ Examples:
32
+
33
+ Get semantic segmentation target
34
+
35
+ .. code-block:: python
36
+
37
+ dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
38
+ target_type='semantic')
39
+
40
+ img, smnt = dataset[0]
41
+
42
+ Get multiple targets
43
+
44
+ .. code-block:: python
45
+
46
+ dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
47
+ target_type=['instance', 'color', 'polygon'])
48
+
49
+ img, (inst, col, poly) = dataset[0]
50
+
51
+ Validate on the "coarse" set
52
+
53
+ .. code-block:: python
54
+
55
+ dataset = Cityscapes('./data/cityscapes', split='val', mode='coarse',
56
+ target_type='semantic')
57
+
58
+ img, smnt = dataset[0]
59
+ """
60
+
61
+ # Based on https://github.com/mcordts/cityscapesScripts
62
+ CityscapesClass = namedtuple(
63
+ "CityscapesClass",
64
+ ["name", "id", "train_id", "category", "category_id", "has_instances", "ignore_in_eval", "color"],
65
+ )
66
+
67
+ classes = [
68
+ CityscapesClass("unlabeled", 0, 255, "void", 0, False, True, (0, 0, 0)),
69
+ CityscapesClass("ego vehicle", 1, 255, "void", 0, False, True, (0, 0, 0)),
70
+ CityscapesClass("rectification border", 2, 255, "void", 0, False, True, (0, 0, 0)),
71
+ CityscapesClass("out of roi", 3, 255, "void", 0, False, True, (0, 0, 0)),
72
+ CityscapesClass("static", 4, 255, "void", 0, False, True, (0, 0, 0)),
73
+ CityscapesClass("dynamic", 5, 255, "void", 0, False, True, (111, 74, 0)),
74
+ CityscapesClass("ground", 6, 255, "void", 0, False, True, (81, 0, 81)),
75
+ CityscapesClass("road", 7, 0, "flat", 1, False, False, (128, 64, 128)),
76
+ CityscapesClass("sidewalk", 8, 1, "flat", 1, False, False, (244, 35, 232)),
77
+ CityscapesClass("parking", 9, 255, "flat", 1, False, True, (250, 170, 160)),
78
+ CityscapesClass("rail track", 10, 255, "flat", 1, False, True, (230, 150, 140)),
79
+ CityscapesClass("building", 11, 2, "construction", 2, False, False, (70, 70, 70)),
80
+ CityscapesClass("wall", 12, 3, "construction", 2, False, False, (102, 102, 156)),
81
+ CityscapesClass("fence", 13, 4, "construction", 2, False, False, (190, 153, 153)),
82
+ CityscapesClass("guard rail", 14, 255, "construction", 2, False, True, (180, 165, 180)),
83
+ CityscapesClass("bridge", 15, 255, "construction", 2, False, True, (150, 100, 100)),
84
+ CityscapesClass("tunnel", 16, 255, "construction", 2, False, True, (150, 120, 90)),
85
+ CityscapesClass("pole", 17, 5, "object", 3, False, False, (153, 153, 153)),
86
+ CityscapesClass("polegroup", 18, 255, "object", 3, False, True, (153, 153, 153)),
87
+ CityscapesClass("traffic light", 19, 6, "object", 3, False, False, (250, 170, 30)),
88
+ CityscapesClass("traffic sign", 20, 7, "object", 3, False, False, (220, 220, 0)),
89
+ CityscapesClass("vegetation", 21, 8, "nature", 4, False, False, (107, 142, 35)),
90
+ CityscapesClass("terrain", 22, 9, "nature", 4, False, False, (152, 251, 152)),
91
+ CityscapesClass("sky", 23, 10, "sky", 5, False, False, (70, 130, 180)),
92
+ CityscapesClass("person", 24, 11, "human", 6, True, False, (220, 20, 60)),
93
+ CityscapesClass("rider", 25, 12, "human", 6, True, False, (255, 0, 0)),
94
+ CityscapesClass("car", 26, 13, "vehicle", 7, True, False, (0, 0, 142)),
95
+ CityscapesClass("truck", 27, 14, "vehicle", 7, True, False, (0, 0, 70)),
96
+ CityscapesClass("bus", 28, 15, "vehicle", 7, True, False, (0, 60, 100)),
97
+ CityscapesClass("caravan", 29, 255, "vehicle", 7, True, True, (0, 0, 90)),
98
+ CityscapesClass("trailer", 30, 255, "vehicle", 7, True, True, (0, 0, 110)),
99
+ CityscapesClass("train", 31, 16, "vehicle", 7, True, False, (0, 80, 100)),
100
+ CityscapesClass("motorcycle", 32, 17, "vehicle", 7, True, False, (0, 0, 230)),
101
+ CityscapesClass("bicycle", 33, 18, "vehicle", 7, True, False, (119, 11, 32)),
102
+ CityscapesClass("license plate", -1, -1, "vehicle", 7, False, True, (0, 0, 142)),
103
+ ]
104
+
105
+ def __init__(
106
+ self,
107
+ root: Union[str, Path],
108
+ split: str = "train",
109
+ mode: str = "fine",
110
+ target_type: Union[List[str], str] = "instance",
111
+ transform: Optional[Callable] = None,
112
+ target_transform: Optional[Callable] = None,
113
+ transforms: Optional[Callable] = None,
114
+ ) -> None:
115
+ super().__init__(root, transforms, transform, target_transform)
116
+ self.mode = "gtFine" if mode == "fine" else "gtCoarse"
117
+ self.images_dir = os.path.join(self.root, "leftImg8bit", split)
118
+ self.targets_dir = os.path.join(self.root, self.mode, split)
119
+ self.target_type = target_type
120
+ self.split = split
121
+ self.images = []
122
+ self.targets = []
123
+
124
+ verify_str_arg(mode, "mode", ("fine", "coarse"))
125
+ if mode == "fine":
126
+ valid_modes = ("train", "test", "val")
127
+ else:
128
+ valid_modes = ("train", "train_extra", "val")
129
+ msg = "Unknown value '{}' for argument split if mode is '{}'. Valid values are {{{}}}."
130
+ msg = msg.format(split, mode, iterable_to_str(valid_modes))
131
+ verify_str_arg(split, "split", valid_modes, msg)
132
+
133
+ if not isinstance(target_type, list):
134
+ self.target_type = [target_type]
135
+ [
136
+ verify_str_arg(value, "target_type", ("instance", "semantic", "polygon", "color"))
137
+ for value in self.target_type
138
+ ]
139
+
140
+ if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):
141
+
142
+ if split == "train_extra":
143
+ image_dir_zip = os.path.join(self.root, "leftImg8bit_trainextra.zip")
144
+ else:
145
+ image_dir_zip = os.path.join(self.root, "leftImg8bit_trainvaltest.zip")
146
+
147
+ if self.mode == "gtFine":
148
+ target_dir_zip = os.path.join(self.root, f"{self.mode}_trainvaltest.zip")
149
+ elif self.mode == "gtCoarse":
150
+ target_dir_zip = os.path.join(self.root, f"{self.mode}.zip")
151
+
152
+ if os.path.isfile(image_dir_zip) and os.path.isfile(target_dir_zip):
153
+ extract_archive(from_path=image_dir_zip, to_path=self.root)
154
+ extract_archive(from_path=target_dir_zip, to_path=self.root)
155
+ else:
156
+ raise RuntimeError(
157
+ "Dataset not found or incomplete. Please make sure all required folders for the"
158
+ ' specified "split" and "mode" are inside the "root" directory'
159
+ )
160
+
161
+ for city in os.listdir(self.images_dir):
162
+ img_dir = os.path.join(self.images_dir, city)
163
+ target_dir = os.path.join(self.targets_dir, city)
164
+ for file_name in os.listdir(img_dir):
165
+ target_types = []
166
+ for t in self.target_type:
167
+ target_name = "{}_{}".format(
168
+ file_name.split("_leftImg8bit")[0], self._get_target_suffix(self.mode, t)
169
+ )
170
+ target_types.append(os.path.join(target_dir, target_name))
171
+
172
+ self.images.append(os.path.join(img_dir, file_name))
173
+ self.targets.append(target_types)
174
+
175
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
176
+ """
177
+ Args:
178
+ index (int): Index
179
+ Returns:
180
+ tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
181
+ than one item. Otherwise, target is a json object if target_type="polygon", else the image segmentation.
182
+ """
183
+
184
+ image = Image.open(self.images[index]).convert("RGB")
185
+
186
+ targets: Any = []
187
+ for i, t in enumerate(self.target_type):
188
+ if t == "polygon":
189
+ target = self._load_json(self.targets[index][i])
190
+ else:
191
+ target = Image.open(self.targets[index][i]) # type: ignore[assignment]
192
+
193
+ targets.append(target)
194
+
195
+ target = tuple(targets) if len(targets) > 1 else targets[0]
196
+
197
+ if self.transforms is not None:
198
+ image, target = self.transforms(image, target)
199
+
200
+ return image, target
201
+
202
+ def __len__(self) -> int:
203
+ return len(self.images)
204
+
205
+ def extra_repr(self) -> str:
206
+ lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"]
207
+ return "\n".join(lines).format(**self.__dict__)
208
+
209
+ def _load_json(self, path: str) -> Dict[str, Any]:
210
+ with open(path) as file:
211
+ data = json.load(file)
212
+ return data
213
+
214
+ def _get_target_suffix(self, mode: str, target_type: str) -> str:
215
+ if target_type == "instance":
216
+ return f"{mode}_instanceIds.png"
217
+ elif target_type == "semantic":
218
+ return f"{mode}_labelIds.png"
219
+ elif target_type == "color":
220
+ return f"{mode}_color.png"
221
+ else:
222
+ return f"{mode}_polygons.json"
.venv/lib/python3.11/site-packages/torchvision/datasets/clevr.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import pathlib
3
+ from typing import Any, Callable, List, Optional, Tuple, Union
4
+ from urllib.parse import urlparse
5
+
6
+ from PIL import Image
7
+
8
+ from .utils import download_and_extract_archive, verify_str_arg
9
+ from .vision import VisionDataset
10
+
11
+
12
+ class CLEVRClassification(VisionDataset):
13
+ """`CLEVR <https://cs.stanford.edu/people/jcjohns/clevr/>`_ classification dataset.
14
+
15
+ The number of objects in a scene are used as label.
16
+
17
+ Args:
18
+ root (str or ``pathlib.Path``): Root directory of dataset where directory ``root/clevr`` exists or will be saved to if download is
19
+ set to True.
20
+ split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
21
+ transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
22
+ version. E.g, ``transforms.RandomCrop``
23
+ target_transform (callable, optional): A function/transform that takes in them target and transforms it.
24
+ download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If
25
+ dataset is already downloaded, it is not downloaded again.
26
+ """
27
+
28
+ _URL = "https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip"
29
+ _MD5 = "b11922020e72d0cd9154779b2d3d07d2"
30
+
31
+ def __init__(
32
+ self,
33
+ root: Union[str, pathlib.Path],
34
+ split: str = "train",
35
+ transform: Optional[Callable] = None,
36
+ target_transform: Optional[Callable] = None,
37
+ download: bool = False,
38
+ ) -> None:
39
+ self._split = verify_str_arg(split, "split", ("train", "val", "test"))
40
+ super().__init__(root, transform=transform, target_transform=target_transform)
41
+ self._base_folder = pathlib.Path(self.root) / "clevr"
42
+ self._data_folder = self._base_folder / pathlib.Path(urlparse(self._URL).path).stem
43
+
44
+ if download:
45
+ self._download()
46
+
47
+ if not self._check_exists():
48
+ raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
49
+
50
+ self._image_files = sorted(self._data_folder.joinpath("images", self._split).glob("*"))
51
+
52
+ self._labels: List[Optional[int]]
53
+ if self._split != "test":
54
+ with open(self._data_folder / "scenes" / f"CLEVR_{self._split}_scenes.json") as file:
55
+ content = json.load(file)
56
+ num_objects = {scene["image_filename"]: len(scene["objects"]) for scene in content["scenes"]}
57
+ self._labels = [num_objects[image_file.name] for image_file in self._image_files]
58
+ else:
59
+ self._labels = [None] * len(self._image_files)
60
+
61
+ def __len__(self) -> int:
62
+ return len(self._image_files)
63
+
64
+ def __getitem__(self, idx: int) -> Tuple[Any, Any]:
65
+ image_file = self._image_files[idx]
66
+ label = self._labels[idx]
67
+
68
+ image = Image.open(image_file).convert("RGB")
69
+
70
+ if self.transform:
71
+ image = self.transform(image)
72
+
73
+ if self.target_transform:
74
+ label = self.target_transform(label)
75
+
76
+ return image, label
77
+
78
+ def _check_exists(self) -> bool:
79
+ return self._data_folder.exists() and self._data_folder.is_dir()
80
+
81
+ def _download(self) -> None:
82
+ if self._check_exists():
83
+ return
84
+
85
+ download_and_extract_archive(self._URL, str(self._base_folder), md5=self._MD5)
86
+
87
+ def extra_repr(self) -> str:
88
+ return f"split={self._split}"
.venv/lib/python3.11/site-packages/torchvision/datasets/coco.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ from pathlib import Path
3
+ from typing import Any, Callable, List, Optional, Tuple, Union
4
+
5
+ from PIL import Image
6
+
7
+ from .vision import VisionDataset
8
+
9
+
10
+ class CocoDetection(VisionDataset):
11
+ """`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset.
12
+
13
+ It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.
14
+
15
+ Args:
16
+ root (str or ``pathlib.Path``): Root directory where images are downloaded to.
17
+ annFile (string): Path to json annotation file.
18
+ transform (callable, optional): A function/transform that takes in a PIL image
19
+ and returns a transformed version. E.g, ``transforms.PILToTensor``
20
+ target_transform (callable, optional): A function/transform that takes in the
21
+ target and transforms it.
22
+ transforms (callable, optional): A function/transform that takes input sample and its target as entry
23
+ and returns a transformed version.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ root: Union[str, Path],
29
+ annFile: str,
30
+ transform: Optional[Callable] = None,
31
+ target_transform: Optional[Callable] = None,
32
+ transforms: Optional[Callable] = None,
33
+ ) -> None:
34
+ super().__init__(root, transforms, transform, target_transform)
35
+ from pycocotools.coco import COCO
36
+
37
+ self.coco = COCO(annFile)
38
+ self.ids = list(sorted(self.coco.imgs.keys()))
39
+
40
+ def _load_image(self, id: int) -> Image.Image:
41
+ path = self.coco.loadImgs(id)[0]["file_name"]
42
+ return Image.open(os.path.join(self.root, path)).convert("RGB")
43
+
44
+ def _load_target(self, id: int) -> List[Any]:
45
+ return self.coco.loadAnns(self.coco.getAnnIds(id))
46
+
47
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
48
+
49
+ if not isinstance(index, int):
50
+ raise ValueError(f"Index must be of type integer, got {type(index)} instead.")
51
+
52
+ id = self.ids[index]
53
+ image = self._load_image(id)
54
+ target = self._load_target(id)
55
+
56
+ if self.transforms is not None:
57
+ image, target = self.transforms(image, target)
58
+
59
+ return image, target
60
+
61
+ def __len__(self) -> int:
62
+ return len(self.ids)
63
+
64
+
65
+ class CocoCaptions(CocoDetection):
66
+ """`MS Coco Captions <https://cocodataset.org/#captions-2015>`_ Dataset.
67
+
68
+ It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.
69
+
70
+ Args:
71
+ root (str or ``pathlib.Path``): Root directory where images are downloaded to.
72
+ annFile (string): Path to json annotation file.
73
+ transform (callable, optional): A function/transform that takes in a PIL image
74
+ and returns a transformed version. E.g, ``transforms.PILToTensor``
75
+ target_transform (callable, optional): A function/transform that takes in the
76
+ target and transforms it.
77
+ transforms (callable, optional): A function/transform that takes input sample and its target as entry
78
+ and returns a transformed version.
79
+
80
+ Example:
81
+
82
+ .. code:: python
83
+
84
+ import torchvision.datasets as dset
85
+ import torchvision.transforms as transforms
86
+ cap = dset.CocoCaptions(root = 'dir where images are',
87
+ annFile = 'json annotation file',
88
+ transform=transforms.PILToTensor())
89
+
90
+ print('Number of samples: ', len(cap))
91
+ img, target = cap[3] # load 4th sample
92
+
93
+ print("Image Size: ", img.size())
94
+ print(target)
95
+
96
+ Output: ::
97
+
98
+ Number of samples: 82783
99
+ Image Size: (3L, 427L, 640L)
100
+ [u'A plane emitting smoke stream flying over a mountain.',
101
+ u'A plane darts across a bright blue sky behind a mountain covered in snow',
102
+ u'A plane leaves a contrail above the snowy mountain top.',
103
+ u'A mountain that has a plane flying overheard in the distance.',
104
+ u'A mountain view with a plume of smoke in the background']
105
+
106
+ """
107
+
108
+ def _load_target(self, id: int) -> List[str]:
109
+ return [ann["caption"] for ann in super()._load_target(id)]
.venv/lib/python3.11/site-packages/torchvision/datasets/dtd.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pathlib
3
+ from typing import Any, Callable, Optional, Tuple, Union
4
+
5
+ import PIL.Image
6
+
7
+ from .utils import download_and_extract_archive, verify_str_arg
8
+ from .vision import VisionDataset
9
+
10
+
11
+ class DTD(VisionDataset):
12
+ """`Describable Textures Dataset (DTD) <https://www.robots.ox.ac.uk/~vgg/data/dtd/>`_.
13
+
14
+ Args:
15
+ root (str or ``pathlib.Path``): Root directory of the dataset.
16
+ split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
17
+ partition (int, optional): The dataset partition. Should be ``1 <= partition <= 10``. Defaults to ``1``.
18
+
19
+ .. note::
20
+
21
+ The partition only changes which split each image belongs to. Thus, regardless of the selected
22
+ partition, combining all splits will result in all images.
23
+
24
+ transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
25
+ version. E.g, ``transforms.RandomCrop``.
26
+ target_transform (callable, optional): A function/transform that takes in the target and transforms it.
27
+ download (bool, optional): If True, downloads the dataset from the internet and
28
+ puts it in root directory. If dataset is already downloaded, it is not
29
+ downloaded again. Default is False.
30
+ """
31
+
32
+ _URL = "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz"
33
+ _MD5 = "fff73e5086ae6bdbea199a49dfb8a4c1"
34
+
35
+ def __init__(
36
+ self,
37
+ root: Union[str, pathlib.Path],
38
+ split: str = "train",
39
+ partition: int = 1,
40
+ transform: Optional[Callable] = None,
41
+ target_transform: Optional[Callable] = None,
42
+ download: bool = False,
43
+ ) -> None:
44
+ self._split = verify_str_arg(split, "split", ("train", "val", "test"))
45
+ if not isinstance(partition, int) and not (1 <= partition <= 10):
46
+ raise ValueError(
47
+ f"Parameter 'partition' should be an integer with `1 <= partition <= 10`, "
48
+ f"but got {partition} instead"
49
+ )
50
+ self._partition = partition
51
+
52
+ super().__init__(root, transform=transform, target_transform=target_transform)
53
+ self._base_folder = pathlib.Path(self.root) / type(self).__name__.lower()
54
+ self._data_folder = self._base_folder / "dtd"
55
+ self._meta_folder = self._data_folder / "labels"
56
+ self._images_folder = self._data_folder / "images"
57
+
58
+ if download:
59
+ self._download()
60
+
61
+ if not self._check_exists():
62
+ raise RuntimeError("Dataset not found. You can use download=True to download it")
63
+
64
+ self._image_files = []
65
+ classes = []
66
+ with open(self._meta_folder / f"{self._split}{self._partition}.txt") as file:
67
+ for line in file:
68
+ cls, name = line.strip().split("/")
69
+ self._image_files.append(self._images_folder.joinpath(cls, name))
70
+ classes.append(cls)
71
+
72
+ self.classes = sorted(set(classes))
73
+ self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
74
+ self._labels = [self.class_to_idx[cls] for cls in classes]
75
+
76
+ def __len__(self) -> int:
77
+ return len(self._image_files)
78
+
79
+ def __getitem__(self, idx: int) -> Tuple[Any, Any]:
80
+ image_file, label = self._image_files[idx], self._labels[idx]
81
+ image = PIL.Image.open(image_file).convert("RGB")
82
+
83
+ if self.transform:
84
+ image = self.transform(image)
85
+
86
+ if self.target_transform:
87
+ label = self.target_transform(label)
88
+
89
+ return image, label
90
+
91
+ def extra_repr(self) -> str:
92
+ return f"split={self._split}, partition={self._partition}"
93
+
94
+ def _check_exists(self) -> bool:
95
+ return os.path.exists(self._data_folder) and os.path.isdir(self._data_folder)
96
+
97
+ def _download(self) -> None:
98
+ if self._check_exists():
99
+ return
100
+ download_and_extract_archive(self._URL, download_root=str(self._base_folder), md5=self._MD5)
.venv/lib/python3.11/site-packages/torchvision/datasets/eurosat.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Callable, Optional, Union
4
+
5
+ from .folder import ImageFolder
6
+ from .utils import download_and_extract_archive
7
+
8
+
9
+ class EuroSAT(ImageFolder):
10
+ """RGB version of the `EuroSAT <https://github.com/phelber/eurosat>`_ Dataset.
11
+
12
+ For the MS version of the dataset, see
13
+ `TorchGeo <https://torchgeo.readthedocs.io/en/stable/api/datasets.html#eurosat>`__.
14
+
15
+ Args:
16
+ root (str or ``pathlib.Path``): Root directory of dataset where ``root/eurosat`` exists.
17
+ transform (callable, optional): A function/transform that takes in a PIL image
18
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
19
+ target_transform (callable, optional): A function/transform that takes in the
20
+ target and transforms it.
21
+ download (bool, optional): If True, downloads the dataset from the internet and
22
+ puts it in root directory. If dataset is already downloaded, it is not
23
+ downloaded again. Default is False.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ root: Union[str, Path],
29
+ transform: Optional[Callable] = None,
30
+ target_transform: Optional[Callable] = None,
31
+ download: bool = False,
32
+ ) -> None:
33
+ self.root = os.path.expanduser(root)
34
+ self._base_folder = os.path.join(self.root, "eurosat")
35
+ self._data_folder = os.path.join(self._base_folder, "2750")
36
+
37
+ if download:
38
+ self.download()
39
+
40
+ if not self._check_exists():
41
+ raise RuntimeError("Dataset not found. You can use download=True to download it")
42
+
43
+ super().__init__(self._data_folder, transform=transform, target_transform=target_transform)
44
+ self.root = os.path.expanduser(root)
45
+
46
+ def __len__(self) -> int:
47
+ return len(self.samples)
48
+
49
+ def _check_exists(self) -> bool:
50
+ return os.path.exists(self._data_folder)
51
+
52
+ def download(self) -> None:
53
+
54
+ if self._check_exists():
55
+ return
56
+
57
+ os.makedirs(self._base_folder, exist_ok=True)
58
+ download_and_extract_archive(
59
+ "https://huggingface.co/datasets/torchgeo/eurosat/resolve/c877bcd43f099cd0196738f714544e355477f3fd/EuroSAT.zip",
60
+ download_root=self._base_folder,
61
+ md5="c8fa014336c82ac7804f0398fcb19387",
62
+ )
.venv/lib/python3.11/site-packages/torchvision/datasets/fakedata.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Optional, Tuple
2
+
3
+ import torch
4
+
5
+ from .. import transforms
6
+ from .vision import VisionDataset
7
+
8
+
9
+ class FakeData(VisionDataset):
10
+ """A fake dataset that returns randomly generated images and returns them as PIL images
11
+
12
+ Args:
13
+ size (int, optional): Size of the dataset. Default: 1000 images
14
+ image_size(tuple, optional): Size if the returned images. Default: (3, 224, 224)
15
+ num_classes(int, optional): Number of classes in the dataset. Default: 10
16
+ transform (callable, optional): A function/transform that takes in a PIL image
17
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
18
+ target_transform (callable, optional): A function/transform that takes in the
19
+ target and transforms it.
20
+ random_offset (int): Offsets the index-based random seed used to
21
+ generate each image. Default: 0
22
+
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ size: int = 1000,
28
+ image_size: Tuple[int, int, int] = (3, 224, 224),
29
+ num_classes: int = 10,
30
+ transform: Optional[Callable] = None,
31
+ target_transform: Optional[Callable] = None,
32
+ random_offset: int = 0,
33
+ ) -> None:
34
+ super().__init__(transform=transform, target_transform=target_transform)
35
+ self.size = size
36
+ self.num_classes = num_classes
37
+ self.image_size = image_size
38
+ self.random_offset = random_offset
39
+
40
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
41
+ """
42
+ Args:
43
+ index (int): Index
44
+
45
+ Returns:
46
+ tuple: (image, target) where target is class_index of the target class.
47
+ """
48
+ # create random image that is consistent with the index id
49
+ if index >= len(self):
50
+ raise IndexError(f"{self.__class__.__name__} index out of range")
51
+ rng_state = torch.get_rng_state()
52
+ torch.manual_seed(index + self.random_offset)
53
+ img = torch.randn(*self.image_size)
54
+ target = torch.randint(0, self.num_classes, size=(1,), dtype=torch.long)[0]
55
+ torch.set_rng_state(rng_state)
56
+
57
+ # convert to PIL Image
58
+ img = transforms.ToPILImage()(img)
59
+ if self.transform is not None:
60
+ img = self.transform(img)
61
+ if self.target_transform is not None:
62
+ target = self.target_transform(target)
63
+
64
+ return img, target.item()
65
+
66
+ def __len__(self) -> int:
67
+ return self.size
.venv/lib/python3.11/site-packages/torchvision/datasets/fer2013.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import pathlib
3
+ from typing import Any, Callable, Optional, Tuple, Union
4
+
5
+ import torch
6
+ from PIL import Image
7
+
8
+ from .utils import check_integrity, verify_str_arg
9
+ from .vision import VisionDataset
10
+
11
+
12
+ class FER2013(VisionDataset):
13
+ """`FER2013
14
+ <https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge>`_ Dataset.
15
+
16
+ .. note::
17
+ This dataset can return test labels only if ``fer2013.csv`` OR
18
+ ``icml_face_data.csv`` are present in ``root/fer2013/``. If only
19
+ ``train.csv`` and ``test.csv`` are present, the test labels are set to
20
+ ``None``.
21
+
22
+ Args:
23
+ root (str or ``pathlib.Path``): Root directory of dataset where directory
24
+ ``root/fer2013`` exists. This directory may contain either
25
+ ``fer2013.csv``, ``icml_face_data.csv``, or both ``train.csv`` and
26
+ ``test.csv``. Precendence is given in that order, i.e. if
27
+ ``fer2013.csv`` is present then the rest of the files will be
28
+ ignored. All these (combinations of) files contain the same data and
29
+ are supported for convenience, but only ``fer2013.csv`` and
30
+ ``icml_face_data.csv`` are able to return non-None test labels.
31
+ split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
32
+ transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
33
+ version. E.g, ``transforms.RandomCrop``
34
+ target_transform (callable, optional): A function/transform that takes in the target and transforms it.
35
+ """
36
+
37
+ _RESOURCES = {
38
+ "train": ("train.csv", "3f0dfb3d3fd99c811a1299cb947e3131"),
39
+ "test": ("test.csv", "b02c2298636a634e8c2faabbf3ea9a23"),
40
+ # The fer2013.csv and icml_face_data.csv files contain both train and
41
+ # tests instances, and unlike test.csv they contain the labels for the
42
+ # test instances. We give these 2 files precedence over train.csv and
43
+ # test.csv. And yes, they both contain the same data, but with different
44
+ # column names (note the spaces) and ordering:
45
+ # $ head -n 1 fer2013.csv icml_face_data.csv train.csv test.csv
46
+ # ==> fer2013.csv <==
47
+ # emotion,pixels,Usage
48
+ #
49
+ # ==> icml_face_data.csv <==
50
+ # emotion, Usage, pixels
51
+ #
52
+ # ==> train.csv <==
53
+ # emotion,pixels
54
+ #
55
+ # ==> test.csv <==
56
+ # pixels
57
+ "fer": ("fer2013.csv", "f8428a1edbd21e88f42c73edd2a14f95"),
58
+ "icml": ("icml_face_data.csv", "b114b9e04e6949e5fe8b6a98b3892b1d"),
59
+ }
60
+
61
+ def __init__(
62
+ self,
63
+ root: Union[str, pathlib.Path],
64
+ split: str = "train",
65
+ transform: Optional[Callable] = None,
66
+ target_transform: Optional[Callable] = None,
67
+ ) -> None:
68
+ self._split = verify_str_arg(split, "split", ("train", "test"))
69
+ super().__init__(root, transform=transform, target_transform=target_transform)
70
+
71
+ base_folder = pathlib.Path(self.root) / "fer2013"
72
+ use_fer_file = (base_folder / self._RESOURCES["fer"][0]).exists()
73
+ use_icml_file = not use_fer_file and (base_folder / self._RESOURCES["icml"][0]).exists()
74
+ file_name, md5 = self._RESOURCES["fer" if use_fer_file else "icml" if use_icml_file else self._split]
75
+ data_file = base_folder / file_name
76
+ if not check_integrity(str(data_file), md5=md5):
77
+ raise RuntimeError(
78
+ f"{file_name} not found in {base_folder} or corrupted. "
79
+ f"You can download it from "
80
+ f"https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge"
81
+ )
82
+
83
+ pixels_key = " pixels" if use_icml_file else "pixels"
84
+ usage_key = " Usage" if use_icml_file else "Usage"
85
+
86
+ def get_img(row):
87
+ return torch.tensor([int(idx) for idx in row[pixels_key].split()], dtype=torch.uint8).reshape(48, 48)
88
+
89
+ def get_label(row):
90
+ if use_fer_file or use_icml_file or self._split == "train":
91
+ return int(row["emotion"])
92
+ else:
93
+ return None
94
+
95
+ with open(data_file, "r", newline="") as file:
96
+ rows = (row for row in csv.DictReader(file))
97
+
98
+ if use_fer_file or use_icml_file:
99
+ valid_keys = ("Training",) if self._split == "train" else ("PublicTest", "PrivateTest")
100
+ rows = (row for row in rows if row[usage_key] in valid_keys)
101
+
102
+ self._samples = [(get_img(row), get_label(row)) for row in rows]
103
+
104
+ def __len__(self) -> int:
105
+ return len(self._samples)
106
+
107
+ def __getitem__(self, idx: int) -> Tuple[Any, Any]:
108
+ image_tensor, target = self._samples[idx]
109
+ image = Image.fromarray(image_tensor.numpy())
110
+
111
+ if self.transform is not None:
112
+ image = self.transform(image)
113
+
114
+ if self.target_transform is not None:
115
+ target = self.target_transform(target)
116
+
117
+ return image, target
118
+
119
+ def extra_repr(self) -> str:
120
+ return f"split={self._split}"
.venv/lib/python3.11/site-packages/torchvision/datasets/fgvc_aircraft.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from pathlib import Path
5
+ from typing import Any, Callable, Optional, Tuple, Union
6
+
7
+ import PIL.Image
8
+
9
+ from .utils import download_and_extract_archive, verify_str_arg
10
+ from .vision import VisionDataset
11
+
12
+
13
+ class FGVCAircraft(VisionDataset):
14
+ """`FGVC Aircraft <https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/>`_ Dataset.
15
+
16
+ The dataset contains 10,000 images of aircraft, with 100 images for each of 100
17
+ different aircraft model variants, most of which are airplanes.
18
+ Aircraft models are organized in a three-levels hierarchy. The three levels, from
19
+ finer to coarser, are:
20
+
21
+ - ``variant``, e.g. Boeing 737-700. A variant collapses all the models that are visually
22
+ indistinguishable into one class. The dataset comprises 100 different variants.
23
+ - ``family``, e.g. Boeing 737. The dataset comprises 70 different families.
24
+ - ``manufacturer``, e.g. Boeing. The dataset comprises 30 different manufacturers.
25
+
26
+ Args:
27
+ root (str or ``pathlib.Path``): Root directory of the FGVC Aircraft dataset.
28
+ split (string, optional): The dataset split, supports ``train``, ``val``,
29
+ ``trainval`` and ``test``.
30
+ annotation_level (str, optional): The annotation level, supports ``variant``,
31
+ ``family`` and ``manufacturer``.
32
+ transform (callable, optional): A function/transform that takes in a PIL image
33
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
34
+ target_transform (callable, optional): A function/transform that takes in the
35
+ target and transforms it.
36
+ download (bool, optional): If True, downloads the dataset from the internet and
37
+ puts it in root directory. If dataset is already downloaded, it is not
38
+ downloaded again.
39
+ """
40
+
41
+ _URL = "https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz"
42
+
43
+ def __init__(
44
+ self,
45
+ root: Union[str, Path],
46
+ split: str = "trainval",
47
+ annotation_level: str = "variant",
48
+ transform: Optional[Callable] = None,
49
+ target_transform: Optional[Callable] = None,
50
+ download: bool = False,
51
+ ) -> None:
52
+ super().__init__(root, transform=transform, target_transform=target_transform)
53
+ self._split = verify_str_arg(split, "split", ("train", "val", "trainval", "test"))
54
+ self._annotation_level = verify_str_arg(
55
+ annotation_level, "annotation_level", ("variant", "family", "manufacturer")
56
+ )
57
+
58
+ self._data_path = os.path.join(self.root, "fgvc-aircraft-2013b")
59
+ if download:
60
+ self._download()
61
+
62
+ if not self._check_exists():
63
+ raise RuntimeError("Dataset not found. You can use download=True to download it")
64
+
65
+ annotation_file = os.path.join(
66
+ self._data_path,
67
+ "data",
68
+ {
69
+ "variant": "variants.txt",
70
+ "family": "families.txt",
71
+ "manufacturer": "manufacturers.txt",
72
+ }[self._annotation_level],
73
+ )
74
+ with open(annotation_file, "r") as f:
75
+ self.classes = [line.strip() for line in f]
76
+
77
+ self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
78
+
79
+ image_data_folder = os.path.join(self._data_path, "data", "images")
80
+ labels_file = os.path.join(self._data_path, "data", f"images_{self._annotation_level}_{self._split}.txt")
81
+
82
+ self._image_files = []
83
+ self._labels = []
84
+
85
+ with open(labels_file, "r") as f:
86
+ for line in f:
87
+ image_name, label_name = line.strip().split(" ", 1)
88
+ self._image_files.append(os.path.join(image_data_folder, f"{image_name}.jpg"))
89
+ self._labels.append(self.class_to_idx[label_name])
90
+
91
+ def __len__(self) -> int:
92
+ return len(self._image_files)
93
+
94
+ def __getitem__(self, idx: int) -> Tuple[Any, Any]:
95
+ image_file, label = self._image_files[idx], self._labels[idx]
96
+ image = PIL.Image.open(image_file).convert("RGB")
97
+
98
+ if self.transform:
99
+ image = self.transform(image)
100
+
101
+ if self.target_transform:
102
+ label = self.target_transform(label)
103
+
104
+ return image, label
105
+
106
+ def _download(self) -> None:
107
+ """
108
+ Download the FGVC Aircraft dataset archive and extract it under root.
109
+ """
110
+ if self._check_exists():
111
+ return
112
+ download_and_extract_archive(self._URL, self.root)
113
+
114
+ def _check_exists(self) -> bool:
115
+ return os.path.exists(self._data_path) and os.path.isdir(self._data_path)
.venv/lib/python3.11/site-packages/torchvision/datasets/flickr.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ from collections import defaultdict
4
+ from html.parser import HTMLParser
5
+ from pathlib import Path
6
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
7
+
8
+ from PIL import Image
9
+
10
+ from .vision import VisionDataset
11
+
12
+
13
+ class Flickr8kParser(HTMLParser):
14
+ """Parser for extracting captions from the Flickr8k dataset web page."""
15
+
16
+ def __init__(self, root: Union[str, Path]) -> None:
17
+ super().__init__()
18
+
19
+ self.root = root
20
+
21
+ # Data structure to store captions
22
+ self.annotations: Dict[str, List[str]] = {}
23
+
24
+ # State variables
25
+ self.in_table = False
26
+ self.current_tag: Optional[str] = None
27
+ self.current_img: Optional[str] = None
28
+
29
+ def handle_starttag(self, tag: str, attrs: List[Tuple[str, Optional[str]]]) -> None:
30
+ self.current_tag = tag
31
+
32
+ if tag == "table":
33
+ self.in_table = True
34
+
35
+ def handle_endtag(self, tag: str) -> None:
36
+ self.current_tag = None
37
+
38
+ if tag == "table":
39
+ self.in_table = False
40
+
41
+ def handle_data(self, data: str) -> None:
42
+ if self.in_table:
43
+ if data == "Image Not Found":
44
+ self.current_img = None
45
+ elif self.current_tag == "a":
46
+ img_id = data.split("/")[-2]
47
+ img_id = os.path.join(self.root, img_id + "_*.jpg")
48
+ img_id = glob.glob(img_id)[0]
49
+ self.current_img = img_id
50
+ self.annotations[img_id] = []
51
+ elif self.current_tag == "li" and self.current_img:
52
+ img_id = self.current_img
53
+ self.annotations[img_id].append(data.strip())
54
+
55
+
56
+ class Flickr8k(VisionDataset):
57
+ """`Flickr8k Entities <http://hockenmaier.cs.illinois.edu/8k-pictures.html>`_ Dataset.
58
+
59
+ Args:
60
+ root (str or ``pathlib.Path``): Root directory where images are downloaded to.
61
+ ann_file (string): Path to annotation file.
62
+ transform (callable, optional): A function/transform that takes in a PIL image
63
+ and returns a transformed version. E.g, ``transforms.PILToTensor``
64
+ target_transform (callable, optional): A function/transform that takes in the
65
+ target and transforms it.
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ root: Union[str, Path],
71
+ ann_file: str,
72
+ transform: Optional[Callable] = None,
73
+ target_transform: Optional[Callable] = None,
74
+ ) -> None:
75
+ super().__init__(root, transform=transform, target_transform=target_transform)
76
+ self.ann_file = os.path.expanduser(ann_file)
77
+
78
+ # Read annotations and store in a dict
79
+ parser = Flickr8kParser(self.root)
80
+ with open(self.ann_file) as fh:
81
+ parser.feed(fh.read())
82
+ self.annotations = parser.annotations
83
+
84
+ self.ids = list(sorted(self.annotations.keys()))
85
+
86
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
87
+ """
88
+ Args:
89
+ index (int): Index
90
+
91
+ Returns:
92
+ tuple: Tuple (image, target). target is a list of captions for the image.
93
+ """
94
+ img_id = self.ids[index]
95
+
96
+ # Image
97
+ img = Image.open(img_id).convert("RGB")
98
+ if self.transform is not None:
99
+ img = self.transform(img)
100
+
101
+ # Captions
102
+ target = self.annotations[img_id]
103
+ if self.target_transform is not None:
104
+ target = self.target_transform(target)
105
+
106
+ return img, target
107
+
108
+ def __len__(self) -> int:
109
+ return len(self.ids)
110
+
111
+
112
+ class Flickr30k(VisionDataset):
113
+ """`Flickr30k Entities <https://bryanplummer.com/Flickr30kEntities/>`_ Dataset.
114
+
115
+ Args:
116
+ root (str or ``pathlib.Path``): Root directory where images are downloaded to.
117
+ ann_file (string): Path to annotation file.
118
+ transform (callable, optional): A function/transform that takes in a PIL image
119
+ and returns a transformed version. E.g, ``transforms.PILToTensor``
120
+ target_transform (callable, optional): A function/transform that takes in the
121
+ target and transforms it.
122
+ """
123
+
124
+ def __init__(
125
+ self,
126
+ root: str,
127
+ ann_file: str,
128
+ transform: Optional[Callable] = None,
129
+ target_transform: Optional[Callable] = None,
130
+ ) -> None:
131
+ super().__init__(root, transform=transform, target_transform=target_transform)
132
+ self.ann_file = os.path.expanduser(ann_file)
133
+
134
+ # Read annotations and store in a dict
135
+ self.annotations = defaultdict(list)
136
+ with open(self.ann_file) as fh:
137
+ for line in fh:
138
+ img_id, caption = line.strip().split("\t")
139
+ self.annotations[img_id[:-2]].append(caption)
140
+
141
+ self.ids = list(sorted(self.annotations.keys()))
142
+
143
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
144
+ """
145
+ Args:
146
+ index (int): Index
147
+
148
+ Returns:
149
+ tuple: Tuple (image, target). target is a list of captions for the image.
150
+ """
151
+ img_id = self.ids[index]
152
+
153
+ # Image
154
+ filename = os.path.join(self.root, img_id)
155
+ img = Image.open(filename).convert("RGB")
156
+ if self.transform is not None:
157
+ img = self.transform(img)
158
+
159
+ # Captions
160
+ target = self.annotations[img_id]
161
+ if self.target_transform is not None:
162
+ target = self.target_transform(target)
163
+
164
+ return img, target
165
+
166
+ def __len__(self) -> int:
167
+ return len(self.ids)
.venv/lib/python3.11/site-packages/torchvision/datasets/flowers102.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Any, Callable, Optional, Tuple, Union
3
+
4
+ import PIL.Image
5
+
6
+ from .utils import check_integrity, download_and_extract_archive, download_url, verify_str_arg
7
+ from .vision import VisionDataset
8
+
9
+
10
+ class Flowers102(VisionDataset):
11
+ """`Oxford 102 Flower <https://www.robots.ox.ac.uk/~vgg/data/flowers/102/>`_ Dataset.
12
+
13
+ .. warning::
14
+
15
+ This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
16
+
17
+ Oxford 102 Flower is an image classification dataset consisting of 102 flower categories. The
18
+ flowers were chosen to be flowers commonly occurring in the United Kingdom. Each class consists of
19
+ between 40 and 258 images.
20
+
21
+ The images have large scale, pose and light variations. In addition, there are categories that
22
+ have large variations within the category, and several very similar categories.
23
+
24
+ Args:
25
+ root (str or ``pathlib.Path``): Root directory of the dataset.
26
+ split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
27
+ transform (callable, optional): A function/transform that takes in a PIL image and returns a
28
+ transformed version. E.g, ``transforms.RandomCrop``.
29
+ target_transform (callable, optional): A function/transform that takes in the target and transforms it.
30
+ download (bool, optional): If true, downloads the dataset from the internet and
31
+ puts it in root directory. If dataset is already downloaded, it is not
32
+ downloaded again.
33
+ """
34
+
35
+ _download_url_prefix = "https://www.robots.ox.ac.uk/~vgg/data/flowers/102/"
36
+ _file_dict = { # filename, md5
37
+ "image": ("102flowers.tgz", "52808999861908f626f3c1f4e79d11fa"),
38
+ "label": ("imagelabels.mat", "e0620be6f572b9609742df49c70aed4d"),
39
+ "setid": ("setid.mat", "a5357ecc9cb78c4bef273ce3793fc85c"),
40
+ }
41
+ _splits_map = {"train": "trnid", "val": "valid", "test": "tstid"}
42
+
43
+ def __init__(
44
+ self,
45
+ root: Union[str, Path],
46
+ split: str = "train",
47
+ transform: Optional[Callable] = None,
48
+ target_transform: Optional[Callable] = None,
49
+ download: bool = False,
50
+ ) -> None:
51
+ super().__init__(root, transform=transform, target_transform=target_transform)
52
+ self._split = verify_str_arg(split, "split", ("train", "val", "test"))
53
+ self._base_folder = Path(self.root) / "flowers-102"
54
+ self._images_folder = self._base_folder / "jpg"
55
+
56
+ if download:
57
+ self.download()
58
+
59
+ if not self._check_integrity():
60
+ raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
61
+
62
+ from scipy.io import loadmat
63
+
64
+ set_ids = loadmat(self._base_folder / self._file_dict["setid"][0], squeeze_me=True)
65
+ image_ids = set_ids[self._splits_map[self._split]].tolist()
66
+
67
+ labels = loadmat(self._base_folder / self._file_dict["label"][0], squeeze_me=True)
68
+ image_id_to_label = dict(enumerate((labels["labels"] - 1).tolist(), 1))
69
+
70
+ self._labels = []
71
+ self._image_files = []
72
+ for image_id in image_ids:
73
+ self._labels.append(image_id_to_label[image_id])
74
+ self._image_files.append(self._images_folder / f"image_{image_id:05d}.jpg")
75
+
76
+ def __len__(self) -> int:
77
+ return len(self._image_files)
78
+
79
+ def __getitem__(self, idx: int) -> Tuple[Any, Any]:
80
+ image_file, label = self._image_files[idx], self._labels[idx]
81
+ image = PIL.Image.open(image_file).convert("RGB")
82
+
83
+ if self.transform:
84
+ image = self.transform(image)
85
+
86
+ if self.target_transform:
87
+ label = self.target_transform(label)
88
+
89
+ return image, label
90
+
91
+ def extra_repr(self) -> str:
92
+ return f"split={self._split}"
93
+
94
+ def _check_integrity(self):
95
+ if not (self._images_folder.exists() and self._images_folder.is_dir()):
96
+ return False
97
+
98
+ for id in ["label", "setid"]:
99
+ filename, md5 = self._file_dict[id]
100
+ if not check_integrity(str(self._base_folder / filename), md5):
101
+ return False
102
+ return True
103
+
104
+ def download(self):
105
+ if self._check_integrity():
106
+ return
107
+ download_and_extract_archive(
108
+ f"{self._download_url_prefix}{self._file_dict['image'][0]}",
109
+ str(self._base_folder),
110
+ md5=self._file_dict["image"][1],
111
+ )
112
+ for id in ["label", "setid"]:
113
+ filename, md5 = self._file_dict[id]
114
+ download_url(self._download_url_prefix + filename, str(self._base_folder), md5=md5)
.venv/lib/python3.11/site-packages/torchvision/datasets/food101.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Any, Callable, Optional, Tuple, Union
4
+
5
+ import PIL.Image
6
+
7
+ from .utils import download_and_extract_archive, verify_str_arg
8
+ from .vision import VisionDataset
9
+
10
+
11
+ class Food101(VisionDataset):
12
+ """`The Food-101 Data Set <https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/>`_.
13
+
14
+ The Food-101 is a challenging data set of 101 food categories with 101,000 images.
15
+ For each class, 250 manually reviewed test images are provided as well as 750 training images.
16
+ On purpose, the training images were not cleaned, and thus still contain some amount of noise.
17
+ This comes mostly in the form of intense colors and sometimes wrong labels. All images were
18
+ rescaled to have a maximum side length of 512 pixels.
19
+
20
+
21
+ Args:
22
+ root (str or ``pathlib.Path``): Root directory of the dataset.
23
+ split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``.
24
+ transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
25
+ version. E.g, ``transforms.RandomCrop``.
26
+ target_transform (callable, optional): A function/transform that takes in the target and transforms it.
27
+ download (bool, optional): If True, downloads the dataset from the internet and
28
+ puts it in root directory. If dataset is already downloaded, it is not
29
+ downloaded again. Default is False.
30
+ """
31
+
32
+ _URL = "http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz"
33
+ _MD5 = "85eeb15f3717b99a5da872d97d918f87"
34
+
35
+ def __init__(
36
+ self,
37
+ root: Union[str, Path],
38
+ split: str = "train",
39
+ transform: Optional[Callable] = None,
40
+ target_transform: Optional[Callable] = None,
41
+ download: bool = False,
42
+ ) -> None:
43
+ super().__init__(root, transform=transform, target_transform=target_transform)
44
+ self._split = verify_str_arg(split, "split", ("train", "test"))
45
+ self._base_folder = Path(self.root) / "food-101"
46
+ self._meta_folder = self._base_folder / "meta"
47
+ self._images_folder = self._base_folder / "images"
48
+
49
+ if download:
50
+ self._download()
51
+
52
+ if not self._check_exists():
53
+ raise RuntimeError("Dataset not found. You can use download=True to download it")
54
+
55
+ self._labels = []
56
+ self._image_files = []
57
+ with open(self._meta_folder / f"{split}.json") as f:
58
+ metadata = json.loads(f.read())
59
+
60
+ self.classes = sorted(metadata.keys())
61
+ self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
62
+
63
+ for class_label, im_rel_paths in metadata.items():
64
+ self._labels += [self.class_to_idx[class_label]] * len(im_rel_paths)
65
+ self._image_files += [
66
+ self._images_folder.joinpath(*f"{im_rel_path}.jpg".split("/")) for im_rel_path in im_rel_paths
67
+ ]
68
+
69
+ def __len__(self) -> int:
70
+ return len(self._image_files)
71
+
72
+ def __getitem__(self, idx: int) -> Tuple[Any, Any]:
73
+ image_file, label = self._image_files[idx], self._labels[idx]
74
+ image = PIL.Image.open(image_file).convert("RGB")
75
+
76
+ if self.transform:
77
+ image = self.transform(image)
78
+
79
+ if self.target_transform:
80
+ label = self.target_transform(label)
81
+
82
+ return image, label
83
+
84
+ def extra_repr(self) -> str:
85
+ return f"split={self._split}"
86
+
87
+ def _check_exists(self) -> bool:
88
+ return all(folder.exists() and folder.is_dir() for folder in (self._meta_folder, self._images_folder))
89
+
90
+ def _download(self) -> None:
91
+ if self._check_exists():
92
+ return
93
+ download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5)