Spaces:
Sleeping
Sleeping
Create rag_retrieval.py
Browse files- rag_retrieval.py +134 -0
rag_retrieval.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import time
|
| 3 |
+
import uuid
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from rag_config import RUNS_DIR, ROOT_DIR
|
| 7 |
+
from rag_data import chunks, bm25, embeddings, sem_model, THREAD_OPTIONS
|
| 8 |
+
from rag_sessions import get_session
|
| 9 |
+
|
| 10 |
+
RUNS_DIR.mkdir(exist_ok=True)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def rewrite_query(user_text: str, session: dict) -> str:
|
| 14 |
+
"""Very simple rewrite: attach thread context."""
|
| 15 |
+
tid = session["thread_id"]
|
| 16 |
+
return f"In thread {tid}, answer this question: {user_text}"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def retrieve_chunks(rewrite: str, session: dict, search_outside_thread: bool):
|
| 20 |
+
"""
|
| 21 |
+
Hybrid retrieval: BM25 + semantic similarity over precomputed embeddings.
|
| 22 |
+
"""
|
| 23 |
+
tokens = rewrite.split()
|
| 24 |
+
bm25_scores = np.array(bm25.get_scores(tokens)) # (N,)
|
| 25 |
+
|
| 26 |
+
# Semantic query vector
|
| 27 |
+
q_vec = sem_model.encode([rewrite], normalize_embeddings=True)[0] # (D,)
|
| 28 |
+
sem_scores = embeddings @ q_vec # cosine similarity
|
| 29 |
+
|
| 30 |
+
# Normalize to [0,1]
|
| 31 |
+
bm25_norm = bm25_scores / bm25_scores.max() if bm25_scores.max() > 0 else bm25_scores
|
| 32 |
+
sem_norm = (sem_scores + 1.0) / 2.0
|
| 33 |
+
|
| 34 |
+
thread_id = session["thread_id"]
|
| 35 |
+
N = len(chunks)
|
| 36 |
+
indices = np.arange(N)
|
| 37 |
+
|
| 38 |
+
# Thread filter unless overridden
|
| 39 |
+
if not search_outside_thread:
|
| 40 |
+
mask = np.array([chunks[i]["thread_id"] == thread_id for i in range(N)])
|
| 41 |
+
indices = indices[mask]
|
| 42 |
+
bm25_norm = bm25_norm[mask]
|
| 43 |
+
sem_norm = sem_norm[mask]
|
| 44 |
+
|
| 45 |
+
combined = 0.6 * bm25_norm + 0.4 * sem_norm
|
| 46 |
+
order = np.argsort(-combined)
|
| 47 |
+
|
| 48 |
+
top_k = 8
|
| 49 |
+
top_indices = indices[order[:top_k]]
|
| 50 |
+
|
| 51 |
+
retrieved = []
|
| 52 |
+
for local_rank, idx in enumerate(top_indices):
|
| 53 |
+
c = chunks[idx]
|
| 54 |
+
retrieved.append({
|
| 55 |
+
"chunk_id": c["chunk_id"],
|
| 56 |
+
"thread_id": c["thread_id"],
|
| 57 |
+
"message_id": c["message_id"],
|
| 58 |
+
"page_no": c.get("page_no"),
|
| 59 |
+
"source": c.get("source", "email"),
|
| 60 |
+
"score_bm25": float(bm25_norm[order][local_rank]),
|
| 61 |
+
"score_sem": float(sem_norm[order][local_rank]),
|
| 62 |
+
"score_combined": float(combined[order][local_rank]),
|
| 63 |
+
"text": c["text"],
|
| 64 |
+
})
|
| 65 |
+
return retrieved
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def build_answer(user_text: str, rewrite: str, retrieved):
|
| 69 |
+
"""
|
| 70 |
+
Simple answer builder:
|
| 71 |
+
- Show relevant snippets with citations.
|
| 72 |
+
"""
|
| 73 |
+
if not retrieved:
|
| 74 |
+
return (
|
| 75 |
+
"I couldn’t find any emails or content in this thread that clearly answer your question.",
|
| 76 |
+
[]
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
lines = [f"**Question:** {user_text}", "", "**Relevant information:**"]
|
| 80 |
+
citations = []
|
| 81 |
+
|
| 82 |
+
for r in retrieved:
|
| 83 |
+
msg_id = r["message_id"]
|
| 84 |
+
page_no = r.get("page_no")
|
| 85 |
+
snippet = r["text"].replace("\n", " ")
|
| 86 |
+
snippet = (snippet[:300] + "…") if len(snippet) > 300 else snippet
|
| 87 |
+
|
| 88 |
+
if page_no is not None:
|
| 89 |
+
cite = f"[msg: {msg_id}, page: {page_no}]"
|
| 90 |
+
else:
|
| 91 |
+
cite = f"[msg: {msg_id}]"
|
| 92 |
+
|
| 93 |
+
lines.append(f"- {snippet} {cite}")
|
| 94 |
+
|
| 95 |
+
citations.append({
|
| 96 |
+
"message_id": msg_id,
|
| 97 |
+
"page_no": page_no,
|
| 98 |
+
"chunk_id": r["chunk_id"],
|
| 99 |
+
})
|
| 100 |
+
|
| 101 |
+
answer = "\n".join(lines)
|
| 102 |
+
return answer, citations
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def log_trace(session_id: str, user_text: str, rewrite: str, retrieved, answer, citations):
|
| 106 |
+
trace_path = RUNS_DIR / "trace.jsonl"
|
| 107 |
+
|
| 108 |
+
session = get_session(session_id)
|
| 109 |
+
thread_id = session["thread_id"] if session else None
|
| 110 |
+
|
| 111 |
+
record = {
|
| 112 |
+
"trace_id": str(uuid.uuid4()),
|
| 113 |
+
"session_id": session_id,
|
| 114 |
+
"thread_id": thread_id,
|
| 115 |
+
"user_text": user_text,
|
| 116 |
+
"rewrite": rewrite,
|
| 117 |
+
"retrieved": [
|
| 118 |
+
{
|
| 119 |
+
"chunk_id": r["chunk_id"],
|
| 120 |
+
"thread_id": r["thread_id"],
|
| 121 |
+
"message_id": r["message_id"],
|
| 122 |
+
"page_no": r["page_no"],
|
| 123 |
+
"score_bm25": r["score_bm25"],
|
| 124 |
+
"score_sem": r["score_sem"],
|
| 125 |
+
"score_combined": r["score_combined"],
|
| 126 |
+
} for r in retrieved
|
| 127 |
+
],
|
| 128 |
+
"answer": answer,
|
| 129 |
+
"citations": citations,
|
| 130 |
+
"timestamp": time.time(),
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
with trace_path.open("a", encoding="utf-8") as f:
|
| 134 |
+
f.write(json.dumps(record) + "\n")
|