ss900371tw commited on
Commit
883a586
·
verified ·
1 Parent(s): d60d172

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +196 -319
src/streamlit_app.py CHANGED
@@ -1,29 +1,30 @@
 
 
1
  import streamlit as st
2
  import os
3
  import io
4
  import json
 
5
  import numpy as np
6
  import faiss
7
  import uuid
8
  import time
9
  import sys
 
10
  # === HuggingFace 模型相關套件 (替換為 InferenceClient) ===
11
  try:
12
  from huggingface_hub import InferenceClient
13
- # 移除本地模型相關導入
14
- # from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
15
- # import torch
16
  except ImportError:
17
  st.error("請檢查是否安裝了所有 Hugging Face 相關依賴:pip install huggingface-hub")
18
- # InferenceClient = None # 保留 InferenceClient
19
-
20
  # === LangChain/RAG 相關套件 (保持不變) ===
21
  from langchain_community.embeddings import HuggingFaceEmbeddings
22
  from langchain_core.documents import Document
23
  from langchain_community.vectorstores import FAISS
24
  from langchain_community.vectorstores.utils import DistanceStrategy
25
  from langchain_community.docstore.in_memory import InMemoryDocstore
26
- # 嘗試匯入 pypdf
 
27
  try:
28
  import pypdf
29
  except ImportError:
@@ -32,29 +33,22 @@ except ImportError:
32
  # --- 頁面設定 ---
33
  st.set_page_config(page_title="Cybersecurity AI Assistant (Hugging Face RAG & Batch Analysis)", page_icon="🛡️", layout="wide")
34
  st.title("🛡️ Meta-Llama-3-8B-Instruct with FAISS RAG & Batch Analysis (Inference Client)")
35
- st.markdown("已啟用:**IndexFlatIP** + **L2 正規化** + **Hugging Face Inference Client (API)**。上傳 JSON 執行批量分析,上傳其他檔案作為 RAG 參考庫。")
36
-
37
 
38
  if 'execute_batch_analysis' not in st.session_state:
39
  st.session_state.execute_batch_analysis = False
40
-
41
  if 'batch_results' not in st.session_state:
42
- st.session_state.batch_results = None # 確保初始化
43
-
44
  if 'rag_current_file_key' not in st.session_state:
45
  st.session_state.rag_current_file_key = None
46
-
47
- if 'json_current_file_key' not in st.session_state:
48
- st.session_state.json_current_file_key = None
49
-
50
- # 確保所有用於存儲狀態的變量都已初始化,例如:
51
  if 'vector_store' not in st.session_state:
52
  st.session_state.vector_store = None
53
-
54
- if 'json_data_for_batch' not in st.session_state:
55
  st.session_state.json_data_for_batch = None
56
-
57
- # 設定模型 ID (替換為您指定的模型)
58
  MODEL_ID = "eojin0312/llama2_security_231214"
59
  WINDOW_SIZE = 8
60
 
@@ -62,8 +56,6 @@ WINDOW_SIZE = 8
62
  with st.sidebar:
63
  st.header("⚙️ 設定")
64
 
65
- # === 替換為 Hugging Face 模型名稱顯示 (移除 API Key 輸入) ===
66
- # ⚠️ 注意: HF Token 必須在環境變數 HF_TOKEN 中設定
67
  if not os.environ.get("HF_TOKEN"):
68
  st.error("環境變數 **HF_TOKEN** 未設定。請設定後重新啟動應用程式。")
69
 
@@ -71,52 +63,49 @@ with st.sidebar:
71
  st.warning("⚠️ **注意**: 該模型使用 Inference API 呼叫,請確保您的 HF Token 具有存取權限。")
72
 
73
  st.divider()
74
-
75
  st.subheader("📂 檔案上傳")
76
- # === 1. JSON 批量分析檔案 (新的上傳器) ===
77
- json_uploaded_file = st.file_uploader(
78
- "1️⃣ 上傳 **JSON** Log/Alert 檔案 (用於批量分析)",
79
- type=['json'],
80
- key="json_uploader"
 
 
81
  )
82
- # === 2. RAG 知識庫檔案 (新的上傳器) ===
 
83
  rag_uploaded_file = st.file_uploader(
84
  "2️⃣ 上傳 **RAG 參考知識庫** (Logs/PDF/Code 等)",
85
  type=['txt', 'py', 'log', 'csv', 'md', 'pdf'],
86
  key="rag_uploader"
87
  )
 
88
  st.divider()
89
 
90
- st.subheader("💡 批量分析指令 (針對 JSON 檔案)")
91
  analysis_prompt = st.text_area(
92
  "針對每個 Log/Alert 執行的指令",
93
  value="You are a security expert in charge of analyzing a single alert and prioritizing its criticality. Respond with a clear, structured analysis using the following mandatory sections: \n\n- Criticality/Priority: Is this alert critical? (Answer Yes/No only), and provide the overall priority level. (Answer High, Medium, or Low only) \n- Explanation: If this alert is critical or medium~high priority level, explain the potential impact and why this specific alert requires attention. If not, omit the explanation section. \n- Action Plan: If this alert is critical or medium~high priority level, What should be the immediate steps to address this specific alert? If not, omit the action plan section. \n\nStrictly use the information in the provided Log.",
94
  height=200
95
  )
96
- st.markdown("此指令將對 JSON 檔案中的**每一個 Log 條目**執行一次獨立分析。")
97
 
98
- if json_uploaded_file: # 移除 API Key 檢查
99
  if st.button("🚀 執行批量分析"):
100
  if not os.environ.get("HF_TOKEN"):
101
  st.error("無法執行,環境變數 **HF_TOKEN** 未設定。")
102
- else:
103
- st.session_state.execute_batch_analysis = True
104
  else:
105
- st.info("請上傳 JSON 檔案以啟用批量分析按鈕。")
106
 
107
  st.divider()
108
-
109
  st.subheader("🔍 RAG 檢索設定")
110
- similarity_threshold = st.slider(
111
- "📐 Cosine Similarity 門檻",
112
- 0.0, 1.0, 0.4, 0.01,
113
- help="數值越大越相似。一般建議 0.4~0.7"
114
- )
115
- st.divider()
116
 
 
117
  st.subheader("模型參數")
