| --- |
| language: en |
| license: mit |
| pipeline_tag: text-classification |
| tags: |
| - text-classification |
| - transformers |
| - pytorch |
| - multi-label-classification |
| - multi-class-classification |
| - emotion |
| - bert |
| - go_emotions |
| - emotion-classification |
| datasets: |
| - google-research-datasets/go_emotions |
| metrics: |
| - f1 |
| - precision |
| - recall |
| widget: |
| - text: I’m just chilling today. |
| example_title: Neutral Example |
| - text: Thank you for saving my life! |
| example_title: Gratitude Example |
| - text: I’m nervous about my exam tomorrow. |
| example_title: Nervousness Example |
| base_model: |
| - google-bert/bert-base-uncased |
| --- |
| |
| # GoEmotions BERT Classifier |
|
|
| Fine-tuned [BERT-base-uncased](https://huggingface.co/bert-base-uncased) on [go_emotions](https://huggingface.co/datasets/go_emotions) for multi-label classification (28 emotions). |
|
|
| ## Model Details |
|
|
| - **Architecture**: BERT-base-uncased (110M parameters) |
| - **Training Data**: [GoEmotions](https://huggingface.co/datasets/google-research-datasets/go_emotions) (58k Reddit comments, 28 emotions) |
| - **Loss Function**: Focal Loss (gamma=2) |
| - **Optimizer**: AdamW (lr=2e-5, weight_decay=0.01) |
| - **Epochs**: 5 |
| - **Hardware**: Kaggle T4 x2 GPUs |
| |
| ## Try It Out |
| For accurate predictions with optimized thresholds, use the [Gradio demo](https://logasanjeev-goemotions-bert-demo.hf.space). |
| |
| ## Performance |
| |
| - **Micro F1**: 0.6025 (optimized thresholds) |
| - **Macro F1**: 0.5266 |
| - **Precision**: 0.5425 |
| - **Recall**: 0.6775 |
| - **Hamming Loss**: 0.0372 |
| - **Avg Positive Predictions**: 1.4564 |
| |
| ### Class-Wise Performance |
| The following table shows per-class metrics on the test set using optimized thresholds (see `thresholds.json`): |
| |
| | Emotion | F1 Score | Precision | Recall | Support | |
| |----------------|----------|-----------|--------|---------| |
| | admiration | 0.7022 | 0.6980 | 0.7063 | 504 | |
| | amusement | 0.8171 | 0.7692 | 0.8712 | 264 | |
| | anger | 0.5123 | 0.5000 | 0.5253 | 198 | |
| | annoyance | 0.3820 | 0.2908 | 0.5563 | 320 | |
| | approval | 0.4112 | 0.3485 | 0.5014 | 351 | |
| | caring | 0.4601 | 0.4045 | 0.5333 | 135 | |
| | confusion | 0.4488 | 0.4533 | 0.4444 | 153 | |
| | curiosity | 0.5721 | 0.4402 | 0.8169 | 284 | |
| | desire | 0.4068 | 0.6857 | 0.2892 | 83 | |
| | disappointment | 0.3476 | 0.3220 | 0.3775 | 151 | |
| | disapproval | 0.4126 | 0.3433 | 0.5169 | 267 | |
| | disgust | 0.4950 | 0.6329 | 0.4065 | 123 | |
| | embarrassment | 0.5000 | 0.7368 | 0.3784 | 37 | |
| | excitement | 0.4084 | 0.4432 | 0.3786 | 103 | |
| | fear | 0.6311 | 0.5078 | 0.8333 | 78 | |
| | gratitude | 0.9173 | 0.9744 | 0.8665 | 352 | |
| | grief | 0.2500 | 0.5000 | 0.1667 | 6 | |
| | joy | 0.6246 | 0.5798 | 0.6770 | 161 | |
| | love | 0.8110 | 0.7630 | 0.8655 | 238 | |
| | nervousness | 0.3830 | 0.3750 | 0.3913 | 23 | |
| | optimism | 0.5777 | 0.5856 | 0.5699 | 186 | |
| | pride | 0.4138 | 0.4615 | 0.3750 | 16 | |
| | realization | 0.2421 | 0.5111 | 0.1586 | 145 | |
| | relief | 0.5385 | 0.4667 | 0.6364 | 11 | |
| | remorse | 0.6797 | 0.5361 | 0.9286 | 56 | |
| | sadness | 0.5391 | 0.6900 | 0.4423 | 156 | |
| | surprise | 0.5724 | 0.5570 | 0.5887 | 141 | |
| | neutral | 0.6895 | 0.5826 | 0.8444 | 1787 | |
| |
| ## Usage |
| |
| The model uses optimized thresholds stored in `thresholds.json` for predictions. Example in Python: |
| |
| ```python |
| from transformers import BertForSequenceClassification, BertTokenizer |
| import torch |
| import json |
| import requests |
| |
| # Load model and tokenizer |
| repo_id = "logasanjeev/goemotions-bert" |
| model = BertForSequenceClassification.from_pretrained(repo_id) |
| tokenizer = BertTokenizer.from_pretrained(repo_id) |
|
|
| # Load thresholds |
| thresholds_url = f"https://huggingface.co/{repo_id}/raw/main/thresholds.json" |
| thresholds_data = json.loads(requests.get(thresholds_url).text) |
| emotion_labels = thresholds_data["emotion_labels"] |
| thresholds = thresholds_data["thresholds"] |
|
|
| # Predict |
| text = "I’m just chilling today." |
| encodings = tokenizer(text, padding='max_length', truncation=True, max_length=128, return_tensors='pt') |
| with torch.no_grad(): |
| logits = torch.sigmoid(model(**encodings).logits).numpy()[0] |
| predictions = [(emotion_labels[i], logit) for i, (logit, thresh) in enumerate(zip(logits, thresholds)) if logit >= thresh] |
| print(sorted(predictions, key=lambda x: x[1], reverse=True)) |
| # Output: [('neutral', 0.8147)] |