π Model Profile
| Feature | Specification |
|---|---|
| Model Name | ViT-DeepRL-1M |
| Architecture | Vision Transformer (ViT) Encoder + Conv2DTranspose Decoder |
| Parameters | ~1,000,000 (1.05M) |
| Grid Size | 128 x 128 |
| Channels | 8 (Life, Food, Lava, 5x Internal Signaling/State) |
| Patch Size | 8 x 8 (256 Total Tokens) |
| Embedding Dim | 192 |
| Heads / Depth | 6 Heads (Key Dim 32) / 3 Transformer Blocks |
| Activation | Swish ($x \cdot \text{sigmoid}(x)$) |
𧬠Architecture & Logic
Unlike the other local-only Neural Cellular Automata (NCA) models in the DeepRL series, this agent treats the world as a series of visual tokens, allowing for non-local decision making.
- Linear Projection of Patches: The 128x128x8 grid is partitioned into 256 patches ($8 \times 8$). Each patch is flattened and projected into a 192-dimensional embedding space.
- Global Self-Attention: Three Transformer blocks allow every patch to attend to every other patch. This enables "Life" pixels in one quadrant to perceive and move toward "Food" clusters in another quadrant without needing a continuous "scent" trail.
- Generative Decoding: The latent representation ($16 \times 16 \times 192$) is upscaled through a
Conv2DTransposelayer to reconstruct a high-resolution 128x128 update map.
π Training Environment (GCP TPU v5e-16)
- Framework: JAX + Keras 3 (JAX Backend).
- Hardware: Single-host Google Cloud TPU v5e-16 (TRC Program).
- Optimization: AdamW ($3 \times 10^{-4}$ Learning Rate, $1 \times 10^{-4}$ Weight Decay).
- Batching: 4 batches per device (replicated across TPU cores).
- Reward Function:
- Food Consumption: $+150.0$ per overlap.
- Lava Contact: $-300.0$ penalty.
- Extinction Event: $-30,000.0$ if total mass $< 5.0$.
- Metabolism: Constant $-0.003$ decay per step to discourage stationary camping.
π» Hardware Target & Deployment
- Primary Target: Intel Core i7-7700 and above / 16GB RAM and above.
- Inference: Designed for high-speed JAX/XLA execution on standard x86 hardware.
- Storage: Distributed as a ragged NumPy object array (
.npy) containing the Transformer weights.
π οΈ Usage (Loading Weights)
import numpy as np
# Note: allow_pickle=True is required for the ragged object array format
weights = np.load("ViT-DeepRL-1M.npy", allow_pickle=True)
# Parameter structure follows build_vit_1m() layer order:
# [0]: Patch projection kernels
# [1]: Positional embeddings
# [2-13]: Transformer layers (Attention + Dense blocks)
# [14]: Conv2DTranspose decoder weights