SarahXia0405 commited on
Commit
6f94a8a
·
verified ·
1 Parent(s): 0c571ff

Update api/rag_engine.py

Browse files
Files changed (1) hide show
  1. api/rag_engine.py +23 -27
api/rag_engine.py CHANGED
@@ -21,7 +21,6 @@ from pypdf import PdfReader
21
  from docx import Document
22
  from pptx import Presentation
23
 
24
-
25
  # ----------------------------
26
  # Helpers
27
  # ----------------------------
@@ -158,21 +157,30 @@ def build_rag_chunks_from_file(path: str, doc_type: str) -> List[Dict]:
158
  def retrieve_relevant_chunks(
159
  query: str,
160
  chunks: List[Dict],
161
- k: int = 2,
162
- max_context_chars: int = 1200,
163
- min_score: int = 3,
164
  ) -> Tuple[str, List[Dict]]:
165
  """
166
  Deterministic lightweight retrieval (no embeddings):
167
- - score by token overlap (fast)
168
- - ONLY include context when overlap score is meaningful (>= min_score)
169
- - keep context short to reduce LLM latency
 
 
 
 
170
  """
171
  query = _clean_text(query)
172
  if not query or not chunks:
173
  return "", []
174
 
175
- q_tokens = set(re.findall(r"[a-zA-Z0-9]+", query.lower()))
 
 
 
 
 
176
  if not q_tokens:
177
  return "", []
178
 
@@ -183,19 +191,13 @@ def retrieve_relevant_chunks(
183
  continue
184
  t_tokens = set(re.findall(r"[a-zA-Z0-9]+", text.lower()))
185
  score = len(q_tokens.intersection(t_tokens))
186
- if score > 0:
187
  scored.append((score, c))
188
 
189
  if not scored:
190
  return "", []
191
 
192
  scored.sort(key=lambda x: x[0], reverse=True)
193
-
194
- # 如果最相关的都很弱,就别塞 RAG(避免白白变慢)
195
- best_score = scored[0][0]
196
- if best_score < min_score:
197
- return "", []
198
-
199
  top = [c for _, c in scored[:k]]
200
 
201
  buf_parts: List[str] = []
@@ -205,18 +207,12 @@ def retrieve_relevant_chunks(
205
  t = c.get("text") or ""
206
  if not t:
207
  continue
208
-
209
- remaining = max_context_chars - total
210
- if remaining <= 0:
211
- break
212
-
213
- if len(t) > remaining:
214
- t = t[:remaining]
215
-
216
- buf_parts.append(t)
217
- used.append(c)
218
- total += len(t)
219
-
220
  if total >= max_context_chars:
221
  break
222
 
 
21
  from docx import Document
22
  from pptx import Presentation
23
 
 
24
  # ----------------------------
25
  # Helpers
26
  # ----------------------------
 
157
  def retrieve_relevant_chunks(
158
  query: str,
159
  chunks: List[Dict],
160
+ k: int = 1,
161
+ max_context_chars: int = 600,
162
+ min_score: int = 6,
163
  ) -> Tuple[str, List[Dict]]:
164
  """
165
  Deterministic lightweight retrieval (no embeddings):
166
+ - score by token overlap
167
+ - return top-k chunks concatenated as context
168
+
169
+ Speed improvements:
170
+ - short/generic queries won't trigger RAG
171
+ - higher min_score prevents accidental triggers
172
+ - smaller max_context_chars reduces LLM prompt size
173
  """
174
  query = _clean_text(query)
175
  if not query or not chunks:
176
  return "", []
177
 
178
+ # Short query gate: avoid wasting time on RAG for greetings / tiny inputs
179
+ q_tokens_list = re.findall(r"[a-zA-Z0-9]+", query.lower())
180
+ if (len(q_tokens_list) < 3) and (len(query) < 20):
181
+ return "", []
182
+
183
+ q_tokens = set(q_tokens_list)
184
  if not q_tokens:
185
  return "", []
186
 
 
191
  continue
192
  t_tokens = set(re.findall(r"[a-zA-Z0-9]+", text.lower()))
193
  score = len(q_tokens.intersection(t_tokens))
194
+ if score >= min_score:
195
  scored.append((score, c))
196
 
197
  if not scored:
198
  return "", []
199
 
200
  scored.sort(key=lambda x: x[0], reverse=True)
 
 
 
 
 
 
201
  top = [c for _, c in scored[:k]]
202
 
203
  buf_parts: List[str] = []
 
207
  t = c.get("text") or ""
208
  if not t:
209
  continue
210
+ if total + len(t) > max_context_chars:
211
+ t = t[: max(0, max_context_chars - total)]
212
+ if t:
213
+ buf_parts.append(t)
214
+ used.append(c)
215
+ total += len(t)
 
 
 
 
 
 
216
  if total >= max_context_chars:
217
  break
218