|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import sys |
|
|
from functools import partial |
|
|
from typing import TYPE_CHECKING, Optional |
|
|
|
|
|
import pyarrow as pa |
|
|
|
|
|
from .. import config |
|
|
from ..features import Features |
|
|
from ..features.features import decode_nested_example |
|
|
from ..utils.py_utils import no_op_if_value_is_null |
|
|
from .formatting import BaseArrowExtractor, TableFormatter |
|
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
import polars as pl |
|
|
|
|
|
|
|
|
class PolarsArrowExtractor(BaseArrowExtractor["pl.DataFrame", "pl.Series", "pl.DataFrame"]): |
|
|
def extract_row(self, pa_table: pa.Table) -> "pl.DataFrame": |
|
|
if config.POLARS_AVAILABLE: |
|
|
if "polars" not in sys.modules: |
|
|
import polars |
|
|
else: |
|
|
polars = sys.modules["polars"] |
|
|
|
|
|
return polars.from_arrow(pa_table.slice(length=1)) |
|
|
else: |
|
|
raise ValueError("Polars needs to be installed to be able to return Polars dataframes.") |
|
|
|
|
|
def extract_column(self, pa_table: pa.Table) -> "pl.Series": |
|
|
if config.POLARS_AVAILABLE: |
|
|
if "polars" not in sys.modules: |
|
|
import polars |
|
|
else: |
|
|
polars = sys.modules["polars"] |
|
|
|
|
|
return polars.from_arrow(pa_table.select([0]))[pa_table.column_names[0]] |
|
|
else: |
|
|
raise ValueError("Polars needs to be installed to be able to return Polars dataframes.") |
|
|
|
|
|
def extract_batch(self, pa_table: pa.Table) -> "pl.DataFrame": |
|
|
if config.POLARS_AVAILABLE: |
|
|
if "polars" not in sys.modules: |
|
|
import polars |
|
|
else: |
|
|
polars = sys.modules["polars"] |
|
|
|
|
|
return polars.from_arrow(pa_table) |
|
|
else: |
|
|
raise ValueError("Polars needs to be installed to be able to return Polars dataframes.") |
|
|
|
|
|
|
|
|
class PolarsFeaturesDecoder: |
|
|
def __init__(self, features: Optional[Features]): |
|
|
self.features = features |
|
|
import polars as pl |
|
|
|
|
|
def decode_row(self, row: "pl.DataFrame") -> "pl.DataFrame": |
|
|
decode = ( |
|
|
{ |
|
|
column_name: no_op_if_value_is_null(partial(decode_nested_example, feature)) |
|
|
for column_name, feature in self.features.items() |
|
|
if self.features._column_requires_decoding[column_name] |
|
|
} |
|
|
if self.features |
|
|
else {} |
|
|
) |
|
|
if decode: |
|
|
row[list(decode.keys())] = row.map_rows(decode) |
|
|
return row |
|
|
|
|
|
def decode_column(self, column: "pl.Series", column_name: str) -> "pl.Series": |
|
|
decode = ( |
|
|
no_op_if_value_is_null(partial(decode_nested_example, self.features[column_name])) |
|
|
if self.features and column_name in self.features and self.features._column_requires_decoding[column_name] |
|
|
else None |
|
|
) |
|
|
if decode: |
|
|
column = column.map_elements(decode) |
|
|
return column |
|
|
|
|
|
def decode_batch(self, batch: "pl.DataFrame") -> "pl.DataFrame": |
|
|
return self.decode_row(batch) |
|
|
|
|
|
|
|
|
class PolarsFormatter(TableFormatter["pl.DataFrame", "pl.Series", "pl.DataFrame"]): |
|
|
table_type = "polars dataframe" |
|
|
column_type = "polars series" |
|
|
|
|
|
def __init__(self, features=None, **np_array_kwargs): |
|
|
super().__init__(features=features) |
|
|
self.np_array_kwargs = np_array_kwargs |
|
|
self.polars_arrow_extractor = PolarsArrowExtractor |
|
|
self.polars_features_decoder = PolarsFeaturesDecoder(features) |
|
|
import polars as pl |
|
|
|
|
|
def format_row(self, pa_table: pa.Table) -> "pl.DataFrame": |
|
|
row = self.polars_arrow_extractor().extract_row(pa_table) |
|
|
row = self.polars_features_decoder.decode_row(row) |
|
|
return row |
|
|
|
|
|
def format_column(self, pa_table: pa.Table) -> "pl.Series": |
|
|
column = self.polars_arrow_extractor().extract_column(pa_table) |
|
|
column = self.polars_features_decoder.decode_column(column, pa_table.column_names[0]) |
|
|
return column |
|
|
|
|
|
def format_batch(self, pa_table: pa.Table) -> "pl.DataFrame": |
|
|
row = self.polars_arrow_extractor().extract_batch(pa_table) |
|
|
row = self.polars_features_decoder.decode_batch(row) |
|
|
return row |
|
|
|