| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Sequence, Union, Optional |
| import json |
|
|
| try: |
| |
| from pydantic import BaseModel as _PydanticBaseModel |
| except Exception: |
| _PydanticBaseModel = None |
|
|
| |
| |
| |
| _orig_default_encoder = json._default_encoder |
|
|
|
|
| class _PatchedJSONEncoder(json.JSONEncoder): |
| def default(self, o): |
| if _PydanticBaseModel is not None and isinstance(o, _PydanticBaseModel): |
| |
| dump = getattr(o, "model_dump", None) |
| if callable(dump): |
| return dump() |
| as_dict = getattr(o, "dict", None) |
| if callable(as_dict): |
| return as_dict() |
| return super().default(o) |
|
|
|
|
| |
| json._default_encoder = _PatchedJSONEncoder() |
|
|
| from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ResponsesRequest, DeltaMessage |
| from vllm.logger import init_logger |
| from vllm.reasoning import ReasoningParser |
|
|
| logger = init_logger(__name__) |
|
|
|
|
| class SolarOpenReasoningParser(ReasoningParser): |
| def is_reasoning_end(self, input_ids: list[int]) -> bool: |
| |
| |
| |
| |
| begin_assistant = self._token_ids("<|begin|>assistant") |
| last_assistant_idx = self._rfind_subsequence(input_ids, begin_assistant) |
| if last_assistant_idx != -1: |
| |
| prev_assistant_idx = self._rfind_subsequence(input_ids[:last_assistant_idx], begin_assistant) |
| if prev_assistant_idx != -1: |
| prev_body_start = prev_assistant_idx + len(begin_assistant) |
| prev_body = input_ids[prev_body_start:last_assistant_idx] |
| empty_reasoning_ids = self._token_ids("<|think|><|end|>") |
| if prev_body == empty_reasoning_ids: |
| return True |
|
|
| |
| |
| |
| |
| |
| |
| start_idx = last_assistant_idx + len(begin_assistant) if last_assistant_idx != -1 else 0 |
|
|
| search_tail = input_ids[start_idx:] |
| content_ids = self._token_ids("<|content|>") |
| tool_calls_ids = self._token_ids("<|tool_calls|>") |
|
|
| if self._find_subsequence(search_tail, content_ids) != -1: |
| return True |
| if self._find_subsequence(search_tail, tool_calls_ids) != -1: |
| return True |
| return False |
|
|
| def extract_content_ids(self, input_ids: list[int]) -> list[int]: |
| |
| |
| |
| content_tag_ids = self._token_ids("<|content|>") |
| tool_calls_tag_ids = self._token_ids("<|tool_calls|>") |
|
|
| idx = self._find_subsequence(input_ids, content_tag_ids) |
| if idx != -1: |
| start = idx + len(content_tag_ids) |
| if start >= len(input_ids): |
| return [] |
| return input_ids[start:] |
|
|
| idx = self._find_subsequence(input_ids, tool_calls_tag_ids) |
| if idx != -1: |
| start = idx + len(tool_calls_tag_ids) |
| if start >= len(input_ids): |
| return [] |
| return input_ids[start:] |
|
|
| return [] |
|
|
| def extract_reasoning( |
| self, |
| model_output: str, |
| request: Union[ChatCompletionRequest, ResponsesRequest], |
| ) -> tuple[str | None, str | None]: |
| |
| |
| |
| |
| reasoning = self._parse_reasoning(model_output) or "" |
| content = self._parse_content_or_calls(model_output) or "" |
|
|
| |
| |
| |
| if not content: |
| stripped = (model_output or "").strip() |
| if stripped.startswith("{") or stripped.startswith("["): |
| content = model_output |
| return reasoning, content |
|
|
| def extract_reasoning_streaming( |
| self, |
| previous_text: str, |
| current_text: str, |
| delta_text: str, |
| previous_token_ids: Sequence[int], |
| current_token_ids: Sequence[int], |
| delta_token_ids: Sequence[int], |
| ) -> Union[DeltaMessage, None]: |
| |
| prev_r = self._parse_reasoning(previous_text) or "" |
| prev_c = self._parse_content_or_calls(previous_text) or "" |
| prev_has_content_tag = self._has_content_tag(previous_text) |
| prev_has_tool_calls_tag = self._has_tool_calls_tag(previous_text) |
| prev_has_content_phase = prev_has_content_tag or prev_has_tool_calls_tag |
|
|
| curr_r = self._parse_reasoning(current_text) or "" |
| curr_c = self._parse_content_or_calls(current_text) or "" |
| curr_has_content_tag = self._has_content_tag(current_text) |
| curr_has_tool_calls_tag = self._has_tool_calls_tag(current_text) |
| curr_has_content_phase = curr_has_content_tag or curr_has_tool_calls_tag |
|
|
| |
| |
| |
| |
| if curr_has_content_phase and not prev_has_content_phase: |
| return DeltaMessage(content="") |
|
|
| |
| if curr_has_content_phase: |
| if curr_c != prev_c: |
| addition = curr_c[len(prev_c):] if curr_c.startswith(prev_c) else curr_c |
| if addition: |
| return DeltaMessage(content=addition) |
| return None |
|
|
| |
| |
| if ( |
| "<|think|>" not in current_text |
| and not self._has_content_phase(current_text) |
| and delta_text not in ("<|think|>", "<|end|>", "<|content|>", "<|tool_calls|>") |
| ): |
| return DeltaMessage(content=delta_text) |
|
|
| |
| |
| |
| prev_prefix = self._parse_reasoning_prefix(previous_text) or "" |
| curr_prefix = self._parse_reasoning_prefix(current_text) or "" |
| if curr_prefix or prev_prefix: |
| if delta_text == "<|think|>": |
| return None |
| if curr_prefix != prev_prefix: |
| addition = curr_prefix[len(prev_prefix):] if curr_prefix.startswith(prev_prefix) else curr_prefix |
| if addition: |
| return DeltaMessage(reasoning=addition) |
|
|
| |
| |
| |
| |
| if ( |
| ("<|think|>" in current_text) |
| and ("<|end|>" not in current_text) |
| and (not self._has_content_phase(current_text)) |
| and delta_text not in ("<|think|>", "<|end|>", "<|content|>", "<|tool_calls|>") |
| ): |
| return DeltaMessage(reasoning=delta_text) |
|
|
| |
| |
| |
| if ( |
| ("<|think|>" in previous_text) |
| and ("<|end|>" not in previous_text) |
| and (not self._has_content_phase(previous_text)) |
| and delta_text not in ("<|think|>", "<|end|>", "<|content|>", "<|tool_calls|>") |
| ): |
| return DeltaMessage(reasoning=delta_text) |
|
|
| return None |
|
|
| |
| |
| |
| def _token_ids(self, text: str) -> list[int]: |
| tokenizer = self.model_tokenizer |
| tokens = tokenizer.tokenize(text) |
| return tokenizer.convert_tokens_to_ids(tokens) |
|
|
| def _find_subsequence(self, haystack: Sequence[int], needle: Sequence[int]) -> int: |
| if not needle: |
| return -1 |
| n = len(needle) |
| limit = len(haystack) - n + 1 |
| for i in range(limit): |
| if haystack[i:i + n] == list(needle): |
| return i |
| return -1 |
|
|
| def _rfind_subsequence(self, haystack: Sequence[int], needle: Sequence[int]) -> int: |
| if not needle: |
| return -1 |
| n = len(needle) |
| limit = len(haystack) - n |
| last = -1 |
| for i in range(0, limit + 1): |
| if haystack[i:i + n] == list(needle): |
| last = i |
| return last |
|
|
| def _parse_reasoning(self, text: str) -> Optional[str]: |
| |
| think_tag = "<|think|>" |
| end_tag = "<|end|>" |
| s = text.find(think_tag) |
| if s == -1: |
| return None |
| s += len(think_tag) |
| e = text.find(end_tag, s) |
| if e == -1: |
| |
| |
| |
| if not self._has_content_phase(text[s:]): |
| return text[s:] if s < len(text) else None |
| return None |
| return text[s:e] |
|
|
| def _parse_trailing_content(self, text: str) -> Optional[str]: |
| |
| content_tag = "<|content|>" |
| s = text.find(content_tag) |
| if s == -1: |
| return None |
| s += len(content_tag) |
| if s >= len(text): |
| |
| return "" |
| return text[s:] |
|
|
| def _has_content_tag(self, text: str) -> bool: |
| return text.find("<|content|>") != -1 |
|
|
| |
| def _parse_content_or_calls(self, text: str) -> Optional[str]: |
| content_tag = "<|content|>" |
| tool_calls_tag = "<|tool_calls|>" |
|
|
| ci = text.find(content_tag) |
| ti = text.find(tool_calls_tag) |
|
|
| if ci != -1: |
| |
| start = ci + len(content_tag) |
| return text[start:] if start <= len(text) else "" |
| if ti != -1: |
| |
| start = ti + len(tool_calls_tag) |
| return text[start:] if start <= len(text) else "" |
| return None |
|
|
| def _has_tool_calls_tag(self, text: str) -> bool: |
| return text.find("<|tool_calls|>") != -1 |
|
|
| def _has_content_phase(self, text: str) -> bool: |
| return self._has_content_tag(text) or self._has_tool_calls_tag(text) |
|
|
| def _is_in_reasoning_phase_prev(self, text: str) -> bool: |
| |
| |
| |
| |
| if text.find("<|think|>") == -1: |
| return False |
| |
| if self._has_content_phase(text): |
| return False |
| |
| if text.find("<|end|>") != -1: |
| return False |
| return True |
|
|
| def _starts_reasoning_now(self, text: str) -> bool: |
| |
| |
| |
| i = text.find("<|think|>") |
| if i == -1: |
| return False |
| after = text[i + len("<|think|>"):] |
| |
| |
| for b in ("<|end|>", "<|content|>", "<|tool_calls|>"): |
| if after.find(b) != -1: |
| return False |
| return True |
|
|
| def _parse_reasoning_prefix(self, text: str) -> Optional[str]: |
| |
| |
| |
| ti = text.find("<|think|>") |
| if ti == -1: |
| return None |
| start = ti + len("<|think|>") |
| |
| boundaries = [ |
| i for i in ( |
| text.find("<|end|>", start), |
| text.find("<|content|>", start), |
| text.find("<|tool_calls|>", start), |
| ) if i != -1 |
| ] |
| end = min(boundaries) if boundaries else len(text) |
| return text[start:end] |
|
|