leezhuuu commited on
Commit
597276e
·
verified ·
1 Parent(s): 96aed6c

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +866 -258
src/streamlit_app.py CHANGED
@@ -1,125 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import pandas as pd
3
  import numpy as np
4
  import jieba
5
  import requests
6
  import os
7
- import sys
 
 
 
8
  import subprocess
9
  from openai import OpenAI
10
  from rank_bm25 import BM25Okapi
11
  from sklearn.metrics.pairwise import cosine_similarity
 
12
 
13
- # ================= 1. 全局配置与 CSS注入 =================
14
 
15
- API_KEY = os.getenv("SILICONFLOW_API_KEY")
16
  API_BASE = "https://api.siliconflow.cn/v1"
 
 
 
17
  EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-4B"
18
  RERANK_MODEL = "Qwen/Qwen3-Reranker-4B"
19
- GEN_MODEL_NAME = "MiniMaxAI/MiniMax-M2"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  DATA_FILENAME = "comsol_embedded.parquet"
21
  DATA_URL = "https://share.leezhu.cn/graduation_design_data/comsol_embedded.parquet"
22
 
 
23
  st.set_page_config(
24
- page_title="COMSOL Dark Expert",
25
- page_icon="🌌",
26
  layout="wide",
27
  initial_sidebar_state="expanded"
28
  )
29
 
30
- # --- 注入自定义 CSS (保持之前的审美) ---
31
  st.markdown("""
32
  <style>
33
- /* 1. 整体背景 - 空黑 */
34
  .stApp {
35
- background-color: #050505;
36
- background-image: radial-gradient(circle at 50% 0%, #1a1f35 0%, #050505 60%);
37
- color: #e0e0e0;
38
- font-family: 'Inter', system-ui, -apple-system, sans-serif;
39
  }
40
 
41
- /* 2. 隐藏默认组件 */
42
- #MainMenu {visibility: hidden;}
43
- footer {visibility: hidden;}
44
- header {visibility: hidden;}
45
-
46
- /* 3. 聊天气泡 */
47
  [data-testid="stChatMessage"] {
48
- background: rgba(255, 255, 255, 0.03);
49
- border: 1px solid rgba(255, 255, 255, 0.08);
50
- border-radius: 16px;
51
- backdrop-filter: blur(12px);
52
- box-shadow: 0 4px 20px rgba(0,0,0,0.2);
53
- padding: 1.2rem;
54
- }
55
-
56
- /* 用户气泡 */
57
- [data-testid="stChatMessage"][data-testid="user"] {
58
- background: rgba(41, 181, 232, 0.1);
59
- border-color: rgba(41, 181, 232, 0.2);
60
  }
61
 
62
- /* 4. 自定义标题栏 */
63
- .custom-header {
64
- border-bottom: 1px solid rgba(255,255,255,0.1);
65
- padding-bottom: 1rem;
66
- margin-bottom: 2rem;
67
- display: flex;
68
- align-items: center;
69
- gap: 1rem;
70
- }
71
- .glitch-text {
72
- font-size: 2rem;
73
- font-weight: 800;
74
- background: linear-gradient(120deg, #fff, #29B5E8);
75
- -webkit-background-clip: text;
76
- -webkit-text-fill-color: transparent;
77
- letter-spacing: -1px;
78
- }
79
-
80
- /* 5. 快捷按钮 */
81
- div.stButton > button {
82
- background: rgba(255,255,255,0.05);
83
- color: #aaa;
84
- border: 1px solid rgba(255,255,255,0.1);
85
- border-radius: 20px;
86
- padding: 0.5rem 1rem;
87
- font-size: 0.85rem;
88
- transition: all 0.3s;
89
- width: 100%;
90
- }
91
- div.stButton > button:hover {
92
- background: rgba(41, 181, 232, 0.2);
93
- color: #fff;
94
- border-color: #29B5E8;
95
- transform: translateY(-2px);
96
  }
97
 
98
- /* 6. 输入框 */
99
- .stChatInputContainer textarea {
100
- background-color: #0f1115 !important;
101
- border: 1px solid #333 !important;
102
- color: white !important;
103
- border-radius: 12px !important;
 
 
 
 
104
  }
105
-
106
- /* 7. Expander */
107
- .streamlit-expanderHeader {
108
- background-color: rgba(255,255,255,0.02);
109
- border: 1px solid rgba(255,255,255,0.05);
 
 
 
 
 
110
  border-radius: 8px;
111
- color: #bbb;
 
 
 
 
 
 
 
 
 
 
 
 
112
  }
113
  </style>
114
  """, unsafe_allow_html=True)
115
 
116
- # ================= 2. 核心逻辑(数据与RAG) =================
117
-
118
- if not API_KEY:
119
- st.error("⚠️ 未检测到 API Key。请在 Settings -> Secrets 中配置 `SILICONFLOW_API_KEY`。")
120
- st.stop()
121
 
122
  def download_with_curl(url, output_path):
 
123
  try:
124
  cmd = [
125
  "curl", "-L",
@@ -129,234 +493,478 @@ def download_with_curl(url, output_path):
129
  url
130
  ]
131
  result = subprocess.run(cmd, capture_output=True, text=True)
132
- if result.returncode != 0: raise Exception(f"Curl failed: {result.stderr}")
 
 
133
  return True
134
  except Exception as e:
135
  print(f"Curl download error: {e}")
136
  return False
137
 
138
  def get_data_file_path():
 
 
139
  possible_paths = [
140
- DATA_FILENAME, os.path.join("/app", DATA_FILENAME),
 
141
  os.path.join("processed_data", DATA_FILENAME),
142
- os.path.join("src", DATA_FILENAME),
143
- os.path.join("..", DATA_FILENAME), "/tmp/" + DATA_FILENAME
144
  ]
 
145
  for path in possible_paths:
146
- if os.path.exists(path): return path
 
147
 
148
- download_target = "/app/" + DATA_FILENAME
149
- try: os.makedirs(os.path.dirname(download_target), exist_ok=True)
150
- except: download_target = "/tmp/" + DATA_FILENAME
151
-
152
  status_container = st.empty()
153
- status_container.info("📡 正在接入神经元网络... (下载核心数据中)")
154
 
 
155
  if download_with_curl(DATA_URL, download_target):
156
  status_container.empty()
157
  return download_target
158
 
 
159
  try:
160
  headers = {'User-Agent': 'Mozilla/5.0'}
161
  r = requests.get(DATA_URL, headers=headers, stream=True)
162
  r.raise_for_status()
163
  with open(download_target, 'wb') as f:
164
- for chunk in r.iter_content(chunk_size=8192): f.write(chunk)
 
165
  status_container.empty()
166
  return download_target
167
  except Exception as e:
168
- st.error(f"❌ 数据链路中断。Error: {e}")
169
  st.stop()
170
 
171
- class FullRetriever:
172
- def __init__(self, parquet_path):
173
- try: self.df = pd.read_parquet(parquet_path)
174
- except Exception as e: st.error(f"Memory Matrix Load Failed: {e}"); st.stop()
175
- self.documents = self.df['content'].tolist()
176
- self.embeddings = np.stack(self.df['embedding'].values)
177
- self.bm25 = BM25Okapi([jieba.lcut(str(d).lower()) for d in self.documents])
 
 
 
 
178
  self.client = OpenAI(base_url=API_BASE, api_key=API_KEY)