118
- # Llama 3 使用 'system' 角色
119
- system_prompt = st.text_area("System Prompt (LLM 使用)", value="You are a Senior Security Analyst, named Ernest. You provide expert, authoritative, and concise advice on Information Security, Network Security, and Cyber Threat Intelligence. Your analysis must be based strictly on the provided context.", height=100)
120
  max_output_tokens = st.slider("Max Output Tokens", 128, 4096, 2048, 128)
121
  temperature = st.slider("Temperature", 0.0, 1.0, 0.1, 0.1)
122
  top_p = st.slider("Top P", 0.1, 1.0, 0.95, 0.05)
@@ -124,63 +113,42 @@ with st.sidebar:
124
  st.divider()
125
  if st.button("🗑️ 清除所有紀錄"):
126
  for key in list(st.session_state.keys()):
127
- if key not in []:
128
- del st.session_state[key]
129
  st.rerun()
130
 
131
- # --- 初始化 Hugging Face LLM Client (重大替換) ---
132
  @st.cache_resource
133
  def load_inference_client(model_id):
134
- if not os.environ.get("HF_TOKEN"):
135
- return None
136
-
137
  try:
138
- # 使用 InferenceClient 替換 AutoModelForCausalLM 的載入
139
- client = InferenceClient(
140
- model_id,
141
- token=os.environ.get("HF_TOKEN")
142
- )
143
  st.success(f"Hugging Face Inference Client **{model_id}** 載入成功。")
144
  return client
145
  except Exception as e:
146
  st.error(f"Hugging Face Inference Client 載入失敗: {e}")
147
  return None
148
 
149
- # 在 main 區塊外初始化 client
150
  inference_client = None
151
  if os.environ.get("HF_TOKEN"):
152
  with st.spinner(f"正在連線到 Inference Client: {MODEL_ID}..."):
153
  inference_client = load_inference_client(MODEL_ID)
154
-
155
  if inference_client is None and os.environ.get("HF_TOKEN"):
156
- st.warning("Hugging Face Inference Client 無法連線。請檢查您的 HF Token 和模型存取權限。")
157
- elif not os.environ.get("HF_TOKEN"):
158
- st.error("請在環境變數中設定 HF_TOKEN 以啟用 LLM。")
159
- # =======================================================================
160
 
161
- # === Embedding 模型 (用於 RAG 參考庫) (保持不變) ===
162
  @st.cache_resource
163
  def load_embedding_model():
164
- model_kwargs = {
165
- 'device': 'cpu',
166
- 'trust_remote_code': True
167
- }
168
- encode_kwargs = {
169
- 'normalize_embeddings': False
170
- }
171
- # 選擇一個適合 RAG 的中文 Embedding Model
172
- return HuggingFaceEmbeddings(
173
- model_name="BAAI/bge-large-zh-v1.5",
174
- model_kwargs=model_kwargs,
175
- encode_kwargs=encode_kwargs
176
- )
177
 
178
  with st.spinner("正在載入 Embedding 模型..."):
179
  embedding_model = load_embedding_model()
180
 
181
  # === 建立向量庫 / Search 函數 (保持不變) ===
182
  def process_file_to_faiss(uploaded_file):
183
- # 函數內容保持不變 (與原代碼相同)
184
  text_content = ""
185
  try:
186
  if uploaded_file.type == "application/pdf":
@@ -188,337 +156,246 @@ def process_file_to_faiss(uploaded_file):
188
  pdf_reader = pypdf.PdfReader(uploaded_file)
189
  for page in pdf_reader.pages:
190
  text_content += page.extract_text() + "\n"
191
- else:
192
- return None, "PDF library missing"
193
  else:
194
  stringio = io.StringIO(uploaded_file.getvalue().decode("utf-8"))
195
  text_content = stringio.read()
196
-
197
- if not text_content.strip():
198
- return None, "File is empty"
199
-
200
- # 嘗試以 </Event> 分割 Log,否則以換行符分割
201
  events = [line for line in text_content.splitlines() if line.strip()]
202
-
203
  docs = [Document(page_content=e) for e in events]
204
-
205
- if not docs:
206
- return None, "No documents created"
207
-
208
  embeddings = embedding_model.embed_documents([d.page_content for d in docs])
209
  embeddings_np = np.array(embeddings).astype("float32")
210
- faiss.normalize_L2(embeddings_np) # L2 正規化
211
-
212
  dimension = embeddings_np.shape[1]
213
- index = faiss.IndexFlatIP(dimension) # IndexFlatIP (內積)
214
  index.add(embeddings_np)
215
-
216
  doc_ids = [str(uuid.uuid4()) for _ in range(len(docs))]
217
  docstore = InMemoryDocstore({_id: doc for _id, doc in zip(doc_ids, docs)})
218
  index_to_docstore_id = {i: _id for i, _id in enumerate(doc_ids)}
219
-
220
- vector_store = FAISS(
221
- embedding_function=embedding_model,
222
- index=index,
223
- docstore=docstore,
224
- index_to_docstore_id=index_to_docstore_id,
225
- distance_strategy=DistanceStrategy.COSINE # 使用 Cosine 距離 (對應 IndexFlatIP)
226
- )
227
-
228
  return vector_store, f"{len(docs)} chunks created."
229
  except Exception as e:
230
  return None, f"Error: {str(e)}"
231
 
232
  def faiss_cosine_search_all(vector_store, query, threshold):
233
- # 函數內容保持不變 (與原代碼相同)
234
  q_emb = embedding_model.embed_query(query)
235
  q_emb = np.array([q_emb]).astype("float32")
236
  faiss.normalize_L2(q_emb)
237
-
238
  index = vector_store.index
239
  D, I = index.search(q_emb, k=index.ntotal)
240
-
241
  selected = []
242
  for score, idx in zip(D[0], I[0]):
243
  if idx == -1: continue
244
- # IndexFlatIP 輸出內積,與歸一化後的 Cosine Similarity 相同
245
  if score >= threshold:
246
  doc_id = vector_store.index_to_docstore_id[idx]
247
  doc = vector_store.docstore.search(doc_id)
248
  selected.append((doc, score))
