edwjin commited on
Commit
eac943c
·
verified ·
1 Parent(s): eb20058

Delete utilities.py

Browse files
Files changed (1) hide show
  1. utilities.py +0 -53
utilities.py DELETED
@@ -1,53 +0,0 @@
1
-
2
- import matplotlib.pyplot as plt
3
- import torch
4
-
5
- class Utilities:
6
- def __init__(self, tokenizer, model):
7
- self.tokenizer = tokenizer
8
- self.model = model
9
-
10
- def sanity_check(self, sentence, block_size):
11
- # Encode the sentence using the tokenizer
12
- wordids = self.tokenizer.encode(sentence)
13
-
14
- # Prepare the padded input for the model
15
- padded_sentence = wordids[:block_size] + [0] * (block_size - len(wordids))
16
- input_tensor = torch.tensor(padded_sentence, dtype=torch.long).unsqueeze(0)
17
-
18
- # Display input tensor shape
19
- print("Input tensor shape:", input_tensor.shape)
20
-
21
- # Process the input tensor through the encoder model
22
- _, attn_maps = self.model(input_tensor)
23
-
24
- # Display the number of attention maps
25
- print("Number of attention maps:", len(attn_maps))
26
-
27
- # Visualize and save the attention maps
28
- for j, attn_map in enumerate(attn_maps):
29
- att_map = attn_map.squeeze(0).detach().cpu().numpy() # Remove batch dimension and convert to NumPy array
30
-
31
- print("map shape", att_map.shape, att_map.ndim)
32
-
33
- # Check if the attention probabilities sum to 1 over rows
34
- total_prob_over_rows = torch.sum(attn_map[0], dim=1)
35
- if torch.any(total_prob_over_rows < 0.99) or torch.any(total_prob_over_rows > 1.01):
36
- print("Failed normalization test: probabilities do not sum to 1.0 over rows")
37
- print("Total probability over rows:", total_prob_over_rows.numpy())
38
-
39
- # Create a heatmap of the attention map
40
- fig, ax = plt.subplots()
41
- cax = ax.imshow(att_map, cmap='hot', interpolation='nearest')
42
- ax.xaxis.tick_top()
43
- fig.colorbar(cax, ax=ax)
44
- plt.title(f"Attention Map {j + 1}")
45
-
46
- # Save the plot
47
- plt.savefig(f"attention_map_{j + 1}.png")
48
-
49
- # Show the plot
50
- # plt.show()
51
-
52
-
53
-