File size: 8,484 Bytes
3589275 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 | import csv
import os
import pathlib
from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np
import PIL
import torch
from torchvision.datasets.folder import make_dataset
from torchvision.datasets.utils import (download_and_extract_archive,
verify_str_arg)
from torchvision.datasets.vision import VisionDataset
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
"""Finds the class folders in a dataset.
See :class:`DatasetFolder` for details.
"""
classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
if not classes:
raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
class PyTorchGTSRB(VisionDataset):
"""`German Traffic Sign Recognition Benchmark (GTSRB) <https://benchmark.ini.rub.de/>`_ Dataset.
Modified from https://pytorch.org/vision/main/_modules/torchvision/datasets/gtsrb.html#GTSRB.
Args:
root (string): Root directory of the dataset.
split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
def __init__(
self,
root: str,
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self._split = verify_str_arg(split, "split", ("train", "test"))
self._base_folder = pathlib.Path(root) / "gtsrb"
self._target_folder = (
self._base_folder / "GTSRB" / ("Training" if self._split == "train" else "Final_Test/Images")
)
if download:
self.download()
if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")
if self._split == "train":
_, class_to_idx = find_classes(str(self._target_folder))
samples = make_dataset(str(self._target_folder), extensions=(".ppm",), class_to_idx=class_to_idx)
else:
with open(self._base_folder / "GT-final_test.csv") as csv_file:
samples = [
(str(self._target_folder / row["Filename"]), int(row["ClassId"]))
for row in csv.DictReader(csv_file, delimiter=";", skipinitialspace=True)
]
self._samples = samples
self.transform = transform
self.target_transform = target_transform
def __len__(self) -> int:
return len(self._samples)
def __getitem__(self, index: int) -> Tuple[Any, Any]:
path, target = self._samples[index]
sample = PIL.Image.open(path).convert("RGB")
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def _check_exists(self) -> bool:
return self._target_folder.is_dir()
def download(self) -> None:
if self._check_exists():
return
base_url = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/"
if self._split == "train":
download_and_extract_archive(
f"{base_url}GTSRB-Training_fixed.zip",
download_root=str(self._base_folder),
md5="513f3c79a4c5141765e10e952eaa2478",
)
else:
download_and_extract_archive(
f"{base_url}GTSRB_Final_Test_Images.zip",
download_root=str(self._base_folder),
md5="c7e4e6327067d32654124b0fe9e82185",
)
download_and_extract_archive(
f"{base_url}GTSRB_Final_Test_GT.zip",
download_root=str(self._base_folder),
md5="fe31e9c9270bbcd7b84b7f21a9d9d9e5",
)
class GTSRB:
def __init__(self,
preprocess,
location=os.path.expanduser('~/data'),
batch_size=128,
num_workers=16):
# to fit with repo conventions for location
self.train_dataset = PyTorchGTSRB(
root=location,
download=True,
split='train',
transform=preprocess
)
self.train_loader = torch.utils.data.DataLoader(
self.train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers
)
self.test_dataset = PyTorchGTSRB(
root=location,
download=True,
split='test',
transform=preprocess
)
self.test_loader = torch.utils.data.DataLoader(
self.test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers
)
# from https://github.com/openai/CLIP/blob/e184f608c5d5e58165682f7c332c3a8b4c1545f2/data/prompts.md
self.classnames = [
'red and white circle 20 kph speed limit',
'red and white circle 30 kph speed limit',
'red and white circle 50 kph speed limit',
'red and white circle 60 kph speed limit',
'red and white circle 70 kph speed limit',
'red and white circle 80 kph speed limit',
'end / de-restriction of 80 kph speed limit',
'red and white circle 100 kph speed limit',
'red and white circle 120 kph speed limit',
'red and white circle red car and black car no passing',
'red and white circle red truck and black car no passing',
'red and white triangle road intersection warning',
'white and yellow diamond priority road',
'red and white upside down triangle yield right-of-way',
'stop',
'empty red and white circle',
'red and white circle no truck entry',
'red circle with white horizonal stripe no entry',
'red and white triangle with exclamation mark warning',
'red and white triangle with black left curve approaching warning',
'red and white triangle with black right curve approaching warning',
'red and white triangle with black double curve approaching warning',
'red and white triangle rough / bumpy road warning',
'red and white triangle car skidding / slipping warning',
'red and white triangle with merging / narrow lanes warning',
'red and white triangle with person digging / construction / road work warning',
'red and white triangle with traffic light approaching warning',
'red and white triangle with person walking warning',
'red and white triangle with child and person walking warning',
'red and white triangle with bicyle warning',
'red and white triangle with snowflake / ice warning',
'red and white triangle with deer warning',
'white circle with gray strike bar no speed limit',
'blue circle with white right turn arrow mandatory',
'blue circle with white left turn arrow mandatory',
'blue circle with white forward arrow mandatory',
'blue circle with white forward or right turn arrow mandatory',
'blue circle with white forward or left turn arrow mandatory',
'blue circle with white keep right arrow mandatory',
'blue circle with white keep left arrow mandatory',
'blue circle with white arrows indicating a traffic circle',
'white circle with gray strike bar indicating no passing for cars has ended',
'white circle with gray strike bar indicating no passing for trucks has ended',
]
|