Spaces:
Sleeping
Sleeping
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import torch | |
| def plot_attention_heatmap(model, tokenizer, input_seq): | |
| model.eval() | |
| inputs = tokenizer(input_seq, return_tensors="pt") | |
| with torch.no_grad(): | |
| output = model.base_model(**inputs, output_attentions=True) | |
| attention = output.attentions[-1][0] | |
| tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) | |
| fig, ax = plt.subplots(figsize=(10, 8)) | |
| sns.heatmap(attention, xticklabels=tokens, yticklabels=tokens, cmap="viridis", ax=ax) | |
| plt.xticks(rotation=90) | |
| plt.title("Attention Heatmap") | |
| return fig | |