File size: 8,933 Bytes
77169b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
"""OpenAI 协议适配器。"""

from __future__ import annotations

import json
import re
import time
import uuid as uuid_mod
from collections.abc import AsyncIterator
from typing import Any

from core.api.conv_parser import (
    extract_session_id_marker,
    parse_conv_uuid_from_messages,
    strip_session_id_suffix,
)
from core.api.function_call import build_tool_calls_response
from core.api.react import (
    format_react_final_answer_content,
    parse_react_output,
    react_output_to_tool_calls,
)
from core.api.react_stream_parser import ReactStreamParser
from core.api.schemas import OpenAIChatRequest, OpenAIContentPart, OpenAIMessage
from core.hub.schemas import OpenAIStreamEvent
from core.protocol.base import ProtocolAdapter
from core.protocol.schemas import (
    CanonicalChatRequest,
    CanonicalContentBlock,
    CanonicalMessage,
    CanonicalToolSpec,
)


class OpenAIProtocolAdapter(ProtocolAdapter):
    protocol_name = "openai"

    def parse_request(
        self,
        provider: str,
        raw_body: dict[str, Any],
    ) -> CanonicalChatRequest:
        req = OpenAIChatRequest.model_validate(raw_body)
        resume_session_id = parse_conv_uuid_from_messages(
            [self._message_to_raw_dict(m) for m in req.messages]
        )
        system_blocks: list[CanonicalContentBlock] = []
        messages: list[CanonicalMessage] = []
        for msg in req.messages:
            blocks = self._to_blocks(msg.content)
            if msg.role == "system":
                system_blocks.extend(blocks)
            else:
                messages.append(CanonicalMessage(role=msg.role, content=blocks))
        tools = [self._to_tool_spec(tool) for tool in list(req.tools or [])]
        return CanonicalChatRequest(
            protocol="openai",
            provider=provider,
            model=req.model,
            system=system_blocks,
            messages=messages,
            stream=req.stream,
            tools=tools,
            tool_choice=req.tool_choice,
            resume_session_id=resume_session_id,
        )

    def render_non_stream(
        self,
        req: CanonicalChatRequest,
        raw_events: list[OpenAIStreamEvent],
    ) -> dict[str, Any]:
        reply = "".join(
            ev.content or ""
            for ev in raw_events
            if ev.type == "content_delta" and ev.content
        )
        session_marker = extract_session_id_marker(reply)
        content_for_parse = strip_session_id_suffix(reply)
        chat_id, created = self._response_context(req)
        if req.tools:
            parsed = parse_react_output(content_for_parse)
            tool_calls_list = react_output_to_tool_calls(parsed) if parsed else []
            if tool_calls_list:
                thought_ns = ""
                if "Thought" in content_for_parse:
                    match = re.search(
                        r"Thought[::]\s*(.+?)(?=\s*Action[::]|$)",
                        content_for_parse,
                        re.DOTALL | re.I,
                    )
                    thought_ns = (match.group(1) or "").strip() if match else ""
                text_content = (
                    f"<think>{thought_ns}</think>\n{session_marker}".strip()
                    if thought_ns
                    else session_marker
                )
                return build_tool_calls_response(
                    tool_calls_list,
                    chat_id,
                    req.model,
                    created,
                    text_content=text_content,
                )
            content_reply = format_react_final_answer_content(content_for_parse)
            if session_marker:
                content_reply += session_marker
        else:
            content_reply = content_for_parse
        return {
            "id": chat_id,
            "object": "chat.completion",
            "created": created,
            "model": req.model,
            "choices": [
                {
                    "index": 0,
                    "message": {"role": "assistant", "content": content_reply},
                    "finish_reason": "stop",
                }
            ],
        }

    async def render_stream(
        self,
        req: CanonicalChatRequest,
        raw_stream: AsyncIterator[OpenAIStreamEvent],
    ) -> AsyncIterator[str]:
        chat_id, created = self._response_context(req)
        parser = ReactStreamParser(
            chat_id=chat_id,
            model=req.model,
            created=created,
            has_tools=bool(req.tools),
        )
        session_marker = ""
        async for event in raw_stream:
            if event.type == "content_delta" and event.content:
                chunk = event.content
                if extract_session_id_marker(chunk) and not strip_session_id_suffix(
                    chunk
                ):
                    session_marker = chunk
                    continue
                for sse in parser.feed(chunk):
                    yield sse
            elif event.type == "finish":
                break
        if session_marker:
            yield self._content_delta(chat_id, req.model, created, session_marker)
        for sse in parser.finish():
            yield sse

    def render_error(self, exc: Exception) -> tuple[int, dict[str, Any]]:
        status = 400 if isinstance(exc, ValueError) else 500
        err_type = "invalid_request_error" if status == 400 else "server_error"
        return (
            status,
            {"error": {"message": str(exc), "type": err_type}},
        )

    @staticmethod
    def _message_to_raw_dict(msg: OpenAIMessage) -> dict[str, Any]:
        if isinstance(msg.content, list):
            content: str | list[dict[str, Any]] = [p.model_dump() for p in msg.content]
        else:
            content = msg.content
        out: dict[str, Any] = {"role": msg.role, "content": content}
        if msg.tool_calls is not None:
            out["tool_calls"] = msg.tool_calls
        if msg.tool_call_id is not None:
            out["tool_call_id"] = msg.tool_call_id
        return out

    @staticmethod
    def _to_blocks(
        content: str | list[OpenAIContentPart] | None,
    ) -> list[CanonicalContentBlock]:
        if content is None:
            return []
        if isinstance(content, str):
            return [
                CanonicalContentBlock(
                    type="text", text=strip_session_id_suffix(content)
                )
            ]
        blocks: list[CanonicalContentBlock] = []
        for part in content:
            if part.type == "text":
                blocks.append(
                    CanonicalContentBlock(
                        type="text",
                        text=strip_session_id_suffix(part.text or ""),
                    )
                )
            elif part.type == "image_url":
                image_url = part.image_url
                url = image_url.get("url") if isinstance(image_url, dict) else image_url
                if not url:
                    continue
                if isinstance(url, str) and url.startswith("data:"):
                    blocks.append(CanonicalContentBlock(type="image", data=url))
                else:
                    blocks.append(CanonicalContentBlock(type="image", url=str(url)))
        return blocks

    @staticmethod
    def _to_tool_spec(tool: dict[str, Any]) -> CanonicalToolSpec:
        function = tool.get("function") if tool.get("type") == "function" else tool
        return CanonicalToolSpec(
            name=str(function.get("name") or ""),
            description=str(function.get("description") or ""),
            input_schema=function.get("parameters")
            or function.get("input_schema")
            or {},
            strict=bool(function.get("strict") or False),
        )

    @staticmethod
    def _content_delta(chat_id: str, model: str, created: int, text: str) -> str:
        return (
            "data: "
            + json.dumps(
                {
                    "id": chat_id,
                    "object": "chat.completion.chunk",
                    "created": created,
                    "model": model,
                    "choices": [
                        {
                            "index": 0,
                            "delta": {"content": text},
                            "logprobs": None,
                            "finish_reason": None,
                        }
                    ],
                },
                ensure_ascii=False,
            )
            + "\n\n"
        )

    @staticmethod
    def _response_context(req: CanonicalChatRequest) -> tuple[str, int]:
        chat_id = str(
            req.metadata.setdefault(
                "response_id", f"chatcmpl-{uuid_mod.uuid4().hex[:24]}"
            )
        )
        created = int(req.metadata.setdefault("created", int(time.time())))
        return chat_id, created