Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """kl_logging.py — dual_kl_logger (ADR-013, framework-side, generic). | |
| The washout/amplification instrument. Given per-token logprobs from three | |
| forward passes on the SAME answer+reasoning tokens: | |
| - policy: the model currently being RL-trained | |
| - altered_init: the altered SFT checkpoint the run STARTED from (the locus | |
| of the cognitive-distortion signature) | |
| - unaltered_base: the original base model BEFORE personality SFT | |
| returns ``{'kl_to_altered_init': float, 'kl_to_base': float}``. | |
| NEITHER KL is optimized by default — both are diagnostics: | |
| - ``kl_to_altered_init`` rising means the policy is moving AWAY from the | |
| altered checkpoint (task-RL is *changing* the alteration). | |
| - ``kl_to_base`` measures distance to the unaltered base. If | |
| ``kl_to_base`` SHRINKS while ``kl_to_altered_init`` grows, the alteration | |
| is WASHING OUT (the policy drifts back toward base). If ``kl_to_base`` | |
| GROWS faster than ``kl_to_altered_init``, the alteration is being AMPLIFIED | |
| (the policy moves further from base than the altered init already was) — | |
| the ADR-013 amplification hypothesis, most likely on the SDPO channel. | |
| Token-mean KL is used (mean over the masked answer+reasoning tokens), the | |
| standard diagnostic convention. The math is the discrete KL between the two | |
| softmax distributions implied by the logprob tensors: | |
| KL(p || q) = sum_v p_v (log p_v - log q_v) | |
| where ``p`` is the policy's per-token distribution. This is unit-testable on | |
| toy tensors: KL(p || p) == 0, and KL grows monotonically as the policy moves. | |
| """ | |
| from __future__ import annotations | |
| from typing import Any | |
| import torch | |
| __all__ = ["dual_kl_logger", "token_mean_kl"] | |
| def _as_log_probs(logprobs: torch.Tensor) -> torch.Tensor: | |
| """Normalize an input that may be raw logits OR already-log-probs to valid | |
| log-probabilities along the last (vocab) dim. | |
| We re-apply ``log_softmax`` defensively: it is idempotent on a genuine | |
| log-prob tensor up to floating point (log_softmax of log-probs == log-probs | |
| since they already sum-exp to 1), and converts raw logits correctly. This | |
| makes the logger robust to either calling convention. | |
| """ | |
| return torch.log_softmax(logprobs.to(torch.float64), dim=-1) | |
| def token_mean_kl( | |
| policy_logprobs: torch.Tensor, | |
| ref_logprobs: torch.Tensor, | |
| mask: torch.Tensor | None = None, | |
| ) -> float: | |
| """Token-mean KL(policy || ref) over distributions on the last dim. | |
| Args: | |
| policy_logprobs: (..., V) logits or log-probs for the policy. | |
| ref_logprobs: (..., V) logits or log-probs for the reference. | |
| mask: optional (...,) mask of tokens to include (1/True = include). If | |
| None, all tokens count. | |
| Returns: | |
| scalar token-mean KL as a python float (>= 0 up to float error). | |
| """ | |
| log_p = _as_log_probs(policy_logprobs) | |
| log_q = _as_log_probs(ref_logprobs) | |
| p = log_p.exp() | |
| # per-token KL: sum over vocab of p * (log p - log q) | |
| per_token = (p * (log_p - log_q)).sum(dim=-1) # (...,) | |
| if mask is not None: | |
| m = mask.to(per_token.dtype) | |
| denom = m.sum() | |
| if float(denom) == 0.0: | |
| return 0.0 | |
| return float((per_token * m).sum() / denom) | |
| return float(per_token.mean()) | |
| def dual_kl_logger( | |
| policy_logprobs: torch.Tensor, | |
| altered_init_logprobs: torch.Tensor, | |
| unaltered_base_logprobs: torch.Tensor, | |
| mask: torch.Tensor | None = None, | |
| **_: Any, | |
| ) -> dict[str, float]: | |
| """Compute the two diagnostic KLs for a step. | |
| Args: | |
| policy_logprobs: (..., V) policy logits/log-probs on the | |
| answer+reasoning tokens. | |
| altered_init_logprobs: (..., V) for the altered SFT init. | |
| unaltered_base_logprobs:(..., V) for the unaltered base. | |
| mask: optional (...,) token mask (answer+reasoning tokens to score). | |
| Returns: | |
| ``{'kl_to_altered_init': float, 'kl_to_base': float}``. | |
| """ | |
| return { | |
| "kl_to_altered_init": token_mean_kl( | |
| policy_logprobs, altered_init_logprobs, mask | |
| ), | |
| "kl_to_base": token_mean_kl( | |
| policy_logprobs, unaltered_base_logprobs, mask | |
| ), | |
| } | |