Spaces:
Sleeping
Sleeping
Update rag_retrieval.py
Browse files- rag_retrieval.py +86 -2
rag_retrieval.py
CHANGED
|
@@ -1,7 +1,9 @@
|
|
|
|
|
| 1 |
import json
|
| 2 |
import time
|
| 3 |
import uuid
|
| 4 |
import numpy as np
|
|
|
|
| 5 |
|
| 6 |
from rag_config import RUNS_DIR, ROOT_DIR
|
| 7 |
from rag_data import chunks, bm25, embeddings, sem_model, THREAD_OPTIONS
|
|
@@ -9,11 +11,44 @@ from rag_sessions import get_session
|
|
| 9 |
|
| 10 |
RUNS_DIR.mkdir(exist_ok=True)
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
def rewrite_query(user_text: str, session: dict) -> str:
|
| 14 |
-
"""
|
|
|
|
|
|
|
|
|
|
| 15 |
tid = session["thread_id"]
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
def retrieve_chunks(rewrite: str, session: dict, search_outside_thread: bool):
|
|
@@ -61,6 +96,9 @@ def retrieve_chunks(rewrite: str, session: dict, search_outside_thread: bool):
|
|
| 61 |
"score_sem": float(sem_norm[order][local_rank]),
|
| 62 |
"score_combined": float(combined[order][local_rank]),
|
| 63 |
"text": c["text"],
|
|
|
|
|
|
|
|
|
|
| 64 |
})
|
| 65 |
return retrieved
|
| 66 |
|
|
@@ -124,6 +162,52 @@ def build_answer(user_text: str, rewrite: str, retrieved):
|
|
| 124 |
return answer, citations
|
| 125 |
|
| 126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
def log_trace(session_id: str, user_text: str, rewrite: str, retrieved, answer, citations):
|
| 128 |
trace_path = RUNS_DIR / "trace.jsonl"
|
| 129 |
|
|
|
|
| 1 |
+
# rag_retrieval.py
|
| 2 |
import json
|
| 3 |
import time
|
| 4 |
import uuid
|
| 5 |
import numpy as np
|
| 6 |
+
import re
|
| 7 |
|
| 8 |
from rag_config import RUNS_DIR, ROOT_DIR
|
| 9 |
from rag_data import chunks, bm25, embeddings, sem_model, THREAD_OPTIONS
|
|
|
|
| 11 |
|
| 12 |
RUNS_DIR.mkdir(exist_ok=True)
|
| 13 |
|
| 14 |
+
# --- simple regex patterns for entities ---
|
| 15 |
+
FILE_PAT = re.compile(r"\b[\w\-.]+\.(?:pdf|docx?|xls[xm]?|pptx?|txt)\b", re.IGNORECASE)
|
| 16 |
+
EMAIL_PAT = re.compile(r"\b[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}\b")
|
| 17 |
+
AMOUNT_PAT = re.compile(r"\b(?:\$|USD\s*)?\d{1,3}(?:,\d{3})*(?:\.\d+)?\b")
|
| 18 |
+
DATE_PAT = re.compile(r"\b\d{1,2}/\d{1,2}/\d{2,4}\b") # very simple date pattern
|
| 19 |
+
|
| 20 |
|
| 21 |
def rewrite_query(user_text: str, session: dict) -> str:
|
| 22 |
+
"""
|
| 23 |
+
Rewrite user query by injecting thread ID and a light summary
|
| 24 |
+
of known entities from entity_memory.
|
| 25 |
+
"""
|
| 26 |
tid = session["thread_id"]
|
| 27 |
+
mem = session.get("entity_memory") or {}
|
| 28 |
+
|
| 29 |
+
key_bits = []
|
| 30 |
+
|
| 31 |
+
people = mem.get("people") or []
|
| 32 |
+
if people:
|
| 33 |
+
key_bits.append(f"people: {', '.join(people[:3])}")
|
| 34 |
+
|
| 35 |
+
files = mem.get("files") or []
|
| 36 |
+
if files:
|
| 37 |
+
key_bits.append(f"files: {', '.join(files[:3])}")
|
| 38 |
+
|
| 39 |
+
amounts = mem.get("amounts") or []
|
| 40 |
+
if amounts:
|
| 41 |
+
key_bits.append(f"amounts: {', '.join(amounts[:3])}")
|
| 42 |
+
|
| 43 |
+
dates = mem.get("dates") or []
|
| 44 |
+
if dates:
|
| 45 |
+
key_bits.append(f"dates: {', '.join(dates[:3])}")
|
| 46 |
+
|
| 47 |
+
context_str = ""
|
| 48 |
+
if key_bits:
|
| 49 |
+
context_str = "Known entities in this thread: " + "; ".join(key_bits) + ". "
|
| 50 |
+
|
| 51 |
+
return f"In thread {tid}, {context_str}answer this question: {user_text}"
|
| 52 |
|
| 53 |
|
| 54 |
def retrieve_chunks(rewrite: str, session: dict, search_outside_thread: bool):
|
|
|
|
| 96 |
"score_sem": float(sem_norm[order][local_rank]),
|
| 97 |
"score_combined": float(combined[order][local_rank]),
|
| 98 |
"text": c["text"],
|
| 99 |
+
# carry over from/to so entity extraction can see people
|
| 100 |
+
"from_addr": c.get("from"),
|
| 101 |
+
"to_addr": c.get("to"),
|
| 102 |
})
|
| 103 |
return retrieved
|
| 104 |
|
|
|
|
| 162 |
return answer, citations
|
| 163 |
|
| 164 |
|
| 165 |
+
def extract_entities_for_turn(user_text: str, retrieved):
|
| 166 |
+
"""
|
| 167 |
+
Extract simple entities from this turn:
|
| 168 |
+
- people: email addresses from chunks + question
|
| 169 |
+
- files: filenames like something.pdf
|
| 170 |
+
- amounts: numbers / $ amounts
|
| 171 |
+
- dates: simple date patterns
|
| 172 |
+
"""
|
| 173 |
+
texts = [user_text] + [r["text"] for r in retrieved]
|
| 174 |
+
|
| 175 |
+
people = set()
|
| 176 |
+
files = set()
|
| 177 |
+
amounts = set()
|
| 178 |
+
dates = set()
|
| 179 |
+
|
| 180 |
+
# from/to emails are good 'people' proxies
|
| 181 |
+
for r in retrieved:
|
| 182 |
+
for field in ("from_addr", "to_addr"):
|
| 183 |
+
val = r.get(field)
|
| 184 |
+
if not val:
|
| 185 |
+
continue
|
| 186 |
+
for email_match in EMAIL_PAT.findall(val):
|
| 187 |
+
people.add(email_match)
|
| 188 |
+
|
| 189 |
+
# scan all texts
|
| 190 |
+
for t in texts:
|
| 191 |
+
for m in EMAIL_PAT.findall(t):
|
| 192 |
+
people.add(m)
|
| 193 |
+
for m in FILE_PAT.findall(t):
|
| 194 |
+
files.add(m)
|
| 195 |
+
for m in AMOUNT_PAT.findall(t):
|
| 196 |
+
amounts.add(m)
|
| 197 |
+
for m in DATE_PAT.findall(t):
|
| 198 |
+
dates.add(m)
|
| 199 |
+
|
| 200 |
+
entities = {
|
| 201 |
+
"people": sorted(people),
|
| 202 |
+
"amounts": sorted(amounts),
|
| 203 |
+
"files": sorted(files),
|
| 204 |
+
"dates": sorted(dates),
|
| 205 |
+
}
|
| 206 |
+
# Strip empty categories
|
| 207 |
+
entities = {k: v for k, v in entities.items() if v}
|
| 208 |
+
return entities
|
| 209 |
+
|
| 210 |
+
|
| 211 |
def log_trace(session_id: str, user_text: str, rewrite: str, retrieved, answer, citations):
|
| 212 |
trace_path = RUNS_DIR / "trace.jsonl"
|
| 213 |
|