179
- # Reranker 初始化移到这里,减少重复调用
180
- self.rerank_headers = {"Content-Type": "application/json", "Authorization": f"Bearer {API_KEY}"}
181
- self.rerank_url = f"{API_BASE}/rerank"
182
-
183
- def _get_emb(self, q):
184
- try: return self.client.embeddings.create(model=EMBEDDING_MODEL, input=[q]).data[0].embedding
185
- except: return [0.0] * 1024
186
-
187
- def hybrid_search(self, query: str, top_k=5):
188
- # 1. Vector
189
- q_emb = self._get_emb(query)
190
- vec_scores = cosine_similarity([q_emb], self.embeddings)[0]
191
- vec_idx = np.argsort(vec_scores)[-100:][::-1]
192
- # 2. Keyword
193
- kw_idx = np.argsort(self.bm25.get_scores(jieba.lcut(query.lower())))[-100:][::-1]
194
- # 3. RRF Fusion
195
- fused = {}
196
- for r, i in enumerate(vec_idx): fused[i] = fused.get(i, 0) + 1/(60+r+1)
197
- for r, i in enumerate(kw_idx): fused[i] = fused.get(i, 0) + 1/(60+r+1)
198
- c_idxs = [x[0] for x in sorted(fused.items(), key=lambda x:x[1], reverse=True)[:50]]
199
- c_docs = [self.documents[i] for i in c_idxs]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
- # 4. Rerank
 
 
202
  try:
203
- payload = {"model": RERANK_MODEL, "query": query, "documents": c_docs, "top_n": top_k}
204
- resp = requests.post(self.rerank_url, headers=self.rerank_headers, json=payload, timeout=10)
205
- results = resp.json().get('results', [])
206
- except:
207
- results = [{"index": i, "relevance_score": 0.0} for i in range(len(c_docs))][:top_k]
208
-
209
- final_res = []
210
- context = ""
211
- for i, item in enumerate(results):
212
- orig_idx = c_idxs[item['index']]
213
- row = self.df.iloc[orig_idx]
214
- final_res.append({
215
- "score": item['relevance_score'],
216
- "filename": row['filename'],
217
- "content": row['content']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  })
219
- context += f"[文档{i+1}]: {row['content']}\n\n"
220
- return final_res, context
221
 
