File size: 16,263 Bytes
9477b5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc7101b
9477b5c
 
 
bc7101b
9477b5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc7101b
 
 
 
 
 
 
 
 
 
9477b5c
 
 
 
 
 
 
 
 
bc7101b
 
9477b5c
 
 
 
 
 
bc7101b
 
 
 
 
 
 
 
9477b5c
bc7101b
 
9477b5c
bc7101b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9477b5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc7101b
 
 
9477b5c
 
 
 
 
bc7101b
 
 
 
 
 
 
 
9477b5c
 
 
 
 
 
 
 
 
 
 
 
bc7101b
 
 
9477b5c
 
 
 
 
 
 
 
 
 
 
 
bc7101b
9477b5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc7101b
9477b5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
"""Continuous latent reasoning model.

Sequence layout (no <think>/</think> tokens — latent positions are inputs_embeds):

    [ x_tokens ; z_1, ..., z_K ; y_tokens ]
       ^^^^^^^    ^^^^^^^^^^^^   ^^^^^^^^
       discrete    continuous     discrete
                   (W_proj of      (gold answer
                    prev hidden)   during training)

Gradient flow: full backprop through z_t = W_proj(h_{t-1}). No sampling,
no torch.no_grad() in the latent path. The y-row attention mask blocks
attention to x columns so the latent is the only information channel.
"""
from __future__ import annotations

import math
from dataclasses import dataclass
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

NEG = -1e9


@dataclass
class BLTConfig:
    base_model: str = "Qwen/Qwen2.5-1.5B-Instruct"
    use_lora: bool = True
    lora_r: int = 16
    lora_alpha: int = 32
    lora_dropout: float = 0.05
    lora_target_modules: tuple = ("q_proj", "k_proj", "v_proj", "o_proj")
    K_latents: int = 4
    block_y_to_x: bool = True
    block_z_to_x: bool = False  # close the z→x architectural leak path (see build_blt_mask)
    proj_init_scale: float = 0.02
    dtype: str = "bfloat16"
    attn_impl: str = "eager"  # required for 4D additive mask
    gradient_checkpointing: bool = False  # trade compute for activation memory; needed for 7B


