| from __future__ import annotations |
|
|
| from typing import Any |
|
|
| from openai import NOT_GIVEN |
| from typing_extensions import TypeGuard |
|
|
| from .exceptions import UserError |
|
|
| _EMPTY_SCHEMA = { |
| "additionalProperties": False, |
| "type": "object", |
| "properties": {}, |
| "required": [], |
| } |
|
|
|
|
| def ensure_strict_json_schema( |
| schema: dict[str, Any], |
| ) -> dict[str, Any]: |
| """Mutates the given JSON schema to ensure it conforms to the `strict` standard |
| that the OpenAI API expects. |
| """ |
| if schema == {}: |
| return _EMPTY_SCHEMA |
| return _ensure_strict_json_schema(schema, path=(), root=schema) |
|
|
|
|
| |
| def _ensure_strict_json_schema( |
| json_schema: object, |
| *, |
| path: tuple[str, ...], |
| root: dict[str, object], |
| ) -> dict[str, Any]: |
| if not is_dict(json_schema): |
| raise TypeError(f"Expected {json_schema} to be a dictionary; path={path}") |
|
|
| defs = json_schema.get("$defs") |
| if is_dict(defs): |
| for def_name, def_schema in defs.items(): |
| _ensure_strict_json_schema(def_schema, path=(*path, "$defs", def_name), root=root) |
|
|
| definitions = json_schema.get("definitions") |
| if is_dict(definitions): |
| for definition_name, definition_schema in definitions.items(): |
| _ensure_strict_json_schema( |
| definition_schema, path=(*path, "definitions", definition_name), root=root |
| ) |
|
|
| typ = json_schema.get("type") |
| if typ == "object" and "additionalProperties" not in json_schema: |
| json_schema["additionalProperties"] = False |
| elif ( |
| typ == "object" |
| and "additionalProperties" in json_schema |
| and json_schema["additionalProperties"] |
| ): |
| raise UserError( |
| "additionalProperties should not be set for object types. This could be because " |
| "you're using an older version of Pydantic, or because you configured additional " |
| "properties to be allowed. If you really need this, update the function or output tool " |
| "to not use a strict schema." |
| ) |
|
|
| |
| |
| properties = json_schema.get("properties") |
| if is_dict(properties): |
| json_schema["required"] = list(properties.keys()) |
| json_schema["properties"] = { |
| key: _ensure_strict_json_schema(prop_schema, path=(*path, "properties", key), root=root) |
| for key, prop_schema in properties.items() |
| } |
|
|
| |
| |
| items = json_schema.get("items") |
| if is_dict(items): |
| json_schema["items"] = _ensure_strict_json_schema(items, path=(*path, "items"), root=root) |
|
|
| |
| any_of = json_schema.get("anyOf") |
| if is_list(any_of): |
| json_schema["anyOf"] = [ |
| _ensure_strict_json_schema(variant, path=(*path, "anyOf", str(i)), root=root) |
| for i, variant in enumerate(any_of) |
| ] |
|
|
| |
| all_of = json_schema.get("allOf") |
| if is_list(all_of): |
| if len(all_of) == 1: |
| json_schema.update( |
| _ensure_strict_json_schema(all_of[0], path=(*path, "allOf", "0"), root=root) |
| ) |
| json_schema.pop("allOf") |
| else: |
| json_schema["allOf"] = [ |
| _ensure_strict_json_schema(entry, path=(*path, "allOf", str(i)), root=root) |
| for i, entry in enumerate(all_of) |
| ] |
|
|
| |
| |
| |
| if json_schema.get("default", NOT_GIVEN) is None: |
| json_schema.pop("default") |
|
|
| |
| |
| |
| |
| |
| ref = json_schema.get("$ref") |
| if ref and has_more_than_n_keys(json_schema, 1): |
| assert isinstance(ref, str), f"Received non-string $ref - {ref}" |
|
|
| resolved = resolve_ref(root=root, ref=ref) |
| if not is_dict(resolved): |
| raise ValueError( |
| f"Expected `$ref: {ref}` to resolved to a dictionary but got {resolved}" |
| ) |
|
|
| |
| json_schema.update({**resolved, **json_schema}) |
| json_schema.pop("$ref") |
| |
| |
| return _ensure_strict_json_schema(json_schema, path=path, root=root) |
|
|
| return json_schema |
|
|
|
|
| def resolve_ref(*, root: dict[str, object], ref: str) -> object: |
| if not ref.startswith("#/"): |
| raise ValueError(f"Unexpected $ref format {ref!r}; Does not start with #/") |
|
|
| path = ref[2:].split("/") |
| resolved = root |
| for key in path: |
| value = resolved[key] |
| assert is_dict(value), ( |
| f"encountered non-dictionary entry while resolving {ref} - {resolved}" |
| ) |
| resolved = value |
|
|
| return resolved |
|
|
|
|
| def is_dict(obj: object) -> TypeGuard[dict[str, object]]: |
| |
| |
| return isinstance(obj, dict) |
|
|
|
|
| def is_list(obj: object) -> TypeGuard[list[object]]: |
| return isinstance(obj, list) |
|
|
|
|
| def has_more_than_n_keys(obj: dict[str, object], n: int) -> bool: |
| i = 0 |
| for _ in obj.keys(): |
| i += 1 |
| if i > n: |
| return True |
| return False |
|
|