Spaces:
Sleeping
Sleeping
| import json | |
| import time | |
| import uuid | |
| import numpy as np | |
| from rag_config import RUNS_DIR, ROOT_DIR | |
| from rag_data import chunks, bm25, embeddings, sem_model, THREAD_OPTIONS | |
| from rag_sessions import get_session | |
| RUNS_DIR.mkdir(exist_ok=True) | |
| def rewrite_query(user_text: str, session: dict) -> str: | |
| """Very simple rewrite: attach thread context.""" | |
| tid = session["thread_id"] | |
| return f"In thread {tid}, answer this question: {user_text}" | |
| def retrieve_chunks(rewrite: str, session: dict, search_outside_thread: bool): | |
| """ | |
| Hybrid retrieval: BM25 + semantic similarity over precomputed embeddings. | |
| """ | |
| tokens = rewrite.split() | |
| bm25_scores = np.array(bm25.get_scores(tokens)) # (N,) | |
| # Semantic query vector | |
| q_vec = sem_model.encode([rewrite], normalize_embeddings=True)[0] # (D,) | |
| sem_scores = embeddings @ q_vec # cosine similarity | |
| # Normalize to [0,1] | |
| bm25_norm = bm25_scores / bm25_scores.max() if bm25_scores.max() > 0 else bm25_scores | |
| sem_norm = (sem_scores + 1.0) / 2.0 | |
| thread_id = session["thread_id"] | |
| N = len(chunks) | |
| indices = np.arange(N) | |
| # Thread filter unless overridden | |
| if not search_outside_thread: | |
| mask = np.array([chunks[i]["thread_id"] == thread_id for i in range(N)]) | |
| indices = indices[mask] | |
| bm25_norm = bm25_norm[mask] | |
| sem_norm = sem_norm[mask] | |
| combined = 0.6 * bm25_norm + 0.4 * sem_norm | |
| order = np.argsort(-combined) | |
| top_k = 8 | |
| top_indices = indices[order[:top_k]] | |
| retrieved = [] | |
| for local_rank, idx in enumerate(top_indices): | |
| c = chunks[idx] | |
| retrieved.append({ | |
| "chunk_id": c["chunk_id"], | |
| "thread_id": c["thread_id"], | |
| "message_id": c["message_id"], | |
| "page_no": c.get("page_no"), | |
| "source": c.get("source", "email"), | |
| "score_bm25": float(bm25_norm[order][local_rank]), | |
| "score_sem": float(sem_norm[order][local_rank]), | |
| "score_combined": float(combined[order][local_rank]), | |
| "text": c["text"], | |
| }) | |
| return retrieved | |
| def build_answer(user_text: str, rewrite: str, retrieved): | |
| """ | |
| Answer builder with a simple 'no clear answer' heuristic. | |
| - If scores are very low OR none of the retrieved snippets share | |
| meaningful words with the question, we return a graceful fallback. | |
| - Otherwise, we list relevant snippets with citations. | |
| """ | |
| if not retrieved: | |
| return ( | |
| "I couldn’t find any emails or content in this thread that clearly answer your question.", | |
| [] | |
| ) | |
| # ---- Heuristic: check scores + keyword overlap ---- | |
| question_tokens = {t.lower() for t in user_text.split() if len(t) > 3} | |
| def snippet_has_overlap(snippet: str) -> bool: | |
| words = {w.lower().strip(".,!?;:()[]") for w in snippet.split()} | |
| return len(question_tokens & words) > 0 | |
| best_score = max(r["score_combined"] for r in retrieved) | |
| any_overlap = any(snippet_has_overlap(r["text"]) for r in retrieved) | |
| if best_score < 0.2 or not any_overlap: | |
| # Fallback: nothing strongly relevant in this thread | |
| return ( | |
| "Within this thread, I don’t see any email that clearly answers this question. " | |
| "You may need to search outside this thread or check other conversations.", | |
| [] | |
| ) | |
| # ---- Normal snippet-based answer ---- | |
| lines = [f"**Question:** {user_text}", "", "**Relevant information:**"] | |
| citations = [] | |
| for r in retrieved: | |
| msg_id = r["message_id"] | |
| page_no = r.get("page_no") | |
| snippet = r["text"].replace("\n", " ") | |
| snippet = (snippet[:300] + "…") if len(snippet) > 300 else snippet | |
| if page_no is not None: | |
| cite = f"[msg: {msg_id}, page: {page_no}]" | |
| else: | |
| cite = f"[msg: {msg_id}]" | |
| lines.append(f"- {snippet} {cite}") | |
| citations.append({ | |
| "message_id": msg_id, | |
| "page_no": page_no, | |
| "chunk_id": r["chunk_id"], | |
| }) | |
| answer = "\n".join(lines) | |
| return answer, citations | |
| def log_trace(session_id: str, user_text: str, rewrite: str, retrieved, answer, citations): | |
| trace_path = RUNS_DIR / "trace.jsonl" | |
| session = get_session(session_id) | |
| thread_id = session["thread_id"] if session else None | |
| record = { | |
| "trace_id": str(uuid.uuid4()), | |
| "session_id": session_id, | |
| "thread_id": thread_id, | |
| "user_text": user_text, | |
| "rewrite": rewrite, | |
| "retrieved": [ | |
| { | |
| "chunk_id": r["chunk_id"], | |
| "thread_id": r["thread_id"], | |
| "message_id": r["message_id"], | |
| "page_no": r["page_no"], | |
| "score_bm25": r["score_bm25"], | |
| "score_sem": r["score_sem"], | |
| "score_combined": r["score_combined"], | |
| } for r in retrieved | |
| ], | |
| "answer": answer, | |
| "citations": citations, | |
| "timestamp": time.time(), | |
| } | |
| with trace_path.open("a", encoding="utf-8") as f: | |
| f.write(json.dumps(record) + "\n") |