SarahXia0405 commited on
Commit
3268902
·
verified ·
1 Parent(s): 037dc25

Update api/rag_engine.py

Browse files
Files changed (1) hide show
  1. api/rag_engine.py +121 -17
api/rag_engine.py CHANGED
@@ -21,6 +21,77 @@ from pypdf import PdfReader
21
  from docx import Document
22
  from pptx import Presentation
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  # ----------------------------
25
  # Helpers
26
  # ----------------------------
@@ -157,19 +228,22 @@ def build_rag_chunks_from_file(path: str, doc_type: str) -> List[Dict]:
157
  def retrieve_relevant_chunks(
158
  query: str,
159
  chunks: List[Dict],
160
- k: int = 1,
161
- max_context_chars: int = 600,
162
  min_score: int = 6,
 
 
 
163
  ) -> Tuple[str, List[Dict]]:
164
  """
165
  Deterministic lightweight retrieval (no embeddings):
166
  - score by token overlap
167
  - return top-k chunks concatenated as context
168
 
169
- Speed improvements:
170
- - short/generic queries won't trigger RAG
171
- - higher min_score prevents accidental triggers
172
- - smaller max_context_chars reduces LLM prompt size
173
  """
174
  query = _clean_text(query)
175
  if not query or not chunks:
@@ -198,22 +272,52 @@ def retrieve_relevant_chunks(
198
  return "", []
199
 
200
  scored.sort(key=lambda x: x[0], reverse=True)
 
 
 
201
  top = [c for _, c in scored[:k]]
202
 
203
- buf_parts: List[str] = []
204
  used: List[Dict] = []
205
- total = 0
 
 
206
  for c in top:
207
- t = c.get("text") or ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  if not t:
209
  continue
210
- if total + len(t) > max_context_chars:
211
- t = t[: max(0, max_context_chars - total)]
212
- if t:
213
- buf_parts.append(t)
214
- used.append(c)
215
- total += len(t)
216
- if total >= max_context_chars:
217
  break
218
 
219
- return "\n\n---\n\n".join(buf_parts), used
 
 
 
 
 
21
  from docx import Document
22
  from pptx import Presentation
23
 
24
+
25
+ # ============================
26
+ # Token helpers (optional tiktoken)
27
+ # ============================
28
+ def _safe_import_tiktoken():
29
+ try:
30
+ import tiktoken # type: ignore
31
+ return tiktoken
32
+ except Exception:
33
+ return None
34
+
35
+
36
+ def _approx_tokens(text: str) -> int:
37
+ if not text:
38
+ return 0
39
+ return max(1, int(len(text) / 4))
40
+
41
+
42
+ def _count_text_tokens(text: str, model: str = "") -> int:
43
+ tk = _safe_import_tiktoken()
44
+ if tk is None:
45
+ return _approx_tokens(text)
46
+
47
+ try:
48
+ enc = tk.encoding_for_model(model) if model else tk.get_encoding("cl100k_base")
49
+ except Exception:
50
+ enc = tk.get_encoding("cl100k_base")
51
+
52
+ return len(enc.encode(text or ""))
53
+
54
+
55
+ def _truncate_to_tokens(text: str, max_tokens: int, model: str = "") -> str:
56
+ """
57
+ Deterministic truncation. Uses tiktoken if available; otherwise approximates by char ratio.
58
+ """
59
+ if not text:
60
+ return text
61
+
62
+ tk = _safe_import_tiktoken()
63
+ if tk is None:
64
+ # approximate by chars
65
+ total = _approx_tokens(text)
66
+ if total <= max_tokens:
67
+ return text
68
+ ratio = max_tokens / max(1, total)
69
+ cut = max(50, min(len(text), int(len(text) * ratio)))
70
+ s = text[:cut]
71
+ # tighten
72
+ while _approx_tokens(s) > max_tokens and len(s) > 50:
73
+ s = s[: int(len(s) * 0.9)]
74
+ return s
75
+
76
+ try:
77
+ enc = tk.encoding_for_model(model) if model else tk.get_encoding("cl100k_base")
78
+ except Exception:
79
+ enc = tk.get_encoding("cl100k_base")
80
+
81
+ ids = enc.encode(text or "")
82
+ if len(ids) <= max_tokens:
83
+ return text
84
+ return enc.decode(ids[:max_tokens])
85
+
86
+
87
+ # ============================
88
+ # RAG hard limits
89
+ # ============================
90
+ RAG_TOPK_LIMIT = 4
91
+ RAG_CHUNK_TOKEN_LIMIT = 500
92
+ RAG_CONTEXT_TOKEN_LIMIT = 2000 # 4 * 500
93
+
94
+
95
  # ----------------------------
96
  # Helpers
97
  # ----------------------------
 
228
  def retrieve_relevant_chunks(
229
  query: str,
230
  chunks: List[Dict],
231
+ k: int = RAG_TOPK_LIMIT,
232
+ max_context_chars: int = 600, # kept for backward compatibility (still used as a safety cap)
233
  min_score: int = 6,
234
+ chunk_token_limit: int = RAG_CHUNK_TOKEN_LIMIT,
235
+ max_context_tokens: int = RAG_CONTEXT_TOKEN_LIMIT,
236
+ model_for_tokenizer: str = "",
237
  ) -> Tuple[str, List[Dict]]:
238
  """
239
  Deterministic lightweight retrieval (no embeddings):
240
  - score by token overlap
241
  - return top-k chunks concatenated as context
242
 
243
+ Hard limits implemented:
244
+ - top-k <= 4 (default)
245
+ - each chunk <= 500 tokens
246
+ - total context <= 2000 tokens (default)
247
  """
248
  query = _clean_text(query)
249
  if not query or not chunks:
 
272
  return "", []
273
 
274
  scored.sort(key=lambda x: x[0], reverse=True)
275
+
276
+ # hard cap k
277
+ k = min(int(k or RAG_TOPK_LIMIT), RAG_TOPK_LIMIT)
278
  top = [c for _, c in scored[:k]]
279
 
280
+ # truncate each chunk to <= chunk_token_limit
281
  used: List[Dict] = []
282
+ truncated_texts: List[str] = []
283
+ total_tokens = 0
284
+
285
  for c in top:
286
+ raw = c.get("text") or ""
287
+ if not raw:
288
+ continue
289
+
290
+ t = _truncate_to_tokens(raw, max_tokens=chunk_token_limit, model=model_for_tokenizer)
291
+
292
+ # enforce total context tokens cap
293
+ t_tokens = _count_text_tokens(t, model=model_for_tokenizer)
294
+ if total_tokens + t_tokens > max_context_tokens:
295
+ remaining = max_context_tokens - total_tokens
296
+ if remaining <= 0:
297
+ break
298
+ t = _truncate_to_tokens(t, max_tokens=remaining, model=model_for_tokenizer)
299
+ t_tokens = _count_text_tokens(t, model=model_for_tokenizer)
300
+
301
+ # legacy char cap safety (keep your previous behavior as extra guard)
302
+ if max_context_chars and max_context_chars > 0:
303
+ # approximate: don't let total string blow up
304
+ current_chars = sum(len(x) for x in truncated_texts)
305
+ if current_chars + len(t) > max_context_chars:
306
+ t = t[: max(0, max_context_chars - current_chars)]
307
+
308
+ t = _clean_text(t)
309
  if not t:
310
  continue
311
+
312
+ truncated_texts.append(t)
313
+ used.append(c)
314
+ total_tokens += t_tokens
315
+
316
+ if total_tokens >= max_context_tokens:
 
317
  break
318
 
319
+ if not truncated_texts:
320
+ return "", []
321
+
322
+ context = "\n\n---\n\n".join(truncated_texts)
323
+ return context, used