raviix46 commited on
Commit
51b20fd
·
verified ·
1 Parent(s): 6747e10

Create rag_retrieval.py

Browse files
Files changed (1) hide show
  1. rag_retrieval.py +134 -0
rag_retrieval.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+ import uuid
4
+ import numpy as np
5
+
6
+ from rag_config import RUNS_DIR, ROOT_DIR
7
+ from rag_data import chunks, bm25, embeddings, sem_model, THREAD_OPTIONS
8
+ from rag_sessions import get_session
9
+
10
+ RUNS_DIR.mkdir(exist_ok=True)
11
+
12
+
13
+ def rewrite_query(user_text: str, session: dict) -> str:
14
+ """Very simple rewrite: attach thread context."""
15
+ tid = session["thread_id"]
16
+ return f"In thread {tid}, answer this question: {user_text}"
17
+
18
+
19
+ def retrieve_chunks(rewrite: str, session: dict, search_outside_thread: bool):
20
+ """
21
+ Hybrid retrieval: BM25 + semantic similarity over precomputed embeddings.
22
+ """
23
+ tokens = rewrite.split()
24
+ bm25_scores = np.array(bm25.get_scores(tokens)) # (N,)
25
+
26
+ # Semantic query vector
27
+ q_vec = sem_model.encode([rewrite], normalize_embeddings=True)[0] # (D,)
28
+ sem_scores = embeddings @ q_vec # cosine similarity
29
+
30
+ # Normalize to [0,1]
31
+ bm25_norm = bm25_scores / bm25_scores.max() if bm25_scores.max() > 0 else bm25_scores
32
+ sem_norm = (sem_scores + 1.0) / 2.0
33
+
34
+ thread_id = session["thread_id"]
35
+ N = len(chunks)
36
+ indices = np.arange(N)
37
+
38
+ # Thread filter unless overridden
39
+ if not search_outside_thread:
40
+ mask = np.array([chunks[i]["thread_id"] == thread_id for i in range(N)])
41
+ indices = indices[mask]
42
+ bm25_norm = bm25_norm[mask]
43
+ sem_norm = sem_norm[mask]
44
+
45
+ combined = 0.6 * bm25_norm + 0.4 * sem_norm
46
+ order = np.argsort(-combined)
47
+
48
+ top_k = 8
49
+ top_indices = indices[order[:top_k]]
50
+
51
+ retrieved = []
52
+ for local_rank, idx in enumerate(top_indices):
53
+ c = chunks[idx]
54
+ retrieved.append({
55
+ "chunk_id": c["chunk_id"],
56
+ "thread_id": c["thread_id"],
57
+ "message_id": c["message_id"],
58
+ "page_no": c.get("page_no"),
59
+ "source": c.get("source", "email"),
60
+ "score_bm25": float(bm25_norm[order][local_rank]),
61
+ "score_sem": float(sem_norm[order][local_rank]),
62
+ "score_combined": float(combined[order][local_rank]),
63
+ "text": c["text"],
64
+ })
65
+ return retrieved
66
+
67
+
68
+ def build_answer(user_text: str, rewrite: str, retrieved):
69
+ """
70
+ Simple answer builder:
71
+ - Show relevant snippets with citations.
72
+ """
73
+ if not retrieved:
74
+ return (
75
+ "I couldn’t find any emails or content in this thread that clearly answer your question.",
76
+ []
77
+ )
78
+
79
+ lines = [f"**Question:** {user_text}", "", "**Relevant information:**"]
80
+ citations = []
81
+
82
+ for r in retrieved:
83
+ msg_id = r["message_id"]
84
+ page_no = r.get("page_no")
85
+ snippet = r["text"].replace("\n", " ")
86
+ snippet = (snippet[:300] + "…") if len(snippet) > 300 else snippet
87
+
88
+ if page_no is not None:
89
+ cite = f"[msg: {msg_id}, page: {page_no}]"
90
+ else:
91
+ cite = f"[msg: {msg_id}]"
92
+
93
+ lines.append(f"- {snippet} {cite}")
94
+
95
+ citations.append({
96
+ "message_id": msg_id,
97
+ "page_no": page_no,
98
+ "chunk_id": r["chunk_id"],
99
+ })
100
+
101
+ answer = "\n".join(lines)
102
+ return answer, citations
103
+
104
+
105
+ def log_trace(session_id: str, user_text: str, rewrite: str, retrieved, answer, citations):
106
+ trace_path = RUNS_DIR / "trace.jsonl"
107
+
108
+ session = get_session(session_id)
109
+ thread_id = session["thread_id"] if session else None
110
+
111
+ record = {
112
+ "trace_id": str(uuid.uuid4()),
113
+ "session_id": session_id,
114
+ "thread_id": thread_id,
115
+ "user_text": user_text,
116
+ "rewrite": rewrite,
117
+ "retrieved": [
118
+ {
119
+ "chunk_id": r["chunk_id"],
120
+ "thread_id": r["thread_id"],
121
+ "message_id": r["message_id"],
122
+ "page_no": r["page_no"],
123
+ "score_bm25": r["score_bm25"],
124
+ "score_sem": r["score_sem"],
125
+ "score_combined": r["score_combined"],
126
+ } for r in retrieved
127
+ ],
128
+ "answer": answer,
129
+ "citations": citations,
130
+ "timestamp": time.time(),
131
+ }
132
+
133
+ with trace_path.open("a", encoding="utf-8") as f:
134
+ f.write(json.dumps(record) + "\n")