Llama-3.1-8B-Instruct / openpipe_llama_dual.py
kovsbo's picture
Upload openpipe_llama_dual.py
9a0b39d verified
raw
history blame
7.21 kB
import json
from collections.abc import Sequence
from typing import Any, Optional
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ToolParser, ToolParserManager
@ToolParserManager.register_module(["openpipe_llama_dual"])
class OpenPipeLlamaDualParser(ToolParser):
"""Parse either official Llama JSON calls or OpenPipe legacy markers."""
LEGACY_START = "<|start_tool_call|>"
LEGACY_END = "<|end_tool_call|>"
VARIANT_LEGACY = "openpipe_legacy"
VARIANT_OFFICIAL = "official"
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.tokenizer = tokenizer
def _get_template_variant(self, request: ChatCompletionRequest) -> Optional[str]:
kwargs = getattr(request, "chat_template_kwargs", None)
if kwargs is None:
return None
if isinstance(kwargs, dict):
value = kwargs.get("template_variant")
return value if isinstance(value, str) else None
value = getattr(kwargs, "template_variant", None)
return value if isinstance(value, str) else None
def _normalize_tool_call(self, payload: dict[str, Any]) -> Optional[dict[str, Any]]:
if "name" in payload and "parameters" in payload:
return {
"name": payload["name"],
"arguments": payload["parameters"],
}
if "function" in payload and isinstance(payload["function"], dict):
function = payload["function"]
if "name" in function and "arguments" in function:
return {
"name": function["name"],
"arguments": function["arguments"],
}
return None
def _extract_legacy_tool_calls(
self,
text: str,
) -> list[dict[str, Any]]:
tool_calls = []
current_index = 0
while True:
start_index = text.find(self.LEGACY_START, current_index)
if start_index == -1:
break
end_index = text.find(self.LEGACY_END, start_index)
if end_index == -1:
break
tool_call_json = text[start_index + len(self.LEGACY_START) : end_index].strip()
payload = json.loads(tool_call_json)
normalized = self._normalize_tool_call(payload)
if normalized:
tool_calls.append(normalized)
current_index = end_index + len(self.LEGACY_END)
return tool_calls
def _extract_official_tool_call(self, text: str) -> Optional[dict[str, Any]]:
stripped = text.strip()
if not stripped.startswith("{") or not stripped.endswith("}"):
return None
payload = json.loads(stripped)
return self._normalize_tool_call(payload)
def _build_delta_tool_call(
self,
tool_call: dict[str, Any],
index: int = 0,
) -> DeltaMessage:
arguments = tool_call["arguments"]
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=index,
id=f"call_{tool_call['name']}",
type="function",
function=DeltaFunctionCall(
name=tool_call["name"],
arguments=json.dumps(arguments, ensure_ascii=False)
if isinstance(arguments, (dict, list))
else arguments,
),
)
]
)
def _build_tool_calls_response(
self,
tool_calls: list[dict[str, Any]],
) -> ExtractedToolCallInformation:
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=[
ToolCall(
id=f"call_{index + 1}",
type="function",
function=FunctionCall(
name=tool_call["name"],
arguments=json.dumps(
tool_call["arguments"], ensure_ascii=False
)
if isinstance(tool_call["arguments"], (dict, list))
else tool_call["arguments"],
),
)
for index, tool_call in enumerate(tool_calls)
],
content=None,
)
def _looks_like_partial_official_json(self, text: str) -> bool:
stripped = text.strip()
if not stripped.startswith("{"):
return False
if stripped.endswith("}"):
return False
return '"name"' in stripped or '"parameters"' in stripped or '"function"' in stripped
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
variant = self._get_template_variant(request)
try:
if variant == self.VARIANT_LEGACY or self.LEGACY_START in current_text:
if self.LEGACY_START in current_text and self.LEGACY_END in current_text:
tool_calls = self._extract_legacy_tool_calls(current_text)
if tool_calls:
return self._build_delta_tool_call(tool_calls[-1])
if self.LEGACY_START in current_text:
return None
return DeltaMessage(content=delta_text)
official_tool_call = self._extract_official_tool_call(current_text)
if official_tool_call:
return self._build_delta_tool_call(official_tool_call)
if variant == self.VARIANT_OFFICIAL and self._looks_like_partial_official_json(current_text):
return None
except Exception:
return DeltaMessage(content=delta_text)
return DeltaMessage(content=delta_text)
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
variant = self._get_template_variant(request)
try:
if variant == self.VARIANT_LEGACY or self.LEGACY_START in model_output:
tool_calls = self._extract_legacy_tool_calls(model_output)
if tool_calls:
return self._build_tool_calls_response(tool_calls)
official_tool_call = self._extract_official_tool_call(model_output)
if official_tool_call:
return self._build_tool_calls_response([official_tool_call])
except Exception:
pass
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=model_output,
)