nothingworry commited on
Commit
0e8c152
Β·
1 Parent(s): fe818bb

feat: update the encoding model

Browse files
backend/mcp_server/common/reranker.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cross-encoder re-ranking for RAG search results.
3
+
4
+ Uses cross-encoder/ms-marco-MiniLM-L-6-v2 for fast, accurate re-ranking
5
+ of vector search results to improve retrieval accuracy.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from functools import lru_cache
11
+ from typing import List, Dict, Any, Optional
12
+
13
+ try:
14
+ from sentence_transformers import CrossEncoder
15
+ except ImportError:
16
+ CrossEncoder = None # type: ignore
17
+
18
+
19
+ @lru_cache(maxsize=1)
20
+ def _get_reranker() -> Optional[Any]:
21
+ """
22
+ Lazily load the cross-encoder model once per process.
23
+
24
+ Uses cross-encoder/ms-marco-MiniLM-L-6-v2 which is optimized for
25
+ MS MARCO dataset and provides fast, accurate re-ranking.
26
+ """
27
+ if CrossEncoder is None:
28
+ return None
29
+ try:
30
+ # Load the cross-encoder model
31
+ # This model is specifically trained for re-ranking search results
32
+ model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
33
+ return model
34
+ except Exception as e:
35
+ print(f"Warning: Failed to load cross-encoder model: {e}")
36
+ print("RAG search will continue without re-ranking.")
37
+ return None
38
+
39
+
40
+ def rerank_results(
41
+ query: str,
42
+ candidates: List[Dict[str, Any]],
43
+ top_k: Optional[int] = None,
44
+ ) -> List[Dict[str, Any]]:
45
+ """
46
+ Re-rank search results using cross-encoder for improved accuracy.
47
+
48
+ Args:
49
+ query: The search query
50
+ candidates: List of candidate results, each with at least a "text" field
51
+ top_k: Optional limit on number of results to return after re-ranking
52
+
53
+ Returns:
54
+ Re-ranked list of candidates with updated "score" and "relevance" fields
55
+ """
56
+ if not candidates:
57
+ return []
58
+
59
+ reranker = _get_reranker()
60
+
61
+ # If cross-encoder is not available, return original results
62
+ if reranker is None:
63
+ return candidates
64
+
65
+ try:
66
+ # Prepare pairs: (query, candidate_text) for each candidate
67
+ pairs = [(query, candidate.get("text", "")) for candidate in candidates]
68
+
69
+ # Get re-ranking scores (higher = more relevant)
70
+ # Cross-encoder outputs raw scores (can be negative or positive)
71
+ scores = reranker.predict(pairs)
72
+
73
+ # Update candidates with new scores
74
+ reranked = []
75
+ for candidate, score in zip(candidates, scores):
76
+ # Cross-encoder scores are logits, normalize to 0-1 using sigmoid
77
+ # This ensures scores are in [0, 1] range for consistency with vector similarity scores
78
+ try:
79
+ import numpy as np
80
+ # Apply sigmoid to normalize logit scores to [0, 1]
81
+ normalized_score = float(1.0 / (1.0 + np.exp(-float(score))))
82
+ except (ImportError, ValueError, TypeError):
83
+ # Fallback: if numpy not available, use simple normalization
84
+ # Cross-encoder scores for ms-marco-MiniLM-L-6-v2 are typically in [-10, 10] range
85
+ # Simple linear scaling to [0, 1] as fallback
86
+ score_float = float(score) if isinstance(score, (int, float)) else 0.0
87
+ normalized_score = max(0.0, min(1.0, (score_float + 10.0) / 20.0))
88
+
89
+ # Update the candidate with re-ranked score
90
+ updated = {
91
+ **candidate,
92
+ "score": normalized_score,
93
+ "relevance": normalized_score, # Keep both for compatibility
94
+ "reranked": True, # Flag to indicate this was re-ranked
95
+ }
96
+ reranked.append(updated)
97
+
98
+ # Sort by re-ranked score (descending)
99
+ reranked.sort(key=lambda x: x.get("score", 0.0), reverse=True)
100
+
101
+ # Return top_k if specified
102
+ if top_k is not None and top_k > 0:
103
+ reranked = reranked[:top_k]
104
+
105
+ return reranked
106
+
107
+ except Exception as e:
108
+ print(f"Warning: Cross-encoder re-ranking failed: {e}")
109
+ print("Returning original results without re-ranking.")
110
+ return candidates
111
+
backend/mcp_server/rag/search.py CHANGED
@@ -6,6 +6,7 @@ from typing import Any, Mapping
6
  from backend.mcp_server.common.database import search_vectors
