borderless / tests /test_llm_wrapper.py
spagestic's picture
switched to qwen
71a7158
Raw
History Blame Contribute Delete
2.59 kB
# tests/test_llm_wrapper.py
from __future__ import annotations
import os
import unittest
from unittest.mock import patch
os.environ.setdefault("BORDERLESS_INFERENCE_MODE", "hub")
os.environ.setdefault("BORDERLESS_PRELOAD_MODEL", "0")
from langchain_openai import ChatOpenAI
from ui.agent.graph.llm import HubInferenceChatModel, MiniCPMChatModel, build_llm
class BuildLlmTests(unittest.TestCase):
@patch("ui.agent.graph.llm.local_inference_enabled", return_value=True)
def test_local_mode_returns_minicpm_wrapper(self, _mock_local: unittest.mock.Mock) -> None:
llm = build_llm({"configurable": {"max_tokens": 512, "temperature": 0.2, "top_p": 0.8}})
self.assertIsInstance(llm, MiniCPMChatModel)
self.assertEqual(llm.max_tokens, 512)
@patch("ui.agent.graph.llm.INFERENCE_MODE", "hub")
@patch("ui.agent.graph.llm.local_inference_enabled", return_value=False)
def test_hub_mode_returns_inference_client_wrapper(
self,
_mock_local: unittest.mock.Mock,
) -> None:
llm = build_llm(
{
"configurable": {
"hf_token": "test-token",
"max_tokens": 512,
"temperature": 0.2,
"top_p": 0.8,
}
}
)
self.assertIsInstance(llm, HubInferenceChatModel)
self.assertEqual(llm.hf_token, "test-token")
@patch("ui.agent.graph.llm.INFERENCE_MODE", "router")
@patch("ui.agent.graph.llm.local_inference_enabled", return_value=False)
def test_router_mode_returns_chat_openai(self, _mock_local: unittest.mock.Mock) -> None:
llm = build_llm(
{
"configurable": {
"hf_token": "test-token",
"max_tokens": 512,
"temperature": 0.2,
"top_p": 0.8,
}
}
)
self.assertIsInstance(llm, ChatOpenAI)
@patch("ui.agent.graph.llm.local_inference_enabled", return_value=True)
def test_minicpm_bind_tools_preserves_generation_params(
self,
_mock_local: unittest.mock.Mock,
) -> None:
llm = MiniCPMChatModel(max_tokens=128, temperature=0.1, top_p=0.5)
bound = llm.bind_tools(
[{"type": "function", "function": {"name": "search_immigration_info"}}]
)
self.assertEqual(bound.max_tokens, 128)
self.assertEqual(len(bound.tools), 1)
if __name__ == "__main__":
unittest.main()