File size: 22,323 Bytes
bcdf9fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 |
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Megatron Actor.
In megatron actor, the differences are:
1. We only make minibatch
Note that our model doesn't have to be `MegatronModule` because we don't share embedding in the last layer
"""
import logging
import os
from functools import partial
from typing import Dict, Iterable
import torch
import torch.distributed
from megatron.core import parallel_state as mpu
from megatron.core.distributed import finalize_model_grads
# from megatron.core.optimizer import DistributedOptimizer
from megatron.core.optimizer import DistributedOptimizer
from megatron.core.pipeline_parallel import get_forward_backward_func
from omegaconf import OmegaConf
from torch import nn
from verl import DataProto
from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, kl_penalty
from verl.utils.debug import GPUMemoryLogger
from verl.utils.debug.profile import Profiler
from verl.utils.megatron.pipeline_parallel import make_batch_generator
from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy, vocab_parallel_log_probs_from_logits
from verl.utils.megatron_utils import get_model_config
from verl.utils.py_functional import append_to_dict
from verl.utils.torch_functional import broadcast_dict_tensor, split_dict_tensor_into_batches
from verl.workers.actor import BasePPOActor
__all__ = ["MegatronPPOActor"]
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
class MegatronPPOActor(BasePPOActor):
def __init__(
self,
config,
model_config,
hf_config,
tf_config,
actor_module: nn.ModuleList,
actor_optimizer: DistributedOptimizer,
):
"""MeagtronPPOActor class. This class implements the simple PPO logics when the model is built with Megatron.
Args:
config (OmegaConf): the basic config that contains the hyper-parameters of PPO Actor. It must contain
``ppo_micro_batch_size_per_gpu``: micro batch size when updating ppo.
``ppo_mini_batch_size``: minibatch size when updating ppo using the batch data.
``ppo_epochs``: number of epochs to update the actor using the batch data.
``shuffle``: whether to shuffle the data after each ppo epoch.
``clip_ratio``: clip ratio of the ppo algorithm. See https://arxiv.org/abs/1707.06347.
``entropy_coeff``: entropy coefficient of the PPO loss. See https://arxiv.org/abs/1707.06347.
model_config (OmegaConf): model configuration. It must contains ``model_config.vocab_size`` and
``model_config.hidden_size``
hf_config (PretrainedConfig): huggingface config
tf_config (TransformerConfig): mcore transformer config
actor_module (nn.ModuleList): actor module is a ModuleList that contains a list of nn.Module in this pp stage.
each nn.Module in this rank holds a vpp module chunk. See https://arxiv.org/pdf/2104.04473.pdf for more details.
The actor module has some constraints to follow in order to use the updating logics implemented here
1. It must implement unpad_input before any computation and pad_input after all the computation. Remove padding is an
optimization that removes the padding tokens. See unpad_input and pad_input function in flash-attn
(https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py).
2. Each pp stage must return the hidden state with the same shape [total_nnz, 1, hidden_size],
where total_nnz is the number of valid tokens in this batch. If sequence parallel is enabled, the size
of the hidden state is [total_nnz // tp, 1, hidden_size].
actor_optimizer (DistributedOptimizer): currently, we only support DistributedOptimizer in Megatron. It implements
zero1 optimizer that shards the optimizer state across dp ranks.
>>> from megatron.training import get_model
>>> from megatron.optimizer import get_megatron_optimizer
>>> actor_module = get_model(megatron_actor_model_provider, wrap_with_ddp=True)
>>> actor_module = nn.ModuleList(actor_module)
>>> actor_optimizer = get_megatron_optimizer(actor_module)
>>> actor = MegatronPPOActor(config=config,
>>> model_config=actor_model_config,
>>> hf_config=hf_config,
>>> tf_config=tf_config,
>>> actor_module=actor_module,
>>> actor_optimizer=actor_optimizer)
"""
super().__init__(config)
self._validate_config(config)
self.model_config = model_config
self.hf_config = hf_config
self.tf_config = tf_config
self.actor_module = actor_module
self.actor_optimizer: DistributedOptimizer = actor_optimizer
self.prof = Profiler(self.config.profile)
self.optimizer_step_args = OmegaConf.create(
{
"skip_grad": None,
"overlap_dp_param_comm": False,
"overlap_dp_grad_comm": False,
"gradient_accumulation_steps": 1,
"sequence_parallel": self.tf_config.sequence_parallel,
"DDP_impl": "local",
"layernorm_allreduce_bucket_threshold": 0,
"pipeline_model_parallel_split_rank": None,
"reduce_grads_use_alltoall": False,
}
)
config = get_model_config(self.actor_module[0])
print(config)
config.finalize_model_grads_func = finalize_model_grads
def _validate_config(self, config) -> None:
"""Validate config options not implemented for Megatron backend"""
assert config.get("ulysses_sequence_parallel_size", 1) == 1
if config.get("shuffle", False):
assert config.data_loader_seed is not None, "If shuffle dataloader, seed must be manually set"
if config.megatron.tensor_model_parallel_size == 1:
print("[Warining] Because actor tp size == 1, set sp to False")
config.megatron.sequence_parallel = False
self.config = config
@GPUMemoryLogger(role="megatron actor", logger=logger)
def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Tensor:
"""Compute the log probability of the responses given input_ids, attention_mask and position_ids
Args:
data (DataProto): a DataProto containing keys
``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the
concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``.
``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64.
``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64.
``responses``: tensor of shape [batch_size, response_length]. torch.int64.
Returns:
DataProto: torch.Tensor: the log_prob tensor
"""
data.batch = data.batch.contiguous()
def compute_logprobs_fn(output, data):
response = data["responses"]
response_length = response.size(1)
logits = output
logits = logits[:, -response_length - 1 : -1].contiguous()
log_probs = vocab_parallel_log_probs_from_logits(logits, response)
return {"log_probs": log_probs}
# We make recompute_old_log_prob by default here.
# TODO (zhangchi.usc1992): actually, this function should only return log_prob and this logic should be handled by user outside
recompute_old_log_prob = self.config.get("recompute_old_log_prob", True)
entropys = torch.Tensor()
if recompute_old_log_prob:
select_keys = ["responses", "input_ids", "attention_mask", "position_ids"]
batch = data.select(batch_keys=select_keys).batch
input_ids = batch["input_ids"]
batch_size = input_ids.size(0)
response = batch["responses"]
response_length = response.size(1)
with torch.no_grad():
output = self.forward_backward_batch(data, forward_only=True, post_process_fn=compute_logprobs_fn, calculate_entropy=calculate_entropy)
if mpu.is_pipeline_last_stage(ignore_virtual=True):
# only on last rank. It should be on every tp rank
if calculate_entropy:
log_probs = torch.cat([o[0]["log_probs"] for o in output], dim=0) # (bs, seq_size)
else:
log_probs = torch.cat([o["log_probs"] for o in output], dim=0) # (bs, seq_size)
log_probs = log_probs.to(torch.float32)
else:
log_probs = torch.empty(size=(batch_size, response_length), dtype=torch.float32, device=input_ids.device)
# broadcast across pp ranks
torch.distributed.broadcast(
tensor=log_probs,
src=mpu.get_pipeline_model_parallel_last_rank(),
group=mpu.get_pipeline_model_parallel_group(),
async_op=False,
)
if calculate_entropy:
# Note that o[0] is metrics, o[1] is entropy
if mpu.is_pipeline_last_stage(ignore_virtual=True):
entropys = torch.cat([o[1] for o in output], dim=0)
entropys = entropys.to(torch.float32)
else:
entropys = torch.empty(size=(batch_size, response_length), dtype=torch.float32, device=input_ids.device)
# broadcast across pp ranks
torch.distributed.broadcast(
tensor=entropys,
src=mpu.get_pipeline_model_parallel_last_rank(),
group=mpu.get_pipeline_model_parallel_group(),
async_op=False,
)
# add empty cache after each compute
torch.cuda.empty_cache()
return log_probs, entropys
def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]:
"""Make minibatch iterator for updating the actor
Args:
data (DataProto): a DataProto containing keys
``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64, where ``sequence_length = prompt_length + response_length``
``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64
``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64
``responses``: tensor of shape [batch_size, response_length]. torch.int64. Note that responses = input_ids[:, -response_length:]
``old_log_probs``: tensor of shape [batch_size, response_length]. torch.float32. The log probability of responses.
``advantages``: tensor of shape [batch_size, response_length]. torch.float32. The advantages of responses.
See PPO paper for details. https://arxiv.org/abs/1707.06347
Returns:
"""
select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "old_log_probs", "advantages"]
if self.config.use_kl_loss:
select_keys.append("ref_log_prob")
data = data.select(batch_keys=select_keys)
return data.make_iterator(
mini_batch_size=self.config.ppo_mini_batch_size,
epochs=self.config.ppo_epochs,
seed=self.config.data_loader_seed,
dataloader_kwargs={"shuffle": self.config.shuffle},
)
def forward_backward_batch(self, data: DataProto, forward_only=False, post_process_fn=None, calculate_entropy=False):
"""
We assume:
- The model takes input: (input_ids, attention_mask, position_ids). No rmpad for the input
- The communication shape is (total_nnz_pad_to_sp // tp_size, 1, hidden_size) if sequence parallel is enabled
"""
# broadcast from last pp rank to all other pp ranks
# TODO: actually, we just need to control the sampling order.
broadcast_dict_tensor(data.batch, src=mpu.get_pipeline_model_parallel_last_rank(), group=mpu.get_pipeline_model_parallel_group())
# split into micro-batches
data.batch["attention_mask"] = data.batch["attention_mask"].to(bool)
if data.meta_info.get("micro_batch_size", None) is not None:
batch_size = data.meta_info["micro_batch_size"]
else:
batch_size = self.config.ppo_micro_batch_size_per_gpu
batches = split_dict_tensor_into_batches(data.batch, batch_size=batch_size)
# compute input shapes for pp stages
n_micro_batch = len(batches)
seq_len = batches[0]["input_ids"].shape[1]
forward_backward_func = get_forward_backward_func()
def loss_func(output, data, meta_info):
# For memory efficiency
# We move calculation of entropy to compute_log_probs, forward_only == True
metrics = {}
if forward_only:
if post_process_fn is None:
metrics["logits"] = output
else:
stats = post_process_fn(output, data)
metrics.update(stats)
if not calculate_entropy:
return torch.tensor(1.0, device=output.device), metrics
responses = data["responses"]
response_length = responses.size(1)
attention_mask = data["attention_mask"]
response_mask = attention_mask[:, -response_length:]
loss_agg_mode = self.config.loss_agg_mode
# compute policy loss
logits = output
logits = logits[:, -response_length - 1 : -1].contiguous()
ret_entropy = None
if not forward_only:
old_log_prob = data["old_log_probs"]
advantages = data["advantages"]
clip_ratio = meta_info["clip_ratio"]
clip_ratio_low = self.config.clip_ratio_low if self.config.clip_ratio_low is not None else clip_ratio
clip_ratio_high = self.config.clip_ratio_high if self.config.clip_ratio_high is not None else clip_ratio
clip_ratio_c = meta_info["clip_ratio_c"]
log_prob = vocab_parallel_log_probs_from_logits(logits, responses)
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss(
old_log_prob=old_log_prob,
log_prob=log_prob,
advantages=advantages,
response_mask=response_mask,
cliprange=clip_ratio,
cliprange_low=clip_ratio_low,
cliprange_high=clip_ratio_high,
clip_ratio_c=clip_ratio_c,
loss_agg_mode=loss_agg_mode,
)
policy_loss = pg_loss
if calculate_entropy:
entropy = vocab_parallel_entropy(logits)
if not forward_only:
entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
entropy_coeff = meta_info["entropy_coeff"]
policy_loss = pg_loss - entropy_coeff * entropy_loss
else:
ret_entropy = entropy
stats = {}
if forward_only:
policy_loss = torch.tensor(1.0, device=output.device)
else:
if self.config.use_kl_loss:
ref_log_prob = data["ref_log_prob"]
# compute kl loss
kld = kl_penalty(logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type)
kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=self.config.loss_agg_mode)
policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef
metrics["actor/kl_loss"] = kl_loss.detach().item()
metrics["actor/kl_coef"] = self.config.kl_loss_coef
# return loss and stats
stats.update(
{
"actor/pg_loss": pg_loss.detach().item(),
"actor/pg_clipfrac": pg_clipfrac.detach().item(),
"actor/ppo_kl": ppo_kl.detach().item(),
"actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
}
)
append_to_dict(metrics, stats)
return policy_loss, [metrics, ret_entropy]
def forward_step(batch_iter, model):
batch = next(batch_iter)
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
position_ids = batch["position_ids"]
from verl.models.mcore import get_mcore_forward_fn
forward_fn = get_mcore_forward_fn(self.hf_config)
output = forward_fn(model, input_ids, attention_mask, position_ids, sequence_parallel=self.tf_config.sequence_parallel)
if forward_only:
meta_info = None
else:
clip_ratio_c = self.config.get("clip_ratio_c", 3.0)
meta_info = {
"clip_ratio": self.config.clip_ratio,
"entropy_coeff": self.config.entropy_coeff,
"clip_ratio_c": clip_ratio_c,
}
return output, partial(loss_func, data=batch, meta_info=meta_info)
# batch should be a list of batches inside micro-batches
batch_generator = make_batch_generator(batches, vpp_size=len(self.actor_module))
# TODO: we may use the new schedule instead
# for flash-attn: (seq_len, batch_size, hidden_size) = (mbs*seq_len, 1, hidden_size)
if mpu.get_pipeline_model_parallel_world_size() > 1:
losses_reduced = forward_backward_func(
forward_step_func=forward_step,
data_iterator=batch_generator,
model=self.actor_module,
num_microbatches=n_micro_batch,
seq_length=batch_size * seq_len, # no use when input_shapes was set
micro_batch_size=1, # no use when input_shapes was set
forward_only=forward_only,
)
else:
losses_reduced = forward_backward_func(
forward_step_func=forward_step,
data_iterator=batch_generator,
model=self.actor_module,
num_microbatches=n_micro_batch,
seq_length=batch_size * seq_len, # in use for pp = 1
micro_batch_size=1, # in use for pp = 1
forward_only=forward_only,
)
# loss_reduces contains the stats returned from loss_func
return losses_reduced
@GPUMemoryLogger(role="megatron actor", logger=logger)
def update_policy(self, dataloader: Iterable[DataProto]) -> Dict:
"""Update the policy with an iterator of DataProto
Args:
dataloader (Iterable[DataProto]): an iterator over the DataProto that returns by ``make_minibatch_iterator``
The keys of each data batch is described in the make_minibatch_iterator.
Returns:
Dict: a dictionary containing the statistics. Note that the statistics are only valid in the last pp stage
and users have to combine the output in each dp rank manually.
"""
metrics = {}
self.prof.start()
for data in dataloader:
# data = data.batch.to(self.actor_module.device)
self.actor_optimizer.zero_grad()
# use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm
for chunk in self.actor_module:
# if use distributed optimizer, zero grad buffer will be handled by optimizer
chunk.zero_grad_buffer()
calculate_entropy = self.config.entropy_coeff != 0
metric_micro_batch = self.forward_backward_batch(data, calculate_entropy=calculate_entropy)
for metric in metric_micro_batch:
# Note that o[0] is metrics, o[1] is entropy, o[2] is response_mask
append_to_dict(metrics, metric[0]) # append the metric from this micro-batch to global metrics.
update_successful, grad_norm, num_zeros_in_grad = self.actor_optimizer.step()
if update_successful:
# allgather already execute in optimizer.step in new megatron
pass
else:
raise NotImplementedError
self.prof.step()
# add empty cache after each compute
self.prof.stop_and_save()
self.prof.stop_trace()
torch.cuda.empty_cache()
return metrics
|