Email-Rag-Prototype / rag_retrieval.py
raviix46's picture
Update rag_retrieval.py
b6bbf60 verified
raw
history blame
5.14 kB
import json
import time
import uuid
import numpy as np
from rag_config import RUNS_DIR, ROOT_DIR
from rag_data import chunks, bm25, embeddings, sem_model, THREAD_OPTIONS
from rag_sessions import get_session
RUNS_DIR.mkdir(exist_ok=True)
def rewrite_query(user_text: str, session: dict) -> str:
"""Very simple rewrite: attach thread context."""
tid = session["thread_id"]
return f"In thread {tid}, answer this question: {user_text}"
def retrieve_chunks(rewrite: str, session: dict, search_outside_thread: bool):
"""
Hybrid retrieval: BM25 + semantic similarity over precomputed embeddings.
"""
tokens = rewrite.split()
bm25_scores = np.array(bm25.get_scores(tokens)) # (N,)
# Semantic query vector
q_vec = sem_model.encode([rewrite], normalize_embeddings=True)[0] # (D,)
sem_scores = embeddings @ q_vec # cosine similarity
# Normalize to [0,1]
bm25_norm = bm25_scores / bm25_scores.max() if bm25_scores.max() > 0 else bm25_scores
sem_norm = (sem_scores + 1.0) / 2.0
thread_id = session["thread_id"]
N = len(chunks)
indices = np.arange(N)
# Thread filter unless overridden
if not search_outside_thread:
mask = np.array([chunks[i]["thread_id"] == thread_id for i in range(N)])
indices = indices[mask]
bm25_norm = bm25_norm[mask]
sem_norm = sem_norm[mask]
combined = 0.6 * bm25_norm + 0.4 * sem_norm
order = np.argsort(-combined)
top_k = 8
top_indices = indices[order[:top_k]]
retrieved = []
for local_rank, idx in enumerate(top_indices):
c = chunks[idx]
retrieved.append({
"chunk_id": c["chunk_id"],
"thread_id": c["thread_id"],
"message_id": c["message_id"],
"page_no": c.get("page_no"),
"source": c.get("source", "email"),
"score_bm25": float(bm25_norm[order][local_rank]),
"score_sem": float(sem_norm[order][local_rank]),
"score_combined": float(combined[order][local_rank]),
"text": c["text"],
})
return retrieved
def build_answer(user_text: str, rewrite: str, retrieved):
"""
Answer builder with a simple 'no clear answer' heuristic.
- If scores are very low OR none of the retrieved snippets share
meaningful words with the question, we return a graceful fallback.
- Otherwise, we list relevant snippets with citations.
"""
if not retrieved:
return (
"I couldn’t find any emails or content in this thread that clearly answer your question.",
[]
)
# ---- Heuristic: check scores + keyword overlap ----
question_tokens = {t.lower() for t in user_text.split() if len(t) > 3}
def snippet_has_overlap(snippet: str) -> bool:
words = {w.lower().strip(".,!?;:()[]") for w in snippet.split()}
return len(question_tokens & words) > 0
best_score = max(r["score_combined"] for r in retrieved)
any_overlap = any(snippet_has_overlap(r["text"]) for r in retrieved)
if best_score < 0.2 or not any_overlap:
# Fallback: nothing strongly relevant in this thread
return (
"Within this thread, I don’t see any email that clearly answers this question. "
"You may need to search outside this thread or check other conversations.",
[]
)
# ---- Normal snippet-based answer ----
lines = [f"**Question:** {user_text}", "", "**Relevant information:**"]
citations = []
for r in retrieved:
msg_id = r["message_id"]
page_no = r.get("page_no")
snippet = r["text"].replace("\n", " ")
snippet = (snippet[:300] + "…") if len(snippet) > 300 else snippet
if page_no is not None:
cite = f"[msg: {msg_id}, page: {page_no}]"
else:
cite = f"[msg: {msg_id}]"
lines.append(f"- {snippet} {cite}")
citations.append({
"message_id": msg_id,
"page_no": page_no,
"chunk_id": r["chunk_id"],
})
answer = "\n".join(lines)
return answer, citations
def log_trace(session_id: str, user_text: str, rewrite: str, retrieved, answer, citations):
trace_path = RUNS_DIR / "trace.jsonl"
session = get_session(session_id)
thread_id = session["thread_id"] if session else None
record = {
"trace_id": str(uuid.uuid4()),
"session_id": session_id,
"thread_id": thread_id,
"user_text": user_text,
"rewrite": rewrite,
"retrieved": [
{
"chunk_id": r["chunk_id"],
"thread_id": r["thread_id"],
"message_id": r["message_id"],
"page_no": r["page_no"],
"score_bm25": r["score_bm25"],
"score_sem": r["score_sem"],
"score_combined": r["score_combined"],
} for r in retrieved
],
"answer": answer,
"citations": citations,
"timestamp": time.time(),
}
with trace_path.open("a", encoding="utf-8") as f:
f.write(json.dumps(record) + "\n")