|
|
""" |
|
|
Session integration tests - Testing multi-turn conversations and session management |
|
|
""" |
|
|
|
|
|
import tempfile |
|
|
from pathlib import Path |
|
|
from unittest.mock import AsyncMock, MagicMock, patch |
|
|
|
|
|
import pytest |
|
|
|
|
|
from mini_agent import LLMClient |
|
|
from mini_agent.agent import Agent |
|
|
from mini_agent.schema import LLMResponse, Message |
|
|
from mini_agent.tools.bash_tool import BashTool |
|
|
from mini_agent.tools.file_tools import ReadTool, WriteTool |
|
|
from mini_agent.tools.note_tool import RecallNoteTool, SessionNoteTool |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def mock_llm_client(): |
|
|
"""Create mock LLM client""" |
|
|
client = MagicMock(spec=LLMClient) |
|
|
return client |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def temp_workspace(): |
|
|
"""Create temporary workspace directory""" |
|
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
|
yield tmpdir |
|
|
|
|
|
|
|
|
def test_multi_turn_conversation(mock_llm_client, temp_workspace): |
|
|
"""Test multi-turn conversation and context sharing""" |
|
|
|
|
|
system_prompt = "You are an intelligent assistant" |
|
|
tools = [ |
|
|
ReadTool(workspace_dir=temp_workspace), |
|
|
WriteTool(workspace_dir=temp_workspace), |
|
|
SessionNoteTool(), |
|
|
] |
|
|
|
|
|
|
|
|
agent = Agent( |
|
|
llm_client=mock_llm_client, |
|
|
system_prompt=system_prompt, |
|
|
tools=tools, |
|
|
workspace_dir=temp_workspace, |
|
|
) |
|
|
|
|
|
|
|
|
assert len(agent.messages) == 1 |
|
|
assert agent.messages[0].role == "system" |
|
|
|
|
|
assert system_prompt in agent.messages[0].content |
|
|
assert "Current Workspace" in agent.messages[0].content |
|
|
|
|
|
|
|
|
agent.add_user_message("Hello") |
|
|
assert len(agent.messages) == 2 |
|
|
assert agent.messages[1].role == "user" |
|
|
assert agent.messages[1].content == "Hello" |
|
|
|
|
|
|
|
|
agent.add_user_message("Help me create a file") |
|
|
assert len(agent.messages) == 3 |
|
|
assert agent.messages[2].role == "user" |
|
|
|
|
|
|
|
|
user_messages = [m for m in agent.messages if m.role == "user"] |
|
|
assert len(user_messages) == 2 |
|
|
assert user_messages[0].content == "Hello" |
|
|
assert user_messages[1].content == "Help me create a file" |
|
|
|
|
|
|
|
|
def test_session_history_management(mock_llm_client, temp_workspace): |
|
|
"""Test session history management""" |
|
|
agent = Agent( |
|
|
llm_client=mock_llm_client, |
|
|
system_prompt="System prompt", |
|
|
tools=[], |
|
|
workspace_dir=temp_workspace, |
|
|
) |
|
|
|
|
|
|
|
|
for i in range(5): |
|
|
agent.add_user_message(f"Message {i}") |
|
|
|
|
|
|
|
|
assert len(agent.messages) == 6 |
|
|
|
|
|
|
|
|
agent.messages = [agent.messages[0]] |
|
|
|
|
|
|
|
|
assert len(agent.messages) == 1 |
|
|
assert agent.messages[0].role == "system" |
|
|
|
|
|
|
|
|
def test_get_history(mock_llm_client, temp_workspace): |
|
|
"""Test getting session history""" |
|
|
agent = Agent( |
|
|
llm_client=mock_llm_client, |
|
|
system_prompt="System", |
|
|
tools=[], |
|
|
workspace_dir=temp_workspace, |
|
|
) |
|
|
|
|
|
|
|
|
agent.add_user_message("Test message") |
|
|
|
|
|
|
|
|
history = agent.get_history() |
|
|
|
|
|
|
|
|
assert len(history) == len(agent.messages) |
|
|
assert history is not agent.messages |
|
|
|
|
|
|
|
|
history.append(Message(role="user", content="New message")) |
|
|
assert len(agent.messages) == 2 |
|
|
assert len(history) == 3 |
|
|
|
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_session_note_persistence(temp_workspace): |
|
|
"""Test SessionNoteTool persistence functionality""" |
|
|
memory_file = Path(temp_workspace) / "memory.json" |
|
|
|
|
|
|
|
|
record_tool = SessionNoteTool(memory_file=str(memory_file)) |
|
|
result1 = await record_tool.execute(content="Test note", category="test") |
|
|
assert result1.success |
|
|
|
|
|
|
|
|
recall_tool = RecallNoteTool(memory_file=str(memory_file)) |
|
|
|
|
|
|
|
|
result2 = await recall_tool.execute() |
|
|
assert result2.success |
|
|
assert "Test note" in result2.content |
|
|
|
|
|
|
|
|
def test_message_statistics(mock_llm_client, temp_workspace): |
|
|
"""Test message statistics functionality""" |
|
|
agent = Agent( |
|
|
llm_client=mock_llm_client, |
|
|
system_prompt="System", |
|
|
tools=[], |
|
|
workspace_dir=temp_workspace, |
|
|
) |
|
|
|
|
|
|
|
|
agent.add_user_message("User message 1") |
|
|
agent.messages.append(Message(role="assistant", content="Assistant response 1")) |
|
|
agent.add_user_message("User message 2") |
|
|
agent.messages.append( |
|
|
Message( |
|
|
role="tool", content="Tool result", tool_call_id="123", name="test_tool" |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
user_msgs = sum(1 for m in agent.messages if m.role == "user") |
|
|
assistant_msgs = sum(1 for m in agent.messages if m.role == "assistant") |
|
|
tool_msgs = sum(1 for m in agent.messages if m.role == "tool") |
|
|
|
|
|
assert user_msgs == 2 |
|
|
assert assistant_msgs == 1 |
|
|
assert tool_msgs == 1 |
|
|
assert len(agent.messages) == 5 |
|
|
|