import streamlit as st import os import io import json import csv # <--- 新增:用於處理 CSV import numpy as np import faiss import uuid import time import sys # === HuggingFace 模型相關套件 (替換為 InferenceClient) === try: from huggingface_hub import InferenceClient except ImportError: st.error("請檢查是否安裝了所有 Hugging Face 相關依賴:pip install huggingface-hub") # === LangChain/RAG 相關套件 (保持不變) === from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_core.documents import Document from langchain_community.vectorstores import FAISS from langchain_community.vectorstores.utils import DistanceStrategy from langchain_community.docstore.in_memory import InMemoryDocstore # 嘗試匯入 pypdftry try: import pypdf except ImportError: pypdf = None # --- 頁面設定 --- st.set_page_config(page_title="Cybersecurity AI Assistant (Hugging Face RAG & Batch Analysis)", page_icon="🛡️", layout="wide") st.title("🛡️ Meta-Llama-3-8B-Instruct with FAISS RAG & Batch Analysis (Inference Client)") st.markdown("已啟用:**IndexFlatIP** + **L2 正規化** + **Hugging Face Inference Client (API)**。支援 JSON/CSV/TXT 執行批量分析。") if 'execute_batch_analysis' not in st.session_state: st.session_state.execute_batch_analysis = False if 'batch_results' not in st.session_state: st.session_state.batch_results = None if 'rag_current_file_key' not in st.session_state: st.session_state.rag_current_file_key = None if 'batch_current_file_key' not in st.session_state: # 修改變數名稱以反映多格式 st.session_state.batch_current_file_key = None if 'vector_store' not in st.session_state: st.session_state.vector_store = None if 'json_data_for_batch' not in st.session_state: # 變數名稱保留,但內容可能是轉換後的 dict st.session_state.json_data_for_batch = None # 設定模型 ID MODEL_ID = "meta-llama/Llama-4-Scout-17B-16E-Instruct" WINDOW_SIZE = 8 # --- 側邊欄設定 --- with st.sidebar: st.header("⚙️ 設定") if not os.environ.get("HF_TOKEN"): st.error("環境變數 **HF_TOKEN** 未設定。請設定後重新啟動應用程式。") st.info(f"LLM 模型:**{MODEL_ID}** (Hugging Face Inference API)") st.warning("⚠️ **注意**: 該模型使用 Inference API 呼叫,請確保您的 HF Token 具有存取權限。") st.divider() st.subheader("📂 檔案上傳") # === 1. 批量分析檔案 (修改處:支援多種格式) === batch_uploaded_file = st.file_uploader( "1️⃣ 上傳 **Log/Alert 檔案** (用於批量分析)", type=['json', 'csv', 'txt'], # <--- 修改:新增 csv 和 txt key="batch_uploader", help="支援 JSON (Array), CSV (含標題), TXT (每行一條 Log)" ) # === 2. RAG 知識庫檔案 === rag_uploaded_file = st.file_uploader( "2️⃣ 上傳 **RAG 參考知識庫** (Logs/PDF/Code 等)", type=['txt', 'py', 'log', 'csv', 'md', 'pdf'], key="rag_uploader" ) st.divider() st.subheader("💡 批量分析指令") analysis_prompt = st.text_area( "針對每個 Log/Alert 執行的指令", value="You are a security expert in charge of analyzing alerts related to Web Application Attacks and Brute Force & Reconnaissance. Respond with a clear, structured analysis using the following mandatory sections: \n\n- Priority: Provide the overall priority level. (Answer High risk, Medium risk, or Low risk only) \n- Explanation: If this alert is highly related to Web Application Attacks and Brute Force & Reconnaissance, explain the potential impact and why this specific alert requires attention. If not, **omit the explanation section**. \n- Action Plan: If this alert is highly related to Web Application Attacks and Brute Force & Reconnaissance, What should be the immediate steps to address this specific alert? If not, **omit the action plan section**. \n\nStrictly use the information in the provided Log.", height=200 ) st.markdown("此指令將對檔案中的**每一個 Log 條目**執行一次獨立分析。") if batch_uploaded_file: if st.button("🚀 執行批量分析"): if not os.environ.get("HF_TOKEN"): st.error("無法執行,環境變數 **HF_TOKEN** 未設定。") else: st.session_state.execute_batch_analysis = True else: st.info("請上傳 Log 檔案以啟用批量分析按鈕。") st.divider() st.subheader("🔍 RAG 檢索設定") similarity_threshold = st.slider("📐 Cosine Similarity 門檻", 0.0, 1.0, 0.4, 0.01) st.divider() st.subheader("模型參數") system_prompt = st.text_area("System Prompt", value="You are a Senior Security Analyst, named Ernest. You provide expert, authoritative, and concise advice on Information Security. Your analysis must be based strictly on the provided context.", height=100) max_output_tokens = st.slider("Max Output Tokens", 128, 4096, 2048, 128) temperature = st.slider("Temperature", 0.0, 1.0, 0.1, 0.1) top_p = st.slider("Top P", 0.1, 1.0, 0.95, 0.05) st.divider() if st.button("🗑️ 清除所有紀錄"): for key in list(st.session_state.keys()): del st.session_state[key] st.rerun() # --- 初始化 Hugging Face LLM Client --- @st.cache_resource def load_inference_client(model_id): if not os.environ.get("HF_TOKEN"): return None try: client = InferenceClient(model_id, token=os.environ.get("HF_TOKEN")) st.success(f"Hugging Face Inference Client **{model_id}** 載入成功。") return client except Exception as e: st.error(f"Hugging Face Inference Client 載入失敗: {e}") return None inference_client = None if os.environ.get("HF_TOKEN"): with st.spinner(f"正在連線到 Inference Client: {MODEL_ID}..."): inference_client = load_inference_client(MODEL_ID) if inference_client is None and os.environ.get("HF_TOKEN"): st.warning("Hugging Face Inference Client 無法連線。") elif not os.environ.get("HF_TOKEN"): st.error("請在環境變數中設定 HF_TOKEN。") # === Embedding 模型 (保持不變) === @st.cache_resource def load_embedding_model(): model_kwargs = {'device': 'cpu', 'trust_remote_code': True} encode_kwargs = {'normalize_embeddings': False} return HuggingFaceEmbeddings(model_name="BAAI/bge-large-zh-v1.5", model_kwargs=model_kwargs, encode_kwargs=encode_kwargs) with st.spinner("正在載入 Embedding 模型..."): embedding_model = load_embedding_model() # === 建立向量庫 / Search 函數 (保持不變) === def process_file_to_faiss(uploaded_file): text_content = "" try: if uploaded_file.type == "application/pdf": if pypdf: pdf_reader = pypdf.PdfReader(uploaded_file) for page in pdf_reader.pages: text_content += page.extract_text() + "\n" else: return None, "PDF library missing" else: stringio = io.StringIO(uploaded_file.getvalue().decode("utf-8")) text_content = stringio.read() if not text_content.strip(): return None, "File is empty" events = [line for line in text_content.splitlines() if line.strip()] docs = [Document(page_content=e) for e in events] if not docs: return None, "No documents created" embeddings = embedding_model.embed_documents([d.page_content for d in docs]) embeddings_np = np.array(embeddings).astype("float32") faiss.normalize_L2(embeddings_np) dimension = embeddings_np.shape[1] index = faiss.IndexFlatIP(dimension) index.add(embeddings_np) doc_ids = [str(uuid.uuid4()) for _ in range(len(docs))] docstore = InMemoryDocstore({_id: doc for _id, doc in zip(doc_ids, docs)}) index_to_docstore_id = {i: _id for i, _id in enumerate(doc_ids)} vector_store = FAISS(embedding_function=embedding_model, index=index, docstore=docstore, index_to_docstore_id=index_to_docstore_id, distance_strategy=DistanceStrategy.COSINE) return vector_store, f"{len(docs)} chunks created." except Exception as e: return None, f"Error: {str(e)}" def faiss_cosine_search_all(vector_store, query, threshold): q_emb = embedding_model.embed_query(query) q_emb = np.array([q_emb]).astype("float32") faiss.normalize_L2(q_emb) index = vector_store.index D, I = index.search(q_emb, k=index.ntotal) selected = [] for score, idx in zip(D[0], I[0]): if idx == -1: continue if score >= threshold: doc_id = vector_store.index_to_docstore_id[idx] doc = vector_store.docstore.search(doc_id) selected.append((doc, score)) selected.sort(key=lambda x: x[1], reverse=True) return selected # === Hugging Face 生成單一 Log 分析回答 (保持不變) === def generate_rag_response_hf_for_log(client, model_id, log_sequence_text, user_prompt, sys_prompt, vector_store, threshold, max_output_tokens, temperature, top_p): if client is None: return "ERROR: Client Error", "" context_text = "" if vector_store: selected = faiss_cosine_search_all(vector_store, log_sequence_text, threshold) if selected: retrieved_contents = [f"--- Reference Chunk (sim={score:.3f}) ---\n{doc.page_content}" for i, (doc, score) in enumerate(selected[:5])] context_text = "\n".join(retrieved_contents) rag_instruction = f"""=== RETRIEVED REFERENCE CONTEXT (Cosine ≥ {threshold}) ==={context_text if context_text else 'No relevant reference context found.'}=== END REFERENCE CONTEXT ===\nANALYSIS INSTRUCTION: {user_prompt}\nBased on the provided LOG SEQUENCE and REFERENCE CONTEXT, you must analyze the **entire sequence** to detect any continuous attack chains or evolving threats.""" log_content_section = f"""=== CURRENT LOG SEQUENCE TO ANALYZE (Window Size: {WINDOW_SIZE}) ===\n{log_sequence_text}\n=== END LOG SEQUENCE ===""" messages = [ {"role": "system", "content": sys_prompt}, {"role": "user", "content": f"{rag_instruction}\n\n{log_content_section}"} ] try: response_stream = client.chat_completion(messages, max_tokens=max_output_tokens, temperature=temperature, top_p=top_p, stream=False) if response_stream and response_stream.choices: return response_stream.choices[0].message.content.strip(), context_text else: return "Format Error", context_text except Exception as e: return f"Model Error: {str(e)}", context_text # ======================================================================= # === 檔案處理區塊 (RAG 檔案) === if rag_uploaded_file: file_key = f"vs_{rag_uploaded_file.name}_{rag_uploaded_file.size}" if st.session_state.rag_current_file_key != file_key or 'vector_store' not in st.session_state: with st.spinner(f"正在建立 RAG 參考知識庫 ({rag_uploaded_file.name})..."): vs, msg = process_file_to_faiss(rag_uploaded_file) if vs: st.session_state.vector_store = vs st.session_state.rag_current_file_key = file_key st.toast(f"RAG 參考知識庫已更新!{msg}", icon="✅") else: st.error(msg) elif 'vector_store' in st.session_state: del st.session_state.vector_store del st.session_state.rag_current_file_key st.info("RAG 檔案已移除,已清除相關知識庫。") # === 檔案處理區塊 (批量分析檔案 - 重大修改處) === # 支援 JSON, CSV, TXT 並統一轉換為 list of dicts if batch_uploaded_file: batch_file_key = f"batch_{batch_uploaded_file.name}_{batch_uploaded_file.size}" if st.session_state.batch_current_file_key != batch_file_key or 'json_data_for_batch' not in st.session_state: try: stringio = io.StringIO(batch_uploaded_file.getvalue().decode("utf-8")) parsed_data = None # --- Case 1: JSON --- if batch_uploaded_file.name.lower().endswith('.json'): parsed_data = json.load(stringio) st.toast("JSON 檔案載入成功", icon="📄") # --- Case 2: CSV --- elif batch_uploaded_file.name.lower().endswith('.csv'): # 使用 DictReader 將 CSV 轉為 List of Dicts reader = csv.DictReader(stringio) parsed_data = list(reader) st.toast("CSV 檔案已轉換為 JSON 結構", icon="📊") # --- Case 3: TXT --- else: # 預設為 TXT # 將每一行包裝成一個 JSON 物件: {"raw_content": "line text"} lines = stringio.readlines() parsed_data = [{"raw_log_entry": line.strip()} for line in lines if line.strip()] st.toast("TXT 檔案已轉換為 JSON 結構", icon="📝") # 儲存處理後的數據 st.session_state.json_data_for_batch = parsed_data st.session_state.batch_current_file_key = batch_file_key except Exception as e: st.error(f"檔案解析錯誤: {e}") if 'json_data_for_batch' in st.session_state: del st.session_state.json_data_for_batch elif 'json_data_for_batch' in st.session_state: del st.session_state.json_data_for_batch del st.session_state.batch_current_file_key if "batch_results" in st.session_state: del st.session_state.batch_results st.info("批量分析檔案已移除,已清除相關數據。") # === 執行批量分析邏輯 === if st.session_state.execute_batch_analysis and 'json_data_for_batch' in st.session_state: st.session_state.execute_batch_analysis = False start_time = time.time() st.session_state.batch_results = [] if inference_client is None: st.error("Client 未連線,無法執行。") else: data_to_process = st.session_state.json_data_for_batch logs_list = [] # 處理不同的 JSON 結構 (Dict vs List) if isinstance(data_to_process, list): logs_list = data_to_process elif isinstance(data_to_process, dict): # 嘗試尋找常見的 key if 'alerts' in data_to_process and isinstance(data_to_process['alerts'], list): logs_list = data_to_process['alerts'] elif 'logs' in data_to_process and isinstance(data_to_process['logs'], list): logs_list = data_to_process['logs'] else: logs_list = [data_to_process] else: logs_list = [data_to_process] if logs_list: vs = st.session_state.get("vector_store", None) # --- 關鍵:在這裡做 JSON String 的轉換 --- # 無論來源是 CSV(Dict) 還是 TXT(Dict),都在這裡用 json.dumps 轉成字串 # 這保證了 Prompt 收到的永遠是 JSON 格式的文字 formatted_logs = [json.dumps(log, indent=2, ensure_ascii=False) for log in logs_list] analysis_sequences = [] for i in range(len(formatted_logs)): start_index = max(0, i - WINDOW_SIZE + 1) end_index = i + 1 current_window = formatted_logs[start_index:end_index] sequence_text = [] for j, log_str in enumerate(current_window): is_target = " <<< TARGET LOG TO ANALYZE" if j == len(current_window) - 1 else "" sequence_text.append(f"--- Log Index {i - len(current_window) + j + 1} ({len(current_window)-j} prior logs){is_target} ---\n{log_str}") analysis_sequences.append({ "sequence_text": "\n\n".join(sequence_text), "target_log_id": i + 1, "original_log_entry": logs_list[i] }) total_sequences = len(analysis_sequences) st.header(f"⚡ 批量分析執行中 (平移視窗 $N={WINDOW_SIZE}$)...") progress_bar = st.progress(0, text=f"準備處理 {total_sequences} 個序列...") results_container = st.container() full_report_chunks = ["## Cybersecurity Batch Analysis Report\n\n"] for i, seq_data in enumerate(analysis_sequences): log_id = seq_data["target_log_id"] progress_bar.progress((i + 1) / total_sequences, text=f"Processing {i + 1}/{total_sequences} (Log #{log_id})...") try: response, retrieved_ctx = generate_rag_response_hf_for_log( client=inference_client, model_id=MODEL_ID, log_sequence_text=seq_data["sequence_text"], user_prompt=analysis_prompt, sys_prompt=system_prompt, vector_store=vs, threshold=similarity_threshold, max_output_tokens=max_output_tokens, temperature=temperature, top_p=top_p ) item = { "log_id": log_id, "log_content": seq_data["original_log_entry"], "sequence_analyzed": seq_data["sequence_text"], "analysis_result": response, "context": retrieved_ctx } st.session_state.batch_results.append(item) with results_container: st.subheader(f"Log/Alert #{item['log_id']}") with st.expander("序列內容 (JSON Format)"): st.code(item["sequence_analyzed"], language='json') # 這裡顯示的會是 JSON 格式 is_high = any(x in response.lower() for x in ['high risk']) if is_high: st.error(item['analysis_result']) else: st.info(item['analysis_result']) if item['context']: with st.expander("參考 RAG 片段"): st.code(item['context']) st.markdown("---") log_content_str_for_report = json.dumps(item["log_content"], indent=2, ensure_ascii=False).replace("`", "\\`") full_report_chunks.append(f"---\n\n### Log #{item['log_id']}\n```json\n{log_content_str_for_report}\n```\nResult:\n{item['analysis_result']}\n") except Exception as e: st.error(f"Error Log {log_id}: {e}") end_time = time.time() progress_bar.empty() st.success(f"完成!耗時 {end_time - start_time:.2f} 秒。") else: st.error("無法提取有效 Log,請檢查檔案格式。") # === 顯示結果 (歷史紀錄) === if st.session_state.get("batch_results") and not st.session_state.execute_batch_analysis: st.header("⚡ 歷史分析結果") full_report_chunks = ["## Report\n\n"] for item in st.session_state.batch_results: log_content_str_for_report = json.dumps(item["log_content"], indent=2, ensure_ascii=False).replace("`", "\\`") full_report_chunks.append(f"---\n\n### Log #{item['log_id']}\n```json\n{log_content_str_for_report}\n```\n{item['analysis_result']}\n") st.download_button("📥 下載完整報告 (.md)", "\n".join(full_report_chunks), "report.md", "text/markdown")