| import csv |
| import os |
| import pathlib |
| from typing import Any, Callable, Dict, List, Optional, Tuple |
|
|
| import numpy as np |
| import PIL |
| import torch |
| from torchvision.datasets.folder import make_dataset |
| from torchvision.datasets.utils import (download_and_extract_archive, |
| verify_str_arg) |
| from torchvision.datasets.vision import VisionDataset |
|
|
| def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]: |
| """Finds the class folders in a dataset. |
| |
| See :class:`DatasetFolder` for details. |
| """ |
| classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir()) |
| if not classes: |
| raise FileNotFoundError(f"Couldn't find any class folder in {directory}.") |
|
|
| class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} |
| return classes, class_to_idx |
|
|
| class PyTorchGTSRB(VisionDataset): |
| """`German Traffic Sign Recognition Benchmark (GTSRB) <https://benchmark.ini.rub.de/>`_ Dataset. |
| |
| Modified from https://pytorch.org/vision/main/_modules/torchvision/datasets/gtsrb.html#GTSRB. |
| |
| Args: |
| root (string): Root directory of the dataset. |
| split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``. |
| transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed |
| version. E.g, ``transforms.RandomCrop``. |
| target_transform (callable, optional): A function/transform that takes in the target and transforms it. |
| download (bool, optional): If True, downloads the dataset from the internet and |
| puts it in root directory. If dataset is already downloaded, it is not |
| downloaded again. |
| """ |
|
|
| def __init__( |
| self, |
| root: str, |
| split: str = "train", |
| transform: Optional[Callable] = None, |
| target_transform: Optional[Callable] = None, |
| download: bool = False, |
| ) -> None: |
|
|
| super().__init__(root, transform=transform, target_transform=target_transform) |
|
|
| self._split = verify_str_arg(split, "split", ("train", "test")) |
| self._base_folder = pathlib.Path(root) / "gtsrb" |
| self._target_folder = ( |
| self._base_folder / "GTSRB" / ("Training" if self._split == "train" else "Final_Test/Images") |
| ) |
|
|
| if download: |
| self.download() |
|
|
| if not self._check_exists(): |
| raise RuntimeError("Dataset not found. You can use download=True to download it") |
|
|
| if self._split == "train": |
| _, class_to_idx = find_classes(str(self._target_folder)) |
| samples = make_dataset(str(self._target_folder), extensions=(".ppm",), class_to_idx=class_to_idx) |
| else: |
| with open(self._base_folder / "GT-final_test.csv") as csv_file: |
| samples = [ |
| (str(self._target_folder / row["Filename"]), int(row["ClassId"])) |
| for row in csv.DictReader(csv_file, delimiter=";", skipinitialspace=True) |
| ] |
|
|
| self._samples = samples |
| self.transform = transform |
| self.target_transform = target_transform |
|
|
| def __len__(self) -> int: |
| return len(self._samples) |
|
|
| def __getitem__(self, index: int) -> Tuple[Any, Any]: |
|
|
| path, target = self._samples[index] |
| sample = PIL.Image.open(path).convert("RGB") |
|
|
| if self.transform is not None: |
| sample = self.transform(sample) |
|
|
| if self.target_transform is not None: |
| target = self.target_transform(target) |
|
|
| return sample, target |
|
|
|
|
| def _check_exists(self) -> bool: |
| return self._target_folder.is_dir() |
|
|
| def download(self) -> None: |
| if self._check_exists(): |
| return |
|
|
| base_url = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/" |
|
|
| if self._split == "train": |
| download_and_extract_archive( |
| f"{base_url}GTSRB-Training_fixed.zip", |
| download_root=str(self._base_folder), |
| md5="513f3c79a4c5141765e10e952eaa2478", |
| ) |
| else: |
| download_and_extract_archive( |
| f"{base_url}GTSRB_Final_Test_Images.zip", |
| download_root=str(self._base_folder), |
| md5="c7e4e6327067d32654124b0fe9e82185", |
| ) |
| download_and_extract_archive( |
| f"{base_url}GTSRB_Final_Test_GT.zip", |
| download_root=str(self._base_folder), |
| md5="fe31e9c9270bbcd7b84b7f21a9d9d9e5", |
| ) |
|
|
|
|
| class GTSRB: |
| def __init__(self, |
| preprocess, |
| location=os.path.expanduser('~/data'), |
| batch_size=128, |
| num_workers=16): |
|
|
| |
| self.train_dataset = PyTorchGTSRB( |
| root=location, |
| download=True, |
| split='train', |
| transform=preprocess |
| ) |
|
|
| self.train_loader = torch.utils.data.DataLoader( |
| self.train_dataset, |
| batch_size=batch_size, |
| shuffle=True, |
| num_workers=num_workers |
| ) |
|
|
| self.test_dataset = PyTorchGTSRB( |
| root=location, |
| download=True, |
| split='test', |
| transform=preprocess |
| ) |
|
|
| self.test_loader = torch.utils.data.DataLoader( |
| self.test_dataset, |
| batch_size=batch_size, |
| shuffle=False, |
| num_workers=num_workers |
| ) |
|
|
| |
| self.classnames = [ |
| 'red and white circle 20 kph speed limit', |
| 'red and white circle 30 kph speed limit', |
| 'red and white circle 50 kph speed limit', |
| 'red and white circle 60 kph speed limit', |
| 'red and white circle 70 kph speed limit', |
| 'red and white circle 80 kph speed limit', |
| 'end / de-restriction of 80 kph speed limit', |
| 'red and white circle 100 kph speed limit', |
| 'red and white circle 120 kph speed limit', |
| 'red and white circle red car and black car no passing', |
| 'red and white circle red truck and black car no passing', |
| 'red and white triangle road intersection warning', |
| 'white and yellow diamond priority road', |
| 'red and white upside down triangle yield right-of-way', |
| 'stop', |
| 'empty red and white circle', |
| 'red and white circle no truck entry', |
| 'red circle with white horizonal stripe no entry', |
| 'red and white triangle with exclamation mark warning', |
| 'red and white triangle with black left curve approaching warning', |
| 'red and white triangle with black right curve approaching warning', |
| 'red and white triangle with black double curve approaching warning', |
| 'red and white triangle rough / bumpy road warning', |
| 'red and white triangle car skidding / slipping warning', |
| 'red and white triangle with merging / narrow lanes warning', |
| 'red and white triangle with person digging / construction / road work warning', |
| 'red and white triangle with traffic light approaching warning', |
| 'red and white triangle with person walking warning', |
| 'red and white triangle with child and person walking warning', |
| 'red and white triangle with bicyle warning', |
| 'red and white triangle with snowflake / ice warning', |
| 'red and white triangle with deer warning', |
| 'white circle with gray strike bar no speed limit', |
| 'blue circle with white right turn arrow mandatory', |
| 'blue circle with white left turn arrow mandatory', |
| 'blue circle with white forward arrow mandatory', |
| 'blue circle with white forward or right turn arrow mandatory', |
| 'blue circle with white forward or left turn arrow mandatory', |
| 'blue circle with white keep right arrow mandatory', |
| 'blue circle with white keep left arrow mandatory', |
| 'blue circle with white arrows indicating a traffic circle', |
| 'white circle with gray strike bar indicating no passing for cars has ended', |
| 'white circle with gray strike bar indicating no passing for trucks has ended', |
| ] |
|
|