#!/usr/bin/env python3 """ Custom tool parser for vLLM with R2E-gym XML format. Same as frogboss_default_parser but handles XML format instead of JSON. Usage: vllm serve microsoft/FrogBoss-32B-2510 \ --tensor-parallel-size 4 \ --enable-auto-tool-choice \ --tool-parser-plugin ./Froggy-Training/src/vllm/frogboss_r2egym_parser.py \ --tool-call-parser froggy \ --enable-log-requests \ --enable-log-outputs \ --max-model-len 32768 """ import json import re import uuid # import the required packages from typing import Sequence, Union from vllm.entrypoints.openai.protocol import ( ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, FunctionCall, ToolCall, ) from vllm.tool_parsers import ToolParser, ToolParserManager from vllm.tool_parsers.abstract_tool_parser import ( ExtractedToolCallInformation, ) from vllm.transformers_utils.tokenizer import AnyTokenizer try: from vllm.entrypoints.chat_utils import make_tool_call_id except ImportError: # Fallback if import fails def make_tool_call_id(): return f"chatcmpl-tool-{uuid.uuid4().hex[:24]}" # define a tool parser and register it to vllm # the name list in register_module can be used # in --tool-call-parser. you can define as many # tool parsers as you want here. @ToolParserManager.register_module(["froggy"]) class FrogyToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) # adjust request. e.g.: set skip special tokens # to False for tool call output. def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: return request # implement the tool call parse for stream call def extract_tool_calls_streaming( self, previous_text: str, current_text: str, delta_text: str, previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: # For streaming, we need to handle partial tool calls progressively # Check if we're currently in a tool call (between XML function tags) # If there's no delta text, return None if not delta_text: return None # Check if we've started a function call in the current text function_started = ( "" not in current_text ) function_completed = ( "" in current_text and "" not in previous_text ) # If we just completed a function call, parse it if function_completed: # Extract the completed function call pattern = r"(.*?)" matches = re.findall(pattern, current_text, re.DOTALL) if matches: # Get the last completed function call function_name, function_body = matches[-1] try: # Parse parameters from the function body param_pattern = r"(.*?)" param_matches = re.findall(param_pattern, function_body, re.DOTALL) # Build arguments dict from parameters arguments = {} for param_name, param_value in param_matches: # Strip whitespace from parameter values param_value = param_value.strip() arguments[param_name] = param_value # Create tool call tool_calls = [] tool_call = DeltaToolCall( index=0, id=make_tool_call_id(), type="function", function=DeltaFunctionCall( name=function_name, arguments=json.dumps( arguments, ensure_ascii=False, separators=(",", ":"), ), ), ) tool_calls.append(tool_call) # Return delta with tool calls return DeltaMessage(tool_calls=tool_calls) except Exception as e: # If parsing fails, just return the delta text pass # Similar to default parser, but for XML format # If we just completed a function call, it's already handled above # If we're currently inside a function call, suppress all content # (we'll send it all as a tool call when completes) if in_function_call and not function_started: return DeltaMessage(content="") # For regular text (not in function call), handle partial tag detection # The challenge: tags like "" can leak through if split across tokens # For example: delta1="<", delta2="function", delta3="=read_file>" # We need to suppress ALL deltas while we're forming an opening tag # First, check if we just added a lone "<" character # This catches the very start of tag formation if current_text.endswith("<") and not previous_text.endswith("<"): # Just added a "<" - might be starting a tag, suppress it return DeltaMessage(content="") # Check if we're in the middle of forming an opening tag # Look for unclosed "", last_function_open if last_function_open != -1 else 0) # If we found "" after it, we're forming the tag if last_function_open != -1 and (last_function_close < last_function_open): # We're in the middle of forming "" - suppress return DeltaMessage(content="") # Same check for parameter tags last_param_open = current_text.rfind("", last_param_open if last_param_open != -1 else 0) if last_param_open != -1 and (last_param_close < last_param_open): # We're in the middle of forming "" - suppress return DeltaMessage(content="") # Check for closing tags being formed if current_text.endswith("", "" ) # Also filter parameter tags filtered_delta = re.sub(r"", "", filtered_delta) filtered_delta = filtered_delta.replace("", "") if filtered_delta: return DeltaMessage(content=filtered_delta) # Return empty content instead of None to keep the stream alive return DeltaMessage(content="") # implement the tool parse for non-stream call def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: # Parse ... tags (R2E-gym XML format) pattern = r"(.*?)" matches = re.findall(pattern, model_output, re.DOTALL) tool_calls = [] for i, (function_name, function_body) in enumerate(matches): try: # Parse parameters from the function body param_pattern = r"(.*?)" param_matches = re.findall(param_pattern, function_body, re.DOTALL) # Build arguments dict from parameters arguments = {} for param_name, param_value in param_matches: # Strip whitespace from parameter values param_value = param_value.strip() arguments[param_name] = param_value # Create tool call tool_call = ToolCall( id=make_tool_call_id(), type="function", function=FunctionCall( name=function_name, arguments=json.dumps( arguments, ensure_ascii=False, separators=(",", ":"), ), ), ) tool_calls.append(tool_call) except Exception as e: # If parsing fails, log the error with the problematic XML print(f"Failed to parse tool call: {e}") print(f"Problematic XML (first 200 chars): {function_body[:200]}") continue # Extract text content (everything before first 0 else None return ExtractedToolCallInformation( tools_called=len(tool_calls) > 0, tool_calls=tool_calls, content=content ) if __name__ == "__main__": # When run as a script, start vLLM with this parser registered from vllm.entrypoints.cli.main import main main()