Update README.md
Browse files
README.md
CHANGED
|
@@ -206,10 +206,8 @@ This model is a **World Model** that combines **Transformers**, **Mixture of Exp
|
|
| 206 |
1. **Transformer**: The model uses a custom Transformer with rotary positional encoding and Mixture of Experts (MoE) layers. It serves as both an encoder and decoder, enabling sequential processing of input and target data.
|
| 207 |
2. **MCTS**: The Monte Carlo Tree Search module iteratively simulates actions to select the best possible path based on exploration and exploitation.
|
| 208 |
3. **PPO Agent**: A Proximal Policy Optimization agent is employed to update the policy and value functions. PPO loss is combined with other regularization losses to improve model performance.
|
| 209 |
-
4. **Custom
|
| 210 |
|
| 211 |
-
### Intended Use
|
| 212 |
-
This model is suitable for tasks that require complex decision-making and optimization based on action-state transitions. It can be applied in fields like game development, reinforcement learning environments, and AI simulation tasks where sequential decision-making and policy optimization are essential.
|
| 213 |
|
| 214 |
## Model Architecture
|
| 215 |
|
|
@@ -245,36 +243,7 @@ This process allows the model to generate more fluent and accurate sequences by
|
|
| 245 |
|
| 246 |
---
|
| 247 |
|
| 248 |
-
##
|
| 249 |
-
|
| 250 |
-
The Transformer architecture, introduced in the paper "Attention is All You Need," is a powerful neural network design for handling sequential data, especially in natural language processing tasks. Transformers are known for their parallelism and ability to capture long-range dependencies in data.
|
| 251 |
-
|
| 252 |
-
#### Key Components of the Transformer
|
| 253 |
-
|
| 254 |
-
1. **Embeddings and Positional Encoding**:
|
| 255 |
-
- The input tokens are embedded into dense vectors. Since Transformers do not inherently encode the sequence order (as opposed to RNNs), they require **positional encodings**. These encodings are added to the embeddings to provide information about the token positions in the sequence.
|
| 256 |
-
|
| 257 |
-
2. **Multi-Head Self-Attention**:
|
| 258 |
-
- Each token in a sequence attends to every other token, capturing dependencies regardless of distance. Multiple attention heads allow the model to focus on different parts of the sequence, extracting varied features.
|
| 259 |
-
- In self-attention, the model computes **query**, **key**, and **value** vectors for each token. The output is a weighted sum of values, where the weights are determined by the similarity between the query and key vectors.
|
| 260 |
-
|
| 261 |
-
3. **Feedforward Neural Networks**:
|
| 262 |
-
- After self-attention, a position-wise feedforward neural network is applied to each token independently. This network consists of two linear layers with a ReLU or GELU activation function in between.
|
| 263 |
-
|
| 264 |
-
4. **Layer Normalization and Residual Connections**:
|
| 265 |
-
- To improve learning stability, **layer normalization** is applied. Residual connections help the model to learn effectively by adding the input of a layer to its output, allowing gradients to flow more easily during backpropagation.
|
| 266 |
-
|
| 267 |
-
5. **Stacking of Layers**:
|
| 268 |
-
- The Transformer consists of **multiple encoder and decoder layers**. Each encoder layer is identical and consists of self-attention and feedforward layers. The decoder layers include an additional cross-attention mechanism to attend to the encoder's output.
|
| 269 |
-
|
| 270 |
-
6. **Final Linear and Softmax Layer**:
|
| 271 |
-
- The final output of the decoder layer is passed through a linear layer, projecting it onto the vocabulary size. A **softmax** function then converts the output into a probability distribution over the vocabulary, from which the next token is selected or sampled.
|
| 272 |
-
|
| 273 |
-
#### Encoder-Decoder Structure
|
| 274 |
-
|
| 275 |
-
- **Encoder**: The encoder processes the input sequence into a contextualized representation that captures relationships between tokens. It consists of multiple layers of self-attention and feedforward networks.
|
| 276 |
-
- **Decoder**: The decoder generates the output sequence by attending to both the encoded input representation (using cross-attention) and previously generated tokens (using self-attention). The decoder's output is used to predict the next token in the sequence.
|
| 277 |
-
|
| 278 |
|
| 279 |
2. **Representation Network**: This module encodes the Transformer output to generate a state representation, reducing dimensionality and making it suitable for further processing.
|
| 280 |
3. **Dynamics Network**: This module predicts the next state given a current state and an action. It uses layer normalization and a GELU activation function.
|
|
@@ -299,8 +268,6 @@ thought_1 = {P1, ... , PN}
|
|
| 299 |
The model explores and exploits thoughts, policies, actions, and tokens, and learning happens at each step of granularity.
|
| 300 |
|
| 301 |
|
| 302 |
-
|
| 303 |
-
|
| 304 |
## Training Details
|
| 305 |
|
| 306 |
The model is trained with the following components and techniques:
|
|
@@ -325,77 +292,6 @@ After each epoch, the model is evaluated on the validation set, computing the av
|
|
| 325 |
### Checkpoints
|
| 326 |
At the end of each epoch, the model saves checkpoints of all components, enabling easy resumption or further fine-tuning as needed.
|
| 327 |
|
| 328 |
-
## Usage
|
| 329 |
-
|
| 330 |
-
To use this model, ensure you have the necessary libraries installed, including `torch`, `transformers`, `datasets`, and `argparse`. The model can be initialized with pre-trained weights for the Transformer, and custom paths for saving checkpoints can be specified. Here’s an example of how to start training:
|
| 331 |
-
|
| 332 |
-
# To Train Language Model
|
| 333 |
-
```bash
|
| 334 |
-
|
| 335 |
-
python your_script.py --model_name "gpt2" --dataset_name "wikitext" --dataset_config "wikitext-2-raw-v1" --batch_size 2 --num_epochs 3 --transformer_model_path "path/to/transformer/model"
|
| 336 |
-
```
|
| 337 |
-
|
| 338 |
-
# To Train World Model
|
| 339 |
-
```bash
|
| 340 |
-
|
| 341 |
-
python lightbulb_WM.py --model_name 'gpt2' --dataset_name 'wikitext' --dataset_config 'wikitext-2-raw-v1' --batch_size 2 --num_epochs 3 --max_length 128 --learning_rate 1e-4 --save_dir './models' --transformer_model_path 'path/to/transformer/model'
|
| 342 |
-
```
|
| 343 |
-
|
| 344 |
-
# Language Model Args:
|
| 345 |
-
|
| 346 |
-
parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name or path')
|
| 347 |
-
parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name from HuggingFace Datasets')
|
| 348 |
-
parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name')
|
| 349 |
-
parser.add_argument('--batch_size', type=int, default=8, help='Batch size')
|
| 350 |
-
parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs')
|
| 351 |
-
parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length')
|
| 352 |
-
parser.add_argument('--accumulation_steps', type=int, default=4, help='Gradient accumulation steps')
|
| 353 |
-
parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate')
|
| 354 |
-
parser.add_argument('--weight_decay', type=float, default=1e-2, help='Weight decay')
|
| 355 |
-
parser.add_argument('--alpha', type=float, default=0.1, help='Entropy regularization weight')
|
| 356 |
-
parser.add_argument('--beta', type=float, default=0.1, help='Variance regularization weight')
|
| 357 |
-
parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Max gradient norm for clipping')
|
| 358 |
-
parser.add_argument('--save_dir', type=str, default='./models', help='Directory to save the models')
|
| 359 |
-
parser.add_argument('--temperature', type=float, default=1.0, help='Temperature parameter for entropy and variance')
|
| 360 |
-
|
| 361 |
-
# World Model Args:
|
| 362 |
-
|
| 363 |
-
parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name or path')
|
| 364 |
-
parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name from HuggingFace Datasets')
|
| 365 |
-
parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name')
|
| 366 |
-
parser.add_argument('--batch_size', type=int, default=2, help='Batch size')
|
| 367 |
-
parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs')
|
| 368 |
-
parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length')
|
| 369 |
-
parser.add_argument('--mcts_iterations', type=int, default=5, help='Number of MCTS Iterations')
|
| 370 |
-
parser.add_argument('--mcts_exploration_constant', type=float, default=1.414, help='Learning rate')
|
| 371 |
-
parser.add_argument('--accumulation_steps', type=int, default=4, help='Gradient accumulation steps')
|
| 372 |
-
parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate')
|
| 373 |
-
parser.add_argument('--weight_decay', type=float, default=1e-2, help='Weight decay')
|
| 374 |
-
parser.add_argument('--alpha', type=float, default=0.1, help='Entropy regularization weight')
|
| 375 |
-
parser.add_argument('--beta', type=float, default=0.1, help='Variance regularization weight')
|
| 376 |
-
parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Max gradient norm for clipping')
|
| 377 |
-
parser.add_argument('--save_dir', type=str, default='./models', help='Directory to save the models')
|
| 378 |
-
parser.add_argument('--temperature', type=float, default=1.0, help='Temperature parameter for entropy and variance')
|
| 379 |
-
parser.add_argument('--transformer_model_path', type=str, required=True, help='Path to the saved Transformer model')
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
This script will train the model on the specified dataset for the defined number of epochs, using a batch size of 2, and loading a pretrained Transformer model from the specified path.
|
| 384 |
-
|
| 385 |
-
### Model Hyperparameters
|
| 386 |
-
Here are the main parameters you can set:
|
| 387 |
-
- `--model_name`: Name of the pretrained model for tokenization.
|
| 388 |
-
- `--dataset_name`: Hugging Face dataset name.
|
| 389 |
-
- `--batch_size`: Batch size for training.
|
| 390 |
-
- `--num_epochs`: Number of epochs to train.
|
| 391 |
-
- `--max_length`: Max sequence length.
|
| 392 |
-
- `--transformer_model_path`: Path to the pretrained Transformer model.
|
| 393 |
-
- `--learning_rate`: Learning rate for optimizer.
|
| 394 |
-
- `--save_dir`: Directory to save model checkpoints.
|
| 395 |
-
- `--temperature`, `--alpha`, `--beta`, `--lambda_reg`: Hyperparameters for regularization.
|
| 396 |
-
|
| 397 |
-
### Expected Results
|
| 398 |
-
As training proceeds, you should see progressively lower training and evaluation losses. Upon completion, the model can perform complex decision-making tasks by generating sequences of actions with MCTS and PPO optimization.
|
| 399 |
|
| 400 |
## Requirements
|
| 401 |
|
|
@@ -406,14 +302,7 @@ This code requires:
|
|
| 406 |
- `datasets`
|
| 407 |
- `argparse`
|
| 408 |
|
| 409 |
-
## Limitations
|
| 410 |
-
|
| 411 |
-
Due to the heavy computational nature of this model, training time may be significant, especially on a CPU. GPU support is recommended for efficient training. Additionally, the MCTS and PPO implementations here are designed for demonstration purposes and may need further tuning for specific use cases.
|
| 412 |
|
| 413 |
## Citation
|
| 414 |
|
| 415 |
If you use this model in your research, please cite the author.
|
| 416 |
-
|
| 417 |
-
---
|
| 418 |
-
|
| 419 |
-
This model card should provide an overview for anyone looking to understand, utilize, or modify your World Model with MCTS and Transformer components.
|
|
|
|
| 206 |
1. **Transformer**: The model uses a custom Transformer with rotary positional encoding and Mixture of Experts (MoE) layers. It serves as both an encoder and decoder, enabling sequential processing of input and target data.
|
| 207 |
2. **MCTS**: The Monte Carlo Tree Search module iteratively simulates actions to select the best possible path based on exploration and exploitation.
|
| 208 |
3. **PPO Agent**: A Proximal Policy Optimization agent is employed to update the policy and value functions. PPO loss is combined with other regularization losses to improve model performance.
|
| 209 |
+
4. **Custom Objective Functions**: Several custom loss functions are implemented to help guide the model’s learning, including Covariance Regularization, Dynamics Performance Loss, Thought Consistency Loss, and more.
|
| 210 |
|
|
|
|
|
|
|
| 211 |
|
| 212 |
## Model Architecture
|
| 213 |
|
|
|
|
| 243 |
|
| 244 |
---
|
| 245 |
|
| 246 |
+
## World Model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
|
| 248 |
2. **Representation Network**: This module encodes the Transformer output to generate a state representation, reducing dimensionality and making it suitable for further processing.
|
| 249 |
3. **Dynamics Network**: This module predicts the next state given a current state and an action. It uses layer normalization and a GELU activation function.
|
|
|
|
| 268 |
The model explores and exploits thoughts, policies, actions, and tokens, and learning happens at each step of granularity.
|
| 269 |
|
| 270 |
|
|
|
|
|
|
|
| 271 |
## Training Details
|
| 272 |
|
| 273 |
The model is trained with the following components and techniques:
|
|
|
|
| 292 |
### Checkpoints
|
| 293 |
At the end of each epoch, the model saves checkpoints of all components, enabling easy resumption or further fine-tuning as needed.
|
| 294 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
|
| 296 |
## Requirements
|
| 297 |
|
|
|
|
| 302 |
- `datasets`
|
| 303 |
- `argparse`
|
| 304 |
|
|
|
|
|
|
|
|
|
|
| 305 |
|
| 306 |
## Citation
|
| 307 |
|
| 308 |
If you use this model in your research, please cite the author.
|
|
|
|
|
|
|
|
|
|
|
|