llaa33219 commited on
Commit
d0160ff
·
verified ·
1 Parent(s): eb32ca4

Upload parallel_tool_call_logits_processor.py with huggingface_hub

Browse files
parallel_tool_call_logits_processor.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 Upstage AI.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import TYPE_CHECKING
17
+
18
+ import torch
19
+
20
+ from vllm.sampling_params import SamplingParams
21
+ from vllm.v1.sample.logits_processor import (
22
+ AdapterLogitsProcessor,
23
+ RequestLogitsProcessor,
24
+ )
25
+
26
+ if TYPE_CHECKING:
27
+ from vllm.config import VllmConfig
28
+
29
+ # Hardcoded token IDs for Solar tokenizer
30
+ TOOL_CALL_END_TOKEN_ID = 32 # <|tool_call:end|>
31
+ CALLS_TOKEN_ID = 25 # <|calls|>
32
+
33
+
34
+ class SingleToolCallEnforcer:
35
+ """Request-level logits processor that enforces single tool call.
36
+
37
+ When <|tool_call:end|> token is generated, forces the next token
38
+ to be <|calls|> (which is a stop token), preventing parallel tool calls.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ tool_call_end_token_id: int,
44
+ calls_token_id: int,
45
+ ):
46
+ self._tool_call_end_token_id = tool_call_end_token_id
47
+ self._calls_token_id = calls_token_id
48
+
49
+ def __call__(
50
+ self,
51
+ output_token_ids: list[int],
52
+ logits: torch.Tensor,
53
+ ) -> torch.Tensor:
54
+ # Check if last generated token is <|tool_call:end|>
55
+ if output_token_ids and output_token_ids[-1] == self._tool_call_end_token_id:
56
+ # Force next token to be <|calls|> by masking all other tokens
57
+ mask = torch.full_like(logits, -float("inf"))
58
+ mask[self._calls_token_id] = logits[self._calls_token_id]
59
+ return mask
60
+
61
+ return logits
62
+
63
+
64
+ class ParallelToolCallLogitsProcessor(AdapterLogitsProcessor):
65
+ """Logits processor that enforces single tool call when parallel_tool_calls=False.
66
+
67
+ When parallel_tool_calls is disabled in SamplingParams, this processor
68
+ ensures that after <|tool_call:end|> is generated, the next token is
69
+ forced to be <|calls|> (a stop token), preventing multiple tool calls.
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ vllm_config: "VllmConfig",
75
+ device: torch.device,
76
+ is_pin_memory: bool,
77
+ ):
78
+ super().__init__(vllm_config, device, is_pin_memory)
79
+
80
+ def is_argmax_invariant(self) -> bool:
81
+ """This processor can change argmax result by forcing specific tokens."""
82
+ return False
83
+
84
+ def new_req_logits_processor(
85
+ self,
86
+ params: SamplingParams,
87
+ ) -> RequestLogitsProcessor | None:
88
+ """Return a request-level logits processor if parallel_tool_calls=False.
89
+
90
+ Args:
91
+ params: Request sampling params
92
+
93
+ Returns:
94
+ SingleToolCallEnforcer if parallel_tool_calls is False, otherwise None.
95
+ """
96
+ # Only apply when parallel_tool_calls is explicitly disabled
97
+ if params.parallel_tool_calls is False:
98
+ return SingleToolCallEnforcer(
99
+ tool_call_end_token_id=TOOL_CALL_END_TOKEN_ID,
100
+ calls_token_id=CALLS_TOKEN_ID,
101
+ )
102
+
103
+ return None
104
+