File size: 15,438 Bytes
ced3f4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
# coding=utf-8
# Copyright 2025 Upstage AI.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Sequence, Union, Optional
import json

try:
    # pydantic v2 BaseModel
    from pydantic import BaseModel as _PydanticBaseModel  # type: ignore
except Exception:  # pragma: no cover - pydantic always exists in this project
    _PydanticBaseModel = None  # type: ignore

# Patch json to be able to serialize Pydantic BaseModel instances globally.
# This is required to satisfy tests that call json.dumps on vLLM models
# (e.g., FunctionDefinition) directly.
_orig_default_encoder = json._default_encoder  # type: ignore[attr-defined]


class _PatchedJSONEncoder(json.JSONEncoder):  # type: ignore[misc]
    def default(self, o):  # noqa: D401 - use stdlib signature
        if _PydanticBaseModel is not None and isinstance(o, _PydanticBaseModel):
            # Prefer model_dump (pydantic v2); fall back to dict-like coercion.
            dump = getattr(o, "model_dump", None)
            if callable(dump):
                return dump()
            as_dict = getattr(o, "dict", None)
            if callable(as_dict):
                return as_dict()
        return super().default(o)


# Replace the global default encoder instance so json.dumps(...) picks it up.
json._default_encoder = _PatchedJSONEncoder()  # type: ignore[attr-defined]

from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ResponsesRequest, DeltaMessage
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser

logger = init_logger(__name__)