222
- @st.cache_resource
223
- def load_engine():
224
- real_path = get_data_file_path()
225
- return FullRetriever(real_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
- # ================= 3. UI 主程序 =================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
  def main():
230
- st.markdown("""
231
- <div class="custom-header">
232
- <div style="font-size: 3rem;">🌌</div>
233
- <div>
234
- <div class="glitch-text">COMSOL DARK EXPERT</div>
235
- <div style="color: #666; font-size: 0.9rem; letter-spacing: 1px;">
236
- NEURAL SIMULATION ASSISTANT <span style="color:#29B5E8">V4.1 Fixed</span>
237
- </div>
238
- </div>
239
- </div>
240
- """, unsafe_allow_html=True)
241
-
242
- retriever = load_engine()
243
 
 
 
 
 
244
  with st.sidebar:
245
- st.markdown("### ⚙️ 控制台")
246
- top_k = st.slider("检索深度", 1, 10, 4)
247
- temp = st.slider("发散度", 0.0, 1.0, 0.3)
248
  st.markdown("---")
249
- if st.button("🗑️ 清空记忆 (Clear)", use_container_width=True):
250
  st.session_state.messages = []
251
- st.session_state.current_refs = []
252
  st.rerun()
253
 
254
- if "messages" not in st.session_state: st.session_state.messages = []
255
- if "current_refs" not in st.session_state: st.session_state.current_refs = []
256
-
257
- col_chat, col_evidence = st.columns([0.65, 0.35], gap="large")
258
-
259
- # ------------------ 处理输入源 ------------------
260
- # 我们定义一个变量 user_input,不管它来自按钮还是输入框
261
- user_input = None
262
-
263
- with col_chat:
264
- # 1. 如果历史为空,显示快捷按钮
265
- if not st.session_state.messages:
266
- st.markdown("##### 💡 初始化提问序列 (Starter Sequence)")
267
- c1, c2, c3 = st.columns(3)
268
- # 点击按钮直接赋值给 user_input
269
- if c1.button("🌊 流固耦合接口设置"):
270
- user_input = "怎么设置流固耦合接口?"
271
- elif c2.button("⚡ 低频电磁场网格"):
272
- user_input = "低频电磁场网格划分有哪些技巧?"
273
- elif c3.button("📉 求解器不收敛"):
274
- user_input = "求解器不收敛通常怎么解决?"
275
 
276
- # 2. 渲染历史消息
277
  for msg in st.session_state.messages:
278
  with st.chat_message(msg["role"]):
279
  st.markdown(msg["content"])
280
 
281
- # 3. 处理底部输入框 (如果有按钮输入,这里会被跳过,因为 user_input 已经有值了)
282
- if not user_input:
283
- user_input = st.chat_input("输入指令或物理参数题...")
 
 
 
 
 
 
 
 
 
 
 
 
 
284
 
285
- # ------------------ 统一处理消息追加 ------------------
286
- if user_input:
287
- st.session_state.messages.append({"role": "user", "content": user_input})
288
- # 强制刷新以立即在 UI 上显示用户的提问(对于按钮点击尤为重要)
289
- st.rerun()
290
 
291
- # ------------------ 统一触发生成 (修复的核心) ------------------
292
- # 检查:如果有消息,且最后一条是 User 发的,说明需要 Assistant 回答
293
- if st.session_state.messages and st.session_state.messages[-1]["role"] == "user":
294
-
295
- # 获取最后一条用户消息
296
- last_query = st.session_state.messages[-1]["content"]
297
-
298
- with col_chat: # 确保在聊天栏显示
299
- with st.spinner("🔍 正在扫描向量空间..."):
300
- refs, context = retriever.hybrid_search(last_query, top_k=top_k)
301
- st.session_state.current_refs = refs
302
-
303
- system_prompt = f"""你是一个COMSOL高级仿真专家。请基于提供的文档回答问题。
304
- 要求:
305
- 1. 语气专业、客观,逻辑严密。
306
- 2. 涉及物理公式时,**必须**使用 LaTeX 格式(例如 $E = mc^2$)。
307
- 3. 涉及步骤或参数对比时,优先使用 Markdown 列表或表格。
308
-
309
- 参考文档:
310
- {context}
311
- """
312
-
313
  with st.chat_message("assistant"):
314
- resp_cont = st.empty()
315
- full_resp = ""
316
- client = OpenAI(base_url=API_BASE, api_key=API_KEY)
 
 
317
 
318
- try:
319
- stream = client.chat.completions.create(
320
- model=GEN_MODEL_NAME,
321
- messages=[{"role": "system", "content": system_prompt}] + st.session_state.messages[-6:], # 除去当前的System
322
- temperature=temp,
323
- stream=True
324
- )
325
- for chunk in stream:
326
- txt = chunk.choices[0].delta.content
327
- if txt:
328
- full_resp += txt
329
- resp_cont.markdown(full_resp + " ")
330
- resp_cont.markdown(full_resp)
331
- st.session_state.messages.append({"role": "assistant", "content": full_resp})
332
- except Exception as e:
333
- st.error(f"Neural Generation Failed: {e}")
334
-
335
- # ------------------ 渲染右侧证据栏 ------------------
336
- with col_evidence:
337
- st.markdown("### 📚 神经记忆 (Evidence)")
338
- if st.session_state.current_refs:
339
- for i, ref in enumerate(st.session_state.current_refs):
340
- score = ref['score']
341
- score_color = "#00ff41" if score > 0.6 else "#ffb700" if score > 0.4 else "#ff003c"
342
-
343
- with st.expander(f"📄 Doc {i+1}: {ref['filename'][:20]}...", expanded=(i==0)):
344
- st.markdown(f"""
345
- <div style="margin-bottom:5px;">
346
- <span style="color:#888;">Relevance:</span>
347
- <span style="color:{score_color}; font-weight:bold;">{score:.4f}</span>
348
- </div>
349
- """, unsafe_allow_html=True)
350
- st.code(ref['content'], language="text")
351
  else:
352
- st.info("等待输入指令以检索知识库...")
353
- st.markdown("""
354
- <div style="opacity:0.3; font-size:0.8rem; margin-top:20px;">
355
- Waiting for query signal...<br>
356
- Index Status: Ready<br>
357
- Awaiting Input
358
- </div>
359
- """, unsafe_allow_html=True)
360
 
361
  if __name__ == "__main__":
362
  main()
 
1
+ # import streamlit as st
2
+ # import pandas as pd
3
+ # import numpy as np
4
+ # import jieba
5
+ # import requests
6
+ # import os
7
+ # import sys
8
+ # import subprocess
9
+ # from openai import OpenAI
10
+ # from rank_bm25 import BM25Okapi
11
+ # from sklearn.metrics.pairwise import cosine_similarity
12
+
13
+ # # ================= 1. 全局配置与 CSS注入 =================
14
+
15
+ # API_KEY = os.getenv("SILICONFLOW_API_KEY")
16
+ # API_BASE = "https://api.siliconflow.cn/v1"
17
+ # EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-4B"
18
+ # RERANK_MODEL = "Qwen/Qwen3-Reranker-4B"
19
+ # GEN_MODEL_NAME = "MiniMaxAI/MiniMax-M2"
20
+ # DATA_FILENAME = "comsol_embedded.parquet"
21
+ # DATA_URL = "https://share.leezhu.cn/graduation_design_data/comsol_embedded.parquet"
22
+
23
+ # st.set_page_config(
24
+ # page_title="COMSOL Dark Expert",
25
+ # page_icon="🌌",
26
+ # layout="wide",
27
+ # initial_sidebar_state="expanded"
28
+ # )
29
+
30
+ # # --- 注入自定义 CSS (保持之前的审美) ---
31
+ # st.markdown("""
32
+ # <style>
33
+ # /* 1. 整体背景 - 深空黑 */
34
+ # .stApp {
35
+ # background-color: #050505;
36
+ # background-image: radial-gradient(circle at 50% 0%, #1a1f35 0%, #050505 60%);
37
+ # color: #e0e0e0;
38
+ # font-family: 'Inter', system-ui, -apple-system, sans-serif;
39
+ # }
40
+
41
+ # /* 2. 隐藏默认组件 */
42
+ # #MainMenu {visibility: hidden;}
43
+ # footer {visibility: hidden;}
44
+ # header {visibility: hidden;}
45
+
46
+ # /* 3. 聊天气泡 */
47
+ # [data-testid="stChatMessage"] {
48
+ # background: rgba(255, 255, 255, 0.03);
49
+ # border: 1px solid rgba(255, 255, 255, 0.08);
50
+ # border-radius: 16px;
51
+ # backdrop-filter: blur(12px);
52
+ # box-shadow: 0 4px 20px rgba(0,0,0,0.2);
53
+ # padding: 1.2rem;
54
+ # }
55
+
56
+ # /* 用户气泡 */
57
+ # [data-testid="stChatMessage"][data-testid="user"] {
58
+ # background: rgba(41, 181, 232, 0.1);
59
+ # border-color: rgba(41, 181, 232, 0.2);
60
+ # }
61
+
62
+ # /* 4. 自定义标题栏 */
63
+ # .custom-header {
64
+ # border-bottom: 1px solid rgba(255,255,255,0.1);
65
+ # padding-bottom: 1rem;
66
+ # margin-bottom: 2rem;
67
+ # display: flex;
68
+ # align-items: center;
69
+ # gap: 1rem;
70
+ # }
71
+ # .glitch-text {
72
+ # font-size: 2rem;
73
+ # font-weight: 800;
74
+ # background: linear-gradient(120deg, #fff, #29B5E8);
75
+ # -webkit-background-clip: text;
76
+ # -webkit-text-fill-color: transparent;
77
+ # letter-spacing: -1px;
78
+ # }
79
+
80
+ # /* 5. 快捷按钮 */
81
+ # div.stButton > button {
82
+ # background: rgba(255,255,255,0.05);
83
+ # color: #aaa;
84
+ # border: 1px solid rgba(255,255,255,0.1);
85
+ # border-radius: 20px;
86
+ # padding: 0.5rem 1rem;
87
+ # font-size: 0.85rem;
88
+ # transition: all 0.3s;
89
+ # width: 100%;
90
+ # }
91
+ # div.stButton > button:hover {
92
+ # background: rgba(41, 181, 232, 0.2);
93
+ # color: #fff;
94
+ # border-color: #29B5E8;
95
+ # transform: translateY(-2px);
96
+ # }
97
+
98
+ # /* 6. 输入框 */
99
+ # .stChatInputContainer textarea {
100
+ # background-color: #0f1115 !important;
101
+ # border: 1px solid #333 !important;
102
+ # color: white !important;
103
+ # border-radius: 12px !important;
104
+ # }
105
+
106
+ # /* 7. Expander */
107
+ # .streamlit-expanderHeader {
108
+ # background-color: rgba(255,255,255,0.02);
109
+ # border: 1px solid rgba(255,255,255,0.05);
110
+ # border-radius: 8px;
111
+ # color: #bbb;
112
+ # }
113
+ # </style>
114
+ # """, unsafe_allow_html=True)
115
+
116
+ # # ================= 2. 核心逻辑(数据与RAG) =================
117
+
118
+ # if not API_KEY:
119
+ # st.error("⚠️ 未检测到 API Key。请在 Settings -> Secrets 中配置 `SILICONFLOW_API_KEY`。")
120
+ # st.stop()
121
+
122
+ # def download_with_curl(url, output_path):
123
+ # try:
124
+ # cmd = [
125
+ # "curl", "-L",
126
+ # "-A", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
127
+ # "-o", output_path,
128
+ # "--fail",
129
+ # url
130
+ # ]
131
+ # result = subprocess.run(cmd, capture_output=True, text=True)
132
+ # if result.returncode != 0: raise Exception(f"Curl failed: {result.stderr}")
133
+ # return True
134
+ # except Exception as e:
135
+ # print(f"Curl download error: {e}")
136
+ # return False
137
+
138
+ # def get_data_file_path():
139
+ # possible_paths = [
140
+ # DATA_FILENAME, os.path.join("/app", DATA_FILENAME),
141
+ # os.path.join("processed_data", DATA_FILENAME),
142
+ # os.path.join("src", DATA_FILENAME),
143
+ # os.path.join("..", DATA_FILENAME), "/tmp/" + DATA_FILENAME
144
+ # ]
145
+ # for path in possible_paths:
146
+ # if os.path.exists(path): return path
147
+
148
+ # download_target = "/app/" + DATA_FILENAME
149
+ # try: os.makedirs(os.path.dirname(download_target), exist_ok=True)
150
+ # except: download_target = "/tmp/" + DATA_FILENAME
151
+
152
+ # status_container = st.empty()
153
+ # status_container.info("📡 正在接入神经元网络... (下载核心数���中)")
154
+
155
+ # if download_with_curl(DATA_URL, download_target):
156
+ # status_container.empty()
157
+ # return download_target
158
+
159
+ # try:
160
+ # headers = {'User-Agent': 'Mozilla/5.0'}
161
+ # r = requests.get(DATA_URL, headers=headers, stream=True)
162
+ # r.raise_for_status()
163
+ # with open(download_target, 'wb') as f:
164
+ # for chunk in r.iter_content(chunk_size=8192): f.write(chunk)
165
+ # status_container.empty()
166
+ # return download_target
167
+ # except Exception as e:
168
+ # st.error(f"❌ 数据链路中断。Error: {e}")
169
+ # st.stop()
170
+
171
+ # class FullRetriever:
172
+ # def __init__(self, parquet_path):
173
+ # try: self.df = pd.read_parquet(parquet_path)
174
+ # except Exception as e: st.error(f"Memory Matrix Load Failed: {e}"); st.stop()
175
+ # self.documents = self.df['content'].tolist()
176
+ # self.embeddings = np.stack(self.df['embedding'].values)
177
+ # self.bm25 = BM25Okapi([jieba.lcut(str(d).lower()) for d in self.documents])
178
+ # self.client = OpenAI(base_url=API_BASE, api_key=API_KEY)
179
+ # # Reranker 初始化移到这里,减少重复调用
180
+ # self.rerank_headers = {"Content-Type": "application/json", "Authorization": f"Bearer {API_KEY}"}
181
+ # self.rerank_url = f"{API_BASE}/rerank"
182
+
183
+ # def _get_emb(self, q):
184
+ # try: return self.client.embeddings.create(model=EMBEDDING_MODEL, input=[q]).data[0].embedding
185
+ # except: return [0.0] * 1024
186
+
187
+ # def hybrid_search(self, query: str, top_k=5):
188
+ # # 1. Vector
189
+ # q_emb = self._get_emb(query)
190
+ # vec_scores = cosine_similarity([q_emb], self.embeddings)[0]
191
+ # vec_idx = np.argsort(vec_scores)[-100:][::-1]
192
+ # # 2. Keyword
193
+ # kw_idx = np.argsort(self.bm25.get_scores(jieba.lcut(query.lower())))[-100:][::-1]
194
+ # # 3. RRF Fusion
195
+ # fused = {}
196
+ # for r, i in enumerate(vec_idx): fused[i] = fused.get(i, 0) + 1/(60+r+1)
197
+ # for r, i in enumerate(kw_idx): fused[i] = fused.get(i, 0) + 1/(60+r+1)
198
+ # c_idxs = [x[0] for x in sorted(fused.items(), key=lambda x:x[1], reverse=True)[:50]]
199
+ # c_docs = [self.documents[i] for i in c_idxs]
200
+
201
+ # # 4. Rerank
202
+ # try:
203
+ # payload = {"model": RERANK_MODEL, "query": query, "documents": c_docs, "top_n": top_k}
204
+ # resp = requests.post(self.rerank_url, headers=self.rerank_headers, json=payload, timeout=10)
205
+ # results = resp.json().get('results', [])
206
+ # except:
207
+ # results = [{"index": i, "relevance_score": 0.0} for i in range(len(c_docs))][:top_k]
208
+
209
+ # final_res = []
210
+ # context = ""
211
+ # for i, item in enumerate(results):
212
+ # orig_idx = c_idxs[item['index']]
213
+ # row = self.df.iloc[orig_idx]
214
+ # final_res.append({
215
+ # "score": item['relevance_score'],
216
+ # "filename": row['filename'],
217
+ # "content": row['content']
218
+ # })
219
+ # context += f"[文档{i+1}]: {row['content']}\n\n"
220
+ # return final_res, context
221
+
222
+ # @st.cache_resource
223
+ # def load_engine():
224
+ # real_path = get_data_file_path()
225
+ # return FullRetriever(real_path)
226
+
227
+ # # ================= 3. UI 主程序 =================
228
+
229
+ # def main():
230
+ # st.markdown("""
231
+ # <div class="custom-header">
232
+ # <div style="font-size: 3rem;">🌌</div>
233
+ # <div>
234
+ # <div class="glitch-text">COMSOL DARK EXPERT</div>
235
+ # <div style="color: #666; font-size: 0.9rem; letter-spacing: 1px;">
236
+ # NEURAL SIMULATION ASSISTANT <span style="color:#29B5E8">V4.1 Fixed</span>
237
+ # </div>
238
+ # </div>
239
+ # </div>
240
+ # """, unsafe_allow_html=True)
241
+
242
+ # retriever = load_engine()
243
+
244
+ # with st.sidebar:
245
+ # st.markdown("### ⚙️ 控制台")
246
+ # top_k = st.slider("检索深度", 1, 10, 4)
247
+ # temp = st.slider("发散度", 0.0, 1.0, 0.3)
248
+ # st.markdown("---")
249
+ # if st.button("🗑️ 清空记忆 (Clear)", use_container_width=True):
250
+ # st.session_state.messages = []
251
+ # st.session_state.current_refs = []
252
+ # st.rerun()
253
+
254
+ # if "messages" not in st.session_state: st.session_state.messages = []
255
+ # if "current_refs" not in st.session_state: st.session_state.current_refs = []
256
+
257
+ # col_chat, col_evidence = st.columns([0.65, 0.35], gap="large")
258
+
259
+ # # ------------------ 处理输入源 ------------------
260
+ # # 我们定义一个变量 user_input,不管它来自按钮还是输入框
261
+ # user_input = None
262
+
263
+ # with col_chat:
264
+ # # 1. 如果历史为空,显示快捷按钮
265
+ # if not st.session_state.messages:
266
+ # st.markdown("##### 💡 初始化提问序列 (Starter Sequence)")
267
+ # c1, c2, c3 = st.columns(3)
268
+ # # 点击按钮直接赋值给 user_input
269
+ # if c1.button("🌊 流固耦合接口设置"):
270
+ # user_input = "怎么设置流固耦���接口?"
271
+ # elif c2.button("⚡ 低频电磁场网格"):
272
+ # user_input = "低频电磁场网格划分有哪些技巧?"
273
+ # elif c3.button("📉 求解器不收敛"):
274
+ # user_input = "求解器不收敛通常怎么解决?"
275
+
276
+ # # 2. 渲染历史消息
277
+ # for msg in st.session_state.messages:
278
+ # with st.chat_message(msg["role"]):
279
+ # st.markdown(msg["content"])
280
+
281
+ # # 3. 处理底部输入框 (如果有按钮输入,这里会被跳过,因为 user_input 已经有值了)
282
+ # if not user_input:
283
+ # user_input = st.chat_input("输入指令或物理参数问题...")
284
+
285
+ # # ------------------ 统一处理消息追加 ------------------
286
+ # if user_input:
287
+ # st.session_state.messages.append({"role": "user", "content": user_input})
288
+ # # 强制刷新以立即在 UI 上显示用户的提问(对于按钮点击尤为重要)
289
+ # st.rerun()
290
+
291
+ # # ------------------ 统一触发生成 (修复的核心) ------------------
292
+ # # 检查:如果有消息,且最后一条是 User 发的,说明需要 Assistant 回答
293
+ # if st.session_state.messages and st.session_state.messages[-1]["role"] == "user":
294
+
295
+ # # 获取最后一条用户消息
296
+ # last_query = st.session_state.messages[-1]["content"]
297
+
298
+ # with col_chat: # 确保在聊天栏显示
299
+ # with st.spinner("🔍 正在扫描向量空间..."):
300
+ # refs, context = retriever.hybrid_search(last_query, top_k=top_k)
301
+ # st.session_state.current_refs = refs
302
+
303
+ # system_prompt = f"""你是一个COMSOL高级仿真专家。请基于提供的文档回答问题。
304
+ # 要求:
305
+ # 1. 语气专业、客观,逻辑严密。
306
+ # 2. 涉及物理公式时,**必须**使用 LaTeX 格式(例如 $E = mc^2$)。
307
+ # 3. 涉及步骤或参数对比时,优先使用 Markdown 列表或表格。
308
+
309
+ # 参考文档:
310
+ # {context}
311
+ # """
312
+
313
+ # with st.chat_message("assistant"):
314
+ # resp_cont = st.empty()
315
+ # full_resp = ""
316
+ # client = OpenAI(base_url=API_BASE, api_key=API_KEY)
317
+
318
+ # try:
319
+ # stream = client.chat.completions.create(
320
+ # model=GEN_MODEL_NAME,
321
+ # messages=[{"role": "system", "content": system_prompt}] + st.session_state.messages[-6:], # 除去当前的System
322
+ # temperature=temp,
323
+ # stream=True
324
+ # )
325
+ # for chunk in stream:
326
+ # txt = chunk.choices[0].delta.content
327
+ # if txt:
328
+ # full_resp += txt
329
+ # resp_cont.markdown(full_resp + " ▌")
330
+ # resp_cont.markdown(full_resp)
331
+ # st.session_state.messages.append({"role": "assistant", "content": full_resp})
332
+ # except Exception as e:
333
+ # st.error(f"Neural Generation Failed: {e}")
334
+
335
+ # # ------------------ 渲染右侧证据栏 ------------------
336
+ # with col_evidence:
337
+ # st.markdown("### 📚 神经记忆 (Evidence)")
338
+ # if st.session_state.current_refs:
339
+ # for i, ref in enumerate(st.session_state.current_refs):
340
+ # score = ref['score']
341
+ # score_color = "#00ff41" if score > 0.6 else "#ffb700" if score > 0.4 else "#ff003c"
342
+
343
+ # with st.expander(f"📄 Doc {i+1}: {ref['filename'][:20]}...", expanded=(i==0)):
344
+ # st.markdown(f"""
345
+ # <div style="margin-bottom:5px;">
346
+ # <span style="color:#888;">Relevance:</span>
347
+ # <span style="color:{score_color}; font-weight:bold;">{score:.4f}</span>
348
+ # </div>
349
+ # """, unsafe_allow_html=True)
350
+ # st.code(ref['content'], language="text")
351
+ # else:
352
+ # st.info("等待输入指令以检索知识库...")
353
+ # st.markdown("""
354
+ # <div style="opacity:0.3; font-size:0.8rem; margin-top:20px;">
355
+ # Waiting for query signal...<br>
356
+ # Index Status: Ready<br>
357
+ # Awaiting Input
358
+ # </div>
359
+ # """, unsafe_allow_html=True)
360
+
361
+ # if __name__ == "__main__":
362
+ # main()
363
+
364
+
365
+
366
  import streamlit as st
367
  import pandas as pd
368
  import numpy as np
369
  import jieba
370
  import requests
371
  import os
372
+ import time
373
+ import json
374
+ import re
375
+ import random
376
  import subprocess
377
  from openai import OpenAI
378
  from rank_bm25 import BM25Okapi
379
  from sklearn.metrics.pairwise import cosine_similarity
380
+ from typing import List, Dict, Tuple, Any
381
 
382
+ # ================= 1. 全局配置与样式 =================
383
 
384
+ # API 配置 (从 HF 环境变量获取)
385
  API_BASE = "https://api.siliconflow.cn/v1"
386
+ API_KEY = os.getenv("SILICONFLOW_API_KEY")
387
+
388
+ # 模型名称配置
389
  EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-4B"
390
  RERANK_MODEL = "Qwen/Qwen3-Reranker-4B"
391
+ GEN_MODEL_NAME = "MiniMaxAI/MiniMax-M2"
392
+ QE_MODEL_NAME = "Qwen/Qwen3-Next-80B-A3B-Instruct"
393
+ SUGGEST_MODEL_NAME = "Qwen/Qwen3-Next-80B-A3B-Instruct"
394
+
395
+ # 预置问题池
396
+ PRESET_QUESTIONS = [
397
+ "如何设置流固耦合接口?",
398
+ "求解器不收敛怎么办?",
399
+ "网格划分有哪些技巧?",
400
+ "如何定义随时间变化的边界条件?",
401
+ "计算结果如何导出数据?",
402
+ "什么是完美匹配层 (PML)?",
403
+ "低频电磁场仿真注意事项",
404
+ "如何提高瞬态计算速度?",
405
+ "参数化扫描如何设置?",
406
+ "多物理场耦合的收敛性优化"
407
+ ]
408
+
409
+ # 数据文件配置
410
  DATA_FILENAME = "comsol_embedded.parquet"
411
  DATA_URL = "https://share.leezhu.cn/graduation_design_data/comsol_embedded.parquet"
412
 
413
+ # 页面配置
414
  st.set_page_config(
415
+ page_title="COMSOL RAG 策略控制台",
416
+ page_icon="🎛️",
417
  layout="wide",
418
  initial_sidebar_state="expanded"
419
  )
420
 
421
+ # 自定义CSS样式
422
  st.markdown("""
423
  <style>
424
+ /* 深色主题 */
425
  .stApp {
426
+ background-color: #0E1117;
427
+ color: #E0E0E0;
 
 
428
  }
429
 
430
+ /* 聊天消息样式 */
 
 
 
 
 
431
  [data-testid="stChatMessage"] {
432
+ background-color: #1E1E1E;
433
+ border: 1px solid #333;
434
+ border-radius: 10px;
435
+ box-shadow: 0 2px 4px rgba(0,0,0,0.3);
 
 
 
 
 
 
 
 
436
  }
437
 
438
+ /* 侧边样式 */
439
+ [data-testid="stSidebar"] {
440
+ background-color: #161B22;
441
+ border-right: 1px solid #30363D;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
442
  }
443
 
444
+ /* 策略标签 */
445
+ .strat-tag {
446
+ font-size: 0.75rem;
447
+ padding: 3px 8px;
448
+ border-radius: 4px;
449
+ margin-right: 6px;
450
+ font-weight: bold;
451
+ display: inline-block;
452
+ margin-bottom: 4px;
453
+ border: 1px solid rgba(255,255,255,0.2);
454
  }
455
+ .tag-vec { background-color: rgba(31, 119, 180, 0.3); color: #4EA8DE; border-color: #1f77b4; }
456
+ .tag-bm25 { background-color: rgba(255, 127, 14, 0.3); color: #FFAB5E; border-color: #ff7f0e; }
457
+ .tag-qe { background-color: rgba(44, 160, 44, 0.3); color: #69DB7C; border-color: #2ca02c; }
458
+ .tag-rerank { background-color: rgba(214, 39, 40, 0.3); color: #FF6B6B; border-color: #d62728; }
459
+
460
+ /* 过程展示框 */
461
+ .process-box {
462
+ background-color: #0D1117;
463
+ border: 1px solid #30363D;
464
+ padding: 15px;
465
  border-radius: 8px;
466
+ font-size: 0.9rem;
467
+ color: #8B949E;
468
+ margin-bottom: 15px;
469
+ }
470
+
471
+ /* 策略矩阵标题 */
472
+ .strategy-title {
473
+ background: linear-gradient(45deg, #667eea 0%, #764ba2 100%);
474
+ -webkit-background-clip: text;
475
+ -webkit-text-fill-color: transparent;
476
+ background-clip: text;
477
+ font-weight: bold;
478
+ font-size: 1.2rem;
479
  }
480
  </style>
481
  """, unsafe_allow_html=True)
482
 
483
+ # ================= 2. 数据下载工具 (HF 适配) =================
 
 
 
 
484
 
485
  def download_with_curl(url, output_path):
486
+ """使用 curl 下载文件,增加鲁棒性"""
487
  try:
488
  cmd = [
489
  "curl", "-L",
 
493
  url
494
  ]
495
  result = subprocess.run(cmd, capture_output=True, text=True)
496
+ if result.returncode != 0:
497
+ print(f"Curl stderr: {result.stderr}")
498
+ return False
499
  return True
500
  except Exception as e:
501
  print(f"Curl download error: {e}")
502
  return False
503
 
504
  def get_data_file_path():
505
+ """获取数据文件路径,如果不存在则自动下载"""
506
+ # 优先检查本地可能存在的路径
507
  possible_paths = [
508
+ DATA_FILENAME,
509
+ os.path.join("/app", DATA_FILENAME),
510
  os.path.join("processed_data", DATA_FILENAME),
511
+ os.path.join(os.getcwd(), DATA_FILENAME)
 
512
  ]
513
+
514
  for path in possible_paths:
515
+ if os.path.exists(path):
516
+ return path
517
 
518
+ # 如果都没找到,准备下载
519
+ # HF Spaces 通常在 /home/user/app 下运行,直接下载到当前目录
520
+ download_target = os.path.join(os.getcwd(), DATA_FILENAME)
521
+
522
  status_container = st.empty()
523
+ status_container.info("📡 正在接入神经元网络... (下载核心数据中,首次运行可能需要几十秒)")
524
 
525
+ # 尝试 Curl 下载
526
  if download_with_curl(DATA_URL, download_target):
527
  status_container.empty()
528
  return download_target
529
 
530
+ # 降级尝试 Requests 下载
531
  try:
532
  headers = {'User-Agent': 'Mozilla/5.0'}
533
  r = requests.get(DATA_URL, headers=headers, stream=True)
534
  r.raise_for_status()
535
  with open(download_target, 'wb') as f:
536
+ for chunk in r.iter_content(chunk_size=8192):
537
+ f.write(chunk)
538
  status_container.empty()
539
  return download_target
540
  except Exception as e:
541
+ st.error(f"❌ 数据下载失败。Error: {e}")
542
  st.stop()
543
 
544
+ # ================= 3. 核心 RAG 控制器 =================
545
+
546
+ class RAGController:
547
+ """RAG系统控制器 - 实现策略矩阵"""
548
+
549
+ def __init__(self):
550
+ """初始化控制器"""
551
+ if not API_KEY:
552
+ st.error("⚠️ 未检测到 API Key。请在 Space Settings -> Secrets 中配置 `SILICONFLOW_API_KEY`。")
553
+ st.stop()
554
+
555
  self.client = OpenAI(base_url=API_BASE, api_key=API_KEY)
556
+ self.df = None
557
+ self.documents = []
558
+ self.embeddings = None
559
+ self.bm25 = None
560
+ self.filenames = []
561
+ self._load_data()
562
+
563
+ def _load_data(self):
564
+ """加载COMSOL文档数据"""
565
+ real_path = get_data_file_path()
566
+
567
+ try:
568
+ # 加载数据
569
+ self.df = pd.read_parquet(real_path)
570
+ self.documents = self.df['content'].tolist()
571
+ self.filenames = self.df['filename'].tolist()
572
+
573
+ # 加载向量嵌入
574
+ self.embeddings = np.stack(self.df['embedding'].values)
575
+
576
+ # 初始化BM25
577
+ tokenized_corpus = [jieba.lcut(str(doc).lower()) for doc in self.documents]
578
+ self.bm25 = BM25Okapi(tokenized_corpus)
579
+
580
+ st.success(f"✅ 成功加载 {len(self.documents)} 条文档")
581
+
582
+ except Exception as e:
583
+ st.error(f"❌ 数据加载失败: {str(e)}")
584
+ st.stop()
585
+
586
+ def get_embedding(self, text: str) -> List[float]:
587
+ """获取文本向量嵌入"""
588
+ try:
589
+ resp = self.client.embeddings.create(
590
+ model=EMBEDDING_MODEL,
591
+ input=[text]
592
+ )
593
+ return resp.data[0].embedding
594
+ except Exception as e:
595
+ st.warning(f"向量获取失败: {e}")
596
+ return [0.0] * 2560 # Qwen3-Embedding-4B dimension fallback
597
+
598
+ def expand_query(self, query: str) -> Tuple[str, float]:
599
+ """查询扩展 - 使用LLM优化查询"""
600
+ prompt = f"""你是COMSOL仿真专家。请将用户的口语化问题改写为专业的检索查询。
601
+
602
+ 要求:
603
+ 1. 补充COMSOL专业术语(物理场、模块、边界条件等)
604
+ 2. 保持问题核心意图不变
605
+ 3. 输出简洁,仅返回改写后的查询
606
 
607
+ 用户问题: {query}
608
+ 专业查询:"""
609
+
610
  try:
611
+ start_time = time.time()
612
+ resp = self.client.chat.completions.create(
613
+ model=QE_MODEL_NAME,
614
+ messages=[{"role": "user", "content": prompt}],
615
+ temperature=0.3
616
+ )
617
+ expanded = resp.choices[0].message.content.strip()
618
+ elapsed = time.time() - start_time
619
+ return expanded, elapsed
620
+ except Exception as e:
621
+ print(f"QE Error: {e}")
622
+ return query, 0
623
+
624
+ def vector_search(self, query: str, top_k: int = 100) -> List[Tuple[int, float]]:
625
+ """向量检索"""
626
+ q_vec = self.get_embedding(query)
627
+ similarities = cosine_similarity([q_vec], self.embeddings)[0]
628
+ top_indices = np.argsort(similarities)[-top_k:][::-1]
629
+ return [(idx, similarities[idx]) for idx in top_indices]
630
+
631
+ def bm25_search(self, query: str, top_k: int = 100) -> List[Tuple[int, float]]:
632
+ """BM25关键词检索"""
633
+ tokenized_query = jieba.lcut(query.lower())
634
+ scores = self.bm25.get_scores(tokenized_query)
635
+ top_indices = np.argsort(scores)[-top_k:][::-1]
636
+ return [(idx, scores[idx]) for idx in top_indices]
637
+
638
+ def reciprocal_rank_fusion(self, vector_results: List[Tuple[int, float]],
639
+ bm25_results: List[Tuple[int, float]], k: int = 60) -> Dict[int, float]:
640
+ """RRF融合算法"""
641
+ scores = {}
642
+ for rank, (idx, score) in enumerate(vector_results):
643
+ scores[idx] = scores.get(idx, 0) + 1.0 / (k + rank + 1)
644
+ for rank, (idx, score) in enumerate(bm25_results):
645
+ scores[idx] = scores.get(idx, 0) + 1.0 / (k + rank + 1)
646
+ return scores
647
+
648
+ def rerank_documents(self, query: str, documents: List[Dict], top_n: int) -> Tuple[List[Dict], float]:
649
+ """使用重排序模型"""
650
+ if not documents: return [], 0
651
+
652
+ url = f"{API_BASE}/rerank"
653
+ headers = {
654
+ "Authorization": f"Bearer {API_KEY}",
655
+ "Content-Type": "application/json"
656
+ }
657
+
658
+ # 截断文档内容以符合 Context Window
659
+ docs_content = [doc["content"][:2048] for doc in documents]
660
+
661
+ payload = {
662
+ "model": RERANK_MODEL,
663
+ "query": query,
664
+ "documents": docs_content,
665
+ "top_n": top_n
666
+ }
667
+
668
+ try:
669
+ start_time = time.time()
670
+ response = requests.post(url, headers=headers, json=payload, timeout=20)
671
+ elapsed = time.time() - start_time
672
+
673
+ if response.status_code == 200:
674
+ results = response.json().get("results", [])
675
+ reranked_docs = []
676
+ for result in results:
677
+ original_doc = documents[result["index"]]
678
+ original_doc["rerank_score"] = result["relevance_score"]
679
+ original_doc["final_score"] = result["relevance_score"]
680
+ reranked_docs.append(original_doc)
681
+ return reranked_docs, elapsed
682
+ else:
683
+ print(f"Rerank API Error: {response.text}")
684
+ return documents[:top_n], elapsed
685
+ except Exception as e:
686
+ print(f"Rerank Exception: {e}")
687
+ return documents[:top_n], 0
688
+
689
+ def execute_strategy(self, query: str, config: Dict[str, Any]) -> Dict[str, Any]:
690
+ """执行策略矩阵"""
691
+ start_time = time.time()
692
+ result = {
693
+ 'original_query': query,
694
+ 'final_query': query,
695
+ 'documents': [],
696
+ 'steps': [],
697
+ 'metrics': {'qe_time': 0, 'retrieval_time': 0, 'rerank_time': 0, 'total_time': 0},
698
+ 'strategy_tags': []
699
+ }
700
+
701
+ # 1. 查询扩展
702
+ if config['use_qe']:
703
+ expanded_q, qe_time = self.expand_query(query)
704
+ result['final_query'] = expanded_q
705
+ result['metrics']['qe_time'] = qe_time
706
+ result['steps'].append(f"🧠 查询扩展 ({qe_time:.2f}s): {query} → **{expanded_q}**")
707
+ result['strategy_tags'].append("QE")
708
+
709
+ # 2. 检索
710
+ retrieval_start = time.time()
711
+ query_to_search = result['final_query']
712
+
713
+ if config['strategy'] == 'Vector':
714
+ results = self.vector_search(query_to_search)
715
+ result['steps'].append(f"🔍 向量检索: 找到 {len(results)} 个候选")
716
+ result['strategy_tags'].append("Vector")
717
+ elif config['strategy'] == 'BM25':
718
+ results = self.bm25_search(query_to_search)
719
+ result['steps'].append(f"🔍 BM25检索: 找到 {len(results)} 个候选")
720
+ result['strategy_tags'].append("BM25")
721
+ elif config['strategy'] == 'Hybrid':
722
+ vec_results = self.vector_search(query_to_search)
723
+ bm25_results = self.bm25_search(query_to_search)
724
+ fused_scores = self.reciprocal_rank_fusion(vec_results, bm25_results)
725
+ results = sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)
726
+ results = [(idx, score) for idx, score in results]
727
+ result['steps'].append(f"🔍 混合检索: Vector + BM25 → {len(results)} 个融合候选")
728
+ result['strategy_tags'].extend(["Vector", "BM25"])
729
+
730
+ result['metrics']['retrieval_time'] = time.time() - retrieval_start
731
+
732
+ # 3. 构建候选列表
733
+ recall_k = config['top_k'] * 3 if config['use_rerank'] else config['top_k']
734
+ top_results = results[:recall_k]
735
+
736
+ documents = []
737
+ for idx, score in top_results:
738
+ documents.append({
739
+ 'content': self.documents[idx],
740
+ 'filename': self.filenames[idx],
741
+ 'retrieval_score': score,
742
+ 'final_score': score,
743
+ 'type': 'retrieval'
744
  })
 
 
