|
|
import torch |
|
|
import streamlit as st |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import os |
|
|
import io |
|
|
import numpy as np |
|
|
import faiss |
|
|
import uuid |
|
|
import time |
|
|
|
|
|
|
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
|
from langchain_core.documents import Document |
|
|
from langchain_community.vectorstores import FAISS |
|
|
from langchain_community.vectorstores.utils import DistanceStrategy |
|
|
from langchain_community.docstore.in_memory import InMemoryDocstore |
|
|
|
|
|
|
|
|
try: |
|
|
import pypdf |
|
|
except ImportError: |
|
|
pypdf = None |
|
|
|
|
|
|
|
|
st.set_page_config(page_title="Cybersecurity AI Assistant (RAG)", page_icon="🛡️", layout="wide") |
|
|
st.title("🛡️ Foundation-Sec-8B with FAISS RAG") |
|
|
st.markdown("已啟用:**IndexFlatIP** + **L2 正規化** + **上下文下載功能**") |
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
st.header("⚙️ 設定") |
|
|
default_token = os.getenv("HF_TOKEN", "") |
|
|
hf_token = st.text_input("Hugging Face Token", value=default_token, type="password") |
|
|
|
|
|
st.divider() |
|
|
st.subheader("📂 上傳分析檔案 (建立 RAG 庫)") |
|
|
|
|
|
uploaded_file = st.file_uploader("上傳 Logs/PDF/Code", type=['txt', 'py', 'log', 'csv', 'md', 'json', 'pdf']) |
|
|
|
|
|
st.divider() |
|
|
st.subheader("🔍 RAG 檢索設定") |
|
|
similarity_threshold = st.slider( |
|
|
"📐 Cosine Similarity 門檻", |
|
|
0.0, 1.0, 0.4, 0.01, |
|
|
help="數值越大越相似。一般建議 0.4~0.7" |
|
|
) |
|
|
|
|
|
st.divider() |
|
|
st.subheader("模型參數") |
|
|
system_prompt = st.text_area("System Prompt", value="You are a Senior Security Analyst. Use the retrieved context to answer the user's question. Every claim you make MUST be supported by a specific Event Record ID from the retrieved context.", height=100) |
|
|
|
|
|
max_new_tokens = st.slider("Max New Tokens", 128, 4096, 1024, 128) |
|
|
temperature = st.slider("Temperature", 0.0, 1.5, 0.1, 0.1) |
|
|
|
|
|
st.divider() |
|
|
|
|
|
if st.button("🗑️ 清除對話紀錄"): |
|
|
st.session_state.messages = [] |
|
|
|
|
|
st.rerun() |
|
|
|
|
|
|
|
|
def get_device(): |
|
|
if torch.cuda.is_available(): return "cuda" |
|
|
elif torch.backends.mps.is_available(): return "mps" |
|
|
else: return "cpu" |
|
|
DEVICE = get_device() |
|
|
st.sidebar.markdown(f"**LLM Device:** `{DEVICE}`") |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(model_id, token): |
|
|
if not token: return None, None, "TokenMissing" |
|
|
try: |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, token=token) |
|
|
template_status = "OK" |
|
|
if tokenizer.chat_template is None: |
|
|
tokenizer.chat_template = """{% for message in messages %}<|im_start|>{{ message['role'] }}{{ message['content'] }}<|im_end|>{% endfor %}<|im_start|>assistant""" |
|
|
template_status = "TemplateSet" |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
pretrained_model_name_or_path=model_id, |
|
|
device_map="auto", |
|
|
torch_dtype=torch.bfloat16, |
|
|
token=token, |
|
|
) |
|
|
return tokenizer, model, template_status |
|
|
except Exception as e: |
|
|
return None, None, f"LoadFailed: {e}" |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_embedding_model(): |
|
|
model_kwargs = { |
|
|
'device': 'cpu', |
|
|
'trust_remote_code': True |
|
|
} |
|
|
encode_kwargs = { |
|
|
'normalize_embeddings': False |
|
|
} |
|
|
return HuggingFaceEmbeddings( |
|
|
model_name="jinaai/jina-embeddings-v2-base-code", |
|
|
model_kwargs=model_kwargs, |
|
|
encode_kwargs=encode_kwargs |
|
|
) |
|
|
|
|
|
|
|
|
if hf_token: |
|
|
MODEL_ID = "fdtn-ai/Foundation-Sec-8B-Instruct" |
|
|
with st.spinner(f"正在載入 LLM {MODEL_ID}..."): |
|
|
tokenizer, model, status = load_model(MODEL_ID, hf_token) |
|
|
if status == "TokenMissing": |
|
|
st.error("請輸入 Token。") |
|
|
st.stop() |
|
|
elif status.startswith("LoadFailed"): |
|
|
st.error(f"模型載入失敗: {status}") |
|
|
st.stop() |
|
|
else: |
|
|
st.warning("請輸入 Token。") |
|
|
st.stop() |
|
|
|
|
|
|
|
|
with st.spinner("正在載入 Embedding 模型..."): |
|
|
embedding_model = load_embedding_model() |
|
|
|
|
|
|
|
|
def process_file_to_faiss(uploaded_file): |
|
|
text_content = "" |
|
|
try: |
|
|
if uploaded_file.type == "application/pdf": |
|
|
if pypdf: |
|
|
pdf_reader = pypdf.PdfReader(uploaded_file) |
|
|
for page in pdf_reader.pages: |
|
|
text_content += page.extract_text() + "\n" |
|
|
else: |
|
|
return None, "PDF library missing" |
|
|
else: |
|
|
stringio = io.StringIO(uploaded_file.getvalue().decode("utf-8")) |
|
|
text_content = stringio.read() |
|
|
|
|
|
if not text_content.strip(): |
|
|
return None, "File is empty" |
|
|
|
|
|
|
|
|
events = [e + "</Event>" for e in text_content.split("</Event>") if e.strip()] |
|
|
|
|
|
if len(events) <= 1: |
|
|
events = [line for line in text_content.split("\n") if line.strip()] |
|
|
|
|
|
docs = [Document(page_content=e) for e in events] |
|
|
|
|
|
if not docs: |
|
|
return None, "No documents created" |
|
|
|
|
|
embeddings = embedding_model.embed_documents([d.page_content for d in docs]) |
|
|
embeddings_np = np.array(embeddings).astype("float32") |
|
|
faiss.normalize_L2(embeddings_np) |
|
|
|
|
|
dimension = embeddings_np.shape[1] |
|
|
index = faiss.IndexFlatIP(dimension) |
|
|
index.add(embeddings_np) |
|
|
|
|
|
doc_ids = [str(uuid.uuid4()) for _ in range(len(docs))] |
|
|
docstore = InMemoryDocstore({_id: doc for _id, doc in zip(doc_ids, docs)}) |
|
|
index_to_docstore_id = {i: _id for i, _id in enumerate(doc_ids)} |
|
|
|
|
|
vector_store = FAISS( |
|
|
embedding_function=embedding_model, |
|
|
index=index, |
|
|
docstore=docstore, |
|
|
index_to_docstore_id=index_to_docstore_id, |
|
|
distance_strategy=DistanceStrategy.COSINE |
|
|
) |
|
|
|
|
|
return vector_store, f"{len(docs)} chunks created." |
|
|
except Exception as e: |
|
|
return None, f"Error: {str(e)}" |
|
|
|
|
|
|
|
|
if uploaded_file: |
|
|
file_key = f"vs_{uploaded_file.name}_{uploaded_file.size}" |
|
|
|
|
|
|
|
|
if "current_file_key" not in st.session_state or st.session_state.current_file_key != file_key: |
|
|
with st.spinner("偵測到新檔案,正在更新知識庫..."): |
|
|
vs, msg = process_file_to_faiss(uploaded_file) |
|
|
if vs: |
|
|
st.session_state.vector_store = vs |
|
|
st.session_state.current_file_key = file_key |
|
|
st.toast(f"知識庫已更新!{msg}", icon="✅") |
|
|
else: |
|
|
st.error(msg) |
|
|
else: |
|
|
|
|
|
if "vector_store" in st.session_state: |
|
|
del st.session_state.vector_store |
|
|
st.info("檔案已移除,已清除知識庫,回到一般模式。") |
|
|
if "current_file_key" in st.session_state: |
|
|
del st.session_state.current_file_key |
|
|
|
|
|
|
|
|
if "messages" not in st.session_state: |
|
|
st.session_state.messages = [] |
|
|
|
|
|
|
|
|
for idx, message in enumerate(st.session_state.messages): |
|
|
with st.chat_message(message["role"]): |
|
|
st.markdown(message["content"]) |
|
|
|
|
|
|
|
|
if message.get("context"): |
|
|
with st.expander(f"查看參考片段 (Turn {idx})"): |
|
|
st.code(message["context"]) |
|
|
|
|
|
st.download_button( |
|
|
label="📥 下載此參考內容 (.txt)", |
|
|
data=message["context"], |
|
|
file_name=f"rag_context_{idx}.txt", |
|
|
mime="text/plain", |
|
|
key=f"dl_btn_{idx}" |
|
|
) |
|
|
|
|
|
|
|
|
def faiss_cosine_search_all(vector_store, query, threshold): |
|
|
q_emb = embedding_model.embed_query(query) |
|
|
q_emb = np.array([q_emb]).astype("float32") |
|
|
faiss.normalize_L2(q_emb) |
|
|
|
|
|
index = vector_store.index |
|
|
D, I = index.search(q_emb, k=index.ntotal) |
|
|
|
|
|
selected = [] |
|
|
for score, idx in zip(D[0], I[0]): |
|
|
if idx == -1: continue |
|
|
if score >= threshold: |
|
|
doc_id = vector_store.index_to_docstore_id[idx] |
|
|
doc = vector_store.docstore.search(doc_id) |
|
|
selected.append((doc, score)) |
|
|
|
|
|
selected.sort(key=lambda x: x[1], reverse=True) |
|
|
return selected |
|
|
|
|
|
|
|
|
def generate_rag_response(prompt, history, sys_prompt, vector_store=None, threshold=0.5): |
|
|
context_text = "" |
|
|
top_k_selected = [] |
|
|
|
|
|
if vector_store: |
|
|
selected = faiss_cosine_search_all(vector_store, prompt, threshold) |
|
|
if selected: |
|
|
top_k_selected = selected |
|
|
retrieved_contents = [ |
|
|
f"--- Chunk (sim={score:.3f}) ---\n{doc.page_content}" |
|
|
for i, (doc, score) in enumerate(top_k_selected) |
|
|
] |
|
|
context_text = "\n".join(retrieved_contents) |
|
|
|
|
|
if context_text: |
|
|
full_user_input = f"""Use the following retrieved context to answer the question.=== RETRIEVED CONTEXT (Cosine ≥ {threshold}) ==={context_text}=== END CONTEXT ===Question: {prompt}""" |
|
|
else: |
|
|
full_user_input = f"""Question: {prompt}""" |
|
|
|
|
|
|
|
|
messages = [{"role": "system", "content": sys_prompt}] |
|
|
|
|
|
for msg in history: |
|
|
messages.append({"role": msg["role"], "content": msg["content"]}) |
|
|
messages.append({"role": "user", "content": full_user_input}) |
|
|
|
|
|
inputs = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
inputs_tokenized = tokenizer(inputs, return_tensors="pt") |
|
|
input_ids = inputs_tokenized["input_ids"].to(DEVICE) |
|
|
|
|
|
|
|
|
MAX_CONTEXT = 4096 |
|
|
if input_ids.shape[1] > (MAX_CONTEXT - max_new_tokens): |
|
|
input_ids = input_ids[:, -(MAX_CONTEXT - max_new_tokens):] |
|
|
|
|
|
do_sample = True if temperature > 0 else False |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
input_ids=input_ids, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature if do_sample else None, |
|
|
do_sample=do_sample, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
pad_token_id=tokenizer.pad_token_id, |
|
|
) |
|
|
|
|
|
response = tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True) |
|
|
return response, context_text |
|
|
|
|
|
|
|
|
if prompt := st.chat_input("請輸入問題..."): |
|
|
vs = st.session_state.get("vector_store", None) |
|
|
display_prompt = prompt |
|
|
if vs: |
|
|
display_prompt = f"🔍 **[RAG]** {prompt}" |
|
|
|
|
|
st.chat_message("user").markdown(display_prompt) |
|
|
|
|
|
if model and tokenizer: |
|
|
with st.chat_message("assistant"): |
|
|
msg_placeholder = st.empty() |
|
|
|
|
|
with st.spinner("Analyzing..."): |
|
|
response, retrieved_ctx = generate_rag_response( |
|
|
prompt, |
|
|
st.session_state.messages, |
|
|
system_prompt, |
|
|
vector_store=vs, |
|
|
threshold=similarity_threshold, |
|
|
) |
|
|
|
|
|
msg_placeholder.markdown(response) |
|
|
|
|
|
|
|
|
if retrieved_ctx: |
|
|
with st.expander("查看檢索到的參考片段"): |
|
|
st.code(retrieved_ctx) |
|
|
st.download_button( |
|
|
label="📥 下載此參考內容 (.txt)", |
|
|
data=retrieved_ctx, |
|
|
file_name=f"rag_context_current.txt", |
|
|
mime="text/plain" |
|
|
) |
|
|
|
|
|
|
|
|
st.session_state.messages.append({"role": "user", "content": display_prompt}) |
|
|
st.session_state.messages.append({ |
|
|
"role": "assistant", |
|
|
"content": response, |
|
|
"context": retrieved_ctx |
|
|
}) |
|
|
|