Text Generation
Transformers
Safetensors
English
stablelm
conversational
euclaise commited on
Commit
01e78cf
·
verified ·
1 Parent(s): 5ac31cd

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +5 -4
README.md CHANGED
@@ -68,11 +68,11 @@ Consider the following chat interaction:
68
 
69
  The model must predict the bolded parts. So, we randomly mask tokens from the bolded parts, and run the model once on the masked sequence and once on the full sequence.
70
 
71
- We then compute a distance loss `D(p_masked, p_full)` between the two predictions. For this, I used the average of the backwards and forwards KL divergences between the predictions.
72
 
73
  Finally, we add this loss to the standard cross-entropy language modeling losses from each prediction, with a weighting value:
74
  ```
75
- loss = CE(p_masked, labels) + CE(p_full, labels) + weight*D(p_masked, p_full)
76
  ```
77
 
78
  ***ReMask-CoT:***
@@ -103,7 +103,7 @@ Here are some benchmark results, computed using the the LM Evaluation Harness wi
103
  | Masked Thought | 24.18% | *43.60%* |
104
  | **ReMask** | **27.90%** | 43.26% |
105
 
106
- As I expected, it improves GSM8K doesn't do much to ARC.
107
 
108
  ## Training details
109
  - Framework: PyTorch Lightning
@@ -115,4 +115,5 @@ As I expected, it improves GSM8K doesn't do much to ARC.
115
  - Batch size: 16, accumulated to 256
116
  - Epochs: 6
117
  - Learning rate: 1e-5
118
- - Learning rate schedule: One Cycle, cosine, no cycle_momentum
 
 
68
 
69
  The model must predict the bolded parts. So, we randomly mask tokens from the bolded parts, and run the model once on the masked sequence and once on the full sequence.
70
 
71
+ We then compute a divergence loss `D(p_masked, p_full)` between the two predictions. For this, I used the average of the backwards and forwards KL divergences between the predictions.
72
 
73
  Finally, we add this loss to the standard cross-entropy language modeling losses from each prediction, with a weighting value:
74
  ```
75
+ loss = 0.5*(CE(p_masked, labels) + CE(p_full, labels)) + weight*D(p_masked, p_full)
76
  ```
77
 
78
  ***ReMask-CoT:***
 
103
  | Masked Thought | 24.18% | *43.60%* |
104
  | **ReMask** | **27.90%** | 43.26% |
105
 
106
+ As I expected, it improves GSM8K, but doesn't do much to ARC.
107
 
108
  ## Training details
109
  - Framework: PyTorch Lightning
 
115
  - Batch size: 16, accumulated to 256
116
  - Epochs: 6
117
  - Learning rate: 1e-5
118
+ - Learning rate schedule: One Cycle, cosine, no cycle_momentum
119
+ - Regularization weight: 0.1