| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """ |
| | ImageFold function. |
| | |
| | Mostly copy-paste from torchvision references |
| | """ |
| | import os |
| | import os.path |
| | from typing import Any, Callable, Dict, List, Optional, Tuple, cast |
| |
|
| | from PIL import Image |
| | from torchvision.datasets.vision import VisionDataset |
| |
|
| |
|
| | def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool: |
| | """Checks if a file is an allowed extension. |
| | |
| | Args: |
| | filename (string): path to a file |
| | extensions (tuple of strings): extensions to consider (lowercase) |
| | |
| | Returns: |
| | bool: True if the filename ends with one of given extensions |
| | """ |
| | return filename.lower().endswith(extensions) |
| |
|
| |
|
| | def is_image_file(filename: str) -> bool: |
| | """Checks if a file is an allowed image extension. |
| | |
| | Args: |
| | filename (string): path to a file |
| | |
| | Returns: |
| | bool: True if the filename ends with a known image extension |
| | """ |
| | return has_file_allowed_extension(filename, IMG_EXTENSIONS) |
| |
|
| |
|
| | def find_classes(directory: str, class_num: int) -> Tuple[List[str], Dict[str, int]]: |
| | """Finds the class folders in a dataset. |
| | |
| | See :class:`DatasetFolder` for details. |
| | """ |
| | classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir()) |
| | if not classes: |
| | raise FileNotFoundError(f"Couldn't find any class folder in {directory}.") |
| | classes = classes[:class_num] |
| | class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} |
| | return classes, class_to_idx |
| |
|
| |
|
| | def make_dataset( |
| | directory: str, |
| | class_to_idx: Optional[Dict[str, int]] = None, |
| | extensions: Optional[Tuple[str, ...]] = None, |
| | is_valid_file: Optional[Callable[[str], bool]] = None, |
| | class_num=10, |
| | ) -> List[Tuple[str, int]]: |
| | """Generates a list of samples of a form (path_to_sample, class). |
| | |
| | See :class:`DatasetFolder` for details. |
| | |
| | Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function |
| | by default. |
| | """ |
| | directory = os.path.expanduser(directory) |
| |
|
| | if class_to_idx is None: |
| | _, class_to_idx = find_classes(directory, class_num) |
| | elif not class_to_idx: |
| | raise ValueError( |
| | "'class_to_index' must have at least one entry to collect any samples." |
| | ) |
| |
|
| | both_none = extensions is None and is_valid_file is None |
| | both_something = extensions is not None and is_valid_file is not None |
| | if both_none or both_something: |
| | raise ValueError( |
| | "Both extensions and is_valid_file cannot be None or not None at the same time" |
| | ) |
| |
|
| | if extensions is not None: |
| |
|
| | def is_valid_file(x: str) -> bool: |
| | return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions)) |
| |
|
| | is_valid_file = cast(Callable[[str], bool], is_valid_file) |
| |
|
| | instances = [] |
| | available_classes = set() |
| | for target_class in sorted(class_to_idx.keys()): |
| | class_index = class_to_idx[target_class] |
| | target_dir = os.path.join(directory, target_class) |
| | if not os.path.isdir(target_dir): |
| | continue |
| | for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): |
| | for fname in sorted(fnames): |
| | path = os.path.join(root, fname) |
| | if is_valid_file(path): |
| | item = path, class_index |
| | instances.append(item) |
| |
|
| | if target_class not in available_classes: |
| | available_classes.add(target_class) |
| |
|
| | empty_classes = set(class_to_idx.keys()) - available_classes |
| | if empty_classes: |
| | msg = ( |
| | f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. " |
| | ) |
| | if extensions is not None: |
| | msg += f"Supported extensions are: {', '.join(extensions)}" |
| | raise FileNotFoundError(msg) |
| |
|
| | return instances |
| |
|
| |
|
| | class DatasetFolder(VisionDataset): |
| | """A generic data loader. |
| | |
| | This default directory structure can be customized by overriding the |
| | :meth:`find_classes` method. |
| | |
| | Args: |
| | root (string): Root directory path. |
| | loader (callable): A function to load a sample given its path. |
| | extensions (tuple[string]): A list of allowed extensions. |
| | both extensions and is_valid_file should not be passed. |
| | transform (callable, optional): A function/transform that takes in |
| | a sample and returns a transformed version. |
| | E.g, ``transforms.RandomCrop`` for images. |
| | target_transform (callable, optional): A function/transform that takes |
| | in the target and transforms it. |
| | is_valid_file (callable, optional): A function that takes path of a file |
| | and check if the file is a valid file (used to check of corrupt files) |
| | both extensions and is_valid_file should not be passed. |
| | class_num: how many classes will be loaded |
| | Attributes: |
| | classes (list): List of the class names sorted alphabetically. |
| | class_to_idx (dict): Dict with items (class_name, class_index). |
| | samples (list): List of (sample path, class_index) tuples |
| | targets (list): The class_index value for each image in the dataset |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | root: str, |
| | loader: Callable[[str], Any], |
| | extensions: Optional[Tuple[str, ...]] = None, |
| | transform: Optional[Callable] = None, |
| | target_transform: Optional[Callable] = None, |
| | is_valid_file: Optional[Callable[[str], bool]] = None, |
| | class_num=10, |
| | ) -> None: |
| | super(DatasetFolder, self).__init__( |
| | root, transform=transform, target_transform=target_transform |
| | ) |
| | classes, class_to_idx = self.find_classes(self.root, class_num=class_num) |
| | samples = self.make_dataset( |
| | self.root, class_to_idx, extensions, is_valid_file, class_num=class_num |
| | ) |
| |
|
| | self.loader = loader |
| | self.extensions = extensions |
| |
|
| | self.classes = classes |
| | self.class_to_idx = class_to_idx |
| | self.samples = samples |
| | self.targets = [s[1] for s in samples] |
| |
|
| | @staticmethod |
| | def make_dataset( |
| | directory: str, |
| | class_to_idx: Dict[str, int], |
| | extensions: Optional[Tuple[str, ...]] = None, |
| | is_valid_file: Optional[Callable[[str], bool]] = None, |
| | class_num=10, |
| | ) -> List[Tuple[str, int]]: |
| | """Generates a list of samples of a form (path_to_sample, class). |
| | |
| | This can be overridden to e.g. read files from a compressed zip file instead of from the disk. |
| | |
| | Args: |
| | directory (str): root dataset directory, corresponding to ``self.root``. |
| | class_to_idx (Dict[str, int]): Dictionary mapping class name to class index. |
| | extensions (optional): A list of allowed extensions. |
| | Either extensions or is_valid_file should be passed. Defaults to None. |
| | is_valid_file (optional): A function that takes path of a file |
| | and checks if the file is a valid file |
| | (used to check of corrupt files) both extensions and |
| | is_valid_file should not be passed. Defaults to None. |
| | class_num: how many classes will be loaded |
| | Raises: |
| | ValueError: In case ``class_to_idx`` is empty. |
| | ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None. |
| | FileNotFoundError: In case no valid file was found for any class. |
| | |
| | Returns: |
| | List[Tuple[str, int]]: samples of a form (path_to_sample, class) |
| | """ |
| | if class_to_idx is None: |
| | |
| | |
| | |
| | raise ValueError("The class_to_idx parameter cannot be None.") |
| | return make_dataset( |
| | directory, |
| | class_to_idx, |
| | extensions=extensions, |
| | is_valid_file=is_valid_file, |
| | class_num=class_num, |
| | ) |
| |
|
| | def find_classes( |
| | self, directory: str, class_num: int |
| | ) -> Tuple[List[str], Dict[str, int]]: |
| | """Find the class folders in a dataset structured as follows:: |
| | |
| | directory/ |
| | βββ class_x |
| | β βββ xxx.ext |
| | β βββ xxy.ext |
| | β βββ ... |
| | β βββ xxz.ext |
| | βββ class_y |
| | βββ 123.ext |
| | βββ nsdf3.ext |
| | βββ ... |
| | βββ asd932_.ext |
| | |
| | This method can be overridden to only consider |
| | a subset of classes, or to adapt to a different dataset directory structure. |
| | |
| | Args: |
| | directory(str): Root directory path, corresponding to ``self.root`` |
| | |
| | Raises: |
| | FileNotFoundError: If ``dir`` has no class folders. |
| | |
| | Returns: |
| | (Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index. |
| | """ |
| | return find_classes(directory, class_num=class_num) |
| |
|
| | def __getitem__(self, index: int) -> Tuple[Any, Any]: |
| | """ |
| | Args: |
| | index (int): Index |
| | |
| | Returns: |
| | tuple: (sample, target) where target is class_index of the target class. |
| | """ |
| | path, target = self.samples[index] |
| | sample = self.loader(path) |
| | if self.transform is not None: |
| | sample = self.transform(sample) |
| | |
| | |
| |
|
| | return sample |
| |
|
| | def __len__(self) -> int: |
| | return len(self.samples) |
| |
|
| |
|
| | IMG_EXTENSIONS = ( |
| | ".jpg", |
| | ".jpeg", |
| | ".png", |
| | ".ppm", |
| | ".bmp", |
| | ".pgm", |
| | ".tif", |
| | ".tiff", |
| | ".webp", |
| | ) |
| |
|
| |
|
| | def pil_loader(path: str) -> Image.Image: |
| | |
| | with open(path, "rb") as f: |
| | img = Image.open(f) |
| | return img.convert("RGB") |
| |
|
| |
|
| | |
| | def accimage_loader(path: str) -> Any: |
| | import accimage |
| |
|
| | try: |
| | return accimage.Image(path) |
| | except IOError: |
| | |
| | return pil_loader(path) |
| |
|
| |
|
| | def default_loader(path: str) -> Any: |
| | from torchvision import get_image_backend |
| |
|
| | if get_image_backend() == "accimage": |
| | return accimage_loader(path) |
| | else: |
| | return pil_loader(path) |
| |
|
| |
|
| | class ImageFolder(DatasetFolder): |
| | """A generic data loader where the images are arranged in this way by default: :: |
| | |
| | root/dog/xxx.png |
| | root/dog/xxy.png |
| | root/dog/[...]/xxz.png |
| | |
| | root/cat/123.png |
| | root/cat/nsdf3.png |
| | root/cat/[...]/asd932_.png |
| | |
| | This class inherits from :class:`~torchvision.datasets.DatasetFolder` so |
| | the same methods can be overridden to customize the dataset. |
| | |
| | Args: |
| | root (string): Root directory path. |
| | transform (callable, optional): A function/transform that takes in an PIL image |
| | and returns a transformed version. E.g, ``transforms.RandomCrop`` |
| | target_transform (callable, optional): A function/transform that takes in the |
| | target and transforms it. |
| | loader (callable, optional): A function to load an image given its path. |
| | is_valid_file (callable, optional): A function that takes path of an Image file |
| | and check if the file is a valid file (used to check of corrupt files) |
| | class_num: how many classes will be loaded |
| | Attributes: |
| | classes (list): List of the class names sorted alphabetically. |
| | class_to_idx (dict): Dict with items (class_name, class_index). |
| | imgs (list): List of (image path, class_index) tuples |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | root: str, |
| | transform: Optional[Callable] = None, |
| | target_transform: Optional[Callable] = None, |
| | loader: Callable[[str], Any] = default_loader, |
| | is_valid_file: Optional[Callable[[str], bool]] = None, |
| | class_num=10, |
| | ): |
| | super(ImageFolder, self).__init__( |
| | root, |
| | loader, |
| | IMG_EXTENSIONS if is_valid_file is None else None, |
| | transform=transform, |
| | target_transform=target_transform, |
| | is_valid_file=is_valid_file, |
| | class_num=class_num, |
| | ) |
| | self.imgs = self.samples |
| |
|