File size: 16,263 Bytes
9477b5c bc7101b 9477b5c bc7101b 9477b5c bc7101b 9477b5c bc7101b 9477b5c bc7101b 9477b5c bc7101b 9477b5c bc7101b 9477b5c bc7101b 9477b5c bc7101b 9477b5c bc7101b 9477b5c bc7101b 9477b5c bc7101b 9477b5c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 | """Continuous latent reasoning model.
Sequence layout (no <think>/</think> tokens — latent positions are inputs_embeds):
[ x_tokens ; z_1, ..., z_K ; y_tokens ]
^^^^^^^ ^^^^^^^^^^^^ ^^^^^^^^
discrete continuous discrete
(W_proj of (gold answer
prev hidden) during training)
Gradient flow: full backprop through z_t = W_proj(h_{t-1}). No sampling,
no torch.no_grad() in the latent path. The y-row attention mask blocks
attention to x columns so the latent is the only information channel.
"""
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
NEG = -1e9
@dataclass
class BLTConfig:
base_model: str = "Qwen/Qwen2.5-1.5B-Instruct"
use_lora: bool = True
lora_r: int = 16
lora_alpha: int = 32
lora_dropout: float = 0.05
lora_target_modules: tuple = ("q_proj", "k_proj", "v_proj", "o_proj")
K_latents: int = 4
block_y_to_x: bool = True
block_z_to_x: bool = False # close the z→x architectural leak path (see build_blt_mask)
proj_init_scale: float = 0.02
dtype: str = "bfloat16"
attn_impl: str = "eager" # required for 4D additive mask
gradient_checkpointing: bool = False # trade compute for activation memory; needed for 7B
def build_base(cfg: BLTConfig):
"""Load tokenizer + base CausalLM, optionally wrap with LoRA."""
from transformers import AutoModelForCausalLM, AutoTokenizer
dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[cfg.dtype]
tok = AutoTokenizer.from_pretrained(cfg.base_model, trust_remote_code=True)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
cfg.base_model,
torch_dtype=dtype,
attn_implementation=cfg.attn_impl,
trust_remote_code=True,
)
model.config.use_cache = False
if getattr(cfg, "gradient_checkpointing", False):
# Must enable BEFORE peft wrap; peft propagates the flag to the base model.
# use_reentrant=False avoids the deprecation warning and is recommended for
# modern HF + custom attention masks (our 4D mask path is non-trivial).
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
# For HF checkpointing to actually propagate grads through inputs_embeds
# (which we use for the latent loop), we need to make inputs require grad.
# peft handles this via enable_input_require_grads on the wrapped model.
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
if cfg.use_lora:
from peft import LoraConfig, get_peft_model, TaskType
lcfg = LoraConfig(
r=cfg.lora_r, lora_alpha=cfg.lora_alpha, lora_dropout=cfg.lora_dropout,
bias="none", task_type=TaskType.CAUSAL_LM,
target_modules=list(cfg.lora_target_modules),
)
model = get_peft_model(model, lcfg)
model.print_trainable_parameters()
if getattr(cfg, "gradient_checkpointing", False) and hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
return model, tok
class LatentProjector(nn.Module):
"""Maps last-layer hidden state to next-step input embedding.
Two variants, selected via `use_mlp`:
* Linear (default, original): a single d→d linear layer, bias=False.
* MLP: d → (hidden_mult·d) → d with GELU. More expressive non-linear
compression — necessary if the single linear projection bottlenecks
latent informativeness. Output bias is zeroed at init so the first
forward is near-zero, mimicking the Linear variant's startup.
`init_scale` controls the std of all weight initializations.
"""
def __init__(self, d_model: int, init_scale: float = 0.02,
use_mlp: bool = False, hidden_mult: int = 4):
super().__init__()
self.use_mlp = use_mlp
if use_mlp:
d_hidden = d_model * hidden_mult
self.proj = nn.Sequential(
nn.Linear(d_model, d_hidden, bias=True),
nn.GELU(),
nn.Linear(d_hidden, d_model, bias=True),
)
nn.init.normal_(self.proj[0].weight, mean=0.0, std=init_scale)
nn.init.zeros_(self.proj[0].bias)
nn.init.normal_(self.proj[2].weight, mean=0.0, std=init_scale)
nn.init.zeros_(self.proj[2].bias)
else:
self.proj = nn.Linear(d_model, d_model, bias=False)
nn.init.normal_(self.proj.weight, mean=0.0, std=init_scale)
def forward(self, h: torch.Tensor) -> torch.Tensor:
return self.proj(h)
def _get_input_embeddings(model) -> nn.Module:
"""Returns the input embedding layer, working through PEFT wrap."""
inner = model.get_base_model() if hasattr(model, "get_base_model") else model
return inner.get_input_embeddings()
def _get_lm_head(model) -> nn.Module:
inner = model.get_base_model() if hasattr(model, "get_base_model") else model
return inner.get_output_embeddings()
def build_blt_mask(
B: int, P: int, K: int, L_y: int, device, dtype, *,
block_y_to_x: bool,
block_z_to_x: bool = False,
) -> torch.Tensor:
"""4D additive attention mask [B, 1, T, T] with T = P + K + L_y.
- Lower-triangular causal everywhere.
- If block_y_to_x: y rows (positions [P+K, P+K+L_y)) cannot attend to
x cols (positions [0, P)).
- If block_z_to_x: z rows (positions [P, P+K)) ALSO cannot attend to x.
This closes the architectural "leak" path where z hidden states in
pass 2 could attend to x and deliver x-info to y bypassing z's input
content. With block_z_to_x=True, z hidden states depend only on z
input embeddings + z self-attention. The z input (= π(h_{t-1}) from
pass 1) becomes the *only* carrier of x→y information, forcing z's
input value to actually matter at inference.
"""
T = P + K + L_y
# Start with full -inf, fill 0 where attention is allowed.
add = torch.full((B, 1, T, T), NEG, device=device, dtype=dtype)
# Causal: allow j <= i
row = torch.arange(T, device=device).unsqueeze(1) # [T, 1]
col = torch.arange(T, device=device).unsqueeze(0) # [1, T]
causal = (col <= row) # [T, T] bool
add[:, 0, :, :] = torch.where(causal, torch.zeros_like(add[0, 0]), torch.full_like(add[0, 0], NEG))
if block_y_to_x and P > 0 and L_y > 0:
# zero-out y→x by re-applying NEG to the y-row × x-col block.
add[:, 0, P + K : P + K + L_y, 0:P] = NEG
if block_z_to_x and P > 0 and K > 0:
# zero-out z→x: z rows cannot attend to x cols.
add[:, 0, P : P + K, 0:P] = NEG
return add
def forward_with_latent(
model,
x_ids: torch.Tensor, # [B, P]
x_attn: torch.Tensor, # [B, P] 1=keep, 0=pad (left-padded)
y_ids: Optional[torch.Tensor], # [B, L_y] None at inference
projector: LatentProjector,
K: int,
*,
block_y_to_x: bool = True,
block_z_to_x: bool = False,
return_z: bool = True,
):
"""Run [x; z_1..z_K; y] in two passes:
pass-1 (KV-cached, with grad): iteratively build z_1..z_K from
the running last-layer hidden state.
pass-2 (single full forward): [embed(x); z_1..z_K; embed(y)] with
custom 4D mask blocking y→x. Returns
logits for y positions.
Returns:
logits_y : [B, L_y, V] (None if y_ids is None)
z : [B, K, d] latent vectors (with grad)
h_last_y : [B, L_y, d] last-layer hidden states at y positions (None if y is None)
"""
inner = model.get_base_model() if hasattr(model, "get_base_model") else model
embed_in = inner.get_input_embeddings()
lm_head = inner.get_output_embeddings()
device = x_ids.device
dtype = embed_in.weight.dtype
B, P = x_ids.shape
# ---- Pass 1: iterative z construction with KV cache, grad retained ----
# Initial forward over x to produce the running last-position hidden state.
# We use the underlying base model (`inner.model`) for hidden-state access.
base_lm = inner # e.g., Qwen2ForCausalLM
transformer = base_lm.model # Qwen2Model
x_embeds = embed_in(x_ids)
out0 = transformer(
inputs_embeds=x_embeds,
attention_mask=x_attn,
use_cache=True,
return_dict=True,
)
past = out0.past_key_values
# Grab last-token hidden state, accounting for left-pad: use the last
# non-pad position. Since we left-pad, the last position is always real.
h_prev = out0.last_hidden_state[:, -1, :] # [B, d]
z_list: List[torch.Tensor] = []
cur_attn = x_attn
for t in range(K):
z_t = projector(h_prev) # [B, d]
z_list.append(z_t)
cur_attn = torch.cat(
[cur_attn, torch.ones(B, 1, device=device, dtype=cur_attn.dtype)], dim=1
)
out_t = transformer(
inputs_embeds=z_t.unsqueeze(1),
attention_mask=cur_attn,
past_key_values=past,
use_cache=True,
return_dict=True,
)
past = out_t.past_key_values
h_prev = out_t.last_hidden_state[:, -1, :]
z = torch.stack(z_list, dim=1) # [B, K, d]
if y_ids is None:
return None, z, None
# ---- Pass 2: full forward with custom mask, no past_kv ----
y_embeds = embed_in(y_ids)
L_y = y_ids.size(1)
# Cast z to the embedding dtype to match.
full_embeds = torch.cat([x_embeds, z.to(y_embeds.dtype), y_embeds], dim=1)
full_4d = build_blt_mask(B, P, K, L_y, device=device, dtype=full_embeds.dtype,
block_y_to_x=block_y_to_x, block_z_to_x=block_z_to_x)
# We also need to respect x pad columns (left-pad → kv positions in x
# that are pad should be masked from EVERYTHING, including latents).
if (x_attn == 0).any():
# Build a 1D mask of pad columns: True where pad.
pad_cols = (x_attn == 0) # [B, P]
pad_kv = torch.cat([pad_cols, torch.zeros(B, K + L_y, device=device, dtype=torch.bool)], dim=1)
# Broadcast: for each (b), set add[b, 0, :, j] = NEG where pad_kv[b, j].
full_4d = full_4d.clone()
full_4d.masked_fill_(pad_kv[:, None, None, :], NEG)
out2 = transformer(
inputs_embeds=full_embeds,
attention_mask=full_4d,
use_cache=False,
return_dict=True,
)
h_full = out2.last_hidden_state # [B, T, d]
# logits over y *predictions*: position t predicts token t+1, so for the
# y-segment we read logits at positions [P+K-1, P+K+L_y-1) and compare
# with y_ids[:, :L_y].
logits_all = lm_head(h_full) # [B, T, V]
pred_slice = logits_all[:, P + K - 1 : P + K - 1 + L_y, :] # [B, L_y, V]
h_last_y = h_full[:, P + K : P + K + L_y, :] # [B, L_y, d]
return pred_slice, z, h_last_y
@torch.no_grad()
def generate_with_latent(
model,
tokenizer,
projector: LatentProjector,
x_ids: torch.Tensor, # [B, P]
x_attn: torch.Tensor,
K: int,
*,
block_y_to_x: bool = True,
max_new_tokens: int = 256,
temperature: float = 0.0,
eos_token_id: Optional[int] = None,
override_z: Optional[torch.Tensor] = None, # [B, K, d] forced latents (ablation)
):
"""Greedy / temperature decoding with the latent loop.
override_z: if provided, skip the latent-loop pass and use these latents
directly. For ablations: random-z (gaussian noise), zero-z
(K=0), shuffled-z, etc.
"""
inner = model.get_base_model() if hasattr(model, "get_base_model") else model
transformer = inner.model
embed_in = inner.get_input_embeddings()
lm_head = inner.get_output_embeddings()
device = x_ids.device
B, P = x_ids.shape
eos = eos_token_id if eos_token_id is not None else tokenizer.eos_token_id
x_embeds = embed_in(x_ids)
# ---- z (computed or overridden) ----
if override_z is not None:
K_eff = override_z.size(1)
z = override_z
# Still need to "consume" x and the latents through the transformer
# to build past_kv used for answer generation. Do a single forward.
full_embeds = torch.cat([x_embeds, z.to(x_embeds.dtype)], dim=1)
cur_attn = torch.cat(
[x_attn, torch.ones(B, K_eff, device=device, dtype=x_attn.dtype)], dim=1
)
# Build a 4D mask: causal + x-pads masked
T0 = P + K_eff
add = torch.full((B, 1, T0, T0), NEG, device=device, dtype=x_embeds.dtype)
row = torch.arange(T0, device=device).unsqueeze(1)
col = torch.arange(T0, device=device).unsqueeze(0)
causal = (col <= row)
add[:, 0, :, :] = torch.where(causal, torch.zeros_like(add[0, 0]),
torch.full_like(add[0, 0], NEG))
if (x_attn == 0).any():
pad_kv = torch.cat([(x_attn == 0),
torch.zeros(B, K_eff, device=device, dtype=torch.bool)], dim=1)
add.masked_fill_(pad_kv[:, None, None, :], NEG)
out0 = transformer(inputs_embeds=full_embeds, attention_mask=add,
use_cache=True, return_dict=True)
past = out0.past_key_values
h_last = out0.last_hidden_state[:, -1, :]
else:
K_eff = K
out0 = transformer(inputs_embeds=x_embeds, attention_mask=x_attn,
use_cache=True, return_dict=True)
past = out0.past_key_values
h_prev = out0.last_hidden_state[:, -1, :]
cur_attn = x_attn
for t in range(K):
z_t = projector(h_prev)
cur_attn = torch.cat([cur_attn, torch.ones(B, 1, device=device, dtype=cur_attn.dtype)], dim=1)
out_t = transformer(inputs_embeds=z_t.unsqueeze(1), attention_mask=cur_attn,
past_key_values=past, use_cache=True, return_dict=True)
past = out_t.past_key_values
h_prev = out_t.last_hidden_state[:, -1, :]
h_last = h_prev
# ---- Answer phase: autoregressive decoding. ----
# When block_y_to_x is on, we need y rows to not attend to the first P kv
# positions. With KV cache + eager, we pass a 2D attn mask over kv-length
# where the x portion is 0. This zeroes out x in additive form.
# NB: We zero x but keep latent + prior y at 1.
gen_ids = []
last_logits = lm_head(h_last) # [B, V]
# Build a base attn mask for y queries: 0 over x, 1 over latents, 1 over prior y.
# Sequence length grows by 1 each step.
y_kv_base = torch.cat(
[torch.zeros(B, P, device=device, dtype=cur_attn.dtype) if block_y_to_x else x_attn,
torch.ones(B, K_eff, device=device, dtype=cur_attn.dtype)],
dim=1,
)
done = torch.zeros(B, dtype=torch.bool, device=device)
for step in range(max_new_tokens):
if temperature <= 0.0:
nxt = last_logits.argmax(dim=-1)
else:
probs = torch.softmax(last_logits.float() / max(temperature, 1e-6), dim=-1)
nxt = torch.multinomial(probs, num_samples=1).squeeze(-1)
nxt = torch.where(done, torch.full_like(nxt, tokenizer.pad_token_id), nxt)
gen_ids.append(nxt)
new_done = done | (nxt == eos)
if bool(new_done.all().item()):
done = new_done
break
done = new_done
y_emb = embed_in(nxt.unsqueeze(-1)) # [B, 1, d]
y_kv_base = torch.cat([y_kv_base, torch.ones(B, 1, device=device, dtype=y_kv_base.dtype)], dim=1)
out = transformer(inputs_embeds=y_emb, attention_mask=y_kv_base,
past_key_values=past, use_cache=True, return_dict=True)
past = out.past_key_values
last_logits = lm_head(out.last_hidden_state[:, -1, :])
return torch.stack(gen_ids, dim=1) # [B, L_gen]
|