File size: 4,209 Bytes
21647a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""kl_logging.py — dual_kl_logger (ADR-013, framework-side, generic).

The washout/amplification instrument. Given per-token logprobs from three
forward passes on the SAME answer+reasoning tokens:

  - policy:         the model currently being RL-trained
  - altered_init:   the altered SFT checkpoint the run STARTED from (the locus
                    of the cognitive-distortion signature)
  - unaltered_base: the original base model BEFORE personality SFT

returns ``{'kl_to_altered_init': float, 'kl_to_base': float}``.

NEITHER KL is optimized by default — both are diagnostics:
  - ``kl_to_altered_init`` rising means the policy is moving AWAY from the
    altered checkpoint (task-RL is *changing* the alteration).
  - ``kl_to_base`` measures distance to the unaltered base. If
    ``kl_to_base`` SHRINKS while ``kl_to_altered_init`` grows, the alteration
    is WASHING OUT (the policy drifts back toward base). If ``kl_to_base``
    GROWS faster than ``kl_to_altered_init``, the alteration is being AMPLIFIED
    (the policy moves further from base than the altered init already was) —
    the ADR-013 amplification hypothesis, most likely on the SDPO channel.

Token-mean KL is used (mean over the masked answer+reasoning tokens), the
standard diagnostic convention. The math is the discrete KL between the two
softmax distributions implied by the logprob tensors:

    KL(p || q) = sum_v p_v (log p_v - log q_v)

where ``p`` is the policy's per-token distribution. This is unit-testable on
toy tensors: KL(p || p) == 0, and KL grows monotonically as the policy moves.
"""
from __future__ import annotations

from typing import Any

import torch

__all__ = ["dual_kl_logger", "token_mean_kl"]


def _as_log_probs(logprobs: torch.Tensor) -> torch.Tensor:
    """Normalize an input that may be raw logits OR already-log-probs to valid
    log-probabilities along the last (vocab) dim.

    We re-apply ``log_softmax`` defensively: it is idempotent on a genuine
    log-prob tensor up to floating point (log_softmax of log-probs == log-probs
    since they already sum-exp to 1), and converts raw logits correctly. This
    makes the logger robust to either calling convention.
    """
    return torch.log_softmax(logprobs.to(torch.float64), dim=-1)


def token_mean_kl(
    policy_logprobs: torch.Tensor,
    ref_logprobs: torch.Tensor,
    mask: torch.Tensor | None = None,
) -> float:
    """Token-mean KL(policy || ref) over distributions on the last dim.

    Args:
        policy_logprobs: (..., V) logits or log-probs for the policy.
        ref_logprobs:    (..., V) logits or log-probs for the reference.
        mask: optional (...,) mask of tokens to include (1/True = include). If
            None, all tokens count.

    Returns:
        scalar token-mean KL as a python float (>= 0 up to float error).
    """
    log_p = _as_log_probs(policy_logprobs)
    log_q = _as_log_probs(ref_logprobs)
    p = log_p.exp()
    # per-token KL: sum over vocab of p * (log p - log q)
    per_token = (p * (log_p - log_q)).sum(dim=-1)  # (...,)

    if mask is not None:
        m = mask.to(per_token.dtype)
        denom = m.sum()
        if float(denom) == 0.0:
            return 0.0
        return float((per_token * m).sum() / denom)
    return float(per_token.mean())


def dual_kl_logger(
    policy_logprobs: torch.Tensor,
    altered_init_logprobs: torch.Tensor,
    unaltered_base_logprobs: torch.Tensor,
    mask: torch.Tensor | None = None,
    **_: Any,
) -> dict[str, float]:
    """Compute the two diagnostic KLs for a step.

    Args:
        policy_logprobs:        (..., V) policy logits/log-probs on the
            answer+reasoning tokens.
        altered_init_logprobs:  (..., V) for the altered SFT init.
        unaltered_base_logprobs:(..., V) for the unaltered base.
        mask: optional (...,) token mask (answer+reasoning tokens to score).

    Returns:
        ``{'kl_to_altered_init': float, 'kl_to_base': float}``.
    """
    return {
        "kl_to_altered_init": token_mean_kl(
            policy_logprobs, altered_init_logprobs, mask
        ),
        "kl_to_base": token_mean_kl(
            policy_logprobs, unaltered_base_logprobs, mask
        ),
    }