Initial World Models checkpoint upload
Browse files- README.md +69 -0
- config.json +15 -0
- pytorch_model.bin +3 -0
README.md
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
tags:
|
| 3 |
+
- reinforcement-learning
|
| 4 |
+
- world-models
|
| 5 |
+
- atari
|
| 6 |
+
- vae
|
| 7 |
+
- rnn
|
| 8 |
+
- cma-es
|
| 9 |
+
- pytorch
|
| 10 |
+
license: mit
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# World Models - Atari Agent
|
| 14 |
+
|
| 15 |
+
This is a World Models implementation from Ha & Schmidhuber (2018) trained on Atari Breakout environment.
|
| 16 |
+
|
| 17 |
+
## Model Description
|
| 18 |
+
|
| 19 |
+
The World Models architecture consists of three main components:
|
| 20 |
+
|
| 21 |
+
1. **VAE (Variational Autoencoder)**: Compresses 64x64 RGB images into a 64-dimensional latent space
|
| 22 |
+
2. **RNN (Memory-Augmented Recurrent Neural Network - MDRNN)**: Predicts the next latent representation given current latent state and action
|
| 23 |
+
3. **Controller**: A linear controller optimized using CMA-ES to maximize cumulative reward
|
| 24 |
+
|
| 25 |
+
## Model Details
|
| 26 |
+
|
| 27 |
+
- **Latent Size**: 64
|
| 28 |
+
- **Hidden Size**: 256 (MDRNN)
|
| 29 |
+
- **Action Space**: 4 (Atari discrete actions)
|
| 30 |
+
- **Architecture**: Convolutional encoder/decoder for VAE, LSTM-based RNN
|
| 31 |
+
- **Optimization**: CMA-ES for controller training
|
| 32 |
+
|
| 33 |
+
## Usage
|
| 34 |
+
|
| 35 |
+
```python
|
| 36 |
+
import torch
|
| 37 |
+
from pathlib import Path
|
| 38 |
+
|
| 39 |
+
# Load checkpoint
|
| 40 |
+
checkpoint = torch.load('pytorch_model.bin')
|
| 41 |
+
|
| 42 |
+
# Access components
|
| 43 |
+
vae_state = checkpoint['vae_state_dict']
|
| 44 |
+
rnn_state = checkpoint['rnn_state_dict']
|
| 45 |
+
controller_state = checkpoint['controller_state_dict']
|
| 46 |
+
|
| 47 |
+
# Reconstruct models (see auto_train.py for architecture definitions)
|
| 48 |
+
# and load states into them
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
## Training Details
|
| 52 |
+
|
| 53 |
+
- **Environment**: Atari Breakout
|
| 54 |
+
- **Harvest Episodes**: 100
|
| 55 |
+
- **VAE Epochs**: 20
|
| 56 |
+
- **RNN Epochs**: 20
|
| 57 |
+
- **CMA-ES Generations**: 25
|
| 58 |
+
- **Population Size**: 32
|
| 59 |
+
- **Framework**: PyTorch
|
| 60 |
+
- **Training Script**: auto_train.py
|
| 61 |
+
|
| 62 |
+
## References
|
| 63 |
+
|
| 64 |
+
- Ha, D., & Schmidhuber, J. (2018). World Models. arXiv preprint arXiv:1803.10122.
|
| 65 |
+
- Original paper: https://worldmodels.github.io/
|
| 66 |
+
|
| 67 |
+
## License
|
| 68 |
+
|
| 69 |
+
MIT License
|
config.json
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architecture": "WorldModels",
|
| 3 |
+
"latent_size": 64,
|
| 4 |
+
"hidden_size": 256,
|
| 5 |
+
"action_size": 4,
|
| 6 |
+
"n_gaussians": 5,
|
| 7 |
+
"components": {
|
| 8 |
+
"vae": "Variational Autoencoder for visual compression",
|
| 9 |
+
"rnn": "MDRNN (Memory-Augmented Recurrent Neural Network)",
|
| 10 |
+
"controller": "CMA-ES optimized linear controller"
|
| 11 |
+
},
|
| 12 |
+
"trained_on": "Atari environment (Breakout)",
|
| 13 |
+
"upload_date": "2025-12-25T08:13:58.344484",
|
| 14 |
+
"framework": "PyTorch"
|
| 15 |
+
}
|
pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5e4139a4a780a9edcf8bb1b81f903d15570f33d8a4266584304898d9480c3f5f
|
| 3 |
+
size 32753806
|