|
|
---
|
|
|
tags:
|
|
|
- reinforcement-learning
|
|
|
- world-models
|
|
|
- atari
|
|
|
- vae
|
|
|
- rnn
|
|
|
- cma-es
|
|
|
- pytorch
|
|
|
license: mit
|
|
|
---
|
|
|
|
|
|
# World Models - Atari Agent
|
|
|
|
|
|
This is a World Models implementation from Ha & Schmidhuber (2018) trained on Atari Breakout environment.
|
|
|
|
|
|
## Model Description
|
|
|
|
|
|
The World Models architecture consists of three main components:
|
|
|
|
|
|
1. **VAE (Variational Autoencoder)**: Compresses 64x64 RGB images into a 64-dimensional latent space
|
|
|
2. **RNN (Memory-Augmented Recurrent Neural Network - MDRNN)**: Predicts the next latent representation given current latent state and action
|
|
|
3. **Controller**: A linear controller optimized using CMA-ES to maximize cumulative reward
|
|
|
|
|
|
## Model Details
|
|
|
|
|
|
- **Latent Size**: 64
|
|
|
- **Hidden Size**: 256 (MDRNN)
|
|
|
- **Action Space**: 4 (Atari discrete actions)
|
|
|
- **Architecture**: Convolutional encoder/decoder for VAE, LSTM-based RNN
|
|
|
- **Optimization**: CMA-ES for controller training
|
|
|
|
|
|
## Usage
|
|
|
|
|
|
```python
|
|
|
import torch
|
|
|
from pathlib import Path
|
|
|
|
|
|
# Load checkpoint
|
|
|
checkpoint = torch.load('pytorch_model.bin')
|
|
|
|
|
|
# Access components
|
|
|
vae_state = checkpoint['vae_state_dict']
|
|
|
rnn_state = checkpoint['rnn_state_dict']
|
|
|
controller_state = checkpoint['controller_state_dict']
|
|
|
|
|
|
# Reconstruct models (see auto_train.py for architecture definitions)
|
|
|
# and load states into them
|
|
|
```
|
|
|
|
|
|
## Training Details
|
|
|
|
|
|
- **Environment**: Atari Breakout
|
|
|
- **Harvest Episodes**: 100
|
|
|
- **VAE Epochs**: 20
|
|
|
- **RNN Epochs**: 20
|
|
|
- **CMA-ES Generations**: 25
|
|
|
- **Population Size**: 32
|
|
|
- **Framework**: PyTorch
|
|
|
- **Training Script**: auto_train.py
|
|
|
|
|
|
## References
|
|
|
|
|
|
- Ha, D., & Schmidhuber, J. (2018). World Models. arXiv preprint arXiv:1803.10122.
|
|
|
- Original paper: https://worldmodels.github.io/
|
|
|
|
|
|
## License
|
|
|
|
|
|
MIT License
|
|
|
|