RJuro Claude Opus 4.6 commited on
Commit
683f135
Β·
1 Parent(s): 3a6e6ca

Add cross-encoder reranking (two-stage retrieval pipeline)

Browse files

Stage 1 retrieves top_k*4 candidates via bi-encoder, Stage 2
reranks with ms-marco-MiniLM-L-6-v2 cross-encoder. Toggle in UI
lets users compare rankings with dual-score display.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (3) hide show
  1. Dockerfile +1 -0
  2. app.py +47 -17
  3. static/index.html +102 -4
Dockerfile CHANGED
@@ -2,6 +2,7 @@ FROM python:3.11-slim
2
  WORKDIR /app
3
  COPY requirements.txt .
4
  RUN pip install --no-cache-dir -r requirements.txt
 
5
  COPY . .
6
  EXPOSE 7860
7
  CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
 
2
  WORKDIR /app
3
  COPY requirements.txt .
4
  RUN pip install --no-cache-dir -r requirements.txt
5
+ RUN python -c "from sentence_transformers import CrossEncoder; CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')"
6
  COPY . .
7
  EXPOSE 7860
8
  CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py CHANGED
@@ -9,24 +9,28 @@ import os
9
  from fastapi import FastAPI, Query
10
  from fastapi.responses import FileResponse
11
  from fastapi.staticfiles import StaticFiles
12
- from sentence_transformers import SentenceTransformer
13
  import chromadb
14
 
15
  MODEL_NAME = "intfloat/multilingual-e5-small"
 
16
  CHROMA_PATH = os.path.join(os.path.dirname(__file__), "data", "chroma_db")
17
  COLLECTION_NAME = "scifact"
18
 
19
  app = FastAPI(title="SciFact Multilingual Semantic Search")
20
 
21
  model: SentenceTransformer = None
 
22
  collection: chromadb.Collection = None
23
 
24
 
25
  @app.on_event("startup")
26
  def startup():
27
- global model, collection
28
  print(f"Loading model: {MODEL_NAME}")
29
  model = SentenceTransformer(MODEL_NAME)
 
 
30
  print(f"Loading ChromaDB from: {CHROMA_PATH}")
31
  client = chromadb.PersistentClient(path=CHROMA_PATH)
32
  collection = client.get_collection(COLLECTION_NAME)
@@ -41,36 +45,62 @@ def index():
41
 
42
 
43
  @app.get("/search")
44
- def search(q: str = Query(..., min_length=1), top_k: int = Query(5, ge=1, le=20)):
 
 
 
 
 
 
 
 
45
  query_embedding = model.encode(
46
- [f"query: {q.strip()}"],
47
  normalize_embeddings=True,
48
  ).tolist()
49
 
50
  results = collection.query(
51
  query_embeddings=query_embedding,
52
- n_results=top_k,
53
  include=["metadatas", "distances", "documents"],
54
  )
55
 
56
- items = []
57
- for i, (meta, dist, doc) in enumerate(
58
- zip(
59
- results["metadatas"][0],
60
- results["distances"][0],
61
- results["documents"][0],
62
- )
63
  ):
64
- items.append(
65
  {
66
- "rank": i + 1,
67
- "score": round(1 - dist, 4),
68
  "title": meta.get("title", ""),
69
- "text": doc[:300] if doc else meta.get("text", "")[:300],
70
  }
71
  )
72
 
73
- return {"query": q, "results": items}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
 
76
  app.mount("/static", StaticFiles(directory=os.path.join(os.path.dirname(__file__), "static")), name="static")
 
9
  from fastapi import FastAPI, Query
10
  from fastapi.responses import FileResponse
11
  from fastapi.staticfiles import StaticFiles
12
+ from sentence_transformers import SentenceTransformer, CrossEncoder
13
  import chromadb
14
 
15
  MODEL_NAME = "intfloat/multilingual-e5-small"
16
+ RERANKER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
17
  CHROMA_PATH = os.path.join(os.path.dirname(__file__), "data", "chroma_db")
18
  COLLECTION_NAME = "scifact"
19
 
