File size: 5,412 Bytes
8193465 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
# Copyright 2020 The HuggingFace Datasets Authors and the TensorFlow Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional, Type
from .. import config
from ..utils import logging
from .formatting import (
ArrowFormatter,
CustomFormatter,
Formatter,
PandasFormatter,
PythonFormatter,
TableFormatter,
TensorFormatter,
format_table,
query_table,
)
from .np_formatter import NumpyFormatter
logger = logging.get_logger(__name__)
_FORMAT_TYPES: dict[Optional[str], type[Formatter]] = {}
_FORMAT_TYPES_ALIASES: dict[Optional[str], str] = {}
_FORMAT_TYPES_ALIASES_UNAVAILABLE: dict[Optional[str], Exception] = {}
def _register_formatter(
formatter_cls: type,
format_type: Optional[str],
aliases: Optional[list[str]] = None,
):
"""
Register a Formatter object using a name and optional aliases.
This function must be used on a Formatter class.
"""
aliases = aliases if aliases is not None else []
if format_type in _FORMAT_TYPES:
logger.warning(
f"Overwriting format type '{format_type}' ({_FORMAT_TYPES[format_type].__name__} -> {formatter_cls.__name__})"
)
_FORMAT_TYPES[format_type] = formatter_cls
for alias in set(aliases + [format_type]):
if alias in _FORMAT_TYPES_ALIASES:
logger.warning(
f"Overwriting format type alias '{alias}' ({_FORMAT_TYPES_ALIASES[alias]} -> {format_type})"
)
_FORMAT_TYPES_ALIASES[alias] = format_type
def _register_unavailable_formatter(
unavailable_error: Exception, format_type: Optional[str], aliases: Optional[list[str]] = None
):
"""
Register an unavailable Formatter object using a name and optional aliases.
This function must be used on an Exception object that is raised when trying to get the unavailable formatter.
"""
aliases = aliases if aliases is not None else []
for alias in set(aliases + [format_type]):
_FORMAT_TYPES_ALIASES_UNAVAILABLE[alias] = unavailable_error
# Here we define all the available formatting functions that can be used by `Dataset.set_format`
_register_formatter(PythonFormatter, None, aliases=["python"])
_register_formatter(ArrowFormatter, "arrow", aliases=["pa", "pyarrow"])
_register_formatter(NumpyFormatter, "numpy", aliases=["np"])
_register_formatter(PandasFormatter, "pandas", aliases=["pd"])
_register_formatter(CustomFormatter, "custom")
if config.POLARS_AVAILABLE:
from .polars_formatter import PolarsFormatter
_register_formatter(PolarsFormatter, "polars", aliases=["pl"])
else:
_polars_error = ValueError("Polars needs to be installed to be able to return Polars dataframes.")
_register_unavailable_formatter(_polars_error, "polars", aliases=["pl"])
if config.TORCH_AVAILABLE:
from .torch_formatter import TorchFormatter
_register_formatter(TorchFormatter, "torch", aliases=["pt", "pytorch"])
else:
_torch_error = ValueError("PyTorch needs to be installed to be able to return PyTorch tensors.")
_register_unavailable_formatter(_torch_error, "torch", aliases=["pt", "pytorch"])
if config.TF_AVAILABLE:
from .tf_formatter import TFFormatter
_register_formatter(TFFormatter, "tensorflow", aliases=["tf"])
else:
_tf_error = ValueError("Tensorflow needs to be installed to be able to return Tensorflow tensors.")
_register_unavailable_formatter(_tf_error, "tensorflow", aliases=["tf"])
if config.JAX_AVAILABLE:
from .jax_formatter import JaxFormatter
_register_formatter(JaxFormatter, "jax", aliases=[])
else:
_jax_error = ValueError("JAX needs to be installed to be able to return JAX arrays.")
_register_unavailable_formatter(_jax_error, "jax", aliases=[])
def get_format_type_from_alias(format_type: Optional[str]) -> Optional[str]:
"""If the given format type is a known alias, then return its main type name. Otherwise return the type with no change."""
if format_type in _FORMAT_TYPES_ALIASES:
return _FORMAT_TYPES_ALIASES[format_type]
else:
return format_type
def get_formatter(format_type: Optional[str], **format_kwargs) -> Formatter:
"""
Factory function to get a Formatter given its type name and keyword arguments.
A formatter is an object that extracts and formats data from pyarrow table.
It defines the formatting for rows, columns and batches.
If the formatter for a given type name doesn't exist or is not available, an error is raised.
"""
format_type = get_format_type_from_alias(format_type)
if format_type in _FORMAT_TYPES:
return _FORMAT_TYPES[format_type](**format_kwargs)
if format_type in _FORMAT_TYPES_ALIASES_UNAVAILABLE:
raise _FORMAT_TYPES_ALIASES_UNAVAILABLE[format_type]
else:
raise ValueError(f"Format type should be one of {list(_FORMAT_TYPES.keys())}, but got '{format_type}'")
|