Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import plotly.express as px | |
| model_name = 'Qwen/Qwen2-1.5B' | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| def load_model(): | |
| return AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.bfloat16, | |
| token=st.secrets['hf_token'] | |
| ).to(device) | |
| def load_tokenizer(): | |
| return AutoTokenizer.from_pretrained( | |
| model_name, | |
| token=st.secrets['hf_token'] | |
| ) | |
| def get_attention_weights_and_tokens(text): | |
| tokenized = tokenizer(text, return_tensors='pt') | |
| tokens = [tokenizer.decode(token) for token in tokenized.input_ids[0]] | |
| tokenized = tokenized.to(device) | |
| output = model(**tokenized, output_attentions=True) | |
| attentions = [attention.to(torch.float32) for attention in output.attentions] | |
| return attentions, tokens | |
| model = load_model() | |
| tokenizer = load_tokenizer() | |
| st.title('Attention visualizer') | |
| text = st.text_area('Write your text here and see attention weights.') | |
| layer = st.slider( | |
| 'Which layer do you want to see?', | |
| min_value=1, | |
| max_value=model.config.num_hidden_layers | |
| ) - 1 | |
| head = st.select_slider( | |
| 'Which head do you want to see?', | |
| options = ['Average'] + list(range(1, model.config.num_attention_heads + 1)) | |
| ) | |
| if text: | |
| attentions, tokens = get_attention_weights_and_tokens(text) | |
| if head == 'Average': | |
| weights = attentions[layer].cpu()[0].mean(dim=0) | |
| else: | |
| weights = attentions[layer].cpu()[0][head - 1] | |
| fig = px.imshow( | |
| weights, | |
| ) | |
| fig.update_layout(xaxis={ | |
| 'ticktext': tokens, | |
| 'tickvals': list(range(len(tokens))), | |
| }, yaxis={ | |
| 'ticktext': tokens, | |
| 'tickvals': list(range(len(tokens))), | |
| }, | |
| height=800, | |
| ) | |
| st.plotly_chart(fig) |