gemma-sage / tests /test_chat_clean.py
neuralworm's picture
fix: Chat Output & Tool Hiding
5713c03
import unittest
from unittest.mock import MagicMock, patch
import json
import logging
# Check patching
with patch('transformers.AutoProcessor.from_pretrained'), \
patch('transformers.AutoModelForCausalLM.from_pretrained'), \
patch('transformers.AutoTokenizer.from_pretrained'):
import agent_module
class TestChatClean(unittest.TestCase):
def setUp(self):
self.mock_llm_patcher = patch('agent_module.get_llm')
self.mock_llm = self.mock_llm_patcher.start()
self.mock_model = MagicMock()
self.mock_model.device = "cpu"
self.mock_processor = MagicMock()
self.mock_llm.return_value = (self.mock_model, self.mock_processor)
self.mock_oracle_patcher = patch('agent_module.get_oracle_data')
self.mock_oracle = self.mock_oracle_patcher.start()
self.mock_oracle.return_value = {"wisdom": "Silence"}
self.mock_streamer_patcher = patch('agent_module.TextIteratorStreamer')
self.mock_streamer = self.mock_streamer_patcher.start()
def tearDown(self):
self.mock_llm_patcher.stop()
self.mock_oracle_patcher.stop()
self.mock_streamer_patcher.stop()
def test_clean_output_tool_hiding(self):
"""Verify tool calls are hidden and status messages removed."""
tool_json = json.dumps({"name": "oracle_consultation", "arguments": {"topic": "Love", "date_str": "today", "name": "Seeker"}})
# Simulate LLM outputting text THEN a tool call
# "Looking up... <tool_call>...</tool_call>"
# We mock the streamer to yield tokens
tokens = ["Looking ", "up...", " <tool_call>", tool_json, "</tool_call>"]
self.mock_streamer.return_value = iter(tokens)
gen = agent_module.chat_agent_stream("Thema: Liebe", [])
results = list(gen)
# Check 1: No JSON in output
for res in results:
self.assertNotIn('{"name":', res, "Output contained raw JSON tool call")
self.assertNotIn('<tool_call>', res, "Output contained <tool_call> tag")
# Check 2: No Status Message
self.assertNotIn("*(Consulting the Oracle...)*", results, "Output contained removed status message")
# Check 3: Final result should be clean text
# The generator yields accumulated text.
# "Looking up..." should be there.
# But specifically, we expect "Looking up..."
self.assertTrue(any("Looking up..." in r for r in results))
def test_accumulation(self):
"""Verify that yielded values are accumulated text, not just tokens."""
# Tokens: "H", "e", "l", "l", "o"
tokens = list("Hello")
self.mock_streamer.return_value = iter(tokens)
gen = agent_module.chat_agent_stream("Hi", [])
results = list(gen)
# The last result should be "Hello"
self.assertEqual(results[-1], "Hello")
# The sequence should be growing: "H", "He", "Hel"...
self.assertEqual(results[0], "H")
self.assertEqual(results[1], "He")
if __name__ == '__main__':
unittest.main()