ss900371tw commited on
Commit
26764a8
·
verified ·
1 Parent(s): 2918e34

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +390 -65
src/streamlit_app.py CHANGED
@@ -1,342 +1,667 @@
1
  import streamlit as st
 
2
  import os
 
3
  import io
 
4
  import numpy as np
 
5
  import faiss
 
6
  import uuid
 
7
  import time
8
- import google.generativeai as genai
9
- from google.generativeai.types import GenerationConfig, HarmCategory, HarmBlockThreshold # 引入必要的型別
 
 
10
 
11
  # === RAG 相關套件 ===
12
- import torch
 
 
 
 
13
  from langchain_community.embeddings import HuggingFaceEmbeddings
 
14
  from langchain_core.documents import Document
 
15
  from langchain_community.vectorstores import FAISS
 
16
  from langchain_community.vectorstores.utils import DistanceStrategy
 
17
  from langchain_community.docstore.in_memory import InMemoryDocstore
18
 
 
 
19
  # 嘗試匯入 pypdf
 
20
  try:
 
21
  import pypdf
 
22
  except ImportError:
 
23
  pypdf = None
24
 
 
 
25
  # --- 頁面設定 ---
 
26
  st.set_page_config(page_title="Cybersecurity AI Assistant (Gemini RAG)", page_icon="🛡️", layout="wide")
27
- st.title("🛡️ Gemini-2.5-Flash RAG 資安分析助理")
 
 
28
  st.markdown("已啟用:**IndexFlatIP** + **L2 正規化** + **Google Gemini API**")
29
 
 
 
30
  # --- 側邊欄設定 ---
 
31
  with st.sidebar:
 
32
  st.header("⚙️ 設定")
 
33
 
34
- # Google API Key 輸入
 
 
35
  default_key = os.getenv("GOOGLE_API_KEY", "")
 
36
  google_api_key = st.text_input("Google API Key", value=default_key, type="password")
 
37
 
 
38
  if not google_api_key:
 
39
  st.warning("請輸入 Google API Key 以繼續。")
 
40
 
 
41
  st.divider()
 
42
  st.subheader("📂 上傳分析檔案 (建立 RAG 庫)")
 
43
  uploaded_file = st.file_uploader("上傳 Logs/PDF/Code", type=['txt', 'py', 'log', 'csv', 'md', 'json', 'pdf'])
 
44
 
 
45
  st.divider()
 
46
  st.subheader("🔍 RAG 檢索設定")
 
47
  similarity_threshold = st.slider(
 
48
  "📐 Cosine Similarity 門檻",
 
49
  0.0, 1.0, 0.4, 0.01,
 
50
  help="數值越大越相似。一般建議 0.4~0.7"
 
51
  )
 
52
 
 
53
  st.divider()
 
54
  st.subheader("模型參數")
55
- # 調整 System Prompt 預設值,鼓勵模型提供結構化資安分析
56
- system_prompt = st.text_area(
57
- "System Prompt",
58
- value="You are a Tier 3 Senior Security Analyst. Use the retrieved context to answer the user's question. Specifically, follow the strict analysis framework provided by the user (Ransomware Kill Chain, Timeline Reconstruction) and respond in Traditional Chinese. If no malicious indicators are found, state clearly.",
59
- height=100
60
- )
61
- # 預設 Max Output Tokens 調整到 4096,以避免中斷
62
- max_output_tokens = st.slider("Max Output Tokens", 128, 8192, 4096, 128, help="調高此值可避免回應被截斷 (MAX_TOKENS 錯誤)。")
63
- temperature = st.slider("Temperature", 0.0, 2.0, 0.1, 0.1, help="資安分析建議使用極低的 Temperature (0.1-0.3)。")
64
 
 
65
  st.divider()
 
66
  if st.button("🗑️ 清除對話紀錄"):
 
67
  st.session_state.messages = []
 
68
  st.rerun()
69
 
 
 
70
  # --- 初始化 Gemini ---
 
71
  genai_model = None
 
72
  if google_api_key:
 
73
  try:
 
74
  genai.configure(api_key=google_api_key)
 
75
  # 使用 Flash 模型,速度快且便宜,適合 RAG 大量文本閱讀
76
- genai_model = genai.GenerativeModel('gemini-2.5-flash')
 
 
77
  except Exception as e:
 
