|
|
import torch |
|
|
import streamlit as st |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import os |
|
|
import io |
|
|
|
|
|
try: |
|
|
import pypdf |
|
|
except ImportError: |
|
|
pypdf = None |
|
|
|
|
|
|
|
|
st.set_page_config(page_title="Cybersecurity AI Assistant", page_icon="🛡️", layout="wide") |
|
|
|
|
|
st.title("🛡️ Foundation-Sec-8B Dashboard") |
|
|
st.markdown("基於 `fdtn-ai/Foundation-Sec-8B` 模型的資安專家聊天機器人") |
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
st.header("⚙️ 設定") |
|
|
|
|
|
default_token = os.getenv("HF_TOKEN", "") |
|
|
hf_token = st.text_input("Hugging Face Token", value=default_token, type="password", help="請輸入您的 HF Token 以存取模型") |
|
|
|
|
|
st.divider() |
|
|
|
|
|
|
|
|
st.subheader("📂 上傳分析檔案") |
|
|
uploaded_file = st.file_uploader("上傳 Logs", type=['txt', 'py', 'log', 'csv', 'md', 'json', 'pdf']) |
|
|
|
|
|
if uploaded_file and uploaded_file.type == "application/pdf" and pypdf is None: |
|
|
st.warning("如果要支援 PDF,請安裝 pypdf: `pip install pypdf`") |
|
|
|
|
|
st.divider() |
|
|
|
|
|
st.subheader("模型參數") |
|
|
system_prompt = st.text_area("System Prompt", value="You are a cybersecurity expert. If the user provides a file content, analyze it carefully.", height=100) |
|
|
max_new_tokens = st.slider("Max New Tokens", min_value=128, max_value=4096, value=1024, step=128) |
|
|
temperature = st.slider("Temperature", min_value=0.0, max_value=1.5, value=0.1, step=0.1, help="數值越低,回答越保守固定;數值越高,回答越有創意。") |
|
|
repetition_penalty = st.slider("Repetition Penalty", min_value=1.0, max_value=2.0, value=1.2, step=0.1) |
|
|
|
|
|
if st.button("清除對話歷史"): |
|
|
st.session_state.messages = [] |
|
|
st.rerun() |
|
|
|
|
|
|
|
|
def get_device(): |
|
|
if torch.cuda.is_available(): |
|
|
return "cuda" |
|
|
elif torch.backends.mps.is_available(): |
|
|
return "mps" |
|
|
else: |
|
|
return "cpu" |
|
|
|
|
|
DEVICE = get_device() |
|
|
st.sidebar.markdown(f"**目前運算裝置:** `{DEVICE}`") |
|
|
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(model_id, token): |
|
|
if not token: |
|
|
return None, None, "TokenMissing" |
|
|
|
|
|
try: |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, token=token) |
|
|
|
|
|
template_status = "OK" |
|
|
if tokenizer.chat_template is None: |
|
|
tokenizer.chat_template = """ |
|
|
{% for message in messages %} |
|
|
<|im_start|>{{ message['role'] }} |
|
|
{{ message['content'] }}<|im_end|> |
|
|
{% endfor %} |
|
|
<|im_start|>assistant |
|
|
""" |
|
|
template_status = "TemplateSet" |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
pretrained_model_name_or_path=model_id, |
|
|
device_map="auto", |
|
|
torch_dtype=torch.bfloat16, |
|
|
token=token, |
|
|
) |
|
|
|
|
|
return tokenizer, model, template_status |
|
|
|
|
|
except Exception as e: |
|
|
return None, None, f"LoadFailed: {e}" |
|
|
|
|
|
|
|
|
if hf_token: |
|
|
MODEL_ID = "fdtn-ai/Foundation-Sec-8B" |
|
|
with st.spinner(f"正在載入模型 {MODEL_ID} ... (這可能需要幾分鐘)"): |
|
|
|
|
|
tokenizer, model, status = load_model(MODEL_ID, hf_token) |
|
|
|
|
|
|
|
|
if status == "TokenMissing": |
|
|
st.error("請先在側邊欄輸入 Hugging Face Token 才能開始。") |
|
|
st.stop() |
|
|
|
|
|
elif status.startswith("LoadFailed"): |
|
|
st.error(f"模型載入失敗: {status.split(': ')[1]}") |
|
|
st.stop() |
|
|
|
|
|
elif status == "TemplateSet": |
|
|
st.toast("Tokenizer 缺乏模板,已手動設定通用對話模板。", icon="⚙️") |
|
|
|
|
|
else: |
|
|
st.warning("請先輸入 Hugging Face Token 才能開始。") |
|
|
st.stop() |
|
|
|
|
|
|
|
|
if "messages" not in st.session_state: |
|
|
st.session_state.messages = [] |
|
|
|
|
|
|
|
|
def process_file_content(uploaded_file): |
|
|
"""讀取上傳檔案並轉為文字字串""" |
|
|
if uploaded_file is None: |
|
|
return None |
|
|
|
|
|
file_content = "" |
|
|
try: |
|
|
|
|
|
if uploaded_file.type == "application/pdf": |
|
|
if pypdf: |
|
|
pdf_reader = pypdf.PdfReader(uploaded_file) |
|
|
for page in pdf_reader.pages: |
|
|
file_content += page.extract_text() + "\n" |
|
|
else: |
|
|
return "[Error] PDF library not installed." |
|
|
|
|
|
else: |
|
|
stringio = io.StringIO(uploaded_file.getvalue().decode("utf-8")) |
|
|
file_content = stringio.read() |
|
|
|
|
|
return file_content |
|
|
except Exception as e: |
|
|
return f"[Error reading file: {str(e)}]" |
|
|
|
|
|
|
|
|
for message in st.session_state.messages: |
|
|
with st.chat_message(message["role"]): |
|
|
st.markdown(message["content"]) |
|
|
|
|
|
|
|
|
def generate_response(prompt, history, sys_prompt, file_context=None): |
|
|
|
|
|
messages = [{"role": "system", "content": sys_prompt}] |
|
|
|
|
|
|
|
|
for msg in history: |
|
|
messages.append({"role": msg["role"], "content": msg["content"]}) |
|
|
|
|
|
|
|
|
full_user_input = prompt |
|
|
if file_context: |
|
|
full_user_input = f"""I have uploaded a file. Here is the content: |
|
|
|
|
|
=== BEGIN FILE CONTENT === |
|
|
{file_context} |
|
|
=== END FILE CONTENT === |
|
|
User Question: {prompt} |
|
|
""" |
|
|
|
|
|
|
|
|
messages.append({"role": "user", "content": full_user_input}) |
|
|
|
|
|
inputs = tokenizer.apply_chat_template( |
|
|
messages, tokenize=False, add_generation_prompt=True |
|
|
) |
|
|
|
|
|
|
|
|
inputs_tokenized = tokenizer(inputs, return_tensors="pt") |
|
|
input_ids = inputs_tokenized["input_ids"].to(DEVICE) |
|
|
|
|
|
do_sample = True |
|
|
current_temp = temperature |
|
|
if temperature == 0: |
|
|
do_sample = False |
|
|
current_temp = None |
|
|
|
|
|
generation_args = { |
|
|
"max_new_tokens": max_new_tokens, |
|
|
"temperature": current_temp, |
|
|
"repetition_penalty": repetition_penalty, |
|
|
"do_sample": do_sample, |
|
|
"use_cache": True, |
|
|
"eos_token_id": tokenizer.eos_token_id, |
|
|
"pad_token_id": tokenizer.pad_token_id, |
|
|
} |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
input_ids=input_ids, |
|
|
**generation_args, |
|
|
) |
|
|
|
|
|
response = tokenizer.decode( |
|
|
outputs[0][input_ids.shape[1]:], |
|
|
skip_special_tokens=True |
|
|
) |
|
|
|
|
|
return response |
|
|
|
|
|
|
|
|
if prompt := st.chat_input("請輸入關於資安的問題..."): |
|
|
|
|
|
|
|
|
file_text = None |
|
|
display_prompt = prompt |
|
|
|
|
|
if uploaded_file: |
|
|
with st.spinner("正在讀取檔案內容..."): |
|
|
file_text = process_file_content(uploaded_file) |
|
|
if file_text: |
|
|
|
|
|
display_prompt = f"📄 **[已附加檔案: {uploaded_file.name}]**\n\n{prompt}" |
|
|
|
|
|
if len(file_text) > 20000: |
|
|
st.toast("⚠️ 檔案內容較長,可能會超過模型處理上限。", icon="⚠️") |
|
|
|
|
|
|
|
|
st.chat_message("user").markdown(display_prompt) |
|
|
|
|
|
|
|
|
if model and tokenizer: |
|
|
with st.chat_message("assistant"): |
|
|
message_placeholder = st.empty() |
|
|
with st.spinner("正在分析與思考中..."): |
|
|
|
|
|
response = generate_response(prompt, st.session_state.messages, system_prompt, file_context=file_text) |
|
|
message_placeholder.markdown(response) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.session_state.messages.append({"role": "user", "content": display_prompt}) |
|
|
st.session_state.messages.append({"role": "assistant", "content": response}) |