File size: 2,587 Bytes
147e220 71a7158 147e220 71a7158 147e220 71a7158 147e220 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | # tests/test_llm_wrapper.py
from __future__ import annotations
import os
import unittest
from unittest.mock import patch
os.environ.setdefault("BORDERLESS_INFERENCE_MODE", "hub")
os.environ.setdefault("BORDERLESS_PRELOAD_MODEL", "0")
from langchain_openai import ChatOpenAI
from ui.agent.graph.llm import HubInferenceChatModel, MiniCPMChatModel, build_llm
class BuildLlmTests(unittest.TestCase):
@patch("ui.agent.graph.llm.local_inference_enabled", return_value=True)
def test_local_mode_returns_minicpm_wrapper(self, _mock_local: unittest.mock.Mock) -> None:
llm = build_llm({"configurable": {"max_tokens": 512, "temperature": 0.2, "top_p": 0.8}})
self.assertIsInstance(llm, MiniCPMChatModel)
self.assertEqual(llm.max_tokens, 512)
@patch("ui.agent.graph.llm.INFERENCE_MODE", "hub")
@patch("ui.agent.graph.llm.local_inference_enabled", return_value=False)
def test_hub_mode_returns_inference_client_wrapper(
self,
_mock_local: unittest.mock.Mock,
) -> None:
llm = build_llm(
{
"configurable": {
"hf_token": "test-token",
"max_tokens": 512,
"temperature": 0.2,
"top_p": 0.8,
}
}
)
self.assertIsInstance(llm, HubInferenceChatModel)
self.assertEqual(llm.hf_token, "test-token")
@patch("ui.agent.graph.llm.INFERENCE_MODE", "router")
@patch("ui.agent.graph.llm.local_inference_enabled", return_value=False)
def test_router_mode_returns_chat_openai(self, _mock_local: unittest.mock.Mock) -> None:
llm = build_llm(
{
"configurable": {
"hf_token": "test-token",
"max_tokens": 512,
"temperature": 0.2,
"top_p": 0.8,
}
}
)
self.assertIsInstance(llm, ChatOpenAI)
@patch("ui.agent.graph.llm.local_inference_enabled", return_value=True)
def test_minicpm_bind_tools_preserves_generation_params(
self,
_mock_local: unittest.mock.Mock,
) -> None:
llm = MiniCPMChatModel(max_tokens=128, temperature=0.1, top_p=0.5)
bound = llm.bind_tools(
[{"type": "function", "function": {"name": "search_immigration_info"}}]
)
self.assertEqual(bound.max_tokens, 128)
self.assertEqual(len(bound.tools), 1)
if __name__ == "__main__":
unittest.main()
|