raviix46 commited on
Commit
3ee7858
·
verified ·
1 Parent(s): ee28484

Upload 5 files

Browse files
email_rag/rag_config.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import json
3
+
4
+ ROOT_DIR = Path(__file__).resolve().parent
5
+ DATA_DIR = ROOT_DIR / "data"
6
+
7
+ CHUNKS_PATH = DATA_DIR / "chunks.jsonl"
8
+ THREADS_PATH = DATA_DIR / "threads.json"
9
+ MESSAGES_PATH = DATA_DIR / "messages.json"
10
+ EMBEDDINGS_PATH = DATA_DIR / "embeddings.npy"
11
+ CHUNK_IDS_PATH = DATA_DIR / "chunk_ids.json"
12
+
13
+ RUNS_DIR = ROOT_DIR / "runs"
14
+
15
+
16
+ def load_json(path: Path):
17
+ with path.open("r", encoding="utf-8") as f:
18
+ return json.load(f)
19
+
20
+
21
+ def load_jsonl(path: Path):
22
+ items = []
23
+ with path.open("r", encoding="utf-8") as f:
24
+ for line in f:
25
+ line = line.strip()
26
+ if not line:
27
+ continue
28
+ items.append(json.loads(line))
29
+ return items
email_rag/rag_data.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from rank_bm25 import BM25Okapi
3
+ from sentence_transformers import SentenceTransformer
4
+
5
+ from rag_config import (
6
+ CHUNKS_PATH,
7
+ THREADS_PATH,
8
+ MESSAGES_PATH,
9
+ EMBEDDINGS_PATH,
10
+ CHUNK_IDS_PATH,
11
+ load_json,
12
+ load_jsonl,
13
+ )
14
+
15
+ # Load base data
16
+ chunks = load_jsonl(CHUNKS_PATH)
17
+ threads = load_json(THREADS_PATH)
18
+ messages = load_json(MESSAGES_PATH)
19
+
20
+ # Map chunk_id -> chunk
21
+ chunk_id_to_chunk = {c["chunk_id"]: c for c in chunks}
22
+
23
+ # BM25 corpus
24
+ corpus_tokens = [c["text"].split() for c in chunks]
25
+ bm25 = BM25Okapi(corpus_tokens)
26
+
27
+ # Semantic embeddings
28
+ embeddings = np.load(EMBEDDINGS_PATH) # (N, D)
29
+
30
+ with CHUNK_IDS_PATH.open("r", encoding="utf-8") as f:
31
+ chunk_ids = load_json(CHUNK_IDS_PATH)
32
+
33
+ # Map chunk_id -> index in embeddings
34
+ chunk_index = {cid: i for i, cid in enumerate(chunk_ids)}
35
+
36
+ # SentenceTransformer model (same as used in build_embeddings)
37
+ SEM_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
38
+ sem_model = SentenceTransformer(SEM_MODEL_NAME)
39
+
40
+ # Thread IDs for dropdown
41
+ THREAD_OPTIONS = sorted(list(threads.keys()))
email_rag/rag_retrieval.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+ import uuid
4
+ import numpy as np
5
+ import re
6
+ from datetime import datetime
7
+
8
+ from rag_config import RUNS_DIR, ROOT_DIR
9
+ from rag_data import chunks, bm25, embeddings, sem_model, THREAD_OPTIONS
10
+ from rag_sessions import get_session
11
+
12
+ RUNS_DIR.mkdir(exist_ok=True)
13
+
14
+ # --- simple regex patterns for entities ---
15
+ FILE_PAT = re.compile(r"\b[\w\-.]+\.(?:pdf|docx?|xls[xm]?|pptx?|txt)\b", re.IGNORECASE)
16
+ EMAIL_PAT = re.compile(r"\b[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}\b")
17
+ AMOUNT_PAT = re.compile(r"\b(?:\$|USD\s*)?\d{1,3}(?:,\d{3})*(?:\.\d+)?\b")
18
+ DATE_PAT = re.compile(r"\b\d{1,2}/\d{1,2}/\d{2,4}\b") # very simple date pattern
19
+
20
+
21
+ def rewrite_query(user_text: str, session: dict) -> str:
22
+ """
23
+ Rewrite user query by injecting thread ID and a light summary
24
+ of known entities from entity_memory.
25
+ """
26
+ tid = session["thread_id"]
27
+ mem = session.get("entity_memory") or {}
28
+
29
+ key_bits = []
30
+
31
+ people = mem.get("people") or []
32
+ if people:
33
+ key_bits.append(f"people: {', '.join(people[:3])}")
34
+
35
+ files = mem.get("files") or []
36
+ if files:
37
+ key_bits.append(f"files: {', '.join(files[:3])}")
38
+
39
+ amounts = mem.get("amounts") or []
40
+ if amounts:
41
+ key_bits.append(f"amounts: {', '.join(amounts[:3])}")
42
+
43
+ dates = mem.get("dates") or []
44
+ if dates:
45
+ key_bits.append(f"dates: {', '.join(dates[:3])}")
46
+
47
+ context_str = ""
48
+ if key_bits:
49
+ context_str = "Known entities in this thread: " + "; ".join(key_bits) + ". "
50
+
51
+ return f"In thread {tid}, {context_str}answer this question: {user_text}"
52
+
53
+
54
+ def retrieve_chunks(rewrite: str, session: dict, search_outside_thread: bool):
55
+ """
56
+ Hybrid retrieval: BM25 + semantic similarity over precomputed embeddings.
57
+ """
58
+ tokens = rewrite.split()
59
+ bm25_scores = np.array(bm25.get_scores(tokens)) # (N,)
60
+
61
+ # Semantic query vector
62
+ q_vec = sem_model.encode([rewrite], normalize_embeddings=True)[0] # (D,)
63
+ sem_scores = embeddings @ q_vec # cosine similarity
64
+
65
+ # Normalize to [0,1]
66
+ bm25_norm = bm25_scores / bm25_scores.max() if bm25_scores.max() > 0 else bm25_scores
67
+ sem_norm = (sem_scores + 1.0) / 2.0
68
+
69
+ thread_id = session["thread_id"]
70
+ N = len(chunks)
71
+ indices = np.arange(N)
72
+
73
+ # Thread filter unless overridden
74
+ if not search_outside_thread:
75
+ mask = np.array([chunks[i]["thread_id"] == thread_id for i in range(N)])
76
+ indices = indices[mask]
77
+ bm25_norm = bm25_norm[mask]
78
+ sem_norm = sem_norm[mask]
79
+
80
+ combined = 0.6 * bm25_norm + 0.4 * sem_norm
81
+ order = np.argsort(-combined)
82
+
83
+ top_k = 8
84
+ top_indices = indices[order[:top_k]]
85
+
86
+ retrieved = []
87
+ for local_rank, idx in enumerate(top_indices):
88
+ c = chunks[idx]
89
+ retrieved.append({
90
+ "chunk_id": c["chunk_id"],
91
+ "thread_id": c["thread_id"],
92
+ "message_id": c["message_id"],
93
+ "page_no": c.get("page_no"),
94
+ "source": c.get("source", "email"),
95
+ "score_bm25": float(bm25_norm[order][local_rank]),
96
+ "score_sem": float(sem_norm[order][local_rank]),
97
+ "score_combined": float(combined[order][local_rank]),
98
+ "text": c["text"],
99
+ # carry over from/to so entity extraction can see people
100
+ "from_addr": c.get("from"),
101
+ "to_addr": c.get("to"),
102
+ "date": c.get("date"),
103
+ })
104
+ return retrieved
105
+
106
+
107
+ def build_answer(user_text: str, rewrite: str, retrieved):
108
+ """
109
+ Answer builder with:
110
+ - 'no clear answer' heuristic
111
+ - special handling for simple 'when' questions using email dates
112
+ - snippet list with citations for grounding
113
+ """
114
+ if not retrieved:
115
+ return (
116
+ "I couldn’t find any emails or content in this thread that clearly answer your question.",
117
+ []
118
+ )
119
+
120
+ # ---- Heuristic: check scores + keyword overlap ----
121
+ question_tokens = {t.lower() for t in user_text.split() if len(t) > 3}
122
+
123
+ def snippet_has_overlap(snippet: str) -> bool:
124
+ words = {w.lower().strip(".,!?;:()[]") for w in snippet.split()}
125
+ return len(question_tokens & words) > 0
126
+
127
+ best_score = max(r["score_combined"] for r in retrieved)
128
+ any_overlap = any(snippet_has_overlap(r["text"]) for r in retrieved)
129
+
130
+ if best_score < 0.2 or not any_overlap:
131
+ # Fallback: nothing strongly relevant in this thread
132
+ return (
133
+ "Within this thread, I don’t see any email that clearly answers this question. "
134
+ "You may need to search outside this thread or check other conversations.",
135
+ []
136
+ )
137
+
138
+ # ---- Optional: direct answer for 'when' questions ----
139
+ direct_answer_line = None
140
+ if "when" in user_text.lower():
141
+ dated = []
142
+ for r in retrieved:
143
+ date_str = r.get("date")
144
+ if not date_str:
145
+ continue
146
+ try:
147
+ dt = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
148
+ dated.append((dt, r))
149
+ except Exception:
150
+ continue
151
+
152
+ if dated:
153
+ # pick the latest email as the likely final approval/confirmation
154
+ dt_best, r_best = max(dated, key=lambda x: x[0])
155
+ nice_date = dt_best.strftime("%Y-%m-%d %H:%M")
156
+ direct_answer_line = (
157
+ f"**Answer:** The most relevant approval email in this thread "
158
+ f"was sent on **{nice_date}** "
159
+ f"[msg: {r_best['message_id']}]."
160
+ )
161
+
162
+ # ---- Build snippet-based explanation ----
163
+ lines = []
164
+ if direct_answer_line:
165
+ lines.append(direct_answer_line)
166
+ lines.append("")
167
+
168
+ lines.append(f"**Question:** {user_text}")
169
+ lines.append("")
170
+ lines.append("**Relevant information:**")
171
+
172
+ citations = []
173
+ seen = set() # avoid exact duplicate snippet+msg combos
174
+
175
+ for r in retrieved:
176
+ msg_id = r["message_id"]
177
+ page_no = r.get("page_no")
178
+ snippet = r["text"].replace("\n", " ")
179
+ snippet = (snippet[:300] + "…") if len(snippet) > 300 else snippet
180
+
181
+ key = (msg_id, snippet)
182
+ if key in seen:
183
+ continue
184
+ seen.add(key)
185
+
186
+ if page_no is not None:
187
+ cite = f"[msg: {msg_id}, page: {page_no}]"
188
+ else:
189
+ cite = f"[msg: {msg_id}]"
190
+
191
+ lines.append(f"- {snippet} {cite}")
192
+
193
+ citations.append({
194
+ "message_id": msg_id,
195
+ "page_no": page_no,
196
+ "chunk_id": r["chunk_id"],
197
+ })
198
+
199
+ answer = "\n".join(lines)
200
+ return answer, citations
201
+
202
+
203
+ def extract_entities_for_turn(user_text: str, retrieved):
204
+ """
205
+ Extract simple entities from this turn:
206
+ - people: email addresses from chunks + question
207
+ - files: filenames like something.pdf
208
+ - amounts: numbers / $ amounts
209
+ - dates: simple date patterns
210
+ """
211
+ texts = [user_text] + [r["text"] for r in retrieved]
212
+
213
+ people = set()
214
+ files = set()
215
+ amounts = set()
216
+ dates = set()
217
+
218
+ # from/to emails are good 'people' proxies
219
+ for r in retrieved:
220
+ for field in ("from_addr", "to_addr"):
221
+ val = r.get(field)
222
+ if not val:
223
+ continue
224
+ for email_match in EMAIL_PAT.findall(val):
225
+ people.add(email_match)
226
+
227
+ # scan all texts
228
+ for t in texts:
229
+ for m in EMAIL_PAT.findall(t):
230
+ people.add(m)
231
+ for m in FILE_PAT.findall(t):
232
+ files.add(m)
233
+ for m in AMOUNT_PAT.findall(t):
234
+ amounts.add(m)
235
+ for m in DATE_PAT.findall(t):
236
+ dates.add(m)
237
+
238
+ entities = {
239
+ "people": sorted(people),
240
+ "amounts": sorted(amounts),
241
+ "files": sorted(files),
242
+ "dates": sorted(dates),
243
+ }
244
+ # Strip empty categories
245
+ entities = {k: v for k, v in entities.items() if v}
246
+ return entities
247
+
248
+
249
+ def log_trace(session_id: str, user_text: str, rewrite: str, retrieved, answer, citations):
250
+ trace_path = RUNS_DIR / "trace.jsonl"
251
+
252
+ session = get_session(session_id)
253
+ thread_id = session["thread_id"] if session else None
254
+
255
+ record = {
256
+ "trace_id": str(uuid.uuid4()),
257
+ "session_id": session_id,
258
+ "thread_id": thread_id,
259
+ "user_text": user_text,
260
+ "rewrite": rewrite,
261
+ "retrieved": [
262
+ {
263
+ "chunk_id": r["chunk_id"],
264
+ "thread_id": r["thread_id"],
265
+ "message_id": r["message_id"],
266
+ "page_no": r["page_no"],
267
+ "score_bm25": r["score_bm25"],
268
+ "score_sem": r["score_sem"],
269
+ "score_combined": r["score_combined"],
270
+ } for r in retrieved
271
+ ],
272
+ "answer": answer,
273
+ "citations": citations,
274
+ "timestamp": time.time(),
275
+ }
276
+
277
+ with trace_path.open("a", encoding="utf-8") as f:
278
+ f.write(json.dumps(record) + "\n")
279
+
280
+ return record["trace_id"]
email_rag/rag_sessions.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rag_sessions.py
2
+ import uuid
3
+
4
+ SESSIONS = {} # session_id -> {thread_id, recent_turns, entity_memory}
5
+
6
+
7
+ def _init_entity_memory():
8
+ """Create a fresh entity memory structure."""
9
+ return {
10
+ "people": [],
11
+ "amounts": [],
12
+ "files": [],
13
+ "dates": [],
14
+ }
15
+
16
+
17
+ def start_session(thread_id: str) -> str:
18
+ """Create a new session fixed to a given thread."""
19
+ sid = str(uuid.uuid4())
20
+ SESSIONS[sid] = {
21
+ "thread_id": thread_id,
22
+ "recent_turns": [],
23
+ "entity_memory": _init_entity_memory(),
24
+ }
25
+ return sid
26
+
27
+
28
+ def get_session(session_id: str):
29
+ return SESSIONS.get(session_id)
30
+
31
+
32
+ def reset_session(session_id: str):
33
+ """Reset memory but keep the same thread."""
34
+ if session_id in SESSIONS:
35
+ tid = SESSIONS[session_id]["thread_id"]
36
+ SESSIONS[session_id] = {
37
+ "thread_id": tid,
38
+ "recent_turns": [],
39
+ "entity_memory": _init_entity_memory(),
40
+ }
41
+
42
+
43
+ def update_entity_memory(session_id: str, new_entities: dict):
44
+ """
45
+ Merge newly extracted entities into the session's entity_memory.
46
+
47
+ new_entities format:
48
+ {
49
+ "people": [...],
50
+ "amounts": [...],
51
+ "files": [...],
52
+ "dates": [...]
53
+ }
54
+ """
55
+ session = get_session(session_id)
56
+ if session is None:
57
+ return
58
+
59
+ mem = session.get("entity_memory")
60
+ if not mem:
61
+ mem = _init_entity_memory()
62
+ session["entity_memory"] = mem
63
+
64
+ for key, values in new_entities.items():
65
+ if key not in mem:
66
+ mem[key] = []
67
+ # Append only unique values, preserve insertion order
68
+ for v in values:
69
+ if v not in mem[key]:
70
+ mem[key].append(v)
email_rag/rag_timeline.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from rag_data import threads, messages
3
+
4
+
5
+ def build_timeline(thread_id: str) -> str:
6
+ """
7
+ Build a simple markdown timeline for a thread:
8
+ - one line per message
9
+ - sorted by date
10
+ - with [msg: <id>] citations
11
+ """
12
+ msg_ids = threads.get(thread_id, [])
13
+ if not msg_ids:
14
+ return f"No messages found for thread {thread_id}."
15
+
16
+ entries = []
17
+ for mid in msg_ids:
18
+ m = messages.get(mid)
19
+ if not m:
20
+ continue
21
+ date_str = m.get("date") or ""
22
+ sender = m.get("from") or "(unknown)"
23
+ subject = m.get("subject") or "(no subject)"
24
+
25
+ # Try to format date nicely
26
+ try:
27
+ dt = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
28
+ date_fmt = dt.strftime("%Y-%m-%d %H:%M")
29
+ except Exception:
30
+ date_fmt = date_str
31
+
32
+ line = f"- **{date_fmt}** — **{sender}** — _{subject}_ [msg: {mid}]"
33
+ entries.append((date_str, line))
34
+
35
+ # Sort by raw date string; not perfect but fine for this dataset
36
+ entries.sort(key=lambda x: x[0])
37
+ lines = [f"### Timeline for thread {thread_id}", ""]
38
+ lines.extend(line for _, line in entries)
39
+
40
+ return "\n".join(lines)