"""Losses for BLT-Reasoner. L_total = L_LM(y | z, mask=y→x-blocked) + λ_id · L_InfoNCE(g(z), f(y)) # identifiability lock + λ_kl · KL( q(z) || N(0, I) ) # magnitude regularizer The InfoNCE term is the structural defense against constant-z basins. f(y) is the frozen base model's mean-pooled hidden state over the gold answer (adapters disabled), so f is decoupled from gradient. g(z) is a small learned head over z (mean-pooled across K positions, then projected). For a constant-z policy: g(z) is the same for every sample in the batch, so all logits in InfoNCE are equal, and the loss is bounded below by log(B). Only a sample-dependent z can lower this. """ from __future__ import annotations from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F @dataclass class LossWeights: lambda_lm: float = 1.0 lambda_id: float = 1.0 lambda_kl: float = 1e-3 tau_infonce: float = 0.07 class InfoNCEHead(nn.Module): """Two MLPs project z and f(y) into a shared embedding space for InfoNCE.""" def __init__(self, d_z: int, d_y: int, d_out: int = 256): super().__init__() self.g = nn.Sequential( nn.Linear(d_z, d_out), nn.GELU(), nn.Linear(d_out, d_out) ) self.h = nn.Sequential( nn.Linear(d_y, d_out), nn.GELU(), nn.Linear(d_out, d_out) ) def forward(self, z_pool: torch.Tensor, y_pool: torch.Tensor): z_emb = F.normalize(self.g(z_pool), dim=-1) y_emb = F.normalize(self.h(y_pool), dim=-1) return z_emb, y_emb def infonce_loss(z_emb: torch.Tensor, y_emb: torch.Tensor, tau: float = 0.07) -> torch.Tensor: """Symmetric InfoNCE (CLIP-style). z_emb, y_emb: [B, d] L2-normalized. Diagonal = positives. Lower bound for constant-z (z_emb identical across rows): -log(1/B) = log(B). """ B = z_emb.size(0) logits = z_emb @ y_emb.t() / tau # [B, B] targets = torch.arange(B, device=z_emb.device) loss_z2y = F.cross_entropy(logits, targets) loss_y2z = F.cross_entropy(logits.t(), targets) return 0.5 * (loss_z2y + loss_y2z) def slot_decorrelation_loss(z: torch.Tensor) -> torch.Tensor: """Penalize pairwise alignment between latent slots. z: [B, K, d] (the per-slot input embeddings from forward_with_latent) Returns: scalar = mean squared off-diagonal of per-batch row-normalized Gram matrices. When this is 0, all K slots are pairwise orthogonal (cos-sim = 0). When this is high, slots are aligned (redundant). Used as a soft regularizer when the capacity diagnostic shows slots are highly redundant (stable_rank << K). Forces the projector / loop to produce slots that span more directions. """ Zn = F.normalize(z.float(), dim=-1, eps=1e-6) # [B, K, d] G = torch.einsum("bkd,bjd->bkj", Zn, Zn) # [B, K, K] K = z.size(1) eye = torch.eye(K, device=z.device, dtype=G.dtype).unsqueeze(0) off_diag = G - eye # diagonal -> 0 # Mean over off-diagonal entries only (more interpretable than mean over all K²) n_off = K * (K - 1) return (off_diag.pow(2).sum(dim=(-1, -2)) / max(n_off, 1)).mean() def kl_to_gaussian(z: torch.Tensor) -> torch.Tensor: """Approximate KL(z || N(0, I)) treating z as a deterministic point. With deterministic z this reduces to 0.5 * (||z||² - d) per latent slot, which is the standard β-VAE regularizer when q is a delta. We use it as a soft magnitude prior so z doesn't grow unboundedly (no inherent norm in the residual stream). """ # z: [B, K, d] return 0.5 * (z.pow(2).sum(dim=-1) - z.size(-1)).mean() @torch.no_grad() def encode_chunks_per_slot(model, tokenizer, chunks_per_problem, device, max_len: int = 32) -> torch.Tensor: """Encode a [B][K] list-of-lists of y-chunk strings via the frozen base. chunks_per_problem: List of length B, each a list of length K of chunk strings. Returns: tensor [B, K, d] of mean-pooled last-layer hidden states (frozen). For per-CoT-step InfoNCE: each slot k receives a *different* target — the encoding of chunk k. Forces slots to specialize rather than all learning the same global y representation (which our capacity diagnostic showed yields stable_rank=6.73 across K=16 slots). """ import torch import contextlib if not chunks_per_problem: raise ValueError("chunks_per_problem is empty") B = len(chunks_per_problem) K = len(chunks_per_problem[0]) flat = [] for cps in chunks_per_problem: for c in cps: flat.append(c if c else "") enc = tokenizer(flat, return_tensors="pt", padding=True, truncation=True, max_length=max_len, add_special_tokens=False).to(device) inner = model.get_base_model() if hasattr(model, "get_base_model") else model ctx = model.disable_adapter() if hasattr(model, "disable_adapter") else contextlib.nullcontext() with ctx: out = inner.model( input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], use_cache=False, return_dict=True, ) mask = enc["attention_mask"].unsqueeze(-1).to(out.last_hidden_state.dtype) pooled = (out.last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1.0) return pooled.detach().view(B, K, -1) def infonce_per_slot_loss( z: torch.Tensor, # [B, K, d_z] y_chunks: torch.Tensor, # [B, K, d_y] head: nn.Module, # InfoNCEHead, used flattened tau: float = 0.07, ) -> dict: """Per-slot InfoNCE: for each (problem b, slot k), positive = chunk_k of problem b's gold y. Negatives = all other (b', k') combinations. Construct a [B*K, B*K] similarity matrix and CE against the identity. A constant z (across slots) cannot satisfy this — neither can a z that correctly identifies "which problem" but mixes up slot index. Forces the projection π to specialize z's slots. Returns dict with the loss + within-batch/within-slot accuracy probes. """ B, K, _ = z.shape z_flat = z.reshape(B * K, -1).float() y_flat = y_chunks.reshape(B * K, -1).float() z_emb, y_emb = head(z_flat, y_flat) # [B*K, d_out] L2-normalized logits = z_emb @ y_emb.t() / tau # [B*K, B*K] targets = torch.arange(B * K, device=z.device) loss_z2y = F.cross_entropy(logits, targets) loss_y2z = F.cross_entropy(logits.t(), targets) loss = 0.5 * (loss_z2y + loss_y2z) with torch.no_grad(): # Top-1 accuracy of identifying the correct (problem, slot) acc_z2y = (logits.argmax(dim=1) == targets).float().mean() acc_y2z = (logits.t().argmax(dim=1) == targets).float().mean() # Within-problem accuracy: among the K chunks of the SAME problem, # does slot k pick chunk k? Tests whether slots distinguish their # *positions* (not just their problem). sim = (z_emb.view(B, K, -1) @ y_emb.view(B, K, -1).transpose(-1, -2)) # [B, K, K] pred = sim.argmax(dim=-1) acc_within = (pred == torch.arange(K, device=z.device).unsqueeze(0)).float().mean() return { "loss": loss, "acc_z2y": acc_z2y, "acc_y2z": acc_y2z, "acc_within_problem": acc_within, } @torch.no_grad() def encode_answer_for_infonce(model, tokenizer, y_text: list, device, max_len: int = 64) -> torch.Tensor: """Encode gold answer strings via the frozen base (LoRA adapters disabled), return mean-pooled last-layer hidden state. For GSM8K we typically feed only the final-number portion ("#### 42") so f(y) is anchored to the answer rather than full reasoning text. """ enc = tokenizer(y_text, return_tensors="pt", padding=True, truncation=True, max_length=max_len, add_special_tokens=False).to(device) inner = model.get_base_model() if hasattr(model, "get_base_model") else model # PEFT: disable adapters for this encoding pass. if hasattr(model, "disable_adapter"): ctx = model.disable_adapter() else: import contextlib ctx = contextlib.nullcontext() with ctx: out = inner.model(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], use_cache=False, return_dict=True) # Mean-pool over non-pad tokens. mask = enc["attention_mask"].unsqueeze(-1).to(out.last_hidden_state.dtype) pooled = (out.last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1.0) return pooled.detach() def lm_loss_on_y(logits_y: torch.Tensor, y_ids: torch.Tensor, y_attn: torch.Tensor) -> torch.Tensor: """Standard next-token CE over the y segment. logits_y: [B, L_y, V] (already sliced so logits[:, t] predicts y[:, t]) y_ids: [B, L_y] y_attn: [B, L_y] 1 where real, 0 where pad """ B, L_y, V = logits_y.shape flat_logits = logits_y.reshape(B * L_y, V) flat_targets = y_ids.reshape(B * L_y) per_tok = F.cross_entropy(flat_logits, flat_targets, reduction="none").reshape(B, L_y) # Mask out pad positions mask = y_attn.float() return (per_tok * mask).sum() / mask.sum().clamp_min(1.0)