YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

Mouse AI - Program Generation Model

232M parameter transformer that generates movement programs for a mouse navigating a maze to collect cheese while avoiding cats.

Quick Start

import torch
from model.model_2B import StructureAwareTransformer2B
from lightweight_simulator import LightweightGameSimulator

# Load model
device = 'cuda:0'  # or 'cpu'
ckpt = torch.load('model_best.pt', map_location='cpu', weights_only=False)
config = ckpt['model_config']

model = StructureAwareTransformer2B(**config)
model.load_state_dict(ckpt['model_state_dict'])
model = model.to(device)
model.eval()

# Play a game
game = LightweightGameSimulator(level=3)
game.reset()

for run in range(20):
    if game.win_sign or game.lose_sign:
        break

    # Get state vector (828 dimensions)
    state = get_state_vector(game).unsqueeze(0).to(device)

    # Generate program
    with torch.no_grad():
        prog = model.generate(
            state, max_length=12, temperature=0.3,
            top_k=10, grammar_constrained=True
        )

    # Parse output
    if isinstance(prog, tuple): prog = prog[0]
    if isinstance(prog, torch.Tensor): prog = prog[0].tolist()
    if prog and prog[0] == 0: prog = prog[1:]  # remove start token
    if 112 in prog: prog = prog[:prog.index(112)]  # remove END and after

    # Execute
    game.execute_program(prog)

print(f"{'WIN' if game.win_sign else 'LOSE'} | Score: {game.score}")

Model Architecture

Parameter Value
Type StructureAwareTransformer2B
Total Parameters 232.2M
Hidden Dimension 1024
Layers 16
Attention Heads 16 (Query) / 4 (KV, Grouped Query Attention)
Feed-Forward Dim 4096
State Input 828 dimensions
Vocab Size 113 tokens
Max Program Length 12 tokens

Model Config (for initialization)

config = {
    'state_dim': 828,
    'hidden_dim': 1024,
    'vocab_size': 113,
    'max_program_length': 12,
    'num_layers': 16,
    'num_heads': 16,
    'num_kv_heads': 4,
    'ff_dim': 4096,
    'dropout': 0.1,
    'end_token': 112,
}
model = StructureAwareTransformer2B(**config)

Token Vocabulary (113 tokens)

Direction Tokens (0-3)

Token ID Direction Movement
0 UP Mouse moves up one cell
1 DOWN Mouse moves down one cell
2 LEFT Mouse moves left one cell
3 RIGHT Mouse moves right one cell

Number Tokens (100-109)

Token ID Value Usage
100 1 LOOP repeat count (1 time)
104 5 LOOP repeat count (5 times)
105 6 LOOP repeat count (6 times)
106 7 LOOP repeat count (7 times)
107 8 LOOP repeat count (8 times)
108 9 LOOP repeat count (9 times)
109 10 LOOP repeat count (10 times)

Note: Tokens 101-103 (values 2-4) exist in vocab but are NOT used by the grammar. The model only generates NUM tokens >= 104 (5+ repeats) for efficiency.

Special Tokens

Token ID Name Function
110 LOOP Start a loop structure
112 END End of program

Token 111 (IF) was removed due to simulator incompatibility.

Grammar Rules

Programs follow a strict context-free grammar:

start       -> DIR | LOOP NUM DIR | END
after_DIR   -> DIR | LOOP NUM DIR | END
after_LOOP  -> NUM (must be 104-109)
after_NUM   -> DIR (must be 0-3)
after_END   -> (stop generation)

Valid Program Examples

[0, 112]                          # Move UP, END
[2, 2, 2, 112]                   # Move LEFT 3 times, END
[110, 106, 1, 112]               # LOOP(7 times, DOWN), END
[0, 110, 104, 2, 3, 112]         # UP, LOOP(5 times, LEFT), RIGHT, END
[110, 108, 0, 110, 105, 3, 112]  # LOOP(9, UP), LOOP(6, RIGHT), END

Grammar Constraint: LOOP cutoff at position 8

