Spaces:
Running
Running
| import functools | |
| from collections.abc import Coroutine | |
| from types import TracebackType | |
| from typing import ( | |
| Any, | |
| Callable, | |
| ClassVar, | |
| Optional, | |
| TypeVar, | |
| Union, | |
| ) | |
| from httpx import ( | |
| URL, | |
| AsyncClient, | |
| Cookies, | |
| HTTPError, | |
| HTTPStatusError, | |
| Request, | |
| Response, | |
| ResponseNotRead, | |
| TransportError, | |
| ) | |
| from .decorators import Retry, TimeIt | |
| from .exceptions import UpstreamAPIException | |
| from .log import logger | |
| AsyncCallable_T = TypeVar("AsyncCallable_T", bound=Callable[..., Coroutine]) | |
| class AsyncHTTPClient(AsyncClient): | |
| net_client: "BaseNetClient" | |
| async def _log_request(request: Request): | |
| method, url = request.method, request.url | |
| logger.debug( | |
| f"Network request <y>sent</y>: <b><e>{method}</e> <u>{url}</u></b>" | |
| ) | |
| async def _log_response(response: Response): | |
| method, url = response.request.method, response.url | |
| try: | |
| length, code = len(response.content), response.status_code | |
| except ResponseNotRead: | |
| length, code = -1, response.status_code | |
| logger.debug( | |
| f"Network request <g>finished</g>: <b><e>{method}</e> " | |
| f"<u>{url}</u> <m>{code}</m></b> <m>{length}</m>" | |
| ) | |
| async def request(self, method: str, url: Union[URL, str], **kwargs): | |
| self.event_hooks = { | |
| "request": [self._log_request], | |
| "response": [self._log_response], | |
| } | |
| return await super().request(method, url, **kwargs) | |
| class BaseNetClient: | |
| connections: ClassVar[int] = 0 | |
| clients: ClassVar[list[AsyncHTTPClient]] = [] | |
| client: Optional[AsyncHTTPClient] = None | |
| def __init__( | |
| self, | |
| headers: Optional[dict[str, Any]] = None, | |
| cookies: Optional[Cookies] = None, | |
| proxies: Optional[dict[str, str]] = None, | |
| client_class: type[AsyncHTTPClient] = AsyncHTTPClient, | |
| ): | |
| self.cookies, self.client_class = cookies or Cookies(), client_class | |
| self.headers: dict[str, Any] = headers or {} | |
| self.proxies: Any = proxies or {} # Bypass type checker | |
| self.create_client() | |
| def create_client(self): | |
| self.client = self.client_class( | |
| headers=self.headers, | |
| proxies=self.proxies, | |
| cookies=self.cookies, | |
| http2=True, | |
| follow_redirects=True, | |
| ) | |
| self.client.net_client = self | |
| BaseNetClient.clients.append(self.client) | |
| return self.client | |
| async def __aenter__(self): | |
| if not self.client or self.client.is_closed: | |
| self.client = await self.create_client().__aenter__() | |
| self.__class__.connections += 1 | |
| return self.client | |
| async def __aexit__( | |
| self, | |
| exc_type: Optional[type[BaseException]] = None, | |
| exc_value: Optional[BaseException] = None, | |
| traceback: Optional[TracebackType] = None, | |
| ): | |
| self.__class__.connections -= 1 | |
| if not (exc_type and exc_value and traceback): | |
| return | |
| if self.client and not self.client.is_closed: | |
| client = self.client | |
| self.client = None | |
| await client.__aexit__(exc_type, exc_value, traceback) | |
| return | |
| def catch_network_error(function: AsyncCallable_T) -> AsyncCallable_T: | |
| timed_func = TimeIt(function) | |
| async def wrapper(*args, **kwargs): | |
| try: | |
| return await timed_func(*args, **kwargs) | |
| except HTTPStatusError as e: | |
| raise UpstreamAPIException(detail=e.response.text) from e | |
| except HTTPError as e: | |
| raise UpstreamAPIException from e | |
| return wrapper # type:ignore | |