LauraGG's picture
Refresh code/ with latest BLT-Reasoner sources (post-campaign)
bc7101b verified
"""Continuous latent reasoning model.
Sequence layout (no <think>/</think> tokens — latent positions are inputs_embeds):
[ x_tokens ; z_1, ..., z_K ; y_tokens ]
^^^^^^^ ^^^^^^^^^^^^ ^^^^^^^^
discrete continuous discrete
(W_proj of (gold answer
prev hidden) during training)
Gradient flow: full backprop through z_t = W_proj(h_{t-1}). No sampling,
no torch.no_grad() in the latent path. The y-row attention mask blocks
attention to x columns so the latent is the only information channel.
"""
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
NEG = -1e9
@dataclass
class BLTConfig:
base_model: str = "Qwen/Qwen2.5-1.5B-Instruct"
use_lora: bool = True
lora_r: int = 16
lora_alpha: int = 32
lora_dropout: float = 0.05
lora_target_modules: tuple = ("q_proj", "k_proj", "v_proj", "o_proj")
K_latents: int = 4
block_y_to_x: bool = True
block_z_to_x: bool = False # close the z→x architectural leak path (see build_blt_mask)
proj_init_scale: float = 0.02
dtype: str = "bfloat16"
attn_impl: str = "eager" # required for 4D additive mask
gradient_checkpointing: bool = False # trade compute for activation memory; needed for 7B
def build_base(cfg: BLTConfig):
"""Load tokenizer + base CausalLM, optionally wrap with LoRA."""
from transformers import AutoModelForCausalLM, AutoTokenizer
dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[cfg.dtype]
tok = AutoTokenizer.from_pretrained(cfg.base_model, trust_remote_code=True)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
cfg.base_model,
torch_dtype=dtype,
attn_implementation=cfg.attn_impl,
trust_remote_code=True,
)
model.config.use_cache = False
if getattr(cfg, "gradient_checkpointing", False):
# Must enable BEFORE peft wrap; peft propagates the flag to the base model.
# use_reentrant=False avoids the deprecation warning and is recommended for
# modern HF + custom attention masks (our 4D mask path is non-trivial).
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
# For HF checkpointing to actually propagate grads through inputs_embeds
# (which we use for the latent loop), we need to make inputs require grad.
# peft handles this via enable_input_require_grads on the wrapped model.
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
if cfg.use_lora:
from peft import LoraConfig, get_peft_model, TaskType
lcfg = LoraConfig(
r=cfg.lora_r, lora_alpha=cfg.lora_alpha, lora_dropout=cfg.lora_dropout,
bias="none", task_type=TaskType.CAUSAL_LM,
target_modules=list(cfg.lora_target_modules),
)
model = get_peft_model(model, lcfg)
model.print_trainable_parameters()
if getattr(cfg, "gradient_checkpointing", False) and hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
return model, tok
class LatentProjector(nn.Module):
"""Maps last-layer hidden state to next-step input embedding.
Two variants, selected via `use_mlp`:
* Linear (default, original): a single d→d linear layer, bias=False.
* MLP: d → (hidden_mult·d) → d with GELU. More expressive non-linear
compression — necessary if the single linear projection bottlenecks
latent informativeness. Output bias is zeroed at init so the first
forward is near-zero, mimicking the Linear variant's startup.
`init_scale` controls the std of all weight initializations.
"""
def __init__(self, d_model: int, init_scale: float = 0.02,
use_mlp: bool = False, hidden_mult: int = 4):
super().__init__()
self.use_mlp = use_mlp
if use_mlp:
d_hidden = d_model * hidden_mult
self.proj = nn.Sequential(
nn.Linear(d_model, d_hidden, bias=True),
nn.GELU(),
nn.Linear(d_hidden, d_model, bias=True),
)
nn.init.normal_(self.proj[0].weight, mean=0.0, std=init_scale)
nn.init.zeros_(self.proj[0].bias)
nn.init.normal_(self.proj[2].weight, mean=0.0, std=init_scale)
nn.init.zeros_(self.proj[2].bias)
else:
self.proj = nn.Linear(d_model, d_model, bias=False)
nn.init.normal_(self.proj.weight, mean=0.0, std=init_scale)
def forward(self, h: torch.Tensor) -> torch.Tensor:
return self.proj(h)
def _get_input_embeddings(model) -> nn.Module:
"""Returns the input embedding layer, working through PEFT wrap."""
inner = model.get_base_model() if hasattr(model, "get_base_model") else model
return inner.get_input_embeddings()
def _get_lm_head(model) -> nn.Module:
inner = model.get_base_model() if hasattr(model, "get_base_model") else model
return inner.get_output_embeddings()
def build_blt_mask(
B: int, P: int, K: int, L_y: int, device, dtype, *,
block_y_to_x: bool,
block_z_to_x: bool = False,
) -> torch.Tensor:
"""4D additive attention mask [B, 1, T, T] with T = P + K + L_y.
- Lower-triangular causal everywhere.
- If block_y_to_x: y rows (positions [P+K, P+K+L_y)) cannot attend to
x cols (positions [0, P)).
- If block_z_to_x: z rows (positions [P, P+K)) ALSO cannot attend to x.
This closes the architectural "leak" path where z hidden states in
pass 2 could attend to x and deliver x-info to y bypassing z's input
content. With block_z_to_x=True, z hidden states depend only on z
input embeddings + z self-attention. The z input (= π(h_{t-1}) from
pass 1) becomes the *only* carrier of x→y information, forcing z's
input value to actually matter at inference.
"""
T = P + K + L_y
# Start with full -inf, fill 0 where attention is allowed.
add = torch.full((B, 1, T, T), NEG, device=device, dtype=dtype)
# Causal: allow j <= i
row = torch.arange(T, device=device).unsqueeze(1) # [T, 1]
col = torch.arange(T, device=device).unsqueeze(0) # [1, T]
causal = (col <= row) # [T, T] bool
add[:, 0, :, :] = torch.where(causal, torch.zeros_like(add[0, 0]), torch.full_like(add[0, 0], NEG))
if block_y_to_x and P > 0 and L_y > 0:
# zero-out y→x by re-applying NEG to the y-row × x-col block.
add[:, 0, P + K : P + K + L_y, 0:P] = NEG
if block_z_to_x and P > 0 and K > 0:
# zero-out z→x: z rows cannot attend to x cols.
add[:, 0, P : P + K, 0:P] = NEG
return add
def forward_with_latent(
model,
x_ids: torch.Tensor, # [B, P]
x_attn: torch.Tensor, # [B, P] 1=keep, 0=pad (left-padded)
y_ids: Optional[torch.Tensor], # [B, L_y] None at inference
projector: LatentProjector,
K: int,
*,
block_y_to_x: bool = True,
block_z_to_x: bool = False,
return_z: bool = True,
):
"""Run [x; z_1..z_K; y] in two passes:
pass-1 (KV-cached, with grad): iteratively build z_1..z_K from
the running last-layer hidden state.
pass-2 (single full forward): [embed(x); z_1..z_K; embed(y)] with
custom 4D mask blocking y→x. Returns
logits for y positions.
Returns:
logits_y : [B, L_y, V] (None if y_ids is None)
z : [B, K, d] latent vectors (with grad)
h_last_y : [B, L_y, d] last-layer hidden states at y positions (None if y is None)
"""
inner = model.get_base_model() if hasattr(model, "get_base_model") else model
embed_in = inner.get_input_embeddings()
lm_head = inner.get_output_embeddings()
device = x_ids.device
dtype = embed_in.weight.dtype
B, P = x_ids.shape
# ---- Pass 1: iterative z construction with KV cache, grad retained ----
# Initial forward over x to produce the running last-position hidden state.
# We use the underlying base model (`inner.model`) for hidden-state access.
base_lm = inner # e.g., Qwen2ForCausalLM
transformer = base_lm.model # Qwen2Model
x_embeds = embed_in(x_ids)
out0 = transformer(
inputs_embeds=x_embeds,
attention_mask=x_attn,
use_cache=True,
return_dict=True,
)
past = out0.past_key_values
# Grab last-token hidden state, accounting for left-pad: use the last
# non-pad position. Since we left-pad, the last position is always real.
h_prev = out0.last_hidden_state[:, -1, :] # [B, d]
z_list: List[torch.Tensor] = []
cur_attn = x_attn
for t in range(K):
z_t = projector(h_prev) # [B, d]
z_list.append(z_t)
cur_attn = torch.cat(
[cur_attn, torch.ones(B, 1, device=device, dtype=cur_attn.dtype)], dim=1
)
out_t = transformer(
inputs_embeds=z_t.unsqueeze(1),
attention_mask=cur_attn,
past_key_values=past,
use_cache=True,
return_dict=True,
)
past = out_t.past_key_values
h_prev = out_t.last_hidden_state[:, -1, :]
z = torch.stack(z_list, dim=1) # [B, K, d]
if y_ids is None:
return None, z, None
# ---- Pass 2: full forward with custom mask, no past_kv ----
y_embeds = embed_in(y_ids)
L_y = y_ids.size(1)
# Cast z to the embedding dtype to match.
full_embeds = torch.cat([x_embeds, z.to(y_embeds.dtype), y_embeds], dim=1)
full_4d = build_blt_mask(B, P, K, L_y, device=device, dtype=full_embeds.dtype,
block_y_to_x=block_y_to_x, block_z_to_x=block_z_to_x)
# We also need to respect x pad columns (left-pad → kv positions in x
# that are pad should be masked from EVERYTHING, including latents).
if (x_attn == 0).any():
# Build a 1D mask of pad columns: True where pad.
pad_cols = (x_attn == 0) # [B, P]
pad_kv = torch.cat([pad_cols, torch.zeros(B, K + L_y, device=device, dtype=torch.bool)], dim=1)
# Broadcast: for each (b), set add[b, 0, :, j] = NEG where pad_kv[b, j].
full_4d = full_4d.clone()
full_4d.masked_fill_(pad_kv[:, None, None, :], NEG)
out2 = transformer(
inputs_embeds=full_embeds,
attention_mask=full_4d,
use_cache=False,
return_dict=True,
)
h_full = out2.last_hidden_state # [B, T, d]
# logits over y *predictions*: position t predicts token t+1, so for the
# y-segment we read logits at positions [P+K-1, P+K+L_y-1) and compare
# with y_ids[:, :L_y].
logits_all = lm_head(h_full) # [B, T, V]
pred_slice = logits_all[:, P + K - 1 : P + K - 1 + L_y, :] # [B, L_y, V]
h_last_y = h_full[:, P + K : P + K + L_y, :] # [B, L_y, d]
return pred_slice, z, h_last_y
@torch.no_grad()
def generate_with_latent(
model,
tokenizer,
projector: LatentProjector,
x_ids: torch.Tensor, # [B, P]
x_attn: torch.Tensor,
K: int,
*,
block_y_to_x: bool = True,
max_new_tokens: int = 256,
temperature: float = 0.0,
eos_token_id: Optional[int] = None,
override_z: Optional[torch.Tensor] = None, # [B, K, d] forced latents (ablation)
):
"""Greedy / temperature decoding with the latent loop.
override_z: if provided, skip the latent-loop pass and use these latents
directly. For ablations: random-z (gaussian noise), zero-z
(K=0), shuffled-z, etc.
"""
inner = model.get_base_model() if hasattr(model, "get_base_model") else model
transformer = inner.model
embed_in = inner.get_input_embeddings()
lm_head = inner.get_output_embeddings()
device = x_ids.device
B, P = x_ids.shape
eos = eos_token_id if eos_token_id is not None else tokenizer.eos_token_id
x_embeds = embed_in(x_ids)
# ---- z (computed or overridden) ----
if override_z is not None:
K_eff = override_z.size(1)
z = override_z
# Still need to "consume" x and the latents through the transformer
# to build past_kv used for answer generation. Do a single forward.
full_embeds = torch.cat([x_embeds, z.to(x_embeds.dtype)], dim=1)
cur_attn = torch.cat(
[x_attn, torch.ones(B, K_eff, device=device, dtype=x_attn.dtype)], dim=1
)
# Build a 4D mask: causal + x-pads masked
T0 = P + K_eff
add = torch.full((B, 1, T0, T0), NEG, device=device, dtype=x_embeds.dtype)
row = torch.arange(T0, device=device).unsqueeze(1)
col = torch.arange(T0, device=device).unsqueeze(0)
causal = (col <= row)
add[:, 0, :, :] = torch.where(causal, torch.zeros_like(add[0, 0]),
torch.full_like(add[0, 0], NEG))
if (x_attn == 0).any():
pad_kv = torch.cat([(x_attn == 0),
torch.zeros(B, K_eff, device=device, dtype=torch.bool)], dim=1)
add.masked_fill_(pad_kv[:, None, None, :], NEG)
out0 = transformer(inputs_embeds=full_embeds, attention_mask=add,
use_cache=True, return_dict=True)
past = out0.past_key_values
h_last = out0.last_hidden_state[:, -1, :]
else:
K_eff = K
out0 = transformer(inputs_embeds=x_embeds, attention_mask=x_attn,
use_cache=True, return_dict=True)
past = out0.past_key_values
h_prev = out0.last_hidden_state[:, -1, :]
cur_attn = x_attn
for t in range(K):
z_t = projector(h_prev)
cur_attn = torch.cat([cur_attn, torch.ones(B, 1, device=device, dtype=cur_attn.dtype)], dim=1)
out_t = transformer(inputs_embeds=z_t.unsqueeze(1), attention_mask=cur_attn,
past_key_values=past, use_cache=True, return_dict=True)
past = out_t.past_key_values
h_prev = out_t.last_hidden_state[:, -1, :]
h_last = h_prev
# ---- Answer phase: autoregressive decoding. ----
# When block_y_to_x is on, we need y rows to not attend to the first P kv
# positions. With KV cache + eager, we pass a 2D attn mask over kv-length
# where the x portion is 0. This zeroes out x in additive form.
# NB: We zero x but keep latent + prior y at 1.
gen_ids = []
last_logits = lm_head(h_last) # [B, V]
# Build a base attn mask for y queries: 0 over x, 1 over latents, 1 over prior y.
# Sequence length grows by 1 each step.
y_kv_base = torch.cat(
[torch.zeros(B, P, device=device, dtype=cur_attn.dtype) if block_y_to_x else x_attn,
torch.ones(B, K_eff, device=device, dtype=cur_attn.dtype)],
dim=1,
)
done = torch.zeros(B, dtype=torch.bool, device=device)
for step in range(max_new_tokens):
if temperature <= 0.0:
nxt = last_logits.argmax(dim=-1)
else:
probs = torch.softmax(last_logits.float() / max(temperature, 1e-6), dim=-1)
nxt = torch.multinomial(probs, num_samples=1).squeeze(-1)
nxt = torch.where(done, torch.full_like(nxt, tokenizer.pad_token_id), nxt)
gen_ids.append(nxt)
new_done = done | (nxt == eos)
if bool(new_done.all().item()):
done = new_done
break
done = new_done
y_emb = embed_in(nxt.unsqueeze(-1)) # [B, 1, d]
y_kv_base = torch.cat([y_kv_base, torch.ones(B, 1, device=device, dtype=y_kv_base.dtype)], dim=1)
out = transformer(inputs_embeds=y_emb, attention_mask=y_kv_base,
past_key_values=past, use_cache=True, return_dict=True)
past = out.past_key_values
last_logits = lm_head(out.last_hidden_state[:, -1, :])
return torch.stack(gen_ids, dim=1) # [B, L_gen]