frogboss parser
Browse files- frogboss_r2egym_parser.py +256 -0
frogboss_r2egym_parser.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Custom tool parser for vLLM with R2E-gym XML format.
|
| 4 |
+
Same as frogboss_default_parser but handles XML format instead of JSON.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
vllm serve microsoft/FrogBoss-2510 \
|
| 8 |
+
--tensor-parallel-size 4 \
|
| 9 |
+
--enable-auto-tool-choice \
|
| 10 |
+
--tool-parser-plugin frogboss_r2egym_parser.py \
|
| 11 |
+
--tool-call-parser froggy \
|
| 12 |
+
--enable-log-requests \
|
| 13 |
+
--enable-log-outputs \
|
| 14 |
+
--max-model-len 32768
|
| 15 |
+
"""
|
| 16 |
+
import json
|
| 17 |
+
import re
|
| 18 |
+
import uuid
|
| 19 |
+
|
| 20 |
+
# import the required packages
|
| 21 |
+
from typing import Sequence, Union
|
| 22 |
+
|
| 23 |
+
from vllm.entrypoints.openai.protocol import (
|
| 24 |
+
ChatCompletionRequest,
|
| 25 |
+
DeltaFunctionCall,
|
| 26 |
+
DeltaMessage,
|
| 27 |
+
DeltaToolCall,
|
| 28 |
+
FunctionCall,
|
| 29 |
+
ToolCall,
|
| 30 |
+
)
|
| 31 |
+
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
| 32 |
+
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
| 33 |
+
ExtractedToolCallInformation,
|
| 34 |
+
)
|
| 35 |
+
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
from vllm.entrypoints.chat_utils import make_tool_call_id
|
| 39 |
+
except ImportError:
|
| 40 |
+
# Fallback if import fails
|
| 41 |
+
def make_tool_call_id():
|
| 42 |
+
return f"chatcmpl-tool-{uuid.uuid4().hex[:24]}"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# define a tool parser and register it to vllm
|
| 46 |
+
# the name list in register_module can be used
|
| 47 |
+
# in --tool-call-parser. you can define as many
|
| 48 |
+
# tool parsers as you want here.
|
| 49 |
+
@ToolParserManager.register_module(["froggy"])
|
| 50 |
+
class FrogyToolParser(ToolParser):
|
| 51 |
+
def __init__(self, tokenizer: AnyTokenizer):
|
| 52 |
+
super().__init__(tokenizer)
|
| 53 |
+
|
| 54 |
+
# adjust request. e.g.: set skip special tokens
|
| 55 |
+
# to False for tool call output.
|
| 56 |
+
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
| 57 |
+
return request
|
| 58 |
+
|
| 59 |
+
# implement the tool call parse for stream call
|
| 60 |
+
def extract_tool_calls_streaming(
|
| 61 |
+
self,
|
| 62 |
+
previous_text: str,
|
| 63 |
+
current_text: str,
|
| 64 |
+
delta_text: str,
|
| 65 |
+
previous_token_ids: Sequence[int],
|
| 66 |
+
current_token_ids: Sequence[int],
|
| 67 |
+
delta_token_ids: Sequence[int],
|
| 68 |
+
request: ChatCompletionRequest,
|
| 69 |
+
) -> Union[DeltaMessage, None]:
|
| 70 |
+
# For streaming, we need to handle partial tool calls progressively
|
| 71 |
+
# Check if we're currently in a tool call (between XML function tags)
|
| 72 |
+
|
| 73 |
+
# If there's no delta text, return None
|
| 74 |
+
if not delta_text:
|
| 75 |
+
return None
|
| 76 |
+
|
| 77 |
+
# Check if we've started a function call in the current text
|
| 78 |
+
function_started = (
|
| 79 |
+
"<function=" in current_text and "<function=" not in previous_text
|
| 80 |
+
)
|
| 81 |
+
in_function_call = (
|
| 82 |
+
"<function=" in current_text and "</function>" not in current_text
|
| 83 |
+
)
|
| 84 |
+
function_completed = (
|
| 85 |
+
"</function>" in current_text and "</function>" not in previous_text
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# If we just completed a function call, parse it
|
| 89 |
+
if function_completed:
|
| 90 |
+
# Extract the completed function call
|
| 91 |
+
pattern = r"<function=(\w+)>(.*?)</function>"
|
| 92 |
+
matches = re.findall(pattern, current_text, re.DOTALL)
|
| 93 |
+
|
| 94 |
+
if matches:
|
| 95 |
+
# Get the last completed function call
|
| 96 |
+
function_name, function_body = matches[-1]
|
| 97 |
+
try:
|
| 98 |
+
# Parse parameters from the function body
|
| 99 |
+
param_pattern = r"<parameter=(\w+)>(.*?)</parameter>"
|
| 100 |
+
param_matches = re.findall(param_pattern, function_body, re.DOTALL)
|
| 101 |
+
|
| 102 |
+
# Build arguments dict from parameters
|
| 103 |
+
arguments = {}
|
| 104 |
+
for param_name, param_value in param_matches:
|
| 105 |
+
# Strip whitespace from parameter values
|
| 106 |
+
param_value = param_value.strip()
|
| 107 |
+
arguments[param_name] = param_value
|
| 108 |
+
|
| 109 |
+
# Create tool call
|
| 110 |
+
tool_calls = []
|
| 111 |
+
tool_call = DeltaToolCall(
|
| 112 |
+
index=0,
|
| 113 |
+
id=make_tool_call_id(),
|
| 114 |
+
type="function",
|
| 115 |
+
function=DeltaFunctionCall(
|
| 116 |
+
name=function_name,
|
| 117 |
+
arguments=json.dumps(
|
| 118 |
+
arguments,
|
| 119 |
+
ensure_ascii=False,
|
| 120 |
+
separators=(",", ":"),
|
| 121 |
+
),
|
| 122 |
+
),
|
| 123 |
+
)
|
| 124 |
+
tool_calls.append(tool_call)
|
| 125 |
+
|
| 126 |
+
# Return delta with tool calls
|
| 127 |
+
return DeltaMessage(tool_calls=tool_calls)
|
| 128 |
+
|
| 129 |
+
except Exception as e:
|
| 130 |
+
# If parsing fails, just return the delta text
|
| 131 |
+
pass
|
| 132 |
+
|
| 133 |
+
# Similar to default parser, but for XML format
|
| 134 |
+
# If we just completed a function call, it's already handled above
|
| 135 |
+
|
| 136 |
+
# If we're currently inside a function call, suppress all content
|
| 137 |
+
# (we'll send it all as a tool call when </function> completes)
|
| 138 |
+
if in_function_call and not function_started:
|
| 139 |
+
return DeltaMessage(content="")
|
| 140 |
+
|
| 141 |
+
# For regular text (not in function call), handle partial tag detection
|
| 142 |
+
# The challenge: tags like "<function=read_file>" can leak through if split across tokens
|
| 143 |
+
# For example: delta1="<", delta2="function", delta3="=read_file>"
|
| 144 |
+
# We need to suppress ALL deltas while we're forming an opening tag
|
| 145 |
+
|
| 146 |
+
# First, check if we just added a lone "<" character
|
| 147 |
+
# This catches the very start of tag formation
|
| 148 |
+
if current_text.endswith("<") and not previous_text.endswith("<"):
|
| 149 |
+
# Just added a "<" - might be starting a tag, suppress it
|
| 150 |
+
return DeltaMessage(content="")
|
| 151 |
+
|
| 152 |
+
# Check if we're in the middle of forming an opening tag
|
| 153 |
+
# Look for unclosed "<function" or "<parameter" tags in current_text
|
| 154 |
+
last_function_open = current_text.rfind("<function")
|
| 155 |
+
last_function_close = current_text.rfind(">", last_function_open if last_function_open != -1 else 0)
|
| 156 |
+
|
| 157 |
+
# If we found "<function" and there's no ">" after it, we're forming the tag
|
| 158 |
+
if last_function_open != -1 and (last_function_close < last_function_open):
|
| 159 |
+
# We're in the middle of forming "<function=name>" - suppress
|
| 160 |
+
return DeltaMessage(content="")
|
| 161 |
+
|
| 162 |
+
# Same check for parameter tags
|
| 163 |
+
last_param_open = current_text.rfind("<parameter")
|
| 164 |
+
last_param_close = current_text.rfind(">", last_param_open if last_param_open != -1 else 0)
|
| 165 |
+
|
| 166 |
+
if last_param_open != -1 and (last_param_close < last_param_open):
|
| 167 |
+
# We're in the middle of forming "<parameter=name>" - suppress
|
| 168 |
+
return DeltaMessage(content="")
|
| 169 |
+
|
| 170 |
+
# Check for closing tags being formed
|
| 171 |
+
if current_text.endswith("</function") or current_text.endswith("</parameter"):
|
| 172 |
+
# Partial closing tag - suppress until complete
|
| 173 |
+
return DeltaMessage(content="")
|
| 174 |
+
|
| 175 |
+
# For regular text, filter out complete tags
|
| 176 |
+
filtered_delta = delta_text
|
| 177 |
+
|
| 178 |
+
# Remove complete tags if they appear in this delta
|
| 179 |
+
filtered_delta = filtered_delta.replace("<function=", "").replace(
|
| 180 |
+
"</function>", ""
|
| 181 |
+
)
|
| 182 |
+
# Also filter parameter tags
|
| 183 |
+
filtered_delta = re.sub(r"<parameter=\w+>", "", filtered_delta)
|
| 184 |
+
filtered_delta = filtered_delta.replace("</parameter>", "")
|
| 185 |
+
|
| 186 |
+
if filtered_delta:
|
| 187 |
+
return DeltaMessage(content=filtered_delta)
|
| 188 |
+
|
| 189 |
+
# Return empty content instead of None to keep the stream alive
|
| 190 |
+
return DeltaMessage(content="")
|
| 191 |
+
|
| 192 |
+
# implement the tool parse for non-stream call
|
| 193 |
+
def extract_tool_calls(
|
| 194 |
+
self,
|
| 195 |
+
model_output: str,
|
| 196 |
+
request: ChatCompletionRequest,
|
| 197 |
+
) -> ExtractedToolCallInformation:
|
| 198 |
+
# Parse <function=...>...</function> tags (R2E-gym XML format)
|
| 199 |
+
pattern = r"<function=(\w+)>(.*?)</function>"
|
| 200 |
+
matches = re.findall(pattern, model_output, re.DOTALL)
|
| 201 |
+
|
| 202 |
+
tool_calls = []
|
| 203 |
+
|
| 204 |
+
for i, (function_name, function_body) in enumerate(matches):
|
| 205 |
+
try:
|
| 206 |
+
# Parse parameters from the function body
|
| 207 |
+
param_pattern = r"<parameter=(\w+)>(.*?)</parameter>"
|
| 208 |
+
param_matches = re.findall(param_pattern, function_body, re.DOTALL)
|
| 209 |
+
|
| 210 |
+
# Build arguments dict from parameters
|
| 211 |
+
arguments = {}
|
| 212 |
+
for param_name, param_value in param_matches:
|
| 213 |
+
# Strip whitespace from parameter values
|
| 214 |
+
param_value = param_value.strip()
|
| 215 |
+
arguments[param_name] = param_value
|
| 216 |
+
|
| 217 |
+
# Create tool call
|
| 218 |
+
tool_call = ToolCall(
|
| 219 |
+
id=make_tool_call_id(),
|
| 220 |
+
type="function",
|
| 221 |
+
function=FunctionCall(
|
| 222 |
+
name=function_name,
|
| 223 |
+
arguments=json.dumps(
|
| 224 |
+
arguments,
|
| 225 |
+
ensure_ascii=False,
|
| 226 |
+
separators=(",", ":"),
|
| 227 |
+
),
|
| 228 |
+
),
|
| 229 |
+
)
|
| 230 |
+
tool_calls.append(tool_call)
|
| 231 |
+
|
| 232 |
+
except Exception as e:
|
| 233 |
+
# If parsing fails, log the error with the problematic XML
|
| 234 |
+
print(f"Failed to parse tool call: {e}")
|
| 235 |
+
print(f"Problematic XML (first 200 chars): {function_body[:200]}")
|
| 236 |
+
continue
|
| 237 |
+
|
| 238 |
+
# Extract text content (everything before first <function=)
|
| 239 |
+
content = re.split(r"<function=", model_output)[0].strip()
|
| 240 |
+
|
| 241 |
+
# Important: When there are tool calls, always provide a content value (even if empty string)
|
| 242 |
+
# to prevent "no response was returned" errors in clients like Copilot UI.
|
| 243 |
+
# Only set content to None when there are no tool calls AND no content.
|
| 244 |
+
if not content:
|
| 245 |
+
content = "" if len(tool_calls) > 0 else None
|
| 246 |
+
|
| 247 |
+
return ExtractedToolCallInformation(
|
| 248 |
+
tools_called=len(tool_calls) > 0, tool_calls=tool_calls, content=content
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
if __name__ == "__main__":
|
| 253 |
+
# When run as a script, start vLLM with this parser registered
|
| 254 |
+
from vllm.entrypoints.cli.main import main
|
| 255 |
+
|
| 256 |
+
main()
|