|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Sequence, Union, Optional |
|
|
import json |
|
|
|
|
|
try: |
|
|
|
|
|
from pydantic import BaseModel as _PydanticBaseModel |
|
|
except Exception: |
|
|
_PydanticBaseModel = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_orig_default_encoder = json._default_encoder |
|
|
|
|
|
|
|
|
class _PatchedJSONEncoder(json.JSONEncoder): |
|
|
def default(self, o): |
|
|
if _PydanticBaseModel is not None and isinstance(o, _PydanticBaseModel): |
|
|
|
|
|
dump = getattr(o, "model_dump", None) |
|
|
if callable(dump): |
|
|
return dump() |
|
|
as_dict = getattr(o, "dict", None) |
|
|
if callable(as_dict): |
|
|
return as_dict() |
|
|
return super().default(o) |
|
|
|
|
|
|
|
|
|
|
|
json._default_encoder = _PatchedJSONEncoder() |
|
|
|
|
|
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ResponsesRequest, DeltaMessage |
|
|
from vllm.logger import init_logger |
|
|
from vllm.reasoning import ReasoningParser |
|
|
|
|
|
logger = init_logger(__name__) |
|
|
|
|
|
|
|
|
class SolarOpenReasoningParser(ReasoningParser): |
|
|
def is_reasoning_end(self, input_ids: list[int]) -> bool: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
begin_assistant = self._token_ids("<|begin|>assistant") |
|
|
last_assistant_idx = self._rfind_subsequence(input_ids, begin_assistant) |
|
|
if last_assistant_idx != -1: |
|
|
|
|
|
prev_assistant_idx = self._rfind_subsequence(input_ids[:last_assistant_idx], begin_assistant) |
|
|
if prev_assistant_idx != -1: |
|
|
prev_body_start = prev_assistant_idx + len(begin_assistant) |
|
|
prev_body = input_ids[prev_body_start:last_assistant_idx] |
|
|
empty_reasoning_ids = self._token_ids("<|think|><|end|>") |
|
|
if prev_body == empty_reasoning_ids: |
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
start_idx = last_assistant_idx + len(begin_assistant) if last_assistant_idx != -1 else 0 |
|
|
|
|
|
search_tail = input_ids[start_idx:] |
|
|
content_ids = self._token_ids("<|content|>") |
|
|
tool_calls_ids = self._token_ids("<|tool_calls|>") |
|
|
|
|
|
if self._find_subsequence(search_tail, content_ids) != -1: |
|
|
return True |
|
|
if self._find_subsequence(search_tail, tool_calls_ids) != -1: |
|
|
return True |
|
|
return False |
|
|
|
|
|
def extract_content_ids(self, input_ids: list[int]) -> list[int]: |
|
|
|
|
|
|
|
|
|
|
|
content_tag_ids = self._token_ids("<|content|>") |
|
|
tool_calls_tag_ids = self._token_ids("<|tool_calls|>") |
|
|
|
|
|
idx = self._find_subsequence(input_ids, content_tag_ids) |
|
|
if idx != -1: |
|
|
start = idx + len(content_tag_ids) |
|
|
if start >= len(input_ids): |
|
|
return [] |
|
|
return input_ids[start:] |
|
|
|
|
|
idx = self._find_subsequence(input_ids, tool_calls_tag_ids) |
|
|
if idx != -1: |
|
|
start = idx + len(tool_calls_tag_ids) |
|
|
if start >= len(input_ids): |
|
|
return [] |
|
|
return input_ids[start:] |
|
|
|
|
|
return [] |
|
|
|
|
|
def extract_reasoning( |
|
|
self, |
|
|
model_output: str, |
|
|
request: Union[ChatCompletionRequest, ResponsesRequest], |
|
|
) -> tuple[str | None, str | None]: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
reasoning = self._parse_reasoning(model_output) or "" |
|
|
content = self._parse_content_or_calls(model_output) or "" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not content: |
|
|
stripped = (model_output or "").strip() |
|
|
if stripped.startswith("{") or stripped.startswith("["): |
|
|
content = model_output |
|
|
return reasoning, content |
|
|
|
|
|
def extract_reasoning_streaming( |
|
|
self, |
|
|
previous_text: str, |
|
|
current_text: str, |
|
|
delta_text: str, |
|
|
previous_token_ids: Sequence[int], |
|
|
current_token_ids: Sequence[int], |
|
|
delta_token_ids: Sequence[int], |
|
|
) -> Union[DeltaMessage, None]: |
|
|
|
|
|
prev_r = self._parse_reasoning(previous_text) or "" |
|
|
prev_c = self._parse_content_or_calls(previous_text) or "" |
|
|
prev_has_content_tag = self._has_content_tag(previous_text) |
|
|
prev_has_tool_calls_tag = self._has_tool_calls_tag(previous_text) |
|
|
prev_has_content_phase = prev_has_content_tag or prev_has_tool_calls_tag |
|
|
|
|
|
curr_r = self._parse_reasoning(current_text) or "" |
|
|
curr_c = self._parse_content_or_calls(current_text) or "" |
|
|
curr_has_content_tag = self._has_content_tag(current_text) |
|
|
curr_has_tool_calls_tag = self._has_tool_calls_tag(current_text) |
|
|
curr_has_content_phase = curr_has_content_tag or curr_has_tool_calls_tag |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if curr_has_content_phase and not prev_has_content_phase: |
|
|
return DeltaMessage(content="") |
|
|
|
|
|
|
|
|
if curr_has_content_phase: |
|
|
if curr_c != prev_c: |
|
|
addition = curr_c[len(prev_c):] if curr_c.startswith(prev_c) else curr_c |
|
|
if addition: |
|
|
return DeltaMessage(content=addition) |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
if ( |
|
|
"<|think|>" not in current_text |
|
|
and not self._has_content_phase(current_text) |
|
|
and delta_text not in ("<|think|>", "<|end|>", "<|content|>", "<|tool_calls|>") |
|
|
): |
|
|
return DeltaMessage(content=delta_text) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prev_prefix = self._parse_reasoning_prefix(previous_text) or "" |
|
|
curr_prefix = self._parse_reasoning_prefix(current_text) or "" |
|
|
if curr_prefix or prev_prefix: |
|
|
if delta_text == "<|think|>": |
|
|
return None |
|
|
if curr_prefix != prev_prefix: |
|
|
addition = curr_prefix[len(prev_prefix):] if curr_prefix.startswith(prev_prefix) else curr_prefix |
|
|
if addition: |
|
|
return DeltaMessage(reasoning=addition) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if ( |
|
|
("<|think|>" in current_text) |
|
|
and ("<|end|>" not in current_text) |
|
|
and (not self._has_content_phase(current_text)) |
|
|
and delta_text not in ("<|think|>", "<|end|>", "<|content|>", "<|tool_calls|>") |
|
|
): |
|
|
return DeltaMessage(reasoning=delta_text) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if ( |
|
|
("<|think|>" in previous_text) |
|
|
and ("<|end|>" not in previous_text) |
|
|
and (not self._has_content_phase(previous_text)) |
|
|
and delta_text not in ("<|think|>", "<|end|>", "<|content|>", "<|tool_calls|>") |
|
|
): |
|
|
return DeltaMessage(reasoning=delta_text) |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _token_ids(self, text: str) -> list[int]: |
|
|
tokenizer = self.model_tokenizer |
|
|
tokens = tokenizer.tokenize(text) |
|
|
return tokenizer.convert_tokens_to_ids(tokens) |
|
|
|
|
|
def _find_subsequence(self, haystack: Sequence[int], needle: Sequence[int]) -> int: |
|
|
if not needle: |
|
|
return -1 |
|
|
n = len(needle) |
|
|
limit = len(haystack) - n + 1 |
|
|
for i in range(limit): |
|
|
if haystack[i:i + n] == list(needle): |
|
|
return i |
|
|
return -1 |
|
|
|
|
|
def _rfind_subsequence(self, haystack: Sequence[int], needle: Sequence[int]) -> int: |
|
|
if not needle: |
|
|
return -1 |
|
|
n = len(needle) |
|
|
limit = len(haystack) - n |
|
|
last = -1 |
|
|
for i in range(0, limit + 1): |
|
|
if haystack[i:i + n] == list(needle): |
|
|
last = i |
|
|
return last |
|
|
|
|
|
def _parse_reasoning(self, text: str) -> Optional[str]: |
|
|
|
|
|
think_tag = "<|think|>" |
|
|
end_tag = "<|end|>" |
|
|
s = text.find(think_tag) |
|
|
if s == -1: |
|
|
return None |
|
|
s += len(think_tag) |
|
|
e = text.find(end_tag, s) |
|
|
if e == -1: |
|
|
|
|
|
|
|
|
|
|
|
if not self._has_content_phase(text[s:]): |
|
|
return text[s:] if s < len(text) else None |
|
|
return None |
|
|
return text[s:e] |
|
|
|
|
|
def _parse_trailing_content(self, text: str) -> Optional[str]: |
|
|
|
|
|
content_tag = "<|content|>" |
|
|
s = text.find(content_tag) |
|
|
if s == -1: |
|
|
return None |
|
|
s += len(content_tag) |
|
|
if s >= len(text): |
|
|
|
|
|
return "" |
|
|
return text[s:] |
|
|
|
|
|
def _has_content_tag(self, text: str) -> bool: |
|
|
return text.find("<|content|>") != -1 |
|
|
|
|
|
|
|
|
def _parse_content_or_calls(self, text: str) -> Optional[str]: |
|
|
content_tag = "<|content|>" |
|
|
tool_calls_tag = "<|tool_calls|>" |
|
|
|
|
|
ci = text.find(content_tag) |
|
|
ti = text.find(tool_calls_tag) |
|
|
|
|
|
if ci != -1: |
|
|
|
|
|
start = ci + len(content_tag) |
|
|
return text[start:] if start <= len(text) else "" |
|
|
if ti != -1: |
|
|
|
|
|
start = ti + len(tool_calls_tag) |
|
|
return text[start:] if start <= len(text) else "" |
|
|
return None |
|
|
|
|
|
def _has_tool_calls_tag(self, text: str) -> bool: |
|
|
return text.find("<|tool_calls|>") != -1 |
|
|
|
|
|
def _has_content_phase(self, text: str) -> bool: |
|
|
return self._has_content_tag(text) or self._has_tool_calls_tag(text) |
|
|
|
|
|
def _is_in_reasoning_phase_prev(self, text: str) -> bool: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if text.find("<|think|>") == -1: |
|
|
return False |
|
|
|
|
|
if self._has_content_phase(text): |
|
|
return False |
|
|
|
|
|
if text.find("<|end|>") != -1: |
|
|
return False |
|
|
return True |
|
|
|
|
|
def _starts_reasoning_now(self, text: str) -> bool: |
|
|
|
|
|
|
|
|
|
|
|
i = text.find("<|think|>") |
|
|
if i == -1: |
|
|
return False |
|
|
after = text[i + len("<|think|>"):] |
|
|
|
|
|
|
|
|
for b in ("<|end|>", "<|content|>", "<|tool_calls|>"): |
|
|
if after.find(b) != -1: |
|
|
return False |
|
|
return True |
|
|
|
|
|
def _parse_reasoning_prefix(self, text: str) -> Optional[str]: |
|
|
|
|
|
|
|
|
|
|
|
ti = text.find("<|think|>") |
|
|
if ti == -1: |
|
|
return None |
|
|
start = ti + len("<|think|>") |
|
|
|
|
|
boundaries = [ |
|
|
i for i in ( |
|
|
text.find("<|end|>", start), |
|
|
text.find("<|content|>", start), |
|
|
text.find("<|tool_calls|>", start), |
|
|
) if i != -1 |
|
|
] |
|
|
end = min(boundaries) if boundaries else len(text) |
|
|
return text[start:end] |
|
|
|