"""Continuous latent reasoning model. Sequence layout (no / tokens — latent positions are inputs_embeds): [ x_tokens ; z_1, ..., z_K ; y_tokens ] ^^^^^^^ ^^^^^^^^^^^^ ^^^^^^^^ discrete continuous discrete (W_proj of (gold answer prev hidden) during training) Gradient flow: full backprop through z_t = W_proj(h_{t-1}). No sampling, no torch.no_grad() in the latent path. The y-row attention mask blocks attention to x columns so the latent is the only information channel. """ from __future__ import annotations import math from dataclasses import dataclass from typing import List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F NEG = -1e9 @dataclass class BLTConfig: base_model: str = "Qwen/Qwen2.5-1.5B-Instruct" use_lora: bool = True lora_r: int = 16 lora_alpha: int = 32 lora_dropout: float = 0.05 lora_target_modules: tuple = ("q_proj", "k_proj", "v_proj", "o_proj") K_latents: int = 4 block_y_to_x: bool = True block_z_to_x: bool = False # close the z→x architectural leak path (see build_blt_mask) proj_init_scale: float = 0.02 dtype: str = "bfloat16" attn_impl: str = "eager" # required for 4D additive mask gradient_checkpointing: bool = False # trade compute for activation memory; needed for 7B def build_base(cfg: BLTConfig): """Load tokenizer + base CausalLM, optionally wrap with LoRA.""" from transformers import AutoModelForCausalLM, AutoTokenizer dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[cfg.dtype] tok = AutoTokenizer.from_pretrained(cfg.base_model, trust_remote_code=True) if tok.pad_token is None: tok.pad_token = tok.eos_token model = AutoModelForCausalLM.from_pretrained( cfg.base_model, torch_dtype=dtype, attn_implementation=cfg.attn_impl, trust_remote_code=True, ) model.config.use_cache = False if getattr(cfg, "gradient_checkpointing", False): # Must enable BEFORE peft wrap; peft propagates the flag to the base model. # use_reentrant=False avoids the deprecation warning and is recommended for # modern HF + custom attention masks (our 4D mask path is non-trivial). model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) # For HF checkpointing to actually propagate grads through inputs_embeds # (which we use for the latent loop), we need to make inputs require grad. # peft handles this via enable_input_require_grads on the wrapped model. if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() if cfg.use_lora: from peft import LoraConfig, get_peft_model, TaskType lcfg = LoraConfig( r=cfg.lora_r, lora_alpha=cfg.lora_alpha, lora_dropout=cfg.lora_dropout, bias="none", task_type=TaskType.CAUSAL_LM, target_modules=list(cfg.lora_target_modules), ) model = get_peft_model(model, lcfg) model.print_trainable_parameters() if getattr(cfg, "gradient_checkpointing", False) and hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() return model, tok class LatentProjector(nn.Module): """Maps last-layer hidden state to next-step input embedding. Two variants, selected via `use_mlp`: * Linear (default, original): a single d→d linear layer, bias=False. * MLP: d → (hidden_mult·d) → d with GELU. More expressive non-linear compression — necessary if the single linear projection bottlenecks latent informativeness. Output bias is zeroed at init so the first forward is near-zero, mimicking the Linear variant's startup. `init_scale` controls the std of all weight initializations. """ def __init__(self, d_model: int, init_scale: float = 0.02, use_mlp: bool = False, hidden_mult: int = 4): super().__init__() self.use_mlp = use_mlp if use_mlp: d_hidden = d_model * hidden_mult self.proj = nn.Sequential( nn.Linear(d_model, d_hidden, bias=True), nn.GELU(), nn.Linear(d_hidden, d_model, bias=True), ) nn.init.normal_(self.proj[0].weight, mean=0.0, std=init_scale) nn.init.zeros_(self.proj[0].bias) nn.init.normal_(self.proj[2].weight, mean=0.0, std=init_scale) nn.init.zeros_(self.proj[2].bias) else: self.proj = nn.Linear(d_model, d_model, bias=False) nn.init.normal_(self.proj.weight, mean=0.0, std=init_scale) def forward(self, h: torch.Tensor) -> torch.Tensor: return self.proj(h) def _get_input_embeddings(model) -> nn.Module: """Returns the input embedding layer, working through PEFT wrap.""" inner = model.get_base_model() if hasattr(model, "get_base_model") else model return inner.get_input_embeddings() def _get_lm_head(model) -> nn.Module: inner = model.get_base_model() if hasattr(model, "get_base_model") else model return inner.get_output_embeddings() def build_blt_mask( B: int, P: int, K: int, L_y: int, device, dtype, *, block_y_to_x: bool, block_z_to_x: bool = False, ) -> torch.Tensor: """4D additive attention mask [B, 1, T, T] with T = P + K + L_y. - Lower-triangular causal everywhere. - If block_y_to_x: y rows (positions [P+K, P+K+L_y)) cannot attend to x cols (positions [0, P)). - If block_z_to_x: z rows (positions [P, P+K)) ALSO cannot attend to x. This closes the architectural "leak" path where z hidden states in pass 2 could attend to x and deliver x-info to y bypassing z's input content. With block_z_to_x=True, z hidden states depend only on z input embeddings + z self-attention. The z input (= π(h_{t-1}) from pass 1) becomes the *only* carrier of x→y information, forcing z's input value to actually matter at inference. """ T = P + K + L_y # Start with full -inf, fill 0 where attention is allowed. add = torch.full((B, 1, T, T), NEG, device=device, dtype=dtype) # Causal: allow j <= i row = torch.arange(T, device=device).unsqueeze(1) # [T, 1] col = torch.arange(T, device=device).unsqueeze(0) # [1, T] causal = (col <= row) # [T, T] bool add[:, 0, :, :] = torch.where(causal, torch.zeros_like(add[0, 0]), torch.full_like(add[0, 0], NEG)) if block_y_to_x and P > 0 and L_y > 0: # zero-out y→x by re-applying NEG to the y-row × x-col block. add[:, 0, P + K : P + K + L_y, 0:P] = NEG if block_z_to_x and P > 0 and K > 0: # zero-out z→x: z rows cannot attend to x cols. add[:, 0, P : P + K, 0:P] = NEG return add def forward_with_latent( model, x_ids: torch.Tensor, # [B, P] x_attn: torch.Tensor, # [B, P] 1=keep, 0=pad (left-padded) y_ids: Optional[torch.Tensor], # [B, L_y] None at inference projector: LatentProjector, K: int, *, block_y_to_x: bool = True, block_z_to_x: bool = False, return_z: bool = True, ): """Run [x; z_1..z_K; y] in two passes: pass-1 (KV-cached, with grad): iteratively build z_1..z_K from the running last-layer hidden state. pass-2 (single full forward): [embed(x); z_1..z_K; embed(y)] with custom 4D mask blocking y→x. Returns logits for y positions. Returns: logits_y : [B, L_y, V] (None if y_ids is None) z : [B, K, d] latent vectors (with grad) h_last_y : [B, L_y, d] last-layer hidden states at y positions (None if y is None) """ inner = model.get_base_model() if hasattr(model, "get_base_model") else model embed_in = inner.get_input_embeddings() lm_head = inner.get_output_embeddings() device = x_ids.device dtype = embed_in.weight.dtype B, P = x_ids.shape # ---- Pass 1: iterative z construction with KV cache, grad retained ---- # Initial forward over x to produce the running last-position hidden state. # We use the underlying base model (`inner.model`) for hidden-state access. base_lm = inner # e.g., Qwen2ForCausalLM transformer = base_lm.model # Qwen2Model x_embeds = embed_in(x_ids) out0 = transformer( inputs_embeds=x_embeds, attention_mask=x_attn, use_cache=True, return_dict=True, ) past = out0.past_key_values # Grab last-token hidden state, accounting for left-pad: use the last # non-pad position. Since we left-pad, the last position is always real. h_prev = out0.last_hidden_state[:, -1, :] # [B, d] z_list: List[torch.Tensor] = [] cur_attn = x_attn for t in range(K): z_t = projector(h_prev) # [B, d] z_list.append(z_t) cur_attn = torch.cat( [cur_attn, torch.ones(B, 1, device=device, dtype=cur_attn.dtype)], dim=1 ) out_t = transformer( inputs_embeds=z_t.unsqueeze(1), attention_mask=cur_attn, past_key_values=past, use_cache=True, return_dict=True, ) past = out_t.past_key_values h_prev = out_t.last_hidden_state[:, -1, :] z = torch.stack(z_list, dim=1) # [B, K, d] if y_ids is None: return None, z, None # ---- Pass 2: full forward with custom mask, no past_kv ---- y_embeds = embed_in(y_ids) L_y = y_ids.size(1) # Cast z to the embedding dtype to match. full_embeds = torch.cat([x_embeds, z.to(y_embeds.dtype), y_embeds], dim=1) full_4d = build_blt_mask(B, P, K, L_y, device=device, dtype=full_embeds.dtype, block_y_to_x=block_y_to_x, block_z_to_x=block_z_to_x) # We also need to respect x pad columns (left-pad → kv positions in x # that are pad should be masked from EVERYTHING, including latents). if (x_attn == 0).any(): # Build a 1D mask of pad columns: True where pad. pad_cols = (x_attn == 0) # [B, P] pad_kv = torch.cat([pad_cols, torch.zeros(B, K + L_y, device=device, dtype=torch.bool)], dim=1) # Broadcast: for each (b), set add[b, 0, :, j] = NEG where pad_kv[b, j]. full_4d = full_4d.clone() full_4d.masked_fill_(pad_kv[:, None, None, :], NEG) out2 = transformer( inputs_embeds=full_embeds, attention_mask=full_4d, use_cache=False, return_dict=True, ) h_full = out2.last_hidden_state # [B, T, d] # logits over y *predictions*: position t predicts token t+1, so for the # y-segment we read logits at positions [P+K-1, P+K+L_y-1) and compare # with y_ids[:, :L_y]. logits_all = lm_head(h_full) # [B, T, V] pred_slice = logits_all[:, P + K - 1 : P + K - 1 + L_y, :] # [B, L_y, V] h_last_y = h_full[:, P + K : P + K + L_y, :] # [B, L_y, d] return pred_slice, z, h_last_y @torch.no_grad() def generate_with_latent( model, tokenizer, projector: LatentProjector, x_ids: torch.Tensor, # [B, P] x_attn: torch.Tensor, K: int, *, block_y_to_x: bool = True, max_new_tokens: int = 256, temperature: float = 0.0, eos_token_id: Optional[int] = None, override_z: Optional[torch.Tensor] = None, # [B, K, d] forced latents (ablation) ): """Greedy / temperature decoding with the latent loop. override_z: if provided, skip the latent-loop pass and use these latents directly. For ablations: random-z (gaussian noise), zero-z (K=0), shuffled-z, etc. """ inner = model.get_base_model() if hasattr(model, "get_base_model") else model transformer = inner.model embed_in = inner.get_input_embeddings() lm_head = inner.get_output_embeddings() device = x_ids.device B, P = x_ids.shape eos = eos_token_id if eos_token_id is not None else tokenizer.eos_token_id x_embeds = embed_in(x_ids) # ---- z (computed or overridden) ---- if override_z is not None: K_eff = override_z.size(1) z = override_z # Still need to "consume" x and the latents through the transformer # to build past_kv used for answer generation. Do a single forward. full_embeds = torch.cat([x_embeds, z.to(x_embeds.dtype)], dim=1) cur_attn = torch.cat( [x_attn, torch.ones(B, K_eff, device=device, dtype=x_attn.dtype)], dim=1 ) # Build a 4D mask: causal + x-pads masked T0 = P + K_eff add = torch.full((B, 1, T0, T0), NEG, device=device, dtype=x_embeds.dtype) row = torch.arange(T0, device=device).unsqueeze(1) col = torch.arange(T0, device=device).unsqueeze(0) causal = (col <= row) add[:, 0, :, :] = torch.where(causal, torch.zeros_like(add[0, 0]), torch.full_like(add[0, 0], NEG)) if (x_attn == 0).any(): pad_kv = torch.cat([(x_attn == 0), torch.zeros(B, K_eff, device=device, dtype=torch.bool)], dim=1) add.masked_fill_(pad_kv[:, None, None, :], NEG) out0 = transformer(inputs_embeds=full_embeds, attention_mask=add, use_cache=True, return_dict=True) past = out0.past_key_values h_last = out0.last_hidden_state[:, -1, :] else: K_eff = K out0 = transformer(inputs_embeds=x_embeds, attention_mask=x_attn, use_cache=True, return_dict=True) past = out0.past_key_values h_prev = out0.last_hidden_state[:, -1, :] cur_attn = x_attn for t in range(K): z_t = projector(h_prev) cur_attn = torch.cat([cur_attn, torch.ones(B, 1, device=device, dtype=cur_attn.dtype)], dim=1) out_t = transformer(inputs_embeds=z_t.unsqueeze(1), attention_mask=cur_attn, past_key_values=past, use_cache=True, return_dict=True) past = out_t.past_key_values h_prev = out_t.last_hidden_state[:, -1, :] h_last = h_prev # ---- Answer phase: autoregressive decoding. ---- # When block_y_to_x is on, we need y rows to not attend to the first P kv # positions. With KV cache + eager, we pass a 2D attn mask over kv-length # where the x portion is 0. This zeroes out x in additive form. # NB: We zero x but keep latent + prior y at 1. gen_ids = [] last_logits = lm_head(h_last) # [B, V] # Build a base attn mask for y queries: 0 over x, 1 over latents, 1 over prior y. # Sequence length grows by 1 each step. y_kv_base = torch.cat( [torch.zeros(B, P, device=device, dtype=cur_attn.dtype) if block_y_to_x else x_attn, torch.ones(B, K_eff, device=device, dtype=cur_attn.dtype)], dim=1, ) done = torch.zeros(B, dtype=torch.bool, device=device) for step in range(max_new_tokens): if temperature <= 0.0: nxt = last_logits.argmax(dim=-1) else: probs = torch.softmax(last_logits.float() / max(temperature, 1e-6), dim=-1) nxt = torch.multinomial(probs, num_samples=1).squeeze(-1) nxt = torch.where(done, torch.full_like(nxt, tokenizer.pad_token_id), nxt) gen_ids.append(nxt) new_done = done | (nxt == eos) if bool(new_done.all().item()): done = new_done break done = new_done y_emb = embed_in(nxt.unsqueeze(-1)) # [B, 1, d] y_kv_base = torch.cat([y_kv_base, torch.ones(B, 1, device=device, dtype=y_kv_base.dtype)], dim=1) out = transformer(inputs_embeds=y_emb, attention_mask=y_kv_base, past_key_values=past, use_cache=True, return_dict=True) past = out.past_key_values last_logits = lm_head(out.last_hidden_state[:, -1, :]) return torch.stack(gen_ids, dim=1) # [B, L_gen]