| from typing import Optional | |
| import pyspark | |
| from .. import Features, NamedSplit | |
| from ..download import DownloadMode | |
| from ..packaged_modules.spark.spark import Spark | |
| from .abc import AbstractDatasetReader | |
| class SparkDatasetReader(AbstractDatasetReader): | |
| """A dataset reader that reads from a Spark DataFrame. | |
| When caching, cache materialization is parallelized over Spark; an NFS that is accessible to the driver must be | |
| provided. Streaming is not currently supported. | |
| """ | |
| def __init__( | |
| self, | |
| df: pyspark.sql.DataFrame, | |
| split: Optional[NamedSplit] = None, | |
| features: Optional[Features] = None, | |
| streaming: bool = True, | |
| cache_dir: str = None, | |
| keep_in_memory: bool = False, | |
| working_dir: str = None, | |
| load_from_cache_file: bool = True, | |
| file_format: str = "arrow", | |
| **kwargs, | |
| ): | |
| super().__init__( | |
| split=split, | |
| features=features, | |
| cache_dir=cache_dir, | |
| keep_in_memory=keep_in_memory, | |
| streaming=streaming, | |
| **kwargs, | |
| ) | |
| self._load_from_cache_file = load_from_cache_file | |
| self._file_format = file_format | |
| self.builder = Spark( | |
| df=df, | |
| features=features, | |
| cache_dir=cache_dir, | |
| working_dir=working_dir, | |
| **kwargs, | |
| ) | |
| def read(self): | |
| if self.streaming: | |
| return self.builder.as_streaming_dataset(split=self.split) | |
| download_mode = None if self._load_from_cache_file else DownloadMode.FORCE_REDOWNLOAD | |
| self.builder.download_and_prepare( | |
| download_mode=download_mode, | |
| file_format=self._file_format, | |
| ) | |
| return self.builder.as_dataset(split=self.split) | |