| # Mouse AI - Program Generation Model |
|
|
| 232M parameter transformer that generates movement programs for a mouse navigating a maze to collect cheese while avoiding cats. |
|
|
| ## Quick Start |
|
|
| ```python |
| import torch |
| from model.model_2B import StructureAwareTransformer2B |
| from lightweight_simulator import LightweightGameSimulator |
| |
| # Load model |
| device = 'cuda:0' # or 'cpu' |
| ckpt = torch.load('model_best.pt', map_location='cpu', weights_only=False) |
| config = ckpt['model_config'] |
| |
| model = StructureAwareTransformer2B(**config) |
| model.load_state_dict(ckpt['model_state_dict']) |
| model = model.to(device) |
| model.eval() |
| |
| # Play a game |
| game = LightweightGameSimulator(level=3) |
| game.reset() |
| |
| for run in range(20): |
| if game.win_sign or game.lose_sign: |
| break |
| |
| # Get state vector (828 dimensions) |
| state = get_state_vector(game).unsqueeze(0).to(device) |
| |
| # Generate program |
| with torch.no_grad(): |
| prog = model.generate( |
| state, max_length=12, temperature=0.3, |
| top_k=10, grammar_constrained=True |
| ) |
| |
| # Parse output |
| if isinstance(prog, tuple): prog = prog[0] |
| if isinstance(prog, torch.Tensor): prog = prog[0].tolist() |
| if prog and prog[0] == 0: prog = prog[1:] # remove start token |
| if 112 in prog: prog = prog[:prog.index(112)] # remove END and after |
| |
| # Execute |
| game.execute_program(prog) |
| |
| print(f"{'WIN' if game.win_sign else 'LOSE'} | Score: {game.score}") |
| ``` |
|
|
| ## Model Architecture |
|
|
| | Parameter | Value | |
| |-----------|-------| |
| | Type | StructureAwareTransformer2B | |
| | Total Parameters | 232.2M | |
| | Hidden Dimension | 1024 | |
| | Layers | 16 | |
| | Attention Heads | 16 (Query) / 4 (KV, Grouped Query Attention) | |
| | Feed-Forward Dim | 4096 | |
| | State Input | 828 dimensions | |
| | Vocab Size | 113 tokens | |
| | Max Program Length | 12 tokens | |
|
|
| ### Model Config (for initialization) |
| ```python |
| config = { |
| 'state_dim': 828, |
| 'hidden_dim': 1024, |
| 'vocab_size': 113, |
| 'max_program_length': 12, |
| 'num_layers': 16, |
| 'num_heads': 16, |
| 'num_kv_heads': 4, |
| 'ff_dim': 4096, |
| 'dropout': 0.1, |
| 'end_token': 112, |
| } |
| model = StructureAwareTransformer2B(**config) |
| ``` |
|
|
| ## Token Vocabulary (113 tokens) |
|
|
| ### Direction Tokens (0-3) |
| | Token ID | Direction | Movement | |
| |----------|-----------|----------| |
| | 0 | UP | Mouse moves up one cell | |
| | 1 | DOWN | Mouse moves down one cell | |
| | 2 | LEFT | Mouse moves left one cell | |
| | 3 | RIGHT | Mouse moves right one cell | |
|
|
| ### Number Tokens (100-109) |
| | Token ID | Value | Usage | |
| |----------|-------|-------| |
| | 100 | 1 | LOOP repeat count (1 time) | |
| | 104 | 5 | LOOP repeat count (5 times) | |
| | 105 | 6 | LOOP repeat count (6 times) | |
| | 106 | 7 | LOOP repeat count (7 times) | |
| | 107 | 8 | LOOP repeat count (8 times) | |
| | 108 | 9 | LOOP repeat count (9 times) | |
| | 109 | 10 | LOOP repeat count (10 times) | |
|
|
| Note: Tokens 101-103 (values 2-4) exist in vocab but are NOT used by the grammar. The model only generates NUM tokens >= 104 (5+ repeats) for efficiency. |
|
|
| ### Special Tokens |
| | Token ID | Name | Function | |
| |----------|------|----------| |
| | 110 | LOOP | Start a loop structure | |
| | 112 | END | End of program | |
|
|
| Token 111 (IF) was removed due to simulator incompatibility. |
|
|
| ## Grammar Rules |
|
|
| Programs follow a strict context-free grammar: |
|
|
| ``` |
| start -> DIR | LOOP NUM DIR | END |
| after_DIR -> DIR | LOOP NUM DIR | END |
| after_LOOP -> NUM (must be 104-109) |
| after_NUM -> DIR (must be 0-3) |
| after_END -> (stop generation) |
| ``` |
|
|
| ### Valid Program Examples |
| ``` |
| [0, 112] # Move UP, END |
| [2, 2, 2, 112] # Move LEFT 3 times, END |
| [110, 106, 1, 112] # LOOP(7 times, DOWN), END |
| [0, 110, 104, 2, 3, 112] # UP, LOOP(5 times, LEFT), RIGHT, END |
| [110, 108, 0, 110, 105, 3, 112] # LOOP(9, UP), LOOP(6, RIGHT), END |
| ``` |
|
|
| ### Grammar Constraint: LOOP cutoff at position 8 |
| LOOP token (110) is only allowed at positions 0-7 (indices 0-7 in the generated sequence). From position 8 onwards, only DIR tokens and END are allowed. This prevents overly long programs. |
|
|
| ## State Vector (828 dimensions) |
|
|
| The 828-dimensional state vector encodes the complete game state: |
|
|
| ```python |
| def get_state_vector(sim): |
| """Extract 828-dim state vector from game simulator""" |
| state_dict = sim.get_state_dict() |
| state = [] |
| DYNAMIC_SCALE = 10.0 # Scale factor for dynamic features |
| |
| # --- Grid features (11x11 grids) --- |
| |
| # 1. Wall grid (121 dims): 1=wall, 0=empty |
| for row in state_dict['wall']: |
| state.extend(row) |
| |
| # 2. Small Cheese grid (121 dims): 1=cheese present, 0=collected |
| # Scaled by DYNAMIC_SCALE (10.0) |
| for row in state_dict['sc']: |
| state.extend([v * DYNAMIC_SCALE for v in row]) |
| |
| # 3. Junction grid (121 dims): 1=junction, 0=not |
| for row in state_dict['junc']: |
| state.extend(row) |
| |
| # 4. Dead-end grid (121 dims): 1=dead-end, 0=not |
| for row in state_dict['deadend']: |
| state.extend(row) |
| |
| # Total grid: 484 dims (4 * 121) |
| |
| # --- Entity positions --- |
| |
| # 5. Mouse position (2 dims): [x, y] |
| mouse = state_dict['mouse'] |
| state.extend([float(mouse[0]), float(mouse[1])]) |
| |
| # 6. Cat positions (12 dims): 6 cats * [x, y], unused=-1 |
| cat_list = state_dict.get('cat', []) |
| for i in range(6): |
| if i < len(cat_list): |
| state.extend([float(cat_list[i][0]), float(cat_list[i][1])]) |
| else: |
| state.extend([-1.0, -1.0]) |
| |
| # 7. Moving Big Cheese positions (10 dims): 5 * [x, y], unused=-1 |
| bc_list = state_dict.get('crzbc', []) |
| for i in range(5): |
| if i < len(bc_list): |
| state.extend([float(bc_list[i][0]), float(bc_list[i][1])]) |
| else: |
| state.extend([-1.0, -1.0]) |
| |
| # Pad to 549 dims (484 + 65) |
| while len(state) < 484 + 65: |
| state.append(0.0) |
| |
| # --- Scalar features (6 dims) --- |
| |
| # 8. Score (normalized by 1000, scaled) |
| state.append(state_dict.get('score', 0) / 1000.0 * DYNAMIC_SCALE) |
| |
| # 9. Life (normalized by 3, scaled) - starts at 3 |
| state.append(state_dict.get('life', 3) * DYNAMIC_SCALE / 3.0) |
| |
| # 10. Current run number (normalized by 20, scaled) |
| state.append(state_dict.get('run', 0) * DYNAMIC_SCALE / 20.0) |
| |
| # 11. Win flag (DYNAMIC_SCALE if won, 0 otherwise) |
| state.append(DYNAMIC_SCALE if state_dict.get('win_sign', False) else 0.0) |
| |
| # 12. Lose flag (DYNAMIC_SCALE if lost, 0 otherwise) |
| state.append(DYNAMIC_SCALE if state_dict.get('lose_sign', False) else 0.0) |
| |
| # 13. Step progress (current_step / step_limit, scaled) |
| step = state_dict.get('step', 0) |
| step_limit = state_dict.get('step_limit', 200) |
| state.append(step / step_limit * DYNAMIC_SCALE if step_limit > 0 else 0.0) |
| |
| # Pad to 828 dims |
| while len(state) < 828: |
| state.append(0.0) |
| |
| return torch.tensor(state[:828], dtype=torch.float32) |
| ``` |
|
|
| ### State Vector Layout Summary |
| | Range | Dims | Content | Scale | |
| |-------|------|---------|-------| |
| | 0-120 | 121 | Wall grid (11x11) | 1.0 | |
| | 121-241 | 121 | Small Cheese grid | 10.0 | |
| | 242-362 | 121 | Junction grid | 1.0 | |
| | 363-483 | 121 | Dead-end grid | 1.0 | |
| | 484-485 | 2 | Mouse position [x,y] | 1.0 | |
| | 486-497 | 12 | Cat positions (6 cats) | 1.0 | |
| | 498-507 | 10 | Big Cheese positions (5) | 1.0 | |
| | 508-548 | 41 | Padding (zeros) | - | |
| | 549 | 1 | Score / 1000 * 10 | 10.0 | |
| | 550 | 1 | Life / 3 * 10 | 10.0 | |
| | 551 | 1 | Run / 20 * 10 | 10.0 | |
| | 552 | 1 | Win flag | 10.0 | |
| | 553 | 1 | Lose flag | 10.0 | |
| | 554 | 1 | Step progress | 10.0 | |
| | 555-827 | 273 | Padding (zeros) | - | |
|
|
| ## Game Rules (Level 3) |
|
|
| ### Map |
| - 11x11 grid maze with walls |
| - Fixed wall layout for level 3 |
|
|
| ### Entities |
| - **Mouse**: Player-controlled, starts at position [10, 10] |
| - **Cat 0 (Dummy)**: Starts at [2, 2], moves only during command execution (len(command) steps) |
| - **Cat 1 (Naughty)**: Starts at [5, 5], moves every mouse step |
| - **Small Cheese (SC)**: 75 stationary items, +10 points each |
| - **Stationary Big Cheese (movbc)**: 2 items, +500 points each, don't move |
| - **Moving Big Cheese (crzbc)**: 2 items, +500 points each, move each step |
|
|
| ### Cat Movement (Random Mode) |
| Cats move randomly at junctions (no turning back), continue straight in corridors, pick random direction when blocked. This is the `_get_cats_direct_actions` mode in the simulator. |
|
|
| ### Scoring |
| | Event | Points | |
| |-------|--------| |
| | Collect Small Cheese | +10 | |
| | Collect Big Cheese | +500 | |
| | Hit Wall | -10 | |
| | Caught by Cat | -500 (+ lose 1 life) | |
| | Win Bonus | +(run * 10 + step) | |
|
|
| ### Win/Lose Conditions |
| - **WIN**: Collect ALL 75 Small Cheese + END token executed |
| - **LOSE (life)**: Life reaches 0 (caught 3 times) |
| - **LOSE (step)**: Step count reaches 200 |
| - **LOSE (run)**: 20 runs exhausted without winning |
|
|
| ### Game Flow |
| 1. Game starts with mouse at [10,10], 3 lives, 20 max runs |
| 2. Each run: model generates a program -> program executes step by step |
| 3. During execution: mouse moves, cats move randomly, cheese collected, collisions checked |
| 4. After program ends: next run begins |
| 5. Continue until WIN or LOSE |
|
|
| ## Program Execution |
|
|
| When a program like `[0, 110, 106, 2, 3, 112]` executes: |
|
|
| 1. Token `0` (UP): mouse moves up 1 step |
| 2. Token `110, 106, 2` (LOOP 7 LEFT): mouse moves left 7 steps |
| 3. Token `3` (RIGHT): mouse moves right 1 step |
| 4. Token `112` (END): program ends |
|
|
| Each step: |
| - Mouse attempts to move in the direction |
| - If wall: mouse stays, -10 points |
| - Cat 1 moves (random at junctions) |
| - Cat 0 moves (only during command-length steps) |
| - Check for cat collision: -500 points, lose 1 life, respawn at [10,10] |
| - Check for cheese collection: +10 (SC) or +500 (BC) |
| - Check win/lose conditions |
|
|
| ## Performance |
|
|
| | Metric | Value | |
| |--------|-------| |
| | Win Rate (temp=0.3, 100 games) | 30% | |
| | Average Score | 1437 | |
| | Average Runs per Win | 13.8 | |
| | Simulator | New simulator (random cats) | |
|
|
| ### Training Pipeline |
| 1. **Base Model**: Expert R1 checkpoint (trained on old simulator, 95% win rate on old sim, 14% on new sim) |
| 2. **RM32 Data Generation**: 10,000 games with Running Max 32 (exhaustive 33 candidates), 20.4% win rate, 30,788 winning run samples |
| 3. **SFT Training**: 40 epochs, batch 4096, lr 3e-5, cosine schedule -> 30% win rate |
|
|
| ## Generation Parameters |
|
|
| | Parameter | Recommended | Description | |
| |-----------|-------------|-------------| |
| | temperature | 0.3 | Lower = more deterministic, higher win rate | |
| | top_k | 10 | Top-k sampling | |
| | grammar_constrained | True | MUST be True to generate valid programs | |
| | max_length | 12 | Maximum program length | |
| |
| ## File Structure |
| |
| ``` |
| hardai_model_export/ |
| model_best.pt # Model checkpoint (886MB) |
| README.md # This file |
| lightweight_simulator.py # Game simulator |
| model/ # Model architecture |
| __init__.py |
| model_2B.py # Main model class |
| state_encoder.py |
| program_embedding.py |
| transformer.py # Flash Attention + gradient checkpointing |
| multi_task_head.py |
| memory_encoder.py |
| memory_state_fusion.py |
| value_predictor.py |
| ``` |
| |
| ## Requirements |
|
|
| ``` |
| torch >= 2.0 |
| numpy |
| pygame (for simulator, can run headless with SDL_VIDEODRIVER=dummy) |
| ``` |
|
|
| ## Headless Mode (No Display) |
|
|
| ```python |
| import os |
| os.environ['SDL_VIDEODRIVER'] = 'dummy' |
| os.environ['SDL_AUDIODRIVER'] = 'dummy' |
| ``` |
|
|
| Set these BEFORE importing the simulator. |
|
|