| """Sampled softmax with importance-weighted (log-q) correction. |
| |
| Audit 2026-05-09 issue #22 — Cluster E. |
| |
| Replaces the uniform-negative sampling in the LM-head loss with negatives |
| drawn from the unigram (token-frequency) distribution. With uniform sampling |
| the correction term reduces to a constant `log(V/K)` and the negatives are |
| dominated by rare tokens that the model already places near-zero mass on, |
| so they carry almost no contrastive signal. Sampling from the empirical |
| unigram distribution puts the negatives where the model's softmax mass |
| actually is — common tokens — and the per-id `log p_unigram[id]` correction |
| makes the resulting loss an unbiased estimator of the full softmax CE |
| (Jean et al. 2015 — *On Using Very Large Target Vocabulary*). |
| |
| Math |
| ---- |
| L_full(x_t, y_t) = -log softmax_y_t(W x_t) |
| = -W[y_t]·x + logsumexp_v (W[v]·x) |
| |
| We approximate `logsumexp_v` by a Monte-Carlo estimate using K negatives |
| drawn from a proposal distribution q. With log-q correction subtracted from |
| each candidate logit: |
| |
| z_v_corrected = W[v]·x - log q(v) |
| |
| then sampled softmax CE over candidates {y_t} ∪ Neg(K) recovers the full |
| softmax CE in expectation. q = unigram is the standard, near-optimal |
| choice when self-loss = NCE-style (concentrating samples in the high-mass |
| region of the model's output distribution). |
| |
| Implementation |
| -------------- |
| We use the **alias method** (Walker, 1977) to sample in O(1) per draw with |
| no log/exp. Tables (`prob` and `alias`) are precomputed once on the GPU at |
| sampler construction; `sample(shape, device)` is a single fused kernel |
| (uniform draws + gathers). |
| |
| For numerical stability the log-q correction uses |
| log_q[v] = log(freq[v] + eps_smooth) - log(freq.sum() + V * eps_smooth) |
| which floors out-of-vocabulary tokens (zero frequency in the cache) at a |
| small but non-zero probability — keeps training stable when the cache is |
| incomplete or shifts mid-training. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import os |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
|
|
| _DEFAULT_CACHE_DIR = Path.home() / ".cache" / "autoresearch" |
|
|
|
|
| def _alias_setup(probs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| """Build alias-method (prob, alias) tables. |
| |
| Standard Vose construction. O(V) work, runs once per sampler. Returns |
| `prob` (per-bucket acceptance probability in [0, 1]) and `alias` |
| (per-bucket fallback index). Both length V, on the same device as |
| `probs`. |
| |
| Inputs are normalised internally; `probs` need only be non-negative. |
| """ |
| if probs.dim() != 1: |
| raise ValueError(f"alias_setup expects 1-D probs, got {probs.shape}") |
| V = probs.shape[0] |
| if V == 0: |
| raise ValueError("alias_setup: empty probability vector") |
|
|
| |
| |
| |
| p = probs.detach().to(torch.float64) |
| s = p.sum() |
| if not torch.isfinite(s) or s <= 0: |
| raise ValueError(f"alias_setup: probs sum is non-positive or non-finite: {float(s)}") |
| p = p / s |
| scaled = p * V |
|
|
| |
| small: list[int] = [] |
| large: list[int] = [] |
| scaled_cpu = scaled.cpu().tolist() |
| for i, v in enumerate(scaled_cpu): |
| (large if v >= 1.0 else small).append(i) |
|
|
| prob_cpu = [0.0] * V |
| alias_cpu = [0] * V |
| while small and large: |
| s_idx = small.pop() |
| l_idx = large.pop() |
| prob_cpu[s_idx] = scaled_cpu[s_idx] |
| alias_cpu[s_idx] = l_idx |
| scaled_cpu[l_idx] = (scaled_cpu[l_idx] + scaled_cpu[s_idx]) - 1.0 |
| if scaled_cpu[l_idx] < 1.0: |
| small.append(l_idx) |
| else: |
| large.append(l_idx) |
| |
| |
| while large: |
| prob_cpu[large.pop()] = 1.0 |
| while small: |
| prob_cpu[small.pop()] = 1.0 |
|
|
| prob = torch.tensor(prob_cpu, dtype=torch.float32, device=probs.device) |
| alias = torch.tensor(alias_cpu, dtype=torch.long, device=probs.device) |
| return prob, alias |
|
|
|
|
| class UnigramSampler: |
| """GPU-resident alias sampler over a fixed token-frequency distribution. |
| |
| Args: |
| freq: 1-D tensor of length V. Non-negative; need not sum to 1. |
| eps_smooth: floor added to each token's frequency before |
| normalisation. Keeps log_q finite for OOV tokens |
| (zero-count in the cache). |
| device: device for the alias tables. Defaults to ``freq.device``. |
| |
| Attributes: |
| log_q: (V,) tensor of log probabilities with smoothing applied. |
| V: vocabulary size. |
| """ |
|
|
| def __init__( |
| self, |
| freq: torch.Tensor, |
| eps_smooth: float = 1e-6, |
| device: torch.device | str | None = None, |
| ) -> None: |
| if freq.dim() != 1: |
| raise ValueError(f"UnigramSampler: freq must be 1-D, got shape {freq.shape}") |
| if (freq < 0).any(): |
| raise ValueError("UnigramSampler: freq must be non-negative") |
| if device is None: |
| device = freq.device |
| else: |
| device = torch.device(device) |
|
|
| V = int(freq.shape[0]) |
| self.V = V |
| self.eps_smooth = float(eps_smooth) |
|
|
| |
| smoothed = freq.detach().to(device=device, dtype=torch.float64) + self.eps_smooth |
| total = smoothed.sum() |
| probs = smoothed / total |
|
|
| self._prob, self._alias = _alias_setup(probs) |
| |
| self.log_q = probs.log().to(torch.float32) |
|
|
| |
| @torch.no_grad() |
| def sample(self, shape: int | tuple[int, ...], device: torch.device | str | None = None) -> torch.Tensor: |
| """Draw `shape` samples from the unigram distribution. |
| |
| Returns a LongTensor of indices in `[0, V)`. All operations are |
| GPU-resident; no host syncs. |
| """ |
| if isinstance(shape, int): |
| shape_t = (shape,) |
| else: |
| shape_t = tuple(shape) |
| n = 1 |
| for d in shape_t: |
| n *= d |
| if n == 0: |
| return torch.empty(shape_t, dtype=torch.long, device=device or self._prob.device) |
|
|
| target_device = torch.device(device) if device is not None else self._prob.device |
| |
| if self._prob.device != target_device: |
| self._prob = self._prob.to(target_device) |
| self._alias = self._alias.to(target_device) |
| self.log_q = self.log_q.to(target_device) |
|
|
| |
| |
| u_bucket = torch.randint(0, self.V, (n,), device=target_device) |
| u_accept = torch.rand(n, device=target_device) |
| keep = u_accept < self._prob[u_bucket] |
| out = torch.where(keep, u_bucket, self._alias[u_bucket]) |
| return out.view(shape_t) |
|
|
| |
| def to(self, device: torch.device | str) -> "UnigramSampler": |
| target = torch.device(device) |
| self._prob = self._prob.to(target) |
| self._alias = self._alias.to(target) |
| self.log_q = self.log_q.to(target) |
| return self |
|
|
| |
| @classmethod |
| def from_uniform(cls, V: int, device: torch.device | str = "cpu") -> "UnigramSampler": |
| """Construct a uniform-distribution sampler. Useful for tests and |
| as a debug fallback when no unigram cache is available.""" |
| return cls(torch.ones(V, dtype=torch.float32), device=device) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def unigram_cache_path(vocab_size: int, cache_dir: Path | str | None = None) -> Path: |
| """Canonical path for the unigram-frequency cache file.""" |
| base = Path(cache_dir) if cache_dir is not None else _DEFAULT_CACHE_DIR |
| return base / f"unigram_freq_v{int(vocab_size)}.pt" |
|
|
|
|
| def save_unigram_freq(freq: torch.Tensor, vocab_size: int, cache_dir: Path | str | None = None) -> Path: |
| """Persist a unigram-frequency tensor to the canonical cache location.""" |
| if freq.dim() != 1 or freq.shape[0] != vocab_size: |
| raise ValueError( |
| f"save_unigram_freq: freq must be 1-D with length {vocab_size}, got {tuple(freq.shape)}" |
| ) |
| path = unigram_cache_path(vocab_size, cache_dir) |
| path.parent.mkdir(parents=True, exist_ok=True) |
| torch.save(freq.detach().to(torch.float32).cpu(), path) |
| return path |
|
|
|
|
| def load_unigram_freq(vocab_size: int, cache_dir: Path | str | None = None) -> torch.Tensor | None: |
| """Load a cached unigram-frequency tensor, or return None if unavailable. |
| |
| Validates the loaded tensor's length matches `vocab_size`; mismatches |
| return None (treated as cache miss) so the caller can rebuild. |
| """ |
| path = unigram_cache_path(vocab_size, cache_dir) |
| if not path.exists(): |
| return None |
| try: |
| freq = torch.load(path, map_location="cpu") |
| except Exception: |
| return None |
| if not isinstance(freq, torch.Tensor) or freq.dim() != 1 or freq.shape[0] != vocab_size: |
| return None |
| return freq.to(torch.float32) |
|
|
|
|
| def build_unigram_freq_from_tokenizer( |
| tokenizer, |
| vocab_size: int, |
| target_tokens: int = 1_000_000, |
| batch_size: int = 64, |
| ) -> torch.Tensor: |
| """Stream a small slice of the training data through the tokenizer and |
| return per-token frequencies (length V). Caller is responsible for |
| persisting via ``save_unigram_freq``. |
| |
| Used as a fallback when no cached frequencies exist; runs once on first |
| training start and writes to the cache. |
| """ |
| |
| |
| import prepare as _p |
|
|
| freq = torch.zeros(vocab_size, dtype=torch.float64) |
| seen = 0 |
| for batch, _epoch in _p._document_batches("train", tokenizer_batch_size=batch_size): |
| encoded = tokenizer.encode(batch, prepend=tokenizer.get_bos_token_id()) |
| flat: list[int] = [] |
| for row in encoded: |
| flat.extend(row) |
| if not flat: |
| continue |
| ids = torch.tensor(flat, dtype=torch.long) |
| |
| freq += torch.bincount(ids, minlength=vocab_size).to(torch.float64) |
| seen += ids.numel() |
| if seen >= target_tokens: |
| break |
| return freq.to(torch.float32) |
|
|
|
|
| def get_or_build_unigram_sampler( |
| tokenizer, |
| vocab_size: int, |
| device: torch.device | str = "cuda", |
| cache_dir: Path | str | None = None, |
| target_tokens: int = 1_000_000, |
| rebuild: bool = False, |
| ) -> UnigramSampler: |
| """Cache-aware constructor: load `unigram_freq_v{V}.pt` if present, |
| otherwise build from a streamed slice and persist. |
| |
| HYDRA_UNIGRAM_REBUILD=1 forces a rebuild even if the cache exists. |
| HYDRA_UNIGRAM_TARGET_TOKENS overrides `target_tokens` at the env level. |
| """ |
| if os.environ.get("HYDRA_UNIGRAM_REBUILD", "0") == "1": |
| rebuild = True |
| env_target = os.environ.get("HYDRA_UNIGRAM_TARGET_TOKENS") |
| if env_target is not None: |
| target_tokens = int(env_target) |
|
|
| freq = None if rebuild else load_unigram_freq(vocab_size, cache_dir) |
| if freq is None: |
| freq = build_unigram_freq_from_tokenizer( |
| tokenizer, vocab_size, target_tokens=target_tokens |
| ) |
| save_unigram_freq(freq, vocab_size, cache_dir) |
| return UnigramSampler(freq, device=device) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def sampled_softmax_loss( |
| x_flat: torch.Tensor, |
| y_flat: torch.Tensor, |
| lm_head_weight: torch.Tensor, |
| sampler: UnigramSampler, |
| K: int, |
| *, |
| label_smoothing: float = 0.0, |
| softcap: float | None = None, |
| softcap_clamp: bool = False, |
| valid_mask: torch.Tensor | None = None, |
| reduction: str = "mean", |
| shared_negatives: bool = True, |
| ) -> torch.Tensor: |
| """Importance-sampled (unigram) sampled softmax cross-entropy. |
| |
| Args: |
| x_flat: (N, d) hidden states. |
| y_flat: (N,) target token ids. Negative entries treated as |
| invalid; if `valid_mask` is None the function falls |
| back to ``y_flat >= 0``. |
| lm_head_weight: (V, d) LM-head weight (typically `model.lm_head.weight`). |
| sampler: a ``UnigramSampler`` instance (its log_q must live on |
| the same device as `x_flat`). |
| K: total candidates per row including the positive |
| (K-1 negatives drawn). |
| label_smoothing: passed through to F.cross_entropy. |
| softcap: if non-None, apply tanh-softcap to candidate logits. |
| softcap_clamp: True → torch.clamp instead of tanh-softcap. |
| valid_mask: (N,) bool mask. Invalid positions contribute zero loss. |
| reduction: 'mean' | 'none'. |
| shared_negatives: if True (default), draw a SINGLE batch of K-1 |
| negatives shared across all N rows. If False, draw |
| independent negatives per row. Shared is faster |
| (single (n, K) matmul, no (n, K, d) gather) and is |
| what Jean et al. 2015 use; per-row is statistically |
| slightly better but expensive at typical d/K. |
| |
| Returns: |
| Scalar (mean) or per-token (none) cross-entropy. |
| |
| Backward: gradient on `lm_head_weight` flows only through the gathered |
| rows (positives + drawn negatives). No full V x d gradient. |
| """ |
| if x_flat.dim() != 2: |
| raise ValueError(f"sampled_softmax_loss: x_flat must be 2-D, got {x_flat.shape}") |
| if y_flat.shape != (x_flat.shape[0],): |
| raise ValueError( |
| f"sampled_softmax_loss: y_flat shape {tuple(y_flat.shape)} " |
| f"does not match x_flat batch {x_flat.shape[0]}" |
| ) |
| V, d = lm_head_weight.shape |
| if x_flat.shape[1] != d: |
| raise ValueError( |
| f"sampled_softmax_loss: x_flat.shape[-1]={x_flat.shape[1]} != lm_head dim {d}" |
| ) |
| if K <= 0 or K > V: |
| raise ValueError(f"sampled_softmax_loss: K={K} out of range (1, V={V}]") |
|
|
| n = x_flat.shape[0] |
| device = x_flat.device |
|
|
| if valid_mask is None: |
| valid_mask = (y_flat >= 0) |
| y_safe = torch.where(valid_mask, y_flat, torch.zeros_like(y_flat)) |
|
|
| if shared_negatives: |
| |
| |
| |
| |
| K_neg = K - 1 |
| neg_idx = sampler.sample((K_neg,), device=device) |
| |
| pos_w = F.embedding(y_safe.view(n, 1), lm_head_weight).squeeze(1) |
| pos_logit = (x_flat * pos_w).sum(-1) |
| |
| neg_w = F.embedding(neg_idx, lm_head_weight) |
| neg_logits = x_flat @ neg_w.t() |
|
|
| if softcap is not None and softcap > 0: |
| if softcap_clamp: |
| pos_logit = torch.clamp(pos_logit, -softcap, softcap) |
| neg_logits = torch.clamp(neg_logits, -softcap, softcap) |
| else: |
| pos_logit = softcap * torch.tanh(pos_logit / softcap) |
| neg_logits = softcap * torch.tanh(neg_logits / softcap) |
|
|
| |
| log_q_pos = sampler.log_q[y_safe] |
| log_q_neg = sampler.log_q[neg_idx] |
| pos_logit = pos_logit - log_q_pos |
| neg_logits = neg_logits - log_q_neg |
|
|
| logits = torch.cat([pos_logit.unsqueeze(-1), neg_logits], dim=1).float() |
| else: |
| |
| neg = sampler.sample((n, K - 1), device=device) |
| cand_idx = torch.cat([y_safe.view(n, 1), neg], dim=1) |
| cand_w = F.embedding(cand_idx, lm_head_weight) |
| logits = torch.einsum("nd,nkd->nk", x_flat, cand_w) |
|
|
| if softcap is not None and softcap > 0: |
| if softcap_clamp: |
| logits = torch.clamp(logits, -softcap, softcap) |
| else: |
| logits = softcap * torch.tanh(logits / softcap) |
|
|
| log_q = sampler.log_q[cand_idx] |
| logits = (logits - log_q).float() |
|
|
| |
| ce_targets = torch.zeros(n, dtype=torch.long, device=device) |
| per_tok = F.cross_entropy( |
| logits, ce_targets, reduction="none", label_smoothing=label_smoothing |
| ) |
|
|
| valid_f = valid_mask.to(per_tok.dtype) |
| per_tok = per_tok * valid_f |
|
|
| if reduction == "none": |
| return per_tok |
| if reduction == "mean": |
| denom = valid_f.sum().clamp(min=1) |
| return per_tok.sum() / denom |
| raise ValueError(f"sampled_softmax_loss: unknown reduction {reduction!r}") |
|
|