| | |
| |
|
| | import json |
| | import re |
| | import uuid |
| | from collections.abc import Sequence |
| | from typing import Union, Optional, Any, List, Dict |
| | from enum import Enum |
| |
|
| | from vllm.entrypoints.openai.protocol import ( |
| | ChatCompletionRequest, |
| | ChatCompletionToolsParam, |
| | DeltaMessage, |
| | DeltaToolCall, |
| | DeltaFunctionCall, |
| | ExtractedToolCallInformation, |
| | FunctionCall, |
| | ToolCall, |
| | ) |
| | from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( |
| | ToolParser, |
| | ToolParserManager, |
| | ) |
| | from vllm.logger import init_logger |
| | from vllm.transformers_utils.tokenizer import AnyTokenizer |
| |
|
| | logger = init_logger(__name__) |
| |
|
| |
|
| | @ToolParserManager.register_module("qwen3_xml") |
| | class Qwen3XMLToolParser(ToolParser): |
| | def __init__(self, tokenizer: AnyTokenizer): |
| | super().__init__(tokenizer) |
| |
|
| | self.current_tool_name_sent: bool = False |
| | self.prev_tool_call_arr: list[dict] = [] |
| | self.current_tool_id: int = -1 |
| | self.streamed_args_for_tool: list[str] = [] |
| |
|
| | |
| | self.tool_call_start_token: str = "<tool_call>" |
| | self.tool_call_end_token: str = "</tool_call>" |
| | self.tool_call_prefix: str = "<function=" |
| | self.function_end_token: str = "</function>" |
| | self.parameter_prefix: str = "<parameter=" |
| | self.parameter_end_token: str = "</parameter>" |
| | self.is_tool_call_started: bool = False |
| | self.failed_count: int = 0 |
| |
|
| | |
| | self._reset_streaming_state() |
| |
|
| | |
| | self.tool_call_complete_regex = re.compile( |
| | r"<tool_call>(.*?)</tool_call>", re.DOTALL |
| | ) |
| | self.tool_call_regex = re.compile( |
| | r"<tool_call>(.*?)</tool_call>|<tool_call>(.*?)$", re.DOTALL |
| | ) |
| | self.tool_call_function_regex = re.compile( |
| | r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL |
| | ) |
| | self.tool_call_parameter_regex = re.compile( |
| | r"<parameter=(.*?)</parameter>|<parameter=(.*?)$", re.DOTALL |
| | ) |
| |
|
| | if not self.model_tokenizer: |
| | raise ValueError( |
| | "The model tokenizer must be passed to the ToolParser " |
| | "constructor during construction." |
| | ) |
| |
|
| | self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token) |
| | self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) |
| |
|
| | if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None: |
| | raise RuntimeError( |
| | "Qwen3 XML Tool parser could not locate tool call start/end " |
| | "tokens in the tokenizer!" |
| | ) |
| |
|
| | logger.info(f"vLLM Successfully import tool parser {self.__class__.__name__} !") |
| |
|
| | def _generate_tool_call_id(self) -> str: |
| | """Generate a unique tool call ID.""" |
| | return f"call_{uuid.uuid4().hex[:24]}" |
| |
|
| | def _reset_streaming_state(self): |
| | """Reset all streaming state.""" |
| | self.current_tool_index = 0 |
| | self.is_tool_call_started = False |
| | self.header_sent = False |
| | self.current_tool_id = None |
| | self.current_function_name = None |
| | self.current_param_name = None |
| | self.current_param_value = "" |
| | self.param_count = 0 |
| | self.in_param = False |
| | self.in_function = False |
| | self.accumulated_text = "" |
| | self.json_started = False |
| | self.json_closed = False |
| |
|
| | def _parse_xml_function_call( |
| | self, function_call_str: str, tools: Optional[list[ChatCompletionToolsParam]] |
| | ) -> Optional[ToolCall]: |
| | def get_arguments_config(func_name: str) -> dict: |
| | if tools is None: |
| | return {} |
| | for config in tools: |
| | if not hasattr(config, "type") or not ( |
| | hasattr(config, "function") and hasattr(config.function, "name") |
| | ): |
| | continue |
| | if config.type == "function" and config.function.name == func_name: |
| | if not hasattr(config.function, "parameters"): |
| | return {} |
| | params = config.function.parameters |
| | if isinstance(params, dict) and "properties" in params: |
| | return params["properties"] |
| | elif isinstance(params, dict): |
| | return params |
| | else: |
| | return {} |
| | logger.warning(f"Tool '{func_name}' is not defined in the tools list.") |
| | return {} |
| |
|
| | def convert_param_value( |
| | param_value: str, param_name: str, param_config: dict, func_name: str |
| | ) -> Any: |
| | |
| | if param_value.lower() == "null": |
| | return None |
| |
|
| | if param_name not in param_config: |
| | if param_config != {}: |
| | logger.warning( |
| | f"Parsed parameter '{param_name}' is not defined in the tool " |
| | f"parameters for tool '{func_name}', directly returning the string value." |
| | ) |
| | return param_value |
| |
|
| | if ( |
| | isinstance(param_config[param_name], dict) |
| | and "type" in param_config[param_name] |
| | ): |
| | param_type = str(param_config[param_name]["type"]).strip().lower() |
| | else: |
| | param_type = "string" |
| | if param_type in ["string", "str", "text", "varchar", "char", "enum"]: |
| | return param_value |
| | elif ( |
| | param_type.startswith("int") |
| | or param_type.startswith("uint") |
| | or param_type.startswith("long") |
| | or param_type.startswith("short") |
| | or param_type.startswith("unsigned") |
| | ): |
| | try: |
| | param_value = int(param_value) |
| | except: |
| | logger.warning( |
| | f"Parsed value '{param_value}' of parameter '{param_name}' is not an integer in tool " |
| | f"'{func_name}', degenerating to string." |
| | ) |
| | return param_value |
| | elif param_type.startswith("num") or param_type.startswith("float"): |
| | try: |
| | float_param_value = float(param_value) |
| | param_value = float_param_value if float_param_value - int(float_param_value) != 0 else int(float_param_value) |
| | except: |
| | logger.warning( |
| | f"Parsed value '{param_value}' of parameter '{param_name}' is not a float in tool " |
| | f"'{func_name}', degenerating to string." |
| | ) |
| | return param_value |
| | elif param_type in ["boolean", "bool", "binary"]: |
| | param_value = param_value.lower() |
| | if param_value not in ["true", "false"]: |
| | logger.warning( |
| | f"Parsed value '{param_value}' of parameter '{param_name}' is not a boolean (`true` of `false`) in tool '{func_name}', degenerating to false." |
| | ) |
| | return param_value == "true" |
| | else: |
| | if param_type == "object" or param_type.startswith("dict"): |
| | try: |
| | param_value = json.loads(param_value) |
| | return param_value |
| | except: |
| | logger.warning( |
| | f"Parsed value '{param_value}' of parameter '{param_name}' is not a valid JSON object in tool " |
| | f"'{func_name}', will try other methods to parse it." |
| | ) |
| | try: |
| | param_value = eval(param_value) |
| | except: |
| | logger.warning( |
| | f"Parsed value '{param_value}' of parameter '{param_name}' cannot be converted via Python `eval()` in tool '{func_name}', degenerating to string." |
| | ) |
| | return param_value |
| |
|
| | |
| | end_index = function_call_str.index(">") |
| | function_name = function_call_str[:end_index] |
| | param_config = get_arguments_config(function_name) |
| | parameters = function_call_str[end_index + 1 :] |
| | param_dict = {} |
| | for match in self.tool_call_parameter_regex.findall(parameters): |
| | match_text = match[0] if match[0] else match[1] |
| | idx = match_text.index(">") |
| | param_name = match_text[:idx] |
| | param_value = str(match_text[idx + 1 :]) |
| | |
| | if param_value.startswith("\n"): |
| | param_value = param_value[1:] |
| | if param_value.endswith("\n"): |
| | param_value = param_value[:-1] |
| |
|
| | param_dict[param_name] = convert_param_value( |
| | param_value, param_name, param_config, function_name |
| | ) |
| | return ToolCall( |
| | type="function", |
| | function=FunctionCall( |
| | name=function_name, arguments=json.dumps(param_dict, ensure_ascii=False) |
| | ), |
| | ) |
| |
|
| | def _get_function_calls(self, model_output: str) -> List[str]: |
| | |
| | matched_ranges = self.tool_call_regex.findall(model_output) |
| | raw_tool_calls = [ |
| | match[0] if match[0] else match[1] for match in matched_ranges |
| | ] |
| |
|
| | |
| | if len(raw_tool_calls) == 0: |
| | raw_tool_calls = [model_output] |
| |
|
| | raw_function_calls = [] |
| | for tool_call in raw_tool_calls: |
| | raw_function_calls.extend(self.tool_call_function_regex.findall(tool_call)) |
| |
|
| | function_calls = [ |
| | match[0] if match[0] else match[1] for match in raw_function_calls |
| | ] |
| | return function_calls |
| |
|
| | def extract_tool_calls( |
| | self, |
| | model_output: str, |
| | request: ChatCompletionRequest, |
| | ) -> ExtractedToolCallInformation: |
| | |
| | if self.tool_call_prefix not in model_output: |
| | return ExtractedToolCallInformation( |
| | tools_called=False, tool_calls=[], content=model_output |
| | ) |
| |
|
| | try: |
| | function_calls = self._get_function_calls(model_output) |
| | if len(function_calls) == 0: |
| | return ExtractedToolCallInformation( |
| | tools_called=False, tool_calls=[], content=model_output |
| | ) |
| |
|
| | tool_calls = [ |
| | self._parse_xml_function_call(function_call_str, request.tools) |
| | for function_call_str in function_calls |
| | ] |
| |
|
| | |
| | self.prev_tool_call_arr.clear() |
| | for tool_call in tool_calls: |
| | if tool_call: |
| | self.prev_tool_call_arr.append( |
| | { |
| | "name": tool_call.function.name, |
| | "arguments": tool_call.function.arguments, |
| | } |
| | ) |
| |
|
| | |
| | content_index = model_output.find(self.tool_call_start_token) |
| | content_index = ( |
| | content_index |
| | if content_index >= 0 |
| | else model_output.find(self.tool_call_prefix) |
| | ) |
| | content = model_output[:content_index] |
| |
|
| | return ExtractedToolCallInformation( |
| | tools_called=(len(tool_calls) > 0), |
| | tool_calls=tool_calls, |
| | content=content if content else None, |
| | ) |
| |
|
| | except Exception: |
| | logger.exception("Error in extracting tool call from response.") |
| | return ExtractedToolCallInformation( |
| | tools_called=False, tool_calls=[], content=model_output |
| | ) |
| |
|
| | def extract_tool_calls_streaming( |
| | self, |
| | previous_text: str, |
| | current_text: str, |
| | delta_text: str, |
| | previous_token_ids: Sequence[int], |
| | current_token_ids: Sequence[int], |
| | delta_token_ids: Sequence[int], |
| | request: ChatCompletionRequest, |
| | ) -> Union[DeltaMessage, None]: |
| | |
| | if not delta_text: |
| | |
| | |
| | |
| | if delta_token_ids and self.tool_call_end_token_id not in delta_token_ids: |
| | |
| | complete_calls = len( |
| | self.tool_call_complete_regex.findall(current_text) |
| | ) |
| |
|
| | |
| | if complete_calls > 0 and len(self.prev_tool_call_arr) > 0: |
| | |
| | open_calls = current_text.count( |
| | self.tool_call_start_token |
| | ) - current_text.count(self.tool_call_end_token) |
| | if open_calls == 0: |
| | |
| | return DeltaMessage(content="") |
| | elif not self.is_tool_call_started and current_text: |
| | |
| | return DeltaMessage(content="") |
| | return None |
| |
|
| | |
| | if not previous_text: |
| | self._reset_streaming_state() |
| |
|
| | |
| | self.accumulated_text = current_text |
| |
|
| | |
| | if self.json_closed and not self.in_function: |
| | |
| | tool_ends = current_text.count(self.tool_call_end_token) |
| | if tool_ends > self.current_tool_index: |
| | |
| | self.current_tool_index += 1 |
| | self.header_sent = False |
| | self.param_count = 0 |
| | self.json_started = False |
| | self.json_closed = False |
| |
|
| | |
| | tool_starts = current_text.count(self.tool_call_start_token) |
| | if self.current_tool_index >= tool_starts: |
| | |
| | self.is_tool_call_started = False |
| | |
| | return None |
| |
|
| | |
| | if not self.is_tool_call_started: |
| | |
| | if ( |
| | self.tool_call_start_token_id in delta_token_ids |
| | or self.tool_call_start_token in delta_text |
| | ): |
| | self.is_tool_call_started = True |
| | |
| | if self.tool_call_start_token in delta_text: |
| | content_before = delta_text[ |
| | : delta_text.index(self.tool_call_start_token) |
| | ] |
| | if content_before: |
| | return DeltaMessage(content=content_before) |
| | return None |
| | else: |
| | |
| | if current_text.rstrip().endswith(self.tool_call_end_token): |
| | |
| | if delta_text.strip() == "": |
| | return None |
| | |
| | return DeltaMessage(content=delta_text) |
| |
|
| | |
| | |
| | tool_starts_count = current_text.count(self.tool_call_start_token) |
| | if self.current_tool_index >= tool_starts_count: |
| | |
| | return None |
| |
|
| | |
| | |
| | tool_starts = [] |
| | idx = 0 |
| | while True: |
| | idx = current_text.find(self.tool_call_start_token, idx) |
| | if idx == -1: |
| | break |
| | tool_starts.append(idx) |
| | idx += len(self.tool_call_start_token) |
| |
|
| | if self.current_tool_index >= len(tool_starts): |
| | |
| | return None |
| |
|
| | tool_start_idx = tool_starts[self.current_tool_index] |
| | |
| | tool_end_idx = current_text.find(self.tool_call_end_token, tool_start_idx) |
| | if tool_end_idx == -1: |
| | tool_text = current_text[tool_start_idx:] |
| | else: |
| | tool_text = current_text[ |
| | tool_start_idx : tool_end_idx + len(self.tool_call_end_token) |
| | ] |
| |
|
| | |
| | if not self.header_sent: |
| | if self.tool_call_prefix in tool_text: |
| | func_start = tool_text.find(self.tool_call_prefix) + len( |
| | self.tool_call_prefix |
| | ) |
| | func_end = tool_text.find(">", func_start) |
| |
|
| | if func_end != -1: |
| | |
| | self.current_function_name = tool_text[func_start:func_end] |
| | self.current_tool_id = self._generate_tool_call_id() |
| | self.header_sent = True |
| | self.in_function = True |
| |
|
| | |
| | |
| | already_added = any( |
| | tool.get("name") == self.current_function_name |
| | for tool in self.prev_tool_call_arr |
| | ) |
| | if not already_added: |
| | self.prev_tool_call_arr.append( |
| | { |
| | "name": self.current_function_name, |
| | "arguments": "{}", |
| | } |
| | ) |
| |
|
| | |
| | return DeltaMessage( |
| | tool_calls=[ |
| | DeltaToolCall( |
| | index=self.current_tool_index, |
| | id=self.current_tool_id, |
| | function=DeltaFunctionCall( |
| | name=self.current_function_name, arguments="" |
| | ), |
| | type="function", |
| | ) |
| | ] |
| | ) |
| | return None |
| |
|
| | |
| | if self.in_function: |
| | |
| | if not self.json_started and not self.parameter_prefix in delta_text: |
| | self.json_started = True |
| | return DeltaMessage( |
| | tool_calls=[ |
| | DeltaToolCall( |
| | index=self.current_tool_index, |
| | function=DeltaFunctionCall(arguments="{"), |
| | ) |
| | ] |
| | ) |
| |
|
| | |
| | if not self.json_started: |
| | self.json_started = True |
| |
|
| | |
| | if not self.json_closed and self.function_end_token in tool_text: |
| | |
| | self.json_closed = True |
| |
|
| | |
| | |
| | func_start = tool_text.find(self.tool_call_prefix) + len( |
| | self.tool_call_prefix |
| | ) |
| | func_content_end = tool_text.find(self.function_end_token, func_start) |
| | if func_content_end != -1: |
| | func_content = tool_text[func_start:func_content_end] |
| | |
| | try: |
| | parsed_tool = self._parse_xml_function_call( |
| | func_content, request.tools if request else None |
| | ) |
| | if parsed_tool: |
| | |
| | for i, tool in enumerate(self.prev_tool_call_arr): |
| | if tool.get("name") == parsed_tool.function.name: |
| | self.prev_tool_call_arr[i]["arguments"] = ( |
| | parsed_tool.function.arguments |
| | ) |
| | break |
| | except Exception: |
| | pass |
| |
|
| | result = DeltaMessage( |
| | tool_calls=[ |
| | DeltaToolCall( |
| | index=self.current_tool_index, |
| | function=DeltaFunctionCall(arguments="}"), |
| | ) |
| | ] |
| | ) |
| |
|
| | |
| | self.in_function = False |
| | self.json_closed = True |
| |
|
| | return result |
| |
|
| | |
| | |
| | complete_params = tool_text.count(self.parameter_end_token) |
| |
|
| | |
| | if not self.in_param and self.param_count < complete_params: |
| | |
| | |
| | param_starts = [] |
| | idx = 0 |
| | while True: |
| | idx = tool_text.find(self.parameter_prefix, idx) |
| | if idx == -1: |
| | break |
| | param_starts.append(idx) |
| | idx += len(self.parameter_prefix) |
| |
|
| | if len(param_starts) > self.param_count: |
| | |
| | param_idx = param_starts[self.param_count] |
| | param_start = param_idx + len(self.parameter_prefix) |
| | remaining = tool_text[param_start:] |
| |
|
| | if ">" in remaining: |
| | |
| | name_end = remaining.find(">") |
| | self.current_param_name = remaining[:name_end] |
| |
|
| | |
| | value_start = param_start + name_end + 1 |
| | value_text = tool_text[value_start:] |
| | if value_text.startswith("\n"): |
| | value_text = value_text[1:] |
| |
|
| | |
| | param_end_idx = value_text.find(self.parameter_end_token) |
| | if param_end_idx != -1: |
| | |
| | param_value = value_text[:param_end_idx] |
| | if param_value.endswith("\n"): |
| | param_value = param_value[:-1] |
| |
|
| | |
| | if self.param_count == 0: |
| | json_fragment = ( |
| | '"' |
| | + self.current_param_name |
| | + '": "' |
| | + json.dumps(param_value)[1:-1] |
| | + '"' |
| | ) |
| | else: |
| | json_fragment = ( |
| | ', "' |
| | + self.current_param_name |
| | + '": "' |
| | + json.dumps(param_value)[1:-1] |
| | + '"' |
| | ) |
| |
|
| | self.param_count += 1 |
| |
|
| | return DeltaMessage( |
| | tool_calls=[ |
| | DeltaToolCall( |
| | index=self.current_tool_index, |
| | function=DeltaFunctionCall( |
| | arguments=json_fragment |
| | ), |
| | ) |
| | ] |
| | ) |
| |
|
| | |
| | if self.in_param: |
| | if self.parameter_end_token in delta_text: |
| | |
| | end_idx = delta_text.find(self.parameter_end_token) |
| | value_chunk = delta_text[:end_idx] |
| |
|
| | |
| | if not self.current_param_value and ">" in value_chunk: |
| | gt_idx = value_chunk.find(">") |
| | value_chunk = value_chunk[gt_idx + 1 :] |
| |
|
| | if not self.current_param_value and value_chunk.startswith("\n"): |
| | value_chunk = value_chunk[1:] |
| |
|
| | |
| | full_value = self.current_param_value + value_chunk |
| | prev_escaped = ( |
| | json.dumps(self.current_param_value)[1:-1] |
| | if self.current_param_value |
| | else "" |
| | ) |
| | full_escaped = json.dumps(full_value)[1:-1] |
| | delta_escaped = full_escaped[len(prev_escaped) :] |
| |
|
| | self.in_param = False |
| | self.current_param_value = "" |
| |
|
| | return DeltaMessage( |
| | tool_calls=[ |
| | DeltaToolCall( |
| | index=self.current_tool_index, |
| | function=DeltaFunctionCall( |
| | arguments=delta_escaped + '"' |
| | ), |
| | ) |
| | ] |
| | ) |
| | else: |
| | |
| | value_chunk = delta_text |
| |
|
| | |
| | if not self.current_param_value and ">" in value_chunk: |
| | gt_idx = value_chunk.find(">") |
| | value_chunk = value_chunk[gt_idx + 1 :] |
| |
|
| | if not self.current_param_value and value_chunk.startswith("\n"): |
| | value_chunk = value_chunk[1:] |
| |
|
| | if value_chunk: |
| | |
| | prev_escaped = ( |
| | json.dumps(self.current_param_value)[1:-1] |
| | if self.current_param_value |
| | else "" |
| | ) |
| | self.current_param_value += value_chunk |
| | full_escaped = json.dumps(self.current_param_value)[1:-1] |
| | delta_escaped = full_escaped[len(prev_escaped) :] |
| |
|
| | if delta_escaped: |
| | return DeltaMessage( |
| | tool_calls=[ |
| | DeltaToolCall( |
| | index=self.current_tool_index, |
| | function=DeltaFunctionCall( |
| | arguments=delta_escaped |
| | ), |
| | ) |
| | ] |
| | ) |
| |
|
| | return None |
| |
|