ss900371tw's picture
Update src/streamlit_app.py
2d7fa4f verified
import streamlit as st
import os
import io
import json
import csv # <--- 新增:用於處理 CSV
import numpy as np
import faiss
import uuid
import time
import sys
# === HuggingFace 模型相關套件 (替換為 InferenceClient) ===
try:
from huggingface_hub import InferenceClient
except ImportError:
st.error("請檢查是否安裝了所有 Hugging Face 相關依賴:pip install huggingface-hub")
# === LangChain/RAG 相關套件 (保持不變) ===
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_core.documents import Document
from langchain_community.vectorstores import FAISS
from langchain_community.vectorstores.utils import DistanceStrategy
from langchain_community.docstore.in_memory import InMemoryDocstore
# 嘗試匯入 pypdftry
try:
import pypdf
except ImportError:
pypdf = None
# --- 頁面設定 ---
st.set_page_config(page_title="Cybersecurity AI Assistant (Hugging Face RAG & Batch Analysis)", page_icon="🛡️", layout="wide")
st.title("🛡️ Meta-Llama-3-8B-Instruct with FAISS RAG & Batch Analysis (Inference Client)")
st.markdown("已啟用:**IndexFlatIP** + **L2 正規化** + **Hugging Face Inference Client (API)**。支援 JSON/CSV/TXT 執行批量分析。")
if 'execute_batch_analysis' not in st.session_state:
st.session_state.execute_batch_analysis = False
if 'batch_results' not in st.session_state:
st.session_state.batch_results = None
if 'rag_current_file_key' not in st.session_state:
st.session_state.rag_current_file_key = None
if 'batch_current_file_key' not in st.session_state: # 修改變數名稱以反映多格式
st.session_state.batch_current_file_key = None
if 'vector_store' not in st.session_state:
st.session_state.vector_store = None
if 'json_data_for_batch' not in st.session_state: # 變數名稱保留,但內容可能是轉換後的 dict
st.session_state.json_data_for_batch = None
# 設定模型 ID
MODEL_ID = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
WINDOW_SIZE = 8
# --- 側邊欄設定 ---
with st.sidebar:
st.header("⚙️ 設定")
if not os.environ.get("HF_TOKEN"):
st.error("環境變數 **HF_TOKEN** 未設定。請設定後重新啟動應用程式。")
st.info(f"LLM 模型:**{MODEL_ID}** (Hugging Face Inference API)")
st.warning("⚠️ **注意**: 該模型使用 Inference API 呼叫,請確保您的 HF Token 具有存取權限。")
st.divider()
st.subheader("📂 檔案上傳")
# === 1. 批量分析檔案 (修改處:支援多種格式) ===
batch_uploaded_file = st.file_uploader(
"1️⃣ 上傳 **Log/Alert 檔案** (用於批量分析)",
type=['json', 'csv', 'txt'], # <--- 修改:新增 csv 和 txt
key="batch_uploader",
help="支援 JSON (Array), CSV (含標題), TXT (每行一條 Log)"
)
# === 2. RAG 知識庫檔案 ===
rag_uploaded_file = st.file_uploader(
"2️⃣ 上傳 **RAG 參考知識庫** (Logs/PDF/Code 等)",
type=['txt', 'py', 'log', 'csv', 'md', 'pdf'],
key="rag_uploader"
)
st.divider()
st.subheader("💡 批量分析指令")
analysis_prompt = st.text_area(
"針對每個 Log/Alert 執行的指令",
value="You are a security expert in charge of analyzing alerts related to Web Application Attacks and Brute Force & Reconnaissance. Respond with a clear, structured analysis using the following mandatory sections: \n\n- Priority: Provide the overall priority level. (Answer High risk, Medium risk, or Low risk only) \n- Explanation: If this alert is highly related to Web Application Attacks and Brute Force & Reconnaissance, explain the potential impact and why this specific alert requires attention. If not, **omit the explanation section**. \n- Action Plan: If this alert is highly related to Web Application Attacks and Brute Force & Reconnaissance, What should be the immediate steps to address this specific alert? If not, **omit the action plan section**. \n\nStrictly use the information in the provided Log.",
height=200
)
st.markdown("此指令將對檔案中的**每一個 Log 條目**執行一次獨立分析。")
if batch_uploaded_file:
if st.button("🚀 執行批量分析"):
if not os.environ.get("HF_TOKEN"):
st.error("無法執行,環境變數 **HF_TOKEN** 未設定。")
else:
st.session_state.execute_batch_analysis = True
else:
st.info("請上傳 Log 檔案以啟用批量分析按鈕。")
st.divider()
st.subheader("🔍 RAG 檢索設定")
similarity_threshold = st.slider("📐 Cosine Similarity 門檻", 0.0, 1.0, 0.4, 0.01)
st.divider()
st.subheader("模型參數")
system_prompt = st.text_area("System Prompt", value="You are a Senior Security Analyst, named Ernest. You provide expert, authoritative, and concise advice on Information Security. Your analysis must be based strictly on the provided context.", height=100)
max_output_tokens = st.slider("Max Output Tokens", 128, 4096, 2048, 128)
temperature = st.slider("Temperature", 0.0, 1.0, 0.1, 0.1)
top_p = st.slider("Top P", 0.1, 1.0, 0.95, 0.05)
st.divider()
if st.button("🗑️ 清除所有紀錄"):
for key in list(st.session_state.keys()):
del st.session_state[key]
st.rerun()
# --- 初始化 Hugging Face LLM Client ---
@st.cache_resource
def load_inference_client(model_id):
if not os.environ.get("HF_TOKEN"): return None
try:
client = InferenceClient(model_id, token=os.environ.get("HF_TOKEN"))
st.success(f"Hugging Face Inference Client **{model_id}** 載入成功。")
return client
except Exception as e:
st.error(f"Hugging Face Inference Client 載入失敗: {e}")
return None
inference_client = None
if os.environ.get("HF_TOKEN"):
with st.spinner(f"正在連線到 Inference Client: {MODEL_ID}..."):
inference_client = load_inference_client(MODEL_ID)
if inference_client is None and os.environ.get("HF_TOKEN"):
st.warning("Hugging Face Inference Client 無法連線。")
elif not os.environ.get("HF_TOKEN"):
st.error("請在環境變數中設定 HF_TOKEN。")
# === Embedding 模型 (保持不變) ===
@st.cache_resource
def load_embedding_model():
model_kwargs = {'device': 'cpu', 'trust_remote_code': True}
encode_kwargs = {'normalize_embeddings': False}
return HuggingFaceEmbeddings(model_name="BAAI/bge-large-zh-v1.5", model_kwargs=model_kwargs, encode_kwargs=encode_kwargs)
with st.spinner("正在載入 Embedding 模型..."):
embedding_model = load_embedding_model()
# === 建立向量庫 / Search 函數 (保持不變) ===
def process_file_to_faiss(uploaded_file):
text_content = ""
try:
if uploaded_file.type == "application/pdf":
if pypdf:
pdf_reader = pypdf.PdfReader(uploaded_file)
for page in pdf_reader.pages:
text_content += page.extract_text() + "\n"
else: return None, "PDF library missing"
else:
stringio = io.StringIO(uploaded_file.getvalue().decode("utf-8"))
text_content = stringio.read()
if not text_content.strip(): return None, "File is empty"
events = [line for line in text_content.splitlines() if line.strip()]
docs = [Document(page_content=e) for e in events]
if not docs: return None, "No documents created"
embeddings = embedding_model.embed_documents([d.page_content for d in docs])
embeddings_np = np.array(embeddings).astype("float32")
faiss.normalize_L2(embeddings_np)
dimension = embeddings_np.shape[1]
index = faiss.IndexFlatIP(dimension)
index.add(embeddings_np)
doc_ids = [str(uuid.uuid4()) for _ in range(len(docs))]
docstore = InMemoryDocstore({_id: doc for _id, doc in zip(doc_ids, docs)})
index_to_docstore_id = {i: _id for i, _id in enumerate(doc_ids)}
vector_store = FAISS(embedding_function=embedding_model, index=index, docstore=docstore, index_to_docstore_id=index_to_docstore_id, distance_strategy=DistanceStrategy.COSINE)
return vector_store, f"{len(docs)} chunks created."
except Exception as e:
return None, f"Error: {str(e)}"
def faiss_cosine_search_all(vector_store, query, threshold):
q_emb = embedding_model.embed_query(query)
q_emb = np.array([q_emb]).astype("float32")
faiss.normalize_L2(q_emb)
index = vector_store.index
D, I = index.search(q_emb, k=index.ntotal)
selected = []
for score, idx in zip(D[0], I[0]):
if idx == -1: continue
if score >= threshold:
doc_id = vector_store.index_to_docstore_id[idx]
doc = vector_store.docstore.search(doc_id)
selected.append((doc, score))
selected.sort(key=lambda x: x[1], reverse=True)
return selected
# === Hugging Face 生成單一 Log 分析回答 (保持不變) ===
def generate_rag_response_hf_for_log(client, model_id, log_sequence_text, user_prompt, sys_prompt, vector_store, threshold, max_output_tokens, temperature, top_p):
if client is None: return "ERROR: Client Error", ""
context_text = ""
if vector_store:
selected = faiss_cosine_search_all(vector_store, log_sequence_text, threshold)
if selected:
retrieved_contents = [f"--- Reference Chunk (sim={score:.3f}) ---\n{doc.page_content}" for i, (doc, score) in enumerate(selected[:5])]
context_text = "\n".join(retrieved_contents)
rag_instruction = f"""=== RETRIEVED REFERENCE CONTEXT (Cosine ≥ {threshold}) ==={context_text if context_text else 'No relevant reference context found.'}=== END REFERENCE CONTEXT ===\nANALYSIS INSTRUCTION: {user_prompt}\nBased on the provided LOG SEQUENCE and REFERENCE CONTEXT, you must analyze the **entire sequence** to detect any continuous attack chains or evolving threats."""
log_content_section = f"""=== CURRENT LOG SEQUENCE TO ANALYZE (Window Size: {WINDOW_SIZE}) ===\n{log_sequence_text}\n=== END LOG SEQUENCE ==="""
messages = [
{"role": "system", "content": sys_prompt},
{"role": "user", "content": f"{rag_instruction}\n\n{log_content_section}"}
]
try:
response_stream = client.chat_completion(messages, max_tokens=max_output_tokens, temperature=temperature, top_p=top_p, stream=False)
if response_stream and response_stream.choices:
return response_stream.choices[0].message.content.strip(), context_text
else: return "Format Error", context_text
except Exception as e: return f"Model Error: {str(e)}", context_text
# =======================================================================
# === 檔案處理區塊 (RAG 檔案) ===
if rag_uploaded_file:
file_key = f"vs_{rag_uploaded_file.name}_{rag_uploaded_file.size}"
if st.session_state.rag_current_file_key != file_key or 'vector_store' not in st.session_state:
with st.spinner(f"正在建立 RAG 參考知識庫 ({rag_uploaded_file.name})..."):
vs, msg = process_file_to_faiss(rag_uploaded_file)
if vs:
st.session_state.vector_store = vs
st.session_state.rag_current_file_key = file_key
st.toast(f"RAG 參考知識庫已更新!{msg}", icon="✅")
else: st.error(msg)
elif 'vector_store' in st.session_state:
del st.session_state.vector_store
del st.session_state.rag_current_file_key
st.info("RAG 檔案已移除,已清除相關知識庫。")
# === 檔案處理區塊 (批量分析檔案 - 重大修改處) ===
# 支援 JSON, CSV, TXT 並統一轉換為 list of dicts
if batch_uploaded_file:
batch_file_key = f"batch_{batch_uploaded_file.name}_{batch_uploaded_file.size}"
if st.session_state.batch_current_file_key != batch_file_key or 'json_data_for_batch' not in st.session_state:
try:
stringio = io.StringIO(batch_uploaded_file.getvalue().decode("utf-8"))
parsed_data = None
# --- Case 1: JSON ---
if batch_uploaded_file.name.lower().endswith('.json'):
parsed_data = json.load(stringio)
st.toast("JSON 檔案載入成功", icon="📄")
# --- Case 2: CSV ---
elif batch_uploaded_file.name.lower().endswith('.csv'):
# 使用 DictReader 將 CSV 轉為 List of Dicts
reader = csv.DictReader(stringio)
parsed_data = list(reader)
st.toast("CSV 檔案已轉換為 JSON 結構", icon="📊")
# --- Case 3: TXT ---
else: # 預設為 TXT
# 將每一行包裝成一個 JSON 物件: {"raw_content": "line text"}
lines = stringio.readlines()
parsed_data = [{"raw_log_entry": line.strip()} for line in lines if line.strip()]
st.toast("TXT 檔案已轉換為 JSON 結構", icon="📝")
# 儲存處理後的數據
st.session_state.json_data_for_batch = parsed_data
st.session_state.batch_current_file_key = batch_file_key
except Exception as e:
st.error(f"檔案解析錯誤: {e}")
if 'json_data_for_batch' in st.session_state:
del st.session_state.json_data_for_batch
elif 'json_data_for_batch' in st.session_state:
del st.session_state.json_data_for_batch
del st.session_state.batch_current_file_key
if "batch_results" in st.session_state:
del st.session_state.batch_results
st.info("批量分析檔案已移除,已清除相關數據。")
# === 執行批量分析邏輯 ===
if st.session_state.execute_batch_analysis and 'json_data_for_batch' in st.session_state:
st.session_state.execute_batch_analysis = False
start_time = time.time()
st.session_state.batch_results = []
if inference_client is None:
st.error("Client 未連線,無法執行。")
else:
data_to_process = st.session_state.json_data_for_batch
logs_list = []
# 處理不同的 JSON 結構 (Dict vs List)
if isinstance(data_to_process, list):
logs_list = data_to_process
elif isinstance(data_to_process, dict):
# 嘗試尋找常見的 key
if 'alerts' in data_to_process and isinstance(data_to_process['alerts'], list):
logs_list = data_to_process['alerts']
elif 'logs' in data_to_process and isinstance(data_to_process['logs'], list):
logs_list = data_to_process['logs']
else:
logs_list = [data_to_process]
else:
logs_list = [data_to_process]
if logs_list:
vs = st.session_state.get("vector_store", None)
# --- 關鍵:在這裡做 JSON String 的轉換 ---
# 無論來源是 CSV(Dict) 還是 TXT(Dict),都在這裡用 json.dumps 轉成字串
# 這保證了 Prompt 收到的永遠是 JSON 格式的文字
formatted_logs = [json.dumps(log, indent=2, ensure_ascii=False) for log in logs_list]
analysis_sequences = []
for i in range(len(formatted_logs)):
start_index = max(0, i - WINDOW_SIZE + 1)
end_index = i + 1
current_window = formatted_logs[start_index:end_index]
sequence_text = []
for j, log_str in enumerate(current_window):
is_target = " <<< TARGET LOG TO ANALYZE" if j == len(current_window) - 1 else ""
sequence_text.append(f"--- Log Index {i - len(current_window) + j + 1} ({len(current_window)-j} prior logs){is_target} ---\n{log_str}")
analysis_sequences.append({
"sequence_text": "\n\n".join(sequence_text),
"target_log_id": i + 1,
"original_log_entry": logs_list[i]
})
total_sequences = len(analysis_sequences)
st.header(f"⚡ 批量分析執行中 (平移視窗 $N={WINDOW_SIZE}$)...")
progress_bar = st.progress(0, text=f"準備處理 {total_sequences} 個序列...")
results_container = st.container()
full_report_chunks = ["## Cybersecurity Batch Analysis Report\n\n"]
for i, seq_data in enumerate(analysis_sequences):
log_id = seq_data["target_log_id"]
progress_bar.progress((i + 1) / total_sequences, text=f"Processing {i + 1}/{total_sequences} (Log #{log_id})...")
try:
response, retrieved_ctx = generate_rag_response_hf_for_log(
client=inference_client,
model_id=MODEL_ID,
log_sequence_text=seq_data["sequence_text"],
user_prompt=analysis_prompt,
sys_prompt=system_prompt,
vector_store=vs,
threshold=similarity_threshold,
max_output_tokens=max_output_tokens,
temperature=temperature,
top_p=top_p
)
item = {
"log_id": log_id,
"log_content": seq_data["original_log_entry"],
"sequence_analyzed": seq_data["sequence_text"],
"analysis_result": response,
"context": retrieved_ctx
}
st.session_state.batch_results.append(item)
with results_container:
st.subheader(f"Log/Alert #{item['log_id']}")
with st.expander("序列內容 (JSON Format)"):
st.code(item["sequence_analyzed"], language='json') # 這裡顯示的會是 JSON 格式
is_high = any(x in response.lower() for x in ['high risk'])
if is_high: st.error(item['analysis_result'])
else: st.info(item['analysis_result'])
if item['context']:
with st.expander("參考 RAG 片段"): st.code(item['context'])
st.markdown("---")
log_content_str_for_report = json.dumps(item["log_content"], indent=2, ensure_ascii=False).replace("`", "\\`")
full_report_chunks.append(f"---\n\n### Log #{item['log_id']}\n```json\n{log_content_str_for_report}\n```\nResult:\n{item['analysis_result']}\n")
except Exception as e:
st.error(f"Error Log {log_id}: {e}")
end_time = time.time()
progress_bar.empty()
st.success(f"完成!耗時 {end_time - start_time:.2f} 秒。")
else:
st.error("無法提取有效 Log,請檢查檔案格式。")
# === 顯示結果 (歷史紀錄) ===
if st.session_state.get("batch_results") and not st.session_state.execute_batch_analysis:
st.header("⚡ 歷史分析結果")
full_report_chunks = ["## Report\n\n"]
for item in st.session_state.batch_results:
log_content_str_for_report = json.dumps(item["log_content"], indent=2, ensure_ascii=False).replace("`", "\\`")
full_report_chunks.append(f"---\n\n### Log #{item['log_id']}\n```json\n{log_content_str_for_report}\n```\n{item['analysis_result']}\n")
st.download_button("📥 下載完整報告 (.md)", "\n".join(full_report_chunks), "report.md", "text/markdown")