Initial model upload
Browse files- .gitattributes +3 -34
- MODEL_CARD.md +147 -0
- README.md +121 -0
- app.py +129 -0
- assets/attention_visualization.png +3 -0
- assets/example_negative.png +0 -0
- assets/exmaple_positive.png +0 -0
- config.json +16 -0
- inference_example.py +34 -0
- model-index.json +15 -0
- model.safetensors +3 -0
- special_tokens_map.json +7 -0
- tokenizer.json +0 -0
- tokenizer_config.json +56 -0
- vocab.txt +0 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,4 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
assets/attention_visualization.png filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
model.safetensors filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_CARD.md
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Attention-based Sentiment Classifier
|
| 2 |
+
|
| 3 |
+
This model is an attention-based sentiment classification model that uses a bidirectional GRU with an attention mechanism to classify text sentiment as positive or negative.
|
| 4 |
+
|
| 5 |
+
## Model Description
|
| 6 |
+
|
| 7 |
+
- **Developed by:** Lantian Wei
|
| 8 |
+
- **Model type:** Sentiment Classification
|
| 9 |
+
- **Language(s):** English
|
| 10 |
+
- **License:** GNU General Public License v3.0
|
| 11 |
+
- **Finetuned from model:** Trained from scratch, using pre-trained BERT tokenizer
|
| 12 |
+
|
| 13 |
+
This sentiment classifier uses a bidirectional GRU architecture with an attention mechanism to focus on the most sentiment-relevant parts of a sentence. The model was trained on the SST-2 (Stanford Sentiment Treebank) dataset, a collection of movie reviews with binary sentiment labels.
|
| 14 |
+
|
| 15 |
+
### Model Architecture
|
| 16 |
+
|
| 17 |
+
- Embedding layer (100 dimensions)
|
| 18 |
+
- Bidirectional GRU (256 hidden dimensions)
|
| 19 |
+
- Attention mechanism
|
| 20 |
+
- Fully connected layers
|
| 21 |
+
- Output: 2 classes (positive/negative)
|
| 22 |
+
|
| 23 |
+
## Intended Uses & Limitations
|
| 24 |
+
|
| 25 |
+
### Intended Uses
|
| 26 |
+
|
| 27 |
+
- Sentiment analysis of short to medium-length English text
|
| 28 |
+
- Educational purposes to understand attention mechanisms
|
| 29 |
+
- Research on interpretability in NLP models
|
| 30 |
+
|
| 31 |
+
### Limitations
|
| 32 |
+
|
| 33 |
+
- Only trained on movie reviews, may not generalize to other domains
|
| 34 |
+
- Limited to English text
|
| 35 |
+
- Binary classification only (positive/negative)
|
| 36 |
+
- Not suitable for multi-lingual content
|
| 37 |
+
- Performance may degrade on texts significantly different from movie reviews
|
| 38 |
+
|
| 39 |
+
## Training Data
|
| 40 |
+
|
| 41 |
+
The model was trained on the SST-2 (Stanford Sentiment Treebank) dataset, which consists of movie reviews labeled as positive or negative. The dataset is commonly used as a benchmark for sentiment analysis models.
|
| 42 |
+
|
| 43 |
+
- Dataset: SST-2 from the GLUE benchmark
|
| 44 |
+
- Training examples: 30,000
|
| 45 |
+
- Validation examples: 500
|
| 46 |
+
|
| 47 |
+
## Training Procedure
|
| 48 |
+
|
| 49 |
+
### Training Hyperparameters
|
| 50 |
+
|
| 51 |
+
- Learning rate: 1e-3
|
| 52 |
+
- Epochs: 12
|
| 53 |
+
- Optimizer: Adam
|
| 54 |
+
- Loss function: Cross Entropy Loss
|
| 55 |
+
- Embedding dimension: 100
|
| 56 |
+
- Hidden dimension: 256
|
| 57 |
+
- Dropout: 0.3
|
| 58 |
+
|
| 59 |
+
## Evaluation Results
|
| 60 |
+
|
| 61 |
+
- Validation accuracy: [Insert your validation accuracy here]
|
| 62 |
+
- Test accuracy: [Insert your test accuracy here]
|
| 63 |
+
|
| 64 |
+
## Visualization Examples
|
| 65 |
+
|
| 66 |
+
One of the key features of this model is its interpretability through attention visualization. The model can output attention weights that highlight which parts of the input text it focused on to make its prediction.
|
| 67 |
+
|
| 68 |
+

