""" Custom vLLM tool parser plugin for models that use XML tags. The model outputs tool calls in this format: {"name": "function_name", "arguments": {"arg1": "val1"}} Multiple tool calls can appear in a single response (parallel tool calling). Usage: vllm serve \ --enable-auto-tool-choice \ --tool-parser-plugin /absolute/path/to/tool_parser_plugin.py \ --tool-call-parser xml_tool_call \ --chat-template /absolute/path/to/tool_chat_template.jinja """ import ast import json import re import uuid from typing import Sequence, Union # --------------------------------------------------------------------------- # Import compatibility: vLLM >=0.8 moved tool_parsers to vllm.tool_parsers; # older versions keep them under vllm.entrypoints.openai.tool_parsers. # --------------------------------------------------------------------------- try: # Newer vLLM, roughly 0.15+ from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest from vllm.entrypoints.openai.engine.protocol import ( DeltaFunctionCall, DeltaMessage, DeltaToolCall, ExtractedToolCallInformation, FunctionCall, ToolCall, ) except ImportError: # Older vLLM from vllm.entrypoints.openai.protocol import ( ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, ExtractedToolCallInformation, FunctionCall, ToolCall, ) try: from vllm.tool_parsers.abstract_tool_parser import ToolParser, ToolParserManager except ImportError: from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, ToolParserManager, ) from vllm.logger import init_logger logger = init_logger(__name__) def _generate_tool_call_id() -> str: """Generate a unique tool-call ID in the format expected by OpenAI.""" return f"call_{uuid.uuid4().hex[:24]}" # --------------------------------------------------------------------------- # Register the parser so it can be referenced via --tool-call-parser # --------------------------------------------------------------------------- @ToolParserManager.register_module(["xml_tool_call"]) class XMLToolCallParser(ToolParser): """ Parses tool calls wrapped in ... XML tags. Handles both single and parallel (multiple) tool calls in one response. Supports streaming and non-streaming extraction. """ # Regex to match complete ... blocks TOOL_CALL_RE = re.compile( r"\s*(.*?)\s*", re.DOTALL, ) # Regex that also matches an incomplete (still-streaming) block TOOL_CALL_OPEN_RE = re.compile( r"\s*(.*?)(?:|$)", re.DOTALL, ) TOOL_CALL_START = "" TOOL_CALL_END = "" def __init__(self, tokenizer, tools=None): # vLLM newer versions: ToolParser.__init__(tokenizer, tools) # vLLM older versions: ToolParser.__init__(tokenizer) try: super().__init__(tokenizer, tools) except TypeError: super().__init__(tokenizer) self.tools = tools or [] # ---- streaming state ---- self.current_tool_id: int = -1 self.current_tool_name_sent: bool = False self.prev_tool_call_arr: list[dict] = [] self.streamed_args_for_tool: list[str] = [] # ------------------------------------------------------------------ # Optional: adjust the request before inference # ------------------------------------------------------------------ @staticmethod def _parse_tool_json(raw: str) -> dict | None: """Parse a tool call JSON block, handling Python-style single quotes.""" # Try standard JSON first try: return json.loads(raw) except (json.JSONDecodeError, ValueError): pass # Fall back to ast.literal_eval for Python-style dicts with single quotes try: result = ast.literal_eval(raw) if isinstance(result, dict): return result except (ValueError, SyntaxError): pass return None def adjust_request( self, request: ChatCompletionRequest ) -> ChatCompletionRequest: return request # ------------------------------------------------------------------ # NON-STREAMING extraction # ------------------------------------------------------------------ def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: """ Parse all ... blocks from the full model output and convert them to OpenAI ToolCall objects. """ # Find all complete tool-call blocks raw_matches = self.TOOL_CALL_RE.findall(model_output) if not raw_matches: # No tool calls found — return the text as-is return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=model_output, ) tool_calls: list[ToolCall] = [] for raw_json in raw_matches: parsed = self._parse_tool_json(raw_json) if parsed is None: logger.warning( "Failed to parse tool call JSON: %s", raw_json ) continue fn_name = parsed.get("name", "") fn_args = parsed.get("arguments", {}) # Ensure arguments is a JSON string (OpenAI format) if isinstance(fn_args, dict): fn_args_str = json.dumps(fn_args) elif isinstance(fn_args, str): # Model may emit arguments as a JSON string — validate and pass through try: json.loads(fn_args) fn_args_str = fn_args except (json.JSONDecodeError, ValueError): # Try ast.literal_eval for Python-style dicts (e.g. single quotes, # unquoted keys). If that also fails, emit an empty dict so # downstream json.loads never sees an invalid string. try: recovered = ast.literal_eval(fn_args) fn_args_str = json.dumps(recovered) if isinstance(recovered, dict) else json.dumps({}) except (ValueError, SyntaxError): fn_args_str = "{}" else: fn_args_str = str(fn_args) tool_calls.append( ToolCall( id=_generate_tool_call_id(), type="function", function=FunctionCall( name=fn_name, arguments=fn_args_str, ), ) ) # Strip tool-call blocks from content to get any surrounding text remaining_content = self.TOOL_CALL_RE.sub("", model_output).strip() return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, content=remaining_content if remaining_content else None, ) # ------------------------------------------------------------------ # STREAMING extraction # ------------------------------------------------------------------ def extract_tool_calls_streaming( self, previous_text: str, current_text: str, delta_text: str, previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: """ Incrementally parse tool calls from the streaming token output. Strategy: - Before seeing , stream tokens as regular content. - Once is detected, buffer until . - On , emit the complete tool call delta. - Support multiple sequential tool calls. """ # If we haven't seen a tool_call opening tag yet, pass through as # regular content (unless the start tag is partially forming). if self.TOOL_CALL_START not in current_text: # Check if the current text ends with a partial match of the # start tag — if so, hold back to avoid emitting partial tags. for i in range(1, len(self.TOOL_CALL_START)): if current_text.endswith(self.TOOL_CALL_START[:i]): # Possibly forming the start tag — hold delta return None return DeltaMessage(content=delta_text) # ---- We are inside or past a block ---- # Find all *complete* tool call blocks so far complete_matches = self.TOOL_CALL_RE.findall(current_text) num_complete = len(complete_matches) # Determine how many we've already streamed num_already_sent = len(self.prev_tool_call_arr) if num_complete > num_already_sent: # A new tool call just completed — emit it new_raw = complete_matches[num_already_sent] parsed = self._parse_tool_json(new_raw) if parsed is None: logger.warning( "Streaming: failed to parse tool call JSON: %s", new_raw, ) return None fn_name = parsed.get("name", "") fn_args = parsed.get("arguments", {}) if isinstance(fn_args, dict): fn_args_str = json.dumps(fn_args) elif isinstance(fn_args, str): try: json.loads(fn_args) fn_args_str = fn_args except (json.JSONDecodeError, ValueError): try: recovered = ast.literal_eval(fn_args) fn_args_str = json.dumps(recovered) if isinstance(recovered, dict) else json.dumps({}) except (ValueError, SyntaxError): fn_args_str = "{}" else: fn_args_str = str(fn_args) self.current_tool_id += 1 self.prev_tool_call_arr.append(parsed) self.streamed_args_for_tool.append(fn_args_str) self.current_tool_name_sent = True return DeltaMessage( tool_calls=[ DeltaToolCall( index=self.current_tool_id, id=_generate_tool_call_id(), type="function", function=DeltaFunctionCall( name=fn_name, arguments=fn_args_str, ), ) ] ) # If we're currently inside an incomplete tool call block, # don't emit anything — wait for it to complete. # Check if there's an open without a matching close open_count = current_text.count(self.TOOL_CALL_START) close_count = current_text.count(self.TOOL_CALL_END) if open_count > close_count: # Still buffering inside a tool call return None # If we're past all tool call blocks, stream remaining content # (unlikely for most models but handles edge cases) return None