Spaces:
Sleeping
Sleeping
| """Unit tests for context utilities module.""" | |
| from __future__ import annotations | |
| import time | |
| from typing import Any | |
| from chatassistant_retail.state.langgraph_manager import ConversationState | |
| from chatassistant_retail.tools.context_utils import ( | |
| get_products_from_context, | |
| get_sales_from_context, | |
| update_products_cache, | |
| update_sales_cache, | |
| ) | |
| # Sample test data | |
| SAMPLE_PRODUCTS = [ | |
| { | |
| "sku": "SKU-10000", | |
| "name": "Laptop Pro", | |
| "category": "Electronics", | |
| "price": 1299.99, | |
| "current_stock": 5, | |
| "reorder_level": 10, | |
| }, | |
| { | |
| "sku": "SKU-10001", | |
| "name": "Wireless Mouse", | |
| "category": "Electronics", | |
| "price": 29.99, | |
| "current_stock": 15, | |
| "reorder_level": 20, | |
| }, | |
| { | |
| "sku": "SKU-20000", | |
| "name": "Office Chair", | |
| "category": "Furniture", | |
| "price": 249.99, | |
| "current_stock": 3, | |
| "reorder_level": 5, | |
| }, | |
| ] | |
| SAMPLE_SALES = [ | |
| { | |
| "sale_id": "SALE-001", | |
| "sku": "SKU-10000", | |
| "quantity": 2, | |
| "sale_price": 1299.99, | |
| "timestamp": "2025-01-01T10:00:00", | |
| }, | |
| { | |
| "sale_id": "SALE-002", | |
| "sku": "SKU-10000", | |
| "quantity": 1, | |
| "sale_price": 1299.99, | |
| "timestamp": "2025-01-02T14:30:00", | |
| }, | |
| ] | |
| def create_state_with_context(context: dict[str, Any]) -> ConversationState: | |
| """Helper to create a ConversationState with specific context.""" | |
| return ConversationState( | |
| messages=[], | |
| context=context, | |
| tool_calls=[], | |
| session_id="test-session", | |
| current_intent="tool", | |
| needs_rag=False, | |
| needs_tool=True, | |
| ) | |
| class TestGetProductsFromContext: | |
| """Tests for get_products_from_context function.""" | |
| def test_get_products_from_rag_data(self): | |
| """Test retrieving products from RAG-populated context.""" | |
| state = create_state_with_context({"products": SAMPLE_PRODUCTS[:2]}) | |
| result = get_products_from_context(state) | |
| assert result is not None | |
| assert len(result) == 2 | |
| assert result[0]["sku"] == "SKU-10000" | |
| def test_get_products_from_cache(self): | |
| """Test retrieving products from structured products_cache.""" | |
| cache = { | |
| "data": SAMPLE_PRODUCTS[:2], | |
| "source": "tool", | |
| "timestamp": time.time(), | |
| "filter_applied": {}, | |
| } | |
| state = create_state_with_context({"products_cache": cache}) | |
| result = get_products_from_context(state) | |
| assert result is not None | |
| assert len(result) == 2 | |
| def test_get_products_with_sku_filter_exact_match(self): | |
| """Test retrieving single product by SKU.""" | |
| state = create_state_with_context({"products": [SAMPLE_PRODUCTS[0]]}) | |
| result = get_products_from_context(state, sku="SKU-10000") | |
| assert result is not None | |
| assert len(result) == 1 | |
| assert result[0]["sku"] == "SKU-10000" | |
| def test_get_products_with_sku_filter_in_multiple(self): | |
| """Test retrieving product by SKU from multi-product cache.""" | |
| state = create_state_with_context({"products": SAMPLE_PRODUCTS}) | |
| result = get_products_from_context(state, sku="SKU-20000") | |
| assert result is not None | |
| assert any(p["sku"] == "SKU-20000" for p in result) | |
| def test_get_products_with_sku_filter_not_found(self): | |
| """Test SKU filter returns None when SKU not in cache.""" | |
| state = create_state_with_context({"products": SAMPLE_PRODUCTS[:2]}) | |
| result = get_products_from_context(state, sku="SKU-99999") | |
| assert result is None | |
| def test_get_products_with_category_filter_match(self): | |
| """Test retrieving products by category.""" | |
| electronics = [p for p in SAMPLE_PRODUCTS if p["category"] == "Electronics"] | |
| state = create_state_with_context({"products": electronics}) | |
| result = get_products_from_context(state, category="Electronics") | |
| assert result is not None | |
| assert all(p["category"] == "Electronics" for p in result) | |
| def test_get_products_with_category_filter_mixed_cache(self): | |
| """Test category filter returns None when cache has mixed categories.""" | |
| state = create_state_with_context({"products": SAMPLE_PRODUCTS}) # Mixed categories | |
| result = get_products_from_context(state, category="Electronics") | |
| assert result is None # Can't reliably use mixed cache | |
| def test_get_products_with_low_stock_filter_match(self): | |
| """Test retrieving low stock products.""" | |
| low_stock_products = [p for p in SAMPLE_PRODUCTS if p["current_stock"] <= 10] | |
| state = create_state_with_context({"products": low_stock_products}) | |
| result = get_products_from_context(state, low_stock=True, threshold=10) | |
| assert result is not None | |
| assert all(p["current_stock"] <= 10 for p in result) | |
| def test_get_products_with_low_stock_filter_mixed_cache(self): | |
| """Test low stock filter returns None when cache has mixed stock levels.""" | |
| state = create_state_with_context({"products": SAMPLE_PRODUCTS}) # Mixed stock | |
| result = get_products_from_context(state, low_stock=True, threshold=10) | |
| assert result is None # Can't use cache with products above threshold | |
| def test_get_products_with_none_state(self): | |
| """Test function handles None state gracefully.""" | |
| result = get_products_from_context(None) | |
| assert result is None | |
| def test_get_products_with_empty_context(self): | |
| """Test function handles empty context gracefully.""" | |
| state = create_state_with_context({}) | |
| result = get_products_from_context(state) | |
| assert result is None | |
| def test_get_products_large_cache_rejected(self): | |
| """Test that very large caches (>50 products) are not used for unfiltered queries.""" | |
| large_cache = [ | |
| {"sku": f"SKU-{i}", "name": f"Product {i}", "category": "Test", "current_stock": 10} for i in range(100) | |
| ] | |
| state = create_state_with_context({"products": large_cache}) | |
| result = get_products_from_context(state) # No filter | |
| assert result is None # Large cache should be rejected | |
| def test_get_products_cache_priority_over_rag(self): | |
| """Test that products_cache takes priority over RAG products.""" | |
| cache_products = [SAMPLE_PRODUCTS[0]] | |
| rag_products = [SAMPLE_PRODUCTS[1]] | |
| state = create_state_with_context( | |
| { | |
| "products": rag_products, | |
| "products_cache": { | |
| "data": cache_products, | |
| "source": "tool", | |
| "timestamp": time.time(), | |
| }, | |
| } | |
| ) | |
| result = get_products_from_context(state) | |
| assert result is not None | |
| assert result[0]["sku"] == "SKU-10000" # From cache, not RAG | |
| class TestGetSalesFromContext: | |
| """Tests for get_sales_from_context function.""" | |
| def test_get_sales_without_filter(self): | |
| """Test retrieving all sales from context.""" | |
| cache = { | |
| "data": SAMPLE_SALES, | |
| "timestamp": time.time(), | |
| "sku_filter": None, | |
| } | |
| state = create_state_with_context({"sales_cache": cache}) | |
| result = get_sales_from_context(state) | |
| assert result is not None | |
| assert len(result) == 2 | |
| def test_get_sales_with_matching_sku_filter(self): | |
| """Test retrieving sales with matching SKU filter.""" | |
| cache = { | |
| "data": SAMPLE_SALES, | |
| "timestamp": time.time(), | |
| "sku_filter": "SKU-10000", | |
| } | |
| state = create_state_with_context({"sales_cache": cache}) | |
| result = get_sales_from_context(state, sku="SKU-10000") | |
| assert result is not None | |
| assert len(result) == 2 | |
| def test_get_sales_with_mismatched_sku_filter(self): | |
| """Test that mismatched SKU filter returns None.""" | |
| cache = { | |
| "data": SAMPLE_SALES, | |
| "timestamp": time.time(), | |
| "sku_filter": "SKU-10000", | |
| } | |
| state = create_state_with_context({"sales_cache": cache}) | |
| result = get_sales_from_context(state, sku="SKU-99999") | |
| assert result is None | |
| def test_get_sales_requesting_all_with_filtered_cache(self): | |
| """Test that requesting all sales fails when cache is SKU-filtered.""" | |
| cache = { | |
| "data": SAMPLE_SALES, | |
| "timestamp": time.time(), | |
| "sku_filter": "SKU-10000", | |
| } | |
| state = create_state_with_context({"sales_cache": cache}) | |
| result = get_sales_from_context(state, sku=None) | |
| assert result is None # Can't use SKU-filtered cache for all sales | |
| def test_get_sales_with_none_state(self): | |
| """Test function handles None state gracefully.""" | |
| result = get_sales_from_context(None) | |
| assert result is None | |
| def test_get_sales_with_empty_context(self): | |
| """Test function handles empty context gracefully.""" | |
| state = create_state_with_context({}) | |
| result = get_sales_from_context(state) | |
| assert result is None | |
| def test_get_sales_with_no_cache(self): | |
| """Test function handles missing sales_cache gracefully.""" | |
| state = create_state_with_context({"products": SAMPLE_PRODUCTS}) | |
| result = get_sales_from_context(state) | |
| assert result is None | |
| def test_get_sales_with_empty_data(self): | |
| """Test function handles empty sales data gracefully.""" | |
| cache = { | |
| "data": [], | |
| "timestamp": time.time(), | |
| "sku_filter": None, | |
| } | |
| state = create_state_with_context({"sales_cache": cache}) | |
| result = get_sales_from_context(state) | |
| assert result is None | |
| class TestUpdateProductsCache: | |
| """Tests for update_products_cache function.""" | |
| def test_update_products_cache_basic(self): | |
| """Test updating products cache with basic data.""" | |
| state = create_state_with_context({}) | |
| update_products_cache(state, SAMPLE_PRODUCTS[:2], source="tool") | |
| assert "products_cache" in state.context | |
| cache = state.context["products_cache"] | |
| assert cache["data"] == SAMPLE_PRODUCTS[:2] | |
| assert cache["source"] == "tool" | |
| assert "timestamp" in cache | |
| assert cache["filter_applied"] == {} | |
| def test_update_products_cache_with_filter(self): | |
| """Test updating cache with filter metadata.""" | |
| state = create_state_with_context({}) | |
| filter_applied = {"sku": "SKU-10000"} | |
| update_products_cache(state, [SAMPLE_PRODUCTS[0]], source="rag", filter_applied=filter_applied) | |
| cache = state.context["products_cache"] | |
| assert cache["filter_applied"] == filter_applied | |
| assert cache["source"] == "rag" | |
| def test_update_products_cache_overwrites_existing(self): | |
| """Test that updating cache overwrites previous cache.""" | |
| state = create_state_with_context( | |
| { | |
| "products_cache": { | |
| "data": [SAMPLE_PRODUCTS[0]], | |
| "source": "old", | |
| "timestamp": time.time() - 100, | |
| }, | |
| } | |
| ) | |
| update_products_cache(state, SAMPLE_PRODUCTS[:2], source="new") | |
| cache = state.context["products_cache"] | |
| assert len(cache["data"]) == 2 | |
| assert cache["source"] == "new" | |
| def test_update_products_cache_initializes_context(self): | |
| """Test that update works with fresh ConversationState.""" | |
| state = ConversationState( | |
| messages=[], | |
| tool_calls=[], | |
| session_id="test", | |
| ) | |
| # Context is initialized as empty dict by default | |
| assert state.context == {} | |
| update_products_cache(state, SAMPLE_PRODUCTS[:1]) | |
| assert "products_cache" in state.context | |
| class TestUpdateSalesCache: | |
| """Tests for update_sales_cache function.""" | |
| def test_update_sales_cache_basic(self): | |
| """Test updating sales cache with basic data.""" | |
| state = create_state_with_context({}) | |
| update_sales_cache(state, SAMPLE_SALES) | |
| assert "sales_cache" in state.context | |
| cache = state.context["sales_cache"] | |
| assert cache["data"] == SAMPLE_SALES | |
| assert "timestamp" in cache | |
| assert cache["sku_filter"] is None | |
| def test_update_sales_cache_with_sku_filter(self): | |
| """Test updating sales cache with SKU filter.""" | |
| state = create_state_with_context({}) | |
| update_sales_cache(state, SAMPLE_SALES, sku_filter="SKU-10000") | |
| cache = state.context["sales_cache"] | |
| assert cache["sku_filter"] == "SKU-10000" | |
| def test_update_sales_cache_overwrites_existing(self): | |
| """Test that updating cache overwrites previous cache.""" | |
| state = create_state_with_context( | |
| { | |
| "sales_cache": { | |
| "data": [], | |
| "timestamp": time.time() - 100, | |
| "sku_filter": "OLD", | |
| }, | |
| } | |
| ) | |
| update_sales_cache(state, SAMPLE_SALES, sku_filter="NEW") | |
| cache = state.context["sales_cache"] | |
| assert len(cache["data"]) == 2 | |
| assert cache["sku_filter"] == "NEW" | |
| def test_update_sales_cache_initializes_context(self): | |
| """Test that update works with fresh ConversationState.""" | |
| state = ConversationState( | |
| messages=[], | |
| tool_calls=[], | |
| session_id="test", | |
| ) | |
| # Context is initialized as empty dict by default | |
| assert state.context == {} | |
| update_sales_cache(state, SAMPLE_SALES) | |
| assert "sales_cache" in state.context | |
| class TestIntegratedScenarios: | |
| """Integration tests for realistic usage scenarios.""" | |
| def test_rag_then_tool_workflow(self): | |
| """Test workflow: RAG retrieves products, then tool uses cached data.""" | |
| # Step 1: RAG retrieves products | |
| state = create_state_with_context({}) | |
| update_products_cache(state, SAMPLE_PRODUCTS[:2], source="rag") | |
| # Step 2: Tool tries to get products | |
| result = get_products_from_context(state) | |
| assert result is not None | |
| assert len(result) == 2 | |
| def test_tool_loads_then_reuses(self): | |
| """Test workflow: Tool loads data, caches it, then reuses on next call.""" | |
| state = create_state_with_context({}) | |
| # First tool call: loads fresh data | |
| products_from_load = SAMPLE_PRODUCTS.copy() | |
| update_products_cache(state, products_from_load, source="tool") | |
| # Second tool call: reuses cached data | |
| result = get_products_from_context(state) | |
| assert result is not None | |
| assert result == products_from_load | |
| def test_multiple_tools_share_cache(self): | |
| """Test workflow: Multiple tools access same cached data.""" | |
| state = create_state_with_context({}) | |
| # Tool 1 caches products and sales | |
| update_products_cache(state, SAMPLE_PRODUCTS, source="tool") | |
| update_sales_cache(state, SAMPLE_SALES, sku_filter="SKU-10000") | |
| # Tool 2 retrieves products | |
| products = get_products_from_context(state, sku="SKU-10000") | |
| assert products is not None | |
| # Tool 3 retrieves sales | |
| sales = get_sales_from_context(state, sku="SKU-10000") | |
| assert sales is not None | |
| assert len(sales) == 2 | |