Spaces:
Sleeping
Sleeping
| [file name] updated_code.py | |
| [file content] | |
| import streamlit as st | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| import torch | |
| from transformers import AutoConfig, AutoTokenizer # Added AutoTokenizer | |
| # Page configuration | |
| st.set_page_config( | |
| page_title="Transformer Visualizer", | |
| page_icon="🧠", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # Custom CSS styling (unchanged) | |
| # ... [same CSS styles as original] ... | |
| # Model database (unchanged) | |
| MODELS = { | |
| # ... [same model database as original] ... | |
| } | |
| def get_model_config(model_name): | |
| config = AutoConfig.from_pretrained(MODELS[model_name]["model_name"]) | |
| return config | |
| def plot_model_comparison(selected_model): | |
| # ... [same comparison function as original] ... | |
| def visualize_architecture(model_info): | |
| # ... [same architecture function as original] ... | |
| def visualize_attention_patterns(): | |
| # ... [same attention patterns function as original] ... | |
| def main(): | |
| st.title("🧠 Transformer Model Visualizer") | |
| selected_model = st.sidebar.selectbox("Select Model", list(MODELS.keys())) | |
| model_info = MODELS[selected_model] | |
| config = get_model_config(selected_model) | |
| # Metrics columns (unchanged) | |
| col1, col2, col3, col4 = st.columns(4) | |
| # ... [same metrics code as original] ... | |
| # Added 4th tab | |
| tab1, tab2, tab3, tab4 = st.tabs(["Model Structure", "Comparison", "Model Attention", "Model Tokenization"]) | |
| # Existing tabs (unchanged) | |
| # ... [same tab1, tab2, tab3 code as original] ... | |
| # New Tokenization Tab | |
| with tab4: | |
| st.subheader("Text Tokenization") | |
| user_input = st.text_input("Enter Text:", value="My name is Sadia!", key="tokenizer_input") | |
| if st.button("Tokenize", key="tokenize_button"): | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(MODELS[selected_model]["model_name"]) | |
| tokens = tokenizer.tokenize(user_input) | |
| # Format output similar to reference image | |
| tokenized_output = "- [ \n" | |
| for idx, token in enumerate(tokens): | |
| tokenized_output += f" {idx} : \"{token}\" \n" | |
| tokenized_output += "]" | |
| st.markdown("**Tokenized Output:**") | |
| st.markdown(f"```\n{tokenized_output}\n```", unsafe_allow_html=True) | |
| except Exception as e: | |
| st.error(f"Error in tokenization: {str(e)}") | |
| if __name__ == "__main__": | |
| main() |