Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
|
@@ -1,6 +1,4 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
-
import time as _time
|
| 3 |
-
_BOOT = _time.time()
|
| 4 |
import json
|
| 5 |
import time
|
| 6 |
import logging
|
|
@@ -11,11 +9,6 @@ from datetime import datetime
|
|
| 11 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
| 12 |
logger = logging.getLogger(__name__)
|
| 13 |
|
| 14 |
-
def _perf(label):
|
| 15 |
-
logger.info(f"[PERF] {label}: {_time.time()-_BOOT:.2f}s")
|
| 16 |
-
|
| 17 |
-
_perf("stdlib imports done")
|
| 18 |
-
|
| 19 |
# numpy 延迟导入
|
| 20 |
_np_module = None
|
| 21 |
def _get_np():
|
|
@@ -23,14 +16,13 @@ def _get_np():
|
|
| 23 |
if _np_module is None:
|
| 24 |
import numpy
|
| 25 |
_np_module = numpy
|
| 26 |
-
_perf("numpy loaded")
|
| 27 |
return _np_module
|
| 28 |
|
| 29 |
# =========================
|
| 30 |
# 1. 页面配置 & 样式注入
|
| 31 |
# =========================
|
| 32 |
st.set_page_config(page_title="RAG 知识库助手 v3 (HF+Supabase)", page_icon="🛡️", layout="wide")
|
| 33 |
-
|
| 34 |
|
| 35 |
|
| 36 |
def inject_custom_css():
|
|
@@ -59,7 +51,6 @@ def inject_custom_css():
|
|
| 59 |
|
| 60 |
inject_custom_css()
|
| 61 |
st.title("🛡️ 智能知识库助手 v3")
|
| 62 |
-
_perf("CSS + title done")
|
| 63 |
|
| 64 |
# =========================
|
| 65 |
# 1.5 Supabase 客户端初始化
|
|
@@ -85,8 +76,6 @@ def _sb():
|
|
| 85 |
return _get_supabase()
|
| 86 |
|
| 87 |
|
| 88 |
-
_perf("supabase client ready")
|
| 89 |
-
|
| 90 |
# =========================
|
| 91 |
# 2. 用户管理(Supabase users 表)
|
| 92 |
# =========================
|
|
@@ -199,10 +188,10 @@ def verify_user(username, password):
|
|
| 199 |
users = _load_users()
|
| 200 |
user_info = users.get(username)
|
| 201 |
if not user_info or not isinstance(user_info, dict):
|
| 202 |
-
return False, None
|
| 203 |
if user_info.get("password_hash") != _hash_password(password):
|
| 204 |
-
return False, None
|
| 205 |
-
return True, user_info.get("role", "user")
|
| 206 |
|
| 207 |
|
| 208 |
# --- 认证 UI ---
|
|
@@ -238,11 +227,14 @@ with st.sidebar:
|
|
| 238 |
if input_username == "" or input_password == "":
|
| 239 |
st.stop()
|
| 240 |
|
| 241 |
-
ok, role = verify_user(input_username, input_password)
|
| 242 |
if not ok:
|
| 243 |
st.session_state.login_attempts += 1
|
| 244 |
remaining = MAX_LOGIN_ATTEMPTS - st.session_state.login_attempts
|
| 245 |
-
|
|
|
|
|
|
|
|
|
|
| 246 |
st.stop()
|
| 247 |
else:
|
| 248 |
st.session_state.login_attempts = 0
|
|
@@ -263,7 +255,10 @@ with st.sidebar:
|
|
| 263 |
else:
|
| 264 |
ok, msg = register_user(reg_user, reg_pass, reg_code)
|
| 265 |
if ok:
|
| 266 |
-
st.
|
|
|
|
|
|
|
|
|
|
| 267 |
time.sleep(1)
|
| 268 |
st.rerun()
|
| 269 |
else:
|
|
@@ -272,7 +267,7 @@ with st.sidebar:
|
|
| 272 |
|
| 273 |
CURRENT_USER = st.session_state.current_user
|
| 274 |
IS_ADMIN = st.session_state.current_role == "admin"
|
| 275 |
-
|
| 276 |
|
| 277 |
# =========================
|
| 278 |
# 3. 安全配置与 Embedding 策略
|
|
@@ -355,32 +350,6 @@ def encode_query(text):
|
|
| 355 |
# 4. Supabase 索引管理(替代本地文件)
|
| 356 |
# =========================
|
| 357 |
|
| 358 |
-
def _load_library(scope):
|
| 359 |
-
"""从 Supabase documents 表加载指定 scope 的所有文档切片。
|
| 360 |
-
返回 (docs, embeddings, sources)。"""
|
| 361 |
-
np = _get_np()
|
| 362 |
-
try:
|
| 363 |
-
resp = _sb().table("documents").select(
|
| 364 |
-
"content, embedding, source_file"
|
| 365 |
-
).eq("scope", scope).execute()
|
| 366 |
-
|
| 367 |
-
docs = []
|
| 368 |
-
embeddings = []
|
| 369 |
-
sources = []
|
| 370 |
-
for row in resp.data:
|
| 371 |
-
docs.append(row["content"])
|
| 372 |
-
emb = row["embedding"]
|
| 373 |
-
# Supabase REST API 可能返回字符串格式的向量,需要解析
|
| 374 |
-
if isinstance(emb, str):
|
| 375 |
-
emb = json.loads(emb)
|
| 376 |
-
embeddings.append(np.array(emb, dtype=np.float32))
|
| 377 |
-
sources.append(row["source_file"])
|
| 378 |
-
return docs, embeddings, sources
|
| 379 |
-
except Exception as e:
|
| 380 |
-
logger.error(f"加载索引失败 [scope={scope}]: {e}")
|
| 381 |
-
return [], [], []
|
| 382 |
-
|
| 383 |
-
|
| 384 |
def _save_chunks_to_db(scope, chunks, vectors, source_file):
|
| 385 |
"""将新切片批量写入 Supabase documents 表。"""
|
| 386 |
rows = []
|
|
@@ -523,52 +492,29 @@ def _clear_uploaded_files_storage(scope):
|
|
| 523 |
logger.warning(f"清空文件失败 [scope={scope}]: {e}")
|
| 524 |
|
| 525 |
|
| 526 |
-
# --- 初始化 session_state 中的缓存 ---
|
| 527 |
-
def _init_library(key_prefix, scope):
|
| 528 |
-
"""加载 Supabase 中的索引到 session_state。"""
|
| 529 |
-
docs_key = f"{key_prefix}_docs"
|
| 530 |
-
emb_key = f"{key_prefix}_embeddings"
|
| 531 |
-
src_key = f"{key_prefix}_sources"
|
| 532 |
-
loaded_key = f"{key_prefix}_loaded"
|
| 533 |
-
|
| 534 |
-
if docs_key not in st.session_state or not st.session_state.get(loaded_key):
|
| 535 |
-
docs, embeddings, sources = _load_library(scope)
|
| 536 |
-
st.session_state[docs_key] = docs
|
| 537 |
-
st.session_state[emb_key] = embeddings
|
| 538 |
-
st.session_state[src_key] = sources
|
| 539 |
-
st.session_state[loaded_key] = True
|
| 540 |
-
|
| 541 |
|
| 542 |
-
def _refresh_library(key_prefix, scope):
|
| 543 |
-
"""强制从 Supabase 重新加载索引到 session_state。"""
|
| 544 |
-
docs, embeddings, sources = _load_library(scope)
|
| 545 |
-
st.session_state[f"{key_prefix}_docs"] = docs
|
| 546 |
-
st.session_state[f"{key_prefix}_embeddings"] = embeddings
|
| 547 |
-
st.session_state[f"{key_prefix}_sources"] = sources
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
_perf("before init_library")
|
| 551 |
PUBLIC_SCOPE = "public"
|
| 552 |
-
_init_library("public", PUBLIC_SCOPE)
|
| 553 |
PRIVATE_SCOPE = CURRENT_USER # 私有库 scope = 用户名
|
| 554 |
-
_init_library("private", PRIVATE_SCOPE)
|
| 555 |
-
_perf("init_library done")
|
| 556 |
|
|
|
|
|
|
|
| 557 |
|
| 558 |
-
def
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
|
|
|
|
|
|
| 572 |
|
| 573 |
|
| 574 |
# =========================
|
|
@@ -605,7 +551,6 @@ def _get_text_splitter():
|
|
| 605 |
if _text_splitter_cache is None:
|
| 606 |
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 607 |
_text_splitter_cache = RecursiveCharacterTextSplitter(chunk_size=400, chunk_overlap=50)
|
| 608 |
-
_perf("text_splitter loaded")
|
| 609 |
return _text_splitter_cache
|
| 610 |
|
| 611 |
SYSTEM_PROMPT = (
|
|
@@ -730,9 +675,6 @@ def process_upload(uploaded_files, target_prefix, scope):
|
|
| 730 |
for src_file, (chunks, vecs) in file_groups.items():
|
| 731 |
_save_chunks_to_db(scope, chunks, vecs, src_file)
|
| 732 |
|
| 733 |
-
# 刷新 session_state 缓存
|
| 734 |
-
_refresh_library(target_prefix, scope)
|
| 735 |
-
|
| 736 |
st.session_state[fp_key] = file_fingerprint
|
| 737 |
# 递增上传组件 key
|
| 738 |
ukey = f"_upload_ver_{target_prefix}"
|
|
@@ -772,9 +714,9 @@ model_mapping = {
|
|
| 772 |
"🏢 百度文心 (官方)": "ernie-3.5-8k",
|
| 773 |
}
|
| 774 |
|
| 775 |
-
|
| 776 |
with st.sidebar:
|
| 777 |
-
pub_chunk_count =
|
| 778 |
with st.expander(f"📚 公共知识库({pub_chunk_count} 切片)"):
|
| 779 |
st.caption("所有人可搜索")
|
| 780 |
|
|
@@ -789,7 +731,6 @@ with st.sidebar:
|
|
| 789 |
if col_del.button("🗑", key=f"delpub_{fname}", help=f"删除 {fname}"):
|
| 790 |
_delete_chunks_by_file(PUBLIC_SCOPE, fname)
|
| 791 |
_delete_uploaded_file_from_storage(PUBLIC_SCOPE, fname)
|
| 792 |
-
_refresh_library("public", PUBLIC_SCOPE)
|
| 793 |
st.success(f"已删除 {fname}")
|
| 794 |
time.sleep(0.5)
|
| 795 |
st.rerun()
|
|
@@ -812,7 +753,6 @@ with st.sidebar:
|
|
| 812 |
if st.button("🗑️ 清空公共库", use_container_width=True, type="secondary", key="clear_pub"):
|
| 813 |
_clear_all_chunks(PUBLIC_SCOPE)
|
| 814 |
_clear_uploaded_files_storage(PUBLIC_SCOPE)
|
| 815 |
-
_refresh_library("public", PUBLIC_SCOPE)
|
| 816 |
st.success("公共知识库已清空。")
|
| 817 |
time.sleep(0.5)
|
| 818 |
st.rerun()
|
|
@@ -820,7 +760,7 @@ with st.sidebar:
|
|
| 820 |
st.caption("*仅管理员可维护公共库*")
|
| 821 |
|
| 822 |
# --- 私有知识库 ---
|
| 823 |
-
priv_chunk_count =
|
| 824 |
with st.expander(f"🔒 我的私有库({priv_chunk_count} 切片)"):
|
| 825 |
st.caption(f"用户:{CURRENT_USER},仅自己可见")
|
| 826 |
|
|
@@ -833,7 +773,6 @@ with st.sidebar:
|
|
| 833 |
if col_del.button("🗑", key=f"delpriv_{fname}", help=f"删除 {fname}"):
|
| 834 |
_delete_chunks_by_file(PRIVATE_SCOPE, fname)
|
| 835 |
_delete_uploaded_file_from_storage(PRIVATE_SCOPE, fname)
|
| 836 |
-
_refresh_library("private", PRIVATE_SCOPE)
|
| 837 |
st.success(f"已删除 {fname}")
|
| 838 |
time.sleep(0.5)
|
| 839 |
st.rerun()
|
|
@@ -853,7 +792,6 @@ with st.sidebar:
|
|
| 853 |
if st.button("🗑️ 清空我的私有库", use_container_width=True, type="secondary", key="clear_priv"):
|
| 854 |
_clear_all_chunks(PRIVATE_SCOPE)
|
| 855 |
_clear_uploaded_files_storage(PRIVATE_SCOPE)
|
| 856 |
-
_refresh_library("private", PRIVATE_SCOPE)
|
| 857 |
st.success("私有知识库已清空。")
|
| 858 |
time.sleep(0.5)
|
| 859 |
st.rerun()
|
|
@@ -878,7 +816,7 @@ with st.sidebar:
|
|
| 878 |
new_pass1 = st.text_input("新密码", type="password", key="self_new_pass1")
|
| 879 |
new_pass2 = st.text_input("确认新密码", type="password", key="self_new_pass2")
|
| 880 |
if st.button("✅ 确认修改", key="btn_change_pass"):
|
| 881 |
-
ok, _ = verify_user(CURRENT_USER, old_pass)
|
| 882 |
if not ok:
|
| 883 |
st.error("当前密码错误")
|
| 884 |
elif len(new_pass1) < 4:
|
|
@@ -971,49 +909,59 @@ with st.sidebar:
|
|
| 971 |
display_users[k] = v
|
| 972 |
st.json(display_users)
|
| 973 |
|
| 974 |
-
# ---
|
| 975 |
-
with st.expander("
|
| 976 |
-
st.
|
| 977 |
-
|
| 978 |
-
|
| 979 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 980 |
|
| 981 |
|
| 982 |
# =========================
|
| 983 |
# 8. 核心搜索逻辑(合并公共库 + 私有库)
|
| 984 |
# =========================
|
| 985 |
-
def _cosine_scores(query_vec, matrix):
|
| 986 |
-
np = _get_np()
|
| 987 |
-
query_norm = np.linalg.norm(query_vec)
|
| 988 |
-
if query_norm < 1e-10:
|
| 989 |
-
return np.zeros(matrix.shape[0])
|
| 990 |
-
mat_norms = np.linalg.norm(matrix, axis=1)
|
| 991 |
-
mat_norms = np.maximum(mat_norms, 1e-10)
|
| 992 |
-
return (matrix @ query_vec) / (mat_norms * query_norm)
|
| 993 |
-
|
| 994 |
-
|
| 995 |
def search_local(query, top_k, threshold):
|
|
|
|
| 996 |
query_vec = encode_query(query)
|
| 997 |
-
|
| 998 |
-
|
| 999 |
-
|
| 1000 |
-
|
| 1001 |
-
|
| 1002 |
-
|
| 1003 |
-
|
| 1004 |
-
|
| 1005 |
-
|
| 1006 |
-
|
| 1007 |
-
|
| 1008 |
-
|
| 1009 |
-
if priv_docs and priv_np.size > 0:
|
| 1010 |
-
scores = _cosine_scores(query_vec, priv_np)
|
| 1011 |
-
for i, s in enumerate(scores):
|
| 1012 |
-
if s > threshold:
|
| 1013 |
-
all_results.append((float(s), priv_docs[i]))
|
| 1014 |
-
|
| 1015 |
-
all_results.sort(key=lambda x: x[0], reverse=True)
|
| 1016 |
-
return [doc for _, doc in all_results[:top_k]]
|
| 1017 |
|
| 1018 |
|
| 1019 |
# =========================
|
|
@@ -1106,11 +1054,63 @@ def llm_answer(query, context_docs, selected_display_name, web_enabled):
|
|
| 1106 |
yield "❌ 抱歉,所有免费和收费线路均暂时不可用。"
|
| 1107 |
|
| 1108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1109 |
# =========================
|
| 1110 |
# 10. 聊天渲染
|
| 1111 |
# =========================
|
| 1112 |
if "messages" not in st.session_state:
|
| 1113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1114 |
|
| 1115 |
for m in st.session_state.messages:
|
| 1116 |
with st.chat_message(m["role"]):
|
|
@@ -1120,6 +1120,7 @@ for m in st.session_state.messages:
|
|
| 1120 |
|
| 1121 |
if q := st.chat_input("输入问题...", key="chat_input_v3"):
|
| 1122 |
st.session_state.messages.append({"role": "user", "content": q})
|
|
|
|
| 1123 |
with st.chat_message("user"):
|
| 1124 |
st.markdown(q)
|
| 1125 |
|
|
@@ -1142,8 +1143,8 @@ if q := st.chat_input("输入问题...", key="chat_input_v3"):
|
|
| 1142 |
st.session_state.messages.append(
|
| 1143 |
{"role": "assistant", "content": full_response, "meta": meta_info}
|
| 1144 |
)
|
|
|
|
| 1145 |
except Exception as e:
|
| 1146 |
logger.error(f"模型调用异常: {e}")
|
| 1147 |
container.error(f"❌ 抱歉,连接模型时出错了: {str(e)}")
|
| 1148 |
|
| 1149 |
-
_perf("script execution complete")
|
|
|
|
| 1 |
import streamlit as st
|
|
|
|
|
|
|
| 2 |
import json
|
| 3 |
import time
|
| 4 |
import logging
|
|
|
|
| 9 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
| 10 |
logger = logging.getLogger(__name__)
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
# numpy 延迟导入
|
| 13 |
_np_module = None
|
| 14 |
def _get_np():
|
|
|
|
| 16 |
if _np_module is None:
|
| 17 |
import numpy
|
| 18 |
_np_module = numpy
|
|
|
|
| 19 |
return _np_module
|
| 20 |
|
| 21 |
# =========================
|
| 22 |
# 1. 页面配置 & 样式注入
|
| 23 |
# =========================
|
| 24 |
st.set_page_config(page_title="RAG 知识库助手 v3 (HF+Supabase)", page_icon="🛡️", layout="wide")
|
| 25 |
+
|
| 26 |
|
| 27 |
|
| 28 |
def inject_custom_css():
|
|
|
|
| 51 |
|
| 52 |
inject_custom_css()
|
| 53 |
st.title("🛡️ 智能知识库助手 v3")
|
|
|
|
| 54 |
|
| 55 |
# =========================
|
| 56 |
# 1.5 Supabase 客户端初始化
|
|
|
|
| 76 |
return _get_supabase()
|
| 77 |
|
| 78 |
|
|
|
|
|
|
|
| 79 |
# =========================
|
| 80 |
# 2. 用户管理(Supabase users 表)
|
| 81 |
# =========================
|
|
|
|
| 188 |
users = _load_users()
|
| 189 |
user_info = users.get(username)
|
| 190 |
if not user_info or not isinstance(user_info, dict):
|
| 191 |
+
return False, None, "not_found"
|
| 192 |
if user_info.get("password_hash") != _hash_password(password):
|
| 193 |
+
return False, None, "wrong_password"
|
| 194 |
+
return True, user_info.get("role", "user"), ""
|
| 195 |
|
| 196 |
|
| 197 |
# --- 认证 UI ---
|
|
|
|
| 227 |
if input_username == "" or input_password == "":
|
| 228 |
st.stop()
|
| 229 |
|
| 230 |
+
ok, role, reason = verify_user(input_username, input_password)
|
| 231 |
if not ok:
|
| 232 |
st.session_state.login_attempts += 1
|
| 233 |
remaining = MAX_LOGIN_ATTEMPTS - st.session_state.login_attempts
|
| 234 |
+
if reason == "not_found":
|
| 235 |
+
st.warning(f"⚠️ 用户不存在,请先注册(剩余 {remaining} 次)")
|
| 236 |
+
else:
|
| 237 |
+
st.warning(f"⚠️ 密码错误(剩余 {remaining} 次)")
|
| 238 |
st.stop()
|
| 239 |
else:
|
| 240 |
st.session_state.login_attempts = 0
|
|
|
|
| 255 |
else:
|
| 256 |
ok, msg = register_user(reg_user, reg_pass, reg_code)
|
| 257 |
if ok:
|
| 258 |
+
st.session_state.current_user = reg_user
|
| 259 |
+
st.session_state.current_role = "user"
|
| 260 |
+
st.session_state.login_attempts = 0
|
| 261 |
+
st.success(f"✅ 注册成功,已自动登录")
|
| 262 |
time.sleep(1)
|
| 263 |
st.rerun()
|
| 264 |
else:
|
|
|
|
| 267 |
|
| 268 |
CURRENT_USER = st.session_state.current_user
|
| 269 |
IS_ADMIN = st.session_state.current_role == "admin"
|
| 270 |
+
|
| 271 |
|
| 272 |
# =========================
|
| 273 |
# 3. 安全配置与 Embedding 策略
|
|
|
|
| 350 |
# 4. Supabase 索引管理(替代本地文件)
|
| 351 |
# =========================
|
| 352 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
def _save_chunks_to_db(scope, chunks, vectors, source_file):
|
| 354 |
"""将新切片批量写入 Supabase documents 表。"""
|
| 355 |
rows = []
|
|
|
|
| 492 |
logger.warning(f"清空文件失败 [scope={scope}]: {e}")
|
| 493 |
|
| 494 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 495 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 496 |
PUBLIC_SCOPE = "public"
|
|
|
|
| 497 |
PRIVATE_SCOPE = CURRENT_USER # 私有库 scope = 用户名
|
|
|
|
|
|
|
| 498 |
|
| 499 |
+
# --- 定时同步:检测其他用户对文档库的修改 ---
|
| 500 |
+
_SYNC_INTERVAL = 30 # 每 30 秒检查一次
|
| 501 |
|
| 502 |
+
def _check_and_sync():
|
| 503 |
+
"""检测文档数量变化,用于多用户同步感知。"""
|
| 504 |
+
now = time.time()
|
| 505 |
+
last_check = st.session_state.get("_sync_last_check", 0)
|
| 506 |
+
if now - last_check < _SYNC_INTERVAL:
|
| 507 |
+
return
|
| 508 |
+
st.session_state["_sync_last_check"] = now
|
| 509 |
+
for scope, label in [(PUBLIC_SCOPE, "public"), (PRIVATE_SCOPE, "private")]:
|
| 510 |
+
count_key = f"_sync_count_{label}"
|
| 511 |
+
current_count = _count_chunks(scope)
|
| 512 |
+
prev_count = st.session_state.get(count_key, -1)
|
| 513 |
+
if prev_count >= 0 and current_count != prev_count:
|
| 514 |
+
logger.info(f"[SYNC] {label} 库变更: {prev_count} -> {current_count}")
|
| 515 |
+
st.session_state[count_key] = current_count
|
| 516 |
+
|
| 517 |
+
_check_and_sync()
|
| 518 |
|
| 519 |
|
| 520 |
# =========================
|
|
|
|
| 551 |
if _text_splitter_cache is None:
|
| 552 |
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 553 |
_text_splitter_cache = RecursiveCharacterTextSplitter(chunk_size=400, chunk_overlap=50)
|
|
|
|
| 554 |
return _text_splitter_cache
|
| 555 |
|
| 556 |
SYSTEM_PROMPT = (
|
|
|
|
| 675 |
for src_file, (chunks, vecs) in file_groups.items():
|
| 676 |
_save_chunks_to_db(scope, chunks, vecs, src_file)
|
| 677 |
|
|
|
|
|
|
|
|
|
|
| 678 |
st.session_state[fp_key] = file_fingerprint
|
| 679 |
# 递增上传组件 key
|
| 680 |
ukey = f"_upload_ver_{target_prefix}"
|
|
|
|
| 714 |
"🏢 百度文心 (官方)": "ernie-3.5-8k",
|
| 715 |
}
|
| 716 |
|
| 717 |
+
|
| 718 |
with st.sidebar:
|
| 719 |
+
pub_chunk_count = st.session_state.get("_sync_count_public", _count_chunks(PUBLIC_SCOPE))
|
| 720 |
with st.expander(f"📚 公共知识库({pub_chunk_count} 切片)"):
|
| 721 |
st.caption("所有人可搜索")
|
| 722 |
|
|
|
|
| 731 |
if col_del.button("🗑", key=f"delpub_{fname}", help=f"删除 {fname}"):
|
| 732 |
_delete_chunks_by_file(PUBLIC_SCOPE, fname)
|
| 733 |
_delete_uploaded_file_from_storage(PUBLIC_SCOPE, fname)
|
|
|
|
| 734 |
st.success(f"已删除 {fname}")
|
| 735 |
time.sleep(0.5)
|
| 736 |
st.rerun()
|
|
|
|
| 753 |
if st.button("🗑️ 清空公共库", use_container_width=True, type="secondary", key="clear_pub"):
|
| 754 |
_clear_all_chunks(PUBLIC_SCOPE)
|
| 755 |
_clear_uploaded_files_storage(PUBLIC_SCOPE)
|
|
|
|
| 756 |
st.success("公共知识库已清空。")
|
| 757 |
time.sleep(0.5)
|
| 758 |
st.rerun()
|
|
|
|
| 760 |
st.caption("*仅管理员可维护公共库*")
|
| 761 |
|
| 762 |
# --- 私有知识库 ---
|
| 763 |
+
priv_chunk_count = st.session_state.get("_sync_count_private", _count_chunks(PRIVATE_SCOPE))
|
| 764 |
with st.expander(f"🔒 我的私有库({priv_chunk_count} 切片)"):
|
| 765 |
st.caption(f"用户:{CURRENT_USER},仅自己可见")
|
| 766 |
|
|
|
|
| 773 |
if col_del.button("🗑", key=f"delpriv_{fname}", help=f"删除 {fname}"):
|
| 774 |
_delete_chunks_by_file(PRIVATE_SCOPE, fname)
|
| 775 |
_delete_uploaded_file_from_storage(PRIVATE_SCOPE, fname)
|
|
|
|
| 776 |
st.success(f"已删除 {fname}")
|
| 777 |
time.sleep(0.5)
|
| 778 |
st.rerun()
|
|
|
|
| 792 |
if st.button("🗑️ 清空我的私有库", use_container_width=True, type="secondary", key="clear_priv"):
|
| 793 |
_clear_all_chunks(PRIVATE_SCOPE)
|
| 794 |
_clear_uploaded_files_storage(PRIVATE_SCOPE)
|
|
|
|
| 795 |
st.success("私有知识库已清空。")
|
| 796 |
time.sleep(0.5)
|
| 797 |
st.rerun()
|
|
|
|
| 816 |
new_pass1 = st.text_input("新密码", type="password", key="self_new_pass1")
|
| 817 |
new_pass2 = st.text_input("确认新密码", type="password", key="self_new_pass2")
|
| 818 |
if st.button("✅ 确认修改", key="btn_change_pass"):
|
| 819 |
+
ok, _, _ = verify_user(CURRENT_USER, old_pass)
|
| 820 |
if not ok:
|
| 821 |
st.error("当前密码错误")
|
| 822 |
elif len(new_pass1) < 4:
|
|
|
|
| 909 |
display_users[k] = v
|
| 910 |
st.json(display_users)
|
| 911 |
|
| 912 |
+
# --- 聊天记录管理 ---
|
| 913 |
+
with st.expander("💬 聊天记录"):
|
| 914 |
+
hist_tab_new, hist_tab_history = st.tabs(["当前对话", "历史记录"])
|
| 915 |
+
|
| 916 |
+
with hist_tab_new:
|
| 917 |
+
st.caption("清空当前对话(数据库记录保留)")
|
| 918 |
+
if st.button("🧹 清空当前对话", use_container_width=True, type="secondary", key="btn_clear_chat"):
|
| 919 |
+
st.session_state.messages = []
|
| 920 |
+
st.rerun()
|
| 921 |
+
|
| 922 |
+
st.caption("清空所有历史记录(不可恢复)")
|
| 923 |
+
if st.button("🗑️ 清空全部记录", use_container_width=True, type="secondary", key="btn_clear_all_hist"):
|
| 924 |
+
_clear_chat_history_db(CURRENT_USER)
|
| 925 |
+
st.session_state.messages = []
|
| 926 |
+
st.success("所有聊天记录已清空")
|
| 927 |
+
time.sleep(0.5)
|
| 928 |
+
st.rerun()
|
| 929 |
+
|
| 930 |
+
with hist_tab_history:
|
| 931 |
+
if st.button("🔄 加载历史记录", use_container_width=True, key="btn_load_hist"):
|
| 932 |
+
st.session_state["_show_history"] = True
|
| 933 |
+
|
| 934 |
+
if st.session_state.get("_show_history"):
|
| 935 |
+
history = _load_chat_history(CURRENT_USER, limit=100)
|
| 936 |
+
if not history:
|
| 937 |
+
st.info("暂无历史记录")
|
| 938 |
+
else:
|
| 939 |
+
st.caption(f"共 {len(history)} 条记录")
|
| 940 |
+
for msg in history:
|
| 941 |
+
ts = msg.get("created_at", "")[:16].replace("T", " ")
|
| 942 |
+
icon = "🧑" if msg["role"] == "user" else "🤖"
|
| 943 |
+
preview = msg["content"][:80].replace("\n", " ")
|
| 944 |
+
st.text(f"{icon} [{ts}] {preview}{'...' if len(msg['content']) > 80 else ''}")
|
| 945 |
|
| 946 |
|
| 947 |
# =========================
|
| 948 |
# 8. 核心搜索逻辑(合并公共库 + 私有库)
|
| 949 |
# =========================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 950 |
def search_local(query, top_k, threshold):
|
| 951 |
+
"""使用 pgvector 数据库端向量搜索(替代内存计算)。"""
|
| 952 |
query_vec = encode_query(query)
|
| 953 |
+
scopes = [PUBLIC_SCOPE, PRIVATE_SCOPE]
|
| 954 |
+
try:
|
| 955 |
+
resp = _sb().rpc("match_documents", {
|
| 956 |
+
"query_embedding": query_vec.tolist() if hasattr(query_vec, 'tolist') else list(query_vec),
|
| 957 |
+
"match_scopes": scopes,
|
| 958 |
+
"match_threshold": float(threshold),
|
| 959 |
+
"match_count": int(top_k),
|
| 960 |
+
}).execute()
|
| 961 |
+
return [row["content"] for row in resp.data] if resp.data else []
|
| 962 |
+
except Exception as e:
|
| 963 |
+
logger.error(f"pgvector 搜索失败: {e}")
|
| 964 |
+
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 965 |
|
| 966 |
|
| 967 |
# =========================
|
|
|
|
| 1054 |
yield "❌ 抱歉,所有免费和收费线路均暂时不可用。"
|
| 1055 |
|
| 1056 |
|
| 1057 |
+
# =========================
|
| 1058 |
+
# 9.5 聊天记录持久化(Supabase chat_history 表)
|
| 1059 |
+
# =========================
|
| 1060 |
+
def _save_chat_message(username, role, content, meta=""):
|
| 1061 |
+
"""保存单条聊天消息到数据库。"""
|
| 1062 |
+
try:
|
| 1063 |
+
_sb().table("chat_history").insert({
|
| 1064 |
+
"username": username,
|
| 1065 |
+
"role": role,
|
| 1066 |
+
"content": content,
|
| 1067 |
+
"meta": meta or "",
|
| 1068 |
+
}).execute()
|
| 1069 |
+
except Exception as e:
|
| 1070 |
+
logger.warning(f"保存聊天记录失败: {e}")
|
| 1071 |
+
|
| 1072 |
+
|
| 1073 |
+
def _load_chat_history(username, limit=50):
|
| 1074 |
+
"""加载用户最近的聊天记录。"""
|
| 1075 |
+
try:
|
| 1076 |
+
resp = _sb().table("chat_history").select(
|
| 1077 |
+
"role, content, meta, created_at"
|
| 1078 |
+
).eq("username", username).order(
|
| 1079 |
+
"created_at", desc=True
|
| 1080 |
+
).limit(limit).execute()
|
| 1081 |
+
if not resp.data:
|
| 1082 |
+
return []
|
| 1083 |
+
# 反转回时间正序
|
| 1084 |
+
rows = list(reversed(resp.data))
|
| 1085 |
+
return [
|
| 1086 |
+
{"role": r["role"], "content": r["content"],
|
| 1087 |
+
"meta": r.get("meta", ""), "created_at": r.get("created_at", "")}
|
| 1088 |
+
for r in rows
|
| 1089 |
+
]
|
| 1090 |
+
except Exception as e:
|
| 1091 |
+
logger.warning(f"加载聊天记录失败: {e}")
|
| 1092 |
+
return []
|
| 1093 |
+
|
| 1094 |
+
|
| 1095 |
+
def _clear_chat_history_db(username):
|
| 1096 |
+
"""清空用户在数据库中的所有聊天记录。"""
|
| 1097 |
+
try:
|
| 1098 |
+
_sb().table("chat_history").delete().eq("username", username).execute()
|
| 1099 |
+
except Exception as e:
|
| 1100 |
+
logger.warning(f"清空聊天记录失败: {e}")
|
| 1101 |
+
|
| 1102 |
+
|
| 1103 |
# =========================
|
| 1104 |
# 10. 聊天渲染
|
| 1105 |
# =========================
|
| 1106 |
if "messages" not in st.session_state:
|
| 1107 |
+
# 首次加载时从数据库恢复最近对话
|
| 1108 |
+
saved = _load_chat_history(CURRENT_USER, limit=50)
|
| 1109 |
+
st.session_state.messages = [
|
| 1110 |
+
{"role": m["role"], "content": m["content"],
|
| 1111 |
+
**({"meta": m["meta"]} if m.get("meta") else {})}
|
| 1112 |
+
for m in saved
|
| 1113 |
+
]
|
| 1114 |
|
| 1115 |
for m in st.session_state.messages:
|
| 1116 |
with st.chat_message(m["role"]):
|
|
|
|
| 1120 |
|
| 1121 |
if q := st.chat_input("输入问题...", key="chat_input_v3"):
|
| 1122 |
st.session_state.messages.append({"role": "user", "content": q})
|
| 1123 |
+
_save_chat_message(CURRENT_USER, "user", q)
|
| 1124 |
with st.chat_message("user"):
|
| 1125 |
st.markdown(q)
|
| 1126 |
|
|
|
|
| 1143 |
st.session_state.messages.append(
|
| 1144 |
{"role": "assistant", "content": full_response, "meta": meta_info}
|
| 1145 |
)
|
| 1146 |
+
_save_chat_message(CURRENT_USER, "assistant", full_response, meta_info)
|
| 1147 |
except Exception as e:
|
| 1148 |
logger.error(f"模型调用异常: {e}")
|
| 1149 |
container.error(f"❌ 抱歉,连接模型时出错了: {str(e)}")
|
| 1150 |
|
|
|