Viharikvs commited on
Commit
08b7e03
·
verified ·
1 Parent(s): 1732005

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. Arc1concept-aug-1000-ACT-torch/pretrain_glps_arc1_h200_v1/all_config.yaml +61 -0
  3. Arc1concept-aug-1000-ACT-torch/pretrain_glps_arc1_h200_v1/evaluator_ARC_step_10361/submission.json +0 -0
  4. Arc1concept-aug-1000-ACT-torch/pretrain_glps_arc1_h200_v1/evaluator_ARC_step_20722/submission.json +0 -0
  5. Arc1concept-aug-1000-ACT-torch/pretrain_glps_arc1_h200_v1/evaluator_ARC_step_31083/submission.json +0 -0
  6. Arc1concept-aug-1000-ACT-torch/pretrain_glps_arc1_h200_v1/evaluator_ARC_step_41444/submission.json +0 -0
  7. Arc1concept-aug-1000-ACT-torch/pretrain_glps_arc1_h200_v1/glps.py +409 -0
  8. Arc1concept-aug-1000-ACT-torch/pretrain_glps_arc1_h200_v1/losses.py +102 -0
  9. Arc1concept-aug-1000-ACT-torch/pretrain_glps_arc1_h200_v1/step_10361 +3 -0
  10. Arc1concept-aug-1000-ACT-torch/pretrain_glps_arc1_h200_v1/step_20722 +3 -0
  11. Arc1concept-aug-1000-ACT-torch/pretrain_glps_arc1_h200_v1/step_31083 +3 -0
  12. Arc1concept-aug-1000-ACT-torch/pretrain_glps_arc1_h200_v1/step_41444 +3 -0
  13. Maze-30x30-hard-1k-ACT-torch/pretrain_glps_maze30_v1_cos/all_config.yaml +60 -0
  14. Maze-30x30-hard-1k-ACT-torch/pretrain_glps_maze30_v1_cos/glps.py +409 -0
  15. Maze-30x30-hard-1k-ACT-torch/pretrain_glps_maze30_v1_cos/losses.py +102 -0
  16. Maze-30x30-hard-1k-ACT-torch/pretrain_glps_maze30_v1_cos_b160/all_config.yaml +60 -0
  17. Maze-30x30-hard-1k-ACT-torch/pretrain_glps_maze30_v1_cos_b160/glps.py +409 -0
  18. Maze-30x30-hard-1k-ACT-torch/pretrain_glps_maze30_v1_cos_b160/losses.py +102 -0
  19. Maze-30x30-hard-1k-ACT-torch/pretrain_glps_maze30_v1_e10k_b160/all_config.yaml +60 -0
  20. Maze-30x30-hard-1k-ACT-torch/pretrain_glps_maze30_v1_e10k_b160/glps.py +409 -0
  21. Maze-30x30-hard-1k-ACT-torch/pretrain_glps_maze30_v1_e10k_b160/losses.py +102 -0
  22. Maze-30x30-hard-1k-ACT-torch/pretrain_glps_maze30_v1_e10k_evalboost_extreme/all_config.yaml +60 -0
  23. Maze-30x30-hard-1k-ACT-torch/pretrain_glps_maze30_v1_e10k_evalboost_extreme/glps.py +409 -0
  24. Maze-30x30-hard-1k-ACT-torch/pretrain_glps_maze30_v1_e10k_evalboost_extreme/losses.py +102 -0
  25. Maze-30x30-hard-1k-ACT-torch/pretrain_glps_maze30_v1_e10k_evalboost_ultra/all_config.yaml +60 -0
  26. Maze-30x30-hard-1k-ACT-torch/pretrain_glps_maze30_v1_e10k_evalboost_ultra/glps.py +409 -0
  27. Maze-30x30-hard-1k-ACT-torch/pretrain_glps_maze30_v1_e10k_evalboost_ultra/losses.py +102 -0
  28. Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku/all_config.yaml +56 -0
  29. Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku/glps.py +390 -0
  30. Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku/losses.py +102 -0
  31. Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_nocompile/all_config.yaml +56 -0
  32. Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_nocompile/glps.py +406 -0
  33. Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_nocompile/losses.py +102 -0
  34. Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v2/all_config.yaml +56 -0
  35. Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v2/glps.py +409 -0
  36. Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v2/losses.py +102 -0
  37. Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v4/all_config.yaml +56 -0
  38. Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v4/glps.py +409 -0
  39. Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v4/losses.py +102 -0
  40. Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v4_decay/all_config.yaml +56 -0
  41. Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v4_decay/glps.py +409 -0
  42. Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v4_decay/losses.py +102 -0
  43. Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v4_decay_ft10k/all_config.yaml +56 -0
  44. Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v4_decay_ft10k/glps.py +409 -0
  45. Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v4_decay_ft10k/losses.py +102 -0
  46. Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v4_evalboost/all_config.yaml +60 -0
  47. Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v4_evalboost/glps.py +409 -0
  48. Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v4_evalboost/losses.py +102 -0
  49. Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v4_evalboost2/all_config.yaml +60 -0
  50. Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v4_evalboost2/glps.py +409 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Arc1concept-aug-1000-ACT-torch/pretrain_glps_arc1_h200_v1/step_10361 filter=lfs diff=lfs merge=lfs -text
37
+ Arc1concept-aug-1000-ACT-torch/pretrain_glps_arc1_h200_v1/step_20722 filter=lfs diff=lfs merge=lfs -text
38
+ Arc1concept-aug-1000-ACT-torch/pretrain_glps_arc1_h200_v1/step_31083 filter=lfs diff=lfs merge=lfs -text
39
+ Arc1concept-aug-1000-ACT-torch/pretrain_glps_arc1_h200_v1/step_41444 filter=lfs diff=lfs merge=lfs -text
Arc1concept-aug-1000-ACT-torch/pretrain_glps_arc1_h200_v1/all_config.yaml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ arch:
2
+ H_cycles: 3
3
+ H_layers: 2
4
+ L_cycles: 1
5
+ L_layers: 6
6
+ dep_rank: 64
7
+ dep_topk: 12
8
+ expansion: 4.0
9
+ forward_dtype: bfloat16
10
+ glps_dep_graph: true
11
+ glps_enabled: true
12
+ glps_fill_obvious: true
13
+ glps_global_propagate_on_low_conf: true
14
+ glps_max_targeted_iters: 4
15
+ glps_tau_halt: 0.95
16
+ glps_tau_uncertain: 0.8
17
+ glps_token_masking: true
18
+ halt_exploration_prob: 0.1
19
+ halt_max_steps: 16
20
+ hidden_size: 512
21
+ loss:
22
+ loss_type: stablemax_cross_entropy
23
+ name: losses@ACTLossHead
24
+ mlp_t: false
25
+ name: recursive_reasoning.glps@GLPS_ACTV1
26
+ num_heads: 8
27
+ pos_encodings: rope
28
+ puzzle_emb_ndim: 512
29
+ rms_norm_eps: 1.0e-05
30
+ rope_theta: 10000.0
31
+ beta1: 0.9
32
+ beta2: 0.95
33
+ checkpoint_every_eval: true
34
+ checkpoint_path: checkpoints/Arc1concept-aug-1000-ACT-torch/pretrain_glps_arc1_h200_v1
35
+ data_paths:
36
+ - data/arc1concept-aug-1000
37
+ data_paths_test: []
38
+ ema: true
39
+ ema_rate: 0.999
40
+ epochs: 100000
41
+ eval_glps_max_targeted_iters: null
42
+ eval_glps_tau_halt: null
43
+ eval_halt_max_steps: null
44
+ eval_interval: 2000
45
+ eval_only: false
46
+ eval_save_outputs: []
47
+ evaluators:
48
+ - name: arc@ARC
49
+ freeze_weights: false
50
+ global_batch_size: 768
51
+ load_checkpoint: null
52
+ lr: 0.0001
53
+ lr_min_ratio: 0.3
54
+ lr_warmup_steps: 2000
55
+ min_eval_interval: 0
56
+ project_name: Arc1concept-aug-1000-ACT-torch
57
+ puzzle_emb_lr: 0.01
58
+ puzzle_emb_weight_decay: 0.1
59
+ run_name: pretrain_glps_arc1_h200_v1
60
+ seed: 0
61
+ weight_decay: 0.1
Arc1concept-aug-1000-ACT-torch/pretrain_glps_arc1_h200_v1/evaluator_ARC_step_10361/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc1concept-aug-1000-ACT-torch/pretrain_glps_arc1_h200_v1/evaluator_ARC_step_20722/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc1concept-aug-1000-ACT-torch/pretrain_glps_arc1_h200_v1/evaluator_ARC_step_31083/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc1concept-aug-1000-ACT-torch/pretrain_glps_arc1_h200_v1/evaluator_ARC_step_41444/submission.json ADDED
The diff for this file is too large to render. See raw diff
 
Arc1concept-aug-1000-ACT-torch/pretrain_glps_arc1_h200_v1/glps.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Tuple, List, Dict
3
+ from dataclasses import dataclass
4
+ import math
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ from pydantic import BaseModel
9
+
10
+ # Reuse the same building blocks as HRM/TRM
11
+ from models.common import trunc_normal_init_
12
+ from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
13
+ from models.sparse_embedding import CastedSparseEmbedding
14
+
15
+ """
16
+ Global-Local Predictive Solver (GLPS)
17
+ ------------------------------------
18
+ A light-weight control-policy on top of the HRM/TRM style blocks:
19
+ - H1: global scan -> certainty map
20
+ - L1: fill-obvious (lock stable cells)
21
+ - H2: dependency scoring over remaining cells
22
+ - L2: targeted refinement (masked updates)
23
+ - H3: energy-based confidence -> (optional) one global propagate sweep -> halt
24
+
25
+ This file keeps parameter growth tiny: a few heads + (optional) low-rank dependency scorer.
26
+ """
27
+
28
+ @dataclass
29
+ class GLPS_ACTV1InnerCarry:
30
+ z_H: torch.Tensor
31
+ z_L: torch.Tensor
32
+
33
+ @dataclass
34
+ class GLPS_ACTV1Carry:
35
+ inner_carry: GLPS_ACTV1InnerCarry
36
+ steps: torch.Tensor
37
+ halted: torch.Tensor
38
+ current_data: Dict[str, torch.Tensor]
39
+
40
+ class GLPS_ACTV1Config(BaseModel):
41
+ # Core IO / shapes
42
+ batch_size: int
43
+ seq_len: int
44
+ puzzle_emb_ndim: int = 0
45
+ num_puzzle_identifiers: int = 1
46
+ vocab_size: int = 256
47
+
48
+ # Cycle schedule
49
+ H_cycles: int = 3 # (scan -> refine -> check) typical
50
+ L_cycles: int = 1
51
+
52
+ # Depth
53
+ H_layers: int = 2
54
+ L_layers: int = 4
55
+
56
+ # Transformer config
57
+ hidden_size: int = 512
58
+ expansion: float = 2.0
59
+ num_heads: int = 8
60
+ pos_encodings: str = "rope"
61
+
62
+ rms_norm_eps: float = 1e-5
63
+ rope_theta: float = 10000.0
64
+
65
+ # ACT wrapper
66
+ halt_max_steps: int = 4
67
+ halt_exploration_prob: float = 0.1
68
+
69
+ forward_dtype: str = "bfloat16"
70
+
71
+ # Optional: use MLP on L instead of attention (matches HRM/TRM option)
72
+ mlp_t: bool = False
73
+
74
+ # ---- GLPS extras (tiny) ----
75
+ glps_enabled: bool = True
76
+ glps_fill_obvious: bool = True
77
+ glps_dep_graph: bool = True
78
+ glps_token_masking: bool = True
79
+ glps_global_propagate_on_low_conf: bool = True
80
+
81
+ glps_tau_halt: float = 0.95 # final confidence to halt
82
+ glps_tau_uncertain: float = 0.60 # cell-wise certainty threshold
83
+ glps_max_targeted_iters: int = 2 # small number: 1-2
84
+
85
+ # Dependency scorer (low rank bilinear)
86
+ dep_rank: int = 32
87
+ dep_topk: int = 8
88
+
89
+ class GLPSBlock(nn.Module):
90
+ def __init__(self, config: GLPS_ACTV1Config) -> None:
91
+ super().__init__()
92
+ self.config = config
93
+ if self.config.mlp_t:
94
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size)
95
+ self.mlp_t = SwiGLU(
96
+ hidden_size=self.config.seq_len + self.puzzle_emb_len, # treat sequence as channel
97
+ expansion=config.expansion,
98
+ )
99
+ else:
100
+ self.self_attn = Attention(
101
+ hidden_size=config.hidden_size,
102
+ head_dim=config.hidden_size // config.num_heads,
103
+ num_heads=config.num_heads,
104
+ num_key_value_heads=config.num_heads,
105
+ causal=False,
106
+ )
107
+ self.mlp = SwiGLU(hidden_size=config.hidden_size, expansion=config.expansion)
108
+ self.norm_eps = config.rms_norm_eps
109
+
110
+ def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
111
+ if self.config.mlp_t:
112
+ # MLP over sequence dimension (mlp-t)
113
+ hidden_states = hidden_states.transpose(1, 2)
114
+ out = self.mlp_t(hidden_states)
115
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
116
+ hidden_states = hidden_states.transpose(1, 2)
117
+ else:
118
+ hidden_states = rms_norm(
119
+ hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states),
120
+ variance_epsilon=self.norm_eps,
121
+ )
122
+ out = self.mlp(hidden_states)
123
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
124
+ return hidden_states
125
+
126
+ class GLPSReasoningModule(nn.Module):
127
+ """Reasoning stack with optional masked updates (only update uncertain tokens)."""
128
+ def __init__(self, layers: List[GLPSBlock]):
129
+ super().__init__()
130
+ self.layers = torch.nn.ModuleList(layers)
131
+
132
+ def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, update_mask: torch.Tensor = None, **kwargs) -> torch.Tensor:
133
+ x = hidden_states
134
+ for layer in self.layers:
135
+ # Compute candidate update using injected context
136
+ y = layer(hidden_states=x + input_injection, **kwargs)
137
+ if update_mask is not None:
138
+ # Convex blend keeps frozen tokens unchanged
139
+ m = update_mask.to(x.dtype)[..., None]
140
+ x = x + m * (y - x)
141
+ else:
142
+ x = y
143
+ return x
144
+
145
+ class GLPS_ACTV1_Inner(nn.Module):
146
+ def __init__(self, config: GLPS_ACTV1Config) -> None:
147
+ super().__init__()
148
+ self.config = config
149
+ self.forward_dtype = getattr(torch, self.config.forward_dtype)
150
+
151
+ # I/O
152
+ self.embed_scale = math.sqrt(self.config.hidden_size)
153
+ embed_init_std = 1.0 / self.embed_scale
154
+
155
+ self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
156
+ self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
157
+ self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
158
+
159
+ # Puzzle emb (optional) — same convention as HRM/TRM
160
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
161
+ if self.config.puzzle_emb_ndim > 0:
162
+ self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim, batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
163
+
164
+ # Positional encodings
165
+ if self.config.pos_encodings == "rope":
166
+ self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads, max_position_embeddings=self.config.seq_len + self.puzzle_emb_len, base=self.config.rope_theta)
167
+ elif self.config.pos_encodings == "learned":
168
+ self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
169
+
170
+ # Reasoning stacks
171
+ self.H_level = GLPSReasoningModule(layers=[GLPSBlock(self.config) for _ in range(self.config.H_layers)])
172
+ self.L_level = GLPSReasoningModule(layers=[GLPSBlock(self.config) for _ in range(self.config.L_layers)])
173
+
174
+ # Initial states (match HRM/TRM style)
175
+ H_init = trunc_normal_init_(torch.empty(1, 1, self.config.hidden_size, dtype=self.forward_dtype), std=1.0)
176
+ L_init = trunc_normal_init_(torch.empty(1, 1, self.config.hidden_size, dtype=self.forward_dtype), std=1.0)
177
+ self.register_buffer("H_init", H_init, persistent=True)
178
+ self.register_buffer("L_init", L_init, persistent=True)
179
+
180
+ # GLPS small heads
181
+ self.candidate_head = CastedLinear(self.config.hidden_size, 16, bias=True) # task-specific; for Sudoku you can slice to 9
182
+ self.certainty_head = CastedLinear(self.config.hidden_size, 1, bias=True)
183
+ self.energy_head = CastedLinear(self.config.hidden_size, 1, bias=True)
184
+
185
+ # Low-rank dependency scorer (shared)
186
+ r = max(1, self.config.dep_rank)
187
+ self.dep_q = CastedLinear(self.config.hidden_size, r, bias=False)
188
+ self.dep_k = CastedLinear(self.config.hidden_size, r, bias=False)
189
+
190
+ # Q head init like HRM/TRM (near-zero -> easier bootstrapping)
191
+ with torch.no_grad():
192
+ self.q_head.weight.zero_()
193
+ self.q_head.bias.fill_(-5)
194
+
195
+ def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
196
+ # Token embedding
197
+ embedding = self.embed_tokens(input.to(torch.int32))
198
+
199
+ # Puzzle embeddings
200
+ if self.config.puzzle_emb_ndim > 0:
201
+ puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
202
+ pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
203
+ if pad_count > 0:
204
+ puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
205
+ embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
206
+
207
+ # Position embeddings
208
+ if self.config.pos_encodings == "learned":
209
+ embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
210
+
211
+ return self.embed_scale * embedding
212
+
213
+ def empty_carry(self, batch_size: int):
214
+ return GLPS_ACTV1InnerCarry(
215
+ z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
216
+ z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
217
+ )
218
+
219
+ def reset_carry(self, reset_flag: torch.Tensor, carry: GLPS_ACTV1InnerCarry):
220
+ # Explicitly expand buffers and mask to target shapes to avoid shape confusion
221
+ B, L, D = carry.z_H.shape
222
+ # Reduce/reset flag to per-batch boolean vector of shape [B]
223
+ if reset_flag.ndim == 1 and reset_flag.shape[0] == B:
224
+ reset_b = reset_flag.to(torch.bool)
225
+ else:
226
+ # If shape is [B, ...] reduce across non-batch dims; otherwise fallback to first B entries
227
+ try:
228
+ reset_b = reset_flag.reshape(B, -1).any(dim=1).to(torch.bool)
229
+ except Exception:
230
+ reset_b = reset_flag.reshape(-1)[:B].to(torch.bool)
231
+ m = reset_b.view(B, 1, 1)
232
+ mH = m.expand(B, L, D)
233
+ mL = mH # same shape for z_L
234
+ H_init_exp = self.H_init.expand(B, L, D)
235
+ L_init_exp = self.L_init.expand(B, L, D)
236
+ return GLPS_ACTV1InnerCarry(
237
+ z_H=torch.where(mH, H_init_exp, carry.z_H),
238
+ z_L=torch.where(mL, L_init_exp, carry.z_L),
239
+ )
240
+
241
+ def _global_scan(self, z_L, z_H, input_embeddings, seq_info):
242
+ # One light pass to gather global signals
243
+ z_scan = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
244
+ cand_logits = self.candidate_head(z_scan) # [B, L, C]
245
+ certainty = torch.sigmoid(self.certainty_head(z_scan)) # [B, L, 1]
246
+ return z_scan, cand_logits, certainty
247
+
248
+ def _build_dep_mask(self, z_ctx, uncertain_mask: torch.Tensor):
249
+ """Compute a dependency-based focus mask from a low-rank bilinear score.
250
+ uncertain_mask: [B, L] boolean mask of cells that are currently uncertain
251
+ Returns: dep_mask [B, L] boolean mask of cells to (re)update.
252
+ """
253
+ B, L, D = z_ctx.shape
254
+ # Project to low-rank space; use float32 to avoid dtype mismatches under torch.compile
255
+ Q = self.dep_q(z_ctx).to(torch.float32) # [B, L, r]
256
+ K = self.dep_k(z_ctx).to(torch.float32) # [B, L, r]
257
+ r = max(1, int(Q.shape[-1]))
258
+ sim = torch.matmul(Q, K.transpose(1, 2)) # [B, L, L] (float32)
259
+ sim = sim / math.sqrt(r)
260
+
261
+ # Aggregate influence from uncertain queries onto target tokens
262
+ src = uncertain_mask.to(sim.dtype).unsqueeze(1) # [B, 1, L]
263
+ influence = torch.matmul(src, sim).squeeze(1) # [B, L]
264
+
265
+ # Top-k influenced tokens per batch
266
+ topk = min(self.config.dep_topk, L)
267
+ vals, idx = torch.topk(influence, k=topk, dim=-1) # [B, topk]
268
+ dep_mask = torch.zeros_like(uncertain_mask)
269
+ dep_mask.scatter_(1, idx, True)
270
+
271
+ # Always include uncertain cells themselves
272
+ dep_mask = dep_mask | uncertain_mask
273
+ return dep_mask
274
+
275
+ def forward(self, carry: GLPS_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]):
276
+ seq_info = dict(cos_sin=getattr(self, "rotary_emb", None)() if hasattr(self, "rotary_emb") else None)
277
+
278
+ # Encode inputs
279
+ input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
280
+
281
+ # States
282
+ z_H, z_L = carry.z_H, carry.z_L
283
+
284
+ if not self.config.glps_enabled:
285
+ # Fallback to an HRM-like single-cycle grad update for compatibility
286
+ with torch.no_grad():
287
+ for _H in range(self.config.H_cycles - 1):
288
+ for _L in range(self.config.L_cycles):
289
+ z_L = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
290
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
291
+ # final grad step
292
+ for _L in range(self.config.L_cycles):
293
+ z_L = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
294
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
295
+
296
+ # Outputs
297
+ new_carry = GLPS_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())
298
+ logits = self.lm_head(z_H)[:, self.puzzle_emb_len:]
299
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
300
+ conf = torch.zeros_like(q_logits[..., :1]) + 0.5 # neutral
301
+ return new_carry, logits, (q_logits[..., 0], q_logits[..., 1]), conf
302
+
303
+ # ===== GLPS path =====
304
+ # H1: global scan (cheap)
305
+ with torch.no_grad():
306
+ z_scan, cand_logits, certainty = self._global_scan(z_L, z_H, input_embeddings, seq_info)
307
+
308
+ # L1: fill-obvious -> compute stable vs uncertain masks
309
+ if self.config.glps_fill_obvious:
310
+ obvious_mask = (certainty >= self.config.glps_tau_uncertain) # [B, L, 1]
311
+ else:
312
+ obvious_mask = torch.zeros_like(certainty).bool()
313
+ stable_mask = obvious_mask.squeeze(-1) # [B, L]
314
+ uncertain_mask = ~stable_mask # [B, L]
315
+
316
+ # H2: dependency prediction over remaining cells
317
+ if self.config.glps_dep_graph:
318
+ dep_mask = self._build_dep_mask(z_scan, uncertain_mask) # [B, L]
319
+ else:
320
+ dep_mask = uncertain_mask
321
+
322
+ # L2: targeted refinement (a couple of masked iters)
323
+ update_mask = dep_mask if self.config.glps_token_masking else None
324
+ z = z_scan.detach() # use scanned context as start
325
+ for _ in range(self.config.glps_max_targeted_iters):
326
+ z = self.L_level(z, z_H + input_embeddings, update_mask=update_mask, **seq_info)
327
+ # Refresh certainty to shrink mask (optional but cheap)
328
+ cert_now = torch.sigmoid(self.certainty_head(z)).squeeze(-1)
329
+ if self.config.glps_token_masking:
330
+ update_mask = dep_mask & (cert_now < self.config.glps_tau_uncertain)
331
+
332
+ # Merge into H and do a light H update with grad
333
+ z_L = z
334
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
335
+
336
+ # H3: energy/consistency -> confidence & optional global propagate
337
+ energy = torch.sigmoid(self.energy_head(z_H)).mean(dim=1) # [B, 1]
338
+ conf = 1.0 - energy
339
+ need_sweep = (conf.squeeze(-1) < self.config.glps_tau_halt)
340
+ if self.config.glps_global_propagate_on_low_conf and need_sweep.any():
341
+ # one final full sweep only for rows needing it
342
+ maskB = need_sweep.view(-1, 1, 1)
343
+ zL2 = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
344
+ zH2 = self.H_level(z_H, zL2, update_mask=None, **seq_info)
345
+ z_L = torch.where(maskB, zL2, z_L)
346
+ z_H = torch.where(maskB, zH2, z_H)
347
+
348
+ # Outputs
349
+ new_carry = GLPS_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())
350
+ logits = self.lm_head(z_H)[:, self.puzzle_emb_len:]
351
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
352
+ return new_carry, logits, (q_logits[..., 0], q_logits[..., 1]), conf
353
+
354
+ class GLPS_ACTV1(nn.Module):
355
+ """ACT-style wrapper that mixes Q-halt with GLPS confidence."""
356
+ def __init__(self, config_dict: dict):
357
+ super().__init__()
358
+ self.config = GLPS_ACTV1Config(**config_dict)
359
+ self.inner = GLPS_ACTV1_Inner(self.config)
360
+
361
+ @property
362
+ def puzzle_emb(self):
363
+ return self.inner.puzzle_emb
364
+
365
+ def initial_carry(self, batch: Dict[str, torch.Tensor]):
366
+ batch_size = batch["inputs"].shape[0]
367
+ return GLPS_ACTV1Carry(
368
+ inner_carry=self.inner.empty_carry(batch_size),
369
+ steps=torch.zeros((batch_size,), dtype=torch.int32),
370
+ halted=torch.ones((batch_size,), dtype=torch.bool), # start halted to force reset on first pass
371
+ current_data={k: torch.empty_like(v) for k, v in batch.items()}
372
+ )
373
+
374
+ def forward(self, carry: GLPS_ACTV1Carry, batch: Dict[str, torch.Tensor]):
375
+ # Reset halted seqs
376
+ new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
377
+ new_steps = torch.where(carry.halted, 0, carry.steps)
378
+ new_current_data = {k: torch.where(carry.halted.view((-1,) + (1,) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
379
+
380
+ # Inner step
381
+ new_inner_carry, logits, (q_halt_logits, q_continue_logits), conf = self.inner(new_inner_carry, new_current_data)
382
+
383
+ outputs = {
384
+ "logits": logits,
385
+ "q_halt_logits": q_halt_logits,
386
+ "q_continue_logits": q_continue_logits,
387
+ "conf": conf.squeeze(-1),
388
+ }
389
+
390
+ with torch.no_grad():
391
+ new_steps = new_steps + 1
392
+ is_last_step = new_steps >= self.config.halt_max_steps
393
+
394
+ # Combine halt signals: Q or confidence or last-step
395
+ halted = is_last_step | (q_halt_logits > q_continue_logits) | (conf.squeeze(-1) >= self.config.glps_tau_halt)
396
+
397
+ # Exploration during training only
398
+ if self.training and (self.config.halt_max_steps > 1):
399
+ min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
400
+ halted = halted & (new_steps >= min_halt_steps)
401
+
402
+ # Optional target for Q-learning (kept similar to HRM)
403
+ next_conf = self.inner(new_inner_carry, new_current_data)[-1]
404
+ outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, q_halt_logits, torch.maximum(q_halt_logits, q_continue_logits)))
405
+ else:
406
+ # During eval, always use max_steps to ensure consistent reasoning depth (same as TRM/HRM eval behavior)
407
+ halted = is_last_step
408
+
409
+ return GLPS_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
Arc1concept-aug-1000-ACT-torch/pretrain_glps_arc1_h200_v1/losses.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Tuple, Dict, Sequence, Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ import math
7
+
8
+ IGNORE_LABEL_ID = -100
9
+
10
+
11
+ def s(x, epsilon=1e-30):
12
+ return torch.where(
13
+ x<0,
14
+ 1/(1-x+ epsilon),
15
+ x + 1
16
+ )
17
+
18
+
19
+ def log_stablemax(x, dim=-1):
20
+ s_x = s(x)
21
+ return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
22
+
23
+
24
+ def stablemax_cross_entropy(logits, labels, ignore_index: int = -100, valid_mask=None):
25
+ logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
26
+
27
+ if valid_mask is None:
28
+ valid_mask = (labels != ignore_index)
29
+ transformed_labels = torch.where(valid_mask, labels, 0)
30
+ prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
31
+
32
+ return -torch.where(valid_mask, prediction_logprobs, 0)
33
+
34
+
35
+ def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
36
+ # Cast logits to f32
37
+ # Flatten logits
38
+ return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
39
+
40
+
41
+ class ACTLossHead(nn.Module):
42
+ def __init__(self, model: nn.Module, loss_type: str):
43
+ super().__init__()
44
+ self.model = model
45
+ self.loss_fn = globals()[loss_type]
46
+
47
+ def initial_carry(self, *args, **kwargs):
48
+ return self.model.initial_carry(*args, **kwargs) # type: ignore
49
+
50
+ def forward(
51
+ self,
52
+ return_keys: Sequence[str],
53
+ # Model args
54
+ **model_kwargs,
55
+ ) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
56
+ # Model logits
57
+ # B x SeqLen x D
58
+ new_carry, outputs = self.model(**model_kwargs)
59
+ labels = new_carry.current_data["labels"]
60
+
61
+ with torch.no_grad():
62
+ # Preds
63
+ outputs["preds"] = torch.argmax(outputs["logits"], dim=-1)
64
+
65
+ # Correctness
66
+ mask = (labels != IGNORE_LABEL_ID)
67
+ loss_counts = mask.sum(-1)
68
+ loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
69
+
70
+ is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
71
+ seq_is_correct = is_correct.sum(-1) == loss_counts
72
+
73
+ # Metrics (halted)
74
+ valid_metrics = new_carry.halted & (loss_counts > 0)
75
+ metrics = {
76
+ "count": valid_metrics.sum(),
77
+
78
+ "accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
79
+ "exact_accuracy": (valid_metrics & seq_is_correct).sum(),
80
+
81
+ "q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
82
+ "steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
83
+ }
84
+
85
+ # Losses
86
+
87
+ lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID, valid_mask=mask) / loss_divisor).sum()
88
+ q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
89
+ metrics.update({
90
+ "lm_loss": lm_loss.detach(),
91
+ "q_halt_loss": q_halt_loss.detach(),
92
+ })
93
+ # Q continue (bootstrapping target loss); Alexia: This fits Q-learning, but seems totally unecessary
94
+ q_continue_loss = 0
95
+ if "target_q_continue" in outputs:
96
+ q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
97
+
98
+ metrics["q_continue_loss"] = q_continue_loss.detach()
99
+ # Filter outputs for return
100
+ detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
101
+
102
+ return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()
Arc1concept-aug-1000-ACT-torch/pretrain_glps_arc1_h200_v1/step_10361 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:16fe5d9f72d90890d9de1ebcec61294dff2f7ca98f75442d000e338993a9945f
3
+ size 1904307893
Arc1concept-aug-1000-ACT-torch/pretrain_glps_arc1_h200_v1/step_20722 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:36344623a0bb8305681ad97ff31e5b8dc0af6f15ff83fd0ffcaa8c770d1601f1
3
+ size 1904307893
Arc1concept-aug-1000-ACT-torch/pretrain_glps_arc1_h200_v1/step_31083 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c39d0dc62d8c9ecfc11c9f7cf48aedeec307e731b0c785c9e4e6c262496786c2
3
+ size 1904307893
Arc1concept-aug-1000-ACT-torch/pretrain_glps_arc1_h200_v1/step_41444 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6259ae684f464a5605db790399e6fac7d7228b63ea1273c48616c5e03ab92843
3
+ size 1904307893
Maze-30x30-hard-1k-ACT-torch/pretrain_glps_maze30_v1_cos/all_config.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ arch:
2
+ H_cycles: 3
3
+ H_layers: 2
4
+ L_cycles: 1
5
+ L_layers: 7
6
+ dep_rank: 64
7
+ dep_topk: 12
8
+ expansion: 4.0
9
+ forward_dtype: bfloat16
10
+ glps_dep_graph: true
11
+ glps_enabled: true
12
+ glps_fill_obvious: true
13
+ glps_global_propagate_on_low_conf: true
14
+ glps_max_targeted_iters: 4
15
+ glps_tau_halt: 0.95
16
+ glps_tau_uncertain: 0.8
17
+ glps_token_masking: true
18
+ halt_exploration_prob: 0.1
19
+ halt_max_steps: 16
20
+ hidden_size: 512
21
+ loss:
22
+ loss_type: stablemax_cross_entropy
23
+ name: losses@ACTLossHead
24
+ mlp_t: false
25
+ name: recursive_reasoning.glps@GLPS_ACTV1
26
+ num_heads: 8
27
+ pos_encodings: rope
28
+ puzzle_emb_ndim: 0
29
+ rms_norm_eps: 1.0e-05
30
+ rope_theta: 10000.0
31
+ beta1: 0.9
32
+ beta2: 0.95
33
+ checkpoint_every_eval: true
34
+ checkpoint_path: checkpoints/Maze-30x30-hard-1k-ACT-torch/pretrain_glps_maze30_v1_cos
35
+ data_paths:
36
+ - data/maze-30x30-hard-1k
37
+ data_paths_test: []
38
+ ema: true
39
+ ema_rate: 0.999
40
+ epochs: 50000
41
+ eval_glps_max_targeted_iters: null
42
+ eval_glps_tau_halt: null
43
+ eval_halt_max_steps: null
44
+ eval_interval: 5000
45
+ eval_only: false
46
+ eval_save_outputs: []
47
+ evaluators: []
48
+ freeze_weights: false
49
+ global_batch_size: 64
50
+ load_checkpoint: null
51
+ lr: 0.0001
52
+ lr_min_ratio: 0.1
53
+ lr_warmup_steps: 2000
54
+ min_eval_interval: 0
55
+ project_name: Maze-30x30-hard-1k-ACT-torch
56
+ puzzle_emb_lr: 0.0001
57
+ puzzle_emb_weight_decay: 1.0
58
+ run_name: pretrain_glps_maze30_v1_cos
59
+ seed: 0
60
+ weight_decay: 1.0
Maze-30x30-hard-1k-ACT-torch/pretrain_glps_maze30_v1_cos/glps.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Tuple, List, Dict
3
+ from dataclasses import dataclass
4
+ import math
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ from pydantic import BaseModel
9
+
10
+ # Reuse the same building blocks as HRM/TRM
11
+ from models.common import trunc_normal_init_
12
+ from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
13
+ from models.sparse_embedding import CastedSparseEmbedding
14
+
15
+ """
16
+ Global-Local Predictive Solver (GLPS)
17
+ ------------------------------------
18
+ A light-weight control-policy on top of the HRM/TRM style blocks:
19
+ - H1: global scan -> certainty map
20
+ - L1: fill-obvious (lock stable cells)
21
+ - H2: dependency scoring over remaining cells
22
+ - L2: targeted refinement (masked updates)
23
+ - H3: energy-based confidence -> (optional) one global propagate sweep -> halt
24
+
25
+ This file keeps parameter growth tiny: a few heads + (optional) low-rank dependency scorer.
26
+ """
27
+
28
+ @dataclass
29
+ class GLPS_ACTV1InnerCarry:
30
+ z_H: torch.Tensor
31
+ z_L: torch.Tensor
32
+
33
+ @dataclass
34
+ class GLPS_ACTV1Carry:
35
+ inner_carry: GLPS_ACTV1InnerCarry
36
+ steps: torch.Tensor
37
+ halted: torch.Tensor
38
+ current_data: Dict[str, torch.Tensor]
39
+
40
+ class GLPS_ACTV1Config(BaseModel):
41
+ # Core IO / shapes
42
+ batch_size: int
43
+ seq_len: int
44
+ puzzle_emb_ndim: int = 0
45
+ num_puzzle_identifiers: int = 1
46
+ vocab_size: int = 256
47
+
48
+ # Cycle schedule
49
+ H_cycles: int = 3 # (scan -> refine -> check) typical
50
+ L_cycles: int = 1
51
+
52
+ # Depth
53
+ H_layers: int = 2
54
+ L_layers: int = 4
55
+
56
+ # Transformer config
57
+ hidden_size: int = 512
58
+ expansion: float = 2.0
59
+ num_heads: int = 8
60
+ pos_encodings: str = "rope"
61
+
62
+ rms_norm_eps: float = 1e-5
63
+ rope_theta: float = 10000.0
64
+
65
+ # ACT wrapper
66
+ halt_max_steps: int = 4
67
+ halt_exploration_prob: float = 0.1
68
+
69
+ forward_dtype: str = "bfloat16"
70
+
71
+ # Optional: use MLP on L instead of attention (matches HRM/TRM option)
72
+ mlp_t: bool = False
73
+
74
+ # ---- GLPS extras (tiny) ----
75
+ glps_enabled: bool = True
76
+ glps_fill_obvious: bool = True
77
+ glps_dep_graph: bool = True
78
+ glps_token_masking: bool = True
79
+ glps_global_propagate_on_low_conf: bool = True
80
+
81
+ glps_tau_halt: float = 0.95 # final confidence to halt
82
+ glps_tau_uncertain: float = 0.60 # cell-wise certainty threshold
83
+ glps_max_targeted_iters: int = 2 # small number: 1-2
84
+
85
+ # Dependency scorer (low rank bilinear)
86
+ dep_rank: int = 32
87
+ dep_topk: int = 8
88
+
89
+ class GLPSBlock(nn.Module):
90
+ def __init__(self, config: GLPS_ACTV1Config) -> None:
91
+ super().__init__()
92
+ self.config = config
93
+ if self.config.mlp_t:
94
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size)
95
+ self.mlp_t = SwiGLU(
96
+ hidden_size=self.config.seq_len + self.puzzle_emb_len, # treat sequence as channel
97
+ expansion=config.expansion,
98
+ )
99
+ else:
100
+ self.self_attn = Attention(
101
+ hidden_size=config.hidden_size,
102
+ head_dim=config.hidden_size // config.num_heads,
103
+ num_heads=config.num_heads,
104
+ num_key_value_heads=config.num_heads,
105
+ causal=False,
106
+ )
107
+ self.mlp = SwiGLU(hidden_size=config.hidden_size, expansion=config.expansion)
108
+ self.norm_eps = config.rms_norm_eps
109
+
110
+ def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
111
+ if self.config.mlp_t:
112
+ # MLP over sequence dimension (mlp-t)
113
+ hidden_states = hidden_states.transpose(1, 2)
114
+ out = self.mlp_t(hidden_states)
115
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
116
+ hidden_states = hidden_states.transpose(1, 2)
117
+ else:
118
+ hidden_states = rms_norm(
119
+ hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states),
120
+ variance_epsilon=self.norm_eps,
121
+ )
122
+ out = self.mlp(hidden_states)
123
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
124
+ return hidden_states
125
+
126
+ class GLPSReasoningModule(nn.Module):
127
+ """Reasoning stack with optional masked updates (only update uncertain tokens)."""
128
+ def __init__(self, layers: List[GLPSBlock]):
129
+ super().__init__()
130
+ self.layers = torch.nn.ModuleList(layers)
131
+
132
+ def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, update_mask: torch.Tensor = None, **kwargs) -> torch.Tensor:
133
+ x = hidden_states
134
+ for layer in self.layers:
135
+ # Compute candidate update using injected context
136
+ y = layer(hidden_states=x + input_injection, **kwargs)
137
+ if update_mask is not None:
138
+ # Convex blend keeps frozen tokens unchanged
139
+ m = update_mask.to(x.dtype)[..., None]
140
+ x = x + m * (y - x)
141
+ else:
142
+ x = y
143
+ return x
144
+
145
+ class GLPS_ACTV1_Inner(nn.Module):
146
+ def __init__(self, config: GLPS_ACTV1Config) -> None:
147
+ super().__init__()
148
+ self.config = config
149
+ self.forward_dtype = getattr(torch, self.config.forward_dtype)
150
+
151
+ # I/O
152
+ self.embed_scale = math.sqrt(self.config.hidden_size)
153
+ embed_init_std = 1.0 / self.embed_scale
154
+
155
+ self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
156
+ self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
157
+ self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
158
+
159
+ # Puzzle emb (optional) — same convention as HRM/TRM
160
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
161
+ if self.config.puzzle_emb_ndim > 0:
162
+ self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim, batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
163
+
164
+ # Positional encodings
165
+ if self.config.pos_encodings == "rope":
166
+ self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads, max_position_embeddings=self.config.seq_len + self.puzzle_emb_len, base=self.config.rope_theta)
167
+ elif self.config.pos_encodings == "learned":
168
+ self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
169
+
170
+ # Reasoning stacks
171
+ self.H_level = GLPSReasoningModule(layers=[GLPSBlock(self.config) for _ in range(self.config.H_layers)])
172
+ self.L_level = GLPSReasoningModule(layers=[GLPSBlock(self.config) for _ in range(self.config.L_layers)])
173
+
174
+ # Initial states (match HRM/TRM style)
175
+ H_init = trunc_normal_init_(torch.empty(1, 1, self.config.hidden_size, dtype=self.forward_dtype), std=1.0)
176
+ L_init = trunc_normal_init_(torch.empty(1, 1, self.config.hidden_size, dtype=self.forward_dtype), std=1.0)
177
+ self.register_buffer("H_init", H_init, persistent=True)
178
+ self.register_buffer("L_init", L_init, persistent=True)
179
+
180
+ # GLPS small heads
181
+ self.candidate_head = CastedLinear(self.config.hidden_size, 16, bias=True) # task-specific; for Sudoku you can slice to 9
182
+ self.certainty_head = CastedLinear(self.config.hidden_size, 1, bias=True)
183
+ self.energy_head = CastedLinear(self.config.hidden_size, 1, bias=True)
184
+
185
+ # Low-rank dependency scorer (shared)
186
+ r = max(1, self.config.dep_rank)
187
+ self.dep_q = CastedLinear(self.config.hidden_size, r, bias=False)
188
+ self.dep_k = CastedLinear(self.config.hidden_size, r, bias=False)
189
+
190
+ # Q head init like HRM/TRM (near-zero -> easier bootstrapping)
191
+ with torch.no_grad():
192
+ self.q_head.weight.zero_()
193
+ self.q_head.bias.fill_(-5)
194
+
195
+ def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
196
+ # Token embedding
197
+ embedding = self.embed_tokens(input.to(torch.int32))
198
+
199
+ # Puzzle embeddings
200
+ if self.config.puzzle_emb_ndim > 0:
201
+ puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
202
+ pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
203
+ if pad_count > 0:
204
+ puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
205
+ embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
206
+
207
+ # Position embeddings
208
+ if self.config.pos_encodings == "learned":
209
+ embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
210
+
211
+ return self.embed_scale * embedding
212
+
213
+ def empty_carry(self, batch_size: int):
214
+ return GLPS_ACTV1InnerCarry(
215
+ z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
216
+ z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
217
+ )
218
+
219
+ def reset_carry(self, reset_flag: torch.Tensor, carry: GLPS_ACTV1InnerCarry):
220
+ # Explicitly expand buffers and mask to target shapes to avoid shape confusion
221
+ B, L, D = carry.z_H.shape
222
+ # Reduce/reset flag to per-batch boolean vector of shape [B]
223
+ if reset_flag.ndim == 1 and reset_flag.shape[0] == B:
224
+ reset_b = reset_flag.to(torch.bool)
225
+ else:
226
+ # If shape is [B, ...] reduce across non-batch dims; otherwise fallback to first B entries
227
+ try:
228
+ reset_b = reset_flag.reshape(B, -1).any(dim=1).to(torch.bool)
229
+ except Exception:
230
+ reset_b = reset_flag.reshape(-1)[:B].to(torch.bool)
231
+ m = reset_b.view(B, 1, 1)
232
+ mH = m.expand(B, L, D)
233
+ mL = mH # same shape for z_L
234
+ H_init_exp = self.H_init.expand(B, L, D)
235
+ L_init_exp = self.L_init.expand(B, L, D)
236
+ return GLPS_ACTV1InnerCarry(
237
+ z_H=torch.where(mH, H_init_exp, carry.z_H),
238
+ z_L=torch.where(mL, L_init_exp, carry.z_L),
239
+ )
240
+
241
+ def _global_scan(self, z_L, z_H, input_embeddings, seq_info):
242
+ # One light pass to gather global signals
243
+ z_scan = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
244
+ cand_logits = self.candidate_head(z_scan) # [B, L, C]
245
+ certainty = torch.sigmoid(self.certainty_head(z_scan)) # [B, L, 1]
246
+ return z_scan, cand_logits, certainty
247
+
248
+ def _build_dep_mask(self, z_ctx, uncertain_mask: torch.Tensor):
249
+ """Compute a dependency-based focus mask from a low-rank bilinear score.
250
+ uncertain_mask: [B, L] boolean mask of cells that are currently uncertain
251
+ Returns: dep_mask [B, L] boolean mask of cells to (re)update.
252
+ """
253
+ B, L, D = z_ctx.shape
254
+ # Project to low-rank space; use float32 to avoid dtype mismatches under torch.compile
255
+ Q = self.dep_q(z_ctx).to(torch.float32) # [B, L, r]
256
+ K = self.dep_k(z_ctx).to(torch.float32) # [B, L, r]
257
+ r = max(1, int(Q.shape[-1]))
258
+ sim = torch.matmul(Q, K.transpose(1, 2)) # [B, L, L] (float32)
259
+ sim = sim / math.sqrt(r)
260
+
261
+ # Aggregate influence from uncertain queries onto target tokens
262
+ src = uncertain_mask.to(sim.dtype).unsqueeze(1) # [B, 1, L]
263
+ influence = torch.matmul(src, sim).squeeze(1) # [B, L]
264
+
265
+ # Top-k influenced tokens per batch
266
+ topk = min(self.config.dep_topk, L)
267
+ vals, idx = torch.topk(influence, k=topk, dim=-1) # [B, topk]
268
+ dep_mask = torch.zeros_like(uncertain_mask)
269
+ dep_mask.scatter_(1, idx, True)
270
+
271
+ # Always include uncertain cells themselves
272
+ dep_mask = dep_mask | uncertain_mask
273
+ return dep_mask
274
+
275
+ def forward(self, carry: GLPS_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]):
276
+ seq_info = dict(cos_sin=getattr(self, "rotary_emb", None)() if hasattr(self, "rotary_emb") else None)
277
+
278
+ # Encode inputs
279
+ input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
280
+
281
+ # States
282
+ z_H, z_L = carry.z_H, carry.z_L
283
+
284
+ if not self.config.glps_enabled:
285
+ # Fallback to an HRM-like single-cycle grad update for compatibility
286
+ with torch.no_grad():
287
+ for _H in range(self.config.H_cycles - 1):
288
+ for _L in range(self.config.L_cycles):
289
+ z_L = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
290
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
291
+ # final grad step
292
+ for _L in range(self.config.L_cycles):
293
+ z_L = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
294
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
295
+
296
+ # Outputs
297
+ new_carry = GLPS_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())
298
+ logits = self.lm_head(z_H)[:, self.puzzle_emb_len:]
299
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
300
+ conf = torch.zeros_like(q_logits[..., :1]) + 0.5 # neutral
301
+ return new_carry, logits, (q_logits[..., 0], q_logits[..., 1]), conf
302
+
303
+ # ===== GLPS path =====
304
+ # H1: global scan (cheap)
305
+ with torch.no_grad():
306
+ z_scan, cand_logits, certainty = self._global_scan(z_L, z_H, input_embeddings, seq_info)
307
+
308
+ # L1: fill-obvious -> compute stable vs uncertain masks
309
+ if self.config.glps_fill_obvious:
310
+ obvious_mask = (certainty >= self.config.glps_tau_uncertain) # [B, L, 1]
311
+ else:
312
+ obvious_mask = torch.zeros_like(certainty).bool()
313
+ stable_mask = obvious_mask.squeeze(-1) # [B, L]
314
+ uncertain_mask = ~stable_mask # [B, L]
315
+
316
+ # H2: dependency prediction over remaining cells
317
+ if self.config.glps_dep_graph:
318
+ dep_mask = self._build_dep_mask(z_scan, uncertain_mask) # [B, L]
319
+ else:
320
+ dep_mask = uncertain_mask
321
+
322
+ # L2: targeted refinement (a couple of masked iters)
323
+ update_mask = dep_mask if self.config.glps_token_masking else None
324
+ z = z_scan.detach() # use scanned context as start
325
+ for _ in range(self.config.glps_max_targeted_iters):
326
+ z = self.L_level(z, z_H + input_embeddings, update_mask=update_mask, **seq_info)
327
+ # Refresh certainty to shrink mask (optional but cheap)
328
+ cert_now = torch.sigmoid(self.certainty_head(z)).squeeze(-1)
329
+ if self.config.glps_token_masking:
330
+ update_mask = dep_mask & (cert_now < self.config.glps_tau_uncertain)
331
+
332
+ # Merge into H and do a light H update with grad
333
+ z_L = z
334
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
335
+
336
+ # H3: energy/consistency -> confidence & optional global propagate
337
+ energy = torch.sigmoid(self.energy_head(z_H)).mean(dim=1) # [B, 1]
338
+ conf = 1.0 - energy
339
+ need_sweep = (conf.squeeze(-1) < self.config.glps_tau_halt)
340
+ if self.config.glps_global_propagate_on_low_conf and need_sweep.any():
341
+ # one final full sweep only for rows needing it
342
+ maskB = need_sweep.view(-1, 1, 1)
343
+ zL2 = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
344
+ zH2 = self.H_level(z_H, zL2, update_mask=None, **seq_info)
345
+ z_L = torch.where(maskB, zL2, z_L)
346
+ z_H = torch.where(maskB, zH2, z_H)
347
+
348
+ # Outputs
349
+ new_carry = GLPS_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())
350
+ logits = self.lm_head(z_H)[:, self.puzzle_emb_len:]
351
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
352
+ return new_carry, logits, (q_logits[..., 0], q_logits[..., 1]), conf
353
+
354
+ class GLPS_ACTV1(nn.Module):
355
+ """ACT-style wrapper that mixes Q-halt with GLPS confidence."""
356
+ def __init__(self, config_dict: dict):
357
+ super().__init__()
358
+ self.config = GLPS_ACTV1Config(**config_dict)
359
+ self.inner = GLPS_ACTV1_Inner(self.config)
360
+
361
+ @property
362
+ def puzzle_emb(self):
363
+ return self.inner.puzzle_emb
364
+
365
+ def initial_carry(self, batch: Dict[str, torch.Tensor]):
366
+ batch_size = batch["inputs"].shape[0]
367
+ return GLPS_ACTV1Carry(
368
+ inner_carry=self.inner.empty_carry(batch_size),
369
+ steps=torch.zeros((batch_size,), dtype=torch.int32),
370
+ halted=torch.ones((batch_size,), dtype=torch.bool), # start halted to force reset on first pass
371
+ current_data={k: torch.empty_like(v) for k, v in batch.items()}
372
+ )
373
+
374
+ def forward(self, carry: GLPS_ACTV1Carry, batch: Dict[str, torch.Tensor]):
375
+ # Reset halted seqs
376
+ new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
377
+ new_steps = torch.where(carry.halted, 0, carry.steps)
378
+ new_current_data = {k: torch.where(carry.halted.view((-1,) + (1,) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
379
+
380
+ # Inner step
381
+ new_inner_carry, logits, (q_halt_logits, q_continue_logits), conf = self.inner(new_inner_carry, new_current_data)
382
+
383
+ outputs = {
384
+ "logits": logits,
385
+ "q_halt_logits": q_halt_logits,
386
+ "q_continue_logits": q_continue_logits,
387
+ "conf": conf.squeeze(-1),
388
+ }
389
+
390
+ with torch.no_grad():
391
+ new_steps = new_steps + 1
392
+ is_last_step = new_steps >= self.config.halt_max_steps
393
+
394
+ # Combine halt signals: Q or confidence or last-step
395
+ halted = is_last_step | (q_halt_logits > q_continue_logits) | (conf.squeeze(-1) >= self.config.glps_tau_halt)
396
+
397
+ # Exploration during training only
398
+ if self.training and (self.config.halt_max_steps > 1):
399
+ min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
400
+ halted = halted & (new_steps >= min_halt_steps)
401
+
402
+ # Optional target for Q-learning (kept similar to HRM)
403
+ next_conf = self.inner(new_inner_carry, new_current_data)[-1]
404
+ outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, q_halt_logits, torch.maximum(q_halt_logits, q_continue_logits)))
405
+ else:
406
+ # During eval, always use max_steps to ensure consistent reasoning depth (same as TRM/HRM eval behavior)
407
+ halted = is_last_step
408
+
409
+ return GLPS_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
Maze-30x30-hard-1k-ACT-torch/pretrain_glps_maze30_v1_cos/losses.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Tuple, Dict, Sequence, Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ import math
7
+
8
+ IGNORE_LABEL_ID = -100
9
+
10
+
11
+ def s(x, epsilon=1e-30):
12
+ return torch.where(
13
+ x<0,
14
+ 1/(1-x+ epsilon),
15
+ x + 1
16
+ )
17
+
18
+
19
+ def log_stablemax(x, dim=-1):
20
+ s_x = s(x)
21
+ return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
22
+
23
+
24
+ def stablemax_cross_entropy(logits, labels, ignore_index: int = -100, valid_mask=None):
25
+ logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
26
+
27
+ if valid_mask is None:
28
+ valid_mask = (labels != ignore_index)
29
+ transformed_labels = torch.where(valid_mask, labels, 0)
30
+ prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
31
+
32
+ return -torch.where(valid_mask, prediction_logprobs, 0)
33
+
34
+
35
+ def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
36
+ # Cast logits to f32
37
+ # Flatten logits
38
+ return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
39
+
40
+
41
+ class ACTLossHead(nn.Module):
42
+ def __init__(self, model: nn.Module, loss_type: str):
43
+ super().__init__()
44
+ self.model = model
45
+ self.loss_fn = globals()[loss_type]
46
+
47
+ def initial_carry(self, *args, **kwargs):
48
+ return self.model.initial_carry(*args, **kwargs) # type: ignore
49
+
50
+ def forward(
51
+ self,
52
+ return_keys: Sequence[str],
53
+ # Model args
54
+ **model_kwargs,
55
+ ) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
56
+ # Model logits
57
+ # B x SeqLen x D
58
+ new_carry, outputs = self.model(**model_kwargs)
59
+ labels = new_carry.current_data["labels"]
60
+
61
+ with torch.no_grad():
62
+ # Preds
63
+ outputs["preds"] = torch.argmax(outputs["logits"], dim=-1)
64
+
65
+ # Correctness
66
+ mask = (labels != IGNORE_LABEL_ID)
67
+ loss_counts = mask.sum(-1)
68
+ loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
69
+
70
+ is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
71
+ seq_is_correct = is_correct.sum(-1) == loss_counts
72
+
73
+ # Metrics (halted)
74
+ valid_metrics = new_carry.halted & (loss_counts > 0)
75
+ metrics = {
76
+ "count": valid_metrics.sum(),
77
+
78
+ "accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
79
+ "exact_accuracy": (valid_metrics & seq_is_correct).sum(),
80
+
81
+ "q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
82
+ "steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
83
+ }
84
+
85
+ # Losses
86
+
87
+ lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID, valid_mask=mask) / loss_divisor).sum()
88
+ q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
89
+ metrics.update({
90
+ "lm_loss": lm_loss.detach(),
91
+ "q_halt_loss": q_halt_loss.detach(),
92
+ })
93
+ # Q continue (bootstrapping target loss); Alexia: This fits Q-learning, but seems totally unecessary
94
+ q_continue_loss = 0
95
+ if "target_q_continue" in outputs:
96
+ q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
97
+
98
+ metrics["q_continue_loss"] = q_continue_loss.detach()
99
+ # Filter outputs for return
100
+ detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
101
+
102
+ return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()
Maze-30x30-hard-1k-ACT-torch/pretrain_glps_maze30_v1_cos_b160/all_config.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ arch:
2
+ H_cycles: 3
3
+ H_layers: 2
4
+ L_cycles: 1
5
+ L_layers: 7
6
+ dep_rank: 64
7
+ dep_topk: 12
8
+ expansion: 4.0
9
+ forward_dtype: bfloat16
10
+ glps_dep_graph: true
11
+ glps_enabled: true
12
+ glps_fill_obvious: true
13
+ glps_global_propagate_on_low_conf: true
14
+ glps_max_targeted_iters: 4
15
+ glps_tau_halt: 0.95
16
+ glps_tau_uncertain: 0.8
17
+ glps_token_masking: true
18
+ halt_exploration_prob: 0.1
19
+ halt_max_steps: 16
20
+ hidden_size: 512
21
+ loss:
22
+ loss_type: stablemax_cross_entropy
23
+ name: losses@ACTLossHead
24
+ mlp_t: false
25
+ name: recursive_reasoning.glps@GLPS_ACTV1
26
+ num_heads: 8
27
+ pos_encodings: rope
28
+ puzzle_emb_ndim: 0
29
+ rms_norm_eps: 1.0e-05
30
+ rope_theta: 10000.0
31
+ beta1: 0.9
32
+ beta2: 0.95
33
+ checkpoint_every_eval: true
34
+ checkpoint_path: checkpoints/Maze-30x30-hard-1k-ACT-torch/pretrain_glps_maze30_v1_cos_b160
35
+ data_paths:
36
+ - data/maze-30x30-hard-1k
37
+ data_paths_test: []
38
+ ema: true
39
+ ema_rate: 0.999
40
+ epochs: 50000
41
+ eval_glps_max_targeted_iters: null
42
+ eval_glps_tau_halt: null
43
+ eval_halt_max_steps: null
44
+ eval_interval: 5000
45
+ eval_only: false
46
+ eval_save_outputs: []
47
+ evaluators: []
48
+ freeze_weights: false
49
+ global_batch_size: 160
50
+ load_checkpoint: null
51
+ lr: 0.0001
52
+ lr_min_ratio: 0.1
53
+ lr_warmup_steps: 2000
54
+ min_eval_interval: 0
55
+ project_name: Maze-30x30-hard-1k-ACT-torch
56
+ puzzle_emb_lr: 0.0001
57
+ puzzle_emb_weight_decay: 1.0
58
+ run_name: pretrain_glps_maze30_v1_cos_b160
59
+ seed: 0
60
+ weight_decay: 1.0
Maze-30x30-hard-1k-ACT-torch/pretrain_glps_maze30_v1_cos_b160/glps.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Tuple, List, Dict
3
+ from dataclasses import dataclass
4
+ import math
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ from pydantic import BaseModel
9
+
10
+ # Reuse the same building blocks as HRM/TRM
11
+ from models.common import trunc_normal_init_
12
+ from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
13
+ from models.sparse_embedding import CastedSparseEmbedding
14
+
15
+ """
16
+ Global-Local Predictive Solver (GLPS)
17
+ ------------------------------------
18
+ A light-weight control-policy on top of the HRM/TRM style blocks:
19
+ - H1: global scan -> certainty map
20
+ - L1: fill-obvious (lock stable cells)
21
+ - H2: dependency scoring over remaining cells
22
+ - L2: targeted refinement (masked updates)
23
+ - H3: energy-based confidence -> (optional) one global propagate sweep -> halt
24
+
25
+ This file keeps parameter growth tiny: a few heads + (optional) low-rank dependency scorer.
26
+ """
27
+
28
+ @dataclass
29
+ class GLPS_ACTV1InnerCarry:
30
+ z_H: torch.Tensor
31
+ z_L: torch.Tensor
32
+
33
+ @dataclass
34
+ class GLPS_ACTV1Carry:
35
+ inner_carry: GLPS_ACTV1InnerCarry
36
+ steps: torch.Tensor
37
+ halted: torch.Tensor
38
+ current_data: Dict[str, torch.Tensor]
39
+
40
+ class GLPS_ACTV1Config(BaseModel):
41
+ # Core IO / shapes
42
+ batch_size: int
43
+ seq_len: int
44
+ puzzle_emb_ndim: int = 0
45
+ num_puzzle_identifiers: int = 1
46
+ vocab_size: int = 256
47
+
48
+ # Cycle schedule
49
+ H_cycles: int = 3 # (scan -> refine -> check) typical
50
+ L_cycles: int = 1
51
+
52
+ # Depth
53
+ H_layers: int = 2
54
+ L_layers: int = 4
55
+
56
+ # Transformer config
57
+ hidden_size: int = 512
58
+ expansion: float = 2.0
59
+ num_heads: int = 8
60
+ pos_encodings: str = "rope"
61
+
62
+ rms_norm_eps: float = 1e-5
63
+ rope_theta: float = 10000.0
64
+
65
+ # ACT wrapper
66
+ halt_max_steps: int = 4
67
+ halt_exploration_prob: float = 0.1
68
+
69
+ forward_dtype: str = "bfloat16"
70
+
71
+ # Optional: use MLP on L instead of attention (matches HRM/TRM option)
72
+ mlp_t: bool = False
73
+
74
+ # ---- GLPS extras (tiny) ----
75
+ glps_enabled: bool = True
76
+ glps_fill_obvious: bool = True
77
+ glps_dep_graph: bool = True
78
+ glps_token_masking: bool = True
79
+ glps_global_propagate_on_low_conf: bool = True
80
+
81
+ glps_tau_halt: float = 0.95 # final confidence to halt
82
+ glps_tau_uncertain: float = 0.60 # cell-wise certainty threshold
83
+ glps_max_targeted_iters: int = 2 # small number: 1-2
84
+
85
+ # Dependency scorer (low rank bilinear)
86
+ dep_rank: int = 32
87
+ dep_topk: int = 8
88
+
89
+ class GLPSBlock(nn.Module):
90
+ def __init__(self, config: GLPS_ACTV1Config) -> None:
91
+ super().__init__()
92
+ self.config = config
93
+ if self.config.mlp_t:
94
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size)
95
+ self.mlp_t = SwiGLU(
96
+ hidden_size=self.config.seq_len + self.puzzle_emb_len, # treat sequence as channel
97
+ expansion=config.expansion,
98
+ )
99
+ else:
100
+ self.self_attn = Attention(
101
+ hidden_size=config.hidden_size,
102
+ head_dim=config.hidden_size // config.num_heads,
103
+ num_heads=config.num_heads,
104
+ num_key_value_heads=config.num_heads,
105
+ causal=False,
106
+ )
107
+ self.mlp = SwiGLU(hidden_size=config.hidden_size, expansion=config.expansion)
108
+ self.norm_eps = config.rms_norm_eps
109
+
110
+ def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
111
+ if self.config.mlp_t:
112
+ # MLP over sequence dimension (mlp-t)
113
+ hidden_states = hidden_states.transpose(1, 2)
114
+ out = self.mlp_t(hidden_states)
115
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
116
+ hidden_states = hidden_states.transpose(1, 2)
117
+ else:
118
+ hidden_states = rms_norm(
119
+ hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states),
120
+ variance_epsilon=self.norm_eps,
121
+ )
122
+ out = self.mlp(hidden_states)
123
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
124
+ return hidden_states
125
+
126
+ class GLPSReasoningModule(nn.Module):
127
+ """Reasoning stack with optional masked updates (only update uncertain tokens)."""
128
+ def __init__(self, layers: List[GLPSBlock]):
129
+ super().__init__()
130
+ self.layers = torch.nn.ModuleList(layers)
131
+
132
+ def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, update_mask: torch.Tensor = None, **kwargs) -> torch.Tensor:
133
+ x = hidden_states
134
+ for layer in self.layers:
135
+ # Compute candidate update using injected context
136
+ y = layer(hidden_states=x + input_injection, **kwargs)
137
+ if update_mask is not None:
138
+ # Convex blend keeps frozen tokens unchanged
139
+ m = update_mask.to(x.dtype)[..., None]
140
+ x = x + m * (y - x)
141
+ else:
142
+ x = y
143
+ return x
144
+
145
+ class GLPS_ACTV1_Inner(nn.Module):
146
+ def __init__(self, config: GLPS_ACTV1Config) -> None:
147
+ super().__init__()
148
+ self.config = config
149
+ self.forward_dtype = getattr(torch, self.config.forward_dtype)
150
+
151
+ # I/O
152
+ self.embed_scale = math.sqrt(self.config.hidden_size)
153
+ embed_init_std = 1.0 / self.embed_scale
154
+
155
+ self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
156
+ self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
157
+ self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
158
+
159
+ # Puzzle emb (optional) — same convention as HRM/TRM
160
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
161
+ if self.config.puzzle_emb_ndim > 0:
162
+ self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim, batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
163
+
164
+ # Positional encodings
165
+ if self.config.pos_encodings == "rope":
166
+ self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads, max_position_embeddings=self.config.seq_len + self.puzzle_emb_len, base=self.config.rope_theta)
167
+ elif self.config.pos_encodings == "learned":
168
+ self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
169
+
170
+ # Reasoning stacks
171
+ self.H_level = GLPSReasoningModule(layers=[GLPSBlock(self.config) for _ in range(self.config.H_layers)])
172
+ self.L_level = GLPSReasoningModule(layers=[GLPSBlock(self.config) for _ in range(self.config.L_layers)])
173
+
174
+ # Initial states (match HRM/TRM style)
175
+ H_init = trunc_normal_init_(torch.empty(1, 1, self.config.hidden_size, dtype=self.forward_dtype), std=1.0)
176
+ L_init = trunc_normal_init_(torch.empty(1, 1, self.config.hidden_size, dtype=self.forward_dtype), std=1.0)
177
+ self.register_buffer("H_init", H_init, persistent=True)
178
+ self.register_buffer("L_init", L_init, persistent=True)
179
+
180
+ # GLPS small heads
181
+ self.candidate_head = CastedLinear(self.config.hidden_size, 16, bias=True) # task-specific; for Sudoku you can slice to 9
182
+ self.certainty_head = CastedLinear(self.config.hidden_size, 1, bias=True)
183
+ self.energy_head = CastedLinear(self.config.hidden_size, 1, bias=True)
184
+
185
+ # Low-rank dependency scorer (shared)
186
+ r = max(1, self.config.dep_rank)
187
+ self.dep_q = CastedLinear(self.config.hidden_size, r, bias=False)
188
+ self.dep_k = CastedLinear(self.config.hidden_size, r, bias=False)
189
+
190
+ # Q head init like HRM/TRM (near-zero -> easier bootstrapping)
191
+ with torch.no_grad():
192
+ self.q_head.weight.zero_()
193
+ self.q_head.bias.fill_(-5)
194
+
195
+ def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
196
+ # Token embedding
197
+ embedding = self.embed_tokens(input.to(torch.int32))
198
+
199
+ # Puzzle embeddings
200
+ if self.config.puzzle_emb_ndim > 0:
201
+ puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
202
+ pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
203
+ if pad_count > 0:
204
+ puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
205
+ embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
206
+
207
+ # Position embeddings
208
+ if self.config.pos_encodings == "learned":
209
+ embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
210
+
211
+ return self.embed_scale * embedding
212
+
213
+ def empty_carry(self, batch_size: int):
214
+ return GLPS_ACTV1InnerCarry(
215
+ z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
216
+ z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
217
+ )
218
+
219
+ def reset_carry(self, reset_flag: torch.Tensor, carry: GLPS_ACTV1InnerCarry):
220
+ # Explicitly expand buffers and mask to target shapes to avoid shape confusion
221
+ B, L, D = carry.z_H.shape
222
+ # Reduce/reset flag to per-batch boolean vector of shape [B]
223
+ if reset_flag.ndim == 1 and reset_flag.shape[0] == B:
224
+ reset_b = reset_flag.to(torch.bool)
225
+ else:
226
+ # If shape is [B, ...] reduce across non-batch dims; otherwise fallback to first B entries
227
+ try:
228
+ reset_b = reset_flag.reshape(B, -1).any(dim=1).to(torch.bool)
229
+ except Exception:
230
+ reset_b = reset_flag.reshape(-1)[:B].to(torch.bool)
231
+ m = reset_b.view(B, 1, 1)
232
+ mH = m.expand(B, L, D)
233
+ mL = mH # same shape for z_L
234
+ H_init_exp = self.H_init.expand(B, L, D)
235
+ L_init_exp = self.L_init.expand(B, L, D)
236
+ return GLPS_ACTV1InnerCarry(
237
+ z_H=torch.where(mH, H_init_exp, carry.z_H),
238
+ z_L=torch.where(mL, L_init_exp, carry.z_L),
239
+ )
240
+
241
+ def _global_scan(self, z_L, z_H, input_embeddings, seq_info):
242
+ # One light pass to gather global signals
243
+ z_scan = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
244
+ cand_logits = self.candidate_head(z_scan) # [B, L, C]
245
+ certainty = torch.sigmoid(self.certainty_head(z_scan)) # [B, L, 1]
246
+ return z_scan, cand_logits, certainty
247
+
248
+ def _build_dep_mask(self, z_ctx, uncertain_mask: torch.Tensor):
249
+ """Compute a dependency-based focus mask from a low-rank bilinear score.
250
+ uncertain_mask: [B, L] boolean mask of cells that are currently uncertain
251
+ Returns: dep_mask [B, L] boolean mask of cells to (re)update.
252
+ """
253
+ B, L, D = z_ctx.shape
254
+ # Project to low-rank space; use float32 to avoid dtype mismatches under torch.compile
255
+ Q = self.dep_q(z_ctx).to(torch.float32) # [B, L, r]
256
+ K = self.dep_k(z_ctx).to(torch.float32) # [B, L, r]
257
+ r = max(1, int(Q.shape[-1]))
258
+ sim = torch.matmul(Q, K.transpose(1, 2)) # [B, L, L] (float32)
259
+ sim = sim / math.sqrt(r)
260
+
261
+ # Aggregate influence from uncertain queries onto target tokens
262
+ src = uncertain_mask.to(sim.dtype).unsqueeze(1) # [B, 1, L]
263
+ influence = torch.matmul(src, sim).squeeze(1) # [B, L]
264
+
265
+ # Top-k influenced tokens per batch
266
+ topk = min(self.config.dep_topk, L)
267
+ vals, idx = torch.topk(influence, k=topk, dim=-1) # [B, topk]
268
+ dep_mask = torch.zeros_like(uncertain_mask)
269
+ dep_mask.scatter_(1, idx, True)
270
+
271
+ # Always include uncertain cells themselves
272
+ dep_mask = dep_mask | uncertain_mask
273
+ return dep_mask
274
+
275
+ def forward(self, carry: GLPS_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]):
276
+ seq_info = dict(cos_sin=getattr(self, "rotary_emb", None)() if hasattr(self, "rotary_emb") else None)
277
+
278
+ # Encode inputs
279
+ input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
280
+
281
+ # States
282
+ z_H, z_L = carry.z_H, carry.z_L
283
+
284
+ if not self.config.glps_enabled:
285
+ # Fallback to an HRM-like single-cycle grad update for compatibility
286
+ with torch.no_grad():
287
+ for _H in range(self.config.H_cycles - 1):
288
+ for _L in range(self.config.L_cycles):
289
+ z_L = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
290
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
291
+ # final grad step
292
+ for _L in range(self.config.L_cycles):
293
+ z_L = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
294
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
295
+
296
+ # Outputs
297
+ new_carry = GLPS_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())
298
+ logits = self.lm_head(z_H)[:, self.puzzle_emb_len:]
299
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
300
+ conf = torch.zeros_like(q_logits[..., :1]) + 0.5 # neutral
301
+ return new_carry, logits, (q_logits[..., 0], q_logits[..., 1]), conf
302
+
303
+ # ===== GLPS path =====
304
+ # H1: global scan (cheap)
305
+ with torch.no_grad():
306
+ z_scan, cand_logits, certainty = self._global_scan(z_L, z_H, input_embeddings, seq_info)
307
+
308
+ # L1: fill-obvious -> compute stable vs uncertain masks
309
+ if self.config.glps_fill_obvious:
310
+ obvious_mask = (certainty >= self.config.glps_tau_uncertain) # [B, L, 1]
311
+ else:
312
+ obvious_mask = torch.zeros_like(certainty).bool()
313
+ stable_mask = obvious_mask.squeeze(-1) # [B, L]
314
+ uncertain_mask = ~stable_mask # [B, L]
315
+
316
+ # H2: dependency prediction over remaining cells
317
+ if self.config.glps_dep_graph:
318
+ dep_mask = self._build_dep_mask(z_scan, uncertain_mask) # [B, L]
319
+ else:
320
+ dep_mask = uncertain_mask
321
+
322
+ # L2: targeted refinement (a couple of masked iters)
323
+ update_mask = dep_mask if self.config.glps_token_masking else None
324
+ z = z_scan.detach() # use scanned context as start
325
+ for _ in range(self.config.glps_max_targeted_iters):
326
+ z = self.L_level(z, z_H + input_embeddings, update_mask=update_mask, **seq_info)
327
+ # Refresh certainty to shrink mask (optional but cheap)
328
+ cert_now = torch.sigmoid(self.certainty_head(z)).squeeze(-1)
329
+ if self.config.glps_token_masking:
330
+ update_mask = dep_mask & (cert_now < self.config.glps_tau_uncertain)
331
+
332
+ # Merge into H and do a light H update with grad
333
+ z_L = z
334
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
335
+
336
+ # H3: energy/consistency -> confidence & optional global propagate
337
+ energy = torch.sigmoid(self.energy_head(z_H)).mean(dim=1) # [B, 1]
338
+ conf = 1.0 - energy
339
+ need_sweep = (conf.squeeze(-1) < self.config.glps_tau_halt)
340
+ if self.config.glps_global_propagate_on_low_conf and need_sweep.any():
341
+ # one final full sweep only for rows needing it
342
+ maskB = need_sweep.view(-1, 1, 1)
343
+ zL2 = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
344
+ zH2 = self.H_level(z_H, zL2, update_mask=None, **seq_info)
345
+ z_L = torch.where(maskB, zL2, z_L)
346
+ z_H = torch.where(maskB, zH2, z_H)
347
+
348
+ # Outputs
349
+ new_carry = GLPS_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())
350
+ logits = self.lm_head(z_H)[:, self.puzzle_emb_len:]
351
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
352
+ return new_carry, logits, (q_logits[..., 0], q_logits[..., 1]), conf
353
+
354
+ class GLPS_ACTV1(nn.Module):
355
+ """ACT-style wrapper that mixes Q-halt with GLPS confidence."""
356
+ def __init__(self, config_dict: dict):
357
+ super().__init__()
358
+ self.config = GLPS_ACTV1Config(**config_dict)
359
+ self.inner = GLPS_ACTV1_Inner(self.config)
360
+
361
+ @property
362
+ def puzzle_emb(self):
363
+ return self.inner.puzzle_emb
364
+
365
+ def initial_carry(self, batch: Dict[str, torch.Tensor]):
366
+ batch_size = batch["inputs"].shape[0]
367
+ return GLPS_ACTV1Carry(
368
+ inner_carry=self.inner.empty_carry(batch_size),
369
+ steps=torch.zeros((batch_size,), dtype=torch.int32),
370
+ halted=torch.ones((batch_size,), dtype=torch.bool), # start halted to force reset on first pass
371
+ current_data={k: torch.empty_like(v) for k, v in batch.items()}
372
+ )
373
+
374
+ def forward(self, carry: GLPS_ACTV1Carry, batch: Dict[str, torch.Tensor]):
375
+ # Reset halted seqs
376
+ new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
377
+ new_steps = torch.where(carry.halted, 0, carry.steps)
378
+ new_current_data = {k: torch.where(carry.halted.view((-1,) + (1,) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
379
+
380
+ # Inner step
381
+ new_inner_carry, logits, (q_halt_logits, q_continue_logits), conf = self.inner(new_inner_carry, new_current_data)
382
+
383
+ outputs = {
384
+ "logits": logits,
385
+ "q_halt_logits": q_halt_logits,
386
+ "q_continue_logits": q_continue_logits,
387
+ "conf": conf.squeeze(-1),
388
+ }
389
+
390
+ with torch.no_grad():
391
+ new_steps = new_steps + 1
392
+ is_last_step = new_steps >= self.config.halt_max_steps
393
+
394
+ # Combine halt signals: Q or confidence or last-step
395
+ halted = is_last_step | (q_halt_logits > q_continue_logits) | (conf.squeeze(-1) >= self.config.glps_tau_halt)
396
+
397
+ # Exploration during training only
398
+ if self.training and (self.config.halt_max_steps > 1):
399
+ min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
400
+ halted = halted & (new_steps >= min_halt_steps)
401
+
402
+ # Optional target for Q-learning (kept similar to HRM)
403
+ next_conf = self.inner(new_inner_carry, new_current_data)[-1]
404
+ outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, q_halt_logits, torch.maximum(q_halt_logits, q_continue_logits)))
405
+ else:
406
+ # During eval, always use max_steps to ensure consistent reasoning depth (same as TRM/HRM eval behavior)
407
+ halted = is_last_step
408
+
409
+ return GLPS_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
Maze-30x30-hard-1k-ACT-torch/pretrain_glps_maze30_v1_cos_b160/losses.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Tuple, Dict, Sequence, Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ import math
7
+
8
+ IGNORE_LABEL_ID = -100
9
+
10
+
11
+ def s(x, epsilon=1e-30):
12
+ return torch.where(
13
+ x<0,
14
+ 1/(1-x+ epsilon),
15
+ x + 1
16
+ )
17
+
18
+
19
+ def log_stablemax(x, dim=-1):
20
+ s_x = s(x)
21
+ return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
22
+
23
+
24
+ def stablemax_cross_entropy(logits, labels, ignore_index: int = -100, valid_mask=None):
25
+ logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
26
+
27
+ if valid_mask is None:
28
+ valid_mask = (labels != ignore_index)
29
+ transformed_labels = torch.where(valid_mask, labels, 0)
30
+ prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
31
+
32
+ return -torch.where(valid_mask, prediction_logprobs, 0)
33
+
34
+
35
+ def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
36
+ # Cast logits to f32
37
+ # Flatten logits
38
+ return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
39
+
40
+
41
+ class ACTLossHead(nn.Module):
42
+ def __init__(self, model: nn.Module, loss_type: str):
43
+ super().__init__()
44
+ self.model = model
45
+ self.loss_fn = globals()[loss_type]
46
+
47
+ def initial_carry(self, *args, **kwargs):
48
+ return self.model.initial_carry(*args, **kwargs) # type: ignore
49
+
50
+ def forward(
51
+ self,
52
+ return_keys: Sequence[str],
53
+ # Model args
54
+ **model_kwargs,
55
+ ) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
56
+ # Model logits
57
+ # B x SeqLen x D
58
+ new_carry, outputs = self.model(**model_kwargs)
59
+ labels = new_carry.current_data["labels"]
60
+
61
+ with torch.no_grad():
62
+ # Preds
63
+ outputs["preds"] = torch.argmax(outputs["logits"], dim=-1)
64
+
65
+ # Correctness
66
+ mask = (labels != IGNORE_LABEL_ID)
67
+ loss_counts = mask.sum(-1)
68
+ loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
69
+
70
+ is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
71
+ seq_is_correct = is_correct.sum(-1) == loss_counts
72
+
73
+ # Metrics (halted)
74
+ valid_metrics = new_carry.halted & (loss_counts > 0)
75
+ metrics = {
76
+ "count": valid_metrics.sum(),
77
+
78
+ "accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
79
+ "exact_accuracy": (valid_metrics & seq_is_correct).sum(),
80
+
81
+ "q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
82
+ "steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
83
+ }
84
+
85
+ # Losses
86
+
87
+ lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID, valid_mask=mask) / loss_divisor).sum()
88
+ q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
89
+ metrics.update({
90
+ "lm_loss": lm_loss.detach(),
91
+ "q_halt_loss": q_halt_loss.detach(),
92
+ })
93
+ # Q continue (bootstrapping target loss); Alexia: This fits Q-learning, but seems totally unecessary
94
+ q_continue_loss = 0
95
+ if "target_q_continue" in outputs:
96
+ q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
97
+
98
+ metrics["q_continue_loss"] = q_continue_loss.detach()
99
+ # Filter outputs for return
100
+ detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
101
+
102
+ return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()
Maze-30x30-hard-1k-ACT-torch/pretrain_glps_maze30_v1_e10k_b160/all_config.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ arch:
2
+ H_cycles: 3
3
+ H_layers: 2
4
+ L_cycles: 1
5
+ L_layers: 7
6
+ dep_rank: 64
7
+ dep_topk: 12
8
+ expansion: 4.0
9
+ forward_dtype: bfloat16
10
+ glps_dep_graph: true
11
+ glps_enabled: true
12
+ glps_fill_obvious: true
13
+ glps_global_propagate_on_low_conf: true
14
+ glps_max_targeted_iters: 4
15
+ glps_tau_halt: 0.95
16
+ glps_tau_uncertain: 0.8
17
+ glps_token_masking: true
18
+ halt_exploration_prob: 0.1
19
+ halt_max_steps: 16
20
+ hidden_size: 512
21
+ loss:
22
+ loss_type: stablemax_cross_entropy
23
+ name: losses@ACTLossHead
24
+ mlp_t: false
25
+ name: recursive_reasoning.glps@GLPS_ACTV1
26
+ num_heads: 8
27
+ pos_encodings: rope
28
+ puzzle_emb_ndim: 0
29
+ rms_norm_eps: 1.0e-05
30
+ rope_theta: 10000.0
31
+ beta1: 0.9
32
+ beta2: 0.95
33
+ checkpoint_every_eval: true
34
+ checkpoint_path: checkpoints/Maze-30x30-hard-1k-ACT-torch/pretrain_glps_maze30_v1_e10k_b160
35
+ data_paths:
36
+ - data/maze-30x30-hard-1k
37
+ data_paths_test: []
38
+ ema: true
39
+ ema_rate: 0.999
40
+ epochs: 10000
41
+ eval_glps_max_targeted_iters: null
42
+ eval_glps_tau_halt: null
43
+ eval_halt_max_steps: null
44
+ eval_interval: 1000
45
+ eval_only: false
46
+ eval_save_outputs: []
47
+ evaluators: []
48
+ freeze_weights: false
49
+ global_batch_size: 160
50
+ load_checkpoint: null
51
+ lr: 0.0001
52
+ lr_min_ratio: 0.1
53
+ lr_warmup_steps: 2000
54
+ min_eval_interval: 0
55
+ project_name: Maze-30x30-hard-1k-ACT-torch
56
+ puzzle_emb_lr: 0.0001
57
+ puzzle_emb_weight_decay: 1.0
58
+ run_name: pretrain_glps_maze30_v1_e10k_b160
59
+ seed: 0
60
+ weight_decay: 1.0
Maze-30x30-hard-1k-ACT-torch/pretrain_glps_maze30_v1_e10k_b160/glps.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Tuple, List, Dict
3
+ from dataclasses import dataclass
4
+ import math
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ from pydantic import BaseModel
9
+
10
+ # Reuse the same building blocks as HRM/TRM
11
+ from models.common import trunc_normal_init_
12
+ from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
13
+ from models.sparse_embedding import CastedSparseEmbedding
14
+
15
+ """
16
+ Global-Local Predictive Solver (GLPS)
17
+ ------------------------------------
18
+ A light-weight control-policy on top of the HRM/TRM style blocks:
19
+ - H1: global scan -> certainty map
20
+ - L1: fill-obvious (lock stable cells)
21
+ - H2: dependency scoring over remaining cells
22
+ - L2: targeted refinement (masked updates)
23
+ - H3: energy-based confidence -> (optional) one global propagate sweep -> halt
24
+
25
+ This file keeps parameter growth tiny: a few heads + (optional) low-rank dependency scorer.
26
+ """
27
+
28
+ @dataclass
29
+ class GLPS_ACTV1InnerCarry:
30
+ z_H: torch.Tensor
31
+ z_L: torch.Tensor
32
+
33
+ @dataclass
34
+ class GLPS_ACTV1Carry:
35
+ inner_carry: GLPS_ACTV1InnerCarry
36
+ steps: torch.Tensor
37
+ halted: torch.Tensor
38
+ current_data: Dict[str, torch.Tensor]
39
+
40
+ class GLPS_ACTV1Config(BaseModel):
41
+ # Core IO / shapes
42
+ batch_size: int
43
+ seq_len: int
44
+ puzzle_emb_ndim: int = 0
45
+ num_puzzle_identifiers: int = 1
46
+ vocab_size: int = 256
47
+
48
+ # Cycle schedule
49
+ H_cycles: int = 3 # (scan -> refine -> check) typical
50
+ L_cycles: int = 1
51
+
52
+ # Depth
53
+ H_layers: int = 2
54
+ L_layers: int = 4
55
+
56
+ # Transformer config
57
+ hidden_size: int = 512
58
+ expansion: float = 2.0
59
+ num_heads: int = 8
60
+ pos_encodings: str = "rope"
61
+
62
+ rms_norm_eps: float = 1e-5
63
+ rope_theta: float = 10000.0
64
+
65
+ # ACT wrapper
66
+ halt_max_steps: int = 4
67
+ halt_exploration_prob: float = 0.1
68
+
69
+ forward_dtype: str = "bfloat16"
70
+
71
+ # Optional: use MLP on L instead of attention (matches HRM/TRM option)
72
+ mlp_t: bool = False
73
+
74
+ # ---- GLPS extras (tiny) ----
75
+ glps_enabled: bool = True
76
+ glps_fill_obvious: bool = True
77
+ glps_dep_graph: bool = True
78
+ glps_token_masking: bool = True
79
+ glps_global_propagate_on_low_conf: bool = True
80
+
81
+ glps_tau_halt: float = 0.95 # final confidence to halt
82
+ glps_tau_uncertain: float = 0.60 # cell-wise certainty threshold
83
+ glps_max_targeted_iters: int = 2 # small number: 1-2
84
+
85
+ # Dependency scorer (low rank bilinear)
86
+ dep_rank: int = 32
87
+ dep_topk: int = 8
88
+
89
+ class GLPSBlock(nn.Module):
90
+ def __init__(self, config: GLPS_ACTV1Config) -> None:
91
+ super().__init__()
92
+ self.config = config
93
+ if self.config.mlp_t:
94
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size)
95
+ self.mlp_t = SwiGLU(
96
+ hidden_size=self.config.seq_len + self.puzzle_emb_len, # treat sequence as channel
97
+ expansion=config.expansion,
98
+ )
99
+ else:
100
+ self.self_attn = Attention(
101
+ hidden_size=config.hidden_size,
102
+ head_dim=config.hidden_size // config.num_heads,
103
+ num_heads=config.num_heads,
104
+ num_key_value_heads=config.num_heads,
105
+ causal=False,
106
+ )
107
+ self.mlp = SwiGLU(hidden_size=config.hidden_size, expansion=config.expansion)
108
+ self.norm_eps = config.rms_norm_eps
109
+
110
+ def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
111
+ if self.config.mlp_t:
112
+ # MLP over sequence dimension (mlp-t)
113
+ hidden_states = hidden_states.transpose(1, 2)
114
+ out = self.mlp_t(hidden_states)
115
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
116
+ hidden_states = hidden_states.transpose(1, 2)
117
+ else:
118
+ hidden_states = rms_norm(
119
+ hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states),
120
+ variance_epsilon=self.norm_eps,
121
+ )
122
+ out = self.mlp(hidden_states)
123
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
124
+ return hidden_states
125
+
126
+ class GLPSReasoningModule(nn.Module):
127
+ """Reasoning stack with optional masked updates (only update uncertain tokens)."""
128
+ def __init__(self, layers: List[GLPSBlock]):
129
+ super().__init__()
130
+ self.layers = torch.nn.ModuleList(layers)
131
+
132
+ def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, update_mask: torch.Tensor = None, **kwargs) -> torch.Tensor:
133
+ x = hidden_states
134
+ for layer in self.layers:
135
+ # Compute candidate update using injected context
136
+ y = layer(hidden_states=x + input_injection, **kwargs)
137
+ if update_mask is not None:
138
+ # Convex blend keeps frozen tokens unchanged
139
+ m = update_mask.to(x.dtype)[..., None]
140
+ x = x + m * (y - x)
141
+ else:
142
+ x = y
143
+ return x
144
+
145
+ class GLPS_ACTV1_Inner(nn.Module):
146
+ def __init__(self, config: GLPS_ACTV1Config) -> None:
147
+ super().__init__()
148
+ self.config = config
149
+ self.forward_dtype = getattr(torch, self.config.forward_dtype)
150
+
151
+ # I/O
152
+ self.embed_scale = math.sqrt(self.config.hidden_size)
153
+ embed_init_std = 1.0 / self.embed_scale
154
+
155
+ self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
156
+ self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
157
+ self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
158
+
159
+ # Puzzle emb (optional) — same convention as HRM/TRM
160
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
161
+ if self.config.puzzle_emb_ndim > 0:
162
+ self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim, batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
163
+
164
+ # Positional encodings
165
+ if self.config.pos_encodings == "rope":
166
+ self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads, max_position_embeddings=self.config.seq_len + self.puzzle_emb_len, base=self.config.rope_theta)
167
+ elif self.config.pos_encodings == "learned":
168
+ self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
169
+
170
+ # Reasoning stacks
171
+ self.H_level = GLPSReasoningModule(layers=[GLPSBlock(self.config) for _ in range(self.config.H_layers)])
172
+ self.L_level = GLPSReasoningModule(layers=[GLPSBlock(self.config) for _ in range(self.config.L_layers)])
173
+
174
+ # Initial states (match HRM/TRM style)
175
+ H_init = trunc_normal_init_(torch.empty(1, 1, self.config.hidden_size, dtype=self.forward_dtype), std=1.0)
176
+ L_init = trunc_normal_init_(torch.empty(1, 1, self.config.hidden_size, dtype=self.forward_dtype), std=1.0)
177
+ self.register_buffer("H_init", H_init, persistent=True)
178
+ self.register_buffer("L_init", L_init, persistent=True)
179
+
180
+ # GLPS small heads
181
+ self.candidate_head = CastedLinear(self.config.hidden_size, 16, bias=True) # task-specific; for Sudoku you can slice to 9
182
+ self.certainty_head = CastedLinear(self.config.hidden_size, 1, bias=True)
183
+ self.energy_head = CastedLinear(self.config.hidden_size, 1, bias=True)
184
+
185
+ # Low-rank dependency scorer (shared)
186
+ r = max(1, self.config.dep_rank)
187
+ self.dep_q = CastedLinear(self.config.hidden_size, r, bias=False)
188
+ self.dep_k = CastedLinear(self.config.hidden_size, r, bias=False)
189
+
190
+ # Q head init like HRM/TRM (near-zero -> easier bootstrapping)
191
+ with torch.no_grad():
192
+ self.q_head.weight.zero_()
193
+ self.q_head.bias.fill_(-5)
194
+
195
+ def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
196
+ # Token embedding
197
+ embedding = self.embed_tokens(input.to(torch.int32))
198
+
199
+ # Puzzle embeddings
200
+ if self.config.puzzle_emb_ndim > 0:
201
+ puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
202
+ pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
203
+ if pad_count > 0:
204
+ puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
205
+ embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
206
+
207
+ # Position embeddings
208
+ if self.config.pos_encodings == "learned":
209
+ embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
210
+
211
+ return self.embed_scale * embedding
212
+
213
+ def empty_carry(self, batch_size: int):
214
+ return GLPS_ACTV1InnerCarry(
215
+ z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
216
+ z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
217
+ )
218
+
219
+ def reset_carry(self, reset_flag: torch.Tensor, carry: GLPS_ACTV1InnerCarry):
220
+ # Explicitly expand buffers and mask to target shapes to avoid shape confusion
221
+ B, L, D = carry.z_H.shape
222
+ # Reduce/reset flag to per-batch boolean vector of shape [B]
223
+ if reset_flag.ndim == 1 and reset_flag.shape[0] == B:
224
+ reset_b = reset_flag.to(torch.bool)
225
+ else:
226
+ # If shape is [B, ...] reduce across non-batch dims; otherwise fallback to first B entries
227
+ try:
228
+ reset_b = reset_flag.reshape(B, -1).any(dim=1).to(torch.bool)
229
+ except Exception:
230
+ reset_b = reset_flag.reshape(-1)[:B].to(torch.bool)
231
+ m = reset_b.view(B, 1, 1)
232
+ mH = m.expand(B, L, D)
233
+ mL = mH # same shape for z_L
234
+ H_init_exp = self.H_init.expand(B, L, D)
235
+ L_init_exp = self.L_init.expand(B, L, D)
236
+ return GLPS_ACTV1InnerCarry(
237
+ z_H=torch.where(mH, H_init_exp, carry.z_H),
238
+ z_L=torch.where(mL, L_init_exp, carry.z_L),
239
+ )
240
+
241
+ def _global_scan(self, z_L, z_H, input_embeddings, seq_info):
242
+ # One light pass to gather global signals
243
+ z_scan = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
244
+ cand_logits = self.candidate_head(z_scan) # [B, L, C]
245
+ certainty = torch.sigmoid(self.certainty_head(z_scan)) # [B, L, 1]
246
+ return z_scan, cand_logits, certainty
247
+
248
+ def _build_dep_mask(self, z_ctx, uncertain_mask: torch.Tensor):
249
+ """Compute a dependency-based focus mask from a low-rank bilinear score.
250
+ uncertain_mask: [B, L] boolean mask of cells that are currently uncertain
251
+ Returns: dep_mask [B, L] boolean mask of cells to (re)update.
252
+ """
253
+ B, L, D = z_ctx.shape
254
+ # Project to low-rank space; use float32 to avoid dtype mismatches under torch.compile
255
+ Q = self.dep_q(z_ctx).to(torch.float32) # [B, L, r]
256
+ K = self.dep_k(z_ctx).to(torch.float32) # [B, L, r]
257
+ r = max(1, int(Q.shape[-1]))
258
+ sim = torch.matmul(Q, K.transpose(1, 2)) # [B, L, L] (float32)
259
+ sim = sim / math.sqrt(r)
260
+
261
+ # Aggregate influence from uncertain queries onto target tokens
262
+ src = uncertain_mask.to(sim.dtype).unsqueeze(1) # [B, 1, L]
263
+ influence = torch.matmul(src, sim).squeeze(1) # [B, L]
264
+
265
+ # Top-k influenced tokens per batch
266
+ topk = min(self.config.dep_topk, L)
267
+ vals, idx = torch.topk(influence, k=topk, dim=-1) # [B, topk]
268
+ dep_mask = torch.zeros_like(uncertain_mask)
269
+ dep_mask.scatter_(1, idx, True)
270
+
271
+ # Always include uncertain cells themselves
272
+ dep_mask = dep_mask | uncertain_mask
273
+ return dep_mask
274
+
275
+ def forward(self, carry: GLPS_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]):
276
+ seq_info = dict(cos_sin=getattr(self, "rotary_emb", None)() if hasattr(self, "rotary_emb") else None)
277
+
278
+ # Encode inputs
279
+ input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
280
+
281
+ # States
282
+ z_H, z_L = carry.z_H, carry.z_L
283
+
284
+ if not self.config.glps_enabled:
285
+ # Fallback to an HRM-like single-cycle grad update for compatibility
286
+ with torch.no_grad():
287
+ for _H in range(self.config.H_cycles - 1):
288
+ for _L in range(self.config.L_cycles):
289
+ z_L = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
290
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
291
+ # final grad step
292
+ for _L in range(self.config.L_cycles):
293
+ z_L = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
294
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
295
+
296
+ # Outputs
297
+ new_carry = GLPS_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())
298
+ logits = self.lm_head(z_H)[:, self.puzzle_emb_len:]
299
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
300
+ conf = torch.zeros_like(q_logits[..., :1]) + 0.5 # neutral
301
+ return new_carry, logits, (q_logits[..., 0], q_logits[..., 1]), conf
302
+
303
+ # ===== GLPS path =====
304
+ # H1: global scan (cheap)
305
+ with torch.no_grad():
306
+ z_scan, cand_logits, certainty = self._global_scan(z_L, z_H, input_embeddings, seq_info)
307
+
308
+ # L1: fill-obvious -> compute stable vs uncertain masks
309
+ if self.config.glps_fill_obvious:
310
+ obvious_mask = (certainty >= self.config.glps_tau_uncertain) # [B, L, 1]
311
+ else:
312
+ obvious_mask = torch.zeros_like(certainty).bool()
313
+ stable_mask = obvious_mask.squeeze(-1) # [B, L]
314
+ uncertain_mask = ~stable_mask # [B, L]
315
+
316
+ # H2: dependency prediction over remaining cells
317
+ if self.config.glps_dep_graph:
318
+ dep_mask = self._build_dep_mask(z_scan, uncertain_mask) # [B, L]
319
+ else:
320
+ dep_mask = uncertain_mask
321
+
322
+ # L2: targeted refinement (a couple of masked iters)
323
+ update_mask = dep_mask if self.config.glps_token_masking else None
324
+ z = z_scan.detach() # use scanned context as start
325
+ for _ in range(self.config.glps_max_targeted_iters):
326
+ z = self.L_level(z, z_H + input_embeddings, update_mask=update_mask, **seq_info)
327
+ # Refresh certainty to shrink mask (optional but cheap)
328
+ cert_now = torch.sigmoid(self.certainty_head(z)).squeeze(-1)
329
+ if self.config.glps_token_masking:
330
+ update_mask = dep_mask & (cert_now < self.config.glps_tau_uncertain)
331
+
332
+ # Merge into H and do a light H update with grad
333
+ z_L = z
334
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
335
+
336
+ # H3: energy/consistency -> confidence & optional global propagate
337
+ energy = torch.sigmoid(self.energy_head(z_H)).mean(dim=1) # [B, 1]
338
+ conf = 1.0 - energy
339
+ need_sweep = (conf.squeeze(-1) < self.config.glps_tau_halt)
340
+ if self.config.glps_global_propagate_on_low_conf and need_sweep.any():
341
+ # one final full sweep only for rows needing it
342
+ maskB = need_sweep.view(-1, 1, 1)
343
+ zL2 = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
344
+ zH2 = self.H_level(z_H, zL2, update_mask=None, **seq_info)
345
+ z_L = torch.where(maskB, zL2, z_L)
346
+ z_H = torch.where(maskB, zH2, z_H)
347
+
348
+ # Outputs
349
+ new_carry = GLPS_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())
350
+ logits = self.lm_head(z_H)[:, self.puzzle_emb_len:]
351
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
352
+ return new_carry, logits, (q_logits[..., 0], q_logits[..., 1]), conf
353
+
354
+ class GLPS_ACTV1(nn.Module):
355
+ """ACT-style wrapper that mixes Q-halt with GLPS confidence."""
356
+ def __init__(self, config_dict: dict):
357
+ super().__init__()
358
+ self.config = GLPS_ACTV1Config(**config_dict)
359
+ self.inner = GLPS_ACTV1_Inner(self.config)
360
+
361
+ @property
362
+ def puzzle_emb(self):
363
+ return self.inner.puzzle_emb
364
+
365
+ def initial_carry(self, batch: Dict[str, torch.Tensor]):
366
+ batch_size = batch["inputs"].shape[0]
367
+ return GLPS_ACTV1Carry(
368
+ inner_carry=self.inner.empty_carry(batch_size),
369
+ steps=torch.zeros((batch_size,), dtype=torch.int32),
370
+ halted=torch.ones((batch_size,), dtype=torch.bool), # start halted to force reset on first pass
371
+ current_data={k: torch.empty_like(v) for k, v in batch.items()}
372
+ )
373
+
374
+ def forward(self, carry: GLPS_ACTV1Carry, batch: Dict[str, torch.Tensor]):
375
+ # Reset halted seqs
376
+ new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
377
+ new_steps = torch.where(carry.halted, 0, carry.steps)
378
+ new_current_data = {k: torch.where(carry.halted.view((-1,) + (1,) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
379
+
380
+ # Inner step
381
+ new_inner_carry, logits, (q_halt_logits, q_continue_logits), conf = self.inner(new_inner_carry, new_current_data)
382
+
383
+ outputs = {
384
+ "logits": logits,
385
+ "q_halt_logits": q_halt_logits,
386
+ "q_continue_logits": q_continue_logits,
387
+ "conf": conf.squeeze(-1),
388
+ }
389
+
390
+ with torch.no_grad():
391
+ new_steps = new_steps + 1
392
+ is_last_step = new_steps >= self.config.halt_max_steps
393
+
394
+ # Combine halt signals: Q or confidence or last-step
395
+ halted = is_last_step | (q_halt_logits > q_continue_logits) | (conf.squeeze(-1) >= self.config.glps_tau_halt)
396
+
397
+ # Exploration during training only
398
+ if self.training and (self.config.halt_max_steps > 1):
399
+ min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
400
+ halted = halted & (new_steps >= min_halt_steps)
401
+
402
+ # Optional target for Q-learning (kept similar to HRM)
403
+ next_conf = self.inner(new_inner_carry, new_current_data)[-1]
404
+ outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, q_halt_logits, torch.maximum(q_halt_logits, q_continue_logits)))
405
+ else:
406
+ # During eval, always use max_steps to ensure consistent reasoning depth (same as TRM/HRM eval behavior)
407
+ halted = is_last_step
408
+
409
+ return GLPS_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
Maze-30x30-hard-1k-ACT-torch/pretrain_glps_maze30_v1_e10k_b160/losses.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Tuple, Dict, Sequence, Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ import math
7
+
8
+ IGNORE_LABEL_ID = -100
9
+
10
+
11
+ def s(x, epsilon=1e-30):
12
+ return torch.where(
13
+ x<0,
14
+ 1/(1-x+ epsilon),
15
+ x + 1
16
+ )
17
+
18
+
19
+ def log_stablemax(x, dim=-1):
20
+ s_x = s(x)
21
+ return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
22
+
23
+
24
+ def stablemax_cross_entropy(logits, labels, ignore_index: int = -100, valid_mask=None):
25
+ logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
26
+
27
+ if valid_mask is None:
28
+ valid_mask = (labels != ignore_index)
29
+ transformed_labels = torch.where(valid_mask, labels, 0)
30
+ prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
31
+
32
+ return -torch.where(valid_mask, prediction_logprobs, 0)
33
+
34
+
35
+ def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
36
+ # Cast logits to f32
37
+ # Flatten logits
38
+ return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
39
+
40
+
41
+ class ACTLossHead(nn.Module):
42
+ def __init__(self, model: nn.Module, loss_type: str):
43
+ super().__init__()
44
+ self.model = model
45
+ self.loss_fn = globals()[loss_type]
46
+
47
+ def initial_carry(self, *args, **kwargs):
48
+ return self.model.initial_carry(*args, **kwargs) # type: ignore
49
+
50
+ def forward(
51
+ self,
52
+ return_keys: Sequence[str],
53
+ # Model args
54
+ **model_kwargs,
55
+ ) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
56
+ # Model logits
57
+ # B x SeqLen x D
58
+ new_carry, outputs = self.model(**model_kwargs)
59
+ labels = new_carry.current_data["labels"]
60
+
61
+ with torch.no_grad():
62
+ # Preds
63
+ outputs["preds"] = torch.argmax(outputs["logits"], dim=-1)
64
+
65
+ # Correctness
66
+ mask = (labels != IGNORE_LABEL_ID)
67
+ loss_counts = mask.sum(-1)
68
+ loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
69
+
70
+ is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
71
+ seq_is_correct = is_correct.sum(-1) == loss_counts
72
+
73
+ # Metrics (halted)
74
+ valid_metrics = new_carry.halted & (loss_counts > 0)
75
+ metrics = {
76
+ "count": valid_metrics.sum(),
77
+
78
+ "accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
79
+ "exact_accuracy": (valid_metrics & seq_is_correct).sum(),
80
+
81
+ "q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
82
+ "steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
83
+ }
84
+
85
+ # Losses
86
+
87
+ lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID, valid_mask=mask) / loss_divisor).sum()
88
+ q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
89
+ metrics.update({
90
+ "lm_loss": lm_loss.detach(),
91
+ "q_halt_loss": q_halt_loss.detach(),
92
+ })
93
+ # Q continue (bootstrapping target loss); Alexia: This fits Q-learning, but seems totally unecessary
94
+ q_continue_loss = 0
95
+ if "target_q_continue" in outputs:
96
+ q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
97
+
98
+ metrics["q_continue_loss"] = q_continue_loss.detach()
99
+ # Filter outputs for return
100
+ detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
101
+
102
+ return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()
Maze-30x30-hard-1k-ACT-torch/pretrain_glps_maze30_v1_e10k_evalboost_extreme/all_config.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ arch:
2
+ H_cycles: 3
3
+ H_layers: 2
4
+ L_cycles: 1
5
+ L_layers: 7
6
+ dep_rank: 64
7
+ dep_topk: 12
8
+ expansion: 4.0
9
+ forward_dtype: bfloat16
10
+ glps_dep_graph: true
11
+ glps_enabled: true
12
+ glps_fill_obvious: true
13
+ glps_global_propagate_on_low_conf: true
14
+ glps_max_targeted_iters: 4
15
+ glps_tau_halt: 0.95
16
+ glps_tau_uncertain: 0.8
17
+ glps_token_masking: true
18
+ halt_exploration_prob: 0.1
19
+ halt_max_steps: 16
20
+ hidden_size: 512
21
+ loss:
22
+ loss_type: stablemax_cross_entropy
23
+ name: losses@ACTLossHead
24
+ mlp_t: false
25
+ name: recursive_reasoning.glps@GLPS_ACTV1
26
+ num_heads: 8
27
+ pos_encodings: rope
28
+ puzzle_emb_ndim: 0
29
+ rms_norm_eps: 1.0e-05
30
+ rope_theta: 10000.0
31
+ beta1: 0.9
32
+ beta2: 0.95
33
+ checkpoint_every_eval: true
34
+ checkpoint_path: checkpoints/Maze-30x30-hard-1k-ACT-torch/pretrain_glps_maze30_v1_e10k_evalboost_extreme
35
+ data_paths:
36
+ - data/maze-30x30-hard-1k
37
+ data_paths_test: []
38
+ ema: true
39
+ ema_rate: 0.999
40
+ epochs: 10000
41
+ eval_glps_max_targeted_iters: 10
42
+ eval_glps_tau_halt: 0.75
43
+ eval_halt_max_steps: 48
44
+ eval_interval: 1000
45
+ eval_only: true
46
+ eval_save_outputs: []
47
+ evaluators: []
48
+ freeze_weights: false
49
+ global_batch_size: 160
50
+ load_checkpoint: checkpoints/Maze-30x30-hard-1k-ACT-torch/pretrain_glps_maze30_v1_e10k_b160/step_6250
51
+ lr: 0.0001
52
+ lr_min_ratio: 0.1
53
+ lr_warmup_steps: 2000
54
+ min_eval_interval: 0
55
+ project_name: Maze-30x30-hard-1k-ACT-torch
56
+ puzzle_emb_lr: 0.0001
57
+ puzzle_emb_weight_decay: 1.0
58
+ run_name: pretrain_glps_maze30_v1_e10k_evalboost_extreme
59
+ seed: 0
60
+ weight_decay: 1.0
Maze-30x30-hard-1k-ACT-torch/pretrain_glps_maze30_v1_e10k_evalboost_extreme/glps.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Tuple, List, Dict
3
+ from dataclasses import dataclass
4
+ import math
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ from pydantic import BaseModel
9
+
10
+ # Reuse the same building blocks as HRM/TRM
11
+ from models.common import trunc_normal_init_
12
+ from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
13
+ from models.sparse_embedding import CastedSparseEmbedding
14
+
15
+ """
16
+ Global-Local Predictive Solver (GLPS)
17
+ ------------------------------------
18
+ A light-weight control-policy on top of the HRM/TRM style blocks:
19
+ - H1: global scan -> certainty map
20
+ - L1: fill-obvious (lock stable cells)
21
+ - H2: dependency scoring over remaining cells
22
+ - L2: targeted refinement (masked updates)
23
+ - H3: energy-based confidence -> (optional) one global propagate sweep -> halt
24
+
25
+ This file keeps parameter growth tiny: a few heads + (optional) low-rank dependency scorer.
26
+ """
27
+
28
+ @dataclass
29
+ class GLPS_ACTV1InnerCarry:
30
+ z_H: torch.Tensor
31
+ z_L: torch.Tensor
32
+
33
+ @dataclass
34
+ class GLPS_ACTV1Carry:
35
+ inner_carry: GLPS_ACTV1InnerCarry
36
+ steps: torch.Tensor
37
+ halted: torch.Tensor
38
+ current_data: Dict[str, torch.Tensor]
39
+
40
+ class GLPS_ACTV1Config(BaseModel):
41
+ # Core IO / shapes
42
+ batch_size: int
43
+ seq_len: int
44
+ puzzle_emb_ndim: int = 0
45
+ num_puzzle_identifiers: int = 1
46
+ vocab_size: int = 256
47
+
48
+ # Cycle schedule
49
+ H_cycles: int = 3 # (scan -> refine -> check) typical
50
+ L_cycles: int = 1
51
+
52
+ # Depth
53
+ H_layers: int = 2
54
+ L_layers: int = 4
55
+
56
+ # Transformer config
57
+ hidden_size: int = 512
58
+ expansion: float = 2.0
59
+ num_heads: int = 8
60
+ pos_encodings: str = "rope"
61
+
62
+ rms_norm_eps: float = 1e-5
63
+ rope_theta: float = 10000.0
64
+
65
+ # ACT wrapper
66
+ halt_max_steps: int = 4
67
+ halt_exploration_prob: float = 0.1
68
+
69
+ forward_dtype: str = "bfloat16"
70
+
71
+ # Optional: use MLP on L instead of attention (matches HRM/TRM option)
72
+ mlp_t: bool = False
73
+
74
+ # ---- GLPS extras (tiny) ----
75
+ glps_enabled: bool = True
76
+ glps_fill_obvious: bool = True
77
+ glps_dep_graph: bool = True
78
+ glps_token_masking: bool = True
79
+ glps_global_propagate_on_low_conf: bool = True
80
+
81
+ glps_tau_halt: float = 0.95 # final confidence to halt
82
+ glps_tau_uncertain: float = 0.60 # cell-wise certainty threshold
83
+ glps_max_targeted_iters: int = 2 # small number: 1-2
84
+
85
+ # Dependency scorer (low rank bilinear)
86
+ dep_rank: int = 32
87
+ dep_topk: int = 8
88
+
89
+ class GLPSBlock(nn.Module):
90
+ def __init__(self, config: GLPS_ACTV1Config) -> None:
91
+ super().__init__()
92
+ self.config = config
93
+ if self.config.mlp_t:
94
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size)
95
+ self.mlp_t = SwiGLU(
96
+ hidden_size=self.config.seq_len + self.puzzle_emb_len, # treat sequence as channel
97
+ expansion=config.expansion,
98
+ )
99
+ else:
100
+ self.self_attn = Attention(
101
+ hidden_size=config.hidden_size,
102
+ head_dim=config.hidden_size // config.num_heads,
103
+ num_heads=config.num_heads,
104
+ num_key_value_heads=config.num_heads,
105
+ causal=False,
106
+ )
107
+ self.mlp = SwiGLU(hidden_size=config.hidden_size, expansion=config.expansion)
108
+ self.norm_eps = config.rms_norm_eps
109
+
110
+ def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
111
+ if self.config.mlp_t:
112
+ # MLP over sequence dimension (mlp-t)
113
+ hidden_states = hidden_states.transpose(1, 2)
114
+ out = self.mlp_t(hidden_states)
115
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
116
+ hidden_states = hidden_states.transpose(1, 2)
117
+ else:
118
+ hidden_states = rms_norm(
119
+ hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states),
120
+ variance_epsilon=self.norm_eps,
121
+ )
122
+ out = self.mlp(hidden_states)
123
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
124
+ return hidden_states
125
+
126
+ class GLPSReasoningModule(nn.Module):
127
+ """Reasoning stack with optional masked updates (only update uncertain tokens)."""
128
+ def __init__(self, layers: List[GLPSBlock]):
129
+ super().__init__()
130
+ self.layers = torch.nn.ModuleList(layers)
131
+
132
+ def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, update_mask: torch.Tensor = None, **kwargs) -> torch.Tensor:
133
+ x = hidden_states
134
+ for layer in self.layers:
135
+ # Compute candidate update using injected context
136
+ y = layer(hidden_states=x + input_injection, **kwargs)
137
+ if update_mask is not None:
138
+ # Convex blend keeps frozen tokens unchanged
139
+ m = update_mask.to(x.dtype)[..., None]
140
+ x = x + m * (y - x)
141
+ else:
142
+ x = y
143
+ return x
144
+
145
+ class GLPS_ACTV1_Inner(nn.Module):
146
+ def __init__(self, config: GLPS_ACTV1Config) -> None:
147
+ super().__init__()
148
+ self.config = config
149
+ self.forward_dtype = getattr(torch, self.config.forward_dtype)
150
+
151
+ # I/O
152
+ self.embed_scale = math.sqrt(self.config.hidden_size)
153
+ embed_init_std = 1.0 / self.embed_scale
154
+
155
+ self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
156
+ self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
157
+ self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
158
+
159
+ # Puzzle emb (optional) — same convention as HRM/TRM
160
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
161
+ if self.config.puzzle_emb_ndim > 0:
162
+ self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim, batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
163
+
164
+ # Positional encodings
165
+ if self.config.pos_encodings == "rope":
166
+ self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads, max_position_embeddings=self.config.seq_len + self.puzzle_emb_len, base=self.config.rope_theta)
167
+ elif self.config.pos_encodings == "learned":
168
+ self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
169
+
170
+ # Reasoning stacks
171
+ self.H_level = GLPSReasoningModule(layers=[GLPSBlock(self.config) for _ in range(self.config.H_layers)])
172
+ self.L_level = GLPSReasoningModule(layers=[GLPSBlock(self.config) for _ in range(self.config.L_layers)])
173
+
174
+ # Initial states (match HRM/TRM style)
175
+ H_init = trunc_normal_init_(torch.empty(1, 1, self.config.hidden_size, dtype=self.forward_dtype), std=1.0)
176
+ L_init = trunc_normal_init_(torch.empty(1, 1, self.config.hidden_size, dtype=self.forward_dtype), std=1.0)
177
+ self.register_buffer("H_init", H_init, persistent=True)
178
+ self.register_buffer("L_init", L_init, persistent=True)
179
+
180
+ # GLPS small heads
181
+ self.candidate_head = CastedLinear(self.config.hidden_size, 16, bias=True) # task-specific; for Sudoku you can slice to 9
182
+ self.certainty_head = CastedLinear(self.config.hidden_size, 1, bias=True)
183
+ self.energy_head = CastedLinear(self.config.hidden_size, 1, bias=True)
184
+
185
+ # Low-rank dependency scorer (shared)
186
+ r = max(1, self.config.dep_rank)
187
+ self.dep_q = CastedLinear(self.config.hidden_size, r, bias=False)
188
+ self.dep_k = CastedLinear(self.config.hidden_size, r, bias=False)
189
+
190
+ # Q head init like HRM/TRM (near-zero -> easier bootstrapping)
191
+ with torch.no_grad():
192
+ self.q_head.weight.zero_()
193
+ self.q_head.bias.fill_(-5)
194
+
195
+ def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
196
+ # Token embedding
197
+ embedding = self.embed_tokens(input.to(torch.int32))
198
+
199
+ # Puzzle embeddings
200
+ if self.config.puzzle_emb_ndim > 0:
201
+ puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
202
+ pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
203
+ if pad_count > 0:
204
+ puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
205
+ embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
206
+
207
+ # Position embeddings
208
+ if self.config.pos_encodings == "learned":
209
+ embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
210
+
211
+ return self.embed_scale * embedding
212
+
213
+ def empty_carry(self, batch_size: int):
214
+ return GLPS_ACTV1InnerCarry(
215
+ z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
216
+ z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
217
+ )
218
+
219
+ def reset_carry(self, reset_flag: torch.Tensor, carry: GLPS_ACTV1InnerCarry):
220
+ # Explicitly expand buffers and mask to target shapes to avoid shape confusion
221
+ B, L, D = carry.z_H.shape
222
+ # Reduce/reset flag to per-batch boolean vector of shape [B]
223
+ if reset_flag.ndim == 1 and reset_flag.shape[0] == B:
224
+ reset_b = reset_flag.to(torch.bool)
225
+ else:
226
+ # If shape is [B, ...] reduce across non-batch dims; otherwise fallback to first B entries
227
+ try:
228
+ reset_b = reset_flag.reshape(B, -1).any(dim=1).to(torch.bool)
229
+ except Exception:
230
+ reset_b = reset_flag.reshape(-1)[:B].to(torch.bool)
231
+ m = reset_b.view(B, 1, 1)
232
+ mH = m.expand(B, L, D)
233
+ mL = mH # same shape for z_L
234
+ H_init_exp = self.H_init.expand(B, L, D)
235
+ L_init_exp = self.L_init.expand(B, L, D)
236
+ return GLPS_ACTV1InnerCarry(
237
+ z_H=torch.where(mH, H_init_exp, carry.z_H),
238
+ z_L=torch.where(mL, L_init_exp, carry.z_L),
239
+ )
240
+
241
+ def _global_scan(self, z_L, z_H, input_embeddings, seq_info):
242
+ # One light pass to gather global signals
243
+ z_scan = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
244
+ cand_logits = self.candidate_head(z_scan) # [B, L, C]
245
+ certainty = torch.sigmoid(self.certainty_head(z_scan)) # [B, L, 1]
246
+ return z_scan, cand_logits, certainty
247
+
248
+ def _build_dep_mask(self, z_ctx, uncertain_mask: torch.Tensor):
249
+ """Compute a dependency-based focus mask from a low-rank bilinear score.
250
+ uncertain_mask: [B, L] boolean mask of cells that are currently uncertain
251
+ Returns: dep_mask [B, L] boolean mask of cells to (re)update.
252
+ """
253
+ B, L, D = z_ctx.shape
254
+ # Project to low-rank space; use float32 to avoid dtype mismatches under torch.compile
255
+ Q = self.dep_q(z_ctx).to(torch.float32) # [B, L, r]
256
+ K = self.dep_k(z_ctx).to(torch.float32) # [B, L, r]
257
+ r = max(1, int(Q.shape[-1]))
258
+ sim = torch.matmul(Q, K.transpose(1, 2)) # [B, L, L] (float32)
259
+ sim = sim / math.sqrt(r)
260
+
261
+ # Aggregate influence from uncertain queries onto target tokens
262
+ src = uncertain_mask.to(sim.dtype).unsqueeze(1) # [B, 1, L]
263
+ influence = torch.matmul(src, sim).squeeze(1) # [B, L]
264
+
265
+ # Top-k influenced tokens per batch
266
+ topk = min(self.config.dep_topk, L)
267
+ vals, idx = torch.topk(influence, k=topk, dim=-1) # [B, topk]
268
+ dep_mask = torch.zeros_like(uncertain_mask)
269
+ dep_mask.scatter_(1, idx, True)
270
+
271
+ # Always include uncertain cells themselves
272
+ dep_mask = dep_mask | uncertain_mask
273
+ return dep_mask
274
+
275
+ def forward(self, carry: GLPS_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]):
276
+ seq_info = dict(cos_sin=getattr(self, "rotary_emb", None)() if hasattr(self, "rotary_emb") else None)
277
+
278
+ # Encode inputs
279
+ input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
280
+
281
+ # States
282
+ z_H, z_L = carry.z_H, carry.z_L
283
+
284
+ if not self.config.glps_enabled:
285
+ # Fallback to an HRM-like single-cycle grad update for compatibility
286
+ with torch.no_grad():
287
+ for _H in range(self.config.H_cycles - 1):
288
+ for _L in range(self.config.L_cycles):
289
+ z_L = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
290
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
291
+ # final grad step
292
+ for _L in range(self.config.L_cycles):
293
+ z_L = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
294
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
295
+
296
+ # Outputs
297
+ new_carry = GLPS_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())
298
+ logits = self.lm_head(z_H)[:, self.puzzle_emb_len:]
299
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
300
+ conf = torch.zeros_like(q_logits[..., :1]) + 0.5 # neutral
301
+ return new_carry, logits, (q_logits[..., 0], q_logits[..., 1]), conf
302
+
303
+ # ===== GLPS path =====
304
+ # H1: global scan (cheap)
305
+ with torch.no_grad():
306
+ z_scan, cand_logits, certainty = self._global_scan(z_L, z_H, input_embeddings, seq_info)
307
+
308
+ # L1: fill-obvious -> compute stable vs uncertain masks
309
+ if self.config.glps_fill_obvious:
310
+ obvious_mask = (certainty >= self.config.glps_tau_uncertain) # [B, L, 1]
311
+ else:
312
+ obvious_mask = torch.zeros_like(certainty).bool()
313
+ stable_mask = obvious_mask.squeeze(-1) # [B, L]
314
+ uncertain_mask = ~stable_mask # [B, L]
315
+
316
+ # H2: dependency prediction over remaining cells
317
+ if self.config.glps_dep_graph:
318
+ dep_mask = self._build_dep_mask(z_scan, uncertain_mask) # [B, L]
319
+ else:
320
+ dep_mask = uncertain_mask
321
+
322
+ # L2: targeted refinement (a couple of masked iters)
323
+ update_mask = dep_mask if self.config.glps_token_masking else None
324
+ z = z_scan.detach() # use scanned context as start
325
+ for _ in range(self.config.glps_max_targeted_iters):
326
+ z = self.L_level(z, z_H + input_embeddings, update_mask=update_mask, **seq_info)
327
+ # Refresh certainty to shrink mask (optional but cheap)
328
+ cert_now = torch.sigmoid(self.certainty_head(z)).squeeze(-1)
329
+ if self.config.glps_token_masking:
330
+ update_mask = dep_mask & (cert_now < self.config.glps_tau_uncertain)
331
+
332
+ # Merge into H and do a light H update with grad
333
+ z_L = z
334
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
335
+
336
+ # H3: energy/consistency -> confidence & optional global propagate
337
+ energy = torch.sigmoid(self.energy_head(z_H)).mean(dim=1) # [B, 1]
338
+ conf = 1.0 - energy
339
+ need_sweep = (conf.squeeze(-1) < self.config.glps_tau_halt)
340
+ if self.config.glps_global_propagate_on_low_conf and need_sweep.any():
341
+ # one final full sweep only for rows needing it
342
+ maskB = need_sweep.view(-1, 1, 1)
343
+ zL2 = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
344
+ zH2 = self.H_level(z_H, zL2, update_mask=None, **seq_info)
345
+ z_L = torch.where(maskB, zL2, z_L)
346
+ z_H = torch.where(maskB, zH2, z_H)
347
+
348
+ # Outputs
349
+ new_carry = GLPS_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())
350
+ logits = self.lm_head(z_H)[:, self.puzzle_emb_len:]
351
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
352
+ return new_carry, logits, (q_logits[..., 0], q_logits[..., 1]), conf
353
+
354
+ class GLPS_ACTV1(nn.Module):
355
+ """ACT-style wrapper that mixes Q-halt with GLPS confidence."""
356
+ def __init__(self, config_dict: dict):
357
+ super().__init__()
358
+ self.config = GLPS_ACTV1Config(**config_dict)
359
+ self.inner = GLPS_ACTV1_Inner(self.config)
360
+
361
+ @property
362
+ def puzzle_emb(self):
363
+ return self.inner.puzzle_emb
364
+
365
+ def initial_carry(self, batch: Dict[str, torch.Tensor]):
366
+ batch_size = batch["inputs"].shape[0]
367
+ return GLPS_ACTV1Carry(
368
+ inner_carry=self.inner.empty_carry(batch_size),
369
+ steps=torch.zeros((batch_size,), dtype=torch.int32),
370
+ halted=torch.ones((batch_size,), dtype=torch.bool), # start halted to force reset on first pass
371
+ current_data={k: torch.empty_like(v) for k, v in batch.items()}
372
+ )
373
+
374
+ def forward(self, carry: GLPS_ACTV1Carry, batch: Dict[str, torch.Tensor]):
375
+ # Reset halted seqs
376
+ new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
377
+ new_steps = torch.where(carry.halted, 0, carry.steps)
378
+ new_current_data = {k: torch.where(carry.halted.view((-1,) + (1,) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
379
+
380
+ # Inner step
381
+ new_inner_carry, logits, (q_halt_logits, q_continue_logits), conf = self.inner(new_inner_carry, new_current_data)
382
+
383
+ outputs = {
384
+ "logits": logits,
385
+ "q_halt_logits": q_halt_logits,
386
+ "q_continue_logits": q_continue_logits,
387
+ "conf": conf.squeeze(-1),
388
+ }
389
+
390
+ with torch.no_grad():
391
+ new_steps = new_steps + 1
392
+ is_last_step = new_steps >= self.config.halt_max_steps
393
+
394
+ # Combine halt signals: Q or confidence or last-step
395
+ halted = is_last_step | (q_halt_logits > q_continue_logits) | (conf.squeeze(-1) >= self.config.glps_tau_halt)
396
+
397
+ # Exploration during training only
398
+ if self.training and (self.config.halt_max_steps > 1):
399
+ min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
400
+ halted = halted & (new_steps >= min_halt_steps)
401
+
402
+ # Optional target for Q-learning (kept similar to HRM)
403
+ next_conf = self.inner(new_inner_carry, new_current_data)[-1]
404
+ outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, q_halt_logits, torch.maximum(q_halt_logits, q_continue_logits)))
405
+ else:
406
+ # During eval, always use max_steps to ensure consistent reasoning depth (same as TRM/HRM eval behavior)
407
+ halted = is_last_step
408
+
409
+ return GLPS_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
Maze-30x30-hard-1k-ACT-torch/pretrain_glps_maze30_v1_e10k_evalboost_extreme/losses.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Tuple, Dict, Sequence, Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ import math
7
+
8
+ IGNORE_LABEL_ID = -100
9
+
10
+
11
+ def s(x, epsilon=1e-30):
12
+ return torch.where(
13
+ x<0,
14
+ 1/(1-x+ epsilon),
15
+ x + 1
16
+ )
17
+
18
+
19
+ def log_stablemax(x, dim=-1):
20
+ s_x = s(x)
21
+ return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
22
+
23
+
24
+ def stablemax_cross_entropy(logits, labels, ignore_index: int = -100, valid_mask=None):
25
+ logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
26
+
27
+ if valid_mask is None:
28
+ valid_mask = (labels != ignore_index)
29
+ transformed_labels = torch.where(valid_mask, labels, 0)
30
+ prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
31
+
32
+ return -torch.where(valid_mask, prediction_logprobs, 0)
33
+
34
+
35
+ def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
36
+ # Cast logits to f32
37
+ # Flatten logits
38
+ return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
39
+
40
+
41
+ class ACTLossHead(nn.Module):
42
+ def __init__(self, model: nn.Module, loss_type: str):
43
+ super().__init__()
44
+ self.model = model
45
+ self.loss_fn = globals()[loss_type]
46
+
47
+ def initial_carry(self, *args, **kwargs):
48
+ return self.model.initial_carry(*args, **kwargs) # type: ignore
49
+
50
+ def forward(
51
+ self,
52
+ return_keys: Sequence[str],
53
+ # Model args
54
+ **model_kwargs,
55
+ ) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
56
+ # Model logits
57
+ # B x SeqLen x D
58
+ new_carry, outputs = self.model(**model_kwargs)
59
+ labels = new_carry.current_data["labels"]
60
+
61
+ with torch.no_grad():
62
+ # Preds
63
+ outputs["preds"] = torch.argmax(outputs["logits"], dim=-1)
64
+
65
+ # Correctness
66
+ mask = (labels != IGNORE_LABEL_ID)
67
+ loss_counts = mask.sum(-1)
68
+ loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
69
+
70
+ is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
71
+ seq_is_correct = is_correct.sum(-1) == loss_counts
72
+
73
+ # Metrics (halted)
74
+ valid_metrics = new_carry.halted & (loss_counts > 0)
75
+ metrics = {
76
+ "count": valid_metrics.sum(),
77
+
78
+ "accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
79
+ "exact_accuracy": (valid_metrics & seq_is_correct).sum(),
80
+
81
+ "q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
82
+ "steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
83
+ }
84
+
85
+ # Losses
86
+
87
+ lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID, valid_mask=mask) / loss_divisor).sum()
88
+ q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
89
+ metrics.update({
90
+ "lm_loss": lm_loss.detach(),
91
+ "q_halt_loss": q_halt_loss.detach(),
92
+ })
93
+ # Q continue (bootstrapping target loss); Alexia: This fits Q-learning, but seems totally unecessary
94
+ q_continue_loss = 0
95
+ if "target_q_continue" in outputs:
96
+ q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
97
+
98
+ metrics["q_continue_loss"] = q_continue_loss.detach()
99
+ # Filter outputs for return
100
+ detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
101
+
102
+ return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()
Maze-30x30-hard-1k-ACT-torch/pretrain_glps_maze30_v1_e10k_evalboost_ultra/all_config.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ arch:
2
+ H_cycles: 5
3
+ H_layers: 2
4
+ L_cycles: 2
5
+ L_layers: 7
6
+ dep_rank: 64
7
+ dep_topk: 12
8
+ expansion: 4.0
9
+ forward_dtype: bfloat16
10
+ glps_dep_graph: true
11
+ glps_enabled: true
12
+ glps_fill_obvious: true
13
+ glps_global_propagate_on_low_conf: true
14
+ glps_max_targeted_iters: 4
15
+ glps_tau_halt: 0.95
16
+ glps_tau_uncertain: 0.7
17
+ glps_token_masking: true
18
+ halt_exploration_prob: 0.1
19
+ halt_max_steps: 16
20
+ hidden_size: 512
21
+ loss:
22
+ loss_type: stablemax_cross_entropy
23
+ name: losses@ACTLossHead
24
+ mlp_t: false
25
+ name: recursive_reasoning.glps@GLPS_ACTV1
26
+ num_heads: 8
27
+ pos_encodings: rope
28
+ puzzle_emb_ndim: 0
29
+ rms_norm_eps: 1.0e-05
30
+ rope_theta: 10000.0
31
+ beta1: 0.9
32
+ beta2: 0.95
33
+ checkpoint_every_eval: true
34
+ checkpoint_path: checkpoints/Maze-30x30-hard-1k-ACT-torch/pretrain_glps_maze30_v1_e10k_evalboost_ultra
35
+ data_paths:
36
+ - data/maze-30x30-hard-1k
37
+ data_paths_test: []
38
+ ema: true
39
+ ema_rate: 0.999
40
+ epochs: 10000
41
+ eval_glps_max_targeted_iters: 16
42
+ eval_glps_tau_halt: 0.65
43
+ eval_halt_max_steps: 96
44
+ eval_interval: 1000
45
+ eval_only: true
46
+ eval_save_outputs: []
47
+ evaluators: []
48
+ freeze_weights: false
49
+ global_batch_size: 32
50
+ load_checkpoint: checkpoints/Maze-30x30-hard-1k-ACT-torch/pretrain_glps_maze30_v1_e10k_b160/step_6250
51
+ lr: 0.0001
52
+ lr_min_ratio: 1.0
53
+ lr_warmup_steps: 2000
54
+ min_eval_interval: 0
55
+ project_name: Maze-30x30-hard-1k-ACT-torch
56
+ puzzle_emb_lr: 0.01
57
+ puzzle_emb_weight_decay: 0.1
58
+ run_name: pretrain_glps_maze30_v1_e10k_evalboost_ultra
59
+ seed: 0
60
+ weight_decay: 0.1
Maze-30x30-hard-1k-ACT-torch/pretrain_glps_maze30_v1_e10k_evalboost_ultra/glps.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Tuple, List, Dict
3
+ from dataclasses import dataclass
4
+ import math
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ from pydantic import BaseModel
9
+
10
+ # Reuse the same building blocks as HRM/TRM
11
+ from models.common import trunc_normal_init_
12
+ from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
13
+ from models.sparse_embedding import CastedSparseEmbedding
14
+
15
+ """
16
+ Global-Local Predictive Solver (GLPS)
17
+ ------------------------------------
18
+ A light-weight control-policy on top of the HRM/TRM style blocks:
19
+ - H1: global scan -> certainty map
20
+ - L1: fill-obvious (lock stable cells)
21
+ - H2: dependency scoring over remaining cells
22
+ - L2: targeted refinement (masked updates)
23
+ - H3: energy-based confidence -> (optional) one global propagate sweep -> halt
24
+
25
+ This file keeps parameter growth tiny: a few heads + (optional) low-rank dependency scorer.
26
+ """
27
+
28
+ @dataclass
29
+ class GLPS_ACTV1InnerCarry:
30
+ z_H: torch.Tensor
31
+ z_L: torch.Tensor
32
+
33
+ @dataclass
34
+ class GLPS_ACTV1Carry:
35
+ inner_carry: GLPS_ACTV1InnerCarry
36
+ steps: torch.Tensor
37
+ halted: torch.Tensor
38
+ current_data: Dict[str, torch.Tensor]
39
+
40
+ class GLPS_ACTV1Config(BaseModel):
41
+ # Core IO / shapes
42
+ batch_size: int
43
+ seq_len: int
44
+ puzzle_emb_ndim: int = 0
45
+ num_puzzle_identifiers: int = 1
46
+ vocab_size: int = 256
47
+
48
+ # Cycle schedule
49
+ H_cycles: int = 3 # (scan -> refine -> check) typical
50
+ L_cycles: int = 1
51
+
52
+ # Depth
53
+ H_layers: int = 2
54
+ L_layers: int = 4
55
+
56
+ # Transformer config
57
+ hidden_size: int = 512
58
+ expansion: float = 2.0
59
+ num_heads: int = 8
60
+ pos_encodings: str = "rope"
61
+
62
+ rms_norm_eps: float = 1e-5
63
+ rope_theta: float = 10000.0
64
+
65
+ # ACT wrapper
66
+ halt_max_steps: int = 4
67
+ halt_exploration_prob: float = 0.1
68
+
69
+ forward_dtype: str = "bfloat16"
70
+
71
+ # Optional: use MLP on L instead of attention (matches HRM/TRM option)
72
+ mlp_t: bool = False
73
+
74
+ # ---- GLPS extras (tiny) ----
75
+ glps_enabled: bool = True
76
+ glps_fill_obvious: bool = True
77
+ glps_dep_graph: bool = True
78
+ glps_token_masking: bool = True
79
+ glps_global_propagate_on_low_conf: bool = True
80
+
81
+ glps_tau_halt: float = 0.95 # final confidence to halt
82
+ glps_tau_uncertain: float = 0.60 # cell-wise certainty threshold
83
+ glps_max_targeted_iters: int = 2 # small number: 1-2
84
+
85
+ # Dependency scorer (low rank bilinear)
86
+ dep_rank: int = 32
87
+ dep_topk: int = 8
88
+
89
+ class GLPSBlock(nn.Module):
90
+ def __init__(self, config: GLPS_ACTV1Config) -> None:
91
+ super().__init__()
92
+ self.config = config
93
+ if self.config.mlp_t:
94
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size)
95
+ self.mlp_t = SwiGLU(
96
+ hidden_size=self.config.seq_len + self.puzzle_emb_len, # treat sequence as channel
97
+ expansion=config.expansion,
98
+ )
99
+ else:
100
+ self.self_attn = Attention(
101
+ hidden_size=config.hidden_size,
102
+ head_dim=config.hidden_size // config.num_heads,
103
+ num_heads=config.num_heads,
104
+ num_key_value_heads=config.num_heads,
105
+ causal=False,
106
+ )
107
+ self.mlp = SwiGLU(hidden_size=config.hidden_size, expansion=config.expansion)
108
+ self.norm_eps = config.rms_norm_eps
109
+
110
+ def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
111
+ if self.config.mlp_t:
112
+ # MLP over sequence dimension (mlp-t)
113
+ hidden_states = hidden_states.transpose(1, 2)
114
+ out = self.mlp_t(hidden_states)
115
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
116
+ hidden_states = hidden_states.transpose(1, 2)
117
+ else:
118
+ hidden_states = rms_norm(
119
+ hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states),
120
+ variance_epsilon=self.norm_eps,
121
+ )
122
+ out = self.mlp(hidden_states)
123
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
124
+ return hidden_states
125
+
126
+ class GLPSReasoningModule(nn.Module):
127
+ """Reasoning stack with optional masked updates (only update uncertain tokens)."""
128
+ def __init__(self, layers: List[GLPSBlock]):
129
+ super().__init__()
130
+ self.layers = torch.nn.ModuleList(layers)
131
+
132
+ def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, update_mask: torch.Tensor = None, **kwargs) -> torch.Tensor:
133
+ x = hidden_states
134
+ for layer in self.layers:
135
+ # Compute candidate update using injected context
136
+ y = layer(hidden_states=x + input_injection, **kwargs)
137
+ if update_mask is not None:
138
+ # Convex blend keeps frozen tokens unchanged
139
+ m = update_mask.to(x.dtype)[..., None]
140
+ x = x + m * (y - x)
141
+ else:
142
+ x = y
143
+ return x
144
+
145
+ class GLPS_ACTV1_Inner(nn.Module):
146
+ def __init__(self, config: GLPS_ACTV1Config) -> None:
147
+ super().__init__()
148
+ self.config = config
149
+ self.forward_dtype = getattr(torch, self.config.forward_dtype)
150
+
151
+ # I/O
152
+ self.embed_scale = math.sqrt(self.config.hidden_size)
153
+ embed_init_std = 1.0 / self.embed_scale
154
+
155
+ self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
156
+ self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
157
+ self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
158
+
159
+ # Puzzle emb (optional) — same convention as HRM/TRM
160
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
161
+ if self.config.puzzle_emb_ndim > 0:
162
+ self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim, batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
163
+
164
+ # Positional encodings
165
+ if self.config.pos_encodings == "rope":
166
+ self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads, max_position_embeddings=self.config.seq_len + self.puzzle_emb_len, base=self.config.rope_theta)
167
+ elif self.config.pos_encodings == "learned":
168
+ self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
169
+
170
+ # Reasoning stacks
171
+ self.H_level = GLPSReasoningModule(layers=[GLPSBlock(self.config) for _ in range(self.config.H_layers)])
172
+ self.L_level = GLPSReasoningModule(layers=[GLPSBlock(self.config) for _ in range(self.config.L_layers)])
173
+
174
+ # Initial states (match HRM/TRM style)
175
+ H_init = trunc_normal_init_(torch.empty(1, 1, self.config.hidden_size, dtype=self.forward_dtype), std=1.0)
176
+ L_init = trunc_normal_init_(torch.empty(1, 1, self.config.hidden_size, dtype=self.forward_dtype), std=1.0)
177
+ self.register_buffer("H_init", H_init, persistent=True)
178
+ self.register_buffer("L_init", L_init, persistent=True)
179
+
180
+ # GLPS small heads
181
+ self.candidate_head = CastedLinear(self.config.hidden_size, 16, bias=True) # task-specific; for Sudoku you can slice to 9
182
+ self.certainty_head = CastedLinear(self.config.hidden_size, 1, bias=True)
183
+ self.energy_head = CastedLinear(self.config.hidden_size, 1, bias=True)
184
+
185
+ # Low-rank dependency scorer (shared)
186
+ r = max(1, self.config.dep_rank)
187
+ self.dep_q = CastedLinear(self.config.hidden_size, r, bias=False)
188
+ self.dep_k = CastedLinear(self.config.hidden_size, r, bias=False)
189
+
190
+ # Q head init like HRM/TRM (near-zero -> easier bootstrapping)
191
+ with torch.no_grad():
192
+ self.q_head.weight.zero_()
193
+ self.q_head.bias.fill_(-5)
194
+
195
+ def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
196
+ # Token embedding
197
+ embedding = self.embed_tokens(input.to(torch.int32))
198
+
199
+ # Puzzle embeddings
200
+ if self.config.puzzle_emb_ndim > 0:
201
+ puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
202
+ pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
203
+ if pad_count > 0:
204
+ puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
205
+ embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
206
+
207
+ # Position embeddings
208
+ if self.config.pos_encodings == "learned":
209
+ embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
210
+
211
+ return self.embed_scale * embedding
212
+
213
+ def empty_carry(self, batch_size: int):
214
+ return GLPS_ACTV1InnerCarry(
215
+ z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
216
+ z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
217
+ )
218
+
219
+ def reset_carry(self, reset_flag: torch.Tensor, carry: GLPS_ACTV1InnerCarry):
220
+ # Explicitly expand buffers and mask to target shapes to avoid shape confusion
221
+ B, L, D = carry.z_H.shape
222
+ # Reduce/reset flag to per-batch boolean vector of shape [B]
223
+ if reset_flag.ndim == 1 and reset_flag.shape[0] == B:
224
+ reset_b = reset_flag.to(torch.bool)
225
+ else:
226
+ # If shape is [B, ...] reduce across non-batch dims; otherwise fallback to first B entries
227
+ try:
228
+ reset_b = reset_flag.reshape(B, -1).any(dim=1).to(torch.bool)
229
+ except Exception:
230
+ reset_b = reset_flag.reshape(-1)[:B].to(torch.bool)
231
+ m = reset_b.view(B, 1, 1)
232
+ mH = m.expand(B, L, D)
233
+ mL = mH # same shape for z_L
234
+ H_init_exp = self.H_init.expand(B, L, D)
235
+ L_init_exp = self.L_init.expand(B, L, D)
236
+ return GLPS_ACTV1InnerCarry(
237
+ z_H=torch.where(mH, H_init_exp, carry.z_H),
238
+ z_L=torch.where(mL, L_init_exp, carry.z_L),
239
+ )
240
+
241
+ def _global_scan(self, z_L, z_H, input_embeddings, seq_info):
242
+ # One light pass to gather global signals
243
+ z_scan = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
244
+ cand_logits = self.candidate_head(z_scan) # [B, L, C]
245
+ certainty = torch.sigmoid(self.certainty_head(z_scan)) # [B, L, 1]
246
+ return z_scan, cand_logits, certainty
247
+
248
+ def _build_dep_mask(self, z_ctx, uncertain_mask: torch.Tensor):
249
+ """Compute a dependency-based focus mask from a low-rank bilinear score.
250
+ uncertain_mask: [B, L] boolean mask of cells that are currently uncertain
251
+ Returns: dep_mask [B, L] boolean mask of cells to (re)update.
252
+ """
253
+ B, L, D = z_ctx.shape
254
+ # Project to low-rank space; use float32 to avoid dtype mismatches under torch.compile
255
+ Q = self.dep_q(z_ctx).to(torch.float32) # [B, L, r]
256
+ K = self.dep_k(z_ctx).to(torch.float32) # [B, L, r]
257
+ r = max(1, int(Q.shape[-1]))
258
+ sim = torch.matmul(Q, K.transpose(1, 2)) # [B, L, L] (float32)
259
+ sim = sim / math.sqrt(r)
260
+
261
+ # Aggregate influence from uncertain queries onto target tokens
262
+ src = uncertain_mask.to(sim.dtype).unsqueeze(1) # [B, 1, L]
263
+ influence = torch.matmul(src, sim).squeeze(1) # [B, L]
264
+
265
+ # Top-k influenced tokens per batch
266
+ topk = min(self.config.dep_topk, L)
267
+ vals, idx = torch.topk(influence, k=topk, dim=-1) # [B, topk]
268
+ dep_mask = torch.zeros_like(uncertain_mask)
269
+ dep_mask.scatter_(1, idx, True)
270
+
271
+ # Always include uncertain cells themselves
272
+ dep_mask = dep_mask | uncertain_mask
273
+ return dep_mask
274
+
275
+ def forward(self, carry: GLPS_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]):
276
+ seq_info = dict(cos_sin=getattr(self, "rotary_emb", None)() if hasattr(self, "rotary_emb") else None)
277
+
278
+ # Encode inputs
279
+ input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
280
+
281
+ # States
282
+ z_H, z_L = carry.z_H, carry.z_L
283
+
284
+ if not self.config.glps_enabled:
285
+ # Fallback to an HRM-like single-cycle grad update for compatibility
286
+ with torch.no_grad():
287
+ for _H in range(self.config.H_cycles - 1):
288
+ for _L in range(self.config.L_cycles):
289
+ z_L = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
290
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
291
+ # final grad step
292
+ for _L in range(self.config.L_cycles):
293
+ z_L = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
294
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
295
+
296
+ # Outputs
297
+ new_carry = GLPS_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())
298
+ logits = self.lm_head(z_H)[:, self.puzzle_emb_len:]
299
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
300
+ conf = torch.zeros_like(q_logits[..., :1]) + 0.5 # neutral
301
+ return new_carry, logits, (q_logits[..., 0], q_logits[..., 1]), conf
302
+
303
+ # ===== GLPS path =====
304
+ # H1: global scan (cheap)
305
+ with torch.no_grad():
306
+ z_scan, cand_logits, certainty = self._global_scan(z_L, z_H, input_embeddings, seq_info)
307
+
308
+ # L1: fill-obvious -> compute stable vs uncertain masks
309
+ if self.config.glps_fill_obvious:
310
+ obvious_mask = (certainty >= self.config.glps_tau_uncertain) # [B, L, 1]
311
+ else:
312
+ obvious_mask = torch.zeros_like(certainty).bool()
313
+ stable_mask = obvious_mask.squeeze(-1) # [B, L]
314
+ uncertain_mask = ~stable_mask # [B, L]
315
+
316
+ # H2: dependency prediction over remaining cells
317
+ if self.config.glps_dep_graph:
318
+ dep_mask = self._build_dep_mask(z_scan, uncertain_mask) # [B, L]
319
+ else:
320
+ dep_mask = uncertain_mask
321
+
322
+ # L2: targeted refinement (a couple of masked iters)
323
+ update_mask = dep_mask if self.config.glps_token_masking else None
324
+ z = z_scan.detach() # use scanned context as start
325
+ for _ in range(self.config.glps_max_targeted_iters):
326
+ z = self.L_level(z, z_H + input_embeddings, update_mask=update_mask, **seq_info)
327
+ # Refresh certainty to shrink mask (optional but cheap)
328
+ cert_now = torch.sigmoid(self.certainty_head(z)).squeeze(-1)
329
+ if self.config.glps_token_masking:
330
+ update_mask = dep_mask & (cert_now < self.config.glps_tau_uncertain)
331
+
332
+ # Merge into H and do a light H update with grad
333
+ z_L = z
334
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
335
+
336
+ # H3: energy/consistency -> confidence & optional global propagate
337
+ energy = torch.sigmoid(self.energy_head(z_H)).mean(dim=1) # [B, 1]
338
+ conf = 1.0 - energy
339
+ need_sweep = (conf.squeeze(-1) < self.config.glps_tau_halt)
340
+ if self.config.glps_global_propagate_on_low_conf and need_sweep.any():
341
+ # one final full sweep only for rows needing it
342
+ maskB = need_sweep.view(-1, 1, 1)
343
+ zL2 = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
344
+ zH2 = self.H_level(z_H, zL2, update_mask=None, **seq_info)
345
+ z_L = torch.where(maskB, zL2, z_L)
346
+ z_H = torch.where(maskB, zH2, z_H)
347
+
348
+ # Outputs
349
+ new_carry = GLPS_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())
350
+ logits = self.lm_head(z_H)[:, self.puzzle_emb_len:]
351
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
352
+ return new_carry, logits, (q_logits[..., 0], q_logits[..., 1]), conf
353
+
354
+ class GLPS_ACTV1(nn.Module):
355
+ """ACT-style wrapper that mixes Q-halt with GLPS confidence."""
356
+ def __init__(self, config_dict: dict):
357
+ super().__init__()
358
+ self.config = GLPS_ACTV1Config(**config_dict)
359
+ self.inner = GLPS_ACTV1_Inner(self.config)
360
+
361
+ @property
362
+ def puzzle_emb(self):
363
+ return self.inner.puzzle_emb
364
+
365
+ def initial_carry(self, batch: Dict[str, torch.Tensor]):
366
+ batch_size = batch["inputs"].shape[0]
367
+ return GLPS_ACTV1Carry(
368
+ inner_carry=self.inner.empty_carry(batch_size),
369
+ steps=torch.zeros((batch_size,), dtype=torch.int32),
370
+ halted=torch.ones((batch_size,), dtype=torch.bool), # start halted to force reset on first pass
371
+ current_data={k: torch.empty_like(v) for k, v in batch.items()}
372
+ )
373
+
374
+ def forward(self, carry: GLPS_ACTV1Carry, batch: Dict[str, torch.Tensor]):
375
+ # Reset halted seqs
376
+ new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
377
+ new_steps = torch.where(carry.halted, 0, carry.steps)
378
+ new_current_data = {k: torch.where(carry.halted.view((-1,) + (1,) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
379
+
380
+ # Inner step
381
+ new_inner_carry, logits, (q_halt_logits, q_continue_logits), conf = self.inner(new_inner_carry, new_current_data)
382
+
383
+ outputs = {
384
+ "logits": logits,
385
+ "q_halt_logits": q_halt_logits,
386
+ "q_continue_logits": q_continue_logits,
387
+ "conf": conf.squeeze(-1),
388
+ }
389
+
390
+ with torch.no_grad():
391
+ new_steps = new_steps + 1
392
+ is_last_step = new_steps >= self.config.halt_max_steps
393
+
394
+ # Combine halt signals: Q or confidence or last-step
395
+ halted = is_last_step | (q_halt_logits > q_continue_logits) | (conf.squeeze(-1) >= self.config.glps_tau_halt)
396
+
397
+ # Exploration during training only
398
+ if self.training and (self.config.halt_max_steps > 1):
399
+ min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
400
+ halted = halted & (new_steps >= min_halt_steps)
401
+
402
+ # Optional target for Q-learning (kept similar to HRM)
403
+ next_conf = self.inner(new_inner_carry, new_current_data)[-1]
404
+ outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, q_halt_logits, torch.maximum(q_halt_logits, q_continue_logits)))
405
+ else:
406
+ # During eval, always use max_steps to ensure consistent reasoning depth (same as TRM/HRM eval behavior)
407
+ halted = is_last_step
408
+
409
+ return GLPS_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
Maze-30x30-hard-1k-ACT-torch/pretrain_glps_maze30_v1_e10k_evalboost_ultra/losses.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Tuple, Dict, Sequence, Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ import math
7
+
8
+ IGNORE_LABEL_ID = -100
9
+
10
+
11
+ def s(x, epsilon=1e-30):
12
+ return torch.where(
13
+ x<0,
14
+ 1/(1-x+ epsilon),
15
+ x + 1
16
+ )
17
+
18
+
19
+ def log_stablemax(x, dim=-1):
20
+ s_x = s(x)
21
+ return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
22
+
23
+
24
+ def stablemax_cross_entropy(logits, labels, ignore_index: int = -100, valid_mask=None):
25
+ logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
26
+
27
+ if valid_mask is None:
28
+ valid_mask = (labels != ignore_index)
29
+ transformed_labels = torch.where(valid_mask, labels, 0)
30
+ prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
31
+
32
+ return -torch.where(valid_mask, prediction_logprobs, 0)
33
+
34
+
35
+ def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
36
+ # Cast logits to f32
37
+ # Flatten logits
38
+ return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
39
+
40
+
41
+ class ACTLossHead(nn.Module):
42
+ def __init__(self, model: nn.Module, loss_type: str):
43
+ super().__init__()
44
+ self.model = model
45
+ self.loss_fn = globals()[loss_type]
46
+
47
+ def initial_carry(self, *args, **kwargs):
48
+ return self.model.initial_carry(*args, **kwargs) # type: ignore
49
+
50
+ def forward(
51
+ self,
52
+ return_keys: Sequence[str],
53
+ # Model args
54
+ **model_kwargs,
55
+ ) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
56
+ # Model logits
57
+ # B x SeqLen x D
58
+ new_carry, outputs = self.model(**model_kwargs)
59
+ labels = new_carry.current_data["labels"]
60
+
61
+ with torch.no_grad():
62
+ # Preds
63
+ outputs["preds"] = torch.argmax(outputs["logits"], dim=-1)
64
+
65
+ # Correctness
66
+ mask = (labels != IGNORE_LABEL_ID)
67
+ loss_counts = mask.sum(-1)
68
+ loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
69
+
70
+ is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
71
+ seq_is_correct = is_correct.sum(-1) == loss_counts
72
+
73
+ # Metrics (halted)
74
+ valid_metrics = new_carry.halted & (loss_counts > 0)
75
+ metrics = {
76
+ "count": valid_metrics.sum(),
77
+
78
+ "accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
79
+ "exact_accuracy": (valid_metrics & seq_is_correct).sum(),
80
+
81
+ "q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
82
+ "steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
83
+ }
84
+
85
+ # Losses
86
+
87
+ lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID, valid_mask=mask) / loss_divisor).sum()
88
+ q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
89
+ metrics.update({
90
+ "lm_loss": lm_loss.detach(),
91
+ "q_halt_loss": q_halt_loss.detach(),
92
+ })
93
+ # Q continue (bootstrapping target loss); Alexia: This fits Q-learning, but seems totally unecessary
94
+ q_continue_loss = 0
95
+ if "target_q_continue" in outputs:
96
+ q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
97
+
98
+ metrics["q_continue_loss"] = q_continue_loss.detach()
99
+ # Filter outputs for return
100
+ detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
101
+
102
+ return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()
Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku/all_config.yaml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ arch:
2
+ H_cycles: 3
3
+ H_layers: 2
4
+ L_cycles: 1
5
+ L_layers: 4
6
+ dep_rank: 32
7
+ dep_topk: 8
8
+ expansion: 2.0
9
+ forward_dtype: bfloat16
10
+ glps_dep_graph: true
11
+ glps_enabled: true
12
+ glps_fill_obvious: true
13
+ glps_global_propagate_on_low_conf: true
14
+ glps_max_targeted_iters: 2
15
+ glps_tau_halt: 0.95
16
+ glps_tau_uncertain: 0.6
17
+ glps_token_masking: true
18
+ halt_exploration_prob: 0.1
19
+ halt_max_steps: 4
20
+ hidden_size: 512
21
+ loss:
22
+ loss_type: stablemax_cross_entropy
23
+ name: losses@ACTLossHead
24
+ mlp_t: false
25
+ name: recursive_reasoning.glps@GLPS_ACTV1
26
+ num_heads: 8
27
+ pos_encodings: rope
28
+ puzzle_emb_ndim: 0
29
+ rms_norm_eps: 1.0e-05
30
+ rope_theta: 10000.0
31
+ beta1: 0.9
32
+ beta2: 0.95
33
+ checkpoint_every_eval: true
34
+ checkpoint_path: checkpoints/Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku
35
+ data_paths:
36
+ - data/sudoku-extreme-1k-aug-1000
37
+ data_paths_test: []
38
+ ema: true
39
+ ema_rate: 0.999
40
+ epochs: 50000
41
+ eval_interval: 5000
42
+ eval_save_outputs: []
43
+ evaluators: []
44
+ freeze_weights: false
45
+ global_batch_size: 768
46
+ load_checkpoint: null
47
+ lr: 0.0001
48
+ lr_min_ratio: 1.0
49
+ lr_warmup_steps: 2000
50
+ min_eval_interval: 0
51
+ project_name: Sudoku-extreme-1k-aug-1000-ACT-torch
52
+ puzzle_emb_lr: 0.0001
53
+ puzzle_emb_weight_decay: 1.0
54
+ run_name: pretrain_glps_sudoku
55
+ seed: 0
56
+ weight_decay: 1.0
Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku/glps.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Tuple, List, Dict
3
+ from dataclasses import dataclass
4
+ import math
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ from pydantic import BaseModel
9
+
10
+ # Reuse the same building blocks as HRM/TRM
11
+ from models.common import trunc_normal_init_
12
+ from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
13
+ from models.sparse_embedding import CastedSparseEmbedding
14
+
15
+ """
16
+ Global-Local Predictive Solver (GLPS)
17
+ ------------------------------------
18
+ A light-weight control-policy on top of the HRM/TRM style blocks:
19
+ - H1: global scan -> certainty map
20
+ - L1: fill-obvious (lock stable cells)
21
+ - H2: dependency scoring over remaining cells
22
+ - L2: targeted refinement (masked updates)
23
+ - H3: energy-based confidence -> (optional) one global propagate sweep -> halt
24
+
25
+ This file keeps parameter growth tiny: a few heads + (optional) low-rank dependency scorer.
26
+ """
27
+
28
+ @dataclass
29
+ class GLPS_ACTV1InnerCarry:
30
+ z_H: torch.Tensor
31
+ z_L: torch.Tensor
32
+
33
+ @dataclass
34
+ class GLPS_ACTV1Carry:
35
+ inner_carry: GLPS_ACTV1InnerCarry
36
+ steps: torch.Tensor
37
+ halted: torch.Tensor
38
+ current_data: Dict[str, torch.Tensor]
39
+
40
+ class GLPS_ACTV1Config(BaseModel):
41
+ # Core IO / shapes
42
+ batch_size: int
43
+ seq_len: int
44
+ puzzle_emb_ndim: int = 0
45
+ num_puzzle_identifiers: int = 1
46
+ vocab_size: int = 256
47
+
48
+ # Cycle schedule
49
+ H_cycles: int = 3 # (scan -> refine -> check) typical
50
+ L_cycles: int = 1
51
+
52
+ # Depth
53
+ H_layers: int = 2
54
+ L_layers: int = 4
55
+
56
+ # Transformer config
57
+ hidden_size: int = 512
58
+ expansion: float = 2.0
59
+ num_heads: int = 8
60
+ pos_encodings: str = "rope"
61
+
62
+ rms_norm_eps: float = 1e-5
63
+ rope_theta: float = 10000.0
64
+
65
+ # ACT wrapper
66
+ halt_max_steps: int = 4
67
+ halt_exploration_prob: float = 0.1
68
+
69
+ forward_dtype: str = "bfloat16"
70
+
71
+ # Optional: use MLP on L instead of attention (matches HRM/TRM option)
72
+ mlp_t: bool = False
73
+
74
+ # ---- GLPS extras (tiny) ----
75
+ glps_enabled: bool = True
76
+ glps_fill_obvious: bool = True
77
+ glps_dep_graph: bool = True
78
+ glps_token_masking: bool = True
79
+ glps_global_propagate_on_low_conf: bool = True
80
+
81
+ glps_tau_halt: float = 0.95 # final confidence to halt
82
+ glps_tau_uncertain: float = 0.60 # cell-wise certainty threshold
83
+ glps_max_targeted_iters: int = 2 # small number: 1-2
84
+
85
+ # Dependency scorer (low rank bilinear)
86
+ dep_rank: int = 32
87
+ dep_topk: int = 8
88
+
89
+ class GLPSBlock(nn.Module):
90
+ def __init__(self, config: GLPS_ACTV1Config) -> None:
91
+ super().__init__()
92
+ self.config = config
93
+ if self.config.mlp_t:
94
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size)
95
+ self.mlp_t = SwiGLU(
96
+ hidden_size=self.config.seq_len + self.puzzle_emb_len, # treat sequence as channel
97
+ expansion=config.expansion,
98
+ )
99
+ else:
100
+ self.self_attn = Attention(
101
+ hidden_size=config.hidden_size,
102
+ head_dim=config.hidden_size // config.num_heads,
103
+ num_heads=config.num_heads,
104
+ num_key_value_heads=config.num_heads,
105
+ causal=False,
106
+ )
107
+ self.mlp = SwiGLU(hidden_size=config.hidden_size, expansion=config.expansion)
108
+ self.norm_eps = config.rms_norm_eps
109
+
110
+ def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
111
+ if self.config.mlp_t:
112
+ # MLP over sequence dimension (mlp-t)
113
+ hidden_states = hidden_states.transpose(1, 2)
114
+ out = self.mlp_t(hidden_states)
115
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
116
+ hidden_states = hidden_states.transpose(1, 2)
117
+ else:
118
+ hidden_states = rms_norm(
119
+ hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states),
120
+ variance_epsilon=self.norm_eps,
121
+ )
122
+ out = self.mlp(hidden_states)
123
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
124
+ return hidden_states
125
+
126
+ class GLPSReasoningModule(nn.Module):
127
+ """Reasoning stack with optional masked updates (only update uncertain tokens)."""
128
+ def __init__(self, layers: List[GLPSBlock]):
129
+ super().__init__()
130
+ self.layers = torch.nn.ModuleList(layers)
131
+
132
+ def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, update_mask: torch.Tensor = None, **kwargs) -> torch.Tensor:
133
+ x = hidden_states
134
+ for layer in self.layers:
135
+ # Compute candidate update using injected context
136
+ y = layer(hidden_states=x + input_injection, **kwargs)
137
+ if update_mask is not None:
138
+ # Convex blend keeps frozen tokens unchanged
139
+ m = update_mask.to(x.dtype)[..., None]
140
+ x = x + m * (y - x)
141
+ else:
142
+ x = y
143
+ return x
144
+
145
+ class GLPS_ACTV1_Inner(nn.Module):
146
+ def __init__(self, config: GLPS_ACTV1Config) -> None:
147
+ super().__init__()
148
+ self.config = config
149
+ self.forward_dtype = getattr(torch, self.config.forward_dtype)
150
+
151
+ # I/O
152
+ self.embed_scale = math.sqrt(self.config.hidden_size)
153
+ embed_init_std = 1.0 / self.embed_scale
154
+
155
+ self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
156
+ self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
157
+ self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
158
+
159
+ # Puzzle emb (optional) — same convention as HRM/TRM
160
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
161
+ if self.config.puzzle_emb_ndim > 0:
162
+ self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim, batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
163
+
164
+ # Positional encodings
165
+ if self.config.pos_encodings == "rope":
166
+ self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads, max_position_embeddings=self.config.seq_len + self.puzzle_emb_len, base=self.config.rope_theta)
167
+ elif self.config.pos_encodings == "learned":
168
+ self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
169
+
170
+ # Reasoning stacks
171
+ self.H_level = GLPSReasoningModule(layers=[GLPSBlock(self.config) for _ in range(self.config.H_layers)])
172
+ self.L_level = GLPSReasoningModule(layers=[GLPSBlock(self.config) for _ in range(self.config.L_layers)])
173
+
174
+ # Initial states (match HRM/TRM style)
175
+ H_init = trunc_normal_init_(torch.empty(1, 1, self.config.hidden_size, dtype=self.forward_dtype), std=1.0)
176
+ L_init = trunc_normal_init_(torch.empty(1, 1, self.config.hidden_size, dtype=self.forward_dtype), std=1.0)
177
+ self.register_buffer("H_init", H_init, persistent=True)
178
+ self.register_buffer("L_init", L_init, persistent=True)
179
+
180
+ # GLPS small heads
181
+ self.candidate_head = CastedLinear(self.config.hidden_size, 16, bias=True) # task-specific; for Sudoku you can slice to 9
182
+ self.certainty_head = CastedLinear(self.config.hidden_size, 1, bias=True)
183
+ self.energy_head = CastedLinear(self.config.hidden_size, 1, bias=True)
184
+
185
+ # Low-rank dependency scorer (shared)
186
+ r = max(1, self.config.dep_rank)
187
+ self.dep_q = CastedLinear(self.config.hidden_size, r, bias=False)
188
+ self.dep_k = CastedLinear(self.config.hidden_size, r, bias=False)
189
+
190
+ # Q head init like HRM/TRM (near-zero -> easier bootstrapping)
191
+ with torch.no_grad():
192
+ self.q_head.weight.zero_()
193
+ self.q_head.bias.fill_(-5)
194
+
195
+ def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
196
+ # Token embedding
197
+ embedding = self.embed_tokens(input.to(torch.int32))
198
+
199
+ # Puzzle embeddings
200
+ if self.config.puzzle_emb_ndim > 0:
201
+ puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
202
+ pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
203
+ if pad_count > 0:
204
+ puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
205
+ embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
206
+
207
+ # Position embeddings
208
+ if self.config.pos_encodings == "learned":
209
+ embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
210
+
211
+ return self.embed_scale * embedding
212
+
213
+ def empty_carry(self, batch_size: int):
214
+ return GLPS_ACTV1InnerCarry(
215
+ z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
216
+ z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
217
+ )
218
+
219
+ def reset_carry(self, reset_flag: torch.Tensor, carry: GLPS_ACTV1InnerCarry):
220
+ m = reset_flag.view(-1, 1, 1)
221
+ return GLPS_ACTV1InnerCarry(
222
+ z_H=torch.where(m, self.H_init.expand_as(carry.z_H), carry.z_H),
223
+ z_L=torch.where(m, self.L_init.expand_as(carry.z_L), carry.z_L),
224
+ )
225
+
226
+ def _global_scan(self, z_L, z_H, input_embeddings, seq_info):
227
+ # One light pass to gather global signals
228
+ z_scan = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
229
+ cand_logits = self.candidate_head(z_scan) # [B, L, C]
230
+ certainty = torch.sigmoid(self.certainty_head(z_scan)) # [B, L, 1]
231
+ return z_scan, cand_logits, certainty
232
+
233
+ def _build_dep_mask(self, z_ctx, uncertain_mask: torch.Tensor):
234
+ """Compute a dependency-based focus mask from a low-rank bilinear score.
235
+ uncertain_mask: [B, L] boolean mask of cells that are currently uncertain
236
+ Returns: dep_mask [B, L] boolean mask of cells to (re)update.
237
+ """
238
+ B, L, D = z_ctx.shape
239
+ Q = self.dep_q(z_ctx) # [B, L, r]
240
+ K = self.dep_k(z_ctx) # [B, L, r]
241
+ r = max(1, int(Q.shape[-1]))
242
+ sim = torch.matmul(Q, K.transpose(1, 2)) # [B, L, L]
243
+ sim = sim / math.sqrt(r)
244
+
245
+ # Aggregate influence from uncertain queries onto target tokens
246
+ src = uncertain_mask.float().unsqueeze(1) # [B, 1, L]
247
+ influence = torch.matmul(src, sim).squeeze(1) # [B, L]
248
+
249
+ # Top-k influenced tokens per batch
250
+ topk = min(self.config.dep_topk, L)
251
+ vals, idx = torch.topk(influence, k=topk, dim=-1) # [B, topk]
252
+ dep_mask = torch.zeros_like(uncertain_mask)
253
+ dep_mask.scatter_(1, idx, True)
254
+
255
+ # Always include uncertain cells themselves
256
+ dep_mask = dep_mask | uncertain_mask
257
+ return dep_mask
258
+
259
+ def forward(self, carry: GLPS_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]):
260
+ seq_info = dict(cos_sin=getattr(self, "rotary_emb", None)() if hasattr(self, "rotary_emb") else None)
261
+
262
+ # Encode inputs
263
+ input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
264
+
265
+ # States
266
+ z_H, z_L = carry.z_H, carry.z_L
267
+
268
+ if not self.config.glps_enabled:
269
+ # Fallback to an HRM-like single-cycle grad update for compatibility
270
+ with torch.no_grad():
271
+ for _H in range(self.config.H_cycles - 1):
272
+ for _L in range(self.config.L_cycles):
273
+ z_L = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
274
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
275
+ # final grad step
276
+ for _L in range(self.config.L_cycles):
277
+ z_L = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
278
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
279
+
280
+ # Outputs
281
+ new_carry = GLPS_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())
282
+ logits = self.lm_head(z_H)[:, self.puzzle_emb_len:]
283
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
284
+ conf = torch.zeros_like(q_logits[..., :1]) + 0.5 # neutral
285
+ return new_carry, logits, (q_logits[..., 0], q_logits[..., 1]), conf
286
+
287
+ # ===== GLPS path =====
288
+ # H1: global scan (cheap)
289
+ with torch.no_grad():
290
+ z_scan, cand_logits, certainty = self._global_scan(z_L, z_H, input_embeddings, seq_info)
291
+
292
+ # L1: fill-obvious -> compute stable vs uncertain masks
293
+ if self.config.glps_fill_obvious:
294
+ obvious_mask = (certainty >= self.config.glps_tau_uncertain) # [B, L, 1]
295
+ else:
296
+ obvious_mask = torch.zeros_like(certainty).bool()
297
+ stable_mask = obvious_mask.squeeze(-1) # [B, L]
298
+ uncertain_mask = ~stable_mask # [B, L]
299
+
300
+ # H2: dependency prediction over remaining cells
301
+ if self.config.glps_dep_graph:
302
+ dep_mask = self._build_dep_mask(z_scan, uncertain_mask) # [B, L]
303
+ else:
304
+ dep_mask = uncertain_mask
305
+
306
+ # L2: targeted refinement (a couple of masked iters)
307
+ update_mask = dep_mask if self.config.glps_token_masking else None
308
+ z = z_scan.detach() # use scanned context as start
309
+ for _ in range(self.config.glps_max_targeted_iters):
310
+ z = self.L_level(z, z_H + input_embeddings, update_mask=update_mask, **seq_info)
311
+ # Refresh certainty to shrink mask (optional but cheap)
312
+ cert_now = torch.sigmoid(self.certainty_head(z)).squeeze(-1)
313
+ if self.config.glps_token_masking:
314
+ update_mask = dep_mask & (cert_now < self.config.glps_tau_uncertain)
315
+
316
+ # Merge into H and do a light H update with grad
317
+ z_L = z
318
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
319
+
320
+ # H3: energy/consistency -> confidence & optional global propagate
321
+ energy = torch.sigmoid(self.energy_head(z_H)).mean(dim=1, keepdim=True) # [B,1]
322
+ conf = 1.0 - energy
323
+ need_sweep = (conf.squeeze(-1) < self.config.glps_tau_halt)
324
+ if self.config.glps_global_propagate_on_low_conf and need_sweep.any():
325
+ # one final full sweep only for rows needing it
326
+ maskB = need_sweep.view(-1, 1, 1)
327
+ zL2 = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
328
+ zH2 = self.H_level(z_H, zL2, update_mask=None, **seq_info)
329
+ z_L = torch.where(maskB, zL2, z_L)
330
+ z_H = torch.where(maskB, zH2, z_H)
331
+
332
+ # Outputs
333
+ new_carry = GLPS_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())
334
+ logits = self.lm_head(z_H)[:, self.puzzle_emb_len:]
335
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
336
+ return new_carry, logits, (q_logits[..., 0], q_logits[..., 1]), conf
337
+
338
+ class GLPS_ACTV1(nn.Module):
339
+ """ACT-style wrapper that mixes Q-halt with GLPS confidence."""
340
+ def __init__(self, config_dict: dict):
341
+ super().__init__()
342
+ self.config = GLPS_ACTV1Config(**config_dict)
343
+ self.inner = GLPS_ACTV1_Inner(self.config)
344
+
345
+ @property
346
+ def puzzle_emb(self):
347
+ return self.inner.puzzle_emb
348
+
349
+ def initial_carry(self, batch: Dict[str, torch.Tensor]):
350
+ batch_size = batch["inputs"].shape[0]
351
+ return GLPS_ACTV1Carry(
352
+ inner_carry=self.inner.empty_carry(batch_size),
353
+ steps=torch.zeros((batch_size,), dtype=torch.int32),
354
+ halted=torch.ones((batch_size,), dtype=torch.bool), # start halted to force reset on first pass
355
+ current_data={k: torch.empty_like(v) for k, v in batch.items()}
356
+ )
357
+
358
+ def forward(self, carry: GLPS_ACTV1Carry, batch: Dict[str, torch.Tensor]):
359
+ # Reset halted seqs
360
+ new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
361
+ new_steps = torch.where(carry.halted, 0, carry.steps)
362
+ new_current_data = {k: torch.where(carry.halted.view((-1,) + (1,) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
363
+
364
+ # Inner step
365
+ new_inner_carry, logits, (q_halt_logits, q_continue_logits), conf = self.inner(new_inner_carry, new_current_data)
366
+
367
+ outputs = {
368
+ "logits": logits,
369
+ "q_halt_logits": q_halt_logits,
370
+ "q_continue_logits": q_continue_logits,
371
+ "conf": conf.squeeze(-1),
372
+ }
373
+
374
+ with torch.no_grad():
375
+ new_steps = new_steps + 1
376
+ is_last_step = new_steps >= self.config.halt_max_steps
377
+
378
+ # Combine halt signals: Q or confidence or last-step
379
+ halted = is_last_step | (q_halt_logits > q_continue_logits) | (conf.squeeze(-1) >= self.config.glps_tau_halt)
380
+
381
+ # Exploration during training only
382
+ if self.training and (self.config.halt_max_steps > 1):
383
+ min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
384
+ halted = halted & (new_steps >= min_halt_steps)
385
+
386
+ # Optional target for Q-learning (kept similar to HRM)
387
+ next_conf = self.inner(new_inner_carry, new_current_data)[-1]
388
+ outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, q_halt_logits, torch.maximum(q_halt_logits, q_continue_logits)))
389
+
390
+ return GLPS_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku/losses.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Tuple, Dict, Sequence, Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ import math
7
+
8
+ IGNORE_LABEL_ID = -100
9
+
10
+
11
+ def s(x, epsilon=1e-30):
12
+ return torch.where(
13
+ x<0,
14
+ 1/(1-x+ epsilon),
15
+ x + 1
16
+ )
17
+
18
+
19
+ def log_stablemax(x, dim=-1):
20
+ s_x = s(x)
21
+ return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
22
+
23
+
24
+ def stablemax_cross_entropy(logits, labels, ignore_index: int = -100, valid_mask=None):
25
+ logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
26
+
27
+ if valid_mask is None:
28
+ valid_mask = (labels != ignore_index)
29
+ transformed_labels = torch.where(valid_mask, labels, 0)
30
+ prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
31
+
32
+ return -torch.where(valid_mask, prediction_logprobs, 0)
33
+
34
+
35
+ def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
36
+ # Cast logits to f32
37
+ # Flatten logits
38
+ return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
39
+
40
+
41
+ class ACTLossHead(nn.Module):
42
+ def __init__(self, model: nn.Module, loss_type: str):
43
+ super().__init__()
44
+ self.model = model
45
+ self.loss_fn = globals()[loss_type]
46
+
47
+ def initial_carry(self, *args, **kwargs):
48
+ return self.model.initial_carry(*args, **kwargs) # type: ignore
49
+
50
+ def forward(
51
+ self,
52
+ return_keys: Sequence[str],
53
+ # Model args
54
+ **model_kwargs,
55
+ ) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
56
+ # Model logits
57
+ # B x SeqLen x D
58
+ new_carry, outputs = self.model(**model_kwargs)
59
+ labels = new_carry.current_data["labels"]
60
+
61
+ with torch.no_grad():
62
+ # Preds
63
+ outputs["preds"] = torch.argmax(outputs["logits"], dim=-1)
64
+
65
+ # Correctness
66
+ mask = (labels != IGNORE_LABEL_ID)
67
+ loss_counts = mask.sum(-1)
68
+ loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
69
+
70
+ is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
71
+ seq_is_correct = is_correct.sum(-1) == loss_counts
72
+
73
+ # Metrics (halted)
74
+ valid_metrics = new_carry.halted & (loss_counts > 0)
75
+ metrics = {
76
+ "count": valid_metrics.sum(),
77
+
78
+ "accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
79
+ "exact_accuracy": (valid_metrics & seq_is_correct).sum(),
80
+
81
+ "q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
82
+ "steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
83
+ }
84
+
85
+ # Losses
86
+
87
+ lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID, valid_mask=mask) / loss_divisor).sum()
88
+ q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
89
+ metrics.update({
90
+ "lm_loss": lm_loss.detach(),
91
+ "q_halt_loss": q_halt_loss.detach(),
92
+ })
93
+ # Q continue (bootstrapping target loss); Alexia: This fits Q-learning, but seems totally unecessary
94
+ q_continue_loss = 0
95
+ if "target_q_continue" in outputs:
96
+ q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
97
+
98
+ metrics["q_continue_loss"] = q_continue_loss.detach()
99
+ # Filter outputs for return
100
+ detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
101
+
102
+ return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()
Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_nocompile/all_config.yaml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ arch:
2
+ H_cycles: 3
3
+ H_layers: 2
4
+ L_cycles: 1
5
+ L_layers: 4
6
+ dep_rank: 32
7
+ dep_topk: 8
8
+ expansion: 2.0
9
+ forward_dtype: bfloat16
10
+ glps_dep_graph: true
11
+ glps_enabled: true
12
+ glps_fill_obvious: true
13
+ glps_global_propagate_on_low_conf: true
14
+ glps_max_targeted_iters: 2
15
+ glps_tau_halt: 0.95
16
+ glps_tau_uncertain: 0.6
17
+ glps_token_masking: true
18
+ halt_exploration_prob: 0.1
19
+ halt_max_steps: 4
20
+ hidden_size: 512
21
+ loss:
22
+ loss_type: stablemax_cross_entropy
23
+ name: losses@ACTLossHead
24
+ mlp_t: false
25
+ name: recursive_reasoning.glps@GLPS_ACTV1
26
+ num_heads: 8
27
+ pos_encodings: rope
28
+ puzzle_emb_ndim: 0
29
+ rms_norm_eps: 1.0e-05
30
+ rope_theta: 10000.0
31
+ beta1: 0.9
32
+ beta2: 0.95
33
+ checkpoint_every_eval: true
34
+ checkpoint_path: checkpoints/Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_nocompile
35
+ data_paths:
36
+ - data/sudoku-extreme-1k-aug-1000
37
+ data_paths_test: []
38
+ ema: true
39
+ ema_rate: 0.999
40
+ epochs: 50000
41
+ eval_interval: 5000
42
+ eval_save_outputs: []
43
+ evaluators: []
44
+ freeze_weights: false
45
+ global_batch_size: 768
46
+ load_checkpoint: null
47
+ lr: 0.0001
48
+ lr_min_ratio: 1.0
49
+ lr_warmup_steps: 2000
50
+ min_eval_interval: 0
51
+ project_name: Sudoku-extreme-1k-aug-1000-ACT-torch
52
+ puzzle_emb_lr: 0.0001
53
+ puzzle_emb_weight_decay: 1.0
54
+ run_name: pretrain_glps_sudoku_nocompile
55
+ seed: 0
56
+ weight_decay: 1.0
Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_nocompile/glps.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Tuple, List, Dict
3
+ from dataclasses import dataclass
4
+ import math
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ from pydantic import BaseModel
9
+
10
+ # Reuse the same building blocks as HRM/TRM
11
+ from models.common import trunc_normal_init_
12
+ from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
13
+ from models.sparse_embedding import CastedSparseEmbedding
14
+
15
+ """
16
+ Global-Local Predictive Solver (GLPS)
17
+ ------------------------------------
18
+ A light-weight control-policy on top of the HRM/TRM style blocks:
19
+ - H1: global scan -> certainty map
20
+ - L1: fill-obvious (lock stable cells)
21
+ - H2: dependency scoring over remaining cells
22
+ - L2: targeted refinement (masked updates)
23
+ - H3: energy-based confidence -> (optional) one global propagate sweep -> halt
24
+
25
+ This file keeps parameter growth tiny: a few heads + (optional) low-rank dependency scorer.
26
+ """
27
+
28
+ @dataclass
29
+ class GLPS_ACTV1InnerCarry:
30
+ z_H: torch.Tensor
31
+ z_L: torch.Tensor
32
+
33
+ @dataclass
34
+ class GLPS_ACTV1Carry:
35
+ inner_carry: GLPS_ACTV1InnerCarry
36
+ steps: torch.Tensor
37
+ halted: torch.Tensor
38
+ current_data: Dict[str, torch.Tensor]
39
+
40
+ class GLPS_ACTV1Config(BaseModel):
41
+ # Core IO / shapes
42
+ batch_size: int
43
+ seq_len: int
44
+ puzzle_emb_ndim: int = 0
45
+ num_puzzle_identifiers: int = 1
46
+ vocab_size: int = 256
47
+
48
+ # Cycle schedule
49
+ H_cycles: int = 3 # (scan -> refine -> check) typical
50
+ L_cycles: int = 1
51
+
52
+ # Depth
53
+ H_layers: int = 2
54
+ L_layers: int = 4
55
+
56
+ # Transformer config
57
+ hidden_size: int = 512
58
+ expansion: float = 2.0
59
+ num_heads: int = 8
60
+ pos_encodings: str = "rope"
61
+
62
+ rms_norm_eps: float = 1e-5
63
+ rope_theta: float = 10000.0
64
+
65
+ # ACT wrapper
66
+ halt_max_steps: int = 4
67
+ halt_exploration_prob: float = 0.1
68
+
69
+ forward_dtype: str = "bfloat16"
70
+
71
+ # Optional: use MLP on L instead of attention (matches HRM/TRM option)
72
+ mlp_t: bool = False
73
+
74
+ # ---- GLPS extras (tiny) ----
75
+ glps_enabled: bool = True
76
+ glps_fill_obvious: bool = True
77
+ glps_dep_graph: bool = True
78
+ glps_token_masking: bool = True
79
+ glps_global_propagate_on_low_conf: bool = True
80
+
81
+ glps_tau_halt: float = 0.95 # final confidence to halt
82
+ glps_tau_uncertain: float = 0.60 # cell-wise certainty threshold
83
+ glps_max_targeted_iters: int = 2 # small number: 1-2
84
+
85
+ # Dependency scorer (low rank bilinear)
86
+ dep_rank: int = 32
87
+ dep_topk: int = 8
88
+
89
+ class GLPSBlock(nn.Module):
90
+ def __init__(self, config: GLPS_ACTV1Config) -> None:
91
+ super().__init__()
92
+ self.config = config
93
+ if self.config.mlp_t:
94
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size)
95
+ self.mlp_t = SwiGLU(
96
+ hidden_size=self.config.seq_len + self.puzzle_emb_len, # treat sequence as channel
97
+ expansion=config.expansion,
98
+ )
99
+ else:
100
+ self.self_attn = Attention(
101
+ hidden_size=config.hidden_size,
102
+ head_dim=config.hidden_size // config.num_heads,
103
+ num_heads=config.num_heads,
104
+ num_key_value_heads=config.num_heads,
105
+ causal=False,
106
+ )
107
+ self.mlp = SwiGLU(hidden_size=config.hidden_size, expansion=config.expansion)
108
+ self.norm_eps = config.rms_norm_eps
109
+
110
+ def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
111
+ if self.config.mlp_t:
112
+ # MLP over sequence dimension (mlp-t)
113
+ hidden_states = hidden_states.transpose(1, 2)
114
+ out = self.mlp_t(hidden_states)
115
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
116
+ hidden_states = hidden_states.transpose(1, 2)
117
+ else:
118
+ hidden_states = rms_norm(
119
+ hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states),
120
+ variance_epsilon=self.norm_eps,
121
+ )
122
+ out = self.mlp(hidden_states)
123
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
124
+ return hidden_states
125
+
126
+ class GLPSReasoningModule(nn.Module):
127
+ """Reasoning stack with optional masked updates (only update uncertain tokens)."""
128
+ def __init__(self, layers: List[GLPSBlock]):
129
+ super().__init__()
130
+ self.layers = torch.nn.ModuleList(layers)
131
+
132
+ def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, update_mask: torch.Tensor = None, **kwargs) -> torch.Tensor:
133
+ x = hidden_states
134
+ for layer in self.layers:
135
+ # Compute candidate update using injected context
136
+ y = layer(hidden_states=x + input_injection, **kwargs)
137
+ if update_mask is not None:
138
+ # Convex blend keeps frozen tokens unchanged
139
+ m = update_mask.to(x.dtype)[..., None]
140
+ x = x + m * (y - x)
141
+ else:
142
+ x = y
143
+ return x
144
+
145
+ class GLPS_ACTV1_Inner(nn.Module):
146
+ def __init__(self, config: GLPS_ACTV1Config) -> None:
147
+ super().__init__()
148
+ self.config = config
149
+ self.forward_dtype = getattr(torch, self.config.forward_dtype)
150
+
151
+ # I/O
152
+ self.embed_scale = math.sqrt(self.config.hidden_size)
153
+ embed_init_std = 1.0 / self.embed_scale
154
+
155
+ self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
156
+ self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
157
+ self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
158
+
159
+ # Puzzle emb (optional) — same convention as HRM/TRM
160
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
161
+ if self.config.puzzle_emb_ndim > 0:
162
+ self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim, batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
163
+
164
+ # Positional encodings
165
+ if self.config.pos_encodings == "rope":
166
+ self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads, max_position_embeddings=self.config.seq_len + self.puzzle_emb_len, base=self.config.rope_theta)
167
+ elif self.config.pos_encodings == "learned":
168
+ self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
169
+
170
+ # Reasoning stacks
171
+ self.H_level = GLPSReasoningModule(layers=[GLPSBlock(self.config) for _ in range(self.config.H_layers)])
172
+ self.L_level = GLPSReasoningModule(layers=[GLPSBlock(self.config) for _ in range(self.config.L_layers)])
173
+
174
+ # Initial states (match HRM/TRM style)
175
+ H_init = trunc_normal_init_(torch.empty(1, 1, self.config.hidden_size, dtype=self.forward_dtype), std=1.0)
176
+ L_init = trunc_normal_init_(torch.empty(1, 1, self.config.hidden_size, dtype=self.forward_dtype), std=1.0)
177
+ self.register_buffer("H_init", H_init, persistent=True)
178
+ self.register_buffer("L_init", L_init, persistent=True)
179
+
180
+ # GLPS small heads
181
+ self.candidate_head = CastedLinear(self.config.hidden_size, 16, bias=True) # task-specific; for Sudoku you can slice to 9
182
+ self.certainty_head = CastedLinear(self.config.hidden_size, 1, bias=True)
183
+ self.energy_head = CastedLinear(self.config.hidden_size, 1, bias=True)
184
+
185
+ # Low-rank dependency scorer (shared)
186
+ r = max(1, self.config.dep_rank)
187
+ self.dep_q = CastedLinear(self.config.hidden_size, r, bias=False)
188
+ self.dep_k = CastedLinear(self.config.hidden_size, r, bias=False)
189
+
190
+ # Q head init like HRM/TRM (near-zero -> easier bootstrapping)
191
+ with torch.no_grad():
192
+ self.q_head.weight.zero_()
193
+ self.q_head.bias.fill_(-5)
194
+
195
+ def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
196
+ # Token embedding
197
+ embedding = self.embed_tokens(input.to(torch.int32))
198
+
199
+ # Puzzle embeddings
200
+ if self.config.puzzle_emb_ndim > 0:
201
+ puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
202
+ pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
203
+ if pad_count > 0:
204
+ puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
205
+ embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
206
+
207
+ # Position embeddings
208
+ if self.config.pos_encodings == "learned":
209
+ embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
210
+
211
+ return self.embed_scale * embedding
212
+
213
+ def empty_carry(self, batch_size: int):
214
+ return GLPS_ACTV1InnerCarry(
215
+ z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
216
+ z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
217
+ )
218
+
219
+ def reset_carry(self, reset_flag: torch.Tensor, carry: GLPS_ACTV1InnerCarry):
220
+ # Explicitly expand buffers and mask to target shapes to avoid shape confusion
221
+ B, L, D = carry.z_H.shape
222
+ # Reduce/reset flag to per-batch boolean vector of shape [B]
223
+ if reset_flag.ndim == 1 and reset_flag.shape[0] == B:
224
+ reset_b = reset_flag.to(torch.bool)
225
+ else:
226
+ # If shape is [B, ...] reduce across non-batch dims; otherwise fallback to first B entries
227
+ try:
228
+ reset_b = reset_flag.reshape(B, -1).any(dim=1).to(torch.bool)
229
+ except Exception:
230
+ reset_b = reset_flag.reshape(-1)[:B].to(torch.bool)
231
+ m = reset_b.view(B, 1, 1)
232
+ mH = m.expand(B, L, D)
233
+ mL = mH # same shape for z_L
234
+ H_init_exp = self.H_init.expand(B, L, D)
235
+ L_init_exp = self.L_init.expand(B, L, D)
236
+ return GLPS_ACTV1InnerCarry(
237
+ z_H=torch.where(mH, H_init_exp, carry.z_H),
238
+ z_L=torch.where(mL, L_init_exp, carry.z_L),
239
+ )
240
+
241
+ def _global_scan(self, z_L, z_H, input_embeddings, seq_info):
242
+ # One light pass to gather global signals
243
+ z_scan = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
244
+ cand_logits = self.candidate_head(z_scan) # [B, L, C]
245
+ certainty = torch.sigmoid(self.certainty_head(z_scan)) # [B, L, 1]
246
+ return z_scan, cand_logits, certainty
247
+
248
+ def _build_dep_mask(self, z_ctx, uncertain_mask: torch.Tensor):
249
+ """Compute a dependency-based focus mask from a low-rank bilinear score.
250
+ uncertain_mask: [B, L] boolean mask of cells that are currently uncertain
251
+ Returns: dep_mask [B, L] boolean mask of cells to (re)update.
252
+ """
253
+ B, L, D = z_ctx.shape
254
+ # Project to low-rank space; use float32 to avoid dtype mismatches under torch.compile
255
+ Q = self.dep_q(z_ctx).to(torch.float32) # [B, L, r]
256
+ K = self.dep_k(z_ctx).to(torch.float32) # [B, L, r]
257
+ r = max(1, int(Q.shape[-1]))
258
+ sim = torch.matmul(Q, K.transpose(1, 2)) # [B, L, L] (float32)
259
+ sim = sim / math.sqrt(r)
260
+
261
+ # Aggregate influence from uncertain queries onto target tokens
262
+ src = uncertain_mask.to(sim.dtype).unsqueeze(1) # [B, 1, L]
263
+ influence = torch.matmul(src, sim).squeeze(1) # [B, L]
264
+
265
+ # Top-k influenced tokens per batch
266
+ topk = min(self.config.dep_topk, L)
267
+ vals, idx = torch.topk(influence, k=topk, dim=-1) # [B, topk]
268
+ dep_mask = torch.zeros_like(uncertain_mask)
269
+ dep_mask.scatter_(1, idx, True)
270
+
271
+ # Always include uncertain cells themselves
272
+ dep_mask = dep_mask | uncertain_mask
273
+ return dep_mask
274
+
275
+ def forward(self, carry: GLPS_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]):
276
+ seq_info = dict(cos_sin=getattr(self, "rotary_emb", None)() if hasattr(self, "rotary_emb") else None)
277
+
278
+ # Encode inputs
279
+ input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
280
+
281
+ # States
282
+ z_H, z_L = carry.z_H, carry.z_L
283
+
284
+ if not self.config.glps_enabled:
285
+ # Fallback to an HRM-like single-cycle grad update for compatibility
286
+ with torch.no_grad():
287
+ for _H in range(self.config.H_cycles - 1):
288
+ for _L in range(self.config.L_cycles):
289
+ z_L = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
290
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
291
+ # final grad step
292
+ for _L in range(self.config.L_cycles):
293
+ z_L = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
294
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
295
+
296
+ # Outputs
297
+ new_carry = GLPS_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())
298
+ logits = self.lm_head(z_H)[:, self.puzzle_emb_len:]
299
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
300
+ conf = torch.zeros_like(q_logits[..., :1]) + 0.5 # neutral
301
+ return new_carry, logits, (q_logits[..., 0], q_logits[..., 1]), conf
302
+
303
+ # ===== GLPS path =====
304
+ # H1: global scan (cheap)
305
+ with torch.no_grad():
306
+ z_scan, cand_logits, certainty = self._global_scan(z_L, z_H, input_embeddings, seq_info)
307
+
308
+ # L1: fill-obvious -> compute stable vs uncertain masks
309
+ if self.config.glps_fill_obvious:
310
+ obvious_mask = (certainty >= self.config.glps_tau_uncertain) # [B, L, 1]
311
+ else:
312
+ obvious_mask = torch.zeros_like(certainty).bool()
313
+ stable_mask = obvious_mask.squeeze(-1) # [B, L]
314
+ uncertain_mask = ~stable_mask # [B, L]
315
+
316
+ # H2: dependency prediction over remaining cells
317
+ if self.config.glps_dep_graph:
318
+ dep_mask = self._build_dep_mask(z_scan, uncertain_mask) # [B, L]
319
+ else:
320
+ dep_mask = uncertain_mask
321
+
322
+ # L2: targeted refinement (a couple of masked iters)
323
+ update_mask = dep_mask if self.config.glps_token_masking else None
324
+ z = z_scan.detach() # use scanned context as start
325
+ for _ in range(self.config.glps_max_targeted_iters):
326
+ z = self.L_level(z, z_H + input_embeddings, update_mask=update_mask, **seq_info)
327
+ # Refresh certainty to shrink mask (optional but cheap)
328
+ cert_now = torch.sigmoid(self.certainty_head(z)).squeeze(-1)
329
+ if self.config.glps_token_masking:
330
+ update_mask = dep_mask & (cert_now < self.config.glps_tau_uncertain)
331
+
332
+ # Merge into H and do a light H update with grad
333
+ z_L = z
334
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
335
+
336
+ # H3: energy/consistency -> confidence & optional global propagate
337
+ energy = torch.sigmoid(self.energy_head(z_H)).mean(dim=1) # [B, 1]
338
+ conf = 1.0 - energy
339
+ need_sweep = (conf.squeeze(-1) < self.config.glps_tau_halt)
340
+ if self.config.glps_global_propagate_on_low_conf and need_sweep.any():
341
+ # one final full sweep only for rows needing it
342
+ maskB = need_sweep.view(-1, 1, 1)
343
+ zL2 = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
344
+ zH2 = self.H_level(z_H, zL2, update_mask=None, **seq_info)
345
+ z_L = torch.where(maskB, zL2, z_L)
346
+ z_H = torch.where(maskB, zH2, z_H)
347
+
348
+ # Outputs
349
+ new_carry = GLPS_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())
350
+ logits = self.lm_head(z_H)[:, self.puzzle_emb_len:]
351
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
352
+ return new_carry, logits, (q_logits[..., 0], q_logits[..., 1]), conf
353
+
354
+ class GLPS_ACTV1(nn.Module):
355
+ """ACT-style wrapper that mixes Q-halt with GLPS confidence."""
356
+ def __init__(self, config_dict: dict):
357
+ super().__init__()
358
+ self.config = GLPS_ACTV1Config(**config_dict)
359
+ self.inner = GLPS_ACTV1_Inner(self.config)
360
+
361
+ @property
362
+ def puzzle_emb(self):
363
+ return self.inner.puzzle_emb
364
+
365
+ def initial_carry(self, batch: Dict[str, torch.Tensor]):
366
+ batch_size = batch["inputs"].shape[0]
367
+ return GLPS_ACTV1Carry(
368
+ inner_carry=self.inner.empty_carry(batch_size),
369
+ steps=torch.zeros((batch_size,), dtype=torch.int32),
370
+ halted=torch.ones((batch_size,), dtype=torch.bool), # start halted to force reset on first pass
371
+ current_data={k: torch.empty_like(v) for k, v in batch.items()}
372
+ )
373
+
374
+ def forward(self, carry: GLPS_ACTV1Carry, batch: Dict[str, torch.Tensor]):
375
+ # Reset halted seqs
376
+ new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
377
+ new_steps = torch.where(carry.halted, 0, carry.steps)
378
+ new_current_data = {k: torch.where(carry.halted.view((-1,) + (1,) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
379
+
380
+ # Inner step
381
+ new_inner_carry, logits, (q_halt_logits, q_continue_logits), conf = self.inner(new_inner_carry, new_current_data)
382
+
383
+ outputs = {
384
+ "logits": logits,
385
+ "q_halt_logits": q_halt_logits,
386
+ "q_continue_logits": q_continue_logits,
387
+ "conf": conf.squeeze(-1),
388
+ }
389
+
390
+ with torch.no_grad():
391
+ new_steps = new_steps + 1
392
+ is_last_step = new_steps >= self.config.halt_max_steps
393
+
394
+ # Combine halt signals: Q or confidence or last-step
395
+ halted = is_last_step | (q_halt_logits > q_continue_logits) | (conf.squeeze(-1) >= self.config.glps_tau_halt)
396
+
397
+ # Exploration during training only
398
+ if self.training and (self.config.halt_max_steps > 1):
399
+ min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
400
+ halted = halted & (new_steps >= min_halt_steps)
401
+
402
+ # Optional target for Q-learning (kept similar to HRM)
403
+ next_conf = self.inner(new_inner_carry, new_current_data)[-1]
404
+ outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, q_halt_logits, torch.maximum(q_halt_logits, q_continue_logits)))
405
+
406
+ return GLPS_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_nocompile/losses.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Tuple, Dict, Sequence, Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ import math
7
+
8
+ IGNORE_LABEL_ID = -100
9
+
10
+
11
+ def s(x, epsilon=1e-30):
12
+ return torch.where(
13
+ x<0,
14
+ 1/(1-x+ epsilon),
15
+ x + 1
16
+ )
17
+
18
+
19
+ def log_stablemax(x, dim=-1):
20
+ s_x = s(x)
21
+ return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
22
+
23
+
24
+ def stablemax_cross_entropy(logits, labels, ignore_index: int = -100, valid_mask=None):
25
+ logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
26
+
27
+ if valid_mask is None:
28
+ valid_mask = (labels != ignore_index)
29
+ transformed_labels = torch.where(valid_mask, labels, 0)
30
+ prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
31
+
32
+ return -torch.where(valid_mask, prediction_logprobs, 0)
33
+
34
+
35
+ def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
36
+ # Cast logits to f32
37
+ # Flatten logits
38
+ return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
39
+
40
+
41
+ class ACTLossHead(nn.Module):
42
+ def __init__(self, model: nn.Module, loss_type: str):
43
+ super().__init__()
44
+ self.model = model
45
+ self.loss_fn = globals()[loss_type]
46
+
47
+ def initial_carry(self, *args, **kwargs):
48
+ return self.model.initial_carry(*args, **kwargs) # type: ignore
49
+
50
+ def forward(
51
+ self,
52
+ return_keys: Sequence[str],
53
+ # Model args
54
+ **model_kwargs,
55
+ ) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
56
+ # Model logits
57
+ # B x SeqLen x D
58
+ new_carry, outputs = self.model(**model_kwargs)
59
+ labels = new_carry.current_data["labels"]
60
+
61
+ with torch.no_grad():
62
+ # Preds
63
+ outputs["preds"] = torch.argmax(outputs["logits"], dim=-1)
64
+
65
+ # Correctness
66
+ mask = (labels != IGNORE_LABEL_ID)
67
+ loss_counts = mask.sum(-1)
68
+ loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
69
+
70
+ is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
71
+ seq_is_correct = is_correct.sum(-1) == loss_counts
72
+
73
+ # Metrics (halted)
74
+ valid_metrics = new_carry.halted & (loss_counts > 0)
75
+ metrics = {
76
+ "count": valid_metrics.sum(),
77
+
78
+ "accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
79
+ "exact_accuracy": (valid_metrics & seq_is_correct).sum(),
80
+
81
+ "q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
82
+ "steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
83
+ }
84
+
85
+ # Losses
86
+
87
+ lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID, valid_mask=mask) / loss_divisor).sum()
88
+ q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
89
+ metrics.update({
90
+ "lm_loss": lm_loss.detach(),
91
+ "q_halt_loss": q_halt_loss.detach(),
92
+ })
93
+ # Q continue (bootstrapping target loss); Alexia: This fits Q-learning, but seems totally unecessary
94
+ q_continue_loss = 0
95
+ if "target_q_continue" in outputs:
96
+ q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
97
+
98
+ metrics["q_continue_loss"] = q_continue_loss.detach()
99
+ # Filter outputs for return
100
+ detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
101
+
102
+ return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()
Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v2/all_config.yaml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ arch:
2
+ H_cycles: 3
3
+ H_layers: 2
4
+ L_cycles: 1
5
+ L_layers: 6
6
+ dep_rank: 64
7
+ dep_topk: 10
8
+ expansion: 4.0
9
+ forward_dtype: bfloat16
10
+ glps_dep_graph: true
11
+ glps_enabled: true
12
+ glps_fill_obvious: true
13
+ glps_global_propagate_on_low_conf: true
14
+ glps_max_targeted_iters: 3
15
+ glps_tau_halt: 0.95
16
+ glps_tau_uncertain: 0.8
17
+ glps_token_masking: true
18
+ halt_exploration_prob: 0.1
19
+ halt_max_steps: 16
20
+ hidden_size: 512
21
+ loss:
22
+ loss_type: stablemax_cross_entropy
23
+ name: losses@ACTLossHead
24
+ mlp_t: false
25
+ name: recursive_reasoning.glps@GLPS_ACTV1
26
+ num_heads: 8
27
+ pos_encodings: rope
28
+ puzzle_emb_ndim: 0
29
+ rms_norm_eps: 1.0e-05
30
+ rope_theta: 10000.0
31
+ beta1: 0.9
32
+ beta2: 0.95
33
+ checkpoint_every_eval: true
34
+ checkpoint_path: checkpoints/Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v2
35
+ data_paths:
36
+ - data/sudoku-extreme-1k-aug-1000
37
+ data_paths_test: []
38
+ ema: true
39
+ ema_rate: 0.999
40
+ epochs: 50000
41
+ eval_interval: 5000
42
+ eval_save_outputs: []
43
+ evaluators: []
44
+ freeze_weights: false
45
+ global_batch_size: 768
46
+ load_checkpoint: null
47
+ lr: 0.0001
48
+ lr_min_ratio: 1.0
49
+ lr_warmup_steps: 2000
50
+ min_eval_interval: 0
51
+ project_name: Sudoku-extreme-1k-aug-1000-ACT-torch
52
+ puzzle_emb_lr: 0.0001
53
+ puzzle_emb_weight_decay: 1.0
54
+ run_name: pretrain_glps_sudoku_v2
55
+ seed: 0
56
+ weight_decay: 1.0
Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v2/glps.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Tuple, List, Dict
3
+ from dataclasses import dataclass
4
+ import math
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ from pydantic import BaseModel
9
+
10
+ # Reuse the same building blocks as HRM/TRM
11
+ from models.common import trunc_normal_init_
12
+ from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
13
+ from models.sparse_embedding import CastedSparseEmbedding
14
+
15
+ """
16
+ Global-Local Predictive Solver (GLPS)
17
+ ------------------------------------
18
+ A light-weight control-policy on top of the HRM/TRM style blocks:
19
+ - H1: global scan -> certainty map
20
+ - L1: fill-obvious (lock stable cells)
21
+ - H2: dependency scoring over remaining cells
22
+ - L2: targeted refinement (masked updates)
23
+ - H3: energy-based confidence -> (optional) one global propagate sweep -> halt
24
+
25
+ This file keeps parameter growth tiny: a few heads + (optional) low-rank dependency scorer.
26
+ """
27
+
28
+ @dataclass
29
+ class GLPS_ACTV1InnerCarry:
30
+ z_H: torch.Tensor
31
+ z_L: torch.Tensor
32
+
33
+ @dataclass
34
+ class GLPS_ACTV1Carry:
35
+ inner_carry: GLPS_ACTV1InnerCarry
36
+ steps: torch.Tensor
37
+ halted: torch.Tensor
38
+ current_data: Dict[str, torch.Tensor]
39
+
40
+ class GLPS_ACTV1Config(BaseModel):
41
+ # Core IO / shapes
42
+ batch_size: int
43
+ seq_len: int
44
+ puzzle_emb_ndim: int = 0
45
+ num_puzzle_identifiers: int = 1
46
+ vocab_size: int = 256
47
+
48
+ # Cycle schedule
49
+ H_cycles: int = 3 # (scan -> refine -> check) typical
50
+ L_cycles: int = 1
51
+
52
+ # Depth
53
+ H_layers: int = 2
54
+ L_layers: int = 4
55
+
56
+ # Transformer config
57
+ hidden_size: int = 512
58
+ expansion: float = 2.0
59
+ num_heads: int = 8
60
+ pos_encodings: str = "rope"
61
+
62
+ rms_norm_eps: float = 1e-5
63
+ rope_theta: float = 10000.0
64
+
65
+ # ACT wrapper
66
+ halt_max_steps: int = 4
67
+ halt_exploration_prob: float = 0.1
68
+
69
+ forward_dtype: str = "bfloat16"
70
+
71
+ # Optional: use MLP on L instead of attention (matches HRM/TRM option)
72
+ mlp_t: bool = False
73
+
74
+ # ---- GLPS extras (tiny) ----
75
+ glps_enabled: bool = True
76
+ glps_fill_obvious: bool = True
77
+ glps_dep_graph: bool = True
78
+ glps_token_masking: bool = True
79
+ glps_global_propagate_on_low_conf: bool = True
80
+
81
+ glps_tau_halt: float = 0.95 # final confidence to halt
82
+ glps_tau_uncertain: float = 0.60 # cell-wise certainty threshold
83
+ glps_max_targeted_iters: int = 2 # small number: 1-2
84
+
85
+ # Dependency scorer (low rank bilinear)
86
+ dep_rank: int = 32
87
+ dep_topk: int = 8
88
+
89
+ class GLPSBlock(nn.Module):
90
+ def __init__(self, config: GLPS_ACTV1Config) -> None:
91
+ super().__init__()
92
+ self.config = config
93
+ if self.config.mlp_t:
94
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size)
95
+ self.mlp_t = SwiGLU(
96
+ hidden_size=self.config.seq_len + self.puzzle_emb_len, # treat sequence as channel
97
+ expansion=config.expansion,
98
+ )
99
+ else:
100
+ self.self_attn = Attention(
101
+ hidden_size=config.hidden_size,
102
+ head_dim=config.hidden_size // config.num_heads,
103
+ num_heads=config.num_heads,
104
+ num_key_value_heads=config.num_heads,
105
+ causal=False,
106
+ )
107
+ self.mlp = SwiGLU(hidden_size=config.hidden_size, expansion=config.expansion)
108
+ self.norm_eps = config.rms_norm_eps
109
+
110
+ def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
111
+ if self.config.mlp_t:
112
+ # MLP over sequence dimension (mlp-t)
113
+ hidden_states = hidden_states.transpose(1, 2)
114
+ out = self.mlp_t(hidden_states)
115
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
116
+ hidden_states = hidden_states.transpose(1, 2)
117
+ else:
118
+ hidden_states = rms_norm(
119
+ hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states),
120
+ variance_epsilon=self.norm_eps,
121
+ )
122
+ out = self.mlp(hidden_states)
123
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
124
+ return hidden_states
125
+
126
+ class GLPSReasoningModule(nn.Module):
127
+ """Reasoning stack with optional masked updates (only update uncertain tokens)."""
128
+ def __init__(self, layers: List[GLPSBlock]):
129
+ super().__init__()
130
+ self.layers = torch.nn.ModuleList(layers)
131
+
132
+ def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, update_mask: torch.Tensor = None, **kwargs) -> torch.Tensor:
133
+ x = hidden_states
134
+ for layer in self.layers:
135
+ # Compute candidate update using injected context
136
+ y = layer(hidden_states=x + input_injection, **kwargs)
137
+ if update_mask is not None:
138
+ # Convex blend keeps frozen tokens unchanged
139
+ m = update_mask.to(x.dtype)[..., None]
140
+ x = x + m * (y - x)
141
+ else:
142
+ x = y
143
+ return x
144
+
145
+ class GLPS_ACTV1_Inner(nn.Module):
146
+ def __init__(self, config: GLPS_ACTV1Config) -> None:
147
+ super().__init__()
148
+ self.config = config
149
+ self.forward_dtype = getattr(torch, self.config.forward_dtype)
150
+
151
+ # I/O
152
+ self.embed_scale = math.sqrt(self.config.hidden_size)
153
+ embed_init_std = 1.0 / self.embed_scale
154
+
155
+ self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
156
+ self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
157
+ self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
158
+
159
+ # Puzzle emb (optional) — same convention as HRM/TRM
160
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
161
+ if self.config.puzzle_emb_ndim > 0:
162
+ self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim, batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
163
+
164
+ # Positional encodings
165
+ if self.config.pos_encodings == "rope":
166
+ self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads, max_position_embeddings=self.config.seq_len + self.puzzle_emb_len, base=self.config.rope_theta)
167
+ elif self.config.pos_encodings == "learned":
168
+ self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
169
+
170
+ # Reasoning stacks
171
+ self.H_level = GLPSReasoningModule(layers=[GLPSBlock(self.config) for _ in range(self.config.H_layers)])
172
+ self.L_level = GLPSReasoningModule(layers=[GLPSBlock(self.config) for _ in range(self.config.L_layers)])
173
+
174
+ # Initial states (match HRM/TRM style)
175
+ H_init = trunc_normal_init_(torch.empty(1, 1, self.config.hidden_size, dtype=self.forward_dtype), std=1.0)
176
+ L_init = trunc_normal_init_(torch.empty(1, 1, self.config.hidden_size, dtype=self.forward_dtype), std=1.0)
177
+ self.register_buffer("H_init", H_init, persistent=True)
178
+ self.register_buffer("L_init", L_init, persistent=True)
179
+
180
+ # GLPS small heads
181
+ self.candidate_head = CastedLinear(self.config.hidden_size, 16, bias=True) # task-specific; for Sudoku you can slice to 9
182
+ self.certainty_head = CastedLinear(self.config.hidden_size, 1, bias=True)
183
+ self.energy_head = CastedLinear(self.config.hidden_size, 1, bias=True)
184
+
185
+ # Low-rank dependency scorer (shared)
186
+ r = max(1, self.config.dep_rank)
187
+ self.dep_q = CastedLinear(self.config.hidden_size, r, bias=False)
188
+ self.dep_k = CastedLinear(self.config.hidden_size, r, bias=False)
189
+
190
+ # Q head init like HRM/TRM (near-zero -> easier bootstrapping)
191
+ with torch.no_grad():
192
+ self.q_head.weight.zero_()
193
+ self.q_head.bias.fill_(-5)
194
+
195
+ def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
196
+ # Token embedding
197
+ embedding = self.embed_tokens(input.to(torch.int32))
198
+
199
+ # Puzzle embeddings
200
+ if self.config.puzzle_emb_ndim > 0:
201
+ puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
202
+ pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
203
+ if pad_count > 0:
204
+ puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
205
+ embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
206
+
207
+ # Position embeddings
208
+ if self.config.pos_encodings == "learned":
209
+ embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
210
+
211
+ return self.embed_scale * embedding
212
+
213
+ def empty_carry(self, batch_size: int):
214
+ return GLPS_ACTV1InnerCarry(
215
+ z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
216
+ z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
217
+ )
218
+
219
+ def reset_carry(self, reset_flag: torch.Tensor, carry: GLPS_ACTV1InnerCarry):
220
+ # Explicitly expand buffers and mask to target shapes to avoid shape confusion
221
+ B, L, D = carry.z_H.shape
222
+ # Reduce/reset flag to per-batch boolean vector of shape [B]
223
+ if reset_flag.ndim == 1 and reset_flag.shape[0] == B:
224
+ reset_b = reset_flag.to(torch.bool)
225
+ else:
226
+ # If shape is [B, ...] reduce across non-batch dims; otherwise fallback to first B entries
227
+ try:
228
+ reset_b = reset_flag.reshape(B, -1).any(dim=1).to(torch.bool)
229
+ except Exception:
230
+ reset_b = reset_flag.reshape(-1)[:B].to(torch.bool)
231
+ m = reset_b.view(B, 1, 1)
232
+ mH = m.expand(B, L, D)
233
+ mL = mH # same shape for z_L
234
+ H_init_exp = self.H_init.expand(B, L, D)
235
+ L_init_exp = self.L_init.expand(B, L, D)
236
+ return GLPS_ACTV1InnerCarry(
237
+ z_H=torch.where(mH, H_init_exp, carry.z_H),
238
+ z_L=torch.where(mL, L_init_exp, carry.z_L),
239
+ )
240
+
241
+ def _global_scan(self, z_L, z_H, input_embeddings, seq_info):
242
+ # One light pass to gather global signals
243
+ z_scan = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
244
+ cand_logits = self.candidate_head(z_scan) # [B, L, C]
245
+ certainty = torch.sigmoid(self.certainty_head(z_scan)) # [B, L, 1]
246
+ return z_scan, cand_logits, certainty
247
+
248
+ def _build_dep_mask(self, z_ctx, uncertain_mask: torch.Tensor):
249
+ """Compute a dependency-based focus mask from a low-rank bilinear score.
250
+ uncertain_mask: [B, L] boolean mask of cells that are currently uncertain
251
+ Returns: dep_mask [B, L] boolean mask of cells to (re)update.
252
+ """
253
+ B, L, D = z_ctx.shape
254
+ # Project to low-rank space; use float32 to avoid dtype mismatches under torch.compile
255
+ Q = self.dep_q(z_ctx).to(torch.float32) # [B, L, r]
256
+ K = self.dep_k(z_ctx).to(torch.float32) # [B, L, r]
257
+ r = max(1, int(Q.shape[-1]))
258
+ sim = torch.matmul(Q, K.transpose(1, 2)) # [B, L, L] (float32)
259
+ sim = sim / math.sqrt(r)
260
+
261
+ # Aggregate influence from uncertain queries onto target tokens
262
+ src = uncertain_mask.to(sim.dtype).unsqueeze(1) # [B, 1, L]
263
+ influence = torch.matmul(src, sim).squeeze(1) # [B, L]
264
+
265
+ # Top-k influenced tokens per batch
266
+ topk = min(self.config.dep_topk, L)
267
+ vals, idx = torch.topk(influence, k=topk, dim=-1) # [B, topk]
268
+ dep_mask = torch.zeros_like(uncertain_mask)
269
+ dep_mask.scatter_(1, idx, True)
270
+
271
+ # Always include uncertain cells themselves
272
+ dep_mask = dep_mask | uncertain_mask
273
+ return dep_mask
274
+
275
+ def forward(self, carry: GLPS_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]):
276
+ seq_info = dict(cos_sin=getattr(self, "rotary_emb", None)() if hasattr(self, "rotary_emb") else None)
277
+
278
+ # Encode inputs
279
+ input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
280
+
281
+ # States
282
+ z_H, z_L = carry.z_H, carry.z_L
283
+
284
+ if not self.config.glps_enabled:
285
+ # Fallback to an HRM-like single-cycle grad update for compatibility
286
+ with torch.no_grad():
287
+ for _H in range(self.config.H_cycles - 1):
288
+ for _L in range(self.config.L_cycles):
289
+ z_L = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
290
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
291
+ # final grad step
292
+ for _L in range(self.config.L_cycles):
293
+ z_L = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
294
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
295
+
296
+ # Outputs
297
+ new_carry = GLPS_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())
298
+ logits = self.lm_head(z_H)[:, self.puzzle_emb_len:]
299
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
300
+ conf = torch.zeros_like(q_logits[..., :1]) + 0.5 # neutral
301
+ return new_carry, logits, (q_logits[..., 0], q_logits[..., 1]), conf
302
+
303
+ # ===== GLPS path =====
304
+ # H1: global scan (cheap)
305
+ with torch.no_grad():
306
+ z_scan, cand_logits, certainty = self._global_scan(z_L, z_H, input_embeddings, seq_info)
307
+
308
+ # L1: fill-obvious -> compute stable vs uncertain masks
309
+ if self.config.glps_fill_obvious:
310
+ obvious_mask = (certainty >= self.config.glps_tau_uncertain) # [B, L, 1]
311
+ else:
312
+ obvious_mask = torch.zeros_like(certainty).bool()
313
+ stable_mask = obvious_mask.squeeze(-1) # [B, L]
314
+ uncertain_mask = ~stable_mask # [B, L]
315
+
316
+ # H2: dependency prediction over remaining cells
317
+ if self.config.glps_dep_graph:
318
+ dep_mask = self._build_dep_mask(z_scan, uncertain_mask) # [B, L]
319
+ else:
320
+ dep_mask = uncertain_mask
321
+
322
+ # L2: targeted refinement (a couple of masked iters)
323
+ update_mask = dep_mask if self.config.glps_token_masking else None
324
+ z = z_scan.detach() # use scanned context as start
325
+ for _ in range(self.config.glps_max_targeted_iters):
326
+ z = self.L_level(z, z_H + input_embeddings, update_mask=update_mask, **seq_info)
327
+ # Refresh certainty to shrink mask (optional but cheap)
328
+ cert_now = torch.sigmoid(self.certainty_head(z)).squeeze(-1)
329
+ if self.config.glps_token_masking:
330
+ update_mask = dep_mask & (cert_now < self.config.glps_tau_uncertain)
331
+
332
+ # Merge into H and do a light H update with grad
333
+ z_L = z
334
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
335
+
336
+ # H3: energy/consistency -> confidence & optional global propagate
337
+ energy = torch.sigmoid(self.energy_head(z_H)).mean(dim=1) # [B, 1]
338
+ conf = 1.0 - energy
339
+ need_sweep = (conf.squeeze(-1) < self.config.glps_tau_halt)
340
+ if self.config.glps_global_propagate_on_low_conf and need_sweep.any():
341
+ # one final full sweep only for rows needing it
342
+ maskB = need_sweep.view(-1, 1, 1)
343
+ zL2 = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
344
+ zH2 = self.H_level(z_H, zL2, update_mask=None, **seq_info)
345
+ z_L = torch.where(maskB, zL2, z_L)
346
+ z_H = torch.where(maskB, zH2, z_H)
347
+
348
+ # Outputs
349
+ new_carry = GLPS_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())
350
+ logits = self.lm_head(z_H)[:, self.puzzle_emb_len:]
351
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
352
+ return new_carry, logits, (q_logits[..., 0], q_logits[..., 1]), conf
353
+
354
+ class GLPS_ACTV1(nn.Module):
355
+ """ACT-style wrapper that mixes Q-halt with GLPS confidence."""
356
+ def __init__(self, config_dict: dict):
357
+ super().__init__()
358
+ self.config = GLPS_ACTV1Config(**config_dict)
359
+ self.inner = GLPS_ACTV1_Inner(self.config)
360
+
361
+ @property
362
+ def puzzle_emb(self):
363
+ return self.inner.puzzle_emb
364
+
365
+ def initial_carry(self, batch: Dict[str, torch.Tensor]):
366
+ batch_size = batch["inputs"].shape[0]
367
+ return GLPS_ACTV1Carry(
368
+ inner_carry=self.inner.empty_carry(batch_size),
369
+ steps=torch.zeros((batch_size,), dtype=torch.int32),
370
+ halted=torch.ones((batch_size,), dtype=torch.bool), # start halted to force reset on first pass
371
+ current_data={k: torch.empty_like(v) for k, v in batch.items()}
372
+ )
373
+
374
+ def forward(self, carry: GLPS_ACTV1Carry, batch: Dict[str, torch.Tensor]):
375
+ # Reset halted seqs
376
+ new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
377
+ new_steps = torch.where(carry.halted, 0, carry.steps)
378
+ new_current_data = {k: torch.where(carry.halted.view((-1,) + (1,) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
379
+
380
+ # Inner step
381
+ new_inner_carry, logits, (q_halt_logits, q_continue_logits), conf = self.inner(new_inner_carry, new_current_data)
382
+
383
+ outputs = {
384
+ "logits": logits,
385
+ "q_halt_logits": q_halt_logits,
386
+ "q_continue_logits": q_continue_logits,
387
+ "conf": conf.squeeze(-1),
388
+ }
389
+
390
+ with torch.no_grad():
391
+ new_steps = new_steps + 1
392
+ is_last_step = new_steps >= self.config.halt_max_steps
393
+
394
+ # Combine halt signals: Q or confidence or last-step
395
+ halted = is_last_step | (q_halt_logits > q_continue_logits) | (conf.squeeze(-1) >= self.config.glps_tau_halt)
396
+
397
+ # Exploration during training only
398
+ if self.training and (self.config.halt_max_steps > 1):
399
+ min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
400
+ halted = halted & (new_steps >= min_halt_steps)
401
+
402
+ # Optional target for Q-learning (kept similar to HRM)
403
+ next_conf = self.inner(new_inner_carry, new_current_data)[-1]
404
+ outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, q_halt_logits, torch.maximum(q_halt_logits, q_continue_logits)))
405
+ else:
406
+ # During eval, always use max_steps to ensure consistent reasoning depth (same as TRM/HRM eval behavior)
407
+ halted = is_last_step
408
+
409
+ return GLPS_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v2/losses.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Tuple, Dict, Sequence, Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ import math
7
+
8
+ IGNORE_LABEL_ID = -100
9
+
10
+
11
+ def s(x, epsilon=1e-30):
12
+ return torch.where(
13
+ x<0,
14
+ 1/(1-x+ epsilon),
15
+ x + 1
16
+ )
17
+
18
+
19
+ def log_stablemax(x, dim=-1):
20
+ s_x = s(x)
21
+ return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
22
+
23
+
24
+ def stablemax_cross_entropy(logits, labels, ignore_index: int = -100, valid_mask=None):
25
+ logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
26
+
27
+ if valid_mask is None:
28
+ valid_mask = (labels != ignore_index)
29
+ transformed_labels = torch.where(valid_mask, labels, 0)
30
+ prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
31
+
32
+ return -torch.where(valid_mask, prediction_logprobs, 0)
33
+
34
+
35
+ def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
36
+ # Cast logits to f32
37
+ # Flatten logits
38
+ return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
39
+
40
+
41
+ class ACTLossHead(nn.Module):
42
+ def __init__(self, model: nn.Module, loss_type: str):
43
+ super().__init__()
44
+ self.model = model
45
+ self.loss_fn = globals()[loss_type]
46
+
47
+ def initial_carry(self, *args, **kwargs):
48
+ return self.model.initial_carry(*args, **kwargs) # type: ignore
49
+
50
+ def forward(
51
+ self,
52
+ return_keys: Sequence[str],
53
+ # Model args
54
+ **model_kwargs,
55
+ ) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
56
+ # Model logits
57
+ # B x SeqLen x D
58
+ new_carry, outputs = self.model(**model_kwargs)
59
+ labels = new_carry.current_data["labels"]
60
+
61
+ with torch.no_grad():
62
+ # Preds
63
+ outputs["preds"] = torch.argmax(outputs["logits"], dim=-1)
64
+
65
+ # Correctness
66
+ mask = (labels != IGNORE_LABEL_ID)
67
+ loss_counts = mask.sum(-1)
68
+ loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
69
+
70
+ is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
71
+ seq_is_correct = is_correct.sum(-1) == loss_counts
72
+
73
+ # Metrics (halted)
74
+ valid_metrics = new_carry.halted & (loss_counts > 0)
75
+ metrics = {
76
+ "count": valid_metrics.sum(),
77
+
78
+ "accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
79
+ "exact_accuracy": (valid_metrics & seq_is_correct).sum(),
80
+
81
+ "q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
82
+ "steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
83
+ }
84
+
85
+ # Losses
86
+
87
+ lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID, valid_mask=mask) / loss_divisor).sum()
88
+ q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
89
+ metrics.update({
90
+ "lm_loss": lm_loss.detach(),
91
+ "q_halt_loss": q_halt_loss.detach(),
92
+ })
93
+ # Q continue (bootstrapping target loss); Alexia: This fits Q-learning, but seems totally unecessary
94
+ q_continue_loss = 0
95
+ if "target_q_continue" in outputs:
96
+ q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
97
+
98
+ metrics["q_continue_loss"] = q_continue_loss.detach()
99
+ # Filter outputs for return
100
+ detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
101
+
102
+ return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()
Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v4/all_config.yaml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ arch:
2
+ H_cycles: 3
3
+ H_layers: 2
4
+ L_cycles: 1
5
+ L_layers: 7
6
+ dep_rank: 64
7
+ dep_topk: 12
8
+ expansion: 4.0
9
+ forward_dtype: bfloat16
10
+ glps_dep_graph: true
11
+ glps_enabled: true
12
+ glps_fill_obvious: true
13
+ glps_global_propagate_on_low_conf: true
14
+ glps_max_targeted_iters: 4
15
+ glps_tau_halt: 0.95
16
+ glps_tau_uncertain: 0.8
17
+ glps_token_masking: true
18
+ halt_exploration_prob: 0.1
19
+ halt_max_steps: 16
20
+ hidden_size: 512
21
+ loss:
22
+ loss_type: stablemax_cross_entropy
23
+ name: losses@ACTLossHead
24
+ mlp_t: false
25
+ name: recursive_reasoning.glps@GLPS_ACTV1
26
+ num_heads: 8
27
+ pos_encodings: rope
28
+ puzzle_emb_ndim: 0
29
+ rms_norm_eps: 1.0e-05
30
+ rope_theta: 10000.0
31
+ beta1: 0.9
32
+ beta2: 0.95
33
+ checkpoint_every_eval: true
34
+ checkpoint_path: checkpoints/Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v4
35
+ data_paths:
36
+ - data/sudoku-extreme-1k-aug-1000
37
+ data_paths_test: []
38
+ ema: true
39
+ ema_rate: 0.999
40
+ epochs: 50000
41
+ eval_interval: 5000
42
+ eval_save_outputs: []
43
+ evaluators: []
44
+ freeze_weights: false
45
+ global_batch_size: 768
46
+ load_checkpoint: null
47
+ lr: 0.0001
48
+ lr_min_ratio: 1.0
49
+ lr_warmup_steps: 2000
50
+ min_eval_interval: 0
51
+ project_name: Sudoku-extreme-1k-aug-1000-ACT-torch
52
+ puzzle_emb_lr: 0.0001
53
+ puzzle_emb_weight_decay: 1.0
54
+ run_name: pretrain_glps_sudoku_v4
55
+ seed: 0
56
+ weight_decay: 1.0
Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v4/glps.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Tuple, List, Dict
3
+ from dataclasses import dataclass
4
+ import math
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ from pydantic import BaseModel
9
+
10
+ # Reuse the same building blocks as HRM/TRM
11
+ from models.common import trunc_normal_init_
12
+ from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
13
+ from models.sparse_embedding import CastedSparseEmbedding
14
+
15
+ """
16
+ Global-Local Predictive Solver (GLPS)
17
+ ------------------------------------
18
+ A light-weight control-policy on top of the HRM/TRM style blocks:
19
+ - H1: global scan -> certainty map
20
+ - L1: fill-obvious (lock stable cells)
21
+ - H2: dependency scoring over remaining cells
22
+ - L2: targeted refinement (masked updates)
23
+ - H3: energy-based confidence -> (optional) one global propagate sweep -> halt
24
+
25
+ This file keeps parameter growth tiny: a few heads + (optional) low-rank dependency scorer.
26
+ """
27
+
28
+ @dataclass
29
+ class GLPS_ACTV1InnerCarry:
30
+ z_H: torch.Tensor
31
+ z_L: torch.Tensor
32
+
33
+ @dataclass
34
+ class GLPS_ACTV1Carry:
35
+ inner_carry: GLPS_ACTV1InnerCarry
36
+ steps: torch.Tensor
37
+ halted: torch.Tensor
38
+ current_data: Dict[str, torch.Tensor]
39
+
40
+ class GLPS_ACTV1Config(BaseModel):
41
+ # Core IO / shapes
42
+ batch_size: int
43
+ seq_len: int
44
+ puzzle_emb_ndim: int = 0
45
+ num_puzzle_identifiers: int = 1
46
+ vocab_size: int = 256
47
+
48
+ # Cycle schedule
49
+ H_cycles: int = 3 # (scan -> refine -> check) typical
50
+ L_cycles: int = 1
51
+
52
+ # Depth
53
+ H_layers: int = 2
54
+ L_layers: int = 4
55
+
56
+ # Transformer config
57
+ hidden_size: int = 512
58
+ expansion: float = 2.0
59
+ num_heads: int = 8
60
+ pos_encodings: str = "rope"
61
+
62
+ rms_norm_eps: float = 1e-5
63
+ rope_theta: float = 10000.0
64
+
65
+ # ACT wrapper
66
+ halt_max_steps: int = 4
67
+ halt_exploration_prob: float = 0.1
68
+
69
+ forward_dtype: str = "bfloat16"
70
+
71
+ # Optional: use MLP on L instead of attention (matches HRM/TRM option)
72
+ mlp_t: bool = False
73
+
74
+ # ---- GLPS extras (tiny) ----
75
+ glps_enabled: bool = True
76
+ glps_fill_obvious: bool = True
77
+ glps_dep_graph: bool = True
78
+ glps_token_masking: bool = True
79
+ glps_global_propagate_on_low_conf: bool = True
80
+
81
+ glps_tau_halt: float = 0.95 # final confidence to halt
82
+ glps_tau_uncertain: float = 0.60 # cell-wise certainty threshold
83
+ glps_max_targeted_iters: int = 2 # small number: 1-2
84
+
85
+ # Dependency scorer (low rank bilinear)
86
+ dep_rank: int = 32
87
+ dep_topk: int = 8
88
+
89
+ class GLPSBlock(nn.Module):
90
+ def __init__(self, config: GLPS_ACTV1Config) -> None:
91
+ super().__init__()
92
+ self.config = config
93
+ if self.config.mlp_t:
94
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size)
95
+ self.mlp_t = SwiGLU(
96
+ hidden_size=self.config.seq_len + self.puzzle_emb_len, # treat sequence as channel
97
+ expansion=config.expansion,
98
+ )
99
+ else:
100
+ self.self_attn = Attention(
101
+ hidden_size=config.hidden_size,
102
+ head_dim=config.hidden_size // config.num_heads,
103
+ num_heads=config.num_heads,
104
+ num_key_value_heads=config.num_heads,
105
+ causal=False,
106
+ )
107
+ self.mlp = SwiGLU(hidden_size=config.hidden_size, expansion=config.expansion)
108
+ self.norm_eps = config.rms_norm_eps
109
+
110
+ def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
111
+ if self.config.mlp_t:
112
+ # MLP over sequence dimension (mlp-t)
113
+ hidden_states = hidden_states.transpose(1, 2)
114
+ out = self.mlp_t(hidden_states)
115
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
116
+ hidden_states = hidden_states.transpose(1, 2)
117
+ else:
118
+ hidden_states = rms_norm(
119
+ hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states),
120
+ variance_epsilon=self.norm_eps,
121
+ )
122
+ out = self.mlp(hidden_states)
123
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
124
+ return hidden_states
125
+
126
+ class GLPSReasoningModule(nn.Module):
127
+ """Reasoning stack with optional masked updates (only update uncertain tokens)."""
128
+ def __init__(self, layers: List[GLPSBlock]):
129
+ super().__init__()
130
+ self.layers = torch.nn.ModuleList(layers)
131
+
132
+ def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, update_mask: torch.Tensor = None, **kwargs) -> torch.Tensor:
133
+ x = hidden_states
134
+ for layer in self.layers:
135
+ # Compute candidate update using injected context
136
+ y = layer(hidden_states=x + input_injection, **kwargs)
137
+ if update_mask is not None:
138
+ # Convex blend keeps frozen tokens unchanged
139
+ m = update_mask.to(x.dtype)[..., None]
140
+ x = x + m * (y - x)
141
+ else:
142
+ x = y
143
+ return x
144
+
145
+ class GLPS_ACTV1_Inner(nn.Module):
146
+ def __init__(self, config: GLPS_ACTV1Config) -> None:
147
+ super().__init__()
148
+ self.config = config
149
+ self.forward_dtype = getattr(torch, self.config.forward_dtype)
150
+
151
+ # I/O
152
+ self.embed_scale = math.sqrt(self.config.hidden_size)
153
+ embed_init_std = 1.0 / self.embed_scale
154
+
155
+ self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
156
+ self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
157
+ self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
158
+
159
+ # Puzzle emb (optional) — same convention as HRM/TRM
160
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
161
+ if self.config.puzzle_emb_ndim > 0:
162
+ self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim, batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
163
+
164
+ # Positional encodings
165
+ if self.config.pos_encodings == "rope":
166
+ self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads, max_position_embeddings=self.config.seq_len + self.puzzle_emb_len, base=self.config.rope_theta)
167
+ elif self.config.pos_encodings == "learned":
168
+ self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
169
+
170
+ # Reasoning stacks
171
+ self.H_level = GLPSReasoningModule(layers=[GLPSBlock(self.config) for _ in range(self.config.H_layers)])
172
+ self.L_level = GLPSReasoningModule(layers=[GLPSBlock(self.config) for _ in range(self.config.L_layers)])
173
+
174
+ # Initial states (match HRM/TRM style)
175
+ H_init = trunc_normal_init_(torch.empty(1, 1, self.config.hidden_size, dtype=self.forward_dtype), std=1.0)
176
+ L_init = trunc_normal_init_(torch.empty(1, 1, self.config.hidden_size, dtype=self.forward_dtype), std=1.0)
177
+ self.register_buffer("H_init", H_init, persistent=True)
178
+ self.register_buffer("L_init", L_init, persistent=True)
179
+
180
+ # GLPS small heads
181
+ self.candidate_head = CastedLinear(self.config.hidden_size, 16, bias=True) # task-specific; for Sudoku you can slice to 9
182
+ self.certainty_head = CastedLinear(self.config.hidden_size, 1, bias=True)
183
+ self.energy_head = CastedLinear(self.config.hidden_size, 1, bias=True)
184
+
185
+ # Low-rank dependency scorer (shared)
186
+ r = max(1, self.config.dep_rank)
187
+ self.dep_q = CastedLinear(self.config.hidden_size, r, bias=False)
188
+ self.dep_k = CastedLinear(self.config.hidden_size, r, bias=False)
189
+
190
+ # Q head init like HRM/TRM (near-zero -> easier bootstrapping)
191
+ with torch.no_grad():
192
+ self.q_head.weight.zero_()
193
+ self.q_head.bias.fill_(-5)
194
+
195
+ def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
196
+ # Token embedding
197
+ embedding = self.embed_tokens(input.to(torch.int32))
198
+
199
+ # Puzzle embeddings
200
+ if self.config.puzzle_emb_ndim > 0:
201
+ puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
202
+ pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
203
+ if pad_count > 0:
204
+ puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
205
+ embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
206
+
207
+ # Position embeddings
208
+ if self.config.pos_encodings == "learned":
209
+ embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
210
+
211
+ return self.embed_scale * embedding
212
+
213
+ def empty_carry(self, batch_size: int):
214
+ return GLPS_ACTV1InnerCarry(
215
+ z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
216
+ z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
217
+ )
218
+
219
+ def reset_carry(self, reset_flag: torch.Tensor, carry: GLPS_ACTV1InnerCarry):
220
+ # Explicitly expand buffers and mask to target shapes to avoid shape confusion
221
+ B, L, D = carry.z_H.shape
222
+ # Reduce/reset flag to per-batch boolean vector of shape [B]
223
+ if reset_flag.ndim == 1 and reset_flag.shape[0] == B:
224
+ reset_b = reset_flag.to(torch.bool)
225
+ else:
226
+ # If shape is [B, ...] reduce across non-batch dims; otherwise fallback to first B entries
227
+ try:
228
+ reset_b = reset_flag.reshape(B, -1).any(dim=1).to(torch.bool)
229
+ except Exception:
230
+ reset_b = reset_flag.reshape(-1)[:B].to(torch.bool)
231
+ m = reset_b.view(B, 1, 1)
232
+ mH = m.expand(B, L, D)
233
+ mL = mH # same shape for z_L
234
+ H_init_exp = self.H_init.expand(B, L, D)
235
+ L_init_exp = self.L_init.expand(B, L, D)
236
+ return GLPS_ACTV1InnerCarry(
237
+ z_H=torch.where(mH, H_init_exp, carry.z_H),
238
+ z_L=torch.where(mL, L_init_exp, carry.z_L),
239
+ )
240
+
241
+ def _global_scan(self, z_L, z_H, input_embeddings, seq_info):
242
+ # One light pass to gather global signals
243
+ z_scan = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
244
+ cand_logits = self.candidate_head(z_scan) # [B, L, C]
245
+ certainty = torch.sigmoid(self.certainty_head(z_scan)) # [B, L, 1]
246
+ return z_scan, cand_logits, certainty
247
+
248
+ def _build_dep_mask(self, z_ctx, uncertain_mask: torch.Tensor):
249
+ """Compute a dependency-based focus mask from a low-rank bilinear score.
250
+ uncertain_mask: [B, L] boolean mask of cells that are currently uncertain
251
+ Returns: dep_mask [B, L] boolean mask of cells to (re)update.
252
+ """
253
+ B, L, D = z_ctx.shape
254
+ # Project to low-rank space; use float32 to avoid dtype mismatches under torch.compile
255
+ Q = self.dep_q(z_ctx).to(torch.float32) # [B, L, r]
256
+ K = self.dep_k(z_ctx).to(torch.float32) # [B, L, r]
257
+ r = max(1, int(Q.shape[-1]))
258
+ sim = torch.matmul(Q, K.transpose(1, 2)) # [B, L, L] (float32)
259
+ sim = sim / math.sqrt(r)
260
+
261
+ # Aggregate influence from uncertain queries onto target tokens
262
+ src = uncertain_mask.to(sim.dtype).unsqueeze(1) # [B, 1, L]
263
+ influence = torch.matmul(src, sim).squeeze(1) # [B, L]
264
+
265
+ # Top-k influenced tokens per batch
266
+ topk = min(self.config.dep_topk, L)
267
+ vals, idx = torch.topk(influence, k=topk, dim=-1) # [B, topk]
268
+ dep_mask = torch.zeros_like(uncertain_mask)
269
+ dep_mask.scatter_(1, idx, True)
270
+
271
+ # Always include uncertain cells themselves
272
+ dep_mask = dep_mask | uncertain_mask
273
+ return dep_mask
274
+
275
+ def forward(self, carry: GLPS_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]):
276
+ seq_info = dict(cos_sin=getattr(self, "rotary_emb", None)() if hasattr(self, "rotary_emb") else None)
277
+
278
+ # Encode inputs
279
+ input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
280
+
281
+ # States
282
+ z_H, z_L = carry.z_H, carry.z_L
283
+
284
+ if not self.config.glps_enabled:
285
+ # Fallback to an HRM-like single-cycle grad update for compatibility
286
+ with torch.no_grad():
287
+ for _H in range(self.config.H_cycles - 1):
288
+ for _L in range(self.config.L_cycles):
289
+ z_L = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
290
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
291
+ # final grad step
292
+ for _L in range(self.config.L_cycles):
293
+ z_L = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
294
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
295
+
296
+ # Outputs
297
+ new_carry = GLPS_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())
298
+ logits = self.lm_head(z_H)[:, self.puzzle_emb_len:]
299
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
300
+ conf = torch.zeros_like(q_logits[..., :1]) + 0.5 # neutral
301
+ return new_carry, logits, (q_logits[..., 0], q_logits[..., 1]), conf
302
+
303
+ # ===== GLPS path =====
304
+ # H1: global scan (cheap)
305
+ with torch.no_grad():
306
+ z_scan, cand_logits, certainty = self._global_scan(z_L, z_H, input_embeddings, seq_info)
307
+
308
+ # L1: fill-obvious -> compute stable vs uncertain masks
309
+ if self.config.glps_fill_obvious:
310
+ obvious_mask = (certainty >= self.config.glps_tau_uncertain) # [B, L, 1]
311
+ else:
312
+ obvious_mask = torch.zeros_like(certainty).bool()
313
+ stable_mask = obvious_mask.squeeze(-1) # [B, L]
314
+ uncertain_mask = ~stable_mask # [B, L]
315
+
316
+ # H2: dependency prediction over remaining cells
317
+ if self.config.glps_dep_graph:
318
+ dep_mask = self._build_dep_mask(z_scan, uncertain_mask) # [B, L]
319
+ else:
320
+ dep_mask = uncertain_mask
321
+
322
+ # L2: targeted refinement (a couple of masked iters)
323
+ update_mask = dep_mask if self.config.glps_token_masking else None
324
+ z = z_scan.detach() # use scanned context as start
325
+ for _ in range(self.config.glps_max_targeted_iters):
326
+ z = self.L_level(z, z_H + input_embeddings, update_mask=update_mask, **seq_info)
327
+ # Refresh certainty to shrink mask (optional but cheap)
328
+ cert_now = torch.sigmoid(self.certainty_head(z)).squeeze(-1)
329
+ if self.config.glps_token_masking:
330
+ update_mask = dep_mask & (cert_now < self.config.glps_tau_uncertain)
331
+
332
+ # Merge into H and do a light H update with grad
333
+ z_L = z
334
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
335
+
336
+ # H3: energy/consistency -> confidence & optional global propagate
337
+ energy = torch.sigmoid(self.energy_head(z_H)).mean(dim=1) # [B, 1]
338
+ conf = 1.0 - energy
339
+ need_sweep = (conf.squeeze(-1) < self.config.glps_tau_halt)
340
+ if self.config.glps_global_propagate_on_low_conf and need_sweep.any():
341
+ # one final full sweep only for rows needing it
342
+ maskB = need_sweep.view(-1, 1, 1)
343
+ zL2 = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
344
+ zH2 = self.H_level(z_H, zL2, update_mask=None, **seq_info)
345
+ z_L = torch.where(maskB, zL2, z_L)
346
+ z_H = torch.where(maskB, zH2, z_H)
347
+
348
+ # Outputs
349
+ new_carry = GLPS_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())
350
+ logits = self.lm_head(z_H)[:, self.puzzle_emb_len:]
351
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
352
+ return new_carry, logits, (q_logits[..., 0], q_logits[..., 1]), conf
353
+
354
+ class GLPS_ACTV1(nn.Module):
355
+ """ACT-style wrapper that mixes Q-halt with GLPS confidence."""
356
+ def __init__(self, config_dict: dict):
357
+ super().__init__()
358
+ self.config = GLPS_ACTV1Config(**config_dict)
359
+ self.inner = GLPS_ACTV1_Inner(self.config)
360
+
361
+ @property
362
+ def puzzle_emb(self):
363
+ return self.inner.puzzle_emb
364
+
365
+ def initial_carry(self, batch: Dict[str, torch.Tensor]):
366
+ batch_size = batch["inputs"].shape[0]
367
+ return GLPS_ACTV1Carry(
368
+ inner_carry=self.inner.empty_carry(batch_size),
369
+ steps=torch.zeros((batch_size,), dtype=torch.int32),
370
+ halted=torch.ones((batch_size,), dtype=torch.bool), # start halted to force reset on first pass
371
+ current_data={k: torch.empty_like(v) for k, v in batch.items()}
372
+ )
373
+
374
+ def forward(self, carry: GLPS_ACTV1Carry, batch: Dict[str, torch.Tensor]):
375
+ # Reset halted seqs
376
+ new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
377
+ new_steps = torch.where(carry.halted, 0, carry.steps)
378
+ new_current_data = {k: torch.where(carry.halted.view((-1,) + (1,) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
379
+
380
+ # Inner step
381
+ new_inner_carry, logits, (q_halt_logits, q_continue_logits), conf = self.inner(new_inner_carry, new_current_data)
382
+
383
+ outputs = {
384
+ "logits": logits,
385
+ "q_halt_logits": q_halt_logits,
386
+ "q_continue_logits": q_continue_logits,
387
+ "conf": conf.squeeze(-1),
388
+ }
389
+
390
+ with torch.no_grad():
391
+ new_steps = new_steps + 1
392
+ is_last_step = new_steps >= self.config.halt_max_steps
393
+
394
+ # Combine halt signals: Q or confidence or last-step
395
+ halted = is_last_step | (q_halt_logits > q_continue_logits) | (conf.squeeze(-1) >= self.config.glps_tau_halt)
396
+
397
+ # Exploration during training only
398
+ if self.training and (self.config.halt_max_steps > 1):
399
+ min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
400
+ halted = halted & (new_steps >= min_halt_steps)
401
+
402
+ # Optional target for Q-learning (kept similar to HRM)
403
+ next_conf = self.inner(new_inner_carry, new_current_data)[-1]
404
+ outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, q_halt_logits, torch.maximum(q_halt_logits, q_continue_logits)))
405
+ else:
406
+ # During eval, always use max_steps to ensure consistent reasoning depth (same as TRM/HRM eval behavior)
407
+ halted = is_last_step
408
+
409
+ return GLPS_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v4/losses.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Tuple, Dict, Sequence, Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ import math
7
+
8
+ IGNORE_LABEL_ID = -100
9
+
10
+
11
+ def s(x, epsilon=1e-30):
12
+ return torch.where(
13
+ x<0,
14
+ 1/(1-x+ epsilon),
15
+ x + 1
16
+ )
17
+
18
+
19
+ def log_stablemax(x, dim=-1):
20
+ s_x = s(x)
21
+ return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
22
+
23
+
24
+ def stablemax_cross_entropy(logits, labels, ignore_index: int = -100, valid_mask=None):
25
+ logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
26
+
27
+ if valid_mask is None:
28
+ valid_mask = (labels != ignore_index)
29
+ transformed_labels = torch.where(valid_mask, labels, 0)
30
+ prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
31
+
32
+ return -torch.where(valid_mask, prediction_logprobs, 0)
33
+
34
+
35
+ def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
36
+ # Cast logits to f32
37
+ # Flatten logits
38
+ return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
39
+
40
+
41
+ class ACTLossHead(nn.Module):
42
+ def __init__(self, model: nn.Module, loss_type: str):
43
+ super().__init__()
44
+ self.model = model
45
+ self.loss_fn = globals()[loss_type]
46
+
47
+ def initial_carry(self, *args, **kwargs):
48
+ return self.model.initial_carry(*args, **kwargs) # type: ignore
49
+
50
+ def forward(
51
+ self,
52
+ return_keys: Sequence[str],
53
+ # Model args
54
+ **model_kwargs,
55
+ ) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
56
+ # Model logits
57
+ # B x SeqLen x D
58
+ new_carry, outputs = self.model(**model_kwargs)
59
+ labels = new_carry.current_data["labels"]
60
+
61
+ with torch.no_grad():
62
+ # Preds
63
+ outputs["preds"] = torch.argmax(outputs["logits"], dim=-1)
64
+
65
+ # Correctness
66
+ mask = (labels != IGNORE_LABEL_ID)
67
+ loss_counts = mask.sum(-1)
68
+ loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
69
+
70
+ is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
71
+ seq_is_correct = is_correct.sum(-1) == loss_counts
72
+
73
+ # Metrics (halted)
74
+ valid_metrics = new_carry.halted & (loss_counts > 0)
75
+ metrics = {
76
+ "count": valid_metrics.sum(),
77
+
78
+ "accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
79
+ "exact_accuracy": (valid_metrics & seq_is_correct).sum(),
80
+
81
+ "q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
82
+ "steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
83
+ }
84
+
85
+ # Losses
86
+
87
+ lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID, valid_mask=mask) / loss_divisor).sum()
88
+ q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
89
+ metrics.update({
90
+ "lm_loss": lm_loss.detach(),
91
+ "q_halt_loss": q_halt_loss.detach(),
92
+ })
93
+ # Q continue (bootstrapping target loss); Alexia: This fits Q-learning, but seems totally unecessary
94
+ q_continue_loss = 0
95
+ if "target_q_continue" in outputs:
96
+ q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
97
+
98
+ metrics["q_continue_loss"] = q_continue_loss.detach()
99
+ # Filter outputs for return
100
+ detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
101
+
102
+ return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()
Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v4_decay/all_config.yaml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ arch:
2
+ H_cycles: 3
3
+ H_layers: 2
4
+ L_cycles: 1
5
+ L_layers: 7
6
+ dep_rank: 64
7
+ dep_topk: 12
8
+ expansion: 4.0
9
+ forward_dtype: bfloat16
10
+ glps_dep_graph: true
11
+ glps_enabled: true
12
+ glps_fill_obvious: true
13
+ glps_global_propagate_on_low_conf: true
14
+ glps_max_targeted_iters: 4
15
+ glps_tau_halt: 0.95
16
+ glps_tau_uncertain: 0.8
17
+ glps_token_masking: true
18
+ halt_exploration_prob: 0.1
19
+ halt_max_steps: 16
20
+ hidden_size: 512
21
+ loss:
22
+ loss_type: stablemax_cross_entropy
23
+ name: losses@ACTLossHead
24
+ mlp_t: false
25
+ name: recursive_reasoning.glps@GLPS_ACTV1
26
+ num_heads: 8
27
+ pos_encodings: rope
28
+ puzzle_emb_ndim: 0
29
+ rms_norm_eps: 1.0e-05
30
+ rope_theta: 10000.0
31
+ beta1: 0.9
32
+ beta2: 0.95
33
+ checkpoint_every_eval: true
34
+ checkpoint_path: checkpoints/Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v4_decay
35
+ data_paths:
36
+ - data/sudoku-extreme-1k-aug-1000
37
+ data_paths_test: []
38
+ ema: true
39
+ ema_rate: 0.999
40
+ epochs: 50000
41
+ eval_interval: 5000
42
+ eval_save_outputs: []
43
+ evaluators: []
44
+ freeze_weights: false
45
+ global_batch_size: 768
46
+ load_checkpoint: checkpoints/Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v4_decay/step_45570
47
+ lr: 0.0001
48
+ lr_min_ratio: 0.1
49
+ lr_warmup_steps: 2000
50
+ min_eval_interval: 0
51
+ project_name: Sudoku-extreme-1k-aug-1000-ACT-torch
52
+ puzzle_emb_lr: 0.0001
53
+ puzzle_emb_weight_decay: 1.0
54
+ run_name: pretrain_glps_sudoku_v4_decay
55
+ seed: 0
56
+ weight_decay: 1.0
Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v4_decay/glps.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Tuple, List, Dict
3
+ from dataclasses import dataclass
4
+ import math
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ from pydantic import BaseModel
9
+
10
+ # Reuse the same building blocks as HRM/TRM
11
+ from models.common import trunc_normal_init_
12
+ from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
13
+ from models.sparse_embedding import CastedSparseEmbedding
14
+
15
+ """
16
+ Global-Local Predictive Solver (GLPS)
17
+ ------------------------------------
18
+ A light-weight control-policy on top of the HRM/TRM style blocks:
19
+ - H1: global scan -> certainty map
20
+ - L1: fill-obvious (lock stable cells)
21
+ - H2: dependency scoring over remaining cells
22
+ - L2: targeted refinement (masked updates)
23
+ - H3: energy-based confidence -> (optional) one global propagate sweep -> halt
24
+
25
+ This file keeps parameter growth tiny: a few heads + (optional) low-rank dependency scorer.
26
+ """
27
+
28
+ @dataclass
29
+ class GLPS_ACTV1InnerCarry:
30
+ z_H: torch.Tensor
31
+ z_L: torch.Tensor
32
+
33
+ @dataclass
34
+ class GLPS_ACTV1Carry:
35
+ inner_carry: GLPS_ACTV1InnerCarry
36
+ steps: torch.Tensor
37
+ halted: torch.Tensor
38
+ current_data: Dict[str, torch.Tensor]
39
+
40
+ class GLPS_ACTV1Config(BaseModel):
41
+ # Core IO / shapes
42
+ batch_size: int
43
+ seq_len: int
44
+ puzzle_emb_ndim: int = 0
45
+ num_puzzle_identifiers: int = 1
46
+ vocab_size: int = 256
47
+
48
+ # Cycle schedule
49
+ H_cycles: int = 3 # (scan -> refine -> check) typical
50
+ L_cycles: int = 1
51
+
52
+ # Depth
53
+ H_layers: int = 2
54
+ L_layers: int = 4
55
+
56
+ # Transformer config
57
+ hidden_size: int = 512
58
+ expansion: float = 2.0
59
+ num_heads: int = 8
60
+ pos_encodings: str = "rope"
61
+
62
+ rms_norm_eps: float = 1e-5
63
+ rope_theta: float = 10000.0
64
+
65
+ # ACT wrapper
66
+ halt_max_steps: int = 4
67
+ halt_exploration_prob: float = 0.1
68
+
69
+ forward_dtype: str = "bfloat16"
70
+
71
+ # Optional: use MLP on L instead of attention (matches HRM/TRM option)
72
+ mlp_t: bool = False
73
+
74
+ # ---- GLPS extras (tiny) ----
75
+ glps_enabled: bool = True
76
+ glps_fill_obvious: bool = True
77
+ glps_dep_graph: bool = True
78
+ glps_token_masking: bool = True
79
+ glps_global_propagate_on_low_conf: bool = True
80
+
81
+ glps_tau_halt: float = 0.95 # final confidence to halt
82
+ glps_tau_uncertain: float = 0.60 # cell-wise certainty threshold
83
+ glps_max_targeted_iters: int = 2 # small number: 1-2
84
+
85
+ # Dependency scorer (low rank bilinear)
86
+ dep_rank: int = 32
87
+ dep_topk: int = 8
88
+
89
+ class GLPSBlock(nn.Module):
90
+ def __init__(self, config: GLPS_ACTV1Config) -> None:
91
+ super().__init__()
92
+ self.config = config
93
+ if self.config.mlp_t:
94
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size)
95
+ self.mlp_t = SwiGLU(
96
+ hidden_size=self.config.seq_len + self.puzzle_emb_len, # treat sequence as channel
97
+ expansion=config.expansion,
98
+ )
99
+ else:
100
+ self.self_attn = Attention(
101
+ hidden_size=config.hidden_size,
102
+ head_dim=config.hidden_size // config.num_heads,
103
+ num_heads=config.num_heads,
104
+ num_key_value_heads=config.num_heads,
105
+ causal=False,
106
+ )
107
+ self.mlp = SwiGLU(hidden_size=config.hidden_size, expansion=config.expansion)
108
+ self.norm_eps = config.rms_norm_eps
109
+
110
+ def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
111
+ if self.config.mlp_t:
112
+ # MLP over sequence dimension (mlp-t)
113
+ hidden_states = hidden_states.transpose(1, 2)
114
+ out = self.mlp_t(hidden_states)
115
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
116
+ hidden_states = hidden_states.transpose(1, 2)
117
+ else:
118
+ hidden_states = rms_norm(
119
+ hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states),
120
+ variance_epsilon=self.norm_eps,
121
+ )
122
+ out = self.mlp(hidden_states)
123
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
124
+ return hidden_states
125
+
126
+ class GLPSReasoningModule(nn.Module):
127
+ """Reasoning stack with optional masked updates (only update uncertain tokens)."""
128
+ def __init__(self, layers: List[GLPSBlock]):
129
+ super().__init__()
130
+ self.layers = torch.nn.ModuleList(layers)
131
+
132
+ def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, update_mask: torch.Tensor = None, **kwargs) -> torch.Tensor:
133
+ x = hidden_states
134
+ for layer in self.layers:
135
+ # Compute candidate update using injected context
136
+ y = layer(hidden_states=x + input_injection, **kwargs)
137
+ if update_mask is not None:
138
+ # Convex blend keeps frozen tokens unchanged
139
+ m = update_mask.to(x.dtype)[..., None]
140
+ x = x + m * (y - x)
141
+ else:
142
+ x = y
143
+ return x
144
+
145
+ class GLPS_ACTV1_Inner(nn.Module):
146
+ def __init__(self, config: GLPS_ACTV1Config) -> None:
147
+ super().__init__()
148
+ self.config = config
149
+ self.forward_dtype = getattr(torch, self.config.forward_dtype)
150
+
151
+ # I/O
152
+ self.embed_scale = math.sqrt(self.config.hidden_size)
153
+ embed_init_std = 1.0 / self.embed_scale
154
+
155
+ self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
156
+ self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
157
+ self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
158
+
159
+ # Puzzle emb (optional) — same convention as HRM/TRM
160
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
161
+ if self.config.puzzle_emb_ndim > 0:
162
+ self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim, batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
163
+
164
+ # Positional encodings
165
+ if self.config.pos_encodings == "rope":
166
+ self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads, max_position_embeddings=self.config.seq_len + self.puzzle_emb_len, base=self.config.rope_theta)
167
+ elif self.config.pos_encodings == "learned":
168
+ self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
169
+
170
+ # Reasoning stacks
171
+ self.H_level = GLPSReasoningModule(layers=[GLPSBlock(self.config) for _ in range(self.config.H_layers)])
172
+ self.L_level = GLPSReasoningModule(layers=[GLPSBlock(self.config) for _ in range(self.config.L_layers)])
173
+
174
+ # Initial states (match HRM/TRM style)
175
+ H_init = trunc_normal_init_(torch.empty(1, 1, self.config.hidden_size, dtype=self.forward_dtype), std=1.0)
176
+ L_init = trunc_normal_init_(torch.empty(1, 1, self.config.hidden_size, dtype=self.forward_dtype), std=1.0)
177
+ self.register_buffer("H_init", H_init, persistent=True)
178
+ self.register_buffer("L_init", L_init, persistent=True)
179
+
180
+ # GLPS small heads
181
+ self.candidate_head = CastedLinear(self.config.hidden_size, 16, bias=True) # task-specific; for Sudoku you can slice to 9
182
+ self.certainty_head = CastedLinear(self.config.hidden_size, 1, bias=True)
183
+ self.energy_head = CastedLinear(self.config.hidden_size, 1, bias=True)
184
+
185
+ # Low-rank dependency scorer (shared)
186
+ r = max(1, self.config.dep_rank)
187
+ self.dep_q = CastedLinear(self.config.hidden_size, r, bias=False)
188
+ self.dep_k = CastedLinear(self.config.hidden_size, r, bias=False)
189
+
190
+ # Q head init like HRM/TRM (near-zero -> easier bootstrapping)
191
+ with torch.no_grad():
192
+ self.q_head.weight.zero_()
193
+ self.q_head.bias.fill_(-5)
194
+
195
+ def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
196
+ # Token embedding
197
+ embedding = self.embed_tokens(input.to(torch.int32))
198
+
199
+ # Puzzle embeddings
200
+ if self.config.puzzle_emb_ndim > 0:
201
+ puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
202
+ pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
203
+ if pad_count > 0:
204
+ puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
205
+ embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
206
+
207
+ # Position embeddings
208
+ if self.config.pos_encodings == "learned":
209
+ embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
210
+
211
+ return self.embed_scale * embedding
212
+
213
+ def empty_carry(self, batch_size: int):
214
+ return GLPS_ACTV1InnerCarry(
215
+ z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
216
+ z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
217
+ )
218
+
219
+ def reset_carry(self, reset_flag: torch.Tensor, carry: GLPS_ACTV1InnerCarry):
220
+ # Explicitly expand buffers and mask to target shapes to avoid shape confusion
221
+ B, L, D = carry.z_H.shape
222
+ # Reduce/reset flag to per-batch boolean vector of shape [B]
223
+ if reset_flag.ndim == 1 and reset_flag.shape[0] == B:
224
+ reset_b = reset_flag.to(torch.bool)
225
+ else:
226
+ # If shape is [B, ...] reduce across non-batch dims; otherwise fallback to first B entries
227
+ try:
228
+ reset_b = reset_flag.reshape(B, -1).any(dim=1).to(torch.bool)
229
+ except Exception:
230
+ reset_b = reset_flag.reshape(-1)[:B].to(torch.bool)
231
+ m = reset_b.view(B, 1, 1)
232
+ mH = m.expand(B, L, D)
233
+ mL = mH # same shape for z_L
234
+ H_init_exp = self.H_init.expand(B, L, D)
235
+ L_init_exp = self.L_init.expand(B, L, D)
236
+ return GLPS_ACTV1InnerCarry(
237
+ z_H=torch.where(mH, H_init_exp, carry.z_H),
238
+ z_L=torch.where(mL, L_init_exp, carry.z_L),
239
+ )
240
+
241
+ def _global_scan(self, z_L, z_H, input_embeddings, seq_info):
242
+ # One light pass to gather global signals
243
+ z_scan = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
244
+ cand_logits = self.candidate_head(z_scan) # [B, L, C]
245
+ certainty = torch.sigmoid(self.certainty_head(z_scan)) # [B, L, 1]
246
+ return z_scan, cand_logits, certainty
247
+
248
+ def _build_dep_mask(self, z_ctx, uncertain_mask: torch.Tensor):
249
+ """Compute a dependency-based focus mask from a low-rank bilinear score.
250
+ uncertain_mask: [B, L] boolean mask of cells that are currently uncertain
251
+ Returns: dep_mask [B, L] boolean mask of cells to (re)update.
252
+ """
253
+ B, L, D = z_ctx.shape
254
+ # Project to low-rank space; use float32 to avoid dtype mismatches under torch.compile
255
+ Q = self.dep_q(z_ctx).to(torch.float32) # [B, L, r]
256
+ K = self.dep_k(z_ctx).to(torch.float32) # [B, L, r]
257
+ r = max(1, int(Q.shape[-1]))
258
+ sim = torch.matmul(Q, K.transpose(1, 2)) # [B, L, L] (float32)
259
+ sim = sim / math.sqrt(r)
260
+
261
+ # Aggregate influence from uncertain queries onto target tokens
262
+ src = uncertain_mask.to(sim.dtype).unsqueeze(1) # [B, 1, L]
263
+ influence = torch.matmul(src, sim).squeeze(1) # [B, L]
264
+
265
+ # Top-k influenced tokens per batch
266
+ topk = min(self.config.dep_topk, L)
267
+ vals, idx = torch.topk(influence, k=topk, dim=-1) # [B, topk]
268
+ dep_mask = torch.zeros_like(uncertain_mask)
269
+ dep_mask.scatter_(1, idx, True)
270
+
271
+ # Always include uncertain cells themselves
272
+ dep_mask = dep_mask | uncertain_mask
273
+ return dep_mask
274
+
275
+ def forward(self, carry: GLPS_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]):
276
+ seq_info = dict(cos_sin=getattr(self, "rotary_emb", None)() if hasattr(self, "rotary_emb") else None)
277
+
278
+ # Encode inputs
279
+ input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
280
+
281
+ # States
282
+ z_H, z_L = carry.z_H, carry.z_L
283
+
284
+ if not self.config.glps_enabled:
285
+ # Fallback to an HRM-like single-cycle grad update for compatibility
286
+ with torch.no_grad():
287
+ for _H in range(self.config.H_cycles - 1):
288
+ for _L in range(self.config.L_cycles):
289
+ z_L = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
290
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
291
+ # final grad step
292
+ for _L in range(self.config.L_cycles):
293
+ z_L = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
294
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
295
+
296
+ # Outputs
297
+ new_carry = GLPS_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())
298
+ logits = self.lm_head(z_H)[:, self.puzzle_emb_len:]
299
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
300
+ conf = torch.zeros_like(q_logits[..., :1]) + 0.5 # neutral
301
+ return new_carry, logits, (q_logits[..., 0], q_logits[..., 1]), conf
302
+
303
+ # ===== GLPS path =====
304
+ # H1: global scan (cheap)
305
+ with torch.no_grad():
306
+ z_scan, cand_logits, certainty = self._global_scan(z_L, z_H, input_embeddings, seq_info)
307
+
308
+ # L1: fill-obvious -> compute stable vs uncertain masks
309
+ if self.config.glps_fill_obvious:
310
+ obvious_mask = (certainty >= self.config.glps_tau_uncertain) # [B, L, 1]
311
+ else:
312
+ obvious_mask = torch.zeros_like(certainty).bool()
313
+ stable_mask = obvious_mask.squeeze(-1) # [B, L]
314
+ uncertain_mask = ~stable_mask # [B, L]
315
+
316
+ # H2: dependency prediction over remaining cells
317
+ if self.config.glps_dep_graph:
318
+ dep_mask = self._build_dep_mask(z_scan, uncertain_mask) # [B, L]
319
+ else:
320
+ dep_mask = uncertain_mask
321
+
322
+ # L2: targeted refinement (a couple of masked iters)
323
+ update_mask = dep_mask if self.config.glps_token_masking else None
324
+ z = z_scan.detach() # use scanned context as start
325
+ for _ in range(self.config.glps_max_targeted_iters):
326
+ z = self.L_level(z, z_H + input_embeddings, update_mask=update_mask, **seq_info)
327
+ # Refresh certainty to shrink mask (optional but cheap)
328
+ cert_now = torch.sigmoid(self.certainty_head(z)).squeeze(-1)
329
+ if self.config.glps_token_masking:
330
+ update_mask = dep_mask & (cert_now < self.config.glps_tau_uncertain)
331
+
332
+ # Merge into H and do a light H update with grad
333
+ z_L = z
334
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
335
+
336
+ # H3: energy/consistency -> confidence & optional global propagate
337
+ energy = torch.sigmoid(self.energy_head(z_H)).mean(dim=1) # [B, 1]
338
+ conf = 1.0 - energy
339
+ need_sweep = (conf.squeeze(-1) < self.config.glps_tau_halt)
340
+ if self.config.glps_global_propagate_on_low_conf and need_sweep.any():
341
+ # one final full sweep only for rows needing it
342
+ maskB = need_sweep.view(-1, 1, 1)
343
+ zL2 = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
344
+ zH2 = self.H_level(z_H, zL2, update_mask=None, **seq_info)
345
+ z_L = torch.where(maskB, zL2, z_L)
346
+ z_H = torch.where(maskB, zH2, z_H)
347
+
348
+ # Outputs
349
+ new_carry = GLPS_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())
350
+ logits = self.lm_head(z_H)[:, self.puzzle_emb_len:]
351
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
352
+ return new_carry, logits, (q_logits[..., 0], q_logits[..., 1]), conf
353
+
354
+ class GLPS_ACTV1(nn.Module):
355
+ """ACT-style wrapper that mixes Q-halt with GLPS confidence."""
356
+ def __init__(self, config_dict: dict):
357
+ super().__init__()
358
+ self.config = GLPS_ACTV1Config(**config_dict)
359
+ self.inner = GLPS_ACTV1_Inner(self.config)
360
+
361
+ @property
362
+ def puzzle_emb(self):
363
+ return self.inner.puzzle_emb
364
+
365
+ def initial_carry(self, batch: Dict[str, torch.Tensor]):
366
+ batch_size = batch["inputs"].shape[0]
367
+ return GLPS_ACTV1Carry(
368
+ inner_carry=self.inner.empty_carry(batch_size),
369
+ steps=torch.zeros((batch_size,), dtype=torch.int32),
370
+ halted=torch.ones((batch_size,), dtype=torch.bool), # start halted to force reset on first pass
371
+ current_data={k: torch.empty_like(v) for k, v in batch.items()}
372
+ )
373
+
374
+ def forward(self, carry: GLPS_ACTV1Carry, batch: Dict[str, torch.Tensor]):
375
+ # Reset halted seqs
376
+ new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
377
+ new_steps = torch.where(carry.halted, 0, carry.steps)
378
+ new_current_data = {k: torch.where(carry.halted.view((-1,) + (1,) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
379
+
380
+ # Inner step
381
+ new_inner_carry, logits, (q_halt_logits, q_continue_logits), conf = self.inner(new_inner_carry, new_current_data)
382
+
383
+ outputs = {
384
+ "logits": logits,
385
+ "q_halt_logits": q_halt_logits,
386
+ "q_continue_logits": q_continue_logits,
387
+ "conf": conf.squeeze(-1),
388
+ }
389
+
390
+ with torch.no_grad():
391
+ new_steps = new_steps + 1
392
+ is_last_step = new_steps >= self.config.halt_max_steps
393
+
394
+ # Combine halt signals: Q or confidence or last-step
395
+ halted = is_last_step | (q_halt_logits > q_continue_logits) | (conf.squeeze(-1) >= self.config.glps_tau_halt)
396
+
397
+ # Exploration during training only
398
+ if self.training and (self.config.halt_max_steps > 1):
399
+ min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
400
+ halted = halted & (new_steps >= min_halt_steps)
401
+
402
+ # Optional target for Q-learning (kept similar to HRM)
403
+ next_conf = self.inner(new_inner_carry, new_current_data)[-1]
404
+ outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, q_halt_logits, torch.maximum(q_halt_logits, q_continue_logits)))
405
+ else:
406
+ # During eval, always use max_steps to ensure consistent reasoning depth (same as TRM/HRM eval behavior)
407
+ halted = is_last_step
408
+
409
+ return GLPS_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v4_decay/losses.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Tuple, Dict, Sequence, Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ import math
7
+
8
+ IGNORE_LABEL_ID = -100
9
+
10
+
11
+ def s(x, epsilon=1e-30):
12
+ return torch.where(
13
+ x<0,
14
+ 1/(1-x+ epsilon),
15
+ x + 1
16
+ )
17
+
18
+
19
+ def log_stablemax(x, dim=-1):
20
+ s_x = s(x)
21
+ return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
22
+
23
+
24
+ def stablemax_cross_entropy(logits, labels, ignore_index: int = -100, valid_mask=None):
25
+ logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
26
+
27
+ if valid_mask is None:
28
+ valid_mask = (labels != ignore_index)
29
+ transformed_labels = torch.where(valid_mask, labels, 0)
30
+ prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
31
+
32
+ return -torch.where(valid_mask, prediction_logprobs, 0)
33
+
34
+
35
+ def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
36
+ # Cast logits to f32
37
+ # Flatten logits
38
+ return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
39
+
40
+
41
+ class ACTLossHead(nn.Module):
42
+ def __init__(self, model: nn.Module, loss_type: str):
43
+ super().__init__()
44
+ self.model = model
45
+ self.loss_fn = globals()[loss_type]
46
+
47
+ def initial_carry(self, *args, **kwargs):
48
+ return self.model.initial_carry(*args, **kwargs) # type: ignore
49
+
50
+ def forward(
51
+ self,
52
+ return_keys: Sequence[str],
53
+ # Model args
54
+ **model_kwargs,
55
+ ) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
56
+ # Model logits
57
+ # B x SeqLen x D
58
+ new_carry, outputs = self.model(**model_kwargs)
59
+ labels = new_carry.current_data["labels"]
60
+
61
+ with torch.no_grad():
62
+ # Preds
63
+ outputs["preds"] = torch.argmax(outputs["logits"], dim=-1)
64
+
65
+ # Correctness
66
+ mask = (labels != IGNORE_LABEL_ID)
67
+ loss_counts = mask.sum(-1)
68
+ loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
69
+
70
+ is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
71
+ seq_is_correct = is_correct.sum(-1) == loss_counts
72
+
73
+ # Metrics (halted)
74
+ valid_metrics = new_carry.halted & (loss_counts > 0)
75
+ metrics = {
76
+ "count": valid_metrics.sum(),
77
+
78
+ "accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
79
+ "exact_accuracy": (valid_metrics & seq_is_correct).sum(),
80
+
81
+ "q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
82
+ "steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
83
+ }
84
+
85
+ # Losses
86
+
87
+ lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID, valid_mask=mask) / loss_divisor).sum()
88
+ q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
89
+ metrics.update({
90
+ "lm_loss": lm_loss.detach(),
91
+ "q_halt_loss": q_halt_loss.detach(),
92
+ })
93
+ # Q continue (bootstrapping target loss); Alexia: This fits Q-learning, but seems totally unecessary
94
+ q_continue_loss = 0
95
+ if "target_q_continue" in outputs:
96
+ q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
97
+
98
+ metrics["q_continue_loss"] = q_continue_loss.detach()
99
+ # Filter outputs for return
100
+ detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
101
+
102
+ return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()
Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v4_decay_ft10k/all_config.yaml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ arch:
2
+ H_cycles: 3
3
+ H_layers: 2
4
+ L_cycles: 1
5
+ L_layers: 7
6
+ dep_rank: 64
7
+ dep_topk: 12
8
+ expansion: 4.0
9
+ forward_dtype: bfloat16
10
+ glps_dep_graph: true
11
+ glps_enabled: true
12
+ glps_fill_obvious: true
13
+ glps_global_propagate_on_low_conf: true
14
+ glps_max_targeted_iters: 4
15
+ glps_tau_halt: 0.9
16
+ glps_tau_uncertain: 0.8
17
+ glps_token_masking: true
18
+ halt_exploration_prob: 0.1
19
+ halt_max_steps: 16
20
+ hidden_size: 512
21
+ loss:
22
+ loss_type: stablemax_cross_entropy
23
+ name: losses@ACTLossHead
24
+ mlp_t: false
25
+ name: recursive_reasoning.glps@GLPS_ACTV1
26
+ num_heads: 8
27
+ pos_encodings: rope
28
+ puzzle_emb_ndim: 0
29
+ rms_norm_eps: 1.0e-05
30
+ rope_theta: 10000.0
31
+ beta1: 0.9
32
+ beta2: 0.95
33
+ checkpoint_every_eval: true
34
+ checkpoint_path: checkpoints/Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v4_decay_ft10k
35
+ data_paths:
36
+ - data/sudoku-extreme-1k-aug-1000
37
+ data_paths_test: []
38
+ ema: true
39
+ ema_rate: 0.999
40
+ epochs: 50000
41
+ eval_interval: 5000
42
+ eval_save_outputs: []
43
+ evaluators: []
44
+ freeze_weights: false
45
+ global_batch_size: 768
46
+ load_checkpoint: checkpoints/Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v4_decay/step_52080
47
+ lr: 0.0001
48
+ lr_min_ratio: 0.01
49
+ lr_warmup_steps: 2000
50
+ min_eval_interval: 0
51
+ project_name: Sudoku-extreme-1k-aug-1000-ACT-torch
52
+ puzzle_emb_lr: 0.0001
53
+ puzzle_emb_weight_decay: 1.0
54
+ run_name: pretrain_glps_sudoku_v4_decay_ft10k
55
+ seed: 0
56
+ weight_decay: 1.0
Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v4_decay_ft10k/glps.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Tuple, List, Dict
3
+ from dataclasses import dataclass
4
+ import math
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ from pydantic import BaseModel
9
+
10
+ # Reuse the same building blocks as HRM/TRM
11
+ from models.common import trunc_normal_init_
12
+ from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
13
+ from models.sparse_embedding import CastedSparseEmbedding
14
+
15
+ """
16
+ Global-Local Predictive Solver (GLPS)
17
+ ------------------------------------
18
+ A light-weight control-policy on top of the HRM/TRM style blocks:
19
+ - H1: global scan -> certainty map
20
+ - L1: fill-obvious (lock stable cells)
21
+ - H2: dependency scoring over remaining cells
22
+ - L2: targeted refinement (masked updates)
23
+ - H3: energy-based confidence -> (optional) one global propagate sweep -> halt
24
+
25
+ This file keeps parameter growth tiny: a few heads + (optional) low-rank dependency scorer.
26
+ """
27
+
28
+ @dataclass
29
+ class GLPS_ACTV1InnerCarry:
30
+ z_H: torch.Tensor
31
+ z_L: torch.Tensor
32
+
33
+ @dataclass
34
+ class GLPS_ACTV1Carry:
35
+ inner_carry: GLPS_ACTV1InnerCarry
36
+ steps: torch.Tensor
37
+ halted: torch.Tensor
38
+ current_data: Dict[str, torch.Tensor]
39
+
40
+ class GLPS_ACTV1Config(BaseModel):
41
+ # Core IO / shapes
42
+ batch_size: int
43
+ seq_len: int
44
+ puzzle_emb_ndim: int = 0
45
+ num_puzzle_identifiers: int = 1
46
+ vocab_size: int = 256
47
+
48
+ # Cycle schedule
49
+ H_cycles: int = 3 # (scan -> refine -> check) typical
50
+ L_cycles: int = 1
51
+
52
+ # Depth
53
+ H_layers: int = 2
54
+ L_layers: int = 4
55
+
56
+ # Transformer config
57
+ hidden_size: int = 512
58
+ expansion: float = 2.0
59
+ num_heads: int = 8
60
+ pos_encodings: str = "rope"
61
+
62
+ rms_norm_eps: float = 1e-5
63
+ rope_theta: float = 10000.0
64
+
65
+ # ACT wrapper
66
+ halt_max_steps: int = 4
67
+ halt_exploration_prob: float = 0.1
68
+
69
+ forward_dtype: str = "bfloat16"
70
+
71
+ # Optional: use MLP on L instead of attention (matches HRM/TRM option)
72
+ mlp_t: bool = False
73
+
74
+ # ---- GLPS extras (tiny) ----
75
+ glps_enabled: bool = True
76
+ glps_fill_obvious: bool = True
77
+ glps_dep_graph: bool = True
78
+ glps_token_masking: bool = True
79
+ glps_global_propagate_on_low_conf: bool = True
80
+
81
+ glps_tau_halt: float = 0.95 # final confidence to halt
82
+ glps_tau_uncertain: float = 0.60 # cell-wise certainty threshold
83
+ glps_max_targeted_iters: int = 2 # small number: 1-2
84
+
85
+ # Dependency scorer (low rank bilinear)
86
+ dep_rank: int = 32
87
+ dep_topk: int = 8
88
+
89
+ class GLPSBlock(nn.Module):
90
+ def __init__(self, config: GLPS_ACTV1Config) -> None:
91
+ super().__init__()
92
+ self.config = config
93
+ if self.config.mlp_t:
94
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size)
95
+ self.mlp_t = SwiGLU(
96
+ hidden_size=self.config.seq_len + self.puzzle_emb_len, # treat sequence as channel
97
+ expansion=config.expansion,
98
+ )
99
+ else:
100
+ self.self_attn = Attention(
101
+ hidden_size=config.hidden_size,
102
+ head_dim=config.hidden_size // config.num_heads,
103
+ num_heads=config.num_heads,
104
+ num_key_value_heads=config.num_heads,
105
+ causal=False,
106
+ )
107
+ self.mlp = SwiGLU(hidden_size=config.hidden_size, expansion=config.expansion)
108
+ self.norm_eps = config.rms_norm_eps
109
+
110
+ def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
111
+ if self.config.mlp_t:
112
+ # MLP over sequence dimension (mlp-t)
113
+ hidden_states = hidden_states.transpose(1, 2)
114
+ out = self.mlp_t(hidden_states)
115
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
116
+ hidden_states = hidden_states.transpose(1, 2)
117
+ else:
118
+ hidden_states = rms_norm(
119
+ hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states),
120
+ variance_epsilon=self.norm_eps,
121
+ )
122
+ out = self.mlp(hidden_states)
123
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
124
+ return hidden_states
125
+
126
+ class GLPSReasoningModule(nn.Module):
127
+ """Reasoning stack with optional masked updates (only update uncertain tokens)."""
128
+ def __init__(self, layers: List[GLPSBlock]):
129
+ super().__init__()
130
+ self.layers = torch.nn.ModuleList(layers)
131
+
132
+ def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, update_mask: torch.Tensor = None, **kwargs) -> torch.Tensor:
133
+ x = hidden_states
134
+ for layer in self.layers:
135
+ # Compute candidate update using injected context
136
+ y = layer(hidden_states=x + input_injection, **kwargs)
137
+ if update_mask is not None:
138
+ # Convex blend keeps frozen tokens unchanged
139
+ m = update_mask.to(x.dtype)[..., None]
140
+ x = x + m * (y - x)
141
+ else:
142
+ x = y
143
+ return x
144
+
145
+ class GLPS_ACTV1_Inner(nn.Module):
146
+ def __init__(self, config: GLPS_ACTV1Config) -> None:
147
+ super().__init__()
148
+ self.config = config
149
+ self.forward_dtype = getattr(torch, self.config.forward_dtype)
150
+
151
+ # I/O
152
+ self.embed_scale = math.sqrt(self.config.hidden_size)
153
+ embed_init_std = 1.0 / self.embed_scale
154
+
155
+ self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
156
+ self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
157
+ self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
158
+
159
+ # Puzzle emb (optional) — same convention as HRM/TRM
160
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
161
+ if self.config.puzzle_emb_ndim > 0:
162
+ self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim, batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
163
+
164
+ # Positional encodings
165
+ if self.config.pos_encodings == "rope":
166
+ self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads, max_position_embeddings=self.config.seq_len + self.puzzle_emb_len, base=self.config.rope_theta)
167
+ elif self.config.pos_encodings == "learned":
168
+ self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
169
+
170
+ # Reasoning stacks
171
+ self.H_level = GLPSReasoningModule(layers=[GLPSBlock(self.config) for _ in range(self.config.H_layers)])
172
+ self.L_level = GLPSReasoningModule(layers=[GLPSBlock(self.config) for _ in range(self.config.L_layers)])
173
+
174
+ # Initial states (match HRM/TRM style)
175
+ H_init = trunc_normal_init_(torch.empty(1, 1, self.config.hidden_size, dtype=self.forward_dtype), std=1.0)
176
+ L_init = trunc_normal_init_(torch.empty(1, 1, self.config.hidden_size, dtype=self.forward_dtype), std=1.0)
177
+ self.register_buffer("H_init", H_init, persistent=True)
178
+ self.register_buffer("L_init", L_init, persistent=True)
179
+
180
+ # GLPS small heads
181
+ self.candidate_head = CastedLinear(self.config.hidden_size, 16, bias=True) # task-specific; for Sudoku you can slice to 9
182
+ self.certainty_head = CastedLinear(self.config.hidden_size, 1, bias=True)
183
+ self.energy_head = CastedLinear(self.config.hidden_size, 1, bias=True)
184
+
185
+ # Low-rank dependency scorer (shared)
186
+ r = max(1, self.config.dep_rank)
187
+ self.dep_q = CastedLinear(self.config.hidden_size, r, bias=False)
188
+ self.dep_k = CastedLinear(self.config.hidden_size, r, bias=False)
189
+
190
+ # Q head init like HRM/TRM (near-zero -> easier bootstrapping)
191
+ with torch.no_grad():
192
+ self.q_head.weight.zero_()
193
+ self.q_head.bias.fill_(-5)
194
+
195
+ def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
196
+ # Token embedding
197
+ embedding = self.embed_tokens(input.to(torch.int32))
198
+
199
+ # Puzzle embeddings
200
+ if self.config.puzzle_emb_ndim > 0:
201
+ puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
202
+ pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
203
+ if pad_count > 0:
204
+ puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
205
+ embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
206
+
207
+ # Position embeddings
208
+ if self.config.pos_encodings == "learned":
209
+ embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
210
+
211
+ return self.embed_scale * embedding
212
+
213
+ def empty_carry(self, batch_size: int):
214
+ return GLPS_ACTV1InnerCarry(
215
+ z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
216
+ z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
217
+ )
218
+
219
+ def reset_carry(self, reset_flag: torch.Tensor, carry: GLPS_ACTV1InnerCarry):
220
+ # Explicitly expand buffers and mask to target shapes to avoid shape confusion
221
+ B, L, D = carry.z_H.shape
222
+ # Reduce/reset flag to per-batch boolean vector of shape [B]
223
+ if reset_flag.ndim == 1 and reset_flag.shape[0] == B:
224
+ reset_b = reset_flag.to(torch.bool)
225
+ else:
226
+ # If shape is [B, ...] reduce across non-batch dims; otherwise fallback to first B entries
227
+ try:
228
+ reset_b = reset_flag.reshape(B, -1).any(dim=1).to(torch.bool)
229
+ except Exception:
230
+ reset_b = reset_flag.reshape(-1)[:B].to(torch.bool)
231
+ m = reset_b.view(B, 1, 1)
232
+ mH = m.expand(B, L, D)
233
+ mL = mH # same shape for z_L
234
+ H_init_exp = self.H_init.expand(B, L, D)
235
+ L_init_exp = self.L_init.expand(B, L, D)
236
+ return GLPS_ACTV1InnerCarry(
237
+ z_H=torch.where(mH, H_init_exp, carry.z_H),
238
+ z_L=torch.where(mL, L_init_exp, carry.z_L),
239
+ )
240
+
241
+ def _global_scan(self, z_L, z_H, input_embeddings, seq_info):
242
+ # One light pass to gather global signals
243
+ z_scan = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
244
+ cand_logits = self.candidate_head(z_scan) # [B, L, C]
245
+ certainty = torch.sigmoid(self.certainty_head(z_scan)) # [B, L, 1]
246
+ return z_scan, cand_logits, certainty
247
+
248
+ def _build_dep_mask(self, z_ctx, uncertain_mask: torch.Tensor):
249
+ """Compute a dependency-based focus mask from a low-rank bilinear score.
250
+ uncertain_mask: [B, L] boolean mask of cells that are currently uncertain
251
+ Returns: dep_mask [B, L] boolean mask of cells to (re)update.
252
+ """
253
+ B, L, D = z_ctx.shape
254
+ # Project to low-rank space; use float32 to avoid dtype mismatches under torch.compile
255
+ Q = self.dep_q(z_ctx).to(torch.float32) # [B, L, r]
256
+ K = self.dep_k(z_ctx).to(torch.float32) # [B, L, r]
257
+ r = max(1, int(Q.shape[-1]))
258
+ sim = torch.matmul(Q, K.transpose(1, 2)) # [B, L, L] (float32)
259
+ sim = sim / math.sqrt(r)
260
+
261
+ # Aggregate influence from uncertain queries onto target tokens
262
+ src = uncertain_mask.to(sim.dtype).unsqueeze(1) # [B, 1, L]
263
+ influence = torch.matmul(src, sim).squeeze(1) # [B, L]
264
+
265
+ # Top-k influenced tokens per batch
266
+ topk = min(self.config.dep_topk, L)
267
+ vals, idx = torch.topk(influence, k=topk, dim=-1) # [B, topk]
268
+ dep_mask = torch.zeros_like(uncertain_mask)
269
+ dep_mask.scatter_(1, idx, True)
270
+
271
+ # Always include uncertain cells themselves
272
+ dep_mask = dep_mask | uncertain_mask
273
+ return dep_mask
274
+
275
+ def forward(self, carry: GLPS_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]):
276
+ seq_info = dict(cos_sin=getattr(self, "rotary_emb", None)() if hasattr(self, "rotary_emb") else None)
277
+
278
+ # Encode inputs
279
+ input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
280
+
281
+ # States
282
+ z_H, z_L = carry.z_H, carry.z_L
283
+
284
+ if not self.config.glps_enabled:
285
+ # Fallback to an HRM-like single-cycle grad update for compatibility
286
+ with torch.no_grad():
287
+ for _H in range(self.config.H_cycles - 1):
288
+ for _L in range(self.config.L_cycles):
289
+ z_L = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
290
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
291
+ # final grad step
292
+ for _L in range(self.config.L_cycles):
293
+ z_L = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
294
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
295
+
296
+ # Outputs
297
+ new_carry = GLPS_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())
298
+ logits = self.lm_head(z_H)[:, self.puzzle_emb_len:]
299
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
300
+ conf = torch.zeros_like(q_logits[..., :1]) + 0.5 # neutral
301
+ return new_carry, logits, (q_logits[..., 0], q_logits[..., 1]), conf
302
+
303
+ # ===== GLPS path =====
304
+ # H1: global scan (cheap)
305
+ with torch.no_grad():
306
+ z_scan, cand_logits, certainty = self._global_scan(z_L, z_H, input_embeddings, seq_info)
307
+
308
+ # L1: fill-obvious -> compute stable vs uncertain masks
309
+ if self.config.glps_fill_obvious:
310
+ obvious_mask = (certainty >= self.config.glps_tau_uncertain) # [B, L, 1]
311
+ else:
312
+ obvious_mask = torch.zeros_like(certainty).bool()
313
+ stable_mask = obvious_mask.squeeze(-1) # [B, L]
314
+ uncertain_mask = ~stable_mask # [B, L]
315
+
316
+ # H2: dependency prediction over remaining cells
317
+ if self.config.glps_dep_graph:
318
+ dep_mask = self._build_dep_mask(z_scan, uncertain_mask) # [B, L]
319
+ else:
320
+ dep_mask = uncertain_mask
321
+
322
+ # L2: targeted refinement (a couple of masked iters)
323
+ update_mask = dep_mask if self.config.glps_token_masking else None
324
+ z = z_scan.detach() # use scanned context as start
325
+ for _ in range(self.config.glps_max_targeted_iters):
326
+ z = self.L_level(z, z_H + input_embeddings, update_mask=update_mask, **seq_info)
327
+ # Refresh certainty to shrink mask (optional but cheap)
328
+ cert_now = torch.sigmoid(self.certainty_head(z)).squeeze(-1)
329
+ if self.config.glps_token_masking:
330
+ update_mask = dep_mask & (cert_now < self.config.glps_tau_uncertain)
331
+
332
+ # Merge into H and do a light H update with grad
333
+ z_L = z
334
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
335
+
336
+ # H3: energy/consistency -> confidence & optional global propagate
337
+ energy = torch.sigmoid(self.energy_head(z_H)).mean(dim=1) # [B, 1]
338
+ conf = 1.0 - energy
339
+ need_sweep = (conf.squeeze(-1) < self.config.glps_tau_halt)
340
+ if self.config.glps_global_propagate_on_low_conf and need_sweep.any():
341
+ # one final full sweep only for rows needing it
342
+ maskB = need_sweep.view(-1, 1, 1)
343
+ zL2 = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
344
+ zH2 = self.H_level(z_H, zL2, update_mask=None, **seq_info)
345
+ z_L = torch.where(maskB, zL2, z_L)
346
+ z_H = torch.where(maskB, zH2, z_H)
347
+
348
+ # Outputs
349
+ new_carry = GLPS_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())
350
+ logits = self.lm_head(z_H)[:, self.puzzle_emb_len:]
351
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
352
+ return new_carry, logits, (q_logits[..., 0], q_logits[..., 1]), conf
353
+
354
+ class GLPS_ACTV1(nn.Module):
355
+ """ACT-style wrapper that mixes Q-halt with GLPS confidence."""
356
+ def __init__(self, config_dict: dict):
357
+ super().__init__()
358
+ self.config = GLPS_ACTV1Config(**config_dict)
359
+ self.inner = GLPS_ACTV1_Inner(self.config)
360
+
361
+ @property
362
+ def puzzle_emb(self):
363
+ return self.inner.puzzle_emb
364
+
365
+ def initial_carry(self, batch: Dict[str, torch.Tensor]):
366
+ batch_size = batch["inputs"].shape[0]
367
+ return GLPS_ACTV1Carry(
368
+ inner_carry=self.inner.empty_carry(batch_size),
369
+ steps=torch.zeros((batch_size,), dtype=torch.int32),
370
+ halted=torch.ones((batch_size,), dtype=torch.bool), # start halted to force reset on first pass
371
+ current_data={k: torch.empty_like(v) for k, v in batch.items()}
372
+ )
373
+
374
+ def forward(self, carry: GLPS_ACTV1Carry, batch: Dict[str, torch.Tensor]):
375
+ # Reset halted seqs
376
+ new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
377
+ new_steps = torch.where(carry.halted, 0, carry.steps)
378
+ new_current_data = {k: torch.where(carry.halted.view((-1,) + (1,) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
379
+
380
+ # Inner step
381
+ new_inner_carry, logits, (q_halt_logits, q_continue_logits), conf = self.inner(new_inner_carry, new_current_data)
382
+
383
+ outputs = {
384
+ "logits": logits,
385
+ "q_halt_logits": q_halt_logits,
386
+ "q_continue_logits": q_continue_logits,
387
+ "conf": conf.squeeze(-1),
388
+ }
389
+
390
+ with torch.no_grad():
391
+ new_steps = new_steps + 1
392
+ is_last_step = new_steps >= self.config.halt_max_steps
393
+
394
+ # Combine halt signals: Q or confidence or last-step
395
+ halted = is_last_step | (q_halt_logits > q_continue_logits) | (conf.squeeze(-1) >= self.config.glps_tau_halt)
396
+
397
+ # Exploration during training only
398
+ if self.training and (self.config.halt_max_steps > 1):
399
+ min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
400
+ halted = halted & (new_steps >= min_halt_steps)
401
+
402
+ # Optional target for Q-learning (kept similar to HRM)
403
+ next_conf = self.inner(new_inner_carry, new_current_data)[-1]
404
+ outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, q_halt_logits, torch.maximum(q_halt_logits, q_continue_logits)))
405
+ else:
406
+ # During eval, always use max_steps to ensure consistent reasoning depth (same as TRM/HRM eval behavior)
407
+ halted = is_last_step
408
+
409
+ return GLPS_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v4_decay_ft10k/losses.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Tuple, Dict, Sequence, Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ import math
7
+
8
+ IGNORE_LABEL_ID = -100
9
+
10
+
11
+ def s(x, epsilon=1e-30):
12
+ return torch.where(
13
+ x<0,
14
+ 1/(1-x+ epsilon),
15
+ x + 1
16
+ )
17
+
18
+
19
+ def log_stablemax(x, dim=-1):
20
+ s_x = s(x)
21
+ return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
22
+
23
+
24
+ def stablemax_cross_entropy(logits, labels, ignore_index: int = -100, valid_mask=None):
25
+ logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
26
+
27
+ if valid_mask is None:
28
+ valid_mask = (labels != ignore_index)
29
+ transformed_labels = torch.where(valid_mask, labels, 0)
30
+ prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
31
+
32
+ return -torch.where(valid_mask, prediction_logprobs, 0)
33
+
34
+
35
+ def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
36
+ # Cast logits to f32
37
+ # Flatten logits
38
+ return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
39
+
40
+
41
+ class ACTLossHead(nn.Module):
42
+ def __init__(self, model: nn.Module, loss_type: str):
43
+ super().__init__()
44
+ self.model = model
45
+ self.loss_fn = globals()[loss_type]
46
+
47
+ def initial_carry(self, *args, **kwargs):
48
+ return self.model.initial_carry(*args, **kwargs) # type: ignore
49
+
50
+ def forward(
51
+ self,
52
+ return_keys: Sequence[str],
53
+ # Model args
54
+ **model_kwargs,
55
+ ) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
56
+ # Model logits
57
+ # B x SeqLen x D
58
+ new_carry, outputs = self.model(**model_kwargs)
59
+ labels = new_carry.current_data["labels"]
60
+
61
+ with torch.no_grad():
62
+ # Preds
63
+ outputs["preds"] = torch.argmax(outputs["logits"], dim=-1)
64
+
65
+ # Correctness
66
+ mask = (labels != IGNORE_LABEL_ID)
67
+ loss_counts = mask.sum(-1)
68
+ loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
69
+
70
+ is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
71
+ seq_is_correct = is_correct.sum(-1) == loss_counts
72
+
73
+ # Metrics (halted)
74
+ valid_metrics = new_carry.halted & (loss_counts > 0)
75
+ metrics = {
76
+ "count": valid_metrics.sum(),
77
+
78
+ "accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
79
+ "exact_accuracy": (valid_metrics & seq_is_correct).sum(),
80
+
81
+ "q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
82
+ "steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
83
+ }
84
+
85
+ # Losses
86
+
87
+ lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID, valid_mask=mask) / loss_divisor).sum()
88
+ q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
89
+ metrics.update({
90
+ "lm_loss": lm_loss.detach(),
91
+ "q_halt_loss": q_halt_loss.detach(),
92
+ })
93
+ # Q continue (bootstrapping target loss); Alexia: This fits Q-learning, but seems totally unecessary
94
+ q_continue_loss = 0
95
+ if "target_q_continue" in outputs:
96
+ q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
97
+
98
+ metrics["q_continue_loss"] = q_continue_loss.detach()
99
+ # Filter outputs for return
100
+ detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
101
+
102
+ return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()
Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v4_evalboost/all_config.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ arch:
2
+ H_cycles: 3
3
+ H_layers: 2
4
+ L_cycles: 1
5
+ L_layers: 7
6
+ dep_rank: 64
7
+ dep_topk: 12
8
+ expansion: 4.0
9
+ forward_dtype: bfloat16
10
+ glps_dep_graph: true
11
+ glps_enabled: true
12
+ glps_fill_obvious: true
13
+ glps_global_propagate_on_low_conf: true
14
+ glps_max_targeted_iters: 4
15
+ glps_tau_halt: 0.95
16
+ glps_tau_uncertain: 0.8
17
+ glps_token_masking: true
18
+ halt_exploration_prob: 0.1
19
+ halt_max_steps: 16
20
+ hidden_size: 512
21
+ loss:
22
+ loss_type: stablemax_cross_entropy
23
+ name: losses@ACTLossHead
24
+ mlp_t: false
25
+ name: recursive_reasoning.glps@GLPS_ACTV1
26
+ num_heads: 8
27
+ pos_encodings: rope
28
+ puzzle_emb_ndim: 0
29
+ rms_norm_eps: 1.0e-05
30
+ rope_theta: 10000.0
31
+ beta1: 0.9
32
+ beta2: 0.95
33
+ checkpoint_every_eval: true
34
+ checkpoint_path: checkpoints/Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v4_evalboost
35
+ data_paths:
36
+ - data/sudoku-extreme-1k-aug-1000
37
+ data_paths_test: []
38
+ ema: true
39
+ ema_rate: 0.999
40
+ epochs: 50000
41
+ eval_glps_max_targeted_iters: 6
42
+ eval_glps_tau_halt: 0.85
43
+ eval_halt_max_steps: 32
44
+ eval_interval: 5000
45
+ eval_only: true
46
+ eval_save_outputs: []
47
+ evaluators: []
48
+ freeze_weights: false
49
+ global_batch_size: 768
50
+ load_checkpoint: /teamspace/studios/this_studio/TinyRecursiveModels/checkpoints/Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v4_decay_ft10k/step_58590
51
+ lr: 0.0001
52
+ lr_min_ratio: 0.01
53
+ lr_warmup_steps: 2000
54
+ min_eval_interval: 0
55
+ project_name: Sudoku-extreme-1k-aug-1000-ACT-torch
56
+ puzzle_emb_lr: 0.0001
57
+ puzzle_emb_weight_decay: 1.0
58
+ run_name: pretrain_glps_sudoku_v4_evalboost
59
+ seed: 0
60
+ weight_decay: 1.0
Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v4_evalboost/glps.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Tuple, List, Dict
3
+ from dataclasses import dataclass
4
+ import math
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ from pydantic import BaseModel
9
+
10
+ # Reuse the same building blocks as HRM/TRM
11
+ from models.common import trunc_normal_init_
12
+ from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
13
+ from models.sparse_embedding import CastedSparseEmbedding
14
+
15
+ """
16
+ Global-Local Predictive Solver (GLPS)
17
+ ------------------------------------
18
+ A light-weight control-policy on top of the HRM/TRM style blocks:
19
+ - H1: global scan -> certainty map
20
+ - L1: fill-obvious (lock stable cells)
21
+ - H2: dependency scoring over remaining cells
22
+ - L2: targeted refinement (masked updates)
23
+ - H3: energy-based confidence -> (optional) one global propagate sweep -> halt
24
+
25
+ This file keeps parameter growth tiny: a few heads + (optional) low-rank dependency scorer.
26
+ """
27
+
28
+ @dataclass
29
+ class GLPS_ACTV1InnerCarry:
30
+ z_H: torch.Tensor
31
+ z_L: torch.Tensor
32
+
33
+ @dataclass
34
+ class GLPS_ACTV1Carry:
35
+ inner_carry: GLPS_ACTV1InnerCarry
36
+ steps: torch.Tensor
37
+ halted: torch.Tensor
38
+ current_data: Dict[str, torch.Tensor]
39
+
40
+ class GLPS_ACTV1Config(BaseModel):
41
+ # Core IO / shapes
42
+ batch_size: int
43
+ seq_len: int
44
+ puzzle_emb_ndim: int = 0
45
+ num_puzzle_identifiers: int = 1
46
+ vocab_size: int = 256
47
+
48
+ # Cycle schedule
49
+ H_cycles: int = 3 # (scan -> refine -> check) typical
50
+ L_cycles: int = 1
51
+
52
+ # Depth
53
+ H_layers: int = 2
54
+ L_layers: int = 4
55
+
56
+ # Transformer config
57
+ hidden_size: int = 512
58
+ expansion: float = 2.0
59
+ num_heads: int = 8
60
+ pos_encodings: str = "rope"
61
+
62
+ rms_norm_eps: float = 1e-5
63
+ rope_theta: float = 10000.0
64
+
65
+ # ACT wrapper
66
+ halt_max_steps: int = 4
67
+ halt_exploration_prob: float = 0.1
68
+
69
+ forward_dtype: str = "bfloat16"
70
+
71
+ # Optional: use MLP on L instead of attention (matches HRM/TRM option)
72
+ mlp_t: bool = False
73
+
74
+ # ---- GLPS extras (tiny) ----
75
+ glps_enabled: bool = True
76
+ glps_fill_obvious: bool = True
77
+ glps_dep_graph: bool = True
78
+ glps_token_masking: bool = True
79
+ glps_global_propagate_on_low_conf: bool = True
80
+
81
+ glps_tau_halt: float = 0.95 # final confidence to halt
82
+ glps_tau_uncertain: float = 0.60 # cell-wise certainty threshold
83
+ glps_max_targeted_iters: int = 2 # small number: 1-2
84
+
85
+ # Dependency scorer (low rank bilinear)
86
+ dep_rank: int = 32
87
+ dep_topk: int = 8
88
+
89
+ class GLPSBlock(nn.Module):
90
+ def __init__(self, config: GLPS_ACTV1Config) -> None:
91
+ super().__init__()
92
+ self.config = config
93
+ if self.config.mlp_t:
94
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size)
95
+ self.mlp_t = SwiGLU(
96
+ hidden_size=self.config.seq_len + self.puzzle_emb_len, # treat sequence as channel
97
+ expansion=config.expansion,
98
+ )
99
+ else:
100
+ self.self_attn = Attention(
101
+ hidden_size=config.hidden_size,
102
+ head_dim=config.hidden_size // config.num_heads,
103
+ num_heads=config.num_heads,
104
+ num_key_value_heads=config.num_heads,
105
+ causal=False,
106
+ )
107
+ self.mlp = SwiGLU(hidden_size=config.hidden_size, expansion=config.expansion)
108
+ self.norm_eps = config.rms_norm_eps
109
+
110
+ def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
111
+ if self.config.mlp_t:
112
+ # MLP over sequence dimension (mlp-t)
113
+ hidden_states = hidden_states.transpose(1, 2)
114
+ out = self.mlp_t(hidden_states)
115
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
116
+ hidden_states = hidden_states.transpose(1, 2)
117
+ else:
118
+ hidden_states = rms_norm(
119
+ hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states),
120
+ variance_epsilon=self.norm_eps,
121
+ )
122
+ out = self.mlp(hidden_states)
123
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
124
+ return hidden_states
125
+
126
+ class GLPSReasoningModule(nn.Module):
127
+ """Reasoning stack with optional masked updates (only update uncertain tokens)."""
128
+ def __init__(self, layers: List[GLPSBlock]):
129
+ super().__init__()
130
+ self.layers = torch.nn.ModuleList(layers)
131
+
132
+ def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, update_mask: torch.Tensor = None, **kwargs) -> torch.Tensor:
133
+ x = hidden_states
134
+ for layer in self.layers:
135
+ # Compute candidate update using injected context
136
+ y = layer(hidden_states=x + input_injection, **kwargs)
137
+ if update_mask is not None:
138
+ # Convex blend keeps frozen tokens unchanged
139
+ m = update_mask.to(x.dtype)[..., None]
140
+ x = x + m * (y - x)
141
+ else:
142
+ x = y
143
+ return x
144
+
145
+ class GLPS_ACTV1_Inner(nn.Module):
146
+ def __init__(self, config: GLPS_ACTV1Config) -> None:
147
+ super().__init__()
148
+ self.config = config
149
+ self.forward_dtype = getattr(torch, self.config.forward_dtype)
150
+
151
+ # I/O
152
+ self.embed_scale = math.sqrt(self.config.hidden_size)
153
+ embed_init_std = 1.0 / self.embed_scale
154
+
155
+ self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
156
+ self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
157
+ self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
158
+
159
+ # Puzzle emb (optional) — same convention as HRM/TRM
160
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
161
+ if self.config.puzzle_emb_ndim > 0:
162
+ self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim, batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
163
+
164
+ # Positional encodings
165
+ if self.config.pos_encodings == "rope":
166
+ self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads, max_position_embeddings=self.config.seq_len + self.puzzle_emb_len, base=self.config.rope_theta)
167
+ elif self.config.pos_encodings == "learned":
168
+ self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
169
+
170
+ # Reasoning stacks
171
+ self.H_level = GLPSReasoningModule(layers=[GLPSBlock(self.config) for _ in range(self.config.H_layers)])
172
+ self.L_level = GLPSReasoningModule(layers=[GLPSBlock(self.config) for _ in range(self.config.L_layers)])
173
+
174
+ # Initial states (match HRM/TRM style)
175
+ H_init = trunc_normal_init_(torch.empty(1, 1, self.config.hidden_size, dtype=self.forward_dtype), std=1.0)
176
+ L_init = trunc_normal_init_(torch.empty(1, 1, self.config.hidden_size, dtype=self.forward_dtype), std=1.0)
177
+ self.register_buffer("H_init", H_init, persistent=True)
178
+ self.register_buffer("L_init", L_init, persistent=True)
179
+
180
+ # GLPS small heads
181
+ self.candidate_head = CastedLinear(self.config.hidden_size, 16, bias=True) # task-specific; for Sudoku you can slice to 9
182
+ self.certainty_head = CastedLinear(self.config.hidden_size, 1, bias=True)
183
+ self.energy_head = CastedLinear(self.config.hidden_size, 1, bias=True)
184
+
185
+ # Low-rank dependency scorer (shared)
186
+ r = max(1, self.config.dep_rank)
187
+ self.dep_q = CastedLinear(self.config.hidden_size, r, bias=False)
188
+ self.dep_k = CastedLinear(self.config.hidden_size, r, bias=False)
189
+
190
+ # Q head init like HRM/TRM (near-zero -> easier bootstrapping)
191
+ with torch.no_grad():
192
+ self.q_head.weight.zero_()
193
+ self.q_head.bias.fill_(-5)
194
+
195
+ def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
196
+ # Token embedding
197
+ embedding = self.embed_tokens(input.to(torch.int32))
198
+
199
+ # Puzzle embeddings
200
+ if self.config.puzzle_emb_ndim > 0:
201
+ puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
202
+ pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
203
+ if pad_count > 0:
204
+ puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
205
+ embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
206
+
207
+ # Position embeddings
208
+ if self.config.pos_encodings == "learned":
209
+ embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
210
+
211
+ return self.embed_scale * embedding
212
+
213
+ def empty_carry(self, batch_size: int):
214
+ return GLPS_ACTV1InnerCarry(
215
+ z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
216
+ z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
217
+ )
218
+
219
+ def reset_carry(self, reset_flag: torch.Tensor, carry: GLPS_ACTV1InnerCarry):
220
+ # Explicitly expand buffers and mask to target shapes to avoid shape confusion
221
+ B, L, D = carry.z_H.shape
222
+ # Reduce/reset flag to per-batch boolean vector of shape [B]
223
+ if reset_flag.ndim == 1 and reset_flag.shape[0] == B:
224
+ reset_b = reset_flag.to(torch.bool)
225
+ else:
226
+ # If shape is [B, ...] reduce across non-batch dims; otherwise fallback to first B entries
227
+ try:
228
+ reset_b = reset_flag.reshape(B, -1).any(dim=1).to(torch.bool)
229
+ except Exception:
230
+ reset_b = reset_flag.reshape(-1)[:B].to(torch.bool)
231
+ m = reset_b.view(B, 1, 1)
232
+ mH = m.expand(B, L, D)
233
+ mL = mH # same shape for z_L
234
+ H_init_exp = self.H_init.expand(B, L, D)
235
+ L_init_exp = self.L_init.expand(B, L, D)
236
+ return GLPS_ACTV1InnerCarry(
237
+ z_H=torch.where(mH, H_init_exp, carry.z_H),
238
+ z_L=torch.where(mL, L_init_exp, carry.z_L),
239
+ )
240
+
241
+ def _global_scan(self, z_L, z_H, input_embeddings, seq_info):
242
+ # One light pass to gather global signals
243
+ z_scan = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
244
+ cand_logits = self.candidate_head(z_scan) # [B, L, C]
245
+ certainty = torch.sigmoid(self.certainty_head(z_scan)) # [B, L, 1]
246
+ return z_scan, cand_logits, certainty
247
+
248
+ def _build_dep_mask(self, z_ctx, uncertain_mask: torch.Tensor):
249
+ """Compute a dependency-based focus mask from a low-rank bilinear score.
250
+ uncertain_mask: [B, L] boolean mask of cells that are currently uncertain
251
+ Returns: dep_mask [B, L] boolean mask of cells to (re)update.
252
+ """
253
+ B, L, D = z_ctx.shape
254
+ # Project to low-rank space; use float32 to avoid dtype mismatches under torch.compile
255
+ Q = self.dep_q(z_ctx).to(torch.float32) # [B, L, r]
256
+ K = self.dep_k(z_ctx).to(torch.float32) # [B, L, r]
257
+ r = max(1, int(Q.shape[-1]))
258
+ sim = torch.matmul(Q, K.transpose(1, 2)) # [B, L, L] (float32)
259
+ sim = sim / math.sqrt(r)
260
+
261
+ # Aggregate influence from uncertain queries onto target tokens
262
+ src = uncertain_mask.to(sim.dtype).unsqueeze(1) # [B, 1, L]
263
+ influence = torch.matmul(src, sim).squeeze(1) # [B, L]
264
+
265
+ # Top-k influenced tokens per batch
266
+ topk = min(self.config.dep_topk, L)
267
+ vals, idx = torch.topk(influence, k=topk, dim=-1) # [B, topk]
268
+ dep_mask = torch.zeros_like(uncertain_mask)
269
+ dep_mask.scatter_(1, idx, True)
270
+
271
+ # Always include uncertain cells themselves
272
+ dep_mask = dep_mask | uncertain_mask
273
+ return dep_mask
274
+
275
+ def forward(self, carry: GLPS_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]):
276
+ seq_info = dict(cos_sin=getattr(self, "rotary_emb", None)() if hasattr(self, "rotary_emb") else None)
277
+
278
+ # Encode inputs
279
+ input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
280
+
281
+ # States
282
+ z_H, z_L = carry.z_H, carry.z_L
283
+
284
+ if not self.config.glps_enabled:
285
+ # Fallback to an HRM-like single-cycle grad update for compatibility
286
+ with torch.no_grad():
287
+ for _H in range(self.config.H_cycles - 1):
288
+ for _L in range(self.config.L_cycles):
289
+ z_L = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
290
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
291
+ # final grad step
292
+ for _L in range(self.config.L_cycles):
293
+ z_L = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
294
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
295
+
296
+ # Outputs
297
+ new_carry = GLPS_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())
298
+ logits = self.lm_head(z_H)[:, self.puzzle_emb_len:]
299
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
300
+ conf = torch.zeros_like(q_logits[..., :1]) + 0.5 # neutral
301
+ return new_carry, logits, (q_logits[..., 0], q_logits[..., 1]), conf
302
+
303
+ # ===== GLPS path =====
304
+ # H1: global scan (cheap)
305
+ with torch.no_grad():
306
+ z_scan, cand_logits, certainty = self._global_scan(z_L, z_H, input_embeddings, seq_info)
307
+
308
+ # L1: fill-obvious -> compute stable vs uncertain masks
309
+ if self.config.glps_fill_obvious:
310
+ obvious_mask = (certainty >= self.config.glps_tau_uncertain) # [B, L, 1]
311
+ else:
312
+ obvious_mask = torch.zeros_like(certainty).bool()
313
+ stable_mask = obvious_mask.squeeze(-1) # [B, L]
314
+ uncertain_mask = ~stable_mask # [B, L]
315
+
316
+ # H2: dependency prediction over remaining cells
317
+ if self.config.glps_dep_graph:
318
+ dep_mask = self._build_dep_mask(z_scan, uncertain_mask) # [B, L]
319
+ else:
320
+ dep_mask = uncertain_mask
321
+
322
+ # L2: targeted refinement (a couple of masked iters)
323
+ update_mask = dep_mask if self.config.glps_token_masking else None
324
+ z = z_scan.detach() # use scanned context as start
325
+ for _ in range(self.config.glps_max_targeted_iters):
326
+ z = self.L_level(z, z_H + input_embeddings, update_mask=update_mask, **seq_info)
327
+ # Refresh certainty to shrink mask (optional but cheap)
328
+ cert_now = torch.sigmoid(self.certainty_head(z)).squeeze(-1)
329
+ if self.config.glps_token_masking:
330
+ update_mask = dep_mask & (cert_now < self.config.glps_tau_uncertain)
331
+
332
+ # Merge into H and do a light H update with grad
333
+ z_L = z
334
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
335
+
336
+ # H3: energy/consistency -> confidence & optional global propagate
337
+ energy = torch.sigmoid(self.energy_head(z_H)).mean(dim=1) # [B, 1]
338
+ conf = 1.0 - energy
339
+ need_sweep = (conf.squeeze(-1) < self.config.glps_tau_halt)
340
+ if self.config.glps_global_propagate_on_low_conf and need_sweep.any():
341
+ # one final full sweep only for rows needing it
342
+ maskB = need_sweep.view(-1, 1, 1)
343
+ zL2 = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
344
+ zH2 = self.H_level(z_H, zL2, update_mask=None, **seq_info)
345
+ z_L = torch.where(maskB, zL2, z_L)
346
+ z_H = torch.where(maskB, zH2, z_H)
347
+
348
+ # Outputs
349
+ new_carry = GLPS_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())
350
+ logits = self.lm_head(z_H)[:, self.puzzle_emb_len:]
351
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
352
+ return new_carry, logits, (q_logits[..., 0], q_logits[..., 1]), conf
353
+
354
+ class GLPS_ACTV1(nn.Module):
355
+ """ACT-style wrapper that mixes Q-halt with GLPS confidence."""
356
+ def __init__(self, config_dict: dict):
357
+ super().__init__()
358
+ self.config = GLPS_ACTV1Config(**config_dict)
359
+ self.inner = GLPS_ACTV1_Inner(self.config)
360
+
361
+ @property
362
+ def puzzle_emb(self):
363
+ return self.inner.puzzle_emb
364
+
365
+ def initial_carry(self, batch: Dict[str, torch.Tensor]):
366
+ batch_size = batch["inputs"].shape[0]
367
+ return GLPS_ACTV1Carry(
368
+ inner_carry=self.inner.empty_carry(batch_size),
369
+ steps=torch.zeros((batch_size,), dtype=torch.int32),
370
+ halted=torch.ones((batch_size,), dtype=torch.bool), # start halted to force reset on first pass
371
+ current_data={k: torch.empty_like(v) for k, v in batch.items()}
372
+ )
373
+
374
+ def forward(self, carry: GLPS_ACTV1Carry, batch: Dict[str, torch.Tensor]):
375
+ # Reset halted seqs
376
+ new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
377
+ new_steps = torch.where(carry.halted, 0, carry.steps)
378
+ new_current_data = {k: torch.where(carry.halted.view((-1,) + (1,) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
379
+
380
+ # Inner step
381
+ new_inner_carry, logits, (q_halt_logits, q_continue_logits), conf = self.inner(new_inner_carry, new_current_data)
382
+
383
+ outputs = {
384
+ "logits": logits,
385
+ "q_halt_logits": q_halt_logits,
386
+ "q_continue_logits": q_continue_logits,
387
+ "conf": conf.squeeze(-1),
388
+ }
389
+
390
+ with torch.no_grad():
391
+ new_steps = new_steps + 1
392
+ is_last_step = new_steps >= self.config.halt_max_steps
393
+
394
+ # Combine halt signals: Q or confidence or last-step
395
+ halted = is_last_step | (q_halt_logits > q_continue_logits) | (conf.squeeze(-1) >= self.config.glps_tau_halt)
396
+
397
+ # Exploration during training only
398
+ if self.training and (self.config.halt_max_steps > 1):
399
+ min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
400
+ halted = halted & (new_steps >= min_halt_steps)
401
+
402
+ # Optional target for Q-learning (kept similar to HRM)
403
+ next_conf = self.inner(new_inner_carry, new_current_data)[-1]
404
+ outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, q_halt_logits, torch.maximum(q_halt_logits, q_continue_logits)))
405
+ else:
406
+ # During eval, always use max_steps to ensure consistent reasoning depth (same as TRM/HRM eval behavior)
407
+ halted = is_last_step
408
+
409
+ return GLPS_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v4_evalboost/losses.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Tuple, Dict, Sequence, Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ import math
7
+
8
+ IGNORE_LABEL_ID = -100
9
+
10
+
11
+ def s(x, epsilon=1e-30):
12
+ return torch.where(
13
+ x<0,
14
+ 1/(1-x+ epsilon),
15
+ x + 1
16
+ )
17
+
18
+
19
+ def log_stablemax(x, dim=-1):
20
+ s_x = s(x)
21
+ return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
22
+
23
+
24
+ def stablemax_cross_entropy(logits, labels, ignore_index: int = -100, valid_mask=None):
25
+ logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
26
+
27
+ if valid_mask is None:
28
+ valid_mask = (labels != ignore_index)
29
+ transformed_labels = torch.where(valid_mask, labels, 0)
30
+ prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
31
+
32
+ return -torch.where(valid_mask, prediction_logprobs, 0)
33
+
34
+
35
+ def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
36
+ # Cast logits to f32
37
+ # Flatten logits
38
+ return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
39
+
40
+
41
+ class ACTLossHead(nn.Module):
42
+ def __init__(self, model: nn.Module, loss_type: str):
43
+ super().__init__()
44
+ self.model = model
45
+ self.loss_fn = globals()[loss_type]
46
+
47
+ def initial_carry(self, *args, **kwargs):
48
+ return self.model.initial_carry(*args, **kwargs) # type: ignore
49
+
50
+ def forward(
51
+ self,
52
+ return_keys: Sequence[str],
53
+ # Model args
54
+ **model_kwargs,
55
+ ) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
56
+ # Model logits
57
+ # B x SeqLen x D
58
+ new_carry, outputs = self.model(**model_kwargs)
59
+ labels = new_carry.current_data["labels"]
60
+
61
+ with torch.no_grad():
62
+ # Preds
63
+ outputs["preds"] = torch.argmax(outputs["logits"], dim=-1)
64
+
65
+ # Correctness
66
+ mask = (labels != IGNORE_LABEL_ID)
67
+ loss_counts = mask.sum(-1)
68
+ loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
69
+
70
+ is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
71
+ seq_is_correct = is_correct.sum(-1) == loss_counts
72
+
73
+ # Metrics (halted)
74
+ valid_metrics = new_carry.halted & (loss_counts > 0)
75
+ metrics = {
76
+ "count": valid_metrics.sum(),
77
+
78
+ "accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
79
+ "exact_accuracy": (valid_metrics & seq_is_correct).sum(),
80
+
81
+ "q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
82
+ "steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
83
+ }
84
+
85
+ # Losses
86
+
87
+ lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID, valid_mask=mask) / loss_divisor).sum()
88
+ q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
89
+ metrics.update({
90
+ "lm_loss": lm_loss.detach(),
91
+ "q_halt_loss": q_halt_loss.detach(),
92
+ })
93
+ # Q continue (bootstrapping target loss); Alexia: This fits Q-learning, but seems totally unecessary
94
+ q_continue_loss = 0
95
+ if "target_q_continue" in outputs:
96
+ q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
97
+
98
+ metrics["q_continue_loss"] = q_continue_loss.detach()
99
+ # Filter outputs for return
100
+ detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
101
+
102
+ return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()
Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v4_evalboost2/all_config.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ arch:
2
+ H_cycles: 3
3
+ H_layers: 2
4
+ L_cycles: 1
5
+ L_layers: 7
6
+ dep_rank: 64
7
+ dep_topk: 12
8
+ expansion: 4.0
9
+ forward_dtype: bfloat16
10
+ glps_dep_graph: true
11
+ glps_enabled: true
12
+ glps_fill_obvious: true
13
+ glps_global_propagate_on_low_conf: true
14
+ glps_max_targeted_iters: 4
15
+ glps_tau_halt: 0.95
16
+ glps_tau_uncertain: 0.8
17
+ glps_token_masking: true
18
+ halt_exploration_prob: 0.1
19
+ halt_max_steps: 16
20
+ hidden_size: 512
21
+ loss:
22
+ loss_type: stablemax_cross_entropy
23
+ name: losses@ACTLossHead
24
+ mlp_t: false
25
+ name: recursive_reasoning.glps@GLPS_ACTV1
26
+ num_heads: 8
27
+ pos_encodings: rope
28
+ puzzle_emb_ndim: 0
29
+ rms_norm_eps: 1.0e-05
30
+ rope_theta: 10000.0
31
+ beta1: 0.9
32
+ beta2: 0.95
33
+ checkpoint_every_eval: true
34
+ checkpoint_path: checkpoints/Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v4_evalboost2
35
+ data_paths:
36
+ - data/sudoku-extreme-1k-aug-1000
37
+ data_paths_test: []
38
+ ema: true
39
+ ema_rate: 0.999
40
+ epochs: 50000
41
+ eval_glps_max_targeted_iters: 8
42
+ eval_glps_tau_halt: 0.8
43
+ eval_halt_max_steps: 48
44
+ eval_interval: 5000
45
+ eval_only: true
46
+ eval_save_outputs: []
47
+ evaluators: []
48
+ freeze_weights: false
49
+ global_batch_size: 768
50
+ load_checkpoint: /teamspace/studios/this_studio/TinyRecursiveModels/checkpoints/Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v4_decay_ft10k/step_58590
51
+ lr: 0.0001
52
+ lr_min_ratio: 0.01
53
+ lr_warmup_steps: 2000
54
+ min_eval_interval: 0
55
+ project_name: Sudoku-extreme-1k-aug-1000-ACT-torch
56
+ puzzle_emb_lr: 0.0001
57
+ puzzle_emb_weight_decay: 1.0
58
+ run_name: pretrain_glps_sudoku_v4_evalboost2
59
+ seed: 0
60
+ weight_decay: 1.0
Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_glps_sudoku_v4_evalboost2/glps.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Tuple, List, Dict
3
+ from dataclasses import dataclass
4
+ import math
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ from pydantic import BaseModel
9
+
10
+ # Reuse the same building blocks as HRM/TRM
11
+ from models.common import trunc_normal_init_
12
+ from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
13
+ from models.sparse_embedding import CastedSparseEmbedding
14
+
15
+ """
16
+ Global-Local Predictive Solver (GLPS)
17
+ ------------------------------------
18
+ A light-weight control-policy on top of the HRM/TRM style blocks:
19
+ - H1: global scan -> certainty map
20
+ - L1: fill-obvious (lock stable cells)
21
+ - H2: dependency scoring over remaining cells
22
+ - L2: targeted refinement (masked updates)
23
+ - H3: energy-based confidence -> (optional) one global propagate sweep -> halt
24
+
25
+ This file keeps parameter growth tiny: a few heads + (optional) low-rank dependency scorer.
26
+ """
27
+
28
+ @dataclass
29
+ class GLPS_ACTV1InnerCarry:
30
+ z_H: torch.Tensor
31
+ z_L: torch.Tensor
32
+
33
+ @dataclass
34
+ class GLPS_ACTV1Carry:
35
+ inner_carry: GLPS_ACTV1InnerCarry
36
+ steps: torch.Tensor
37
+ halted: torch.Tensor
38
+ current_data: Dict[str, torch.Tensor]
39
+
40
+ class GLPS_ACTV1Config(BaseModel):
41
+ # Core IO / shapes
42
+ batch_size: int
43
+ seq_len: int
44
+ puzzle_emb_ndim: int = 0
45
+ num_puzzle_identifiers: int = 1
46
+ vocab_size: int = 256
47
+
48
+ # Cycle schedule
49
+ H_cycles: int = 3 # (scan -> refine -> check) typical
50
+ L_cycles: int = 1
51
+
52
+ # Depth
53
+ H_layers: int = 2
54
+ L_layers: int = 4
55
+
56
+ # Transformer config
57
+ hidden_size: int = 512
58
+ expansion: float = 2.0
59
+ num_heads: int = 8
60
+ pos_encodings: str = "rope"
61
+
62
+ rms_norm_eps: float = 1e-5
63
+ rope_theta: float = 10000.0
64
+
65
+ # ACT wrapper
66
+ halt_max_steps: int = 4
67
+ halt_exploration_prob: float = 0.1
68
+
69
+ forward_dtype: str = "bfloat16"
70
+
71
+ # Optional: use MLP on L instead of attention (matches HRM/TRM option)
72
+ mlp_t: bool = False
73
+
74
+ # ---- GLPS extras (tiny) ----
75
+ glps_enabled: bool = True
76
+ glps_fill_obvious: bool = True
77
+ glps_dep_graph: bool = True
78
+ glps_token_masking: bool = True
79
+ glps_global_propagate_on_low_conf: bool = True
80
+
81
+ glps_tau_halt: float = 0.95 # final confidence to halt
82
+ glps_tau_uncertain: float = 0.60 # cell-wise certainty threshold
83
+ glps_max_targeted_iters: int = 2 # small number: 1-2
84
+
85
+ # Dependency scorer (low rank bilinear)
86
+ dep_rank: int = 32
87
+ dep_topk: int = 8
88
+
89
+ class GLPSBlock(nn.Module):
90
+ def __init__(self, config: GLPS_ACTV1Config) -> None:
91
+ super().__init__()
92
+ self.config = config
93
+ if self.config.mlp_t:
94
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size)
95
+ self.mlp_t = SwiGLU(
96
+ hidden_size=self.config.seq_len + self.puzzle_emb_len, # treat sequence as channel
97
+ expansion=config.expansion,
98
+ )
99
+ else:
100
+ self.self_attn = Attention(
101
+ hidden_size=config.hidden_size,
102
+ head_dim=config.hidden_size // config.num_heads,
103
+ num_heads=config.num_heads,
104
+ num_key_value_heads=config.num_heads,
105
+ causal=False,
106
+ )
107
+ self.mlp = SwiGLU(hidden_size=config.hidden_size, expansion=config.expansion)
108
+ self.norm_eps = config.rms_norm_eps
109
+
110
+ def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
111
+ if self.config.mlp_t:
112
+ # MLP over sequence dimension (mlp-t)
113
+ hidden_states = hidden_states.transpose(1, 2)
114
+ out = self.mlp_t(hidden_states)
115
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
116
+ hidden_states = hidden_states.transpose(1, 2)
117
+ else:
118
+ hidden_states = rms_norm(
119
+ hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states),
120
+ variance_epsilon=self.norm_eps,
121
+ )
122
+ out = self.mlp(hidden_states)
123
+ hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
124
+ return hidden_states
125
+
126
+ class GLPSReasoningModule(nn.Module):
127
+ """Reasoning stack with optional masked updates (only update uncertain tokens)."""
128
+ def __init__(self, layers: List[GLPSBlock]):
129
+ super().__init__()
130
+ self.layers = torch.nn.ModuleList(layers)
131
+
132
+ def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, update_mask: torch.Tensor = None, **kwargs) -> torch.Tensor:
133
+ x = hidden_states
134
+ for layer in self.layers:
135
+ # Compute candidate update using injected context
136
+ y = layer(hidden_states=x + input_injection, **kwargs)
137
+ if update_mask is not None:
138
+ # Convex blend keeps frozen tokens unchanged
139
+ m = update_mask.to(x.dtype)[..., None]
140
+ x = x + m * (y - x)
141
+ else:
142
+ x = y
143
+ return x
144
+
145
+ class GLPS_ACTV1_Inner(nn.Module):
146
+ def __init__(self, config: GLPS_ACTV1Config) -> None:
147
+ super().__init__()
148
+ self.config = config
149
+ self.forward_dtype = getattr(torch, self.config.forward_dtype)
150
+
151
+ # I/O
152
+ self.embed_scale = math.sqrt(self.config.hidden_size)
153
+ embed_init_std = 1.0 / self.embed_scale
154
+
155
+ self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
156
+ self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
157
+ self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
158
+
159
+ # Puzzle emb (optional) — same convention as HRM/TRM
160
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
161
+ if self.config.puzzle_emb_ndim > 0:
162
+ self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim, batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
163
+
164
+ # Positional encodings
165
+ if self.config.pos_encodings == "rope":
166
+ self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads, max_position_embeddings=self.config.seq_len + self.puzzle_emb_len, base=self.config.rope_theta)
167
+ elif self.config.pos_encodings == "learned":
168
+ self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
169
+
170
+ # Reasoning stacks
171
+ self.H_level = GLPSReasoningModule(layers=[GLPSBlock(self.config) for _ in range(self.config.H_layers)])
172
+ self.L_level = GLPSReasoningModule(layers=[GLPSBlock(self.config) for _ in range(self.config.L_layers)])
173
+
174
+ # Initial states (match HRM/TRM style)
175
+ H_init = trunc_normal_init_(torch.empty(1, 1, self.config.hidden_size, dtype=self.forward_dtype), std=1.0)
176
+ L_init = trunc_normal_init_(torch.empty(1, 1, self.config.hidden_size, dtype=self.forward_dtype), std=1.0)
177
+ self.register_buffer("H_init", H_init, persistent=True)
178
+ self.register_buffer("L_init", L_init, persistent=True)
179
+
180
+ # GLPS small heads
181
+ self.candidate_head = CastedLinear(self.config.hidden_size, 16, bias=True) # task-specific; for Sudoku you can slice to 9
182
+ self.certainty_head = CastedLinear(self.config.hidden_size, 1, bias=True)
183
+ self.energy_head = CastedLinear(self.config.hidden_size, 1, bias=True)
184
+
185
+ # Low-rank dependency scorer (shared)
186
+ r = max(1, self.config.dep_rank)
187
+ self.dep_q = CastedLinear(self.config.hidden_size, r, bias=False)
188
+ self.dep_k = CastedLinear(self.config.hidden_size, r, bias=False)
189
+
190
+ # Q head init like HRM/TRM (near-zero -> easier bootstrapping)
191
+ with torch.no_grad():
192
+ self.q_head.weight.zero_()
193
+ self.q_head.bias.fill_(-5)
194
+
195
+ def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
196
+ # Token embedding
197
+ embedding = self.embed_tokens(input.to(torch.int32))
198
+
199
+ # Puzzle embeddings
200
+ if self.config.puzzle_emb_ndim > 0:
201
+ puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
202
+ pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
203
+ if pad_count > 0:
204
+ puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
205
+ embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
206
+
207
+ # Position embeddings
208
+ if self.config.pos_encodings == "learned":
209
+ embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
210
+
211
+ return self.embed_scale * embedding
212
+
213
+ def empty_carry(self, batch_size: int):
214
+ return GLPS_ACTV1InnerCarry(
215
+ z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
216
+ z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
217
+ )
218
+
219
+ def reset_carry(self, reset_flag: torch.Tensor, carry: GLPS_ACTV1InnerCarry):
220
+ # Explicitly expand buffers and mask to target shapes to avoid shape confusion
221
+ B, L, D = carry.z_H.shape
222
+ # Reduce/reset flag to per-batch boolean vector of shape [B]
223
+ if reset_flag.ndim == 1 and reset_flag.shape[0] == B:
224
+ reset_b = reset_flag.to(torch.bool)
225
+ else:
226
+ # If shape is [B, ...] reduce across non-batch dims; otherwise fallback to first B entries
227
+ try:
228
+ reset_b = reset_flag.reshape(B, -1).any(dim=1).to(torch.bool)
229
+ except Exception:
230
+ reset_b = reset_flag.reshape(-1)[:B].to(torch.bool)
231
+ m = reset_b.view(B, 1, 1)
232
+ mH = m.expand(B, L, D)
233
+ mL = mH # same shape for z_L
234
+ H_init_exp = self.H_init.expand(B, L, D)
235
+ L_init_exp = self.L_init.expand(B, L, D)
236
+ return GLPS_ACTV1InnerCarry(
237
+ z_H=torch.where(mH, H_init_exp, carry.z_H),
238
+ z_L=torch.where(mL, L_init_exp, carry.z_L),
239
+ )
240
+
241
+ def _global_scan(self, z_L, z_H, input_embeddings, seq_info):
242
+ # One light pass to gather global signals
243
+ z_scan = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
244
+ cand_logits = self.candidate_head(z_scan) # [B, L, C]
245
+ certainty = torch.sigmoid(self.certainty_head(z_scan)) # [B, L, 1]
246
+ return z_scan, cand_logits, certainty
247
+
248
+ def _build_dep_mask(self, z_ctx, uncertain_mask: torch.Tensor):
249
+ """Compute a dependency-based focus mask from a low-rank bilinear score.
250
+ uncertain_mask: [B, L] boolean mask of cells that are currently uncertain
251
+ Returns: dep_mask [B, L] boolean mask of cells to (re)update.
252
+ """
253
+ B, L, D = z_ctx.shape
254
+ # Project to low-rank space; use float32 to avoid dtype mismatches under torch.compile
255
+ Q = self.dep_q(z_ctx).to(torch.float32) # [B, L, r]
256
+ K = self.dep_k(z_ctx).to(torch.float32) # [B, L, r]
257
+ r = max(1, int(Q.shape[-1]))
258
+ sim = torch.matmul(Q, K.transpose(1, 2)) # [B, L, L] (float32)
259
+ sim = sim / math.sqrt(r)
260
+
261
+ # Aggregate influence from uncertain queries onto target tokens
262
+ src = uncertain_mask.to(sim.dtype).unsqueeze(1) # [B, 1, L]
263
+ influence = torch.matmul(src, sim).squeeze(1) # [B, L]
264
+
265
+ # Top-k influenced tokens per batch
266
+ topk = min(self.config.dep_topk, L)
267
+ vals, idx = torch.topk(influence, k=topk, dim=-1) # [B, topk]
268
+ dep_mask = torch.zeros_like(uncertain_mask)
269
+ dep_mask.scatter_(1, idx, True)
270
+
271
+ # Always include uncertain cells themselves
272
+ dep_mask = dep_mask | uncertain_mask
273
+ return dep_mask
274
+
275
+ def forward(self, carry: GLPS_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]):
276
+ seq_info = dict(cos_sin=getattr(self, "rotary_emb", None)() if hasattr(self, "rotary_emb") else None)
277
+
278
+ # Encode inputs
279
+ input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
280
+
281
+ # States
282
+ z_H, z_L = carry.z_H, carry.z_L
283
+
284
+ if not self.config.glps_enabled:
285
+ # Fallback to an HRM-like single-cycle grad update for compatibility
286
+ with torch.no_grad():
287
+ for _H in range(self.config.H_cycles - 1):
288
+ for _L in range(self.config.L_cycles):
289
+ z_L = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
290
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
291
+ # final grad step
292
+ for _L in range(self.config.L_cycles):
293
+ z_L = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
294
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
295
+
296
+ # Outputs
297
+ new_carry = GLPS_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())
298
+ logits = self.lm_head(z_H)[:, self.puzzle_emb_len:]
299
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
300
+ conf = torch.zeros_like(q_logits[..., :1]) + 0.5 # neutral
301
+ return new_carry, logits, (q_logits[..., 0], q_logits[..., 1]), conf
302
+
303
+ # ===== GLPS path =====
304
+ # H1: global scan (cheap)
305
+ with torch.no_grad():
306
+ z_scan, cand_logits, certainty = self._global_scan(z_L, z_H, input_embeddings, seq_info)
307
+
308
+ # L1: fill-obvious -> compute stable vs uncertain masks
309
+ if self.config.glps_fill_obvious:
310
+ obvious_mask = (certainty >= self.config.glps_tau_uncertain) # [B, L, 1]
311
+ else:
312
+ obvious_mask = torch.zeros_like(certainty).bool()
313
+ stable_mask = obvious_mask.squeeze(-1) # [B, L]
314
+ uncertain_mask = ~stable_mask # [B, L]
315
+
316
+ # H2: dependency prediction over remaining cells
317
+ if self.config.glps_dep_graph:
318
+ dep_mask = self._build_dep_mask(z_scan, uncertain_mask) # [B, L]
319
+ else:
320
+ dep_mask = uncertain_mask
321
+
322
+ # L2: targeted refinement (a couple of masked iters)
323
+ update_mask = dep_mask if self.config.glps_token_masking else None
324
+ z = z_scan.detach() # use scanned context as start
325
+ for _ in range(self.config.glps_max_targeted_iters):
326
+ z = self.L_level(z, z_H + input_embeddings, update_mask=update_mask, **seq_info)
327
+ # Refresh certainty to shrink mask (optional but cheap)
328
+ cert_now = torch.sigmoid(self.certainty_head(z)).squeeze(-1)
329
+ if self.config.glps_token_masking:
330
+ update_mask = dep_mask & (cert_now < self.config.glps_tau_uncertain)
331
+
332
+ # Merge into H and do a light H update with grad
333
+ z_L = z
334
+ z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)
335
+
336
+ # H3: energy/consistency -> confidence & optional global propagate
337
+ energy = torch.sigmoid(self.energy_head(z_H)).mean(dim=1) # [B, 1]
338
+ conf = 1.0 - energy
339
+ need_sweep = (conf.squeeze(-1) < self.config.glps_tau_halt)
340
+ if self.config.glps_global_propagate_on_low_conf and need_sweep.any():
341
+ # one final full sweep only for rows needing it
342
+ maskB = need_sweep.view(-1, 1, 1)
343
+ zL2 = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
344
+ zH2 = self.H_level(z_H, zL2, update_mask=None, **seq_info)
345
+ z_L = torch.where(maskB, zL2, z_L)
346
+ z_H = torch.where(maskB, zH2, z_H)
347
+
348
+ # Outputs
349
+ new_carry = GLPS_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())
350
+ logits = self.lm_head(z_H)[:, self.puzzle_emb_len:]
351
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
352
+ return new_carry, logits, (q_logits[..., 0], q_logits[..., 1]), conf
353
+
354
+ class GLPS_ACTV1(nn.Module):
355
+ """ACT-style wrapper that mixes Q-halt with GLPS confidence."""
356
+ def __init__(self, config_dict: dict):
357
+ super().__init__()
358
+ self.config = GLPS_ACTV1Config(**config_dict)
359
+ self.inner = GLPS_ACTV1_Inner(self.config)
360
+
361
+ @property
362
+ def puzzle_emb(self):
363
+ return self.inner.puzzle_emb
364
+
365
+ def initial_carry(self, batch: Dict[str, torch.Tensor]):
366
+ batch_size = batch["inputs"].shape[0]
367
+ return GLPS_ACTV1Carry(
368
+ inner_carry=self.inner.empty_carry(batch_size),
369
+ steps=torch.zeros((batch_size,), dtype=torch.int32),
370
+ halted=torch.ones((batch_size,), dtype=torch.bool), # start halted to force reset on first pass
371
+ current_data={k: torch.empty_like(v) for k, v in batch.items()}
372
+ )
373
+
374
+ def forward(self, carry: GLPS_ACTV1Carry, batch: Dict[str, torch.Tensor]):
375
+ # Reset halted seqs
376
+ new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
377
+ new_steps = torch.where(carry.halted, 0, carry.steps)
378
+ new_current_data = {k: torch.where(carry.halted.view((-1,) + (1,) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
379
+
380
+ # Inner step
381
+ new_inner_carry, logits, (q_halt_logits, q_continue_logits), conf = self.inner(new_inner_carry, new_current_data)
382
+
383
+ outputs = {
384
+ "logits": logits,
385
+ "q_halt_logits": q_halt_logits,
386
+ "q_continue_logits": q_continue_logits,
387
+ "conf": conf.squeeze(-1),
388
+ }
389
+
390
+ with torch.no_grad():
391
+ new_steps = new_steps + 1
392
+ is_last_step = new_steps >= self.config.halt_max_steps
393
+
394
+ # Combine halt signals: Q or confidence or last-step
395
+ halted = is_last_step | (q_halt_logits > q_continue_logits) | (conf.squeeze(-1) >= self.config.glps_tau_halt)
396
+
397
+ # Exploration during training only
398
+ if self.training and (self.config.halt_max_steps > 1):
399
+ min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
400
+ halted = halted & (new_steps >= min_halt_steps)
401
+
402
+ # Optional target for Q-learning (kept similar to HRM)
403
+ next_conf = self.inner(new_inner_carry, new_current_data)[-1]
404
+ outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, q_halt_logits, torch.maximum(q_halt_logits, q_continue_logits)))
405
+ else:
406
+ # During eval, always use max_steps to ensure consistent reasoning depth (same as TRM/HRM eval behavior)
407
+ halted = is_last_step
408
+
409
+ return GLPS_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs