Update README.md
Browse files
README.md
CHANGED
|
@@ -68,11 +68,11 @@ Consider the following chat interaction:
|
|
| 68 |
|
| 69 |
The model must predict the bolded parts. So, we randomly mask tokens from the bolded parts, and run the model once on the masked sequence and once on the full sequence.
|
| 70 |
|
| 71 |
-
We then compute a
|
| 72 |
|
| 73 |
Finally, we add this loss to the standard cross-entropy language modeling losses from each prediction, with a weighting value:
|
| 74 |
```
|
| 75 |
-
loss = CE(p_masked, labels) + CE(p_full, labels) + weight*D(p_masked, p_full)
|
| 76 |
```
|
| 77 |
|
| 78 |
***ReMask-CoT:***
|
|
@@ -103,7 +103,7 @@ Here are some benchmark results, computed using the the LM Evaluation Harness wi
|
|
| 103 |
| Masked Thought | 24.18% | *43.60%* |
|
| 104 |
| **ReMask** | **27.90%** | 43.26% |
|
| 105 |
|
| 106 |
-
As I expected, it improves GSM8K doesn't do much to ARC.
|
| 107 |
|
| 108 |
## Training details
|
| 109 |
- Framework: PyTorch Lightning
|
|
@@ -115,4 +115,5 @@ As I expected, it improves GSM8K doesn't do much to ARC.
|
|
| 115 |
- Batch size: 16, accumulated to 256
|
| 116 |
- Epochs: 6
|
| 117 |
- Learning rate: 1e-5
|
| 118 |
-
- Learning rate schedule: One Cycle, cosine, no cycle_momentum
|
|
|
|
|
|
| 68 |
|
| 69 |
The model must predict the bolded parts. So, we randomly mask tokens from the bolded parts, and run the model once on the masked sequence and once on the full sequence.
|
| 70 |
|
| 71 |
+
We then compute a divergence loss `D(p_masked, p_full)` between the two predictions. For this, I used the average of the backwards and forwards KL divergences between the predictions.
|
| 72 |
|
| 73 |
Finally, we add this loss to the standard cross-entropy language modeling losses from each prediction, with a weighting value:
|
| 74 |
```
|
| 75 |
+
loss = 0.5*(CE(p_masked, labels) + CE(p_full, labels)) + weight*D(p_masked, p_full)
|
| 76 |
```
|
| 77 |
|
| 78 |
***ReMask-CoT:***
|
|
|
|
| 103 |
| Masked Thought | 24.18% | *43.60%* |
|
| 104 |
| **ReMask** | **27.90%** | 43.26% |
|
| 105 |
|
| 106 |
+
As I expected, it improves GSM8K, but doesn't do much to ARC.
|
| 107 |
|
| 108 |
## Training details
|
| 109 |
- Framework: PyTorch Lightning
|
|
|
|
| 115 |
- Batch size: 16, accumulated to 256
|
| 116 |
- Epochs: 6
|
| 117 |
- Learning rate: 1e-5
|
| 118 |
+
- Learning rate schedule: One Cycle, cosine, no cycle_momentum
|
| 119 |
+
- Regularization weight: 0.1
|