--- tags: - reinforcement-learning - world-models - atari - vae - rnn - cma-es - pytorch license: mit --- # World Models - Atari Agent This is a World Models implementation from Ha & Schmidhuber (2018) trained on Atari Breakout environment. ## Model Description The World Models architecture consists of three main components: 1. **VAE (Variational Autoencoder)**: Compresses 64x64 RGB images into a 64-dimensional latent space 2. **RNN (Memory-Augmented Recurrent Neural Network - MDRNN)**: Predicts the next latent representation given current latent state and action 3. **Controller**: A linear controller optimized using CMA-ES to maximize cumulative reward ## Model Details - **Latent Size**: 64 - **Hidden Size**: 256 (MDRNN) - **Action Space**: 4 (Atari discrete actions) - **Architecture**: Convolutional encoder/decoder for VAE, LSTM-based RNN - **Optimization**: CMA-ES for controller training ## Usage ```python import torch from pathlib import Path # Load checkpoint checkpoint = torch.load('pytorch_model.bin') # Access components vae_state = checkpoint['vae_state_dict'] rnn_state = checkpoint['rnn_state_dict'] controller_state = checkpoint['controller_state_dict'] # Reconstruct models (see auto_train.py for architecture definitions) # and load states into them ``` ## Training Details - **Environment**: Atari Breakout - **Harvest Episodes**: 100 - **VAE Epochs**: 20 - **RNN Epochs**: 20 - **CMA-ES Generations**: 25 - **Population Size**: 32 - **Framework**: PyTorch - **Training Script**: auto_train.py ## References - Ha, D., & Schmidhuber, J. (2018). World Models. arXiv preprint arXiv:1803.10122. - Original paper: https://worldmodels.github.io/ ## License MIT License