Spaces:
Build error
Build error
Create visualize.py
Browse files- visualize.py +90 -0
visualize.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# visualize.py - Contains functions to draw:
|
| 2 |
+
|
| 3 |
+
#Attention matrix
|
| 4 |
+
#Tokenization preview
|
| 5 |
+
#Embedding heatmaps
|
| 6 |
+
#Model comparison chart
|
| 7 |
+
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import seaborn as sns
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
from sklearn.decomposition import PCA
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def plot_attention(tokens, attn_matrix):
|
| 16 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
| 17 |
+
cax = ax.matshow(attn_matrix, cmap="viridis")
|
| 18 |
+
fig.colorbar(cax)
|
| 19 |
+
ax.set_xticks(range(len(tokens)))
|
| 20 |
+
ax.set_yticks(range(len(tokens)))
|
| 21 |
+
ax.set_xticklabels(tokens, rotation=90)
|
| 22 |
+
ax.set_yticklabels(tokens)
|
| 23 |
+
ax.set_title("Attention Map")
|
| 24 |
+
plt.tight_layout()
|
| 25 |
+
return fig
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def visualize_attention(tokenizer, model, text, layer_index, head_index):
|
| 29 |
+
inputs = tokenizer(text, return_tensors="pt")
|
| 30 |
+
with torch.no_grad():
|
| 31 |
+
outputs = model(**inputs)
|
| 32 |
+
|
| 33 |
+
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
|
| 34 |
+
attn = outputs.attentions[layer_index][0, head_index].detach().numpy()
|
| 35 |
+
return plot_attention(tokens, attn)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def show_tokenization(tokenizer, text):
|
| 39 |
+
tokens = tokenizer.tokenize(text)
|
| 40 |
+
fig, ax = plt.subplots(figsize=(8, 1))
|
| 41 |
+
ax.imshow([[0] * len(tokens)], cmap="Pastel2", aspect="auto")
|
| 42 |
+
ax.set_xticks(range(len(tokens)))
|
| 43 |
+
ax.set_xticklabels(tokens, rotation=90)
|
| 44 |
+
ax.set_yticks([])
|
| 45 |
+
ax.set_title("Tokenization")
|
| 46 |
+
return fig
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def show_embeddings(tokenizer, model, text):
|
| 50 |
+
inputs = tokenizer(text, return_tensors="pt")
|
| 51 |
+
with torch.no_grad():
|
| 52 |
+
outputs = model(**inputs)
|
| 53 |
+
|
| 54 |
+
embeddings = outputs.last_hidden_state[0].detach().numpy()
|
| 55 |
+
pca = PCA(n_components=2)
|
| 56 |
+
reduced = pca.fit_transform(embeddings)
|
| 57 |
+
|
| 58 |
+
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
|
| 59 |
+
fig, ax = plt.subplots()
|
| 60 |
+
ax.scatter(reduced[:, 0], reduced[:, 1])
|
| 61 |
+
|
| 62 |
+
for i, token in enumerate(tokens):
|
| 63 |
+
ax.annotate(token, (reduced[i, 0], reduced[i, 1]))
|
| 64 |
+
|
| 65 |
+
ax.set_title("Token Embeddings (PCA)")
|
| 66 |
+
return fig
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def compare_model_sizes():
|
| 70 |
+
from model_utils import MODEL_OPTIONS
|
| 71 |
+
from transformers import AutoModel
|
| 72 |
+
|
| 73 |
+
model_names = list(MODEL_OPTIONS.values())
|
| 74 |
+
sizes = []
|
| 75 |
+
|
| 76 |
+
for name in model_names:
|
| 77 |
+
try:
|
| 78 |
+
model = AutoModel.from_pretrained(name)
|
| 79 |
+
size = sum(p.numel() for p in model.parameters()) / 1e6 # in millions
|
| 80 |
+
sizes.append(size)
|
| 81 |
+
except:
|
| 82 |
+
sizes.append(None)
|
| 83 |
+
|
| 84 |
+
fig, ax = plt.subplots()
|
| 85 |
+
ax.bar(list(MODEL_OPTIONS.keys()), sizes, color="skyblue")
|
| 86 |
+
ax.set_ylabel("Parameters (Millions)")
|
| 87 |
+
ax.set_title("Model Size Comparison")
|
| 88 |
+
ax.tick_params(axis='x', rotation=45)
|
| 89 |
+
plt.tight_layout()
|
| 90 |
+
return fig
|