Fill-Mask
Transformers
code
File size: 5,412 Bytes
8193465
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# Copyright 2020 The HuggingFace Datasets Authors and the TensorFlow Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Optional, Type

from .. import config
from ..utils import logging
from .formatting import (
    ArrowFormatter,
    CustomFormatter,
    Formatter,
    PandasFormatter,
    PythonFormatter,
    TableFormatter,
    TensorFormatter,
    format_table,
    query_table,
)
from .np_formatter import NumpyFormatter


logger = logging.get_logger(__name__)

_FORMAT_TYPES: dict[Optional[str], type[Formatter]] = {}
_FORMAT_TYPES_ALIASES: dict[Optional[str], str] = {}
_FORMAT_TYPES_ALIASES_UNAVAILABLE: dict[Optional[str], Exception] = {}


def _register_formatter(
    formatter_cls: type,
    format_type: Optional[str],
    aliases: Optional[list[str]] = None,
):
    """
    Register a Formatter object using a name and optional aliases.
    This function must be used on a Formatter class.
    """
    aliases = aliases if aliases is not None else []
    if format_type in _FORMAT_TYPES:
        logger.warning(
            f"Overwriting format type '{format_type}' ({_FORMAT_TYPES[format_type].__name__} -> {formatter_cls.__name__})"
        )
    _FORMAT_TYPES[format_type] = formatter_cls
    for alias in set(aliases + [format_type]):
        if alias in _FORMAT_TYPES_ALIASES:
            logger.warning(
                f"Overwriting format type alias '{alias}' ({_FORMAT_TYPES_ALIASES[alias]} -> {format_type})"
            )
        _FORMAT_TYPES_ALIASES[alias] = format_type


def _register_unavailable_formatter(
    unavailable_error: Exception, format_type: Optional[str], aliases: Optional[list[str]] = None
):
    """
    Register an unavailable Formatter object using a name and optional aliases.
    This function must be used on an Exception object that is raised when trying to get the unavailable formatter.
    """
    aliases = aliases if aliases is not None else []
    for alias in set(aliases + [format_type]):
        _FORMAT_TYPES_ALIASES_UNAVAILABLE[alias] = unavailable_error


# Here we define all the available formatting functions that can be used by `Dataset.set_format`
_register_formatter(PythonFormatter, None, aliases=["python"])
_register_formatter(ArrowFormatter, "arrow", aliases=["pa", "pyarrow"])
_register_formatter(NumpyFormatter, "numpy", aliases=["np"])
_register_formatter(PandasFormatter, "pandas", aliases=["pd"])
_register_formatter(CustomFormatter, "custom")

if config.POLARS_AVAILABLE:
    from .polars_formatter import PolarsFormatter

    _register_formatter(PolarsFormatter, "polars", aliases=["pl"])
else:
    _polars_error = ValueError("Polars needs to be installed to be able to return Polars dataframes.")
    _register_unavailable_formatter(_polars_error, "polars", aliases=["pl"])

if config.TORCH_AVAILABLE:
    from .torch_formatter import TorchFormatter

    _register_formatter(TorchFormatter, "torch", aliases=["pt", "pytorch"])
else:
    _torch_error = ValueError("PyTorch needs to be installed to be able to return PyTorch tensors.")
    _register_unavailable_formatter(_torch_error, "torch", aliases=["pt", "pytorch"])

if config.TF_AVAILABLE:
    from .tf_formatter import TFFormatter

    _register_formatter(TFFormatter, "tensorflow", aliases=["tf"])
else:
    _tf_error = ValueError("Tensorflow needs to be installed to be able to return Tensorflow tensors.")
    _register_unavailable_formatter(_tf_error, "tensorflow", aliases=["tf"])

if config.JAX_AVAILABLE:
    from .jax_formatter import JaxFormatter

    _register_formatter(JaxFormatter, "jax", aliases=[])
else:
    _jax_error = ValueError("JAX needs to be installed to be able to return JAX arrays.")
    _register_unavailable_formatter(_jax_error, "jax", aliases=[])


def get_format_type_from_alias(format_type: Optional[str]) -> Optional[str]:
    """If the given format type is a known alias, then return its main type name. Otherwise return the type with no change."""
    if format_type in _FORMAT_TYPES_ALIASES:
        return _FORMAT_TYPES_ALIASES[format_type]
    else:
        return format_type


def get_formatter(format_type: Optional[str], **format_kwargs) -> Formatter:
    """
    Factory function to get a Formatter given its type name and keyword arguments.
    A formatter is an object that extracts and formats data from pyarrow table.
    It defines the formatting for rows, columns and batches.
    If the formatter for a given type name doesn't exist or is not available, an error is raised.
    """
    format_type = get_format_type_from_alias(format_type)
    if format_type in _FORMAT_TYPES:
        return _FORMAT_TYPES[format_type](**format_kwargs)
    if format_type in _FORMAT_TYPES_ALIASES_UNAVAILABLE:
        raise _FORMAT_TYPES_ALIASES_UNAVAILABLE[format_type]
    else:
        raise ValueError(f"Format type should be one of {list(_FORMAT_TYPES.keys())}, but got '{format_type}'")