File size: 25,404 Bytes
d7b3a74 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 | # Adapt from https://github.com/OpenRLHF/OpenRLHF/blob/10c733694ed9fbb78a0a2ff6a05efc7401584d46/openrlhf/models/utils.py
# and https://github.com/OpenRLHF/OpenRLHF/blob/10c733694ed9fbb78a0a2ff6a05efc7401584d46/openrlhf/trainer/ppo_utils/experience_maker.py
from argparse import Namespace
import torch
import torch.distributed as dist
import torch.nn.functional as F
@torch.compile(dynamic=True)
def compute_approx_kl(
log_probs: torch.Tensor,
log_probs_base: torch.Tensor,
kl_loss_type: str,
importance_ratio: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Compute the approximate KL divergence between two distributions.
Schulman blog: http://joschu.net/blog/kl-approx.html
Args:
log_probs: Log probabilities of the new distribution.
log_probs_base: Log probabilities of the base distribution.
kl_loss_type: Type of KL estimator (k1, k2, k3, low_var_kl).
importance_ratio: Optional IS ratio (蟺_胃/蟺_old) for unbiased KL estimation.
"""
log_ratio = log_probs.float() - log_probs_base.float()
if kl_loss_type == "k1":
kl = log_ratio
elif kl_loss_type == "k2":
kl = log_ratio**2 / 2.0
elif kl_loss_type in ["k3", "low_var_kl"]:
# The non negative kl approximation in
# http://joschu.net/blog/kl-approx.html
# Besides non negative, it is also unbiased and have lower variance.
log_ratio = -log_ratio
kl = log_ratio.exp() - 1 - log_ratio
else:
raise ValueError(f"Unknown kl_loss_type: {kl_loss_type}")
# Apply IS ratio for unbiased KL estimation (DeepSeek-V3.2)
if importance_ratio is not None:
kl = importance_ratio * kl
# Clamp only for low_var_kl for numerical stability
if kl_loss_type == "low_var_kl":
kl = torch.clamp(kl, min=-10, max=10)
return kl
def compute_opsm_mask(
args: Namespace,
full_log_probs: list[torch.Tensor],
full_old_log_probs: list[torch.Tensor],
advantages: list[torch.Tensor],
loss_masks: list[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute Off-Policy Sequence Masking (OPSM) mask.
Args:
args: Configuration containing `opsm_delta` threshold.
full_log_probs: Current policy log-probs per sample.
full_old_log_probs: Old policy log-probs per sample.
advantages: Advantage values per sample.
loss_masks: Loss masks per sample.
Returns:
Tuple of `(opsm_mask, opsm_clipfrac)` where `opsm_mask` is a
concatenated tensor of per-token masks and
`opsm_clipfrac` is the count of masked sequences.
"""
opsm_mask_list = []
device = advantages[0].device
opsm_clipfrac = torch.tensor(0.0, device=device)
for full_log_prob, full_old_log_prob, advantage, loss_mask in zip(
full_log_probs, full_old_log_probs, advantages, loss_masks, strict=False
):
# Calculate sequence-level KL
seq_kl = ((full_old_log_prob - full_log_prob) * loss_mask).sum() / torch.clamp_min(loss_mask.sum(), 1)
# Create mask: 0 if (advantage < 0 and seq_kl > delta), else 1
mask = ((advantage < 0) & (seq_kl > args.opsm_delta)).float()
opsm_clipfrac += mask.sum() / torch.clamp_min(loss_mask.sum(), 1)
opsm_mask_list.append(1 - mask)
opsm_mask = torch.cat(opsm_mask_list, dim=0)
return opsm_mask, opsm_clipfrac
def compute_gspo_kl(
full_log_probs: list[torch.Tensor],
full_old_log_probs: list[torch.Tensor],
local_log_probs: list[torch.Tensor],
loss_masks: list[torch.Tensor],
) -> torch.Tensor:
"""Compute GSPO-style per-sequence KL divergence.
Args:
full_log_probs: Current policy log-probs per sample (full or CP-local).
full_old_log_probs: Old policy log-probs per sample (full or CP-local).
local_log_probs: Local (CP-local) log-probs for expansion shape reference.
loss_masks: Loss masks per sample.
Returns:
Concatenated tensor of per-token KL values where each token in a
sequence has the same KL value (the sequence-level KL).
"""
# Compute sequence-level KL and expand to per-token
ppo_kl = [
((old_logprob - log_prob) * loss_mask).sum() / torch.clamp_min(loss_mask.sum(), 1)
for log_prob, old_logprob, loss_mask in zip(full_log_probs, full_old_log_probs, loss_masks, strict=False)
]
ppo_kl = [kl.expand_as(log_prob) for kl, log_prob in zip(ppo_kl, local_log_probs, strict=False)]
ppo_kl = torch.cat(ppo_kl, dim=0)
return ppo_kl
@torch.compile(dynamic=True)
def compute_policy_loss(
ppo_kl: torch.Tensor,
advantages: torch.Tensor,
eps_clip: float,
eps_clip_high: float,
eps_clip_c: float | None = None,
):
ratio = (-ppo_kl).exp()
pg_losses1 = -ratio * advantages
pg_losses2 = -ratio.clamp(1 - eps_clip, 1 + eps_clip_high) * advantages
clip_pg_losses1 = torch.maximum(pg_losses1, pg_losses2)
clipfrac = torch.gt(pg_losses2, pg_losses1).float()
if eps_clip_c is not None:
assert (
eps_clip_c > 1.0
), f"The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0, but get the value: {eps_clip_c}."
pg_losses3 = -eps_clip_c * advantages
clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
else:
pg_losses = clip_pg_losses1
return pg_losses, clipfrac
def compute_log_probs(logits: torch.Tensor, tokens: torch.Tensor, process_group: dist.ProcessGroup | None):
# TODO: when megatron is not installed, fall back to naive implementation
from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy
# convert to [seq_len, batch_size, vocab_size] as expected by fused_vocab_parallel_cross_entropy
logits = logits.unsqueeze(1)
tokens = tokens.unsqueeze(1)
return -fused_vocab_parallel_cross_entropy(logits, tokens, process_group)
# from https://github.com/volcengine/verl/blob/0bdf7f469854815177e73dcfe9e420836c952e6e/verl/utils/megatron/tensor_parallel.py#L99
class _VocabParallelEntropy(torch.autograd.Function):
@staticmethod
def forward(ctx, vocab_parallel_logits: torch.Tensor, process_group: dist.ProcessGroup) -> torch.Tensor:
@torch.compile(dynamic=True)
def mul_reduce(a, b):
return (a * b).sum(dim=-1, keepdim=True)
logits_max = vocab_parallel_logits.max(dim=-1, keepdim=True).values
dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group)
normalized_vocab_parallel_logits = vocab_parallel_logits - logits_max
normalized_exp_logits = normalized_vocab_parallel_logits.exp_()
normalized_sum_exp_logits = normalized_exp_logits.sum(dim=-1, keepdim=True)
dist.all_reduce(normalized_sum_exp_logits, group=process_group)
softmax_logits = normalized_exp_logits.div_(normalized_sum_exp_logits)
sum_softmax_times_logits = mul_reduce(softmax_logits, vocab_parallel_logits)
dist.all_reduce(sum_softmax_times_logits, group=process_group)
entropy = logits_max + normalized_sum_exp_logits.log() - sum_softmax_times_logits
ctx.save_for_backward(vocab_parallel_logits, softmax_logits, sum_softmax_times_logits)
return entropy.squeeze(dim=-1)
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
vocab_parallel_logits, softmax_logits, sum_softmax_times_logits = ctx.saved_tensors
# reuse softmax_logits as grad
vocab_parallel_logits.sub_(sum_softmax_times_logits)
softmax_logits.mul_(vocab_parallel_logits)
softmax_logits.mul_(grad_output.unsqueeze(dim=-1))
# recover vocab_parallel_logits
vocab_parallel_logits.add_(sum_softmax_times_logits)
softmax_logits.mul_(-1)
return softmax_logits, None
def compute_entropy_from_logits(logits: torch.Tensor, process_group) -> torch.Tensor:
return _VocabParallelEntropy.apply(logits, process_group)
def get_grpo_returns(
rewards: torch.Tensor,
kl: list[torch.Tensor],
):
returns = []
for i in range(len(rewards)):
returns.append(torch.ones_like(kl[i]) * rewards[i])
return returns
def get_reinforce_plus_plus_returns(
rewards: torch.Tensor,
kl: list[torch.Tensor],
loss_masks: list[torch.Tensor],
response_lengths: list[int],
total_lengths: list[int],
kl_coef: float,
gamma: float,
) -> list[torch.Tensor]:
"""
Calculates discounted returns for REINFORCE++ (https://arxiv.org/pdf/2501.03262)
Args:
rewards (Tensor): A tensor of scalar rewards for each sequence.
kl (List[Tensor]): List of per-token KL divergence tensors for sequence chunks.
loss_masks (List[Tensor]): List of response-only loss masks for each full sequence.
response_lengths (List[int]): The full length of each response sequence.
total_lengths (List[int]): The full length of each sequence (prompt + response).
kl_coef (float): Coefficient for the KL penalty.
gamma (float): The discount factor.
Returns:
List[torch.Tensor]: A list of return (G_t) tensors for the
local sequence chunks owned by the current GPU rank.
"""
from megatron.core import mpu
cp_size = mpu.get_context_parallel_world_size()
final_returns_chunks = []
for i in range(len(rewards)):
local_kl_chunk = kl[i]
total_len, response_len = total_lengths[i], response_lengths[i]
if cp_size > 1:
# Step 1,2:Gather all chunks and token_offsets from all ranks and reconstruct the full response tensor by splitting and placing each part
from slime.backends.megatron_utils.cp_utils import all_gather_with_cp
full_kl_response = all_gather_with_cp(local_kl_chunk, total_len, response_len)
else:
full_kl_response = local_kl_chunk
# Step 3: Compute returns on full response kl tensor.
full_mask = loss_masks[i]
assert full_mask.sum().item() > 0, f"Sequence at index {i} is fully masked."
masked_kl = full_kl_response * full_mask
token_level_rewards = -kl_coef * masked_kl
last_idx = full_mask.nonzero(as_tuple=True)[0][-1]
token_level_rewards[last_idx] += rewards[i]
returns_for_seq = torch.zeros_like(token_level_rewards)
running_return = 0.0
for t in reversed(range(token_level_rewards.size(0))):
# G_t = r_t + gamma * G_{t+1}
running_return = token_level_rewards[t] + gamma * running_return
returns_for_seq[t] = running_return
# Step 4: Pick up the results corresponding to our local chunk's parts.
if cp_size > 1:
from slime.backends.megatron_utils.cp_utils import slice_log_prob_with_cp
local_returns_chunk = slice_log_prob_with_cp(returns_for_seq, total_len, response_len)
else:
local_returns_chunk = returns_for_seq
final_returns_chunks.append(local_returns_chunk)
return final_returns_chunks
def get_reinforce_plus_plus_baseline_advantages(
rewards: torch.Tensor,
kl: list[torch.Tensor],
loss_masks: list[torch.Tensor],
kl_coef: float,
) -> list[torch.Tensor]:
"""
Calculates the unwhitened advantages for the REINFORCE++-baseline algorithm.
Broadcasting the scalar (reward - group_baseline) to each token.
Args:
rewards (Tensor): A tensor of scalar rewards, where the group-wise
baseline has already been subtracted.
kl (list[Tensor]): A list of per-token KL divergence tensors. Used to
get the shape for broadcasting.
loss_masks (list[Tensor]): A list of per-token loss masks.
kl_coef (float): Coefficient for the KL penalty.
Returns:
list[Tensor]: A list of tensors containing the unwhitened advantages.
"""
# Broadcast to get unwhitened advantages
unwhitened_advantages = [
torch.ones_like(kl_tensor) * reward_val - kl_coef * kl_tensor
for kl_tensor, reward_val in zip(kl, rewards, strict=False)
]
return unwhitened_advantages
def get_advantages_and_returns(
total_len: int,
response_len: int,
values: torch.Tensor,
rewards: torch.Tensor,
gamma: float,
lambd: float,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Function that computes advantages and returns from rewards and values.
Calculated as in the original PPO paper: https://arxiv.org/abs/1707.06347
Note that rewards may include a KL divergence loss term.
Advantages looks like this:
Adv1 = R1 + 纬 * 位 * R2 + 纬^2 * 位^2 * R3 + ...
- V1 + 纬 * (1 - 位) V2 + 纬^2 * 位 * (1 - 位) V3 + ...
Returns looks like this:
Ret1 = R1 + 纬 * 位 * R2 + 纬^2 * 位^2 * R3 + ...
+ 纬 * (1 - 位) V2 + 纬^2 * 位 * (1 - 位) V3 + ...
Input:
- values: Tensor of shape (response_size,)
- rewards: Tensor of shape (response_size,)
Output:
- advantages: Tensor of shape (response_size,)
- returns: Tensor of shape (response_size,)
"""
from megatron.core import mpu
cp_size = mpu.get_context_parallel_world_size()
if cp_size > 1:
from slime.backends.megatron_utils.cp_utils import all_gather_with_cp
full_rewards = all_gather_with_cp(rewards, total_len, response_len)
full_values = all_gather_with_cp(values, total_len, response_len)
else:
full_rewards = rewards
full_values = values
lastgaelam = 0
advantages_reversed = []
for t in reversed(range(response_len)):
nextvalues = full_values[t + 1] if t < response_len - 1 else 0.0
delta = full_rewards[t] + gamma * nextvalues - full_values[t]
lastgaelam = delta + gamma * lambd * lastgaelam
advantages_reversed.append(lastgaelam)
full_advantages = torch.tensor(advantages_reversed[::-1], dtype=full_values.dtype, device=full_values.device)
full_returns = full_advantages + full_values
if cp_size > 1:
from slime.backends.megatron_utils.cp_utils import slice_log_prob_with_cp
advantages = slice_log_prob_with_cp(full_advantages, total_len, response_len)
returns = slice_log_prob_with_cp(full_returns, total_len, response_len)
else:
advantages = full_advantages
returns = full_returns
return advantages.detach(), returns
def get_advantages_and_returns_batch(
total_lengths,
response_lengths,
values_list,
rewards_list,
gamma,
lambd,
chunked: bool = True,
):
"""
Batched GAE with CP support.
Input:
total_lengths: list[int], each sample's total_len
response_lengths: list[int], each sample's response_len
values_list: list[Tensor], each shape = [resp_len_i]
rewards_list: list[Tensor], same shape
Output:
advantages_list: list[Tensor], each shape = [resp_len_i]
returns_list: list[Tensor], same shape
"""
from megatron.core import mpu
with torch.no_grad():
B = len(response_lengths)
assert B == len(values_list)
assert B == len(rewards_list)
cp_size = mpu.get_context_parallel_world_size()
device = values_list[0].device
dtype = values_list[0].dtype
if cp_size > 1:
from slime.backends.megatron_utils.cp_utils import all_gather_with_cp
full_values_list = []
full_rewards_list = []
for total_len, resp_len, v, r in zip(
total_lengths, response_lengths, values_list, rewards_list, strict=False
):
full_v = all_gather_with_cp(v, total_len, resp_len)
full_r = all_gather_with_cp(r, total_len, resp_len)
full_values_list.append(full_v)
full_rewards_list.append(full_r)
# full_values_list[i].shape = [total_len_i]
else:
full_values_list = values_list
full_rewards_list = rewards_list
# pad to max_len for batched GAE
max_len = max(response_lengths)
full_values = torch.zeros(B, max_len, device=device, dtype=dtype)
full_rewards = torch.zeros(B, max_len, device=device, dtype=dtype)
for i in range(B):
L = response_lengths[i]
full_values[i, :L] = full_values_list[i][:L]
full_rewards[i, :L] = full_rewards_list[i][:L]
if not chunked:
full_advantages, full_returns = vanilla_gae(
rewards=full_rewards,
values=full_values,
gamma=gamma,
lambd=lambd,
)
else:
full_advantages, full_returns = chunked_gae(
rewards=full_rewards,
values=full_values,
gamma=gamma,
lambd=lambd,
)
advantages_list = []
returns_list = []
if cp_size > 1:
from slime.backends.megatron_utils.cp_utils import slice_log_prob_with_cp
for total_len, resp_len, adv_row, ret_row in zip(
total_lengths,
response_lengths,
full_advantages,
full_returns,
strict=False,
):
adv_full = adv_row # shape = [resp_len_i padded to max_len]
ret_full = ret_row
adv_sliced = slice_log_prob_with_cp(adv_full[:resp_len], total_len, resp_len)
ret_sliced = slice_log_prob_with_cp(ret_full[:resp_len], total_len, resp_len)
advantages_list.append(adv_sliced)
returns_list.append(ret_sliced)
else:
for i in range(B):
L = response_lengths[i]
advantages_list.append(full_advantages[i, :L])
returns_list.append(full_returns[i, :L])
return advantages_list, returns_list
def vanilla_gae(
rewards: torch.Tensor,
values: torch.Tensor,
gamma: float,
lambd: float,
):
B, T = rewards.shape
device = rewards.device
dtype = rewards.dtype
lastgaelam = torch.zeros(B, device=device, dtype=dtype)
adv_rev = []
for t in reversed(range(T)):
next_value = values[:, t + 1] if t < T - 1 else 0.0
delta = rewards[:, t] + gamma * next_value - values[:, t]
lastgaelam = delta + gamma * lambd * lastgaelam
adv_rev.append(lastgaelam)
full_advantages = torch.stack(adv_rev[::-1], dim=1) # [B, max_len]
full_returns = full_advantages + values # [B, max_len]
return full_advantages, full_returns
def chunked_gae(
rewards: torch.Tensor,
values: torch.Tensor,
gamma: float,
lambd: float,
chunk_size: int = 128,
):
"""
Compute Generalized Advantage Estimation (GAE) using a FlashLinearAttention-
inspired algorithm: parallel prefix scan within chunks and recurrent state
propagation across chunks.
This reduces the sequential dependency length from O(T) to O(T / chunk_size),
while keeping chunk computations fully parallelizable (O(C^2) per chunk).
Args:
rewards (Tensor): [B, T] reward sequence.
values (Tensor): [B, T] value predictions. The next-value of the final
step is assumed to be zero (standard PPO convention).
gamma (float): discount factor.
lam (float): GAE lambda.
chunk_size (int): sequence chunk length for parallel scan.
Returns:
advantages (Tensor): [B, T] computed advantages.
returns (Tensor): [B, T] advantages + values.
"""
# -------------------------------------------------------------------------
# Validate inputs
# -------------------------------------------------------------------------
assert rewards.ndim == 2 and values.ndim == 2
B, T = rewards.shape
assert values.shape == (B, T)
device = rewards.device
dtype = rewards.dtype
# -------------------------------------------------------------------------
# Build 未_t = r_t + 纬 * V_{t+1} - V_t with V_{T} = 0
# -------------------------------------------------------------------------
next_values = torch.cat(
[values[:, 1:], torch.zeros(B, 1, device=device, dtype=dtype)],
dim=1,
)
deltas = rewards + gamma * next_values - values
# Reformulate backward GAE as a forward scan on the reversed sequence:
# S[i] = 螖[i] + w * S[i - 1], w = 纬位
w = gamma * lambd
deltas_rev = torch.flip(deltas, dims=[1]) # [B, T]
# -------------------------------------------------------------------------
# Pad to a multiple of chunk_size
# -------------------------------------------------------------------------
if T % chunk_size != 0:
pad = chunk_size - (T % chunk_size)
deltas_rev = F.pad(deltas_rev, (0, pad))
else:
pad = 0
B, T_pad = deltas_rev.shape
n_chunks = T_pad // chunk_size
deltas_chunks = deltas_rev.view(B, n_chunks, chunk_size)
# -------------------------------------------------------------------------
# Construct the intra-chunk parallel scan kernel M
#
# For a chunk 螖[0..C-1], we want:
# S_local[t] = sum_{k=0..t} w^(t-k) * 螖[k]
#
# This is implemented as:
# S_local = 螖 @ M
#
# where:
# M[i, j] = w^(j - i) if j >= i
# 0 otherwise
# -------------------------------------------------------------------------
idx = torch.arange(chunk_size, device=device)
row = idx[:, None]
col = idx[None, :]
diff = col - row
M = torch.zeros(chunk_size, chunk_size, device=device, dtype=dtype)
mask = diff >= 0
if w == 0.0:
M[mask & (diff == 0)] = 1.0
else:
M[mask] = w ** diff[mask].to(dtype)
# pow_vec[t] = w^(t+1), used to inject the recurrent state s_prev
if w == 0.0:
pow_vec = torch.zeros(chunk_size, device=device, dtype=dtype)
else:
pow_vec = w ** torch.arange(1, chunk_size + 1, device=device, dtype=dtype)
# -------------------------------------------------------------------------
# Parallel compute local chunk results (assuming initial state = 0)
# -------------------------------------------------------------------------
deltas_flat = deltas_chunks.reshape(B * n_chunks, chunk_size)
S_local_flat = deltas_flat @ M
S_local_chunks = S_local_flat.view(B, n_chunks, chunk_size)
# Effective length of each chunk (the last chunk may be padded)
lengths = [chunk_size] * n_chunks
if pad > 0:
lengths[-1] = chunk_size - pad
# -------------------------------------------------------------------------
# Recurrent propagation between chunks
#
# Each chunk contributes:
# S_global[t] = S_local[t] + w^(t+1) * s_prev
#
# And updates:
# s_prev = S_global[last_t]
# -------------------------------------------------------------------------
S_rev = deltas_rev.new_zeros(B, T_pad)
s_prev = torch.zeros(B, device=device, dtype=dtype)
for c in range(n_chunks):
Lc = lengths[c]
start = c * chunk_size
end = start + Lc
S_local = S_local_chunks[:, c, :Lc]
S_global = S_local + s_prev.unsqueeze(1) * pow_vec[:Lc]
S_rev[:, start:end] = S_global
s_prev = S_global[:, -1] # state for next chunk
# Remove padding and flip back to original time order
if pad > 0:
S_rev = S_rev[:, :T]
advantages = torch.flip(S_rev, dims=[1])
returns = advantages + values
return advantages, returns
def calculate_log_probs_and_entropy(logits, tokens, tp_group, with_entropy: bool = False, chunk_size: int = -1):
logits = logits.contiguous()
# TODO: not sure why we need to clone the logits here.
# Without the clone, the backward will trigger inplace edit error.
# It seems that the function with tp will modify the logits inplace.
entropy = None
if logits.size(0) != 0:
if chunk_size > 0:
num_chunks = (logits.size(0) - 1) // chunk_size + 1
tokens_chunks = tokens.chunk(num_chunks, dim=0)
logits_chunks = logits.chunk(num_chunks, dim=0)
log_probs = []
for tokens_chunk, logits_chunk in zip(tokens_chunks, logits_chunks, strict=True):
log_prob = compute_log_probs(logits_chunk.clone(), tokens_chunk, tp_group)
log_probs.append(log_prob)
log_prob = torch.cat(log_probs, dim=0)
if with_entropy:
entropys = []
for _, logits_chunk in zip(tokens_chunks, logits_chunks, strict=True):
entropy = compute_entropy_from_logits(logits_chunk.clone(), tp_group)
entropys.append(entropy)
entropy = torch.cat(entropys, dim=0)
else:
log_prob = compute_log_probs(logits.clone(), tokens, tp_group)
if with_entropy:
entropy = compute_entropy_from_logits(logits.clone(), tp_group)
else:
log_prob = logits.new_zeros((0,))
if with_entropy:
entropy = logits.new_zeros((0,))
return log_prob, entropy
|