JustinTX's picture
Add files using upload-large-folder tool
d7b3a74 verified
# Adapt from https://github.com/OpenRLHF/OpenRLHF/blob/10c733694ed9fbb78a0a2ff6a05efc7401584d46/openrlhf/models/utils.py
# and https://github.com/OpenRLHF/OpenRLHF/blob/10c733694ed9fbb78a0a2ff6a05efc7401584d46/openrlhf/trainer/ppo_utils/experience_maker.py
from argparse import Namespace
import torch
import torch.distributed as dist
import torch.nn.functional as F
@torch.compile(dynamic=True)
def compute_approx_kl(
log_probs: torch.Tensor,
log_probs_base: torch.Tensor,
kl_loss_type: str,
importance_ratio: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Compute the approximate KL divergence between two distributions.
Schulman blog: http://joschu.net/blog/kl-approx.html
Args:
log_probs: Log probabilities of the new distribution.
log_probs_base: Log probabilities of the base distribution.
kl_loss_type: Type of KL estimator (k1, k2, k3, low_var_kl).
importance_ratio: Optional IS ratio (π_θ/π_old) for unbiased KL estimation.
"""
log_ratio = log_probs.float() - log_probs_base.float()
if kl_loss_type == "k1":
kl = log_ratio
elif kl_loss_type == "k2":
kl = log_ratio**2 / 2.0
elif kl_loss_type in ["k3", "low_var_kl"]:
# The non negative kl approximation in
# http://joschu.net/blog/kl-approx.html
# Besides non negative, it is also unbiased and have lower variance.
log_ratio = -log_ratio
kl = log_ratio.exp() - 1 - log_ratio
else:
raise ValueError(f"Unknown kl_loss_type: {kl_loss_type}")
# Apply IS ratio for unbiased KL estimation (DeepSeek-V3.2)
if importance_ratio is not None:
kl = importance_ratio * kl
# Clamp only for low_var_kl for numerical stability
if kl_loss_type == "low_var_kl":
kl = torch.clamp(kl, min=-10, max=10)
return kl
def compute_opsm_mask(
args: Namespace,
full_log_probs: list[torch.Tensor],
full_old_log_probs: list[torch.Tensor],
advantages: list[torch.Tensor],
loss_masks: list[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute Off-Policy Sequence Masking (OPSM) mask.
Args:
args: Configuration containing `opsm_delta` threshold.
full_log_probs: Current policy log-probs per sample.
full_old_log_probs: Old policy log-probs per sample.
advantages: Advantage values per sample.
loss_masks: Loss masks per sample.
Returns:
Tuple of `(opsm_mask, opsm_clipfrac)` where `opsm_mask` is a
concatenated tensor of per-token masks and
`opsm_clipfrac` is the count of masked sequences.
"""
opsm_mask_list = []
device = advantages[0].device
opsm_clipfrac = torch.tensor(0.0, device=device)
for full_log_prob, full_old_log_prob, advantage, loss_mask in zip(
full_log_probs, full_old_log_probs, advantages, loss_masks, strict=False
):
# Calculate sequence-level KL
seq_kl = ((full_old_log_prob - full_log_prob) * loss_mask).sum() / torch.clamp_min(loss_mask.sum(), 1)
# Create mask: 0 if (advantage < 0 and seq_kl > delta), else 1
mask = ((advantage < 0) & (seq_kl > args.opsm_delta)).float()
opsm_clipfrac += mask.sum() / torch.clamp_min(loss_mask.sum(), 1)
opsm_mask_list.append(1 - mask)
opsm_mask = torch.cat(opsm_mask_list, dim=0)
return opsm_mask, opsm_clipfrac
def compute_gspo_kl(
full_log_probs: list[torch.Tensor],
full_old_log_probs: list[torch.Tensor],
local_log_probs: list[torch.Tensor],
loss_masks: list[torch.Tensor],
) -> torch.Tensor:
"""Compute GSPO-style per-sequence KL divergence.
Args:
full_log_probs: Current policy log-probs per sample (full or CP-local).
full_old_log_probs: Old policy log-probs per sample (full or CP-local).
local_log_probs: Local (CP-local) log-probs for expansion shape reference.
loss_masks: Loss masks per sample.
Returns:
Concatenated tensor of per-token KL values where each token in a
sequence has the same KL value (the sequence-level KL).
"""
# Compute sequence-level KL and expand to per-token
ppo_kl = [
((old_logprob - log_prob) * loss_mask).sum() / torch.clamp_min(loss_mask.sum(), 1)
for log_prob, old_logprob, loss_mask in zip(full_log_probs, full_old_log_probs, loss_masks, strict=False)
]
ppo_kl = [kl.expand_as(log_prob) for kl, log_prob in zip(ppo_kl, local_log_probs, strict=False)]
ppo_kl = torch.cat(ppo_kl, dim=0)
return ppo_kl
@torch.compile(dynamic=True)
def compute_policy_loss(
ppo_kl: torch.Tensor,
advantages: torch.Tensor,
eps_clip: float,
eps_clip_high: float,
eps_clip_c: float | None = None,
):
ratio = (-ppo_kl).exp()
pg_losses1 = -ratio * advantages
pg_losses2 = -ratio.clamp(1 - eps_clip, 1 + eps_clip_high) * advantages
clip_pg_losses1 = torch.maximum(pg_losses1, pg_losses2)
clipfrac = torch.gt(pg_losses2, pg_losses1).float()
if eps_clip_c is not None:
assert (
eps_clip_c > 1.0
), f"The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0, but get the value: {eps_clip_c}."
pg_losses3 = -eps_clip_c * advantages
clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
else:
pg_losses = clip_pg_losses1
return pg_losses, clipfrac
def compute_log_probs(logits: torch.Tensor, tokens: torch.Tensor, process_group: dist.ProcessGroup | None):
# TODO: when megatron is not installed, fall back to naive implementation
from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy
# convert to [seq_len, batch_size, vocab_size] as expected by fused_vocab_parallel_cross_entropy
logits = logits.unsqueeze(1)
tokens = tokens.unsqueeze(1)
return -fused_vocab_parallel_cross_entropy(logits, tokens, process_group)
# from https://github.com/volcengine/verl/blob/0bdf7f469854815177e73dcfe9e420836c952e6e/verl/utils/megatron/tensor_parallel.py#L99
class _VocabParallelEntropy(torch.autograd.Function):
@staticmethod
def forward(ctx, vocab_parallel_logits: torch.Tensor, process_group: dist.ProcessGroup) -> torch.Tensor:
@torch.compile(dynamic=True)
def mul_reduce(a, b):
return (a * b).sum(dim=-1, keepdim=True)
logits_max = vocab_parallel_logits.max(dim=-1, keepdim=True).values
dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group)
normalized_vocab_parallel_logits = vocab_parallel_logits - logits_max
normalized_exp_logits = normalized_vocab_parallel_logits.exp_()
normalized_sum_exp_logits = normalized_exp_logits.sum(dim=-1, keepdim=True)
dist.all_reduce(normalized_sum_exp_logits, group=process_group)
softmax_logits = normalized_exp_logits.div_(normalized_sum_exp_logits)
sum_softmax_times_logits = mul_reduce(softmax_logits, vocab_parallel_logits)
dist.all_reduce(sum_softmax_times_logits, group=process_group)
entropy = logits_max + normalized_sum_exp_logits.log() - sum_softmax_times_logits
ctx.save_for_backward(vocab_parallel_logits, softmax_logits, sum_softmax_times_logits)
return entropy.squeeze(dim=-1)
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
vocab_parallel_logits, softmax_logits, sum_softmax_times_logits = ctx.saved_tensors
# reuse softmax_logits as grad
vocab_parallel_logits.sub_(sum_softmax_times_logits)
softmax_logits.mul_(vocab_parallel_logits)
softmax_logits.mul_(grad_output.unsqueeze(dim=-1))
# recover vocab_parallel_logits
vocab_parallel_logits.add_(sum_softmax_times_logits)
softmax_logits.mul_(-1)
return softmax_logits, None
def compute_entropy_from_logits(logits: torch.Tensor, process_group) -> torch.Tensor:
return _VocabParallelEntropy.apply(logits, process_group)
def get_grpo_returns(
rewards: torch.Tensor,
kl: list[torch.Tensor],
):
returns = []
for i in range(len(rewards)):
returns.append(torch.ones_like(kl[i]) * rewards[i])
return returns
def get_reinforce_plus_plus_returns(
rewards: torch.Tensor,
kl: list[torch.Tensor],
loss_masks: list[torch.Tensor],
response_lengths: list[int],
total_lengths: list[int],
kl_coef: float,
gamma: float,
) -> list[torch.Tensor]:
"""
Calculates discounted returns for REINFORCE++ (https://arxiv.org/pdf/2501.03262)
Args:
rewards (Tensor): A tensor of scalar rewards for each sequence.
kl (List[Tensor]): List of per-token KL divergence tensors for sequence chunks.
loss_masks (List[Tensor]): List of response-only loss masks for each full sequence.
response_lengths (List[int]): The full length of each response sequence.
total_lengths (List[int]): The full length of each sequence (prompt + response).
kl_coef (float): Coefficient for the KL penalty.
gamma (float): The discount factor.
Returns:
List[torch.Tensor]: A list of return (G_t) tensors for the
local sequence chunks owned by the current GPU rank.
"""
from megatron.core import mpu
cp_size = mpu.get_context_parallel_world_size()
final_returns_chunks = []
for i in range(len(rewards)):
local_kl_chunk = kl[i]
total_len, response_len = total_lengths[i], response_lengths[i]
if cp_size > 1:
# Step 1,2:Gather all chunks and token_offsets from all ranks and reconstruct the full response tensor by splitting and placing each part
from slime.backends.megatron_utils.cp_utils import all_gather_with_cp
full_kl_response = all_gather_with_cp(local_kl_chunk, total_len, response_len)
else:
full_kl_response = local_kl_chunk
# Step 3: Compute returns on full response kl tensor.
full_mask = loss_masks[i]
assert full_mask.sum().item() > 0, f"Sequence at index {i} is fully masked."
masked_kl = full_kl_response * full_mask
token_level_rewards = -kl_coef * masked_kl
last_idx = full_mask.nonzero(as_tuple=True)[0][-1]
token_level_rewards[last_idx] += rewards[i]
returns_for_seq = torch.zeros_like(token_level_rewards)
running_return = 0.0
for t in reversed(range(token_level_rewards.size(0))):
# G_t = r_t + gamma * G_{t+1}
running_return = token_level_rewards[t] + gamma * running_return
returns_for_seq[t] = running_return
# Step 4: Pick up the results corresponding to our local chunk's parts.
if cp_size > 1:
from slime.backends.megatron_utils.cp_utils import slice_log_prob_with_cp
local_returns_chunk = slice_log_prob_with_cp(returns_for_seq, total_len, response_len)
else:
local_returns_chunk = returns_for_seq
final_returns_chunks.append(local_returns_chunk)
return final_returns_chunks
def get_reinforce_plus_plus_baseline_advantages(
rewards: torch.Tensor,
kl: list[torch.Tensor],
loss_masks: list[torch.Tensor],
kl_coef: float,
) -> list[torch.Tensor]:
"""
Calculates the unwhitened advantages for the REINFORCE++-baseline algorithm.
Broadcasting the scalar (reward - group_baseline) to each token.
Args:
rewards (Tensor): A tensor of scalar rewards, where the group-wise
baseline has already been subtracted.
kl (list[Tensor]): A list of per-token KL divergence tensors. Used to
get the shape for broadcasting.
loss_masks (list[Tensor]): A list of per-token loss masks.
kl_coef (float): Coefficient for the KL penalty.
Returns:
list[Tensor]: A list of tensors containing the unwhitened advantages.
"""
# Broadcast to get unwhitened advantages
unwhitened_advantages = [
torch.ones_like(kl_tensor) * reward_val - kl_coef * kl_tensor
for kl_tensor, reward_val in zip(kl, rewards, strict=False)
]
return unwhitened_advantages
def get_advantages_and_returns(
total_len: int,
response_len: int,
values: torch.Tensor,
rewards: torch.Tensor,
gamma: float,
lambd: float,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Function that computes advantages and returns from rewards and values.
Calculated as in the original PPO paper: https://arxiv.org/abs/1707.06347
Note that rewards may include a KL divergence loss term.
Advantages looks like this:
Adv1 = R1 + γ * λ * R2 + γ^2 * λ^2 * R3 + ...
- V1 + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ...
Returns looks like this:
Ret1 = R1 + γ * λ * R2 + γ^2 * λ^2 * R3 + ...
+ γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ...
Input:
- values: Tensor of shape (response_size,)
- rewards: Tensor of shape (response_size,)
Output:
- advantages: Tensor of shape (response_size,)
- returns: Tensor of shape (response_size,)
"""
from megatron.core import mpu
cp_size = mpu.get_context_parallel_world_size()
if cp_size > 1:
from slime.backends.megatron_utils.cp_utils import all_gather_with_cp
full_rewards = all_gather_with_cp(rewards, total_len, response_len)
full_values = all_gather_with_cp(values, total_len, response_len)
else:
full_rewards = rewards
full_values = values
lastgaelam = 0
advantages_reversed = []
for t in reversed(range(response_len)):
nextvalues = full_values[t + 1] if t < response_len - 1 else 0.0
delta = full_rewards[t] + gamma * nextvalues - full_values[t]
lastgaelam = delta + gamma * lambd * lastgaelam
advantages_reversed.append(lastgaelam)
full_advantages = torch.tensor(advantages_reversed[::-1], dtype=full_values.dtype, device=full_values.device)
full_returns = full_advantages + full_values
if cp_size > 1:
from slime.backends.megatron_utils.cp_utils import slice_log_prob_with_cp
advantages = slice_log_prob_with_cp(full_advantages, total_len, response_len)
returns = slice_log_prob_with_cp(full_returns, total_len, response_len)
else:
advantages = full_advantages
returns = full_returns
return advantages.detach(), returns
def get_advantages_and_returns_batch(
total_lengths,
response_lengths,
values_list,
rewards_list,
gamma,
lambd,
chunked: bool = True,
):
"""
Batched GAE with CP support.
Input:
total_lengths: list[int], each sample's total_len
response_lengths: list[int], each sample's response_len
values_list: list[Tensor], each shape = [resp_len_i]
rewards_list: list[Tensor], same shape
Output:
advantages_list: list[Tensor], each shape = [resp_len_i]
returns_list: list[Tensor], same shape
"""
from megatron.core import mpu
with torch.no_grad():
B = len(response_lengths)
assert B == len(values_list)
assert B == len(rewards_list)
cp_size = mpu.get_context_parallel_world_size()
device = values_list[0].device
dtype = values_list[0].dtype
if cp_size > 1:
from slime.backends.megatron_utils.cp_utils import all_gather_with_cp
full_values_list = []
full_rewards_list = []
for total_len, resp_len, v, r in zip(
total_lengths, response_lengths, values_list, rewards_list, strict=False
):
full_v = all_gather_with_cp(v, total_len, resp_len)
full_r = all_gather_with_cp(r, total_len, resp_len)
full_values_list.append(full_v)
full_rewards_list.append(full_r)
# full_values_list[i].shape = [total_len_i]
else:
full_values_list = values_list
full_rewards_list = rewards_list
# pad to max_len for batched GAE
max_len = max(response_lengths)
full_values = torch.zeros(B, max_len, device=device, dtype=dtype)
full_rewards = torch.zeros(B, max_len, device=device, dtype=dtype)
for i in range(B):
L = response_lengths[i]
full_values[i, :L] = full_values_list[i][:L]
full_rewards[i, :L] = full_rewards_list[i][:L]
if not chunked:
full_advantages, full_returns = vanilla_gae(
rewards=full_rewards,
values=full_values,
gamma=gamma,
lambd=lambd,
)
else:
full_advantages, full_returns = chunked_gae(
rewards=full_rewards,
values=full_values,
gamma=gamma,
lambd=lambd,
)
advantages_list = []
returns_list = []
if cp_size > 1:
from slime.backends.megatron_utils.cp_utils import slice_log_prob_with_cp
for total_len, resp_len, adv_row, ret_row in zip(
total_lengths,
response_lengths,
full_advantages,
full_returns,
strict=False,
):
adv_full = adv_row # shape = [resp_len_i padded to max_len]
ret_full = ret_row
adv_sliced = slice_log_prob_with_cp(adv_full[:resp_len], total_len, resp_len)
ret_sliced = slice_log_prob_with_cp(ret_full[:resp_len], total_len, resp_len)
advantages_list.append(adv_sliced)
returns_list.append(ret_sliced)
else:
for i in range(B):
L = response_lengths[i]
advantages_list.append(full_advantages[i, :L])
returns_list.append(full_returns[i, :L])
return advantages_list, returns_list
def vanilla_gae(
rewards: torch.Tensor,
values: torch.Tensor,
gamma: float,
lambd: float,
):
B, T = rewards.shape
device = rewards.device
dtype = rewards.dtype
lastgaelam = torch.zeros(B, device=device, dtype=dtype)
adv_rev = []
for t in reversed(range(T)):
next_value = values[:, t + 1] if t < T - 1 else 0.0
delta = rewards[:, t] + gamma * next_value - values[:, t]
lastgaelam = delta + gamma * lambd * lastgaelam
adv_rev.append(lastgaelam)
full_advantages = torch.stack(adv_rev[::-1], dim=1) # [B, max_len]
full_returns = full_advantages + values # [B, max_len]
return full_advantages, full_returns
def chunked_gae(
rewards: torch.Tensor,
values: torch.Tensor,
gamma: float,
lambd: float,
chunk_size: int = 128,
):
"""
Compute Generalized Advantage Estimation (GAE) using a FlashLinearAttention-
inspired algorithm: parallel prefix scan within chunks and recurrent state
propagation across chunks.
This reduces the sequential dependency length from O(T) to O(T / chunk_size),
while keeping chunk computations fully parallelizable (O(C^2) per chunk).
Args:
rewards (Tensor): [B, T] reward sequence.
values (Tensor): [B, T] value predictions. The next-value of the final
step is assumed to be zero (standard PPO convention).
gamma (float): discount factor.
lam (float): GAE lambda.
chunk_size (int): sequence chunk length for parallel scan.
Returns:
advantages (Tensor): [B, T] computed advantages.
returns (Tensor): [B, T] advantages + values.
"""
# -------------------------------------------------------------------------
# Validate inputs
# -------------------------------------------------------------------------
assert rewards.ndim == 2 and values.ndim == 2
B, T = rewards.shape
assert values.shape == (B, T)
device = rewards.device
dtype = rewards.dtype
# -------------------------------------------------------------------------
# Build δ_t = r_t + γ * V_{t+1} - V_t with V_{T} = 0
# -------------------------------------------------------------------------
next_values = torch.cat(
[values[:, 1:], torch.zeros(B, 1, device=device, dtype=dtype)],
dim=1,
)
deltas = rewards + gamma * next_values - values
# Reformulate backward GAE as a forward scan on the reversed sequence:
# S[i] = Δ[i] + w * S[i - 1], w = γλ
w = gamma * lambd
deltas_rev = torch.flip(deltas, dims=[1]) # [B, T]
# -------------------------------------------------------------------------
# Pad to a multiple of chunk_size
# -------------------------------------------------------------------------
if T % chunk_size != 0:
pad = chunk_size - (T % chunk_size)
deltas_rev = F.pad(deltas_rev, (0, pad))
else:
pad = 0
B, T_pad = deltas_rev.shape
n_chunks = T_pad // chunk_size
deltas_chunks = deltas_rev.view(B, n_chunks, chunk_size)
# -------------------------------------------------------------------------
# Construct the intra-chunk parallel scan kernel M
#
# For a chunk Δ[0..C-1], we want:
# S_local[t] = sum_{k=0..t} w^(t-k) * Δ[k]
#
# This is implemented as:
# S_local = Δ @ M
#
# where:
# M[i, j] = w^(j - i) if j >= i
# 0 otherwise
# -------------------------------------------------------------------------
idx = torch.arange(chunk_size, device=device)
row = idx[:, None]
col = idx[None, :]
diff = col - row
M = torch.zeros(chunk_size, chunk_size, device=device, dtype=dtype)
mask = diff >= 0
if w == 0.0:
M[mask & (diff == 0)] = 1.0
else:
M[mask] = w ** diff[mask].to(dtype)
# pow_vec[t] = w^(t+1), used to inject the recurrent state s_prev
if w == 0.0:
pow_vec = torch.zeros(chunk_size, device=device, dtype=dtype)
else:
pow_vec = w ** torch.arange(1, chunk_size + 1, device=device, dtype=dtype)
# -------------------------------------------------------------------------
# Parallel compute local chunk results (assuming initial state = 0)
# -------------------------------------------------------------------------
deltas_flat = deltas_chunks.reshape(B * n_chunks, chunk_size)
S_local_flat = deltas_flat @ M
S_local_chunks = S_local_flat.view(B, n_chunks, chunk_size)
# Effective length of each chunk (the last chunk may be padded)
lengths = [chunk_size] * n_chunks
if pad > 0:
lengths[-1] = chunk_size - pad
# -------------------------------------------------------------------------
# Recurrent propagation between chunks
#
# Each chunk contributes:
# S_global[t] = S_local[t] + w^(t+1) * s_prev
#
# And updates:
# s_prev = S_global[last_t]
# -------------------------------------------------------------------------
S_rev = deltas_rev.new_zeros(B, T_pad)
s_prev = torch.zeros(B, device=device, dtype=dtype)
for c in range(n_chunks):
Lc = lengths[c]
start = c * chunk_size
end = start + Lc
S_local = S_local_chunks[:, c, :Lc]
S_global = S_local + s_prev.unsqueeze(1) * pow_vec[:Lc]
S_rev[:, start:end] = S_global
s_prev = S_global[:, -1] # state for next chunk
# Remove padding and flip back to original time order
if pad > 0:
S_rev = S_rev[:, :T]
advantages = torch.flip(S_rev, dims=[1])
returns = advantages + values
return advantages, returns
def calculate_log_probs_and_entropy(logits, tokens, tp_group, with_entropy: bool = False, chunk_size: int = -1):
logits = logits.contiguous()
# TODO: not sure why we need to clone the logits here.
# Without the clone, the backward will trigger inplace edit error.
# It seems that the function with tp will modify the logits inplace.
entropy = None
if logits.size(0) != 0:
if chunk_size > 0:
num_chunks = (logits.size(0) - 1) // chunk_size + 1
tokens_chunks = tokens.chunk(num_chunks, dim=0)
logits_chunks = logits.chunk(num_chunks, dim=0)
log_probs = []
for tokens_chunk, logits_chunk in zip(tokens_chunks, logits_chunks, strict=True):
log_prob = compute_log_probs(logits_chunk.clone(), tokens_chunk, tp_group)
log_probs.append(log_prob)
log_prob = torch.cat(log_probs, dim=0)
if with_entropy:
entropys = []
for _, logits_chunk in zip(tokens_chunks, logits_chunks, strict=True):
entropy = compute_entropy_from_logits(logits_chunk.clone(), tp_group)
entropys.append(entropy)
entropy = torch.cat(entropys, dim=0)
else:
log_prob = compute_log_probs(logits.clone(), tokens, tp_group)
if with_entropy:
entropy = compute_entropy_from_logits(logits.clone(), tp_group)
else:
log_prob = logits.new_zeros((0,))
if with_entropy:
entropy = logits.new_zeros((0,))
return log_prob, entropy