7
  from backend.mcp_server.common.embeddings import embed_text
8
  from backend.mcp_server.common.logging import log_rag_search_metrics
 
9
  from backend.mcp_server.common.tenant import TenantContext
10
  from backend.mcp_server.common.utils import ToolValidationError, tool_handler
11
 
@@ -33,32 +34,70 @@ async def rag_search(context: TenantContext, payload: Mapping[str, Any]) -> dict
33
  raise ToolValidationError("threshold must be a float between 0.0 and 1.0")
34
 
35
  embedding = embed_text(query)
36
- raw_results = search_vectors(context.tenant_id, embedding, limit=limit_value)
37
- # Return top results even if slightly below threshold, but prioritize high-scoring ones
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  filtered = []
39
- for chunk in raw_results:
40
- similarity = chunk.get("similarity", 0.0)
 
41
  if similarity >= threshold_value:
42
  filtered.append({
43
  "text": chunk.get("text", ""),
44
  "relevance": similarity,
45
  "score": similarity # Add score field for compatibility
46
  })
47
- # If we have results above threshold, return top 3. Otherwise, return top 1 even if below threshold.
 
48
  if filtered:
49
- filtered = sorted(filtered, key=lambda x: x.get("relevance", 0.0), reverse=True)[:3]
50
- elif raw_results:
51
  # Return the top result even if below threshold, as it might still be relevant
52
- top_chunk = raw_results[0]
 
53
  filtered = [{
54
  "text": top_chunk.get("text", ""),
55
- "relevance": top_chunk.get("similarity", 0.0),
56
- "score": top_chunk.get("similarity", 0.0)
57
  }]
58
 
59
- hits = len(raw_results)
60
- avg_score = mean([item.get("similarity", 0.0) for item in raw_results]) if raw_results else None
61
- top_score = raw_results[0].get("similarity") if raw_results else None
 
 
 
 
 
62
 
63
  log_rag_search_metrics(
64
  tenant_id=context.tenant_id,
@@ -74,7 +113,8 @@ async def rag_search(context: TenantContext, payload: Mapping[str, Any]) -> dict
74
  "metadata": {
75
  "limit": limit_value,
76
  "threshold": threshold_value,
77
- "hits_before_filter": hits,
 
78
  },
79
  }
80
 
 
6
  from backend.mcp_server.common.database import search_vectors
7
  from backend.mcp_server.common.embeddings import embed_text
8
  from backend.mcp_server.common.logging import log_rag_search_metrics
9
+ from backend.mcp_server.common.reranker import rerank_results
10
  from backend.mcp_server.common.tenant import TenantContext
11
  from backend.mcp_server.common.utils import ToolValidationError, tool_handler
12
 
 
34
  raise ToolValidationError("threshold must be a float between 0.0 and 1.0")
35
 
36
  embedding = embed_text(query)
