koreashin commited on
Commit
7a208d8
·
verified ·
1 Parent(s): a98591d

Upload 2 files

Browse files
Files changed (2) hide show
  1. README.md +359 -3
  2. model_best.pt +3 -0
README.md CHANGED
@@ -1,3 +1,359 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mouse AI - Program Generation Model
2
+
3
+ 232M parameter transformer that generates movement programs for a mouse navigating a maze to collect cheese while avoiding cats.
4
+
5
+ ## Quick Start
6
+
7
+ ```python
8
+ import torch
9
+ from model.model_2B import StructureAwareTransformer2B
10
+ from lightweight_simulator import LightweightGameSimulator
11
+
12
+ # Load model
13
+ device = 'cuda:0' # or 'cpu'
14
+ ckpt = torch.load('model_best.pt', map_location='cpu', weights_only=False)
15
+ config = ckpt['model_config']
16
+
17
+ model = StructureAwareTransformer2B(**config)
18
+ model.load_state_dict(ckpt['model_state_dict'])
19
+ model = model.to(device)
20
+ model.eval()
21
+
22
+ # Play a game
23
+ game = LightweightGameSimulator(level=3)
24
+ game.reset()
25
+
26
+ for run in range(20):
27
+ if game.win_sign or game.lose_sign:
28
+ break
29
+
30
+ # Get state vector (828 dimensions)
31
+ state = get_state_vector(game).unsqueeze(0).to(device)
32
+
33
+ # Generate program
34
+ with torch.no_grad():
35
+ prog = model.generate(
36
+ state, max_length=12, temperature=0.3,
37
+ top_k=10, grammar_constrained=True
38
+ )
39
+
40
+ # Parse output
41
+ if isinstance(prog, tuple): prog = prog[0]
42
+ if isinstance(prog, torch.Tensor): prog = prog[0].tolist()
43
+ if prog and prog[0] == 0: prog = prog[1:] # remove start token
44
+ if 112 in prog: prog = prog[:prog.index(112)] # remove END and after
45
+
46
+ # Execute
47
+ game.execute_program(prog)
48
+
49
+ print(f"{'WIN' if game.win_sign else 'LOSE'} | Score: {game.score}")
50
+ ```
51
+
52
+ ## Model Architecture
53
+
54
+ | Parameter | Value |
55
+ |-----------|-------|
56
+ | Type | StructureAwareTransformer2B |
57
+ | Total Parameters | 232.2M |
58
+ | Hidden Dimension | 1024 |
59
+ | Layers | 16 |
60
+ | Attention Heads | 16 (Query) / 4 (KV, Grouped Query Attention) |
61
+ | Feed-Forward Dim | 4096 |
62
+ | State Input | 828 dimensions |
63
+ | Vocab Size | 113 tokens |
64
+ | Max Program Length | 12 tokens |
65
+
66
+ ### Model Config (for initialization)
67
+ ```python
68
+ config = {
69
+ 'state_dim': 828,
70
+ 'hidden_dim': 1024,
71
+ 'vocab_size': 113,
72
+ 'max_program_length': 12,
73
+ 'num_layers': 16,
74
+ 'num_heads': 16,
75
+ 'num_kv_heads': 4,
76
+ 'ff_dim': 4096,
77
+ 'dropout': 0.1,
78
+ 'end_token': 112,
79
+ }
80
+ model = StructureAwareTransformer2B(**config)
81
+ ```
82
+
83
+ ## Token Vocabulary (113 tokens)
84
+
85
+ ### Direction Tokens (0-3)
86
+ | Token ID | Direction | Movement |
87
+ |----------|-----------|----------|
88
+ | 0 | UP | Mouse moves up one cell |
89
+ | 1 | DOWN | Mouse moves down one cell |
90
+ | 2 | LEFT | Mouse moves left one cell |
91
+ | 3 | RIGHT | Mouse moves right one cell |
92
+
93
+ ### Number Tokens (100-109)
94
+ | Token ID | Value | Usage |
95
+ |----------|-------|-------|
96
+ | 100 | 1 | LOOP repeat count (1 time) |
97
+ | 104 | 5 | LOOP repeat count (5 times) |
98
+ | 105 | 6 | LOOP repeat count (6 times) |
99
+ | 106 | 7 | LOOP repeat count (7 times) |
100
+ | 107 | 8 | LOOP repeat count (8 times) |
101
+ | 108 | 9 | LOOP repeat count (9 times) |
102
+ | 109 | 10 | LOOP repeat count (10 times) |
103
+
104
+ Note: Tokens 101-103 (values 2-4) exist in vocab but are NOT used by the grammar. The model only generates NUM tokens >= 104 (5+ repeats) for efficiency.
105
+
106
+ ### Special Tokens
107
+ | Token ID | Name | Function |
108
+ |----------|------|----------|
109
+ | 110 | LOOP | Start a loop structure |
110
+ | 112 | END | End of program |
111
+
112
+ Token 111 (IF) was removed due to simulator incompatibility.
113
+
114
+ ## Grammar Rules
115
+
116
+ Programs follow a strict context-free grammar:
117
+
118
+ ```
119
+ start -> DIR | LOOP NUM DIR | END
120
+ after_DIR -> DIR | LOOP NUM DIR | END
121
+ after_LOOP -> NUM (must be 104-109)
122
+ after_NUM -> DIR (must be 0-3)
123
+ after_END -> (stop generation)
124
+ ```
125
+
126
+ ### Valid Program Examples
127
+ ```
128
+ [0, 112] # Move UP, END
129
+ [2, 2, 2, 112] # Move LEFT 3 times, END
130
+ [110, 106, 1, 112] # LOOP(7 times, DOWN), END
131
+ [0, 110, 104, 2, 3, 112] # UP, LOOP(5 times, LEFT), RIGHT, END
132
+ [110, 108, 0, 110, 105, 3, 112] # LOOP(9, UP), LOOP(6, RIGHT), END
133
+ ```
134
+
135
+ ### Grammar Constraint: LOOP cutoff at position 8
136
+ LOOP token (110) is only allowed at positions 0-7 (indices 0-7 in the generated sequence). From position 8 onwards, only DIR tokens and END are allowed. This prevents overly long programs.
137
+
138
+ ## State Vector (828 dimensions)
139
+
140
+ The 828-dimensional state vector encodes the complete game state:
141
+
142
+ ```python
143
+ def get_state_vector(sim):
144
+ """Extract 828-dim state vector from game simulator"""
145
+ state_dict = sim.get_state_dict()
146
+ state = []
147
+ DYNAMIC_SCALE = 10.0 # Scale factor for dynamic features
148
+
149
+ # --- Grid features (11x11 grids) ---
150
+
151
+ # 1. Wall grid (121 dims): 1=wall, 0=empty
152
+ for row in state_dict['wall']:
153
+ state.extend(row)
154
+
155
+ # 2. Small Cheese grid (121 dims): 1=cheese present, 0=collected
156
+ # Scaled by DYNAMIC_SCALE (10.0)
157
+ for row in state_dict['sc']:
158
+ state.extend([v * DYNAMIC_SCALE for v in row])
159
+
160
+ # 3. Junction grid (121 dims): 1=junction, 0=not
161
+ for row in state_dict['junc']:
162
+ state.extend(row)
163
+
164
+ # 4. Dead-end grid (121 dims): 1=dead-end, 0=not
165
+ for row in state_dict['deadend']:
166
+ state.extend(row)
167
+
168
+ # Total grid: 484 dims (4 * 121)
169
+
170
+ # --- Entity positions ---
171
+
172
+ # 5. Mouse position (2 dims): [x, y]
173
+ mouse = state_dict['mouse']
174
+ state.extend([float(mouse[0]), float(mouse[1])])
175
+
176
+ # 6. Cat positions (12 dims): 6 cats * [x, y], unused=-1
177
+ cat_list = state_dict.get('cat', [])
178
+ for i in range(6):
179
+ if i < len(cat_list):
180
+ state.extend([float(cat_list[i][0]), float(cat_list[i][1])])
181
+ else:
182
+ state.extend([-1.0, -1.0])
183
+
184
+ # 7. Moving Big Cheese positions (10 dims): 5 * [x, y], unused=-1
185
+ bc_list = state_dict.get('crzbc', [])
186
+ for i in range(5):
187
+ if i < len(bc_list):
188
+ state.extend([float(bc_list[i][0]), float(bc_list[i][1])])
189
+ else:
190
+ state.extend([-1.0, -1.0])
191
+
192
+ # Pad to 549 dims (484 + 65)
193
+ while len(state) < 484 + 65:
194
+ state.append(0.0)
195
+
196
+ # --- Scalar features (6 dims) ---
197
+
198
+ # 8. Score (normalized by 1000, scaled)
199
+ state.append(state_dict.get('score', 0) / 1000.0 * DYNAMIC_SCALE)
200
+
201
+ # 9. Life (normalized by 3, scaled) - starts at 3
202
+ state.append(state_dict.get('life', 3) * DYNAMIC_SCALE / 3.0)
203
+
204
+ # 10. Current run number (normalized by 20, scaled)
205
+ state.append(state_dict.get('run', 0) * DYNAMIC_SCALE / 20.0)
206
+
207
+ # 11. Win flag (DYNAMIC_SCALE if won, 0 otherwise)
208
+ state.append(DYNAMIC_SCALE if state_dict.get('win_sign', False) else 0.0)
209
+
210
+ # 12. Lose flag (DYNAMIC_SCALE if lost, 0 otherwise)
211
+ state.append(DYNAMIC_SCALE if state_dict.get('lose_sign', False) else 0.0)
212
+
213
+ # 13. Step progress (current_step / step_limit, scaled)
214
+ step = state_dict.get('step', 0)
215
+ step_limit = state_dict.get('step_limit', 200)
216
+ state.append(step / step_limit * DYNAMIC_SCALE if step_limit > 0 else 0.0)
217
+
218
+ # Pad to 828 dims
219
+ while len(state) < 828:
220
+ state.append(0.0)
221
+
222
+ return torch.tensor(state[:828], dtype=torch.float32)
223
+ ```
224
+
225
+ ### State Vector Layout Summary
226
+ | Range | Dims | Content | Scale |
227
+ |-------|------|---------|-------|
228
+ | 0-120 | 121 | Wall grid (11x11) | 1.0 |
229
+ | 121-241 | 121 | Small Cheese grid | 10.0 |
230
+ | 242-362 | 121 | Junction grid | 1.0 |
231
+ | 363-483 | 121 | Dead-end grid | 1.0 |
232
+ | 484-485 | 2 | Mouse position [x,y] | 1.0 |
233
+ | 486-497 | 12 | Cat positions (6 cats) | 1.0 |
234
+ | 498-507 | 10 | Big Cheese positions (5) | 1.0 |
235
+ | 508-548 | 41 | Padding (zeros) | - |
236
+ | 549 | 1 | Score / 1000 * 10 | 10.0 |
237
+ | 550 | 1 | Life / 3 * 10 | 10.0 |
238
+ | 551 | 1 | Run / 20 * 10 | 10.0 |
239
+ | 552 | 1 | Win flag | 10.0 |
240
+ | 553 | 1 | Lose flag | 10.0 |
241
+ | 554 | 1 | Step progress | 10.0 |
242
+ | 555-827 | 273 | Padding (zeros) | - |
243
+
244
+ ## Game Rules (Level 3)
245
+
246
+ ### Map
247
+ - 11x11 grid maze with walls
248
+ - Fixed wall layout for level 3
249
+
250
+ ### Entities
251
+ - **Mouse**: Player-controlled, starts at position [10, 10]
252
+ - **Cat 0 (Dummy)**: Starts at [2, 2], moves only during command execution (len(command) steps)
253
+ - **Cat 1 (Naughty)**: Starts at [5, 5], moves every mouse step
254
+ - **Small Cheese (SC)**: 75 stationary items, +10 points each
255
+ - **Stationary Big Cheese (movbc)**: 2 items, +500 points each, don't move
256
+ - **Moving Big Cheese (crzbc)**: 2 items, +500 points each, move each step
257
+
258
+ ### Cat Movement (Random Mode)
259
+ Cats move randomly at junctions (no turning back), continue straight in corridors, pick random direction when blocked. This is the `_get_cats_direct_actions` mode in the simulator.
260
+
261
+ ### Scoring
262
+ | Event | Points |
263
+ |-------|--------|
264
+ | Collect Small Cheese | +10 |
265
+ | Collect Big Cheese | +500 |
266
+ | Hit Wall | -10 |
267
+ | Caught by Cat | -500 (+ lose 1 life) |
268
+ | Win Bonus | +(run * 10 + step) |
269
+
270
+ ### Win/Lose Conditions
271
+ - **WIN**: Collect ALL 75 Small Cheese + END token executed
272
+ - **LOSE (life)**: Life reaches 0 (caught 3 times)
273
+ - **LOSE (step)**: Step count reaches 200
274
+ - **LOSE (run)**: 20 runs exhausted without winning
275
+
276
+ ### Game Flow
277
+ 1. Game starts with mouse at [10,10], 3 lives, 20 max runs
278
+ 2. Each run: model generates a program -> program executes step by step
279
+ 3. During execution: mouse moves, cats move randomly, cheese collected, collisions checked
280
+ 4. After program ends: next run begins
281
+ 5. Continue until WIN or LOSE
282
+
283
+ ## Program Execution
284
+
285
+ When a program like `[0, 110, 106, 2, 3, 112]` executes:
286
+
287
+ 1. Token `0` (UP): mouse moves up 1 step
288
+ 2. Token `110, 106, 2` (LOOP 7 LEFT): mouse moves left 7 steps
289
+ 3. Token `3` (RIGHT): mouse moves right 1 step
290
+ 4. Token `112` (END): program ends
291
+
292
+ Each step:
293
+ - Mouse attempts to move in the direction
294
+ - If wall: mouse stays, -10 points
295
+ - Cat 1 moves (random at junctions)
296
+ - Cat 0 moves (only during command-length steps)
297
+ - Check for cat collision: -500 points, lose 1 life, respawn at [10,10]
298
+ - Check for cheese collection: +10 (SC) or +500 (BC)
299
+ - Check win/lose conditions
300
+
301
+ ## Performance
302
+
303
+ | Metric | Value |
304
+ |--------|-------|
305
+ | Win Rate (temp=0.3, 100 games) | 30% |
306
+ | Average Score | 1437 |
307
+ | Average Runs per Win | 13.8 |
308
+ | Simulator | New simulator (random cats) |
309
+
310
+ ### Training Pipeline
311
+ 1. **Base Model**: Expert R1 checkpoint (trained on old simulator, 95% win rate on old sim, 14% on new sim)
312
+ 2. **RM32 Data Generation**: 10,000 games with Running Max 32 (exhaustive 33 candidates), 20.4% win rate, 30,788 winning run samples
313
+ 3. **SFT Training**: 40 epochs, batch 4096, lr 3e-5, cosine schedule -> 30% win rate
314
+
315
+ ## Generation Parameters
316
+
317
+ | Parameter | Recommended | Description |
318
+ |-----------|-------------|-------------|
319
+ | temperature | 0.3 | Lower = more deterministic, higher win rate |
320
+ | top_k | 10 | Top-k sampling |
321
+ | grammar_constrained | True | MUST be True to generate valid programs |
322
+ | max_length | 12 | Maximum program length |
323
+
324
+ ## File Structure
325
+
326
+ ```
327
+ hardai_model_export/
328
+ model_best.pt # Model checkpoint (886MB)
329
+ README.md # This file
330
+ lightweight_simulator.py # Game simulator
331
+ model/ # Model architecture
332
+ __init__.py
333
+ model_2B.py # Main model class
334
+ state_encoder.py
335
+ program_embedding.py
336
+ transformer.py # Flash Attention + gradient checkpointing
337
+ multi_task_head.py
338
+ memory_encoder.py
339
+ memory_state_fusion.py
340
+ value_predictor.py
341
+ ```
342
+
343
+ ## Requirements
344
+
345
+ ```
346
+ torch >= 2.0
347
+ numpy
348
+ pygame (for simulator, can run headless with SDL_VIDEODRIVER=dummy)
349
+ ```
350
+
351
+ ## Headless Mode (No Display)
352
+
353
+ ```python
354
+ import os
355
+ os.environ['SDL_VIDEODRIVER'] = 'dummy'
356
+ os.environ['SDL_AUDIODRIVER'] = 'dummy'
357
+ ```
358
+
359
+ Set these BEFORE importing the simulator.
model_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5e004bee68626f67cfb62d95426ece702b1d7ed0a9acb1f43d8a0ea92160c289
3
+ size 928976104