| | --- |
| | language: en |
| | license: mit |
| | pipeline_tag: text-classification |
| | tags: |
| | - text-classification |
| | - transformers |
| | - pytorch |
| | - multi-label-classification |
| | - multi-class-classification |
| | - emotion |
| | - bert |
| | - go_emotions |
| | - emotion-classification |
| | datasets: |
| | - google-research-datasets/go_emotions |
| | metrics: |
| | - f1 |
| | - precision |
| | - recall |
| | widget: |
| | - text: I’m just chilling today. |
| | example_title: Neutral Example |
| | - text: Thank you for saving my life! |
| | example_title: Gratitude Example |
| | - text: I’m nervous about my exam tomorrow. |
| | example_title: Nervousness Example |
| | base_model: |
| | - google-bert/bert-base-uncased |
| | --- |
| | |
| | # GoEmotions BERT Classifier |
| |
|
| | Fine-tuned [BERT-base-uncased](https://huggingface.co/bert-base-uncased) on [go_emotions](https://huggingface.co/datasets/go_emotions) for multi-label classification (28 emotions). |
| |
|
| | ## Model Details |
| |
|
| | - **Architecture**: BERT-base-uncased (110M parameters) |
| | - **Training Data**: [GoEmotions](https://huggingface.co/datasets/google-research-datasets/go_emotions) (58k Reddit comments, 28 emotions) |
| | - **Loss Function**: Focal Loss (gamma=2) |
| | - **Optimizer**: AdamW (lr=2e-5, weight_decay=0.01) |
| | - **Epochs**: 5 |
| | - **Hardware**: Kaggle T4 x2 GPUs |
| | |
| | ## Try It Out |
| | For accurate predictions with optimized thresholds, use the [Gradio demo](https://logasanjeev-goemotions-bert-demo.hf.space). |
| | |
| | ## Performance |
| | |
| | - **Micro F1**: 0.6025 (optimized thresholds) |
| | - **Macro F1**: 0.5266 |
| | - **Precision**: 0.5425 |
| | - **Recall**: 0.6775 |
| | - **Hamming Loss**: 0.0372 |
| | - **Avg Positive Predictions**: 1.4564 |
| | |
| | ### Class-Wise Performance |
| | The following table shows per-class metrics on the test set using optimized thresholds (see `thresholds.json`): |
| | |
| | | Emotion | F1 Score | Precision | Recall | Support | |
| | |----------------|----------|-----------|--------|---------| |
| | | admiration | 0.7022 | 0.6980 | 0.7063 | 504 | |
| | | amusement | 0.8171 | 0.7692 | 0.8712 | 264 | |
| | | anger | 0.5123 | 0.5000 | 0.5253 | 198 | |
| | | annoyance | 0.3820 | 0.2908 | 0.5563 | 320 | |
| | | approval | 0.4112 | 0.3485 | 0.5014 | 351 | |
| | | caring | 0.4601 | 0.4045 | 0.5333 | 135 | |
| | | confusion | 0.4488 | 0.4533 | 0.4444 | 153 | |
| | | curiosity | 0.5721 | 0.4402 | 0.8169 | 284 | |
| | | desire | 0.4068 | 0.6857 | 0.2892 | 83 | |
| | | disappointment | 0.3476 | 0.3220 | 0.3775 | 151 | |
| | | disapproval | 0.4126 | 0.3433 | 0.5169 | 267 | |
| | | disgust | 0.4950 | 0.6329 | 0.4065 | 123 | |
| | | embarrassment | 0.5000 | 0.7368 | 0.3784 | 37 | |
| | | excitement | 0.4084 | 0.4432 | 0.3786 | 103 | |
| | | fear | 0.6311 | 0.5078 | 0.8333 | 78 | |
| | | gratitude | 0.9173 | 0.9744 | 0.8665 | 352 | |
| | | grief | 0.2500 | 0.5000 | 0.1667 | 6 | |
| | | joy | 0.6246 | 0.5798 | 0.6770 | 161 | |
| | | love | 0.8110 | 0.7630 | 0.8655 | 238 | |
| | | nervousness | 0.3830 | 0.3750 | 0.3913 | 23 | |
| | | optimism | 0.5777 | 0.5856 | 0.5699 | 186 | |
| | | pride | 0.4138 | 0.4615 | 0.3750 | 16 | |
| | | realization | 0.2421 | 0.5111 | 0.1586 | 145 | |
| | | relief | 0.5385 | 0.4667 | 0.6364 | 11 | |
| | | remorse | 0.6797 | 0.5361 | 0.9286 | 56 | |
| | | sadness | 0.5391 | 0.6900 | 0.4423 | 156 | |
| | | surprise | 0.5724 | 0.5570 | 0.5887 | 141 | |
| | | neutral | 0.6895 | 0.5826 | 0.8444 | 1787 | |
| | |
| | ## Usage |
| | |
| | The model uses optimized thresholds stored in `thresholds.json` for predictions. Example in Python: |
| | |
| | ```python |
| | from transformers import BertForSequenceClassification, BertTokenizer |
| | import torch |
| | import json |
| | import requests |
| | |
| | # Load model and tokenizer |
| | repo_id = "logasanjeev/goemotions-bert" |
| | model = BertForSequenceClassification.from_pretrained(repo_id) |
| | tokenizer = BertTokenizer.from_pretrained(repo_id) |
| |
|
| | # Load thresholds |
| | thresholds_url = f"https://huggingface.co/{repo_id}/raw/main/thresholds.json" |
| | thresholds_data = json.loads(requests.get(thresholds_url).text) |
| | emotion_labels = thresholds_data["emotion_labels"] |
| | thresholds = thresholds_data["thresholds"] |
| |
|
| | # Predict |
| | text = "I’m just chilling today." |
| | encodings = tokenizer(text, padding='max_length', truncation=True, max_length=128, return_tensors='pt') |
| | with torch.no_grad(): |
| | logits = torch.sigmoid(model(**encodings).logits).numpy()[0] |
| | predictions = [(emotion_labels[i], logit) for i, (logit, thresh) in enumerate(zip(logits, thresholds)) if logit >= thresh] |
| | print(sorted(predictions, key=lambda x: x[1], reverse=True)) |
| | # Output: [('neutral', 0.8147)] |