# Mouse AI - Program Generation Model 232M parameter transformer that generates movement programs for a mouse navigating a maze to collect cheese while avoiding cats. ## Quick Start ```python import torch from model.model_2B import StructureAwareTransformer2B from lightweight_simulator import LightweightGameSimulator # Load model device = 'cuda:0' # or 'cpu' ckpt = torch.load('model_best.pt', map_location='cpu', weights_only=False) config = ckpt['model_config'] model = StructureAwareTransformer2B(**config) model.load_state_dict(ckpt['model_state_dict']) model = model.to(device) model.eval() # Play a game game = LightweightGameSimulator(level=3) game.reset() for run in range(20): if game.win_sign or game.lose_sign: break # Get state vector (828 dimensions) state = get_state_vector(game).unsqueeze(0).to(device) # Generate program with torch.no_grad(): prog = model.generate( state, max_length=12, temperature=0.3, top_k=10, grammar_constrained=True ) # Parse output if isinstance(prog, tuple): prog = prog[0] if isinstance(prog, torch.Tensor): prog = prog[0].tolist() if prog and prog[0] == 0: prog = prog[1:] # remove start token if 112 in prog: prog = prog[:prog.index(112)] # remove END and after # Execute game.execute_program(prog) print(f"{'WIN' if game.win_sign else 'LOSE'} | Score: {game.score}") ``` ## Model Architecture | Parameter | Value | |-----------|-------| | Type | StructureAwareTransformer2B | | Total Parameters | 232.2M | | Hidden Dimension | 1024 | | Layers | 16 | | Attention Heads | 16 (Query) / 4 (KV, Grouped Query Attention) | | Feed-Forward Dim | 4096 | | State Input | 828 dimensions | | Vocab Size | 113 tokens | | Max Program Length | 12 tokens | ### Model Config (for initialization) ```python config = { 'state_dim': 828, 'hidden_dim': 1024, 'vocab_size': 113, 'max_program_length': 12, 'num_layers': 16, 'num_heads': 16, 'num_kv_heads': 4, 'ff_dim': 4096, 'dropout': 0.1, 'end_token': 112, } model = StructureAwareTransformer2B(**config) ``` ## Token Vocabulary (113 tokens) ### Direction Tokens (0-3) | Token ID | Direction | Movement | |----------|-----------|----------| | 0 | UP | Mouse moves up one cell | | 1 | DOWN | Mouse moves down one cell | | 2 | LEFT | Mouse moves left one cell | | 3 | RIGHT | Mouse moves right one cell | ### Number Tokens (100-109) | Token ID | Value | Usage | |----------|-------|-------| | 100 | 1 | LOOP repeat count (1 time) | | 104 | 5 | LOOP repeat count (5 times) | | 105 | 6 | LOOP repeat count (6 times) | | 106 | 7 | LOOP repeat count (7 times) | | 107 | 8 | LOOP repeat count (8 times) | | 108 | 9 | LOOP repeat count (9 times) | | 109 | 10 | LOOP repeat count (10 times) | Note: Tokens 101-103 (values 2-4) exist in vocab but are NOT used by the grammar. The model only generates NUM tokens >= 104 (5+ repeats) for efficiency. ### Special Tokens | Token ID | Name | Function | |----------|------|----------| | 110 | LOOP | Start a loop structure | | 112 | END | End of program | Token 111 (IF) was removed due to simulator incompatibility. ## Grammar Rules Programs follow a strict context-free grammar: ``` start -> DIR | LOOP NUM DIR | END after_DIR -> DIR | LOOP NUM DIR | END after_LOOP -> NUM (must be 104-109) after_NUM -> DIR (must be 0-3) after_END -> (stop generation) ``` ### Valid Program Examples ``` [0, 112] # Move UP, END [2, 2, 2, 112] # Move LEFT 3 times, END [110, 106, 1, 112] # LOOP(7 times, DOWN), END [0, 110, 104, 2, 3, 112] # UP, LOOP(5 times, LEFT), RIGHT, END [110, 108, 0, 110, 105, 3, 112] # LOOP(9, UP), LOOP(6, RIGHT), END ``` ### Grammar Constraint: LOOP cutoff at position 8 LOOP token (110) is only allowed at positions 0-7 (indices 0-7 in the generated sequence). From position 8 onwards, only DIR tokens and END are allowed. This prevents overly long programs. ## State Vector (828 dimensions) The 828-dimensional state vector encodes the complete game state: ```python def get_state_vector(sim): """Extract 828-dim state vector from game simulator""" state_dict = sim.get_state_dict() state = [] DYNAMIC_SCALE = 10.0 # Scale factor for dynamic features # --- Grid features (11x11 grids) --- # 1. Wall grid (121 dims): 1=wall, 0=empty for row in state_dict['wall']: state.extend(row) # 2. Small Cheese grid (121 dims): 1=cheese present, 0=collected # Scaled by DYNAMIC_SCALE (10.0) for row in state_dict['sc']: state.extend([v * DYNAMIC_SCALE for v in row]) # 3. Junction grid (121 dims): 1=junction, 0=not for row in state_dict['junc']: state.extend(row) # 4. Dead-end grid (121 dims): 1=dead-end, 0=not for row in state_dict['deadend']: state.extend(row) # Total grid: 484 dims (4 * 121) # --- Entity positions --- # 5. Mouse position (2 dims): [x, y] mouse = state_dict['mouse'] state.extend([float(mouse[0]), float(mouse[1])]) # 6. Cat positions (12 dims): 6 cats * [x, y], unused=-1 cat_list = state_dict.get('cat', []) for i in range(6): if i < len(cat_list): state.extend([float(cat_list[i][0]), float(cat_list[i][1])]) else: state.extend([-1.0, -1.0]) # 7. Moving Big Cheese positions (10 dims): 5 * [x, y], unused=-1 bc_list = state_dict.get('crzbc', []) for i in range(5): if i < len(bc_list): state.extend([float(bc_list[i][0]), float(bc_list[i][1])]) else: state.extend([-1.0, -1.0]) # Pad to 549 dims (484 + 65) while len(state) < 484 + 65: state.append(0.0) # --- Scalar features (6 dims) --- # 8. Score (normalized by 1000, scaled) state.append(state_dict.get('score', 0) / 1000.0 * DYNAMIC_SCALE) # 9. Life (normalized by 3, scaled) - starts at 3 state.append(state_dict.get('life', 3) * DYNAMIC_SCALE / 3.0) # 10. Current run number (normalized by 20, scaled) state.append(state_dict.get('run', 0) * DYNAMIC_SCALE / 20.0) # 11. Win flag (DYNAMIC_SCALE if won, 0 otherwise) state.append(DYNAMIC_SCALE if state_dict.get('win_sign', False) else 0.0) # 12. Lose flag (DYNAMIC_SCALE if lost, 0 otherwise) state.append(DYNAMIC_SCALE if state_dict.get('lose_sign', False) else 0.0) # 13. Step progress (current_step / step_limit, scaled) step = state_dict.get('step', 0) step_limit = state_dict.get('step_limit', 200) state.append(step / step_limit * DYNAMIC_SCALE if step_limit > 0 else 0.0) # Pad to 828 dims while len(state) < 828: state.append(0.0) return torch.tensor(state[:828], dtype=torch.float32) ``` ### State Vector Layout Summary | Range | Dims | Content | Scale | |-------|------|---------|-------| | 0-120 | 121 | Wall grid (11x11) | 1.0 | | 121-241 | 121 | Small Cheese grid | 10.0 | | 242-362 | 121 | Junction grid | 1.0 | | 363-483 | 121 | Dead-end grid | 1.0 | | 484-485 | 2 | Mouse position [x,y] | 1.0 | | 486-497 | 12 | Cat positions (6 cats) | 1.0 | | 498-507 | 10 | Big Cheese positions (5) | 1.0 | | 508-548 | 41 | Padding (zeros) | - | | 549 | 1 | Score / 1000 * 10 | 10.0 | | 550 | 1 | Life / 3 * 10 | 10.0 | | 551 | 1 | Run / 20 * 10 | 10.0 | | 552 | 1 | Win flag | 10.0 | | 553 | 1 | Lose flag | 10.0 | | 554 | 1 | Step progress | 10.0 | | 555-827 | 273 | Padding (zeros) | - | ## Game Rules (Level 3) ### Map - 11x11 grid maze with walls - Fixed wall layout for level 3 ### Entities - **Mouse**: Player-controlled, starts at position [10, 10] - **Cat 0 (Dummy)**: Starts at [2, 2], moves only during command execution (len(command) steps) - **Cat 1 (Naughty)**: Starts at [5, 5], moves every mouse step - **Small Cheese (SC)**: 75 stationary items, +10 points each - **Stationary Big Cheese (movbc)**: 2 items, +500 points each, don't move - **Moving Big Cheese (crzbc)**: 2 items, +500 points each, move each step ### Cat Movement (Random Mode) Cats move randomly at junctions (no turning back), continue straight in corridors, pick random direction when blocked. This is the `_get_cats_direct_actions` mode in the simulator. ### Scoring | Event | Points | |-------|--------| | Collect Small Cheese | +10 | | Collect Big Cheese | +500 | | Hit Wall | -10 | | Caught by Cat | -500 (+ lose 1 life) | | Win Bonus | +(run * 10 + step) | ### Win/Lose Conditions - **WIN**: Collect ALL 75 Small Cheese + END token executed - **LOSE (life)**: Life reaches 0 (caught 3 times) - **LOSE (step)**: Step count reaches 200 - **LOSE (run)**: 20 runs exhausted without winning ### Game Flow 1. Game starts with mouse at [10,10], 3 lives, 20 max runs 2. Each run: model generates a program -> program executes step by step 3. During execution: mouse moves, cats move randomly, cheese collected, collisions checked 4. After program ends: next run begins 5. Continue until WIN or LOSE ## Program Execution When a program like `[0, 110, 106, 2, 3, 112]` executes: 1. Token `0` (UP): mouse moves up 1 step 2. Token `110, 106, 2` (LOOP 7 LEFT): mouse moves left 7 steps 3. Token `3` (RIGHT): mouse moves right 1 step 4. Token `112` (END): program ends Each step: - Mouse attempts to move in the direction - If wall: mouse stays, -10 points - Cat 1 moves (random at junctions) - Cat 0 moves (only during command-length steps) - Check for cat collision: -500 points, lose 1 life, respawn at [10,10] - Check for cheese collection: +10 (SC) or +500 (BC) - Check win/lose conditions ## Performance | Metric | Value | |--------|-------| | Win Rate (temp=0.3, 100 games) | 30% | | Average Score | 1437 | | Average Runs per Win | 13.8 | | Simulator | New simulator (random cats) | ### Training Pipeline 1. **Base Model**: Expert R1 checkpoint (trained on old simulator, 95% win rate on old sim, 14% on new sim) 2. **RM32 Data Generation**: 10,000 games with Running Max 32 (exhaustive 33 candidates), 20.4% win rate, 30,788 winning run samples 3. **SFT Training**: 40 epochs, batch 4096, lr 3e-5, cosine schedule -> 30% win rate ## Generation Parameters | Parameter | Recommended | Description | |-----------|-------------|-------------| | temperature | 0.3 | Lower = more deterministic, higher win rate | | top_k | 10 | Top-k sampling | | grammar_constrained | True | MUST be True to generate valid programs | | max_length | 12 | Maximum program length | ## File Structure ``` hardai_model_export/ model_best.pt # Model checkpoint (886MB) README.md # This file lightweight_simulator.py # Game simulator model/ # Model architecture __init__.py model_2B.py # Main model class state_encoder.py program_embedding.py transformer.py # Flash Attention + gradient checkpointing multi_task_head.py memory_encoder.py memory_state_fusion.py value_predictor.py ``` ## Requirements ``` torch >= 2.0 numpy pygame (for simulator, can run headless with SDL_VIDEODRIVER=dummy) ``` ## Headless Mode (No Display) ```python import os os.environ['SDL_VIDEODRIVER'] = 'dummy' os.environ['SDL_AUDIODRIVER'] = 'dummy' ``` Set these BEFORE importing the simulator.