| | import streamlit as st |
| | import pandas as pd |
| | import numpy as np |
| | import jieba |
| | import requests |
| | import os |
| | import time |
| | import json |
| | import re |
| | import random |
| | import subprocess |
| | import logging |
| | import psutil |
| | import traceback |
| | from openai import OpenAI, RateLimitError, APIStatusError |
| | from rank_bm25 import BM25Okapi |
| | from sklearn.metrics.pairwise import cosine_similarity |
| | from typing import List, Dict, Any |
| |
|
| | |
| |
|
| | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
| | logger = logging.getLogger(__name__) |
| |
|
| | def log_memory(): |
| | """监控内存,防止 HF Space OOM""" |
| | process = psutil.Process(os.getpid()) |
| | mem_info = process.memory_info() |
| | res_mem = mem_info.rss / (1024 * 1024) |
| | logger.info(f"💾 Memory Usage: {res_mem:.2f} MB") |
| | return res_mem |
| |
|
| | |
| |
|
| | |
| | API_BASE = "https://api.siliconflow.cn/v1" |
| | API_KEY = os.getenv("SILICONFLOW_API_KEY") |
| |
|
| | |
| | EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-4B" |
| | RERANK_MODEL = "Qwen/Qwen3-Reranker-4B" |
| | |
| | GEN_MODEL_NAME = "MiniMaxAI/MiniMax-M2" |
| | |
| | SUGGEST_MODEL_NAME = "Qwen/Qwen3-Next-80B-A3B-Instruct" |
| |
|
| | |
| | DATA_FILENAME = "comsol_embedded.parquet" |
| | DATA_URL = "https://share.leezhu.cn/graduation_design_data/comsol_embedded.parquet" |
| |
|
| | |
| | PRESET_QUESTIONS = [ |
| | "如何设置流固耦合接口?", |
| | "求解器不收敛通常怎么解决?", |
| | "低频电磁场网格划分有哪些技巧?", |
| | "如何定义随时间变化的边界条件?" |
| | ] |
| |
|
| | st.set_page_config( |
| | page_title="COMSOL Agentic Expert", |
| | page_icon="🌌", |
| | layout="centered", |
| | initial_sidebar_state="expanded" |
| | ) |
| |
|
| | |
| | st.markdown(""" |
| | <style> |
| | /* 全局背景与字体 */ |
| | .stApp { |
| | background-color: #0E1117; |
| | color: #E0E0E0; |
| | } |
| | |
| | /* 聊天气泡 */ |
| | [data-testid="stChatMessage"] { |
| | background-color: #1E1E1E; |
| | border: 1px solid #333; |
| | border-radius: 12px; |
| | padding: 1rem; |
| | } |
| | [data-testid="stChatMessage"][data-testid="user"] { |
| | background-color: #262730; |
| | } |
| | |
| | /* 思考过程 (Thinking) - 17特有 */ |
| | .thinking-block { |
| | color: #8B949E; |
| | font-style: italic; |
| | font-size: 0.9rem; |
| | border-left: 3px solid #333; |
| | padding-left: 10px; |
| | margin: 5px 0 15px 0; |
| | background-color: rgba(255,255,255,0.02); |
| | } |
| | |
| | /* 工具调用日志 (Tool Logs) - 17特有 */ |
| | .tool-block { |
| | font-family: 'Consolas', monospace; |
| | color: #29B5E8; |
| | font-size: 0.85rem; |
| | background-color: #0D1117; |
| | padding: 8px; |
| | border-radius: 6px; |
| | border: 1px solid #30363D; |
| | margin: 5px 0; |
| | } |
| | |
| | /* 标题栏 */ |
| | .main-header { |
| | background: linear-gradient(90deg, #0f2027 0%, #203a43 50%, #2c5364 100%); |
| | padding: 1.5rem; |
| | border-radius: 15px; |
| | text-align: center; |
| | margin-bottom: 2rem; |
| | border: 1px solid #333; |
| | box-shadow: 0 4px 15px rgba(0,0,0,0.5); |
| | } |
| | </style> |
| | """, unsafe_allow_html=True) |
| |
|
| | |
| |
|
| | def download_with_curl(url, output_path): |
| | try: |
| | cmd = [ |
| | "curl", "-L", |
| | "-A", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", |
| | "-o", output_path, |
| | "--fail", |
| | url |
| | ] |
| | result = subprocess.run(cmd, capture_output=True, text=True) |
| | if result.returncode != 0: |
| | logger.warning(f"Curl stderr: {result.stderr}") |
| | return False |
| | return True |
| | except Exception as e: |
| | logger.error(f"Curl error: {e}") |
| | return False |
| |
|
| | def get_data_file_path(): |
| | possible_paths = [ |
| | DATA_FILENAME, |
| | os.path.join(os.getcwd(), DATA_FILENAME), |
| | "/app/" + DATA_FILENAME |
| | ] |
| | for path in possible_paths: |
| | if os.path.exists(path): return path |
| | |
| | download_target = os.path.join(os.getcwd(), DATA_FILENAME) |
| | status_container = st.empty() |
| | status_container.info("📡 正在接入神经元网络... (下载核心知识库)") |
| | |
| | if download_with_curl(DATA_URL, download_target): |
| | status_container.empty() |
| | return download_target |
| | |
| | try: |
| | r = requests.get(DATA_URL, stream=True) |
| | with open(download_target, 'wb') as f: |
| | for chunk in r.iter_content(chunk_size=8192): f.write(chunk) |
| | status_container.empty() |
| | return download_target |
| | except Exception as e: |
| | st.error(f"❌ 数据下载失败: {e}") |
| | st.stop() |
| |
|
| | |
| |
|
| | class RAGController: |
| | def __init__(self): |
| | if not API_KEY: |
| | st.error("⚠️ 未检测到 API Key。请在 Settings -> Secrets 中配置 `SILICONFLOW_API_KEY`。") |
| | st.stop() |
| | self.client = OpenAI(base_url=API_BASE, api_key=API_KEY) |
| | self._load_data() |
| |
|
| | def _load_data(self): |
| | real_path = get_data_file_path() |
| | try: |
| | logger.info("Initializing Engine...") |
| | self.df = pd.read_parquet(real_path) |
| | self.documents = self.df['content'].tolist() |
| | self.filenames = self.df['filename'].tolist() |
| | self.embeddings = np.stack(self.df['embedding'].values) |
| | |
| | tokenized = [jieba.lcut(str(d).lower()) for d in self.documents] |
| | self.bm25 = BM25Okapi(tokenized) |
| | logger.info(f"✅ Loaded {len(self.documents)} docs.") |
| | log_memory() |
| | except Exception as e: |
| | st.error(f"Engine Load Failed: {e}") |
| | st.stop() |
| |
|
| | def execute_retrieval(self, query: str, limit: int = 5) -> Dict[str, Any]: |
| | """供 Agent 调用的检索接口""" |
| | start_t = time.time() |
| | |
| | |
| | try: |
| | q_vec = self.client.embeddings.create(model=EMBEDDING_MODEL, input=[query]).data[0].embedding |
| | except: |
| | return {"query": query, "docs": [], "error": "Embedding API Failed"} |
| | |
| | vec_sim = cosine_similarity([q_vec], self.embeddings)[0] |
| | |
| | |
| | bm25_scores = self.bm25.get_scores(jieba.lcut(query.lower())) |
| | |
| | |
| | scores = {} |
| | for idx in np.argsort(vec_sim)[-100:]: |
| | scores[idx] = scores.get(idx, 0) + 1/(60 + list(np.argsort(vec_sim)).index(idx)) |
| | for idx in np.argsort(bm25_scores)[-100:]: |
| | scores[idx] = scores.get(idx, 0) + 1/(60 + list(np.argsort(bm25_scores)).index(idx)) |
| | |
| | top_idxs = sorted(scores.items(), key=lambda x:x[1], reverse=True)[:50] |
| | if not top_idxs: return {"query": query, "docs": []} |
| |
|
| | recall_docs = [self.documents[i] for i, _ in top_idxs] |
| | |
| | |
| | try: |
| | payload = {"model": RERANK_MODEL, "query": query, "documents": recall_docs, "top_n": limit} |
| | headers = {"Authorization": f"Bearer {API_KEY}", "Content-Type": "application/json"} |
| | resp = requests.post(f"{API_BASE}/rerank", json=payload, headers=headers, timeout=10) |
| | |
| | final_docs = [] |
| | if resp.status_code == 200: |
| | results = resp.json().get('results', []) |
| | for item in results: |
| | orig_idx = top_idxs[item['index']][0] |
| | final_docs.append({ |
| | "filename": self.filenames[orig_idx], |
| | "content": self.documents[orig_idx], |
| | "score": item['relevance_score'] |
| | }) |
| | else: |
| | |
| | for i, _ in top_idxs[:limit]: |
| | final_docs.append({"filename": self.filenames[i], "content": self.documents[i], "score": 0}) |
| | except: |
| | for i, _ in top_idxs[:limit]: |
| | final_docs.append({"filename": self.filenames[i], "content": self.documents[i], "score": 0}) |
| |
|
| | return { |
| | "query": query, |
| | "docs": final_docs, |
| | "meta": {"time": time.time() - start_t} |
| | } |
| |
|
| | @st.cache_resource |
| | def get_engine(): return RAGController() |
| |
|
| | |
| |
|
| | tools_schema = [{ |
| | "type": "function", |
| | "function": { |
| | "name": "search_knowledge_base", |
| | "description": "Search for COMSOL technical documentation. Use this whenever user asks technical questions.", |
| | "parameters": { |
| | "type": "object", |
| | "properties": { |
| | "query": {"type": "string", "description": "Specific technical keywords"}, |
| | "limit": {"type": "integer", "default": 4} |
| | }, |
| | "required": ["query"] |
| | } |
| | } |
| | }] |
| |
|
| | def stream_and_parse_flat(client, messages, container): |
| | """ |
| | 核心流式处理函数:同时渲染 Thinking, Tool Logs 和 Final Answer |
| | """ |
| | full_reasoning = "" |
| | full_content = "" |
| | tool_calls_buffer = {} |
| | |
| | |
| | reasoning_ph = None |
| | content_ph = None |
| | |
| | try: |
| | |
| | stream = None |
| | for attempt in range(3): |
| | try: |
| | stream = client.chat.completions.create( |
| | model=GEN_MODEL_NAME, |
| | messages=messages, |
| | tools=tools_schema, |
| | tool_choice="auto", |
| | temperature=0.1, |
| | stream=True |
| | ) |
| | break |
| | except (RateLimitError, APIStatusError): |
| | time.sleep(2 * (attempt + 1)) |
| | |
| | if not stream: return None, None, None |
| |
|
| | for chunk in stream: |
| | delta = chunk.choices[0].delta |
| | |
| | |
| | |
| | r_content = getattr(delta, 'reasoning_content', None) |
| | |
| | if r_content: |
| | if reasoning_ph is None: |
| | reasoning_ph = container.empty() |
| | full_reasoning += r_content |
| | |
| | reasoning_ph.markdown(f"<div class='thinking-block'>Thinking: {full_reasoning} ▌</div>", unsafe_allow_html=True) |
| |
|
| | |
| | if delta.tool_calls: |
| | for tc in delta.tool_calls: |
| | idx = tc.index |
| | if idx not in tool_calls_buffer: |
| | tool_calls_buffer[idx] = {"name": "", "args": "", "id": ""} |
| | if tc.id: tool_calls_buffer[idx]["id"] = tc.id |
| | if tc.function.name: tool_calls_buffer[idx]["name"] = tc.function.name |
| | if tc.function.arguments: tool_calls_buffer[idx]["args"] += tc.function.arguments |
| |
|
| | |
| | if delta.content: |
| | |
| | if reasoning_ph: |
| | reasoning_ph.markdown(f"<div class='thinking-block'>Thinking: {full_reasoning}</div>", unsafe_allow_html=True) |
| | reasoning_ph = None |
| | |
| | if content_ph is None: |
| | content_ph = container.empty() |
| | full_content += delta.content |
| | content_ph.markdown(full_content + "▌") |
| |
|
| | |
| | if reasoning_ph: |
| | reasoning_ph.markdown(f"<div class='thinking-block'>Thinking: {full_reasoning}</div>", unsafe_allow_html=True) |
| | if content_ph: |
| | content_ph.markdown(full_content) |
| | |
| | return full_reasoning, full_content, tool_calls_buffer |
| |
|
| | except Exception as e: |
| | logger.error(f"Stream Error: {traceback.format_exc()}") |
| | return None, None, None |
| |
|
| | def generate_suggestions(client, query, answer): |
| | """生成后续问题建议""" |
| | try: |
| | prompt = f"""Generate 3 follow-up COMSOL questions based on: Q: {query} A: {answer[:500]}... JSON Array only.""" |
| | resp = client.chat.completions.create(model=SUGGEST_MODEL_NAME, messages=[{"role": "user", "content": prompt}]) |
| | return json.loads(re.search(r'\[.*\]', resp.choices[0].message.content, re.DOTALL).group())[:3] |
| | except: return [] |
| |
|
| | |
| |
|
| | def render_history(): |
| | """渲染复杂的 Agent 历史记录""" |
| | for msg in st.session_state.messages: |
| | role = msg["role"] |
| | content = msg["content"] |
| | |
| | if role == "user": |
| | with st.chat_message("user"): st.markdown(content) |
| | |
| | elif role == "assistant": |
| | |
| | with st.chat_message("assistant"): st.markdown(content) |
| | |
| | elif role == "thought": |
| | |
| | with st.chat_message("assistant"): |
| | with st.expander("💭 Agent Thought Process"): |
| | st.markdown(f"<div class='thinking-block'>{content}</div>", unsafe_allow_html=True) |
| | |
| | elif role == "tool_log": |
| | |
| | with st.chat_message("assistant"): |
| | st.markdown(f""" |
| | <div class='tool-block'> |
| | 🔧 <b>Tool:</b> {content['query']} <br> |
| | 📚 <b>Result:</b> Found {len(content['docs'])} docs (Score: {content['docs'][0]['score'] if content['docs'] else 0:.2f}) |
| | </div> |
| | """, unsafe_allow_html=True) |
| |
|
| | def main(): |
| | |
| | st.markdown(""" |
| | <div class="main-header"> |
| | <h1 style="color:white; margin:0;">🌌 COMSOL Agentic Expert</h1> |
| | <p style="color:#aaa; margin:5px;">Autonomous RAG Agent · Self-Correcting · Deep Reasoning</p> |
| | </div> |
| | """, unsafe_allow_html=True) |
| |
|
| | |
| | if "messages" not in st.session_state: st.session_state.messages = [] |
| | if "suggestions" not in st.session_state: st.session_state.suggestions = random.sample(PRESET_QUESTIONS, 3) |
| | |
| | engine = get_engine() |
| |
|
| | |
| | with st.sidebar: |
| | st.write("### ⚙️ 控制台") |
| | if st.button("🗑️ 清空对话记忆", use_container_width=True): |
| | st.session_state.messages = [] |
| | st.rerun() |
| | st.info(f"📚 知识库状态: {len(engine.documents)} 篇文档已加载") |
| | |
| | |
| | render_history() |
| |
|
| | |
| | |
| | if not st.session_state.messages or st.session_state.messages[-1]["role"] != "user": |
| | st.write("💡 **您可能想问:**") |
| | cols = st.columns(len(st.session_state.suggestions)) |
| | for i, sug in enumerate(st.session_state.suggestions): |
| | if cols[i].button(sug, key=f"sug_{i}"): |
| | st.session_state.messages.append({"role": "user", "content": sug}) |
| | st.rerun() |
| |
|
| | |
| | user_input = st.chat_input("输入 COMSOL 仿真问题 (例如: 流固耦合如何设置?)...") |
| | if user_input: |
| | st.session_state.messages.append({"role": "user", "content": user_input}) |
| | st.rerun() |
| |
|
| | |
| | if st.session_state.messages and st.session_state.messages[-1]["role"] == "user": |
| | |
| | |
| | llm_history = [] |
| | for m in st.session_state.messages: |
| | if m["role"] in ["user", "assistant"]: |
| | llm_history.append({"role": m["role"], "content": m["content"]}) |
| | |
| | sys_prompt = """You are a COMSOL Expert Agent. |
| | 1. FIRST, THINK. Use `reasoning_content` (or output inner thoughts) to analyze the user's request. |
| | 2. IF you need technical details, you MUST call `search_knowledge_base`. |
| | 3. CHECK the search results. If irrelevant, try searching again with different keywords. |
| | 4. ANSWER professionally using the retrieved context. Use LaTeX for math. |
| | """ |
| | |
| | current_messages = [{"role": "system", "content": sys_prompt}] + llm_history[-6:] |
| | |
| | with st.chat_message("assistant"): |
| | container = st.container() |
| | |
| | loop_count = 0 |
| | MAX_LOOPS = 5 |
| | |
| | while loop_count < MAX_LOOPS: |
| | loop_count += 1 |
| | |
| | |
| | reasoning, content, tool_calls = stream_and_parse_flat(engine.client, current_messages, container) |
| | |
| | |
| | if reasoning: |
| | st.session_state.messages.append({"role": "thought", "content": reasoning}) |
| | |
| | |
| | if not tool_calls: |
| | st.session_state.messages.append({"role": "assistant", "content": content}) |
| | |
| | |
| | new_sugs = generate_suggestions(engine.client, st.session_state.messages[-2]['content'], content) |
| | if new_sugs: st.session_state.suggestions = new_sugs |
| | st.rerun() |
| | break |
| | |
| | |
| | |
| | current_messages.append({ |
| | "role": "assistant", |
| | "content": content, |
| | "tool_calls": [{"id": v["id"], "type": "function", "function": {"name": v["name"], "arguments": v["args"]}} for v in tool_calls.values()] |
| | }) |
| | |
| | for idx, tc_data in tool_calls.items(): |
| | try: |
| | func_name = tc_data["name"] |
| | args = json.loads(tc_data["args"]) |
| | |
| | if func_name == "search_knowledge_base": |
| | q = args.get("query") |
| | |
| | ret = engine.execute_retrieval(q) |
| | |
| | |
| | top_score = ret['docs'][0]['score'] if ret['docs'] else 0 |
| | container.markdown(f""" |
| | <div class='tool-block'> |
| | 🔧 <b>Searching:</b> "{q}" <br> |
| | 📄 <b>Found:</b> {len(ret['docs'])} refs (Top Score: {top_score:.3f}) |
| | </div> |
| | """, unsafe_allow_html=True) |
| | |
| | |
| | st.session_state.messages.append({"role": "tool_log", "content": {"query": q, "docs": ret['docs']}}) |
| | |
| | |
| | doc_context = "\n".join([f"[Doc {i}] (Score {d['score']:.2f}): {d['content']}" for i, d in enumerate(ret['docs'])]) |
| | current_messages.append({ |
| | "role": "tool", |
| | "tool_call_id": tc_data["id"], |
| | "content": f"Search Results:\n{doc_context}" |
| | }) |
| | |
| | except Exception as e: |
| | logger.error(f"Tool Execution Error: {e}") |
| | current_messages.append({ |
| | "role": "tool", |
| | "tool_call_id": tc_data["id"], |
| | "content": f"Error executing tool: {str(e)}" |
| | }) |
| |
|
| | if __name__ == "__main__": |
| | main() |