Basem1166's picture
Initial World Models checkpoint upload
5216407 verified
---
tags:
- reinforcement-learning
- world-models
- atari
- vae
- rnn
- cma-es
- pytorch
license: mit
---
# World Models - Atari Agent
This is a World Models implementation from Ha & Schmidhuber (2018) trained on Atari Breakout environment.
## Model Description
The World Models architecture consists of three main components:
1. **VAE (Variational Autoencoder)**: Compresses 64x64 RGB images into a 64-dimensional latent space
2. **RNN (Memory-Augmented Recurrent Neural Network - MDRNN)**: Predicts the next latent representation given current latent state and action
3. **Controller**: A linear controller optimized using CMA-ES to maximize cumulative reward
## Model Details
- **Latent Size**: 64
- **Hidden Size**: 256 (MDRNN)
- **Action Space**: 4 (Atari discrete actions)
- **Architecture**: Convolutional encoder/decoder for VAE, LSTM-based RNN
- **Optimization**: CMA-ES for controller training
## Usage
```python
import torch
from pathlib import Path
# Load checkpoint
checkpoint = torch.load('pytorch_model.bin')
# Access components
vae_state = checkpoint['vae_state_dict']
rnn_state = checkpoint['rnn_state_dict']
controller_state = checkpoint['controller_state_dict']
# Reconstruct models (see auto_train.py for architecture definitions)
# and load states into them
```
## Training Details
- **Environment**: Atari Breakout
- **Harvest Episodes**: 100
- **VAE Epochs**: 20
- **RNN Epochs**: 20
- **CMA-ES Generations**: 25
- **Population Size**: 32
- **Framework**: PyTorch
- **Training Script**: auto_train.py
## References
- Ha, D., & Schmidhuber, J. (2018). World Models. arXiv preprint arXiv:1803.10122.
- Original paper: https://worldmodels.github.io/
## License
MIT License