|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
FSDP PPO Trainer with Ray-based single controller. |
|
|
This trainer supports model-agonistic model initialization with huggingface |
|
|
""" |
|
|
|
|
|
import os |
|
|
import uuid |
|
|
from collections import defaultdict |
|
|
from contextlib import contextmanager |
|
|
from copy import deepcopy |
|
|
from dataclasses import dataclass, field |
|
|
from enum import Enum, IntEnum, auto |
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Type |
|
|
|
|
|
import numpy as np |
|
|
import ray |
|
|
import torch |
|
|
from codetiming import Timer |
|
|
from ray.experimental.tqdm_ray import tqdm |
|
|
from torch.utils.data import RandomSampler, SequentialSampler |
|
|
from torchdata.stateful_dataloader import StatefulDataLoader |
|
|
from transformers import PreTrainedTokenizer, ProcessorMixin |
|
|
|
|
|
from ..protocol import DataProto, pad_dataproto_to_divisor, unpad_dataproto |
|
|
from ..single_controller.base import Worker |
|
|
from ..single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup |
|
|
from ..single_controller.ray.base import create_colocated_worker_cls |
|
|
from ..utils import torch_functional as VF |
|
|
from ..utils.checkpoint import CHECKPOINT_TRACKER, remove_obsolete_ckpt |
|
|
from ..utils.dataset import RLHFDataset, collate_fn, CurriculumCollator |
|
|
from ..utils.logger import Tracker |
|
|
from ..utils.py_functional import convert_dict_to_str |
|
|
from ..utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance |
|
|
from ..workers.fsdp_workers import FSDPWorker |
|
|
from . import core_algos |
|
|
|
|
|
from .config import PPOConfig |
|
|
from .metrics import compute_data_metrics, compute_throughout_metrics, compute_timing_metrics, reduce_metrics |
|
|
from .model_merger import merge_and_save_model, reorganize_folders |
|
|
|
|
|
import itertools |
|
|
|
|
|
class Role(IntEnum): |
|
|
""" |
|
|
To create more roles dynamically, you can subclass Role and add new members |
|
|
""" |
|
|
|
|
|
Actor = auto() |
|
|
Rollout = auto() |
|
|
ActorRollout = auto() |
|
|
Critic = auto() |
|
|
RefPolicy = auto() |
|
|
RewardModel = auto() |
|
|
ActorRolloutRef = auto() |
|
|
|
|
|
|
|
|
class AdvantageEstimator(str, Enum): |
|
|
""" |
|
|
Using an enumeration class to avoid spelling errors in adv_estimator |
|
|
""" |
|
|
|
|
|
GAE = "gae" |
|
|
GRPO = "grpo" |
|
|
REINFORCE_PLUS_PLUS = "reinforce_plus_plus" |
|
|
REMAX = "remax" |
|
|
RLOO = "rloo" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ResourcePoolManager: |
|
|
""" |
|
|
Define a resource pool specification. Resource pool will be initialized first. |
|
|
""" |
|
|
|
|
|
resource_pool_spec: dict[str, list[int]] |
|
|
mapping: dict[Role, str] |
|
|
resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict) |
|
|
|
|
|
def create_resource_pool(self): |
|
|
for resource_pool_name, process_on_nodes in self.resource_pool_spec.items(): |
|
|
|
|
|
|
|
|
|
|
|
resource_pool = RayResourcePool( |
|
|
process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name |
|
|
) |
|
|
self.resource_pool_dict[resource_pool_name] = resource_pool |
|
|
|
|
|
self._check_resource_available() |
|
|
|
|
|
def get_resource_pool(self, role: Role) -> RayResourcePool: |
|
|
"""Get the resource pool of the worker.""" |
|
|
return self.resource_pool_dict[self.mapping[role]] |
|
|
|
|
|
def get_n_gpus(self) -> int: |
|
|
"""Get the number of gpus in this cluster.""" |
|
|
return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]) |
|
|
|
|
|
def _check_resource_available(self): |
|
|
"""Check if the resource pool can be satisfied in this ray cluster.""" |
|
|
node_available_resources = ray.state.available_resources_per_node() |
|
|
node_available_gpus = {node: node_info.get("GPU", 0) for node, node_info in node_available_resources.items()} |
|
|
|
|
|
|
|
|
total_available_gpus = sum(node_available_gpus.values()) |
|
|
total_required_gpus = sum( |
|
|
[n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes] |
|
|
) |
|
|
if total_available_gpus < total_required_gpus: |
|
|
raise ValueError( |
|
|
f"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}." |
|
|
) |
|
|
|
|
|
|
|
|
def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.KLController, kl_penalty="kl"): |
|
|
token_level_scores = data.batch["token_level_scores"] |
|
|
batch_size = data.batch.batch_size[0] |
|
|
response_mask = data.batch["response_mask"] |
|
|
|
|
|
|
|
|
if "ref_log_probs" in data.batch.keys(): |
|
|
kld = core_algos.compute_kl(data.batch["old_log_probs"], data.batch["ref_log_probs"], kl_penalty=kl_penalty) |
|
|
kld = kld * response_mask |
|
|
else: |
|
|
kld = torch.zeros_like(response_mask, dtype=torch.float32) |
|
|
|
|
|
data.batch["token_level_rewards"] = token_level_scores - kl_ctrl.kl_coef * kld |
|
|
|
|
|
current_kl = VF.masked_mean(kld, mask=response_mask, dim=-1) |
|
|
current_kl = torch.mean(current_kl, dim=0).item() |
|
|
metrics = {"critic/kl": current_kl, "critic/kl_coef": kl_ctrl.kl_coef} |
|
|
|
|
|
|
|
|
kl_ctrl.update(current_kl=current_kl, n_steps=batch_size) |
|
|
return data, metrics |
|
|
|
|
|
|
|
|
def compute_advantage(data: DataProto, adv_estimator: AdvantageEstimator, gamma: float = 1.0, lam: float = 1.0): |
|
|
token_level_rewards = data.batch["token_level_rewards"] |
|
|
response_mask = data.batch["response_mask"] |
|
|
index = data.non_tensor_batch["uid"] |
|
|
if adv_estimator == AdvantageEstimator.GAE: |
|
|
values = data.batch["values"] |
|
|
advantages, returns = core_algos.compute_gae_advantage_return( |
|
|
token_level_rewards, values, response_mask, gamma, lam |
|
|
) |
|
|
elif adv_estimator == AdvantageEstimator.GRPO: |
|
|
advantages, returns = core_algos.compute_grpo_outcome_advantage(token_level_rewards, response_mask, index) |
|
|
elif adv_estimator == AdvantageEstimator.REINFORCE_PLUS_PLUS: |
|
|
advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage( |
|
|
token_level_rewards, response_mask, gamma |
|
|
) |
|
|
elif adv_estimator == AdvantageEstimator.REMAX: |
|
|
reward_baselines = data.batch["reward_baselines"] |
|
|
advantages, returns = core_algos.compute_remax_outcome_advantage( |
|
|
token_level_rewards, reward_baselines, response_mask |
|
|
) |
|
|
elif adv_estimator == AdvantageEstimator.RLOO: |
|
|
advantages, returns = core_algos.compute_rloo_outcome_advantage(token_level_rewards, response_mask, index) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
data.batch["advantages"] = advantages |
|
|
data.batch["returns"] = returns |
|
|
return data |
|
|
|
|
|
|
|
|
@contextmanager |
|
|
def _timer(name: str, timing_raw: Dict[str, float]): |
|
|
with Timer(name=name, logger=None) as timer: |
|
|
yield |
|
|
|
|
|
timing_raw[name] = timer.last |
|
|
|
|
|
|
|
|
class RayPPOTrainer: |
|
|
""" |
|
|
Note that this trainer runs on the driver process on a single CPU/GPU node. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: PPOConfig, |
|
|
tokenizer: PreTrainedTokenizer, |
|
|
processor: Optional[ProcessorMixin], |
|
|
role_worker_mapping: dict[Role, Type[Worker]], |
|
|
resource_pool_manager: ResourcePoolManager, |
|
|
ray_worker_group_cls: Type[RayWorkerGroup] = RayWorkerGroup, |
|
|
reward_fn: Optional[Callable[[DataProto], Tuple[torch.Tensor, Dict[str, List[float]]]]] = None, |
|
|
val_reward_fn: Optional[Callable[[DataProto], Tuple[torch.Tensor, Dict[str, List[float]]]]] = None, |
|
|
): |
|
|
self.tokenizer = tokenizer |
|
|
self.processor = processor |
|
|
self.config = config |
|
|
self.reward_fn = reward_fn |
|
|
self.val_reward_fn = val_reward_fn |
|
|
|
|
|
self.hybrid_engine = config.worker.hybrid_engine |
|
|
if self.hybrid_engine: |
|
|
assert Role.ActorRollout in role_worker_mapping, ( |
|
|
f"ActorRollout should be included in {role_worker_mapping.keys()}." |
|
|
) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
self.role_worker_mapping = role_worker_mapping |
|
|
self.resource_pool_manager = resource_pool_manager |
|
|
self.use_reward_model = Role.RewardModel in role_worker_mapping |
|
|
self.ray_worker_group_cls = ray_worker_group_cls |
|
|
|
|
|
|
|
|
if Role.RefPolicy in role_worker_mapping and not config.algorithm.disable_kl: |
|
|
self.use_reference_policy = True |
|
|
self.kl_ctrl = core_algos.get_kl_controller(config.algorithm) |
|
|
else: |
|
|
self.use_reference_policy = False |
|
|
self.kl_ctrl = core_algos.FixedKLController(init_kl_coef=0.0) |
|
|
print("KL is disabled, no KL metrics will be logged. Please set `kl_coef=0` to log KL metrics.") |
|
|
|
|
|
if config.algorithm.adv_estimator == AdvantageEstimator.GAE: |
|
|
self.use_critic = True |
|
|
else: |
|
|
self.use_critic = False |
|
|
|
|
|
if config.algorithm.adv_estimator not in list(AdvantageEstimator): |
|
|
raise NotImplementedError(f"Unknown advantage estimator: {config.algorithm.adv_estimator}.") |
|
|
|
|
|
if config.data.rollout_batch_size % config.worker.actor.global_batch_size != 0: |
|
|
raise ValueError("Rollout batch size must be divisible by actor global batch size.") |
|
|
|
|
|
if ( |
|
|
config.data.rollout_batch_size * config.worker.rollout.n |
|
|
) % config.worker.actor.micro_batch_size_per_device_for_experience != 0: |
|
|
raise ValueError( |
|
|
"Rollout batch size * rollout.n must be divisible by actor micro batch size for experience." |
|
|
) |
|
|
|
|
|
if self.use_critic: |
|
|
if config.data.rollout_batch_size % config.worker.critic.global_batch_size != 0: |
|
|
raise ValueError("Rollout batch size must be divisible by critic global batch size.") |
|
|
|
|
|
if ( |
|
|
config.data.rollout_batch_size * config.worker.rollout.n |
|
|
) % config.worker.critic.micro_batch_size_per_device_for_experience != 0: |
|
|
raise ValueError( |
|
|
"Rollout batch size * rollout.n must be divisible by critic micro batch size for experience." |
|
|
) |
|
|
|
|
|
if ( |
|
|
config.algorithm.adv_estimator in (AdvantageEstimator.GRPO, AdvantageEstimator.RLOO) |
|
|
and config.worker.rollout.n == 1 |
|
|
): |
|
|
raise ValueError("GRPO and RLOO algorithm need `config.worker.rollout.n > 1`.") |
|
|
|
|
|
self._create_val_dataloader() |
|
|
self.max_accu = 0 |
|
|
self.current_reward_accu=-1 |
|
|
|
|
|
def _create_val_dataloader(self) -> None: |
|
|
|
|
|
self.val_dataset = RLHFDataset( |
|
|
data_path=self.config.data.val_files, |
|
|
tokenizer=self.tokenizer, |
|
|
processor=self.processor, |
|
|
prompt_key=self.config.data.prompt_key, |
|
|
answer_key=self.config.data.answer_key, |
|
|
image_key=self.config.data.image_key, |
|
|
max_prompt_length=self.config.data.max_prompt_length, |
|
|
truncation="right", |
|
|
format_prompt=self.config.data.format_prompt, |
|
|
min_pixels=self.config.data.min_pixels, |
|
|
max_pixels=self.config.data.max_pixels, |
|
|
) |
|
|
self.val_dataloader = StatefulDataLoader( |
|
|
dataset=self.val_dataset, |
|
|
batch_size=len(self.val_dataset) |
|
|
if self.config.data.val_batch_size == -1 |
|
|
else self.config.data.val_batch_size, |
|
|
shuffle=False, |
|
|
num_workers=8, |
|
|
collate_fn=collate_fn, |
|
|
|
|
|
pin_memory=False, |
|
|
drop_last=False, |
|
|
) |
|
|
|
|
|
assert len(self.val_dataloader) >= 1 |
|
|
print(f"Size of val dataloader: {len(self.val_dataloader)}") |
|
|
|
|
|
|
|
|
def _create_dataloader(self, current_epoch) -> None: |
|
|
|
|
|
self.collator = CurriculumCollator(total_epoches=self.config.trainer.total_episodes, current_epoch = current_epoch) |
|
|
|
|
|
self.train_dataset = RLHFDataset( |
|
|
data_path=self.config.data.train_files, |
|
|
tokenizer=self.tokenizer, |
|
|
processor=self.processor, |
|
|
prompt_key=self.config.data.prompt_key, |
|
|
answer_key=self.config.data.answer_key, |
|
|
image_key=self.config.data.image_key, |
|
|
max_prompt_length=self.config.data.max_prompt_length, |
|
|
truncation="right", |
|
|
format_prompt=self.config.data.format_prompt, |
|
|
min_pixels=self.config.data.min_pixels, |
|
|
max_pixels=self.config.data.max_pixels, |
|
|
) |
|
|
|
|
|
if self.config.data.shuffle: |
|
|
train_dataloader_generator = torch.Generator() |
|
|
train_dataloader_generator.manual_seed(self.config.data.seed) |
|
|
sampler = RandomSampler(data_source=self.train_dataset, generator=train_dataloader_generator) |
|
|
else: |
|
|
sampler = SequentialSampler(data_source=self.train_dataset) |
|
|
|
|
|
self.train_dataloader = StatefulDataLoader( |
|
|
dataset=self.train_dataset, |
|
|
batch_size=self.config.data.rollout_batch_size, |
|
|
sampler=sampler, |
|
|
num_workers=8, |
|
|
|
|
|
collate_fn=self.collator, |
|
|
pin_memory=False, |
|
|
drop_last=True, |
|
|
) |
|
|
|
|
|
|
|
|
assert len(self.train_dataloader) >= 1 |
|
|
print(f"Size of train dataloader: {len(self.train_dataloader)}") |
|
|
|
|
|
if self.config.trainer.max_steps is not None: |
|
|
training_steps = self.config.trainer.max_steps |
|
|
else: |
|
|
training_steps = len(self.train_dataloader) * self.config.trainer.total_episodes |
|
|
|
|
|
self.training_steps = training_steps |
|
|
self.config.worker.actor.optim.training_steps = training_steps |
|
|
self.config.worker.critic.optim.training_steps = training_steps |
|
|
print(f"Total training steps: {self.training_steps}") |
|
|
|
|
|
def _maybe_log_val_generations( |
|
|
self, inputs: List[str], outputs: List[str], labels: List[str], scores: List[float] |
|
|
) -> None: |
|
|
"""Log a table of validation samples""" |
|
|
if self.config.trainer.val_generations_to_log <= 0: |
|
|
return |
|
|
|
|
|
|
|
|
samples = list(zip(inputs, outputs, labels, scores)) |
|
|
samples.sort(key=lambda x: x[0]) |
|
|
|
|
|
|
|
|
rng = np.random.RandomState(42) |
|
|
rng.shuffle(samples) |
|
|
|
|
|
samples = samples[: self.config.trainer.val_generations_to_log] |
|
|
self.logger.log_generation(samples, self.global_step) |
|
|
|
|
|
def _validate(self) -> Dict[str, Any]: |
|
|
ori_stage_env = os.environ.get("stage", "1") |
|
|
|
|
|
os.environ['stage'] = "2" |
|
|
print(f"stage for validation: {os.environ['stage']}") |
|
|
reward_tensor_lst = [] |
|
|
|
|
|
sample_inputs, sample_outputs, sample_labels, sample_scores = [], [], [], [] |
|
|
reward_metrics_lst = defaultdict(list) |
|
|
for batch_dict in self.val_dataloader: |
|
|
test_batch = DataProto.from_single_dict(batch_dict) |
|
|
|
|
|
input_ids = test_batch.batch["input_ids"] |
|
|
input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] |
|
|
sample_inputs.extend(input_texts) |
|
|
|
|
|
if "multi_modal_inputs" in test_batch.non_tensor_batch.keys(): |
|
|
test_gen_batch = test_batch.pop( |
|
|
batch_keys=["input_ids", "attention_mask", "position_ids"], |
|
|
non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data", "multi_modal_inputs", "stage"], |
|
|
) |
|
|
else: |
|
|
test_gen_batch = test_batch.pop( |
|
|
batch_keys=["input_ids", "attention_mask", "position_ids"], |
|
|
non_tensor_batch_keys=["raw_prompt_ids", "stage"], |
|
|
) |
|
|
|
|
|
test_gen_batch.non_tensor_batch['budget'] = test_batch.non_tensor_batch['budget'] |
|
|
|
|
|
test_gen_batch.meta_info = self.config.worker.rollout.val_override_config |
|
|
test_gen_batch, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size) |
|
|
test_output_gen_batch = self.actor_rollout_wg.generate_sequences(test_gen_batch) |
|
|
test_output_gen_batch = unpad_dataproto(test_output_gen_batch, pad_size=pad_size) |
|
|
print("validation generation end") |
|
|
|
|
|
|
|
|
output_ids = test_output_gen_batch.batch["responses"] |
|
|
output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] |
|
|
sample_outputs.extend(output_texts) |
|
|
sample_labels.extend(test_batch.non_tensor_batch["ground_truth"].tolist()) |
|
|
test_batch = test_batch.union(test_output_gen_batch) |
|
|
|
|
|
|
|
|
reward_tensor, reward_metrics = self.val_reward_fn(test_batch) |
|
|
|
|
|
|
|
|
scores = reward_tensor.sum(-1).cpu().tolist() |
|
|
sample_scores.extend(scores) |
|
|
|
|
|
reward_tensor_lst.append(reward_tensor) |
|
|
for key, value in reward_metrics.items(): |
|
|
reward_metrics_lst[key].extend(value) |
|
|
|
|
|
self._maybe_log_val_generations(sample_inputs, sample_outputs, sample_labels, sample_scores) |
|
|
reward_score = torch.cat(reward_tensor_lst, dim=0).sum(-1).mean().item() |
|
|
val_reward_metrics = {f"val/{key}_reward": value for key, value in reduce_metrics(reward_metrics_lst).items()} |
|
|
|
|
|
|
|
|
self.current_reward_accu = val_reward_metrics['val/accuracy_reward'] |
|
|
self.max_accu = max(self.max_accu, self.current_reward_accu) |
|
|
|
|
|
os.environ['stage'] = ori_stage_env |
|
|
print(f"stage for training: {os.environ['stage']}") |
|
|
return {"val/reward_score": reward_score, **val_reward_metrics} |
|
|
|
|
|
def init_workers(self) -> None: |
|
|
"""Init resource pool and worker group""" |
|
|
self.resource_pool_manager.create_resource_pool() |
|
|
self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} |
|
|
|
|
|
|
|
|
if self.hybrid_engine: |
|
|
resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout) |
|
|
actor_rollout_cls = RayClassWithInitArgs( |
|
|
cls=self.role_worker_mapping[Role.ActorRollout], config=self.config.worker, role="actor_rollout" |
|
|
) |
|
|
self.resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
if self.use_critic: |
|
|
resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) |
|
|
critic_cls = RayClassWithInitArgs( |
|
|
cls=self.role_worker_mapping[Role.Critic], config=self.config.worker, role="critic" |
|
|
) |
|
|
self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls |
|
|
|
|
|
|
|
|
if self.use_reference_policy: |
|
|
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) |
|
|
ref_policy_cls = RayClassWithInitArgs( |
|
|
self.role_worker_mapping[Role.RefPolicy], config=self.config.worker, role="ref" |
|
|
) |
|
|
self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls |
|
|
|
|
|
|
|
|
if self.use_reward_model: |
|
|
|
|
|
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) |
|
|
rm_cls = RayClassWithInitArgs( |
|
|
cls=self.role_worker_mapping[Role.RewardModel], config=self.config.worker, role="reward" |
|
|
) |
|
|
self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
all_wg: Dict[str, FSDPWorker] = {} |
|
|
self.wg_dicts = [] |
|
|
for resource_pool, class_dict in self.resource_pool_to_cls.items(): |
|
|
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) |
|
|
wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls) |
|
|
spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) |
|
|
all_wg.update(spawn_wg) |
|
|
|
|
|
self.wg_dicts.append(wg_dict) |
|
|
|
|
|
if self.use_critic: |
|
|
self.critic_wg = all_wg["critic"] |
|
|
self.critic_wg.init_model() |
|
|
|
|
|
if self.use_reference_policy: |
|
|
self.ref_policy_wg = all_wg["ref"] |
|
|
self.ref_policy_wg.init_model() |
|
|
|
|
|
if self.use_reward_model: |
|
|
self.rm_wg = all_wg["rm"] |
|
|
self.rm_wg.init_model() |
|
|
|
|
|
|
|
|
self.actor_rollout_wg = all_wg["actor_rollout"] |
|
|
self.actor_rollout_wg.init_model() |
|
|
|
|
|
def _save_checkpoint(self) -> None: |
|
|
|
|
|
remove_obsolete_ckpt( |
|
|
self.config.trainer.save_checkpoint_path, self.global_step, self.config.trainer.save_limit |
|
|
) |
|
|
folder_path = os.path.join(self.config.trainer.save_checkpoint_path, f"global_step_{self.global_step}") |
|
|
actor_path = os.path.join(folder_path, "actor") |
|
|
self.actor_rollout_wg.save_checkpoint(actor_path) |
|
|
|
|
|
if self.use_critic: |
|
|
critic_path = os.path.join(folder_path, "critic") |
|
|
self.critic_wg.save_checkpoint(critic_path) |
|
|
|
|
|
dataloader_path = os.path.join(folder_path, "dataloader.pt") |
|
|
dataloader_state_dict = self.train_dataloader.state_dict() |
|
|
torch.save(dataloader_state_dict, dataloader_path) |
|
|
|
|
|
last_global_step_path = os.path.join(self.config.trainer.save_checkpoint_path, CHECKPOINT_TRACKER) |
|
|
with open(last_global_step_path, "w") as f: |
|
|
f.write(str(self.global_step)) |
|
|
|
|
|
def _save_checkpoin_maxaccu(self) -> None: |
|
|
|
|
|
|
|
|
import re |
|
|
checkpoint_folder = self.config.trainer.save_checkpoint_path |
|
|
folder_path = os.path.join(self.config.trainer.save_checkpoint_path, f"step_{self.global_step}_reward_{self.max_accu}") |
|
|
actor_path = os.path.join(folder_path, "actor") |
|
|
self.actor_rollout_wg.save_checkpoint(actor_path) |
|
|
|
|
|
if self.use_critic: |
|
|
critic_path = os.path.join(folder_path, "critic") |
|
|
self.critic_wg.save_checkpoint(critic_path) |
|
|
|
|
|
dataloader_path = os.path.join(folder_path, "dataloader.pt") |
|
|
dataloader_state_dict = self.train_dataloader.state_dict() |
|
|
torch.save(dataloader_state_dict, dataloader_path) |
|
|
|
|
|
actor_path = folder_path + "/actor" |
|
|
merge_and_save_model(actor_path) |
|
|
reorganize_folders(folder_path) |
|
|
|
|
|
|
|
|
def _load_checkpoint(self) -> None: |
|
|
if self.config.trainer.load_checkpoint_path is None: |
|
|
return |
|
|
|
|
|
if "global_step_" not in self.config.trainer.load_checkpoint_path.strip(os.path.sep).split(os.path.sep)[-1]: |
|
|
raise ValueError("`load_checkpoint_path` should end with `global_step_*`.") |
|
|
|
|
|
print(f"Load from checkpoint: {self.config.trainer.load_checkpoint_path}.") |
|
|
self.global_step = int(self.config.trainer.load_checkpoint_path.strip(os.path.sep).split("global_step_")[-1]) |
|
|
actor_path = os.path.join(self.config.trainer.load_checkpoint_path, "actor") |
|
|
self.actor_rollout_wg.load_checkpoint(actor_path) |
|
|
if self.use_critic: |
|
|
critic_path = os.path.join(self.config.trainer.load_checkpoint_path, "critic") |
|
|
self.critic_wg.load_checkpoint(critic_path) |
|
|
|
|
|
dataloader_path = os.path.join(self.config.trainer.load_checkpoint_path, "dataloader.pt") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _balance_batch(self, batch: DataProto, metrics: Dict[str, Any], logging_prefix: str = "global_seqlen") -> None: |
|
|
"""Reorder the data on single controller such that each dp rank gets similar total tokens""" |
|
|
attention_mask = batch.batch["attention_mask"] |
|
|
batch_size = attention_mask.shape[0] |
|
|
global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist() |
|
|
world_size = self.actor_rollout_wg.world_size |
|
|
global_partition_lst = get_seqlen_balanced_partitions( |
|
|
global_seqlen_lst, k_partitions=world_size, equal_size=True |
|
|
) |
|
|
|
|
|
global_idx = torch.tensor([j for partition in global_partition_lst for j in partition]) |
|
|
batch.reorder(global_idx) |
|
|
global_balance_stats = log_seqlen_unbalance( |
|
|
seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix |
|
|
) |
|
|
metrics.update(global_balance_stats) |
|
|
|
|
|
def fit(self): |
|
|
""" |
|
|
The training loop of PPO with DAPO-style dynamic sampling added. |
|
|
""" |
|
|
reward_score_function = self.config.worker.reward.score_function |
|
|
self.logger = Tracker(loggers=self.config.trainer.logger, config=self.config.to_dict()) |
|
|
self.global_step = 0 |
|
|
val_metrics: Optional[Dict[str, Any]] = None |
|
|
|
|
|
|
|
|
self._load_checkpoint() |
|
|
|
|
|
|
|
|
if self.val_reward_fn is not None and self.config.trainer.val_before_train: |
|
|
val_metrics = self._validate() |
|
|
self.logger.log(data=val_metrics, step=self.global_step) |
|
|
if self.config.trainer.val_only: |
|
|
return |
|
|
|
|
|
ori_epoch = 0 |
|
|
self._create_dataloader(ori_epoch) |
|
|
steps_per_epoch = len(self.train_dataloader) |
|
|
|
|
|
now_epoch = self.global_step // steps_per_epoch |
|
|
new_step_in_now_epoch = self.global_step % steps_per_epoch |
|
|
print(f"now_epoch: {now_epoch}, steps_per_epoch: {steps_per_epoch}, global_step: {self.global_step}, new_step_in_now_epoch: {new_step_in_now_epoch}") |
|
|
|
|
|
|
|
|
accumulated_batch = None |
|
|
num_prompt_in_batch = 0 |
|
|
num_gen_batches_accumulated = 0 |
|
|
|
|
|
for current_epoch in tqdm(range(now_epoch, self.config.trainer.total_episodes), desc="Episode", position=0): |
|
|
current_epoch_copy = current_epoch + 1 |
|
|
self._create_dataloader(current_epoch_copy) |
|
|
|
|
|
for batch_dict in tqdm(itertools.islice(self.train_dataloader, new_step_in_now_epoch, steps_per_epoch), desc="Running step", position=1): |
|
|
self.global_step += 1 |
|
|
print("!" * 100 + f"global_step: {self.global_step}" + "!" * 100) |
|
|
if self.global_step > self.training_steps: |
|
|
break |
|
|
|
|
|
metrics, timing_raw = {}, {} |
|
|
batch: DataProto = DataProto.from_single_dict(batch_dict) |
|
|
num_gen_batches_accumulated +=1 |
|
|
|
|
|
|
|
|
if "multi_modal_inputs" in batch.non_tensor_batch.keys(): |
|
|
gen_batch = batch.pop( |
|
|
batch_keys=["input_ids", "attention_mask", "position_ids"], |
|
|
non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data", "multi_modal_inputs", "stage"], |
|
|
) |
|
|
else: |
|
|
gen_batch = batch.pop( |
|
|
batch_keys=["input_ids", "attention_mask", "position_ids"], |
|
|
non_tensor_batch_keys=["raw_prompt_ids", "stage"], |
|
|
) |
|
|
|
|
|
gen_batch.non_tensor_batch['budget'] = batch.non_tensor_batch['budget'] |
|
|
|
|
|
with _timer("step", timing_raw): |
|
|
|
|
|
with _timer("gen", timing_raw): |
|
|
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) |
|
|
|
|
|
if self.config.algorithm.adv_estimator == "remax": |
|
|
with _timer("gen_max", timing_raw): |
|
|
gen_baseline_batch = deepcopy(gen_batch) |
|
|
gen_baseline_batch.meta_info["temperature"] = 0 |
|
|
gen_baseline_batch.meta_info["n"] = 1 |
|
|
|
|
|
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) |
|
|
|
|
|
batch = batch.union(gen_baseline_output) |
|
|
reward_baseline_tensor, _ = self.reward_fn(batch) |
|
|
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) |
|
|
|
|
|
batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) |
|
|
batch.batch["reward_baselines"] = reward_baseline_tensor |
|
|
del gen_baseline_batch, gen_baseline_output |
|
|
|
|
|
batch.non_tensor_batch["uid"] = np.array( |
|
|
[str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object |
|
|
) |
|
|
|
|
|
batch = batch.repeat(repeat_times=self.config.worker.rollout.n, interleave=True) |
|
|
batch = batch.union(gen_batch_output) |
|
|
|
|
|
|
|
|
with _timer("reward", timing_raw): |
|
|
if self.use_reward_model: |
|
|
raise NotImplementedError("Reward model is not supported yet.") |
|
|
|
|
|
reward_tensor, reward_metrics = self.reward_fn(batch) |
|
|
batch.batch["token_level_scores"] = reward_tensor |
|
|
|
|
|
|
|
|
|
|
|
reward_metrics = { |
|
|
f"reward/{key}": value for key, value in reduce_metrics(reward_metrics).items() |
|
|
} |
|
|
metrics.update(reward_metrics) |
|
|
|
|
|
|
|
|
if hasattr(self.config.algorithm, "dynamic_sampling") and self.config.algorithm.dynamic_sampling.enable: |
|
|
|
|
|
|
|
|
token_level_scores = batch.batch["token_level_scores"] |
|
|
seq_rewards = token_level_scores.sum(dim=-1) |
|
|
|
|
|
|
|
|
if torch.allclose(seq_rewards, torch.zeros_like(seq_rewards), atol=1e-5): |
|
|
print("All rewards close to 0, skipping this batch.") |
|
|
continue |
|
|
|
|
|
if torch.allclose(seq_rewards, torch.ones_like(seq_rewards), atol=1e-5): |
|
|
print("All rewards close to 1, skipping this batch.") |
|
|
continue |
|
|
|
|
|
if torch.var(seq_rewards) < 1e-4: |
|
|
print("Low variance in reward scores, skipping.") |
|
|
continue |
|
|
|
|
|
|
|
|
if accumulated_batch is None: |
|
|
accumulated_batch = batch |
|
|
else: |
|
|
accumulated_batch = DataProto.concat([accumulated_batch, batch]) |
|
|
|
|
|
prompt_bsz = self.config.data.rollout_batch_size |
|
|
rollout_n = self.config.worker.rollout.n |
|
|
total_prompt_num = len(accumulated_batch) // rollout_n |
|
|
|
|
|
if total_prompt_num < prompt_bsz: |
|
|
max_batches = self.config.algorithm.dynamic_sampling.max_num_gen_batches |
|
|
if num_gen_batches_accumulated < max_batches: |
|
|
print(f"Accumulating... {total_prompt_num}/{prompt_bsz} prompts") |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
num_prompt_in_batch = batch.batch["input_ids"].shape[0] // rollout_n |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._balance_batch(batch, metrics=metrics) |
|
|
|
|
|
|
|
|
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() |
|
|
|
|
|
|
|
|
with _timer("old", timing_raw): |
|
|
old_log_probs = self.actor_rollout_wg.compute_log_probs(batch) |
|
|
batch = batch.union(old_log_probs) |
|
|
|
|
|
|
|
|
if self.use_reference_policy: |
|
|
with _timer("ref", timing_raw): |
|
|
ref_log_probs = self.ref_policy_wg.compute_ref_log_probs(batch) |
|
|
batch = batch.union(ref_log_probs) |
|
|
|
|
|
|
|
|
if self.use_critic: |
|
|
with _timer("values", timing_raw): |
|
|
values = self.critic_wg.compute_values(batch) |
|
|
batch = batch.union(values) |
|
|
|
|
|
with _timer("adv", timing_raw): |
|
|
if not self.config.algorithm.use_kl_loss and self.use_reference_policy: |
|
|
batch, kl_metrics = apply_kl_penalty( |
|
|
batch, kl_ctrl=self.kl_ctrl, kl_penalty=self.config.algorithm.kl_penalty |
|
|
) |
|
|
metrics.update(kl_metrics) |
|
|
else: |
|
|
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] |
|
|
|
|
|
batch = compute_advantage( |
|
|
batch, |
|
|
adv_estimator=self.config.algorithm.adv_estimator, |
|
|
gamma=self.config.algorithm.gamma, |
|
|
lam=self.config.algorithm.lam, |
|
|
) |
|
|
|
|
|
|
|
|
if self.use_critic: |
|
|
with _timer("update_critic", timing_raw): |
|
|
critic_output = self.critic_wg.update_critic(batch) |
|
|
|
|
|
critic_metrics = reduce_metrics(critic_output.non_tensor_batch) |
|
|
metrics.update(critic_metrics) |
|
|
|
|
|
|
|
|
if self.config.trainer.critic_warmup <= self.global_step: |
|
|
with _timer("update_actor", timing_raw): |
|
|
actor_output = self.actor_rollout_wg.update_actor(batch) |
|
|
|
|
|
actor_metrics = reduce_metrics(actor_output.non_tensor_batch) |
|
|
metrics.update(actor_metrics) |
|
|
|
|
|
|
|
|
if ( |
|
|
self.val_reward_fn is not None |
|
|
and self.config.trainer.val_freq > 0 |
|
|
and self.global_step % self.config.trainer.val_freq == 0 |
|
|
): |
|
|
with _timer("validation", timing_raw): |
|
|
val_metrics = self._validate() |
|
|
|
|
|
metrics.update(val_metrics) |
|
|
|
|
|
if self.config.trainer.save_freq > 0 and self.global_step % self.config.trainer.save_freq == 0: |
|
|
with _timer("save_checkpoint", timing_raw): |
|
|
self._save_checkpoint() |
|
|
|
|
|
|
|
|
if self.current_reward_accu == self.max_accu: |
|
|
with _timer("save_checkpoint", timing_raw): |
|
|
self._save_checkpoin_maxaccu() |
|
|
|
|
|
|
|
|
n_gpus = self.resource_pool_manager.get_n_gpus() |
|
|
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) |
|
|
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) |
|
|
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) |
|
|
|
|
|
|
|
|
|
|
|
if hasattr(self.config.algorithm, "dynamic_sampling") and self.config.algorithm.dynamic_sampling.enable: |
|
|
metrics["dynamic_sampling/num_gen_batches"] = num_gen_batches_accumulated |
|
|
metrics["dynamic_sampling/num_prompt_in_batch"] = num_prompt_in_batch |
|
|
|
|
|
|
|
|
self.logger.log(data=metrics, step=self.global_step) |
|
|
|
|
|
|
|
|
accumulated_batch = None |
|
|
num_prompt_in_batch = 0 |
|
|
|
|
|
|
|
|
if self.val_reward_fn is not None: |
|
|
if ( |
|
|
val_metrics is None |
|
|
or self.config.trainer.val_freq <= 0 |
|
|
or self.global_step % self.config.trainer.val_freq != 0 |
|
|
): |
|
|
val_metrics = self._validate() |
|
|
|
|
|
self.logger.log(data=val_metrics, step=self.global_step) |
|
|
|
|
|
print(f"Final validation metrics: {convert_dict_to_str(val_metrics)}") |
|
|
|
|
|
if self.config.trainer.save_freq <= 0 or self.global_step % self.config.trainer.save_freq != 0: |
|
|
self._save_checkpoint() |
|
|
|