ss900371tw's picture
Update src/streamlit_app.py
6f360a6 verified
raw
history blame
13 kB
import torch
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import io
import numpy as np
import faiss
import uuid
import time
# === RAG 相關套件 ===
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_core.documents import Document
from langchain_community.vectorstores import FAISS
from langchain_community.vectorstores.utils import DistanceStrategy
from langchain_community.docstore.in_memory import InMemoryDocstore
# 嘗試匯入 pypdf
try:
import pypdf
except ImportError:
pypdf = None
# --- 頁面設定 ---
st.set_page_config(page_title="Cybersecurity AI Assistant (RAG)", page_icon="🛡️", layout="wide")
st.title("🛡️ Foundation-Sec-8B with FAISS RAG")
st.markdown("已啟用:**IndexFlatIP** + **L2 正規化** + **上下文下載功能**")
# --- 側邊欄設定 ---
with st.sidebar:
st.header("⚙️ 設定")
default_token = os.getenv("HF_TOKEN", "")
hf_token = st.text_input("Hugging Face Token", value=default_token, type="password")
st.divider()
st.subheader("📂 上傳分析檔案 (建立 RAG 庫)")
# 使用 key 確保重新整理時狀態正確
uploaded_file = st.file_uploader("上傳 Logs/PDF/Code", type=['txt', 'py', 'log', 'csv', 'md', 'json', 'pdf'])
st.divider()
st.subheader("🔍 RAG 檢索設定")
similarity_threshold = st.slider(
"📐 Cosine Similarity 門檻",
0.0, 1.0, 0.4, 0.01,
help="數值越大越相似。一般建議 0.4~0.7"
)
st.divider()
st.subheader("模型參數")
system_prompt = st.text_area("System Prompt", value="You are a Senior Security Analyst. Use the retrieved context to answer the user's question. Every claim you make MUST be supported by a specific Event Record ID from the retrieved context.", height=100)
max_new_tokens = st.slider("Max New Tokens", 128, 4096, 1024, 128)
temperature = st.slider("Temperature", 0.0, 1.5, 0.1, 0.1)
st.divider()
# modification 1: 清除按鈕只清除對話,不清除知識庫
if st.button("🗑️ 清除對話紀錄"):
st.session_state.messages = []
# 注意:這裡不刪除 vector_store,保留 RAG 狀態
st.rerun()
# --- Device ---
def get_device():
if torch.cuda.is_available(): return "cuda"
elif torch.backends.mps.is_available(): return "mps"
else: return "cpu"
DEVICE = get_device()
st.sidebar.markdown(f"**LLM Device:** `{DEVICE}`")
# --- 模型載入 ---
@st.cache_resource
def load_model(model_id, token):
if not token: return None, None, "TokenMissing"
try:
tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
template_status = "OK"
if tokenizer.chat_template is None:
tokenizer.chat_template = """{% for message in messages %}<|im_start|>{{ message['role'] }}{{ message['content'] }}<|im_end|>{% endfor %}<|im_start|>assistant"""
template_status = "TemplateSet"
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
token=token,
)
return tokenizer, model, template_status
except Exception as e:
return None, None, f"LoadFailed: {e}"
# === Embedding 模型 ===
@st.cache_resource
def load_embedding_model():
model_kwargs = {
'device': 'cpu',
'trust_remote_code': True # <--- 關鍵修正:必須允許執行 Jina 的自定義代碼
}
encode_kwargs = {
'normalize_embeddings': False # 我們在 FAISS 建立索引前會自己做 normalize
}
return HuggingFaceEmbeddings(
model_name="jinaai/jina-embeddings-v2-base-code",
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)
# === 初始化 LLM ===
if hf_token:
MODEL_ID = "fdtn-ai/Foundation-Sec-8B-Instruct"
with st.spinner(f"正在載入 LLM {MODEL_ID}..."):
tokenizer, model, status = load_model(MODEL_ID, hf_token)
if status == "TokenMissing":
st.error("請輸入 Token。")
st.stop()
elif status.startswith("LoadFailed"):
st.error(f"模型載入失敗: {status}")
st.stop()
else:
st.warning("請輸入 Token。")
st.stop()
# === Embedding 初始化 ===
with st.spinner("正在載入 Embedding 模型..."):
embedding_model = load_embedding_model()
# === 建立向量庫 (Strict Cosine) ===
def process_file_to_faiss(uploaded_file):
text_content = ""
try:
if uploaded_file.type == "application/pdf":
if pypdf:
pdf_reader = pypdf.PdfReader(uploaded_file)
for page in pdf_reader.pages:
text_content += page.extract_text() + "\n"
else:
return None, "PDF library missing"
else:
stringio = io.StringIO(uploaded_file.getvalue().decode("utf-8"))
text_content = stringio.read()
if not text_content.strip():
return None, "File is empty"
# 簡單切分
events = [e + "</Event>" for e in text_content.split("</Event>") if e.strip()]
# 如果不是 XML 格式,改用換行切分作為 fallback
if len(events) <= 1:
events = [line for line in text_content.split("\n") if line.strip()]
docs = [Document(page_content=e) for e in events]
if not docs:
return None, "No documents created"
embeddings = embedding_model.embed_documents([d.page_content for d in docs])
embeddings_np = np.array(embeddings).astype("float32")
faiss.normalize_L2(embeddings_np)
dimension = embeddings_np.shape[1]
index = faiss.IndexFlatIP(dimension)
index.add(embeddings_np)
doc_ids = [str(uuid.uuid4()) for _ in range(len(docs))]
docstore = InMemoryDocstore({_id: doc for _id, doc in zip(doc_ids, docs)})
index_to_docstore_id = {i: _id for i, _id in enumerate(doc_ids)}
vector_store = FAISS(
embedding_function=embedding_model,
index=index,
docstore=docstore,
index_to_docstore_id=index_to_docstore_id,
distance_strategy=DistanceStrategy.COSINE
)
return vector_store, f"{len(docs)} chunks created."
except Exception as e:
return None, f"Error: {str(e)}"
# === Modification 2: 檔案處理邏輯 (自動同步) ===
if uploaded_file:
file_key = f"vs_{uploaded_file.name}_{uploaded_file.size}" # 加入 size 避免同名但內容不同
# 檢查是否為新檔案 (或第一次上傳)
if "current_file_key" not in st.session_state or st.session_state.current_file_key != file_key:
with st.spinner("偵測到新檔案,正在更新知識庫..."):
vs, msg = process_file_to_faiss(uploaded_file)
if vs:
st.session_state.vector_store = vs
st.session_state.current_file_key = file_key
st.toast(f"知識庫已更新!{msg}", icon="✅")
else:
st.error(msg)
else:
# 如果使用者移除了檔案 (點擊 X),則清除知識庫
if "vector_store" in st.session_state:
del st.session_state.vector_store
st.info("檔案已移除,已清除知識庫,回到一般模式。")
if "current_file_key" in st.session_state:
del st.session_state.current_file_key
# === 顯示對話歷史 ===
if "messages" not in st.session_state:
st.session_state.messages = []
# Modification 3 Part A: 顯示歷史訊息時加入下載按鈕
for idx, message in enumerate(st.session_state.messages):
with st.chat_message(message["role"]):
st.markdown(message["content"])
# 檢查是否有 context 欄位 (RAG 檢索內容)
if message.get("context"):
with st.expander(f"查看參考片段 (Turn {idx})"):
st.code(message["context"])
# 下載按鈕
st.download_button(
label="📥 下載此參考內容 (.txt)",
data=message["context"],
file_name=f"rag_context_{idx}.txt",
mime="text/plain",
key=f"dl_btn_{idx}" # 必須給 unique key
)
# === Search 函數 ===
def faiss_cosine_search_all(vector_store, query, threshold):
q_emb = embedding_model.embed_query(query)
q_emb = np.array([q_emb]).astype("float32")
faiss.normalize_L2(q_emb)
index = vector_store.index
D, I = index.search(q_emb, k=index.ntotal)
selected = []
for score, idx in zip(D[0], I[0]):
if idx == -1: continue
if score >= threshold:
doc_id = vector_store.index_to_docstore_id[idx]
doc = vector_store.docstore.search(doc_id)
selected.append((doc, score))
selected.sort(key=lambda x: x[1], reverse=True)
return selected
# === 產生回答 ===
def generate_rag_response(prompt, history, sys_prompt, vector_store=None, threshold=0.5):
context_text = ""
top_k_selected = []
if vector_store:
selected = faiss_cosine_search_all(vector_store, prompt, threshold)
if selected:
top_k_selected = selected
retrieved_contents = [
f"--- Chunk (sim={score:.3f}) ---\n{doc.page_content}"
for i, (doc, score) in enumerate(top_k_selected)
]
context_text = "\n".join(retrieved_contents)
if context_text:
full_user_input = f"""Use the following retrieved context to answer the question.=== RETRIEVED CONTEXT (Cosine ≥ {threshold}) ==={context_text}=== END CONTEXT ===Question: {prompt}"""
else:
full_user_input = f"""Question: {prompt}"""
# 構建 Prompt
messages = [{"role": "system", "content": sys_prompt}]
# 僅使用 content 欄位構建歷史,避免把 context 欄位傳給 LLM
for msg in history:
messages.append({"role": msg["role"], "content": msg["content"]})
messages.append({"role": "user", "content": full_user_input})
inputs = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs_tokenized = tokenizer(inputs, return_tensors="pt")
input_ids = inputs_tokenized["input_ids"].to(DEVICE)
# Context Window 防呆
MAX_CONTEXT = 4096
if input_ids.shape[1] > (MAX_CONTEXT - max_new_tokens):
input_ids = input_ids[:, -(MAX_CONTEXT - max_new_tokens):]
do_sample = True if temperature > 0 else False
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature if do_sample else None,
do_sample=do_sample,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
)
response = tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True)
return response, context_text
# === 處理使用者輸入 ===
if prompt := st.chat_input("請輸入問題..."):
vs = st.session_state.get("vector_store", None)
display_prompt = prompt
if vs:
display_prompt = f"🔍 **[RAG]** {prompt}"
st.chat_message("user").markdown(display_prompt)
if model and tokenizer:
with st.chat_message("assistant"):
msg_placeholder = st.empty()
with st.spinner("Analyzing..."):
response, retrieved_ctx = generate_rag_response(
prompt,
st.session_state.messages,
system_prompt,
vector_store=vs,
threshold=similarity_threshold,
)
msg_placeholder.markdown(response)
# Modification 3 Part B: 當下回應顯示 Expander 與 下載按鈕
if retrieved_ctx:
with st.expander("查看檢索到的參考片段"):
st.code(retrieved_ctx)
st.download_button(
label="📥 下載此參考內容 (.txt)",
data=retrieved_ctx,
file_name=f"rag_context_current.txt",
mime="text/plain"
)
# Modification 3 Part C: 存入 Session State 時,同時儲存 context
st.session_state.messages.append({"role": "user", "content": display_prompt})
st.session_state.messages.append({
"role": "assistant",
"content": response,
"context": retrieved_ctx # 把檢索內容存起來,讓上面的迴圈可以讀取
})