249
-
250
  selected.sort(key=lambda x: x[1], reverse=True)
251
  return selected
252
 
253
- # === Hugging Face 生成單一 Log 分析回答 (核心批量處理函數 - 重大替換為 InferenceClient) ===
254
  def generate_rag_response_hf_for_log(client, model_id, log_sequence_text, user_prompt, sys_prompt, vector_store, threshold, max_output_tokens, temperature, top_p):
255
- """
256
- 使用 Hugging Face Inference Client 執行 RAG 增強的 Log 序列分析。
257
- """
258
- if client is None:
259
- return "ERROR: Hugging Face Inference Client 未載入或 HF_TOKEN 未設定。", ""
260
-
261
  context_text = ""
262
- # 1. RAG 檢索邏輯 (保持不變)
263
  if vector_store:
264
  selected = faiss_cosine_search_all(vector_store, log_sequence_text, threshold)
265
  if selected:
266
- retrieved_contents = [
267
- f"--- Reference Chunk (sim={score:.3f}) ---\n{doc.page_content}"
268
- for i, (doc, score) in enumerate(selected[:5]) # 限制檢索結果數量
269
- ]
270
  context_text = "\n".join(retrieved_contents)
271
-
272
- # 2. 建構 Llama 3 ChatML 格式的 Messages 列表
273
- rag_instruction = f"""=== RETRIEVED REFERENCE CONTEXT (Cosine ≥ {threshold}) ==={context_text if context_text else 'No relevant reference context found.'}=== END REFERENCE CONTEXT ===
274
-
275
- ANALYSIS INSTRUCTION: {user_prompt}
276
- Based on the provided LOG SEQUENCE and REFERENCE CONTEXT, you must analyze the **entire sequence** to detect any continuous attack chains or evolving threats. Focus on the **last log entry in the sequence** to determine its final criticality and priority, considering the preceding {WINDOW_SIZE} logs."""
277
-
278
  log_content_section = f"""=== CURRENT LOG SEQUENCE TO ANALYZE (Window Size: {WINDOW_SIZE}) ===\n{log_sequence_text}\n=== END LOG SEQUENCE ==="""
279
 
280
- # 整合 System Prompt、RAG、和 Log 內容到 messages 列表
281
- # Llama 3 標準的 chat 格式
282
  messages = [
283
  {"role": "system", "content": sys_prompt},
284
  {"role": "user", "content": f"{rag_instruction}\n\n{log_content_section}"}
285
  ]
286
-
287
- # 3. 呼叫 Hugging Face Inference Client
288
  try:
289
- # 使用 client.chat_completion 替換 pipeline 呼叫
290
- response_stream = client.chat_completion(
291
- messages,
292
- max_tokens=max_output_tokens,
293
- temperature=temperature,
294
- top_p=top_p,
295
- stream=False, # 由於是批量分析,不啟用流式輸出,一次性獲得結果
296
- )
297
-
298
- # 處理 chat_completion 的輸出格式 (非流式)
299
  if response_stream and response_stream.choices:
300
- # chat_completion 在非流式下返回一個 ChatCompletionResponse
301
- generated_text = response_stream.choices[0].message.content
302
- return generated_text.strip(), context_text
303
- else:
304
- return f"Hugging Face Inference Client 輸出格式錯誤: {response_stream}", context_text
305
-
306
- except Exception as e:
307
- # 如果模型呼叫失敗,回傳詳細錯誤訊息
308
- return f"Hugging Face Model Error: {str(e)}", context_text
309
-
310
- # === 檔案處理和主執行邏輯 (保持結構,替換 LLM 呼叫) ===
311
 
312
-
 
313
  if rag_uploaded_file:
314
  file_key = f"vs_{rag_uploaded_file.name}_{rag_uploaded_file.size}"
315
-
316
  if st.session_state.rag_current_file_key != file_key or 'vector_store' not in st.session_state:
317
- # 偵測到新 RAG 檔案,需要重新建立知識庫
318
  with st.spinner(f"正在建立 RAG 參考知識庫 ({rag_uploaded_file.name})..."):
319
  vs, msg = process_file_to_faiss(rag_uploaded_file)
320
  if vs:
321
  st.session_state.vector_store = vs
322
  st.session_state.rag_current_file_key = file_key
323
  st.toast(f"RAG 參考知識庫已更新!{msg}", icon="✅")
324
- else:
325
- st.error(msg)
326
- # 檔案移除/狀態清理 (如果使用者移除了 RAG 檔案)
327
  elif 'vector_store' in st.session_state:
328
  del st.session_state.vector_store
329
  del st.session_state.rag_current_file_key
330
  st.info("RAG 檔案已移除,已清除相關知識庫。")
331
 
332
-
333
-
334
- if json_uploaded_file:
335
- json_file_key = f"json_{json_uploaded_file.name}_{json_uploaded_file.size}"
336
 
337
- if st.session_state.json_current_file_key != json_file_key or 'json_data_for_batch' not in st.session_state:
338
  try:
339
- # 偵測到新 JSON 檔案
340
- json_data = json.load(io.StringIO(json_uploaded_file.getvalue().decode("utf-8")))
341
- st.session_state.json_data_for_batch = json_data
342
- st.session_state.json_current_file_key = json_file_key
343
- st.toast("JSON Log 檔案已載入,請按 '執行批量分析'。", icon="📄")
344
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  except Exception as e:
346
- st.error(f"JSON 檔案解析錯誤: {e}")
347
  if 'json_data_for_batch' in st.session_state:
348
  del st.session_state.json_data_for_batch
349
- # 檔案移除/狀態清理 (如果使用者移除了 JSON 檔案)
350
  elif 'json_data_for_batch' in st.session_state:
351
  del st.session_state.json_data_for_batch
352
- del st.session_state.json_current_file_key
353
  if "batch_results" in st.session_state:
354
  del st.session_state.batch_results
355
- st.info("JSON 檔案已移除,已清除日誌數據和分析結果。")
356
 
357
- # === 執行批量分析邏輯 (包含顏色控制) ===
358
  if st.session_state.execute_batch_analysis and 'json_data_for_batch' in st.session_state:
359
  st.session_state.execute_batch_analysis = False
