| """Continuous latent reasoning model. |
| |
| Sequence layout (no <think>/</think> tokens — latent positions are inputs_embeds): |
| |
| [ x_tokens ; z_1, ..., z_K ; y_tokens ] |
| ^^^^^^^ ^^^^^^^^^^^^ ^^^^^^^^ |
| discrete continuous discrete |
| (W_proj of (gold answer |
| prev hidden) during training) |
| |
| Gradient flow: full backprop through z_t = W_proj(h_{t-1}). No sampling, |
| no torch.no_grad() in the latent path. The y-row attention mask blocks |
| attention to x columns so the latent is the only information channel. |
| """ |
| from __future__ import annotations |
|
|
| import math |
| from dataclasses import dataclass |
| from typing import List, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| NEG = -1e9 |
|
|
|
|
| @dataclass |
| class BLTConfig: |
| base_model: str = "Qwen/Qwen2.5-1.5B-Instruct" |
| use_lora: bool = True |
| lora_r: int = 16 |
| lora_alpha: int = 32 |
| lora_dropout: float = 0.05 |
| lora_target_modules: tuple = ("q_proj", "k_proj", "v_proj", "o_proj") |
| K_latents: int = 4 |
| block_y_to_x: bool = True |
| block_z_to_x: bool = False |
| proj_init_scale: float = 0.02 |
| dtype: str = "bfloat16" |
| attn_impl: str = "eager" |
| gradient_checkpointing: bool = False |
|
|
|
|
| def build_base(cfg: BLTConfig): |
| """Load tokenizer + base CausalLM, optionally wrap with LoRA.""" |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[cfg.dtype] |
| tok = AutoTokenizer.from_pretrained(cfg.base_model, trust_remote_code=True) |
| if tok.pad_token is None: |
| tok.pad_token = tok.eos_token |
| model = AutoModelForCausalLM.from_pretrained( |
| cfg.base_model, |
| torch_dtype=dtype, |
| attn_implementation=cfg.attn_impl, |
| trust_remote_code=True, |
| ) |
| model.config.use_cache = False |
| if getattr(cfg, "gradient_checkpointing", False): |
| |
| |
| |
| model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) |
| |
| |
| |
| if hasattr(model, "enable_input_require_grads"): |
| model.enable_input_require_grads() |
| if cfg.use_lora: |
| from peft import LoraConfig, get_peft_model, TaskType |
| lcfg = LoraConfig( |
| r=cfg.lora_r, lora_alpha=cfg.lora_alpha, lora_dropout=cfg.lora_dropout, |
| bias="none", task_type=TaskType.CAUSAL_LM, |
| target_modules=list(cfg.lora_target_modules), |
| ) |
| model = get_peft_model(model, lcfg) |
| model.print_trainable_parameters() |
| if getattr(cfg, "gradient_checkpointing", False) and hasattr(model, "enable_input_require_grads"): |
| model.enable_input_require_grads() |
| return model, tok |
|
|
|
|
| class LatentProjector(nn.Module): |
| """Maps last-layer hidden state to next-step input embedding. |
| |
| Two variants, selected via `use_mlp`: |
| * Linear (default, original): a single d→d linear layer, bias=False. |
| * MLP: d → (hidden_mult·d) → d with GELU. More expressive non-linear |
| compression — necessary if the single linear projection bottlenecks |
| latent informativeness. Output bias is zeroed at init so the first |
| forward is near-zero, mimicking the Linear variant's startup. |
| |
| `init_scale` controls the std of all weight initializations. |
| """ |
| def __init__(self, d_model: int, init_scale: float = 0.02, |
| use_mlp: bool = False, hidden_mult: int = 4): |
| super().__init__() |
| self.use_mlp = use_mlp |
| if use_mlp: |
| d_hidden = d_model * hidden_mult |
| self.proj = nn.Sequential( |
| nn.Linear(d_model, d_hidden, bias=True), |
| nn.GELU(), |
| nn.Linear(d_hidden, d_model, bias=True), |
| ) |
| nn.init.normal_(self.proj[0].weight, mean=0.0, std=init_scale) |
| nn.init.zeros_(self.proj[0].bias) |
| nn.init.normal_(self.proj[2].weight, mean=0.0, std=init_scale) |
| nn.init.zeros_(self.proj[2].bias) |
| else: |
| self.proj = nn.Linear(d_model, d_model, bias=False) |
| nn.init.normal_(self.proj.weight, mean=0.0, std=init_scale) |
|
|
| def forward(self, h: torch.Tensor) -> torch.Tensor: |
| return self.proj(h) |
|
|
|
|
| def _get_input_embeddings(model) -> nn.Module: |
| """Returns the input embedding layer, working through PEFT wrap.""" |
| inner = model.get_base_model() if hasattr(model, "get_base_model") else model |
| return inner.get_input_embeddings() |
|
|
|
|
| def _get_lm_head(model) -> nn.Module: |
| inner = model.get_base_model() if hasattr(model, "get_base_model") else model |
| return inner.get_output_embeddings() |
|
|
|
|
| def build_blt_mask( |
| B: int, P: int, K: int, L_y: int, device, dtype, *, |
| block_y_to_x: bool, |
| block_z_to_x: bool = False, |
| ) -> torch.Tensor: |
| """4D additive attention mask [B, 1, T, T] with T = P + K + L_y. |
| |
| - Lower-triangular causal everywhere. |
| - If block_y_to_x: y rows (positions [P+K, P+K+L_y)) cannot attend to |
| x cols (positions [0, P)). |
| - If block_z_to_x: z rows (positions [P, P+K)) ALSO cannot attend to x. |
| This closes the architectural "leak" path where z hidden states in |
| pass 2 could attend to x and deliver x-info to y bypassing z's input |
| content. With block_z_to_x=True, z hidden states depend only on z |
| input embeddings + z self-attention. The z input (= π(h_{t-1}) from |
| pass 1) becomes the *only* carrier of x→y information, forcing z's |
| input value to actually matter at inference. |
| """ |
| T = P + K + L_y |
| |
| add = torch.full((B, 1, T, T), NEG, device=device, dtype=dtype) |
| |
| row = torch.arange(T, device=device).unsqueeze(1) |
| col = torch.arange(T, device=device).unsqueeze(0) |
| causal = (col <= row) |
| add[:, 0, :, :] = torch.where(causal, torch.zeros_like(add[0, 0]), torch.full_like(add[0, 0], NEG)) |
| if block_y_to_x and P > 0 and L_y > 0: |
| |
| add[:, 0, P + K : P + K + L_y, 0:P] = NEG |
| if block_z_to_x and P > 0 and K > 0: |
| |
| add[:, 0, P : P + K, 0:P] = NEG |
| return add |
|
|
|
|
| def forward_with_latent( |
| model, |
| x_ids: torch.Tensor, |
| x_attn: torch.Tensor, |
| y_ids: Optional[torch.Tensor], |
| projector: LatentProjector, |
| K: int, |
| *, |
| block_y_to_x: bool = True, |
| block_z_to_x: bool = False, |
| return_z: bool = True, |
| ): |
| """Run [x; z_1..z_K; y] in two passes: |
| |
| pass-1 (KV-cached, with grad): iteratively build z_1..z_K from |
| the running last-layer hidden state. |
| pass-2 (single full forward): [embed(x); z_1..z_K; embed(y)] with |
| custom 4D mask blocking y→x. Returns |
| logits for y positions. |
| |
| Returns: |
| logits_y : [B, L_y, V] (None if y_ids is None) |
| z : [B, K, d] latent vectors (with grad) |
| h_last_y : [B, L_y, d] last-layer hidden states at y positions (None if y is None) |
| """ |
| inner = model.get_base_model() if hasattr(model, "get_base_model") else model |
| embed_in = inner.get_input_embeddings() |
| lm_head = inner.get_output_embeddings() |
| device = x_ids.device |
| dtype = embed_in.weight.dtype |
| B, P = x_ids.shape |
|
|
| |
| |
| |
| base_lm = inner |
| transformer = base_lm.model |
|
|
| x_embeds = embed_in(x_ids) |
| out0 = transformer( |
| inputs_embeds=x_embeds, |
| attention_mask=x_attn, |
| use_cache=True, |
| return_dict=True, |
| ) |
| past = out0.past_key_values |
| |
| |
| h_prev = out0.last_hidden_state[:, -1, :] |
|
|
| z_list: List[torch.Tensor] = [] |
| cur_attn = x_attn |
| for t in range(K): |
| z_t = projector(h_prev) |
| z_list.append(z_t) |
| cur_attn = torch.cat( |
| [cur_attn, torch.ones(B, 1, device=device, dtype=cur_attn.dtype)], dim=1 |
| ) |
| out_t = transformer( |
| inputs_embeds=z_t.unsqueeze(1), |
| attention_mask=cur_attn, |
| past_key_values=past, |
| use_cache=True, |
| return_dict=True, |
| ) |
| past = out_t.past_key_values |
| h_prev = out_t.last_hidden_state[:, -1, :] |
|
|
| z = torch.stack(z_list, dim=1) |
|
|
| if y_ids is None: |
| return None, z, None |
|
|
| |
| y_embeds = embed_in(y_ids) |
| L_y = y_ids.size(1) |
| |
| full_embeds = torch.cat([x_embeds, z.to(y_embeds.dtype), y_embeds], dim=1) |
| full_4d = build_blt_mask(B, P, K, L_y, device=device, dtype=full_embeds.dtype, |
| block_y_to_x=block_y_to_x, block_z_to_x=block_z_to_x) |
|
|
| |
| |
| if (x_attn == 0).any(): |
| |
| pad_cols = (x_attn == 0) |
| pad_kv = torch.cat([pad_cols, torch.zeros(B, K + L_y, device=device, dtype=torch.bool)], dim=1) |
| |
| full_4d = full_4d.clone() |
| full_4d.masked_fill_(pad_kv[:, None, None, :], NEG) |
|
|
| out2 = transformer( |
| inputs_embeds=full_embeds, |
| attention_mask=full_4d, |
| use_cache=False, |
| return_dict=True, |
| ) |
| h_full = out2.last_hidden_state |
| |
| |
| |
| logits_all = lm_head(h_full) |
| pred_slice = logits_all[:, P + K - 1 : P + K - 1 + L_y, :] |
| h_last_y = h_full[:, P + K : P + K + L_y, :] |
|
|
| return pred_slice, z, h_last_y |
|
|
|
|
| @torch.no_grad() |
| def generate_with_latent( |
| model, |
| tokenizer, |
| projector: LatentProjector, |
| x_ids: torch.Tensor, |
| x_attn: torch.Tensor, |
| K: int, |
| *, |
| block_y_to_x: bool = True, |
| max_new_tokens: int = 256, |
| temperature: float = 0.0, |
| eos_token_id: Optional[int] = None, |
| override_z: Optional[torch.Tensor] = None, |
| ): |
| """Greedy / temperature decoding with the latent loop. |
| |
| override_z: if provided, skip the latent-loop pass and use these latents |
| directly. For ablations: random-z (gaussian noise), zero-z |
| (K=0), shuffled-z, etc. |
| """ |
| inner = model.get_base_model() if hasattr(model, "get_base_model") else model |
| transformer = inner.model |
| embed_in = inner.get_input_embeddings() |
| lm_head = inner.get_output_embeddings() |
| device = x_ids.device |
| B, P = x_ids.shape |
| eos = eos_token_id if eos_token_id is not None else tokenizer.eos_token_id |
|
|
| x_embeds = embed_in(x_ids) |
|
|
| |
| if override_z is not None: |
| K_eff = override_z.size(1) |
| z = override_z |
| |
| |
| full_embeds = torch.cat([x_embeds, z.to(x_embeds.dtype)], dim=1) |
| cur_attn = torch.cat( |
| [x_attn, torch.ones(B, K_eff, device=device, dtype=x_attn.dtype)], dim=1 |
| ) |
| |
| T0 = P + K_eff |
| add = torch.full((B, 1, T0, T0), NEG, device=device, dtype=x_embeds.dtype) |
| row = torch.arange(T0, device=device).unsqueeze(1) |
| col = torch.arange(T0, device=device).unsqueeze(0) |
| causal = (col <= row) |
| add[:, 0, :, :] = torch.where(causal, torch.zeros_like(add[0, 0]), |
| torch.full_like(add[0, 0], NEG)) |
| if (x_attn == 0).any(): |
| pad_kv = torch.cat([(x_attn == 0), |
| torch.zeros(B, K_eff, device=device, dtype=torch.bool)], dim=1) |
| add.masked_fill_(pad_kv[:, None, None, :], NEG) |
| out0 = transformer(inputs_embeds=full_embeds, attention_mask=add, |
| use_cache=True, return_dict=True) |
| past = out0.past_key_values |
| h_last = out0.last_hidden_state[:, -1, :] |
| else: |
| K_eff = K |
| out0 = transformer(inputs_embeds=x_embeds, attention_mask=x_attn, |
| use_cache=True, return_dict=True) |
| past = out0.past_key_values |
| h_prev = out0.last_hidden_state[:, -1, :] |
| cur_attn = x_attn |
| for t in range(K): |
| z_t = projector(h_prev) |
| cur_attn = torch.cat([cur_attn, torch.ones(B, 1, device=device, dtype=cur_attn.dtype)], dim=1) |
| out_t = transformer(inputs_embeds=z_t.unsqueeze(1), attention_mask=cur_attn, |
| past_key_values=past, use_cache=True, return_dict=True) |
| past = out_t.past_key_values |
| h_prev = out_t.last_hidden_state[:, -1, :] |
| h_last = h_prev |
|
|
| |
| |
| |
| |
| |
| gen_ids = [] |
| last_logits = lm_head(h_last) |
|
|
| |
| |
| y_kv_base = torch.cat( |
| [torch.zeros(B, P, device=device, dtype=cur_attn.dtype) if block_y_to_x else x_attn, |
| torch.ones(B, K_eff, device=device, dtype=cur_attn.dtype)], |
| dim=1, |
| ) |
|
|
| done = torch.zeros(B, dtype=torch.bool, device=device) |
| for step in range(max_new_tokens): |
| if temperature <= 0.0: |
| nxt = last_logits.argmax(dim=-1) |
| else: |
| probs = torch.softmax(last_logits.float() / max(temperature, 1e-6), dim=-1) |
| nxt = torch.multinomial(probs, num_samples=1).squeeze(-1) |
| nxt = torch.where(done, torch.full_like(nxt, tokenizer.pad_token_id), nxt) |
| gen_ids.append(nxt) |
| new_done = done | (nxt == eos) |
| if bool(new_done.all().item()): |
| done = new_done |
| break |
| done = new_done |
|
|
| y_emb = embed_in(nxt.unsqueeze(-1)) |
| y_kv_base = torch.cat([y_kv_base, torch.ones(B, 1, device=device, dtype=y_kv_base.dtype)], dim=1) |
| out = transformer(inputs_embeds=y_emb, attention_mask=y_kv_base, |
| past_key_values=past, use_cache=True, return_dict=True) |
| past = out.past_key_values |
| last_logits = lm_head(out.last_hidden_state[:, -1, :]) |
|
|
| return torch.stack(gen_ids, dim=1) |
|
|