| import functools |
| from collections.abc import Coroutine |
| from types import TracebackType |
| from typing import ( |
| Any, |
| Callable, |
| ClassVar, |
| Optional, |
| TypeVar, |
| Union, |
| ) |
|
|
| from httpx import ( |
| URL, |
| AsyncClient, |
| Cookies, |
| HTTPError, |
| HTTPStatusError, |
| Request, |
| Response, |
| ResponseNotRead, |
| TransportError, |
| ) |
|
|
| from .decorators import Retry, TimeIt |
| from .exceptions import UpstreamAPIException |
| from .log import logger |
|
|
| AsyncCallable_T = TypeVar("AsyncCallable_T", bound=Callable[..., Coroutine]) |
|
|
|
|
| class AsyncHTTPClient(AsyncClient): |
| net_client: "BaseNetClient" |
|
|
| @staticmethod |
| async def _log_request(request: Request): |
| method, url = request.method, request.url |
| logger.debug( |
| f"Network request <y>sent</y>: <b><e>{method}</e> <u>{url}</u></b>" |
| ) |
|
|
| @staticmethod |
| async def _log_response(response: Response): |
| method, url = response.request.method, response.url |
| try: |
| length, code = len(response.content), response.status_code |
| except ResponseNotRead: |
| length, code = -1, response.status_code |
| logger.debug( |
| f"Network request <g>finished</g>: <b><e>{method}</e> " |
| f"<u>{url}</u> <m>{code}</m></b> <m>{length}</m>" |
| ) |
|
|
| @Retry(exceptions=[TransportError]) |
| async def request(self, method: str, url: Union[URL, str], **kwargs): |
| self.event_hooks = { |
| "request": [self._log_request], |
| "response": [self._log_response], |
| } |
| return await super().request(method, url, **kwargs) |
|
|
|
|
| class BaseNetClient: |
| connections: ClassVar[int] = 0 |
| clients: ClassVar[list[AsyncHTTPClient]] = [] |
|
|
| client: Optional[AsyncHTTPClient] = None |
|
|
| def __init__( |
| self, |
| headers: Optional[dict[str, Any]] = None, |
| cookies: Optional[Cookies] = None, |
| proxies: Optional[dict[str, str]] = None, |
| client_class: type[AsyncHTTPClient] = AsyncHTTPClient, |
| ): |
| self.cookies, self.client_class = cookies or Cookies(), client_class |
| self.headers: dict[str, Any] = headers or {} |
| self.proxies: Any = proxies or {} |
|
|
| self.create_client() |
|
|
| def create_client(self): |
| self.client = self.client_class( |
| headers=self.headers, |
| proxies=self.proxies, |
| cookies=self.cookies, |
| http2=True, |
| follow_redirects=True, |
| ) |
| self.client.net_client = self |
| BaseNetClient.clients.append(self.client) |
| return self.client |
|
|
| async def __aenter__(self): |
| if not self.client or self.client.is_closed: |
| self.client = await self.create_client().__aenter__() |
|
|
| self.__class__.connections += 1 |
| return self.client |
|
|
| async def __aexit__( |
| self, |
| exc_type: Optional[type[BaseException]] = None, |
| exc_value: Optional[BaseException] = None, |
| traceback: Optional[TracebackType] = None, |
| ): |
| self.__class__.connections -= 1 |
|
|
| if not (exc_type and exc_value and traceback): |
| return |
| if self.client and not self.client.is_closed: |
| client = self.client |
| self.client = None |
| await client.__aexit__(exc_type, exc_value, traceback) |
| return |
|
|
|
|
| def catch_network_error(function: AsyncCallable_T) -> AsyncCallable_T: |
| timed_func = TimeIt(function) |
|
|
| @functools.wraps(timed_func) |
| async def wrapper(*args, **kwargs): |
| try: |
| return await timed_func(*args, **kwargs) |
| except HTTPStatusError as e: |
| raise UpstreamAPIException(detail=e.response.text) from e |
| except HTTPError as e: |
| raise UpstreamAPIException from e |
|
|
| return wrapper |
|
|