| |
|
|
| import typing |
| import warnings |
| import sys |
| from copy import deepcopy |
|
|
| from dataclasses import MISSING, is_dataclass, fields as dc_fields |
| from datetime import datetime |
| from decimal import Decimal |
| from uuid import UUID |
| from enum import Enum |
|
|
| from typing_inspect import is_union_type |
|
|
| from marshmallow import fields, Schema, post_load |
| from marshmallow.exceptions import ValidationError |
|
|
| from dataclasses_json.core import (_is_supported_generic, _decode_dataclass, |
| _ExtendedEncoder, _user_overrides_or_exts) |
| from dataclasses_json.utils import (_is_collection, _is_optional, |
| _issubclass_safe, _timestamp_to_dt_aware, |
| _is_new_type, _get_type_origin, |
| _handle_undefined_parameters_safe, |
| CatchAllVar) |
|
|
|
|
| class _TimestampField(fields.Field): |
| def _serialize(self, value, attr, obj, **kwargs): |
| if value is not None: |
| return value.timestamp() |
| else: |
| if not self.required: |
| return None |
| else: |
| raise ValidationError(self.default_error_messages["required"]) |
|
|
| def _deserialize(self, value, attr, data, **kwargs): |
| if value is not None: |
| return _timestamp_to_dt_aware(value) |
| else: |
| if not self.required: |
| return None |
| else: |
| raise ValidationError(self.default_error_messages["required"]) |
|
|
|
|
| class _IsoField(fields.Field): |
| def _serialize(self, value, attr, obj, **kwargs): |
| if value is not None: |
| return value.isoformat() |
| else: |
| if not self.required: |
| return None |
| else: |
| raise ValidationError(self.default_error_messages["required"]) |
|
|
| def _deserialize(self, value, attr, data, **kwargs): |
| if value is not None: |
| return datetime.fromisoformat(value) |
| else: |
| if not self.required: |
| return None |
| else: |
| raise ValidationError(self.default_error_messages["required"]) |
|
|
|
|
| class _UnionField(fields.Field): |
| def __init__(self, desc, cls, field, *args, **kwargs): |
| self.desc = desc |
| self.cls = cls |
| self.field = field |
| super().__init__(*args, **kwargs) |
|
|
| def _serialize(self, value, attr, obj, **kwargs): |
| if self.allow_none and value is None: |
| return None |
| for type_, schema_ in self.desc.items(): |
| if _issubclass_safe(type(value), type_): |
| if is_dataclass(value): |
| res = schema_._serialize(value, attr, obj, **kwargs) |
| res['__type'] = str(type_.__name__) |
| return res |
| break |
| elif isinstance(value, _get_type_origin(type_)): |
| return schema_._serialize(value, attr, obj, **kwargs) |
| else: |
| warnings.warn( |
| f'The type "{type(value).__name__}" (value: "{value}") ' |
| f'is not in the list of possible types of typing.Union ' |
| f'(dataclass: {self.cls.__name__}, field: {self.field.name}). ' |
| f'Value cannot be serialized properly.') |
| return super()._serialize(value, attr, obj, **kwargs) |
|
|
| def _deserialize(self, value, attr, data, **kwargs): |
| tmp_value = deepcopy(value) |
| if isinstance(tmp_value, dict) and '__type' in tmp_value: |
| dc_name = tmp_value['__type'] |
| for type_, schema_ in self.desc.items(): |
| if is_dataclass(type_) and type_.__name__ == dc_name: |
| del tmp_value['__type'] |
| return schema_._deserialize(tmp_value, attr, data, **kwargs) |
| elif isinstance(tmp_value, dict): |
| warnings.warn( |
| f'Attempting to deserialize "dict" (value: "{tmp_value}) ' |
| f'that does not have a "__type" type specifier field into' |
| f'(dataclass: {self.cls.__name__}, field: {self.field.name}).' |
| f'Deserialization may fail, or deserialization to wrong type may occur.' |
| ) |
| return super()._deserialize(tmp_value, attr, data, **kwargs) |
| else: |
| for type_, schema_ in self.desc.items(): |
| if isinstance(tmp_value, _get_type_origin(type_)): |
| return schema_._deserialize(tmp_value, attr, data, **kwargs) |
| else: |
| warnings.warn( |
| f'The type "{type(tmp_value).__name__}" (value: "{tmp_value}") ' |
| f'is not in the list of possible types of typing.Union ' |
| f'(dataclass: {self.cls.__name__}, field: {self.field.name}). ' |
| f'Value cannot be deserialized properly.') |
| return super()._deserialize(tmp_value, attr, data, **kwargs) |
|
|
|
|
| class _TupleVarLen(fields.List): |
| """ |
| variable-length homogeneous tuples |
| """ |
| def _deserialize(self, value, attr, data, **kwargs): |
| optional_list = super()._deserialize(value, attr, data, **kwargs) |
| return None if optional_list is None else tuple(optional_list) |
|
|
|
|
| TYPES = { |
| typing.Mapping: fields.Mapping, |
| typing.MutableMapping: fields.Mapping, |
| typing.List: fields.List, |
| typing.Dict: fields.Dict, |
| typing.Tuple: fields.Tuple, |
| typing.Callable: fields.Function, |
| typing.Any: fields.Raw, |
| dict: fields.Dict, |
| list: fields.List, |
| tuple: fields.Tuple, |
| str: fields.Str, |
| int: fields.Int, |
| float: fields.Float, |
| bool: fields.Bool, |
| datetime: _TimestampField, |
| UUID: fields.UUID, |
| Decimal: fields.Decimal, |
| CatchAllVar: fields.Dict, |
| } |
|
|
| A = typing.TypeVar('A') |
| JsonData = typing.Union[str, bytes, bytearray] |
| TEncoded = typing.Dict[str, typing.Any] |
| TOneOrMulti = typing.Union[typing.List[A], A] |
| TOneOrMultiEncoded = typing.Union[typing.List[TEncoded], TEncoded] |
|
|
| if sys.version_info >= (3, 7) or typing.TYPE_CHECKING: |
| class SchemaF(Schema, typing.Generic[A]): |
| """Lift Schema into a type constructor""" |
|
|
| def __init__(self, *args, **kwargs): |
| """ |
| Raises exception because this class should not be inherited. |
| This class is helper only. |
| """ |
|
|
| super().__init__(*args, **kwargs) |
| raise NotImplementedError() |
|
|
| @typing.overload |
| def dump(self, obj: typing.List[A], many: typing.Optional[bool] = None) -> typing.List[TEncoded]: |
| |
| pass |
|
|
| @typing.overload |
| def dump(self, obj: A, many: typing.Optional[bool] = None) -> TEncoded: |
| pass |
|
|
| def dump(self, obj: TOneOrMulti, |
| many: typing.Optional[bool] = None) -> TOneOrMultiEncoded: |
| pass |
|
|
| @typing.overload |
| def dumps(self, obj: typing.List[A], many: typing.Optional[bool] = None, *args, |
| **kwargs) -> str: |
| pass |
|
|
| @typing.overload |
| def dumps(self, obj: A, many: typing.Optional[bool] = None, *args, **kwargs) -> str: |
| pass |
|
|
| def dumps(self, obj: TOneOrMulti, many: typing.Optional[bool] = None, *args, |
| **kwargs) -> str: |
| pass |
|
|
| @typing.overload |
| def load(self, data: typing.List[TEncoded], |
| many: bool = True, partial: typing.Optional[bool] = None, |
| unknown: typing.Optional[str] = None) -> \ |
| typing.List[A]: |
| |
| pass |
|
|
| @typing.overload |
| def load(self, data: TEncoded, |
| many: None = None, partial: typing.Optional[bool] = None, |
| unknown: typing.Optional[str] = None) -> A: |
| pass |
|
|
| def load(self, data: TOneOrMultiEncoded, |
| many: typing.Optional[bool] = None, partial: typing.Optional[bool] = None, |
| unknown: typing.Optional[str] = None) -> TOneOrMulti: |
| pass |
|
|
| @typing.overload |
| def loads(self, json_data: JsonData, |
| many: typing.Optional[bool] = True, partial: typing.Optional[bool] = None, unknown: typing.Optional[str] = None, |
| **kwargs) -> typing.List[A]: |
| |
| |
| |
| pass |
|
|
| def loads(self, json_data: JsonData, |
| many: typing.Optional[bool] = None, partial: typing.Optional[bool] = None, unknown: typing.Optional[str] = None, |
| **kwargs) -> TOneOrMulti: |
| pass |
|
|
|
|
| SchemaType = SchemaF[A] |
| else: |
| SchemaType = Schema |
|
|
|
|
| def build_type(type_, options, mixin, field, cls): |
| def inner(type_, options): |
| while True: |
| if not _is_new_type(type_): |
| break |
|
|
| type_ = type_.__supertype__ |
|
|
| if is_dataclass(type_): |
| if _issubclass_safe(type_, mixin): |
| options['field_many'] = bool( |
| _is_supported_generic(field.type) and _is_collection( |
| field.type)) |
| return fields.Nested(type_.schema(), **options) |
| else: |
| warnings.warn(f"Nested dataclass field {field.name} of type " |
| f"{field.type} detected in " |
| f"{cls.__name__} that is not an instance of " |
| f"dataclass_json. Did you mean to recursively " |
| f"serialize this field? If so, make sure to " |
| f"augment {type_} with either the " |
| f"`dataclass_json` decorator or mixin.") |
| return fields.Field(**options) |
|
|
| origin = getattr(type_, '__origin__', type_) |
| args = [inner(a, {}) for a in getattr(type_, '__args__', []) if |
| a is not type(None)] |
|
|
| if type_ == Ellipsis: |
| return type_ |
|
|
| if _is_optional(type_): |
| options["allow_none"] = True |
| if origin is tuple: |
| if len(args) == 2 and args[1] == Ellipsis: |
| return _TupleVarLen(args[0], **options) |
| else: |
| return fields.Tuple(args, **options) |
| if origin in TYPES: |
| return TYPES[origin](*args, **options) |
|
|
| if _issubclass_safe(origin, Enum): |
| return fields.Enum(enum=origin, by_value=True, *args, **options) |
|
|
| if is_union_type(type_): |
| union_types = [a for a in getattr(type_, '__args__', []) if |
| a is not type(None)] |
| union_desc = dict(zip(union_types, args)) |
| return _UnionField(union_desc, cls, field, **options) |
|
|
| warnings.warn( |
| f"Unknown type {type_} at {cls.__name__}.{field.name}: {field.type} " |
| f"It's advised to pass the correct marshmallow type to `mm_field`.") |
| return fields.Field(**options) |
|
|
| return inner(type_, options) |
|
|
|
|
| def schema(cls, mixin, infer_missing): |
| schema = {} |
| overrides = _user_overrides_or_exts(cls) |
| |
| |
| for field in dc_fields(cls): |
| metadata = overrides[field.name] |
| if metadata.mm_field is not None: |
| schema[field.name] = metadata.mm_field |
| else: |
| type_ = field.type |
| options: typing.Dict[str, typing.Any] = {} |
| missing_key = 'missing' if infer_missing else 'default' |
| if field.default is not MISSING: |
| options[missing_key] = field.default |
| elif field.default_factory is not MISSING: |
| options[missing_key] = field.default_factory() |
| else: |
| options['required'] = True |
|
|
| if options.get(missing_key, ...) is None: |
| options['allow_none'] = True |
|
|
| if _is_optional(type_): |
| options.setdefault(missing_key, None) |
| options['allow_none'] = True |
| if len(type_.__args__) == 2: |
| |
| type_ = [tp for tp in type_.__args__ if tp is not type(None)][0] |
|
|
| if metadata.letter_case is not None: |
| options['data_key'] = metadata.letter_case(field.name) |
|
|
| t = build_type(type_, options, mixin, field, cls) |
| if field.metadata.get('dataclasses_json', {}).get('decoder'): |
| |
| |
| |
| |
| |
| |
| |
| t._deserialize = lambda v, *_a, **_kw: v |
|
|
| |
| if field.type != typing.Optional[CatchAllVar]: |
| schema[field.name] = t |
|
|
| return schema |
|
|
|
|
| def build_schema(cls: typing.Type[A], |
| mixin, |
| infer_missing, |
| partial) -> typing.Type["SchemaType[A]"]: |
| Meta = type('Meta', |
| (), |
| {'fields': tuple(field.name for field in dc_fields(cls) |
| if |
| field.name != 'dataclass_json_config' and field.type != |
| typing.Optional[CatchAllVar]), |
| |
| |
| }) |
|
|
| @post_load |
| def make_instance(self, kvs, **kwargs): |
| return _decode_dataclass(cls, kvs, partial) |
|
|
| def dumps(self, *args, **kwargs): |
| if 'cls' not in kwargs: |
| kwargs['cls'] = _ExtendedEncoder |
|
|
| return Schema.dumps(self, *args, **kwargs) |
|
|
| def dump(self, obj, *, many=None): |
| many = self.many if many is None else bool(many) |
| dumped = Schema.dump(self, obj, many=many) |
| |
| |
|
|
| |
| |
| if many: |
| for i, _obj in enumerate(obj): |
| dumped[i].update( |
| _handle_undefined_parameters_safe(cls=_obj, kvs={}, |
| usage="dump")) |
| else: |
| dumped.update(_handle_undefined_parameters_safe(cls=obj, kvs={}, |
| usage="dump")) |
| return dumped |
|
|
| schema_ = schema(cls, mixin, infer_missing) |
| DataClassSchema: typing.Type["SchemaType[A]"] = type( |
| f'{cls.__name__.capitalize()}Schema', |
| (Schema,), |
| {'Meta': Meta, |
| f'make_{cls.__name__.lower()}': make_instance, |
| 'dumps': dumps, |
| 'dump': dump, |
| **schema_}) |
|
|
| return DataClassSchema |
|
|