# /*--------------------------------------------------------------------------------------------- # * Copyright (c) 2025 STMicroelectronics. #  * All rights reserved. #  * #  * Copyright (c) Soumith Chintala 2016, All rights reserved., with a BSD-3 license # * #  * This software is licensed under terms that can be found in the LICENSE file in #  * the root directory of this software component. #  * If no LICENSE file comes with this software, it is provided AS-IS. #  *--------------------------------------------------------------------------------------------*/ from pathlib import Path import PIL.Image from torchvision.datasets.utils import ( check_integrity, download_and_extract_archive, download_url, verify_str_arg, ) from torchvision.datasets.vision import VisionDataset class Flowers102(VisionDataset): # Taken from https://github.com/pytorch/vision/blob/HEAD/torchvision/datasets/flowers102.py # Added for compatibility with old torchvision versions _download_url_prefix = "https://www.robots.ox.ac.uk/~vgg/data/flowers/102/" _file_dict = { # filename, md5 "image": ("102flowers.tgz", "52808999861908f626f3c1f4e79d11fa"), "label": ("imagelabels.mat", "e0620be6f572b9609742df49c70aed4d"), "setid": ("setid.mat", "a5357ecc9cb78c4bef273ce3793fc85c"), } _splits_map = {"train": "trnid", "val": "valid", "test": "tstid"} def __init__( self, root: str, split: str = "train", transform=None, target_transform=None, download: bool = False, ) -> None: super().__init__(root, transform=transform, target_transform=target_transform) self._split = verify_str_arg(split, "split", ("train", "val", "test")) self._base_folder = Path(self.root) / "flowers-102" #self._base_folder = Path(self.root) # Latest changes by nikhil self._images_folder = self._base_folder / "jpg" if download: self.download() if not self._check_integrity(): raise RuntimeError( "Dataset not found or corrupted. You can use download=True to download it" ) from scipy.io import loadmat set_ids = loadmat( self._base_folder / self._file_dict["setid"][0], squeeze_me=True ) image_ids = set_ids[self._splits_map[self._split]].tolist() labels = loadmat( self._base_folder / self._file_dict["label"][0], squeeze_me=True ) image_id_to_label = dict(enumerate((labels["labels"] - 1).tolist(), 1)) self._labels = [] self._image_files = [] for image_id in image_ids: self._labels.append(image_id_to_label[image_id]) self._image_files.append(self._images_folder / f"image_{image_id:05d}.jpg") self.classes = set(self._labels) def __len__(self) -> int: return len(self._image_files) def __getitem__(self, idx): image_file, label = self._image_files[idx], self._labels[idx] image = PIL.Image.open(image_file).convert("RGB") if self.transform: image = self.transform(image) if self.target_transform: label = self.target_transform(label) return image, label def extra_repr(self): return f"split={self._split}" def _check_integrity(self): if not (self._images_folder.exists() and self._images_folder.is_dir()): return False for id in ["label", "setid"]: filename, md5 = self._file_dict[id] if not check_integrity(str(self._base_folder / filename), md5): return False return True def download(self): if self._check_integrity(): return download_and_extract_archive( f"{self._download_url_prefix}{self._file_dict['image'][0]}", str(self._base_folder), md5=self._file_dict["image"][1], ) for id in ["label", "setid"]: filename, md5 = self._file_dict[id] download_url( self._download_url_prefix + filename, str(self._base_folder), md5=md5 )