wtfmahe commited on
Commit
d12141a
·
verified ·
1 Parent(s): f11ed71

Upload 34 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/TRM_fig.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/TRM_pseudocode.png filter=lfs diff=lfs merge=lfs -text
assets/TRM_fig.png ADDED

Git LFS Details

  • SHA256: 3cc0ac2a6eeca5af89d03ae678d121b4a3467ba7e34f0ab1b1c532bbbf4689da
  • Pointer size: 131 Bytes
  • Size of remote file: 354 kB
assets/TRM_pseudocode.png ADDED

Git LFS Details

  • SHA256: cab417c86f074113f40dbd16628b0354adafc8a0db1fc5d489dd427f11927659
  • Pointer size: 131 Bytes
  • Size of remote file: 267 kB
config/arch/hrm.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: recursive_reasoning.hrm@HierarchicalReasoningModel_ACTV1
2
+ loss:
3
+ name: losses@ACTLossHead
4
+ loss_type: stablemax_cross_entropy
5
+
6
+ halt_exploration_prob: 0.1
7
+ halt_max_steps: 16
8
+
9
+ H_cycles: 2
10
+ L_cycles: 2
11
+
12
+ H_layers: 4
13
+ L_layers: 4
14
+
15
+ hidden_size: 512
16
+ num_heads: 8 # min(2, hidden_size // 64)
17
+ expansion: 4
18
+
19
+ puzzle_emb_ndim: ${.hidden_size}
20
+
21
+ pos_encodings: rope
22
+ forward_dtype: bfloat16
23
+
24
+ mlp_t: False # use mlp on L instead of transformer
config/arch/transformers_baseline.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: recursive_reasoning.transformers_baseline@Model_ACTV2
2
+ loss:
3
+ name: losses@ACTLossHead
4
+ loss_type: stablemax_cross_entropy
5
+
6
+ halt_exploration_prob: 0.1
7
+ halt_max_steps: 16
8
+
9
+ H_cycles: 1 # kept for compatibility
10
+ H_layers: 8
11
+
12
+ hidden_size: 512
13
+ num_heads: 12
14
+ expansion: 4
15
+
16
+ puzzle_emb_ndim: ${.hidden_size}
17
+
18
+ pos_encodings: rope
config/arch/trm.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: recursive_reasoning.trm@TinyRecursiveReasoningModel_ACTV1
2
+ loss:
3
+ name: losses@ACTLossHead
4
+ loss_type: stablemax_cross_entropy
5
+
6
+ halt_exploration_prob: 0.1
7
+ halt_max_steps: 16
8
+
9
+ H_cycles: 3
10
+ L_cycles: 6
11
+
12
+ H_layers: 0
13
+ L_layers: 2
14
+
15
+ hidden_size: 512
16
+ num_heads: 8 # min(2, hidden_size // 64)
17
+ expansion: 4
18
+
19
+ puzzle_emb_ndim: ${.hidden_size}
20
+
21
+ pos_encodings: rope
22
+ forward_dtype: bfloat16
23
+
24
+ mlp_t: False # use mlp on L instead of transformer
25
+ puzzle_emb_len: 16 # if non-zero, its specified to this value
26
+ no_ACT_continue: True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense
config/arch/trm_hier6.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: recursive_reasoning.trm_hier6@TinyRecursiveReasoningModel_ACTV1
2
+ loss:
3
+ name: losses@ACTLossHead
4
+ loss_type: stablemax_cross_entropy
5
+
6
+ halt_exploration_prob: 0.1
7
+ halt_max_steps: 16
8
+
9
+ H_cycles: 3
10
+ L_cycles: 6
11
+
12
+ H_layers: 0
13
+ L_layers: 2
14
+
15
+ hidden_size: 512
16
+ num_heads: 8 # min(2, hidden_size // 64)
17
+ expansion: 4
18
+
19
+ puzzle_emb_ndim: ${.hidden_size}
20
+
21
+ pos_encodings: rope
22
+ forward_dtype: bfloat16
23
+
24
+ mlp_t: False # use mlp on L instead of transformer
25
+ puzzle_emb_len: 16 # if non-zero, its specified to this value
26
+ no_ACT_continue: True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense
config/arch/trm_singlez.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: recursive_reasoning.trm_singlez@TinyRecursiveReasoningModel_ACTV1
2
+ loss:
3
+ name: losses@ACTLossHead
4
+ loss_type: stablemax_cross_entropy
5
+
6
+ halt_exploration_prob: 0.1
7
+ halt_max_steps: 16
8
+
9
+ H_cycles: 3
10
+ L_cycles: 6
11
+
12
+ H_layers: 0
13
+ L_layers: 2
14
+
15
+ hidden_size: 512
16
+ num_heads: 8 # min(2, hidden_size // 64)
17
+ expansion: 4
18
+
19
+ puzzle_emb_ndim: ${.hidden_size}
20
+
21
+ pos_encodings: rope
22
+ forward_dtype: bfloat16
23
+
24
+ mlp_t: False # use mlp on L instead of transformer
25
+ puzzle_emb_len: 16 # if non-zero, its specified to this value
26
+ no_ACT_continue: True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense
config/cfg_pretrain.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ARC training config
2
+
3
+ defaults:
4
+ - arch: trm
5
+ - _self_
6
+
7
+ hydra:
8
+ output_subdir: null
9
+
10
+ # Data path
11
+ data_paths: ['data/arc-aug-1000']
12
+ data_paths_test: []
13
+
14
+ evaluators:
15
+ - name: arc@ARC
16
+
17
+ # Hyperparams - Training
18
+ global_batch_size: 768
19
+
20
+ epochs: 100000
21
+ eval_interval: 10000
22
+ checkpoint_every_eval: True
23
+
24
+ lr: 1e-4
25
+ lr_min_ratio: 1.0
26
+ lr_warmup_steps: 2000
27
+
28
+ # Standard hyperparameter settings for LM, as used in Llama
29
+ beta1: 0.9
30
+ beta2: 0.95
31
+ weight_decay: 0.1
32
+ puzzle_emb_weight_decay: 0.1
33
+
34
+ # Hyperparams - Puzzle embeddings training
35
+ puzzle_emb_lr: 1e-2
36
+
37
+ seed: 0
38
+ min_eval_interval: 0 # when to start the eval
39
+
40
+ ema: False # use Exponential-Moving-Average
41
+ ema_rate: 0.999 # EMA-rate
42
+ freeze_weights: False # If True, freeze weights and only learn the embeddings
dataset/build_arc_dataset.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple, Dict
2
+ from dataclasses import dataclass
3
+ import os
4
+ import json
5
+ import hashlib
6
+ import numpy as np
7
+
8
+ from argdantic import ArgParser
9
+ from pydantic import BaseModel
10
+
11
+ from dataset.common import PuzzleDatasetMetadata, dihedral_transform, inverse_dihedral_transform
12
+
13
+
14
+ cli = ArgParser()
15
+
16
+
17
+ class DataProcessConfig(BaseModel):
18
+ input_file_prefix: str
19
+ output_dir: str
20
+ subsets: List[str]
21
+ test_set_name: str
22
+ test_set_name2: str = "your_test_set"
23
+ seed: int = 42
24
+ num_aug: int = 1000
25
+ puzzle_identifiers_start: int = 1 # start > 1 to handle multiple datasets
26
+
27
+ ARCMaxGridSize = 30
28
+ ARCAugmentRetriesFactor = 5
29
+
30
+ PuzzleIdSeparator = "|||"
31
+
32
+
33
+ @dataclass
34
+ class ARCPuzzle:
35
+ id: str
36
+ examples: List[Tuple[np.ndarray, np.ndarray]]
37
+
38
+
39
+ def arc_grid_to_np(grid: List[List[int]]):
40
+ arr = np.array(grid)
41
+
42
+ # Shape check
43
+ assert arr.ndim == 2
44
+ assert arr.shape[0] <= ARCMaxGridSize and arr.shape[1] <= ARCMaxGridSize
45
+ # Element check
46
+ assert np.all((arr >= 0) & (arr <= 9))
47
+ return arr.astype(np.uint8)
48
+
49
+
50
+ def np_grid_to_seq_translational_augment(inp: np.ndarray, out: np.ndarray, do_translation: bool):
51
+ # PAD: 0, <eos>: 1, digits: 2 ... 11
52
+ # Compute random top-left pad
53
+ if do_translation:
54
+ pad_r = np.random.randint(0, ARCMaxGridSize - max(inp.shape[0], out.shape[0]) + 1)
55
+ pad_c = np.random.randint(0, ARCMaxGridSize - max(inp.shape[1], out.shape[1]) + 1)
56
+ else:
57
+ pad_r = pad_c = 0
58
+
59
+ # Pad grid
60
+ result = []
61
+ for grid in [inp, out]:
62
+ nrow, ncol = grid.shape
63
+ grid = np.pad(grid + 2, ((pad_r, ARCMaxGridSize - pad_r - nrow), (pad_c, ARCMaxGridSize - pad_c - ncol)), constant_values=0)
64
+
65
+ # Add <eos>
66
+ eos_row, eos_col = pad_r + nrow, pad_c + ncol
67
+ if eos_row < ARCMaxGridSize:
68
+ grid[eos_row, pad_c:eos_col] = 1
69
+ if eos_col < ARCMaxGridSize:
70
+ grid[pad_r:eos_row, eos_col] = 1
71
+
72
+ result.append(grid.flatten())
73
+
74
+ return result
75
+
76
+
77
+ def grid_hash(grid: np.ndarray):
78
+ assert grid.ndim == 2
79
+ assert grid.dtype == np.uint8
80
+
81
+ buffer = [x.to_bytes(1, byteorder='big') for x in grid.shape]
82
+ buffer.append(grid.tobytes())
83
+
84
+ return hashlib.sha256(b"".join(buffer)).hexdigest()
85
+
86
+
87
+ def puzzle_hash(puzzle: dict):
88
+ # Hash the puzzle for checking equivalence
89
+ hashes = []
90
+ for example_type, example in puzzle.items():
91
+ for input, label in example.examples:
92
+ hashes.append(f"{grid_hash(input)}|{grid_hash(label)}")
93
+
94
+ hashes.sort()
95
+ return hashlib.sha256("|".join(hashes).encode()).hexdigest()
96
+
97
+
98
+ def aug(name: str):
99
+ # Augment plan
100
+ trans_id = np.random.randint(0, 8)
101
+ mapping = np.concatenate([np.arange(0, 1, dtype=np.uint8), np.random.permutation(np.arange(1, 10, dtype=np.uint8))]) # Permute colors, Excluding "0" (black)
102
+
103
+ name_with_aug_repr = f"{name}{PuzzleIdSeparator}t{trans_id}{PuzzleIdSeparator}{''.join(str(x) for x in mapping)}"
104
+
105
+ def _map_grid(grid: np.ndarray):
106
+ return dihedral_transform(mapping[grid], trans_id)
107
+
108
+ return name_with_aug_repr, _map_grid
109
+
110
+
111
+ def inverse_aug(name: str):
112
+ # Inverse the "aug" function
113
+ if PuzzleIdSeparator not in name:
114
+ return name, lambda x: x
115
+
116
+ trans_id, perm = name.split(PuzzleIdSeparator)[-2:]
117
+ trans_id = int(trans_id[1:]) # Remove "t" letter
118
+ inv_perm = np.argsort(list(perm)).astype(np.uint8)
119
+
120
+ def _map_grid(grid: np.ndarray):
121
+ return inv_perm[inverse_dihedral_transform(grid, trans_id)]
122
+
123
+ return name.split(PuzzleIdSeparator)[0], _map_grid
124
+
125
+
126
+ def convert_single_arc_puzzle(results: dict, name: str, puzzle: dict, aug_count: int, dest_mapping: Dict[str, Tuple[str, str]]):
127
+ # Convert
128
+ dests = set(dest_mapping.values())
129
+ converted = {dest: ARCPuzzle(name, []) for dest in dests}
130
+ for example_type, examples in puzzle.items():
131
+ # Map to target split
132
+ dest = dest_mapping[example_type]
133
+ converted[dest].examples.extend([(arc_grid_to_np(example["input"]), arc_grid_to_np(example["output"])) for example in examples])
134
+
135
+ group = [converted]
136
+
137
+ # Augment
138
+ if aug_count > 0:
139
+ hashes = {puzzle_hash(converted)}
140
+
141
+ for _trial in range(ARCAugmentRetriesFactor * aug_count):
142
+ aug_name, _map_grid = aug(name)
143
+
144
+ # Check duplicate
145
+ augmented = {dest: ARCPuzzle(aug_name, [(_map_grid(input), _map_grid(label)) for (input, label) in puzzle.examples]) for dest, puzzle in converted.items()}
146
+ h = puzzle_hash(augmented)
147
+ if h not in hashes:
148
+ hashes.add(h)
149
+ group.append(augmented)
150
+
151
+ if len(group) >= aug_count + 1:
152
+ break
153
+
154
+ if len(group) < aug_count + 1:
155
+ print (f"[Puzzle {name}] augmentation not full, only {len(group)}")
156
+
157
+ # Append
158
+ for dest in dests:
159
+ # Convert the examples
160
+ dest_split, dest_set = dest
161
+
162
+ results.setdefault(dest_split, {})
163
+ results[dest_split].setdefault(dest_set, [])
164
+ results[dest_split][dest_set].append([converted[dest] for converted in group])
165
+
166
+
167
+ def load_puzzles_arcagi(config: DataProcessConfig):
168
+ train_examples_dest = ("train", "all")
169
+ test_examples_map = {
170
+ config.test_set_name: [(1.0, ("test", "all"))],
171
+ config.test_set_name2: [(1.0, ("test", "all"))],
172
+ "_default": [(1.0, ("train", "all"))]
173
+ }
174
+
175
+ test_puzzles = {}
176
+ results = {}
177
+
178
+ total_puzzles = 0
179
+ for subset_name in config.subsets:
180
+ # Load all puzzles in this subset
181
+ with open(f"{config.input_file_prefix}_{subset_name}_challenges.json", "r") as f:
182
+ puzzles = json.load(f)
183
+
184
+ sols_filename = f"{config.input_file_prefix}_{subset_name}_solutions.json"
185
+ if os.path.isfile(sols_filename):
186
+ with open(sols_filename, "r") as f:
187
+ sols = json.load(f)
188
+
189
+ for puzzle_id in puzzles.keys():
190
+ for idx, sol_grid in enumerate(sols[puzzle_id]):
191
+ puzzles[puzzle_id]["test"][idx]["output"] = sol_grid
192
+ else:
193
+ # Fill with dummy
194
+ print (f"{subset_name} solutions not found, filling with dummy")
195
+
196
+ for puzzle_id, puzzle in puzzles.items():
197
+ for example in puzzle["test"]:
198
+ example.setdefault("output", [[0]])
199
+
200
+ # Shuffle puzzles
201
+ puzzles = list(puzzles.items())
202
+ np.random.shuffle(puzzles)
203
+
204
+ # Assign by fraction
205
+ for idx, (name, puzzle) in enumerate(puzzles):
206
+ fraction = idx / len(puzzles)
207
+ test_examples_dest = None
208
+ for f, dest in test_examples_map.get(subset_name, test_examples_map["_default"]):
209
+ if fraction < f:
210
+ test_examples_dest = dest
211
+ break
212
+
213
+ assert test_examples_dest is not None
214
+
215
+ if test_examples_dest[0] == "test":
216
+ test_puzzles[name] = puzzle
217
+
218
+ convert_single_arc_puzzle(results, name, puzzle, config.num_aug, {"train": train_examples_dest, "test": test_examples_dest})
219
+ total_puzzles += 1
220
+
221
+ print (f"Total puzzles: {total_puzzles}")
222
+ return results, test_puzzles
223
+
224
+
225
+ def convert_dataset(config: DataProcessConfig):
226
+ np.random.seed(config.seed)
227
+
228
+ # Read dataset
229
+ data, test_puzzles = load_puzzles_arcagi(config)
230
+
231
+ # Map global puzzle identifiers
232
+ num_identifiers = config.puzzle_identifiers_start # 0 is blank, start at 1
233
+ identifier_map = {}
234
+ for split_name, split in data.items():
235
+ for subset_name, subset in split.items():
236
+ for group in subset:
237
+ for puzzle in group:
238
+ if puzzle.id not in identifier_map:
239
+ identifier_map[puzzle.id] = num_identifiers
240
+ num_identifiers += 1
241
+ print (f"Total puzzle IDs (including <blank>): {num_identifiers}")
242
+
243
+ # Save
244
+ for split_name, split in data.items():
245
+ os.makedirs(os.path.join(config.output_dir, split_name), exist_ok=True)
246
+
247
+ # Translational augmentations
248
+ enable_translational_augment = split_name == "train"
249
+
250
+ # Statistics
251
+ total_examples = 0
252
+ total_puzzles = 0
253
+ total_groups = 0
254
+
255
+ for subset_name, subset in split.items(): # "all" is the only subset
256
+ # Construct subset
257
+ results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
258
+ results["puzzle_indices"].append(0)
259
+ results["group_indices"].append(0)
260
+
261
+ example_id = 0
262
+ puzzle_id = 0
263
+
264
+ for group in subset:
265
+ for puzzle in group:
266
+ # Push puzzle
267
+ no_aug_id = np.random.randint(0, len(puzzle.examples))
268
+ for _idx_ex, (inp, out) in enumerate(puzzle.examples):
269
+ inp, out = np_grid_to_seq_translational_augment(inp, out, do_translation=enable_translational_augment and _idx_ex != no_aug_id)
270
+
271
+ results["inputs"].append(inp)
272
+ results["labels"].append(out)
273
+ example_id += 1
274
+
275
+ total_examples += 1
276
+
277
+ results["puzzle_indices"].append(example_id)
278
+ results["puzzle_identifiers"].append(identifier_map[puzzle.id])
279
+
280
+ puzzle_id += 1
281
+ total_puzzles += 1
282
+
283
+ # Push group
284
+ results["group_indices"].append(puzzle_id)
285
+ total_groups += 1
286
+
287
+ for k, v in results.items():
288
+ if k in {"inputs", "labels"}:
289
+ v = np.stack(v, 0)
290
+ else:
291
+ v = np.array(v, dtype=np.int32)
292
+
293
+ np.save(os.path.join(config.output_dir, split_name, f"{subset_name}__{k}.npy"), v)
294
+
295
+ # Metadata
296
+ metadata = PuzzleDatasetMetadata(
297
+ seq_len=ARCMaxGridSize * ARCMaxGridSize,
298
+ vocab_size=10 + 2, # PAD + EOS + "0" ... "9"
299
+ pad_id=0,
300
+ ignore_label_id=0,
301
+ blank_identifier_id=0,
302
+ num_puzzle_identifiers=num_identifiers,
303
+ total_groups=total_groups,
304
+ mean_puzzle_examples=total_examples / total_puzzles,
305
+ total_puzzles=total_puzzles,
306
+ sets=list(split.keys())
307
+ )
308
+
309
+ # Save metadata as JSON.
310
+ with open(os.path.join(config.output_dir, split_name, "dataset.json"), "w") as f:
311
+ json.dump(metadata.model_dump(), f)
312
+
313
+ # Save IDs mapping
314
+ with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
315
+ ids_mapping = {v: k for k, v in identifier_map.items()}
316
+ json.dump([ids_mapping.get(i, "<blank>") for i in range(num_identifiers)], f)
317
+
318
+ # Save Test Puzzles
319
+ with open(os.path.join(config.output_dir, "test_puzzles.json"), "w") as f:
320
+ json.dump(test_puzzles, f)
321
+
322
+
323
+ @cli.command(singleton=True)
324
+ def main(config: DataProcessConfig):
325
+ convert_dataset(config)
326
+
327
+
328
+ if __name__ == "__main__":
329
+ cli()
330
+
331
+
332
+
333
+
334
+
335
+
336
+
337
+
338
+
339
+
340
+
341
+
dataset/build_maze_dataset.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import math
3
+ import os
4
+ import csv
5
+ import json
6
+ import numpy as np
7
+
8
+ from argdantic import ArgParser
9
+ from pydantic import BaseModel
10
+ from tqdm import tqdm
11
+ from huggingface_hub import hf_hub_download
12
+
13
+ from common import PuzzleDatasetMetadata, dihedral_transform
14
+
15
+
16
+ CHARSET = "# SGo"
17
+
18
+
19
+ cli = ArgParser()
20
+
21
+
22
+ class DataProcessConfig(BaseModel):
23
+ source_repo: str = "sapientinc/maze-30x30-hard-1k"
24
+ output_dir: str = "data/maze-30x30-hard-1k"
25
+
26
+ subsample_size: Optional[int] = None
27
+ aug: bool = False
28
+
29
+
30
+ def convert_subset(set_name: str, config: DataProcessConfig):
31
+ # Read CSV
32
+ all_chars = set()
33
+ grid_size = None
34
+ inputs = []
35
+ labels = []
36
+
37
+ with open(hf_hub_download(config.source_repo, f"{set_name}.csv", repo_type="dataset"), newline="") as csvfile: # type: ignore
38
+ reader = csv.reader(csvfile)
39
+ next(reader) # Skip header
40
+ for source, q, a, rating in reader:
41
+ all_chars.update(q)
42
+ all_chars.update(a)
43
+
44
+ if grid_size is None:
45
+ n = int(len(q) ** 0.5)
46
+ grid_size = (n, n)
47
+
48
+ inputs.append(np.frombuffer(q.encode(), dtype=np.uint8).reshape(grid_size))
49
+ labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(grid_size))
50
+
51
+ # If subsample_size is specified for the training set,
52
+ # randomly sample the desired number of examples.
53
+ if set_name == "train" and config.subsample_size is not None:
54
+ total_samples = len(inputs)
55
+ if config.subsample_size < total_samples:
56
+ indices = np.random.choice(total_samples, size=config.subsample_size, replace=False)
57
+ inputs = [inputs[i] for i in indices]
58
+ labels = [labels[i] for i in indices]
59
+
60
+ # Generate dataset
61
+ results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
62
+ puzzle_id = 0
63
+ example_id = 0
64
+
65
+ results["puzzle_indices"].append(0)
66
+ results["group_indices"].append(0)
67
+
68
+ for inp, out in zip(tqdm(inputs), labels):
69
+ # Dihedral transformations for augmentation
70
+ for aug_idx in range(8 if (set_name == "train" and config.aug) else 1):
71
+ results["inputs"].append(dihedral_transform(inp, aug_idx))
72
+ results["labels"].append(dihedral_transform(out, aug_idx))
73
+ example_id += 1
74
+ puzzle_id += 1
75
+
76
+ results["puzzle_indices"].append(example_id)
77
+ results["puzzle_identifiers"].append(0)
78
+
79
+ # Push group
80
+ results["group_indices"].append(puzzle_id)
81
+
82
+ # Char mappings
83
+ assert len(all_chars - set(CHARSET)) == 0
84
+
85
+ char2id = np.zeros(256, np.uint8)
86
+ char2id[np.array(list(map(ord, CHARSET)))] = np.arange(len(CHARSET)) + 1
87
+
88
+ # To Numpy
89
+ def _seq_to_numpy(seq):
90
+ arr = np.vstack([char2id[s.reshape(-1)] for s in seq])
91
+
92
+ return arr
93
+
94
+ results = {
95
+ "inputs": _seq_to_numpy(results["inputs"]),
96
+ "labels": _seq_to_numpy(results["labels"]),
97
+
98
+ "group_indices": np.array(results["group_indices"], dtype=np.int32),
99
+ "puzzle_indices": np.array(results["puzzle_indices"], dtype=np.int32),
100
+ "puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32),
101
+ }
102
+
103
+ # Metadata
104
+ metadata = PuzzleDatasetMetadata(
105
+ seq_len=int(math.prod(grid_size)), # type: ignore
106
+ vocab_size=len(CHARSET) + 1, # PAD + Charset
107
+ pad_id=0,
108
+ ignore_label_id=0,
109
+ blank_identifier_id=0,
110
+ num_puzzle_identifiers=1,
111
+ total_groups=len(results["group_indices"]) - 1,
112
+ mean_puzzle_examples=1,
113
+ total_puzzles=len(results["group_indices"]) - 1,
114
+ sets=["all"]
115
+ )
116
+
117
+ # Save metadata as JSON.
118
+ save_dir = os.path.join(config.output_dir, set_name)
119
+ os.makedirs(save_dir, exist_ok=True)
120
+
121
+ with open(os.path.join(save_dir, "dataset.json"), "w") as f:
122
+ json.dump(metadata.model_dump(), f)
123
+
124
+ # Save data
125
+ for k, v in results.items():
126
+ np.save(os.path.join(save_dir, f"all__{k}.npy"), v)
127
+
128
+ # Save IDs mapping (for visualization only)
129
+ with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
130
+ json.dump(["<blank>"], f)
131
+
132
+
133
+ @cli.command(singleton=True)
134
+ def preprocess_data(config: DataProcessConfig):
135
+ convert_subset("train", config)
136
+ convert_subset("test", config)
137
+
138
+
139
+ if __name__ == "__main__":
140
+ cli()
dataset/build_sudoku_dataset.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import os
3
+ import csv
4
+ import json
5
+ import numpy as np
6
+
7
+ from argdantic import ArgParser
8
+ from pydantic import BaseModel
9
+ from tqdm import tqdm
10
+ from huggingface_hub import hf_hub_download
11
+
12
+ from common import PuzzleDatasetMetadata
13
+
14
+
15
+ cli = ArgParser()
16
+
17
+
18
+ class DataProcessConfig(BaseModel):
19
+ source_repo: str = "sapientinc/sudoku-extreme"
20
+ output_dir: str = "data/sudoku-extreme-full"
21
+
22
+ subsample_size: Optional[int] = None
23
+ min_difficulty: Optional[int] = None
24
+ num_aug: int = 0
25
+
26
+
27
+ def shuffle_sudoku(board: np.ndarray, solution: np.ndarray):
28
+ # Create a random digit mapping: a permutation of 1..9, with zero (blank) unchanged
29
+ digit_map = np.pad(np.random.permutation(np.arange(1, 10)), (1, 0))
30
+
31
+ # Randomly decide whether to transpose.
32
+ transpose_flag = np.random.rand() < 0.5
33
+
34
+ # Generate a valid row permutation:
35
+ # - Shuffle the 3 bands (each band = 3 rows) and for each band, shuffle its 3 rows.
36
+ bands = np.random.permutation(3)
37
+ row_perm = np.concatenate([b * 3 + np.random.permutation(3) for b in bands])
38
+
39
+ # Similarly for columns (stacks).
40
+ stacks = np.random.permutation(3)
41
+ col_perm = np.concatenate([s * 3 + np.random.permutation(3) for s in stacks])
42
+
43
+ # Build an 81->81 mapping. For each new cell at (i, j)
44
+ # (row index = i // 9, col index = i % 9),
45
+ # its value comes from old row = row_perm[i//9] and old col = col_perm[i%9].
46
+ mapping = np.array([row_perm[i // 9] * 9 + col_perm[i % 9] for i in range(81)])
47
+
48
+ def apply_transformation(x: np.ndarray) -> np.ndarray:
49
+ # Apply transpose flag
50
+ if transpose_flag:
51
+ x = x.T
52
+ # Apply the position mapping.
53
+ new_board = x.flatten()[mapping].reshape(9, 9).copy()
54
+ # Apply digit mapping
55
+ return digit_map[new_board]
56
+
57
+ return apply_transformation(board), apply_transformation(solution)
58
+
59
+
60
+ def convert_subset(set_name: str, config: DataProcessConfig):
61
+ # Read CSV
62
+ inputs = []
63
+ labels = []
64
+
65
+ with open(hf_hub_download(config.source_repo, f"{set_name}.csv", repo_type="dataset"), newline="") as csvfile:
66
+ reader = csv.reader(csvfile)
67
+ next(reader) # Skip header
68
+ for source, q, a, rating in reader:
69
+ if (config.min_difficulty is None) or (int(rating) >= config.min_difficulty):
70
+ assert len(q) == 81 and len(a) == 81
71
+
72
+ inputs.append(np.frombuffer(q.replace('.', '0').encode(), dtype=np.uint8).reshape(9, 9) - ord('0'))
73
+ labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(9, 9) - ord('0'))
74
+
75
+ # If subsample_size is specified for the training set,
76
+ # randomly sample the desired number of examples.
77
+ if set_name == "train" and config.subsample_size is not None:
78
+ total_samples = len(inputs)
79
+ if config.subsample_size < total_samples:
80
+ indices = np.random.choice(total_samples, size=config.subsample_size, replace=False)
81
+ inputs = [inputs[i] for i in indices]
82
+ labels = [labels[i] for i in indices]
83
+
84
+ # Generate dataset
85
+ num_augments = config.num_aug if set_name == "train" else 0
86
+
87
+ results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
88
+ puzzle_id = 0
89
+ example_id = 0
90
+
91
+ results["puzzle_indices"].append(0)
92
+ results["group_indices"].append(0)
93
+
94
+ for orig_inp, orig_out in zip(tqdm(inputs), labels):
95
+ for aug_idx in range(1 + num_augments):
96
+ # First index is not augmented
97
+ if aug_idx == 0:
98
+ inp, out = orig_inp, orig_out
99
+ else:
100
+ inp, out = shuffle_sudoku(orig_inp, orig_out)
101
+
102
+ # Push puzzle (only single example)
103
+ results["inputs"].append(inp)
104
+ results["labels"].append(out)
105
+ example_id += 1
106
+ puzzle_id += 1
107
+
108
+ results["puzzle_indices"].append(example_id)
109
+ results["puzzle_identifiers"].append(0)
110
+
111
+ # Push group
112
+ results["group_indices"].append(puzzle_id)
113
+
114
+ # To Numpy
115
+ def _seq_to_numpy(seq):
116
+ arr = np.concatenate(seq).reshape(len(seq), -1)
117
+
118
+ assert np.all((arr >= 0) & (arr <= 9))
119
+ return arr + 1
120
+
121
+ results = {
122
+ "inputs": _seq_to_numpy(results["inputs"]),
123
+ "labels": _seq_to_numpy(results["labels"]),
124
+
125
+ "group_indices": np.array(results["group_indices"], dtype=np.int32),
126
+ "puzzle_indices": np.array(results["puzzle_indices"], dtype=np.int32),
127
+ "puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32),
128
+ }
129
+
130
+ # Metadata
131
+ metadata = PuzzleDatasetMetadata(
132
+ seq_len=81,
133
+ vocab_size=10 + 1, # PAD + "0" ... "9"
134
+ pad_id=0,
135
+ ignore_label_id=0,
136
+ blank_identifier_id=0,
137
+ num_puzzle_identifiers=1,
138
+ total_groups=len(results["group_indices"]) - 1,
139
+ mean_puzzle_examples=1,
140
+ total_puzzles=len(results["group_indices"]) - 1,
141
+ sets=["all"]
142
+ )
143
+
144
+ # Save metadata as JSON.
145
+ save_dir = os.path.join(config.output_dir, set_name)
146
+ os.makedirs(save_dir, exist_ok=True)
147
+
148
+ with open(os.path.join(save_dir, "dataset.json"), "w") as f:
149
+ json.dump(metadata.model_dump(), f)
150
+
151
+ # Save data
152
+ for k, v in results.items():
153
+ np.save(os.path.join(save_dir, f"all__{k}.npy"), v)
154
+
155
+ # Save IDs mapping (for visualization only)
156
+ with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
157
+ json.dump(["<blank>"], f)
158
+
159
+
160
+ @cli.command(singleton=True)
161
+ def preprocess_data(config: DataProcessConfig):
162
+ convert_subset("train", config)
163
+ convert_subset("test", config)
164
+
165
+
166
+ if __name__ == "__main__":
167
+ cli()
dataset/common.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import pydantic
4
+ import numpy as np
5
+
6
+
7
+ # Global list mapping each dihedral transform id to its inverse.
8
+ # Index corresponds to the original tid, and the value is its inverse.
9
+ DIHEDRAL_INVERSE = [0, 3, 2, 1, 4, 5, 6, 7]
10
+
11
+
12
+ class PuzzleDatasetMetadata(pydantic.BaseModel):
13
+ pad_id: int
14
+ ignore_label_id: Optional[int]
15
+ blank_identifier_id: int
16
+ vocab_size: int
17
+ seq_len: int
18
+ num_puzzle_identifiers: int
19
+ total_groups: int
20
+ mean_puzzle_examples: float
21
+ total_puzzles: int
22
+ sets: List[str]
23
+
24
+
25
+ def dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:
26
+ """8 dihedral symmetries by rotate, flip and mirror"""
27
+
28
+ if tid == 0:
29
+ return arr # identity
30
+ elif tid == 1:
31
+ return np.rot90(arr, k=1)
32
+ elif tid == 2:
33
+ return np.rot90(arr, k=2)
34
+ elif tid == 3:
35
+ return np.rot90(arr, k=3)
36
+ elif tid == 4:
37
+ return np.fliplr(arr) # horizontal flip
38
+ elif tid == 5:
39
+ return np.flipud(arr) # vertical flip
40
+ elif tid == 6:
41
+ return arr.T # transpose (reflection along main diagonal)
42
+ elif tid == 7:
43
+ return np.fliplr(np.rot90(arr, k=1)) # anti-diagonal reflection
44
+ else:
45
+ return arr
46
+
47
+
48
+ def inverse_dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:
49
+ return dihedral_transform(arr, DIHEDRAL_INVERSE[tid])
evaluators/arc.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Sequence, Optional
2
+ import os
3
+ import json
4
+
5
+ import torch
6
+ import numpy as np
7
+ from numba import njit
8
+ import torch.distributed as dist
9
+
10
+ from dataset.build_arc_dataset import inverse_aug, grid_hash, arc_grid_to_np
11
+ from dataset.common import PuzzleDatasetMetadata
12
+
13
+ @njit
14
+ def _crop(grid: np.ndarray):
15
+ """Find maximum-sized rectangle without any EOS token inside. """
16
+ grid = grid.reshape(30, 30)
17
+
18
+ max_area = 0
19
+ max_size = (0, 0)
20
+ nr, nc = grid.shape
21
+
22
+ num_c = nc
23
+ for num_r in range(1, nr + 1):
24
+ # Scan for maximum c
25
+ for c in range(1, num_c + 1):
26
+ x = grid[num_r - 1, c - 1]
27
+ if (x < 2) | (x > 11):
28
+ num_c = c - 1
29
+ break
30
+
31
+ area = num_r * num_c
32
+ if area > max_area:
33
+ max_area = area
34
+ max_size = (num_r, num_c)
35
+
36
+ return (grid[:max_size[0], :max_size[1]] - 2).astype(np.uint8)
37
+
38
+
39
+ class ARC:
40
+ required_outputs = {"inputs", "puzzle_identifiers", "q_halt_logits", "preds"}
41
+
42
+ def __init__(self, data_path: str,
43
+ eval_metadata: PuzzleDatasetMetadata,
44
+ submission_K: int = 2,
45
+ pass_Ks: Sequence[int] = (1, 2, 5, 10, 100, 1000),
46
+ aggregated_voting: bool = True):
47
+ super().__init__()
48
+ self.pass_Ks = pass_Ks
49
+ self.submission_K = submission_K
50
+ self.aggregated_voting = aggregated_voting
51
+ self.blank_identifier_id = eval_metadata.blank_identifier_id
52
+
53
+ # Load identifiers and test puzzles
54
+ with open(os.path.join(data_path, "identifiers.json"), "r") as f:
55
+ self.identifier_map = json.load(f)
56
+ with open(os.path.join(data_path, "test_puzzles.json"), "r") as f:
57
+ self.test_puzzles = json.load(f)
58
+
59
+ # States
60
+ self._local_hmap = {}
61
+ self._local_preds = {}
62
+
63
+ def begin_eval(self):
64
+ if not self.aggregated_voting:
65
+ # Clear previous predictions
66
+ self._local_hmap = {}
67
+ self._local_preds = {}
68
+
69
+ def update_batch(self, batch: Dict[str, torch.Tensor], preds: Dict[str, torch.Tensor]):
70
+ # Collect required outputs to CPU
71
+ outputs = {}
72
+ q_values = None
73
+
74
+ for collection in (batch, preds):
75
+ for k, v in collection.items():
76
+ if k in self.required_outputs:
77
+ if k == "q_halt_logits":
78
+ q_values = v.to(torch.float64).sigmoid().cpu()
79
+ else:
80
+ outputs[k] = v.cpu()
81
+
82
+ assert q_values is not None
83
+
84
+ # Remove padding from outputs
85
+ mask = outputs["puzzle_identifiers"] != self.blank_identifier_id
86
+ outputs = {k: v[mask] for k, v in outputs.items()}
87
+
88
+ # Get predictions
89
+ for identifier, input, pred, q in zip(outputs["puzzle_identifiers"].numpy(), outputs["inputs"].numpy(), outputs["preds"].numpy(), q_values.numpy()):
90
+ name = self.identifier_map[identifier]
91
+ orig_name, _inverse_fn = inverse_aug(name)
92
+
93
+ input_hash = grid_hash(_inverse_fn(_crop(input)))
94
+
95
+ pred = _inverse_fn(_crop(pred))
96
+ assert np.all((pred >= 0) & (pred <= 9)), f"Puzzle {name}'s prediction out of 0-9 range." # Sanity check
97
+
98
+ # Store into local state
99
+ pred_hash = grid_hash(pred)
100
+
101
+ self._local_hmap[pred_hash] = pred
102
+
103
+ self._local_preds.setdefault(orig_name, {})
104
+ self._local_preds[orig_name].setdefault(input_hash, [])
105
+ self._local_preds[orig_name][input_hash].append((pred_hash, float(q)))
106
+
107
+ def result(self, save_path: Optional[str], rank: int, world_size: int, group: Optional[torch.distributed.ProcessGroup] = None) -> Optional[Dict[str, float]]:
108
+ # Gather predictions to rank 0 for voting
109
+ global_hmap_preds = [None for _ in range(world_size)] if rank == 0 else None
110
+ dist.gather_object((self._local_hmap, self._local_preds), global_hmap_preds, dst=0, group=group)
111
+
112
+ # Rank 0 logic
113
+ if rank != 0:
114
+ return
115
+
116
+ submission = {}
117
+ correct = [0.0 for _ in range(len(self.pass_Ks))]
118
+
119
+ for name, puzzle in self.test_puzzles.items():
120
+ # Process test examples in this puzzle
121
+ submission[name] = []
122
+ num_test_correct = [0 for _ in range(len(self.pass_Ks))]
123
+ for pair in puzzle["test"]:
124
+ input_hash = grid_hash(arc_grid_to_np(pair["input"]))
125
+ label_hash = grid_hash(arc_grid_to_np(pair["output"]))
126
+
127
+ p_map = {}
128
+ for hmap, preds in global_hmap_preds: # type: ignore
129
+ for h, q in preds.get(name, {}).get(input_hash, {}):
130
+ p_map.setdefault(h, [0, 0])
131
+ p_map[h][0] += 1
132
+ p_map[h][1] += q
133
+
134
+ if not len(p_map):
135
+ print (f"Puzzle {name} has no predictions.")
136
+ continue
137
+
138
+ for h, stats in p_map.items():
139
+ stats[1] /= stats[0]
140
+
141
+ p_map = sorted(p_map.items(), key=lambda kv: kv[1], reverse=True)
142
+
143
+ # vote for different Ks
144
+ for i, k in enumerate(self.pass_Ks):
145
+ ok = False
146
+ for h, stats in p_map[:k]:
147
+ ok |= h == label_hash
148
+
149
+ num_test_correct[i] += ok
150
+
151
+ # Query grids
152
+ pred_grids = []
153
+ for h, stats in p_map[:self.submission_K]:
154
+ for hmap, preds in global_hmap_preds: # type: ignore
155
+ if h in hmap:
156
+ pred_grids.append(hmap[h])
157
+ break
158
+
159
+ # Pad to K
160
+ while len(pred_grids) < self.submission_K:
161
+ pred_grids.append(pred_grids[0])
162
+
163
+ submission[name].append({f"attempt_{i + 1}": grid.tolist() for i, grid in enumerate(pred_grids)})
164
+
165
+ # Total correctness
166
+ for i in range(len(self.pass_Ks)):
167
+ correct[i] += num_test_correct[i] / len(puzzle["test"])
168
+
169
+ # Save submission
170
+ if save_path is not None:
171
+ with open(os.path.join(save_path, "submission.json"), "w") as f:
172
+ json.dump(submission, f)
173
+
174
+ # Final result
175
+ all_results = {f"ARC/pass@{k}": correct[i] / len(self.test_puzzles) for i, k in enumerate(self.pass_Ks)}
176
+
177
+ return all_results
kaggle/combined/arc-agi_concept_challenges.json ADDED
The diff for this file is too large to render. See raw diff
 
kaggle/combined/arc-agi_concept_solutions.json ADDED
The diff for this file is too large to render. See raw diff
 
kaggle/combined/arc-agi_evaluation2_challenges.json ADDED
The diff for this file is too large to render. See raw diff
 
kaggle/combined/arc-agi_evaluation2_solutions.json ADDED
The diff for this file is too large to render. See raw diff
 
kaggle/combined/arc-agi_evaluation_challenges.json ADDED
The diff for this file is too large to render. See raw diff
 
kaggle/combined/arc-agi_evaluation_solutions.json ADDED
The diff for this file is too large to render. See raw diff
 
kaggle/combined/arc-agi_training2_challenges.json ADDED
The diff for this file is too large to render. See raw diff
 
kaggle/combined/arc-agi_training2_solutions.json ADDED
The diff for this file is too large to render. See raw diff
 
kaggle/combined/arc-agi_training_challenges.json ADDED
The diff for this file is too large to render. See raw diff
 
kaggle/combined/arc-agi_training_solutions.json ADDED
The diff for this file is too large to render. See raw diff
 
models/common.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+ def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, lower: float = -2.0, upper: float = 2.0):
8
+ # NOTE: PyTorch nn.init.trunc_normal_ is not mathematically correct, the std dev is not actually the std dev of initialized tensor
9
+ # This function is a PyTorch version of jax truncated normal init (default init method in flax)
10
+ # https://github.com/jax-ml/jax/blob/main/jax/_src/random.py#L807-L848
11
+ # https://github.com/jax-ml/jax/blob/main/jax/_src/nn/initializers.py#L162-L199
12
+
13
+ with torch.no_grad():
14
+ if std == 0:
15
+ tensor.zero_()
16
+ else:
17
+ sqrt2 = math.sqrt(2)
18
+ a = math.erf(lower / sqrt2)
19
+ b = math.erf(upper / sqrt2)
20
+ z = (b - a) / 2
21
+
22
+ c = (2 * math.pi) ** -0.5
23
+ pdf_u = c * math.exp(-0.5 * lower ** 2)
24
+ pdf_l = c * math.exp(-0.5 * upper ** 2)
25
+ comp_std = std / math.sqrt(1 - (upper * pdf_u - lower * pdf_l) / z - ((pdf_u - pdf_l) / z) ** 2)
26
+
27
+ tensor.uniform_(a, b)
28
+ tensor.erfinv_()
29
+ tensor.mul_(sqrt2 * comp_std)
30
+ tensor.clip_(lower * comp_std, upper * comp_std)
31
+
32
+ return tensor
models/ema.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import torch.nn as nn
3
+
4
+ class EMAHelper(object):
5
+ def __init__(self, mu=0.999):
6
+ self.mu = mu
7
+ self.shadow = {}
8
+
9
+ def register(self, module):
10
+ if isinstance(module, nn.DataParallel):
11
+ module = module.module
12
+ for name, param in module.named_parameters():
13
+ if param.requires_grad:
14
+ self.shadow[name] = param.data.clone()
15
+
16
+ def update(self, module):
17
+ if isinstance(module, nn.DataParallel):
18
+ module = module.module
19
+ for name, param in module.named_parameters():
20
+ if param.requires_grad:
21
+ self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data
22
+
23
+ def ema(self, module):
24
+ if isinstance(module, nn.DataParallel):
25
+ module = module.module
26
+ for name, param in module.named_parameters():
27
+ if param.requires_grad:
28
+ param.data.copy_(self.shadow[name].data)
29
+
30
+ def ema_copy(self, module):
31
+ module_copy = copy.deepcopy(module)
32
+ self.ema(module_copy)
33
+ return module_copy
34
+
35
+ def state_dict(self):
36
+ return self.shadow
37
+
38
+ def load_state_dict(self, state_dict):
39
+ self.shadow = state_dict
40
+
models/layers.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import einops
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+
7
+ #try:
8
+ # from flash_attn_interface import flash_attn_func # type: ignore[import]
9
+ #except ImportError:
10
+ # # Fallback to FlashAttention 2
11
+ # from flash_attn import flash_attn_func # type: ignore[import]
12
+ from torch.nn.functional import scaled_dot_product_attention
13
+
14
+ from models.common import trunc_normal_init_
15
+
16
+
17
+ CosSin = Tuple[torch.Tensor, torch.Tensor]
18
+
19
+
20
+ def _find_multiple(a, b):
21
+ return (-(a // -b)) * b
22
+
23
+
24
+ def rotate_half(x: torch.Tensor):
25
+ """Rotates half the hidden dims of the input."""
26
+ x1 = x[..., : x.shape[-1] // 2]
27
+ x2 = x[..., x.shape[-1] // 2 :]
28
+ return torch.cat((-x2, x1), dim=-1)
29
+
30
+
31
+ def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
32
+ # q, k: [bs, seq_len, num_heads, head_dim]
33
+ # cos, sin: [seq_len, head_dim]
34
+ orig_dtype = q.dtype
35
+ q = q.to(cos.dtype)
36
+ k = k.to(cos.dtype)
37
+
38
+ q_embed = (q * cos.unsqueeze(-2)) + (rotate_half(q) * sin.unsqueeze(-2))
39
+ k_embed = (k * cos.unsqueeze(-2)) + (rotate_half(k) * sin.unsqueeze(-2))
40
+
41
+ return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
42
+
43
+
44
+ class CastedLinear(nn.Module):
45
+ def __init__(self,
46
+ in_features: int,
47
+ out_features: int,
48
+ bias: bool):
49
+ super().__init__()
50
+ # Truncated LeCun normal init
51
+ self.weight = nn.Parameter(
52
+ trunc_normal_init_(torch.empty((out_features, in_features)), std=1.0 / (in_features ** 0.5))
53
+ )
54
+ self.bias = None
55
+ if bias:
56
+ # Zero init bias
57
+ self.bias = nn.Parameter(torch.zeros((out_features, )))
58
+
59
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
60
+ return F.linear(input, self.weight.to(input.dtype), bias=self.bias.to(input.dtype) if self.bias is not None else None)
61
+
62
+
63
+ class CastedEmbedding(nn.Module):
64
+ def __init__(self,
65
+ num_embeddings: int,
66
+ embedding_dim: int,
67
+ init_std: float,
68
+ cast_to: torch.dtype):
69
+ super().__init__()
70
+ self.cast_to = cast_to
71
+
72
+ # Truncated LeCun normal init
73
+ self.embedding_weight = nn.Parameter(
74
+ trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std)
75
+ )
76
+
77
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
78
+ return F.embedding(input, self.embedding_weight.to(self.cast_to))
79
+
80
+
81
+ class RotaryEmbedding(nn.Module):
82
+ def __init__(self, dim, max_position_embeddings, base, device=None):
83
+ super().__init__()
84
+
85
+ # RoPE
86
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
87
+ t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device)
88
+ freqs = torch.outer(t, inv_freq)
89
+
90
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
91
+ emb = torch.cat((freqs, freqs), dim=-1)
92
+ self.cos_cached = nn.Buffer(emb.cos(), persistent=False)
93
+ self.sin_cached = nn.Buffer(emb.sin(), persistent=False)
94
+
95
+ def forward(self):
96
+ return self.cos_cached, self.sin_cached
97
+
98
+
99
+ class Attention(nn.Module):
100
+ def __init__(self, hidden_size, head_dim, num_heads, num_key_value_heads, causal=False):
101
+ super().__init__()
102
+
103
+ self.hidden_size = hidden_size
104
+ self.head_dim = head_dim
105
+ self.output_size = head_dim * num_heads
106
+ self.num_heads = num_heads
107
+ self.num_key_value_heads = num_key_value_heads
108
+ self.causal = causal
109
+
110
+ self.qkv_proj = CastedLinear(self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False)
111
+ self.o_proj = CastedLinear(self.output_size, self.hidden_size, bias=False)
112
+
113
+ def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
114
+ batch_size, seq_len, _ = hidden_states.shape
115
+
116
+ # hidden_states: [bs, seq_len, num_heads, head_dim]
117
+ qkv = self.qkv_proj(hidden_states)
118
+
119
+ # Split head
120
+ qkv = qkv.view(batch_size, seq_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
121
+ query = qkv[:, :, :self.num_heads]
122
+ key = qkv[:, :, self.num_heads: self.num_heads + self.num_key_value_heads]
123
+ value = qkv[:, :, self.num_heads + self.num_key_value_heads:]
124
+
125
+ # RoPE
126
+ if cos_sin is not None:
127
+ cos, sin = cos_sin
128
+ query, key = apply_rotary_pos_emb(query, key, cos, sin)
129
+
130
+ # flash attn
131
+ query, key, value = map(lambda t: einops.rearrange(t, 'B S H D -> B H S D'), (query, key, value)) # needed for scaled_dot_product_attention but not flash_attn_func
132
+ attn_output = scaled_dot_product_attention(query=query, key=key, value=value, is_causal=self.causal)
133
+ attn_output = einops.rearrange(attn_output, 'B H S D -> B S H D')
134
+ attn_output = attn_output.view(batch_size, seq_len, self.output_size) # type: ignore
135
+ return self.o_proj(attn_output)
136
+
137
+ class LinearSwish(nn.Module):
138
+ def __init__(self, hidden_size: int, reverse=False):
139
+ super().__init__()
140
+
141
+ self.linear = CastedLinear(hidden_size, hidden_size, bias=False)
142
+ self.reverse = reverse
143
+
144
+ def forward(self, x):
145
+ if self.reverse:
146
+ return F.silu(self.linear(x))
147
+ else:
148
+ return self.linear(F.silu(x))
149
+
150
+
151
+ class SwiGLU(nn.Module):
152
+ def __init__(self, hidden_size: int, expansion: float):
153
+ super().__init__()
154
+ inter = _find_multiple(round(expansion * hidden_size * 2 / 3), 256)
155
+
156
+ self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False)
157
+ self.down_proj = CastedLinear(inter, hidden_size, bias=False)
158
+
159
+ def forward(self, x):
160
+ gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
161
+ return self.down_proj(F.silu(gate) * up)
162
+
163
+ def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float) -> torch.Tensor:
164
+ input_dtype = hidden_states.dtype
165
+ hidden_states = hidden_states.to(torch.float32)
166
+
167
+ variance = hidden_states.square().mean(-1, keepdim=True)
168
+ hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
169
+ return hidden_states.to(input_dtype)
models/losses.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Tuple, Dict, Sequence, Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ import math
7
+
8
+ IGNORE_LABEL_ID = -100
9
+
10
+
11
+ def s(x, epsilon=1e-30):
12
+ return torch.where(
13
+ x<0,
14
+ 1/(1-x+ epsilon),
15
+ x + 1
16
+ )
17
+
18
+
19
+ def log_stablemax(x, dim=-1):
20
+ s_x = s(x)
21
+ return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
22
+
23
+
24
+ def stablemax_cross_entropy(logits, labels, ignore_index: int = -100, valid_mask=None):
25
+ logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
26
+
27
+ if valid_mask is None:
28
+ valid_mask = (labels != ignore_index)
29
+ transformed_labels = torch.where(valid_mask, labels, 0)
30
+ prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
31
+
32
+ return -torch.where(valid_mask, prediction_logprobs, 0)
33
+
34
+
35
+ def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
36
+ # Cast logits to f32
37
+ # Flatten logits
38
+ return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
39
+
40
+
41
+ class ACTLossHead(nn.Module):
42
+ def __init__(self, model: nn.Module, loss_type: str):
43
+ super().__init__()
44
+ self.model = model
45
+ self.loss_fn = globals()[loss_type]
46
+
47
+ def initial_carry(self, *args, **kwargs):
48
+ return self.model.initial_carry(*args, **kwargs) # type: ignore
49
+
50
+ def forward(
51
+ self,
52
+ return_keys: Sequence[str],
53
+ # Model args
54
+ **model_kwargs,
55
+ ) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
56
+ # Model logits
57
+ # B x SeqLen x D
58
+ new_carry, outputs = self.model(**model_kwargs)
59
+ labels = new_carry.current_data["labels"]
60
+
61
+ with torch.no_grad():
62
+ # Preds
63
+ outputs["preds"] = torch.argmax(outputs["logits"], dim=-1)
64
+
65
+ # Correctness
66
+ mask = (labels != IGNORE_LABEL_ID)
67
+ loss_counts = mask.sum(-1)
68
+ loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
69
+
70
+ is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
71
+ seq_is_correct = is_correct.sum(-1) == loss_counts
72
+
73
+ # Metrics (halted)
74
+ valid_metrics = new_carry.halted & (loss_counts > 0)
75
+ metrics = {
76
+ "count": valid_metrics.sum(),
77
+
78
+ "accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
79
+ "exact_accuracy": (valid_metrics & seq_is_correct).sum(),
80
+
81
+ "q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
82
+ "steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
83
+ }
84
+
85
+ # Losses
86
+
87
+ lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID, valid_mask=mask) / loss_divisor).sum()
88
+ q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
89
+ metrics.update({
90
+ "lm_loss": lm_loss.detach(),
91
+ "q_halt_loss": q_halt_loss.detach(),
92
+ })
93
+ # Q continue (bootstrapping target loss); Alexia: This fits Q-learning, but seems totally unecessary
94
+ q_continue_loss = 0
95
+ if "target_q_continue" in outputs:
96
+ q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
97
+
98
+ metrics["q_continue_loss"] = q_continue_loss.detach()
99
+ # Filter outputs for return
100
+ detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
101
+
102
+ return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()
103
+
models/recursive_reasoning/hrm.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, List, Dict, Optional
2
+ from dataclasses import dataclass
3
+ import math
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+ from pydantic import BaseModel
8
+
9
+ from models.common import trunc_normal_init_
10
+ from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
11
+ from models.sparse_embedding import CastedSparseEmbedding
12
+
13
+ @dataclass
14
+ class HierarchicalReasoningModel_ACTV1InnerCarry:
15
+ z_H: torch.Tensor
16
+ z_L: torch.Tensor
17
+
18
+
19
+ @dataclass
20
+ class HierarchicalReasoningModel_ACTV1Carry:
21
+ inner_carry: HierarchicalReasoningModel_ACTV1InnerCarry
22
+
23
+ steps: torch.Tensor
24
+ halted: torch.Tensor
25
+
26
+ current_data: Dict[str, torch.Tensor]
27
+
28
+
29
+ class HierarchicalReasoningModel_ACTV1Config(BaseModel):
30
+ batch_size: int
31
+ seq_len: int
32
+ puzzle_emb_ndim: int = 0
33
+ num_puzzle_identifiers: int
34
+ vocab_size: int
35
+
36
+ H_cycles: int
37
+ L_cycles: int
38
+
39
+ H_layers: int
40
+ L_layers: int
41
+
42
+ # Transformer config
43
+ hidden_size: int
44
+ expansion: float
45
+ num_heads: int
46
+ pos_encodings: str
47
+
48
+ rms_norm_eps: float = 1e-5
49
+ rope_theta: float = 10000.0
50
+
51
+ # Halting Q-learning config
52
+ halt_max_steps: int
53
+ halt_exploration_prob: float
54
+
55
+ forward_dtype: str = "bfloat16"
56
+
57
+ # Alexia: added
58
+ mlp_t: bool=False # use mlp on L instead of transformer
59
+
60
+ class HierarchicalReasoningModel_ACTV1Block(nn.Module):
61
+ def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
62
+ super().__init__()
63
+
64
+ self.config = config
65
+ if self.config.mlp_t:
66
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size)
67
+ self.mlp_t = SwiGLU(
68
+ hidden_size=self.config.seq_len + self.puzzle_emb_len, # L
69
+ expansion=config.expansion,
70
+ )
71
+ else:
72
+ self.self_attn = Attention(
73
+ hidden_size=config.hidden_size,
74
+ head_dim=config.hidden_size // config.num_heads,
75
+ num_heads=config.num_heads,
76
+ num_key_value_heads=config.num_heads,
77
+ causal=False
78
+ )
79
+ self.mlp = SwiGLU(
80
+ hidden_size=config.hidden_size,
81
+ expansion=config.expansion,
82
+ )
83
+ self.norm_eps = config.rms_norm_eps
84
+
85
+ def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
86
+ # B, L, D = hidden_states.shape
87
+ # Post Norm
88
+ if self.config.mlp_t:
89
+ hidden_states = hidden_states.transpose(1,2)
90
+ out = self.mlp_t(hidden_states)
91
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
92
+ hidden_states = hidden_states.transpose(1,2)
93
+ else:
94
+ # Self Attention
95
+ hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
96
+ # Fully Connected
97
+ out = self.mlp(hidden_states)
98
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
99
+ return hidden_states
100
+
101
+ class HierarchicalReasoningModel_ACTV1ReasoningModule(nn.Module):
102
+ def __init__(self, layers: List[HierarchicalReasoningModel_ACTV1Block]):
103
+ super().__init__()
104
+
105
+ self.layers = torch.nn.ModuleList(layers)
106
+
107
+ def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
108
+ # Input injection (add)
109
+ hidden_states = hidden_states + input_injection
110
+ # Layers
111
+ for layer in self.layers:
112
+ hidden_states = layer(hidden_states=hidden_states, **kwargs)
113
+
114
+ return hidden_states
115
+
116
+
117
+ class HierarchicalReasoningModel_ACTV1_Inner(nn.Module):
118
+ def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
119
+ super().__init__()
120
+ self.config = config
121
+ self.forward_dtype = getattr(torch, self.config.forward_dtype)
122
+
123
+ # I/O
124
+ self.embed_scale = math.sqrt(self.config.hidden_size)
125
+ embed_init_std = 1.0 / self.embed_scale
126
+
127
+ self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
128
+ self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
129
+ self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
130
+
131
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
132
+ if self.config.puzzle_emb_ndim > 0:
133
+ # Zero init puzzle embeddings
134
+ self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
135
+ batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
136
+
137
+ # LM Blocks
138
+ if self.config.pos_encodings == "rope":
139
+ self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
140
+ max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
141
+ base=self.config.rope_theta)
142
+ elif self.config.pos_encodings == "learned":
143
+ self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
144
+ else:
145
+ pass
146
+
147
+ # Reasoning Layers
148
+ self.H_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.H_layers)])
149
+ self.L_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
150
+
151
+ # Initial states
152
+ self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
153
+ self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
154
+
155
+ # Q head special init
156
+ # Init Q to (almost) zero for faster learning during bootstrapping
157
+ with torch.no_grad():
158
+ self.q_head.weight.zero_()
159
+ self.q_head.bias.fill_(-5) # type: ignore
160
+
161
+ def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
162
+ # Token embedding
163
+ embedding = self.embed_tokens(input.to(torch.int32))
164
+
165
+ # Puzzle embeddings
166
+ if self.config.puzzle_emb_ndim > 0:
167
+ puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
168
+
169
+ pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
170
+ if pad_count > 0:
171
+ puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
172
+
173
+ embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
174
+
175
+ # Position embeddings
176
+ if self.config.pos_encodings == "learned":
177
+ # scale by 1/sqrt(2) to maintain forward variance
178
+ embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
179
+
180
+ # Scale
181
+ return self.embed_scale * embedding
182
+
183
+ def empty_carry(self, batch_size: int):
184
+ return HierarchicalReasoningModel_ACTV1InnerCarry(
185
+ z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
186
+ z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
187
+ )
188
+
189
+ def reset_carry(self, reset_flag: torch.Tensor, carry: HierarchicalReasoningModel_ACTV1InnerCarry):
190
+ return HierarchicalReasoningModel_ACTV1InnerCarry(
191
+ z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
192
+ z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
193
+ )
194
+
195
+ def forward(self, carry: HierarchicalReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
196
+ seq_info = dict(
197
+ cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
198
+ )
199
+
200
+ # Input encoding
201
+ input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
202
+
203
+ # Forward iterations
204
+ with torch.no_grad():
205
+ z_H, z_L = carry.z_H, carry.z_L
206
+ for _H_step in range(self.config.H_cycles):
207
+ for _L_step in range(self.config.L_cycles):
208
+ if not ((_H_step == self.config.H_cycles - 1) and (_L_step == self.config.L_cycles - 1)):
209
+ z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
210
+ if not (_H_step == self.config.H_cycles - 1):
211
+ z_H = self.H_level(z_H, z_L, **seq_info)
212
+ assert not z_H.requires_grad and not z_L.requires_grad
213
+ # 1-step grad
214
+ z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
215
+ z_H = self.H_level(z_H, z_L, **seq_info)
216
+
217
+ # LM Outputs
218
+ new_carry = HierarchicalReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach()) # New carry no grad
219
+ output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
220
+
221
+ # Q head
222
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
223
+
224
+ return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
225
+
226
+
227
+ class HierarchicalReasoningModel_ACTV1(nn.Module):
228
+ """ACT wrapper."""
229
+
230
+ def __init__(self, config_dict: dict):
231
+ super().__init__()
232
+ self.config = HierarchicalReasoningModel_ACTV1Config(**config_dict)
233
+ self.inner = HierarchicalReasoningModel_ACTV1_Inner(self.config)
234
+
235
+ @property
236
+ def puzzle_emb(self):
237
+ return self.inner.puzzle_emb
238
+
239
+ def initial_carry(self, batch: Dict[str, torch.Tensor]):
240
+ batch_size = batch["inputs"].shape[0]
241
+
242
+ return HierarchicalReasoningModel_ACTV1Carry(
243
+ inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
244
+
245
+ steps=torch.zeros((batch_size, ), dtype=torch.int32),
246
+ halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
247
+
248
+ current_data={k: torch.empty_like(v) for k, v in batch.items()}
249
+ )
250
+
251
+ def forward(self, carry: HierarchicalReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
252
+ # Update data, carry (removing halted sequences)
253
+ new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
254
+
255
+ new_steps = torch.where(carry.halted, 0, carry.steps)
256
+
257
+ new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
258
+
259
+ # Forward inner model
260
+ new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
261
+
262
+ outputs = {
263
+ "logits": logits,
264
+ "q_halt_logits": q_halt_logits,
265
+ "q_continue_logits": q_continue_logits
266
+ }
267
+
268
+ with torch.no_grad():
269
+ # Step
270
+ new_steps = new_steps + 1
271
+ is_last_step = new_steps >= self.config.halt_max_steps
272
+
273
+ halted = is_last_step
274
+
275
+ # if training, and ACT is enabled
276
+ if self.training and (self.config.halt_max_steps > 1):
277
+ # Halt signal
278
+ # NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
279
+ halted = halted | (q_halt_logits > q_continue_logits)
280
+
281
+ # Exploration
282
+ min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
283
+
284
+ halted = halted & (new_steps >= min_halt_steps)
285
+
286
+ # Compute target Q
287
+ # NOTE: No replay buffer and target networks for computing target Q-value.
288
+ # As batch_size is large, there're many parallel envs.
289
+ # Similar concept as PQN https://arxiv.org/abs/2407.04811
290
+ next_q_halt_logits, next_q_continue_logits = self.inner(new_inner_carry, new_current_data)[-1]
291
+
292
+ outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
293
+
294
+ return HierarchicalReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
models/recursive_reasoning/transformers_baseline.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HRM ACT V2: Transformer Baseline for Architecture Ablation
3
+
4
+ This is an architecture ablation of the Hierarchical Reasoning Model (HRM).
5
+ Key changes from V1:
6
+ 1. REMOVED hierarchical split (no separate H and L levels)
7
+ 2. REMOVED inner cycles (no H_cycles/L_cycles loops within reasoning)
8
+ 3. KEPT ACT outer loop structure intact
9
+ 4. KEPT all data preprocessing, embeddings, and evaluation infrastructure
10
+
11
+ Architecture: Single-level transformer that processes the full 30x30 grid as a
12
+ 900-token sequence, with the same positional encodings and sparse embeddings as V1.
13
+
14
+ """
15
+
16
+ from typing import Tuple, List, Dict, Optional
17
+ from dataclasses import dataclass
18
+ import math
19
+
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from torch import nn
23
+ from pydantic import BaseModel
24
+
25
+ from models.common import trunc_normal_init_
26
+ from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
27
+ from models.sparse_embedding import CastedSparseEmbedding
28
+
29
+
30
+ @dataclass
31
+ class Model_ACTV2InnerCarry:
32
+ z_H: torch.Tensor
33
+
34
+
35
+ @dataclass
36
+ class Model_ACTV2Carry:
37
+ inner_carry: Model_ACTV2InnerCarry
38
+
39
+ steps: torch.Tensor
40
+ halted: torch.Tensor
41
+
42
+ current_data: Dict[str, torch.Tensor]
43
+
44
+
45
+ class Model_ACTV2Config(BaseModel):
46
+ batch_size: int
47
+ seq_len: int
48
+ puzzle_emb_ndim: int = 0
49
+ num_puzzle_identifiers: int
50
+ vocab_size: int
51
+
52
+ H_cycles: int
53
+
54
+ H_layers: int
55
+
56
+ # Transformer config
57
+ hidden_size: int
58
+ expansion: float
59
+ num_heads: int
60
+ pos_encodings: str
61
+
62
+ rms_norm_eps: float = 1e-5
63
+ rope_theta: float = 10000.0
64
+
65
+ # Halting Q-learning config
66
+ halt_max_steps: int
67
+ halt_exploration_prob: float
68
+ act_enabled: bool = True # If False, always run halt_max_steps (no early stopping during training)
69
+ act_inference: bool = False # If True, use adaptive computation during inference
70
+
71
+ forward_dtype: str = "bfloat16"
72
+
73
+
74
+ class Model_ACTV2Block(nn.Module):
75
+ def __init__(self, config: Model_ACTV2Config) -> None:
76
+ super().__init__()
77
+
78
+ self.self_attn = Attention(
79
+ hidden_size=config.hidden_size,
80
+ head_dim=config.hidden_size // config.num_heads,
81
+ num_heads=config.num_heads,
82
+ num_key_value_heads=config.num_heads,
83
+ causal=False,
84
+ )
85
+ self.mlp = SwiGLU(
86
+ hidden_size=config.hidden_size,
87
+ expansion=config.expansion,
88
+ )
89
+ self.norm_eps = config.rms_norm_eps
90
+
91
+ def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
92
+ # Post Norm
93
+ # Self Attention
94
+ hidden_states = rms_norm(
95
+ hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states),
96
+ variance_epsilon=self.norm_eps,
97
+ )
98
+ # Fully Connected
99
+ hidden_states = rms_norm(hidden_states + self.mlp(hidden_states), variance_epsilon=self.norm_eps)
100
+ return hidden_states
101
+
102
+
103
+ class Model_ACTV2ReasoningModule(nn.Module):
104
+ def __init__(self, layers: List[Model_ACTV2Block]):
105
+ super().__init__()
106
+
107
+ self.layers = torch.nn.ModuleList(layers)
108
+
109
+ def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
110
+ # Input injection (add)
111
+ hidden_states = hidden_states + input_injection
112
+ # Layers
113
+ for layer in self.layers:
114
+ hidden_states = layer(hidden_states=hidden_states, **kwargs)
115
+
116
+ return hidden_states
117
+
118
+
119
+ class Model_ACTV2_Inner(nn.Module):
120
+ def __init__(self, config: Model_ACTV2Config) -> None:
121
+ super().__init__()
122
+ self.config = config
123
+ self.forward_dtype = getattr(torch, self.config.forward_dtype)
124
+
125
+ # I/O
126
+ self.embed_scale = math.sqrt(self.config.hidden_size)
127
+ embed_init_std = 1.0 / self.embed_scale
128
+
129
+ self.embed_tokens = CastedEmbedding(
130
+ self.config.vocab_size,
131
+ self.config.hidden_size,
132
+ init_std=embed_init_std,
133
+ cast_to=self.forward_dtype,
134
+ )
135
+ self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
136
+ self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
137
+
138
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
139
+ if self.config.puzzle_emb_ndim > 0:
140
+ # Zero init puzzle embeddings
141
+ self.puzzle_emb = CastedSparseEmbedding(
142
+ self.config.num_puzzle_identifiers,
143
+ self.config.puzzle_emb_ndim,
144
+ batch_size=self.config.batch_size,
145
+ init_std=0,
146
+ cast_to=self.forward_dtype,
147
+ )
148
+
149
+ # LM Blocks
150
+ if self.config.pos_encodings == "rope":
151
+ self.rotary_emb = RotaryEmbedding(
152
+ dim=self.config.hidden_size // self.config.num_heads,
153
+ max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
154
+ base=self.config.rope_theta,
155
+ )
156
+ elif self.config.pos_encodings == "learned":
157
+ self.embed_pos = CastedEmbedding(
158
+ self.config.seq_len + self.puzzle_emb_len,
159
+ self.config.hidden_size,
160
+ init_std=embed_init_std,
161
+ cast_to=self.forward_dtype,
162
+ )
163
+ else:
164
+ raise NotImplementedError()
165
+
166
+ # Reasoning Layers
167
+ self.H_level = Model_ACTV2ReasoningModule(
168
+ layers=[Model_ACTV2Block(self.config) for _i in range(self.config.H_layers)]
169
+ )
170
+
171
+ # Initial states
172
+ self.H_init = nn.Buffer(
173
+ trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1),
174
+ persistent=True,
175
+ )
176
+
177
+ # Q head special init
178
+ # Init Q to (almost) zero for faster learning during bootstrapping
179
+ with torch.no_grad():
180
+ self.q_head.weight.zero_()
181
+ self.q_head.bias.fill_(-5) # type: ignore
182
+
183
+ def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
184
+ # Token embedding
185
+ embedding = self.embed_tokens(input.to(torch.int32))
186
+
187
+ # Puzzle embeddings
188
+ if self.config.puzzle_emb_ndim > 0:
189
+ puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
190
+
191
+ pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
192
+ if pad_count > 0:
193
+ puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
194
+
195
+ embedding = torch.cat(
196
+ (puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2
197
+ )
198
+
199
+ # Position embeddings
200
+ if self.config.pos_encodings == "learned":
201
+ # scale by 1/sqrt(2) to maintain forward variance
202
+ embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
203
+
204
+ # Scale
205
+ return self.embed_scale * embedding
206
+
207
+ def empty_carry(self, batch_size: int):
208
+ return Model_ACTV2InnerCarry(
209
+ z_H=torch.empty(
210
+ batch_size,
211
+ self.config.seq_len + self.puzzle_emb_len,
212
+ self.config.hidden_size,
213
+ dtype=self.forward_dtype,
214
+ ),
215
+ )
216
+
217
+ def reset_carry(self, reset_flag: torch.Tensor, carry: Model_ACTV2InnerCarry):
218
+ return Model_ACTV2InnerCarry(
219
+ z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
220
+ )
221
+
222
+ def forward(
223
+ self, carry: Model_ACTV2InnerCarry, batch: Dict[str, torch.Tensor]
224
+ ) -> Tuple[Model_ACTV2InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
225
+ seq_info = dict(
226
+ cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
227
+ )
228
+
229
+ # Input encoding
230
+ input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
231
+
232
+ # 1-step grad
233
+ z_H = self.H_level(carry.z_H, input_embeddings, **seq_info)
234
+
235
+ # LM Outputs
236
+ new_carry = Model_ACTV2InnerCarry(
237
+ z_H=z_H.detach(),
238
+ ) # New carry no grad
239
+ output = self.lm_head(z_H)[:, self.puzzle_emb_len :]
240
+
241
+ # Q head
242
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
243
+
244
+ return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
245
+
246
+
247
+ class Model_ACTV2(nn.Module):
248
+ """ACT wrapper."""
249
+
250
+ def __init__(self, config_dict: dict):
251
+ super().__init__()
252
+ self.config = Model_ACTV2Config(**config_dict)
253
+ self.inner = Model_ACTV2_Inner(self.config)
254
+
255
+ @property
256
+ def puzzle_emb(self):
257
+ return self.inner.puzzle_emb
258
+
259
+ def initial_carry(self, batch: Dict[str, torch.Tensor]):
260
+ batch_size = batch["inputs"].shape[0]
261
+
262
+ return Model_ACTV2Carry(
263
+ inner_carry=self.inner.empty_carry(
264
+ batch_size
265
+ ), # Empty is expected, it will be reseted in first pass as all sequences are halted.
266
+ steps=torch.zeros((batch_size,), dtype=torch.int32),
267
+ halted=torch.ones((batch_size,), dtype=torch.bool), # Default to halted
268
+ current_data={k: torch.empty_like(v) for k, v in batch.items()},
269
+ )
270
+
271
+ def forward(
272
+ self,
273
+ carry: Model_ACTV2Carry,
274
+ batch: Dict[str, torch.Tensor],
275
+ compute_target_q: bool = False,
276
+ ) -> Tuple[Model_ACTV2Carry, Dict[str, torch.Tensor]]:
277
+ # Update data, carry (removing halted sequences)
278
+ new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
279
+
280
+ new_steps = torch.where(carry.halted, 0, carry.steps)
281
+
282
+ new_current_data = {
283
+ k: torch.where(carry.halted.view((-1,) + (1,) * (batch[k].ndim - 1)), batch[k], v)
284
+ for k, v in carry.current_data.items()
285
+ }
286
+
287
+ # Forward inner model
288
+ new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(
289
+ new_inner_carry, new_current_data
290
+ )
291
+
292
+ outputs = {"logits": logits, "q_halt_logits": q_halt_logits, "q_continue_logits": q_continue_logits}
293
+
294
+ with torch.no_grad():
295
+ # Step
296
+ new_steps = new_steps + 1
297
+ is_last_step = new_steps >= self.config.halt_max_steps
298
+
299
+ halted = is_last_step
300
+
301
+ # Check if adaptive computation should be used
302
+ use_adaptive = (self.config.halt_max_steps > 1) and (
303
+ (self.training and self.config.act_enabled)
304
+ or (not self.training and self.config.act_inference)
305
+ )
306
+
307
+ if use_adaptive:
308
+ # Halt signal based on Q-values (but always halt at max steps)
309
+ q_halt_signal = q_halt_logits > q_continue_logits
310
+ halted = halted | q_halt_signal
311
+
312
+ # Store actual steps used for logging (only during inference)
313
+ if not self.training:
314
+ outputs["actual_steps"] = new_steps.float()
315
+
316
+ # Exploration (only during training)
317
+ if self.training:
318
+ min_halt_steps = (
319
+ torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob
320
+ ) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
321
+ halted = halted & (new_steps >= min_halt_steps)
322
+
323
+ # Compute target Q (only during training)
324
+ # NOTE: No replay buffer and target networks for computing target Q-value.
325
+ # As batch_size is large, there're many parallel envs.
326
+ # Similar concept as PQN https://arxiv.org/abs/2407.04811
327
+ if self.training and compute_target_q:
328
+ next_q_halt_logits, next_q_continue_logits = self.inner(
329
+ new_inner_carry, new_current_data
330
+ )[-1]
331
+
332
+ outputs["target_q_continue"] = torch.sigmoid(
333
+ torch.where(
334
+ is_last_step,
335
+ next_q_halt_logits,
336
+ torch.maximum(next_q_halt_logits, next_q_continue_logits),
337
+ )
338
+ )
339
+
340
+ return Model_ACTV2Carry(
341
+ new_inner_carry, new_steps, halted, new_current_data
342
+ ), outputs
models/recursive_reasoning/trm.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, List, Dict, Optional
2
+ from dataclasses import dataclass
3
+ import math
4
+ import torch
5
+ import copy
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ from pydantic import BaseModel
9
+ import random
10
+ from models.common import trunc_normal_init_
11
+ from models.layers import rms_norm, LinearSwish, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
12
+ from models.sparse_embedding import CastedSparseEmbedding
13
+
14
+ IGNORE_LABEL_ID = -100
15
+
16
+ @dataclass
17
+ class TinyRecursiveReasoningModel_ACTV1InnerCarry:
18
+ z_H: torch.Tensor
19
+ z_L: torch.Tensor
20
+
21
+
22
+ @dataclass
23
+ class TinyRecursiveReasoningModel_ACTV1Carry:
24
+ inner_carry: TinyRecursiveReasoningModel_ACTV1InnerCarry
25
+
26
+ steps: torch.Tensor
27
+ halted: torch.Tensor
28
+
29
+ current_data: Dict[str, torch.Tensor]
30
+
31
+
32
+ class TinyRecursiveReasoningModel_ACTV1Config(BaseModel):
33
+ batch_size: int
34
+ seq_len: int
35
+ puzzle_emb_ndim: int = 0
36
+ num_puzzle_identifiers: int
37
+ vocab_size: int
38
+
39
+ H_cycles: int
40
+ L_cycles: int
41
+
42
+ H_layers: int # ignored
43
+ L_layers: int
44
+
45
+ # Transformer config
46
+ hidden_size: int
47
+ expansion: float
48
+ num_heads: int
49
+ pos_encodings: str
50
+
51
+ rms_norm_eps: float = 1e-5
52
+ rope_theta: float = 10000.0
53
+
54
+ # Halting Q-learning config
55
+ halt_max_steps: int
56
+ halt_exploration_prob: float
57
+
58
+ forward_dtype: str = "bfloat16"
59
+
60
+ # Alexia: added
61
+ mlp_t: bool = False # use mlp on L instead of transformer
62
+ puzzle_emb_len: int = 16 # if non-zero, its specified to this value
63
+ no_ACT_continue: bool = True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense
64
+
65
+ class TinyRecursiveReasoningModel_ACTV1Block(nn.Module):
66
+ def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
67
+ super().__init__()
68
+
69
+ self.config = config
70
+ if self.config.mlp_t:
71
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len
72
+ self.mlp_t = SwiGLU(
73
+ hidden_size=self.config.seq_len + self.puzzle_emb_len, # L
74
+ expansion=config.expansion,
75
+ )
76
+ else:
77
+ self.self_attn = Attention(
78
+ hidden_size=config.hidden_size,
79
+ head_dim=config.hidden_size // config.num_heads,
80
+ num_heads=config.num_heads,
81
+ num_key_value_heads=config.num_heads,
82
+ causal=False
83
+ )
84
+ self.mlp = SwiGLU(
85
+ hidden_size=config.hidden_size,
86
+ expansion=config.expansion,
87
+ )
88
+ self.norm_eps = config.rms_norm_eps
89
+
90
+ def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
91
+ # B, L, D = hidden_states.shape
92
+ # Post Norm
93
+ if self.config.mlp_t:
94
+ hidden_states = hidden_states.transpose(1,2)
95
+ out = self.mlp_t(hidden_states)
96
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
97
+ hidden_states = hidden_states.transpose(1,2)
98
+ else:
99
+ # Self Attention
100
+ hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
101
+ # Fully Connected
102
+ out = self.mlp(hidden_states)
103
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
104
+ return hidden_states
105
+
106
+ class TinyRecursiveReasoningModel_ACTV1ReasoningModule(nn.Module):
107
+ def __init__(self, layers: List[TinyRecursiveReasoningModel_ACTV1Block]):
108
+ super().__init__()
109
+ self.layers = torch.nn.ModuleList(layers)
110
+
111
+ def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
112
+ hidden_states = hidden_states + input_injection
113
+ for layer in self.layers:
114
+ hidden_states = layer(hidden_states=hidden_states, **kwargs)
115
+ return hidden_states
116
+
117
+
118
+ class TinyRecursiveReasoningModel_ACTV1_Inner(nn.Module):
119
+ def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
120
+ super().__init__()
121
+ self.config = config
122
+ self.forward_dtype = getattr(torch, self.config.forward_dtype)
123
+
124
+ # I/O
125
+
126
+ self.embed_scale = math.sqrt(self.config.hidden_size)
127
+ embed_init_std = 1.0 / self.embed_scale
128
+
129
+ self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
130
+ self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
131
+ self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
132
+
133
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len # ceil div
134
+ if self.config.puzzle_emb_ndim > 0:
135
+ # Zero init puzzle embeddings
136
+ self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
137
+ batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
138
+
139
+ # LM Blocks
140
+ if self.config.pos_encodings == "rope":
141
+ self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
142
+ max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
143
+ base=self.config.rope_theta)
144
+ elif self.config.pos_encodings == "learned":
145
+ self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
146
+ else:
147
+ pass
148
+
149
+ # Reasoning Layers
150
+ self.L_level = TinyRecursiveReasoningModel_ACTV1ReasoningModule(layers=[TinyRecursiveReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
151
+
152
+ # Initial states
153
+ self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
154
+ self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
155
+
156
+ # Q head special init
157
+ # Init Q to (almost) zero for faster learning during bootstrapping
158
+ with torch.no_grad():
159
+ self.q_head.weight.zero_()
160
+ self.q_head.bias.fill_(-5) # type: ignore
161
+
162
+ def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
163
+ # Token embedding
164
+ embedding = self.embed_tokens(input.to(torch.int32))
165
+
166
+ # Puzzle embeddings
167
+ if self.config.puzzle_emb_ndim > 0:
168
+ puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
169
+
170
+ pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
171
+ if pad_count > 0:
172
+ puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
173
+
174
+ embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
175
+
176
+ # Position embeddings
177
+ if self.config.pos_encodings == "learned":
178
+ # scale by 1/sqrt(2) to maintain forward variance
179
+ embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
180
+
181
+ # Scale
182
+ return self.embed_scale * embedding
183
+
184
+ def empty_carry(self, batch_size: int):
185
+ return TinyRecursiveReasoningModel_ACTV1InnerCarry(
186
+ z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
187
+ z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
188
+ )
189
+
190
+ def reset_carry(self, reset_flag: torch.Tensor, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry):
191
+ return TinyRecursiveReasoningModel_ACTV1InnerCarry(
192
+ z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
193
+ z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
194
+ )
195
+
196
+ def forward(self, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
197
+ seq_info = dict(
198
+ cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
199
+ )
200
+
201
+ # Input encoding
202
+ input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
203
+
204
+ # Forward iterations
205
+ it = 0
206
+ z_H, z_L = carry.z_H, carry.z_L
207
+ # H_cycles-1 without grad
208
+ with torch.no_grad():
209
+ for _H_step in range(self.config.H_cycles-1):
210
+ for _L_step in range(self.config.L_cycles):
211
+ z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
212
+ z_H = self.L_level(z_H, z_L, **seq_info)
213
+ # 1 with grad
214
+ for _L_step in range(self.config.L_cycles):
215
+ z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
216
+ z_H = self.L_level(z_H, z_L, **seq_info)
217
+
218
+ # LM Outputs
219
+ new_carry = TinyRecursiveReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach()) # New carry no grad
220
+ output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
221
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32) # Q-head; uses the first puzzle_emb position
222
+ return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
223
+
224
+
225
+ class TinyRecursiveReasoningModel_ACTV1(nn.Module):
226
+ """ACT wrapper."""
227
+
228
+ def __init__(self, config_dict: dict):
229
+ super().__init__()
230
+ self.config = TinyRecursiveReasoningModel_ACTV1Config(**config_dict)
231
+ self.inner = TinyRecursiveReasoningModel_ACTV1_Inner(self.config)
232
+
233
+ @property
234
+ def puzzle_emb(self):
235
+ return self.inner.puzzle_emb
236
+
237
+ def initial_carry(self, batch: Dict[str, torch.Tensor]):
238
+ batch_size = batch["inputs"].shape[0]
239
+
240
+ return TinyRecursiveReasoningModel_ACTV1Carry(
241
+ inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
242
+
243
+ steps=torch.zeros((batch_size, ), dtype=torch.int32),
244
+ halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
245
+
246
+ current_data={k: torch.empty_like(v) for k, v in batch.items()}
247
+ )
248
+
249
+ def forward(self, carry: TinyRecursiveReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
250
+
251
+ # Update data, carry (removing halted sequences)
252
+ new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
253
+
254
+ new_steps = torch.where(carry.halted, 0, carry.steps)
255
+
256
+ new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
257
+
258
+ # Forward inner model
259
+ new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
260
+
261
+ outputs = {
262
+ "logits": logits,
263
+ "q_halt_logits": q_halt_logits,
264
+ "q_continue_logits": q_continue_logits
265
+ }
266
+
267
+ with torch.no_grad():
268
+ # Step
269
+ new_steps = new_steps + 1
270
+ is_last_step = new_steps >= self.config.halt_max_steps
271
+
272
+ halted = is_last_step
273
+
274
+ # if training, and ACT is enabled
275
+ if self.training and (self.config.halt_max_steps > 1):
276
+
277
+ # Halt signal
278
+ # NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
279
+
280
+ if self.config.no_ACT_continue:
281
+ halted = halted | (q_halt_logits > 0)
282
+ else:
283
+ halted = halted | (q_halt_logits > q_continue_logits)
284
+
285
+ # Exploration
286
+ min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
287
+ halted = halted & (new_steps >= min_halt_steps)
288
+
289
+ if not self.config.no_ACT_continue:
290
+ # Compute target Q
291
+ # NOTE: No replay buffer and target networks for computing target Q-value.
292
+ # As batch_size is large, there're many parallel envs.
293
+ # Similar concept as PQN https://arxiv.org/abs/2407.04811
294
+ _, _, (next_q_halt_logits, next_q_continue_logits), _, _ = self.inner(new_inner_carry, new_current_data)
295
+ outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
296
+
297
+ return TinyRecursiveReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
models/recursive_reasoning/trm_hier6.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, List, Dict, Optional
2
+ from dataclasses import dataclass
3
+ import math
4
+ import torch
5
+ import copy
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ from pydantic import BaseModel
9
+ import random
10
+ from models.common import trunc_normal_init_
11
+ from models.layers import rms_norm, LinearSwish, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
12
+ from models.sparse_embedding import CastedSparseEmbedding
13
+
14
+ IGNORE_LABEL_ID = -100
15
+
16
+ @dataclass
17
+ class TinyRecursiveReasoningModel_ACTV1InnerCarry:
18
+ z_H: torch.Tensor
19
+ z_L1: torch.Tensor
20
+ z_L2: torch.Tensor
21
+ z_L3: torch.Tensor
22
+ z_L4: torch.Tensor
23
+ z_L5: torch.Tensor
24
+ z_L6: torch.Tensor
25
+
26
+
27
+
28
+ @dataclass
29
+ class TinyRecursiveReasoningModel_ACTV1Carry:
30
+ inner_carry: TinyRecursiveReasoningModel_ACTV1InnerCarry
31
+
32
+ steps: torch.Tensor
33
+ halted: torch.Tensor
34
+
35
+ current_data: Dict[str, torch.Tensor]
36
+
37
+
38
+ class TinyRecursiveReasoningModel_ACTV1Config(BaseModel):
39
+ batch_size: int
40
+ seq_len: int
41
+ puzzle_emb_ndim: int = 0
42
+ num_puzzle_identifiers: int
43
+ vocab_size: int
44
+
45
+ H_cycles: int
46
+ L_cycles: int
47
+
48
+ H_layers: int # ignored
49
+ L_layers: int
50
+
51
+ # Transformer config
52
+ hidden_size: int
53
+ expansion: float
54
+ num_heads: int
55
+ pos_encodings: str
56
+
57
+ rms_norm_eps: float = 1e-5
58
+ rope_theta: float = 10000.0
59
+
60
+ # Halting Q-learning config
61
+ halt_max_steps: int
62
+ halt_exploration_prob: float
63
+
64
+ forward_dtype: str = "bfloat16"
65
+
66
+ # Alexia: added
67
+ mlp_t: bool = False # use mlp on L instead of transformer
68
+ puzzle_emb_len: int = 16 # if non-zero, its specified to this value
69
+ no_ACT_continue: bool = True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense
70
+
71
+ class TinyRecursiveReasoningModel_ACTV1Block(nn.Module):
72
+ def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
73
+ super().__init__()
74
+
75
+ self.config = config
76
+ if self.config.mlp_t:
77
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len
78
+ self.mlp_t = SwiGLU(
79
+ hidden_size=self.config.seq_len + self.puzzle_emb_len, # L
80
+ expansion=config.expansion,
81
+ )
82
+ else:
83
+ self.self_attn = Attention(
84
+ hidden_size=config.hidden_size,
85
+ head_dim=config.hidden_size // config.num_heads,
86
+ num_heads=config.num_heads,
87
+ num_key_value_heads=config.num_heads,
88
+ causal=False
89
+ )
90
+ self.mlp = SwiGLU(
91
+ hidden_size=config.hidden_size,
92
+ expansion=config.expansion,
93
+ )
94
+ self.norm_eps = config.rms_norm_eps
95
+
96
+ def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
97
+ # B, L, D = hidden_states.shape
98
+ # Post Norm
99
+ if self.config.mlp_t:
100
+ hidden_states = hidden_states.transpose(1,2)
101
+ out = self.mlp_t(hidden_states)
102
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
103
+ hidden_states = hidden_states.transpose(1,2)
104
+ else:
105
+ # Self Attention
106
+ hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
107
+ # Fully Connected
108
+ out = self.mlp(hidden_states)
109
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
110
+ return hidden_states
111
+
112
+ class TinyRecursiveReasoningModel_ACTV1ReasoningModule(nn.Module):
113
+ def __init__(self, layers: List[TinyRecursiveReasoningModel_ACTV1Block]):
114
+ super().__init__()
115
+ self.layers = torch.nn.ModuleList(layers)
116
+
117
+ def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
118
+ hidden_states = hidden_states + input_injection
119
+ for layer in self.layers:
120
+ hidden_states = layer(hidden_states=hidden_states, **kwargs)
121
+ return hidden_states
122
+
123
+
124
+ class TinyRecursiveReasoningModel_ACTV1_Inner(nn.Module):
125
+ def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
126
+ super().__init__()
127
+ self.config = config
128
+ self.forward_dtype = getattr(torch, self.config.forward_dtype)
129
+
130
+ # I/O
131
+
132
+ self.embed_scale = math.sqrt(self.config.hidden_size)
133
+ embed_init_std = 1.0 / self.embed_scale
134
+
135
+ self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
136
+ self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
137
+ self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
138
+
139
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len # ceil div
140
+ if self.config.puzzle_emb_ndim > 0:
141
+ # Zero init puzzle embeddings
142
+ self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
143
+ batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
144
+
145
+ # LM Blocks
146
+ if self.config.pos_encodings == "rope":
147
+ self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
148
+ max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
149
+ base=self.config.rope_theta)
150
+ elif self.config.pos_encodings == "learned":
151
+ self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
152
+ else:
153
+ pass
154
+
155
+ # Reasoning Layers
156
+ self.L_level = TinyRecursiveReasoningModel_ACTV1ReasoningModule(layers=[TinyRecursiveReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
157
+
158
+ # Initial states
159
+ self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
160
+ self.L1_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
161
+ self.L2_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
162
+ self.L3_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
163
+ self.L4_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
164
+ self.L5_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
165
+ self.L6_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
166
+
167
+ # Q head special init
168
+ # Init Q to (almost) zero for faster learning during bootstrapping
169
+ with torch.no_grad():
170
+ self.q_head.weight.zero_()
171
+ self.q_head.bias.fill_(-5) # type: ignore
172
+
173
+ def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
174
+ # Token embedding
175
+ embedding = self.embed_tokens(input.to(torch.int32))
176
+
177
+ # Puzzle embeddings
178
+ if self.config.puzzle_emb_ndim > 0:
179
+ puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
180
+
181
+ pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
182
+ if pad_count > 0:
183
+ puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
184
+
185
+ embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
186
+
187
+ # Position embeddings
188
+ if self.config.pos_encodings == "learned":
189
+ # scale by 1/sqrt(2) to maintain forward variance
190
+ embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
191
+
192
+ # Scale
193
+ return self.embed_scale * embedding
194
+
195
+ def empty_carry(self, batch_size: int):
196
+ return TinyRecursiveReasoningModel_ACTV1InnerCarry(
197
+ z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
198
+ z_L1=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
199
+ z_L2=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
200
+ z_L3=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
201
+ z_L4=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
202
+ z_L5=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
203
+ z_L6=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
204
+ )
205
+
206
+ def reset_carry(self, reset_flag: torch.Tensor, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry):
207
+ return TinyRecursiveReasoningModel_ACTV1InnerCarry(
208
+ z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
209
+ z_L1=torch.where(reset_flag.view(-1, 1, 1), self.L1_init, carry.z_L1),
210
+ z_L2=torch.where(reset_flag.view(-1, 1, 1), self.L2_init, carry.z_L2),
211
+ z_L3=torch.where(reset_flag.view(-1, 1, 1), self.L3_init, carry.z_L3),
212
+ z_L4=torch.where(reset_flag.view(-1, 1, 1), self.L4_init, carry.z_L4),
213
+ z_L5=torch.where(reset_flag.view(-1, 1, 1), self.L5_init, carry.z_L5),
214
+ z_L6=torch.where(reset_flag.view(-1, 1, 1), self.L6_init, carry.z_L6),
215
+ )
216
+
217
+
218
+ def forward(self, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
219
+ seq_info = dict(
220
+ cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
221
+ )
222
+
223
+ # Input encoding
224
+ input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
225
+
226
+ # Forward iterations
227
+ it = 0
228
+ z_H, z_L = carry.z_H, [carry.z_L1, carry.z_L2, carry.z_L3, carry.z_L4, carry.z_L5, carry.z_L6]
229
+ # H_cycles-1 without grad
230
+ with torch.no_grad():
231
+ for _H_step in range(self.config.H_cycles-1):
232
+ for _L_step in range(self.config.L_cycles):
233
+ z_L_ = z_L[0] + z_L[1] + z_L[2] + z_L[3] + z_L[4] + z_L[5]
234
+ z_L[_L_step] = self.L_level(z_L_, z_H + input_embeddings, **seq_info)
235
+ z_L_ = z_L[0] + z_L[1] + z_L[2] + z_L[3] + z_L[4] + z_L[5]
236
+ z_H = self.L_level(z_H, z_L_, **seq_info)
237
+ # 1 with grad
238
+ for _L_step in range(self.config.L_cycles):
239
+ z_L_ = z_L[0] + z_L[1] + z_L[2] + z_L[3] + z_L[4] + z_L[5]
240
+ z_L[_L_step] = self.L_level(z_L_, z_H + input_embeddings, **seq_info)
241
+ z_L_ = z_L[0] + z_L[1] + z_L[2] + z_L[3] + z_L[4] + z_L[5]
242
+ z_H = self.L_level(z_H, z_L_, **seq_info)
243
+
244
+ # LM Outputs
245
+ new_carry = TinyRecursiveReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L1=z_L[0].detach(), z_L2=z_L[1].detach(), z_L3=z_L[2].detach(), z_L4=z_L[3].detach(), z_L5=z_L[4].detach(), z_L6=z_L[5].detach()) # New carry no grad
246
+ output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
247
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32) # Q-head; uses the first puzzle_emb position
248
+ return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
249
+
250
+
251
+ class TinyRecursiveReasoningModel_ACTV1(nn.Module):
252
+ """ACT wrapper."""
253
+
254
+ def __init__(self, config_dict: dict):
255
+ super().__init__()
256
+ self.config = TinyRecursiveReasoningModel_ACTV1Config(**config_dict)
257
+ self.inner = TinyRecursiveReasoningModel_ACTV1_Inner(self.config)
258
+
259
+ @property
260
+ def puzzle_emb(self):
261
+ return self.inner.puzzle_emb
262
+
263
+ def initial_carry(self, batch: Dict[str, torch.Tensor]):
264
+ batch_size = batch["inputs"].shape[0]
265
+
266
+ return TinyRecursiveReasoningModel_ACTV1Carry(
267
+ inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
268
+
269
+ steps=torch.zeros((batch_size, ), dtype=torch.int32),
270
+ halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
271
+
272
+ current_data={k: torch.empty_like(v) for k, v in batch.items()}
273
+ )
274
+
275
+ def forward(self, carry: TinyRecursiveReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
276
+
277
+ # Update data, carry (removing halted sequences)
278
+ new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
279
+
280
+ new_steps = torch.where(carry.halted, 0, carry.steps)
281
+
282
+ new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
283
+
284
+ # Forward inner model
285
+ new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
286
+
287
+ outputs = {
288
+ "logits": logits,
289
+ "q_halt_logits": q_halt_logits,
290
+ "q_continue_logits": q_continue_logits
291
+ }
292
+
293
+ with torch.no_grad():
294
+ # Step
295
+ new_steps = new_steps + 1
296
+ is_last_step = new_steps >= self.config.halt_max_steps
297
+
298
+ halted = is_last_step
299
+
300
+ # if training, and ACT is enabled
301
+ if self.training and (self.config.halt_max_steps > 1):
302
+
303
+ # Halt signal
304
+ # NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
305
+
306
+ if self.config.no_ACT_continue:
307
+ halted = halted | (q_halt_logits > 0)
308
+ else:
309
+ halted = halted | (q_halt_logits > q_continue_logits)
310
+
311
+ # Exploration
312
+ min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
313
+ halted = halted & (new_steps >= min_halt_steps)
314
+
315
+ if not self.config.no_ACT_continue:
316
+ # Compute target Q
317
+ # NOTE: No replay buffer and target networks for computing target Q-value.
318
+ # As batch_size is large, there're many parallel envs.
319
+ # Similar concept as PQN https://arxiv.org/abs/2407.04811
320
+ _, _, (next_q_halt_logits, next_q_continue_logits), _, _ = self.inner(new_inner_carry, new_current_data)
321
+ outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
322
+
323
+ return TinyRecursiveReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
models/recursive_reasoning/trm_singlez.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, List, Dict, Optional
2
+ from dataclasses import dataclass
3
+ import math
4
+ import torch
5
+ import copy
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ from pydantic import BaseModel
9
+ import random
10
+ from models.common import trunc_normal_init_
11
+ from models.layers import rms_norm, LinearSwish, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
12
+ from models.sparse_embedding import CastedSparseEmbedding
13
+
14
+ IGNORE_LABEL_ID = -100
15
+
16
+ @dataclass
17
+ class TinyRecursiveReasoningModel_ACTV1InnerCarry:
18
+ z_L: torch.Tensor
19
+
20
+
21
+
22
+ @dataclass
23
+ class TinyRecursiveReasoningModel_ACTV1Carry:
24
+ inner_carry: TinyRecursiveReasoningModel_ACTV1InnerCarry
25
+
26
+ steps: torch.Tensor
27
+ halted: torch.Tensor
28
+
29
+ current_data: Dict[str, torch.Tensor]
30
+
31
+
32
+ class TinyRecursiveReasoningModel_ACTV1Config(BaseModel):
33
+ batch_size: int
34
+ seq_len: int
35
+ puzzle_emb_ndim: int = 0
36
+ num_puzzle_identifiers: int
37
+ vocab_size: int
38
+
39
+ H_cycles: int
40
+ L_cycles: int
41
+
42
+ H_layers: int # ignored
43
+ L_layers: int
44
+
45
+ # Transformer config
46
+ hidden_size: int
47
+ expansion: float
48
+ num_heads: int
49
+ pos_encodings: str
50
+
51
+ rms_norm_eps: float = 1e-5
52
+ rope_theta: float = 10000.0
53
+
54
+ # Halting Q-learning config
55
+ halt_max_steps: int
56
+ halt_exploration_prob: float
57
+
58
+ forward_dtype: str = "bfloat16"
59
+
60
+ # Alexia: added
61
+ mlp_t: bool = False # use mlp on L instead of transformer
62
+ puzzle_emb_len: int = 16 # if non-zero, its specified to this value
63
+ no_ACT_continue: bool = True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense
64
+
65
+ class TinyRecursiveReasoningModel_ACTV1Block(nn.Module):
66
+ def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
67
+ super().__init__()
68
+
69
+ self.config = config
70
+ if self.config.mlp_t:
71
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len
72
+ self.mlp_t = SwiGLU(
73
+ hidden_size=self.config.seq_len + self.puzzle_emb_len, # L
74
+ expansion=config.expansion,
75
+ )
76
+ else:
77
+ self.self_attn = Attention(
78
+ hidden_size=config.hidden_size,
79
+ head_dim=config.hidden_size // config.num_heads,
80
+ num_heads=config.num_heads,
81
+ num_key_value_heads=config.num_heads,
82
+ causal=False
83
+ )
84
+ self.mlp = SwiGLU(
85
+ hidden_size=config.hidden_size,
86
+ expansion=config.expansion,
87
+ )
88
+ self.norm_eps = config.rms_norm_eps
89
+
90
+ def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
91
+ # B, L, D = hidden_states.shape
92
+ # Post Norm
93
+ if self.config.mlp_t:
94
+ hidden_states = hidden_states.transpose(1,2)
95
+ out = self.mlp_t(hidden_states)
96
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
97
+ hidden_states = hidden_states.transpose(1,2)
98
+ else:
99
+ # Self Attention
100
+ hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
101
+ # Fully Connected
102
+ out = self.mlp(hidden_states)
103
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
104
+ return hidden_states
105
+
106
+ class TinyRecursiveReasoningModel_ACTV1ReasoningModule(nn.Module):
107
+ def __init__(self, layers: List[TinyRecursiveReasoningModel_ACTV1Block]):
108
+ super().__init__()
109
+ self.layers = torch.nn.ModuleList(layers)
110
+
111
+ def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
112
+ for layer in self.layers:
113
+ hidden_states = layer(hidden_states=hidden_states, **kwargs)
114
+ return hidden_states
115
+
116
+
117
+ class TinyRecursiveReasoningModel_ACTV1_Inner(nn.Module):
118
+ def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
119
+ super().__init__()
120
+ self.config = config
121
+ self.forward_dtype = getattr(torch, self.config.forward_dtype)
122
+
123
+ # I/O
124
+
125
+ self.embed_scale = math.sqrt(self.config.hidden_size)
126
+ embed_init_std = 1.0 / self.embed_scale
127
+
128
+ self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
129
+ self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
130
+ self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
131
+
132
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len # ceil div
133
+ if self.config.puzzle_emb_ndim > 0:
134
+ # Zero init puzzle embeddings
135
+ self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
136
+ batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
137
+
138
+ # LM Blocks
139
+ if self.config.pos_encodings == "rope":
140
+ self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
141
+ max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
142
+ base=self.config.rope_theta)
143
+ elif self.config.pos_encodings == "learned":
144
+ self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
145
+ else:
146
+ pass
147
+
148
+ # Reasoning Layers
149
+ self.L_level = TinyRecursiveReasoningModel_ACTV1ReasoningModule(layers=[TinyRecursiveReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
150
+
151
+ # Initial states
152
+ self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
153
+
154
+ # Q head special init
155
+ # Init Q to (almost) zero for faster learning during bootstrapping
156
+ with torch.no_grad():
157
+ self.q_head.weight.zero_()
158
+ self.q_head.bias.fill_(-5) # type: ignore
159
+
160
+ def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
161
+ # Token embedding
162
+ embedding = self.embed_tokens(input.to(torch.int32))
163
+
164
+ # Puzzle embeddings
165
+ if self.config.puzzle_emb_ndim > 0:
166
+ puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
167
+
168
+ pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
169
+ if pad_count > 0:
170
+ puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
171
+
172
+ embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
173
+
174
+ # Position embeddings
175
+ if self.config.pos_encodings == "learned":
176
+ # scale by 1/sqrt(2) to maintain forward variance
177
+ embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
178
+
179
+ # Scale
180
+ return self.embed_scale * embedding
181
+
182
+ def empty_carry(self, batch_size: int):
183
+ return TinyRecursiveReasoningModel_ACTV1InnerCarry(
184
+ z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
185
+ )
186
+
187
+ def reset_carry(self, reset_flag: torch.Tensor, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry):
188
+ return TinyRecursiveReasoningModel_ACTV1InnerCarry(
189
+ z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
190
+ )
191
+
192
+ def forward(self, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
193
+ seq_info = dict(
194
+ cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
195
+ )
196
+
197
+ # Input encoding
198
+ input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
199
+
200
+ # Forward iterations
201
+ it = 0
202
+ z_L = carry.z_L
203
+ # H_cycles-1 without grad
204
+ with torch.no_grad():
205
+ for _H_step in range(self.config.H_cycles-1):
206
+ for _L_step in range(self.config.L_cycles):
207
+ z_L = self.L_level(z_L + input_embeddings, **seq_info)
208
+ z_L = self.L_level(z_L, **seq_info)
209
+ # 1 with grad
210
+ for _L_step in range(self.config.L_cycles):
211
+ z_L = self.L_level(z_L + input_embeddings, **seq_info)
212
+ z_L = self.L_level(z_L, **seq_info)
213
+ z_out = z_L
214
+
215
+ # LM Outputs
216
+ new_carry = TinyRecursiveReasoningModel_ACTV1InnerCarry(z_L=z_L.detach()) # New carry no grad
217
+ output = self.lm_head(z_out)[:, self.puzzle_emb_len:]
218
+ q_logits = self.q_head(z_out[:, 0]).to(torch.float32) # Q-head; uses the first puzzle_emb position
219
+ return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
220
+
221
+
222
+ class TinyRecursiveReasoningModel_ACTV1(nn.Module):
223
+ """ACT wrapper."""
224
+
225
+ def __init__(self, config_dict: dict):
226
+ super().__init__()
227
+ self.config = TinyRecursiveReasoningModel_ACTV1Config(**config_dict)
228
+ self.inner = TinyRecursiveReasoningModel_ACTV1_Inner(self.config)
229
+
230
+ @property
231
+ def puzzle_emb(self):
232
+ return self.inner.puzzle_emb
233
+
234
+ def initial_carry(self, batch: Dict[str, torch.Tensor]):
235
+ batch_size = batch["inputs"].shape[0]
236
+
237
+ return TinyRecursiveReasoningModel_ACTV1Carry(
238
+ inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
239
+
240
+ steps=torch.zeros((batch_size, ), dtype=torch.int32),
241
+ halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
242
+
243
+ current_data={k: torch.empty_like(v) for k, v in batch.items()}
244
+ )
245
+
246
+ def forward(self, carry: TinyRecursiveReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
247
+
248
+ # Update data, carry (removing halted sequences)
249
+ new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
250
+
251
+ new_steps = torch.where(carry.halted, 0, carry.steps)
252
+
253
+ new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
254
+
255
+ # Forward inner model
256
+ new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
257
+
258
+ outputs = {
259
+ "logits": logits,
260
+ "q_halt_logits": q_halt_logits,
261
+ "q_continue_logits": q_continue_logits
262
+ }
263
+
264
+ with torch.no_grad():
265
+ # Step
266
+ new_steps = new_steps + 1
267
+ is_last_step = new_steps >= self.config.halt_max_steps
268
+
269
+ halted = is_last_step
270
+
271
+ # if training, and ACT is enabled
272
+ if self.training and (self.config.halt_max_steps > 1):
273
+
274
+ # Halt signal
275
+ # NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
276
+
277
+ if self.config.no_ACT_continue:
278
+ halted = halted | (q_halt_logits > 0)
279
+ else:
280
+ halted = halted | (q_halt_logits > q_continue_logits)
281
+
282
+ # Exploration
283
+ min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
284
+ halted = halted & (new_steps >= min_halt_steps)
285
+
286
+ if not self.config.no_ACT_continue:
287
+ # Compute target Q
288
+ # NOTE: No replay buffer and target networks for computing target Q-value.
289
+ # As batch_size is large, there're many parallel envs.
290
+ # Similar concept as PQN https://arxiv.org/abs/2407.04811
291
+ _, _, (next_q_halt_logits, next_q_continue_logits), _, _ = self.inner(new_inner_carry, new_current_data)
292
+ outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
293
+
294
+ return TinyRecursiveReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
models/sparse_embedding.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import torch
4
+ from torch import nn
5
+ import torch.distributed as dist
6
+ from torch.optim.optimizer import Optimizer, ParamsT
7
+
8
+ from models.common import trunc_normal_init_
9
+
10
+
11
+ class CastedSparseEmbedding(nn.Module):
12
+ def __init__(self, num_embeddings: int, embedding_dim: int, batch_size: int, init_std: float, cast_to: torch.dtype):
13
+ super().__init__()
14
+ self.cast_to = cast_to
15
+
16
+ # Real Weights
17
+ # Truncated LeCun normal init
18
+ self.weights = nn.Buffer(
19
+ trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std), persistent=True
20
+ )
21
+
22
+ # Local weights and IDs
23
+ # Local embeddings, with gradient, not persistent
24
+ self.local_weights = nn.Buffer(torch.zeros(batch_size, embedding_dim, requires_grad=True), persistent=False)
25
+ # Local embedding IDs, not persistent
26
+ self.local_ids = nn.Buffer(torch.zeros(batch_size, dtype=torch.int32), persistent=False)
27
+
28
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
29
+ if not self.training:
30
+ # Test mode, no gradient
31
+ return self.weights[inputs].to(self.cast_to)
32
+
33
+ # Training mode, fill puzzle embedding from weights
34
+ with torch.no_grad():
35
+ self.local_weights.copy_(self.weights[inputs])
36
+ self.local_ids.copy_(inputs)
37
+
38
+ return self.local_weights.to(self.cast_to)
39
+
40
+
41
+ class CastedSparseEmbeddingSignSGD_Distributed(Optimizer):
42
+ def __init__(
43
+ self,
44
+ params: ParamsT,
45
+
46
+ world_size: int,
47
+ lr: Union[float, torch.Tensor] = 1e-3,
48
+ weight_decay: float = 1e-2,
49
+ ):
50
+ if not 0.0 <= lr:
51
+ raise ValueError(f"Invalid learning rate: {lr}")
52
+ if not 0.0 <= weight_decay:
53
+ raise ValueError(f"Invalid weight_decay value: {weight_decay}")
54
+
55
+ defaults = dict(
56
+ lr=lr,
57
+ weight_decay=weight_decay,
58
+ world_size=world_size
59
+ )
60
+ super().__init__(params, defaults)
61
+
62
+ @torch.no_grad
63
+ def step(self, closure=None): # type: ignore
64
+ for group in self.param_groups:
65
+ # Find the sparse embedding weights
66
+ local_weights_grad = None
67
+ local_ids = None
68
+ weights = None
69
+
70
+ assert len(group["params"]) == 3
71
+ for p in group["params"]:
72
+ if p.requires_grad:
73
+ local_weights_grad = p.grad
74
+ elif p.ndim == 1:
75
+ local_ids = p
76
+ elif p.ndim == 2:
77
+ weights = p
78
+ else:
79
+ assert False
80
+
81
+ assert local_ids is not None
82
+ assert weights is not None
83
+
84
+ # Apply SignSGD
85
+ # Adam ≈ SignSGD if gradient is very sparse
86
+ if local_weights_grad is not None:
87
+ _sparse_emb_signsgd_dist(
88
+ local_weights_grad,
89
+ local_ids,
90
+ weights,
91
+
92
+ lr=group["lr"],
93
+ weight_decay=group["weight_decay"],
94
+ world_size=group["world_size"]
95
+ )
96
+
97
+
98
+ def _sparse_emb_signsgd_dist(
99
+ local_weights_grad: torch.Tensor,
100
+ local_ids: torch.Tensor,
101
+ weights: torch.Tensor,
102
+
103
+ lr: float,
104
+ weight_decay: float,
105
+ world_size: int
106
+ ) -> None:
107
+ N, D = local_weights_grad.shape
108
+
109
+ # All-gather
110
+ all_weights_grad = local_weights_grad
111
+ all_ids = local_ids
112
+
113
+ if world_size > 1:
114
+ all_weights_grad = torch.empty((world_size * N, D), dtype=local_weights_grad.dtype, device=local_weights_grad.device)
115
+ all_ids = torch.empty(world_size * N, dtype=local_ids.dtype, device=local_ids.device)
116
+
117
+ dist.all_gather_into_tensor(all_weights_grad, local_weights_grad)
118
+ dist.all_gather_into_tensor(all_ids, local_ids)
119
+
120
+ # Unique
121
+ grad_ids, inv = all_ids.unique(return_inverse=True)
122
+
123
+ grad = torch.zeros((grad_ids.shape[0], D), dtype=all_weights_grad.dtype, device=all_weights_grad.device)
124
+ grad.scatter_add_(0, inv.unsqueeze(-1).expand(-1, D), all_weights_grad)
125
+
126
+ # SignSGD with decoupled weight decay
127
+ p = weights[grad_ids]
128
+
129
+ p.mul_(1.0 - lr * weight_decay).add_(torch.sign(grad), alpha=-lr)
130
+
131
+ # Write updated slices back
132
+ weights[grad_ids] = p
utils/functions.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import inspect
3
+
4
+
5
+ def load_model_class(identifier: str, prefix: str = "models."):
6
+ module_path, class_name = identifier.split('@')
7
+
8
+ # Import the module
9
+ module = importlib.import_module(prefix + module_path)
10
+ cls = getattr(module, class_name)
11
+
12
+ return cls
13
+
14
+
15
+ def get_model_source_path(identifier: str, prefix: str = "models."):
16
+ module_path, class_name = identifier.split('@')
17
+
18
+ module = importlib.import_module(prefix + module_path)
19
+ return inspect.getsourcefile(module)