File size: 2,754 Bytes
a402b9b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | import json
import logging
import pytest
from sglang.srt.entrypoints.openai.protocol import Function, Tool
from sglang.srt.environ import envs
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import StreamingParseResult
from sglang.test.ci.ci_register import register_cpu_ci
register_cpu_ci(1.0, "default")
class DummyDetector(BaseFormatDetector):
def has_tool_call(self, text: str) -> bool:
return True
def detect_and_parse(self, text: str, tools):
action = json.loads(text)
return StreamingParseResult(
normal_text="", calls=self.parse_base_json(action, tools)
)
def structure_info(self):
pass
def test_unknown_tool_name_dropped_default(caplog):
"""Test that unknown tools are dropped by default (legacy behavior)."""
with envs.SGLANG_FORWARD_UNKNOWN_TOOLS.override(False):
tools = [
Tool(
function=Function(
name="get_weather", parameters={"type": "object", "properties": {}}
)
)
]
detector = DummyDetector()
with caplog.at_level(
logging.WARNING, logger="sglang.srt.function_call.base_format_detector"
):
result = detector.detect_and_parse(
'{"name":"unknown_tool","parameters":{"city":"Paris"}}', tools
)
assert any(
"Model attempted to call undefined function: unknown_tool" in m
for m in caplog.messages
)
assert len(result.calls) == 0 # dropped in default mode
def test_unknown_tool_name_forwarded(caplog):
"""Test that unknown tools are forwarded when env var is True."""
with envs.SGLANG_FORWARD_UNKNOWN_TOOLS.override(True):
tools = [
Tool(
function=Function(
name="get_weather", parameters={"type": "object", "properties": {}}
)
)
]
detector = DummyDetector()
with caplog.at_level(
logging.WARNING, logger="sglang.srt.function_call.base_format_detector"
):
result = detector.detect_and_parse(
'{"name":"unknown_tool","parameters":{"city":"Paris"}}', tools
)
assert any(
"Model attempted to call undefined function: unknown_tool" in m
for m in caplog.messages
)
assert len(result.calls) == 1
assert result.calls[0].name == "unknown_tool"
assert result.calls[0].tool_index == -1
assert json.loads(result.calls[0].parameters)["city"] == "Paris"
if __name__ == "__main__":
import sys
sys.exit(pytest.main([__file__]))
|