chatassistant_retail / tests /unit /test_context_utils.py
github-actions[bot]
Sync from https://github.com/samir72/chatassistant_retail
8b30412
"""Unit tests for context utilities module."""
from __future__ import annotations
import time
from typing import Any
from chatassistant_retail.state.langgraph_manager import ConversationState
from chatassistant_retail.tools.context_utils import (
get_products_from_context,
get_sales_from_context,
update_products_cache,
update_sales_cache,
)
# Sample test data
SAMPLE_PRODUCTS = [
{
"sku": "SKU-10000",
"name": "Laptop Pro",
"category": "Electronics",
"price": 1299.99,
"current_stock": 5,
"reorder_level": 10,
},
{
"sku": "SKU-10001",
"name": "Wireless Mouse",
"category": "Electronics",
"price": 29.99,
"current_stock": 15,
"reorder_level": 20,
},
{
"sku": "SKU-20000",
"name": "Office Chair",
"category": "Furniture",
"price": 249.99,
"current_stock": 3,
"reorder_level": 5,
},
]
SAMPLE_SALES = [
{
"sale_id": "SALE-001",
"sku": "SKU-10000",
"quantity": 2,
"sale_price": 1299.99,
"timestamp": "2025-01-01T10:00:00",
},
{
"sale_id": "SALE-002",
"sku": "SKU-10000",
"quantity": 1,
"sale_price": 1299.99,
"timestamp": "2025-01-02T14:30:00",
},
]
def create_state_with_context(context: dict[str, Any]) -> ConversationState:
"""Helper to create a ConversationState with specific context."""
return ConversationState(
messages=[],
context=context,
tool_calls=[],
session_id="test-session",
current_intent="tool",
needs_rag=False,
needs_tool=True,
)
class TestGetProductsFromContext:
"""Tests for get_products_from_context function."""
def test_get_products_from_rag_data(self):
"""Test retrieving products from RAG-populated context."""
state = create_state_with_context({"products": SAMPLE_PRODUCTS[:2]})
result = get_products_from_context(state)
assert result is not None
assert len(result) == 2
assert result[0]["sku"] == "SKU-10000"
def test_get_products_from_cache(self):
"""Test retrieving products from structured products_cache."""
cache = {
"data": SAMPLE_PRODUCTS[:2],
"source": "tool",
"timestamp": time.time(),
"filter_applied": {},
}
state = create_state_with_context({"products_cache": cache})
result = get_products_from_context(state)
assert result is not None
assert len(result) == 2
def test_get_products_with_sku_filter_exact_match(self):
"""Test retrieving single product by SKU."""
state = create_state_with_context({"products": [SAMPLE_PRODUCTS[0]]})
result = get_products_from_context(state, sku="SKU-10000")
assert result is not None
assert len(result) == 1
assert result[0]["sku"] == "SKU-10000"
def test_get_products_with_sku_filter_in_multiple(self):
"""Test retrieving product by SKU from multi-product cache."""
state = create_state_with_context({"products": SAMPLE_PRODUCTS})
result = get_products_from_context(state, sku="SKU-20000")
assert result is not None
assert any(p["sku"] == "SKU-20000" for p in result)
def test_get_products_with_sku_filter_not_found(self):
"""Test SKU filter returns None when SKU not in cache."""
state = create_state_with_context({"products": SAMPLE_PRODUCTS[:2]})
result = get_products_from_context(state, sku="SKU-99999")
assert result is None
def test_get_products_with_category_filter_match(self):
"""Test retrieving products by category."""
electronics = [p for p in SAMPLE_PRODUCTS if p["category"] == "Electronics"]
state = create_state_with_context({"products": electronics})
result = get_products_from_context(state, category="Electronics")
assert result is not None
assert all(p["category"] == "Electronics" for p in result)
def test_get_products_with_category_filter_mixed_cache(self):
"""Test category filter returns None when cache has mixed categories."""
state = create_state_with_context({"products": SAMPLE_PRODUCTS}) # Mixed categories
result = get_products_from_context(state, category="Electronics")
assert result is None # Can't reliably use mixed cache
def test_get_products_with_low_stock_filter_match(self):
"""Test retrieving low stock products."""
low_stock_products = [p for p in SAMPLE_PRODUCTS if p["current_stock"] <= 10]
state = create_state_with_context({"products": low_stock_products})
result = get_products_from_context(state, low_stock=True, threshold=10)
assert result is not None
assert all(p["current_stock"] <= 10 for p in result)
def test_get_products_with_low_stock_filter_mixed_cache(self):
"""Test low stock filter returns None when cache has mixed stock levels."""
state = create_state_with_context({"products": SAMPLE_PRODUCTS}) # Mixed stock
result = get_products_from_context(state, low_stock=True, threshold=10)
assert result is None # Can't use cache with products above threshold
def test_get_products_with_none_state(self):
"""Test function handles None state gracefully."""
result = get_products_from_context(None)
assert result is None
def test_get_products_with_empty_context(self):
"""Test function handles empty context gracefully."""
state = create_state_with_context({})
result = get_products_from_context(state)
assert result is None
def test_get_products_large_cache_rejected(self):
"""Test that very large caches (>50 products) are not used for unfiltered queries."""
large_cache = [
{"sku": f"SKU-{i}", "name": f"Product {i}", "category": "Test", "current_stock": 10} for i in range(100)
]
state = create_state_with_context({"products": large_cache})
result = get_products_from_context(state) # No filter
assert result is None # Large cache should be rejected
def test_get_products_cache_priority_over_rag(self):
"""Test that products_cache takes priority over RAG products."""
cache_products = [SAMPLE_PRODUCTS[0]]
rag_products = [SAMPLE_PRODUCTS[1]]
state = create_state_with_context(
{
"products": rag_products,
"products_cache": {
"data": cache_products,
"source": "tool",
"timestamp": time.time(),
},
}
)
result = get_products_from_context(state)
assert result is not None
assert result[0]["sku"] == "SKU-10000" # From cache, not RAG
class TestGetSalesFromContext:
"""Tests for get_sales_from_context function."""
def test_get_sales_without_filter(self):
"""Test retrieving all sales from context."""
cache = {
"data": SAMPLE_SALES,
"timestamp": time.time(),
"sku_filter": None,
}
state = create_state_with_context({"sales_cache": cache})
result = get_sales_from_context(state)
assert result is not None
assert len(result) == 2
def test_get_sales_with_matching_sku_filter(self):
"""Test retrieving sales with matching SKU filter."""
cache = {
"data": SAMPLE_SALES,
"timestamp": time.time(),
"sku_filter": "SKU-10000",
}
state = create_state_with_context({"sales_cache": cache})
result = get_sales_from_context(state, sku="SKU-10000")
assert result is not None
assert len(result) == 2
def test_get_sales_with_mismatched_sku_filter(self):
"""Test that mismatched SKU filter returns None."""
cache = {
"data": SAMPLE_SALES,
"timestamp": time.time(),
"sku_filter": "SKU-10000",
}
state = create_state_with_context({"sales_cache": cache})
result = get_sales_from_context(state, sku="SKU-99999")
assert result is None
def test_get_sales_requesting_all_with_filtered_cache(self):
"""Test that requesting all sales fails when cache is SKU-filtered."""
cache = {
"data": SAMPLE_SALES,
"timestamp": time.time(),
"sku_filter": "SKU-10000",
}
state = create_state_with_context({"sales_cache": cache})
result = get_sales_from_context(state, sku=None)
assert result is None # Can't use SKU-filtered cache for all sales
def test_get_sales_with_none_state(self):
"""Test function handles None state gracefully."""
result = get_sales_from_context(None)
assert result is None
def test_get_sales_with_empty_context(self):
"""Test function handles empty context gracefully."""
state = create_state_with_context({})
result = get_sales_from_context(state)
assert result is None
def test_get_sales_with_no_cache(self):
"""Test function handles missing sales_cache gracefully."""
state = create_state_with_context({"products": SAMPLE_PRODUCTS})
result = get_sales_from_context(state)
assert result is None
def test_get_sales_with_empty_data(self):
"""Test function handles empty sales data gracefully."""
cache = {
"data": [],
"timestamp": time.time(),
"sku_filter": None,
}
state = create_state_with_context({"sales_cache": cache})
result = get_sales_from_context(state)
assert result is None
class TestUpdateProductsCache:
"""Tests for update_products_cache function."""
def test_update_products_cache_basic(self):
"""Test updating products cache with basic data."""
state = create_state_with_context({})
update_products_cache(state, SAMPLE_PRODUCTS[:2], source="tool")
assert "products_cache" in state.context
cache = state.context["products_cache"]
assert cache["data"] == SAMPLE_PRODUCTS[:2]
assert cache["source"] == "tool"
assert "timestamp" in cache
assert cache["filter_applied"] == {}
def test_update_products_cache_with_filter(self):
"""Test updating cache with filter metadata."""
state = create_state_with_context({})
filter_applied = {"sku": "SKU-10000"}
update_products_cache(state, [SAMPLE_PRODUCTS[0]], source="rag", filter_applied=filter_applied)
cache = state.context["products_cache"]
assert cache["filter_applied"] == filter_applied
assert cache["source"] == "rag"
def test_update_products_cache_overwrites_existing(self):
"""Test that updating cache overwrites previous cache."""
state = create_state_with_context(
{
"products_cache": {
"data": [SAMPLE_PRODUCTS[0]],
"source": "old",
"timestamp": time.time() - 100,
},
}
)
update_products_cache(state, SAMPLE_PRODUCTS[:2], source="new")
cache = state.context["products_cache"]
assert len(cache["data"]) == 2
assert cache["source"] == "new"
def test_update_products_cache_initializes_context(self):
"""Test that update works with fresh ConversationState."""
state = ConversationState(
messages=[],
tool_calls=[],
session_id="test",
)
# Context is initialized as empty dict by default
assert state.context == {}
update_products_cache(state, SAMPLE_PRODUCTS[:1])
assert "products_cache" in state.context
class TestUpdateSalesCache:
"""Tests for update_sales_cache function."""
def test_update_sales_cache_basic(self):
"""Test updating sales cache with basic data."""
state = create_state_with_context({})
update_sales_cache(state, SAMPLE_SALES)
assert "sales_cache" in state.context
cache = state.context["sales_cache"]
assert cache["data"] == SAMPLE_SALES
assert "timestamp" in cache
assert cache["sku_filter"] is None
def test_update_sales_cache_with_sku_filter(self):
"""Test updating sales cache with SKU filter."""
state = create_state_with_context({})
update_sales_cache(state, SAMPLE_SALES, sku_filter="SKU-10000")
cache = state.context["sales_cache"]
assert cache["sku_filter"] == "SKU-10000"
def test_update_sales_cache_overwrites_existing(self):
"""Test that updating cache overwrites previous cache."""
state = create_state_with_context(
{
"sales_cache": {
"data": [],
"timestamp": time.time() - 100,
"sku_filter": "OLD",
},
}
)
update_sales_cache(state, SAMPLE_SALES, sku_filter="NEW")
cache = state.context["sales_cache"]
assert len(cache["data"]) == 2
assert cache["sku_filter"] == "NEW"
def test_update_sales_cache_initializes_context(self):
"""Test that update works with fresh ConversationState."""
state = ConversationState(
messages=[],
tool_calls=[],
session_id="test",
)
# Context is initialized as empty dict by default
assert state.context == {}
update_sales_cache(state, SAMPLE_SALES)
assert "sales_cache" in state.context
class TestIntegratedScenarios:
"""Integration tests for realistic usage scenarios."""
def test_rag_then_tool_workflow(self):
"""Test workflow: RAG retrieves products, then tool uses cached data."""
# Step 1: RAG retrieves products
state = create_state_with_context({})
update_products_cache(state, SAMPLE_PRODUCTS[:2], source="rag")
# Step 2: Tool tries to get products
result = get_products_from_context(state)
assert result is not None
assert len(result) == 2
def test_tool_loads_then_reuses(self):
"""Test workflow: Tool loads data, caches it, then reuses on next call."""
state = create_state_with_context({})
# First tool call: loads fresh data
products_from_load = SAMPLE_PRODUCTS.copy()
update_products_cache(state, products_from_load, source="tool")
# Second tool call: reuses cached data
result = get_products_from_context(state)
assert result is not None
assert result == products_from_load
def test_multiple_tools_share_cache(self):
"""Test workflow: Multiple tools access same cached data."""
state = create_state_with_context({})
# Tool 1 caches products and sales
update_products_cache(state, SAMPLE_PRODUCTS, source="tool")
update_sales_cache(state, SAMPLE_SALES, sku_filter="SKU-10000")
# Tool 2 retrieves products
products = get_products_from_context(state, sku="SKU-10000")
assert products is not None
# Tool 3 retrieves sales
sales = get_sales_from_context(state, sku="SKU-10000")
assert sales is not None
assert len(sales) == 2