Teen-Different commited on
Commit
138a6a8
·
verified ·
1 Parent(s): 66cd553

Upload ARC-IT model checkpoint

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. README.md +55 -0
  3. config.json +111 -0
  4. model.pt +3 -0
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ model.pt filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - arc-agi
4
+ - abstract-reasoning
5
+ - rule-conditioned-transformer
6
+ - discrete-reasoning
7
+ license: mit
8
+ ---
9
+
10
+ # ARC-IT: Rule-Conditioned Transformer for ARC-AGI
11
+
12
+ A novel architecture that solves abstract reasoning tasks (ARC-AGI) by explicitly
13
+ extracting transformation rules from demonstration pairs and applying them to new inputs:
14
+
15
+ - **GridTokenizer** -- Embeds discrete ARC grids (0-11) into continuous patch tokens
16
+ - **RuleEncoder** -- Extracts transformation rules from demo input/output pairs via cross-attention
17
+ - **RuleApplier** -- Applies the learned rules to a test input via cross-attention
18
+ - **SpatialDecoder** -- Converts output tokens to 64x64 grid logits
19
+
20
+ ## Architecture
21
+
22
+ ```
23
+ Demo Pairs -> GridTokenizer -> RuleEncoder (cross-attention + aggregation) -> Rule Tokens
24
+ Test Input -> GridTokenizer -> RuleApplier (cross-attention to rules) -> SpatialDecoder -> Predicted Grid
25
+ ```
26
+
27
+ ## Training
28
+
29
+ - **2-stage training**: Full Training -> Hard Focus (AGI-2 oversampling)
30
+ - **Test-Time Training (TTT)**: Per-task fine-tuning on demonstration examples
31
+
32
+ ## Model Details
33
+
34
+ - **Training step**: 18000
35
+ - **Best validation accuracy**: 0.733029360572497
36
+ - **Hidden size**: 384
37
+ - **Rule Encoder**: 2 pair layers, 2 agg layers, 64 rule tokens
38
+ - **Rule Applier**: 4 layers, 8 heads
39
+ - **Canvas size**: 64
40
+
41
+ ## Usage
42
+
43
+ ```python
44
+ import torch
45
+ from arc_it.models.arc_it_model import ARCITModel
46
+
47
+ model = ARCITModel.from_config(config)
48
+ ckpt = torch.load("model.pt", map_location="cpu", weights_only=False)
49
+ model.load_state_dict(ckpt["model_state_dict"])
50
+ ```
51
+
52
+ ## Links
53
+
54
+ - **Repository**: [github.com/REDDITARUN/arc_it](https://github.com/REDDITARUN/arc_it)
55
+ - **ARC-AGI**: [arcprize.org](https://arcprize.org)
config.json ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "data": {
3
+ "arc_agi1_path": "References/ARC-AGI",
4
+ "arc_agi2_path": "References/ARC-AGI-2",
5
+ "re_arc_path": "References/RE-ARC",
6
+ "canvas_size": 64,
7
+ "num_colors": 12,
8
+ "max_grid_size": 30,
9
+ "max_demos": 5,
10
+ "re_arc_samples_per_task": 50,
11
+ "repeat_factor": 1,
12
+ "augmentation": {
13
+ "geometric": true,
14
+ "color_permutation": true,
15
+ "num_color_perms": 10,
16
+ "keep_background": true,
17
+ "resolution_scaling": true,
18
+ "translation": true
19
+ }
20
+ },
21
+ "model": {
22
+ "hidden_size": 384,
23
+ "mlp_ratio": 2.5,
24
+ "tokenizer": {
25
+ "patch_size": 4
26
+ },
27
+ "rule_encoder": {
28
+ "pair_layers": 2,
29
+ "agg_layers": 2,
30
+ "num_heads": 8,
31
+ "num_rule_tokens": 64
32
+ },
33
+ "rule_applier": {
34
+ "num_layers": 4,
35
+ "num_heads": 8
36
+ },
37
+ "decoder": {
38
+ "upsample_method": "transposed_conv",
39
+ "hidden_channels": [
40
+ 192,
41
+ 96
42
+ ]
43
+ }
44
+ },
45
+ "training": {
46
+ "batch_size": 64,
47
+ "num_workers": 8,
48
+ "gradient_clip": 1.0,
49
+ "stage1": {
50
+ "name": "pretrain",
51
+ "data_sources": [
52
+ "re_arc"
53
+ ],
54
+ "epochs": 50,
55
+ "lr": 0.0003
56
+ },
57
+ "stage2": {
58
+ "name": "finetune",
59
+ "data_sources": [
60
+ "agi1",
61
+ "agi2"
62
+ ],
63
+ "epochs": 30,
64
+ "lr": 0.0001
65
+ },
66
+ "stage3": {
67
+ "name": "hard_focus",
68
+ "data_sources": [
69
+ "agi1",
70
+ "agi2"
71
+ ],
72
+ "epochs": 10,
73
+ "lr": 3e-05,
74
+ "agi2_oversample": 2.0
75
+ },
76
+ "optimizer": {
77
+ "name": "adamw",
78
+ "weight_decay": 0.01,
79
+ "betas": [
80
+ 0.9,
81
+ 0.999
82
+ ]
83
+ },
84
+ "scheduler": {
85
+ "name": "cosine",
86
+ "warmup_ratio": 0.1
87
+ },
88
+ "log_every_n_steps": 100,
89
+ "save_every_n_epochs": 10,
90
+ "checkpoint_dir": "checkpoints"
91
+ },
92
+ "ttt": {
93
+ "enabled": true,
94
+ "steps": 100,
95
+ "lr": 0.0001,
96
+ "batch_size": 8,
97
+ "num_candidates": 32
98
+ },
99
+ "evaluation": {
100
+ "val_split_ratio": 0.1,
101
+ "val_data_sources": [
102
+ "agi1",
103
+ "agi2"
104
+ ],
105
+ "metrics": [
106
+ "pixel_accuracy",
107
+ "grid_exact_match"
108
+ ],
109
+ "visualize_every_n_tasks": 50
110
+ }
111
+ }
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9cbf16263cf79d22e6f219cb5e11110d6dc6ed84aedb204222900ce9727f119
3
+ size 68062850