| """Losses for BLT-Reasoner. |
| |
| L_total = L_LM(y | z, mask=y→x-blocked) |
| + λ_id · L_InfoNCE(g(z), f(y)) # identifiability lock |
| + λ_kl · KL( q(z) || N(0, I) ) # magnitude regularizer |
| |
| The InfoNCE term is the structural defense against constant-z basins. f(y) |
| is the frozen base model's mean-pooled hidden state over the gold answer |
| (adapters disabled), so f is decoupled from gradient. g(z) is a small |
| learned head over z (mean-pooled across K positions, then projected). |
| |
| For a constant-z policy: g(z) is the same for every sample in the batch, |
| so all logits in InfoNCE are equal, and the loss is bounded below by |
| log(B). Only a sample-dependent z can lower this. |
| """ |
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| @dataclass |
| class LossWeights: |
| lambda_lm: float = 1.0 |
| lambda_id: float = 1.0 |
| lambda_kl: float = 1e-3 |
| tau_infonce: float = 0.07 |
|
|
|
|
| class InfoNCEHead(nn.Module): |
| """Two MLPs project z and f(y) into a shared embedding space for InfoNCE.""" |
| def __init__(self, d_z: int, d_y: int, d_out: int = 256): |
| super().__init__() |
| self.g = nn.Sequential( |
| nn.Linear(d_z, d_out), nn.GELU(), nn.Linear(d_out, d_out) |
| ) |
| self.h = nn.Sequential( |
| nn.Linear(d_y, d_out), nn.GELU(), nn.Linear(d_out, d_out) |
| ) |
|
|
| def forward(self, z_pool: torch.Tensor, y_pool: torch.Tensor): |
| z_emb = F.normalize(self.g(z_pool), dim=-1) |
| y_emb = F.normalize(self.h(y_pool), dim=-1) |
| return z_emb, y_emb |
|
|
|
|
| def infonce_loss(z_emb: torch.Tensor, y_emb: torch.Tensor, tau: float = 0.07) -> torch.Tensor: |
| """Symmetric InfoNCE (CLIP-style). |
| |
| z_emb, y_emb: [B, d] L2-normalized. Diagonal = positives. |
| Lower bound for constant-z (z_emb identical across rows): -log(1/B) = log(B). |
| """ |
| B = z_emb.size(0) |
| logits = z_emb @ y_emb.t() / tau |
| targets = torch.arange(B, device=z_emb.device) |
| loss_z2y = F.cross_entropy(logits, targets) |
| loss_y2z = F.cross_entropy(logits.t(), targets) |
| return 0.5 * (loss_z2y + loss_y2z) |
|
|
|
|
| def slot_decorrelation_loss(z: torch.Tensor) -> torch.Tensor: |
| """Penalize pairwise alignment between latent slots. |
| |
| z: [B, K, d] (the per-slot input embeddings from forward_with_latent) |
| Returns: scalar = mean squared off-diagonal of per-batch row-normalized |
| Gram matrices. When this is 0, all K slots are pairwise orthogonal |
| (cos-sim = 0). When this is high, slots are aligned (redundant). |
| |
| Used as a soft regularizer when the capacity diagnostic shows slots are |
| highly redundant (stable_rank << K). Forces the projector / loop to |
| produce slots that span more directions. |
| """ |
| Zn = F.normalize(z.float(), dim=-1, eps=1e-6) |
| G = torch.einsum("bkd,bjd->bkj", Zn, Zn) |
| K = z.size(1) |
| eye = torch.eye(K, device=z.device, dtype=G.dtype).unsqueeze(0) |
| off_diag = G - eye |
| |
| n_off = K * (K - 1) |
| return (off_diag.pow(2).sum(dim=(-1, -2)) / max(n_off, 1)).mean() |
|
|
|
|
| def kl_to_gaussian(z: torch.Tensor) -> torch.Tensor: |
| """Approximate KL(z || N(0, I)) treating z as a deterministic point. |
| |
| With deterministic z this reduces to 0.5 * (||z||² - d) per latent slot, |
| which is the standard β-VAE regularizer when q is a delta. We use it as |
| a soft magnitude prior so z doesn't grow unboundedly (no inherent norm |
| in the residual stream). |
| """ |
| |
| return 0.5 * (z.pow(2).sum(dim=-1) - z.size(-1)).mean() |
|
|
|
|
| @torch.no_grad() |
| def encode_chunks_per_slot(model, tokenizer, chunks_per_problem, device, max_len: int = 32) -> torch.Tensor: |
| """Encode a [B][K] list-of-lists of y-chunk strings via the frozen base. |
| |
| chunks_per_problem: List of length B, each a list of length K of chunk strings. |
| Returns: tensor [B, K, d] of mean-pooled last-layer hidden states (frozen). |
| |
| For per-CoT-step InfoNCE: each slot k receives a *different* target — |
| the encoding of chunk k. Forces slots to specialize rather than all |
| learning the same global y representation (which our capacity diagnostic |
| showed yields stable_rank=6.73 across K=16 slots). |
| """ |
| import torch |
| import contextlib |
| if not chunks_per_problem: |
| raise ValueError("chunks_per_problem is empty") |
| B = len(chunks_per_problem) |
| K = len(chunks_per_problem[0]) |
| flat = [] |
| for cps in chunks_per_problem: |
| for c in cps: |
| flat.append(c if c else "<pad>") |
| enc = tokenizer(flat, return_tensors="pt", padding=True, truncation=True, |
| max_length=max_len, add_special_tokens=False).to(device) |
| inner = model.get_base_model() if hasattr(model, "get_base_model") else model |
| ctx = model.disable_adapter() if hasattr(model, "disable_adapter") else contextlib.nullcontext() |
| with ctx: |
| out = inner.model( |
| input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], |
| use_cache=False, return_dict=True, |
| ) |
| mask = enc["attention_mask"].unsqueeze(-1).to(out.last_hidden_state.dtype) |
| pooled = (out.last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1.0) |
| return pooled.detach().view(B, K, -1) |
|
|
|
|
| def infonce_per_slot_loss( |
| z: torch.Tensor, |
| y_chunks: torch.Tensor, |
| head: nn.Module, |
| tau: float = 0.07, |
| ) -> dict: |
| """Per-slot InfoNCE: for each (problem b, slot k), positive = chunk_k of |
| problem b's gold y. Negatives = all other (b', k') combinations. |
| |
| Construct a [B*K, B*K] similarity matrix and CE against the identity. |
| A constant z (across slots) cannot satisfy this — neither can a z that |
| correctly identifies "which problem" but mixes up slot index. Forces |
| the projection π to specialize z's slots. |
| |
| Returns dict with the loss + within-batch/within-slot accuracy probes. |
| """ |
| B, K, _ = z.shape |
| z_flat = z.reshape(B * K, -1).float() |
| y_flat = y_chunks.reshape(B * K, -1).float() |
| z_emb, y_emb = head(z_flat, y_flat) |
| logits = z_emb @ y_emb.t() / tau |
| targets = torch.arange(B * K, device=z.device) |
| loss_z2y = F.cross_entropy(logits, targets) |
| loss_y2z = F.cross_entropy(logits.t(), targets) |
| loss = 0.5 * (loss_z2y + loss_y2z) |
| with torch.no_grad(): |
| |
| acc_z2y = (logits.argmax(dim=1) == targets).float().mean() |
| acc_y2z = (logits.t().argmax(dim=1) == targets).float().mean() |
| |
| |
| |
| sim = (z_emb.view(B, K, -1) @ y_emb.view(B, K, -1).transpose(-1, -2)) |
| pred = sim.argmax(dim=-1) |
| acc_within = (pred == torch.arange(K, device=z.device).unsqueeze(0)).float().mean() |
| return { |
| "loss": loss, |
| "acc_z2y": acc_z2y, |
| "acc_y2z": acc_y2z, |
| "acc_within_problem": acc_within, |
| } |
|
|
|
|
| @torch.no_grad() |
| def encode_answer_for_infonce(model, tokenizer, y_text: list, device, max_len: int = 64) -> torch.Tensor: |
| """Encode gold answer strings via the frozen base (LoRA adapters disabled), |
| return mean-pooled last-layer hidden state. |
| |
| For GSM8K we typically feed only the final-number portion ("#### 42") so |
| f(y) is anchored to the answer rather than full reasoning text. |
| """ |
| enc = tokenizer(y_text, return_tensors="pt", padding=True, truncation=True, |
| max_length=max_len, add_special_tokens=False).to(device) |
| inner = model.get_base_model() if hasattr(model, "get_base_model") else model |
| |
| if hasattr(model, "disable_adapter"): |
| ctx = model.disable_adapter() |
| else: |
| import contextlib |
| ctx = contextlib.nullcontext() |
| with ctx: |
| out = inner.model(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], |
| use_cache=False, return_dict=True) |
| |
| mask = enc["attention_mask"].unsqueeze(-1).to(out.last_hidden_state.dtype) |
| pooled = (out.last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1.0) |
| return pooled.detach() |
|
|
|
|
| def lm_loss_on_y(logits_y: torch.Tensor, y_ids: torch.Tensor, y_attn: torch.Tensor) -> torch.Tensor: |
| """Standard next-token CE over the y segment. |
| |
| logits_y: [B, L_y, V] (already sliced so logits[:, t] predicts y[:, t]) |
| y_ids: [B, L_y] |
| y_attn: [B, L_y] 1 where real, 0 where pad |
| """ |
| B, L_y, V = logits_y.shape |
| flat_logits = logits_y.reshape(B * L_y, V) |
| flat_targets = y_ids.reshape(B * L_y) |
| per_tok = F.cross_entropy(flat_logits, flat_targets, reduction="none").reshape(B, L_y) |
| |
| mask = y_attn.float() |
| return (per_tok * mask).sum() / mask.sum().clamp_min(1.0) |
|
|