AlauStone commited on
Commit
ec2e567
·
verified ·
1 Parent(s): 5ea08e1

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -136
app.py CHANGED
@@ -1,6 +1,4 @@
1
  import streamlit as st
2
- import time as _time
3
- _BOOT = _time.time()
4
  import json
5
  import time
6
  import logging
@@ -11,11 +9,6 @@ from datetime import datetime
11
  logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
12
  logger = logging.getLogger(__name__)
13
 
14
- def _perf(label):
15
- logger.info(f"[PERF] {label}: {_time.time()-_BOOT:.2f}s")
16
-
17
- _perf("stdlib imports done")
18
-
19
  # numpy 延迟导入
20
  _np_module = None
21
  def _get_np():
@@ -23,14 +16,13 @@ def _get_np():
23
  if _np_module is None:
24
  import numpy
25
  _np_module = numpy
26
- _perf("numpy loaded")
27
  return _np_module
28
 
29
  # =========================
30
  # 1. 页面配置 & 样式注入
31
  # =========================
32
  st.set_page_config(page_title="RAG 知识库助手 v3 (HF+Supabase)", page_icon="🛡️", layout="wide")
33
- _perf("page_config done")
34
 
35
 
36
  def inject_custom_css():
@@ -59,7 +51,6 @@ def inject_custom_css():
59
 
60
  inject_custom_css()
61
  st.title("🛡️ 智能知识库助手 v3")
62
- _perf("CSS + title done")
63
 
64
  # =========================
65
  # 1.5 Supabase 客户端初始化
@@ -85,8 +76,6 @@ def _sb():
85
  return _get_supabase()
86
 
87
 
88
- _perf("supabase client ready")
89
-
90
  # =========================
91
  # 2. 用户管理(Supabase users 表)
92
  # =========================
@@ -199,10 +188,10 @@ def verify_user(username, password):
199
  users = _load_users()
200
  user_info = users.get(username)
201
  if not user_info or not isinstance(user_info, dict):
202
- return False, None
203
  if user_info.get("password_hash") != _hash_password(password):
204
- return False, None
205
- return True, user_info.get("role", "user")
206
 
207
 
208
  # --- 认证 UI ---
@@ -238,11 +227,14 @@ with st.sidebar:
238
  if input_username == "" or input_password == "":
239
  st.stop()
240
 
241
- ok, role = verify_user(input_username, input_password)
242
  if not ok:
243
  st.session_state.login_attempts += 1
244
  remaining = MAX_LOGIN_ATTEMPTS - st.session_state.login_attempts
245
- st.warning(f"⚠️ 用户名或密码错误(剩余 {remaining} 次)")
 
 
 
246
  st.stop()
247
  else:
248
  st.session_state.login_attempts = 0
@@ -263,7 +255,10 @@ with st.sidebar:
263
  else:
264
  ok, msg = register_user(reg_user, reg_pass, reg_code)
265
  if ok:
266
- st.success(f"✅ {msg}")
 
 
 
267
  time.sleep(1)
268
  st.rerun()
269
  else:
@@ -272,7 +267,7 @@ with st.sidebar:
272
 
273
  CURRENT_USER = st.session_state.current_user
274
  IS_ADMIN = st.session_state.current_role == "admin"
275
- _perf("auth done")
276
 
277
  # =========================
278
  # 3. 安全配置与 Embedding 策略
@@ -355,32 +350,6 @@ def encode_query(text):
355
  # 4. Supabase 索引管理(替代本地文件)
356
  # =========================
357
 
358
- def _load_library(scope):
359
- """从 Supabase documents 表加载指定 scope 的所有文档切片。
360
- 返回 (docs, embeddings, sources)。"""
361
- np = _get_np()
362
- try:
363
- resp = _sb().table("documents").select(
364
- "content, embedding, source_file"
365
- ).eq("scope", scope).execute()
366
-
367
- docs = []
368
- embeddings = []
369
- sources = []
370
- for row in resp.data:
371
- docs.append(row["content"])
372
- emb = row["embedding"]
373
- # Supabase REST API 可能返回字符串格式的向量,需要解析
374
- if isinstance(emb, str):
375
- emb = json.loads(emb)
376
- embeddings.append(np.array(emb, dtype=np.float32))
377
- sources.append(row["source_file"])
378
- return docs, embeddings, sources
379
- except Exception as e:
380
- logger.error(f"加载索引失败 [scope={scope}]: {e}")
381
- return [], [], []
382
-
383
-
384
  def _save_chunks_to_db(scope, chunks, vectors, source_file):
