Spaces:
Running
Running
| # # # import streamlit as st | |
| # # # import pandas as pd | |
| # # # import numpy as np | |
| # # # import jieba | |
| # # # import requests | |
| # # # import os | |
| # # # import sys | |
| # # # import subprocess | |
| # # # from openai import OpenAI | |
| # # # from rank_bm25 import BM25Okapi | |
| # # # from sklearn.metrics.pairwise import cosine_similarity | |
| # # # # ================= 1. 全局配置与 CSS注入 ================= | |
| # # # API_KEY = os.getenv("SILICONFLOW_API_KEY") | |
| # # # API_BASE = "https://api.siliconflow.cn/v1" | |
| # # # EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-4B" | |
| # # # RERANK_MODEL = "Qwen/Qwen3-Reranker-4B" | |
| # # # GEN_MODEL_NAME = "MiniMaxAI/MiniMax-M2" | |
| # # # DATA_FILENAME = "comsol_embedded.parquet" | |
| # # # DATA_URL = "https://share.leezhu.cn/graduation_design_data/comsol_embedded.parquet" | |
| # # # st.set_page_config( | |
| # # # page_title="COMSOL Dark Expert", | |
| # # # page_icon="🌌", | |
| # # # layout="wide", | |
| # # # initial_sidebar_state="expanded" | |
| # # # ) | |
| # # # # --- 注入自定义 CSS (保持之前的审美) --- | |
| # # # st.markdown(""" | |
| # # # <style> | |
| # # # /* 1. 整体背景 - 深空黑 */ | |
| # # # .stApp { | |
| # # # background-color: #050505; | |
| # # # background-image: radial-gradient(circle at 50% 0%, #1a1f35 0%, #050505 60%); | |
| # # # color: #e0e0e0; | |
| # # # font-family: 'Inter', system-ui, -apple-system, sans-serif; | |
| # # # } | |
| # # # /* 2. 隐藏默认组件 */ | |
| # # # #MainMenu {visibility: hidden;} | |
| # # # footer {visibility: hidden;} | |
| # # # header {visibility: hidden;} | |
| # # # /* 3. 聊天气泡 */ | |
| # # # [data-testid="stChatMessage"] { | |
| # # # background: rgba(255, 255, 255, 0.03); | |
| # # # border: 1px solid rgba(255, 255, 255, 0.08); | |
| # # # border-radius: 16px; | |
| # # # backdrop-filter: blur(12px); | |
| # # # box-shadow: 0 4px 20px rgba(0,0,0,0.2); | |
| # # # padding: 1.2rem; | |
| # # # } | |
| # # # /* 用户气泡 */ | |
| # # # [data-testid="stChatMessage"][data-testid="user"] { | |
| # # # background: rgba(41, 181, 232, 0.1); | |
| # # # border-color: rgba(41, 181, 232, 0.2); | |
| # # # } | |
| # # # /* 4. 自定义标题栏 */ | |
| # # # .custom-header { | |
| # # # border-bottom: 1px solid rgba(255,255,255,0.1); | |
| # # # padding-bottom: 1rem; | |
| # # # margin-bottom: 2rem; | |
| # # # display: flex; | |
| # # # align-items: center; | |
| # # # gap: 1rem; | |
| # # # } | |
| # # # .glitch-text { | |
| # # # font-size: 2rem; | |
| # # # font-weight: 800; | |
| # # # background: linear-gradient(120deg, #fff, #29B5E8); | |
| # # # -webkit-background-clip: text; | |
| # # # -webkit-text-fill-color: transparent; | |
| # # # letter-spacing: -1px; | |
| # # # } | |
| # # # /* 5. 快捷按钮 */ | |
| # # # div.stButton > button { | |
| # # # background: rgba(255,255,255,0.05); | |
| # # # color: #aaa; | |
| # # # border: 1px solid rgba(255,255,255,0.1); | |
| # # # border-radius: 20px; | |
| # # # padding: 0.5rem 1rem; | |
| # # # font-size: 0.85rem; | |
| # # # transition: all 0.3s; | |
| # # # width: 100%; | |
| # # # } | |
| # # # div.stButton > button:hover { | |
| # # # background: rgba(41, 181, 232, 0.2); | |
| # # # color: #fff; | |
| # # # border-color: #29B5E8; | |
| # # # transform: translateY(-2px); | |
| # # # } | |
| # # # /* 6. 输入框 */ | |
| # # # .stChatInputContainer textarea { | |
| # # # background-color: #0f1115 !important; | |
| # # # border: 1px solid #333 !important; | |
| # # # color: white !important; | |
| # # # border-radius: 12px !important; | |
| # # # } | |
| # # # /* 7. Expander */ | |
| # # # .streamlit-expanderHeader { | |
| # # # background-color: rgba(255,255,255,0.02); | |
| # # # border: 1px solid rgba(255,255,255,0.05); | |
| # # # border-radius: 8px; | |
| # # # color: #bbb; | |
| # # # } | |
| # # # </style> | |
| # # # """, unsafe_allow_html=True) | |
| # # # # ================= 2. 核心逻辑(数据与RAG) ================= | |
| # # # if not API_KEY: | |
| # # # st.error("⚠️ 未检测到 API Key。请在 Settings -> Secrets 中配置 `SILICONFLOW_API_KEY`。") | |
| # # # st.stop() | |
| # # # def download_with_curl(url, output_path): | |
| # # # try: | |
| # # # cmd = [ | |
| # # # "curl", "-L", | |
| # # # "-A", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", | |
| # # # "-o", output_path, | |
| # # # "--fail", | |
| # # # url | |
| # # # ] | |
| # # # result = subprocess.run(cmd, capture_output=True, text=True) | |
| # # # if result.returncode != 0: raise Exception(f"Curl failed: {result.stderr}") | |
| # # # return True | |
| # # # except Exception as e: | |
| # # # print(f"Curl download error: {e}") | |
| # # # return False | |
| # # # def get_data_file_path(): | |
| # # # possible_paths = [ | |
| # # # DATA_FILENAME, os.path.join("/app", DATA_FILENAME), | |
| # # # os.path.join("processed_data", DATA_FILENAME), | |
| # # # os.path.join("src", DATA_FILENAME), | |
| # # # os.path.join("..", DATA_FILENAME), "/tmp/" + DATA_FILENAME | |
| # # # ] | |
| # # # for path in possible_paths: | |
| # # # if os.path.exists(path): return path | |
| # # # download_target = "/app/" + DATA_FILENAME | |
| # # # try: os.makedirs(os.path.dirname(download_target), exist_ok=True) | |
| # # # except: download_target = "/tmp/" + DATA_FILENAME | |
| # # # status_container = st.empty() | |
| # # # status_container.info("📡 正在接入神经元网络... (下载核心数据中)") | |
| # # # if download_with_curl(DATA_URL, download_target): | |
| # # # status_container.empty() | |
| # # # return download_target | |
| # # # try: | |
| # # # headers = {'User-Agent': 'Mozilla/5.0'} | |
| # # # r = requests.get(DATA_URL, headers=headers, stream=True) | |
| # # # r.raise_for_status() | |
| # # # with open(download_target, 'wb') as f: | |
| # # # for chunk in r.iter_content(chunk_size=8192): f.write(chunk) | |
| # # # status_container.empty() | |
| # # # return download_target | |
| # # # except Exception as e: | |
| # # # st.error(f"❌ 数据链路中断。Error: {e}") | |
| # # # st.stop() | |
| # # # class FullRetriever: | |
| # # # def __init__(self, parquet_path): | |
| # # # try: self.df = pd.read_parquet(parquet_path) | |
| # # # except Exception as e: st.error(f"Memory Matrix Load Failed: {e}"); st.stop() | |
| # # # self.documents = self.df['content'].tolist() | |
| # # # self.embeddings = np.stack(self.df['embedding'].values) | |
| # # # self.bm25 = BM25Okapi([jieba.lcut(str(d).lower()) for d in self.documents]) | |
| # # # self.client = OpenAI(base_url=API_BASE, api_key=API_KEY) | |
| # # # # Reranker 初始化移到这里,减少重复调用 | |
| # # # self.rerank_headers = {"Content-Type": "application/json", "Authorization": f"Bearer {API_KEY}"} | |
| # # # self.rerank_url = f"{API_BASE}/rerank" | |
| # # # def _get_emb(self, q): | |
| # # # try: return self.client.embeddings.create(model=EMBEDDING_MODEL, input=[q]).data[0].embedding | |
| # # # except: return [0.0] * 1024 | |
| # # # def hybrid_search(self, query: str, top_k=5): | |
| # # # # 1. Vector | |
| # # # q_emb = self._get_emb(query) | |
| # # # vec_scores = cosine_similarity([q_emb], self.embeddings)[0] | |
| # # # vec_idx = np.argsort(vec_scores)[-100:][::-1] | |
| # # # # 2. Keyword | |
| # # # kw_idx = np.argsort(self.bm25.get_scores(jieba.lcut(query.lower())))[-100:][::-1] | |
| # # # # 3. RRF Fusion | |
| # # # fused = {} | |
| # # # for r, i in enumerate(vec_idx): fused[i] = fused.get(i, 0) + 1/(60+r+1) | |
| # # # for r, i in enumerate(kw_idx): fused[i] = fused.get(i, 0) + 1/(60+r+1) | |
| # # # c_idxs = [x[0] for x in sorted(fused.items(), key=lambda x:x[1], reverse=True)[:50]] | |
| # # # c_docs = [self.documents[i] for i in c_idxs] | |
| # # # # 4. Rerank | |
| # # # try: | |
| # # # payload = {"model": RERANK_MODEL, "query": query, "documents": c_docs, "top_n": top_k} | |
| # # # resp = requests.post(self.rerank_url, headers=self.rerank_headers, json=payload, timeout=10) | |
| # # # results = resp.json().get('results', []) | |
| # # # except: | |
| # # # results = [{"index": i, "relevance_score": 0.0} for i in range(len(c_docs))][:top_k] | |
| # # # final_res = [] | |
| # # # context = "" | |
| # # # for i, item in enumerate(results): | |
| # # # orig_idx = c_idxs[item['index']] | |
| # # # row = self.df.iloc[orig_idx] | |
| # # # final_res.append({ | |
| # # # "score": item['relevance_score'], | |
| # # # "filename": row['filename'], | |
| # # # "content": row['content'] | |
| # # # }) | |
| # # # context += f"[文档{i+1}]: {row['content']}\n\n" | |
| # # # return final_res, context | |
| # # # @st.cache_resource | |
| # # # def load_engine(): | |
| # # # real_path = get_data_file_path() | |
| # # # return FullRetriever(real_path) | |
| # # # # ================= 3. UI 主程序 ================= | |
| # # # def main(): | |
| # # # st.markdown(""" | |
| # # # <div class="custom-header"> | |
| # # # <div style="font-size: 3rem;">🌌</div> | |
| # # # <div> | |
| # # # <div class="glitch-text">COMSOL DARK EXPERT</div> | |
| # # # <div style="color: #666; font-size: 0.9rem; letter-spacing: 1px;"> | |
| # # # NEURAL SIMULATION ASSISTANT <span style="color:#29B5E8">V4.1 Fixed</span> | |
| # # # </div> | |
| # # # </div> | |
| # # # </div> | |
| # # # """, unsafe_allow_html=True) | |
| # # # retriever = load_engine() | |
| # # # with st.sidebar: | |
| # # # st.markdown("### ⚙️ 控制台") | |
| # # # top_k = st.slider("检索深度", 1, 10, 4) | |
| # # # temp = st.slider("发散度", 0.0, 1.0, 0.3) | |
| # # # st.markdown("---") | |
| # # # if st.button("🗑️ 清空记忆 (Clear)", use_container_width=True): | |
| # # # st.session_state.messages = [] | |
| # # # st.session_state.current_refs = [] | |
| # # # st.rerun() | |
| # # # if "messages" not in st.session_state: st.session_state.messages = [] | |
| # # # if "current_refs" not in st.session_state: st.session_state.current_refs = [] | |
| # # # col_chat, col_evidence = st.columns([0.65, 0.35], gap="large") | |
| # # # # ------------------ 处理输入源 ------------------ | |
| # # # # 我们定义一个变量 user_input,不管它来自按钮还是输入框 | |
| # # # user_input = None | |
| # # # with col_chat: | |
| # # # # 1. 如果历史为空,显示快捷按钮 | |
| # # # if not st.session_state.messages: | |
| # # # st.markdown("##### 💡 初始化提问序列 (Starter Sequence)") | |
| # # # c1, c2, c3 = st.columns(3) | |
| # # # # 点击按钮直接赋值给 user_input | |
| # # # if c1.button("🌊 流固耦合接口设置"): | |
| # # # user_input = "怎么设置流固耦合接口?" | |
| # # # elif c2.button("⚡ 低频电磁场网格"): | |
| # # # user_input = "低频电磁场网格划分有哪些技巧?" | |
| # # # elif c3.button("📉 求解器不收敛"): | |
| # # # user_input = "求解器不收敛通常怎么解决?" | |
| # # # # 2. 渲染历史消息 | |
| # # # for msg in st.session_state.messages: | |
| # # # with st.chat_message(msg["role"]): | |
| # # # st.markdown(msg["content"]) | |
| # # # # 3. 处理底部输入框 (如果有按钮输入,这里会被跳过,因为 user_input 已经有值了) | |
| # # # if not user_input: | |
| # # # user_input = st.chat_input("输入指令或物理参数问题...") | |
| # # # # ------------------ 统一处理消息追加 ------------------ | |
| # # # if user_input: | |
| # # # st.session_state.messages.append({"role": "user", "content": user_input}) | |
| # # # # 强制刷新以立即在 UI 上显示用户的提问(对于按钮点击尤为重要) | |
| # # # st.rerun() | |
| # # # # ------------------ 统一触发生成 (修复的核心) ------------------ | |
| # # # # 检查:如果有消息,且最后一条是 User 发的,说明需要 Assistant 回答 | |
| # # # if st.session_state.messages and st.session_state.messages[-1]["role"] == "user": | |
| # # # # 获取最后一条用户消息 | |
| # # # last_query = st.session_state.messages[-1]["content"] | |
| # # # with col_chat: # 确保在聊天栏显示 | |
| # # # with st.spinner("🔍 正在扫描向量空间..."): | |
| # # # refs, context = retriever.hybrid_search(last_query, top_k=top_k) | |
| # # # st.session_state.current_refs = refs | |
| # # # system_prompt = f"""你是一个COMSOL高级仿真专家。请基于提供的文档回答问题。 | |
| # # # 要求: | |
| # # # 1. 语气专业、客观,逻辑严密。 | |
| # # # 2. 涉及物理公式时,**必须**使用 LaTeX 格式(例如 $E = mc^2$)。 | |
| # # # 3. 涉及步骤或参数对比时,优先使用 Markdown 列表或表格。 | |
| # # # 参考文档: | |
| # # # {context} | |
| # # # """ | |
| # # # with st.chat_message("assistant"): | |
| # # # resp_cont = st.empty() | |
| # # # full_resp = "" | |
| # # # client = OpenAI(base_url=API_BASE, api_key=API_KEY) | |
| # # # try: | |
| # # # stream = client.chat.completions.create( | |
| # # # model=GEN_MODEL_NAME, | |
| # # # messages=[{"role": "system", "content": system_prompt}] + st.session_state.messages[-6:], # 除去当前的System | |
| # # # temperature=temp, | |
| # # # stream=True | |
| # # # ) | |
| # # # for chunk in stream: | |
| # # # txt = chunk.choices[0].delta.content | |
| # # # if txt: | |
| # # # full_resp += txt | |
| # # # resp_cont.markdown(full_resp + " ▌") | |
| # # # resp_cont.markdown(full_resp) | |
| # # # st.session_state.messages.append({"role": "assistant", "content": full_resp}) | |
| # # # except Exception as e: | |
| # # # st.error(f"Neural Generation Failed: {e}") | |
| # # # # ------------------ 渲染右侧证据栏 ------------------ | |
| # # # with col_evidence: | |
| # # # st.markdown("### 📚 神经记忆 (Evidence)") | |
| # # # if st.session_state.current_refs: | |
| # # # for i, ref in enumerate(st.session_state.current_refs): | |
| # # # score = ref['score'] | |
| # # # score_color = "#00ff41" if score > 0.6 else "#ffb700" if score > 0.4 else "#ff003c" | |
| # # # with st.expander(f"📄 Doc {i+1}: {ref['filename'][:20]}...", expanded=(i==0)): | |
| # # # st.markdown(f""" | |
| # # # <div style="margin-bottom:5px;"> | |
| # # # <span style="color:#888;">Relevance:</span> | |
| # # # <span style="color:{score_color}; font-weight:bold;">{score:.4f}</span> | |
| # # # </div> | |
| # # # """, unsafe_allow_html=True) | |
| # # # st.code(ref['content'], language="text") | |
| # # # else: | |
| # # # st.info("等待输入指令以检索知识库...") | |
| # # # st.markdown(""" | |
| # # # <div style="opacity:0.3; font-size:0.8rem; margin-top:20px;"> | |
| # # # Waiting for query signal...<br> | |
| # # # Index Status: Ready<br> | |
| # # # Awaiting Input | |
| # # # </div> | |
| # # # """, unsafe_allow_html=True) | |
| # # # if __name__ == "__main__": | |
| # # # main() | |
| # # import streamlit as st | |
| # # import pandas as pd | |
| # # import numpy as np | |
| # # import jieba | |
| # # import requests | |
| # # import os | |
| # # import time | |
| # # import json | |
| # # import re | |
| # # import random | |
| # # import subprocess | |
| # # from openai import OpenAI | |
| # # from rank_bm25 import BM25Okapi | |
| # # from sklearn.metrics.pairwise import cosine_similarity | |
| # # from typing import List, Dict, Tuple, Any | |
| # # # ================= 1. 全局配置与样式 ================= | |
| # # # API 配置 (从 HF 环境变量获取) | |
| # # API_BASE = "https://api.siliconflow.cn/v1" | |
| # # API_KEY = os.getenv("SILICONFLOW_API_KEY") | |
| # # # 模型名称配置 | |
| # # EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-4B" | |
| # # RERANK_MODEL = "Qwen/Qwen3-Reranker-4B" | |
| # # GEN_MODEL_NAME = "MiniMaxAI/MiniMax-M2" | |
| # # QE_MODEL_NAME = "Qwen/Qwen3-Next-80B-A3B-Instruct" | |
| # # SUGGEST_MODEL_NAME = "Qwen/Qwen3-Next-80B-A3B-Instruct" | |
| # # # 预置问题池 | |
| # # PRESET_QUESTIONS = [ | |
| # # "如何设置流固耦合接口?", | |
| # # "求解器不收敛怎么办?", | |
| # # "网格划分有哪些技巧?", | |
| # # "如何定义随时间变化的边界条件?", | |
| # # "计算结果如何导出数据?", | |
| # # "什么是完美匹配层 (PML)?", | |
| # # "低频电磁场仿真注意事项", | |
| # # "如何提高瞬态计算速度?", | |
| # # "参数化扫描如何设置?", | |
| # # "多物理场耦合的收敛性优化" | |
| # # ] | |
| # # # 数据文件配置 | |
| # # DATA_FILENAME = "comsol_embedded.parquet" | |
| # # DATA_URL = "https://share.leezhu.cn/graduation_design_data/comsol_embedded.parquet" | |
| # # # 页面配置 | |
| # # st.set_page_config( | |
| # # page_title="COMSOL RAG 策略控制台", | |
| # # page_icon="🎛️", | |
| # # layout="wide", | |
| # # initial_sidebar_state="expanded" | |
| # # ) | |
| # # # 自定义CSS样式 | |
| # # st.markdown(""" | |
| # # <style> | |
| # # /* 深色主题 */ | |
| # # .stApp { | |
| # # background-color: #0E1117; | |
| # # color: #E0E0E0; | |
| # # } | |
| # # /* 聊天消息样式 */ | |
| # # [data-testid="stChatMessage"] { | |
| # # background-color: #1E1E1E; | |
| # # border: 1px solid #333; | |
| # # border-radius: 10px; | |
| # # box-shadow: 0 2px 4px rgba(0,0,0,0.3); | |
| # # } | |
| # # /* 侧边栏样式 */ | |
| # # [data-testid="stSidebar"] { | |
| # # background-color: #161B22; | |
| # # border-right: 1px solid #30363D; | |
| # # } | |
| # # /* 策略标签 */ | |
| # # .strat-tag { | |
| # # font-size: 0.75rem; | |
| # # padding: 3px 8px; | |
| # # border-radius: 4px; | |
| # # margin-right: 6px; | |
| # # font-weight: bold; | |
| # # display: inline-block; | |
| # # margin-bottom: 4px; | |
| # # border: 1px solid rgba(255,255,255,0.2); | |
| # # } | |
| # # .tag-vec { background-color: rgba(31, 119, 180, 0.3); color: #4EA8DE; border-color: #1f77b4; } | |
| # # .tag-bm25 { background-color: rgba(255, 127, 14, 0.3); color: #FFAB5E; border-color: #ff7f0e; } | |
| # # .tag-qe { background-color: rgba(44, 160, 44, 0.3); color: #69DB7C; border-color: #2ca02c; } | |
| # # .tag-rerank { background-color: rgba(214, 39, 40, 0.3); color: #FF6B6B; border-color: #d62728; } | |
| # # /* 过程展示框 */ | |
| # # .process-box { | |
| # # background-color: #0D1117; | |
| # # border: 1px solid #30363D; | |
| # # padding: 15px; | |
| # # border-radius: 8px; | |
| # # font-size: 0.9rem; | |
| # # color: #8B949E; | |
| # # margin-bottom: 15px; | |
| # # } | |
| # # /* 策略矩阵标题 */ | |
| # # .strategy-title { | |
| # # background: linear-gradient(45deg, #667eea 0%, #764ba2 100%); | |
| # # -webkit-background-clip: text; | |
| # # -webkit-text-fill-color: transparent; | |
| # # background-clip: text; | |
| # # font-weight: bold; | |
| # # font-size: 1.2rem; | |
| # # } | |
| # # </style> | |
| # # """, unsafe_allow_html=True) | |
| # # # ================= 2. 数据下载工具 (HF 适配) ================= | |
| # # def download_with_curl(url, output_path): | |
| # # """使用 curl 下载文件,增加鲁棒性""" | |
| # # try: | |
| # # cmd = [ | |
| # # "curl", "-L", | |
| # # "-A", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", | |
| # # "-o", output_path, | |
| # # "--fail", | |
| # # url | |
| # # ] | |
| # # result = subprocess.run(cmd, capture_output=True, text=True) | |
| # # if result.returncode != 0: | |
| # # print(f"Curl stderr: {result.stderr}") | |
| # # return False | |
| # # return True | |
| # # except Exception as e: | |
| # # print(f"Curl download error: {e}") | |
| # # return False | |
| # # def get_data_file_path(): | |
| # # """获取数据文件路径,如果不存在则自动下载""" | |
| # # # 优先检查本地可能存在的路径 | |
| # # possible_paths = [ | |
| # # DATA_FILENAME, | |
| # # os.path.join("/app", DATA_FILENAME), | |
| # # os.path.join("processed_data", DATA_FILENAME), | |
| # # os.path.join(os.getcwd(), DATA_FILENAME) | |
| # # ] | |
| # # for path in possible_paths: | |
| # # if os.path.exists(path): | |
| # # return path | |
| # # # 如果都没找到,准备下载 | |
| # # # HF Spaces 通常在 /home/user/app 下运行,直接下载到当前目录 | |
| # # download_target = os.path.join(os.getcwd(), DATA_FILENAME) | |
| # # status_container = st.empty() | |
| # # status_container.info("📡 正在接入神经元网络... (下载核心数据中,首次运行可能需要几十秒)") | |
| # # # 尝试 Curl 下载 | |
| # # if download_with_curl(DATA_URL, download_target): | |
| # # status_container.empty() | |
| # # return download_target | |
| # # # 降级尝试 Requests 下载 | |
| # # try: | |
| # # headers = {'User-Agent': 'Mozilla/5.0'} | |
| # # r = requests.get(DATA_URL, headers=headers, stream=True) | |
| # # r.raise_for_status() | |
| # # with open(download_target, 'wb') as f: | |
| # # for chunk in r.iter_content(chunk_size=8192): | |
| # # f.write(chunk) | |
| # # status_container.empty() | |
| # # return download_target | |
| # # except Exception as e: | |
| # # st.error(f"❌ 数据下载失败。Error: {e}") | |
| # # st.stop() | |
| # # # ================= 3. 核心 RAG 控制器 ================= | |
| # # class RAGController: | |
| # # """RAG系统控制器 - 实现策略矩阵""" | |
| # # def __init__(self): | |
| # # """初始化控制器""" | |
| # # if not API_KEY: | |
| # # st.error("⚠️ 未检测到 API Key。请在 Space Settings -> Secrets 中配置 `SILICONFLOW_API_KEY`。") | |
| # # st.stop() | |
| # # self.client = OpenAI(base_url=API_BASE, api_key=API_KEY) | |
| # # self.df = None | |
| # # self.documents = [] | |
| # # self.embeddings = None | |
| # # self.bm25 = None | |
| # # self.filenames = [] | |
| # # self._load_data() | |
| # # def _load_data(self): | |
| # # """加载COMSOL文档数据""" | |
| # # real_path = get_data_file_path() | |
| # # try: | |
| # # # 加载数据 | |
| # # self.df = pd.read_parquet(real_path) | |
| # # self.documents = self.df['content'].tolist() | |
| # # self.filenames = self.df['filename'].tolist() | |
| # # # 加载向量嵌入 | |
| # # self.embeddings = np.stack(self.df['embedding'].values) | |
| # # # 初始化BM25 | |
| # # tokenized_corpus = [jieba.lcut(str(doc).lower()) for doc in self.documents] | |
| # # self.bm25 = BM25Okapi(tokenized_corpus) | |
| # # st.success(f"✅ 成功加载 {len(self.documents)} 条文档") | |
| # # except Exception as e: | |
| # # st.error(f"❌ 数据加载失败: {str(e)}") | |
| # # st.stop() | |
| # # def get_embedding(self, text: str) -> List[float]: | |
| # # """获取文本向量嵌入""" | |
| # # try: | |
| # # resp = self.client.embeddings.create( | |
| # # model=EMBEDDING_MODEL, | |
| # # input=[text] | |
| # # ) | |
| # # return resp.data[0].embedding | |
| # # except Exception as e: | |
| # # st.warning(f"向量获取失败: {e}") | |
| # # return [0.0] * 2560 # Qwen3-Embedding-4B dimension fallback | |
| # # def expand_query(self, query: str) -> Tuple[str, float]: | |
| # # """查询扩展 - 使用LLM优化查询""" | |
| # # prompt = f"""你是COMSOL仿真专家。请将用户的口语化问题改写为专业的检索查询。 | |
| # # 要求: | |
| # # 1. 补充COMSOL专业术语(物理场、模块、边界条件等) | |
| # # 2. 保持问题核心意图不变 | |
| # # 3. 输出简洁,仅返回改写后的查询 | |
| # # 用户问题: {query} | |
| # # 专业查询:""" | |
| # # try: | |
| # # start_time = time.time() | |
| # # resp = self.client.chat.completions.create( | |
| # # model=QE_MODEL_NAME, | |
| # # messages=[{"role": "user", "content": prompt}], | |
| # # temperature=0.3 | |
| # # ) | |
| # # expanded = resp.choices[0].message.content.strip() | |
| # # elapsed = time.time() - start_time | |
| # # return expanded, elapsed | |
| # # except Exception as e: | |
| # # print(f"QE Error: {e}") | |
| # # return query, 0 | |
| # # def vector_search(self, query: str, top_k: int = 100) -> List[Tuple[int, float]]: | |
| # # """向量检索""" | |
| # # q_vec = self.get_embedding(query) | |
| # # similarities = cosine_similarity([q_vec], self.embeddings)[0] | |
| # # top_indices = np.argsort(similarities)[-top_k:][::-1] | |
| # # return [(idx, similarities[idx]) for idx in top_indices] | |
| # # def bm25_search(self, query: str, top_k: int = 100) -> List[Tuple[int, float]]: | |
| # # """BM25关键词检索""" | |
| # # tokenized_query = jieba.lcut(query.lower()) | |
| # # scores = self.bm25.get_scores(tokenized_query) | |
| # # top_indices = np.argsort(scores)[-top_k:][::-1] | |
| # # return [(idx, scores[idx]) for idx in top_indices] | |
| # # def reciprocal_rank_fusion(self, vector_results: List[Tuple[int, float]], | |
| # # bm25_results: List[Tuple[int, float]], k: int = 60) -> Dict[int, float]: | |
| # # """RRF融合算法""" | |
| # # scores = {} | |
| # # for rank, (idx, score) in enumerate(vector_results): | |
| # # scores[idx] = scores.get(idx, 0) + 1.0 / (k + rank + 1) | |
| # # for rank, (idx, score) in enumerate(bm25_results): | |
| # # scores[idx] = scores.get(idx, 0) + 1.0 / (k + rank + 1) | |
| # # return scores | |
| # # def rerank_documents(self, query: str, documents: List[Dict], top_n: int) -> Tuple[List[Dict], float]: | |
| # # """使用重排序模型""" | |
| # # if not documents: return [], 0 | |
| # # url = f"{API_BASE}/rerank" | |
| # # headers = { | |
| # # "Authorization": f"Bearer {API_KEY}", | |
| # # "Content-Type": "application/json" | |
| # # } | |
| # # # 截断文档内容以符合 Context Window | |
| # # docs_content = [doc["content"][:2048] for doc in documents] | |
| # # payload = { | |
| # # "model": RERANK_MODEL, | |
| # # "query": query, | |
| # # "documents": docs_content, | |
| # # "top_n": top_n | |
| # # } | |
| # # try: | |
| # # start_time = time.time() | |
| # # response = requests.post(url, headers=headers, json=payload, timeout=20) | |
| # # elapsed = time.time() - start_time | |
| # # if response.status_code == 200: | |
| # # results = response.json().get("results", []) | |
| # # reranked_docs = [] | |
| # # for result in results: | |
| # # original_doc = documents[result["index"]] | |
| # # original_doc["rerank_score"] = result["relevance_score"] | |
| # # original_doc["final_score"] = result["relevance_score"] | |
| # # reranked_docs.append(original_doc) | |
| # # return reranked_docs, elapsed | |
| # # else: | |
| # # print(f"Rerank API Error: {response.text}") | |
| # # return documents[:top_n], elapsed | |
| # # except Exception as e: | |
| # # print(f"Rerank Exception: {e}") | |
| # # return documents[:top_n], 0 | |
| # # def execute_strategy(self, query: str, config: Dict[str, Any]) -> Dict[str, Any]: | |
| # # """执行策略矩阵""" | |
| # # start_time = time.time() | |
| # # result = { | |
| # # 'original_query': query, | |
| # # 'final_query': query, | |
| # # 'documents': [], | |
| # # 'steps': [], | |
| # # 'metrics': {'qe_time': 0, 'retrieval_time': 0, 'rerank_time': 0, 'total_time': 0}, | |
| # # 'strategy_tags': [] | |
| # # } | |
| # # # 1. 查询扩展 | |
| # # if config['use_qe']: | |
| # # expanded_q, qe_time = self.expand_query(query) | |
| # # result['final_query'] = expanded_q | |
| # # result['metrics']['qe_time'] = qe_time | |
| # # result['steps'].append(f"🧠 查询扩展 ({qe_time:.2f}s): {query} → **{expanded_q}**") | |
| # # result['strategy_tags'].append("QE") | |
| # # # 2. 检索 | |
| # # retrieval_start = time.time() | |
| # # query_to_search = result['final_query'] | |
| # # if config['strategy'] == 'Vector': | |
| # # results = self.vector_search(query_to_search) | |
| # # result['steps'].append(f"🔍 向量检索: 找到 {len(results)} 个候选") | |
| # # result['strategy_tags'].append("Vector") | |
| # # elif config['strategy'] == 'BM25': | |
| # # results = self.bm25_search(query_to_search) | |
| # # result['steps'].append(f"🔍 BM25检索: 找到 {len(results)} 个候选") | |
| # # result['strategy_tags'].append("BM25") | |
| # # elif config['strategy'] == 'Hybrid': | |
| # # vec_results = self.vector_search(query_to_search) | |
| # # bm25_results = self.bm25_search(query_to_search) | |
| # # fused_scores = self.reciprocal_rank_fusion(vec_results, bm25_results) | |
| # # results = sorted(fused_scores.items(), key=lambda x: x[1], reverse=True) | |
| # # results = [(idx, score) for idx, score in results] | |
| # # result['steps'].append(f"🔍 混合检索: Vector + BM25 → {len(results)} 个融合候选") | |
| # # result['strategy_tags'].extend(["Vector", "BM25"]) | |
| # # result['metrics']['retrieval_time'] = time.time() - retrieval_start | |
| # # # 3. 构建候选列表 | |
| # # recall_k = config['top_k'] * 3 if config['use_rerank'] else config['top_k'] | |
| # # top_results = results[:recall_k] | |
| # # documents = [] | |
| # # for idx, score in top_results: | |
| # # documents.append({ | |
| # # 'content': self.documents[idx], | |
| # # 'filename': self.filenames[idx], | |
| # # 'retrieval_score': score, | |
| # # 'final_score': score, | |
| # # 'type': 'retrieval' | |
| # # }) | |
| # # # 4. 重排序 | |
| # # if config['use_rerank']: | |
| # # reranked_docs, rerank_time = self.rerank_documents( | |
| # # result['final_query'], documents, config['top_k'] | |
| # # ) | |
| # # result['documents'] = reranked_docs | |
| # # result['metrics']['rerank_time'] = rerank_time | |
| # # result['steps'].append(f"⚖️ 重排序 ({rerank_time:.2f}s): 精选 Top-{config['top_k']}") | |
| # # result['strategy_tags'].append("Rerank") | |
| # # else: | |
| # # result['documents'] = documents[:config['top_k']] | |
| # # result['metrics']['total_time'] = time.time() - start_time | |
| # # result['steps'].append(f"⏱️ 总耗时: {result['metrics']['total_time']:.2f}s") | |
| # # return result | |
| # # def generate_suggestions(controller, query: str, answer: str) -> List[str]: | |
| # # """生成3个后续引导问题""" | |
| # # prompt = f"""基于以下技术问答,预测用户可能感兴趣的3个后续COMSOL专业问题。 | |
| # # 用户问题:{query} | |
| # # 专家回答:{answer[:800]}... | |
| # # 要求: | |
| # # 1. 问题简短(15字以内)。 | |
| # # 2. 紧扣当前话题。 | |
| # # 3. 严格输出 JSON 字符串数组格式,例如:["问题1", "问题2", "问题3"]。 | |
| # # 4. 不要包含任何 Markdown 标记。 | |
| # # """ | |
| # # try: | |
| # # resp = controller.client.chat.completions.create( | |
| # # model=SUGGEST_MODEL_NAME, | |
| # # messages=[{"role": "user", "content": prompt}], | |
| # # temperature=0.5 | |
| # # ) | |
| # # content = resp.choices[0].message.content.strip() | |
| # # match = re.search(r'\[.*\]', content, re.DOTALL) | |
| # # if match: | |
| # # sugs = json.loads(match.group()) | |
| # # return sugs[:3] | |
| # # return [] | |
| # # except Exception as e: | |
| # # print(f"Suggestion Error: {e}") | |
| # # return [] | |
| # # def generate_answer(controller, query: str, documents: List[Dict], history: List[Dict], max_rounds: int) -> str: | |
| # # """流式生成回答""" | |
| # # if not documents: | |
| # # return "抱歉,没有找到相关的文档来回答您的问题。" | |
| # # context_text = "\n\n".join([f"[文档{i+1}] {doc['content'][:800]}..." for i, doc in enumerate(documents)]) | |
| # # system_prompt = f"""你是一个COMSOL Multiphysics仿真专家。请基于提供的文档回答用户问题。 | |
| # # 要求: | |
| # # 1. 语气专业,使用COMSOL术语。 | |
| # # 2. 物理公式使用 LaTeX(如 $E=mc^2$)。 | |
| # # 3. 如果文档信息不足,请如实告知,不要编造。 | |
| # # 【参考文档】: | |
| # # {context_text} | |
| # # """ | |
| # # # 构建历史记录 | |
| # # keep_messages = max_rounds * 2 | |
| # # history_to_send = history[:-1][-keep_messages:] if keep_messages > 0 else [] | |
| # # api_messages = [{"role": "system", "content": system_prompt}] + history_to_send + [{"role": "user", "content": query}] | |
| # # try: | |
| # # response = controller.client.chat.completions.create( | |
| # # model=GEN_MODEL_NAME, | |
| # # messages=api_messages, | |
| # # temperature=0.3, | |
| # # stream=True | |
| # # ) | |
| # # answer = "" | |
| # # placeholder = st.empty() | |
| # # for chunk in response: | |
| # # if chunk.choices[0].delta.content: | |
| # # answer += chunk.choices[0].delta.content | |
| # # placeholder.markdown(answer + "▌") | |
| # # placeholder.markdown(answer) | |
| # # return answer | |
| # # except Exception as e: | |
| # # return f"生成遇到错误: {e}" | |
| # # # ================= 4. 初始化与组件渲染 ================= | |
| # # @st.cache_resource(show_spinner="🚀 正在初始化 RAG 引擎...") | |
| # # def initialize_controller(): | |
| # # return RAGController() | |
| # # def render_strategy_matrix(): | |
| # # st.markdown('<p class="strategy-title">🎯 策略矩阵配置</p>', unsafe_allow_html=True) | |
| # # st.markdown("""<div style="background-color: #161B22; padding: 10px; border-radius: 8px; margin-bottom: 20px;"> | |
| # # <p style="font-size: 0.85rem; color: #8B949E; margin: 0;">⚙️ <b>参数调节</b>:控制检索片段数量和模型记忆深度。</p> | |
| # # </div>""", unsafe_allow_html=True) | |
| # # col1, col2 = st.columns(2) | |
| # # with col1: | |
| # # use_qe = st.toggle("🔄 查询扩展 (QE)", value=False) | |
| # # use_rerank = st.toggle("⚖️ 深度重排序 (Rerank)", value=True) | |
| # # max_history_rounds = st.slider("🧠 记忆轮数", 0, 50, 10, help="发给模型的对话历史轮数") | |
| # # with col2: | |
| # # strategy = st.radio("🔍 检索策略", ["Vector", "BM25", "Hybrid"], index=2) | |
| # # top_k = st.slider("📊 检索数量", 1, 50, 10, help="从知识库召回的片段数量") | |
| # # return {'use_qe': use_qe, 'strategy': strategy, 'use_rerank': use_rerank, 'top_k': top_k, 'max_history_rounds': max_history_rounds} | |
| # # def render_metrics(metrics): | |
| # # st.markdown("### 📊 性能指标") | |
| # # cols = st.columns(4) | |
| # # with cols[0]: st.metric("查询扩展", f"{metrics['qe_time']:.2f}s" if metrics['qe_time']>0 else "N/A", delta="QE" if metrics['qe_time']>0 else None) | |
| # # with cols[1]: st.metric("检索耗时", f"{metrics['retrieval_time']:.2f}s") | |
| # # with cols[2]: st.metric("重排序", f"{metrics['rerank_time']:.2f}s" if metrics['rerank_time']>0 else "N/A", delta="Rerank" if metrics['rerank_time']>0 else None) | |
| # # with cols[3]: st.metric("总耗时", f"{metrics['total_time']:.2f}s", delta="⚡") | |
| # # def render_documents(documents, strategy_tags): | |
| # # st.markdown("### 📄 检索结果") | |
| # # if not documents: | |
| # # st.warning("未找到相关文档") | |
| # # return | |
| # # tags_html = "".join([f'<span class="strat-tag {{"QE":"tag-qe","Vector":"tag-vec","BM25":"tag-bm25","Rerank":"tag-rerank"}}.get(t, "")}}">{t}</span>' for t in strategy_tags]) # Simplified for brevity, use full logic if copying | |
| # # # Manual mapping for safety | |
| # # html_tags = "" | |
| # # for tag in strategy_tags: | |
| # # cls = "tag-vec" if tag=="Vector" else "tag-bm25" if tag=="BM25" else "tag-qe" if tag=="QE" else "tag-rerank" | |
| # # html_tags += f'<span class="strat-tag {cls}">{tag}</span>' | |
| # # st.markdown(f"**策略组合:** {html_tags}", unsafe_allow_html=True) | |
| # # for i, doc in enumerate(documents): | |
| # # score = doc.get('final_score', 0) | |
| # # with st.expander(f"📄 文档 {i+1} | Score: {score:.4f} | {doc['filename'][:40]}...", expanded=i<2): | |
| # # st.code(doc['content'], language="markdown") | |
| # # # ================= 5. 主程序 ================= | |
| # # def main(): | |
| # # # 状态初始化 | |
| # # if "messages" not in st.session_state: st.session_state.messages = [] | |
| # # if "last_result" not in st.session_state: st.session_state.last_result = None | |
| # # if "suggestions" not in st.session_state: st.session_state.suggestions = random.sample(PRESET_QUESTIONS, 3) | |
| # # if "prompt_trigger" not in st.session_state: st.session_state.prompt_trigger = None | |
| # # # 加载控制器 | |
| # # controller = initialize_controller() | |
| # # # 侧边栏 | |
| # # with st.sidebar: | |
| # # config = render_strategy_matrix() | |
| # # st.markdown("---") | |
| # # if st.button("🗑️ 清空当前对话", use_container_width=True): | |
| # # st.session_state.messages = [] | |
| # # st.session_state.last_result = None | |
| # # st.rerun() | |
| # # # 主界面布局 | |
| # # main_col, debug_col = st.columns([0.6, 0.4], gap="large") | |
| # # with main_col: | |
| # # st.markdown("### 💬 智能仿真问答") | |
| # # # 1. 历史消息 | |
| # # for msg in st.session_state.messages: | |
| # # with st.chat_message(msg["role"]): | |
| # # st.markdown(msg["content"]) | |
| # # # 2. 建议区 | |
| # # if st.session_state.suggestions: | |
| # # st.markdown("##### 💡 您可能想问:") | |
| # # cols = st.columns(3) | |
| # # for i, sug in enumerate(st.session_state.suggestions): | |
| # # if cols[i].button(sug, use_container_width=True, key=f"sug_{i}"): | |
| # # st.session_state.prompt_trigger = sug | |
| # # st.rerun() | |
| # # # 3. 输入处理 | |
| # # user_input = None | |
| # # if st.session_state.prompt_trigger: | |
| # # user_input = st.session_state.prompt_trigger | |
| # # st.session_state.prompt_trigger = None | |
| # # else: | |
| # # user_input = st.chat_input("请输入您关于 COMSOL 的问题...") | |
| # # # 4. 执行逻辑 | |
| # # if user_input: | |
| # # st.session_state.messages.append({"role": "user", "content": user_input}) | |
| # # with st.chat_message("user"): st.markdown(user_input) | |
| # # # 检索 | |
| # # with st.spinner("🔍 检索知识库中..."): | |
| # # result = controller.execute_strategy(user_input, config) | |
| # # st.session_state.last_result = result | |
| # # # 生成 | |
| # # with st.chat_message("assistant"): | |
| # # answer = generate_answer( | |
| # # controller, user_input, result['documents'], | |
| # # st.session_state.messages, config['max_history_rounds'] | |
| # # ) | |
| # # st.session_state.messages.append({"role": "assistant", "content": answer}) | |
| # # # 生成新建议 | |
| # # new_sugs = generate_suggestions(controller, user_input, answer) | |
| # # st.session_state.suggestions = new_sugs if new_sugs else random.sample(PRESET_QUESTIONS, 3) | |
| # # st.rerun() | |
| # # with debug_col: | |
| # # st.markdown("### 🔍 系统调试视图") | |
| # # if st.session_state.last_result: | |
| # # res = st.session_state.last_result | |
| # # st.info(f"当前查询: {res.get('original_query', 'N/A')}") | |
| # # render_metrics(res['metrics']) | |
| # # with st.expander("🔧 检索链路详情", expanded=True): | |
| # # for step in res['steps']: st.markdown(f"- {step}") | |
| # # render_documents(res['documents'], res['strategy_tags']) | |
| # # else: | |
| # # st.info("等待交互...") | |
| # # if __name__ == "__main__": | |
| # # main() | |
| # import streamlit as st | |
| # import pandas as pd | |
| # import numpy as np | |
| # import jieba | |
| # import requests | |
| # import os | |
| # import time | |
| # import json | |
| # import re | |
| # import random | |
| # import subprocess | |
| # from openai import OpenAI | |
| # from rank_bm25 import BM25Okapi | |
| # from sklearn.metrics.pairwise import cosine_similarity | |
| # from typing import List, Dict, Tuple, Any | |
| # # ================= 1. 全局配置与样式 ================= | |
| # # API 配置 (从 HF 环境变量获取) | |
| # API_BASE = "https://api.siliconflow.cn/v1" | |
| # API_KEY = os.getenv("SILICONFLOW_API_KEY") | |
| # # 模型名称配置 | |
| # EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-4B" | |
| # RERANK_MODEL = "Qwen/Qwen3-Reranker-4B" | |
| # GEN_MODEL_NAME = "MiniMaxAI/MiniMax-M2" | |
| # # QE_MODEL_NAME = "Qwen/Qwen3-Next-80B-A3B-Instruct" | |
| # # SUGGEST_MODEL_NAME = "Qwen/Qwen3-Next-80B-A3B-Instruct" | |
| # QE_MODEL_NAME = "MiniMaxAI/MiniMax-M2" | |
| # SUGGEST_MODEL_NAME = "MiniMaxAI/MiniMax-M2" | |
| # # 预置问题池 | |
| # PRESET_QUESTIONS = [ | |
| # "如何设置流固耦合接口?", | |
| # "求解器不收敛怎么办?", | |
| # "网格划分有哪些技巧?", | |
| # "如何定义随时间变化的边界条件?", | |
| # "计算结果如何导出数据?", | |
| # "什么是完美匹配层 (PML)?", | |
| # "低频电磁场仿真注意事项", | |
| # "如何提高瞬态计算速度?", | |
| # "参数化扫描如何设置?", | |
| # "多物理场耦合的收敛性优化" | |
| # ] | |
| # # 数据文件配置 | |
| # DATA_FILENAME = "comsol_embedded.parquet" | |
| # DATA_URL = "https://share.leezhu.cn/graduation_design_data/comsol_embedded.parquet" | |
| # # 页面配置 | |
| # st.set_page_config( | |
| # page_title="COMSOL RAG 策略控制台", | |
| # page_icon="🎛️", | |
| # layout="wide", | |
| # initial_sidebar_state="expanded" | |
| # ) | |
| # # 自定义CSS样式 (适配深色/浅色模式) | |
| # st.markdown(""" | |
| # <style> | |
| # /* 移除强制背景色,改用透明或半透明,从而适配系统主题 */ | |
| # /* 聊天消息样式 - 使用半透明背景以适配两种模式 */ | |
| # [data-testid="stChatMessage"] { | |
| # background-color: rgba(128, 128, 128, 0.1); /* 10%透明度的灰色,深浅模式都适用 */ | |
| # border: 1px solid rgba(128, 128, 128, 0.2); | |
| # border-radius: 10px; | |
| # } | |
| # /* 策略标签 - 保持原有的 rgba 设置,因为它们是半透明的,在浅色模式下也好看 */ | |
| # .strat-tag { | |
| # font-size: 0.75rem; | |
| # padding: 3px 8px; | |
| # border-radius: 4px; | |
| # margin-right: 6px; | |
| # font-weight: bold; | |
| # display: inline-block; | |
| # margin-bottom: 4px; | |
| # border: 1px solid rgba(128, 128, 128, 0.2); | |
| # } | |
| # /* 调整标签颜色适配 */ | |
| # .tag-vec { background-color: rgba(31, 119, 180, 0.15); color: #1f77b4; border-color: #1f77b4; } | |
| # .tag-bm25 { background-color: rgba(255, 127, 14, 0.15); color: #d66a00; border-color: #ff7f0e; } | |
| # .tag-qe { background-color: rgba(44, 160, 44, 0.15); color: #2ca02c; border-color: #2ca02c; } | |
| # .tag-rerank { background-color: rgba(214, 39, 40, 0.15); color: #d62728; border-color: #d62728; } | |
| # /* 过程展示框 */ | |
| # .process-box { | |
| # background-color: rgba(128, 128, 128, 0.05); /* 极淡的背景 */ | |
| # border: 1px solid rgba(128, 128, 128, 0.2); | |
| # padding: 15px; | |
| # border-radius: 8px; | |
| # font-size: 0.9rem; | |
| # margin-bottom: 15px; | |
| # } | |
| # /* 策略矩阵标题 - 渐变色文字通常在两种背景下都可见,微调一下 */ | |
| # .strategy-title { | |
| # background: linear-gradient(45deg, #4A90E2 0%, #9013FE 100%); | |
| # -webkit-background-clip: text; | |
| # -webkit-text-fill-color: transparent; | |
| # background-clip: text; | |
| # font-weight: bold; | |
| # font-size: 1.2rem; | |
| # } | |
| # /* 性能指标框 */ | |
| # .metric-box { | |
| # background-color: rgba(128, 128, 128, 0.05); | |
| # border: 1px solid rgba(128, 128, 128, 0.2); | |
| # border-radius: 6px; | |
| # padding: 10px; | |
| # margin: 5px 0; | |
| # text-align: center; | |
| # } | |
| # /* 动画效果保持不变 */ | |
| # @keyframes pulse { | |
| # 0% { opacity: 1; } | |
| # 50% { opacity: 0.7; } | |
| # 100% { opacity: 1; } | |
| # } | |
| # .processing { | |
| # animation: pulse 1.5s infinite; | |
| # } | |
| # </style> | |
| # """, unsafe_allow_html=True) | |
| # # ================= 2. 数据下载工具 (HF 适配) ================= | |
| # def download_with_curl(url, output_path): | |
| # """使用 curl 下载文件,增加鲁棒性""" | |
| # try: | |
| # cmd = [ | |
| # "curl", "-L", | |
| # "-A", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", | |
| # "-o", output_path, | |
| # "--fail", | |
| # url | |
| # ] | |
| # result = subprocess.run(cmd, capture_output=True, text=True) | |
| # if result.returncode != 0: | |
| # print(f"Curl stderr: {result.stderr}") | |
| # return False | |
| # return True | |
| # except Exception as e: | |
| # print(f"Curl download error: {e}") | |
| # return False | |
| # def get_data_file_path(): | |
| # """获取数据文件路径,如果不存在则自动下载""" | |
| # # 优先检查本地可能存在的路径 | |
| # possible_paths = [ | |
| # DATA_FILENAME, | |
| # os.path.join("/app", DATA_FILENAME), | |
| # os.path.join("processed_data", DATA_FILENAME), | |
| # os.path.join(os.getcwd(), DATA_FILENAME) | |
| # ] | |
| # for path in possible_paths: | |
| # if os.path.exists(path): | |
| # return path | |
| # # 如果都没找到,准备下载 | |
| # # HF Spaces 通常在 /home/user/app 下运行,直接下载到当前目录 | |
| # download_target = os.path.join(os.getcwd(), DATA_FILENAME) | |
| # status_container = st.empty() | |
| # status_container.info("📡 正在接入神经元网络... (下载核心数据中,首次运行可能需要几十秒)") | |
| # # 尝试 Curl 下载 | |
| # if download_with_curl(DATA_URL, download_target): | |
| # status_container.empty() | |
| # return download_target | |
| # # 降级尝试 Requests 下载 | |
| # try: | |
| # headers = {'User-Agent': 'Mozilla/5.0'} | |
| # r = requests.get(DATA_URL, headers=headers, stream=True) | |
| # r.raise_for_status() | |
| # with open(download_target, 'wb') as f: | |
| # for chunk in r.iter_content(chunk_size=8192): | |
| # f.write(chunk) | |
| # status_container.empty() | |
| # return download_target | |
| # except Exception as e: | |
| # st.error(f"❌ 数据下载失败。Error: {e}") | |
| # st.stop() | |
| # # ================= 3. 核心 RAG 控制器 ================= | |
| # class RAGController: | |
| # """RAG系统控制器 - 实现策略矩阵""" | |
| # def __init__(self): | |
| # """初始化控制器""" | |
| # if not API_KEY: | |
| # st.error("⚠️ 未检测到 API Key。请在 Space Settings -> Secrets 中配置 `SILICONFLOW_API_KEY`。") | |
| # st.stop() | |
| # self.client = OpenAI(base_url=API_BASE, api_key=API_KEY) | |
| # self.df = None | |
| # self.documents = [] | |
| # self.embeddings = None | |
| # self.bm25 = None | |
| # self.filenames = [] | |
| # self._load_data() | |
| # def _load_data(self): | |
| # """加载COMSOL文档数据""" | |
| # real_path = get_data_file_path() | |
| # try: | |
| # # 加载数据 | |
| # self.df = pd.read_parquet(real_path) | |
| # self.documents = self.df['content'].tolist() | |
| # self.filenames = self.df['filename'].tolist() | |
| # # 加载向量嵌入 | |
| # self.embeddings = np.stack(self.df['embedding'].values) | |
| # # 初始化BM25 | |
| # tokenized_corpus = [jieba.lcut(str(doc).lower()) for doc in self.documents] | |
| # self.bm25 = BM25Okapi(tokenized_corpus) | |
| # st.success(f"✅ 成功加载 {len(self.documents)} 条文档") | |
| # except Exception as e: | |
| # st.error(f"❌ 数据加载失败: {str(e)}") | |
| # st.stop() | |
| # def get_embedding(self, text: str) -> List[float]: | |
| # """获取文本向量嵌入""" | |
| # try: | |
| # resp = self.client.embeddings.create( | |
| # model=EMBEDDING_MODEL, | |
| # input=[text] | |
| # ) | |
| # return resp.data[0].embedding | |
| # except Exception as e: | |
| # st.warning(f"向量获取失败: {e}") | |
| # return [0.0] * 2560 # Qwen3-Embedding-4B dimension fallback | |
| # def expand_query(self, query: str) -> Tuple[str, float]: | |
| # """查询扩展 - 使用LLM优化查询""" | |
| # prompt = f"""你是COMSOL仿真专家。请将用户的口语化问题改写为专业的检索查询。 | |
| # 要求: | |
| # 1. 补充COMSOL专业术语(物理场、模块、边界条件等) | |
| # 2. 保持问题核心意图不变 | |
| # 3. 输出简洁,仅返回改写后的查询 | |
| # 用户问题: {query} | |
| # 专业查询:""" | |
| # try: | |
| # start_time = time.time() | |
| # resp = self.client.chat.completions.create( | |
| # model=QE_MODEL_NAME, | |
| # messages=[{"role": "user", "content": prompt}], | |
| # temperature=0.3 | |
| # ) | |
| # expanded = resp.choices[0].message.content.strip() | |
| # elapsed = time.time() - start_time | |
| # return expanded, elapsed | |
| # except Exception as e: | |
| # print(f"QE Error: {e}") | |
| # return query, 0 | |
| # def vector_search(self, query: str, top_k: int = 100) -> List[Tuple[int, float]]: | |
| # """向量检索""" | |
| # q_vec = self.get_embedding(query) | |
| # similarities = cosine_similarity([q_vec], self.embeddings)[0] | |
| # top_indices = np.argsort(similarities)[-top_k:][::-1] | |
| # return [(idx, similarities[idx]) for idx in top_indices] | |
| # def bm25_search(self, query: str, top_k: int = 100) -> List[Tuple[int, float]]: | |
| # """BM25关键词检索""" | |
| # tokenized_query = jieba.lcut(query.lower()) | |
| # scores = self.bm25.get_scores(tokenized_query) | |
| # top_indices = np.argsort(scores)[-top_k:][::-1] | |
| # return [(idx, scores[idx]) for idx in top_indices] | |
| # def reciprocal_rank_fusion(self, vector_results: List[Tuple[int, float]], | |
| # bm25_results: List[Tuple[int, float]], k: int = 60) -> Dict[int, float]: | |
| # """RRF融合算法""" | |
| # scores = {} | |
| # for rank, (idx, score) in enumerate(vector_results): | |
| # scores[idx] = scores.get(idx, 0) + 1.0 / (k + rank + 1) | |
| # for rank, (idx, score) in enumerate(bm25_results): | |
| # scores[idx] = scores.get(idx, 0) + 1.0 / (k + rank + 1) | |
| # return scores | |
| # def rerank_documents(self, query: str, documents: List[Dict], top_n: int) -> Tuple[List[Dict], float]: | |
| # """使用重排序模型""" | |
| # if not documents: return [], 0 | |
| # url = f"{API_BASE}/rerank" | |
| # headers = { | |
| # "Authorization": f"Bearer {API_KEY}", | |
| # "Content-Type": "application/json" | |
| # } | |
| # # 截断文档内容以符合 Context Window | |
| # docs_content = [doc["content"][:2048] for doc in documents] | |
| # payload = { | |
| # "model": RERANK_MODEL, | |
| # "query": query, | |
| # "documents": docs_content, | |
| # "top_n": top_n | |
| # } | |
| # try: | |
| # start_time = time.time() | |
| # response = requests.post(url, headers=headers, json=payload, timeout=20) | |
| # elapsed = time.time() - start_time | |
| # if response.status_code == 200: | |
| # results = response.json().get("results", []) | |
| # reranked_docs = [] | |
| # for result in results: | |
| # original_doc = documents[result["index"]] | |
| # original_doc["rerank_score"] = result["relevance_score"] | |
| # original_doc["final_score"] = result["relevance_score"] | |
| # reranked_docs.append(original_doc) | |
| # return reranked_docs, elapsed | |
| # else: | |
| # print(f"Rerank API Error: {response.text}") | |
| # return documents[:top_n], elapsed | |
| # except Exception as e: | |
| # print(f"Rerank Exception: {e}") | |
| # return documents[:top_n], 0 | |
| # def execute_strategy(self, query: str, config: Dict[str, Any]) -> Dict[str, Any]: | |
| # """执行策略矩阵""" | |
| # start_time = time.time() | |
| # result = { | |
| # 'original_query': query, | |
| # 'final_query': query, | |
| # 'documents': [], | |
| # 'steps': [], | |
| # 'metrics': {'qe_time': 0, 'retrieval_time': 0, 'rerank_time': 0, 'total_time': 0}, | |
| # 'strategy_tags': [] | |
| # } | |
| # # 1. 查询扩展 | |
| # if config['use_qe']: | |
| # expanded_q, qe_time = self.expand_query(query) | |
| # result['final_query'] = expanded_q | |
| # result['metrics']['qe_time'] = qe_time | |
| # result['steps'].append(f"🧠 查询扩展 ({qe_time:.2f}s): {query} → **{expanded_q}**") | |
| # result['strategy_tags'].append("QE") | |
| # # 2. 检索 | |
| # retrieval_start = time.time() | |
| # query_to_search = result['final_query'] | |
| # if config['strategy'] == 'Vector': | |
| # results = self.vector_search(query_to_search) | |
| # result['steps'].append(f"🔍 向量检索: 找到 {len(results)} 个候选") | |
| # result['strategy_tags'].append("Vector") | |
| # elif config['strategy'] == 'BM25': | |
| # results = self.bm25_search(query_to_search) | |
| # result['steps'].append(f"🔍 BM25检索: 找到 {len(results)} 个候选") | |
| # result['strategy_tags'].append("BM25") | |
| # elif config['strategy'] == 'Hybrid': | |
| # vec_results = self.vector_search(query_to_search) | |
| # bm25_results = self.bm25_search(query_to_search) | |
| # fused_scores = self.reciprocal_rank_fusion(vec_results, bm25_results) | |
| # results = sorted(fused_scores.items(), key=lambda x: x[1], reverse=True) | |
| # results = [(idx, score) for idx, score in results] | |
| # result['steps'].append(f"🔍 混合检索: Vector + BM25 → {len(results)} 个融合候选") | |
| # result['strategy_tags'].extend(["Vector", "BM25"]) | |
| # result['metrics']['retrieval_time'] = time.time() - retrieval_start | |
| # # 3. 构建候选列表 | |
| # recall_k = config['top_k'] * 3 if config['use_rerank'] else config['top_k'] | |
| # top_results = results[:recall_k] | |
| # documents = [] | |
| # for idx, score in top_results: | |
| # documents.append({ | |
| # 'content': self.documents[idx], | |
| # 'filename': self.filenames[idx], | |
| # 'retrieval_score': score, | |
| # 'final_score': score, | |
| # 'type': 'retrieval' | |
| # }) | |
| # # 4. 重排序 | |
| # if config['use_rerank']: | |
| # reranked_docs, rerank_time = self.rerank_documents( | |
| # result['final_query'], documents, config['top_k'] | |
| # ) | |
| # result['documents'] = reranked_docs | |
| # result['metrics']['rerank_time'] = rerank_time | |
| # result['steps'].append(f"⚖️ 重排序 ({rerank_time:.2f}s): 精选 Top-{config['top_k']}") | |
| # result['strategy_tags'].append("Rerank") | |
| # else: | |
| # result['documents'] = documents[:config['top_k']] | |
| # result['metrics']['total_time'] = time.time() - start_time | |
| # result['steps'].append(f"⏱️ 总耗时: {result['metrics']['total_time']:.2f}s") | |
| # return result | |
| # def generate_suggestions(controller, query: str, answer: str) -> List[str]: | |
| # """生成3个后续引导问题""" | |
| # prompt = f"""基于以下技术问答,预测用户可能感兴趣的3个后续COMSOL专业问题。 | |
| # 用户问题:{query} | |
| # 专家回答:{answer[:800]}... | |
| # 要求: | |
| # 1. 问题简短(15字以内)。 | |
| # 2. 紧扣当前话题。 | |
| # 3. 严格输出 JSON 字符串数组格式,例如:["问题1", "问题2", "问题3"]。 | |
| # 4. 不要包含任何 Markdown 标记。 | |
| # """ | |
| # try: | |
| # resp = controller.client.chat.completions.create( | |
| # model=SUGGEST_MODEL_NAME, | |
| # messages=[{"role": "user", "content": prompt}], | |
| # temperature=0.5 | |
| # ) | |
| # content = resp.choices[0].message.content.strip() | |
| # match = re.search(r'\[.*\]', content, re.DOTALL) | |
| # if match: | |
| # sugs = json.loads(match.group()) | |
| # return sugs[:3] | |
| # return [] | |
| # except Exception as e: | |
| # print(f"Suggestion Error: {e}") | |
| # return [] | |
| # def generate_answer(controller, query: str, documents: List[Dict], history: List[Dict], max_rounds: int) -> str: | |
| # """流式生成回答""" | |
| # if not documents: | |
| # return "抱歉,没有找到相关的文档来回答您的问题。" | |
| # context_text = "\n\n".join([f"[文档{i+1}] {doc['content'][:800]}..." for i, doc in enumerate(documents)]) | |
| # system_prompt = f"""你是一个COMSOL Multiphysics仿真专家。请基于提供的文档回答用户问题。 | |
| # 要求: | |
| # 1. 语气专业,使用COMSOL术语。 | |
| # 2. 物理公式使用 LaTeX(如 $E=mc^2$)。 | |
| # 3. 如果文档信息不足,请如实告知,不要编造。 | |
| # 【参考文档】: | |
| # {context_text} | |
| # """ | |
| # # 构建历史记录 | |
| # keep_messages = max_rounds * 2 | |
| # history_to_send = history[:-1][-keep_messages:] if keep_messages > 0 else [] | |
| # api_messages = [{"role": "system", "content": system_prompt}] + history_to_send + [{"role": "user", "content": query}] | |
| # try: | |
| # response = controller.client.chat.completions.create( | |
| # model=GEN_MODEL_NAME, | |
| # messages=api_messages, | |
| # temperature=0.3, | |
| # stream=True | |
| # ) | |
| # answer = "" | |
| # placeholder = st.empty() | |
| # for chunk in response: | |
| # if chunk.choices[0].delta.content: | |
| # answer += chunk.choices[0].delta.content | |
| # placeholder.markdown(answer + "▌") | |
| # placeholder.markdown(answer) | |
| # return answer | |
| # except Exception as e: | |
| # return f"生成遇到错误: {e}" | |
| # # ================= 4. 初始化与组件渲染 ================= | |
| # @st.cache_resource(show_spinner="🚀 正在初始化 RAG 引擎...") | |
| # def initialize_controller(): | |
| # return RAGController() | |
| # def render_strategy_matrix(): | |
| # st.markdown('<p class="strategy-title">🎯 策略矩阵配置</p>', unsafe_allow_html=True) | |
| # st.markdown("""<div style="background-color: rgba(128, 128, 128, 0.05); padding: 10px; border-radius: 8px; margin-bottom: 20px; border: 1px solid rgba(128, 128, 128, 0.2);"> | |
| # <p style="font-size: 0.85rem; margin: 0;">⚙️ <b>参数调节</b>:控制检索片段数量和模型记忆深度。</p> | |
| # </div>""", unsafe_allow_html=True) | |
| # col1, col2 = st.columns(2) | |
| # with col1: | |
| # use_qe = st.toggle("🔄 查询扩展 (QE)", value=False) | |
| # use_rerank = st.toggle("⚖️ 深度重排序 (Rerank)", value=True) | |
| # max_history_rounds = st.slider("🧠 记忆轮数", 0, 50, 10, help="发给模型的对话历史轮数") | |
| # with col2: | |
| # strategy = st.radio("🔍 检索策略", ["Vector", "BM25", "Hybrid"], index=2) | |
| # top_k = st.slider("📊 检索数量", 1, 50, 10, help="从知识库召回的片段数量") | |
| # return {'use_qe': use_qe, 'strategy': strategy, 'use_rerank': use_rerank, 'top_k': top_k, 'max_history_rounds': max_history_rounds} | |
| # def render_metrics(metrics): | |
| # st.markdown("### 📊 性能指标") | |
| # cols = st.columns(4) | |
| # with cols[0]: st.metric("查询扩展", f"{metrics['qe_time']:.2f}s" if metrics['qe_time']>0 else "N/A", delta="QE" if metrics['qe_time']>0 else None) | |
| # with cols[1]: st.metric("检索耗时", f"{metrics['retrieval_time']:.2f}s") | |
| # with cols[2]: st.metric("重排序", f"{metrics['rerank_time']:.2f}s" if metrics['rerank_time']>0 else "N/A", delta="Rerank" if metrics['rerank_time']>0 else None) | |
| # with cols[3]: st.metric("总耗时", f"{metrics['total_time']:.2f}s", delta="⚡") | |
| # def render_documents(documents, strategy_tags): | |
| # st.markdown("### 📄 检索结果") | |
| # if not documents: | |
| # st.warning("未找到相关文档") | |
| # return | |
| # tags_html = "".join([f'<span class="strat-tag {{"QE":"tag-qe","Vector":"tag-vec","BM25":"tag-bm25","Rerank":"tag-rerank"}}.get(t, "")}}">{t}</span>' for t in strategy_tags]) # Simplified for brevity, use full logic if copying | |
| # # Manual mapping for safety | |
| # html_tags = "" | |
| # for tag in strategy_tags: | |
| # cls = "tag-vec" if tag=="Vector" else "tag-bm25" if tag=="BM25" else "tag-qe" if tag=="QE" else "tag-rerank" | |
| # html_tags += f'<span class="strat-tag {cls}">{tag}</span>' | |
| # st.markdown(f"**策略组合:** {html_tags}", unsafe_allow_html=True) | |
| # for i, doc in enumerate(documents): | |
| # score = doc.get('final_score', 0) | |
| # with st.expander(f"📄 文档 {i+1} | Score: {score:.4f} | {doc['filename'][:40]}...", expanded=i<2): | |
| # st.code(doc['content'], language="markdown") | |
| # # ================= 5. 主程序 ================= | |
| # def main(): | |
| # # 状态初始化 | |
| # if "messages" not in st.session_state: st.session_state.messages = [] | |
| # if "last_result" not in st.session_state: st.session_state.last_result = None | |
| # if "suggestions" not in st.session_state: st.session_state.suggestions = random.sample(PRESET_QUESTIONS, 3) | |
| # if "prompt_trigger" not in st.session_state: st.session_state.prompt_trigger = None | |
| # # 加载控制器 | |
| # controller = initialize_controller() | |
| # # 侧边栏 | |
| # with st.sidebar: | |
| # config = render_strategy_matrix() | |
| # st.markdown("---") | |
| # if st.button("🗑️ 清空当前对话", use_container_width=True): | |
| # st.session_state.messages = [] | |
| # st.session_state.last_result = None | |
| # st.rerun() | |
| # # 主界面布局 | |
| # main_col, debug_col = st.columns([0.6, 0.4], gap="large") | |
| # with main_col: | |
| # st.markdown("### 💬 智能仿真问答") | |
| # # 1. 历史消息 | |
| # for msg in st.session_state.messages: | |
| # with st.chat_message(msg["role"]): | |
| # st.markdown(msg["content"]) | |
| # # 2. 建议区 | |
| # if st.session_state.suggestions: | |
| # st.markdown("##### 💡 您可能想问:") | |
| # cols = st.columns(3) | |
| # for i, sug in enumerate(st.session_state.suggestions): | |
| # if cols[i].button(sug, use_container_width=True, key=f"sug_{i}"): | |
| # st.session_state.prompt_trigger = sug | |
| # st.rerun() | |
| # # 3. 输入处理 | |
| # user_input = None | |
| # if st.session_state.prompt_trigger: | |
| # user_input = st.session_state.prompt_trigger | |
| # st.session_state.prompt_trigger = None | |
| # else: | |
| # user_input = st.chat_input("请输入您关于 COMSOL 的问题...") | |
| # # 4. 执行逻辑 | |
| # if user_input: | |
| # st.session_state.messages.append({"role": "user", "content": user_input}) | |
| # with st.chat_message("user"): st.markdown(user_input) | |
| # # 检索 | |
| # with st.spinner("🔍 检索知识库中..."): | |
| # result = controller.execute_strategy(user_input, config) | |
| # st.session_state.last_result = result | |
| # # 生成 | |
| # with st.chat_message("assistant"): | |
| # answer = generate_answer( | |
| # controller, user_input, result['documents'], | |
| # st.session_state.messages, config['max_history_rounds'] | |
| # ) | |
| # st.session_state.messages.append({"role": "assistant", "content": answer}) | |
| # # 生成新建议 | |
| # new_sugs = generate_suggestions(controller, user_input, answer) | |
| # st.session_state.suggestions = new_sugs if new_sugs else random.sample(PRESET_QUESTIONS, 3) | |
| # st.rerun() | |
| # with debug_col: | |
| # st.markdown("### 🔍 系统调试视图") | |
| # if st.session_state.last_result: | |
| # res = st.session_state.last_result | |
| # st.info(f"当前查询: {res.get('original_query', 'N/A')}") | |
| # render_metrics(res['metrics']) | |
| # with st.expander("🔧 检索链路详情", expanded=True): | |
| # for step in res['steps']: st.markdown(f"- {step}") | |
| # render_documents(res['documents'], res['strategy_tags']) | |
| # else: | |
| # st.info("等待交互...") | |
| # if __name__ == "__main__": | |
| # main() | |
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import jieba | |
| import requests | |
| import os | |
| import time | |
| import json | |
| import re | |
| import random | |
| import subprocess | |
| import logging | |
| import psutil | |
| from openai import OpenAI | |
| from rank_bm25 import BM25Okapi | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from typing import List, Dict, Tuple, Any | |
| # ================= 0. 日志与内存监控配置 ================= | |
| # 配置日志格式 - 日志会显示在 HF Space 的 "Logs" 标签页 | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def log_memory(): | |
| """记录当前内存占用情况 (单位: MB)""" | |
| process = psutil.Process(os.getpid()) | |
| mem_info = process.memory_info() | |
| res_mem = mem_info.rss / (1024 * 1024) | |
| logger.info(f"💾 Current Memory Usage: {res_mem:.2f} MB") | |
| return res_mem | |
| # ================= 1. 全局配置与样式 ================= | |
| # API 配置 (从 HF 环境变量获取) | |
| API_BASE = "https://api.siliconflow.cn/v1" | |
| API_KEY = os.getenv("SILICONFLOW_API_KEY") | |
| # 模型名称配置 | |
| EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-4B" | |
| RERANK_MODEL = "Qwen/Qwen3-Reranker-4B" | |
| GEN_MODEL_NAME = "MiniMaxAI/MiniMax-M2" | |
| QE_MODEL_NAME = "Qwen/Qwen3-Next-80B-A3B-Instruct" | |
| SUGGEST_MODEL_NAME = "Qwen/Qwen3-Next-80B-A3B-Instruct" | |
| # 预置问题池 | |
| PRESET_QUESTIONS = [ | |
| "如何设置流固耦合接口?", | |
| "求解器不收敛怎么办?", | |
| "网格划分有哪些技巧?", | |
| "如何定义随时间变化的边界条件?", | |
| "计算结果如何导出数据?", | |
| "什么是完美匹配层 (PML)?", | |
| "低频电磁场仿真注意事项", | |
| "如何提高瞬态计算速度?", | |
| "参数化扫描如何设置?", | |
| "多物理场耦合的收敛性优化" | |
| ] | |
| # 数据文件配置 | |
| DATA_FILENAME = "comsol_embedded.parquet" | |
| DATA_URL = "https://share.leezhu.cn/graduation_design_data/comsol_embedded.parquet" | |
| # 页面配置 | |
| st.set_page_config( | |
| page_title="COMSOL RAG 策略控制台", | |
| page_icon="🎛️", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # 自定义CSS样式 (适配深色/浅色模式) | |
| st.markdown(""" | |
| <style> | |
| /* 移除强制背景色,改用透明或半透明,从而适配系统主题 */ | |
| /* 聊天消息样式 - 使用半透明背景以适配两种模式 */ | |
| [data-testid="stChatMessage"] { | |
| background-color: rgba(128, 128, 128, 0.1); /* 10%透明度的灰色,深浅模式都适用 */ | |
| border: 1px solid rgba(128, 128, 128, 0.2); | |
| border-radius: 10px; | |
| } | |
| /* 策略标签 - 保持原有的 rgba 设置,因为它们是半透明的,在浅色模式下也好看 */ | |
| .strat-tag { | |
| font-size: 0.75rem; | |
| padding: 3px 8px; | |
| border-radius: 4px; | |
| margin-right: 6px; | |
| font-weight: bold; | |
| display: inline-block; | |
| margin-bottom: 4px; | |
| border: 1px solid rgba(128, 128, 128, 0.2); | |
| } | |
| /* 调整标签颜色适配 */ | |
| .tag-vec { background-color: rgba(31, 119, 180, 0.15); color: #1f77b4; border-color: #1f77b4; } | |
| .tag-bm25 { background-color: rgba(255, 127, 14, 0.15); color: #d66a00; border-color: #ff7f0e; } | |
| .tag-qe { background-color: rgba(44, 160, 44, 0.15); color: #2ca02c; border-color: #2ca02c; } | |
| .tag-rerank { background-color: rgba(214, 39, 40, 0.15); color: #d62728; border-color: #d62728; } | |
| /* 过程展示框 */ | |
| .process-box { | |
| background-color: rgba(128, 128, 128, 0.05); /* 极淡的背景 */ | |
| border: 1px solid rgba(128, 128, 128, 0.2); | |
| padding: 15px; | |
| border-radius: 8px; | |
| font-size: 0.9rem; | |
| margin-bottom: 15px; | |
| } | |
| /* 策略矩阵标题 - 渐变色文字通常在两种背景下都可见,微调一下 */ | |
| .strategy-title { | |
| background: linear-gradient(45deg, #4A90E2 0%, #9013FE 100%); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| background-clip: text; | |
| font-weight: bold; | |
| font-size: 1.2rem; | |
| } | |
| /* 性能指标框 */ | |
| .metric-box { | |
| background-color: rgba(128, 128, 128, 0.05); | |
| border: 1px solid rgba(128, 128, 128, 0.2); | |
| border-radius: 6px; | |
| padding: 10px; | |
| margin: 5px 0; | |
| text-align: center; | |
| } | |
| /* 动画效果保持不变 */ | |
| @keyframes pulse { | |
| 0% { opacity: 1; } | |
| 50% { opacity: 0.7; } | |
| 100% { opacity: 1; } | |
| } | |
| .processing { | |
| animation: pulse 1.5s infinite; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # ================= 2. 数据下载工具 (HF 适配) ================= | |
| def download_with_curl(url, output_path): | |
| """使用 curl 下载文件,增加鲁棒性""" | |
| try: | |
| cmd = [ | |
| "curl", "-L", | |
| "-A", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", | |
| "-o", output_path, | |
| "--fail", | |
| url | |
| ] | |
| result = subprocess.run(cmd, capture_output=True, text=True) | |
| if result.returncode != 0: | |
| logger.warning(f"Curl stderr: {result.stderr}") | |
| return False | |
| return True | |
| except Exception as e: | |
| logger.error(f"Curl download error: {e}") | |
| return False | |
| def get_data_file_path(): | |
| """获取数据文件路径,如果不存在则自动下载""" | |
| # 优先检查本地可能存在的路径 | |
| possible_paths = [ | |
| DATA_FILENAME, | |
| os.path.join("/app", DATA_FILENAME), | |
| os.path.join("processed_data", DATA_FILENAME), | |
| os.path.join(os.getcwd(), DATA_FILENAME) | |
| ] | |
| for path in possible_paths: | |
| if os.path.exists(path): | |
| return path | |
| # 如果都没找到,准备下载 | |
| # HF Spaces 通常在 /home/user/app 下运行,直接下载到当前目录 | |
| download_target = os.path.join(os.getcwd(), DATA_FILENAME) | |
| status_container = st.empty() | |
| status_container.info("📡 正在接入神经元网络... (下载核心数据中,首次运行可能需要几十秒)") | |
| # 尝试 Curl 下载 | |
| if download_with_curl(DATA_URL, download_target): | |
| status_container.empty() | |
| return download_target | |
| # 降级尝试 Requests 下载 | |
| try: | |
| headers = {'User-Agent': 'Mozilla/5.0'} | |
| r = requests.get(DATA_URL, headers=headers, stream=True) | |
| r.raise_for_status() | |
| with open(download_target, 'wb') as f: | |
| for chunk in r.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| status_container.empty() | |
| return download_target | |
| except Exception as e: | |
| st.error(f"❌ 数据下载失败。Error: {e}") | |
| st.stop() | |
| # ================= 3. 核心 RAG 控制器 ================= | |
| class RAGController: | |
| """RAG系统控制器 - 实现策略矩阵""" | |
| def __init__(self): | |
| """初始化控制器""" | |
| if not API_KEY: | |
| st.error("⚠️ 未检测到 API Key。请在 Space Settings -> Secrets 中配置 `SILICONFLOW_API_KEY`。") | |
| st.stop() | |
| self.client = OpenAI(base_url=API_BASE, api_key=API_KEY) | |
| self.df = None | |
| self.documents = [] | |
| self.embeddings = None | |
| self.bm25 = None | |
| self.filenames = [] | |
| self._load_data() | |
| def _load_data(self): | |
| """加载COMSOL文档数据""" | |
| real_path = get_data_file_path() | |
| try: | |
| logger.info("🚀 Initializing RAG Controller...") | |
| log_memory() | |
| # 加载数据 | |
| logger.info(f"📂 Loading parquet from: {real_path}") | |
| self.df = pd.read_parquet(real_path) | |
| self.documents = self.df['content'].tolist() | |
| self.filenames = self.df['filename'].tolist() | |
| logger.info(f"✅ Dataframe loaded. Shape: {self.df.shape}") | |
| log_memory() | |
| # 加载向量嵌入 | |
| logger.info("🧠 Stacking embeddings matrix...") | |
| self.embeddings = np.stack(self.df['embedding'].values) | |
| logger.info(f"✅ Embeddings stacked. Shape: {self.embeddings.shape}") | |
| log_memory() | |
| # 初始化BM25 | |
| logger.info("🔍 Initializing BM25 index (Tokenizing)...") | |
| tokenized_corpus = [jieba.lcut(str(doc).lower()) for doc in self.documents] | |
| self.bm25 = BM25Okapi(tokenized_corpus) | |
| logger.info("✅ BM25 Index ready.") | |
| log_memory() | |
| st.success(f"✅ 成功加载 {len(self.documents)} 条文档") | |
| except Exception as e: | |
| logger.error(f"❌ Critical error during data load: {str(e)}", exc_info=True) | |
| st.error(f"❌ 数据加载失败: {str(e)}") | |
| st.stop() | |
| def get_embedding(self, text: str) -> List[float]: | |
| """获取文本向量嵌入""" | |
| try: | |
| resp = self.client.embeddings.create( | |
| model=EMBEDDING_MODEL, | |
| input=[text] | |
| ) | |
| return resp.data[0].embedding | |
| except Exception as e: | |
| st.warning(f"向量获取失败: {e}") | |
| return [0.0] * 2560 # Qwen3-Embedding-4B dimension fallback | |
| def expand_query(self, query: str) -> Tuple[str, float]: | |
| """查询扩展 - 使用LLM优化查询""" | |
| prompt = f"""你是COMSOL仿真专家。请将用户的口语化问题改写为专业的检索查询。 | |
| 要求: | |
| 1. 补充COMSOL专业术语(物理场、模块、边界条件等) | |
| 2. 保持问题核心意图不变 | |
| 3. 输出简洁,仅返回改写后的查询 | |
| 用户问题: {query} | |
| 专业查询:""" | |
| try: | |
| start_time = time.time() | |
| resp = self.client.chat.completions.create( | |
| model=QE_MODEL_NAME, | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0.3 | |
| ) | |
| expanded = resp.choices[0].message.content.strip() | |
| elapsed = time.time() - start_time | |
| logger.info(f"🔧 QE completed in {elapsed:.2f}s") | |
| return expanded, elapsed | |
| except Exception as e: | |
| logger.error(f"❌ QE Error: {e}") | |
| return query, 0 | |
| def vector_search(self, query: str, top_k: int = 100) -> List[Tuple[int, float]]: | |
| """向量检索""" | |
| q_vec = self.get_embedding(query) | |
| similarities = cosine_similarity([q_vec], self.embeddings)[0] | |
| top_indices = np.argsort(similarities)[-top_k:][::-1] | |
| return [(idx, similarities[idx]) for idx in top_indices] | |
| def bm25_search(self, query: str, top_k: int = 100) -> List[Tuple[int, float]]: | |
| """BM25关键词检索""" | |
| tokenized_query = jieba.lcut(query.lower()) | |
| scores = self.bm25.get_scores(tokenized_query) | |
| top_indices = np.argsort(scores)[-top_k:][::-1] | |
| return [(idx, scores[idx]) for idx in top_indices] | |
| def reciprocal_rank_fusion(self, vector_results: List[Tuple[int, float]], | |
| bm25_results: List[Tuple[int, float]], k: int = 60) -> Dict[int, float]: | |
| """RRF融合算法""" | |
| scores = {} | |
| for rank, (idx, score) in enumerate(vector_results): | |
| scores[idx] = scores.get(idx, 0) + 1.0 / (k + rank + 1) | |
| for rank, (idx, score) in enumerate(bm25_results): | |
| scores[idx] = scores.get(idx, 0) + 1.0 / (k + rank + 1) | |
| return scores | |
| def rerank_documents(self, query: str, documents: List[Dict], top_n: int) -> Tuple[List[Dict], float]: | |
| """使用重排序模型""" | |
| if not documents: return [], 0 | |
| url = f"{API_BASE}/rerank" | |
| headers = { | |
| "Authorization": f"Bearer {API_KEY}", | |
| "Content-Type": "application/json" | |
| } | |
| # 截断文档内容以符合 Context Window | |
| docs_content = [doc["content"][:2048] for doc in documents] | |
| payload = { | |
| "model": RERANK_MODEL, | |
| "query": query, | |
| "documents": docs_content, | |
| "top_n": top_n | |
| } | |
| try: | |
| start_time = time.time() | |
| # 设置 timeout 为 15 秒,防止长时间挂起导致 WebSocket 断开 | |
| response = requests.post(url, headers=headers, json=payload, timeout=15) | |
| elapsed = time.time() - start_time | |
| if response.status_code == 200: | |
| results = response.json().get("results", []) | |
| reranked_docs = [] | |
| for result in results: | |
| original_doc = documents[result["index"]] | |
| original_doc["rerank_score"] = result["relevance_score"] | |
| original_doc["final_score"] = result["relevance_score"] | |
| reranked_docs.append(original_doc) | |
| return reranked_docs, elapsed | |
| else: | |
| logger.warning(f"Rerank API Error: {response.text}") | |
| return documents[:top_n], elapsed | |
| except requests.exceptions.Timeout: | |
| logger.warning("⚠️ Rerank API timed out, falling back to original order.") | |
| return documents[:top_n], 0 | |
| except Exception as e: | |
| logger.error(f"❌ Rerank error: {e}") | |
| return documents[:top_n], 0 | |
| def execute_strategy(self, query: str, config: Dict[str, Any]) -> Dict[str, Any]: | |
| """执行策略矩阵""" | |
| start_time = time.time() | |
| result = { | |
| 'original_query': query, | |
| 'final_query': query, | |
| 'documents': [], | |
| 'steps': [], | |
| 'metrics': {'qe_time': 0, 'retrieval_time': 0, 'rerank_time': 0, 'total_time': 0}, | |
| 'strategy_tags': [] | |
| } | |
| # 1. 查询扩展 | |
| if config['use_qe']: | |
| expanded_q, qe_time = self.expand_query(query) | |
| result['final_query'] = expanded_q | |
| result['metrics']['qe_time'] = qe_time | |
| result['steps'].append(f"🧠 查询扩展 ({qe_time:.2f}s): {query} → **{expanded_q}**") | |
| result['strategy_tags'].append("QE") | |
| # 2. 检索 | |
| retrieval_start = time.time() | |
| query_to_search = result['final_query'] | |
| if config['strategy'] == 'Vector': | |
| results = self.vector_search(query_to_search) | |
| result['steps'].append(f"🔍 向量检索: 找到 {len(results)} 个候选") | |
| result['strategy_tags'].append("Vector") | |
| elif config['strategy'] == 'BM25': | |
| results = self.bm25_search(query_to_search) | |
| result['steps'].append(f"🔍 BM25检索: 找到 {len(results)} 个候选") | |
| result['strategy_tags'].append("BM25") | |
| elif config['strategy'] == 'Hybrid': | |
| vec_results = self.vector_search(query_to_search) | |
| bm25_results = self.bm25_search(query_to_search) | |
| fused_scores = self.reciprocal_rank_fusion(vec_results, bm25_results) | |
| results = sorted(fused_scores.items(), key=lambda x: x[1], reverse=True) | |
| results = [(idx, score) for idx, score in results] | |
| result['steps'].append(f"🔍 混合检索: Vector + BM25 → {len(results)} 个融合候选") | |
| result['strategy_tags'].extend(["Vector", "BM25"]) | |
| result['metrics']['retrieval_time'] = time.time() - retrieval_start | |
| # 3. 构建候选列表 | |
| recall_k = config['top_k'] * 3 if config['use_rerank'] else config['top_k'] | |
| top_results = results[:recall_k] | |
| documents = [] | |
| for idx, score in top_results: | |
| documents.append({ | |
| 'content': self.documents[idx], | |
| 'filename': self.filenames[idx], | |
| 'retrieval_score': score, | |
| 'final_score': score, | |
| 'type': 'retrieval' | |
| }) | |
| # 4. 重排序 | |
| if config['use_rerank']: | |
| reranked_docs, rerank_time = self.rerank_documents( | |
| result['final_query'], documents, config['top_k'] | |
| ) | |
| result['documents'] = reranked_docs | |
| result['metrics']['rerank_time'] = rerank_time | |
| result['steps'].append(f"⚖️ 重排序 ({rerank_time:.2f}s): 精选 Top-{config['top_k']}") | |
| result['strategy_tags'].append("Rerank") | |
| else: | |
| result['documents'] = documents[:config['top_k']] | |
| result['metrics']['total_time'] = time.time() - start_time | |
| result['steps'].append(f"⏱️ 总耗时: {result['metrics']['total_time']:.2f}s") | |
| return result | |
| def generate_suggestions(controller, query: str, answer: str) -> List[str]: | |
| """生成3个后续引导问题""" | |
| prompt = f"""基于以下技术问答,预测用户可能感兴趣的3个后续COMSOL专业问题。 | |
| 用户问题:{query} | |
| 专家回答:{answer[:800]}... | |
| 要求: | |
| 1. 问题简短(15字以内)。 | |
| 2. 紧扣当前话题。 | |
| 3. 严格输出 JSON 字符串数组格式,例如:["问题1", "问题2", "问题3"]。 | |
| 4. 不要包含任何 Markdown 标记。 | |
| """ | |
| try: | |
| resp = controller.client.chat.completions.create( | |
| model=SUGGEST_MODEL_NAME, | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0.5 | |
| ) | |
| content = resp.choices[0].message.content.strip() | |
| match = re.search(r'\[.*\]', content, re.DOTALL) | |
| if match: | |
| sugs = json.loads(match.group()) | |
| return sugs[:3] | |
| return [] | |
| except Exception as e: | |
| logger.error(f"Suggestion Error: {e}") | |
| return [] | |
| def generate_answer(controller, query: str, documents: List[Dict], history: List[Dict], max_rounds: int) -> str: | |
| """流式生成回答""" | |
| if not documents: | |
| return "抱歉,没有找到相关的文档来回答您的问题。" | |
| context_text = "\n\n".join([f"[文档{i+1}] {doc['content'][:800]}..." for i, doc in enumerate(documents)]) | |
| system_prompt = f"""你是一个COMSOL Multiphysics仿真专家。请基于提供的文档回答用户问题。 | |
| 要求: | |
| 1. 语气专业,使用COMSOL术语。 | |
| 2. 物理公式使用 LaTeX(如 $E=mc^2$)。 | |
| 3. 如果文档信息不足,请如实告知,不要编造。 | |
| 【参考文档】: | |
| {context_text} | |
| """ | |
| # 构建历史记录 | |
| keep_messages = max_rounds * 2 | |
| history_to_send = history[:-1][-keep_messages:] if keep_messages > 0 else [] | |
| api_messages = [{"role": "system", "content": system_prompt}] + history_to_send + [{"role": "user", "content": query}] | |
| try: | |
| response = controller.client.chat.completions.create( | |
| model=GEN_MODEL_NAME, | |
| messages=api_messages, | |
| temperature=0.3, | |
| stream=True | |
| ) | |
| answer = "" | |
| placeholder = st.empty() | |
| for chunk in response: | |
| if chunk.choices[0].delta.content: | |
| answer += chunk.choices[0].delta.content | |
| placeholder.markdown(answer + "▌") | |
| placeholder.markdown(answer) | |
| return answer | |
| except Exception as e: | |
| return f"生成遇到错误: {e}" | |
| # ================= 4. 初始化与组件渲染 ================= | |
| def initialize_controller(): | |
| return RAGController() | |
| def render_strategy_matrix(): | |
| st.markdown('<p class="strategy-title">🎯 策略矩阵配置</p>', unsafe_allow_html=True) | |
| st.markdown("""<div style="background-color: rgba(128, 128, 128, 0.05); padding: 10px; border-radius: 8px; margin-bottom: 20px; border: 1px solid rgba(128, 128, 128, 0.2);"> | |
| <p style="font-size: 0.85rem; margin: 0;">⚙️ <b>参数调节</b>:控制检索片段数量和模型记忆深度。</p> | |
| </div>""", unsafe_allow_html=True) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| use_qe = st.toggle("🔄 查询扩展 (QE)", value=False) | |
| use_rerank = st.toggle("⚖️ 深度重排序 (Rerank)", value=True) | |
| max_history_rounds = st.slider("🧠 记忆轮数", 0, 50, 10, help="发给模型的对话历史轮数") | |
| with col2: | |
| strategy = st.radio("🔍 检索策略", ["Vector", "BM25", "Hybrid"], index=2) | |
| top_k = st.slider("📊 检索数量", 1, 50, 10, help="从知识库召回的片段数量") | |
| return {'use_qe': use_qe, 'strategy': strategy, 'use_rerank': use_rerank, 'top_k': top_k, 'max_history_rounds': max_history_rounds} | |
| def render_metrics(metrics): | |
| st.markdown("### 📊 性能指标") | |
| cols = st.columns(4) | |
| with cols[0]: st.metric("查询扩展", f"{metrics['qe_time']:.2f}s" if metrics['qe_time']>0 else "N/A", delta="QE" if metrics['qe_time']>0 else None) | |
| with cols[1]: st.metric("检索耗时", f"{metrics['retrieval_time']:.2f}s") | |
| with cols[2]: st.metric("重排序", f"{metrics['rerank_time']:.2f}s" if metrics['rerank_time']>0 else "N/A", delta="Rerank" if metrics['rerank_time']>0 else None) | |
| with cols[3]: st.metric("总耗时", f"{metrics['total_time']:.2f}s", delta="⚡") | |
| def render_documents(documents, strategy_tags): | |
| st.markdown("### 📄 检索结果") | |
| if not documents: | |
| st.warning("未找到相关文档") | |
| return | |
| tags_html = "".join([f'<span class="strat-tag {{"QE":"tag-qe","Vector":"tag-vec","BM25":"tag-bm25","Rerank":"tag-rerank"}}.get(t, "")}}">{t}</span>' for t in strategy_tags]) # Simplified for brevity, use full logic if copying | |
| # Manual mapping for safety | |
| html_tags = "" | |
| for tag in strategy_tags: | |
| cls = "tag-vec" if tag=="Vector" else "tag-bm25" if tag=="BM25" else "tag-qe" if tag=="QE" else "tag-rerank" | |
| html_tags += f'<span class="strat-tag {cls}">{tag}</span>' | |
| st.markdown(f"**策略组合:** {html_tags}", unsafe_allow_html=True) | |
| for i, doc in enumerate(documents): | |
| score = doc.get('final_score', 0) | |
| with st.expander(f"📄 文档 {i+1} | Score: {score:.4f} | {doc['filename'][:40]}...", expanded=i<2): | |
| st.code(doc['content'], language="markdown") | |
| # ================= 5. 主程序 ================= | |
| def main(): | |
| # 状态初始化 | |
| if "messages" not in st.session_state: st.session_state.messages = [] | |
| if "last_result" not in st.session_state: st.session_state.last_result = None | |
| if "suggestions" not in st.session_state: st.session_state.suggestions = random.sample(PRESET_QUESTIONS, 3) | |
| if "prompt_trigger" not in st.session_state: st.session_state.prompt_trigger = None | |
| # 加载控制器 | |
| controller = initialize_controller() | |
| # 侧边栏 | |
| with st.sidebar: | |
| config = render_strategy_matrix() | |
| st.markdown("---") | |
| if st.button("🗑️ 清空当前对话", use_container_width=True): | |
| st.session_state.messages = [] | |
| st.session_state.last_result = None | |
| st.rerun() | |
| # 主界面布局 | |
| main_col, debug_col = st.columns([0.6, 0.4], gap="large") | |
| with main_col: | |
| st.markdown("### 💬 智能仿真问答") | |
| # 1. 历史消息 | |
| for msg in st.session_state.messages: | |
| with st.chat_message(msg["role"]): | |
| st.markdown(msg["content"]) | |
| # 2. 建议区 | |
| if st.session_state.suggestions: | |
| st.markdown("##### 💡 您可能想问:") | |
| cols = st.columns(3) | |
| for i, sug in enumerate(st.session_state.suggestions): | |
| if cols[i].button(sug, use_container_width=True, key=f"sug_{i}"): | |
| logger.info(f"🔘 Triggered by button: {sug}") | |
| st.session_state.prompt_trigger = sug | |
| st.rerun() | |
| # 3. 输入处理 | |
| user_input = None | |
| if st.session_state.prompt_trigger: | |
| user_input = st.session_state.prompt_trigger | |
| st.session_state.prompt_trigger = None # 立即清除,防止重复触发 | |
| logger.info(f"🔘 Triggered by button: {user_input}") | |
| else: | |
| user_input = st.chat_input("请输入您关于 COMSOL 的问题...") | |
| if user_input: | |
| logger.info(f"⌨️ Triggered by chat input: {user_input}") | |
| # 4. 执行逻辑 | |
| if user_input: | |
| st.session_state.messages.append({"role": "user", "content": user_input}) | |
| with st.chat_message("user"): st.markdown(user_input) | |
| # 检索 | |
| with st.spinner("🔍 检索知识库中..."): | |
| logger.info(f"🔎 Starting retrieval for: {user_input[:50]}...") | |
| result = controller.execute_strategy(user_input, config) | |
| st.session_state.last_result = result | |
| logger.info(f"✅ Retrieval done in {result['metrics']['total_time']:.2f}s") | |
| # 生成 | |
| with st.chat_message("assistant"): | |
| logger.info("🤖 Generating answer...") | |
| answer = generate_answer( | |
| controller, user_input, result['documents'], | |
| st.session_state.messages, config['max_history_rounds'] | |
| ) | |
| st.session_state.messages.append({"role": "assistant", "content": answer}) | |
| # 生成新建议 (不需要 rerun,Streamlit 会在脚本结束后自动刷新 UI) | |
| logger.info("✨ Generating follow-up questions...") | |
| new_sugs = generate_suggestions(controller, user_input, answer) | |
| st.session_state.suggestions = new_sugs if new_sugs else random.sample(PRESET_QUESTIONS, 3) | |
| logger.info(f"✅ Response completed in {result['metrics']['total_time']:.2f}s") | |
| with debug_col: | |
| st.markdown("### 🔍 系统调试视图") | |
| if st.session_state.last_result: | |
| res = st.session_state.last_result | |
| st.info(f"当前查询: {res.get('original_query', 'N/A')}") | |
| render_metrics(res['metrics']) | |
| with st.expander("🔧 检索链路详情", expanded=True): | |
| for step in res['steps']: st.markdown(f"- {step}") | |
| render_documents(res['documents'], res['strategy_tags']) | |
| else: | |
| st.info("等待交互...") | |
| if __name__ == "__main__": | |
| main() |