web2api / core /protocol /openai.py
ohmyapi's picture
feat: align hosted Space deployment with latest upstream
77169b4
"""OpenAI 协议适配器。"""
from __future__ import annotations
import json
import re
import time
import uuid as uuid_mod
from collections.abc import AsyncIterator
from typing import Any
from core.api.conv_parser import (
extract_session_id_marker,
parse_conv_uuid_from_messages,
strip_session_id_suffix,
)
from core.api.function_call import build_tool_calls_response
from core.api.react import (
format_react_final_answer_content,
parse_react_output,
react_output_to_tool_calls,
)
from core.api.react_stream_parser import ReactStreamParser
from core.api.schemas import OpenAIChatRequest, OpenAIContentPart, OpenAIMessage
from core.hub.schemas import OpenAIStreamEvent
from core.protocol.base import ProtocolAdapter
from core.protocol.schemas import (
CanonicalChatRequest,
CanonicalContentBlock,
CanonicalMessage,
CanonicalToolSpec,
)
class OpenAIProtocolAdapter(ProtocolAdapter):
protocol_name = "openai"
def parse_request(
self,
provider: str,
raw_body: dict[str, Any],
) -> CanonicalChatRequest:
req = OpenAIChatRequest.model_validate(raw_body)
resume_session_id = parse_conv_uuid_from_messages(
[self._message_to_raw_dict(m) for m in req.messages]
)
system_blocks: list[CanonicalContentBlock] = []
messages: list[CanonicalMessage] = []
for msg in req.messages:
blocks = self._to_blocks(msg.content)
if msg.role == "system":
system_blocks.extend(blocks)
else:
messages.append(CanonicalMessage(role=msg.role, content=blocks))
tools = [self._to_tool_spec(tool) for tool in list(req.tools or [])]
return CanonicalChatRequest(
protocol="openai",
provider=provider,
model=req.model,
system=system_blocks,
messages=messages,
stream=req.stream,
tools=tools,
tool_choice=req.tool_choice,
resume_session_id=resume_session_id,
)
def render_non_stream(
self,
req: CanonicalChatRequest,
raw_events: list[OpenAIStreamEvent],
) -> dict[str, Any]:
reply = "".join(
ev.content or ""
for ev in raw_events
if ev.type == "content_delta" and ev.content
)
session_marker = extract_session_id_marker(reply)
content_for_parse = strip_session_id_suffix(reply)
chat_id, created = self._response_context(req)
if req.tools:
parsed = parse_react_output(content_for_parse)
tool_calls_list = react_output_to_tool_calls(parsed) if parsed else []
if tool_calls_list:
thought_ns = ""
if "Thought" in content_for_parse:
match = re.search(
r"Thought[::]\s*(.+?)(?=\s*Action[::]|$)",
content_for_parse,
re.DOTALL | re.I,
)
thought_ns = (match.group(1) or "").strip() if match else ""
text_content = (
f"<think>{thought_ns}</think>\n{session_marker}".strip()
if thought_ns
else session_marker
)
return build_tool_calls_response(
tool_calls_list,
chat_id,
req.model,
created,
text_content=text_content,
)
content_reply = format_react_final_answer_content(content_for_parse)
if session_marker:
content_reply += session_marker
else:
content_reply = content_for_parse
return {
"id": chat_id,
"object": "chat.completion",
"created": created,
"model": req.model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": content_reply},
"finish_reason": "stop",
}
],
}
async def render_stream(
self,
req: CanonicalChatRequest,
raw_stream: AsyncIterator[OpenAIStreamEvent],
) -> AsyncIterator[str]:
chat_id, created = self._response_context(req)
parser = ReactStreamParser(
chat_id=chat_id,
model=req.model,
created=created,
has_tools=bool(req.tools),
)
session_marker = ""
async for event in raw_stream:
if event.type == "content_delta" and event.content:
chunk = event.content
if extract_session_id_marker(chunk) and not strip_session_id_suffix(
chunk
):
session_marker = chunk
continue
for sse in parser.feed(chunk):
yield sse
elif event.type == "finish":
break
if session_marker:
yield self._content_delta(chat_id, req.model, created, session_marker)
for sse in parser.finish():
yield sse
def render_error(self, exc: Exception) -> tuple[int, dict[str, Any]]:
status = 400 if isinstance(exc, ValueError) else 500
err_type = "invalid_request_error" if status == 400 else "server_error"
return (
status,
{"error": {"message": str(exc), "type": err_type}},
)
@staticmethod
def _message_to_raw_dict(msg: OpenAIMessage) -> dict[str, Any]:
if isinstance(msg.content, list):
content: str | list[dict[str, Any]] = [p.model_dump() for p in msg.content]
else:
content = msg.content
out: dict[str, Any] = {"role": msg.role, "content": content}
if msg.tool_calls is not None:
out["tool_calls"] = msg.tool_calls
if msg.tool_call_id is not None:
out["tool_call_id"] = msg.tool_call_id
return out
@staticmethod
def _to_blocks(
content: str | list[OpenAIContentPart] | None,
) -> list[CanonicalContentBlock]:
if content is None:
return []
if isinstance(content, str):
return [
CanonicalContentBlock(
type="text", text=strip_session_id_suffix(content)
)
]
blocks: list[CanonicalContentBlock] = []
for part in content:
if part.type == "text":
blocks.append(
CanonicalContentBlock(
type="text",
text=strip_session_id_suffix(part.text or ""),
)
)
elif part.type == "image_url":
image_url = part.image_url
url = image_url.get("url") if isinstance(image_url, dict) else image_url
if not url:
continue
if isinstance(url, str) and url.startswith("data:"):
blocks.append(CanonicalContentBlock(type="image", data=url))
else:
blocks.append(CanonicalContentBlock(type="image", url=str(url)))
return blocks
@staticmethod
def _to_tool_spec(tool: dict[str, Any]) -> CanonicalToolSpec:
function = tool.get("function") if tool.get("type") == "function" else tool
return CanonicalToolSpec(
name=str(function.get("name") or ""),
description=str(function.get("description") or ""),
input_schema=function.get("parameters")
or function.get("input_schema")
or {},
strict=bool(function.get("strict") or False),
)
@staticmethod
def _content_delta(chat_id: str, model: str, created: int, text: str) -> str:
return (
"data: "
+ json.dumps(
{
"id": chat_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": text},
"logprobs": None,
"finish_reason": None,
}
],
},
ensure_ascii=False,
)
+ "\n\n"
)
@staticmethod
def _response_context(req: CanonicalChatRequest) -> tuple[str, int]:
chat_id = str(
req.metadata.setdefault(
"response_id", f"chatcmpl-{uuid_mod.uuid4().hex[:24]}"
)
)
created = int(req.metadata.setdefault("created", int(time.time())))
return chat_id, created