File size: 4,547 Bytes
f3ff1d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import streamlit as st
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoTokenizer
from models.huggingface_model import SentimentClassifierForHuggingFace
import numpy as np
import io
from PIL import Image

# Load model and tokenizer
@st.cache_resource
def load_model():
    model = SentimentClassifierForHuggingFace.from_pretrained("./")
    tokenizer = AutoTokenizer.from_pretrained("./")
    return model, tokenizer

def predict_sentiment(text, model, tokenizer):
    # Tokenize the input
    tokens = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
    input_ids = tokens["input_ids"]
    
    # Run inference
    model.eval()
    with torch.no_grad():
        outputs = model(input_ids, return_attention=True, return_dict=True)
    
    # Get prediction results
    logits = outputs["logits"]
    attention_weights = outputs["attention_weights"]
    
    # Convert to probabilities and get prediction
    probs = torch.nn.functional.softmax(logits, dim=1)
    prediction = torch.argmax(probs, dim=1).item()
    confidence = probs[0][prediction].item()
    sentiment = "Positive" if prediction == 1 else "Negative"
    
    # Get token strings for visualization
    tokens_list = []
    for id in input_ids[0]:
        token = tokenizer.convert_ids_to_tokens(id.item())
        tokens_list.append(token)
    
    # Create visualization
    fig, ax = plt.subplots(figsize=(10, 2))
    sns.heatmap(
        attention_weights.squeeze(0).cpu().numpy().reshape(1, -1),
        cmap="YlOrRd",
        annot=True,
        fmt=".2f",
        cbar=False,
        xticklabels=tokens_list,
        yticklabels=["Attention"],
        ax=ax
    )
    
    # Rotate x-axis labels for better readability
    plt.xticks(rotation=45, ha="right", rotation_mode="anchor")
    plt.title(f"Prediction: {sentiment} (Confidence: {confidence:.4f})")
    plt.tight_layout()
    
    # Convert plot to image
    buf = io.BytesIO()
    fig.savefig(buf, format="png", dpi=150, bbox_inches="tight")
    buf.seek(0)
    img = Image.open(buf)
    plt.close(fig)
    
    return sentiment, confidence, img

# Streamlit app
def main():
    st.set_page_config(
        page_title="Sentiment Analysis with Attention",
        page_icon="🧠",
        layout="wide"
    )
    
    st.title("Sentiment Analysis with Attention Visualization")
    st.write("This model classifies text sentiment as positive or negative and visualizes which parts of the text it focused on using an attention mechanism.")
    
    # Load model and tokenizer
    try:
        model, tokenizer = load_model()
        model_loaded = True
    except Exception as e:
        st.error(f"Error loading model: {e}")
        model_loaded = False
    
    # Text input
    text_input = st.text_area(
        "Enter text to analyze:",
        value="I absolutely loved this movie! The acting was superb.",
        height=100,
    )
    
    # Process when button is clicked
    if st.button("Analyze Sentiment") and model_loaded:
        with st.spinner("Analyzing..."):
            sentiment, confidence, viz_img = predict_sentiment(text_input, model, tokenizer)
            
            # Display results
            col1, col2 = st.columns([1, 3])
            
            with col1:
                st.subheader("Prediction:")
                sentiment_color = "#5FD068" if sentiment == "Positive" else "#D21312"
                st.markdown(
                    f"<div style='background-color:{sentiment_color}; padding:10px; border-radius:5px;"
                    f"color:white; text-align:center; font-size:24px;'>{sentiment}</div>",
                    unsafe_allow_html=True
                )
                st.metric("Confidence", f"{confidence:.2%}")
            
            with col2:
                st.subheader("Attention Visualization:")
                st.image(viz_img, use_column_width=True)
                st.caption("The heatmap shows which words the model focused on most when making its prediction.")
            
            st.markdown("---")
            st.subheader("How to interpret the visualization:")
            st.write(
                "The attention heatmap shows the weight assigned to each token in the text. "
                "Darker colors indicate where the model focused more attention when making its prediction. "
                "This can help identify which parts of the text were most influential for sentiment classification."
            )

if __name__ == "__main__":
    main()