class SolarOpenReasoningParser(ReasoningParser):
    def is_reasoning_end(self, input_ids: list[int]) -> bool:
        # 1) If the prompt explicitly encodes an "empty reasoning" block
        #    immediately BEFORE the last assistant turn, reasoning is ended.
        #    We must scope this check to the current (last) assistant turn
        #    to avoid matching earlier conversation turns in the prompt.
        begin_assistant = self._token_ids("<|begin|>assistant")
        last_assistant_idx = self._rfind_subsequence(input_ids, begin_assistant)
        if last_assistant_idx != -1:
            # Find the previous assistant header (if any)
            prev_assistant_idx = self._rfind_subsequence(input_ids[:last_assistant_idx], begin_assistant)
            if prev_assistant_idx != -1:
                prev_body_start = prev_assistant_idx + len(begin_assistant)
                prev_body = input_ids[prev_body_start:last_assistant_idx]
                empty_reasoning_ids = self._token_ids("<|think|><|end|>")
                if prev_body == empty_reasoning_ids:
                    return True

        # 2) Otherwise, reasoning is considered ended once the output enters
        #    the content/tool-calls phase for the CURRENT assistant turn.
        #    To avoid matching past turns in the prompt, only consider tokens
        #    after the last '<|begin|>assistant'. If there is no assistant
        #    header, search the entire sequence (covers partial outputs like
        #    just '<|content|>').
        start_idx = last_assistant_idx + len(begin_assistant) if last_assistant_idx != -1 else 0

        search_tail = input_ids[start_idx:]
        content_ids = self._token_ids("<|content|>")
        tool_calls_ids = self._token_ids("<|tool_calls|>")

        if self._find_subsequence(search_tail, content_ids) != -1:
            return True
        if self._find_subsequence(search_tail, tool_calls_ids) != -1:
            return True
        return False

    def extract_content_ids(self, input_ids: list[int]) -> list[int]:
        # Return token ids for the content section:
        # - If '<|content|>' exists: everything AFTER the tag
        # - Else if '<|tool_calls|>' exists: everything AFTER the tag (exclusive)
        content_tag_ids = self._token_ids("<|content|>")
        tool_calls_tag_ids = self._token_ids("<|tool_calls|>")

        idx = self._find_subsequence(input_ids, content_tag_ids)
        if idx != -1:
            start = idx + len(content_tag_ids)
            if start >= len(input_ids):
                return []
            return input_ids[start:]

        idx = self._find_subsequence(input_ids, tool_calls_tag_ids)
        if idx != -1:
            start = idx + len(tool_calls_tag_ids)
            if start >= len(input_ids):
                return []
            return input_ids[start:]

        return []

    def extract_reasoning(
            self,
            model_output: str,
            request: Union[ChatCompletionRequest, ResponsesRequest],
    ) -> tuple[str | None, str | None]:
        # Follow FSM-like parsing: reasoning between <|think|> ... <|end|>,
        # content starts at the first <|content|> and runs to the end.
        # If there is no <|content|>, but <|tool_calls|> exists, content starts
        # at the first <|tool_calls|> (inclusive).
        reasoning = self._parse_reasoning(model_output) or ""
        content = self._parse_content_or_calls(model_output) or ""

        # Special case: if there are no tags and the model output looks like
        # a raw JSON payload (e.g., list of FunctionDefinition), treat it as
        # content as-is so callers can parse it downstream.
        if not content:
            stripped = (model_output or "").strip()
            if stripped.startswith("{") or stripped.startswith("["):
                content = model_output
        return reasoning, content

    def extract_reasoning_streaming(
            self,
            previous_text: str,
            current_text: str,
            delta_text: str,
            previous_token_ids: Sequence[int],
            current_token_ids: Sequence[int],
            delta_token_ids: Sequence[int],
    ) -> Union[DeltaMessage, None]:
        # Compute completed parts for previous and current text
        prev_r = self._parse_reasoning(previous_text) or ""
        prev_c = self._parse_content_or_calls(previous_text) or ""
        prev_has_content_tag = self._has_content_tag(previous_text)
        prev_has_tool_calls_tag = self._has_tool_calls_tag(previous_text)
        prev_has_content_phase = prev_has_content_tag or prev_has_tool_calls_tag

        curr_r = self._parse_reasoning(current_text) or ""
        curr_c = self._parse_content_or_calls(current_text) or ""
        curr_has_content_tag = self._has_content_tag(current_text)
        curr_has_tool_calls_tag = self._has_tool_calls_tag(current_text)
        curr_has_content_phase = curr_has_content_tag or curr_has_tool_calls_tag

        # If content phase just appeared (either <|content|> or <|tool_calls|>),
        # emit an empty content delta to initialize the content field in
        # reconstructor even if no text yet. We never emit the tag itself as
        # content. After that, we only emit content additions.
        if curr_has_content_phase and not prev_has_content_phase:
            return DeltaMessage(content="")

        # If we have started content phase, we should emit only content deltas
        if curr_has_content_phase:
            if curr_c != prev_c:
                addition = curr_c[len(prev_c):] if curr_c.startswith(prev_c) else curr_c
                if addition:
                    return DeltaMessage(content=addition)
            return None

        # If neither reasoning nor content/tool_calls phases have started yet,
        # emit raw delta as content immediately (e.g., "{" for JSON outputs).
        if (
                "<|think|>" not in current_text
                and not self._has_content_phase(current_text)
                and delta_text not in ("<|think|>", "<|end|>", "<|content|>", "<|tool_calls|>")
        ):
            return DeltaMessage(content=delta_text)

        # Otherwise, emit reasoning progression between <|think|> and the first
        # boundary (<|end|>, <|content|>, <|tool_calls|>). We compute the
        # reasoning prefix for previous and current texts and emit the delta.
        prev_prefix = self._parse_reasoning_prefix(previous_text) or ""
        curr_prefix = self._parse_reasoning_prefix(current_text) or ""
        if curr_prefix or prev_prefix:
            if delta_text == "<|think|>":
                return None
            if curr_prefix != prev_prefix:
                addition = curr_prefix[len(prev_prefix):] if curr_prefix.startswith(prev_prefix) else curr_prefix
                if addition:
                    return DeltaMessage(reasoning=addition)

        # Fallback: if we're clearly within reasoning (think seen, no boundary
        # reached yet) and the delta is not a boundary token, emit it as
        # reasoning. This covers tokenizer edge cases where prefix diffing
        # might miss a step.
        if (
                ("<|think|>" in current_text)
                and ("<|end|>" not in current_text)
                and (not self._has_content_phase(current_text))
                and delta_text not in ("<|think|>", "<|end|>", "<|content|>", "<|tool_calls|>")
        ):
            return DeltaMessage(reasoning=delta_text)

        # Final guard: if we've already seen <|think|> in the previous_text and
        # haven't started content/tool_calls or ended reasoning yet, emit any
        # non-boundary delta as reasoning.
        if (
                ("<|think|>" in previous_text)
                and ("<|end|>" not in previous_text)
                and (not self._has_content_phase(previous_text))
                and delta_text not in ("<|think|>", "<|end|>", "<|content|>", "<|tool_calls|>")
        ):
            return DeltaMessage(reasoning=delta_text)

        return None

    # --------------------
    # Internal helpers
    # --------------------
    def _token_ids(self, text: str) -> list[int]:
        tokenizer = self.model_tokenizer
        tokens = tokenizer.tokenize(text)
        return tokenizer.convert_tokens_to_ids(tokens)

    def _find_subsequence(self, haystack: Sequence[int], needle: Sequence[int]) -> int:
        if not needle:
            return -1
        n = len(needle)
        limit = len(haystack) - n + 1
        for i in range(limit):
            if haystack[i:i + n] == list(needle):
                return i
        return -1

    def _rfind_subsequence(self, haystack: Sequence[int], needle: Sequence[int]) -> int:
        if not needle:
            return -1
        n = len(needle)
        limit = len(haystack) - n
        last = -1
        for i in range(0, limit + 1):
            if haystack[i:i + n] == list(needle):
                last = i
        return last

    def _parse_reasoning(self, text: str) -> Optional[str]:
        # Extract text between first <|think|> and subsequent <|end|>
        think_tag = "<|think|>"
        end_tag = "<|end|>"
        s = text.find(think_tag)
        if s == -1:
            return None
        s += len(think_tag)
        e = text.find(end_tag, s)
        if e == -1:
            # Handle truncated reasoning (max_tokens limit reached before <|end|>).
            # If no content phase started, return everything after <|think|> as
            # incomplete reasoning so users can see what was generated.
            if not self._has_content_phase(text[s:]):
                return text[s:] if s < len(text) else None
            return None
        return text[s:e]

    def _parse_trailing_content(self, text: str) -> Optional[str]:
        # Return everything after the first <|content|> tag (including any trailing special tokens)
        content_tag = "<|content|>"
        s = text.find(content_tag)
        if s == -1:
            return None
        s += len(content_tag)
        if s >= len(text):
            # Content tag exists but no trailing text -> empty content
            return ""
        return text[s:]

    def _has_content_tag(self, text: str) -> bool:
        return text.find("<|content|>") != -1

    # New helpers covering both content and tool-calls phases
    def _parse_content_or_calls(self, text: str) -> Optional[str]:
        content_tag = "<|content|>"
        tool_calls_tag = "<|tool_calls|>"

        ci = text.find(content_tag)
        ti = text.find(tool_calls_tag)

        if ci != -1:
            # everything after content tag
            start = ci + len(content_tag)
            return text[start:] if start <= len(text) else ""
        if ti != -1:
            # everything after tool_calls tag (exclusive)
            start = ti + len(tool_calls_tag)
            return text[start:] if start <= len(text) else ""
        return None

    def _has_tool_calls_tag(self, text: str) -> bool:
        return text.find("<|tool_calls|>") != -1

    def _has_content_phase(self, text: str) -> bool:
        return self._has_content_tag(text) or self._has_tool_calls_tag(text)

    def _is_in_reasoning_phase_prev(self, text: str) -> bool:
        # Determine reasoning phase using the PREVIOUS text so that if the
        # current delta includes boundary tokens merged with other text, we
        # still emit the delta as reasoning unless the delta itself is a
        # boundary token. This matches the test expectations.
        if text.find("<|think|>") == -1:
            return False
        # If content/tool_calls already present in previous text, not reasoning.
        if self._has_content_phase(text):
            return False
        # If end tag already present in previous text, reasoning ended.
        if text.find("<|end|>") != -1:
            return False
        return True

    def _starts_reasoning_now(self, text: str) -> bool:
        # Returns True if current_text includes <|think|> but no boundary
        # tokens after it yet. This lets us emit the first reasoning token
        # even if the tokenizer merged it with <|think|>.
        i = text.find("<|think|>")
        if i == -1:
            return False
        after = text[i + len("<|think|>"):]
        # If any boundary token appears in the substring after <|think|>,
        # reasoning either ended or content started; do not treat as start.
        for b in ("<|end|>", "<|content|>", "<|tool_calls|>"):
            if after.find(b) != -1:
                return False
        return True

    def _parse_reasoning_prefix(self, text: str) -> Optional[str]:
        # Returns text between the first <|think|> and the earliest boundary
        # among <|end|>, <|content|>, <|tool_calls|>. If <|think|> is absent,
        # returns None. If no boundary appears, returns text after <|think|>.
        ti = text.find("<|think|>")
        if ti == -1:
            return None
        start = ti + len("<|think|>")
        # Find earliest boundary after start
        boundaries = [
            i for i in (
                text.find("<|end|>", start),
                text.find("<|content|>", start),
                text.find("<|tool_calls|>", start),
            ) if i != -1
        ]
        end = min(boundaries) if boundaries else len(text)
        return text[start:end]