| """ |
| Custom vLLM tool parser plugin for models that use <tool_call> XML tags. |
| |
| The model outputs tool calls in this format: |
| <tool_call> |
| {"name": "function_name", "arguments": {"arg1": "val1"}} |
| </tool_call> |
| |
| Multiple tool calls can appear in a single response (parallel tool calling). |
| |
| Usage: |
| vllm serve <model> \ |
| --enable-auto-tool-choice \ |
| --tool-parser-plugin /absolute/path/to/tool_parser_plugin.py \ |
| --tool-call-parser xml_tool_call \ |
| --chat-template /absolute/path/to/tool_chat_template.jinja |
| """ |
|
|
| import ast |
| import json |
| import re |
| import uuid |
| from typing import Sequence, Union |
|
|
| |
| |
| |
| |
| try: |
| |
| from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest |
| from vllm.entrypoints.openai.engine.protocol import ( |
| DeltaFunctionCall, |
| DeltaMessage, |
| DeltaToolCall, |
| ExtractedToolCallInformation, |
| FunctionCall, |
| ToolCall, |
| ) |
| except ImportError: |
| |
| from vllm.entrypoints.openai.protocol import ( |
| ChatCompletionRequest, |
| DeltaFunctionCall, |
| DeltaMessage, |
| DeltaToolCall, |
| ExtractedToolCallInformation, |
| FunctionCall, |
| ToolCall, |
| ) |
|
|
| try: |
| from vllm.tool_parsers.abstract_tool_parser import ToolParser, ToolParserManager |
| except ImportError: |
| from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( |
| ToolParser, |
| ToolParserManager, |
| ) |
|
|
| from vllm.logger import init_logger |
|
|
| logger = init_logger(__name__) |
|
|
|
|
| def _generate_tool_call_id() -> str: |
| """Generate a unique tool-call ID in the format expected by OpenAI.""" |
| return f"call_{uuid.uuid4().hex[:24]}" |
|
|
|
|
| |
| |
| |
| @ToolParserManager.register_module(["xml_tool_call"]) |
| class XMLToolCallParser(ToolParser): |
| """ |
| Parses tool calls wrapped in <tool_call>...</tool_call> XML tags. |
| |
| Handles both single and parallel (multiple) tool calls in one response. |
| Supports streaming and non-streaming extraction. |
| """ |
|
|
| |
| TOOL_CALL_RE = re.compile( |
| r"<tool_call>\s*(.*?)\s*</tool_call>", |
| re.DOTALL, |
| ) |
|
|
| |
| TOOL_CALL_OPEN_RE = re.compile( |
| r"<tool_call>\s*(.*?)(?:</tool_call>|$)", |
| re.DOTALL, |
| ) |
|
|
| TOOL_CALL_START = "<tool_call>" |
| TOOL_CALL_END = "</tool_call>" |
|
|
| def __init__(self, tokenizer, tools=None): |
| |
| |
| try: |
| super().__init__(tokenizer, tools) |
| except TypeError: |
| super().__init__(tokenizer) |
| self.tools = tools or [] |
| |
| |
| self.current_tool_id: int = -1 |
| self.current_tool_name_sent: bool = False |
| self.prev_tool_call_arr: list[dict] = [] |
| self.streamed_args_for_tool: list[str] = [] |
|
|
| |
| |
| |
| @staticmethod |
| def _parse_tool_json(raw: str) -> dict | None: |
| """Parse a tool call JSON block, handling Python-style single quotes.""" |
| |
| try: |
| return json.loads(raw) |
| except (json.JSONDecodeError, ValueError): |
| pass |
| |
| try: |
| result = ast.literal_eval(raw) |
| if isinstance(result, dict): |
| return result |
| except (ValueError, SyntaxError): |
| pass |
| return None |
|
|
| def adjust_request( |
| self, request: ChatCompletionRequest |
| ) -> ChatCompletionRequest: |
| return request |
|
|
| |
| |
| |
| def extract_tool_calls( |
| self, |
| model_output: str, |
| request: ChatCompletionRequest, |
| ) -> ExtractedToolCallInformation: |
| """ |
| Parse all <tool_call>...</tool_call> blocks from the full model |
| output and convert them to OpenAI ToolCall objects. |
| """ |
|
|
| |
| raw_matches = self.TOOL_CALL_RE.findall(model_output) |
|
|
| if not raw_matches: |
| |
| return ExtractedToolCallInformation( |
| tools_called=False, |
| tool_calls=[], |
| content=model_output, |
| ) |
|
|
| tool_calls: list[ToolCall] = [] |
| for raw_json in raw_matches: |
| parsed = self._parse_tool_json(raw_json) |
| if parsed is None: |
| logger.warning( |
| "Failed to parse tool call JSON: %s", raw_json |
| ) |
| continue |
|
|
| fn_name = parsed.get("name", "") |
| fn_args = parsed.get("arguments", {}) |
|
|
| |
| if isinstance(fn_args, dict): |
| fn_args_str = json.dumps(fn_args) |
| elif isinstance(fn_args, str): |
| |
| try: |
| json.loads(fn_args) |
| fn_args_str = fn_args |
| except (json.JSONDecodeError, ValueError): |
| |
| |
| |
| try: |
| recovered = ast.literal_eval(fn_args) |
| fn_args_str = json.dumps(recovered) if isinstance(recovered, dict) else json.dumps({}) |
| except (ValueError, SyntaxError): |
| fn_args_str = "{}" |
| else: |
| fn_args_str = str(fn_args) |
|
|
| tool_calls.append( |
| ToolCall( |
| id=_generate_tool_call_id(), |
| type="function", |
| function=FunctionCall( |
| name=fn_name, |
| arguments=fn_args_str, |
| ), |
| ) |
| ) |
|
|
| |
| remaining_content = self.TOOL_CALL_RE.sub("", model_output).strip() |
|
|
| return ExtractedToolCallInformation( |
| tools_called=True, |
| tool_calls=tool_calls, |
| content=remaining_content if remaining_content else None, |
| ) |
|
|
| |
| |
| |
| def extract_tool_calls_streaming( |
| self, |
| previous_text: str, |
| current_text: str, |
| delta_text: str, |
| previous_token_ids: Sequence[int], |
| current_token_ids: Sequence[int], |
| delta_token_ids: Sequence[int], |
| request: ChatCompletionRequest, |
| ) -> Union[DeltaMessage, None]: |
| """ |
| Incrementally parse tool calls from the streaming token output. |
| |
| Strategy: |
| - Before seeing <tool_call>, stream tokens as regular content. |
| - Once <tool_call> is detected, buffer until </tool_call>. |
| - On </tool_call>, emit the complete tool call delta. |
| - Support multiple sequential tool calls. |
| """ |
|
|
| |
| |
| if self.TOOL_CALL_START not in current_text: |
| |
| |
| for i in range(1, len(self.TOOL_CALL_START)): |
| if current_text.endswith(self.TOOL_CALL_START[:i]): |
| |
| return None |
| return DeltaMessage(content=delta_text) |
|
|
| |
|
|
| |
| complete_matches = self.TOOL_CALL_RE.findall(current_text) |
| num_complete = len(complete_matches) |
|
|
| |
| num_already_sent = len(self.prev_tool_call_arr) |
|
|
| if num_complete > num_already_sent: |
| |
| new_raw = complete_matches[num_already_sent] |
| parsed = self._parse_tool_json(new_raw) |
| if parsed is None: |
| logger.warning( |
| "Streaming: failed to parse tool call JSON: %s", |
| new_raw, |
| ) |
| return None |
|
|
| fn_name = parsed.get("name", "") |
| fn_args = parsed.get("arguments", {}) |
| if isinstance(fn_args, dict): |
| fn_args_str = json.dumps(fn_args) |
| elif isinstance(fn_args, str): |
| try: |
| json.loads(fn_args) |
| fn_args_str = fn_args |
| except (json.JSONDecodeError, ValueError): |
| try: |
| recovered = ast.literal_eval(fn_args) |
| fn_args_str = json.dumps(recovered) if isinstance(recovered, dict) else json.dumps({}) |
| except (ValueError, SyntaxError): |
| fn_args_str = "{}" |
| else: |
| fn_args_str = str(fn_args) |
|
|
| self.current_tool_id += 1 |
| self.prev_tool_call_arr.append(parsed) |
| self.streamed_args_for_tool.append(fn_args_str) |
| self.current_tool_name_sent = True |
|
|
| return DeltaMessage( |
| tool_calls=[ |
| DeltaToolCall( |
| index=self.current_tool_id, |
| id=_generate_tool_call_id(), |
| type="function", |
| function=DeltaFunctionCall( |
| name=fn_name, |
| arguments=fn_args_str, |
| ), |
| ) |
| ] |
| ) |
|
|
| |
| |
| |
| open_count = current_text.count(self.TOOL_CALL_START) |
| close_count = current_text.count(self.TOOL_CALL_END) |
| if open_count > close_count: |
| |
| return None |
|
|
| |
| |
| return None |