rag-assistant / app.py
AlauStone's picture
Upload app.py
6d768b1 verified
import os
# 优化 Streamlit 连接稳定性
os.environ.setdefault("STREAMLIT_SERVER_ENABLE_WEBSOCKET_COMPRESSION", "false")
os.environ.setdefault("STREAMLIT_SERVER_MAX_MESSAGE_SIZE", "200")
import streamlit as st
import streamlit.components.v1 as components
import json
import time
import logging
import hashlib
import uuid
from datetime import datetime
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)
# numpy 延迟导入
_np_module = None
def _get_np():
global _np_module
if _np_module is None:
import numpy
_np_module = numpy
return _np_module
# =========================
# 1. 页面配置 & 样式注入
# =========================
st.set_page_config(page_title="智答 AI 助手", page_icon="🤖", layout="wide", initial_sidebar_state="collapsed")
@st.cache_data
def _get_custom_css():
return """
<style>
[data-testid="stSidebarContent"] { padding-top: 1.5rem !important; }
[data-testid="stVerticalBlock"] > div { gap: 0.2rem !important; }
/* 聊天消息间距控制 */
[data-testid="stChatMessage"] { margin-bottom: 0.1rem !important; margin-top: 0.1rem !important; padding-top: 0.5rem !important; padding-bottom: 0.5rem !important; }
/* 聊天消息内部 gap 归零 */
[data-testid="stChatMessageContent"] > [data-testid="stVerticalBlock"] > div { gap: 0 !important; }
/* 聊天消息内标题字体大小限制 */
[data-testid="stChatMessage"] h1 { font-size: 1.3rem !important; margin: 0.8rem 0 0.5rem !important; }
[data-testid="stChatMessage"] h2 { font-size: 1.15rem !important; margin: 0.7rem 0 0.4rem !important; }
[data-testid="stChatMessage"] h3 { font-size: 1.05rem !important; margin: 0.6rem 0 0.3rem !important; }
[data-testid="stChatMessage"] h4, [data-testid="stChatMessage"] h5, [data-testid="stChatMessage"] h6 { font-size: 1rem !important; margin: 0.5rem 0 0.3rem !important; }
/* 隐藏 stHtml iframe 容器间距,但保留渲染(JS 可执行) */
[data-testid="stHtml"] { height: 0 !important; min-height: 0 !important; margin: 0 !important; padding: 0 !important; overflow: hidden !important; }
[data-testid="stSelectbox"] input { caret-color: transparent !important; }
/* 隐藏 text_input 的 press enter to apply 提示 */
[data-testid="stTextInput"] div[data-testid="InputInstructions"] { display: none !important; }
[data-testid="stFileUploader"] section > div { display: none; }
[data-testid="stFileUploaderDropzoneInstructions"] { display: none !important; }
[data-testid="stFileUploader"] section::before {
content: "拖拽文档至此";
color: #555; font-size: 14px; display: block; margin-bottom: 10px;
}
[data-testid="stFileUploader"] section::after {
content: "支持格式:TXT, PDF, DOCX";
color: #888; font-size: 12px; display: block; margin-top: 5px;
}
[data-testid="stFileUploader"] button { font-size: 0 !important; }
[data-testid="stFileUploader"] button::after {
content: "选择文件";
font-size: 14px !important;
}
/* 顶部工具栏 */
[data-testid="stHeader"] {
height: 2.5rem !important;
background: transparent !important;
}
/* 左上角侧边栏按钮:替换箭头为「☰ 菜单」- 无边框风格 */
[data-testid="stSidebarCollapsedControl"] {
top: 0.4rem !important;
left: 0.5rem !important;
}
[data-testid="stSidebarCollapsedControl"] button {
background: transparent !important;
border: none !important;
padding: 4px 8px !important;
}
[data-testid="stSidebarCollapsedControl"] button:hover {
background: rgba(0,0,0,0.05) !important;
border-radius: 6px !important;
}
[data-testid="stSidebarCollapsedControl"] button svg { display: none !important; }
[data-testid="stSidebarCollapsedControl"] button::after {
content: "☰ 菜单";
font-size: 14px;
color: #666;
}
/* 侧边栏展开后的收起按钮 - 无边框风格 */
[data-testid="stSidebarCollapseButton"] button {
background: transparent !important;
border: none !important;
padding: 4px 8px !important;
}
[data-testid="stSidebarCollapseButton"] button:hover {
background: rgba(0,0,0,0.05) !important;
border-radius: 6px !important;
}
[data-testid="stSidebarCollapseButton"] button svg { display: none !important; }
[data-testid="stSidebarCollapseButton"] button::after {
content: "✕ 收起";
font-size: 14px;
color: #666;
}
/* 右上角三点菜单:替换为「⚙ 设置」- 无边框风格 */
[data-testid="stMainMenu"] {
position: fixed !important;
top: 0.55rem !important;
right: 0.5rem !important;
}
[data-testid="stMainMenu"] button {
background: transparent !important;
border: none !important;
padding: 4px 8px !important;
}
[data-testid="stMainMenu"] button:hover {
background: rgba(0,0,0,0.05) !important;
border-radius: 6px !important;
}
[data-testid="stMainMenu"] button svg { display: none !important; }
[data-testid="stMainMenu"] button::before {
content: "⚙ 设置";
font-size: 14px;
color: #666;
}
/* 主内容区域:限制在标题和搜索框之间 */
.block-container {
position: fixed !important;
top: 7rem !important;
bottom: 7rem !important;
left: 0 !important;
right: 0 !important;
overflow-y: auto !important;
-webkit-overflow-scrolling: touch !important;
padding: 1rem 1rem 1rem 1rem !important;
}
/* 固定标题在顶部 */
.main-title {
position: fixed !important;
top: 2.5rem !important;
left: 0 !important;
right: 0 !important;
z-index: 9999 !important;
background: white !important;
padding: 0.3rem 1rem !important;
}
/* 固定搜索框在底部 */
[data-testid="stChatInput"] {
position: fixed !important;
bottom: 3.5rem !important;
left: 50% !important;
transform: translateX(-50%) !important;
width: calc(100% - 2rem) !important;
max-width: 730px !important;
z-index: 9999 !important;
background: white !important;
}
/* 搜索框聚焦时去掉红色边框 */
[data-testid="stChatInput"] textarea:focus,
[data-testid="stChatInput"] div:focus-within {
outline: none !important;
border-color: #ddd !important;
box-shadow: none !important;
}
/* 底部工具栏固定在搜索框下方 */
.stMainBlockContainer > div > div > div:has([data-testid="stPopover"]) {
position: fixed !important;
bottom: 0.3rem !important;
left: 50% !important;
transform: translateX(-50%) !important;
width: calc(100% - 2rem) !important;
max-width: 730px !important;
z-index: 9998 !important;
}
.stMainBlockContainer > div > div > div:has([data-testid="stPopover"]) [data-testid="stHorizontalBlock"] {
display: flex !important;
flex-wrap: nowrap !important;
gap: 0.5rem !important;
}
.stMainBlockContainer > div > div > div:has([data-testid="stPopover"]) [data-testid="stHorizontalBlock"] > div {
flex: 1 !important;
min-width: 0 !important;
}
/* 确保底部工具栏按钮占满容器宽度 */
.stMainBlockContainer > div > div > div:has([data-testid="stPopover"]) [data-testid="stHorizontalBlock"] button {
width: 100% !important;
}
/* popover 向上展开 */
[data-testid="stPopoverBody"] {
bottom: 100% !important;
top: auto !important;
margin-bottom: 0.5rem !important;
}
/* 底部工具栏按钮:无边框风格,统一颜色 */
.stMainBlockContainer > div > div > div:has([data-testid="stPopover"]) button,
.stMainBlockContainer > div > div > div:has([data-testid="stPopover"]) [data-testid="stPopover"] > div:first-child {
background: transparent !important;
border: none !important;
box-shadow: none !important;
color: #666 !important;
}
.stMainBlockContainer > div > div > div:has([data-testid="stPopover"]) button:hover,
.stMainBlockContainer > div > div > div:has([data-testid="stPopover"]) [data-testid="stPopover"] > div:first-child:hover {
background: rgba(0,0,0,0.05) !important;
}
.stMainBlockContainer > div > div > div:has([data-testid="stPopover"]) button p,
.stMainBlockContainer > div > div > div:has([data-testid="stPopover"]) button span,
.stMainBlockContainer > div > div > div:has([data-testid="stPopover"]) [data-testid="stMarkdownContainer"] p,
.stMainBlockContainer > div > div > div:has([data-testid="stPopover"]) [data-testid="stPopover"] p {
color: #666 !important;
}
/* 确保聊天容器可滚动 */
[data-testid="stVerticalBlockBorderWrapper"] {
overflow-y: auto !important;
-webkit-overflow-scrolling: touch !important;
}
/* 手机端响应式调整(屏幕宽度 < 768px) */
@media (max-width: 768px) {
.main-title {
padding: 0.2rem 0.5rem !important;
}
.main-title h1 {
font-size: 1.2rem !important;
}
/* 手机端聊天消息标题更小 */
[data-testid="stChatMessage"] h1 { font-size: 1.1rem !important; }
[data-testid="stChatMessage"] h2 { font-size: 1rem !important; }
[data-testid="stChatMessage"] h3, [data-testid="stChatMessage"] h4, [data-testid="stChatMessage"] h5, [data-testid="stChatMessage"] h6 { font-size: 0.95rem !important; }
/* 手机端工具栏:强制两列并排 */
[data-testid="stHorizontalBlock"] {
display: flex !important;
flex-wrap: nowrap !important;
gap: 0.3rem !important;
}
[data-testid="stHorizontalBlock"] > div {
flex: 1 !important;
min-width: 0 !important;
width: auto !important;
}
/* 按钮文字缩小,单行截断 */
[data-testid="stHorizontalBlock"] button {
font-size: 12px !important;
padding: 0.4rem 0.3rem !important;
white-space: nowrap !important;
overflow: hidden !important;
text-overflow: ellipsis !important;
}
[data-testid="stHorizontalBlock"] button p {
white-space: nowrap !important;
overflow: hidden !important;
text-overflow: ellipsis !important;
margin: 0 !important;
}
/* 手机端 popover 居中显示 */
[data-testid="stPopoverBody"] {
left: 1rem !important;
right: 1rem !important;
transform: none !important;
width: auto !important;
max-width: calc(100vw - 2rem) !important;
bottom: auto !important;
top: auto !important;
}
}
/* 游客模式:标题区域 */
body:has(.main-title-guest) .block-container {
top: 6.5rem !important;
}
@media (max-width: 768px) {
body:has(.main-title-guest) .block-container {
top: 6.5rem !important;
padding: 0.5rem !important;
}
}
/* 登录模式:标题较矮,内容区 top 更小 */
body:has(.main-title-user) .block-container {
top: 6rem !important;
}
@media (max-width: 768px) {
body:has(.main-title-user) .block-container {
top: 4.5rem !important;
padding: 0.5rem !important;
}
}
/* 欢迎页:垂直居中,禁止滚动 */
.block-container:has(.welcome-marker) {
overflow: hidden !important;
display: flex !important;
align-items: center !important;
justify-content: center !important;
}
/* 快捷问题按钮:无边框链接风格 */
.block-container:has(.welcome-marker) [data-testid="stBaseButton-secondary"] {
background: transparent !important;
border: none !important;
color: #1a73e8 !important;
box-shadow: none !important;
white-space: nowrap !important;
}
.block-container:has(.welcome-marker) [data-testid="stBaseButton-secondary"]:hover {
background: #f0f7ff !important;
}
.block-container:has(.welcome-marker) [data-testid="stBaseButton-secondary"] p {
color: #1a73e8 !important;
white-space: nowrap !important;
}
/* 手机端快捷问题:垂直排列,按钮宽度自适应内容 */
@media (max-width: 768px) {
/* 快捷问题按钮区域:垂直居中排列 */
.block-container:has(.welcome-marker) [data-testid="stVerticalBlockBorderWrapper"] [data-testid="stHorizontalBlock"]:not(:has([data-testid="stPopover"])) {
flex-direction: column !important;
align-items: center !important;
gap: 0.3rem !important;
}
/* 按钮容器:宽度自适应 */
.block-container:has(.welcome-marker) [data-testid="stVerticalBlockBorderWrapper"] [data-testid="stHorizontalBlock"]:not(:has([data-testid="stPopover"])) > div {
width: auto !important;
flex: none !important;
}
/* 按钮本身:宽度自适应内容 */
.block-container:has(.welcome-marker) [data-testid="stBaseButton-secondary"] {
width: auto !important;
min-width: 0 !important;
}
}
/* 侧边栏文件列表:防止文件名溢出遮挡删除按钮 */
[data-testid="stSidebar"] [data-testid="stHorizontalBlock"] > div:first-child {
overflow: hidden !important;
text-overflow: ellipsis !important;
white-space: nowrap !important;
}
[data-testid="stSidebar"] [data-testid="stHorizontalBlock"] > div:last-child {
flex-shrink: 0 !important;
z-index: 1 !important;
}
/* meta 信息中的附加内容(文件名、降级提示)- 移动端换行 */
.meta-extra {
display: inline;
}
.meta-sep {
display: inline;
}
@media (max-width: 768px) {
.meta-extra {
display: block;
margin-top: 2px;
}
.meta-sep {
display: none;
}
}
/* 移动端侧边栏文件列表布局修复 */
@media (max-width: 768px) {
[data-testid="stSidebar"] [data-testid="stHorizontalBlock"] {
display: flex !important;
flex-direction: row !important;
justify-content: space-between !important;
align-items: center !important;
width: 100% !important;
}
[data-testid="stSidebar"] [data-testid="stHorizontalBlock"] > div:first-child {
flex: 1 !important;
min-width: 0 !important;
}
[data-testid="stSidebar"] [data-testid="stHorizontalBlock"] > div:last-child {
flex: 0 0 auto !important;
margin-left: 8px !important;
}
}
</style>
"""
def inject_custom_css():
st.markdown(_get_custom_css(), unsafe_allow_html=True)
inject_custom_css()
def _render_action_buttons(content, msg_id):
"""生成消息底部的复制/分享/播报按钮(纯HTML部分)"""
return f'''
<div style="display:flex; gap:12px; margin-top:8px; padding-top:6px; border-top:1px solid #eee;">
<button id="copy_btn_{msg_id}" style="background:none; border:none; cursor:pointer; color:#666; font-size:13px; padding:4px 8px; border-radius:4px;">📋 复制</button>
<button id="share_btn_{msg_id}" style="background:none; border:none; cursor:pointer; color:#666; font-size:13px; padding:4px 8px; border-radius:4px;">🔗 分享</button>
<button id="tts_btn_{msg_id}" style="background:none; border:none; cursor:pointer; color:#666; font-size:13px; padding:4px 8px; border-radius:4px;">🔊 播报</button>
</div>
'''
def _inject_action_js(content, msg_id):
"""注入按钮的 JavaScript(需要用 components.html 执行)"""
# 转义内容用于 JavaScript
escaped = content.replace("\\", "\\\\").replace("`", "\\`").replace("$", "\\$").replace("</script>", "<\\/script>")
js_code = f'''
<script>
(function() {{
const msgContent = `{escaped}`;
let speaking = false;
const parentDoc = window.parent.document;
// 兼容性复制函数(支持 iframe 环境)
function copyToClipboard(text) {{
// 方法1: 现代 API
if (navigator.clipboard && window.isSecureContext) {{
return navigator.clipboard.writeText(text);
}}
// 方法2: 传统方法(textarea)
return new Promise((resolve, reject) => {{
const textarea = parentDoc.createElement('textarea');
textarea.value = text;
textarea.style.position = 'fixed';
textarea.style.left = '-9999px';
textarea.style.top = '-9999px';
parentDoc.body.appendChild(textarea);
textarea.focus();
textarea.select();
try {{
const ok = parentDoc.execCommand('copy');
parentDoc.body.removeChild(textarea);
ok ? resolve() : reject();
}} catch (e) {{
parentDoc.body.removeChild(textarea);
reject(e);
}}
}});
}}
const copyBtn = parentDoc.getElementById('copy_btn_{msg_id}');
const shareBtn = parentDoc.getElementById('share_btn_{msg_id}');
const ttsBtn = parentDoc.getElementById('tts_btn_{msg_id}');
if (copyBtn) {{
copyBtn.onclick = function() {{
copyToClipboard(msgContent).then(() => {{
copyBtn.textContent = '✅ 已复制';
setTimeout(() => {{ copyBtn.textContent = '📋 复制'; }}, 1500);
}}).catch(() => {{
// 最后方案:提示用户手动复制
prompt('请手动复制以下内容:', msgContent.substring(0, 500));
}});
}};
copyBtn.onmouseover = function() {{ this.style.background='#f0f0f0'; }};
copyBtn.onmouseout = function() {{ this.style.background='none'; }};
}}
if (shareBtn) {{
shareBtn.onclick = function() {{
// 分享功能:优先用 Web Share API,否则复制
if (navigator.share && !window.parent.frames.length) {{
navigator.share({{ title: '智答AI助手', text: msgContent }}).catch(() => {{}});
}} else {{
copyToClipboard(msgContent).then(() => {{
shareBtn.textContent = '✅ 已复制';
setTimeout(() => {{ shareBtn.textContent = '🔗 分享'; }}, 1500);
}}).catch(() => {{
prompt('请手动复制分享:', msgContent.substring(0, 500));
}});
}}
}};
shareBtn.onmouseover = function() {{ this.style.background='#f0f0f0'; }};
shareBtn.onmouseout = function() {{ this.style.background='none'; }};
}}
if (ttsBtn) {{
ttsBtn.onclick = function() {{
const synth = window.speechSynthesis || window.parent.speechSynthesis;
if (!synth) {{ alert('浏览器不支持语音播报'); return; }}
if (speaking) {{
synth.cancel();
speaking = false;
ttsBtn.textContent = '🔊 播报';
}} else {{
const utterance = new SpeechSynthesisUtterance(msgContent);
utterance.lang = 'zh-CN';
utterance.rate = 1.0;
utterance.onend = () => {{ speaking = false; ttsBtn.textContent = '🔊 播报'; }};
utterance.onerror = () => {{ speaking = false; ttsBtn.textContent = '🔊 播报'; }};
synth.speak(utterance);
speaking = true;
ttsBtn.textContent = '⏹️ 停止';
}}
}};
ttsBtn.onmouseover = function() {{ this.style.background='#f0f0f0'; }};
ttsBtn.onmouseout = function() {{ this.style.background='none'; }};
}}
}})();
</script>
'''
components.html(js_code, height=0)
# 全局游客标识(session_state 在单次渲染中不变)
IS_GUEST = not bool(st.session_state.get("current_user"))
# 标题(固定在顶部)- 游客始终显示提示
_guest_tip = """<div style="text-align:center; padding:4px 0; color:#666; font-size:13px;">
🙋 当前为<b>游客模式</b>,可直接提问体验
</div>""" if IS_GUEST else ""
# 根据是否游客添加不同 class,用于 CSS 控制间距
_title_class = "main-title main-title-guest" if IS_GUEST else "main-title main-title-user"
st.markdown(
f"""
<div class="{_title_class}">
<h1 style="white-space:nowrap; overflow:hidden; text-overflow:ellipsis; font-size:clamp(1.4rem, 5vw, 2.2rem); margin:0; text-align:center;">
🤖 智答 AI 助手
</h1>
{_guest_tip}
</div>
""",
unsafe_allow_html=True,
)
# =========================
# 1.5 Supabase 客户端初始化
# =========================
from supabase import create_client
SUPABASE_URL = st.secrets.get("SUPABASE_URL", "")
SUPABASE_KEY = st.secrets.get("SUPABASE_KEY", "") # service_role key (后端使用)
STORAGE_BUCKET = "rag-files"
if not SUPABASE_URL or not SUPABASE_KEY:
st.error("⚠️ 未配置 SUPABASE_URL 或 SUPABASE_KEY,请在 Secrets 中设置。")
st.stop()
@st.cache_resource
def _get_supabase():
return create_client(SUPABASE_URL, SUPABASE_KEY)
def _sb():
"""快捷获取 Supabase 客户端。"""
return _get_supabase()
# =========================
# 2. 用户管理(Supabase users 表)
# =========================
MAX_LOGIN_ATTEMPTS = 10
def _hash_password(password):
return hashlib.sha256(password.encode("utf-8")).hexdigest()
@st.cache_data(ttl=30) # 缓存30秒,减少重复查询
def _load_users():
"""从 Supabase users 表加载所有用户,返回 {username: {password_hash, role, created_at}}。"""
try:
resp = _sb().table("users").select("*").execute()
users = {}
for row in resp.data:
users[row["username"]] = {
"password_hash": row["password_hash"],
"role": row["role"],
"created_at": row["created_at"][:16] if row.get("created_at") else "未知",
}
return users
except Exception as e:
logger.error(f"加载用户表失败: {e}")
return {}
def _ensure_admin():
"""首次运行时,从 secrets 创建管理员(如果 users 表为空)。"""
users = _load_users()
if users:
return
admin_user = st.secrets.get("ADMIN_USER", "admin")
admin_pass = st.secrets.get("ADMIN_PASSWORD", "")
if not admin_pass:
return
try:
_sb().table("users").upsert({
"username": admin_user,
"password_hash": _hash_password(admin_pass),
"role": "admin",
}).execute()
logger.info(f"初始管理员 {admin_user} 已创建")
except Exception as e:
logger.error(f"创建管理员失败: {e}")
# 仅首次运行时检查管理员(避免每次刷新都查询)
if not st.session_state.get("_admin_ensured"):
_ensure_admin()
st.session_state["_admin_ensured"] = True
def _save_user(username, password_hash, role="user"):
"""新增或更新单个用户。"""
_sb().table("users").upsert({
"username": username,
"password_hash": password_hash,
"role": role,
}).execute()
_load_users.clear() # 清除用户列表缓存
def _delete_user_db(username):
"""删除用户记录。"""
_sb().table("users").delete().eq("username", username).execute()
_load_users.clear() # 清除用户列表缓存
@st.cache_data(ttl=60) # 缓存60秒
def _get_invite_code():
"""从 app_meta 表读取邀请码。"""
try:
resp = _sb().table("app_meta").select("value").eq("key", "invite_code").execute()
if resp.data:
return resp.data[0]["value"]
except Exception:
pass
return st.secrets.get("INVITE_CODE", "")
def _set_invite_code(new_code):
_sb().table("app_meta").upsert({"key": "invite_code", "value": new_code}).execute()
_get_invite_code.clear() # 清除缓存
def register_user(username, password, invite_code):
if not username or not password:
return False, "用户名和密码不能为空"
if len(username) < 2 or len(username) > 20:
return False, "用户名长度需要 2-20 个字符"
if len(password) < 4:
return False, "密码至少 4 个字符"
if username.startswith("__"):
return False, "用户名不能以 __ 开头"
correct_code = _get_invite_code()
if not correct_code:
return False, "邀请码未配置,请联系管理员"
if invite_code != correct_code:
return False, "邀请码错误"
users = _load_users()
if username in users:
return False, "用户名已存在"
try:
_save_user(username, _hash_password(password), "user")
logger.info(f"新用户注册: {username}")
return True, "注册成功,请登录"
except Exception as e:
logger.error(f"注册失败: {e}")
return False, f"注册失败: {e}"
def verify_user(username, password):
users = _load_users()
user_info = users.get(username)
if not user_info or not isinstance(user_info, dict):
return False, None, "not_found"
if user_info.get("password_hash") != _hash_password(password):
return False, None, "wrong_password"
return True, user_info.get("role", "user"), ""
# --- 认证 UI ---
if "login_attempts" not in st.session_state:
st.session_state.login_attempts = 0
if "current_user" not in st.session_state:
st.session_state.current_user = None
if "current_role" not in st.session_state:
st.session_state.current_role = None
if "auth_mode" not in st.session_state:
st.session_state.auth_mode = "login"
with st.sidebar:
# 未登录时不用 expander,直接显示登录表单(避免手动收起后无法自动展开)
if IS_GUEST:
st.subheader("🔑 账号")
if st.session_state.login_attempts >= MAX_LOGIN_ATTEMPTS:
st.error("🚫 尝试次数过多,请刷新页面后重试。")
st.stop()
users_data = _load_users()
if not users_data:
st.error("⚠️ 未配置管理员,请在 secrets 中设置 ADMIN_USER 和 ADMIN_PASSWORD。")
st.stop()
auth_mode = st.radio(
"操作", ["登录", "注册"], horizontal=True,
label_visibility="collapsed", key="auth_radio",
)
if auth_mode == "登录":
with st.form("login_form", clear_on_submit=False):
input_username = st.text_input("用户名", key="login_user")
input_password = st.text_input("密码", type="password", key="login_pass")
submitted = st.form_submit_button("🔓 登录", use_container_width=True, type="primary")
if submitted:
if input_username == "" or input_password == "":
st.warning("请输入用户名和密码")
elif st.session_state.login_attempts >= MAX_LOGIN_ATTEMPTS:
pass
else:
ok, role, reason = verify_user(input_username, input_password)
if not ok:
st.session_state.login_attempts += 1
remaining = MAX_LOGIN_ATTEMPTS - st.session_state.login_attempts
if reason == "not_found":
st.warning(f"⚠️ 用户不存在,请先注册(剩余 {remaining} 次)")
else:
st.warning(f"⚠️ 密码错误(剩余 {remaining} 次)")
else:
st.session_state.login_attempts = 0
st.session_state.current_user = input_username
st.session_state.current_role = role
st.rerun()
else: # 注册
with st.form("register_form", clear_on_submit=False):
reg_user = st.text_input("用户名", key="reg_user")
reg_pass = st.text_input("密码", type="password", key="reg_pass")
reg_pass2 = st.text_input("确认密码", type="password", key="reg_pass2")
reg_code = st.text_input("邀请码", type="password", key="reg_code")
reg_submitted = st.form_submit_button("注册", use_container_width=True)
if reg_submitted:
if reg_pass != reg_pass2:
st.error("两次密码不一致")
else:
ok, msg = register_user(reg_user, reg_pass, reg_code)
if ok:
st.session_state.current_user = reg_user
st.session_state.current_role = "user"
st.session_state.login_attempts = 0
st.rerun()
else:
st.error(f"❌ {msg}")
else:
with st.expander("🔑 账号"):
role_label = "管理员" if st.session_state.current_role == "admin" else "普通用户"
st.success(f"✅ {st.session_state.current_user}{role_label})")
# 设置全局用户标识
if st.session_state.get("current_user"):
CURRENT_USER = st.session_state.current_user
IS_ADMIN = st.session_state.current_role == "admin"
IS_GUEST = False
else:
CURRENT_USER = "_guest_"
IS_ADMIN = False
IS_GUEST = True
# 登录后自动收起侧边栏(通过 JS 模拟点击收起按钮)
if not IS_GUEST and not st.session_state.get("_sidebar_collapsed_once"):
st.session_state["_sidebar_collapsed_once"] = True
components.html(
"""
<script>
(function() {
const btn = window.parent.document.querySelector(
'[data-testid="stSidebarCollapseButton"] button'
);
if (btn) btn.click();
})();
</script>
""",
height=0,
)
# =========================
# 3. 安全配置与 Embedding 策略
# =========================
TAVILY_KEY = st.secrets.get("TAVILY_API_KEY", "")
DS_API_KEY = st.secrets.get("DEEPSEEK_API_KEY", "")
BAIDU_TOKEN = st.secrets.get("BAIDU_BEARER_TOKEN", "")
BAIDU_APP_ID = st.secrets.get("BAIDU_APP_ID", "")
OR_KEY = st.secrets.get("OPENROUTER_API_KEY", "")
@st.cache_resource
def _get_embedding_client():
from openai import OpenAI
if BAIDU_TOKEN and BAIDU_APP_ID:
return OpenAI(
api_key=BAIDU_TOKEN,
base_url="https://qianfan.baidubce.com/v2",
default_headers={"appid": BAIDU_APP_ID},
), "qwen3-embedding-0.6b" # 升级:8192 token上下文,性能更优
if OR_KEY:
return OpenAI(
api_key=OR_KEY,
base_url="https://openrouter.ai/api/v1",
), "BAAI/bge-small-zh"
return None, None
def _api_encode(texts):
np = _get_np()
client, model = _get_embedding_client()
if client is None:
return None
try:
batch_size = 32
all_vecs = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
resp = client.embeddings.create(model=model, input=batch)
all_vecs.extend([np.array(item.embedding) for item in resp.data])
return all_vecs
except Exception as e:
logger.warning(f"API embedding 失败,回退到本地模型: {e}")
return None
def _get_local_model():
if "_local_emb_model" not in st.session_state:
try:
with st.spinner("API 不可用,正在加载本地向量模型(仅首次)..."):
from sentence_transformers import SentenceTransformer
st.session_state._local_emb_model = SentenceTransformer("BAAI/bge-small-zh")
except ImportError:
logger.error("sentence-transformers 未安装,本地模型不可用")
return None
return st.session_state.get("_local_emb_model")
def encode_texts(texts):
if not texts:
return []
if isinstance(texts, str):
texts = [texts]
result = _api_encode(texts)
if result is not None:
return result
model = _get_local_model()
if model is None:
st.error("❌ Embedding 服务不可用:API 调用失败且本地模型未安装。请检查 API Key 配置。")
return []
return list(model.encode(texts))
def encode_query(text):
vecs = encode_texts([text])
return vecs[0]
# =========================
# 4. Supabase 索引管理(替代本地文件)
# =========================
def _save_chunks_to_db(scope, chunks, vectors, source_file):
"""将新切片批量写入 Supabase documents 表。"""
rows = []
for content, vec, src in zip(chunks, vectors, [source_file] * len(chunks)):
rows.append({
"scope": scope,
"source_file": src,
"content": content,
"embedding": vec.tolist() if hasattr(vec, 'tolist') else list(vec),
})
# Supabase 批量插入(每次最多 500 行)
batch_size = 500
for i in range(0, len(rows), batch_size):
_sb().table("documents").insert(rows[i:i + batch_size]).execute()
def _delete_chunks_by_file(scope, filename):
"""删除指定 scope + filename 的所有切片。"""
_sb().table("documents").delete().eq("scope", scope).eq("source_file", filename).execute()
def _clear_all_chunks(scope):
"""清空指定 scope 的所有文档切片。"""
_sb().table("documents").delete().eq("scope", scope).execute()
def _count_chunks(scope):
"""返回指定 scope 的切片数量。"""
try:
resp = _sb().table("documents").select("id", count="exact").eq("scope", scope).execute()
return resp.count or 0
except Exception:
return 0
# --- 原始文件管理(Supabase Storage + uploaded_files 表)---
def _safe_storage_path(scope, filename):
"""生成 ASCII 安全的 Storage 路径(Supabase Storage 不支持中文/特殊字符)。"""
import os
ext = os.path.splitext(filename)[1].lower() # 保留扩展名
safe_name = uuid.uuid4().hex[:16] + ext
return f"{scope}/{safe_name}"
def _save_uploaded_file_to_storage(scope, uploaded_file):
"""上传原始文件到 Supabase Storage,并记录元数据。"""
uploaded_file.seek(0)
file_bytes = uploaded_file.read()
# 检查是否已有同名文件记录,复用其 storage_path
try:
existing = _sb().table("uploaded_files").select("storage_path").eq(
"scope", scope).eq("filename", uploaded_file.name).execute()
if existing.data:
storage_path = existing.data[0]["storage_path"]
else:
storage_path = _safe_storage_path(scope, uploaded_file.name)
except Exception:
storage_path = _safe_storage_path(scope, uploaded_file.name)
# 上传到 Storage(存在则覆盖)
try:
_sb().storage.from_(STORAGE_BUCKET).upload(
storage_path, file_bytes,
file_options={"content-type": "application/octet-stream", "upsert": "true"}
)
except Exception as e:
logger.warning(f"Storage upload fallback: {e}")
try:
_sb().storage.from_(STORAGE_BUCKET).remove([storage_path])
except Exception:
pass
_sb().storage.from_(STORAGE_BUCKET).upload(
storage_path, file_bytes,
file_options={"content-type": "application/octet-stream"}
)
# 记录元数据到 uploaded_files 表(filename 保留原始中文名)
_sb().table("uploaded_files").upsert({
"scope": scope,
"filename": uploaded_file.name,
"file_size": len(file_bytes),
"storage_path": storage_path,
}, on_conflict="scope,filename").execute()
_list_uploaded_files_db.clear() # 清除文件列表缓存
@st.cache_data(ttl=15) # 缓存15秒,文件列表变化不频繁
def _list_uploaded_files_db(scope):
"""列出某个 scope 已上传的文件。返回 [(filename, size_str, storage_path), ...]。"""
try:
resp = _sb().table("uploaded_files").select(
"filename, file_size, storage_path"
).eq("scope", scope).order("filename").execute()
result = []
for row in resp.data:
size = row.get("file_size", 0) or 0
if size < 1024:
size_str = f"{size}B"
elif size < 1048576:
size_str = f"{size / 1024:.1f}KB"
else:
size_str = f"{size / 1048576:.1f}MB"
result.append((row["filename"], size_str, row.get("storage_path", "")))
return result
except Exception as e:
logger.error(f"列出文件失败 [scope={scope}]: {e}")
return []
def _delete_uploaded_file_from_storage(scope, filename):
"""删除 Storage 中的文件和 uploaded_files 表记录。"""
try:
# 从数据库查真实 storage_path
resp = _sb().table("uploaded_files").select("storage_path").eq(
"scope", scope).eq("filename", filename).execute()
if resp.data:
storage_path = resp.data[0]["storage_path"]
_sb().storage.from_(STORAGE_BUCKET).remove([storage_path])
except Exception as e:
logger.warning(f"Storage 删除失败: {e}")
try:
_sb().table("uploaded_files").delete().eq("scope", scope).eq("filename", filename).execute()
_list_uploaded_files_db.clear() # 清除文件列表缓存
except Exception as e:
logger.warning(f"uploaded_files 记录删除失败: {e}")
def _clear_uploaded_files_storage(scope):
"""清空某个 scope 的所有上传文件。"""
try:
resp = _sb().table("uploaded_files").select("storage_path").eq("scope", scope).execute()
paths = [row["storage_path"] for row in resp.data]
if paths:
_sb().storage.from_(STORAGE_BUCKET).remove(paths)
_sb().table("uploaded_files").delete().eq("scope", scope).execute()
_list_uploaded_files_db.clear() # 清除文件列表缓存
except Exception as e:
logger.warning(f"清空文件失败 [scope={scope}]: {e}")
PUBLIC_SCOPE = "public"
PRIVATE_SCOPE = CURRENT_USER # 私有库 scope = 用户名
# --- 定时同步:检测其他用户对文档库的修改 ---
_SYNC_INTERVAL = 120 # 每 120 秒检查一次
def _check_and_sync():
"""检测文档数量变化,用于多用户同步感知。"""
now = time.time()
last_check = st.session_state.get("_sync_last_check", 0)
if now - last_check < _SYNC_INTERVAL:
return
st.session_state["_sync_last_check"] = now
for scope, label in [(PUBLIC_SCOPE, "public"), (PRIVATE_SCOPE, "private")]:
count_key = f"_sync_count_{label}"
current_count = _count_chunks(scope)
prev_count = st.session_state.get(count_key, -1)
if prev_count >= 0 and current_count != prev_count:
logger.info(f"[SYNC] {label} 库变更: {prev_count} -> {current_count}")
st.session_state[count_key] = current_count
_check_and_sync()
# --- 登录用户自动模式切换:根据知识库内容决定默认模式 ---
def _has_any_kb_content():
"""检查是否有任何知识库内容(公共库或私有库)"""
pub_count = st.session_state.get("_sync_count_public")
if pub_count is None:
pub_count = _count_chunks(PUBLIC_SCOPE)
st.session_state["_sync_count_public"] = pub_count
priv_count = st.session_state.get("_sync_count_private")
if priv_count is None:
priv_count = _count_chunks(PRIVATE_SCOPE)
st.session_state["_sync_count_private"] = priv_count
return (pub_count + priv_count) > 0
if not IS_GUEST and "sel_web" not in st.session_state:
# 首次加载时,有任何知识库内容则默认知识库模式,否则联网模式
st.session_state["sel_web"] = not _has_any_kb_content()
# =========================
# 5. 缓存 LLM 客户端
# =========================
@st.cache_resource
def get_or_client():
from openai import OpenAI
return OpenAI(api_key=OR_KEY, base_url="https://openrouter.ai/api/v1")
@st.cache_resource
def get_ds_client():
from openai import OpenAI
return OpenAI(api_key=DS_API_KEY, base_url="https://api.deepseek.com")
@st.cache_resource
def get_baidu_client():
from openai import OpenAI
return OpenAI(
api_key=BAIDU_TOKEN,
base_url="https://qianfan.baidubce.com/v2",
default_headers={"appid": BAIDU_APP_ID},
)
# =========================
# 6. 实用功能函数
# =========================
_text_splitter_cache = None
def _get_text_splitter():
global _text_splitter_cache
if _text_splitter_cache is None:
from langchain_text_splitters import RecursiveCharacterTextSplitter
_text_splitter_cache = RecursiveCharacterTextSplitter(chunk_size=400, chunk_overlap=50)
return _text_splitter_cache
SYSTEM_PROMPT = (
"你是一个专业的知识问答助手。请基于提供的参考资料回答用户问题。"
"如果资料中没有相关信息,请诚实说明。回答要准确、有条理、简洁。"
"不要编造不在资料中的信息。"
)
SYSTEM_PROMPT_WEB = (
"你是一个智能AI助手,擅长结合互联网搜索结果回答用户问题。"
"请根据搜索结果提供准确、有条理的回答。"
"如果搜索结果不足以回答问题,请结合你自身的知识进行补充。"
"回答要简洁实用,注明信息来源。"
)
SYSTEM_PROMPT_DIRECT = (
"你是一个智能AI助手。请直接回答用户的问题。"
"回答要准确、有条理、简洁实用。"
"可以充分发挥你的知识和能力来帮助用户。"
)
def web_search(query):
if not TAVILY_KEY:
return "⚠️ 未配置搜索 Key"
from tavily import TavilyClient
tavily = TavilyClient(api_key=TAVILY_KEY)
current_year = datetime.now().year
try:
search_result = tavily.search(
query=f"{current_year}{query}",
search_depth="advanced",
max_results=3,
)
results = [
f"来源: {r.get('url')}\n内容: {r.get('content', '')[:700]}"
for r in search_result["results"]
]
return "\n\n".join(results)[:2500]
except Exception as e:
logger.error(f"联网搜索异常: {e}")
return f"联网搜索异常:{str(e)}"
def estimate_tokens(text):
if not text:
return 0
zh_count = sum(1 for c in text if "\u4e00" <= c <= "\u9fff")
return int(zh_count * 1.5 + (len(text) - zh_count) * 0.4)
def extract_text(file):
fname = file.name.lower()
text = ""
try:
if fname.endswith(".txt"):
text = file.read().decode("utf-8", errors="ignore")
elif fname.endswith(".pdf"):
import io
file.seek(0)
pdf_bytes = file.read()
if len(pdf_bytes) < 100:
raise ValueError("PDF 文件过小,可能已损坏")
if not pdf_bytes[:5] == b"%PDF-":
raise ValueError("不是有效的 PDF 文件(缺少 %PDF- 头)")
pdf_stream = io.BytesIO(pdf_bytes)
pages_text = []
import pdfplumber
with pdfplumber.open(pdf_stream) as pdf:
for i, page in enumerate(pdf.pages):
try:
page_text = page.extract_text() or ""
pages_text.append(page_text)
except Exception as page_err:
logger.warning(f"PDF 第{i+1}页解析失败: {page_err}")
pages_text.append("")
text = "\n".join(pages_text)
elif fname.endswith(".docx"):
from docx import Document
doc = Document(file)
text = "\n".join(para.text for para in doc.paragraphs)
except Exception as e:
logger.error(f"文件解析失败 [{file.name}]: {e}", exc_info=True)
st.error(f"解析失败: {e}")
return text
def process_upload(uploaded_files, target_prefix, scope):
"""处理上传文件:解析 → 切片 → 编码 → 写入 Supabase。"""
if not uploaded_files:
return False
try:
all_new_chunks = []
all_new_sources = []
with st.spinner("正在自动解析文档并更新索引..."):
for f in uploaded_files:
try:
# 保存原始文件到 Supabase Storage
f.seek(0)
_save_uploaded_file_to_storage(scope, f)
# 解析文本
f.seek(0)
raw_text = extract_text(f)
if not raw_text.strip():
st.warning(f"文件 {f.name} 内容为空,已跳过。")
continue
chunks = _get_text_splitter().split_text(raw_text)
all_new_chunks.extend(chunks)
all_new_sources.extend([f.name] * len(chunks))
except Exception as file_err:
logger.error(f"文件 {f.name} 处理失败: {file_err}", exc_info=True)
st.warning(f"⚠️ 文件 {f.name} 处理失败:{str(file_err)[:100]},已跳过。")
if all_new_chunks:
# 分批编码
batch_size = 64
all_vecs = []
for i in range(0, len(all_new_chunks), batch_size):
batch = all_new_chunks[i:i + batch_size]
all_vecs.extend(encode_texts(batch))
# 按 source_file 分组写入 Supabase
file_groups = {}
for chunk, vec, src in zip(all_new_chunks, all_vecs, all_new_sources):
file_groups.setdefault(src, ([], []))
file_groups[src][0].append(chunk)
file_groups[src][1].append(vec)
for src_file, (chunks, vecs) in file_groups.items():
_save_chunks_to_db(scope, chunks, vecs, src_file)
# 递增上传组件 key
ukey = f"_upload_ver_{target_prefix}"
st.session_state[ukey] = st.session_state.get(ukey, 0) + 1
# 立即刷新缓存的切片计数
st.session_state[f"_sync_count_{target_prefix}"] = _count_chunks(scope)
# 清除文件列表缓存
_list_uploaded_files_db.clear()
# 上传成功后自动切换到知识库模式
st.session_state["sel_web"] = False
st.toast(f"✅ 导入 {len(all_new_chunks)} 个切片")
st.rerun()
else:
st.error("解析失败,未发现有效文字内容。")
except Exception as e:
logger.error(f"上传处理异常: {e}", exc_info=True)
st.error(f"❌ 上传处理出错:{str(e)[:200]}")
return False
# =========================
# 6.5 聊天记录持久化(Supabase chat_history 表)
# =========================
import threading
def _async_run(fn, *args):
"""在后台线程执行耗时操作,不阻塞 UI。"""
t = threading.Thread(target=fn, args=args, daemon=True)
t.start()
def _save_chat_message_sync(username, role, content, meta=""):
"""同步保存单条聊天消息到数据库(内部使用)。"""
try:
_sb().table("chat_history").insert({
"username": username,
"role": role,
"content": content,
"meta": meta or "",
}).execute()
except Exception as e:
logger.warning(f"保存聊天记录失败: {e}")
def _save_chat_message(username, role, content, meta=""):
"""异步保存聊天消息,不阻塞 UI。"""
t = threading.Thread(
target=_save_chat_message_sync,
args=(username, role, content, meta),
daemon=True,
)
t.start()
@st.cache_data(ttl=10) # 缓存10秒,减少重复数据库查询
def _load_chat_history(username, limit=50):
"""加载用户最近的聊天记录。"""
try:
resp = _sb().table("chat_history").select(
"role, content, meta, created_at"
).eq("username", username).order(
"created_at", desc=True
).limit(limit).execute()
if not resp.data:
return []
rows = list(reversed(resp.data))
return [
{"role": r["role"], "content": r["content"],
"meta": r.get("meta", ""), "created_at": r.get("created_at", "")}
for r in rows
]
except Exception as e:
logger.warning(f"加载聊天记录失败: {e}")
return []
def _clear_chat_history_db(username):
"""清空用户在数据库中的所有聊天记录。"""
try:
_sb().table("chat_history").delete().eq("username", username).execute()
except Exception as e:
logger.warning(f"清空聊天记录失败: {e}")
# =========================
# 7. 侧边栏 UI & 逻辑
# =========================
# 游客模式模型(免费,通过 OpenRouter)
_GUEST_MODELS = {
"⭐ Step-3.5 (首选)": "stepfun/step-3.5-flash:free",
"🌐 OR-Auto (避堵)": "openrouter/free",
"🧠 GLM-4.5 (推理)": "z-ai/glm-4.5-air:free",
"🔥 Gemma-3-27B (旗舰)": "google/gemma-3-27b-it:free",
"🐋 Nemotron (120B)": "nvidia/nemotron-3-super-120b-a12b:free",
"⚡ Trinity-L (极速)": "arcee-ai/trinity-large-preview:free",
"💭 Liquid-Think (思维链)": "liquid/lfm-2.5-1.2b-thinking:free",
"🏎️ Liquid-Ins (1.0s)": "liquid/lfm-2.5-1.2b-instruct:free",
"⚖️ Gemma-3-12B (平衡)": "google/gemma-3-12b-it:free",
"💎 Gemma-3n-e4b (稳)": "google/gemma-3n-e4b-it:free",
"🤖 Nemotron-Nano (混)": "nvidia/nemotron-3-nano-30b-a3b:free",
"📉 Trinity-M (1.8s)": "arcee-ai/trinity-mini:free",
"🍃 Nemotron-9B": "nvidia/nemotron-nano-9b-v2:free",
"🪶 Gemma-3-4B": "google/gemma-3-4b-it:free",
"🫧 Gemma-3n-e2b": "google/gemma-3n-e2b-it:free",
"📷 Nemotron-VL": "nvidia/nemotron-nano-12b-v2-vl:free",
}
# 登录模式模型(收费,通过百度千帆 API)
_USER_MODELS = {
"🧠 文心5.0思维链": "ernie-5.0-thinking-latest",
"🔥 文心5.0": "ernie-5.0",
"⚡ 文心4.5-Turbo": "ernie-4.5-turbo-128k",
"🐋 DeepSeek-V3.2思维": "deepseek-v3.2-think",
"💎 DeepSeek-V3.2": "deepseek-v3.2",
"🔮 DeepSeek-R1": "deepseek-r1",
"🌟 千问3.5-397B": "qwen3.5-397b-a17b",
"💭 千问3-235B思维": "qwen3-235b-a22b-thinking-2507",
"📊 千问3-32B": "qwen3-32b",
"🎯 GLM-5": "glm-5",
"🌙 Kimi-K2.5": "kimi-k2.5",
"✨ MiniMax-M2.5": "minimax-m2.5",
"🛡️ DeepSeek官方": "deepseek-chat",
"🏢 文心3.5": "ernie-3.5-8k",
}
# 根据用户状态选择模型列表
if IS_GUEST:
model_mapping = _GUEST_MODELS
_active_models = _GUEST_MODELS
else:
model_mapping = _USER_MODELS
_active_models = _USER_MODELS
with st.sidebar:
# --- 检索参数设置(仅登录用户可见) ---
if not IS_GUEST:
with st.expander("⚙️ 检索参数"):
c1, c2 = st.columns(2)
with c1:
ui_top_k = st.slider("Top-K", min_value=1, max_value=15, value=5, key="sel_topk")
with c2:
ui_threshold = st.slider("阈值", min_value=0.0, max_value=1.0, value=0.25, step=0.05, key="sel_threshold")
# --- 以下面板仅登录用户可见 ---
if not IS_GUEST:
pub_chunk_count = st.session_state.get("_sync_count_public")
if pub_chunk_count is None:
pub_chunk_count = _count_chunks(PUBLIC_SCOPE)
st.session_state["_sync_count_public"] = pub_chunk_count
with st.expander(f"📚 公共知识库({pub_chunk_count} 切片)"):
st.caption("所有人可搜索")
# 文件列表
pub_file_list = _list_uploaded_files_db(PUBLIC_SCOPE)
if pub_file_list:
# 先检查是否有删除请求,避免重复渲染
delete_target = None
if IS_ADMIN:
for fname, _, _ in pub_file_list:
if st.session_state.get(f"delpub_{fname}"):
delete_target = fname
break
if delete_target:
_delete_chunks_by_file(PUBLIC_SCOPE, delete_target)
_delete_uploaded_file_from_storage(PUBLIC_SCOPE, delete_target)
_list_uploaded_files_db.clear()
st.session_state["_sync_count_public"] = _count_chunks(PUBLIC_SCOPE)
# 删除后如果所有知识库为空,自动切换到联网模式
if not _has_any_kb_content():
st.session_state["sel_web"] = True
st.toast(f"已删除 {delete_target}")
st.rerun()
st.caption(f"📎 已上传 {len(pub_file_list)} 个文件:")
for fname, size_str, _ in pub_file_list:
if IS_ADMIN:
col_name, col_del = st.columns([4, 1])
col_name.text(f"📄 {fname} ({size_str})")
col_del.button("🗑", key=f"delpub_{fname}")
else:
st.text(f"📄 {fname} ({size_str})")
if IS_ADMIN:
pub_upload_key = f"upload_public_{st.session_state.get('_upload_ver_public', 0)}"
# 使用 on_change 回调标记有新文件上传
def _on_pub_upload_change():
st.session_state["_pending_upload_public"] = True
pub_files = st.file_uploader(
"上传到公共库",
type=["txt", "pdf", "docx"],
accept_multiple_files=True,
label_visibility="collapsed",
key=pub_upload_key,
on_change=_on_pub_upload_change,
)
# 检查是否有待处理的上传(通过 on_change 标记)
if pub_files and st.session_state.pop("_pending_upload_public", False):
process_upload(pub_files, "public", PUBLIC_SCOPE)
if pub_chunk_count > 0 and len(pub_file_list) >= 2:
if st.button("🗑️ 清空公共库", use_container_width=True, type="secondary", key="clear_pub"):
_clear_all_chunks(PUBLIC_SCOPE)
_clear_uploaded_files_storage(PUBLIC_SCOPE)
_list_uploaded_files_db.clear()
st.session_state["_sync_count_public"] = 0
# 清空后如果所有知识库为空,自动切换到联网模式
if not _has_any_kb_content():
st.session_state["sel_web"] = True
st.toast("公共知识库已清空")
st.rerun()
else:
st.caption("*仅管理员可维护公共库*")
# --- 私有知识库 ---
priv_chunk_count = st.session_state.get("_sync_count_private")
if priv_chunk_count is None:
priv_chunk_count = _count_chunks(PRIVATE_SCOPE)
st.session_state["_sync_count_private"] = priv_chunk_count
with st.expander(f"🔒 我的私有库({priv_chunk_count} 切片)"):
st.caption(f"用户:{CURRENT_USER},仅自己可见")
priv_file_list = _list_uploaded_files_db(PRIVATE_SCOPE)
if priv_file_list:
# 先检查是否有删除请求,避免重复渲染
delete_target = None
for fname, _, _ in priv_file_list:
if st.session_state.get(f"delpriv_{fname}"):
delete_target = fname
break
if delete_target:
_delete_chunks_by_file(PRIVATE_SCOPE, delete_target)
_delete_uploaded_file_from_storage(PRIVATE_SCOPE, delete_target)
_list_uploaded_files_db.clear()
st.session_state["_sync_count_private"] = _count_chunks(PRIVATE_SCOPE)
# 删除后如果所有知识库为空,自动切换到联网模式
if not _has_any_kb_content():
st.session_state["sel_web"] = True
st.toast(f"已删除 {delete_target}")
st.rerun()
st.caption(f"📎 已上传 {len(priv_file_list)} 个文件:")
for fname, size_str, _ in priv_file_list:
col_name, col_del = st.columns([4, 1])
col_name.text(f"📄 {fname} ({size_str})")
col_del.button("🗑", key=f"delpriv_{fname}")
priv_upload_key = f"upload_private_{st.session_state.get('_upload_ver_private', 0)}"
# 使用 on_change 回调标记有新文件上传
def _on_priv_upload_change():
st.session_state["_pending_upload_private"] = True
priv_files = st.file_uploader(
"上传到私有库",
type=["txt", "pdf", "docx"],
accept_multiple_files=True,
label_visibility="collapsed",
key=priv_upload_key,
on_change=_on_priv_upload_change,
)
# 检查是否有待处理的上传(通过 on_change 标记)
if priv_files and st.session_state.pop("_pending_upload_private", False):
process_upload(priv_files, "private", PRIVATE_SCOPE)
if priv_chunk_count > 0 and len(priv_file_list) >= 2:
if st.button("🗑️ 清空我的私有库", use_container_width=True, type="secondary", key="clear_priv"):
_clear_all_chunks(PRIVATE_SCOPE)
_clear_uploaded_files_storage(PRIVATE_SCOPE)
st.session_state["_sync_count_private"] = 0
# 清空后如果所有知识库为空,自动切换到联网模式
if not _has_any_kb_content():
st.session_state["sel_web"] = True
st.toast("私有知识库已清空")
st.rerun()
# --- 修改密码 ---
with st.expander("🔐 修改密码"):
with st.form("change_password_form", clear_on_submit=False):
old_pass = st.text_input("当前密码", type="password", key="self_old_pass")
new_pass1 = st.text_input("新密码", type="password", key="self_new_pass1")
new_pass2 = st.text_input("确认新密码", type="password", key="self_new_pass2")
submitted = st.form_submit_button("✅ 确认修改", use_container_width=True)
if submitted:
ok, _, _ = verify_user(CURRENT_USER, old_pass)
if not ok:
st.error("当前密码错误")
elif len(new_pass1) < 4:
st.error("新密码至少 4 个字符")
elif new_pass1 != new_pass2:
st.error("两次新密码不一致")
else:
_save_user(CURRENT_USER, _hash_password(new_pass1),
st.session_state.current_role)
st.success("密码修改成功")
# --- 管理员面板 ---
if IS_ADMIN:
with st.expander("👥 用户管理"):
all_users = _load_users()
user_list = [(u, info) for u, info in all_users.items() if isinstance(info, dict)]
st.caption(f"共 **{len(user_list)}** 个用户")
for uname, uinfo in user_list:
role_tag = "👑" if uinfo.get("role") == "admin" else "👤"
created = uinfo.get("created_at", "未知")
st.text(f"{role_tag} {uname}{created})")
deletable = [u for u, _ in user_list if u != CURRENT_USER]
if deletable:
del_target = st.selectbox("选择要删除的用户", deletable, key="del_user_select")
if st.button("❌ 删除该用户", key="btn_del_user"):
_delete_user_db(del_target)
# 清除该用户的私有库
_clear_all_chunks(del_target)
_clear_uploaded_files_storage(del_target)
st.toast(f"用户 {del_target} 已删除")
st.rerun()
resetable = [u for u, _ in user_list if u != CURRENT_USER]
if resetable:
reset_target = st.selectbox("选择要重置密码的用户", resetable, key="reset_user_select")
with st.form("reset_password_form", clear_on_submit=False):
new_pass = st.text_input("新密码", type="password", key="reset_new_pass")
reset_submitted = st.form_submit_button("🔄 重置密码", use_container_width=True)
if reset_submitted:
if len(new_pass) < 4:
st.error("密码至少 4 个字符")
else:
target_role = all_users[reset_target].get("role", "user")
_save_user(reset_target, _hash_password(new_pass), target_role)
st.toast(f"用户 {reset_target} 密码已重置")
with st.expander("📩 邀请码管理"):
current_code = _get_invite_code()
st.text(f"当前邀请码:{current_code if current_code else '未设置'}")
with st.form("invite_code_form", clear_on_submit=False):
new_code = st.text_input("新邀请码", key="new_invite_code")
code_submitted = st.form_submit_button("✏️ 更新邀请码", use_container_width=True)
if code_submitted:
if new_code.strip():
_set_invite_code(new_code.strip())
st.toast("邀请码已更新")
st.rerun()
else:
st.error("邀请码不能为空")
with st.expander("🛠️ 数据库概览"):
st.caption("Supabase 数据统计")
try:
pub_cnt = _count_chunks(PUBLIC_SCOPE)
st.text(f"📚 公共库切片数: {pub_cnt}")
# 简单统计各用户私有库
for uname, _ in user_list:
cnt = _count_chunks(uname)
if cnt > 0:
st.text(f"🔒 {uname} 私有库: {cnt} 切片")
except Exception as e:
st.warning(f"统计失败: {e}")
st.divider()
st.caption("📋 用户列表")
display_users = {}
for k, v in all_users.items():
if isinstance(v, dict) and "password_hash" in v:
v_copy = dict(v)
v_copy["password_hash"] = v_copy["password_hash"][:8] + "..."
display_users[k] = v_copy
else:
display_users[k] = v
st.json(display_users)
# --- 聊天记录管理(仅登录用户) ---
if not IS_GUEST:
with st.expander("💬 聊天记录"):
hist_tab_new, hist_tab_history = st.tabs(["当前对话", "历史记录"])
with hist_tab_new:
st.caption("清空当前对话(数据库记录保留)")
if st.button("🧹 清空当前对话", use_container_width=True, type="secondary", key="btn_clear_chat"):
st.session_state.messages = []
st.session_state["_chat_cleared"] = True # 标记为主动清空
st.rerun()
st.caption("清空所有历史记录(不可恢复)")
if st.button("🗑️ 清空全部记录", use_container_width=True, type="secondary", key="btn_clear_all_hist"):
_async_run(_clear_chat_history_db, CURRENT_USER)
st.session_state.messages = []
st.session_state["_chat_cleared"] = True # 标记为主动清空
st.toast("所有聊天记录已清空")
st.rerun()
with hist_tab_history:
if st.button("🔄 加载历史记录", use_container_width=True, key="btn_load_hist"):
st.session_state["_show_history"] = True
if st.session_state.get("_show_history"):
history = _load_chat_history(CURRENT_USER, limit=100)
if not history:
st.info("暂无历史记录")
else:
st.caption(f"共 {len(history)} 条记录")
for msg in history:
ts = msg.get("created_at", "")[:16].replace("T", " ")
icon = "🧑" if msg["role"] == "user" else "🤖"
preview = msg["content"][:80].replace("\n", " ")
st.text(f"{icon} [{ts}] {preview}{'...' if len(msg['content']) > 80 else ''}")
# =========================
# 8. 核心搜索逻辑(合并公共库 + 私有库)
# =========================
def search_local(query, top_k, threshold):
"""使用 pgvector 数据库端向量搜索(替代内存计算)。返回 (docs, files)"""
query_vec = encode_query(query)
scopes = [PUBLIC_SCOPE, PRIVATE_SCOPE]
try:
resp = _sb().rpc("match_documents", {
"query_embedding": query_vec.tolist() if hasattr(query_vec, 'tolist') else list(query_vec),
"match_scopes": scopes,
"match_threshold": float(threshold),
"match_count": int(top_k),
}).execute()
if resp.data:
docs = [row["content"] for row in resp.data]
files = list(set(row["source_file"] for row in resp.data)) # 去重
return docs, files
return [], []
except Exception as e:
logger.error(f"pgvector 搜索失败: {e}")
return [], []
# =========================
# 9. LLM 回答逻辑
# =========================
def llm_answer(query, context_docs, selected_display_name, web_enabled, kb_mode=False, source_files=None):
all_context = ""
curr_time = datetime.now().strftime("%Y-%m-%d %H:%M")
st.session_state["kb_fallback"] = False # 重置降级标志
st.session_state["kb_source_files"] = source_files or [] # 记录引用文件
if web_enabled:
# 联网模式:只用网络搜索,不用知识库
search_res = web_search(query)
all_context = f"【互联网搜索结果】:\n{search_res}"
prompt_content = f"当前时间:{curr_time}\n\n{all_context[:6500]}\n\n用户问题:{query}"
system_prompt = SYSTEM_PROMPT_WEB
elif context_docs:
# 知识库模式:有知识库资料
all_context = "【知识库资料】:\n" + "\n".join(context_docs)
prompt_content = f"当前时间:{curr_time}\n\n参考资料:\n{all_context[:6500]}\n\n用户问题:{query}"
system_prompt = SYSTEM_PROMPT
elif kb_mode:
# 知识库模式但无匹配资料:降级到直接回答,但标记降级
st.session_state["kb_fallback"] = True
prompt_content = f"当前时间:{curr_time}\n\n用户问题:{query}"
system_prompt = SYSTEM_PROMPT_DIRECT
else:
# 直接回答模式:不联网、无知识库,直接用大模型回答
prompt_content = f"当前时间:{curr_time}\n\n用户问题:{query}"
system_prompt = SYSTEM_PROMPT_DIRECT
input_tokens = estimate_tokens(prompt_content)
or_client = get_or_client()
ds_client = get_ds_client()
baidu_client = get_baidu_client()
selected_id = model_mapping[selected_display_name]
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt_content},
]
# 登录用户:直接调用选中模型(通过百度千帆),不使用重试队列
if not IS_GUEST:
# 登录态模型都通过百度千帆调用,除了 deepseek-chat 用官方 API
if selected_id == "deepseek-chat":
client = ds_client
else:
client = baidu_client
label = f"🔐 {selected_display_name}"
logger.info(f"[{CURRENT_USER}] 登录用户调用: {label}")
try:
response = client.chat.completions.create(
model=selected_id,
messages=messages,
stream=True,
timeout=60, # 登录用户给更长超时
)
full_text = ""
has_content = False
# 直接迭代流式响应
for chunk in response:
if chunk.choices and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
full_text += content
has_content = True
yield content
if has_content:
fallback_hint = '<span class="meta-extra"><span class="meta-sep"> | </span>⚠️ 知识库无匹配,使用通用知识</span>' if st.session_state.get("kb_fallback") else ""
source_files = st.session_state.get("kb_source_files", [])
files_hint = f'<span class="meta-extra"><span class="meta-sep"> | </span>📄 {", ".join(source_files)}</span>' if source_files else ""
st.session_state["last_meta"] = (
f"🟢 {label} | 📊 ~{input_tokens}/{estimate_tokens(full_text)} Tokens{files_hint}{fallback_hint}"
)
return
else:
yield "❌ 模型返回空响应,请稍后重试。"
return
except Exception as e:
err_msg = str(e)
logger.error(f"{label} 调用失败: {err_msg[:200]}")
yield f"❌ 调用失败:{err_msg[:100]}"
return
# 游客用户:使用重试队列和兜底机制
retry_queue = []
retry_queue.append(
(or_client, selected_id, f"首选-{selected_display_name}")
)
if selected_id != "stepfun/step-3.5-flash:free":
retry_queue.append((or_client, "stepfun/step-3.5-flash:free", "⚡ 快速备选-Step3.5"))
if selected_id != "openrouter/free":
retry_queue.append((or_client, "openrouter/free", "OR-Auto 免费避堵"))
# 游客的收费兜底
paid_backups = [
("deepseek-chat", "🛡️ DeepSeek 官方", ds_client),
("ernie-3.5-8k", "🏢 百度文心", baidu_client),
]
for p_id, p_label, p_client in paid_backups:
if selected_id != p_id:
retry_queue.append((p_client, p_id, f"💰 收费兜底-{p_label}"))
for idx, (client, m_id, label) in enumerate(retry_queue):
logger.info(f"[{CURRENT_USER}] 尝试链路: {label}")
try:
extra_h = (
{"HTTP-Referer": "https://streamlit.io", "X-Title": "RAG_v3"}
if client is or_client
else None
)
response = client.chat.completions.create(
model=m_id,
messages=messages,
stream=True,
extra_headers=extra_h,
timeout=25,
)
full_text = ""
has_content = False
# 直接迭代流式响应
for chunk in response:
if chunk.choices and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
full_text += content
has_content = True
yield content
if has_content:
st.session_state["last_meta"] = (
f"🟢 {label} | 📊 ~{input_tokens}/{estimate_tokens(full_text)} Tokens"
)
return
except Exception as e:
err_msg = str(e)
logger.warning(f"{label} 失败: {err_msg[:100]}")
if "429" in err_msg:
st.toast(f"{label} 拥堵,切换备选...", icon="⏳")
time.sleep(1.5)
continue
yield "❌ 抱歉,所有免费和收费线路均暂时不可用。"
# =========================
# 10. 聊天渲染(使用 @st.fragment 避免整页刷新)
# =========================
if "messages" not in st.session_state:
# 首次加载,从数据库恢复最近对话(游客不恢复)
if not IS_GUEST:
saved = _load_chat_history(CURRENT_USER, limit=50)
st.session_state.messages = [
{"role": m["role"], "content": m["content"],
**({"meta": m["meta"]} if m.get("meta") else {})}
for m in saved
]
else:
st.session_state.messages = []
elif st.session_state.get("_chat_cleared"):
# 用户主动清空,不从数据库恢复
del st.session_state["_chat_cleared"]
@st.fragment
def _chat_fragment():
"""聊天区域独立 fragment,提交消息时只刷新此区域,不刷新整个页面。"""
_model_name = st.session_state.get("sel_model", list(model_mapping.keys())[0])
_web_on = st.session_state.get("sel_web", True) # 默认联网模式
_top_k = st.session_state.get("sel_topk", 5)
_threshold = st.session_state.get("sel_threshold", 0.25)
# 欢迎页(无消息时显示,放在 fragment 内避免整页刷新)
_welcome_hero = st.empty()
if not st.session_state.messages:
with _welcome_hero.container():
st.markdown(
"""
<div class="welcome-marker" style="text-align:center; color:#888;">
<div style="font-size:48px; margin-bottom:12px;">🤖</div>
<div style="font-size:18px; font-weight:600; margin-bottom:8px;">欢迎使用智答 AI 助手</div>
<div style="font-size:14px; line-height:1.8; margin-bottom:20px;">
🌐 <b>联网模式</b> —— 大模型 + 网络搜索,实时回答<br>
📚 <b>知识库模式</b> —— 上传文档,基于私有知识回答
</div>
<div style="color:#999; font-size:13px; margin-bottom:12px;">💡 试试这些问题</div>
</div>
""",
unsafe_allow_html=True,
)
# 快捷问题 - 在 fragment 内部,点击只触发 fragment 刷新
_quick_questions = [
"今天有什么热点新闻",
"用Python写快速排序",
"解释RAG技术是什么",
]
_q_cols = st.columns(3)
for i, qq in enumerate(_quick_questions):
with _q_cols[i]:
if st.button(f"💬 {qq}", key=f"quick_q{i}", use_container_width=True, type="secondary"):
st.session_state["_quick_question"] = qq
# 消息容器 - 不设固定高度,由外层CSS控制滚动
chat_box = st.container()
q = st.chat_input("问问智答AI助手", key="chat_input_v3")
# 检查快捷问题
if st.session_state.get("_quick_question"):
q = st.session_state.pop("_quick_question")
with chat_box:
for idx, m in enumerate(st.session_state.messages):
with st.chat_message(m["role"]):
if m["role"] == "assistant":
meta_html = f'\n\n<span style="color:#999;font-size:12px;">{m["meta"]}</span>' if m.get("meta") else ""
action_btns = _render_action_buttons(m["content"], f"hist_{idx}")
st.markdown(
m["content"] + meta_html + action_btns,
unsafe_allow_html=True,
)
_inject_action_js(m["content"], f"hist_{idx}")
else:
st.markdown(m["content"])
if q:
_welcome_hero.empty() # 发送消息后清空欢迎页
st.session_state.messages.append({"role": "user", "content": q})
if not IS_GUEST:
_save_chat_message(CURRENT_USER, "user", q)
with st.chat_message("user"):
st.markdown(q)
with st.chat_message("assistant"):
response_container = st.empty()
# 滚动JS放在assistant块内部,不影响问题和回答之间的间距
components.html(
"""<script>
setTimeout(function(){
var c = window.parent.document.querySelector('.block-container');
if(c) c.scrollTop = c.scrollHeight;
}, 80);
</script>""",
height=0,
)
# 游客模式:始终用大模型(联网时加网络搜索,不联网直接回答)
# 登录模式 + 联网:大模型 + 网络搜索
# 登录模式 + 知识库:大模型 + 知识库检索
if IS_GUEST:
# 游客没有知识库,直接用大模型
if _web_on:
response_container.markdown("*🌐 正在联网搜索...*")
else:
response_container.markdown("*🤔 正在思考...*")
relevant_docs, source_files = [], []
elif _web_on:
response_container.markdown("*🌐 正在联网搜索...*")
relevant_docs, source_files = [], []
else:
response_container.markdown("*🔍 正在搜索知识库...*")
relevant_docs, source_files = search_local(q, _top_k, _threshold)
response_container.markdown("*🤔 正在组织语言...*")
# 登录用户 + 非联网 = 知识库模式
_kb_mode = (not IS_GUEST) and (not _web_on)
try:
full_response = response_container.write_stream(
llm_answer(q, relevant_docs, _model_name, _web_on, kb_mode=_kb_mode, source_files=source_files)
)
meta_info = st.session_state.get("last_meta", "")
# 生成工具按钮(使用当前消息数作为ID)
msg_idx = len(st.session_state.messages)
action_btns = _render_action_buttons(full_response, f"new_{msg_idx}")
meta_html = f'\n\n<span style="color:#999;font-size:12px;">{meta_info}</span>' if meta_info else ""
response_container.markdown(
full_response + meta_html + action_btns,
unsafe_allow_html=True,
)
_inject_action_js(full_response, f"new_{msg_idx}")
st.session_state.messages.append(
{"role": "assistant", "content": full_response, "meta": meta_info}
)
if not IS_GUEST:
_save_chat_message(CURRENT_USER, "assistant", full_response, meta_info)
except Exception as e:
logger.error(f"模型调用异常: {e}")
response_container.error(f"❌ 抱歉,连接模型时出错了: {str(e)}")
# =========================
# 11. 搜索框下方的工具按钮(使用 fragment 局部刷新)
# =========================
@st.fragment
def _toolbar_fragment():
_current_model = st.session_state.get("sel_model", list(model_mapping.keys())[0])
_web_status = st.session_state.get("sel_web", True)
_model_short = _current_model.split("(")[0].strip()
_btn_cols = st.columns(2)
with _btn_cols[0]:
with st.popover(_model_short, use_container_width=True):
st.caption("选择模型")
_model_list = list(_active_models.keys())
for m in _model_list:
_btn_type = "primary" if m == _current_model else "secondary"
if st.button(m, key=f"mdl_{m}", use_container_width=True, type=_btn_type):
if m != _current_model:
st.session_state["sel_model"] = m
st.rerun()
with _btn_cols[1]:
if IS_GUEST:
_web_label = "🌐 联网模式" if _web_status else "💬 直接回答"
else:
_web_label = "🌐 联网模式" if _web_status else "📚 知识库"
if st.button(_web_label, key="toggle_web", use_container_width=True):
st.session_state["sel_web"] = not _web_status
st.rerun()
_toolbar_fragment()
_chat_fragment()