| import json |
| from collections.abc import Sequence |
| from typing import Any, Optional |
|
|
| from vllm.entrypoints.openai.chat_completion.protocol import ( |
| ChatCompletionRequest, |
| ) |
| from vllm.entrypoints.openai.engine.protocol import ( |
| DeltaFunctionCall, |
| DeltaMessage, |
| DeltaToolCall, |
| ExtractedToolCallInformation, |
| FunctionCall, |
| ToolCall, |
| ) |
| from vllm.tokenizers import TokenizerLike |
| from vllm.tool_parsers.abstract_tool_parser import ToolParser, ToolParserManager |
|
|
|
|
| @ToolParserManager.register_module(["openpipe_llama_dual"]) |
| class OpenPipeLlamaDualParser(ToolParser): |
| """Parse either official Llama JSON calls or OpenPipe legacy markers.""" |
|
|
| LEGACY_START = "<|start_tool_call|>" |
| LEGACY_END = "<|end_tool_call|>" |
| VARIANT_LEGACY = "openpipe_legacy" |
| VARIANT_OFFICIAL = "official" |
|
|
| def __init__(self, tokenizer: TokenizerLike): |
| super().__init__(tokenizer) |
| self.tokenizer = tokenizer |
|
|
| def _get_template_variant(self, request: ChatCompletionRequest) -> Optional[str]: |
| kwargs = getattr(request, "chat_template_kwargs", None) |
| if kwargs is None: |
| return None |
| if isinstance(kwargs, dict): |
| value = kwargs.get("template_variant") |
| return value if isinstance(value, str) else None |
| value = getattr(kwargs, "template_variant", None) |
| return value if isinstance(value, str) else None |
|
|
| def _normalize_tool_call(self, payload: dict[str, Any]) -> Optional[dict[str, Any]]: |
| if "name" in payload and "parameters" in payload: |
| return { |
| "name": payload["name"], |
| "arguments": payload["parameters"], |
| } |
| if "function" in payload and isinstance(payload["function"], dict): |
| function = payload["function"] |
| if "name" in function and "arguments" in function: |
| return { |
| "name": function["name"], |
| "arguments": function["arguments"], |
| } |
| return None |
|
|
| def _extract_legacy_tool_calls( |
| self, |
| text: str, |
| ) -> list[dict[str, Any]]: |
| tool_calls = [] |
| current_index = 0 |
|
|
| while True: |
| start_index = text.find(self.LEGACY_START, current_index) |
| if start_index == -1: |
| break |
|
|
| end_index = text.find(self.LEGACY_END, start_index) |
| if end_index == -1: |
| break |
|
|
| tool_call_json = text[start_index + len(self.LEGACY_START) : end_index].strip() |
| payload = json.loads(tool_call_json) |
| normalized = self._normalize_tool_call(payload) |
| if normalized: |
| tool_calls.append(normalized) |
| current_index = end_index + len(self.LEGACY_END) |
|
|
| return tool_calls |
|
|
| def _extract_official_tool_call(self, text: str) -> Optional[dict[str, Any]]: |
| stripped = text.strip() |
| if not stripped.startswith("{") or not stripped.endswith("}"): |
| return None |
| payload = json.loads(stripped) |
| return self._normalize_tool_call(payload) |
|
|
| def _build_delta_tool_call( |
| self, |
| tool_call: dict[str, Any], |
| index: int = 0, |
| ) -> DeltaMessage: |
| arguments = tool_call["arguments"] |
| return DeltaMessage( |
| tool_calls=[ |
| DeltaToolCall( |
| index=index, |
| id=f"call_{tool_call['name']}", |
| type="function", |
| function=DeltaFunctionCall( |
| name=tool_call["name"], |
| arguments=json.dumps(arguments, ensure_ascii=False) |
| if isinstance(arguments, (dict, list)) |
| else arguments, |
| ), |
| ) |
| ] |
| ) |
|
|
| def _build_tool_calls_response( |
| self, |
| tool_calls: list[dict[str, Any]], |
| ) -> ExtractedToolCallInformation: |
| return ExtractedToolCallInformation( |
| tools_called=True, |
| tool_calls=[ |
| ToolCall( |
| id=f"call_{index + 1}", |
| type="function", |
| function=FunctionCall( |
| name=tool_call["name"], |
| arguments=json.dumps( |
| tool_call["arguments"], ensure_ascii=False |
| ) |
| if isinstance(tool_call["arguments"], (dict, list)) |
| else tool_call["arguments"], |
| ), |
| ) |
| for index, tool_call in enumerate(tool_calls) |
| ], |
| content=None, |
| ) |
|
|
| def _looks_like_partial_official_json(self, text: str) -> bool: |
| stripped = text.strip() |
| if not stripped.startswith("{"): |
| return False |
| if stripped.endswith("}"): |
| return False |
| return '"name"' in stripped or '"parameters"' in stripped or '"function"' in stripped |
|
|
| def extract_tool_calls_streaming( |
| self, |
| previous_text: str, |
| current_text: str, |
| delta_text: str, |
| previous_token_ids: Sequence[int], |
| current_token_ids: Sequence[int], |
| delta_token_ids: Sequence[int], |
| request: ChatCompletionRequest, |
| ) -> DeltaMessage | None: |
| variant = self._get_template_variant(request) |
|
|
| try: |
| if variant == self.VARIANT_LEGACY or self.LEGACY_START in current_text: |
| if self.LEGACY_START in current_text and self.LEGACY_END in current_text: |
| tool_calls = self._extract_legacy_tool_calls(current_text) |
| if tool_calls: |
| return self._build_delta_tool_call(tool_calls[-1]) |
| if self.LEGACY_START in current_text: |
| return None |
| return DeltaMessage(content=delta_text) |
|
|
| official_tool_call = self._extract_official_tool_call(current_text) |
| if official_tool_call: |
| return self._build_delta_tool_call(official_tool_call) |
| if variant == self.VARIANT_OFFICIAL and self._looks_like_partial_official_json(current_text): |
| return None |
| except Exception: |
| return DeltaMessage(content=delta_text) |
|
|
| return DeltaMessage(content=delta_text) |
|
|
| def extract_tool_calls( |
| self, |
| model_output: str, |
| request: ChatCompletionRequest, |
| ) -> ExtractedToolCallInformation: |
| variant = self._get_template_variant(request) |
|
|
| try: |
| if variant == self.VARIANT_LEGACY or self.LEGACY_START in model_output: |
| tool_calls = self._extract_legacy_tool_calls(model_output) |
| if tool_calls: |
| return self._build_tool_calls_response(tool_calls) |
|
|
| official_tool_call = self._extract_official_tool_call(model_output) |
| if official_tool_call: |
| return self._build_tool_calls_response([official_tool_call]) |
| except Exception: |
| pass |
|
|
| return ExtractedToolCallInformation( |
| tools_called=False, |
| tool_calls=[], |
| content=model_output, |
| ) |
|
|