385
  """将新切片批量写入 Supabase documents 表。"""
386
  rows = []
@@ -523,52 +492,29 @@ def _clear_uploaded_files_storage(scope):
523
  logger.warning(f"清空文件失败 [scope={scope}]: {e}")
524
 
525
 
526
- # --- 初始化 session_state 中的缓存 ---
527
- def _init_library(key_prefix, scope):
528
- """加载 Supabase 中的索引到 session_state。"""
529
- docs_key = f"{key_prefix}_docs"
530
- emb_key = f"{key_prefix}_embeddings"
531
- src_key = f"{key_prefix}_sources"
532
- loaded_key = f"{key_prefix}_loaded"
533
-
534
- if docs_key not in st.session_state or not st.session_state.get(loaded_key):
535
- docs, embeddings, sources = _load_library(scope)
536
- st.session_state[docs_key] = docs
537
- st.session_state[emb_key] = embeddings
538
- st.session_state[src_key] = sources
539
- st.session_state[loaded_key] = True
540
-
541
 
542
- def _refresh_library(key_prefix, scope):
543
- """强制从 Supabase 重新加载索引到 session_state。"""
544
- docs, embeddings, sources = _load_library(scope)
545
- st.session_state[f"{key_prefix}_docs"] = docs
546
- st.session_state[f"{key_prefix}_embeddings"] = embeddings
547
- st.session_state[f"{key_prefix}_sources"] = sources
548
-
549
-
550
- _perf("before init_library")
551
  PUBLIC_SCOPE = "public"
552
- _init_library("public", PUBLIC_SCOPE)
553
  PRIVATE_SCOPE = CURRENT_USER # 私有库 scope = 用户名
554
- _init_library("private", PRIVATE_SCOPE)
555
- _perf("init_library done")
556
 
 
 
557
 
558
- def _get_embeddings_np(key_prefix):
559
- np = _get_np()
560
- np_key = f"{key_prefix}_embeddings_np"
561
- ver_key = f"{key_prefix}_emb_version"
562
- emb_key = f"{key_prefix}_embeddings"
563
- emb_list = st.session_state.get(emb_key, [])
564
- current_ver = id(emb_list)
565
- if np_key not in st.session_state or st.session_state.get(ver_key) != current_ver:
566
- if emb_list:
567
- st.session_state[np_key] = np.array(emb_list)
568
- else:
569
- st.session_state[np_key] = np.array([])
570
- st.session_state[ver_key] = current_ver
571
- return st.session_state[np_key]
 
 
572
 
573
 
574
  # =========================
@@ -605,7 +551,6 @@ def _get_text_splitter():
605
  if _text_splitter_cache is None:
606
  from langchain_text_splitters import RecursiveCharacterTextSplitter
607
  _text_splitter_cache = RecursiveCharacterTextSplitter(chunk_size=400, chunk_overlap=50)
608
- _perf("text_splitter loaded")
609
  return _text_splitter_cache
610
 
