File size: 20,029 Bytes
63dc939
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428

from typing import Tuple, List, Dict
from dataclasses import dataclass
import math
import torch
import torch.nn.functional as F
from torch import nn
from pydantic import BaseModel

from models.common import trunc_normal_init_
from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
from models.sparse_embedding import CastedSparseEmbedding

"""
Global-Local Predictive Solver (GLPS)
------------------------------------
A light-weight control-policy on top of the style blocks:
- H1: global scan -> certainty map
- L1: fill-obvious (lock stable cells)
- H2: dependency scoring over remaining cells
- L2: targeted refinement (masked updates)
- H3: energy-based confidence -> (optional) one global propagate sweep -> halt

This file keeps parameter growth tiny: a few heads + (optional) low-rank dependency scorer.
"""

@dataclass
class GLPS_ACTV1InnerCarry:
    z_H: torch.Tensor
    z_L: torch.Tensor

@dataclass
class GLPS_ACTV1Carry:
    inner_carry: GLPS_ACTV1InnerCarry
    steps: torch.Tensor
    halted: torch.Tensor
    current_data: Dict[str, torch.Tensor]

class GLPS_ACTV1Config(BaseModel):
    # Core IO / shapes
    batch_size: int
    seq_len: int
    puzzle_emb_ndim: int = 0
    num_puzzle_identifiers: int = 1
    vocab_size: int = 256

    # Cycle schedule
    H_cycles: int = 3              # (scan -> refine -> check) typical
    L_cycles: int = 1

    # Depth
    H_layers: int = 2
    L_layers: int = 4
    # Parameter sharing (TRM-style): when true, use one shared stack for H and L
    share_levels: bool = True
    # If > 0, overrides depth of shared stack; otherwise min(H_layers, L_layers)
    shared_layers: int = 0

    # Transformer config
    hidden_size: int = 512
    expansion: float = 2.0
    num_heads: int = 8
    pos_encodings: str = "rope"

    rms_norm_eps: float = 1e-5
    rope_theta: float = 10000.0

    # ACT wrapper
    halt_max_steps: int = 4
    halt_exploration_prob: float = 0.1

    forward_dtype: str = "bfloat16"

    # Optional: use MLP on L instead of attention (matches / option)
    mlp_t: bool = False

    # ---- GLPS extras (tiny) ----
    glps_enabled: bool = True
    glps_fill_obvious: bool = True
    glps_dep_graph: bool = True
    glps_token_masking: bool = True
    glps_global_propagate_on_low_conf: bool = True

    glps_tau_halt: float = 0.95       # final confidence to halt
    glps_tau_uncertain: float = 0.60  # cell-wise certainty threshold
    glps_max_targeted_iters: int = 2  # small number: 1-2

    # Dependency scorer (low rank bilinear)
    dep_rank: int = 32
    dep_topk: int = 8

    # When True, use simple halt threshold (q_halt > 0) instead of comparing q_halt vs q_continue
    no_ACT_continue: bool = True

class GLPSBlock(nn.Module):
    def __init__(self, config: GLPS_ACTV1Config) -> None:
        super().__init__()
        self.config = config
        if self.config.mlp_t:
            self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size)
            self.mlp_t = SwiGLU(
                hidden_size=self.config.seq_len + self.puzzle_emb_len, # treat sequence as channel
                expansion=config.expansion,
            )
        else:
            self.self_attn = Attention(
                hidden_size=config.hidden_size,
                head_dim=config.hidden_size // config.num_heads,
                num_heads=config.num_heads,
                num_key_value_heads=config.num_heads,
                causal=False,
            )
        self.mlp = SwiGLU(hidden_size=config.hidden_size, expansion=config.expansion)
        self.norm_eps = config.rms_norm_eps

    def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
        if self.config.mlp_t:
            # MLP over sequence dimension (mlp-t)
            hidden_states = hidden_states.transpose(1, 2)
            out = self.mlp_t(hidden_states)
            hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
            hidden_states = hidden_states.transpose(1, 2)
        else:
            hidden_states = rms_norm(
                hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states),
                variance_epsilon=self.norm_eps,
            )
        out = self.mlp(hidden_states)
        hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
        return hidden_states

