AlauStone commited on
Commit
ba534ff
·
verified ·
1 Parent(s): 421dcad

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +1123 -0
  2. requirements.txt +8 -3
app.py ADDED
@@ -0,0 +1,1123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import time as _time
3
+ _BOOT = _time.time()
4
+ import json
5
+ import time
6
+ import logging
7
+ import hashlib
8
+ from datetime import datetime
9
+
10
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
11
+ logger = logging.getLogger(__name__)
12
+
13
+ def _perf(label):
14
+ logger.info(f"[PERF] {label}: {_time.time()-_BOOT:.2f}s")
15
+
16
+ _perf("stdlib imports done")
17
+
18
+ # numpy 延迟导入
19
+ _np_module = None
20
+ def _get_np():
21
+ global _np_module
22
+ if _np_module is None:
23
+ import numpy
24
+ _np_module = numpy
25
+ _perf("numpy loaded")
26
+ return _np_module
27
+
28
+ # =========================
29
+ # 1. 页面配置 & 样式注入
30
+ # =========================
31
+ st.set_page_config(page_title="RAG 知识库助手 v3 (HF+Supabase)", page_icon="🛡️", layout="wide")
32
+ _perf("page_config done")
33
+
34
+
35
+ def inject_custom_css():
36
+ st.markdown("""
37
+ <style>
38
+ [data-testid="stSidebarContent"] { padding-top: 1.5rem !important; }
39
+ [data-testid="stVerticalBlock"] > div { gap: 0.8rem !important; }
40
+ [data-testid="stFileUploader"] section > div { display: none; }
41
+ [data-testid="stFileUploaderDropzoneInstructions"] { display: none !important; }
42
+ [data-testid="stFileUploader"] section::before {
43
+ content: "拖拽文档至此";
44
+ color: #555; font-size: 14px; display: block; margin-bottom: 10px;
45
+ }
46
+ [data-testid="stFileUploader"] section::after {
47
+ content: "支持格式:TXT, PDF, DOCX";
48
+ color: #888; font-size: 12px; display: block; margin-top: 5px;
49
+ }
50
+ [data-testid="stFileUploader"] button { font-size: 0 !important; }
51
+ [data-testid="stFileUploader"] button::after {
52
+ content: "选择文件";
53
+ font-size: 14px !important;
54
+ }
55
+ </style>
56
+ """, unsafe_allow_html=True)
57
+
58
+
59
+ inject_custom_css()
60
+ st.title("🛡️ 智能知识库助手 v3")
61
+ _perf("CSS + title done")
62
+
63
+ # =========================
64
+ # 1.5 Supabase 客户端初始化
65
+ # =========================
66
+ from supabase import create_client
67
+
68
+ SUPABASE_URL = st.secrets.get("SUPABASE_URL", "")
69
+ SUPABASE_KEY = st.secrets.get("SUPABASE_KEY", "") # service_role key (后端使用)
70
+ STORAGE_BUCKET = "rag-files"
71
+
72
+ if not SUPABASE_URL or not SUPABASE_KEY:
73
+ st.error("⚠️ 未配置 SUPABASE_URL 或 SUPABASE_KEY,请在 Secrets 中设置。")
74
+ st.stop()
75
+
76
+
77
+ @st.cache_resource
78
+ def _get_supabase():
79
+ return create_client(SUPABASE_URL, SUPABASE_KEY)
80
+
81
+
82
+ def _sb():
83
+ """快捷获取 Supabase 客户端。"""
84
+ return _get_supabase()
85
+
86
+
87
+ _perf("supabase client ready")
88
+
89
+ # =========================
90
+ # 2. 用户管理(Supabase users 表)
91
+ # =========================
92
+ MAX_LOGIN_ATTEMPTS = 10
93
+
94
+
95
+ def _hash_password(password):
96
+ return hashlib.sha256(password.encode("utf-8")).hexdigest()
97
+
98
+
99
+ def _load_users():
100
+ """从 Supabase users 表加载所有用户,返回 {username: {password_hash, role, created_at}}。"""
101
+ try:
102
+ resp = _sb().table("users").select("*").execute()
103
+ users = {}
104
+ for row in resp.data:
105
+ users[row["username"]] = {
106
+ "password_hash": row["password_hash"],
107
+ "role": row["role"],
108
+ "created_at": row["created_at"][:16] if row.get("created_at") else "未知",
109
+ }
110
+ return users
111
+ except Exception as e:
112
+ logger.error(f"加载用户表失败: {e}")
113
+ return {}
114
+
115
+
116
+ def _ensure_admin():
117
+ """首次运行时,从 secrets 创建管理员(如果 users 表为空)。"""
118
+ users = _load_users()
119
+ if users:
120
+ return
121
+ admin_user = st.secrets.get("ADMIN_USER", "admin")
122
+ admin_pass = st.secrets.get("ADMIN_PASSWORD", "")
123
+ if not admin_pass:
124
+ return
125
+ try:
126
+ _sb().table("users").upsert({
127
+ "username": admin_user,
128
+ "password_hash": _hash_password(admin_pass),
129
+ "role": "admin",
130
+ }).execute()
131
+ logger.info(f"初始管理员 {admin_user} 已创建")
132
+ except Exception as e:
133
+ logger.error(f"创建管理员失败: {e}")
134
+
135
+
136
+ _ensure_admin()
137
+
138
+
139
+ def _save_user(username, password_hash, role="user"):
140
+ """新增或更新单个用户。"""
141
+ _sb().table("users").upsert({
142
+ "username": username,
143
+ "password_hash": password_hash,
144
+ "role": role,
145
+ }).execute()
146
+
147
+
148
+ def _delete_user_db(username):
149
+ """删除用户记录。"""
150
+ _sb().table("users").delete().eq("username", username).execute()
151
+
152
+
153
+ def _get_invite_code():
154
+ """从 app_meta 表读取邀请码。"""
155
+ try:
156
+ resp = _sb().table("app_meta").select("value").eq("key", "invite_code").execute()
157
+ if resp.data:
158
+ return resp.data[0]["value"]
159
+ except Exception:
160
+ pass
161
+ return st.secrets.get("INVITE_CODE", "")
162
+
163
+
164
+ def _set_invite_code(new_code):
165
+ _sb().table("app_meta").upsert({"key": "invite_code", "value": new_code}).execute()
166
+
167
+
168
+ def register_user(username, password, invite_code):
169
+ if not username or not password:
170
+ return False, "用户名和密码不能为空"
171
+ if len(username) < 2 or len(username) > 20:
172
+ return False, "用户名长度需要 2-20 个字符"
173
+ if len(password) < 4:
174
+ return False, "密码至少 4 个字符"
175
+ if username.startswith("__"):
176
+ return False, "用户名不能以 __ 开头"
177
+
178
+ correct_code = _get_invite_code()
179
+ if not correct_code:
180
+ return False, "邀请码未配置,请联系管理员"
181
+ if invite_code != correct_code:
182
+ return False, "邀请码错误"
183
+
184
+ users = _load_users()
185
+ if username in users:
186
+ return False, "用户名已存在"
187
+
188
+ try:
189
+ _save_user(username, _hash_password(password), "user")
190
+ logger.info(f"新用户注册: {username}")
191
+ return True, "注册成功,请登录"
192
+ except Exception as e:
193
+ logger.error(f"注册失败: {e}")
194
+ return False, f"注册失败: {e}"
195
+
196
+
197
+ def verify_user(username, password):
198
+ users = _load_users()
199
+ user_info = users.get(username)
200
+ if not user_info or not isinstance(user_info, dict):
201
+ return False, None
202
+ if user_info.get("password_hash") != _hash_password(password):
203
+ return False, None
204
+ return True, user_info.get("role", "user")
205
+
206
+
207
+ # --- 认证 UI ---
208
+ if "login_attempts" not in st.session_state:
209
+ st.session_state.login_attempts = 0
210
+ if "current_user" not in st.session_state:
211
+ st.session_state.current_user = None
212
+ if "current_role" not in st.session_state:
213
+ st.session_state.current_role = None
214
+ if "auth_mode" not in st.session_state:
215
+ st.session_state.auth_mode = "login"
216
+
217
+ with st.sidebar:
218
+ with st.expander("🔑 账号"):
219
+ if st.session_state.login_attempts >= MAX_LOGIN_ATTEMPTS:
220
+ st.error("🚫 尝试次数过多,请刷新页面后重试。")
221
+ st.stop()
222
+
223
+ users_data = _load_users()
224
+ if not users_data:
225
+ st.error("⚠️ 未配置管理员,请在 secrets 中设置 ADMIN_USER 和 ADMIN_PASSWORD。")
226
+ st.stop()
227
+
228
+ auth_mode = st.radio(
229
+ "操作", ["登录", "注册"], horizontal=True,
230
+ label_visibility="collapsed", key="auth_radio",
231
+ )
232
+
233
+ if auth_mode == "登录":
234
+ input_username = st.text_input("用户名", key="login_user")
235
+ input_password = st.text_input("密码", type="password", key="login_pass")
236
+
237
+ if input_username == "" or input_password == "":
238
+ st.stop()
239
+
240
+ ok, role = verify_user(input_username, input_password)
241
+ if not ok:
242
+ st.session_state.login_attempts += 1
243
+ remaining = MAX_LOGIN_ATTEMPTS - st.session_state.login_attempts
244
+ st.warning(f"⚠️ 用户名或密码错误(剩余 {remaining} 次)")
245
+ st.stop()
246
+ else:
247
+ st.session_state.login_attempts = 0
248
+ st.session_state.current_user = input_username
249
+ st.session_state.current_role = role
250
+ role_label = "管理员" if role == "admin" else "普通用户"
251
+ st.success(f"✅ {input_username}({role_label})")
252
+
253
+ else: # 注册
254
+ reg_user = st.text_input("用户名", key="reg_user")
255
+ reg_pass = st.text_input("密码", type="password", key="reg_pass")
256
+ reg_pass2 = st.text_input("确认密码", type="password", key="reg_pass2")
257
+ reg_code = st.text_input("邀请码", type="password", key="reg_code")
258
+
259
+ if st.button("注册", use_container_width=True, key="btn_register"):
260
+ if reg_pass != reg_pass2:
261
+ st.error("两次密码不一致")
262
+ else:
263
+ ok, msg = register_user(reg_user, reg_pass, reg_code)
264
+ if ok:
265
+ st.success(f"✅ {msg}")
266
+ time.sleep(1)
267
+ st.rerun()
268
+ else:
269
+ st.error(f"❌ {msg}")
270
+ st.stop()
271
+
272
+ CURRENT_USER = st.session_state.current_user
273
+ IS_ADMIN = st.session_state.current_role == "admin"
274
+ _perf("auth done")
275
+
276
+ # =========================
277
+ # 3. 安全配置与 Embedding 策略
278
+ # =========================
279
+ TAVILY_KEY = st.secrets.get("TAVILY_API_KEY", "")
280
+ DS_API_KEY = st.secrets.get("DEEPSEEK_API_KEY", "")
281
+ BAIDU_TOKEN = st.secrets.get("BAIDU_BEARER_TOKEN", "")
282
+ BAIDU_APP_ID = st.secrets.get("BAIDU_APP_ID", "")
283
+ OR_KEY = st.secrets.get("OPENROUTER_API_KEY", "")
284
+
285
+
286
+ @st.cache_resource
287
+ def _get_embedding_client():
288
+ from openai import OpenAI
289
+ if BAIDU_TOKEN and BAIDU_APP_ID:
290
+ return OpenAI(
291
+ api_key=BAIDU_TOKEN,
292
+ base_url="https://qianfan.baidubce.com/v2",
293
+ default_headers={"appid": BAIDU_APP_ID},
294
+ ), "bge-large-zh"
295
+ if OR_KEY:
296
+ return OpenAI(
297
+ api_key=OR_KEY,
298
+ base_url="https://openrouter.ai/api/v1",
299
+ ), "BAAI/bge-small-zh"
300
+ return None, None
301
+
302
+
303
+ def _api_encode(texts):
304
+ np = _get_np()
305
+ client, model = _get_embedding_client()
306
+ if client is None:
307
+ return None
308
+ try:
309
+ batch_size = 32
310
+ all_vecs = []
311
+ for i in range(0, len(texts), batch_size):
312
+ batch = texts[i:i + batch_size]
313
+ resp = client.embeddings.create(model=model, input=batch)
314
+ all_vecs.extend([np.array(item.embedding) for item in resp.data])
315
+ return all_vecs
316
+ except Exception as e:
317
+ logger.warning(f"API embedding 失败,回退到本地模型: {e}")
318
+ return None
319
+
320
+
321
+ def _get_local_model():
322
+ if "_local_emb_model" not in st.session_state:
323
+ try:
324
+ with st.spinner("API 不可用,正在加载本地向量模型(仅首次)..."):
325
+ from sentence_transformers import SentenceTransformer
326
+ st.session_state._local_emb_model = SentenceTransformer("BAAI/bge-small-zh")
327
+ except ImportError:
328
+ logger.error("sentence-transformers 未安装,本地模型不可用")
329
+ return None
330
+ return st.session_state.get("_local_emb_model")
331
+
332
+
333
+ def encode_texts(texts):
334
+ if not texts:
335
+ return []
336
+ if isinstance(texts, str):
337
+ texts = [texts]
338
+ result = _api_encode(texts)
339
+ if result is not None:
340
+ return result
341
+ model = _get_local_model()
342
+ if model is None:
343
+ st.error("❌ Embedding 服务不可用:API 调用失败且本地模型未安装。请检查 API Key 配置。")
344
+ return []
345
+ return list(model.encode(texts))
346
+
347
+
348
+ def encode_query(text):
349
+ vecs = encode_texts([text])
350
+ return vecs[0]
351
+
352
+
353
+ # =========================
354
+ # 4. Supabase 索引管理(替代本地文件)
355
+ # =========================
356
+
357
+ def _load_library(scope):
358
+ """从 Supabase documents 表加载指定 scope 的所有文档切片。
359
+ 返回 (docs, embeddings, sources)。"""
360
+ np = _get_np()
361
+ try:
362
+ resp = _sb().table("documents").select(
363
+ "content, embedding, source_file"
364
+ ).eq("scope", scope).execute()
365
+
366
+ docs = []
367
+ embeddings = []
368
+ sources = []
369
+ for row in resp.data:
370
+ docs.append(row["content"])
371
+ embeddings.append(np.array(row["embedding"]))
372
+ sources.append(row["source_file"])
373
+ return docs, embeddings, sources
374
+ except Exception as e:
375
+ logger.error(f"加载索引失败 [scope={scope}]: {e}")
376
+ return [], [], []
377
+
378
+
379
+ def _save_chunks_to_db(scope, chunks, vectors, source_file):
380
+ """将新切片批量写入 Supabase documents 表。"""
381
+ rows = []
382
+ for content, vec, src in zip(chunks, vectors, [source_file] * len(chunks)):
383
+ rows.append({
384
+ "scope": scope,
385
+ "source_file": src,
386
+ "content": content,
387
+ "embedding": vec.tolist() if hasattr(vec, 'tolist') else list(vec),
388
+ })
389
+ # Supabase 批量插入(每次最多 500 行)
390
+ batch_size = 500
391
+ for i in range(0, len(rows), batch_size):
392
+ _sb().table("documents").insert(rows[i:i + batch_size]).execute()
393
+
394
+
395
+ def _delete_chunks_by_file(scope, filename):
396
+ """删除指定 scope + filename 的所有切片。"""
397
+ _sb().table("documents").delete().eq("scope", scope).eq("source_file", filename).execute()
398
+
399
+
400
+ def _clear_all_chunks(scope):
401
+ """清空指定 scope 的所有文档切片。"""
402
+ _sb().table("documents").delete().eq("scope", scope).execute()
403
+
404
+
405
+ def _count_chunks(scope):
406
+ """返回指定 scope 的切片数量。"""
407
+ try:
408
+ resp = _sb().table("documents").select("id", count="exact").eq("scope", scope).execute()
409
+ return resp.count or 0
410
+ except Exception:
411
+ return 0
412
+
413
+
414
+ # --- 原始文件管理(Supabase Storage + uploaded_files 表)---
415
+
416
+ def _save_uploaded_file_to_storage(scope, uploaded_file):
417
+ """上传原始文件到 Supabase Storage,并记录元数据。"""
418
+ storage_path = f"{scope}/{uploaded_file.name}"
419
+ uploaded_file.seek(0)
420
+ file_bytes = uploaded_file.read()
421
+
422
+ # 上传到 Storage(存在则覆盖)
423
+ try:
424
+ _sb().storage.from_(STORAGE_BUCKET).upload(
425
+ storage_path, file_bytes,
426
+ file_options={"content-type": "application/octet-stream", "upsert": "true"}
427
+ )
428
+ except Exception as e:
429
+ # supabase-py 某些版本 upsert 需要先删再传
430
+ logger.warning(f"Storage upload fallback: {e}")
431
+ try:
432
+ _sb().storage.from_(STORAGE_BUCKET).remove([storage_path])
433
+ except Exception:
434
+ pass
435
+ _sb().storage.from_(STORAGE_BUCKET).upload(
436
+ storage_path, file_bytes,
437
+ file_options={"content-type": "application/octet-stream"}
438
+ )
439
+
440
+ # 记录元数据到 uploaded_files 表
441
+ _sb().table("uploaded_files").upsert({
442
+ "scope": scope,
443
+ "filename": uploaded_file.name,
444
+ "file_size": len(file_bytes),
445
+ "storage_path": storage_path,
446
+ }, on_conflict="scope,filename").execute()
447
+
448
+
449
+ def _list_uploaded_files_db(scope):
450
+ """列出某个 scope 已上传的文件。返回 [(filename, size_str, storage_path), ...]。"""
451
+ try:
452
+ resp = _sb().table("uploaded_files").select(
453
+ "filename, file_size, storage_path"
454
+ ).eq("scope", scope).order("filename").execute()
455
+
456
+ result = []
457
+ for row in resp.data:
458
+ size = row.get("file_size", 0) or 0
459
+ if size < 1024:
460
+ size_str = f"{size}B"
461
+ elif size < 1048576:
462
+ size_str = f"{size / 1024:.1f}KB"
463
+ else:
464
+ size_str = f"{size / 1048576:.1f}MB"
465
+ result.append((row["filename"], size_str, row.get("storage_path", "")))
466
+ return result
467
+ except Exception as e:
468
+ logger.error(f"列出文件失败 [scope={scope}]: {e}")
469
+ return []
470
+
471
+
472
+ def _delete_uploaded_file_from_storage(scope, filename):
473
+ """删除 Storage 中的文件和 uploaded_files 表记录。"""
474
+ storage_path = f"{scope}/{filename}"
475
+ try:
476
+ _sb().storage.from_(STORAGE_BUCKET).remove([storage_path])
477
+ except Exception as e:
478
+ logger.warning(f"Storage 删除失败: {e}")
479
+ try:
480
+ _sb().table("uploaded_files").delete().eq("scope", scope).eq("filename", filename).execute()
481
+ except Exception as e:
482
+ logger.warning(f"uploaded_files 记录删除失败: {e}")
483
+
484
+
485
+ def _clear_uploaded_files_storage(scope):
486
+ """清空某个 scope 的所有上传文件。"""
487
+ files = _list_uploaded_files_db(scope)
488
+ paths = [f["storage_path"] for f in files] if files else []
489
+ # 从 uploaded_files 表拿 storage_path
490
+ try:
491
+ resp = _sb().table("uploaded_files").select("storage_path").eq("scope", scope).execute()
492
+ paths = [row["storage_path"] for row in resp.data]
493
+ if paths:
494
+ _sb().storage.from_(STORAGE_BUCKET).remove(paths)
495
+ _sb().table("uploaded_files").delete().eq("scope", scope).execute()
496
+ except Exception as e:
497
+ logger.warning(f"清空文件失败 [scope={scope}]: {e}")
498
+
499
+
500
+ # --- 初始化 session_state 中的缓存 ---
501
+ def _init_library(key_prefix, scope):
502
+ """加载 Supabase 中的索引到 session_state。"""
503
+ docs_key = f"{key_prefix}_docs"
504
+ emb_key = f"{key_prefix}_embeddings"
505
+ src_key = f"{key_prefix}_sources"
506
+ loaded_key = f"{key_prefix}_loaded"
507
+
508
+ if docs_key not in st.session_state or not st.session_state.get(loaded_key):
509
+ docs, embeddings, sources = _load_library(scope)
510
+ st.session_state[docs_key] = docs
511
+ st.session_state[emb_key] = embeddings
512
+ st.session_state[src_key] = sources
513
+ st.session_state[loaded_key] = True
514
+
515
+
516
+ def _refresh_library(key_prefix, scope):
517
+ """强制从 Supabase 重新加载索引到 session_state。"""
518
+ docs, embeddings, sources = _load_library(scope)
519
+ st.session_state[f"{key_prefix}_docs"] = docs
520
+ st.session_state[f"{key_prefix}_embeddings"] = embeddings
521
+ st.session_state[f"{key_prefix}_sources"] = sources
522
+
523
+
524
+ _perf("before init_library")
525
+ PUBLIC_SCOPE = "public"
526
+ _init_library("public", PUBLIC_SCOPE)
527
+ PRIVATE_SCOPE = CURRENT_USER # 私有库 scope = 用户名
528
+ _init_library("private", PRIVATE_SCOPE)
529
+ _perf("init_library done")
530
+
531
+
532
+ def _get_embeddings_np(key_prefix):
533
+ np = _get_np()
534
+ np_key = f"{key_prefix}_embeddings_np"
535
+ ver_key = f"{key_prefix}_emb_version"
536
+ emb_key = f"{key_prefix}_embeddings"
537
+ emb_list = st.session_state.get(emb_key, [])
538
+ current_ver = id(emb_list)
539
+ if np_key not in st.session_state or st.session_state.get(ver_key) != current_ver:
540
+ if emb_list:
541
+ st.session_state[np_key] = np.array(emb_list)
542
+ else:
543
+ st.session_state[np_key] = np.array([])
544
+ st.session_state[ver_key] = current_ver
545
+ return st.session_state[np_key]
546
+
547
+
548
+ # =========================
549
+ # 5. 缓存 LLM 客户端
550
+ # =========================
551
+ @st.cache_resource
552
+ def get_or_client():
553
+ from openai import OpenAI
554
+ return OpenAI(api_key=OR_KEY, base_url="https://openrouter.ai/api/v1")
555
+
556
+
557
+ @st.cache_resource
558
+ def get_ds_client():
559
+ from openai import OpenAI
560
+ return OpenAI(api_key=DS_API_KEY, base_url="https://api.deepseek.com")
561
+
562
+
563
+ @st.cache_resource
564
+ def get_baidu_client():
565
+ from openai import OpenAI
566
+ return OpenAI(
567
+ api_key=BAIDU_TOKEN,
568
+ base_url="https://qianfan.baidubce.com/v2",
569
+ default_headers={"appid": BAIDU_APP_ID},
570
+ )
571
+
572
+
573
+ # =========================
574
+ # 6. 实用功能函数
575
+ # =========================
576
+ _text_splitter_cache = None
577
+ def _get_text_splitter():
578
+ global _text_splitter_cache
579
+ if _text_splitter_cache is None:
580
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
581
+ _text_splitter_cache = RecursiveCharacterTextSplitter(chunk_size=400, chunk_overlap=50)
582
+ _perf("text_splitter loaded")
583
+ return _text_splitter_cache
584
+
585
+ SYSTEM_PROMPT = (
586
+ "你是一个专业的知识问答助手。请基于提供的参考资料回答用户问题。"
587
+ "如果资料中没有相关信息,请诚实说明。回答要准确、有条理、简洁。"
588
+ "不要编造不在资料中的信息。"
589
+ )
590
+
591
+
592
+ def web_search(query):
593
+ if not TAVILY_KEY:
594
+ return "⚠️ 未配置搜索 Key"
595
+ from tavily import TavilyClient
596
+ tavily = TavilyClient(api_key=TAVILY_KEY)
597
+ current_year = datetime.now().year
598
+ try:
599
+ search_result = tavily.search(
600
+ query=f"{current_year}年 {query}",
601
+ search_depth="advanced",
602
+ max_results=3,
603
+ )
604
+ results = [
605
+ f"来源: {r.get('url')}\n内容: {r.get('content', '')[:700]}"
606
+ for r in search_result["results"]
607
+ ]
608
+ return "\n\n".join(results)[:2500]
609
+ except Exception as e:
610
+ logger.error(f"联网搜索异常: {e}")
611
+ return f"联网搜索异常:{str(e)}"
612
+
613
+
614
+ def estimate_tokens(text):
615
+ if not text:
616
+ return 0
617
+ zh_count = sum(1 for c in text if "\u4e00" <= c <= "\u9fff")
618
+ return int(zh_count * 1.5 + (len(text) - zh_count) * 0.4)
619
+
620
+
621
+ def extract_text(file):
622
+ fname = file.name.lower()
623
+ text = ""
624
+ try:
625
+ if fname.endswith(".txt"):
626
+ text = file.read().decode("utf-8", errors="ignore")
627
+ elif fname.endswith(".pdf"):
628
+ import io
629
+ file.seek(0)
630
+ pdf_bytes = file.read()
631
+ if len(pdf_bytes) < 100:
632
+ raise ValueError("PDF 文件过小,可能已损坏")
633
+ if not pdf_bytes[:5] == b"%PDF-":
634
+ raise ValueError("不是有效的 PDF 文件(缺少 %PDF- 头)")
635
+ pdf_stream = io.BytesIO(pdf_bytes)
636
+ pages_text = []
637
+ import pdfplumber
638
+ with pdfplumber.open(pdf_stream) as pdf:
639
+ for i, page in enumerate(pdf.pages):
640
+ try:
641
+ page_text = page.extract_text() or ""
642
+ pages_text.append(page_text)
643
+ except Exception as page_err:
644
+ logger.warning(f"PDF 第{i+1}页解析失败: {page_err}")
645
+ pages_text.append("")
646
+ text = "\n".join(pages_text)
647
+ elif fname.endswith(".docx"):
648
+ from docx import Document
649
+ doc = Document(file)
650
+ text = "\n".join(para.text for para in doc.paragraphs)
651
+ except Exception as e:
652
+ logger.error(f"文件解析失败 [{file.name}]: {e}", exc_info=True)
653
+ st.error(f"解析失败: {e}")
654
+ return text
655
+
656
+
657
+ def process_upload(uploaded_files, target_prefix, scope):
658
+ """处理上传文件:解析 → 切片 → 编码 → 写入 Supabase。"""
659
+ if not uploaded_files:
660
+ return False
661
+ file_fingerprint = str(sorted((f.name, f.size) for f in uploaded_files))
662
+ fp_key = f"_last_upload_fp_{target_prefix}"
663
+ if file_fingerprint == st.session_state.get(fp_key):
664
+ return False
665
+
666
+ try:
667
+ all_new_chunks = []
668
+ all_new_sources = []
669
+ with st.spinner("正在自动解析文档并更新索引..."):
670
+ for f in uploaded_files:
671
+ try:
672
+ # 保存原始文件到 Supabase Storage
673
+ f.seek(0)
674
+ _save_uploaded_file_to_storage(scope, f)
675
+
676
+ # 解析文本
677
+ f.seek(0)
678
+ raw_text = extract_text(f)
679
+ if not raw_text.strip():
680
+ st.warning(f"文件 {f.name} 内容为空,已跳过。")
681
+ continue
682
+ chunks = _get_text_splitter().split_text(raw_text)
683
+ all_new_chunks.extend(chunks)
684
+ all_new_sources.extend([f.name] * len(chunks))
685
+ except Exception as file_err:
686
+ logger.error(f"文件 {f.name} 处理失败: {file_err}", exc_info=True)
687
+ st.warning(f"⚠️ 文件 {f.name} 处理失败:{str(file_err)[:100]},已跳过。")
688
+
689
+ if all_new_chunks:
690
+ # 分批编码
691
+ batch_size = 64
692
+ all_vecs = []
693
+ for i in range(0, len(all_new_chunks), batch_size):
694
+ batch = all_new_chunks[i:i + batch_size]
695
+ all_vecs.extend(encode_texts(batch))
696
+
697
+ # 按 source_file 分组写入 Supabase
698
+ file_groups = {}
699
+ for chunk, vec, src in zip(all_new_chunks, all_vecs, all_new_sources):
700
+ file_groups.setdefault(src, ([], []))
701
+ file_groups[src][0].append(chunk)
702
+ file_groups[src][1].append(vec)
703
+
704
+ for src_file, (chunks, vecs) in file_groups.items():
705
+ _save_chunks_to_db(scope, chunks, vecs, src_file)
706
+
707
+ # 刷新 session_state 缓存
708
+ _refresh_library(target_prefix, scope)
709
+
710
+ st.session_state[fp_key] = file_fingerprint
711
+ # 递增上传组件 key
712
+ ukey = f"_upload_ver_{target_prefix}"
713
+ st.session_state[ukey] = st.session_state.get(ukey, 0) + 1
714
+ st.success(f"自动导入 {len(all_new_chunks)} 个知识切片!")
715
+ time.sleep(1)
716
+ st.rerun()
717
+ else:
718
+ st.error("解析失败,未发现有效文字内容。")
719
+ except Exception as e:
720
+ logger.error(f"上传处理异常: {e}", exc_info=True)
721
+ st.error(f"❌ 上传处理出错:{str(e)[:200]}")
722
+ return False
723
+
724
+
725
+ # =========================
726
+ # 7. 侧边栏 UI & 逻辑
727
+ # =========================
728
+ model_mapping = {
729
+ "⭐ Step-3.5 (首选)": "stepfun/step-3.5-flash:free",
730
+ "🌐 OR-Auto (避堵)": "openrouter/free",
731
+ "🧠 GLM-4.5 (推理)": "z-ai/glm-4.5-air:free",
732
+ "🔥 Gemma-3-27B (旗舰)": "google/gemma-3-27b-it:free",
733
+ "🐋 Nemotron (120B)": "nvidia/nemotron-3-super-120b-a12b:free",
734
+ "⚡ Trinity-L (极速)": "arcee-ai/trinity-large-preview:free",
735
+ "💭 Liquid-Think (思维链)": "liquid/lfm-2.5-1.2b-thinking:free",
736
+ "🏎️ Liquid-Ins (1.0s)": "liquid/lfm-2.5-1.2b-instruct:free",
737
+ "⚖️ Gemma-3-12B (平衡)": "google/gemma-3-12b-it:free",
738
+ "💎 Gemma-3n-e4b (稳)": "google/gemma-3n-e4b-it:free",
739
+ "🤖 Nemotron-Nano (混)": "nvidia/nemotron-3-nano-30b-a3b:free",
740
+ "📉 Trinity-M (1.8s)": "arcee-ai/trinity-mini:free",
741
+ "🍃 Nemotron-9B": "nvidia/nemotron-nano-9b-v2:free",
742
+ "🪶 Gemma-3-4B": "google/gemma-3-4b-it:free",
743
+ "🫧 Gemma-3n-e2b": "google/gemma-3n-e2b-it:free",
744
+ "📷 Nemotron-VL": "nvidia/nemotron-nano-12b-v2-vl:free",
745
+ "🛡️ DeepSeek (官方)": "deepseek-chat",
746
+ "🏢 百度文心 (官方)": "ernie-3.5-8k",
747
+ }
748
+
749
+ _perf("before sidebar UI")
750
+ with st.sidebar:
751
+ pub_chunk_count = len(st.session_state.get("public_docs", []))
752
+ with st.expander(f"📚 公共知识库({pub_chunk_count} 切片)"):
753
+ st.caption("所有人可搜索")
754
+
755
+ # 文件列表
756
+ pub_file_list = _list_uploaded_files_db(PUBLIC_SCOPE)
757
+ if pub_file_list:
758
+ st.caption(f"📎 已上传 {len(pub_file_list)} 个文件:")
759
+ for fname, size_str, _ in pub_file_list:
760
+ if IS_ADMIN:
761
+ col_name, col_del = st.columns([4, 1])
762
+ col_name.text(f"📄 {fname} ({size_str})")
763
+ if col_del.button("🗑", key=f"delpub_{fname}", help=f"删除 {fname}"):
764
+ _delete_chunks_by_file(PUBLIC_SCOPE, fname)
765
+ _delete_uploaded_file_from_storage(PUBLIC_SCOPE, fname)
766
+ _refresh_library("public", PUBLIC_SCOPE)
767
+ st.success(f"已删除 {fname}")
768
+ time.sleep(0.5)
769
+ st.rerun()
770
+ else:
771
+ st.text(f"📄 {fname} ({size_str})")
772
+
773
+ if IS_ADMIN:
774
+ pub_upload_key = f"upload_public_{st.session_state.get('_upload_ver_public', 0)}"
775
+ pub_files = st.file_uploader(
776
+ "上传到公共库",
777
+ type=["txt", "pdf", "docx"],
778
+ accept_multiple_files=True,
779
+ label_visibility="collapsed",
780
+ key=pub_upload_key,
781
+ )
782
+ if pub_files:
783
+ process_upload(pub_files, "public", PUBLIC_SCOPE)
784
+
785
+ if pub_chunk_count > 0 and len(pub_file_list) >= 2:
786
+ if st.button("🗑️ 清空公共库", use_container_width=True, type="secondary", key="clear_pub"):
787
+ _clear_all_chunks(PUBLIC_SCOPE)
788
+ _clear_uploaded_files_storage(PUBLIC_SCOPE)
789
+ _refresh_library("public", PUBLIC_SCOPE)
790
+ st.success("公共知识库已清空。")
791
+ time.sleep(0.5)
792
+ st.rerun()
793
+ else:
794
+ st.caption("*仅管理员可维护公共库*")
795
+
796
+ # --- 私有知识库 ---
797
+ priv_chunk_count = len(st.session_state.get("private_docs", []))
798
+ with st.expander(f"🔒 我的私有库({priv_chunk_count} 切片)"):
799
+ st.caption(f"用户:{CURRENT_USER},仅自己可见")
800
+
801
+ priv_file_list = _list_uploaded_files_db(PRIVATE_SCOPE)
802
+ if priv_file_list:
803
+ st.caption(f"📎 已上传 {len(priv_file_list)} 个文件:")
804
+ for fname, size_str, _ in priv_file_list:
805
+ col_name, col_del = st.columns([4, 1])
806
+ col_name.text(f"📄 {fname} ({size_str})")
807
+ if col_del.button("🗑", key=f"delpriv_{fname}", help=f"删除 {fname}"):
808
+ _delete_chunks_by_file(PRIVATE_SCOPE, fname)
809
+ _delete_uploaded_file_from_storage(PRIVATE_SCOPE, fname)
810
+ _refresh_library("private", PRIVATE_SCOPE)
811
+ st.success(f"已删除 {fname}")
812
+ time.sleep(0.5)
813
+ st.rerun()
814
+
815
+ priv_upload_key = f"upload_private_{st.session_state.get('_upload_ver_private', 0)}"
816
+ priv_files = st.file_uploader(
817
+ "上传到私有库",
818
+ type=["txt", "pdf", "docx"],
819
+ accept_multiple_files=True,
820
+ label_visibility="collapsed",
821
+ key=priv_upload_key,
822
+ )
823
+ if priv_files:
824
+ process_upload(priv_files, "private", PRIVATE_SCOPE)
825
+
826
+ if priv_chunk_count > 0 and len(priv_file_list) >= 2:
827
+ if st.button("🗑️ 清空我的私有库", use_container_width=True, type="secondary", key="clear_priv"):
828
+ _clear_all_chunks(PRIVATE_SCOPE)
829
+ _clear_uploaded_files_storage(PRIVATE_SCOPE)
830
+ _refresh_library("private", PRIVATE_SCOPE)
831
+ st.success("私有知识库已清空。")
832
+ time.sleep(0.5)
833
+ st.rerun()
834
+
835
+ # --- 模型设置 ---
836
+ with st.expander("⚙️ 模型设置"):
837
+ selected_display_name = st.selectbox(
838
+ "模型", list(model_mapping.keys()), index=0, label_visibility="collapsed"
839
+ )
840
+
841
+ web_on = st.toggle("🌐 联网增强", value=False)
842
+
843
+ c1, c2 = st.columns(2)
844
+ with c1:
845
+ ui_top_k = st.number_input("Top-K", 1, 15, 5)
846
+ with c2:
847
+ ui_threshold = st.number_input("阈值", 0.0, 1.0, 0.25, step=0.05)
848
+
849
+ # --- 修改密码 ---
850
+ with st.expander("🔐 修改密码"):
851
+ old_pass = st.text_input("当前密码", type="password", key="self_old_pass")
852
+ new_pass1 = st.text_input("新密码", type="password", key="self_new_pass1")
853
+ new_pass2 = st.text_input("确认新密码", type="password", key="self_new_pass2")
854
+ if st.button("✅ 确认修改", key="btn_change_pass"):
855
+ ok, _ = verify_user(CURRENT_USER, old_pass)
856
+ if not ok:
857
+ st.error("当前密码错误")
858
+ elif len(new_pass1) < 4:
859
+ st.error("新密码至少 4 个字符")
860
+ elif new_pass1 != new_pass2:
861
+ st.error("两次新密码不一致")
862
+ else:
863
+ _save_user(CURRENT_USER, _hash_password(new_pass1),
864
+ st.session_state.current_role)
865
+ st.success("密码修改成功,请重新登录")
866
+ time.sleep(1)
867
+ st.rerun()
868
+
869
+ # --- 管理员面板 ---
870
+ if IS_ADMIN:
871
+ with st.expander("👥 用户管理"):
872
+ all_users = _load_users()
873
+ user_list = [(u, info) for u, info in all_users.items() if isinstance(info, dict)]
874
+
875
+ st.caption(f"共 **{len(user_list)}** 个用户")
876
+ for uname, uinfo in user_list:
877
+ role_tag = "👑" if uinfo.get("role") == "admin" else "👤"
878
+ created = uinfo.get("created_at", "未知")
879
+ st.text(f"{role_tag} {uname}({created})")
880
+
881
+ deletable = [u for u, _ in user_list if u != CURRENT_USER]
882
+ if deletable:
883
+ del_target = st.selectbox("选择要删除的用户", deletable, key="del_user_select")
884
+ if st.button("❌ 删除该用户", key="btn_del_user"):
885
+ _delete_user_db(del_target)
886
+ # 清除该用户的私有库
887
+ _clear_all_chunks(del_target)
888
+ _clear_uploaded_files_storage(del_target)
889
+ st.success(f"用户 {del_target} 已删除")
890
+ time.sleep(0.5)
891
+ st.rerun()
892
+
893
+ resetable = [u for u, _ in user_list if u != CURRENT_USER]
894
+ if resetable:
895
+ reset_target = st.selectbox("选择要重置密码的用户", resetable, key="reset_user_select")
896
+ new_pass = st.text_input("新密码", type="password", key="reset_new_pass")
897
+ if st.button("🔄 重置密码", key="btn_reset_pass"):
898
+ if len(new_pass) < 4:
899
+ st.error("密码至少 4 个字符")
900
+ else:
901
+ target_role = all_users[reset_target].get("role", "user")
902
+ _save_user(reset_target, _hash_password(new_pass), target_role)
903
+ st.success(f"用户 {reset_target} 密码已重置")
904
+ time.sleep(0.5)
905
+ st.rerun()
906
+
907
+ with st.expander("📩 邀请码管理"):
908
+ current_code = _get_invite_code()
909
+ st.text(f"当前邀请码:{current_code if current_code else '未设置'}")
910
+ new_code = st.text_input("新邀请码", key="new_invite_code")
911
+ if st.button("✏️ 更新邀请码", key="btn_update_code"):
912
+ if new_code.strip():
913
+ _set_invite_code(new_code.strip())
914
+ st.success("邀请码已更新")
915
+ time.sleep(0.5)
916
+ st.rerun()
917
+ else:
918
+ st.error("邀请码不能为空")
919
+
920
+ with st.expander("🛠️ 数据库概览"):
921
+ st.caption("Supabase 数据统计")
922
+ try:
923
+ pub_cnt = _count_chunks(PUBLIC_SCOPE)
924
+ st.text(f"📚 公共库切片数: {pub_cnt}")
925
+
926
+ # 统计所有 scope
927
+ resp = _sb().rpc("", {}).execute() if False else None # placeholder
928
+ # 简单统计各用户私有库
929
+ for uname, _ in user_list:
930
+ cnt = _count_chunks(uname)
931
+ if cnt > 0:
932
+ st.text(f"🔒 {uname} 私有库: {cnt} 切片")
933
+ except Exception as e:
934
+ st.warning(f"统计失败: {e}")
935
+
936
+ st.divider()
937
+ st.caption("📋 用户列表")
938
+ display_users = {}
939
+ for k, v in all_users.items():
940
+ if isinstance(v, dict) and "password_hash" in v:
941
+ v_copy = dict(v)
942
+ v_copy["password_hash"] = v_copy["password_hash"][:8] + "..."
943
+ display_users[k] = v_copy
944
+ else:
945
+ display_users[k] = v
946
+ st.json(display_users)
947
+
948
+ # --- 清空聊天记录 ---
949
+ with st.expander("🧹 清空聊天记录"):
950
+ st.caption("清空后不可恢复")
951
+ if st.button("确认清空", use_container_width=True, type="secondary", key="btn_clear_chat"):
952
+ st.session_state.messages = []
953
+ st.rerun()
954
+
955
+
956
+ # =========================
957
+ # 8. 核心搜索逻辑(合并公共库 + 私有库)
958
+ # =========================
959
+ def _cosine_scores(query_vec, matrix):
960
+ np = _get_np()
961
+ query_norm = np.linalg.norm(query_vec)
962
+ if query_norm < 1e-10:
963
+ return np.zeros(matrix.shape[0])
964
+ mat_norms = np.linalg.norm(matrix, axis=1)
965
+ mat_norms = np.maximum(mat_norms, 1e-10)
966
+ return (matrix @ query_vec) / (mat_norms * query_norm)
967
+
968
+
969
+ def search_local(query, top_k, threshold):
970
+ query_vec = encode_query(query)
971
+ all_results = []
972
+
973
+ pub_docs = st.session_state.get("public_docs", [])
974
+ pub_np = _get_embeddings_np("public")
975
+ if pub_docs and pub_np.size > 0:
976
+ scores = _cosine_scores(query_vec, pub_np)
977
+ for i, s in enumerate(scores):
978
+ if s > threshold:
979
+ all_results.append((float(s), pub_docs[i]))
980
+
981
+ priv_docs = st.session_state.get("private_docs", [])
982
+ priv_np = _get_embeddings_np("private")
983
+ if priv_docs and priv_np.size > 0:
984
+ scores = _cosine_scores(query_vec, priv_np)
985
+ for i, s in enumerate(scores):
986
+ if s > threshold:
987
+ all_results.append((float(s), priv_docs[i]))
988
+
989
+ all_results.sort(key=lambda x: x[0], reverse=True)
990
+ return [doc for _, doc in all_results[:top_k]]
991
+
992
+
993
+ # =========================
994
+ # 9. LLM 回答逻辑
995
+ # =========================
996
+ def llm_answer(query, context_docs, selected_display_name, web_enabled):
997
+ all_context = ""
998
+ curr_time = datetime.now().strftime("%Y-%m-%d %H:%M")
999
+
1000
+ if context_docs:
1001
+ all_context += "【知识库资料】:\n" + "\n".join(context_docs) + "\n"
1002
+
1003
+ if web_enabled:
1004
+ search_res = web_search(query)
1005
+ all_context += f"\n【互联网实时资料】:\n{search_res}"
1006
+
1007
+ prompt_content = f"当前时间:{curr_time}\n\n参考资料:\n{all_context[:6500]}\n\n用户问题:{query}"
1008
+ input_tokens = estimate_tokens(prompt_content)
1009
+
1010
+ or_client = get_or_client()
1011
+ ds_client = get_ds_client()
1012
+ baidu_client = get_baidu_client()
1013
+
1014
+ special_clients = {"deepseek-chat": ds_client, "ernie-3.5-8k": baidu_client}
1015
+ selected_id = model_mapping[selected_display_name]
1016
+
1017
+ retry_queue = []
1018
+ retry_queue.append(
1019
+ (special_clients.get(selected_id, or_client), selected_id, f"首选-{selected_display_name}")
1020
+ )
1021
+
1022
+ if selected_id != "stepfun/step-3.5-flash:free":
1023
+ retry_queue.append((or_client, "stepfun/step-3.5-flash:free", "⚡ 快速备选-Step3.5"))
1024
+
1025
+ if selected_id != "openrouter/free":
1026
+ retry_queue.append((or_client, "openrouter/free", "OR-Auto 免费避堵"))
1027
+
1028
+ paid_backups = [
1029
+ ("deepseek-chat", "🛡️ DeepSeek 官方", ds_client),
1030
+ ("ernie-3.5-8k", "🏢 百度文心", baidu_client),
1031
+ ]
1032
+ for p_id, p_label, p_client in paid_backups:
1033
+ if selected_id != p_id:
1034
+ retry_queue.append((p_client, p_id, f"💰 收费兜底-{p_label}"))
1035
+
1036
+ messages = [
1037
+ {"role": "system", "content": SYSTEM_PROMPT},
1038
+ {"role": "user", "content": prompt_content},
1039
+ ]
1040
+
1041
+ for idx, (client, m_id, label) in enumerate(retry_queue):
1042
+ logger.info(f"[{CURRENT_USER}] 尝试链路: {label}")
1043
+ try:
1044
+ extra_h = (
1045
+ {"HTTP-Referer": "https://streamlit.io", "X-Title": "RAG_v3"}
1046
+ if client is or_client
1047
+ else None
1048
+ )
1049
+ response = client.chat.completions.create(
1050
+ model=m_id,
1051
+ messages=messages,
1052
+ stream=True,
1053
+ extra_headers=extra_h,
1054
+ timeout=25,
1055
+ )
1056
+
1057
+ full_text = ""
1058
+ has_content = False
1059
+ for chunk in response:
1060
+ if chunk.choices and chunk.choices[0].delta.content:
1061
+ content = chunk.choices[0].delta.content
1062
+ full_text += content
1063
+ has_content = True
1064
+ yield content
1065
+
1066
+ if has_content:
1067
+ st.session_state["last_meta"] = (
1068
+ f"🟢 {label} | 📊 ~{input_tokens}/{estimate_tokens(full_text)} Tokens"
1069
+ )
1070
+ return
1071
+
1072
+ except Exception as e:
1073
+ err_msg = str(e)
1074
+ logger.warning(f"{label} 失败: {err_msg[:100]}")
1075
+ if "429" in err_msg:
1076
+ st.toast(f"{label} 拥堵,切换备选...", icon="⏳")
1077
+ time.sleep(1.5)
1078
+ continue
1079
+
1080
+ yield "❌ 抱歉,所有免费和收费线路均暂时不可用。"
1081
+
1082
+
1083
+ # =========================
1084
+ # 10. 聊天渲染
1085
+ # =========================
1086
+ if "messages" not in st.session_state:
1087
+ st.session_state.messages = []
1088
+
1089
+ for m in st.session_state.messages:
1090
+ with st.chat_message(m["role"]):
1091
+ st.markdown(m["content"])
1092
+ if "meta" in m:
1093
+ st.caption(m["meta"])
1094
+
1095
+ if q := st.chat_input("输入问题...", key="chat_input_v3"):
1096
+ st.session_state.messages.append({"role": "user", "content": q})
1097
+ with st.chat_message("user"):
1098
+ st.markdown(q)
1099
+
1100
+ with st.chat_message("assistant"):
1101
+ relevant_docs = search_local(q, ui_top_k, ui_threshold)
1102
+ container = st.empty()
1103
+ container.markdown("*🤔 正在组织语言...*")
1104
+
1105
+ if web_on:
1106
+ with st.status("🌐 正在抓取实时网络数据...", expanded=False) as s:
1107
+ time.sleep(0.1)
1108
+ s.update(label="✅ 网络资料已就绪", state="complete")
1109
+
1110
+ try:
1111
+ full_response = container.write_stream(
1112
+ llm_answer(q, relevant_docs, selected_display_name, web_on)
1113
+ )
1114
+ meta_info = st.session_state.get("last_meta", "")
1115
+ st.caption(meta_info)
1116
+ st.session_state.messages.append(
1117
+ {"role": "assistant", "content": full_response, "meta": meta_info}
1118
+ )
1119
+ except Exception as e:
1120
+ logger.error(f"模型调用异常: {e}")
1121
+ container.error(f"❌ 抱歉,连接模型时出错了: {str(e)}")
1122
+
1123
+ _perf("script execution complete")
requirements.txt CHANGED
@@ -1,3 +1,8 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
 
1
+ streamlit
2
+ numpy
3
+ openai
4
+ langchain-text-splitters
5
+ python-docx
6
+ tavily-python
7
+ pdfplumber
8
+ supabase