IntegraChat / test_reranking.py
nothingworry's picture
feat: update the encoding model
0e8c152
raw
history blame
6.32 kB
"""
Test script for cross-encoder re-ranking in RAG search.
This script tests:
1. Model loading
2. Re-ranking functionality
3. Comparison of results with/without re-ranking
"""
import sys
import asyncio
from pathlib import Path
# Add backend to path
backend_dir = Path(__file__).parent / "backend"
sys.path.insert(0, str(backend_dir))
from mcp_server.common.reranker import rerank_results, _get_reranker
def test_model_loading():
"""Test that the cross-encoder model loads correctly."""
print("=" * 60)
print("Test 1: Model Loading")
print("=" * 60)
try:
reranker = _get_reranker()
if reranker is None:
print("❌ FAILED: Reranker model is None (sentence-transformers not available?)")
return False
print("βœ… SUCCESS: Cross-encoder model loaded successfully")
print(f" Model type: {type(reranker).__name__}")
return True
except Exception as e:
print(f"❌ FAILED: Error loading model: {e}")
return False
def test_reranking_basic():
"""Test basic re-ranking functionality."""
print("\n" + "=" * 60)
print("Test 2: Basic Re-ranking")
print("=" * 60)
query = "What is the refund policy?"
candidates = [
{"text": "Our refund policy allows returns within 30 days.", "score": 0.85, "relevance": 0.85},
{"text": "The company was founded in 2020.", "score": 0.45, "relevance": 0.45},
{"text": "Refunds are processed within 5-7 business days after approval.", "score": 0.72, "relevance": 0.72},
{"text": "Contact support for assistance.", "score": 0.30, "relevance": 0.30},
]
print(f"Query: {query}")
print(f"\nOriginal order (by vector similarity):")
for i, cand in enumerate(candidates, 1):
print(f" {i}. Score: {cand['score']:.3f} - {cand['text'][:60]}...")
try:
reranked = rerank_results(query, candidates, top_k=3)
if not reranked:
print("❌ FAILED: Re-ranking returned empty results")
return False
print(f"\nRe-ranked order (by cross-encoder):")
for i, cand in enumerate(reranked, 1):
print(f" {i}. Score: {cand['score']:.3f} - {cand['text'][:60]}...")
# Check that results are sorted by score (descending)
scores = [c.get("score", 0.0) for c in reranked]
if scores != sorted(scores, reverse=True):
print("❌ FAILED: Results are not sorted by score")
return False
# Check that reranked flag is set
if not all(c.get("reranked") is True for c in reranked):
print("❌ FAILED: 'reranked' flag not set")
return False
print("βœ… SUCCESS: Re-ranking works correctly")
return True
except Exception as e:
print(f"❌ FAILED: Error during re-ranking: {e}")
import traceback
traceback.print_exc()
return False
def test_reranking_empty():
"""Test re-ranking with empty candidates."""
print("\n" + "=" * 60)
print("Test 3: Empty Candidates Handling")
print("=" * 60)
try:
reranked = rerank_results("test query", [])
if reranked == []:
print("βœ… SUCCESS: Empty candidates handled correctly")
return True
else:
print(f"❌ FAILED: Expected empty list, got {reranked}")
return False
except Exception as e:
print(f"❌ FAILED: Error with empty candidates: {e}")
return False
async def test_rag_search_integration():
"""Test RAG search with re-ranking (requires database)."""
print("\n" + "=" * 60)
print("Test 4: RAG Search Integration (requires database)")
print("=" * 60)
try:
from mcp_server.rag.search import rag_search
from mcp_server.common.tenant import TenantContext
# Create a test tenant context
context = TenantContext(tenant_id="test_tenant_rerank")
# Test search
payload = {
"query": "test query",
"limit": 5,
"threshold": 0.1
}
print(f"Testing RAG search with query: '{payload['query']}'")
print("Note: This requires a running database with documents.")
result = await rag_search(context, payload)
print(f"\nResults: {len(result.get('results', []))} items")
print(f"Metadata: {result.get('metadata', {})}")
if result.get('metadata', {}).get('reranked'):
print("βœ… SUCCESS: Re-ranking was applied")
else:
print("⚠️ WARNING: Re-ranking was not applied (may be normal if no candidates found)")
return True
except Exception as e:
print(f"⚠️ SKIPPED: Integration test requires database: {e}")
return None
def main():
"""Run all tests."""
print("\n" + "=" * 60)
print("Cross-Encoder Re-ranking Test Suite")
print("=" * 60)
results = []
# Test 1: Model loading
results.append(("Model Loading", test_model_loading()))
# Test 2: Basic re-ranking
results.append(("Basic Re-ranking", test_reranking_basic()))
# Test 3: Empty candidates
results.append(("Empty Candidates", test_reranking_empty()))
# Test 4: Integration (optional, requires DB)
try:
integration_result = asyncio.run(test_rag_search_integration())
if integration_result is not None:
results.append(("RAG Integration", integration_result))
except Exception as e:
print(f"⚠️ Integration test skipped: {e}")
# Summary
print("\n" + "=" * 60)
print("Test Summary")
print("=" * 60)
passed = sum(1 for _, result in results if result is True)
total = len(results)
for test_name, result in results:
status = "βœ… PASS" if result is True else "❌ FAIL" if result is False else "⚠️ SKIP"
print(f"{status}: {test_name}")
print(f"\nTotal: {passed}/{total} tests passed")
if passed == total:
print("\nπŸŽ‰ All tests passed!")
else:
print("\n⚠️ Some tests failed. Check output above for details.")
if __name__ == "__main__":
main()