|
|
import inspect |
|
|
from collections.abc import Mapping |
|
|
from contextvars import ContextVar |
|
|
from enum import Enum |
|
|
from fnmatch import fnmatch |
|
|
from functools import wraps |
|
|
from typing import Annotated, Any, Callable, Literal, Optional |
|
|
from urllib.parse import ParseResult, urlparse |
|
|
|
|
|
from fastapi import Depends, Request |
|
|
from fastapi.routing import APIRouter |
|
|
from httpx import URL |
|
|
from pydantic import AnyHttpUrl |
|
|
from pydantic.errors import UrlHostError |
|
|
from starlette.datastructures import Headers, MutableHeaders |
|
|
|
|
|
from hibiapi.utils.cache import endpoint_cache |
|
|
from hibiapi.utils.net import AsyncCallable_T, AsyncHTTPClient, BaseNetClient |
|
|
|
|
|
DONT_ROUTE_KEY = "_dont_route" |
|
|
|
|
|
|
|
|
def dont_route(func: AsyncCallable_T) -> AsyncCallable_T: |
|
|
setattr(func, DONT_ROUTE_KEY, True) |
|
|
return func |
|
|
|
|
|
|
|
|
class EndpointMeta(type): |
|
|
@staticmethod |
|
|
def _list_router_function(members: dict[str, Any]): |
|
|
return { |
|
|
name: object |
|
|
for name, object in members.items() |
|
|
if ( |
|
|
inspect.iscoroutinefunction(object) |
|
|
and not name.startswith("_") |
|
|
and not getattr(object, DONT_ROUTE_KEY, False) |
|
|
) |
|
|
} |
|
|
|
|
|
def __new__( |
|
|
cls, |
|
|
name: str, |
|
|
bases: tuple[type, ...], |
|
|
namespace: dict[str, Any], |
|
|
*, |
|
|
cache_endpoints: bool = True, |
|
|
**kwargs, |
|
|
): |
|
|
for object_name, object in cls._list_router_function(namespace).items(): |
|
|
namespace[object_name] = ( |
|
|
endpoint_cache(object) if cache_endpoints else object |
|
|
) |
|
|
return super().__new__(cls, name, bases, namespace, **kwargs) |
|
|
|
|
|
@property |
|
|
def router_functions(self): |
|
|
return self._list_router_function(dict(inspect.getmembers(self))) |
|
|
|
|
|
|
|
|
class BaseEndpoint(metaclass=EndpointMeta, cache_endpoints=False): |
|
|
def __init__(self, client: AsyncHTTPClient): |
|
|
self.client = client |
|
|
|
|
|
@staticmethod |
|
|
def _join(base: str, endpoint: str, params: dict[str, Any]) -> URL: |
|
|
host: ParseResult = urlparse(base) |
|
|
params = { |
|
|
k: (v.value if isinstance(v, Enum) else v) |
|
|
for k, v in params.items() |
|
|
if v is not None |
|
|
} |
|
|
return URL( |
|
|
url=ParseResult( |
|
|
scheme=host.scheme, |
|
|
netloc=host.netloc, |
|
|
path=endpoint.format(**params), |
|
|
params="", |
|
|
query="", |
|
|
fragment="", |
|
|
).geturl(), |
|
|
params=params, |
|
|
) |
|
|
|
|
|
|
|
|
class SlashRouter(APIRouter): |
|
|
def api_route(self, path: str, **kwargs): |
|
|
path = path if path.startswith("/") else f"/{path}" |
|
|
return super().api_route(path, **kwargs) |
|
|
|
|
|
|
|
|
class EndpointRouter(SlashRouter): |
|
|
@staticmethod |
|
|
def _exclude_params(func: Callable, params: Mapping[str, Any]) -> dict[str, Any]: |
|
|
func_params = inspect.signature(func).parameters |
|
|
return {k: v for k, v in params.items() if k in func_params} |
|
|
|
|
|
@staticmethod |
|
|
def _router_signature_convert( |
|
|
func, |
|
|
endpoint_class: type["BaseEndpoint"], |
|
|
request_client: Callable, |
|
|
method_name: Optional[str] = None, |
|
|
): |
|
|
@wraps(func) |
|
|
async def route_func(endpoint: endpoint_class, **kwargs): |
|
|
endpoint_method = getattr(endpoint, method_name or func.__name__) |
|
|
return await endpoint_method(**kwargs) |
|
|
|
|
|
route_func.__signature__ = inspect.signature(route_func).replace( |
|
|
parameters=[ |
|
|
inspect.Parameter( |
|
|
name="endpoint", |
|
|
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, |
|
|
annotation=endpoint_class, |
|
|
default=Depends(request_client), |
|
|
), |
|
|
*( |
|
|
param |
|
|
for param in inspect.signature(func).parameters.values() |
|
|
if param.kind == inspect.Parameter.KEYWORD_ONLY |
|
|
), |
|
|
] |
|
|
) |
|
|
return route_func |
|
|
|
|
|
def include_endpoint( |
|
|
self, |
|
|
endpoint_class: type[BaseEndpoint], |
|
|
net_client: BaseNetClient, |
|
|
add_match_all: bool = True, |
|
|
): |
|
|
router_functions = endpoint_class.router_functions |
|
|
|
|
|
async def request_client(): |
|
|
async with net_client as client: |
|
|
yield endpoint_class(client) |
|
|
|
|
|
for func_name, func in router_functions.items(): |
|
|
self.add_api_route( |
|
|
path=f"/{func_name}", |
|
|
endpoint=self._router_signature_convert( |
|
|
func, |
|
|
endpoint_class=endpoint_class, |
|
|
request_client=request_client, |
|
|
method_name=func_name, |
|
|
), |
|
|
methods=["GET"], |
|
|
) |
|
|
|
|
|
if not add_match_all: |
|
|
return |
|
|
|
|
|
@self.get("/", description="JournalAD style API routing", deprecated=True) |
|
|
async def match_all( |
|
|
endpoint: Annotated[endpoint_class, Depends(request_client)], |
|
|
request: Request, |
|
|
type: Literal[tuple(router_functions.keys())], |
|
|
): |
|
|
func = router_functions[type] |
|
|
return await func( |
|
|
endpoint, **self._exclude_params(func, request.query_params) |
|
|
) |
|
|
|
|
|
|
|
|
class BaseHostUrl(AnyHttpUrl): |
|
|
allowed_hosts: list[str] = [] |
|
|
|
|
|
@classmethod |
|
|
def validate_host(cls, parts) -> tuple[str, Optional[str], str, bool]: |
|
|
host, tld, host_type, rebuild = super().validate_host(parts) |
|
|
if not cls._check_domain(host): |
|
|
raise UrlHostError(allowed=cls.allowed_hosts) |
|
|
return host, tld, host_type, rebuild |
|
|
|
|
|
@classmethod |
|
|
def _check_domain(cls, host: str) -> bool: |
|
|
return any( |
|
|
filter( |
|
|
lambda x: fnmatch(host, x), |
|
|
cls.allowed_hosts, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
request_headers = ContextVar[Headers]("request_headers") |
|
|
response_headers = ContextVar[MutableHeaders]("response_headers") |
|
|
|