Spaces:
Runtime error
Runtime error
| # Copyright 2024 EPFL and Apple Inc. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import os.path | |
| import pickle | |
| import random | |
| from copy import deepcopy | |
| from typing import Any, Callable, Dict, List, Optional, Tuple, cast | |
| import numpy as np | |
| from torchvision.datasets.vision import VisionDataset | |
| from fourm.data.modality_transforms import AbstractTransform, get_transform_key | |
| IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp', '.jpx', '.npy', '.npz') | |
| UNIFIED_EXTENSIONS = IMG_EXTENSIONS + ('.json', '.txt', '.json.gz') | |
| def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool: | |
| """Checks if a file is an allowed extension. | |
| Args: | |
| filename (string): path to a file | |
| extensions (tuple of strings): extensions to consider (lowercase) | |
| Returns: | |
| bool: True if the filename ends with one of given extensions | |
| """ | |
| return filename.lower().endswith(extensions) | |
| def is_image_file(filename: str) -> bool: | |
| """Checks if a file is an allowed image extension. | |
| Args: | |
| filename (string): path to a file | |
| Returns: | |
| bool: True if the filename ends with a known image extension | |
| """ | |
| return has_file_allowed_extension(filename, IMG_EXTENSIONS) | |
| def make_dataset( | |
| directory: str, | |
| class_to_idx: Dict[str, int], | |
| extensions: Optional[Tuple[str, ...]] = None, | |
| is_valid_file: Optional[Callable[[str], bool]] = None, | |
| cache_path: Optional[str] = None, | |
| ) -> List[Tuple[str, int]]: | |
| if cache_path is not None and os.path.exists(cache_path): | |
| # Load cached file paths from disk if it exists | |
| with open(cache_path, 'rb') as f: | |
| return pickle.load(f) | |
| instances = [] | |
| directory = os.path.expanduser(directory) | |
| both_none = extensions is None and is_valid_file is None | |
| both_something = extensions is not None and is_valid_file is not None | |
| if both_none or both_something: | |
| raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") | |
| if extensions is not None: | |
| def is_valid_file(x: str) -> bool: | |
| return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions)) | |
| is_valid_file = cast(Callable[[str], bool], is_valid_file) | |
| for target_class in sorted(class_to_idx.keys()): | |
| class_index = class_to_idx[target_class] | |
| target_dir = os.path.join(directory, target_class) | |
| if not os.path.isdir(target_dir): | |
| continue | |
| for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): | |
| for fname in sorted(fnames): | |
| path = os.path.join(root, fname) | |
| if is_valid_file(path): | |
| item = path, class_index | |
| instances.append(item) | |
| if cache_path is not None: | |
| # Cache all file paths s.t. setting up the dataloader is instant in the future | |
| os.makedirs(os.path.dirname(cache_path), exist_ok=True) | |
| with open(cache_path, 'wb') as f: | |
| pickle.dump(instances, f) | |
| return instances | |
| class DatasetFolder(VisionDataset): | |
| """A generic data loader where the samples are arranged in this way: :: | |
| root/class_x/xxx.ext | |
| root/class_x/xxy.ext | |
| root/class_x/xxz.ext | |
| root/class_y/123.ext | |
| root/class_y/nsdf3.ext | |
| root/class_y/asd932_.ext | |
| Args: | |
| root (string): Root directory path. | |
| loader (callable): A function to load a sample given its path. | |
| extensions (tuple[string]): A list of allowed extensions. | |
| both extensions and is_valid_file should not be passed. | |
| transform (callable, optional): A function/transform that takes in | |
| a sample and returns a transformed version. | |
| E.g, ``transforms.RandomCrop`` for images. | |
| target_transform (callable, optional): A function/transform that takes | |
| in the target and transforms it. | |
| is_valid_file (callable, optional): A function that takes path of a file | |
| and check if the file is a valid file (used to check of corrupt logs) | |
| both extensions and is_valid_file should not be passed. | |
| Attributes: | |
| classes (list): List of the class names sorted alphabetically. | |
| class_to_idx (dict): Dict with items (class_name, class_index). | |
| samples (list): List of (sample path, class_index) tuples | |
| targets (list): The class_index value for each image in the dataset | |
| """ | |
| def __init__( | |
| self, | |
| root: str, | |
| loader: Callable[[str], Any], | |
| extensions: Optional[Tuple[str, ...]] = None, | |
| transform: Optional[Callable] = None, | |
| target_transform: Optional[Callable] = None, | |
| is_valid_file: Optional[Callable[[str], bool]] = None, | |
| ) -> None: | |
| super(DatasetFolder, self).__init__(root, transform=transform, | |
| target_transform=target_transform) | |
| classes, class_to_idx = self._find_classes(self.root) | |
| samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file) | |
| if len(samples) == 0: | |
| msg = "Found 0 logs in subfolders of: {}\n".format(self.root) | |
| if extensions is not None: | |
| msg += "Supported extensions are: {}".format(",".join(extensions)) | |
| raise RuntimeError(msg) | |
| self.loader = loader | |
| self.extensions = extensions | |
| self.classes = classes | |
| self.class_to_idx = class_to_idx | |
| self.samples = samples | |
| self.targets = [s[1] for s in samples] | |
| def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]: | |
| """ | |
| Finds the class folders in a dataset. | |
| Args: | |
| dir (string): Root directory path. | |
| Returns: | |
| tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. | |
| Ensures: | |
| No class is a subdirectory of another. | |
| """ | |
| classes = [d.name for d in os.scandir(dir) if d.is_dir()] | |
| classes.sort() | |
| class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} | |
| return classes, class_to_idx | |
| def __getitem__(self, index: int) -> Tuple[Any, Any]: | |
| """ | |
| Args: | |
| index (int): Index | |
| Returns: | |
| tuple: (sample, target) where target is class_index of the target class. | |
| """ | |
| while True: | |
| try: | |
| path, target = self.samples[index] | |
| sample = self.loader(path) | |
| break | |
| except Exception as e: | |
| print(e) | |
| index = random.randint(0, len(self.samples) - 1) | |
| if self.transform is not None: | |
| sample = self.transform(sample) | |
| if self.target_transform is not None: | |
| target = self.target_transform(target) | |
| return sample, target | |
| def __len__(self) -> int: | |
| return len(self.samples) | |
| class MultiModalDatasetFolder(VisionDataset): | |
| """A generic multi-modal dataset loader where the samples are arranged in this way: :: | |
| root/modality_a/class_x/xxx.ext | |
| root/modality_a/class_y/xxy.ext | |
| root/modality_a/class_z/xxz.ext | |
| root/modality_b/class_x/xxx.ext | |
| root/modality_b/class_y/xxy.ext | |
| root/modality_b/class_z/xxz.ext | |
| Args: | |
| root (string): Root directory path. | |
| modalities (list): List of modalities as strings | |
| modality_paths (dict): Dict of paths to modalities | |
| modality_transforms (dict): Dict of transforms for each modality | |
| loader (callable): A function to load a sample given its path. | |
| transform (callable, optional): A function/transform that takes in | |
| a sample and returns a transformed version. | |
| E.g, ``transforms.RandomCrop`` for images. | |
| target_transform (callable, optional): A function/transform that takes | |
| in the target and transforms it. | |
| is_valid_file (callable, optional): A function that takes path of a file | |
| and check if the file is a valid file (used to check of corrupt logs) | |
| both extensions and is_valid_file should not be passed. | |
| max_samples (int, optional): Maximum number of samples to load. If None, all samples are loaded. | |
| pre_shuffle (bool, optional): Whether to shuffle the sample during the init. | |
| return_paths (bool, optional): Whether to return the paths of the samples. | |
| cache (bool, optional): Whether to cache the samples in memory. If True, the samples are loaded only once and then cached in memory. | |
| Attributes: | |
| classes (list): List of the class names sorted alphabetically. | |
| class_to_idx (dict): Dict with items (class_name, class_index). | |
| samples (list): List of (sample path, class_index) tuples | |
| targets (list): The class_index value for each image in the dataset | |
| """ | |
| def __init__( | |
| self, | |
| root: str, | |
| modalities: List[str], | |
| modality_paths: Dict[str, str], | |
| modality_transforms: Dict[str, AbstractTransform], | |
| transform: Optional[Callable] = None, | |
| target_transform: Optional[Callable] = None, | |
| is_valid_file: Optional[Callable[[str], bool]] = None, | |
| max_samples: Optional[int] = None, | |
| pre_shuffle: bool = False, | |
| cache: bool = False, | |
| return_path: bool = False, | |
| ) -> None: | |
| super(MultiModalDatasetFolder, self).__init__(root, transform=transform, target_transform=target_transform) | |
| self.modalities = modalities | |
| # If modality_paths is not provided, use the default paths | |
| self.modality_paths = modality_paths | |
| for mod in self.modalities: | |
| if mod not in self.modality_paths: | |
| modality_paths[mod] = mod | |
| self.modality_transforms = modality_transforms | |
| self.return_path = return_path | |
| classes, class_to_idx = self._find_classes(os.path.join(self.root, list(self.modality_paths.values())[0])) | |
| extensions = UNIFIED_EXTENSIONS if is_valid_file is None else None | |
| samples = { | |
| mod: make_dataset( | |
| os.path.join(self.root, f'{self.modality_paths[mod]}'), | |
| class_to_idx, | |
| extensions, | |
| is_valid_file, | |
| cache_path=os.path.join(self.root, 'dataloader_cache', f'{self.modality_paths[mod]}.pkl') if cache else None) | |
| for mod in self.modalities | |
| } | |
| for mod, mod_samples in samples.items(): | |
| if len(mod_samples) == 0: | |
| msg = "Found 0 logs in subfolders of: {}\n".format(os.path.join(self.root, f'{self.modality_paths[mod]}')) | |
| if extensions is not None: | |
| msg += "Supported extensions are: {}".format(",".join(extensions)) | |
| raise RuntimeError(msg) | |
| self.extensions = extensions | |
| self.classes = classes | |
| self.class_to_idx = class_to_idx | |
| self.samples = samples | |
| # Select random subset of dataset if so specified | |
| if isinstance(max_samples, int): | |
| total_samples = len(list(self.samples.values())[0]) | |
| np.random.seed(0) | |
| permutation = np.random.permutation(total_samples) | |
| for task in samples: | |
| self.samples[task] = [self.samples[task][i] for i in permutation][:max_samples] | |
| if pre_shuffle: | |
| total_samples = len(list(self.samples.values())[0]) | |
| np.random.seed(100) | |
| permutation = np.random.permutation(total_samples) | |
| for task in samples: | |
| self.samples[task] = [self.samples[task][i] for i in permutation] | |
| self.cache = {} | |
| self.imgs = self.samples | |
| def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]: | |
| """ | |
| Finds the class folders in a dataset. | |
| Args: | |
| dir (string): Root directory path. | |
| Returns: | |
| tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. | |
| Ensures: | |
| No class is a subdirectory of another. | |
| """ | |
| classes = [d.name for d in os.scandir(dir) if d.is_dir()] | |
| classes.sort() | |
| class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} | |
| return classes, class_to_idx | |
| def get_class_and_file(self, path: str) -> Tuple[str, str]: | |
| """ Extracts the class and file name from a path. """ | |
| class_id, file_name = path.split('/')[-2:] | |
| file_name = file_name.split('.')[0] | |
| return class_id, file_name | |
| def __getitem__(self, index: int) -> Tuple[Any, Any]: | |
| """ | |
| Args: | |
| index (int): Index | |
| Returns: | |
| tuple: (sample, target) where target is class_index of the target class. | |
| """ | |
| if index in self.cache: | |
| sample_dict, target = deepcopy(self.cache[index]) | |
| else: | |
| sample_dict = {} | |
| for mod in self.modalities: | |
| path, target = self.samples[mod][index] | |
| sample = self.modality_transforms[get_transform_key(mod)].load(path) | |
| sample_dict[mod] = sample | |
| # self.cache[index] = deepcopy((sample_dict, target)) | |
| if self.transform is not None: | |
| sample_dict = self.transform(sample_dict) | |
| if self.target_transform is not None: | |
| target = self.target_transform(target) | |
| sample_dict['class_idx'] = target | |
| if self.return_path and not index in self.cache: | |
| class_id, file_name = self.get_class_and_file(path) | |
| sample_dict['class_id'] = class_id | |
| sample_dict['file_name'] = file_name | |
| return sample_dict | |
| def __len__(self) -> int: | |
| return len(list(self.samples.values())[0]) | |