SarahXia0405 commited on
Commit
d83a16b
·
verified ·
1 Parent(s): a6f0418

Update api/rag_engine.py

Browse files
Files changed (1) hide show
  1. api/rag_engine.py +46 -6
api/rag_engine.py CHANGED
@@ -2,7 +2,7 @@
2
  """
3
  RAG engine:
4
  - build_rag_chunks_from_file(path, doc_type) -> List[chunk]
5
- - retrieve_relevant_chunks(query, chunks) -> (context_text, used_chunks)
6
 
7
  Chunk format (MVP):
8
  {
@@ -11,11 +11,17 @@ Chunk format (MVP):
11
  "section": str,
12
  "doc_type": str
13
  }
 
 
 
 
 
 
14
  """
15
 
16
  import os
17
  import re
18
- from typing import Dict, List, Tuple
19
 
20
  from pypdf import PdfReader
21
  from docx import Document
@@ -61,14 +67,12 @@ def _truncate_to_tokens(text: str, max_tokens: int, model: str = "") -> str:
61
 
62
  tk = _safe_import_tiktoken()
63
  if tk is None:
64
- # approximate by chars
65
  total = _approx_tokens(text)
66
  if total <= max_tokens:
67
  return text
68
  ratio = max_tokens / max(1, total)
69
  cut = max(50, min(len(text), int(len(text) * ratio)))
70
  s = text[:cut]
71
- # tighten
72
  while _approx_tokens(s) > max_tokens and len(s) > 50:
73
  s = s[: int(len(s) * 0.9)]
74
  return s
@@ -136,6 +140,13 @@ def _file_label(path: str) -> str:
136
  return os.path.basename(path) if path else "uploaded_file"
137
 
138
 
 
 
 
 
 
 
 
139
  # ----------------------------
140
  # Parsers
141
  # ----------------------------
@@ -234,12 +245,20 @@ def retrieve_relevant_chunks(
234
  chunk_token_limit: int = RAG_CHUNK_TOKEN_LIMIT,
235
  max_context_tokens: int = RAG_CONTEXT_TOKEN_LIMIT,
236
  model_for_tokenizer: str = "",
 
 
 
237
  ) -> Tuple[str, List[Dict]]:
238
  """
239
  Deterministic lightweight retrieval (no embeddings):
240
  - score by token overlap
241
  - return top-k chunks concatenated as context
242
 
 
 
 
 
 
243
  Hard limits implemented:
244
  - top-k <= 4 (default)
245
  - each chunk <= 500 tokens
@@ -249,6 +268,28 @@ def retrieve_relevant_chunks(
249
  if not query or not chunks:
250
  return "", []
251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  # ✅ Short query gate: avoid wasting time on RAG for greetings / tiny inputs
253
  q_tokens_list = re.findall(r"[a-zA-Z0-9]+", query.lower())
254
  if (len(q_tokens_list) < 3) and (len(query) < 20):
@@ -259,7 +300,7 @@ def retrieve_relevant_chunks(
259
  return "", []
260
 
261
  scored: List[Tuple[int, Dict]] = []
262
- for c in chunks:
263
  text = (c.get("text") or "")
264
  if not text:
265
  continue
@@ -300,7 +341,6 @@ def retrieve_relevant_chunks(
300
 
301
  # legacy char cap safety (keep your previous behavior as extra guard)
302
  if max_context_chars and max_context_chars > 0:
303
- # approximate: don't let total string blow up
304
  current_chars = sum(len(x) for x in truncated_texts)
305
  if current_chars + len(t) > max_context_chars:
306
  t = t[: max(0, max_context_chars - current_chars)]
 
2
  """
3
  RAG engine:
4
  - build_rag_chunks_from_file(path, doc_type) -> List[chunk]
5
+ - retrieve_relevant_chunks(query, chunks, ...) -> (context_text, used_chunks)
6
 
7
  Chunk format (MVP):
8
  {
 
11
  "section": str,
12
  "doc_type": str
13
  }
