File size: 10,377 Bytes
5329af4
 
 
 
 
 
cc93095
5329af4
 
cc93095
5329af4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc93095
 
5329af4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
#!/usr/bin/env python3
"""
Custom tool parser for vLLM with R2E-gym XML format.
Same as frogboss_default_parser but handles XML format instead of JSON.

Usage:
vllm serve microsoft/FrogBoss-32B-2510 \
    --tensor-parallel-size 4 \
        --enable-auto-tool-choice \
            --tool-parser-plugin ./Froggy-Training/src/vllm/frogboss_r2egym_parser.py \
                --tool-call-parser froggy \
                    --enable-log-requests \
                        --enable-log-outputs \
                            --max-model-len 32768
"""
import json
import re
import uuid

# import the required packages
from typing import Sequence, Union

from vllm.entrypoints.openai.protocol import (
    ChatCompletionRequest,
    DeltaFunctionCall,
    DeltaMessage,
    DeltaToolCall,
    FunctionCall,
    ToolCall,
)
from vllm.tool_parsers import ToolParser, ToolParserManager
from vllm.tool_parsers.abstract_tool_parser import (
    ExtractedToolCallInformation,
)
from vllm.transformers_utils.tokenizer import AnyTokenizer

try:
    from vllm.entrypoints.chat_utils import make_tool_call_id
except ImportError:
    # Fallback if import fails
    def make_tool_call_id():
        return f"chatcmpl-tool-{uuid.uuid4().hex[:24]}"