360
- start_time = time.time() # 開始計時
361
  st.session_state.batch_results = []
362
 
363
  if inference_client is None:
364
- st.error("Hugging Face Inference Client 未載入,請檢查 HF_TOKEN 和網路連線,無法執行批量分析。")
365
- st.session_state.execute_batch_analysis = False
 
 
366
 
367
- data_to_process = st.session_state.json_data_for_batch
368
-
369
- # 提取 Log 列表的邏輯 (保持不變)
370
- logs_list = []
371
- if isinstance(data_to_process, list):
372
- logs_list = data_to_process
373
- elif isinstance(data_to_process, dict):
374
- if all(isinstance(v, (dict, str, list)) for v in data_to_process.values()):
375
- logs_list = list(data_to_process.values())
376
- elif 'alerts' in data_to_process and isinstance(data_to_process['alerts'], list):
377
- logs_list = data_to_process['alerts']
378
- elif 'logs' in data_to_process and isinstance(data_to_process['logs'], list):
379
- logs_list = data_to_process['logs']
380
  else:
381
  logs_list = [data_to_process]
382
- else:
383
- logs_list = [data_to_process]
384
-
385
- if logs_list:
386
- vs = st.session_state.get("vector_store", None)
387
- if vs:
388
- st.success("✅ RAG 知識庫已啟用並用於分析。")
389
- else:
390
- st.warning("⚠️ RAG 知識庫未載入,將單純執行 Log 分析。")
391
-
392
- # --- 新增:創建平移視窗序列 --- (保持不變)
393
- # 將所有 Log 轉換為 JSON 格式化字串列表,以便後續拼接
394
- formatted_logs = [json.dumps(log, indent=2, ensure_ascii=False) for log in logs_list]
395
-
396
- # 創建要分析的序列 (Sliding Window) 列表
397
- analysis_sequences = []
398
-
399
- for i in range(len(formatted_logs)):
400
- start_index = max(0, i - WINDOW_SIZE + 1)
401
- end_index = i + 1 # 終點為當前 Log
402
-
403
- current_window = formatted_logs[start_index:end_index]
404
 
405
- sequence_text = []
406
- for j, log_str in enumerate(current_window):
407
- is_target = " <<< TARGET LOG TO ANALYZE" if j == len(current_window) - 1 else ""
408
- # 使用 i-len(current_window)+j+1 來計算原始索引
409
- sequence_text.append(f"--- Log Index {i - len(current_window) + j + 1} ({len(current_window)-j} prior logs){is_target} ---\n{log_str}")
410
 
411
- analysis_sequences.append({
412
- "sequence_text": "\n\n".join(sequence_text),
413
- "target_log_id": i + 1, # 該序列的分析目標是原始列表中的第 i+1 條 Log
414
- "original_log_entry": logs_list[i]
415
- })
416
 
417
- total_sequences = len(analysis_sequences)
418
- if total_sequences < WINDOW_SIZE:
419
- st.warning(f"Log 總數 ({total_sequences}) 少於視窗大小 ({WINDOW_SIZE}),分析的結果可能較不準確。")
 
 
 
 
 
 
 
 
 
 
 
420
 
421
- # --- 執行序列分析 ---
422
- st.header(f"⚡ 批量分析執行中 (平移視窗 $N={WINDOW_SIZE}$)...")
423
- progress_bar = st.progress(0, text=f"準備處理 {total_sequences} 個序列...")
424
- results_container = st.container()
425
- full_report_chunks = ["## Cybersecurity Batch Analysis Report\n\n"]
426
-
427
- priority_keyword = "Criticality/Priority:"
428
-
429
- for i, seq_data in enumerate(analysis_sequences):
430
- log_id = seq_data["target_log_id"]
431
- progress_bar.progress((i + 1) / total_sequences, text=f"已處理 {i + 1}/{total_sequences} 個序列 (目標 Log #{log_id})...")
432
 
433
- try:
434
- # *** 替換為 Inference Client 呼叫函數 ***
435
- response, retrieved_ctx = generate_rag_response_hf_for_log(
436
- client=inference_client, # <--- 新的 Inference Client
437
- model_id=MODEL_ID,
438
- log_sequence_text=seq_data["sequence_text"],
439
- user_prompt=analysis_prompt,
440
- sys_prompt=system_prompt,
441
- vector_store=vs,
442
- threshold=similarity_threshold,
443
- max_output_tokens=max_output_tokens,
444
- temperature=temperature,
445
- top_p=top_p
446
- )
447
-
448
- # 儲存結果
449
- item = {
450
- "log_id": log_id,
451
- "log_content": seq_data["original_log_entry"], # 記錄原始 Log 條目
452
- "sequence_analyzed": seq_data["sequence_text"], # 記錄分析的序列
453
- "analysis_result": response,
454
- "context": retrieved_ctx
455
- }
456
- st.session_state.batch_results.append(item)
457
 
458
- # 結果顯示邏輯 (保持不變)
459
- with results_container:
460
- st.subheader(f"Log/Alert #{item['log_id']} (序列分析完成)")
461
- with st.expander(f"序列內容 (包含 {len(seq_data['sequence_text'].split('--- Log Index'))-1} 條 Log)"):
462
- st.code(item["sequence_analyzed"], language='text')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463
 
464
- # 顏色控制:
465
- is_high_priority = False
466
- if 'criticality/priority:' in response.lower():
467
- try:
468
- priority_section = response.split('Criticality/Priority:')[1].split('\n')[0].strip()
469
- if 'high' in priority_section.lower() or 'medium' in priority_section.lower() or 'yes' in priority_section.lower():
470
- is_high_priority = True
471
- except IndexError:
472
- pass
473
-
474
- st.markdown(f"### 🤖 分析結果 (針對 Log #{log_id})")
475
- if is_high_priority:
476
- st.error(item['analysis_result'])
477
- else:
478
- st.info(item['analysis_result'])
479
 
