| | from __future__ import annotations |
| |
|
| | import re |
| | import textwrap |
| | from collections.abc import Iterable |
| | from typing import Any, Optional, Callable |
| |
|
| | from . import inspect as mi, to_builtins |
| |
|
| | __all__ = ("schema", "schema_components") |
| |
|
| |
|
| | def schema( |
| | type: Any, *, schema_hook: Optional[Callable[[type], dict[str, Any]]] = None |
| | ) -> dict[str, Any]: |
| | """Generate a JSON Schema for a given type. |
| | |
| | Any schemas for (potentially) shared components are extracted and stored in |
| | a top-level ``"$defs"`` field. |
| | |
| | If you want to generate schemas for multiple types, or to have more control |
| | over the generated schema you may want to use ``schema_components`` instead. |
| | |
| | Parameters |
| | ---------- |
| | type : type |
| | The type to generate the schema for. |
| | schema_hook : callable, optional |
| | An optional callback to use for generating JSON schemas of custom |
| | types. Will be called with the custom type, and should return a dict |
| | representation of the JSON schema for that type. |
| | |
| | Returns |
| | ------- |
| | schema : dict |
| | The generated JSON Schema. |
| | |
| | See Also |
| | -------- |
| | schema_components |
| | """ |
| | (out,), components = schema_components((type,), schema_hook=schema_hook) |
| | if components: |
| | out["$defs"] = components |
| | return out |
| |
|
| |
|
| | def schema_components( |
| | types: Iterable[Any], |
| | *, |
| | schema_hook: Optional[Callable[[type], dict[str, Any]]] = None, |
| | ref_template: str = "#/$defs/{name}", |
| | ) -> tuple[tuple[dict[str, Any], ...], dict[str, Any]]: |
| | """Generate JSON Schemas for one or more types. |
| | |
| | Any schemas for (potentially) shared components are extracted and returned |
| | in a separate ``components`` dict. |
| | |
| | Parameters |
| | ---------- |
| | types : Iterable[type] |
| | An iterable of one or more types to generate schemas for. |
| | schema_hook : callable, optional |
| | An optional callback to use for generating JSON schemas of custom |
| | types. Will be called with the custom type, and should return a dict |
| | representation of the JSON schema for that type. |
| | ref_template : str, optional |
| | A template to use when generating ``"$ref"`` fields. This template is |
| | formatted with the type name as ``template.format(name=name)``. This |
| | can be useful if you intend to store the ``components`` mapping |
| | somewhere other than a top-level ``"$defs"`` field. For example, you |
| | might use ``ref_template="#/components/{name}"`` if generating an |
| | OpenAPI schema. |
| | |
| | Returns |
| | ------- |
| | schemas : tuple[dict] |
| | A tuple of JSON Schemas, one for each type in ``types``. |
| | components : dict |
| | A mapping of name to schema for any shared components used by |
| | ``schemas``. |
| | |
| | See Also |
| | -------- |
| | schema |
| | """ |
| | type_infos = mi.multi_type_info(types) |
| |
|
| | component_types = _collect_component_types(type_infos) |
| |
|
| | name_map = _build_name_map(component_types) |
| |
|
| | gen = _SchemaGenerator(name_map, schema_hook, ref_template) |
| |
|
| | schemas = tuple(gen.to_schema(t) for t in type_infos) |
| |
|
| | components = { |
| | name_map[cls]: gen.to_schema(t, False) for cls, t in component_types.items() |
| | } |
| | return schemas, components |
| |
|
| |
|
| | def _collect_component_types(type_infos: Iterable[mi.Type]) -> dict[Any, mi.Type]: |
| | """Find all types in the type tree that are "nameable" and worthy of being |
| | extracted out into a shared top-level components mapping. |
| | |
| | Currently this looks for Struct, Dataclass, NamedTuple, TypedDict, and Enum |
| | types. |
| | """ |
| | components = {} |
| |
|
| | def collect(t): |
| | if isinstance( |
| | t, (mi.StructType, mi.TypedDictType, mi.DataclassType, mi.NamedTupleType) |
| | ): |
| | if t.cls not in components: |
| | components[t.cls] = t |
| | for f in t.fields: |
| | collect(f.type) |
| | elif isinstance(t, mi.EnumType): |
| | components[t.cls] = t |
| | elif isinstance(t, mi.Metadata): |
| | collect(t.type) |
| | elif isinstance(t, mi.CollectionType): |
| | collect(t.item_type) |
| | elif isinstance(t, mi.TupleType): |
| | for st in t.item_types: |
| | collect(st) |
| | elif isinstance(t, mi.DictType): |
| | collect(t.key_type) |
| | collect(t.value_type) |
| | elif isinstance(t, mi.UnionType): |
| | for st in t.types: |
| | collect(st) |
| |
|
| | for t in type_infos: |
| | collect(t) |
| |
|
| | return components |
| |
|
| |
|
| | def _type_repr(obj): |
| | return obj.__name__ if isinstance(obj, type) else repr(obj) |
| |
|
| |
|
| | def _get_class_name(cls: Any) -> str: |
| | if hasattr(cls, "__origin__"): |
| | name = cls.__origin__.__name__ |
| | args = ", ".join(_type_repr(a) for a in cls.__args__) |
| | return f"{name}[{args}]" |
| | return cls.__name__ |
| |
|
| |
|
| | def _get_doc(t: mi.Type) -> str: |
| | assert hasattr(t, "cls") |
| | cls = getattr(t.cls, "__origin__", t.cls) |
| | doc = getattr(cls, "__doc__", "") |
| | if not doc: |
| | return "" |
| | doc = textwrap.dedent(doc).strip("\r\n") |
| | if isinstance(t, mi.EnumType): |
| | if doc == "An enumeration.": |
| | return "" |
| | elif isinstance(t, (mi.NamedTupleType, mi.DataclassType)): |
| | if doc.startswith(f"{cls.__name__}(") and doc.endswith(")"): |
| | return "" |
| | return doc |
| |
|
| |
|
| | def _build_name_map(component_types: dict[Any, mi.Type]) -> dict[Any, str]: |
| | """A mapping from nameable subcomponents to a generated name. |
| | |
| | The generated name is usually a normalized version of the class name. In |
| | the case of conflicts, the name will be expanded to also include the full |
| | import path. |
| | """ |
| |
|
| | def normalize(name): |
| | return re.sub(r"[^a-zA-Z0-9.\-_]", "_", name) |
| |
|
| | def fullname(cls): |
| | return normalize(f"{cls.__module__}.{cls.__qualname__}") |
| |
|
| | conflicts = set() |
| | names: dict[str, Any] = {} |
| |
|
| | for cls in component_types: |
| | name = normalize(_get_class_name(cls)) |
| | if name in names: |
| | old = names.pop(name) |
| | conflicts.add(name) |
| | names[fullname(old)] = old |
| | if name in conflicts: |
| | names[fullname(cls)] = cls |
| | else: |
| | names[name] = cls |
| | return {v: k for k, v in names.items()} |
| |
|
| |
|
| | class _SchemaGenerator: |
| | def __init__( |
| | self, |
| | name_map: dict[Any, str], |
| | schema_hook: Optional[Callable[[type], dict[str, Any]]] = None, |
| | ref_template: str = "#/$defs/{name}", |
| | ): |
| | self.name_map = name_map |
| | self.schema_hook = schema_hook |
| | self.ref_template = ref_template |
| |
|
| | def to_schema(self, t: mi.Type, check_ref: bool = True) -> dict[str, Any]: |
| | """Converts a Type to a json-schema.""" |
| | schema: dict[str, Any] = {} |
| |
|
| | while isinstance(t, mi.Metadata): |
| | schema = mi._merge_json(schema, t.extra_json_schema) |
| | t = t.type |
| |
|
| | if check_ref and hasattr(t, "cls"): |
| | if name := self.name_map.get(t.cls): |
| | schema["$ref"] = self.ref_template.format(name=name) |
| | return schema |
| |
|
| | if isinstance(t, (mi.AnyType, mi.RawType)): |
| | pass |
| | elif isinstance(t, mi.NoneType): |
| | schema["type"] = "null" |
| | elif isinstance(t, mi.BoolType): |
| | schema["type"] = "boolean" |
| | elif isinstance(t, (mi.IntType, mi.FloatType)): |
| | schema["type"] = "integer" if isinstance(t, mi.IntType) else "number" |
| | if t.ge is not None: |
| | schema["minimum"] = t.ge |
| | if t.gt is not None: |
| | schema["exclusiveMinimum"] = t.gt |
| | if t.le is not None: |
| | schema["maximum"] = t.le |
| | if t.lt is not None: |
| | schema["exclusiveMaximum"] = t.lt |
| | if t.multiple_of is not None: |
| | schema["multipleOf"] = t.multiple_of |
| | elif isinstance(t, mi.StrType): |
| | schema["type"] = "string" |
| | if t.max_length is not None: |
| | schema["maxLength"] = t.max_length |
| | if t.min_length is not None: |
| | schema["minLength"] = t.min_length |
| | if t.pattern is not None: |
| | schema["pattern"] = t.pattern |
| | elif isinstance(t, (mi.BytesType, mi.ByteArrayType, mi.MemoryViewType)): |
| | schema["type"] = "string" |
| | schema["contentEncoding"] = "base64" |
| | if t.max_length is not None: |
| | schema["maxLength"] = 4 * ((t.max_length + 2) // 3) |
| | if t.min_length is not None: |
| | schema["minLength"] = 4 * ((t.min_length + 2) // 3) |
| | elif isinstance(t, mi.DateTimeType): |
| | schema["type"] = "string" |
| | if t.tz is True: |
| | schema["format"] = "date-time" |
| | elif isinstance(t, mi.TimeType): |
| | schema["type"] = "string" |
| | if t.tz is True: |
| | schema["format"] = "time" |
| | elif t.tz is False: |
| | schema["format"] = "partial-time" |
| | elif isinstance(t, mi.DateType): |
| | schema["type"] = "string" |
| | schema["format"] = "date" |
| | elif isinstance(t, mi.TimeDeltaType): |
| | schema["type"] = "string" |
| | schema["format"] = "duration" |
| | elif isinstance(t, mi.UUIDType): |
| | schema["type"] = "string" |
| | schema["format"] = "uuid" |
| | elif isinstance(t, mi.DecimalType): |
| | schema["type"] = "string" |
| | schema["format"] = "decimal" |
| | elif isinstance(t, mi.CollectionType): |
| | schema["type"] = "array" |
| | if not isinstance(t.item_type, mi.AnyType): |
| | schema["items"] = self.to_schema(t.item_type) |
| | if t.max_length is not None: |
| | schema["maxItems"] = t.max_length |
| | if t.min_length is not None: |
| | schema["minItems"] = t.min_length |
| | elif isinstance(t, mi.TupleType): |
| | schema["type"] = "array" |
| | schema["minItems"] = schema["maxItems"] = len(t.item_types) |
| | if t.item_types: |
| | schema["prefixItems"] = [self.to_schema(i) for i in t.item_types] |
| | schema["items"] = False |
| | elif isinstance(t, mi.DictType): |
| | schema["type"] = "object" |
| | |
| | if isinstance(key_type := t.key_type, mi.StrType): |
| | property_names: dict[str, Any] = {} |
| | if key_type.min_length is not None: |
| | property_names["minLength"] = key_type.min_length |
| | if key_type.max_length is not None: |
| | property_names["maxLength"] = key_type.max_length |
| | if key_type.pattern is not None: |
| | property_names["pattern"] = key_type.pattern |
| | if property_names: |
| | schema["propertyNames"] = property_names |
| | if not isinstance(t.value_type, mi.AnyType): |
| | schema["additionalProperties"] = self.to_schema(t.value_type) |
| | if t.max_length is not None: |
| | schema["maxProperties"] = t.max_length |
| | if t.min_length is not None: |
| | schema["minProperties"] = t.min_length |
| | elif isinstance(t, mi.UnionType): |
| | structs = {} |
| | other = [] |
| | tag_field = None |
| | for subtype in t.types: |
| | real_type = subtype |
| | while isinstance(real_type, mi.Metadata): |
| | real_type = real_type.type |
| | if isinstance(real_type, mi.StructType) and not real_type.array_like: |
| | tag_field = real_type.tag_field |
| | structs[real_type.tag] = real_type |
| | else: |
| | other.append(subtype) |
| |
|
| | options = [self.to_schema(a) for a in other] |
| |
|
| | if len(structs) >= 2: |
| | mapping = { |
| | k: self.ref_template.format(name=self.name_map[v.cls]) |
| | for k, v in structs.items() |
| | } |
| | struct_schema = { |
| | "anyOf": [self.to_schema(v) for v in structs.values()], |
| | "discriminator": {"propertyName": tag_field, "mapping": mapping}, |
| | } |
| | if options: |
| | options.append(struct_schema) |
| | schema["anyOf"] = options |
| | else: |
| | schema.update(struct_schema) |
| | elif len(structs) == 1: |
| | _, subtype = structs.popitem() |
| | options.append(self.to_schema(subtype)) |
| | schema["anyOf"] = options |
| | else: |
| | schema["anyOf"] = options |
| | elif isinstance(t, mi.LiteralType): |
| | schema["enum"] = sorted(t.values) |
| | elif isinstance(t, mi.EnumType): |
| | schema.setdefault("title", t.cls.__name__) |
| | if doc := _get_doc(t): |
| | schema.setdefault("description", doc) |
| | schema["enum"] = sorted(e.value for e in t.cls) |
| | elif isinstance(t, mi.StructType): |
| | schema.setdefault("title", _get_class_name(t.cls)) |
| | if doc := _get_doc(t): |
| | schema.setdefault("description", doc) |
| | required = [] |
| | names = [] |
| | fields = [] |
| |
|
| | if t.tag_field is not None: |
| | required.append(t.tag_field) |
| | names.append(t.tag_field) |
| | fields.append({"enum": [t.tag]}) |
| |
|
| | for field in t.fields: |
| | field_schema = self.to_schema(field.type) |
| | if field.required: |
| | required.append(field.encode_name) |
| | elif field.default is not mi.NODEFAULT: |
| | field_schema["default"] = to_builtins(field.default, str_keys=True) |
| | elif field.default_factory in (list, dict, set, bytearray): |
| | field_schema["default"] = field.default_factory() |
| | names.append(field.encode_name) |
| | fields.append(field_schema) |
| |
|
| | if t.array_like: |
| | n_trailing_defaults = 0 |
| | for n_trailing_defaults, f in enumerate(reversed(t.fields)): |
| | if f.required: |
| | break |
| | schema["type"] = "array" |
| | schema["prefixItems"] = fields |
| | schema["minItems"] = len(fields) - n_trailing_defaults |
| | if t.forbid_unknown_fields: |
| | schema["maxItems"] = len(fields) |
| | else: |
| | schema["type"] = "object" |
| | schema["properties"] = dict(zip(names, fields)) |
| | schema["required"] = required |
| | if t.forbid_unknown_fields: |
| | schema["additionalProperties"] = False |
| | elif isinstance(t, (mi.TypedDictType, mi.DataclassType, mi.NamedTupleType)): |
| | schema.setdefault("title", _get_class_name(t.cls)) |
| | if doc := _get_doc(t): |
| | schema.setdefault("description", doc) |
| | names = [] |
| | fields = [] |
| | required = [] |
| | for field in t.fields: |
| | field_schema = self.to_schema(field.type) |
| | if field.required: |
| | required.append(field.encode_name) |
| | elif field.default is not mi.NODEFAULT: |
| | field_schema["default"] = to_builtins(field.default, str_keys=True) |
| | names.append(field.encode_name) |
| | fields.append(field_schema) |
| | if isinstance(t, mi.NamedTupleType): |
| | schema["type"] = "array" |
| | schema["prefixItems"] = fields |
| | schema["minItems"] = len(required) |
| | schema["maxItems"] = len(fields) |
| | else: |
| | schema["type"] = "object" |
| | schema["properties"] = dict(zip(names, fields)) |
| | schema["required"] = required |
| | elif isinstance(t, mi.ExtType): |
| | raise TypeError("json-schema doesn't support msgpack Ext types") |
| | elif isinstance(t, mi.CustomType): |
| | if self.schema_hook: |
| | try: |
| | schema = mi._merge_json(self.schema_hook(t.cls), schema) |
| | except NotImplementedError: |
| | pass |
| | if not schema: |
| | raise TypeError( |
| | "Generating JSON schema for custom types requires either:\n" |
| | "- specifying a `schema_hook`\n" |
| | "- annotating the type with `Meta(extra_json_schema=...)`\n" |
| | "\n" |
| | f"type {t.cls!r} is not supported" |
| | ) |
| | else: |
| | |
| | raise TypeError(f"json-schema doesn't support type {t!r}") |
| |
|
| | return schema |
| |
|