loayahmed123 commited on
Commit
7c219ec
·
verified ·
1 Parent(s): d9fbaa0

Upload World Models for Space Invaders

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ sample_gameplay.mp4 filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - reinforcement-learning
4
+ - world-models
5
+ - atari
6
+ - space-invaders
7
+ - deep-learning
8
+ library_name: pytorch
9
+ ---
10
+
11
+ # World Models for Space Invaders
12
+
13
+ This is a World Models agent trained on the `SpaceInvadersNoFrameskip-v4` environment.
14
+
15
+ ## Model Description
16
+
17
+ World Models is a model-based reinforcement learning approach that learns a compressed representation
18
+ of the environment and trains a controller to maximize reward in the learned model.
19
+
20
+ The architecture consists of three components:
21
+ - **V (Vision)**: Variational Autoencoder that compresses 64x64 RGB frames to 32-dimensional latent vectors
22
+ - **M (Memory)**: MDN-RNN that predicts the next latent state given current state and action
23
+ - **C (Controller)**: Linear policy trained with CMA-ES evolution strategy
24
+
25
+ ## Training Details
26
+
27
+ ### Hyperparameters
28
+ - VAE Latent Dimension: 32
29
+ - RNN Hidden Dimension: 256
30
+ - Number of Gaussian Mixtures: 5
31
+ - Population Size (CMA-ES): 64
32
+ - Training Episodes: 100
33
+ - VAE Epochs: 10
34
+ - RNN Epochs: 20
35
+ - Controller Generations: 10
36
+
37
+ ## Evaluation Results
38
+
39
+ - **Mean Reward**: 155.00 ± 0.00
40
+ - **Max Reward**: 155.00
41
+ - **Mean Episode Length**: 1000.00
42
+
43
+ ## Usage
44
+
45
+ ```python
46
+ import torch
47
+ import gymnasium as gym
48
+
49
+ # Load models
50
+ vae = VAE(latent_dim=32)
51
+ vae.load_state_dict(torch.load('vae_model.pt'))
52
+
53
+ rnn = MDNRNN(latent_dim=32, action_dim=6)
54
+ rnn.load_state_dict(torch.load('mdnrnn_model.pt'))
55
+
56
+ controller = Controller(latent_dim=32, hidden_dim=256)
57
+ controller.load_state_dict(torch.load('controller_model.pt'))
58
+
59
+ # Run agent
60
+ env = gym.make('SpaceInvadersNoFrameskip-v4')
61
+ # ... (see repository for full inference code)
62
+ ```
63
+
64
+ ## References
65
+
66
+ - Paper: [World Models (Ha & Schmidhuber, 2018)](https://worldmodels.github.io/)
67
+ - Code: Based on the original World Models implementation
68
+
69
+ ## Citation
70
+
71
+ ```bibtex
72
+ @article{ha2018worldmodels,
73
+ title={World Models},
74
+ author={Ha, David and Schmidhuber, J{\"u}rgen},
75
+ journal={arXiv preprint arXiv:1803.10122},
76
+ year={2018}
77
+ }
78
+ ```
config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "env_name": "SpaceInvadersNoFrameskip-v4",
3
+ "max_episode_steps": 1000,
4
+ "latent_dim": 32,
5
+ "hidden_dim": 256,
6
+ "num_mixtures": 5,
7
+ "action_dim": 6,
8
+ "vae_epochs": 10,
9
+ "vae_batch_size": 32,
10
+ "vae_lr": 0.0001,
11
+ "num_random_episodes": 100,
12
+ "rnn_epochs": 20,
13
+ "rnn_batch_size": 16,
14
+ "rnn_lr": 0.0001,
15
+ "sequence_length": 32,
16
+ "population_size": 64,
17
+ "num_generations": 10,
18
+ "sigma": 0.5,
19
+ "eval_episodes": 10,
20
+ "record_video_every": 10,
21
+ "device": "cuda"
22
+ }
controller_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac7c86f564aef60093bf7fcd22b86e3b5b9f07e59a0b43c5822566203ad3e44c
3
+ size 8941
eval_stats.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "mean_reward": 155.0,
3
+ "std_reward": 0.0,
4
+ "mean_length": 1000.0,
5
+ "min_reward": 155.0,
6
+ "max_reward": 155.0
7
+ }
mdnrnn_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0962ecad39b967e2b3d34d6954f51687a5ca30d7ad97edbe2032daaecd0259a1
3
+ size 1549070
sample_gameplay.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05b84dc8b3b4a842081ac4d3277ca78e7b86e45a14e5ab3814ad8a693c8548ca
3
+ size 109782
vae_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9bf44de7a3baa9fab029a96395ddc1ae28b7931bef2682e11cc3b12eb23cd9a
3
+ size 7117721