ss900371tw commited on
Commit
e528918
·
verified ·
1 Parent(s): 284db77

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +209 -34
src/streamlit_app.py CHANGED
@@ -1,40 +1,215 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
 
 
 
 
 
12
 
13
- In the meantime, below is an example of what you can do with just a few lines of code:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
1
  import streamlit as st
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import os
5
+ import io
6
+ # 嘗試匯入 pypdf,如果沒有安裝則提示
7
+ try:
8
+ import pypdf
9
+ except ImportError:
10
+ pypdf = None
11
 
12
+ # --- 頁面設定 ---
13
+ st.set_page_config(page_title="Cybersecurity AI Assistant", page_icon="🛡️", layout="wide")
14
+
15
+ st.title("🛡️ Foundation-Sec-8B-Instruct Dashboard")
16
+ st.markdown("基於 `fdtn-ai/Foundation-Sec-8B-Instruct` 模型的資安專家聊天機器人")
17
+
18
+ # --- 側邊欄設定 (參數與 Token) ---
19
+ with st.sidebar:
20
+ st.header("⚙️ 設定")
21
+
22
+ default_token = os.getenv("HF_TOKEN", "")
23
+ hf_token = st.text_input("Hugging Face Token", value=default_token, type="password", help="請輸入您的 HF Token 以存取模型")
24
+
25
+ st.divider()
26
+
27
+ # === 新增:檔案上傳功能 ===
28
+ st.subheader("📂 上傳分析檔案")
29
+ uploaded_file = st.file_uploader("上傳 Logs", type=['txt', 'py', 'log', 'csv', 'md', 'json', 'pdf'])
30
+
31
+ if uploaded_file and uploaded_file.type == "application/pdf" and pypdf is None:
32
+ st.warning("如果要支援 PDF,請安裝 pypdf: `pip install pypdf`")
33
+
34
+ st.divider()
35
+
36
+ st.subheader("模型參數")
37
+ system_prompt = st.text_area("System Prompt", value="You are a cybersecurity expert. If the user provides a file content, analyze it carefully.", height=100)
38
+ max_new_tokens = st.slider("Max New Tokens", min_value=128, max_value=4096, value=1024, step=128) # 增加上限以容納長檔案分析
39
+ temperature = st.slider("Temperature", min_value=0.0, max_value=1.5, value=0.1, step=0.1, help="數值越低,回答越保守固定;數值越高,回答越有創意。")
40
+ repetition_penalty = st.slider("Repetition Penalty", min_value=1.0, max_value=2.0, value=1.2, step=0.1)
41
+
42
+ if st.button("清除對話歷史"):
43
+ st.session_state.messages = []
44
+ st.rerun()
45
+
46
+ # --- 硬體偵測 ---
47
+ def get_device():
48
+ if torch.cuda.is_available():
49
+ return "cuda"
50
+ elif torch.backends.mps.is_available():
51
+ return "mps"
52
+ else:
53
+ return "cpu"
54
+
55
+ DEVICE = get_device()
56
+ st.sidebar.markdown(f"**目前運算裝置:** `{DEVICE}`")
57
+
58
+ # --- 模型載入 (使用 cache 避免重複載入) ---
59
+ @st.cache_resource
60
+ def load_model(model_id, token):
61
+ if not token:
62
+ st.error("請在側邊欄輸入 Hugging Face Token")
63
+ return None, None
64
+
65
+ try:
66
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
67
+ model = AutoModelForCausalLM.from_pretrained(
68
+ pretrained_model_name_or_path=model_id,
69
+ device_map="auto",
70
+ torch_dtype=torch.bfloat16,
71
+ token=token,
72
+ )
73
+ return tokenizer, model
74
+ except Exception as e:
75
+ st.error(f"模型載入失敗: {e}")
76
+ return None, None
77
 
78
+ # 只有在有 Token 時才載入模型
79
+ if hf_token:
80
+ MODEL_ID = "fdtn-ai/Foundation-Sec-8B-Instruct"
81
+ with st.spinner(f"正在載入模型 {MODEL_ID} ... (這可能需要幾分鐘)"):
82
+ tokenizer, model = load_model(MODEL_ID, hf_token)
83
+ else:
84
+ st.warning("請先輸入 Hugging Face Token 才能開始。")
85
+ st.stop()
86
 
