Spaces:
Paused
Paused
| from __future__ import annotations | |
| import os | |
| import inspect | |
| import logging | |
| import datetime | |
| import functools | |
| from types import TracebackType | |
| from typing import ( | |
| TYPE_CHECKING, | |
| Any, | |
| Union, | |
| Generic, | |
| TypeVar, | |
| Callable, | |
| Iterator, | |
| AsyncIterator, | |
| cast, | |
| overload, | |
| ) | |
| from typing_extensions import Awaitable, ParamSpec, override, get_origin | |
| import anyio | |
| import httpx | |
| import pydantic | |
| from ._types import NoneType | |
| from ._utils import is_given, extract_type_arg, is_annotated_type, extract_type_var_from_base | |
| from ._models import BaseModel, is_basemodel | |
| from ._constants import RAW_RESPONSE_HEADER, OVERRIDE_CAST_TO_HEADER | |
| from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type | |
| from ._exceptions import AnthropicError, APIResponseValidationError | |
| if TYPE_CHECKING: | |
| from ._models import FinalRequestOptions | |
| from ._base_client import BaseClient | |
| P = ParamSpec("P") | |
| R = TypeVar("R") | |
| _T = TypeVar("_T") | |
| _APIResponseT = TypeVar("_APIResponseT", bound="APIResponse[Any]") | |
| _AsyncAPIResponseT = TypeVar("_AsyncAPIResponseT", bound="AsyncAPIResponse[Any]") | |
| log: logging.Logger = logging.getLogger(__name__) | |
| class BaseAPIResponse(Generic[R]): | |
| _cast_to: type[R] | |
| _client: BaseClient[Any, Any] | |
| _parsed_by_type: dict[type[Any], Any] | |
| _is_sse_stream: bool | |
| _stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None | |
| _options: FinalRequestOptions | |
| http_response: httpx.Response | |
| def __init__( | |
| self, | |
| *, | |
| raw: httpx.Response, | |
| cast_to: type[R], | |
| client: BaseClient[Any, Any], | |
| stream: bool, | |
| stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None, | |
| options: FinalRequestOptions, | |
| ) -> None: | |
| self._cast_to = cast_to | |
| self._client = client | |
| self._parsed_by_type = {} | |
| self._is_sse_stream = stream | |
| self._stream_cls = stream_cls | |
| self._options = options | |
| self.http_response = raw | |
| def headers(self) -> httpx.Headers: | |
| return self.http_response.headers | |
| def http_request(self) -> httpx.Request: | |
| """Returns the httpx Request instance associated with the current response.""" | |
| return self.http_response.request | |
| def status_code(self) -> int: | |
| return self.http_response.status_code | |
| def url(self) -> httpx.URL: | |
| """Returns the URL for which the request was made.""" | |
| return self.http_response.url | |
| def method(self) -> str: | |
| return self.http_request.method | |
| def http_version(self) -> str: | |
| return self.http_response.http_version | |
| def elapsed(self) -> datetime.timedelta: | |
| """The time taken for the complete request/response cycle to complete.""" | |
| return self.http_response.elapsed | |
| def is_closed(self) -> bool: | |
| """Whether or not the response body has been closed. | |
| If this is False then there is response data that has not been read yet. | |
| You must either fully consume the response body or call `.close()` | |
| before discarding the response to prevent resource leaks. | |
| """ | |
| return self.http_response.is_closed | |
| def __repr__(self) -> str: | |
| return ( | |
| f"<{self.__class__.__name__} [{self.status_code} {self.http_response.reason_phrase}] type={self._cast_to}>" | |
| ) | |
| def _parse(self, *, to: type[_T] | None = None) -> R | _T: | |
| # unwrap `Annotated[T, ...]` -> `T` | |
| if to and is_annotated_type(to): | |
| to = extract_type_arg(to, 0) | |
| if self._is_sse_stream: | |
| if to: | |
| if not is_stream_class_type(to): | |
| raise TypeError(f"Expected custom parse type to be a subclass of {Stream} or {AsyncStream}") | |
| return cast( | |
| _T, | |
| to( | |
| cast_to=extract_stream_chunk_type( | |
| to, | |
| failure_message="Expected custom stream type to be passed with a type argument, e.g. Stream[ChunkType]", | |
| ), | |
| response=self.http_response, | |
| client=cast(Any, self._client), | |
| ), | |
| ) | |
| if self._stream_cls: | |
| return cast( | |
| R, | |
| self._stream_cls( | |
| cast_to=extract_stream_chunk_type(self._stream_cls), | |
| response=self.http_response, | |
| client=cast(Any, self._client), | |
| ), | |
| ) | |
| stream_cls = cast("type[Stream[Any]] | type[AsyncStream[Any]] | None", self._client._default_stream_cls) | |
| if stream_cls is None: | |
| raise MissingStreamClassError() | |
| return cast( | |
| R, | |
| stream_cls( | |
| cast_to=self._cast_to, | |
| response=self.http_response, | |
| client=cast(Any, self._client), | |
| ), | |
| ) | |
| cast_to = to if to is not None else self._cast_to | |
| # unwrap `Annotated[T, ...]` -> `T` | |
| if is_annotated_type(cast_to): | |
| cast_to = extract_type_arg(cast_to, 0) | |
| if cast_to is NoneType: | |
| return cast(R, None) | |
| response = self.http_response | |
| if cast_to == str: | |
| return cast(R, response.text) | |
| if cast_to == bytes: | |
| return cast(R, response.content) | |
| if cast_to == int: | |
| return cast(R, int(response.text)) | |
| if cast_to == float: | |
| return cast(R, float(response.text)) | |
| origin = get_origin(cast_to) or cast_to | |
| # handle the legacy binary response case | |
| if inspect.isclass(cast_to) and cast_to.__name__ == "HttpxBinaryResponseContent": | |
| return cast(R, cast_to(response)) # type: ignore | |
| if origin == APIResponse: | |
| raise RuntimeError("Unexpected state - cast_to is `APIResponse`") | |
| if inspect.isclass(origin) and issubclass(origin, httpx.Response): | |
| # Because of the invariance of our ResponseT TypeVar, users can subclass httpx.Response | |
| # and pass that class to our request functions. We cannot change the variance to be either | |
| # covariant or contravariant as that makes our usage of ResponseT illegal. We could construct | |
| # the response class ourselves but that is something that should be supported directly in httpx | |
| # as it would be easy to incorrectly construct the Response object due to the multitude of arguments. | |
| if cast_to != httpx.Response: | |
| raise ValueError(f"Subclasses of httpx.Response cannot be passed to `cast_to`") | |
| return cast(R, response) | |
| if inspect.isclass(origin) and not issubclass(origin, BaseModel) and issubclass(origin, pydantic.BaseModel): | |
| raise TypeError("Pydantic models must subclass our base model type, e.g. `from anthropic import BaseModel`") | |
| if ( | |
| cast_to is not object | |
| and not origin is list | |
| and not origin is dict | |
| and not origin is Union | |
| and not issubclass(origin, BaseModel) | |
| ): | |
| raise RuntimeError( | |
| f"Unsupported type, expected {cast_to} to be a subclass of {BaseModel}, {dict}, {list}, {Union}, {NoneType}, {str} or {httpx.Response}." | |
| ) | |
| # split is required to handle cases where additional information is included | |
| # in the response, e.g. application/json; charset=utf-8 | |
| content_type, *_ = response.headers.get("content-type", "*").split(";") | |
| if content_type != "application/json": | |
| if is_basemodel(cast_to): | |
| try: | |
| data = response.json() | |
| except Exception as exc: | |
| log.debug("Could not read JSON from response data due to %s - %s", type(exc), exc) | |
| else: | |
| return self._client._process_response_data( | |
| data=data, | |
| cast_to=cast_to, # type: ignore | |
| response=response, | |
| ) | |
| if self._client._strict_response_validation: | |
| raise APIResponseValidationError( | |
| response=response, | |
| message=f"Expected Content-Type response header to be `application/json` but received `{content_type}` instead.", | |
| body=response.text, | |
| ) | |
| # If the API responds with content that isn't JSON then we just return | |
| # the (decoded) text without performing any parsing so that you can still | |
| # handle the response however you need to. | |
| return response.text # type: ignore | |
| data = response.json() | |
| return self._client._process_response_data( | |
| data=data, | |
| cast_to=cast_to, # type: ignore | |
| response=response, | |
| ) | |
| class APIResponse(BaseAPIResponse[R]): | |
| def parse(self, *, to: type[_T]) -> _T: | |
| ... | |
| def parse(self) -> R: | |
| ... | |
| def parse(self, *, to: type[_T] | None = None) -> R | _T: | |
| """Returns the rich python representation of this response's data. | |
| For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`. | |
| You can customise the type that the response is parsed into through | |
| the `to` argument, e.g. | |
| ```py | |
| from anthropic import BaseModel | |
| class MyModel(BaseModel): | |
| foo: str | |
| obj = response.parse(to=MyModel) | |
| print(obj.foo) | |
| ``` | |
| We support parsing: | |
| - `BaseModel` | |
| - `dict` | |
| - `list` | |
| - `Union` | |
| - `str` | |
| - `int` | |
| - `float` | |
| - `httpx.Response` | |
| """ | |
| cache_key = to if to is not None else self._cast_to | |
| cached = self._parsed_by_type.get(cache_key) | |
| if cached is not None: | |
| return cached # type: ignore[no-any-return] | |
| if not self._is_sse_stream: | |
| self.read() | |
| parsed = self._parse(to=to) | |
| if is_given(self._options.post_parser): | |
| parsed = self._options.post_parser(parsed) | |
| self._parsed_by_type[cache_key] = parsed | |
| return parsed | |
| def read(self) -> bytes: | |
| """Read and return the binary response content.""" | |
| try: | |
| return self.http_response.read() | |
| except httpx.StreamConsumed as exc: | |
| # The default error raised by httpx isn't very | |
| # helpful in our case so we re-raise it with | |
| # a different error message. | |
| raise StreamAlreadyConsumed() from exc | |
| def text(self) -> str: | |
| """Read and decode the response content into a string.""" | |
| self.read() | |
| return self.http_response.text | |
| def json(self) -> object: | |
| """Read and decode the JSON response content.""" | |
| self.read() | |
| return self.http_response.json() | |
| def close(self) -> None: | |
| """Close the response and release the connection. | |
| Automatically called if the response body is read to completion. | |
| """ | |
| self.http_response.close() | |
| def iter_bytes(self, chunk_size: int | None = None) -> Iterator[bytes]: | |
| """ | |
| A byte-iterator over the decoded response content. | |
| This automatically handles gzip, deflate and brotli encoded responses. | |
| """ | |
| for chunk in self.http_response.iter_bytes(chunk_size): | |
| yield chunk | |
| def iter_text(self, chunk_size: int | None = None) -> Iterator[str]: | |
| """A str-iterator over the decoded response content | |
| that handles both gzip, deflate, etc but also detects the content's | |
| string encoding. | |
| """ | |
| for chunk in self.http_response.iter_text(chunk_size): | |
| yield chunk | |
| def iter_lines(self) -> Iterator[str]: | |
| """Like `iter_text()` but will only yield chunks for each line""" | |
| for chunk in self.http_response.iter_lines(): | |
| yield chunk | |
| class AsyncAPIResponse(BaseAPIResponse[R]): | |
| async def parse(self, *, to: type[_T]) -> _T: | |
| ... | |
| async def parse(self) -> R: | |
| ... | |
| async def parse(self, *, to: type[_T] | None = None) -> R | _T: | |
| """Returns the rich python representation of this response's data. | |
| For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`. | |
| You can customise the type that the response is parsed into through | |
| the `to` argument, e.g. | |
| ```py | |
| from anthropic import BaseModel | |
| class MyModel(BaseModel): | |
| foo: str | |
| obj = response.parse(to=MyModel) | |
| print(obj.foo) | |
| ``` | |
| We support parsing: | |
| - `BaseModel` | |
| - `dict` | |
| - `list` | |
| - `Union` | |
| - `str` | |
| - `httpx.Response` | |
| """ | |
| cache_key = to if to is not None else self._cast_to | |
| cached = self._parsed_by_type.get(cache_key) | |
| if cached is not None: | |
| return cached # type: ignore[no-any-return] | |
| if not self._is_sse_stream: | |
| await self.read() | |
| parsed = self._parse(to=to) | |
| if is_given(self._options.post_parser): | |
| parsed = self._options.post_parser(parsed) | |
| self._parsed_by_type[cache_key] = parsed | |
| return parsed | |
| async def read(self) -> bytes: | |
| """Read and return the binary response content.""" | |
| try: | |
| return await self.http_response.aread() | |
| except httpx.StreamConsumed as exc: | |
| # the default error raised by httpx isn't very | |
| # helpful in our case so we re-raise it with | |
| # a different error message | |
| raise StreamAlreadyConsumed() from exc | |
| async def text(self) -> str: | |
| """Read and decode the response content into a string.""" | |
| await self.read() | |
| return self.http_response.text | |
| async def json(self) -> object: | |
| """Read and decode the JSON response content.""" | |
| await self.read() | |
| return self.http_response.json() | |
| async def close(self) -> None: | |
| """Close the response and release the connection. | |
| Automatically called if the response body is read to completion. | |
| """ | |
| await self.http_response.aclose() | |
| async def iter_bytes(self, chunk_size: int | None = None) -> AsyncIterator[bytes]: | |
| """ | |
| A byte-iterator over the decoded response content. | |
| This automatically handles gzip, deflate and brotli encoded responses. | |
| """ | |
| async for chunk in self.http_response.aiter_bytes(chunk_size): | |
| yield chunk | |
| async def iter_text(self, chunk_size: int | None = None) -> AsyncIterator[str]: | |
| """A str-iterator over the decoded response content | |
| that handles both gzip, deflate, etc but also detects the content's | |
| string encoding. | |
| """ | |
| async for chunk in self.http_response.aiter_text(chunk_size): | |
| yield chunk | |
| async def iter_lines(self) -> AsyncIterator[str]: | |
| """Like `iter_text()` but will only yield chunks for each line""" | |
| async for chunk in self.http_response.aiter_lines(): | |
| yield chunk | |
| class BinaryAPIResponse(APIResponse[bytes]): | |
| """Subclass of APIResponse providing helpers for dealing with binary data. | |
| Note: If you want to stream the response data instead of eagerly reading it | |
| all at once then you should use `.with_streaming_response` when making | |
| the API request, e.g. `.with_streaming_response.get_binary_response()` | |
| """ | |
| def write_to_file( | |
| self, | |
| file: str | os.PathLike[str], | |
| ) -> None: | |
| """Write the output to the given file. | |
| Accepts a filename or any path-like object, e.g. pathlib.Path | |
| Note: if you want to stream the data to the file instead of writing | |
| all at once then you should use `.with_streaming_response` when making | |
| the API request, e.g. `.with_streaming_response.get_binary_response()` | |
| """ | |
| with open(file, mode="wb") as f: | |
| for data in self.iter_bytes(): | |
| f.write(data) | |
| class AsyncBinaryAPIResponse(AsyncAPIResponse[bytes]): | |
| """Subclass of APIResponse providing helpers for dealing with binary data. | |
| Note: If you want to stream the response data instead of eagerly reading it | |
| all at once then you should use `.with_streaming_response` when making | |
| the API request, e.g. `.with_streaming_response.get_binary_response()` | |
| """ | |
| async def write_to_file( | |
| self, | |
| file: str | os.PathLike[str], | |
| ) -> None: | |
| """Write the output to the given file. | |
| Accepts a filename or any path-like object, e.g. pathlib.Path | |
| Note: if you want to stream the data to the file instead of writing | |
| all at once then you should use `.with_streaming_response` when making | |
| the API request, e.g. `.with_streaming_response.get_binary_response()` | |
| """ | |
| path = anyio.Path(file) | |
| async with await path.open(mode="wb") as f: | |
| async for data in self.iter_bytes(): | |
| await f.write(data) | |
| class StreamedBinaryAPIResponse(APIResponse[bytes]): | |
| def stream_to_file( | |
| self, | |
| file: str | os.PathLike[str], | |
| *, | |
| chunk_size: int | None = None, | |
| ) -> None: | |
| """Streams the output to the given file. | |
| Accepts a filename or any path-like object, e.g. pathlib.Path | |
| """ | |
| with open(file, mode="wb") as f: | |
| for data in self.iter_bytes(chunk_size): | |
| f.write(data) | |
| class AsyncStreamedBinaryAPIResponse(AsyncAPIResponse[bytes]): | |
| async def stream_to_file( | |
| self, | |
| file: str | os.PathLike[str], | |
| *, | |
| chunk_size: int | None = None, | |
| ) -> None: | |
| """Streams the output to the given file. | |
| Accepts a filename or any path-like object, e.g. pathlib.Path | |
| """ | |
| path = anyio.Path(file) | |
| async with await path.open(mode="wb") as f: | |
| async for data in self.iter_bytes(chunk_size): | |
| await f.write(data) | |
| class MissingStreamClassError(TypeError): | |
| def __init__(self) -> None: | |
| super().__init__( | |
| "The `stream` argument was set to `True` but the `stream_cls` argument was not given. See `anthropic._streaming` for reference", | |
| ) | |
| class StreamAlreadyConsumed(AnthropicError): | |
| """ | |
| Attempted to read or stream content, but the content has already | |
| been streamed. | |
| This can happen if you use a method like `.iter_lines()` and then attempt | |
| to read th entire response body afterwards, e.g. | |
| ```py | |
| response = await client.post(...) | |
| async for line in response.iter_lines(): | |
| ... # do something with `line` | |
| content = await response.read() | |
| # ^ error | |
| ``` | |
| If you want this behaviour you'll need to either manually accumulate the response | |
| content or call `await response.read()` before iterating over the stream. | |
| """ | |
| def __init__(self) -> None: | |
| message = ( | |
| "Attempted to read or stream some content, but the content has " | |
| "already been streamed. " | |
| "This could be due to attempting to stream the response " | |
| "content more than once." | |
| "\n\n" | |
| "You can fix this by manually accumulating the response content while streaming " | |
| "or by calling `.read()` before starting to stream." | |
| ) | |
| super().__init__(message) | |
| class ResponseContextManager(Generic[_APIResponseT]): | |
| """Context manager for ensuring that a request is not made | |
| until it is entered and that the response will always be closed | |
| when the context manager exits | |
| """ | |
| def __init__(self, request_func: Callable[[], _APIResponseT]) -> None: | |
| self._request_func = request_func | |
| self.__response: _APIResponseT | None = None | |
| def __enter__(self) -> _APIResponseT: | |
| self.__response = self._request_func() | |
| return self.__response | |
| def __exit__( | |
| self, | |
| exc_type: type[BaseException] | None, | |
| exc: BaseException | None, | |
| exc_tb: TracebackType | None, | |
| ) -> None: | |
| if self.__response is not None: | |
| self.__response.close() | |
| class AsyncResponseContextManager(Generic[_AsyncAPIResponseT]): | |
| """Context manager for ensuring that a request is not made | |
| until it is entered and that the response will always be closed | |
| when the context manager exits | |
| """ | |
| def __init__(self, api_request: Awaitable[_AsyncAPIResponseT]) -> None: | |
| self._api_request = api_request | |
| self.__response: _AsyncAPIResponseT | None = None | |
| async def __aenter__(self) -> _AsyncAPIResponseT: | |
| self.__response = await self._api_request | |
| return self.__response | |
| async def __aexit__( | |
| self, | |
| exc_type: type[BaseException] | None, | |
| exc: BaseException | None, | |
| exc_tb: TracebackType | None, | |
| ) -> None: | |
| if self.__response is not None: | |
| await self.__response.close() | |
| def to_streamed_response_wrapper(func: Callable[P, R]) -> Callable[P, ResponseContextManager[APIResponse[R]]]: | |
| """Higher order function that takes one of our bound API methods and wraps it | |
| to support streaming and returning the raw `APIResponse` object directly. | |
| """ | |
| def wrapped(*args: P.args, **kwargs: P.kwargs) -> ResponseContextManager[APIResponse[R]]: | |
| extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})} | |
| extra_headers[RAW_RESPONSE_HEADER] = "stream" | |
| kwargs["extra_headers"] = extra_headers | |
| make_request = functools.partial(func, *args, **kwargs) | |
| return ResponseContextManager(cast(Callable[[], APIResponse[R]], make_request)) | |
| return wrapped | |
| def async_to_streamed_response_wrapper( | |
| func: Callable[P, Awaitable[R]], | |
| ) -> Callable[P, AsyncResponseContextManager[AsyncAPIResponse[R]]]: | |
| """Higher order function that takes one of our bound API methods and wraps it | |
| to support streaming and returning the raw `APIResponse` object directly. | |
| """ | |
| def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncResponseContextManager[AsyncAPIResponse[R]]: | |
| extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})} | |
| extra_headers[RAW_RESPONSE_HEADER] = "stream" | |
| kwargs["extra_headers"] = extra_headers | |
| make_request = func(*args, **kwargs) | |
| return AsyncResponseContextManager(cast(Awaitable[AsyncAPIResponse[R]], make_request)) | |
| return wrapped | |
| def to_custom_streamed_response_wrapper( | |
| func: Callable[P, object], | |
| response_cls: type[_APIResponseT], | |
| ) -> Callable[P, ResponseContextManager[_APIResponseT]]: | |
| """Higher order function that takes one of our bound API methods and an `APIResponse` class | |
| and wraps the method to support streaming and returning the given response class directly. | |
| Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])` | |
| """ | |
| def wrapped(*args: P.args, **kwargs: P.kwargs) -> ResponseContextManager[_APIResponseT]: | |
| extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})} | |
| extra_headers[RAW_RESPONSE_HEADER] = "stream" | |
| extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls | |
| kwargs["extra_headers"] = extra_headers | |
| make_request = functools.partial(func, *args, **kwargs) | |
| return ResponseContextManager(cast(Callable[[], _APIResponseT], make_request)) | |
| return wrapped | |
| def async_to_custom_streamed_response_wrapper( | |
| func: Callable[P, Awaitable[object]], | |
| response_cls: type[_AsyncAPIResponseT], | |
| ) -> Callable[P, AsyncResponseContextManager[_AsyncAPIResponseT]]: | |
| """Higher order function that takes one of our bound API methods and an `APIResponse` class | |
| and wraps the method to support streaming and returning the given response class directly. | |
| Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])` | |
| """ | |
| def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncResponseContextManager[_AsyncAPIResponseT]: | |
| extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})} | |
| extra_headers[RAW_RESPONSE_HEADER] = "stream" | |
| extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls | |
| kwargs["extra_headers"] = extra_headers | |
| make_request = func(*args, **kwargs) | |
| return AsyncResponseContextManager(cast(Awaitable[_AsyncAPIResponseT], make_request)) | |
| return wrapped | |
| def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, APIResponse[R]]: | |
| """Higher order function that takes one of our bound API methods and wraps it | |
| to support returning the raw `APIResponse` object directly. | |
| """ | |
| def wrapped(*args: P.args, **kwargs: P.kwargs) -> APIResponse[R]: | |
| extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})} | |
| extra_headers[RAW_RESPONSE_HEADER] = "raw" | |
| kwargs["extra_headers"] = extra_headers | |
| return cast(APIResponse[R], func(*args, **kwargs)) | |
| return wrapped | |
| def async_to_raw_response_wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[AsyncAPIResponse[R]]]: | |
| """Higher order function that takes one of our bound API methods and wraps it | |
| to support returning the raw `APIResponse` object directly. | |
| """ | |
| async def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncAPIResponse[R]: | |
| extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})} | |
| extra_headers[RAW_RESPONSE_HEADER] = "raw" | |
| kwargs["extra_headers"] = extra_headers | |
| return cast(AsyncAPIResponse[R], await func(*args, **kwargs)) | |
| return wrapped | |
| def to_custom_raw_response_wrapper( | |
| func: Callable[P, object], | |
| response_cls: type[_APIResponseT], | |
| ) -> Callable[P, _APIResponseT]: | |
| """Higher order function that takes one of our bound API methods and an `APIResponse` class | |
| and wraps the method to support returning the given response class directly. | |
| Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])` | |
| """ | |
| def wrapped(*args: P.args, **kwargs: P.kwargs) -> _APIResponseT: | |
| extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})} | |
| extra_headers[RAW_RESPONSE_HEADER] = "raw" | |
| extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls | |
| kwargs["extra_headers"] = extra_headers | |
| return cast(_APIResponseT, func(*args, **kwargs)) | |
| return wrapped | |
| def async_to_custom_raw_response_wrapper( | |
| func: Callable[P, Awaitable[object]], | |
| response_cls: type[_AsyncAPIResponseT], | |
| ) -> Callable[P, Awaitable[_AsyncAPIResponseT]]: | |
| """Higher order function that takes one of our bound API methods and an `APIResponse` class | |
| and wraps the method to support returning the given response class directly. | |
| Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])` | |
| """ | |
| def wrapped(*args: P.args, **kwargs: P.kwargs) -> Awaitable[_AsyncAPIResponseT]: | |
| extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})} | |
| extra_headers[RAW_RESPONSE_HEADER] = "raw" | |
| extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls | |
| kwargs["extra_headers"] = extra_headers | |
| return cast(Awaitable[_AsyncAPIResponseT], func(*args, **kwargs)) | |
| return wrapped | |
| def extract_response_type(typ: type[BaseAPIResponse[Any]]) -> type: | |
| """Given a type like `APIResponse[T]`, returns the generic type variable `T`. | |
| This also handles the case where a concrete subclass is given, e.g. | |
| ```py | |
| class MyResponse(APIResponse[bytes]): | |
| ... | |
| extract_response_type(MyResponse) -> bytes | |
| ``` | |
| """ | |
| return extract_type_var_from_base( | |
| typ, | |
| generic_bases=cast("tuple[type, ...]", (BaseAPIResponse, APIResponse, AsyncAPIResponse)), | |
| index=0, | |
| ) | |