"""Anthropic 协议适配器。""" from __future__ import annotations import json import time import uuid as uuid_mod from collections.abc import AsyncIterator from typing import Any from core.api.conv_parser import ( decode_latest_session_id, extract_session_id_marker, strip_session_id_suffix, ) from core.api.react import format_react_final_answer_content, parse_react_output from core.api.react_stream_parser import ReactStreamParser from core.hub.schemas import OpenAIStreamEvent from core.protocol.base import ProtocolAdapter from core.protocol.schemas import ( CanonicalChatRequest, CanonicalContentBlock, CanonicalMessage, CanonicalToolSpec, ) class AnthropicProtocolAdapter(ProtocolAdapter): protocol_name = "anthropic" def parse_request( self, provider: str, raw_body: dict[str, Any], ) -> CanonicalChatRequest: messages = raw_body.get("messages") or [] if not isinstance(messages, list): raise ValueError("messages 必须为数组") system_blocks = self._parse_content(raw_body.get("system")) canonical_messages: list[CanonicalMessage] = [] resume_session_id: str | None = None for item in messages: if not isinstance(item, dict): continue blocks = self._parse_content(item.get("content")) for block in blocks: text = block.text or "" decoded = decode_latest_session_id(text) if decoded: resume_session_id = decoded block.text = strip_session_id_suffix(text) canonical_messages.append( CanonicalMessage( role=str(item.get("role") or "user"), content=blocks, ) ) for block in system_blocks: text = block.text or "" decoded = decode_latest_session_id(text) if decoded: resume_session_id = decoded block.text = strip_session_id_suffix(text) tools = [self._parse_tool(tool) for tool in list(raw_body.get("tools") or [])] stop_sequences = raw_body.get("stop_sequences") or [] return CanonicalChatRequest( protocol="anthropic", provider=provider, model=str(raw_body.get("model") or ""), system=system_blocks, messages=canonical_messages, stream=bool(raw_body.get("stream") or False), max_tokens=raw_body.get("max_tokens"), temperature=raw_body.get("temperature"), top_p=raw_body.get("top_p"), stop_sequences=[str(v) for v in stop_sequences if isinstance(v, str)], tools=tools, tool_choice=raw_body.get("tool_choice"), resume_session_id=resume_session_id, ) def render_non_stream( self, req: CanonicalChatRequest, raw_events: list[OpenAIStreamEvent], ) -> dict[str, Any]: full = "".join( ev.content or "" for ev in raw_events if ev.type == "content_delta" and ev.content ) session_marker = extract_session_id_marker(full) text = strip_session_id_suffix(full) message_id = self._message_id(req) if req.tools: parsed = parse_react_output(text) if parsed and parsed.get("type") == "tool_call": content: list[dict[str, Any]] = [ { "type": "tool_use", "id": f"toolu_{uuid_mod.uuid4().hex[:24]}", "name": str(parsed.get("tool") or ""), "input": parsed.get("params") or {}, } ] if session_marker: content.append({"type": "text", "text": session_marker}) return self._message_response( req, message_id, content, stop_reason="tool_use", ) rendered = format_react_final_answer_content(text) else: rendered = text if session_marker: rendered += session_marker return self._message_response( req, message_id, [{"type": "text", "text": rendered}], stop_reason="end_turn", ) async def render_stream( self, req: CanonicalChatRequest, raw_stream: AsyncIterator[OpenAIStreamEvent], ) -> AsyncIterator[str]: message_id = self._message_id(req) parser = ReactStreamParser( chat_id=f"chatcmpl-{uuid_mod.uuid4().hex[:24]}", model=req.model, created=int(time.time()), has_tools=bool(req.tools), ) session_marker = "" translator = _AnthropicStreamTranslator(req, message_id) async for event in raw_stream: if event.type == "content_delta" and event.content: chunk = event.content if extract_session_id_marker(chunk) and not strip_session_id_suffix( chunk ): session_marker = chunk continue for sse in parser.feed(chunk): for out in translator.feed_openai_sse(sse): yield out elif event.type == "finish": break for sse in parser.finish(): for out in translator.feed_openai_sse(sse, session_marker=session_marker): yield out def render_error(self, exc: Exception) -> tuple[int, dict[str, Any]]: status = 400 if isinstance(exc, ValueError) else 500 err_type = "invalid_request_error" if status == 400 else "api_error" return ( status, { "type": "error", "error": {"type": err_type, "message": str(exc)}, }, ) @staticmethod def _parse_tool(tool: dict[str, Any]) -> CanonicalToolSpec: return CanonicalToolSpec( name=str(tool.get("name") or ""), description=str(tool.get("description") or ""), input_schema=tool.get("input_schema") or {}, ) @staticmethod def _parse_content(value: Any) -> list[CanonicalContentBlock]: if value is None: return [] if isinstance(value, str): return [CanonicalContentBlock(type="text", text=value)] if isinstance(value, list): blocks: list[CanonicalContentBlock] = [] for item in value: if isinstance(item, str): blocks.append(CanonicalContentBlock(type="text", text=item)) continue if not isinstance(item, dict): continue item_type = str(item.get("type") or "") if item_type == "text": blocks.append( CanonicalContentBlock( type="text", text=str(item.get("text") or "") ) ) elif item_type == "image": source = item.get("source") or {} source_type = source.get("type") if source_type == "base64": blocks.append( CanonicalContentBlock( type="image", mime_type=str(source.get("media_type") or ""), data=str(source.get("data") or ""), ) ) elif item_type == "tool_result": text_parts = AnthropicProtocolAdapter._parse_content( item.get("content") ) blocks.append( CanonicalContentBlock( type="tool_result", tool_use_id=str(item.get("tool_use_id") or ""), text="\n".join( part.text or "" for part in text_parts if part.type == "text" ), is_error=bool(item.get("is_error") or False), ) ) return blocks raise ValueError("content 格式不合法") @staticmethod def _message_response( req: CanonicalChatRequest, message_id: str, content: list[dict[str, Any]], *, stop_reason: str, ) -> dict[str, Any]: return { "id": message_id, "type": "message", "role": "assistant", "model": req.model, "content": content, "stop_reason": stop_reason, "stop_sequence": None, "usage": {"input_tokens": 0, "output_tokens": 0}, } @staticmethod def _message_id(req: CanonicalChatRequest) -> str: return str( req.metadata.setdefault( "anthropic_message_id", f"msg_{uuid_mod.uuid4().hex}" ) ) class _AnthropicStreamTranslator: def __init__(self, req: CanonicalChatRequest, message_id: str) -> None: self._req = req self._message_id = message_id self._started = False self._current_block_type: str | None = None self._current_index = -1 self._pending_tool_id: str | None = None self._pending_tool_name: str | None = None self._stopped = False def feed_openai_sse( self, sse: str, *, session_marker: str = "", ) -> list[str]: lines = [line for line in sse.splitlines() if line.startswith("data: ")] out: list[str] = [] for line in lines: payload = line[6:].strip() if payload == "[DONE]": continue obj = json.loads(payload) choice = (obj.get("choices") or [{}])[0] delta = choice.get("delta") or {} finish_reason = choice.get("finish_reason") if not self._started: out.append( self._event( "message_start", { "type": "message_start", "message": { "id": self._message_id, "type": "message", "role": "assistant", "model": self._req.model, "content": [], "stop_reason": None, "stop_sequence": None, "usage": {"input_tokens": 0, "output_tokens": 0}, }, }, ) ) self._started = True content = delta.get("content") if isinstance(content, str) and content: out.extend(self._ensure_text_block()) out.append( self._event( "content_block_delta", { "type": "content_block_delta", "index": self._current_index, "delta": {"type": "text_delta", "text": content}, }, ) ) tool_calls = delta.get("tool_calls") or [] if tool_calls: head = tool_calls[0] if head.get("id") and head.get("function", {}).get("name") is not None: out.extend(self._close_current_block()) self._current_index += 1 self._current_block_type = "tool_use" self._pending_tool_id = str(head.get("id") or "") self._pending_tool_name = str( head.get("function", {}).get("name") or "" ) out.append( self._event( "content_block_start", { "type": "content_block_start", "index": self._current_index, "content_block": { "type": "tool_use", "id": self._pending_tool_id, "name": self._pending_tool_name, "input": {}, }, }, ) ) args_delta = head.get("function", {}).get("arguments") if args_delta: out.append( self._event( "content_block_delta", { "type": "content_block_delta", "index": self._current_index, "delta": { "type": "input_json_delta", "partial_json": str(args_delta), }, }, ) ) if finish_reason: if session_marker: if finish_reason == "tool_calls": out.extend(self._close_current_block()) out.extend(self._emit_marker_text_block(session_marker)) else: out.extend(self._ensure_text_block()) out.append( self._event( "content_block_delta", { "type": "content_block_delta", "index": self._current_index, "delta": { "type": "text_delta", "text": session_marker, }, }, ) ) out.extend(self._close_current_block()) stop_reason = ( "tool_use" if finish_reason == "tool_calls" else "end_turn" ) out.append( self._event( "message_delta", { "type": "message_delta", "delta": { "stop_reason": stop_reason, "stop_sequence": None, }, "usage": {"output_tokens": 0}, }, ) ) out.append(self._event("message_stop", {"type": "message_stop"})) self._stopped = True return out def _ensure_text_block(self) -> list[str]: if self._current_block_type == "text": return [] out = self._close_current_block() self._current_index += 1 self._current_block_type = "text" out.append( self._event( "content_block_start", { "type": "content_block_start", "index": self._current_index, "content_block": {"type": "text", "text": ""}, }, ) ) return out def _emit_marker_text_block(self, marker: str) -> list[str]: self._current_index += 1 self._current_block_type = "text" return [ self._event( "content_block_start", { "type": "content_block_start", "index": self._current_index, "content_block": {"type": "text", "text": ""}, }, ), self._event( "content_block_delta", { "type": "content_block_delta", "index": self._current_index, "delta": {"type": "text_delta", "text": marker}, }, ), self._event( "content_block_stop", {"type": "content_block_stop", "index": self._current_index}, ), ] def _close_current_block(self) -> list[str]: if self._current_block_type is None: return [] block_index = self._current_index self._current_block_type = None return [ self._event( "content_block_stop", {"type": "content_block_stop", "index": block_index}, ) ] @staticmethod def _event(event_name: str, payload: dict[str, Any]) -> str: del event_name return f"event: {payload['type']}\ndata: {json.dumps(payload, ensure_ascii=False)}\n\n"