87
+ # --- 初始化 Session State (對話歷史) ---
88
+ if "messages" not in st.session_state:
89
+ st.session_state.messages = []
90
+
91
+ # --- 檔案處理函數 ---
92
+ def process_file_content(uploaded_file):
93
+ """讀取上傳檔案並轉為文字字串"""
94
+ if uploaded_file is None:
95
+ return None
96
+
97
+ file_content = ""
98
+ try:
99
+ # 處理 PDF
100
+ if uploaded_file.type == "application/pdf":
101
+ if pypdf:
102
+ pdf_reader = pypdf.PdfReader(uploaded_file)
103
+ for page in pdf_reader.pages:
104
+ file_content += page.extract_text() + "\n"
105
+ else:
106
+ return "[Error] PDF library not installed."
107
+ # 處理純文字/程式碼/Logs
108
+ else:
109
+ stringio = io.StringIO(uploaded_file.getvalue().decode("utf-8"))
110
+ file_content = stringio.read()
111
+
112
+ return file_content
113
+ except Exception as e:
114
+ return f"[Error reading file: {str(e)}]"
115
+
116
+ # --- 顯示對話歷史 ---
117
+ for message in st.session_state.messages:
118
+ with st.chat_message(message["role"]):
119
+ st.markdown(message["content"])
120
+
121
+ # --- 推論邏輯 ---
122
+ def generate_response(prompt, history, sys_prompt, file_context=None):
123
+ # 建構符合 Chat Template 的格式
124
+ messages = [{"role": "system", "content": sys_prompt}]
125
+
126
+ # 將歷史對話加入
127
+ for msg in history:
128
+ messages.append({"role": msg["role"], "content": msg["content"]})
129
+
130
+ # 如果有檔案內容,將其組合進 Prompt 中
131
+ full_user_input = prompt
132
+ if file_context:
133
+ full_user_input = f"""I have uploaded a file. Here is the content:
134
+
135
+ === BEGIN FILE CONTENT ===
136
+ {file_context}
137
+ === END FILE CONTENT ===
138
+ User Question: {prompt}
139
  """
140
+
141
+ # 加入當前使用者輸入
142
+ messages.append({"role": "user", "content": full_user_input})
143
+
144
+ inputs = tokenizer.apply_chat_template(
145
+ messages, tokenize=False, add_generation_prompt=True
146
+ )
147
+
148
+ # 注意:如果檔案太長,這裡可能會超過模型上限,實際生產環境需要做截斷處理
149
+ inputs_tokenized = tokenizer(inputs, return_tensors="pt")
150
+ input_ids = inputs_tokenized["input_ids"].to(DEVICE)
151
+
152
+ do_sample = True
153
+ current_temp = temperature
154
+ if temperature == 0:
155
+ do_sample = False
156
+ current_temp = None
157
+
158
+ generation_args = {
159
+ "max_new_tokens": max_new_tokens,
160
+ "temperature": current_temp,
161
+ "repetition_penalty": repetition_penalty,
162
+ "do_sample": do_sample,
163
+ "use_cache": True,
164
+ "eos_token_id": tokenizer.eos_token_id,
165
+ "pad_token_id": tokenizer.pad_token_id,
166
+ }
167
+
168
+ with torch.no_grad():
169
+ outputs = model.generate(
170
+ input_ids=input_ids,
171
+ **generation_args,
172
+ )
173
+
174
+ response = tokenizer.decode(
175
+ outputs[0][input_ids.shape[1]:],
176
+ skip_special_tokens=True
177
+ )
178
+
179
+ return response
180
+
181
+ # --- 處理使用者輸入 ---
182
+ if prompt := st.chat_input("請輸入關於資安的問題..."):
183
+
184
+ # 1. 處理檔案
185
+ file_text = None
186
+ display_prompt = prompt # 在畫面上顯示的文字
187
+
188
+ if uploaded_file:
189
+ with st.spinner("正在讀取檔案內容..."):
190
+ file_text = process_file_content(uploaded_file)
191
+ if file_text:
192
+ # 如果有檔案,我們在畫面上加個小提示,但不要把整個檔案內容印出來洗版
193
+ display_prompt = f"📄 **[已附加檔案: {uploaded_file.name}]**\n\n{prompt}"
194
+ # 簡單的長度檢查警告
195
+ if len(file_text) > 20000:
196
+ st.toast("⚠️ 檔案內容較長,可能會超過模型處理上限。", icon="⚠️")
197
 
198
+ # 2. 顯示使用者訊息
199
+ st.chat_message("user").markdown(display_prompt)
200
+
201
+ # 3. 呼叫模型產生回應
202
+ if model and tokenizer:
203
+ with st.chat_message("assistant"):
204
+ message_placeholder = st.empty()
205
+ with st.spinner("正在分析與思考中..."):
206
+ # 傳入 file_text 作為額外上下文
207
+ response = generate_response(prompt, st.session_state.messages, system_prompt, file_context=file_text)
208
+ message_placeholder.markdown(response)
209
+
210
+ # 4. 更新對話歷史
211
+ # 這裡我們選擇儲存 display_prompt,讓歷史紀錄看得到有傳檔案,但模型實際上是收到完整文字
212
+ # 注意:為了節省 Context,歷史紀錄裡我們不存完整的檔案內容,只存使用者的問題
213
+ # 如果希望模型在"下一輪"對話還記得檔案,則必須將 full content 存入 history,但這會消耗大量記憶體
214
+ st.session_state.messages.append({"role": "user", "content": display_prompt})
215
+ st.session_state.messages.append({"role": "assistant", "content": response})