import os # 优化 Streamlit 连接稳定性 os.environ.setdefault("STREAMLIT_SERVER_ENABLE_WEBSOCKET_COMPRESSION", "false") os.environ.setdefault("STREAMLIT_SERVER_MAX_MESSAGE_SIZE", "200") import streamlit as st import streamlit.components.v1 as components import json import time import logging import hashlib import uuid from datetime import datetime logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") logger = logging.getLogger(__name__) # numpy 延迟导入 _np_module = None def _get_np(): global _np_module if _np_module is None: import numpy _np_module = numpy return _np_module # ========================= # 1. 页面配置 & 样式注入 # ========================= st.set_page_config(page_title="智答 AI 助手", page_icon="🤖", layout="wide", initial_sidebar_state="collapsed") @st.cache_data def _get_custom_css(): return """ """ def inject_custom_css(): st.markdown(_get_custom_css(), unsafe_allow_html=True) inject_custom_css() def _render_action_buttons(content, msg_id): """生成消息底部的复制/分享/播报按钮(纯HTML部分)""" return f'''
''' def _inject_action_js(content, msg_id): """注入按钮的 JavaScript(需要用 components.html 执行)""" # 转义内容用于 JavaScript escaped = content.replace("\\", "\\\\").replace("`", "\\`").replace("$", "\\$").replace("", "<\\/script>") js_code = f''' ''' components.html(js_code, height=0) # 全局游客标识(session_state 在单次渲染中不变) IS_GUEST = not bool(st.session_state.get("current_user")) # 标题(固定在顶部)- 游客始终显示提示 _guest_tip = """
🙋 当前为游客模式,可直接提问体验
""" if IS_GUEST else "" # 根据是否游客添加不同 class,用于 CSS 控制间距 _title_class = "main-title main-title-guest" if IS_GUEST else "main-title main-title-user" st.markdown( f"""

🤖 智答 AI 助手

{_guest_tip}
""", unsafe_allow_html=True, ) # ========================= # 1.5 Supabase 客户端初始化 # ========================= from supabase import create_client SUPABASE_URL = st.secrets.get("SUPABASE_URL", "") SUPABASE_KEY = st.secrets.get("SUPABASE_KEY", "") # service_role key (后端使用) STORAGE_BUCKET = "rag-files" if not SUPABASE_URL or not SUPABASE_KEY: st.error("⚠️ 未配置 SUPABASE_URL 或 SUPABASE_KEY,请在 Secrets 中设置。") st.stop() @st.cache_resource def _get_supabase(): return create_client(SUPABASE_URL, SUPABASE_KEY) def _sb(): """快捷获取 Supabase 客户端。""" return _get_supabase() # ========================= # 2. 用户管理(Supabase users 表) # ========================= MAX_LOGIN_ATTEMPTS = 10 def _hash_password(password): return hashlib.sha256(password.encode("utf-8")).hexdigest() @st.cache_data(ttl=30) # 缓存30秒,减少重复查询 def _load_users(): """从 Supabase users 表加载所有用户,返回 {username: {password_hash, role, created_at}}。""" try: resp = _sb().table("users").select("*").execute() users = {} for row in resp.data: users[row["username"]] = { "password_hash": row["password_hash"], "role": row["role"], "created_at": row["created_at"][:16] if row.get("created_at") else "未知", } return users except Exception as e: logger.error(f"加载用户表失败: {e}") return {} def _ensure_admin(): """首次运行时,从 secrets 创建管理员(如果 users 表为空)。""" users = _load_users() if users: return admin_user = st.secrets.get("ADMIN_USER", "admin") admin_pass = st.secrets.get("ADMIN_PASSWORD", "") if not admin_pass: return try: _sb().table("users").upsert({ "username": admin_user, "password_hash": _hash_password(admin_pass), "role": "admin", }).execute() logger.info(f"初始管理员 {admin_user} 已创建") except Exception as e: logger.error(f"创建管理员失败: {e}") # 仅首次运行时检查管理员(避免每次刷新都查询) if not st.session_state.get("_admin_ensured"): _ensure_admin() st.session_state["_admin_ensured"] = True def _save_user(username, password_hash, role="user"): """新增或更新单个用户。""" _sb().table("users").upsert({ "username": username, "password_hash": password_hash, "role": role, }).execute() _load_users.clear() # 清除用户列表缓存 def _delete_user_db(username): """删除用户记录。""" _sb().table("users").delete().eq("username", username).execute() _load_users.clear() # 清除用户列表缓存 @st.cache_data(ttl=60) # 缓存60秒 def _get_invite_code(): """从 app_meta 表读取邀请码。""" try: resp = _sb().table("app_meta").select("value").eq("key", "invite_code").execute() if resp.data: return resp.data[0]["value"] except Exception: pass return st.secrets.get("INVITE_CODE", "") def _set_invite_code(new_code): _sb().table("app_meta").upsert({"key": "invite_code", "value": new_code}).execute() _get_invite_code.clear() # 清除缓存 def register_user(username, password, invite_code): if not username or not password: return False, "用户名和密码不能为空" if len(username) < 2 or len(username) > 20: return False, "用户名长度需要 2-20 个字符" if len(password) < 4: return False, "密码至少 4 个字符" if username.startswith("__"): return False, "用户名不能以 __ 开头" correct_code = _get_invite_code() if not correct_code: return False, "邀请码未配置,请联系管理员" if invite_code != correct_code: return False, "邀请码错误" users = _load_users() if username in users: return False, "用户名已存在" try: _save_user(username, _hash_password(password), "user") logger.info(f"新用户注册: {username}") return True, "注册成功,请登录" except Exception as e: logger.error(f"注册失败: {e}") return False, f"注册失败: {e}" def verify_user(username, password): users = _load_users() user_info = users.get(username) if not user_info or not isinstance(user_info, dict): return False, None, "not_found" if user_info.get("password_hash") != _hash_password(password): return False, None, "wrong_password" return True, user_info.get("role", "user"), "" # --- 认证 UI --- if "login_attempts" not in st.session_state: st.session_state.login_attempts = 0 if "current_user" not in st.session_state: st.session_state.current_user = None if "current_role" not in st.session_state: st.session_state.current_role = None if "auth_mode" not in st.session_state: st.session_state.auth_mode = "login" with st.sidebar: # 未登录时不用 expander,直接显示登录表单(避免手动收起后无法自动展开) if IS_GUEST: st.subheader("🔑 账号") if st.session_state.login_attempts >= MAX_LOGIN_ATTEMPTS: st.error("🚫 尝试次数过多,请刷新页面后重试。") st.stop() users_data = _load_users() if not users_data: st.error("⚠️ 未配置管理员,请在 secrets 中设置 ADMIN_USER 和 ADMIN_PASSWORD。") st.stop() auth_mode = st.radio( "操作", ["登录", "注册"], horizontal=True, label_visibility="collapsed", key="auth_radio", ) if auth_mode == "登录": with st.form("login_form", clear_on_submit=False): input_username = st.text_input("用户名", key="login_user") input_password = st.text_input("密码", type="password", key="login_pass") submitted = st.form_submit_button("🔓 登录", use_container_width=True, type="primary") if submitted: if input_username == "" or input_password == "": st.warning("请输入用户名和密码") elif st.session_state.login_attempts >= MAX_LOGIN_ATTEMPTS: pass else: ok, role, reason = verify_user(input_username, input_password) if not ok: st.session_state.login_attempts += 1 remaining = MAX_LOGIN_ATTEMPTS - st.session_state.login_attempts if reason == "not_found": st.warning(f"⚠️ 用户不存在,请先注册(剩余 {remaining} 次)") else: st.warning(f"⚠️ 密码错误(剩余 {remaining} 次)") else: st.session_state.login_attempts = 0 st.session_state.current_user = input_username st.session_state.current_role = role st.rerun() else: # 注册 with st.form("register_form", clear_on_submit=False): reg_user = st.text_input("用户名", key="reg_user") reg_pass = st.text_input("密码", type="password", key="reg_pass") reg_pass2 = st.text_input("确认密码", type="password", key="reg_pass2") reg_code = st.text_input("邀请码", type="password", key="reg_code") reg_submitted = st.form_submit_button("注册", use_container_width=True) if reg_submitted: if reg_pass != reg_pass2: st.error("两次密码不一致") else: ok, msg = register_user(reg_user, reg_pass, reg_code) if ok: st.session_state.current_user = reg_user st.session_state.current_role = "user" st.session_state.login_attempts = 0 st.rerun() else: st.error(f"❌ {msg}") else: with st.expander("🔑 账号"): role_label = "管理员" if st.session_state.current_role == "admin" else "普通用户" st.success(f"✅ {st.session_state.current_user}({role_label})") # 设置全局用户标识 if st.session_state.get("current_user"): CURRENT_USER = st.session_state.current_user IS_ADMIN = st.session_state.current_role == "admin" IS_GUEST = False else: CURRENT_USER = "_guest_" IS_ADMIN = False IS_GUEST = True # 登录后自动收起侧边栏(通过 JS 模拟点击收起按钮) if not IS_GUEST and not st.session_state.get("_sidebar_collapsed_once"): st.session_state["_sidebar_collapsed_once"] = True components.html( """ """, height=0, ) # ========================= # 3. 安全配置与 Embedding 策略 # ========================= TAVILY_KEY = st.secrets.get("TAVILY_API_KEY", "") DS_API_KEY = st.secrets.get("DEEPSEEK_API_KEY", "") BAIDU_TOKEN = st.secrets.get("BAIDU_BEARER_TOKEN", "") BAIDU_APP_ID = st.secrets.get("BAIDU_APP_ID", "") OR_KEY = st.secrets.get("OPENROUTER_API_KEY", "") @st.cache_resource def _get_embedding_client(): from openai import OpenAI if BAIDU_TOKEN and BAIDU_APP_ID: return OpenAI( api_key=BAIDU_TOKEN, base_url="https://qianfan.baidubce.com/v2", default_headers={"appid": BAIDU_APP_ID}, ), "qwen3-embedding-0.6b" # 升级:8192 token上下文,性能更优 if OR_KEY: return OpenAI( api_key=OR_KEY, base_url="https://openrouter.ai/api/v1", ), "BAAI/bge-small-zh" return None, None def _api_encode(texts): np = _get_np() client, model = _get_embedding_client() if client is None: return None try: batch_size = 32 all_vecs = [] for i in range(0, len(texts), batch_size): batch = texts[i:i + batch_size] resp = client.embeddings.create(model=model, input=batch) all_vecs.extend([np.array(item.embedding) for item in resp.data]) return all_vecs except Exception as e: logger.warning(f"API embedding 失败,回退到本地模型: {e}") return None def _get_local_model(): if "_local_emb_model" not in st.session_state: try: with st.spinner("API 不可用,正在加载本地向量模型(仅首次)..."): from sentence_transformers import SentenceTransformer st.session_state._local_emb_model = SentenceTransformer("BAAI/bge-small-zh") except ImportError: logger.error("sentence-transformers 未安装,本地模型不可用") return None return st.session_state.get("_local_emb_model") def encode_texts(texts): if not texts: return [] if isinstance(texts, str): texts = [texts] result = _api_encode(texts) if result is not None: return result model = _get_local_model() if model is None: st.error("❌ Embedding 服务不可用:API 调用失败且本地模型未安装。请检查 API Key 配置。") return [] return list(model.encode(texts)) def encode_query(text): vecs = encode_texts([text]) return vecs[0] # ========================= # 4. Supabase 索引管理(替代本地文件) # ========================= def _save_chunks_to_db(scope, chunks, vectors, source_file): """将新切片批量写入 Supabase documents 表。""" rows = [] for content, vec, src in zip(chunks, vectors, [source_file] * len(chunks)): rows.append({ "scope": scope, "source_file": src, "content": content, "embedding": vec.tolist() if hasattr(vec, 'tolist') else list(vec), }) # Supabase 批量插入(每次最多 500 行) batch_size = 500 for i in range(0, len(rows), batch_size): _sb().table("documents").insert(rows[i:i + batch_size]).execute() def _delete_chunks_by_file(scope, filename): """删除指定 scope + filename 的所有切片。""" _sb().table("documents").delete().eq("scope", scope).eq("source_file", filename).execute() def _clear_all_chunks(scope): """清空指定 scope 的所有文档切片。""" _sb().table("documents").delete().eq("scope", scope).execute() def _count_chunks(scope): """返回指定 scope 的切片数量。""" try: resp = _sb().table("documents").select("id", count="exact").eq("scope", scope).execute() return resp.count or 0 except Exception: return 0 # --- 原始文件管理(Supabase Storage + uploaded_files 表)--- def _safe_storage_path(scope, filename): """生成 ASCII 安全的 Storage 路径(Supabase Storage 不支持中文/特殊字符)。""" import os ext = os.path.splitext(filename)[1].lower() # 保留扩展名 safe_name = uuid.uuid4().hex[:16] + ext return f"{scope}/{safe_name}" def _save_uploaded_file_to_storage(scope, uploaded_file): """上传原始文件到 Supabase Storage,并记录元数据。""" uploaded_file.seek(0) file_bytes = uploaded_file.read() # 检查是否已有同名文件记录,复用其 storage_path try: existing = _sb().table("uploaded_files").select("storage_path").eq( "scope", scope).eq("filename", uploaded_file.name).execute() if existing.data: storage_path = existing.data[0]["storage_path"] else: storage_path = _safe_storage_path(scope, uploaded_file.name) except Exception: storage_path = _safe_storage_path(scope, uploaded_file.name) # 上传到 Storage(存在则覆盖) try: _sb().storage.from_(STORAGE_BUCKET).upload( storage_path, file_bytes, file_options={"content-type": "application/octet-stream", "upsert": "true"} ) except Exception as e: logger.warning(f"Storage upload fallback: {e}") try: _sb().storage.from_(STORAGE_BUCKET).remove([storage_path]) except Exception: pass _sb().storage.from_(STORAGE_BUCKET).upload( storage_path, file_bytes, file_options={"content-type": "application/octet-stream"} ) # 记录元数据到 uploaded_files 表(filename 保留原始中文名) _sb().table("uploaded_files").upsert({ "scope": scope, "filename": uploaded_file.name, "file_size": len(file_bytes), "storage_path": storage_path, }, on_conflict="scope,filename").execute() _list_uploaded_files_db.clear() # 清除文件列表缓存 @st.cache_data(ttl=15) # 缓存15秒,文件列表变化不频繁 def _list_uploaded_files_db(scope): """列出某个 scope 已上传的文件。返回 [(filename, size_str, storage_path), ...]。""" try: resp = _sb().table("uploaded_files").select( "filename, file_size, storage_path" ).eq("scope", scope).order("filename").execute() result = [] for row in resp.data: size = row.get("file_size", 0) or 0 if size < 1024: size_str = f"{size}B" elif size < 1048576: size_str = f"{size / 1024:.1f}KB" else: size_str = f"{size / 1048576:.1f}MB" result.append((row["filename"], size_str, row.get("storage_path", ""))) return result except Exception as e: logger.error(f"列出文件失败 [scope={scope}]: {e}") return [] def _delete_uploaded_file_from_storage(scope, filename): """删除 Storage 中的文件和 uploaded_files 表记录。""" try: # 从数据库查真实 storage_path resp = _sb().table("uploaded_files").select("storage_path").eq( "scope", scope).eq("filename", filename).execute() if resp.data: storage_path = resp.data[0]["storage_path"] _sb().storage.from_(STORAGE_BUCKET).remove([storage_path]) except Exception as e: logger.warning(f"Storage 删除失败: {e}") try: _sb().table("uploaded_files").delete().eq("scope", scope).eq("filename", filename).execute() _list_uploaded_files_db.clear() # 清除文件列表缓存 except Exception as e: logger.warning(f"uploaded_files 记录删除失败: {e}") def _clear_uploaded_files_storage(scope): """清空某个 scope 的所有上传文件。""" try: resp = _sb().table("uploaded_files").select("storage_path").eq("scope", scope).execute() paths = [row["storage_path"] for row in resp.data] if paths: _sb().storage.from_(STORAGE_BUCKET).remove(paths) _sb().table("uploaded_files").delete().eq("scope", scope).execute() _list_uploaded_files_db.clear() # 清除文件列表缓存 except Exception as e: logger.warning(f"清空文件失败 [scope={scope}]: {e}") PUBLIC_SCOPE = "public" PRIVATE_SCOPE = CURRENT_USER # 私有库 scope = 用户名 # --- 定时同步:检测其他用户对文档库的修改 --- _SYNC_INTERVAL = 120 # 每 120 秒检查一次 def _check_and_sync(): """检测文档数量变化,用于多用户同步感知。""" now = time.time() last_check = st.session_state.get("_sync_last_check", 0) if now - last_check < _SYNC_INTERVAL: return st.session_state["_sync_last_check"] = now for scope, label in [(PUBLIC_SCOPE, "public"), (PRIVATE_SCOPE, "private")]: count_key = f"_sync_count_{label}" current_count = _count_chunks(scope) prev_count = st.session_state.get(count_key, -1) if prev_count >= 0 and current_count != prev_count: logger.info(f"[SYNC] {label} 库变更: {prev_count} -> {current_count}") st.session_state[count_key] = current_count _check_and_sync() # --- 登录用户自动模式切换:根据知识库内容决定默认模式 --- def _has_any_kb_content(): """检查是否有任何知识库内容(公共库或私有库)""" pub_count = st.session_state.get("_sync_count_public") if pub_count is None: pub_count = _count_chunks(PUBLIC_SCOPE) st.session_state["_sync_count_public"] = pub_count priv_count = st.session_state.get("_sync_count_private") if priv_count is None: priv_count = _count_chunks(PRIVATE_SCOPE) st.session_state["_sync_count_private"] = priv_count return (pub_count + priv_count) > 0 if not IS_GUEST and "sel_web" not in st.session_state: # 首次加载时,有任何知识库内容则默认知识库模式,否则联网模式 st.session_state["sel_web"] = not _has_any_kb_content() # ========================= # 5. 缓存 LLM 客户端 # ========================= @st.cache_resource def get_or_client(): from openai import OpenAI return OpenAI(api_key=OR_KEY, base_url="https://openrouter.ai/api/v1") @st.cache_resource def get_ds_client(): from openai import OpenAI return OpenAI(api_key=DS_API_KEY, base_url="https://api.deepseek.com") @st.cache_resource def get_baidu_client(): from openai import OpenAI return OpenAI( api_key=BAIDU_TOKEN, base_url="https://qianfan.baidubce.com/v2", default_headers={"appid": BAIDU_APP_ID}, ) # ========================= # 6. 实用功能函数 # ========================= _text_splitter_cache = None def _get_text_splitter(): global _text_splitter_cache if _text_splitter_cache is None: from langchain_text_splitters import RecursiveCharacterTextSplitter _text_splitter_cache = RecursiveCharacterTextSplitter(chunk_size=400, chunk_overlap=50) return _text_splitter_cache SYSTEM_PROMPT = ( "你是一个专业的知识问答助手。请基于提供的参考资料回答用户问题。" "如果资料中没有相关信息,请诚实说明。回答要准确、有条理、简洁。" "不要编造不在资料中的信息。" ) SYSTEM_PROMPT_WEB = ( "你是一个智能AI助手,擅长结合互联网搜索结果回答用户问题。" "请根据搜索结果提供准确、有条理的回答。" "如果搜索结果不足以回答问题,请结合你自身的知识进行补充。" "回答要简洁实用,注明信息来源。" ) SYSTEM_PROMPT_DIRECT = ( "你是一个智能AI助手。请直接回答用户的问题。" "回答要准确、有条理、简洁实用。" "可以充分发挥你的知识和能力来帮助用户。" ) def web_search(query): if not TAVILY_KEY: return "⚠️ 未配置搜索 Key" from tavily import TavilyClient tavily = TavilyClient(api_key=TAVILY_KEY) current_year = datetime.now().year try: search_result = tavily.search( query=f"{current_year}年 {query}", search_depth="advanced", max_results=3, ) results = [ f"来源: {r.get('url')}\n内容: {r.get('content', '')[:700]}" for r in search_result["results"] ] return "\n\n".join(results)[:2500] except Exception as e: logger.error(f"联网搜索异常: {e}") return f"联网搜索异常:{str(e)}" def estimate_tokens(text): if not text: return 0 zh_count = sum(1 for c in text if "\u4e00" <= c <= "\u9fff") return int(zh_count * 1.5 + (len(text) - zh_count) * 0.4) def extract_text(file): fname = file.name.lower() text = "" try: if fname.endswith(".txt"): text = file.read().decode("utf-8", errors="ignore") elif fname.endswith(".pdf"): import io file.seek(0) pdf_bytes = file.read() if len(pdf_bytes) < 100: raise ValueError("PDF 文件过小,可能已损坏") if not pdf_bytes[:5] == b"%PDF-": raise ValueError("不是有效的 PDF 文件(缺少 %PDF- 头)") pdf_stream = io.BytesIO(pdf_bytes) pages_text = [] import pdfplumber with pdfplumber.open(pdf_stream) as pdf: for i, page in enumerate(pdf.pages): try: page_text = page.extract_text() or "" pages_text.append(page_text) except Exception as page_err: logger.warning(f"PDF 第{i+1}页解析失败: {page_err}") pages_text.append("") text = "\n".join(pages_text) elif fname.endswith(".docx"): from docx import Document doc = Document(file) text = "\n".join(para.text for para in doc.paragraphs) except Exception as e: logger.error(f"文件解析失败 [{file.name}]: {e}", exc_info=True) st.error(f"解析失败: {e}") return text def process_upload(uploaded_files, target_prefix, scope): """处理上传文件:解析 → 切片 → 编码 → 写入 Supabase。""" if not uploaded_files: return False try: all_new_chunks = [] all_new_sources = [] with st.spinner("正在自动解析文档并更新索引..."): for f in uploaded_files: try: # 保存原始文件到 Supabase Storage f.seek(0) _save_uploaded_file_to_storage(scope, f) # 解析文本 f.seek(0) raw_text = extract_text(f) if not raw_text.strip(): st.warning(f"文件 {f.name} 内容为空,已跳过。") continue chunks = _get_text_splitter().split_text(raw_text) all_new_chunks.extend(chunks) all_new_sources.extend([f.name] * len(chunks)) except Exception as file_err: logger.error(f"文件 {f.name} 处理失败: {file_err}", exc_info=True) st.warning(f"⚠️ 文件 {f.name} 处理失败:{str(file_err)[:100]},已跳过。") if all_new_chunks: # 分批编码 batch_size = 64 all_vecs = [] for i in range(0, len(all_new_chunks), batch_size): batch = all_new_chunks[i:i + batch_size] all_vecs.extend(encode_texts(batch)) # 按 source_file 分组写入 Supabase file_groups = {} for chunk, vec, src in zip(all_new_chunks, all_vecs, all_new_sources): file_groups.setdefault(src, ([], [])) file_groups[src][0].append(chunk) file_groups[src][1].append(vec) for src_file, (chunks, vecs) in file_groups.items(): _save_chunks_to_db(scope, chunks, vecs, src_file) # 递增上传组件 key ukey = f"_upload_ver_{target_prefix}" st.session_state[ukey] = st.session_state.get(ukey, 0) + 1 # 立即刷新缓存的切片计数 st.session_state[f"_sync_count_{target_prefix}"] = _count_chunks(scope) # 清除文件列表缓存 _list_uploaded_files_db.clear() # 上传成功后自动切换到知识库模式 st.session_state["sel_web"] = False st.toast(f"✅ 导入 {len(all_new_chunks)} 个切片") st.rerun() else: st.error("解析失败,未发现有效文字内容。") except Exception as e: logger.error(f"上传处理异常: {e}", exc_info=True) st.error(f"❌ 上传处理出错:{str(e)[:200]}") return False # ========================= # 6.5 聊天记录持久化(Supabase chat_history 表) # ========================= import threading def _async_run(fn, *args): """在后台线程执行耗时操作,不阻塞 UI。""" t = threading.Thread(target=fn, args=args, daemon=True) t.start() def _save_chat_message_sync(username, role, content, meta=""): """同步保存单条聊天消息到数据库(内部使用)。""" try: _sb().table("chat_history").insert({ "username": username, "role": role, "content": content, "meta": meta or "", }).execute() except Exception as e: logger.warning(f"保存聊天记录失败: {e}") def _save_chat_message(username, role, content, meta=""): """异步保存聊天消息,不阻塞 UI。""" t = threading.Thread( target=_save_chat_message_sync, args=(username, role, content, meta), daemon=True, ) t.start() @st.cache_data(ttl=10) # 缓存10秒,减少重复数据库查询 def _load_chat_history(username, limit=50): """加载用户最近的聊天记录。""" try: resp = _sb().table("chat_history").select( "role, content, meta, created_at" ).eq("username", username).order( "created_at", desc=True ).limit(limit).execute() if not resp.data: return [] rows = list(reversed(resp.data)) return [ {"role": r["role"], "content": r["content"], "meta": r.get("meta", ""), "created_at": r.get("created_at", "")} for r in rows ] except Exception as e: logger.warning(f"加载聊天记录失败: {e}") return [] def _clear_chat_history_db(username): """清空用户在数据库中的所有聊天记录。""" try: _sb().table("chat_history").delete().eq("username", username).execute() except Exception as e: logger.warning(f"清空聊天记录失败: {e}") # ========================= # 7. 侧边栏 UI & 逻辑 # ========================= # 游客模式模型(免费,通过 OpenRouter) _GUEST_MODELS = { "⭐ Step-3.5 (首选)": "stepfun/step-3.5-flash:free", "🌐 OR-Auto (避堵)": "openrouter/free", "🧠 GLM-4.5 (推理)": "z-ai/glm-4.5-air:free", "🔥 Gemma-3-27B (旗舰)": "google/gemma-3-27b-it:free", "🐋 Nemotron (120B)": "nvidia/nemotron-3-super-120b-a12b:free", "⚡ Trinity-L (极速)": "arcee-ai/trinity-large-preview:free", "💭 Liquid-Think (思维链)": "liquid/lfm-2.5-1.2b-thinking:free", "🏎️ Liquid-Ins (1.0s)": "liquid/lfm-2.5-1.2b-instruct:free", "⚖️ Gemma-3-12B (平衡)": "google/gemma-3-12b-it:free", "💎 Gemma-3n-e4b (稳)": "google/gemma-3n-e4b-it:free", "🤖 Nemotron-Nano (混)": "nvidia/nemotron-3-nano-30b-a3b:free", "📉 Trinity-M (1.8s)": "arcee-ai/trinity-mini:free", "🍃 Nemotron-9B": "nvidia/nemotron-nano-9b-v2:free", "🪶 Gemma-3-4B": "google/gemma-3-4b-it:free", "🫧 Gemma-3n-e2b": "google/gemma-3n-e2b-it:free", "📷 Nemotron-VL": "nvidia/nemotron-nano-12b-v2-vl:free", } # 登录模式模型(收费,通过百度千帆 API) _USER_MODELS = { "🧠 文心5.0思维链": "ernie-5.0-thinking-latest", "🔥 文心5.0": "ernie-5.0", "⚡ 文心4.5-Turbo": "ernie-4.5-turbo-128k", "🐋 DeepSeek-V3.2思维": "deepseek-v3.2-think", "💎 DeepSeek-V3.2": "deepseek-v3.2", "🔮 DeepSeek-R1": "deepseek-r1", "🌟 千问3.5-397B": "qwen3.5-397b-a17b", "💭 千问3-235B思维": "qwen3-235b-a22b-thinking-2507", "📊 千问3-32B": "qwen3-32b", "🎯 GLM-5": "glm-5", "🌙 Kimi-K2.5": "kimi-k2.5", "✨ MiniMax-M2.5": "minimax-m2.5", "🛡️ DeepSeek官方": "deepseek-chat", "🏢 文心3.5": "ernie-3.5-8k", } # 根据用户状态选择模型列表 if IS_GUEST: model_mapping = _GUEST_MODELS _active_models = _GUEST_MODELS else: model_mapping = _USER_MODELS _active_models = _USER_MODELS with st.sidebar: # --- 检索参数设置(仅登录用户可见) --- if not IS_GUEST: with st.expander("⚙️ 检索参数"): c1, c2 = st.columns(2) with c1: ui_top_k = st.slider("Top-K", min_value=1, max_value=15, value=5, key="sel_topk") with c2: ui_threshold = st.slider("阈值", min_value=0.0, max_value=1.0, value=0.25, step=0.05, key="sel_threshold") # --- 以下面板仅登录用户可见 --- if not IS_GUEST: pub_chunk_count = st.session_state.get("_sync_count_public") if pub_chunk_count is None: pub_chunk_count = _count_chunks(PUBLIC_SCOPE) st.session_state["_sync_count_public"] = pub_chunk_count with st.expander(f"📚 公共知识库({pub_chunk_count} 切片)"): st.caption("所有人可搜索") # 文件列表 pub_file_list = _list_uploaded_files_db(PUBLIC_SCOPE) if pub_file_list: # 先检查是否有删除请求,避免重复渲染 delete_target = None if IS_ADMIN: for fname, _, _ in pub_file_list: if st.session_state.get(f"delpub_{fname}"): delete_target = fname break if delete_target: _delete_chunks_by_file(PUBLIC_SCOPE, delete_target) _delete_uploaded_file_from_storage(PUBLIC_SCOPE, delete_target) _list_uploaded_files_db.clear() st.session_state["_sync_count_public"] = _count_chunks(PUBLIC_SCOPE) # 删除后如果所有知识库为空,自动切换到联网模式 if not _has_any_kb_content(): st.session_state["sel_web"] = True st.toast(f"已删除 {delete_target}") st.rerun() st.caption(f"📎 已上传 {len(pub_file_list)} 个文件:") for fname, size_str, _ in pub_file_list: if IS_ADMIN: col_name, col_del = st.columns([4, 1]) col_name.text(f"📄 {fname} ({size_str})") col_del.button("🗑", key=f"delpub_{fname}") else: st.text(f"📄 {fname} ({size_str})") if IS_ADMIN: pub_upload_key = f"upload_public_{st.session_state.get('_upload_ver_public', 0)}" # 使用 on_change 回调标记有新文件上传 def _on_pub_upload_change(): st.session_state["_pending_upload_public"] = True pub_files = st.file_uploader( "上传到公共库", type=["txt", "pdf", "docx"], accept_multiple_files=True, label_visibility="collapsed", key=pub_upload_key, on_change=_on_pub_upload_change, ) # 检查是否有待处理的上传(通过 on_change 标记) if pub_files and st.session_state.pop("_pending_upload_public", False): process_upload(pub_files, "public", PUBLIC_SCOPE) if pub_chunk_count > 0 and len(pub_file_list) >= 2: if st.button("🗑️ 清空公共库", use_container_width=True, type="secondary", key="clear_pub"): _clear_all_chunks(PUBLIC_SCOPE) _clear_uploaded_files_storage(PUBLIC_SCOPE) _list_uploaded_files_db.clear() st.session_state["_sync_count_public"] = 0 # 清空后如果所有知识库为空,自动切换到联网模式 if not _has_any_kb_content(): st.session_state["sel_web"] = True st.toast("公共知识库已清空") st.rerun() else: st.caption("*仅管理员可维护公共库*") # --- 私有知识库 --- priv_chunk_count = st.session_state.get("_sync_count_private") if priv_chunk_count is None: priv_chunk_count = _count_chunks(PRIVATE_SCOPE) st.session_state["_sync_count_private"] = priv_chunk_count with st.expander(f"🔒 我的私有库({priv_chunk_count} 切片)"): st.caption(f"用户:{CURRENT_USER},仅自己可见") priv_file_list = _list_uploaded_files_db(PRIVATE_SCOPE) if priv_file_list: # 先检查是否有删除请求,避免重复渲染 delete_target = None for fname, _, _ in priv_file_list: if st.session_state.get(f"delpriv_{fname}"): delete_target = fname break if delete_target: _delete_chunks_by_file(PRIVATE_SCOPE, delete_target) _delete_uploaded_file_from_storage(PRIVATE_SCOPE, delete_target) _list_uploaded_files_db.clear() st.session_state["_sync_count_private"] = _count_chunks(PRIVATE_SCOPE) # 删除后如果所有知识库为空,自动切换到联网模式 if not _has_any_kb_content(): st.session_state["sel_web"] = True st.toast(f"已删除 {delete_target}") st.rerun() st.caption(f"📎 已上传 {len(priv_file_list)} 个文件:") for fname, size_str, _ in priv_file_list: col_name, col_del = st.columns([4, 1]) col_name.text(f"📄 {fname} ({size_str})") col_del.button("🗑", key=f"delpriv_{fname}") priv_upload_key = f"upload_private_{st.session_state.get('_upload_ver_private', 0)}" # 使用 on_change 回调标记有新文件上传 def _on_priv_upload_change(): st.session_state["_pending_upload_private"] = True priv_files = st.file_uploader( "上传到私有库", type=["txt", "pdf", "docx"], accept_multiple_files=True, label_visibility="collapsed", key=priv_upload_key, on_change=_on_priv_upload_change, ) # 检查是否有待处理的上传(通过 on_change 标记) if priv_files and st.session_state.pop("_pending_upload_private", False): process_upload(priv_files, "private", PRIVATE_SCOPE) if priv_chunk_count > 0 and len(priv_file_list) >= 2: if st.button("🗑️ 清空我的私有库", use_container_width=True, type="secondary", key="clear_priv"): _clear_all_chunks(PRIVATE_SCOPE) _clear_uploaded_files_storage(PRIVATE_SCOPE) st.session_state["_sync_count_private"] = 0 # 清空后如果所有知识库为空,自动切换到联网模式 if not _has_any_kb_content(): st.session_state["sel_web"] = True st.toast("私有知识库已清空") st.rerun() # --- 修改密码 --- with st.expander("🔐 修改密码"): with st.form("change_password_form", clear_on_submit=False): old_pass = st.text_input("当前密码", type="password", key="self_old_pass") new_pass1 = st.text_input("新密码", type="password", key="self_new_pass1") new_pass2 = st.text_input("确认新密码", type="password", key="self_new_pass2") submitted = st.form_submit_button("✅ 确认修改", use_container_width=True) if submitted: ok, _, _ = verify_user(CURRENT_USER, old_pass) if not ok: st.error("当前密码错误") elif len(new_pass1) < 4: st.error("新密码至少 4 个字符") elif new_pass1 != new_pass2: st.error("两次新密码不一致") else: _save_user(CURRENT_USER, _hash_password(new_pass1), st.session_state.current_role) st.success("密码修改成功") # --- 管理员面板 --- if IS_ADMIN: with st.expander("👥 用户管理"): all_users = _load_users() user_list = [(u, info) for u, info in all_users.items() if isinstance(info, dict)] st.caption(f"共 **{len(user_list)}** 个用户") for uname, uinfo in user_list: role_tag = "👑" if uinfo.get("role") == "admin" else "👤" created = uinfo.get("created_at", "未知") st.text(f"{role_tag} {uname}({created})") deletable = [u for u, _ in user_list if u != CURRENT_USER] if deletable: del_target = st.selectbox("选择要删除的用户", deletable, key="del_user_select") if st.button("❌ 删除该用户", key="btn_del_user"): _delete_user_db(del_target) # 清除该用户的私有库 _clear_all_chunks(del_target) _clear_uploaded_files_storage(del_target) st.toast(f"用户 {del_target} 已删除") st.rerun() resetable = [u for u, _ in user_list if u != CURRENT_USER] if resetable: reset_target = st.selectbox("选择要重置密码的用户", resetable, key="reset_user_select") with st.form("reset_password_form", clear_on_submit=False): new_pass = st.text_input("新密码", type="password", key="reset_new_pass") reset_submitted = st.form_submit_button("🔄 重置密码", use_container_width=True) if reset_submitted: if len(new_pass) < 4: st.error("密码至少 4 个字符") else: target_role = all_users[reset_target].get("role", "user") _save_user(reset_target, _hash_password(new_pass), target_role) st.toast(f"用户 {reset_target} 密码已重置") with st.expander("📩 邀请码管理"): current_code = _get_invite_code() st.text(f"当前邀请码:{current_code if current_code else '未设置'}") with st.form("invite_code_form", clear_on_submit=False): new_code = st.text_input("新邀请码", key="new_invite_code") code_submitted = st.form_submit_button("✏️ 更新邀请码", use_container_width=True) if code_submitted: if new_code.strip(): _set_invite_code(new_code.strip()) st.toast("邀请码已更新") st.rerun() else: st.error("邀请码不能为空") with st.expander("🛠️ 数据库概览"): st.caption("Supabase 数据统计") try: pub_cnt = _count_chunks(PUBLIC_SCOPE) st.text(f"📚 公共库切片数: {pub_cnt}") # 简单统计各用户私有库 for uname, _ in user_list: cnt = _count_chunks(uname) if cnt > 0: st.text(f"🔒 {uname} 私有库: {cnt} 切片") except Exception as e: st.warning(f"统计失败: {e}") st.divider() st.caption("📋 用户列表") display_users = {} for k, v in all_users.items(): if isinstance(v, dict) and "password_hash" in v: v_copy = dict(v) v_copy["password_hash"] = v_copy["password_hash"][:8] + "..." display_users[k] = v_copy else: display_users[k] = v st.json(display_users) # --- 聊天记录管理(仅登录用户) --- if not IS_GUEST: with st.expander("💬 聊天记录"): hist_tab_new, hist_tab_history = st.tabs(["当前对话", "历史记录"]) with hist_tab_new: st.caption("清空当前对话(数据库记录保留)") if st.button("🧹 清空当前对话", use_container_width=True, type="secondary", key="btn_clear_chat"): st.session_state.messages = [] st.session_state["_chat_cleared"] = True # 标记为主动清空 st.rerun() st.caption("清空所有历史记录(不可恢复)") if st.button("🗑️ 清空全部记录", use_container_width=True, type="secondary", key="btn_clear_all_hist"): _async_run(_clear_chat_history_db, CURRENT_USER) st.session_state.messages = [] st.session_state["_chat_cleared"] = True # 标记为主动清空 st.toast("所有聊天记录已清空") st.rerun() with hist_tab_history: if st.button("🔄 加载历史记录", use_container_width=True, key="btn_load_hist"): st.session_state["_show_history"] = True if st.session_state.get("_show_history"): history = _load_chat_history(CURRENT_USER, limit=100) if not history: st.info("暂无历史记录") else: st.caption(f"共 {len(history)} 条记录") for msg in history: ts = msg.get("created_at", "")[:16].replace("T", " ") icon = "🧑" if msg["role"] == "user" else "🤖" preview = msg["content"][:80].replace("\n", " ") st.text(f"{icon} [{ts}] {preview}{'...' if len(msg['content']) > 80 else ''}") # ========================= # 8. 核心搜索逻辑(合并公共库 + 私有库) # ========================= def search_local(query, top_k, threshold): """使用 pgvector 数据库端向量搜索(替代内存计算)。返回 (docs, files)""" query_vec = encode_query(query) scopes = [PUBLIC_SCOPE, PRIVATE_SCOPE] try: resp = _sb().rpc("match_documents", { "query_embedding": query_vec.tolist() if hasattr(query_vec, 'tolist') else list(query_vec), "match_scopes": scopes, "match_threshold": float(threshold), "match_count": int(top_k), }).execute() if resp.data: docs = [row["content"] for row in resp.data] files = list(set(row["source_file"] for row in resp.data)) # 去重 return docs, files return [], [] except Exception as e: logger.error(f"pgvector 搜索失败: {e}") return [], [] # ========================= # 9. LLM 回答逻辑 # ========================= def llm_answer(query, context_docs, selected_display_name, web_enabled, kb_mode=False, source_files=None): all_context = "" curr_time = datetime.now().strftime("%Y-%m-%d %H:%M") st.session_state["kb_fallback"] = False # 重置降级标志 st.session_state["kb_source_files"] = source_files or [] # 记录引用文件 if web_enabled: # 联网模式:只用网络搜索,不用知识库 search_res = web_search(query) all_context = f"【互联网搜索结果】:\n{search_res}" prompt_content = f"当前时间:{curr_time}\n\n{all_context[:6500]}\n\n用户问题:{query}" system_prompt = SYSTEM_PROMPT_WEB elif context_docs: # 知识库模式:有知识库资料 all_context = "【知识库资料】:\n" + "\n".join(context_docs) prompt_content = f"当前时间:{curr_time}\n\n参考资料:\n{all_context[:6500]}\n\n用户问题:{query}" system_prompt = SYSTEM_PROMPT elif kb_mode: # 知识库模式但无匹配资料:降级到直接回答,但标记降级 st.session_state["kb_fallback"] = True prompt_content = f"当前时间:{curr_time}\n\n用户问题:{query}" system_prompt = SYSTEM_PROMPT_DIRECT else: # 直接回答模式:不联网、无知识库,直接用大模型回答 prompt_content = f"当前时间:{curr_time}\n\n用户问题:{query}" system_prompt = SYSTEM_PROMPT_DIRECT input_tokens = estimate_tokens(prompt_content) or_client = get_or_client() ds_client = get_ds_client() baidu_client = get_baidu_client() selected_id = model_mapping[selected_display_name] messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": prompt_content}, ] # 登录用户:直接调用选中模型(通过百度千帆),不使用重试队列 if not IS_GUEST: # 登录态模型都通过百度千帆调用,除了 deepseek-chat 用官方 API if selected_id == "deepseek-chat": client = ds_client else: client = baidu_client label = f"🔐 {selected_display_name}" logger.info(f"[{CURRENT_USER}] 登录用户调用: {label}") try: response = client.chat.completions.create( model=selected_id, messages=messages, stream=True, timeout=60, # 登录用户给更长超时 ) full_text = "" has_content = False # 直接迭代流式响应 for chunk in response: if chunk.choices and chunk.choices[0].delta.content: content = chunk.choices[0].delta.content full_text += content has_content = True yield content if has_content: fallback_hint = ' | ⚠️ 知识库无匹配,使用通用知识' if st.session_state.get("kb_fallback") else "" source_files = st.session_state.get("kb_source_files", []) files_hint = f' | 📄 {", ".join(source_files)}' if source_files else "" st.session_state["last_meta"] = ( f"🟢 {label} | 📊 ~{input_tokens}/{estimate_tokens(full_text)} Tokens{files_hint}{fallback_hint}" ) return else: yield "❌ 模型返回空响应,请稍后重试。" return except Exception as e: err_msg = str(e) logger.error(f"{label} 调用失败: {err_msg[:200]}") yield f"❌ 调用失败:{err_msg[:100]}" return # 游客用户:使用重试队列和兜底机制 retry_queue = [] retry_queue.append( (or_client, selected_id, f"首选-{selected_display_name}") ) if selected_id != "stepfun/step-3.5-flash:free": retry_queue.append((or_client, "stepfun/step-3.5-flash:free", "⚡ 快速备选-Step3.5")) if selected_id != "openrouter/free": retry_queue.append((or_client, "openrouter/free", "OR-Auto 免费避堵")) # 游客的收费兜底 paid_backups = [ ("deepseek-chat", "🛡️ DeepSeek 官方", ds_client), ("ernie-3.5-8k", "🏢 百度文心", baidu_client), ] for p_id, p_label, p_client in paid_backups: if selected_id != p_id: retry_queue.append((p_client, p_id, f"💰 收费兜底-{p_label}")) for idx, (client, m_id, label) in enumerate(retry_queue): logger.info(f"[{CURRENT_USER}] 尝试链路: {label}") try: extra_h = ( {"HTTP-Referer": "https://streamlit.io", "X-Title": "RAG_v3"} if client is or_client else None ) response = client.chat.completions.create( model=m_id, messages=messages, stream=True, extra_headers=extra_h, timeout=25, ) full_text = "" has_content = False # 直接迭代流式响应 for chunk in response: if chunk.choices and chunk.choices[0].delta.content: content = chunk.choices[0].delta.content full_text += content has_content = True yield content if has_content: st.session_state["last_meta"] = ( f"🟢 {label} | 📊 ~{input_tokens}/{estimate_tokens(full_text)} Tokens" ) return except Exception as e: err_msg = str(e) logger.warning(f"{label} 失败: {err_msg[:100]}") if "429" in err_msg: st.toast(f"{label} 拥堵,切换备选...", icon="⏳") time.sleep(1.5) continue yield "❌ 抱歉,所有免费和收费线路均暂时不可用。" # ========================= # 10. 聊天渲染(使用 @st.fragment 避免整页刷新) # ========================= if "messages" not in st.session_state: # 首次加载,从数据库恢复最近对话(游客不恢复) if not IS_GUEST: saved = _load_chat_history(CURRENT_USER, limit=50) st.session_state.messages = [ {"role": m["role"], "content": m["content"], **({"meta": m["meta"]} if m.get("meta") else {})} for m in saved ] else: st.session_state.messages = [] elif st.session_state.get("_chat_cleared"): # 用户主动清空,不从数据库恢复 del st.session_state["_chat_cleared"] @st.fragment def _chat_fragment(): """聊天区域独立 fragment,提交消息时只刷新此区域,不刷新整个页面。""" _model_name = st.session_state.get("sel_model", list(model_mapping.keys())[0]) _web_on = st.session_state.get("sel_web", True) # 默认联网模式 _top_k = st.session_state.get("sel_topk", 5) _threshold = st.session_state.get("sel_threshold", 0.25) # 欢迎页(无消息时显示,放在 fragment 内避免整页刷新) _welcome_hero = st.empty() if not st.session_state.messages: with _welcome_hero.container(): st.markdown( """
🤖
欢迎使用智答 AI 助手
🌐 联网模式 —— 大模型 + 网络搜索,实时回答
📚 知识库模式 —— 上传文档,基于私有知识回答
💡 试试这些问题
""", unsafe_allow_html=True, ) # 快捷问题 - 在 fragment 内部,点击只触发 fragment 刷新 _quick_questions = [ "今天有什么热点新闻", "用Python写快速排序", "解释RAG技术是什么", ] _q_cols = st.columns(3) for i, qq in enumerate(_quick_questions): with _q_cols[i]: if st.button(f"💬 {qq}", key=f"quick_q{i}", use_container_width=True, type="secondary"): st.session_state["_quick_question"] = qq # 消息容器 - 不设固定高度,由外层CSS控制滚动 chat_box = st.container() q = st.chat_input("问问智答AI助手", key="chat_input_v3") # 检查快捷问题 if st.session_state.get("_quick_question"): q = st.session_state.pop("_quick_question") with chat_box: for idx, m in enumerate(st.session_state.messages): with st.chat_message(m["role"]): if m["role"] == "assistant": meta_html = f'\n\n{m["meta"]}' if m.get("meta") else "" action_btns = _render_action_buttons(m["content"], f"hist_{idx}") st.markdown( m["content"] + meta_html + action_btns, unsafe_allow_html=True, ) _inject_action_js(m["content"], f"hist_{idx}") else: st.markdown(m["content"]) if q: _welcome_hero.empty() # 发送消息后清空欢迎页 st.session_state.messages.append({"role": "user", "content": q}) if not IS_GUEST: _save_chat_message(CURRENT_USER, "user", q) with st.chat_message("user"): st.markdown(q) with st.chat_message("assistant"): response_container = st.empty() # 滚动JS放在assistant块内部,不影响问题和回答之间的间距 components.html( """""", height=0, ) # 游客模式:始终用大模型(联网时加网络搜索,不联网直接回答) # 登录模式 + 联网:大模型 + 网络搜索 # 登录模式 + 知识库:大模型 + 知识库检索 if IS_GUEST: # 游客没有知识库,直接用大模型 if _web_on: response_container.markdown("*🌐 正在联网搜索...*") else: response_container.markdown("*🤔 正在思考...*") relevant_docs, source_files = [], [] elif _web_on: response_container.markdown("*🌐 正在联网搜索...*") relevant_docs, source_files = [], [] else: response_container.markdown("*🔍 正在搜索知识库...*") relevant_docs, source_files = search_local(q, _top_k, _threshold) response_container.markdown("*🤔 正在组织语言...*") # 登录用户 + 非联网 = 知识库模式 _kb_mode = (not IS_GUEST) and (not _web_on) try: full_response = response_container.write_stream( llm_answer(q, relevant_docs, _model_name, _web_on, kb_mode=_kb_mode, source_files=source_files) ) meta_info = st.session_state.get("last_meta", "") # 生成工具按钮(使用当前消息数作为ID) msg_idx = len(st.session_state.messages) action_btns = _render_action_buttons(full_response, f"new_{msg_idx}") meta_html = f'\n\n{meta_info}' if meta_info else "" response_container.markdown( full_response + meta_html + action_btns, unsafe_allow_html=True, ) _inject_action_js(full_response, f"new_{msg_idx}") st.session_state.messages.append( {"role": "assistant", "content": full_response, "meta": meta_info} ) if not IS_GUEST: _save_chat_message(CURRENT_USER, "assistant", full_response, meta_info) except Exception as e: logger.error(f"模型调用异常: {e}") response_container.error(f"❌ 抱歉,连接模型时出错了: {str(e)}") # ========================= # 11. 搜索框下方的工具按钮(使用 fragment 局部刷新) # ========================= @st.fragment def _toolbar_fragment(): _current_model = st.session_state.get("sel_model", list(model_mapping.keys())[0]) _web_status = st.session_state.get("sel_web", True) _model_short = _current_model.split("(")[0].strip() _btn_cols = st.columns(2) with _btn_cols[0]: with st.popover(_model_short, use_container_width=True): st.caption("选择模型") _model_list = list(_active_models.keys()) for m in _model_list: _btn_type = "primary" if m == _current_model else "secondary" if st.button(m, key=f"mdl_{m}", use_container_width=True, type=_btn_type): if m != _current_model: st.session_state["sel_model"] = m st.rerun() with _btn_cols[1]: if IS_GUEST: _web_label = "🌐 联网模式" if _web_status else "💬 直接回答" else: _web_label = "🌐 联网模式" if _web_status else "📚 知识库" if st.button(_web_label, key="toggle_web", use_container_width=True): st.session_state["sel_web"] = not _web_status st.rerun() _toolbar_fragment() _chat_fragment()