Spaces:
Sleeping
Sleeping
Update api/rag_engine.py
Browse files- api/rag_engine.py +23 -27
api/rag_engine.py
CHANGED
|
@@ -21,7 +21,6 @@ from pypdf import PdfReader
|
|
| 21 |
from docx import Document
|
| 22 |
from pptx import Presentation
|
| 23 |
|
| 24 |
-
|
| 25 |
# ----------------------------
|
| 26 |
# Helpers
|
| 27 |
# ----------------------------
|
|
@@ -158,21 +157,30 @@ def build_rag_chunks_from_file(path: str, doc_type: str) -> List[Dict]:
|
|
| 158 |
def retrieve_relevant_chunks(
|
| 159 |
query: str,
|
| 160 |
chunks: List[Dict],
|
| 161 |
-
k: int =
|
| 162 |
-
max_context_chars: int =
|
| 163 |
-
min_score: int =
|
| 164 |
) -> Tuple[str, List[Dict]]:
|
| 165 |
"""
|
| 166 |
Deterministic lightweight retrieval (no embeddings):
|
| 167 |
-
- score by token overlap
|
| 168 |
-
-
|
| 169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
"""
|
| 171 |
query = _clean_text(query)
|
| 172 |
if not query or not chunks:
|
| 173 |
return "", []
|
| 174 |
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
if not q_tokens:
|
| 177 |
return "", []
|
| 178 |
|
|
@@ -183,19 +191,13 @@ def retrieve_relevant_chunks(
|
|
| 183 |
continue
|
| 184 |
t_tokens = set(re.findall(r"[a-zA-Z0-9]+", text.lower()))
|
| 185 |
score = len(q_tokens.intersection(t_tokens))
|
| 186 |
-
if score >
|
| 187 |
scored.append((score, c))
|
| 188 |
|
| 189 |
if not scored:
|
| 190 |
return "", []
|
| 191 |
|
| 192 |
scored.sort(key=lambda x: x[0], reverse=True)
|
| 193 |
-
|
| 194 |
-
# 如果最相关的都很弱,就别塞 RAG(避免白白变慢)
|
| 195 |
-
best_score = scored[0][0]
|
| 196 |
-
if best_score < min_score:
|
| 197 |
-
return "", []
|
| 198 |
-
|
| 199 |
top = [c for _, c in scored[:k]]
|
| 200 |
|
| 201 |
buf_parts: List[str] = []
|
|
@@ -205,18 +207,12 @@ def retrieve_relevant_chunks(
|
|
| 205 |
t = c.get("text") or ""
|
| 206 |
if not t:
|
| 207 |
continue
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
if
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
t = t[:remaining]
|
| 215 |
-
|
| 216 |
-
buf_parts.append(t)
|
| 217 |
-
used.append(c)
|
| 218 |
-
total += len(t)
|
| 219 |
-
|
| 220 |
if total >= max_context_chars:
|
| 221 |
break
|
| 222 |
|
|
|
|
| 21 |
from docx import Document
|
| 22 |
from pptx import Presentation
|
| 23 |
|
|
|
|
| 24 |
# ----------------------------
|
| 25 |
# Helpers
|
| 26 |
# ----------------------------
|
|
|
|
| 157 |
def retrieve_relevant_chunks(
|
| 158 |
query: str,
|
| 159 |
chunks: List[Dict],
|
| 160 |
+
k: int = 1,
|
| 161 |
+
max_context_chars: int = 600,
|
| 162 |
+
min_score: int = 6,
|
| 163 |
) -> Tuple[str, List[Dict]]:
|
| 164 |
"""
|
| 165 |
Deterministic lightweight retrieval (no embeddings):
|
| 166 |
+
- score by token overlap
|
| 167 |
+
- return top-k chunks concatenated as context
|
| 168 |
+
|
| 169 |
+
Speed improvements:
|
| 170 |
+
- short/generic queries won't trigger RAG
|
| 171 |
+
- higher min_score prevents accidental triggers
|
| 172 |
+
- smaller max_context_chars reduces LLM prompt size
|
| 173 |
"""
|
| 174 |
query = _clean_text(query)
|
| 175 |
if not query or not chunks:
|
| 176 |
return "", []
|
| 177 |
|
| 178 |
+
# ✅ Short query gate: avoid wasting time on RAG for greetings / tiny inputs
|
| 179 |
+
q_tokens_list = re.findall(r"[a-zA-Z0-9]+", query.lower())
|
| 180 |
+
if (len(q_tokens_list) < 3) and (len(query) < 20):
|
| 181 |
+
return "", []
|
| 182 |
+
|
| 183 |
+
q_tokens = set(q_tokens_list)
|
| 184 |
if not q_tokens:
|
| 185 |
return "", []
|
| 186 |
|
|
|
|
| 191 |
continue
|
| 192 |
t_tokens = set(re.findall(r"[a-zA-Z0-9]+", text.lower()))
|
| 193 |
score = len(q_tokens.intersection(t_tokens))
|
| 194 |
+
if score >= min_score:
|
| 195 |
scored.append((score, c))
|
| 196 |
|
| 197 |
if not scored:
|
| 198 |
return "", []
|
| 199 |
|
| 200 |
scored.sort(key=lambda x: x[0], reverse=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
top = [c for _, c in scored[:k]]
|
| 202 |
|
| 203 |
buf_parts: List[str] = []
|
|
|
|
| 207 |
t = c.get("text") or ""
|
| 208 |
if not t:
|
| 209 |
continue
|
| 210 |
+
if total + len(t) > max_context_chars:
|
| 211 |
+
t = t[: max(0, max_context_chars - total)]
|
| 212 |
+
if t:
|
| 213 |
+
buf_parts.append(t)
|
| 214 |
+
used.append(c)
|
| 215 |
+
total += len(t)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
if total >= max_context_chars:
|
| 217 |
break
|
| 218 |
|