Spaces:
Running
Running
File size: 1,852 Bytes
0f8b3a0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 | import dataclasses
import types
from typing import Any, get_args, get_origin, get_type_hints, Union
def resolve_type_hints(cls: type) -> dict[str, Any]:
try:
hints = get_type_hints(cls, include_extras=True)
except TypeError:
hints = get_type_hints(cls)
except Exception as e:
raise TypeError(f"Failed to resolve type hints for {cls.__name__}: {e}") from e
for field in dataclasses.fields(cls):
if field.name not in hints:
raise TypeError(f"{cls.__name__}.{field.name} has no type hint. All dataclass fields must be annotated.")
return hints
def _is_union_origin(origin: Any) -> bool:
return origin is Union or origin is types.UnionType
def normalize_type_for_deserialization(tp: Any) -> type:
origin = get_origin(tp)
if _is_union_origin(origin):
args = [a for a in get_args(tp) if a is not type(None)]
if len(args) == 1:
return normalize_type_for_deserialization(args[0])
raise TypeError(f"Unsupported Union type {tp}. Only Optional[T] or T | None are supported.")
if origin is not None:
return origin
if isinstance(tp, type):
return tp
raise TypeError(f"Unsupported type annotation {tp!r}. Use a concrete runtime type.")
def normalize_type_for_serialization(tp: Any) -> type:
origin = get_origin(tp)
if _is_union_origin(origin):
args = [a for a in get_args(tp) if a is not type(None)]
if len(args) == 1:
return normalize_type_for_serialization(args[0])
raise TypeError(f"Unsupported Union type {tp}. Only Optional[T] or T | None are supported.")
if isinstance(tp, type):
return tp
if origin is not None:
return origin
raise TypeError(f"Unsupported type annotation {tp!r} for serialization. Use a concrete runtime type.")
|