Spaces:
Sleeping
Sleeping
Update email_rag/rag_retrieval.py
Browse files- email_rag/rag_retrieval.py +15 -10
email_rag/rag_retrieval.py
CHANGED
|
@@ -17,7 +17,7 @@ EMAIL_PAT = re.compile(r"\b[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}\b")
|
|
| 17 |
AMOUNT_PAT = re.compile(r"\b(?:\$|USD\s*)?\d{1,3}(?:,\d{3})*(?:\.\d+)?\b")
|
| 18 |
DATE_PAT = re.compile(r"\b\d{1,2}/\d{1,2}/\d{2,4}\b") # very simple date pattern
|
| 19 |
|
| 20 |
-
|
| 21 |
def rewrite_query(user_text: str, session: dict) -> str:
|
| 22 |
"""
|
| 23 |
Rewrite user query by injecting thread ID and a light summary
|
|
@@ -28,6 +28,7 @@ def rewrite_query(user_text: str, session: dict) -> str:
|
|
| 28 |
|
| 29 |
key_bits = []
|
| 30 |
|
|
|
|
| 31 |
people = mem.get("people") or []
|
| 32 |
if people:
|
| 33 |
key_bits.append(f"people: {', '.join(people[:3])}")
|
|
@@ -50,7 +51,7 @@ def rewrite_query(user_text: str, session: dict) -> str:
|
|
| 50 |
|
| 51 |
return f"In thread {tid}, {context_str}answer this question: {user_text}"
|
| 52 |
|
| 53 |
-
|
| 54 |
def retrieve_chunks(rewrite: str, session: dict, search_outside_thread: bool):
|
| 55 |
"""
|
| 56 |
Hybrid retrieval: BM25 + semantic similarity over precomputed embeddings.
|
|
@@ -60,7 +61,7 @@ def retrieve_chunks(rewrite: str, session: dict, search_outside_thread: bool):
|
|
| 60 |
|
| 61 |
# Semantic query vector
|
| 62 |
q_vec = sem_model.encode([rewrite], normalize_embeddings=True)[0] # (D,)
|
| 63 |
-
sem_scores = embeddings @ q_vec # cosine similarity
|
| 64 |
|
| 65 |
# Normalize to [0,1]
|
| 66 |
bm25_norm = bm25_scores / bm25_scores.max() if bm25_scores.max() > 0 else bm25_scores
|
|
@@ -77,12 +78,13 @@ def retrieve_chunks(rewrite: str, session: dict, search_outside_thread: bool):
|
|
| 77 |
bm25_norm = bm25_norm[mask]
|
| 78 |
sem_norm = sem_norm[mask]
|
| 79 |
|
| 80 |
-
combined = 0.6 * bm25_norm + 0.4 * sem_norm
|
| 81 |
-
order = np.argsort(-combined)
|
| 82 |
|
| 83 |
top_k = 8
|
| 84 |
top_indices = indices[order[:top_k]]
|
| 85 |
|
|
|
|
| 86 |
retrieved = []
|
| 87 |
for local_rank, idx in enumerate(top_indices):
|
| 88 |
c = chunks[idx]
|
|
@@ -116,8 +118,9 @@ def build_answer(user_text: str, rewrite: str, retrieved):
|
|
| 116 |
"I couldn’t find any emails or content in this thread that clearly answer your question.",
|
| 117 |
[]
|
| 118 |
)
|
| 119 |
-
|
| 120 |
-
#
|
|
|
|
| 121 |
question_tokens = {t.lower() for t in user_text.split() if len(t) > 3}
|
| 122 |
|
| 123 |
def snippet_has_overlap(snippet: str) -> bool:
|
|
@@ -177,8 +180,10 @@ def build_answer(user_text: str, rewrite: str, retrieved):
|
|
| 177 |
page_no = r.get("page_no")
|
| 178 |
snippet = r["text"].replace("\n", " ")
|
| 179 |
snippet = (snippet[:300] + "…") if len(snippet) > 300 else snippet
|
| 180 |
-
|
|
|
|
| 181 |
key = (msg_id, snippet)
|
|
|
|
| 182 |
if key in seen:
|
| 183 |
continue
|
| 184 |
seen.add(key)
|
|
@@ -215,7 +220,7 @@ def extract_entities_for_turn(user_text: str, retrieved):
|
|
| 215 |
amounts = set()
|
| 216 |
dates = set()
|
| 217 |
|
| 218 |
-
#
|
| 219 |
for r in retrieved:
|
| 220 |
for field in ("from_addr", "to_addr"):
|
| 221 |
val = r.get(field)
|
|
@@ -245,7 +250,7 @@ def extract_entities_for_turn(user_text: str, retrieved):
|
|
| 245 |
entities = {k: v for k, v in entities.items() if v}
|
| 246 |
return entities
|
| 247 |
|
| 248 |
-
|
| 249 |
def log_trace(session_id: str, user_text: str, rewrite: str, retrieved, answer, citations):
|
| 250 |
trace_path = RUNS_DIR / "trace.jsonl"
|
| 251 |
|
|
|
|
| 17 |
AMOUNT_PAT = re.compile(r"\b(?:\$|USD\s*)?\d{1,3}(?:,\d{3})*(?:\.\d+)?\b")
|
| 18 |
DATE_PAT = re.compile(r"\b\d{1,2}/\d{1,2}/\d{2,4}\b") # very simple date pattern
|
| 19 |
|
| 20 |
+
# making query more richer, this help both BM25 and the embedding model to see thread_ids and important names.
|
| 21 |
def rewrite_query(user_text: str, session: dict) -> str:
|
| 22 |
"""
|
| 23 |
Rewrite user query by injecting thread ID and a light summary
|
|
|
|
| 28 |
|
| 29 |
key_bits = []
|
| 30 |
|
| 31 |
+
#For each entity type (people, files, amounts, dates), Take up to first 3 values.
|
| 32 |
people = mem.get("people") or []
|
| 33 |
if people:
|
| 34 |
key_bits.append(f"people: {', '.join(people[:3])}")
|
|
|
|
| 51 |
|
| 52 |
return f"In thread {tid}, {context_str}answer this question: {user_text}"
|
| 53 |
|
| 54 |
+
# Tokenize rewrite querry, give one score per chunks
|
| 55 |
def retrieve_chunks(rewrite: str, session: dict, search_outside_thread: bool):
|
| 56 |
"""
|
| 57 |
Hybrid retrieval: BM25 + semantic similarity over precomputed embeddings.
|
|
|
|
| 61 |
|
| 62 |
# Semantic query vector
|
| 63 |
q_vec = sem_model.encode([rewrite], normalize_embeddings=True)[0] # (D,)
|
| 64 |
+
sem_scores = embeddings @ q_vec # cosine similarity, dot product with every chunk vector
|
| 65 |
|
| 66 |
# Normalize to [0,1]
|
| 67 |
bm25_norm = bm25_scores / bm25_scores.max() if bm25_scores.max() > 0 else bm25_scores
|
|
|
|
| 78 |
bm25_norm = bm25_norm[mask]
|
| 79 |
sem_norm = sem_norm[mask]
|
| 80 |
|
| 81 |
+
combined = 0.6 * bm25_norm + 0.4 * sem_norm # 60% BM25, 40% semantic
|
| 82 |
+
order = np.argsort(-combined) # indices sorted descending by combined score
|
| 83 |
|
| 84 |
top_k = 8
|
| 85 |
top_indices = indices[order[:top_k]]
|
| 86 |
|
| 87 |
+
#For each top chunk: Copy all key metadata, Add Score
|
| 88 |
retrieved = []
|
| 89 |
for local_rank, idx in enumerate(top_indices):
|
| 90 |
c = chunks[idx]
|
|
|
|
| 118 |
"I couldn’t find any emails or content in this thread that clearly answer your question.",
|
| 119 |
[]
|
| 120 |
)
|
| 121 |
+
|
| 122 |
+
# If the retrieved chunks are low-scoring or don’t share keywords with the question, the system refuses to guess and returns a polite “no clear answer” instead of hallucinating
|
| 123 |
+
# Heuristic: check scores + keyword overlap
|
| 124 |
question_tokens = {t.lower() for t in user_text.split() if len(t) > 3}
|
| 125 |
|
| 126 |
def snippet_has_overlap(snippet: str) -> bool:
|
|
|
|
| 180 |
page_no = r.get("page_no")
|
| 181 |
snippet = r["text"].replace("\n", " ")
|
| 182 |
snippet = (snippet[:300] + "…") if len(snippet) > 300 else snippet
|
| 183 |
+
# Shorten each chunk to a 300-character snippet
|
| 184 |
+
|
| 185 |
key = (msg_id, snippet)
|
| 186 |
+
# avoid printing the same snippet twice
|
| 187 |
if key in seen:
|
| 188 |
continue
|
| 189 |
seen.add(key)
|
|
|
|
| 220 |
amounts = set()
|
| 221 |
dates = set()
|
| 222 |
|
| 223 |
+
# Extracts email addresses and adds them to people (from_addr <-> to_addr)
|
| 224 |
for r in retrieved:
|
| 225 |
for field in ("from_addr", "to_addr"):
|
| 226 |
val = r.get(field)
|
|
|
|
| 250 |
entities = {k: v for k, v in entities.items() if v}
|
| 251 |
return entities
|
| 252 |
|
| 253 |
+
#Logs every interaction (query, rewrite, retrieved chunks, answer, citations) into runs/trace.jsonl for evaluation and debugging.
|
| 254 |
def log_trace(session_id: str, user_text: str, rewrite: str, retrieved, answer, citations):
|
| 255 |
trace_path = RUNS_DIR / "trace.jsonl"
|
| 256 |
|