Spaces:
Paused
Paused
| # Note: initially copied from https://github.com/florimondmanca/httpx-sse/blob/master/src/httpx_sse/_decoders.py | |
| from __future__ import annotations | |
| import abc | |
| import json | |
| import inspect | |
| import warnings | |
| from types import TracebackType | |
| from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast | |
| from typing_extensions import Self, Protocol, TypeGuard, override, get_origin, runtime_checkable | |
| import httpx | |
| from ._utils import is_dict, extract_type_var_from_base | |
| if TYPE_CHECKING: | |
| from ._client import Anthropic, AsyncAnthropic | |
| _T = TypeVar("_T") | |
| class _SyncStreamMeta(abc.ABCMeta): | |
| def __instancecheck__(self, instance: Any) -> bool: | |
| # we override the `isinstance()` check for `Stream` | |
| # as a previous version of the `MessageStream` class | |
| # inherited from `Stream` & without this workaround, | |
| # changing it to not inherit would be a breaking change. | |
| from .lib.streaming import MessageStream | |
| if isinstance(instance, MessageStream): | |
| warnings.warn( | |
| "Using `isinstance()` to check if a `MessageStream` object is an instance of `Stream` is deprecated & will be removed in the next major version", | |
| DeprecationWarning, | |
| stacklevel=2, | |
| ) | |
| return True | |
| return False | |
| class Stream(Generic[_T], metaclass=_SyncStreamMeta): | |
| """Provides the core interface to iterate over a synchronous stream response.""" | |
| response: httpx.Response | |
| _decoder: SSEBytesDecoder | |
| def __init__( | |
| self, | |
| *, | |
| cast_to: type[_T], | |
| response: httpx.Response, | |
| client: Anthropic, | |
| ) -> None: | |
| self.response = response | |
| self._cast_to = cast_to | |
| self._client = client | |
| self._decoder = client._make_sse_decoder() | |
| self._iterator = self.__stream__() | |
| def __next__(self) -> _T: | |
| return self._iterator.__next__() | |
| def __iter__(self) -> Iterator[_T]: | |
| for item in self._iterator: | |
| yield item | |
| def _iter_events(self) -> Iterator[ServerSentEvent]: | |
| yield from self._decoder.iter_bytes(self.response.iter_bytes()) | |
| def __stream__(self) -> Iterator[_T]: | |
| cast_to = cast(Any, self._cast_to) | |
| response = self.response | |
| process_data = self._client._process_response_data | |
| iterator = self._iter_events() | |
| for sse in iterator: | |
| if sse.event == "completion": | |
| yield process_data(data=sse.json(), cast_to=cast_to, response=response) | |
| if ( | |
| sse.event == "message_start" | |
| or sse.event == "message_delta" | |
| or sse.event == "message_stop" | |
| or sse.event == "content_block_start" | |
| or sse.event == "content_block_delta" | |
| or sse.event == "content_block_stop" | |
| ): | |
| data = sse.json() | |
| if is_dict(data) and "type" not in data: | |
| data["type"] = sse.event | |
| yield process_data(data=data, cast_to=cast_to, response=response) | |
| if sse.event == "ping": | |
| continue | |
| if sse.event == "error": | |
| body = sse.data | |
| try: | |
| body = sse.json() | |
| err_msg = f"{body}" | |
| except Exception: | |
| err_msg = sse.data or f"Error code: {response.status_code}" | |
| raise self._client._make_status_error( | |
| err_msg, | |
| body=body, | |
| response=self.response, | |
| ) | |
| # Ensure the entire stream is consumed | |
| for _sse in iterator: | |
| ... | |
| def __enter__(self) -> Self: | |
| return self | |
| def __exit__( | |
| self, | |
| exc_type: type[BaseException] | None, | |
| exc: BaseException | None, | |
| exc_tb: TracebackType | None, | |
| ) -> None: | |
| self.close() | |
| def close(self) -> None: | |
| """ | |
| Close the response and release the connection. | |
| Automatically called if the response body is read to completion. | |
| """ | |
| self.response.close() | |
| class _AsyncStreamMeta(abc.ABCMeta): | |
| def __instancecheck__(self, instance: Any) -> bool: | |
| # we override the `isinstance()` check for `AsyncStream` | |
| # as a previous version of the `AsyncMessageStream` class | |
| # inherited from `AsyncStream` & without this workaround, | |
| # changing it to not inherit would be a breaking change. | |
| from .lib.streaming import AsyncMessageStream | |
| if isinstance(instance, AsyncMessageStream): | |
| warnings.warn( | |
| "Using `isinstance()` to check if a `AsyncMessageStream` object is an instance of `AsyncStream` is deprecated & will be removed in the next major version", | |
| DeprecationWarning, | |
| stacklevel=2, | |
| ) | |
| return True | |
| return False | |
| class AsyncStream(Generic[_T], metaclass=_AsyncStreamMeta): | |
| """Provides the core interface to iterate over an asynchronous stream response.""" | |
| response: httpx.Response | |
| _decoder: SSEDecoder | SSEBytesDecoder | |
| def __init__( | |
| self, | |
| *, | |
| cast_to: type[_T], | |
| response: httpx.Response, | |
| client: AsyncAnthropic, | |
| ) -> None: | |
| self.response = response | |
| self._cast_to = cast_to | |
| self._client = client | |
| self._decoder = client._make_sse_decoder() | |
| self._iterator = self.__stream__() | |
| async def __anext__(self) -> _T: | |
| return await self._iterator.__anext__() | |
| async def __aiter__(self) -> AsyncIterator[_T]: | |
| async for item in self._iterator: | |
| yield item | |
| async def _iter_events(self) -> AsyncIterator[ServerSentEvent]: | |
| async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()): | |
| yield sse | |
| async def __stream__(self) -> AsyncIterator[_T]: | |
| cast_to = cast(Any, self._cast_to) | |
| response = self.response | |
| process_data = self._client._process_response_data | |
| iterator = self._iter_events() | |
| async for sse in iterator: | |
| if sse.event == "completion": | |
| yield process_data(data=sse.json(), cast_to=cast_to, response=response) | |
| if ( | |
| sse.event == "message_start" | |
| or sse.event == "message_delta" | |
| or sse.event == "message_stop" | |
| or sse.event == "content_block_start" | |
| or sse.event == "content_block_delta" | |
| or sse.event == "content_block_stop" | |
| ): | |
| data = sse.json() | |
| if is_dict(data) and "type" not in data: | |
| data["type"] = sse.event | |
| yield process_data(data=data, cast_to=cast_to, response=response) | |
| if sse.event == "ping": | |
| continue | |
| if sse.event == "error": | |
| body = sse.data | |
| try: | |
| body = sse.json() | |
| err_msg = f"{body}" | |
| except Exception: | |
| err_msg = sse.data or f"Error code: {response.status_code}" | |
| raise self._client._make_status_error( | |
| err_msg, | |
| body=body, | |
| response=self.response, | |
| ) | |
| # Ensure the entire stream is consumed | |
| async for _sse in iterator: | |
| ... | |
| async def __aenter__(self) -> Self: | |
| return self | |
| async def __aexit__( | |
| self, | |
| exc_type: type[BaseException] | None, | |
| exc: BaseException | None, | |
| exc_tb: TracebackType | None, | |
| ) -> None: | |
| await self.close() | |
| async def close(self) -> None: | |
| """ | |
| Close the response and release the connection. | |
| Automatically called if the response body is read to completion. | |
| """ | |
| await self.response.aclose() | |
| class ServerSentEvent: | |
| def __init__( | |
| self, | |
| *, | |
| event: str | None = None, | |
| data: str | None = None, | |
| id: str | None = None, | |
| retry: int | None = None, | |
| ) -> None: | |
| if data is None: | |
| data = "" | |
| self._id = id | |
| self._data = data | |
| self._event = event or None | |
| self._retry = retry | |
| def event(self) -> str | None: | |
| return self._event | |
| def id(self) -> str | None: | |
| return self._id | |
| def retry(self) -> int | None: | |
| return self._retry | |
| def data(self) -> str: | |
| return self._data | |
| def json(self) -> Any: | |
| return json.loads(self.data) | |
| def __repr__(self) -> str: | |
| return f"ServerSentEvent(event={self.event}, data={self.data}, id={self.id}, retry={self.retry})" | |
| class SSEDecoder: | |
| _data: list[str] | |
| _event: str | None | |
| _retry: int | None | |
| _last_event_id: str | None | |
| def __init__(self) -> None: | |
| self._event = None | |
| self._data = [] | |
| self._last_event_id = None | |
| self._retry = None | |
| def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]: | |
| """Given an iterator that yields raw binary data, iterate over it & yield every event encountered""" | |
| for chunk in self._iter_chunks(iterator): | |
| # Split before decoding so splitlines() only uses \r and \n | |
| for raw_line in chunk.splitlines(): | |
| line = raw_line.decode("utf-8") | |
| sse = self.decode(line) | |
| if sse: | |
| yield sse | |
| def _iter_chunks(self, iterator: Iterator[bytes]) -> Iterator[bytes]: | |
| """Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks""" | |
| data = b"" | |
| for chunk in iterator: | |
| for line in chunk.splitlines(keepends=True): | |
| data += line | |
| if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")): | |
| yield data | |
| data = b"" | |
| if data: | |
| yield data | |
| async def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]: | |
| """Given an iterator that yields raw binary data, iterate over it & yield every event encountered""" | |
| async for chunk in self._aiter_chunks(iterator): | |
| # Split before decoding so splitlines() only uses \r and \n | |
| for raw_line in chunk.splitlines(): | |
| line = raw_line.decode("utf-8") | |
| sse = self.decode(line) | |
| if sse: | |
| yield sse | |
| async def _aiter_chunks(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[bytes]: | |
| """Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks""" | |
| data = b"" | |
| async for chunk in iterator: | |
| for line in chunk.splitlines(keepends=True): | |
| data += line | |
| if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")): | |
| yield data | |
| data = b"" | |
| if data: | |
| yield data | |
| def decode(self, line: str) -> ServerSentEvent | None: | |
| # See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation # noqa: E501 | |
| if not line: | |
| if not self._event and not self._data and not self._last_event_id and self._retry is None: | |
| return None | |
| sse = ServerSentEvent( | |
| event=self._event, | |
| data="\n".join(self._data), | |
| id=self._last_event_id, | |
| retry=self._retry, | |
| ) | |
| # NOTE: as per the SSE spec, do not reset last_event_id. | |
| self._event = None | |
| self._data = [] | |
| self._retry = None | |
| return sse | |
| if line.startswith(":"): | |
| return None | |
| fieldname, _, value = line.partition(":") | |
| if value.startswith(" "): | |
| value = value[1:] | |
| if fieldname == "event": | |
| self._event = value | |
| elif fieldname == "data": | |
| self._data.append(value) | |
| elif fieldname == "id": | |
| if "\0" in value: | |
| pass | |
| else: | |
| self._last_event_id = value | |
| elif fieldname == "retry": | |
| try: | |
| self._retry = int(value) | |
| except (TypeError, ValueError): | |
| pass | |
| else: | |
| pass # Field is ignored. | |
| return None | |
| class SSEBytesDecoder(Protocol): | |
| def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]: | |
| """Given an iterator that yields raw binary data, iterate over it & yield every event encountered""" | |
| ... | |
| def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]: | |
| """Given an async iterator that yields raw binary data, iterate over it & yield every event encountered""" | |
| ... | |
| def is_stream_class_type(typ: type) -> TypeGuard[type[Stream[object]] | type[AsyncStream[object]]]: | |
| """TypeGuard for determining whether or not the given type is a subclass of `Stream` / `AsyncStream`""" | |
| origin = get_origin(typ) or typ | |
| return inspect.isclass(origin) and issubclass(origin, (Stream, AsyncStream)) | |
| def extract_stream_chunk_type( | |
| stream_cls: type, | |
| *, | |
| failure_message: str | None = None, | |
| ) -> type: | |
| """Given a type like `Stream[T]`, returns the generic type variable `T`. | |
| This also handles the case where a concrete subclass is given, e.g. | |
| ```py | |
| class MyStream(Stream[bytes]): | |
| ... | |
| extract_stream_chunk_type(MyStream) -> bytes | |
| ``` | |
| """ | |
| from ._base_client import Stream, AsyncStream | |
| return extract_type_var_from_base( | |
| stream_cls, | |
| index=0, | |
| generic_bases=cast("tuple[type, ...]", (Stream, AsyncStream)), | |
| failure_message=failure_message, | |
| ) | |