| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import json |
| import logging |
| import os |
| from abc import ABC, abstractmethod |
| from typing import Any, Optional |
|
|
| import regex |
| from pydantic import BaseModel |
|
|
| from verl.tools.schemas import OpenAIFunctionToolSchema |
| from verl.utils.ray_utils import get_event_loop |
| from verl.utils.rollout_trace import rollout_trace_op |
|
|
| logger = logging.getLogger(__file__) |
| logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) |
|
|
|
|
| class FunctionCall(BaseModel): |
| arguments: str |
| """ |
| The arguments to call the function with, as generated by the model in JSON |
| format. Note that the model does not always generate valid JSON, and may |
| hallucinate parameters not defined by your function schema. Validate the |
| arguments in your code before calling your function. |
| """ |
|
|
| name: str |
| """The name of the function to call.""" |
|
|
|
|
| class ToolParser(ABC): |
| _registry: dict[str, type["ToolParser"]] = {} |
|
|
| def __init__(self, tokenizer) -> None: |
| self.tokenizer = tokenizer |
|
|
| @abstractmethod |
| async def extract_tool_calls( |
| self, responses_ids: list[int], tools: list[OpenAIFunctionToolSchema] = None |
| ) -> tuple[str, list[FunctionCall]]: |
| """Extract tool calls from the responses. |
| |
| Args: |
| responses_ids (List[int]): The ids of the responses. |
| tools (List[OpenAIFunctionToolSchema], optional): OpenAI function tool schema. |
| |
| Returns: |
| Tuple[str, List[FunctionCall]]: Content and extracted tool calls. |
| """ |
| raise NotImplementedError |
|
|
| @classmethod |
| def get_tool_parser(cls, name: str, tokenizer): |
| if name not in cls._registry: |
| raise ValueError(f"Unknown tool parser: {name}") |
| return cls._registry[name](tokenizer) |
|
|
| @classmethod |
| def register(cls, name: str): |
| def decorator(subclass: type[ToolParser]) -> type[ToolParser]: |
| cls._registry[name] = subclass |
| return subclass |
|
|
| return decorator |
|
|
|
|
| @ToolParser.register("hermes") |
| class HermesToolParser(ToolParser): |
| """Adapted from https://github.com/vllm-project/vllm/blob/v0.9.1/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py""" |
|
|
| def __init__(self, tokenizer) -> None: |
| super().__init__(tokenizer) |
|
|
| self.tool_call_start_token: str = "<tool_call>" |
| self.tool_call_end_token: str = "</tool_call>" |
| self.tool_call_regex = regex.compile(r"<tool_call>(.*?)</tool_call>", regex.DOTALL) |
|
|
| @rollout_trace_op |
| async def extract_tool_calls( |
| self, responses_ids: list[int], tools: list[OpenAIFunctionToolSchema] = None |
| ) -> tuple[str, list[FunctionCall]]: |
| loop = get_event_loop() |
| text = await loop.run_in_executor(None, self.tokenizer.decode, responses_ids) |
| if self.tool_call_start_token not in text or self.tool_call_end_token not in text: |
| return text, [] |
|
|
| matches = self.tool_call_regex.findall(text) |
| function_calls = [] |
| for match in matches: |
| try: |
| function_call = json.loads(match) |
| name, arguments = function_call["name"], function_call["arguments"] |
| function_calls.append(FunctionCall(name=name, arguments=json.dumps(arguments, ensure_ascii=False))) |
| except Exception as e: |
| logger.error(f"Failed to decode tool call: {e}") |
|
|
| |
| content = self.tool_call_regex.sub("", text) |
|
|
| return content, function_calls |
|
|
|
|
| @ToolParser.register("gpt-oss") |
| class GptOssToolParser(ToolParser): |
| """ |
| Tool parser for gpt-oss model. |
| Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/function_call/gpt_oss_detector.py |
| |
| Args: |
| tokenizer: The tokenizer to use. |
| """ |
|
|
| def __init__(self, tokenizer) -> None: |
| super().__init__(tokenizer) |
| |
| self.cot_pattern = regex.compile( |
| r"<\|start\|>assistant<\|channel\|>analysis<\|message\|>.*?<\|end\|>", regex.DOTALL |
| ) |
| |
| self.partial_cot_pattern = regex.compile(r"<\|channel\|>analysis<\|message\|>(.*?)<\|end\|>", regex.DOTALL) |
| self.tool_call_pattern = regex.compile( |
| r"<\|start\|>assistant<\|channel\|>[^<]* to=functions\.([^<]+) " |
| r"<\|constrain\|>json<\|message\|>(.*?)<\|call\|>", |
| regex.DOTALL, |
| ) |
|
|
| @rollout_trace_op |
| async def extract_tool_calls( |
| self, responses_ids: list[int], tools: list[OpenAIFunctionToolSchema] = None |
| ) -> tuple[str, list[FunctionCall]]: |
| loop = get_event_loop() |
| |
| text = await loop.run_in_executor(None, lambda: self.tokenizer.decode(responses_ids, skip_special_tokens=False)) |
| |
| text = text.replace(self.tokenizer.pad_token, "") |
| |
| text = regex.sub(self.cot_pattern, "", text) |
| text = regex.sub(self.partial_cot_pattern, "", text) |
|
|
| |
| matches = regex.findall(self.tool_call_pattern, text) |
| if not matches: |
| return text, [] |
|
|
| function_calls = [] |
| for match in matches: |
| try: |
| name, arguments = match[0], match[1] |
| |
| function_calls.append(FunctionCall(name=name, arguments=arguments)) |
| except Exception as e: |
| logger.error(f"Failed to decode tool call: {e}") |
|
|
| |
| content = regex.sub(self.tool_call_pattern, "", text) |
|
|
| return content, function_calls |
|
|
|
|
| @ToolParser.register("qwen3_coder") |
| class Qwen3XMLToolParser(ToolParser): |
| """ |
| Tool parser for qwen3_coder/qwen3.5 model. |
| Adapted from https://huggingface.co/Qwen/Qwen3-Coder-30B-A3B-Instruct/blob/main/qwen3coder_tool_parser.py |
| |
| Args: |
| tokenizer: The tokenizer to use. |
| """ |
|
|
| def __init__(self, tokenizer): |
| super().__init__(tokenizer) |
|
|
| self.tool_call_start_token: str = "<tool_call>" |
| self.tool_call_end_token: str = "</tool_call>" |
| self.tool_call_prefix: str = "<function=" |
|
|
| self.tool_call_complete_regex = regex.compile(r"<tool_call>(.*?)</tool_call>", regex.DOTALL) |
| self.tool_call_regex = regex.compile(r"<tool_call>(.*?)</tool_call>|<tool_call>(.*?)$", regex.DOTALL) |
| self.tool_call_function_regex = regex.compile(r"<function=(.*?)</function>|<function=(.*)$", regex.DOTALL) |
| self.tool_call_parameter_regex = regex.compile(r"<parameter=(.*?)</parameter>|<parameter=(.*?)$", regex.DOTALL) |
|
|
| def _parse_xml_function_call( |
| self, function_call_str: str, tools: Optional[list[OpenAIFunctionToolSchema]] |
| ) -> FunctionCall: |
| def get_arguments_config(func_name: str) -> dict: |
| for config in tools: |
| if config.type == "function" and config.function.name == func_name: |
| properties = config.function.parameters.properties |
| return {k: v.model_dump() for k, v in properties.items()} |
| logger.warning(f"Tool '{func_name}' is not defined in the tools list.") |
| return {} |
|
|
| def convert_param_value(param_value: str, param_name: str, param_config: dict, func_name: str) -> Any: |
| |
| if param_value.lower() == "null": |
| return None |
|
|
| if param_name not in param_config: |
| if param_config != {}: |
| logger.warning( |
| f"Parsed parameter '{param_name}' is not defined in the tool " |
| f"parameters for tool '{func_name}', directly returning the string value." |
| ) |
| return param_value |
|
|
| if isinstance(param_config[param_name], dict) and "type" in param_config[param_name]: |
| param_type = str(param_config[param_name]["type"]).strip().lower() |
| else: |
| param_type = "string" |
| if param_type in ["string", "str", "text", "varchar", "char", "enum"]: |
| return param_value |
| elif ( |
| param_type.startswith("int") |
| or param_type.startswith("uint") |
| or param_type.startswith("long") |
| or param_type.startswith("short") |
| or param_type.startswith("unsigned") |
| ): |
| try: |
| param_value = int(param_value) |
| except Exception: |
| logger.warning( |
| f"Parsed value '{param_value}' of parameter '{param_name}' is not an integer in tool " |
| f"'{func_name}', degenerating to string." |
| ) |
| return param_value |
| elif param_type.startswith("num") or param_type.startswith("float"): |
| try: |
| float_param_value = float(param_value) |
| param_value = ( |
| float_param_value if float_param_value - int(float_param_value) != 0 else int(float_param_value) |
| ) |
| except Exception: |
| logger.warning( |
| f"Parsed value '{param_value}' of parameter '{param_name}' is not a float in tool " |
| f"'{func_name}', degenerating to string." |
| ) |
| return param_value |
| elif param_type in ["boolean", "bool", "binary"]: |
| param_value = param_value.lower() |
| if param_value not in ["true", "false"]: |
| logger.warning( |
| f"Parsed value '{param_value}' of parameter '{param_name}' is not a " |
| f"boolean (`true` of `false`) in tool '{func_name}', degenerating to false." |
| ) |
| return param_value == "true" |
| else: |
| if param_type == "object" or param_type.startswith("dict"): |
| try: |
| param_value = json.loads(param_value) |
| return param_value |
| except Exception: |
| logger.warning( |
| f"Parsed value '{param_value}' of parameter '{param_name}' is not a valid " |
| f"JSON object in tool '{func_name}', will try other methods to parse it." |
| ) |
| try: |
| param_value = eval(param_value) |
| except Exception: |
| logger.warning( |
| f"Parsed value '{param_value}' of parameter '{param_name}' cannot be converted " |
| f"via Python `eval()` in tool '{func_name}', degenerating to string." |
| ) |
| return param_value |
|
|
| |
| end_index = function_call_str.index(">") |
| function_name = function_call_str[:end_index] |
| param_config = get_arguments_config(function_name) |
| parameters = function_call_str[end_index + 1 :] |
| param_dict = {} |
| for match in self.tool_call_parameter_regex.findall(parameters): |
| match_text = match[0] if match[0] else match[1] |
| idx = match_text.index(">") |
| param_name = match_text[:idx] |
| param_value = str(match_text[idx + 1 :]) |
| |
| if param_value.startswith("\n"): |
| param_value = param_value[1:] |
| if param_value.endswith("\n"): |
| param_value = param_value[:-1] |
|
|
| param_dict[param_name] = convert_param_value(param_value, param_name, param_config, function_name) |
| return FunctionCall(name=function_name, arguments=json.dumps(param_dict, ensure_ascii=False)) |
|
|
| def _get_function_calls(self, model_output: str) -> list[str]: |
| |
| matched_ranges = self.tool_call_regex.findall(model_output) |
| raw_tool_calls = [match[0] if match[0] else match[1] for match in matched_ranges] |
|
|
| |
| if len(raw_tool_calls) == 0: |
| raw_tool_calls = [model_output] |
|
|
| raw_function_calls = [] |
| for tool_call in raw_tool_calls: |
| raw_function_calls.extend(self.tool_call_function_regex.findall(tool_call)) |
|
|
| function_calls = [match[0] if match[0] else match[1] for match in raw_function_calls] |
| return function_calls |
|
|
| @rollout_trace_op |
| async def extract_tool_calls( |
| self, responses_ids: list[int], tools: list[OpenAIFunctionToolSchema] = None |
| ) -> tuple[str, list[FunctionCall]]: |
| loop = get_event_loop() |
| text = await loop.run_in_executor(None, self.tokenizer.decode, responses_ids) |
| if self.tool_call_start_token not in text: |
| return text, [] |
|
|
| try: |
| function_calls = self._get_function_calls(text) |
| if len(function_calls) == 0: |
| return text, [] |
|
|
| tool_calls = [ |
| self._parse_xml_function_call(function_call_str, tools) for function_call_str in function_calls |
| ] |
|
|
| |
| content_index = text.find(self.tool_call_start_token) |
| content_index = content_index if content_index >= 0 else text.find(self.tool_call_prefix) |
| content = text[:content_index] |
|
|
| return content, tool_calls |
| except Exception as e: |
| logger.exception(f"Error in extracting tool call from response: {e}") |
| return text, [] |
|
|