| | |
| | |
| |
|
| | |
| | |
| |
|
| |
|
| | import json |
| | import dataclasses |
| | import numpy as np |
| | from dataclasses import Field, MISSING |
| | from typing import IO, TypeVar, Type, get_args, get_origin, Union, Any, Tuple |
| |
|
| | _X = TypeVar("_X") |
| |
|
| |
|
| | def load_dataclass(f: IO, cls: Type[_X], binary: bool = False) -> _X: |
| | """ |
| | Loads to a @dataclass or collection hierarchy including dataclasses |
| | from a json recursively. |
| | Call it like load_dataclass(f, typing.List[FrameAnnotationAnnotation]). |
| | raises KeyError if json has keys not mapping to the dataclass fields. |
| | |
| | Args: |
| | f: Either a path to a file, or a file opened for writing. |
| | cls: The class of the loaded dataclass. |
| | binary: Set to True if `f` is a file handle, else False. |
| | """ |
| | if binary: |
| | asdict = json.loads(f.read().decode("utf8")) |
| | else: |
| | asdict = json.load(f) |
| |
|
| | |
| | cls = get_args(cls)[0] |
| | res = list(_dataclass_list_from_dict_list(asdict, cls)) |
| |
|
| | return res |
| |
|
| |
|
| | def _resolve_optional(type_: Any) -> Tuple[bool, Any]: |
| | """Check whether `type_` is equivalent to `typing.Optional[T]` for some T.""" |
| | if get_origin(type_) is Union: |
| | args = get_args(type_) |
| | if len(args) == 2 and args[1] == type(None): |
| | return True, args[0] |
| | if type_ is Any: |
| | return True, Any |
| |
|
| | return False, type_ |
| |
|
| |
|
| | def _unwrap_type(tp): |
| | |
| | if get_origin(tp) is Union: |
| | args = get_args(tp) |
| | if len(args) == 2 and any(a is type(None) for a in args): |
| | |
| | return args[0] if args[1] is type(None) else args[1] |
| | return tp |
| |
|
| |
|
| | def _get_dataclass_field_default(field: Field) -> Any: |
| | if field.default_factory is not MISSING: |
| | |
| | |
| | return field.default_factory() |
| | elif field.default is not MISSING: |
| | return field.default |
| | else: |
| | return None |
| |
|
| |
|
| | def _dataclass_list_from_dict_list(dlist, typeannot): |
| | """ |
| | Vectorised version of `_dataclass_from_dict`. |
| | The output should be equivalent to |
| | `[_dataclass_from_dict(d, typeannot) for d in dlist]`. |
| | |
| | Args: |
| | dlist: list of objects to convert. |
| | typeannot: type of each of those objects. |
| | Returns: |
| | iterator or list over converted objects of the same length as `dlist`. |
| | |
| | Raises: |
| | ValueError: it assumes the objects have None's in consistent places across |
| | objects, otherwise it would ignore some values. This generally holds for |
| | auto-generated annotations, but otherwise use `_dataclass_from_dict`. |
| | """ |
| |
|
| | cls = get_origin(typeannot) or typeannot |
| |
|
| | if typeannot is Any: |
| | return dlist |
| | if all(obj is None for obj in dlist): |
| | return dlist |
| | if any(obj is None for obj in dlist): |
| | |
| | idx_notnone = [(i, obj) for i, obj in enumerate(dlist) if obj is not None] |
| | idx, notnone = zip(*idx_notnone) |
| | converted = _dataclass_list_from_dict_list(notnone, typeannot) |
| | res = [None] * len(dlist) |
| | for i, obj in zip(idx, converted): |
| | res[i] = obj |
| | return res |
| |
|
| | is_optional, contained_type = _resolve_optional(typeannot) |
| | if is_optional: |
| | return _dataclass_list_from_dict_list(dlist, contained_type) |
| |
|
| | |
| | if issubclass(cls, tuple) and hasattr(cls, "_fields"): |
| | |
| | types = cls.__annotations__.values() |
| | dlist_T = zip(*dlist) |
| | res_T = [ |
| | _dataclass_list_from_dict_list(key_list, tp) for key_list, tp in zip(dlist_T, types) |
| | ] |
| | return [cls(*converted_as_tuple) for converted_as_tuple in zip(*res_T)] |
| | elif issubclass(cls, (list, tuple)): |
| | |
| | types = get_args(typeannot) |
| | if len(types) == 1: |
| | types = types * len(dlist[0]) |
| | dlist_T = zip(*dlist) |
| | res_T = ( |
| | _dataclass_list_from_dict_list(pos_list, tp) for pos_list, tp in zip(dlist_T, types) |
| | ) |
| | if issubclass(cls, tuple): |
| | return list(zip(*res_T)) |
| | else: |
| | return [cls(converted_as_tuple) for converted_as_tuple in zip(*res_T)] |
| | elif issubclass(cls, dict): |
| | |
| | key_t, val_t = get_args(typeannot) |
| | all_keys_res = _dataclass_list_from_dict_list( |
| | [k for obj in dlist for k in obj.keys()], key_t |
| | ) |
| | all_vals_res = _dataclass_list_from_dict_list( |
| | [k for obj in dlist for k in obj.values()], val_t |
| | ) |
| | indices = np.cumsum([len(obj) for obj in dlist]) |
| | assert indices[-1] == len(all_keys_res) |
| |
|
| | keys = np.split(list(all_keys_res), indices[:-1]) |
| | all_vals_res_iter = iter(all_vals_res) |
| | return [cls(zip(k, all_vals_res_iter)) for k in keys] |
| | elif not dataclasses.is_dataclass(typeannot): |
| | return dlist |
| |
|
| | |
| | |
| | assert dataclasses.is_dataclass(cls) |
| | fieldtypes = { |
| | f.name: (_unwrap_type(f.type), _get_dataclass_field_default(f)) |
| | for f in dataclasses.fields(typeannot) |
| | } |
| |
|
| | |
| | key_lists = ( |
| | _dataclass_list_from_dict_list([obj.get(k, default) for obj in dlist], type_) |
| | for k, (type_, default) in fieldtypes.items() |
| | ) |
| | transposed = zip(*key_lists) |
| | return [cls(*vals_as_tuple) for vals_as_tuple in transposed] |
| |
|