"""kl_logging.py — dual_kl_logger (ADR-013, framework-side, generic). The washout/amplification instrument. Given per-token logprobs from three forward passes on the SAME answer+reasoning tokens: - policy: the model currently being RL-trained - altered_init: the altered SFT checkpoint the run STARTED from (the locus of the cognitive-distortion signature) - unaltered_base: the original base model BEFORE personality SFT returns ``{'kl_to_altered_init': float, 'kl_to_base': float}``. NEITHER KL is optimized by default — both are diagnostics: - ``kl_to_altered_init`` rising means the policy is moving AWAY from the altered checkpoint (task-RL is *changing* the alteration). - ``kl_to_base`` measures distance to the unaltered base. If ``kl_to_base`` SHRINKS while ``kl_to_altered_init`` grows, the alteration is WASHING OUT (the policy drifts back toward base). If ``kl_to_base`` GROWS faster than ``kl_to_altered_init``, the alteration is being AMPLIFIED (the policy moves further from base than the altered init already was) — the ADR-013 amplification hypothesis, most likely on the SDPO channel. Token-mean KL is used (mean over the masked answer+reasoning tokens), the standard diagnostic convention. The math is the discrete KL between the two softmax distributions implied by the logprob tensors: KL(p || q) = sum_v p_v (log p_v - log q_v) where ``p`` is the policy's per-token distribution. This is unit-testable on toy tensors: KL(p || p) == 0, and KL grows monotonically as the policy moves. """ from __future__ import annotations from typing import Any import torch __all__ = ["dual_kl_logger", "token_mean_kl"] def _as_log_probs(logprobs: torch.Tensor) -> torch.Tensor: """Normalize an input that may be raw logits OR already-log-probs to valid log-probabilities along the last (vocab) dim. We re-apply ``log_softmax`` defensively: it is idempotent on a genuine log-prob tensor up to floating point (log_softmax of log-probs == log-probs since they already sum-exp to 1), and converts raw logits correctly. This makes the logger robust to either calling convention. """ return torch.log_softmax(logprobs.to(torch.float64), dim=-1) def token_mean_kl( policy_logprobs: torch.Tensor, ref_logprobs: torch.Tensor, mask: torch.Tensor | None = None, ) -> float: """Token-mean KL(policy || ref) over distributions on the last dim. Args: policy_logprobs: (..., V) logits or log-probs for the policy. ref_logprobs: (..., V) logits or log-probs for the reference. mask: optional (...,) mask of tokens to include (1/True = include). If None, all tokens count. Returns: scalar token-mean KL as a python float (>= 0 up to float error). """ log_p = _as_log_probs(policy_logprobs) log_q = _as_log_probs(ref_logprobs) p = log_p.exp() # per-token KL: sum over vocab of p * (log p - log q) per_token = (p * (log_p - log_q)).sum(dim=-1) # (...,) if mask is not None: m = mask.to(per_token.dtype) denom = m.sum() if float(denom) == 0.0: return 0.0 return float((per_token * m).sum() / denom) return float(per_token.mean()) def dual_kl_logger( policy_logprobs: torch.Tensor, altered_init_logprobs: torch.Tensor, unaltered_base_logprobs: torch.Tensor, mask: torch.Tensor | None = None, **_: Any, ) -> dict[str, float]: """Compute the two diagnostic KLs for a step. Args: policy_logprobs: (..., V) policy logits/log-probs on the answer+reasoning tokens. altered_init_logprobs: (..., V) for the altered SFT init. unaltered_base_logprobs:(..., V) for the unaltered base. mask: optional (...,) token mask (answer+reasoning tokens to score). Returns: ``{'kl_to_altered_init': float, 'kl_to_base': float}``. """ return { "kl_to_altered_init": token_mean_kl( policy_logprobs, altered_init_logprobs, mask ), "kl_to_base": token_mean_kl( policy_logprobs, unaltered_base_logprobs, mask ), }