Spaces:
Sleeping
Sleeping
File size: 6,319 Bytes
0e8c152 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
"""
Test script for cross-encoder re-ranking in RAG search.
This script tests:
1. Model loading
2. Re-ranking functionality
3. Comparison of results with/without re-ranking
"""
import sys
import asyncio
from pathlib import Path
# Add backend to path
backend_dir = Path(__file__).parent / "backend"
sys.path.insert(0, str(backend_dir))
from mcp_server.common.reranker import rerank_results, _get_reranker
def test_model_loading():
"""Test that the cross-encoder model loads correctly."""
print("=" * 60)
print("Test 1: Model Loading")
print("=" * 60)
try:
reranker = _get_reranker()
if reranker is None:
print("β FAILED: Reranker model is None (sentence-transformers not available?)")
return False
print("β
SUCCESS: Cross-encoder model loaded successfully")
print(f" Model type: {type(reranker).__name__}")
return True
except Exception as e:
print(f"β FAILED: Error loading model: {e}")
return False
def test_reranking_basic():
"""Test basic re-ranking functionality."""
print("\n" + "=" * 60)
print("Test 2: Basic Re-ranking")
print("=" * 60)
query = "What is the refund policy?"
candidates = [
{"text": "Our refund policy allows returns within 30 days.", "score": 0.85, "relevance": 0.85},
{"text": "The company was founded in 2020.", "score": 0.45, "relevance": 0.45},
{"text": "Refunds are processed within 5-7 business days after approval.", "score": 0.72, "relevance": 0.72},
{"text": "Contact support for assistance.", "score": 0.30, "relevance": 0.30},
]
print(f"Query: {query}")
print(f"\nOriginal order (by vector similarity):")
for i, cand in enumerate(candidates, 1):
print(f" {i}. Score: {cand['score']:.3f} - {cand['text'][:60]}...")
try:
reranked = rerank_results(query, candidates, top_k=3)
if not reranked:
print("β FAILED: Re-ranking returned empty results")
return False
print(f"\nRe-ranked order (by cross-encoder):")
for i, cand in enumerate(reranked, 1):
print(f" {i}. Score: {cand['score']:.3f} - {cand['text'][:60]}...")
# Check that results are sorted by score (descending)
scores = [c.get("score", 0.0) for c in reranked]
if scores != sorted(scores, reverse=True):
print("β FAILED: Results are not sorted by score")
return False
# Check that reranked flag is set
if not all(c.get("reranked") is True for c in reranked):
print("β FAILED: 'reranked' flag not set")
return False
print("β
SUCCESS: Re-ranking works correctly")
return True
except Exception as e:
print(f"β FAILED: Error during re-ranking: {e}")
import traceback
traceback.print_exc()
return False
def test_reranking_empty():
"""Test re-ranking with empty candidates."""
print("\n" + "=" * 60)
print("Test 3: Empty Candidates Handling")
print("=" * 60)
try:
reranked = rerank_results("test query", [])
if reranked == []:
print("β
SUCCESS: Empty candidates handled correctly")
return True
else:
print(f"β FAILED: Expected empty list, got {reranked}")
return False
except Exception as e:
print(f"β FAILED: Error with empty candidates: {e}")
return False
async def test_rag_search_integration():
"""Test RAG search with re-ranking (requires database)."""
print("\n" + "=" * 60)
print("Test 4: RAG Search Integration (requires database)")
print("=" * 60)
try:
from mcp_server.rag.search import rag_search
from mcp_server.common.tenant import TenantContext
# Create a test tenant context
context = TenantContext(tenant_id="test_tenant_rerank")
# Test search
payload = {
"query": "test query",
"limit": 5,
"threshold": 0.1
}
print(f"Testing RAG search with query: '{payload['query']}'")
print("Note: This requires a running database with documents.")
result = await rag_search(context, payload)
print(f"\nResults: {len(result.get('results', []))} items")
print(f"Metadata: {result.get('metadata', {})}")
if result.get('metadata', {}).get('reranked'):
print("β
SUCCESS: Re-ranking was applied")
else:
print("β οΈ WARNING: Re-ranking was not applied (may be normal if no candidates found)")
return True
except Exception as e:
print(f"β οΈ SKIPPED: Integration test requires database: {e}")
return None
def main():
"""Run all tests."""
print("\n" + "=" * 60)
print("Cross-Encoder Re-ranking Test Suite")
print("=" * 60)
results = []
# Test 1: Model loading
results.append(("Model Loading", test_model_loading()))
# Test 2: Basic re-ranking
results.append(("Basic Re-ranking", test_reranking_basic()))
# Test 3: Empty candidates
results.append(("Empty Candidates", test_reranking_empty()))
# Test 4: Integration (optional, requires DB)
try:
integration_result = asyncio.run(test_rag_search_integration())
if integration_result is not None:
results.append(("RAG Integration", integration_result))
except Exception as e:
print(f"β οΈ Integration test skipped: {e}")
# Summary
print("\n" + "=" * 60)
print("Test Summary")
print("=" * 60)
passed = sum(1 for _, result in results if result is True)
total = len(results)
for test_name, result in results:
status = "β
PASS" if result is True else "β FAIL" if result is False else "β οΈ SKIP"
print(f"{status}: {test_name}")
print(f"\nTotal: {passed}/{total} tests passed")
if passed == total:
print("\nπ All tests passed!")
else:
print("\nβ οΈ Some tests failed. Check output above for details.")
if __name__ == "__main__":
main()
|