Upload 34 files
Browse files- .gitattributes +2 -0
- assets/TRM_fig.png +3 -0
- assets/TRM_pseudocode.png +3 -0
- config/arch/hrm.yaml +24 -0
- config/arch/transformers_baseline.yaml +18 -0
- config/arch/trm.yaml +26 -0
- config/arch/trm_hier6.yaml +26 -0
- config/arch/trm_singlez.yaml +26 -0
- config/cfg_pretrain.yaml +42 -0
- dataset/build_arc_dataset.py +341 -0
- dataset/build_maze_dataset.py +140 -0
- dataset/build_sudoku_dataset.py +167 -0
- dataset/common.py +49 -0
- evaluators/arc.py +177 -0
- kaggle/combined/arc-agi_concept_challenges.json +0 -0
- kaggle/combined/arc-agi_concept_solutions.json +0 -0
- kaggle/combined/arc-agi_evaluation2_challenges.json +0 -0
- kaggle/combined/arc-agi_evaluation2_solutions.json +0 -0
- kaggle/combined/arc-agi_evaluation_challenges.json +0 -0
- kaggle/combined/arc-agi_evaluation_solutions.json +0 -0
- kaggle/combined/arc-agi_training2_challenges.json +0 -0
- kaggle/combined/arc-agi_training2_solutions.json +0 -0
- kaggle/combined/arc-agi_training_challenges.json +0 -0
- kaggle/combined/arc-agi_training_solutions.json +0 -0
- models/common.py +32 -0
- models/ema.py +40 -0
- models/layers.py +169 -0
- models/losses.py +103 -0
- models/recursive_reasoning/hrm.py +294 -0
- models/recursive_reasoning/transformers_baseline.py +342 -0
- models/recursive_reasoning/trm.py +297 -0
- models/recursive_reasoning/trm_hier6.py +323 -0
- models/recursive_reasoning/trm_singlez.py +294 -0
- models/sparse_embedding.py +132 -0
- utils/functions.py +19 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/TRM_fig.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/TRM_pseudocode.png filter=lfs diff=lfs merge=lfs -text
|
assets/TRM_fig.png
ADDED
|
Git LFS Details
|
assets/TRM_pseudocode.png
ADDED
|
Git LFS Details
|
config/arch/hrm.yaml
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: recursive_reasoning.hrm@HierarchicalReasoningModel_ACTV1
|
| 2 |
+
loss:
|
| 3 |
+
name: losses@ACTLossHead
|
| 4 |
+
loss_type: stablemax_cross_entropy
|
| 5 |
+
|
| 6 |
+
halt_exploration_prob: 0.1
|
| 7 |
+
halt_max_steps: 16
|
| 8 |
+
|
| 9 |
+
H_cycles: 2
|
| 10 |
+
L_cycles: 2
|
| 11 |
+
|
| 12 |
+
H_layers: 4
|
| 13 |
+
L_layers: 4
|
| 14 |
+
|
| 15 |
+
hidden_size: 512
|
| 16 |
+
num_heads: 8 # min(2, hidden_size // 64)
|
| 17 |
+
expansion: 4
|
| 18 |
+
|
| 19 |
+
puzzle_emb_ndim: ${.hidden_size}
|
| 20 |
+
|
| 21 |
+
pos_encodings: rope
|
| 22 |
+
forward_dtype: bfloat16
|
| 23 |
+
|
| 24 |
+
mlp_t: False # use mlp on L instead of transformer
|
config/arch/transformers_baseline.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: recursive_reasoning.transformers_baseline@Model_ACTV2
|
| 2 |
+
loss:
|
| 3 |
+
name: losses@ACTLossHead
|
| 4 |
+
loss_type: stablemax_cross_entropy
|
| 5 |
+
|
| 6 |
+
halt_exploration_prob: 0.1
|
| 7 |
+
halt_max_steps: 16
|
| 8 |
+
|
| 9 |
+
H_cycles: 1 # kept for compatibility
|
| 10 |
+
H_layers: 8
|
| 11 |
+
|
| 12 |
+
hidden_size: 512
|
| 13 |
+
num_heads: 12
|
| 14 |
+
expansion: 4
|
| 15 |
+
|
| 16 |
+
puzzle_emb_ndim: ${.hidden_size}
|
| 17 |
+
|
| 18 |
+
pos_encodings: rope
|
config/arch/trm.yaml
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: recursive_reasoning.trm@TinyRecursiveReasoningModel_ACTV1
|
| 2 |
+
loss:
|
| 3 |
+
name: losses@ACTLossHead
|
| 4 |
+
loss_type: stablemax_cross_entropy
|
| 5 |
+
|
| 6 |
+
halt_exploration_prob: 0.1
|
| 7 |
+
halt_max_steps: 16
|
| 8 |
+
|
| 9 |
+
H_cycles: 3
|
| 10 |
+
L_cycles: 6
|
| 11 |
+
|
| 12 |
+
H_layers: 0
|
| 13 |
+
L_layers: 2
|
| 14 |
+
|
| 15 |
+
hidden_size: 512
|
| 16 |
+
num_heads: 8 # min(2, hidden_size // 64)
|
| 17 |
+
expansion: 4
|
| 18 |
+
|
| 19 |
+
puzzle_emb_ndim: ${.hidden_size}
|
| 20 |
+
|
| 21 |
+
pos_encodings: rope
|
| 22 |
+
forward_dtype: bfloat16
|
| 23 |
+
|
| 24 |
+
mlp_t: False # use mlp on L instead of transformer
|
| 25 |
+
puzzle_emb_len: 16 # if non-zero, its specified to this value
|
| 26 |
+
no_ACT_continue: True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense
|
config/arch/trm_hier6.yaml
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: recursive_reasoning.trm_hier6@TinyRecursiveReasoningModel_ACTV1
|
| 2 |
+
loss:
|
| 3 |
+
name: losses@ACTLossHead
|
| 4 |
+
loss_type: stablemax_cross_entropy
|
| 5 |
+
|
| 6 |
+
halt_exploration_prob: 0.1
|
| 7 |
+
halt_max_steps: 16
|
| 8 |
+
|
| 9 |
+
H_cycles: 3
|
| 10 |
+
L_cycles: 6
|
| 11 |
+
|
| 12 |
+
H_layers: 0
|
| 13 |
+
L_layers: 2
|
| 14 |
+
|
| 15 |
+
hidden_size: 512
|
| 16 |
+
num_heads: 8 # min(2, hidden_size // 64)
|
| 17 |
+
expansion: 4
|
| 18 |
+
|
| 19 |
+
puzzle_emb_ndim: ${.hidden_size}
|
| 20 |
+
|
| 21 |
+
pos_encodings: rope
|
| 22 |
+
forward_dtype: bfloat16
|
| 23 |
+
|
| 24 |
+
mlp_t: False # use mlp on L instead of transformer
|
| 25 |
+
puzzle_emb_len: 16 # if non-zero, its specified to this value
|
| 26 |
+
no_ACT_continue: True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense
|
config/arch/trm_singlez.yaml
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: recursive_reasoning.trm_singlez@TinyRecursiveReasoningModel_ACTV1
|
| 2 |
+
loss:
|
| 3 |
+
name: losses@ACTLossHead
|
| 4 |
+
loss_type: stablemax_cross_entropy
|
| 5 |
+
|
| 6 |
+
halt_exploration_prob: 0.1
|
| 7 |
+
halt_max_steps: 16
|
| 8 |
+
|
| 9 |
+
H_cycles: 3
|
| 10 |
+
L_cycles: 6
|
| 11 |
+
|
| 12 |
+
H_layers: 0
|
| 13 |
+
L_layers: 2
|
| 14 |
+
|
| 15 |
+
hidden_size: 512
|
| 16 |
+
num_heads: 8 # min(2, hidden_size // 64)
|
| 17 |
+
expansion: 4
|
| 18 |
+
|
| 19 |
+
puzzle_emb_ndim: ${.hidden_size}
|
| 20 |
+
|
| 21 |
+
pos_encodings: rope
|
| 22 |
+
forward_dtype: bfloat16
|
| 23 |
+
|
| 24 |
+
mlp_t: False # use mlp on L instead of transformer
|
| 25 |
+
puzzle_emb_len: 16 # if non-zero, its specified to this value
|
| 26 |
+
no_ACT_continue: True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense
|
config/cfg_pretrain.yaml
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ARC training config
|
| 2 |
+
|
| 3 |
+
defaults:
|
| 4 |
+
- arch: trm
|
| 5 |
+
- _self_
|
| 6 |
+
|
| 7 |
+
hydra:
|
| 8 |
+
output_subdir: null
|
| 9 |
+
|
| 10 |
+
# Data path
|
| 11 |
+
data_paths: ['data/arc-aug-1000']
|
| 12 |
+
data_paths_test: []
|
| 13 |
+
|
| 14 |
+
evaluators:
|
| 15 |
+
- name: arc@ARC
|
| 16 |
+
|
| 17 |
+
# Hyperparams - Training
|
| 18 |
+
global_batch_size: 768
|
| 19 |
+
|
| 20 |
+
epochs: 100000
|
| 21 |
+
eval_interval: 10000
|
| 22 |
+
checkpoint_every_eval: True
|
| 23 |
+
|
| 24 |
+
lr: 1e-4
|
| 25 |
+
lr_min_ratio: 1.0
|
| 26 |
+
lr_warmup_steps: 2000
|
| 27 |
+
|
| 28 |
+
# Standard hyperparameter settings for LM, as used in Llama
|
| 29 |
+
beta1: 0.9
|
| 30 |
+
beta2: 0.95
|
| 31 |
+
weight_decay: 0.1
|
| 32 |
+
puzzle_emb_weight_decay: 0.1
|
| 33 |
+
|
| 34 |
+
# Hyperparams - Puzzle embeddings training
|
| 35 |
+
puzzle_emb_lr: 1e-2
|
| 36 |
+
|
| 37 |
+
seed: 0
|
| 38 |
+
min_eval_interval: 0 # when to start the eval
|
| 39 |
+
|
| 40 |
+
ema: False # use Exponential-Moving-Average
|
| 41 |
+
ema_rate: 0.999 # EMA-rate
|
| 42 |
+
freeze_weights: False # If True, freeze weights and only learn the embeddings
|
dataset/build_arc_dataset.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Tuple, Dict
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
import hashlib
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from argdantic import ArgParser
|
| 9 |
+
from pydantic import BaseModel
|
| 10 |
+
|
| 11 |
+
from dataset.common import PuzzleDatasetMetadata, dihedral_transform, inverse_dihedral_transform
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
cli = ArgParser()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class DataProcessConfig(BaseModel):
|
| 18 |
+
input_file_prefix: str
|
| 19 |
+
output_dir: str
|
| 20 |
+
subsets: List[str]
|
| 21 |
+
test_set_name: str
|
| 22 |
+
test_set_name2: str = "your_test_set"
|
| 23 |
+
seed: int = 42
|
| 24 |
+
num_aug: int = 1000
|
| 25 |
+
puzzle_identifiers_start: int = 1 # start > 1 to handle multiple datasets
|
| 26 |
+
|
| 27 |
+
ARCMaxGridSize = 30
|
| 28 |
+
ARCAugmentRetriesFactor = 5
|
| 29 |
+
|
| 30 |
+
PuzzleIdSeparator = "|||"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dataclass
|
| 34 |
+
class ARCPuzzle:
|
| 35 |
+
id: str
|
| 36 |
+
examples: List[Tuple[np.ndarray, np.ndarray]]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def arc_grid_to_np(grid: List[List[int]]):
|
| 40 |
+
arr = np.array(grid)
|
| 41 |
+
|
| 42 |
+
# Shape check
|
| 43 |
+
assert arr.ndim == 2
|
| 44 |
+
assert arr.shape[0] <= ARCMaxGridSize and arr.shape[1] <= ARCMaxGridSize
|
| 45 |
+
# Element check
|
| 46 |
+
assert np.all((arr >= 0) & (arr <= 9))
|
| 47 |
+
return arr.astype(np.uint8)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def np_grid_to_seq_translational_augment(inp: np.ndarray, out: np.ndarray, do_translation: bool):
|
| 51 |
+
# PAD: 0, <eos>: 1, digits: 2 ... 11
|
| 52 |
+
# Compute random top-left pad
|
| 53 |
+
if do_translation:
|
| 54 |
+
pad_r = np.random.randint(0, ARCMaxGridSize - max(inp.shape[0], out.shape[0]) + 1)
|
| 55 |
+
pad_c = np.random.randint(0, ARCMaxGridSize - max(inp.shape[1], out.shape[1]) + 1)
|
| 56 |
+
else:
|
| 57 |
+
pad_r = pad_c = 0
|
| 58 |
+
|
| 59 |
+
# Pad grid
|
| 60 |
+
result = []
|
| 61 |
+
for grid in [inp, out]:
|
| 62 |
+
nrow, ncol = grid.shape
|
| 63 |
+
grid = np.pad(grid + 2, ((pad_r, ARCMaxGridSize - pad_r - nrow), (pad_c, ARCMaxGridSize - pad_c - ncol)), constant_values=0)
|
| 64 |
+
|
| 65 |
+
# Add <eos>
|
| 66 |
+
eos_row, eos_col = pad_r + nrow, pad_c + ncol
|
| 67 |
+
if eos_row < ARCMaxGridSize:
|
| 68 |
+
grid[eos_row, pad_c:eos_col] = 1
|
| 69 |
+
if eos_col < ARCMaxGridSize:
|
| 70 |
+
grid[pad_r:eos_row, eos_col] = 1
|
| 71 |
+
|
| 72 |
+
result.append(grid.flatten())
|
| 73 |
+
|
| 74 |
+
return result
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def grid_hash(grid: np.ndarray):
|
| 78 |
+
assert grid.ndim == 2
|
| 79 |
+
assert grid.dtype == np.uint8
|
| 80 |
+
|
| 81 |
+
buffer = [x.to_bytes(1, byteorder='big') for x in grid.shape]
|
| 82 |
+
buffer.append(grid.tobytes())
|
| 83 |
+
|
| 84 |
+
return hashlib.sha256(b"".join(buffer)).hexdigest()
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def puzzle_hash(puzzle: dict):
|
| 88 |
+
# Hash the puzzle for checking equivalence
|
| 89 |
+
hashes = []
|
| 90 |
+
for example_type, example in puzzle.items():
|
| 91 |
+
for input, label in example.examples:
|
| 92 |
+
hashes.append(f"{grid_hash(input)}|{grid_hash(label)}")
|
| 93 |
+
|
| 94 |
+
hashes.sort()
|
| 95 |
+
return hashlib.sha256("|".join(hashes).encode()).hexdigest()
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def aug(name: str):
|
| 99 |
+
# Augment plan
|
| 100 |
+
trans_id = np.random.randint(0, 8)
|
| 101 |
+
mapping = np.concatenate([np.arange(0, 1, dtype=np.uint8), np.random.permutation(np.arange(1, 10, dtype=np.uint8))]) # Permute colors, Excluding "0" (black)
|
| 102 |
+
|
| 103 |
+
name_with_aug_repr = f"{name}{PuzzleIdSeparator}t{trans_id}{PuzzleIdSeparator}{''.join(str(x) for x in mapping)}"
|
| 104 |
+
|
| 105 |
+
def _map_grid(grid: np.ndarray):
|
| 106 |
+
return dihedral_transform(mapping[grid], trans_id)
|
| 107 |
+
|
| 108 |
+
return name_with_aug_repr, _map_grid
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def inverse_aug(name: str):
|
| 112 |
+
# Inverse the "aug" function
|
| 113 |
+
if PuzzleIdSeparator not in name:
|
| 114 |
+
return name, lambda x: x
|
| 115 |
+
|
| 116 |
+
trans_id, perm = name.split(PuzzleIdSeparator)[-2:]
|
| 117 |
+
trans_id = int(trans_id[1:]) # Remove "t" letter
|
| 118 |
+
inv_perm = np.argsort(list(perm)).astype(np.uint8)
|
| 119 |
+
|
| 120 |
+
def _map_grid(grid: np.ndarray):
|
| 121 |
+
return inv_perm[inverse_dihedral_transform(grid, trans_id)]
|
| 122 |
+
|
| 123 |
+
return name.split(PuzzleIdSeparator)[0], _map_grid
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def convert_single_arc_puzzle(results: dict, name: str, puzzle: dict, aug_count: int, dest_mapping: Dict[str, Tuple[str, str]]):
|
| 127 |
+
# Convert
|
| 128 |
+
dests = set(dest_mapping.values())
|
| 129 |
+
converted = {dest: ARCPuzzle(name, []) for dest in dests}
|
| 130 |
+
for example_type, examples in puzzle.items():
|
| 131 |
+
# Map to target split
|
| 132 |
+
dest = dest_mapping[example_type]
|
| 133 |
+
converted[dest].examples.extend([(arc_grid_to_np(example["input"]), arc_grid_to_np(example["output"])) for example in examples])
|
| 134 |
+
|
| 135 |
+
group = [converted]
|
| 136 |
+
|
| 137 |
+
# Augment
|
| 138 |
+
if aug_count > 0:
|
| 139 |
+
hashes = {puzzle_hash(converted)}
|
| 140 |
+
|
| 141 |
+
for _trial in range(ARCAugmentRetriesFactor * aug_count):
|
| 142 |
+
aug_name, _map_grid = aug(name)
|
| 143 |
+
|
| 144 |
+
# Check duplicate
|
| 145 |
+
augmented = {dest: ARCPuzzle(aug_name, [(_map_grid(input), _map_grid(label)) for (input, label) in puzzle.examples]) for dest, puzzle in converted.items()}
|
| 146 |
+
h = puzzle_hash(augmented)
|
| 147 |
+
if h not in hashes:
|
| 148 |
+
hashes.add(h)
|
| 149 |
+
group.append(augmented)
|
| 150 |
+
|
| 151 |
+
if len(group) >= aug_count + 1:
|
| 152 |
+
break
|
| 153 |
+
|
| 154 |
+
if len(group) < aug_count + 1:
|
| 155 |
+
print (f"[Puzzle {name}] augmentation not full, only {len(group)}")
|
| 156 |
+
|
| 157 |
+
# Append
|
| 158 |
+
for dest in dests:
|
| 159 |
+
# Convert the examples
|
| 160 |
+
dest_split, dest_set = dest
|
| 161 |
+
|
| 162 |
+
results.setdefault(dest_split, {})
|
| 163 |
+
results[dest_split].setdefault(dest_set, [])
|
| 164 |
+
results[dest_split][dest_set].append([converted[dest] for converted in group])
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def load_puzzles_arcagi(config: DataProcessConfig):
|
| 168 |
+
train_examples_dest = ("train", "all")
|
| 169 |
+
test_examples_map = {
|
| 170 |
+
config.test_set_name: [(1.0, ("test", "all"))],
|
| 171 |
+
config.test_set_name2: [(1.0, ("test", "all"))],
|
| 172 |
+
"_default": [(1.0, ("train", "all"))]
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
test_puzzles = {}
|
| 176 |
+
results = {}
|
| 177 |
+
|
| 178 |
+
total_puzzles = 0
|
| 179 |
+
for subset_name in config.subsets:
|
| 180 |
+
# Load all puzzles in this subset
|
| 181 |
+
with open(f"{config.input_file_prefix}_{subset_name}_challenges.json", "r") as f:
|
| 182 |
+
puzzles = json.load(f)
|
| 183 |
+
|
| 184 |
+
sols_filename = f"{config.input_file_prefix}_{subset_name}_solutions.json"
|
| 185 |
+
if os.path.isfile(sols_filename):
|
| 186 |
+
with open(sols_filename, "r") as f:
|
| 187 |
+
sols = json.load(f)
|
| 188 |
+
|
| 189 |
+
for puzzle_id in puzzles.keys():
|
| 190 |
+
for idx, sol_grid in enumerate(sols[puzzle_id]):
|
| 191 |
+
puzzles[puzzle_id]["test"][idx]["output"] = sol_grid
|
| 192 |
+
else:
|
| 193 |
+
# Fill with dummy
|
| 194 |
+
print (f"{subset_name} solutions not found, filling with dummy")
|
| 195 |
+
|
| 196 |
+
for puzzle_id, puzzle in puzzles.items():
|
| 197 |
+
for example in puzzle["test"]:
|
| 198 |
+
example.setdefault("output", [[0]])
|
| 199 |
+
|
| 200 |
+
# Shuffle puzzles
|
| 201 |
+
puzzles = list(puzzles.items())
|
| 202 |
+
np.random.shuffle(puzzles)
|
| 203 |
+
|
| 204 |
+
# Assign by fraction
|
| 205 |
+
for idx, (name, puzzle) in enumerate(puzzles):
|
| 206 |
+
fraction = idx / len(puzzles)
|
| 207 |
+
test_examples_dest = None
|
| 208 |
+
for f, dest in test_examples_map.get(subset_name, test_examples_map["_default"]):
|
| 209 |
+
if fraction < f:
|
| 210 |
+
test_examples_dest = dest
|
| 211 |
+
break
|
| 212 |
+
|
| 213 |
+
assert test_examples_dest is not None
|
| 214 |
+
|
| 215 |
+
if test_examples_dest[0] == "test":
|
| 216 |
+
test_puzzles[name] = puzzle
|
| 217 |
+
|
| 218 |
+
convert_single_arc_puzzle(results, name, puzzle, config.num_aug, {"train": train_examples_dest, "test": test_examples_dest})
|
| 219 |
+
total_puzzles += 1
|
| 220 |
+
|
| 221 |
+
print (f"Total puzzles: {total_puzzles}")
|
| 222 |
+
return results, test_puzzles
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def convert_dataset(config: DataProcessConfig):
|
| 226 |
+
np.random.seed(config.seed)
|
| 227 |
+
|
| 228 |
+
# Read dataset
|
| 229 |
+
data, test_puzzles = load_puzzles_arcagi(config)
|
| 230 |
+
|
| 231 |
+
# Map global puzzle identifiers
|
| 232 |
+
num_identifiers = config.puzzle_identifiers_start # 0 is blank, start at 1
|
| 233 |
+
identifier_map = {}
|
| 234 |
+
for split_name, split in data.items():
|
| 235 |
+
for subset_name, subset in split.items():
|
| 236 |
+
for group in subset:
|
| 237 |
+
for puzzle in group:
|
| 238 |
+
if puzzle.id not in identifier_map:
|
| 239 |
+
identifier_map[puzzle.id] = num_identifiers
|
| 240 |
+
num_identifiers += 1
|
| 241 |
+
print (f"Total puzzle IDs (including <blank>): {num_identifiers}")
|
| 242 |
+
|
| 243 |
+
# Save
|
| 244 |
+
for split_name, split in data.items():
|
| 245 |
+
os.makedirs(os.path.join(config.output_dir, split_name), exist_ok=True)
|
| 246 |
+
|
| 247 |
+
# Translational augmentations
|
| 248 |
+
enable_translational_augment = split_name == "train"
|
| 249 |
+
|
| 250 |
+
# Statistics
|
| 251 |
+
total_examples = 0
|
| 252 |
+
total_puzzles = 0
|
| 253 |
+
total_groups = 0
|
| 254 |
+
|
| 255 |
+
for subset_name, subset in split.items(): # "all" is the only subset
|
| 256 |
+
# Construct subset
|
| 257 |
+
results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
|
| 258 |
+
results["puzzle_indices"].append(0)
|
| 259 |
+
results["group_indices"].append(0)
|
| 260 |
+
|
| 261 |
+
example_id = 0
|
| 262 |
+
puzzle_id = 0
|
| 263 |
+
|
| 264 |
+
for group in subset:
|
| 265 |
+
for puzzle in group:
|
| 266 |
+
# Push puzzle
|
| 267 |
+
no_aug_id = np.random.randint(0, len(puzzle.examples))
|
| 268 |
+
for _idx_ex, (inp, out) in enumerate(puzzle.examples):
|
| 269 |
+
inp, out = np_grid_to_seq_translational_augment(inp, out, do_translation=enable_translational_augment and _idx_ex != no_aug_id)
|
| 270 |
+
|
| 271 |
+
results["inputs"].append(inp)
|
| 272 |
+
results["labels"].append(out)
|
| 273 |
+
example_id += 1
|
| 274 |
+
|
| 275 |
+
total_examples += 1
|
| 276 |
+
|
| 277 |
+
results["puzzle_indices"].append(example_id)
|
| 278 |
+
results["puzzle_identifiers"].append(identifier_map[puzzle.id])
|
| 279 |
+
|
| 280 |
+
puzzle_id += 1
|
| 281 |
+
total_puzzles += 1
|
| 282 |
+
|
| 283 |
+
# Push group
|
| 284 |
+
results["group_indices"].append(puzzle_id)
|
| 285 |
+
total_groups += 1
|
| 286 |
+
|
| 287 |
+
for k, v in results.items():
|
| 288 |
+
if k in {"inputs", "labels"}:
|
| 289 |
+
v = np.stack(v, 0)
|
| 290 |
+
else:
|
| 291 |
+
v = np.array(v, dtype=np.int32)
|
| 292 |
+
|
| 293 |
+
np.save(os.path.join(config.output_dir, split_name, f"{subset_name}__{k}.npy"), v)
|
| 294 |
+
|
| 295 |
+
# Metadata
|
| 296 |
+
metadata = PuzzleDatasetMetadata(
|
| 297 |
+
seq_len=ARCMaxGridSize * ARCMaxGridSize,
|
| 298 |
+
vocab_size=10 + 2, # PAD + EOS + "0" ... "9"
|
| 299 |
+
pad_id=0,
|
| 300 |
+
ignore_label_id=0,
|
| 301 |
+
blank_identifier_id=0,
|
| 302 |
+
num_puzzle_identifiers=num_identifiers,
|
| 303 |
+
total_groups=total_groups,
|
| 304 |
+
mean_puzzle_examples=total_examples / total_puzzles,
|
| 305 |
+
total_puzzles=total_puzzles,
|
| 306 |
+
sets=list(split.keys())
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
# Save metadata as JSON.
|
| 310 |
+
with open(os.path.join(config.output_dir, split_name, "dataset.json"), "w") as f:
|
| 311 |
+
json.dump(metadata.model_dump(), f)
|
| 312 |
+
|
| 313 |
+
# Save IDs mapping
|
| 314 |
+
with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
|
| 315 |
+
ids_mapping = {v: k for k, v in identifier_map.items()}
|
| 316 |
+
json.dump([ids_mapping.get(i, "<blank>") for i in range(num_identifiers)], f)
|
| 317 |
+
|
| 318 |
+
# Save Test Puzzles
|
| 319 |
+
with open(os.path.join(config.output_dir, "test_puzzles.json"), "w") as f:
|
| 320 |
+
json.dump(test_puzzles, f)
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
@cli.command(singleton=True)
|
| 324 |
+
def main(config: DataProcessConfig):
|
| 325 |
+
convert_dataset(config)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
if __name__ == "__main__":
|
| 329 |
+
cli()
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
|
dataset/build_maze_dataset.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
import math
|
| 3 |
+
import os
|
| 4 |
+
import csv
|
| 5 |
+
import json
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from argdantic import ArgParser
|
| 9 |
+
from pydantic import BaseModel
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from huggingface_hub import hf_hub_download
|
| 12 |
+
|
| 13 |
+
from common import PuzzleDatasetMetadata, dihedral_transform
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
CHARSET = "# SGo"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
cli = ArgParser()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class DataProcessConfig(BaseModel):
|
| 23 |
+
source_repo: str = "sapientinc/maze-30x30-hard-1k"
|
| 24 |
+
output_dir: str = "data/maze-30x30-hard-1k"
|
| 25 |
+
|
| 26 |
+
subsample_size: Optional[int] = None
|
| 27 |
+
aug: bool = False
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def convert_subset(set_name: str, config: DataProcessConfig):
|
| 31 |
+
# Read CSV
|
| 32 |
+
all_chars = set()
|
| 33 |
+
grid_size = None
|
| 34 |
+
inputs = []
|
| 35 |
+
labels = []
|
| 36 |
+
|
| 37 |
+
with open(hf_hub_download(config.source_repo, f"{set_name}.csv", repo_type="dataset"), newline="") as csvfile: # type: ignore
|
| 38 |
+
reader = csv.reader(csvfile)
|
| 39 |
+
next(reader) # Skip header
|
| 40 |
+
for source, q, a, rating in reader:
|
| 41 |
+
all_chars.update(q)
|
| 42 |
+
all_chars.update(a)
|
| 43 |
+
|
| 44 |
+
if grid_size is None:
|
| 45 |
+
n = int(len(q) ** 0.5)
|
| 46 |
+
grid_size = (n, n)
|
| 47 |
+
|
| 48 |
+
inputs.append(np.frombuffer(q.encode(), dtype=np.uint8).reshape(grid_size))
|
| 49 |
+
labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(grid_size))
|
| 50 |
+
|
| 51 |
+
# If subsample_size is specified for the training set,
|
| 52 |
+
# randomly sample the desired number of examples.
|
| 53 |
+
if set_name == "train" and config.subsample_size is not None:
|
| 54 |
+
total_samples = len(inputs)
|
| 55 |
+
if config.subsample_size < total_samples:
|
| 56 |
+
indices = np.random.choice(total_samples, size=config.subsample_size, replace=False)
|
| 57 |
+
inputs = [inputs[i] for i in indices]
|
| 58 |
+
labels = [labels[i] for i in indices]
|
| 59 |
+
|
| 60 |
+
# Generate dataset
|
| 61 |
+
results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
|
| 62 |
+
puzzle_id = 0
|
| 63 |
+
example_id = 0
|
| 64 |
+
|
| 65 |
+
results["puzzle_indices"].append(0)
|
| 66 |
+
results["group_indices"].append(0)
|
| 67 |
+
|
| 68 |
+
for inp, out in zip(tqdm(inputs), labels):
|
| 69 |
+
# Dihedral transformations for augmentation
|
| 70 |
+
for aug_idx in range(8 if (set_name == "train" and config.aug) else 1):
|
| 71 |
+
results["inputs"].append(dihedral_transform(inp, aug_idx))
|
| 72 |
+
results["labels"].append(dihedral_transform(out, aug_idx))
|
| 73 |
+
example_id += 1
|
| 74 |
+
puzzle_id += 1
|
| 75 |
+
|
| 76 |
+
results["puzzle_indices"].append(example_id)
|
| 77 |
+
results["puzzle_identifiers"].append(0)
|
| 78 |
+
|
| 79 |
+
# Push group
|
| 80 |
+
results["group_indices"].append(puzzle_id)
|
| 81 |
+
|
| 82 |
+
# Char mappings
|
| 83 |
+
assert len(all_chars - set(CHARSET)) == 0
|
| 84 |
+
|
| 85 |
+
char2id = np.zeros(256, np.uint8)
|
| 86 |
+
char2id[np.array(list(map(ord, CHARSET)))] = np.arange(len(CHARSET)) + 1
|
| 87 |
+
|
| 88 |
+
# To Numpy
|
| 89 |
+
def _seq_to_numpy(seq):
|
| 90 |
+
arr = np.vstack([char2id[s.reshape(-1)] for s in seq])
|
| 91 |
+
|
| 92 |
+
return arr
|
| 93 |
+
|
| 94 |
+
results = {
|
| 95 |
+
"inputs": _seq_to_numpy(results["inputs"]),
|
| 96 |
+
"labels": _seq_to_numpy(results["labels"]),
|
| 97 |
+
|
| 98 |
+
"group_indices": np.array(results["group_indices"], dtype=np.int32),
|
| 99 |
+
"puzzle_indices": np.array(results["puzzle_indices"], dtype=np.int32),
|
| 100 |
+
"puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32),
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
# Metadata
|
| 104 |
+
metadata = PuzzleDatasetMetadata(
|
| 105 |
+
seq_len=int(math.prod(grid_size)), # type: ignore
|
| 106 |
+
vocab_size=len(CHARSET) + 1, # PAD + Charset
|
| 107 |
+
pad_id=0,
|
| 108 |
+
ignore_label_id=0,
|
| 109 |
+
blank_identifier_id=0,
|
| 110 |
+
num_puzzle_identifiers=1,
|
| 111 |
+
total_groups=len(results["group_indices"]) - 1,
|
| 112 |
+
mean_puzzle_examples=1,
|
| 113 |
+
total_puzzles=len(results["group_indices"]) - 1,
|
| 114 |
+
sets=["all"]
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# Save metadata as JSON.
|
| 118 |
+
save_dir = os.path.join(config.output_dir, set_name)
|
| 119 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 120 |
+
|
| 121 |
+
with open(os.path.join(save_dir, "dataset.json"), "w") as f:
|
| 122 |
+
json.dump(metadata.model_dump(), f)
|
| 123 |
+
|
| 124 |
+
# Save data
|
| 125 |
+
for k, v in results.items():
|
| 126 |
+
np.save(os.path.join(save_dir, f"all__{k}.npy"), v)
|
| 127 |
+
|
| 128 |
+
# Save IDs mapping (for visualization only)
|
| 129 |
+
with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
|
| 130 |
+
json.dump(["<blank>"], f)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
@cli.command(singleton=True)
|
| 134 |
+
def preprocess_data(config: DataProcessConfig):
|
| 135 |
+
convert_subset("train", config)
|
| 136 |
+
convert_subset("test", config)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
if __name__ == "__main__":
|
| 140 |
+
cli()
|
dataset/build_sudoku_dataset.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
import os
|
| 3 |
+
import csv
|
| 4 |
+
import json
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from argdantic import ArgParser
|
| 8 |
+
from pydantic import BaseModel
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
from huggingface_hub import hf_hub_download
|
| 11 |
+
|
| 12 |
+
from common import PuzzleDatasetMetadata
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
cli = ArgParser()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class DataProcessConfig(BaseModel):
|
| 19 |
+
source_repo: str = "sapientinc/sudoku-extreme"
|
| 20 |
+
output_dir: str = "data/sudoku-extreme-full"
|
| 21 |
+
|
| 22 |
+
subsample_size: Optional[int] = None
|
| 23 |
+
min_difficulty: Optional[int] = None
|
| 24 |
+
num_aug: int = 0
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def shuffle_sudoku(board: np.ndarray, solution: np.ndarray):
|
| 28 |
+
# Create a random digit mapping: a permutation of 1..9, with zero (blank) unchanged
|
| 29 |
+
digit_map = np.pad(np.random.permutation(np.arange(1, 10)), (1, 0))
|
| 30 |
+
|
| 31 |
+
# Randomly decide whether to transpose.
|
| 32 |
+
transpose_flag = np.random.rand() < 0.5
|
| 33 |
+
|
| 34 |
+
# Generate a valid row permutation:
|
| 35 |
+
# - Shuffle the 3 bands (each band = 3 rows) and for each band, shuffle its 3 rows.
|
| 36 |
+
bands = np.random.permutation(3)
|
| 37 |
+
row_perm = np.concatenate([b * 3 + np.random.permutation(3) for b in bands])
|
| 38 |
+
|
| 39 |
+
# Similarly for columns (stacks).
|
| 40 |
+
stacks = np.random.permutation(3)
|
| 41 |
+
col_perm = np.concatenate([s * 3 + np.random.permutation(3) for s in stacks])
|
| 42 |
+
|
| 43 |
+
# Build an 81->81 mapping. For each new cell at (i, j)
|
| 44 |
+
# (row index = i // 9, col index = i % 9),
|
| 45 |
+
# its value comes from old row = row_perm[i//9] and old col = col_perm[i%9].
|
| 46 |
+
mapping = np.array([row_perm[i // 9] * 9 + col_perm[i % 9] for i in range(81)])
|
| 47 |
+
|
| 48 |
+
def apply_transformation(x: np.ndarray) -> np.ndarray:
|
| 49 |
+
# Apply transpose flag
|
| 50 |
+
if transpose_flag:
|
| 51 |
+
x = x.T
|
| 52 |
+
# Apply the position mapping.
|
| 53 |
+
new_board = x.flatten()[mapping].reshape(9, 9).copy()
|
| 54 |
+
# Apply digit mapping
|
| 55 |
+
return digit_map[new_board]
|
| 56 |
+
|
| 57 |
+
return apply_transformation(board), apply_transformation(solution)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def convert_subset(set_name: str, config: DataProcessConfig):
|
| 61 |
+
# Read CSV
|
| 62 |
+
inputs = []
|
| 63 |
+
labels = []
|
| 64 |
+
|
| 65 |
+
with open(hf_hub_download(config.source_repo, f"{set_name}.csv", repo_type="dataset"), newline="") as csvfile:
|
| 66 |
+
reader = csv.reader(csvfile)
|
| 67 |
+
next(reader) # Skip header
|
| 68 |
+
for source, q, a, rating in reader:
|
| 69 |
+
if (config.min_difficulty is None) or (int(rating) >= config.min_difficulty):
|
| 70 |
+
assert len(q) == 81 and len(a) == 81
|
| 71 |
+
|
| 72 |
+
inputs.append(np.frombuffer(q.replace('.', '0').encode(), dtype=np.uint8).reshape(9, 9) - ord('0'))
|
| 73 |
+
labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(9, 9) - ord('0'))
|
| 74 |
+
|
| 75 |
+
# If subsample_size is specified for the training set,
|
| 76 |
+
# randomly sample the desired number of examples.
|
| 77 |
+
if set_name == "train" and config.subsample_size is not None:
|
| 78 |
+
total_samples = len(inputs)
|
| 79 |
+
if config.subsample_size < total_samples:
|
| 80 |
+
indices = np.random.choice(total_samples, size=config.subsample_size, replace=False)
|
| 81 |
+
inputs = [inputs[i] for i in indices]
|
| 82 |
+
labels = [labels[i] for i in indices]
|
| 83 |
+
|
| 84 |
+
# Generate dataset
|
| 85 |
+
num_augments = config.num_aug if set_name == "train" else 0
|
| 86 |
+
|
| 87 |
+
results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
|
| 88 |
+
puzzle_id = 0
|
| 89 |
+
example_id = 0
|
| 90 |
+
|
| 91 |
+
results["puzzle_indices"].append(0)
|
| 92 |
+
results["group_indices"].append(0)
|
| 93 |
+
|
| 94 |
+
for orig_inp, orig_out in zip(tqdm(inputs), labels):
|
| 95 |
+
for aug_idx in range(1 + num_augments):
|
| 96 |
+
# First index is not augmented
|
| 97 |
+
if aug_idx == 0:
|
| 98 |
+
inp, out = orig_inp, orig_out
|
| 99 |
+
else:
|
| 100 |
+
inp, out = shuffle_sudoku(orig_inp, orig_out)
|
| 101 |
+
|
| 102 |
+
# Push puzzle (only single example)
|
| 103 |
+
results["inputs"].append(inp)
|
| 104 |
+
results["labels"].append(out)
|
| 105 |
+
example_id += 1
|
| 106 |
+
puzzle_id += 1
|
| 107 |
+
|
| 108 |
+
results["puzzle_indices"].append(example_id)
|
| 109 |
+
results["puzzle_identifiers"].append(0)
|
| 110 |
+
|
| 111 |
+
# Push group
|
| 112 |
+
results["group_indices"].append(puzzle_id)
|
| 113 |
+
|
| 114 |
+
# To Numpy
|
| 115 |
+
def _seq_to_numpy(seq):
|
| 116 |
+
arr = np.concatenate(seq).reshape(len(seq), -1)
|
| 117 |
+
|
| 118 |
+
assert np.all((arr >= 0) & (arr <= 9))
|
| 119 |
+
return arr + 1
|
| 120 |
+
|
| 121 |
+
results = {
|
| 122 |
+
"inputs": _seq_to_numpy(results["inputs"]),
|
| 123 |
+
"labels": _seq_to_numpy(results["labels"]),
|
| 124 |
+
|
| 125 |
+
"group_indices": np.array(results["group_indices"], dtype=np.int32),
|
| 126 |
+
"puzzle_indices": np.array(results["puzzle_indices"], dtype=np.int32),
|
| 127 |
+
"puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32),
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
# Metadata
|
| 131 |
+
metadata = PuzzleDatasetMetadata(
|
| 132 |
+
seq_len=81,
|
| 133 |
+
vocab_size=10 + 1, # PAD + "0" ... "9"
|
| 134 |
+
pad_id=0,
|
| 135 |
+
ignore_label_id=0,
|
| 136 |
+
blank_identifier_id=0,
|
| 137 |
+
num_puzzle_identifiers=1,
|
| 138 |
+
total_groups=len(results["group_indices"]) - 1,
|
| 139 |
+
mean_puzzle_examples=1,
|
| 140 |
+
total_puzzles=len(results["group_indices"]) - 1,
|
| 141 |
+
sets=["all"]
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# Save metadata as JSON.
|
| 145 |
+
save_dir = os.path.join(config.output_dir, set_name)
|
| 146 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 147 |
+
|
| 148 |
+
with open(os.path.join(save_dir, "dataset.json"), "w") as f:
|
| 149 |
+
json.dump(metadata.model_dump(), f)
|
| 150 |
+
|
| 151 |
+
# Save data
|
| 152 |
+
for k, v in results.items():
|
| 153 |
+
np.save(os.path.join(save_dir, f"all__{k}.npy"), v)
|
| 154 |
+
|
| 155 |
+
# Save IDs mapping (for visualization only)
|
| 156 |
+
with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
|
| 157 |
+
json.dump(["<blank>"], f)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
@cli.command(singleton=True)
|
| 161 |
+
def preprocess_data(config: DataProcessConfig):
|
| 162 |
+
convert_subset("train", config)
|
| 163 |
+
convert_subset("test", config)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
if __name__ == "__main__":
|
| 167 |
+
cli()
|
dataset/common.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional
|
| 2 |
+
|
| 3 |
+
import pydantic
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# Global list mapping each dihedral transform id to its inverse.
|
| 8 |
+
# Index corresponds to the original tid, and the value is its inverse.
|
| 9 |
+
DIHEDRAL_INVERSE = [0, 3, 2, 1, 4, 5, 6, 7]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class PuzzleDatasetMetadata(pydantic.BaseModel):
|
| 13 |
+
pad_id: int
|
| 14 |
+
ignore_label_id: Optional[int]
|
| 15 |
+
blank_identifier_id: int
|
| 16 |
+
vocab_size: int
|
| 17 |
+
seq_len: int
|
| 18 |
+
num_puzzle_identifiers: int
|
| 19 |
+
total_groups: int
|
| 20 |
+
mean_puzzle_examples: float
|
| 21 |
+
total_puzzles: int
|
| 22 |
+
sets: List[str]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:
|
| 26 |
+
"""8 dihedral symmetries by rotate, flip and mirror"""
|
| 27 |
+
|
| 28 |
+
if tid == 0:
|
| 29 |
+
return arr # identity
|
| 30 |
+
elif tid == 1:
|
| 31 |
+
return np.rot90(arr, k=1)
|
| 32 |
+
elif tid == 2:
|
| 33 |
+
return np.rot90(arr, k=2)
|
| 34 |
+
elif tid == 3:
|
| 35 |
+
return np.rot90(arr, k=3)
|
| 36 |
+
elif tid == 4:
|
| 37 |
+
return np.fliplr(arr) # horizontal flip
|
| 38 |
+
elif tid == 5:
|
| 39 |
+
return np.flipud(arr) # vertical flip
|
| 40 |
+
elif tid == 6:
|
| 41 |
+
return arr.T # transpose (reflection along main diagonal)
|
| 42 |
+
elif tid == 7:
|
| 43 |
+
return np.fliplr(np.rot90(arr, k=1)) # anti-diagonal reflection
|
| 44 |
+
else:
|
| 45 |
+
return arr
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def inverse_dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:
|
| 49 |
+
return dihedral_transform(arr, DIHEDRAL_INVERSE[tid])
|
evaluators/arc.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Sequence, Optional
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
from numba import njit
|
| 8 |
+
import torch.distributed as dist
|
| 9 |
+
|
| 10 |
+
from dataset.build_arc_dataset import inverse_aug, grid_hash, arc_grid_to_np
|
| 11 |
+
from dataset.common import PuzzleDatasetMetadata
|
| 12 |
+
|
| 13 |
+
@njit
|
| 14 |
+
def _crop(grid: np.ndarray):
|
| 15 |
+
"""Find maximum-sized rectangle without any EOS token inside. """
|
| 16 |
+
grid = grid.reshape(30, 30)
|
| 17 |
+
|
| 18 |
+
max_area = 0
|
| 19 |
+
max_size = (0, 0)
|
| 20 |
+
nr, nc = grid.shape
|
| 21 |
+
|
| 22 |
+
num_c = nc
|
| 23 |
+
for num_r in range(1, nr + 1):
|
| 24 |
+
# Scan for maximum c
|
| 25 |
+
for c in range(1, num_c + 1):
|
| 26 |
+
x = grid[num_r - 1, c - 1]
|
| 27 |
+
if (x < 2) | (x > 11):
|
| 28 |
+
num_c = c - 1
|
| 29 |
+
break
|
| 30 |
+
|
| 31 |
+
area = num_r * num_c
|
| 32 |
+
if area > max_area:
|
| 33 |
+
max_area = area
|
| 34 |
+
max_size = (num_r, num_c)
|
| 35 |
+
|
| 36 |
+
return (grid[:max_size[0], :max_size[1]] - 2).astype(np.uint8)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class ARC:
|
| 40 |
+
required_outputs = {"inputs", "puzzle_identifiers", "q_halt_logits", "preds"}
|
| 41 |
+
|
| 42 |
+
def __init__(self, data_path: str,
|
| 43 |
+
eval_metadata: PuzzleDatasetMetadata,
|
| 44 |
+
submission_K: int = 2,
|
| 45 |
+
pass_Ks: Sequence[int] = (1, 2, 5, 10, 100, 1000),
|
| 46 |
+
aggregated_voting: bool = True):
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.pass_Ks = pass_Ks
|
| 49 |
+
self.submission_K = submission_K
|
| 50 |
+
self.aggregated_voting = aggregated_voting
|
| 51 |
+
self.blank_identifier_id = eval_metadata.blank_identifier_id
|
| 52 |
+
|
| 53 |
+
# Load identifiers and test puzzles
|
| 54 |
+
with open(os.path.join(data_path, "identifiers.json"), "r") as f:
|
| 55 |
+
self.identifier_map = json.load(f)
|
| 56 |
+
with open(os.path.join(data_path, "test_puzzles.json"), "r") as f:
|
| 57 |
+
self.test_puzzles = json.load(f)
|
| 58 |
+
|
| 59 |
+
# States
|
| 60 |
+
self._local_hmap = {}
|
| 61 |
+
self._local_preds = {}
|
| 62 |
+
|
| 63 |
+
def begin_eval(self):
|
| 64 |
+
if not self.aggregated_voting:
|
| 65 |
+
# Clear previous predictions
|
| 66 |
+
self._local_hmap = {}
|
| 67 |
+
self._local_preds = {}
|
| 68 |
+
|
| 69 |
+
def update_batch(self, batch: Dict[str, torch.Tensor], preds: Dict[str, torch.Tensor]):
|
| 70 |
+
# Collect required outputs to CPU
|
| 71 |
+
outputs = {}
|
| 72 |
+
q_values = None
|
| 73 |
+
|
| 74 |
+
for collection in (batch, preds):
|
| 75 |
+
for k, v in collection.items():
|
| 76 |
+
if k in self.required_outputs:
|
| 77 |
+
if k == "q_halt_logits":
|
| 78 |
+
q_values = v.to(torch.float64).sigmoid().cpu()
|
| 79 |
+
else:
|
| 80 |
+
outputs[k] = v.cpu()
|
| 81 |
+
|
| 82 |
+
assert q_values is not None
|
| 83 |
+
|
| 84 |
+
# Remove padding from outputs
|
| 85 |
+
mask = outputs["puzzle_identifiers"] != self.blank_identifier_id
|
| 86 |
+
outputs = {k: v[mask] for k, v in outputs.items()}
|
| 87 |
+
|
| 88 |
+
# Get predictions
|
| 89 |
+
for identifier, input, pred, q in zip(outputs["puzzle_identifiers"].numpy(), outputs["inputs"].numpy(), outputs["preds"].numpy(), q_values.numpy()):
|
| 90 |
+
name = self.identifier_map[identifier]
|
| 91 |
+
orig_name, _inverse_fn = inverse_aug(name)
|
| 92 |
+
|
| 93 |
+
input_hash = grid_hash(_inverse_fn(_crop(input)))
|
| 94 |
+
|
| 95 |
+
pred = _inverse_fn(_crop(pred))
|
| 96 |
+
assert np.all((pred >= 0) & (pred <= 9)), f"Puzzle {name}'s prediction out of 0-9 range." # Sanity check
|
| 97 |
+
|
| 98 |
+
# Store into local state
|
| 99 |
+
pred_hash = grid_hash(pred)
|
| 100 |
+
|
| 101 |
+
self._local_hmap[pred_hash] = pred
|
| 102 |
+
|
| 103 |
+
self._local_preds.setdefault(orig_name, {})
|
| 104 |
+
self._local_preds[orig_name].setdefault(input_hash, [])
|
| 105 |
+
self._local_preds[orig_name][input_hash].append((pred_hash, float(q)))
|
| 106 |
+
|
| 107 |
+
def result(self, save_path: Optional[str], rank: int, world_size: int, group: Optional[torch.distributed.ProcessGroup] = None) -> Optional[Dict[str, float]]:
|
| 108 |
+
# Gather predictions to rank 0 for voting
|
| 109 |
+
global_hmap_preds = [None for _ in range(world_size)] if rank == 0 else None
|
| 110 |
+
dist.gather_object((self._local_hmap, self._local_preds), global_hmap_preds, dst=0, group=group)
|
| 111 |
+
|
| 112 |
+
# Rank 0 logic
|
| 113 |
+
if rank != 0:
|
| 114 |
+
return
|
| 115 |
+
|
| 116 |
+
submission = {}
|
| 117 |
+
correct = [0.0 for _ in range(len(self.pass_Ks))]
|
| 118 |
+
|
| 119 |
+
for name, puzzle in self.test_puzzles.items():
|
| 120 |
+
# Process test examples in this puzzle
|
| 121 |
+
submission[name] = []
|
| 122 |
+
num_test_correct = [0 for _ in range(len(self.pass_Ks))]
|
| 123 |
+
for pair in puzzle["test"]:
|
| 124 |
+
input_hash = grid_hash(arc_grid_to_np(pair["input"]))
|
| 125 |
+
label_hash = grid_hash(arc_grid_to_np(pair["output"]))
|
| 126 |
+
|
| 127 |
+
p_map = {}
|
| 128 |
+
for hmap, preds in global_hmap_preds: # type: ignore
|
| 129 |
+
for h, q in preds.get(name, {}).get(input_hash, {}):
|
| 130 |
+
p_map.setdefault(h, [0, 0])
|
| 131 |
+
p_map[h][0] += 1
|
| 132 |
+
p_map[h][1] += q
|
| 133 |
+
|
| 134 |
+
if not len(p_map):
|
| 135 |
+
print (f"Puzzle {name} has no predictions.")
|
| 136 |
+
continue
|
| 137 |
+
|
| 138 |
+
for h, stats in p_map.items():
|
| 139 |
+
stats[1] /= stats[0]
|
| 140 |
+
|
| 141 |
+
p_map = sorted(p_map.items(), key=lambda kv: kv[1], reverse=True)
|
| 142 |
+
|
| 143 |
+
# vote for different Ks
|
| 144 |
+
for i, k in enumerate(self.pass_Ks):
|
| 145 |
+
ok = False
|
| 146 |
+
for h, stats in p_map[:k]:
|
| 147 |
+
ok |= h == label_hash
|
| 148 |
+
|
| 149 |
+
num_test_correct[i] += ok
|
| 150 |
+
|
| 151 |
+
# Query grids
|
| 152 |
+
pred_grids = []
|
| 153 |
+
for h, stats in p_map[:self.submission_K]:
|
| 154 |
+
for hmap, preds in global_hmap_preds: # type: ignore
|
| 155 |
+
if h in hmap:
|
| 156 |
+
pred_grids.append(hmap[h])
|
| 157 |
+
break
|
| 158 |
+
|
| 159 |
+
# Pad to K
|
| 160 |
+
while len(pred_grids) < self.submission_K:
|
| 161 |
+
pred_grids.append(pred_grids[0])
|
| 162 |
+
|
| 163 |
+
submission[name].append({f"attempt_{i + 1}": grid.tolist() for i, grid in enumerate(pred_grids)})
|
| 164 |
+
|
| 165 |
+
# Total correctness
|
| 166 |
+
for i in range(len(self.pass_Ks)):
|
| 167 |
+
correct[i] += num_test_correct[i] / len(puzzle["test"])
|
| 168 |
+
|
| 169 |
+
# Save submission
|
| 170 |
+
if save_path is not None:
|
| 171 |
+
with open(os.path.join(save_path, "submission.json"), "w") as f:
|
| 172 |
+
json.dump(submission, f)
|
| 173 |
+
|
| 174 |
+
# Final result
|
| 175 |
+
all_results = {f"ARC/pass@{k}": correct[i] / len(self.test_puzzles) for i, k in enumerate(self.pass_Ks)}
|
| 176 |
+
|
| 177 |
+
return all_results
|
kaggle/combined/arc-agi_concept_challenges.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
kaggle/combined/arc-agi_concept_solutions.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
kaggle/combined/arc-agi_evaluation2_challenges.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
kaggle/combined/arc-agi_evaluation2_solutions.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
kaggle/combined/arc-agi_evaluation_challenges.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
kaggle/combined/arc-agi_evaluation_solutions.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
kaggle/combined/arc-agi_training2_challenges.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
kaggle/combined/arc-agi_training2_solutions.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
kaggle/combined/arc-agi_training_challenges.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
kaggle/combined/arc-agi_training_solutions.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/common.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, lower: float = -2.0, upper: float = 2.0):
|
| 8 |
+
# NOTE: PyTorch nn.init.trunc_normal_ is not mathematically correct, the std dev is not actually the std dev of initialized tensor
|
| 9 |
+
# This function is a PyTorch version of jax truncated normal init (default init method in flax)
|
| 10 |
+
# https://github.com/jax-ml/jax/blob/main/jax/_src/random.py#L807-L848
|
| 11 |
+
# https://github.com/jax-ml/jax/blob/main/jax/_src/nn/initializers.py#L162-L199
|
| 12 |
+
|
| 13 |
+
with torch.no_grad():
|
| 14 |
+
if std == 0:
|
| 15 |
+
tensor.zero_()
|
| 16 |
+
else:
|
| 17 |
+
sqrt2 = math.sqrt(2)
|
| 18 |
+
a = math.erf(lower / sqrt2)
|
| 19 |
+
b = math.erf(upper / sqrt2)
|
| 20 |
+
z = (b - a) / 2
|
| 21 |
+
|
| 22 |
+
c = (2 * math.pi) ** -0.5
|
| 23 |
+
pdf_u = c * math.exp(-0.5 * lower ** 2)
|
| 24 |
+
pdf_l = c * math.exp(-0.5 * upper ** 2)
|
| 25 |
+
comp_std = std / math.sqrt(1 - (upper * pdf_u - lower * pdf_l) / z - ((pdf_u - pdf_l) / z) ** 2)
|
| 26 |
+
|
| 27 |
+
tensor.uniform_(a, b)
|
| 28 |
+
tensor.erfinv_()
|
| 29 |
+
tensor.mul_(sqrt2 * comp_std)
|
| 30 |
+
tensor.clip_(lower * comp_std, upper * comp_std)
|
| 31 |
+
|
| 32 |
+
return tensor
|
models/ema.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
class EMAHelper(object):
|
| 5 |
+
def __init__(self, mu=0.999):
|
| 6 |
+
self.mu = mu
|
| 7 |
+
self.shadow = {}
|
| 8 |
+
|
| 9 |
+
def register(self, module):
|
| 10 |
+
if isinstance(module, nn.DataParallel):
|
| 11 |
+
module = module.module
|
| 12 |
+
for name, param in module.named_parameters():
|
| 13 |
+
if param.requires_grad:
|
| 14 |
+
self.shadow[name] = param.data.clone()
|
| 15 |
+
|
| 16 |
+
def update(self, module):
|
| 17 |
+
if isinstance(module, nn.DataParallel):
|
| 18 |
+
module = module.module
|
| 19 |
+
for name, param in module.named_parameters():
|
| 20 |
+
if param.requires_grad:
|
| 21 |
+
self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data
|
| 22 |
+
|
| 23 |
+
def ema(self, module):
|
| 24 |
+
if isinstance(module, nn.DataParallel):
|
| 25 |
+
module = module.module
|
| 26 |
+
for name, param in module.named_parameters():
|
| 27 |
+
if param.requires_grad:
|
| 28 |
+
param.data.copy_(self.shadow[name].data)
|
| 29 |
+
|
| 30 |
+
def ema_copy(self, module):
|
| 31 |
+
module_copy = copy.deepcopy(module)
|
| 32 |
+
self.ema(module_copy)
|
| 33 |
+
return module_copy
|
| 34 |
+
|
| 35 |
+
def state_dict(self):
|
| 36 |
+
return self.shadow
|
| 37 |
+
|
| 38 |
+
def load_state_dict(self, state_dict):
|
| 39 |
+
self.shadow = state_dict
|
| 40 |
+
|
models/layers.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple
|
| 2 |
+
import einops
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
#try:
|
| 8 |
+
# from flash_attn_interface import flash_attn_func # type: ignore[import]
|
| 9 |
+
#except ImportError:
|
| 10 |
+
# # Fallback to FlashAttention 2
|
| 11 |
+
# from flash_attn import flash_attn_func # type: ignore[import]
|
| 12 |
+
from torch.nn.functional import scaled_dot_product_attention
|
| 13 |
+
|
| 14 |
+
from models.common import trunc_normal_init_
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
CosSin = Tuple[torch.Tensor, torch.Tensor]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _find_multiple(a, b):
|
| 21 |
+
return (-(a // -b)) * b
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def rotate_half(x: torch.Tensor):
|
| 25 |
+
"""Rotates half the hidden dims of the input."""
|
| 26 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 27 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 28 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
| 32 |
+
# q, k: [bs, seq_len, num_heads, head_dim]
|
| 33 |
+
# cos, sin: [seq_len, head_dim]
|
| 34 |
+
orig_dtype = q.dtype
|
| 35 |
+
q = q.to(cos.dtype)
|
| 36 |
+
k = k.to(cos.dtype)
|
| 37 |
+
|
| 38 |
+
q_embed = (q * cos.unsqueeze(-2)) + (rotate_half(q) * sin.unsqueeze(-2))
|
| 39 |
+
k_embed = (k * cos.unsqueeze(-2)) + (rotate_half(k) * sin.unsqueeze(-2))
|
| 40 |
+
|
| 41 |
+
return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class CastedLinear(nn.Module):
|
| 45 |
+
def __init__(self,
|
| 46 |
+
in_features: int,
|
| 47 |
+
out_features: int,
|
| 48 |
+
bias: bool):
|
| 49 |
+
super().__init__()
|
| 50 |
+
# Truncated LeCun normal init
|
| 51 |
+
self.weight = nn.Parameter(
|
| 52 |
+
trunc_normal_init_(torch.empty((out_features, in_features)), std=1.0 / (in_features ** 0.5))
|
| 53 |
+
)
|
| 54 |
+
self.bias = None
|
| 55 |
+
if bias:
|
| 56 |
+
# Zero init bias
|
| 57 |
+
self.bias = nn.Parameter(torch.zeros((out_features, )))
|
| 58 |
+
|
| 59 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 60 |
+
return F.linear(input, self.weight.to(input.dtype), bias=self.bias.to(input.dtype) if self.bias is not None else None)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class CastedEmbedding(nn.Module):
|
| 64 |
+
def __init__(self,
|
| 65 |
+
num_embeddings: int,
|
| 66 |
+
embedding_dim: int,
|
| 67 |
+
init_std: float,
|
| 68 |
+
cast_to: torch.dtype):
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.cast_to = cast_to
|
| 71 |
+
|
| 72 |
+
# Truncated LeCun normal init
|
| 73 |
+
self.embedding_weight = nn.Parameter(
|
| 74 |
+
trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std)
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 78 |
+
return F.embedding(input, self.embedding_weight.to(self.cast_to))
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class RotaryEmbedding(nn.Module):
|
| 82 |
+
def __init__(self, dim, max_position_embeddings, base, device=None):
|
| 83 |
+
super().__init__()
|
| 84 |
+
|
| 85 |
+
# RoPE
|
| 86 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
|
| 87 |
+
t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device)
|
| 88 |
+
freqs = torch.outer(t, inv_freq)
|
| 89 |
+
|
| 90 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 91 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 92 |
+
self.cos_cached = nn.Buffer(emb.cos(), persistent=False)
|
| 93 |
+
self.sin_cached = nn.Buffer(emb.sin(), persistent=False)
|
| 94 |
+
|
| 95 |
+
def forward(self):
|
| 96 |
+
return self.cos_cached, self.sin_cached
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class Attention(nn.Module):
|
| 100 |
+
def __init__(self, hidden_size, head_dim, num_heads, num_key_value_heads, causal=False):
|
| 101 |
+
super().__init__()
|
| 102 |
+
|
| 103 |
+
self.hidden_size = hidden_size
|
| 104 |
+
self.head_dim = head_dim
|
| 105 |
+
self.output_size = head_dim * num_heads
|
| 106 |
+
self.num_heads = num_heads
|
| 107 |
+
self.num_key_value_heads = num_key_value_heads
|
| 108 |
+
self.causal = causal
|
| 109 |
+
|
| 110 |
+
self.qkv_proj = CastedLinear(self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False)
|
| 111 |
+
self.o_proj = CastedLinear(self.output_size, self.hidden_size, bias=False)
|
| 112 |
+
|
| 113 |
+
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 114 |
+
batch_size, seq_len, _ = hidden_states.shape
|
| 115 |
+
|
| 116 |
+
# hidden_states: [bs, seq_len, num_heads, head_dim]
|
| 117 |
+
qkv = self.qkv_proj(hidden_states)
|
| 118 |
+
|
| 119 |
+
# Split head
|
| 120 |
+
qkv = qkv.view(batch_size, seq_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
|
| 121 |
+
query = qkv[:, :, :self.num_heads]
|
| 122 |
+
key = qkv[:, :, self.num_heads: self.num_heads + self.num_key_value_heads]
|
| 123 |
+
value = qkv[:, :, self.num_heads + self.num_key_value_heads:]
|
| 124 |
+
|
| 125 |
+
# RoPE
|
| 126 |
+
if cos_sin is not None:
|
| 127 |
+
cos, sin = cos_sin
|
| 128 |
+
query, key = apply_rotary_pos_emb(query, key, cos, sin)
|
| 129 |
+
|
| 130 |
+
# flash attn
|
| 131 |
+
query, key, value = map(lambda t: einops.rearrange(t, 'B S H D -> B H S D'), (query, key, value)) # needed for scaled_dot_product_attention but not flash_attn_func
|
| 132 |
+
attn_output = scaled_dot_product_attention(query=query, key=key, value=value, is_causal=self.causal)
|
| 133 |
+
attn_output = einops.rearrange(attn_output, 'B H S D -> B S H D')
|
| 134 |
+
attn_output = attn_output.view(batch_size, seq_len, self.output_size) # type: ignore
|
| 135 |
+
return self.o_proj(attn_output)
|
| 136 |
+
|
| 137 |
+
class LinearSwish(nn.Module):
|
| 138 |
+
def __init__(self, hidden_size: int, reverse=False):
|
| 139 |
+
super().__init__()
|
| 140 |
+
|
| 141 |
+
self.linear = CastedLinear(hidden_size, hidden_size, bias=False)
|
| 142 |
+
self.reverse = reverse
|
| 143 |
+
|
| 144 |
+
def forward(self, x):
|
| 145 |
+
if self.reverse:
|
| 146 |
+
return F.silu(self.linear(x))
|
| 147 |
+
else:
|
| 148 |
+
return self.linear(F.silu(x))
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class SwiGLU(nn.Module):
|
| 152 |
+
def __init__(self, hidden_size: int, expansion: float):
|
| 153 |
+
super().__init__()
|
| 154 |
+
inter = _find_multiple(round(expansion * hidden_size * 2 / 3), 256)
|
| 155 |
+
|
| 156 |
+
self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False)
|
| 157 |
+
self.down_proj = CastedLinear(inter, hidden_size, bias=False)
|
| 158 |
+
|
| 159 |
+
def forward(self, x):
|
| 160 |
+
gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
|
| 161 |
+
return self.down_proj(F.silu(gate) * up)
|
| 162 |
+
|
| 163 |
+
def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float) -> torch.Tensor:
|
| 164 |
+
input_dtype = hidden_states.dtype
|
| 165 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 166 |
+
|
| 167 |
+
variance = hidden_states.square().mean(-1, keepdim=True)
|
| 168 |
+
hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
|
| 169 |
+
return hidden_states.to(input_dtype)
|
models/losses.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Tuple, Dict, Sequence, Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torch import nn
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
IGNORE_LABEL_ID = -100
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def s(x, epsilon=1e-30):
|
| 12 |
+
return torch.where(
|
| 13 |
+
x<0,
|
| 14 |
+
1/(1-x+ epsilon),
|
| 15 |
+
x + 1
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def log_stablemax(x, dim=-1):
|
| 20 |
+
s_x = s(x)
|
| 21 |
+
return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def stablemax_cross_entropy(logits, labels, ignore_index: int = -100, valid_mask=None):
|
| 25 |
+
logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
|
| 26 |
+
|
| 27 |
+
if valid_mask is None:
|
| 28 |
+
valid_mask = (labels != ignore_index)
|
| 29 |
+
transformed_labels = torch.where(valid_mask, labels, 0)
|
| 30 |
+
prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
|
| 31 |
+
|
| 32 |
+
return -torch.where(valid_mask, prediction_logprobs, 0)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
|
| 36 |
+
# Cast logits to f32
|
| 37 |
+
# Flatten logits
|
| 38 |
+
return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class ACTLossHead(nn.Module):
|
| 42 |
+
def __init__(self, model: nn.Module, loss_type: str):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.model = model
|
| 45 |
+
self.loss_fn = globals()[loss_type]
|
| 46 |
+
|
| 47 |
+
def initial_carry(self, *args, **kwargs):
|
| 48 |
+
return self.model.initial_carry(*args, **kwargs) # type: ignore
|
| 49 |
+
|
| 50 |
+
def forward(
|
| 51 |
+
self,
|
| 52 |
+
return_keys: Sequence[str],
|
| 53 |
+
# Model args
|
| 54 |
+
**model_kwargs,
|
| 55 |
+
) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
|
| 56 |
+
# Model logits
|
| 57 |
+
# B x SeqLen x D
|
| 58 |
+
new_carry, outputs = self.model(**model_kwargs)
|
| 59 |
+
labels = new_carry.current_data["labels"]
|
| 60 |
+
|
| 61 |
+
with torch.no_grad():
|
| 62 |
+
# Preds
|
| 63 |
+
outputs["preds"] = torch.argmax(outputs["logits"], dim=-1)
|
| 64 |
+
|
| 65 |
+
# Correctness
|
| 66 |
+
mask = (labels != IGNORE_LABEL_ID)
|
| 67 |
+
loss_counts = mask.sum(-1)
|
| 68 |
+
loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
|
| 69 |
+
|
| 70 |
+
is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
|
| 71 |
+
seq_is_correct = is_correct.sum(-1) == loss_counts
|
| 72 |
+
|
| 73 |
+
# Metrics (halted)
|
| 74 |
+
valid_metrics = new_carry.halted & (loss_counts > 0)
|
| 75 |
+
metrics = {
|
| 76 |
+
"count": valid_metrics.sum(),
|
| 77 |
+
|
| 78 |
+
"accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
|
| 79 |
+
"exact_accuracy": (valid_metrics & seq_is_correct).sum(),
|
| 80 |
+
|
| 81 |
+
"q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
|
| 82 |
+
"steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
# Losses
|
| 86 |
+
|
| 87 |
+
lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID, valid_mask=mask) / loss_divisor).sum()
|
| 88 |
+
q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
|
| 89 |
+
metrics.update({
|
| 90 |
+
"lm_loss": lm_loss.detach(),
|
| 91 |
+
"q_halt_loss": q_halt_loss.detach(),
|
| 92 |
+
})
|
| 93 |
+
# Q continue (bootstrapping target loss); Alexia: This fits Q-learning, but seems totally unecessary
|
| 94 |
+
q_continue_loss = 0
|
| 95 |
+
if "target_q_continue" in outputs:
|
| 96 |
+
q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
|
| 97 |
+
|
| 98 |
+
metrics["q_continue_loss"] = q_continue_loss.detach()
|
| 99 |
+
# Filter outputs for return
|
| 100 |
+
detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
|
| 101 |
+
|
| 102 |
+
return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()
|
| 103 |
+
|
models/recursive_reasoning/hrm.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple, List, Dict, Optional
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch import nn
|
| 7 |
+
from pydantic import BaseModel
|
| 8 |
+
|
| 9 |
+
from models.common import trunc_normal_init_
|
| 10 |
+
from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
|
| 11 |
+
from models.sparse_embedding import CastedSparseEmbedding
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class HierarchicalReasoningModel_ACTV1InnerCarry:
|
| 15 |
+
z_H: torch.Tensor
|
| 16 |
+
z_L: torch.Tensor
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class HierarchicalReasoningModel_ACTV1Carry:
|
| 21 |
+
inner_carry: HierarchicalReasoningModel_ACTV1InnerCarry
|
| 22 |
+
|
| 23 |
+
steps: torch.Tensor
|
| 24 |
+
halted: torch.Tensor
|
| 25 |
+
|
| 26 |
+
current_data: Dict[str, torch.Tensor]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class HierarchicalReasoningModel_ACTV1Config(BaseModel):
|
| 30 |
+
batch_size: int
|
| 31 |
+
seq_len: int
|
| 32 |
+
puzzle_emb_ndim: int = 0
|
| 33 |
+
num_puzzle_identifiers: int
|
| 34 |
+
vocab_size: int
|
| 35 |
+
|
| 36 |
+
H_cycles: int
|
| 37 |
+
L_cycles: int
|
| 38 |
+
|
| 39 |
+
H_layers: int
|
| 40 |
+
L_layers: int
|
| 41 |
+
|
| 42 |
+
# Transformer config
|
| 43 |
+
hidden_size: int
|
| 44 |
+
expansion: float
|
| 45 |
+
num_heads: int
|
| 46 |
+
pos_encodings: str
|
| 47 |
+
|
| 48 |
+
rms_norm_eps: float = 1e-5
|
| 49 |
+
rope_theta: float = 10000.0
|
| 50 |
+
|
| 51 |
+
# Halting Q-learning config
|
| 52 |
+
halt_max_steps: int
|
| 53 |
+
halt_exploration_prob: float
|
| 54 |
+
|
| 55 |
+
forward_dtype: str = "bfloat16"
|
| 56 |
+
|
| 57 |
+
# Alexia: added
|
| 58 |
+
mlp_t: bool=False # use mlp on L instead of transformer
|
| 59 |
+
|
| 60 |
+
class HierarchicalReasoningModel_ACTV1Block(nn.Module):
|
| 61 |
+
def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
|
| 62 |
+
super().__init__()
|
| 63 |
+
|
| 64 |
+
self.config = config
|
| 65 |
+
if self.config.mlp_t:
|
| 66 |
+
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size)
|
| 67 |
+
self.mlp_t = SwiGLU(
|
| 68 |
+
hidden_size=self.config.seq_len + self.puzzle_emb_len, # L
|
| 69 |
+
expansion=config.expansion,
|
| 70 |
+
)
|
| 71 |
+
else:
|
| 72 |
+
self.self_attn = Attention(
|
| 73 |
+
hidden_size=config.hidden_size,
|
| 74 |
+
head_dim=config.hidden_size // config.num_heads,
|
| 75 |
+
num_heads=config.num_heads,
|
| 76 |
+
num_key_value_heads=config.num_heads,
|
| 77 |
+
causal=False
|
| 78 |
+
)
|
| 79 |
+
self.mlp = SwiGLU(
|
| 80 |
+
hidden_size=config.hidden_size,
|
| 81 |
+
expansion=config.expansion,
|
| 82 |
+
)
|
| 83 |
+
self.norm_eps = config.rms_norm_eps
|
| 84 |
+
|
| 85 |
+
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 86 |
+
# B, L, D = hidden_states.shape
|
| 87 |
+
# Post Norm
|
| 88 |
+
if self.config.mlp_t:
|
| 89 |
+
hidden_states = hidden_states.transpose(1,2)
|
| 90 |
+
out = self.mlp_t(hidden_states)
|
| 91 |
+
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
|
| 92 |
+
hidden_states = hidden_states.transpose(1,2)
|
| 93 |
+
else:
|
| 94 |
+
# Self Attention
|
| 95 |
+
hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
|
| 96 |
+
# Fully Connected
|
| 97 |
+
out = self.mlp(hidden_states)
|
| 98 |
+
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
|
| 99 |
+
return hidden_states
|
| 100 |
+
|
| 101 |
+
class HierarchicalReasoningModel_ACTV1ReasoningModule(nn.Module):
|
| 102 |
+
def __init__(self, layers: List[HierarchicalReasoningModel_ACTV1Block]):
|
| 103 |
+
super().__init__()
|
| 104 |
+
|
| 105 |
+
self.layers = torch.nn.ModuleList(layers)
|
| 106 |
+
|
| 107 |
+
def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 108 |
+
# Input injection (add)
|
| 109 |
+
hidden_states = hidden_states + input_injection
|
| 110 |
+
# Layers
|
| 111 |
+
for layer in self.layers:
|
| 112 |
+
hidden_states = layer(hidden_states=hidden_states, **kwargs)
|
| 113 |
+
|
| 114 |
+
return hidden_states
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class HierarchicalReasoningModel_ACTV1_Inner(nn.Module):
|
| 118 |
+
def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
|
| 119 |
+
super().__init__()
|
| 120 |
+
self.config = config
|
| 121 |
+
self.forward_dtype = getattr(torch, self.config.forward_dtype)
|
| 122 |
+
|
| 123 |
+
# I/O
|
| 124 |
+
self.embed_scale = math.sqrt(self.config.hidden_size)
|
| 125 |
+
embed_init_std = 1.0 / self.embed_scale
|
| 126 |
+
|
| 127 |
+
self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
|
| 128 |
+
self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
|
| 129 |
+
self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
|
| 130 |
+
|
| 131 |
+
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
|
| 132 |
+
if self.config.puzzle_emb_ndim > 0:
|
| 133 |
+
# Zero init puzzle embeddings
|
| 134 |
+
self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
|
| 135 |
+
batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
|
| 136 |
+
|
| 137 |
+
# LM Blocks
|
| 138 |
+
if self.config.pos_encodings == "rope":
|
| 139 |
+
self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
|
| 140 |
+
max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
|
| 141 |
+
base=self.config.rope_theta)
|
| 142 |
+
elif self.config.pos_encodings == "learned":
|
| 143 |
+
self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
|
| 144 |
+
else:
|
| 145 |
+
pass
|
| 146 |
+
|
| 147 |
+
# Reasoning Layers
|
| 148 |
+
self.H_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.H_layers)])
|
| 149 |
+
self.L_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
|
| 150 |
+
|
| 151 |
+
# Initial states
|
| 152 |
+
self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
| 153 |
+
self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
| 154 |
+
|
| 155 |
+
# Q head special init
|
| 156 |
+
# Init Q to (almost) zero for faster learning during bootstrapping
|
| 157 |
+
with torch.no_grad():
|
| 158 |
+
self.q_head.weight.zero_()
|
| 159 |
+
self.q_head.bias.fill_(-5) # type: ignore
|
| 160 |
+
|
| 161 |
+
def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
|
| 162 |
+
# Token embedding
|
| 163 |
+
embedding = self.embed_tokens(input.to(torch.int32))
|
| 164 |
+
|
| 165 |
+
# Puzzle embeddings
|
| 166 |
+
if self.config.puzzle_emb_ndim > 0:
|
| 167 |
+
puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
|
| 168 |
+
|
| 169 |
+
pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
|
| 170 |
+
if pad_count > 0:
|
| 171 |
+
puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
|
| 172 |
+
|
| 173 |
+
embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
|
| 174 |
+
|
| 175 |
+
# Position embeddings
|
| 176 |
+
if self.config.pos_encodings == "learned":
|
| 177 |
+
# scale by 1/sqrt(2) to maintain forward variance
|
| 178 |
+
embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
|
| 179 |
+
|
| 180 |
+
# Scale
|
| 181 |
+
return self.embed_scale * embedding
|
| 182 |
+
|
| 183 |
+
def empty_carry(self, batch_size: int):
|
| 184 |
+
return HierarchicalReasoningModel_ACTV1InnerCarry(
|
| 185 |
+
z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
| 186 |
+
z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
def reset_carry(self, reset_flag: torch.Tensor, carry: HierarchicalReasoningModel_ACTV1InnerCarry):
|
| 190 |
+
return HierarchicalReasoningModel_ACTV1InnerCarry(
|
| 191 |
+
z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
|
| 192 |
+
z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
def forward(self, carry: HierarchicalReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 196 |
+
seq_info = dict(
|
| 197 |
+
cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# Input encoding
|
| 201 |
+
input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
|
| 202 |
+
|
| 203 |
+
# Forward iterations
|
| 204 |
+
with torch.no_grad():
|
| 205 |
+
z_H, z_L = carry.z_H, carry.z_L
|
| 206 |
+
for _H_step in range(self.config.H_cycles):
|
| 207 |
+
for _L_step in range(self.config.L_cycles):
|
| 208 |
+
if not ((_H_step == self.config.H_cycles - 1) and (_L_step == self.config.L_cycles - 1)):
|
| 209 |
+
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
|
| 210 |
+
if not (_H_step == self.config.H_cycles - 1):
|
| 211 |
+
z_H = self.H_level(z_H, z_L, **seq_info)
|
| 212 |
+
assert not z_H.requires_grad and not z_L.requires_grad
|
| 213 |
+
# 1-step grad
|
| 214 |
+
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
|
| 215 |
+
z_H = self.H_level(z_H, z_L, **seq_info)
|
| 216 |
+
|
| 217 |
+
# LM Outputs
|
| 218 |
+
new_carry = HierarchicalReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach()) # New carry no grad
|
| 219 |
+
output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
|
| 220 |
+
|
| 221 |
+
# Q head
|
| 222 |
+
q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
|
| 223 |
+
|
| 224 |
+
return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class HierarchicalReasoningModel_ACTV1(nn.Module):
|
| 228 |
+
"""ACT wrapper."""
|
| 229 |
+
|
| 230 |
+
def __init__(self, config_dict: dict):
|
| 231 |
+
super().__init__()
|
| 232 |
+
self.config = HierarchicalReasoningModel_ACTV1Config(**config_dict)
|
| 233 |
+
self.inner = HierarchicalReasoningModel_ACTV1_Inner(self.config)
|
| 234 |
+
|
| 235 |
+
@property
|
| 236 |
+
def puzzle_emb(self):
|
| 237 |
+
return self.inner.puzzle_emb
|
| 238 |
+
|
| 239 |
+
def initial_carry(self, batch: Dict[str, torch.Tensor]):
|
| 240 |
+
batch_size = batch["inputs"].shape[0]
|
| 241 |
+
|
| 242 |
+
return HierarchicalReasoningModel_ACTV1Carry(
|
| 243 |
+
inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
|
| 244 |
+
|
| 245 |
+
steps=torch.zeros((batch_size, ), dtype=torch.int32),
|
| 246 |
+
halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
|
| 247 |
+
|
| 248 |
+
current_data={k: torch.empty_like(v) for k, v in batch.items()}
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
def forward(self, carry: HierarchicalReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
|
| 252 |
+
# Update data, carry (removing halted sequences)
|
| 253 |
+
new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
|
| 254 |
+
|
| 255 |
+
new_steps = torch.where(carry.halted, 0, carry.steps)
|
| 256 |
+
|
| 257 |
+
new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
|
| 258 |
+
|
| 259 |
+
# Forward inner model
|
| 260 |
+
new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
|
| 261 |
+
|
| 262 |
+
outputs = {
|
| 263 |
+
"logits": logits,
|
| 264 |
+
"q_halt_logits": q_halt_logits,
|
| 265 |
+
"q_continue_logits": q_continue_logits
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
with torch.no_grad():
|
| 269 |
+
# Step
|
| 270 |
+
new_steps = new_steps + 1
|
| 271 |
+
is_last_step = new_steps >= self.config.halt_max_steps
|
| 272 |
+
|
| 273 |
+
halted = is_last_step
|
| 274 |
+
|
| 275 |
+
# if training, and ACT is enabled
|
| 276 |
+
if self.training and (self.config.halt_max_steps > 1):
|
| 277 |
+
# Halt signal
|
| 278 |
+
# NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
|
| 279 |
+
halted = halted | (q_halt_logits > q_continue_logits)
|
| 280 |
+
|
| 281 |
+
# Exploration
|
| 282 |
+
min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
|
| 283 |
+
|
| 284 |
+
halted = halted & (new_steps >= min_halt_steps)
|
| 285 |
+
|
| 286 |
+
# Compute target Q
|
| 287 |
+
# NOTE: No replay buffer and target networks for computing target Q-value.
|
| 288 |
+
# As batch_size is large, there're many parallel envs.
|
| 289 |
+
# Similar concept as PQN https://arxiv.org/abs/2407.04811
|
| 290 |
+
next_q_halt_logits, next_q_continue_logits = self.inner(new_inner_carry, new_current_data)[-1]
|
| 291 |
+
|
| 292 |
+
outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
|
| 293 |
+
|
| 294 |
+
return HierarchicalReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
|
models/recursive_reasoning/transformers_baseline.py
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HRM ACT V2: Transformer Baseline for Architecture Ablation
|
| 3 |
+
|
| 4 |
+
This is an architecture ablation of the Hierarchical Reasoning Model (HRM).
|
| 5 |
+
Key changes from V1:
|
| 6 |
+
1. REMOVED hierarchical split (no separate H and L levels)
|
| 7 |
+
2. REMOVED inner cycles (no H_cycles/L_cycles loops within reasoning)
|
| 8 |
+
3. KEPT ACT outer loop structure intact
|
| 9 |
+
4. KEPT all data preprocessing, embeddings, and evaluation infrastructure
|
| 10 |
+
|
| 11 |
+
Architecture: Single-level transformer that processes the full 30x30 grid as a
|
| 12 |
+
900-token sequence, with the same positional encodings and sparse embeddings as V1.
|
| 13 |
+
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from typing import Tuple, List, Dict, Optional
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
import math
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
from torch import nn
|
| 23 |
+
from pydantic import BaseModel
|
| 24 |
+
|
| 25 |
+
from models.common import trunc_normal_init_
|
| 26 |
+
from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
|
| 27 |
+
from models.sparse_embedding import CastedSparseEmbedding
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class Model_ACTV2InnerCarry:
|
| 32 |
+
z_H: torch.Tensor
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class Model_ACTV2Carry:
|
| 37 |
+
inner_carry: Model_ACTV2InnerCarry
|
| 38 |
+
|
| 39 |
+
steps: torch.Tensor
|
| 40 |
+
halted: torch.Tensor
|
| 41 |
+
|
| 42 |
+
current_data: Dict[str, torch.Tensor]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class Model_ACTV2Config(BaseModel):
|
| 46 |
+
batch_size: int
|
| 47 |
+
seq_len: int
|
| 48 |
+
puzzle_emb_ndim: int = 0
|
| 49 |
+
num_puzzle_identifiers: int
|
| 50 |
+
vocab_size: int
|
| 51 |
+
|
| 52 |
+
H_cycles: int
|
| 53 |
+
|
| 54 |
+
H_layers: int
|
| 55 |
+
|
| 56 |
+
# Transformer config
|
| 57 |
+
hidden_size: int
|
| 58 |
+
expansion: float
|
| 59 |
+
num_heads: int
|
| 60 |
+
pos_encodings: str
|
| 61 |
+
|
| 62 |
+
rms_norm_eps: float = 1e-5
|
| 63 |
+
rope_theta: float = 10000.0
|
| 64 |
+
|
| 65 |
+
# Halting Q-learning config
|
| 66 |
+
halt_max_steps: int
|
| 67 |
+
halt_exploration_prob: float
|
| 68 |
+
act_enabled: bool = True # If False, always run halt_max_steps (no early stopping during training)
|
| 69 |
+
act_inference: bool = False # If True, use adaptive computation during inference
|
| 70 |
+
|
| 71 |
+
forward_dtype: str = "bfloat16"
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class Model_ACTV2Block(nn.Module):
|
| 75 |
+
def __init__(self, config: Model_ACTV2Config) -> None:
|
| 76 |
+
super().__init__()
|
| 77 |
+
|
| 78 |
+
self.self_attn = Attention(
|
| 79 |
+
hidden_size=config.hidden_size,
|
| 80 |
+
head_dim=config.hidden_size // config.num_heads,
|
| 81 |
+
num_heads=config.num_heads,
|
| 82 |
+
num_key_value_heads=config.num_heads,
|
| 83 |
+
causal=False,
|
| 84 |
+
)
|
| 85 |
+
self.mlp = SwiGLU(
|
| 86 |
+
hidden_size=config.hidden_size,
|
| 87 |
+
expansion=config.expansion,
|
| 88 |
+
)
|
| 89 |
+
self.norm_eps = config.rms_norm_eps
|
| 90 |
+
|
| 91 |
+
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 92 |
+
# Post Norm
|
| 93 |
+
# Self Attention
|
| 94 |
+
hidden_states = rms_norm(
|
| 95 |
+
hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states),
|
| 96 |
+
variance_epsilon=self.norm_eps,
|
| 97 |
+
)
|
| 98 |
+
# Fully Connected
|
| 99 |
+
hidden_states = rms_norm(hidden_states + self.mlp(hidden_states), variance_epsilon=self.norm_eps)
|
| 100 |
+
return hidden_states
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class Model_ACTV2ReasoningModule(nn.Module):
|
| 104 |
+
def __init__(self, layers: List[Model_ACTV2Block]):
|
| 105 |
+
super().__init__()
|
| 106 |
+
|
| 107 |
+
self.layers = torch.nn.ModuleList(layers)
|
| 108 |
+
|
| 109 |
+
def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 110 |
+
# Input injection (add)
|
| 111 |
+
hidden_states = hidden_states + input_injection
|
| 112 |
+
# Layers
|
| 113 |
+
for layer in self.layers:
|
| 114 |
+
hidden_states = layer(hidden_states=hidden_states, **kwargs)
|
| 115 |
+
|
| 116 |
+
return hidden_states
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class Model_ACTV2_Inner(nn.Module):
|
| 120 |
+
def __init__(self, config: Model_ACTV2Config) -> None:
|
| 121 |
+
super().__init__()
|
| 122 |
+
self.config = config
|
| 123 |
+
self.forward_dtype = getattr(torch, self.config.forward_dtype)
|
| 124 |
+
|
| 125 |
+
# I/O
|
| 126 |
+
self.embed_scale = math.sqrt(self.config.hidden_size)
|
| 127 |
+
embed_init_std = 1.0 / self.embed_scale
|
| 128 |
+
|
| 129 |
+
self.embed_tokens = CastedEmbedding(
|
| 130 |
+
self.config.vocab_size,
|
| 131 |
+
self.config.hidden_size,
|
| 132 |
+
init_std=embed_init_std,
|
| 133 |
+
cast_to=self.forward_dtype,
|
| 134 |
+
)
|
| 135 |
+
self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
|
| 136 |
+
self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
|
| 137 |
+
|
| 138 |
+
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
|
| 139 |
+
if self.config.puzzle_emb_ndim > 0:
|
| 140 |
+
# Zero init puzzle embeddings
|
| 141 |
+
self.puzzle_emb = CastedSparseEmbedding(
|
| 142 |
+
self.config.num_puzzle_identifiers,
|
| 143 |
+
self.config.puzzle_emb_ndim,
|
| 144 |
+
batch_size=self.config.batch_size,
|
| 145 |
+
init_std=0,
|
| 146 |
+
cast_to=self.forward_dtype,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
# LM Blocks
|
| 150 |
+
if self.config.pos_encodings == "rope":
|
| 151 |
+
self.rotary_emb = RotaryEmbedding(
|
| 152 |
+
dim=self.config.hidden_size // self.config.num_heads,
|
| 153 |
+
max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
|
| 154 |
+
base=self.config.rope_theta,
|
| 155 |
+
)
|
| 156 |
+
elif self.config.pos_encodings == "learned":
|
| 157 |
+
self.embed_pos = CastedEmbedding(
|
| 158 |
+
self.config.seq_len + self.puzzle_emb_len,
|
| 159 |
+
self.config.hidden_size,
|
| 160 |
+
init_std=embed_init_std,
|
| 161 |
+
cast_to=self.forward_dtype,
|
| 162 |
+
)
|
| 163 |
+
else:
|
| 164 |
+
raise NotImplementedError()
|
| 165 |
+
|
| 166 |
+
# Reasoning Layers
|
| 167 |
+
self.H_level = Model_ACTV2ReasoningModule(
|
| 168 |
+
layers=[Model_ACTV2Block(self.config) for _i in range(self.config.H_layers)]
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
# Initial states
|
| 172 |
+
self.H_init = nn.Buffer(
|
| 173 |
+
trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1),
|
| 174 |
+
persistent=True,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# Q head special init
|
| 178 |
+
# Init Q to (almost) zero for faster learning during bootstrapping
|
| 179 |
+
with torch.no_grad():
|
| 180 |
+
self.q_head.weight.zero_()
|
| 181 |
+
self.q_head.bias.fill_(-5) # type: ignore
|
| 182 |
+
|
| 183 |
+
def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
|
| 184 |
+
# Token embedding
|
| 185 |
+
embedding = self.embed_tokens(input.to(torch.int32))
|
| 186 |
+
|
| 187 |
+
# Puzzle embeddings
|
| 188 |
+
if self.config.puzzle_emb_ndim > 0:
|
| 189 |
+
puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
|
| 190 |
+
|
| 191 |
+
pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
|
| 192 |
+
if pad_count > 0:
|
| 193 |
+
puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
|
| 194 |
+
|
| 195 |
+
embedding = torch.cat(
|
| 196 |
+
(puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
# Position embeddings
|
| 200 |
+
if self.config.pos_encodings == "learned":
|
| 201 |
+
# scale by 1/sqrt(2) to maintain forward variance
|
| 202 |
+
embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
|
| 203 |
+
|
| 204 |
+
# Scale
|
| 205 |
+
return self.embed_scale * embedding
|
| 206 |
+
|
| 207 |
+
def empty_carry(self, batch_size: int):
|
| 208 |
+
return Model_ACTV2InnerCarry(
|
| 209 |
+
z_H=torch.empty(
|
| 210 |
+
batch_size,
|
| 211 |
+
self.config.seq_len + self.puzzle_emb_len,
|
| 212 |
+
self.config.hidden_size,
|
| 213 |
+
dtype=self.forward_dtype,
|
| 214 |
+
),
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
def reset_carry(self, reset_flag: torch.Tensor, carry: Model_ACTV2InnerCarry):
|
| 218 |
+
return Model_ACTV2InnerCarry(
|
| 219 |
+
z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
def forward(
|
| 223 |
+
self, carry: Model_ACTV2InnerCarry, batch: Dict[str, torch.Tensor]
|
| 224 |
+
) -> Tuple[Model_ACTV2InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 225 |
+
seq_info = dict(
|
| 226 |
+
cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
# Input encoding
|
| 230 |
+
input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
|
| 231 |
+
|
| 232 |
+
# 1-step grad
|
| 233 |
+
z_H = self.H_level(carry.z_H, input_embeddings, **seq_info)
|
| 234 |
+
|
| 235 |
+
# LM Outputs
|
| 236 |
+
new_carry = Model_ACTV2InnerCarry(
|
| 237 |
+
z_H=z_H.detach(),
|
| 238 |
+
) # New carry no grad
|
| 239 |
+
output = self.lm_head(z_H)[:, self.puzzle_emb_len :]
|
| 240 |
+
|
| 241 |
+
# Q head
|
| 242 |
+
q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
|
| 243 |
+
|
| 244 |
+
return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class Model_ACTV2(nn.Module):
|
| 248 |
+
"""ACT wrapper."""
|
| 249 |
+
|
| 250 |
+
def __init__(self, config_dict: dict):
|
| 251 |
+
super().__init__()
|
| 252 |
+
self.config = Model_ACTV2Config(**config_dict)
|
| 253 |
+
self.inner = Model_ACTV2_Inner(self.config)
|
| 254 |
+
|
| 255 |
+
@property
|
| 256 |
+
def puzzle_emb(self):
|
| 257 |
+
return self.inner.puzzle_emb
|
| 258 |
+
|
| 259 |
+
def initial_carry(self, batch: Dict[str, torch.Tensor]):
|
| 260 |
+
batch_size = batch["inputs"].shape[0]
|
| 261 |
+
|
| 262 |
+
return Model_ACTV2Carry(
|
| 263 |
+
inner_carry=self.inner.empty_carry(
|
| 264 |
+
batch_size
|
| 265 |
+
), # Empty is expected, it will be reseted in first pass as all sequences are halted.
|
| 266 |
+
steps=torch.zeros((batch_size,), dtype=torch.int32),
|
| 267 |
+
halted=torch.ones((batch_size,), dtype=torch.bool), # Default to halted
|
| 268 |
+
current_data={k: torch.empty_like(v) for k, v in batch.items()},
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
def forward(
|
| 272 |
+
self,
|
| 273 |
+
carry: Model_ACTV2Carry,
|
| 274 |
+
batch: Dict[str, torch.Tensor],
|
| 275 |
+
compute_target_q: bool = False,
|
| 276 |
+
) -> Tuple[Model_ACTV2Carry, Dict[str, torch.Tensor]]:
|
| 277 |
+
# Update data, carry (removing halted sequences)
|
| 278 |
+
new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
|
| 279 |
+
|
| 280 |
+
new_steps = torch.where(carry.halted, 0, carry.steps)
|
| 281 |
+
|
| 282 |
+
new_current_data = {
|
| 283 |
+
k: torch.where(carry.halted.view((-1,) + (1,) * (batch[k].ndim - 1)), batch[k], v)
|
| 284 |
+
for k, v in carry.current_data.items()
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
# Forward inner model
|
| 288 |
+
new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(
|
| 289 |
+
new_inner_carry, new_current_data
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
outputs = {"logits": logits, "q_halt_logits": q_halt_logits, "q_continue_logits": q_continue_logits}
|
| 293 |
+
|
| 294 |
+
with torch.no_grad():
|
| 295 |
+
# Step
|
| 296 |
+
new_steps = new_steps + 1
|
| 297 |
+
is_last_step = new_steps >= self.config.halt_max_steps
|
| 298 |
+
|
| 299 |
+
halted = is_last_step
|
| 300 |
+
|
| 301 |
+
# Check if adaptive computation should be used
|
| 302 |
+
use_adaptive = (self.config.halt_max_steps > 1) and (
|
| 303 |
+
(self.training and self.config.act_enabled)
|
| 304 |
+
or (not self.training and self.config.act_inference)
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
if use_adaptive:
|
| 308 |
+
# Halt signal based on Q-values (but always halt at max steps)
|
| 309 |
+
q_halt_signal = q_halt_logits > q_continue_logits
|
| 310 |
+
halted = halted | q_halt_signal
|
| 311 |
+
|
| 312 |
+
# Store actual steps used for logging (only during inference)
|
| 313 |
+
if not self.training:
|
| 314 |
+
outputs["actual_steps"] = new_steps.float()
|
| 315 |
+
|
| 316 |
+
# Exploration (only during training)
|
| 317 |
+
if self.training:
|
| 318 |
+
min_halt_steps = (
|
| 319 |
+
torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob
|
| 320 |
+
) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
|
| 321 |
+
halted = halted & (new_steps >= min_halt_steps)
|
| 322 |
+
|
| 323 |
+
# Compute target Q (only during training)
|
| 324 |
+
# NOTE: No replay buffer and target networks for computing target Q-value.
|
| 325 |
+
# As batch_size is large, there're many parallel envs.
|
| 326 |
+
# Similar concept as PQN https://arxiv.org/abs/2407.04811
|
| 327 |
+
if self.training and compute_target_q:
|
| 328 |
+
next_q_halt_logits, next_q_continue_logits = self.inner(
|
| 329 |
+
new_inner_carry, new_current_data
|
| 330 |
+
)[-1]
|
| 331 |
+
|
| 332 |
+
outputs["target_q_continue"] = torch.sigmoid(
|
| 333 |
+
torch.where(
|
| 334 |
+
is_last_step,
|
| 335 |
+
next_q_halt_logits,
|
| 336 |
+
torch.maximum(next_q_halt_logits, next_q_continue_logits),
|
| 337 |
+
)
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
return Model_ACTV2Carry(
|
| 341 |
+
new_inner_carry, new_steps, halted, new_current_data
|
| 342 |
+
), outputs
|
models/recursive_reasoning/trm.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple, List, Dict, Optional
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
import copy
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import nn
|
| 8 |
+
from pydantic import BaseModel
|
| 9 |
+
import random
|
| 10 |
+
from models.common import trunc_normal_init_
|
| 11 |
+
from models.layers import rms_norm, LinearSwish, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
|
| 12 |
+
from models.sparse_embedding import CastedSparseEmbedding
|
| 13 |
+
|
| 14 |
+
IGNORE_LABEL_ID = -100
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class TinyRecursiveReasoningModel_ACTV1InnerCarry:
|
| 18 |
+
z_H: torch.Tensor
|
| 19 |
+
z_L: torch.Tensor
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class TinyRecursiveReasoningModel_ACTV1Carry:
|
| 24 |
+
inner_carry: TinyRecursiveReasoningModel_ACTV1InnerCarry
|
| 25 |
+
|
| 26 |
+
steps: torch.Tensor
|
| 27 |
+
halted: torch.Tensor
|
| 28 |
+
|
| 29 |
+
current_data: Dict[str, torch.Tensor]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class TinyRecursiveReasoningModel_ACTV1Config(BaseModel):
|
| 33 |
+
batch_size: int
|
| 34 |
+
seq_len: int
|
| 35 |
+
puzzle_emb_ndim: int = 0
|
| 36 |
+
num_puzzle_identifiers: int
|
| 37 |
+
vocab_size: int
|
| 38 |
+
|
| 39 |
+
H_cycles: int
|
| 40 |
+
L_cycles: int
|
| 41 |
+
|
| 42 |
+
H_layers: int # ignored
|
| 43 |
+
L_layers: int
|
| 44 |
+
|
| 45 |
+
# Transformer config
|
| 46 |
+
hidden_size: int
|
| 47 |
+
expansion: float
|
| 48 |
+
num_heads: int
|
| 49 |
+
pos_encodings: str
|
| 50 |
+
|
| 51 |
+
rms_norm_eps: float = 1e-5
|
| 52 |
+
rope_theta: float = 10000.0
|
| 53 |
+
|
| 54 |
+
# Halting Q-learning config
|
| 55 |
+
halt_max_steps: int
|
| 56 |
+
halt_exploration_prob: float
|
| 57 |
+
|
| 58 |
+
forward_dtype: str = "bfloat16"
|
| 59 |
+
|
| 60 |
+
# Alexia: added
|
| 61 |
+
mlp_t: bool = False # use mlp on L instead of transformer
|
| 62 |
+
puzzle_emb_len: int = 16 # if non-zero, its specified to this value
|
| 63 |
+
no_ACT_continue: bool = True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense
|
| 64 |
+
|
| 65 |
+
class TinyRecursiveReasoningModel_ACTV1Block(nn.Module):
|
| 66 |
+
def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
|
| 67 |
+
super().__init__()
|
| 68 |
+
|
| 69 |
+
self.config = config
|
| 70 |
+
if self.config.mlp_t:
|
| 71 |
+
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len
|
| 72 |
+
self.mlp_t = SwiGLU(
|
| 73 |
+
hidden_size=self.config.seq_len + self.puzzle_emb_len, # L
|
| 74 |
+
expansion=config.expansion,
|
| 75 |
+
)
|
| 76 |
+
else:
|
| 77 |
+
self.self_attn = Attention(
|
| 78 |
+
hidden_size=config.hidden_size,
|
| 79 |
+
head_dim=config.hidden_size // config.num_heads,
|
| 80 |
+
num_heads=config.num_heads,
|
| 81 |
+
num_key_value_heads=config.num_heads,
|
| 82 |
+
causal=False
|
| 83 |
+
)
|
| 84 |
+
self.mlp = SwiGLU(
|
| 85 |
+
hidden_size=config.hidden_size,
|
| 86 |
+
expansion=config.expansion,
|
| 87 |
+
)
|
| 88 |
+
self.norm_eps = config.rms_norm_eps
|
| 89 |
+
|
| 90 |
+
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 91 |
+
# B, L, D = hidden_states.shape
|
| 92 |
+
# Post Norm
|
| 93 |
+
if self.config.mlp_t:
|
| 94 |
+
hidden_states = hidden_states.transpose(1,2)
|
| 95 |
+
out = self.mlp_t(hidden_states)
|
| 96 |
+
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
|
| 97 |
+
hidden_states = hidden_states.transpose(1,2)
|
| 98 |
+
else:
|
| 99 |
+
# Self Attention
|
| 100 |
+
hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
|
| 101 |
+
# Fully Connected
|
| 102 |
+
out = self.mlp(hidden_states)
|
| 103 |
+
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
|
| 104 |
+
return hidden_states
|
| 105 |
+
|
| 106 |
+
class TinyRecursiveReasoningModel_ACTV1ReasoningModule(nn.Module):
|
| 107 |
+
def __init__(self, layers: List[TinyRecursiveReasoningModel_ACTV1Block]):
|
| 108 |
+
super().__init__()
|
| 109 |
+
self.layers = torch.nn.ModuleList(layers)
|
| 110 |
+
|
| 111 |
+
def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 112 |
+
hidden_states = hidden_states + input_injection
|
| 113 |
+
for layer in self.layers:
|
| 114 |
+
hidden_states = layer(hidden_states=hidden_states, **kwargs)
|
| 115 |
+
return hidden_states
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class TinyRecursiveReasoningModel_ACTV1_Inner(nn.Module):
|
| 119 |
+
def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
|
| 120 |
+
super().__init__()
|
| 121 |
+
self.config = config
|
| 122 |
+
self.forward_dtype = getattr(torch, self.config.forward_dtype)
|
| 123 |
+
|
| 124 |
+
# I/O
|
| 125 |
+
|
| 126 |
+
self.embed_scale = math.sqrt(self.config.hidden_size)
|
| 127 |
+
embed_init_std = 1.0 / self.embed_scale
|
| 128 |
+
|
| 129 |
+
self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
|
| 130 |
+
self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
|
| 131 |
+
self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
|
| 132 |
+
|
| 133 |
+
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len # ceil div
|
| 134 |
+
if self.config.puzzle_emb_ndim > 0:
|
| 135 |
+
# Zero init puzzle embeddings
|
| 136 |
+
self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
|
| 137 |
+
batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
|
| 138 |
+
|
| 139 |
+
# LM Blocks
|
| 140 |
+
if self.config.pos_encodings == "rope":
|
| 141 |
+
self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
|
| 142 |
+
max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
|
| 143 |
+
base=self.config.rope_theta)
|
| 144 |
+
elif self.config.pos_encodings == "learned":
|
| 145 |
+
self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
|
| 146 |
+
else:
|
| 147 |
+
pass
|
| 148 |
+
|
| 149 |
+
# Reasoning Layers
|
| 150 |
+
self.L_level = TinyRecursiveReasoningModel_ACTV1ReasoningModule(layers=[TinyRecursiveReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
|
| 151 |
+
|
| 152 |
+
# Initial states
|
| 153 |
+
self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
| 154 |
+
self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
| 155 |
+
|
| 156 |
+
# Q head special init
|
| 157 |
+
# Init Q to (almost) zero for faster learning during bootstrapping
|
| 158 |
+
with torch.no_grad():
|
| 159 |
+
self.q_head.weight.zero_()
|
| 160 |
+
self.q_head.bias.fill_(-5) # type: ignore
|
| 161 |
+
|
| 162 |
+
def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
|
| 163 |
+
# Token embedding
|
| 164 |
+
embedding = self.embed_tokens(input.to(torch.int32))
|
| 165 |
+
|
| 166 |
+
# Puzzle embeddings
|
| 167 |
+
if self.config.puzzle_emb_ndim > 0:
|
| 168 |
+
puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
|
| 169 |
+
|
| 170 |
+
pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
|
| 171 |
+
if pad_count > 0:
|
| 172 |
+
puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
|
| 173 |
+
|
| 174 |
+
embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
|
| 175 |
+
|
| 176 |
+
# Position embeddings
|
| 177 |
+
if self.config.pos_encodings == "learned":
|
| 178 |
+
# scale by 1/sqrt(2) to maintain forward variance
|
| 179 |
+
embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
|
| 180 |
+
|
| 181 |
+
# Scale
|
| 182 |
+
return self.embed_scale * embedding
|
| 183 |
+
|
| 184 |
+
def empty_carry(self, batch_size: int):
|
| 185 |
+
return TinyRecursiveReasoningModel_ACTV1InnerCarry(
|
| 186 |
+
z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
| 187 |
+
z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
def reset_carry(self, reset_flag: torch.Tensor, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry):
|
| 191 |
+
return TinyRecursiveReasoningModel_ACTV1InnerCarry(
|
| 192 |
+
z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
|
| 193 |
+
z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 197 |
+
seq_info = dict(
|
| 198 |
+
cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# Input encoding
|
| 202 |
+
input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
|
| 203 |
+
|
| 204 |
+
# Forward iterations
|
| 205 |
+
it = 0
|
| 206 |
+
z_H, z_L = carry.z_H, carry.z_L
|
| 207 |
+
# H_cycles-1 without grad
|
| 208 |
+
with torch.no_grad():
|
| 209 |
+
for _H_step in range(self.config.H_cycles-1):
|
| 210 |
+
for _L_step in range(self.config.L_cycles):
|
| 211 |
+
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
|
| 212 |
+
z_H = self.L_level(z_H, z_L, **seq_info)
|
| 213 |
+
# 1 with grad
|
| 214 |
+
for _L_step in range(self.config.L_cycles):
|
| 215 |
+
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
|
| 216 |
+
z_H = self.L_level(z_H, z_L, **seq_info)
|
| 217 |
+
|
| 218 |
+
# LM Outputs
|
| 219 |
+
new_carry = TinyRecursiveReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach()) # New carry no grad
|
| 220 |
+
output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
|
| 221 |
+
q_logits = self.q_head(z_H[:, 0]).to(torch.float32) # Q-head; uses the first puzzle_emb position
|
| 222 |
+
return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class TinyRecursiveReasoningModel_ACTV1(nn.Module):
|
| 226 |
+
"""ACT wrapper."""
|
| 227 |
+
|
| 228 |
+
def __init__(self, config_dict: dict):
|
| 229 |
+
super().__init__()
|
| 230 |
+
self.config = TinyRecursiveReasoningModel_ACTV1Config(**config_dict)
|
| 231 |
+
self.inner = TinyRecursiveReasoningModel_ACTV1_Inner(self.config)
|
| 232 |
+
|
| 233 |
+
@property
|
| 234 |
+
def puzzle_emb(self):
|
| 235 |
+
return self.inner.puzzle_emb
|
| 236 |
+
|
| 237 |
+
def initial_carry(self, batch: Dict[str, torch.Tensor]):
|
| 238 |
+
batch_size = batch["inputs"].shape[0]
|
| 239 |
+
|
| 240 |
+
return TinyRecursiveReasoningModel_ACTV1Carry(
|
| 241 |
+
inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
|
| 242 |
+
|
| 243 |
+
steps=torch.zeros((batch_size, ), dtype=torch.int32),
|
| 244 |
+
halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
|
| 245 |
+
|
| 246 |
+
current_data={k: torch.empty_like(v) for k, v in batch.items()}
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
|
| 250 |
+
|
| 251 |
+
# Update data, carry (removing halted sequences)
|
| 252 |
+
new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
|
| 253 |
+
|
| 254 |
+
new_steps = torch.where(carry.halted, 0, carry.steps)
|
| 255 |
+
|
| 256 |
+
new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
|
| 257 |
+
|
| 258 |
+
# Forward inner model
|
| 259 |
+
new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
|
| 260 |
+
|
| 261 |
+
outputs = {
|
| 262 |
+
"logits": logits,
|
| 263 |
+
"q_halt_logits": q_halt_logits,
|
| 264 |
+
"q_continue_logits": q_continue_logits
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
with torch.no_grad():
|
| 268 |
+
# Step
|
| 269 |
+
new_steps = new_steps + 1
|
| 270 |
+
is_last_step = new_steps >= self.config.halt_max_steps
|
| 271 |
+
|
| 272 |
+
halted = is_last_step
|
| 273 |
+
|
| 274 |
+
# if training, and ACT is enabled
|
| 275 |
+
if self.training and (self.config.halt_max_steps > 1):
|
| 276 |
+
|
| 277 |
+
# Halt signal
|
| 278 |
+
# NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
|
| 279 |
+
|
| 280 |
+
if self.config.no_ACT_continue:
|
| 281 |
+
halted = halted | (q_halt_logits > 0)
|
| 282 |
+
else:
|
| 283 |
+
halted = halted | (q_halt_logits > q_continue_logits)
|
| 284 |
+
|
| 285 |
+
# Exploration
|
| 286 |
+
min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
|
| 287 |
+
halted = halted & (new_steps >= min_halt_steps)
|
| 288 |
+
|
| 289 |
+
if not self.config.no_ACT_continue:
|
| 290 |
+
# Compute target Q
|
| 291 |
+
# NOTE: No replay buffer and target networks for computing target Q-value.
|
| 292 |
+
# As batch_size is large, there're many parallel envs.
|
| 293 |
+
# Similar concept as PQN https://arxiv.org/abs/2407.04811
|
| 294 |
+
_, _, (next_q_halt_logits, next_q_continue_logits), _, _ = self.inner(new_inner_carry, new_current_data)
|
| 295 |
+
outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
|
| 296 |
+
|
| 297 |
+
return TinyRecursiveReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
|
models/recursive_reasoning/trm_hier6.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple, List, Dict, Optional
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
import copy
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import nn
|
| 8 |
+
from pydantic import BaseModel
|
| 9 |
+
import random
|
| 10 |
+
from models.common import trunc_normal_init_
|
| 11 |
+
from models.layers import rms_norm, LinearSwish, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
|
| 12 |
+
from models.sparse_embedding import CastedSparseEmbedding
|
| 13 |
+
|
| 14 |
+
IGNORE_LABEL_ID = -100
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class TinyRecursiveReasoningModel_ACTV1InnerCarry:
|
| 18 |
+
z_H: torch.Tensor
|
| 19 |
+
z_L1: torch.Tensor
|
| 20 |
+
z_L2: torch.Tensor
|
| 21 |
+
z_L3: torch.Tensor
|
| 22 |
+
z_L4: torch.Tensor
|
| 23 |
+
z_L5: torch.Tensor
|
| 24 |
+
z_L6: torch.Tensor
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class TinyRecursiveReasoningModel_ACTV1Carry:
|
| 30 |
+
inner_carry: TinyRecursiveReasoningModel_ACTV1InnerCarry
|
| 31 |
+
|
| 32 |
+
steps: torch.Tensor
|
| 33 |
+
halted: torch.Tensor
|
| 34 |
+
|
| 35 |
+
current_data: Dict[str, torch.Tensor]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class TinyRecursiveReasoningModel_ACTV1Config(BaseModel):
|
| 39 |
+
batch_size: int
|
| 40 |
+
seq_len: int
|
| 41 |
+
puzzle_emb_ndim: int = 0
|
| 42 |
+
num_puzzle_identifiers: int
|
| 43 |
+
vocab_size: int
|
| 44 |
+
|
| 45 |
+
H_cycles: int
|
| 46 |
+
L_cycles: int
|
| 47 |
+
|
| 48 |
+
H_layers: int # ignored
|
| 49 |
+
L_layers: int
|
| 50 |
+
|
| 51 |
+
# Transformer config
|
| 52 |
+
hidden_size: int
|
| 53 |
+
expansion: float
|
| 54 |
+
num_heads: int
|
| 55 |
+
pos_encodings: str
|
| 56 |
+
|
| 57 |
+
rms_norm_eps: float = 1e-5
|
| 58 |
+
rope_theta: float = 10000.0
|
| 59 |
+
|
| 60 |
+
# Halting Q-learning config
|
| 61 |
+
halt_max_steps: int
|
| 62 |
+
halt_exploration_prob: float
|
| 63 |
+
|
| 64 |
+
forward_dtype: str = "bfloat16"
|
| 65 |
+
|
| 66 |
+
# Alexia: added
|
| 67 |
+
mlp_t: bool = False # use mlp on L instead of transformer
|
| 68 |
+
puzzle_emb_len: int = 16 # if non-zero, its specified to this value
|
| 69 |
+
no_ACT_continue: bool = True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense
|
| 70 |
+
|
| 71 |
+
class TinyRecursiveReasoningModel_ACTV1Block(nn.Module):
|
| 72 |
+
def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
|
| 73 |
+
super().__init__()
|
| 74 |
+
|
| 75 |
+
self.config = config
|
| 76 |
+
if self.config.mlp_t:
|
| 77 |
+
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len
|
| 78 |
+
self.mlp_t = SwiGLU(
|
| 79 |
+
hidden_size=self.config.seq_len + self.puzzle_emb_len, # L
|
| 80 |
+
expansion=config.expansion,
|
| 81 |
+
)
|
| 82 |
+
else:
|
| 83 |
+
self.self_attn = Attention(
|
| 84 |
+
hidden_size=config.hidden_size,
|
| 85 |
+
head_dim=config.hidden_size // config.num_heads,
|
| 86 |
+
num_heads=config.num_heads,
|
| 87 |
+
num_key_value_heads=config.num_heads,
|
| 88 |
+
causal=False
|
| 89 |
+
)
|
| 90 |
+
self.mlp = SwiGLU(
|
| 91 |
+
hidden_size=config.hidden_size,
|
| 92 |
+
expansion=config.expansion,
|
| 93 |
+
)
|
| 94 |
+
self.norm_eps = config.rms_norm_eps
|
| 95 |
+
|
| 96 |
+
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 97 |
+
# B, L, D = hidden_states.shape
|
| 98 |
+
# Post Norm
|
| 99 |
+
if self.config.mlp_t:
|
| 100 |
+
hidden_states = hidden_states.transpose(1,2)
|
| 101 |
+
out = self.mlp_t(hidden_states)
|
| 102 |
+
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
|
| 103 |
+
hidden_states = hidden_states.transpose(1,2)
|
| 104 |
+
else:
|
| 105 |
+
# Self Attention
|
| 106 |
+
hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
|
| 107 |
+
# Fully Connected
|
| 108 |
+
out = self.mlp(hidden_states)
|
| 109 |
+
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
|
| 110 |
+
return hidden_states
|
| 111 |
+
|
| 112 |
+
class TinyRecursiveReasoningModel_ACTV1ReasoningModule(nn.Module):
|
| 113 |
+
def __init__(self, layers: List[TinyRecursiveReasoningModel_ACTV1Block]):
|
| 114 |
+
super().__init__()
|
| 115 |
+
self.layers = torch.nn.ModuleList(layers)
|
| 116 |
+
|
| 117 |
+
def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 118 |
+
hidden_states = hidden_states + input_injection
|
| 119 |
+
for layer in self.layers:
|
| 120 |
+
hidden_states = layer(hidden_states=hidden_states, **kwargs)
|
| 121 |
+
return hidden_states
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class TinyRecursiveReasoningModel_ACTV1_Inner(nn.Module):
|
| 125 |
+
def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
|
| 126 |
+
super().__init__()
|
| 127 |
+
self.config = config
|
| 128 |
+
self.forward_dtype = getattr(torch, self.config.forward_dtype)
|
| 129 |
+
|
| 130 |
+
# I/O
|
| 131 |
+
|
| 132 |
+
self.embed_scale = math.sqrt(self.config.hidden_size)
|
| 133 |
+
embed_init_std = 1.0 / self.embed_scale
|
| 134 |
+
|
| 135 |
+
self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
|
| 136 |
+
self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
|
| 137 |
+
self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
|
| 138 |
+
|
| 139 |
+
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len # ceil div
|
| 140 |
+
if self.config.puzzle_emb_ndim > 0:
|
| 141 |
+
# Zero init puzzle embeddings
|
| 142 |
+
self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
|
| 143 |
+
batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
|
| 144 |
+
|
| 145 |
+
# LM Blocks
|
| 146 |
+
if self.config.pos_encodings == "rope":
|
| 147 |
+
self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
|
| 148 |
+
max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
|
| 149 |
+
base=self.config.rope_theta)
|
| 150 |
+
elif self.config.pos_encodings == "learned":
|
| 151 |
+
self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
|
| 152 |
+
else:
|
| 153 |
+
pass
|
| 154 |
+
|
| 155 |
+
# Reasoning Layers
|
| 156 |
+
self.L_level = TinyRecursiveReasoningModel_ACTV1ReasoningModule(layers=[TinyRecursiveReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
|
| 157 |
+
|
| 158 |
+
# Initial states
|
| 159 |
+
self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
| 160 |
+
self.L1_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
| 161 |
+
self.L2_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
| 162 |
+
self.L3_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
| 163 |
+
self.L4_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
| 164 |
+
self.L5_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
| 165 |
+
self.L6_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
| 166 |
+
|
| 167 |
+
# Q head special init
|
| 168 |
+
# Init Q to (almost) zero for faster learning during bootstrapping
|
| 169 |
+
with torch.no_grad():
|
| 170 |
+
self.q_head.weight.zero_()
|
| 171 |
+
self.q_head.bias.fill_(-5) # type: ignore
|
| 172 |
+
|
| 173 |
+
def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
|
| 174 |
+
# Token embedding
|
| 175 |
+
embedding = self.embed_tokens(input.to(torch.int32))
|
| 176 |
+
|
| 177 |
+
# Puzzle embeddings
|
| 178 |
+
if self.config.puzzle_emb_ndim > 0:
|
| 179 |
+
puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
|
| 180 |
+
|
| 181 |
+
pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
|
| 182 |
+
if pad_count > 0:
|
| 183 |
+
puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
|
| 184 |
+
|
| 185 |
+
embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
|
| 186 |
+
|
| 187 |
+
# Position embeddings
|
| 188 |
+
if self.config.pos_encodings == "learned":
|
| 189 |
+
# scale by 1/sqrt(2) to maintain forward variance
|
| 190 |
+
embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
|
| 191 |
+
|
| 192 |
+
# Scale
|
| 193 |
+
return self.embed_scale * embedding
|
| 194 |
+
|
| 195 |
+
def empty_carry(self, batch_size: int):
|
| 196 |
+
return TinyRecursiveReasoningModel_ACTV1InnerCarry(
|
| 197 |
+
z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
| 198 |
+
z_L1=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
| 199 |
+
z_L2=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
| 200 |
+
z_L3=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
| 201 |
+
z_L4=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
| 202 |
+
z_L5=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
| 203 |
+
z_L6=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
def reset_carry(self, reset_flag: torch.Tensor, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry):
|
| 207 |
+
return TinyRecursiveReasoningModel_ACTV1InnerCarry(
|
| 208 |
+
z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
|
| 209 |
+
z_L1=torch.where(reset_flag.view(-1, 1, 1), self.L1_init, carry.z_L1),
|
| 210 |
+
z_L2=torch.where(reset_flag.view(-1, 1, 1), self.L2_init, carry.z_L2),
|
| 211 |
+
z_L3=torch.where(reset_flag.view(-1, 1, 1), self.L3_init, carry.z_L3),
|
| 212 |
+
z_L4=torch.where(reset_flag.view(-1, 1, 1), self.L4_init, carry.z_L4),
|
| 213 |
+
z_L5=torch.where(reset_flag.view(-1, 1, 1), self.L5_init, carry.z_L5),
|
| 214 |
+
z_L6=torch.where(reset_flag.view(-1, 1, 1), self.L6_init, carry.z_L6),
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 219 |
+
seq_info = dict(
|
| 220 |
+
cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
# Input encoding
|
| 224 |
+
input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
|
| 225 |
+
|
| 226 |
+
# Forward iterations
|
| 227 |
+
it = 0
|
| 228 |
+
z_H, z_L = carry.z_H, [carry.z_L1, carry.z_L2, carry.z_L3, carry.z_L4, carry.z_L5, carry.z_L6]
|
| 229 |
+
# H_cycles-1 without grad
|
| 230 |
+
with torch.no_grad():
|
| 231 |
+
for _H_step in range(self.config.H_cycles-1):
|
| 232 |
+
for _L_step in range(self.config.L_cycles):
|
| 233 |
+
z_L_ = z_L[0] + z_L[1] + z_L[2] + z_L[3] + z_L[4] + z_L[5]
|
| 234 |
+
z_L[_L_step] = self.L_level(z_L_, z_H + input_embeddings, **seq_info)
|
| 235 |
+
z_L_ = z_L[0] + z_L[1] + z_L[2] + z_L[3] + z_L[4] + z_L[5]
|
| 236 |
+
z_H = self.L_level(z_H, z_L_, **seq_info)
|
| 237 |
+
# 1 with grad
|
| 238 |
+
for _L_step in range(self.config.L_cycles):
|
| 239 |
+
z_L_ = z_L[0] + z_L[1] + z_L[2] + z_L[3] + z_L[4] + z_L[5]
|
| 240 |
+
z_L[_L_step] = self.L_level(z_L_, z_H + input_embeddings, **seq_info)
|
| 241 |
+
z_L_ = z_L[0] + z_L[1] + z_L[2] + z_L[3] + z_L[4] + z_L[5]
|
| 242 |
+
z_H = self.L_level(z_H, z_L_, **seq_info)
|
| 243 |
+
|
| 244 |
+
# LM Outputs
|
| 245 |
+
new_carry = TinyRecursiveReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L1=z_L[0].detach(), z_L2=z_L[1].detach(), z_L3=z_L[2].detach(), z_L4=z_L[3].detach(), z_L5=z_L[4].detach(), z_L6=z_L[5].detach()) # New carry no grad
|
| 246 |
+
output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
|
| 247 |
+
q_logits = self.q_head(z_H[:, 0]).to(torch.float32) # Q-head; uses the first puzzle_emb position
|
| 248 |
+
return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class TinyRecursiveReasoningModel_ACTV1(nn.Module):
|
| 252 |
+
"""ACT wrapper."""
|
| 253 |
+
|
| 254 |
+
def __init__(self, config_dict: dict):
|
| 255 |
+
super().__init__()
|
| 256 |
+
self.config = TinyRecursiveReasoningModel_ACTV1Config(**config_dict)
|
| 257 |
+
self.inner = TinyRecursiveReasoningModel_ACTV1_Inner(self.config)
|
| 258 |
+
|
| 259 |
+
@property
|
| 260 |
+
def puzzle_emb(self):
|
| 261 |
+
return self.inner.puzzle_emb
|
| 262 |
+
|
| 263 |
+
def initial_carry(self, batch: Dict[str, torch.Tensor]):
|
| 264 |
+
batch_size = batch["inputs"].shape[0]
|
| 265 |
+
|
| 266 |
+
return TinyRecursiveReasoningModel_ACTV1Carry(
|
| 267 |
+
inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
|
| 268 |
+
|
| 269 |
+
steps=torch.zeros((batch_size, ), dtype=torch.int32),
|
| 270 |
+
halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
|
| 271 |
+
|
| 272 |
+
current_data={k: torch.empty_like(v) for k, v in batch.items()}
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
|
| 276 |
+
|
| 277 |
+
# Update data, carry (removing halted sequences)
|
| 278 |
+
new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
|
| 279 |
+
|
| 280 |
+
new_steps = torch.where(carry.halted, 0, carry.steps)
|
| 281 |
+
|
| 282 |
+
new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
|
| 283 |
+
|
| 284 |
+
# Forward inner model
|
| 285 |
+
new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
|
| 286 |
+
|
| 287 |
+
outputs = {
|
| 288 |
+
"logits": logits,
|
| 289 |
+
"q_halt_logits": q_halt_logits,
|
| 290 |
+
"q_continue_logits": q_continue_logits
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
with torch.no_grad():
|
| 294 |
+
# Step
|
| 295 |
+
new_steps = new_steps + 1
|
| 296 |
+
is_last_step = new_steps >= self.config.halt_max_steps
|
| 297 |
+
|
| 298 |
+
halted = is_last_step
|
| 299 |
+
|
| 300 |
+
# if training, and ACT is enabled
|
| 301 |
+
if self.training and (self.config.halt_max_steps > 1):
|
| 302 |
+
|
| 303 |
+
# Halt signal
|
| 304 |
+
# NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
|
| 305 |
+
|
| 306 |
+
if self.config.no_ACT_continue:
|
| 307 |
+
halted = halted | (q_halt_logits > 0)
|
| 308 |
+
else:
|
| 309 |
+
halted = halted | (q_halt_logits > q_continue_logits)
|
| 310 |
+
|
| 311 |
+
# Exploration
|
| 312 |
+
min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
|
| 313 |
+
halted = halted & (new_steps >= min_halt_steps)
|
| 314 |
+
|
| 315 |
+
if not self.config.no_ACT_continue:
|
| 316 |
+
# Compute target Q
|
| 317 |
+
# NOTE: No replay buffer and target networks for computing target Q-value.
|
| 318 |
+
# As batch_size is large, there're many parallel envs.
|
| 319 |
+
# Similar concept as PQN https://arxiv.org/abs/2407.04811
|
| 320 |
+
_, _, (next_q_halt_logits, next_q_continue_logits), _, _ = self.inner(new_inner_carry, new_current_data)
|
| 321 |
+
outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
|
| 322 |
+
|
| 323 |
+
return TinyRecursiveReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
|
models/recursive_reasoning/trm_singlez.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple, List, Dict, Optional
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
import copy
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import nn
|
| 8 |
+
from pydantic import BaseModel
|
| 9 |
+
import random
|
| 10 |
+
from models.common import trunc_normal_init_
|
| 11 |
+
from models.layers import rms_norm, LinearSwish, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
|
| 12 |
+
from models.sparse_embedding import CastedSparseEmbedding
|
| 13 |
+
|
| 14 |
+
IGNORE_LABEL_ID = -100
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class TinyRecursiveReasoningModel_ACTV1InnerCarry:
|
| 18 |
+
z_L: torch.Tensor
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class TinyRecursiveReasoningModel_ACTV1Carry:
|
| 24 |
+
inner_carry: TinyRecursiveReasoningModel_ACTV1InnerCarry
|
| 25 |
+
|
| 26 |
+
steps: torch.Tensor
|
| 27 |
+
halted: torch.Tensor
|
| 28 |
+
|
| 29 |
+
current_data: Dict[str, torch.Tensor]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class TinyRecursiveReasoningModel_ACTV1Config(BaseModel):
|
| 33 |
+
batch_size: int
|
| 34 |
+
seq_len: int
|
| 35 |
+
puzzle_emb_ndim: int = 0
|
| 36 |
+
num_puzzle_identifiers: int
|
| 37 |
+
vocab_size: int
|
| 38 |
+
|
| 39 |
+
H_cycles: int
|
| 40 |
+
L_cycles: int
|
| 41 |
+
|
| 42 |
+
H_layers: int # ignored
|
| 43 |
+
L_layers: int
|
| 44 |
+
|
| 45 |
+
# Transformer config
|
| 46 |
+
hidden_size: int
|
| 47 |
+
expansion: float
|
| 48 |
+
num_heads: int
|
| 49 |
+
pos_encodings: str
|
| 50 |
+
|
| 51 |
+
rms_norm_eps: float = 1e-5
|
| 52 |
+
rope_theta: float = 10000.0
|
| 53 |
+
|
| 54 |
+
# Halting Q-learning config
|
| 55 |
+
halt_max_steps: int
|
| 56 |
+
halt_exploration_prob: float
|
| 57 |
+
|
| 58 |
+
forward_dtype: str = "bfloat16"
|
| 59 |
+
|
| 60 |
+
# Alexia: added
|
| 61 |
+
mlp_t: bool = False # use mlp on L instead of transformer
|
| 62 |
+
puzzle_emb_len: int = 16 # if non-zero, its specified to this value
|
| 63 |
+
no_ACT_continue: bool = True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense
|
| 64 |
+
|
| 65 |
+
class TinyRecursiveReasoningModel_ACTV1Block(nn.Module):
|
| 66 |
+
def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
|
| 67 |
+
super().__init__()
|
| 68 |
+
|
| 69 |
+
self.config = config
|
| 70 |
+
if self.config.mlp_t:
|
| 71 |
+
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len
|
| 72 |
+
self.mlp_t = SwiGLU(
|
| 73 |
+
hidden_size=self.config.seq_len + self.puzzle_emb_len, # L
|
| 74 |
+
expansion=config.expansion,
|
| 75 |
+
)
|
| 76 |
+
else:
|
| 77 |
+
self.self_attn = Attention(
|
| 78 |
+
hidden_size=config.hidden_size,
|
| 79 |
+
head_dim=config.hidden_size // config.num_heads,
|
| 80 |
+
num_heads=config.num_heads,
|
| 81 |
+
num_key_value_heads=config.num_heads,
|
| 82 |
+
causal=False
|
| 83 |
+
)
|
| 84 |
+
self.mlp = SwiGLU(
|
| 85 |
+
hidden_size=config.hidden_size,
|
| 86 |
+
expansion=config.expansion,
|
| 87 |
+
)
|
| 88 |
+
self.norm_eps = config.rms_norm_eps
|
| 89 |
+
|
| 90 |
+
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 91 |
+
# B, L, D = hidden_states.shape
|
| 92 |
+
# Post Norm
|
| 93 |
+
if self.config.mlp_t:
|
| 94 |
+
hidden_states = hidden_states.transpose(1,2)
|
| 95 |
+
out = self.mlp_t(hidden_states)
|
| 96 |
+
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
|
| 97 |
+
hidden_states = hidden_states.transpose(1,2)
|
| 98 |
+
else:
|
| 99 |
+
# Self Attention
|
| 100 |
+
hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
|
| 101 |
+
# Fully Connected
|
| 102 |
+
out = self.mlp(hidden_states)
|
| 103 |
+
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
|
| 104 |
+
return hidden_states
|
| 105 |
+
|
| 106 |
+
class TinyRecursiveReasoningModel_ACTV1ReasoningModule(nn.Module):
|
| 107 |
+
def __init__(self, layers: List[TinyRecursiveReasoningModel_ACTV1Block]):
|
| 108 |
+
super().__init__()
|
| 109 |
+
self.layers = torch.nn.ModuleList(layers)
|
| 110 |
+
|
| 111 |
+
def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 112 |
+
for layer in self.layers:
|
| 113 |
+
hidden_states = layer(hidden_states=hidden_states, **kwargs)
|
| 114 |
+
return hidden_states
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class TinyRecursiveReasoningModel_ACTV1_Inner(nn.Module):
|
| 118 |
+
def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
|
| 119 |
+
super().__init__()
|
| 120 |
+
self.config = config
|
| 121 |
+
self.forward_dtype = getattr(torch, self.config.forward_dtype)
|
| 122 |
+
|
| 123 |
+
# I/O
|
| 124 |
+
|
| 125 |
+
self.embed_scale = math.sqrt(self.config.hidden_size)
|
| 126 |
+
embed_init_std = 1.0 / self.embed_scale
|
| 127 |
+
|
| 128 |
+
self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
|
| 129 |
+
self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
|
| 130 |
+
self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
|
| 131 |
+
|
| 132 |
+
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len # ceil div
|
| 133 |
+
if self.config.puzzle_emb_ndim > 0:
|
| 134 |
+
# Zero init puzzle embeddings
|
| 135 |
+
self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
|
| 136 |
+
batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
|
| 137 |
+
|
| 138 |
+
# LM Blocks
|
| 139 |
+
if self.config.pos_encodings == "rope":
|
| 140 |
+
self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
|
| 141 |
+
max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
|
| 142 |
+
base=self.config.rope_theta)
|
| 143 |
+
elif self.config.pos_encodings == "learned":
|
| 144 |
+
self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
|
| 145 |
+
else:
|
| 146 |
+
pass
|
| 147 |
+
|
| 148 |
+
# Reasoning Layers
|
| 149 |
+
self.L_level = TinyRecursiveReasoningModel_ACTV1ReasoningModule(layers=[TinyRecursiveReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
|
| 150 |
+
|
| 151 |
+
# Initial states
|
| 152 |
+
self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
| 153 |
+
|
| 154 |
+
# Q head special init
|
| 155 |
+
# Init Q to (almost) zero for faster learning during bootstrapping
|
| 156 |
+
with torch.no_grad():
|
| 157 |
+
self.q_head.weight.zero_()
|
| 158 |
+
self.q_head.bias.fill_(-5) # type: ignore
|
| 159 |
+
|
| 160 |
+
def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
|
| 161 |
+
# Token embedding
|
| 162 |
+
embedding = self.embed_tokens(input.to(torch.int32))
|
| 163 |
+
|
| 164 |
+
# Puzzle embeddings
|
| 165 |
+
if self.config.puzzle_emb_ndim > 0:
|
| 166 |
+
puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
|
| 167 |
+
|
| 168 |
+
pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
|
| 169 |
+
if pad_count > 0:
|
| 170 |
+
puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
|
| 171 |
+
|
| 172 |
+
embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
|
| 173 |
+
|
| 174 |
+
# Position embeddings
|
| 175 |
+
if self.config.pos_encodings == "learned":
|
| 176 |
+
# scale by 1/sqrt(2) to maintain forward variance
|
| 177 |
+
embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
|
| 178 |
+
|
| 179 |
+
# Scale
|
| 180 |
+
return self.embed_scale * embedding
|
| 181 |
+
|
| 182 |
+
def empty_carry(self, batch_size: int):
|
| 183 |
+
return TinyRecursiveReasoningModel_ACTV1InnerCarry(
|
| 184 |
+
z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
def reset_carry(self, reset_flag: torch.Tensor, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry):
|
| 188 |
+
return TinyRecursiveReasoningModel_ACTV1InnerCarry(
|
| 189 |
+
z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 193 |
+
seq_info = dict(
|
| 194 |
+
cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
# Input encoding
|
| 198 |
+
input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
|
| 199 |
+
|
| 200 |
+
# Forward iterations
|
| 201 |
+
it = 0
|
| 202 |
+
z_L = carry.z_L
|
| 203 |
+
# H_cycles-1 without grad
|
| 204 |
+
with torch.no_grad():
|
| 205 |
+
for _H_step in range(self.config.H_cycles-1):
|
| 206 |
+
for _L_step in range(self.config.L_cycles):
|
| 207 |
+
z_L = self.L_level(z_L + input_embeddings, **seq_info)
|
| 208 |
+
z_L = self.L_level(z_L, **seq_info)
|
| 209 |
+
# 1 with grad
|
| 210 |
+
for _L_step in range(self.config.L_cycles):
|
| 211 |
+
z_L = self.L_level(z_L + input_embeddings, **seq_info)
|
| 212 |
+
z_L = self.L_level(z_L, **seq_info)
|
| 213 |
+
z_out = z_L
|
| 214 |
+
|
| 215 |
+
# LM Outputs
|
| 216 |
+
new_carry = TinyRecursiveReasoningModel_ACTV1InnerCarry(z_L=z_L.detach()) # New carry no grad
|
| 217 |
+
output = self.lm_head(z_out)[:, self.puzzle_emb_len:]
|
| 218 |
+
q_logits = self.q_head(z_out[:, 0]).to(torch.float32) # Q-head; uses the first puzzle_emb position
|
| 219 |
+
return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class TinyRecursiveReasoningModel_ACTV1(nn.Module):
|
| 223 |
+
"""ACT wrapper."""
|
| 224 |
+
|
| 225 |
+
def __init__(self, config_dict: dict):
|
| 226 |
+
super().__init__()
|
| 227 |
+
self.config = TinyRecursiveReasoningModel_ACTV1Config(**config_dict)
|
| 228 |
+
self.inner = TinyRecursiveReasoningModel_ACTV1_Inner(self.config)
|
| 229 |
+
|
| 230 |
+
@property
|
| 231 |
+
def puzzle_emb(self):
|
| 232 |
+
return self.inner.puzzle_emb
|
| 233 |
+
|
| 234 |
+
def initial_carry(self, batch: Dict[str, torch.Tensor]):
|
| 235 |
+
batch_size = batch["inputs"].shape[0]
|
| 236 |
+
|
| 237 |
+
return TinyRecursiveReasoningModel_ACTV1Carry(
|
| 238 |
+
inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
|
| 239 |
+
|
| 240 |
+
steps=torch.zeros((batch_size, ), dtype=torch.int32),
|
| 241 |
+
halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
|
| 242 |
+
|
| 243 |
+
current_data={k: torch.empty_like(v) for k, v in batch.items()}
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
|
| 247 |
+
|
| 248 |
+
# Update data, carry (removing halted sequences)
|
| 249 |
+
new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
|
| 250 |
+
|
| 251 |
+
new_steps = torch.where(carry.halted, 0, carry.steps)
|
| 252 |
+
|
| 253 |
+
new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
|
| 254 |
+
|
| 255 |
+
# Forward inner model
|
| 256 |
+
new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
|
| 257 |
+
|
| 258 |
+
outputs = {
|
| 259 |
+
"logits": logits,
|
| 260 |
+
"q_halt_logits": q_halt_logits,
|
| 261 |
+
"q_continue_logits": q_continue_logits
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
with torch.no_grad():
|
| 265 |
+
# Step
|
| 266 |
+
new_steps = new_steps + 1
|
| 267 |
+
is_last_step = new_steps >= self.config.halt_max_steps
|
| 268 |
+
|
| 269 |
+
halted = is_last_step
|
| 270 |
+
|
| 271 |
+
# if training, and ACT is enabled
|
| 272 |
+
if self.training and (self.config.halt_max_steps > 1):
|
| 273 |
+
|
| 274 |
+
# Halt signal
|
| 275 |
+
# NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
|
| 276 |
+
|
| 277 |
+
if self.config.no_ACT_continue:
|
| 278 |
+
halted = halted | (q_halt_logits > 0)
|
| 279 |
+
else:
|
| 280 |
+
halted = halted | (q_halt_logits > q_continue_logits)
|
| 281 |
+
|
| 282 |
+
# Exploration
|
| 283 |
+
min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
|
| 284 |
+
halted = halted & (new_steps >= min_halt_steps)
|
| 285 |
+
|
| 286 |
+
if not self.config.no_ACT_continue:
|
| 287 |
+
# Compute target Q
|
| 288 |
+
# NOTE: No replay buffer and target networks for computing target Q-value.
|
| 289 |
+
# As batch_size is large, there're many parallel envs.
|
| 290 |
+
# Similar concept as PQN https://arxiv.org/abs/2407.04811
|
| 291 |
+
_, _, (next_q_halt_logits, next_q_continue_logits), _, _ = self.inner(new_inner_carry, new_current_data)
|
| 292 |
+
outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
|
| 293 |
+
|
| 294 |
+
return TinyRecursiveReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
|
models/sparse_embedding.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
import torch.distributed as dist
|
| 6 |
+
from torch.optim.optimizer import Optimizer, ParamsT
|
| 7 |
+
|
| 8 |
+
from models.common import trunc_normal_init_
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class CastedSparseEmbedding(nn.Module):
|
| 12 |
+
def __init__(self, num_embeddings: int, embedding_dim: int, batch_size: int, init_std: float, cast_to: torch.dtype):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.cast_to = cast_to
|
| 15 |
+
|
| 16 |
+
# Real Weights
|
| 17 |
+
# Truncated LeCun normal init
|
| 18 |
+
self.weights = nn.Buffer(
|
| 19 |
+
trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std), persistent=True
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# Local weights and IDs
|
| 23 |
+
# Local embeddings, with gradient, not persistent
|
| 24 |
+
self.local_weights = nn.Buffer(torch.zeros(batch_size, embedding_dim, requires_grad=True), persistent=False)
|
| 25 |
+
# Local embedding IDs, not persistent
|
| 26 |
+
self.local_ids = nn.Buffer(torch.zeros(batch_size, dtype=torch.int32), persistent=False)
|
| 27 |
+
|
| 28 |
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
| 29 |
+
if not self.training:
|
| 30 |
+
# Test mode, no gradient
|
| 31 |
+
return self.weights[inputs].to(self.cast_to)
|
| 32 |
+
|
| 33 |
+
# Training mode, fill puzzle embedding from weights
|
| 34 |
+
with torch.no_grad():
|
| 35 |
+
self.local_weights.copy_(self.weights[inputs])
|
| 36 |
+
self.local_ids.copy_(inputs)
|
| 37 |
+
|
| 38 |
+
return self.local_weights.to(self.cast_to)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class CastedSparseEmbeddingSignSGD_Distributed(Optimizer):
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
params: ParamsT,
|
| 45 |
+
|
| 46 |
+
world_size: int,
|
| 47 |
+
lr: Union[float, torch.Tensor] = 1e-3,
|
| 48 |
+
weight_decay: float = 1e-2,
|
| 49 |
+
):
|
| 50 |
+
if not 0.0 <= lr:
|
| 51 |
+
raise ValueError(f"Invalid learning rate: {lr}")
|
| 52 |
+
if not 0.0 <= weight_decay:
|
| 53 |
+
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
| 54 |
+
|
| 55 |
+
defaults = dict(
|
| 56 |
+
lr=lr,
|
| 57 |
+
weight_decay=weight_decay,
|
| 58 |
+
world_size=world_size
|
| 59 |
+
)
|
| 60 |
+
super().__init__(params, defaults)
|
| 61 |
+
|
| 62 |
+
@torch.no_grad
|
| 63 |
+
def step(self, closure=None): # type: ignore
|
| 64 |
+
for group in self.param_groups:
|
| 65 |
+
# Find the sparse embedding weights
|
| 66 |
+
local_weights_grad = None
|
| 67 |
+
local_ids = None
|
| 68 |
+
weights = None
|
| 69 |
+
|
| 70 |
+
assert len(group["params"]) == 3
|
| 71 |
+
for p in group["params"]:
|
| 72 |
+
if p.requires_grad:
|
| 73 |
+
local_weights_grad = p.grad
|
| 74 |
+
elif p.ndim == 1:
|
| 75 |
+
local_ids = p
|
| 76 |
+
elif p.ndim == 2:
|
| 77 |
+
weights = p
|
| 78 |
+
else:
|
| 79 |
+
assert False
|
| 80 |
+
|
| 81 |
+
assert local_ids is not None
|
| 82 |
+
assert weights is not None
|
| 83 |
+
|
| 84 |
+
# Apply SignSGD
|
| 85 |
+
# Adam ≈ SignSGD if gradient is very sparse
|
| 86 |
+
if local_weights_grad is not None:
|
| 87 |
+
_sparse_emb_signsgd_dist(
|
| 88 |
+
local_weights_grad,
|
| 89 |
+
local_ids,
|
| 90 |
+
weights,
|
| 91 |
+
|
| 92 |
+
lr=group["lr"],
|
| 93 |
+
weight_decay=group["weight_decay"],
|
| 94 |
+
world_size=group["world_size"]
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _sparse_emb_signsgd_dist(
|
| 99 |
+
local_weights_grad: torch.Tensor,
|
| 100 |
+
local_ids: torch.Tensor,
|
| 101 |
+
weights: torch.Tensor,
|
| 102 |
+
|
| 103 |
+
lr: float,
|
| 104 |
+
weight_decay: float,
|
| 105 |
+
world_size: int
|
| 106 |
+
) -> None:
|
| 107 |
+
N, D = local_weights_grad.shape
|
| 108 |
+
|
| 109 |
+
# All-gather
|
| 110 |
+
all_weights_grad = local_weights_grad
|
| 111 |
+
all_ids = local_ids
|
| 112 |
+
|
| 113 |
+
if world_size > 1:
|
| 114 |
+
all_weights_grad = torch.empty((world_size * N, D), dtype=local_weights_grad.dtype, device=local_weights_grad.device)
|
| 115 |
+
all_ids = torch.empty(world_size * N, dtype=local_ids.dtype, device=local_ids.device)
|
| 116 |
+
|
| 117 |
+
dist.all_gather_into_tensor(all_weights_grad, local_weights_grad)
|
| 118 |
+
dist.all_gather_into_tensor(all_ids, local_ids)
|
| 119 |
+
|
| 120 |
+
# Unique
|
| 121 |
+
grad_ids, inv = all_ids.unique(return_inverse=True)
|
| 122 |
+
|
| 123 |
+
grad = torch.zeros((grad_ids.shape[0], D), dtype=all_weights_grad.dtype, device=all_weights_grad.device)
|
| 124 |
+
grad.scatter_add_(0, inv.unsqueeze(-1).expand(-1, D), all_weights_grad)
|
| 125 |
+
|
| 126 |
+
# SignSGD with decoupled weight decay
|
| 127 |
+
p = weights[grad_ids]
|
| 128 |
+
|
| 129 |
+
p.mul_(1.0 - lr * weight_decay).add_(torch.sign(grad), alpha=-lr)
|
| 130 |
+
|
| 131 |
+
# Write updated slices back
|
| 132 |
+
weights[grad_ids] = p
|
utils/functions.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
import inspect
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def load_model_class(identifier: str, prefix: str = "models."):
|
| 6 |
+
module_path, class_name = identifier.split('@')
|
| 7 |
+
|
| 8 |
+
# Import the module
|
| 9 |
+
module = importlib.import_module(prefix + module_path)
|
| 10 |
+
cls = getattr(module, class_name)
|
| 11 |
+
|
| 12 |
+
return cls
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_model_source_path(identifier: str, prefix: str = "models."):
|
| 16 |
+
module_path, class_name = identifier.split('@')
|
| 17 |
+
|
| 18 |
+
module = importlib.import_module(prefix + module_path)
|
| 19 |
+
return inspect.getsourcefile(module)
|