745
 
746
+ # 4. 重排序
747
+ if config['use_rerank']:
748
+ reranked_docs, rerank_time = self.rerank_documents(
749
+ result['final_query'], documents, config['top_k']
750
+ )
751
+ result['documents'] = reranked_docs
752
+ result['metrics']['rerank_time'] = rerank_time
753
+ result['steps'].append(f"⚖️ 重排序 ({rerank_time:.2f}s): 精选 Top-{config['top_k']}")
754
+ result['strategy_tags'].append("Rerank")
755
+ else:
756
+ result['documents'] = documents[:config['top_k']]
757
+
758
+ result['metrics']['total_time'] = time.time() - start_time
759
+ result['steps'].append(f"⏱️ 总耗时: {result['metrics']['total_time']:.2f}s")
760
+ return result
761
+
762
+ def generate_suggestions(controller, query: str, answer: str) -> List[str]:
763
+ """生成3个后续引导问题"""
764
+ prompt = f"""基于以下技术问答,预测用户可能感兴趣的3个后续COMSOL专业问题。
765
+ 用户问题:{query}
766
+ 专家回答:{answer[:800]}...
767
+
768
+ 要求:
769
+ 1. 问题简短(15字以内)。
770
+ 2. 紧扣当前话题。
771
+ 3. 严格输出 JSON 字符串数组格式,例如:["问题1", "问题2", "问题3"]。
772
+ 4. 不要包含任何 Markdown 标记。
773
+ """
774
+
775
+ try:
776
+ resp = controller.client.chat.completions.create(
777
+ model=SUGGEST_MODEL_NAME,
778
+ messages=[{"role": "user", "content": prompt}],
779
+ temperature=0.5
780
+ )
781
+ content = resp.choices[0].message.content.strip()
782
+ match = re.search(r'\[.*\]', content, re.DOTALL)
783
+ if match:
784
+ sugs = json.loads(match.group())
785
+ return sugs[:3]
786
+ return []
787
+ except Exception as e:
788
+ print(f"Suggestion Error: {e}")
789
+ return []
790
+
791
+ def generate_answer(controller, query: str, documents: List[Dict], history: List[Dict], max_rounds: int) -> str:
792
+ """流式生成回答"""
793
+ if not documents:
794
+ return "抱歉,没有找到相关的文档来回答您的问题。"
795
+
796
+ context_text = "\n\n".join([f"[文档{i+1}] {doc['content'][:800]}..." for i, doc in enumerate(documents)])
797
+
798
+ system_prompt = f"""你是一个COMSOL Multiphysics仿真专家。请基于提供的文档回答用户问题。
799
+ 要求:
800
+ 1. 语气专业,使用COMSOL术语。
801
+ 2. 物理公式使用 LaTeX(如 $E=mc^2$)。
802
+ 3. 如果文档信息不足,请如实告知,不要编造。
803
+
804
+ 【参考文档】:
805
+ {context_text}
806
+ """
807
+
808
+ # 构建历史记录
809
+ keep_messages = max_rounds * 2
810
+ history_to_send = history[:-1][-keep_messages:] if keep_messages > 0 else []
811
+
812
+ api_messages = [{"role": "system", "content": system_prompt}] + history_to_send + [{"role": "user", "content": query}]
813
+
814
+ try:
815
+ response = controller.client.chat.completions.create(
816
+ model=GEN_MODEL_NAME,
817
+ messages=api_messages,
818
+ temperature=0.3,
819
+ stream=True
820
+ )
821
+
822
+ answer = ""
823
+ placeholder = st.empty()
824
+ for chunk in response:
825
+ if chunk.choices[0].delta.content:
826
+ answer += chunk.choices[0].delta.content
827
+ placeholder.markdown(answer + "▌")
828
+ placeholder.markdown(answer)
829
+ return answer
830
+ except Exception as e:
831
+ return f"生成遇到错误: {e}"
832
+
833
+ # ================= 4. 初始化与组件渲染 =================
834
+
835
+ @st.cache_resource(show_spinner="🚀 正在初始化 RAG 引擎...")
836
+ def initialize_controller():
837
+ return RAGController()
838
+
839
+ def render_strategy_matrix():
840
+ st.markdown('<p class="strategy-title">🎯 策略矩阵配置</p>', unsafe_allow_html=True)
841
+ st.markdown("""<div style="background-color: #161B22; padding: 10px; border-radius: 8px; margin-bottom: 20px;">
842
+ <p style="font-size: 0.85rem; color: #8B949E; margin: 0;">⚙️ <b>参数调节</b>:控制检索片段数量和模型记忆深度。</p>
843
+ </div>""", unsafe_allow_html=True)
844
+
845
+ col1, col2 = st.columns(2)
846
+ with col1:
847
+ use_qe = st.toggle("🔄 查询扩展 (QE)", value=False)
848
+ use_rerank = st.toggle("⚖️ 深度重排序 (Rerank)", value=True)
849
+ max_history_rounds = st.slider("🧠 记忆轮数", 0, 50, 10, help="发给模型的对话历史轮数")
850
+
851
+ with col2:
852
+ strategy = st.radio("🔍 检索策略", ["Vector", "BM25", "Hybrid"], index=2)
853
+ top_k = st.slider("📊 检索数量", 1, 50, 10, help="从知识库召回的片段数量")
854
+
855
+ return {'use_qe': use_qe, 'strategy': strategy, 'use_rerank': use_rerank, 'top_k': top_k, 'max_history_rounds': max_history_rounds}
856
+
857
+ def render_metrics(metrics):
858
+ st.markdown("### 📊 性能指标")
859
+ cols = st.columns(4)
860
+ with cols[0]: st.metric("查询扩展", f"{metrics['qe_time']:.2f}s" if metrics['qe_time']>0 else "N/A", delta="QE" if metrics['qe_time']>0 else None)
861
+ with cols[1]: st.metric("检索耗时", f"{metrics['retrieval_time']:.2f}s")
862
+ with cols[2]: st.metric("重排序", f"{metrics['rerank_time']:.2f}s" if metrics['rerank_time']>0 else "N/A", delta="Rerank" if metrics['rerank_time']>0 else None)
863
+ with cols[3]: st.metric("总耗时", f"{metrics['total_time']:.2f}s", delta="⚡")
864
 