78
  st.error(f"Gemini 設定失敗: {e}")
79
 
 
 
80
  # === Embedding 模型 (保留原本的 Jina 或其他 HF 模型) ===
 
 
 
81
  @st.cache_resource
 
82
  def load_embedding_model():
 
83
  model_kwargs = {
84
- 'device': 'cpu',
 
 
85
  'trust_remote_code': True
 
86
  }
 
87
  encode_kwargs = {
 
88
  'normalize_embeddings': False
 
89
  }
 
90
  return HuggingFaceEmbeddings(
 
91
  model_name="jinaai/jina-embeddings-v2-base-code",
 
92
  model_kwargs=model_kwargs,
 
93
  encode_kwargs=encode_kwargs
 
94
  )
95
 
 
 
96
  with st.spinner("正在載入 Embedding 模型..."):
 
97
  embedding_model = load_embedding_model()
98
 
 
 
99
  # === 建立向量庫 (Strict Cosine) - 邏輯維持不變 ===
 
100
  def process_file_to_faiss(uploaded_file):
 
101
  text_content = ""
 
102
  try:
 
103
  if uploaded_file.type == "application/pdf":
 
104
  if pypdf:
 
105
  pdf_reader = pypdf.PdfReader(uploaded_file)
 
106
  for page in pdf_reader.pages:
 
107
  text_content += page.extract_text() + "\n"
 
108
  else:
 
109
  return None, "PDF library missing"
 
110
  else:
 
111
  stringio = io.StringIO(uploaded_file.getvalue().decode("utf-8"))
 
112
  text_content = stringio.read()
 
113
 
 
114
  if not text_content.strip():
 
115
  return None, "File is empty"
116
 
 
 
117
  # 簡單切分
 
118
  events = [e + "</Event>" for e in text_content.split("</Event>") if e.strip()]
 
119
  if len(events) <= 1:
 
120
  events = [line for line in text_content.split("\n") if line.strip()]
 
121
 
 
122
  docs = [Document(page_content=e) for e in events]
 
123
 
 
124
  if not docs:
 
125
  return None, "No documents created"
126
 
 
 
127
  embeddings = embedding_model.embed_documents([d.page_content for d in docs])
 
128
  embeddings_np = np.array(embeddings).astype("float32")
 
129
  faiss.normalize_L2(embeddings_np)
 
130
 
 
131
  dimension = embeddings_np.shape[1]
 
132
  index = faiss.IndexFlatIP(dimension)
 
133
  index.add(embeddings_np)
 
134
 
 
135
  doc_ids = [str(uuid.uuid4()) for _ in range(len(docs))]
 
136
  docstore = InMemoryDocstore({_id: doc for _id, doc in zip(doc_ids, docs)})
 
137
  index_to_docstore_id = {i: _id for i, _id in enumerate(doc_ids)}
 
138
 
 
139
  vector_store = FAISS(
 
140
  embedding_function=embedding_model,
 
141
  index=index,
 
142
  docstore=docstore,
 
143
  index_to_docstore_id=index_to_docstore_id,
 
144
  distance_strategy=DistanceStrategy.COSINE
 
145
  )
 
146
 
147
- return vector_store, f"建立了 {len(docs)} 個日誌片段。"
 
 
148
  except Exception as e:
149
- return None, f"錯誤: {str(e)}"
 
 
 
150
 
151
  # === 檔案處理邏輯 ===
 
152
  if uploaded_file:
 
153
  file_key = f"vs_{uploaded_file.name}_{uploaded_file.size}"
 
154
 
 
155
  if "current_file_key" not in st.session_state or st.session_state.current_file_key != file_key:
 
156
  with st.spinner("偵測到新檔案,正在更新知識庫..."):
 
157
  vs, msg = process_file_to_faiss(uploaded_file)
 
158
  if vs:
 
159
  st.session_state.vector_store = vs
 
160
  st.session_state.current_file_key = file_key
 
161
  st.toast(f"知識庫已更新!{msg}", icon="✅")
 
162
  else:
 
163
  st.error(msg)
 
164
  else:
 
165
  if "vector_store" in st.session_state:
 
166
  del st.session_state.vector_store
 
