FrogBoss-32B-2510 / frogboss_r2egym_parser.py
chsingh's picture
Update frogboss_r2egym_parser.py
cc93095 verified
#!/usr/bin/env python3
"""
Custom tool parser for vLLM with R2E-gym XML format.
Same as frogboss_default_parser but handles XML format instead of JSON.
Usage:
vllm serve microsoft/FrogBoss-32B-2510 \
--tensor-parallel-size 4 \
--enable-auto-tool-choice \
--tool-parser-plugin ./Froggy-Training/src/vllm/frogboss_r2egym_parser.py \
--tool-call-parser froggy \
--enable-log-requests \
--enable-log-outputs \
--max-model-len 32768
"""
import json
import re
import uuid
# import the required packages
from typing import Sequence, Union
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
FunctionCall,
ToolCall,
)
from vllm.tool_parsers import ToolParser, ToolParserManager
from vllm.tool_parsers.abstract_tool_parser import (
ExtractedToolCallInformation,
)
from vllm.transformers_utils.tokenizer import AnyTokenizer
try:
from vllm.entrypoints.chat_utils import make_tool_call_id
except ImportError:
# Fallback if import fails
def make_tool_call_id():
return f"chatcmpl-tool-{uuid.uuid4().hex[:24]}"
# define a tool parser and register it to vllm
# the name list in register_module can be used
# in --tool-call-parser. you can define as many
# tool parsers as you want here.
@ToolParserManager.register_module(["froggy"])
class FrogyToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
# adjust request. e.g.: set skip special tokens
# to False for tool call output.
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
return request
# implement the tool call parse for stream call
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
# For streaming, we need to handle partial tool calls progressively
# Check if we're currently in a tool call (between XML function tags)
# If there's no delta text, return None
if not delta_text:
return None
# Check if we've started a function call in the current text
function_started = (
"<function=" in current_text and "<function=" not in previous_text
)
in_function_call = (
"<function=" in current_text and "</function>" not in current_text
)
function_completed = (
"</function>" in current_text and "</function>" not in previous_text
)
# If we just completed a function call, parse it
if function_completed:
# Extract the completed function call
pattern = r"<function=(\w+)>(.*?)</function>"
matches = re.findall(pattern, current_text, re.DOTALL)
if matches:
# Get the last completed function call
function_name, function_body = matches[-1]
try:
# Parse parameters from the function body
param_pattern = r"<parameter=(\w+)>(.*?)</parameter>"
param_matches = re.findall(param_pattern, function_body, re.DOTALL)
# Build arguments dict from parameters
arguments = {}
for param_name, param_value in param_matches:
# Strip whitespace from parameter values
param_value = param_value.strip()
arguments[param_name] = param_value
# Create tool call
tool_calls = []
tool_call = DeltaToolCall(
index=0,
id=make_tool_call_id(),
type="function",
function=DeltaFunctionCall(
name=function_name,
arguments=json.dumps(
arguments,
ensure_ascii=False,
separators=(",", ":"),
),
),
)
tool_calls.append(tool_call)
# Return delta with tool calls
return DeltaMessage(tool_calls=tool_calls)
except Exception as e:
# If parsing fails, just return the delta text
pass
# Similar to default parser, but for XML format
# If we just completed a function call, it's already handled above
# If we're currently inside a function call, suppress all content
# (we'll send it all as a tool call when </function> completes)
if in_function_call and not function_started:
return DeltaMessage(content="")
# For regular text (not in function call), handle partial tag detection
# The challenge: tags like "<function=read_file>" can leak through if split across tokens
# For example: delta1="<", delta2="function", delta3="=read_file>"
# We need to suppress ALL deltas while we're forming an opening tag
# First, check if we just added a lone "<" character
# This catches the very start of tag formation
if current_text.endswith("<") and not previous_text.endswith("<"):
# Just added a "<" - might be starting a tag, suppress it
return DeltaMessage(content="")
# Check if we're in the middle of forming an opening tag
# Look for unclosed "<function" or "<parameter" tags in current_text
last_function_open = current_text.rfind("<function")
last_function_close = current_text.rfind(">", last_function_open if last_function_open != -1 else 0)
# If we found "<function" and there's no ">" after it, we're forming the tag
if last_function_open != -1 and (last_function_close < last_function_open):
# We're in the middle of forming "<function=name>" - suppress
return DeltaMessage(content="")
# Same check for parameter tags
last_param_open = current_text.rfind("<parameter")
last_param_close = current_text.rfind(">", last_param_open if last_param_open != -1 else 0)
if last_param_open != -1 and (last_param_close < last_param_open):
# We're in the middle of forming "<parameter=name>" - suppress
return DeltaMessage(content="")
# Check for closing tags being formed
if current_text.endswith("</function") or current_text.endswith("</parameter"):
# Partial closing tag - suppress until complete
return DeltaMessage(content="")
# For regular text, filter out complete tags
filtered_delta = delta_text
# Remove complete tags if they appear in this delta
filtered_delta = filtered_delta.replace("<function=", "").replace(
"</function>", ""
)
# Also filter parameter tags
filtered_delta = re.sub(r"<parameter=\w+>", "", filtered_delta)
filtered_delta = filtered_delta.replace("</parameter>", "")
if filtered_delta:
return DeltaMessage(content=filtered_delta)
# Return empty content instead of None to keep the stream alive
return DeltaMessage(content="")
# implement the tool parse for non-stream call
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
# Parse <function=...>...</function> tags (R2E-gym XML format)
pattern = r"<function=(\w+)>(.*?)</function>"
matches = re.findall(pattern, model_output, re.DOTALL)
tool_calls = []
for i, (function_name, function_body) in enumerate(matches):
try:
# Parse parameters from the function body
param_pattern = r"<parameter=(\w+)>(.*?)</parameter>"
param_matches = re.findall(param_pattern, function_body, re.DOTALL)
# Build arguments dict from parameters
arguments = {}
for param_name, param_value in param_matches:
# Strip whitespace from parameter values
param_value = param_value.strip()
arguments[param_name] = param_value
# Create tool call
tool_call = ToolCall(
id=make_tool_call_id(),
type="function",
function=FunctionCall(
name=function_name,
arguments=json.dumps(
arguments,
ensure_ascii=False,
separators=(",", ":"),
),
),
)
tool_calls.append(tool_call)
except Exception as e:
# If parsing fails, log the error with the problematic XML
print(f"Failed to parse tool call: {e}")
print(f"Problematic XML (first 200 chars): {function_body[:200]}")
continue
# Extract text content (everything before first <function=)
content = re.split(r"<function=", model_output)[0].strip()
# Important: When there are tool calls, always provide a content value (even if empty string)
# to prevent "no response was returned" errors in clients like Copilot UI.
# Only set content to None when there are no tool calls AND no content.
if not content:
content = "" if len(tool_calls) > 0 else None
return ExtractedToolCallInformation(
tools_called=len(tool_calls) > 0, tool_calls=tool_calls, content=content
)
if __name__ == "__main__":
# When run as a script, start vLLM with this parser registered
from vllm.entrypoints.cli.main import main
main()