Spaces:
Running
Running
| """Track content-block state for native Anthropic SSE strings we emit to clients.""" | |
| from __future__ import annotations | |
| import uuid | |
| from collections.abc import Iterator | |
| from contextlib import suppress | |
| from typing import Any | |
| from core.anthropic.sse import SSEBuilder, format_sse_event | |
| from core.anthropic.stream_contracts import SSEEvent, event_index, parse_sse_lines | |
| class EmittedNativeSseTracker: | |
| """Parse emitted SSE frames so mid-stream errors can close blocks and pick a fresh index.""" | |
| def __init__(self) -> None: | |
| self._buf = "" | |
| self._open_stack: list[int] = [] | |
| self._max_index = -1 | |
| self.message_id: str | None = None | |
| self.model: str = "" | |
| def feed(self, chunk: str) -> None: | |
| """Record SSE frames completed by ``chunk`` (handles splitting across reads).""" | |
| self._buf += chunk | |
| while True: | |
| sep = self._buf.find("\n\n") | |
| if sep < 0: | |
| break | |
| frame = self._buf[:sep] | |
| self._buf = self._buf[sep + 2 :] | |
| if not frame.strip(): | |
| continue | |
| for event in parse_sse_lines(frame.splitlines()): | |
| self._observe(event) | |
| def _observe(self, event: SSEEvent) -> None: | |
| if event.event == "message_start": | |
| message = event.data.get("message") | |
| if isinstance(message, dict): | |
| mid = message.get("id") | |
| if isinstance(mid, str) and mid: | |
| self.message_id = mid | |
| model = message.get("model") | |
| if isinstance(model, str) and model: | |
| self.model = model | |
| return | |
| if event.event == "content_block_start": | |
| idx = event_index(event) | |
| self._max_index = max(self._max_index, idx) | |
| self._open_stack.append(idx) | |
| return | |
| if event.event == "content_block_stop": | |
| idx = event_index(event) | |
| if self._open_stack and self._open_stack[-1] == idx: | |
| self._open_stack.pop() | |
| else: | |
| with suppress(ValueError): | |
| self._open_stack.remove(idx) | |
| def next_content_index(self) -> int: | |
| """Next unused content block index based on emitted starts.""" | |
| return self._max_index + 1 | |
| def iter_close_unclosed_blocks(self) -> Iterator[str]: | |
| """Yield ``content_block_stop`` events for blocks that were started but not stopped.""" | |
| while self._open_stack: | |
| idx = self._open_stack.pop() | |
| yield format_sse_event( | |
| "content_block_stop", | |
| {"type": "content_block_stop", "index": idx}, | |
| ) | |
| def iter_midstream_error_tail( | |
| self, | |
| error_message: str, | |
| *, | |
| request: Any, | |
| input_tokens: int, | |
| log_raw_sse_events: bool, | |
| ) -> Iterator[str]: | |
| """Close dangling blocks, emit a text error block at a fresh index, then message tail.""" | |
| mid = self.message_id or f"msg_{uuid.uuid4()}" | |
| model = self.model or (getattr(request, "model", "") or "") | |
| sse = SSEBuilder( | |
| mid, | |
| model, | |
| input_tokens, | |
| log_raw_events=log_raw_sse_events, | |
| ) | |
| sse.blocks.next_index = self.next_content_index() | |
| yield from sse.emit_error(error_message) | |
| yield sse.message_delta("end_turn", 1) | |
| yield sse.message_stop() | |