167
  st.info("檔案已移除,已清除知識庫,回到一般模式。")
 
168
  if "current_file_key" in st.session_state:
169
- if 'vector_store' not in st.session_state: # 避免重複刪除
170
- del st.session_state.current_file_key
 
 
171
 
172
  # === 顯示對話歷史 ===
 
173
  if "messages" not in st.session_state:
 
174
  st.session_state.messages = []
175
 
 
 
176
  for idx, message in enumerate(st.session_state.messages):
 
177
  with st.chat_message(message["role"]):
 
178
  st.markdown(message["content"])
 
179
  if message.get("context"):
180
- # 確保只顯示最近一次對話的 expander
181
- if idx == len(st.session_state.messages) - 1:
182
- is_expanded = True
183
- else:
184
- is_expanded = False
185
-
186
- with st.expander("查看參考片段", expanded=is_expanded):
187
- st.code(message["context"], language="log")
188
  st.download_button(
 
189
  label="📥 下載此參考內容 (.txt)",
 
190
  data=message["context"],
 
191
  file_name=f"rag_context_{idx}.txt",
 
192
  mime="text/plain",
 
193
  key=f"dl_btn_{idx}"
 
194
  )
195
 
 
 
196
  # === Search 函數 ===
 
197
  def faiss_cosine_search_all(vector_store, query, threshold):
 
198
  q_emb = embedding_model.embed_query(query)
 
199
  q_emb = np.array([q_emb]).astype("float32")
 
200
  faiss.normalize_L2(q_emb)
 
201
 
 
202
  index = vector_store.index
 
203
  D, I = index.search(q_emb, k=index.ntotal)
 
204
 
 
205
  selected = []
206
- # 這裡只取相似度高於門檻的片段
207
  for score, idx in zip(D[0], I[0]):
 
208
  if idx == -1: continue
 
209
  if score >= threshold:
 
210
  doc_id = vector_store.index_to_docstore_id[idx]
 
211
  doc = vector_store.docstore.search(doc_id)
 
212
  selected.append((doc, score))
 
213
 
 
214
  selected.sort(key=lambda x: x[1], reverse=True)
 
215
  return selected
216
 
 
 
217
  # === Gemini 產生回答 ===
 
218
  def generate_rag_response_gemini(prompt, history, sys_prompt, vector_store=None, threshold=0.5):
 
219
  context_text = ""
 
 
 
220
 
 
221
  # 1. 檢索
 
222
  if vector_store:
223
- # 為了資安分析,我們需要擷取所有相關的 Log,所以將 threshold 作為篩選標準
224
  selected = faiss_cosine_search_all(vector_store, prompt, threshold)
225
-
226
  if selected:
227
- # 取前 50 個片段,以利用 Gemini-Flash 的大上下文視窗
228
- top_k_selected = selected[:50]
 
 
 
229
  retrieved_contents = [
 
230
  f"--- Chunk (sim={score:.3f}) ---\n{doc.page_content}"
231
- for (doc, score) in top_k_selected
 
 
232
  ]
 
233
  context_text = "\n".join(retrieved_contents)
234
 
 
 
235
  # 2. 構建 Prompt
236
- # 我們將系統指令與上下文合併,作為使用者訊息的第一部分,以確保指令被嚴格遵循
237
  if context_text:
 
238
  full_user_input = f"""
239
- System Instruction (CRITICAL: Adhere to the following framework and respond in Traditional Chinese):
240
- {sys_prompt}
 
 
241
 
242
  === RETRIEVED CONTEXT (Cosine ≥ {threshold}) ===
 
243
  {context_text}
 
244
  === END CONTEXT ===
245
 
 
 
246
  Question: {prompt}
247
- Analyze the question based strictly on the context and output using the required 4-part Chinese structure.
 
 
248
  """
 
249
  else:
 
250
  full_user_input = f"""
 
251
  System Instruction: {sys_prompt}
252
 
 
 
253
  Question: {prompt}
 
254
  """
255
 
 
 
256
  # 3. 轉換歷史訊息格式 (Streamlit -> Gemini)
 
 
 
257
  gemini_history = []
 
258
  for msg in history:
259
- # 由於我們在 full_user_input 塞入了 System Prompt 和 Context,這裡只傳遞純對話以避免上下文重複
260
  role = "user" if msg["role"] == "user" else "model"
