File size: 18,194 Bytes
c383594
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
"""Sampled softmax with importance-weighted (log-q) correction.

Audit 2026-05-09 issue #22 β€” Cluster E.

Replaces the uniform-negative sampling in the LM-head loss with negatives
drawn from the unigram (token-frequency) distribution. With uniform sampling
the correction term reduces to a constant `log(V/K)` and the negatives are
dominated by rare tokens that the model already places near-zero mass on,
so they carry almost no contrastive signal. Sampling from the empirical
unigram distribution puts the negatives where the model's softmax mass
actually is β€” common tokens β€” and the per-id `log p_unigram[id]` correction
makes the resulting loss an unbiased estimator of the full softmax CE
(Jean et al. 2015 β€” *On Using Very Large Target Vocabulary*).

Math
----
    L_full(x_t, y_t)  =  -log softmax_y_t(W x_t)
                      =  -W[y_t]Β·x  +  logsumexp_v (W[v]Β·x)

We approximate `logsumexp_v` by a Monte-Carlo estimate using K negatives
drawn from a proposal distribution q. With log-q correction subtracted from
each candidate logit:

    z_v_corrected = W[v]Β·x  -  log q(v)

then sampled softmax CE over candidates {y_t} βˆͺ Neg(K) recovers the full
softmax CE in expectation. q = unigram is the standard, near-optimal
choice when self-loss = NCE-style (concentrating samples in the high-mass
region of the model's output distribution).

Implementation
--------------
We use the **alias method** (Walker, 1977) to sample in O(1) per draw with
no log/exp. Tables (`prob` and `alias`) are precomputed once on the GPU at
sampler construction; `sample(shape, device)` is a single fused kernel
(uniform draws + gathers).

For numerical stability the log-q correction uses
    log_q[v] = log(freq[v] + eps_smooth) - log(freq.sum() + V * eps_smooth)
which floors out-of-vocabulary tokens (zero frequency in the cache) at a
small but non-zero probability β€” keeps training stable when the cache is
incomplete or shifts mid-training.
"""

from __future__ import annotations

import os
from pathlib import Path

import torch
import torch.nn.functional as F


_DEFAULT_CACHE_DIR = Path.home() / ".cache" / "autoresearch"


