Codeseys's picture
feat(wave-b): ADR-013 LMA integration + B4 end-to-end SDPO-fires proof + doc refresh
21647a4
Raw
History Blame Contribute Delete
4.21 kB
"""kl_logging.py — dual_kl_logger (ADR-013, framework-side, generic).
The washout/amplification instrument. Given per-token logprobs from three
forward passes on the SAME answer+reasoning tokens:
- policy: the model currently being RL-trained
- altered_init: the altered SFT checkpoint the run STARTED from (the locus
of the cognitive-distortion signature)
- unaltered_base: the original base model BEFORE personality SFT
returns ``{'kl_to_altered_init': float, 'kl_to_base': float}``.
NEITHER KL is optimized by default — both are diagnostics:
- ``kl_to_altered_init`` rising means the policy is moving AWAY from the
altered checkpoint (task-RL is *changing* the alteration).
- ``kl_to_base`` measures distance to the unaltered base. If
``kl_to_base`` SHRINKS while ``kl_to_altered_init`` grows, the alteration
is WASHING OUT (the policy drifts back toward base). If ``kl_to_base``
GROWS faster than ``kl_to_altered_init``, the alteration is being AMPLIFIED
(the policy moves further from base than the altered init already was) —
the ADR-013 amplification hypothesis, most likely on the SDPO channel.
Token-mean KL is used (mean over the masked answer+reasoning tokens), the
standard diagnostic convention. The math is the discrete KL between the two
softmax distributions implied by the logprob tensors:
KL(p || q) = sum_v p_v (log p_v - log q_v)
where ``p`` is the policy's per-token distribution. This is unit-testable on
toy tensors: KL(p || p) == 0, and KL grows monotonically as the policy moves.
"""
from __future__ import annotations
from typing import Any
import torch
__all__ = ["dual_kl_logger", "token_mean_kl"]
def _as_log_probs(logprobs: torch.Tensor) -> torch.Tensor:
"""Normalize an input that may be raw logits OR already-log-probs to valid
log-probabilities along the last (vocab) dim.
We re-apply ``log_softmax`` defensively: it is idempotent on a genuine
log-prob tensor up to floating point (log_softmax of log-probs == log-probs
since they already sum-exp to 1), and converts raw logits correctly. This
makes the logger robust to either calling convention.
"""
return torch.log_softmax(logprobs.to(torch.float64), dim=-1)
def token_mean_kl(
policy_logprobs: torch.Tensor,
ref_logprobs: torch.Tensor,
mask: torch.Tensor | None = None,
) -> float:
"""Token-mean KL(policy || ref) over distributions on the last dim.
Args:
policy_logprobs: (..., V) logits or log-probs for the policy.
ref_logprobs: (..., V) logits or log-probs for the reference.
mask: optional (...,) mask of tokens to include (1/True = include). If
None, all tokens count.
Returns:
scalar token-mean KL as a python float (>= 0 up to float error).
"""
log_p = _as_log_probs(policy_logprobs)
log_q = _as_log_probs(ref_logprobs)
p = log_p.exp()
# per-token KL: sum over vocab of p * (log p - log q)
per_token = (p * (log_p - log_q)).sum(dim=-1) # (...,)
if mask is not None:
m = mask.to(per_token.dtype)
denom = m.sum()
if float(denom) == 0.0:
return 0.0
return float((per_token * m).sum() / denom)
return float(per_token.mean())
def dual_kl_logger(
policy_logprobs: torch.Tensor,
altered_init_logprobs: torch.Tensor,
unaltered_base_logprobs: torch.Tensor,
mask: torch.Tensor | None = None,
**_: Any,
) -> dict[str, float]:
"""Compute the two diagnostic KLs for a step.
Args:
policy_logprobs: (..., V) policy logits/log-probs on the
answer+reasoning tokens.
altered_init_logprobs: (..., V) for the altered SFT init.
unaltered_base_logprobs:(..., V) for the unaltered base.
mask: optional (...,) token mask (answer+reasoning tokens to score).
Returns:
``{'kl_to_altered_init': float, 'kl_to_base': float}``.
"""
return {
"kl_to_altered_init": token_mean_kl(
policy_logprobs, altered_init_logprobs, mask
),
"kl_to_base": token_mean_kl(
policy_logprobs, unaltered_base_logprobs, mask
),
}