import matplotlib.pyplot as plt import seaborn as sns import torch def plot_attention_heatmap(model, tokenizer, input_seq): model.eval() inputs = tokenizer(input_seq, return_tensors="pt") with torch.no_grad(): output = model.base_model(**inputs, output_attentions=True) attention = output.attentions[-1][0] tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) fig, ax = plt.subplots(figsize=(10, 8)) sns.heatmap(attention, xticklabels=tokens, yticklabels=tokens, cmap="viridis", ax=ax) plt.xticks(rotation=90) plt.title("Attention Heatmap") return fig