| import copy
|
| import http.client
|
| import inspect
|
| import warnings
|
| from collections.abc import Sequence
|
| from typing import Any, Optional, Union, cast
|
|
|
| from fastapi import routing
|
| from fastapi._compat import (
|
| ModelField,
|
| Undefined,
|
| get_definitions,
|
| get_flat_models_from_fields,
|
| get_model_name_map,
|
| get_schema_from_model_field,
|
| lenient_issubclass,
|
| )
|
| from fastapi.datastructures import DefaultPlaceholder
|
| from fastapi.dependencies.models import Dependant
|
| from fastapi.dependencies.utils import (
|
| _get_flat_fields_from_params,
|
| get_flat_dependant,
|
| get_flat_params,
|
| get_validation_alias,
|
| )
|
| from fastapi.encoders import jsonable_encoder
|
| from fastapi.exceptions import FastAPIDeprecationWarning
|
| from fastapi.openapi.constants import METHODS_WITH_BODY, REF_PREFIX
|
| from fastapi.openapi.models import OpenAPI
|
| from fastapi.params import Body, ParamTypes
|
| from fastapi.responses import Response
|
| from fastapi.types import ModelNameMap
|
| from fastapi.utils import (
|
| deep_dict_update,
|
| generate_operation_id_for_path,
|
| is_body_allowed_for_status_code,
|
| )
|
| from pydantic import BaseModel
|
| from starlette.responses import JSONResponse
|
| from starlette.routing import BaseRoute
|
| from typing_extensions import Literal
|
|
|
| validation_error_definition = {
|
| "title": "ValidationError",
|
| "type": "object",
|
| "properties": {
|
| "loc": {
|
| "title": "Location",
|
| "type": "array",
|
| "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
|
| },
|
| "msg": {"title": "Message", "type": "string"},
|
| "type": {"title": "Error Type", "type": "string"},
|
| "input": {"title": "Input"},
|
| "ctx": {"title": "Context", "type": "object"},
|
| },
|
| "required": ["loc", "msg", "type"],
|
| }
|
|
|
| validation_error_response_definition = {
|
| "title": "HTTPValidationError",
|
| "type": "object",
|
| "properties": {
|
| "detail": {
|
| "title": "Detail",
|
| "type": "array",
|
| "items": {"$ref": REF_PREFIX + "ValidationError"},
|
| }
|
| },
|
| }
|
|
|
| status_code_ranges: dict[str, str] = {
|
| "1XX": "Information",
|
| "2XX": "Success",
|
| "3XX": "Redirection",
|
| "4XX": "Client Error",
|
| "5XX": "Server Error",
|
| "DEFAULT": "Default Response",
|
| }
|
|
|
|
|
| def get_openapi_security_definitions(
|
| flat_dependant: Dependant,
|
| ) -> tuple[dict[str, Any], list[dict[str, Any]]]:
|
| security_definitions = {}
|
|
|
| operation_security_dict: dict[str, list[str]] = {}
|
| for security_dependency in flat_dependant._security_dependencies:
|
| security_definition = jsonable_encoder(
|
| security_dependency._security_scheme.model,
|
| by_alias=True,
|
| exclude_none=True,
|
| )
|
| security_name = security_dependency._security_scheme.scheme_name
|
| security_definitions[security_name] = security_definition
|
|
|
| if security_name not in operation_security_dict:
|
| operation_security_dict[security_name] = []
|
| for scope in security_dependency.oauth_scopes or []:
|
| if scope not in operation_security_dict[security_name]:
|
| operation_security_dict[security_name].append(scope)
|
| operation_security = [
|
| {name: scopes} for name, scopes in operation_security_dict.items()
|
| ]
|
| return security_definitions, operation_security
|
|
|
|
|
| def _get_openapi_operation_parameters(
|
| *,
|
| dependant: Dependant,
|
| model_name_map: ModelNameMap,
|
| field_mapping: dict[
|
| tuple[ModelField, Literal["validation", "serialization"]], dict[str, Any]
|
| ],
|
| separate_input_output_schemas: bool = True,
|
| ) -> list[dict[str, Any]]:
|
| parameters = []
|
| flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
|
| path_params = _get_flat_fields_from_params(flat_dependant.path_params)
|
| query_params = _get_flat_fields_from_params(flat_dependant.query_params)
|
| header_params = _get_flat_fields_from_params(flat_dependant.header_params)
|
| cookie_params = _get_flat_fields_from_params(flat_dependant.cookie_params)
|
| parameter_groups = [
|
| (ParamTypes.path, path_params),
|
| (ParamTypes.query, query_params),
|
| (ParamTypes.header, header_params),
|
| (ParamTypes.cookie, cookie_params),
|
| ]
|
| default_convert_underscores = True
|
| if len(flat_dependant.header_params) == 1:
|
| first_field = flat_dependant.header_params[0]
|
| if lenient_issubclass(first_field.field_info.annotation, BaseModel):
|
| default_convert_underscores = getattr(
|
| first_field.field_info, "convert_underscores", True
|
| )
|
| for param_type, param_group in parameter_groups:
|
| for param in param_group:
|
| field_info = param.field_info
|
|
|
| if not getattr(field_info, "include_in_schema", True):
|
| continue
|
| param_schema = get_schema_from_model_field(
|
| field=param,
|
| model_name_map=model_name_map,
|
| field_mapping=field_mapping,
|
| separate_input_output_schemas=separate_input_output_schemas,
|
| )
|
| name = get_validation_alias(param)
|
| convert_underscores = getattr(
|
| param.field_info,
|
| "convert_underscores",
|
| default_convert_underscores,
|
| )
|
| if (
|
| param_type == ParamTypes.header
|
| and name == param.name
|
| and convert_underscores
|
| ):
|
| name = param.name.replace("_", "-")
|
|
|
| parameter = {
|
| "name": name,
|
| "in": param_type.value,
|
| "required": param.field_info.is_required(),
|
| "schema": param_schema,
|
| }
|
| if field_info.description:
|
| parameter["description"] = field_info.description
|
| openapi_examples = getattr(field_info, "openapi_examples", None)
|
| example = getattr(field_info, "example", None)
|
| if openapi_examples:
|
| parameter["examples"] = jsonable_encoder(openapi_examples)
|
| elif example != Undefined:
|
| parameter["example"] = jsonable_encoder(example)
|
| if getattr(field_info, "deprecated", None):
|
| parameter["deprecated"] = True
|
| parameters.append(parameter)
|
| return parameters
|
|
|
|
|
| def get_openapi_operation_request_body(
|
| *,
|
| body_field: Optional[ModelField],
|
| model_name_map: ModelNameMap,
|
| field_mapping: dict[
|
| tuple[ModelField, Literal["validation", "serialization"]], dict[str, Any]
|
| ],
|
| separate_input_output_schemas: bool = True,
|
| ) -> Optional[dict[str, Any]]:
|
| if not body_field:
|
| return None
|
| assert isinstance(body_field, ModelField)
|
| body_schema = get_schema_from_model_field(
|
| field=body_field,
|
| model_name_map=model_name_map,
|
| field_mapping=field_mapping,
|
| separate_input_output_schemas=separate_input_output_schemas,
|
| )
|
| field_info = cast(Body, body_field.field_info)
|
| request_media_type = field_info.media_type
|
| required = body_field.field_info.is_required()
|
| request_body_oai: dict[str, Any] = {}
|
| if required:
|
| request_body_oai["required"] = required
|
| request_media_content: dict[str, Any] = {"schema": body_schema}
|
| if field_info.openapi_examples:
|
| request_media_content["examples"] = jsonable_encoder(
|
| field_info.openapi_examples
|
| )
|
| elif field_info.example != Undefined:
|
| request_media_content["example"] = jsonable_encoder(field_info.example)
|
| request_body_oai["content"] = {request_media_type: request_media_content}
|
| return request_body_oai
|
|
|
|
|
| def generate_operation_id(
|
| *, route: routing.APIRoute, method: str
|
| ) -> str:
|
| warnings.warn(
|
| message="fastapi.openapi.utils.generate_operation_id() was deprecated, "
|
| "it is not used internally, and will be removed soon",
|
| category=FastAPIDeprecationWarning,
|
| stacklevel=2,
|
| )
|
| if route.operation_id:
|
| return route.operation_id
|
| path: str = route.path_format
|
| return generate_operation_id_for_path(name=route.name, path=path, method=method)
|
|
|
|
|
| def generate_operation_summary(*, route: routing.APIRoute, method: str) -> str:
|
| if route.summary:
|
| return route.summary
|
| return route.name.replace("_", " ").title()
|
|
|
|
|
| def get_openapi_operation_metadata(
|
| *, route: routing.APIRoute, method: str, operation_ids: set[str]
|
| ) -> dict[str, Any]:
|
| operation: dict[str, Any] = {}
|
| if route.tags:
|
| operation["tags"] = route.tags
|
| operation["summary"] = generate_operation_summary(route=route, method=method)
|
| if route.description:
|
| operation["description"] = route.description
|
| operation_id = route.operation_id or route.unique_id
|
| if operation_id in operation_ids:
|
| message = (
|
| f"Duplicate Operation ID {operation_id} for function "
|
| + f"{route.endpoint.__name__}"
|
| )
|
| file_name = getattr(route.endpoint, "__globals__", {}).get("__file__")
|
| if file_name:
|
| message += f" at {file_name}"
|
| warnings.warn(message, stacklevel=1)
|
| operation_ids.add(operation_id)
|
| operation["operationId"] = operation_id
|
| if route.deprecated:
|
| operation["deprecated"] = route.deprecated
|
| return operation
|
|
|
|
|
| def get_openapi_path(
|
| *,
|
| route: routing.APIRoute,
|
| operation_ids: set[str],
|
| model_name_map: ModelNameMap,
|
| field_mapping: dict[
|
| tuple[ModelField, Literal["validation", "serialization"]], dict[str, Any]
|
| ],
|
| separate_input_output_schemas: bool = True,
|
| ) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
|
| path = {}
|
| security_schemes: dict[str, Any] = {}
|
| definitions: dict[str, Any] = {}
|
| assert route.methods is not None, "Methods must be a list"
|
| if isinstance(route.response_class, DefaultPlaceholder):
|
| current_response_class: type[Response] = route.response_class.value
|
| else:
|
| current_response_class = route.response_class
|
| assert current_response_class, "A response class is needed to generate OpenAPI"
|
| route_response_media_type: Optional[str] = current_response_class.media_type
|
| if route.include_in_schema:
|
| for method in route.methods:
|
| operation = get_openapi_operation_metadata(
|
| route=route, method=method, operation_ids=operation_ids
|
| )
|
| parameters: list[dict[str, Any]] = []
|
| flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True)
|
| security_definitions, operation_security = get_openapi_security_definitions(
|
| flat_dependant=flat_dependant
|
| )
|
| if operation_security:
|
| operation.setdefault("security", []).extend(operation_security)
|
| if security_definitions:
|
| security_schemes.update(security_definitions)
|
| operation_parameters = _get_openapi_operation_parameters(
|
| dependant=route.dependant,
|
| model_name_map=model_name_map,
|
| field_mapping=field_mapping,
|
| separate_input_output_schemas=separate_input_output_schemas,
|
| )
|
| parameters.extend(operation_parameters)
|
| if parameters:
|
| all_parameters = {
|
| (param["in"], param["name"]): param for param in parameters
|
| }
|
| required_parameters = {
|
| (param["in"], param["name"]): param
|
| for param in parameters
|
| if param.get("required")
|
| }
|
|
|
|
|
| all_parameters.update(required_parameters)
|
| operation["parameters"] = list(all_parameters.values())
|
| if method in METHODS_WITH_BODY:
|
| request_body_oai = get_openapi_operation_request_body(
|
| body_field=route.body_field,
|
| model_name_map=model_name_map,
|
| field_mapping=field_mapping,
|
| separate_input_output_schemas=separate_input_output_schemas,
|
| )
|
| if request_body_oai:
|
| operation["requestBody"] = request_body_oai
|
| if route.callbacks:
|
| callbacks = {}
|
| for callback in route.callbacks:
|
| if isinstance(callback, routing.APIRoute):
|
| (
|
| cb_path,
|
| cb_security_schemes,
|
| cb_definitions,
|
| ) = get_openapi_path(
|
| route=callback,
|
| operation_ids=operation_ids,
|
| model_name_map=model_name_map,
|
| field_mapping=field_mapping,
|
| separate_input_output_schemas=separate_input_output_schemas,
|
| )
|
| callbacks[callback.name] = {callback.path: cb_path}
|
| operation["callbacks"] = callbacks
|
| if route.status_code is not None:
|
| status_code = str(route.status_code)
|
| else:
|
|
|
|
|
|
|
|
|
|
|
| response_signature = inspect.signature(current_response_class.__init__)
|
| status_code_param = response_signature.parameters.get("status_code")
|
| if status_code_param is not None:
|
| if isinstance(status_code_param.default, int):
|
| status_code = str(status_code_param.default)
|
| operation.setdefault("responses", {}).setdefault(status_code, {})[
|
| "description"
|
| ] = route.response_description
|
| if route_response_media_type and is_body_allowed_for_status_code(
|
| route.status_code
|
| ):
|
| response_schema = {"type": "string"}
|
| if lenient_issubclass(current_response_class, JSONResponse):
|
| if route.response_field:
|
| response_schema = get_schema_from_model_field(
|
| field=route.response_field,
|
| model_name_map=model_name_map,
|
| field_mapping=field_mapping,
|
| separate_input_output_schemas=separate_input_output_schemas,
|
| )
|
| else:
|
| response_schema = {}
|
| operation.setdefault("responses", {}).setdefault(
|
| status_code, {}
|
| ).setdefault("content", {}).setdefault(route_response_media_type, {})[
|
| "schema"
|
| ] = response_schema
|
| if route.responses:
|
| operation_responses = operation.setdefault("responses", {})
|
| for (
|
| additional_status_code,
|
| additional_response,
|
| ) in route.responses.items():
|
| process_response = copy.deepcopy(additional_response)
|
| process_response.pop("model", None)
|
| status_code_key = str(additional_status_code).upper()
|
| if status_code_key == "DEFAULT":
|
| status_code_key = "default"
|
| openapi_response = operation_responses.setdefault(
|
| status_code_key, {}
|
| )
|
| assert isinstance(process_response, dict), (
|
| "An additional response must be a dict"
|
| )
|
| field = route.response_fields.get(additional_status_code)
|
| additional_field_schema: Optional[dict[str, Any]] = None
|
| if field:
|
| additional_field_schema = get_schema_from_model_field(
|
| field=field,
|
| model_name_map=model_name_map,
|
| field_mapping=field_mapping,
|
| separate_input_output_schemas=separate_input_output_schemas,
|
| )
|
| media_type = route_response_media_type or "application/json"
|
| additional_schema = (
|
| process_response.setdefault("content", {})
|
| .setdefault(media_type, {})
|
| .setdefault("schema", {})
|
| )
|
| deep_dict_update(additional_schema, additional_field_schema)
|
| status_text: Optional[str] = status_code_ranges.get(
|
| str(additional_status_code).upper()
|
| ) or http.client.responses.get(int(additional_status_code))
|
| description = (
|
| process_response.get("description")
|
| or openapi_response.get("description")
|
| or status_text
|
| or "Additional Response"
|
| )
|
| deep_dict_update(openapi_response, process_response)
|
| openapi_response["description"] = description
|
| http422 = "422"
|
| all_route_params = get_flat_params(route.dependant)
|
| if (all_route_params or route.body_field) and not any(
|
| status in operation["responses"]
|
| for status in [http422, "4XX", "default"]
|
| ):
|
| operation["responses"][http422] = {
|
| "description": "Validation Error",
|
| "content": {
|
| "application/json": {
|
| "schema": {"$ref": REF_PREFIX + "HTTPValidationError"}
|
| }
|
| },
|
| }
|
| if "ValidationError" not in definitions:
|
| definitions.update(
|
| {
|
| "ValidationError": validation_error_definition,
|
| "HTTPValidationError": validation_error_response_definition,
|
| }
|
| )
|
| if route.openapi_extra:
|
| deep_dict_update(operation, route.openapi_extra)
|
| path[method.lower()] = operation
|
| return path, security_schemes, definitions
|
|
|
|
|
| def get_fields_from_routes(
|
| routes: Sequence[BaseRoute],
|
| ) -> list[ModelField]:
|
| body_fields_from_routes: list[ModelField] = []
|
| responses_from_routes: list[ModelField] = []
|
| request_fields_from_routes: list[ModelField] = []
|
| callback_flat_models: list[ModelField] = []
|
| for route in routes:
|
| if getattr(route, "include_in_schema", None) and isinstance(
|
| route, routing.APIRoute
|
| ):
|
| if route.body_field:
|
| assert isinstance(route.body_field, ModelField), (
|
| "A request body must be a Pydantic Field"
|
| )
|
| body_fields_from_routes.append(route.body_field)
|
| if route.response_field:
|
| responses_from_routes.append(route.response_field)
|
| if route.response_fields:
|
| responses_from_routes.extend(route.response_fields.values())
|
| if route.callbacks:
|
| callback_flat_models.extend(get_fields_from_routes(route.callbacks))
|
| params = get_flat_params(route.dependant)
|
| request_fields_from_routes.extend(params)
|
|
|
| flat_models = callback_flat_models + list(
|
| body_fields_from_routes + responses_from_routes + request_fields_from_routes
|
| )
|
| return flat_models
|
|
|
|
|
| def get_openapi(
|
| *,
|
| title: str,
|
| version: str,
|
| openapi_version: str = "3.1.0",
|
| summary: Optional[str] = None,
|
| description: Optional[str] = None,
|
| routes: Sequence[BaseRoute],
|
| webhooks: Optional[Sequence[BaseRoute]] = None,
|
| tags: Optional[list[dict[str, Any]]] = None,
|
| servers: Optional[list[dict[str, Union[str, Any]]]] = None,
|
| terms_of_service: Optional[str] = None,
|
| contact: Optional[dict[str, Union[str, Any]]] = None,
|
| license_info: Optional[dict[str, Union[str, Any]]] = None,
|
| separate_input_output_schemas: bool = True,
|
| external_docs: Optional[dict[str, Any]] = None,
|
| ) -> dict[str, Any]:
|
| info: dict[str, Any] = {"title": title, "version": version}
|
| if summary:
|
| info["summary"] = summary
|
| if description:
|
| info["description"] = description
|
| if terms_of_service:
|
| info["termsOfService"] = terms_of_service
|
| if contact:
|
| info["contact"] = contact
|
| if license_info:
|
| info["license"] = license_info
|
| output: dict[str, Any] = {"openapi": openapi_version, "info": info}
|
| if servers:
|
| output["servers"] = servers
|
| components: dict[str, dict[str, Any]] = {}
|
| paths: dict[str, dict[str, Any]] = {}
|
| webhook_paths: dict[str, dict[str, Any]] = {}
|
| operation_ids: set[str] = set()
|
| all_fields = get_fields_from_routes(list(routes or []) + list(webhooks or []))
|
| flat_models = get_flat_models_from_fields(all_fields, known_models=set())
|
| model_name_map = get_model_name_map(flat_models)
|
| field_mapping, definitions = get_definitions(
|
| fields=all_fields,
|
| model_name_map=model_name_map,
|
| separate_input_output_schemas=separate_input_output_schemas,
|
| )
|
| for route in routes or []:
|
| if isinstance(route, routing.APIRoute):
|
| result = get_openapi_path(
|
| route=route,
|
| operation_ids=operation_ids,
|
| model_name_map=model_name_map,
|
| field_mapping=field_mapping,
|
| separate_input_output_schemas=separate_input_output_schemas,
|
| )
|
| if result:
|
| path, security_schemes, path_definitions = result
|
| if path:
|
| paths.setdefault(route.path_format, {}).update(path)
|
| if security_schemes:
|
| components.setdefault("securitySchemes", {}).update(
|
| security_schemes
|
| )
|
| if path_definitions:
|
| definitions.update(path_definitions)
|
| for webhook in webhooks or []:
|
| if isinstance(webhook, routing.APIRoute):
|
| result = get_openapi_path(
|
| route=webhook,
|
| operation_ids=operation_ids,
|
| model_name_map=model_name_map,
|
| field_mapping=field_mapping,
|
| separate_input_output_schemas=separate_input_output_schemas,
|
| )
|
| if result:
|
| path, security_schemes, path_definitions = result
|
| if path:
|
| webhook_paths.setdefault(webhook.path_format, {}).update(path)
|
| if security_schemes:
|
| components.setdefault("securitySchemes", {}).update(
|
| security_schemes
|
| )
|
| if path_definitions:
|
| definitions.update(path_definitions)
|
| if definitions:
|
| components["schemas"] = {k: definitions[k] for k in sorted(definitions)}
|
| if components:
|
| output["components"] = components
|
| output["paths"] = paths
|
| if webhook_paths:
|
| output["webhooks"] = webhook_paths
|
| if tags:
|
| output["tags"] = tags
|
| if external_docs:
|
| output["externalDocs"] = external_docs
|
| return jsonable_encoder(OpenAPI(**output), by_alias=True, exclude_none=True)
|
|
|