|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Implement Actor |
|
|
""" |
|
|
|
|
|
import os |
|
|
from collections import defaultdict |
|
|
from typing import Any, Dict, Optional |
|
|
|
|
|
import torch |
|
|
from ray.experimental.tqdm_ray import tqdm |
|
|
from torch import nn |
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
|
|
|
|
|
from ...protocol import DataProto |
|
|
from ...trainer import core_algos |
|
|
from ...utils import torch_functional as VF |
|
|
from ...utils.py_functional import append_to_dict |
|
|
from ...utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs |
|
|
from .base import BasePPOActor |
|
|
from .config import ActorConfig |
|
|
|
|
|
|
|
|
try: |
|
|
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
|
|
|
__all__ = ["DataParallelPPOActor"] |
|
|
|
|
|
|
|
|
class DataParallelPPOActor(BasePPOActor): |
|
|
def __init__( |
|
|
self, |
|
|
config: ActorConfig, |
|
|
actor_module: nn.Module, |
|
|
actor_optimizer: Optional[torch.optim.Optimizer] = None, |
|
|
): |
|
|
""" |
|
|
When optimizer is None, it is Reference Policy |
|
|
""" |
|
|
super().__init__(config) |
|
|
self.rank = int(os.getenv("RANK", "0")) |
|
|
self.actor_module = actor_module |
|
|
self.actor_optimizer = actor_optimizer |
|
|
if config.use_torch_compile: |
|
|
self.log_probs_from_logits = torch.compile(VF.log_probs_from_logits, dynamic=True) |
|
|
else: |
|
|
self.log_probs_from_logits = VF.log_probs_from_logits |
|
|
|
|
|
def _forward_micro_batch(self, micro_batch: Dict[str, torch.Tensor], temperature: float) -> torch.Tensor: |
|
|
""" |
|
|
Returns: |
|
|
log_probs: # (bs, response_len) |
|
|
""" |
|
|
input_ids = micro_batch["input_ids"] |
|
|
batch_size, seqlen = input_ids.shape |
|
|
attention_mask = micro_batch["attention_mask"] |
|
|
position_ids = micro_batch["position_ids"] |
|
|
responses = micro_batch["responses"] |
|
|
response_length = responses.size(-1) |
|
|
if position_ids.dim() == 3: |
|
|
position_ids = position_ids.transpose(0, 1) |
|
|
|
|
|
multi_modal_inputs = {} |
|
|
if "multi_modal_inputs" in micro_batch: |
|
|
for key in micro_batch["multi_modal_inputs"][0].keys(): |
|
|
multi_modal_inputs[key] = torch.cat( |
|
|
[inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0 |
|
|
) |
|
|
|
|
|
if self.config.padding_free: |
|
|
input_ids_rmpad, indices, *_ = unpad_input( |
|
|
input_ids.unsqueeze(-1), attention_mask |
|
|
) |
|
|
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) |
|
|
|
|
|
|
|
|
if position_ids.dim() == 3: |
|
|
position_ids_rmpad = ( |
|
|
index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices) |
|
|
.transpose(0, 1) |
|
|
.unsqueeze(1) |
|
|
) |
|
|
else: |
|
|
position_ids_rmpad = index_first_axis( |
|
|
rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices |
|
|
).transpose(0, 1) |
|
|
|
|
|
|
|
|
input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) |
|
|
|
|
|
|
|
|
if self.config.ulysses_sequence_parallel_size > 1: |
|
|
input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( |
|
|
input_ids_rmpad, position_ids_rmpad, sp_size=self.config.ulysses_sequence_parallel_size |
|
|
) |
|
|
input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( |
|
|
input_ids_rmpad_rolled, None, self.config.ulysses_sequence_parallel_size |
|
|
) |
|
|
|
|
|
input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) |
|
|
|
|
|
|
|
|
output = self.actor_module( |
|
|
input_ids=input_ids_rmpad, |
|
|
attention_mask=None, |
|
|
position_ids=position_ids_rmpad, |
|
|
**multi_modal_inputs, |
|
|
use_cache=False, |
|
|
) |
|
|
logits_rmpad = output.logits.squeeze(0) |
|
|
logits_rmpad.div_(temperature) |
|
|
|
|
|
log_probs = self.log_probs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) |
|
|
|
|
|
|
|
|
if self.config.ulysses_sequence_parallel_size > 1: |
|
|
|
|
|
log_probs = gather_outputs_and_unpad(log_probs, gather_dim=0, unpad_dim=0, padding_size=pad_size) |
|
|
|
|
|
|
|
|
full_log_probs = pad_input( |
|
|
hidden_states=log_probs.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen |
|
|
) |
|
|
log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1 : -1] |
|
|
else: |
|
|
output = self.actor_module( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
**multi_modal_inputs, |
|
|
use_cache=False, |
|
|
) |
|
|
logits: torch.Tensor = output.logits |
|
|
logits.div_(temperature) |
|
|
logits = logits[:, -response_length - 1 : -1, :] |
|
|
log_probs = self.log_probs_from_logits(logits, responses) |
|
|
|
|
|
return log_probs |
|
|
|
|
|
def _optimizer_step(self) -> torch.Tensor: |
|
|
if isinstance(self.actor_module, FSDP): |
|
|
grad_norm = self.actor_module.clip_grad_norm_(self.config.max_grad_norm) |
|
|
else: |
|
|
grad_norm = nn.utils.clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.max_grad_norm) |
|
|
|
|
|
if not torch.isfinite(grad_norm): |
|
|
print("Gradient norm is not finite. Skip update.") |
|
|
else: |
|
|
self.actor_optimizer.step() |
|
|
|
|
|
self.actor_optimizer.zero_grad() |
|
|
return grad_norm |
|
|
|
|
|
@torch.no_grad() |
|
|
def compute_log_prob(self, data: DataProto) -> torch.Tensor: |
|
|
"""Compute the log probability of the responses given input_ids, attention_mask and position_ids |
|
|
|
|
|
Args: |
|
|
data (DataProto): a DataProto containing keys |
|
|
|
|
|
``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the |
|
|
concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``. |
|
|
|
|
|
``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64. |
|
|
|
|
|
``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. |
|
|
|
|
|
``responses``: tensor of shape [batch_size, response_length]. torch.int64. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: the log_prob tensor |
|
|
""" |
|
|
self.actor_module.eval() |
|
|
|
|
|
temperature = data.meta_info["temperature"] |
|
|
select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] |
|
|
if "multi_modal_inputs" in data.non_tensor_batch.keys(): |
|
|
non_tensor_select_keys = ["multi_modal_inputs"] |
|
|
else: |
|
|
non_tensor_select_keys = [] |
|
|
|
|
|
micro_batches = data.select(select_keys, non_tensor_select_keys).split( |
|
|
self.config.micro_batch_size_per_device_for_experience |
|
|
) |
|
|
log_probs_lst = [] |
|
|
if self.rank == 0: |
|
|
micro_batches = tqdm(micro_batches, desc="Compute log probs", position=2) |
|
|
|
|
|
for micro_batch in micro_batches: |
|
|
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} |
|
|
log_probs = self._forward_micro_batch(model_inputs, temperature=temperature) |
|
|
log_probs_lst.append(log_probs) |
|
|
|
|
|
log_probs = torch.concat(log_probs_lst, dim=0) |
|
|
return log_probs |
|
|
|
|
|
def update_policy(self, data: DataProto) -> Dict[str, Any]: |
|
|
self.actor_module.train() |
|
|
|
|
|
temperature = data.meta_info["temperature"] |
|
|
select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "old_log_probs", "advantages"] |
|
|
if self.config.use_kl_loss and not self.config.disable_kl: |
|
|
select_keys.append("ref_log_probs") |
|
|
|
|
|
if "multi_modal_inputs" in data.non_tensor_batch.keys(): |
|
|
non_tensor_select_keys = ["multi_modal_inputs"] |
|
|
else: |
|
|
non_tensor_select_keys = [] |
|
|
|
|
|
|
|
|
|
|
|
mini_batches = data.select(select_keys, non_tensor_select_keys).split(self.config.global_batch_size_per_device) |
|
|
|
|
|
metrics = defaultdict(list) |
|
|
for _ in range(self.config.ppo_epochs): |
|
|
if self.rank == 0: |
|
|
mini_batches = tqdm(mini_batches, desc="Train mini-batches", position=2) |
|
|
|
|
|
for mini_batch in mini_batches: |
|
|
gradient_accumulation = ( |
|
|
self.config.global_batch_size_per_device // self.config.micro_batch_size_per_device_for_update |
|
|
) |
|
|
micro_batches = mini_batch.split(self.config.micro_batch_size_per_device_for_update) |
|
|
if self.rank == 0: |
|
|
micro_batches = tqdm(micro_batches, desc="Update policy", position=3) |
|
|
|
|
|
for micro_batch in micro_batches: |
|
|
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} |
|
|
responses = model_inputs["responses"] |
|
|
response_length = responses.size(1) |
|
|
attention_mask = model_inputs["attention_mask"] |
|
|
response_mask = attention_mask[:, -response_length:] |
|
|
old_log_probs = model_inputs["old_log_probs"] |
|
|
advantages = model_inputs["advantages"] |
|
|
|
|
|
|
|
|
log_probs = self._forward_micro_batch(model_inputs, temperature=temperature) |
|
|
entropy_loss = -VF.masked_mean(log_probs, response_mask) |
|
|
|
|
|
pg_loss, pg_clipfrac_higher, pg_clipfrac_lower, ppo_kl = core_algos.compute_policy_loss( |
|
|
old_log_probs=old_log_probs, |
|
|
log_probs=log_probs, |
|
|
advantages=advantages, |
|
|
response_mask=response_mask, |
|
|
clip_ratio_low=self.config.clip_ratio_low, |
|
|
clip_ratio_high=self.config.clip_ratio_high, |
|
|
clip_ratio_dual=self.config.clip_ratio_dual, |
|
|
) |
|
|
if "ref_log_probs" in model_inputs: |
|
|
ref_log_probs = model_inputs["ref_log_probs"] |
|
|
|
|
|
kld = core_algos.compute_kl( |
|
|
log_probs=log_probs, |
|
|
ref_log_probs=ref_log_probs, |
|
|
kl_penalty=self.config.kl_penalty, |
|
|
) |
|
|
kl_loss = VF.masked_mean(kld, response_mask) |
|
|
pg_loss = pg_loss + kl_loss * self.config.kl_coef |
|
|
metrics["actor/kl_loss"] = kl_loss.detach().item() |
|
|
metrics["actor/kl_coef"] = self.config.kl_coef |
|
|
|
|
|
loss = pg_loss / gradient_accumulation |
|
|
loss.backward() |
|
|
|
|
|
batch_metrics = { |
|
|
"actor/pg_loss": pg_loss.detach().item(), |
|
|
"actor/pg_clipfrac_higher": pg_clipfrac_higher.detach().item(), |
|
|
"actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(), |
|
|
"actor/entropy_loss": entropy_loss.detach().item(), |
|
|
"actor/ppo_kl": ppo_kl.detach().item(), |
|
|
} |
|
|
append_to_dict(metrics, batch_metrics) |
|
|
|
|
|
grad_norm = self._optimizer_step() |
|
|
append_to_dict(metrics, {"actor/grad_norm": grad_norm.detach().item()}) |
|
|
|
|
|
return metrics |
|
|
|