raviix46 commited on
Commit
370b601
·
verified ·
1 Parent(s): 9ca244a

Update email_rag/rag_retrieval.py

Browse files
Files changed (1) hide show
  1. email_rag/rag_retrieval.py +15 -10
email_rag/rag_retrieval.py CHANGED
@@ -17,7 +17,7 @@ EMAIL_PAT = re.compile(r"\b[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}\b")
17
  AMOUNT_PAT = re.compile(r"\b(?:\$|USD\s*)?\d{1,3}(?:,\d{3})*(?:\.\d+)?\b")
18
  DATE_PAT = re.compile(r"\b\d{1,2}/\d{1,2}/\d{2,4}\b") # very simple date pattern
19
 
20
-
21
  def rewrite_query(user_text: str, session: dict) -> str:
22
  """
23
  Rewrite user query by injecting thread ID and a light summary
@@ -28,6 +28,7 @@ def rewrite_query(user_text: str, session: dict) -> str:
28
 
29
  key_bits = []
30
 
 
31
  people = mem.get("people") or []
32
  if people:
33
  key_bits.append(f"people: {', '.join(people[:3])}")
@@ -50,7 +51,7 @@ def rewrite_query(user_text: str, session: dict) -> str:
50
 
51
  return f"In thread {tid}, {context_str}answer this question: {user_text}"
52
 
53
-
54
  def retrieve_chunks(rewrite: str, session: dict, search_outside_thread: bool):
55
  """
56
  Hybrid retrieval: BM25 + semantic similarity over precomputed embeddings.
@@ -60,7 +61,7 @@ def retrieve_chunks(rewrite: str, session: dict, search_outside_thread: bool):
60
 
61
  # Semantic query vector
62
  q_vec = sem_model.encode([rewrite], normalize_embeddings=True)[0] # (D,)
63
- sem_scores = embeddings @ q_vec # cosine similarity
64
 
65
  # Normalize to [0,1]
66
  bm25_norm = bm25_scores / bm25_scores.max() if bm25_scores.max() > 0 else bm25_scores
@@ -77,12 +78,13 @@ def retrieve_chunks(rewrite: str, session: dict, search_outside_thread: bool):
77
  bm25_norm = bm25_norm[mask]
78
  sem_norm = sem_norm[mask]
79
 
80
- combined = 0.6 * bm25_norm + 0.4 * sem_norm
81
- order = np.argsort(-combined)
82
 
83
  top_k = 8
84
  top_indices = indices[order[:top_k]]
85
 
 
86
  retrieved = []
87
  for local_rank, idx in enumerate(top_indices):
88
  c = chunks[idx]
@@ -116,8 +118,9 @@ def build_answer(user_text: str, rewrite: str, retrieved):
116
  "I couldn’t find any emails or content in this thread that clearly answer your question.",
117
  []
118
  )
119
-
120
- # ---- Heuristic: check scores + keyword overlap ----
 
121
  question_tokens = {t.lower() for t in user_text.split() if len(t) > 3}
122
 
123
  def snippet_has_overlap(snippet: str) -> bool:
@@ -177,8 +180,10 @@ def build_answer(user_text: str, rewrite: str, retrieved):
177
  page_no = r.get("page_no")
178
  snippet = r["text"].replace("\n", " ")
179
  snippet = (snippet[:300] + "…") if len(snippet) > 300 else snippet
180
-
 
181
  key = (msg_id, snippet)
 
182
  if key in seen:
183
  continue
184
  seen.add(key)
@@ -215,7 +220,7 @@ def extract_entities_for_turn(user_text: str, retrieved):
215
  amounts = set()
216
  dates = set()
217
 
218
- # from/to emails are good 'people' proxies
219
  for r in retrieved:
220
  for field in ("from_addr", "to_addr"):
221
  val = r.get(field)
@@ -245,7 +250,7 @@ def extract_entities_for_turn(user_text: str, retrieved):
245
  entities = {k: v for k, v in entities.items() if v}