def _alias_setup(probs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """Build alias-method (prob, alias) tables.

    Standard Vose construction. O(V) work, runs once per sampler. Returns
    `prob` (per-bucket acceptance probability in [0, 1]) and `alias`
    (per-bucket fallback index). Both length V, on the same device as
    `probs`.

    Inputs are normalised internally; `probs` need only be non-negative.
    """
    if probs.dim() != 1:
        raise ValueError(f"alias_setup expects 1-D probs, got {probs.shape}")
    V = probs.shape[0]
    if V == 0:
        raise ValueError("alias_setup: empty probability vector")

    # Vose construction operates in float64 for numerical stability on
    # large V (V=200k probs sum to 1.0 with relative error ~1e-7 in
    # float32, which produces persistent under-/over-flowing buckets).
    p = probs.detach().to(torch.float64)
    s = p.sum()
    if not torch.isfinite(s) or s <= 0:
        raise ValueError(f"alias_setup: probs sum is non-positive or non-finite: {float(s)}")
    p = p / s
    scaled = p * V

    # Two work queues: small (scaled < 1) and large (scaled >= 1).
    small: list[int] = []
    large: list[int] = []
    scaled_cpu = scaled.cpu().tolist()  # Python loop is fine β€” runs once per training session.
    for i, v in enumerate(scaled_cpu):
        (large if v >= 1.0 else small).append(i)

    prob_cpu = [0.0] * V
    alias_cpu = [0] * V
    while small and large:
        s_idx = small.pop()
        l_idx = large.pop()
        prob_cpu[s_idx] = scaled_cpu[s_idx]
        alias_cpu[s_idx] = l_idx
        scaled_cpu[l_idx] = (scaled_cpu[l_idx] + scaled_cpu[s_idx]) - 1.0
        if scaled_cpu[l_idx] < 1.0:
            small.append(l_idx)
        else:
            large.append(l_idx)
    # Drain. Both queues should be near-1.0 by construction; floating point
    # leaves negligible residue. Set acceptance to 1.0 β€” alias never used.
    while large:
        prob_cpu[large.pop()] = 1.0
    while small:
        prob_cpu[small.pop()] = 1.0

    prob = torch.tensor(prob_cpu, dtype=torch.float32, device=probs.device)
    alias = torch.tensor(alias_cpu, dtype=torch.long, device=probs.device)
    return prob, alias


class UnigramSampler:
    """GPU-resident alias sampler over a fixed token-frequency distribution.

    Args:
        freq:          1-D tensor of length V. Non-negative; need not sum to 1.
        eps_smooth:    floor added to each token's frequency before
                       normalisation. Keeps log_q finite for OOV tokens
                       (zero-count in the cache).
        device:        device for the alias tables. Defaults to ``freq.device``.

    Attributes:
        log_q:         (V,) tensor of log probabilities with smoothing applied.
        V:             vocabulary size.
    """

    def __init__(
        self,
        freq: torch.Tensor,
        eps_smooth: float = 1e-6,
        device: torch.device | str | None = None,
    ) -> None:
        if freq.dim() != 1:
            raise ValueError(f"UnigramSampler: freq must be 1-D, got shape {freq.shape}")
        if (freq < 0).any():
            raise ValueError("UnigramSampler: freq must be non-negative")
        if device is None:
            device = freq.device
        else:
            device = torch.device(device)

        V = int(freq.shape[0])
        self.V = V
        self.eps_smooth = float(eps_smooth)

        # Smoothed probabilities for both alias-build and log_q correction.
        smoothed = freq.detach().to(device=device, dtype=torch.float64) + self.eps_smooth
        total = smoothed.sum()
        probs = smoothed / total

        self._prob, self._alias = _alias_setup(probs)
        # log_q is registered as float32 β€” autocast-friendly, used in CE.
        self.log_q = probs.log().to(torch.float32)

    # ------------------------------------------------------------------
    @torch.no_grad()
    def sample(self, shape: int | tuple[int, ...], device: torch.device | str | None = None) -> torch.Tensor:
        """Draw `shape` samples from the unigram distribution.

        Returns a LongTensor of indices in `[0, V)`. All operations are
        GPU-resident; no host syncs.
        """
        if isinstance(shape, int):
            shape_t = (shape,)
        else:
            shape_t = tuple(shape)
        n = 1
        for d in shape_t:
            n *= d
        if n == 0:
            return torch.empty(shape_t, dtype=torch.long, device=device or self._prob.device)

        target_device = torch.device(device) if device is not None else self._prob.device
        # Move alias tables on demand (rare β€” usually constructed on CUDA).
        if self._prob.device != target_device:
            self._prob = self._prob.to(target_device)
            self._alias = self._alias.to(target_device)
            self.log_q = self.log_q.to(target_device)

        # Vose: pick a uniform bucket, then with probability prob[bucket] keep
        # it, else jump to alias[bucket].
        u_bucket = torch.randint(0, self.V, (n,), device=target_device)
        u_accept = torch.rand(n, device=target_device)
        keep = u_accept < self._prob[u_bucket]
        out = torch.where(keep, u_bucket, self._alias[u_bucket])
        return out.view(shape_t)

    # ------------------------------------------------------------------
    def to(self, device: torch.device | str) -> "UnigramSampler":
        target = torch.device(device)
        self._prob = self._prob.to(target)
        self._alias = self._alias.to(target)
        self.log_q = self.log_q.to(target)
        return self

    # ------------------------------------------------------------------
    @classmethod
    def from_uniform(cls, V: int, device: torch.device | str = "cpu") -> "UnigramSampler":
        """Construct a uniform-distribution sampler. Useful for tests and
        as a debug fallback when no unigram cache is available."""
        return cls(torch.ones(V, dtype=torch.float32), device=device)


# ----------------------------------------------------------------------
# Frequency-cache build/load helpers
# ----------------------------------------------------------------------


def unigram_cache_path(vocab_size: int, cache_dir: Path | str | None = None) -> Path:
    """Canonical path for the unigram-frequency cache file."""
    base = Path(cache_dir) if cache_dir is not None else _DEFAULT_CACHE_DIR
    return base / f"unigram_freq_v{int(vocab_size)}.pt"


def save_unigram_freq(freq: torch.Tensor, vocab_size: int, cache_dir: Path | str | None = None) -> Path:
    """Persist a unigram-frequency tensor to the canonical cache location."""
    if freq.dim() != 1 or freq.shape[0] != vocab_size:
        raise ValueError(
            f"save_unigram_freq: freq must be 1-D with length {vocab_size}, got {tuple(freq.shape)}"
        )
    path = unigram_cache_path(vocab_size, cache_dir)
    path.parent.mkdir(parents=True, exist_ok=True)
    torch.save(freq.detach().to(torch.float32).cpu(), path)
    return path


def load_unigram_freq(vocab_size: int, cache_dir: Path | str | None = None) -> torch.Tensor | None:
    """Load a cached unigram-frequency tensor, or return None if unavailable.

    Validates the loaded tensor's length matches `vocab_size`; mismatches
    return None (treated as cache miss) so the caller can rebuild.
    """
    path = unigram_cache_path(vocab_size, cache_dir)
    if not path.exists():
        return None
    try:
        freq = torch.load(path, map_location="cpu")
    except Exception:
        return None
    if not isinstance(freq, torch.Tensor) or freq.dim() != 1 or freq.shape[0] != vocab_size:
        return None
    return freq.to(torch.float32)


def build_unigram_freq_from_tokenizer(
    tokenizer,
    vocab_size: int,
    target_tokens: int = 1_000_000,
    batch_size: int = 64,
) -> torch.Tensor:
    """Stream a small slice of the training data through the tokenizer and
    return per-token frequencies (length V). Caller is responsible for
    persisting via ``save_unigram_freq``.

    Used as a fallback when no cached frequencies exist; runs once on first
    training start and writes to the cache.
    """
    # Lazy import β€” keeps the module importable on machines without the data
    # path provisioned (e.g., CI, unit tests).
    import prepare as _p

    freq = torch.zeros(vocab_size, dtype=torch.float64)
    seen = 0
    for batch, _epoch in _p._document_batches("train", tokenizer_batch_size=batch_size):
        encoded = tokenizer.encode(batch, prepend=tokenizer.get_bos_token_id())
        flat: list[int] = []
        for row in encoded:
            flat.extend(row)
        if not flat:
            continue
        ids = torch.tensor(flat, dtype=torch.long)
        # bincount with minlength keeps the histogram aligned to V.
        freq += torch.bincount(ids, minlength=vocab_size).to(torch.float64)
        seen += ids.numel()
        if seen >= target_tokens:
            break
    return freq.to(torch.float32)


def get_or_build_unigram_sampler(
    tokenizer,
    vocab_size: int,
    device: torch.device | str = "cuda",
    cache_dir: Path | str | None = None,
    target_tokens: int = 1_000_000,
    rebuild: bool = False,
) -> UnigramSampler:
    """Cache-aware constructor: load `unigram_freq_v{V}.pt` if present,
    otherwise build from a streamed slice and persist.

    HYDRA_UNIGRAM_REBUILD=1 forces a rebuild even if the cache exists.
    HYDRA_UNIGRAM_TARGET_TOKENS overrides `target_tokens` at the env level.
    """
    if os.environ.get("HYDRA_UNIGRAM_REBUILD", "0") == "1":
        rebuild = True
    env_target = os.environ.get("HYDRA_UNIGRAM_TARGET_TOKENS")
    if env_target is not None:
        target_tokens = int(env_target)

    freq = None if rebuild else load_unigram_freq(vocab_size, cache_dir)
    if freq is None:
        freq = build_unigram_freq_from_tokenizer(
            tokenizer, vocab_size, target_tokens=target_tokens
        )
        save_unigram_freq(freq, vocab_size, cache_dir)
    return UnigramSampler(freq, device=device)


# ----------------------------------------------------------------------
# Loss
# ----------------------------------------------------------------------


def sampled_softmax_loss(
    x_flat: torch.Tensor,
    y_flat: torch.Tensor,
    lm_head_weight: torch.Tensor,
    sampler: UnigramSampler,
    K: int,
    *,
    label_smoothing: float = 0.0,
    softcap: float | None = None,
    softcap_clamp: bool = False,
    valid_mask: torch.Tensor | None = None,
    reduction: str = "mean",
    shared_negatives: bool = True,
) -> torch.Tensor:
    """Importance-sampled (unigram) sampled softmax cross-entropy.

    Args:
        x_flat:         (N, d) hidden states.
        y_flat:         (N,) target token ids. Negative entries treated as
                        invalid; if `valid_mask` is None the function falls
                        back to ``y_flat >= 0``.
        lm_head_weight: (V, d) LM-head weight (typically `model.lm_head.weight`).
        sampler:        a ``UnigramSampler`` instance (its log_q must live on
                        the same device as `x_flat`).
        K:              total candidates per row including the positive
                        (K-1 negatives drawn).
        label_smoothing: passed through to F.cross_entropy.
        softcap:        if non-None, apply tanh-softcap to candidate logits.
        softcap_clamp:  True β†’ torch.clamp instead of tanh-softcap.
        valid_mask:     (N,) bool mask. Invalid positions contribute zero loss.
        reduction:      'mean' | 'none'.
        shared_negatives: if True (default), draw a SINGLE batch of K-1
                        negatives shared across all N rows. If False, draw
                        independent negatives per row. Shared is faster
                        (single (n, K) matmul, no (n, K, d) gather) and is
                        what Jean et al. 2015 use; per-row is statistically
                        slightly better but expensive at typical d/K.

    Returns:
        Scalar (mean) or per-token (none) cross-entropy.

    Backward: gradient on `lm_head_weight` flows only through the gathered
    rows (positives + drawn negatives). No full V x d gradient.
    """
    if x_flat.dim() != 2:
        raise ValueError(f"sampled_softmax_loss: x_flat must be 2-D, got {x_flat.shape}")
    if y_flat.shape != (x_flat.shape[0],):
        raise ValueError(
            f"sampled_softmax_loss: y_flat shape {tuple(y_flat.shape)} "
            f"does not match x_flat batch {x_flat.shape[0]}"
        )
    V, d = lm_head_weight.shape
    if x_flat.shape[1] != d:
        raise ValueError(
            f"sampled_softmax_loss: x_flat.shape[-1]={x_flat.shape[1]} != lm_head dim {d}"
        )
    if K <= 0 or K > V:
        raise ValueError(f"sampled_softmax_loss: K={K} out of range (1, V={V}]")

    n = x_flat.shape[0]
    device = x_flat.device

    if valid_mask is None:
        valid_mask = (y_flat >= 0)
    y_safe = torch.where(valid_mask, y_flat, torch.zeros_like(y_flat))

    if shared_negatives:
        # Shared-batch path: (n, d) x (d, K) matmul + per-row positive dot.
        # This is what the production loss path actually wants β€” the (n, K, d)
        # gather of the per-row path costs O(nKd) memory and beats the full
        # softmax matmul only when n is small AND d is large.
        K_neg = K - 1  # `K` total candidates includes the positive at column 0.
        neg_idx = sampler.sample((K_neg,), device=device)                 # (K-1,)
        # Positive logit: (n, d) * (n, d) -> (n,)
        pos_w = F.embedding(y_safe.view(n, 1), lm_head_weight).squeeze(1) # (n, d)
        pos_logit = (x_flat * pos_w).sum(-1)                              # (n,)
        # Negative logits: shared (K-1) negatives
        neg_w = F.embedding(neg_idx, lm_head_weight)                      # (K-1, d)
        neg_logits = x_flat @ neg_w.t()                                   # (n, K-1)

        if softcap is not None and softcap > 0:
            if softcap_clamp:
                pos_logit = torch.clamp(pos_logit, -softcap, softcap)
                neg_logits = torch.clamp(neg_logits, -softcap, softcap)
            else:
                pos_logit = softcap * torch.tanh(pos_logit / softcap)
                neg_logits = softcap * torch.tanh(neg_logits / softcap)

        # log-q correction.
        log_q_pos = sampler.log_q[y_safe]                                 # (n,)
        log_q_neg = sampler.log_q[neg_idx]                                # (K-1,)
        pos_logit = pos_logit - log_q_pos
        neg_logits = neg_logits - log_q_neg                               # broadcasts

        logits = torch.cat([pos_logit.unsqueeze(-1), neg_logits], dim=1).float()  # (n, K)
    else:
        # Per-row independent negatives. (n, K-1) negatives, gather (n, K, d).
        neg = sampler.sample((n, K - 1), device=device)
        cand_idx = torch.cat([y_safe.view(n, 1), neg], dim=1)             # (n, K)
        cand_w = F.embedding(cand_idx, lm_head_weight)                    # (n, K, d)
        logits = torch.einsum("nd,nkd->nk", x_flat, cand_w)               # (n, K)

        if softcap is not None and softcap > 0:
            if softcap_clamp:
                logits = torch.clamp(logits, -softcap, softcap)
            else:
                logits = softcap * torch.tanh(logits / softcap)

        log_q = sampler.log_q[cand_idx]                                   # (n, K)
        logits = (logits - log_q).float()

    # CE with positive at column 0.
    ce_targets = torch.zeros(n, dtype=torch.long, device=device)
    per_tok = F.cross_entropy(
        logits, ce_targets, reduction="none", label_smoothing=label_smoothing
    )

    valid_f = valid_mask.to(per_tok.dtype)
    per_tok = per_tok * valid_f

    if reduction == "none":
        return per_tok
    if reduction == "mean":
        denom = valid_f.sum().clamp(min=1)
        return per_tok.sum() / denom
    raise ValueError(f"sampled_softmax_loss: unknown reduction {reduction!r}")