World Models - Atari Agent
This is a World Models implementation from Ha & Schmidhuber (2018) trained on Atari Breakout environment.
Model Description
The World Models architecture consists of three main components:
- VAE (Variational Autoencoder): Compresses 64x64 RGB images into a 64-dimensional latent space
- RNN (Memory-Augmented Recurrent Neural Network - MDRNN): Predicts the next latent representation given current latent state and action
- Controller: A linear controller optimized using CMA-ES to maximize cumulative reward
Model Details
- Latent Size: 64
- Hidden Size: 256 (MDRNN)
- Action Space: 4 (Atari discrete actions)
- Architecture: Convolutional encoder/decoder for VAE, LSTM-based RNN
- Optimization: CMA-ES for controller training
Usage
import torch
from pathlib import Path
# Load checkpoint
checkpoint = torch.load('pytorch_model.bin')
# Access components
vae_state = checkpoint['vae_state_dict']
rnn_state = checkpoint['rnn_state_dict']
controller_state = checkpoint['controller_state_dict']
# Reconstruct models (see auto_train.py for architecture definitions)
# and load states into them
Training Details
- Environment: Atari Breakout
- Harvest Episodes: 100
- VAE Epochs: 20
- RNN Epochs: 20
- CMA-ES Generations: 25
- Population Size: 32
- Framework: PyTorch
- Training Script: auto_train.py
References
- Ha, D., & Schmidhuber, J. (2018). World Models. arXiv preprint arXiv:1803.10122.
- Original paper: https://worldmodels.github.io/
License
MIT License
- Downloads last month
- 7