turtle170 commited on
Commit
33674dc
·
verified ·
1 Parent(s): 1d00821

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +73 -3
README.md CHANGED
@@ -1,3 +1,73 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ tags:
5
+ - jax
6
+ - reinforcement-learning
7
+ - evolution
8
+ - artificial-life
9
+ - cellular-automata
10
+ pipeline_tag: reinforcement-learning
11
+ library_name: jax
12
+ ---
13
+
14
+ ## 📊 Model Profile
15
+
16
+ | Feature | Specification |
17
+ | :---------------- | :---------------------------------------------------- |
18
+ | **Model Name** | ViT-DeepRL-1M |
19
+ | **Architecture** | Vision Transformer (ViT) Encoder + Conv2DTranspose Decoder |
20
+ | **Parameters** | ~1,000,000 (1.05M) |
21
+ | **Grid Size** | 128 x 128 |
22
+ | **Channels** | 8 (Life, Food, Lava, 5x Internal Signaling/State) |
23
+ | **Patch Size** | 8 x 8 (256 Total Tokens) |
24
+ | **Embedding Dim** | 192 |
25
+ | **Heads / Depth** | 6 Heads (Key Dim 32) / 3 Transformer Blocks |
26
+ | **Activation** | Swish ($x \cdot \text{sigmoid}(x)$) |
27
+
28
+ ---
29
+
30
+ ## 🧬 Architecture & Logic
31
+
32
+ Unlike the other local-only Neural Cellular Automata (NCA) models in the DeepRL series, this agent treats the world as a series of visual tokens, allowing for non-local decision making.
33
+
34
+ 1. **Linear Projection of Patches:** The 128x128x8 grid is partitioned into 256 patches ($8 \times 8$). Each patch is flattened and projected into a 192-dimensional embedding space.
35
+ 2. **Global Self-Attention:** Three Transformer blocks allow every patch to attend to every other patch. This enables "Life" pixels in one quadrant to perceive and move toward "Food" clusters in another quadrant without needing a continuous "scent" trail.
36
+ 3. **Generative Decoding:** The latent representation ($16 \times 16 \times 192$) is upscaled through a `Conv2DTranspose` layer to reconstruct a high-resolution 128x128 update map.
37
+
38
+ ---
39
+
40
+ ## 🚀 Training Environment (GCP TPU v5e-16)
41
+
42
+ * **Framework:** JAX + Keras 3 (JAX Backend).
43
+ * **Hardware:** Single-host Google Cloud TPU v5e-16 (TRC Program).
44
+ * **Optimization:** AdamW ($3 \times 10^{-4}$ Learning Rate, $1 \times 10^{-4}$ Weight Decay).
45
+ * **Batching:** 4 batches per device (replicated across TPU cores).
46
+ * **Reward Function:**
47
+ * **Food Consumption:** $+150.0$ per overlap.
48
+ * **Lava Contact:** $-300.0$ penalty.
49
+ * **Extinction Event:** $-30,000.0$ if total mass $< 5.0$.
50
+ * **Metabolism:** Constant $-0.003$ decay per step to discourage stationary camping.
51
+
52
+ ---
53
+
54
+ ## 💻 Hardware Target & Deployment
55
+
56
+ * **Primary Target:** Intel Core i7-7700 and above / 16GB RAM and above.
57
+ * **Inference:** Designed for high-speed JAX/XLA execution on standard x86 hardware.
58
+ * **Storage:** Distributed as a ragged NumPy object array (`.npy`) containing the Transformer weights.
59
+
60
+ ---
61
+
62
+ ## 🛠️ Usage (Loading Weights)
63
+
64
+ ```python
65
+ import numpy as np
66
+ # Note: allow_pickle=True is required for the ragged object array format
67
+ weights = np.load("ViT-DeepRL-1M.npy", allow_pickle=True)
68
+
69
+ # Parameter structure follows build_vit_1m() layer order:
70
+ # [0]: Patch projection kernels
71
+ # [1]: Positional embeddings
72
+ # [2-13]: Transformer layers (Attention + Dense blocks)
73
+ # [14]: Conv2DTranspose decoder weights