SayaGugu commited on
Commit
45bc017
·
0 Parent(s):

Upload SR2 ARC-AGI checkpoints (all evaluator steps and configs)

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/all_config.yaml +41 -0
  3. Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_1030727/submission.json +0 -0
  4. Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_1091359/submission.json +0 -0
  5. Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_1151990/submission.json +0 -0
  6. Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_121262/submission.json +0 -0
  7. Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_1212621/submission.json +0 -0
  8. Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_1273252/submission.json +0 -0
  9. Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_1333883/submission.json +0 -0
  10. Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_1394514/submission.json +0 -0
  11. Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_1455145/submission.json +0 -0
  12. Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_1515776/submission.json +0 -0
  13. Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_1576408/submission.json +0 -0
  14. Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_1637039/submission.json +0 -0
  15. Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_1697671/submission.json +0 -0
  16. Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_1758302/submission.json +0 -0
  17. Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_181893/submission.json +0 -0
  18. Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_1818933/submission.json +0 -0
  19. Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_242524/submission.json +0 -0
  20. Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_303155/submission.json +0 -0
  21. Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_363786/submission.json +0 -0
  22. Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_424417/submission.json +0 -0
  23. Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_485048/submission.json +0 -0
  24. Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_545679/submission.json +0 -0
  25. Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_606308/submission.json +0 -0
  26. Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_60631/submission.json +0 -0
  27. Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_666939/submission.json +0 -0
  28. Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_727571/submission.json +0 -0
  29. Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_788203/submission.json +0 -0
  30. Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_848834/submission.json +0 -0
  31. Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_909465/submission.json +0 -0
  32. Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_970096/submission.json +0 -0
  33. Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/losses.py +105 -0
  34. Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/sr2.py +360 -0
  35. Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/step_1818933 +3 -0
  36. Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 quizzical-labradoodle/all_config.yaml +41 -0
  37. Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 quizzical-labradoodle/evaluator_ARC_step_1036146/submission.json +0 -0
  38. Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 quizzical-labradoodle/evaluator_ARC_step_103616/submission.json +0 -0
  39. Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 quizzical-labradoodle/evaluator_ARC_step_1087953/submission.json +0 -0
  40. Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 quizzical-labradoodle/evaluator_ARC_step_1139760/submission.json +0 -0
  41. Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 quizzical-labradoodle/evaluator_ARC_step_1191567/submission.json +0 -0
  42. Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 quizzical-labradoodle/evaluator_ARC_step_1243374/submission.json +0 -0
  43. Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 quizzical-labradoodle/evaluator_ARC_step_1295181/submission.json +0 -0
  44. Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 quizzical-labradoodle/evaluator_ARC_step_1346988/submission.json +0 -0
  45. Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 quizzical-labradoodle/evaluator_ARC_step_1398795/submission.json +0 -0
  46. Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 quizzical-labradoodle/evaluator_ARC_step_1450603/submission.json +0 -0
  47. Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 quizzical-labradoodle/evaluator_ARC_step_1502410/submission.json +0 -0
  48. Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 quizzical-labradoodle/evaluator_ARC_step_1554218/submission.json +0 -0
  49. Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 quizzical-labradoodle/evaluator_ARC_step_155423/submission.json +0 -0
  50. Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 quizzical-labradoodle/evaluator_ARC_step_207230/submission.json +0 -0
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ step_* filter=lfs diff=lfs merge=lfs -text
Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/all_config.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ arch:
2
+ H_cycles: 1
3
+ H_layers: 16
4
+ expansion: 4
5
+ halt_exploration_prob: 0.1
6
+ halt_max_steps: 16
7
+ hidden_size: 512
8
+ loss:
9
+ loss_type: stablemax_cross_entropy
10
+ name: losses@ACTLossHead
11
+ name: hrm.sr2@HierarchicalReasoningModel_ACTV3
12
+ num_heads: 8
13
+ pos_encodings: rope
14
+ puzzle_emb_ndim: 512
15
+ beta1: 0.9
16
+ beta2: 0.95
17
+ checkpoint_every_eval: true
18
+ checkpoint_path: checkpoints/Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3
19
+ spectacular-dragon
20
+ data_path: data/arc-2-aug-1000
21
+ ema_decay: 0.999
22
+ ema_device: cuda
23
+ ema_enabled: true
24
+ ema_use_buffers: true
25
+ epochs: 300000
26
+ eval_interval: 10000
27
+ eval_save_outputs: []
28
+ evaluators:
29
+ - name: arc@ARC
30
+ global_batch_size: 768
31
+ load_checkpoint: null
32
+ lr: 0.0001
33
+ lr_min_ratio: 1.0
34
+ lr_warmup_steps: 2000
35
+ project_name: Arc-2-aug-1000 ACT-torch
36
+ puzzle_emb_lr: 0.01
37
+ puzzle_emb_weight_decay: 0.1
38
+ run_name: HierarchicalReasoningModel_ACTV3 spectacular-dragon
39
+ seed: 0
40
+ target_q_update_every: 4
41
+ weight_decay: 0.1
Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_1030727/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_1091359/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_1151990/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_121262/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_1212621/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_1273252/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_1333883/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_1394514/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_1455145/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_1515776/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_1576408/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_1637039/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_1697671/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_1758302/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_181893/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_1818933/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_242524/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_303155/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_363786/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_424417/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_485048/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_545679/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_606308/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_60631/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_666939/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_727571/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_788203/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_848834/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_909465/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/evaluator_ARC_step_970096/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/losses.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Tuple, Dict, Set, Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+
7
+
8
+ IGNORE_LABEL_ID = -100
9
+
10
+
11
+ def s(x, epsilon=1e-30):
12
+ return torch.where(
13
+ x<0,
14
+ 1/(1-x+ epsilon),
15
+ x + 1
16
+ )
17
+
18
+
19
+ def log_stablemax(x, dim=-1):
20
+ s_x = s(x)
21
+ return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
22
+
23
+
24
+ def stablemax_cross_entropy(logits, labels, ignore_index: int = -100):
25
+ logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
26
+
27
+ valid_mask = labels != ignore_index
28
+ transformed_labels = torch.where(valid_mask, labels, 0)
29
+ prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
30
+
31
+ return -torch.where(valid_mask, prediction_logprobs, 0)
32
+
33
+
34
+ def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
35
+ # Cast logits to f32
36
+ # Flatten logits
37
+ return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
38
+
39
+
40
+ class ACTLossHead(nn.Module):
41
+ def __init__(self, model: nn.Module, loss_type: str):
42
+ super().__init__()
43
+ self.model = model
44
+ self.loss_fn = globals()[loss_type]
45
+
46
+ def initial_carry(self, *args, **kwargs):
47
+ return self.model.initial_carry(*args, **kwargs) # type: ignore
48
+
49
+ def forward(
50
+ self,
51
+ return_keys: Set[str],
52
+ # Model args
53
+ **model_kwargs,
54
+ ) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
55
+ # Model logits
56
+ # B x SeqLen x D
57
+ new_carry, outputs = self.model(**model_kwargs)
58
+ labels = new_carry.current_data["labels"]
59
+
60
+ # Correctness
61
+ with torch.no_grad():
62
+ # Preds
63
+ outputs["preds"] = torch.argmax(outputs["logits"], dim=-1)
64
+
65
+ # Correctness
66
+ mask = labels != IGNORE_LABEL_ID
67
+ loss_counts = mask.sum(-1)
68
+ loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
69
+
70
+ is_correct = mask & (outputs["preds"] == labels)
71
+ seq_is_correct = is_correct.sum(-1) == loss_counts
72
+
73
+ # Metrics (halted)
74
+ valid_metrics = new_carry.halted & (loss_counts > 0)
75
+ metrics = {
76
+ "count": valid_metrics.sum(),
77
+
78
+ "accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
79
+ "exact_accuracy": (valid_metrics & seq_is_correct).sum(),
80
+
81
+ "q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
82
+ "steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
83
+ }
84
+
85
+ # Losses
86
+ # FIXME: Assuming the batch is always full
87
+ lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID) / loss_divisor).sum()
88
+ q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
89
+
90
+ metrics.update({
91
+ "lm_loss": lm_loss.detach(),
92
+ "q_halt_loss": q_halt_loss.detach(),
93
+ })
94
+
95
+ # Q continue (bootstrapping target loss)
96
+ q_continue_loss = 0
97
+ if "target_q_continue" in outputs:
98
+ q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
99
+
100
+ metrics["q_continue_loss"] = q_continue_loss.detach()
101
+
102
+ # Filter outputs for return
103
+ detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
104
+
105
+ return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()
Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/sr2.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HRM ACT V2: Transformer Baseline for Architecture Ablation
3
+
4
+ This is an architecture ablation of the Hierarchical Reasoning Model (HRM).
5
+ Key changes from V1:
6
+ 1. REMOVED hierarchical split (no separate H and L levels)
7
+ 2. REMOVED inner cycles (no H_cycles/L_cycles loops within reasoning)
8
+ 3. KEPT ACT outer loop structure intact
9
+ 4. KEPT all data preprocessing, embeddings, and evaluation infrastructure
10
+
11
+ Architecture: Single-level transformer that processes the full 30x30 grid as a
12
+ 900-token sequence, with the same positional encodings and sparse embeddings as V1.
13
+
14
+ """
15
+
16
+ from typing import Tuple, List, Dict, Optional
17
+ from dataclasses import dataclass
18
+ import math
19
+
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from torch import nn
23
+ from pydantic import BaseModel
24
+
25
+ from models.common import trunc_normal_init_
26
+ from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
27
+ from models.sparse_embedding import CastedSparseEmbedding
28
+
29
+
30
+ @dataclass
31
+ class HierarchicalReasoningModel_ACTV3InnerCarry:
32
+ z_H: torch.Tensor
33
+
34
+
35
+ @dataclass
36
+ class HierarchicalReasoningModel_ACTV3Carry:
37
+ inner_carry: HierarchicalReasoningModel_ACTV3InnerCarry
38
+
39
+ steps: torch.Tensor
40
+ halted: torch.Tensor
41
+
42
+ current_data: Dict[str, torch.Tensor]
43
+
44
+
45
+ class HierarchicalReasoningModel_ACTV3Config(BaseModel):
46
+ batch_size: int
47
+ seq_len: int
48
+ puzzle_emb_ndim: int = 0
49
+ num_puzzle_identifiers: int
50
+ vocab_size: int
51
+
52
+ H_cycles: int
53
+
54
+ H_layers: int
55
+
56
+ # Transformer config
57
+ hidden_size: int
58
+ expansion: float
59
+ num_heads: int
60
+ pos_encodings: str
61
+
62
+ rms_norm_eps: float = 1e-5
63
+ rope_theta: float = 10000.0
64
+
65
+ # Halting Q-learning config
66
+ halt_max_steps: int
67
+ halt_exploration_prob: float
68
+ act_enabled: bool = True # If False, always run halt_max_steps (no early stopping during training)
69
+ act_inference: bool = False # If True, use adaptive computation during inference
70
+
71
+ forward_dtype: str = "bfloat16"
72
+
73
+
74
+ class HierarchicalReasoningModel_ACTV3Block(nn.Module):
75
+ def __init__(self, config: HierarchicalReasoningModel_ACTV3Config) -> None:
76
+ super().__init__()
77
+
78
+ self.self_attn = Attention(
79
+ hidden_size=config.hidden_size,
80
+ head_dim=config.hidden_size // config.num_heads,
81
+ num_heads=config.num_heads,
82
+ num_key_value_heads=config.num_heads,
83
+ causal=False,
84
+ )
85
+ self.mlp = SwiGLU(
86
+ hidden_size=config.hidden_size,
87
+ expansion=config.expansion,
88
+ )
89
+ self.norm_eps = config.rms_norm_eps
90
+
91
+ def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
92
+ # Post Norm
93
+ # Self Attention
94
+ hidden_states = rms_norm(
95
+ hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states),
96
+ variance_epsilon=self.norm_eps,
97
+ )
98
+ # Fully Connected
99
+ hidden_states = rms_norm(hidden_states + self.mlp(hidden_states), variance_epsilon=self.norm_eps)
100
+ return hidden_states
101
+
102
+
103
+ class HierarchicalReasoningModel_ACTV3ReasoningModule(nn.Module):
104
+ """
105
+ 说明(已修改):
106
+ - 原来通过传入 List[Block] 构造若干“参数彼此独立”的层。
107
+ - 现在改为只持有一个共享的 block,并通过 `self.repeats` 在前向中重复调用同一个 block。
108
+ - 为了减少对外部代码的影响,仍然保留 `self.layers` 属性,但其只包含一个共享 block。
109
+ """
110
+ def __init__(self, block: HierarchicalReasoningModel_ACTV3Block, repeats: int):
111
+ super().__init__()
112
+ # 仅注册一个共享 block(保持属性名 layers 以避免外部依赖破坏)
113
+ self.layers = torch.nn.ModuleList([block])
114
+ self.repeats = int(repeats)
115
+
116
+ def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
117
+ # Input injection (add)
118
+
119
+ # hidden_states = hidden_states + input_injection
120
+ # 使用同一个 block 重复 n 次(权重完全共享)
121
+ shared_block = self.layers[0]
122
+ for _ in range(self.repeats):
123
+ hidden_states = shared_block(hidden_states=hidden_states + input_injection, **kwargs)
124
+ return hidden_states
125
+
126
+
127
+ class HierarchicalReasoningModel_ACTV3_Inner(nn.Module):
128
+ def __init__(self, config: HierarchicalReasoningModel_ACTV3Config) -> None:
129
+ super().__init__()
130
+ self.config = config
131
+ self.forward_dtype = getattr(torch, self.config.forward_dtype)
132
+
133
+ # I/O
134
+ self.embed_scale = math.sqrt(self.config.hidden_size)
135
+ embed_init_std = 1.0 / self.embed_scale
136
+
137
+ self.embed_tokens = CastedEmbedding(
138
+ self.config.vocab_size,
139
+ self.config.hidden_size,
140
+ init_std=embed_init_std,
141
+ cast_to=self.forward_dtype,
142
+ )
143
+ self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
144
+ self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
145
+
146
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
147
+ if self.config.puzzle_emb_ndim > 0:
148
+ # Zero init puzzle embeddings
149
+ self.puzzle_emb = CastedSparseEmbedding(
150
+ self.config.num_puzzle_identifiers,
151
+ self.config.puzzle_emb_ndim,
152
+ batch_size=self.config.batch_size,
153
+ init_std=0,
154
+ cast_to=self.forward_dtype,
155
+ )
156
+
157
+ # LM Blocks
158
+ if self.config.pos_encodings == "rope":
159
+ self.rotary_emb = RotaryEmbedding(
160
+ dim=self.config.hidden_size // self.config.num_heads,
161
+ max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
162
+ base=self.config.rope_theta,
163
+ )
164
+ elif self.config.pos_encodings == "learned":
165
+ self.embed_pos = CastedEmbedding(
166
+ self.config.seq_len + self.puzzle_emb_len,
167
+ self.config.hidden_size,
168
+ init_std=embed_init_std,
169
+ cast_to=self.forward_dtype,
170
+ )
171
+ else:
172
+ raise NotImplementedError()
173
+
174
+ # Reasoning Layers
175
+ # self.H_level = HierarchicalReasoningModel_ACTV3ReasoningModule(
176
+ # layers=[HierarchicalReasoningModel_ACTV3Block(self.config) for _i in range(self.config.H_layers)]
177
+ # )
178
+
179
+ H_block = HierarchicalReasoningModel_ACTV3Block(self.config)
180
+ self.H_level = HierarchicalReasoningModel_ACTV3ReasoningModule(
181
+ block=H_block,
182
+ repeats=self.config.H_layers
183
+ )
184
+
185
+ # Initial states
186
+ self.H_init = nn.Buffer(
187
+ trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1),
188
+ persistent=True,
189
+ )
190
+
191
+ # Q head special init
192
+ # Init Q to (almost) zero for faster learning during bootstrapping
193
+ with torch.no_grad():
194
+ self.q_head.weight.zero_()
195
+ self.q_head.bias.fill_(-5) # type: ignore
196
+
197
+ def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
198
+ # Token embedding
199
+ embedding = self.embed_tokens(input.to(torch.int32))
200
+
201
+ # Puzzle embeddings
202
+ if self.config.puzzle_emb_ndim > 0:
203
+ puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
204
+
205
+ pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
206
+ if pad_count > 0:
207
+ puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
208
+
209
+ embedding = torch.cat(
210
+ (puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2
211
+ )
212
+
213
+ # Position embeddings
214
+ if self.config.pos_encodings == "learned":
215
+ # scale by 1/sqrt(2) to maintain forward variance
216
+ embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
217
+
218
+ # Scale
219
+ return self.embed_scale * embedding
220
+
221
+ def empty_carry(self, batch_size: int):
222
+ return HierarchicalReasoningModel_ACTV3InnerCarry(
223
+ z_H=torch.empty(
224
+ batch_size,
225
+ self.config.seq_len + self.puzzle_emb_len,
226
+ self.config.hidden_size,
227
+ dtype=self.forward_dtype,
228
+ ),
229
+ )
230
+
231
+ def reset_carry(self, reset_flag: torch.Tensor, carry: HierarchicalReasoningModel_ACTV3InnerCarry):
232
+ return HierarchicalReasoningModel_ACTV3InnerCarry(
233
+ z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
234
+ )
235
+
236
+ def forward(
237
+ self, carry: HierarchicalReasoningModel_ACTV3InnerCarry, batch: Dict[str, torch.Tensor], carry_steps: torch.Tensor
238
+ ) -> Tuple[HierarchicalReasoningModel_ACTV3InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
239
+ seq_info = dict(
240
+ cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
241
+ )
242
+
243
+ # Input encoding
244
+ input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
245
+
246
+ # 仅第一步注入 input
247
+ first_step_mask = (carry_steps == 0).view(-1, 1, 1).to(device=input_embeddings.device, dtype=input_embeddings.dtype)
248
+ gated_injection = input_embeddings * first_step_mask # [B,S_full,D]
249
+
250
+ # 1-step grad
251
+ z_H = self.H_level(carry.z_H, gated_injection, **seq_info)
252
+
253
+ # LM Outputs
254
+ new_carry = HierarchicalReasoningModel_ACTV3InnerCarry(
255
+ z_H=z_H.detach(),
256
+ ) # New carry no grad
257
+ output = self.lm_head(z_H)[:, self.puzzle_emb_len :]
258
+
259
+ # Q head
260
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
261
+
262
+ return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
263
+
264
+
265
+ class HierarchicalReasoningModel_ACTV3(nn.Module):
266
+ """ACT wrapper."""
267
+
268
+ def __init__(self, config_dict: dict):
269
+ super().__init__()
270
+ self.config = HierarchicalReasoningModel_ACTV3Config(**config_dict)
271
+ self.inner = HierarchicalReasoningModel_ACTV3_Inner(self.config)
272
+
273
+ @property
274
+ def puzzle_emb(self):
275
+ return self.inner.puzzle_emb
276
+
277
+ def initial_carry(self, batch: Dict[str, torch.Tensor]):
278
+ batch_size = batch["inputs"].shape[0]
279
+
280
+ return HierarchicalReasoningModel_ACTV3Carry(
281
+ inner_carry=self.inner.empty_carry(
282
+ batch_size
283
+ ), # Empty is expected, it will be reseted in first pass as all sequences are halted.
284
+ steps=torch.zeros((batch_size,), dtype=torch.int32),
285
+ halted=torch.ones((batch_size,), dtype=torch.bool), # Default to halted
286
+ current_data={k: torch.empty_like(v) for k, v in batch.items()},
287
+ )
288
+
289
+ def forward(
290
+ self,
291
+ carry: HierarchicalReasoningModel_ACTV3Carry,
292
+ batch: Dict[str, torch.Tensor],
293
+ compute_target_q: bool = False,
294
+ ) -> Tuple[HierarchicalReasoningModel_ACTV3Carry, Dict[str, torch.Tensor]]:
295
+ # Update data, carry (removing halted sequences)
296
+ new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
297
+
298
+ new_steps = torch.where(carry.halted, 0, carry.steps)
299
+
300
+ new_current_data = {
301
+ k: torch.where(carry.halted.view((-1,) + (1,) * (batch[k].ndim - 1)), batch[k], v)
302
+ for k, v in carry.current_data.items()
303
+ }
304
+
305
+ # Forward inner model
306
+ new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(
307
+ new_inner_carry, new_current_data, new_steps
308
+ )
309
+
310
+ outputs = {"logits": logits, "q_halt_logits": q_halt_logits, "q_continue_logits": q_continue_logits}
311
+
312
+ with torch.no_grad():
313
+ # Step
314
+ new_steps = new_steps + 1
315
+ is_last_step = new_steps >= self.config.halt_max_steps
316
+
317
+ halted = is_last_step
318
+
319
+ # Check if adaptive computation should be used
320
+ use_adaptive = (self.config.halt_max_steps > 1) and (
321
+ (self.training and self.config.act_enabled)
322
+ or (not self.training and self.config.act_inference)
323
+ )
324
+
325
+ if use_adaptive:
326
+ # Halt signal based on Q-values (but always halt at max steps)
327
+ q_halt_signal = q_halt_logits > q_continue_logits
328
+ halted = halted | q_halt_signal
329
+
330
+ # Store actual steps used for logging (only during inference)
331
+ if not self.training:
332
+ outputs["actual_steps"] = new_steps.float()
333
+
334
+ # Exploration (only during training)
335
+ if self.training:
336
+ min_halt_steps = (
337
+ torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob
338
+ ) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
339
+ halted = halted & (new_steps >= min_halt_steps)
340
+
341
+ # Compute target Q (only during training)
342
+ # NOTE: No replay buffer and target networks for computing target Q-value.
343
+ # As batch_size is large, there're many parallel envs.
344
+ # Similar concept as PQN https://arxiv.org/abs/2407.04811
345
+ if self.training and compute_target_q:
346
+ next_q_halt_logits, next_q_continue_logits = self.inner(
347
+ new_inner_carry, new_current_data, new_steps
348
+ )[-1]
349
+
350
+ outputs["target_q_continue"] = torch.sigmoid(
351
+ torch.where(
352
+ is_last_step,
353
+ next_q_halt_logits,
354
+ torch.maximum(next_q_halt_logits, next_q_continue_logits),
355
+ )
356
+ )
357
+
358
+ return HierarchicalReasoningModel_ACTV3Carry(
359
+ new_inner_carry, new_steps, halted, new_current_data
360
+ ), outputs
Arc-2-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 spectacular-dragon/step_1818933 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c8436b037d0f366ca02870c66dfd9a22590f17677aee752bf3ac3f43b2089a10
3
+ size 4311137979
Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 quizzical-labradoodle/all_config.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ arch:
2
+ H_cycles: 1
3
+ H_layers: 16
4
+ expansion: 4
5
+ halt_exploration_prob: 0.1
6
+ halt_max_steps: 16
7
+ hidden_size: 512
8
+ loss:
9
+ loss_type: stablemax_cross_entropy
10
+ name: losses@ACTLossHead
11
+ name: hrm.sr2@HierarchicalReasoningModel_ACTV3
12
+ num_heads: 8
13
+ pos_encodings: rope
14
+ puzzle_emb_ndim: 512
15
+ beta1: 0.9
16
+ beta2: 0.95
17
+ checkpoint_every_eval: true
18
+ checkpoint_path: checkpoints/Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3
19
+ quizzical-labradoodle
20
+ data_path: data/arc-aug-1000
21
+ ema_decay: 0.999
22
+ ema_device: cuda
23
+ ema_enabled: true
24
+ ema_use_buffers: true
25
+ epochs: 300000
26
+ eval_interval: 10000
27
+ eval_save_outputs: []
28
+ evaluators:
29
+ - name: arc@ARC
30
+ global_batch_size: 768
31
+ load_checkpoint: null
32
+ lr: 0.0001
33
+ lr_min_ratio: 1.0
34
+ lr_warmup_steps: 2000
35
+ project_name: Arc-aug-1000 ACT-torch
36
+ puzzle_emb_lr: 0.01
37
+ puzzle_emb_weight_decay: 0.1
38
+ run_name: HierarchicalReasoningModel_ACTV3 quizzical-labradoodle
39
+ seed: 0
40
+ target_q_update_every: 4
41
+ weight_decay: 0.1
Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 quizzical-labradoodle/evaluator_ARC_step_1036146/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 quizzical-labradoodle/evaluator_ARC_step_103616/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 quizzical-labradoodle/evaluator_ARC_step_1087953/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 quizzical-labradoodle/evaluator_ARC_step_1139760/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 quizzical-labradoodle/evaluator_ARC_step_1191567/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 quizzical-labradoodle/evaluator_ARC_step_1243374/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 quizzical-labradoodle/evaluator_ARC_step_1295181/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 quizzical-labradoodle/evaluator_ARC_step_1346988/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 quizzical-labradoodle/evaluator_ARC_step_1398795/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 quizzical-labradoodle/evaluator_ARC_step_1450603/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 quizzical-labradoodle/evaluator_ARC_step_1502410/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 quizzical-labradoodle/evaluator_ARC_step_1554218/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 quizzical-labradoodle/evaluator_ARC_step_155423/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV3 quizzical-labradoodle/evaluator_ARC_step_207230/submission.json ADDED
The diff for this file is too large to render. See raw diff