feather-a10g-large-runtime / overlay /hydra /sampled_softmax.py
icarus112's picture
Upload folder using huggingface_hub
c383594 verified
"""Sampled softmax with importance-weighted (log-q) correction.
Audit 2026-05-09 issue #22 — Cluster E.
Replaces the uniform-negative sampling in the LM-head loss with negatives
drawn from the unigram (token-frequency) distribution. With uniform sampling
the correction term reduces to a constant `log(V/K)` and the negatives are
dominated by rare tokens that the model already places near-zero mass on,
so they carry almost no contrastive signal. Sampling from the empirical
unigram distribution puts the negatives where the model's softmax mass
actually is — common tokens — and the per-id `log p_unigram[id]` correction
makes the resulting loss an unbiased estimator of the full softmax CE
(Jean et al. 2015 — *On Using Very Large Target Vocabulary*).
Math
----
L_full(x_t, y_t) = -log softmax_y_t(W x_t)
= -W[y_t]·x + logsumexp_v (W[v]·x)
We approximate `logsumexp_v` by a Monte-Carlo estimate using K negatives
drawn from a proposal distribution q. With log-q correction subtracted from
each candidate logit:
z_v_corrected = W[v]·x - log q(v)
then sampled softmax CE over candidates {y_t} ∪ Neg(K) recovers the full
softmax CE in expectation. q = unigram is the standard, near-optimal
choice when self-loss = NCE-style (concentrating samples in the high-mass
region of the model's output distribution).
Implementation
--------------
We use the **alias method** (Walker, 1977) to sample in O(1) per draw with
no log/exp. Tables (`prob` and `alias`) are precomputed once on the GPU at
sampler construction; `sample(shape, device)` is a single fused kernel
(uniform draws + gathers).
For numerical stability the log-q correction uses
log_q[v] = log(freq[v] + eps_smooth) - log(freq.sum() + V * eps_smooth)
which floors out-of-vocabulary tokens (zero frequency in the cache) at a
small but non-zero probability — keeps training stable when the cache is
incomplete or shifts mid-training.
"""
from __future__ import annotations
import os
from pathlib import Path
import torch
import torch.nn.functional as F
_DEFAULT_CACHE_DIR = Path.home() / ".cache" / "autoresearch"
def _alias_setup(probs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Build alias-method (prob, alias) tables.
Standard Vose construction. O(V) work, runs once per sampler. Returns
`prob` (per-bucket acceptance probability in [0, 1]) and `alias`
(per-bucket fallback index). Both length V, on the same device as
`probs`.
Inputs are normalised internally; `probs` need only be non-negative.
"""
if probs.dim() != 1:
raise ValueError(f"alias_setup expects 1-D probs, got {probs.shape}")
V = probs.shape[0]
if V == 0:
raise ValueError("alias_setup: empty probability vector")
# Vose construction operates in float64 for numerical stability on
# large V (V=200k probs sum to 1.0 with relative error ~1e-7 in
# float32, which produces persistent under-/over-flowing buckets).
p = probs.detach().to(torch.float64)
s = p.sum()
if not torch.isfinite(s) or s <= 0:
raise ValueError(f"alias_setup: probs sum is non-positive or non-finite: {float(s)}")
p = p / s
scaled = p * V
# Two work queues: small (scaled < 1) and large (scaled >= 1).
small: list[int] = []
large: list[int] = []
scaled_cpu = scaled.cpu().tolist() # Python loop is fine — runs once per training session.
for i, v in enumerate(scaled_cpu):
(large if v >= 1.0 else small).append(i)
prob_cpu = [0.0] * V
alias_cpu = [0] * V
while small and large:
s_idx = small.pop()
l_idx = large.pop()
prob_cpu[s_idx] = scaled_cpu[s_idx]
alias_cpu[s_idx] = l_idx
scaled_cpu[l_idx] = (scaled_cpu[l_idx] + scaled_cpu[s_idx]) - 1.0
if scaled_cpu[l_idx] < 1.0:
small.append(l_idx)
else:
large.append(l_idx)
# Drain. Both queues should be near-1.0 by construction; floating point
# leaves negligible residue. Set acceptance to 1.0 — alias never used.
while large:
prob_cpu[large.pop()] = 1.0
while small:
prob_cpu[small.pop()] = 1.0
prob = torch.tensor(prob_cpu, dtype=torch.float32, device=probs.device)
alias = torch.tensor(alias_cpu, dtype=torch.long, device=probs.device)
return prob, alias
class UnigramSampler:
"""GPU-resident alias sampler over a fixed token-frequency distribution.
Args:
freq: 1-D tensor of length V. Non-negative; need not sum to 1.
eps_smooth: floor added to each token's frequency before
normalisation. Keeps log_q finite for OOV tokens
(zero-count in the cache).
device: device for the alias tables. Defaults to ``freq.device``.
Attributes:
log_q: (V,) tensor of log probabilities with smoothing applied.
V: vocabulary size.
"""
def __init__(
self,
freq: torch.Tensor,
eps_smooth: float = 1e-6,
device: torch.device | str | None = None,
) -> None:
if freq.dim() != 1:
raise ValueError(f"UnigramSampler: freq must be 1-D, got shape {freq.shape}")
if (freq < 0).any():
raise ValueError("UnigramSampler: freq must be non-negative")
if device is None:
device = freq.device
else:
device = torch.device(device)
V = int(freq.shape[0])
self.V = V
self.eps_smooth = float(eps_smooth)
# Smoothed probabilities for both alias-build and log_q correction.
smoothed = freq.detach().to(device=device, dtype=torch.float64) + self.eps_smooth
total = smoothed.sum()
probs = smoothed / total
self._prob, self._alias = _alias_setup(probs)
# log_q is registered as float32 — autocast-friendly, used in CE.
self.log_q = probs.log().to(torch.float32)
# ------------------------------------------------------------------
@torch.no_grad()
def sample(self, shape: int | tuple[int, ...], device: torch.device | str | None = None) -> torch.Tensor:
"""Draw `shape` samples from the unigram distribution.
Returns a LongTensor of indices in `[0, V)`. All operations are
GPU-resident; no host syncs.
"""
if isinstance(shape, int):
shape_t = (shape,)
else:
shape_t = tuple(shape)
n = 1
for d in shape_t:
n *= d
if n == 0:
return torch.empty(shape_t, dtype=torch.long, device=device or self._prob.device)
target_device = torch.device(device) if device is not None else self._prob.device
# Move alias tables on demand (rare — usually constructed on CUDA).
if self._prob.device != target_device:
self._prob = self._prob.to(target_device)
self._alias = self._alias.to(target_device)
self.log_q = self.log_q.to(target_device)
# Vose: pick a uniform bucket, then with probability prob[bucket] keep
# it, else jump to alias[bucket].
u_bucket = torch.randint(0, self.V, (n,), device=target_device)
u_accept = torch.rand(n, device=target_device)
keep = u_accept < self._prob[u_bucket]
out = torch.where(keep, u_bucket, self._alias[u_bucket])
return out.view(shape_t)
# ------------------------------------------------------------------
def to(self, device: torch.device | str) -> "UnigramSampler":
target = torch.device(device)
self._prob = self._prob.to(target)
self._alias = self._alias.to(target)
self.log_q = self.log_q.to(target)
return self
# ------------------------------------------------------------------
@classmethod
def from_uniform(cls, V: int, device: torch.device | str = "cpu") -> "UnigramSampler":
"""Construct a uniform-distribution sampler. Useful for tests and
as a debug fallback when no unigram cache is available."""
return cls(torch.ones(V, dtype=torch.float32), device=device)
# ----------------------------------------------------------------------
# Frequency-cache build/load helpers
# ----------------------------------------------------------------------
def unigram_cache_path(vocab_size: int, cache_dir: Path | str | None = None) -> Path:
"""Canonical path for the unigram-frequency cache file."""
base = Path(cache_dir) if cache_dir is not None else _DEFAULT_CACHE_DIR
return base / f"unigram_freq_v{int(vocab_size)}.pt"
def save_unigram_freq(freq: torch.Tensor, vocab_size: int, cache_dir: Path | str | None = None) -> Path:
"""Persist a unigram-frequency tensor to the canonical cache location."""
if freq.dim() != 1 or freq.shape[0] != vocab_size:
raise ValueError(
f"save_unigram_freq: freq must be 1-D with length {vocab_size}, got {tuple(freq.shape)}"
)
path = unigram_cache_path(vocab_size, cache_dir)
path.parent.mkdir(parents=True, exist_ok=True)
torch.save(freq.detach().to(torch.float32).cpu(), path)
return path
def load_unigram_freq(vocab_size: int, cache_dir: Path | str | None = None) -> torch.Tensor | None:
"""Load a cached unigram-frequency tensor, or return None if unavailable.
Validates the loaded tensor's length matches `vocab_size`; mismatches
return None (treated as cache miss) so the caller can rebuild.
"""
path = unigram_cache_path(vocab_size, cache_dir)
if not path.exists():
return None
try:
freq = torch.load(path, map_location="cpu")
except Exception:
return None
if not isinstance(freq, torch.Tensor) or freq.dim() != 1 or freq.shape[0] != vocab_size:
return None
return freq.to(torch.float32)
def build_unigram_freq_from_tokenizer(
tokenizer,
vocab_size: int,
target_tokens: int = 1_000_000,
batch_size: int = 64,
) -> torch.Tensor:
"""Stream a small slice of the training data through the tokenizer and
return per-token frequencies (length V). Caller is responsible for
persisting via ``save_unigram_freq``.
Used as a fallback when no cached frequencies exist; runs once on first
training start and writes to the cache.
"""
# Lazy import — keeps the module importable on machines without the data
# path provisioned (e.g., CI, unit tests).
import prepare as _p
freq = torch.zeros(vocab_size, dtype=torch.float64)
seen = 0
for batch, _epoch in _p._document_batches("train", tokenizer_batch_size=batch_size):
encoded = tokenizer.encode(batch, prepend=tokenizer.get_bos_token_id())
flat: list[int] = []
for row in encoded:
flat.extend(row)
if not flat:
continue
ids = torch.tensor(flat, dtype=torch.long)
# bincount with minlength keeps the histogram aligned to V.
freq += torch.bincount(ids, minlength=vocab_size).to(torch.float64)
seen += ids.numel()
if seen >= target_tokens:
break
return freq.to(torch.float32)
def get_or_build_unigram_sampler(
tokenizer,
vocab_size: int,
device: torch.device | str = "cuda",
cache_dir: Path | str | None = None,
target_tokens: int = 1_000_000,
rebuild: bool = False,
) -> UnigramSampler:
"""Cache-aware constructor: load `unigram_freq_v{V}.pt` if present,
otherwise build from a streamed slice and persist.
HYDRA_UNIGRAM_REBUILD=1 forces a rebuild even if the cache exists.
HYDRA_UNIGRAM_TARGET_TOKENS overrides `target_tokens` at the env level.
"""
if os.environ.get("HYDRA_UNIGRAM_REBUILD", "0") == "1":
rebuild = True
env_target = os.environ.get("HYDRA_UNIGRAM_TARGET_TOKENS")
if env_target is not None:
target_tokens = int(env_target)
freq = None if rebuild else load_unigram_freq(vocab_size, cache_dir)
if freq is None:
freq = build_unigram_freq_from_tokenizer(
tokenizer, vocab_size, target_tokens=target_tokens
)
save_unigram_freq(freq, vocab_size, cache_dir)
return UnigramSampler(freq, device=device)
# ----------------------------------------------------------------------
# Loss
# ----------------------------------------------------------------------
def sampled_softmax_loss(
x_flat: torch.Tensor,
y_flat: torch.Tensor,
lm_head_weight: torch.Tensor,
sampler: UnigramSampler,
K: int,
*,
label_smoothing: float = 0.0,
softcap: float | None = None,
softcap_clamp: bool = False,
valid_mask: torch.Tensor | None = None,
reduction: str = "mean",
shared_negatives: bool = True,
) -> torch.Tensor:
"""Importance-sampled (unigram) sampled softmax cross-entropy.
Args:
x_flat: (N, d) hidden states.
y_flat: (N,) target token ids. Negative entries treated as
invalid; if `valid_mask` is None the function falls
back to ``y_flat >= 0``.
lm_head_weight: (V, d) LM-head weight (typically `model.lm_head.weight`).
sampler: a ``UnigramSampler`` instance (its log_q must live on
the same device as `x_flat`).
K: total candidates per row including the positive
(K-1 negatives drawn).
label_smoothing: passed through to F.cross_entropy.
softcap: if non-None, apply tanh-softcap to candidate logits.
softcap_clamp: True → torch.clamp instead of tanh-softcap.
valid_mask: (N,) bool mask. Invalid positions contribute zero loss.
reduction: 'mean' | 'none'.
shared_negatives: if True (default), draw a SINGLE batch of K-1
negatives shared across all N rows. If False, draw
independent negatives per row. Shared is faster
(single (n, K) matmul, no (n, K, d) gather) and is
what Jean et al. 2015 use; per-row is statistically
slightly better but expensive at typical d/K.
Returns:
Scalar (mean) or per-token (none) cross-entropy.
Backward: gradient on `lm_head_weight` flows only through the gathered
rows (positives + drawn negatives). No full V x d gradient.
"""
if x_flat.dim() != 2:
raise ValueError(f"sampled_softmax_loss: x_flat must be 2-D, got {x_flat.shape}")
if y_flat.shape != (x_flat.shape[0],):
raise ValueError(
f"sampled_softmax_loss: y_flat shape {tuple(y_flat.shape)} "
f"does not match x_flat batch {x_flat.shape[0]}"
)
V, d = lm_head_weight.shape
if x_flat.shape[1] != d:
raise ValueError(
f"sampled_softmax_loss: x_flat.shape[-1]={x_flat.shape[1]} != lm_head dim {d}"
)
if K <= 0 or K > V:
raise ValueError(f"sampled_softmax_loss: K={K} out of range (1, V={V}]")
n = x_flat.shape[0]
device = x_flat.device
if valid_mask is None:
valid_mask = (y_flat >= 0)
y_safe = torch.where(valid_mask, y_flat, torch.zeros_like(y_flat))
if shared_negatives:
# Shared-batch path: (n, d) x (d, K) matmul + per-row positive dot.
# This is what the production loss path actually wants — the (n, K, d)
# gather of the per-row path costs O(nKd) memory and beats the full
# softmax matmul only when n is small AND d is large.
K_neg = K - 1 # `K` total candidates includes the positive at column 0.
neg_idx = sampler.sample((K_neg,), device=device) # (K-1,)
# Positive logit: (n, d) * (n, d) -> (n,)
pos_w = F.embedding(y_safe.view(n, 1), lm_head_weight).squeeze(1) # (n, d)
pos_logit = (x_flat * pos_w).sum(-1) # (n,)
# Negative logits: shared (K-1) negatives
neg_w = F.embedding(neg_idx, lm_head_weight) # (K-1, d)
neg_logits = x_flat @ neg_w.t() # (n, K-1)
if softcap is not None and softcap > 0:
if softcap_clamp:
pos_logit = torch.clamp(pos_logit, -softcap, softcap)
neg_logits = torch.clamp(neg_logits, -softcap, softcap)
else:
pos_logit = softcap * torch.tanh(pos_logit / softcap)
neg_logits = softcap * torch.tanh(neg_logits / softcap)
# log-q correction.
log_q_pos = sampler.log_q[y_safe] # (n,)
log_q_neg = sampler.log_q[neg_idx] # (K-1,)
pos_logit = pos_logit - log_q_pos
neg_logits = neg_logits - log_q_neg # broadcasts
logits = torch.cat([pos_logit.unsqueeze(-1), neg_logits], dim=1).float() # (n, K)
else:
# Per-row independent negatives. (n, K-1) negatives, gather (n, K, d).
neg = sampler.sample((n, K - 1), device=device)
cand_idx = torch.cat([y_safe.view(n, 1), neg], dim=1) # (n, K)
cand_w = F.embedding(cand_idx, lm_head_weight) # (n, K, d)
logits = torch.einsum("nd,nkd->nk", x_flat, cand_w) # (n, K)
if softcap is not None and softcap > 0:
if softcap_clamp:
logits = torch.clamp(logits, -softcap, softcap)
else:
logits = softcap * torch.tanh(logits / softcap)
log_q = sampler.log_q[cand_idx] # (n, K)
logits = (logits - log_q).float()
# CE with positive at column 0.
ce_targets = torch.zeros(n, dtype=torch.long, device=device)
per_tok = F.cross_entropy(
logits, ce_targets, reduction="none", label_smoothing=label_smoothing
)
valid_f = valid_mask.to(per_tok.dtype)
per_tok = per_tok * valid_f
if reduction == "none":
return per_tok
if reduction == "mean":
denom = valid_f.sum().clamp(min=1)
return per_tok.sum() / denom
raise ValueError(f"sampled_softmax_loss: unknown reduction {reduction!r}")