| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Core functions to implement PPO algorithms. |
| The function implemented in this file should be used by trainer with different distributed strategies to |
| implement PPO-like algorithms. |
| """ |
|
|
| __all__ = ["register_adv_est", "get_adv_estimator_fn", "AdvantageEstimator"] |
|
|
| from collections import defaultdict |
| from enum import Enum |
| import math |
| from typing import Any, Callable, Optional |
|
|
| import numpy as np |
| import torch |
| from omegaconf import DictConfig |
|
|
| import verl.utils.torch_functional as verl_F |
| from verl.trainer.config import AlgoConfig |
| from verl.utils import as_torch_index, group_mean_std |
| from verl.utils.import_utils import deprecated |
| from verl.workers.config import ActorConfig |
|
|
| PolicyLossFn = Callable[ |
| [ |
| torch.Tensor, |
| torch.Tensor, |
| torch.Tensor, |
| torch.Tensor, |
| str, |
| Optional[DictConfig | ActorConfig], |
| torch.Tensor | None, |
| ], |
| tuple[torch.Tensor, dict[str, Any]], |
| ] |
|
|
| POLICY_LOSS_REGISTRY: dict[str, PolicyLossFn] = {} |
|
|
|
|
| def register_policy_loss(name: str) -> Callable[[PolicyLossFn], PolicyLossFn]: |
| """Register a policy loss function with the given name. |
| |
| Args: |
| name (str): The name to register the policy loss function under. |
| |
| Returns: |
| function: Decorator function that registers the policy loss function. |
| """ |
|
|
| def decorator(func: PolicyLossFn) -> PolicyLossFn: |
| POLICY_LOSS_REGISTRY[name] = func |
| return func |
|
|
| return decorator |
|
|
|
|
| def get_policy_loss_fn(name): |
| """Get the policy loss with a given name. |
| |
| Args: |
| name: `(str)` |
| The name of the policy loss. |
| |
| Returns: |
| `(callable)`: The policy loss function. |
| """ |
| loss_name = name |
| if loss_name not in POLICY_LOSS_REGISTRY: |
| raise ValueError( |
| f"Unsupported loss mode: {loss_name}. Supported modes are: {list(POLICY_LOSS_REGISTRY.keys())}" |
| ) |
| return POLICY_LOSS_REGISTRY[loss_name] |
|
|
|
|
| class AdvantageEstimator(str, Enum): |
| """Using an enumeration class to avoid spelling errors in adv_estimator. |
| |
| Note(haibin.lin): this enum class is immutable after creation. Extending this |
| enum for new estimators may not be necessary since users can always just call |
| `verl.trainer.ppo.core_algos.register` with string name for a custom advantage |
| estimator instead. |
| """ |
|
|
| GAE = "gae" |
| GRPO = "grpo" |
| REINFORCE_PLUS_PLUS = "reinforce_plus_plus" |
| REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline" |
| REMAX = "remax" |
| RLOO = "rloo" |
| OPO = "opo" |
| GRPO_PASSK = "grpo_passk" |
| GPG = "gpg" |
| RLOO_VECTORIZED = "rloo_vectorized" |
| GRPO_VECTORIZED = "grpo_vectorized" |
| QAE = "QAE" |
| OPTIMAL_TOKEN_BASELINE = "optimal_token_baseline" |
| TIR_OPTIMAL_TOKEN_BASELINE = "tir_optimal_token_baseline" |
| GDPO = "gdpo" |
|
|
|
|
| ADV_ESTIMATOR_REGISTRY: dict[str, Any] = {} |
|
|
|
|
| def register_adv_est(name_or_enum: str | AdvantageEstimator) -> Any: |
| """Decorator to register a advantage estimator function with a given name. |
| |
| Args: |
| name_or_enum: `(str)` or `(AdvantageEstimator)` |
| The name or enum of the advantage estimator. |
| |
| """ |
|
|
| def decorator(fn): |
| name = name_or_enum.value if isinstance(name_or_enum, Enum) else name_or_enum |
| if name in ADV_ESTIMATOR_REGISTRY and ADV_ESTIMATOR_REGISTRY[name] != fn: |
| raise ValueError( |
| f"Adv estimator {name} has already been registered: {ADV_ESTIMATOR_REGISTRY[name]} vs {fn}" |
| ) |
| ADV_ESTIMATOR_REGISTRY[name] = fn |
| return fn |
|
|
| return decorator |
|
|
|
|
| def get_adv_estimator_fn(name_or_enum): |
| """Get the advantage estimator function with a given name. |
| |
| Args: |
| name_or_enum: `(str)` or `(AdvantageEstimator)` |
| The name or enum of the advantage estimator. |
| |
| Returns: |
| `(callable)`: The advantage estimator function. |
| """ |
| name = name_or_enum.value if isinstance(name_or_enum, Enum) else name_or_enum |
| if name not in ADV_ESTIMATOR_REGISTRY: |
| raise ValueError(f"Unknown advantage estimator simply: {name}") |
| return ADV_ESTIMATOR_REGISTRY[name] |
|
|
|
|
| class AdaptiveKLController: |
| """ |
| Adaptive KL controller described in the paper: |
| https://arxiv.org/pdf/1909.08593.pdf |
| """ |
|
|
| def __init__(self, init_kl_coef, target_kl, horizon): |
| self.value = init_kl_coef |
| self.target = target_kl |
| self.horizon = horizon |
|
|
| def update(self, current_kl, n_steps): |
| """Update the KL coefficient based on current KL divergence. |
| |
| Args: |
| current_kl (float): Current KL divergence value. |
| n_steps (int): Number of steps taken. |
| """ |
| target = self.target |
| proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2) |
| mult = 1 + proportional_error * n_steps / self.horizon |
| self.value *= mult |
|
|
|
|
| class FixedKLController: |
| """Fixed KL controller.""" |
|
|
| def __init__(self, kl_coef): |
| self.value = kl_coef |
|
|
| def update(self, current_kl, n_steps): |
| """Update method for fixed KL controller (no-op). |
| |
| Args: |
| current_kl (float): Current KL divergence value (unused). |
| n_steps (int): Number of steps taken (unused). |
| """ |
| pass |
|
|
|
|
| def get_kl_controller(kl_ctrl): |
| """Factory function to create appropriate KL controller based on configuration. |
| |
| Args: |
| kl_ctrl: Configuration object containing KL controller settings. |
| |
| Returns: |
| KL controller instance (FixedKLController or AdaptiveKLController). |
| |
| Raises: |
| NotImplementedError: If controller type is not supported. |
| AssertionError: If adaptive controller horizon is not positive. |
| """ |
| if kl_ctrl.type == "fixed": |
| return FixedKLController(kl_coef=kl_ctrl.kl_coef) |
| elif kl_ctrl.type == "adaptive": |
| assert kl_ctrl.horizon > 0, f"horizon must be larger than 0. Got {kl_ctrl.horizon}" |
| return AdaptiveKLController(init_kl_coef=kl_ctrl.kl_coef, target_kl=kl_ctrl.target_kl, horizon=kl_ctrl.horizon) |
| else: |
| raise NotImplementedError |
|
|
|
|
| @register_adv_est(AdvantageEstimator.GAE) |
| def compute_gae_advantage_return( |
| token_level_rewards: torch.Tensor, |
| values: torch.Tensor, |
| response_mask: torch.Tensor, |
| gamma: torch.Tensor, |
| lam: torch.Tensor, |
| ): |
| """Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py |
| |
| Args: |
| token_level_rewards: `(torch.Tensor)` |
| shape is (bs, response_length) |
| values: `(torch.Tensor)` |
| shape is (bs, response_length) |
| response_mask: `(torch.Tensor)` |
| shape is (bs, response_length). [EOS] mask. The token after [EOS] have mask zero. |
| gamma is `(float)` |
| discounted factor used in RL |
| lam: `(float)` |
| lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438) |
| |
| Returns: |
| advantages: `(torch.Tensor)` |
| shape: (bs, response_length) |
| Returns: `(torch.Tensor)` |
| shape: (bs, response_length) |
| |
| """ |
| with torch.no_grad(): |
| nextvalues = 0 |
| lastgaelam = 0 |
| advantages_reversed = [] |
| gen_len = token_level_rewards.shape[-1] |
|
|
| for t in reversed(range(gen_len)): |
| delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t] |
| lastgaelam_ = delta + gamma * lam * lastgaelam |
|
|
| |
| nextvalues = values[:, t] * response_mask[:, t] + (1 - response_mask[:, t]) * nextvalues |
| lastgaelam = lastgaelam_ * response_mask[:, t] + (1 - response_mask[:, t]) * lastgaelam |
|
|
| advantages_reversed.append(lastgaelam) |
| advantages = torch.stack(advantages_reversed[::-1], dim=1) |
|
|
| returns = advantages + values |
| advantages = verl_F.masked_whiten(advantages, response_mask) |
| return advantages, returns |
|
|
|
|
| |
| @register_adv_est(AdvantageEstimator.GRPO) |
| def compute_grpo_outcome_advantage( |
| token_level_rewards: torch.Tensor, |
| response_mask: torch.Tensor, |
| index: np.ndarray, |
| epsilon: float = 1e-6, |
| norm_adv_by_std_in_grpo: bool = True, |
| config: Optional[AlgoConfig] = None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Compute advantage for GRPO, operating only on Outcome reward |
| (with only one scalar reward for each response). |
| |
| Args: |
| token_level_rewards: `(torch.Tensor)` |
| shape is (bs, response_length) |
| response_mask: `(torch.Tensor)` |
| shape is (bs, response_length) |
| index: `(np.ndarray)` |
| index array for grouping |
| epsilon: `(float)` |
| small value to avoid division by zero |
| norm_adv_by_std_in_grpo: `(bool)` |
| whether to scale the GRPO advantage |
| config: `(Optional[AlgoConfig])` |
| algorithm configuration object |
| |
| Note: |
| If norm_adv_by_std_in_grpo is True, the advantage is scaled by the std, as in the original GRPO. |
| If False, the advantage is not scaled, as in Dr.GRPO (https://arxiv.org/abs/2503.20783). |
| |
| Returns: |
| advantages: `(torch.Tensor)` |
| shape is (bs, response_length) |
| Returns: `(torch.Tensor)` |
| shape is (bs, response_length) |
| """ |
| scores = token_level_rewards.sum(dim=-1) |
|
|
| id2score = defaultdict(list) |
| id2mean = {} |
| id2std = {} |
|
|
| with torch.no_grad(): |
| bsz = scores.shape[0] |
| for i in range(bsz): |
| id2score[index[i]].append(scores[i]) |
| for idx in id2score: |
| if len(id2score[idx]) == 1: |
| id2mean[idx] = torch.tensor(0.0) |
| id2std[idx] = torch.tensor(1.0) |
| elif len(id2score[idx]) > 1: |
| scores_tensor = torch.stack(id2score[idx]) |
| id2mean[idx] = torch.mean(scores_tensor) |
| id2std[idx] = torch.std(scores_tensor) |
| else: |
| raise ValueError(f"no score in prompt index: {idx}") |
| for i in range(bsz): |
| if norm_adv_by_std_in_grpo: |
| scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon) |
| else: |
| scores[i] = scores[i] - id2mean[index[i]] |
| scores = scores.unsqueeze(-1) * response_mask |
|
|
| return scores, scores |
|
|
|
|
| @register_adv_est(AdvantageEstimator.GRPO_VECTORIZED) |
| def compute_grpo_vectorized_outcome_advantage( |
| token_level_rewards: torch.Tensor, |
| response_mask: torch.Tensor, |
| index: np.ndarray, |
| epsilon: float = 1e-6, |
| norm_adv_by_std_in_grpo: bool = True, |
| config: Optional[AlgoConfig] = None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Vectorized GRPO(outcome-only): |
| For each group g: |
| a_i = \\frac{r_i - \\mu_g}{\\sigma_g} (or without dividing by \\sigma_g), |
| then broadcast the scalar across the token dimension (multiplied by response_mask).。 |
| """ |
| with torch.no_grad(): |
| scores = token_level_rewards.sum(dim=-1) |
| g = as_torch_index(index, device=scores.device) |
| mean_g, std_g, _ = group_mean_std(scores, g, eps=epsilon, device=scores.device) |
| if norm_adv_by_std_in_grpo: |
| scalars = (scores - mean_g[g]) / (std_g[g] + epsilon) |
| else: |
| scalars = scores - mean_g[g] |
| advantages = scalars.unsqueeze(-1) * response_mask |
| return advantages, advantages |
|
|
|
|
| @register_adv_est(AdvantageEstimator.QAE) |
| def compute_qae_outcome_advantage( |
| token_level_rewards: torch.Tensor, |
| response_mask: torch.Tensor, |
| index: np.ndarray, |
| epsilon: float = 1e-6, |
| config: Optional[AlgoConfig] = None, |
| **kwargs, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """Compute Quantile Advantage Estimation (QAE) for outcome-only rewards. |
| |
| This implements Eq. 3 from Wu et al. (2026), using the right-continuous |
| empirical K-quantile as the group baseline. For binary rewards, this reduces |
| to the hard/easy query gate in Eq. 4. For general scalar outcome rewards, we |
| keep the paper's empirical quantile definition. |
| |
| Args: |
| token_level_rewards: Token-level rewards of shape ``(bs, response_length)``. |
| response_mask: Response mask of shape ``(bs, response_length)``. |
| index: Group ids that map responses belonging to the same prompt. |
| epsilon: Small value to avoid division by zero. |
| config: Algorithm config. Supports: |
| - ``qae_quantile`` in ``(0, 1)``, defaults to ``0.4``. |
| - ``qae_norm_by_std``, defaults to ``True``. |
| |
| Returns: |
| A tuple of ``(advantages, returns)`` with the same shape as ``response_mask``. |
| """ |
| scores = token_level_rewards.sum(dim=-1) |
| quantile = config.get("qae_quantile", 0.4) if config is not None else 0.4 |
| normalize_by_std = config.get("qae_norm_by_std", True) if config is not None else True |
|
|
| if not 0.0 < quantile < 1.0: |
| raise ValueError(f"QAE requires algorithm.qae_quantile to be in (0, 1). Got: {quantile}") |
|
|
| id2score = defaultdict(list) |
| id2baseline = {} |
| id2std = {} |
|
|
| with torch.no_grad(): |
| bsz = scores.shape[0] |
| for i in range(bsz): |
| id2score[index[i]].append(scores[i]) |
|
|
| for idx, grouped_scores in id2score.items(): |
| if len(grouped_scores) == 0: |
| raise ValueError(f"no score in prompt index: {idx}") |
|
|
| scores_tensor = torch.stack(grouped_scores) |
| if len(grouped_scores) == 1: |
| |
| |
| id2baseline[idx] = scores_tensor[0] |
| id2std[idx] = scores.new_tensor(1.0) |
| continue |
|
|
| sorted_scores = torch.sort(scores_tensor).values |
| quantile_rank = max(math.ceil(quantile * len(grouped_scores)) - 1, 0) |
| id2baseline[idx] = sorted_scores[quantile_rank] |
| id2std[idx] = torch.std(scores_tensor) |
|
|
| for i in range(bsz): |
| scores[i] = scores[i] - id2baseline[index[i]] |
| if normalize_by_std: |
| scores[i] = scores[i] / (id2std[index[i]] + epsilon) |
|
|
| scores = scores.unsqueeze(-1) * response_mask |
|
|
| return scores, scores |
|
|
|
|
| register_adv_est("qae")(compute_qae_outcome_advantage) |
|
|
|
|
| @register_adv_est(AdvantageEstimator.GDPO) |
| def compute_gdpo_outcome_advantage( |
| token_level_rewards: torch.Tensor, |
| response_mask: torch.Tensor, |
| index: np.ndarray, |
| epsilon: float = 1e-6, |
| norm_adv_by_std_in_grpo: bool = True, |
| config: Optional[AlgoConfig] = None, |
| non_tensor_batch: Optional[dict] = None, |
| batch: Optional[dict] = None, |
| **kwargs, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| GDPO: Group reward-Decoupled Normalization Policy Optimization. |
| |
| Instead of summing all reward dimensions first (like GRPO), GDPO normalizes |
| each reward dimension independently within each group before aggregation. |
| This prevents a dominant reward signal from drowning out weaker ones. |
| |
| Mathematical formulation: |
| Step 1 – Group-wise decoupled normalization (via GRPO per dimension): |
| For each reward dimension k, within each group g: |
| A_k = (r_k - μ_group(r_k)) / (σ_group(r_k) + ε) |
| |
| Step 2 – Weighted aggregation: |
| A_sum = Σ_k w_k · A_k |
| |
| Step 3 – Batch-level normalization (via masked_whiten): |
| A_final = whiten(A_sum, response_mask) |
| |
| Args: |
| token_level_rewards: (bs, response_length) – standard token-level rewards. |
| Used as fallback when per-dimension rewards are not provided. |
| response_mask: (bs, response_length) |
| index: (bs,) – group id per sample (from ``uid``). |
| epsilon: Numerical stability constant. |
| norm_adv_by_std_in_grpo: Whether to normalize by std in GRPO. |
| config: Algorithm configuration (optional). |
| non_tensor_batch: Non-tensor batch data containing per-dimension reward scores. |
| batch: Batch data containing prompts, attention_mask, etc. |
| |
| Note: |
| Ref GDPO (https://arxiv.org/abs/2601.05242). |
| |
| Returns: |
| advantages: (bs, response_length) |
| returns: (bs, response_length) – same as advantages (outcome-only). |
| """ |
| score_list = None |
| reward_weights = None |
|
|
| if config is not None and non_tensor_batch is not None and batch is not None: |
| gdpo_reward_keys = config.get("gdpo_reward_keys", None) |
| assert gdpo_reward_keys, ( |
| "GDPO requires 'algorithm.gdpo_reward_keys' listing the individual reward " |
| "component keys returned by compute_score (e.g. ['format_reward', 'accuracy_reward'])." |
| ) |
| device = token_level_rewards.device |
| prompt_length = batch["prompts"].size(1) |
| valid_response_length = batch["attention_mask"][:, prompt_length:].sum(dim=1) - 1 |
|
|
| score_list = [] |
| for key in gdpo_reward_keys: |
| assert key in non_tensor_batch, ( |
| f"GDPO reward key '{key}' not found in non_tensor_batch. " |
| f"Available keys: {list(non_tensor_batch.keys())}. " |
| f"Make sure your compute_score returns a dict containing '{key}'." |
| ) |
| comp = non_tensor_batch[key] |
| rm_score = torch.tensor(np.asarray(comp, dtype=np.float32), device=device) |
| rm_scores = torch.zeros_like(response_mask, dtype=torch.float32) |
| rm_scores[torch.arange(rm_scores.size(0), device=device), valid_response_length] = rm_score |
| score_list.append(rm_scores) |
|
|
| gdpo_weights = config.get("gdpo_reward_weights", None) |
| if gdpo_weights is not None: |
| reward_weights = list(gdpo_weights) |
|
|
| if score_list is None: |
| score_list = [token_level_rewards] |
|
|
| num_scores = len(score_list) |
|
|
| if reward_weights is not None: |
| weights = torch.tensor(reward_weights, dtype=torch.float32, device=token_level_rewards.device) |
| else: |
| weights = torch.ones(num_scores, dtype=torch.float32, device=token_level_rewards.device) |
|
|
| new_advantage = None |
|
|
| for i in range(num_scores): |
| normalized_score, _ = compute_grpo_outcome_advantage( |
| token_level_rewards=score_list[i], |
| response_mask=response_mask, |
| index=index, |
| epsilon=epsilon, |
| norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, |
| config=config, |
| ) |
|
|
| if new_advantage is None: |
| new_advantage = weights[i] * normalized_score |
| else: |
| new_advantage += weights[i] * normalized_score |
|
|
| advantages = verl_F.masked_whiten(new_advantage, response_mask) * response_mask |
|
|
| return advantages, advantages |
|
|
|
|
| @register_adv_est(AdvantageEstimator.GRPO_PASSK) |
| def compute_grpo_passk_outcome_advantage( |
| token_level_rewards: torch.Tensor, |
| response_mask: torch.Tensor, |
| index: np.ndarray, |
| epsilon: float = 1e-6, |
| norm_adv_by_std_in_grpo: bool = True, |
| config: Optional[AlgoConfig] = None, |
| **kwargs, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Compute advantage for Pass@k using a GRPO-style outcome reward formulation. |
| Only the best response per group gets a non-zero advantage: r_max - r_second_max. |
| |
| Implemented as described in https://arxiv.org/abs/2503.19595. |
| |
| Args: |
| token_level_rewards: (bs, response_length) |
| response_mask: (bs, response_length) |
| index: (bs,) → group ID per sample |
| epsilon: float for numerical stability |
| config: (AlgoConfig) algorithm settings, which contains "norm_adv_by_std_in_grpo" |
| |
| Returns: |
| advantages: (bs, response_length) |
| returns: (bs, response_length) |
| """ |
| assert config is not None |
| |
| norm_adv_by_std_in_grpo = config.get("norm_adv_by_std_in_grpo", True) |
| scores = token_level_rewards.sum(dim=-1) |
| advantages = torch.zeros_like(scores) |
|
|
| id2scores = defaultdict(list) |
| id2indices = defaultdict(list) |
|
|
| with torch.no_grad(): |
| bsz = scores.shape[0] |
| for i in range(bsz): |
| idx = index[i] |
| id2scores[idx].append(scores[i]) |
| id2indices[idx].append(i) |
|
|
| for idx in id2scores: |
| rewards = torch.stack(id2scores[idx]) |
| if rewards.numel() < 2: |
| raise ValueError( |
| f"Pass@k requires at least 2 samples per group. Got {rewards.numel()} for group {idx}." |
| ) |
| topk, topk_idx = torch.topk(rewards, 2) |
| r_max, r_second_max = topk[0], topk[1] |
| i_max = id2indices[idx][topk_idx[0].item()] |
| advantage = r_max - r_second_max |
| if norm_adv_by_std_in_grpo: |
| std = torch.std(rewards) |
| advantage = advantage / (std + epsilon) |
| advantages[i_max] = advantage |
|
|
| advantages = advantages.unsqueeze(-1) * response_mask |
| return advantages, advantages |
|
|
|
|
| @register_adv_est( |
| AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE |
| ) |
| def compute_reinforce_plus_plus_baseline_outcome_advantage( |
| token_level_rewards: torch.Tensor, |
| response_mask: torch.Tensor, |
| index: torch.Tensor, |
| epsilon: float = 1e-6, |
| config: Optional[AlgoConfig] = None, |
| **kwargs, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Compute advantage for RF++-baseline (https://arxiv.org/abs/2501.03262), operating only on Outcome reward |
| (with only one scalar reward for each response). |
| |
| Args: |
| token_level_rewards: `(torch.Tensor)` |
| shape: (bs, response_length) |
| response_mask: `(torch.Tensor)` |
| shape: (bs, response_length) |
| config: (AlgoConfig) algorithm config |
| |
| Returns: |
| advantages: `(torch.Tensor)` |
| shape: (bs, response_length) |
| Returns: `(torch.Tensor)` |
| shape: (bs, response_length) |
| """ |
| response_length = token_level_rewards.shape[-1] |
| scores = token_level_rewards.sum(dim=-1) |
|
|
| id2score = defaultdict(list) |
| id2mean = {} |
|
|
| with torch.no_grad(): |
| bsz = scores.shape[0] |
| for i in range(bsz): |
| id2score[index[i]].append(scores[i]) |
| for idx in id2score: |
| if len(id2score[idx]) == 1: |
| id2mean[idx] = torch.tensor(0.0) |
| elif len(id2score[idx]) > 1: |
| id2mean[idx] = torch.mean(torch.stack(id2score[idx])) |
| else: |
| raise ValueError(f"no score in prompt index: {idx}") |
| for i in range(bsz): |
| scores[i] = scores[i] - id2mean[index[i]] |
|
|
| scores = scores.unsqueeze(-1).tile([1, response_length]) * response_mask |
| scores = verl_F.masked_whiten(scores, response_mask) * response_mask |
|
|
| return scores, scores |
|
|
|
|
| @register_adv_est(AdvantageEstimator.RLOO) |
| def compute_rloo_outcome_advantage( |
| token_level_rewards: torch.Tensor, |
| response_mask: torch.Tensor, |
| index: np.ndarray, |
| epsilon: float = 1e-6, |
| config: Optional[AlgoConfig] = None, |
| **kwargs, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740 |
| |
| Args: |
| token_level_rewards: `(torch.Tensor)` |
| shape: (bs, response_length) |
| response_mask: `(torch.Tensor)` |
| shape: (bs, response_length) |
| config: (AlgoConfig) algorithm config |
| |
| Returns: |
| advantages: `(torch.Tensor)` |
| shape: (bs, response_length) |
| Returns: `(torch.Tensor)` |
| shape: (bs, response_length) |
| """ |
| scores = token_level_rewards.sum(dim=-1) |
|
|
| id2score = defaultdict(list) |
| id2mean = {} |
|
|
| with torch.no_grad(): |
| bsz = scores.shape[0] |
| for i in range(bsz): |
| id2score[index[i]].append(scores[i]) |
| for idx in id2score: |
| if len(id2score[idx]) == 1: |
| id2mean[idx] = torch.tensor(0.0) |
| elif len(id2score[idx]) > 1: |
| id2mean[idx] = torch.mean(torch.stack(id2score[idx])) |
| else: |
| raise ValueError(f"no score in prompt index: {idx}") |
| for i in range(bsz): |
| response_num = len(id2score[index[i]]) |
| if response_num > 1: |
| scores[i] = scores[i] * response_num / (response_num - 1) - id2mean[index[i]] * response_num / ( |
| response_num - 1 |
| ) |
| scores = scores.unsqueeze(-1) * response_mask |
|
|
| return scores, scores |
|
|
|
|
| @register_adv_est(AdvantageEstimator.OPO) |
| def compute_opo_outcome_advantage( |
| token_level_rewards: torch.Tensor, |
| response_mask: torch.Tensor, |
| index: np.ndarray, |
| epsilon: float = 1e-6, |
| config: Optional[AlgoConfig] = None, |
| **kwargs, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Compute advantage for OPO based on https://arxiv.org/pdf/2505.23585 |
| |
| Args: |
| token_level_rewards: `(torch.Tensor)` |
| shape: (bs, response_length) |
| response_mask: `(torch.Tensor)` |
| shape: (bs, response_length) |
| config: (AlgoConfig) algorithm config |
| |
| Returns: |
| advantages: `(torch.Tensor)` |
| shape: (bs, response_length) |
| Returns: `(torch.Tensor)` |
| shape: (bs, response_length) |
| """ |
| response_length = response_mask.sum(dim=-1) |
| scores = token_level_rewards.sum(dim=-1) |
|
|
| id2score = defaultdict(list) |
| id2len = defaultdict(list) |
| id2bsl = {} |
|
|
| with torch.no_grad(): |
| bsz = scores.shape[0] |
| for i in range(bsz): |
| id2score[index[i]].append(scores[i]) |
| id2len[index[i]].append(response_length[i]) |
|
|
| for idx in id2score: |
| if len(id2score[idx]) == 1: |
| id2bsl[idx] = torch.tensor(0.0) |
| elif len(id2score[idx]) > 1: |
| score_tensor = torch.stack(id2score[idx]) |
| len_tensor = torch.stack(id2len[idx]) |
| id2bsl[idx] = (len_tensor * score_tensor).sum() / len_tensor.sum() |
| else: |
| raise ValueError(f"no score in prompt index: {idx}") |
| for i in range(bsz): |
| scores[i] = scores[i] - id2bsl[index[i]] |
| scores = scores.unsqueeze(-1) * response_mask |
|
|
| return scores, scores |
|
|
|
|
| @register_adv_est(AdvantageEstimator.REINFORCE_PLUS_PLUS) |
| def compute_reinforce_plus_plus_outcome_advantage( |
| token_level_rewards: torch.Tensor, response_mask: torch.Tensor, config: Optional[AlgoConfig] = None, **kwargs |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Compute advantage for REINFORCE++. |
| This implementation is based on the paper: https://arxiv.org/abs/2501.03262 |
| |
| Args: |
| token_level_rewards: `(torch.Tensor)` |
| shape: (bs, response_length) |
| response_mask: `(torch.Tensor)` |
| shape: (bs, response_length) |
| config: (AlgoConfig) algorithm config |
| |
| Returns: |
| advantages: `(torch.Tensor)` |
| shape: (bs, response_length) |
| Returns: `(torch.Tensor)` |
| shape: (bs, response_length) |
| """ |
| assert config is not None |
| gamma = config.gamma |
| with torch.no_grad(): |
| returns = torch.zeros_like(token_level_rewards) |
| running_return = 0 |
|
|
| for t in reversed(range(token_level_rewards.shape[1])): |
| running_return = token_level_rewards[:, t] + gamma * running_return |
| returns[:, t] = running_return |
| |
| running_return = running_return * response_mask[:, t] |
|
|
| advantages = verl_F.masked_whiten(returns, response_mask) |
| advantages = advantages * response_mask |
|
|
| return advantages, returns |
|
|
|
|
| @register_adv_est(AdvantageEstimator.REMAX) |
| def compute_remax_outcome_advantage( |
| token_level_rewards: torch.Tensor, |
| reward_baselines: torch.Tensor, |
| response_mask: torch.Tensor, |
| config: Optional[AlgoConfig] = None, |
| **kwargs, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Compute advantage for ReMax, operating only on Outcome reward |
| This implementation is based on the paper: https://arxiv.org/abs/2310.10505 |
| (with only one scalar reward for each response). |
| |
| Args: |
| token_level_rewards: `(torch.Tensor)` |
| shape: (bs, response_length) |
| reward_baselines: `(torch.Tensor)` |
| shape: (bs,) |
| response_mask: `(torch.Tensor)` |
| shape: (bs, response_length) |
| config: (AlgoConfig) algorithm config |
| |
| Returns: |
| advantages: `(torch.Tensor)` |
| shape: (bs, response_length) |
| Returns: `(torch.Tensor)` |
| shape: (bs, response_length) |
| """ |
|
|
| with torch.no_grad(): |
| returns = (token_level_rewards * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) |
| advantages = returns - reward_baselines.unsqueeze(-1) * response_mask |
|
|
| return advantages, returns |
|
|
|
|
| @register_adv_est(AdvantageEstimator.GPG) |
| def compute_gpg_outcome_advantage( |
| token_level_rewards: torch.Tensor, |
| response_mask: torch.Tensor, |
| index: np.ndarray, |
| epsilon: float = 1e-6, |
| f_norm: float = 1.0, |
| alpha: float = 1.0, |
| config=None, |
| **kwargs, |
| ): |
| """ |
| Compute advantage for GPG, operating only on Outcome reward |
| (with only one scalar reward for each response). |
| Args: |
| token_level_rewards: `(torch.Tensor)` |
| shape: (bs, response_length) |
| response_mask: `(torch.Tensor)` |
| shape: (bs, response_length) |
| index: `(np.ndarray)` |
| shape: (bs,) |
| epsilon: (float) |
| f_norm: (float) |
| alpha: (float) |
| config: (dict) algorithm config |
| |
| Returns: |
| advantages: `(torch.Tensor)` |
| shape: (bs, response_length) |
| Returns: `(torch.Tensor)` |
| shape: (bs, response_length) |
| """ |
| scores = token_level_rewards.sum(dim=-1) |
|
|
| id2score = defaultdict(list) |
| id2mean = {} |
| id2std = {} |
|
|
| with torch.no_grad(): |
| bsz = scores.shape[0] |
| m = torch.count_nonzero(scores) |
| alpha = bsz / m.clamp(min=1) |
|
|
| for i in range(bsz): |
| id2score[index[i]].append(scores[i]) |
|
|
| for idx in id2score: |
| if len(id2score[idx]) == 1: |
| id2mean[idx] = torch.tensor(0.0) |
| id2std[idx] = torch.tensor(1.0) |
| elif len(id2score[idx]) > 1: |
| scores_tensor = torch.stack(id2score[idx]) |
| id2mean[idx] = torch.mean(scores_tensor) |
| id2std[idx] = torch.std(scores_tensor) |
| else: |
| raise ValueError(f"no score in prompt index: {idx}") |
| for i in range(bsz): |
| scores[i] = alpha * (scores[i] - id2mean[index[i]]) / (f_norm) |
| scores = scores.unsqueeze(-1) * response_mask |
|
|
| return scores, scores |
|
|
|
|
| @register_adv_est(AdvantageEstimator.RLOO_VECTORIZED) |
| def compute_rloo_vectorized_outcome_advantage( |
| token_level_rewards: torch.Tensor, |
| response_mask: torch.Tensor, |
| index: np.ndarray, |
| epsilon: float = 1e-6, |
| config: Optional[AlgoConfig] = None, |
| **kwargs, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740 |
| |
| Args: |
| token_level_rewards: `(torch.Tensor)` |
| shape: (bs, response_length) |
| response_mask: `(torch.Tensor)` |
| shape: (bs, response_length) |
| config: (AlgoConfig) algorithm config |
| |
| Returns: |
| advantages: `(torch.Tensor)` |
| shape: (bs, response_length) |
| Returns: `(torch.Tensor)` |
| shape: (bs, response_length) |
| """ |
| scores = token_level_rewards.sum(dim=-1) |
|
|
| with torch.no_grad(): |
| inv = torch.from_numpy(np.unique(index, return_inverse=True)[1]).to(scores.device) |
|
|
| c = torch.bincount(inv)[inv].to(scores.dtype) |
| adv = ((c * scores - torch.bincount(inv, weights=scores)[inv]) / (c - 1).clamp_min(1)) * (c > 1) |
|
|
| adv = adv.unsqueeze(-1) * response_mask |
|
|
| return adv, adv |
|
|
|
|
| @register_adv_est(AdvantageEstimator.OPTIMAL_TOKEN_BASELINE) |
| def compute_optimal_token_baseline_advantage( |
| token_level_rewards: torch.Tensor, |
| response_mask: torch.Tensor, |
| index: np.ndarray, |
| old_log_probs: torch.Tensor, |
| sum_pi_squared: torch.Tensor, |
| rollout_is_weights: torch.Tensor = None, |
| handle_zero_tail: bool = True, |
| epsilon: float = 1e-8, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Compute advantages using Optimal Token Baseline (OTB). |
| |
| Unlike the group mean based baseline which uses a single baseline per trajectory, |
| this computes a unique baseline for each timestep using cumulative path variance. |
| |
| Theory: |
| For each timestep t in each prompt group: |
| B_t* = E[G_t × W_t] / E[W_t] |
| where W_t = Σ_{j=1}^t ||s_j||² (cumulative path-variance proxy) |
| and ||s_j||² = 1 - 2π_j + Σπ² |
| |
| The cumulative sum W_t captures the "realized energy" of trajectory has been up to timestep t, |
| giving higher weight to predicting rewards on high-variance paths. |
| |
| Args: |
| token_level_rewards: Rewards at each token position [shape: (bs, response_length)] |
| response_mask: Binary mask for valid tokens (1) vs padding (0) [shape: (bs, response_length)] |
| index: Prompt indices for grouping trajectories from same prompt [shape: (bs,)] |
| old_log_probs: Log probabilities from training policy during generation [shape: (bs, response_length)] |
| sum_pi_squared: Sum of squared probabilities over vocabulary Σπ² [shape: (bs, response_length)] |
| rollout_is_weights: Pre-computed IS weights for W correction [shape: (bs, response_length)], |
| None if not using IS |
| handle_zero_tail: If True, zero baselines will be set in the portion of the longest trajectory |
| that extends beyond the second-longest trajectory in the prompt group. |
| Default: True |
| epsilon: Small constant for numerical stability (default: 1e-8) |
| |
| Returns: |
| advantages: OTB advantage estimates [shape: (bs, response_length)] |
| returns: Cumulative rewards (returns) from each position [shape: (bs, response_length)] |
| |
| Note on Rollout Importance Sampling: |
| When rollout_is_weights is provided, W_t is scaled by ρ̄²(t) to minimize MSE under truncated IS: |
| B_t* = Σ[G_t × ρ̄²(t) × W_t] / Σ[ρ̄²(t) × W_t] |
| """ |
| with torch.no_grad(): |
| batch_size, seq_len = token_level_rewards.shape |
| device = token_level_rewards.device |
|
|
| |
| returns = (token_level_rewards * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) |
|
|
| |
| pi_t = torch.exp(old_log_probs) |
| w_per_timestep = 1 - 2 * pi_t + sum_pi_squared |
|
|
| |
| if rollout_is_weights is not None: |
| |
| w_per_timestep = w_per_timestep * (rollout_is_weights**2) |
|
|
| |
| |
| w_cumulative = (w_per_timestep * response_mask).cumsum(dim=-1) |
|
|
| |
| prompt_groups = defaultdict(list) |
| for i in range(batch_size): |
| prompt_groups[index[i]].append(i) |
|
|
| |
| baselines = torch.zeros_like(returns) |
|
|
| |
| for _, trajectory_indices in prompt_groups.items(): |
| N = len(trajectory_indices) |
| if N == 1: |
| |
| continue |
|
|
| traj_idx = torch.tensor(trajectory_indices, device=device) |
|
|
| |
| returns_group = returns[traj_idx] |
| w_cumulative_group = w_cumulative[traj_idx] |
| mask_group = response_mask[traj_idx] |
|
|
| |
| |
| |
| numerator = (returns_group * w_cumulative_group * mask_group).sum(dim=0) |
| denominator = (w_cumulative_group * mask_group).sum(dim=0) + epsilon |
|
|
| baseline_per_step = numerator / denominator |
|
|
| |
| baselines[traj_idx] = baseline_per_step.unsqueeze(0).expand(N, -1) |
|
|
| if handle_zero_tail: |
| |
| |
| response_lengths = mask_group.sum(dim=-1) |
| sorted_lengths, _ = torch.sort(response_lengths) |
| max_length = int(sorted_lengths[-1].item()) |
| second_max_length = int(sorted_lengths[-2].item()) |
| max_length_idx = (response_lengths == max_length).nonzero(as_tuple=True)[0] |
| if max_length_idx.numel() == 1 and max_length > second_max_length: |
| max_length_traj_idx = trajectory_indices[int(max_length_idx[0])] |
| baselines[max_length_traj_idx, second_max_length:] = 0.0 |
|
|
| |
| advantages = (returns - baselines) * response_mask |
|
|
| return advantages, returns |
|
|
|
|
| @register_adv_est(AdvantageEstimator.TIR_OPTIMAL_TOKEN_BASELINE) |
| def compute_multi_turn_optimal_token_baseline_advantage( |
| token_level_rewards: torch.Tensor, |
| response_mask: torch.Tensor, |
| index: np.ndarray, |
| old_log_probs: torch.Tensor, |
| sum_pi_squared: torch.Tensor, |
| rollout_is_weights: torch.Tensor = None, |
| handle_zero_tail: bool = True, |
| epsilon: float = 1e-8, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Compute advantages using Optimal Token Baseline (OTB). |
| |
| Unlike the group mean based baseline which uses a single baseline per trajectory, |
| this computes a unique baseline for each timestep using cumulative path variance. |
| |
| Theory: |
| For each timestep t in each prompt group: |
| B_t* = E[G_t × W_t] / E[W_t] |
| where W_t = Σ_{j=1}^t ||s_j||² (cumulative path-variance proxy) |
| and ||s_j||² = 1 - 2π_j + Σπ² |
| |
| The cumulative sum W_t captures the "realized energy" of trajectory has been up to timestep t, |
| giving higher weight to predicting rewards on high-variance paths. |
| |
| Args: |
| token_level_rewards: Rewards at each token position [shape: (bs, response_length)] |
| response_mask: Binary mask for valid tokens (1) vs padding (0) [shape: (bs, response_length)] |
| index: Prompt indices for grouping trajectories from same prompt [shape: (bs,)] |
| old_log_probs: Log probabilities from training policy during generation [shape: (bs, response_length)] |
| sum_pi_squared: Sum of squared probabilities over vocabulary Σπ² [shape: (bs, response_length)] |
| rollout_is_weights: Pre-computed IS weights for W correction [shape: (bs, response_length)], |
| None if not using IS |
| handle_zero_tail: If True, zero baselines will be set in the portion of the longest trajectory |
| that extends beyond the second-longest trajectory in the prompt group. |
| Default: False |
| epsilon: Small constant for numerical stability (default: 1e-8) |
| |
| Returns: |
| advantages: OTB advantage estimates [shape: (bs, response_length)] |
| returns: Cumulative rewards (returns) from each position [shape: (bs, response_length)] |
| |
| Note on Rollout Importance Sampling: |
| When rollout_is_weights is provided, W_t is scaled by ρ̄²(t) to minimize MSE under truncated IS: |
| B_t* = Σ[G_t × ρ̄²(t) × W_t] / Σ[ρ̄²(t) × W_t] |
| """ |
| with torch.no_grad(): |
| |
| token_returns = (token_level_rewards * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) |
|
|
| |
| pi_t = torch.exp(old_log_probs) |
| w_per_timestep = 1 - 2 * pi_t + sum_pi_squared |
|
|
| |
| if rollout_is_weights is not None: |
| |
| w_per_timestep = w_per_timestep * (rollout_is_weights**2) |
|
|
| |
| |
| w_cumulative = (w_per_timestep * response_mask).cumsum(dim=-1) |
|
|
| |
| |
| response_lengths = response_mask.sum(dim=-1).to(dtype=torch.long) |
| max_response_length = int(response_lengths.max().item()) if response_lengths.numel() > 0 else 0 |
| all_w_values = w_cumulative.new_zeros( |
| (len(response_lengths), max_response_length) |
| ) |
| all_returns = torch.zeros_like(all_w_values) |
| for i in range(len(response_lengths)): |
| length = int(response_lengths[i].item()) |
| if length == 0: |
| continue |
| mask = response_mask[i].bool() |
| all_w_values[i, :length] = w_cumulative[i, mask] |
| all_returns[i, :length] = token_returns[i, mask] |
|
|
| |
| prompt_groups = defaultdict(list) |
| for i in range(len(response_lengths)): |
| if response_lengths[i] == 0: |
| continue |
| prompt_groups[index[i]].append(i) |
|
|
| |
| baselines = torch.zeros_like(all_returns) |
|
|
| for _, trajectory_indices in prompt_groups.items(): |
| N = len(trajectory_indices) |
| traj_idx = torch.tensor(trajectory_indices, device=all_returns.device) |
|
|
| if N == 1: |
| |
| baselines[traj_idx[0]] = 0.0 |
| continue |
|
|
| |
| w_group = all_w_values[traj_idx] |
| R_group = all_returns[traj_idx] |
| |
| b_star = (R_group * w_group).sum(dim=0) / (w_group.sum(dim=0) + epsilon) |
| |
| baselines[traj_idx] = b_star.to(baselines.dtype) |
|
|
| if handle_zero_tail: |
| |
| |
| response_lengths_group = response_lengths[traj_idx] |
| sorted_lengths, _ = torch.sort(response_lengths_group) |
| max_length = int(sorted_lengths[-1].item()) |
| second_max_length = int(sorted_lengths[-2].item()) |
| max_length_idx = (response_lengths_group == max_length).nonzero(as_tuple=True)[0] |
| if max_length_idx.numel() == 1 and max_length > second_max_length: |
| max_length_traj_idx = trajectory_indices[int(max_length_idx[0])] |
| baselines[max_length_traj_idx, second_max_length:] = 0.0 |
|
|
| |
| all_advantages = all_returns - baselines |
|
|
| advantages = torch.zeros_like(token_returns) |
| for i in range(len(response_lengths)): |
| if response_lengths[i] == 0: |
| continue |
| advantages[i, response_mask[i].bool()] = all_advantages[i, : response_lengths[i]] |
|
|
| advantages = advantages * response_mask |
|
|
| return advantages, token_returns |
|
|
|
|
| def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio): |
| """Compute token-level rewards with KL penalty. |
| |
| Args: |
| token_level_scores (torch.Tensor): Token-level reward scores. |
| old_log_prob (torch.Tensor): Log probabilities from current policy. |
| ref_log_prob (torch.Tensor): Log probabilities from reference policy. |
| kl_ratio (float): KL penalty coefficient. |
| |
| Returns: |
| torch.Tensor: Token-level rewards with KL penalty applied. |
| """ |
| kl = old_log_prob - ref_log_prob |
| return token_level_scores - kl * kl_ratio |
|
|
|
|
| def agg_loss( |
| loss_mat: torch.Tensor, |
| loss_mask: torch.Tensor, |
| loss_agg_mode: str, |
| dp_size: int = 1, |
| batch_num_tokens: Optional[int] = None, |
| global_batch_size: Optional[int] = None, |
| loss_scale_factor: Optional[int] = None, |
| ): |
| """ |
| Aggregate the loss across global batch to ensure the loss is invariant to fsdp/megatron parallelism. |
| |
| NOTE: The returned loss has different behaviors for different backend: |
| - FSDP: the loss is directly used for backward. |
| - Megatron: the loss should be scaled by `num_microbatches` and `cp_size` for pp schedule. |
| |
| Args: |
| loss_mat: micro batch loss matrix, (bs, response_length) |
| loss_mask: micro batch loss mask, (bs, response_length) |
| loss_agg_mode: method to aggregate the loss matrix into a scalar |
| dp_size: data parallel size |
| batch_num_tokens: number of valid tokens in global batch |
| global_batch_size: global batch size |
| loss_scale_factor: scale factor for "seq-mean-token-sum-norm" mode. If None, uses loss_mask.shape[-1]. |
| Set this to a constant value to ensure consistent normalization throughout training. |
| |
| Returns: |
| loss: `a scalar torch.Tensor` |
| aggregated loss |
| """ |
| if loss_agg_mode == "token-mean": |
| if batch_num_tokens is None: |
| if dp_size > 1: |
| raise ValueError("(global) batch_num_tokens is required when dp_size > 1") |
| batch_num_tokens = loss_mask.sum() |
| loss = verl_F.masked_sum(loss_mat, loss_mask) / batch_num_tokens * dp_size |
| elif loss_agg_mode in ["seq-mean-token-sum", "seq-mean-token-sum-norm"]: |
| seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) |
| seq_mask = (torch.sum(loss_mask, dim=-1) > 0).float() |
| if global_batch_size is None: |
| if dp_size > 1: |
| raise ValueError("global_batch_size is required when dp_size > 1") |
| global_batch_size = seq_mask.sum() |
| loss = verl_F.masked_sum(seq_losses, seq_mask) / global_batch_size * dp_size |
| if loss_agg_mode == "seq-mean-token-sum-norm": |
| if loss_scale_factor is None: |
| horizon = loss_mask.shape[-1] |
| loss_scale_factor = horizon |
| loss /= loss_scale_factor |
| elif loss_agg_mode == "seq-mean-token-mean": |
| seq_mask = torch.sum(loss_mask, dim=-1) |
| seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / (seq_mask + 1e-8) |
| seq_mask = (seq_mask > 0).float() |
| if global_batch_size is None: |
| if dp_size > 1: |
| raise ValueError("global_batch_size is required when dp_size > 1") |
| global_batch_size = seq_mask.sum() |
| loss = verl_F.masked_sum(seq_losses, seq_mask) / global_batch_size * dp_size |
| else: |
| raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}") |
|
|
| return loss |
|
|
|
|
| @deprecated("verl.trainer.ppo.core_algos.compute_policy_loss_vanilla") |
| def compute_policy_loss( |
| old_log_prob, |
| log_prob, |
| advantages, |
| response_mask, |
| cliprange=None, |
| cliprange_low=None, |
| cliprange_high=None, |
| clip_ratio_c=3.0, |
| loss_agg_mode: str = "token-mean", |
| ): |
| """ |
| Compute the clipped policy objective and related metrics for PPO. |
| |
| Adapted from |
| https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122 |
| |
| Args: |
| old_log_prob (torch.Tensor): |
| Log-probabilities of actions under the old policy, shape (batch_size, response_length). |
| log_prob (torch.Tensor): |
| Log-probabilities of actions under the current policy, shape (batch_size, response_length). |
| advantages (torch.Tensor): |
| Advantage estimates for each action, shape (batch_size, response_length). |
| response_mask (torch.Tensor): |
| Mask indicating which tokens to include in the loss, shape (batch_size, response_length). |
| cliprange (float, optional): |
| Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347. |
| Defaults to None (must be provided). |
| cliprange_low (float, optional): |
| Lower clip range for dual-clip PPO. Defaults to same as `cliprange`. |
| cliprange_high (float, optional): |
| Upper clip range for dual-clip PPO. Defaults to same as `cliprange`. |
| clip_ratio_c (float, optional): |
| Lower bound of the ratio for dual-clip PPO. See https://arxiv.org/pdf/1912.09729. |
| Defaults to 3.0. |
| loss_agg_mode (str, optional): |
| Aggregation mode for `agg_loss`. Defaults to "token-mean". |
| """ |
| assert clip_ratio_c > 1.0, ( |
| "The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0," |
| + f" but get the value: {clip_ratio_c}." |
| ) |
|
|
| negative_approx_kl = log_prob - old_log_prob |
| |
| negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) |
| ratio = torch.exp(negative_approx_kl) |
| ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) |
|
|
| pg_losses1 = -advantages * ratio |
| if cliprange_low is None: |
| cliprange_low = cliprange |
| if cliprange_high is None: |
| cliprange_high = cliprange |
| pg_losses2 = -advantages * torch.clamp( |
| ratio, 1 - cliprange_low, 1 + cliprange_high |
| ) |
| clip_pg_losses1 = torch.maximum( |
| pg_losses1, pg_losses2 |
| ) |
| pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask) |
|
|
| pg_losses3 = -advantages * clip_ratio_c |
| clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1) |
| pg_clipfrac_lower = verl_F.masked_mean( |
| torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask |
| ) |
|
|
| pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) |
| pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) |
|
|
| return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower |
|
|
|
|
| @register_policy_loss("vanilla") |
| def compute_policy_loss_vanilla( |
| old_log_prob: torch.Tensor, |
| log_prob: torch.Tensor, |
| advantages: torch.Tensor, |
| response_mask: torch.Tensor, |
| loss_agg_mode: str = "token-mean", |
| config: Optional[ActorConfig] = None, |
| rollout_is_weights: torch.Tensor | None = None, |
| ) -> tuple[torch.Tensor, dict[str, Any]]: |
| """ |
| Compute the clipped policy objective and related metrics for PPO. |
| |
| Adapted from |
| https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122 |
| |
| Args: |
| old_log_prob (torch.Tensor): |
| Log-probabilities of actions under the old policy, shape (batch_size, response_length). |
| log_prob (torch.Tensor): |
| Log-probabilities of actions under the current policy, shape (batch_size, response_length). |
| advantages (torch.Tensor): |
| Advantage estimates for each action, shape (batch_size, response_length). |
| response_mask (torch.Tensor): |
| Mask indicating which tokens to include in the loss, shape (batch_size, response_length). |
| loss_agg_mode (str, optional): |
| Aggregation mode for `agg_loss`. Defaults to "token-mean". |
| config: `(verl.trainer.config.ActorConfig)`: |
| config for the actor. |
| rollout_log_probs: `(torch.Tensor)`: |
| log probabilities of actions under the rollout policy, shape (batch_size, response_length). |
| """ |
|
|
| assert config is not None |
| assert not isinstance(config, AlgoConfig) |
| clip_ratio = config.clip_ratio |
| clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio |
| clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio |
| clip_ratio_c = config.get( |
| "clip_ratio_c", 3.0 |
| ) |
|
|
| cliprange = clip_ratio |
| cliprange_low = clip_ratio_low |
| cliprange_high = clip_ratio_high |
|
|
| assert clip_ratio_c > 1.0, ( |
| "The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0," |
| + f" but get the value: {clip_ratio_c}." |
| ) |
|
|
| negative_approx_kl = log_prob - old_log_prob |
| |
| negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) |
| ratio = torch.exp(negative_approx_kl) |
| ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) |
|
|
| pg_losses1 = -advantages * ratio |
| if cliprange_low is None: |
| cliprange_low = cliprange |
| if cliprange_high is None: |
| cliprange_high = cliprange |
| pg_losses2 = -advantages * torch.clamp( |
| ratio, 1 - cliprange_low, 1 + cliprange_high |
| ) |
| clip_pg_losses1 = torch.maximum( |
| pg_losses1, pg_losses2 |
| ) |
| pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask) |
|
|
| pg_losses3 = -advantages * clip_ratio_c |
| clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1) |
| pg_clipfrac_lower = verl_F.masked_mean( |
| torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask |
| ) |
|
|
| pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) |
|
|
| |
| if rollout_is_weights is not None: |
| pg_losses = pg_losses * rollout_is_weights |
|
|
| pg_loss = agg_loss( |
| loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode, **config.global_batch_info |
| ) |
|
|
| pg_metrics = { |
| "actor/pg_clipfrac": pg_clipfrac.detach().item(), |
| "actor/ppo_kl": ppo_kl.detach().item(), |
| "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(), |
| } |
| return pg_loss, pg_metrics |
|
|
|
|
| @register_policy_loss("dppo_tv") |
| def compute_policy_loss_dppo_tv( |
| old_log_prob: torch.Tensor, |
| log_prob: torch.Tensor, |
| advantages: torch.Tensor, |
| response_mask: torch.Tensor, |
| loss_agg_mode: str = "token-mean", |
| config: Optional[ActorConfig] = None, |
| rollout_is_weights: torch.Tensor | None = None, |
| ) -> tuple[torch.Tensor, dict[str, Any]]: |
| """ |
| Compute the clipped policy objective and related metrics for DPPO-Binary-TV. |
| |
| See https://arxiv.org/pdf/2602.04879 for more details. |
| |
| Args: |
| old_log_prob (torch.Tensor): |
| Log-probabilities of actions under the old policy, shape (batch_size, response_length). |
| log_prob (torch.Tensor): |
| Log-probabilities of actions under the current policy, shape (batch_size, response_length). |
| advantages (torch.Tensor): |
| Advantage estimates for each action, shape (batch_size, response_length). |
| response_mask (torch.Tensor): |
| Mask indicating which tokens to include in the loss, shape (batch_size, response_length). |
| loss_agg_mode (str, optional): |
| Aggregation mode for `agg_loss`. Defaults to "token-mean". |
| config: `(verl.trainer.config.ActorConfig)`: |
| config for the actor. |
| rollout_log_probs: `(torch.Tensor)`: |
| log probabilities of actions under the rollout policy, shape (batch_size, response_length). |
| """ |
|
|
| assert config is not None |
| assert not isinstance(config, AlgoConfig) |
| |
| clip_divergence = config.clip_ratio |
| clip_divergence_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_divergence |
| clip_divergence_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_divergence |
|
|
| negative_approx_kl = log_prob - old_log_prob |
| |
| negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) |
| ratio = torch.exp(negative_approx_kl) |
| ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) |
|
|
| |
| |
| |
| clip_ratio_c = config.get("clip_ratio_c", 20.0) |
| truncated_ratio = torch.clamp(ratio, max=clip_ratio_c) |
| truncated_ratio = truncated_ratio.detach() |
|
|
| |
| prob = torch.exp(log_prob) |
| old_prob = torch.exp(old_log_prob) |
| valid_positive_mask = (prob - old_prob) <= clip_divergence_high |
| valid_negative_mask = (prob - old_prob) >= -clip_divergence_low |
| valid_mask = torch.where(advantages > 0, valid_positive_mask, valid_negative_mask) |
| valid_mask = valid_mask.detach().float() |
|
|
| pg_losses = -advantages * truncated_ratio * log_prob * valid_mask |
|
|
| |
| if rollout_is_weights is not None: |
| pg_losses = pg_losses * rollout_is_weights |
|
|
| pg_loss = agg_loss( |
| loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode, **config.global_batch_info |
| ) |
|
|
| pg_clipfrac = verl_F.masked_mean((1.0 - valid_mask).float(), response_mask) |
| pg_clipfrac_lower = verl_F.masked_mean((ratio > clip_ratio_c).float() * valid_mask, response_mask) |
|
|
| pg_metrics = { |
| "actor/pg_clipfrac": pg_clipfrac.detach().item(), |
| "actor/ppo_kl": ppo_kl.detach().item(), |
| "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(), |
| } |
| return pg_loss, pg_metrics |
|
|
|
|
| @register_policy_loss("dppo_kl") |
| def compute_policy_loss_dppo_kl( |
| old_log_prob: torch.Tensor, |
| log_prob: torch.Tensor, |
| advantages: torch.Tensor, |
| response_mask: torch.Tensor, |
| loss_agg_mode: str = "token-mean", |
| config: Optional[ActorConfig] = None, |
| rollout_is_weights: torch.Tensor | None = None, |
| ) -> tuple[torch.Tensor, dict[str, Any]]: |
| """ |
| Compute the clipped policy objective and related metrics for DPPO-Binary-KL. |
| |
| See https://arxiv.org/pdf/2602.04879 for more details. |
| |
| Args: |
| old_log_prob (torch.Tensor): |
| Log-probabilities of actions under the old policy, shape (batch_size, response_length). |
| log_prob (torch.Tensor): |
| Log-probabilities of actions under the current policy, shape (batch_size, response_length). |
| advantages (torch.Tensor): |
| Advantage estimates for each action, shape (batch_size, response_length). |
| response_mask (torch.Tensor): |
| Mask indicating which tokens to include in the loss, shape (batch_size, response_length). |
| loss_agg_mode (str, optional): |
| Aggregation mode for `agg_loss`. Defaults to "token-mean". |
| config: `(verl.trainer.config.ActorConfig)`: |
| config for the actor. |
| rollout_log_probs: `(torch.Tensor)`: |
| log probabilities of actions under the rollout policy, shape (batch_size, response_length). |
| """ |
|
|
| assert config is not None |
| assert not isinstance(config, AlgoConfig) |
| |
| clip_divergence = config.clip_ratio |
| clip_divergence_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_divergence |
| clip_divergence_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_divergence |
|
|
| negative_approx_kl = log_prob - old_log_prob |
| |
| negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) |
| ratio = torch.exp(negative_approx_kl) |
| ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) |
|
|
| |
| |
| |
| clip_ratio_c = config.get("clip_ratio_c", 20.0) |
| truncated_ratio = torch.clamp(ratio, max=clip_ratio_c) |
| truncated_ratio = truncated_ratio.detach() |
|
|
| |
| prob = torch.exp(log_prob) |
| old_prob = torch.exp(old_log_prob) |
| binary_kl = old_prob * (old_log_prob - log_prob) + (1 - old_prob) * torch.log( |
| (1.0 - old_prob + 1e-8) / (1.0 - prob + 1e-8) |
| ) |
| valid_positive_mask = (binary_kl <= clip_divergence_high) | (prob <= old_prob) |
| valid_negative_mask = (binary_kl <= clip_divergence_low) | (prob >= old_prob) |
| valid_mask = torch.where(advantages > 0, valid_positive_mask, valid_negative_mask) |
| valid_mask = valid_mask.detach().float() |
|
|
| pg_losses = -advantages * truncated_ratio * log_prob * valid_mask |
|
|
| |
| if rollout_is_weights is not None: |
| pg_losses = pg_losses * rollout_is_weights |
|
|
| pg_loss = agg_loss( |
| loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode, **config.global_batch_info |
| ) |
|
|
| |
| pg_clipfrac = verl_F.masked_mean((1.0 - valid_mask).float(), response_mask) |
| pg_clipfrac_lower = verl_F.masked_mean((ratio > clip_ratio_c).float() * valid_mask, response_mask) |
|
|
| pg_metrics = { |
| "actor/pg_clipfrac": pg_clipfrac.detach().item(), |
| "actor/ppo_kl": ppo_kl.detach().item(), |
| "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(), |
| } |
| return pg_loss, pg_metrics |
|
|
|
|
| @register_policy_loss("gspo") |
| def compute_policy_loss_gspo( |
| old_log_prob: torch.Tensor, |
| log_prob: torch.Tensor, |
| advantages: torch.Tensor, |
| response_mask: torch.Tensor, |
| loss_agg_mode: str = "seq-mean-token-mean", |
| config: Optional[ActorConfig] = None, |
| rollout_is_weights: torch.Tensor | None = None, |
| ) -> tuple[torch.Tensor, dict[str, Any]]: |
| """ |
| Compute the clipped policy objective and related metrics for GSPO. |
| |
| See https://arxiv.org/pdf/2507.18071 for more details. |
| |
| Args: |
| old_log_prob (torch.Tensor): |
| Log-probabilities of actions under the old policy, shape (batch_size, response_length). |
| log_prob (torch.Tensor): |
| Log-probabilities of actions under the current policy, shape (batch_size, response_length). |
| advantages (torch.Tensor): |
| Advantage estimates for each action, shape (batch_size, response_length). |
| response_mask (torch.Tensor): |
| Mask indicating which tokens to include in the loss, shape (batch_size, response_length). |
| loss_agg_mode (str, optional): |
| Aggregation mode for `agg_loss`. For GSPO, it is recommended to use "seq-mean-token-mean". |
| """ |
|
|
| assert config is not None |
| assert isinstance(config, ActorConfig) |
| clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else config.clip_ratio |
| clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else config.clip_ratio |
|
|
| negative_approx_kl = log_prob - old_log_prob |
|
|
| |
| |
| |
| seq_lengths = torch.sum(response_mask, dim=-1).clamp(min=1) |
| negative_approx_kl_seq = torch.sum(negative_approx_kl * response_mask, dim=-1) / seq_lengths |
|
|
| |
| |
| |
| log_seq_importance_ratio = log_prob - log_prob.detach() + negative_approx_kl_seq.detach().unsqueeze(-1) |
| log_seq_importance_ratio = torch.clamp(log_seq_importance_ratio, max=10.0) |
|
|
| |
| seq_importance_ratio = torch.exp(log_seq_importance_ratio) |
|
|
| pg_losses1 = -advantages * seq_importance_ratio |
| pg_losses2 = -advantages * torch.clamp(seq_importance_ratio, 1 - clip_ratio_low, 1 + clip_ratio_high) |
| pg_losses = torch.maximum(pg_losses1, pg_losses2) |
|
|
| |
| if rollout_is_weights is not None: |
| pg_losses = pg_losses * rollout_is_weights |
|
|
| |
| pg_loss = agg_loss( |
| loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode="seq-mean-token-mean", **config.global_batch_info |
| ) |
|
|
| |
| pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask) |
| pg_clipfrac_lower = torch.tensor(0.0, device=pg_loss.device) |
|
|
| ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) |
| pg_metrics = { |
| "actor/pg_clipfrac": pg_clipfrac.detach().item(), |
| "actor/ppo_kl": ppo_kl.detach().item(), |
| "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(), |
| } |
| return pg_loss, pg_metrics |
|
|
|
|
| @register_policy_loss("sapo") |
| def compute_policy_loss_sapo( |
| old_log_prob: torch.Tensor, |
| log_prob: torch.Tensor, |
| advantages: torch.Tensor, |
| response_mask: torch.Tensor, |
| loss_agg_mode: str = "seq-mean-token-mean", |
| config: Optional[ActorConfig] = None, |
| rollout_is_weights: torch.Tensor | None = None, |
| ) -> tuple[torch.Tensor, dict[str, Any]]: |
| """ |
| Compute the smoothed policy objective and related metrics for SAPO. |
| |
| See https://arxiv.org/pdf/2511.20347 for more details. |
| |
| Args: |
| old_log_prob (torch.Tensor): |
| Log-probabilities of actions under the old policy, shape (batch_size, response_length). |
| log_prob (torch.Tensor): |
| Log-probabilities of actions under the current policy, shape (batch_size, response_length). |
| advantages (torch.Tensor): |
| Advantage estimates for each action, shape (batch_size, response_length). |
| response_mask (torch.Tensor): |
| Mask indicating which tokens to include in the loss, shape (batch_size, response_length). |
| loss_agg_mode (str, optional): |
| Aggregation mode for `agg_loss`. For SAPO, it is recommended to use "seq-mean-token-mean". |
| """ |
|
|
| assert config is not None |
| assert isinstance(config, ActorConfig) |
|
|
| |
| tau_pos = torch.as_tensor(config.tau_pos, dtype=advantages.dtype, device=advantages.device) |
| tau_neg = torch.as_tensor(config.tau_neg, dtype=advantages.dtype, device=advantages.device) |
|
|
| def gate_function(x, tau): |
| """The gating function used in SAPO""" |
| return torch.sigmoid(tau * (x - 1.0)) * (4.0 / tau) |
|
|
| |
| |
| |
| negative_approx_kl = log_prob - old_log_prob |
| |
| negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) |
| |
| ratio = torch.exp(negative_approx_kl) |
|
|
| |
| taus = torch.where( |
| condition=advantages > 0, |
| input=tau_pos, |
| other=tau_neg, |
| ) |
|
|
| |
| gates = gate_function(ratio, taus) |
|
|
| |
| pg_losses = -gates * advantages |
|
|
| |
| if rollout_is_weights is not None: |
| pg_losses = pg_losses * rollout_is_weights |
|
|
| |
| pg_loss = agg_loss( |
| loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode="seq-mean-token-mean", **config.global_batch_info |
| ) |
|
|
| |
| pg_clipfrac = torch.tensor(0.0, device=pg_loss.device) |
| pg_clipfrac_lower = torch.tensor(0.0, device=pg_loss.device) |
| |
| ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) |
| |
| pg_metrics = { |
| "actor/pg_clipfrac": pg_clipfrac.detach().item(), |
| "actor/ppo_kl": ppo_kl.detach().item(), |
| "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(), |
| } |
|
|
| return pg_loss, pg_metrics |
|
|
|
|
| @register_policy_loss("gpg") |
| def compute_policy_loss_gpg( |
| old_log_prob: torch.Tensor, |
| log_prob: torch.Tensor, |
| advantages: torch.Tensor, |
| response_mask: torch.Tensor, |
| loss_agg_mode: str = "token-mean", |
| config: Optional[ActorConfig] = None, |
| rollout_is_weights: torch.Tensor | None = None, |
| ) -> tuple[torch.Tensor, dict[str, Any]]: |
| """Adapted from |
| https://github.com/AMAP-ML/GPG/blob/main/VisualThinker-R1-Zero/src/open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py#L495 |
| Args: |
| log_prob: `(torch.Tensor)` |
| shape: (bs, response_length) |
| advantages: `(torch.Tensor)` |
| shape: (bs, response_length) |
| response_mask: `(torch.Tensor)` |
| shape: (bs, response_length) |
| return: |
| pg_loss: `a scalar torch.Tensor` |
| policy gradient loss computed via GPG |
| """ |
| assert config is not None |
| pg_losses = -log_prob * advantages |
|
|
| |
| if rollout_is_weights is not None: |
| pg_losses = pg_losses * rollout_is_weights |
|
|
| pg_loss = agg_loss( |
| loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode, **config.global_batch_info |
| ) |
| return pg_loss, {} |
|
|
|
|
| @register_policy_loss("clip_cov") |
| def compute_policy_loss_clip_cov( |
| old_log_prob: torch.Tensor, |
| log_prob: torch.Tensor, |
| advantages: torch.Tensor, |
| response_mask: torch.Tensor, |
| loss_agg_mode: str = "token-mean", |
| config: Optional[ActorConfig] = None, |
| rollout_is_weights: torch.Tensor | None = None, |
| ) -> tuple[torch.Tensor, dict[str, Any]]: |
| """ |
| Compute the clipped policy objective and related metrics for Clip-Cov. |
| |
| Adapted from |
| https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/verl/trainer/ppo/core_algos.py |
| |
| Args: |
| old_log_prob (torch.Tensor): |
| Log-probabilities of actions under the old policy, shape (batch_size, response_length). |
| log_prob (torch.Tensor): |
| Log-probabilities of actions under the current policy, shape (batch_size, response_length). |
| advantages (torch.Tensor): |
| Advantage estimates for each action, shape (batch_size, response_length). |
| response_mask (torch.Tensor): |
| Mask indicating which tokens to include in the loss, shape (batch_size, response_length). |
| cliprange (float, optional): |
| Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347. |
| Defaults to None (must be provided). |
| cliprange_low (float, optional): |
| Lower clip range for dual-clip PPO. Defaults to same as `cliprange`. |
| cliprange_high (float, optional): |
| Upper clip range for dual-clip PPO. Defaults to same as `cliprange`. |
| loss_agg_mode (str, optional): |
| Aggregation mode for `agg_loss`. Defaults to "token-mean". |
| clip_cvo_ratio (float, optional): |
| Ratio for clipping the covariance. Defaults to 0.0002. |
| clip_cov_lb (float, optional): |
| Lower bound for clipping covariance. Defaults to 1.0. |
| clip_cov_ub (float, optional): |
| Upper bound for clipping covariance. Defaults to 5.0. |
| """ |
| assert config is not None |
| assert not isinstance(config, AlgoConfig), "passing AlgoConfig not supported yet" |
| assert config.policy_loss is not None |
|
|
| clip_cov_ratio = config.policy_loss.clip_cov_ratio if config.policy_loss.clip_cov_ratio is not None else 0.0002 |
| cliprange = config.clip_ratio |
| cliprange_low = config.clip_ratio_low if config.clip_ratio_low is not None else cliprange |
| cliprange_high = config.clip_ratio_high if config.clip_ratio_high is not None else cliprange |
| clip_cov_ub = config.policy_loss.clip_cov_ub if config.policy_loss.clip_cov_ub is not None else 5.0 |
| clip_cov_lb = config.policy_loss.clip_cov_lb if config.policy_loss.clip_cov_lb is not None else 1.0 |
|
|
| assert clip_cov_ratio > 0, "clip_ratio should be larger than 0." |
|
|
| negative_approx_kl = log_prob - old_log_prob |
| ratio = torch.exp(negative_approx_kl) |
| ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) |
|
|
| pg_losses1 = -advantages * ratio |
|
|
| if cliprange_low is None: |
| cliprange_low = cliprange |
| if cliprange_high is None: |
| cliprange_high = cliprange |
|
|
| corr = torch.ones_like(advantages) |
| pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high) |
| clip_by_origin = (pg_losses2 > pg_losses1) & (response_mask > 0) |
|
|
| cov_all = (advantages - verl_F.masked_mean(advantages, response_mask)) * ( |
| log_prob - verl_F.masked_mean(log_prob.detach(), response_mask) |
| ) |
| cov_all[response_mask == 0] = -torch.inf |
| cov_all[clip_by_origin] = -torch.inf |
|
|
| clip_num = max(int(clip_cov_ratio * response_mask.sum().item()), 1) |
| top_k_idx = (cov_all < clip_cov_ub) & (cov_all > clip_cov_lb) & (response_mask > 0) |
| top_k_idx = torch.nonzero(top_k_idx) |
|
|
| if len(top_k_idx) > 0: |
| perm = torch.randperm(len(top_k_idx)) |
| top_k_idx = top_k_idx[perm[: min(clip_num, len(top_k_idx))]] |
| else: |
| top_k_idx = torch.empty((0, 2), device=cov_all.device, dtype=torch.long) |
|
|
| corr[top_k_idx[:, 0], top_k_idx[:, 1]] = 0 |
|
|
| pg_clipfrac = verl_F.masked_mean((corr == 0).float(), response_mask) |
|
|
| pg_losses = torch.maximum(pg_losses1, pg_losses2) * corr |
|
|
| |
| if rollout_is_weights is not None: |
| pg_losses = pg_losses * rollout_is_weights |
|
|
| pg_loss = agg_loss( |
| loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode, **config.global_batch_info |
| ) |
| pg_metrics = { |
| "actor/pg_clipfrac": pg_clipfrac.detach().item(), |
| "actor/ppo_kl": ppo_kl.detach().item(), |
| } |
| return pg_loss, pg_metrics |
|
|
|
|
| @register_policy_loss("kl_cov") |
| def compute_policy_loss_kl_cov( |
| old_log_prob: torch.Tensor, |
| log_prob: torch.Tensor, |
| advantages: torch.Tensor, |
| response_mask: torch.Tensor, |
| loss_agg_mode: str = "token-mean", |
| config: Optional[ActorConfig] = None, |
| rollout_is_weights: torch.Tensor | None = None, |
| ) -> tuple[torch.Tensor, dict[str, Any]]: |
| """ |
| Compute the clipped policy objective and related metrics for Clip-Cov. |
| |
| Adapted from |
| https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/verl/trainer/ppo/core_algos.py |
| |
| Args: |
| old_log_prob (torch.Tensor): |
| Log-probabilities of actions under the old policy, shape (batch_size, response_length). |
| log_prob (torch.Tensor): |
| Log-probabilities of actions under the current policy, shape (batch_size, response_length). |
| advantages (torch.Tensor): |
| Advantage estimates for each action, shape (batch_size, response_length). |
| response_mask (torch.Tensor): |
| Mask indicating which tokens to include in the loss, shape (batch_size, response_length). |
| loss_agg_mode (str, optional): |
| Aggregation mode for `agg_loss`. Defaults to "token-mean". |
| kl_cov_ratio (float, optional): |
| Ratio for selecting the top-k covariance values. Defaults to 0.0002. |
| ppo_kl_coef (float, optional): |
| Coefficient for the KL penalty term in the loss. Defaults to 1. |
| """ |
| assert config is not None |
| assert not isinstance(config, AlgoConfig), "passing AlgoConfig not supported yet" |
| assert config.policy_loss is not None |
|
|
| kl_cov_ratio = config.policy_loss.kl_cov_ratio if config.policy_loss.kl_cov_ratio is not None else 0.0002 |
| ppo_kl_coef = config.policy_loss.ppo_kl_coef if config.policy_loss.ppo_kl_coef is not None else 1.0 |
|
|
| assert kl_cov_ratio > 0, "kl_cov_ratio should be larger than 0." |
|
|
| negative_approx_kl = log_prob - old_log_prob |
| abs_kl = negative_approx_kl.abs() |
| ratio = torch.exp(negative_approx_kl) |
| ppo_kl_abs = verl_F.masked_mean(negative_approx_kl.abs(), response_mask) |
| pg_losses1 = -advantages * ratio |
| pg_losses_kl = -advantages * ratio + ppo_kl_coef * abs_kl |
| pg_losses = pg_losses1 |
|
|
| all_valid = response_mask > 0 |
| all_valid_idx = torch.nonzero(all_valid.reshape(-1), as_tuple=True)[0] |
| all_valid_adv = advantages[all_valid].detach().reshape(-1).cpu() |
| all_valid_logp = log_prob[all_valid].detach().reshape(-1).cpu() |
|
|
| k = min(kl_cov_ratio, len(all_valid_adv)) |
|
|
| if k != 0: |
| cov_lst_all = (all_valid_adv - all_valid_adv.mean()) * (all_valid_logp - all_valid_logp.mean()) |
| k_percent_nums = max(1, int(len(cov_lst_all) * kl_cov_ratio)) |
| large_cov_idxs = torch.topk(cov_lst_all, k_percent_nums, largest=True).indices |
|
|
| if len(large_cov_idxs) != 0: |
| large_cov_idxs = all_valid_idx[large_cov_idxs] |
| pg_losses[large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1]] = pg_losses_kl[ |
| large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1] |
| ] |
|
|
| |
| if rollout_is_weights is not None: |
| pg_losses = pg_losses * rollout_is_weights |
|
|
| pg_loss = agg_loss( |
| loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode, **config.global_batch_info |
| ) |
| pg_metrics = { |
| "actor/ppo_kl": ppo_kl_abs.detach().item(), |
| } |
| return pg_loss, pg_metrics |
|
|
|
|
| @register_policy_loss("geo_mean") |
| def compute_policy_loss_geo_mean( |
| old_log_prob: torch.Tensor, |
| log_prob: torch.Tensor, |
| advantages: torch.Tensor, |
| response_mask: torch.Tensor, |
| loss_agg_mode: str = "token-mean", |
| config: Optional[ActorConfig] = None, |
| rollout_is_weights: torch.Tensor | None = None, |
| ) -> tuple[torch.Tensor, dict[str, Any]]: |
| """ |
| Compute the clipped policy objective and related metrics for GMPO. |
| |
| Adapted from paper https://arxiv.org/abs/2507.20673 |
| https://github.com/callsys/GMPO/blob/main/train_zero_math_gmpo.py |
| |
| Args: |
| old_log_prob (torch.Tensor): |
| Log-probabilities of actions under the old policy, shape (batch_size, response_length). |
| log_prob (torch.Tensor): |
| Log-probabilities of actions under the current policy, shape (batch_size, response_length). |
| advantages (torch.Tensor): |
| Advantage estimates for each action, shape (batch_size, response_length). |
| response_mask (torch.Tensor): |
| Mask indicating which tokens to include in the loss, shape (batch_size, response_length). |
| loss_agg_mode (str, optional): |
| not used |
| """ |
|
|
| assert config is not None |
| assert not isinstance(config, AlgoConfig) |
| clip_ratio = config.clip_ratio |
| clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio |
| clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio |
|
|
| cliprange = clip_ratio |
| cliprange_low = clip_ratio_low |
| cliprange_high = clip_ratio_high |
| if cliprange_low is None: |
| cliprange_low = cliprange |
| if cliprange_high is None: |
| cliprange_high = cliprange |
|
|
| negative_approx_kl = log_prob - old_log_prob |
| |
| |
| ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) |
|
|
| |
| sgn_advantage = torch.sign(advantages) |
| negative_approx_kl_clamp = torch.clamp(negative_approx_kl, -cliprange_low, cliprange_high) |
| negative_approx_kl_min = torch.min(sgn_advantage * negative_approx_kl, sgn_advantage * negative_approx_kl_clamp) |
| negative_approx_kl_min = sgn_advantage * negative_approx_kl_min |
|
|
| |
| response_mask_sum = response_mask.sum(dim=-1) |
| ratio = torch.exp((negative_approx_kl_min * response_mask).sum(dim=-1) / (response_mask_sum + 1e-8)) |
| |
| |
| advantage = (advantages * response_mask).sum(dim=-1) / (response_mask_sum + 1e-8) |
| pg_losses = -advantage * ratio |
|
|
| |
| |
| if rollout_is_weights is not None: |
| |
| |
| seq_is_weights = torch.exp( |
| (torch.log(rollout_is_weights + 1e-10) * response_mask).sum(dim=-1) / (response_mask_sum + 1e-8) |
| ) |
| pg_losses = pg_losses * seq_is_weights |
|
|
| pg_loss = torch.mean(pg_losses) |
|
|
| |
| clipped = torch.ne(negative_approx_kl, negative_approx_kl_clamp) |
| pg_clipfrac = verl_F.masked_mean((clipped * (advantages > 0)).float(), response_mask) |
| pg_clipfrac_lower = verl_F.masked_mean((clipped * (advantages < 0)).float(), response_mask) |
| pg_metrics = { |
| "actor/pg_clipfrac": pg_clipfrac.detach().item(), |
| "actor/ppo_kl": ppo_kl.detach().item(), |
| "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(), |
| } |
| return pg_loss, pg_metrics |
|
|
|
|
| @register_policy_loss("cispo") |
| def compute_policy_loss_cispo( |
| old_log_prob: torch.Tensor, |
| log_prob: torch.Tensor, |
| advantages: torch.Tensor, |
| response_mask: torch.Tensor, |
| loss_agg_mode: str = "token-mean", |
| config: Optional[DictConfig | ActorConfig] = None, |
| rollout_is_weights: torch.Tensor | None = None, |
| ) -> tuple[torch.Tensor, dict[str, Any]]: |
| """ |
| Compute the clipped policy objective and related metrics for CISPO. |
| |
| See https://arxiv.org/pdf/2506.13585 for more details. |
| """ |
|
|
| assert config is not None |
| assert isinstance(config, ActorConfig) |
| clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else config.clip_ratio |
| clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else config.clip_ratio |
|
|
| |
| negative_approx_kl = log_prob - old_log_prob |
| |
| negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) |
| ratio = torch.exp(negative_approx_kl) |
| ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) |
|
|
| |
| |
| |
| |
| clipped_ratio = torch.clamp(ratio, 1 - clip_ratio_low, 1 + clip_ratio_high) |
| clipped_ratio_sg = clipped_ratio.detach() |
|
|
| |
| |
| pg_losses = -clipped_ratio_sg * advantages * log_prob |
|
|
| |
| pg_clipfrac = verl_F.masked_mean((ratio != clipped_ratio).float(), response_mask) |
|
|
| |
| if rollout_is_weights is not None: |
| pg_losses = pg_losses * rollout_is_weights |
|
|
| pg_loss = agg_loss( |
| loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode, **config.global_batch_info |
| ) |
|
|
| |
| pg_clipfrac_lower = torch.tensor(0.0, device=pg_loss.device) |
|
|
| pg_metrics = { |
| "actor/pg_clipfrac": pg_clipfrac.detach().item(), |
| "actor/ppo_kl": ppo_kl.detach().item(), |
| "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(), |
| } |
| return pg_loss, pg_metrics |
|
|
|
|
| def compute_entropy_loss(logits, response_mask, loss_agg_mode: str = "token-mean"): |
| """Compute categorical entropy loss (For backward compatibility) |
| |
| Args: |
| logits (torch.Tensor): shape is (bs, response_length, vocab_size) |
| response_mask (torch.Tensor): shape is (bs, response_length) |
| |
| Returns: |
| entropy: a scalar torch.Tensor |
| |
| """ |
| |
| token_entropy = verl_F.entropy_from_logits(logits) |
| entropy_loss = agg_loss(loss_mat=token_entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) |
| return entropy_loss |
|
|
|
|
| def compute_value_loss( |
| vpreds: torch.Tensor, |
| returns: torch.Tensor, |
| values: torch.Tensor, |
| response_mask: torch.Tensor, |
| cliprange_value: float, |
| loss_agg_mode: str = "token-mean", |
| ): |
| """ |
| Compute the clipped value-function loss for PPO. |
| |
| Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151 |
| |
| Args: |
| vpreds (torch.FloatTensor): |
| Predicted values from the value head, shape (batch_size, response_length). |
| values (torch.FloatTensor): |
| Old (baseline) values from the value head, shape (batch_size, response_length). |
| returns (torch.FloatTensor): |
| Ground-truth returns, shape (batch_size, response_length). |
| response_mask (torch.Tensor): |
| Mask indicating which tokens to include in the value loss calculation. |
| cliprange_value (float): |
| Clip range for value prediction updates. |
| loss_agg_mode (str, optional): |
| Aggregation mode for `agg_loss`. Defaults to "token-mean". |
| |
| Returns: |
| vf_loss (torch.FloatTensor): |
| A scalar tensor containing the aggregated value-function loss. |
| vf_clipfrac (float): |
| Fraction of elements where the clipped loss was used. |
| """ |
| vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value) |
| vf_losses1 = (vpreds - returns) ** 2 |
| vf_losses2 = (vpredclipped - returns) ** 2 |
| clipped_vf_losses = torch.max(vf_losses1, vf_losses2) |
| vf_loss = 0.5 * agg_loss(loss_mat=clipped_vf_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) |
| vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), response_mask) |
| return vf_loss, vf_clipfrac |
|
|
|
|
| def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor: |
| """Compute KL divergence given logprob and ref_logprob. Optionally using straight through to bind k2 on other |
| kl penalty compute method for unbiased KL gradient estimation. |
| See more description in http://joschu.net/blog/kl-approx.html |
| |
| Args: |
| logprob: |
| ref_logprob: |
| |
| Returns: |
| kl_estimate |
| """ |
| forward_score = kl_penalty_forward(logprob, ref_logprob, kl_penalty) |
| if not kl_penalty.endswith("+") or kl_penalty in ("mse", "k2"): |
| return forward_score |
|
|
| """ |
| The expectation of k1 and k3 estimator is the expected value of KL, but the expected gradient of k1 and k3 |
| estimator is not the expected gradient of KL. On the other hand k2 estimator gives right gradient estimator, |
| so we use a straight through trick here if the kl_penalty method ends with '+', e.g., k3+. |
| """ |
| backward_score = 0.5 * (logprob - ref_logprob).square() |
|
|
| return backward_score - backward_score.detach() + forward_score.detach() |
|
|
|
|
| def kl_penalty_forward(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor: |
| """Compute KL divergence given logprob and ref_logprob. |
| Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104 |
| See more description in http://joschu.net/blog/kl-approx.html |
| |
| Args: |
| logprob: |
| ref_logprob: |
| |
| Returns: |
| kl_estimate |
| """ |
| if kl_penalty in ("kl", "k1"): |
| return logprob - ref_logprob |
|
|
| if kl_penalty == "abs": |
| return (logprob - ref_logprob).abs() |
|
|
| if kl_penalty in ("mse", "k2"): |
| return 0.5 * (logprob - ref_logprob).square() |
|
|
| |
| |
| if kl_penalty in ("low_var_kl", "k3"): |
| kl = ref_logprob - logprob |
| |
| kl = torch.clamp(kl, min=-20, max=20) |
| ratio = torch.exp(kl) |
| kld = (ratio - kl - 1).contiguous() |
| return torch.clamp(kld, min=-10, max=10) |
|
|
| if kl_penalty == "full": |
| |
| raise NotImplementedError |
|
|
| raise NotImplementedError |
|
|
|
|
| def compute_pf_ppo_reweight_data( |
| data, |
| reweight_method: str = "pow", |
| weight_pow: float = 2.0, |
| ): |
| """Reweight the data based on the token_level_scores. |
| |
| Args: |
| data: DataProto object, containing batch, non_tensor_batch and meta_info |
| reweight_method: str, choices: "pow", "max_min", "max_random" |
| weight_pow: float, the power of the weight |
| |
| Returns: |
| |
| """ |
|
|
| @torch.no_grad() |
| def compute_weights(scores: torch.Tensor, reweight_method: str, weight_pow: float) -> torch.Tensor: |
| """Compute importance weights for resampling based on scores. |
| |
| Args: |
| scores (torch.Tensor): Tensor of scores to compute weights from. |
| reweight_method (str): Method for computing weights ('pow', 'max_min', 'max_random'). |
| weight_pow (float): Power exponent for 'pow' method. |
| |
| Returns: |
| torch.Tensor: Computed importance weights. |
| |
| Raises: |
| ValueError: If reweight_method is not supported. |
| """ |
| if reweight_method == "pow": |
| weights = torch.pow(torch.abs(scores), weight_pow) |
| elif reweight_method == "max_min": |
| max_score = torch.max(scores) |
| min_score = torch.min(scores) |
| weights = torch.where((scores == max_score) | (scores == min_score), 1.0, 0.0) |
| elif reweight_method == "max_random": |
| max_score = torch.max(scores) |
| weights = torch.where(scores == max_score, 0.4, 0.1) |
| else: |
| raise ValueError(f"Unsupported reweight_method: {reweight_method}") |
| return weights |
|
|
| scores = data.batch["token_level_scores"].sum(dim=-1) |
| weights = compute_weights(scores, reweight_method, weight_pow) |
| weights = torch.clamp(weights + 1e-8, min=1e-8) |
|
|
| batch_size = scores.shape[0] |
| sample_indices = torch.multinomial(weights, batch_size, replacement=True) |
|
|
| resampled_batch = {key: tensor[sample_indices] for key, tensor in data.batch.items()} |
|
|
| sample_indices_np = sample_indices.numpy() |
| resampled_non_tensor_batch = {} |
| for key, array in data.non_tensor_batch.items(): |
| if isinstance(array, np.ndarray): |
| resampled_non_tensor_batch[key] = array[sample_indices_np] |
| else: |
| resampled_non_tensor_batch[key] = [array[i] for i in sample_indices_np] |
|
|
| resampled_meta_info = {} |
| for key, value in data.meta_info.items(): |
| if isinstance(value, list) and len(value) == batch_size: |
| resampled_meta_info[key] = [value[i] for i in sample_indices_np] |
| else: |
| resampled_meta_info[key] = value |
|
|
| from copy import deepcopy |
|
|
| resampled_data = deepcopy(data) |
| resampled_data.batch = type(data.batch)(resampled_batch) |
| resampled_data.batch.batch_size = data.batch.batch_size |
| resampled_data.non_tensor_batch = resampled_non_tensor_batch |
| resampled_data.meta_info = resampled_meta_info |
|
|
| return resampled_data |
|
|
|
|
| def compute_policy_loss_reinforce( |
| rollout_log_prob: torch.Tensor, |
| log_prob: torch.Tensor, |
| advantages: torch.Tensor, |
| response_mask: torch.Tensor, |
| loss_agg_mode: str = "seq-mean-token-sum", |
| config: Optional[ActorConfig] = None, |
| rollout_is_weights: Optional[torch.Tensor] = None, |
| ) -> tuple[torch.Tensor, dict[str, Any]]: |
| """Compute REINFORCE-style policy gradient loss with optional IS correction. |
| |
| This function implements policy gradient (REINFORCE) with optional importance |
| sampling correction for rollout-training policy mismatch. |
| |
| Mathematical formulation: |
| Without IS (rollout_is_weights=None): |
| L = -E[log π(a|s) * A(s,a)] |
| Gradient: ∇_θ L = -E[∇log π(a|s) * A] (standard REINFORCE) |
| |
| With IS (rollout_is_weights provided): |
| L = -E_π_rollout[w * log π(a|s) * A(s,a)] |
| where w = π_current / π_rollout (truncated IS weight) |
| Gradient: ∇_θ L = -E[w * ∇log π(a|s) * A] (IS-corrected policy gradient) |
| |
| Args: |
| rollout_log_prob: Log probabilities from rollout policy (e.g., vLLM BF16). |
| Shape: (batch_size, seq_length). Used for KL computation. |
| log_prob: Log probabilities from current training policy. |
| Shape: (batch_size, seq_length) |
| advantages: Advantage estimates for each token. |
| Shape: (batch_size, seq_length) |
| response_mask: Mask indicating valid tokens (1 for valid, 0 for padding). |
| Shape: (batch_size, seq_length). Should already include rejection sampling. |
| loss_agg_mode: Loss aggregation strategy (see agg_loss for details). |
| config: Actor config (required for global_batch_info). |
| rollout_is_weights: Pre-computed IS weights (π_current / π_rollout). |
| Shape: (batch_size, seq_length). None to disable IS correction. |
| |
| Returns: |
| Tuple of (loss, metrics): |
| loss: Scalar policy gradient loss |
| metrics: Dictionary with "actor/ppo_kl" |
| |
| Note: |
| Unlike PPO (compute_policy_loss_vanilla), this function: |
| - Does NOT use PPO clipping |
| - Uses log π(a|s) directly (not ratio) |
| - IS weights are applied as multiplicative factor |
| """ |
| assert config is not None, "ActorConfig must be provided for REINFORCE loss" |
|
|
| |
| |
| |
| if rollout_is_weights is not None: |
| |
| pg_losses = -advantages * log_prob * rollout_is_weights |
| else: |
| |
| pg_losses = -advantages * log_prob |
|
|
| |
| pg_loss = agg_loss( |
| loss_mat=pg_losses, |
| loss_mask=response_mask, |
| loss_agg_mode=loss_agg_mode, |
| **config.global_batch_info, |
| ) |
|
|
| |
| negative_approx_kl = log_prob - rollout_log_prob |
| kl_divergence = verl_F.masked_mean(-negative_approx_kl, response_mask) |
|
|
| pg_metrics = { |
| "actor/ppo_kl": kl_divergence.detach().item(), |
| } |
|
|
| return pg_loss, pg_metrics |
|
|
|
|
| @register_policy_loss("bypass_mode") |
| def compute_policy_loss_bypass_mode( |
| old_log_prob: torch.Tensor, |
| log_prob: torch.Tensor, |
| advantages: torch.Tensor, |
| response_mask: torch.Tensor, |
| loss_agg_mode: str = "token-mean", |
| config: Optional[ActorConfig] = None, |
| rollout_is_weights: torch.Tensor | None = None, |
| ) -> tuple[torch.Tensor, dict[str, Any]]: |
| """Bypass mode policy loss supporting both REINFORCE and PPO-clip. |
| |
| This function is the entry point for bypass mode, where old_log_prob = rollout_log_prob. |
| It computes IS weights and rejection masks, then dispatches to either REINFORCE or |
| PPO-clip loss based on the loss_type configuration. |
| |
| IMPORTANT - Bypass mode semantics: |
| In bypass mode, the trainer sets old_log_prob = rollout_log_prob. |
| This means: |
| - For REINFORCE: We use IS weights w = π_current / π_rollout explicitly |
| - For PPO-clip: The PPO ratio π_current / π_old = π_current / π_rollout |
| already incorporates the IS correction through clipping, so we do NOT |
| apply additional IS weights (would be double-counting) |
| |
| Loss types: |
| - "ppo_clip" (default): PPO clipped objective (compute_policy_loss_vanilla) |
| L = -E[min(r*A, clip(r)*A)] where r = π_current / π_rollout |
| Note: IS weights are NOT applied (clipping handles the ratio) |
| - "reinforce": REINFORCE-style policy gradient with IS correction |
| L = -E[w * log π(a|s) * A] where w = π_current / π_rollout |
| |
| Args: |
| old_log_prob: In bypass mode, this is actually rollout_log_prob. |
| Shape: (batch_size, seq_length) |
| log_prob: Current policy log probabilities. |
| Shape: (batch_size, seq_length) |
| advantages: Advantage estimates. |
| Shape: (batch_size, seq_length) |
| response_mask: Valid token mask (1=valid, 0=padding). |
| Shape: (batch_size, seq_length) |
| loss_agg_mode: Loss aggregation mode (passed to underlying loss function). |
| config: Actor config containing rollout_correction settings in policy_loss. |
| rollout_is_weights: Pre-computed IS weights (ignored, computed internally). |
| |
| Config options (in config.policy_loss.rollout_correction): |
| loss_type: "ppo_clip" (default) or "reinforce" |
| rollout_is: IS aggregation level ("token", "sequence", or None) |
| rollout_is_threshold: Upper threshold for truncating IS weights (default: 2.0) |
| rollout_rs: Rejection sampling level (see rollout_corr_helper for supported modes) |
| rollout_rs_threshold: Threshold specification for rejection sampling |
| rollout_is_batch_normalize: Whether to normalize IS weights to mean=1.0 |
| |
| Returns: |
| Tuple of (loss, metrics): |
| loss: Scalar policy loss |
| metrics: Dictionary with rollout correction metrics and actor/ppo_kl |
| """ |
| from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_rejection_mask |
|
|
| assert config is not None, "config is required for bypass_mode loss" |
|
|
| |
| rollout_corr_config = config.policy_loss.get("rollout_correction", None) if hasattr(config, "policy_loss") else None |
|
|
| if rollout_corr_config is None: |
| raise ValueError( |
| "rollout_correction config not found in policy_loss. " |
| "When using loss_mode='bypass_mode', ensure rollout_correction config is passed." |
| ) |
|
|
| |
| loss_type = rollout_corr_config.get("loss_type", "ppo_clip") |
| rollout_is = rollout_corr_config.get("rollout_is", None) |
| rollout_is_threshold = rollout_corr_config.get("rollout_is_threshold", 2.0) |
| rollout_is_batch_normalize = rollout_corr_config.get("rollout_is_batch_normalize", False) |
| rollout_rs = rollout_corr_config.get("rollout_rs", None) |
| rollout_rs_threshold = rollout_corr_config.get("rollout_rs_threshold", None) |
|
|
| |
| rollout_log_prob = old_log_prob |
|
|
| |
| |
| with torch.no_grad(): |
| rollout_is_weights_proto, modified_response_mask, rollout_metrics = ( |
| compute_rollout_correction_and_rejection_mask( |
| old_log_prob=log_prob, |
| rollout_log_prob=rollout_log_prob, |
| response_mask=response_mask, |
| rollout_is=rollout_is, |
| rollout_is_threshold=rollout_is_threshold, |
| rollout_is_batch_normalize=rollout_is_batch_normalize, |
| rollout_rs=rollout_rs, |
| rollout_rs_threshold=rollout_rs_threshold, |
| ) |
| ) |
|
|
| |
| computed_is_weights = rollout_is_weights_proto.batch["rollout_is_weights"] if rollout_is_weights_proto else None |
|
|
| |
| effective_mask = modified_response_mask |
|
|
| |
| if loss_type == "reinforce": |
| |
| pg_loss, pg_metrics = compute_policy_loss_reinforce( |
| rollout_log_prob=rollout_log_prob, |
| log_prob=log_prob, |
| advantages=advantages, |
| response_mask=effective_mask, |
| loss_agg_mode=loss_agg_mode, |
| config=config, |
| rollout_is_weights=computed_is_weights, |
| ) |
|
|
| elif loss_type == "ppo_clip": |
| |
| |
| |
| pg_loss, pg_metrics = compute_policy_loss_vanilla( |
| old_log_prob=rollout_log_prob, |
| log_prob=log_prob, |
| advantages=advantages, |
| response_mask=effective_mask, |
| loss_agg_mode=loss_agg_mode, |
| config=config, |
| rollout_is_weights=None, |
| ) |
|
|
| else: |
| raise ValueError(f"Invalid loss_type: {loss_type}. Must be 'reinforce' or 'ppo_clip'.") |
|
|
| |
| pg_metrics.update(rollout_metrics) |
|
|
| return pg_loss, pg_metrics |
|
|