| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import unittest |
|
|
| from transformers.agents.agent_types import AgentImage |
| from transformers.agents.agents import AgentError, ReactCodeAgent, ReactJsonAgent |
| from transformers.agents.monitoring import stream_to_gradio |
|
|
|
|
| class MonitoringTester(unittest.TestCase): |
| def test_code_agent_metrics(self): |
| class FakeLLMEngine: |
| def __init__(self): |
| self.last_input_token_count = 10 |
| self.last_output_token_count = 20 |
|
|
| def __call__(self, prompt, **kwargs): |
| return """ |
| Code: |
| ```py |
| final_answer('This is the final answer.') |
| ```""" |
|
|
| agent = ReactCodeAgent( |
| tools=[], |
| llm_engine=FakeLLMEngine(), |
| max_iterations=1, |
| ) |
|
|
| agent.run("Fake task") |
|
|
| self.assertEqual(agent.monitor.total_input_token_count, 10) |
| self.assertEqual(agent.monitor.total_output_token_count, 20) |
|
|
| def test_json_agent_metrics(self): |
| class FakeLLMEngine: |
| def __init__(self): |
| self.last_input_token_count = 10 |
| self.last_output_token_count = 20 |
|
|
| def __call__(self, prompt, **kwargs): |
| return 'Action:{"action": "final_answer", "action_input": {"answer": "image"}}' |
|
|
| agent = ReactJsonAgent( |
| tools=[], |
| llm_engine=FakeLLMEngine(), |
| max_iterations=1, |
| ) |
|
|
| agent.run("Fake task") |
|
|
| self.assertEqual(agent.monitor.total_input_token_count, 10) |
| self.assertEqual(agent.monitor.total_output_token_count, 20) |
|
|
| def test_code_agent_metrics_max_iterations(self): |
| class FakeLLMEngine: |
| def __init__(self): |
| self.last_input_token_count = 10 |
| self.last_output_token_count = 20 |
|
|
| def __call__(self, prompt, **kwargs): |
| return "Malformed answer" |
|
|
| agent = ReactCodeAgent( |
| tools=[], |
| llm_engine=FakeLLMEngine(), |
| max_iterations=1, |
| ) |
|
|
| agent.run("Fake task") |
|
|
| self.assertEqual(agent.monitor.total_input_token_count, 20) |
| self.assertEqual(agent.monitor.total_output_token_count, 40) |
|
|
| def test_code_agent_metrics_generation_error(self): |
| class FakeLLMEngine: |
| def __init__(self): |
| self.last_input_token_count = 10 |
| self.last_output_token_count = 20 |
|
|
| def __call__(self, prompt, **kwargs): |
| raise AgentError |
|
|
| agent = ReactCodeAgent( |
| tools=[], |
| llm_engine=FakeLLMEngine(), |
| max_iterations=1, |
| ) |
|
|
| agent.run("Fake task") |
|
|
| self.assertEqual(agent.monitor.total_input_token_count, 20) |
| self.assertEqual(agent.monitor.total_output_token_count, 40) |
|
|
| def test_streaming_agent_text_output(self): |
| def dummy_llm_engine(prompt, **kwargs): |
| return """ |
| Code: |
| ```py |
| final_answer('This is the final answer.') |
| ```""" |
|
|
| agent = ReactCodeAgent( |
| tools=[], |
| llm_engine=dummy_llm_engine, |
| max_iterations=1, |
| ) |
|
|
| |
| outputs = list(stream_to_gradio(agent, task="Test task", test_mode=True)) |
|
|
| self.assertEqual(len(outputs), 3) |
| final_message = outputs[-1] |
| self.assertEqual(final_message.role, "assistant") |
| self.assertIn("This is the final answer.", final_message.content) |
|
|
| def test_streaming_agent_image_output(self): |
| def dummy_llm_engine(prompt, **kwargs): |
| return 'Action:{"action": "final_answer", "action_input": {"answer": "image"}}' |
|
|
| agent = ReactJsonAgent( |
| tools=[], |
| llm_engine=dummy_llm_engine, |
| max_iterations=1, |
| ) |
|
|
| |
| outputs = list(stream_to_gradio(agent, task="Test task", image=AgentImage(value="path.png"), test_mode=True)) |
|
|
| self.assertEqual(len(outputs), 2) |
| final_message = outputs[-1] |
| self.assertEqual(final_message.role, "assistant") |
| self.assertIsInstance(final_message.content, dict) |
| self.assertEqual(final_message.content["path"], "path.png") |
| self.assertEqual(final_message.content["mime_type"], "image/png") |
|
|
| def test_streaming_with_agent_error(self): |
| def dummy_llm_engine(prompt, **kwargs): |
| raise AgentError("Simulated agent error") |
|
|
| agent = ReactCodeAgent( |
| tools=[], |
| llm_engine=dummy_llm_engine, |
| max_iterations=1, |
| ) |
|
|
| |
| outputs = list(stream_to_gradio(agent, task="Test task", test_mode=True)) |
|
|
| self.assertEqual(len(outputs), 3) |
| final_message = outputs[-1] |
| self.assertEqual(final_message.role, "assistant") |
| self.assertIn("Simulated agent error", final_message.content) |
|
|