""" Test script for cross-encoder re-ranking in RAG search. This script tests: 1. Model loading 2. Re-ranking functionality 3. Comparison of results with/without re-ranking """ import sys import asyncio from pathlib import Path # Add backend to path backend_dir = Path(__file__).parent / "backend" sys.path.insert(0, str(backend_dir)) from mcp_server.common.reranker import rerank_results, _get_reranker def test_model_loading(): """Test that the cross-encoder model loads correctly.""" print("=" * 60) print("Test 1: Model Loading") print("=" * 60) try: reranker = _get_reranker() if reranker is None: print("❌ FAILED: Reranker model is None (sentence-transformers not available?)") return False print("✅ SUCCESS: Cross-encoder model loaded successfully") print(f" Model type: {type(reranker).__name__}") return True except Exception as e: print(f"❌ FAILED: Error loading model: {e}") return False def test_reranking_basic(): """Test basic re-ranking functionality.""" print("\n" + "=" * 60) print("Test 2: Basic Re-ranking") print("=" * 60) query = "What is the refund policy?" candidates = [ {"text": "Our refund policy allows returns within 30 days.", "score": 0.85, "relevance": 0.85}, {"text": "The company was founded in 2020.", "score": 0.45, "relevance": 0.45}, {"text": "Refunds are processed within 5-7 business days after approval.", "score": 0.72, "relevance": 0.72}, {"text": "Contact support for assistance.", "score": 0.30, "relevance": 0.30}, ] print(f"Query: {query}") print(f"\nOriginal order (by vector similarity):") for i, cand in enumerate(candidates, 1): print(f" {i}. Score: {cand['score']:.3f} - {cand['text'][:60]}...") try: reranked = rerank_results(query, candidates, top_k=3) if not reranked: print("❌ FAILED: Re-ranking returned empty results") return False print(f"\nRe-ranked order (by cross-encoder):") for i, cand in enumerate(reranked, 1): print(f" {i}. Score: {cand['score']:.3f} - {cand['text'][:60]}...") # Check that results are sorted by score (descending) scores = [c.get("score", 0.0) for c in reranked] if scores != sorted(scores, reverse=True): print("❌ FAILED: Results are not sorted by score") return False # Check that reranked flag is set if not all(c.get("reranked") is True for c in reranked): print("❌ FAILED: 'reranked' flag not set") return False print("✅ SUCCESS: Re-ranking works correctly") return True except Exception as e: print(f"❌ FAILED: Error during re-ranking: {e}") import traceback traceback.print_exc() return False def test_reranking_empty(): """Test re-ranking with empty candidates.""" print("\n" + "=" * 60) print("Test 3: Empty Candidates Handling") print("=" * 60) try: reranked = rerank_results("test query", []) if reranked == []: print("✅ SUCCESS: Empty candidates handled correctly") return True else: print(f"❌ FAILED: Expected empty list, got {reranked}") return False except Exception as e: print(f"❌ FAILED: Error with empty candidates: {e}") return False async def test_rag_search_integration(): """Test RAG search with re-ranking (requires database).""" print("\n" + "=" * 60) print("Test 4: RAG Search Integration (requires database)") print("=" * 60) try: from mcp_server.rag.search import rag_search from mcp_server.common.tenant import TenantContext # Create a test tenant context context = TenantContext(tenant_id="test_tenant_rerank") # Test search payload = { "query": "test query", "limit": 5, "threshold": 0.1 } print(f"Testing RAG search with query: '{payload['query']}'") print("Note: This requires a running database with documents.") result = await rag_search(context, payload) print(f"\nResults: {len(result.get('results', []))} items") print(f"Metadata: {result.get('metadata', {})}") if result.get('metadata', {}).get('reranked'): print("✅ SUCCESS: Re-ranking was applied") else: print("⚠️ WARNING: Re-ranking was not applied (may be normal if no candidates found)") return True except Exception as e: print(f"⚠️ SKIPPED: Integration test requires database: {e}") return None def main(): """Run all tests.""" print("\n" + "=" * 60) print("Cross-Encoder Re-ranking Test Suite") print("=" * 60) results = [] # Test 1: Model loading results.append(("Model Loading", test_model_loading())) # Test 2: Basic re-ranking results.append(("Basic Re-ranking", test_reranking_basic())) # Test 3: Empty candidates results.append(("Empty Candidates", test_reranking_empty())) # Test 4: Integration (optional, requires DB) try: integration_result = asyncio.run(test_rag_search_integration()) if integration_result is not None: results.append(("RAG Integration", integration_result)) except Exception as e: print(f"⚠️ Integration test skipped: {e}") # Summary print("\n" + "=" * 60) print("Test Summary") print("=" * 60) passed = sum(1 for _, result in results if result is True) total = len(results) for test_name, result in results: status = "✅ PASS" if result is True else "❌ FAIL" if result is False else "⚠️ SKIP" print(f"{status}: {test_name}") print(f"\nTotal: {passed}/{total} tests passed") if passed == total: print("\n🎉 All tests passed!") else: print("\n⚠️ Some tests failed. Check output above for details.") if __name__ == "__main__": main()