|
| 69 |
+
|
| 70 |
+
## Usage Examples
|
| 71 |
+
|
| 72 |
+
```python
|
| 73 |
+
from transformers import AutoTokenizer
|
| 74 |
+
from models.huggingface_model import SentimentClassifierForHuggingFace, SentimentClassifierConfig
|
| 75 |
+
import torch
|
| 76 |
+
import matplotlib.pyplot as plt
|
| 77 |
+
import seaborn as sns
|
| 78 |
+
|
| 79 |
+
# Load the model
|
| 80 |
+
config = SentimentClassifierConfig()
|
| 81 |
+
model = SentimentClassifierForHuggingFace(config)
|
| 82 |
+
model.load_state_dict(torch.load("path_to_weights.pth"))
|
| 83 |
+
model.eval()
|
| 84 |
+
|
| 85 |
+
# Load the tokenizer
|
| 86 |
+
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
| 87 |
+
|
| 88 |
+
# Function to make predictions with attention visualization
|
| 89 |
+
def predict_with_attention(text):
|
| 90 |
+
# Tokenize
|
| 91 |
+
tokens = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
|
| 92 |
+
input_ids = tokens["input_ids"]
|
| 93 |
+
|
| 94 |
+
# Get prediction and attention weights
|
| 95 |
+
with torch.no_grad():
|
| 96 |
+
outputs = model(input_ids, return_attention=True, return_dict=True)
|
| 97 |
+
|
| 98 |
+
logits = outputs["logits"]
|
| 99 |
+
attention_weights = outputs["attention_weights"]
|
| 100 |
+
|
| 101 |
+
# Get prediction and confidence
|
| 102 |
+
probs = torch.nn.functional.softmax(logits, dim=1)
|
| 103 |
+
prediction = torch.argmax(probs, dim=1).item()
|
| 104 |
+
confidence = probs[0][prediction].item()
|
| 105 |
+
sentiment = "Positive" if prediction == 1 else "Negative"
|
| 106 |
+
|
| 107 |
+
# Visualize attention weights
|
| 108 |
+
tokens_list = [tokenizer.convert_ids_to_tokens(id.item()) for id in input_ids[0]]
|
| 109 |
+
|
| 110 |
+
# Plot attention heatmap
|
| 111 |
+
plt.figure(figsize=(10, 2))
|
| 112 |
+
sns.heatmap(
|
| 113 |
+
attention_weights.squeeze(0).cpu().numpy(),
|
| 114 |
+
cmap="YlOrRd",
|
| 115 |
+
annot=True,
|
| 116 |
+
fmt=".2f",
|
| 117 |
+
cbar=False,
|
| 118 |
+
xticklabels=tokens_list,
|
| 119 |
+
yticklabels=["Attention"]
|
| 120 |
+
)
|
| 121 |
+
plt.title(f"Prediction: {sentiment} (Confidence: {confidence:.4f})")
|
| 122 |
+
plt.tight_layout()
|
| 123 |
+
plt.show()
|
| 124 |
+
|
| 125 |
+
return {
|
| 126 |
+
"text": text,
|
| 127 |
+
"sentiment": sentiment,
|
| 128 |
+
"confidence": confidence,
|
| 129 |
+
"attention": attention_weights.squeeze(0).cpu().numpy()
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
# Example usage
|
| 133 |
+
result = predict_with_attention("I absolutely loved this movie! The acting was superb.")
|
| 134 |
+
print(f"Sentiment: {result['sentiment']} (Confidence: {result['confidence']:.4f})")
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
## Citations
|
| 138 |
+
|
| 139 |
+
```
|
| 140 |
+
@inproceedings{socher2013recursive,
|
| 141 |
+
title={Recursive deep models for semantic compositionality over a sentiment treebank},
|
| 142 |
+
author={Socher, Richard and Perelygin, Alex and Wu, Jean and Chuang, Jason and Manning, Christopher D and Ng, Andrew Y and Potts, Christopher},
|
| 143 |
+
booktitle={Proceedings of the 2013 conference on empirical methods in natural language processing},
|
| 144 |
+
pages={1631--1642},
|
| 145 |
+
year={2013}
|
| 146 |
+
}
|
| 147 |
+
```
|
README.md
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Attention-based Sentiment Classifier
|
| 2 |
+
|
| 3 |
+
This repository contains an attention-based sentiment classification model that demonstrates how attention mechanisms can enhance interpretability in NLP tasks.
|
| 4 |
+
|
| 5 |
+

|
| 6 |
+
|
| 7 |
+
## Model Overview
|
| 8 |
+
|
| 9 |
+
This model uses a bidirectional GRU with an attention mechanism to classify text sentiment (positive/negative). The attention mechanism allows the model to focus on the most relevant parts of the input text, providing insight into which words influence the classification the most.
|
| 10 |
+
|
| 11 |
+
### Key Features
|
| 12 |
+
|
| 13 |
+
- Bidirectional GRU architecture
|
| 14 |
+
- Additive attention mechanism for interpretability
|
| 15 |
+
- Binary sentiment classification (positive/negative)
|
| 16 |
+
- Visualization tools for attention weights
|
| 17 |
+
|
| 18 |
+
## Quick Start
|
| 19 |
+
|
| 20 |
+
```python
|
| 21 |
+
from transformers import pipeline
|
| 22 |
+
import matplotlib.pyplot as plt
|
| 23 |
+
import seaborn as sns
|
| 24 |
+
|
| 25 |
+
# Load model directly from Hugging Face
|
| 26 |
+
classifier = pipeline(
|
| 27 |
+
"text-classification",
|
| 28 |
+
model="your-username/attention-sentiment-classifier"
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
# Standard prediction
|
| 32 |
+
result = classifier("I absolutely loved this movie! The acting was superb.")
|
| 33 |
+
print(f"Sentiment: {result[0]['label']}, Score: {result[0]['score']:.4f}")
|
| 34 |
+
|
| 35 |
+
# For attention visualization, use the model directly
|
| 36 |
+
from transformers import AutoTokenizer, AutoModel
|
| 37 |
+
import torch
|
| 38 |
+
|
| 39 |
+
tokenizer = AutoTokenizer.from_pretrained("your-username/attention-sentiment-classifier")
|
| 40 |
+
model = AutoModel.from_pretrained("your-username/attention-sentiment-classifier")
|
| 41 |
+
|
| 42 |
+
text = "I absolutely loved this movie! The acting was superb."
|
| 43 |
+
inputs = tokenizer(text, return_tensors="pt")
|
| 44 |
+
|
| 45 |
+
# Get prediction with attention weights
|
| 46 |
+
model.eval()
|
| 47 |
+
with torch.no_grad():
|
| 48 |
+
outputs = model(inputs["input_ids"], return_attention=True, return_dict=True)
|
| 49 |
+
|
| 50 |
+
# Get prediction results
|
| 51 |
+
logits = outputs["logits"]
|
| 52 |
+
attention_weights = outputs["attention_weights"]
|
| 53 |
+
|
| 54 |
+
# Visualize attention
|
| 55 |
+
tokens = [tokenizer.convert_ids_to_tokens(id.item()) for id in inputs["input_ids"][0]]
|
| 56 |
+
|
| 57 |
+
plt.figure(figsize=(10, 2))
|
| 58 |
+
sns.heatmap(
|
| 59 |
+
attention_weights.squeeze(0).cpu().numpy().reshape(1, -1),
|
| 60 |
+
cmap="YlOrRd",
|
| 61 |
+
annot=True,
|
| 62 |
+
fmt=".2f",
|
| 63 |
+
cbar=False,
|
| 64 |
+
xticklabels=tokens,
|
| 65 |
+
yticklabels=["Attention"]
|
| 66 |
+
)
|
| 67 |
+
plt.xticks(rotation=45, ha="right", rotation_mode="anchor")
|
| 68 |
+
plt.title("Attention Weights Visualization")
|
| 69 |
+
plt.tight_layout()
|
| 70 |
+
plt.show()
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
## Demo App
|
| 74 |
+
|
| 75 |
+
This model includes a Streamlit demo app that can be launched directly on Hugging Face Spaces.
|
| 76 |
+
|
| 77 |
+
## Model Architecture
|
| 78 |
+
|
| 79 |
+
The model consists of:
|
| 80 |
+
|
| 81 |
+
1. **Embedding Layer**: Converts token IDs to dense vectors
|
| 82 |
+
2. **Bidirectional GRU**: Processes the text in both directions
|
| 83 |
+
3. **Attention Mechanism**: Focuses on the most relevant parts of the text
|
| 84 |
+
4. **Classifier Head**: Makes the final sentiment prediction
|
| 85 |
+
|
| 86 |
+
## Training
|
| 87 |
+
|
| 88 |
+
The model was trained on the SST-2 (Stanford Sentiment Treebank) dataset using the following hyperparameters:
|
| 89 |
+
|
| 90 |
+
- Learning rate: 1e-3
|
| 91 |
+
- Epochs: 12
|
| 92 |
+
- Optimizer: Adam
|
| 93 |
+
- Loss function: Cross Entropy Loss
|
| 94 |
+
- Embedding dimension: 100
|
| 95 |
+
- Hidden dimension: 256
|
| 96 |
+
|
| 97 |
+
## Limitations
|
| 98 |
+
|
| 99 |
+
- Only trained on movie reviews, may not generalize to other domains
|
| 100 |
+
- Limited to English text
|
| 101 |
+
- Binary classification only (positive/negative)
|
| 102 |
+
- Not suitable for multi-lingual content
|
| 103 |
+
- Performance may degrade on texts significantly different from movie reviews
|
| 104 |
+
|
| 105 |
+
## Citation
|
| 106 |
+
|
| 107 |
+
If you use this model, please cite:
|
| 108 |
+
|
| 109 |
+
```
|
| 110 |
+
@misc{attention-sentiment-classifier,
|
| 111 |
+
author = {Lantian Wei},
|
| 112 |
+
title = {Attention-based Sentiment Classifier},
|
| 113 |
+
year = {2025},
|
| 114 |
+
publisher = {Hugging Face},
|
| 115 |
+
howpublished = {\url{https://huggingface.co/your-username/attention-sentiment-classifier}}
|
| 116 |
+
}
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
## License
|
| 120 |
+
|
| 121 |
+
This model is licensed under the GNU General Public License v3.0.
|
app.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import torch
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import seaborn as sns
|
| 5 |
+
from transformers import AutoTokenizer
|
| 6 |
+
from models.huggingface_model import SentimentClassifierForHuggingFace
|
| 7 |
+
import numpy as np
|
| 8 |
+
import io
|
| 9 |
+
from PIL import Image
|
| 10 |
+
|
| 11 |
+
# Load model and tokenizer
|
| 12 |
+
@st.cache_resource
|
| 13 |
+
def load_model():
|
| 14 |
+
model = SentimentClassifierForHuggingFace.from_pretrained("./")
|
| 15 |
+
tokenizer = AutoTokenizer.from_pretrained("./")
|
| 16 |
+
return model, tokenizer
|
| 17 |
+
|
| 18 |
+
def predict_sentiment(text, model, tokenizer):
|
| 19 |
+
# Tokenize the input
|
| 20 |
+
tokens = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
|
| 21 |
+
input_ids = tokens["input_ids"]
|
| 22 |
+
|
| 23 |
+
# Run inference
|
| 24 |
+
model.eval()
|
| 25 |
+
with torch.no_grad():
|
| 26 |
+
outputs = model(input_ids, return_attention=True, return_dict=True)
|
| 27 |
+
|
| 28 |
+
# Get prediction results
|
| 29 |
+
logits = outputs["logits"]
|
| 30 |
+
attention_weights = outputs["attention_weights"]
|
| 31 |
+
|
| 32 |
+
# Convert to probabilities and get prediction
|
| 33 |
+
probs = torch.nn.functional.softmax(logits, dim=1)
|
| 34 |
+
prediction = torch.argmax(probs, dim=1).item()
|
| 35 |
+
confidence = probs[0][prediction].item()
|
| 36 |
+
sentiment = "Positive" if prediction == 1 else "Negative"
|
| 37 |
+
|
| 38 |
+
# Get token strings for visualization
|
| 39 |
+
tokens_list = []
|
| 40 |
+
for id in input_ids[0]:
|
| 41 |
+
token = tokenizer.convert_ids_to_tokens(id.item())
|
| 42 |
+
tokens_list.append(token)
|
| 43 |
+
|
| 44 |
+
# Create visualization
|
| 45 |
+
fig, ax = plt.subplots(figsize=(10, 2))
|
| 46 |
+
sns.heatmap(
|
| 47 |
+
attention_weights.squeeze(0).cpu().numpy().reshape(1, -1),
|
| 48 |
+
cmap="YlOrRd",
|
| 49 |
+
annot=True,
|
| 50 |
+
fmt=".2f",
|
| 51 |
+
cbar=False,
|
| 52 |
+
xticklabels=tokens_list,
|
| 53 |
+
yticklabels=["Attention"],
|
| 54 |
+
ax=ax
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# Rotate x-axis labels for better readability
|
| 58 |
+
plt.xticks(rotation=45, ha="right", rotation_mode="anchor")
|
| 59 |
+
plt.title(f"Prediction: {sentiment} (Confidence: {confidence:.4f})")
|
| 60 |
+
plt.tight_layout()
|
| 61 |
+
|
| 62 |
+
# Convert plot to image
|
| 63 |
+
buf = io.BytesIO()
|
| 64 |
+
fig.savefig(buf, format="png", dpi=150, bbox_inches="tight")
|
| 65 |
+
buf.seek(0)
|
| 66 |
+
img = Image.open(buf)
|
| 67 |
+
plt.close(fig)
|
| 68 |
+
|
| 69 |
+
return sentiment, confidence, img
|
| 70 |
+
|
| 71 |
+
# Streamlit app
|
| 72 |
+
def main():
|
| 73 |
+
st.set_page_config(
|
| 74 |
+
page_title="Sentiment Analysis with Attention",
|
| 75 |
+
page_icon="🧠",
|
| 76 |
+
layout="wide"
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
st.title("Sentiment Analysis with Attention Visualization")
|
| 80 |
+
st.write("This model classifies text sentiment as positive or negative and visualizes which parts of the text it focused on using an attention mechanism.")
|
| 81 |
+
|
| 82 |
+
# Load model and tokenizer
|
| 83 |
+
try:
|
| 84 |
+
model, tokenizer = load_model()
|
| 85 |
+
model_loaded = True
|
| 86 |
+
except Exception as e:
|
| 87 |
+
st.error(f"Error loading model: {e}")
|
| 88 |
+
model_loaded = False
|
| 89 |
+
|
| 90 |
+
# Text input
|
| 91 |
+
text_input = st.text_area(
|
| 92 |
+
"Enter text to analyze:",
|
| 93 |
+
value="I absolutely loved this movie! The acting was superb.",
|
| 94 |
+
height=100,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Process when button is clicked
|
| 98 |
+
if st.button("Analyze Sentiment") and model_loaded:
|
| 99 |
+
with st.spinner("Analyzing..."):
|
| 100 |
+
sentiment, confidence, viz_img = predict_sentiment(text_input, model, tokenizer)
|
| 101 |
+
|
| 102 |
+
# Display results
|
| 103 |
+
col1, col2 = st.columns([1, 3])
|
| 104 |
+
|
| 105 |
+
with col1:
|
| 106 |
+
st.subheader("Prediction:")
|
| 107 |
+
sentiment_color = "#5FD068" if sentiment == "Positive" else "#D21312"
|
| 108 |
+
st.markdown(
|
| 109 |
+
f"<div style='background-color:{sentiment_color}; padding:10px; border-radius:5px;"
|
| 110 |
+
f"color:white; text-align:center; font-size:24px;'>{sentiment}</div>",
|
| 111 |
+
unsafe_allow_html=True
|
| 112 |
+
)
|
| 113 |
+
st.metric("Confidence", f"{confidence:.2%}")
|
| 114 |
+
|
| 115 |
+
with col2:
|
| 116 |
+
st.subheader("Attention Visualization:")
|
| 117 |
+
st.image(viz_img, use_column_width=True)
|
| 118 |
+
st.caption("The heatmap shows which words the model focused on most when making its prediction.")
|
| 119 |
+
|
| 120 |
+
st.markdown("---")
|
| 121 |
+
st.subheader("How to interpret the visualization:")
|
| 122 |
+
st.write(
|
| 123 |
+
"The attention heatmap shows the weight assigned to each token in the text. "
|
| 124 |
+
"Darker colors indicate where the model focused more attention when making its prediction. "
|
| 125 |
+
"This can help identify which parts of the text were most influential for sentiment classification."
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
if __name__ == "__main__":
|
| 129 |
+
main()
|
assets/attention_visualization.png
ADDED
|
Git LFS Details
|
assets/example_negative.png
ADDED
|
assets/exmaple_positive.png
ADDED
|
config.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"SentimentClassifierForHuggingFace"
|
| 4 |
+
],
|
| 5 |
+
"bidirectional": true,
|
| 6 |
+
"dropout": 0.3,
|
| 7 |
+
"embedding_dim": 100,
|
| 8 |
+
"hidden_dim": 256,
|
| 9 |
+
"model_type": "sentiment_classifier",
|
| 10 |
+
"n_layers": 1,
|
| 11 |
+
"output_dim": 2,
|
| 12 |
+
"pad_idx": 0,
|
| 13 |
+
"torch_dtype": "float32",
|
| 14 |
+
"transformers_version": "4.51.2",
|
| 15 |
+
"vocab_size": 30522
|
| 16 |
+
}
|
inference_example.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import AutoTokenizer
|
| 4 |
+
from models.huggingface_model import SentimentClassifierForHuggingFace
|
| 5 |
+
|
| 6 |
+
# Load the model and tokenizer
|
| 7 |
+
model = SentimentClassifierForHuggingFace.from_pretrained("./")
|
| 8 |
+
tokenizer = AutoTokenizer.from_pretrained("./")
|
| 9 |
+
|
| 10 |
+
# Prepare text input
|
| 11 |
+
text = "I absolutely loved this movie! The acting was superb."
|
| 12 |
+
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
|
| 13 |
+
|
| 14 |
+
# Run inference
|
| 15 |
+
model.eval()
|
| 16 |
+
with torch.no_grad():
|
| 17 |
+
outputs = model(inputs["input_ids"], return_attention=True, return_dict=True)
|
| 18 |
+
|
| 19 |
+
# Process results
|
| 20 |
+
logits = outputs["logits"]
|
| 21 |
+
attention_weights = outputs["attention_weights"]
|
| 22 |
+
|
| 23 |
+
# Get prediction and confidence
|
| 24 |
+
probs = torch.nn.functional.softmax(logits, dim=1)
|
| 25 |
+
prediction = torch.argmax(probs, dim=1).item()
|
| 26 |
+
confidence = probs[0][prediction].item()
|
| 27 |
+
sentiment = "Positive" if prediction == 1 else "Negative"
|
| 28 |
+
|
| 29 |
+
print(f"Text: {text}")
|
| 30 |
+
print(f"Sentiment: {sentiment}")
|
| 31 |
+
print(f"Confidence: {confidence:.4f}")
|
| 32 |
+
|
| 33 |
+
# To visualize attention weights, add matplotlib and seaborn imports
|
| 34 |
+
# and use attention_weights to create a heatmap
|
model-index.json
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "attention-sentiment-classifier",
|
| 3 |
+
"description": "Attention-based sentiment classification model that visualizes which parts of text influence predictions",
|
| 4 |
+
"tags": [
|
| 5 |
+
"pytorch",
|
| 6 |
+
"text-classification",
|
| 7 |
+
"sentiment-analysis",
|
| 8 |
+
"attention-mechanism",
|
| 9 |
+
"english",
|
| 10 |
+
"sst2"
|
| 11 |
+
],
|
| 12 |
+
"license": "gpl-3.0",
|
| 13 |
+
"library_name": "transformers",
|
| 14 |
+
"pipeline_tag": "text-classification"
|
| 15 |
+
}
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:914de6219991e18c1a6be6138555ff8cfe4c5b3d238eca3cb8c920dfc261968d
|
| 3 |
+
size 17040676
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cls_token": "[CLS]",
|
| 3 |
+
"mask_token": "[MASK]",
|
| 4 |
+
"pad_token": "[PAD]",
|
| 5 |
+
"sep_token": "[SEP]",
|
| 6 |
+
"unk_token": "[UNK]"
|
| 7 |
+
}
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {
|
| 3 |
+
"0": {
|
| 4 |
+
"content": "[PAD]",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false,
|
| 9 |
+
"special": true
|
| 10 |
+
},
|
| 11 |
+
"100": {
|
| 12 |
+
"content": "[UNK]",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false,
|
| 17 |
+
"special": true
|
| 18 |
+
},
|
| 19 |
+
"101": {
|
| 20 |
+
"content": "[CLS]",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": false,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false,
|
| 25 |
+
"special": true
|
| 26 |
+
},
|
| 27 |
+
"102": {
|
| 28 |
+
"content": "[SEP]",
|
| 29 |
+
"lstrip": false,
|
| 30 |
+
"normalized": false,
|
| 31 |
+
"rstrip": false,
|
| 32 |
+
"single_word": false,
|
| 33 |
+
"special": true
|
| 34 |
+
},
|
| 35 |
+
"103": {
|
| 36 |
+
"content": "[MASK]",
|
| 37 |
+
"lstrip": false,
|
| 38 |
+
"normalized": false,
|
| 39 |
+
"rstrip": false,
|
| 40 |
+
"single_word": false,
|
| 41 |
+
"special": true
|
| 42 |
+
}
|
| 43 |
+
},
|
| 44 |
+
"clean_up_tokenization_spaces": false,
|
| 45 |
+
"cls_token": "[CLS]",
|
| 46 |
+
"do_lower_case": true,
|
| 47 |
+
"extra_special_tokens": {},
|
| 48 |
+
"mask_token": "[MASK]",
|
| 49 |
+
"model_max_length": 512,
|
| 50 |
+
"pad_token": "[PAD]",
|
| 51 |
+
"sep_token": "[SEP]",
|
| 52 |
+
"strip_accents": null,
|
| 53 |
+
"tokenize_chinese_chars": true,
|
| 54 |
+
"tokenizer_class": "BertTokenizer",
|
| 55 |
+
"unk_token": "[UNK]"
|
| 56 |
+
}
|
vocab.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|