chsingh commited on
Commit
5329af4
·
verified ·
1 Parent(s): a484535

frogboss parser

Browse files
Files changed (1) hide show
  1. frogboss_r2egym_parser.py +256 -0
frogboss_r2egym_parser.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Custom tool parser for vLLM with R2E-gym XML format.
4
+ Same as frogboss_default_parser but handles XML format instead of JSON.
5
+
6
+ Usage:
7
+ vllm serve microsoft/FrogBoss-2510 \
8
+ --tensor-parallel-size 4 \
9
+ --enable-auto-tool-choice \
10
+ --tool-parser-plugin frogboss_r2egym_parser.py \
11
+ --tool-call-parser froggy \
12
+ --enable-log-requests \
13
+ --enable-log-outputs \
14
+ --max-model-len 32768
15
+ """
16
+ import json
17
+ import re
18
+ import uuid
19
+
20
+ # import the required packages
21
+ from typing import Sequence, Union
22
+
23
+ from vllm.entrypoints.openai.protocol import (
24
+ ChatCompletionRequest,
25
+ DeltaFunctionCall,
26
+ DeltaMessage,
27
+ DeltaToolCall,
28
+ FunctionCall,
29
+ ToolCall,
30
+ )
31
+ from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
32
+ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
33
+ ExtractedToolCallInformation,
34
+ )
35
+ from vllm.transformers_utils.tokenizer import AnyTokenizer
36
+
37
+ try:
38
+ from vllm.entrypoints.chat_utils import make_tool_call_id
39
+ except ImportError:
40
+ # Fallback if import fails
41
+ def make_tool_call_id():
42
+ return f"chatcmpl-tool-{uuid.uuid4().hex[:24]}"
43
+
44
+
45
+ # define a tool parser and register it to vllm
46
+ # the name list in register_module can be used
47
+ # in --tool-call-parser. you can define as many
48
+ # tool parsers as you want here.
49
+ @ToolParserManager.register_module(["froggy"])
50
+ class FrogyToolParser(ToolParser):
51
+ def __init__(self, tokenizer: AnyTokenizer):
52
+ super().__init__(tokenizer)
53
+
54
+ # adjust request. e.g.: set skip special tokens
55
+ # to False for tool call output.
56
+ def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
57
+ return request
58
+
59
+ # implement the tool call parse for stream call
60
+ def extract_tool_calls_streaming(
61
+ self,
62
+ previous_text: str,
63
+ current_text: str,
64
+ delta_text: str,
65
+ previous_token_ids: Sequence[int],
66
+ current_token_ids: Sequence[int],
67
+ delta_token_ids: Sequence[int],
68
+ request: ChatCompletionRequest,
69
+ ) -> Union[DeltaMessage, None]:
70
+ # For streaming, we need to handle partial tool calls progressively
71
+ # Check if we're currently in a tool call (between XML function tags)
72
+
73
+ # If there's no delta text, return None
74
+ if not delta_text:
75
+ return None
76
+
77
+ # Check if we've started a function call in the current text
78
+ function_started = (
79
+ "<function=" in current_text and "<function=" not in previous_text
80
+ )
81
+ in_function_call = (
82
+ "<function=" in current_text and "</function>" not in current_text
83
+ )
84
+ function_completed = (
85
+ "</function>" in current_text and "</function>" not in previous_text
86
+ )
87
+
88
+ # If we just completed a function call, parse it
89
+ if function_completed:
90
+ # Extract the completed function call
91
+ pattern = r"<function=(\w+)>(.*?)</function>"
92
+ matches = re.findall(pattern, current_text, re.DOTALL)
93
+
94
+ if matches:
95
+ # Get the last completed function call
96
+ function_name, function_body = matches[-1]
97
+ try:
98
+ # Parse parameters from the function body
99
+ param_pattern = r"<parameter=(\w+)>(.*?)</parameter>"
100
+ param_matches = re.findall(param_pattern, function_body, re.DOTALL)
101
+
102
+ # Build arguments dict from parameters
103
+ arguments = {}
104
+ for param_name, param_value in param_matches:
105
+ # Strip whitespace from parameter values
106
+ param_value = param_value.strip()
107
+ arguments[param_name] = param_value
108
+
109
+ # Create tool call
110
+ tool_calls = []
111
+ tool_call = DeltaToolCall(
112
+ index=0,
113
+ id=make_tool_call_id(),
114
+ type="function",
115
+ function=DeltaFunctionCall(
116
+ name=function_name,
117
+ arguments=json.dumps(
118
+ arguments,
119
+ ensure_ascii=False,
120
+ separators=(",", ":"),
121
+ ),
122
+ ),
123
+ )
124
+ tool_calls.append(tool_call)
125
+
126
+ # Return delta with tool calls
127
+ return DeltaMessage(tool_calls=tool_calls)
128
+
129
+ except Exception as e:
130
+ # If parsing fails, just return the delta text
131
+ pass
132
+
133
+ # Similar to default parser, but for XML format
134
+ # If we just completed a function call, it's already handled above
135
+
136
+ # If we're currently inside a function call, suppress all content
137
+ # (we'll send it all as a tool call when </function> completes)
138
+ if in_function_call and not function_started:
139
+ return DeltaMessage(content="")
140
+
141
+ # For regular text (not in function call), handle partial tag detection
142
+ # The challenge: tags like "<function=read_file>" can leak through if split across tokens
143
+ # For example: delta1="<", delta2="function", delta3="=read_file>"
144
+ # We need to suppress ALL deltas while we're forming an opening tag
145
+
146
+ # First, check if we just added a lone "<" character
147
+ # This catches the very start of tag formation
148
+ if current_text.endswith("<") and not previous_text.endswith("<"):
149
+ # Just added a "<" - might be starting a tag, suppress it
150
+ return DeltaMessage(content="")
151
+
152
+ # Check if we're in the middle of forming an opening tag
153
+ # Look for unclosed "<function" or "<parameter" tags in current_text
154
+ last_function_open = current_text.rfind("<function")
155
+ last_function_close = current_text.rfind(">", last_function_open if last_function_open != -1 else 0)
156
+
157
+ # If we found "<function" and there's no ">" after it, we're forming the tag
158
+ if last_function_open != -1 and (last_function_close < last_function_open):
159
+ # We're in the middle of forming "<function=name>" - suppress
160
+ return DeltaMessage(content="")
161
+
162
+ # Same check for parameter tags
163
+ last_param_open = current_text.rfind("<parameter")
164
+ last_param_close = current_text.rfind(">", last_param_open if last_param_open != -1 else 0)
165
+
166
+ if last_param_open != -1 and (last_param_close < last_param_open):
167
+ # We're in the middle of forming "<parameter=name>" - suppress
168
+ return DeltaMessage(content="")
169
+
170
+ # Check for closing tags being formed
171
+ if current_text.endswith("</function") or current_text.endswith("</parameter"):
172
+ # Partial closing tag - suppress until complete
173
+ return DeltaMessage(content="")
174
+
175
+ # For regular text, filter out complete tags
176
+ filtered_delta = delta_text
177
+
178
+ # Remove complete tags if they appear in this delta
179
+ filtered_delta = filtered_delta.replace("<function=", "").replace(
180
+ "</function>", ""
181
+ )
182
+ # Also filter parameter tags
183
+ filtered_delta = re.sub(r"<parameter=\w+>", "", filtered_delta)
184
+ filtered_delta = filtered_delta.replace("</parameter>", "")
185
+
186
+ if filtered_delta:
187
+ return DeltaMessage(content=filtered_delta)
188
+
189
+ # Return empty content instead of None to keep the stream alive
190
+ return DeltaMessage(content="")
191
+
192
+ # implement the tool parse for non-stream call
193
+ def extract_tool_calls(
194
+ self,
195
+ model_output: str,
196
+ request: ChatCompletionRequest,
197
+ ) -> ExtractedToolCallInformation:
198
+ # Parse <function=...>...</function> tags (R2E-gym XML format)
199
+ pattern = r"<function=(\w+)>(.*?)</function>"
200
+ matches = re.findall(pattern, model_output, re.DOTALL)
201
+
202
+ tool_calls = []
203
+
204
+ for i, (function_name, function_body) in enumerate(matches):
205
+ try:
206
+ # Parse parameters from the function body
207
+ param_pattern = r"<parameter=(\w+)>(.*?)</parameter>"
208
+ param_matches = re.findall(param_pattern, function_body, re.DOTALL)
209
+
210
+ # Build arguments dict from parameters
211
+ arguments = {}
212
+ for param_name, param_value in param_matches:
213
+ # Strip whitespace from parameter values
214
+ param_value = param_value.strip()
215
+ arguments[param_name] = param_value
216
+
217
+ # Create tool call
218
+ tool_call = ToolCall(
219
+ id=make_tool_call_id(),
220
+ type="function",
221
+ function=FunctionCall(
222
+ name=function_name,
223
+ arguments=json.dumps(
224
+ arguments,
225
+ ensure_ascii=False,
226
+ separators=(",", ":"),
227
+ ),
228
+ ),
229
+ )
230
+ tool_calls.append(tool_call)
231
+
232
+ except Exception as e:
233
+ # If parsing fails, log the error with the problematic XML
234
+ print(f"Failed to parse tool call: {e}")
235
+ print(f"Problematic XML (first 200 chars): {function_body[:200]}")
236
+ continue
237
+
238
+ # Extract text content (everything before first <function=)
239
+ content = re.split(r"<function=", model_output)[0].strip()
240
+
241
+ # Important: When there are tool calls, always provide a content value (even if empty string)
242
+ # to prevent "no response was returned" errors in clients like Copilot UI.
243
+ # Only set content to None when there are no tool calls AND no content.
244
+ if not content:
245
+ content = "" if len(tool_calls) > 0 else None
246
+
247
+ return ExtractedToolCallInformation(
248
+ tools_called=len(tool_calls) > 0, tool_calls=tool_calls, content=content
249
+ )
250
+
251
+
252
+ if __name__ == "__main__":
253
+ # When run as a script, start vLLM with this parser registered
254
+ from vllm.entrypoints.cli.main import main
255
+
256
+ main()