File size: 10,363 Bytes
bddf0b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
# coding=utf-8
# Copyright 2025 Upstage AI.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
import re
import string
import ast
import json
from collections.abc import Sequence
from typing import Union, Tuple, List, Optional

from vllm.entrypoints.openai.protocol import (
    ChatCompletionRequest,
    DeltaMessage,
    DeltaFunctionCall,
    DeltaToolCall,
    ExtractedToolCallInformation,
    ToolCall,
    FunctionCall,
)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
    ToolParser
)
from vllm.logger import init_logger

import pyjson5

class ToolCallID:
    _LENGTH = 10

    def __init__(self, id_val: str, validation: bool = False):
        self._id = id_val
        if validation:
            self._validate()

    @classmethod
    def random(cls, validation=False) -> 'ToolCallID':
        chars = string.ascii_lowercase + string.digits
        return cls(''.join(random.choice(chars) for _ in range(ToolCallID._LENGTH)), validation=validation)

    def _validate(self):
        assert len(self._id) == ToolCallID._LENGTH
        pattern = r'^[a-z0-9]{10}$'
        assert re.match(pattern, self._id) is not None

    def to_string(self) -> str:
        return self._id

    def __str__(self) -> str:
        return self.to_string()


logger = init_logger(__name__)


class SolarOpenToolParser(ToolParser):

    def extract_tool_calls(
            self,
            model_output: str,
            request: ChatCompletionRequest,
    ) -> ExtractedToolCallInformation:
        content, tool_calls = self._parse_text(model_output)
        return ExtractedToolCallInformation(
            tools_called=len(tool_calls) > 0,
            tool_calls=tool_calls,
            content=content if content else None,
        )

    def extract_tool_calls_streaming(
            self,
            previous_text: str,
            current_text: str,
            delta_text: str,
            previous_token_ids: Sequence[int],
            current_token_ids: Sequence[int],
            delta_token_ids: Sequence[int],
            request: ChatCompletionRequest,
    ) -> Union[DeltaMessage, None]:
        # 1) Emit plain content tokens immediately until content terminator
        # tags or tool_calls section begins. Be careful when tokenizer groups
        # multiple special tags into a single delta (e.g., "<|tool_calls|><|tool_call:begin|>").
        # Only emit as content if BOTH:
        #  - previous_text has not seen any special markers, and
        #  - delta_text does NOT contain any of those markers as a substring.
        if delta_text:
            # Do NOT emit content if we have already started any special section
            # including tool call tags. Content should only be emitted at the
            # very beginning before any markers show up.
            special_markers = (
                "<|flush|>",
                "<|end|>",
                "<|begin|>",
                "<|tool_calls|>",
                "<|tool_call:begin|>",
                "<|tool_call:name|>",
                "<|tool_call:args|>",
                "<|tool_call:end|>",
                "<|calls|>",
            )
            if not any(tag in previous_text for tag in special_markers):
                if not any(tag in delta_text for tag in special_markers):
                    return DeltaMessage(content=delta_text, tool_calls=[])

        tool_call_deltas: list[DeltaToolCall] = []

        # Helper lambdas to analyze current_text state
        def _completed_calls_count(txt: str) -> int:
            return len(self._parse_tool_calls(txt))

        # Detect if a new tool_call started streaming its args just now.
        if delta_text and "<|tool_call:args|>" in delta_text:
            # Extract id and name for the latest tool call block present so far.
            begin_tag = "<|tool_call:begin|>"
            name_tag = "<|tool_call:name|>"
            args_tag = "<|tool_call:args|>"

            latest_args = current_text.rfind(args_tag)
            latest_name = current_text.rfind(name_tag, 0, latest_args if latest_args != -1 else None)
            latest_begin = current_text.rfind(begin_tag, 0, latest_name if latest_name != -1 else None)
            if latest_begin != -1 and latest_name != -1 and latest_args != -1 and latest_begin < latest_name < latest_args:
                tool_id = current_text[latest_begin + len(begin_tag):latest_name]
                func_name = current_text[latest_name + len(name_tag):latest_args]
                # Index equals number of args tags seen before this delta
                index = previous_text.count(args_tag)
                tool_call_deltas.append(
                    DeltaToolCall(
                        id=tool_id,
                        type="function",
                        index=index,
                        function=DeltaFunctionCall(name=func_name, arguments=""),
                    )
                )

        # If we are inside args (after last args tag without end), stream arg chunk
        begin_tag = "<|tool_call:begin|>"
        args_tag = "<|tool_call:args|>"
        end_tag = "<|tool_call:end|>"
        last_args_pos = current_text.rfind(args_tag)
        last_end_pos = current_text.rfind(end_tag)
        if last_args_pos != -1 and (last_end_pos == -1 or last_args_pos > last_end_pos):
            # Currently within args for the latest tool call
            # Determine previous args text and current args text to compute delta
            prev_last_args = previous_text.rfind(args_tag)
            prev_last_end = previous_text.rfind(end_tag)
            if prev_last_args != -1 and (prev_last_end == -1 or prev_last_args > prev_last_end):
                # Already inside args previously: emit only the delta_text
                if delta_text and delta_text not in (begin_tag, args_tag, end_tag):
                    # Stream into the most recently started (but not yet ended) call
                    index = max(previous_text.count(args_tag) - 1, 0)
                    tool_call_deltas.append(
                        DeltaToolCall(
                            id=None,
                            type=None,
                            index=index,
                            function=DeltaFunctionCall(name=None, arguments=delta_text),
                        )
                    )

        if not tool_call_deltas:
            return None

        return DeltaMessage(content=None, tool_calls=tool_call_deltas)

    # --------------------
    # Internal helpers
    # --------------------
    def _parse_text(self, text: str) -> Tuple[Optional[str], List[ToolCall]]:
        """Parse the completed segments from the given text.

        Returns (content, tool_calls) where content is extracted as the leading
        text up to the first '<|flush|>' or '<|end|>' marker, and tool_calls is
        a list of fully parsed tool calls inside '<|tool_calls|> ... <|calls|>'.
        """
        content = self._parse_content(text)
        tool_calls = self._parse_tool_calls(text)
        return content, tool_calls

    def _parse_content(self, text: str) -> Optional[str]:
        """Extract assistant content from the text.

        Rule: take the leading content before the first '<|flush|>' or
        '<|end|>' marker. If neither marker exists, return None.
        """
        end_tags = ["<|flush|>", "<|end|>"]

        # Take leading content before the first end tag
        end_positions = [pos for tag in end_tags if (pos := text.find(tag)) != -1]
        if not end_positions:
            return None
        end = min(end_positions)
        # Trim only the extracted portion; tests expect exact substring
        return text[:end]

    def _parse_tool_call_args(self, text: str) -> str:
        try:
            # Try to parse as JSON
            args = json.loads(text)
        except json.JSONDecodeError:
            try:
                # Try to parse as JSON5
                args = pyjson5.decode(text)
            except pyjson5.Json5DecoderException:
                try:
                    # Try to parse as Python literal
                    args = ast.literal_eval(text)
                except Exception:
                    # Fallback: return the original string
                    args = text
        if not isinstance(args, str):
            # Always convert back to JSON string
            args = json.dumps(args)
        return args

    def _parse_tool_calls(self, text: str) -> List[ToolCall]:
        tool_calls: list[ToolCall] = []
        # Parse globally; wrapper '<|tool_calls|>' may or may not be present.
        section_start = 0
        # section ends at <|calls|> if present, else use end of text
        section_end = text.find("<|calls|>")
        if section_end == -1:
            section_end = len(text)
        i = section_start
        while True:
            begin_tag = "<|tool_call:begin|>"
            name_tag = "<|tool_call:name|>"
            args_tag = "<|tool_call:args|>"
            end_tag = "<|tool_call:end|>"

            b = text.find(begin_tag, i, section_end)
            if b == -1:
                break
            b += len(begin_tag)
            n = text.find(name_tag, b, section_end)
            if n == -1:
                break
            tool_id = text[b:n]
            n += len(name_tag)
            a = text.find(args_tag, n, section_end)
            if a == -1:
                break
            name = text[n:a]
            a += len(args_tag)
            e = text.find(end_tag, a, section_end)
            if e == -1:
                break
            args = text[a:e]
            tool_calls.append(
                ToolCall(
                    id=tool_id,
                    function=FunctionCall(name=name, arguments=self._parse_tool_call_args(args)),
                ))
            i = e + len(end_tag)

        return tool_calls