def build_base(cfg: BLTConfig):
    """Load tokenizer + base CausalLM, optionally wrap with LoRA."""
    from transformers import AutoModelForCausalLM, AutoTokenizer

    dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[cfg.dtype]
    tok = AutoTokenizer.from_pretrained(cfg.base_model, trust_remote_code=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    model = AutoModelForCausalLM.from_pretrained(
        cfg.base_model,
        torch_dtype=dtype,
        attn_implementation=cfg.attn_impl,
        trust_remote_code=True,
    )
    model.config.use_cache = False
    if getattr(cfg, "gradient_checkpointing", False):
        # Must enable BEFORE peft wrap; peft propagates the flag to the base model.
        # use_reentrant=False avoids the deprecation warning and is recommended for
        # modern HF + custom attention masks (our 4D mask path is non-trivial).
        model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
        # For HF checkpointing to actually propagate grads through inputs_embeds
        # (which we use for the latent loop), we need to make inputs require grad.
        # peft handles this via enable_input_require_grads on the wrapped model.
        if hasattr(model, "enable_input_require_grads"):
            model.enable_input_require_grads()
    if cfg.use_lora:
        from peft import LoraConfig, get_peft_model, TaskType
        lcfg = LoraConfig(
            r=cfg.lora_r, lora_alpha=cfg.lora_alpha, lora_dropout=cfg.lora_dropout,
            bias="none", task_type=TaskType.CAUSAL_LM,
            target_modules=list(cfg.lora_target_modules),
        )
        model = get_peft_model(model, lcfg)
        model.print_trainable_parameters()
        if getattr(cfg, "gradient_checkpointing", False) and hasattr(model, "enable_input_require_grads"):
            model.enable_input_require_grads()
    return model, tok


class LatentProjector(nn.Module):
    """Maps last-layer hidden state to next-step input embedding.

    Two variants, selected via `use_mlp`:
    * Linear (default, original): a single d→d linear layer, bias=False.
    * MLP: d → (hidden_mult·d) → d with GELU. More expressive non-linear
      compression — necessary if the single linear projection bottlenecks
      latent informativeness. Output bias is zeroed at init so the first
      forward is near-zero, mimicking the Linear variant's startup.

    `init_scale` controls the std of all weight initializations.
    """
    def __init__(self, d_model: int, init_scale: float = 0.02,
                 use_mlp: bool = False, hidden_mult: int = 4):
        super().__init__()
        self.use_mlp = use_mlp
        if use_mlp:
            d_hidden = d_model * hidden_mult
            self.proj = nn.Sequential(
                nn.Linear(d_model, d_hidden, bias=True),
                nn.GELU(),
                nn.Linear(d_hidden, d_model, bias=True),
            )
            nn.init.normal_(self.proj[0].weight, mean=0.0, std=init_scale)
            nn.init.zeros_(self.proj[0].bias)
            nn.init.normal_(self.proj[2].weight, mean=0.0, std=init_scale)
            nn.init.zeros_(self.proj[2].bias)
        else:
            self.proj = nn.Linear(d_model, d_model, bias=False)
            nn.init.normal_(self.proj.weight, mean=0.0, std=init_scale)

    def forward(self, h: torch.Tensor) -> torch.Tensor:
        return self.proj(h)


def _get_input_embeddings(model) -> nn.Module:
    """Returns the input embedding layer, working through PEFT wrap."""
    inner = model.get_base_model() if hasattr(model, "get_base_model") else model
    return inner.get_input_embeddings()


def _get_lm_head(model) -> nn.Module:
    inner = model.get_base_model() if hasattr(model, "get_base_model") else model
    return inner.get_output_embeddings()


def build_blt_mask(
    B: int, P: int, K: int, L_y: int, device, dtype, *,
    block_y_to_x: bool,
    block_z_to_x: bool = False,
) -> torch.Tensor:
    """4D additive attention mask [B, 1, T, T] with T = P + K + L_y.

    - Lower-triangular causal everywhere.
    - If block_y_to_x: y rows (positions [P+K, P+K+L_y)) cannot attend to
      x cols (positions [0, P)).
    - If block_z_to_x: z rows (positions [P, P+K)) ALSO cannot attend to x.
      This closes the architectural "leak" path where z hidden states in
      pass 2 could attend to x and deliver x-info to y bypassing z's input
      content. With block_z_to_x=True, z hidden states depend only on z
      input embeddings + z self-attention. The z input (= π(h_{t-1}) from
      pass 1) becomes the *only* carrier of x→y information, forcing z's
      input value to actually matter at inference.
    """
    T = P + K + L_y
    # Start with full -inf, fill 0 where attention is allowed.
    add = torch.full((B, 1, T, T), NEG, device=device, dtype=dtype)
    # Causal: allow j <= i
    row = torch.arange(T, device=device).unsqueeze(1)  # [T, 1]
    col = torch.arange(T, device=device).unsqueeze(0)  # [1, T]
    causal = (col <= row)                              # [T, T] bool
    add[:, 0, :, :] = torch.where(causal, torch.zeros_like(add[0, 0]), torch.full_like(add[0, 0], NEG))
    if block_y_to_x and P > 0 and L_y > 0:
        # zero-out y→x by re-applying NEG to the y-row × x-col block.
        add[:, 0, P + K : P + K + L_y, 0:P] = NEG
    if block_z_to_x and P > 0 and K > 0:
        # zero-out z→x: z rows cannot attend to x cols.
        add[:, 0, P : P + K, 0:P] = NEG
    return add


def forward_with_latent(
    model,
    x_ids: torch.Tensor,       # [B, P]
    x_attn: torch.Tensor,      # [B, P]  1=keep, 0=pad (left-padded)
    y_ids: Optional[torch.Tensor],   # [B, L_y]  None at inference
    projector: LatentProjector,
    K: int,
    *,
    block_y_to_x: bool = True,
    block_z_to_x: bool = False,
    return_z: bool = True,
):
    """Run [x; z_1..z_K; y] in two passes:

      pass-1 (KV-cached, with grad):  iteratively build z_1..z_K from
                                       the running last-layer hidden state.
      pass-2 (single full forward):    [embed(x); z_1..z_K; embed(y)] with
                                       custom 4D mask blocking y→x. Returns
                                       logits for y positions.

    Returns:
      logits_y : [B, L_y, V]   (None if y_ids is None)
      z        : [B, K, d]     latent vectors (with grad)
      h_last_y : [B, L_y, d]   last-layer hidden states at y positions (None if y is None)
    """
    inner = model.get_base_model() if hasattr(model, "get_base_model") else model
    embed_in = inner.get_input_embeddings()
    lm_head = inner.get_output_embeddings()
    device = x_ids.device
    dtype = embed_in.weight.dtype
    B, P = x_ids.shape

    # ---- Pass 1: iterative z construction with KV cache, grad retained ----
    # Initial forward over x to produce the running last-position hidden state.
    # We use the underlying base model (`inner.model`) for hidden-state access.
    base_lm = inner  # e.g., Qwen2ForCausalLM
    transformer = base_lm.model  # Qwen2Model

    x_embeds = embed_in(x_ids)
    out0 = transformer(
        inputs_embeds=x_embeds,
        attention_mask=x_attn,
        use_cache=True,
        return_dict=True,
    )
    past = out0.past_key_values
    # Grab last-token hidden state, accounting for left-pad: use the last
    # non-pad position. Since we left-pad, the last position is always real.
    h_prev = out0.last_hidden_state[:, -1, :]   # [B, d]

    z_list: List[torch.Tensor] = []
    cur_attn = x_attn
    for t in range(K):
        z_t = projector(h_prev)                        # [B, d]
        z_list.append(z_t)
        cur_attn = torch.cat(
            [cur_attn, torch.ones(B, 1, device=device, dtype=cur_attn.dtype)], dim=1
        )
        out_t = transformer(
            inputs_embeds=z_t.unsqueeze(1),
            attention_mask=cur_attn,
            past_key_values=past,
            use_cache=True,
            return_dict=True,
        )
        past = out_t.past_key_values
        h_prev = out_t.last_hidden_state[:, -1, :]

    z = torch.stack(z_list, dim=1)        # [B, K, d]

    if y_ids is None:
        return None, z, None

    # ---- Pass 2: full forward with custom mask, no past_kv ----
    y_embeds = embed_in(y_ids)
    L_y = y_ids.size(1)
    # Cast z to the embedding dtype to match.
    full_embeds = torch.cat([x_embeds, z.to(y_embeds.dtype), y_embeds], dim=1)
    full_4d = build_blt_mask(B, P, K, L_y, device=device, dtype=full_embeds.dtype,
                              block_y_to_x=block_y_to_x, block_z_to_x=block_z_to_x)

    # We also need to respect x pad columns (left-pad → kv positions in x
    # that are pad should be masked from EVERYTHING, including latents).
    if (x_attn == 0).any():
        # Build a 1D mask of pad columns: True where pad.
        pad_cols = (x_attn == 0)             # [B, P]
        pad_kv = torch.cat([pad_cols, torch.zeros(B, K + L_y, device=device, dtype=torch.bool)], dim=1)
        # Broadcast: for each (b), set add[b, 0, :, j] = NEG where pad_kv[b, j].
        full_4d = full_4d.clone()
        full_4d.masked_fill_(pad_kv[:, None, None, :], NEG)

    out2 = transformer(
        inputs_embeds=full_embeds,
        attention_mask=full_4d,
        use_cache=False,
        return_dict=True,
    )
    h_full = out2.last_hidden_state            # [B, T, d]
    # logits over y *predictions*: position t predicts token t+1, so for the
    # y-segment we read logits at positions [P+K-1, P+K+L_y-1) and compare
    # with y_ids[:, :L_y].
    logits_all = lm_head(h_full)               # [B, T, V]
    pred_slice = logits_all[:, P + K - 1 : P + K - 1 + L_y, :]   # [B, L_y, V]
    h_last_y = h_full[:, P + K : P + K + L_y, :]                 # [B, L_y, d]

    return pred_slice, z, h_last_y


@torch.no_grad()
def generate_with_latent(
    model,
    tokenizer,
    projector: LatentProjector,
    x_ids: torch.Tensor,         # [B, P]
    x_attn: torch.Tensor,
    K: int,
    *,
    block_y_to_x: bool = True,
    max_new_tokens: int = 256,
    temperature: float = 0.0,
    eos_token_id: Optional[int] = None,
    override_z: Optional[torch.Tensor] = None,  # [B, K, d] forced latents (ablation)
):
    """Greedy / temperature decoding with the latent loop.

    override_z: if provided, skip the latent-loop pass and use these latents
                directly. For ablations: random-z (gaussian noise), zero-z
                (K=0), shuffled-z, etc.
    """
    inner = model.get_base_model() if hasattr(model, "get_base_model") else model
    transformer = inner.model
    embed_in = inner.get_input_embeddings()
    lm_head = inner.get_output_embeddings()
    device = x_ids.device
    B, P = x_ids.shape
    eos = eos_token_id if eos_token_id is not None else tokenizer.eos_token_id

    x_embeds = embed_in(x_ids)

    # ---- z (computed or overridden) ----
    if override_z is not None:
        K_eff = override_z.size(1)
        z = override_z
        # Still need to "consume" x and the latents through the transformer
        # to build past_kv used for answer generation. Do a single forward.
        full_embeds = torch.cat([x_embeds, z.to(x_embeds.dtype)], dim=1)
        cur_attn = torch.cat(
            [x_attn, torch.ones(B, K_eff, device=device, dtype=x_attn.dtype)], dim=1
        )
        # Build a 4D mask: causal + x-pads masked
        T0 = P + K_eff
        add = torch.full((B, 1, T0, T0), NEG, device=device, dtype=x_embeds.dtype)
        row = torch.arange(T0, device=device).unsqueeze(1)
        col = torch.arange(T0, device=device).unsqueeze(0)
        causal = (col <= row)
        add[:, 0, :, :] = torch.where(causal, torch.zeros_like(add[0, 0]),
                                       torch.full_like(add[0, 0], NEG))
        if (x_attn == 0).any():
            pad_kv = torch.cat([(x_attn == 0),
                                torch.zeros(B, K_eff, device=device, dtype=torch.bool)], dim=1)
            add.masked_fill_(pad_kv[:, None, None, :], NEG)
        out0 = transformer(inputs_embeds=full_embeds, attention_mask=add,
                           use_cache=True, return_dict=True)
        past = out0.past_key_values
        h_last = out0.last_hidden_state[:, -1, :]
    else:
        K_eff = K
        out0 = transformer(inputs_embeds=x_embeds, attention_mask=x_attn,
                           use_cache=True, return_dict=True)
        past = out0.past_key_values
        h_prev = out0.last_hidden_state[:, -1, :]
        cur_attn = x_attn
        for t in range(K):
            z_t = projector(h_prev)
            cur_attn = torch.cat([cur_attn, torch.ones(B, 1, device=device, dtype=cur_attn.dtype)], dim=1)
            out_t = transformer(inputs_embeds=z_t.unsqueeze(1), attention_mask=cur_attn,
                                past_key_values=past, use_cache=True, return_dict=True)
            past = out_t.past_key_values
            h_prev = out_t.last_hidden_state[:, -1, :]
        h_last = h_prev

    # ---- Answer phase: autoregressive decoding.  ----
    # When block_y_to_x is on, we need y rows to not attend to the first P kv
    # positions. With KV cache + eager, we pass a 2D attn mask over kv-length
    # where the x portion is 0. This zeroes out x in additive form.
    # NB: We zero x but keep latent + prior y at 1.
    gen_ids = []
    last_logits = lm_head(h_last)   # [B, V]

    # Build a base attn mask for y queries: 0 over x, 1 over latents, 1 over prior y.
    # Sequence length grows by 1 each step.
    y_kv_base = torch.cat(
        [torch.zeros(B, P, device=device, dtype=cur_attn.dtype) if block_y_to_x else x_attn,
         torch.ones(B, K_eff, device=device, dtype=cur_attn.dtype)],
        dim=1,
    )

    done = torch.zeros(B, dtype=torch.bool, device=device)
    for step in range(max_new_tokens):
        if temperature <= 0.0:
            nxt = last_logits.argmax(dim=-1)
        else:
            probs = torch.softmax(last_logits.float() / max(temperature, 1e-6), dim=-1)
            nxt = torch.multinomial(probs, num_samples=1).squeeze(-1)
        nxt = torch.where(done, torch.full_like(nxt, tokenizer.pad_token_id), nxt)
        gen_ids.append(nxt)
        new_done = done | (nxt == eos)
        if bool(new_done.all().item()):
            done = new_done
            break
        done = new_done

        y_emb = embed_in(nxt.unsqueeze(-1))    # [B, 1, d]
        y_kv_base = torch.cat([y_kv_base, torch.ones(B, 1, device=device, dtype=y_kv_base.dtype)], dim=1)
        out = transformer(inputs_embeds=y_emb, attention_mask=y_kv_base,
                          past_key_values=past, use_cache=True, return_dict=True)
        past = out.past_key_values
        last_logits = lm_head(out.last_hidden_state[:, -1, :])

    return torch.stack(gen_ids, dim=1)        # [B, L_gen]