| import io |
| import itertools |
| from dataclasses import dataclass |
| from typing import Optional |
|
|
| import pandas as pd |
| import pyarrow as pa |
| import pyarrow.json as paj |
|
|
| import datasets |
| import datasets.config |
| from datasets.table import table_cast |
| from datasets.utils.file_utils import readline |
|
|
|
|
| logger = datasets.utils.logging.get_logger(__name__) |
|
|
|
|
| def ujson_dumps(*args, **kwargs): |
| try: |
| return pd.io.json.ujson_dumps(*args, **kwargs) |
| except AttributeError: |
| |
| return pd.io.json.dumps(*args, **kwargs) |
|
|
|
|
| def ujson_loads(*args, **kwargs): |
| try: |
| return pd.io.json.ujson_loads(*args, **kwargs) |
| except AttributeError: |
| |
| return pd.io.json.loads(*args, **kwargs) |
|
|
|
|
| def pandas_read_json(path_or_buf, **kwargs): |
| if datasets.config.PANDAS_VERSION.major >= 2: |
| kwargs["dtype_backend"] = "pyarrow" |
| return pd.read_json(path_or_buf, **kwargs) |
|
|
|
|
| @dataclass |
| class JsonConfig(datasets.BuilderConfig): |
| """BuilderConfig for JSON.""" |
|
|
| features: Optional[datasets.Features] = None |
| encoding: str = "utf-8" |
| encoding_errors: Optional[str] = None |
| field: Optional[str] = None |
| use_threads: bool = True |
| block_size: Optional[int] = None |
| chunksize: int = 10 << 20 |
| newlines_in_values: Optional[bool] = None |
|
|
| def __post_init__(self): |
| super().__post_init__() |
|
|
|
|
| class Json(datasets.ArrowBasedBuilder): |
| BUILDER_CONFIG_CLASS = JsonConfig |
|
|
| def _info(self): |
| if self.config.block_size is not None: |
| logger.warning("The JSON loader parameter `block_size` is deprecated. Please use `chunksize` instead") |
| self.config.chunksize = self.config.block_size |
| if self.config.use_threads is not True: |
| logger.warning( |
| "The JSON loader parameter `use_threads` is deprecated and doesn't have any effect anymore." |
| ) |
| if self.config.newlines_in_values is not None: |
| raise ValueError("The JSON loader parameter `newlines_in_values` is no longer supported") |
| return datasets.DatasetInfo(features=self.config.features) |
|
|
| def _split_generators(self, dl_manager): |
| """We handle string, list and dicts in datafiles""" |
| if not self.config.data_files: |
| raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") |
| dl_manager.download_config.extract_on_the_fly = True |
| data_files = dl_manager.download_and_extract(self.config.data_files) |
| splits = [] |
| for split_name, files in data_files.items(): |
| if isinstance(files, str): |
| files = [files] |
| files = [dl_manager.iter_files(file) for file in files] |
| splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files})) |
| return splits |
|
|
| def _cast_table(self, pa_table: pa.Table) -> pa.Table: |
| if self.config.features is not None: |
| |
| for column_name in set(self.config.features) - set(pa_table.column_names): |
| type = self.config.features.arrow_schema.field(column_name).type |
| pa_table = pa_table.append_column(column_name, pa.array([None] * len(pa_table), type=type)) |
| |
| for i, column_name in enumerate(pa_table.column_names): |
| if pa.types.is_struct(pa_table[column_name].type) and self.config.features.get( |
| column_name, None |
| ) == datasets.Value("string"): |
| jsonl = ( |
| pa_table[column_name] |
| .to_pandas(types_mapper=pd.ArrowDtype) |
| .to_json(orient="records", lines=True) |
| ) |
| string_array = pa.array( |
| ("{" + x.rstrip() for x in ("\n" + jsonl).split("\n{") if x), type=pa.string() |
| ) |
| pa_table = pa_table.set_column(i, column_name, string_array) |
| |
| |
| pa_table = table_cast(pa_table, self.config.features.arrow_schema) |
| return pa_table |
|
|
| def _generate_tables(self, files): |
| for file_idx, file in enumerate(itertools.chain.from_iterable(files)): |
| |
| if self.config.field is not None: |
| with open(file, encoding=self.config.encoding, errors=self.config.encoding_errors) as f: |
| dataset = ujson_loads(f.read()) |
| |
| dataset = dataset[self.config.field] |
| df = pandas_read_json(io.StringIO(ujson_dumps(dataset))) |
| if df.columns.tolist() == [0]: |
| df.columns = list(self.config.features) if self.config.features else ["text"] |
| pa_table = pa.Table.from_pandas(df, preserve_index=False) |
| yield file_idx, self._cast_table(pa_table) |
|
|
| |
| else: |
| with open(file, "rb") as f: |
| batch_idx = 0 |
| |
| |
| block_size = max(self.config.chunksize // 32, 16 << 10) |
| encoding_errors = ( |
| self.config.encoding_errors if self.config.encoding_errors is not None else "strict" |
| ) |
| while True: |
| batch = f.read(self.config.chunksize) |
| if not batch: |
| break |
| |
| try: |
| batch += f.readline() |
| except (AttributeError, io.UnsupportedOperation): |
| batch += readline(f) |
| |
| if self.config.encoding != "utf-8": |
| batch = batch.decode(self.config.encoding, errors=encoding_errors).encode("utf-8") |
| try: |
| while True: |
| try: |
| pa_table = paj.read_json( |
| io.BytesIO(batch), read_options=paj.ReadOptions(block_size=block_size) |
| ) |
| break |
| except (pa.ArrowInvalid, pa.ArrowNotImplementedError) as e: |
| if ( |
| isinstance(e, pa.ArrowInvalid) |
| and "straddling" not in str(e) |
| or block_size > len(batch) |
| ): |
| raise |
| else: |
| |
| |
| logger.debug( |
| f"Batch of {len(batch)} bytes couldn't be parsed with block_size={block_size}. Retrying with block_size={block_size * 2}." |
| ) |
| block_size *= 2 |
| except pa.ArrowInvalid as e: |
| try: |
| with open( |
| file, encoding=self.config.encoding, errors=self.config.encoding_errors |
| ) as f: |
| df = pandas_read_json(f) |
| except ValueError: |
| logger.error(f"Failed to load JSON from file '{file}' with error {type(e)}: {e}") |
| raise e |
| if df.columns.tolist() == [0]: |
| df.columns = list(self.config.features) if self.config.features else ["text"] |
| try: |
| pa_table = pa.Table.from_pandas(df, preserve_index=False) |
| except pa.ArrowInvalid as e: |
| logger.error( |
| f"Failed to convert pandas DataFrame to Arrow Table from file '{file}' with error {type(e)}: {e}" |
| ) |
| raise ValueError( |
| f"Failed to convert pandas DataFrame to Arrow Table from file {file}." |
| ) from None |
| yield file_idx, self._cast_table(pa_table) |
| break |
| yield (file_idx, batch_idx), self._cast_table(pa_table) |
| batch_idx += 1 |
|
|