261
- gemini_history.append({"role": role, "parts": [{"text": msg["content"]}]})
 
 
 
 
 
 
 
 
 
262
 
263
  # 4. 呼叫 Gemini
 
264
  try:
 
 
 
 
 
265
  # 設定生成參數
266
- generation_config = GenerationConfig(
 
 
267
  candidate_count=1,
 
268
  max_output_tokens=max_output_tokens,
 
269
  temperature=temperature,
 
270
  )
 
271
 
272
- # 安全設定 (設為 BLOCK_NONE,這是解決資安敏感內容被阻擋的關鍵)
 
 
273
  safety_settings = [
274
- {"category": HarmCategory.HARM_CATEGORY_HARASSMENT, "threshold": HarmBlockThreshold.BLOCK_NONE},
275
- {"category": HarmCategory.HARM_CATEGORY_HATE_SPEECH, "threshold": HarmBlockThreshold.BLOCK_NONE},
276
- {"category": HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, "threshold": HarmBlockThreshold.BLOCK_NONE},
277
- {"category": HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, "threshold": HarmBlockThreshold.BLOCK_NONE},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  ]
279
 
280
- chat = genai_model.start_chat(history=gemini_history)
281
-
282
  response = chat.send_message(
 
283
  full_user_input,
 
284
  generation_config=generation_config,
 
285
  safety_settings=safety_settings
 
286
  )
287
-
288
- # 檢查是否有安全阻擋或錯誤
289
- if response.prompt_feedback.block_reason or not response.candidates:
290
- # 安全性阻擋的錯誤處理
291
- reason = response.prompt_feedback.block_reason.name if response.prompt_feedback.block_reason else "Unknown"
292
- return f"Gemini API 錯誤: 由於安全原因,回應被阻擋。原因: {reason}", context_text
293
-
294
  return response.text, context_text
295
 
 
 
296
  except Exception as e:
297
- return f"Gemini API 錯誤: {str(e)}", context_text
 
 
 
298
 
299
  # === 處理使用者輸入 ===
 
300
  if prompt := st.chat_input("請輸入問題..."):
301
- if not google_api_key:
 
 
302
  st.error("請先輸入有效的 Google API Key")
303
- elif not genai_model:
304
- st.error("Gemini 模型初始化失敗,請檢查 API Key")
305
  else:
 
306
  vs = st.session_state.get("vector_store", None)
 
307
  display_prompt = prompt
 
 
 
 
 
308
 
309
- st.chat_message("user").markdown(f"🔍 **[RAG]** {prompt}" if vs else prompt)
 
 
310
 
 
311
  with st.chat_message("assistant"):
 
312
  msg_placeholder = st.empty()
 
313
 
314
- with st.spinner("Gemini Thinking... 正在進行日誌分析與 RAG 檢索"):
 
 
315
  response, retrieved_ctx = generate_rag_response_gemini(
 
316
  prompt,
 
317
  st.session_state.messages,
 
318
  system_prompt,
 
319
  vector_store=vs,
 
320
  threshold=similarity_threshold,
 
321
  )
 
322
 
 
323
  msg_placeholder.markdown(response)
 
324
 
 
325
  if retrieved_ctx:
326
- # 再次顯示擴展器,確保當前回合的參考資料可見
327
  with st.expander("查看檢索到的參考片段"):
328
- st.code(retrieved_ctx, language="log")
 
 
329
  st.download_button(
 
330
  label="📥 下載此參考內容 (.txt)",
 
331
  data=retrieved_ctx,
 
332
  file_name=f"rag_context_current.txt",
 
333
  mime="text/plain"
 
334
  )
335
 
336
- # 更新歷史 (將原始 prompt 和回應存入 session state)
337
- st.session_state.messages.append({"role": "user", "content": f"🔍 **[RAG]** {prompt}" if vs else prompt})
 
 
 
 
338
  st.session_state.messages.append({
 
339
  "role": "assistant",
 
340
  "content": response,
 
341
  "context": retrieved_ctx
 
342
  })
 
1
  import streamlit as st
2
+
3
  import os
4
+
5
  import io
6
+
7
  import numpy as np
8
+
9
  import faiss
10
+
11
  import uuid
12
+
13
  import time