LOOP token (110) is only allowed at positions 0-7 (indices 0-7 in the generated sequence). From position 8 onwards, only DIR tokens and END are allowed. This prevents overly long programs.

State Vector (828 dimensions)

The 828-dimensional state vector encodes the complete game state:

def get_state_vector(sim):
    """Extract 828-dim state vector from game simulator"""
    state_dict = sim.get_state_dict()
    state = []
    DYNAMIC_SCALE = 10.0  # Scale factor for dynamic features

    # --- Grid features (11x11 grids) ---

    # 1. Wall grid (121 dims): 1=wall, 0=empty
    for row in state_dict['wall']:
        state.extend(row)

    # 2. Small Cheese grid (121 dims): 1=cheese present, 0=collected
    #    Scaled by DYNAMIC_SCALE (10.0)
    for row in state_dict['sc']:
        state.extend([v * DYNAMIC_SCALE for v in row])

    # 3. Junction grid (121 dims): 1=junction, 0=not
    for row in state_dict['junc']:
        state.extend(row)

    # 4. Dead-end grid (121 dims): 1=dead-end, 0=not
    for row in state_dict['deadend']:
        state.extend(row)

    # Total grid: 484 dims (4 * 121)

    # --- Entity positions ---

    # 5. Mouse position (2 dims): [x, y]
    mouse = state_dict['mouse']
    state.extend([float(mouse[0]), float(mouse[1])])

    # 6. Cat positions (12 dims): 6 cats * [x, y], unused=-1
    cat_list = state_dict.get('cat', [])
    for i in range(6):
        if i < len(cat_list):
            state.extend([float(cat_list[i][0]), float(cat_list[i][1])])
        else:
            state.extend([-1.0, -1.0])

    # 7. Moving Big Cheese positions (10 dims): 5 * [x, y], unused=-1
    bc_list = state_dict.get('crzbc', [])
    for i in range(5):
        if i < len(bc_list):
            state.extend([float(bc_list[i][0]), float(bc_list[i][1])])
        else:
            state.extend([-1.0, -1.0])

    # Pad to 549 dims (484 + 65)
    while len(state) < 484 + 65:
        state.append(0.0)

    # --- Scalar features (6 dims) ---

    # 8. Score (normalized by 1000, scaled)
    state.append(state_dict.get('score', 0) / 1000.0 * DYNAMIC_SCALE)

    # 9. Life (normalized by 3, scaled) - starts at 3
    state.append(state_dict.get('life', 3) * DYNAMIC_SCALE / 3.0)

    # 10. Current run number (normalized by 20, scaled)
    state.append(state_dict.get('run', 0) * DYNAMIC_SCALE / 20.0)

    # 11. Win flag (DYNAMIC_SCALE if won, 0 otherwise)
    state.append(DYNAMIC_SCALE if state_dict.get('win_sign', False) else 0.0)

    # 12. Lose flag (DYNAMIC_SCALE if lost, 0 otherwise)
    state.append(DYNAMIC_SCALE if state_dict.get('lose_sign', False) else 0.0)

    # 13. Step progress (current_step / step_limit, scaled)
    step = state_dict.get('step', 0)
    step_limit = state_dict.get('step_limit', 200)
    state.append(step / step_limit * DYNAMIC_SCALE if step_limit > 0 else 0.0)

    # Pad to 828 dims
    while len(state) < 828:
        state.append(0.0)

    return torch.tensor(state[:828], dtype=torch.float32)

State Vector Layout Summary

Range Dims Content Scale
0-120 121 Wall grid (11x11) 1.0
121-241 121 Small Cheese grid 10.0
242-362 121 Junction grid 1.0
363-483 121 Dead-end grid 1.0
484-485 2 Mouse position [x,y] 1.0
486-497 12 Cat positions (6 cats) 1.0
498-507 10 Big Cheese positions (5) 1.0
508-548 41 Padding (zeros) -
549 1 Score / 1000 * 10 10.0
550 1 Life / 3 * 10 10.0
551 1 Run / 20 * 10 10.0
552 1 Win flag 10.0
553 1 Lose flag 10.0
554 1 Step progress 10.0
555-827 273 Padding (zeros) -

