Domyn-Small-v1.0 / tool_parser_plugin.py
iGenius-AI-Team's picture
squash commits
96389c0
"""
Custom vLLM tool parser plugin for models that use <tool_call> XML tags.
The model outputs tool calls in this format:
<tool_call>
{"name": "function_name", "arguments": {"arg1": "val1"}}
</tool_call>
Multiple tool calls can appear in a single response (parallel tool calling).
Usage:
vllm serve <model> \
--enable-auto-tool-choice \
--tool-parser-plugin /absolute/path/to/tool_parser_plugin.py \
--tool-call-parser xml_tool_call \
--chat-template /absolute/path/to/tool_chat_template.jinja
"""
import ast
import json
import re
import uuid
from typing import Sequence, Union
# ---------------------------------------------------------------------------
# Import compatibility: vLLM >=0.8 moved tool_parsers to vllm.tool_parsers;
# older versions keep them under vllm.entrypoints.openai.tool_parsers.
# ---------------------------------------------------------------------------
try:
# Newer vLLM, roughly 0.15+
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
except ImportError:
# Older vLLM
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
try:
from vllm.tool_parsers.abstract_tool_parser import ToolParser, ToolParserManager
except ImportError:
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
ToolParserManager,
)
from vllm.logger import init_logger
logger = init_logger(__name__)
def _generate_tool_call_id() -> str:
"""Generate a unique tool-call ID in the format expected by OpenAI."""
return f"call_{uuid.uuid4().hex[:24]}"
# ---------------------------------------------------------------------------
# Register the parser so it can be referenced via --tool-call-parser
# ---------------------------------------------------------------------------
@ToolParserManager.register_module(["xml_tool_call"])
class XMLToolCallParser(ToolParser):
"""
Parses tool calls wrapped in <tool_call>...</tool_call> XML tags.
Handles both single and parallel (multiple) tool calls in one response.
Supports streaming and non-streaming extraction.
"""
# Regex to match complete <tool_call>...</tool_call> blocks
TOOL_CALL_RE = re.compile(
r"<tool_call>\s*(.*?)\s*</tool_call>",
re.DOTALL,
)
# Regex that also matches an incomplete (still-streaming) block
TOOL_CALL_OPEN_RE = re.compile(
r"<tool_call>\s*(.*?)(?:</tool_call>|$)",
re.DOTALL,
)
TOOL_CALL_START = "<tool_call>"
TOOL_CALL_END = "</tool_call>"
def __init__(self, tokenizer, tools=None):
# vLLM newer versions: ToolParser.__init__(tokenizer, tools)
# vLLM older versions: ToolParser.__init__(tokenizer)
try:
super().__init__(tokenizer, tools)
except TypeError:
super().__init__(tokenizer)
self.tools = tools or []
# ---- streaming state ----
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = []
self.streamed_args_for_tool: list[str] = []
# ------------------------------------------------------------------
# Optional: adjust the request before inference
# ------------------------------------------------------------------
@staticmethod
def _parse_tool_json(raw: str) -> dict | None:
"""Parse a tool call JSON block, handling Python-style single quotes."""
# Try standard JSON first
try:
return json.loads(raw)
except (json.JSONDecodeError, ValueError):
pass
# Fall back to ast.literal_eval for Python-style dicts with single quotes
try:
result = ast.literal_eval(raw)
if isinstance(result, dict):
return result
except (ValueError, SyntaxError):
pass
return None
def adjust_request(
self, request: ChatCompletionRequest
) -> ChatCompletionRequest:
return request
# ------------------------------------------------------------------
# NON-STREAMING extraction
# ------------------------------------------------------------------
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
"""
Parse all <tool_call>...</tool_call> blocks from the full model
output and convert them to OpenAI ToolCall objects.
"""
# Find all complete tool-call blocks
raw_matches = self.TOOL_CALL_RE.findall(model_output)
if not raw_matches:
# No tool calls found — return the text as-is
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=model_output,
)
tool_calls: list[ToolCall] = []
for raw_json in raw_matches:
parsed = self._parse_tool_json(raw_json)
if parsed is None:
logger.warning(
"Failed to parse tool call JSON: %s", raw_json
)
continue
fn_name = parsed.get("name", "")
fn_args = parsed.get("arguments", {})
# Ensure arguments is a JSON string (OpenAI format)
if isinstance(fn_args, dict):
fn_args_str = json.dumps(fn_args)
elif isinstance(fn_args, str):
# Model may emit arguments as a JSON string — validate and pass through
try:
json.loads(fn_args)
fn_args_str = fn_args
except (json.JSONDecodeError, ValueError):
# Try ast.literal_eval for Python-style dicts (e.g. single quotes,
# unquoted keys). If that also fails, emit an empty dict so
# downstream json.loads never sees an invalid string.
try:
recovered = ast.literal_eval(fn_args)
fn_args_str = json.dumps(recovered) if isinstance(recovered, dict) else json.dumps({})
except (ValueError, SyntaxError):
fn_args_str = "{}"
else:
fn_args_str = str(fn_args)
tool_calls.append(
ToolCall(
id=_generate_tool_call_id(),
type="function",
function=FunctionCall(
name=fn_name,
arguments=fn_args_str,
),
)
)
# Strip tool-call blocks from content to get any surrounding text
remaining_content = self.TOOL_CALL_RE.sub("", model_output).strip()
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=remaining_content if remaining_content else None,
)
# ------------------------------------------------------------------
# STREAMING extraction
# ------------------------------------------------------------------
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
"""
Incrementally parse tool calls from the streaming token output.
Strategy:
- Before seeing <tool_call>, stream tokens as regular content.
- Once <tool_call> is detected, buffer until </tool_call>.
- On </tool_call>, emit the complete tool call delta.
- Support multiple sequential tool calls.
"""
# If we haven't seen a tool_call opening tag yet, pass through as
# regular content (unless the start tag is partially forming).
if self.TOOL_CALL_START not in current_text:
# Check if the current text ends with a partial match of the
# start tag — if so, hold back to avoid emitting partial tags.
for i in range(1, len(self.TOOL_CALL_START)):
if current_text.endswith(self.TOOL_CALL_START[:i]):
# Possibly forming the start tag — hold delta
return None
return DeltaMessage(content=delta_text)
# ---- We are inside or past a <tool_call> block ----
# Find all *complete* tool call blocks so far
complete_matches = self.TOOL_CALL_RE.findall(current_text)
num_complete = len(complete_matches)
# Determine how many we've already streamed
num_already_sent = len(self.prev_tool_call_arr)
if num_complete > num_already_sent:
# A new tool call just completed — emit it
new_raw = complete_matches[num_already_sent]
parsed = self._parse_tool_json(new_raw)
if parsed is None:
logger.warning(
"Streaming: failed to parse tool call JSON: %s",
new_raw,
)
return None
fn_name = parsed.get("name", "")
fn_args = parsed.get("arguments", {})
if isinstance(fn_args, dict):
fn_args_str = json.dumps(fn_args)
elif isinstance(fn_args, str):
try:
json.loads(fn_args)
fn_args_str = fn_args
except (json.JSONDecodeError, ValueError):
try:
recovered = ast.literal_eval(fn_args)
fn_args_str = json.dumps(recovered) if isinstance(recovered, dict) else json.dumps({})
except (ValueError, SyntaxError):
fn_args_str = "{}"
else:
fn_args_str = str(fn_args)
self.current_tool_id += 1
self.prev_tool_call_arr.append(parsed)
self.streamed_args_for_tool.append(fn_args_str)
self.current_tool_name_sent = True
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
id=_generate_tool_call_id(),
type="function",
function=DeltaFunctionCall(
name=fn_name,
arguments=fn_args_str,
),
)
]
)
# If we're currently inside an incomplete tool call block,
# don't emit anything — wait for it to complete.
# Check if there's an open <tool_call> without a matching close
open_count = current_text.count(self.TOOL_CALL_START)
close_count = current_text.count(self.TOOL_CALL_END)
if open_count > close_count:
# Still buffering inside a tool call
return None
# If we're past all tool call blocks, stream remaining content
# (unlikely for most models but handles edge cases)
return None