πŸ“Š Model Profile

Feature Specification
Model Name ViT-DeepRL-1M
Architecture Vision Transformer (ViT) Encoder + Conv2DTranspose Decoder
Parameters ~1,000,000 (1.05M)
Grid Size 128 x 128
Channels 8 (Life, Food, Lava, 5x Internal Signaling/State)
Patch Size 8 x 8 (256 Total Tokens)
Embedding Dim 192
Heads / Depth 6 Heads (Key Dim 32) / 3 Transformer Blocks
Activation Swish ($x \cdot \text{sigmoid}(x)$)

🧬 Architecture & Logic

Unlike the other local-only Neural Cellular Automata (NCA) models in the DeepRL series, this agent treats the world as a series of visual tokens, allowing for non-local decision making.

  1. Linear Projection of Patches: The 128x128x8 grid is partitioned into 256 patches ($8 \times 8$). Each patch is flattened and projected into a 192-dimensional embedding space.
  2. Global Self-Attention: Three Transformer blocks allow every patch to attend to every other patch. This enables "Life" pixels in one quadrant to perceive and move toward "Food" clusters in another quadrant without needing a continuous "scent" trail.
  3. Generative Decoding: The latent representation ($16 \times 16 \times 192$) is upscaled through a Conv2DTranspose layer to reconstruct a high-resolution 128x128 update map.

πŸš€ Training Environment (GCP TPU v5e-16)

  • Framework: JAX + Keras 3 (JAX Backend).
  • Hardware: Single-host Google Cloud TPU v5e-16 (TRC Program).
  • Optimization: AdamW ($3 \times 10^{-4}$ Learning Rate, $1 \times 10^{-4}$ Weight Decay).
  • Batching: 4 batches per device (replicated across TPU cores).
  • Reward Function:
    • Food Consumption: $+150.0$ per overlap.
    • Lava Contact: $-300.0$ penalty.
    • Extinction Event: $-30,000.0$ if total mass $< 5.0$.
    • Metabolism: Constant $-0.003$ decay per step to discourage stationary camping.

πŸ’» Hardware Target & Deployment

  • Primary Target: Intel Core i7-7700 and above / 16GB RAM and above.
  • Inference: Designed for high-speed JAX/XLA execution on standard x86 hardware.
  • Storage: Distributed as a ragged NumPy object array (.npy) containing the Transformer weights.

πŸ› οΈ Usage (Loading Weights)

import numpy as np
# Note: allow_pickle=True is required for the ragged object array format
weights = np.load("ViT-DeepRL-1M.npy", allow_pickle=True)

# Parameter structure follows build_vit_1m() layer order:
# [0]: Patch projection kernels
# [1]: Positional embeddings
# [2-13]: Transformer layers (Attention + Dense blocks)
# [14]: Conv2DTranspose decoder weights
Downloads last month

-

Downloads are not tracked for this model. How to track
Video Preview
loading