|
|
|
|
|
import streamlit as st |
|
|
import os |
|
|
import io |
|
|
import json |
|
|
import csv |
|
|
import numpy as np |
|
|
import faiss |
|
|
import uuid |
|
|
import time |
|
|
import sys |
|
|
|
|
|
|
|
|
try: |
|
|
from huggingface_hub import InferenceClient |
|
|
except ImportError: |
|
|
st.error("請檢查是否安裝了所有 Hugging Face 相關依賴:pip install huggingface-hub") |
|
|
|
|
|
|
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
|
from langchain_core.documents import Document |
|
|
from langchain_community.vectorstores import FAISS |
|
|
from langchain_community.vectorstores.utils import DistanceStrategy |
|
|
from langchain_community.docstore.in_memory import InMemoryDocstore |
|
|
|
|
|
|
|
|
try: |
|
|
import pypdf |
|
|
except ImportError: |
|
|
pypdf = None |
|
|
|
|
|
|
|
|
st.set_page_config(page_title="Cybersecurity AI Assistant (Hugging Face RAG & Batch Analysis)", page_icon="🛡️", layout="wide") |
|
|
st.title("🛡️ Meta-Llama-3-8B-Instruct with FAISS RAG & Batch Analysis (Inference Client)") |
|
|
st.markdown("已啟用:**IndexFlatIP** + **L2 正規化** + **Hugging Face Inference Client (API)**。支援 JSON/CSV/TXT 執行批量分析。") |
|
|
|
|
|
if 'execute_batch_analysis' not in st.session_state: |
|
|
st.session_state.execute_batch_analysis = False |
|
|
if 'batch_results' not in st.session_state: |
|
|
st.session_state.batch_results = None |
|
|
if 'rag_current_file_key' not in st.session_state: |
|
|
st.session_state.rag_current_file_key = None |
|
|
if 'batch_current_file_key' not in st.session_state: |
|
|
st.session_state.batch_current_file_key = None |
|
|
if 'vector_store' not in st.session_state: |
|
|
st.session_state.vector_store = None |
|
|
if 'json_data_for_batch' not in st.session_state: |
|
|
st.session_state.json_data_for_batch = None |
|
|
|
|
|
|
|
|
MODEL_ID = "meta-llama/Llama-4-Scout-17B-16E-Instruct" |
|
|
WINDOW_SIZE = 8 |
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
st.header("⚙️ 設定") |
|
|
|
|
|
if not os.environ.get("HF_TOKEN"): |
|
|
st.error("環境變數 **HF_TOKEN** 未設定。請設定後重新啟動應用程式。") |
|
|
|
|
|
st.info(f"LLM 模型:**{MODEL_ID}** (Hugging Face Inference API)") |
|
|
st.warning("⚠️ **注意**: 該模型使用 Inference API 呼叫,請確保您的 HF Token 具有存取權限。") |
|
|
|
|
|
st.divider() |
|
|
st.subheader("📂 檔案上傳") |
|
|
|
|
|
|
|
|
batch_uploaded_file = st.file_uploader( |
|
|
"1️⃣ 上傳 **Log/Alert 檔案** (用於批量分析)", |
|
|
type=['json', 'csv', 'txt'], |
|
|
key="batch_uploader", |
|
|
help="支援 JSON (Array), CSV (含標題), TXT (每行一條 Log)" |
|
|
) |
|
|
|
|
|
|
|
|
rag_uploaded_file = st.file_uploader( |
|
|
"2️⃣ 上傳 **RAG 參考知識庫** (Logs/PDF/Code 等)", |
|
|
type=['txt', 'py', 'log', 'csv', 'md', 'pdf'], |
|
|
key="rag_uploader" |
|
|
) |
|
|
|
|
|
st.divider() |
|
|
|
|
|
st.subheader("💡 批量分析指令") |
|
|
analysis_prompt = st.text_area( |
|
|
"針對每個 Log/Alert 執行的指令", |
|
|
value="You are a security expert in charge of analyzing alerts related to Web Application Attacks and Brute Force & Reconnaissance. Respond with a clear, structured analysis using the following mandatory sections: \n\n- Priority: Provide the overall priority level. (Answer High risk, Medium risk, or Low risk only) \n- Explanation: If this alert is highly related to Web Application Attacks and Brute Force & Reconnaissance, explain the potential impact and why this specific alert requires attention. If not, **omit the explanation section**. \n- Action Plan: If this alert is highly related to Web Application Attacks and Brute Force & Reconnaissance, What should be the immediate steps to address this specific alert? If not, **omit the action plan section**. \n\nStrictly use the information in the provided Log.", |
|
|
height=200 |
|
|
) |
|
|
st.markdown("此指令將對檔案中的**每一個 Log 條目**執行一次獨立分析。") |
|
|
|
|
|
if batch_uploaded_file: |
|
|
if st.button("🚀 執行批量分析"): |
|
|
if not os.environ.get("HF_TOKEN"): |
|
|
st.error("無法執行,環境變數 **HF_TOKEN** 未設定。") |
|
|
else: |
|
|
st.session_state.execute_batch_analysis = True |
|
|
else: |
|
|
st.info("請上傳 Log 檔案以啟用批量分析按鈕。") |
|
|
|
|
|
st.divider() |
|
|
st.subheader("🔍 RAG 檢索設定") |
|
|
similarity_threshold = st.slider("📐 Cosine Similarity 門檻", 0.0, 1.0, 0.4, 0.01) |
|
|
|
|
|
st.divider() |
|
|
st.subheader("模型參數") |
|
|
system_prompt = st.text_area("System Prompt", value="You are a Senior Security Analyst, named Ernest. You provide expert, authoritative, and concise advice on Information Security. Your analysis must be based strictly on the provided context.", height=100) |
|
|
max_output_tokens = st.slider("Max Output Tokens", 128, 4096, 2048, 128) |
|
|
temperature = st.slider("Temperature", 0.0, 1.0, 0.1, 0.1) |
|
|
top_p = st.slider("Top P", 0.1, 1.0, 0.95, 0.05) |
|
|
|
|
|
st.divider() |
|
|
if st.button("🗑️ 清除所有紀錄"): |
|
|
for key in list(st.session_state.keys()): |
|
|
del st.session_state[key] |
|
|
st.rerun() |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_inference_client(model_id): |
|
|
if not os.environ.get("HF_TOKEN"): return None |
|
|
try: |
|
|
client = InferenceClient(model_id, token=os.environ.get("HF_TOKEN")) |
|
|
st.success(f"Hugging Face Inference Client **{model_id}** 載入成功。") |
|
|
return client |
|
|
except Exception as e: |
|
|
st.error(f"Hugging Face Inference Client 載入失敗: {e}") |
|
|
return None |
|
|
|
|
|
inference_client = None |
|
|
if os.environ.get("HF_TOKEN"): |
|
|
with st.spinner(f"正在連線到 Inference Client: {MODEL_ID}..."): |
|
|
inference_client = load_inference_client(MODEL_ID) |
|
|
if inference_client is None and os.environ.get("HF_TOKEN"): |
|
|
st.warning("Hugging Face Inference Client 無法連線。") |
|
|
elif not os.environ.get("HF_TOKEN"): |
|
|
st.error("請在環境變數中設定 HF_TOKEN。") |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_embedding_model(): |
|
|
model_kwargs = {'device': 'cpu', 'trust_remote_code': True} |
|
|
encode_kwargs = {'normalize_embeddings': False} |
|
|
return HuggingFaceEmbeddings(model_name="BAAI/bge-large-zh-v1.5", model_kwargs=model_kwargs, encode_kwargs=encode_kwargs) |
|
|
|
|
|
with st.spinner("正在載入 Embedding 模型..."): |
|
|
embedding_model = load_embedding_model() |
|
|
|
|
|
|
|
|
def process_file_to_faiss(uploaded_file): |
|
|
text_content = "" |
|
|
try: |
|
|
if uploaded_file.type == "application/pdf": |
|
|
if pypdf: |
|
|
pdf_reader = pypdf.PdfReader(uploaded_file) |
|
|
for page in pdf_reader.pages: |
|
|
text_content += page.extract_text() + "\n" |
|
|
else: return None, "PDF library missing" |
|
|
else: |
|
|
stringio = io.StringIO(uploaded_file.getvalue().decode("utf-8")) |
|
|
text_content = stringio.read() |
|
|
|
|
|
if not text_content.strip(): return None, "File is empty" |
|
|
|
|
|
events = [line for line in text_content.splitlines() if line.strip()] |
|
|
docs = [Document(page_content=e) for e in events] |
|
|
if not docs: return None, "No documents created" |
|
|
|
|
|
embeddings = embedding_model.embed_documents([d.page_content for d in docs]) |
|
|
embeddings_np = np.array(embeddings).astype("float32") |
|
|
faiss.normalize_L2(embeddings_np) |
|
|
|
|
|
dimension = embeddings_np.shape[1] |
|
|
index = faiss.IndexFlatIP(dimension) |
|
|
index.add(embeddings_np) |
|
|
|
|
|
doc_ids = [str(uuid.uuid4()) for _ in range(len(docs))] |
|
|
docstore = InMemoryDocstore({_id: doc for _id, doc in zip(doc_ids, docs)}) |
|
|
index_to_docstore_id = {i: _id for i, _id in enumerate(doc_ids)} |
|
|
|
|
|
vector_store = FAISS(embedding_function=embedding_model, index=index, docstore=docstore, index_to_docstore_id=index_to_docstore_id, distance_strategy=DistanceStrategy.COSINE) |
|
|
return vector_store, f"{len(docs)} chunks created." |
|
|
except Exception as e: |
|
|
return None, f"Error: {str(e)}" |
|
|
|
|
|
def faiss_cosine_search_all(vector_store, query, threshold): |
|
|
q_emb = embedding_model.embed_query(query) |
|
|
q_emb = np.array([q_emb]).astype("float32") |
|
|
faiss.normalize_L2(q_emb) |
|
|
index = vector_store.index |
|
|
D, I = index.search(q_emb, k=index.ntotal) |
|
|
selected = [] |
|
|
for score, idx in zip(D[0], I[0]): |
|
|
if idx == -1: continue |
|
|
if score >= threshold: |
|
|
doc_id = vector_store.index_to_docstore_id[idx] |
|
|
doc = vector_store.docstore.search(doc_id) |
|
|
selected.append((doc, score)) |
|
|
selected.sort(key=lambda x: x[1], reverse=True) |
|
|
return selected |
|
|
|
|
|
|
|
|
def generate_rag_response_hf_for_log(client, model_id, log_sequence_text, user_prompt, sys_prompt, vector_store, threshold, max_output_tokens, temperature, top_p): |
|
|
if client is None: return "ERROR: Client Error", "" |
|
|
context_text = "" |
|
|
if vector_store: |
|
|
selected = faiss_cosine_search_all(vector_store, log_sequence_text, threshold) |
|
|
if selected: |
|
|
retrieved_contents = [f"--- Reference Chunk (sim={score:.3f}) ---\n{doc.page_content}" for i, (doc, score) in enumerate(selected[:5])] |
|
|
context_text = "\n".join(retrieved_contents) |
|
|
|
|
|
rag_instruction = f"""=== RETRIEVED REFERENCE CONTEXT (Cosine ≥ {threshold}) ==={context_text if context_text else 'No relevant reference context found.'}=== END REFERENCE CONTEXT ===\nANALYSIS INSTRUCTION: {user_prompt}\nBased on the provided LOG SEQUENCE and REFERENCE CONTEXT, you must analyze the **entire sequence** to detect any continuous attack chains or evolving threats.""" |
|
|
log_content_section = f"""=== CURRENT LOG SEQUENCE TO ANALYZE (Window Size: {WINDOW_SIZE}) ===\n{log_sequence_text}\n=== END LOG SEQUENCE ===""" |
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": sys_prompt}, |
|
|
{"role": "user", "content": f"{rag_instruction}\n\n{log_content_section}"} |
|
|
] |
|
|
try: |
|
|
response_stream = client.chat_completion(messages, max_tokens=max_output_tokens, temperature=temperature, top_p=top_p, stream=False) |
|
|
if response_stream and response_stream.choices: |
|
|
return response_stream.choices[0].message.content.strip(), context_text |
|
|
else: return "Format Error", context_text |
|
|
except Exception as e: return f"Model Error: {str(e)}", context_text |
|
|
|
|
|
|
|
|
|
|
|
if rag_uploaded_file: |
|
|
file_key = f"vs_{rag_uploaded_file.name}_{rag_uploaded_file.size}" |
|
|
if st.session_state.rag_current_file_key != file_key or 'vector_store' not in st.session_state: |
|
|
with st.spinner(f"正在建立 RAG 參考知識庫 ({rag_uploaded_file.name})..."): |
|
|
vs, msg = process_file_to_faiss(rag_uploaded_file) |
|
|
if vs: |
|
|
st.session_state.vector_store = vs |
|
|
st.session_state.rag_current_file_key = file_key |
|
|
st.toast(f"RAG 參考知識庫已更新!{msg}", icon="✅") |
|
|
else: st.error(msg) |
|
|
elif 'vector_store' in st.session_state: |
|
|
del st.session_state.vector_store |
|
|
del st.session_state.rag_current_file_key |
|
|
st.info("RAG 檔案已移除,已清除相關知識庫。") |
|
|
|
|
|
|
|
|
|
|
|
if batch_uploaded_file: |
|
|
batch_file_key = f"batch_{batch_uploaded_file.name}_{batch_uploaded_file.size}" |
|
|
|
|
|
if st.session_state.batch_current_file_key != batch_file_key or 'json_data_for_batch' not in st.session_state: |
|
|
try: |
|
|
stringio = io.StringIO(batch_uploaded_file.getvalue().decode("utf-8")) |
|
|
parsed_data = None |
|
|
|
|
|
|
|
|
if batch_uploaded_file.name.lower().endswith('.json'): |
|
|
parsed_data = json.load(stringio) |
|
|
st.toast("JSON 檔案載入成功", icon="📄") |
|
|
|
|
|
|
|
|
elif batch_uploaded_file.name.lower().endswith('.csv'): |
|
|
|
|
|
reader = csv.DictReader(stringio) |
|
|
parsed_data = list(reader) |
|
|
st.toast("CSV 檔案已轉換為 JSON 結構", icon="📊") |
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
lines = stringio.readlines() |
|
|
parsed_data = [{"raw_log_entry": line.strip()} for line in lines if line.strip()] |
|
|
st.toast("TXT 檔案已轉換為 JSON 結構", icon="📝") |
|
|
|
|
|
|
|
|
st.session_state.json_data_for_batch = parsed_data |
|
|
st.session_state.batch_current_file_key = batch_file_key |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"檔案解析錯誤: {e}") |
|
|
if 'json_data_for_batch' in st.session_state: |
|
|
del st.session_state.json_data_for_batch |
|
|
|
|
|
elif 'json_data_for_batch' in st.session_state: |
|
|
del st.session_state.json_data_for_batch |
|
|
del st.session_state.batch_current_file_key |
|
|
if "batch_results" in st.session_state: |
|
|
del st.session_state.batch_results |
|
|
st.info("批量分析檔案已移除,已清除相關數據。") |
|
|
|
|
|
|
|
|
if st.session_state.execute_batch_analysis and 'json_data_for_batch' in st.session_state: |
|
|
st.session_state.execute_batch_analysis = False |
|
|
start_time = time.time() |
|
|
st.session_state.batch_results = [] |
|
|
|
|
|
if inference_client is None: |
|
|
st.error("Client 未連線,無法執行。") |
|
|
else: |
|
|
data_to_process = st.session_state.json_data_for_batch |
|
|
logs_list = [] |
|
|
|
|
|
|
|
|
if isinstance(data_to_process, list): |
|
|
logs_list = data_to_process |
|
|
elif isinstance(data_to_process, dict): |
|
|
|
|
|
if 'alerts' in data_to_process and isinstance(data_to_process['alerts'], list): |
|
|
logs_list = data_to_process['alerts'] |
|
|
elif 'logs' in data_to_process and isinstance(data_to_process['logs'], list): |
|
|
logs_list = data_to_process['logs'] |
|
|
else: |
|
|
logs_list = [data_to_process] |
|
|
else: |
|
|
logs_list = [data_to_process] |
|
|
|
|
|
if logs_list: |
|
|
vs = st.session_state.get("vector_store", None) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
formatted_logs = [json.dumps(log, indent=2, ensure_ascii=False) for log in logs_list] |
|
|
|
|
|
analysis_sequences = [] |
|
|
for i in range(len(formatted_logs)): |
|
|
start_index = max(0, i - WINDOW_SIZE + 1) |
|
|
end_index = i + 1 |
|
|
current_window = formatted_logs[start_index:end_index] |
|
|
sequence_text = [] |
|
|
for j, log_str in enumerate(current_window): |
|
|
is_target = " <<< TARGET LOG TO ANALYZE" if j == len(current_window) - 1 else "" |
|
|
sequence_text.append(f"--- Log Index {i - len(current_window) + j + 1} ({len(current_window)-j} prior logs){is_target} ---\n{log_str}") |
|
|
analysis_sequences.append({ |
|
|
"sequence_text": "\n\n".join(sequence_text), |
|
|
"target_log_id": i + 1, |
|
|
"original_log_entry": logs_list[i] |
|
|
}) |
|
|
|
|
|
total_sequences = len(analysis_sequences) |
|
|
st.header(f"⚡ 批量分析執行中 (平移視窗 $N={WINDOW_SIZE}$)...") |
|
|
progress_bar = st.progress(0, text=f"準備處理 {total_sequences} 個序列...") |
|
|
results_container = st.container() |
|
|
full_report_chunks = ["## Cybersecurity Batch Analysis Report\n\n"] |
|
|
|
|
|
for i, seq_data in enumerate(analysis_sequences): |
|
|
log_id = seq_data["target_log_id"] |
|
|
progress_bar.progress((i + 1) / total_sequences, text=f"Processing {i + 1}/{total_sequences} (Log #{log_id})...") |
|
|
|
|
|
try: |
|
|
response, retrieved_ctx = generate_rag_response_hf_for_log( |
|
|
client=inference_client, |
|
|
model_id=MODEL_ID, |
|
|
log_sequence_text=seq_data["sequence_text"], |
|
|
user_prompt=analysis_prompt, |
|
|
sys_prompt=system_prompt, |
|
|
vector_store=vs, |
|
|
threshold=similarity_threshold, |
|
|
max_output_tokens=max_output_tokens, |
|
|
temperature=temperature, |
|
|
top_p=top_p |
|
|
) |
|
|
item = { |
|
|
"log_id": log_id, |
|
|
"log_content": seq_data["original_log_entry"], |
|
|
"sequence_analyzed": seq_data["sequence_text"], |
|
|
"analysis_result": response, |
|
|
"context": retrieved_ctx |
|
|
} |
|
|
st.session_state.batch_results.append(item) |
|
|
|
|
|
with results_container: |
|
|
st.subheader(f"Log/Alert #{item['log_id']}") |
|
|
with st.expander("序列內容 (JSON Format)"): |
|
|
st.code(item["sequence_analyzed"], language='json') |
|
|
|
|
|
is_high = any(x in response.lower() for x in ['high risk']) |
|
|
if is_high: st.error(item['analysis_result']) |
|
|
else: st.info(item['analysis_result']) |
|
|
if item['context']: |
|
|
with st.expander("參考 RAG 片段"): st.code(item['context']) |
|
|
st.markdown("---") |
|
|
|
|
|
log_content_str_for_report = json.dumps(item["log_content"], indent=2, ensure_ascii=False).replace("`", "\\`") |
|
|
full_report_chunks.append(f"---\n\n### Log #{item['log_id']}\n```json\n{log_content_str_for_report}\n```\nResult:\n{item['analysis_result']}\n") |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"Error Log {log_id}: {e}") |
|
|
|
|
|
end_time = time.time() |
|
|
progress_bar.empty() |
|
|
st.success(f"完成!耗時 {end_time - start_time:.2f} 秒。") |
|
|
else: |
|
|
st.error("無法提取有效 Log,請檢查檔案格式。") |
|
|
|
|
|
|
|
|
if st.session_state.get("batch_results") and not st.session_state.execute_batch_analysis: |
|
|
st.header("⚡ 歷史分析結果") |
|
|
full_report_chunks = ["## Report\n\n"] |
|
|
for item in st.session_state.batch_results: |
|
|
log_content_str_for_report = json.dumps(item["log_content"], indent=2, ensure_ascii=False).replace("`", "\\`") |
|
|
full_report_chunks.append(f"---\n\n### Log #{item['log_id']}\n```json\n{log_content_str_for_report}\n```\n{item['analysis_result']}\n") |
|
|
st.download_button("📥 下載完整報告 (.md)", "\n".join(full_report_chunks), "report.md", "text/markdown") |