611
  SYSTEM_PROMPT = (
@@ -730,9 +675,6 @@ def process_upload(uploaded_files, target_prefix, scope):
730
  for src_file, (chunks, vecs) in file_groups.items():
731
  _save_chunks_to_db(scope, chunks, vecs, src_file)
732
 
733
- # 刷新 session_state 缓存
734
- _refresh_library(target_prefix, scope)
735
-
736
  st.session_state[fp_key] = file_fingerprint
737
  # 递增上传组件 key
738
  ukey = f"_upload_ver_{target_prefix}"
@@ -772,9 +714,9 @@ model_mapping = {
772
  "🏢 百度文心 (官方)": "ernie-3.5-8k",
773
  }
774
 
775
- _perf("before sidebar UI")
776
  with st.sidebar:
777
- pub_chunk_count = len(st.session_state.get("public_docs", []))
778
  with st.expander(f"📚 公共知识库({pub_chunk_count} 切片)"):
779
  st.caption("所有人可搜索")
780
 
@@ -789,7 +731,6 @@ with st.sidebar:
789
  if col_del.button("🗑", key=f"delpub_{fname}", help=f"删除 {fname}"):
790
  _delete_chunks_by_file(PUBLIC_SCOPE, fname)
791
  _delete_uploaded_file_from_storage(PUBLIC_SCOPE, fname)
792
- _refresh_library("public", PUBLIC_SCOPE)
793
  st.success(f"已删除 {fname}")
794
  time.sleep(0.5)
795
  st.rerun()
@@ -812,7 +753,6 @@ with st.sidebar:
812
  if st.button("🗑️ 清空公共库", use_container_width=True, type="secondary", key="clear_pub"):
813
  _clear_all_chunks(PUBLIC_SCOPE)
814
  _clear_uploaded_files_storage(PUBLIC_SCOPE)
815
- _refresh_library("public", PUBLIC_SCOPE)
816
  st.success("公共知识库已清空。")
817
  time.sleep(0.5)
818
  st.rerun()
@@ -820,7 +760,7 @@ with st.sidebar:
820
  st.caption("*仅管理员可维护公共库*")
821
 
822
  # --- 私有知识库 ---
823
- priv_chunk_count = len(st.session_state.get("private_docs", []))
824
  with st.expander(f"🔒 我的私有库({priv_chunk_count} 切片)"):
825
  st.caption(f"用户:{CURRENT_USER},仅自己可见")
826
 
@@ -833,7 +773,6 @@ with st.sidebar:
833
  if col_del.button("🗑", key=f"delpriv_{fname}", help=f"删除 {fname}"):
834
  _delete_chunks_by_file(PRIVATE_SCOPE, fname)
835
  _delete_uploaded_file_from_storage(PRIVATE_SCOPE, fname)
836
- _refresh_library("private", PRIVATE_SCOPE)
837
  st.success(f"已删除 {fname}")
838
  time.sleep(0.5)
839
  st.rerun()
@@ -853,7 +792,6 @@ with st.sidebar:
853
  if st.button("🗑️ 清空我的私有库", use_container_width=True, type="secondary", key="clear_priv"):
854
  _clear_all_chunks(PRIVATE_SCOPE)
855
  _clear_uploaded_files_storage(PRIVATE_SCOPE)
856
- _refresh_library("private", PRIVATE_SCOPE)
857
  st.success("私有知识库已清空。")
858
  time.sleep(0.5)
859
  st.rerun()
@@ -878,7 +816,7 @@ with st.sidebar:
878
  new_pass1 = st.text_input("新密码", type="password", key="self_new_pass1")
879
  new_pass2 = st.text_input("确认新密码", type="password", key="self_new_pass2")
880
  if st.button("✅ 确认修改", key="btn_change_pass"):
881
- ok, _ = verify_user(CURRENT_USER, old_pass)
882
  if not ok:
883
  st.error("当前密码错误")
884
  elif len(new_pass1) < 4:
@@ -971,49 +909,59 @@ with st.sidebar:
971
  display_users[k] = v
972
  st.json(display_users)
973
 
974
- # --- 清空聊天记录 ---
975
- with st.expander("🧹 清空聊天记录"):
976
- st.caption("清空后不可恢复")
977
- if st.button("确认清空", use_container_width=True, type="secondary", key="btn_clear_chat"):
978
- st.session_state.messages = []
979
- st.rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
980
 
981
 
982
  # =========================
983
  # 8. 核心搜索逻辑(合并公共库 + 私有库)
984
  # =========================
985
- def _cosine_scores(query_vec, matrix):
986
- np = _get_np()
987
- query_norm = np.linalg.norm(query_vec)
988
- if query_norm < 1e-10:
989
- return np.zeros(matrix.shape[0])
990
- mat_norms = np.linalg.norm(matrix, axis=1)
991
- mat_norms = np.maximum(mat_norms, 1e-10)
992
- return (matrix @ query_vec) / (mat_norms * query_norm)
993
-
994
-
995
  def search_local(query, top_k, threshold):
 
996
  query_vec = encode_query(query)
997
- all_results = []
998
-
999
- pub_docs = st.session_state.get("public_docs", [])
1000
- pub_np = _get_embeddings_np("public")
1001
- if pub_docs and pub_np.size > 0:
1002
- scores = _cosine_scores(query_vec, pub_np)
1003
- for i, s in enumerate(scores):
1004
- if s > threshold:
1005
- all_results.append((float(s), pub_docs[i]))
1006
-
1007
- priv_docs = st.session_state.get("private_docs", [])
1008
- priv_np = _get_embeddings_np("private")
1009
- if priv_docs and priv_np.size > 0:
1010
- scores = _cosine_scores(query_vec, priv_np)
1011
- for i, s in enumerate(scores):
1012
- if s > threshold:
1013
- all_results.append((float(s), priv_docs[i]))
1014
-
1015
- all_results.sort(key=lambda x: x[0], reverse=True)
1016
- return [doc for _, doc in all_results[:top_k]]
1017
 
1018
 
1019
  # =========================
@@ -1106,11 +1054,63 @@ def llm_answer(query, context_docs, selected_display_name, web_enabled):
1106
  yield "❌ 抱歉,所有免费和收费线路均暂时不可用。"
1107
 
1108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1109
  # =========================
1110
  # 10. 聊天渲染
1111
  # =========================
1112
  if "messages" not in st.session_state:
1113
- st.session_state.messages = []
 
 
 
 
 
 
1114
 
1115
  for m in st.session_state.messages:
1116
  with st.chat_message(m["role"]):
@@ -1120,6 +1120,7 @@ for m in st.session_state.messages:
1120
 
1121
  if q := st.chat_input("输入问题...", key="chat_input_v3"):
1122
  st.session_state.messages.append({"role": "user", "content": q})
 
1123
  with st.chat_message("user"):
1124
  st.markdown(q)
1125
 
@@ -1142,8 +1143,8 @@ if q := st.chat_input("输入问题...", key="chat_input_v3"):
1142
  st.session_state.messages.append(
1143
  {"role": "assistant", "content": full_response, "meta": meta_info}
1144
  )
 
1145
  except Exception as e:
1146
  logger.error(f"模型调用异常: {e}")
1147
  container.error(f"❌ 抱歉,连接模型时出错了: {str(e)}")
1148
 
1149
- _perf("script execution complete")
 
1
  import streamlit as st
 
 
2
  import json
3
  import time
4
  import logging
 
9
  logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
10
  logger = logging.getLogger(__name__)
11
 
 
 
 
 
 
12
  # numpy 延迟导入
13
  _np_module = None
14
  def _get_np():
 
16
  if _np_module is None:
17
  import numpy
18
  _np_module = numpy
 
19
  return _np_module
20
 
21
  # =========================
22
  # 1. 页面配置 & 样式注入
23
  # =========================
24
  st.set_page_config(page_title="RAG 知识库助手 v3 (HF+Supabase)", page_icon="🛡️", layout="wide")
25
+
26
 
27
 
28
  def inject_custom_css():
 
51
 
52
  inject_custom_css()
53
  st.title("🛡️ 智能知识库助手 v3")
 
54
 
55
  # =========================
56
  # 1.5 Supabase 客户端初始化
 
76
  return _get_supabase()
77
 
78
 
 
 
79
  # =========================
80
  # 2. 用户管理(Supabase users 表)
81
  # =========================
 
188
  users = _load_users()
189
  user_info = users.get(username)
190
  if not user_info or not isinstance(user_info, dict):
191
+ return False, None, "not_found"
192
  if user_info.get("password_hash") != _hash_password(password):
193
+ return False, None, "wrong_password"
194
+ return True, user_info.get("role", "user"), ""
195
 
196
 
197
  # --- 认证 UI ---
 
227
  if input_username == "" or input_password == "":
228
  st.stop()
229
 
230
+ ok, role, reason = verify_user(input_username, input_password)
231
  if not ok:
232
  st.session_state.login_attempts += 1
233
  remaining = MAX_LOGIN_ATTEMPTS - st.session_state.login_attempts
234
+ if reason == "not_found":
235
+ st.warning(f"⚠️ 用户不存在,请先注册(剩余 {remaining} 次)")
236
+ else:
237
+ st.warning(f"⚠️ 密码错误(剩余 {remaining} 次)")
238
  st.stop()
239
  else:
240
  st.session_state.login_attempts = 0
 
255
  else:
256
  ok, msg = register_user(reg_user, reg_pass, reg_code)
257
  if ok:
258
+ st.session_state.current_user = reg_user
259
+ st.session_state.current_role = "user"
260
+ st.session_state.login_attempts = 0
261
+ st.success(f"✅ 注册成功,已自动登录")
262
  time.sleep(1)
263
  st.rerun()
264
  else:
 
267
 
268
  CURRENT_USER = st.session_state.current_user
269
  IS_ADMIN = st.session_state.current_role == "admin"
270
+
271
 
272
  # =========================
273
  # 3. 安全配置与 Embedding 策略
 
350
  # 4. Supabase 索引管理(替代本地文件)
351
  # =========================
352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  def _save_chunks_to_db(scope, chunks, vectors, source_file):
354
  """将新切片批量写入 Supabase documents 表。"""
355
  rows = []
 
492
  logger.warning(f"清空文件失败 [scope={scope}]: {e}")
493
 
494
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
495
 
 
 
 
 
 
 
 
 
 
496
  PUBLIC_SCOPE = "public"
 
497
  PRIVATE_SCOPE = CURRENT_USER # 私有库 scope = 用户名
 
 
498
 
499
+ # --- 定时同步:检测其他用户对文档库的修改 ---
500
+ _SYNC_INTERVAL = 30 # 每 30 秒检查一次
501
 
502
+ def _check_and_sync():
503
+ """检测文档数量变化,用于多用户同步感知。"""
504
+ now = time.time()
505
+ last_check = st.session_state.get("_sync_last_check", 0)
506
+ if now - last_check < _SYNC_INTERVAL:
507
+ return
508
+ st.session_state["_sync_last_check"] = now
509
+ for scope, label in [(PUBLIC_SCOPE, "public"), (PRIVATE_SCOPE, "private")]:
510
+ count_key = f"_sync_count_{label}"
511
+ current_count = _count_chunks(scope)
512
+ prev_count = st.session_state.get(count_key, -1)
513
+ if prev_count >= 0 and current_count != prev_count:
514
+ logger.info(f"[SYNC] {label} 库变更: {prev_count} -> {current_count}")
515
+ st.session_state[count_key] = current_count
516
+
517
+ _check_and_sync()
518
 
519
 
520
  # =========================
 
551
  if _text_splitter_cache is None:
552
  from langchain_text_splitters import RecursiveCharacterTextSplitter
553
  _text_splitter_cache = RecursiveCharacterTextSplitter(chunk_size=400, chunk_overlap=50)
 
554
  return _text_splitter_cache
555
 
556
  SYSTEM_PROMPT = (
 
675
  for src_file, (chunks, vecs) in file_groups.items():
676
  _save_chunks_to_db(scope, chunks, vecs, src_file)
677
 
 
 
 
678
  st.session_state[fp_key] = file_fingerprint
679
  # 递增上传组件 key
680
  ukey = f"_upload_ver_{target_prefix}"
 
714
  "🏢 百度文心 (官方)": "ernie-3.5-8k",
715
  }
716
 
717
+
718
  with st.sidebar:
719
+ pub_chunk_count = st.session_state.get("_sync_count_public", _count_chunks(PUBLIC_SCOPE))
720
  with st.expander(f"📚 公共知识库({pub_chunk_count} 切片)"):
721
  st.caption("所有人可搜索")
722
 
 
731
  if col_del.button("🗑", key=f"delpub_{fname}", help=f"删除 {fname}"):
732
  _delete_chunks_by_file(PUBLIC_SCOPE, fname)
733
  _delete_uploaded_file_from_storage(PUBLIC_SCOPE, fname)
 
734
  st.success(f"已删除 {fname}")
735
  time.sleep(0.5)
736
  st.rerun()
 
753
  if st.button("🗑️ 清空公共库", use_container_width=True, type="secondary", key="clear_pub"):
754
  _clear_all_chunks(PUBLIC_SCOPE)
755
  _clear_uploaded_files_storage(PUBLIC_SCOPE)
 
756
  st.success("公共知识库已清空。")
757
  time.sleep(0.5)
758
  st.rerun()
 
760
  st.caption("*仅管理员可维护公共库*")
761
 
762
  # --- 私有知识库 ---
763
+ priv_chunk_count = st.session_state.get("_sync_count_private", _count_chunks(PRIVATE_SCOPE))
764
  with st.expander(f"🔒 我的私有库({priv_chunk_count} 切片)"):
765
  st.caption(f"用户:{CURRENT_USER},仅自己可见")
766
 
 
773
  if col_del.button("🗑", key=f"delpriv_{fname}", help=f"删除 {fname}"):
774
  _delete_chunks_by_file(PRIVATE_SCOPE, fname)
775
  _delete_uploaded_file_from_storage(PRIVATE_SCOPE, fname)
 
776
  st.success(f"已删除 {fname}")
777
  time.sleep(0.5)
778
  st.rerun()
 
792
  if st.button("🗑️ 清空我的私有库", use_container_width=True, type="secondary", key="clear_priv"):
793
  _clear_all_chunks(PRIVATE_SCOPE)
794
  _clear_uploaded_files_storage(PRIVATE_SCOPE)
 
795
  st.success("私有知识库已清空。")
796
  time.sleep(0.5)
797
  st.rerun()
 
816
  new_pass1 = st.text_input("新密码", type="password", key="self_new_pass1")
817
  new_pass2 = st.text_input("确认新密码", type="password", key="self_new_pass2")
818
  if st.button("✅ 确认修改", key="btn_change_pass"):
819
+ ok, _, _ = verify_user(CURRENT_USER, old_pass)
820
  if not ok:
821
  st.error("当前密码错误")
822
  elif len(new_pass1) < 4:
 
909
  display_users[k] = v
910
  st.json(display_users)
911
 
912
+ # --- 聊天记录管理 ---
913
+ with st.expander("💬 聊天记录"):
914
+ hist_tab_new, hist_tab_history = st.tabs(["当前对话", "历史记录"])
915
+
916
+ with hist_tab_new:
917
+ st.caption("清空当前对话(数据库记录保留)")
918
+ if st.button("🧹 清空当前对话", use_container_width=True, type="secondary", key="btn_clear_chat"):
919
+ st.session_state.messages = []
920
+ st.rerun()
921
+
922
+ st.caption("清空所有历史记录(不可恢复)")
923
+ if st.button("🗑️ 清空全部记录", use_container_width=True, type="secondary", key="btn_clear_all_hist"):
924
+ _clear_chat_history_db(CURRENT_USER)
925
+ st.session_state.messages = []
926
+ st.success("所有聊天记录已清空")
927
+ time.sleep(0.5)
928
+ st.rerun()
929
+
930
+ with hist_tab_history:
931
+ if st.button("🔄 加载历史记录", use_container_width=True, key="btn_load_hist"):
932
+ st.session_state["_show_history"] = True
933
+
934
+ if st.session_state.get("_show_history"):
935
+ history = _load_chat_history(CURRENT_USER, limit=100)
936
+ if not history:
937
+ st.info("暂无历史记录")
938
+ else:
939
+ st.caption(f"共 {len(history)} 条记录")
940
+ for msg in history:
941
+ ts = msg.get("created_at", "")[:16].replace("T", " ")
942
+ icon = "🧑" if msg["role"] == "user" else "🤖"
943
+ preview = msg["content"][:80].replace("\n", " ")
944
+ st.text(f"{icon} [{ts}] {preview}{'...' if len(msg['content']) > 80 else ''}")
945
 
946
 
947
  # =========================
948
  # 8. 核心搜索逻辑(合并公共库 + 私有库)
949
  # =========================
 
 
 
 
 
 
 
 
 
 
950
  def search_local(query, top_k, threshold):
951
+ """使用 pgvector 数据库端向量搜索(替代内存计算)。"""
952
  query_vec = encode_query(query)
953
+ scopes = [PUBLIC_SCOPE, PRIVATE_SCOPE]
954
+ try:
955
+ resp = _sb().rpc("match_documents", {
956
+ "query_embedding": query_vec.tolist() if hasattr(query_vec, 'tolist') else list(query_vec),
957
+ "match_scopes": scopes,
958
+ "match_threshold": float(threshold),
959
+ "match_count": int(top_k),
960
+ }).execute()
961
+ return [row["content"] for row in resp.data] if resp.data else []
962
+ except Exception as e:
963
+ logger.error(f"pgvector 搜索失败: {e}")
964
+ return []
 
 
 
 
 
 
 
 
965
 
966
 
967
  # =========================
 
1054
  yield "❌ 抱歉,所有免费和收费线路均暂时不可用。"
1055
 
1056
 
1057
+ # =========================
1058
+ # 9.5 聊天记录持久化(Supabase chat_history 表)
1059
+ # =========================
1060
+ def _save_chat_message(username, role, content, meta=""):
1061
+ """保存单条聊天消息到数据库。"""
1062
+ try:
1063
+ _sb().table("chat_history").insert({
1064
+ "username": username,
1065
+ "role": role,
1066
+ "content": content,
1067
+ "meta": meta or "",
1068
+ }).execute()
1069
+ except Exception as e:
1070
+ logger.warning(f"保存聊天记录失败: {e}")
1071
+
1072
+
1073
+ def _load_chat_history(username, limit=50):
1074
+ """加载用户最近的聊天记录。"""
1075
+ try:
1076
+ resp = _sb().table("chat_history").select(
1077
+ "role, content, meta, created_at"
1078
+ ).eq("username", username).order(
1079
+ "created_at", desc=True
1080
+ ).limit(limit).execute()
1081
+ if not resp.data:
1082
+ return []
1083
+ # 反转回时间正序
1084
+ rows = list(reversed(resp.data))
1085
+ return [
1086
+ {"role": r["role"], "content": r["content"],
1087
+ "meta": r.get("meta", ""), "created_at": r.get("created_at", "")}
1088
+ for r in rows
1089
+ ]
1090
+ except Exception as e:
1091
+ logger.warning(f"加载聊天记录失败: {e}")
1092
+ return []
1093
+
1094
+
1095
+ def _clear_chat_history_db(username):
1096
+ """清空用户在数据库中的所有聊天记录。"""
1097
+ try:
1098
+ _sb().table("chat_history").delete().eq("username", username).execute()
1099
+ except Exception as e:
1100
+ logger.warning(f"清空聊天记录失败: {e}")
1101
+
1102
+
1103
  # =========================
1104
  # 10. 聊天渲染
1105
  # =========================
1106
  if "messages" not in st.session_state:
1107
+ # 首次加载时从数据库恢复最近对话
1108
+ saved = _load_chat_history(CURRENT_USER, limit=50)
1109
+ st.session_state.messages = [
1110
+ {"role": m["role"], "content": m["content"],
1111
+ **({"meta": m["meta"]} if m.get("meta") else {})}
1112
+ for m in saved
1113
+ ]
1114
 
1115
  for m in st.session_state.messages:
1116
  with st.chat_message(m["role"]):
 
1120
 
1121
  if q := st.chat_input("输入问题...", key="chat_input_v3"):
1122
  st.session_state.messages.append({"role": "user", "content": q})
1123
+ _save_chat_message(CURRENT_USER, "user", q)
1124
  with st.chat_message("user"):
1125
  st.markdown(q)
1126
 
 
1143
  st.session_state.messages.append(
1144
  {"role": "assistant", "content": full_response, "meta": meta_info}
1145
  )
1146
+ _save_chat_message(CURRENT_USER, "assistant", full_response, meta_info)
1147
  except Exception as e:
1148
  logger.error(f"模型调用异常: {e}")
1149
  container.error(f"❌ 抱歉,连接模型时出错了: {str(e)}")
1150