| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from typing import TYPE_CHECKING |
| |
|
| | import torch |
| |
|
| | from vllm.sampling_params import SamplingParams |
| | from vllm.v1.sample.logits_processor import ( |
| | AdapterLogitsProcessor, |
| | RequestLogitsProcessor, |
| | ) |
| |
|
| | if TYPE_CHECKING: |
| | from vllm.config import VllmConfig |
| |
|
| | |
| | TOOL_CALL_END_TOKEN_ID = 32 |
| | CALLS_TOKEN_ID = 25 |
| |
|
| |
|
| | class SingleToolCallEnforcer: |
| | """Request-level logits processor that enforces single tool call. |
| | |
| | When <|tool_call:end|> token is generated, forces the next token |
| | to be <|calls|> (which is a stop token), preventing parallel tool calls. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | tool_call_end_token_id: int, |
| | calls_token_id: int, |
| | ): |
| | self._tool_call_end_token_id = tool_call_end_token_id |
| | self._calls_token_id = calls_token_id |
| |
|
| | def __call__( |
| | self, |
| | output_token_ids: list[int], |
| | logits: torch.Tensor, |
| | ) -> torch.Tensor: |
| | |
| | if output_token_ids and output_token_ids[-1] == self._tool_call_end_token_id: |
| | |
| | mask = torch.full_like(logits, -float("inf")) |
| | mask[self._calls_token_id] = logits[self._calls_token_id] |
| | return mask |
| |
|
| | return logits |
| |
|
| |
|
| | class ParallelToolCallLogitsProcessor(AdapterLogitsProcessor): |
| | """Logits processor that enforces single tool call when parallel_tool_calls=False. |
| | |
| | When parallel_tool_calls is disabled in SamplingParams, this processor |
| | ensures that after <|tool_call:end|> is generated, the next token is |
| | forced to be <|calls|> (a stop token), preventing multiple tool calls. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | vllm_config: "VllmConfig", |
| | device: torch.device, |
| | is_pin_memory: bool, |
| | ): |
| | super().__init__(vllm_config, device, is_pin_memory) |
| |
|
| | def is_argmax_invariant(self) -> bool: |
| | """This processor can change argmax result by forcing specific tokens.""" |
| | return False |
| |
|
| | def new_req_logits_processor( |
| | self, |
| | params: SamplingParams, |
| | ) -> RequestLogitsProcessor | None: |
| | """Return a request-level logits processor if parallel_tool_calls=False. |
| | |
| | Args: |
| | params: Request sampling params |
| | |
| | Returns: |
| | SingleToolCallEnforcer if parallel_tool_calls is False, otherwise None. |
| | """ |
| | |
| | if params.parallel_tool_calls is False: |
| | return SingleToolCallEnforcer( |
| | tool_call_end_token_id=TOOL_CALL_END_TOKEN_ID, |
| | calls_token_id=CALLS_TOKEN_ID, |
| | ) |
| |
|
| | return None |
| |
|
| |
|