YAML Metadata Warning:empty or missing yaml metadata in repo card
Check out the documentation for more information.
Mouse AI - Program Generation Model
232M parameter transformer that generates movement programs for a mouse navigating a maze to collect cheese while avoiding cats.
Quick Start
import torch
from model.model_2B import StructureAwareTransformer2B
from lightweight_simulator import LightweightGameSimulator
# Load model
device = 'cuda:0' # or 'cpu'
ckpt = torch.load('model_best.pt', map_location='cpu', weights_only=False)
config = ckpt['model_config']
model = StructureAwareTransformer2B(**config)
model.load_state_dict(ckpt['model_state_dict'])
model = model.to(device)
model.eval()
# Play a game
game = LightweightGameSimulator(level=3)
game.reset()
for run in range(20):
if game.win_sign or game.lose_sign:
break
# Get state vector (828 dimensions)
state = get_state_vector(game).unsqueeze(0).to(device)
# Generate program
with torch.no_grad():
prog = model.generate(
state, max_length=12, temperature=0.3,
top_k=10, grammar_constrained=True
)
# Parse output
if isinstance(prog, tuple): prog = prog[0]
if isinstance(prog, torch.Tensor): prog = prog[0].tolist()
if prog and prog[0] == 0: prog = prog[1:] # remove start token
if 112 in prog: prog = prog[:prog.index(112)] # remove END and after
# Execute
game.execute_program(prog)
print(f"{'WIN' if game.win_sign else 'LOSE'} | Score: {game.score}")
Model Architecture
| Parameter | Value |
|---|---|
| Type | StructureAwareTransformer2B |
| Total Parameters | 232.2M |
| Hidden Dimension | 1024 |
| Layers | 16 |
| Attention Heads | 16 (Query) / 4 (KV, Grouped Query Attention) |
| Feed-Forward Dim | 4096 |
| State Input | 828 dimensions |
| Vocab Size | 113 tokens |
| Max Program Length | 12 tokens |
Model Config (for initialization)
config = {
'state_dim': 828,
'hidden_dim': 1024,
'vocab_size': 113,
'max_program_length': 12,
'num_layers': 16,
'num_heads': 16,
'num_kv_heads': 4,
'ff_dim': 4096,
'dropout': 0.1,
'end_token': 112,
}
model = StructureAwareTransformer2B(**config)
Token Vocabulary (113 tokens)
Direction Tokens (0-3)
| Token ID | Direction | Movement |
|---|---|---|
| 0 | UP | Mouse moves up one cell |
| 1 | DOWN | Mouse moves down one cell |
| 2 | LEFT | Mouse moves left one cell |
| 3 | RIGHT | Mouse moves right one cell |
Number Tokens (100-109)
| Token ID | Value | Usage |
|---|---|---|
| 100 | 1 | LOOP repeat count (1 time) |
| 104 | 5 | LOOP repeat count (5 times) |
| 105 | 6 | LOOP repeat count (6 times) |
| 106 | 7 | LOOP repeat count (7 times) |
| 107 | 8 | LOOP repeat count (8 times) |
| 108 | 9 | LOOP repeat count (9 times) |
| 109 | 10 | LOOP repeat count (10 times) |
Note: Tokens 101-103 (values 2-4) exist in vocab but are NOT used by the grammar. The model only generates NUM tokens >= 104 (5+ repeats) for efficiency.
Special Tokens
| Token ID | Name | Function |
|---|---|---|
| 110 | LOOP | Start a loop structure |
| 112 | END | End of program |
Token 111 (IF) was removed due to simulator incompatibility.
Grammar Rules
Programs follow a strict context-free grammar:
start -> DIR | LOOP NUM DIR | END
after_DIR -> DIR | LOOP NUM DIR | END
after_LOOP -> NUM (must be 104-109)
after_NUM -> DIR (must be 0-3)
after_END -> (stop generation)
Valid Program Examples
[0, 112] # Move UP, END
[2, 2, 2, 112] # Move LEFT 3 times, END
[110, 106, 1, 112] # LOOP(7 times, DOWN), END
[0, 110, 104, 2, 3, 112] # UP, LOOP(5 times, LEFT), RIGHT, END
[110, 108, 0, 110, 105, 3, 112] # LOOP(9, UP), LOOP(6, RIGHT), END
Grammar Constraint: LOOP cutoff at position 8
LOOP token (110) is only allowed at positions 0-7 (indices 0-7 in the generated sequence). From position 8 onwards, only DIR tokens and END are allowed. This prevents overly long programs.
State Vector (828 dimensions)
The 828-dimensional state vector encodes the complete game state:
def get_state_vector(sim):
"""Extract 828-dim state vector from game simulator"""
state_dict = sim.get_state_dict()
state = []
DYNAMIC_SCALE = 10.0 # Scale factor for dynamic features
# --- Grid features (11x11 grids) ---
# 1. Wall grid (121 dims): 1=wall, 0=empty
for row in state_dict['wall']:
state.extend(row)
# 2. Small Cheese grid (121 dims): 1=cheese present, 0=collected
# Scaled by DYNAMIC_SCALE (10.0)
for row in state_dict['sc']:
state.extend([v * DYNAMIC_SCALE for v in row])
# 3. Junction grid (121 dims): 1=junction, 0=not
for row in state_dict['junc']:
state.extend(row)
# 4. Dead-end grid (121 dims): 1=dead-end, 0=not
for row in state_dict['deadend']:
state.extend(row)
# Total grid: 484 dims (4 * 121)
# --- Entity positions ---
# 5. Mouse position (2 dims): [x, y]
mouse = state_dict['mouse']
state.extend([float(mouse[0]), float(mouse[1])])
# 6. Cat positions (12 dims): 6 cats * [x, y], unused=-1
cat_list = state_dict.get('cat', [])
for i in range(6):
if i < len(cat_list):
state.extend([float(cat_list[i][0]), float(cat_list[i][1])])
else:
state.extend([-1.0, -1.0])
# 7. Moving Big Cheese positions (10 dims): 5 * [x, y], unused=-1
bc_list = state_dict.get('crzbc', [])
for i in range(5):
if i < len(bc_list):
state.extend([float(bc_list[i][0]), float(bc_list[i][1])])
else:
state.extend([-1.0, -1.0])
# Pad to 549 dims (484 + 65)
while len(state) < 484 + 65:
state.append(0.0)
# --- Scalar features (6 dims) ---
# 8. Score (normalized by 1000, scaled)
state.append(state_dict.get('score', 0) / 1000.0 * DYNAMIC_SCALE)
# 9. Life (normalized by 3, scaled) - starts at 3
state.append(state_dict.get('life', 3) * DYNAMIC_SCALE / 3.0)
# 10. Current run number (normalized by 20, scaled)
state.append(state_dict.get('run', 0) * DYNAMIC_SCALE / 20.0)
# 11. Win flag (DYNAMIC_SCALE if won, 0 otherwise)
state.append(DYNAMIC_SCALE if state_dict.get('win_sign', False) else 0.0)
# 12. Lose flag (DYNAMIC_SCALE if lost, 0 otherwise)
state.append(DYNAMIC_SCALE if state_dict.get('lose_sign', False) else 0.0)
# 13. Step progress (current_step / step_limit, scaled)
step = state_dict.get('step', 0)
step_limit = state_dict.get('step_limit', 200)
state.append(step / step_limit * DYNAMIC_SCALE if step_limit > 0 else 0.0)
# Pad to 828 dims
while len(state) < 828:
state.append(0.0)
return torch.tensor(state[:828], dtype=torch.float32)
State Vector Layout Summary
| Range | Dims | Content | Scale |
|---|---|---|---|
| 0-120 | 121 | Wall grid (11x11) | 1.0 |
| 121-241 | 121 | Small Cheese grid | 10.0 |
| 242-362 | 121 | Junction grid | 1.0 |
| 363-483 | 121 | Dead-end grid | 1.0 |
| 484-485 | 2 | Mouse position [x,y] | 1.0 |
| 486-497 | 12 | Cat positions (6 cats) | 1.0 |
| 498-507 | 10 | Big Cheese positions (5) | 1.0 |
| 508-548 | 41 | Padding (zeros) | - |
| 549 | 1 | Score / 1000 * 10 | 10.0 |
| 550 | 1 | Life / 3 * 10 | 10.0 |
| 551 | 1 | Run / 20 * 10 | 10.0 |
| 552 | 1 | Win flag | 10.0 |
| 553 | 1 | Lose flag | 10.0 |
| 554 | 1 | Step progress | 10.0 |
| 555-827 | 273 | Padding (zeros) | - |
Game Rules (Level 3)
Map
- 11x11 grid maze with walls
- Fixed wall layout for level 3
Entities
- Mouse: Player-controlled, starts at position [10, 10]
- Cat 0 (Dummy): Starts at [2, 2], moves only during command execution (len(command) steps)
- Cat 1 (Naughty): Starts at [5, 5], moves every mouse step
- Small Cheese (SC): 75 stationary items, +10 points each
- Stationary Big Cheese (movbc): 2 items, +500 points each, don't move
- Moving Big Cheese (crzbc): 2 items, +500 points each, move each step
Cat Movement (Random Mode)
Cats move randomly at junctions (no turning back), continue straight in corridors, pick random direction when blocked. This is the _get_cats_direct_actions mode in the simulator.
Scoring
| Event | Points |
|---|---|
| Collect Small Cheese | +10 |
| Collect Big Cheese | +500 |
| Hit Wall | -10 |
| Caught by Cat | -500 (+ lose 1 life) |
| Win Bonus | +(run * 10 + step) |
Win/Lose Conditions
- WIN: Collect ALL 75 Small Cheese + END token executed
- LOSE (life): Life reaches 0 (caught 3 times)
- LOSE (step): Step count reaches 200
- LOSE (run): 20 runs exhausted without winning
Game Flow
- Game starts with mouse at [10,10], 3 lives, 20 max runs
- Each run: model generates a program -> program executes step by step
- During execution: mouse moves, cats move randomly, cheese collected, collisions checked
- After program ends: next run begins
- Continue until WIN or LOSE
Program Execution
When a program like [0, 110, 106, 2, 3, 112] executes:
- Token
0(UP): mouse moves up 1 step - Token
110, 106, 2(LOOP 7 LEFT): mouse moves left 7 steps - Token
3(RIGHT): mouse moves right 1 step - Token
112(END): program ends
Each step:
- Mouse attempts to move in the direction
- If wall: mouse stays, -10 points
- Cat 1 moves (random at junctions)
- Cat 0 moves (only during command-length steps)
- Check for cat collision: -500 points, lose 1 life, respawn at [10,10]
- Check for cheese collection: +10 (SC) or +500 (BC)
- Check win/lose conditions
Performance
| Metric | Value |
|---|---|
| Win Rate (temp=0.3, 100 games) | 30% |
| Average Score | 1437 |
| Average Runs per Win | 13.8 |
| Simulator | New simulator (random cats) |
Training Pipeline
- Base Model: Expert R1 checkpoint (trained on old simulator, 95% win rate on old sim, 14% on new sim)
- RM32 Data Generation: 10,000 games with Running Max 32 (exhaustive 33 candidates), 20.4% win rate, 30,788 winning run samples
- SFT Training: 40 epochs, batch 4096, lr 3e-5, cosine schedule -> 30% win rate
Generation Parameters
| Parameter | Recommended | Description |
|---|---|---|
| temperature | 0.3 | Lower = more deterministic, higher win rate |
| top_k | 10 | Top-k sampling |
| grammar_constrained | True | MUST be True to generate valid programs |
| max_length | 12 | Maximum program length |
File Structure
hardai_model_export/
model_best.pt # Model checkpoint (886MB)
README.md # This file
lightweight_simulator.py # Game simulator
model/ # Model architecture
__init__.py
model_2B.py # Main model class
state_encoder.py
program_embedding.py
transformer.py # Flash Attention + gradient checkpointing
multi_task_head.py
memory_encoder.py
memory_state_fusion.py
value_predictor.py
Requirements
torch >= 2.0
numpy
pygame (for simulator, can run headless with SDL_VIDEODRIVER=dummy)
Headless Mode (No Display)
import os
os.environ['SDL_VIDEODRIVER'] = 'dummy'
os.environ['SDL_AUDIODRIVER'] = 'dummy'
Set these BEFORE importing the simulator.