| import json | |
| import logging | |
| from abc import ABC, abstractmethod | |
| from typing import Any, Dict, List | |
| import orjson | |
| from partial_json_parser.core.exceptions import MalformedJSON | |
| from partial_json_parser.core.options import Allow | |
| from sglang.srt.entrypoints.openai.protocol import Tool | |
| from sglang.srt.function_call.core_types import ( | |
| StreamingParseResult, | |
| ToolCallItem, | |
| _GetInfoFunc, | |
| ) | |
| from sglang.srt.function_call.utils import ( | |
| _find_common_prefix, | |
| _is_complete_json, | |
| _partial_json_loads, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class BaseFormatDetector(ABC): | |
| """Base class providing two sets of interfaces: one-time and streaming incremental.""" | |
| def __init__(self): | |
| # Streaming state management | |
| # Buffer for accumulating incomplete patterns that arrive across multiple streaming chunks | |
| self._buffer = "" | |
| # Stores complete tool call info (name and arguments) for each tool being parsed. | |
| # Used by serving layer for completion handling when streaming ends. | |
| # Format: [{"name": str, "arguments": dict}, ...] | |
| self.prev_tool_call_arr: List[Dict] = [] | |
| # Index of currently streaming tool call. Starts at -1 (no active tool), | |
| # increments as each tool completes. Tracks which tool's arguments are streaming. | |
| self.current_tool_id: int = -1 | |
| # Flag for whether current tool's name has been sent to client. | |
| # Tool names sent first with empty parameters, then arguments stream incrementally. | |
| self.current_tool_name_sent: bool = False | |
| # Tracks raw JSON string content streamed to client for each tool's arguments. | |
| # Critical for serving layer to calculate remaining content when streaming ends. | |
| # Each index corresponds to a tool_id. Example: ['{"location": "San Francisco"', '{"temp": 72'] | |
| self.streamed_args_for_tool: List[str] = [] | |
| # Token configuration (override in subclasses) | |
| self.bot_token = "" | |
| self.eot_token = "" | |
| self.tool_call_separator = ", " | |
| def _get_tool_indices(self, tools: List[Tool]) -> Dict[str, int]: | |
| """ | |
| Get a mapping of tool names to their indices in the tools list. | |
| This utility method creates a dictionary mapping function names to their | |
| indices in the tools list, which is commonly needed for tool validation | |
| and ToolCallItem creation. | |
| Args: | |
| tools: List of available tools | |
| Returns: | |
| Dictionary mapping tool names to their indices | |
| """ | |
| return { | |
| tool.function.name: i for i, tool in enumerate(tools) if tool.function.name | |
| } | |
| def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]: | |
| tool_indices = self._get_tool_indices(tools) | |
| if not isinstance(action, list): | |
| action = [action] | |
| results = [] | |
| for act in action: | |
| name = act.get("name") | |
| if name and name in tool_indices: | |
| results.append( | |
| ToolCallItem( | |
| tool_index=-1, # Caller should update this based on the actual tools array called | |
| name=name, | |
| parameters=json.dumps( | |
| act.get("parameters") or act.get("arguments", {}), | |
| ensure_ascii=False, | |
| ), | |
| ) | |
| ) | |
| else: | |
| logger.warning(f"Model attempted to call undefined function: {name}") | |
| return results | |
| def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: | |
| """ | |
| Parses the text in one go. Returns success=True if the format matches, otherwise False. | |
| Note that leftover_text here represents "content that this parser will not consume further". | |
| """ | |
| action = orjson.loads(text) | |
| return StreamingParseResult(calls=self.parse_base_json(action, tools)) | |
| def _ends_with_partial_token(self, buffer: str, bot_token: str) -> int: | |
| """ | |
| Check if buffer ends with a partial bot_token. | |
| Return the length of the partial bot_token. | |
| For some format, the bot_token is not a token in model's vocabulary, such as | |
| `[TOOL_CALLS] [` in Mistral. | |
| """ | |
| for i in range(1, min(len(buffer) + 1, len(bot_token))): | |
| if bot_token.startswith(buffer[-i:]): | |
| return i | |
| return 0 | |
| def parse_streaming_increment( | |
| self, new_text: str, tools: List[Tool] | |
| ) -> StreamingParseResult: | |
| """ | |
| Streaming incremental parsing with tool validation. | |
| This base implementation works best with formats where: | |
| 1. bot_token is followed immediately by JSON (e.g., bot_token + JSON_array) | |
| 2. JSON can be parsed incrementally using partial_json_loads | |
| 3. Multiple tool calls are separated by "; " or ", " | |
| Examples of incompatible formats (need custom implementation, may reuse some logic from this class): | |
| - Each tool call is wrapped in a separate block: See Qwen25Detector | |
| - Multiple separate blocks: [TOOL_CALLS] [...] \n [TOOL_CALLS] [...] | |
| - Tool call is Pythonic style | |
| For incompatible formats, detectors should override this method with custom logic. | |
| """ | |
| # Append new text to buffer | |
| self._buffer += new_text | |
| current_text = self._buffer | |
| # The current_text has tool_call if it is the start of a new tool call sequence | |
| # or it is the start of a new tool call after a tool call separator, when there is a previous tool call | |
| if not ( | |
| self.has_tool_call(current_text) | |
| or ( | |
| self.current_tool_id > 0 | |
| and current_text.startswith(self.tool_call_separator) | |
| ) | |
| ): | |
| # Only clear buffer if we're sure no tool call is starting | |
| if not self._ends_with_partial_token(self._buffer, self.bot_token): | |
| normal_text = self._buffer | |
| self._buffer = "" | |
| if self.eot_token in normal_text: | |
| normal_text = normal_text.replace(self.eot_token, "") | |
| return StreamingParseResult(normal_text=normal_text) | |
| else: | |
| # Might be partial bot_token, keep buffering | |
| return StreamingParseResult() | |
| # Build tool indices if not already built | |
| if not hasattr(self, "_tool_indices"): | |
| self._tool_indices = self._get_tool_indices(tools) | |
| flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR | |
| try: | |
| try: | |
| tool_call_pos = current_text.find(self.bot_token) | |
| if tool_call_pos != -1: | |
| start_idx = tool_call_pos + len(self.bot_token) | |
| elif self.current_tool_id > 0 and current_text.startswith( | |
| self.tool_call_separator | |
| ): | |
| start_idx = len(self.tool_call_separator) | |
| else: | |
| start_idx = 0 | |
| if start_idx >= len(current_text): | |
| return StreamingParseResult() | |
| (obj, end_idx) = _partial_json_loads(current_text[start_idx:], flags) | |
| is_current_complete = _is_complete_json( | |
| current_text[start_idx : start_idx + end_idx] | |
| ) | |
| # Validate tool name if present | |
| if "name" in obj and obj["name"] not in self._tool_indices: | |
| # Invalid tool name - reset state | |
| self._buffer = "" | |
| self.current_tool_id = -1 | |
| self.current_tool_name_sent = False | |
| if self.streamed_args_for_tool: | |
| self.streamed_args_for_tool.pop() | |
| return StreamingParseResult() | |
| # Handle parameters/arguments consistency | |
| # NOTE: we assume here that the obj is always partial of a single tool call | |
| if "parameters" in obj: | |
| assert ( | |
| "arguments" not in obj | |
| ), "model generated both parameters and arguments" | |
| obj["arguments"] = obj["parameters"] | |
| current_tool_call = obj | |
| except MalformedJSON: | |
| return StreamingParseResult() | |
| if not current_tool_call: | |
| return StreamingParseResult() | |
| # Case 1: Handle tool name streaming | |
| # This happens when we encounter a tool but haven't sent its name yet | |
| if not self.current_tool_name_sent: | |
| function_name = current_tool_call.get("name") | |
| if function_name and function_name in self._tool_indices: | |
| # If this is a new tool (current_tool_id was -1), initialize it | |
| if self.current_tool_id == -1: | |
| self.current_tool_id = 0 | |
| self.streamed_args_for_tool.append("") | |
| # If this is a subsequent tool, ensure streamed_args_for_tool is large enough | |
| elif self.current_tool_id >= len(self.streamed_args_for_tool): | |
| while len(self.streamed_args_for_tool) <= self.current_tool_id: | |
| self.streamed_args_for_tool.append("") | |
| # Send the tool name with empty parameters | |
| res = StreamingParseResult( | |
| calls=[ | |
| ToolCallItem( | |
| tool_index=self.current_tool_id, | |
| name=function_name, | |
| parameters="", | |
| ) | |
| ], | |
| ) | |
| self.current_tool_name_sent = True | |
| else: | |
| res = StreamingParseResult() | |
| # Case 2: Handle streaming arguments | |
| # This happens when we've already sent the tool name and now need to stream arguments incrementally | |
| else: | |
| cur_arguments = current_tool_call.get("arguments") | |
| res = StreamingParseResult() | |
| if cur_arguments: | |
| # Calculate how much of the arguments we've already streamed | |
| sent = len(self.streamed_args_for_tool[self.current_tool_id]) | |
| cur_args_json = json.dumps(cur_arguments) | |
| prev_arguments = None | |
| if self.current_tool_id < len(self.prev_tool_call_arr): | |
| prev_arguments = self.prev_tool_call_arr[ | |
| self.current_tool_id | |
| ].get("arguments") | |
| argument_diff = None | |
| # If the current tool's JSON is complete, send all remaining arguments | |
| if is_current_complete: | |
| argument_diff = cur_args_json[sent:] | |
| completing_tool_id = ( | |
| self.current_tool_id | |
| ) # Save the ID of the tool that's completing | |
| # Only remove the processed portion, keep unprocessed content | |
| self._buffer = current_text[start_idx + end_idx :] | |
| # If the tool is still being parsed, send incremental changes | |
| elif prev_arguments: | |
| prev_args_json = json.dumps(prev_arguments) | |
| if cur_args_json != prev_args_json: | |
| prefix = _find_common_prefix(prev_args_json, cur_args_json) | |
| argument_diff = prefix[sent:] | |
| # Update prev_tool_call_arr with current state | |
| if self.current_tool_id >= 0: | |
| # Ensure prev_tool_call_arr is large enough | |
| while len(self.prev_tool_call_arr) <= self.current_tool_id: | |
| self.prev_tool_call_arr.append({}) | |
| self.prev_tool_call_arr[self.current_tool_id] = ( | |
| current_tool_call | |
| ) | |
| # Advance to next tool if complete | |
| if is_current_complete: | |
| self.current_tool_name_sent = False | |
| self.current_tool_id += 1 | |
| # Send the argument diff if there's something new | |
| if argument_diff is not None: | |
| # Use the correct tool_index: completing_tool_id for completed tools, current_tool_id for ongoing | |
| tool_index_to_use = ( | |
| completing_tool_id | |
| if is_current_complete | |
| else self.current_tool_id | |
| ) | |
| res = StreamingParseResult( | |
| calls=[ | |
| ToolCallItem( | |
| tool_index=tool_index_to_use, | |
| parameters=argument_diff, | |
| ) | |
| ], | |
| ) | |
| self.streamed_args_for_tool[tool_index_to_use] += argument_diff | |
| return res | |
| except Exception as e: | |
| logger.error(f"Error in parse_streaming_increment: {e}") | |
| return StreamingParseResult() | |
| def has_tool_call(self, text: str) -> bool: | |
| """ | |
| Check if the given text contains function call markers specific to this format. | |
| """ | |
| raise NotImplementedError() | |
| def supports_structural_tag(self) -> bool: | |
| """Return True if this detector supports structural tag format.""" | |
| return True | |
| def structure_info(self) -> _GetInfoFunc: | |
| """ | |
| Return a function that creates StructureInfo for constrained generation. | |
| The returned function takes a tool name and returns a StructureInfo object | |
| containing the begin/end patterns and trigger tokens needed for constrained | |
| generation of function calls in this format. | |
| Returns: | |
| A function that takes a tool name (str) and returns StructureInfo | |
| """ | |
| raise NotImplementedError() | |
| def build_ebnf(self, tools: List[Tool]) -> str: | |
| """ | |
| Build an EBNF grammar for constrained generation of function calls. | |
| This method generates an Extended Backus-Naur Form (EBNF) grammar that | |
| constrains the model's output to valid function calls in this format. | |
| The grammar should include all available tools and their parameter schemas. | |
| Args: | |
| tools: List of available tools/functions that can be called | |
| Returns: | |
| A string containing the EBNF grammar for this function call format | |
| The EBNF grammar should: | |
| - Define the overall structure of function calls in this format | |
| - Include all tool names from the provided tools list | |
| - Define valid JSON structures for function arguments | |
| - Handle multiple function calls if the format supports them | |
| Note: | |
| Most implementations use EBNFComposer.build_ebnf() utility with | |
| format-specific parameters rather than writing EBNF from scratch. | |
| """ | |
| raise NotImplementedError() | |
Xet Storage Details
- Size:
- 15.6 kB
- Xet hash:
- 181601ec8359ae97161bb690db310b73ea0cd5a79c883c43e2158a56c37b39fa
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.