Donlagon007 commited on
Commit
b6e983e
·
verified ·
1 Parent(s): 8bb282f

Upload 5 files

Browse files
Files changed (5) hide show
  1. README.md +6 -19
  2. app.py +427 -0
  3. apt.txt +3 -0
  4. huggingface.yaml +8 -0
  5. requirements.txt +16 -3
README.md CHANGED
@@ -1,20 +1,7 @@
1
- ---
2
- title: Transformer
3
- emoji: 🚀
4
- colorFrom: red
5
- colorTo: red
6
- sdk: docker
7
- app_port: 8501
8
- tags:
9
- - streamlit
10
- pinned: false
11
- short_description: Streamlit template space
12
- license: mit
13
- ---
14
 
15
- # Welcome to Streamlit!
16
-
17
- Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
18
-
19
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
20
- forums](https://discuss.streamlit.io).
 
1
+ # 中文詞級 Transformer 可視化 (Streamlit)
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ 可視化 Token / Position / Attention。支援:
4
+ - Jieba 分詞
5
+ - Token Word 對照表
6
+ - Attention Heatmap(可選層與頭)
7
+ - Token/Position 向量下載成 CSV
 
app.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import pandas as pd
4
+ import numpy as np
5
+ import seaborn as sns
6
+ import matplotlib.pyplot as plt
7
+ import re
8
+ import jieba
9
+ import matplotlib
10
+ import matplotlib.font_manager as fm
11
+ from transformers import AutoTokenizer, AutoModel
12
+ import os
13
+ import warnings
14
+
15
+
16
+ # ===============================
17
+ # 中文字體設定(跨平台支持)
18
+ # ===============================
19
+ def setup_chinese_font():
20
+ """設置中文字體,支持 Windows, Mac 和 Linux"""
21
+ font_paths = [
22
+ # Windows 字體
23
+ "C:/Windows/Fonts/msjh.ttc",
24
+ "C:/Windows/Fonts/simsun.ttc",
25
+ # Mac 字體
26
+ "/System/Library/Fonts/PingFang.ttc",
27
+ # Linux 字體
28
+ "/usr/share/fonts/truetype/wqy/wqy-microhei.ttc",
29
+ "/usr/share/fonts/truetype/droid/DroidSansFallbackFull.ttf"
30
+ ]
31
+
32
+ # 嘗試加載第一個可用字體
33
+ for path in font_paths:
34
+ if os.path.exists(path):
35
+ zh_font = fm.FontProperties(fname=path)
36
+ font_name = os.path.basename(path).split('.')[0]
37
+ matplotlib.rcParams["font.sans-serif"] = [font_name, "DejaVu Sans"]
38
+ matplotlib.rcParams["axes.unicode_minus"] = False
39
+ return zh_font
40
+
41
+ # 如果沒有找到任何字體,使用系統默認
42
+ warnings.warn("無法找到中文字體,使用系統默認字體")
43
+ zh_font = fm.FontProperties()
44
+ matplotlib.rcParams["font.sans-serif"] = ["DejaVu Sans"]
45
+ matplotlib.rcParams["axes.unicode_minus"] = False
46
+ return zh_font
47
+
48
+
49
+ zh_font = setup_chinese_font()
50
+
51
+ # ===============================
52
+ # 頁面設定
53
+ # ===============================
54
+ st.set_page_config(page_title="中文詞級 Transformer 可視化", layout="wide")
55
+ st.title("🧠 中文詞級 Transformer Token / Position / Attention 可視化工具")
56
+
57
+ # ===============================
58
+ # 模型選擇與載入
59
+ # ===============================
60
+ model_options = {
61
+ "Chinese RoBERTa (WWM-ext)": "hfl/chinese-roberta-wwm-ext",
62
+ "BERT-base-Chinese": "bert-base-chinese",
63
+ "Chinese MacBERT-base": "hfl/chinese-macbert-base"
64
+ }
65
+
66
+ selected_model = st.selectbox(
67
+ "選擇模型",
68
+ list(model_options.keys()),
69
+ index=0
70
+ )
71
+
72
+ model_name = model_options[selected_model]
73
+
74
+
75
+ @st.cache_resource
76
+ def load_model(name):
77
+ with st.spinner(f"載入模型 {name} 中..."):
78
+ try:
79
+ tokenizer = AutoTokenizer.from_pretrained(name)
80
+ model = AutoModel.from_pretrained(name, output_attentions=True)
81
+ return tokenizer, model, None
82
+ except Exception as e:
83
+ return None, None, str(e)
84
+
85
+
86
+ tokenizer, model, error = load_model(model_name)
87
+
88
+ if error:
89
+ st.error(f"模型載入失敗: {error}")
90
+ st.stop()
91
+
92
+ # ===============================
93
+ # 使用者輸入
94
+ # ===============================
95
+ text = st.text_area(
96
+ "請輸入中文句子:",
97
+ "我今年35歲,目前在科技業工作,作息略不規律。",
98
+ help="輸入您想分析的中文文本。將使用 Jieba 進行分詞,然後用 Transformer 模型分析。"
99
+ )
100
+
101
+
102
+ def normalize_text(s):
103
+ """移除特殊符號與全形字"""
104
+ s = re.sub(r"[^\u4e00-\u9fa5A-Za-z0-9,。、;:?!%%\s]", "", s)
105
+ s = s.replace("%", "%").replace("。", "。 ")
106
+ return s.strip()
107
+
108
+
109
+ # ===============================
110
+ # 主流程
111
+ # ===============================
112
+ if st.button("開始分析", type="primary"):
113
+ if not text.strip():
114
+ st.warning("請輸入有效的中文句子")
115
+ st.stop()
116
+
117
+ # 文本清理與分詞
118
+ text = normalize_text(text)
119
+ words = list(jieba.cut(text))
120
+ st.write("🔹 Jieba 分詞結果:", words)
121
+
122
+ # 不使用空格連接,直接使用原始文本
123
+ # 這樣可以避免空格導致的詞-token不匹配問題
124
+ tokenized_result = tokenizer(text, return_tensors="pt")
125
+ token_ids = tokenized_result["input_ids"][0].tolist()
126
+ tokens = tokenizer.convert_ids_to_tokens(token_ids)
127
+
128
+ # 為了更準確地映射詞和token,我們需要找出每個token在原始文本中的位置
129
+ # 創建更穩健的詞-token映射
130
+ char_to_word = {}
131
+ current_pos = 0
132
+
133
+ # 為每個字符創建映射到對應詞的索引
134
+ for word_idx, word in enumerate(words):
135
+ for _ in range(len(word)):
136
+ char_to_word[current_pos] = word_idx
137
+ current_pos += 1
138
+
139
+ # 創建token到字符位置的映射
140
+ # 注意:這個方法適用於基於字符的中文模型,如BERT/RoBERTa中文模型
141
+ # 對於某些模型可能需要調整
142
+
143
+ # 首先找出特殊標記
144
+ special_tokens = []
145
+ for i, token in enumerate(tokens):
146
+ if token in ['[CLS]', '[SEP]', '<s>', '</s>', '<cls>', '<sep>']:
147
+ special_tokens.append(i)
148
+
149
+ # 找出原始文本中每個token的起始位置
150
+ chars = list(text) # 將文本轉換為字符列表
151
+ token_to_char_mapping = []
152
+ token_to_word_mapping = []
153
+
154
+ # 處理特殊標記
155
+ char_pos = 0
156
+ for i, token in enumerate(tokens):
157
+ if i in special_tokens:
158
+ token_to_char_mapping.append(-1) # 特殊標記沒有對應的字符位置
159
+ token_to_word_mapping.append("特殊標記")
160
+ else:
161
+ # 對於中文字符,大多數模型是一個字符一個token
162
+ # 這個邏輯可能需要根據具體模型調整
163
+ if token.startswith('##'): # BERT風格的子詞
164
+ actual_token = token[2:]
165
+ elif token.startswith('▁') or token.startswith('Ġ'): # 其他模型風格
166
+ actual_token = token[1:]
167
+ else:
168
+ actual_token = token
169
+
170
+ # 注意:中文BERT通常每個token就是一個字符
171
+ # 所以這裡可以直接映射
172
+ if char_pos < len(chars):
173
+ token_to_char_mapping.append(char_pos)
174
+ if char_pos in char_to_word:
175
+ word_idx = char_to_word[char_pos]
176
+ token_to_word_mapping.append(words[word_idx])
177
+ else:
178
+ token_to_word_mapping.append("未知詞")
179
+ char_pos += len(actual_token)
180
+ else:
181
+ token_to_char_mapping.append(-1)
182
+ token_to_word_mapping.append("未知詞")
183
+
184
+ # 創建詞到token的映射
185
+ word_to_tokens = [[] for _ in range(len(words))]
186
+ for i, word_idx in enumerate(char_to_word.values()):
187
+ if i < len(chars):
188
+ # 找出對應這個字符位置的token
189
+ for j, char_pos in enumerate(token_to_char_mapping):
190
+ if char_pos == i:
191
+ word_to_tokens[word_idx].append(j)
192
+ break
193
+
194
+ # 創建token-word對照表
195
+ token_word_df = pd.DataFrame({
196
+ "Token": tokens,
197
+ "Token_ID": token_ids,
198
+ "Word": token_to_word_mapping
199
+ })
200
+
201
+ # 創建word-tokens對照表
202
+ word_token_map = []
203
+ for i, word in enumerate(words):
204
+ token_indices = word_to_tokens[i]
205
+ token_list = [tokens[idx] for idx in token_indices if idx < len(tokens)]
206
+ word_token_map.append({
207
+ "Word": word,
208
+ "Tokens": " ".join(token_list) if token_list else "無對應Token"
209
+ })
210
+
211
+ word_token_df = pd.DataFrame(word_token_map)
212
+
213
+ # 模型前向運算
214
+ with torch.no_grad():
215
+ try:
216
+ outputs = model(**tokenized_result)
217
+
218
+ hidden_states = outputs.last_hidden_state.squeeze(0)
219
+ attentions = outputs.attentions
220
+
221
+ # Position & Token embeddings
222
+ position_ids = torch.arange(0, tokenized_result["input_ids"].size(1)).unsqueeze(0)
223
+ pos_embeddings = model.embeddings.position_embeddings(position_ids).squeeze(0)
224
+ tok_embeddings = model.embeddings.word_embeddings(tokenized_result["input_ids"]).squeeze(0)
225
+
226
+ # ===============================
227
+ # 顯示 Token-Word 映射
228
+ # ===============================
229
+ st.subheader("🔤 Token與詞的對應關係")
230
+
231
+ # 顯示詞-Token映射
232
+ st.write("詞對應的Tokens:")
233
+ st.dataframe(word_token_df, use_container_width=True)
234
+
235
+ # 顯示Token-詞映射
236
+ st.write("每個Token對應的詞:")
237
+ st.dataframe(token_word_df, use_container_width=True)
238
+
239
+ # ===============================
240
+ # 顯示 Embedding(前10維)
241
+ # ===============================
242
+ st.subheader("🧩 Token Embedding(前10維)")
243
+ tok_df = pd.DataFrame(tok_embeddings[:, :10].detach().numpy(),
244
+ columns=[f"dim_{i}" for i in range(10)])
245
+ tok_df.insert(0, "Token", tokens)
246
+ tok_df.insert(1, "Word", token_word_df["Word"])
247
+ st.dataframe(tok_df, use_container_width=True)
248
+
249
+ st.subheader("📍 Position Embedding(前10維)")
250
+ pos_df = pd.DataFrame(pos_embeddings[:, :10].detach().numpy(),
251
+ columns=[f"dim_{i}" for i in range(10)])
252
+ pos_df.insert(0, "Token", tokens)
253
+ pos_df.insert(1, "Word", token_word_df["Word"])
254
+ st.dataframe(pos_df, use_container_width=True)
255
+
256
+ # ===============================
257
+ # Attention 可視化
258
+ # ===============================
259
+ num_layers = len(attentions)
260
+ num_heads = attentions[0].shape[1]
261
+
262
+ col1, col2 = st.columns(2)
263
+ with col1:
264
+ layer_idx = st.slider("選擇 Attention 層數", 1, num_layers, num_layers)
265
+ with col2:
266
+ head_idx = st.slider("選擇 Attention Head", 1, num_heads, 1)
267
+
268
+ # 取得該層、該頭的注意力矩陣
269
+ selected_attention = attentions[layer_idx - 1][0, head_idx - 1].detach().numpy()
270
+ mean_attention = attentions[layer_idx - 1][0].mean(0).detach().numpy()
271
+
272
+ # 添加標註信息
273
+ token_labels = [f"{t}\n({w})" if w != "特殊標記" else t
274
+ for t, w in zip(tokens, token_word_df["Word"])]
275
+
276
+ # 單頭 Attention Heatmap
277
+ st.subheader(f"🔥 Attention Heatmap(第 {layer_idx} 層,第 {head_idx} 頭)")
278
+ fig, ax = plt.subplots(figsize=(12, 10))
279
+ sns.heatmap(selected_attention, xticklabels=token_labels, yticklabels=token_labels,
280
+ cmap="YlGnBu", ax=ax)
281
+ plt.title(f"Attention - Layer {layer_idx}, Head {head_idx}", fontproperties=zh_font)
282
+ plt.xticks(rotation=90, fontsize=10, fontproperties=zh_font)
283
+ plt.yticks(rotation=0, fontsize=10, fontproperties=zh_font)
284
+ st.pyplot(fig, clear_figure=True, use_container_width=True)
285
+
286
+ # 平均所有頭
287
+ st.subheader(f"🌈 平均所有頭(第 {layer_idx} 層)")
288
+ fig2, ax2 = plt.subplots(figsize=(12, 10))
289
+ sns.heatmap(mean_attention, xticklabels=token_labels, yticklabels=token_labels,
290
+ cmap="rocket_r", ax=ax2)
291
+ plt.title(f"Mean Attention - Layer {layer_idx}", fontproperties=zh_font)
292
+ plt.xticks(rotation=90, fontsize=10, fontproperties=zh_font)
293
+ plt.yticks(rotation=0, fontsize=10, fontproperties=zh_font)
294
+ st.pyplot(fig2, clear_figure=True, use_container_width=True)
295
+
296
+ # ===============================
297
+ # 詞的平均注意力可視化
298
+ # ===============================
299
+ st.subheader("📊 詞級別注意力熱圖")
300
+
301
+ # 創建詞彙列表(去除特殊標記和未知詞)
302
+ unique_words = [w for w in words if w.strip()]
303
+
304
+ if len(unique_words) > 1: # 確保有足夠的詞進行可視化
305
+ # 創建詞-詞注意力矩陣
306
+ word_attention = np.zeros((len(unique_words), len(unique_words)))
307
+
308
+ # 使用之前建立的映射來聚合token級別的注意力到詞級別
309
+ for i, word_i in enumerate(unique_words):
310
+ # 找出屬於word_i的所有token
311
+ tokens_i = []
312
+ for j, w in enumerate(token_word_df["Word"]):
313
+ if w == word_i:
314
+ tokens_i.append(j)
315
+
316
+ for j, word_j in enumerate(unique_words):
317
+ # 找出屬於word_j的所有token
318
+ tokens_j = []
319
+ for k, w in enumerate(token_word_df["Word"]):
320
+ if w == word_j:
321
+ tokens_j.append(k)
322
+
323
+ # 計算這兩個詞之間的所有token對的平均注意力
324
+ if tokens_i and tokens_j: # 確保兩個詞都有對應的token
325
+ attention_sum = 0
326
+ count = 0
327
+ for ti in tokens_i:
328
+ for tj in tokens_j:
329
+ if ti < len(selected_attention) and tj < len(selected_attention[0]):
330
+ attention_sum += selected_attention[ti, tj]
331
+ count += 1
332
+
333
+ if count > 0:
334
+ word_attention[i, j] = attention_sum / count
335
+
336
+ # 繪製詞級別注意力熱圖
337
+ fig3, ax3 = plt.subplots(figsize=(10, 8))
338
+ sns.heatmap(word_attention, xticklabels=unique_words, yticklabels=unique_words,
339
+ cmap="viridis", ax=ax3)
340
+ plt.title(f"詞級別注意力 - Layer {layer_idx}, Head {head_idx}", fontproperties=zh_font)
341
+ plt.xticks(rotation=45, fontsize=12, fontproperties=zh_font)
342
+ plt.yticks(rotation=0, fontsize=12, fontproperties=zh_font)
343
+ st.pyplot(fig3, clear_figure=True, use_container_width=True)
344
+ else:
345
+ st.info("詞數量不足,無法生成詞級別注意力熱圖")
346
+
347
+ # ===============================
348
+ # 下載 CSV
349
+ # ===============================
350
+ merged_df = pd.concat([tok_df, pos_df.add_prefix("pos_").iloc[:, 2:]], axis=1)
351
+ st.download_button(
352
+ label="💾 下載 Token + Position 向量 CSV",
353
+ data=merged_df.to_csv(index=False).encode("utf-8-sig"),
354
+ file_name="embeddings.csv",
355
+ mime="text/csv"
356
+ )
357
+
358
+ # 詞級別平均 embeddings
359
+ st.subheader("📑 詞級別平均 Embeddings(前10維)")
360
+
361
+ word_embeddings = {}
362
+ for word in unique_words:
363
+ # 找出屬於該��的所有token索引
364
+ token_indices = [i for i, w in enumerate(token_word_df["Word"]) if w == word]
365
+
366
+ if token_indices:
367
+ # 計算該詞的平均 embedding
368
+ word_emb = tok_embeddings[token_indices].mean(dim=0)
369
+ word_embeddings[word] = word_emb[:10].detach().numpy()
370
+
371
+ if word_embeddings:
372
+ word_emb_df = pd.DataFrame.from_dict(
373
+ {word: values for word, values in word_embeddings.items()},
374
+ orient='index',
375
+ columns=[f"dim_{i}" for i in range(10)]
376
+ )
377
+ word_emb_df = word_emb_df.reset_index().rename(columns={"index": "Word"})
378
+ st.dataframe(word_emb_df, use_container_width=True)
379
+
380
+ # 下載詞級別 embeddings
381
+ st.download_button(
382
+ label="💾 下載詞級別向量 CSV",
383
+ data=word_emb_df.to_csv(index=False).encode("utf-8-sig"),
384
+ file_name="word_embeddings.csv",
385
+ mime="text/csv"
386
+ )
387
+
388
+ except Exception as e:
389
+ st.error(f"處理時發生錯誤: {str(e)}")
390
+ import traceback
391
+
392
+ st.code(traceback.format_exc(), language="python")
393
+
394
+ # ===============================
395
+ # 說明與幫助
396
+ # ===============================
397
+ with st.expander("📖 使用說明"):
398
+ st.markdown("""
399
+ ### 工具功能
400
+
401
+ 這個工具可以幫助您理解 Transformer 模型如何處理中文文本:
402
+
403
+ 1. **分詞與映射**:使用 Jieba 將文本分詞,然後映射到 Transformer 模型的 token
404
+ 2. **Embedding 可視化**:查看每個 token 和位置的 embedding 向量前10維
405
+ 3. **Attention 可視化**:查看不同層和頭的注意力模式
406
+ 4. **詞級別分析**:整合 token 級別信息,得到詞級別的 embedding 和注意力模式
407
+
408
+ ### 使用方法
409
+
410
+ 1. 選擇一個預訓練的中文模型
411
+ 2. 輸入您想分析的中文文本
412
+ 3. 點擊"開始分析"按鈕
413
+ 4. 使用滑塊選擇不同的層和注意力頭進行可視化
414
+ 5. 下載 CSV 文件以進一步分析
415
+
416
+ ### 技術細節
417
+
418
+ - **詞-Token映射**:中文字符通常會被映射到單個Token,而詞通常由多個Token組成
419
+ - **注意力機制**:每一層的每個注意力頭都關注不同的模式
420
+ - **注意力熱圖**:顏色越深表示注意力越強
421
+
422
+ ### 注意事項
423
+
424
+ - Transformer 模型可能會將一個詞切分成多個 token
425
+ - 特殊標記(如 [CLS], [SEP])會被排除在詞級別分析之外
426
+ - 較長的文本可能需要更多處理時間
427
+ """)
apt.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ fonts-wqy-zenhei
2
+ fonts-wqy-microhei
3
+ fonts-noto-cjk
huggingface.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # แจ้งให้ Spaces รู้ว่าเป็น Streamlit และไฟล์หลักคือ app.py
2
+ # ถ้าคุณยังใช้ชื่อ app2.py ให้ตั้ง app_file: app2.py
3
+ ---
4
+ title: Chinese Transformer Visualizer
5
+ sdk: streamlit
6
+ app_file: app.py
7
+ python_version: "3.10"
8
+ # sdk_version: "1.39.0" # จะระบุหรือไม่ก็ได้
requirements.txt CHANGED
@@ -1,3 +1,16 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core app
2
+ streamlit>=1.39
3
+ pandas>=2.2
4
+ numpy>=1.26
5
+ matplotlib>=3.8
6
+ seaborn>=0.13
7
+ jieba>=0.42
8
+
9
+ # ML stack
10
+ torch==2.4.1
11
+ transformers>=4.44
12
+ safetensors>=0.4
13
+ accelerate>=0.34
14
+
15
+ # (ทางเลือก) ถ้าเจอปัญหา tokenizer/protobuf ให้เปิดคอมเมนต์บรรทัดล่าง
16
+ # protobuf>=4.25