| """OpenAI 协议适配器。""" |
|
|
| from __future__ import annotations |
|
|
| import json |
| import re |
| import time |
| import uuid as uuid_mod |
| from collections.abc import AsyncIterator |
| from typing import Any |
|
|
| from core.api.conv_parser import ( |
| extract_session_id_marker, |
| parse_conv_uuid_from_messages, |
| strip_session_id_suffix, |
| ) |
| from core.api.function_call import build_tool_calls_response |
| from core.api.react import ( |
| format_react_final_answer_content, |
| parse_react_output, |
| react_output_to_tool_calls, |
| ) |
| from core.api.react_stream_parser import ReactStreamParser |
| from core.api.schemas import OpenAIChatRequest, OpenAIContentPart, OpenAIMessage |
| from core.hub.schemas import OpenAIStreamEvent |
| from core.protocol.base import ProtocolAdapter |
| from core.protocol.schemas import ( |
| CanonicalChatRequest, |
| CanonicalContentBlock, |
| CanonicalMessage, |
| CanonicalToolSpec, |
| ) |
|
|
|
|
| class OpenAIProtocolAdapter(ProtocolAdapter): |
| protocol_name = "openai" |
|
|
| def parse_request( |
| self, |
| provider: str, |
| raw_body: dict[str, Any], |
| ) -> CanonicalChatRequest: |
| req = OpenAIChatRequest.model_validate(raw_body) |
| resume_session_id = parse_conv_uuid_from_messages( |
| [self._message_to_raw_dict(m) for m in req.messages] |
| ) |
| system_blocks: list[CanonicalContentBlock] = [] |
| messages: list[CanonicalMessage] = [] |
| for msg in req.messages: |
| blocks = self._to_blocks(msg.content) |
| if msg.role == "system": |
| system_blocks.extend(blocks) |
| else: |
| messages.append(CanonicalMessage(role=msg.role, content=blocks)) |
| tools = [self._to_tool_spec(tool) for tool in list(req.tools or [])] |
| return CanonicalChatRequest( |
| protocol="openai", |
| provider=provider, |
| model=req.model, |
| system=system_blocks, |
| messages=messages, |
| stream=req.stream, |
| tools=tools, |
| tool_choice=req.tool_choice, |
| resume_session_id=resume_session_id, |
| ) |
|
|
| def render_non_stream( |
| self, |
| req: CanonicalChatRequest, |
| raw_events: list[OpenAIStreamEvent], |
| ) -> dict[str, Any]: |
| reply = "".join( |
| ev.content or "" |
| for ev in raw_events |
| if ev.type == "content_delta" and ev.content |
| ) |
| session_marker = extract_session_id_marker(reply) |
| content_for_parse = strip_session_id_suffix(reply) |
| chat_id, created = self._response_context(req) |
| if req.tools: |
| parsed = parse_react_output(content_for_parse) |
| tool_calls_list = react_output_to_tool_calls(parsed) if parsed else [] |
| if tool_calls_list: |
| thought_ns = "" |
| if "Thought" in content_for_parse: |
| match = re.search( |
| r"Thought[::]\s*(.+?)(?=\s*Action[::]|$)", |
| content_for_parse, |
| re.DOTALL | re.I, |
| ) |
| thought_ns = (match.group(1) or "").strip() if match else "" |
| text_content = ( |
| f"<think>{thought_ns}</think>\n{session_marker}".strip() |
| if thought_ns |
| else session_marker |
| ) |
| return build_tool_calls_response( |
| tool_calls_list, |
| chat_id, |
| req.model, |
| created, |
| text_content=text_content, |
| ) |
| content_reply = format_react_final_answer_content(content_for_parse) |
| if session_marker: |
| content_reply += session_marker |
| else: |
| content_reply = content_for_parse |
| return { |
| "id": chat_id, |
| "object": "chat.completion", |
| "created": created, |
| "model": req.model, |
| "choices": [ |
| { |
| "index": 0, |
| "message": {"role": "assistant", "content": content_reply}, |
| "finish_reason": "stop", |
| } |
| ], |
| } |
|
|
| async def render_stream( |
| self, |
| req: CanonicalChatRequest, |
| raw_stream: AsyncIterator[OpenAIStreamEvent], |
| ) -> AsyncIterator[str]: |
| chat_id, created = self._response_context(req) |
| parser = ReactStreamParser( |
| chat_id=chat_id, |
| model=req.model, |
| created=created, |
| has_tools=bool(req.tools), |
| ) |
| session_marker = "" |
| async for event in raw_stream: |
| if event.type == "content_delta" and event.content: |
| chunk = event.content |
| if extract_session_id_marker(chunk) and not strip_session_id_suffix( |
| chunk |
| ): |
| session_marker = chunk |
| continue |
| for sse in parser.feed(chunk): |
| yield sse |
| elif event.type == "finish": |
| break |
| if session_marker: |
| yield self._content_delta(chat_id, req.model, created, session_marker) |
| for sse in parser.finish(): |
| yield sse |
|
|
| def render_error(self, exc: Exception) -> tuple[int, dict[str, Any]]: |
| status = 400 if isinstance(exc, ValueError) else 500 |
| err_type = "invalid_request_error" if status == 400 else "server_error" |
| return ( |
| status, |
| {"error": {"message": str(exc), "type": err_type}}, |
| ) |
|
|
| @staticmethod |
| def _message_to_raw_dict(msg: OpenAIMessage) -> dict[str, Any]: |
| if isinstance(msg.content, list): |
| content: str | list[dict[str, Any]] = [p.model_dump() for p in msg.content] |
| else: |
| content = msg.content |
| out: dict[str, Any] = {"role": msg.role, "content": content} |
| if msg.tool_calls is not None: |
| out["tool_calls"] = msg.tool_calls |
| if msg.tool_call_id is not None: |
| out["tool_call_id"] = msg.tool_call_id |
| return out |
|
|
| @staticmethod |
| def _to_blocks( |
| content: str | list[OpenAIContentPart] | None, |
| ) -> list[CanonicalContentBlock]: |
| if content is None: |
| return [] |
| if isinstance(content, str): |
| return [ |
| CanonicalContentBlock( |
| type="text", text=strip_session_id_suffix(content) |
| ) |
| ] |
| blocks: list[CanonicalContentBlock] = [] |
| for part in content: |
| if part.type == "text": |
| blocks.append( |
| CanonicalContentBlock( |
| type="text", |
| text=strip_session_id_suffix(part.text or ""), |
| ) |
| ) |
| elif part.type == "image_url": |
| image_url = part.image_url |
| url = image_url.get("url") if isinstance(image_url, dict) else image_url |
| if not url: |
| continue |
| if isinstance(url, str) and url.startswith("data:"): |
| blocks.append(CanonicalContentBlock(type="image", data=url)) |
| else: |
| blocks.append(CanonicalContentBlock(type="image", url=str(url))) |
| return blocks |
|
|
| @staticmethod |
| def _to_tool_spec(tool: dict[str, Any]) -> CanonicalToolSpec: |
| function = tool.get("function") if tool.get("type") == "function" else tool |
| return CanonicalToolSpec( |
| name=str(function.get("name") or ""), |
| description=str(function.get("description") or ""), |
| input_schema=function.get("parameters") |
| or function.get("input_schema") |
| or {}, |
| strict=bool(function.get("strict") or False), |
| ) |
|
|
| @staticmethod |
| def _content_delta(chat_id: str, model: str, created: int, text: str) -> str: |
| return ( |
| "data: " |
| + json.dumps( |
| { |
| "id": chat_id, |
| "object": "chat.completion.chunk", |
| "created": created, |
| "model": model, |
| "choices": [ |
| { |
| "index": 0, |
| "delta": {"content": text}, |
| "logprobs": None, |
| "finish_reason": None, |
| } |
| ], |
| }, |
| ensure_ascii=False, |
| ) |
| + "\n\n" |
| ) |
|
|
| @staticmethod |
| def _response_context(req: CanonicalChatRequest) -> tuple[str, int]: |
| chat_id = str( |
| req.metadata.setdefault( |
| "response_id", f"chatcmpl-{uuid_mod.uuid4().hex[:24]}" |
| ) |
| ) |
| created = int(req.metadata.setdefault("created", int(time.time()))) |
| return chat_id, created |
|
|