Kaanta / test_optimizer.py
Eniiyanu's picture
Upload 14 files
c407c64 verified
# test_optimizer.py
"""
Quick test script to verify tax optimizer modules work correctly
Run this before starting the API to catch any import/logic errors
"""
def test_imports():
"""Test that all modules can be imported"""
print("Testing imports...")
try:
from transaction_classifier import TransactionClassifier
from transaction_aggregator import TransactionAggregator
from tax_strategy_extractor import TaxStrategyExtractor
from tax_optimizer import TaxOptimizer
print("[PASS] All modules imported successfully")
return True
except ImportError as e:
print(f"[FAIL] Import error: {e}")
return False
def test_classifier():
"""Test transaction classifier"""
print("\nTesting TransactionClassifier...")
try:
from transaction_classifier import TransactionClassifier
classifier = TransactionClassifier(rag_pipeline=None)
# Test transaction
test_tx = {
"type": "credit",
"amount": 500000,
"narration": "SALARY PAYMENT FROM ABC COMPANY LTD",
"date": "2025-01-31",
"balance": 750000
}
result = classifier.classify_transaction(test_tx)
assert result["tax_category"] == "employment_income", "Should classify as employment income"
assert result["deductible"] == False, "Income should not be deductible"
assert result["confidence"] > 0.8, "Should have high confidence"
print(f"[PASS] Classifier working: {result['tax_category']} (confidence: {result['confidence']:.2f})")
return True
except Exception as e:
print(f"[FAIL] Classifier test failed: {e}")
return False
def test_aggregator():
"""Test transaction aggregator"""
print("\nTesting TransactionAggregator...")
try:
from transaction_aggregator import TransactionAggregator
aggregator = TransactionAggregator()
# Test transactions
test_txs = [
{
"type": "credit",
"amount": 500000,
"narration": "SALARY",
"date": "2025-01-31",
"tax_category": "employment_income",
"metadata": {"basic_salary": 300000, "housing_allowance": 120000, "transport_allowance": 60000, "bonus": 20000}
},
{
"type": "debit",
"amount": 24000,
"narration": "PENSION",
"date": "2025-01-31",
"tax_category": "pension_contribution"
}
]
result = aggregator.aggregate_for_tax_year(test_txs, 2025)
assert result["gross_income"] == 500000, "Should aggregate gross income"
assert result["employee_pension_contribution"] == 24000, "Should aggregate pension"
print(f"[PASS] Aggregator working: Gross income = ₦{result['gross_income']:,.0f}")
return True
except Exception as e:
print(f"[FAIL] Aggregator test failed: {e}")
return False
def test_integration():
"""Test full integration without RAG"""
print("\nTesting integration (without RAG)...")
try:
from transaction_classifier import TransactionClassifier
from transaction_aggregator import TransactionAggregator
from rules_engine import RuleCatalog, TaxEngine
from datetime import date
# Initialize components
classifier = TransactionClassifier(rag_pipeline=None)
aggregator = TransactionAggregator()
# Load tax engine
catalog = RuleCatalog.from_yaml_files(["rules/rules_all.yaml"])
engine = TaxEngine(catalog, rounding_mode="half_up")
# Test transactions
transactions = [
{
"type": "credit",
"amount": 500000,
"narration": "SALARY PAYMENT",
"date": "2025-01-31",
"balance": 500000
},
{
"type": "debit",
"amount": 40000,
"narration": "PENSION CONTRIBUTION",
"date": "2025-01-31",
"balance": 460000
}
]
# Classify
classified = classifier.classify_batch(transactions)
# Aggregate
tax_inputs = aggregator.aggregate_for_tax_year(classified, 2025)
# Add required inputs for minimum wage exemption rule
tax_inputs["employment_income_annual"] = tax_inputs.get("gross_income", 0)
tax_inputs["min_wage_monthly"] = 70000 # Current minimum wage
# Calculate tax
result = engine.run(
tax_type="PIT",
as_of=date(2025, 12, 31),
jurisdiction="state",
inputs=tax_inputs
)
tax_due = result.values.get("tax_due", 0)
gross_income = tax_inputs['gross_income']
min_wage_threshold = tax_inputs['min_wage_monthly'] * 12
# Verify minimum wage exemption
if gross_income <= min_wage_threshold and tax_due > 0:
print(f"[WARN] Income ₦{gross_income:,.0f} is below exemption threshold ₦{min_wage_threshold:,.0f}")
print(f" But tax is ₦{tax_due:,.0f} (should be ₦0)")
print(f" This indicates the minimum wage exemption rule is not applying correctly")
print(f"[PASS] Integration test passed:")
print(f" Transactions: {len(transactions)}")
print(f" Classified: {len([t for t in classified if t['tax_category'] != 'uncategorized'])}")
print(f" Gross Income: ₦{tax_inputs['gross_income']:,.0f}")
print(f" Exemption Threshold: ₦{min_wage_threshold:,.0f}")
print(f" Tax Due: ₦{tax_due:,.0f}{' (EXEMPT)' if gross_income <= min_wage_threshold else ''}")
return True
except Exception as e:
print(f"[FAIL] Integration test failed: {e}")
import traceback
traceback.print_exc()
return False
def test_with_rag():
"""Test full optimization with RAG pipeline"""
print("\nTesting with RAG pipeline...")
try:
import os
from pathlib import Path
from transaction_classifier import TransactionClassifier
from transaction_aggregator import TransactionAggregator
from tax_strategy_extractor import TaxStrategyExtractor
from tax_optimizer import TaxOptimizer
from rules_engine import RuleCatalog, TaxEngine
from rag_pipeline import RAGPipeline, DocumentStore
# Check if GROQ_API_KEY is set
if not os.getenv("GROQ_API_KEY"):
print("[SKIP] GROQ_API_KEY not set - skipping RAG test")
print(" Set GROQ_API_KEY in .env to enable RAG testing")
return True # Don't fail the test, just skip
# Check if PDFs exist
pdf_source = Path("data")
if not pdf_source.exists() or not list(pdf_source.glob("*.pdf")):
print("[SKIP] No PDFs found in data/ - skipping RAG test")
return True # Don't fail the test, just skip
print(" Initializing RAG pipeline (this may take a moment)...")
# Initialize RAG
doc_store = DocumentStore(
persist_dir=Path("vector_store"),
embedding_model="sentence-transformers/all-MiniLM-L6-v2"
)
pdfs = doc_store.discover_pdfs(pdf_source)
doc_store.build_vector_store(pdfs, force_rebuild=False)
rag = RAGPipeline(doc_store=doc_store, model="llama-3.3-70b-versatile", temperature=0.1)
# Initialize tax engine
catalog = RuleCatalog.from_yaml_files(["rules/rules_all.yaml"])
engine = TaxEngine(catalog, rounding_mode="half_up")
# Initialize optimizer with RAG
classifier = TransactionClassifier(rag_pipeline=rag)
aggregator = TransactionAggregator()
strategy_extractor = TaxStrategyExtractor(rag_pipeline=rag)
optimizer = TaxOptimizer(
classifier=classifier,
aggregator=aggregator,
strategy_extractor=strategy_extractor,
tax_engine=engine
)
# Test transactions
transactions = [
{
"type": "credit",
"amount": 500000,
"narration": "SALARY PAYMENT FROM ABC COMPANY",
"date": "2025-01-31",
"balance": 500000
},
{
"type": "debit",
"amount": 40000,
"narration": "PENSION CONTRIBUTION TO XYZ PFA",
"date": "2025-01-31",
"balance": 460000
}
]
print(" Running optimization with RAG...")
result = optimizer.optimize(
user_id="test_user",
transactions=transactions,
tax_year=2025,
tax_type="PIT",
jurisdiction="state"
)
print(f"[PASS] RAG integration test passed:")
print(f" Baseline Tax: ₦{result['baseline_tax_liability']:,.0f}")
print(f" Potential Savings: ₦{result['total_potential_savings']:,.0f}")
print(f" Recommendations: {result['recommendation_count']}")
if result['recommendation_count'] > 0:
top_rec = result['recommendations'][0]
print(f" Top Strategy: {top_rec['strategy_name']}")
return True
except Exception as e:
print(f"[FAIL] RAG integration test failed: {e}")
import traceback
traceback.print_exc()
return False
def test_high_earner():
"""Test optimization for high earner (₦10M annual income)"""
print("\nTesting high earner optimization (₦10M/year)...")
try:
import os
from pathlib import Path
from transaction_classifier import TransactionClassifier
from transaction_aggregator import TransactionAggregator
from tax_strategy_extractor import TaxStrategyExtractor
from tax_optimizer import TaxOptimizer
from rules_engine import RuleCatalog, TaxEngine
from rag_pipeline import RAGPipeline, DocumentStore
# Check if GROQ_API_KEY is set
if not os.getenv("GROQ_API_KEY"):
print("[SKIP] GROQ_API_KEY not set - skipping high earner test")
return True
# Check if PDFs exist
pdf_source = Path("data")
if not pdf_source.exists() or not list(pdf_source.glob("*.pdf")):
print("[SKIP] No PDFs found - skipping high earner test")
return True
print(" Initializing components...")
# Initialize RAG
doc_store = DocumentStore(
persist_dir=Path("vector_store"),
embedding_model="sentence-transformers/all-MiniLM-L6-v2"
)
pdfs = doc_store.discover_pdfs(pdf_source)
doc_store.build_vector_store(pdfs, force_rebuild=False)
rag = RAGPipeline(doc_store=doc_store, model="llama-3.3-70b-versatile", temperature=0.1)
# Initialize tax engine
catalog = RuleCatalog.from_yaml_files(["rules/rules_all.yaml"])
engine = TaxEngine(catalog, rounding_mode="half_up")
# Initialize optimizer
classifier = TransactionClassifier(rag_pipeline=rag)
aggregator = TransactionAggregator()
strategy_extractor = TaxStrategyExtractor(rag_pipeline=rag)
optimizer = TaxOptimizer(
classifier=classifier,
aggregator=aggregator,
strategy_extractor=strategy_extractor,
tax_engine=engine
)
# Create realistic transactions for ₦10M earner
monthly_gross = 833333 # ₦10M / 12
transactions = []
# 12 months of salary
for month in range(1, 13):
date_str = f"2025-{month:02d}-28"
# Salary breakdown
transactions.append({
"type": "credit",
"amount": monthly_gross,
"narration": "SALARY PAYMENT FROM XYZ CORPORATION",
"date": date_str,
"balance": monthly_gross,
"metadata": {
"basic_salary": 500000, # 60% basic
"housing_allowance": 200000, # 24% housing
"transport_allowance": 100000, # 12% transport
"bonus": 33333 # 4% bonus
}
})
# Current pension (8% of basic = ₦40,000)
transactions.append({
"type": "debit",
"amount": 40000,
"narration": "PENSION CONTRIBUTION TO ABC PFA RSA",
"date": date_str,
"balance": monthly_gross - 40000
})
# NHF (2.5% of basic = ₦12,500)
transactions.append({
"type": "debit",
"amount": 12500,
"narration": "NHF HOUSING FUND DEDUCTION",
"date": date_str,
"balance": monthly_gross - 52500
})
# Annual life insurance
transactions.append({
"type": "debit",
"amount": 100000,
"narration": "LIFE INSURANCE PREMIUM - ANNUAL",
"date": "2025-01-15",
"balance": 700000
})
# Monthly rent
for month in range(1, 13):
transactions.append({
"type": "debit",
"amount": 300000,
"narration": "RENT PAYMENT TO LANDLORD",
"date": f"2025-{month:02d}-05",
"balance": 500000
})
print(f" Created {len(transactions)} transactions")
print(f" Annual gross income: ₦10,000,000")
print(f" Current pension: ₦{40000 * 12:,}/year (8%)")
print(f" Running optimization...")
result = optimizer.optimize(
user_id="high_earner_test",
transactions=transactions,
tax_year=2025,
tax_type="PIT",
jurisdiction="state"
)
print(f"\n{'='*80}")
print(f"HIGH EARNER OPTIMIZATION RESULTS (₦10M/year)")
print(f"{'='*80}")
print(f"\nTax Summary:")
print(f" Baseline Tax: ₦{result['baseline_tax_liability']:,.0f}")
print(f" Optimized Tax: ₦{result['optimized_tax_liability']:,.0f}")
print(f" Potential Savings: ₦{result['total_potential_savings']:,.0f}")
print(f" Savings Percentage: {result['savings_percentage']:.1f}%")
print(f"\nIncome & Deductions:")
print(f" Total Annual Income: ₦{result['total_annual_income']:,.0f}")
print(f" Current Deductions:")
for key, value in result['current_deductions'].items():
if key != 'total' and value > 0:
print(f" - {key.replace('_', ' ').title()}: ₦{value:,.0f}")
print(f" Total: ₦{result['current_deductions']['total']:,.0f}")
print(f"\nTop Recommendations:")
for i, rec in enumerate(result['recommendations'][:5], 1):
print(f"\n {i}. {rec['strategy_name']}")
print(f" Annual Savings: ₦{rec['annual_tax_savings']:,.0f}")
print(f" Description: {rec['description']}")
print(f" Risk: {rec['risk_level'].upper()} | Complexity: {rec['complexity'].upper()}")
if rec['implementation_steps']:
print(f" Implementation:")
for step in rec['implementation_steps'][:2]:
print(f" • {step}")
print(f"\n{'='*80}")
# Verify results make sense
assert result['baseline_tax_liability'] > 0, "High earner should have tax liability"
assert result['total_annual_income'] >= 9900000, "Should have ~₦10M income (allowing for rounding)"
assert result['recommendation_count'] >= 0, "Should have recommendations (or 0 if already optimal)"
print(f"[PASS] High earner test passed!")
return True
except Exception as e:
print(f"[FAIL] High earner test failed: {e}")
import traceback
traceback.print_exc()
return False
def main():
"""Run all tests"""
print("=" * 80)
print("TAX OPTIMIZER MODULE TESTS")
print("=" * 80)
results = []
results.append(("Imports", test_imports()))
results.append(("Classifier", test_classifier()))
results.append(("Aggregator", test_aggregator()))
results.append(("Integration (no RAG)", test_integration()))
results.append(("Integration (with RAG)", test_with_rag()))
results.append(("High Earner (₦10M)", test_high_earner()))
print("\n" + "=" * 80)
print("TEST RESULTS")
print("=" * 80)
for test_name, passed in results:
status = "[PASS]" if passed else "[FAIL]"
print(f"{test_name:20s} {status}")
all_passed = all(result[1] for result in results)
print("\n" + "=" * 80)
if all_passed:
print("[SUCCESS] ALL TESTS PASSED - Ready to start API")
print("\nNext steps:")
print("1. Ensure GROQ_API_KEY is set in .env")
print("2. Start API: uvicorn orchestrator:app --reload --port 8000")
print("3. Test endpoint: python example_optimize.py")
else:
print("[ERROR] SOME TESTS FAILED - Fix errors before starting API")
print("=" * 80)
return all_passed
if __name__ == "__main__":
import sys
success = main()
sys.exit(0 if success else 1)