Spaces:
Sleeping
Sleeping
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +104 -117
src/streamlit_app.py
CHANGED
|
@@ -7,27 +7,22 @@ import faiss
|
|
| 7 |
import uuid
|
| 8 |
import time
|
| 9 |
import sys
|
| 10 |
-
|
| 11 |
-
# === HuggingFace 模型相關套件 (新增) ===
|
| 12 |
try:
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
import
|
| 16 |
-
#
|
| 17 |
-
# from accelerate import Accelerator # 建議安裝
|
| 18 |
-
# import bitsandbytes # 建議安裝
|
| 19 |
except ImportError:
|
| 20 |
-
st.error("請檢查是否安裝了所有 Hugging Face 相關依賴:pip install
|
| 21 |
-
#
|
| 22 |
-
|
| 23 |
-
|
| 24 |
# === LangChain/RAG 相關套件 (保持不變) ===
|
| 25 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 26 |
from langchain_core.documents import Document
|
| 27 |
from langchain_community.vectorstores import FAISS
|
| 28 |
from langchain_community.vectorstores.utils import DistanceStrategy
|
| 29 |
from langchain_community.docstore.in_memory import InMemoryDocstore
|
| 30 |
-
|
| 31 |
# 嘗試匯入 pypdf
|
| 32 |
try:
|
| 33 |
import pypdf
|
|
@@ -36,11 +31,11 @@ except ImportError:
|
|
| 36 |
|
| 37 |
# --- 頁面設定 ---
|
| 38 |
st.set_page_config(page_title="Cybersecurity AI Assistant (Hugging Face RAG & Batch Analysis)", page_icon="🛡️", layout="wide")
|
| 39 |
-
st.title("🛡️
|
| 40 |
-
st.markdown("已啟用:**IndexFlatIP** + **L2 正規化** + **Hugging Face
|
| 41 |
|
| 42 |
-
# 設定模型 ID (
|
| 43 |
-
MODEL_ID = "
|
| 44 |
WINDOW_SIZE = 8
|
| 45 |
|
| 46 |
# --- 側邊欄設定 ---
|
|
@@ -48,8 +43,12 @@ with st.sidebar:
|
|
| 48 |
st.header("⚙️ 設定")
|
| 49 |
|
| 50 |
# === 替換為 Hugging Face 模型名稱顯示 (移除 API Key 輸入) ===
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
st.divider()
|
| 55 |
|
|
@@ -78,7 +77,10 @@ with st.sidebar:
|
|
| 78 |
|
| 79 |
if json_uploaded_file: # 移除 API Key 檢查
|
| 80 |
if st.button("🚀 執行批量分析"):
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
| 82 |
else:
|
| 83 |
st.info("請上傳 JSON 檔案以啟用批量分析按鈕。")
|
| 84 |
|
|
@@ -93,7 +95,8 @@ with st.sidebar:
|
|
| 93 |
st.divider()
|
| 94 |
|
| 95 |
st.subheader("模型參數")
|
| 96 |
-
|
|
|
|
| 97 |
max_output_tokens = st.slider("Max Output Tokens", 128, 4096, 2048, 128)
|
| 98 |
temperature = st.slider("Temperature", 0.0, 1.0, 0.1, 0.1)
|
| 99 |
top_p = st.slider("Top P", 0.1, 1.0, 0.95, 0.05)
|
|
@@ -107,44 +110,34 @@ with st.sidebar:
|
|
| 107 |
|
| 108 |
# --- 初始化 Hugging Face LLM Client (重大替換) ---
|
| 109 |
@st.cache_resource
|
| 110 |
-
def
|
| 111 |
-
if
|
| 112 |
-
st.error("無法載入 Hugging Face 依賴,請安裝:pip install transformers torch accelerate bitsandbytes")
|
| 113 |
return None
|
|
|
|
| 114 |
try:
|
| 115 |
-
#
|
| 116 |
-
|
| 117 |
-
model = AutoModelForCausalLM.from_pretrained(
|
| 118 |
model_id,
|
| 119 |
-
|
| 120 |
-
device_map="auto", # <--- 讓 accelerate 管理裝置
|
| 121 |
-
trust_remote_code=True,
|
| 122 |
-
# load_in_4bit=True # 如果需要 4-bit 量化
|
| 123 |
-
)
|
| 124 |
-
# 使用 pipeline 簡化呼叫
|
| 125 |
-
llm_pipeline = pipeline(
|
| 126 |
-
"text-generation",
|
| 127 |
-
model=model,
|
| 128 |
-
tokenizer=tokenizer,
|
| 129 |
-
# device=(0 if torch.cuda.is_available() else -1) # <--- **移除此參數**
|
| 130 |
)
|
| 131 |
-
st.success(f"Hugging Face
|
| 132 |
-
return
|
| 133 |
except Exception as e:
|
| 134 |
-
st.error(f"Hugging Face
|
| 135 |
return None
|
| 136 |
|
| 137 |
-
# 在 main 區塊外初始化
|
| 138 |
-
|
| 139 |
-
if
|
| 140 |
-
with st.spinner(f"
|
| 141 |
-
|
| 142 |
|
| 143 |
-
if
|
| 144 |
-
st.warning("Hugging Face
|
|
|
|
|
|
|
| 145 |
# =======================================================================
|
| 146 |
|
| 147 |
-
|
| 148 |
# === Embedding 模型 (用於 RAG 參考庫) (保持不變) ===
|
| 149 |
@st.cache_resource
|
| 150 |
def load_embedding_model():
|
|
@@ -167,6 +160,7 @@ with st.spinner("正在載入 Embedding 模型..."):
|
|
| 167 |
|
| 168 |
# === 建立向量庫 / Search 函數 (保持不變) ===
|
| 169 |
def process_file_to_faiss(uploaded_file):
|
|
|
|
| 170 |
text_content = ""
|
| 171 |
try:
|
| 172 |
if uploaded_file.type == "application/pdf":
|
|
@@ -179,32 +173,32 @@ def process_file_to_faiss(uploaded_file):
|
|
| 179 |
else:
|
| 180 |
stringio = io.StringIO(uploaded_file.getvalue().decode("utf-8"))
|
| 181 |
text_content = stringio.read()
|
| 182 |
-
|
| 183 |
if not text_content.strip():
|
| 184 |
return None, "File is empty"
|
| 185 |
-
|
| 186 |
# 嘗試以 </Event> 分割 Log,否則以換行符分割
|
| 187 |
events = [e + "</Event>" for e in text_content.split("</Event>") if e.strip()]
|
| 188 |
if len(events) <= 1:
|
| 189 |
events = [line for line in text_content.split("\n") if line.strip()]
|
| 190 |
-
|
| 191 |
docs = [Document(page_content=e) for e in events]
|
| 192 |
-
|
| 193 |
if not docs:
|
| 194 |
return None, "No documents created"
|
| 195 |
-
|
| 196 |
embeddings = embedding_model.embed_documents([d.page_content for d in docs])
|
| 197 |
embeddings_np = np.array(embeddings).astype("float32")
|
| 198 |
faiss.normalize_L2(embeddings_np) # L2 正規化
|
| 199 |
-
|
| 200 |
dimension = embeddings_np.shape[1]
|
| 201 |
index = faiss.IndexFlatIP(dimension) # IndexFlatIP (內積)
|
| 202 |
index.add(embeddings_np)
|
| 203 |
-
|
| 204 |
doc_ids = [str(uuid.uuid4()) for _ in range(len(docs))]
|
| 205 |
docstore = InMemoryDocstore({_id: doc for _id, doc in zip(doc_ids, docs)})
|
| 206 |
index_to_docstore_id = {i: _id for i, _id in enumerate(doc_ids)}
|
| 207 |
-
|
| 208 |
vector_store = FAISS(
|
| 209 |
embedding_function=embedding_model,
|
| 210 |
index=index,
|
|
@@ -212,19 +206,20 @@ def process_file_to_faiss(uploaded_file):
|
|
| 212 |
index_to_docstore_id=index_to_docstore_id,
|
| 213 |
distance_strategy=DistanceStrategy.COSINE # 使用 Cosine 距離 (對應 IndexFlatIP)
|
| 214 |
)
|
| 215 |
-
|
| 216 |
return vector_store, f"{len(docs)} chunks created."
|
| 217 |
except Exception as e:
|
| 218 |
return None, f"Error: {str(e)}"
|
| 219 |
|
| 220 |
def faiss_cosine_search_all(vector_store, query, threshold):
|
|
|
|
| 221 |
q_emb = embedding_model.embed_query(query)
|
| 222 |
q_emb = np.array([q_emb]).astype("float32")
|
| 223 |
faiss.normalize_L2(q_emb)
|
| 224 |
-
|
| 225 |
index = vector_store.index
|
| 226 |
D, I = index.search(q_emb, k=index.ntotal)
|
| 227 |
-
|
| 228 |
selected = []
|
| 229 |
for score, idx in zip(D[0], I[0]):
|
| 230 |
if idx == -1: continue
|
|
@@ -233,20 +228,20 @@ def faiss_cosine_search_all(vector_store, query, threshold):
|
|
| 233 |
doc_id = vector_store.index_to_docstore_id[idx]
|
| 234 |
doc = vector_store.docstore.search(doc_id)
|
| 235 |
selected.append((doc, score))
|
| 236 |
-
|
| 237 |
selected.sort(key=lambda x: x[1], reverse=True)
|
| 238 |
return selected
|
| 239 |
|
| 240 |
-
# === Hugging Face 生成單一 Log 分析回答 (核心批量處理函數
|
| 241 |
-
def generate_rag_response_hf_for_log(
|
| 242 |
"""
|
| 243 |
-
使用 Hugging Face
|
| 244 |
"""
|
| 245 |
-
if
|
| 246 |
-
return "ERROR: Hugging Face
|
| 247 |
-
|
| 248 |
context_text = ""
|
| 249 |
-
# 1. RAG 檢索邏輯
|
| 250 |
if vector_store:
|
| 251 |
selected = faiss_cosine_search_all(vector_store, log_sequence_text, threshold)
|
| 252 |
if selected:
|
|
@@ -255,58 +250,53 @@ def generate_rag_response_hf_for_log(llm_pipeline, model_id, log_sequence_text,
|
|
| 255 |
for i, (doc, score) in enumerate(selected[:5]) # 限制檢索結果數量
|
| 256 |
]
|
| 257 |
context_text = "\n".join(retrieved_contents)
|
| 258 |
-
|
| 259 |
-
# 2. 建構
|
| 260 |
-
rag_instruction = f"""=== RETRIEVED REFERENCE CONTEXT (Cosine ≥ {threshold}) ===
|
| 261 |
-
|
| 262 |
-
=== END REFERENCE CONTEXT ===
|
| 263 |
ANALYSIS INSTRUCTION: {user_prompt}
|
| 264 |
Based on the provided LOG SEQUENCE and REFERENCE CONTEXT, you must analyze the **entire sequence** to detect any continuous attack chains or evolving threats. Focus on the **last log entry in the sequence** to determine its final criticality and priority, considering the preceding {WINDOW_SIZE} logs."""
|
| 265 |
-
|
| 266 |
-
log_content_section = f"""=== CURRENT LOG SEQUENCE TO ANALYZE (Window Size: {WINDOW_SIZE}) ===
|
| 267 |
-
{log_sequence_text}
|
| 268 |
-
=== END LOG SEQUENCE ==="""
|
| 269 |
|
| 270 |
-
|
| 271 |
-
# 注意:fdtn-ai/Foundation-Sec-1.1-8B-Instruct 遵循 ChatML 格式,但此處使用簡化的 instruction-tuning 格式
|
| 272 |
-
full_prompt = (
|
| 273 |
-
f"**SYSTEM INSTRUCTION**: {sys_prompt}\n\n"
|
| 274 |
-
f"**RAG & ANALYSIS INSTRUCTION**:\n{rag_instruction}\n\n"
|
| 275 |
-
f"**LOG DATA**:\n{log_content_section}\n\n"
|
| 276 |
-
f"**RESPONSE**:"
|
| 277 |
-
)
|
| 278 |
|
| 279 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
try:
|
| 281 |
-
#
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
temperature=temperature,
|
| 286 |
top_p=top_p,
|
| 287 |
-
|
| 288 |
-
return_full_text=False # 只返回生成的文本
|
| 289 |
)
|
| 290 |
-
|
| 291 |
-
# 處理 pipeline 的輸出格式
|
| 292 |
-
if response and isinstance(response, list) and 'generated_text' in response[0]:
|
| 293 |
-
return response[0]['generated_text'].strip(), context_text
|
| 294 |
-
else:
|
| 295 |
-
return f"Hugging Face Pipeline 輸出格式錯誤: {response}", context_text
|
| 296 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
except Exception as e:
|
| 298 |
# 如果模型呼叫失敗,回傳詳細錯誤訊息
|
| 299 |
return f"Hugging Face Model Error: {str(e)}", context_text
|
| 300 |
|
| 301 |
-
|
| 302 |
# === 檔案處理和主執行邏輯 (保持結構,替換 LLM 呼叫) ===
|
| 303 |
-
# 初始化 Session State
|
| 304 |
if 'execute_batch_analysis' not in st.session_state:
|
| 305 |
st.session_state.execute_batch_analysis = False
|
| 306 |
if 'batch_results' not in st.session_state:
|
| 307 |
st.session_state.batch_results = None
|
| 308 |
-
|
| 309 |
-
# --- 1. 處理 RAG 知識庫檔案 (rag_uploaded_file) ---
|
| 310 |
if 'rag_current_file_key' not in st.session_state:
|
| 311 |
st.session_state.rag_current_file_key = None
|
| 312 |
|
|
@@ -328,8 +318,8 @@ elif 'vector_store' in st.session_state:
|
|
| 328 |
del st.session_state.vector_store
|
| 329 |
del st.session_state.rag_current_file_key
|
| 330 |
st.info("RAG 檔案已移除,已清除相關知識庫。")
|
| 331 |
-
|
| 332 |
-
# --- 2. 處理 JSON 批量分析檔案 (json_uploaded_file) ---
|
| 333 |
if 'json_current_file_key' not in st.session_state:
|
| 334 |
st.session_state.json_current_file_key = None
|
| 335 |
|
|
@@ -343,12 +333,11 @@ if json_uploaded_file:
|
|
| 343 |
st.session_state.json_data_for_batch = json_data
|
| 344 |
st.session_state.json_current_file_key = json_file_key
|
| 345 |
st.toast("JSON Log 檔案已載入,請按 '執行批量分析'。", icon="📄")
|
| 346 |
-
|
| 347 |
except Exception as e:
|
| 348 |
st.error(f"JSON 檔案解析錯誤: {e}")
|
| 349 |
if 'json_data_for_batch' in st.session_state:
|
| 350 |
del st.session_state.json_data_for_batch
|
| 351 |
-
|
| 352 |
# 檔案移除/狀態清理 (如果使用者移除了 JSON 檔案)
|
| 353 |
elif 'json_data_for_batch' in st.session_state:
|
| 354 |
del st.session_state.json_data_for_batch
|
|
@@ -363,11 +352,10 @@ if st.session_state.execute_batch_analysis and 'json_data_for_batch' in st.sessi
|
|
| 363 |
start_time = time.time() # 開始計時
|
| 364 |
st.session_state.batch_results = []
|
| 365 |
|
| 366 |
-
if
|
| 367 |
-
st.error("Hugging Face
|
| 368 |
-
# 由於這是一個 Streamlit App,我們不直接 st.stop(),讓使用者可以檢查設定
|
| 369 |
st.session_state.execute_batch_analysis = False
|
| 370 |
-
|
| 371 |
data_to_process = st.session_state.json_data_for_batch
|
| 372 |
|
| 373 |
# 提取 Log 列表的邏輯 (保持不變)
|
|
@@ -385,7 +373,7 @@ if st.session_state.execute_batch_analysis and 'json_data_for_batch' in st.sessi
|
|
| 385 |
logs_list = [data_to_process]
|
| 386 |
else:
|
| 387 |
logs_list = [data_to_process]
|
| 388 |
-
|
| 389 |
if logs_list:
|
| 390 |
vs = st.session_state.get("vector_store", None)
|
| 391 |
if vs:
|
|
@@ -393,8 +381,7 @@ if st.session_state.execute_batch_analysis and 'json_data_for_batch' in st.sessi
|
|
| 393 |
else:
|
| 394 |
st.warning("⚠️ RAG 知識庫未載入,將單純執行 Log 分析。")
|
| 395 |
|
| 396 |
-
# --- 新增:創建平移視窗序列 ---
|
| 397 |
-
|
| 398 |
# 將所有 Log 轉換為 JSON 格式化字串列表,以便後續拼接
|
| 399 |
formatted_logs = [json.dumps(log, indent=2, ensure_ascii=False) for log in logs_list]
|
| 400 |
|
|
@@ -436,9 +423,9 @@ if st.session_state.execute_batch_analysis and 'json_data_for_batch' in st.sessi
|
|
| 436 |
progress_bar.progress((i + 1) / total_sequences, text=f"已處理 {i + 1}/{total_sequences} 個序列 (目標 Log #{log_id})...")
|
| 437 |
|
| 438 |
try:
|
| 439 |
-
# *** 替換為
|
| 440 |
response, retrieved_ctx = generate_rag_response_hf_for_log(
|
| 441 |
-
|
| 442 |
model_id=MODEL_ID,
|
| 443 |
log_sequence_text=seq_data["sequence_text"],
|
| 444 |
user_prompt=analysis_prompt,
|
|
@@ -460,7 +447,7 @@ if st.session_state.execute_batch_analysis and 'json_data_for_batch' in st.sessi
|
|
| 460 |
}
|
| 461 |
st.session_state.batch_results.append(item)
|
| 462 |
|
| 463 |
-
# 結果顯示邏輯
|
| 464 |
with results_container:
|
| 465 |
st.subheader(f"Log/Alert #{item['log_id']} (序列分析完成)")
|
| 466 |
with st.expander(f"序列內容 (包含 {len(seq_data['sequence_text'].split('--- Log Index'))-1} 條 Log)"):
|
|
@@ -519,7 +506,7 @@ if st.session_state.batch_results and not st.session_state.execute_batch_analysi
|
|
| 519 |
for item in st.session_state.batch_results:
|
| 520 |
log_content_str_for_report = json.dumps(item["log_content"], indent=2, ensure_ascii=False).replace("`", "\\`")
|
| 521 |
full_report_chunks.append(f"---\n\n### Log/Alert #{item['log_id']}\n\n#### 原始內容\n```json\n{log_content_str_for_report}\n```\n\n#### LLM 分析結果\n{item['analysis_result']}\n")
|
| 522 |
-
|
| 523 |
st.info(f"偵測到 {len(st.session_state.batch_results)} 條 Log 的歷史分析結果。")
|
| 524 |
st.download_button(
|
| 525 |
label="📥 下載上次的完整報告 (.md)",
|
|
|
|
| 7 |
import uuid
|
| 8 |
import time
|
| 9 |
import sys
|
| 10 |
+
# === HuggingFace 模型相關套件 (替換為 InferenceClient) ===
|
|
|
|
| 11 |
try:
|
| 12 |
+
from huggingface_hub import InferenceClient
|
| 13 |
+
# 移除本地模型相關導入
|
| 14 |
+
# from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
| 15 |
+
# import torch
|
|
|
|
|
|
|
| 16 |
except ImportError:
|
| 17 |
+
st.error("請檢查是否安裝了所有 Hugging Face 相關依賴:pip install huggingface-hub")
|
| 18 |
+
# InferenceClient = None # 保留 InferenceClient
|
| 19 |
+
|
|
|
|
| 20 |
# === LangChain/RAG 相關套件 (保持不變) ===
|
| 21 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 22 |
from langchain_core.documents import Document
|
| 23 |
from langchain_community.vectorstores import FAISS
|
| 24 |
from langchain_community.vectorstores.utils import DistanceStrategy
|
| 25 |
from langchain_community.docstore.in_memory import InMemoryDocstore
|
|
|
|
| 26 |
# 嘗試匯入 pypdf
|
| 27 |
try:
|
| 28 |
import pypdf
|
|
|
|
| 31 |
|
| 32 |
# --- 頁面設定 ---
|
| 33 |
st.set_page_config(page_title="Cybersecurity AI Assistant (Hugging Face RAG & Batch Analysis)", page_icon="🛡️", layout="wide")
|
| 34 |
+
st.title("🛡️ Meta-Llama-3-8B-Instruct with FAISS RAG & Batch Analysis (Inference Client)")
|
| 35 |
+
st.markdown("已啟用:**IndexFlatIP** + **L2 正規化** + **Hugging Face Inference Client (API)**。上傳 JSON 執行批量分析,上傳其他檔案作為 RAG 參考庫。")
|
| 36 |
|
| 37 |
+
# 設定模型 ID (替換為您指定的模型)
|
| 38 |
+
MODEL_ID = "fdtn-ai/Foundation-Sec-8B-Instruct"
|
| 39 |
WINDOW_SIZE = 8
|
| 40 |
|
| 41 |
# --- 側邊欄設定 ---
|
|
|
|
| 43 |
st.header("⚙️ 設定")
|
| 44 |
|
| 45 |
# === 替換為 Hugging Face 模型名稱顯示 (移除 API Key 輸入) ===
|
| 46 |
+
# ⚠️ 注意: HF Token 必須在環境變數 HF_TOKEN 中設定
|
| 47 |
+
if not os.environ.get("HF_TOKEN"):
|
| 48 |
+
st.error("環境變數 **HF_TOKEN** 未設定。請設定後重新啟動應用程式。")
|
| 49 |
+
|
| 50 |
+
st.info(f"LLM 模型:**{MODEL_ID}** (Hugging Face Inference API)")
|
| 51 |
+
st.warning("⚠️ **注意**: 該模型使用 Inference API 呼叫,請確保您的 HF Token 具有存取權限。")
|
| 52 |
|
| 53 |
st.divider()
|
| 54 |
|
|
|
|
| 77 |
|
| 78 |
if json_uploaded_file: # 移除 API Key 檢查
|
| 79 |
if st.button("🚀 執行批量分析"):
|
| 80 |
+
if not os.environ.get("HF_TOKEN"):
|
| 81 |
+
st.error("無法執行,環境變數 **HF_TOKEN** 未設定。")
|
| 82 |
+
else:
|
| 83 |
+
st.session_state.execute_batch_analysis = True
|
| 84 |
else:
|
| 85 |
st.info("請上傳 JSON 檔案以啟用批量分析按鈕。")
|
| 86 |
|
|
|
|
| 95 |
st.divider()
|
| 96 |
|
| 97 |
st.subheader("模型參數")
|
| 98 |
+
# Llama 3 使用 'system' 角色
|
| 99 |
+
system_prompt = st.text_area("System Prompt (LLM 使用)", value="You are a Senior Security Analyst, named Ernest. You provide expert, authoritative, and concise advice on Information Security, Network Security, and Cyber Threat Intelligence. Your analysis must be based strictly on the provided context.", height=100)
|
| 100 |
max_output_tokens = st.slider("Max Output Tokens", 128, 4096, 2048, 128)
|
| 101 |
temperature = st.slider("Temperature", 0.0, 1.0, 0.1, 0.1)
|
| 102 |
top_p = st.slider("Top P", 0.1, 1.0, 0.95, 0.05)
|
|
|
|
| 110 |
|
| 111 |
# --- 初始化 Hugging Face LLM Client (重大替換) ---
|
| 112 |
@st.cache_resource
|
| 113 |
+
def load_inference_client(model_id):
|
| 114 |
+
if not os.environ.get("HF_TOKEN"):
|
|
|
|
| 115 |
return None
|
| 116 |
+
|
| 117 |
try:
|
| 118 |
+
# 使用 InferenceClient 替換 AutoModelForCausalLM 的載入
|
| 119 |
+
client = InferenceClient(
|
|
|
|
| 120 |
model_id,
|
| 121 |
+
token=os.environ.get("HF_TOKEN")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
)
|
| 123 |
+
st.success(f"Hugging Face Inference Client **{model_id}** 載入成功。")
|
| 124 |
+
return client
|
| 125 |
except Exception as e:
|
| 126 |
+
st.error(f"Hugging Face Inference Client 載入失敗: {e}")
|
| 127 |
return None
|
| 128 |
|
| 129 |
+
# 在 main 區塊外初始化 client
|
| 130 |
+
inference_client = None
|
| 131 |
+
if os.environ.get("HF_TOKEN"):
|
| 132 |
+
with st.spinner(f"正在連線到 Inference Client: {MODEL_ID}..."):
|
| 133 |
+
inference_client = load_inference_client(MODEL_ID)
|
| 134 |
|
| 135 |
+
if inference_client is None and os.environ.get("HF_TOKEN"):
|
| 136 |
+
st.warning("Hugging Face Inference Client 無法連線。請檢查您的 HF Token 和模型存取權限。")
|
| 137 |
+
elif not os.environ.get("HF_TOKEN"):
|
| 138 |
+
st.error("請在環境變數中設定 HF_TOKEN 以啟用 LLM。")
|
| 139 |
# =======================================================================
|
| 140 |
|
|
|
|
| 141 |
# === Embedding 模型 (用於 RAG 參考庫) (保持不變) ===
|
| 142 |
@st.cache_resource
|
| 143 |
def load_embedding_model():
|
|
|
|
| 160 |
|
| 161 |
# === 建立向量庫 / Search 函數 (保持不變) ===
|
| 162 |
def process_file_to_faiss(uploaded_file):
|
| 163 |
+
# 函數內容保持不變 (與原代碼相同)
|
| 164 |
text_content = ""
|
| 165 |
try:
|
| 166 |
if uploaded_file.type == "application/pdf":
|
|
|
|
| 173 |
else:
|
| 174 |
stringio = io.StringIO(uploaded_file.getvalue().decode("utf-8"))
|
| 175 |
text_content = stringio.read()
|
| 176 |
+
|
| 177 |
if not text_content.strip():
|
| 178 |
return None, "File is empty"
|
| 179 |
+
|
| 180 |
# 嘗試以 </Event> 分割 Log,否則以換行符分割
|
| 181 |
events = [e + "</Event>" for e in text_content.split("</Event>") if e.strip()]
|
| 182 |
if len(events) <= 1:
|
| 183 |
events = [line for line in text_content.split("\n") if line.strip()]
|
| 184 |
+
|
| 185 |
docs = [Document(page_content=e) for e in events]
|
| 186 |
+
|
| 187 |
if not docs:
|
| 188 |
return None, "No documents created"
|
| 189 |
+
|
| 190 |
embeddings = embedding_model.embed_documents([d.page_content for d in docs])
|
| 191 |
embeddings_np = np.array(embeddings).astype("float32")
|
| 192 |
faiss.normalize_L2(embeddings_np) # L2 正規化
|
| 193 |
+
|
| 194 |
dimension = embeddings_np.shape[1]
|
| 195 |
index = faiss.IndexFlatIP(dimension) # IndexFlatIP (內積)
|
| 196 |
index.add(embeddings_np)
|
| 197 |
+
|
| 198 |
doc_ids = [str(uuid.uuid4()) for _ in range(len(docs))]
|
| 199 |
docstore = InMemoryDocstore({_id: doc for _id, doc in zip(doc_ids, docs)})
|
| 200 |
index_to_docstore_id = {i: _id for i, _id in enumerate(doc_ids)}
|
| 201 |
+
|
| 202 |
vector_store = FAISS(
|
| 203 |
embedding_function=embedding_model,
|
| 204 |
index=index,
|
|
|
|
| 206 |
index_to_docstore_id=index_to_docstore_id,
|
| 207 |
distance_strategy=DistanceStrategy.COSINE # 使用 Cosine 距離 (對應 IndexFlatIP)
|
| 208 |
)
|
| 209 |
+
|
| 210 |
return vector_store, f"{len(docs)} chunks created."
|
| 211 |
except Exception as e:
|
| 212 |
return None, f"Error: {str(e)}"
|
| 213 |
|
| 214 |
def faiss_cosine_search_all(vector_store, query, threshold):
|
| 215 |
+
# 函數內容保持不變 (與原代碼相同)
|
| 216 |
q_emb = embedding_model.embed_query(query)
|
| 217 |
q_emb = np.array([q_emb]).astype("float32")
|
| 218 |
faiss.normalize_L2(q_emb)
|
| 219 |
+
|
| 220 |
index = vector_store.index
|
| 221 |
D, I = index.search(q_emb, k=index.ntotal)
|
| 222 |
+
|
| 223 |
selected = []
|
| 224 |
for score, idx in zip(D[0], I[0]):
|
| 225 |
if idx == -1: continue
|
|
|
|
| 228 |
doc_id = vector_store.index_to_docstore_id[idx]
|
| 229 |
doc = vector_store.docstore.search(doc_id)
|
| 230 |
selected.append((doc, score))
|
| 231 |
+
|
| 232 |
selected.sort(key=lambda x: x[1], reverse=True)
|
| 233 |
return selected
|
| 234 |
|
| 235 |
+
# === Hugging Face 生成單一 Log 分析回答 (核心批量處理函數 - 重大替換為 InferenceClient) ===
|
| 236 |
+
def generate_rag_response_hf_for_log(client, model_id, log_sequence_text, user_prompt, sys_prompt, vector_store, threshold, max_output_tokens, temperature, top_p):
|
| 237 |
"""
|
| 238 |
+
使用 Hugging Face Inference Client 執行 RAG 增強的 Log 序列分析。
|
| 239 |
"""
|
| 240 |
+
if client is None:
|
| 241 |
+
return "ERROR: Hugging Face Inference Client 未載入或 HF_TOKEN 未設定。", ""
|
| 242 |
+
|
| 243 |
context_text = ""
|
| 244 |
+
# 1. RAG 檢索邏輯 (保持不變)
|
| 245 |
if vector_store:
|
| 246 |
selected = faiss_cosine_search_all(vector_store, log_sequence_text, threshold)
|
| 247 |
if selected:
|
|
|
|
| 250 |
for i, (doc, score) in enumerate(selected[:5]) # 限制檢索結果數量
|
| 251 |
]
|
| 252 |
context_text = "\n".join(retrieved_contents)
|
| 253 |
+
|
| 254 |
+
# 2. 建構 Llama 3 的 ChatML 格式的 Messages 列表
|
| 255 |
+
rag_instruction = f"""=== RETRIEVED REFERENCE CONTEXT (Cosine ≥ {threshold}) ==={context_text if context_text else 'No relevant reference context found.'}=== END REFERENCE CONTEXT ===
|
| 256 |
+
|
|
|
|
| 257 |
ANALYSIS INSTRUCTION: {user_prompt}
|
| 258 |
Based on the provided LOG SEQUENCE and REFERENCE CONTEXT, you must analyze the **entire sequence** to detect any continuous attack chains or evolving threats. Focus on the **last log entry in the sequence** to determine its final criticality and priority, considering the preceding {WINDOW_SIZE} logs."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
|
| 260 |
+
log_content_section = f"""=== CURRENT LOG SEQUENCE TO ANALYZE (Window Size: {WINDOW_SIZE}) ===\n{log_sequence_text}\n=== END LOG SEQUENCE ==="""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
|
| 262 |
+
# 整合 System Prompt、RAG、和 Log 內容到 messages 列表
|
| 263 |
+
# Llama 3 標準的 chat 格式
|
| 264 |
+
messages = [
|
| 265 |
+
{"role": "system", "content": sys_prompt},
|
| 266 |
+
{"role": "user", "content": f"{rag_instruction}\n\n{log_content_section}"}
|
| 267 |
+
]
|
| 268 |
+
|
| 269 |
+
# 3. 呼叫 Hugging Face Inference Client
|
| 270 |
try:
|
| 271 |
+
# 使用 client.chat_completion 替換 pipeline 呼叫
|
| 272 |
+
response_stream = client.chat_completion(
|
| 273 |
+
messages,
|
| 274 |
+
max_tokens=max_output_tokens,
|
| 275 |
temperature=temperature,
|
| 276 |
top_p=top_p,
|
| 277 |
+
stream=False, # 由於是批量分析,不啟用流式輸出,一次性獲得結果
|
|
|
|
| 278 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
|
| 280 |
+
# 處理 chat_completion 的輸出格式 (非流式)
|
| 281 |
+
if response_stream and response_stream.choices:
|
| 282 |
+
# chat_completion 在非流式下返回一個 ChatCompletionResponse
|
| 283 |
+
generated_text = response_stream.choices[0].message.content
|
| 284 |
+
return generated_text.strip(), context_text
|
| 285 |
+
else:
|
| 286 |
+
return f"Hugging Face Inference Client 輸出格式錯誤: {response_stream}", context_text
|
| 287 |
+
|
| 288 |
except Exception as e:
|
| 289 |
# 如果模型呼叫失敗,回傳詳細錯誤訊息
|
| 290 |
return f"Hugging Face Model Error: {str(e)}", context_text
|
| 291 |
|
|
|
|
| 292 |
# === 檔案處理和主執行邏輯 (保持結構,替換 LLM 呼叫) ===
|
| 293 |
+
# 初始化 Session State (保持不變)
|
| 294 |
if 'execute_batch_analysis' not in st.session_state:
|
| 295 |
st.session_state.execute_batch_analysis = False
|
| 296 |
if 'batch_results' not in st.session_state:
|
| 297 |
st.session_state.batch_results = None
|
| 298 |
+
|
| 299 |
+
# --- 1. 處理 RAG 知識庫檔案 (rag_uploaded_file) --- (保持不變)
|
| 300 |
if 'rag_current_file_key' not in st.session_state:
|
| 301 |
st.session_state.rag_current_file_key = None
|
| 302 |
|
|
|
|
| 318 |
del st.session_state.vector_store
|
| 319 |
del st.session_state.rag_current_file_key
|
| 320 |
st.info("RAG 檔案已移除,已清除相關知識庫。")
|
| 321 |
+
|
| 322 |
+
# --- 2. 處理 JSON 批量分析檔案 (json_uploaded_file) --- (保持不變)
|
| 323 |
if 'json_current_file_key' not in st.session_state:
|
| 324 |
st.session_state.json_current_file_key = None
|
| 325 |
|
|
|
|
| 333 |
st.session_state.json_data_for_batch = json_data
|
| 334 |
st.session_state.json_current_file_key = json_file_key
|
| 335 |
st.toast("JSON Log 檔案已載入,請按 '執行批量分析'。", icon="📄")
|
| 336 |
+
|
| 337 |
except Exception as e:
|
| 338 |
st.error(f"JSON 檔案解析錯誤: {e}")
|
| 339 |
if 'json_data_for_batch' in st.session_state:
|
| 340 |
del st.session_state.json_data_for_batch
|
|
|
|
| 341 |
# 檔案移除/狀態清理 (如果使用者移除了 JSON 檔案)
|
| 342 |
elif 'json_data_for_batch' in st.session_state:
|
| 343 |
del st.session_state.json_data_for_batch
|
|
|
|
| 352 |
start_time = time.time() # 開始計時
|
| 353 |
st.session_state.batch_results = []
|
| 354 |
|
| 355 |
+
if inference_client is None:
|
| 356 |
+
st.error("Hugging Face Inference Client 未載入,請檢查 HF_TOKEN 和網路連線,無法執行批量分析。")
|
|
|
|
| 357 |
st.session_state.execute_batch_analysis = False
|
| 358 |
+
|
| 359 |
data_to_process = st.session_state.json_data_for_batch
|
| 360 |
|
| 361 |
# 提取 Log 列表的邏輯 (保持不變)
|
|
|
|
| 373 |
logs_list = [data_to_process]
|
| 374 |
else:
|
| 375 |
logs_list = [data_to_process]
|
| 376 |
+
|
| 377 |
if logs_list:
|
| 378 |
vs = st.session_state.get("vector_store", None)
|
| 379 |
if vs:
|
|
|
|
| 381 |
else:
|
| 382 |
st.warning("⚠️ RAG 知識庫未載入,將單純執行 Log 分析。")
|
| 383 |
|
| 384 |
+
# --- 新增:創建平移視窗序列 --- (保持不變)
|
|
|
|
| 385 |
# 將所有 Log 轉換為 JSON 格式化字串列表,以便後續拼接
|
| 386 |
formatted_logs = [json.dumps(log, indent=2, ensure_ascii=False) for log in logs_list]
|
| 387 |
|
|
|
|
| 423 |
progress_bar.progress((i + 1) / total_sequences, text=f"已處理 {i + 1}/{total_sequences} 個序列 (目標 Log #{log_id})...")
|
| 424 |
|
| 425 |
try:
|
| 426 |
+
# *** 替換為 Inference Client 呼叫函數 ***
|
| 427 |
response, retrieved_ctx = generate_rag_response_hf_for_log(
|
| 428 |
+
client=inference_client, # <--- 新的 Inference Client
|
| 429 |
model_id=MODEL_ID,
|
| 430 |
log_sequence_text=seq_data["sequence_text"],
|
| 431 |
user_prompt=analysis_prompt,
|
|
|
|
| 447 |
}
|
| 448 |
st.session_state.batch_results.append(item)
|
| 449 |
|
| 450 |
+
# 結果顯示邏輯 (保持不變)
|
| 451 |
with results_container:
|
| 452 |
st.subheader(f"Log/Alert #{item['log_id']} (序列分析完成)")
|
| 453 |
with st.expander(f"序列內容 (包含 {len(seq_data['sequence_text'].split('--- Log Index'))-1} 條 Log)"):
|
|
|
|
| 506 |
for item in st.session_state.batch_results:
|
| 507 |
log_content_str_for_report = json.dumps(item["log_content"], indent=2, ensure_ascii=False).replace("`", "\\`")
|
| 508 |
full_report_chunks.append(f"---\n\n### Log/Alert #{item['log_id']}\n\n#### 原始內容\n```json\n{log_content_str_for_report}\n```\n\n#### LLM 分析結果\n{item['analysis_result']}\n")
|
| 509 |
+
|
| 510 |
st.info(f"偵測到 {len(st.session_state.batch_results)} 條 Log 的歷史分析結果。")
|
| 511 |
st.download_button(
|
| 512 |
label="📥 下載上次的完整報告 (.md)",
|