LauraGG's picture
Refresh code/ with latest BLT-Reasoner sources (post-campaign)
bc7101b verified
"""Losses for BLT-Reasoner.
L_total = L_LM(y | z, mask=y→x-blocked)
+ λ_id · L_InfoNCE(g(z), f(y)) # identifiability lock
+ λ_kl · KL( q(z) || N(0, I) ) # magnitude regularizer
The InfoNCE term is the structural defense against constant-z basins. f(y)
is the frozen base model's mean-pooled hidden state over the gold answer
(adapters disabled), so f is decoupled from gradient. g(z) is a small
learned head over z (mean-pooled across K positions, then projected).
For a constant-z policy: g(z) is the same for every sample in the batch,
so all logits in InfoNCE are equal, and the loss is bounded below by
log(B). Only a sample-dependent z can lower this.
"""
from __future__ import annotations
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
@dataclass
class LossWeights:
lambda_lm: float = 1.0
lambda_id: float = 1.0
lambda_kl: float = 1e-3
tau_infonce: float = 0.07
class InfoNCEHead(nn.Module):
"""Two MLPs project z and f(y) into a shared embedding space for InfoNCE."""
def __init__(self, d_z: int, d_y: int, d_out: int = 256):
super().__init__()
self.g = nn.Sequential(
nn.Linear(d_z, d_out), nn.GELU(), nn.Linear(d_out, d_out)
)
self.h = nn.Sequential(
nn.Linear(d_y, d_out), nn.GELU(), nn.Linear(d_out, d_out)
)
def forward(self, z_pool: torch.Tensor, y_pool: torch.Tensor):
z_emb = F.normalize(self.g(z_pool), dim=-1)
y_emb = F.normalize(self.h(y_pool), dim=-1)
return z_emb, y_emb
def infonce_loss(z_emb: torch.Tensor, y_emb: torch.Tensor, tau: float = 0.07) -> torch.Tensor:
"""Symmetric InfoNCE (CLIP-style).
z_emb, y_emb: [B, d] L2-normalized. Diagonal = positives.
Lower bound for constant-z (z_emb identical across rows): -log(1/B) = log(B).
"""
B = z_emb.size(0)
logits = z_emb @ y_emb.t() / tau # [B, B]
targets = torch.arange(B, device=z_emb.device)
loss_z2y = F.cross_entropy(logits, targets)
loss_y2z = F.cross_entropy(logits.t(), targets)
return 0.5 * (loss_z2y + loss_y2z)
def slot_decorrelation_loss(z: torch.Tensor) -> torch.Tensor:
"""Penalize pairwise alignment between latent slots.
z: [B, K, d] (the per-slot input embeddings from forward_with_latent)
Returns: scalar = mean squared off-diagonal of per-batch row-normalized
Gram matrices. When this is 0, all K slots are pairwise orthogonal
(cos-sim = 0). When this is high, slots are aligned (redundant).
Used as a soft regularizer when the capacity diagnostic shows slots are
highly redundant (stable_rank << K). Forces the projector / loop to
produce slots that span more directions.
"""
Zn = F.normalize(z.float(), dim=-1, eps=1e-6) # [B, K, d]
G = torch.einsum("bkd,bjd->bkj", Zn, Zn) # [B, K, K]
K = z.size(1)
eye = torch.eye(K, device=z.device, dtype=G.dtype).unsqueeze(0)
off_diag = G - eye # diagonal -> 0
# Mean over off-diagonal entries only (more interpretable than mean over all K²)
n_off = K * (K - 1)
return (off_diag.pow(2).sum(dim=(-1, -2)) / max(n_off, 1)).mean()
def kl_to_gaussian(z: torch.Tensor) -> torch.Tensor:
"""Approximate KL(z || N(0, I)) treating z as a deterministic point.
With deterministic z this reduces to 0.5 * (||z||² - d) per latent slot,
which is the standard β-VAE regularizer when q is a delta. We use it as
a soft magnitude prior so z doesn't grow unboundedly (no inherent norm
in the residual stream).
"""
# z: [B, K, d]
return 0.5 * (z.pow(2).sum(dim=-1) - z.size(-1)).mean()
@torch.no_grad()
def encode_chunks_per_slot(model, tokenizer, chunks_per_problem, device, max_len: int = 32) -> torch.Tensor:
"""Encode a [B][K] list-of-lists of y-chunk strings via the frozen base.
chunks_per_problem: List of length B, each a list of length K of chunk strings.
Returns: tensor [B, K, d] of mean-pooled last-layer hidden states (frozen).
For per-CoT-step InfoNCE: each slot k receives a *different* target —
the encoding of chunk k. Forces slots to specialize rather than all
learning the same global y representation (which our capacity diagnostic
showed yields stable_rank=6.73 across K=16 slots).
"""
import torch
import contextlib
if not chunks_per_problem:
raise ValueError("chunks_per_problem is empty")
B = len(chunks_per_problem)
K = len(chunks_per_problem[0])
flat = []
for cps in chunks_per_problem:
for c in cps:
flat.append(c if c else "<pad>")
enc = tokenizer(flat, return_tensors="pt", padding=True, truncation=True,
max_length=max_len, add_special_tokens=False).to(device)
inner = model.get_base_model() if hasattr(model, "get_base_model") else model
ctx = model.disable_adapter() if hasattr(model, "disable_adapter") else contextlib.nullcontext()
with ctx:
out = inner.model(
input_ids=enc["input_ids"], attention_mask=enc["attention_mask"],
use_cache=False, return_dict=True,
)
mask = enc["attention_mask"].unsqueeze(-1).to(out.last_hidden_state.dtype)
pooled = (out.last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1.0)
return pooled.detach().view(B, K, -1)
def infonce_per_slot_loss(
z: torch.Tensor, # [B, K, d_z]
y_chunks: torch.Tensor, # [B, K, d_y]
head: nn.Module, # InfoNCEHead, used flattened
tau: float = 0.07,
) -> dict:
"""Per-slot InfoNCE: for each (problem b, slot k), positive = chunk_k of
problem b's gold y. Negatives = all other (b', k') combinations.
Construct a [B*K, B*K] similarity matrix and CE against the identity.
A constant z (across slots) cannot satisfy this — neither can a z that
correctly identifies "which problem" but mixes up slot index. Forces
the projection π to specialize z's slots.
Returns dict with the loss + within-batch/within-slot accuracy probes.
"""
B, K, _ = z.shape
z_flat = z.reshape(B * K, -1).float()
y_flat = y_chunks.reshape(B * K, -1).float()
z_emb, y_emb = head(z_flat, y_flat) # [B*K, d_out] L2-normalized
logits = z_emb @ y_emb.t() / tau # [B*K, B*K]
targets = torch.arange(B * K, device=z.device)
loss_z2y = F.cross_entropy(logits, targets)
loss_y2z = F.cross_entropy(logits.t(), targets)
loss = 0.5 * (loss_z2y + loss_y2z)
with torch.no_grad():
# Top-1 accuracy of identifying the correct (problem, slot)
acc_z2y = (logits.argmax(dim=1) == targets).float().mean()
acc_y2z = (logits.t().argmax(dim=1) == targets).float().mean()
# Within-problem accuracy: among the K chunks of the SAME problem,
# does slot k pick chunk k? Tests whether slots distinguish their
# *positions* (not just their problem).
sim = (z_emb.view(B, K, -1) @ y_emb.view(B, K, -1).transpose(-1, -2)) # [B, K, K]
pred = sim.argmax(dim=-1)
acc_within = (pred == torch.arange(K, device=z.device).unsqueeze(0)).float().mean()
return {
"loss": loss,
"acc_z2y": acc_z2y,
"acc_y2z": acc_y2z,
"acc_within_problem": acc_within,
}
@torch.no_grad()
def encode_answer_for_infonce(model, tokenizer, y_text: list, device, max_len: int = 64) -> torch.Tensor:
"""Encode gold answer strings via the frozen base (LoRA adapters disabled),
return mean-pooled last-layer hidden state.
For GSM8K we typically feed only the final-number portion ("#### 42") so
f(y) is anchored to the answer rather than full reasoning text.
"""
enc = tokenizer(y_text, return_tensors="pt", padding=True, truncation=True,
max_length=max_len, add_special_tokens=False).to(device)
inner = model.get_base_model() if hasattr(model, "get_base_model") else model
# PEFT: disable adapters for this encoding pass.
if hasattr(model, "disable_adapter"):
ctx = model.disable_adapter()
else:
import contextlib
ctx = contextlib.nullcontext()
with ctx:
out = inner.model(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"],
use_cache=False, return_dict=True)
# Mean-pool over non-pad tokens.
mask = enc["attention_mask"].unsqueeze(-1).to(out.last_hidden_state.dtype)
pooled = (out.last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1.0)
return pooled.detach()
def lm_loss_on_y(logits_y: torch.Tensor, y_ids: torch.Tensor, y_attn: torch.Tensor) -> torch.Tensor:
"""Standard next-token CE over the y segment.
logits_y: [B, L_y, V] (already sliced so logits[:, t] predicts y[:, t])
y_ids: [B, L_y]
y_attn: [B, L_y] 1 where real, 0 where pad
"""
B, L_y, V = logits_y.shape
flat_logits = logits_y.reshape(B * L_y, V)
flat_targets = y_ids.reshape(B * L_y)
per_tok = F.cross_entropy(flat_logits, flat_targets, reduction="none").reshape(B, L_y)
# Mask out pad positions
mask = y_attn.float()
return (per_tok * mask).sum() / mask.sum().clamp_min(1.0)