File size: 2,392 Bytes
a00e3e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import gradio as gr
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# Load model and tokenizer once
model_name = "alusci/distilbert-smsafe"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, output_attentions=True)
model.eval()

# Main function
def classify_and_plot_attention(text):
    # Tokenize input
    inputs = tokenizer(text, return_tensors="pt")
    
    # Forward pass with attention
    with torch.no_grad():
        outputs = model(**inputs)

    # Get prediction
    logits = outputs.logits
    probs = torch.nn.functional.softmax(logits, dim=-1)
    pred_idx = torch.argmax(probs).item()
    pred_label = model.config.id2label[pred_idx]
    pred_score = round(probs[0, pred_idx].item(), 4)

    # Extract attention across all layers and heads
    all_attn = torch.stack(outputs.attentions)  # (layers, batch, heads, seq_len, seq_len)
    mean_attn = all_attn.mean(dim=(0, 2))[0].numpy()  # average over layers & heads

    # Token filtering (remove CLS/SEP)
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    real_token_idxs = [i for i, tok in enumerate(tokens) if tok not in ("[CLS]", "[SEP]")]
    real_tokens = [tokens[i] for i in real_token_idxs]
    trimmed_attn = mean_attn[np.ix_(real_token_idxs, real_token_idxs)]

    # Normalize
    norm_attn = (trimmed_attn - trimmed_attn.min()) / (trimmed_attn.max() - trimmed_attn.min())

    # Plot
    fig, ax = plt.subplots(figsize=(8, 6))
    sns.heatmap(norm_attn, xticklabels=real_tokens, yticklabels=real_tokens,
                cmap="viridis", square=True, ax=ax, cbar=True)
    ax.set_title("Normalized Attention Map")
    ax.set_xlabel("Input Tokens")
    ax.set_ylabel("Output Tokens")
    plt.xticks(rotation=45)
    plt.tight_layout()

    return f"Prediction: {pred_label} (Score: {pred_score})", fig

# Gradio UI
demo = gr.Interface(
    fn=classify_and_plot_attention,
    inputs=gr.Textbox(lines=3, placeholder="Paste your SMS OTP message here..."),
    outputs=["text", "plot"],
    title="SMS OTP Spam Classifier + Attention Visualizer",
    description="Enter an SMS OTP message to classify it and view the attention matrix.",
    allow_flagging="never"
)

if __name__ == "__main__":
    demo.launch()