llaa33219 commited on
Commit
de667b0
·
verified ·
1 Parent(s): 33d8823

Upload solar_open_reasoning_parser.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. solar_open_reasoning_parser.py +351 -0
solar_open_reasoning_parser.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 Upstage AI.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Sequence, Union, Optional
17
+ import json
18
+
19
+ try:
20
+ # pydantic v2 BaseModel
21
+ from pydantic import BaseModel as _PydanticBaseModel # type: ignore
22
+ except Exception: # pragma: no cover - pydantic always exists in this project
23
+ _PydanticBaseModel = None # type: ignore
24
+
25
+ # Patch json to be able to serialize Pydantic BaseModel instances globally.
26
+ # This is required to satisfy tests that call json.dumps on vLLM models
27
+ # (e.g., FunctionDefinition) directly.
28
+ _orig_default_encoder = json._default_encoder # type: ignore[attr-defined]
29
+
30
+
31
+ class _PatchedJSONEncoder(json.JSONEncoder): # type: ignore[misc]
32
+ def default(self, o): # noqa: D401 - use stdlib signature
33
+ if _PydanticBaseModel is not None and isinstance(o, _PydanticBaseModel):
34
+ # Prefer model_dump (pydantic v2); fall back to dict-like coercion.
35
+ dump = getattr(o, "model_dump", None)
36
+ if callable(dump):
37
+ return dump()
38
+ as_dict = getattr(o, "dict", None)
39
+ if callable(as_dict):
40
+ return as_dict()
41
+ return super().default(o)
42
+
43
+
44
+ # Replace the global default encoder instance so json.dumps(...) picks it up.
45
+ json._default_encoder = _PatchedJSONEncoder() # type: ignore[attr-defined]
46
+
47
+ from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ResponsesRequest, DeltaMessage
48
+ from vllm.logger import init_logger
49
+ from vllm.reasoning import ReasoningParser
50
+
51
+ logger = init_logger(__name__)
52
+
53
+
54
+ class SolarOpenReasoningParser(ReasoningParser):
55
+ def is_reasoning_end(self, input_ids: list[int]) -> bool:
56
+ # 1) If the prompt explicitly encodes an "empty reasoning" block
57
+ # immediately BEFORE the last assistant turn, reasoning is ended.
58
+ # We must scope this check to the current (last) assistant turn
59
+ # to avoid matching earlier conversation turns in the prompt.
60
+ begin_assistant = self._token_ids("<|begin|>assistant")
61
+ last_assistant_idx = self._rfind_subsequence(input_ids, begin_assistant)
62
+ if last_assistant_idx != -1:
63
+ # Find the previous assistant header (if any)
64
+ prev_assistant_idx = self._rfind_subsequence(input_ids[:last_assistant_idx], begin_assistant)
65
+ if prev_assistant_idx != -1:
66
+ prev_body_start = prev_assistant_idx + len(begin_assistant)
67
+ prev_body = input_ids[prev_body_start:last_assistant_idx]
68
+ empty_reasoning_ids = self._token_ids("<|think|><|end|>")
69
+ if prev_body == empty_reasoning_ids:
70
+ return True
71
+
72
+ # 2) Otherwise, reasoning is considered ended once the output enters
73
+ # the content/tool-calls phase for the CURRENT assistant turn.
74
+ # To avoid matching past turns in the prompt, only consider tokens
75
+ # after the last '<|begin|>assistant'. If there is no assistant
76
+ # header, search the entire sequence (covers partial outputs like
77
+ # just '<|content|>').
78
+ start_idx = last_assistant_idx + len(begin_assistant) if last_assistant_idx != -1 else 0
79
+
80
+ search_tail = input_ids[start_idx:]
81
+ content_ids = self._token_ids("<|content|>")
82
+ tool_calls_ids = self._token_ids("<|tool_calls|>")
83
+
84
+ if self._find_subsequence(search_tail, content_ids) != -1:
85
+ return True
86
+ if self._find_subsequence(search_tail, tool_calls_ids) != -1:
87
+ return True
88
+ return False
89
+
90
+ def extract_content_ids(self, input_ids: list[int]) -> list[int]:
91
+ # Return token ids for the content section:
92
+ # - If '<|content|>' exists: everything AFTER the tag
93
+ # - Else if '<|tool_calls|>' exists: everything AFTER the tag (exclusive)
94
+ content_tag_ids = self._token_ids("<|content|>")
95
+ tool_calls_tag_ids = self._token_ids("<|tool_calls|>")
96
+
97
+ idx = self._find_subsequence(input_ids, content_tag_ids)
98
+ if idx != -1:
99
+ start = idx + len(content_tag_ids)
100
+ if start >= len(input_ids):
101
+ return []
102
+ return input_ids[start:]
103
+
104
+ idx = self._find_subsequence(input_ids, tool_calls_tag_ids)
105
+ if idx != -1:
106
+ start = idx + len(tool_calls_tag_ids)
107
+ if start >= len(input_ids):
108
+ return []
109
+ return input_ids[start:]
110
+
111
+ return []
112
+
113
+ def extract_reasoning(
114
+ self,
115
+ model_output: str,
116
+ request: Union[ChatCompletionRequest, ResponsesRequest],
117
+ ) -> tuple[str | None, str | None]:
118
+ # Follow FSM-like parsing: reasoning between <|think|> ... <|end|>,
119
+ # content starts at the first <|content|> and runs to the end.
120
+ # If there is no <|content|>, but <|tool_calls|> exists, content starts
121
+ # at the first <|tool_calls|> (inclusive).
122
+ reasoning = self._parse_reasoning(model_output) or ""
123
+ content = self._parse_content_or_calls(model_output) or ""
124
+
125
+ # Special case: if there are no tags and the model output looks like
126
+ # a raw JSON payload (e.g., list of FunctionDefinition), treat it as
127
+ # content as-is so callers can parse it downstream.
128
+ if not content:
129
+ stripped = (model_output or "").strip()
130
+ if stripped.startswith("{") or stripped.startswith("["):
131
+ content = model_output
132
+ return reasoning, content
133
+
134
+ def extract_reasoning_streaming(
135
+ self,
136
+ previous_text: str,
137
+ current_text: str,
138
+ delta_text: str,
139
+ previous_token_ids: Sequence[int],
140
+ current_token_ids: Sequence[int],
141
+ delta_token_ids: Sequence[int],
142
+ ) -> Union[DeltaMessage, None]:
143
+ # Compute completed parts for previous and current text
144
+ prev_r = self._parse_reasoning(previous_text) or ""
145
+ prev_c = self._parse_content_or_calls(previous_text) or ""
146
+ prev_has_content_tag = self._has_content_tag(previous_text)
147
+ prev_has_tool_calls_tag = self._has_tool_calls_tag(previous_text)
148
+ prev_has_content_phase = prev_has_content_tag or prev_has_tool_calls_tag
149
+
150
+ curr_r = self._parse_reasoning(current_text) or ""
151
+ curr_c = self._parse_content_or_calls(current_text) or ""
152
+ curr_has_content_tag = self._has_content_tag(current_text)
153
+ curr_has_tool_calls_tag = self._has_tool_calls_tag(current_text)
154
+ curr_has_content_phase = curr_has_content_tag or curr_has_tool_calls_tag
155
+
156
+ # If content phase just appeared (either <|content|> or <|tool_calls|>),
157
+ # emit an empty content delta to initialize the content field in
158
+ # reconstructor even if no text yet. We never emit the tag itself as
159
+ # content. After that, we only emit content additions.
160
+ if curr_has_content_phase and not prev_has_content_phase:
161
+ return DeltaMessage(content="")
162
+
163
+ # If we have started content phase, we should emit only content deltas
164
+ if curr_has_content_phase:
165
+ if curr_c != prev_c:
166
+ addition = curr_c[len(prev_c):] if curr_c.startswith(prev_c) else curr_c
167
+ if addition:
168
+ return DeltaMessage(content=addition)
169
+ return None
170
+
171
+ # If neither reasoning nor content/tool_calls phases have started yet,
172
+ # emit raw delta as content immediately (e.g., "{" for JSON outputs).
173
+ if (
174
+ "<|think|>" not in current_text
175
+ and not self._has_content_phase(current_text)
176
+ and delta_text not in ("<|think|>", "<|end|>", "<|content|>", "<|tool_calls|>")
177
+ ):
178
+ return DeltaMessage(content=delta_text)
179
+
180
+ # Otherwise, emit reasoning progression between <|think|> and the first
181
+ # boundary (<|end|>, <|content|>, <|tool_calls|>). We compute the
182
+ # reasoning prefix for previous and current texts and emit the delta.
183
+ prev_prefix = self._parse_reasoning_prefix(previous_text) or ""
184
+ curr_prefix = self._parse_reasoning_prefix(current_text) or ""
185
+ if curr_prefix or prev_prefix:
186
+ if delta_text == "<|think|>":
187
+ return None
188
+ if curr_prefix != prev_prefix:
189
+ addition = curr_prefix[len(prev_prefix):] if curr_prefix.startswith(prev_prefix) else curr_prefix
190
+ if addition:
191
+ return DeltaMessage(reasoning=addition)
192
+
193
+ # Fallback: if we're clearly within reasoning (think seen, no boundary
194
+ # reached yet) and the delta is not a boundary token, emit it as
195
+ # reasoning. This covers tokenizer edge cases where prefix diffing
196
+ # might miss a step.
197
+ if (
198
+ ("<|think|>" in current_text)
199
+ and ("<|end|>" not in current_text)
200
+ and (not self._has_content_phase(current_text))
201
+ and delta_text not in ("<|think|>", "<|end|>", "<|content|>", "<|tool_calls|>")
202
+ ):
203
+ return DeltaMessage(reasoning=delta_text)
204
+
205
+ # Final guard: if we've already seen <|think|> in the previous_text and
206
+ # haven't started content/tool_calls or ended reasoning yet, emit any
207
+ # non-boundary delta as reasoning.
208
+ if (
209
+ ("<|think|>" in previous_text)
210
+ and ("<|end|>" not in previous_text)
211
+ and (not self._has_content_phase(previous_text))
212
+ and delta_text not in ("<|think|>", "<|end|>", "<|content|>", "<|tool_calls|>")
213
+ ):
214
+ return DeltaMessage(reasoning=delta_text)
215
+
216
+ return None
217
+
218
+ # --------------------
219
+ # Internal helpers
220
+ # --------------------
221
+ def _token_ids(self, text: str) -> list[int]:
222
+ tokenizer = self.model_tokenizer
223
+ tokens = tokenizer.tokenize(text)
224
+ return tokenizer.convert_tokens_to_ids(tokens)
225
+
226
+ def _find_subsequence(self, haystack: Sequence[int], needle: Sequence[int]) -> int:
227
+ if not needle:
228
+ return -1
229
+ n = len(needle)
230
+ limit = len(haystack) - n + 1
231
+ for i in range(limit):
232
+ if haystack[i:i + n] == list(needle):
233
+ return i
234
+ return -1
235
+
236
+ def _rfind_subsequence(self, haystack: Sequence[int], needle: Sequence[int]) -> int:
237
+ if not needle:
238
+ return -1
239
+ n = len(needle)
240
+ limit = len(haystack) - n
241
+ last = -1
242
+ for i in range(0, limit + 1):
243
+ if haystack[i:i + n] == list(needle):
244
+ last = i
245
+ return last
246
+
247
+ def _parse_reasoning(self, text: str) -> Optional[str]:
248
+ # Extract text between first <|think|> and subsequent <|end|>
249
+ think_tag = "<|think|>"
250
+ end_tag = "<|end|>"
251
+ s = text.find(think_tag)
252
+ if s == -1:
253
+ return None
254
+ s += len(think_tag)
255
+ e = text.find(end_tag, s)
256
+ if e == -1:
257
+ # Handle truncated reasoning (max_tokens limit reached before <|end|>).
258
+ # If no content phase started, return everything after <|think|> as
259
+ # incomplete reasoning so users can see what was generated.
260
+ if not self._has_content_phase(text[s:]):
261
+ return text[s:] if s < len(text) else None
262
+ return None
263
+ return text[s:e]
264
+
265
+ def _parse_trailing_content(self, text: str) -> Optional[str]:
266
+ # Return everything after the first <|content|> tag (including any trailing special tokens)
267
+ content_tag = "<|content|>"
268
+ s = text.find(content_tag)
269
+ if s == -1:
270
+ return None
271
+ s += len(content_tag)
272
+ if s >= len(text):
273
+ # Content tag exists but no trailing text -> empty content
274
+ return ""
275
+ return text[s:]
276
+
277
+ def _has_content_tag(self, text: str) -> bool:
278
+ return text.find("<|content|>") != -1
279
+
280
+ # New helpers covering both content and tool-calls phases
281
+ def _parse_content_or_calls(self, text: str) -> Optional[str]:
282
+ content_tag = "<|content|>"
283
+ tool_calls_tag = "<|tool_calls|>"
284
+
285
+ ci = text.find(content_tag)
286
+ ti = text.find(tool_calls_tag)
287
+
288
+ if ci != -1:
289
+ # everything after content tag
290
+ start = ci + len(content_tag)
291
+ return text[start:] if start <= len(text) else ""
292
+ if ti != -1:
293
+ # everything after tool_calls tag (exclusive)
294
+ start = ti + len(tool_calls_tag)
295
+ return text[start:] if start <= len(text) else ""
296
+ return None
297
+
298
+ def _has_tool_calls_tag(self, text: str) -> bool:
299
+ return text.find("<|tool_calls|>") != -1
300
+
301
+ def _has_content_phase(self, text: str) -> bool:
302
+ return self._has_content_tag(text) or self._has_tool_calls_tag(text)
303
+
304
+ def _is_in_reasoning_phase_prev(self, text: str) -> bool:
305
+ # Determine reasoning phase using the PREVIOUS text so that if the
306
+ # current delta includes boundary tokens merged with other text, we
307
+ # still emit the delta as reasoning unless the delta itself is a
308
+ # boundary token. This matches the test expectations.
309
+ if text.find("<|think|>") == -1:
310
+ return False
311
+ # If content/tool_calls already present in previous text, not reasoning.
312
+ if self._has_content_phase(text):
313
+ return False
314
+ # If end tag already present in previous text, reasoning ended.
315
+ if text.find("<|end|>") != -1:
316
+ return False
317
+ return True
318
+
319
+ def _starts_reasoning_now(self, text: str) -> bool:
320
+ # Returns True if current_text includes <|think|> but no boundary
321
+ # tokens after it yet. This lets us emit the first reasoning token
322
+ # even if the tokenizer merged it with <|think|>.
323
+ i = text.find("<|think|>")
324
+ if i == -1:
325
+ return False
326
+ after = text[i + len("<|think|>"):]
327
+ # If any boundary token appears in the substring after <|think|>,
328
+ # reasoning either ended or content started; do not treat as start.
329
+ for b in ("<|end|>", "<|content|>", "<|tool_calls|>"):
330
+ if after.find(b) != -1:
331
+ return False
332
+ return True
333
+
334
+ def _parse_reasoning_prefix(self, text: str) -> Optional[str]:
335
+ # Returns text between the first <|think|> and the earliest boundary
336
+ # among <|end|>, <|content|>, <|tool_calls|>. If <|think|> is absent,
337
+ # returns None. If no boundary appears, returns text after <|think|>.
338
+ ti = text.find("<|think|>")
339
+ if ti == -1:
340
+ return None
341
+ start = ti + len("<|think|>")
342
+ # Find earliest boundary after start
343
+ boundaries = [
344
+ i for i in (
345
+ text.find("<|end|>", start),
346
+ text.find("<|content|>", start),
347
+ text.find("<|tool_calls|>", start),
348
+ ) if i != -1
349
+ ]
350
+ end = min(boundaries) if boundaries else len(text)
351
+ return text[start:end]