14
+
15
+ import google.generativeai as genai # <--- 新增 Google SDK
16
+
17
+
18
 
19
  # === RAG 相關套件 ===
20
+
21
+ # 這裡保留 Torch 和 HuggingFaceEmbeddings 是為了向量化 (Embedding),這部分吃資源很少
22
+
23
+ import torch
24
+
25
  from langchain_community.embeddings import HuggingFaceEmbeddings
26
+
27
  from langchain_core.documents import Document
28
+
29
  from langchain_community.vectorstores import FAISS
30
+
31
  from langchain_community.vectorstores.utils import DistanceStrategy
32
+
33
  from langchain_community.docstore.in_memory import InMemoryDocstore
34
 
35
+
36
+
37
  # 嘗試匯入 pypdf
38
+
39
  try:
40
+
41
  import pypdf
42
+
43
  except ImportError:
44
+
45
  pypdf = None
46
 
47
+
48
+
49
  # --- 頁面設定 ---
50
+
51
  st.set_page_config(page_title="Cybersecurity AI Assistant (Gemini RAG)", page_icon="🛡️", layout="wide")
52
+
53
+ st.title("🛡️ Gemini-1.5-Flash with FAISS RAG")
54
+
55
  st.markdown("已啟用:**IndexFlatIP** + **L2 正規化** + **Google Gemini API**")
56
 
57
+
58
+
59
  # --- 側邊欄設定 ---
60
+
61
  with st.sidebar:
62
+
63
  st.header("⚙️ 設定")
64
+
65
 
66
+
67
+ # 改為 Google API Key
68
+
69
  default_key = os.getenv("GOOGLE_API_KEY", "")
70
+
71
  google_api_key = st.text_input("Google API Key", value=default_key, type="password")
72
+
73
 
74
+
75
  if not google_api_key:
76
+
77
  st.warning("請輸入 Google API Key 以繼續。")
78
+
79
 
80
+
81
  st.divider()
82
+
83
  st.subheader("📂 上傳分析檔案 (建立 RAG 庫)")
84
+
85
  uploaded_file = st.file_uploader("上傳 Logs/PDF/Code", type=['txt', 'py', 'log', 'csv', 'md', 'json', 'pdf'])
86
+
87
 
88
+
89
  st.divider()
90
+
91
  st.subheader("🔍 RAG 檢索設定")
92
+
93
  similarity_threshold = st.slider(
94
+
95
  "📐 Cosine Similarity 門檻",
96
+
97
  0.0, 1.0, 0.4, 0.01,
98
+
99
  help="數值越大越相似。一般建議 0.4~0.7"
100
+
101
  )
102
+
103
 
104
+
105
  st.divider()
106
+
107
  st.subheader("模型參數")
108
+
109
+ system_prompt = st.text_area("System Prompt", value="You are a Senior Security Analyst. Use the retrieved context to answer the user's question. Every claim you make MUST be supported by a specific Event Record ID from the retrieved context.", height=100)
110
+
111
+ # Gemini 不需要 max_new_tokens 來限制記憶體,但可以設定輸出上限
112
+
113
+ max_output_tokens = st.slider("Max Output Tokens", 128, 8192, 2048, 128)
114
+
115
+ temperature = st.slider("Temperature", 0.0, 2.0, 0.1, 0.1)
116
+
117
 
118
+
119
  st.divider()
120
+
121
  if st.button("🗑️ 清除對話紀錄"):
122
+
123
  st.session_state.messages = []
124
+
125
  st.rerun()
126
 
127
+
128
+
129
  # --- 初始化 Gemini ---
130
+
131
  genai_model = None
132
+
133
  if google_api_key:
134
+
135
  try:
136
+
137
  genai.configure(api_key=google_api_key)
138
+
139
  # 使用 Flash 模型,速度快且便宜,適合 RAG 大量文本閱讀
140
+
141
+ genai_model = genai.GenerativeModel('gemini-2.5-pro')
142
+
143
  except Exception as e:
144
+
145
  st.error(f"Gemini 設定失敗: {e}")
146
 
147
+
148
+
149
  # === Embedding 模型 (保留原本的 Jina 或其他 HF 模型) ===
150
+
151
+ # Embedding 還是建議用專門的模型,不一定要換成 Google 的 Embedding
152
+
153
  @st.cache_resource
