""" GRPOTrainer extension: scale (or drop) the reference-policy KL term on tokens inside ``...`` (inner body only). When ``thinking_inner_kl_weight == 1.0`` (default), behavior matches upstream TRL. Requires ``use_liger_kernel=False`` whenever ``thinking_inner_kl_weight != 1.0`` and ``beta != 0``, because the Liger fused loss cannot apply this mask. Synced against TRL v1.0.x ``GRPOTrainer._compute_loss`` — re-sync if upgrading TRL breaks this module. """ from __future__ import annotations import torch from trl import GRPOTrainer from trl.trainer.utils import nanmax, nanmin from .thinking_kl_mask import redacted_thinking_kl_scale class ThinkingKLMaskedGRPOTrainer(GRPOTrainer): thinking_inner_kl_weight: float = 1.0 thinking_kl_tokenizer = None def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): if return_outputs: raise ValueError("The GRPOTrainer does not support returning outputs") if self.use_liger_kernel: w = float(getattr(self, "thinking_inner_kl_weight", 1.0)) if w != 1.0 and self.beta != 0.0: raise RuntimeError( "thinking_inner_kl_weight != 1.0 with beta != 0 requires use_liger_kernel=False " "(Liger fused GRPO cannot mask KL inside )." ) return super().compute_loss( model, inputs, return_outputs=return_outputs, num_items_in_batch=num_items_in_batch, ) def _compute_loss(self, model, inputs): w = float(getattr(self, "thinking_inner_kl_weight", 1.0)) tok = getattr(self, "thinking_kl_tokenizer", None) if w == 1.0 or tok is None or self.beta == 0.0: return super()._compute_loss(model, inputs) return self._compute_loss_with_thinking_kl_mask(model, inputs) def _compute_loss_with_thinking_kl_mask(self, model, inputs): tok = getattr(self, "thinking_kl_tokenizer", None) if tok is None: return super()._compute_loss(model, inputs) prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] input_ids = torch.cat([prompt_ids, completion_ids], dim=1) attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) logits_to_keep = completion_ids.size(1) mask = completion_mask if "tool_mask" not in inputs else completion_mask * inputs["tool_mask"] per_token_logps, entropies = self._get_per_token_logps_and_entropies( model, input_ids, attention_mask, logits_to_keep, compute_entropy=True, pixel_values=inputs.get("pixel_values"), image_grid_thw=inputs.get("image_grid_thw"), num_images=inputs.get("num_images"), pixel_attention_mask=inputs.get("pixel_attention_mask"), image_sizes=inputs.get("image_sizes"), token_type_ids=inputs.get("token_type_ids"), mm_token_type_ids=inputs.get("mm_token_type_ids"), pixel_position_ids=inputs.get("pixel_position_ids"), ) if self.top_entropy_quantile < 1.0: entropy_mask = self.get_high_entropy_mask(entropies, mask, 1 - self.top_entropy_quantile) else: entropy_mask = None advantages = inputs["advantages"] if advantages.dim() == 1: advantages = advantages.unsqueeze(1) old_per_token_logps = inputs.get("old_per_token_logps") old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps if self.off_policy_mask_threshold is not None: sampling_per_token_logps = inputs.get("sampling_per_token_logps", old_per_token_logps) off_policy_mask = self.get_off_policy_mask( advantages=advantages, per_token_logps=per_token_logps, sampling_per_token_logps=sampling_per_token_logps, mask=mask, off_policy_threshold=self.off_policy_mask_threshold, ) else: off_policy_mask = None log_ratio = per_token_logps - old_per_token_logps if self.importance_sampling_level == "token": log_importance_weights = log_ratio elif self.importance_sampling_level == "sequence": log_importance_weights = (log_ratio * mask).sum(-1) / mask.sum(-1).clamp(min=1.0) log_importance_weights = log_importance_weights.unsqueeze(-1) else: raise ValueError( f"Unknown importance sampling level: {self.importance_sampling_level}. " "Possible values are 'token' and 'sequence'." ) coef_1 = torch.exp(log_importance_weights) if self.beta != 0.0: ref_per_token_logps = inputs["ref_per_token_logps"] per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - ( ref_per_token_logps - per_token_logps ) - 1 if self.args.use_bias_correction_kl: per_token_kl = per_token_kl * coef_1 scale = redacted_thinking_kl_scale( inputs["completion_ids"], inputs["completion_mask"], tok, self.thinking_inner_kl_weight, ).to(device=per_token_kl.device, dtype=per_token_kl.dtype) per_token_kl = per_token_kl * scale else: per_token_kl = None if self.loss_type == "cispo": clamped_ratios = torch.clamp(coef_1, max=self.epsilon_high).detach() per_token_loss = -clamped_ratios * advantages * per_token_logps elif self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo", "luspo"]: coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) if self.args.delta is not None: coef_1 = torch.clamp(coef_1, max=self.args.delta) per_token_loss1 = coef_1 * advantages per_token_loss2 = coef_2 * advantages per_token_loss = -torch.min(per_token_loss1, per_token_loss2) elif self.loss_type == "sapo": temperatures = torch.where(advantages > 0, self.args.sapo_temperature_pos, self.args.sapo_temperature_neg) soft_coef_1 = torch.sigmoid(temperatures * (coef_1 - 1)) * 4 / temperatures per_token_loss = -soft_coef_1 * advantages elif self.loss_type == "vespo": phi_seq = self.get_gamma_weights( advantages=advantages, log_ratio_per_token=log_ratio, mask=mask, importance_sampling_ratio=inputs.get("importance_sampling_ratio"), k_pos=self.args.vespo_k_pos, lambda_pos=self.args.vespo_lambda_pos, k_neg=self.args.vespo_k_neg, lambda_neg=self.args.vespo_lambda_neg, ) per_token_loss = -phi_seq * advantages * per_token_logps else: raise ValueError(f"Unknown loss type: {self.loss_type}") if off_policy_mask is not None: per_token_loss = per_token_loss * off_policy_mask if entropy_mask is not None: per_token_loss = per_token_loss * entropy_mask if self.use_vllm and self.vllm_importance_sampling_correction and self.loss_type != "vespo": per_token_loss = per_token_loss * inputs["importance_sampling_ratio"] if self.beta != 0.0: per_token_loss = per_token_loss + self.beta * per_token_kl mode = "train" if self.model.training else "eval" if self.loss_type in ["grpo", "sapo"]: loss = ((per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)).mean() normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 loss = loss / normalizer elif self.loss_type == "bnpo": loss = (per_token_loss * mask).sum() / mask.sum().clamp(min=1.0) normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 loss = loss / normalizer elif self.loss_type == "dr_grpo": loss = (per_token_loss * mask).sum() / (per_token_loss.size(0) * self.max_completion_length) normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 loss = loss / normalizer elif self.loss_type in ["cispo", "dapo", "vespo"]: normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes loss = (per_token_loss * mask).sum() / normalizer elif self.loss_type == "luspo": loss = (per_token_loss * mask.sum(1, keepdim=True)).mean() normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 loss = loss / normalizer else: raise ValueError(f"Unknown loss type: {self.loss_type}") completion_token_count = mask.sum().clamp(min=1.0) def masked_batch_mean(x): if x.shape[1] == 1: return x.mean() return (x * mask).sum() / completion_token_count if self.beta != 0.0: mean_kl = masked_batch_mean(per_token_kl) self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).nanmean().item()) mean_entropy = masked_batch_mean(entropies) self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item()) if self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo", "luspo"]: is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages < 0) is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages > 0) is_region_clipped = is_low_clipped | is_high_clipped low_clip = masked_batch_mean(is_low_clipped.float()) high_clip = masked_batch_mean(is_high_clipped.float()) clip_ratio = masked_batch_mean(is_region_clipped.float()) gathered_low_clip = self.accelerator.gather(low_clip) self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item()) self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item()) gathered_high_clip = self.accelerator.gather(high_clip) self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item()) self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item()) gathered_clip_ratio = self.accelerator.gather(clip_ratio) self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item()) elif self.loss_type == "cispo": is_cispo_clipped = (coef_1 > self.epsilon_high) & (advantages > 0) cispo_clip_ratio = masked_batch_mean(is_cispo_clipped.float()) gathered_cispo_clip_ratio = self.accelerator.gather(cispo_clip_ratio) self._metrics[mode]["cispo_clip_ratio"].append(gathered_cispo_clip_ratio.nanmean().item()) elif self.loss_type == "vespo": gathered_phi_seq = self.accelerator.gather(phi_seq) self._metrics[mode]["vespo/phi_seq_mean"].append(gathered_phi_seq.nanmean().item()) return loss