raviix46 commited on
Commit
2359bec
·
verified ·
1 Parent(s): 59ae20b

Delete rag_retrieval.py

Browse files
Files changed (1) hide show
  1. rag_retrieval.py +0 -280
rag_retrieval.py DELETED
@@ -1,280 +0,0 @@
1
- import json
2
- import time
3
- import uuid
4
- import numpy as np
5
- import re
6
- from datetime import datetime
7
-
8
- from rag_config import RUNS_DIR, ROOT_DIR
9
- from rag_data import chunks, bm25, embeddings, sem_model, THREAD_OPTIONS
10
- from rag_sessions import get_session
11
-
12
- RUNS_DIR.mkdir(exist_ok=True)
13
-
14
- # --- simple regex patterns for entities ---
15
- FILE_PAT = re.compile(r"\b[\w\-.]+\.(?:pdf|docx?|xls[xm]?|pptx?|txt)\b", re.IGNORECASE)
16
- EMAIL_PAT = re.compile(r"\b[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}\b")
17
- AMOUNT_PAT = re.compile(r"\b(?:\$|USD\s*)?\d{1,3}(?:,\d{3})*(?:\.\d+)?\b")
18
- DATE_PAT = re.compile(r"\b\d{1,2}/\d{1,2}/\d{2,4}\b") # very simple date pattern
19
-
20
-
21
- def rewrite_query(user_text: str, session: dict) -> str:
22
- """
23
- Rewrite user query by injecting thread ID and a light summary
24
- of known entities from entity_memory.
25
- """
26
- tid = session["thread_id"]
27
- mem = session.get("entity_memory") or {}
28
-
29
- key_bits = []
30
-
31
- people = mem.get("people") or []
32
- if people:
33
- key_bits.append(f"people: {', '.join(people[:3])}")
34
-
35
- files = mem.get("files") or []
36
- if files:
37
- key_bits.append(f"files: {', '.join(files[:3])}")
38
-
39
- amounts = mem.get("amounts") or []
40
- if amounts:
41
- key_bits.append(f"amounts: {', '.join(amounts[:3])}")
42
-
43
- dates = mem.get("dates") or []
44
- if dates:
45
- key_bits.append(f"dates: {', '.join(dates[:3])}")
46
-
47
- context_str = ""
48
- if key_bits:
49
- context_str = "Known entities in this thread: " + "; ".join(key_bits) + ". "
50
-
51
- return f"In thread {tid}, {context_str}answer this question: {user_text}"
52
-
53
-
54
- def retrieve_chunks(rewrite: str, session: dict, search_outside_thread: bool):
55
- """
56
- Hybrid retrieval: BM25 + semantic similarity over precomputed embeddings.
57
- """
58
- tokens = rewrite.split()
59
- bm25_scores = np.array(bm25.get_scores(tokens)) # (N,)
60
-
61
- # Semantic query vector
62
- q_vec = sem_model.encode([rewrite], normalize_embeddings=True)[0] # (D,)
63
- sem_scores = embeddings @ q_vec # cosine similarity
64
-
65
- # Normalize to [0,1]
66
- bm25_norm = bm25_scores / bm25_scores.max() if bm25_scores.max() > 0 else bm25_scores
67
- sem_norm = (sem_scores + 1.0) / 2.0
68
-
69
- thread_id = session["thread_id"]
70
- N = len(chunks)
71
- indices = np.arange(N)
72
-
73
- # Thread filter unless overridden
74
- if not search_outside_thread:
75
- mask = np.array([chunks[i]["thread_id"] == thread_id for i in range(N)])
76
- indices = indices[mask]
77
- bm25_norm = bm25_norm[mask]
78
- sem_norm = sem_norm[mask]
79
-
80
- combined = 0.6 * bm25_norm + 0.4 * sem_norm
81
- order = np.argsort(-combined)
82
-
83
- top_k = 8
84
- top_indices = indices[order[:top_k]]
85
-
86
- retrieved = []
87
- for local_rank, idx in enumerate(top_indices):
88
- c = chunks[idx]
89
- retrieved.append({
90
- "chunk_id": c["chunk_id"],
91
- "thread_id": c["thread_id"],
92
- "message_id": c["message_id"],
93
- "page_no": c.get("page_no"),
94
- "source": c.get("source", "email"),
95
- "score_bm25": float(bm25_norm[order][local_rank]),
96
- "score_sem": float(sem_norm[order][local_rank]),
97
- "score_combined": float(combined[order][local_rank]),
98
- "text": c["text"],
99
- # carry over from/to so entity extraction can see people
100
- "from_addr": c.get("from"),
101
- "to_addr": c.get("to"),
102
- "date": c.get("date"),
103
- })
104
- return retrieved
105
-
106
-
107
- def build_answer(user_text: str, rewrite: str, retrieved):
108
- """
109
- Answer builder with:
110
- - 'no clear answer' heuristic
111
- - special handling for simple 'when' questions using email dates
112
- - snippet list with citations for grounding
113
- """
114
- if not retrieved:
115
- return (
116
- "I couldn’t find any emails or content in this thread that clearly answer your question.",
117
- []
118
- )
119
-
120
- # ---- Heuristic: check scores + keyword overlap ----
121
- question_tokens = {t.lower() for t in user_text.split() if len(t) > 3}
122
-
123
- def snippet_has_overlap(snippet: str) -> bool:
124
- words = {w.lower().strip(".,!?;:()[]") for w in snippet.split()}
125
- return len(question_tokens & words) > 0
126
-
127
- best_score = max(r["score_combined"] for r in retrieved)
128
- any_overlap = any(snippet_has_overlap(r["text"]) for r in retrieved)
129
-
130
- if best_score < 0.2 or not any_overlap:
131
- # Fallback: nothing strongly relevant in this thread
132
- return (
133
- "Within this thread, I don’t see any email that clearly answers this question. "
134
- "You may need to search outside this thread or check other conversations.",
135
- []
136
- )
137
-
138
- # ---- Optional: direct answer for 'when' questions ----
139
- direct_answer_line = None
140
- if "when" in user_text.lower():
141
- dated = []
142
- for r in retrieved:
143
- date_str = r.get("date")
144
- if not date_str:
145
- continue
146
- try:
147
- dt = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
148
- dated.append((dt, r))
149
- except Exception:
150
- continue
151
-
152
- if dated:
153
- # pick the latest email as the likely final approval/confirmation
154
- dt_best, r_best = max(dated, key=lambda x: x[0])
155
- nice_date = dt_best.strftime("%Y-%m-%d %H:%M")
156
- direct_answer_line = (
157
- f"**Answer:** The most relevant approval email in this thread "
158
- f"was sent on **{nice_date}** "
159
- f"[msg: {r_best['message_id']}]."
160
- )
161
-
162
- # ---- Build snippet-based explanation ----
163
- lines = []
164
- if direct_answer_line:
165
- lines.append(direct_answer_line)
166
- lines.append("")
167
-
168
- lines.append(f"**Question:** {user_text}")
169
- lines.append("")
170
- lines.append("**Relevant information:**")
171
-
172
- citations = []
173
- seen = set() # avoid exact duplicate snippet+msg combos
174
-
175
- for r in retrieved:
176
- msg_id = r["message_id"]
177
- page_no = r.get("page_no")
178
- snippet = r["text"].replace("\n", " ")
179
- snippet = (snippet[:300] + "…") if len(snippet) > 300 else snippet
180
-
181
- key = (msg_id, snippet)
182
- if key in seen:
183
- continue
184
- seen.add(key)
185
-
186
- if page_no is not None:
187
- cite = f"[msg: {msg_id}, page: {page_no}]"
188
- else:
189
- cite = f"[msg: {msg_id}]"
190
-
191
- lines.append(f"- {snippet} {cite}")
192
-
193
- citations.append({
194
- "message_id": msg_id,
195
- "page_no": page_no,
196
- "chunk_id": r["chunk_id"],
197
- })
198
-
199
- answer = "\n".join(lines)
200
- return answer, citations
201
-
202
-
203
- def extract_entities_for_turn(user_text: str, retrieved):
204
- """
205
- Extract simple entities from this turn:
206
- - people: email addresses from chunks + question
207
- - files: filenames like something.pdf
208
- - amounts: numbers / $ amounts
209
- - dates: simple date patterns
210
- """
211
- texts = [user_text] + [r["text"] for r in retrieved]
212
-
213
- people = set()
214
- files = set()
215
- amounts = set()
216
- dates = set()
217
-
218
- # from/to emails are good 'people' proxies
219
- for r in retrieved:
220
- for field in ("from_addr", "to_addr"):
221
- val = r.get(field)
222
- if not val:
223
- continue
224
- for email_match in EMAIL_PAT.findall(val):
225
- people.add(email_match)
226
-
227
- # scan all texts
228
- for t in texts:
229
- for m in EMAIL_PAT.findall(t):
230
- people.add(m)
231
- for m in FILE_PAT.findall(t):
232
- files.add(m)
233
- for m in AMOUNT_PAT.findall(t):
234
- amounts.add(m)
235
- for m in DATE_PAT.findall(t):
236
- dates.add(m)
237
-
238
- entities = {
239
- "people": sorted(people),
240
- "amounts": sorted(amounts),
241
- "files": sorted(files),
242
- "dates": sorted(dates),
243
- }
244
- # Strip empty categories
245
- entities = {k: v for k, v in entities.items() if v}
246
- return entities
247
-
248
-
249
- def log_trace(session_id: str, user_text: str, rewrite: str, retrieved, answer, citations):
250
- trace_path = RUNS_DIR / "trace.jsonl"
251
-
252
- session = get_session(session_id)
253
- thread_id = session["thread_id"] if session else None
254
-
255
- record = {
256
- "trace_id": str(uuid.uuid4()),
257
- "session_id": session_id,
258
- "thread_id": thread_id,
259
- "user_text": user_text,
260
- "rewrite": rewrite,
261
- "retrieved": [
262
- {
263
- "chunk_id": r["chunk_id"],
264
- "thread_id": r["thread_id"],
265
- "message_id": r["message_id"],
266
- "page_no": r["page_no"],
267
- "score_bm25": r["score_bm25"],
268
- "score_sem": r["score_sem"],
269
- "score_combined": r["score_combined"],
270
- } for r in retrieved
271
- ],
272
- "answer": answer,
273
- "citations": citations,
274
- "timestamp": time.time(),
275
- }
276
-
277
- with trace_path.open("a", encoding="utf-8") as f:
278
- f.write(json.dumps(record) + "\n")
279
-
280
- return record["trace_id"]