Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| import glob | |
| import os | |
| from typing import Callable, Iterator | |
| from torch import Tensor | |
| from torch.utils.data import DataLoader, Dataset, IterableDataset | |
| class CustomIterableDataset(IterableDataset): | |
| r""" | |
| An auxiliary class for iterating through a dataset. | |
| """ | |
| def __init__(self, transform_filename_to_tensor: Callable, path: str) -> None: | |
| r""" | |
| Args: | |
| transform_filename_to_tensor (callable): Function to read a data | |
| file from path and return a tensor from that file. | |
| path (str): Path to dataset files. This can be either a path to a | |
| directory or a file where input examples are stored. | |
| """ | |
| self.file_itr = None | |
| self.path = path | |
| if os.path.isdir(self.path): | |
| self.file_itr = glob.glob(self.path + "*") | |
| self.transform_filename_to_tensor = transform_filename_to_tensor | |
| def __iter__(self) -> Iterator[Tensor]: | |
| r""" | |
| Returns: | |
| iter (Iterator[Tensor]): A map from a function that | |
| processes a list of file path(s) to a list of Tensors. | |
| """ | |
| if self.file_itr is not None: | |
| return map(self.transform_filename_to_tensor, self.file_itr) | |
| else: | |
| return self.transform_filename_to_tensor(self.path) | |
| def dataset_to_dataloader(dataset: Dataset, batch_size: int = 64) -> DataLoader: | |
| r""" | |
| An auxiliary function that creates torch DataLoader from torch Dataset | |
| using input `batch_size`. | |
| Args: | |
| dataset (Dataset): A torch dataset that allows to iterate over | |
| the batches of examples. | |
| batch_size (int, optional): Batch size of for each tensor in the | |
| iteration. | |
| Returns: | |
| dataloader_iter (DataLoader): a DataLoader for data iteration. | |
| """ | |
| return DataLoader(dataset, batch_size=batch_size) | |