Spaces:
Sleeping
Sleeping
| import os | |
| # 优化 Streamlit 连接稳定性 | |
| os.environ.setdefault("STREAMLIT_SERVER_ENABLE_WEBSOCKET_COMPRESSION", "false") | |
| os.environ.setdefault("STREAMLIT_SERVER_MAX_MESSAGE_SIZE", "200") | |
| import streamlit as st | |
| import streamlit.components.v1 as components | |
| import json | |
| import time | |
| import logging | |
| import hashlib | |
| import uuid | |
| from datetime import datetime | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") | |
| logger = logging.getLogger(__name__) | |
| # numpy 延迟导入 | |
| _np_module = None | |
| def _get_np(): | |
| global _np_module | |
| if _np_module is None: | |
| import numpy | |
| _np_module = numpy | |
| return _np_module | |
| # ========================= | |
| # 1. 页面配置 & 样式注入 | |
| # ========================= | |
| st.set_page_config(page_title="智答 AI 助手", page_icon="🤖", layout="wide", initial_sidebar_state="collapsed") | |
| def _get_custom_css(): | |
| return """ | |
| <style> | |
| [data-testid="stSidebarContent"] { padding-top: 1.5rem !important; } | |
| [data-testid="stVerticalBlock"] > div { gap: 0.2rem !important; } | |
| /* 聊天消息间距控制 */ | |
| [data-testid="stChatMessage"] { margin-bottom: 0.1rem !important; margin-top: 0.1rem !important; padding-top: 0.5rem !important; padding-bottom: 0.5rem !important; } | |
| /* 聊天消息内部 gap 归零 */ | |
| [data-testid="stChatMessageContent"] > [data-testid="stVerticalBlock"] > div { gap: 0 !important; } | |
| /* 聊天消息内标题字体大小限制 */ | |
| [data-testid="stChatMessage"] h1 { font-size: 1.3rem !important; margin: 0.8rem 0 0.5rem !important; } | |
| [data-testid="stChatMessage"] h2 { font-size: 1.15rem !important; margin: 0.7rem 0 0.4rem !important; } | |
| [data-testid="stChatMessage"] h3 { font-size: 1.05rem !important; margin: 0.6rem 0 0.3rem !important; } | |
| [data-testid="stChatMessage"] h4, [data-testid="stChatMessage"] h5, [data-testid="stChatMessage"] h6 { font-size: 1rem !important; margin: 0.5rem 0 0.3rem !important; } | |
| /* 隐藏 stHtml iframe 容器间距,但保留渲染(JS 可执行) */ | |
| [data-testid="stHtml"] { height: 0 !important; min-height: 0 !important; margin: 0 !important; padding: 0 !important; overflow: hidden !important; } | |
| [data-testid="stSelectbox"] input { caret-color: transparent !important; } | |
| /* 隐藏 text_input 的 press enter to apply 提示 */ | |
| [data-testid="stTextInput"] div[data-testid="InputInstructions"] { display: none !important; } | |
| [data-testid="stFileUploader"] section > div { display: none; } | |
| [data-testid="stFileUploaderDropzoneInstructions"] { display: none !important; } | |
| [data-testid="stFileUploader"] section::before { | |
| content: "拖拽文档至此"; | |
| color: #555; font-size: 14px; display: block; margin-bottom: 10px; | |
| } | |
| [data-testid="stFileUploader"] section::after { | |
| content: "支持格式:TXT, PDF, DOCX"; | |
| color: #888; font-size: 12px; display: block; margin-top: 5px; | |
| } | |
| [data-testid="stFileUploader"] button { font-size: 0 !important; } | |
| [data-testid="stFileUploader"] button::after { | |
| content: "选择文件"; | |
| font-size: 14px !important; | |
| } | |
| /* 顶部工具栏 */ | |
| [data-testid="stHeader"] { | |
| height: 2.5rem !important; | |
| background: transparent !important; | |
| } | |
| /* 左上角侧边栏按钮:替换箭头为「☰ 菜单」- 无边框风格 */ | |
| [data-testid="stSidebarCollapsedControl"] { | |
| top: 0.4rem !important; | |
| left: 0.5rem !important; | |
| } | |
| [data-testid="stSidebarCollapsedControl"] button { | |
| background: transparent !important; | |
| border: none !important; | |
| padding: 4px 8px !important; | |
| } | |
| [data-testid="stSidebarCollapsedControl"] button:hover { | |
| background: rgba(0,0,0,0.05) !important; | |
| border-radius: 6px !important; | |
| } | |
| [data-testid="stSidebarCollapsedControl"] button svg { display: none !important; } | |
| [data-testid="stSidebarCollapsedControl"] button::after { | |
| content: "☰ 菜单"; | |
| font-size: 14px; | |
| color: #666; | |
| } | |
| /* 侧边栏展开后的收起按钮 - 无边框风格 */ | |
| [data-testid="stSidebarCollapseButton"] button { | |
| background: transparent !important; | |
| border: none !important; | |
| padding: 4px 8px !important; | |
| } | |
| [data-testid="stSidebarCollapseButton"] button:hover { | |
| background: rgba(0,0,0,0.05) !important; | |
| border-radius: 6px !important; | |
| } | |
| [data-testid="stSidebarCollapseButton"] button svg { display: none !important; } | |
| [data-testid="stSidebarCollapseButton"] button::after { | |
| content: "✕ 收起"; | |
| font-size: 14px; | |
| color: #666; | |
| } | |
| /* 右上角三点菜单:替换为「⚙ 设置」- 无边框风格 */ | |
| [data-testid="stMainMenu"] { | |
| position: fixed !important; | |
| top: 0.55rem !important; | |
| right: 0.5rem !important; | |
| } | |
| [data-testid="stMainMenu"] button { | |
| background: transparent !important; | |
| border: none !important; | |
| padding: 4px 8px !important; | |
| } | |
| [data-testid="stMainMenu"] button:hover { | |
| background: rgba(0,0,0,0.05) !important; | |
| border-radius: 6px !important; | |
| } | |
| [data-testid="stMainMenu"] button svg { display: none !important; } | |
| [data-testid="stMainMenu"] button::before { | |
| content: "⚙ 设置"; | |
| font-size: 14px; | |
| color: #666; | |
| } | |
| /* 主内容区域:限制在标题和搜索框之间 */ | |
| .block-container { | |
| position: fixed !important; | |
| top: 7rem !important; | |
| bottom: 7rem !important; | |
| left: 0 !important; | |
| right: 0 !important; | |
| overflow-y: auto !important; | |
| -webkit-overflow-scrolling: touch !important; | |
| padding: 1rem 1rem 1rem 1rem !important; | |
| } | |
| /* 固定标题在顶部 */ | |
| .main-title { | |
| position: fixed !important; | |
| top: 2.5rem !important; | |
| left: 0 !important; | |
| right: 0 !important; | |
| z-index: 9999 !important; | |
| background: white !important; | |
| padding: 0.3rem 1rem !important; | |
| } | |
| /* 固定搜索框在底部 */ | |
| [data-testid="stChatInput"] { | |
| position: fixed !important; | |
| bottom: 3.5rem !important; | |
| left: 50% !important; | |
| transform: translateX(-50%) !important; | |
| width: calc(100% - 2rem) !important; | |
| max-width: 730px !important; | |
| z-index: 9999 !important; | |
| background: white !important; | |
| } | |
| /* 搜索框聚焦时去掉红色边框 */ | |
| [data-testid="stChatInput"] textarea:focus, | |
| [data-testid="stChatInput"] div:focus-within { | |
| outline: none !important; | |
| border-color: #ddd !important; | |
| box-shadow: none !important; | |
| } | |
| /* 底部工具栏固定在搜索框下方 */ | |
| .stMainBlockContainer > div > div > div:has([data-testid="stPopover"]) { | |
| position: fixed !important; | |
| bottom: 0.3rem !important; | |
| left: 50% !important; | |
| transform: translateX(-50%) !important; | |
| width: calc(100% - 2rem) !important; | |
| max-width: 730px !important; | |
| z-index: 9998 !important; | |
| } | |
| .stMainBlockContainer > div > div > div:has([data-testid="stPopover"]) [data-testid="stHorizontalBlock"] { | |
| display: flex !important; | |
| flex-wrap: nowrap !important; | |
| gap: 0.5rem !important; | |
| } | |
| .stMainBlockContainer > div > div > div:has([data-testid="stPopover"]) [data-testid="stHorizontalBlock"] > div { | |
| flex: 1 !important; | |
| min-width: 0 !important; | |
| } | |
| /* 确保底部工具栏按钮占满容器宽度 */ | |
| .stMainBlockContainer > div > div > div:has([data-testid="stPopover"]) [data-testid="stHorizontalBlock"] button { | |
| width: 100% !important; | |
| } | |
| /* popover 向上展开 */ | |
| [data-testid="stPopoverBody"] { | |
| bottom: 100% !important; | |
| top: auto !important; | |
| margin-bottom: 0.5rem !important; | |
| } | |
| /* 底部工具栏按钮:无边框风格,统一颜色 */ | |
| .stMainBlockContainer > div > div > div:has([data-testid="stPopover"]) button, | |
| .stMainBlockContainer > div > div > div:has([data-testid="stPopover"]) [data-testid="stPopover"] > div:first-child { | |
| background: transparent !important; | |
| border: none !important; | |
| box-shadow: none !important; | |
| color: #666 !important; | |
| } | |
| .stMainBlockContainer > div > div > div:has([data-testid="stPopover"]) button:hover, | |
| .stMainBlockContainer > div > div > div:has([data-testid="stPopover"]) [data-testid="stPopover"] > div:first-child:hover { | |
| background: rgba(0,0,0,0.05) !important; | |
| } | |
| .stMainBlockContainer > div > div > div:has([data-testid="stPopover"]) button p, | |
| .stMainBlockContainer > div > div > div:has([data-testid="stPopover"]) button span, | |
| .stMainBlockContainer > div > div > div:has([data-testid="stPopover"]) [data-testid="stMarkdownContainer"] p, | |
| .stMainBlockContainer > div > div > div:has([data-testid="stPopover"]) [data-testid="stPopover"] p { | |
| color: #666 !important; | |
| } | |
| /* 确保聊天容器可滚动 */ | |
| [data-testid="stVerticalBlockBorderWrapper"] { | |
| overflow-y: auto !important; | |
| -webkit-overflow-scrolling: touch !important; | |
| } | |
| /* 手机端响应式调整(屏幕宽度 < 768px) */ | |
| @media (max-width: 768px) { | |
| .main-title { | |
| padding: 0.2rem 0.5rem !important; | |
| } | |
| .main-title h1 { | |
| font-size: 1.2rem !important; | |
| } | |
| /* 手机端聊天消息标题更小 */ | |
| [data-testid="stChatMessage"] h1 { font-size: 1.1rem !important; } | |
| [data-testid="stChatMessage"] h2 { font-size: 1rem !important; } | |
| [data-testid="stChatMessage"] h3, [data-testid="stChatMessage"] h4, [data-testid="stChatMessage"] h5, [data-testid="stChatMessage"] h6 { font-size: 0.95rem !important; } | |
| /* 手机端工具栏:强制两列并排 */ | |
| [data-testid="stHorizontalBlock"] { | |
| display: flex !important; | |
| flex-wrap: nowrap !important; | |
| gap: 0.3rem !important; | |
| } | |
| [data-testid="stHorizontalBlock"] > div { | |
| flex: 1 !important; | |
| min-width: 0 !important; | |
| width: auto !important; | |
| } | |
| /* 按钮文字缩小,单行截断 */ | |
| [data-testid="stHorizontalBlock"] button { | |
| font-size: 12px !important; | |
| padding: 0.4rem 0.3rem !important; | |
| white-space: nowrap !important; | |
| overflow: hidden !important; | |
| text-overflow: ellipsis !important; | |
| } | |
| [data-testid="stHorizontalBlock"] button p { | |
| white-space: nowrap !important; | |
| overflow: hidden !important; | |
| text-overflow: ellipsis !important; | |
| margin: 0 !important; | |
| } | |
| /* 手机端 popover 居中显示 */ | |
| [data-testid="stPopoverBody"] { | |
| left: 1rem !important; | |
| right: 1rem !important; | |
| transform: none !important; | |
| width: auto !important; | |
| max-width: calc(100vw - 2rem) !important; | |
| bottom: auto !important; | |
| top: auto !important; | |
| } | |
| } | |
| /* 游客模式:标题区域 */ | |
| body:has(.main-title-guest) .block-container { | |
| top: 6.5rem !important; | |
| } | |
| @media (max-width: 768px) { | |
| body:has(.main-title-guest) .block-container { | |
| top: 6.5rem !important; | |
| padding: 0.5rem !important; | |
| } | |
| } | |
| /* 登录模式:标题较矮,内容区 top 更小 */ | |
| body:has(.main-title-user) .block-container { | |
| top: 6rem !important; | |
| } | |
| @media (max-width: 768px) { | |
| body:has(.main-title-user) .block-container { | |
| top: 4.5rem !important; | |
| padding: 0.5rem !important; | |
| } | |
| } | |
| /* 欢迎页:垂直居中,禁止滚动 */ | |
| .block-container:has(.welcome-marker) { | |
| overflow: hidden !important; | |
| display: flex !important; | |
| align-items: center !important; | |
| justify-content: center !important; | |
| } | |
| /* 快捷问题按钮:无边框链接风格 */ | |
| .block-container:has(.welcome-marker) [data-testid="stBaseButton-secondary"] { | |
| background: transparent !important; | |
| border: none !important; | |
| color: #1a73e8 !important; | |
| box-shadow: none !important; | |
| white-space: nowrap !important; | |
| } | |
| .block-container:has(.welcome-marker) [data-testid="stBaseButton-secondary"]:hover { | |
| background: #f0f7ff !important; | |
| } | |
| .block-container:has(.welcome-marker) [data-testid="stBaseButton-secondary"] p { | |
| color: #1a73e8 !important; | |
| white-space: nowrap !important; | |
| } | |
| /* 手机端快捷问题:垂直排列,按钮宽度自适应内容 */ | |
| @media (max-width: 768px) { | |
| /* 快捷问题按钮区域:垂直居中排列 */ | |
| .block-container:has(.welcome-marker) [data-testid="stVerticalBlockBorderWrapper"] [data-testid="stHorizontalBlock"]:not(:has([data-testid="stPopover"])) { | |
| flex-direction: column !important; | |
| align-items: center !important; | |
| gap: 0.3rem !important; | |
| } | |
| /* 按钮容器:宽度自适应 */ | |
| .block-container:has(.welcome-marker) [data-testid="stVerticalBlockBorderWrapper"] [data-testid="stHorizontalBlock"]:not(:has([data-testid="stPopover"])) > div { | |
| width: auto !important; | |
| flex: none !important; | |
| } | |
| /* 按钮本身:宽度自适应内容 */ | |
| .block-container:has(.welcome-marker) [data-testid="stBaseButton-secondary"] { | |
| width: auto !important; | |
| min-width: 0 !important; | |
| } | |
| } | |
| /* 侧边栏文件列表:防止文件名溢出遮挡删除按钮 */ | |
| [data-testid="stSidebar"] [data-testid="stHorizontalBlock"] > div:first-child { | |
| overflow: hidden !important; | |
| text-overflow: ellipsis !important; | |
| white-space: nowrap !important; | |
| } | |
| [data-testid="stSidebar"] [data-testid="stHorizontalBlock"] > div:last-child { | |
| flex-shrink: 0 !important; | |
| z-index: 1 !important; | |
| } | |
| /* meta 信息中的附加内容(文件名、降级提示)- 移动端换行 */ | |
| .meta-extra { | |
| display: inline; | |
| } | |
| .meta-sep { | |
| display: inline; | |
| } | |
| @media (max-width: 768px) { | |
| .meta-extra { | |
| display: block; | |
| margin-top: 2px; | |
| } | |
| .meta-sep { | |
| display: none; | |
| } | |
| } | |
| /* 移动端侧边栏文件列表布局修复 */ | |
| @media (max-width: 768px) { | |
| [data-testid="stSidebar"] [data-testid="stHorizontalBlock"] { | |
| display: flex !important; | |
| flex-direction: row !important; | |
| justify-content: space-between !important; | |
| align-items: center !important; | |
| width: 100% !important; | |
| } | |
| [data-testid="stSidebar"] [data-testid="stHorizontalBlock"] > div:first-child { | |
| flex: 1 !important; | |
| min-width: 0 !important; | |
| } | |
| [data-testid="stSidebar"] [data-testid="stHorizontalBlock"] > div:last-child { | |
| flex: 0 0 auto !important; | |
| margin-left: 8px !important; | |
| } | |
| } | |
| </style> | |
| """ | |
| def inject_custom_css(): | |
| st.markdown(_get_custom_css(), unsafe_allow_html=True) | |
| inject_custom_css() | |
| def _render_action_buttons(content, msg_id): | |
| """生成消息底部的复制/分享/播报按钮(纯HTML部分)""" | |
| return f''' | |
| <div style="display:flex; gap:12px; margin-top:8px; padding-top:6px; border-top:1px solid #eee;"> | |
| <button id="copy_btn_{msg_id}" style="background:none; border:none; cursor:pointer; color:#666; font-size:13px; padding:4px 8px; border-radius:4px;">📋 复制</button> | |
| <button id="share_btn_{msg_id}" style="background:none; border:none; cursor:pointer; color:#666; font-size:13px; padding:4px 8px; border-radius:4px;">🔗 分享</button> | |
| <button id="tts_btn_{msg_id}" style="background:none; border:none; cursor:pointer; color:#666; font-size:13px; padding:4px 8px; border-radius:4px;">🔊 播报</button> | |
| </div> | |
| ''' | |
| def _inject_action_js(content, msg_id): | |
| """注入按钮的 JavaScript(需要用 components.html 执行)""" | |
| # 转义内容用于 JavaScript | |
| escaped = content.replace("\\", "\\\\").replace("`", "\\`").replace("$", "\\$").replace("</script>", "<\\/script>") | |
| js_code = f''' | |
| <script> | |
| (function() {{ | |
| const msgContent = `{escaped}`; | |
| let speaking = false; | |
| const parentDoc = window.parent.document; | |
| // 兼容性复制函数(支持 iframe 环境) | |
| function copyToClipboard(text) {{ | |
| // 方法1: 现代 API | |
| if (navigator.clipboard && window.isSecureContext) {{ | |
| return navigator.clipboard.writeText(text); | |
| }} | |
| // 方法2: 传统方法(textarea) | |
| return new Promise((resolve, reject) => {{ | |
| const textarea = parentDoc.createElement('textarea'); | |
| textarea.value = text; | |
| textarea.style.position = 'fixed'; | |
| textarea.style.left = '-9999px'; | |
| textarea.style.top = '-9999px'; | |
| parentDoc.body.appendChild(textarea); | |
| textarea.focus(); | |
| textarea.select(); | |
| try {{ | |
| const ok = parentDoc.execCommand('copy'); | |
| parentDoc.body.removeChild(textarea); | |
| ok ? resolve() : reject(); | |
| }} catch (e) {{ | |
| parentDoc.body.removeChild(textarea); | |
| reject(e); | |
| }} | |
| }}); | |
| }} | |
| const copyBtn = parentDoc.getElementById('copy_btn_{msg_id}'); | |
| const shareBtn = parentDoc.getElementById('share_btn_{msg_id}'); | |
| const ttsBtn = parentDoc.getElementById('tts_btn_{msg_id}'); | |
| if (copyBtn) {{ | |
| copyBtn.onclick = function() {{ | |
| copyToClipboard(msgContent).then(() => {{ | |
| copyBtn.textContent = '✅ 已复制'; | |
| setTimeout(() => {{ copyBtn.textContent = '📋 复制'; }}, 1500); | |
| }}).catch(() => {{ | |
| // 最后方案:提示用户手动复制 | |
| prompt('请手动复制以下内容:', msgContent.substring(0, 500)); | |
| }}); | |
| }}; | |
| copyBtn.onmouseover = function() {{ this.style.background='#f0f0f0'; }}; | |
| copyBtn.onmouseout = function() {{ this.style.background='none'; }}; | |
| }} | |
| if (shareBtn) {{ | |
| shareBtn.onclick = function() {{ | |
| // 分享功能:优先用 Web Share API,否则复制 | |
| if (navigator.share && !window.parent.frames.length) {{ | |
| navigator.share({{ title: '智答AI助手', text: msgContent }}).catch(() => {{}}); | |
| }} else {{ | |
| copyToClipboard(msgContent).then(() => {{ | |
| shareBtn.textContent = '✅ 已复制'; | |
| setTimeout(() => {{ shareBtn.textContent = '🔗 分享'; }}, 1500); | |
| }}).catch(() => {{ | |
| prompt('请手动复制分享:', msgContent.substring(0, 500)); | |
| }}); | |
| }} | |
| }}; | |
| shareBtn.onmouseover = function() {{ this.style.background='#f0f0f0'; }}; | |
| shareBtn.onmouseout = function() {{ this.style.background='none'; }}; | |
| }} | |
| if (ttsBtn) {{ | |
| ttsBtn.onclick = function() {{ | |
| const synth = window.speechSynthesis || window.parent.speechSynthesis; | |
| if (!synth) {{ alert('浏览器不支持语音播报'); return; }} | |
| if (speaking) {{ | |
| synth.cancel(); | |
| speaking = false; | |
| ttsBtn.textContent = '🔊 播报'; | |
| }} else {{ | |
| const utterance = new SpeechSynthesisUtterance(msgContent); | |
| utterance.lang = 'zh-CN'; | |
| utterance.rate = 1.0; | |
| utterance.onend = () => {{ speaking = false; ttsBtn.textContent = '🔊 播报'; }}; | |
| utterance.onerror = () => {{ speaking = false; ttsBtn.textContent = '🔊 播报'; }}; | |
| synth.speak(utterance); | |
| speaking = true; | |
| ttsBtn.textContent = '⏹️ 停止'; | |
| }} | |
| }}; | |
| ttsBtn.onmouseover = function() {{ this.style.background='#f0f0f0'; }}; | |
| ttsBtn.onmouseout = function() {{ this.style.background='none'; }}; | |
| }} | |
| }})(); | |
| </script> | |
| ''' | |
| components.html(js_code, height=0) | |
| # 全局游客标识(session_state 在单次渲染中不变) | |
| IS_GUEST = not bool(st.session_state.get("current_user")) | |
| # 标题(固定在顶部)- 游客始终显示提示 | |
| _guest_tip = """<div style="text-align:center; padding:4px 0; color:#666; font-size:13px;"> | |
| 🙋 当前为<b>游客模式</b>,可直接提问体验 | |
| </div>""" if IS_GUEST else "" | |
| # 根据是否游客添加不同 class,用于 CSS 控制间距 | |
| _title_class = "main-title main-title-guest" if IS_GUEST else "main-title main-title-user" | |
| st.markdown( | |
| f""" | |
| <div class="{_title_class}"> | |
| <h1 style="white-space:nowrap; overflow:hidden; text-overflow:ellipsis; font-size:clamp(1.4rem, 5vw, 2.2rem); margin:0; text-align:center;"> | |
| 🤖 智答 AI 助手 | |
| </h1> | |
| {_guest_tip} | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| # ========================= | |
| # 1.5 Supabase 客户端初始化 | |
| # ========================= | |
| from supabase import create_client | |
| SUPABASE_URL = st.secrets.get("SUPABASE_URL", "") | |
| SUPABASE_KEY = st.secrets.get("SUPABASE_KEY", "") # service_role key (后端使用) | |
| STORAGE_BUCKET = "rag-files" | |
| if not SUPABASE_URL or not SUPABASE_KEY: | |
| st.error("⚠️ 未配置 SUPABASE_URL 或 SUPABASE_KEY,请在 Secrets 中设置。") | |
| st.stop() | |
| def _get_supabase(): | |
| return create_client(SUPABASE_URL, SUPABASE_KEY) | |
| def _sb(): | |
| """快捷获取 Supabase 客户端。""" | |
| return _get_supabase() | |
| # ========================= | |
| # 2. 用户管理(Supabase users 表) | |
| # ========================= | |
| MAX_LOGIN_ATTEMPTS = 10 | |
| def _hash_password(password): | |
| return hashlib.sha256(password.encode("utf-8")).hexdigest() | |
| # 缓存30秒,减少重复查询 | |
| def _load_users(): | |
| """从 Supabase users 表加载所有用户,返回 {username: {password_hash, role, created_at}}。""" | |
| try: | |
| resp = _sb().table("users").select("*").execute() | |
| users = {} | |
| for row in resp.data: | |
| users[row["username"]] = { | |
| "password_hash": row["password_hash"], | |
| "role": row["role"], | |
| "created_at": row["created_at"][:16] if row.get("created_at") else "未知", | |
| } | |
| return users | |
| except Exception as e: | |
| logger.error(f"加载用户表失败: {e}") | |
| return {} | |
| def _ensure_admin(): | |
| """首次运行时,从 secrets 创建管理员(如果 users 表为空)。""" | |
| users = _load_users() | |
| if users: | |
| return | |
| admin_user = st.secrets.get("ADMIN_USER", "admin") | |
| admin_pass = st.secrets.get("ADMIN_PASSWORD", "") | |
| if not admin_pass: | |
| return | |
| try: | |
| _sb().table("users").upsert({ | |
| "username": admin_user, | |
| "password_hash": _hash_password(admin_pass), | |
| "role": "admin", | |
| }).execute() | |
| logger.info(f"初始管理员 {admin_user} 已创建") | |
| except Exception as e: | |
| logger.error(f"创建管理员失败: {e}") | |
| # 仅首次运行时检查管理员(避免每次刷新都查询) | |
| if not st.session_state.get("_admin_ensured"): | |
| _ensure_admin() | |
| st.session_state["_admin_ensured"] = True | |
| def _save_user(username, password_hash, role="user"): | |
| """新增或更新单个用户。""" | |
| _sb().table("users").upsert({ | |
| "username": username, | |
| "password_hash": password_hash, | |
| "role": role, | |
| }).execute() | |
| _load_users.clear() # 清除用户列表缓存 | |
| def _delete_user_db(username): | |
| """删除用户记录。""" | |
| _sb().table("users").delete().eq("username", username).execute() | |
| _load_users.clear() # 清除用户列表缓存 | |
| # 缓存60秒 | |
| def _get_invite_code(): | |
| """从 app_meta 表读取邀请码。""" | |
| try: | |
| resp = _sb().table("app_meta").select("value").eq("key", "invite_code").execute() | |
| if resp.data: | |
| return resp.data[0]["value"] | |
| except Exception: | |
| pass | |
| return st.secrets.get("INVITE_CODE", "") | |
| def _set_invite_code(new_code): | |
| _sb().table("app_meta").upsert({"key": "invite_code", "value": new_code}).execute() | |
| _get_invite_code.clear() # 清除缓存 | |
| def register_user(username, password, invite_code): | |
| if not username or not password: | |
| return False, "用户名和密码不能为空" | |
| if len(username) < 2 or len(username) > 20: | |
| return False, "用户名长度需要 2-20 个字符" | |
| if len(password) < 4: | |
| return False, "密码至少 4 个字符" | |
| if username.startswith("__"): | |
| return False, "用户名不能以 __ 开头" | |
| correct_code = _get_invite_code() | |
| if not correct_code: | |
| return False, "邀请码未配置,请联系管理员" | |
| if invite_code != correct_code: | |
| return False, "邀请码错误" | |
| users = _load_users() | |
| if username in users: | |
| return False, "用户名已存在" | |
| try: | |
| _save_user(username, _hash_password(password), "user") | |
| logger.info(f"新用户注册: {username}") | |
| return True, "注册成功,请登录" | |
| except Exception as e: | |
| logger.error(f"注册失败: {e}") | |
| return False, f"注册失败: {e}" | |
| def verify_user(username, password): | |
| users = _load_users() | |
| user_info = users.get(username) | |
| if not user_info or not isinstance(user_info, dict): | |
| return False, None, "not_found" | |
| if user_info.get("password_hash") != _hash_password(password): | |
| return False, None, "wrong_password" | |
| return True, user_info.get("role", "user"), "" | |
| # --- 认证 UI --- | |
| if "login_attempts" not in st.session_state: | |
| st.session_state.login_attempts = 0 | |
| if "current_user" not in st.session_state: | |
| st.session_state.current_user = None | |
| if "current_role" not in st.session_state: | |
| st.session_state.current_role = None | |
| if "auth_mode" not in st.session_state: | |
| st.session_state.auth_mode = "login" | |
| with st.sidebar: | |
| # 未登录时不用 expander,直接显示登录表单(避免手动收起后无法自动展开) | |
| if IS_GUEST: | |
| st.subheader("🔑 账号") | |
| if st.session_state.login_attempts >= MAX_LOGIN_ATTEMPTS: | |
| st.error("🚫 尝试次数过多,请刷新页面后重试。") | |
| st.stop() | |
| users_data = _load_users() | |
| if not users_data: | |
| st.error("⚠️ 未配置管理员,请在 secrets 中设置 ADMIN_USER 和 ADMIN_PASSWORD。") | |
| st.stop() | |
| auth_mode = st.radio( | |
| "操作", ["登录", "注册"], horizontal=True, | |
| label_visibility="collapsed", key="auth_radio", | |
| ) | |
| if auth_mode == "登录": | |
| with st.form("login_form", clear_on_submit=False): | |
| input_username = st.text_input("用户名", key="login_user") | |
| input_password = st.text_input("密码", type="password", key="login_pass") | |
| submitted = st.form_submit_button("🔓 登录", use_container_width=True, type="primary") | |
| if submitted: | |
| if input_username == "" or input_password == "": | |
| st.warning("请输入用户名和密码") | |
| elif st.session_state.login_attempts >= MAX_LOGIN_ATTEMPTS: | |
| pass | |
| else: | |
| ok, role, reason = verify_user(input_username, input_password) | |
| if not ok: | |
| st.session_state.login_attempts += 1 | |
| remaining = MAX_LOGIN_ATTEMPTS - st.session_state.login_attempts | |
| if reason == "not_found": | |
| st.warning(f"⚠️ 用户不存在,请先注册(剩余 {remaining} 次)") | |
| else: | |
| st.warning(f"⚠️ 密码错误(剩余 {remaining} 次)") | |
| else: | |
| st.session_state.login_attempts = 0 | |
| st.session_state.current_user = input_username | |
| st.session_state.current_role = role | |
| st.rerun() | |
| else: # 注册 | |
| with st.form("register_form", clear_on_submit=False): | |
| reg_user = st.text_input("用户名", key="reg_user") | |
| reg_pass = st.text_input("密码", type="password", key="reg_pass") | |
| reg_pass2 = st.text_input("确认密码", type="password", key="reg_pass2") | |
| reg_code = st.text_input("邀请码", type="password", key="reg_code") | |
| reg_submitted = st.form_submit_button("注册", use_container_width=True) | |
| if reg_submitted: | |
| if reg_pass != reg_pass2: | |
| st.error("两次密码不一致") | |
| else: | |
| ok, msg = register_user(reg_user, reg_pass, reg_code) | |
| if ok: | |
| st.session_state.current_user = reg_user | |
| st.session_state.current_role = "user" | |
| st.session_state.login_attempts = 0 | |
| st.rerun() | |
| else: | |
| st.error(f"❌ {msg}") | |
| else: | |
| with st.expander("🔑 账号"): | |
| role_label = "管理员" if st.session_state.current_role == "admin" else "普通用户" | |
| st.success(f"✅ {st.session_state.current_user}({role_label})") | |
| # 设置全局用户标识 | |
| if st.session_state.get("current_user"): | |
| CURRENT_USER = st.session_state.current_user | |
| IS_ADMIN = st.session_state.current_role == "admin" | |
| IS_GUEST = False | |
| else: | |
| CURRENT_USER = "_guest_" | |
| IS_ADMIN = False | |
| IS_GUEST = True | |
| # 登录后自动收起侧边栏(通过 JS 模拟点击收起按钮) | |
| if not IS_GUEST and not st.session_state.get("_sidebar_collapsed_once"): | |
| st.session_state["_sidebar_collapsed_once"] = True | |
| components.html( | |
| """ | |
| <script> | |
| (function() { | |
| const btn = window.parent.document.querySelector( | |
| '[data-testid="stSidebarCollapseButton"] button' | |
| ); | |
| if (btn) btn.click(); | |
| })(); | |
| </script> | |
| """, | |
| height=0, | |
| ) | |
| # ========================= | |
| # 3. 安全配置与 Embedding 策略 | |
| # ========================= | |
| TAVILY_KEY = st.secrets.get("TAVILY_API_KEY", "") | |
| DS_API_KEY = st.secrets.get("DEEPSEEK_API_KEY", "") | |
| BAIDU_TOKEN = st.secrets.get("BAIDU_BEARER_TOKEN", "") | |
| BAIDU_APP_ID = st.secrets.get("BAIDU_APP_ID", "") | |
| OR_KEY = st.secrets.get("OPENROUTER_API_KEY", "") | |
| def _get_embedding_client(): | |
| from openai import OpenAI | |
| if BAIDU_TOKEN and BAIDU_APP_ID: | |
| return OpenAI( | |
| api_key=BAIDU_TOKEN, | |
| base_url="https://qianfan.baidubce.com/v2", | |
| default_headers={"appid": BAIDU_APP_ID}, | |
| ), "qwen3-embedding-0.6b" # 升级:8192 token上下文,性能更优 | |
| if OR_KEY: | |
| return OpenAI( | |
| api_key=OR_KEY, | |
| base_url="https://openrouter.ai/api/v1", | |
| ), "BAAI/bge-small-zh" | |
| return None, None | |
| def _api_encode(texts): | |
| np = _get_np() | |
| client, model = _get_embedding_client() | |
| if client is None: | |
| return None | |
| try: | |
| batch_size = 32 | |
| all_vecs = [] | |
| for i in range(0, len(texts), batch_size): | |
| batch = texts[i:i + batch_size] | |
| resp = client.embeddings.create(model=model, input=batch) | |
| all_vecs.extend([np.array(item.embedding) for item in resp.data]) | |
| return all_vecs | |
| except Exception as e: | |
| logger.warning(f"API embedding 失败,回退到本地模型: {e}") | |
| return None | |
| def _get_local_model(): | |
| if "_local_emb_model" not in st.session_state: | |
| try: | |
| with st.spinner("API 不可用,正在加载本地向量模型(仅首次)..."): | |
| from sentence_transformers import SentenceTransformer | |
| st.session_state._local_emb_model = SentenceTransformer("BAAI/bge-small-zh") | |
| except ImportError: | |
| logger.error("sentence-transformers 未安装,本地模型不可用") | |
| return None | |
| return st.session_state.get("_local_emb_model") | |
| def encode_texts(texts): | |
| if not texts: | |
| return [] | |
| if isinstance(texts, str): | |
| texts = [texts] | |
| result = _api_encode(texts) | |
| if result is not None: | |
| return result | |
| model = _get_local_model() | |
| if model is None: | |
| st.error("❌ Embedding 服务不可用:API 调用失败且本地模型未安装。请检查 API Key 配置。") | |
| return [] | |
| return list(model.encode(texts)) | |
| def encode_query(text): | |
| vecs = encode_texts([text]) | |
| return vecs[0] | |
| # ========================= | |
| # 4. Supabase 索引管理(替代本地文件) | |
| # ========================= | |
| def _save_chunks_to_db(scope, chunks, vectors, source_file): | |
| """将新切片批量写入 Supabase documents 表。""" | |
| rows = [] | |
| for content, vec, src in zip(chunks, vectors, [source_file] * len(chunks)): | |
| rows.append({ | |
| "scope": scope, | |
| "source_file": src, | |
| "content": content, | |
| "embedding": vec.tolist() if hasattr(vec, 'tolist') else list(vec), | |
| }) | |
| # Supabase 批量插入(每次最多 500 行) | |
| batch_size = 500 | |
| for i in range(0, len(rows), batch_size): | |
| _sb().table("documents").insert(rows[i:i + batch_size]).execute() | |
| def _delete_chunks_by_file(scope, filename): | |
| """删除指定 scope + filename 的所有切片。""" | |
| _sb().table("documents").delete().eq("scope", scope).eq("source_file", filename).execute() | |
| def _clear_all_chunks(scope): | |
| """清空指定 scope 的所有文档切片。""" | |
| _sb().table("documents").delete().eq("scope", scope).execute() | |
| def _count_chunks(scope): | |
| """返回指定 scope 的切片数量。""" | |
| try: | |
| resp = _sb().table("documents").select("id", count="exact").eq("scope", scope).execute() | |
| return resp.count or 0 | |
| except Exception: | |
| return 0 | |
| # --- 原始文件管理(Supabase Storage + uploaded_files 表)--- | |
| def _safe_storage_path(scope, filename): | |
| """生成 ASCII 安全的 Storage 路径(Supabase Storage 不支持中文/特殊字符)。""" | |
| import os | |
| ext = os.path.splitext(filename)[1].lower() # 保留扩展名 | |
| safe_name = uuid.uuid4().hex[:16] + ext | |
| return f"{scope}/{safe_name}" | |
| def _save_uploaded_file_to_storage(scope, uploaded_file): | |
| """上传原始文件到 Supabase Storage,并记录元数据。""" | |
| uploaded_file.seek(0) | |
| file_bytes = uploaded_file.read() | |
| # 检查是否已有同名文件记录,复用其 storage_path | |
| try: | |
| existing = _sb().table("uploaded_files").select("storage_path").eq( | |
| "scope", scope).eq("filename", uploaded_file.name).execute() | |
| if existing.data: | |
| storage_path = existing.data[0]["storage_path"] | |
| else: | |
| storage_path = _safe_storage_path(scope, uploaded_file.name) | |
| except Exception: | |
| storage_path = _safe_storage_path(scope, uploaded_file.name) | |
| # 上传到 Storage(存在则覆盖) | |
| try: | |
| _sb().storage.from_(STORAGE_BUCKET).upload( | |
| storage_path, file_bytes, | |
| file_options={"content-type": "application/octet-stream", "upsert": "true"} | |
| ) | |
| except Exception as e: | |
| logger.warning(f"Storage upload fallback: {e}") | |
| try: | |
| _sb().storage.from_(STORAGE_BUCKET).remove([storage_path]) | |
| except Exception: | |
| pass | |
| _sb().storage.from_(STORAGE_BUCKET).upload( | |
| storage_path, file_bytes, | |
| file_options={"content-type": "application/octet-stream"} | |
| ) | |
| # 记录元数据到 uploaded_files 表(filename 保留原始中文名) | |
| _sb().table("uploaded_files").upsert({ | |
| "scope": scope, | |
| "filename": uploaded_file.name, | |
| "file_size": len(file_bytes), | |
| "storage_path": storage_path, | |
| }, on_conflict="scope,filename").execute() | |
| _list_uploaded_files_db.clear() # 清除文件列表缓存 | |
| # 缓存15秒,文件列表变化不频繁 | |
| def _list_uploaded_files_db(scope): | |
| """列出某个 scope 已上传的文件。返回 [(filename, size_str, storage_path), ...]。""" | |
| try: | |
| resp = _sb().table("uploaded_files").select( | |
| "filename, file_size, storage_path" | |
| ).eq("scope", scope).order("filename").execute() | |
| result = [] | |
| for row in resp.data: | |
| size = row.get("file_size", 0) or 0 | |
| if size < 1024: | |
| size_str = f"{size}B" | |
| elif size < 1048576: | |
| size_str = f"{size / 1024:.1f}KB" | |
| else: | |
| size_str = f"{size / 1048576:.1f}MB" | |
| result.append((row["filename"], size_str, row.get("storage_path", ""))) | |
| return result | |
| except Exception as e: | |
| logger.error(f"列出文件失败 [scope={scope}]: {e}") | |
| return [] | |
| def _delete_uploaded_file_from_storage(scope, filename): | |
| """删除 Storage 中的文件和 uploaded_files 表记录。""" | |
| try: | |
| # 从数据库查真实 storage_path | |
| resp = _sb().table("uploaded_files").select("storage_path").eq( | |
| "scope", scope).eq("filename", filename).execute() | |
| if resp.data: | |
| storage_path = resp.data[0]["storage_path"] | |
| _sb().storage.from_(STORAGE_BUCKET).remove([storage_path]) | |
| except Exception as e: | |
| logger.warning(f"Storage 删除失败: {e}") | |
| try: | |
| _sb().table("uploaded_files").delete().eq("scope", scope).eq("filename", filename).execute() | |
| _list_uploaded_files_db.clear() # 清除文件列表缓存 | |
| except Exception as e: | |
| logger.warning(f"uploaded_files 记录删除失败: {e}") | |
| def _clear_uploaded_files_storage(scope): | |
| """清空某个 scope 的所有上传文件。""" | |
| try: | |
| resp = _sb().table("uploaded_files").select("storage_path").eq("scope", scope).execute() | |
| paths = [row["storage_path"] for row in resp.data] | |
| if paths: | |
| _sb().storage.from_(STORAGE_BUCKET).remove(paths) | |
| _sb().table("uploaded_files").delete().eq("scope", scope).execute() | |
| _list_uploaded_files_db.clear() # 清除文件列表缓存 | |
| except Exception as e: | |
| logger.warning(f"清空文件失败 [scope={scope}]: {e}") | |
| PUBLIC_SCOPE = "public" | |
| PRIVATE_SCOPE = CURRENT_USER # 私有库 scope = 用户名 | |
| # --- 定时同步:检测其他用户对文档库的修改 --- | |
| _SYNC_INTERVAL = 120 # 每 120 秒检查一次 | |
| def _check_and_sync(): | |
| """检测文档数量变化,用于多用户同步感知。""" | |
| now = time.time() | |
| last_check = st.session_state.get("_sync_last_check", 0) | |
| if now - last_check < _SYNC_INTERVAL: | |
| return | |
| st.session_state["_sync_last_check"] = now | |
| for scope, label in [(PUBLIC_SCOPE, "public"), (PRIVATE_SCOPE, "private")]: | |
| count_key = f"_sync_count_{label}" | |
| current_count = _count_chunks(scope) | |
| prev_count = st.session_state.get(count_key, -1) | |
| if prev_count >= 0 and current_count != prev_count: | |
| logger.info(f"[SYNC] {label} 库变更: {prev_count} -> {current_count}") | |
| st.session_state[count_key] = current_count | |
| _check_and_sync() | |
| # --- 登录用户自动模式切换:根据知识库内容决定默认模式 --- | |
| def _has_any_kb_content(): | |
| """检查是否有任何知识库内容(公共库或私有库)""" | |
| pub_count = st.session_state.get("_sync_count_public") | |
| if pub_count is None: | |
| pub_count = _count_chunks(PUBLIC_SCOPE) | |
| st.session_state["_sync_count_public"] = pub_count | |
| priv_count = st.session_state.get("_sync_count_private") | |
| if priv_count is None: | |
| priv_count = _count_chunks(PRIVATE_SCOPE) | |
| st.session_state["_sync_count_private"] = priv_count | |
| return (pub_count + priv_count) > 0 | |
| if not IS_GUEST and "sel_web" not in st.session_state: | |
| # 首次加载时,有任何知识库内容则默认知识库模式,否则联网模式 | |
| st.session_state["sel_web"] = not _has_any_kb_content() | |
| # ========================= | |
| # 5. 缓存 LLM 客户端 | |
| # ========================= | |
| def get_or_client(): | |
| from openai import OpenAI | |
| return OpenAI(api_key=OR_KEY, base_url="https://openrouter.ai/api/v1") | |
| def get_ds_client(): | |
| from openai import OpenAI | |
| return OpenAI(api_key=DS_API_KEY, base_url="https://api.deepseek.com") | |
| def get_baidu_client(): | |
| from openai import OpenAI | |
| return OpenAI( | |
| api_key=BAIDU_TOKEN, | |
| base_url="https://qianfan.baidubce.com/v2", | |
| default_headers={"appid": BAIDU_APP_ID}, | |
| ) | |
| # ========================= | |
| # 6. 实用功能函数 | |
| # ========================= | |
| _text_splitter_cache = None | |
| def _get_text_splitter(): | |
| global _text_splitter_cache | |
| if _text_splitter_cache is None: | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| _text_splitter_cache = RecursiveCharacterTextSplitter(chunk_size=400, chunk_overlap=50) | |
| return _text_splitter_cache | |
| SYSTEM_PROMPT = ( | |
| "你是一个专业的知识问答助手。请基于提供的参考资料回答用户问题。" | |
| "如果资料中没有相关信息,请诚实说明。回答要准确、有条理、简洁。" | |
| "不要编造不在资料中的信息。" | |
| ) | |
| SYSTEM_PROMPT_WEB = ( | |
| "你是一个智能AI助手,擅长结合互联网搜索结果回答用户问题。" | |
| "请根据搜索结果提供准确、有条理的回答。" | |
| "如果搜索结果不足以回答问题,请结合你自身的知识进行补充。" | |
| "回答要简洁实用,注明信息来源。" | |
| ) | |
| SYSTEM_PROMPT_DIRECT = ( | |
| "你是一个智能AI助手。请直接回答用户的问题。" | |
| "回答要准确、有条理、简洁实用。" | |
| "可以充分发挥你的知识和能力来帮助用户。" | |
| ) | |
| def web_search(query): | |
| if not TAVILY_KEY: | |
| return "⚠️ 未配置搜索 Key" | |
| from tavily import TavilyClient | |
| tavily = TavilyClient(api_key=TAVILY_KEY) | |
| current_year = datetime.now().year | |
| try: | |
| search_result = tavily.search( | |
| query=f"{current_year}年 {query}", | |
| search_depth="advanced", | |
| max_results=3, | |
| ) | |
| results = [ | |
| f"来源: {r.get('url')}\n内容: {r.get('content', '')[:700]}" | |
| for r in search_result["results"] | |
| ] | |
| return "\n\n".join(results)[:2500] | |
| except Exception as e: | |
| logger.error(f"联网搜索异常: {e}") | |
| return f"联网搜索异常:{str(e)}" | |
| def estimate_tokens(text): | |
| if not text: | |
| return 0 | |
| zh_count = sum(1 for c in text if "\u4e00" <= c <= "\u9fff") | |
| return int(zh_count * 1.5 + (len(text) - zh_count) * 0.4) | |
| def extract_text(file): | |
| fname = file.name.lower() | |
| text = "" | |
| try: | |
| if fname.endswith(".txt"): | |
| text = file.read().decode("utf-8", errors="ignore") | |
| elif fname.endswith(".pdf"): | |
| import io | |
| file.seek(0) | |
| pdf_bytes = file.read() | |
| if len(pdf_bytes) < 100: | |
| raise ValueError("PDF 文件过小,可能已损坏") | |
| if not pdf_bytes[:5] == b"%PDF-": | |
| raise ValueError("不是有效的 PDF 文件(缺少 %PDF- 头)") | |
| pdf_stream = io.BytesIO(pdf_bytes) | |
| pages_text = [] | |
| import pdfplumber | |
| with pdfplumber.open(pdf_stream) as pdf: | |
| for i, page in enumerate(pdf.pages): | |
| try: | |
| page_text = page.extract_text() or "" | |
| pages_text.append(page_text) | |
| except Exception as page_err: | |
| logger.warning(f"PDF 第{i+1}页解析失败: {page_err}") | |
| pages_text.append("") | |
| text = "\n".join(pages_text) | |
| elif fname.endswith(".docx"): | |
| from docx import Document | |
| doc = Document(file) | |
| text = "\n".join(para.text for para in doc.paragraphs) | |
| except Exception as e: | |
| logger.error(f"文件解析失败 [{file.name}]: {e}", exc_info=True) | |
| st.error(f"解析失败: {e}") | |
| return text | |
| def process_upload(uploaded_files, target_prefix, scope): | |
| """处理上传文件:解析 → 切片 → 编码 → 写入 Supabase。""" | |
| if not uploaded_files: | |
| return False | |
| try: | |
| all_new_chunks = [] | |
| all_new_sources = [] | |
| with st.spinner("正在自动解析文档并更新索引..."): | |
| for f in uploaded_files: | |
| try: | |
| # 保存原始文件到 Supabase Storage | |
| f.seek(0) | |
| _save_uploaded_file_to_storage(scope, f) | |
| # 解析文本 | |
| f.seek(0) | |
| raw_text = extract_text(f) | |
| if not raw_text.strip(): | |
| st.warning(f"文件 {f.name} 内容为空,已跳过。") | |
| continue | |
| chunks = _get_text_splitter().split_text(raw_text) | |
| all_new_chunks.extend(chunks) | |
| all_new_sources.extend([f.name] * len(chunks)) | |
| except Exception as file_err: | |
| logger.error(f"文件 {f.name} 处理失败: {file_err}", exc_info=True) | |
| st.warning(f"⚠️ 文件 {f.name} 处理失败:{str(file_err)[:100]},已跳过。") | |
| if all_new_chunks: | |
| # 分批编码 | |
| batch_size = 64 | |
| all_vecs = [] | |
| for i in range(0, len(all_new_chunks), batch_size): | |
| batch = all_new_chunks[i:i + batch_size] | |
| all_vecs.extend(encode_texts(batch)) | |
| # 按 source_file 分组写入 Supabase | |
| file_groups = {} | |
| for chunk, vec, src in zip(all_new_chunks, all_vecs, all_new_sources): | |
| file_groups.setdefault(src, ([], [])) | |
| file_groups[src][0].append(chunk) | |
| file_groups[src][1].append(vec) | |
| for src_file, (chunks, vecs) in file_groups.items(): | |
| _save_chunks_to_db(scope, chunks, vecs, src_file) | |
| # 递增上传组件 key | |
| ukey = f"_upload_ver_{target_prefix}" | |
| st.session_state[ukey] = st.session_state.get(ukey, 0) + 1 | |
| # 立即刷新缓存的切片计数 | |
| st.session_state[f"_sync_count_{target_prefix}"] = _count_chunks(scope) | |
| # 清除文件列表缓存 | |
| _list_uploaded_files_db.clear() | |
| # 上传成功后自动切换到知识库模式 | |
| st.session_state["sel_web"] = False | |
| st.toast(f"✅ 导入 {len(all_new_chunks)} 个切片") | |
| st.rerun() | |
| else: | |
| st.error("解析失败,未发现有效文字内容。") | |
| except Exception as e: | |
| logger.error(f"上传处理异常: {e}", exc_info=True) | |
| st.error(f"❌ 上传处理出错:{str(e)[:200]}") | |
| return False | |
| # ========================= | |
| # 6.5 聊天记录持久化(Supabase chat_history 表) | |
| # ========================= | |
| import threading | |
| def _async_run(fn, *args): | |
| """在后台线程执行耗时操作,不阻塞 UI。""" | |
| t = threading.Thread(target=fn, args=args, daemon=True) | |
| t.start() | |
| def _save_chat_message_sync(username, role, content, meta=""): | |
| """同步保存单条聊天消息到数据库(内部使用)。""" | |
| try: | |
| _sb().table("chat_history").insert({ | |
| "username": username, | |
| "role": role, | |
| "content": content, | |
| "meta": meta or "", | |
| }).execute() | |
| except Exception as e: | |
| logger.warning(f"保存聊天记录失败: {e}") | |
| def _save_chat_message(username, role, content, meta=""): | |
| """异步保存聊天消息,不阻塞 UI。""" | |
| t = threading.Thread( | |
| target=_save_chat_message_sync, | |
| args=(username, role, content, meta), | |
| daemon=True, | |
| ) | |
| t.start() | |
| # 缓存10秒,减少重复数据库查询 | |
| def _load_chat_history(username, limit=50): | |
| """加载用户最近的聊天记录。""" | |
| try: | |
| resp = _sb().table("chat_history").select( | |
| "role, content, meta, created_at" | |
| ).eq("username", username).order( | |
| "created_at", desc=True | |
| ).limit(limit).execute() | |
| if not resp.data: | |
| return [] | |
| rows = list(reversed(resp.data)) | |
| return [ | |
| {"role": r["role"], "content": r["content"], | |
| "meta": r.get("meta", ""), "created_at": r.get("created_at", "")} | |
| for r in rows | |
| ] | |
| except Exception as e: | |
| logger.warning(f"加载聊天记录失败: {e}") | |
| return [] | |
| def _clear_chat_history_db(username): | |
| """清空用户在数据库中的所有聊天记录。""" | |
| try: | |
| _sb().table("chat_history").delete().eq("username", username).execute() | |
| except Exception as e: | |
| logger.warning(f"清空聊天记录失败: {e}") | |
| # ========================= | |
| # 7. 侧边栏 UI & 逻辑 | |
| # ========================= | |
| # 游客模式模型(免费,通过 OpenRouter) | |
| _GUEST_MODELS = { | |
| "⭐ Step-3.5 (首选)": "stepfun/step-3.5-flash:free", | |
| "🌐 OR-Auto (避堵)": "openrouter/free", | |
| "🧠 GLM-4.5 (推理)": "z-ai/glm-4.5-air:free", | |
| "🔥 Gemma-3-27B (旗舰)": "google/gemma-3-27b-it:free", | |
| "🐋 Nemotron (120B)": "nvidia/nemotron-3-super-120b-a12b:free", | |
| "⚡ Trinity-L (极速)": "arcee-ai/trinity-large-preview:free", | |
| "💭 Liquid-Think (思维链)": "liquid/lfm-2.5-1.2b-thinking:free", | |
| "🏎️ Liquid-Ins (1.0s)": "liquid/lfm-2.5-1.2b-instruct:free", | |
| "⚖️ Gemma-3-12B (平衡)": "google/gemma-3-12b-it:free", | |
| "💎 Gemma-3n-e4b (稳)": "google/gemma-3n-e4b-it:free", | |
| "🤖 Nemotron-Nano (混)": "nvidia/nemotron-3-nano-30b-a3b:free", | |
| "📉 Trinity-M (1.8s)": "arcee-ai/trinity-mini:free", | |
| "🍃 Nemotron-9B": "nvidia/nemotron-nano-9b-v2:free", | |
| "🪶 Gemma-3-4B": "google/gemma-3-4b-it:free", | |
| "🫧 Gemma-3n-e2b": "google/gemma-3n-e2b-it:free", | |
| "📷 Nemotron-VL": "nvidia/nemotron-nano-12b-v2-vl:free", | |
| } | |
| # 登录模式模型(收费,通过百度千帆 API) | |
| _USER_MODELS = { | |
| "🧠 文心5.0思维链": "ernie-5.0-thinking-latest", | |
| "🔥 文心5.0": "ernie-5.0", | |
| "⚡ 文心4.5-Turbo": "ernie-4.5-turbo-128k", | |
| "🐋 DeepSeek-V3.2思维": "deepseek-v3.2-think", | |
| "💎 DeepSeek-V3.2": "deepseek-v3.2", | |
| "🔮 DeepSeek-R1": "deepseek-r1", | |
| "🌟 千问3.5-397B": "qwen3.5-397b-a17b", | |
| "💭 千问3-235B思维": "qwen3-235b-a22b-thinking-2507", | |
| "📊 千问3-32B": "qwen3-32b", | |
| "🎯 GLM-5": "glm-5", | |
| "🌙 Kimi-K2.5": "kimi-k2.5", | |
| "✨ MiniMax-M2.5": "minimax-m2.5", | |
| "🛡️ DeepSeek官方": "deepseek-chat", | |
| "🏢 文心3.5": "ernie-3.5-8k", | |
| } | |
| # 根据用户状态选择模型列表 | |
| if IS_GUEST: | |
| model_mapping = _GUEST_MODELS | |
| _active_models = _GUEST_MODELS | |
| else: | |
| model_mapping = _USER_MODELS | |
| _active_models = _USER_MODELS | |
| with st.sidebar: | |
| # --- 检索参数设置(仅登录用户可见) --- | |
| if not IS_GUEST: | |
| with st.expander("⚙️ 检索参数"): | |
| c1, c2 = st.columns(2) | |
| with c1: | |
| ui_top_k = st.slider("Top-K", min_value=1, max_value=15, value=5, key="sel_topk") | |
| with c2: | |
| ui_threshold = st.slider("阈值", min_value=0.0, max_value=1.0, value=0.25, step=0.05, key="sel_threshold") | |
| # --- 以下面板仅登录用户可见 --- | |
| if not IS_GUEST: | |
| pub_chunk_count = st.session_state.get("_sync_count_public") | |
| if pub_chunk_count is None: | |
| pub_chunk_count = _count_chunks(PUBLIC_SCOPE) | |
| st.session_state["_sync_count_public"] = pub_chunk_count | |
| with st.expander(f"📚 公共知识库({pub_chunk_count} 切片)"): | |
| st.caption("所有人可搜索") | |
| # 文件列表 | |
| pub_file_list = _list_uploaded_files_db(PUBLIC_SCOPE) | |
| if pub_file_list: | |
| # 先检查是否有删除请求,避免重复渲染 | |
| delete_target = None | |
| if IS_ADMIN: | |
| for fname, _, _ in pub_file_list: | |
| if st.session_state.get(f"delpub_{fname}"): | |
| delete_target = fname | |
| break | |
| if delete_target: | |
| _delete_chunks_by_file(PUBLIC_SCOPE, delete_target) | |
| _delete_uploaded_file_from_storage(PUBLIC_SCOPE, delete_target) | |
| _list_uploaded_files_db.clear() | |
| st.session_state["_sync_count_public"] = _count_chunks(PUBLIC_SCOPE) | |
| # 删除后如果所有知识库为空,自动切换到联网模式 | |
| if not _has_any_kb_content(): | |
| st.session_state["sel_web"] = True | |
| st.toast(f"已删除 {delete_target}") | |
| st.rerun() | |
| st.caption(f"📎 已上传 {len(pub_file_list)} 个文件:") | |
| for fname, size_str, _ in pub_file_list: | |
| if IS_ADMIN: | |
| col_name, col_del = st.columns([4, 1]) | |
| col_name.text(f"📄 {fname} ({size_str})") | |
| col_del.button("🗑", key=f"delpub_{fname}") | |
| else: | |
| st.text(f"📄 {fname} ({size_str})") | |
| if IS_ADMIN: | |
| pub_upload_key = f"upload_public_{st.session_state.get('_upload_ver_public', 0)}" | |
| # 使用 on_change 回调标记有新文件上传 | |
| def _on_pub_upload_change(): | |
| st.session_state["_pending_upload_public"] = True | |
| pub_files = st.file_uploader( | |
| "上传到公共库", | |
| type=["txt", "pdf", "docx"], | |
| accept_multiple_files=True, | |
| label_visibility="collapsed", | |
| key=pub_upload_key, | |
| on_change=_on_pub_upload_change, | |
| ) | |
| # 检查是否有待处理的上传(通过 on_change 标记) | |
| if pub_files and st.session_state.pop("_pending_upload_public", False): | |
| process_upload(pub_files, "public", PUBLIC_SCOPE) | |
| if pub_chunk_count > 0 and len(pub_file_list) >= 2: | |
| if st.button("🗑️ 清空公共库", use_container_width=True, type="secondary", key="clear_pub"): | |
| _clear_all_chunks(PUBLIC_SCOPE) | |
| _clear_uploaded_files_storage(PUBLIC_SCOPE) | |
| _list_uploaded_files_db.clear() | |
| st.session_state["_sync_count_public"] = 0 | |
| # 清空后如果所有知识库为空,自动切换到联网模式 | |
| if not _has_any_kb_content(): | |
| st.session_state["sel_web"] = True | |
| st.toast("公共知识库已清空") | |
| st.rerun() | |
| else: | |
| st.caption("*仅管理员可维护公共库*") | |
| # --- 私有知识库 --- | |
| priv_chunk_count = st.session_state.get("_sync_count_private") | |
| if priv_chunk_count is None: | |
| priv_chunk_count = _count_chunks(PRIVATE_SCOPE) | |
| st.session_state["_sync_count_private"] = priv_chunk_count | |
| with st.expander(f"🔒 我的私有库({priv_chunk_count} 切片)"): | |
| st.caption(f"用户:{CURRENT_USER},仅自己可见") | |
| priv_file_list = _list_uploaded_files_db(PRIVATE_SCOPE) | |
| if priv_file_list: | |
| # 先检查是否有删除请求,避免重复渲染 | |
| delete_target = None | |
| for fname, _, _ in priv_file_list: | |
| if st.session_state.get(f"delpriv_{fname}"): | |
| delete_target = fname | |
| break | |
| if delete_target: | |
| _delete_chunks_by_file(PRIVATE_SCOPE, delete_target) | |
| _delete_uploaded_file_from_storage(PRIVATE_SCOPE, delete_target) | |
| _list_uploaded_files_db.clear() | |
| st.session_state["_sync_count_private"] = _count_chunks(PRIVATE_SCOPE) | |
| # 删除后如果所有知识库为空,自动切换到联网模式 | |
| if not _has_any_kb_content(): | |
| st.session_state["sel_web"] = True | |
| st.toast(f"已删除 {delete_target}") | |
| st.rerun() | |
| st.caption(f"📎 已上传 {len(priv_file_list)} 个文件:") | |
| for fname, size_str, _ in priv_file_list: | |
| col_name, col_del = st.columns([4, 1]) | |
| col_name.text(f"📄 {fname} ({size_str})") | |
| col_del.button("🗑", key=f"delpriv_{fname}") | |
| priv_upload_key = f"upload_private_{st.session_state.get('_upload_ver_private', 0)}" | |
| # 使用 on_change 回调标记有新文件上传 | |
| def _on_priv_upload_change(): | |
| st.session_state["_pending_upload_private"] = True | |
| priv_files = st.file_uploader( | |
| "上传到私有库", | |
| type=["txt", "pdf", "docx"], | |
| accept_multiple_files=True, | |
| label_visibility="collapsed", | |
| key=priv_upload_key, | |
| on_change=_on_priv_upload_change, | |
| ) | |
| # 检查是否有待处理的上传(通过 on_change 标记) | |
| if priv_files and st.session_state.pop("_pending_upload_private", False): | |
| process_upload(priv_files, "private", PRIVATE_SCOPE) | |
| if priv_chunk_count > 0 and len(priv_file_list) >= 2: | |
| if st.button("🗑️ 清空我的私有库", use_container_width=True, type="secondary", key="clear_priv"): | |
| _clear_all_chunks(PRIVATE_SCOPE) | |
| _clear_uploaded_files_storage(PRIVATE_SCOPE) | |
| st.session_state["_sync_count_private"] = 0 | |
| # 清空后如果所有知识库为空,自动切换到联网模式 | |
| if not _has_any_kb_content(): | |
| st.session_state["sel_web"] = True | |
| st.toast("私有知识库已清空") | |
| st.rerun() | |
| # --- 修改密码 --- | |
| with st.expander("🔐 修改密码"): | |
| with st.form("change_password_form", clear_on_submit=False): | |
| old_pass = st.text_input("当前密码", type="password", key="self_old_pass") | |
| new_pass1 = st.text_input("新密码", type="password", key="self_new_pass1") | |
| new_pass2 = st.text_input("确认新密码", type="password", key="self_new_pass2") | |
| submitted = st.form_submit_button("✅ 确认修改", use_container_width=True) | |
| if submitted: | |
| ok, _, _ = verify_user(CURRENT_USER, old_pass) | |
| if not ok: | |
| st.error("当前密码错误") | |
| elif len(new_pass1) < 4: | |
| st.error("新密码至少 4 个字符") | |
| elif new_pass1 != new_pass2: | |
| st.error("两次新密码不一致") | |
| else: | |
| _save_user(CURRENT_USER, _hash_password(new_pass1), | |
| st.session_state.current_role) | |
| st.success("密码修改成功") | |
| # --- 管理员面板 --- | |
| if IS_ADMIN: | |
| with st.expander("👥 用户管理"): | |
| all_users = _load_users() | |
| user_list = [(u, info) for u, info in all_users.items() if isinstance(info, dict)] | |
| st.caption(f"共 **{len(user_list)}** 个用户") | |
| for uname, uinfo in user_list: | |
| role_tag = "👑" if uinfo.get("role") == "admin" else "👤" | |
| created = uinfo.get("created_at", "未知") | |
| st.text(f"{role_tag} {uname}({created})") | |
| deletable = [u for u, _ in user_list if u != CURRENT_USER] | |
| if deletable: | |
| del_target = st.selectbox("选择要删除的用户", deletable, key="del_user_select") | |
| if st.button("❌ 删除该用户", key="btn_del_user"): | |
| _delete_user_db(del_target) | |
| # 清除该用户的私有库 | |
| _clear_all_chunks(del_target) | |
| _clear_uploaded_files_storage(del_target) | |
| st.toast(f"用户 {del_target} 已删除") | |
| st.rerun() | |
| resetable = [u for u, _ in user_list if u != CURRENT_USER] | |
| if resetable: | |
| reset_target = st.selectbox("选择要重置密码的用户", resetable, key="reset_user_select") | |
| with st.form("reset_password_form", clear_on_submit=False): | |
| new_pass = st.text_input("新密码", type="password", key="reset_new_pass") | |
| reset_submitted = st.form_submit_button("🔄 重置密码", use_container_width=True) | |
| if reset_submitted: | |
| if len(new_pass) < 4: | |
| st.error("密码至少 4 个字符") | |
| else: | |
| target_role = all_users[reset_target].get("role", "user") | |
| _save_user(reset_target, _hash_password(new_pass), target_role) | |
| st.toast(f"用户 {reset_target} 密码已重置") | |
| with st.expander("📩 邀请码管理"): | |
| current_code = _get_invite_code() | |
| st.text(f"当前邀请码:{current_code if current_code else '未设置'}") | |
| with st.form("invite_code_form", clear_on_submit=False): | |
| new_code = st.text_input("新邀请码", key="new_invite_code") | |
| code_submitted = st.form_submit_button("✏️ 更新邀请码", use_container_width=True) | |
| if code_submitted: | |
| if new_code.strip(): | |
| _set_invite_code(new_code.strip()) | |
| st.toast("邀请码已更新") | |
| st.rerun() | |
| else: | |
| st.error("邀请码不能为空") | |
| with st.expander("🛠️ 数据库概览"): | |
| st.caption("Supabase 数据统计") | |
| try: | |
| pub_cnt = _count_chunks(PUBLIC_SCOPE) | |
| st.text(f"📚 公共库切片数: {pub_cnt}") | |
| # 简单统计各用户私有库 | |
| for uname, _ in user_list: | |
| cnt = _count_chunks(uname) | |
| if cnt > 0: | |
| st.text(f"🔒 {uname} 私有库: {cnt} 切片") | |
| except Exception as e: | |
| st.warning(f"统计失败: {e}") | |
| st.divider() | |
| st.caption("📋 用户列表") | |
| display_users = {} | |
| for k, v in all_users.items(): | |
| if isinstance(v, dict) and "password_hash" in v: | |
| v_copy = dict(v) | |
| v_copy["password_hash"] = v_copy["password_hash"][:8] + "..." | |
| display_users[k] = v_copy | |
| else: | |
| display_users[k] = v | |
| st.json(display_users) | |
| # --- 聊天记录管理(仅登录用户) --- | |
| if not IS_GUEST: | |
| with st.expander("💬 聊天记录"): | |
| hist_tab_new, hist_tab_history = st.tabs(["当前对话", "历史记录"]) | |
| with hist_tab_new: | |
| st.caption("清空当前对话(数据库记录保留)") | |
| if st.button("🧹 清空当前对话", use_container_width=True, type="secondary", key="btn_clear_chat"): | |
| st.session_state.messages = [] | |
| st.session_state["_chat_cleared"] = True # 标记为主动清空 | |
| st.rerun() | |
| st.caption("清空所有历史记录(不可恢复)") | |
| if st.button("🗑️ 清空全部记录", use_container_width=True, type="secondary", key="btn_clear_all_hist"): | |
| _async_run(_clear_chat_history_db, CURRENT_USER) | |
| st.session_state.messages = [] | |
| st.session_state["_chat_cleared"] = True # 标记为主动清空 | |
| st.toast("所有聊天记录已清空") | |
| st.rerun() | |
| with hist_tab_history: | |
| if st.button("🔄 加载历史记录", use_container_width=True, key="btn_load_hist"): | |
| st.session_state["_show_history"] = True | |
| if st.session_state.get("_show_history"): | |
| history = _load_chat_history(CURRENT_USER, limit=100) | |
| if not history: | |
| st.info("暂无历史记录") | |
| else: | |
| st.caption(f"共 {len(history)} 条记录") | |
| for msg in history: | |
| ts = msg.get("created_at", "")[:16].replace("T", " ") | |
| icon = "🧑" if msg["role"] == "user" else "🤖" | |
| preview = msg["content"][:80].replace("\n", " ") | |
| st.text(f"{icon} [{ts}] {preview}{'...' if len(msg['content']) > 80 else ''}") | |
| # ========================= | |
| # 8. 核心搜索逻辑(合并公共库 + 私有库) | |
| # ========================= | |
| def search_local(query, top_k, threshold): | |
| """使用 pgvector 数据库端向量搜索(替代内存计算)。返回 (docs, files)""" | |
| query_vec = encode_query(query) | |
| scopes = [PUBLIC_SCOPE, PRIVATE_SCOPE] | |
| try: | |
| resp = _sb().rpc("match_documents", { | |
| "query_embedding": query_vec.tolist() if hasattr(query_vec, 'tolist') else list(query_vec), | |
| "match_scopes": scopes, | |
| "match_threshold": float(threshold), | |
| "match_count": int(top_k), | |
| }).execute() | |
| if resp.data: | |
| docs = [row["content"] for row in resp.data] | |
| files = list(set(row["source_file"] for row in resp.data)) # 去重 | |
| return docs, files | |
| return [], [] | |
| except Exception as e: | |
| logger.error(f"pgvector 搜索失败: {e}") | |
| return [], [] | |
| # ========================= | |
| # 9. LLM 回答逻辑 | |
| # ========================= | |
| def llm_answer(query, context_docs, selected_display_name, web_enabled, kb_mode=False, source_files=None): | |
| all_context = "" | |
| curr_time = datetime.now().strftime("%Y-%m-%d %H:%M") | |
| st.session_state["kb_fallback"] = False # 重置降级标志 | |
| st.session_state["kb_source_files"] = source_files or [] # 记录引用文件 | |
| if web_enabled: | |
| # 联网模式:只用网络搜索,不用知识库 | |
| search_res = web_search(query) | |
| all_context = f"【互联网搜索结果】:\n{search_res}" | |
| prompt_content = f"当前时间:{curr_time}\n\n{all_context[:6500]}\n\n用户问题:{query}" | |
| system_prompt = SYSTEM_PROMPT_WEB | |
| elif context_docs: | |
| # 知识库模式:有知识库资料 | |
| all_context = "【知识库资料】:\n" + "\n".join(context_docs) | |
| prompt_content = f"当前时间:{curr_time}\n\n参考资料:\n{all_context[:6500]}\n\n用户问题:{query}" | |
| system_prompt = SYSTEM_PROMPT | |
| elif kb_mode: | |
| # 知识库模式但无匹配资料:降级到直接回答,但标记降级 | |
| st.session_state["kb_fallback"] = True | |
| prompt_content = f"当前时间:{curr_time}\n\n用户问题:{query}" | |
| system_prompt = SYSTEM_PROMPT_DIRECT | |
| else: | |
| # 直接回答模式:不联网、无知识库,直接用大模型回答 | |
| prompt_content = f"当前时间:{curr_time}\n\n用户问题:{query}" | |
| system_prompt = SYSTEM_PROMPT_DIRECT | |
| input_tokens = estimate_tokens(prompt_content) | |
| or_client = get_or_client() | |
| ds_client = get_ds_client() | |
| baidu_client = get_baidu_client() | |
| selected_id = model_mapping[selected_display_name] | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": prompt_content}, | |
| ] | |
| # 登录用户:直接调用选中模型(通过百度千帆),不使用重试队列 | |
| if not IS_GUEST: | |
| # 登录态模型都通过百度千帆调用,除了 deepseek-chat 用官方 API | |
| if selected_id == "deepseek-chat": | |
| client = ds_client | |
| else: | |
| client = baidu_client | |
| label = f"🔐 {selected_display_name}" | |
| logger.info(f"[{CURRENT_USER}] 登录用户调用: {label}") | |
| try: | |
| response = client.chat.completions.create( | |
| model=selected_id, | |
| messages=messages, | |
| stream=True, | |
| timeout=60, # 登录用户给更长超时 | |
| ) | |
| full_text = "" | |
| has_content = False | |
| # 直接迭代流式响应 | |
| for chunk in response: | |
| if chunk.choices and chunk.choices[0].delta.content: | |
| content = chunk.choices[0].delta.content | |
| full_text += content | |
| has_content = True | |
| yield content | |
| if has_content: | |
| fallback_hint = '<span class="meta-extra"><span class="meta-sep"> | </span>⚠️ 知识库无匹配,使用通用知识</span>' if st.session_state.get("kb_fallback") else "" | |
| source_files = st.session_state.get("kb_source_files", []) | |
| files_hint = f'<span class="meta-extra"><span class="meta-sep"> | </span>📄 {", ".join(source_files)}</span>' if source_files else "" | |
| st.session_state["last_meta"] = ( | |
| f"🟢 {label} | 📊 ~{input_tokens}/{estimate_tokens(full_text)} Tokens{files_hint}{fallback_hint}" | |
| ) | |
| return | |
| else: | |
| yield "❌ 模型返回空响应,请稍后重试。" | |
| return | |
| except Exception as e: | |
| err_msg = str(e) | |
| logger.error(f"{label} 调用失败: {err_msg[:200]}") | |
| yield f"❌ 调用失败:{err_msg[:100]}" | |
| return | |
| # 游客用户:使用重试队列和兜底机制 | |
| retry_queue = [] | |
| retry_queue.append( | |
| (or_client, selected_id, f"首选-{selected_display_name}") | |
| ) | |
| if selected_id != "stepfun/step-3.5-flash:free": | |
| retry_queue.append((or_client, "stepfun/step-3.5-flash:free", "⚡ 快速备选-Step3.5")) | |
| if selected_id != "openrouter/free": | |
| retry_queue.append((or_client, "openrouter/free", "OR-Auto 免费避堵")) | |
| # 游客的收费兜底 | |
| paid_backups = [ | |
| ("deepseek-chat", "🛡️ DeepSeek 官方", ds_client), | |
| ("ernie-3.5-8k", "🏢 百度文心", baidu_client), | |
| ] | |
| for p_id, p_label, p_client in paid_backups: | |
| if selected_id != p_id: | |
| retry_queue.append((p_client, p_id, f"💰 收费兜底-{p_label}")) | |
| for idx, (client, m_id, label) in enumerate(retry_queue): | |
| logger.info(f"[{CURRENT_USER}] 尝试链路: {label}") | |
| try: | |
| extra_h = ( | |
| {"HTTP-Referer": "https://streamlit.io", "X-Title": "RAG_v3"} | |
| if client is or_client | |
| else None | |
| ) | |
| response = client.chat.completions.create( | |
| model=m_id, | |
| messages=messages, | |
| stream=True, | |
| extra_headers=extra_h, | |
| timeout=25, | |
| ) | |
| full_text = "" | |
| has_content = False | |
| # 直接迭代流式响应 | |
| for chunk in response: | |
| if chunk.choices and chunk.choices[0].delta.content: | |
| content = chunk.choices[0].delta.content | |
| full_text += content | |
| has_content = True | |
| yield content | |
| if has_content: | |
| st.session_state["last_meta"] = ( | |
| f"🟢 {label} | 📊 ~{input_tokens}/{estimate_tokens(full_text)} Tokens" | |
| ) | |
| return | |
| except Exception as e: | |
| err_msg = str(e) | |
| logger.warning(f"{label} 失败: {err_msg[:100]}") | |
| if "429" in err_msg: | |
| st.toast(f"{label} 拥堵,切换备选...", icon="⏳") | |
| time.sleep(1.5) | |
| continue | |
| yield "❌ 抱歉,所有免费和收费线路均暂时不可用。" | |
| # ========================= | |
| # 10. 聊天渲染(使用 @st.fragment 避免整页刷新) | |
| # ========================= | |
| if "messages" not in st.session_state: | |
| # 首次加载,从数据库恢复最近对话(游客不恢复) | |
| if not IS_GUEST: | |
| saved = _load_chat_history(CURRENT_USER, limit=50) | |
| st.session_state.messages = [ | |
| {"role": m["role"], "content": m["content"], | |
| **({"meta": m["meta"]} if m.get("meta") else {})} | |
| for m in saved | |
| ] | |
| else: | |
| st.session_state.messages = [] | |
| elif st.session_state.get("_chat_cleared"): | |
| # 用户主动清空,不从数据库恢复 | |
| del st.session_state["_chat_cleared"] | |
| def _chat_fragment(): | |
| """聊天区域独立 fragment,提交消息时只刷新此区域,不刷新整个页面。""" | |
| _model_name = st.session_state.get("sel_model", list(model_mapping.keys())[0]) | |
| _web_on = st.session_state.get("sel_web", True) # 默认联网模式 | |
| _top_k = st.session_state.get("sel_topk", 5) | |
| _threshold = st.session_state.get("sel_threshold", 0.25) | |
| # 欢迎页(无消息时显示,放在 fragment 内避免整页刷新) | |
| _welcome_hero = st.empty() | |
| if not st.session_state.messages: | |
| with _welcome_hero.container(): | |
| st.markdown( | |
| """ | |
| <div class="welcome-marker" style="text-align:center; color:#888;"> | |
| <div style="font-size:48px; margin-bottom:12px;">🤖</div> | |
| <div style="font-size:18px; font-weight:600; margin-bottom:8px;">欢迎使用智答 AI 助手</div> | |
| <div style="font-size:14px; line-height:1.8; margin-bottom:20px;"> | |
| 🌐 <b>联网模式</b> —— 大模型 + 网络搜索,实时回答<br> | |
| 📚 <b>知识库模式</b> —— 上传文档,基于私有知识回答 | |
| </div> | |
| <div style="color:#999; font-size:13px; margin-bottom:12px;">💡 试试这些问题</div> | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| # 快捷问题 - 在 fragment 内部,点击只触发 fragment 刷新 | |
| _quick_questions = [ | |
| "今天有什么热点新闻", | |
| "用Python写快速排序", | |
| "解释RAG技术是什么", | |
| ] | |
| _q_cols = st.columns(3) | |
| for i, qq in enumerate(_quick_questions): | |
| with _q_cols[i]: | |
| if st.button(f"💬 {qq}", key=f"quick_q{i}", use_container_width=True, type="secondary"): | |
| st.session_state["_quick_question"] = qq | |
| # 消息容器 - 不设固定高度,由外层CSS控制滚动 | |
| chat_box = st.container() | |
| q = st.chat_input("问问智答AI助手", key="chat_input_v3") | |
| # 检查快捷问题 | |
| if st.session_state.get("_quick_question"): | |
| q = st.session_state.pop("_quick_question") | |
| with chat_box: | |
| for idx, m in enumerate(st.session_state.messages): | |
| with st.chat_message(m["role"]): | |
| if m["role"] == "assistant": | |
| meta_html = f'\n\n<span style="color:#999;font-size:12px;">{m["meta"]}</span>' if m.get("meta") else "" | |
| action_btns = _render_action_buttons(m["content"], f"hist_{idx}") | |
| st.markdown( | |
| m["content"] + meta_html + action_btns, | |
| unsafe_allow_html=True, | |
| ) | |
| _inject_action_js(m["content"], f"hist_{idx}") | |
| else: | |
| st.markdown(m["content"]) | |
| if q: | |
| _welcome_hero.empty() # 发送消息后清空欢迎页 | |
| st.session_state.messages.append({"role": "user", "content": q}) | |
| if not IS_GUEST: | |
| _save_chat_message(CURRENT_USER, "user", q) | |
| with st.chat_message("user"): | |
| st.markdown(q) | |
| with st.chat_message("assistant"): | |
| response_container = st.empty() | |
| # 滚动JS放在assistant块内部,不影响问题和回答之间的间距 | |
| components.html( | |
| """<script> | |
| setTimeout(function(){ | |
| var c = window.parent.document.querySelector('.block-container'); | |
| if(c) c.scrollTop = c.scrollHeight; | |
| }, 80); | |
| </script>""", | |
| height=0, | |
| ) | |
| # 游客模式:始终用大模型(联网时加网络搜索,不联网直接回答) | |
| # 登录模式 + 联网:大模型 + 网络搜索 | |
| # 登录模式 + 知识库:大模型 + 知识库检索 | |
| if IS_GUEST: | |
| # 游客没有知识库,直接用大模型 | |
| if _web_on: | |
| response_container.markdown("*🌐 正在联网搜索...*") | |
| else: | |
| response_container.markdown("*🤔 正在思考...*") | |
| relevant_docs, source_files = [], [] | |
| elif _web_on: | |
| response_container.markdown("*🌐 正在联网搜索...*") | |
| relevant_docs, source_files = [], [] | |
| else: | |
| response_container.markdown("*🔍 正在搜索知识库...*") | |
| relevant_docs, source_files = search_local(q, _top_k, _threshold) | |
| response_container.markdown("*🤔 正在组织语言...*") | |
| # 登录用户 + 非联网 = 知识库模式 | |
| _kb_mode = (not IS_GUEST) and (not _web_on) | |
| try: | |
| full_response = response_container.write_stream( | |
| llm_answer(q, relevant_docs, _model_name, _web_on, kb_mode=_kb_mode, source_files=source_files) | |
| ) | |
| meta_info = st.session_state.get("last_meta", "") | |
| # 生成工具按钮(使用当前消息数作为ID) | |
| msg_idx = len(st.session_state.messages) | |
| action_btns = _render_action_buttons(full_response, f"new_{msg_idx}") | |
| meta_html = f'\n\n<span style="color:#999;font-size:12px;">{meta_info}</span>' if meta_info else "" | |
| response_container.markdown( | |
| full_response + meta_html + action_btns, | |
| unsafe_allow_html=True, | |
| ) | |
| _inject_action_js(full_response, f"new_{msg_idx}") | |
| st.session_state.messages.append( | |
| {"role": "assistant", "content": full_response, "meta": meta_info} | |
| ) | |
| if not IS_GUEST: | |
| _save_chat_message(CURRENT_USER, "assistant", full_response, meta_info) | |
| except Exception as e: | |
| logger.error(f"模型调用异常: {e}") | |
| response_container.error(f"❌ 抱歉,连接模型时出错了: {str(e)}") | |
| # ========================= | |
| # 11. 搜索框下方的工具按钮(使用 fragment 局部刷新) | |
| # ========================= | |
| def _toolbar_fragment(): | |
| _current_model = st.session_state.get("sel_model", list(model_mapping.keys())[0]) | |
| _web_status = st.session_state.get("sel_web", True) | |
| _model_short = _current_model.split("(")[0].strip() | |
| _btn_cols = st.columns(2) | |
| with _btn_cols[0]: | |
| with st.popover(_model_short, use_container_width=True): | |
| st.caption("选择模型") | |
| _model_list = list(_active_models.keys()) | |
| for m in _model_list: | |
| _btn_type = "primary" if m == _current_model else "secondary" | |
| if st.button(m, key=f"mdl_{m}", use_container_width=True, type=_btn_type): | |
| if m != _current_model: | |
| st.session_state["sel_model"] = m | |
| st.rerun() | |
| with _btn_cols[1]: | |
| if IS_GUEST: | |
| _web_label = "🌐 联网模式" if _web_status else "💬 直接回答" | |
| else: | |
| _web_label = "🌐 联网模式" if _web_status else "📚 知识库" | |
| if st.button(_web_label, key="toggle_web", use_container_width=True): | |
| st.session_state["sel_web"] = not _web_status | |
| st.rerun() | |
| _toolbar_fragment() | |
| _chat_fragment() | |