| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Single Process Actor |
| """ |
|
|
| import logging |
| import os |
|
|
| import torch |
| from torch import nn |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| from torch.distributed.tensor import DTensor |
|
|
| import verl.utils.torch_functional as verl_F |
| from verl import DataProto |
| from verl.trainer.ppo.core_algos import agg_loss, get_policy_loss_fn, kl_penalty |
| from verl.utils.attention_utils import index_first_axis, pad_input, rearrange, unpad_input |
| from verl.utils.device import get_device_id, get_device_name |
| from verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_ |
| from verl.utils.profiler import GPUMemoryLogger |
| from verl.utils.py_functional import append_to_dict |
| from verl.utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch |
| from verl.utils.torch_dtypes import PrecisionType |
| from verl.utils.torch_functional import logprobs_from_logits |
| from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad, ulysses_pad_and_slice_inputs |
| from verl.workers.actor import BasePPOActor |
| from verl.workers.config import ActorConfig |
|
|
| __all__ = ["DataParallelPPOActor"] |
|
|
| logger = logging.getLogger(__file__) |
| logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) |
|
|
|
|
| class DataParallelPPOActor(BasePPOActor): |
| """FSDP DataParallel PPO Actor or Ref worker |
| |
| Args: |
| config (ActorConfig): Actor config |
| actor_module (nn.Module): Actor or ref module |
| actor_optimizer (torch.optim.Optimizer, optional): Actor optimizer. Defaults to None. |
| """ |
|
|
| def __init__(self, config: ActorConfig, actor_module: nn.Module, actor_optimizer: torch.optim.Optimizer = None): |
| """When optimizer is None, it is Reference Policy""" |
| super().__init__(config) |
| self.actor_module = actor_module |
| self.actor_optimizer = actor_optimizer |
| role = "Ref" if actor_optimizer is None else "Actor" |
|
|
| self.use_remove_padding = self.config.get("use_remove_padding", False) |
| if torch.distributed.get_rank() == 0: |
| print(f"{role} use_remove_padding={self.use_remove_padding}") |
| self.use_fused_kernels = self.config.get("use_fused_kernels", False) |
| if torch.distributed.get_rank() == 0: |
| print(f"{role} use_fused_kernels={self.use_fused_kernels}") |
|
|
| self.ulysses_sequence_parallel_size = self.config.ulysses_sequence_parallel_size |
| self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1 |
|
|
| self.use_dynamic_bsz = self.config.get("use_dynamic_bsz", False) |
|
|
| self.use_prefix_grouper = self.config.get("use_prefix_grouper", False) |
| if torch.distributed.get_rank() == 0: |
| print(f"{role} use_prefix_grouper={self.use_prefix_grouper}") |
|
|
| if self.config.entropy_from_logits_with_chunking: |
| entropy_from_logits = verl_F.entropy_from_logits_with_chunking |
| else: |
| entropy_from_logits = verl_F.entropy_from_logits |
|
|
| self.compute_entropy_from_logits = ( |
| torch.compile(entropy_from_logits, dynamic=True) |
| if self.config.get("use_torch_compile", True) |
| else entropy_from_logits |
| ) |
| self.device_name = get_device_name() |
| self.param_dtype = PrecisionType.to_dtype(self.config.fsdp_config.get("dtype", "bfloat16")) |
| if self.param_dtype == torch.float16: |
| from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler |
|
|
| self.scaler = ShardedGradScaler(growth_interval=400) |
| else: |
| self.scaler = None |
|
|
| |
| |
| if self.config.get("calculate_sum_pi_squared", False): |
| self.calculate_sum_pi_squared_from_logits = ( |
| torch.compile(verl_F.calculate_sum_pi_squared_from_logits, dynamic=True) |
| if self.config.get("use_torch_compile", True) |
| else verl_F.calculate_sum_pi_squared_from_logits |
| ) |
| assert not (self.use_fused_kernels or self.use_prefix_grouper), ( |
| "calculate_sum_pi_squared is not supported with " |
| f"{self.use_fused_kernels=} or {self.use_prefix_grouper=} for now." |
| ) |
|
|
| def _forward_micro_batch( |
| self, micro_batch: dict[str, torch.Tensor], temperature: float, calculate_entropy: bool = False |
| ) -> dict[str, torch.Tensor]: |
| """ |
| Returns: |
| dict[str, torch.Tensor]: |
| log_probs: (bs, response_len) |
| if calculate_entropy is True: |
| entropys: (bs, response_len) |
| if calculate_sum_pi_squared is False: |
| sum_pi_squared: (bs, response_len) |
| """ |
| calculate_sum_pi_squared = self.config.get("calculate_sum_pi_squared", False) |
| sum_pi_squared_checkpointing = self.config.get("sum_pi_squared_checkpointing", False) |
| |
| if self.use_prefix_grouper: |
| can_use_pg = ( |
| not self.use_remove_padding |
| and not self.use_ulysses_sp |
| and not self.use_fused_kernels |
| and not self.use_dynamic_bsz |
| ) |
| if can_use_pg and "response_mask" in micro_batch and "uid" in micro_batch: |
| from verl.trainer.ppo.prefix_grouper_utils import forward_micro_batch_with_prefix_grouper |
|
|
| return forward_micro_batch_with_prefix_grouper( |
| micro_batch=micro_batch, |
| model=self.actor_module, |
| temperature=temperature, |
| calculate_entropy=calculate_entropy, |
| device_name=self.device_name, |
| param_dtype=self.param_dtype, |
| use_chunking_entropy=self.config.get("entropy_from_logits_with_chunking", False), |
| ) |
|
|
| response_length = micro_batch["responses"].size(-1) |
| multi_modal_inputs = {} |
| if "multi_modal_inputs" in micro_batch.keys(): |
| from verl.utils.model import extract_multi_modal_inputs |
|
|
| multi_modal_inputs = extract_multi_modal_inputs(micro_batch["multi_modal_inputs"]) |
|
|
| with torch.autocast(device_type=self.device_name, dtype=self.param_dtype): |
| input_ids = micro_batch["input_ids"] |
| batch_size, seqlen = input_ids.shape |
| attention_mask = micro_batch["attention_mask"] |
| position_ids = micro_batch["position_ids"] |
| entropy = None |
| if position_ids.dim() == 3: |
| position_ids = position_ids.transpose(0, 1) |
|
|
| if self.use_remove_padding: |
| input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input( |
| input_ids.unsqueeze(-1), attention_mask |
| ) |
| input_ids_rmpad = input_ids_rmpad.transpose(0, 1) |
|
|
| |
| if position_ids.dim() == 3: |
| position_ids_rmpad = ( |
| index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices) |
| .transpose(0, 1) |
| .unsqueeze(1) |
| ) |
| else: |
| position_ids_rmpad = index_first_axis( |
| rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices |
| ).transpose(0, 1) |
|
|
| is_mask_all_zero = attention_mask.sum() == 0 |
| if is_mask_all_zero: |
| input_ids_rmpad = torch.zeros( |
| (1, self.ulysses_sequence_parallel_size), |
| device=input_ids.device, |
| dtype=input_ids.dtype, |
| ) |
| if position_ids.dim() == 3: |
| position_ids_rmpad = torch.zeros( |
| (position_ids.shape[0], 1, self.ulysses_sequence_parallel_size), |
| device=position_ids.device, |
| dtype=position_ids.dtype, |
| ) |
| else: |
| position_ids_rmpad = torch.zeros( |
| (1, self.ulysses_sequence_parallel_size), |
| device=position_ids.device, |
| dtype=position_ids.dtype, |
| ) |
|
|
| if "image_bound" in multi_modal_inputs: |
| from verl.utils.dataset.vision_utils import process_multi_modal_inputs_for_minicpmo |
|
|
| multi_modal_inputs = process_multi_modal_inputs_for_minicpmo( |
| input_ids, attention_mask, position_ids, cu_seqlens, multi_modal_inputs |
| ) |
|
|
| |
| input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) |
|
|
| |
| if self.use_ulysses_sp: |
| is_vlm_model = hasattr( |
| getattr(self.actor_module, "module", self.actor_module).config, "vision_config" |
| ) |
| if is_vlm_model: |
| |
| input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad( |
| input_ids_rmpad, |
| position_ids_rmpad=position_ids_rmpad, |
| sp_size=self.ulysses_sequence_parallel_size, |
| ) |
| else: |
| input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( |
| input_ids_rmpad, |
| position_ids_rmpad=position_ids_rmpad, |
| sp_size=self.ulysses_sequence_parallel_size, |
| ) |
| input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( |
| input_ids_rmpad_rolled, |
| position_ids_rmpad=None, |
| sp_size=self.ulysses_sequence_parallel_size, |
| ) |
|
|
| input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) |
|
|
| |
| extra_args = {} |
| if self.use_fused_kernels: |
| extra_args["temperature"] = temperature |
| extra_args["return_dict"] = True |
|
|
| output = self.actor_module( |
| input_ids=input_ids_rmpad, |
| attention_mask=None, |
| position_ids=position_ids_rmpad, |
| **multi_modal_inputs, |
| use_cache=False, |
| **extra_args, |
| ) |
|
|
| if self.use_fused_kernels: |
| log_probs = output.log_probs.squeeze(0) |
| entropy_rmpad = output.entropy.squeeze(0) |
|
|
| else: |
| logits_rmpad = output.logits.squeeze(0) |
| logits_rmpad.div_(temperature) |
|
|
| |
| inplace_backward = True |
| if calculate_entropy: |
| inplace_backward = False |
| log_probs = logprobs_from_logits( |
| logits=logits_rmpad, |
| labels=input_ids_rmpad_rolled, |
| inplace_backward=inplace_backward, |
| ) |
|
|
| |
| if calculate_entropy: |
| |
| entropy_rmpad = ( |
| self.compute_entropy_from_logits(logits_rmpad) |
| if not self.config.entropy_checkpointing |
| else torch.utils.checkpoint.checkpoint(self.compute_entropy_from_logits, logits_rmpad) |
| ) |
|
|
| |
| if calculate_sum_pi_squared: |
| sum_pi_squared_rmpad = ( |
| self.calculate_sum_pi_squared_from_logits(logits_rmpad) |
| if not sum_pi_squared_checkpointing |
| else torch.utils.checkpoint.checkpoint( |
| self.calculate_sum_pi_squared_from_logits, logits_rmpad |
| ) |
| ) |
|
|
| |
| if self.use_ulysses_sp: |
| |
| log_probs = gather_outputs_and_unpad( |
| log_probs, |
| gather_dim=0, |
| unpad_dim=0, |
| padding_size=pad_size, |
| ) |
| if calculate_entropy: |
| entropy_rmpad = gather_outputs_and_unpad( |
| entropy_rmpad, |
| gather_dim=0, |
| unpad_dim=0, |
| padding_size=pad_size, |
| ) |
| if calculate_sum_pi_squared: |
| sum_pi_squared_rmpad = gather_outputs_and_unpad( |
| sum_pi_squared_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size |
| ) |
|
|
| if is_mask_all_zero: |
| log_probs = log_probs[:0] |
| if calculate_entropy: |
| entropy_rmpad = entropy_rmpad[:0] |
|
|
| |
| if calculate_entropy: |
| full_entropy = pad_input( |
| hidden_states=entropy_rmpad.unsqueeze(-1), |
| indices=indices, |
| batch=batch_size, |
| seqlen=seqlen, |
| ) |
| if calculate_sum_pi_squared: |
| full_sum_pi_squared = pad_input( |
| hidden_states=sum_pi_squared_rmpad.unsqueeze(-1), |
| indices=indices, |
| batch=batch_size, |
| seqlen=seqlen, |
| ) |
| full_log_probs = pad_input( |
| hidden_states=log_probs.unsqueeze(-1), |
| indices=indices, |
| batch=batch_size, |
| seqlen=seqlen, |
| ) |
|
|
| |
| if calculate_entropy: |
| entropy = full_entropy.squeeze(-1)[:, -response_length - 1 : -1] |
| if calculate_sum_pi_squared: |
| |
| sum_pi_squared = full_sum_pi_squared.squeeze(-1)[:, -response_length - 1 : -1] |
| log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1 : -1] |
|
|
| else: |
| extra_args = {} |
| if self.use_fused_kernels: |
| extra_args["temperature"] = temperature |
| extra_args["return_dict"] = True |
|
|
| output = self.actor_module( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| **multi_modal_inputs, |
| use_cache=False, |
| **extra_args, |
| ) |
|
|
| if self.use_fused_kernels: |
| log_probs = output.log_probs[:, -response_length - 1 : -1] |
| entropy = output.entropy[:, -response_length - 1 : -1] |
|
|
| else: |
| logits = output.logits |
|
|
| logits.div_(temperature) |
| logits = logits[:, -response_length - 1 : -1, :] |
| log_probs = logprobs_from_logits(logits, micro_batch["responses"]) |
| if calculate_entropy: |
| if not self.config.entropy_checkpointing: |
| entropy = verl_F.entropy_from_logits(logits) |
| else: |
| entropy = torch.utils.checkpoint.checkpoint(verl_F.entropy_from_logits, logits) |
| |
| if calculate_sum_pi_squared: |
| sum_pi_squared = ( |
| self.calculate_sum_pi_squared_from_logits(logits) |
| if not sum_pi_squared_checkpointing |
| else torch.utils.checkpoint.checkpoint(self.calculate_sum_pi_squared_from_logits, logits) |
| ) |
|
|
| outputs = {"log_probs": log_probs} |
| if calculate_entropy: |
| outputs["entropys"] = entropy |
| if calculate_sum_pi_squared: |
| outputs["sum_pi_squared"] = sum_pi_squared |
| return outputs |
|
|
| def _optimizer_step(self): |
| assert self.config.grad_clip is not None |
| if self.scaler is not None: |
| self.scaler.unscale_(self.actor_optimizer) |
| if isinstance(self.actor_module, FSDP): |
| grad_norm = self.actor_module.clip_grad_norm_(max_norm=self.config.grad_clip) |
| elif isinstance(self.actor_module, FSDPModule): |
| grad_norm = fsdp2_clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.grad_clip) |
| else: |
| grad_norm = torch.nn.utils.clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.grad_clip) |
|
|
| if isinstance(grad_norm, DTensor): |
| grad_norm = grad_norm.full_tensor() |
|
|
| |
| if self.scaler is not None: |
| self.scaler.step(self.actor_optimizer) |
| self.scaler.update() |
| else: |
| if not torch.isfinite(grad_norm): |
| print(f"WARN: rank {torch.distributed.get_rank()} grad_norm is not finite: {grad_norm}") |
| self.actor_optimizer.zero_grad() |
| else: |
| self.actor_optimizer.step() |
|
|
| |
| if getattr(self.actor_module, "_qat_fuse_enabled", False): |
| from verl.utils.qat import invalidate_all_scales |
|
|
| invalidate_all_scales(self.actor_module) |
|
|
| return grad_norm |
|
|
| @GPUMemoryLogger(role="dp actor", logger=logger) |
| def compute_log_prob(self, data: DataProto, calculate_entropy: bool = False) -> dict[str, torch.Tensor]: |
| """Compute the log probability of the responses given input_ids, attention_mask and position_ids |
| |
| Args: |
| data (DataProto): a DataProto containing keys |
| |
| ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the |
| concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``. |
| |
| ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64. |
| |
| ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. |
| |
| ``responses``: tensor of shape [batch_size, response_length]. torch.int64. |
| |
| Returns: |
| dict[str, torch.Tensor]: a dict containing keys |
| - ``log_probs``: tensor of shape [batch_size, response_length]. torch.float32. |
| - ``entropys``: tensor of shape [batch_size, response_length]. torch.float32. |
| - ``sum_pi_squared``: tensor of shape [batch_size, response_length]. torch.float32. |
| """ |
| calculate_sum_pi_squared = self.config.get("calculate_sum_pi_squared", False) |
|
|
| |
| self.actor_module.eval() |
|
|
| micro_batch_size = data.meta_info["micro_batch_size"] |
| temperature = data.meta_info["temperature"] |
| use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] |
| pad_token_id = data.meta_info.get("pad_token_id", 0) |
| has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() |
|
|
| select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] |
| non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else [] |
| if self.use_prefix_grouper: |
| select_keys += [k for k in ["prompts", "response_mask"] if k in data.batch] |
| if "uid" in data.non_tensor_batch: |
| non_tensor_select_keys.append("uid") |
|
|
| data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) |
|
|
| if use_dynamic_bsz: |
| max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size |
| micro_batches, batch_idx_list = prepare_dynamic_batch(data, max_token_len=max_token_len) |
| else: |
| micro_batches = data.split(micro_batch_size) |
|
|
| log_probs_lst = [] |
| entropy_lst = [] |
| sum_pi_squared_lst = [] |
| for micro_batch in micro_batches: |
| micro_batch = micro_batch.to(get_device_id()) |
| model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch, "pad_token_id": pad_token_id} |
| with torch.no_grad(): |
| outputs = self._forward_micro_batch( |
| model_inputs, temperature=temperature, calculate_entropy=calculate_entropy |
| ) |
| log_probs_lst.append(outputs["log_probs"]) |
| if calculate_entropy: |
| entropy_lst.append(outputs["entropys"]) |
| if calculate_sum_pi_squared: |
| sum_pi_squared_lst.append(outputs["sum_pi_squared"]) |
|
|
| log_probs = torch.concat(log_probs_lst, dim=0) |
| if calculate_entropy: |
| entropys = torch.concat(entropy_lst, dim=0) |
| if calculate_sum_pi_squared: |
| sum_pi_squared = torch.concat(sum_pi_squared_lst, dim=0) |
|
|
| if use_dynamic_bsz: |
| log_probs = restore_dynamic_batch(log_probs, batch_idx_list) |
| if calculate_entropy: |
| entropys = restore_dynamic_batch(entropys, batch_idx_list) |
| if calculate_sum_pi_squared: |
| sum_pi_squared = restore_dynamic_batch(sum_pi_squared, batch_idx_list) |
|
|
| outputs = {"log_probs": log_probs} |
| if calculate_entropy: |
| outputs["entropys"] = entropys |
| if calculate_sum_pi_squared: |
| outputs["sum_pi_squared"] = sum_pi_squared |
| return outputs |
|
|
| @GPUMemoryLogger(role="dp actor", logger=logger) |
| def update_policy(self, data: DataProto): |
| |
| self.actor_module.train() |
|
|
| temperature = data.meta_info["temperature"] |
| pad_token_id = data.meta_info.get("pad_token_id", 0) |
|
|
| select_keys = [ |
| "responses", |
| "response_mask", |
| "input_ids", |
| "attention_mask", |
| "position_ids", |
| "old_log_probs", |
| "advantages", |
| ] |
| if self.use_prefix_grouper and "prompts" in data.batch.keys(): |
| select_keys.append("prompts") |
| if self.config.use_kl_loss: |
| select_keys.append("ref_log_prob") |
| |
| |
| if "rollout_is_weights" in data.batch.keys(): |
| select_keys.append("rollout_is_weights") |
| |
| if "rollout_log_probs" in data.batch.keys(): |
| select_keys.append("rollout_log_probs") |
|
|
| has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() |
| non_tensor_select_keys = [] |
| if has_multi_modal_inputs: |
| non_tensor_select_keys.append("multi_modal_inputs") |
| if self.use_prefix_grouper and "uid" in data.non_tensor_batch.keys(): |
| non_tensor_select_keys.append("uid") |
|
|
| data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) |
|
|
| |
| |
| mini_batches = data.split(self.config.ppo_mini_batch_size) |
|
|
| on_policy = len(mini_batches) == 1 and self.config.ppo_epochs == 1 |
|
|
| metrics = { |
| "actor/pg_loss": 0.0, |
| "actor/kl_loss": 0.0, |
| } |
| for _ in range(self.config.ppo_epochs): |
| for batch_idx, mini_batch in enumerate(mini_batches): |
| if self.config.use_dynamic_bsz: |
| max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size |
| micro_batches, _ = prepare_dynamic_batch(mini_batch, max_token_len=max_token_len) |
| else: |
| self.gradient_accumulation = ( |
| self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu |
| ) |
| micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) |
|
|
| self.actor_optimizer.zero_grad() |
|
|
| for micro_batch in micro_batches: |
| micro_batch = micro_batch.to(get_device_id()) |
| micro_batch_metrics = {} |
| model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch, "pad_token_id": pad_token_id} |
| response_mask = model_inputs["response_mask"] |
| old_log_prob = model_inputs["old_log_probs"] |
| advantages = model_inputs["advantages"] |
|
|
| entropy_coeff = self.config.entropy_coeff |
| loss_agg_mode = self.config.loss_agg_mode |
|
|
| calculate_entropy = self.config.calculate_entropy or (entropy_coeff != 0) |
|
|
| if self.config.use_dynamic_bsz: |
| loss_scale_factor = response_mask.shape[0] / self.config.ppo_mini_batch_size |
| else: |
| loss_scale_factor = 1 / self.gradient_accumulation |
|
|
| |
| outputs = self._forward_micro_batch( |
| model_inputs, temperature=temperature, calculate_entropy=calculate_entropy |
| ) |
| log_prob = outputs["log_probs"] |
| entropy = outputs["entropys"] if calculate_entropy else None |
|
|
| |
| if hasattr(self.config, "use_rollout_log_probs") and self.config.use_rollout_log_probs: |
| old_log_prob = model_inputs["old_log_probs"] |
| else: |
| if on_policy: |
| old_log_prob = log_prob.detach() |
| else: |
| old_log_prob = model_inputs["old_log_probs"] |
|
|
| loss_mode = self.config.policy_loss.get("loss_mode", "vanilla") |
| |
|
|
| |
| |
| rollout_is_weights = model_inputs.get("rollout_is_weights", None) |
|
|
| |
| |
| policy_loss_fn = get_policy_loss_fn(loss_mode) |
|
|
| |
| pg_loss, pg_metrics = policy_loss_fn( |
| old_log_prob=old_log_prob, |
| log_prob=log_prob, |
| advantages=advantages, |
| response_mask=response_mask, |
| loss_agg_mode=loss_agg_mode, |
| config=self.config, |
| rollout_is_weights=rollout_is_weights, |
| ) |
| micro_batch_metrics.update(pg_metrics) |
|
|
| |
| rollout_log_prob = model_inputs.get("rollout_log_probs", None) |
| if loss_mode != "bypass_mode" and rollout_log_prob is not None: |
| |
| |
| from verl.trainer.ppo.rollout_corr_helper import compute_rollout_corr_metrics_from_logprobs |
|
|
| rollout_corr_metrics = compute_rollout_corr_metrics_from_logprobs( |
| log_prob=log_prob, |
| rollout_log_prob=rollout_log_prob, |
| response_mask=response_mask, |
| ) |
| micro_batch_metrics.update(rollout_corr_metrics) |
|
|
| policy_loss = pg_loss |
| if calculate_entropy and entropy is not None: |
| entropy_agg = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) |
| micro_batch_metrics["actor/entropy"] = entropy_agg.detach().item() |
| if entropy_coeff != 0: |
| policy_loss -= entropy_agg * entropy_coeff |
|
|
| if self.config.use_kl_loss: |
| ref_log_prob = model_inputs["ref_log_prob"] |
| |
| kld = kl_penalty( |
| logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type |
| ) |
| kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) |
|
|
| policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef |
| metrics["actor/kl_loss"] += kl_loss.detach().item() * loss_scale_factor |
| micro_batch_metrics["actor/kl_coef"] = self.config.kl_loss_coef |
|
|
| if self.config.use_dynamic_bsz: |
| |
| loss = policy_loss * loss_scale_factor |
| else: |
| loss = policy_loss * loss_scale_factor |
| if self.scaler is not None: |
| self.scaler.scale(loss).backward() |
| else: |
| loss.backward() |
|
|
| metrics["actor/pg_loss"] += pg_loss.detach().item() * loss_scale_factor |
| append_to_dict(metrics, micro_batch_metrics) |
|
|
| grad_norm = self._optimizer_step() |
| mini_batch_metrics = {"actor/grad_norm": grad_norm.detach().item()} |
| append_to_dict(metrics, mini_batch_metrics) |
| self.actor_optimizer.zero_grad() |
| return metrics |
|
|