FBAGSTM's picture
STM32 AI Experimentation Hub
747451d
# /*---------------------------------------------------------------------------------------------
# * Copyright (c) 2025 STMicroelectronics.
#  * All rights reserved.
#  *
#  * Copyright (c) Soumith Chintala 2016, All rights reserved., with a BSD-3 license
# *
#  * This software is licensed under terms that can be found in the LICENSE file in
#  * the root directory of this software component.
#  * If no LICENSE file comes with this software, it is provided AS-IS.
#  *--------------------------------------------------------------------------------------------*/
from pathlib import Path
import PIL.Image
from torchvision.datasets.utils import (
check_integrity,
download_and_extract_archive,
download_url,
verify_str_arg,
)
from torchvision.datasets.vision import VisionDataset
class Flowers102(VisionDataset):
# Taken from https://github.com/pytorch/vision/blob/HEAD/torchvision/datasets/flowers102.py
# Added for compatibility with old torchvision versions
_download_url_prefix = "https://www.robots.ox.ac.uk/~vgg/data/flowers/102/"
_file_dict = { # filename, md5
"image": ("102flowers.tgz", "52808999861908f626f3c1f4e79d11fa"),
"label": ("imagelabels.mat", "e0620be6f572b9609742df49c70aed4d"),
"setid": ("setid.mat", "a5357ecc9cb78c4bef273ce3793fc85c"),
}
_splits_map = {"train": "trnid", "val": "valid", "test": "tstid"}
def __init__(
self,
root: str,
split: str = "train",
transform=None,
target_transform=None,
download: bool = False,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self._split = verify_str_arg(split, "split", ("train", "val", "test"))
self._base_folder = Path(self.root) / "flowers-102"
#self._base_folder = Path(self.root) # Latest changes by nikhil
self._images_folder = self._base_folder / "jpg"
if download:
self.download()
if not self._check_integrity():
raise RuntimeError(
"Dataset not found or corrupted. You can use download=True to download it"
)
from scipy.io import loadmat
set_ids = loadmat(
self._base_folder / self._file_dict["setid"][0], squeeze_me=True
)
image_ids = set_ids[self._splits_map[self._split]].tolist()
labels = loadmat(
self._base_folder / self._file_dict["label"][0], squeeze_me=True
)
image_id_to_label = dict(enumerate((labels["labels"] - 1).tolist(), 1))
self._labels = []
self._image_files = []
for image_id in image_ids:
self._labels.append(image_id_to_label[image_id])
self._image_files.append(self._images_folder / f"image_{image_id:05d}.jpg")
self.classes = set(self._labels)
def __len__(self) -> int:
return len(self._image_files)
def __getitem__(self, idx):
image_file, label = self._image_files[idx], self._labels[idx]
image = PIL.Image.open(image_file).convert("RGB")
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
def extra_repr(self):
return f"split={self._split}"
def _check_integrity(self):
if not (self._images_folder.exists() and self._images_folder.is_dir()):
return False
for id in ["label", "setid"]:
filename, md5 = self._file_dict[id]
if not check_integrity(str(self._base_folder / filename), md5):
return False
return True
def download(self):
if self._check_integrity():
return
download_and_extract_archive(
f"{self._download_url_prefix}{self._file_dict['image'][0]}",
str(self._base_folder),
md5=self._file_dict["image"][1],
)
for id in ["label", "setid"]:
filename, md5 = self._file_dict[id]
download_url(
self._download_url_prefix + filename, str(self._base_folder), md5=md5
)