Spaces:
Sleeping
Sleeping
Update api/server.py
Browse files- api/server.py +53 -2
api/server.py
CHANGED
|
@@ -180,6 +180,50 @@ def _should_force_rag(message: str) -> bool:
|
|
| 180 |
]
|
| 181 |
return any(t in m for t in triggers)
|
| 182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
# ----------------------------
|
| 185 |
# Warmup
|
|
@@ -466,11 +510,18 @@ def chat(req: ChatReq):
|
|
| 466 |
|
| 467 |
# NEW: do NOT bypass RAG for document actions (so UI refs are preserved)
|
| 468 |
force_rag = _should_force_rag(msg)
|
| 469 |
-
|
|
|
|
|
|
|
| 470 |
if (len(msg) < 20 and ("?" not in msg)) and (not force_rag):
|
| 471 |
rag_context_text, rag_used_chunks = "", []
|
| 472 |
else:
|
| 473 |
-
rag_context_text, rag_used_chunks = retrieve_relevant_chunks(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
|
| 475 |
|
| 476 |
|
|
|
|
| 180 |
]
|
| 181 |
return any(t in m for t in triggers)
|
| 182 |
|
| 183 |
+
def _extract_filename_hint(message: str) -> Optional[str]:
|
| 184 |
+
m = (message or "").strip()
|
| 185 |
+
if not m:
|
| 186 |
+
return None
|
| 187 |
+
# 极简:如果用户直接提到了 .pdf/.ppt/.docx 文件名,就用它
|
| 188 |
+
for token in m.replace("“", '"').replace("”", '"').split():
|
| 189 |
+
if any(token.lower().endswith(ext) for ext in [".pdf", ".ppt", ".pptx", ".doc", ".docx"]):
|
| 190 |
+
return os.path.basename(token.strip('"').strip("'").strip())
|
| 191 |
+
return None
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def _resolve_rag_scope(sess: Dict[str, Any], msg: str) -> Tuple[Optional[List[str]], Optional[List[str]]]:
|
| 195 |
+
"""
|
| 196 |
+
Return (allowed_source_files, allowed_doc_types)
|
| 197 |
+
- If user is asking about "uploaded file"/document action -> restrict to latest uploaded file.
|
| 198 |
+
- If message contains an explicit filename -> restrict to that filename if we have it.
|
| 199 |
+
- Else no restriction (None, None).
|
| 200 |
+
"""
|
| 201 |
+
files = sess.get("uploaded_files") or []
|
| 202 |
+
msg_l = (msg or "").lower()
|
| 203 |
+
|
| 204 |
+
# 1) explicit filename mentioned
|
| 205 |
+
hinted = _extract_filename_hint(msg)
|
| 206 |
+
if hinted:
|
| 207 |
+
# only restrict if that file exists in session uploads
|
| 208 |
+
known = {os.path.basename(f.get("filename", "")) for f in files if f.get("filename")}
|
| 209 |
+
if hinted in known:
|
| 210 |
+
return ([hinted], None)
|
| 211 |
+
|
| 212 |
+
# 2) generic "uploaded file" intent
|
| 213 |
+
uploaded_intent = any(t in msg_l for t in [
|
| 214 |
+
"uploaded file", "uploaded files", "the uploaded file", "this file", "this document",
|
| 215 |
+
"上传的文件", "这份文件", "这个文件", "文档", "课件", "讲义"
|
| 216 |
+
])
|
| 217 |
+
if uploaded_intent and files:
|
| 218 |
+
last = files[-1]
|
| 219 |
+
fn = os.path.basename(last.get("filename", "")).strip() or None
|
| 220 |
+
dt = (last.get("doc_type") or "").strip() or None
|
| 221 |
+
allowed_files = [fn] if fn else None
|
| 222 |
+
allowed_doc_types = [dt] if dt else None
|
| 223 |
+
return (allowed_files, allowed_doc_types)
|
| 224 |
+
|
| 225 |
+
return (None, None)
|
| 226 |
+
|
| 227 |
|
| 228 |
# ----------------------------
|
| 229 |
# Warmup
|
|
|
|
| 510 |
|
| 511 |
# NEW: do NOT bypass RAG for document actions (so UI refs are preserved)
|
| 512 |
force_rag = _should_force_rag(msg)
|
| 513 |
+
|
| 514 |
+
allowed_files, allowed_doc_types = _resolve_rag_scope(sess, msg)
|
| 515 |
+
|
| 516 |
if (len(msg) < 20 and ("?" not in msg)) and (not force_rag):
|
| 517 |
rag_context_text, rag_used_chunks = "", []
|
| 518 |
else:
|
| 519 |
+
rag_context_text, rag_used_chunks = retrieve_relevant_chunks(
|
| 520 |
+
msg,
|
| 521 |
+
sess["rag_chunks"],
|
| 522 |
+
allowed_source_files=allowed_files,
|
| 523 |
+
allowed_doc_types=allowed_doc_types,
|
| 524 |
+
)
|
| 525 |
|
| 526 |
|
| 527 |
|