|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Dict, List, Optional, Type |
|
|
|
|
|
from .. import config |
|
|
from ..utils import logging |
|
|
from .formatting import ( |
|
|
ArrowFormatter, |
|
|
CustomFormatter, |
|
|
Formatter, |
|
|
PandasFormatter, |
|
|
PythonFormatter, |
|
|
TableFormatter, |
|
|
TensorFormatter, |
|
|
format_table, |
|
|
query_table, |
|
|
) |
|
|
from .np_formatter import NumpyFormatter |
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
_FORMAT_TYPES: dict[Optional[str], type[Formatter]] = {} |
|
|
_FORMAT_TYPES_ALIASES: dict[Optional[str], str] = {} |
|
|
_FORMAT_TYPES_ALIASES_UNAVAILABLE: dict[Optional[str], Exception] = {} |
|
|
|
|
|
|
|
|
def _register_formatter( |
|
|
formatter_cls: type, |
|
|
format_type: Optional[str], |
|
|
aliases: Optional[list[str]] = None, |
|
|
): |
|
|
""" |
|
|
Register a Formatter object using a name and optional aliases. |
|
|
This function must be used on a Formatter class. |
|
|
""" |
|
|
aliases = aliases if aliases is not None else [] |
|
|
if format_type in _FORMAT_TYPES: |
|
|
logger.warning( |
|
|
f"Overwriting format type '{format_type}' ({_FORMAT_TYPES[format_type].__name__} -> {formatter_cls.__name__})" |
|
|
) |
|
|
_FORMAT_TYPES[format_type] = formatter_cls |
|
|
for alias in set(aliases + [format_type]): |
|
|
if alias in _FORMAT_TYPES_ALIASES: |
|
|
logger.warning( |
|
|
f"Overwriting format type alias '{alias}' ({_FORMAT_TYPES_ALIASES[alias]} -> {format_type})" |
|
|
) |
|
|
_FORMAT_TYPES_ALIASES[alias] = format_type |
|
|
|
|
|
|
|
|
def _register_unavailable_formatter( |
|
|
unavailable_error: Exception, format_type: Optional[str], aliases: Optional[list[str]] = None |
|
|
): |
|
|
""" |
|
|
Register an unavailable Formatter object using a name and optional aliases. |
|
|
This function must be used on an Exception object that is raised when trying to get the unavailable formatter. |
|
|
""" |
|
|
aliases = aliases if aliases is not None else [] |
|
|
for alias in set(aliases + [format_type]): |
|
|
_FORMAT_TYPES_ALIASES_UNAVAILABLE[alias] = unavailable_error |
|
|
|
|
|
|
|
|
|
|
|
_register_formatter(PythonFormatter, None, aliases=["python"]) |
|
|
_register_formatter(ArrowFormatter, "arrow", aliases=["pa", "pyarrow"]) |
|
|
_register_formatter(NumpyFormatter, "numpy", aliases=["np"]) |
|
|
_register_formatter(PandasFormatter, "pandas", aliases=["pd"]) |
|
|
_register_formatter(CustomFormatter, "custom") |
|
|
|
|
|
if config.POLARS_AVAILABLE: |
|
|
from .polars_formatter import PolarsFormatter |
|
|
|
|
|
_register_formatter(PolarsFormatter, "polars", aliases=["pl"]) |
|
|
else: |
|
|
_polars_error = ValueError("Polars needs to be installed to be able to return Polars dataframes.") |
|
|
_register_unavailable_formatter(_polars_error, "polars", aliases=["pl"]) |
|
|
|
|
|
if config.TORCH_AVAILABLE: |
|
|
from .torch_formatter import TorchFormatter |
|
|
|
|
|
_register_formatter(TorchFormatter, "torch", aliases=["pt", "pytorch"]) |
|
|
else: |
|
|
_torch_error = ValueError("PyTorch needs to be installed to be able to return PyTorch tensors.") |
|
|
_register_unavailable_formatter(_torch_error, "torch", aliases=["pt", "pytorch"]) |
|
|
|
|
|
if config.TF_AVAILABLE: |
|
|
from .tf_formatter import TFFormatter |
|
|
|
|
|
_register_formatter(TFFormatter, "tensorflow", aliases=["tf"]) |
|
|
else: |
|
|
_tf_error = ValueError("Tensorflow needs to be installed to be able to return Tensorflow tensors.") |
|
|
_register_unavailable_formatter(_tf_error, "tensorflow", aliases=["tf"]) |
|
|
|
|
|
if config.JAX_AVAILABLE: |
|
|
from .jax_formatter import JaxFormatter |
|
|
|
|
|
_register_formatter(JaxFormatter, "jax", aliases=[]) |
|
|
else: |
|
|
_jax_error = ValueError("JAX needs to be installed to be able to return JAX arrays.") |
|
|
_register_unavailable_formatter(_jax_error, "jax", aliases=[]) |
|
|
|
|
|
|
|
|
def get_format_type_from_alias(format_type: Optional[str]) -> Optional[str]: |
|
|
"""If the given format type is a known alias, then return its main type name. Otherwise return the type with no change.""" |
|
|
if format_type in _FORMAT_TYPES_ALIASES: |
|
|
return _FORMAT_TYPES_ALIASES[format_type] |
|
|
else: |
|
|
return format_type |
|
|
|
|
|
|
|
|
def get_formatter(format_type: Optional[str], **format_kwargs) -> Formatter: |
|
|
""" |
|
|
Factory function to get a Formatter given its type name and keyword arguments. |
|
|
A formatter is an object that extracts and formats data from pyarrow table. |
|
|
It defines the formatting for rows, columns and batches. |
|
|
If the formatter for a given type name doesn't exist or is not available, an error is raised. |
|
|
""" |
|
|
format_type = get_format_type_from_alias(format_type) |
|
|
if format_type in _FORMAT_TYPES: |
|
|
return _FORMAT_TYPES[format_type](**format_kwargs) |
|
|
if format_type in _FORMAT_TYPES_ALIASES_UNAVAILABLE: |
|
|
raise _FORMAT_TYPES_ALIASES_UNAVAILABLE[format_type] |
|
|
else: |
|
|
raise ValueError(f"Format type should be one of {list(_FORMAT_TYPES.keys())}, but got '{format_type}'") |
|
|
|