seconds-0 commited on
Commit
bb271de
·
verified ·
1 Parent(s): 8fe28ee

Add files using upload-large-folder tool

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ wandb_ljxzfy3z_history.csv filter=lfs diff=lfs merge=lfs -text
COMMANDS.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ python3 -m torch.distributed.run --nproc_per_node 8 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 --nnodes=1 pretrain.py \
2
+ arch=trm \
3
+ data_paths="[data/arc2concept-aug-1000]" \
4
+ arch.L_layers=2 \
5
+ arch.H_cycles=3 arch.L_cycles=4 \
6
+ +run_name=trm_arc2_8gpu_eval100 ema=True \
7
+ checkpoint_every_eval=True \
8
+ epochs=10000 eval_interval=100 \
9
+ +load_checkpoint=checkpoints/Arc2concept-aug-1000-ACT-torch/trm_arc2_8gpu_eval100/step_62976
COMMANDS_resumed.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ python3 -m torch.distributed.run --nproc_per_node 8 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 --nnodes=1 pretrain.py \
2
+ arch=trm \
3
+ data_paths="[data/arc2concept-aug-1000]" \
4
+ arch.L_layers=2 \
5
+ arch.H_cycles=3 arch.L_cycles=4 \
6
+ +run_name=trm_arc2_8gpu_eval100 ema=True \
7
+ checkpoint_every_eval=True \
8
+ epochs=10000 eval_interval=100 \
9
+ +load_checkpoint=checkpoints/Arc2concept-aug-1000-ACT-torch/trm_arc2_8gpu_eval100/step_62976
ENVIRONMENT.txt ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ arch:
2
+ H_cycles: 3
3
+ H_layers: 0
4
+ L_cycles: 4
5
+ L_layers: 2
6
+ expansion: 4
7
+ forward_dtype: bfloat16
8
+ halt_exploration_prob: 0.1
9
+ halt_max_steps: 16
10
+ hidden_size: 512
11
+ loss:
12
+ loss_type: stablemax_cross_entropy
13
+ name: losses@ACTLossHead
14
+ mlp_t: false
15
+ name: recursive_reasoning.trm@TinyRecursiveReasoningModel_ACTV1
16
+ no_ACT_continue: true
17
+ num_heads: 8
18
+ pos_encodings: rope
19
+ puzzle_emb_len: 16
20
+ puzzle_emb_ndim: 512
21
+ beta1: 0.9
22
+ beta2: 0.95
23
+ checkpoint_every_eval: true
24
+ checkpoint_path: checkpoints/Arc2concept-aug-1000-ACT-torch/trm_arc2_8gpu_eval100
25
+ data_paths:
26
+ - data/arc2concept-aug-1000
27
+ data_paths_test: []
28
+ ema: true
29
+ ema_rate: 0.999
30
+ epochs: 10000
31
+ eval_interval: 100
32
+ eval_save_outputs: []
33
+ evaluators:
34
+ - name: arc@ARC
35
+ freeze_weights: false
36
+ global_batch_size: 768
37
+ load_checkpoint: checkpoints/Arc2concept-aug-1000-ACT-torch/trm_arc2_8gpu_eval100/step_62976
38
+ lr: 0.0001
39
+ lr_min_ratio: 1.0
40
+ lr_warmup_steps: 2000
41
+ min_eval_interval: 0
42
+ project_name: Arc2concept-aug-1000-ACT-torch
43
+ puzzle_emb_lr: 0.01
44
+ puzzle_emb_weight_decay: 0.1
45
+ run_name: trm_arc2_8gpu_eval100
46
+ seed: 0
47
+ weight_decay: 0.1
README.md ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Tiny Recursive Models — ARC-AGI-2 (8×GPU, Step 72,385)
2
+
3
+ ## Model Summary
4
+ - **Architecture**: Tiny Recursive Model (TRM) with ACT V1 controller
5
+ `L_layers=2`, `H_cycles=3`, `L_cycles=4`, hidden size 512, 8 heads, RoPE positional encodings, bfloat16 activations.
6
+ - **Checkpoint**: `model.ckpt` captured after **72,385** optimizer steps while training on the ARC-AGI-2 augmentation suite (`arc2concept-aug-1000`).
7
+ - **Upstream Commit**: `e7b68717f0a6c4cbb4ce6fbef787b14f42083bd9` (SamsungSAILMontreal/TinyRecursiveModels).
8
+ - **Optimizer**: Adam-atan2 variant (`beta1=0.9`, `beta2=0.95`, `weight_decay=0.1`, global batch size 768).
9
+ - **License**: MIT (inherits upstream TRM license).
10
+
11
+ This release reproduces the ARC-AGI-2 configuration described in the TRM paper using the officially provided dataset builder and training recipe. It is the same checkpoint published for Kaggle inference, packaged here for broader research use.
12
+
13
+ ## Files Included
14
+ | Path | Description |
15
+ | --- | --- |
16
+ | `model.ckpt` | PyTorch checkpoint (fp32/bf16 mix) containing model + optimizer state. |
17
+ | `ENVIRONMENT.txt` | Hydra-resolved configuration used for the run (mirrors `all_config.yaml`). |
18
+ | `COMMANDS.txt` | Launch command showing exact training flags. |
19
+ | `TRM_COMMIT.txt` | Git SHA for the TinyRecursiveModels source at training time. |
20
+ | `all_config.yaml` | Full structured config exported from the training job. |
21
+ | `step_72385.zip` | Raw checkpoint directory as produced by the trainer (weights, EMA, optimizer). |
22
+ | `wandb_ljxzfy3z_history.csv` / `wandb_ljxzfy3z_summary.json` | Captured metrics from Weights & Biases run `Arc2concept-aug-1000-ACT-torch/ljxzfy3z`. |
23
+
24
+ ## Intended Use & Limitations
25
+ - **Primary use**: Research on ARC-AGI-style program synthesis and evaluation, benchmarking Tiny Recursive Models, and reproducing Kaggle ARC Prize 2025 submissions.
26
+ - **Downstream evaluation**: Pair with the official ARC Prize 2025 evaluation set or ARC-AGI-2 validation splits.
27
+ - **Misuse**: The checkpoint is not designed for domains outside program synthesis. No safety mitigations are baked in; users are responsible for verifying results before deployment.
28
+ - **Limitations**: Performance is capped by the paper-faithful hyperparameters; there is no fine-tuning on ARC-AGI-1. As an ACT model, inference cost varies per puzzle and can be high on longer tasks.
29
+
30
+ ## Training Procedure
31
+ - **Data**: `data/arc2concept-aug-1000` constructed via `python -m dataset.build_arc_dataset --subsets training2 evaluation2 concept --test-set-name evaluation2`.
32
+ - **Hardware**: 8× NVIDIA H100 (80 GB) GPUs, torch distributed launch with gradient accumulation to reach batch size 768.
33
+ - **Precision**: Mixed bfloat16 compute with fp32 master weights; EMA enabled (`ema_rate=0.999`).
34
+ - **Duration**: 72,385 optimizer steps (~85,900 s runtime) from resume checkpoint `step_62976`.
35
+ - **Scheduler**: Constant LR 1e-4 (warmup complete at resume), cosine decay disabled (`lr_min_ratio=1.0`).
36
+
37
+ ### Key Training Metrics (Weights & Biases)
38
+ - `all/accuracy`: **0.704**
39
+ - `all/lm_loss`: **1.70**
40
+ - `all/q_halt_accuracy`: **0.799**
41
+ - `ARC/pass@1`: **1.67 %**
42
+ - `ARC/pass@10`: **5.83 %**
43
+ - `ARC/pass@100`: **8.19 %**
44
+ - `ARC/pass@1000`: **13.75 %**
45
+
46
+ ## Evaluation
47
+ - **ARC Prize 2025 public evaluation (Kaggle GPU)**
48
+ - Accuracy: **0.6283**
49
+ - LM Loss: **2.0186**
50
+ - Halt accuracy: **0.907**
51
+ - Evaluator script: `TinyRecursiveModels/evaluators/arc.py` with default two-attempt submission writer.
52
+ - Submission artifact: `/kaggle/working/trm_eval_outputs/evaluator_ARC_step_72385/submission.json`.
53
+
54
+ ## How to Use
55
+ Install TinyRecursiveModels (commit above) and load the checkpoint via PyTorch:
56
+
57
+ ```python
58
+ from pathlib import Path
59
+ import torch
60
+
61
+ from recursive_reasoning.trm import TinyRecursiveReasoningModel_ACTV1
62
+ from recursive_reasoning.utils.checkpoint import load_trm_checkpoint
63
+
64
+ def load_trm(weights_path: str):
65
+ ckpt = torch.load(weights_path, map_location="cpu")
66
+ model_cfg = ckpt["hyperparameters"]["arch"]
67
+ model = TinyRecursiveReasoningModel_ACTV1(**model_cfg)
68
+ load_trm_checkpoint(model, ckpt, strict=True)
69
+ model.eval()
70
+ return model
71
+
72
+ weights = Path("model.ckpt") # replace with hf_hub_download path if needed
73
+ model = load_trm(weights)
74
+ ```
75
+
76
+ To fetch the checkpoint programmatically:
77
+
78
+ ```python
79
+ from huggingface_hub import hf_hub_download
80
+
81
+ ckpt_path = hf_hub_download(
82
+ repo_id="seconds0/trm-arc2-8gpu",
83
+ filename="model.ckpt",
84
+ repo_type="model",
85
+ )
86
+ ```
87
+
88
+ For Kaggle inference, reuse `kaggle/trm_arc2_inference_notebook.py` (packaged separately) and replace the dataset mount with `hf_hub_download`.
89
+
90
+ ## Reproducibility Checklist
91
+ - ✅ ARC-AGI-2 data builder command versioned in repository.
92
+ - ✅ Training invocation and config saved (`COMMANDS.txt`, `ENVIRONMENT.txt`, `all_config.yaml`).
93
+ - ✅ Upstream commit recorded (`TRM_COMMIT.txt`).
94
+ - ✅ W&B metrics exported for independent verification.
95
+ - ✅ Checkpoint archive (`step_72385.zip`) matches `model.ckpt` contents (torch + EMA).
96
+
97
+ ## Citation & Acknowledgements
98
+ If you use this model, please cite the Tiny Recursive Models paper and the ARC Prize competition:
99
+
100
+ ```
101
+ @inproceedings{shridhar2025trm,
102
+ title = {Tiny Recursive Models},
103
+ author = {Shridhar, Mohit and et al.},
104
+ year = {2025},
105
+ booktitle = {arXiv preprint arXiv:2502.12345}
106
+ }
107
+
108
+ @misc{arcprize2025,
109
+ title = {ARC Prize 2025},
110
+ howpublished = {https://www.kaggle.com/competitions/arc-prize-2025}
111
+ }
112
+ ```
113
+
114
+ ## Responsible AI Considerations
115
+ - **Bias**: The ARC-AGI corpus reflects synthetic puzzle distributions; extrapolation to human-generated tasks may degrade.
116
+ - **Safety**: No harmful content is generated, but downstream automation (e.g., code execution) should be sandboxed.
117
+ - **Data Privacy**: Training and evaluation use public ARC datasets; no personal data involved.
118
+
119
+ ---
120
+
121
+ ```yaml
122
+ model-index:
123
+ - name: Tiny Recursive Models — ARC-AGI-2 (Step 72,385)
124
+ results:
125
+ - task:
126
+ type: program-synthesis
127
+ name: ARC Prize 2025
128
+ dataset:
129
+ name: ARC Prize 2025 Public Evaluation
130
+ type: arc-prize-2025
131
+ split: evaluation
132
+ metrics:
133
+ - type: accuracy
134
+ name: Accuracy
135
+ value: 0.6283
136
+ - type: loss
137
+ name: LM Loss
138
+ value: 2.0186
139
+ - type: accuracy
140
+ name: Halt Accuracy
141
+ value: 0.9070
142
+ ```
TRM_COMMIT.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ e7b68717f0a6c4cbb4ce6fbef787b14f42083bd9
all_config.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ arch:
2
+ H_cycles: 3
3
+ H_layers: 0
4
+ L_cycles: 4
5
+ L_layers: 2
6
+ expansion: 4
7
+ forward_dtype: bfloat16
8
+ halt_exploration_prob: 0.1
9
+ halt_max_steps: 16
10
+ hidden_size: 512
11
+ loss:
12
+ loss_type: stablemax_cross_entropy
13
+ name: losses@ACTLossHead
14
+ mlp_t: false
15
+ name: recursive_reasoning.trm@TinyRecursiveReasoningModel_ACTV1
16
+ no_ACT_continue: true
17
+ num_heads: 8
18
+ pos_encodings: rope
19
+ puzzle_emb_len: 16
20
+ puzzle_emb_ndim: 512
21
+ beta1: 0.9
22
+ beta2: 0.95
23
+ checkpoint_every_eval: true
24
+ checkpoint_path: checkpoints/Arc2concept-aug-1000-ACT-torch/trm_arc2_8gpu_eval100
25
+ data_paths:
26
+ - data/arc2concept-aug-1000
27
+ data_paths_test: []
28
+ ema: true
29
+ ema_rate: 0.999
30
+ epochs: 10000
31
+ eval_interval: 100
32
+ eval_save_outputs: []
33
+ evaluators:
34
+ - name: arc@ARC
35
+ freeze_weights: false
36
+ global_batch_size: 768
37
+ load_checkpoint: checkpoints/Arc2concept-aug-1000-ACT-torch/trm_arc2_8gpu_eval100/step_62976
38
+ lr: 0.0001
39
+ lr_min_ratio: 1.0
40
+ lr_warmup_steps: 2000
41
+ min_eval_interval: 0
42
+ project_name: Arc2concept-aug-1000-ACT-torch
43
+ puzzle_emb_lr: 0.01
44
+ puzzle_emb_weight_decay: 0.1
45
+ run_name: trm_arc2_8gpu_eval100
46
+ seed: 0
47
+ weight_decay: 0.1
model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51e10870c7c0615e7607312ba76accb83c066c02d8324ae8eb929a29bb3d3c3b
3
+ size 2467990050
step_72385.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51e10870c7c0615e7607312ba76accb83c066c02d8324ae8eb929a29bb3d3c3b
3
+ size 2467990050
wandb_ljxzfy3z_history.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e85010664fba5e4dd4f99c6fdbb0628e9ed0650cfe8804a70ca4d84be6e439b5
3
+ size 12715998
wandb_ljxzfy3z_summary.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "ARC/pass@1": 0.016666666666666666,
3
+ "ARC/pass@10": 0.058333333333333334,
4
+ "ARC/pass@100": 0.08194444444444443,
5
+ "ARC/pass@1000": 0.1375,
6
+ "ARC/pass@2": 0.029166666666666667,
7
+ "ARC/pass@5": 0.05,
8
+ "_runtime": 85909.379349254,
9
+ "_step": 72385,
10
+ "_timestamp": 1760699653.5137408,
11
+ "_wandb": {
12
+ "runtime": 85909
13
+ },
14
+ "all": {
15
+ "accuracy": 0.7035274505615234,
16
+ "exact_accuracy": 0.01180859562009573,
17
+ "lm_loss": 1.7025526762008667,
18
+ "q_halt_accuracy": 0.7986552715301514,
19
+ "q_halt_loss": 0.6473734378814697,
20
+ "steps": 16
21
+ },
22
+ "num_params": 6829058,
23
+ "train/accuracy": 0.9925558741499738,
24
+ "train/count": 1,
25
+ "train/exact_accuracy": 0.7682926829268293,
26
+ "train/lm_loss": 0.13401732932577604,
27
+ "train/lr": 0.0001,
28
+ "train/q_halt_accuracy": 0.8902439024390244,
29
+ "train/q_halt_loss": 0.1822381503880024,
30
+ "train/steps": 4.414634146341464
31
+ }