865
+ def render_documents(documents, strategy_tags):
866
+ st.markdown("### 📄 检索结果")
867
+ if not documents:
868
+ st.warning("未找到相关文档")
869
+ return
870
+
871
+ tags_html = "".join([f'<span class="strat-tag {{"QE":"tag-qe","Vector":"tag-vec","BM25":"tag-bm25","Rerank":"tag-rerank"}}.get(t, "")}}">{t}</span>' for t in strategy_tags]) # Simplified for brevity, use full logic if copying
872
+ # Manual mapping for safety
873
+ html_tags = ""
874
+ for tag in strategy_tags:
875
+ cls = "tag-vec" if tag=="Vector" else "tag-bm25" if tag=="BM25" else "tag-qe" if tag=="QE" else "tag-rerank"
876
+ html_tags += f'<span class="strat-tag {cls}">{tag}</span>'
877
+
878
+ st.markdown(f"**策略组合:** {html_tags}", unsafe_allow_html=True)
879
+
880
+ for i, doc in enumerate(documents):
881
+ score = doc.get('final_score', 0)
882
+ with st.expander(f"📄 文档 {i+1} | Score: {score:.4f} | {doc['filename'][:40]}...", expanded=i<2):
883
+ st.code(doc['content'], language="markdown")
884
+
885
+ # ================= 5. 主程序 =================
886
 
