File size: 14,615 Bytes
1faccd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import os
from abc import ABC, abstractmethod
from typing import Any, Optional

import regex
from pydantic import BaseModel

from verl.tools.schemas import OpenAIFunctionToolSchema
from verl.utils.ray_utils import get_event_loop
from verl.utils.rollout_trace import rollout_trace_op

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))


class FunctionCall(BaseModel):
    arguments: str
    """
    The arguments to call the function with, as generated by the model in JSON
    format. Note that the model does not always generate valid JSON, and may
    hallucinate parameters not defined by your function schema. Validate the
    arguments in your code before calling your function.
    """

    name: str
    """The name of the function to call."""


class ToolParser(ABC):
    _registry: dict[str, type["ToolParser"]] = {}

    def __init__(self, tokenizer) -> None:
        self.tokenizer = tokenizer

    @abstractmethod
    async def extract_tool_calls(
        self, responses_ids: list[int], tools: list[OpenAIFunctionToolSchema] = None
    ) -> tuple[str, list[FunctionCall]]:
        """Extract tool calls from the responses.

        Args:
            responses_ids (List[int]): The ids of the responses.
            tools (List[OpenAIFunctionToolSchema], optional): OpenAI function tool schema.

        Returns:
            Tuple[str, List[FunctionCall]]: Content and extracted tool calls.
        """
        raise NotImplementedError

    @classmethod
    def get_tool_parser(cls, name: str, tokenizer):
        if name not in cls._registry:
            raise ValueError(f"Unknown tool parser: {name}")
        return cls._registry[name](tokenizer)

    @classmethod
    def register(cls, name: str):
        def decorator(subclass: type[ToolParser]) -> type[ToolParser]:
            cls._registry[name] = subclass
            return subclass

        return decorator


@ToolParser.register("hermes")
class HermesToolParser(ToolParser):
    """Adapted from https://github.com/vllm-project/vllm/blob/v0.9.1/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py"""

    def __init__(self, tokenizer) -> None:
        super().__init__(tokenizer)

        self.tool_call_start_token: str = "<tool_call>"
        self.tool_call_end_token: str = "</tool_call>"
        self.tool_call_regex = regex.compile(r"<tool_call>(.*?)</tool_call>", regex.DOTALL)

    @rollout_trace_op
    async def extract_tool_calls(
        self, responses_ids: list[int], tools: list[OpenAIFunctionToolSchema] = None
    ) -> tuple[str, list[FunctionCall]]:
        loop = get_event_loop()
        text = await loop.run_in_executor(None, self.tokenizer.decode, responses_ids)
        if self.tool_call_start_token not in text or self.tool_call_end_token not in text:
            return text, []

        matches = self.tool_call_regex.findall(text)
        function_calls = []
        for match in matches:
            try:
                function_call = json.loads(match)
                name, arguments = function_call["name"], function_call["arguments"]
                function_calls.append(FunctionCall(name=name, arguments=json.dumps(arguments, ensure_ascii=False)))
            except Exception as e:
                logger.error(f"Failed to decode tool call: {e}")

        # remaing text exclude tool call tokens
        content = self.tool_call_regex.sub("", text)

        return content, function_calls


