|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
Core functions to implement PPO algorithms.
|
|
|
The function implemented in this file should be used by trainer with different distributed strategies to
|
|
|
implement PPO
|
|
|
"""
|
|
|
|
|
|
from collections import defaultdict
|
|
|
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
|
|
|
import verl.utils.torch_functional as verl_F
|
|
|
|
|
|
|
|
|
class AdaptiveKLController:
|
|
|
"""
|
|
|
Adaptive KL controller described in the paper:
|
|
|
https://arxiv.org/pdf/1909.08593.pdf
|
|
|
"""
|
|
|
|
|
|
def __init__(self, init_kl_coef, target_kl, horizon):
|
|
|
self.value = init_kl_coef
|
|
|
self.target = target_kl
|
|
|
self.horizon = horizon
|
|
|
|
|
|
def update(self, current_kl, n_steps):
|
|
|
target = self.target
|
|
|
proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)
|
|
|
mult = 1 + proportional_error * n_steps / self.horizon
|
|
|
self.value *= mult
|
|
|
|
|
|
|
|
|
class FixedKLController:
|
|
|
"""Fixed KL controller."""
|
|
|
|
|
|
def __init__(self, kl_coef):
|
|
|
self.value = kl_coef
|
|
|
|
|
|
def update(self, current_kl, n_steps):
|
|
|
pass
|
|
|
|
|
|
|
|
|
def get_kl_controller(kl_ctrl):
|
|
|
if kl_ctrl.type == "fixed":
|
|
|
return FixedKLController(kl_coef=kl_ctrl.kl_coef)
|
|
|
elif kl_ctrl.type == "adaptive":
|
|
|
assert kl_ctrl.horizon > 0, f"horizon must be larger than 0. Got {kl_ctrl.horizon}"
|
|
|
return AdaptiveKLController(init_kl_coef=kl_ctrl.kl_coef, target_kl=kl_ctrl.target_kl, horizon=kl_ctrl.horizon)
|
|
|
else:
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
def compute_gae_advantage_return(
|
|
|
token_level_rewards: torch.Tensor,
|
|
|
values: torch.Tensor,
|
|
|
response_mask: torch.Tensor,
|
|
|
gamma: torch.Tensor,
|
|
|
lam: torch.Tensor,
|
|
|
):
|
|
|
"""Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py
|
|
|
|
|
|
Args:
|
|
|
token_level_rewards: `(torch.Tensor)`
|
|
|
shape: (bs, response_length)
|
|
|
values: `(torch.Tensor)`
|
|
|
shape: (bs, response_length)
|
|
|
response_mask: `(torch.Tensor)`
|
|
|
shape: (bs, response_length). [EOS] mask. The token after [EOS] have mask zero.
|
|
|
gamma: `(float)`
|
|
|
discounted factor used in RL
|
|
|
lam: `(float)`
|
|
|
lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
|
|
|
|
|
|
Returns:
|
|
|
advantages: `(torch.Tensor)`
|
|
|
shape: (bs, response_length)
|
|
|
Returns: `(torch.Tensor)`
|
|
|
shape: (bs, response_length)
|
|
|
|
|
|
"""
|
|
|
with torch.no_grad():
|
|
|
lastgaelam = 0
|
|
|
advantages_reversed = []
|
|
|
gen_len = token_level_rewards.shape[-1]
|
|
|
|
|
|
for t in reversed(range(gen_len)):
|
|
|
nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
|
|
|
delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]
|
|
|
lastgaelam = delta + gamma * lam * lastgaelam
|
|
|
advantages_reversed.append(lastgaelam)
|
|
|
advantages = torch.stack(advantages_reversed[::-1], dim=1)
|
|
|
|
|
|
returns = advantages + values
|
|
|
advantages = verl_F.masked_whiten(advantages, response_mask)
|
|
|
return advantages, returns
|
|
|
|
|
|
|
|
|
|
|
|
def compute_grpo_outcome_advantage(
|
|
|
token_level_rewards: torch.Tensor,
|
|
|
response_mask: torch.Tensor,
|
|
|
index: np.ndarray,
|
|
|
epsilon: float = 1e-6,
|
|
|
norm_adv_by_std_in_grpo: str = True,
|
|
|
):
|
|
|
"""
|
|
|
Compute advantage for GRPO, operating only on Outcome reward
|
|
|
(with only one scalar reward for each response).
|
|
|
Args:
|
|
|
token_level_rewards: `(torch.Tensor)`
|
|
|
shape: (bs, response_length)
|
|
|
response_mask: `(torch.Tensor)`
|
|
|
shape: (bs, response_length)
|
|
|
norm_adv_by_std_in_grpo: (bool)
|
|
|
whether to scale the GRPO advantage.
|
|
|
If True, the advantage is scaled by the std, as in the original GRPO.
|
|
|
If False, the advantage is not scaled, as in Dr.GRPO (https://arxiv.org/abs/2503.20783).
|
|
|
|
|
|
Returns:
|
|
|
advantages: `(torch.Tensor)`
|
|
|
shape: (bs, response_length)
|
|
|
Returns: `(torch.Tensor)`
|
|
|
shape: (bs, response_length)
|
|
|
"""
|
|
|
scores = token_level_rewards.sum(dim=-1)
|
|
|
|
|
|
id2score = defaultdict(list)
|
|
|
id2mean = {}
|
|
|
id2std = {}
|
|
|
|
|
|
with torch.no_grad():
|
|
|
bsz = scores.shape[0]
|
|
|
for i in range(bsz):
|
|
|
id2score[index[i]].append(scores[i])
|
|
|
for idx in id2score:
|
|
|
if len(id2score[idx]) == 1:
|
|
|
id2mean[idx] = torch.tensor(0.0)
|
|
|
id2std[idx] = torch.tensor(1.0)
|
|
|
elif len(id2score[idx]) > 1:
|
|
|
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
|
|
|
id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
|
|
|
else:
|
|
|
raise ValueError(f"no score in prompt index: {idx}")
|
|
|
for i in range(bsz):
|
|
|
if norm_adv_by_std_in_grpo:
|
|
|
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
|
|
|
else:
|
|
|
scores[i] = scores[i] - id2mean[index[i]]
|
|
|
scores = scores.unsqueeze(-1) * response_mask
|
|
|
|
|
|
return scores, scores
|
|
|
|
|
|
|
|
|
def compute_reinforce_plus_plus_baseline_outcome_advantage(token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: torch.Tensor, epsilon: float = 1e-6):
|
|
|
"""
|
|
|
Compute advantage for RF++-baseline (https://arxiv.org/abs/2501.03262), operating only on Outcome reward
|
|
|
(with only one scalar reward for each response).
|
|
|
Args:
|
|
|
token_level_rewards: `(torch.Tensor)`
|
|
|
shape: (bs, response_length)
|
|
|
response_mask: `(torch.Tensor)`
|
|
|
shape: (bs, response_length)
|
|
|
|
|
|
Returns:
|
|
|
advantages: `(torch.Tensor)`
|
|
|
shape: (bs, response_length)
|
|
|
Returns: `(torch.Tensor)`
|
|
|
shape: (bs, response_length)
|
|
|
"""
|
|
|
response_length = token_level_rewards.shape[-1]
|
|
|
scores = token_level_rewards.sum(dim=-1)
|
|
|
|
|
|
id2score = defaultdict(list)
|
|
|
id2mean = {}
|
|
|
|
|
|
with torch.no_grad():
|
|
|
bsz = scores.shape[0]
|
|
|
for i in range(bsz):
|
|
|
id2score[index[i]].append(scores[i])
|
|
|
for idx in id2score:
|
|
|
if len(id2score[idx]) == 1:
|
|
|
id2mean[idx] = torch.tensor(0.0)
|
|
|
elif len(id2score[idx]) > 1:
|
|
|
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
|
|
|
else:
|
|
|
raise ValueError(f"no score in prompt index: {idx}")
|
|
|
for i in range(bsz):
|
|
|
scores[i] = scores[i] - id2mean[index[i]]
|
|
|
|
|
|
scores = scores.unsqueeze(-1).tile([1, response_length]) * response_mask
|
|
|
scores = verl_F.masked_whiten(scores, response_mask)
|
|
|
|
|
|
return scores, scores
|
|
|
|
|
|
|
|
|
def compute_rloo_outcome_advantage(token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: np.ndarray, epsilon: float = 1e-6):
|
|
|
"""
|
|
|
Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740
|
|
|
Args:
|
|
|
token_level_rewards: `(torch.Tensor)`
|
|
|
shape: (bs, response_length)
|
|
|
response_mask: `(torch.Tensor)`
|
|
|
shape: (bs, response_length)
|
|
|
|
|
|
Returns:
|
|
|
advantages: `(torch.Tensor)`
|
|
|
shape: (bs, response_length)
|
|
|
Returns: `(torch.Tensor)`
|
|
|
shape: (bs, response_length)
|
|
|
"""
|
|
|
scores = token_level_rewards.sum(dim=-1)
|
|
|
|
|
|
id2score = defaultdict(list)
|
|
|
id2mean = {}
|
|
|
|
|
|
with torch.no_grad():
|
|
|
bsz = scores.shape[0]
|
|
|
for i in range(bsz):
|
|
|
id2score[index[i]].append(scores[i])
|
|
|
for idx in id2score:
|
|
|
if len(id2score[idx]) == 1:
|
|
|
id2mean[idx] = torch.tensor(0.0)
|
|
|
elif len(id2score[idx]) > 1:
|
|
|
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
|
|
|
else:
|
|
|
raise ValueError(f"no score in prompt index: {idx}")
|
|
|
for i in range(bsz):
|
|
|
response_num = len(id2score[index[i]])
|
|
|
if response_num > 1:
|
|
|
scores[i] = scores[i] * response_num / (response_num - 1) - id2mean[index[i]] * response_num / (response_num - 1)
|
|
|
scores = scores.unsqueeze(-1) * response_mask
|
|
|
|
|
|
return scores, scores
|
|
|
|
|
|
|
|
|
def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards: torch.Tensor, response_mask: torch.Tensor, gamma: torch.Tensor):
|
|
|
"""
|
|
|
Compute advantage for REINFORCE++.
|
|
|
This implementation is based on the paper: https://arxiv.org/abs/2501.03262
|
|
|
Args:
|
|
|
token_level_rewards: `(torch.Tensor)`
|
|
|
shape: (bs, response_length)
|
|
|
response_mask: `(torch.Tensor)`
|
|
|
shape: (bs, response_length)
|
|
|
|
|
|
Returns:
|
|
|
advantages: `(torch.Tensor)`
|
|
|
shape: (bs, response_length)
|
|
|
Returns: `(torch.Tensor)`
|
|
|
shape: (bs, response_length)
|
|
|
"""
|
|
|
|
|
|
with torch.no_grad():
|
|
|
returns = torch.zeros_like(token_level_rewards)
|
|
|
running_return = 0
|
|
|
|
|
|
for t in reversed(range(token_level_rewards.shape[1])):
|
|
|
running_return = token_level_rewards[:, t] + gamma * running_return
|
|
|
returns[:, t] = running_return
|
|
|
|
|
|
running_return = running_return * response_mask[:, t]
|
|
|
|
|
|
advantages = verl_F.masked_whiten(returns, response_mask)
|
|
|
advantages = advantages * response_mask
|
|
|
|
|
|
return advantages, returns
|
|
|
|
|
|
|
|
|
def compute_remax_outcome_advantage(token_level_rewards: torch.Tensor, reward_baselines: torch.Tensor, response_mask: torch.Tensor):
|
|
|
"""
|
|
|
Compute advantage for ReMax, operating only on Outcome reward
|
|
|
This implementation is based on the paper: https://arxiv.org/abs/2310.10505
|
|
|
|
|
|
(with only one scalar reward for each response).
|
|
|
Args:
|
|
|
token_level_rewards: `(torch.Tensor)`
|
|
|
shape: (bs, response_length)
|
|
|
reward_baselines: `(torch.Tensor)`
|
|
|
shape: (bs,)
|
|
|
response_mask: `(torch.Tensor)`
|
|
|
shape: (bs, response_length)
|
|
|
|
|
|
Returns:
|
|
|
advantages: `(torch.Tensor)`
|
|
|
shape: (bs, response_length)
|
|
|
Returns: `(torch.Tensor)`
|
|
|
shape: (bs, response_length)
|
|
|
"""
|
|
|
|
|
|
with torch.no_grad():
|
|
|
returns = (token_level_rewards * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
|
|
|
advantages = returns - reward_baselines.unsqueeze(-1) * response_mask
|
|
|
|
|
|
return advantages, returns
|
|
|
|
|
|
|
|
|
def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio):
|
|
|
kl = old_log_prob - ref_log_prob
|
|
|
return token_level_scores - kl * kl_ratio
|
|
|
|
|
|
|
|
|
def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str):
|
|
|
"""
|
|
|
Aggregate the loss matrix into a scalar.
|
|
|
Args:
|
|
|
loss_mat: `(torch.Tensor)`
|
|
|
shape: (bs, response_length)
|
|
|
loss_mask: `(torch.Tensor)`
|
|
|
shape: (bs, response_length)
|
|
|
loss_agg_mode: (str) choices: "token-mean" /
|
|
|
"seq-mean-token-sum" /
|
|
|
"seq-mean-token-mean" /
|
|
|
"seq-mean-token-sum-norm" /
|
|
|
"token-mean" is the default behavior
|
|
|
Returns:
|
|
|
loss: `a scalar torch.Tensor`
|
|
|
aggregated loss
|
|
|
"""
|
|
|
if loss_agg_mode == "token-mean":
|
|
|
loss = verl_F.masked_mean(loss_mat, loss_mask)
|
|
|
elif loss_agg_mode == "seq-mean-token-sum":
|
|
|
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1)
|
|
|
loss = torch.mean(seq_losses)
|
|
|
elif loss_agg_mode == "seq-mean-token-mean":
|
|
|
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / torch.sum(loss_mask, dim=-1)
|
|
|
loss = torch.mean(seq_losses)
|
|
|
elif loss_agg_mode == "seq-mean-token-sum-norm":
|
|
|
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1)
|
|
|
loss = torch.sum(seq_losses) / loss_mask.shape[-1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}")
|
|
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
def compute_policy_loss(
|
|
|
old_log_prob,
|
|
|
log_prob,
|
|
|
advantages,
|
|
|
response_mask,
|
|
|
cliprange=None,
|
|
|
cliprange_low=None,
|
|
|
cliprange_high=None,
|
|
|
clip_ratio_c=3.0,
|
|
|
loss_agg_mode="token-mean",
|
|
|
):
|
|
|
"""Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122
|
|
|
Args:
|
|
|
old_log_prob: `(torch.Tensor)`
|
|
|
shape: (bs, response_length)
|
|
|
log_prob: `(torch.Tensor)`
|
|
|
shape: (bs, response_length)
|
|
|
advantages: `(torch.Tensor)`
|
|
|
shape: (bs, response_length)
|
|
|
response_mask: `(torch.Tensor)`
|
|
|
shape: (bs, response_length)
|
|
|
cliprange: (float)
|
|
|
The clip range used in PPO. See https://arxiv.org/abs/1707.06347
|
|
|
cliprange_low: (float)
|
|
|
The lower clip range used in PPO.
|
|
|
cliprange_high: (float)
|
|
|
The higher clip range used in PPO.
|
|
|
clip_ratio_c: (float) default: 3.0
|
|
|
The lower bound of the ratio for dual-clip PPO, See https://arxiv.org/pdf/1912.09729
|
|
|
loss_agg_mode: (str) choices: "token-mean" /
|
|
|
"seq-mean-token-sum" /
|
|
|
"seq-mean-token-mean" /
|
|
|
"seq-mean-token-sum-norm" /
|
|
|
"token-mean" is the default behavior
|
|
|
|
|
|
Returns:
|
|
|
pg_loss: `a scalar torch.Tensor`
|
|
|
policy gradient loss computed via PPO
|
|
|
pg_clipfrac: (float)
|
|
|
the fraction of policy gradient loss being clipped
|
|
|
ppo_kl: (float)
|
|
|
the estimated KL divergence between the latest updating policy and the old sampling policy
|
|
|
pg_clipfrac_lower: (float)
|
|
|
the fraction of policy gradient loss being clipped when the advantage is negative
|
|
|
"""
|
|
|
assert clip_ratio_c > 1.0, "The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0," + f" but get the value: {clip_ratio_c}."
|
|
|
|
|
|
negative_approx_kl = log_prob - old_log_prob
|
|
|
ratio = torch.exp(negative_approx_kl)
|
|
|
ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)
|
|
|
|
|
|
pg_losses1 = -advantages * ratio
|
|
|
if cliprange_low is None:
|
|
|
cliprange_low = cliprange
|
|
|
if cliprange_high is None:
|
|
|
cliprange_high = cliprange
|
|
|
pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high)
|
|
|
clip_pg_losses1 = torch.maximum(pg_losses1, pg_losses2)
|
|
|
pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)
|
|
|
|
|
|
pg_losses3 = -advantages * clip_ratio_c
|
|
|
clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
|
|
|
pg_clipfrac_lower = verl_F.masked_mean(torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask)
|
|
|
|
|
|
pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
|
|
|
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
|
|
|
|
|
|
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
|
|
|
|
|
|
|
|
|
def compute_entropy_loss(logits, response_mask):
|
|
|
"""Compute Categorical entropy loss
|
|
|
|
|
|
Args:
|
|
|
logits: `(torch.Tensor)`
|
|
|
shape: (bs, response_length, vocab_size)
|
|
|
response_mask: `(torch.Tensor)`
|
|
|
shape: (bs, response_length)
|
|
|
|
|
|
Returns:
|
|
|
entropy: a scalar torch.Tensor
|
|
|
|
|
|
"""
|
|
|
|
|
|
entropy = verl_F.entropy_from_logits(logits)
|
|
|
entropy_loss = verl_F.masked_mean(entropy, mask=response_mask)
|
|
|
return entropy_loss
|
|
|
|
|
|
|
|
|
def compute_value_loss(vpreds, returns, values, response_mask, cliprange_value):
|
|
|
"""Compute the value loss. Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151
|
|
|
|
|
|
Args:
|
|
|
vpreds (`torch.FloatTensor`):
|
|
|
Predicted values of the value head, shape (`batch_size`, `response_length`)
|
|
|
values (`torch.FloatTensor`):
|
|
|
Old values of value head, shape (`batch_size`, `response_length`)
|
|
|
returns: (`torch.FloatTensor`):
|
|
|
Ground truth returns, shape (`batch_size`, `response_length`)
|
|
|
|
|
|
Returns:
|
|
|
vf_loss: a scalar (`torch.FloatTensor`):
|
|
|
value function loss
|
|
|
vf_clipfrac: a float
|
|
|
The ratio of vf being clipped
|
|
|
|
|
|
"""
|
|
|
vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value)
|
|
|
vf_losses1 = (vpreds - returns) ** 2
|
|
|
vf_losses2 = (vpredclipped - returns) ** 2
|
|
|
vf_loss = 0.5 * verl_F.masked_mean(torch.max(vf_losses1, vf_losses2), response_mask)
|
|
|
vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), response_mask)
|
|
|
return vf_loss, vf_clipfrac
|
|
|
|
|
|
|
|
|
def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor:
|
|
|
"""Compute KL divergence given logprob and ref_logprob.
|
|
|
Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104
|
|
|
|
|
|
Args:
|
|
|
logprob:
|
|
|
ref_logprob:
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
"""
|
|
|
if kl_penalty == "kl":
|
|
|
return logprob - ref_logprob
|
|
|
|
|
|
if kl_penalty == "abs":
|
|
|
return (logprob - ref_logprob).abs()
|
|
|
|
|
|
if kl_penalty == "mse":
|
|
|
return 0.5 * (logprob - ref_logprob).square()
|
|
|
|
|
|
|
|
|
|
|
|
if kl_penalty == "low_var_kl":
|
|
|
kl = ref_logprob - logprob
|
|
|
ratio = torch.exp(kl)
|
|
|
kld = (ratio - kl - 1).contiguous()
|
|
|
return torch.clamp(kld, min=-10, max=10)
|
|
|
|
|
|
if kl_penalty == "full":
|
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
raise NotImplementedError
|
|
|
|