class GLPSReasoningModule(nn.Module):
    """Reasoning stack with optional masked updates (only update uncertain tokens)."""
    def __init__(self, layers: List[GLPSBlock]):
        super().__init__()
        self.layers = torch.nn.ModuleList(layers)

    def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, update_mask: torch.Tensor = None, **kwargs) -> torch.Tensor:
        x = hidden_states
        for layer in self.layers:
            # Compute candidate update using injected context
            y = layer(hidden_states=x + input_injection, **kwargs)
            if update_mask is not None:
                # Convex blend keeps frozen tokens unchanged
                m = update_mask.to(x.dtype)[..., None]
                x = x + m * (y - x)
            else:
                x = y
        return x

class GLPS_ACTV1_Inner(nn.Module):
    def __init__(self, config: GLPS_ACTV1Config) -> None:
        super().__init__()
        self.config = config
        self.forward_dtype = getattr(torch, self.config.forward_dtype)

        # I/O
        self.embed_scale  = math.sqrt(self.config.hidden_size)
        embed_init_std = 1.0 / self.embed_scale

        self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
        self.lm_head      = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
        self.q_head       = CastedLinear(self.config.hidden_size, 2, bias=True)

        # Puzzle emb (optional) — same convention as /
        self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size)  # ceil div
        if self.config.puzzle_emb_ndim > 0:
            self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim, batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)

        # Positional encodings
        if self.config.pos_encodings == "rope":
            self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads, max_position_embeddings=self.config.seq_len + self.puzzle_emb_len, base=self.config.rope_theta)
        elif self.config.pos_encodings == "learned":
            self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)

        # Reasoning stacks (optionally shared between H and L, TRM-style)
        if self.config.share_levels:
            depth = self.config.shared_layers if (getattr(self.config, "shared_layers", 0) and self.config.shared_layers > 0) else min(self.config.H_layers, self.config.L_layers)
            shared_reasoner = GLPSReasoningModule(layers=[GLPSBlock(self.config) for _ in range(depth)])
            self.H_level = shared_reasoner
            self.L_level = shared_reasoner
        else:
            self.H_level = GLPSReasoningModule(layers=[GLPSBlock(self.config) for _ in range(self.config.H_layers)])
            self.L_level = GLPSReasoningModule(layers=[GLPSBlock(self.config) for _ in range(self.config.L_layers)])

        # Initial states (match / style)
        H_init = trunc_normal_init_(torch.empty(1, 1, self.config.hidden_size, dtype=self.forward_dtype), std=1.0)
        L_init = trunc_normal_init_(torch.empty(1, 1, self.config.hidden_size, dtype=self.forward_dtype), std=1.0)
        self.register_buffer("H_init", H_init, persistent=True)
        self.register_buffer("L_init", L_init, persistent=True)

        # GLPS small heads
        self.candidate_head = CastedLinear(self.config.hidden_size, 16, bias=True)  # task-specific; for Sudoku you can slice to 9
        self.certainty_head = CastedLinear(self.config.hidden_size, 1, bias=True)
        self.energy_head    = CastedLinear(self.config.hidden_size, 1, bias=True)

        # Low-rank dependency scorer (shared)
        r = max(1, self.config.dep_rank)
        self.dep_q = CastedLinear(self.config.hidden_size, r, bias=False)
        self.dep_k = CastedLinear(self.config.hidden_size, r, bias=False)

        # Q head init like / (near-zero -> easier bootstrapping)
        with torch.no_grad():
            self.q_head.weight.zero_()
            self.q_head.bias.fill_(-5)

    def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
        # Token embedding
        embedding = self.embed_tokens(input.to(torch.int32))

        # Puzzle embeddings
        if self.config.puzzle_emb_ndim > 0:
            puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
            pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
            if pad_count > 0:
                puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
            embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)

        # Position embeddings
        if self.config.pos_encodings == "learned":
            embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))

        return self.embed_scale * embedding

    def empty_carry(self, batch_size: int):
        return GLPS_ACTV1InnerCarry(
            z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
            z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
        )

    def reset_carry(self, reset_flag: torch.Tensor, carry: GLPS_ACTV1InnerCarry):
        # Explicitly expand buffers and mask to target shapes to avoid shape confusion
        B, L, D = carry.z_H.shape
        # Reduce/reset flag to per-batch boolean vector of shape [B]
        if reset_flag.ndim == 1 and reset_flag.shape[0] == B:
            reset_b = reset_flag.to(torch.bool)
        else:
            # If shape is [B, ...] reduce across non-batch dims; otherwise fallback to first B entries
            try:
                reset_b = reset_flag.reshape(B, -1).any(dim=1).to(torch.bool)
            except Exception:
                reset_b = reset_flag.reshape(-1)[:B].to(torch.bool)
        m = reset_b.view(B, 1, 1)
        mH = m.expand(B, L, D)
        mL = mH  # same shape for z_L
        H_init_exp = self.H_init.expand(B, L, D)
        L_init_exp = self.L_init.expand(B, L, D)
        return GLPS_ACTV1InnerCarry(
            z_H=torch.where(mH, H_init_exp, carry.z_H),
            z_L=torch.where(mL, L_init_exp, carry.z_L),
        )

    def _global_scan(self, z_L, z_H, input_embeddings, seq_info):
        # One light pass to gather global signals
        z_scan = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
        cand_logits = self.candidate_head(z_scan)              # [B, L, C]
        certainty   = torch.sigmoid(self.certainty_head(z_scan))  # [B, L, 1]
        return z_scan, cand_logits, certainty

    def _build_dep_mask(self, z_ctx, uncertain_mask: torch.Tensor):
        """Compute a dependency-based focus mask from a low-rank bilinear score.
        uncertain_mask: [B, L] boolean mask of cells that are currently uncertain
        Returns: dep_mask [B, L] boolean mask of cells to (re)update.
        """
        B, L, D = z_ctx.shape
        # Project to low-rank space; use float32 to avoid dtype mismatches under torch.compile
        Q = self.dep_q(z_ctx).to(torch.float32)   # [B, L, r]
        K = self.dep_k(z_ctx).to(torch.float32)   # [B, L, r]
        r = max(1, int(Q.shape[-1]))
        sim = torch.matmul(Q, K.transpose(1, 2))  # [B, L, L] (float32)
        sim = sim / math.sqrt(r)

        # Aggregate influence from uncertain queries onto target tokens
        src = uncertain_mask.to(sim.dtype).unsqueeze(1)      # [B, 1, L]
        influence = torch.matmul(src, sim).squeeze(1)        # [B, L]

        # Top-k influenced tokens per batch
        topk = min(self.config.dep_topk, L)
        vals, idx = torch.topk(influence, k=topk, dim=-1)    # [B, topk]
        dep_mask = torch.zeros_like(uncertain_mask)
        dep_mask.scatter_(1, idx, True)

        # Always include uncertain cells themselves
        dep_mask = dep_mask | uncertain_mask
        return dep_mask

    def forward(self, carry: GLPS_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]):
        seq_info = dict(cos_sin=getattr(self, "rotary_emb", None)() if hasattr(self, "rotary_emb") else None)

        # Encode inputs
        input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])

        # States
        z_H, z_L = carry.z_H, carry.z_L

        if not self.config.glps_enabled:
            # Fallback: run all cycles with gradients (TRM-style full backprop)
            for _H in range(self.config.H_cycles):
                for _L in range(self.config.L_cycles):
                    z_L = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
                z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)

            # Outputs
            new_carry = GLPS_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())
            logits = self.lm_head(z_H)[:, self.puzzle_emb_len:]
            q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
            conf = torch.zeros_like(q_logits[..., :1]) + 0.5  # neutral
            return new_carry, logits, (q_logits[..., 0], q_logits[..., 1]), conf

        # ===== GLPS path =====
        # H1: global scan (keep gradients to enable full backprop through recursion)
        z_scan, cand_logits, certainty = self._global_scan(z_L, z_H, input_embeddings, seq_info)

        # L1: fill-obvious -> compute stable vs uncertain masks
        if self.config.glps_fill_obvious:
            obvious_mask = (certainty >= self.config.glps_tau_uncertain)  # [B, L, 1]
        else:
            obvious_mask = torch.zeros_like(certainty).bool()
        stable_mask = obvious_mask.squeeze(-1)         # [B, L]
        uncertain_mask = ~stable_mask                  # [B, L]

        # H2: dependency prediction over remaining cells (no_grad; selection only)
        if self.config.glps_dep_graph:
            with torch.no_grad():
                dep_mask = self._build_dep_mask(z_scan, uncertain_mask)   # [B, L]
        else:
            dep_mask = uncertain_mask

        # L2: targeted refinement — run all iters with gradients (full backprop)
        update_mask = dep_mask if self.config.glps_token_masking else None
        z = z_scan  # use scanned context as start (no detach to keep gradients)
        iters = max(1, int(self.config.glps_max_targeted_iters))
        for _ in range(iters):
            z = self.L_level(z, z_H + input_embeddings, update_mask=update_mask, **seq_info)
            # Refresh certainty to shrink mask; mask ops are non-differentiable, keep them out of graph
            if self.config.glps_token_masking:
                with torch.no_grad():
                    cert_now = torch.sigmoid(self.certainty_head(z)).squeeze(-1)
                    update_mask = dep_mask & (cert_now < self.config.glps_tau_uncertain)

        # Merge into H and do a light H update with grad
        z_L = z
        z_H = self.H_level(z_H, z_L, update_mask=None, **seq_info)

        # H3: energy/consistency -> confidence & optional global propagate
        with torch.no_grad():
            energy = torch.sigmoid(self.energy_head(z_H)).mean(dim=1)  # [B, 1]
            conf = 1.0 - energy
            need_sweep = (conf.squeeze(-1) < self.config.glps_tau_halt)
            perform_sweep = self.config.glps_global_propagate_on_low_conf and bool(need_sweep.any())
        if perform_sweep:
            # one final full sweep only for rows needing it (run with gradients)
            maskB = need_sweep.view(-1, 1, 1)
            zL2 = self.L_level(z_L, z_H + input_embeddings, update_mask=None, **seq_info)
            zH2 = self.H_level(z_H, zL2, update_mask=None, **seq_info)
            z_L = torch.where(maskB, zL2, z_L)
            z_H = torch.where(maskB, zH2, z_H)

        # Outputs
        new_carry = GLPS_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())
        logits = self.lm_head(z_H)[:, self.puzzle_emb_len:]
        q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
        return new_carry, logits, (q_logits[..., 0], q_logits[..., 1]), conf

