Email-Rag-Prototype / rag_retrieval.py
raviix46's picture
Update rag_retrieval.py
7a0a491 verified
raw
history blame
8.88 kB
import json
import time
import uuid
import numpy as np
import re
from datetime import datetime
from rag_config import RUNS_DIR, ROOT_DIR
from rag_data import chunks, bm25, embeddings, sem_model, THREAD_OPTIONS
from rag_sessions import get_session
RUNS_DIR.mkdir(exist_ok=True)
# --- simple regex patterns for entities ---
FILE_PAT = re.compile(r"\b[\w\-.]+\.(?:pdf|docx?|xls[xm]?|pptx?|txt)\b", re.IGNORECASE)
EMAIL_PAT = re.compile(r"\b[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}\b")
AMOUNT_PAT = re.compile(r"\b(?:\$|USD\s*)?\d{1,3}(?:,\d{3})*(?:\.\d+)?\b")
DATE_PAT = re.compile(r"\b\d{1,2}/\d{1,2}/\d{2,4}\b") # very simple date pattern
def rewrite_query(user_text: str, session: dict) -> str:
"""
Rewrite user query by injecting thread ID and a light summary
of known entities from entity_memory.
"""
tid = session["thread_id"]
mem = session.get("entity_memory") or {}
key_bits = []
people = mem.get("people") or []
if people:
key_bits.append(f"people: {', '.join(people[:3])}")
files = mem.get("files") or []
if files:
key_bits.append(f"files: {', '.join(files[:3])}")
amounts = mem.get("amounts") or []
if amounts:
key_bits.append(f"amounts: {', '.join(amounts[:3])}")
dates = mem.get("dates") or []
if dates:
key_bits.append(f"dates: {', '.join(dates[:3])}")
context_str = ""
if key_bits:
context_str = "Known entities in this thread: " + "; ".join(key_bits) + ". "
return f"In thread {tid}, {context_str}answer this question: {user_text}"
def retrieve_chunks(rewrite: str, session: dict, search_outside_thread: bool):
"""
Hybrid retrieval: BM25 + semantic similarity over precomputed embeddings.
"""
tokens = rewrite.split()
bm25_scores = np.array(bm25.get_scores(tokens)) # (N,)
# Semantic query vector
q_vec = sem_model.encode([rewrite], normalize_embeddings=True)[0] # (D,)
sem_scores = embeddings @ q_vec # cosine similarity
# Normalize to [0,1]
bm25_norm = bm25_scores / bm25_scores.max() if bm25_scores.max() > 0 else bm25_scores
sem_norm = (sem_scores + 1.0) / 2.0
thread_id = session["thread_id"]
N = len(chunks)
indices = np.arange(N)
# Thread filter unless overridden
if not search_outside_thread:
mask = np.array([chunks[i]["thread_id"] == thread_id for i in range(N)])
indices = indices[mask]
bm25_norm = bm25_norm[mask]
sem_norm = sem_norm[mask]
combined = 0.6 * bm25_norm + 0.4 * sem_norm
order = np.argsort(-combined)
top_k = 8
top_indices = indices[order[:top_k]]
retrieved = []
for local_rank, idx in enumerate(top_indices):
c = chunks[idx]
retrieved.append({
"chunk_id": c["chunk_id"],
"thread_id": c["thread_id"],
"message_id": c["message_id"],
"page_no": c.get("page_no"),
"source": c.get("source", "email"),
"score_bm25": float(bm25_norm[order][local_rank]),
"score_sem": float(sem_norm[order][local_rank]),
"score_combined": float(combined[order][local_rank]),
"text": c["text"],
# carry over from/to so entity extraction can see people
"from_addr": c.get("from"),
"to_addr": c.get("to"),
"date": c.get("date"),
})
return retrieved
def build_answer(user_text: str, rewrite: str, retrieved):
"""
Answer builder with:
- 'no clear answer' heuristic
- special handling for simple 'when' questions using email dates
- snippet list with citations for grounding
"""
if not retrieved:
return (
"I couldn’t find any emails or content in this thread that clearly answer your question.",
[]
)
# ---- Heuristic: check scores + keyword overlap ----
question_tokens = {t.lower() for t in user_text.split() if len(t) > 3}
def snippet_has_overlap(snippet: str) -> bool:
words = {w.lower().strip(".,!?;:()[]") for w in snippet.split()}
return len(question_tokens & words) > 0
best_score = max(r["score_combined"] for r in retrieved)
any_overlap = any(snippet_has_overlap(r["text"]) for r in retrieved)
if best_score < 0.2 or not any_overlap:
# Fallback: nothing strongly relevant in this thread
return (
"Within this thread, I don’t see any email that clearly answers this question. "
"You may need to search outside this thread or check other conversations.",
[]
)
# ---- Optional: direct answer for 'when' questions ----
direct_answer_line = None
if "when" in user_text.lower():
dated = []
for r in retrieved:
date_str = r.get("date")
if not date_str:
continue
try:
dt = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
dated.append((dt, r))
except Exception:
continue
if dated:
# pick the latest email as the likely final approval/confirmation
dt_best, r_best = max(dated, key=lambda x: x[0])
nice_date = dt_best.strftime("%Y-%m-%d %H:%M")
direct_answer_line = (
f"**Answer:** The most relevant approval email in this thread "
f"was sent on **{nice_date}** "
f"[msg: {r_best['message_id']}]."
)
# ---- Build snippet-based explanation ----
lines = []
if direct_answer_line:
lines.append(direct_answer_line)
lines.append("")
lines.append(f"**Question:** {user_text}")
lines.append("")
lines.append("**Relevant information:**")
citations = []
seen = set() # avoid exact duplicate snippet+msg combos
for r in retrieved:
msg_id = r["message_id"]
page_no = r.get("page_no")
snippet = r["text"].replace("\n", " ")
snippet = (snippet[:300] + "…") if len(snippet) > 300 else snippet
key = (msg_id, snippet)
if key in seen:
continue
seen.add(key)
if page_no is not None:
cite = f"[msg: {msg_id}, page: {page_no}]"
else:
cite = f"[msg: {msg_id}]"
lines.append(f"- {snippet} {cite}")
citations.append({
"message_id": msg_id,
"page_no": page_no,
"chunk_id": r["chunk_id"],
})
answer = "\n".join(lines)
return answer, citations
def extract_entities_for_turn(user_text: str, retrieved):
"""
Extract simple entities from this turn:
- people: email addresses from chunks + question
- files: filenames like something.pdf
- amounts: numbers / $ amounts
- dates: simple date patterns
"""
texts = [user_text] + [r["text"] for r in retrieved]
people = set()
files = set()
amounts = set()
dates = set()
# from/to emails are good 'people' proxies
for r in retrieved:
for field in ("from_addr", "to_addr"):
val = r.get(field)
if not val:
continue
for email_match in EMAIL_PAT.findall(val):
people.add(email_match)
# scan all texts
for t in texts:
for m in EMAIL_PAT.findall(t):
people.add(m)
for m in FILE_PAT.findall(t):
files.add(m)
for m in AMOUNT_PAT.findall(t):
amounts.add(m)
for m in DATE_PAT.findall(t):
dates.add(m)
entities = {
"people": sorted(people),
"amounts": sorted(amounts),
"files": sorted(files),
"dates": sorted(dates),
}
# Strip empty categories
entities = {k: v for k, v in entities.items() if v}
return entities
def log_trace(session_id: str, user_text: str, rewrite: str, retrieved, answer, citations):
trace_path = RUNS_DIR / "trace.jsonl"
session = get_session(session_id)
thread_id = session["thread_id"] if session else None
record = {
"trace_id": str(uuid.uuid4()),
"session_id": session_id,
"thread_id": thread_id,
"user_text": user_text,
"rewrite": rewrite,
"retrieved": [
{
"chunk_id": r["chunk_id"],
"thread_id": r["thread_id"],
"message_id": r["message_id"],
"page_no": r["page_no"],
"score_bm25": r["score_bm25"],
"score_sem": r["score_sem"],
"score_combined": r["score_combined"],
} for r in retrieved
],
"answer": answer,
"citations": citations,
"timestamp": time.time(),
}
with trace_path.open("a", encoding="utf-8") as f:
f.write(json.dumps(record) + "\n")
return record["trace_id"]