154
+
155
  def load_embedding_model():
156
+
157
  model_kwargs = {
158
+
159
+ 'device': 'cpu', # Embedding 通常 CPU 夠用,若有 GPU 也可改 cuda
160
+
161
  'trust_remote_code': True
162
+
163
  }
164
+
165
  encode_kwargs = {
166
+
167
  'normalize_embeddings': False
168
+
169
  }
170
+
171
  return HuggingFaceEmbeddings(
172
+
173
  model_name="jinaai/jina-embeddings-v2-base-code",
174
+
175
  model_kwargs=model_kwargs,
176
+
177
  encode_kwargs=encode_kwargs
178
+
179
  )
180
 
181
+
182
+
183
  with st.spinner("正在載入 Embedding 模型..."):
184
+
185
  embedding_model = load_embedding_model()
186
 
187
+
188
+
189
  # === 建立向量庫 (Strict Cosine) - 邏輯維持不變 ===
190
+
191
  def process_file_to_faiss(uploaded_file):
192
+
193
  text_content = ""
194
+
195
  try:
196
+
197
  if uploaded_file.type == "application/pdf":
198
+
199
  if pypdf:
200
+
201
  pdf_reader = pypdf.PdfReader(uploaded_file)
202
+
203
  for page in pdf_reader.pages:
204
+
205
  text_content += page.extract_text() + "\n"
206
+
207
  else:
208
+
209
  return None, "PDF library missing"
210
+
211
  else:
212
+
213
  stringio = io.StringIO(uploaded_file.getvalue().decode("utf-8"))
214
+
215
  text_content = stringio.read()
216
+
217
 
218
+
219
  if not text_content.strip():
220
+
221
  return None, "File is empty"
222
 
223
+
224
+
225
  # 簡單切分
226
+
227
  events = [e + "</Event>" for e in text_content.split("</Event>") if e.strip()]
228
+
229
  if len(events) <= 1:
230
+
231
  events = [line for line in text_content.split("\n") if line.strip()]
232
+
233
 
234
+
235
  docs = [Document(page_content=e) for e in events]
236
+
237
 
238
+
239
  if not docs:
240
+
241
  return None, "No documents created"
242
 
243
+
244
+
245
  embeddings = embedding_model.embed_documents([d.page_content for d in docs])
246
+
247
  embeddings_np = np.array(embeddings).astype("float32")
248
+
249
  faiss.normalize_L2(embeddings_np)
250
+
251
 
252
+
253
  dimension = embeddings_np.shape[1]
254
+
255
  index = faiss.IndexFlatIP(dimension)
256
+
257
  index.add(embeddings_np)
258
+
259
 
260
+
261
  doc_ids = [str(uuid.uuid4()) for _ in range(len(docs))]
262
+
263
  docstore = InMemoryDocstore({_id: doc for _id, doc in zip(doc_ids, docs)})
264
+
265
  index_to_docstore_id = {i: _id for i, _id in enumerate(doc_ids)}
266
+
267
 
268
+
269
  vector_store = FAISS(
270
+
271
  embedding_function=embedding_model,
272
+
273
  index=index,
274
+
275
  docstore=docstore,
276
+
277
  index_to_docstore_id=index_to_docstore_id,
278
+
279
  distance_strategy=DistanceStrategy.COSINE
280
+
281
  )
282
+
283
 
284
+
285
+ return vector_store, f"{len(docs)} chunks created."
286
+
287
  except Exception as e:
288
+
289
+ return None, f"Error: {str(e)}"
290
+
291
+
292
 
293
  # === 檔案處理邏輯 ===
294
+
295
  if uploaded_file:
296
+
297
  file_key = f"vs_{uploaded_file.name}_{uploaded_file.size}"
298
+
299
 
300
+
301
  if "current_file_key" not in st.session_state or st.session_state.current_file_key != file_key:
302
+
303
  with st.spinner("偵測到新檔案,正在更新知識庫..."):
304
+
305
  vs, msg = process_file_to_faiss(uploaded_file)
306
+
307
  if vs:
308
+
309
  st.session_state.vector_store = vs
310
+
311
  st.session_state.current_file_key = file_key
312
+
313
  st.toast(f"知識庫已更新!{msg}", icon="✅")
