Đinh Trác Đức Anh commited on
Commit ·
7c4c75e
1
Parent(s): 7c25540
fix visualize bias_matrix
Browse files- utils/visualize_bias_matrix.py +26 -14
utils/visualize_bias_matrix.py
CHANGED
|
@@ -1,31 +1,43 @@
|
|
| 1 |
import matplotlib.pyplot as plt
|
| 2 |
import seaborn as sns
|
| 3 |
import torch
|
|
|
|
| 4 |
|
| 5 |
-
def visualize_bias_matrix(bias_matrix, tokens=None, title="Bias Matrix Visualization"):
|
| 6 |
"""
|
| 7 |
-
Hiển thị bias matrix dưới dạng heatmap,
|
| 8 |
|
| 9 |
Args:
|
| 10 |
-
bias_matrix: torch.Tensor
|
| 11 |
-
|
|
|
|
|
|
|
| 12 |
title: tiêu đề heatmap
|
| 13 |
"""
|
|
|
|
| 14 |
if isinstance(bias_matrix, torch.Tensor):
|
| 15 |
bias_matrix = bias_matrix.detach().cpu()
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
bias_matrix = bias_matrix.
|
| 22 |
-
|
| 23 |
seq_len = bias_matrix.shape[0]
|
| 24 |
|
| 25 |
-
#
|
| 26 |
if tokens is None:
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
tokens = tokens[:seq_len]
|
| 30 |
|
| 31 |
# Vẽ heatmap
|
|
|
|
| 1 |
import matplotlib.pyplot as plt
|
| 2 |
import seaborn as sns
|
| 3 |
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
|
| 6 |
+
def visualize_bias_matrix(bias_matrix, encoded=None, tokenizer=None, tokens=None, title="Bias Matrix Visualization"):
|
| 7 |
"""
|
| 8 |
+
Hiển thị bias matrix dưới dạng heatmap, gắn nhãn token.
|
| 9 |
|
| 10 |
Args:
|
| 11 |
+
bias_matrix: torch.Tensor, shape [seq_len, seq_len] hoặc [1, num_heads, seq_len, seq_len]
|
| 12 |
+
encoded: dict từ tokenizer, chứa 'input_ids' (tùy chọn)
|
| 13 |
+
tokenizer: tokenizer dùng để convert input_ids sang token (tùy chọn)
|
| 14 |
+
tokens: list[str], nhãn token nếu muốn tự truyền
|
| 15 |
title: tiêu đề heatmap
|
| 16 |
"""
|
| 17 |
+
# Nếu bias_matrix là 4D -> [1, num_heads, seq_len, seq_len]
|
| 18 |
if isinstance(bias_matrix, torch.Tensor):
|
| 19 |
bias_matrix = bias_matrix.detach().cpu()
|
| 20 |
+
if bias_matrix.ndim == 4:
|
| 21 |
+
# trung bình trên head
|
| 22 |
+
bias_matrix = bias_matrix.mean(dim=1).squeeze(0)
|
| 23 |
+
elif bias_matrix.ndim == 3:
|
| 24 |
+
bias_matrix = bias_matrix.squeeze(0)
|
| 25 |
+
bias_matrix = bias_matrix.numpy()
|
| 26 |
+
|
| 27 |
seq_len = bias_matrix.shape[0]
|
| 28 |
|
| 29 |
+
# Lấy tokens từ input_ids nếu chưa có
|
| 30 |
if tokens is None:
|
| 31 |
+
if encoded is not None and tokenizer is not None:
|
| 32 |
+
input_ids = encoded.get("input_ids")
|
| 33 |
+
if isinstance(input_ids, torch.Tensor):
|
| 34 |
+
if input_ids.ndim == 2: # batch
|
| 35 |
+
input_ids = input_ids[0]
|
| 36 |
+
input_ids = input_ids.detach().cpu().tolist()
|
| 37 |
+
tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
| 38 |
+
else:
|
| 39 |
+
tokens = [str(i) for i in range(seq_len)]
|
| 40 |
+
else:
|
| 41 |
tokens = tokens[:seq_len]
|
| 42 |
|
| 43 |
# Vẽ heatmap
|