| import abc |
| import dataclasses |
| import functools |
| import inspect |
| import sys |
| from dataclasses import Field, fields |
| from typing import Any, Callable, Dict, Optional, Tuple, Union, Type, get_type_hints |
| from enum import Enum |
|
|
| from marshmallow.exceptions import ValidationError |
|
|
| from dataclasses_json.utils import CatchAllVar |
|
|
| KnownParameters = Dict[str, Any] |
| UnknownParameters = Dict[str, Any] |
|
|
|
|
| class _UndefinedParameterAction(abc.ABC): |
| @staticmethod |
| @abc.abstractmethod |
| def handle_from_dict(cls, kvs: Dict[Any, Any]) -> Dict[str, Any]: |
| """ |
| Return the parameters to initialize the class with. |
| """ |
| pass |
|
|
| @staticmethod |
| def handle_to_dict(obj, kvs: Dict[Any, Any]) -> Dict[Any, Any]: |
| """ |
| Return the parameters that will be written to the output dict |
| """ |
| return kvs |
|
|
| @staticmethod |
| def handle_dump(obj) -> Dict[Any, Any]: |
| """ |
| Return the parameters that will be added to the schema dump. |
| """ |
| return {} |
|
|
| @staticmethod |
| def create_init(obj) -> Callable: |
| return obj.__init__ |
|
|
| @staticmethod |
| def _separate_defined_undefined_kvs(cls, kvs: Dict) -> \ |
| Tuple[KnownParameters, UnknownParameters]: |
| """ |
| Returns a 2 dictionaries: defined and undefined parameters |
| """ |
| class_fields = fields(cls) |
| field_names = [field.name for field in class_fields] |
| unknown_given_parameters = {k: v for k, v in kvs.items() if |
| k not in field_names} |
| known_given_parameters = {k: v for k, v in kvs.items() if |
| k in field_names} |
| return known_given_parameters, unknown_given_parameters |
|
|
|
|
| class _RaiseUndefinedParameters(_UndefinedParameterAction): |
| """ |
| This action raises UndefinedParameterError if it encounters an undefined |
| parameter during initialization. |
| """ |
|
|
| @staticmethod |
| def handle_from_dict(cls, kvs: Dict) -> Dict[str, Any]: |
| known, unknown = \ |
| _UndefinedParameterAction._separate_defined_undefined_kvs( |
| cls=cls, kvs=kvs) |
| if len(unknown) > 0: |
| raise UndefinedParameterError( |
| f"Received undefined initialization arguments {unknown}") |
| return known |
|
|
|
|
| CatchAll = Optional[CatchAllVar] |
|
|
|
|
| class _IgnoreUndefinedParameters(_UndefinedParameterAction): |
| """ |
| This action does nothing when it encounters undefined parameters. |
| The undefined parameters can not be retrieved after the class has been |
| created. |
| """ |
|
|
| @staticmethod |
| def handle_from_dict(cls, kvs: Dict) -> Dict[str, Any]: |
| known_given_parameters, _ = \ |
| _UndefinedParameterAction._separate_defined_undefined_kvs( |
| cls=cls, kvs=kvs) |
| return known_given_parameters |
|
|
| @staticmethod |
| def create_init(obj) -> Callable: |
| original_init = obj.__init__ |
| init_signature = inspect.signature(original_init) |
|
|
| @functools.wraps(obj.__init__) |
| def _ignore_init(self, *args, **kwargs): |
| known_kwargs, _ = \ |
| _CatchAllUndefinedParameters._separate_defined_undefined_kvs( |
| obj, kwargs) |
| num_params_takeable = len( |
| init_signature.parameters) - 1 |
| num_args_takeable = num_params_takeable - len(known_kwargs) |
|
|
| args = args[:num_args_takeable] |
| bound_parameters = init_signature.bind_partial(self, *args, |
| **known_kwargs) |
| bound_parameters.apply_defaults() |
|
|
| arguments = bound_parameters.arguments |
| arguments.pop("self", None) |
| final_parameters = \ |
| _IgnoreUndefinedParameters.handle_from_dict(obj, arguments) |
| original_init(self, **final_parameters) |
|
|
| return _ignore_init |
|
|
|
|
| class _CatchAllUndefinedParameters(_UndefinedParameterAction): |
| """ |
| This class allows to add a field of type utils.CatchAll which acts as a |
| dictionary into which all |
| undefined parameters will be written. |
| These parameters are not affected by LetterCase. |
| If no undefined parameters are given, this dictionary will be empty. |
| """ |
|
|
| class _SentinelNoDefault: |
| pass |
|
|
| @staticmethod |
| def handle_from_dict(cls, kvs: Dict) -> Dict[str, Any]: |
| known, unknown = _UndefinedParameterAction \ |
| ._separate_defined_undefined_kvs(cls=cls, kvs=kvs) |
| catch_all_field = _CatchAllUndefinedParameters._get_catch_all_field( |
| cls=cls) |
|
|
| if catch_all_field.name in known: |
|
|
| already_parsed = isinstance(known[catch_all_field.name], dict) |
| default_value = _CatchAllUndefinedParameters._get_default( |
| catch_all_field=catch_all_field) |
| received_default = default_value == known[catch_all_field.name] |
|
|
| value_to_write: Any |
| if received_default and len(unknown) == 0: |
| value_to_write = default_value |
| elif received_default and len(unknown) > 0: |
| value_to_write = unknown |
| elif already_parsed: |
| |
| value_to_write = known[catch_all_field.name] |
| if len(unknown) > 0: |
| value_to_write.update(unknown) |
| else: |
| error_message = f"Received input field with " \ |
| f"same name as catch-all field: " \ |
| f"'{catch_all_field.name}': " \ |
| f"'{known[catch_all_field.name]}'" |
| raise UndefinedParameterError(error_message) |
| else: |
| value_to_write = unknown |
|
|
| known[catch_all_field.name] = value_to_write |
| return known |
|
|
| @staticmethod |
| def _get_default(catch_all_field: Field) -> Any: |
| |
| |
| |
|
|
| |
| has_default = not isinstance(catch_all_field.default, |
| dataclasses._MISSING_TYPE) |
| |
| has_default_factory = not isinstance(catch_all_field.default_factory, |
| |
| dataclasses._MISSING_TYPE) |
| |
| default_value: Union[ |
| Type[_CatchAllUndefinedParameters._SentinelNoDefault], Any] = _CatchAllUndefinedParameters\ |
| ._SentinelNoDefault |
|
|
| if has_default: |
| default_value = catch_all_field.default |
| elif has_default_factory: |
| |
| |
| |
| default_value = catch_all_field.default_factory() |
|
|
| return default_value |
|
|
| @staticmethod |
| def handle_to_dict(obj, kvs: Dict[Any, Any]) -> Dict[Any, Any]: |
| catch_all_field = \ |
| _CatchAllUndefinedParameters._get_catch_all_field(obj.__class__) |
| undefined_parameters = kvs.pop(catch_all_field.name) |
| if isinstance(undefined_parameters, dict): |
| kvs.update( |
| undefined_parameters) |
| return kvs |
|
|
| @staticmethod |
| def handle_dump(obj) -> Dict[Any, Any]: |
| catch_all_field = _CatchAllUndefinedParameters._get_catch_all_field( |
| cls=obj) |
| return getattr(obj, catch_all_field.name) |
|
|
| @staticmethod |
| def create_init(obj) -> Callable: |
| original_init = obj.__init__ |
| init_signature = inspect.signature(original_init) |
|
|
| @functools.wraps(obj.__init__) |
| def _catch_all_init(self, *args, **kwargs): |
| known_kwargs, unknown_kwargs = \ |
| _CatchAllUndefinedParameters._separate_defined_undefined_kvs( |
| obj, kwargs) |
| num_params_takeable = len( |
| init_signature.parameters) - 1 |
| if _CatchAllUndefinedParameters._get_catch_all_field( |
| obj).name not in known_kwargs: |
| num_params_takeable -= 1 |
| num_args_takeable = num_params_takeable - len(known_kwargs) |
|
|
| args, unknown_args = args[:num_args_takeable], args[ |
| num_args_takeable:] |
| bound_parameters = init_signature.bind_partial(self, *args, |
| **known_kwargs) |
|
|
| unknown_args = {f"_UNKNOWN{i}": v for i, v in |
| enumerate(unknown_args)} |
| arguments = bound_parameters.arguments |
| arguments.update(unknown_args) |
| arguments.update(unknown_kwargs) |
| arguments.pop("self", None) |
| final_parameters = _CatchAllUndefinedParameters.handle_from_dict( |
| obj, arguments) |
| original_init(self, **final_parameters) |
|
|
| return _catch_all_init |
|
|
| @staticmethod |
| def _get_catch_all_field(cls) -> Field: |
| cls_globals = vars(sys.modules[cls.__module__]) |
| types = get_type_hints(cls, globalns=cls_globals) |
| catch_all_fields = list( |
| filter(lambda f: types[f.name] == Optional[CatchAllVar], fields(cls))) |
| number_of_catch_all_fields = len(catch_all_fields) |
| if number_of_catch_all_fields == 0: |
| raise UndefinedParameterError( |
| "No field of type dataclasses_json.CatchAll defined") |
| elif number_of_catch_all_fields > 1: |
| raise UndefinedParameterError( |
| f"Multiple catch-all fields supplied: " |
| f"{number_of_catch_all_fields}.") |
| else: |
| return catch_all_fields[0] |
|
|
|
|
| class Undefined(Enum): |
| """ |
| Choose the behavior what happens when an undefined parameter is encountered |
| during class initialization. |
| """ |
| INCLUDE = _CatchAllUndefinedParameters |
| RAISE = _RaiseUndefinedParameters |
| EXCLUDE = _IgnoreUndefinedParameters |
|
|
|
|
| class UndefinedParameterError(ValidationError): |
| """ |
| Raised when something has gone wrong handling undefined parameters. |
| """ |
| pass |
|
|