480
- if item['context']:
481
- with st.expander("參考的 RAG 知識庫片段"):
482
- st.code(item['context'])
483
- st.markdown("---")
484
-
485
- # 報告 chunks
486
- log_content_str_for_report = json.dumps(item["log_content"], indent=2, ensure_ascii=False).replace("`", "\\`")
487
- full_report_chunks.append(f"---\n\n### Log/Alert #{item['log_id']} (序列分析)\n\n#### 分析的序列內容\n```\n{seq_data['sequence_text']}\n```\n\n#### LLM 分析結果\n{item['analysis_result']}\n")
488
-
489
- except Exception as e:
490
- error_message = f"ERROR: Log {log_id} 序列處理失敗: {e}"
491
- st.session_state.batch_results.append({
492
- "log_id": log_id,
493
- "log_content": seq_data["original_log_entry"],
494
- "sequence_analyzed": seq_data["sequence_text"],
495
- "analysis_result": error_message,
496
- "context": ""
497
- })
498
- with results_container:
499
- st.error(error_message)
500
-
501
- end_time = time.time()
502
- progress_bar.empty()
503
- st.success(f"批量分析完成!共處理 {total_sequences} 個 Log 序列,耗時 {end_time - start_time:.2f} 秒。")
504
- st.divider()
505
-
506
- else:
507
- st.error("無法從上傳的 JSON 檔案中提取 Log 列表或有效的 Log 條目。請檢查檔案結構。")
508
 
509
- # === 顯示結果 (歷史紀錄) (保持不變) ===
510
  if st.session_state.get("batch_results") and not st.session_state.execute_batch_analysis:
511
- st.header("⚡ 上次分析結果 (歷史紀錄)")
512
-
513
- full_report_chunks = ["## Cybersecurity Batch Analysis Report\n\n"]
514
  for item in st.session_state.batch_results:
515
  log_content_str_for_report = json.dumps(item["log_content"], indent=2, ensure_ascii=False).replace("`", "\\`")
516
- full_report_chunks.append(f"---\n\n### Log/Alert #{item['log_id']}\n\n#### 原始內容\n```json\n{log_content_str_for_report}\n```\n\n#### LLM 分析結果\n{item['analysis_result']}\n")
517
-
518
- st.info(f"偵測到 {len(st.session_state.batch_results)} 條 Log 的歷史分析結果。")
519
- st.download_button(
520
- label="📥 下載上次的完整報告 (.md)",
521
- data="\n".join(full_report_chunks),
522
- file_name="security_batch_analysis_report_history.md",
523
- mime="text/markdown"
524
- )
 
1
+
2
+
3
  import streamlit as st
4
  import os
5
  import io
6
  import json
7
+ import csv # <--- 新增:用於處理 CSV
8
  import numpy as np
9
  import faiss
10
  import uuid
11
  import time
12
  import sys
13
+
14
  # === HuggingFace 模型相關套件 (替換為 InferenceClient) ===
15
  try:
16
  from huggingface_hub import InferenceClient
 
 
 
17
  except ImportError:
18
  st.error("請檢查是否安裝了所有 Hugging Face 相關依賴:pip install huggingface-hub")
19
+
 
20
  # === LangChain/RAG 相關套件 (保持不變) ===
21
  from langchain_community.embeddings import HuggingFaceEmbeddings
22
  from langchain_core.documents import Document
23
  from langchain_community.vectorstores import FAISS
24
  from langchain_community.vectorstores.utils import DistanceStrategy
25
  from langchain_community.docstore.in_memory import InMemoryDocstore
26
+
27
+ # 嘗試匯入 pypdftry
28
  try:
29
  import pypdf
30
  except ImportError:
 
33
  # --- 頁面設定 ---
34
  st.set_page_config(page_title="Cybersecurity AI Assistant (Hugging Face RAG & Batch Analysis)", page_icon="🛡️", layout="wide")
35
  st.title("🛡️ Meta-Llama-3-8B-Instruct with FAISS RAG & Batch Analysis (Inference Client)")
36
+ st.markdown("已啟用:**IndexFlatIP** + **L2 正規化** + **Hugging Face Inference Client (API)**。支援 JSON/CSV/TXT 執行批量分析。")
 
37
 
38
  if 'execute_batch_analysis' not in st.session_state:
39
  st.session_state.execute_batch_analysis = False
 
40
  if 'batch_results' not in st.session_state:
41
+ st.session_state.batch_results = None
 
42
  if 'rag_current_file_key' not in st.session_state:
43
  st.session_state.rag_current_file_key = None
44
+ if 'batch_current_file_key' not in st.session_state: # 修改變數名稱以反映多格式
45
+ st.session_state.batch_current_file_key = None
 
 
 
46
  if 'vector_store' not in st.session_state:
47
  st.session_state.vector_store = None
48
+ if 'json_data_for_batch' not in st.session_state: # 變數名稱保留,但內容可能是轉換後的 dict
 
49
  st.session_state.json_data_for_batch = None
50
+
51
+ # 設定模型 ID
52
  MODEL_ID = "eojin0312/llama2_security_231214"
53
  WINDOW_SIZE = 8
54
 
 
56
  with st.sidebar:
57
  st.header("⚙️ 設定")
58
 
 
 
59
  if not os.environ.get("HF_TOKEN"):
60
  st.error("環境變數 **HF_TOKEN** 未設定。請設定後重新啟動應用程式。")
61
 
 
63
  st.warning("⚠️ **注意**: 該模型使用 Inference API 呼叫,請確保您的 HF Token 具有存取權限。")
64
 
65
  st.divider()
 
66
  st.subheader("📂 檔案上傳")
67
+
68
+ # === 1. 批量分析檔案 (修改處:支援多種格式) ===
69
+ batch_uploaded_file = st.file_uploader(
70
+ "1️⃣ 上傳 **Log/Alert 檔案** (用於批量分析)",
71
+ type=['json', 'csv', 'txt'], # <--- 修改:新增 csv 和 txt
72
+ key="batch_uploader",
73
+ help="支援 JSON (Array), CSV (含標題), TXT (每行一條 Log)"
74
  )
75
+
76
+ # === 2. RAG 知識庫檔案 ===
77
  rag_uploaded_file = st.file_uploader(
78
  "2️⃣ 上傳 **RAG 參考知識庫** (Logs/PDF/Code 等)",
79
  type=['txt', 'py', 'log', 'csv', 'md', 'pdf'],
80
  key="rag_uploader"
81
  )
82
+
83
  st.divider()