@ToolParser.register("gpt-oss")
class GptOssToolParser(ToolParser):
    """
    Tool parser for gpt-oss model.
    Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/function_call/gpt_oss_detector.py

    Args:
        tokenizer: The tokenizer to use.
    """

    def __init__(self, tokenizer) -> None:
        super().__init__(tokenizer)
        # check https://cookbook.openai.com/articles/openai-harmony for more details.
        self.cot_pattern = regex.compile(
            r"<\|start\|>assistant<\|channel\|>analysis<\|message\|>.*?<\|end\|>", regex.DOTALL
        )
        # <|start|>assistant may be pre-appended in prompts, so we need to remove it.
        self.partial_cot_pattern = regex.compile(r"<\|channel\|>analysis<\|message\|>(.*?)<\|end\|>", regex.DOTALL)
        self.tool_call_pattern = regex.compile(
            r"<\|start\|>assistant<\|channel\|>[^<]* to=functions\.([^<]+) "
            r"<\|constrain\|>json<\|message\|>(.*?)<\|call\|>",
            regex.DOTALL,
        )

    @rollout_trace_op
    async def extract_tool_calls(
        self, responses_ids: list[int], tools: list[OpenAIFunctionToolSchema] = None
    ) -> tuple[str, list[FunctionCall]]:
        loop = get_event_loop()
        # We need to keep special tokens for gpt-oss model for better tool call extraction.
        text = await loop.run_in_executor(None, lambda: self.tokenizer.decode(responses_ids, skip_special_tokens=False))
        # Need to remove padding tokens for better tool call extraction.
        text = text.replace(self.tokenizer.pad_token, "")
        # Need to reomve COT since COT may contain tool call tokens.But they are not valid tool calls.
        text = regex.sub(self.cot_pattern, "", text)
        text = regex.sub(self.partial_cot_pattern, "", text)

        # check if there are tool calls in the text by re.findall
        matches = regex.findall(self.tool_call_pattern, text)
        if not matches:
            return text, []

        function_calls = []
        for match in matches:
            try:
                name, arguments = match[0], match[1]
                # don't check if arguments is valid JSON and leave it to client
                function_calls.append(FunctionCall(name=name, arguments=arguments))
            except Exception as e:
                logger.error(f"Failed to decode tool call: {e}")

        # remaing text exclude tool call tokens
        content = regex.sub(self.tool_call_pattern, "", text)

        return content, function_calls


