Spaces:
Sleeping
Sleeping
| """ | |
| Test script for cross-encoder re-ranking in RAG search. | |
| This script tests: | |
| 1. Model loading | |
| 2. Re-ranking functionality | |
| 3. Comparison of results with/without re-ranking | |
| """ | |
| import sys | |
| import asyncio | |
| from pathlib import Path | |
| # Add backend to path | |
| backend_dir = Path(__file__).parent / "backend" | |
| sys.path.insert(0, str(backend_dir)) | |
| from mcp_server.common.reranker import rerank_results, _get_reranker | |
| def test_model_loading(): | |
| """Test that the cross-encoder model loads correctly.""" | |
| print("=" * 60) | |
| print("Test 1: Model Loading") | |
| print("=" * 60) | |
| try: | |
| reranker = _get_reranker() | |
| if reranker is None: | |
| print("β FAILED: Reranker model is None (sentence-transformers not available?)") | |
| return False | |
| print("β SUCCESS: Cross-encoder model loaded successfully") | |
| print(f" Model type: {type(reranker).__name__}") | |
| return True | |
| except Exception as e: | |
| print(f"β FAILED: Error loading model: {e}") | |
| return False | |
| def test_reranking_basic(): | |
| """Test basic re-ranking functionality.""" | |
| print("\n" + "=" * 60) | |
| print("Test 2: Basic Re-ranking") | |
| print("=" * 60) | |
| query = "What is the refund policy?" | |
| candidates = [ | |
| {"text": "Our refund policy allows returns within 30 days.", "score": 0.85, "relevance": 0.85}, | |
| {"text": "The company was founded in 2020.", "score": 0.45, "relevance": 0.45}, | |
| {"text": "Refunds are processed within 5-7 business days after approval.", "score": 0.72, "relevance": 0.72}, | |
| {"text": "Contact support for assistance.", "score": 0.30, "relevance": 0.30}, | |
| ] | |
| print(f"Query: {query}") | |
| print(f"\nOriginal order (by vector similarity):") | |
| for i, cand in enumerate(candidates, 1): | |
| print(f" {i}. Score: {cand['score']:.3f} - {cand['text'][:60]}...") | |
| try: | |
| reranked = rerank_results(query, candidates, top_k=3) | |
| if not reranked: | |
| print("β FAILED: Re-ranking returned empty results") | |
| return False | |
| print(f"\nRe-ranked order (by cross-encoder):") | |
| for i, cand in enumerate(reranked, 1): | |
| print(f" {i}. Score: {cand['score']:.3f} - {cand['text'][:60]}...") | |
| # Check that results are sorted by score (descending) | |
| scores = [c.get("score", 0.0) for c in reranked] | |
| if scores != sorted(scores, reverse=True): | |
| print("β FAILED: Results are not sorted by score") | |
| return False | |
| # Check that reranked flag is set | |
| if not all(c.get("reranked") is True for c in reranked): | |
| print("β FAILED: 'reranked' flag not set") | |
| return False | |
| print("β SUCCESS: Re-ranking works correctly") | |
| return True | |
| except Exception as e: | |
| print(f"β FAILED: Error during re-ranking: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return False | |
| def test_reranking_empty(): | |
| """Test re-ranking with empty candidates.""" | |
| print("\n" + "=" * 60) | |
| print("Test 3: Empty Candidates Handling") | |
| print("=" * 60) | |
| try: | |
| reranked = rerank_results("test query", []) | |
| if reranked == []: | |
| print("β SUCCESS: Empty candidates handled correctly") | |
| return True | |
| else: | |
| print(f"β FAILED: Expected empty list, got {reranked}") | |
| return False | |
| except Exception as e: | |
| print(f"β FAILED: Error with empty candidates: {e}") | |
| return False | |
| async def test_rag_search_integration(): | |
| """Test RAG search with re-ranking (requires database).""" | |
| print("\n" + "=" * 60) | |
| print("Test 4: RAG Search Integration (requires database)") | |
| print("=" * 60) | |
| try: | |
| from mcp_server.rag.search import rag_search | |
| from mcp_server.common.tenant import TenantContext | |
| # Create a test tenant context | |
| context = TenantContext(tenant_id="test_tenant_rerank") | |
| # Test search | |
| payload = { | |
| "query": "test query", | |
| "limit": 5, | |
| "threshold": 0.1 | |
| } | |
| print(f"Testing RAG search with query: '{payload['query']}'") | |
| print("Note: This requires a running database with documents.") | |
| result = await rag_search(context, payload) | |
| print(f"\nResults: {len(result.get('results', []))} items") | |
| print(f"Metadata: {result.get('metadata', {})}") | |
| if result.get('metadata', {}).get('reranked'): | |
| print("β SUCCESS: Re-ranking was applied") | |
| else: | |
| print("β οΈ WARNING: Re-ranking was not applied (may be normal if no candidates found)") | |
| return True | |
| except Exception as e: | |
| print(f"β οΈ SKIPPED: Integration test requires database: {e}") | |
| return None | |
| def main(): | |
| """Run all tests.""" | |
| print("\n" + "=" * 60) | |
| print("Cross-Encoder Re-ranking Test Suite") | |
| print("=" * 60) | |
| results = [] | |
| # Test 1: Model loading | |
| results.append(("Model Loading", test_model_loading())) | |
| # Test 2: Basic re-ranking | |
| results.append(("Basic Re-ranking", test_reranking_basic())) | |
| # Test 3: Empty candidates | |
| results.append(("Empty Candidates", test_reranking_empty())) | |
| # Test 4: Integration (optional, requires DB) | |
| try: | |
| integration_result = asyncio.run(test_rag_search_integration()) | |
| if integration_result is not None: | |
| results.append(("RAG Integration", integration_result)) | |
| except Exception as e: | |
| print(f"β οΈ Integration test skipped: {e}") | |
| # Summary | |
| print("\n" + "=" * 60) | |
| print("Test Summary") | |
| print("=" * 60) | |
| passed = sum(1 for _, result in results if result is True) | |
| total = len(results) | |
| for test_name, result in results: | |
| status = "β PASS" if result is True else "β FAIL" if result is False else "β οΈ SKIP" | |
| print(f"{status}: {test_name}") | |
| print(f"\nTotal: {passed}/{total} tests passed") | |
| if passed == total: | |
| print("\nπ All tests passed!") | |
| else: | |
| print("\nβ οΈ Some tests failed. Check output above for details.") | |
| if __name__ == "__main__": | |
| main() | |