#!/usr/bin/env python3
"""
Custom tool parser for vLLM with R2E-gym XML format.
Same as frogboss_default_parser but handles XML format instead of JSON.
Usage:
vllm serve microsoft/FrogBoss-32B-2510 \
--tensor-parallel-size 4 \
--enable-auto-tool-choice \
--tool-parser-plugin ./Froggy-Training/src/vllm/frogboss_r2egym_parser.py \
--tool-call-parser froggy \
--enable-log-requests \
--enable-log-outputs \
--max-model-len 32768
"""
import json
import re
import uuid
# import the required packages
from typing import Sequence, Union
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
FunctionCall,
ToolCall,
)
from vllm.tool_parsers import ToolParser, ToolParserManager
from vllm.tool_parsers.abstract_tool_parser import (
ExtractedToolCallInformation,
)
from vllm.transformers_utils.tokenizer import AnyTokenizer
try:
from vllm.entrypoints.chat_utils import make_tool_call_id
except ImportError:
# Fallback if import fails
def make_tool_call_id():
return f"chatcmpl-tool-{uuid.uuid4().hex[:24]}"
# define a tool parser and register it to vllm
# the name list in register_module can be used
# in --tool-call-parser. you can define as many
# tool parsers as you want here.
@ToolParserManager.register_module(["froggy"])
class FrogyToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
# adjust request. e.g.: set skip special tokens
# to False for tool call output.
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
return request
# implement the tool call parse for stream call
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
# For streaming, we need to handle partial tool calls progressively
# Check if we're currently in a tool call (between XML function tags)
# If there's no delta text, return None
if not delta_text:
return None
# Check if we've started a function call in the current text
function_started = (
"" not in current_text
)
function_completed = (
"" in current_text and "" not in previous_text
)
# If we just completed a function call, parse it
if function_completed:
# Extract the completed function call
pattern = r"(.*?)"
matches = re.findall(pattern, current_text, re.DOTALL)
if matches:
# Get the last completed function call
function_name, function_body = matches[-1]
try:
# Parse parameters from the function body
param_pattern = r"(.*?)"
param_matches = re.findall(param_pattern, function_body, re.DOTALL)
# Build arguments dict from parameters
arguments = {}
for param_name, param_value in param_matches:
# Strip whitespace from parameter values
param_value = param_value.strip()
arguments[param_name] = param_value
# Create tool call
tool_calls = []
tool_call = DeltaToolCall(
index=0,
id=make_tool_call_id(),
type="function",
function=DeltaFunctionCall(
name=function_name,
arguments=json.dumps(
arguments,
ensure_ascii=False,
separators=(",", ":"),
),
),
)
tool_calls.append(tool_call)
# Return delta with tool calls
return DeltaMessage(tool_calls=tool_calls)
except Exception as e:
# If parsing fails, just return the delta text
pass
# Similar to default parser, but for XML format
# If we just completed a function call, it's already handled above
# If we're currently inside a function call, suppress all content
# (we'll send it all as a tool call when completes)
if in_function_call and not function_started:
return DeltaMessage(content="")
# For regular text (not in function call), handle partial tag detection
# The challenge: tags like "" can leak through if split across tokens
# For example: delta1="<", delta2="function", delta3="=read_file>"
# We need to suppress ALL deltas while we're forming an opening tag
# First, check if we just added a lone "<" character
# This catches the very start of tag formation
if current_text.endswith("<") and not previous_text.endswith("<"):
# Just added a "<" - might be starting a tag, suppress it
return DeltaMessage(content="")
# Check if we're in the middle of forming an opening tag
# Look for unclosed "", last_function_open if last_function_open != -1 else 0)
# If we found "" after it, we're forming the tag
if last_function_open != -1 and (last_function_close < last_function_open):
# We're in the middle of forming "" - suppress
return DeltaMessage(content="")
# Same check for parameter tags
last_param_open = current_text.rfind("", last_param_open if last_param_open != -1 else 0)
if last_param_open != -1 and (last_param_close < last_param_open):
# We're in the middle of forming "" - suppress
return DeltaMessage(content="")
# Check for closing tags being formed
if current_text.endswith("", ""
)
# Also filter parameter tags
filtered_delta = re.sub(r"", "", filtered_delta)
filtered_delta = filtered_delta.replace("", "")
if filtered_delta:
return DeltaMessage(content=filtered_delta)
# Return empty content instead of None to keep the stream alive
return DeltaMessage(content="")
# implement the tool parse for non-stream call
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
# Parse ... tags (R2E-gym XML format)
pattern = r"(.*?)"
matches = re.findall(pattern, model_output, re.DOTALL)
tool_calls = []
for i, (function_name, function_body) in enumerate(matches):
try:
# Parse parameters from the function body
param_pattern = r"(.*?)"
param_matches = re.findall(param_pattern, function_body, re.DOTALL)
# Build arguments dict from parameters
arguments = {}
for param_name, param_value in param_matches:
# Strip whitespace from parameter values
param_value = param_value.strip()
arguments[param_name] = param_value
# Create tool call
tool_call = ToolCall(
id=make_tool_call_id(),
type="function",
function=FunctionCall(
name=function_name,
arguments=json.dumps(
arguments,
ensure_ascii=False,
separators=(",", ":"),
),
),
)
tool_calls.append(tool_call)
except Exception as e:
# If parsing fails, log the error with the problematic XML
print(f"Failed to parse tool call: {e}")
print(f"Problematic XML (first 200 chars): {function_body[:200]}")
continue
# Extract text content (everything before first 0 else None
return ExtractedToolCallInformation(
tools_called=len(tool_calls) > 0, tool_calls=tool_calls, content=content
)
if __name__ == "__main__":
# When run as a script, start vLLM with this parser registered
from vllm.entrypoints.cli.main import main
main()