| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import os |
| | from enum import Enum |
| | from typing import TYPE_CHECKING |
| |
|
| | import torch |
| |
|
| | from vllm.sampling_params import SamplingParams |
| | from vllm.v1.sample.logits_processor import ( |
| | AdapterLogitsProcessor, |
| | RequestLogitsProcessor, |
| | ) |
| |
|
| | if TYPE_CHECKING: |
| | from vllm.config import VllmConfig |
| |
|
| | |
| |
|
| | |
| | BEGIN_TOKEN_ID = 20 |
| | END_TOKEN_ID = 21 |
| | THINK_TOKEN_ID = 22 |
| | CONTENT_TOKEN_ID = 23 |
| | FLUSH_TOKEN_ID = 24 |
| | ASSISTANT_TOKEN_ID = 163444 |
| | ''' |
| | 'assistant' is not a special token exactly, but is treated as one in the logits |
| | processing. |
| | ''' |
| |
|
| | |
| | CALLS_TOKEN_ID = 25 |
| | TOOL_CALLS_TOKEN_ID = 30 |
| | TOOL_CALL_BEGIN_TOKEN_ID = 31 |
| | TOOL_CALL_END_TOKEN_ID = 32 |
| | TOOL_CALL_NAME_TOKEN_ID = 33 |
| | TOOL_CALL_ARGS_TOKEN_ID = 34 |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | DEFAULT_REASONING_EFFORT = "high" |
| |
|
| | |
| | DEFAULT_REASONING_BUDGET_HIGH_MAX = 32 * 1024 |
| | DEFAULT_REASONING_BUDGET_HIGH_MIN = 8 * 1024 |
| | DEFAULT_REASONING_BUDGET_HIGH_RATIO = 60 |
| |
|
| | |
| | DEFAULT_REASONING_BUDGET_MEDIUM_MAX = 16 * 1024 |
| | DEFAULT_REASONING_BUDGET_MEDIUM_MIN = 4 * 1024 |
| | DEFAULT_REASONING_BUDGET_MEDIUM_RATIO = 30 |
| |
|
| | |
| | DEFAULT_TOOL_CALL_ID_BUDGET = 10 |
| |
|
| | |
| | NEG_INF = float("-inf") |
| |
|
| |
|
| | def is_reasoning_request(params: SamplingParams) -> bool: |
| | """Check if the request is a reasoning request based on reasoning_effort.""" |
| | return (params.reasoning_effort is None) or (params.reasoning_effort in ("medium", "high")) |
| |
|
| |
|
| | def is_structured_outputs(params: SamplingParams) -> bool: |
| | """Check if the request has structured outputs constraints.""" |
| | return ( |
| | params.structured_outputs is not None |
| | and not params.structured_outputs.all_constraints_none() |
| | ) |
| |
|
| |
|
| | class GenerationState(Enum): |
| | """Enum representing the current state of response generation.""" |
| |
|
| | |
| | INITIAL = "initial" |
| |
|
| | |
| | NEW_MESSAGE_BEGIN = "new_message_begin" |
| | NEW_MESSAGE_ASSISTANT = "new_message_assistant" |
| |
|
| | |
| | THINK_BEGIN = "think_begin" |
| | THINK_IN_PROGRESS = "think_in_progress" |
| | THINK_END = "think_end" |
| | THINK_FLUSH = "think_flush" |
| |
|
| | |
| | CONTENT_BEGIN = "content_begin" |
| | CONTENT_IN_PROGRESS = "content_in_progress" |
| | CONTENT_END = "content_end" |
| | CONTENT_FLUSH = "content_flush" |
| |
|
| | |
| | |
| | |
| | TOOL_CALLS_BEGIN = "tool_calls_begin" |
| | TOOL_CALL_BEGIN = "tool_call_begin" |
| | TOOL_CALL_ID_IN_PROGRESS = "tool_call_id_in_progress" |
| | TOOL_CALL_NAME_BEGIN = "tool_call_name_begin" |
| | TOOL_CALL_NAME_IN_PROGRESS = "tool_call_name_in_progress" |
| | TOOL_CALL_ARGS_BEGIN = "tool_call_args_begin" |
| | TOOL_CALL_ARGS_IN_PROGRESS = "tool_call_args_in_progress" |
| | TOOL_CALL_END = "tool_call_end" |
| | CALLS = "calls" |
| |
|
| |
|
| | def get_generation_state( |
| | output_token_ids: list[int], |
| | begin_token_id: int = BEGIN_TOKEN_ID, |
| | end_token_id: int = END_TOKEN_ID, |
| | flush_token_id: int = FLUSH_TOKEN_ID, |
| | think_token_id: int = THINK_TOKEN_ID, |
| | content_token_id: int = CONTENT_TOKEN_ID, |
| | tool_calls_token_id: int = TOOL_CALLS_TOKEN_ID, |
| | tool_call_begin_token_id: int = TOOL_CALL_BEGIN_TOKEN_ID, |
| | tool_call_name_token_id: int = TOOL_CALL_NAME_TOKEN_ID, |
| | tool_call_args_token_id: int = TOOL_CALL_ARGS_TOKEN_ID, |
| | tool_call_end_token_id: int = TOOL_CALL_END_TOKEN_ID, |
| | calls_token_id: int = CALLS_TOKEN_ID, |
| | assistant_token_id: int = ASSISTANT_TOKEN_ID, |
| | ) -> GenerationState: |
| | """Determine the current generation state based on output token IDs. |
| | |
| | Analyzes the sequence of generated tokens to determine which phase |
| | of the chat template the generation is currently in. |
| | |
| | Response format specs: |
| | - think mode: <|think|>{{think-tokens}}<|end|><|begin|>assistant<|content|>{{content-tokens}}<|flush|> |
| | - tool mode: <|begin|>assistant<|tool_calls|><|tool_call:begin|>{{id}}<|tool_call:name|>{{name}}<|tool_call:args|>{{args}}<|tool_call:end|><|calls|> |
| | - tool mode (with think): <|think|>{{think-tokens}}<|end|><|begin|>assistant<|tool_calls|>...<|calls|> |
| | - no-think mode: <|content|>{{content-tokens}}<|flush|> |
| | |
| | Args: |
| | output_token_ids: List of token IDs generated so far. |
| | begin_token_id: Token ID for <|begin|>. |
| | end_token_id: Token ID for <|end|>. |
| | flush_token_id: Token ID for <|flush|> (eos). |
| | think_token_id: Token ID for <|think|>. |
| | content_token_id: Token ID for <|content|>. |
| | tool_calls_token_id: Token ID for <|tool_calls|>. |
| | tool_call_begin_token_id: Token ID for <|tool_call:begin|>. |
| | tool_call_name_token_id: Token ID for <|tool_call:name|>. |
| | tool_call_args_token_id: Token ID for <|tool_call:args|>. |
| | tool_call_end_token_id: Token ID for <|tool_call:end|>. |
| | calls_token_id: Token ID for <|calls|> (eos). |
| | assistant_token_id: Token ID for assistant. |
| | |
| | Returns: |
| | GenerationState indicating the current phase of generation. |
| | """ |
| | if not output_token_ids: |
| | return GenerationState.INITIAL |
| |
|
| | |
| | state = GenerationState.INITIAL |
| | in_think = False |
| | in_content = False |
| |
|
| | for token_id in output_token_ids: |
| | if token_id == think_token_id: |
| | state = GenerationState.THINK_BEGIN |
| | in_think = True |
| | in_content = False |
| |
|
| | elif token_id == content_token_id: |
| | state = GenerationState.CONTENT_BEGIN |
| | in_content = True |
| | in_think = False |
| |
|
| | elif token_id == tool_calls_token_id: |
| | state = GenerationState.TOOL_CALLS_BEGIN |
| | in_think = False |
| | in_content = False |
| |
|
| | elif token_id == tool_call_begin_token_id: |
| | state = GenerationState.TOOL_CALL_BEGIN |
| |
|
| | elif token_id == tool_call_name_token_id: |
| | state = GenerationState.TOOL_CALL_NAME_BEGIN |
| |
|
| | elif token_id == tool_call_args_token_id: |
| | state = GenerationState.TOOL_CALL_ARGS_BEGIN |
| |
|
| | elif token_id == tool_call_end_token_id: |
| | state = GenerationState.TOOL_CALL_END |
| |
|
| | elif token_id == calls_token_id: |
| | state = GenerationState.CALLS |
| |
|
| | elif token_id == begin_token_id: |
| | state = GenerationState.NEW_MESSAGE_BEGIN |
| |
|
| | elif token_id == assistant_token_id: |
| | if state == GenerationState.NEW_MESSAGE_BEGIN: |
| | state = GenerationState.NEW_MESSAGE_ASSISTANT |
| |
|
| | elif token_id == end_token_id: |
| | if in_think: |
| | state = GenerationState.THINK_END |
| | in_think = False |
| | elif in_content: |
| | state = GenerationState.CONTENT_END |
| | in_content = False |
| |
|
| | elif token_id == flush_token_id: |
| | if in_think: |
| | state = GenerationState.THINK_FLUSH |
| | in_think = False |
| | elif in_content: |
| | state = GenerationState.CONTENT_FLUSH |
| | in_content = False |
| |
|
| | else: |
| | |
| | if state == GenerationState.THINK_BEGIN: |
| | state = GenerationState.THINK_IN_PROGRESS |
| | elif state == GenerationState.THINK_IN_PROGRESS: |
| | pass |
| | elif state == GenerationState.CONTENT_BEGIN: |
| | state = GenerationState.CONTENT_IN_PROGRESS |
| | elif state == GenerationState.CONTENT_IN_PROGRESS: |
| | pass |
| | elif state == GenerationState.TOOL_CALL_BEGIN: |
| | state = GenerationState.TOOL_CALL_ID_IN_PROGRESS |
| | elif state == GenerationState.TOOL_CALL_ID_IN_PROGRESS: |
| | pass |
| | elif state == GenerationState.TOOL_CALL_NAME_BEGIN: |
| | state = GenerationState.TOOL_CALL_NAME_IN_PROGRESS |
| | elif state == GenerationState.TOOL_CALL_NAME_IN_PROGRESS: |
| | pass |
| | elif state == GenerationState.TOOL_CALL_ARGS_BEGIN: |
| | state = GenerationState.TOOL_CALL_ARGS_IN_PROGRESS |
| | elif state == GenerationState.TOOL_CALL_ARGS_IN_PROGRESS: |
| | pass |
| |
|
| | return state |
| |
|
| |
|
| | |
| | _ALL_SPECIAL_TOKEN_IDS = [ |
| | BEGIN_TOKEN_ID, |
| | END_TOKEN_ID, |
| | THINK_TOKEN_ID, |
| | CONTENT_TOKEN_ID, |
| | FLUSH_TOKEN_ID, |
| | CALLS_TOKEN_ID, |
| | TOOL_CALLS_TOKEN_ID, |
| | TOOL_CALL_BEGIN_TOKEN_ID, |
| | TOOL_CALL_END_TOKEN_ID, |
| | TOOL_CALL_NAME_TOKEN_ID, |
| | TOOL_CALL_ARGS_TOKEN_ID, |
| | ] |
| |
|
| | |
| | _SPECIAL_EXCEPT_END = [ |
| | BEGIN_TOKEN_ID, FLUSH_TOKEN_ID, THINK_TOKEN_ID, CONTENT_TOKEN_ID, |
| | TOOL_CALLS_TOKEN_ID, CALLS_TOKEN_ID, TOOL_CALL_BEGIN_TOKEN_ID, |
| | TOOL_CALL_END_TOKEN_ID, TOOL_CALL_NAME_TOKEN_ID, TOOL_CALL_ARGS_TOKEN_ID, |
| | ] |
| |
|
| | _SPECIAL_EXCEPT_CONTENT_TOOLCALLS = [ |
| | THINK_TOKEN_ID, BEGIN_TOKEN_ID, END_TOKEN_ID, FLUSH_TOKEN_ID, |
| | CALLS_TOKEN_ID, TOOL_CALL_BEGIN_TOKEN_ID, TOOL_CALL_END_TOKEN_ID, |
| | TOOL_CALL_NAME_TOKEN_ID, TOOL_CALL_ARGS_TOKEN_ID, |
| | ] |
| |
|
| | _SPECIAL_EXCEPT_FLUSH = [ |
| | BEGIN_TOKEN_ID, END_TOKEN_ID, THINK_TOKEN_ID, CONTENT_TOKEN_ID, |
| | TOOL_CALLS_TOKEN_ID, CALLS_TOKEN_ID, TOOL_CALL_BEGIN_TOKEN_ID, |
| | TOOL_CALL_END_TOKEN_ID, TOOL_CALL_NAME_TOKEN_ID, TOOL_CALL_ARGS_TOKEN_ID, |
| | ] |
| |
|
| | _SPECIAL_EXCEPT_TOOLCALL_NAME = [ |
| | BEGIN_TOKEN_ID, END_TOKEN_ID, THINK_TOKEN_ID, CONTENT_TOKEN_ID, |
| | FLUSH_TOKEN_ID, CALLS_TOKEN_ID, TOOL_CALLS_TOKEN_ID, |
| | TOOL_CALL_BEGIN_TOKEN_ID, TOOL_CALL_END_TOKEN_ID, TOOL_CALL_ARGS_TOKEN_ID, |
| | ] |
| |
|
| | _SPECIAL_EXCEPT_TOOLCALL_ARGS = [ |
| | BEGIN_TOKEN_ID, END_TOKEN_ID, THINK_TOKEN_ID, CONTENT_TOKEN_ID, |
| | FLUSH_TOKEN_ID, CALLS_TOKEN_ID, TOOL_CALLS_TOKEN_ID, |
| | TOOL_CALL_BEGIN_TOKEN_ID, TOOL_CALL_END_TOKEN_ID, TOOL_CALL_NAME_TOKEN_ID, |
| | ] |
| |
|
| | _SPECIAL_EXCEPT_TOOLCALL_END = [ |
| | BEGIN_TOKEN_ID, END_TOKEN_ID, THINK_TOKEN_ID, CONTENT_TOKEN_ID, |
| | FLUSH_TOKEN_ID, CALLS_TOKEN_ID, TOOL_CALLS_TOKEN_ID, |
| | TOOL_CALL_BEGIN_TOKEN_ID, TOOL_CALL_NAME_TOKEN_ID, TOOL_CALL_ARGS_TOKEN_ID, |
| | ] |
| |
|
| |
|
| | def _forbid_all_special_tokens(logits: torch.Tensor) -> None: |
| | """Set all special token logits to -inf.""" |
| | logits[_ALL_SPECIAL_TOKEN_IDS] = NEG_INF |
| |
|
| |
|
| | class SolarOpenTemplateEnforcer: |
| | """Request-level logits processor that enforces Solar Open chat template. |
| | |
| | Enforces the following generation rules: |
| | - think mode: <|think|>{{tokens}}<|end|><|begin|>assistant<|content|>{{tokens}}<|flush|> |
| | - tool mode: <|tool_calls|><|tool_call:begin|>{{id}}<|tool_call:name|>{{name}}<|tool_call:args|>{{args}}<|tool_call:end|><|calls|> |
| | - tool+think mode: <|think|>{{tokens}}<|end|><|begin|>assistant<|tool_calls|>...<|calls|> |
| | - no-think mode: <|content|>{{tokens}}<|flush|> |
| | |
| | Key constraints: |
| | - Think message can only appear first |
| | - Think message must be followed by another message |
| | - Content and tool messages cannot coexist |
| | - Maximum 2 messages (think + content/tool, or just content/tool) |
| | |
| | Performance optimization: |
| | - Uses incremental state tracking to avoid full token sequence scan on each call |
| | - Maintains local counters for budget tracking |
| | - Uses pre-computed constants to avoid repeated object creation |
| | """ |
| |
|
| | |
| | _REASONING_STATES = frozenset({ |
| | GenerationState.INITIAL, |
| | GenerationState.THINK_BEGIN, |
| | GenerationState.THINK_IN_PROGRESS, |
| | }) |
| |
|
| | def __init__( |
| | self, |
| | is_reasoning_request: bool, |
| | is_structured_outputs: bool, |
| | reasoning_budget: int | None = None, |
| | tool_call_id_budget: int = DEFAULT_TOOL_CALL_ID_BUDGET, |
| | ): |
| | self._is_reasoning_request = is_reasoning_request |
| | self._is_structured_outputs = is_structured_outputs |
| | self._reasoning_budget = reasoning_budget |
| | self._tool_call_id_budget = tool_call_id_budget |
| |
|
| | |
| | self._state = GenerationState.INITIAL |
| | self._last_processed_len = 0 |
| | self._in_think = False |
| | self._in_content = False |
| |
|
| | |
| | self._think_token_count = 0 |
| | self._tool_call_id_token_count = 0 |
| |
|
| | def _reset_state(self) -> None: |
| | """Reset all incremental state to initial values. |
| | |
| | Called when defensive reprocessing is needed (e.g., token sequence inconsistency). |
| | """ |
| | self._state = GenerationState.INITIAL |
| | self._last_processed_len = 0 |
| | self._in_think = False |
| | self._in_content = False |
| | self._think_token_count = 0 |
| | self._tool_call_id_token_count = 0 |
| |
|
| | def _process_token(self, token_id: int) -> None: |
| | """Process a single token and update internal state incrementally. |
| | |
| | Args: |
| | token_id: The token ID to process. |
| | """ |
| | if token_id == THINK_TOKEN_ID: |
| | self._state = GenerationState.THINK_BEGIN |
| | self._in_think = True |
| | self._in_content = False |
| | self._think_token_count = 0 |
| |
|
| | elif token_id == CONTENT_TOKEN_ID: |
| | self._state = GenerationState.CONTENT_BEGIN |
| | self._in_content = True |
| | self._in_think = False |
| |
|
| | elif token_id == TOOL_CALLS_TOKEN_ID: |
| | self._state = GenerationState.TOOL_CALLS_BEGIN |
| | self._in_think = False |
| | self._in_content = False |
| |
|
| | elif token_id == TOOL_CALL_BEGIN_TOKEN_ID: |
| | self._state = GenerationState.TOOL_CALL_BEGIN |
| | self._tool_call_id_token_count = 0 |
| |
|
| | elif token_id == TOOL_CALL_NAME_TOKEN_ID: |
| | self._state = GenerationState.TOOL_CALL_NAME_BEGIN |
| |
|
| | elif token_id == TOOL_CALL_ARGS_TOKEN_ID: |
| | self._state = GenerationState.TOOL_CALL_ARGS_BEGIN |
| |
|
| | elif token_id == TOOL_CALL_END_TOKEN_ID: |
| | self._state = GenerationState.TOOL_CALL_END |
| |
|
| | elif token_id == CALLS_TOKEN_ID: |
| | self._state = GenerationState.CALLS |
| |
|
| | elif token_id == BEGIN_TOKEN_ID: |
| | self._state = GenerationState.NEW_MESSAGE_BEGIN |
| |
|
| | elif token_id == ASSISTANT_TOKEN_ID: |
| | if self._state == GenerationState.NEW_MESSAGE_BEGIN: |
| | self._state = GenerationState.NEW_MESSAGE_ASSISTANT |
| |
|
| | elif token_id == END_TOKEN_ID: |
| | if self._in_think: |
| | self._state = GenerationState.THINK_END |
| | self._in_think = False |
| | elif self._in_content: |
| | self._state = GenerationState.CONTENT_END |
| | self._in_content = False |
| |
|
| | elif token_id == FLUSH_TOKEN_ID: |
| | if self._in_think: |
| | self._state = GenerationState.THINK_FLUSH |
| | self._in_think = False |
| | elif self._in_content: |
| | self._state = GenerationState.CONTENT_FLUSH |
| | self._in_content = False |
| |
|
| | else: |
| | |
| | if self._state == GenerationState.THINK_BEGIN: |
| | self._state = GenerationState.THINK_IN_PROGRESS |
| | self._think_token_count += 1 |
| | elif self._state == GenerationState.THINK_IN_PROGRESS: |
| | self._think_token_count += 1 |
| | elif self._state == GenerationState.CONTENT_BEGIN: |
| | self._state = GenerationState.CONTENT_IN_PROGRESS |
| | elif self._state == GenerationState.CONTENT_IN_PROGRESS: |
| | pass |
| | elif self._state == GenerationState.TOOL_CALL_BEGIN: |
| | self._state = GenerationState.TOOL_CALL_ID_IN_PROGRESS |
| | self._tool_call_id_token_count += 1 |
| | elif self._state == GenerationState.TOOL_CALL_ID_IN_PROGRESS: |
| | self._tool_call_id_token_count += 1 |
| | elif self._state == GenerationState.TOOL_CALL_NAME_BEGIN: |
| | self._state = GenerationState.TOOL_CALL_NAME_IN_PROGRESS |
| | elif self._state == GenerationState.TOOL_CALL_NAME_IN_PROGRESS: |
| | pass |
| | elif self._state == GenerationState.TOOL_CALL_ARGS_BEGIN: |
| | self._state = GenerationState.TOOL_CALL_ARGS_IN_PROGRESS |
| | elif self._state == GenerationState.TOOL_CALL_ARGS_IN_PROGRESS: |
| | pass |
| |
|
| | def _update_state_incremental(self, output_token_ids: list[int]) -> None: |
| | """Update internal state by processing only new tokens. |
| | |
| | Args: |
| | output_token_ids: Full list of output token IDs. |
| | """ |
| | current_len = len(output_token_ids) |
| |
|
| | |
| | if current_len < self._last_processed_len: |
| | self._reset_state() |
| |
|
| | |
| | for i in range(self._last_processed_len, current_len): |
| | self._process_token(output_token_ids[i]) |
| |
|
| | self._last_processed_len = current_len |
| |
|
| | @staticmethod |
| | def _count_think_tokens(output_token_ids: list[int]) -> int: |
| | """Count the number of tokens generated after <|think|> token. |
| | |
| | Returns 0 if <|think|> token is not found (defensive). |
| | Note: This static method is kept for backward compatibility and testing. |
| | The incremental version uses _think_token_count instead. |
| | """ |
| | try: |
| | think_index = output_token_ids.index(THINK_TOKEN_ID) |
| | return len(output_token_ids) - think_index - 1 |
| | except ValueError: |
| | return 0 |
| |
|
| | @staticmethod |
| | def _count_tool_call_id_tokens(output_token_ids: list[int]) -> int: |
| | """Count the number of tokens generated after the last <|tool_call:begin|> token. |
| | |
| | Returns 0 if <|tool_call:begin|> token is not found (defensive). |
| | Note: This static method is kept for backward compatibility and testing. |
| | The incremental version uses _tool_call_id_token_count instead. |
| | """ |
| | |
| | try: |
| | |
| | reversed_index = output_token_ids[::-1].index(TOOL_CALL_BEGIN_TOKEN_ID) |
| | last_begin_index = len(output_token_ids) - 1 - reversed_index |
| | return len(output_token_ids) - last_begin_index - 1 |
| | except ValueError: |
| | return 0 |
| |
|
| | def __call__( |
| | self, |
| | output_token_ids: list[int], |
| | logits: torch.Tensor, |
| | ) -> torch.Tensor: |
| | |
| | self._update_state_incremental(output_token_ids) |
| | state = self._state |
| |
|
| | |
| | if self._is_structured_outputs: |
| | if not self._is_reasoning_request: |
| | |
| | return logits |
| | else: |
| | |
| | |
| | if state not in self._REASONING_STATES: |
| | |
| | return logits |
| |
|
| | if state == GenerationState.INITIAL: |
| | if self._is_reasoning_request: |
| | |
| | think_logit = logits[THINK_TOKEN_ID].clone() |
| | logits.fill_(NEG_INF) |
| | logits[THINK_TOKEN_ID] = think_logit |
| | else: |
| | |
| | content_logit = logits[CONTENT_TOKEN_ID].clone() |
| | tool_calls_logit = logits[TOOL_CALLS_TOKEN_ID].clone() |
| | logits.fill_(NEG_INF) |
| | logits[CONTENT_TOKEN_ID] = content_logit |
| | logits[TOOL_CALLS_TOKEN_ID] = tool_calls_logit |
| |
|
| | elif state in (GenerationState.THINK_BEGIN, GenerationState.THINK_IN_PROGRESS): |
| | |
| | if ( |
| | self._reasoning_budget is not None |
| | and state == GenerationState.THINK_IN_PROGRESS |
| | ): |
| | if self._think_token_count >= self._reasoning_budget: |
| | |
| | logits.fill_(NEG_INF) |
| | logits[END_TOKEN_ID] = 0.0 |
| | return logits |
| |
|
| | |
| | |
| | logits[END_TOKEN_ID] = torch.maximum(logits[END_TOKEN_ID], logits[FLUSH_TOKEN_ID]) |
| | |
| | logits[_SPECIAL_EXCEPT_END] = NEG_INF |
| |
|
| | elif state == GenerationState.THINK_END: |
| | |
| | |
| | logits.fill_(NEG_INF) |
| | logits[BEGIN_TOKEN_ID] = 0.0 |
| |
|
| | elif state == GenerationState.NEW_MESSAGE_BEGIN: |
| | |
| | logits.fill_(NEG_INF) |
| | logits[ASSISTANT_TOKEN_ID] = 0.0 |
| |
|
| | elif state == GenerationState.NEW_MESSAGE_ASSISTANT: |
| | |
| | |
| | logits[_SPECIAL_EXCEPT_CONTENT_TOOLCALLS] = NEG_INF |
| |
|
| | elif state in (GenerationState.CONTENT_BEGIN, GenerationState.CONTENT_IN_PROGRESS): |
| | |
| | |
| | logits[FLUSH_TOKEN_ID] = torch.maximum(logits[FLUSH_TOKEN_ID], logits[END_TOKEN_ID]) |
| | |
| | logits[_SPECIAL_EXCEPT_FLUSH] = NEG_INF |
| |
|
| | elif state == GenerationState.TOOL_CALLS_BEGIN: |
| | |
| | tool_call_begin_logit = logits[TOOL_CALL_BEGIN_TOKEN_ID].clone() |
| | logits.fill_(NEG_INF) |
| | logits[TOOL_CALL_BEGIN_TOKEN_ID] = tool_call_begin_logit |
| |
|
| | elif state == GenerationState.TOOL_CALL_BEGIN: |
| | |
| | |
| | _forbid_all_special_tokens(logits) |
| |
|
| | elif state == GenerationState.TOOL_CALL_ID_IN_PROGRESS: |
| | |
| | if self._tool_call_id_token_count >= self._tool_call_id_budget: |
| | |
| | logits.fill_(NEG_INF) |
| | logits[TOOL_CALL_NAME_TOKEN_ID] = 0.0 |
| | return logits |
| |
|
| | |
| | |
| | logits[_SPECIAL_EXCEPT_TOOLCALL_NAME] = NEG_INF |
| |
|
| | elif state == GenerationState.TOOL_CALL_NAME_BEGIN: |
| | |
| | |
| | _forbid_all_special_tokens(logits) |
| |
|
| | elif state == GenerationState.TOOL_CALL_NAME_IN_PROGRESS: |
| | |
| | |
| | logits[_SPECIAL_EXCEPT_TOOLCALL_ARGS] = NEG_INF |
| |
|
| | elif state == GenerationState.TOOL_CALL_ARGS_BEGIN: |
| | |
| | |
| | _forbid_all_special_tokens(logits) |
| |
|
| | elif state == GenerationState.TOOL_CALL_ARGS_IN_PROGRESS: |
| | |
| | |
| | logits[_SPECIAL_EXCEPT_TOOLCALL_END] = NEG_INF |
| |
|
| | elif state == GenerationState.TOOL_CALL_END: |
| | |
| | |
| | tool_call_begin_logit = logits[TOOL_CALL_BEGIN_TOKEN_ID].clone() |
| | calls_logit = logits[CALLS_TOKEN_ID].clone() |
| | logits.fill_(NEG_INF) |
| | logits[TOOL_CALL_BEGIN_TOKEN_ID] = tool_call_begin_logit |
| | logits[CALLS_TOKEN_ID] = calls_logit |
| |
|
| | |
| |
|
| | return logits |
| |
|
| | class SolarOpenTemplateLogitsProcessor(AdapterLogitsProcessor): |
| | """ |
| | Logits processor that enforces Solar Open chat template. |
| | This processor manages the generation flow according to the |
| | Solar Open chat template by tracking generation states. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | vllm_config: "VllmConfig", |
| | device: torch.device, |
| | is_pin_memory: bool, |
| | ): |
| | super().__init__(vllm_config, device, is_pin_memory) |
| |
|
| | |
| | self._high_max = self._parse_env_int( |
| | "SOLAR_REASONING_BUDGET_HIGH_MAX", DEFAULT_REASONING_BUDGET_HIGH_MAX |
| | ) |
| | self._high_min = self._parse_env_int( |
| | "SOLAR_REASONING_BUDGET_HIGH_MIN", DEFAULT_REASONING_BUDGET_HIGH_MIN |
| | ) |
| | self._high_ratio = self._parse_env_int( |
| | "SOLAR_REASONING_BUDGET_HIGH_RATIO", DEFAULT_REASONING_BUDGET_HIGH_RATIO |
| | ) |
| |
|
| | |
| | self._medium_max = self._parse_env_int( |
| | "SOLAR_REASONING_BUDGET_MEDIUM_MAX", DEFAULT_REASONING_BUDGET_MEDIUM_MAX |
| | ) |
| | self._medium_min = self._parse_env_int( |
| | "SOLAR_REASONING_BUDGET_MEDIUM_MIN", DEFAULT_REASONING_BUDGET_MEDIUM_MIN |
| | ) |
| | self._medium_ratio = self._parse_env_int( |
| | "SOLAR_REASONING_BUDGET_MEDIUM_RATIO", DEFAULT_REASONING_BUDGET_MEDIUM_RATIO |
| | ) |
| |
|
| | self._tool_call_id_budget: int = self._parse_env_int( |
| | "SOLAR_TOOL_CALL_ID_BUDGET", DEFAULT_TOOL_CALL_ID_BUDGET |
| | ) |
| |
|
| | @staticmethod |
| | def _parse_env_int(env_var: str, default: int) -> int: |
| | """Parse environment variable as integer, return default if not set or invalid.""" |
| | value = os.environ.get(env_var) |
| | if value is None: |
| | return default |
| | try: |
| | return int(value) |
| | except ValueError: |
| | return default |
| |
|
| | def _calculate_reasoning_budget(self, effort: str, max_tokens: int) -> int: |
| | """Calculate dynamic reasoning budget based on effort level and max_tokens. |
| | |
| | Priority (higher priority conditions are applied first): |
| | 1. max_budget: Upper limit for reasoning tokens |
| | 2. min_budget: Lower limit for reasoning tokens |
| | 3. ratio: Percentage of max_tokens allocated for reasoning (e.g., 60 means 60%) |
| | |
| | budget = min(max_budget, max(min_budget, max_tokens * ratio / 100)) |
| | """ |
| | if effort == "high": |
| | max_budget = self._high_max |
| | min_budget = self._high_min |
| | ratio = self._high_ratio |
| | elif effort == "medium": |
| | max_budget = self._medium_max |
| | min_budget = self._medium_min |
| | ratio = self._medium_ratio |
| | else: |
| | |
| | max_budget = self._high_max |
| | min_budget = self._high_min |
| | ratio = self._high_ratio |
| |
|
| | |
| | ratio_budget = max_tokens * ratio // 100 |
| |
|
| | |
| | budget = min(max_budget, max(min_budget, ratio_budget)) |
| |
|
| | return budget |
| |
|
| | def is_argmax_invariant(self) -> bool: |
| | """This processor can change argmax result by forcing specific tokens.""" |
| | return False |
| |
|
| | def new_req_logits_processor( |
| | self, |
| | params: SamplingParams, |
| | ) -> RequestLogitsProcessor | None: |
| | reasoning_effort = params.reasoning_effort or DEFAULT_REASONING_EFFORT |
| | reasoning_budget = self._calculate_reasoning_budget( |
| | reasoning_effort, params.max_tokens |
| | ) |
| | return SolarOpenTemplateEnforcer( |
| | is_reasoning_request=is_reasoning_request(params), |
| | is_structured_outputs=is_structured_outputs(params), |
| | reasoning_budget=reasoning_budget, |
| | tool_call_id_budget=self._tool_call_id_budget, |
| | ) |
| |
|
| |
|