raviix46 commited on
Commit
80521e2
·
verified ·
1 Parent(s): cf88796

Update rag_retrieval.py

Browse files
Files changed (1) hide show
  1. rag_retrieval.py +86 -2
rag_retrieval.py CHANGED
@@ -1,7 +1,9 @@
 
1
  import json
2
  import time
3
  import uuid
4
  import numpy as np
 
5
 
6
  from rag_config import RUNS_DIR, ROOT_DIR
7
  from rag_data import chunks, bm25, embeddings, sem_model, THREAD_OPTIONS
@@ -9,11 +11,44 @@ from rag_sessions import get_session
9
 
10
  RUNS_DIR.mkdir(exist_ok=True)
11
 
 
 
 
 
 
 
12
 
13
  def rewrite_query(user_text: str, session: dict) -> str:
14
- """Very simple rewrite: attach thread context."""
 
 
 
15
  tid = session["thread_id"]
16
- return f"In thread {tid}, answer this question: {user_text}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
 
19
  def retrieve_chunks(rewrite: str, session: dict, search_outside_thread: bool):
@@ -61,6 +96,9 @@ def retrieve_chunks(rewrite: str, session: dict, search_outside_thread: bool):
61
  "score_sem": float(sem_norm[order][local_rank]),
62
  "score_combined": float(combined[order][local_rank]),
63
  "text": c["text"],
 
 
 
64
  })
65
  return retrieved
66
 
@@ -124,6 +162,52 @@ def build_answer(user_text: str, rewrite: str, retrieved):
124
  return answer, citations
125
 
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  def log_trace(session_id: str, user_text: str, rewrite: str, retrieved, answer, citations):
128
  trace_path = RUNS_DIR / "trace.jsonl"
129
 
 
1
+ # rag_retrieval.py
2
  import json
3
  import time
4
  import uuid
5
  import numpy as np
6
+ import re
7
 
8
  from rag_config import RUNS_DIR, ROOT_DIR
9
  from rag_data import chunks, bm25, embeddings, sem_model, THREAD_OPTIONS
 
11
 
12
  RUNS_DIR.mkdir(exist_ok=True)
13
 
14
+ # --- simple regex patterns for entities ---
15
+ FILE_PAT = re.compile(r"\b[\w\-.]+\.(?:pdf|docx?|xls[xm]?|pptx?|txt)\b", re.IGNORECASE)
16
+ EMAIL_PAT = re.compile(r"\b[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}\b")
17
+ AMOUNT_PAT = re.compile(r"\b(?:\$|USD\s*)?\d{1,3}(?:,\d{3})*(?:\.\d+)?\b")
18
+ DATE_PAT = re.compile(r"\b\d{1,2}/\d{1,2}/\d{2,4}\b") # very simple date pattern
19
+
20
 
21
  def rewrite_query(user_text: str, session: dict) -> str:
22
+ """
23
+ Rewrite user query by injecting thread ID and a light summary
24
+ of known entities from entity_memory.
25
+ """
26
  tid = session["thread_id"]
27
+ mem = session.get("entity_memory") or {}
28
+
29
+ key_bits = []
30
+
31
+ people = mem.get("people") or []
32
+ if people:
33
+ key_bits.append(f"people: {', '.join(people[:3])}")
34
+
35
+ files = mem.get("files") or []
36
+ if files:
37
+ key_bits.append(f"files: {', '.join(files[:3])}")
38
+
39
+ amounts = mem.get("amounts") or []
40
+ if amounts:
41
+ key_bits.append(f"amounts: {', '.join(amounts[:3])}")
42
+
43
+ dates = mem.get("dates") or []
44
+ if dates:
45
+ key_bits.append(f"dates: {', '.join(dates[:3])}")
46
+
47
+ context_str = ""
48
+ if key_bits:
49
+ context_str = "Known entities in this thread: " + "; ".join(key_bits) + ". "
50
+
51
+ return f"In thread {tid}, {context_str}answer this question: {user_text}"
52
 
53
 
54
  def retrieve_chunks(rewrite: str, session: dict, search_outside_thread: bool):
 
96
  "score_sem": float(sem_norm[order][local_rank]),
97
  "score_combined": float(combined[order][local_rank]),
98
  "text": c["text"],
99
+ # carry over from/to so entity extraction can see people
100
+ "from_addr": c.get("from"),
101
+ "to_addr": c.get("to"),
102
  })
103
  return retrieved
104
 
 
162
  return answer, citations
163
 
164
 
165
+ def extract_entities_for_turn(user_text: str, retrieved):
166
+ """
167
+ Extract simple entities from this turn:
168
+ - people: email addresses from chunks + question
169
+ - files: filenames like something.pdf
170
+ - amounts: numbers / $ amounts
171
+ - dates: simple date patterns
172
+ """
173
+ texts = [user_text] + [r["text"] for r in retrieved]
174
+
175
+ people = set()
176
+ files = set()
177
+ amounts = set()
178
+ dates = set()
179
+
180
+ # from/to emails are good 'people' proxies
181
+ for r in retrieved:
182
+ for field in ("from_addr", "to_addr"):
183
+ val = r.get(field)
184
+ if not val:
185
+ continue
186
+ for email_match in EMAIL_PAT.findall(val):
187
+ people.add(email_match)
188
+
189
+ # scan all texts
190
+ for t in texts:
191
+ for m in EMAIL_PAT.findall(t):
192
+ people.add(m)
193
+ for m in FILE_PAT.findall(t):
194
+ files.add(m)
195
+ for m in AMOUNT_PAT.findall(t):
196
+ amounts.add(m)
197
+ for m in DATE_PAT.findall(t):
198
+ dates.add(m)
199
+
200
+ entities = {
201
+ "people": sorted(people),
202
+ "amounts": sorted(amounts),
203
+ "files": sorted(files),
204
+ "dates": sorted(dates),
205
+ }
206
+ # Strip empty categories
207
+ entities = {k: v for k, v in entities.items() if v}
208
+ return entities
209
+
210
+
211
  def log_trace(session_id: str, user_text: str, rewrite: str, retrieved, answer, citations):
212
  trace_path = RUNS_DIR / "trace.jsonl"
213