File size: 25,562 Bytes
fb9bda4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 | # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa
import json
from collections.abc import Sequence
from enum import Enum
from typing import Any, Union
import partial_json_parser
import regex as re
from partial_json_parser.core.options import Allow
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid
logger = init_logger(__name__)
class ParsedStructure(Enum):
CONTENT = 1
REASONING_CONTENT = 2
TOOL_CALL = 3
TOOL_CALL_DELIMITER = 4
TOOL_CALL_START_TAG = 5
TOOL_CALL_END_TAG = 6
@ToolParserManager.register_module("tng_r1t2")
class TngR1T2ToolParser(ToolParser):
"""Tool Parser for models like tngtech/DeepSeek-TNG-R1T2-Chimera,
It is compatible with hermes tool call templates
but does not require <tool_call> and </tool_call>
to be single tokens in the vocabulary;
instead only the string representation of the model output
is parsed, making this tool call parser robust and versatile."""
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
# For backward compatibility with serving code
self.prev_tool_call_arr: list[dict] = []
self.streamed_args_for_tool: list[str] = []
self.think_tag_pattern = r"(<think>[\s\S]*?</think>)"
self.think_start_tag = "<think>"
self.think_end_tag = "</think>"
self.tool_call_tag_pattern = r"<tool_call>([\s\S]*?)</tool_call>"
self.tool_call_start_tag = "<tool_call>"
self.tool_call_end_tag = "</tool_call>"
# Define streaming state type to be initialized later
self.streaming_state: dict[str, Any] = {
"streamed_tool_calls": [],
"buffer": "",
"parsed_structure": ParsedStructure.CONTENT,
}
def extract_tool_call_from_nonthink_output(
self, raw_text: str) -> tuple[str, list[dict]]:
parts = re.split(self.tool_call_tag_pattern, raw_text)
content = ""
tool_calls: list[dict] = []
for i, part in enumerate(parts):
is_potential_tool_call = i % 2 == 1
if is_potential_tool_call:
try:
more_tool_calls = json.loads(part)
if isinstance(more_tool_calls, list):
tool_calls.extend(more_tool_calls)
else:
tool_calls.extend([more_tool_calls])
except json.JSONDecodeError:
logger.warning("Invalid tool call json "
"-> parse as text content")
content += part
continue
else:
content += part
return content, tool_calls
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
"""
Extract tool calls from a complete model output.
"""
# split at think traces -> those will not be parsed for tool calls
think_parts = re.split(self.think_tag_pattern, model_output)
content = ""
tool_calls = []
for i, part in enumerate(think_parts):
parse_output = i % 2 == 0
if parse_output:
more_content, more_tool_calls = (
self.extract_tool_call_from_nonthink_output(part))
content += more_content
tool_calls += more_tool_calls
else:
content += part
if not tool_calls:
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=content,
)
tool_call_objs: list[ToolCall] = []
for idx, call in enumerate(tool_calls):
if (not isinstance(call, dict) or "name" not in call
or "arguments" not in call):
logger.warning("Invalid tool call format, ignore.")
continue
tool_call = ToolCall(
id=f"call_{idx}_{random_uuid()}",
type="function",
function=FunctionCall(
name=call["name"],
arguments=(json.dumps(call["arguments"]) if isinstance(
call["arguments"], dict) else call["arguments"]),
),
)
tool_call_objs.append(tool_call)
return ExtractedToolCallInformation(
tools_called=len(tool_call_objs) > 0,
tool_calls=tool_call_objs,
content=content,
)
def _parse_think_trace(self, raw_text: str) -> tuple[str, bool, str]:
"""
Returns: (unambiguous_text_content, found_think_end, rest_string)
"""
# Either a complete think_end_tag can be somewhere in raw_text
think_end_pos = raw_text.find(self.think_end_tag)
if think_end_pos >= 0:
# in contrast to tool_call_start_tags, </think> remains part of content
think_end_tag_end_pos = think_end_pos + len(self.think_end_tag)
return (raw_text[:think_end_tag_end_pos], True,
raw_text[think_end_tag_end_pos:])
# or the end of raw_text can be continued to a complete think_end_tag
think_end_pos = (
len(raw_text) -
self._ends_with_partial_token(raw_text, self.think_end_tag))
return raw_text[:think_end_pos], False, raw_text[think_end_pos:]
def _parse_unambiguous_text_content(
self, raw_text: str) -> tuple[str, Union[str, None], str]:
"""
Returns: (unambiguous_text_content, interrupting_tag, rest_string)
"""
# Either a complete tool_call_start_tag or think_start can be somewhere in raw_text
search_tags = [self.think_start_tag, self.tool_call_start_tag]
tag_positions = [(tag, pos) for tag in search_tags
if (pos := raw_text.find(tag)) >= 0]
tag_positions.sort(key=lambda tag_and_pos: tag_and_pos[1])
if len(tag_positions) > 0:
first_tag, tag_pos = tag_positions[0]
return raw_text[:tag_pos], first_tag, raw_text[tag_pos:]
# or the end of raw_text can be continued to a complete tag
tag_positions = [
(tag, len(raw_text) - self._ends_with_partial_token(raw_text, tag))
for tag in search_tags
]
tag_positions.sort(key=lambda tag_and_pos: tag_and_pos[1])
first_tag, tag_pos = tag_positions[0]
if tag_pos < len(raw_text):
return raw_text[:tag_pos], None, raw_text[tag_pos:]
return raw_text, None, ""
def _parse_tool_call_start_tag(self, raw_text: str) -> tuple[bool, str]:
"""
Removes tool_call_start_tag from the beginning of raw_text,
and an optional "[", and leading whitespace.
Returns: (found_complete_tool_call_start_tag, rest_string)
"""
if not raw_text.startswith(self.tool_call_start_tag):
return False, raw_text
rest = raw_text[len(self.tool_call_start_tag):].lstrip()
if rest.startswith("["):
rest = rest[1:].lstrip()
return True, rest
def _parse_tool_call_end_tag(
self, raw_text: str) -> tuple[Union[bool, None], str]:
"""
Removes tool_call_end_tag from the beginning of raw_text,
and an optional "]" before it, and leading whitespace.
Returns: tuple
found_complete_tool_call_end_tag (or None if not decidable yet)
rest_string
"""
# remove optional whitespace and closing ] bracket from json list notation
rest = raw_text.lstrip()
if rest.startswith("]"):
rest = rest[1:].lstrip()
if rest.startswith(self.tool_call_end_tag):
# found a complete tool call end tag
return True, rest[len(self.tool_call_end_tag):]
if (len(rest) >= len(self.tool_call_end_tag)
or rest != self.tool_call_end_tag[:len(rest)]):
# evidence that rest_string does not start with a tool call end tag
return False, raw_text
# incomplete tool call end tag, can not be decided yet
return None, raw_text
def _extract_arguments_from_partial_tool_call(
self, raw_text: str) -> Union[str, None]:
"""
Extracts the raw text of the "arguments" field of a complete
or partial tool call.
Args:
raw_text: tool call raw text,
e.g `{"name": "my_tool", "arguments": {"firstarg": "some`
Returns:
raw text of the "arguments" field, which is not valid JSON
unless the tool call is complete,
e.g. `{"firstarg": "some` for the example raw_text above
"""
# assumptions:
# - "arguments" is always an object
# - there is no other field of type object in the function call
# - `raw_text` contains first "name", then "arguments" (otherwise,
# we'd have to find the end of "arguments" before returning its
# raw text value)
# typically, at position 0, but there might be leading whitespace
tool_call_start_pos = raw_text.find("{")
assert raw_text[:tool_call_start_pos].strip() == ""
arguments_start_pos = raw_text.find("{", tool_call_start_pos + 1)
if arguments_start_pos < 0:
return None
arguments_raw_text = raw_text[arguments_start_pos:]
return arguments_raw_text
def _parse_complete_tool_call(
self, raw_text: str) -> tuple[Union[dict, None], str]:
"""
Returns: tuple
parsed tool call if complete, None otherwise
rest_string that needs to be parsed again or may contain
a partial tool call
"""
# raw_text must start without whitespace for correct parsing
obj, end_pos = self.extract_complete_json_dict(raw_text)
if obj is None:
return None, raw_text
tool_call_raw_text = raw_text[:end_pos]
# `tool_call_raw_text` is something like:
# '{"name": "tool-name", "arguments": {...xyz...} }'
# we want to extract `{...xyz...}`,
# but `extract_arguments_from_partial_tool_call` would return
# everything after the second '{', i.e. `{...xyz...} }`
arguments_raw_text = self._extract_arguments_from_partial_tool_call(
tool_call_raw_text.removesuffix("}").rstrip())
tool_call = {
"name": obj.get("name"),
"arguments": obj.get("arguments"),
"arguments_raw_text": arguments_raw_text,
"is_complete": True,
}
return tool_call, raw_text[end_pos:]
def _parse_partial_tool_call(self, raw_text: str) -> Union[dict, None]:
# raw_text must start without whitespace for correct parsing
obj = partial_json_parser.loads(raw_text, Allow.ALL)
arguments_raw_text = (
self._extract_arguments_from_partial_tool_call(raw_text))
tool_call = {
"name": obj.get("name"),
"arguments": obj.get("arguments"),
"arguments_raw_text": arguments_raw_text,
"is_complete": False,
}
return tool_call
def _parse_tool_call(
self,
raw_text: str) -> tuple[Union[bool, None], Union[dict, None], str]:
# remove optional whitespace and closing ] bracket
# from json list notation
rest = raw_text.lstrip()
if rest == "":
# no json has been received yet
# -> can't tell if this will be a valid tool call
return None, None, raw_text
if not rest.startswith("{"):
# can't be a tool call json
return False, None, raw_text
tool_call, rest = self._parse_complete_tool_call(rest)
if tool_call:
return True, tool_call, rest
try:
tool_call = self._parse_partial_tool_call(rest)
# need to re-parse partial tool call later again
# -> return None, not True
return None, tool_call, rest
except json.JSONDecodeError:
# invalid json -> neither complete nor partial tool call
return False, None, rest
def _parse_tool_call_delimiter(
self, raw_text: str) -> tuple[Union[bool, None], str]:
"""
Returns: tuple
does raw_text start with tool call delimiter?
(None if undecidable/incomplete)
rest_string
"""
rest = raw_text.lstrip()
if rest == "":
return None, raw_text
has_next_tool_call = rest.startswith(",")
if not has_next_tool_call:
return False, raw_text
rest = rest[1:].lstrip()
if rest == "":
return None, raw_text
has_next_tool_call = rest.startswith("{")
if not has_next_tool_call:
return False, raw_text
return True, rest
def _parse_all(
self, raw_text: str, start_mode: ParsedStructure
) -> tuple[str, list[dict], str, ParsedStructure]:
if start_mode == ParsedStructure.REASONING_CONTENT:
content, found_closing_think, rest = self._parse_think_trace(
raw_text)
if found_closing_think:
more_content, tool_calls, rest, structure = self._parse_all(
rest, start_mode=ParsedStructure.CONTENT)
return content + more_content, tool_calls, rest, structure
return content, [], rest, ParsedStructure.REASONING_CONTENT
elif start_mode == ParsedStructure.CONTENT:
content, interrupting_tag, rest = self._parse_unambiguous_text_content(
raw_text)
# rest might contain a tool call start tag or a think start tag
if interrupting_tag == self.tool_call_start_tag:
more_content, tool_calls, rest, structure = self._parse_all(
rest, start_mode=ParsedStructure.TOOL_CALL_START_TAG)
return content + more_content, tool_calls, rest, structure
elif interrupting_tag == self.think_start_tag:
more_content, tool_calls, rest, structure = self._parse_all(
rest, start_mode=ParsedStructure.REASONING_CONTENT)
return content + more_content, tool_calls, rest, structure
else:
return content, [], rest, ParsedStructure.CONTENT
elif start_mode == ParsedStructure.TOOL_CALL_START_TAG:
found_tool_call_start_tag, rest = self._parse_tool_call_start_tag(
raw_text)
if not found_tool_call_start_tag:
return "", [], raw_text, ParsedStructure.CONTENT
# we found a complete start tag, but we haven't seen the begin of a tool call json yet
content, tool_calls, rest, structure = self._parse_all(
rest, start_mode=ParsedStructure.TOOL_CALL)
if not content and not tool_calls:
# We haven't reached the opening "{" of the tool call yet.
# We might see a "[" before the "{", so let's process the start tag again next chunk.
return content, [], raw_text, ParsedStructure.CONTENT
return content, tool_calls, rest, structure
elif start_mode == ParsedStructure.TOOL_CALL:
found_tool_call, tool_call, rest = self._parse_tool_call(raw_text)
if found_tool_call is True:
tool_calls = [tool_call] if tool_call else []
content, more_tool_calls, rest, structure = self._parse_all(
rest, start_mode=ParsedStructure.TOOL_CALL_DELIMITER)
return (content, tool_calls + more_tool_calls, rest, structure)
elif found_tool_call is None:
# partial tool call -> need to parse again with next chunk
tool_calls = ([tool_call] if tool_call is not None else [])
return "", tool_calls, rest, ParsedStructure.TOOL_CALL
else:
logger.warning(
"Invalid tool call -> continue with parsing model output as text content"
)
return self._parse_all(raw_text,
start_mode=ParsedStructure.CONTENT)
elif start_mode == ParsedStructure.TOOL_CALL_DELIMITER:
found_tool_call_delimiter, rest = self._parse_tool_call_delimiter(
raw_text)
if found_tool_call_delimiter is True:
return self._parse_all(rest,
start_mode=ParsedStructure.TOOL_CALL)
elif found_tool_call_delimiter is None:
# could neither confirm nor deny that raw_text starts with a tool call delimiter
return "", [], rest, ParsedStructure.TOOL_CALL_DELIMITER
else:
return self._parse_all(
raw_text, start_mode=ParsedStructure.TOOL_CALL_END_TAG)
elif start_mode == ParsedStructure.TOOL_CALL_END_TAG:
found_tool_call_end_tag, rest = self._parse_tool_call_end_tag(
raw_text)
if found_tool_call_end_tag is True:
return self._parse_all(rest,
start_mode=ParsedStructure.CONTENT)
elif found_tool_call_end_tag is None:
return "", [], rest, ParsedStructure.TOOL_CALL_END_TAG
else:
return self._parse_all(raw_text,
start_mode=ParsedStructure.CONTENT)
logger.warning(
f"Unknown tool call parser start_mode '{start_mode}'. Falling back to text content."
)
return self._parse_all(raw_text, start_mode=ParsedStructure.CONTENT)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
"""
Extract tool calls for streaming mode.
"""
raw_text = self.streaming_state["buffer"] + delta_text
structure = self.streaming_state["parsed_structure"]
content, tool_calls, rest, new_structure = self._parse_all(
raw_text, start_mode=structure)
self.streaming_state["buffer"] = rest
self.streaming_state["parsed_structure"] = new_structure
already_streamed_tool_calls = (
self.streaming_state["streamed_tool_calls"])
already_streamed_complete_tool_calls = [
tool_call for tool_call in already_streamed_tool_calls
if tool_call["is_complete"]
]
all_tool_calls = (already_streamed_complete_tool_calls +
(tool_calls or []))
to_be_streamed_tool_calls = self._calculate_delta_tool_calls(
all_tool_calls, already_streamed_tool_calls)
if not content and not to_be_streamed_tool_calls:
return None
self.update_state_vars(all_tool_calls)
return DeltaMessage(content=content if content else None,
tool_calls=to_be_streamed_tool_calls)
def _calculate_delta_tool_calls(
self, current_tool_calls: Union[list[dict], None],
already_streamed_tool_calls: list[dict]) -> list[DeltaToolCall]:
if not current_tool_calls:
return []
new_deltas = []
for tool_call_idx, partial_tool_call in enumerate(current_tool_calls):
if (partial_tool_call.get("name") is None
or partial_tool_call.get("arguments") is None):
# do not stream arguments for an unknown tool name;
# and unless arguments appear in the partial json,
# it might be that "name" has not been received completely
# (assuming a template like `{"name": "mytool", "arguments": ...}`)
continue
partial_tool_call["tool_call_idx"] = tool_call_idx
partial_tool_call["arguments_raw_text"] = (
partial_tool_call.get('arguments_raw_text') or "")
if len(already_streamed_tool_calls) > tool_call_idx:
# parts of this tool_call_idx have already been streamed
already_streamed_tool_call = already_streamed_tool_calls[
tool_call_idx]
delta_tool_call = self._delta_for_partial_tool_call(
partial_tool_call, already_streamed_tool_call)
if delta_tool_call is not None:
new_deltas.append(delta_tool_call)
already_streamed_tool_calls[tool_call_idx] = (
already_streamed_tool_call | partial_tool_call)
else:
# no parts of this tool_call_idx have been streamed yet
tool_call = self._delta_for_new_tool_call(partial_tool_call)
new_deltas.append(tool_call)
already_streamed_tool_calls.append(partial_tool_call)
return new_deltas
def _delta_for_new_tool_call(self, tool_call_dict: dict) -> DeltaToolCall:
"""constructs DeltaToolCall for new tool call,
with tool_call_id, name, and all arguments seen so far.
Updates tool_call dictionary with tool_call_id.
"""
tool_call_idx = tool_call_dict["tool_call_idx"]
tool_call_dict[
"tool_call_id"] = f"call_{tool_call_idx}_{random_uuid()}"
tool_call_dict["arguments_raw_text"] = tool_call_dict.get(
'arguments_raw_text') or ""
delta_tool_call = DeltaToolCall(
index=tool_call_idx,
type="function",
id=tool_call_dict["tool_call_id"],
function=DeltaFunctionCall(
name=tool_call_dict.get("name"),
arguments=tool_call_dict["arguments_raw_text"]))
return delta_tool_call
def _delta_for_partial_tool_call(
self, new_tool_call: dict,
already_streamed_tool_call: dict) -> Union[DeltaToolCall, None]:
"""Calculate delta for a tool call of which some parts have already been streamed."""
assert new_tool_call["name"] == already_streamed_tool_call["name"]
assert already_streamed_tool_call.get("tool_call_id")
if already_streamed_tool_call.get("is_complete"):
return None
to_be_streamed_arguments = (
new_tool_call["arguments_raw_text"].removeprefix(
already_streamed_tool_call["arguments_raw_text"]))
if not to_be_streamed_arguments:
return None
delta_tool_call = DeltaToolCall(
index=new_tool_call["tool_call_idx"],
type="function",
function=DeltaFunctionCall(arguments=to_be_streamed_arguments))
return delta_tool_call
def update_state_vars(self, all_tools: list[dict]) -> None:
# `tool_parser.streamed_args_for_tool` and
# `tool_parser.prev_tool_call_arr` are checked in serving_chat.py
# relevant is {"arguments": {...}}
self.prev_tool_call_arr = all_tools
# json-serialized argument
self.streamed_args_for_tool = [
tool_call.get("arguments_raw_text", "") for tool_call in all_tools
]
@classmethod
def _ends_with_partial_token(cls, buffer: str, tag: str) -> int:
"""
Check if buffer ends with a partial tag.
Return the length of the partial tag.
"""
for i in range(1, min(len(buffer) + 1, len(tag))):
if tag.startswith(buffer[-i:]):
return i
return 0
@classmethod
def extract_complete_json_dict(cls, json_str: str):
try:
decoder = json.JSONDecoder()
obj, end_pos = decoder.raw_decode(
json_str) # ignore any text after the end of the json object
if isinstance(obj, dict):
return obj, end_pos
return None, 0
except json.JSONDecodeError:
return None, 0
|