Game Rules (Level 3)

Map

  • 11x11 grid maze with walls
  • Fixed wall layout for level 3

Entities

  • Mouse: Player-controlled, starts at position [10, 10]
  • Cat 0 (Dummy): Starts at [2, 2], moves only during command execution (len(command) steps)
  • Cat 1 (Naughty): Starts at [5, 5], moves every mouse step
  • Small Cheese (SC): 75 stationary items, +10 points each
  • Stationary Big Cheese (movbc): 2 items, +500 points each, don't move
  • Moving Big Cheese (crzbc): 2 items, +500 points each, move each step

Cat Movement (Random Mode)

Cats move randomly at junctions (no turning back), continue straight in corridors, pick random direction when blocked. This is the _get_cats_direct_actions mode in the simulator.

Scoring

Event Points
Collect Small Cheese +10
Collect Big Cheese +500
Hit Wall -10
Caught by Cat -500 (+ lose 1 life)
Win Bonus +(run * 10 + step)

Win/Lose Conditions

  • WIN: Collect ALL 75 Small Cheese + END token executed
  • LOSE (life): Life reaches 0 (caught 3 times)
  • LOSE (step): Step count reaches 200
  • LOSE (run): 20 runs exhausted without winning

Game Flow

  1. Game starts with mouse at [10,10], 3 lives, 20 max runs
  2. Each run: model generates a program -> program executes step by step
  3. During execution: mouse moves, cats move randomly, cheese collected, collisions checked
  4. After program ends: next run begins
  5. Continue until WIN or LOSE

Program Execution

When a program like [0, 110, 106, 2, 3, 112] executes:

  1. Token 0 (UP): mouse moves up 1 step
  2. Token 110, 106, 2 (LOOP 7 LEFT): mouse moves left 7 steps
  3. Token 3 (RIGHT): mouse moves right 1 step
  4. Token 112 (END): program ends

Each step:

  • Mouse attempts to move in the direction
  • If wall: mouse stays, -10 points
  • Cat 1 moves (random at junctions)
  • Cat 0 moves (only during command-length steps)
  • Check for cat collision: -500 points, lose 1 life, respawn at [10,10]
  • Check for cheese collection: +10 (SC) or +500 (BC)
  • Check win/lose conditions

Performance

Metric Value
Win Rate (temp=0.3, 100 games) 30%
Average Score 1437
Average Runs per Win 13.8
Simulator New simulator (random cats)

Training Pipeline

  1. Base Model: Expert R1 checkpoint (trained on old simulator, 95% win rate on old sim, 14% on new sim)
  2. RM32 Data Generation: 10,000 games with Running Max 32 (exhaustive 33 candidates), 20.4% win rate, 30,788 winning run samples
  3. SFT Training: 40 epochs, batch 4096, lr 3e-5, cosine schedule -> 30% win rate

Generation Parameters

Parameter Recommended Description
temperature 0.3 Lower = more deterministic, higher win rate
top_k 10 Top-k sampling
grammar_constrained True MUST be True to generate valid programs
max_length 12 Maximum program length

File Structure

hardai_model_export/
  model_best.pt              # Model checkpoint (886MB)
  README.md                  # This file
  lightweight_simulator.py   # Game simulator
  model/                     # Model architecture
    __init__.py
    model_2B.py              # Main model class
    state_encoder.py
    program_embedding.py
    transformer.py           # Flash Attention + gradient checkpointing
    multi_task_head.py
    memory_encoder.py
    memory_state_fusion.py
    value_predictor.py

Requirements

torch >= 2.0
numpy
pygame (for simulator, can run headless with SDL_VIDEODRIVER=dummy)

Headless Mode (No Display)

import os
os.environ['SDL_VIDEODRIVER'] = 'dummy'
os.environ['SDL_AUDIODRIVER'] = 'dummy'

Set these BEFORE importing the simulator.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support