Spaces:
Sleeping
Sleeping
Update api/rag_engine.py
Browse files- api/rag_engine.py +46 -6
api/rag_engine.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
"""
|
| 3 |
RAG engine:
|
| 4 |
- build_rag_chunks_from_file(path, doc_type) -> List[chunk]
|
| 5 |
-
- retrieve_relevant_chunks(query, chunks) -> (context_text, used_chunks)
|
| 6 |
|
| 7 |
Chunk format (MVP):
|
| 8 |
{
|
|
@@ -11,11 +11,17 @@ Chunk format (MVP):
|
|
| 11 |
"section": str,
|
| 12 |
"doc_type": str
|
| 13 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
"""
|
| 15 |
|
| 16 |
import os
|
| 17 |
import re
|
| 18 |
-
from typing import Dict, List, Tuple
|
| 19 |
|
| 20 |
from pypdf import PdfReader
|
| 21 |
from docx import Document
|
|
@@ -61,14 +67,12 @@ def _truncate_to_tokens(text: str, max_tokens: int, model: str = "") -> str:
|
|
| 61 |
|
| 62 |
tk = _safe_import_tiktoken()
|
| 63 |
if tk is None:
|
| 64 |
-
# approximate by chars
|
| 65 |
total = _approx_tokens(text)
|
| 66 |
if total <= max_tokens:
|
| 67 |
return text
|
| 68 |
ratio = max_tokens / max(1, total)
|
| 69 |
cut = max(50, min(len(text), int(len(text) * ratio)))
|
| 70 |
s = text[:cut]
|
| 71 |
-
# tighten
|
| 72 |
while _approx_tokens(s) > max_tokens and len(s) > 50:
|
| 73 |
s = s[: int(len(s) * 0.9)]
|
| 74 |
return s
|
|
@@ -136,6 +140,13 @@ def _file_label(path: str) -> str:
|
|
| 136 |
return os.path.basename(path) if path else "uploaded_file"
|
| 137 |
|
| 138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
# ----------------------------
|
| 140 |
# Parsers
|
| 141 |
# ----------------------------
|
|
@@ -234,12 +245,20 @@ def retrieve_relevant_chunks(
|
|
| 234 |
chunk_token_limit: int = RAG_CHUNK_TOKEN_LIMIT,
|
| 235 |
max_context_tokens: int = RAG_CONTEXT_TOKEN_LIMIT,
|
| 236 |
model_for_tokenizer: str = "",
|
|
|
|
|
|
|
|
|
|
| 237 |
) -> Tuple[str, List[Dict]]:
|
| 238 |
"""
|
| 239 |
Deterministic lightweight retrieval (no embeddings):
|
| 240 |
- score by token overlap
|
| 241 |
- return top-k chunks concatenated as context
|
| 242 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
Hard limits implemented:
|
| 244 |
- top-k <= 4 (default)
|
| 245 |
- each chunk <= 500 tokens
|
|
@@ -249,6 +268,28 @@ def retrieve_relevant_chunks(
|
|
| 249 |
if not query or not chunks:
|
| 250 |
return "", []
|
| 251 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
# ✅ Short query gate: avoid wasting time on RAG for greetings / tiny inputs
|
| 253 |
q_tokens_list = re.findall(r"[a-zA-Z0-9]+", query.lower())
|
| 254 |
if (len(q_tokens_list) < 3) and (len(query) < 20):
|
|
@@ -259,7 +300,7 @@ def retrieve_relevant_chunks(
|
|
| 259 |
return "", []
|
| 260 |
|
| 261 |
scored: List[Tuple[int, Dict]] = []
|
| 262 |
-
for c in
|
| 263 |
text = (c.get("text") or "")
|
| 264 |
if not text:
|
| 265 |
continue
|
|
@@ -300,7 +341,6 @@ def retrieve_relevant_chunks(
|
|
| 300 |
|
| 301 |
# legacy char cap safety (keep your previous behavior as extra guard)
|
| 302 |
if max_context_chars and max_context_chars > 0:
|
| 303 |
-
# approximate: don't let total string blow up
|
| 304 |
current_chars = sum(len(x) for x in truncated_texts)
|
| 305 |
if current_chars + len(t) > max_context_chars:
|
| 306 |
t = t[: max(0, max_context_chars - current_chars)]
|
|
|
|
| 2 |
"""
|
| 3 |
RAG engine:
|
| 4 |
- build_rag_chunks_from_file(path, doc_type) -> List[chunk]
|
| 5 |
+
- retrieve_relevant_chunks(query, chunks, ...) -> (context_text, used_chunks)
|
| 6 |
|
| 7 |
Chunk format (MVP):
|
| 8 |
{
|
|
|
|
| 11 |
"section": str,
|
| 12 |
"doc_type": str
|
| 13 |
}
|
| 14 |
+
|
| 15 |
+
✅ Update in this version:
|
| 16 |
+
- retrieve_relevant_chunks now supports optional scoping:
|
| 17 |
+
- allowed_source_files: Optional[List[str]] (match by basename)
|
| 18 |
+
- allowed_doc_types: Optional[List[str]]
|
| 19 |
+
- Scoping happens BEFORE scoring, so refs returned are guaranteed to be the true used chunks.
|
| 20 |
"""
|
| 21 |
|
| 22 |
import os
|
| 23 |
import re
|
| 24 |
+
from typing import Dict, List, Tuple, Optional
|
| 25 |
|
| 26 |
from pypdf import PdfReader
|
| 27 |
from docx import Document
|
|
|
|
| 67 |
|
| 68 |
tk = _safe_import_tiktoken()
|
| 69 |
if tk is None:
|
|
|
|
| 70 |
total = _approx_tokens(text)
|
| 71 |
if total <= max_tokens:
|
| 72 |
return text
|
| 73 |
ratio = max_tokens / max(1, total)
|
| 74 |
cut = max(50, min(len(text), int(len(text) * ratio)))
|
| 75 |
s = text[:cut]
|
|
|
|
| 76 |
while _approx_tokens(s) > max_tokens and len(s) > 50:
|
| 77 |
s = s[: int(len(s) * 0.9)]
|
| 78 |
return s
|
|
|
|
| 140 |
return os.path.basename(path) if path else "uploaded_file"
|
| 141 |
|
| 142 |
|
| 143 |
+
def _basename(x: str) -> str:
|
| 144 |
+
try:
|
| 145 |
+
return os.path.basename(x or "")
|
| 146 |
+
except Exception:
|
| 147 |
+
return x or ""
|
| 148 |
+
|
| 149 |
+
|
| 150 |
# ----------------------------
|
| 151 |
# Parsers
|
| 152 |
# ----------------------------
|
|
|
|
| 245 |
chunk_token_limit: int = RAG_CHUNK_TOKEN_LIMIT,
|
| 246 |
max_context_tokens: int = RAG_CONTEXT_TOKEN_LIMIT,
|
| 247 |
model_for_tokenizer: str = "",
|
| 248 |
+
# ✅ NEW: scoping controls
|
| 249 |
+
allowed_source_files: Optional[List[str]] = None,
|
| 250 |
+
allowed_doc_types: Optional[List[str]] = None,
|
| 251 |
) -> Tuple[str, List[Dict]]:
|
| 252 |
"""
|
| 253 |
Deterministic lightweight retrieval (no embeddings):
|
| 254 |
- score by token overlap
|
| 255 |
- return top-k chunks concatenated as context
|
| 256 |
|
| 257 |
+
✅ Scoping:
|
| 258 |
+
- If allowed_source_files provided: only consider chunks whose source_file basename is in the allowlist
|
| 259 |
+
- If allowed_doc_types provided: only consider chunks whose doc_type is in the allowlist
|
| 260 |
+
Scoping is applied BEFORE scoring; returned used_chunks are the true sources for refs.
|
| 261 |
+
|
| 262 |
Hard limits implemented:
|
| 263 |
- top-k <= 4 (default)
|
| 264 |
- each chunk <= 500 tokens
|
|
|
|
| 268 |
if not query or not chunks:
|
| 269 |
return "", []
|
| 270 |
|
| 271 |
+
# ----------------------------
|
| 272 |
+
# ✅ Apply scoping BEFORE scoring
|
| 273 |
+
# ----------------------------
|
| 274 |
+
filtered = chunks or []
|
| 275 |
+
|
| 276 |
+
if allowed_source_files:
|
| 277 |
+
allow_files = {_basename(str(x)).strip() for x in allowed_source_files if str(x).strip()}
|
| 278 |
+
if allow_files:
|
| 279 |
+
filtered = [
|
| 280 |
+
c
|
| 281 |
+
for c in filtered
|
| 282 |
+
if _basename(str(c.get("source_file", ""))).strip() in allow_files
|
| 283 |
+
]
|
| 284 |
+
|
| 285 |
+
if allowed_doc_types:
|
| 286 |
+
allow_dt = {str(x).strip() for x in allowed_doc_types if str(x).strip()}
|
| 287 |
+
if allow_dt:
|
| 288 |
+
filtered = [c for c in filtered if str(c.get("doc_type", "")).strip() in allow_dt]
|
| 289 |
+
|
| 290 |
+
if not filtered:
|
| 291 |
+
return "", []
|
| 292 |
+
|
| 293 |
# ✅ Short query gate: avoid wasting time on RAG for greetings / tiny inputs
|
| 294 |
q_tokens_list = re.findall(r"[a-zA-Z0-9]+", query.lower())
|
| 295 |
if (len(q_tokens_list) < 3) and (len(query) < 20):
|
|
|
|
| 300 |
return "", []
|
| 301 |
|
| 302 |
scored: List[Tuple[int, Dict]] = []
|
| 303 |
+
for c in filtered:
|
| 304 |
text = (c.get("text") or "")
|
| 305 |
if not text:
|
| 306 |
continue
|
|
|
|
| 341 |
|
| 342 |
# legacy char cap safety (keep your previous behavior as extra guard)
|
| 343 |
if max_context_chars and max_context_chars > 0:
|
|
|
|
| 344 |
current_chars = sum(len(x) for x in truncated_texts)
|
| 345 |
if current_chars + len(t) > max_context_chars:
|
| 346 |
t = t[: max(0, max_context_chars - current_chars)]
|