Spaces:
Paused
Paused
| diff --git a/python/sglang/srt/function_call/qwen25_detector.py b/python/sglang/srt/function_call/qwen25_detector.py | |
| index cee3f18e..30c911ed 100644 | |
| --- a/python/sglang/srt/function_call/qwen25_detector.py | |
| +++ b/python/sglang/srt/function_call/qwen25_detector.py | |
| import json | |
| import logging | |
| import re | |
| -from typing import List | |
| +from typing import List, Any | |
| from sglang.srt.entrypoints.openai.protocol import Tool | |
| from sglang.srt.function_call.base_format_detector import BaseFormatDetector | |
| from sglang.srt.function_call.core_types import ( | |
| StreamingParseResult, | |
| StructureInfo, | |
| _GetInfoFunc, | |
| + ToolCallItem, | |
| ) | |
| from sglang.srt.function_call.ebnf_composer import EBNFComposer | |
| class Qwen25Detector(BaseFormatDetector): | |
| """Check if the text contains a Qwen 2.5 format tool call.""" | |
| return self.bot_token in text | |
| + def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]: | |
| + """Override base_json parsing to handle Qwen2.5's specific argument format.""" | |
| + tool_indices = { | |
| + tool.function.name: i for i, tool in enumerate(tools) if tool.function.name | |
| + } | |
| + if not isinstance(action, list): | |
| + action = [action] | |
| + | |
| + results = [] | |
| + for act in action: | |
| + name = act.get("name") | |
| + if name and name in tool_indices: | |
| + # Get arguments, which may be a string or dict | |
| + arguments = act.get("parameters") or act.get("arguments", {}) | |
| + | |
| + # Handle the case where arguments is a JSON-encoded string (possibly multiple times) | |
| + if isinstance(arguments, str): | |
| + try: | |
| + # Try to parse the string as JSON first | |
| + parsed_arguments = json.loads(arguments) | |
| + | |
| + # If the result is still a string, it might be double-encoded, so parse again | |
| + while isinstance(parsed_arguments, str): | |
| + try: | |
| + parsed_arguments = json.loads(parsed_arguments) | |
| + except (json.JSONDecodeError, TypeError): | |
| + # If parsing fails, stop trying | |
| + break | |
| + | |
| + # Re-encode the final result properly | |
| + arguments = json.dumps(parsed_arguments, ensure_ascii=False) | |
| + except (json.JSONDecodeError, TypeError): | |
| + # If parsing fails, it might already be a proper JSON string | |
| + # or it might be a malformed string, so use it as-is | |
| + pass | |
| + else: | |
| + # If arguments is not a string, convert it to JSON | |
| + arguments = json.dumps(arguments, ensure_ascii=False) | |
| + | |
| + results.append( | |
| + ToolCallItem( | |
| + tool_index=-1, # Caller should update this based on the actual tools array called | |
| + name=name, | |
| + parameters=arguments, | |
| + ) | |
| + ) | |
| + else: | |
| + logger.warning(f"Model attempted to call undefined function: {name}") | |
| + | |
| + return results | |
| + | |
| def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: | |
| """ | |
| One-time parsing: Detects and parses tool calls in the provided text. | |
| diff --git a/test/srt/test_function_call_parser.py b/test/srt/test_function_call_parser.py | |
| index 35b75d71..32dfb09e 100644 | |
| --- a/test/srt/test_function_call_parser.py | |
| +++ b/test/srt/test_function_call_parser.py | |
| class TestEBNFGeneration(unittest.TestCase): | |
| self.assertIsNotNone(ebnf) | |
| # Check that the EBNF contains expected patterns | |
| - self.assertIn("<|tool▁calls▁begin|>", ebnf) | |
| - self.assertIn("<|tool▁call▁begin|>function<|tool▁sep|>get_weather", ebnf) | |
| + self.assertIn("<|tool▁call▁begin|>function::get_weather", ebnf) | |
| self.assertIn('\\"location\\"" ":" basic_string ', ebnf) | |
| # Validate that the EBNF can be compiled by GrammarCompiler | |
| class TestLlama32Detector(unittest.TestCase): | |
| self.assertTrue(result.normal_text.strip().startswith("Some intro.")) | |
| +class TestQwen25Detector(unittest.TestCase): | |
| + def setUp(self): | |
| + """Set up test tools and detector for Qwen25 format testing.""" | |
| + self.tools = [ | |
| + Tool( | |
| + type="function", | |
| + function=Function( | |
| + name="get_website_info", | |
| + description="Get information about a website", | |
| + parameters={ | |
| + "type": "object", | |
| + "properties": { | |
| + "url": { | |
| + "type": "string", | |
| + "description": "The URL to get information about", | |
| + } | |
| + }, | |
| + "required": ["url"], | |
| + }, | |
| + ), | |
| + ), | |
| + ] | |
| + self.detector = Qwen25Detector() | |
| + | |
| + def test_parse_json_encoded_string_arguments(self): | |
| + """Test parsing when arguments are provided as JSON-encoded strings (the problematic case).""" | |
| + # This is the problematic case from Qwen2.5 where arguments are JSON-encoded strings | |
| + test_action = { | |
| + "name": "get_website_info", | |
| + "arguments": '{"url": "https://huggingface.co/agentsoc/spaces"}' | |
| + } | |
| + | |
| + result = self.detector.parse_base_json(test_action, self.tools) | |
| + | |
| + self.assertEqual(len(result), 1) | |
| + call = result[0] | |
| + self.assertEqual(call.name, "get_website_info") | |
| + | |
| + # The parameters should be properly formatted JSON | |
| + params = json.loads(call.parameters) | |
| + self.assertEqual(params["url"], "https://huggingface.co/agentsoc/spaces") | |
| + | |
| + # Verify that the parameters string is valid JSON | |
| + self.assertIsInstance(call.parameters, str) | |
| + json.loads(call.parameters) # Should not raise an exception | |
| + | |
| + def test_parse_dict_arguments(self): | |
| + """Test parsing when arguments are provided as dictionaries (the normal case).""" | |
| + test_action = { | |
| + "name": "get_website_info", | |
| + "arguments": {"url": "https://huggingface.co/agentsoc/spaces"} | |
| + } | |
| + | |
| + result = self.detector.parse_base_json(test_action, self.tools) | |
| + | |
| + self.assertEqual(len(result), 1) | |
| + call = result[0] | |
| + self.assertEqual(call.name, "get_website_info") | |
| + | |
| + # The parameters should be properly formatted JSON | |
| + params = json.loads(call.parameters) | |
| + self.assertEqual(params["url"], "https://huggingface.co/agentsoc/spaces") | |
| + | |
| + def test_parse_double_encoded_string_arguments(self): | |
| + """Test parsing when arguments are double JSON-encoded strings.""" | |
| + # This tests the extreme case where the string is double-encoded | |
| + test_action = { | |
| + "name": "get_website_info", | |
| + "arguments": '"{\\\"url\\\":\\\"https://huggingface.co/agentsoc/spaces\\\"}"' | |
| + } | |
| + | |
| + result = self.detector.parse_base_json(test_action, self.tools) | |
| + | |
| + self.assertEqual(len(result), 1) | |
| + call = result[0] | |
| + self.assertEqual(call.name, "get_website_info") | |
| + | |
| + # The parameters should be properly formatted JSON | |
| + params = json.loads(call.parameters) | |
| + self.assertEqual(params["url"], "https://huggingface.co/agentsoc/spaces") | |
| + | |
| + def test_parse_malformed_string_arguments(self): | |
| + """Test parsing when arguments are malformed strings.""" | |
| + test_action = { | |
| + "name": "get_website_info", | |
| + "arguments": "this is not valid json" | |
| + } | |
| + | |
| + result = self.detector.parse_base_json(test_action, self.tools) | |
| + | |
| + self.assertEqual(len(result), 1) | |
| + call = result[0] | |
| + self.assertEqual(call.name, "get_website_info") | |
| + | |
| + # The parameters should be the original string since it couldn't be parsed | |
| + self.assertEqual(call.parameters, "this is not valid json") | |
| + | |
| + def test_detect_and_parse_with_tool_call_format(self): | |
| + """Test full detect_and_parse with complete tool call format.""" | |
| + test_text = '<tool_call>\n{"name": "get_website_info", "arguments": {"url": "https://huggingface.co/agentsoc/spaces"}}\n</tool_call>' | |
| + | |
| + result = self.detector.detect_and_parse(test_text, self.tools) | |
| + | |
| + self.assertEqual(len(result.calls), 1) | |
| + call = result.calls[0] | |
| + self.assertEqual(call.name, "get_website_info") | |
| + | |
| + params = json.loads(call.parameters) | |
| + self.assertEqual(params["url"], "https://huggingface.co/agentsoc/spaces") | |
| + | |
| + # Normal text should be empty for pure tool call | |
| + self.assertEqual(result.normal_text, "") | |
| + | |
| + | |
| if __name__ == "__main__": | |
| unittest.main() | |