|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
from enum import Enum |
|
|
from typing import TYPE_CHECKING |
|
|
|
|
|
import torch |
|
|
|
|
|
from vllm.sampling_params import SamplingParams |
|
|
from vllm.v1.sample.logits_processor import ( |
|
|
AdapterLogitsProcessor, |
|
|
RequestLogitsProcessor, |
|
|
) |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from vllm.config import VllmConfig |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BEGIN_TOKEN_ID = 20 |
|
|
END_TOKEN_ID = 21 |
|
|
THINK_TOKEN_ID = 22 |
|
|
CONTENT_TOKEN_ID = 23 |
|
|
FLUSH_TOKEN_ID = 24 |
|
|
ASSISTANT_TOKEN_ID = 163444 |
|
|
''' |
|
|
'assistant' is not a special token exactly, but is treated as one in the logits |
|
|
processing. |
|
|
''' |
|
|
|
|
|
|
|
|
CALLS_TOKEN_ID = 25 |
|
|
TOOL_CALLS_TOKEN_ID = 30 |
|
|
TOOL_CALL_BEGIN_TOKEN_ID = 31 |
|
|
TOOL_CALL_END_TOKEN_ID = 32 |
|
|
TOOL_CALL_NAME_TOKEN_ID = 33 |
|
|
TOOL_CALL_ARGS_TOKEN_ID = 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_REASONING_EFFORT = "high" |
|
|
|
|
|
|
|
|
DEFAULT_REASONING_BUDGET_HIGH_MAX = 32 * 1024 |
|
|
DEFAULT_REASONING_BUDGET_HIGH_MIN = 8 * 1024 |
|
|
DEFAULT_REASONING_BUDGET_HIGH_RATIO = 60 |
|
|
|
|
|
|
|
|
DEFAULT_REASONING_BUDGET_MEDIUM_MAX = 16 * 1024 |
|
|
DEFAULT_REASONING_BUDGET_MEDIUM_MIN = 4 * 1024 |
|
|
DEFAULT_REASONING_BUDGET_MEDIUM_RATIO = 30 |
|
|
|
|
|
|
|
|
DEFAULT_TOOL_CALL_ID_BUDGET = 10 |
|
|
|
|
|
|
|
|
NEG_INF = float("-inf") |
|
|
|
|
|
|
|
|
def is_reasoning_request(params: SamplingParams) -> bool: |
|
|
"""Check if the request is a reasoning request based on reasoning_effort.""" |
|
|
return (params.reasoning_effort is None) or (params.reasoning_effort in ("medium", "high")) |
|
|
|
|
|
|
|
|
def is_structured_outputs(params: SamplingParams) -> bool: |
|
|
"""Check if the request has structured outputs constraints.""" |
|
|
return ( |
|
|
params.structured_outputs is not None |
|
|
and not params.structured_outputs.all_constraints_none() |
|
|
) |
|
|
|
|
|
|
|
|
class GenerationState(Enum): |
|
|
"""Enum representing the current state of response generation.""" |
|
|
|
|
|
|
|
|
INITIAL = "initial" |
|
|
|
|
|
|
|
|
NEW_MESSAGE_BEGIN = "new_message_begin" |
|
|
NEW_MESSAGE_ASSISTANT = "new_message_assistant" |
|
|
|
|
|
|
|
|
THINK_BEGIN = "think_begin" |
|
|
THINK_IN_PROGRESS = "think_in_progress" |
|
|
THINK_END = "think_end" |
|
|
THINK_FLUSH = "think_flush" |
|
|
|
|
|
|
|
|
CONTENT_BEGIN = "content_begin" |
|
|
CONTENT_IN_PROGRESS = "content_in_progress" |
|
|
CONTENT_END = "content_end" |
|
|
CONTENT_FLUSH = "content_flush" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TOOL_CALLS_BEGIN = "tool_calls_begin" |
|
|
TOOL_CALL_BEGIN = "tool_call_begin" |
|
|
TOOL_CALL_ID_IN_PROGRESS = "tool_call_id_in_progress" |
|
|
TOOL_CALL_NAME_BEGIN = "tool_call_name_begin" |
|
|
TOOL_CALL_NAME_IN_PROGRESS = "tool_call_name_in_progress" |
|
|
TOOL_CALL_ARGS_BEGIN = "tool_call_args_begin" |
|
|
TOOL_CALL_ARGS_IN_PROGRESS = "tool_call_args_in_progress" |
|
|
TOOL_CALL_END = "tool_call_end" |
|
|
CALLS = "calls" |
|
|
|
|
|
|
|
|
def get_generation_state( |
|
|
output_token_ids: list[int], |
|
|
begin_token_id: int = BEGIN_TOKEN_ID, |
|
|
end_token_id: int = END_TOKEN_ID, |
|
|
flush_token_id: int = FLUSH_TOKEN_ID, |
|
|
think_token_id: int = THINK_TOKEN_ID, |
|
|
content_token_id: int = CONTENT_TOKEN_ID, |
|
|
tool_calls_token_id: int = TOOL_CALLS_TOKEN_ID, |
|
|
tool_call_begin_token_id: int = TOOL_CALL_BEGIN_TOKEN_ID, |
|
|
tool_call_name_token_id: int = TOOL_CALL_NAME_TOKEN_ID, |
|
|
tool_call_args_token_id: int = TOOL_CALL_ARGS_TOKEN_ID, |
|
|
tool_call_end_token_id: int = TOOL_CALL_END_TOKEN_ID, |
|
|
calls_token_id: int = CALLS_TOKEN_ID, |
|
|
assistant_token_id: int = ASSISTANT_TOKEN_ID, |
|
|
) -> GenerationState: |
|
|
"""Determine the current generation state based on output token IDs. |
|
|
|
|
|
Analyzes the sequence of generated tokens to determine which phase |
|
|
of the chat template the generation is currently in. |
|
|
|
|
|
Response format specs: |
|
|
- think mode: <|think|>{{think-tokens}}<|end|><|begin|>assistant<|content|>{{content-tokens}}<|flush|> |
|
|
- tool mode: <|begin|>assistant<|tool_calls|><|tool_call:begin|>{{id}}<|tool_call:name|>{{name}}<|tool_call:args|>{{args}}<|tool_call:end|><|calls|> |
|
|
- tool mode (with think): <|think|>{{think-tokens}}<|end|><|begin|>assistant<|tool_calls|>...<|calls|> |
|
|
- no-think mode: <|content|>{{content-tokens}}<|flush|> |
|
|
|
|
|
Args: |
|
|
output_token_ids: List of token IDs generated so far. |
|
|
begin_token_id: Token ID for <|begin|>. |
|
|
end_token_id: Token ID for <|end|>. |
|
|
flush_token_id: Token ID for <|flush|> (eos). |
|
|
think_token_id: Token ID for <|think|>. |
|
|
content_token_id: Token ID for <|content|>. |
|
|
tool_calls_token_id: Token ID for <|tool_calls|>. |
|
|
tool_call_begin_token_id: Token ID for <|tool_call:begin|>. |
|
|
tool_call_name_token_id: Token ID for <|tool_call:name|>. |
|
|
tool_call_args_token_id: Token ID for <|tool_call:args|>. |
|
|
tool_call_end_token_id: Token ID for <|tool_call:end|>. |
|
|
calls_token_id: Token ID for <|calls|> (eos). |
|
|
assistant_token_id: Token ID for assistant. |
|
|
|
|
|
Returns: |
|
|
GenerationState indicating the current phase of generation. |
|
|
""" |
|
|
if not output_token_ids: |
|
|
return GenerationState.INITIAL |
|
|
|
|
|
|
|
|
state = GenerationState.INITIAL |
|
|
in_think = False |
|
|
in_content = False |
|
|
|
|
|
for token_id in output_token_ids: |
|
|
if token_id == think_token_id: |
|
|
state = GenerationState.THINK_BEGIN |
|
|
in_think = True |
|
|
in_content = False |
|
|
|
|
|
elif token_id == content_token_id: |
|
|
state = GenerationState.CONTENT_BEGIN |
|
|
in_content = True |
|
|
in_think = False |
|
|
|
|
|
elif token_id == tool_calls_token_id: |
|
|
state = GenerationState.TOOL_CALLS_BEGIN |
|
|
in_think = False |
|
|
in_content = False |
|
|
|
|
|
elif token_id == tool_call_begin_token_id: |
|
|
state = GenerationState.TOOL_CALL_BEGIN |
|
|
|
|
|
elif token_id == tool_call_name_token_id: |
|
|
state = GenerationState.TOOL_CALL_NAME_BEGIN |
|
|
|
|
|
elif token_id == tool_call_args_token_id: |
|
|
state = GenerationState.TOOL_CALL_ARGS_BEGIN |
|
|
|
|
|
elif token_id == tool_call_end_token_id: |
|
|
state = GenerationState.TOOL_CALL_END |
|
|
|
|
|
elif token_id == calls_token_id: |
|
|
state = GenerationState.CALLS |
|
|
|
|
|
elif token_id == begin_token_id: |
|
|
state = GenerationState.NEW_MESSAGE_BEGIN |
|
|
|
|
|
elif token_id == assistant_token_id: |
|
|
if state == GenerationState.NEW_MESSAGE_BEGIN: |
|
|
state = GenerationState.NEW_MESSAGE_ASSISTANT |
|
|
|
|
|
elif token_id == end_token_id: |
|
|
if in_think: |
|
|
state = GenerationState.THINK_END |
|
|
in_think = False |
|
|
elif in_content: |
|
|
state = GenerationState.CONTENT_END |
|
|
in_content = False |
|
|
|
|
|
elif token_id == flush_token_id: |
|
|
if in_think: |
|
|
state = GenerationState.THINK_FLUSH |
|
|
in_think = False |
|
|
elif in_content: |
|
|
state = GenerationState.CONTENT_FLUSH |
|
|
in_content = False |
|
|
|
|
|
else: |
|
|
|
|
|
if state == GenerationState.THINK_BEGIN: |
|
|
state = GenerationState.THINK_IN_PROGRESS |
|
|
elif state == GenerationState.THINK_IN_PROGRESS: |
|
|
pass |
|
|
elif state == GenerationState.CONTENT_BEGIN: |
|
|
state = GenerationState.CONTENT_IN_PROGRESS |
|
|
elif state == GenerationState.CONTENT_IN_PROGRESS: |
|
|
pass |
|
|
elif state == GenerationState.TOOL_CALL_BEGIN: |
|
|
state = GenerationState.TOOL_CALL_ID_IN_PROGRESS |
|
|
elif state == GenerationState.TOOL_CALL_ID_IN_PROGRESS: |
|
|
pass |
|
|
elif state == GenerationState.TOOL_CALL_NAME_BEGIN: |
|
|
state = GenerationState.TOOL_CALL_NAME_IN_PROGRESS |
|
|
elif state == GenerationState.TOOL_CALL_NAME_IN_PROGRESS: |
|
|
pass |
|
|
elif state == GenerationState.TOOL_CALL_ARGS_BEGIN: |
|
|
state = GenerationState.TOOL_CALL_ARGS_IN_PROGRESS |
|
|
elif state == GenerationState.TOOL_CALL_ARGS_IN_PROGRESS: |
|
|
pass |
|
|
|
|
|
return state |
|
|
|
|
|
|
|
|
|
|
|
_ALL_SPECIAL_TOKEN_IDS = [ |
|
|
BEGIN_TOKEN_ID, |
|
|
END_TOKEN_ID, |
|
|
THINK_TOKEN_ID, |
|
|
CONTENT_TOKEN_ID, |
|
|
FLUSH_TOKEN_ID, |
|
|
CALLS_TOKEN_ID, |
|
|
TOOL_CALLS_TOKEN_ID, |
|
|
TOOL_CALL_BEGIN_TOKEN_ID, |
|
|
TOOL_CALL_END_TOKEN_ID, |
|
|
TOOL_CALL_NAME_TOKEN_ID, |
|
|
TOOL_CALL_ARGS_TOKEN_ID, |
|
|
] |
|
|
|
|
|
|
|
|
_SPECIAL_EXCEPT_END = [ |
|
|
BEGIN_TOKEN_ID, FLUSH_TOKEN_ID, THINK_TOKEN_ID, CONTENT_TOKEN_ID, |
|
|
TOOL_CALLS_TOKEN_ID, CALLS_TOKEN_ID, TOOL_CALL_BEGIN_TOKEN_ID, |
|
|
TOOL_CALL_END_TOKEN_ID, TOOL_CALL_NAME_TOKEN_ID, TOOL_CALL_ARGS_TOKEN_ID, |
|
|
] |
|
|
|
|
|
_SPECIAL_EXCEPT_CONTENT_TOOLCALLS = [ |
|
|
THINK_TOKEN_ID, BEGIN_TOKEN_ID, END_TOKEN_ID, FLUSH_TOKEN_ID, |
|
|
CALLS_TOKEN_ID, TOOL_CALL_BEGIN_TOKEN_ID, TOOL_CALL_END_TOKEN_ID, |
|
|
TOOL_CALL_NAME_TOKEN_ID, TOOL_CALL_ARGS_TOKEN_ID, |
|
|
] |
|
|
|
|
|
_SPECIAL_EXCEPT_FLUSH = [ |
|
|
BEGIN_TOKEN_ID, END_TOKEN_ID, THINK_TOKEN_ID, CONTENT_TOKEN_ID, |
|
|
TOOL_CALLS_TOKEN_ID, CALLS_TOKEN_ID, TOOL_CALL_BEGIN_TOKEN_ID, |
|
|
TOOL_CALL_END_TOKEN_ID, TOOL_CALL_NAME_TOKEN_ID, TOOL_CALL_ARGS_TOKEN_ID, |
|
|
] |
|
|
|
|
|
_SPECIAL_EXCEPT_TOOLCALL_NAME = [ |
|
|
BEGIN_TOKEN_ID, END_TOKEN_ID, THINK_TOKEN_ID, CONTENT_TOKEN_ID, |
|
|
FLUSH_TOKEN_ID, CALLS_TOKEN_ID, TOOL_CALLS_TOKEN_ID, |
|
|
TOOL_CALL_BEGIN_TOKEN_ID, TOOL_CALL_END_TOKEN_ID, TOOL_CALL_ARGS_TOKEN_ID, |
|
|
] |
|
|
|
|
|
_SPECIAL_EXCEPT_TOOLCALL_ARGS = [ |
|
|
BEGIN_TOKEN_ID, END_TOKEN_ID, THINK_TOKEN_ID, CONTENT_TOKEN_ID, |
|
|
FLUSH_TOKEN_ID, CALLS_TOKEN_ID, TOOL_CALLS_TOKEN_ID, |
|
|
TOOL_CALL_BEGIN_TOKEN_ID, TOOL_CALL_END_TOKEN_ID, TOOL_CALL_NAME_TOKEN_ID, |
|
|
] |
|
|
|
|
|
_SPECIAL_EXCEPT_TOOLCALL_END = [ |
|
|
BEGIN_TOKEN_ID, END_TOKEN_ID, THINK_TOKEN_ID, CONTENT_TOKEN_ID, |
|
|
FLUSH_TOKEN_ID, CALLS_TOKEN_ID, TOOL_CALLS_TOKEN_ID, |
|
|
TOOL_CALL_BEGIN_TOKEN_ID, TOOL_CALL_NAME_TOKEN_ID, TOOL_CALL_ARGS_TOKEN_ID, |
|
|
] |
|
|
|
|
|
|
|
|
def _forbid_all_special_tokens(logits: torch.Tensor) -> None: |
|
|
"""Set all special token logits to -inf.""" |
|
|
logits[_ALL_SPECIAL_TOKEN_IDS] = NEG_INF |
|
|
|
|
|
|
|
|
class SolarOpenTemplateEnforcer: |
|
|
"""Request-level logits processor that enforces Solar Open chat template. |
|
|
|
|
|
Enforces the following generation rules: |
|
|
- think mode: <|think|>{{tokens}}<|end|><|begin|>assistant<|content|>{{tokens}}<|flush|> |
|
|
- tool mode: <|tool_calls|><|tool_call:begin|>{{id}}<|tool_call:name|>{{name}}<|tool_call:args|>{{args}}<|tool_call:end|><|calls|> |
|
|
- tool+think mode: <|think|>{{tokens}}<|end|><|begin|>assistant<|tool_calls|>...<|calls|> |
|
|
- no-think mode: <|content|>{{tokens}}<|flush|> |
|
|
|
|
|
Key constraints: |
|
|
- Think message can only appear first |
|
|
- Think message must be followed by another message |
|
|
- Content and tool messages cannot coexist |
|
|
- Maximum 2 messages (think + content/tool, or just content/tool) |
|
|
|
|
|
Performance optimization: |
|
|
- Uses incremental state tracking to avoid full token sequence scan on each call |
|
|
- Maintains local counters for budget tracking |
|
|
- Uses pre-computed constants to avoid repeated object creation |
|
|
""" |
|
|
|
|
|
|
|
|
_REASONING_STATES = frozenset({ |
|
|
GenerationState.INITIAL, |
|
|
GenerationState.THINK_BEGIN, |
|
|
GenerationState.THINK_IN_PROGRESS, |
|
|
}) |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
is_reasoning_request: bool, |
|
|
is_structured_outputs: bool, |
|
|
reasoning_budget: int | None = None, |
|
|
tool_call_id_budget: int = DEFAULT_TOOL_CALL_ID_BUDGET, |
|
|
): |
|
|
self._is_reasoning_request = is_reasoning_request |
|
|
self._is_structured_outputs = is_structured_outputs |
|
|
self._reasoning_budget = reasoning_budget |
|
|
self._tool_call_id_budget = tool_call_id_budget |
|
|
|
|
|
|
|
|
self._state = GenerationState.INITIAL |
|
|
self._last_processed_len = 0 |
|
|
self._in_think = False |
|
|
self._in_content = False |
|
|
|
|
|
|
|
|
self._think_token_count = 0 |
|
|
self._tool_call_id_token_count = 0 |
|
|
|
|
|
def _reset_state(self) -> None: |
|
|
"""Reset all incremental state to initial values. |
|
|
|
|
|
Called when defensive reprocessing is needed (e.g., token sequence inconsistency). |
|
|
""" |
|
|
self._state = GenerationState.INITIAL |
|
|
self._last_processed_len = 0 |
|
|
self._in_think = False |
|
|
self._in_content = False |
|
|
self._think_token_count = 0 |
|
|
self._tool_call_id_token_count = 0 |
|
|
|
|
|
def _process_token(self, token_id: int) -> None: |
|
|
"""Process a single token and update internal state incrementally. |
|
|
|
|
|
Args: |
|
|
token_id: The token ID to process. |
|
|
""" |
|
|
if token_id == THINK_TOKEN_ID: |
|
|
self._state = GenerationState.THINK_BEGIN |
|
|
self._in_think = True |
|
|
self._in_content = False |
|
|
self._think_token_count = 0 |
|
|
|
|
|
elif token_id == CONTENT_TOKEN_ID: |
|
|
self._state = GenerationState.CONTENT_BEGIN |
|
|
self._in_content = True |
|
|
self._in_think = False |
|
|
|
|
|
elif token_id == TOOL_CALLS_TOKEN_ID: |
|
|
self._state = GenerationState.TOOL_CALLS_BEGIN |
|
|
self._in_think = False |
|
|
self._in_content = False |
|
|
|
|
|
elif token_id == TOOL_CALL_BEGIN_TOKEN_ID: |
|
|
self._state = GenerationState.TOOL_CALL_BEGIN |
|
|
self._tool_call_id_token_count = 0 |
|
|
|
|
|
elif token_id == TOOL_CALL_NAME_TOKEN_ID: |
|
|
self._state = GenerationState.TOOL_CALL_NAME_BEGIN |
|
|
|
|
|
elif token_id == TOOL_CALL_ARGS_TOKEN_ID: |
|
|
self._state = GenerationState.TOOL_CALL_ARGS_BEGIN |
|
|
|
|
|
elif token_id == TOOL_CALL_END_TOKEN_ID: |
|
|
self._state = GenerationState.TOOL_CALL_END |
|
|
|
|
|
elif token_id == CALLS_TOKEN_ID: |
|
|
self._state = GenerationState.CALLS |
|
|
|
|
|
elif token_id == BEGIN_TOKEN_ID: |
|
|
self._state = GenerationState.NEW_MESSAGE_BEGIN |
|
|
|
|
|
elif token_id == ASSISTANT_TOKEN_ID: |
|
|
if self._state == GenerationState.NEW_MESSAGE_BEGIN: |
|
|
self._state = GenerationState.NEW_MESSAGE_ASSISTANT |
|
|
|
|
|
elif token_id == END_TOKEN_ID: |
|
|
if self._in_think: |
|
|
self._state = GenerationState.THINK_END |
|
|
self._in_think = False |
|
|
elif self._in_content: |
|
|
self._state = GenerationState.CONTENT_END |
|
|
self._in_content = False |
|
|
|
|
|
elif token_id == FLUSH_TOKEN_ID: |
|
|
if self._in_think: |
|
|
self._state = GenerationState.THINK_FLUSH |
|
|
self._in_think = False |
|
|
elif self._in_content: |
|
|
self._state = GenerationState.CONTENT_FLUSH |
|
|
self._in_content = False |
|
|
|
|
|
else: |
|
|
|
|
|
if self._state == GenerationState.THINK_BEGIN: |
|
|
self._state = GenerationState.THINK_IN_PROGRESS |
|
|
self._think_token_count += 1 |
|
|
elif self._state == GenerationState.THINK_IN_PROGRESS: |
|
|
self._think_token_count += 1 |
|
|
elif self._state == GenerationState.CONTENT_BEGIN: |
|
|
self._state = GenerationState.CONTENT_IN_PROGRESS |
|
|
elif self._state == GenerationState.CONTENT_IN_PROGRESS: |
|
|
pass |
|
|
elif self._state == GenerationState.TOOL_CALL_BEGIN: |
|
|
self._state = GenerationState.TOOL_CALL_ID_IN_PROGRESS |
|
|
self._tool_call_id_token_count += 1 |
|
|
elif self._state == GenerationState.TOOL_CALL_ID_IN_PROGRESS: |
|
|
self._tool_call_id_token_count += 1 |
|
|
elif self._state == GenerationState.TOOL_CALL_NAME_BEGIN: |
|
|
self._state = GenerationState.TOOL_CALL_NAME_IN_PROGRESS |
|
|
elif self._state == GenerationState.TOOL_CALL_NAME_IN_PROGRESS: |
|
|
pass |
|
|
elif self._state == GenerationState.TOOL_CALL_ARGS_BEGIN: |
|
|
self._state = GenerationState.TOOL_CALL_ARGS_IN_PROGRESS |
|
|
elif self._state == GenerationState.TOOL_CALL_ARGS_IN_PROGRESS: |
|
|
pass |
|
|
|
|
|
def _update_state_incremental(self, output_token_ids: list[int]) -> None: |
|
|
"""Update internal state by processing only new tokens. |
|
|
|
|
|
Args: |
|
|
output_token_ids: Full list of output token IDs. |
|
|
""" |
|
|
current_len = len(output_token_ids) |
|
|
|
|
|
|
|
|
if current_len < self._last_processed_len: |
|
|
self._reset_state() |
|
|
|
|
|
|
|
|
for i in range(self._last_processed_len, current_len): |
|
|
self._process_token(output_token_ids[i]) |
|
|
|
|
|
self._last_processed_len = current_len |
|
|
|
|
|
@staticmethod |
|
|
def _count_think_tokens(output_token_ids: list[int]) -> int: |
|
|
"""Count the number of tokens generated after <|think|> token. |
|
|
|
|
|
Returns 0 if <|think|> token is not found (defensive). |
|
|
Note: This static method is kept for backward compatibility and testing. |
|
|
The incremental version uses _think_token_count instead. |
|
|
""" |
|
|
try: |
|
|
think_index = output_token_ids.index(THINK_TOKEN_ID) |
|
|
return len(output_token_ids) - think_index - 1 |
|
|
except ValueError: |
|
|
return 0 |
|
|
|
|
|
@staticmethod |
|
|
def _count_tool_call_id_tokens(output_token_ids: list[int]) -> int: |
|
|
"""Count the number of tokens generated after the last <|tool_call:begin|> token. |
|
|
|
|
|
Returns 0 if <|tool_call:begin|> token is not found (defensive). |
|
|
Note: This static method is kept for backward compatibility and testing. |
|
|
The incremental version uses _tool_call_id_token_count instead. |
|
|
""" |
|
|
|
|
|
try: |
|
|
|
|
|
reversed_index = output_token_ids[::-1].index(TOOL_CALL_BEGIN_TOKEN_ID) |
|
|
last_begin_index = len(output_token_ids) - 1 - reversed_index |
|
|
return len(output_token_ids) - last_begin_index - 1 |
|
|
except ValueError: |
|
|
return 0 |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
output_token_ids: list[int], |
|
|
logits: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
|
|
|
self._update_state_incremental(output_token_ids) |
|
|
state = self._state |
|
|
|
|
|
|
|
|
if self._is_structured_outputs: |
|
|
if not self._is_reasoning_request: |
|
|
|
|
|
return logits |
|
|
else: |
|
|
|
|
|
|
|
|
if state not in self._REASONING_STATES: |
|
|
|
|
|
return logits |
|
|
|
|
|
if state == GenerationState.INITIAL: |
|
|
if self._is_reasoning_request: |
|
|
|
|
|
think_logit = logits[THINK_TOKEN_ID].clone() |
|
|
logits.fill_(NEG_INF) |
|
|
logits[THINK_TOKEN_ID] = think_logit |
|
|
else: |
|
|
|
|
|
content_logit = logits[CONTENT_TOKEN_ID].clone() |
|
|
tool_calls_logit = logits[TOOL_CALLS_TOKEN_ID].clone() |
|
|
logits.fill_(NEG_INF) |
|
|
logits[CONTENT_TOKEN_ID] = content_logit |
|
|
logits[TOOL_CALLS_TOKEN_ID] = tool_calls_logit |
|
|
|
|
|
elif state in (GenerationState.THINK_BEGIN, GenerationState.THINK_IN_PROGRESS): |
|
|
|
|
|
if ( |
|
|
self._reasoning_budget is not None |
|
|
and state == GenerationState.THINK_IN_PROGRESS |
|
|
): |
|
|
if self._think_token_count >= self._reasoning_budget: |
|
|
|
|
|
logits.fill_(NEG_INF) |
|
|
logits[END_TOKEN_ID] = 0.0 |
|
|
return logits |
|
|
|
|
|
|
|
|
|
|
|
logits[END_TOKEN_ID] = torch.maximum(logits[END_TOKEN_ID], logits[FLUSH_TOKEN_ID]) |
|
|
|
|
|
logits[_SPECIAL_EXCEPT_END] = NEG_INF |
|
|
|
|
|
elif state == GenerationState.THINK_END: |
|
|
|
|
|
|
|
|
logits.fill_(NEG_INF) |
|
|
logits[BEGIN_TOKEN_ID] = 0.0 |
|
|
|
|
|
elif state == GenerationState.NEW_MESSAGE_BEGIN: |
|
|
|
|
|
logits.fill_(NEG_INF) |
|
|
logits[ASSISTANT_TOKEN_ID] = 0.0 |
|
|
|
|
|
elif state == GenerationState.NEW_MESSAGE_ASSISTANT: |
|
|
|
|
|
|
|
|
logits[_SPECIAL_EXCEPT_CONTENT_TOOLCALLS] = NEG_INF |
|
|
|
|
|
elif state in (GenerationState.CONTENT_BEGIN, GenerationState.CONTENT_IN_PROGRESS): |
|
|
|
|
|
|
|
|
logits[FLUSH_TOKEN_ID] = torch.maximum(logits[FLUSH_TOKEN_ID], logits[END_TOKEN_ID]) |
|
|
|
|
|
logits[_SPECIAL_EXCEPT_FLUSH] = NEG_INF |
|
|
|
|
|
elif state == GenerationState.TOOL_CALLS_BEGIN: |
|
|
|
|
|
tool_call_begin_logit = logits[TOOL_CALL_BEGIN_TOKEN_ID].clone() |
|
|
logits.fill_(NEG_INF) |
|
|
logits[TOOL_CALL_BEGIN_TOKEN_ID] = tool_call_begin_logit |
|
|
|
|
|
elif state == GenerationState.TOOL_CALL_BEGIN: |
|
|
|
|
|
|
|
|
_forbid_all_special_tokens(logits) |
|
|
|
|
|
elif state == GenerationState.TOOL_CALL_ID_IN_PROGRESS: |
|
|
|
|
|
if self._tool_call_id_token_count >= self._tool_call_id_budget: |
|
|
|
|
|
logits.fill_(NEG_INF) |
|
|
logits[TOOL_CALL_NAME_TOKEN_ID] = 0.0 |
|
|
return logits |
|
|
|
|
|
|
|
|
|
|
|
logits[_SPECIAL_EXCEPT_TOOLCALL_NAME] = NEG_INF |
|
|
|
|
|
elif state == GenerationState.TOOL_CALL_NAME_BEGIN: |
|
|
|
|
|
|
|
|
_forbid_all_special_tokens(logits) |
|
|
|
|
|
elif state == GenerationState.TOOL_CALL_NAME_IN_PROGRESS: |
|
|
|
|
|
|
|
|
logits[_SPECIAL_EXCEPT_TOOLCALL_ARGS] = NEG_INF |
|
|
|
|
|
elif state == GenerationState.TOOL_CALL_ARGS_BEGIN: |
|
|
|
|
|
|
|
|
_forbid_all_special_tokens(logits) |
|
|
|
|
|
elif state == GenerationState.TOOL_CALL_ARGS_IN_PROGRESS: |
|
|
|
|
|
|
|
|
logits[_SPECIAL_EXCEPT_TOOLCALL_END] = NEG_INF |
|
|
|
|
|
elif state == GenerationState.TOOL_CALL_END: |
|
|
|
|
|
|
|
|
tool_call_begin_logit = logits[TOOL_CALL_BEGIN_TOKEN_ID].clone() |
|
|
calls_logit = logits[CALLS_TOKEN_ID].clone() |
|
|
logits.fill_(NEG_INF) |
|
|
logits[TOOL_CALL_BEGIN_TOKEN_ID] = tool_call_begin_logit |
|
|
logits[CALLS_TOKEN_ID] = calls_logit |
|
|
|
|
|
|
|
|
|
|
|
return logits |
|
|
|
|
|
class SolarOpenTemplateLogitsProcessor(AdapterLogitsProcessor): |
|
|
""" |
|
|
Logits processor that enforces Solar Open chat template. |
|
|
This processor manages the generation flow according to the |
|
|
Solar Open chat template by tracking generation states. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vllm_config: "VllmConfig", |
|
|
device: torch.device, |
|
|
is_pin_memory: bool, |
|
|
): |
|
|
super().__init__(vllm_config, device, is_pin_memory) |
|
|
|
|
|
|
|
|
self._high_max = self._parse_env_int( |
|
|
"SOLAR_REASONING_BUDGET_HIGH_MAX", DEFAULT_REASONING_BUDGET_HIGH_MAX |
|
|
) |
|
|
self._high_min = self._parse_env_int( |
|
|
"SOLAR_REASONING_BUDGET_HIGH_MIN", DEFAULT_REASONING_BUDGET_HIGH_MIN |
|
|
) |
|
|
self._high_ratio = self._parse_env_int( |
|
|
"SOLAR_REASONING_BUDGET_HIGH_RATIO", DEFAULT_REASONING_BUDGET_HIGH_RATIO |
|
|
) |
|
|
|
|
|
|
|
|
self._medium_max = self._parse_env_int( |
|
|
"SOLAR_REASONING_BUDGET_MEDIUM_MAX", DEFAULT_REASONING_BUDGET_MEDIUM_MAX |
|
|
) |
|
|
self._medium_min = self._parse_env_int( |
|
|
"SOLAR_REASONING_BUDGET_MEDIUM_MIN", DEFAULT_REASONING_BUDGET_MEDIUM_MIN |
|
|
) |
|
|
self._medium_ratio = self._parse_env_int( |
|
|
"SOLAR_REASONING_BUDGET_MEDIUM_RATIO", DEFAULT_REASONING_BUDGET_MEDIUM_RATIO |
|
|
) |
|
|
|
|
|
self._tool_call_id_budget: int = self._parse_env_int( |
|
|
"SOLAR_TOOL_CALL_ID_BUDGET", DEFAULT_TOOL_CALL_ID_BUDGET |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def _parse_env_int(env_var: str, default: int) -> int: |
|
|
"""Parse environment variable as integer, return default if not set or invalid.""" |
|
|
value = os.environ.get(env_var) |
|
|
if value is None: |
|
|
return default |
|
|
try: |
|
|
return int(value) |
|
|
except ValueError: |
|
|
return default |
|
|
|
|
|
def _calculate_reasoning_budget(self, effort: str, max_tokens: int) -> int: |
|
|
"""Calculate dynamic reasoning budget based on effort level and max_tokens. |
|
|
|
|
|
Priority (higher priority conditions are applied first): |
|
|
1. max_budget: Upper limit for reasoning tokens |
|
|
2. min_budget: Lower limit for reasoning tokens |
|
|
3. ratio: Percentage of max_tokens allocated for reasoning (e.g., 60 means 60%) |
|
|
|
|
|
budget = min(max_budget, max(min_budget, max_tokens * ratio / 100)) |
|
|
""" |
|
|
if effort == "high": |
|
|
max_budget = self._high_max |
|
|
min_budget = self._high_min |
|
|
ratio = self._high_ratio |
|
|
elif effort == "medium": |
|
|
max_budget = self._medium_max |
|
|
min_budget = self._medium_min |
|
|
ratio = self._medium_ratio |
|
|
else: |
|
|
|
|
|
max_budget = self._high_max |
|
|
min_budget = self._high_min |
|
|
ratio = self._high_ratio |
|
|
|
|
|
|
|
|
ratio_budget = max_tokens * ratio // 100 |
|
|
|
|
|
|
|
|
budget = min(max_budget, max(min_budget, ratio_budget)) |
|
|
|
|
|
return budget |
|
|
|
|
|
def is_argmax_invariant(self) -> bool: |
|
|
"""This processor can change argmax result by forcing specific tokens.""" |
|
|
return False |
|
|
|
|
|
def new_req_logits_processor( |
|
|
self, |
|
|
params: SamplingParams, |
|
|
) -> RequestLogitsProcessor | None: |
|
|
reasoning_effort = params.reasoning_effort or DEFAULT_REASONING_EFFORT |
|
|
reasoning_budget = self._calculate_reasoning_budget( |
|
|
reasoning_effort, params.max_tokens |
|
|
) |
|
|
return SolarOpenTemplateEnforcer( |
|
|
is_reasoning_request=is_reasoning_request(params), |
|
|
is_structured_outputs=is_structured_outputs(params), |
|
|
reasoning_budget=reasoning_budget, |
|
|
tool_call_id_budget=self._tool_call_id_budget, |
|
|
) |
|
|
|
|
|
|