Update README.md
Browse files
README.md
CHANGED
|
@@ -224,14 +224,90 @@ The model is trained with the following components and techniques:
|
|
| 224 |
- **Optimization**: Training uses an **AdamW** optimizer with **CosineAnnealingLR** scheduler for learning rate adjustments. The **Gradient Scaler** helps prevent overflow when training with mixed precision.
|
| 225 |
- **Gradient Accumulation**: Since the model can be computationally heavy, gradients are accumulated over several steps to reduce memory usage.
|
| 226 |
- **Loss Functions**: The training process leverages a comprehensive set of custom loss functions:
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
|
| 236 |
### Evaluation
|
| 237 |
After each epoch, the model is evaluated on the validation set, computing the average loss over the dataset. The evaluation function utilizes the same loss functions as training but does not backpropagate, allowing it to be run in inference mode.
|
|
|
|
| 224 |
- **Optimization**: Training uses an **AdamW** optimizer with **CosineAnnealingLR** scheduler for learning rate adjustments. The **Gradient Scaler** helps prevent overflow when training with mixed precision.
|
| 225 |
- **Gradient Accumulation**: Since the model can be computationally heavy, gradients are accumulated over several steps to reduce memory usage.
|
| 226 |
- **Loss Functions**: The training process leverages a comprehensive set of custom loss functions:
|
| 227 |
+
|
| 228 |
+
**1. InfoNCE Loss (Info Noise Contrastive Estimation Loss):**
|
| 229 |
+
Definition: This loss function is used for contrastive learning, encouraging similar samples to be close in the embedding space while pushing dissimilar samples apart.
|
| 230 |
+
|
| 231 |
+
Formula:
|
| 232 |
+
```
|
| 233 |
+
L_InfoNCE = -log[ exp(sim(z_i, z_j) / 蟿) / 危_k exp(sim(z_i, z_k) / 蟿) ]
|
| 234 |
+
```
|
| 235 |
+
where sim() is the cosine similarity, 蟿 is the temperature parameter, z_i and z_j are paired samples, and the sum in the denominator is over all other samples in the batch.
|
| 236 |
+
|
| 237 |
+
**2. Covariance Regularization:**
|
| 238 |
+
Definition: This regularization term encourages the learned representations to have uncorrelated dimensions, promoting more diverse and informative embeddings.
|
| 239 |
+
|
| 240 |
+
Formula:
|
| 241 |
+
```
|
| 242 |
+
L_cov = 位 * (危_i,j (Cov(i,j)^2 - diag(Cov(i,j))^2))
|
| 243 |
+
```
|
| 244 |
+
where Cov is the covariance matrix of the embeddings, and 位 is a regularization coefficient.
|
| 245 |
+
|
| 246 |
+
**3. Dynamics Performance Loss:**
|
| 247 |
+
Definition: This loss measures the accuracy of predicted next states while also encouraging diverse predictions.
|
| 248 |
+
|
| 249 |
+
Formula:
|
| 250 |
+
```
|
| 251 |
+
L_dynamics = MSE(true_next_state, predicted_next_state) + 位 * Var(predicted_next_state)
|
| 252 |
+
```
|
| 253 |
+
where MSE is the mean squared error, Var is the variance, and 位 is a weighting factor.
|
| 254 |
+
|
| 255 |
+
**4. Thought Consistency Loss:**
|
| 256 |
+
Definition: This loss encourages consistency between true next states and perturbed next states.
|
| 257 |
+
|
| 258 |
+
Formula:
|
| 259 |
+
```
|
| 260 |
+
L_consistency = MSE(true_next_state, perturbed_next_state)
|
| 261 |
+
```
|
| 262 |
+
|
| 263 |
+
**5. Policy Value Joint Loss:**
|
| 264 |
+
Definition: This loss combines policy and value losses for reinforcement learning tasks.
|
| 265 |
+
|
| 266 |
+
Formula:
|
| 267 |
+
```
|
| 268 |
+
L_joint = CrossEntropy(policy_logits, true_policy) + 位 * MSE(value_pred, true_value)
|
| 269 |
+
```
|
| 270 |
+
where 位 is a weighting factor balancing policy and value losses.
|
| 271 |
+
|
| 272 |
+
**6. Action Diversity Reward:**
|
| 273 |
+
Definition: This reward encourages diversity in action embeddings.
|
| 274 |
+
|
| 275 |
+
Formula:
|
| 276 |
+
```
|
| 277 |
+
R_diversity = 位 * 危_i,j (cos_sim(a_i, a_j)^2)
|
| 278 |
+
```
|
| 279 |
+
where cos_sim is the cosine similarity between action embeddings, and 位 is a scaling factor.
|
| 280 |
+
|
| 281 |
+
**7. Expected Thought Value Loss:**
|
| 282 |
+
Definition: This loss aims to maximize the expected value from Monte Carlo Tree Search.
|
| 283 |
+
|
| 284 |
+
Formula:
|
| 285 |
+
```
|
| 286 |
+
L_ETV = -mean(mcts_best_values)
|
| 287 |
+
```
|
| 288 |
+
|
| 289 |
+
**8. Exploration Regularization:**
|
| 290 |
+
Definition: This regularization encourages exploration by rewarding less-visited actions.
|
| 291 |
+
|
| 292 |
+
Formula:
|
| 293 |
+
```
|
| 294 |
+
R_exploration = 位 * mean(危_a (1 / (visit_count(a) + 1)))
|
| 295 |
+
```
|
| 296 |
+
where 位 is a scaling factor.
|
| 297 |
+
|
| 298 |
+
**9. KL Divergence Loss:**
|
| 299 |
+
Definition: This loss measures the difference between old and new policies in policy optimization.
|
| 300 |
+
|
| 301 |
+
Formula:
|
| 302 |
+
```
|
| 303 |
+
L_KL = KL(new_policy || old_policy) = \sum_{i=1}^{n} old\_policy_i \cdot \log\left(\frac{old\_policy_i}{new\_policy_i}\right)
|
| 304 |
+
```
|
| 305 |
+
where KL is the Kullback-Leibler divergence.
|
| 306 |
+
|
| 307 |
+
L_KL is the KL divergence loss
|
| 308 |
+
old_policy and new_policy are probability distributions
|
| 309 |
+
i represents each possible outcome or action
|
| 310 |
+
n is the total number of possible outcomes or actions
|
| 311 |
|
| 312 |
### Evaluation
|
| 313 |
After each epoch, the model is evaluated on the validation set, computing the average loss over the dataset. The evaluation function utilizes the same loss functions as training but does not backpropagate, allowing it to be run in inference mode.
|