246
  return entities
247
 
248
-
249
  def log_trace(session_id: str, user_text: str, rewrite: str, retrieved, answer, citations):
250
  trace_path = RUNS_DIR / "trace.jsonl"
251
 
 
17
  AMOUNT_PAT = re.compile(r"\b(?:\$|USD\s*)?\d{1,3}(?:,\d{3})*(?:\.\d+)?\b")
18
  DATE_PAT = re.compile(r"\b\d{1,2}/\d{1,2}/\d{2,4}\b") # very simple date pattern
19
 
20
+ # making query more richer, this help both BM25 and the embedding model to see thread_ids and important names.
21
  def rewrite_query(user_text: str, session: dict) -> str:
22
  """
23
  Rewrite user query by injecting thread ID and a light summary
 
28
 
29
  key_bits = []
30
 
31
+ #For each entity type (people, files, amounts, dates), Take up to first 3 values.
32
  people = mem.get("people") or []
33
  if people:
34
  key_bits.append(f"people: {', '.join(people[:3])}")
 
51
 
52
  return f"In thread {tid}, {context_str}answer this question: {user_text}"
53
 
54
+ # Tokenize rewrite querry, give one score per chunks
55
  def retrieve_chunks(rewrite: str, session: dict, search_outside_thread: bool):
56
  """
57
  Hybrid retrieval: BM25 + semantic similarity over precomputed embeddings.
 
61
 
62
  # Semantic query vector
63
  q_vec = sem_model.encode([rewrite], normalize_embeddings=True)[0] # (D,)
64
+ sem_scores = embeddings @ q_vec # cosine similarity, dot product with every chunk vector
65
 
66
  # Normalize to [0,1]
67
  bm25_norm = bm25_scores / bm25_scores.max() if bm25_scores.max() > 0 else bm25_scores
 
78
  bm25_norm = bm25_norm[mask]
79
  sem_norm = sem_norm[mask]
80
 
81
+ combined = 0.6 * bm25_norm + 0.4 * sem_norm # 60% BM25, 40% semantic
82
+ order = np.argsort(-combined) # indices sorted descending by combined score
83
 
84
  top_k = 8
85
  top_indices = indices[order[:top_k]]
86
 
87
+ #For each top chunk: Copy all key metadata, Add Score
88
  retrieved = []
89
  for local_rank, idx in enumerate(top_indices):
90
  c = chunks[idx]
 
118
  "I couldn’t find any emails or content in this thread that clearly answer your question.",
119
  []
120
  )
121
+
122
+ # If the retrieved chunks are low-scoring or don’t share keywords with the question, the system refuses to guess and returns a polite “no clear answer” instead of hallucinating
123
+ # Heuristic: check scores + keyword overlap
124
  question_tokens = {t.lower() for t in user_text.split() if len(t) > 3}
125
 
126
  def snippet_has_overlap(snippet: str) -> bool:
 
180
  page_no = r.get("page_no")
181
  snippet = r["text"].replace("\n", " ")
182
  snippet = (snippet[:300] + "…") if len(snippet) > 300 else snippet
183
+ # Shorten each chunk to a 300-character snippet
184
+
185
  key = (msg_id, snippet)
186
+ # avoid printing the same snippet twice
187
  if key in seen:
188
  continue
189
  seen.add(key)
 
220
  amounts = set()
221
  dates = set()
222
 
223
+ # Extracts email addresses and adds them to people (from_addr <-> to_addr)
224
  for r in retrieved:
225
  for field in ("from_addr", "to_addr"):
226
  val = r.get(field)
 
250
  entities = {k: v for k, v in entities.items() if v}
251
  return entities
252
 
253
+ #Logs every interaction (query, rewrite, retrieved chunks, answer, citations) into runs/trace.jsonl for evaluation and debugging.
254
  def log_trace(session_id: str, user_text: str, rewrite: str, retrieved, answer, citations):
255
  trace_path = RUNS_DIR / "trace.jsonl"
256