887
  def main():
888
+ # 状态初始化
889
+ if "messages" not in st.session_state: st.session_state.messages = []
890
+ if "last_result" not in st.session_state: st.session_state.last_result = None
891
+ if "suggestions" not in st.session_state: st.session_state.suggestions = random.sample(PRESET_QUESTIONS, 3)
892
+ if "prompt_trigger" not in st.session_state: st.session_state.prompt_trigger = None
 
 
 
 
 
 
 
 
893
 
894
+ # 加载控制器
895
+ controller = initialize_controller()
896
+
897
+ # 侧边栏
898
  with st.sidebar:
899
+ config = render_strategy_matrix()
 
 
900
  st.markdown("---")
901
+ if st.button("🗑️ 清空当前对话", use_container_width=True):
902
  st.session_state.messages = []
903
+ st.session_state.last_result = None
904
  st.rerun()
905
 
906
+ # 主界面布局
907
+ main_col, debug_col = st.columns([0.6, 0.4], gap="large")
908
+
909
+ with main_col:
910
+ st.markdown("### 💬 智能仿真问答")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
911
 
912
+ # 1. 历史消息
913
  for msg in st.session_state.messages:
914
  with st.chat_message(msg["role"]):
915
  st.markdown(msg["content"])
916
 
917
+ # 2. 建议区
918
+ if st.session_state.suggestions:
919
+ st.markdown("##### 💡 您可能想")
920
+ cols = st.columns(3)
921
+ for i, sug in enumerate(st.session_state.suggestions):
922
+ if cols[i].button(sug, use_container_width=True, key=f"sug_{i}"):
923
+ st.session_state.prompt_trigger = sug
924
+ st.rerun()
925
+
926
+ # 3. 输入处理
927
+ user_input = None
928
+ if st.session_state.prompt_trigger:
929
+ user_input = st.session_state.prompt_trigger
930
+ st.session_state.prompt_trigger = None
931
+ else:
932
+ user_input = st.chat_input("请输入您关于 COMSOL 的问题...")
933
 
