neuralese_temp / src /hackable /thinking_kl_grpo_trainer.py
psidharth567's picture
Export neuralese codebase (cache and .env excluded).
dbc69f3
"""
GRPOTrainer extension: scale (or drop) the reference-policy KL term on tokens inside
``<redacted_thinking>...</redacted_thinking>`` (inner body only).
When ``thinking_inner_kl_weight == 1.0`` (default), behavior matches upstream TRL.
Requires ``use_liger_kernel=False`` whenever ``thinking_inner_kl_weight != 1.0`` and ``beta != 0``,
because the Liger fused loss cannot apply this mask.
Synced against TRL v1.0.x ``GRPOTrainer._compute_loss`` — re-sync if upgrading TRL breaks this module.
"""
from __future__ import annotations
import torch
from trl import GRPOTrainer
from trl.trainer.utils import nanmax, nanmin
from .thinking_kl_mask import redacted_thinking_kl_scale
class ThinkingKLMaskedGRPOTrainer(GRPOTrainer):
thinking_inner_kl_weight: float = 1.0
thinking_kl_tokenizer = None
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs")
if self.use_liger_kernel:
w = float(getattr(self, "thinking_inner_kl_weight", 1.0))
if w != 1.0 and self.beta != 0.0:
raise RuntimeError(
"thinking_inner_kl_weight != 1.0 with beta != 0 requires use_liger_kernel=False "
"(Liger fused GRPO cannot mask KL inside <redacted_thinking>)."
)
return super().compute_loss(
model,
inputs,
return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch,
)
def _compute_loss(self, model, inputs):
w = float(getattr(self, "thinking_inner_kl_weight", 1.0))
tok = getattr(self, "thinking_kl_tokenizer", None)
if w == 1.0 or tok is None or self.beta == 0.0:
return super()._compute_loss(model, inputs)
return self._compute_loss_with_thinking_kl_mask(model, inputs)
def _compute_loss_with_thinking_kl_mask(self, model, inputs):
tok = getattr(self, "thinking_kl_tokenizer", None)
if tok is None:
return super()._compute_loss(model, inputs)
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
logits_to_keep = completion_ids.size(1)
mask = completion_mask if "tool_mask" not in inputs else completion_mask * inputs["tool_mask"]
per_token_logps, entropies = self._get_per_token_logps_and_entropies(
model,
input_ids,
attention_mask,
logits_to_keep,
compute_entropy=True,
pixel_values=inputs.get("pixel_values"),
image_grid_thw=inputs.get("image_grid_thw"),
num_images=inputs.get("num_images"),
pixel_attention_mask=inputs.get("pixel_attention_mask"),
image_sizes=inputs.get("image_sizes"),
token_type_ids=inputs.get("token_type_ids"),
mm_token_type_ids=inputs.get("mm_token_type_ids"),
pixel_position_ids=inputs.get("pixel_position_ids"),
)
if self.top_entropy_quantile < 1.0:
entropy_mask = self.get_high_entropy_mask(entropies, mask, 1 - self.top_entropy_quantile)
else:
entropy_mask = None
advantages = inputs["advantages"]
if advantages.dim() == 1:
advantages = advantages.unsqueeze(1)
old_per_token_logps = inputs.get("old_per_token_logps")
old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps
if self.off_policy_mask_threshold is not None:
sampling_per_token_logps = inputs.get("sampling_per_token_logps", old_per_token_logps)
off_policy_mask = self.get_off_policy_mask(
advantages=advantages,
per_token_logps=per_token_logps,
sampling_per_token_logps=sampling_per_token_logps,
mask=mask,
off_policy_threshold=self.off_policy_mask_threshold,
)
else:
off_policy_mask = None
log_ratio = per_token_logps - old_per_token_logps
if self.importance_sampling_level == "token":
log_importance_weights = log_ratio
elif self.importance_sampling_level == "sequence":
log_importance_weights = (log_ratio * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)
log_importance_weights = log_importance_weights.unsqueeze(-1)
else:
raise ValueError(
f"Unknown importance sampling level: {self.importance_sampling_level}. "
"Possible values are 'token' and 'sequence'."
)
coef_1 = torch.exp(log_importance_weights)
if self.beta != 0.0:
ref_per_token_logps = inputs["ref_per_token_logps"]
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (
ref_per_token_logps - per_token_logps
) - 1
if self.args.use_bias_correction_kl:
per_token_kl = per_token_kl * coef_1
scale = redacted_thinking_kl_scale(
inputs["completion_ids"],
inputs["completion_mask"],
tok,
self.thinking_inner_kl_weight,
).to(device=per_token_kl.device, dtype=per_token_kl.dtype)
per_token_kl = per_token_kl * scale
else:
per_token_kl = None
if self.loss_type == "cispo":
clamped_ratios = torch.clamp(coef_1, max=self.epsilon_high).detach()
per_token_loss = -clamped_ratios * advantages * per_token_logps
elif self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo", "luspo"]:
coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
if self.args.delta is not None:
coef_1 = torch.clamp(coef_1, max=self.args.delta)
per_token_loss1 = coef_1 * advantages
per_token_loss2 = coef_2 * advantages
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
elif self.loss_type == "sapo":
temperatures = torch.where(advantages > 0, self.args.sapo_temperature_pos, self.args.sapo_temperature_neg)
soft_coef_1 = torch.sigmoid(temperatures * (coef_1 - 1)) * 4 / temperatures
per_token_loss = -soft_coef_1 * advantages
elif self.loss_type == "vespo":
phi_seq = self.get_gamma_weights(
advantages=advantages,
log_ratio_per_token=log_ratio,
mask=mask,
importance_sampling_ratio=inputs.get("importance_sampling_ratio"),
k_pos=self.args.vespo_k_pos,
lambda_pos=self.args.vespo_lambda_pos,
k_neg=self.args.vespo_k_neg,
lambda_neg=self.args.vespo_lambda_neg,
)
per_token_loss = -phi_seq * advantages * per_token_logps
else:
raise ValueError(f"Unknown loss type: {self.loss_type}")
if off_policy_mask is not None:
per_token_loss = per_token_loss * off_policy_mask
if entropy_mask is not None:
per_token_loss = per_token_loss * entropy_mask
if self.use_vllm and self.vllm_importance_sampling_correction and self.loss_type != "vespo":
per_token_loss = per_token_loss * inputs["importance_sampling_ratio"]
if self.beta != 0.0:
per_token_loss = per_token_loss + self.beta * per_token_kl
mode = "train" if self.model.training else "eval"
if self.loss_type in ["grpo", "sapo"]:
loss = ((per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)).mean()
normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0
loss = loss / normalizer
elif self.loss_type == "bnpo":
loss = (per_token_loss * mask).sum() / mask.sum().clamp(min=1.0)
normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0
loss = loss / normalizer
elif self.loss_type == "dr_grpo":
loss = (per_token_loss * mask).sum() / (per_token_loss.size(0) * self.max_completion_length)
normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0
loss = loss / normalizer
elif self.loss_type in ["cispo", "dapo", "vespo"]:
normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes
loss = (per_token_loss * mask).sum() / normalizer
elif self.loss_type == "luspo":
loss = (per_token_loss * mask.sum(1, keepdim=True)).mean()
normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0
loss = loss / normalizer
else:
raise ValueError(f"Unknown loss type: {self.loss_type}")
completion_token_count = mask.sum().clamp(min=1.0)
def masked_batch_mean(x):
if x.shape[1] == 1:
return x.mean()
return (x * mask).sum() / completion_token_count
if self.beta != 0.0:
mean_kl = masked_batch_mean(per_token_kl)
self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).nanmean().item())
mean_entropy = masked_batch_mean(entropies)
self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item())
if self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo", "luspo"]:
is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages < 0)
is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages > 0)
is_region_clipped = is_low_clipped | is_high_clipped
low_clip = masked_batch_mean(is_low_clipped.float())
high_clip = masked_batch_mean(is_high_clipped.float())
clip_ratio = masked_batch_mean(is_region_clipped.float())
gathered_low_clip = self.accelerator.gather(low_clip)
self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item())
self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item())
gathered_high_clip = self.accelerator.gather(high_clip)
self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item())
self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item())
gathered_clip_ratio = self.accelerator.gather(clip_ratio)
self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item())
elif self.loss_type == "cispo":
is_cispo_clipped = (coef_1 > self.epsilon_high) & (advantages > 0)
cispo_clip_ratio = masked_batch_mean(is_cispo_clipped.float())
gathered_cispo_clip_ratio = self.accelerator.gather(cispo_clip_ratio)
self._metrics[mode]["cispo_clip_ratio"].append(gathered_cispo_clip_ratio.nanmean().item())
elif self.loss_type == "vespo":
gathered_phi_seq = self.accelerator.gather(phi_seq)
self._metrics[mode]["vespo/phi_seq_mean"].append(gathered_phi_seq.nanmean().item())
return loss