84
 
85
+ st.subheader("💡 批量分析指令")
86
  analysis_prompt = st.text_area(
87
  "針對每個 Log/Alert 執行的指令",
88
  value="You are a security expert in charge of analyzing a single alert and prioritizing its criticality. Respond with a clear, structured analysis using the following mandatory sections: \n\n- Criticality/Priority: Is this alert critical? (Answer Yes/No only), and provide the overall priority level. (Answer High, Medium, or Low only) \n- Explanation: If this alert is critical or medium~high priority level, explain the potential impact and why this specific alert requires attention. If not, omit the explanation section. \n- Action Plan: If this alert is critical or medium~high priority level, What should be the immediate steps to address this specific alert? If not, omit the action plan section. \n\nStrictly use the information in the provided Log.",
89
  height=200
90
  )
91
+ st.markdown("此指令將對檔案中的**每一個 Log 條目**執行一次獨立分析。")
92
 
93
+ if batch_uploaded_file:
94
  if st.button("🚀 執行批量分析"):
95
  if not os.environ.get("HF_TOKEN"):
96
  st.error("無法執行,環境變數 **HF_TOKEN** 未設定。")
97
+ else:
98
+ st.session_state.execute_batch_analysis = True
99
  else:
100
+ st.info("請上傳 Log 檔案以啟用批量分析按鈕。")
101
 
102
  st.divider()
 
103
  st.subheader("🔍 RAG 檢索設定")
104
+ similarity_threshold = st.slider("📐 Cosine Similarity 門檻", 0.0, 1.0, 0.4, 0.01)
 
 
 
 
 
105
 
106
+ st.divider()
107
  st.subheader("模型參數")
108
+ system_prompt = st.text_area("System Prompt", value="You are a Senior Security Analyst, named Ernest. You provide expert, authoritative, and concise advice on Information Security. Your analysis must be based strictly on the provided context.", height=100)
 
109
  max_output_tokens = st.slider("Max Output Tokens", 128, 4096, 2048, 128)
110
  temperature = st.slider("Temperature", 0.0, 1.0, 0.1, 0.1)
111
  top_p = st.slider("Top P", 0.1, 1.0, 0.95, 0.05)
 
113
  st.divider()
114
  if st.button("🗑️ 清除所有紀錄"):
115
  for key in list(st.session_state.keys()):
116
+ del st.session_state[key]
 
117
  st.rerun()
118
 
119
+ # --- 初始化 Hugging Face LLM Client ---
120
  @st.cache_resource
121
  def load_inference_client(model_id):
122
+ if not os.environ.get("HF_TOKEN"): return None
 
 
123
  try:
124
+ client = InferenceClient(model_id, token=os.environ.get("HF_TOKEN"))
 
 
 
 
125
  st.success(f"Hugging Face Inference Client **{model_id}** 載入成功。")
126
  return client
127
  except Exception as e:
128
  st.error(f"Hugging Face Inference Client 載入失敗: {e}")
129
  return None
130
 
 
131
  inference_client = None
132
  if os.environ.get("HF_TOKEN"):
133
  with st.spinner(f"正在連線到 Inference Client: {MODEL_ID}..."):
134
  inference_client = load_inference_client(MODEL_ID)
 
135
  if inference_client is None and os.environ.get("HF_TOKEN"):
136
+ st.warning("Hugging Face Inference Client 無法連線。")
137
+ elif not os.environ.get("HF_TOKEN"):
138
+ st.error("請在環境變數中設定 HF_TOKEN。")
 
139
 
140
+ # === Embedding 模型 (保持不變) ===
141
  @st.cache_resource
142
  def load_embedding_model():
143
+ model_kwargs = {'device': 'cpu', 'trust_remote_code': True}
144
+ encode_kwargs = {'normalize_embeddings': False}
145
+ return HuggingFaceEmbeddings(model_name="BAAI/bge-large-zh-v1.5", model_kwargs=model_kwargs, encode_kwargs=encode_kwargs)
 
 
 
 
 
 
 
 
 
 
146
 
147
  with st.spinner("正在載入 Embedding 模型..."):
148
  embedding_model = load_embedding_model()
149
 
150
  # === 建立向量庫 / Search 函數 (保持不變) ===
151
  def process_file_to_faiss(uploaded_file):
 
152
  text_content = ""
153
  try:
154
  if uploaded_file.type == "application/pdf":
 
156
  pdf_reader = pypdf.PdfReader(uploaded_file)
157
  for page in pdf_reader.pages:
158
  text_content += page.extract_text() + "\n"
159
+ else: return None, "PDF library missing"
 
160
  else:
161
  stringio = io.StringIO(uploaded_file.getvalue().decode("utf-8"))
162
  text_content = stringio.read()
163
+
164
+ if not text_content.strip(): return None, "File is empty"
165
+
 
 
166
  events = [line for line in text_content.splitlines() if line.strip()]
 
167
  docs = [Document(page_content=e) for e in events]
168
+ if not docs: return None, "No documents created"
169
+
 
 
170
  embeddings = embedding_model.embed_documents([d.page_content for d in docs])
171
  embeddings_np = np.array(embeddings).astype("float32")
172
+ faiss.normalize_L2(embeddings_np)
173
+
174
  dimension = embeddings_np.shape[1]
175
+ index = faiss.IndexFlatIP(dimension)
176
  index.add(embeddings_np)
177
+
178
  doc_ids = [str(uuid.uuid4()) for _ in range(len(docs))]
179
  docstore = InMemoryDocstore({_id: doc for _id, doc in zip(doc_ids, docs)})
180
  index_to_docstore_id = {i: _id for i, _id in enumerate(doc_ids)}
181
+
182
+ vector_store = FAISS(embedding_function=embedding_model, index=index, docstore=docstore, index_to_docstore_id=index_to_docstore_id, distance_strategy=DistanceStrategy.COSINE)
 
 
 
 
 
 
 
183
  return vector_store, f"{len(docs)} chunks created."
184
  except Exception as e:
185
  return None, f"Error: {str(e)}"
186
 
187
  def faiss_cosine_search_all(vector_store, query, threshold):
 
188
  q_emb = embedding_model.embed_query(query)
189
  q_emb = np.array([q_emb]).astype("float32")
190
  faiss.normalize_L2(q_emb)
 
191
  index = vector_store.index
192
  D, I = index.search(q_emb, k=index.ntotal)
 
193
  selected = []
194
  for score, idx in zip(D[0], I[0]):
195
  if idx == -1: continue
 
196
  if score >= threshold:
197
  doc_id = vector_store.index_to_docstore_id[idx]
198
  doc = vector_store.docstore.search(doc_id)
199
  selected.append((doc, score))
 
200
  selected.sort(key=lambda x: x[1], reverse=True)
201
  return selected
202
 
203
+ # === Hugging Face 生成單一 Log 分析回答 (保持不變) ===
204
  def generate_rag_response_hf_for_log(client, model_id, log_sequence_text, user_prompt, sys_prompt, vector_store, threshold, max_output_tokens, temperature, top_p):
205
+ if client is None: return "ERROR: Client Error", ""
 
 
 
 
 
206
  context_text = ""
 
207
  if vector_store:
208
  selected = faiss_cosine_search_all(vector_store, log_sequence_text, threshold)
209
  if selected:
210
+ retrieved_contents = [f"--- Reference Chunk (sim={score:.3f}) ---\n{doc.page_content}" for i, (doc, score) in enumerate(selected[:5])]
 
 
 
211
  context_text = "\n".join(retrieved_contents)
212
+
213
+ rag_instruction = f"""=== RETRIEVED REFERENCE CONTEXT (Cosine ≥ {threshold}) ==={context_text if context_text else 'No relevant reference context found.'}=== END REFERENCE CONTEXT ===\nANALYSIS INSTRUCTION: {user_prompt}\nBased on the provided LOG SEQUENCE and REFERENCE CONTEXT, you must analyze the **entire sequence** to detect any continuous attack chains or evolving threats."""
 
 
 
 
 
214
  log_content_section = f"""=== CURRENT LOG SEQUENCE TO ANALYZE (Window Size: {WINDOW_SIZE}) ===\n{log_sequence_text}\n=== END LOG SEQUENCE ==="""
215
 
 
 
216
  messages = [
217
  {"role": "system", "content": sys_prompt},
218
  {"role": "user", "content": f"{rag_instruction}\n\n{log_content_section}"}
219
  ]
 
 
220
  try:
221
+ response_stream = client.chat_completion(messages, max_tokens=max_output_tokens, temperature=temperature, top_p=top_p, stream=False)
 
 
 
 
 
 
 
 
 
222
  if response_stream and response_stream.choices:
223
+ return response_stream.choices[0].message.content.strip(), context_text
224
+ else: return "Format Error", context_text
225
+ except Exception as e: return f"Model Error: {str(e)}", context_text
 
 
 
 
 
 
 
 
226
 
227
+ # =======================================================================
228
+ # === 檔案處理區塊 (RAG 檔案) ===
229
  if rag_uploaded_file:
230
  file_key = f"vs_{rag_uploaded_file.name}_{rag_uploaded_file.size}"
 
231
  if st.session_state.rag_current_file_key != file_key or 'vector_store' not in st.session_state:
 
232
  with st.spinner(f"正在建立 RAG 參考知識庫 ({rag_uploaded_file.name})..."):
233
  vs, msg = process_file_to_faiss(rag_uploaded_file)
234
  if vs:
235
  st.session_state.vector_store = vs
236
  st.session_state.rag_current_file_key = file_key
237
  st.toast(f"RAG 參考知識庫已更新!{msg}", icon="✅")
238
+ else: st.error(msg)
 
 
239
  elif 'vector_store' in st.session_state:
240
  del st.session_state.vector_store
241
  del st.session_state.rag_current_file_key
242
  st.info("RAG 檔案已移除,已清除相關知識庫。")
243
 
244
+ # === 檔案處理區塊 (批量分析檔案 - 重大修改處) ===
245
+ # 支援 JSON, CSV, TXT 並統一轉換為 list of dicts
246
+ if batch_uploaded_file:
247
+ batch_file_key = f"batch_{batch_uploaded_file.name}_{batch_uploaded_file.size}"
248
 
249
+ if st.session_state.batch_current_file_key != batch_file_key or 'json_data_for_batch' not in st.session_state:
250
  try:
251
+ stringio = io.StringIO(batch_uploaded_file.getvalue().decode("utf-8"))
252
+ parsed_data = None
253
+
254
+ # --- Case 1: JSON ---
255
+ if batch_uploaded_file.name.lower().endswith('.json'):
256
+ parsed_data = json.load(stringio)
257
+ st.toast("JSON 檔案載入成功", icon="📄")
258
+
259
+ # --- Case 2: CSV ---
260
+ elif batch_uploaded_file.name.lower().endswith('.csv'):
261
+ # 使用 DictReader 將 CSV 轉為 List of Dicts
262
+ reader = csv.DictReader(stringio)
263
+ parsed_data = list(reader)
264
+ st.toast("CSV 檔案已轉換為 JSON 結構", icon="📊")
265
+
266
+ # --- Case 3: TXT ---
267
+ else: # 預設為 TXT
268
+ # 將每一行包裝成一個 JSON 物件: {"raw_content": "line text"}
269
+ lines = stringio.readlines()
270
+ parsed_data = [{"raw_log_entry": line.strip()} for line in lines if line.strip()]
271
+ st.toast("TXT 檔案已轉換為 JSON 結構", icon="📝")
272
+
273
+ # 儲存處理後的數據
274
+ st.session_state.json_data_for_batch = parsed_data
275
+ st.session_state.batch_current_file_key = batch_file_key
276
+
277
  except Exception as e:
278
+ st.error(f"檔案解析錯誤: {e}")
279
  if 'json_data_for_batch' in st.session_state:
280
  del st.session_state.json_data_for_batch
281
+
282
  elif 'json_data_for_batch' in st.session_state:
283
  del st.session_state.json_data_for_batch
284
+ del st.session_state.batch_current_file_key
285
  if "batch_results" in st.session_state:
286
  del st.session_state.batch_results
287
+ st.info("批量分析檔案已移除,已清除相關數據。")
288
 
289
+ # === 執行批量分析邏輯 ===
290
  if st.session_state.execute_batch_analysis and 'json_data_for_batch' in st.session_state:
291
  st.session_state.execute_batch_analysis = False
292
+ start_time = time.time()
293
  st.session_state.batch_results = []
294
 
295
  if inference_client is None:
296
+ st.error("Client 未連線,無法執行。")
297
+ else:
298
+ data_to_process = st.session_state.json_data_for_batch
299
+ logs_list = []
300
 
301
+ # 處理不同的 JSON 結構 (Dict vs List)
302
+ if isinstance(data_to_process, list):
303
+ logs_list = data_to_process
304
+ elif isinstance(data_to_process, dict):
305
+ # 嘗試尋找常見的 key
306
+ if 'alerts' in data_to_process and isinstance(data_to_process['alerts'], list):
307
+ logs_list = data_to_process['alerts']
308
+ elif 'logs' in data_to_process and isinstance(data_to_process['logs'], list):
309
+ logs_list = data_to_process['logs']
310
+ else:
311
+ logs_list = [data_to_process]
 
 
312
  else:
313
  logs_list = [data_to_process]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
 
315
+ if logs_list:
316
+ vs = st.session_state.get("vector_store", None)
 
 
 
317
 
318
+ # --- 關鍵:在這裡做 JSON String 的轉換 ---
319
+ # 無論來源是 CSV(Dict) 還是 TXT(Dict),都在這裡用 json.dumps 轉成字串
320
+ # 這保證了 Prompt 收到的永遠是 JSON 格式的文字
321
+ formatted_logs = [json.dumps(log, indent=2, ensure_ascii=False) for log in logs_list]
 
322
 
323
+ analysis_sequences = []
324
+ for i in range(len(formatted_logs)):
325
+ start_index = max(0, i - WINDOW_SIZE + 1)
326
+ end_index = i + 1
327
+ current_window = formatted_logs[start_index:end_index]
328
+ sequence_text = []
329
+ for j, log_str in enumerate(current_window):
330
+ is_target = " <<< TARGET LOG TO ANALYZE" if j == len(current_window) - 1 else ""
331
+ sequence_text.append(f"--- Log Index {i - len(current_window) + j + 1} ({len(current_window)-j} prior logs){is_target} ---\n{log_str}")
332
+ analysis_sequences.append({
333
+ "sequence_text": "\n\n".join(sequence_text),
334
+ "target_log_id": i + 1,
335
+ "original_log_entry": logs_list[i]
336
+ })
337
 
338
+ total_sequences = len(analysis_sequences)
339
+ st.header(f"⚡ 批量分析執行中 (平移視窗 $N={WINDOW_SIZE}$)...")
340
+ progress_bar = st.progress(0, text=f"準備處理 {total_sequences} 個序列...")
341
+ results_container = st.container()
342
+ full_report_chunks = ["## Cybersecurity Batch Analysis Report\n\n"]
 
 
 
 
 
 
343
 
344
+ for i, seq_data in enumerate(analysis_sequences):
345
+ log_id = seq_data["target_log_id"]
346
+ progress_bar.progress((i + 1) / total_sequences, text=f"Processing {i + 1}/{total_sequences} (Log #{log_id})...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
 
348
+ try:
349
+ response, retrieved_ctx = generate_rag_response_hf_for_log(
350
+ client=inference_client,
351
+ model_id=MODEL_ID,
352
+ log_sequence_text=seq_data["sequence_text"],
353
+ user_prompt=analysis_prompt,
354
+ sys_prompt=system_prompt,
355
+ vector_store=vs,
356
+ threshold=similarity_threshold,
357
+ max_output_tokens=max_output_tokens,
358
+ temperature=temperature,
359
+ top_p=top_p
360
+ )
361
+ item = {
362
+ "log_id": log_id,
363
+ "log_content": seq_data["original_log_entry"],
364
+ "sequence_analyzed": seq_data["sequence_text"],
365
+ "analysis_result": response,
366
+ "context": retrieved_ctx
367
+ }
368
+ st.session_state.batch_results.append(item)
369
+
370
+ with results_container:
371
+ st.subheader(f"Log/Alert #{item['log_id']}")
372
+ with st.expander("序列內容 (JSON Format)"):
373
+ st.code(item["sequence_analyzed"], language='json') # 這裡顯示的會是 JSON 格式
374
 
375
+ is_high = any(x in response.lower() for x in ['high', 'critical', 'yes']) and 'criticality/priority:' in response.lower()
376
+ if is_high: st.error(item['analysis_result'])
377
+ else: st.info(item['analysis_result'])
378
+ if item['context']:
379
+ with st.expander("參考 RAG 片段"): st.code(item['context'])
380
+ st.markdown("---")
 
 
 
 
 
 
 
 
 
381
 
382
+ log_content_str_for_report = json.dumps(item["log_content"], indent=2, ensure_ascii=False).replace("`", "\\`")
383
+ full_report_chunks.append(f"---\n\n### Log #{item['log_id']}\n```json\n{log_content_str_for_report}\n```\nResult:\n{item['analysis_result']}\n")
384
+
385
+ except Exception as e:
386
+ st.error(f"Error Log {log_id}: {e}")
387
+
388
+ end_time = time.time()
389
+ progress_bar.empty()
390
+ st.success(f"完成!耗時 {end_time - start_time:.2f} 秒。")
391
+ else:
392
+ st.error("無法提取有效 Log,請檢查檔案格式。")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
 
394
+ # === 顯示結果 (歷史紀錄) ===
395
  if st.session_state.get("batch_results") and not st.session_state.execute_batch_analysis:
396
+ st.header("⚡ 歷史分析結果")
397
+ full_report_chunks = ["## Report\n\n"]
 
398
  for item in st.session_state.batch_results:
399
  log_content_str_for_report = json.dumps(item["log_content"], indent=2, ensure_ascii=False).replace("`", "\\`")
400
+ full_report_chunks.append(f"---\n\n### Log #{item['log_id']}\n```json\n{log_content_str_for_report}\n```\n{item['analysis_result']}\n")
401
+ st.download_button("📥 下載完整報告 (.md)", "\n".join(full_report_chunks), "report.md", "text/markdown")