|
|
|
|
|
""" |
|
|
Custom tool parser for vLLM with R2E-gym XML format. |
|
|
Same as frogboss_default_parser but handles XML format instead of JSON. |
|
|
|
|
|
Usage: |
|
|
vllm serve microsoft/FrogBoss-32B-2510 \ |
|
|
--tensor-parallel-size 4 \ |
|
|
--enable-auto-tool-choice \ |
|
|
--tool-parser-plugin ./Froggy-Training/src/vllm/frogboss_r2egym_parser.py \ |
|
|
--tool-call-parser froggy \ |
|
|
--enable-log-requests \ |
|
|
--enable-log-outputs \ |
|
|
--max-model-len 32768 |
|
|
""" |
|
|
import json |
|
|
import re |
|
|
import uuid |
|
|
|
|
|
|
|
|
from typing import Sequence, Union |
|
|
|
|
|
from vllm.entrypoints.openai.protocol import ( |
|
|
ChatCompletionRequest, |
|
|
DeltaFunctionCall, |
|
|
DeltaMessage, |
|
|
DeltaToolCall, |
|
|
FunctionCall, |
|
|
ToolCall, |
|
|
) |
|
|
from vllm.tool_parsers import ToolParser, ToolParserManager |
|
|
from vllm.tool_parsers.abstract_tool_parser import ( |
|
|
ExtractedToolCallInformation, |
|
|
) |
|
|
from vllm.transformers_utils.tokenizer import AnyTokenizer |
|
|
|
|
|
try: |
|
|
from vllm.entrypoints.chat_utils import make_tool_call_id |
|
|
except ImportError: |
|
|
|
|
|
def make_tool_call_id(): |
|
|
return f"chatcmpl-tool-{uuid.uuid4().hex[:24]}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ToolParserManager.register_module(["froggy"]) |
|
|
class FrogyToolParser(ToolParser): |
|
|
def __init__(self, tokenizer: AnyTokenizer): |
|
|
super().__init__(tokenizer) |
|
|
|
|
|
|
|
|
|
|
|
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: |
|
|
return request |
|
|
|
|
|
|
|
|
def extract_tool_calls_streaming( |
|
|
self, |
|
|
previous_text: str, |
|
|
current_text: str, |
|
|
delta_text: str, |
|
|
previous_token_ids: Sequence[int], |
|
|
current_token_ids: Sequence[int], |
|
|
delta_token_ids: Sequence[int], |
|
|
request: ChatCompletionRequest, |
|
|
) -> Union[DeltaMessage, None]: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not delta_text: |
|
|
return None |
|
|
|
|
|
|
|
|
function_started = ( |
|
|
"<function=" in current_text and "<function=" not in previous_text |
|
|
) |
|
|
in_function_call = ( |
|
|
"<function=" in current_text and "</function>" not in current_text |
|
|
) |
|
|
function_completed = ( |
|
|
"</function>" in current_text and "</function>" not in previous_text |
|
|
) |
|
|
|
|
|
|
|
|
if function_completed: |
|
|
|
|
|
pattern = r"<function=(\w+)>(.*?)</function>" |
|
|
matches = re.findall(pattern, current_text, re.DOTALL) |
|
|
|
|
|
if matches: |
|
|
|
|
|
function_name, function_body = matches[-1] |
|
|
try: |
|
|
|
|
|
param_pattern = r"<parameter=(\w+)>(.*?)</parameter>" |
|
|
param_matches = re.findall(param_pattern, function_body, re.DOTALL) |
|
|
|
|
|
|
|
|
arguments = {} |
|
|
for param_name, param_value in param_matches: |
|
|
|
|
|
param_value = param_value.strip() |
|
|
arguments[param_name] = param_value |
|
|
|
|
|
|
|
|
tool_calls = [] |
|
|
tool_call = DeltaToolCall( |
|
|
index=0, |
|
|
id=make_tool_call_id(), |
|
|
type="function", |
|
|
function=DeltaFunctionCall( |
|
|
name=function_name, |
|
|
arguments=json.dumps( |
|
|
arguments, |
|
|
ensure_ascii=False, |
|
|
separators=(",", ":"), |
|
|
), |
|
|
), |
|
|
) |
|
|
tool_calls.append(tool_call) |
|
|
|
|
|
|
|
|
return DeltaMessage(tool_calls=tool_calls) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if in_function_call and not function_started: |
|
|
return DeltaMessage(content="") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if current_text.endswith("<") and not previous_text.endswith("<"): |
|
|
|
|
|
return DeltaMessage(content="") |
|
|
|
|
|
|
|
|
|
|
|
last_function_open = current_text.rfind("<function") |
|
|
last_function_close = current_text.rfind(">", last_function_open if last_function_open != -1 else 0) |
|
|
|
|
|
|
|
|
if last_function_open != -1 and (last_function_close < last_function_open): |
|
|
|
|
|
return DeltaMessage(content="") |
|
|
|
|
|
|
|
|
last_param_open = current_text.rfind("<parameter") |
|
|
last_param_close = current_text.rfind(">", last_param_open if last_param_open != -1 else 0) |
|
|
|
|
|
if last_param_open != -1 and (last_param_close < last_param_open): |
|
|
|
|
|
return DeltaMessage(content="") |
|
|
|
|
|
|
|
|
if current_text.endswith("</function") or current_text.endswith("</parameter"): |
|
|
|
|
|
return DeltaMessage(content="") |
|
|
|
|
|
|
|
|
filtered_delta = delta_text |
|
|
|
|
|
|
|
|
filtered_delta = filtered_delta.replace("<function=", "").replace( |
|
|
"</function>", "" |
|
|
) |
|
|
|
|
|
filtered_delta = re.sub(r"<parameter=\w+>", "", filtered_delta) |
|
|
filtered_delta = filtered_delta.replace("</parameter>", "") |
|
|
|
|
|
if filtered_delta: |
|
|
return DeltaMessage(content=filtered_delta) |
|
|
|
|
|
|
|
|
return DeltaMessage(content="") |
|
|
|
|
|
|
|
|
def extract_tool_calls( |
|
|
self, |
|
|
model_output: str, |
|
|
request: ChatCompletionRequest, |
|
|
) -> ExtractedToolCallInformation: |
|
|
|
|
|
pattern = r"<function=(\w+)>(.*?)</function>" |
|
|
matches = re.findall(pattern, model_output, re.DOTALL) |
|
|
|
|
|
tool_calls = [] |
|
|
|
|
|
for i, (function_name, function_body) in enumerate(matches): |
|
|
try: |
|
|
|
|
|
param_pattern = r"<parameter=(\w+)>(.*?)</parameter>" |
|
|
param_matches = re.findall(param_pattern, function_body, re.DOTALL) |
|
|
|
|
|
|
|
|
arguments = {} |
|
|
for param_name, param_value in param_matches: |
|
|
|
|
|
param_value = param_value.strip() |
|
|
arguments[param_name] = param_value |
|
|
|
|
|
|
|
|
tool_call = ToolCall( |
|
|
id=make_tool_call_id(), |
|
|
type="function", |
|
|
function=FunctionCall( |
|
|
name=function_name, |
|
|
arguments=json.dumps( |
|
|
arguments, |
|
|
ensure_ascii=False, |
|
|
separators=(",", ":"), |
|
|
), |
|
|
), |
|
|
) |
|
|
tool_calls.append(tool_call) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
print(f"Failed to parse tool call: {e}") |
|
|
print(f"Problematic XML (first 200 chars): {function_body[:200]}") |
|
|
continue |
|
|
|
|
|
|
|
|
content = re.split(r"<function=", model_output)[0].strip() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not content: |
|
|
content = "" if len(tool_calls) > 0 else None |
|
|
|
|
|
return ExtractedToolCallInformation( |
|
|
tools_called=len(tool_calls) > 0, tool_calls=tool_calls, content=content |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
from vllm.entrypoints.cli.main import main |
|
|
|
|
|
main() |
|
|
|