Donlagon007 commited on
Commit
985e0f6
·
verified ·
1 Parent(s): d9d69d7

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile +25 -0
  2. app.py +447 -404
  3. requirements.txt +1 -0
Dockerfile ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.13.5-slim
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update && apt-get install -y \
6
+ build-essential \
7
+ curl \
8
+ git \
9
+ fontconfig \
10
+ fonts-wqy-microhei \
11
+ fonts-wqy-zenhei \
12
+ fonts-noto-cjk \
13
+ && fc-cache -f \
14
+ && rm -rf /var/lib/apt/lists/*
15
+
16
+ COPY requirements.txt ./
17
+ COPY src/ ./src/
18
+
19
+ RUN pip3 install -r requirements.txt
20
+
21
+ EXPOSE 8501
22
+
23
+ HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
24
+
25
+ ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
app.py CHANGED
@@ -1,432 +1,475 @@
1
- import streamlit as st
2
- import torch
3
- import pandas as pd
4
- import numpy as np
5
- import seaborn as sns
6
- import matplotlib.pyplot as plt
7
- import re
8
- import jieba
9
- import matplotlib
10
- import matplotlib.font_manager as fm
11
- from transformers import AutoTokenizer, AutoModel
12
- import os
13
- import warnings
14
 
15
-
16
- # ===============================
17
- # 中文字體設定(跨平台支持)
18
- # ===============================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def setup_chinese_font():
20
- # เส้นทา่พบบ่อใน Ubuntu/HF Spaces
 
 
 
 
 
 
 
 
 
 
21
  candidate_paths = [
22
  "/usr/share/fonts/truetype/wqy/wqy-microhei.ttc",
23
  "/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc",
24
- "/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc", # Debian/Ubuntu
25
- "/usr/share/fonts/opentype/noto/NotoSansCJK-Sc-Regular.otf", # SC
26
- "/usr/share/fonts/opentype/noto/NotoSansCJK-TC-Regular.otf", # TC
27
  "/usr/share/fonts/truetype/droid/DroidSansFallbackFull.ttf",
28
  ]
29
-
30
  for p in candidate_paths:
31
  if os.path.exists(p):
32
- # ลงทะเบียนฟอนต์เข้า fontManager (สำคัญ)
33
  fm.fontManager.addfont(p)
34
  prop = fm.FontProperties(fname=p)
35
- # ตั้งค่าสำรองให้ใช้ชื่อฟอนต์ที่เพิ่ง add เข้าไป
36
  matplotlib.rcParams["font.sans-serif"] = [prop.get_name(), "DejaVu Sans"]
37
  matplotlib.rcParams["axes.unicode_minus"] = False
 
38
  return prop
39
 
40
- # สแกนทั้งระบบเผื่อชื่อไฟล์ต่าง distribution
41
  for p in fm.findSystemFonts(fontpaths=["/usr/share/fonts", "/usr/local/share/fonts"]):
42
  if any(k in p.lower() for k in ["wqy", "noto", "cjk", "droid"]):
43
  fm.fontManager.addfont(p)
44
  prop = fm.FontProperties(fname=p)
45
  matplotlib.rcParams["font.sans-serif"] = [prop.get_name(), "DejaVu Sans"]
46
  matplotlib.rcParams["axes.unicode_minus"] = False
 
47
  return prop
48
 
 
49
  warnings.warn("ไม่พบฟอนต์จีน ใช้ DejaVu Sans ชั่วคราว")
50
  matplotlib.rcParams["font.sans-serif"] = ["DejaVu Sans"]
51
  matplotlib.rcParams["axes.unicode_minus"] = False
52
  return fm.FontProperties()
53
-
54
- zh_font = setup_chinese_font()
55
-
56
- # ===============================
57
- # 頁面設定
58
- # ===============================
59
- st.set_page_config(page_title="中文詞級 Transformer 可視化", layout="wide")
60
- st.title("🧠 中文詞級 Transformer Token / Position / Attention 可視化工具")
61
-
62
- # ===============================
63
- # 模型選擇與載入
64
- # ===============================
65
- model_options = {
66
- "Chinese RoBERTa (WWM-ext)": "hfl/chinese-roberta-wwm-ext",
67
- "BERT-base-Chinese": "bert-base-chinese",
68
- "Chinese MacBERT-base": "hfl/chinese-macbert-base"
69
- }
70
-
71
- selected_model = st.selectbox(
72
- "選擇模型",
73
- list(model_options.keys()),
74
- index=0
75
- )
76
-
77
- model_name = model_options[selected_model]
78
-
79
-
80
- @st.cache_resource
81
- def load_model(name):
82
- with st.spinner(f"載入模型 {name} 中..."):
83
- try:
84
- tokenizer = AutoTokenizer.from_pretrained(name)
85
- model = AutoModel.from_pretrained(name, output_attentions=True)
86
- return tokenizer, model, None
87
- except Exception as e:
88
- return None, None, str(e)
89
-
90
-
91
- tokenizer, model, error = load_model(model_name)
92
-
93
- if error:
94
- st.error(f"模型載入失敗: {error}")
95
- st.stop()
96
-
97
- # ===============================
98
- # 使用者輸入
99
- # ===============================
100
- text = st.text_area(
101
- "請輸入中文句子:",
102
- "我今年35歲,目前在科技業工作,作息略不規律。",
103
- help="輸入您想分析的中文文本。將使用 Jieba 進行分詞然後用 Transformer 模型分析。"
104
- )
105
-
106
-
107
- def normalize_text(s):
108
- """移除特殊符號與全形字"""
109
- s = re.sub(r"[^\u4e00-\u9fa5A-Za-z0-9,。、;:?!%%\s]", "", s)
110
- s = s.replace("%", "%").replace("。", "。 ")
111
- return s.strip()
112
-
113
-
114
- # ===============================
115
- # 主流程
116
- # ===============================
117
- if st.button("開始分析", type="primary"):
118
- if not text.strip():
119
- st.warning("請輸入有效的中文句子")
120
- st.stop()
121
-
122
- # 文本清理與分詞
123
- text = normalize_text(text)
124
- words = list(jieba.cut(text))
125
- st.write("🔹 Jieba 分詞結果:", words)
126
-
127
- # 不使用空格連接,直接使用原始文本
128
- # 這樣可以避免空格導致的詞-token不匹配問題
129
- tokenized_result = tokenizer(text, return_tensors="pt")
130
- token_ids = tokenized_result["input_ids"][0].tolist()
131
- tokens = tokenizer.convert_ids_to_tokens(token_ids)
132
-
133
- # 為了更準確地映射詞和token,我們需要找出每個token在原始文本中的位置
134
- # 創建穩健的詞-token映射
135
- char_to_word = {}
136
- current_pos = 0
137
-
138
- # 為每個字符創建映射到對應詞的索引
139
- for word_idx, word in enumerate(words):
140
- for _ in range(len(word)):
141
- char_to_word[current_pos] = word_idx
142
- current_pos += 1
143
-
144
- # 創建token到字符位置的映射
145
- # 注意:這個方法適用於基於字符的中文模型,如BERT/RoBERTa中文模型
146
- # 某些模型可能需要調整
147
-
148
- # 首先找出特殊標記
149
- special_tokens = []
150
- for i, token in enumerate(tokens):
151
- if token in ['[CLS]', '[SEP]', '<s>', '</s>', '<cls>', '<sep>']:
152
- special_tokens.append(i)
153
-
154
- # 找出原始文本中每個token的起始位置
155
- chars = list(text) # 文本轉換為字符列表
156
- token_to_char_mapping = []
157
- token_to_word_mapping = []
158
-
159
- # 處理特殊標記
160
- char_pos = 0
161
- for i, token in enumerate(tokens):
162
- if i in special_tokens:
163
- token_to_char_mapping.append(-1) # 特殊標記沒有對應的字符位置
164
- token_to_word_mapping.append("特殊標記")
165
- else:
166
- # 對於中文字符,大多數模型是一個字符一個token
167
- # 這個邏輯可能需要根據具體模型調整
168
- if token.startswith('##'): # BERT風格的子詞
169
- actual_token = token[2:]
170
- elif token.startswith('▁') or token.startswith('Ġ'): # 其他模型風格
171
- actual_token = token[1:]
172
- else:
173
- actual_token = token
174
-
175
- # 注意:中文BERT通常每個token就是一個字符
176
- # 所以這裡可以直接映射
177
- if char_pos < len(chars):
178
- token_to_char_mapping.append(char_pos)
179
- if char_pos in char_to_word:
180
- word_idx = char_to_word[char_pos]
181
- token_to_word_mapping.append(words[word_idx])
182
- else:
183
- token_to_word_mapping.append("未知詞")
184
- char_pos += len(actual_token)
185
- else:
186
- token_to_char_mapping.append(-1)
187
- token_to_word_mapping.append("未知詞")
188
-
189
- # 創建詞到token的映射
190
- word_to_tokens = [[] for _ in range(len(words))]
191
- for i, word_idx in enumerate(char_to_word.values()):
192
- if i < len(chars):
193
- # 找出對應這個字符位置的token
194
- for j, char_pos in enumerate(token_to_char_mapping):
195
- if char_pos == i:
196
- word_to_tokens[word_idx].append(j)
197
- break
198
-
199
- # 創建token-word對照表
200
- token_word_df = pd.DataFrame({
201
- "Token": tokens,
202
- "Token_ID": token_ids,
203
- "Word": token_to_word_mapping
204
- })
205
-
206
- # 創建word-tokens對照表
207
- word_token_map = []
208
- for i, word in enumerate(words):
209
- token_indices = word_to_tokens[i]
210
- token_list = [tokens[idx] for idx in token_indices if idx < len(tokens)]
211
- word_token_map.append({
212
- "Word": word,
213
- "Tokens": " ".join(token_list) if token_list else "無對應Token"
214
- })
215
-
216
- word_token_df = pd.DataFrame(word_token_map)
217
-
218
- # 模型前向運算
219
- with torch.no_grad():
220
- try:
221
- outputs = model(**tokenized_result)
222
-
223
- hidden_states = outputs.last_hidden_state.squeeze(0)
224
- attentions = outputs.attentions
225
-
226
- # Position & Token embeddings
227
- position_ids = torch.arange(0, tokenized_result["input_ids"].size(1)).unsqueeze(0)
228
- pos_embeddings = model.embeddings.position_embeddings(position_ids).squeeze(0)
229
- tok_embeddings = model.embeddings.word_embeddings(tokenized_result["input_ids"]).squeeze(0)
230
-
231
- # ===============================
232
- # 顯示 Token-Word 映射
233
- # ===============================
234
- st.subheader("🔤 Token與詞的對應關係")
235
-
236
- # 顯示詞-Token映射
237
- st.write("對應的Tokens:")
238
- st.dataframe(word_token_df, use_container_width=True)
239
-
240
- # 顯示Token-詞映射
241
- st.write("每個Token對應的:")
242
- st.dataframe(token_word_df, use_container_width=True)
243
-
244
- # ===============================
245
- # 顯示 Embedding(前10維)
246
- # ===============================
247
- st.subheader("🧩 Token Embedding(前10維)")
248
- tok_df = pd.DataFrame(tok_embeddings[:, :10].detach().numpy(),
249
- columns=[f"dim_{i}" for i in range(10)])
250
- tok_df.insert(0, "Token", tokens)
251
- tok_df.insert(1, "Word", token_word_df["Word"])
252
- st.dataframe(tok_df, use_container_width=True)
253
-
254
- st.subheader("📍 Position Embedding(前10維)")
255
- pos_df = pd.DataFrame(pos_embeddings[:, :10].detach().numpy(),
256
- columns=[f"dim_{i}" for i in range(10)])
257
- pos_df.insert(0, "Token", tokens)
258
- pos_df.insert(1, "Word", token_word_df["Word"])
259
- st.dataframe(pos_df, use_container_width=True)
260
-
261
- # ===============================
262
- # Attention 可視化
263
- # ===============================
264
- num_layers = len(attentions)
265
- num_heads = attentions[0].shape[1]
266
-
267
- col1, col2 = st.columns(2)
268
- with col1:
269
- layer_idx = st.slider("選擇 Attention 層數", 1, num_layers, num_layers)
270
- with col2:
271
- head_idx = st.slider("選擇 Attention Head", 1, num_heads, 1)
272
-
273
- # 取得該層、該頭的注意力矩陣
274
- selected_attention = attentions[layer_idx - 1][0, head_idx - 1].detach().numpy()
275
- mean_attention = attentions[layer_idx - 1][0].mean(0).detach().numpy()
276
-
277
- # 添加標註信息
278
- token_labels = [f"{t}\n({w})" if w != "特殊記" else t
279
- for t, w in zip(tokens, token_word_df["Word"])]
280
-
281
- # 單頭 Attention Heatmap
282
- st.subheader(f"🔥 Attention Heatmap(第 {layer_idx} 層,第 {head_idx} 頭)")
283
- fig, ax = plt.subplots(figsize=(12, 10))
284
- sns.heatmap(selected_attention, xticklabels=token_labels, yticklabels=token_labels,
285
- cmap="YlGnBu", ax=ax)
286
- plt.title(f"Attention - Layer {layer_idx}, Head {head_idx}", fontproperties=zh_font)
287
- plt.xticks(rotation=90, fontsize=10, fontproperties=zh_font)
288
- plt.yticks(rotation=0, fontsize=10, fontproperties=zh_font)
289
- st.pyplot(fig, clear_figure=True, use_container_width=True)
290
-
291
- # 平均所有頭
292
- st.subheader(f"🌈 平均所有頭(第 {layer_idx} 層)")
293
- fig2, ax2 = plt.subplots(figsize=(12, 10))
294
- sns.heatmap(mean_attention, xticklabels=token_labels, yticklabels=token_labels,
295
- cmap="rocket_r", ax=ax2)
296
- plt.title(f"Mean Attention - Layer {layer_idx}", fontproperties=zh_font)
297
- plt.xticks(rotation=90, fontsize=10, fontproperties=zh_font)
298
- plt.yticks(rotation=0, fontsize=10, fontproperties=zh_font)
299
- st.pyplot(fig2, clear_figure=True, use_container_width=True)
300
-
301
- # ===============================
302
- # 詞的平均注意力可視化
303
- # ===============================
304
- st.subheader("📊 詞級別注意力熱圖")
305
-
306
- # 創建詞彙列表(去除特殊標記和未知詞)
307
- unique_words = [w for w in words if w.strip()]
308
-
309
- if len(unique_words) > 1: # 確保有足夠的詞進行可視化
310
- # 創建-詞注意力矩陣
311
- word_attention = np.zeros((len(unique_words), len(unique_words)))
312
-
313
- # 使用之前建立的映射來聚合token級別的注意力到詞級別
314
- for i, word_i in enumerate(unique_words):
315
- # 找出屬於word_i的所有token
316
- tokens_i = []
317
- for j, w in enumerate(token_word_df["Word"]):
318
- if w == word_i:
319
- tokens_i.append(j)
320
-
321
- for j, word_j in enumerate(unique_words):
322
- # 找出屬於word_j的所有token
323
- tokens_j = []
324
- for k, w in enumerate(token_word_df["Word"]):
325
- if w == word_j:
326
- tokens_j.append(k)
327
-
328
- # 計算這兩個詞之間的所有token對的平均注意力
329
- if tokens_i and tokens_j: # 確保兩個詞有對token
330
- attention_sum = 0
331
- count = 0
332
- for ti in tokens_i:
333
- for tj in tokens_j:
334
- if ti < len(selected_attention) and tj < len(selected_attention[0]):
335
- attention_sum += selected_attention[ti, tj]
336
- count += 1
337
-
338
- if count > 0:
339
- word_attention[i, j] = attention_sum / count
340
-
341
- # 繪製詞級別注意力熱圖
342
- fig3, ax3 = plt.subplots(figsize=(10, 8))
343
- sns.heatmap(word_attention, xticklabels=unique_words, yticklabels=unique_words,
344
- cmap="viridis", ax=ax3)
345
- plt.title(f"詞級別注意力 - Layer {layer_idx}, Head {head_idx}", fontproperties=zh_font)
346
- plt.xticks(rotation=45, fontsize=12, fontproperties=zh_font)
347
- plt.yticks(rotation=0, fontsize=12, fontproperties=zh_font)
348
- st.pyplot(fig3, clear_figure=True, use_container_width=True)
349
- else:
350
- st.info("詞數量不足,無法生成詞級別注意力熱圖")
351
-
352
- # ===============================
353
- # 下載 CSV
354
- # ===============================
355
- merged_df = pd.concat([tok_df, pos_df.add_prefix("pos_").iloc[:, 2:]], axis=1)
356
- st.download_button(
357
- label="💾 下載 Token + Position 向量 CSV",
358
- data=merged_df.to_csv(index=False).encode("utf-8-sig"),
359
- file_name="embeddings.csv",
360
- mime="text/csv"
361
- )
362
-
363
- # 詞級別平均 embeddings
364
- st.subheader("📑 詞級別平均 Embeddings(前10維)")
365
-
366
- word_embeddings = {}
367
- for word in unique_words:
368
- # 找出屬於該詞的所有token索引
369
- token_indices = [i for i, w in enumerate(token_word_df["Word"]) if w == word]
370
-
371
- if token_indices:
372
- # 計算該詞的平均 embedding
373
- word_emb = tok_embeddings[token_indices].mean(dim=0)
374
- word_embeddings[word] = word_emb[:10].detach().numpy()
375
-
376
- if word_embeddings:
377
- word_emb_df = pd.DataFrame.from_dict(
378
- {word: values for word, values in word_embeddings.items()},
379
- orient='index',
380
- columns=[f"dim_{i}" for i in range(10)]
381
- )
382
- word_emb_df = word_emb_df.reset_index().rename(columns={"index": "Word"})
383
- st.dataframe(word_emb_df, use_container_width=True)
384
-
385
- # 下載詞級別 embeddings
386
- st.download_button(
387
- label="💾 下載詞級別向量 CSV",
388
- data=word_emb_df.to_csv(index=False).encode("utf-8-sig"),
389
- file_name="word_embeddings.csv",
390
- mime="text/csv"
391
- )
392
-
393
- except Exception as e:
394
- st.error(f"處理時發生錯誤: {str(e)}")
395
- import traceback
396
-
397
- st.code(traceback.format_exc(), language="python")
398
-
399
- # ===============================
400
- # 說明與幫助
401
- # ===============================
402
- with st.expander("📖 使用說明"):
403
- st.markdown("""
404
- ### 工具功能
405
-
406
- 這個工具可以幫助您理解 Transformer 模型如何處理中文文本:
407
-
408
- 1. **分詞與映射**:使用 Jieba 將文本分詞,然後映射到 Transformer 模型的 token
409
- 2. **Embedding 可視化**:查看每個 token 和位置的 embedding 向量前10維
410
- 3. **Attention 可視化**:查看不同層注意力模式
411
- 4. **詞級別分析**:整合 token 級別信息,得到詞級別的 embedding 和注意力模式
412
-
413
- ### 使用方法
414
-
415
- 1. 選擇一個預訓練的中文模型
416
- 2. 輸入您想分析的中文文本
417
- 3. 點擊"開始分析"按鈕
418
- 4. 使用滑塊選擇不同的層和注意力頭進行可視化
419
- 5. 下載 CSV 文件以一步分析
420
-
421
- ### 技術細節
422
-
423
- - **詞-Token映射**:中文字符通常會被映射到單個Token,而詞通常由多個Token組成
424
- - **注意力機制**:每一層的每注意力頭都關注不同的模式
425
- - **注意力熱圖**:顏色越深表示注意力越強
426
-
427
- ### 注意事項
428
-
429
- - Transformer 模型可能會將一個詞切分成多個 token
430
- - 特殊標記(如 [CLS], [SEP])被排除在級別析之外
431
- - 較長的文本可能需要更多處理時間
 
432
  """)
 
1
+ from pathlib import Path
2
+ import requests
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ @st.cache_resource
5
+ def ensure_local_cjk_font(font_filename="NotoSansCJKtc-Regular.otf", variant="tc"):
6
+ """
7
+ ดาวน์โหลดฟอนต์ Noto Sans CJK ครั้งแรกตอนรัน แล้วลงทะเบียนให้ matplotlib ใช้
8
+ variant: "tc"=ตัวเต็ม (Taiwan/HK), "sc"=ตัวย่อ (CN)
9
+ """
10
+ here = Path(__file__).resolve().parent
11
+ fonts_dir = here / "fonts"
12
+ fonts_dir.mkdir(exist_ok=True)
13
+ dest = fonts_dir / font_filename
14
+
15
+ if not dest.exists():
16
+ url = {
17
+ "tc": "https://github.com/googlefonts/noto-cjk/raw/main/Sans/OTF/TraditionalChinese/NotoSansCJKtc-Regular.otf",
18
+ "sc": "https://github.com/googlefonts/noto-cjk/raw/main/Sans/OTF/SimplifiedChinese/NotoSansCJKsc-Regular.otf",
19
+ }[variant]
20
+ r = requests.get(url, timeout=60)
21
+ r.raise_for_status()
22
+ dest.write_bytes(r.content)
23
+ print("⬇️ Downloaded font to", dest)
24
+
25
+ fm.fontManager.addfont(str(dest))
26
+ prop = fm.FontProperties(fname=str(dest))
27
+ family = prop.get_name()
28
+ matplotlib.rcParams["font.sans-serif"] = [family, "DejaVu Sans"]
29
+ matplotlib.rcParams["axes.unicode_minus"] = False
30
+ print("✅ Using CJK font (local):", dest.name, "->", family)
31
+ return prop, str(dest), family
32
+
33
+ import streamlit as st
34
+ import torch
35
+ import pandas as pd
36
+ import numpy as np
37
+ import seaborn as sns
38
+ import matplotlib.pyplot as plt
39
+ import re
40
+ import jieba
41
+ import matplotlib
42
+ import matplotlib.font_manager as fm
43
+ from transformers import AutoTokenizer, AutoModel
44
+ import os
45
+ import warnings
46
+
47
+
48
+ # ===============================
49
+ # 中文字體設定(跨平台支持)
50
+ # ===============================
51
  def setup_chinese_font():
52
+ # 0) บัคับมฟอนต์แบบ local ก่อน (จบในรอบเดียว ไม่ง้อระบบ)
53
+ try:
54
+ prop, path, family = ensure_local_cjk_font(
55
+ font_filename="NotoSansCJKtc-Regular.otf", # หรือ NotoSansCJKsc-Regular.otf
56
+ variant="tc" # หรือ "sc"
57
+ )
58
+ return prop
59
+ except Exception as _:
60
+ pass # ถ้าดาวน์โหลดพลาด ค่อยตกไปลอง path ระบบด้านล่าง
61
+
62
+ # 1) ถัดไป: ลอง path ระบบตามปกติ (เผื่อคุณไปติดตั้งผ่าน apt/Docker ไว้แล้ว)
63
  candidate_paths = [
64
  "/usr/share/fonts/truetype/wqy/wqy-microhei.ttc",
65
  "/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc",
66
+ "/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc",
67
+ "/usr/share/fonts/opentype/noto/NotoSansCJK-Sc-Regular.otf",
68
+ "/usr/share/fonts/opentype/noto/NotoSansCJK-TC-Regular.otf",
69
  "/usr/share/fonts/truetype/droid/DroidSansFallbackFull.ttf",
70
  ]
 
71
  for p in candidate_paths:
72
  if os.path.exists(p):
 
73
  fm.fontManager.addfont(p)
74
  prop = fm.FontProperties(fname=p)
 
75
  matplotlib.rcParams["font.sans-serif"] = [prop.get_name(), "DejaVu Sans"]
76
  matplotlib.rcParams["axes.unicode_minus"] = False
77
+ print("✅ Using system CJK font:", p, "->", prop.get_name())
78
  return prop
79
 
80
+ # 2) scan ทั้งระบบ (กันพลาดชื่อไฟล์ต่าง distro)
81
  for p in fm.findSystemFonts(fontpaths=["/usr/share/fonts", "/usr/local/share/fonts"]):
82
  if any(k in p.lower() for k in ["wqy", "noto", "cjk", "droid"]):
83
  fm.fontManager.addfont(p)
84
  prop = fm.FontProperties(fname=p)
85
  matplotlib.rcParams["font.sans-serif"] = [prop.get_name(), "DejaVu Sans"]
86
  matplotlib.rcParams["axes.unicode_minus"] = False
87
+ print("✅ Using scanned CJK font:", p, "->", prop.get_name())
88
  return prop
89
 
90
+ # 3) สุดท้ายค่อยเตือน
91
  warnings.warn("ไม่พบฟอนต์จีน ใช้ DejaVu Sans ชั่วคราว")
92
  matplotlib.rcParams["font.sans-serif"] = ["DejaVu Sans"]
93
  matplotlib.rcParams["axes.unicode_minus"] = False
94
  return fm.FontProperties()
95
+
96
+
97
+ zh_font = setup_chinese_font()
98
+
99
+ # ===============================
100
+ # 頁面設定
101
+ # ===============================
102
+ st.set_page_config(page_title="中文詞級 Transformer 可視化", layout="wide")
103
+ st.title("🧠 中文詞級 Transformer Token / Position / Attention 可視化工具")
104
+
105
+ # ===============================
106
+ # 模型選擇與載入
107
+ # ===============================
108
+ model_options = {
109
+ "Chinese RoBERTa (WWM-ext)": "hfl/chinese-roberta-wwm-ext",
110
+ "BERT-base-Chinese": "bert-base-chinese",
111
+ "Chinese MacBERT-base": "hfl/chinese-macbert-base"
112
+ }
113
+
114
+ selected_model = st.selectbox(
115
+ "選擇模型",
116
+ list(model_options.keys()),
117
+ index=0
118
+ )
119
+
120
+ model_name = model_options[selected_model]
121
+
122
+
123
+ @st.cache_resource
124
+ def load_model(name):
125
+ with st.spinner(f"載入模型 {name} 中..."):
126
+ try:
127
+ tokenizer = AutoTokenizer.from_pretrained(name)
128
+ model = AutoModel.from_pretrained(name, output_attentions=True)
129
+ return tokenizer, model, None
130
+ except Exception as e:
131
+ return None, None, str(e)
132
+
133
+
134
+ tokenizer, model, error = load_model(model_name)
135
+
136
+ if error:
137
+ st.error(f"模型載入失敗: {error}")
138
+ st.stop()
139
+
140
+ # ===============================
141
+ # 使用者輸入
142
+ # ===============================
143
+ text = st.text_area(
144
+ "請輸入中文句子:",
145
+ "我今年35歲目前在科技業工作,作息略不規律。",
146
+ help="輸入您想分析的中文文本。將使用 Jieba 進行分詞,然後用 Transformer 模型分析。"
147
+ )
148
+
149
+
150
+ def normalize_text(s):
151
+ """移除特殊符號與全形字"""
152
+ s = re.sub(r"[^\u4e00-\u9fa5A-Za-z0-9,。、;:?!%\s]", "", s)
153
+ s = s.replace("%", "%").replace("。", "。 ")
154
+ return s.strip()
155
+
156
+
157
+ # ===============================
158
+ # 主流程
159
+ # ===============================
160
+ if st.button("開始分析", type="primary"):
161
+ if not text.strip():
162
+ st.warning("請輸入有效的中文句子")
163
+ st.stop()
164
+
165
+ # 文本清理與分詞
166
+ text = normalize_text(text)
167
+ words = list(jieba.cut(text))
168
+ st.write("🔹 Jieba 分詞結果:", words)
169
+
170
+ # 不使用空格連接,直接使用原始文本
171
+ # 這樣可以避免空格導致的詞-token不匹配問題
172
+ tokenized_result = tokenizer(text, return_tensors="pt")
173
+ token_ids = tokenized_result["input_ids"][0].tolist()
174
+ tokens = tokenizer.convert_ids_to_tokens(token_ids)
175
+
176
+ # 為了準確地映射詞和token,我們需要找出每個token在原始文本中的位置
177
+ # 創建更穩健的詞-token映射
178
+ char_to_word = {}
179
+ current_pos = 0
180
+
181
+ # 為每個字符創建映射到對應詞的索引
182
+ for word_idx, word in enumerate(words):
183
+ for _ in range(len(word)):
184
+ char_to_word[current_pos] = word_idx
185
+ current_pos += 1
186
+
187
+ # 創建token到字符位置映射
188
+ # 注意:這個方法適用基於字符的中文模型,如BERT/RoBERTa中文模型
189
+ # 對於某些模型可能需要調整
190
+
191
+ # 首先找出特殊標記
192
+ special_tokens = []
193
+ for i, token in enumerate(tokens):
194
+ if token in ['[CLS]', '[SEP]', '<s>', '</s>', '<cls>', '<sep>']:
195
+ special_tokens.append(i)
196
+
197
+ # 找出原始文本中每個token的起始位置
198
+ chars = list(text) # 將文本轉換為字符列表
199
+ token_to_char_mapping = []
200
+ token_to_word_mapping = []
201
+
202
+ # 處理特殊標記
203
+ char_pos = 0
204
+ for i, token in enumerate(tokens):
205
+ if i in special_tokens:
206
+ token_to_char_mapping.append(-1) # 特殊標記沒有對應的字符位置
207
+ token_to_word_mapping.append("特殊標記")
208
+ else:
209
+ # 對於中文字符,大多數模型是一個字符一個token
210
+ # 這個��輯可能需要根據具體模型調整
211
+ if token.startswith('##'): # BERT風格的子詞
212
+ actual_token = token[2:]
213
+ elif token.startswith('▁') or token.startswith('Ġ'): # 其他模型風格
214
+ actual_token = token[1:]
215
+ else:
216
+ actual_token = token
217
+
218
+ # 注意:中文BERT通常每個token就是一個字符
219
+ # 所以這裡可以直接映射
220
+ if char_pos < len(chars):
221
+ token_to_char_mapping.append(char_pos)
222
+ if char_pos in char_to_word:
223
+ word_idx = char_to_word[char_pos]
224
+ token_to_word_mapping.append(words[word_idx])
225
+ else:
226
+ token_to_word_mapping.append("未知詞")
227
+ char_pos += len(actual_token)
228
+ else:
229
+ token_to_char_mapping.append(-1)
230
+ token_to_word_mapping.append("未知詞")
231
+
232
+ # 創建詞到token的映射
233
+ word_to_tokens = [[] for _ in range(len(words))]
234
+ for i, word_idx in enumerate(char_to_word.values()):
235
+ if i < len(chars):
236
+ # 找出對應這個字符位置的token
237
+ for j, char_pos in enumerate(token_to_char_mapping):
238
+ if char_pos == i:
239
+ word_to_tokens[word_idx].append(j)
240
+ break
241
+
242
+ # 創建token-word對照表
243
+ token_word_df = pd.DataFrame({
244
+ "Token": tokens,
245
+ "Token_ID": token_ids,
246
+ "Word": token_to_word_mapping
247
+ })
248
+
249
+ # 創建word-tokens對照表
250
+ word_token_map = []
251
+ for i, word in enumerate(words):
252
+ token_indices = word_to_tokens[i]
253
+ token_list = [tokens[idx] for idx in token_indices if idx < len(tokens)]
254
+ word_token_map.append({
255
+ "Word": word,
256
+ "Tokens": " ".join(token_list) if token_list else "無對應Token"
257
+ })
258
+
259
+ word_token_df = pd.DataFrame(word_token_map)
260
+
261
+ # 模型前向運算
262
+ with torch.no_grad():
263
+ try:
264
+ outputs = model(**tokenized_result)
265
+
266
+ hidden_states = outputs.last_hidden_state.squeeze(0)
267
+ attentions = outputs.attentions
268
+
269
+ # Position & Token embeddings
270
+ position_ids = torch.arange(0, tokenized_result["input_ids"].size(1)).unsqueeze(0)
271
+ pos_embeddings = model.embeddings.position_embeddings(position_ids).squeeze(0)
272
+ tok_embeddings = model.embeddings.word_embeddings(tokenized_result["input_ids"]).squeeze(0)
273
+
274
+ # ===============================
275
+ # 顯示 Token-Word 映射
276
+ # ===============================
277
+ st.subheader("🔤 Token與詞的對應關係")
278
+
279
+ # 顯示-Token映射
280
+ st.write("詞對應的Tokens:")
281
+ st.dataframe(word_token_df, use_container_width=True)
282
+
283
+ # 顯示Token-映射
284
+ st.write("每個Token對應的詞:")
285
+ st.dataframe(token_word_df, use_container_width=True)
286
+
287
+ # ===============================
288
+ # 顯示 Embedding(前10維)
289
+ # ===============================
290
+ st.subheader("🧩 Token Embedding(前10維)")
291
+ tok_df = pd.DataFrame(tok_embeddings[:, :10].detach().numpy(),
292
+ columns=[f"dim_{i}" for i in range(10)])
293
+ tok_df.insert(0, "Token", tokens)
294
+ tok_df.insert(1, "Word", token_word_df["Word"])
295
+ st.dataframe(tok_df, use_container_width=True)
296
+
297
+ st.subheader("📍 Position Embedding(前10維)")
298
+ pos_df = pd.DataFrame(pos_embeddings[:, :10].detach().numpy(),
299
+ columns=[f"dim_{i}" for i in range(10)])
300
+ pos_df.insert(0, "Token", tokens)
301
+ pos_df.insert(1, "Word", token_word_df["Word"])
302
+ st.dataframe(pos_df, use_container_width=True)
303
+
304
+ # ===============================
305
+ # Attention 可視化
306
+ # ===============================
307
+ num_layers = len(attentions)
308
+ num_heads = attentions[0].shape[1]
309
+
310
+ col1, col2 = st.columns(2)
311
+ with col1:
312
+ layer_idx = st.slider("選擇 Attention 層數", 1, num_layers, num_layers)
313
+ with col2:
314
+ head_idx = st.slider("選擇 Attention Head", 1, num_heads, 1)
315
+
316
+ # 取得該層、該頭的注意力矩陣
317
+ selected_attention = attentions[layer_idx - 1][0, head_idx - 1].detach().numpy()
318
+ mean_attention = attentions[layer_idx - 1][0].mean(0).detach().numpy()
319
+
320
+ # 添加註信息
321
+ token_labels = [f"{t}\n({w})" if w != "特殊標記" else t
322
+ for t, w in zip(tokens, token_word_df["Word"])]
323
+
324
+ # 單頭 Attention Heatmap
325
+ st.subheader(f"🔥 Attention Heatmap(第 {layer_idx} 層��第 {head_idx} 頭)")
326
+ fig, ax = plt.subplots(figsize=(12, 10))
327
+ sns.heatmap(selected_attention, xticklabels=token_labels, yticklabels=token_labels,
328
+ cmap="YlGnBu", ax=ax)
329
+ plt.title(f"Attention - Layer {layer_idx}, Head {head_idx}", fontproperties=zh_font)
330
+ plt.xticks(rotation=90, fontsize=10, fontproperties=zh_font)
331
+ plt.yticks(rotation=0, fontsize=10, fontproperties=zh_font)
332
+ st.pyplot(fig, clear_figure=True, use_container_width=True)
333
+
334
+ # 平均所有頭
335
+ st.subheader(f"🌈 平均所有頭(第 {layer_idx} 層)")
336
+ fig2, ax2 = plt.subplots(figsize=(12, 10))
337
+ sns.heatmap(mean_attention, xticklabels=token_labels, yticklabels=token_labels,
338
+ cmap="rocket_r", ax=ax2)
339
+ plt.title(f"Mean Attention - Layer {layer_idx}", fontproperties=zh_font)
340
+ plt.xticks(rotation=90, fontsize=10, fontproperties=zh_font)
341
+ plt.yticks(rotation=0, fontsize=10, fontproperties=zh_font)
342
+ st.pyplot(fig2, clear_figure=True, use_container_width=True)
343
+
344
+ # ===============================
345
+ # 詞的平均注意力可視化
346
+ # ===============================
347
+ st.subheader("📊 詞級別注意力熱圖")
348
+
349
+ # 創建詞彙列表(去除特殊標記和未知詞)
350
+ unique_words = [w for w in words if w.strip()]
351
+
352
+ if len(unique_words) > 1: # 確保有足夠的進行可視化
353
+ # 創建詞-詞注意力矩陣
354
+ word_attention = np.zeros((len(unique_words), len(unique_words)))
355
+
356
+ # 使用之前建立的映射來聚合token級別的注意力到詞級別
357
+ for i, word_i in enumerate(unique_words):
358
+ # 找出屬於word_i的所有token
359
+ tokens_i = []
360
+ for j, w in enumerate(token_word_df["Word"]):
361
+ if w == word_i:
362
+ tokens_i.append(j)
363
+
364
+ for j, word_j in enumerate(unique_words):
365
+ # 找出屬於word_j的所有token
366
+ tokens_j = []
367
+ for k, w in enumerate(token_word_df["Word"]):
368
+ if w == word_j:
369
+ tokens_j.append(k)
370
+
371
+ # 計算這兩個詞之間的所token對的平均注意力
372
+ if tokens_i and tokens_j: # 確保兩個詞都有對應的token
373
+ attention_sum = 0
374
+ count = 0
375
+ for ti in tokens_i:
376
+ for tj in tokens_j:
377
+ if ti < len(selected_attention) and tj < len(selected_attention[0]):
378
+ attention_sum += selected_attention[ti, tj]
379
+ count += 1
380
+
381
+ if count > 0:
382
+ word_attention[i, j] = attention_sum / count
383
+
384
+ # 繪製詞級別注意力熱圖
385
+ fig3, ax3 = plt.subplots(figsize=(10, 8))
386
+ sns.heatmap(word_attention, xticklabels=unique_words, yticklabels=unique_words,
387
+ cmap="viridis", ax=ax3)
388
+ plt.title(f"詞級別注意力 - Layer {layer_idx}, Head {head_idx}", fontproperties=zh_font)
389
+ plt.xticks(rotation=45, fontsize=12, fontproperties=zh_font)
390
+ plt.yticks(rotation=0, fontsize=12, fontproperties=zh_font)
391
+ st.pyplot(fig3, clear_figure=True, use_container_width=True)
392
+ else:
393
+ st.info("詞數量不足,無法生成詞級別注意力熱圖")
394
+
395
+ # ===============================
396
+ # 下載 CSV
397
+ # ===============================
398
+ merged_df = pd.concat([tok_df, pos_df.add_prefix("pos_").iloc[:, 2:]], axis=1)
399
+ st.download_button(
400
+ label="💾 下載 Token + Position 向量 CSV",
401
+ data=merged_df.to_csv(index=False).encode("utf-8-sig"),
402
+ file_name="embeddings.csv",
403
+ mime="text/csv"
404
+ )
405
+
406
+ # 詞級別平均 embeddings
407
+ st.subheader("📑 詞級別平均 Embeddings(前10維)")
408
+
409
+ word_embeddings = {}
410
+ for word in unique_words:
411
+ # 找出屬於該詞的所有token索引
412
+ token_indices = [i for i, w in enumerate(token_word_df["Word"]) if w == word]
413
+
414
+ if token_indices:
415
+ # 計算該詞的平均 embedding
416
+ word_emb = tok_embeddings[token_indices].mean(dim=0)
417
+ word_embeddings[word] = word_emb[:10].detach().numpy()
418
+
419
+ if word_embeddings:
420
+ word_emb_df = pd.DataFrame.from_dict(
421
+ {word: values for word, values in word_embeddings.items()},
422
+ orient='index',
423
+ columns=[f"dim_{i}" for i in range(10)]
424
+ )
425
+ word_emb_df = word_emb_df.reset_index().rename(columns={"index": "Word"})
426
+ st.dataframe(word_emb_df, use_container_width=True)
427
+
428
+ # 下載詞級別 embeddings
429
+ st.download_button(
430
+ label="💾 下載詞級別向量 CSV",
431
+ data=word_emb_df.to_csv(index=False).encode("utf-8-sig"),
432
+ file_name="word_embeddings.csv",
433
+ mime="text/csv"
434
+ )
435
+
436
+ except Exception as e:
437
+ st.error(f"處理時發生錯誤: {str(e)}")
438
+ import traceback
439
+
440
+ st.code(traceback.format_exc(), language="python")
441
+
442
+ # ===============================
443
+ # 說明與幫助
444
+ # ===============================
445
+ with st.expander("📖 使用說明"):
446
+ st.markdown("""
447
+ ### 工具功能
448
+
449
+ 這個工具可以幫助您理解 Transformer 模型如何處理中文文本:
450
+
451
+ 1. **分詞與映射**:使用 Jieba 將文本分詞,然後映射到 Transformer 模型的 token
452
+ 2. **Embedding 可視化**:查看每個 token 位置 embedding 向量前10維
453
+ 3. **Attention 可視化**:查看不同層頭的注意力模式
454
+ 4. **詞級別分析**:整合 token 級別信息,得到詞級別的 embedding 和注意力模式
455
+
456
+ ### 使用方法
457
+
458
+ 1. 選擇一個預訓練的中文模型
459
+ 2. 輸入您想分析的中文文本
460
+ 3. 點擊"開始分析"按鈕
461
+ 4. 使用滑塊選擇不同的層和注意力頭行可視化
462
+ 5. 下載 CSV 文件以進一步分析
463
+
464
+ ### 技術細節
465
+
466
+ - **詞-Token映射**:中文字符通常會被映射到單Token,而詞通常由多個Token組成
467
+ - **注意力機制**:每一層的每個注意力頭都關注不同的模式
468
+ - **注意力熱圖**:顏色越深表示注意力越強
469
+
470
+ ### 注意事項
471
+
472
+ - Transformer 模型可能將一個成多個 token
473
+ - 特殊標記(如 [CLS], [SEP])會被排除在詞級別分析之外
474
+ - 較長的文本可能需要更多處理時間
475
  """)
requirements.txt CHANGED
@@ -12,5 +12,6 @@ transformers
12
  safetensors
13
  accelerate
14
 
 
15
  # (ทางเลือก) ถ้าเจอปัญหา tokenizer/protobuf ให้เปิดคอมเมนต์บรรทัดล่าง
16
  # protobuf>=4.25
 
12
  safetensors
13
  accelerate
14
 
15
+ requests
16
  # (ทางเลือก) ถ้าเจอปัญหา tokenizer/protobuf ให้เปิดคอมเมนต์บรรทัดล่าง
17
  # protobuf>=4.25