|
|
""" |
|
|
Utility routines |
|
|
""" |
|
|
from collections.abc import Mapping, MutableMapping |
|
|
from copy import deepcopy |
|
|
import json |
|
|
import itertools |
|
|
import re |
|
|
import sys |
|
|
import traceback |
|
|
import warnings |
|
|
from typing import ( |
|
|
Callable, |
|
|
TypeVar, |
|
|
Any, |
|
|
Union, |
|
|
Dict, |
|
|
Optional, |
|
|
Tuple, |
|
|
Sequence, |
|
|
Type, |
|
|
cast, |
|
|
) |
|
|
from types import ModuleType |
|
|
|
|
|
import jsonschema |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
from pandas.api.types import infer_dtype |
|
|
|
|
|
from altair.utils.schemapi import SchemaBase |
|
|
from altair.utils._dfi_types import Column, DtypeKind, DataFrame as DfiDataFrame |
|
|
|
|
|
if sys.version_info >= (3, 10): |
|
|
from typing import ParamSpec |
|
|
else: |
|
|
from typing_extensions import ParamSpec |
|
|
|
|
|
from typing import Literal, Protocol, TYPE_CHECKING |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from pandas.core.interchange.dataframe_protocol import Column as PandasColumn |
|
|
|
|
|
V = TypeVar("V") |
|
|
P = ParamSpec("P") |
|
|
|
|
|
|
|
|
class DataFrameLike(Protocol): |
|
|
def __dataframe__(self, *args, **kwargs) -> DfiDataFrame: |
|
|
... |
|
|
|
|
|
|
|
|
TYPECODE_MAP = { |
|
|
"ordinal": "O", |
|
|
"nominal": "N", |
|
|
"quantitative": "Q", |
|
|
"temporal": "T", |
|
|
"geojson": "G", |
|
|
} |
|
|
|
|
|
INV_TYPECODE_MAP = {v: k for k, v in TYPECODE_MAP.items()} |
|
|
|
|
|
|
|
|
|
|
|
AGGREGATES = [ |
|
|
"argmax", |
|
|
"argmin", |
|
|
"average", |
|
|
"count", |
|
|
"distinct", |
|
|
"max", |
|
|
"mean", |
|
|
"median", |
|
|
"min", |
|
|
"missing", |
|
|
"product", |
|
|
"q1", |
|
|
"q3", |
|
|
"ci0", |
|
|
"ci1", |
|
|
"stderr", |
|
|
"stdev", |
|
|
"stdevp", |
|
|
"sum", |
|
|
"valid", |
|
|
"values", |
|
|
"variance", |
|
|
"variancep", |
|
|
] |
|
|
|
|
|
|
|
|
WINDOW_AGGREGATES = [ |
|
|
"row_number", |
|
|
"rank", |
|
|
"dense_rank", |
|
|
"percent_rank", |
|
|
"cume_dist", |
|
|
"ntile", |
|
|
"lag", |
|
|
"lead", |
|
|
"first_value", |
|
|
"last_value", |
|
|
"nth_value", |
|
|
] |
|
|
|
|
|
|
|
|
TIMEUNITS = [ |
|
|
"year", |
|
|
"quarter", |
|
|
"month", |
|
|
"week", |
|
|
"day", |
|
|
"dayofyear", |
|
|
"date", |
|
|
"hours", |
|
|
"minutes", |
|
|
"seconds", |
|
|
"milliseconds", |
|
|
"yearquarter", |
|
|
"yearquartermonth", |
|
|
"yearmonth", |
|
|
"yearmonthdate", |
|
|
"yearmonthdatehours", |
|
|
"yearmonthdatehoursminutes", |
|
|
"yearmonthdatehoursminutesseconds", |
|
|
"yearweek", |
|
|
"yearweekday", |
|
|
"yearweekdayhours", |
|
|
"yearweekdayhoursminutes", |
|
|
"yearweekdayhoursminutesseconds", |
|
|
"yeardayofyear", |
|
|
"quartermonth", |
|
|
"monthdate", |
|
|
"monthdatehours", |
|
|
"monthdatehoursminutes", |
|
|
"monthdatehoursminutesseconds", |
|
|
"weekday", |
|
|
"weeksdayhours", |
|
|
"weekdayhoursminutes", |
|
|
"weekdayhoursminutesseconds", |
|
|
"dayhours", |
|
|
"dayhoursminutes", |
|
|
"dayhoursminutesseconds", |
|
|
"hoursminutes", |
|
|
"hoursminutesseconds", |
|
|
"minutesseconds", |
|
|
"secondsmilliseconds", |
|
|
"utcyear", |
|
|
"utcquarter", |
|
|
"utcmonth", |
|
|
"utcweek", |
|
|
"utcday", |
|
|
"utcdayofyear", |
|
|
"utcdate", |
|
|
"utchours", |
|
|
"utcminutes", |
|
|
"utcseconds", |
|
|
"utcmilliseconds", |
|
|
"utcyearquarter", |
|
|
"utcyearquartermonth", |
|
|
"utcyearmonth", |
|
|
"utcyearmonthdate", |
|
|
"utcyearmonthdatehours", |
|
|
"utcyearmonthdatehoursminutes", |
|
|
"utcyearmonthdatehoursminutesseconds", |
|
|
"utcyearweek", |
|
|
"utcyearweekday", |
|
|
"utcyearweekdayhours", |
|
|
"utcyearweekdayhoursminutes", |
|
|
"utcyearweekdayhoursminutesseconds", |
|
|
"utcyeardayofyear", |
|
|
"utcquartermonth", |
|
|
"utcmonthdate", |
|
|
"utcmonthdatehours", |
|
|
"utcmonthdatehoursminutes", |
|
|
"utcmonthdatehoursminutesseconds", |
|
|
"utcweekday", |
|
|
"utcweeksdayhours", |
|
|
"utcweekdayhoursminutes", |
|
|
"utcweekdayhoursminutesseconds", |
|
|
"utcdayhours", |
|
|
"utcdayhoursminutes", |
|
|
"utcdayhoursminutesseconds", |
|
|
"utchoursminutes", |
|
|
"utchoursminutesseconds", |
|
|
"utcminutesseconds", |
|
|
"utcsecondsmilliseconds", |
|
|
] |
|
|
|
|
|
|
|
|
InferredVegaLiteType = Literal["ordinal", "nominal", "quantitative", "temporal"] |
|
|
|
|
|
|
|
|
def infer_vegalite_type( |
|
|
data: object, |
|
|
) -> Union[InferredVegaLiteType, Tuple[InferredVegaLiteType, list]]: |
|
|
""" |
|
|
From an array-like input, infer the correct vega typecode |
|
|
('ordinal', 'nominal', 'quantitative', or 'temporal') |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
data: object |
|
|
""" |
|
|
typ = infer_dtype(data, skipna=False) |
|
|
|
|
|
if typ in [ |
|
|
"floating", |
|
|
"mixed-integer-float", |
|
|
"integer", |
|
|
"mixed-integer", |
|
|
"complex", |
|
|
]: |
|
|
return "quantitative" |
|
|
elif typ == "categorical" and hasattr(data, "cat") and data.cat.ordered: |
|
|
return ("ordinal", data.cat.categories.tolist()) |
|
|
elif typ in ["string", "bytes", "categorical", "boolean", "mixed", "unicode"]: |
|
|
return "nominal" |
|
|
elif typ in [ |
|
|
"datetime", |
|
|
"datetime64", |
|
|
"timedelta", |
|
|
"timedelta64", |
|
|
"date", |
|
|
"time", |
|
|
"period", |
|
|
]: |
|
|
return "temporal" |
|
|
else: |
|
|
warnings.warn( |
|
|
"I don't know how to infer vegalite type from '{}'. " |
|
|
"Defaulting to nominal.".format(typ), |
|
|
stacklevel=1, |
|
|
) |
|
|
return "nominal" |
|
|
|
|
|
|
|
|
def merge_props_geom(feat: dict) -> dict: |
|
|
""" |
|
|
Merge properties with geometry |
|
|
* Overwrites 'type' and 'geometry' entries if existing |
|
|
""" |
|
|
|
|
|
geom = {k: feat[k] for k in ("type", "geometry")} |
|
|
try: |
|
|
feat["properties"].update(geom) |
|
|
props_geom = feat["properties"] |
|
|
except (AttributeError, KeyError): |
|
|
|
|
|
|
|
|
props_geom = geom |
|
|
|
|
|
return props_geom |
|
|
|
|
|
|
|
|
def sanitize_geo_interface(geo: MutableMapping) -> dict: |
|
|
"""Santize a geo_interface to prepare it for serialization. |
|
|
|
|
|
* Make a copy |
|
|
* Convert type array or _Array to list |
|
|
* Convert tuples to lists (using json.loads/dumps) |
|
|
* Merge properties with geometry |
|
|
""" |
|
|
|
|
|
geo = deepcopy(geo) |
|
|
|
|
|
|
|
|
for key in geo.keys(): |
|
|
if str(type(geo[key]).__name__).startswith(("_Array", "array")): |
|
|
geo[key] = geo[key].tolist() |
|
|
|
|
|
|
|
|
geo_dct: dict = json.loads(json.dumps(geo)) |
|
|
|
|
|
|
|
|
if geo_dct["type"] == "FeatureCollection": |
|
|
geo_dct = geo_dct["features"] |
|
|
if len(geo_dct) > 0: |
|
|
for idx, feat in enumerate(geo_dct): |
|
|
geo_dct[idx] = merge_props_geom(feat) |
|
|
elif geo_dct["type"] == "Feature": |
|
|
geo_dct = merge_props_geom(geo_dct) |
|
|
else: |
|
|
geo_dct = {"type": "Feature", "geometry": geo_dct} |
|
|
|
|
|
return geo_dct |
|
|
|
|
|
|
|
|
def numpy_is_subtype(dtype: Any, subtype: Any) -> bool: |
|
|
try: |
|
|
return np.issubdtype(dtype, subtype) |
|
|
except (NotImplementedError, TypeError): |
|
|
return False |
|
|
|
|
|
|
|
|
def sanitize_dataframe(df: pd.DataFrame) -> pd.DataFrame: |
|
|
"""Sanitize a DataFrame to prepare it for serialization. |
|
|
|
|
|
* Make a copy |
|
|
* Convert RangeIndex columns to strings |
|
|
* Raise ValueError if column names are not strings |
|
|
* Raise ValueError if it has a hierarchical index. |
|
|
* Convert categoricals to strings. |
|
|
* Convert np.bool_ dtypes to Python bool objects |
|
|
* Convert np.int dtypes to Python int objects |
|
|
* Convert floats to objects and replace NaNs/infs with None. |
|
|
* Convert DateTime dtypes into appropriate string representations |
|
|
* Convert Nullable integers to objects and replace NaN with None |
|
|
* Convert Nullable boolean to objects and replace NaN with None |
|
|
* convert dedicated string column to objects and replace NaN with None |
|
|
* Raise a ValueError for TimeDelta dtypes |
|
|
""" |
|
|
df = df.copy() |
|
|
|
|
|
if isinstance(df.columns, pd.RangeIndex): |
|
|
df.columns = df.columns.astype(str) |
|
|
|
|
|
for col_name in df.columns: |
|
|
if not isinstance(col_name, str): |
|
|
raise ValueError( |
|
|
"Dataframe contains invalid column name: {0!r}. " |
|
|
"Column names must be strings".format(col_name) |
|
|
) |
|
|
|
|
|
if isinstance(df.index, pd.MultiIndex): |
|
|
raise ValueError("Hierarchical indices not supported") |
|
|
if isinstance(df.columns, pd.MultiIndex): |
|
|
raise ValueError("Hierarchical indices not supported") |
|
|
|
|
|
def to_list_if_array(val): |
|
|
if isinstance(val, np.ndarray): |
|
|
return val.tolist() |
|
|
else: |
|
|
return val |
|
|
|
|
|
for dtype_item in df.dtypes.items(): |
|
|
|
|
|
|
|
|
|
|
|
col_name = cast(str, dtype_item[0]) |
|
|
dtype = dtype_item[1] |
|
|
dtype_name = str(dtype) |
|
|
if dtype_name == "category": |
|
|
|
|
|
|
|
|
|
|
|
col = df[col_name].astype(object) |
|
|
df[col_name] = col.where(col.notnull(), None) |
|
|
elif dtype_name == "string": |
|
|
|
|
|
|
|
|
col = df[col_name].astype(object) |
|
|
df[col_name] = col.where(col.notnull(), None) |
|
|
elif dtype_name == "bool": |
|
|
|
|
|
df[col_name] = df[col_name].astype(object) |
|
|
elif dtype_name == "boolean": |
|
|
|
|
|
|
|
|
col = df[col_name].astype(object) |
|
|
df[col_name] = col.where(col.notnull(), None) |
|
|
elif dtype_name.startswith("datetime") or dtype_name.startswith("timestamp"): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
df[col_name] = ( |
|
|
df[col_name].apply(lambda x: x.isoformat()).replace("NaT", "") |
|
|
) |
|
|
elif dtype_name.startswith("timedelta"): |
|
|
raise ValueError( |
|
|
'Field "{col_name}" has type "{dtype}" which is ' |
|
|
"not supported by Altair. Please convert to " |
|
|
"either a timestamp or a numerical value." |
|
|
"".format(col_name=col_name, dtype=dtype) |
|
|
) |
|
|
elif dtype_name.startswith("geometry"): |
|
|
|
|
|
|
|
|
continue |
|
|
elif ( |
|
|
dtype_name |
|
|
in { |
|
|
"Int8", |
|
|
"Int16", |
|
|
"Int32", |
|
|
"Int64", |
|
|
"UInt8", |
|
|
"UInt16", |
|
|
"UInt32", |
|
|
"UInt64", |
|
|
"Float32", |
|
|
"Float64", |
|
|
} |
|
|
): |
|
|
|
|
|
col = df[col_name].astype(object) |
|
|
df[col_name] = col.where(col.notnull(), None) |
|
|
elif numpy_is_subtype(dtype, np.integer): |
|
|
|
|
|
df[col_name] = df[col_name].astype(object) |
|
|
elif numpy_is_subtype(dtype, np.floating): |
|
|
|
|
|
|
|
|
col = df[col_name] |
|
|
bad_values = col.isnull() | np.isinf(col) |
|
|
df[col_name] = col.astype(object).where(~bad_values, None) |
|
|
elif dtype == object: |
|
|
|
|
|
|
|
|
col = df[col_name].astype(object).apply(to_list_if_array) |
|
|
df[col_name] = col.where(col.notnull(), None) |
|
|
return df |
|
|
|
|
|
|
|
|
def sanitize_arrow_table(pa_table): |
|
|
"""Sanitize arrow table for JSON serialization""" |
|
|
import pyarrow as pa |
|
|
import pyarrow.compute as pc |
|
|
|
|
|
arrays = [] |
|
|
schema = pa_table.schema |
|
|
for name in schema.names: |
|
|
array = pa_table[name] |
|
|
dtype = schema.field(name).type |
|
|
if str(dtype).startswith("timestamp"): |
|
|
arrays.append(pc.strftime(array)) |
|
|
elif str(dtype).startswith("duration"): |
|
|
raise ValueError( |
|
|
'Field "{col_name}" has type "{dtype}" which is ' |
|
|
"not supported by Altair. Please convert to " |
|
|
"either a timestamp or a numerical value." |
|
|
"".format(col_name=name, dtype=dtype) |
|
|
) |
|
|
else: |
|
|
arrays.append(array) |
|
|
|
|
|
return pa.Table.from_arrays(arrays, names=schema.names) |
|
|
|
|
|
|
|
|
def parse_shorthand( |
|
|
shorthand: Union[Dict[str, Any], str], |
|
|
data: Optional[Union[pd.DataFrame, DataFrameLike]] = None, |
|
|
parse_aggregates: bool = True, |
|
|
parse_window_ops: bool = False, |
|
|
parse_timeunits: bool = True, |
|
|
parse_types: bool = True, |
|
|
) -> Dict[str, Any]: |
|
|
"""General tool to parse shorthand values |
|
|
|
|
|
These are of the form: |
|
|
|
|
|
- "col_name" |
|
|
- "col_name:O" |
|
|
- "average(col_name)" |
|
|
- "average(col_name):O" |
|
|
|
|
|
Optionally, a dataframe may be supplied, from which the type |
|
|
will be inferred if not specified in the shorthand. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
shorthand : dict or string |
|
|
The shorthand representation to be parsed |
|
|
data : DataFrame, optional |
|
|
If specified and of type DataFrame, then use these values to infer the |
|
|
column type if not provided by the shorthand. |
|
|
parse_aggregates : boolean |
|
|
If True (default), then parse aggregate functions within the shorthand. |
|
|
parse_window_ops : boolean |
|
|
If True then parse window operations within the shorthand (default:False) |
|
|
parse_timeunits : boolean |
|
|
If True (default), then parse timeUnits from within the shorthand |
|
|
parse_types : boolean |
|
|
If True (default), then parse typecodes within the shorthand |
|
|
|
|
|
Returns |
|
|
------- |
|
|
attrs : dict |
|
|
a dictionary of attributes extracted from the shorthand |
|
|
|
|
|
Examples |
|
|
-------- |
|
|
>>> data = pd.DataFrame({'foo': ['A', 'B', 'A', 'B'], |
|
|
... 'bar': [1, 2, 3, 4]}) |
|
|
|
|
|
>>> parse_shorthand('name') == {'field': 'name'} |
|
|
True |
|
|
|
|
|
>>> parse_shorthand('name:Q') == {'field': 'name', 'type': 'quantitative'} |
|
|
True |
|
|
|
|
|
>>> parse_shorthand('average(col)') == {'aggregate': 'average', 'field': 'col'} |
|
|
True |
|
|
|
|
|
>>> parse_shorthand('foo:O') == {'field': 'foo', 'type': 'ordinal'} |
|
|
True |
|
|
|
|
|
>>> parse_shorthand('min(foo):Q') == {'aggregate': 'min', 'field': 'foo', 'type': 'quantitative'} |
|
|
True |
|
|
|
|
|
>>> parse_shorthand('month(col)') == {'field': 'col', 'timeUnit': 'month', 'type': 'temporal'} |
|
|
True |
|
|
|
|
|
>>> parse_shorthand('year(col):O') == {'field': 'col', 'timeUnit': 'year', 'type': 'ordinal'} |
|
|
True |
|
|
|
|
|
>>> parse_shorthand('foo', data) == {'field': 'foo', 'type': 'nominal'} |
|
|
True |
|
|
|
|
|
>>> parse_shorthand('bar', data) == {'field': 'bar', 'type': 'quantitative'} |
|
|
True |
|
|
|
|
|
>>> parse_shorthand('bar:O', data) == {'field': 'bar', 'type': 'ordinal'} |
|
|
True |
|
|
|
|
|
>>> parse_shorthand('sum(bar)', data) == {'aggregate': 'sum', 'field': 'bar', 'type': 'quantitative'} |
|
|
True |
|
|
|
|
|
>>> parse_shorthand('count()', data) == {'aggregate': 'count', 'type': 'quantitative'} |
|
|
True |
|
|
""" |
|
|
from altair.utils._importers import pyarrow_available |
|
|
|
|
|
if not shorthand: |
|
|
return {} |
|
|
|
|
|
valid_typecodes = list(TYPECODE_MAP) + list(INV_TYPECODE_MAP) |
|
|
|
|
|
units = { |
|
|
"field": "(?P<field>.*)", |
|
|
"type": "(?P<type>{})".format("|".join(valid_typecodes)), |
|
|
"agg_count": "(?P<aggregate>count)", |
|
|
"op_count": "(?P<op>count)", |
|
|
"aggregate": "(?P<aggregate>{})".format("|".join(AGGREGATES)), |
|
|
"window_op": "(?P<op>{})".format("|".join(AGGREGATES + WINDOW_AGGREGATES)), |
|
|
"timeUnit": "(?P<timeUnit>{})".format("|".join(TIMEUNITS)), |
|
|
} |
|
|
|
|
|
patterns = [] |
|
|
|
|
|
if parse_aggregates: |
|
|
patterns.extend([r"{agg_count}\(\)"]) |
|
|
patterns.extend([r"{aggregate}\({field}\)"]) |
|
|
if parse_window_ops: |
|
|
patterns.extend([r"{op_count}\(\)"]) |
|
|
patterns.extend([r"{window_op}\({field}\)"]) |
|
|
if parse_timeunits: |
|
|
patterns.extend([r"{timeUnit}\({field}\)"]) |
|
|
|
|
|
patterns.extend([r"{field}"]) |
|
|
|
|
|
if parse_types: |
|
|
patterns = list(itertools.chain(*((p + ":{type}", p) for p in patterns))) |
|
|
|
|
|
regexps = ( |
|
|
re.compile(r"\A" + p.format(**units) + r"\Z", re.DOTALL) for p in patterns |
|
|
) |
|
|
|
|
|
|
|
|
if isinstance(shorthand, dict): |
|
|
attrs = shorthand |
|
|
else: |
|
|
attrs = next( |
|
|
exp.match(shorthand).groupdict() |
|
|
for exp in regexps |
|
|
if exp.match(shorthand) is not None |
|
|
) |
|
|
|
|
|
|
|
|
if "type" in attrs: |
|
|
attrs["type"] = INV_TYPECODE_MAP.get(attrs["type"], attrs["type"]) |
|
|
|
|
|
|
|
|
if attrs == {"aggregate": "count"}: |
|
|
attrs["type"] = "quantitative" |
|
|
|
|
|
|
|
|
if "timeUnit" in attrs and "type" not in attrs: |
|
|
attrs["type"] = "temporal" |
|
|
|
|
|
|
|
|
if "type" not in attrs: |
|
|
if pyarrow_available() and data is not None and hasattr(data, "__dataframe__"): |
|
|
dfi = data.__dataframe__() |
|
|
if "field" in attrs: |
|
|
unescaped_field = attrs["field"].replace("\\", "") |
|
|
if unescaped_field in dfi.column_names(): |
|
|
column = dfi.get_column_by_name(unescaped_field) |
|
|
try: |
|
|
attrs["type"] = infer_vegalite_type_for_dfi_column(column) |
|
|
except (NotImplementedError, AttributeError, ValueError): |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(data, pd.DataFrame): |
|
|
attrs["type"] = infer_vegalite_type(data[unescaped_field]) |
|
|
else: |
|
|
raise |
|
|
|
|
|
if isinstance(attrs["type"], tuple): |
|
|
attrs["sort"] = attrs["type"][1] |
|
|
attrs["type"] = attrs["type"][0] |
|
|
elif isinstance(data, pd.DataFrame): |
|
|
|
|
|
|
|
|
|
|
|
if "field" in attrs and attrs["field"].replace("\\", "") in data.columns: |
|
|
attrs["type"] = infer_vegalite_type( |
|
|
data[attrs["field"].replace("\\", "")] |
|
|
) |
|
|
|
|
|
if isinstance(attrs["type"], tuple): |
|
|
attrs["sort"] = attrs["type"][1] |
|
|
attrs["type"] = attrs["type"][0] |
|
|
|
|
|
|
|
|
|
|
|
if ( |
|
|
"field" in attrs |
|
|
and ":" in attrs["field"] |
|
|
and attrs["field"][attrs["field"].rfind(":") - 1] != "\\" |
|
|
): |
|
|
raise ValueError( |
|
|
'"{}" '.format(attrs["field"].split(":")[-1]) |
|
|
+ "is not one of the valid encoding data types: {}.".format( |
|
|
", ".join(TYPECODE_MAP.values()) |
|
|
) |
|
|
+ "\nFor more details, see https://altair-viz.github.io/user_guide/encodings/index.html#encoding-data-types. " |
|
|
+ "If you are trying to use a column name that contains a colon, " |
|
|
+ 'prefix it with a backslash; for example "column\\:name" instead of "column:name".' |
|
|
) |
|
|
return attrs |
|
|
|
|
|
|
|
|
def infer_vegalite_type_for_dfi_column( |
|
|
column: Union[Column, "PandasColumn"], |
|
|
) -> Union[InferredVegaLiteType, Tuple[InferredVegaLiteType, list]]: |
|
|
from pyarrow.interchange.from_dataframe import column_to_array |
|
|
|
|
|
try: |
|
|
kind = column.dtype[0] |
|
|
except NotImplementedError as e: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "datetime64" in e.args[0] or "timestamp" in e.args[0]: |
|
|
return "temporal" |
|
|
raise e |
|
|
|
|
|
if ( |
|
|
kind == DtypeKind.CATEGORICAL |
|
|
and column.describe_categorical["is_ordered"] |
|
|
and column.describe_categorical["categories"] is not None |
|
|
): |
|
|
|
|
|
categories_column = column.describe_categorical["categories"] |
|
|
categories_array = column_to_array(categories_column) |
|
|
return "ordinal", categories_array.to_pylist() |
|
|
if kind in (DtypeKind.STRING, DtypeKind.CATEGORICAL, DtypeKind.BOOL): |
|
|
return "nominal" |
|
|
elif kind in (DtypeKind.INT, DtypeKind.UINT, DtypeKind.FLOAT): |
|
|
return "quantitative" |
|
|
elif kind == DtypeKind.DATETIME: |
|
|
return "temporal" |
|
|
else: |
|
|
raise ValueError(f"Unexpected DtypeKind: {kind}") |
|
|
|
|
|
|
|
|
def use_signature(Obj: Callable[P, Any]): |
|
|
"""Apply call signature and documentation of Obj to the decorated method""" |
|
|
|
|
|
def decorate(f: Callable[..., V]) -> Callable[P, V]: |
|
|
|
|
|
|
|
|
f.__wrapped__ = Obj.__init__ |
|
|
f._uses_signature = Obj |
|
|
|
|
|
|
|
|
if Obj.__doc__: |
|
|
|
|
|
|
|
|
doclines = Obj.__doc__.splitlines() |
|
|
doclines[0] = f"Refer to :class:`{Obj.__name__}`" |
|
|
|
|
|
if f.__doc__: |
|
|
doc = f.__doc__ + "\n".join(doclines[1:]) |
|
|
else: |
|
|
doc = "\n".join(doclines) |
|
|
try: |
|
|
f.__doc__ = doc |
|
|
except AttributeError: |
|
|
|
|
|
pass |
|
|
|
|
|
return f |
|
|
|
|
|
return decorate |
|
|
|
|
|
|
|
|
def update_nested( |
|
|
original: MutableMapping, update: Mapping, copy: bool = False |
|
|
) -> MutableMapping: |
|
|
"""Update nested dictionaries |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
original : MutableMapping |
|
|
the original (nested) dictionary, which will be updated in-place |
|
|
update : Mapping |
|
|
the nested dictionary of updates |
|
|
copy : bool, default False |
|
|
if True, then copy the original dictionary rather than modifying it |
|
|
|
|
|
Returns |
|
|
------- |
|
|
original : MutableMapping |
|
|
a reference to the (modified) original dict |
|
|
|
|
|
Examples |
|
|
-------- |
|
|
>>> original = {'x': {'b': 2, 'c': 4}} |
|
|
>>> update = {'x': {'b': 5, 'd': 6}, 'y': 40} |
|
|
>>> update_nested(original, update) # doctest: +SKIP |
|
|
{'x': {'b': 5, 'c': 4, 'd': 6}, 'y': 40} |
|
|
>>> original # doctest: +SKIP |
|
|
{'x': {'b': 5, 'c': 4, 'd': 6}, 'y': 40} |
|
|
""" |
|
|
if copy: |
|
|
original = deepcopy(original) |
|
|
for key, val in update.items(): |
|
|
if isinstance(val, Mapping): |
|
|
orig_val = original.get(key, {}) |
|
|
if isinstance(orig_val, MutableMapping): |
|
|
original[key] = update_nested(orig_val, val) |
|
|
else: |
|
|
original[key] = val |
|
|
else: |
|
|
original[key] = val |
|
|
return original |
|
|
|
|
|
|
|
|
def display_traceback(in_ipython: bool = True): |
|
|
exc_info = sys.exc_info() |
|
|
|
|
|
if in_ipython: |
|
|
from IPython.core.getipython import get_ipython |
|
|
|
|
|
ip = get_ipython() |
|
|
else: |
|
|
ip = None |
|
|
|
|
|
if ip is not None: |
|
|
ip.showtraceback(exc_info) |
|
|
else: |
|
|
traceback.print_exception(*exc_info) |
|
|
|
|
|
|
|
|
def infer_encoding_types(args: Sequence, kwargs: MutableMapping, channels: ModuleType): |
|
|
"""Infer typed keyword arguments for args and kwargs |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
args : Sequence |
|
|
Sequence of function args |
|
|
kwargs : MutableMapping |
|
|
Dict of function kwargs |
|
|
channels : ModuleType |
|
|
The module containing all altair encoding channel classes. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
kwargs : dict |
|
|
All args and kwargs in a single dict, with keys and types |
|
|
based on the channels mapping. |
|
|
""" |
|
|
|
|
|
|
|
|
channel_objs = (getattr(channels, name) for name in dir(channels)) |
|
|
channel_objs = ( |
|
|
c for c in channel_objs if isinstance(c, type) and issubclass(c, SchemaBase) |
|
|
) |
|
|
channel_to_name: Dict[Type[SchemaBase], str] = { |
|
|
c: c._encoding_name for c in channel_objs |
|
|
} |
|
|
name_to_channel: Dict[str, Dict[str, Type[SchemaBase]]] = {} |
|
|
for chan, name in channel_to_name.items(): |
|
|
chans = name_to_channel.setdefault(name, {}) |
|
|
if chan.__name__.endswith("Datum"): |
|
|
key = "datum" |
|
|
elif chan.__name__.endswith("Value"): |
|
|
key = "value" |
|
|
else: |
|
|
key = "field" |
|
|
chans[key] = chan |
|
|
|
|
|
|
|
|
for arg in args: |
|
|
if isinstance(arg, (list, tuple)) and len(arg) > 0: |
|
|
type_ = type(arg[0]) |
|
|
else: |
|
|
type_ = type(arg) |
|
|
|
|
|
encoding = channel_to_name.get(type_, None) |
|
|
if encoding is None: |
|
|
raise NotImplementedError("positional of type {}" "".format(type_)) |
|
|
if encoding in kwargs: |
|
|
raise ValueError("encoding {} specified twice.".format(encoding)) |
|
|
kwargs[encoding] = arg |
|
|
|
|
|
def _wrap_in_channel_class(obj, encoding): |
|
|
if isinstance(obj, SchemaBase): |
|
|
return obj |
|
|
|
|
|
if isinstance(obj, str): |
|
|
obj = {"shorthand": obj} |
|
|
|
|
|
if isinstance(obj, (list, tuple)): |
|
|
return [_wrap_in_channel_class(subobj, encoding) for subobj in obj] |
|
|
|
|
|
if encoding not in name_to_channel: |
|
|
warnings.warn( |
|
|
"Unrecognized encoding channel '{}'".format(encoding), stacklevel=1 |
|
|
) |
|
|
return obj |
|
|
|
|
|
classes = name_to_channel[encoding] |
|
|
cls = classes["value"] if "value" in obj else classes["field"] |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
return cls.from_dict(obj, validate=False) |
|
|
except jsonschema.ValidationError: |
|
|
|
|
|
return obj |
|
|
|
|
|
return { |
|
|
encoding: _wrap_in_channel_class(obj, encoding) |
|
|
for encoding, obj in kwargs.items() |
|
|
} |
|
|
|