20
  app = FastAPI(title="SciFact Multilingual Semantic Search")
21
 
22
  model: SentenceTransformer = None
23
+ reranker: CrossEncoder = None
24
  collection: chromadb.Collection = None
25
 
26
 
27
  @app.on_event("startup")
28
  def startup():
29
+ global model, reranker, collection
30
  print(f"Loading model: {MODEL_NAME}")
31
  model = SentenceTransformer(MODEL_NAME)
32
+ print(f"Loading reranker: {RERANKER_MODEL}")
33
+ reranker = CrossEncoder(RERANKER_MODEL)
34
  print(f"Loading ChromaDB from: {CHROMA_PATH}")
35
  client = chromadb.PersistentClient(path=CHROMA_PATH)
36
  collection = client.get_collection(COLLECTION_NAME)
 
45
 
46
 
47
  @app.get("/search")
48
+ def search(
49
+ q: str = Query(..., min_length=1),
50
+ top_k: int = Query(5, ge=1, le=20),
51
+ rerank: bool = Query(True),
52
+ ):
53
+ raw_query = q.strip()
54
+
55
+ # Stage 1: bi-encoder retrieval
56
+ n_candidates = min(top_k * 4, 40) if rerank else top_k
57
  query_embedding = model.encode(
58
+ [f"query: {raw_query}"],
59
  normalize_embeddings=True,
60
  ).tolist()
61
 
62
  results = collection.query(
63
  query_embeddings=query_embedding,
64
+ n_results=n_candidates,
65
  include=["metadatas", "distances", "documents"],
66
  )
67
 
68
+ candidates = []
69
+ for meta, dist, doc in zip(
70
+ results["metadatas"][0],
71
+ results["distances"][0],
72
+ results["documents"][0],
 
 
73
  ):
74
+ candidates.append(
75
  {
76
+ "bi_score": round(1 - dist, 4),
 
77
  "title": meta.get("title", ""),
78
+ "full_text": doc if doc else meta.get("text", ""),
79
  }
80
  )
81
 
82
+ # Stage 2: cross-encoder reranking
83
+ if rerank and candidates:
84
+ pairs = [(raw_query, c["full_text"]) for c in candidates]
85
+ ce_scores = reranker.predict(pairs).tolist()
86
+ for c, score in zip(candidates, ce_scores):
87
+ c["ce_score"] = round(score, 4)
88
+ candidates.sort(key=lambda c: c["ce_score"], reverse=True)
89
+
90
+ items = []
91
+ for i, c in enumerate(candidates[:top_k]):
92
+ item = {
93
+ "rank": i + 1,
94
+ "bi_score": c["bi_score"],
95
+ "title": c["title"],
96
+ "text": c["full_text"][:300],
97
+ "reranked": rerank,
98
+ }
99
+ if rerank:
100
+ item["ce_score"] = c.get("ce_score")
101
+ items.append(item)
102
+
103
+ return {"query": raw_query, "results": items}
104
 
105
 
106
  app.mount("/static", StaticFiles(directory=os.path.join(os.path.dirname(__file__), "static")), name="static")
