claude-code-proxy / core /anthropic /emitted_sse_tracker.py
Yash030's picture
Deploy claude-code-nvidia proxy to Hugging Face Spaces
0157ac7
"""Track content-block state for native Anthropic SSE strings we emit to clients."""
from __future__ import annotations
import uuid
from collections.abc import Iterator
from contextlib import suppress
from typing import Any
from core.anthropic.sse import SSEBuilder, format_sse_event
from core.anthropic.stream_contracts import SSEEvent, event_index, parse_sse_lines
class EmittedNativeSseTracker:
"""Parse emitted SSE frames so mid-stream errors can close blocks and pick a fresh index."""
def __init__(self) -> None:
self._buf = ""
self._open_stack: list[int] = []
self._max_index = -1
self.message_id: str | None = None
self.model: str = ""
def feed(self, chunk: str) -> None:
"""Record SSE frames completed by ``chunk`` (handles splitting across reads)."""
self._buf += chunk
while True:
sep = self._buf.find("\n\n")
if sep < 0:
break
frame = self._buf[:sep]
self._buf = self._buf[sep + 2 :]
if not frame.strip():
continue
for event in parse_sse_lines(frame.splitlines()):
self._observe(event)
def _observe(self, event: SSEEvent) -> None:
if event.event == "message_start":
message = event.data.get("message")
if isinstance(message, dict):
mid = message.get("id")
if isinstance(mid, str) and mid:
self.message_id = mid
model = message.get("model")
if isinstance(model, str) and model:
self.model = model
return
if event.event == "content_block_start":
idx = event_index(event)
self._max_index = max(self._max_index, idx)
self._open_stack.append(idx)
return
if event.event == "content_block_stop":
idx = event_index(event)
if self._open_stack and self._open_stack[-1] == idx:
self._open_stack.pop()
else:
with suppress(ValueError):
self._open_stack.remove(idx)
def next_content_index(self) -> int:
"""Next unused content block index based on emitted starts."""
return self._max_index + 1
def iter_close_unclosed_blocks(self) -> Iterator[str]:
"""Yield ``content_block_stop`` events for blocks that were started but not stopped."""
while self._open_stack:
idx = self._open_stack.pop()
yield format_sse_event(
"content_block_stop",
{"type": "content_block_stop", "index": idx},
)
def iter_midstream_error_tail(
self,
error_message: str,
*,
request: Any,
input_tokens: int,
log_raw_sse_events: bool,
) -> Iterator[str]:
"""Close dangling blocks, emit a text error block at a fresh index, then message tail."""
mid = self.message_id or f"msg_{uuid.uuid4()}"
model = self.model or (getattr(request, "model", "") or "")
sse = SSEBuilder(
mid,
model,
input_tokens,
log_raw_events=log_raw_sse_events,
)
sse.blocks.next_index = self.next_content_index()
yield from sse.emit_error(error_message)
yield sse.message_delta("end_turn", 1)
yield sse.message_stop()