Spaces:
Build error
Build error
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| Copied from jam_data. | |
| """ | |
| import inspect | |
| from typing import Any, Callable, Dict | |
| import torch | |
| from torch.utils.data import Dataset | |
| MAX_LENGTH = 1 << 15 | |
| class LambdaDataset(torch.utils.data.Dataset): | |
| """ | |
| A dataset that generates items by applying a function. This allows for creating | |
| dynamic datasets where the items are the result of function calls. The function can optionally | |
| accept an index argument. | |
| Attributes: | |
| length (int): The total number of items in the dataset. | |
| fn (Callable): The function to generate dataset items. | |
| is_index_in_params (bool): Flag to determine whether 'index' should be passed | |
| to the function `fn`. | |
| """ | |
| def __init__(self, fn: Callable, length: int = MAX_LENGTH) -> None: | |
| """ | |
| Initializes the LambdaDataset with a function and the total length. | |
| Args: | |
| fn (Callable): A function that returns a dataset item. It can optionally accept an | |
| index argument to generate data items based on their index. | |
| length (int): The total number of items in the dataset, defaults to MAX_LENGTH. | |
| """ | |
| self.length = length | |
| self.fn = fn | |
| try: | |
| # Attempt to inspect the function signature to determine if it accepts an 'index' parameter. | |
| signature = inspect.signature(fn) | |
| self.is_index_in_params = "index" in signature.parameters | |
| except ValueError: | |
| # If the function signature is not inspectable, assume 'index' is not a parameter. | |
| self.is_index_in_params = False | |
| def __len__(self) -> int: | |
| """ | |
| Returns the total length of the dataset. | |
| Returns: | |
| int: The number of items in the dataset. | |
| """ | |
| return self.length | |
| def __getitem__(self, index: int) -> Any: | |
| """ | |
| Retrieves an item at a specific index from the dataset by calling the function `fn`. | |
| Passes the index to `fn` if `fn` is designed to accept an index. | |
| Args: | |
| index (int): The index of the item to retrieve. | |
| Returns: | |
| Any: The item returned by the function `fn`. | |
| """ | |
| if self.is_index_in_params: | |
| return self.fn(index) # Call fn with index if it accepts an index parameter. | |
| return self.fn() # Call fn without any parameters if it does not accept the index. | |
| class RepeatDataset(torch.utils.data.Dataset): | |
| """ | |
| A dataset wrapper that allows repeating access to items from an underlying dataset. | |
| This dataset can be used to create an artificial extension of the underlying dataset | |
| to a specified `length`. Each item from the original dataset can be accessed | |
| repeatedly up to `num_item` times before it loops back. | |
| Attributes: | |
| length (int): The total length of the dataset to be exposed. | |
| dataset (Dataset): The original dataset. | |
| num_item (int): Number of times each item is repeated. | |
| cache_item (dict): Cache to store accessed items to avoid recomputation. | |
| """ | |
| def __init__(self, dataset: Dataset, length: int = MAX_LENGTH, num_item: int = 1) -> None: | |
| """ | |
| Initializes the RepeatDataset with a dataset, length, and number of repeats per item. | |
| Args: | |
| dataset (Dataset): The dataset to repeat. | |
| length (int): The total length of the dataset to be exposed. Defaults to MAX_LENGTH. | |
| num_item (int): The number of times to repeat each item. Defaults to 1. | |
| """ | |
| self.length = length | |
| self.dataset = dataset | |
| self.num_item = num_item | |
| self.cache_item = {} | |
| def __len__(self) -> int: | |
| return self.length | |
| def __getitem__(self, index: int) -> Any: | |
| index = index % self.num_item | |
| if index not in self.cache_item: | |
| self.cache_item[index] = self.dataset[index] | |
| return self.cache_item[index] | |
| class CombinedDictDataset(torch.utils.data.Dataset): | |
| """ | |
| A dataset that wraps multiple PyTorch datasets and returns a dictionary of data items from each dataset for a given index. | |
| This dataset ensures that all constituent datasets have the same length by setting the length to the minimum length of the datasets provided. | |
| Parameters: | |
| ----------- | |
| **datasets : Dict[str, Dataset] | |
| A dictionary where keys are string identifiers for the datasets and values are the datasets instances themselves. | |
| Attributes: | |
| ----------- | |
| datasets : Dict[str, Dataset] | |
| Stores the input datasets. | |
| max_length : int | |
| The minimum length among all provided datasets, determining the length of this combined dataset. | |
| Examples: | |
| --------- | |
| >>> dataset1 = torch.utils.data.TensorDataset(torch.randn(100, 3, 32, 32)) | |
| >>> dataset2 = torch.utils.data.TensorDataset(torch.randn(100, 3, 32, 32)) | |
| >>> combined_dataset = CombinedDictDataset(dataset1=dataset1, dataset2=dataset2) | |
| >>> print(len(combined_dataset)) | |
| 100 | |
| >>> data = combined_dataset[50] | |
| >>> print(data.keys()) | |
| dict_keys(['dataset1', 'dataset2']) | |
| """ | |
| def __init__(self, **datasets: Dict[str, Dataset]) -> None: | |
| """ | |
| Initializes the CombinedDictDataset with multiple datasets. | |
| Args: | |
| **datasets (Dict[str, Dataset]): Key-value pairs where keys are dataset names and values | |
| are dataset instances. Each key-value pair adds a dataset | |
| under the specified key. | |
| """ | |
| self.datasets = datasets | |
| self.max_length = min([len(dataset) for dataset in datasets.values()]) | |
| def __len__(self) -> int: | |
| return self.max_length | |
| def __getitem__(self, index: int) -> Dict[str, Any]: | |
| """ | |
| Retrieves an item from each dataset at the specified index, combines them into a dictionary, | |
| and returns the dictionary. Each key in the dictionary corresponds to one of the dataset names provided | |
| during initialization, and its value is the item from that dataset at the given index. | |
| Args: | |
| index (int): The index of the items to retrieve across all datasets. | |
| Returns: | |
| Dict[str, Any]: A dictionary containing data items from all datasets for the given index. | |
| Each key corresponds to a dataset name, and its value is the data item from that dataset. | |
| """ | |
| data = {} | |
| for key, dataset in self.datasets.items(): | |
| data[key] = dataset[index] | |
| return data | |