| import os |
| import torch |
|
|
| import abc |
| import os |
| from typing import Any, Callable, Dict, Optional, Tuple |
|
|
| import numpy as np |
| import torch |
| from torch import Tensor |
| from torch.utils.data import Dataset |
| from torchvision.datasets import ImageFolder |
| from torchvision.datasets.folder import default_loader as pil_loader |
|
|
|
|
| |
| class VisionDataset(Dataset[Dict[str, Any]], abc.ABC): |
| """Abstract base class for datasets lacking geospatial information. |
| This base class is designed for datasets with pre-defined image chips. |
| """ |
|
|
| @abc.abstractmethod |
| def __getitem__(self, index: int) -> Dict[str, Any]: |
| """Return an index within the dataset. |
| Args: |
| index: index to return |
| Returns: |
| data and labels at that index |
| Raises: |
| IndexError: if index is out of range of the dataset |
| """ |
|
|
| @abc.abstractmethod |
| def __len__(self) -> int: |
| """Return the length of the dataset. |
| Returns: |
| length of the dataset |
| """ |
|
|
| def __str__(self) -> str: |
| """Return the informal string representation of the object. |
| Returns: |
| informal string representation |
| """ |
| return f"""\ |
| {self.__class__.__name__} Dataset |
| type: VisionDataset |
| size: {len(self)}""" |
|
|
|
|
| class VisionClassificationDataset(VisionDataset, ImageFolder): |
| """Abstract base class for classification datasets lacking geospatial information. |
| This base class is designed for datasets with pre-defined image chips which |
| are separated into separate folders per class. |
| """ |
|
|
| def __init__( |
| self, |
| root: str, |
| transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, |
| loader: Optional[Callable[[str], Any]] = pil_loader, |
| is_valid_file: Optional[Callable[[str], bool]] = None, |
| ) -> None: |
| """Initialize a new VisionClassificationDataset instance. |
| Args: |
| root: root directory where dataset can be found |
| transforms: a function/transform that takes input sample and its target as |
| entry and returns a transformed version |
| loader: a callable function which takes as input a path to an image and |
| returns a PIL Image or numpy array |
| is_valid_file: A function that takes the path of an Image file and checks if |
| the file is a valid file |
| """ |
| |
| |
| super().__init__( |
| root=root, |
| transform=None, |
| target_transform=None, |
| loader=loader, |
| is_valid_file=is_valid_file, |
| ) |
|
|
| |
| self.transforms = transforms |
|
|
| def __getitem__(self, index: int) -> Dict[str, Tensor]: |
| """Return an index within the dataset. |
| Args: |
| index: index to return |
| Returns: |
| data and label at that index |
| """ |
| image, label = self._load_image(index) |
|
|
| if self.transforms is not None: |
| return self.transforms(image), label |
|
|
| return image, label |
|
|
| def __len__(self) -> int: |
| """Return the number of data points in the dataset. |
| Returns: |
| length of the dataset |
| """ |
| return len(self.imgs) |
|
|
| def _load_image(self, index: int) -> Tuple[Tensor, Tensor]: |
| """Load a single image and it's class label. |
| Args: |
| index: index to return |
| Returns: |
| the image |
| the image class label |
| """ |
| img, label = ImageFolder.__getitem__(self, index) |
| label = torch.tensor(label) |
| return img, label |
|
|
|
|
| class RESISC45Dataset(VisionClassificationDataset): |
| """RESISC45 dataset. |
| The `RESISC45 <http://www.escience.cn/people/JunweiHan/NWPU-RESISC45.html>`_ |
| dataset is a dataset for remote sensing image scene classification. |
| Dataset features: |
| * 31,500 images with 0.2-30 m per pixel resolution (256x256 px) |
| * three spectral bands - RGB |
| * 45 scene classes, 700 images per class |
| * images extracted from Google Earth from over 100 countries |
| * images conditions with high variability (resolution, weather, illumination) |
| Dataset format: |
| * images are three-channel jpgs |
| Dataset classes: |
| 0. airplane |
| 1. airport |
| 2. baseball_diamond |
| 3. basketball_court |
| 4. beach |
| 5. bridge |
| 6. chaparral |
| 7. church |
| 8. circular_farmland |
| 9. cloud |
| 10. commercial_area |
| 11. dense_residential |
| 12. desert |
| 13. forest |
| 14. freeway |
| 15. golf_course |
| 16. ground_track_field |
| 17. harbor |
| 18. industrial_area |
| 19. intersection |
| 20. island |
| 21. lake |
| 22. meadow |
| 23. medium_residential |
| 24. mobile_home_park |
| 25. mountain |
| 26. overpass |
| 27. palace |
| 28. parking_lot |
| 29. railway |
| 30. railway_station |
| 31. rectangular_farmland |
| 32. river |
| 33. roundabout |
| 34. runway |
| 35. sea_ice |
| 36. ship |
| 37. snowberg |
| 38. sparse_residential |
| 39. stadium |
| 40. storage_tank |
| 41. tennis_court |
| 42. terrace |
| 43. thermal_power_station |
| 44. wetland |
| This dataset uses the train/val/test splits defined in the "In-domain representation |
| learning for remote sensing" paper: |
| * https://arxiv.org/abs/1911.06721 |
| If you use this dataset in your research, please cite the following paper: |
| * https://doi.org/10.1109/jproc.2017.2675998 |
| """ |
|
|
| |
| |
| |
| directory = "resisc45/NWPU-RESISC45" |
|
|
| splits = ["train", "val", "test"] |
| split_urls = { |
| "train": "https://storage.googleapis.com/remote_sensing_representations/resisc45-train.txt", |
| "val": "https://storage.googleapis.com/remote_sensing_representations/resisc45-val.txt", |
| "test": "https://storage.googleapis.com/remote_sensing_representations/resisc45-test.txt", |
| } |
| split_md5s = { |
| "train": "b5a4c05a37de15e4ca886696a85c403e", |
| "val": "a0770cee4c5ca20b8c32bbd61e114805", |
| "test": "3dda9e4988b47eb1de9f07993653eb08", |
| } |
| classes = [ |
| "airplane", |
| "airport", |
| "baseball_diamond", |
| "basketball_court", |
| "beach", |
| "bridge", |
| "chaparral", |
| "church", |
| "circular_farmland", |
| "cloud", |
| "commercial_area", |
| "dense_residential", |
| "desert", |
| "forest", |
| "freeway", |
| "golf_course", |
| "ground_track_field", |
| "harbor", |
| "industrial_area", |
| "intersection", |
| "island", |
| "lake", |
| "meadow", |
| "medium_residential", |
| "mobile_home_park", |
| "mountain", |
| "overpass", |
| "palace", |
| "parking_lot", |
| "railway", |
| "railway_station", |
| "rectangular_farmland", |
| "river", |
| "roundabout", |
| "runway", |
| "sea_ice", |
| "ship", |
| "snowberg", |
| "sparse_residential", |
| "stadium", |
| "storage_tank", |
| "tennis_court", |
| "terrace", |
| "thermal_power_station", |
| "wetland", |
| ] |
|
|
| def __init__( |
| self, |
| root: str = "data", |
| split: str = "train", |
| transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, |
| ) -> None: |
| """Initialize a new RESISC45 dataset instance. |
| Args: |
| root: root directory where dataset can be found |
| split: one of "train", "val", or "test" |
| transforms: a function/transform that takes input sample and its target as |
| entry and returns a transformed version |
| """ |
| assert split in self.splits |
| self.root = root |
|
|
| valid_fns = set() |
| with open(os.path.join(self.root, "resisc45", f"resisc45-{split}.txt")) as f: |
| for fn in f: |
| valid_fns.add(fn.strip()) |
| is_in_split: Callable[[str], bool] = lambda x: os.path.basename( |
| x) in valid_fns |
|
|
| super().__init__( |
| root=os.path.join(root, self.directory), |
| transforms=transforms, |
| is_valid_file=is_in_split, |
| ) |
|
|
|
|
|
|
| class RESISC45: |
| def __init__(self, |
| preprocess, |
| location=os.path.expanduser('~/data'), |
| batch_size=32, |
| num_workers=16): |
|
|
| self.train_dataset = RESISC45Dataset(root=location, split='train', transforms=preprocess) |
| self.train_loader = torch.utils.data.DataLoader( |
| self.train_dataset, |
| shuffle=True, |
| batch_size=batch_size, |
| num_workers=num_workers, |
| ) |
|
|
| self.test_dataset = RESISC45Dataset(root=location, split='test', transforms=preprocess) |
| self.test_loader = torch.utils.data.DataLoader( |
| self.test_dataset, |
| batch_size=batch_size, |
| num_workers=num_workers |
| ) |
|
|
| |
| self.classnames = [' '.join(c.split('_')) for c in RESISC45Dataset.classes] |
|
|