Spaces:
Build error
Build error
| import streamlit as st | |
| import torch | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from transformers import BertTokenizer, BertModel | |
| # Load pre-trained BERT model and tokenizer from HuggingFace | |
| tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
| model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True) | |
| # App title and description | |
| st.title("BERT Attention Map Visualizer") | |
| st.write(""" | |
| ## Introduction | |
| This application visualizes the attention mechanism of the BERT model for a given input sentence. | |
| The attention mechanism allows BERT to focus on different parts of the sentence when encoding each token, | |
| providing insights into how the model understands the context and relationships between words. | |
| This app showcases how BERT generates attention maps and word embeddings using a pre-trained BERT model. | |
| ### Attention Mechanism | |
| The attention mechanism is a method to enhance the ability of the model to focus on important parts of the input sequence. | |
| It computes a weighted sum of values (V) based on the similarity between queries (Q) and keys (K). The formulation is as follows: | |
| $$ | |
| \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V | |
| $$ | |
| where: | |
| - \( Q \) (Query): Represents the current token for which attention is being calculated. | |
| - \( K \) (Key): Represents the tokens in the input sequence to compare against the query. | |
| - \( V \) (Value): Represents the actual values used to compute the attention-weighted sum. | |
| - \( d_k \): Dimension of the key vectors, used for scaling. | |
| ### Key, Query, and Value | |
| - **Query (Q)**: Captures the essence of the word/token we are focusing on. | |
| - **Key (K)**: Represents all words/tokens we are comparing the query against. | |
| - **Value (V)**: Contains the information of all tokens that is aggregated based on attention scores. | |
| This mechanism allows the model to dynamically adjust its focus on different parts of the sentence, thereby improving contextual understanding. | |
| """) | |
| # Input sentence from the user | |
| sentence = st.text_input("Enter a sentence:", "The cat is on the mat") | |
| # Tokenize and encode the sentence | |
| inputs = tokenizer(sentence, return_tensors='pt', add_special_tokens=True) | |
| # Get the embeddings and attention weights from BERT | |
| outputs = model(**inputs) | |
| attention = outputs.attentions # Extract attention weights directly from the pretrained model | |
| attention_weights = attention[-1].squeeze(0) # Get attention from the last layer | |
| # Function to visualize attention weights | |
| def visualize_attention(tokens, attention_weights): | |
| attention_weights = attention_weights.detach().numpy() | |
| fig, ax = plt.subplots(figsize=(8, 8)) | |
| cax = ax.matshow(attention_weights, cmap='viridis') | |
| plt.xticks(range(len(tokens)), tokens, rotation=90) | |
| plt.yticks(range(len(tokens)), tokens) | |
| fig.colorbar(cax) | |
| plt.title("Attention Map") | |
| st.pyplot(fig) | |
| # Extract tokens including special tokens | |
| tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) | |
| # Remove special tokens for visualization | |
| tokens_vis = [token for token in tokens if token not in tokenizer.all_special_tokens] | |
| # Visualize the attention weights for the sentence excluding special tokens | |
| visualize_attention(tokens_vis, attention_weights[0, 1:-1, 1:-1]) | |
| st.write(""" | |
| ### About BERT | |
| BERT (Bidirectional Encoder Representations from Transformers) is a transformer-based model designed to understand the context of words in a sentence. It uses the attention mechanism to weigh the importance of different words when generating word embeddings. This attention mechanism is crucial for tasks like language translation, sentiment analysis, and more. | |
| """) | |