934
+ # 4. 执行逻辑
935
+ if user_input:
936
+ st.session_state.messages.append({"role": "user", "content": user_input})
937
+ with st.chat_message("user"): st.markdown(user_input)
 
938
 
939
+ # 检索
940
+ with st.spinner("🔍 检索知识库中..."):
941
+ result = controller.execute_strategy(user_input, config)
942
+ st.session_state.last_result = result
943
+
944
+ # 生成
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
945
  with st.chat_message("assistant"):
946
+ answer = generate_answer(
947
+ controller, user_input, result['documents'],
948
+ st.session_state.messages, config['max_history_rounds']
949
+ )
950
+ st.session_state.messages.append({"role": "assistant", "content": answer})
951
 
952
+ # 生成新建议
953
+ new_sugs = generate_suggestions(controller, user_input, answer)
954
+ st.session_state.suggestions = new_sugs if new_sugs else random.sample(PRESET_QUESTIONS, 3)
955
+ st.rerun()
956
+
957
+ with debug_col:
958
+ st.markdown("### 🔍 系统调试视图")
959
+ if st.session_state.last_result:
960
+ res = st.session_state.last_result
961
+ st.info(f"当前查询: {res.get('original_query', 'N/A')}")
962
+ render_metrics(res['metrics'])
963
+ with st.expander("🔧 检索链路详情", expanded=True):
964
+ for step in res['steps']: st.markdown(f"- {step}")
965
+ render_documents(res['documents'], res['strategy_tags'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
966
  else:
967
+ st.info("等待交互...")
 
 
 
 
 
 
 
968
 
969
  if __name__ == "__main__":
970
  main()