|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Core functions to implement PPO algorithms. |
|
|
The function implemented in this file should be used by trainer with different distributed strategies to |
|
|
implement PPO |
|
|
""" |
|
|
|
|
|
from abc import ABC, abstractmethod |
|
|
from collections import defaultdict |
|
|
from typing import TYPE_CHECKING, Tuple |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from ..utils import torch_functional as VF |
|
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from .config import AlgorithmConfig |
|
|
|
|
|
|
|
|
class KLController(ABC): |
|
|
kl_coef: float |
|
|
"""KL coefficient.""" |
|
|
|
|
|
@abstractmethod |
|
|
def update(self, current_kl: float, n_steps: int) -> None: |
|
|
"""Update kl_coef according to current KL.""" |
|
|
... |
|
|
|
|
|
|
|
|
class AdaptiveKLController(KLController): |
|
|
"""Adaptive KL controller described in: https://arxiv.org/pdf/1909.08593.pdf |
|
|
|
|
|
Copied from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/utils.py#L54""" |
|
|
|
|
|
def __init__(self, init_kl_coef: float, target_kl: float, horizon: float): |
|
|
self.kl_coef = init_kl_coef |
|
|
self.target = target_kl |
|
|
self.horizon = horizon |
|
|
|
|
|
def update(self, current_kl: float, n_steps: int) -> None: |
|
|
target = self.target |
|
|
proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2) |
|
|
mult = 1 + proportional_error * n_steps / self.horizon |
|
|
self.kl_coef *= mult |
|
|
|
|
|
|
|
|
class FixedKLController(KLController): |
|
|
"""Fixed KL controller. |
|
|
|
|
|
Copeid from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/utils.py#L72""" |
|
|
|
|
|
def __init__(self, init_kl_coef: float): |
|
|
self.kl_coef = init_kl_coef |
|
|
|
|
|
def update(self, current_kl: float, n_steps: int) -> None: |
|
|
pass |
|
|
|
|
|
|
|
|
def get_kl_controller(algorithm_config: "AlgorithmConfig") -> KLController: |
|
|
"""Adapted from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/ppo_trainer.py#L319""" |
|
|
if algorithm_config.kl_type == "fixed": |
|
|
kl_ctrl = FixedKLController(init_kl_coef=algorithm_config.kl_coef) |
|
|
elif algorithm_config.kl_type == "adaptive": |
|
|
assert algorithm_config.kl_horizon > 0, f"horizon must be larger than 0. Got {algorithm_config.kl_horizon}." |
|
|
kl_ctrl = AdaptiveKLController( |
|
|
init_kl_coef=algorithm_config.kl_coef, |
|
|
target_kl=algorithm_config.kl_target, |
|
|
horizon=algorithm_config.kl_horizon, |
|
|
) |
|
|
else: |
|
|
raise ValueError(f"Unknown kl type: {algorithm_config.kl_type}.") |
|
|
|
|
|
return kl_ctrl |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def compute_gae_advantage_return( |
|
|
token_level_rewards: torch.Tensor, |
|
|
values: torch.Tensor, |
|
|
response_mask: torch.Tensor, |
|
|
gamma: torch.Tensor, |
|
|
lam: torch.Tensor, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Adapted from https://github.com/huggingface/trl/blob/v0.16.0/trl/trainer/ppo_trainer.py#L513 |
|
|
|
|
|
Args: |
|
|
token_level_rewards: `(torch.Tensor)` |
|
|
shape: (bs, response_length) |
|
|
values: `(torch.Tensor)` |
|
|
shape: (bs, response_length) |
|
|
response_mask: `(torch.Tensor)` |
|
|
shape: (bs, response_length). The token after eos tokens have mask zero. |
|
|
gamma: `(float)` |
|
|
discounted factor used in RL |
|
|
lam: `(float)` |
|
|
lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438) |
|
|
|
|
|
Returns: |
|
|
advantages: `(torch.Tensor)` |
|
|
shape: (bs, response_length) |
|
|
returns: `(torch.Tensor)` |
|
|
shape: (bs, response_length) |
|
|
|
|
|
""" |
|
|
lastgaelam = 0 |
|
|
advantages_reversed = [] |
|
|
gen_len = token_level_rewards.shape[-1] |
|
|
for t in reversed(range(gen_len)): |
|
|
nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0 |
|
|
delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t] |
|
|
lastgaelam = delta + gamma * lam * lastgaelam |
|
|
advantages_reversed.append(lastgaelam) |
|
|
|
|
|
advantages = torch.stack(advantages_reversed[::-1], dim=1) |
|
|
returns = advantages + values |
|
|
advantages = VF.masked_whiten(advantages, response_mask) |
|
|
return advantages, returns |
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def compute_grpo_outcome_advantage( |
|
|
token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: torch.Tensor, eps: float = 1e-6 |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Compute advantage for GRPO, operating only on Outcome reward |
|
|
(with only one scalar reward for each response). |
|
|
|
|
|
Args: |
|
|
token_level_rewards: `(torch.Tensor)` |
|
|
shape: (bs, response_length) |
|
|
response_mask: `(torch.Tensor)` |
|
|
shape: (bs, response_length) |
|
|
|
|
|
Returns: |
|
|
advantages: `(torch.Tensor)` |
|
|
shape: (bs, response_length) |
|
|
returns: `(torch.Tensor)` |
|
|
shape: (bs, response_length) |
|
|
|
|
|
""" |
|
|
scores = token_level_rewards.sum(dim=-1) |
|
|
id2score = defaultdict(list) |
|
|
id2mean, id2std = {}, {} |
|
|
|
|
|
bsz = scores.shape[0] |
|
|
for i in range(bsz): |
|
|
id2score[index[i]].append(scores[i]) |
|
|
|
|
|
for idx in id2score: |
|
|
assert len(id2score[idx]) > 1, "GRPO needs rollout.n > 1." |
|
|
id2mean[idx] = torch.mean(torch.tensor(id2score[idx])) |
|
|
id2std[idx] = torch.std(torch.tensor(id2score[idx])) |
|
|
|
|
|
for i in range(bsz): |
|
|
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + eps) |
|
|
|
|
|
returns = scores.unsqueeze(-1) * response_mask |
|
|
return returns, returns |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def compute_rloo_outcome_advantage( |
|
|
token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740 |
|
|
|
|
|
Args: |
|
|
token_level_rewards: `(torch.Tensor)` |
|
|
shape: (bs, response_length) |
|
|
response_mask: `(torch.Tensor)` |
|
|
shape: (bs, response_length) |
|
|
|
|
|
Returns: |
|
|
advantages: `(torch.Tensor)` |
|
|
shape: (bs, response_length) |
|
|
returns: `(torch.Tensor)` |
|
|
shape: (bs, response_length) |
|
|
|
|
|
""" |
|
|
scores = token_level_rewards.sum(dim=-1) |
|
|
|
|
|
id2score = defaultdict(list) |
|
|
id2sum = {} |
|
|
bsz = scores.shape[0] |
|
|
for i in range(bsz): |
|
|
id2score[index[i]].append(scores[i]) |
|
|
|
|
|
for idx in id2score: |
|
|
id2sum[idx] = torch.sum(torch.tensor(id2score[idx])) |
|
|
|
|
|
for i in range(bsz): |
|
|
sample_num = len(id2score[index[i]]) |
|
|
assert sample_num > 1, "RLOO needs rollout.n > 1." |
|
|
baseline = (id2sum[index[i]] - scores[i]) / (sample_num - 1) |
|
|
scores[i] = scores[i] - baseline |
|
|
|
|
|
returns = scores.unsqueeze(-1) * response_mask |
|
|
return returns, returns |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def compute_reinforce_plus_plus_outcome_advantage( |
|
|
token_level_rewards: torch.Tensor, response_mask: torch.Tensor, gamma: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Compute advantage for REINFORCE++. |
|
|
This implementation is based on the paper: https://arxiv.org/abs/2501.03262 |
|
|
|
|
|
Args: |
|
|
token_level_rewards: `(torch.Tensor)` |
|
|
shape: (bs, response_length) |
|
|
response_mask: `(torch.Tensor)` |
|
|
shape: (bs, response_length) |
|
|
|
|
|
Returns: |
|
|
advantages: `(torch.Tensor)` |
|
|
shape: (bs, response_length) |
|
|
returns: `(torch.Tensor)` |
|
|
shape: (bs, response_length) |
|
|
|
|
|
""" |
|
|
returns = torch.zeros_like(token_level_rewards) |
|
|
running_return = 0 |
|
|
for t in reversed(range(token_level_rewards.shape[1])): |
|
|
running_return = token_level_rewards[:, t] + gamma * running_return |
|
|
returns[:, t] = running_return |
|
|
|
|
|
running_return = running_return * response_mask[:, t] |
|
|
|
|
|
advantages = VF.masked_whiten(returns, response_mask) |
|
|
return advantages, returns |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def compute_remax_outcome_advantage( |
|
|
token_level_rewards: torch.Tensor, reward_baselines: torch.Tensor, response_mask: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Compute advantage for ReMax, operating only on Outcome reward |
|
|
This implementation is based on the paper: https://arxiv.org/abs/2310.10505 |
|
|
|
|
|
(with only one scalar reward for each response). |
|
|
Args: |
|
|
token_level_rewards: `(torch.Tensor)` |
|
|
shape: (bs, response_length) |
|
|
reward_baselines: `(torch.Tensor)` |
|
|
shape: (bs,) |
|
|
response_mask: `(torch.Tensor)` |
|
|
shape: (bs, response_length) |
|
|
|
|
|
Returns: |
|
|
advantages: `(torch.Tensor)` |
|
|
shape: (bs, response_length) |
|
|
returns: `(torch.Tensor)` |
|
|
shape: (bs, response_length) |
|
|
|
|
|
""" |
|
|
scores = token_level_rewards.sum(dim=-1) - reward_baselines |
|
|
returns = scores.unsqueeze(-1) * response_mask |
|
|
return returns, returns |
|
|
|
|
|
|
|
|
def compute_rewards( |
|
|
token_level_scores: torch.Tensor, |
|
|
log_probs: torch.Tensor, |
|
|
ref_log_probs: torch.Tensor, |
|
|
kl_ratio: float, |
|
|
) -> torch.Tensor: |
|
|
kl = log_probs - ref_log_probs |
|
|
return token_level_scores - kl * kl_ratio |
|
|
|
|
|
|
|
|
def compute_policy_loss( |
|
|
old_log_probs: torch.Tensor, |
|
|
log_probs: torch.Tensor, |
|
|
advantages: torch.Tensor, |
|
|
response_mask: torch.Tensor, |
|
|
clip_ratio_low: float, |
|
|
clip_ratio_high: float, |
|
|
clip_ratio_dual: float, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
"""Compute the policy loss. |
|
|
|
|
|
Adapted from https://github.com/huggingface/trl/blob/v0.15.0/trl/trainer/ppo_trainer.py#L568 |
|
|
|
|
|
Args: |
|
|
old_log_prob: `(torch.Tensor)` |
|
|
shape: (bs, response_length) |
|
|
log_prob: `(torch.Tensor)` |
|
|
shape: (bs, response_length) |
|
|
advantages: `(torch.Tensor)` |
|
|
shape: (bs, response_length) |
|
|
response_mask: `(torch.Tensor)` |
|
|
shape: (bs, response_length) |
|
|
clip_ratio_low: (float) |
|
|
The lower clip range used in PPO. See https://arxiv.org/abs/1707.06347 |
|
|
clip_ratio_high: (float) |
|
|
The higher clip range used in DAPO. See https://arxiv.org/pdf/2503.14476 |
|
|
clip_ratio_dual: (float) |
|
|
The dual clip range used in Dual-clip PPO. See https://arxiv.org/pdf/1912.09729 |
|
|
|
|
|
Returns: |
|
|
pg_loss: `a scalar torch.Tensor` |
|
|
policy gradient loss computed via PPO |
|
|
pg_clipfrac_higher: (float) |
|
|
a float number indicating the fraction of policy gradient loss being clipped to a higher value |
|
|
pg_clipfrac_lower: (float) |
|
|
a float number indicating the fraction of policy gradient loss being clipped to a lower value |
|
|
ppo_kl: (float) |
|
|
a float number indicating the mean KL divergence between the old policy and the new policy |
|
|
|
|
|
""" |
|
|
negative_approx_kl = log_probs - old_log_probs |
|
|
|
|
|
|
|
|
ratio = torch.exp(negative_approx_kl) |
|
|
clipped_ratio = torch.exp( |
|
|
torch.clamp(negative_approx_kl, np.log(1.0 - clip_ratio_low), np.log(1.0 + clip_ratio_high)) |
|
|
) |
|
|
|
|
|
pg_loss = -advantages * ratio |
|
|
pg_loss2 = -advantages * clipped_ratio |
|
|
pg_loss3 = -advantages * clip_ratio_dual |
|
|
|
|
|
clipped_pg_loss_higher = torch.max(pg_loss, pg_loss2) |
|
|
pg_clipfrac_higher = (pg_loss < pg_loss2).float() |
|
|
clipped_pg_loss_lower = torch.min(clipped_pg_loss_higher, pg_loss3) |
|
|
final_pg_loss = torch.where(advantages < 0, clipped_pg_loss_lower, clipped_pg_loss_higher) |
|
|
pg_clipfrac_lower = (clipped_pg_loss_higher > pg_loss3).float() * (advantages < 0).float() |
|
|
|
|
|
final_pg_loss = VF.masked_mean(final_pg_loss, response_mask) |
|
|
pg_clipfrac_higher = VF.masked_mean(pg_clipfrac_higher, response_mask) |
|
|
pg_clipfrac_lower = VF.masked_mean(pg_clipfrac_lower, response_mask) |
|
|
ppo_kl = VF.masked_mean(-negative_approx_kl, response_mask) |
|
|
return final_pg_loss, pg_clipfrac_higher, pg_clipfrac_lower, ppo_kl |
|
|
|
|
|
|
|
|
def compute_value_loss( |
|
|
vpreds: torch.Tensor, |
|
|
returns: torch.Tensor, |
|
|
values: torch.Tensor, |
|
|
action_mask: torch.Tensor, |
|
|
cliprange_value: float, |
|
|
) -> Tuple[torch.Tensor, float]: |
|
|
"""Compute the value loss. |
|
|
|
|
|
Adapted from https://github.com/huggingface/trl/blob/v0.15.0/trl/trainer/ppo_trainer.py#L556 |
|
|
|
|
|
Args: |
|
|
vpreds (`torch.FloatTensor`): |
|
|
Predicted values of the value head, shape (`batch_size`, `response_length`) |
|
|
returns: (`torch.FloatTensor`): |
|
|
Ground truth returns, shape (`batch_size`, `response_length`) |
|
|
values (`torch.FloatTensor`): |
|
|
Old values of value head, shape (`batch_size`, `response_length`) |
|
|
action_mask: `(torch.Tensor)` |
|
|
shape: (bs, response_length) |
|
|
cliprange_value: (float) |
|
|
The clip range for value net used in PPO. See https://arxiv.org/abs/1707.06347 |
|
|
|
|
|
Returns: |
|
|
vf_loss: a scalar (`torch.FloatTensor`): |
|
|
value function loss |
|
|
vf_clipfrac: a float |
|
|
The ratio of vf being clipped |
|
|
|
|
|
""" |
|
|
vpredclipped = torch.clamp(vpreds, values - cliprange_value, values + cliprange_value) |
|
|
vf_loss1 = torch.square(vpreds - returns) |
|
|
vf_loss2 = torch.square(vpredclipped - returns) |
|
|
vf_loss = 0.5 * VF.masked_mean(torch.max(vf_loss1, vf_loss2), action_mask) |
|
|
vf_clipfrac = VF.masked_mean((vf_loss1 < vf_loss2).float(), action_mask) |
|
|
return vf_loss, vf_clipfrac |
|
|
|
|
|
|
|
|
def compute_kl(log_probs: torch.FloatTensor, ref_log_probs: torch.FloatTensor, kl_penalty: str) -> torch.Tensor: |
|
|
"""Compute KL divergence given log_probs and ref_log_probs. |
|
|
|
|
|
Adapted from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/ppo_trainer.py#L1150 |
|
|
|
|
|
Args: |
|
|
log_probs: torch.Tensor |
|
|
ref_log_probs: torch.Tensor |
|
|
kl_penalty: str |
|
|
|
|
|
Returns: |
|
|
kl_div: torch.Tensor |
|
|
|
|
|
""" |
|
|
log_probs, ref_log_probs = log_probs.float(), ref_log_probs.float() |
|
|
if kl_penalty == "kl": |
|
|
return log_probs - ref_log_probs |
|
|
|
|
|
if kl_penalty == "abs": |
|
|
return (log_probs - ref_log_probs).abs() |
|
|
|
|
|
if kl_penalty == "mse": |
|
|
return 0.5 * (log_probs - ref_log_probs).square() |
|
|
|
|
|
|
|
|
|
|
|
if kl_penalty == "low_var_kl": |
|
|
kl = ref_log_probs - log_probs |
|
|
kld = (kl.exp() - kl - 1).contiguous() |
|
|
return torch.clamp(kld, min=-10, max=10) |
|
|
|
|
|
if kl_penalty == "full": |
|
|
return F.kl_div(ref_log_probs, log_probs, log_target=True, reduction="none").sum(-1) |
|
|
|
|
|
raise NotImplementedError(f"Unknown KL penalty: {kl_penalty}.") |
|
|
|