Spaces:
Sleeping
Sleeping
Delete utilities.py
Browse files- utilities.py +0 -53
utilities.py
DELETED
|
@@ -1,53 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
import matplotlib.pyplot as plt
|
| 3 |
-
import torch
|
| 4 |
-
|
| 5 |
-
class Utilities:
|
| 6 |
-
def __init__(self, tokenizer, model):
|
| 7 |
-
self.tokenizer = tokenizer
|
| 8 |
-
self.model = model
|
| 9 |
-
|
| 10 |
-
def sanity_check(self, sentence, block_size):
|
| 11 |
-
# Encode the sentence using the tokenizer
|
| 12 |
-
wordids = self.tokenizer.encode(sentence)
|
| 13 |
-
|
| 14 |
-
# Prepare the padded input for the model
|
| 15 |
-
padded_sentence = wordids[:block_size] + [0] * (block_size - len(wordids))
|
| 16 |
-
input_tensor = torch.tensor(padded_sentence, dtype=torch.long).unsqueeze(0)
|
| 17 |
-
|
| 18 |
-
# Display input tensor shape
|
| 19 |
-
print("Input tensor shape:", input_tensor.shape)
|
| 20 |
-
|
| 21 |
-
# Process the input tensor through the encoder model
|
| 22 |
-
_, attn_maps = self.model(input_tensor)
|
| 23 |
-
|
| 24 |
-
# Display the number of attention maps
|
| 25 |
-
print("Number of attention maps:", len(attn_maps))
|
| 26 |
-
|
| 27 |
-
# Visualize and save the attention maps
|
| 28 |
-
for j, attn_map in enumerate(attn_maps):
|
| 29 |
-
att_map = attn_map.squeeze(0).detach().cpu().numpy() # Remove batch dimension and convert to NumPy array
|
| 30 |
-
|
| 31 |
-
print("map shape", att_map.shape, att_map.ndim)
|
| 32 |
-
|
| 33 |
-
# Check if the attention probabilities sum to 1 over rows
|
| 34 |
-
total_prob_over_rows = torch.sum(attn_map[0], dim=1)
|
| 35 |
-
if torch.any(total_prob_over_rows < 0.99) or torch.any(total_prob_over_rows > 1.01):
|
| 36 |
-
print("Failed normalization test: probabilities do not sum to 1.0 over rows")
|
| 37 |
-
print("Total probability over rows:", total_prob_over_rows.numpy())
|
| 38 |
-
|
| 39 |
-
# Create a heatmap of the attention map
|
| 40 |
-
fig, ax = plt.subplots()
|
| 41 |
-
cax = ax.imshow(att_map, cmap='hot', interpolation='nearest')
|
| 42 |
-
ax.xaxis.tick_top()
|
| 43 |
-
fig.colorbar(cax, ax=ax)
|
| 44 |
-
plt.title(f"Attention Map {j + 1}")
|
| 45 |
-
|
| 46 |
-
# Save the plot
|
| 47 |
-
plt.savefig(f"attention_map_{j + 1}.png")
|
| 48 |
-
|
| 49 |
-
# Show the plot
|
| 50 |
-
# plt.show()
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|