SarahXia0405 commited on
Commit
a6f0418
·
verified ·
1 Parent(s): d9f042c

Update api/server.py

Browse files
Files changed (1) hide show
  1. api/server.py +53 -2
api/server.py CHANGED
@@ -180,6 +180,50 @@ def _should_force_rag(message: str) -> bool:
180
  ]
181
  return any(t in m for t in triggers)
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
  # ----------------------------
185
  # Warmup
@@ -466,11 +510,18 @@ def chat(req: ChatReq):
466
 
467
  # NEW: do NOT bypass RAG for document actions (so UI refs are preserved)
468
  force_rag = _should_force_rag(msg)
469
-
 
 
470
  if (len(msg) < 20 and ("?" not in msg)) and (not force_rag):
471
  rag_context_text, rag_used_chunks = "", []
472
  else:
473
- rag_context_text, rag_used_chunks = retrieve_relevant_chunks(msg, sess["rag_chunks"])
 
 
 
 
 
474
 
475
 
476
 
 
180
  ]
181
  return any(t in m for t in triggers)
182
 
183
+ def _extract_filename_hint(message: str) -> Optional[str]:
184
+ m = (message or "").strip()
185
+ if not m:
186
+ return None
187
+ # 极简:如果用户直接提到了 .pdf/.ppt/.docx 文件名,就用它
188
+ for token in m.replace("“", '"').replace("”", '"').split():
189
+ if any(token.lower().endswith(ext) for ext in [".pdf", ".ppt", ".pptx", ".doc", ".docx"]):
190
+ return os.path.basename(token.strip('"').strip("'").strip())
191
+ return None
192
+
193
+
194
+ def _resolve_rag_scope(sess: Dict[str, Any], msg: str) -> Tuple[Optional[List[str]], Optional[List[str]]]:
195
+ """
196
+ Return (allowed_source_files, allowed_doc_types)
197
+ - If user is asking about "uploaded file"/document action -> restrict to latest uploaded file.
198
+ - If message contains an explicit filename -> restrict to that filename if we have it.
199
+ - Else no restriction (None, None).
200
+ """
201
+ files = sess.get("uploaded_files") or []
202
+ msg_l = (msg or "").lower()
203
+
204
+ # 1) explicit filename mentioned
205
+ hinted = _extract_filename_hint(msg)
206
+ if hinted:
207
+ # only restrict if that file exists in session uploads
208
+ known = {os.path.basename(f.get("filename", "")) for f in files if f.get("filename")}
209
+ if hinted in known:
210
+ return ([hinted], None)
211
+
212
+ # 2) generic "uploaded file" intent
213
+ uploaded_intent = any(t in msg_l for t in [
214
+ "uploaded file", "uploaded files", "the uploaded file", "this file", "this document",
215
+ "上传的文件", "这份文件", "这个文件", "文档", "课件", "讲义"
216
+ ])
217
+ if uploaded_intent and files:
218
+ last = files[-1]
219
+ fn = os.path.basename(last.get("filename", "")).strip() or None
220
+ dt = (last.get("doc_type") or "").strip() or None
221
+ allowed_files = [fn] if fn else None
222
+ allowed_doc_types = [dt] if dt else None
223
+ return (allowed_files, allowed_doc_types)
224
+
225
+ return (None, None)
226
+
227
 
228
  # ----------------------------
229
  # Warmup
 
510
 
511
  # NEW: do NOT bypass RAG for document actions (so UI refs are preserved)
512
  force_rag = _should_force_rag(msg)
513
+
514
+ allowed_files, allowed_doc_types = _resolve_rag_scope(sess, msg)
515
+
516
  if (len(msg) < 20 and ("?" not in msg)) and (not force_rag):
517
  rag_context_text, rag_used_chunks = "", []
518
  else:
519
+ rag_context_text, rag_used_chunks = retrieve_relevant_chunks(
520
+ msg,
521
+ sess["rag_chunks"],
522
+ allowed_source_files=allowed_files,
523
+ allowed_doc_types=allowed_doc_types,
524
+ )
525
 
526
 
527