transE / app.py
amasood's picture
Update app.py
bc02f8a verified
import streamlit as st
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import plotly.express as px
from sklearn.decomposition import PCA
from transformers import AutoModel, AutoTokenizer, pipeline, AutoModelForCausalLM
# App Title
st.title("πŸš€ Transformer Model Explorer")
st.markdown(
"""
Explore different transformer models, their architectures, tokenization, and attention mechanisms.
"""
)
# Model Selection
model_name = st.selectbox(
"Choose a Transformer Model:",
["bert-base-uncased", "gpt2", "t5-small", "roberta-base"]
)
# Load Tokenizer & Model
st.write(f"Loading model: `{model_name}`...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
# Display Model Details
st.subheader("πŸ›  Model Details")
st.write(f"Model Type: `{model.config.model_type}`")
st.write(f"Number of Layers: `{model.config.num_hidden_layers}`")
st.write(f"Number of Attention Heads: `{model.config.num_attention_heads if hasattr(model.config, 'num_attention_heads') else 'N/A'}`")
st.write(f"Total Parameters: `{sum(p.numel() for p in model.parameters())/1e6:.2f}M`")
# Model Size Comparison
st.subheader("πŸ“Š Model Size Comparison")
model_sizes = {
"bert-base-uncased": 110,
"gpt2": 124,
"t5-small": 60,
"roberta-base": 125
}
df_size = pd.DataFrame(model_sizes.items(), columns=["Model", "Size (Million Parameters)"])
fig = px.bar(df_size, x="Model", y="Size (Million Parameters)", title="Model Size Comparison")
st.plotly_chart(fig)
# Tokenization Section
st.subheader("πŸ“ Tokenization Visualization")
input_text = st.text_input("Enter Text:", "Hello, how are you?")
tokens = tokenizer.tokenize(input_text)
st.write("Tokenized Output:", tokens)
# Token Embeddings Visualization (Fixed PCA Projection)
st.subheader("🧩 Token Embeddings Visualization")
with torch.no_grad():
inputs = tokenizer(input_text, return_tensors="pt")
outputs = model(**inputs)
if hasattr(outputs, "last_hidden_state"):
embeddings = outputs.last_hidden_state.squeeze(0).numpy()
# Ensure the number of tokens and embeddings match
n_tokens = min(len(tokens), embeddings.shape[0])
embeddings = embeddings[:n_tokens] # Trim embeddings to match token count
tokens = tokens[:n_tokens] # Trim tokens to match embeddings count
pca = PCA(n_components=2)
reduced_embeddings = pca.fit_transform(embeddings)
df_embeddings = pd.DataFrame(reduced_embeddings, columns=["PCA1", "PCA2"])
df_embeddings["Token"] = tokens
fig = px.scatter(df_embeddings, x="PCA1", y="PCA2", text="Token",
title="Token Embeddings (PCA Projection)")
st.plotly_chart(fig)
# Attention Visualization (for BERT & RoBERTa models)
if "bert" in model_name or "roberta" in model_name:
st.subheader("πŸ” Attention Map")
with torch.no_grad():
outputs = model(**inputs, output_attentions=True)
attention = outputs.attentions[-1].squeeze().detach().numpy()
fig, ax = plt.subplots(figsize=(10, 5))
sns.heatmap(attention[0], cmap="viridis", xticklabels=tokens, yticklabels=tokens, ax=ax)
st.pyplot(fig)
# Text Generation Demo (for GPT-like models)
if "gpt" in model_name:
st.subheader("✍️ Text Generation & Token Probabilities")
generator = pipeline("text-generation", model=model_name, return_full_text=False)
generated_output = generator(input_text, max_length=50, return_tensors=True)
st.write("Generated Output:", generated_output[0]["generated_text"])
# Token Probability Visualization
model_gen = AutoModelForCausalLM.from_pretrained(model_name)
with torch.no_grad():
inputs = tokenizer(input_text, return_tensors="pt")
logits = model_gen(**inputs).logits[:, -1, :]
probs = torch.nn.functional.softmax(logits, dim=-1).squeeze().detach().numpy()
top_tokens = np.argsort(probs)[-10:][::-1] # Top 10 tokens
token_probs = {tokenizer.decode([idx]): probs[idx] for idx in top_tokens}
df_probs = pd.DataFrame(token_probs.items(), columns=["Token", "Probability"])
fig_prob = px.bar(df_probs, x="Token", y="Probability", title="Top Token Predictions")
st.plotly_chart(fig_prob)
st.markdown("πŸ’‘ *Explore more about Transformer models!*")