import os import torch import abc import os from typing import Any, Callable, Dict, Optional, Tuple import numpy as np import torch from torch import Tensor from torch.utils.data import Dataset from torchvision.datasets import ImageFolder from torchvision.datasets.folder import default_loader as pil_loader # modified from: https://github.com/microsoft/torchgeo class VisionDataset(Dataset[Dict[str, Any]], abc.ABC): """Abstract base class for datasets lacking geospatial information. This base class is designed for datasets with pre-defined image chips. """ @abc.abstractmethod def __getitem__(self, index: int) -> Dict[str, Any]: """Return an index within the dataset. Args: index: index to return Returns: data and labels at that index Raises: IndexError: if index is out of range of the dataset """ @abc.abstractmethod def __len__(self) -> int: """Return the length of the dataset. Returns: length of the dataset """ def __str__(self) -> str: """Return the informal string representation of the object. Returns: informal string representation """ return f"""\ {self.__class__.__name__} Dataset type: VisionDataset size: {len(self)}""" class VisionClassificationDataset(VisionDataset, ImageFolder): """Abstract base class for classification datasets lacking geospatial information. This base class is designed for datasets with pre-defined image chips which are separated into separate folders per class. """ def __init__( self, root: str, transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, loader: Optional[Callable[[str], Any]] = pil_loader, is_valid_file: Optional[Callable[[str], bool]] = None, ) -> None: """Initialize a new VisionClassificationDataset instance. Args: root: root directory where dataset can be found transforms: a function/transform that takes input sample and its target as entry and returns a transformed version loader: a callable function which takes as input a path to an image and returns a PIL Image or numpy array is_valid_file: A function that takes the path of an Image file and checks if the file is a valid file """ # When transform & target_transform are None, ImageFolder.__getitem__(index) # returns a PIL.Image and int for image and label, respectively super().__init__( root=root, transform=None, target_transform=None, loader=loader, is_valid_file=is_valid_file, ) # Must be set after calling super().__init__() self.transforms = transforms def __getitem__(self, index: int) -> Dict[str, Tensor]: """Return an index within the dataset. Args: index: index to return Returns: data and label at that index """ image, label = self._load_image(index) if self.transforms is not None: return self.transforms(image), label return image, label def __len__(self) -> int: """Return the number of data points in the dataset. Returns: length of the dataset """ return len(self.imgs) def _load_image(self, index: int) -> Tuple[Tensor, Tensor]: """Load a single image and it's class label. Args: index: index to return Returns: the image the image class label """ img, label = ImageFolder.__getitem__(self, index) label = torch.tensor(label) return img, label class RESISC45Dataset(VisionClassificationDataset): """RESISC45 dataset. The `RESISC45 `_ dataset is a dataset for remote sensing image scene classification. Dataset features: * 31,500 images with 0.2-30 m per pixel resolution (256x256 px) * three spectral bands - RGB * 45 scene classes, 700 images per class * images extracted from Google Earth from over 100 countries * images conditions with high variability (resolution, weather, illumination) Dataset format: * images are three-channel jpgs Dataset classes: 0. airplane 1. airport 2. baseball_diamond 3. basketball_court 4. beach 5. bridge 6. chaparral 7. church 8. circular_farmland 9. cloud 10. commercial_area 11. dense_residential 12. desert 13. forest 14. freeway 15. golf_course 16. ground_track_field 17. harbor 18. industrial_area 19. intersection 20. island 21. lake 22. meadow 23. medium_residential 24. mobile_home_park 25. mountain 26. overpass 27. palace 28. parking_lot 29. railway 30. railway_station 31. rectangular_farmland 32. river 33. roundabout 34. runway 35. sea_ice 36. ship 37. snowberg 38. sparse_residential 39. stadium 40. storage_tank 41. tennis_court 42. terrace 43. thermal_power_station 44. wetland This dataset uses the train/val/test splits defined in the "In-domain representation learning for remote sensing" paper: * https://arxiv.org/abs/1911.06721 If you use this dataset in your research, please cite the following paper: * https://doi.org/10.1109/jproc.2017.2675998 """ # url = "https://drive.google.com/file/d/1DnPSU5nVSN7xv95bpZ3XQ0JhKXZOKgIv" # md5 = "d824acb73957502b00efd559fc6cfbbb" # filename = "NWPU-RESISC45.rar" directory = "resisc45/NWPU-RESISC45" splits = ["train", "val", "test"] split_urls = { "train": "https://storage.googleapis.com/remote_sensing_representations/resisc45-train.txt", # noqa: E501 "val": "https://storage.googleapis.com/remote_sensing_representations/resisc45-val.txt", # noqa: E501 "test": "https://storage.googleapis.com/remote_sensing_representations/resisc45-test.txt", # noqa: E501 } split_md5s = { "train": "b5a4c05a37de15e4ca886696a85c403e", "val": "a0770cee4c5ca20b8c32bbd61e114805", "test": "3dda9e4988b47eb1de9f07993653eb08", } classes = [ "airplane", "airport", "baseball_diamond", "basketball_court", "beach", "bridge", "chaparral", "church", "circular_farmland", "cloud", "commercial_area", "dense_residential", "desert", "forest", "freeway", "golf_course", "ground_track_field", "harbor", "industrial_area", "intersection", "island", "lake", "meadow", "medium_residential", "mobile_home_park", "mountain", "overpass", "palace", "parking_lot", "railway", "railway_station", "rectangular_farmland", "river", "roundabout", "runway", "sea_ice", "ship", "snowberg", "sparse_residential", "stadium", "storage_tank", "tennis_court", "terrace", "thermal_power_station", "wetland", ] def __init__( self, root: str = "data", split: str = "train", transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, ) -> None: """Initialize a new RESISC45 dataset instance. Args: root: root directory where dataset can be found split: one of "train", "val", or "test" transforms: a function/transform that takes input sample and its target as entry and returns a transformed version """ assert split in self.splits self.root = root valid_fns = set() with open(os.path.join(self.root, "resisc45", f"resisc45-{split}.txt")) as f: for fn in f: valid_fns.add(fn.strip()) is_in_split: Callable[[str], bool] = lambda x: os.path.basename( x) in valid_fns super().__init__( root=os.path.join(root, self.directory), transforms=transforms, is_valid_file=is_in_split, ) class RESISC45: def __init__(self, preprocess, location=os.path.expanduser('~/data'), batch_size=32, num_workers=16): self.train_dataset = RESISC45Dataset(root=location, split='train', transforms=preprocess) self.train_loader = torch.utils.data.DataLoader( self.train_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers, ) self.test_dataset = RESISC45Dataset(root=location, split='test', transforms=preprocess) self.test_loader = torch.utils.data.DataLoader( self.test_dataset, batch_size=batch_size, num_workers=num_workers ) # class names have _ so split on this for better zero-shot head self.classnames = [' '.join(c.split('_')) for c in RESISC45Dataset.classes]