class GLPS_ACTV1(nn.Module):
    """ACT-style wrapper that mixes Q-halt with GLPS confidence."""
    def __init__(self, config_dict: dict):
        super().__init__()
        self.config = GLPS_ACTV1Config(**config_dict)
        self.inner = GLPS_ACTV1_Inner(self.config)

    @property
    def puzzle_emb(self):
        return self.inner.puzzle_emb

    def initial_carry(self, batch: Dict[str, torch.Tensor]):
        batch_size = batch["inputs"].shape[0]
        return GLPS_ACTV1Carry(
            inner_carry=self.inner.empty_carry(batch_size),
            steps=torch.zeros((batch_size,), dtype=torch.int32),
            halted=torch.ones((batch_size,), dtype=torch.bool),  # start halted to force reset on first pass
            current_data={k: torch.empty_like(v) for k, v in batch.items()}
        )

    def forward(self, carry: GLPS_ACTV1Carry, batch: Dict[str, torch.Tensor]):
        # Reset halted seqs
        new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
        new_steps = torch.where(carry.halted, 0, carry.steps)
        new_current_data = {k: torch.where(carry.halted.view((-1,) + (1,) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}

        # Inner step
        new_inner_carry, logits, (q_halt_logits, q_continue_logits), conf = self.inner(new_inner_carry, new_current_data)

        outputs = {
            "logits": logits,
            "q_halt_logits": q_halt_logits,
            "q_continue_logits": q_continue_logits,
            "conf": conf.squeeze(-1),
        }

        with torch.no_grad():
            new_steps = new_steps + 1
            is_last_step = new_steps >= self.config.halt_max_steps

            # Combine halt signals: max-steps, Q-head, and confidence
            if self.config.no_ACT_continue:
                # Simple -style: q_halt > 0 (no comparison with q_continue)
                q_halt_signal = (q_halt_logits > 0)
            else:
                # RL-style: compare q_halt vs q_continue
                q_halt_signal = (q_halt_logits > q_continue_logits)
            
            halted = is_last_step | q_halt_signal | (conf.squeeze(-1) >= self.config.glps_tau_halt)

            # Exploration during training only
            if self.training and (self.config.halt_max_steps > 1):
                min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
                halted = halted & (new_steps >= min_halt_steps)
                
                # Optional Q-learning target (only if using RL-style)
                if not self.config.no_ACT_continue:
                    _carry2, _logits2, (next_q_halt_logits, next_q_continue_logits), _conf2 = self.inner(new_inner_carry, new_current_data)
                    outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
            else:
                # During eval, always use max_steps to ensure consistent reasoning depth (same as / eval behavior)
                halted = is_last_step

        return GLPS_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs