ss900371tw's picture
Update src/streamlit_app.py
47412a4 verified
import streamlit as st
import os
import io
import json
import numpy as np
import faiss
import uuid
import time
import sys
# === HuggingFace 模型相關套件 (新增) ===
try:
# 確保只在需要時載入,避免在無 GPU 環境下強制載入導致錯誤
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
# 針對本地大模型:
# from accelerate import Accelerator # 建議安裝
# import bitsandbytes # 建議安裝
except ImportError:
st.error("請檢查是否安裝了所有 Hugging Face 相關依賴:pip install transformers torch accelerate bitsandbytes")
# 如果缺少,則退出或將相關變數設為 None
AutoModelForCausalLM, AutoTokenizer, pipeline, torch = None, None, None, None
# === LangChain/RAG 相關套件 (保持不變) ===
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_core.documents import Document
from langchain_community.vectorstores import FAISS
from langchain_community.vectorstores.utils import DistanceStrategy
from langchain_community.docstore.in_memory import InMemoryDocstore
# 嘗試匯入 pypdf
try:
import pypdf
except ImportError:
pypdf = None
# --- 頁面設定 ---
st.set_page_config(page_title="Cybersecurity AI Assistant (Hugging Face RAG & Batch Analysis)", page_icon="🛡️", layout="wide")
st.title("🛡️ Foundation-Sec-1.1-8B-Instruct with FAISS RAG & Batch Analysis")
st.markdown("已啟用:**IndexFlatIP** + **L2 正規化** + **Hugging Face LLM**。上傳 JSON 執行批量分析,上傳其他檔案作為 RAG 參考庫。")
# 設定模型 ID (替換為 Hugging Face 模型名稱)
MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"
WINDOW_SIZE = 8
# --- 側邊欄設定 ---
with st.sidebar:
st.header("⚙️ 設定")
# === 替換為 Hugging Face 模型名稱顯示 (移除 API Key 輸入) ===
st.info(f"LLM 模型:**{MODEL_ID}** (Hugging Face Model)")
st.warning("⚠️ **注意**: 8B 模型需要大量 RAM/VRAM 和算力。運行可能較慢或失敗。")
st.divider()
st.subheader("📂 檔案上傳")
# === 1. JSON 批量分析檔案 (新的上傳器) ===
json_uploaded_file = st.file_uploader(
"1️⃣ 上傳 **JSON** Log/Alert 檔案 (用於批量分析)",
type=['json'],
key="json_uploader"
)
# === 2. RAG 知識庫檔案 (新的上傳器) ===
rag_uploaded_file = st.file_uploader(
"2️⃣ 上傳 **RAG 參考知識庫** (Logs/PDF/Code 等)",
type=['txt', 'py', 'log', 'csv', 'md', 'pdf'],
key="rag_uploader"
)
st.divider()
st.subheader("💡 批量分析指令 (針對 JSON 檔案)")
analysis_prompt = st.text_area(
"針對每個 Log/Alert 執行的指令",
value="You are a security expert in charge of analyzing a single alert and prioritizing its criticality. Respond with a clear, structured analysis using the following mandatory sections: \n\n- Criticality/Priority: Is this alert critical? (Answer Yes/No only), and provide the overall priority level. (Answer High, Medium, or Low only) \n- Explanation: If this alert is critical or medium~high priority level, explain the potential impact and why this specific alert requires attention. If not, omit the explanation section. \n- Action Plan: If this alert is critical or medium~high priority level, What should be the immediate steps to address this specific alert? If not, omit the action plan section. \n\nStrictly use the information in the provided Log.",
height=200
)
st.markdown("此指令將對 JSON 檔案中的**每一個 Log 條目**執行一次獨立分析。")
if json_uploaded_file: # 移除 API Key 檢查
if st.button("🚀 執行批量分析"):
st.session_state.execute_batch_analysis = True
else:
st.info("請上傳 JSON 檔案以啟用批量分析按鈕。")
st.divider()
st.subheader("🔍 RAG 檢索設定")
similarity_threshold = st.slider(
"📐 Cosine Similarity 門檻",
0.0, 1.0, 0.4, 0.01,
help="數值越大越相似。一般建議 0.4~0.7"
)
st.divider()
st.subheader("模型參數")
system_prompt = st.text_area("System Prompt (LLM 使用)", value="You are a Senior Security Analyst. Be professional.", height=100)
max_output_tokens = st.slider("Max Output Tokens", 128, 4096, 2048, 128)
temperature = st.slider("Temperature", 0.0, 1.0, 0.1, 0.1)
top_p = st.slider("Top P", 0.1, 1.0, 0.95, 0.05)
st.divider()
if st.button("🗑️ 清除所有紀錄"):
for key in list(st.session_state.keys()):
if key not in []:
del st.session_state[key]
st.rerun()
# --- 初始化 Hugging Face LLM Client (重大替換) ---
@st.cache_resource
def load_huggingface_llm(model_id):
if AutoModelForCausalLM is None:
st.error("無法載入 Hugging Face 依賴,請安裝:pip install transformers torch accelerate bitsandbytes")
return None
try:
# 使用量化 (4-bit) 減少記憶體消耗,這是運行 8B 模型的常見做法
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else None,
device_map="auto", # <--- 讓 accelerate 管理裝置
trust_remote_code=True,
# load_in_4bit=True # 如果需要 4-bit 量化
)
# 使用 pipeline 簡化呼叫
llm_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
# device=(0 if torch.cuda.is_available() else -1) # <--- **移除此參數**
)
st.success(f"Hugging Face 模型 **{model_id}** 載入成功。")
return llm_pipeline
except Exception as e:
st.error(f"Hugging Face 模型載入失敗: {e}")
return None
# 在 main 區塊外初始化 pipeline
llm_pipeline = None
if AutoModelForCausalLM is not None:
with st.spinner(f"正在載入 LLM 模型: {MODEL_ID} (8B)... (可能需要數分鐘)"):
llm_pipeline = load_huggingface_llm(MODEL_ID)
if llm_pipeline is None:
st.warning("Hugging Face LLM 無法載入。請檢查依賴和環境資源。")
# =======================================================================
# === Embedding 模型 (用於 RAG 參考庫) (保持不變) ===
@st.cache_resource
def load_embedding_model():
model_kwargs = {
'device': 'cpu',
'trust_remote_code': True
}
encode_kwargs = {
'normalize_embeddings': False
}
# 選擇一個適合 RAG 的中文 Embedding Model
return HuggingFaceEmbeddings(
model_name="BAAI/bge-large-zh-v1.5",
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)
with st.spinner("正在載入 Embedding 模型..."):
embedding_model = load_embedding_model()
# === 建立向量庫 / Search 函數 (保持不變) ===
def process_file_to_faiss(uploaded_file):
text_content = ""
try:
if uploaded_file.type == "application/pdf":
if pypdf:
pdf_reader = pypdf.PdfReader(uploaded_file)
for page in pdf_reader.pages:
text_content += page.extract_text() + "\n"
else:
return None, "PDF library missing"
else:
stringio = io.StringIO(uploaded_file.getvalue().decode("utf-8"))
text_content = stringio.read()
if not text_content.strip():
return None, "File is empty"
# 嘗試以 </Event> 分割 Log,否則以換行符分割
events = [e + "</Event>" for e in text_content.split("</Event>") if e.strip()]
if len(events) <= 1:
events = [line for line in text_content.split("\n") if line.strip()]
docs = [Document(page_content=e) for e in events]
if not docs:
return None, "No documents created"
embeddings = embedding_model.embed_documents([d.page_content for d in docs])
embeddings_np = np.array(embeddings).astype("float32")
faiss.normalize_L2(embeddings_np) # L2 正規化
dimension = embeddings_np.shape[1]
index = faiss.IndexFlatIP(dimension) # IndexFlatIP (內積)
index.add(embeddings_np)
doc_ids = [str(uuid.uuid4()) for _ in range(len(docs))]
docstore = InMemoryDocstore({_id: doc for _id, doc in zip(doc_ids, docs)})
index_to_docstore_id = {i: _id for i, _id in enumerate(doc_ids)}
vector_store = FAISS(
embedding_function=embedding_model,
index=index,
docstore=docstore,
index_to_docstore_id=index_to_docstore_id,
distance_strategy=DistanceStrategy.COSINE # 使用 Cosine 距離 (對應 IndexFlatIP)
)
return vector_store, f"{len(docs)} chunks created."
except Exception as e:
return None, f"Error: {str(e)}"
def faiss_cosine_search_all(vector_store, query, threshold):
q_emb = embedding_model.embed_query(query)
q_emb = np.array([q_emb]).astype("float32")
faiss.normalize_L2(q_emb)
index = vector_store.index
D, I = index.search(q_emb, k=index.ntotal)
selected = []
for score, idx in zip(D[0], I[0]):
if idx == -1: continue
# IndexFlatIP 輸出內積,與歸一化後的 Cosine Similarity 相同
if score >= threshold:
doc_id = vector_store.index_to_docstore_id[idx]
doc = vector_store.docstore.search(doc_id)
selected.append((doc, score))
selected.sort(key=lambda x: x[1], reverse=True)
return selected
# === Hugging Face 生成單一 Log 分析回答 (核心批量處理函數) (重大替換) ===
def generate_rag_response_hf_for_log(llm_pipeline, model_id, log_sequence_text, user_prompt, sys_prompt, vector_store, threshold, max_output_tokens, temperature, top_p):
"""
使用 Hugging Face LLM 執行 RAG 增強的 Log 序列分析。
"""
if llm_pipeline is None:
return "ERROR: Hugging Face LLM Pipeline 未載入。", ""
context_text = ""
# 1. RAG 檢索邏輯
if vector_store:
selected = faiss_cosine_search_all(vector_store, log_sequence_text, threshold)
if selected:
retrieved_contents = [
f"--- Reference Chunk (sim={score:.3f}) ---\n{doc.page_content}"
for i, (doc, score) in enumerate(selected[:5]) # 限制檢索結果數量
]
context_text = "\n".join(retrieved_contents)
# 2. 建構 Prompt 的 RAG 部分和指令部分 (針對 HF 指令模型)
rag_instruction = f"""=== RETRIEVED REFERENCE CONTEXT (Cosine ≥ {threshold}) ===
{context_text if context_text else 'No relevant reference context found.'}
=== END REFERENCE CONTEXT ===
ANALYSIS INSTRUCTION: {user_prompt}
Based on the provided LOG SEQUENCE and REFERENCE CONTEXT, you must analyze the **entire sequence** to detect any continuous attack chains or evolving threats. Focus on the **last log entry in the sequence** to determine its final criticality and priority, considering the preceding {WINDOW_SIZE} logs."""
log_content_section = f"""=== CURRENT LOG SEQUENCE TO ANALYZE (Window Size: {WINDOW_SIZE}) ===
{log_sequence_text}
=== END LOG SEQUENCE ==="""
# 整合 System Prompt、RAG、和 Log 內容
# 注意:fdtn-ai/Foundation-Sec-1.1-8B-Instruct 遵循 ChatML 格式,但此處使用簡化的 instruction-tuning 格式
full_prompt = (
f"**SYSTEM INSTRUCTION**: {sys_prompt}\n\n"
f"**RAG & ANALYSIS INSTRUCTION**:\n{rag_instruction}\n\n"
f"**LOG DATA**:\n{log_content_section}\n\n"
f"**RESPONSE**:"
)
# 3. 呼叫 Hugging Face Pipeline
try:
# Pipeline 參數設定
response = llm_pipeline(
full_prompt,
max_new_tokens=max_output_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True, # 啟用採樣
return_full_text=False # 只返回生成的文本
)
# 處理 pipeline 的輸出格式
if response and isinstance(response, list) and 'generated_text' in response[0]:
return response[0]['generated_text'].strip(), context_text
else:
return f"Hugging Face Pipeline 輸出格式錯誤: {response}", context_text
except Exception as e:
# 如果模型呼叫失敗,回傳詳細錯誤訊息
return f"Hugging Face Model Error: {str(e)}", context_text
# === 檔案處理和主執行邏輯 (保持結構,替換 LLM 呼叫) ===
# 初始化 Session State
if 'execute_batch_analysis' not in st.session_state:
st.session_state.execute_batch_analysis = False
if 'batch_results' not in st.session_state:
st.session_state.batch_results = None
# --- 1. 處理 RAG 知識庫檔案 (rag_uploaded_file) ---
if 'rag_current_file_key' not in st.session_state:
st.session_state.rag_current_file_key = None
if rag_uploaded_file:
file_key = f"vs_{rag_uploaded_file.name}_{rag_uploaded_file.size}"
if st.session_state.rag_current_file_key != file_key or 'vector_store' not in st.session_state:
# 偵測到新 RAG 檔案,需要重新建立知識庫
with st.spinner(f"正在建立 RAG 參考知識庫 ({rag_uploaded_file.name})..."):
vs, msg = process_file_to_faiss(rag_uploaded_file)
if vs:
st.session_state.vector_store = vs
st.session_state.rag_current_file_key = file_key
st.toast(f"RAG 參考知識庫已更新!{msg}", icon="✅")
else:
st.error(msg)
# 檔案移除/狀態清理 (如果使用者移除了 RAG 檔案)
elif 'vector_store' in st.session_state:
del st.session_state.vector_store
del st.session_state.rag_current_file_key
st.info("RAG 檔案已移除,已清除相關知識庫。")
# --- 2. 處理 JSON 批量分析檔案 (json_uploaded_file) ---
if 'json_current_file_key' not in st.session_state:
st.session_state.json_current_file_key = None
if json_uploaded_file:
json_file_key = f"json_{json_uploaded_file.name}_{json_uploaded_file.size}"
if st.session_state.json_current_file_key != json_file_key or 'json_data_for_batch' not in st.session_state:
try:
# 偵測到新 JSON 檔案
json_data = json.load(io.StringIO(json_uploaded_file.getvalue().decode("utf-8")))
st.session_state.json_data_for_batch = json_data
st.session_state.json_current_file_key = json_file_key
st.toast("JSON Log 檔案已載入,請按 '執行批量分析'。", icon="📄")
except Exception as e:
st.error(f"JSON 檔案解析錯誤: {e}")
if 'json_data_for_batch' in st.session_state:
del st.session_state.json_data_for_batch
# 檔案移除/狀態清理 (如果使用者移除了 JSON 檔案)
elif 'json_data_for_batch' in st.session_state:
del st.session_state.json_data_for_batch
del st.session_state.json_current_file_key
if "batch_results" in st.session_state:
del st.session_state.batch_results
st.info("JSON 檔案已移除,已清除日誌數據和分析結果。")
# === 執行批量分析邏輯 (包含顏色控制) ===
if st.session_state.execute_batch_analysis and 'json_data_for_batch' in st.session_state:
st.session_state.execute_batch_analysis = False
start_time = time.time() # 開始計時
st.session_state.batch_results = []
if llm_pipeline is None:
st.error("Hugging Face LLM Pipeline 未載入,請檢查依賴和環境資源,無法執行批量分析。")
# 由於這是一個 Streamlit App,我們不直接 st.stop(),讓使用者可以檢查設定
st.session_state.execute_batch_analysis = False
data_to_process = st.session_state.json_data_for_batch
# 提取 Log 列表的邏輯 (保持不變)
logs_list = []
if isinstance(data_to_process, list):
logs_list = data_to_process
elif isinstance(data_to_process, dict):
if all(isinstance(v, (dict, str, list)) for v in data_to_process.values()):
logs_list = list(data_to_process.values())
elif 'alerts' in data_to_process and isinstance(data_to_process['alerts'], list):
logs_list = data_to_process['alerts']
elif 'logs' in data_to_process and isinstance(data_to_process['logs'], list):
logs_list = data_to_process['logs']
else:
logs_list = [data_to_process]
else:
logs_list = [data_to_process]
if logs_list:
vs = st.session_state.get("vector_store", None)
if vs:
st.success("✅ RAG 知識庫已啟用並用於分析。")
else:
st.warning("⚠️ RAG 知識庫未載入,將單純執行 Log 分析。")
# --- 新增:創建平移視窗序列 ---
# 將所有 Log 轉換為 JSON 格式化字串列表,以便後續拼接
formatted_logs = [json.dumps(log, indent=2, ensure_ascii=False) for log in logs_list]
# 創建要分析的序列 (Sliding Window) 列表
analysis_sequences = []
for i in range(len(formatted_logs)):
start_index = max(0, i - WINDOW_SIZE + 1)
end_index = i + 1 # 終點為當前 Log
current_window = formatted_logs[start_index:end_index]
sequence_text = []
for j, log_str in enumerate(current_window):
is_target = " <<< TARGET LOG TO ANALYZE" if j == len(current_window) - 1 else ""
# 使用 i-len(current_window)+j+1 來計算原始索引
sequence_text.append(f"--- Log Index {i - len(current_window) + j + 1} ({len(current_window)-j} prior logs){is_target} ---\n{log_str}")
analysis_sequences.append({
"sequence_text": "\n\n".join(sequence_text),
"target_log_id": i + 1, # 該序列的分析目標是原始列表中的第 i+1 條 Log
"original_log_entry": logs_list[i]
})
total_sequences = len(analysis_sequences)
if total_sequences < WINDOW_SIZE:
st.warning(f"Log 總數 ({total_sequences}) 少於視窗大小 ({WINDOW_SIZE}),分析的結果可能較不準確。")
# --- 執行序列分析 ---
st.header(f"⚡ 批量分析執行中 (平移視窗 $N={WINDOW_SIZE}$)...")
progress_bar = st.progress(0, text=f"準備處理 {total_sequences} 個序列...")
results_container = st.container()
full_report_chunks = ["## Cybersecurity Batch Analysis Report\n\n"]
priority_keyword = "Criticality/Priority:"
for i, seq_data in enumerate(analysis_sequences):
log_id = seq_data["target_log_id"]
progress_bar.progress((i + 1) / total_sequences, text=f"已處理 {i + 1}/{total_sequences} 個序列 (目標 Log #{log_id})...")
try:
# *** 替換為 Hugging Face 呼叫函數 ***
response, retrieved_ctx = generate_rag_response_hf_for_log(
llm_pipeline=llm_pipeline, # <--- 新的 LLM pipeline
model_id=MODEL_ID,
log_sequence_text=seq_data["sequence_text"],
user_prompt=analysis_prompt,
sys_prompt=system_prompt,
vector_store=vs,
threshold=similarity_threshold,
max_output_tokens=max_output_tokens,
temperature=temperature,
top_p=top_p
)
# 儲存結果
item = {
"log_id": log_id,
"log_content": seq_data["original_log_entry"], # 記錄原始 Log 條目
"sequence_analyzed": seq_data["sequence_text"], # 記錄分析的序列
"analysis_result": response,
"context": retrieved_ctx
}
st.session_state.batch_results.append(item)
# 結果顯示邏輯
with results_container:
st.subheader(f"Log/Alert #{item['log_id']} (序列分析完成)")
with st.expander(f"序列內容 (包含 {len(seq_data['sequence_text'].split('--- Log Index'))-1} 條 Log)"):
st.code(item["sequence_analyzed"], language='text')
# 顏色控制:
is_high_priority = False
if 'criticality/priority:' in response.lower():
try:
priority_section = response.split('Criticality/Priority:')[1].split('\n')[0].strip()
if 'high' in priority_section.lower() or 'medium' in priority_section.lower() or 'yes' in priority_section.lower():
is_high_priority = True
except IndexError:
pass
st.markdown(f"### 🤖 分析結果 (針對 Log #{log_id})")
if is_high_priority:
st.error(item['analysis_result'])
else:
st.info(item['analysis_result'])
if item['context']:
with st.expander("參考的 RAG 知識庫片段"):
st.code(item['context'])
st.markdown("---")
# 報告 chunks
log_content_str_for_report = json.dumps(item["log_content"], indent=2, ensure_ascii=False).replace("`", "\\`")
full_report_chunks.append(f"---\n\n### Log/Alert #{item['log_id']} (序列分析)\n\n#### 分析的序列內容\n```\n{seq_data['sequence_text']}\n```\n\n#### LLM 分析結果\n{item['analysis_result']}\n")
except Exception as e:
error_message = f"ERROR: Log {log_id} 序列處理失敗: {e}"
st.session_state.batch_results.append({
"log_id": log_id,
"log_content": seq_data["original_log_entry"],
"sequence_analyzed": seq_data["sequence_text"],
"analysis_result": error_message,
"context": ""
})
with results_container:
st.error(error_message)
end_time = time.time()
progress_bar.empty()
st.success(f"批量分析完成!共處理 {total_sequences} 個 Log 序列,耗時 {end_time - start_time:.2f} 秒。")
st.divider()
else:
st.error("無法從上傳的 JSON 檔案中提取 Log 列表或有效的 Log 條目。請檢查檔案結構。")
# === 顯示結果 (歷史紀錄) (保持不變) ===
if st.session_state.batch_results and not st.session_state.execute_batch_analysis:
st.header("⚡ 上次分析結果 (歷史紀錄)")
full_report_chunks = ["## Cybersecurity Batch Analysis Report\n\n"]
for item in st.session_state.batch_results:
log_content_str_for_report = json.dumps(item["log_content"], indent=2, ensure_ascii=False).replace("`", "\\`")
full_report_chunks.append(f"---\n\n### Log/Alert #{item['log_id']}\n\n#### 原始內容\n```json\n{log_content_str_for_report}\n```\n\n#### LLM 分析結果\n{item['analysis_result']}\n")
st.info(f"偵測到 {len(st.session_state.batch_results)} 條 Log 的歷史分析結果。")
st.download_button(
label="📥 下載上次的完整報告 (.md)",
data="\n".join(full_report_chunks),
file_name="security_batch_analysis_report_history.md",
mime="text/markdown"
)