| """ |
| GRPOTrainer extension: scale (or drop) the reference-policy KL term on tokens inside |
| ``<redacted_thinking>...</redacted_thinking>`` (inner body only). |
| |
| When ``thinking_inner_kl_weight == 1.0`` (default), behavior matches upstream TRL. |
| |
| Requires ``use_liger_kernel=False`` whenever ``thinking_inner_kl_weight != 1.0`` and ``beta != 0``, |
| because the Liger fused loss cannot apply this mask. |
| |
| Synced against TRL v1.0.x ``GRPOTrainer._compute_loss`` — re-sync if upgrading TRL breaks this module. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import torch |
| from trl import GRPOTrainer |
| from trl.trainer.utils import nanmax, nanmin |
|
|
| from .thinking_kl_mask import redacted_thinking_kl_scale |
|
|
|
|
| class ThinkingKLMaskedGRPOTrainer(GRPOTrainer): |
| thinking_inner_kl_weight: float = 1.0 |
| thinking_kl_tokenizer = None |
|
|
| def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): |
| if return_outputs: |
| raise ValueError("The GRPOTrainer does not support returning outputs") |
| if self.use_liger_kernel: |
| w = float(getattr(self, "thinking_inner_kl_weight", 1.0)) |
| if w != 1.0 and self.beta != 0.0: |
| raise RuntimeError( |
| "thinking_inner_kl_weight != 1.0 with beta != 0 requires use_liger_kernel=False " |
| "(Liger fused GRPO cannot mask KL inside <redacted_thinking>)." |
| ) |
| return super().compute_loss( |
| model, |
| inputs, |
| return_outputs=return_outputs, |
| num_items_in_batch=num_items_in_batch, |
| ) |
|
|
| def _compute_loss(self, model, inputs): |
| w = float(getattr(self, "thinking_inner_kl_weight", 1.0)) |
| tok = getattr(self, "thinking_kl_tokenizer", None) |
| if w == 1.0 or tok is None or self.beta == 0.0: |
| return super()._compute_loss(model, inputs) |
| return self._compute_loss_with_thinking_kl_mask(model, inputs) |
|
|
| def _compute_loss_with_thinking_kl_mask(self, model, inputs): |
| tok = getattr(self, "thinking_kl_tokenizer", None) |
| if tok is None: |
| return super()._compute_loss(model, inputs) |
|
|
| prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] |
| completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] |
| input_ids = torch.cat([prompt_ids, completion_ids], dim=1) |
| attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) |
| logits_to_keep = completion_ids.size(1) |
| mask = completion_mask if "tool_mask" not in inputs else completion_mask * inputs["tool_mask"] |
|
|
| per_token_logps, entropies = self._get_per_token_logps_and_entropies( |
| model, |
| input_ids, |
| attention_mask, |
| logits_to_keep, |
| compute_entropy=True, |
| pixel_values=inputs.get("pixel_values"), |
| image_grid_thw=inputs.get("image_grid_thw"), |
| num_images=inputs.get("num_images"), |
| pixel_attention_mask=inputs.get("pixel_attention_mask"), |
| image_sizes=inputs.get("image_sizes"), |
| token_type_ids=inputs.get("token_type_ids"), |
| mm_token_type_ids=inputs.get("mm_token_type_ids"), |
| pixel_position_ids=inputs.get("pixel_position_ids"), |
| ) |
|
|
| if self.top_entropy_quantile < 1.0: |
| entropy_mask = self.get_high_entropy_mask(entropies, mask, 1 - self.top_entropy_quantile) |
| else: |
| entropy_mask = None |
|
|
| advantages = inputs["advantages"] |
| if advantages.dim() == 1: |
| advantages = advantages.unsqueeze(1) |
| old_per_token_logps = inputs.get("old_per_token_logps") |
| old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps |
|
|
| if self.off_policy_mask_threshold is not None: |
| sampling_per_token_logps = inputs.get("sampling_per_token_logps", old_per_token_logps) |
| off_policy_mask = self.get_off_policy_mask( |
| advantages=advantages, |
| per_token_logps=per_token_logps, |
| sampling_per_token_logps=sampling_per_token_logps, |
| mask=mask, |
| off_policy_threshold=self.off_policy_mask_threshold, |
| ) |
| else: |
| off_policy_mask = None |
|
|
| log_ratio = per_token_logps - old_per_token_logps |
| if self.importance_sampling_level == "token": |
| log_importance_weights = log_ratio |
| elif self.importance_sampling_level == "sequence": |
| log_importance_weights = (log_ratio * mask).sum(-1) / mask.sum(-1).clamp(min=1.0) |
| log_importance_weights = log_importance_weights.unsqueeze(-1) |
| else: |
| raise ValueError( |
| f"Unknown importance sampling level: {self.importance_sampling_level}. " |
| "Possible values are 'token' and 'sequence'." |
| ) |
|
|
| coef_1 = torch.exp(log_importance_weights) |
|
|
| if self.beta != 0.0: |
| ref_per_token_logps = inputs["ref_per_token_logps"] |
| per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - ( |
| ref_per_token_logps - per_token_logps |
| ) - 1 |
| if self.args.use_bias_correction_kl: |
| per_token_kl = per_token_kl * coef_1 |
|
|
| scale = redacted_thinking_kl_scale( |
| inputs["completion_ids"], |
| inputs["completion_mask"], |
| tok, |
| self.thinking_inner_kl_weight, |
| ).to(device=per_token_kl.device, dtype=per_token_kl.dtype) |
| per_token_kl = per_token_kl * scale |
| else: |
| per_token_kl = None |
|
|
| if self.loss_type == "cispo": |
| clamped_ratios = torch.clamp(coef_1, max=self.epsilon_high).detach() |
| per_token_loss = -clamped_ratios * advantages * per_token_logps |
| elif self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo", "luspo"]: |
| coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) |
| if self.args.delta is not None: |
| coef_1 = torch.clamp(coef_1, max=self.args.delta) |
| per_token_loss1 = coef_1 * advantages |
| per_token_loss2 = coef_2 * advantages |
| per_token_loss = -torch.min(per_token_loss1, per_token_loss2) |
| elif self.loss_type == "sapo": |
| temperatures = torch.where(advantages > 0, self.args.sapo_temperature_pos, self.args.sapo_temperature_neg) |
| soft_coef_1 = torch.sigmoid(temperatures * (coef_1 - 1)) * 4 / temperatures |
| per_token_loss = -soft_coef_1 * advantages |
| elif self.loss_type == "vespo": |
| phi_seq = self.get_gamma_weights( |
| advantages=advantages, |
| log_ratio_per_token=log_ratio, |
| mask=mask, |
| importance_sampling_ratio=inputs.get("importance_sampling_ratio"), |
| k_pos=self.args.vespo_k_pos, |
| lambda_pos=self.args.vespo_lambda_pos, |
| k_neg=self.args.vespo_k_neg, |
| lambda_neg=self.args.vespo_lambda_neg, |
| ) |
| per_token_loss = -phi_seq * advantages * per_token_logps |
| else: |
| raise ValueError(f"Unknown loss type: {self.loss_type}") |
|
|
| if off_policy_mask is not None: |
| per_token_loss = per_token_loss * off_policy_mask |
| if entropy_mask is not None: |
| per_token_loss = per_token_loss * entropy_mask |
| if self.use_vllm and self.vllm_importance_sampling_correction and self.loss_type != "vespo": |
| per_token_loss = per_token_loss * inputs["importance_sampling_ratio"] |
|
|
| if self.beta != 0.0: |
| per_token_loss = per_token_loss + self.beta * per_token_kl |
|
|
| mode = "train" if self.model.training else "eval" |
| if self.loss_type in ["grpo", "sapo"]: |
| loss = ((per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)).mean() |
| normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 |
| loss = loss / normalizer |
| elif self.loss_type == "bnpo": |
| loss = (per_token_loss * mask).sum() / mask.sum().clamp(min=1.0) |
| normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 |
| loss = loss / normalizer |
| elif self.loss_type == "dr_grpo": |
| loss = (per_token_loss * mask).sum() / (per_token_loss.size(0) * self.max_completion_length) |
| normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 |
| loss = loss / normalizer |
| elif self.loss_type in ["cispo", "dapo", "vespo"]: |
| normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes |
| loss = (per_token_loss * mask).sum() / normalizer |
| elif self.loss_type == "luspo": |
| loss = (per_token_loss * mask.sum(1, keepdim=True)).mean() |
| normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 |
| loss = loss / normalizer |
| else: |
| raise ValueError(f"Unknown loss type: {self.loss_type}") |
|
|
| completion_token_count = mask.sum().clamp(min=1.0) |
|
|
| def masked_batch_mean(x): |
| if x.shape[1] == 1: |
| return x.mean() |
| return (x * mask).sum() / completion_token_count |
|
|
| if self.beta != 0.0: |
| mean_kl = masked_batch_mean(per_token_kl) |
| self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).nanmean().item()) |
|
|
| mean_entropy = masked_batch_mean(entropies) |
| self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item()) |
|
|
| if self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo", "luspo"]: |
| is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages < 0) |
| is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages > 0) |
| is_region_clipped = is_low_clipped | is_high_clipped |
| low_clip = masked_batch_mean(is_low_clipped.float()) |
| high_clip = masked_batch_mean(is_high_clipped.float()) |
| clip_ratio = masked_batch_mean(is_region_clipped.float()) |
| gathered_low_clip = self.accelerator.gather(low_clip) |
| self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item()) |
| self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item()) |
| gathered_high_clip = self.accelerator.gather(high_clip) |
| self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item()) |
| self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item()) |
| gathered_clip_ratio = self.accelerator.gather(clip_ratio) |
| self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item()) |
| elif self.loss_type == "cispo": |
| is_cispo_clipped = (coef_1 > self.epsilon_high) & (advantages > 0) |
| cispo_clip_ratio = masked_batch_mean(is_cispo_clipped.float()) |
| gathered_cispo_clip_ratio = self.accelerator.gather(cispo_clip_ratio) |
| self._metrics[mode]["cispo_clip_ratio"].append(gathered_cispo_clip_ratio.nanmean().item()) |
| elif self.loss_type == "vespo": |
| gathered_phi_seq = self.accelerator.gather(phi_seq) |
| self._metrics[mode]["vespo/phi_seq_mean"].append(gathered_phi_seq.nanmean().item()) |
|
|
| return loss |
|
|