chatassistant_retail / tests /integration /test_state_manager.py
github-actions[bot]
Sync from https://github.com/samir72/chatassistant_retail
8b30412
"""Integration tests for Langgraph state manager."""
import pytest
from langchain_core.messages import HumanMessage
from chatassistant_retail.state import ConversationState, LanggraphManager
class MockLLMClient:
"""Mock LLM client for testing."""
async def call_llm(self, messages, tools=None):
"""Mock LLM call - returns dictionary format."""
return {
"choices": [
{
"message": {
"content": "This is a test response.",
"role": "assistant",
"tool_calls": None,
}
}
]
}
async def extract_response_content(self, response):
"""Extract response content from dictionary."""
if isinstance(response, dict) and "choices" in response:
choices = response.get("choices", [])
if choices:
message = choices[0].get("message", {})
return message.get("content", "")
return ""
async def extract_tool_calls(self, response):
"""Extract tool calls from dictionary."""
if isinstance(response, dict) and "choices" in response:
choices = response.get("choices", [])
if choices:
message = choices[0].get("message", {})
tool_calls = message.get("tool_calls", [])
if tool_calls:
return [
{
"name": tc.get("function", {}).get("name", ""),
"arguments": tc.get("function", {}).get("arguments", {}),
}
for tc in tool_calls
]
return []
class MockRAGRetriever:
"""Mock RAG retriever for testing."""
async def retrieve(self, query, top_k=5):
"""Mock retrieval."""
return [
{
"sku": "SKU-10000",
"name": "Test Product",
"category": "Electronics",
"price": 99.99,
"current_stock": 5,
"reorder_level": 10,
}
]
class MockToolExecutor:
"""Mock tool executor for testing."""
async def execute_tool(self, tool_name, args):
"""Mock tool execution."""
return {
"success": True,
"message": f"Executed {tool_name} with args {args}",
}
class TestLanggraphManager:
"""Test Langgraph state management."""
@pytest.mark.asyncio
async def test_greeting_classification(self):
"""Test that greetings are classified correctly."""
llm_client = MockLLMClient()
rag_retriever = MockRAGRetriever()
tool_executor = MockToolExecutor()
manager = LanggraphManager(llm_client, rag_retriever, tool_executor)
state = ConversationState(
session_id="test-session",
messages=[HumanMessage(content="Hello")],
)
# Classify intent
state = await manager._classify_intent_node(state)
assert state.current_intent == "greeting"
@pytest.mark.asyncio
async def test_rag_classification(self):
"""Test that product queries are classified as RAG."""
llm_client = MockLLMClient()
rag_retriever = MockRAGRetriever()
tool_executor = MockToolExecutor()
manager = LanggraphManager(llm_client, rag_retriever, tool_executor)
state = ConversationState(
session_id="test-session",
messages=[HumanMessage(content="Find me a wireless mouse")],
)
# Classify intent
state = await manager._classify_intent_node(state)
assert state.current_intent == "rag"
assert state.needs_rag is True
@pytest.mark.asyncio
async def test_tool_classification(self):
"""Test that tool-related queries are classified correctly."""
llm_client = MockLLMClient()
rag_retriever = MockRAGRetriever()
tool_executor = MockToolExecutor()
manager = LanggraphManager(llm_client, rag_retriever, tool_executor)
state = ConversationState(
session_id="test-session",
messages=[HumanMessage(content="Check low stock items")],
)
# Classify intent
state = await manager._classify_intent_node(state)
assert state.current_intent == "tool"
assert state.needs_tool is True
@pytest.mark.asyncio
async def test_rag_retrieval_node(self):
"""Test RAG retrieval node."""
llm_client = MockLLMClient()
rag_retriever = MockRAGRetriever()
tool_executor = MockToolExecutor()
manager = LanggraphManager(llm_client, rag_retriever, tool_executor)
state = ConversationState(
session_id="test-session",
messages=[HumanMessage(content="Find wireless mouse")],
)
# Execute retrieval
state = await manager._rag_retrieval_node(state)
assert "products" in state.context
assert len(state.context["products"]) > 0
assert state.context["products"][0]["name"] == "Test Product"
@pytest.mark.asyncio
async def test_generate_response_node(self):
"""Test response generation node."""
llm_client = MockLLMClient()
rag_retriever = MockRAGRetriever()
tool_executor = MockToolExecutor()
manager = LanggraphManager(llm_client, rag_retriever, tool_executor)
state = ConversationState(
session_id="test-session",
messages=[HumanMessage(content="Hello")],
context={"products": []},
)
# Generate response
state = await manager._generate_response_node(state)
# Should have added an AI message
assert len(state.messages) == 2
assert state.messages[1].content == "This is a test response."
@pytest.mark.asyncio
async def test_full_workflow_greeting(self):
"""Test full workflow for greeting."""
llm_client = MockLLMClient()
rag_retriever = MockRAGRetriever()
tool_executor = MockToolExecutor()
manager = LanggraphManager(llm_client, rag_retriever, tool_executor)
state = ConversationState(
session_id="test-session",
messages=[HumanMessage(content="Hi there")],
)
# Process through workflow
final_state = await manager.process(state)
# Should have response
assert len(final_state.messages) == 2
assert final_state.current_intent == "greeting"
assert final_state.error is None
@pytest.mark.asyncio
async def test_full_workflow_rag(self):
"""Test full workflow for RAG query."""
llm_client = MockLLMClient()
rag_retriever = MockRAGRetriever()
tool_executor = MockToolExecutor()
manager = LanggraphManager(llm_client, rag_retriever, tool_executor)
state = ConversationState(
session_id="test-session",
messages=[HumanMessage(content="Find electronics products")],
)
# Process through workflow
final_state = await manager.process(state)
# Should have retrieved products and generated response
assert "products" in final_state.context
assert len(final_state.messages) == 2
assert final_state.current_intent == "rag"
@pytest.mark.asyncio
async def test_error_handling(self):
"""Test error handling in workflow."""
class FailingLLMClient:
async def call_llm(self, messages, tools=None):
raise Exception("LLM error")
llm_client = FailingLLMClient()
rag_retriever = MockRAGRetriever()
tool_executor = MockToolExecutor()
manager = LanggraphManager(llm_client, rag_retriever, tool_executor)
state = ConversationState(
session_id="test-session",
messages=[HumanMessage(content="Hello")],
)
# Process through workflow (should handle error gracefully)
final_state = await manager.process(state)
# Should have error set
assert final_state.error is not None
if __name__ == "__main__":
pytest.main([__file__, "-v"])