314
+
315
  else:
316
+
317
  st.error(msg)
318
+
319
  else:
320
+
321
  if "vector_store" in st.session_state:
322
+
323
  del st.session_state.vector_store
324
+
325
  st.info("檔案已移除,已清除知識庫,回到一般模式。")
326
+
327
  if "current_file_key" in st.session_state:
328
+
329
+ del st.session_state.current_file_key
330
+
331
+
332
 
333
  # === 顯示對話歷史 ===
334
+
335
  if "messages" not in st.session_state:
336
+
337
  st.session_state.messages = []
338
 
339
+
340
+
341
  for idx, message in enumerate(st.session_state.messages):
342
+
343
  with st.chat_message(message["role"]):
344
+
345
  st.markdown(message["content"])
346
+
347
  if message.get("context"):
348
+
349
+ with st.expander(f"查看參考片段 (Turn {idx})"):
350
+
351
+ st.code(message["context"])
352
+
 
 
 
353
  st.download_button(
354
+
355
  label="📥 下載此參考內容 (.txt)",
356
+
357
  data=message["context"],
358
+
359
  file_name=f"rag_context_{idx}.txt",
360
+
361
  mime="text/plain",
362
+
363
  key=f"dl_btn_{idx}"
364
+
365
  )
366
 
367
+
368
+
369
  # === Search 函數 ===
370
+
371
  def faiss_cosine_search_all(vector_store, query, threshold):
372
+
373
  q_emb = embedding_model.embed_query(query)
374
+
375
  q_emb = np.array([q_emb]).astype("float32")
376
+
377
  faiss.normalize_L2(q_emb)
378
+
379
 
380
+
381
  index = vector_store.index
382
+
383
  D, I = index.search(q_emb, k=index.ntotal)
384
+
385
 
386
+
387
  selected = []
388
+
389
  for score, idx in zip(D[0], I[0]):
390
+
391
  if idx == -1: continue
392
+
393
  if score >= threshold:
394
+
395
  doc_id = vector_store.index_to_docstore_id[idx]
396
+
397
  doc = vector_store.docstore.search(doc_id)
398
+
399
  selected.append((doc, score))
400
+
401
 
402
+
403
  selected.sort(key=lambda x: x[1], reverse=True)
404
+
405
  return selected
406
 
407
+
408
+
409
  # === Gemini 產生回答 ===
410
+
411
  def generate_rag_response_gemini(prompt, history, sys_prompt, vector_store=None, threshold=0.5):
412
+
413
  context_text = ""
414
+
415
+ top_k_selected = []
416
+
417
 
418
+
419
  # 1. 檢索
420
+
421
  if vector_store:
422
+
423
  selected = faiss_cosine_search_all(vector_store, prompt, threshold)
424
+
425
  if selected:
426
+
427
+ top_k_selected = selected
428
+
429
+ # 取前 30 個或更多 (Gemini Context Window 很大,可以塞多一點)
430
+
431
  retrieved_contents = [
432
+
433
  f"--- Chunk (sim={score:.3f}) ---\n{doc.page_content}"
434
+
435
+ for i, (doc, score) in enumerate(top_k_selected[:30])
436
+
437
  ]
438
+
439
  context_text = "\n".join(retrieved_contents)
440
 
441
+
442
+
443
  # 2. 構建 Prompt
444
+
445
  if context_text:
446
+
447
  full_user_input = f"""
448
+
449
+ System Instruction: {sys_prompt}
450
+
451
+
452
 
453
  === RETRIEVED CONTEXT (Cosine ≥ {threshold}) ===
454
+
455
  {context_text}
456
+
457
  === END CONTEXT ===
458
 
459
+
460
+
461
  Question: {prompt}
462
+
463
+ Answer the question strictly based on the provided context.
464
+
465
  """
466
+
467
  else:
468
+
469
  full_user_input = f"""
470
+
471
  System Instruction: {sys_prompt}
472
 
473
+
474
+
475
  Question: {prompt}
476
+
477
  """
478
 
479
+
480
+
481
  # 3. 轉換歷史訊息格式 (Streamlit -> Gemini)
482
+
483
+ # Gemini 格式: [{'role': 'user', 'parts': [...]}, {'role': 'model', 'parts': [...]}]
484
+
485
  gemini_history = []