static/index.html CHANGED
@@ -162,6 +162,77 @@
162
  .search-row button:hover { background: #C42A20; }
163
  .search-row button:disabled { background: var(--gray-2); color: var(--gray-3); cursor: not-allowed; }
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  /* ── Chips ── */
166
  .chips {
167
  display: flex;
@@ -446,6 +517,13 @@
446
  <input id="q" type="text" placeholder="Enter a query in any language..." autocomplete="off" spellcheck="false">
447
  <button id="btn" onclick="doSearch()">SEARCH</button>
448
  </div>
 
 
 
 
 
 
 
449
  <div class="chips">
450
  <div class="chip" onclick="chipSearch(this)"><span class="lang">EN</span>effects of vaccination</div>
451
  <div class="chip" onclick="chipSearch(this)"><span class="lang">FR</span>effets de la vaccination</div>
@@ -469,7 +547,7 @@
469
 
470
  <!-- Footer -->
471
  <div class="footer">
472
- <span>intfloat/multilingual-e5-small</span>
473
  <span>chromadb<span class="dot">/</span>fastapi<span class="dot">/</span>cpu</span>
474
  </div>
475
 
@@ -479,9 +557,20 @@
479
  const input = document.getElementById('q');
480
  const btn = document.getElementById('btn');
481
  const results = document.getElementById('results');
 
482
 
483
  input.addEventListener('keydown', e => { if (e.key === 'Enter') doSearch(); });
484
 
 
 
 
 
 
 
 
 
 
 
485
  function chipSearch(el) {
486
  const text = el.textContent.replace(/^[A-Z]{2}/, '').trim();
487
  input.value = text;
@@ -497,7 +586,7 @@ async function doSearch() {
497
  results.innerHTML = skeletonHTML();
498
 
499
  try {
500
- const res = await fetch('/search?q=' + encodeURIComponent(q) + '&top_k=5');
501
  if (!res.ok) throw new Error(res.statusText);
502
  const data = await res.json();
503
  renderResults(data.results);
@@ -526,11 +615,20 @@ function renderResults(items) {
526
  return;
527
  }
528
  results.innerHTML = items.map((r, i) => {
529
- const pct = Math.max(0, Math.min(100, r.score * 100));
 
 
 
 
 
 
 
 
530
  return '<div class="result-card" style="animation-delay:' + (i * .08) + 's">' +
531
- '<div class="rank-col"><div class="rank-num">0' + r.rank + '</div><div class="rank-line"></div></div>' +
532
  '<div class="result-body">' +
533
  '<div class="result-title">' + esc(r.title) + '</div>' +
 
534
  '<div class="score-row">' +
535
  '<div class="score-track"><div class="score-fill" style="width:' + pct + '%"></div></div>' +
536
  '<div class="score-val">' + pct.toFixed(1) + '%</div>' +
 
162
  .search-row button:hover { background: #C42A20; }
163
  .search-row button:disabled { background: var(--gray-2); color: var(--gray-3); cursor: not-allowed; }
164
 
165
+ /* ── Rerank Toggle ── */
166
+ .toggle-row {
167
+ display: flex;
168
+ align-items: center;
169
+ gap: 12px;
170
+ margin-top: 14px;
171
+ }
172
+ .toggle-label {
173
+ font-family: 'JetBrains Mono', monospace;
174
+ font-size: .7rem;
175
+ color: var(--gray-4);
176
+ letter-spacing: .5px;
177
+ text-transform: uppercase;
178
+ }
179
+ .toggle-track {
180
+ position: relative;
181
+ width: 42px;
182
+ height: 22px;
183
+ border: 2px solid var(--white);
184
+ background: transparent;
185
+ cursor: pointer;
186
+ transition: all .2s;
187
+ flex-shrink: 0;
188
+ }
189
+ .toggle-track.active {
190
+ border-color: var(--red);
191
+ background: var(--red);
192
+ }
193
+ .toggle-thumb {
194
+ position: absolute;
195
+ top: 2px;
196
+ left: 2px;
197
+ width: 14px;
198
+ height: 14px;
199
+ background: var(--white);
200
+ transition: transform .2s;
201
+ }
202
+ .toggle-track.active .toggle-thumb {
203
+ transform: translateX(20px);
204
+ }
205
+ .toggle-status {
206
+ font-family: 'JetBrains Mono', monospace;
207
+ font-size: .62rem;
208
+ color: var(--gray-3);
209
+ letter-spacing: .5px;
210
+ }
211
+
212
+ /* ── CE Score / Badge ── */
213
+ .ce-row {
214
+ display: flex;
215
+ align-items: center;
216
+ gap: 10px;
217
+ margin-bottom: 4px;
218
+ }
219
+ .ce-badge {
220
+ font-family: 'JetBrains Mono', monospace;
221
+ font-size: .58rem;
222
+ letter-spacing: 1px;
223
+ color: var(--black);
224
+ background: var(--red);
225
+ padding: 2px 7px;
226
+ font-weight: 600;
227
+ }
228
+ .ce-val {
229
+ font-family: 'JetBrains Mono', monospace;
230
+ font-size: .72rem;
231
+ color: var(--white);
232
+ font-weight: 600;
233
+ letter-spacing: .5px;
234
+ }
235
+
236
  /* ── Chips ── */
237
  .chips {
238
  display: flex;
 
517
  <input id="q" type="text" placeholder="Enter a query in any language..." autocomplete="off" spellcheck="false">
518
  <button id="btn" onclick="doSearch()">SEARCH</button>
519
  </div>
520
+ <div class="toggle-row">
521
+ <div class="toggle-track active" id="rerankToggle" onclick="toggleRerank()">
522
+ <div class="toggle-thumb"></div>
523
+ </div>
524
+ <span class="toggle-label">Cross-encoder reranking</span>
525
+ <span class="toggle-status" id="rerankStatus">ON β€” ms-marco-MiniLM-L-6-v2</span>
526
+ </div>
527
  <div class="chips">
528
  <div class="chip" onclick="chipSearch(this)"><span class="lang">EN</span>effects of vaccination</div>
529
  <div class="chip" onclick="chipSearch(this)"><span class="lang">FR</span>effets de la vaccination</div>
 
547
 
548
  <!-- Footer -->
549
  <div class="footer">
550
+ <span>bi-encoder: multilingual-e5-small<span class="dot">/</span>cross-encoder: ms-marco-MiniLM-L-6-v2</span>
551
  <span>chromadb<span class="dot">/</span>fastapi<span class="dot">/</span>cpu</span>
552
  </div>
553
 
 
557
  const input = document.getElementById('q');
558
  const btn = document.getElementById('btn');
559
  const results = document.getElementById('results');
560
+ let rerankEnabled = true;
561
 
562
  input.addEventListener('keydown', e => { if (e.key === 'Enter') doSearch(); });
563
 
564
+ function toggleRerank() {
565
+ rerankEnabled = !rerankEnabled;
566
+ const toggle = document.getElementById('rerankToggle');
567
+ const status = document.getElementById('rerankStatus');
568
+ toggle.classList.toggle('active', rerankEnabled);
569
+ status.textContent = rerankEnabled ? 'ON β€” ms-marco-MiniLM-L-6-v2' : 'OFF β€” bi-encoder only';
570
+ // Re-run search if there are existing results
571
+ if (input.value.trim()) doSearch();
572
+ }
573
+
574
  function chipSearch(el) {
575
  const text = el.textContent.replace(/^[A-Z]{2}/, '').trim();
576
  input.value = text;
 
586
  results.innerHTML = skeletonHTML();
587
 
588
  try {
589
+ const res = await fetch('/search?q=' + encodeURIComponent(q) + '&top_k=5&rerank=' + rerankEnabled);
590
  if (!res.ok) throw new Error(res.statusText);
591
  const data = await res.json();
592
  renderResults(data.results);
 
615
  return;
616
  }
617
  results.innerHTML = items.map((r, i) => {
618
+ const pct = Math.max(0, Math.min(100, r.bi_score * 100));
619
+ const rankStr = r.rank < 10 ? '0' + r.rank : '' + r.rank;
620
+ let ceHTML = '';
621
+ if (r.reranked && r.ce_score != null) {
622
+ ceHTML = '<div class="ce-row">' +
623
+ '<span class="ce-badge">RERANKED</span>' +
624
+ '<span class="ce-val">CE ' + r.ce_score.toFixed(2) + '</span>' +
625
+ '</div>';
626
+ }
627
  return '<div class="result-card" style="animation-delay:' + (i * .08) + 's">' +
628
+ '<div class="rank-col"><div class="rank-num">' + rankStr + '</div><div class="rank-line"></div></div>' +
629
  '<div class="result-body">' +
630
  '<div class="result-title">' + esc(r.title) + '</div>' +
631
+ ceHTML +
632
  '<div class="score-row">' +
633
  '<div class="score-track"><div class="score-fill" style="width:' + pct + '%"></div></div>' +
634
  '<div class="score-val">' + pct.toFixed(1) + '%</div>' +