loayahmed123 commited on
Commit
fc4d17b
·
verified ·
1 Parent(s): 9d635d1

Upload World Models for Space Invaders

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ sample_gameplay.mp4 filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - reinforcement-learning
4
+ - world-models
5
+ - atari
6
+ - space-invaders
7
+ - deep-learning
8
+ library_name: pytorch
9
+ ---
10
+
11
+ # World Models for Space Invaders
12
+
13
+ This is a World Models agent trained on the `SpaceInvadersNoFrameskip-v4` environment.
14
+
15
+ ## Model Description
16
+
17
+ World Models is a model-based reinforcement learning approach that learns a compressed representation
18
+ of the environment and trains a controller to maximize reward in the learned model.
19
+
20
+ The architecture consists of three components:
21
+ - **V (Vision)**: Variational Autoencoder that compresses 64x64 RGB frames to 32-dimensional latent vectors
22
+ - **M (Memory)**: MDN-RNN that predicts the next latent state given current state and action
23
+ - **C (Controller)**: Linear policy trained with CMA-ES evolution strategy
24
+
25
+ ## Training Details
26
+
27
+ ### Hyperparameters
28
+ - VAE Latent Dimension: 32
29
+ - RNN Hidden Dimension: 256
30
+ - Number of Gaussian Mixtures: 5
31
+ - Population Size (CMA-ES): 64
32
+ - Training Episodes: 100
33
+ - VAE Epochs: 10
34
+ - RNN Epochs: 20
35
+ - Controller Generations: 10
36
+
37
+ ## Evaluation Results
38
+
39
+ - **Mean Reward**: 506.50 ± 34.57
40
+ - **Max Reward**: 565.00
41
+ - **Mean Episode Length**: 852.00
42
+
43
+ ## Usage
44
+
45
+ ```python
46
+ import torch
47
+ import gymnasium as gym
48
+
49
+ # Load models
50
+ vae = VAE(latent_dim=32)
51
+ vae.load_state_dict(torch.load('vae_model.pt'))
52
+
53
+ rnn = MDNRNN(latent_dim=32, action_dim=6)
54
+ rnn.load_state_dict(torch.load('mdnrnn_model.pt'))
55
+
56
+ controller = Controller(latent_dim=32, hidden_dim=256)
57
+ controller.load_state_dict(torch.load('controller_model.pt'))
58
+
59
+ # Run agent
60
+ env = gym.make('SpaceInvadersNoFrameskip-v4')
61
+ # ... (see repository for full inference code)
62
+ ```
63
+
64
+ ## References
65
+
66
+ - Paper: [World Models (Ha & Schmidhuber, 2018)](https://worldmodels.github.io/)
67
+ - Code: Based on the original World Models implementation
68
+
69
+ ## Citation
70
+
71
+ ```bibtex
72
+ @article{ha2018worldmodels,
73
+ title={World Models},
74
+ author={Ha, David and Schmidhuber, J{\"u}rgen},
75
+ journal={arXiv preprint arXiv:1803.10122},
76
+ year={2018}
77
+ }
78
+ ```
config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "env_name": "ALE/SpaceInvaders-v5",
3
+ "max_episode_steps": 1000,
4
+ "latent_dim": 32,
5
+ "hidden_dim": 256,
6
+ "num_mixtures": 5,
7
+ "action_dim": 6,
8
+ "vae_epochs": 10,
9
+ "vae_batch_size": 32,
10
+ "vae_lr": 0.0001,
11
+ "num_random_episodes": 100,
12
+ "rnn_epochs": 20,
13
+ "rnn_batch_size": 16,
14
+ "rnn_lr": 0.0001,
15
+ "sequence_length": 32,
16
+ "population_size": 64,
17
+ "num_generations": 10,
18
+ "sigma": 0.5,
19
+ "eval_episodes": 10,
20
+ "record_video_every": 10,
21
+ "device": "cuda"
22
+ }
controller_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:83c1183938a96735212d6b52013a043f55857b2c48001669288ba2d02e3c53a3
3
+ size 8941
eval_stats.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "mean_reward": 506.5,
3
+ "std_reward": 34.57238782612506,
4
+ "mean_length": 852.0,
5
+ "min_reward": 460.0,
6
+ "max_reward": 565.0
7
+ }
mdnrnn_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:469d6dde4ed5c223b1af9685430526cc6bc8d6c6954ec41fa1dd70a9426c6d12
3
+ size 1549070
sample_gameplay.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b1af84b887edca9c959663e80a493c724429698b1e40e91486a17a9a73ee96f8
3
+ size 228198
vae_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f14b7c69a35616ee0422da3c102d974e86023b7063b5ad4ff79be7a88bc8e7a9
3
+ size 7117721