File size: 10,377 Bytes
5329af4 cc93095 5329af4 cc93095 5329af4 cc93095 5329af4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 | #!/usr/bin/env python3
"""
Custom tool parser for vLLM with R2E-gym XML format.
Same as frogboss_default_parser but handles XML format instead of JSON.
Usage:
vllm serve microsoft/FrogBoss-32B-2510 \
--tensor-parallel-size 4 \
--enable-auto-tool-choice \
--tool-parser-plugin ./Froggy-Training/src/vllm/frogboss_r2egym_parser.py \
--tool-call-parser froggy \
--enable-log-requests \
--enable-log-outputs \
--max-model-len 32768
"""
import json
import re
import uuid
# import the required packages
from typing import Sequence, Union
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
FunctionCall,
ToolCall,
)
from vllm.tool_parsers import ToolParser, ToolParserManager
from vllm.tool_parsers.abstract_tool_parser import (
ExtractedToolCallInformation,
)
from vllm.transformers_utils.tokenizer import AnyTokenizer
try:
from vllm.entrypoints.chat_utils import make_tool_call_id
except ImportError:
# Fallback if import fails
def make_tool_call_id():
return f"chatcmpl-tool-{uuid.uuid4().hex[:24]}"
# define a tool parser and register it to vllm
# the name list in register_module can be used
# in --tool-call-parser. you can define as many
# tool parsers as you want here.
@ToolParserManager.register_module(["froggy"])
class FrogyToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
# adjust request. e.g.: set skip special tokens
# to False for tool call output.
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
return request
# implement the tool call parse for stream call
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
# For streaming, we need to handle partial tool calls progressively
# Check if we're currently in a tool call (between XML function tags)
# If there's no delta text, return None
if not delta_text:
return None
# Check if we've started a function call in the current text
function_started = (
"<function=" in current_text and "<function=" not in previous_text
)
in_function_call = (
"<function=" in current_text and "</function>" not in current_text
)
function_completed = (
"</function>" in current_text and "</function>" not in previous_text
)
# If we just completed a function call, parse it
if function_completed:
# Extract the completed function call
pattern = r"<function=(\w+)>(.*?)</function>"
matches = re.findall(pattern, current_text, re.DOTALL)
if matches:
# Get the last completed function call
function_name, function_body = matches[-1]
try:
# Parse parameters from the function body
param_pattern = r"<parameter=(\w+)>(.*?)</parameter>"
param_matches = re.findall(param_pattern, function_body, re.DOTALL)
# Build arguments dict from parameters
arguments = {}
for param_name, param_value in param_matches:
# Strip whitespace from parameter values
param_value = param_value.strip()
arguments[param_name] = param_value
# Create tool call
tool_calls = []
tool_call = DeltaToolCall(
index=0,
id=make_tool_call_id(),
type="function",
function=DeltaFunctionCall(
name=function_name,
arguments=json.dumps(
arguments,
ensure_ascii=False,
separators=(",", ":"),
),
),
)
tool_calls.append(tool_call)
# Return delta with tool calls
return DeltaMessage(tool_calls=tool_calls)
except Exception as e:
# If parsing fails, just return the delta text
pass
# Similar to default parser, but for XML format
# If we just completed a function call, it's already handled above
# If we're currently inside a function call, suppress all content
# (we'll send it all as a tool call when </function> completes)
if in_function_call and not function_started:
return DeltaMessage(content="")
# For regular text (not in function call), handle partial tag detection
# The challenge: tags like "<function=read_file>" can leak through if split across tokens
# For example: delta1="<", delta2="function", delta3="=read_file>"
# We need to suppress ALL deltas while we're forming an opening tag
# First, check if we just added a lone "<" character
# This catches the very start of tag formation
if current_text.endswith("<") and not previous_text.endswith("<"):
# Just added a "<" - might be starting a tag, suppress it
return DeltaMessage(content="")
# Check if we're in the middle of forming an opening tag
# Look for unclosed "<function" or "<parameter" tags in current_text
last_function_open = current_text.rfind("<function")
last_function_close = current_text.rfind(">", last_function_open if last_function_open != -1 else 0)
# If we found "<function" and there's no ">" after it, we're forming the tag
if last_function_open != -1 and (last_function_close < last_function_open):
# We're in the middle of forming "<function=name>" - suppress
return DeltaMessage(content="")
# Same check for parameter tags
last_param_open = current_text.rfind("<parameter")
last_param_close = current_text.rfind(">", last_param_open if last_param_open != -1 else 0)
if last_param_open != -1 and (last_param_close < last_param_open):
# We're in the middle of forming "<parameter=name>" - suppress
return DeltaMessage(content="")
# Check for closing tags being formed
if current_text.endswith("</function") or current_text.endswith("</parameter"):
# Partial closing tag - suppress until complete
return DeltaMessage(content="")
# For regular text, filter out complete tags
filtered_delta = delta_text
# Remove complete tags if they appear in this delta
filtered_delta = filtered_delta.replace("<function=", "").replace(
"</function>", ""
)
# Also filter parameter tags
filtered_delta = re.sub(r"<parameter=\w+>", "", filtered_delta)
filtered_delta = filtered_delta.replace("</parameter>", "")
if filtered_delta:
return DeltaMessage(content=filtered_delta)
# Return empty content instead of None to keep the stream alive
return DeltaMessage(content="")
# implement the tool parse for non-stream call
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
# Parse <function=...>...</function> tags (R2E-gym XML format)
pattern = r"<function=(\w+)>(.*?)</function>"
matches = re.findall(pattern, model_output, re.DOTALL)
tool_calls = []
for i, (function_name, function_body) in enumerate(matches):
try:
# Parse parameters from the function body
param_pattern = r"<parameter=(\w+)>(.*?)</parameter>"
param_matches = re.findall(param_pattern, function_body, re.DOTALL)
# Build arguments dict from parameters
arguments = {}
for param_name, param_value in param_matches:
# Strip whitespace from parameter values
param_value = param_value.strip()
arguments[param_name] = param_value
# Create tool call
tool_call = ToolCall(
id=make_tool_call_id(),
type="function",
function=FunctionCall(
name=function_name,
arguments=json.dumps(
arguments,
ensure_ascii=False,
separators=(",", ":"),
),
),
)
tool_calls.append(tool_call)
except Exception as e:
# If parsing fails, log the error with the problematic XML
print(f"Failed to parse tool call: {e}")
print(f"Problematic XML (first 200 chars): {function_body[:200]}")
continue
# Extract text content (everything before first <function=)
content = re.split(r"<function=", model_output)[0].strip()
# Important: When there are tool calls, always provide a content value (even if empty string)
# to prevent "no response was returned" errors in clients like Copilot UI.
# Only set content to None when there are no tool calls AND no content.
if not content:
content = "" if len(tool_calls) > 0 else None
return ExtractedToolCallInformation(
tools_called=len(tool_calls) > 0, tool_calls=tool_calls, content=content
)
if __name__ == "__main__":
# When run as a script, start vLLM with this parser registered
from vllm.entrypoints.cli.main import main
main()
|