Update README.md
Browse files
README.md
CHANGED
|
@@ -10,4 +10,53 @@ metrics:
|
|
| 10 |
- precision
|
| 11 |
- recall
|
| 12 |
pipeline_tag: text-classification
|
| 13 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
- precision
|
| 11 |
- recall
|
| 12 |
pipeline_tag: text-classification
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# Emotion Classification with BERT + RL Fine-tuning
|
| 18 |
+
|
| 19 |
+
This model combines BERT architecture with Reinforcement Learning (RL) for emotion classification. Initially fine-tuned on the `dair-ai/emotion` dataset (20k English sentences with 6 emotions), we then applied PPO reinforcement learning to optimize prediction behavior.
|
| 20 |
+
|
| 21 |
+
## 🔧 Training Approach
|
| 22 |
+
|
| 23 |
+
1. **Supervised Phase**:
|
| 24 |
+
- Base BERT model fine-tuned with cross-entropy loss
|
| 25 |
+
- Achieved strong baseline performance
|
| 26 |
+
|
| 27 |
+
2. **RL Phase**:
|
| 28 |
+
- Implemented Actor-Critic architecture
|
| 29 |
+
- Policy Gradient optimization with custom rewards
|
| 30 |
+
- PPO clipping (ε=0.2) and entropy regularization
|
| 31 |
+
- Custom reward function: `+1.0` for correct, `-0.1` for incorrect predictions
|
| 32 |
+
|
| 33 |
+
## 📊 Performance Comparison
|
| 34 |
+
|
| 35 |
+
| Metric | Pre-RL | Post-RL | Δ |
|
| 36 |
+
|------------|---------|---------|---------|
|
| 37 |
+
| Accuracy | 0.9205 | 0.931 | +1.14% |
|
| 38 |
+
| F1-Score | 0.9227 | 0.9298 | +0.77% |
|
| 39 |
+
| Precision | 0.9325 | 0.9305 | -0.21% |
|
| 40 |
+
| Recall | 0.9205 | 0.931 | +1.14% |
|
| 41 |
+
|
| 42 |
+
Key observation: RL fine-tuning provided modest but consistent improvements across most metrics, particularly in recall.
|
| 43 |
+
|
| 44 |
+
## 🚀 Usage
|
| 45 |
+
|
| 46 |
+
```python
|
| 47 |
+
from transformers import pipeline
|
| 48 |
+
|
| 49 |
+
# Load from your repository
|
| 50 |
+
classifier = pipeline("text-classification",
|
| 51 |
+
model="SimoGiuffrida/SentimentRL",
|
| 52 |
+
tokenizer="bert-base-uncased")
|
| 53 |
+
|
| 54 |
+
results = classifier("I'm thrilled about this new opportunity!")
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
## 💡 Key Features
|
| 58 |
+
- Hybrid training: Supervised + Reinforcement Learning
|
| 59 |
+
- Optimized for nuanced emotion detection
|
| 60 |
+
- Handles class imbalance (see confusion matrix in repo)
|
| 61 |
+
|
| 62 |
+
For full training details and analysis, visit the [GitHub repository](https://github.com/SimoGiuffrida/DLA2).
|