| | import ast |
| | import json |
| | import logging |
| | import re |
| | from typing import Any, List, Optional |
| |
|
| | from sglang.srt.entrypoints.openai.protocol import Tool |
| | from sglang.srt.function_call.base_format_detector import BaseFormatDetector |
| | from sglang.srt.function_call.core_types import ( |
| | StreamingParseResult, |
| | ToolCallItem, |
| | _GetInfoFunc, |
| | ) |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class Qwen3CoderDetector(BaseFormatDetector): |
| | def __init__(self): |
| | super().__init__() |
| |
|
| | |
| | self.tool_call_start_token: str = "<tool_call>" |
| | self.tool_call_end_token: str = "</tool_call>" |
| | self.tool_call_prefix: str = "<function=" |
| | self.function_end_token: str = "</function>" |
| | self.parameter_prefix: str = "<parameter=" |
| | self.parameter_end_token: str = "</parameter>" |
| |
|
| | |
| | self.tool_call_regex = re.compile(r"<tool_call>(.*?)</tool_call>", re.DOTALL) |
| | self.tool_call_function_regex = re.compile( |
| | r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL |
| | ) |
| | self.tool_call_parameter_regex = re.compile( |
| | r"<parameter=(.*?)(?:</parameter>|(?=<parameter=)|(?=</function>)|$)", |
| | re.DOTALL, |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | self.parsed_pos: int = 0 |
| | |
| | self.current_tool_param_count: int = 0 |
| | |
| | self.json_started: bool = False |
| |
|
| | |
| | self.is_inside_tool_call: bool = False |
| |
|
| | |
| | self.current_func_name: Optional[str] = None |
| |
|
| | def has_tool_call(self, text: str) -> bool: |
| | return self.tool_call_start_token in text |
| |
|
| | def _get_arguments_config( |
| | self, func_name: str, tools: Optional[list[Tool]] |
| | ) -> dict: |
| | """Extract argument configuration for a function.""" |
| | if tools is None: |
| | return {} |
| | for config in tools: |
| | try: |
| | config_type = config.type |
| | config_function = config.function |
| | config_function_name = config_function.name |
| | except AttributeError: |
| | continue |
| |
|
| | if config_type == "function" and config_function_name == func_name: |
| | try: |
| | params = config_function.parameters |
| | except AttributeError: |
| | return {} |
| |
|
| | if isinstance(params, dict) and "properties" in params: |
| | return params["properties"] |
| | elif isinstance(params, dict): |
| | return params |
| | else: |
| | return {} |
| | logger.warning(f"Tool '{func_name}' is not defined in the tools list.") |
| | return {} |
| |
|
| | def _convert_param_value( |
| | self, param_value: str, param_name: str, param_config: dict, func_name: str |
| | ) -> Any: |
| | """Convert parameter value based on its type in the schema.""" |
| | |
| | if param_value.lower() == "null": |
| | return None |
| |
|
| | if param_name not in param_config: |
| | if param_config != {}: |
| | logger.warning( |
| | f"Parsed parameter '{param_name}' is not defined in the tool " |
| | f"parameters for tool '{func_name}', directly returning the string value." |
| | ) |
| | return param_value |
| |
|
| | if ( |
| | isinstance(param_config[param_name], dict) |
| | and "type" in param_config[param_name] |
| | ): |
| | param_type = str(param_config[param_name]["type"]).strip().lower() |
| | else: |
| | param_type = "string" |
| | if param_type in ["string", "str", "text", "varchar", "char", "enum"]: |
| | return param_value |
| | elif ( |
| | param_type.startswith("int") |
| | or param_type.startswith("uint") |
| | or param_type.startswith("long") |
| | or param_type.startswith("short") |
| | or param_type.startswith("unsigned") |
| | ): |
| | try: |
| | param_value = int(param_value) |
| | except Exception: |
| | logger.warning( |
| | f"Parsed value '{param_value}' of parameter '{param_name}' is not an integer in tool " |
| | f"'{func_name}', degenerating to string." |
| | ) |
| | return param_value |
| | elif param_type.startswith("num") or param_type.startswith("float"): |
| | try: |
| | maybe_convert = ( |
| | False if "." in param_value or "e" in param_value.lower() else True |
| | ) |
| | param_value: float = float(param_value) |
| | if maybe_convert and param_value.is_integer(): |
| | param_value = int(param_value) |
| | except Exception: |
| | logger.warning( |
| | f"Parsed value '{param_value}' of parameter '{param_name}' is not a float in tool " |
| | f"'{func_name}', degenerating to string." |
| | ) |
| | return param_value |
| | elif param_type in ["boolean", "bool", "binary"]: |
| | param_value = param_value.lower() |
| | if param_value not in ["true", "false"]: |
| | logger.warning( |
| | f"Parsed value '{param_value}' of parameter '{param_name}' is not a boolean (`true` of `false`) in tool '{func_name}', degenerating to false." |
| | ) |
| | return param_value == "true" |
| | else: |
| | if ( |
| | param_type in ["object", "array", "arr"] |
| | or param_type.startswith("dict") |
| | or param_type.startswith("list") |
| | ): |
| | try: |
| | param_value = json.loads(param_value) |
| | return param_value |
| | except Exception: |
| | logger.warning( |
| | f"Parsed value '{param_value}' of parameter '{param_name}' cannot be parsed with json.loads in tool " |
| | f"'{func_name}', will try other methods to parse it." |
| | ) |
| | try: |
| | param_value = ast.literal_eval(param_value) |
| | except Exception: |
| | logger.warning( |
| | f"Parsed value '{param_value}' of parameter '{param_name}' cannot be converted via Python `ast.literal_eval()` in tool '{func_name}', degenerating to string." |
| | ) |
| | return param_value |
| |
|
| | def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: |
| | """One-shot parsing for non-streaming scenarios.""" |
| | if self.tool_call_start_token not in text: |
| | return StreamingParseResult(normal_text=text) |
| |
|
| | calls = [] |
| | try: |
| | |
| | |
| | raw_tool_calls = self.tool_call_regex.findall(text) |
| | if not raw_tool_calls: |
| | |
| | if self.tool_call_prefix in text: |
| | raw_tool_calls = [text] |
| |
|
| | tool_idx = 0 |
| | for tool_content in raw_tool_calls: |
| | |
| | funcs = self.tool_call_function_regex.findall(tool_content) |
| | for func_match in funcs: |
| | func_body = func_match[0] or func_match[1] |
| | if ">" not in func_body: |
| | continue |
| |
|
| | name_end = func_body.index(">") |
| | func_name = func_body[:name_end] |
| | params_str = func_body[name_end + 1 :] |
| |
|
| | param_config = self._get_arguments_config(func_name, tools) |
| | parsed_params = {} |
| |
|
| | for p_match in self.tool_call_parameter_regex.findall(params_str): |
| | if ">" not in p_match: |
| | continue |
| | p_idx = p_match.index(">") |
| | p_name = p_match[:p_idx] |
| | p_val = p_match[p_idx + 1 :] |
| | |
| | if p_val.startswith("\n"): |
| | p_val = p_val[1:] |
| | if p_val.endswith("\n"): |
| | p_val = p_val[:-1] |
| |
|
| | parsed_params[p_name] = self._convert_param_value( |
| | p_val, p_name, param_config, func_name |
| | ) |
| |
|
| | calls.append( |
| | ToolCallItem( |
| | tool_index=tool_idx, |
| | name=func_name, |
| | parameters=json.dumps(parsed_params, ensure_ascii=False), |
| | ) |
| | ) |
| | tool_idx += 1 |
| |
|
| | |
| | start_idx = text.find(self.tool_call_start_token) |
| | if start_idx == -1: |
| | start_idx = text.find(self.tool_call_prefix) |
| | normal_text = text[:start_idx] if start_idx > 0 else "" |
| |
|
| | return StreamingParseResult(normal_text=normal_text, calls=calls) |
| |
|
| | except Exception as e: |
| | logger.error(f"Error in detect_and_parse: {e}") |
| | return StreamingParseResult(normal_text=text) |
| |
|
| | def parse_streaming_increment( |
| | self, new_text: str, tools: List[Tool] |
| | ) -> StreamingParseResult: |
| | """ |
| | Robust cursor-based streaming parser. |
| | """ |
| | self._buffer += new_text |
| |
|
| | |
| | if not self._buffer: |
| | return StreamingParseResult() |
| |
|
| | calls = [] |
| | normal_text_chunks = [] |
| |
|
| | while True: |
| | |
| | current_slice = self._buffer[self.parsed_pos :] |
| |
|
| | |
| | if not current_slice: |
| | break |
| |
|
| | |
| | |
| | |
| | if current_slice.startswith(self.tool_call_start_token): |
| | self.parsed_pos += len(self.tool_call_start_token) |
| | self.is_inside_tool_call = True |
| | continue |
| |
|
| | |
| | |
| | |
| | if current_slice.startswith(self.tool_call_prefix): |
| | end_angle = current_slice.find(">") |
| | if end_angle != -1: |
| | func_name = current_slice[len(self.tool_call_prefix) : end_angle] |
| |
|
| | self.current_tool_id += 1 |
| | self.current_tool_name_sent = True |
| | self.current_tool_param_count = 0 |
| | self.json_started = False |
| | self.current_func_name = func_name |
| |
|
| | calls.append( |
| | ToolCallItem( |
| | tool_index=self.current_tool_id, |
| | name=func_name, |
| | parameters="", |
| | ) |
| | ) |
| |
|
| | self.parsed_pos += end_angle + 1 |
| | continue |
| | else: |
| | |
| | break |
| |
|
| | |
| | |
| | |
| | if current_slice.startswith(self.parameter_prefix): |
| | name_end = current_slice.find(">") |
| | if name_end != -1: |
| | value_start_idx = name_end + 1 |
| | rest_of_slice = current_slice[value_start_idx:] |
| |
|
| | |
| | |
| | |
| | |
| | |
| | cand_end_param = rest_of_slice.find(self.parameter_end_token) |
| | cand_next_param = rest_of_slice.find(self.parameter_prefix) |
| | cand_end_func = rest_of_slice.find(self.function_end_token) |
| |
|
| | candidates = [] |
| | if cand_end_param != -1: |
| | candidates.append( |
| | (cand_end_param, len(self.parameter_end_token)) |
| | ) |
| | if cand_next_param != -1: |
| | candidates.append((cand_next_param, 0)) |
| | if cand_end_func != -1: |
| | candidates.append((cand_end_func, 0)) |
| |
|
| | if candidates: |
| | best_cand = min(candidates, key=lambda x: x[0]) |
| | end_pos = best_cand[0] |
| | end_token_len = best_cand[1] |
| |
|
| | param_name = current_slice[ |
| | len(self.parameter_prefix) : name_end |
| | ] |
| | raw_value = rest_of_slice[:end_pos] |
| |
|
| | |
| | if raw_value.startswith("\n"): |
| | raw_value = raw_value[1:] |
| | if raw_value.endswith("\n"): |
| | raw_value = raw_value[:-1] |
| |
|
| | |
| | if not self.json_started: |
| | calls.append( |
| | ToolCallItem( |
| | tool_index=self.current_tool_id, parameters="{" |
| | ) |
| | ) |
| | self.json_started = True |
| |
|
| | param_config = self._get_arguments_config( |
| | self.current_func_name, tools |
| | ) |
| | converted_val = self._convert_param_value( |
| | raw_value, param_name, param_config, self.current_func_name |
| | ) |
| |
|
| | |
| | |
| | json_key_val = f"{json.dumps(param_name)}: {json.dumps(converted_val, ensure_ascii=False)}" |
| |
|
| | if self.current_tool_param_count > 0: |
| | fragment = f", {json_key_val}" |
| | else: |
| | fragment = json_key_val |
| |
|
| | calls.append( |
| | ToolCallItem( |
| | tool_index=self.current_tool_id, parameters=fragment |
| | ) |
| | ) |
| | self.current_tool_param_count += 1 |
| |
|
| | |
| | total_len = (name_end + 1) + end_pos + end_token_len |
| | self.parsed_pos += total_len |
| | continue |
| |
|
| | |
| | break |
| |
|
| | |
| | |
| | |
| | if current_slice.startswith(self.function_end_token): |
| | if not self.json_started: |
| | calls.append( |
| | ToolCallItem(tool_index=self.current_tool_id, parameters="{") |
| | ) |
| | self.json_started = True |
| |
|
| | calls.append( |
| | ToolCallItem(tool_index=self.current_tool_id, parameters="}") |
| | ) |
| | self.parsed_pos += len(self.function_end_token) |
| | self.current_func_name = None |
| | continue |
| |
|
| | |
| | |
| | |
| | if current_slice.startswith(self.tool_call_end_token): |
| | self.parsed_pos += len(self.tool_call_end_token) |
| | self.is_inside_tool_call = False |
| | continue |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | next_open_angle = current_slice.find("<") |
| |
|
| | if next_open_angle == -1: |
| | |
| | if not self.is_inside_tool_call: |
| | normal_text_chunks.append(current_slice) |
| | |
| | self.parsed_pos += len(current_slice) |
| | continue |
| |
|
| | elif next_open_angle == 0: |
| | |
| |
|
| | possible_tags = [ |
| | self.tool_call_start_token, |
| | self.tool_call_end_token, |
| | self.tool_call_prefix, |
| | self.function_end_token, |
| | self.parameter_prefix, |
| | self.parameter_end_token, |
| | ] |
| |
|
| | is_potential_tag = False |
| | for tag in possible_tags: |
| | if tag.startswith(current_slice): |
| | is_potential_tag = True |
| | break |
| |
|
| | if is_potential_tag: |
| | break |
| | else: |
| | |
| | if not self.is_inside_tool_call: |
| | normal_text_chunks.append("<") |
| | self.parsed_pos += 1 |
| | continue |
| |
|
| | else: |
| | |
| | text_segment = current_slice[:next_open_angle] |
| | if not self.is_inside_tool_call: |
| | normal_text_chunks.append(text_segment) |
| | |
| | self.parsed_pos += next_open_angle |
| | continue |
| |
|
| | |
| | |
| | if self.parsed_pos > 0: |
| | self._buffer = self._buffer[self.parsed_pos :] |
| | self.parsed_pos = 0 |
| |
|
| | normal_text = "".join(normal_text_chunks) if normal_text_chunks else "" |
| | return StreamingParseResult(calls=calls, normal_text=normal_text) |
| |
|
| | def supports_structural_tag(self) -> bool: |
| | return False |
| |
|
| | def structure_info(self) -> _GetInfoFunc: |
| | raise NotImplementedError |
| |
|