Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 4,209 Bytes
21647a4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 | """kl_logging.py — dual_kl_logger (ADR-013, framework-side, generic).
The washout/amplification instrument. Given per-token logprobs from three
forward passes on the SAME answer+reasoning tokens:
- policy: the model currently being RL-trained
- altered_init: the altered SFT checkpoint the run STARTED from (the locus
of the cognitive-distortion signature)
- unaltered_base: the original base model BEFORE personality SFT
returns ``{'kl_to_altered_init': float, 'kl_to_base': float}``.
NEITHER KL is optimized by default — both are diagnostics:
- ``kl_to_altered_init`` rising means the policy is moving AWAY from the
altered checkpoint (task-RL is *changing* the alteration).
- ``kl_to_base`` measures distance to the unaltered base. If
``kl_to_base`` SHRINKS while ``kl_to_altered_init`` grows, the alteration
is WASHING OUT (the policy drifts back toward base). If ``kl_to_base``
GROWS faster than ``kl_to_altered_init``, the alteration is being AMPLIFIED
(the policy moves further from base than the altered init already was) —
the ADR-013 amplification hypothesis, most likely on the SDPO channel.
Token-mean KL is used (mean over the masked answer+reasoning tokens), the
standard diagnostic convention. The math is the discrete KL between the two
softmax distributions implied by the logprob tensors:
KL(p || q) = sum_v p_v (log p_v - log q_v)
where ``p`` is the policy's per-token distribution. This is unit-testable on
toy tensors: KL(p || p) == 0, and KL grows monotonically as the policy moves.
"""
from __future__ import annotations
from typing import Any
import torch
__all__ = ["dual_kl_logger", "token_mean_kl"]
def _as_log_probs(logprobs: torch.Tensor) -> torch.Tensor:
"""Normalize an input that may be raw logits OR already-log-probs to valid
log-probabilities along the last (vocab) dim.
We re-apply ``log_softmax`` defensively: it is idempotent on a genuine
log-prob tensor up to floating point (log_softmax of log-probs == log-probs
since they already sum-exp to 1), and converts raw logits correctly. This
makes the logger robust to either calling convention.
"""
return torch.log_softmax(logprobs.to(torch.float64), dim=-1)
def token_mean_kl(
policy_logprobs: torch.Tensor,
ref_logprobs: torch.Tensor,
mask: torch.Tensor | None = None,
) -> float:
"""Token-mean KL(policy || ref) over distributions on the last dim.
Args:
policy_logprobs: (..., V) logits or log-probs for the policy.
ref_logprobs: (..., V) logits or log-probs for the reference.
mask: optional (...,) mask of tokens to include (1/True = include). If
None, all tokens count.
Returns:
scalar token-mean KL as a python float (>= 0 up to float error).
"""
log_p = _as_log_probs(policy_logprobs)
log_q = _as_log_probs(ref_logprobs)
p = log_p.exp()
# per-token KL: sum over vocab of p * (log p - log q)
per_token = (p * (log_p - log_q)).sum(dim=-1) # (...,)
if mask is not None:
m = mask.to(per_token.dtype)
denom = m.sum()
if float(denom) == 0.0:
return 0.0
return float((per_token * m).sum() / denom)
return float(per_token.mean())
def dual_kl_logger(
policy_logprobs: torch.Tensor,
altered_init_logprobs: torch.Tensor,
unaltered_base_logprobs: torch.Tensor,
mask: torch.Tensor | None = None,
**_: Any,
) -> dict[str, float]:
"""Compute the two diagnostic KLs for a step.
Args:
policy_logprobs: (..., V) policy logits/log-probs on the
answer+reasoning tokens.
altered_init_logprobs: (..., V) for the altered SFT init.
unaltered_base_logprobs:(..., V) for the unaltered base.
mask: optional (...,) token mask (answer+reasoning tokens to score).
Returns:
``{'kl_to_altered_init': float, 'kl_to_base': float}``.
"""
return {
"kl_to_altered_init": token_mean_kl(
policy_logprobs, altered_init_logprobs, mask
),
"kl_to_base": token_mean_kl(
policy_logprobs, unaltered_base_logprobs, mask
),
}
|