Đinh Trác Đức Anh commited on
Commit
7c4c75e
·
1 Parent(s): 7c25540

fix visualize bias_matrix

Browse files
Files changed (1) hide show
  1. utils/visualize_bias_matrix.py +26 -14
utils/visualize_bias_matrix.py CHANGED
@@ -1,31 +1,43 @@
1
  import matplotlib.pyplot as plt
2
  import seaborn as sns
3
  import torch
 
4
 
5
- def visualize_bias_matrix(bias_matrix, tokens=None, title="Bias Matrix Visualization"):
6
  """
7
- Hiển thị bias matrix dưới dạng heatmap, gắn nhãn token.
8
 
9
  Args:
10
- bias_matrix: torch.Tensor shape [1, num_heads, seq_len, seq_len] hoặc [seq_len, seq_len]
11
- tokens: list[str], danh sách subword hoặc syllable tương ứng
 
 
12
  title: tiêu đề heatmap
13
  """
 
14
  if isinstance(bias_matrix, torch.Tensor):
15
  bias_matrix = bias_matrix.detach().cpu()
16
-
17
- # Nếu nhiều head → lấy trung bình
18
- if bias_matrix.ndim == 4:
19
- bias_matrix = bias_matrix.mean(dim=1).squeeze(0)
20
- elif bias_matrix.ndim == 3:
21
- bias_matrix = bias_matrix.squeeze(0)
22
-
23
  seq_len = bias_matrix.shape[0]
24
 
25
- # Đảm bảo tokens phù hợp độ dài
26
  if tokens is None:
27
- tokens = [str(i) for i in range(seq_len)]
28
- elif len(tokens) != seq_len:
 
 
 
 
 
 
 
 
29
  tokens = tokens[:seq_len]
30
 
31
  # Vẽ heatmap
 
1
  import matplotlib.pyplot as plt
2
  import seaborn as sns
3
  import torch
4
+ import numpy as np
5
 
6
+ def visualize_bias_matrix(bias_matrix, encoded=None, tokenizer=None, tokens=None, title="Bias Matrix Visualization"):
7
  """
8
+ Hiển thị bias matrix dưới dạng heatmap, gắn nhãn token.
9
 
10
  Args:
11
+ bias_matrix: torch.Tensor, shape [seq_len, seq_len] hoặc [1, num_heads, seq_len, seq_len]
12
+ encoded: dict từ tokenizer, chứa 'input_ids' (tùy chọn)
13
+ tokenizer: tokenizer dùng để convert input_ids sang token (tùy chọn)
14
+ tokens: list[str], nhãn token nếu muốn tự truyền
15
  title: tiêu đề heatmap
16
  """
17
+ # Nếu bias_matrix là 4D -> [1, num_heads, seq_len, seq_len]
18
  if isinstance(bias_matrix, torch.Tensor):
19
  bias_matrix = bias_matrix.detach().cpu()
20
+ if bias_matrix.ndim == 4:
21
+ # trung bình trên head
22
+ bias_matrix = bias_matrix.mean(dim=1).squeeze(0)
23
+ elif bias_matrix.ndim == 3:
24
+ bias_matrix = bias_matrix.squeeze(0)
25
+ bias_matrix = bias_matrix.numpy()
26
+
27
  seq_len = bias_matrix.shape[0]
28
 
29
+ # Lấy tokens từ input_ids nếu chưa có
30
  if tokens is None:
31
+ if encoded is not None and tokenizer is not None:
32
+ input_ids = encoded.get("input_ids")
33
+ if isinstance(input_ids, torch.Tensor):
34
+ if input_ids.ndim == 2: # batch
35
+ input_ids = input_ids[0]
36
+ input_ids = input_ids.detach().cpu().tolist()
37
+ tokens = tokenizer.convert_ids_to_tokens(input_ids)
38
+ else:
39
+ tokens = [str(i) for i in range(seq_len)]
40
+ else:
41
  tokens = tokens[:seq_len]
42
 
43
  # Vẽ heatmap