37
+
38
+ # Step 1: Get top 10 candidates from vector search for re-ranking
39
+ # We fetch more candidates than requested to allow cross-encoder to find the best matches
40
+ rerank_candidates_count = max(10, limit_value * 2) # Get at least 10, or 2x the requested limit
41
+ raw_results = search_vectors(context.tenant_id, embedding, limit=rerank_candidates_count)
42
+
43
+ # Step 2: Re-rank candidates using cross-encoder for improved accuracy
44
+ # Re-rank up to top 10 candidates (or all if fewer than 10)
45
+ candidates_for_rerank = raw_results[:10] # Re-rank top 10 (or all available)
46
+ reranked_results = None
47
+
48
+ if candidates_for_rerank:
49
+ # Prepare candidates with text and initial similarity score
50
+ candidates = [
51
+ {
52
+ "text": chunk.get("text", ""),
53
+ "relevance": chunk.get("similarity", 0.0),
54
+ "score": chunk.get("similarity", 0.0),
55
+ }
56
+ for chunk in candidates_for_rerank
57
+ ]
58
+
59
+ # Re-rank using cross-encoder (returns top_k results already sorted)
60
+ reranked = rerank_results(query, candidates, top_k=limit_value)
61
+
62
+ if reranked:
63
+ reranked_results = reranked
64
+
65
+ # Step 3: Use re-ranked results if available, otherwise use original vector search results
66
+ results_to_filter = reranked_results if reranked_results else raw_results
67
+
68
+ # Step 4: Filter by threshold and return top results
69
  filtered = []
70
+ for chunk in results_to_filter:
71
+ # Re-ranked results have "score" and "relevance", original have "similarity"
72
+ similarity = chunk.get("similarity") or chunk.get("score") or chunk.get("relevance") or 0.0
73
  if similarity >= threshold_value:
74
  filtered.append({
75
  "text": chunk.get("text", ""),
76
  "relevance": similarity,
77
  "score": similarity # Add score field for compatibility
78
  })
79
+
80
+ # If we have results above threshold, return top results. Otherwise, return top 1 even if below threshold.
81
  if filtered:
82
+ filtered = sorted(filtered, key=lambda x: x.get("relevance", 0.0), reverse=True)[:limit_value]
83
+ elif results_to_filter:
84
  # Return the top result even if below threshold, as it might still be relevant
85
+ top_chunk = results_to_filter[0]
86
+ similarity = top_chunk.get("similarity") or top_chunk.get("score") or top_chunk.get("relevance") or 0.0
87
  filtered = [{
88
  "text": top_chunk.get("text", ""),
89
+ "relevance": similarity,
90
+ "score": similarity
91
  }]
92
 
93
+ # Calculate metrics from the results we're using (re-ranked or original)
94
+ hits = len(results_to_filter)
95
+ scores_for_metrics = [
96
+ item.get("similarity") or item.get("score") or item.get("relevance") or 0.0
97
+ for item in results_to_filter
98
+ ]
99
+ avg_score = mean(scores_for_metrics) if scores_for_metrics else None
100
+ top_score = scores_for_metrics[0] if scores_for_metrics else None
101
 
102
  log_rag_search_metrics(
103
  tenant_id=context.tenant_id,
 
113
  "metadata": {
114
  "limit": limit_value,
115
  "threshold": threshold_value,
116
+ "hits_before_filter": len(raw_results),
117
+ "reranked": reranked_results is not None,
118
  },
119
  }
120
 
