Spaces:
Sleeping
Sleeping
| import unittest | |
| from unittest.mock import MagicMock, patch | |
| import json | |
| import logging | |
| # Check patching | |
| with patch('transformers.AutoProcessor.from_pretrained'), \ | |
| patch('transformers.AutoModelForCausalLM.from_pretrained'), \ | |
| patch('transformers.AutoTokenizer.from_pretrained'): | |
| import agent_module | |
| class TestChatClean(unittest.TestCase): | |
| def setUp(self): | |
| self.mock_llm_patcher = patch('agent_module.get_llm') | |
| self.mock_llm = self.mock_llm_patcher.start() | |
| self.mock_model = MagicMock() | |
| self.mock_model.device = "cpu" | |
| self.mock_processor = MagicMock() | |
| self.mock_llm.return_value = (self.mock_model, self.mock_processor) | |
| self.mock_oracle_patcher = patch('agent_module.get_oracle_data') | |
| self.mock_oracle = self.mock_oracle_patcher.start() | |
| self.mock_oracle.return_value = {"wisdom": "Silence"} | |
| self.mock_streamer_patcher = patch('agent_module.TextIteratorStreamer') | |
| self.mock_streamer = self.mock_streamer_patcher.start() | |
| def tearDown(self): | |
| self.mock_llm_patcher.stop() | |
| self.mock_oracle_patcher.stop() | |
| self.mock_streamer_patcher.stop() | |
| def test_clean_output_tool_hiding(self): | |
| """Verify tool calls are hidden and status messages removed.""" | |
| tool_json = json.dumps({"name": "oracle_consultation", "arguments": {"topic": "Love", "date_str": "today", "name": "Seeker"}}) | |
| # Simulate LLM outputting text THEN a tool call | |
| # "Looking up... <tool_call>...</tool_call>" | |
| # We mock the streamer to yield tokens | |
| tokens = ["Looking ", "up...", " <tool_call>", tool_json, "</tool_call>"] | |
| self.mock_streamer.return_value = iter(tokens) | |
| gen = agent_module.chat_agent_stream("Thema: Liebe", []) | |
| results = list(gen) | |
| # Check 1: No JSON in output | |
| for res in results: | |
| self.assertNotIn('{"name":', res, "Output contained raw JSON tool call") | |
| self.assertNotIn('<tool_call>', res, "Output contained <tool_call> tag") | |
| # Check 2: No Status Message | |
| self.assertNotIn("*(Consulting the Oracle...)*", results, "Output contained removed status message") | |
| # Check 3: Final result should be clean text | |
| # The generator yields accumulated text. | |
| # "Looking up..." should be there. | |
| # But specifically, we expect "Looking up..." | |
| self.assertTrue(any("Looking up..." in r for r in results)) | |
| def test_accumulation(self): | |
| """Verify that yielded values are accumulated text, not just tokens.""" | |
| # Tokens: "H", "e", "l", "l", "o" | |
| tokens = list("Hello") | |
| self.mock_streamer.return_value = iter(tokens) | |
| gen = agent_module.chat_agent_stream("Hi", []) | |
| results = list(gen) | |
| # The last result should be "Hello" | |
| self.assertEqual(results[-1], "Hello") | |
| # The sequence should be growing: "H", "He", "Hel"... | |
| self.assertEqual(results[0], "H") | |
| self.assertEqual(results[1], "He") | |
| if __name__ == '__main__': | |
| unittest.main() | |