Ranjit0034's picture
Upload tests/test_rag.py with huggingface_hub
f28b6fd verified
"""
Tests for FinEE RAG Engine
==========================
Comprehensive tests for:
- Merchant Knowledge Base
- Category Taxonomy
- Vector Store
- RAG Engine
Author: Ranjit Behera
"""
import pytest
import json
from pathlib import Path
import tempfile
class TestMerchantKnowledgeBase:
"""Tests for MerchantKnowledgeBase."""
def setup_method(self):
"""Setup test fixtures."""
from finee.rag import MerchantKnowledgeBase, Merchant
self.kb = MerchantKnowledgeBase()
self.Merchant = Merchant
def test_default_merchants_loaded(self):
"""Test that default merchants are loaded."""
assert len(self.kb.merchants) > 30
assert "swiggy" in self.kb.merchants
assert "amazon" in self.kb.merchants
assert "zerodha" in self.kb.merchants
def test_lookup_by_name(self):
"""Test merchant lookup by name."""
merchant = self.kb.lookup("Swiggy")
assert merchant is not None
assert merchant.name == "Swiggy"
assert merchant.category == "food"
def test_lookup_by_vpa(self):
"""Test merchant lookup by VPA."""
merchant = self.kb.lookup("swiggy@ybl")
assert merchant is not None
assert merchant.name == "Swiggy"
def test_lookup_by_alias(self):
"""Test merchant lookup by alias."""
merchant = self.kb.lookup("amzn")
assert merchant is not None
assert merchant.name == "Amazon"
def test_lookup_partial_match(self):
"""Test partial name matching."""
merchant = self.kb.lookup("netflix")
assert merchant is not None
assert merchant.name == "Netflix"
def test_lookup_not_found(self):
"""Test lookup returns None for unknown merchant."""
merchant = self.kb.lookup("unknownmerchant123")
assert merchant is None
def test_search_by_text(self):
"""Test text-based merchant search."""
matches = self.kb.search("food delivery swiggy order")
assert len(matches) > 0
assert matches[0].name == "Swiggy"
def test_add_merchant(self):
"""Test adding new merchant."""
new_merchant = self.Merchant(
name="TestMerchant",
vpa="test@ybl",
category="test",
aliases=["tm"],
)
self.kb.add_merchant(new_merchant)
assert "testmerchant" in self.kb.merchants
assert self.kb.lookup("test@ybl") is not None
assert self.kb.lookup("tm") is not None
def test_get_category_merchants(self):
"""Test getting merchants by category."""
food_merchants = self.kb.get_category_merchants("food")
assert len(food_merchants) > 0
for m in food_merchants:
assert m.category == "food"
def test_save_and_load(self):
"""Test saving and loading KB."""
with tempfile.TemporaryDirectory() as tmpdir:
path = Path(tmpdir) / "merchants.json"
# Save
self.kb.save(path)
assert path.exists()
# Load
from finee.rag import MerchantKnowledgeBase
loaded_kb = MerchantKnowledgeBase.load(path)
assert len(loaded_kb.merchants) == len(self.kb.merchants)
class TestCategoryTaxonomy:
"""Tests for CategoryTaxonomy."""
def setup_method(self):
"""Setup test fixtures."""
from finee.rag import CategoryTaxonomy
self.taxonomy = CategoryTaxonomy
def test_get_hierarchy_simple(self):
"""Test simple category hierarchy."""
hierarchy = self.taxonomy.get_hierarchy("food")
assert hierarchy == ["food"]
def test_get_hierarchy_nested(self):
"""Test nested category hierarchy."""
hierarchy = self.taxonomy.get_hierarchy("grocery")
assert "shopping" in hierarchy
assert "grocery" in hierarchy
def test_get_hierarchy_unknown(self):
"""Test unknown category returns itself."""
hierarchy = self.taxonomy.get_hierarchy("unknown")
assert hierarchy == ["unknown"]
def test_infer_category_food(self):
"""Test food category inference."""
category = self.taxonomy.infer_category("lunch delivery from restaurant")
assert category == "food"
def test_infer_category_investment(self):
"""Test investment category inference."""
category = self.taxonomy.infer_category("SIP mutual fund trading invest")
assert category == "investment"
def test_infer_category_bills(self):
"""Test bills category inference."""
category = self.taxonomy.infer_category("electricity bill payment recharge")
assert category == "bills"
def test_infer_category_unknown(self):
"""Test unknown text returns 'other'."""
category = self.taxonomy.infer_category("random text here")
assert category == "other"
class TestSimpleVectorStore:
"""Tests for SimpleVectorStore."""
def setup_method(self):
"""Setup test fixtures."""
from finee.rag import SimpleVectorStore, Transaction
self.store = SimpleVectorStore()
self.Transaction = Transaction
def test_add_transaction(self):
"""Test adding transaction."""
txn = self.Transaction(
id="test1",
text="HDFC Bank Rs.500 debited Swiggy",
amount=500.0,
type="debit",
merchant="Swiggy",
category="food",
)
self.store.add(txn)
assert len(self.store.documents) == 1
def test_search_similar(self):
"""Test similarity search."""
# Add some transactions
transactions = [
self.Transaction("1", "HDFC Rs.500 Swiggy food", 500, "debit", "Swiggy", "food"),
self.Transaction("2", "SBI Rs.1000 Amazon shopping", 1000, "debit", "Amazon", "shopping"),
self.Transaction("3", "ICICI Rs.250 Zomato dinner", 250, "debit", "Zomato", "food"),
]
for txn in transactions:
self.store.add(txn)
# Search
results = self.store.search("Swiggy food delivery", limit=2)
assert len(results) > 0
# Swiggy or Zomato should be top result (both are food)
assert results[0][0].category == "food"
def test_search_empty_store(self):
"""Test search on empty store."""
results = self.store.search("any query")
assert results == []
def test_save_and_load(self):
"""Test saving and loading store."""
# Add transactions
txn = self.Transaction("1", "Test transaction", 100, "debit")
self.store.add(txn)
with tempfile.TemporaryDirectory() as tmpdir:
path = Path(tmpdir) / "store.json"
# Save
self.store.save(path)
assert path.exists()
# Load
from finee.rag import SimpleVectorStore
loaded = SimpleVectorStore.load(path)
assert len(loaded.documents) == 1
class TestRAGEngine:
"""Tests for RAGEngine."""
def setup_method(self):
"""Setup test fixtures."""
from finee.rag import RAGEngine
self.rag = RAGEngine()
def test_retrieve_with_merchant(self):
"""Test retrieval with known merchant."""
context = self.rag.retrieve("HDFC Bank Rs.499 debited UPI:swiggy@ybl")
assert context.merchant_info is not None
assert context.merchant_info["name"] == "Swiggy"
assert context.merchant_info["category"] == "food"
assert context.confidence_boost > 0
def test_retrieve_with_category(self):
"""Test category inference."""
context = self.rag.retrieve("Netflix subscription payment")
assert context.category_hierarchy is not None
assert "entertainment" in context.category_hierarchy
def test_retrieve_investment(self):
"""Test investment detection."""
context = self.rag.retrieve("Rs.25000 transferred to Zerodha trading")
assert context.merchant_info is not None
assert context.merchant_info["category"] == "investment"
def test_augment_prompt(self):
"""Test prompt augmentation."""
context = self.rag.retrieve("Swiggy food order")
augmented = self.rag.augment_prompt("Swiggy food order", context)
assert "Swiggy" in augmented
assert "food" in augmented
def test_add_transaction(self):
"""Test adding transaction to history."""
self.rag.add_transaction(
"Test transaction text",
{"amount": 100, "merchant": "Test", "category": "test"}
)
assert len(self.rag.vector_store.documents) > 0
def test_save_and_load(self):
"""Test saving and loading RAG state."""
with tempfile.TemporaryDirectory() as tmpdir:
path = Path(tmpdir) / "rag_state"
# Save
self.rag.save(path)
assert (path / "merchants.json").exists()
# Load
from finee.rag import RAGEngine
new_rag = RAGEngine()
new_rag.load(path)
assert len(new_rag.merchant_kb.merchants) > 0
class TestRAGIntegration:
"""Integration tests for RAG with extraction."""
def setup_method(self):
"""Setup test fixtures."""
from finee.rag import RAGEngine
self.rag = RAGEngine()
@pytest.mark.parametrize("message,expected_merchant,expected_category", [
("HDFC Rs.499 UPI:swiggy@ybl", "Swiggy", "food"),
("SBI Rs.999 Netflix subscription", "Netflix", "entertainment"),
("ICICI Rs.25000 Zerodha trading", "Zerodha", "investment"),
("Axis Rs.1500 Amazon order", "Amazon", "shopping"),
("Kotak Rs.350 Uber ride", "Uber", "transport"),
("PNB Rs.100 Airtel recharge", "Airtel", "bills"),
])
def test_merchant_detection(self, message, expected_merchant, expected_category):
"""Test merchant detection for various messages."""
context = self.rag.retrieve(message)
assert context.merchant_info is not None
assert context.merchant_info["name"] == expected_merchant
assert context.merchant_info["category"] == expected_category
# ============================================================================
# RUN TESTS
# ============================================================================
if __name__ == "__main__":
pytest.main([__file__, "-v"])