# define a tool parser and register it to vllm
# the name list in register_module can be used
# in --tool-call-parser. you can define as many
# tool parsers as you want here.
@ToolParserManager.register_module(["froggy"])
class FrogyToolParser(ToolParser):
    def __init__(self, tokenizer: AnyTokenizer):
        super().__init__(tokenizer)

    # adjust request. e.g.: set skip special tokens
    # to False for tool call output.
    def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
        return request

    # implement the tool call parse for stream call
    def extract_tool_calls_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
        request: ChatCompletionRequest,
    ) -> Union[DeltaMessage, None]:
        # For streaming, we need to handle partial tool calls progressively
        # Check if we're currently in a tool call (between XML function tags)

        # If there's no delta text, return None
        if not delta_text:
            return None

        # Check if we've started a function call in the current text
        function_started = (
            "<function=" in current_text and "<function=" not in previous_text
        )
        in_function_call = (
            "<function=" in current_text and "</function>" not in current_text
        )
        function_completed = (
            "</function>" in current_text and "</function>" not in previous_text
        )

        # If we just completed a function call, parse it
        if function_completed:
            # Extract the completed function call
            pattern = r"<function=(\w+)>(.*?)</function>"
            matches = re.findall(pattern, current_text, re.DOTALL)

            if matches:
                # Get the last completed function call
                function_name, function_body = matches[-1]
                try:
                    # Parse parameters from the function body
                    param_pattern = r"<parameter=(\w+)>(.*?)</parameter>"
                    param_matches = re.findall(param_pattern, function_body, re.DOTALL)
                    
                    # Build arguments dict from parameters
                    arguments = {}
                    for param_name, param_value in param_matches:
                        # Strip whitespace from parameter values
                        param_value = param_value.strip()
                        arguments[param_name] = param_value

                    # Create tool call
                    tool_calls = []
                    tool_call = DeltaToolCall(
                        index=0,
                        id=make_tool_call_id(),
                        type="function",
                        function=DeltaFunctionCall(
                            name=function_name,
                            arguments=json.dumps(
                                arguments,
                                ensure_ascii=False,
                                separators=(",", ":"),
                            ),
                        ),
                    )
                    tool_calls.append(tool_call)

                    # Return delta with tool calls
                    return DeltaMessage(tool_calls=tool_calls)

                except Exception as e:
                    # If parsing fails, just return the delta text
                    pass

        # Similar to default parser, but for XML format
        # If we just completed a function call, it's already handled above
        
        # If we're currently inside a function call, suppress all content
        # (we'll send it all as a tool call when </function> completes)
        if in_function_call and not function_started:
            return DeltaMessage(content="")

        # For regular text (not in function call), handle partial tag detection
        # The challenge: tags like "<function=read_file>" can leak through if split across tokens
        # For example: delta1="<", delta2="function", delta3="=read_file>"
        # We need to suppress ALL deltas while we're forming an opening tag
        
        # First, check if we just added a lone "<" character
        # This catches the very start of tag formation
        if current_text.endswith("<") and not previous_text.endswith("<"):
            # Just added a "<" - might be starting a tag, suppress it
            return DeltaMessage(content="")
        
        # Check if we're in the middle of forming an opening tag
        # Look for unclosed "<function" or "<parameter" tags in current_text
        last_function_open = current_text.rfind("<function")
        last_function_close = current_text.rfind(">", last_function_open if last_function_open != -1 else 0)
        
        # If we found "<function" and there's no ">" after it, we're forming the tag
        if last_function_open != -1 and (last_function_close < last_function_open):
            # We're in the middle of forming "<function=name>" - suppress
            return DeltaMessage(content="")
        
        # Same check for parameter tags
        last_param_open = current_text.rfind("<parameter")
        last_param_close = current_text.rfind(">", last_param_open if last_param_open != -1 else 0)
        
        if last_param_open != -1 and (last_param_close < last_param_open):
            # We're in the middle of forming "<parameter=name>" - suppress
            return DeltaMessage(content="")
        
        # Check for closing tags being formed
        if current_text.endswith("</function") or current_text.endswith("</parameter"):
            # Partial closing tag - suppress until complete
            return DeltaMessage(content="")
        
        # For regular text, filter out complete tags
        filtered_delta = delta_text
        
        # Remove complete tags if they appear in this delta
        filtered_delta = filtered_delta.replace("<function=", "").replace(
            "</function>", ""
        )
        # Also filter parameter tags  
        filtered_delta = re.sub(r"<parameter=\w+>", "", filtered_delta)
        filtered_delta = filtered_delta.replace("</parameter>", "")

        if filtered_delta:
            return DeltaMessage(content=filtered_delta)

        # Return empty content instead of None to keep the stream alive
        return DeltaMessage(content="")

    # implement the tool parse for non-stream call
    def extract_tool_calls(
        self,
        model_output: str,
        request: ChatCompletionRequest,
    ) -> ExtractedToolCallInformation:
        # Parse <function=...>...</function> tags (R2E-gym XML format)
        pattern = r"<function=(\w+)>(.*?)</function>"
        matches = re.findall(pattern, model_output, re.DOTALL)

        tool_calls = []

        for i, (function_name, function_body) in enumerate(matches):
            try:
                # Parse parameters from the function body
                param_pattern = r"<parameter=(\w+)>(.*?)</parameter>"
                param_matches = re.findall(param_pattern, function_body, re.DOTALL)
                
                # Build arguments dict from parameters
                arguments = {}
                for param_name, param_value in param_matches:
                    # Strip whitespace from parameter values
                    param_value = param_value.strip()
                    arguments[param_name] = param_value

                # Create tool call
                tool_call = ToolCall(
                    id=make_tool_call_id(),
                    type="function",
                    function=FunctionCall(
                        name=function_name,
                        arguments=json.dumps(
                            arguments,
                            ensure_ascii=False,
                            separators=(",", ":"),
                        ),
                    ),
                )
                tool_calls.append(tool_call)

            except Exception as e:
                # If parsing fails, log the error with the problematic XML
                print(f"Failed to parse tool call: {e}")
                print(f"Problematic XML (first 200 chars): {function_body[:200]}")
                continue

        # Extract text content (everything before first <function=)
        content = re.split(r"<function=", model_output)[0].strip()

        # Important: When there are tool calls, always provide a content value (even if empty string)
        # to prevent "no response was returned" errors in clients like Copilot UI.
        # Only set content to None when there are no tool calls AND no content.
        if not content:
            content = "" if len(tool_calls) > 0 else None

        return ExtractedToolCallInformation(
            tools_called=len(tool_calls) > 0, tool_calls=tool_calls, content=content
        )


if __name__ == "__main__":
    # When run as a script, start vLLM with this parser registered
    from vllm.entrypoints.cli.main import main

    main()