| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Dict, List, Optional, Type |
|
|
| from .. import config |
| from ..utils import logging |
| from .formatting import ( |
| ArrowFormatter, |
| CustomFormatter, |
| Formatter, |
| PandasFormatter, |
| PythonFormatter, |
| TableFormatter, |
| TensorFormatter, |
| format_table, |
| query_table, |
| ) |
| from .np_formatter import NumpyFormatter |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
| _FORMAT_TYPES: dict[Optional[str], type[Formatter]] = {} |
| _FORMAT_TYPES_ALIASES: dict[Optional[str], str] = {} |
| _FORMAT_TYPES_ALIASES_UNAVAILABLE: dict[Optional[str], Exception] = {} |
|
|
|
|
| def _register_formatter( |
| formatter_cls: type, |
| format_type: Optional[str], |
| aliases: Optional[list[str]] = None, |
| ): |
| """ |
| Register a Formatter object using a name and optional aliases. |
| This function must be used on a Formatter class. |
| """ |
| aliases = aliases if aliases is not None else [] |
| if format_type in _FORMAT_TYPES: |
| logger.warning( |
| f"Overwriting format type '{format_type}' ({_FORMAT_TYPES[format_type].__name__} -> {formatter_cls.__name__})" |
| ) |
| _FORMAT_TYPES[format_type] = formatter_cls |
| for alias in set(aliases + [format_type]): |
| if alias in _FORMAT_TYPES_ALIASES: |
| logger.warning( |
| f"Overwriting format type alias '{alias}' ({_FORMAT_TYPES_ALIASES[alias]} -> {format_type})" |
| ) |
| _FORMAT_TYPES_ALIASES[alias] = format_type |
|
|
|
|
| def _register_unavailable_formatter( |
| unavailable_error: Exception, format_type: Optional[str], aliases: Optional[list[str]] = None |
| ): |
| """ |
| Register an unavailable Formatter object using a name and optional aliases. |
| This function must be used on an Exception object that is raised when trying to get the unavailable formatter. |
| """ |
| aliases = aliases if aliases is not None else [] |
| for alias in set(aliases + [format_type]): |
| _FORMAT_TYPES_ALIASES_UNAVAILABLE[alias] = unavailable_error |
|
|
|
|
| |
| _register_formatter(PythonFormatter, None, aliases=["python"]) |
| _register_formatter(ArrowFormatter, "arrow", aliases=["pa", "pyarrow"]) |
| _register_formatter(NumpyFormatter, "numpy", aliases=["np"]) |
| _register_formatter(PandasFormatter, "pandas", aliases=["pd"]) |
| _register_formatter(CustomFormatter, "custom") |
|
|
| if config.POLARS_AVAILABLE: |
| from .polars_formatter import PolarsFormatter |
|
|
| _register_formatter(PolarsFormatter, "polars", aliases=["pl"]) |
| else: |
| _polars_error = ValueError("Polars needs to be installed to be able to return Polars dataframes.") |
| _register_unavailable_formatter(_polars_error, "polars", aliases=["pl"]) |
|
|
| if config.TORCH_AVAILABLE: |
| from .torch_formatter import TorchFormatter |
|
|
| _register_formatter(TorchFormatter, "torch", aliases=["pt", "pytorch"]) |
| else: |
| _torch_error = ValueError("PyTorch needs to be installed to be able to return PyTorch tensors.") |
| _register_unavailable_formatter(_torch_error, "torch", aliases=["pt", "pytorch"]) |
|
|
| if config.TF_AVAILABLE: |
| from .tf_formatter import TFFormatter |
|
|
| _register_formatter(TFFormatter, "tensorflow", aliases=["tf"]) |
| else: |
| _tf_error = ValueError("Tensorflow needs to be installed to be able to return Tensorflow tensors.") |
| _register_unavailable_formatter(_tf_error, "tensorflow", aliases=["tf"]) |
|
|
| if config.JAX_AVAILABLE: |
| from .jax_formatter import JaxFormatter |
|
|
| _register_formatter(JaxFormatter, "jax", aliases=[]) |
| else: |
| _jax_error = ValueError("JAX needs to be installed to be able to return JAX arrays.") |
| _register_unavailable_formatter(_jax_error, "jax", aliases=[]) |
|
|
|
|
| def get_format_type_from_alias(format_type: Optional[str]) -> Optional[str]: |
| """If the given format type is a known alias, then return its main type name. Otherwise return the type with no change.""" |
| if format_type in _FORMAT_TYPES_ALIASES: |
| return _FORMAT_TYPES_ALIASES[format_type] |
| else: |
| return format_type |
|
|
|
|
| def get_formatter(format_type: Optional[str], **format_kwargs) -> Formatter: |
| """ |
| Factory function to get a Formatter given its type name and keyword arguments. |
| A formatter is an object that extracts and formats data from pyarrow table. |
| It defines the formatting for rows, columns and batches. |
| If the formatter for a given type name doesn't exist or is not available, an error is raised. |
| """ |
| format_type = get_format_type_from_alias(format_type) |
| if format_type in _FORMAT_TYPES: |
| return _FORMAT_TYPES[format_type](**format_kwargs) |
| if format_type in _FORMAT_TYPES_ALIASES_UNAVAILABLE: |
| raise _FORMAT_TYPES_ALIASES_UNAVAILABLE[format_type] |
| else: |
| raise ValueError(f"Format type should be one of {list(_FORMAT_TYPES.keys())}, but got '{format_type}'") |
|
|