Basem1166 commited on
Commit
5216407
·
verified ·
1 Parent(s): 9e4022a

Initial World Models checkpoint upload

Browse files
Files changed (3) hide show
  1. README.md +69 -0
  2. config.json +15 -0
  3. pytorch_model.bin +3 -0
README.md ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - reinforcement-learning
4
+ - world-models
5
+ - atari
6
+ - vae
7
+ - rnn
8
+ - cma-es
9
+ - pytorch
10
+ license: mit
11
+ ---
12
+
13
+ # World Models - Atari Agent
14
+
15
+ This is a World Models implementation from Ha & Schmidhuber (2018) trained on Atari Breakout environment.
16
+
17
+ ## Model Description
18
+
19
+ The World Models architecture consists of three main components:
20
+
21
+ 1. **VAE (Variational Autoencoder)**: Compresses 64x64 RGB images into a 64-dimensional latent space
22
+ 2. **RNN (Memory-Augmented Recurrent Neural Network - MDRNN)**: Predicts the next latent representation given current latent state and action
23
+ 3. **Controller**: A linear controller optimized using CMA-ES to maximize cumulative reward
24
+
25
+ ## Model Details
26
+
27
+ - **Latent Size**: 64
28
+ - **Hidden Size**: 256 (MDRNN)
29
+ - **Action Space**: 4 (Atari discrete actions)
30
+ - **Architecture**: Convolutional encoder/decoder for VAE, LSTM-based RNN
31
+ - **Optimization**: CMA-ES for controller training
32
+
33
+ ## Usage
34
+
35
+ ```python
36
+ import torch
37
+ from pathlib import Path
38
+
39
+ # Load checkpoint
40
+ checkpoint = torch.load('pytorch_model.bin')
41
+
42
+ # Access components
43
+ vae_state = checkpoint['vae_state_dict']
44
+ rnn_state = checkpoint['rnn_state_dict']
45
+ controller_state = checkpoint['controller_state_dict']
46
+
47
+ # Reconstruct models (see auto_train.py for architecture definitions)
48
+ # and load states into them
49
+ ```
50
+
51
+ ## Training Details
52
+
53
+ - **Environment**: Atari Breakout
54
+ - **Harvest Episodes**: 100
55
+ - **VAE Epochs**: 20
56
+ - **RNN Epochs**: 20
57
+ - **CMA-ES Generations**: 25
58
+ - **Population Size**: 32
59
+ - **Framework**: PyTorch
60
+ - **Training Script**: auto_train.py
61
+
62
+ ## References
63
+
64
+ - Ha, D., & Schmidhuber, J. (2018). World Models. arXiv preprint arXiv:1803.10122.
65
+ - Original paper: https://worldmodels.github.io/
66
+
67
+ ## License
68
+
69
+ MIT License
config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architecture": "WorldModels",
3
+ "latent_size": 64,
4
+ "hidden_size": 256,
5
+ "action_size": 4,
6
+ "n_gaussians": 5,
7
+ "components": {
8
+ "vae": "Variational Autoencoder for visual compression",
9
+ "rnn": "MDRNN (Memory-Augmented Recurrent Neural Network)",
10
+ "controller": "CMA-ES optimized linear controller"
11
+ },
12
+ "trained_on": "Atari environment (Breakout)",
13
+ "upload_date": "2025-12-25T08:13:58.344484",
14
+ "framework": "PyTorch"
15
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5e4139a4a780a9edcf8bb1b81f903d15570f33d8a4266584304898d9480c3f5f
3
+ size 32753806