Xin-Rui's picture
Upload folder using huggingface_hub
7155cf2 verified
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Implement Actor
"""
import os
from collections import defaultdict
from typing import Any, Dict, Optional
import torch
from ray.experimental.tqdm_ray import tqdm
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from ...protocol import DataProto
from ...trainer import core_algos
from ...utils import torch_functional as VF
from ...utils.py_functional import append_to_dict
from ...utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs
from .base import BasePPOActor
from .config import ActorConfig
try:
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
except ImportError:
pass
__all__ = ["DataParallelPPOActor"]
class DataParallelPPOActor(BasePPOActor):
def __init__(
self,
config: ActorConfig,
actor_module: nn.Module,
actor_optimizer: Optional[torch.optim.Optimizer] = None,
):
"""
When optimizer is None, it is Reference Policy
"""
super().__init__(config)
self.rank = int(os.getenv("RANK", "0"))
self.actor_module = actor_module
self.actor_optimizer = actor_optimizer
if config.use_torch_compile:
self.log_probs_from_logits = torch.compile(VF.log_probs_from_logits, dynamic=True)
else:
self.log_probs_from_logits = VF.log_probs_from_logits
def _forward_micro_batch(self, micro_batch: Dict[str, torch.Tensor], temperature: float) -> torch.Tensor:
"""
Returns:
log_probs: # (bs, response_len)
"""
input_ids = micro_batch["input_ids"]
batch_size, seqlen = input_ids.shape
attention_mask = micro_batch["attention_mask"]
position_ids = micro_batch["position_ids"]
responses = micro_batch["responses"]
response_length = responses.size(-1)
if position_ids.dim() == 3: # qwen2vl mrope
position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen)
multi_modal_inputs = {}
if "multi_modal_inputs" in micro_batch:
for key in micro_batch["multi_modal_inputs"][0].keys():
multi_modal_inputs[key] = torch.cat(
[inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0
)
if self.config.padding_free:
input_ids_rmpad, indices, *_ = unpad_input(
input_ids.unsqueeze(-1), attention_mask
) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
if position_ids.dim() == 3:
position_ids_rmpad = (
index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices)
.transpose(0, 1)
.unsqueeze(1)
) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen)
else:
position_ids_rmpad = index_first_axis(
rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
).transpose(0, 1)
# for compute the log_prob
input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz)
# pad and slice the inputs if sp > 1
if self.config.ulysses_sequence_parallel_size > 1:
input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(
input_ids_rmpad, position_ids_rmpad, sp_size=self.config.ulysses_sequence_parallel_size
)
input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(
input_ids_rmpad_rolled, None, self.config.ulysses_sequence_parallel_size
)
input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad)
# only pass input_ids and position_ids to enable flash_attn_varlen
output = self.actor_module(
input_ids=input_ids_rmpad,
attention_mask=None,
position_ids=position_ids_rmpad,
**multi_modal_inputs,
use_cache=False,
) # prevent model thinks we are generating
logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size)
logits_rmpad.div_(temperature)
# ((total_nnz / sp) + pad)
log_probs = self.log_probs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled)
# gather log_prob if sp > 1
if self.config.ulysses_sequence_parallel_size > 1:
# gather and unpad for the ulysses sp
log_probs = gather_outputs_and_unpad(log_probs, gather_dim=0, unpad_dim=0, padding_size=pad_size)
# pad back to (bsz, seqlen)
full_log_probs = pad_input(
hidden_states=log_probs.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen
)
log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length)
else:
output = self.actor_module(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
**multi_modal_inputs,
use_cache=False,
)
logits: torch.Tensor = output.logits
logits.div_(temperature)
logits = logits[:, -response_length - 1 : -1, :] # (bsz, response_length, vocab_size)
log_probs = self.log_probs_from_logits(logits, responses) # (bsz, response_length)
return log_probs
def _optimizer_step(self) -> torch.Tensor:
if isinstance(self.actor_module, FSDP):
grad_norm = self.actor_module.clip_grad_norm_(self.config.max_grad_norm)
else:
grad_norm = nn.utils.clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.max_grad_norm)
if not torch.isfinite(grad_norm):
print("Gradient norm is not finite. Skip update.")
else:
self.actor_optimizer.step()
self.actor_optimizer.zero_grad()
return grad_norm
@torch.no_grad()
def compute_log_prob(self, data: DataProto) -> torch.Tensor:
"""Compute the log probability of the responses given input_ids, attention_mask and position_ids
Args:
data (DataProto): a DataProto containing keys
``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the
concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``.
``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64.
``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64.
``responses``: tensor of shape [batch_size, response_length]. torch.int64.
Returns:
torch.Tensor: the log_prob tensor
"""
self.actor_module.eval()
temperature = data.meta_info["temperature"]
select_keys = ["responses", "input_ids", "attention_mask", "position_ids"]
if "multi_modal_inputs" in data.non_tensor_batch.keys():
non_tensor_select_keys = ["multi_modal_inputs"]
else:
non_tensor_select_keys = []
micro_batches = data.select(select_keys, non_tensor_select_keys).split(
self.config.micro_batch_size_per_device_for_experience
)
log_probs_lst = []
if self.rank == 0:
micro_batches = tqdm(micro_batches, desc="Compute log probs", position=2)
for micro_batch in micro_batches:
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
log_probs = self._forward_micro_batch(model_inputs, temperature=temperature)
log_probs_lst.append(log_probs)
log_probs = torch.concat(log_probs_lst, dim=0)
return log_probs
def update_policy(self, data: DataProto) -> Dict[str, Any]:
self.actor_module.train()
temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid slient error
select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "old_log_probs", "advantages"]
if self.config.use_kl_loss and not self.config.disable_kl:
select_keys.append("ref_log_probs")
if "multi_modal_inputs" in data.non_tensor_batch.keys():
non_tensor_select_keys = ["multi_modal_inputs"]
else:
non_tensor_select_keys = []
# Split to make minibatch iterator for updating the actor
# See PPO paper for details. https://arxiv.org/abs/1707.06347
mini_batches = data.select(select_keys, non_tensor_select_keys).split(self.config.global_batch_size_per_device)
metrics = defaultdict(list)
for _ in range(self.config.ppo_epochs):
if self.rank == 0:
mini_batches = tqdm(mini_batches, desc="Train mini-batches", position=2)
for mini_batch in mini_batches:
gradient_accumulation = (
self.config.global_batch_size_per_device // self.config.micro_batch_size_per_device_for_update
)
micro_batches = mini_batch.split(self.config.micro_batch_size_per_device_for_update)
if self.rank == 0:
micro_batches = tqdm(micro_batches, desc="Update policy", position=3)
for micro_batch in micro_batches:
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
responses = model_inputs["responses"]
response_length = responses.size(1)
attention_mask = model_inputs["attention_mask"]
response_mask = attention_mask[:, -response_length:]
old_log_probs = model_inputs["old_log_probs"]
advantages = model_inputs["advantages"]
# all return: (bsz, response_length)
log_probs = self._forward_micro_batch(model_inputs, temperature=temperature)
entropy_loss = -VF.masked_mean(log_probs, response_mask) # estimator of entropy loss
pg_loss, pg_clipfrac_higher, pg_clipfrac_lower, ppo_kl = core_algos.compute_policy_loss(
old_log_probs=old_log_probs,
log_probs=log_probs,
advantages=advantages,
response_mask=response_mask,
clip_ratio_low=self.config.clip_ratio_low,
clip_ratio_high=self.config.clip_ratio_high,
clip_ratio_dual=self.config.clip_ratio_dual,
)
if "ref_log_probs" in model_inputs:
ref_log_probs = model_inputs["ref_log_probs"]
# compute kl loss
kld = core_algos.compute_kl(
log_probs=log_probs,
ref_log_probs=ref_log_probs,
kl_penalty=self.config.kl_penalty,
)
kl_loss = VF.masked_mean(kld, response_mask)
pg_loss = pg_loss + kl_loss * self.config.kl_coef
metrics["actor/kl_loss"] = kl_loss.detach().item()
metrics["actor/kl_coef"] = self.config.kl_coef
loss = pg_loss / gradient_accumulation
loss.backward()
batch_metrics = {
"actor/pg_loss": pg_loss.detach().item(),
"actor/pg_clipfrac_higher": pg_clipfrac_higher.detach().item(),
"actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
"actor/entropy_loss": entropy_loss.detach().item(),
"actor/ppo_kl": ppo_kl.detach().item(),
}
append_to_dict(metrics, batch_metrics)
grad_norm = self._optimizer_step()
append_to_dict(metrics, {"actor/grad_norm": grad_norm.detach().item()})
return metrics