File size: 5,321 Bytes
bfed822 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | from __future__ import annotations as _annotations
import functools
import inspect
from collections.abc import Awaitable
from functools import partial
from typing import Any, Callable
import pydantic_core
from ..config import ConfigDict
from ..plugin._schema_validator import create_schema_validator
from ._config import ConfigWrapper
from ._generate_schema import GenerateSchema, ValidateCallSupportedTypes
from ._namespace_utils import MappingNamespace, NsResolver, ns_for_function
def extract_function_name(func: ValidateCallSupportedTypes) -> str:
"""Extract the name of a `ValidateCallSupportedTypes` object."""
return f'partial({func.func.__name__})' if isinstance(func, functools.partial) else func.__name__
def extract_function_qualname(func: ValidateCallSupportedTypes) -> str:
"""Extract the qualname of a `ValidateCallSupportedTypes` object."""
return f'partial({func.func.__qualname__})' if isinstance(func, functools.partial) else func.__qualname__
def update_wrapper_attributes(wrapped: ValidateCallSupportedTypes, wrapper: Callable[..., Any]):
"""Update the `wrapper` function with the attributes of the `wrapped` function. Return the updated function."""
if inspect.iscoroutinefunction(wrapped):
@functools.wraps(wrapped)
async def wrapper_function(*args, **kwargs): # type: ignore
return await wrapper(*args, **kwargs)
else:
@functools.wraps(wrapped)
def wrapper_function(*args, **kwargs):
return wrapper(*args, **kwargs)
# We need to manually update this because `partial` object has no `__name__` and `__qualname__`.
wrapper_function.__name__ = extract_function_name(wrapped)
wrapper_function.__qualname__ = extract_function_qualname(wrapped)
wrapper_function.raw_function = wrapped # type: ignore
return wrapper_function
class ValidateCallWrapper:
"""This is a wrapper around a function that validates the arguments passed to it, and optionally the return value."""
__slots__ = (
'function',
'validate_return',
'schema_type',
'module',
'qualname',
'ns_resolver',
'config_wrapper',
'__pydantic_complete__',
'__pydantic_validator__',
'__return_pydantic_validator__',
)
def __init__(
self,
function: ValidateCallSupportedTypes,
config: ConfigDict | None,
validate_return: bool,
parent_namespace: MappingNamespace | None,
) -> None:
self.function = function
self.validate_return = validate_return
if isinstance(function, partial):
self.schema_type = function.func
self.module = function.func.__module__
else:
self.schema_type = function
self.module = function.__module__
self.qualname = extract_function_qualname(function)
self.ns_resolver = NsResolver(
namespaces_tuple=ns_for_function(self.schema_type, parent_namespace=parent_namespace)
)
self.config_wrapper = ConfigWrapper(config)
if not self.config_wrapper.defer_build:
self._create_validators()
else:
self.__pydantic_complete__ = False
def _create_validators(self) -> None:
gen_schema = GenerateSchema(self.config_wrapper, self.ns_resolver)
schema = gen_schema.clean_schema(gen_schema.generate_schema(self.function))
core_config = self.config_wrapper.core_config(title=self.qualname)
self.__pydantic_validator__ = create_schema_validator(
schema,
self.schema_type,
self.module,
self.qualname,
'validate_call',
core_config,
self.config_wrapper.plugin_settings,
)
if self.validate_return:
signature = inspect.signature(self.function)
return_type = signature.return_annotation if signature.return_annotation is not signature.empty else Any
gen_schema = GenerateSchema(self.config_wrapper, self.ns_resolver)
schema = gen_schema.clean_schema(gen_schema.generate_schema(return_type))
validator = create_schema_validator(
schema,
self.schema_type,
self.module,
self.qualname,
'validate_call',
core_config,
self.config_wrapper.plugin_settings,
)
if inspect.iscoroutinefunction(self.function):
async def return_val_wrapper(aw: Awaitable[Any]) -> None:
return validator.validate_python(await aw)
self.__return_pydantic_validator__ = return_val_wrapper
else:
self.__return_pydantic_validator__ = validator.validate_python
else:
self.__return_pydantic_validator__ = None
self.__pydantic_complete__ = True
def __call__(self, *args: Any, **kwargs: Any) -> Any:
if not self.__pydantic_complete__:
self._create_validators()
res = self.__pydantic_validator__.validate_python(pydantic_core.ArgsKwargs(args, kwargs))
if self.__return_pydantic_validator__:
return self.__return_pydantic_validator__(res)
else:
return res
|