|
|
| from __future__ import annotations
|
|
|
| import os
|
| import unittest
|
| from unittest.mock import patch
|
|
|
| os.environ.setdefault("BORDERLESS_INFERENCE_MODE", "hub")
|
| os.environ.setdefault("BORDERLESS_PRELOAD_MODEL", "0")
|
|
|
| from langchain_openai import ChatOpenAI
|
|
|
| from ui.agent.graph.llm import HubInferenceChatModel, MiniCPMChatModel, build_llm
|
|
|
|
|
| class BuildLlmTests(unittest.TestCase):
|
| @patch("ui.agent.graph.llm.local_inference_enabled", return_value=True)
|
| def test_local_mode_returns_minicpm_wrapper(self, _mock_local: unittest.mock.Mock) -> None:
|
| llm = build_llm({"configurable": {"max_tokens": 512, "temperature": 0.2, "top_p": 0.8}})
|
| self.assertIsInstance(llm, MiniCPMChatModel)
|
| self.assertEqual(llm.max_tokens, 512)
|
|
|
| @patch("ui.agent.graph.llm.INFERENCE_MODE", "hub")
|
| @patch("ui.agent.graph.llm.local_inference_enabled", return_value=False)
|
| def test_hub_mode_returns_inference_client_wrapper(
|
| self,
|
| _mock_local: unittest.mock.Mock,
|
| ) -> None:
|
| llm = build_llm(
|
| {
|
| "configurable": {
|
| "hf_token": "test-token",
|
| "max_tokens": 512,
|
| "temperature": 0.2,
|
| "top_p": 0.8,
|
| }
|
| }
|
| )
|
| self.assertIsInstance(llm, HubInferenceChatModel)
|
| self.assertEqual(llm.hf_token, "test-token")
|
|
|
| @patch("ui.agent.graph.llm.INFERENCE_MODE", "router")
|
| @patch("ui.agent.graph.llm.local_inference_enabled", return_value=False)
|
| def test_router_mode_returns_chat_openai(self, _mock_local: unittest.mock.Mock) -> None:
|
| llm = build_llm(
|
| {
|
| "configurable": {
|
| "hf_token": "test-token",
|
| "max_tokens": 512,
|
| "temperature": 0.2,
|
| "top_p": 0.8,
|
| }
|
| }
|
| )
|
| self.assertIsInstance(llm, ChatOpenAI)
|
|
|
| @patch("ui.agent.graph.llm.local_inference_enabled", return_value=True)
|
| def test_minicpm_bind_tools_preserves_generation_params(
|
| self,
|
| _mock_local: unittest.mock.Mock,
|
| ) -> None:
|
| llm = MiniCPMChatModel(max_tokens=128, temperature=0.1, top_p=0.5)
|
| bound = llm.bind_tools(
|
| [{"type": "function", "function": {"name": "search_immigration_info"}}]
|
| )
|
| self.assertEqual(bound.max_tokens, 128)
|
| self.assertEqual(len(bound.tools), 1)
|
|
|
|
|
| if __name__ == "__main__":
|
| unittest.main()
|
|
|