Added Class Wise Performance
Browse files
README.md
CHANGED
|
@@ -54,12 +54,39 @@ For accurate predictions with optimized thresholds, use the [Gradio demo](https:
|
|
| 54 |
- **Hamming Loss**: 0.0372
|
| 55 |
- **Avg Positive Predictions**: 1.4564
|
| 56 |
|
| 57 |
-
|
| 58 |
-
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
## Usage
|
| 65 |
|
|
@@ -89,5 +116,4 @@ with torch.no_grad():
|
|
| 89 |
logits = torch.sigmoid(model(**encodings).logits).numpy()[0]
|
| 90 |
predictions = [(emotion_labels[i], logit) for i, (logit, thresh) in enumerate(zip(logits, thresholds)) if logit >= thresh]
|
| 91 |
print(sorted(predictions, key=lambda x: x[1], reverse=True))
|
| 92 |
-
# Output: [('neutral', 0.8147)]
|
| 93 |
-
```
|
|
|
|
| 54 |
- **Hamming Loss**: 0.0372
|
| 55 |
- **Avg Positive Predictions**: 1.4564
|
| 56 |
|
| 57 |
+
### Class-Wise Performance
|
| 58 |
+
The following table shows per-class metrics on the test set using optimized thresholds (see `thresholds.json`):
|
| 59 |
+
|
| 60 |
+
| Emotion | F1 Score | Precision | Recall | Support |
|
| 61 |
+
|----------------|----------|-----------|--------|---------|
|
| 62 |
+
| admiration | 0.7022 | 0.6980 | 0.7063 | 504 |
|
| 63 |
+
| amusement | 0.8171 | 0.7692 | 0.8712 | 264 |
|
| 64 |
+
| anger | 0.5123 | 0.5000 | 0.5253 | 198 |
|
| 65 |
+
| annoyance | 0.3820 | 0.2908 | 0.5563 | 320 |
|
| 66 |
+
| approval | 0.4112 | 0.3485 | 0.5014 | 351 |
|
| 67 |
+
| caring | 0.4601 | 0.4045 | 0.5333 | 135 |
|
| 68 |
+
| confusion | 0.4488 | 0.4533 | 0.4444 | 153 |
|
| 69 |
+
| curiosity | 0.5721 | 0.4402 | 0.8169 | 284 |
|
| 70 |
+
| desire | 0.4068 | 0.6857 | 0.2892 | 83 |
|
| 71 |
+
| disappointment | 0.3476 | 0.3220 | 0.3775 | 151 |
|
| 72 |
+
| disapproval | 0.4126 | 0.3433 | 0.5169 | 267 |
|
| 73 |
+
| disgust | 0.4950 | 0.6329 | 0.4065 | 123 |
|
| 74 |
+
| embarrassment | 0.5000 | 0.7368 | 0.3784 | 37 |
|
| 75 |
+
| excitement | 0.4084 | 0.4432 | 0.3786 | 103 |
|
| 76 |
+
| fear | 0.6311 | 0.5078 | 0.8333 | 78 |
|
| 77 |
+
| gratitude | 0.9173 | 0.9744 | 0.8665 | 352 |
|
| 78 |
+
| grief | 0.2500 | 0.5000 | 0.1667 | 6 |
|
| 79 |
+
| joy | 0.6246 | 0.5798 | 0.6770 | 161 |
|
| 80 |
+
| love | 0.8110 | 0.7630 | 0.8655 | 238 |
|
| 81 |
+
| nervousness | 0.3830 | 0.3750 | 0.3913 | 23 |
|
| 82 |
+
| optimism | 0.5777 | 0.5856 | 0.5699 | 186 |
|
| 83 |
+
| pride | 0.4138 | 0.4615 | 0.3750 | 16 |
|
| 84 |
+
| realization | 0.2421 | 0.5111 | 0.1586 | 145 |
|
| 85 |
+
| relief | 0.5385 | 0.4667 | 0.6364 | 11 |
|
| 86 |
+
| remorse | 0.6797 | 0.5361 | 0.9286 | 56 |
|
| 87 |
+
| sadness | 0.5391 | 0.6900 | 0.4423 | 156 |
|
| 88 |
+
| surprise | 0.5724 | 0.5570 | 0.5887 | 141 |
|
| 89 |
+
| neutral | 0.6895 | 0.5826 | 0.8444 | 1787 |
|
| 90 |
|
| 91 |
## Usage
|
| 92 |
|
|
|
|
| 116 |
logits = torch.sigmoid(model(**encodings).logits).numpy()[0]
|
| 117 |
predictions = [(emotion_labels[i], logit) for i, (logit, thresh) in enumerate(zip(logits, thresholds)) if logit >= thresh]
|
| 118 |
print(sorted(predictions, key=lambda x: x[1], reverse=True))
|
| 119 |
+
# Output: [('neutral', 0.8147)]
|
|
|