ss900371tw's picture
Update src/streamlit_app.py
a0f3479 verified
raw
history blame
8.77 kB
import torch
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import io
# 嘗試匯入 pypdf,如果沒有安裝則提示
try:
import pypdf
except ImportError:
pypdf = None
# --- 頁面設定 ---
st.set_page_config(page_title="Cybersecurity AI Assistant", page_icon="🛡️", layout="wide")
st.title("🛡️ Foundation-Sec-8B Dashboard")
st.markdown("基於 `fdtn-ai/Foundation-Sec-8B` 模型的資安專家聊天機器人")
# --- 側邊欄設定 (參數與 Token) ---
with st.sidebar:
st.header("⚙️ 設定")
default_token = os.getenv("HF_TOKEN", "")
hf_token = st.text_input("Hugging Face Token", value=default_token, type="password", help="請輸入您的 HF Token 以存取模型")
st.divider()
# === 新增:檔案上傳功能 ===
st.subheader("📂 上傳分析檔案")
uploaded_file = st.file_uploader("上傳 Logs", type=['txt', 'py', 'log', 'csv', 'md', 'json', 'pdf'])
if uploaded_file and uploaded_file.type == "application/pdf" and pypdf is None:
st.warning("如果要支援 PDF,請安裝 pypdf: `pip install pypdf`")
st.divider()
st.subheader("模型參數")
system_prompt = st.text_area("System Prompt", value="You are a cybersecurity expert. If the user provides a file content, analyze it carefully.", height=100)
max_new_tokens = st.slider("Max New Tokens", min_value=128, max_value=4096, value=1024, step=128) # 增加上限以容納長檔案分析
temperature = st.slider("Temperature", min_value=0.0, max_value=1.5, value=0.1, step=0.1, help="數值越低,回答越保守固定;數值越高,回答越有創意。")
repetition_penalty = st.slider("Repetition Penalty", min_value=1.0, max_value=2.0, value=1.2, step=0.1)
if st.button("清除對話歷史"):
st.session_state.messages = []
st.rerun()
# --- 硬體偵測 ---
def get_device():
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
else:
return "cpu"
DEVICE = get_device()
st.sidebar.markdown(f"**目前運算裝置:** `{DEVICE}`")
# --- 模型載入 (使用 cache 避免重複載入) ---
@st.cache_resource
def load_model(model_id, token):
if not token:
return None, None, "TokenMissing"
try:
tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
template_status = "OK"
if tokenizer.chat_template is None:
tokenizer.chat_template = """
{% for message in messages %}
<|im_start|>{{ message['role'] }}
{{ message['content'] }}<|im_end|>
{% endfor %}
<|im_start|>assistant
"""
template_status = "TemplateSet"
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
token=token,
)
return tokenizer, model, template_status
except Exception as e:
return None, None, f"LoadFailed: {e}"
if hf_token:
MODEL_ID = "fdtn-ai/Foundation-Sec-8B"
with st.spinner(f"正在載入模型 {MODEL_ID} ... (這可能需要幾分鐘)"):
# ⭐️ 接收新的回傳值 template_status
tokenizer, model, status = load_model(MODEL_ID, hf_token)
# ⭐️ 修正 5: 在 load_model 外部處理錯誤和狀態顯示
if status == "TokenMissing":
st.error("請先在側邊欄輸入 Hugging Face Token 才能開始。")
st.stop()
elif status.startswith("LoadFailed"):
st.error(f"模型載入失敗: {status.split(': ')[1]}")
st.stop()
elif status == "TemplateSet":
st.toast("Tokenizer 缺乏模板,已手動設定通用對話模板。", icon="⚙️")
else:
st.warning("請先輸入 Hugging Face Token 才能開始。")
st.stop()
# --- 初始化 Session State (對話歷史) ---
if "messages" not in st.session_state:
st.session_state.messages = []
# --- 檔案處理函數 ---
def process_file_content(uploaded_file):
"""讀取上傳檔案並轉為文字字串"""
if uploaded_file is None:
return None
file_content = ""
try:
# 處理 PDF
if uploaded_file.type == "application/pdf":
if pypdf:
pdf_reader = pypdf.PdfReader(uploaded_file)
for page in pdf_reader.pages:
file_content += page.extract_text() + "\n"
else:
return "[Error] PDF library not installed."
# 處理純文字/程式碼/Logs
else:
stringio = io.StringIO(uploaded_file.getvalue().decode("utf-8"))
file_content = stringio.read()
return file_content
except Exception as e:
return f"[Error reading file: {str(e)}]"
# --- 顯示對話歷史 ---
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# --- 推論邏輯 ---
def generate_response(prompt, history, sys_prompt, file_context=None):
# 建構符合 Chat Template 的格式
messages = [{"role": "system", "content": sys_prompt}]
# 將歷史對話加入
for msg in history:
messages.append({"role": msg["role"], "content": msg["content"]})
# 如果有檔案內容,將其組合進 Prompt 中
full_user_input = prompt
if file_context:
full_user_input = f"""I have uploaded a file. Here is the content:
=== BEGIN FILE CONTENT ===
{file_context}
=== END FILE CONTENT ===
User Question: {prompt}
"""
# 加入當前使用者輸入
messages.append({"role": "user", "content": full_user_input})
inputs = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# 注意:如果檔案太長,這裡可能會超過模型上限,實際生產環境需要做截斷處理
inputs_tokenized = tokenizer(inputs, return_tensors="pt")
input_ids = inputs_tokenized["input_ids"].to(DEVICE)
do_sample = True
current_temp = temperature
if temperature == 0:
do_sample = False
current_temp = None
generation_args = {
"max_new_tokens": max_new_tokens,
"temperature": current_temp,
"repetition_penalty": repetition_penalty,
"do_sample": do_sample,
"use_cache": True,
"eos_token_id": tokenizer.eos_token_id,
"pad_token_id": tokenizer.pad_token_id,
}
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids,
**generation_args,
)
response = tokenizer.decode(
outputs[0][input_ids.shape[1]:],
skip_special_tokens=True
)
return response
# --- 處理使用者輸入 ---
if prompt := st.chat_input("請輸入關於資安的問題..."):
# 1. 處理檔案
file_text = None
display_prompt = prompt # 在畫面上顯示的文字
if uploaded_file:
with st.spinner("正在讀取檔案內容..."):
file_text = process_file_content(uploaded_file)
if file_text:
# 如果有檔案,我們在畫面上加個小提示,但不要把整個檔案內容印出來洗版
display_prompt = f"📄 **[已附加檔案: {uploaded_file.name}]**\n\n{prompt}"
# 簡單的長度檢查警告
if len(file_text) > 20000:
st.toast("⚠️ 檔案內容較長,可能會超過模型處理上限。", icon="⚠️")
# 2. 顯示使用者訊息
st.chat_message("user").markdown(display_prompt)
# 3. 呼叫模型產生回應
if model and tokenizer:
with st.chat_message("assistant"):
message_placeholder = st.empty()
with st.spinner("正在分析與思考中..."):
# 傳入 file_text 作為額外上下文
response = generate_response(prompt, st.session_state.messages, system_prompt, file_context=file_text)
message_placeholder.markdown(response)
# 4. 更新對話歷史
# 這裡我們選擇儲存 display_prompt,讓歷史紀錄看得到有傳檔案,但模型實際上是收到完整文字
# 注意:為了節省 Context,歷史紀錄裡我們不存完整的檔案內容,只存使用者的問題
# 如果希望模型在"下一輪"對話還記得檔案,則必須將 full content 存入 history,但這會消耗大量記憶體
st.session_state.messages.append({"role": "user", "content": display_prompt})
st.session_state.messages.append({"role": "assistant", "content": response})