|
|
--- |
|
|
license: apache-2.0 |
|
|
datasets: |
|
|
- dair-ai/emotion |
|
|
language: |
|
|
- en |
|
|
metrics: |
|
|
- accuracy |
|
|
- f1 |
|
|
- precision |
|
|
- recall |
|
|
pipeline_tag: text-classification |
|
|
--- |
|
|
|
|
|
|
|
|
|
|
|
# Emotion Classification with BERT + RL Fine-tuning |
|
|
|
|
|
This model combines BERT architecture with Reinforcement Learning (RL) for emotion classification. Initially fine-tuned on the `dair-ai/emotion` dataset (20k English sentences with 6 emotions), we then applied PPO reinforcement learning to optimize prediction behavior. |
|
|
|
|
|
## π§ Training Approach |
|
|
|
|
|
1. **Supervised Phase**: |
|
|
- Base BERT model fine-tuned with cross-entropy loss |
|
|
- Achieved strong baseline performance |
|
|
|
|
|
2. **RL Phase**: |
|
|
- Implemented Actor-Critic architecture |
|
|
- Policy Gradient optimization with custom rewards |
|
|
- PPO clipping (Ξ΅=0.2) and entropy regularization |
|
|
- Custom reward function: `+1.0` for correct, `-0.1` for incorrect predictions |
|
|
|
|
|
## π Performance Comparison |
|
|
|
|
|
| Metric | Pre-RL | Post-RL | Ξ | |
|
|
|------------|---------|---------|---------| |
|
|
| Accuracy | 0.9205 | 0.931 | +1.14% | |
|
|
| F1-Score | 0.9227 | 0.9298 | +0.77% | |
|
|
| Precision | 0.9325 | 0.9305 | -0.21% | |
|
|
| Recall | 0.9205 | 0.931 | +1.14% | |
|
|
|
|
|
Key observation: RL fine-tuning provided modest but consistent improvements across most metrics, particularly in recall. |
|
|
|
|
|
## π Usage |
|
|
|
|
|
```python |
|
|
from transformers import pipeline |
|
|
|
|
|
# Load from your repository |
|
|
classifier = pipeline("text-classification", |
|
|
model="SimoGiuffrida/SentimentRL", |
|
|
tokenizer="bert-base-uncased") |
|
|
|
|
|
results = classifier("I'm thrilled about this new opportunity!") |
|
|
``` |
|
|
|
|
|
## π‘ Key Features |
|
|
- Hybrid training: Supervised + Reinforcement Learning |
|
|
- Optimized for nuanced emotion detection |
|
|
- Handles class imbalance (see confusion matrix in repo) |
|
|
|
|
|
For full training details and analysis, visit the [GitHub repository](https://github.com/SimoGiuffrida/DLA2). |