frontend/app/admin-rules/page.tsx CHANGED
@@ -52,44 +52,6 @@ export default function AdminRulesPage() {
52
  const [lastUpdated, setLastUpdated] = useState<string>("");
53
  const fileInputRef = useRef<HTMLInputElement>(null);
54
 
55
- // Check permissions early
56
- if (!canManageRules(role)) {
57
- return (
58
- <main className="mx-auto flex min-h-screen max-w-5xl flex-col gap-10 px-4 pb-16 pt-12 sm:px-6 lg:px-8">
59
- <header className="flex flex-col gap-4 rounded-2xl border border-white/10 bg-white/5 px-6 py-6 text-slate-100 shadow-lg shadow-slate-950/40">
60
- <div className="flex items-center justify-between gap-3">
61
- <div className="flex items-center gap-3 text-base font-semibold">
62
- <span className="inline-flex h-10 w-10 items-center justify-center rounded-2xl bg-gradient-to-br from-sky-400 to-cyan-500 text-slate-950">
63
- IC
64
- </span>
65
- IntegraChat Β· Admin Rules
66
- </div>
67
- <div className="flex items-center gap-4">
68
- <TenantSelector />
69
- <Link href="/" className="text-xs font-semibold uppercase tracking-[0.3em] text-cyan-300 hover:text-white">
70
- ← Back Home
71
- </Link>
72
- </div>
73
- </div>
74
- </header>
75
-
76
- <div className="rounded-2xl border border-red-500/50 bg-red-500/10 p-8 text-center">
77
- <h2 className="text-2xl font-bold text-red-300 mb-2">Access Denied</h2>
78
- <p className="text-slate-300 mb-4">
79
- You need <strong>Admin</strong> or <strong>Owner</strong> role to manage rules.
80
- </p>
81
- <p className="text-sm text-slate-400">
82
- Your current role: <strong className="text-slate-200">{role.charAt(0).toUpperCase() + role.slice(1)}</strong>
83
- </p>
84
- <p className="text-sm text-slate-400 mt-2">
85
- Please switch your role using the dropdown in the header.
86
- </p>
87
- </div>
88
- <Footer />
89
- </main>
90
- );
91
- }
92
-
93
  // Set initial time only on client side to avoid hydration mismatch
94
  useEffect(() => {
95
  setLastUpdated(new Date().toLocaleTimeString());
@@ -316,6 +278,44 @@ export default function AdminRulesPage() {
316
  }
317
  }, [deleteInput, handleRefresh, headers, requireTenant]);
318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  return (
320
  <main className="mx-auto flex min-h-screen max-w-5xl flex-col gap-10 px-4 pb-16 pt-12 sm:px-6 lg:px-8">
321
  <header className="flex flex-col gap-4 rounded-2xl border border-white/10 bg-white/5 px-6 py-6 text-slate-100 shadow-lg shadow-slate-950/40">
 
52
  const [lastUpdated, setLastUpdated] = useState<string>("");
53
  const fileInputRef = useRef<HTMLInputElement>(null);
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  // Set initial time only on client side to avoid hydration mismatch
56
  useEffect(() => {
57
  setLastUpdated(new Date().toLocaleTimeString());
 
278
  }
279
  }, [deleteInput, handleRefresh, headers, requireTenant]);
280
 
281
+ // Check permissions AFTER all hooks are called
282
+ if (!canManageRules(role)) {
283
+ return (
284
+ <main className="mx-auto flex min-h-screen max-w-5xl flex-col gap-10 px-4 pb-16 pt-12 sm:px-6 lg:px-8">
285
+ <header className="flex flex-col gap-4 rounded-2xl border border-white/10 bg-white/5 px-6 py-6 text-slate-100 shadow-lg shadow-slate-950/40">
286
+ <div className="flex items-center justify-between gap-3">
287
+ <div className="flex items-center gap-3 text-base font-semibold">
288
+ <span className="inline-flex h-10 w-10 items-center justify-center rounded-2xl bg-gradient-to-br from-sky-400 to-cyan-500 text-slate-950">
289
+ IC
290
+ </span>
291
+ IntegraChat Β· Admin Rules
292
+ </div>
293
+ <div className="flex items-center gap-4">
294
+ <TenantSelector />
295
+ <Link href="/" className="text-xs font-semibold uppercase tracking-[0.3em] text-cyan-300 hover:text-white">
296
+ ← Back Home
297
+ </Link>
298
+ </div>
299
+ </div>
300
+ </header>
301
+
302
+ <div className="rounded-2xl border border-red-500/50 bg-red-500/10 p-8 text-center">
303
+ <h2 className="text-2xl font-bold text-red-300 mb-2">Access Denied</h2>
304
+ <p className="text-slate-300 mb-4">
305
+ You need <strong>Admin</strong> or <strong>Owner</strong> role to manage rules.
306
+ </p>
307
+ <p className="text-sm text-slate-400">
308
+ Your current role: <strong className="text-slate-200">{role.charAt(0).toUpperCase() + role.slice(1)}</strong>
309
+ </p>
310
+ <p className="text-sm text-slate-400 mt-2">
311
+ Please switch your role using the dropdown in the header.
312
+ </p>
313
+ </div>
314
+ <Footer />
315
+ </main>
316
+ );
317
+ }
318
+
319
  return (
320
  <main className="mx-auto flex min-h-screen max-w-5xl flex-col gap-10 px-4 pb-16 pt-12 sm:px-6 lg:px-8">
321
  <header className="flex flex-col gap-4 rounded-2xl border border-white/10 bg-white/5 px-6 py-6 text-slate-100 shadow-lg shadow-slate-950/40">
frontend/components/knowledge-base-panel.tsx CHANGED
@@ -20,7 +20,7 @@ type Document = {
20
  type SourceType = "raw_text" | "url" | "pdf" | "docx" | "txt" | "markdown";
21
 
22
  const API_BASE =
23
- process.env.NEXT_PUBLIC_API_URL?.replace(/\/$/, "") || "http://localhost:8000";
24
 
25
  export function KnowledgeBasePanel() {
26
  const { tenantId, isLoading: tenantLoading, role } = useTenant();
@@ -242,7 +242,7 @@ export function KnowledgeBasePanel() {
242
  setDocuments([]);
243
  return;
244
  } else if (response.status === 503) {
245
- console.error("Cannot connect to RAG MCP server");
246
  setDocuments([]);
247
  return;
248
  } else {
@@ -253,8 +253,15 @@ export function KnowledgeBasePanel() {
253
  const data = await response.json();
254
  setDocuments(data.documents || []);
255
  } catch (err) {
256
- console.error(err);
257
- setDocuments([]);
 
 
 
 
 
 
 
258
  // Don't show error in status for document loading - it's not critical
259
  } finally {
260
  setIsLoadingDocs(false);
@@ -338,7 +345,8 @@ export function KnowledgeBasePanel() {
338
  if (!tenantLoading && tenantId && tenantId.trim()) {
339
  loadDocuments();
340
  }
341
- }, [tenantId, tenantLoading]);
 
342
 
343
  return (
344
  <section
 
20
  type SourceType = "raw_text" | "url" | "pdf" | "docx" | "txt" | "markdown";
21
 
22
  const API_BASE =
23
+ process.env.NEXT_PUBLIC_BACKEND_BASE_URL?.replace(/\/$/, "") || "http://localhost:8000";
24
 
25
  export function KnowledgeBasePanel() {
26
  const { tenantId, isLoading: tenantLoading, role } = useTenant();
 
242
  setDocuments([]);
243
  return;
244
  } else if (response.status === 503) {
245
+ console.warn("Cannot connect to RAG MCP server");
246
  setDocuments([]);
247
  return;
248
  } else {
 
253
  const data = await response.json();
254
  setDocuments(data.documents || []);
255
  } catch (err) {
256
+ // Handle network errors (e.g., backend not running, CORS, etc.)
257
+ if (err instanceof TypeError && err.message === "Failed to fetch") {
258
+ // Network error - backend likely not running or unreachable
259
+ console.warn("Cannot connect to backend. Make sure the backend server is running.");
260
+ setDocuments([]);
261
+ } else {
262
+ console.error("Error loading documents:", err);
263
+ setDocuments([]);
264
+ }
265
  // Don't show error in status for document loading - it's not critical
266
  } finally {
267
  setIsLoadingDocs(false);
 
345
  if (!tenantLoading && tenantId && tenantId.trim()) {
346
  loadDocuments();
347
  }
348
+ // eslint-disable-next-line react-hooks/exhaustive-deps
349
+ }, [tenantId, tenantLoading, role]);
350
 
351
  return (
352
  <section
test_reranking.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test script for cross-encoder re-ranking in RAG search.
3
+
4
+ This script tests:
5
+ 1. Model loading
6
+ 2. Re-ranking functionality
7
+ 3. Comparison of results with/without re-ranking
8
+ """
9
+
10
+ import sys
11
+ import asyncio
12
+ from pathlib import Path
13
+
14
+ # Add backend to path
15
+ backend_dir = Path(__file__).parent / "backend"
16
+ sys.path.insert(0, str(backend_dir))
17
+
18
+ from mcp_server.common.reranker import rerank_results, _get_reranker
19
+
20
+
21
+ def test_model_loading():
22
+ """Test that the cross-encoder model loads correctly."""
23
+ print("=" * 60)
24
+ print("Test 1: Model Loading")
25
+ print("=" * 60)
26
+
27
+ try:
28
+ reranker = _get_reranker()
29
+ if reranker is None:
30
+ print("❌ FAILED: Reranker model is None (sentence-transformers not available?)")
31
+ return False
32
+ print("βœ… SUCCESS: Cross-encoder model loaded successfully")
33
+ print(f" Model type: {type(reranker).__name__}")
34
+ return True
35
+ except Exception as e:
36
+ print(f"❌ FAILED: Error loading model: {e}")
37
+ return False
38
+
39
+
40
+ def test_reranking_basic():
41
+ """Test basic re-ranking functionality."""
42
+ print("\n" + "=" * 60)
43
+ print("Test 2: Basic Re-ranking")
44
+ print("=" * 60)
45
+
46
+ query = "What is the refund policy?"
47
+ candidates = [
48
+ {"text": "Our refund policy allows returns within 30 days.", "score": 0.85, "relevance": 0.85},
49
+ {"text": "The company was founded in 2020.", "score": 0.45, "relevance": 0.45},
50
+ {"text": "Refunds are processed within 5-7 business days after approval.", "score": 0.72, "relevance": 0.72},
51
+ {"text": "Contact support for assistance.", "score": 0.30, "relevance": 0.30},
52
+ ]
53
+
54
+ print(f"Query: {query}")
55
+ print(f"\nOriginal order (by vector similarity):")
56
+ for i, cand in enumerate(candidates, 1):
57
+ print(f" {i}. Score: {cand['score']:.3f} - {cand['text'][:60]}...")
58
+
59
+ try:
60
+ reranked = rerank_results(query, candidates, top_k=3)
61
+
62
+ if not reranked:
63
+ print("❌ FAILED: Re-ranking returned empty results")
64
+ return False
65
+
66
+ print(f"\nRe-ranked order (by cross-encoder):")
67
+ for i, cand in enumerate(reranked, 1):
68
+ print(f" {i}. Score: {cand['score']:.3f} - {cand['text'][:60]}...")
69
+
70
+ # Check that results are sorted by score (descending)
71
+ scores = [c.get("score", 0.0) for c in reranked]
72
+ if scores != sorted(scores, reverse=True):
73
+ print("❌ FAILED: Results are not sorted by score")
74
+ return False
75
+
76
+ # Check that reranked flag is set
77
+ if not all(c.get("reranked") is True for c in reranked):
78
+ print("❌ FAILED: 'reranked' flag not set")
79
+ return False
80
+
81
+ print("βœ… SUCCESS: Re-ranking works correctly")
82
+ return True
83
+
84
+ except Exception as e:
85
+ print(f"❌ FAILED: Error during re-ranking: {e}")
86
+ import traceback
87
+ traceback.print_exc()
88
+ return False
89
+
90
+
91
+ def test_reranking_empty():
92
+ """Test re-ranking with empty candidates."""
93
+ print("\n" + "=" * 60)
94
+ print("Test 3: Empty Candidates Handling")
95
+ print("=" * 60)
96
+
97
+ try:
98
+ reranked = rerank_results("test query", [])
99
+ if reranked == []:
100
+ print("βœ… SUCCESS: Empty candidates handled correctly")
101
+ return True
102
+ else:
103
+ print(f"❌ FAILED: Expected empty list, got {reranked}")
104
+ return False
105
+ except Exception as e:
106
+ print(f"❌ FAILED: Error with empty candidates: {e}")
107
+ return False
108
+
109
+
110
+ async def test_rag_search_integration():
111
+ """Test RAG search with re-ranking (requires database)."""
112
+ print("\n" + "=" * 60)
113
+ print("Test 4: RAG Search Integration (requires database)")
114
+ print("=" * 60)
115
+
116
+ try:
117
+ from mcp_server.rag.search import rag_search
118
+ from mcp_server.common.tenant import TenantContext
119
+
120
+ # Create a test tenant context
121
+ context = TenantContext(tenant_id="test_tenant_rerank")
122
+
123
+ # Test search
124
+ payload = {
125
+ "query": "test query",
126
+ "limit": 5,
127
+ "threshold": 0.1
128
+ }
129
+
130
+ print(f"Testing RAG search with query: '{payload['query']}'")
131
+ print("Note: This requires a running database with documents.")
132
+
133
+ result = await rag_search(context, payload)
134
+
135
+ print(f"\nResults: {len(result.get('results', []))} items")
136
+ print(f"Metadata: {result.get('metadata', {})}")
137
+
138
+ if result.get('metadata', {}).get('reranked'):
139
+ print("βœ… SUCCESS: Re-ranking was applied")
140
+ else:
141
+ print("⚠️ WARNING: Re-ranking was not applied (may be normal if no candidates found)")
142
+
143
+ return True
144
+
145
+ except Exception as e:
146
+ print(f"⚠️ SKIPPED: Integration test requires database: {e}")
147
+ return None
148
+
149
+
150
+ def main():
151
+ """Run all tests."""
152
+ print("\n" + "=" * 60)
153
+ print("Cross-Encoder Re-ranking Test Suite")
154
+ print("=" * 60)
155
+
156
+ results = []
157
+
158
+ # Test 1: Model loading
159
+ results.append(("Model Loading", test_model_loading()))
160
+
161
+ # Test 2: Basic re-ranking
162
+ results.append(("Basic Re-ranking", test_reranking_basic()))
163
+
164
+ # Test 3: Empty candidates
165
+ results.append(("Empty Candidates", test_reranking_empty()))
166
+
167
+ # Test 4: Integration (optional, requires DB)
168
+ try:
169
+ integration_result = asyncio.run(test_rag_search_integration())
170
+ if integration_result is not None:
171
+ results.append(("RAG Integration", integration_result))
172
+ except Exception as e:
173
+ print(f"⚠️ Integration test skipped: {e}")
174
+
175
+ # Summary
176
+ print("\n" + "=" * 60)
177
+ print("Test Summary")
178
+ print("=" * 60)
179
+
180
+ passed = sum(1 for _, result in results if result is True)
181
+ total = len(results)
182
+
183
+ for test_name, result in results:
184
+ status = "βœ… PASS" if result is True else "❌ FAIL" if result is False else "⚠️ SKIP"
185
+ print(f"{status}: {test_name}")
186
+
187
+ print(f"\nTotal: {passed}/{total} tests passed")
188
+
189
+ if passed == total:
190
+ print("\nπŸŽ‰ All tests passed!")
191
+ else:
192
+ print("\n⚠️ Some tests failed. Check output above for details.")
193
+
194
+
195
+ if __name__ == "__main__":
196
+ main()
197
+