World Models for Space Invaders

This is a World Models agent trained on the SpaceInvadersNoFrameskip-v4 environment.

Model Description

World Models is a model-based reinforcement learning approach that learns a compressed representation of the environment and trains a controller to maximize reward in the learned model.

The architecture consists of three components:

  • V (Vision): Variational Autoencoder that compresses 64x64 RGB frames to 32-dimensional latent vectors
  • M (Memory): MDN-RNN that predicts the next latent state given current state and action
  • C (Controller): Linear policy trained with CMA-ES evolution strategy

Training Details

Hyperparameters

  • VAE Latent Dimension: 32
  • RNN Hidden Dimension: 256
  • Number of Gaussian Mixtures: 5
  • Population Size (CMA-ES): 64
  • Training Episodes: 100
  • VAE Epochs: 10
  • RNN Epochs: 20
  • Controller Generations: 10

Evaluation Results

  • Mean Reward: 532.00 ± 149.42
  • Max Reward: 715.00
  • Mean Episode Length: 782.80

Usage

import torch
import gymnasium as gym

# Load models
vae = VAE(latent_dim=32)
vae.load_state_dict(torch.load('vae_model.pt'))

rnn = MDNRNN(latent_dim=32, action_dim=6)
rnn.load_state_dict(torch.load('mdnrnn_model.pt'))

controller = Controller(latent_dim=32, hidden_dim=256)
controller.load_state_dict(torch.load('controller_model.pt'))

# Run agent
env = gym.make('SpaceInvadersNoFrameskip-v4')
# ... (see repository for full inference code)

References

Citation

@article{ha2018worldmodels,
  title={World Models},
  author={Ha, David and Schmidhuber, J{\"u}rgen},
  journal={arXiv preprint arXiv:1803.10122},
  year={2018}
}
Downloads last month
12
Video Preview
loading