14
+
15
+ ✅ Update in this version:
16
+ - retrieve_relevant_chunks now supports optional scoping:
17
+ - allowed_source_files: Optional[List[str]] (match by basename)
18
+ - allowed_doc_types: Optional[List[str]]
19
+ - Scoping happens BEFORE scoring, so refs returned are guaranteed to be the true used chunks.
20
  """
21
 
22
  import os
23
  import re
24
+ from typing import Dict, List, Tuple, Optional
25
 
26
  from pypdf import PdfReader
27
  from docx import Document
 
67
 
68
  tk = _safe_import_tiktoken()
69
  if tk is None:
 
70
  total = _approx_tokens(text)
71
  if total <= max_tokens:
72
  return text
73
  ratio = max_tokens / max(1, total)
74
  cut = max(50, min(len(text), int(len(text) * ratio)))
75
  s = text[:cut]
 
76
  while _approx_tokens(s) > max_tokens and len(s) > 50:
77
  s = s[: int(len(s) * 0.9)]
78
  return s
 
140
  return os.path.basename(path) if path else "uploaded_file"
141
 
142
 
143
+ def _basename(x: str) -> str:
144
+ try:
145
+ return os.path.basename(x or "")
146
+ except Exception:
147
+ return x or ""
148
+
149
+
150
  # ----------------------------
151
  # Parsers
152
  # ----------------------------
 
245
  chunk_token_limit: int = RAG_CHUNK_TOKEN_LIMIT,
246
  max_context_tokens: int = RAG_CONTEXT_TOKEN_LIMIT,
247
  model_for_tokenizer: str = "",
248
+ # ✅ NEW: scoping controls
249
+ allowed_source_files: Optional[List[str]] = None,
250
+ allowed_doc_types: Optional[List[str]] = None,
251
  ) -> Tuple[str, List[Dict]]:
252
  """
253
  Deterministic lightweight retrieval (no embeddings):
254
  - score by token overlap
255
  - return top-k chunks concatenated as context
256
 
257
+ ✅ Scoping:
258
+ - If allowed_source_files provided: only consider chunks whose source_file basename is in the allowlist
259
+ - If allowed_doc_types provided: only consider chunks whose doc_type is in the allowlist
260
+ Scoping is applied BEFORE scoring; returned used_chunks are the true sources for refs.
261
+
262
  Hard limits implemented:
263
  - top-k <= 4 (default)
264
  - each chunk <= 500 tokens
 
268
  if not query or not chunks:
269
  return "", []
270
 
271
+ # ----------------------------
272
+ # ✅ Apply scoping BEFORE scoring
273
+ # ----------------------------
274
+ filtered = chunks or []
275
+
276
+ if allowed_source_files:
277
+ allow_files = {_basename(str(x)).strip() for x in allowed_source_files if str(x).strip()}
278
+ if allow_files:
279
+ filtered = [
280
+ c
281
+ for c in filtered
282
+ if _basename(str(c.get("source_file", ""))).strip() in allow_files
283
+ ]
284
+
285
+ if allowed_doc_types:
286
+ allow_dt = {str(x).strip() for x in allowed_doc_types if str(x).strip()}
287
+ if allow_dt:
288
+ filtered = [c for c in filtered if str(c.get("doc_type", "")).strip() in allow_dt]
289
+
290
+ if not filtered:
291
+ return "", []
292
+
293
  # ✅ Short query gate: avoid wasting time on RAG for greetings / tiny inputs
294
  q_tokens_list = re.findall(r"[a-zA-Z0-9]+", query.lower())
295
  if (len(q_tokens_list) < 3) and (len(query) < 20):
 
300
  return "", []
301
 
302
  scored: List[Tuple[int, Dict]] = []
303
+ for c in filtered:
304
  text = (c.get("text") or "")
305
  if not text:
306
  continue
 
341
 
342
  # legacy char cap safety (keep your previous behavior as extra guard)
343
  if max_context_chars and max_context_chars > 0:
 
344
  current_chars = sum(len(x) for x in truncated_texts)
345
  if current_chars + len(t) > max_context_chars:
346
  t = t[: max(0, max_context_chars - current_chars)]