486
+
487
  for msg in history:
488
+
489
  role = "user" if msg["role"] == "user" else "model"
490
+
491
+ # 濾除非文字內容 (簡單處理)
492
+
493
+ content_text = msg["content"]
494
+
495
+ # 這裡不把之前的 context 重複塞入歷史,避免 context window 爆炸或混淆,僅傳遞純對話
496
+
497
+ gemini_history.append({"role": role, "parts": [content_text]})
498
+
499
+
500
 
501
  # 4. 呼叫 Gemini
502
+
503
  try:
504
+
505
+ chat = genai_model.start_chat(history=gemini_history)
506
+
507
+
508
+
509
  # 設定生成參數
510
+
511
+ generation_config = genai.types.GenerationConfig(
512
+
513
  candidate_count=1,
514
+
515
  max_output_tokens=max_output_tokens,
516
+
517
  temperature=temperature,
518
+
519
  )
520
+
521
 
522
+
523
+ # 安全設定 (設為 BLOCK_NONE 以避免資安 Log 被誤判為有害內容)
524
+
525
  safety_settings = [
526
+
527
+ {
528
+
529
+ "category": "HARM_CATEGORY_HARASSMENT",
530
+
531
+ "threshold": "BLOCK_NONE",
532
+
533
+ },
534
+
535
+ {
536
+
537
+ "category": "HARM_CATEGORY_HATE_SPEECH",
538
+
539
+ "threshold": "BLOCK_NONE",
540
+
541
+ },
542
+
543
+ {
544
+
545
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
546
+
547
+ "threshold": "BLOCK_NONE",
548
+
549
+ },
550
+
551
+ {
552
+
553
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
554
+
555
+ "threshold": "BLOCK_NONE",
556
+
557
+ },
558
+
559
  ]
560
 
561
+
562
+
563
  response = chat.send_message(
564
+
565
  full_user_input,
566
+
567
  generation_config=generation_config,
568
+
569
  safety_settings=safety_settings
570
+
571
  )
572
+
 
 
 
 
 
 
573
  return response.text, context_text
574
 
575
+
576
+
577
  except Exception as e:
578
+
579
+ return f"Gemini API Error: {str(e)}", context_text
580
+
581
+
582
 
583
  # === 處理使用者輸入 ===
584
+
585
  if prompt := st.chat_input("請輸入問題..."):
586
+
587
+ if not genai_model:
588
+
589
  st.error("請先輸入有效的 Google API Key")
590
+
 
591
  else:
592
+
593
  vs = st.session_state.get("vector_store", None)
594
+
595
  display_prompt = prompt
596
+
597
+ if vs:
598
+
599
+ display_prompt = f"🔍 **[RAG]** {prompt}"
600
+
601
 
602
+
603
+ st.chat_message("user").markdown(display_prompt)
604
+
605
 
606
+
607
  with st.chat_message("assistant"):
608
+
609
  msg_placeholder = st.empty()
610
+
611
 
612
+
613
+ with st.spinner("Gemini Thinking..."):
614
+
615
  response, retrieved_ctx = generate_rag_response_gemini(
616
+
617
  prompt,
618
+
619
  st.session_state.messages,
620
+
621
  system_prompt,
622
+
623
  vector_store=vs,
624
+
625
  threshold=similarity_threshold,
626
+
627
  )
628
+
629
 
630
+
631
  msg_placeholder.markdown(response)
632
+
633
 
634
+
635
  if retrieved_ctx:
636
+
637
  with st.expander("查看檢索到的參考片段"):
638
+
639
+ st.code(retrieved_ctx)
640
+
641
  st.download_button(
642
+
643
  label="📥 下載此參考內容 (.txt)",
644
+
645
  data=retrieved_ctx,
646
+
647
  file_name=f"rag_context_current.txt",
648
+
649
  mime="text/plain"
650
+
651
  )
652
 
653
+
654
+
655
+ # 更新歷史
656
+
657
+ st.session_state.messages.append({"role": "user", "content": display_prompt})
658
+
659
  st.session_state.messages.append({
660
+
661
  "role": "assistant",
662
+
663
  "content": response,
664
+
665
  "context": retrieved_ctx
666
+
667
  })