@ToolParser.register("qwen3_coder")
class Qwen3XMLToolParser(ToolParser):
    """
    Tool parser for qwen3_coder/qwen3.5 model.
    Adapted from https://huggingface.co/Qwen/Qwen3-Coder-30B-A3B-Instruct/blob/main/qwen3coder_tool_parser.py

    Args:
        tokenizer: The tokenizer to use.
    """

    def __init__(self, tokenizer):
        super().__init__(tokenizer)

        self.tool_call_start_token: str = "<tool_call>"
        self.tool_call_end_token: str = "</tool_call>"
        self.tool_call_prefix: str = "<function="

        self.tool_call_complete_regex = regex.compile(r"<tool_call>(.*?)</tool_call>", regex.DOTALL)
        self.tool_call_regex = regex.compile(r"<tool_call>(.*?)</tool_call>|<tool_call>(.*?)$", regex.DOTALL)
        self.tool_call_function_regex = regex.compile(r"<function=(.*?)</function>|<function=(.*)$", regex.DOTALL)
        self.tool_call_parameter_regex = regex.compile(r"<parameter=(.*?)</parameter>|<parameter=(.*?)$", regex.DOTALL)

    def _parse_xml_function_call(
        self, function_call_str: str, tools: Optional[list[OpenAIFunctionToolSchema]]
    ) -> FunctionCall:
        def get_arguments_config(func_name: str) -> dict:
            for config in tools:
                if config.type == "function" and config.function.name == func_name:
                    properties = config.function.parameters.properties
                    return {k: v.model_dump() for k, v in properties.items()}
            logger.warning(f"Tool '{func_name}' is not defined in the tools list.")
            return {}

        def convert_param_value(param_value: str, param_name: str, param_config: dict, func_name: str) -> Any:
            # Handle null value for any type
            if param_value.lower() == "null":
                return None

            if param_name not in param_config:
                if param_config != {}:
                    logger.warning(
                        f"Parsed parameter '{param_name}' is not defined in the tool "
                        f"parameters for tool '{func_name}', directly returning the string value."
                    )
                return param_value

            if isinstance(param_config[param_name], dict) and "type" in param_config[param_name]:
                param_type = str(param_config[param_name]["type"]).strip().lower()
            else:
                param_type = "string"
            if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
                return param_value
            elif (
                param_type.startswith("int")
                or param_type.startswith("uint")
                or param_type.startswith("long")
                or param_type.startswith("short")
                or param_type.startswith("unsigned")
            ):
                try:
                    param_value = int(param_value)
                except Exception:
                    logger.warning(
                        f"Parsed value '{param_value}' of parameter '{param_name}' is not an integer in tool "
                        f"'{func_name}', degenerating to string."
                    )
                return param_value
            elif param_type.startswith("num") or param_type.startswith("float"):
                try:
                    float_param_value = float(param_value)
                    param_value = (
                        float_param_value if float_param_value - int(float_param_value) != 0 else int(float_param_value)
                    )
                except Exception:
                    logger.warning(
                        f"Parsed value '{param_value}' of parameter '{param_name}' is not a float in tool "
                        f"'{func_name}', degenerating to string."
                    )
                return param_value
            elif param_type in ["boolean", "bool", "binary"]:
                param_value = param_value.lower()
                if param_value not in ["true", "false"]:
                    logger.warning(
                        f"Parsed value '{param_value}' of parameter '{param_name}' is not a "
                        f"boolean (`true` of `false`) in tool '{func_name}', degenerating to false."
                    )
                return param_value == "true"
            else:
                if param_type == "object" or param_type.startswith("dict"):
                    try:
                        param_value = json.loads(param_value)
                        return param_value
                    except Exception:
                        logger.warning(
                            f"Parsed value '{param_value}' of parameter '{param_name}' is not a valid "
                            f"JSON object in tool '{func_name}', will try other methods to parse it."
                        )
                try:
                    param_value = eval(param_value)
                except Exception:
                    logger.warning(
                        f"Parsed value '{param_value}' of parameter '{param_name}' cannot be converted "
                        f"via Python `eval()` in tool '{func_name}', degenerating to string."
                    )
                return param_value

        # Extract function name
        end_index = function_call_str.index(">")
        function_name = function_call_str[:end_index]
        param_config = get_arguments_config(function_name)
        parameters = function_call_str[end_index + 1 :]
        param_dict = {}
        for match in self.tool_call_parameter_regex.findall(parameters):
            match_text = match[0] if match[0] else match[1]
            idx = match_text.index(">")
            param_name = match_text[:idx]
            param_value = str(match_text[idx + 1 :])
            # Remove prefix and trailing \n
            if param_value.startswith("\n"):
                param_value = param_value[1:]
            if param_value.endswith("\n"):
                param_value = param_value[:-1]

            param_dict[param_name] = convert_param_value(param_value, param_name, param_config, function_name)
        return FunctionCall(name=function_name, arguments=json.dumps(param_dict, ensure_ascii=False))

    def _get_function_calls(self, model_output: str) -> list[str]:
        # Find all tool calls
        matched_ranges = self.tool_call_regex.findall(model_output)
        raw_tool_calls = [match[0] if match[0] else match[1] for match in matched_ranges]

        # Back-off strategy if no tool_call tags found
        if len(raw_tool_calls) == 0:
            raw_tool_calls = [model_output]

        raw_function_calls = []
        for tool_call in raw_tool_calls:
            raw_function_calls.extend(self.tool_call_function_regex.findall(tool_call))

        function_calls = [match[0] if match[0] else match[1] for match in raw_function_calls]
        return function_calls

    @rollout_trace_op
    async def extract_tool_calls(
        self, responses_ids: list[int], tools: list[OpenAIFunctionToolSchema] = None
    ) -> tuple[str, list[FunctionCall]]:
        loop = get_event_loop()
        text = await loop.run_in_executor(None, self.tokenizer.decode, responses_ids)
        if self.tool_call_start_token not in text:
            return text, []

        try:
            function_calls = self._get_function_calls(text)
            if len(function_calls) == 0:
                return text, []

            tool_calls = [
                self._parse_xml_function_call(function_call_str, tools) for function_call_str in function_calls
            ]

            # Extract content before tool calls
            content_index = text.find(self.tool_call_start_token)
            content_index = content_index if content_index >= 0 else text.find(self.tool_call_prefix)
            content = text[:content_index]  # .rstrip()

            return content, tool_calls
        except Exception as e:
            logger.exception(f"Error in extracting tool call from response: {e}")
            return text, []