|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
FSDP PPO Trainer with Ray-based single controller.
|
|
|
This trainer supports model-agonistic model initialization with huggingface
|
|
|
"""
|
|
|
|
|
|
import json
|
|
|
import os
|
|
|
import uuid
|
|
|
from collections import defaultdict
|
|
|
from contextlib import contextmanager
|
|
|
from copy import deepcopy
|
|
|
from dataclasses import dataclass, field
|
|
|
from enum import Enum
|
|
|
from pprint import pprint
|
|
|
from typing import Dict, Optional, Type
|
|
|
|
|
|
import numpy as np
|
|
|
import ray
|
|
|
import torch
|
|
|
from codetiming import Timer
|
|
|
from omegaconf import OmegaConf, open_dict
|
|
|
from torch.utils.data import Dataset, Sampler
|
|
|
from torchdata.stateful_dataloader import StatefulDataLoader
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
from verl import DataProto
|
|
|
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
|
|
|
from verl.single_controller.base import Worker
|
|
|
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
|
|
|
from verl.single_controller.ray.base import create_colocated_worker_cls
|
|
|
from verl.trainer.ppo import core_algos
|
|
|
from verl.trainer.ppo.core_algos import agg_loss
|
|
|
from verl.trainer.ppo.metric_utils import (
|
|
|
compute_data_metrics,
|
|
|
compute_throughout_metrics,
|
|
|
compute_timing_metrics,
|
|
|
process_validation_metrics,
|
|
|
)
|
|
|
from verl.trainer.ppo.reward import compute_reward, compute_reward_async
|
|
|
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path
|
|
|
from verl.utils.metric import (
|
|
|
reduce_metrics,
|
|
|
)
|
|
|
from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
|
|
|
from verl.utils.torch_functional import masked_mean
|
|
|
from verl.utils.tracking import ValidationGenerationsLogger
|
|
|
from verl.workers.rollout.async_server import AsyncLLMServerManager
|
|
|
|
|
|
WorkerType = Type[Worker]
|
|
|
|
|
|
|
|
|
class Role(Enum):
|
|
|
"""
|
|
|
To create more roles dynamically, you can subclass Role and add new members
|
|
|
"""
|
|
|
|
|
|
Actor = 0
|
|
|
Rollout = 1
|
|
|
ActorRollout = 2
|
|
|
Critic = 3
|
|
|
RefPolicy = 4
|
|
|
RewardModel = 5
|
|
|
ActorRolloutRef = 6
|
|
|
|
|
|
|
|
|
class AdvantageEstimator(str, Enum):
|
|
|
"""
|
|
|
Using an enumeration class to avoid spelling errors in adv_estimator
|
|
|
"""
|
|
|
|
|
|
GAE = "gae"
|
|
|
GRPO = "grpo"
|
|
|
REINFORCE_PLUS_PLUS = "reinforce_plus_plus"
|
|
|
REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline"
|
|
|
REMAX = "remax"
|
|
|
RLOO = "rloo"
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class ResourcePoolManager:
|
|
|
"""
|
|
|
Define a resource pool specification. Resource pool will be initialized first.
|
|
|
"""
|
|
|
|
|
|
resource_pool_spec: dict[str, list[int]]
|
|
|
mapping: dict[Role, str]
|
|
|
resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict)
|
|
|
|
|
|
def create_resource_pool(self):
|
|
|
for resource_pool_name, process_on_nodes in self.resource_pool_spec.items():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
resource_pool = RayResourcePool(process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name)
|
|
|
self.resource_pool_dict[resource_pool_name] = resource_pool
|
|
|
|
|
|
self._check_resource_available()
|
|
|
|
|
|
def get_resource_pool(self, role: Role) -> RayResourcePool:
|
|
|
"""Get the resource pool of the worker_cls"""
|
|
|
return self.resource_pool_dict[self.mapping[role]]
|
|
|
|
|
|
def get_n_gpus(self) -> int:
|
|
|
"""Get the number of gpus in this cluster."""
|
|
|
return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes])
|
|
|
|
|
|
def _check_resource_available(self):
|
|
|
"""Check if the resource pool can be satisfied in this ray cluster."""
|
|
|
node_available_resources = ray.state.available_resources_per_node()
|
|
|
node_available_gpus = {node: node_info.get("GPU", 0) for node, node_info in node_available_resources.items()}
|
|
|
|
|
|
|
|
|
total_available_gpus = sum(node_available_gpus.values())
|
|
|
total_required_gpus = sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes])
|
|
|
if total_available_gpus < total_required_gpus:
|
|
|
raise ValueError(f"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}")
|
|
|
|
|
|
|
|
|
for resource_pool_name, process_on_nodes in self.resource_pool_spec.items():
|
|
|
num_gpus, num_nodes = process_on_nodes[0], len(process_on_nodes)
|
|
|
for node, available_gpus in node_available_gpus.items():
|
|
|
if available_gpus >= num_gpus:
|
|
|
node_available_gpus[node] -= num_gpus
|
|
|
num_nodes -= 1
|
|
|
if num_nodes == 0:
|
|
|
break
|
|
|
if num_nodes > 0:
|
|
|
raise ValueError(f"Resource pool {resource_pool_name}: {num_gpus}*{num_nodes}" + "cannot be satisfied in this ray cluster")
|
|
|
|
|
|
|
|
|
def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl", multi_turn=False):
|
|
|
responses = data.batch["responses"]
|
|
|
response_length = responses.size(1)
|
|
|
token_level_scores = data.batch["token_level_scores"]
|
|
|
batch_size = data.batch.batch_size[0]
|
|
|
|
|
|
if multi_turn:
|
|
|
loss_mask = data.batch["loss_mask"]
|
|
|
response_mask = loss_mask[:, -response_length:]
|
|
|
else:
|
|
|
attention_mask = data.batch["attention_mask"]
|
|
|
response_mask = attention_mask[:, -response_length:]
|
|
|
|
|
|
|
|
|
|
|
|
kld = core_algos.kl_penalty(data.batch["old_log_probs"], data.batch["ref_log_prob"], kl_penalty=kl_penalty)
|
|
|
kld = kld * response_mask
|
|
|
beta = kl_ctrl.value
|
|
|
|
|
|
token_level_rewards = token_level_scores - beta * kld
|
|
|
|
|
|
current_kl = masked_mean(kld, mask=response_mask, axis=-1)
|
|
|
current_kl = torch.mean(current_kl, dim=0).item()
|
|
|
|
|
|
|
|
|
kl_ctrl.update(current_kl=current_kl, n_steps=batch_size)
|
|
|
data.batch["token_level_rewards"] = token_level_rewards
|
|
|
|
|
|
metrics = {"actor/reward_kl_penalty": current_kl, "actor/reward_kl_penalty_coeff": beta}
|
|
|
|
|
|
return data, metrics
|
|
|
|
|
|
|
|
|
def compute_response_mask(data: DataProto):
|
|
|
responses = data.batch["responses"]
|
|
|
response_length = responses.size(1)
|
|
|
attention_mask = data.batch["attention_mask"]
|
|
|
return attention_mask[:, -response_length:]
|
|
|
|
|
|
|
|
|
def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_repeat=1, multi_turn=False, norm_adv_by_std_in_grpo=True):
|
|
|
|
|
|
if "response_mask" not in data.batch:
|
|
|
data.batch["response_mask"] = compute_response_mask(data)
|
|
|
|
|
|
|
|
|
if adv_estimator == AdvantageEstimator.GAE:
|
|
|
advantages, returns = core_algos.compute_gae_advantage_return(
|
|
|
token_level_rewards=data.batch["token_level_rewards"],
|
|
|
values=data.batch["values"],
|
|
|
response_mask=data.batch["response_mask"],
|
|
|
gamma=gamma,
|
|
|
lam=lam,
|
|
|
)
|
|
|
data.batch["advantages"] = advantages
|
|
|
data.batch["returns"] = returns
|
|
|
elif adv_estimator == AdvantageEstimator.GRPO:
|
|
|
|
|
|
grpo_calculation_mask = data.batch["response_mask"]
|
|
|
if multi_turn:
|
|
|
|
|
|
response_length = grpo_calculation_mask.size(1)
|
|
|
grpo_calculation_mask = data.batch["loss_mask"][:, -response_length:]
|
|
|
|
|
|
advantages, returns = core_algos.compute_grpo_outcome_advantage(
|
|
|
token_level_rewards=data.batch["token_level_rewards"],
|
|
|
response_mask=grpo_calculation_mask,
|
|
|
index=data.non_tensor_batch["uid"],
|
|
|
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
|
|
|
)
|
|
|
data.batch["advantages"] = advantages
|
|
|
data.batch["returns"] = returns
|
|
|
elif adv_estimator == AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE:
|
|
|
advantages, returns = core_algos.compute_reinforce_plus_plus_baseline_outcome_advantage(
|
|
|
token_level_rewards=data.batch["token_level_rewards"],
|
|
|
response_mask=data.batch["response_mask"],
|
|
|
index=data.non_tensor_batch["uid"],
|
|
|
)
|
|
|
data.batch["advantages"] = advantages
|
|
|
data.batch["returns"] = returns
|
|
|
elif adv_estimator == AdvantageEstimator.REINFORCE_PLUS_PLUS:
|
|
|
advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage(
|
|
|
token_level_rewards=data.batch["token_level_rewards"],
|
|
|
response_mask=data.batch["response_mask"],
|
|
|
gamma=gamma,
|
|
|
)
|
|
|
data.batch["advantages"] = advantages
|
|
|
data.batch["returns"] = returns
|
|
|
elif adv_estimator == AdvantageEstimator.REMAX:
|
|
|
advantages, returns = core_algos.compute_remax_outcome_advantage(
|
|
|
token_level_rewards=data.batch["token_level_rewards"],
|
|
|
reward_baselines=data.batch["reward_baselines"],
|
|
|
response_mask=data.batch["response_mask"],
|
|
|
)
|
|
|
|
|
|
data.batch["advantages"] = advantages
|
|
|
data.batch["returns"] = returns
|
|
|
elif adv_estimator == AdvantageEstimator.RLOO:
|
|
|
advantages, returns = core_algos.compute_rloo_outcome_advantage(
|
|
|
token_level_rewards=data.batch["token_level_rewards"],
|
|
|
response_mask=data.batch["response_mask"],
|
|
|
index=data.non_tensor_batch["uid"],
|
|
|
)
|
|
|
data.batch["advantages"] = advantages
|
|
|
data.batch["returns"] = returns
|
|
|
else:
|
|
|
raise NotImplementedError
|
|
|
return data
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
def _timer(name: str, timing_raw: Dict[str, float]):
|
|
|
with Timer(name=name, logger=None) as timer:
|
|
|
yield
|
|
|
if name not in timing_raw:
|
|
|
timing_raw[name] = 0
|
|
|
timing_raw[name] += timer.last
|
|
|
|
|
|
|
|
|
class RayPPOTrainer:
|
|
|
"""
|
|
|
Note that this trainer runs on the driver process on a single CPU/GPU node.
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
config,
|
|
|
tokenizer,
|
|
|
role_worker_mapping: dict[Role, WorkerType],
|
|
|
resource_pool_manager: ResourcePoolManager,
|
|
|
ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,
|
|
|
processor=None,
|
|
|
reward_fn=None,
|
|
|
val_reward_fn=None,
|
|
|
train_dataset: Optional[Dataset] = None,
|
|
|
val_dataset: Optional[Dataset] = None,
|
|
|
collate_fn=None,
|
|
|
train_sampler: Optional[Sampler] = None,
|
|
|
):
|
|
|
|
|
|
|
|
|
self.tokenizer = tokenizer
|
|
|
self.processor = processor
|
|
|
self.config = config
|
|
|
self.reward_fn = reward_fn
|
|
|
self.val_reward_fn = val_reward_fn
|
|
|
|
|
|
self.hybrid_engine = config.actor_rollout_ref.hybrid_engine
|
|
|
assert self.hybrid_engine, "Currently, only support hybrid engine"
|
|
|
|
|
|
if self.hybrid_engine:
|
|
|
assert Role.ActorRollout in role_worker_mapping, f"{role_worker_mapping.keys()=}"
|
|
|
|
|
|
self.role_worker_mapping = role_worker_mapping
|
|
|
self.resource_pool_manager = resource_pool_manager
|
|
|
self.use_reference_policy = Role.RefPolicy in role_worker_mapping
|
|
|
self.use_rm = Role.RewardModel in role_worker_mapping
|
|
|
self.ray_worker_group_cls = ray_worker_group_cls
|
|
|
self.validation_generations_logger = ValidationGenerationsLogger()
|
|
|
|
|
|
|
|
|
|
|
|
if config.algorithm.use_kl_in_reward:
|
|
|
self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl)
|
|
|
|
|
|
if self.config.algorithm.adv_estimator == AdvantageEstimator.GAE:
|
|
|
self.use_critic = True
|
|
|
elif self.config.algorithm.adv_estimator in [
|
|
|
AdvantageEstimator.GRPO,
|
|
|
AdvantageEstimator.REINFORCE_PLUS_PLUS,
|
|
|
AdvantageEstimator.REMAX,
|
|
|
AdvantageEstimator.RLOO,
|
|
|
AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE,
|
|
|
]:
|
|
|
self.use_critic = False
|
|
|
else:
|
|
|
raise NotImplementedError
|
|
|
|
|
|
self._validate_config()
|
|
|
self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)
|
|
|
|
|
|
def _validate_config(self):
|
|
|
config = self.config
|
|
|
|
|
|
n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes
|
|
|
|
|
|
|
|
|
real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n
|
|
|
assert real_train_batch_size % n_gpus == 0, f"real_train_batch_size ({real_train_batch_size}) must be divisible by total n_gpus ({n_gpus})."
|
|
|
|
|
|
|
|
|
|
|
|
def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
|
|
|
settings = {
|
|
|
"actor_rollout_ref.actor": "micro_batch_size",
|
|
|
"critic": "micro_batch_size",
|
|
|
"reward_model": "micro_batch_size",
|
|
|
"actor_rollout_ref.ref": "log_prob_micro_batch_size",
|
|
|
"actor_rollout_ref.rollout": "log_prob_micro_batch_size",
|
|
|
}
|
|
|
|
|
|
if name in settings:
|
|
|
param = settings[name]
|
|
|
param_per_gpu = f"{param}_per_gpu"
|
|
|
|
|
|
if mbs is None and mbs_per_gpu is None:
|
|
|
raise ValueError(f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'.")
|
|
|
|
|
|
if mbs is not None and mbs_per_gpu is not None:
|
|
|
raise ValueError(f"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. Please remove '{name}.{param}' because only '*_{param_per_gpu}'" + "is supported (the former is deprecated).")
|
|
|
|
|
|
if not config.actor_rollout_ref.actor.use_dynamic_bsz:
|
|
|
|
|
|
check_mutually_exclusive(
|
|
|
config.actor_rollout_ref.actor.ppo_micro_batch_size,
|
|
|
config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu,
|
|
|
"actor_rollout_ref.actor",
|
|
|
)
|
|
|
|
|
|
if self.use_reference_policy:
|
|
|
|
|
|
check_mutually_exclusive(
|
|
|
config.actor_rollout_ref.ref.log_prob_micro_batch_size,
|
|
|
config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu,
|
|
|
"actor_rollout_ref.ref",
|
|
|
)
|
|
|
|
|
|
|
|
|
check_mutually_exclusive(
|
|
|
config.actor_rollout_ref.rollout.log_prob_micro_batch_size,
|
|
|
config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu,
|
|
|
"actor_rollout_ref.rollout",
|
|
|
)
|
|
|
|
|
|
if self.use_critic and not config.critic.use_dynamic_bsz:
|
|
|
|
|
|
check_mutually_exclusive(config.critic.ppo_micro_batch_size, config.critic.ppo_micro_batch_size_per_gpu, "critic")
|
|
|
|
|
|
|
|
|
if config.reward_model.enable and not config.reward_model.use_dynamic_bsz:
|
|
|
check_mutually_exclusive(config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, "reward_model")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not config.actor_rollout_ref.actor.use_dynamic_bsz:
|
|
|
assert config.data.train_batch_size >= config.actor_rollout_ref.actor.ppo_mini_batch_size
|
|
|
sp_size = config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1)
|
|
|
if config.actor_rollout_ref.actor.ppo_micro_batch_size is not None:
|
|
|
assert config.actor_rollout_ref.actor.ppo_mini_batch_size % config.actor_rollout_ref.actor.ppo_micro_batch_size == 0
|
|
|
assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus
|
|
|
|
|
|
assert config.actor_rollout_ref.actor.loss_agg_mode in [
|
|
|
"token-mean",
|
|
|
"seq-mean-token-sum",
|
|
|
"seq-mean-token-mean",
|
|
|
"seq-mean-token-sum-norm",
|
|
|
], f"Invalid loss_agg_mode: {config.actor_rollout_ref.actor.loss_agg_mode}"
|
|
|
|
|
|
if config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss:
|
|
|
print("NOTICE: You have both enabled in-reward kl and kl loss.")
|
|
|
|
|
|
|
|
|
if self.use_critic and not config.critic.use_dynamic_bsz:
|
|
|
assert config.data.train_batch_size >= config.critic.ppo_mini_batch_size
|
|
|
sp_size = config.critic.get("ulysses_sequence_parallel_size", 1)
|
|
|
if config.critic.ppo_micro_batch_size is not None:
|
|
|
assert config.critic.ppo_mini_batch_size % config.critic.ppo_micro_batch_size == 0
|
|
|
assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus
|
|
|
|
|
|
|
|
|
if config.actor_rollout_ref.actor.strategy == "fsdp" and (config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) > 1 or config.actor_rollout_ref.ref.get("ulysses_sequence_parallel_size", 1) > 1):
|
|
|
assert config.actor_rollout_ref.model.use_remove_padding, "When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`."
|
|
|
|
|
|
if self.use_critic and config.critic.strategy == "fsdp":
|
|
|
if config.critic.get("ulysses_sequence_parallel_size", 1) > 1:
|
|
|
assert config.critic.model.use_remove_padding, "When using sequence parallelism for critic, you must enable `use_remove_padding`."
|
|
|
|
|
|
if config.data.get("val_batch_size", None) is not None:
|
|
|
print("WARNING: val_batch_size is deprecated." + " Validation datasets are sent to inference engines as a whole batch," + " which will schedule the memory themselves.")
|
|
|
|
|
|
|
|
|
if config.actor_rollout_ref.rollout.val_kwargs.do_sample:
|
|
|
assert config.actor_rollout_ref.rollout.temperature > 0, "validation gen temperature should be greater than 0 when enabling do_sample"
|
|
|
|
|
|
|
|
|
if config.actor_rollout_ref.rollout.multi_turn.enable:
|
|
|
assert config.actor_rollout_ref.rollout.multi_turn.tool_config_path is not None, "tool_config_path must be set when enabling multi_turn with tool, due to no role-playing support"
|
|
|
assert config.algorithm.adv_estimator in [AdvantageEstimator.GRPO], "only GRPO is tested for multi-turn with tool"
|
|
|
|
|
|
print("[validate_config] All configuration checks passed successfully!")
|
|
|
|
|
|
def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler):
|
|
|
"""
|
|
|
Creates the train and validation dataloaders.
|
|
|
"""
|
|
|
|
|
|
from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler
|
|
|
|
|
|
if train_dataset is None:
|
|
|
train_dataset = create_rl_dataset(self.config.data.train_files, self.config.data, self.tokenizer, self.processor)
|
|
|
if val_dataset is None:
|
|
|
val_dataset = create_rl_dataset(self.config.data.val_files, self.config.data, self.tokenizer, self.processor)
|
|
|
self.train_dataset, self.val_dataset = train_dataset, val_dataset
|
|
|
|
|
|
if train_sampler is None:
|
|
|
train_sampler = create_rl_sampler(self.config.data, self.train_dataset)
|
|
|
if collate_fn is None:
|
|
|
from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn
|
|
|
|
|
|
collate_fn = default_collate_fn
|
|
|
|
|
|
self.train_dataloader = StatefulDataLoader(
|
|
|
dataset=self.train_dataset,
|
|
|
batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size),
|
|
|
num_workers=self.config.data.get("dataloader_num_workers", 8),
|
|
|
drop_last=True,
|
|
|
collate_fn=collate_fn,
|
|
|
sampler=train_sampler,
|
|
|
)
|
|
|
|
|
|
val_batch_size = self.config.data.val_batch_size
|
|
|
if val_batch_size is None:
|
|
|
val_batch_size = len(self.val_dataset)
|
|
|
|
|
|
self.val_dataloader = StatefulDataLoader(
|
|
|
dataset=self.val_dataset,
|
|
|
batch_size=val_batch_size,
|
|
|
num_workers=self.config.data.get("dataloader_num_workers", 8),
|
|
|
shuffle=False,
|
|
|
drop_last=False,
|
|
|
collate_fn=collate_fn,
|
|
|
)
|
|
|
|
|
|
assert len(self.train_dataloader) >= 1, "Train dataloader is empty!"
|
|
|
assert len(self.val_dataloader) >= 1, "Validation dataloader is empty!"
|
|
|
|
|
|
print(f"Size of train dataloader: {len(self.train_dataloader)}, Size of val dataloader: {len(self.val_dataloader)}")
|
|
|
|
|
|
total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs
|
|
|
|
|
|
if self.config.trainer.total_training_steps is not None:
|
|
|
total_training_steps = self.config.trainer.total_training_steps
|
|
|
|
|
|
self.total_training_steps = total_training_steps
|
|
|
print(f"Total training steps: {self.total_training_steps}")
|
|
|
|
|
|
try:
|
|
|
OmegaConf.set_struct(self.config, True)
|
|
|
with open_dict(self.config):
|
|
|
if OmegaConf.select(self.config, "actor_rollout_ref.actor.optim"):
|
|
|
self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps
|
|
|
if OmegaConf.select(self.config, "critic.optim"):
|
|
|
self.config.critic.optim.total_training_steps = total_training_steps
|
|
|
except Exception as e:
|
|
|
print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}")
|
|
|
|
|
|
def _dump_generations(self, inputs, outputs, scores, reward_extra_infos_dict, dump_path):
|
|
|
"""Dump rollout/validation samples as JSONL."""
|
|
|
os.makedirs(dump_path, exist_ok=True)
|
|
|
filename = os.path.join(dump_path, f"{self.global_steps}.jsonl")
|
|
|
|
|
|
n = len(inputs)
|
|
|
base_data = {
|
|
|
"input": inputs,
|
|
|
"output": outputs,
|
|
|
"score": scores,
|
|
|
"step": [self.global_steps] * n,
|
|
|
}
|
|
|
|
|
|
for k, v in reward_extra_infos_dict.items():
|
|
|
if len(v) == n:
|
|
|
base_data[k] = v
|
|
|
|
|
|
with open(filename, "w") as f:
|
|
|
for i in range(n):
|
|
|
entry = {k: v[i] for k, v in base_data.items()}
|
|
|
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
|
|
|
|
|
print(f"Dumped generations to {filename}")
|
|
|
|
|
|
def _maybe_log_val_generations(self, inputs, outputs, scores):
|
|
|
"""Log a table of validation samples to the configured logger (wandb or swanlab)"""
|
|
|
|
|
|
generations_to_log = self.config.trainer.log_val_generations
|
|
|
|
|
|
if generations_to_log == 0:
|
|
|
return
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
samples = list(zip(inputs, outputs, scores))
|
|
|
samples.sort(key=lambda x: x[0])
|
|
|
|
|
|
|
|
|
rng = np.random.RandomState(42)
|
|
|
rng.shuffle(samples)
|
|
|
|
|
|
|
|
|
samples = samples[:generations_to_log]
|
|
|
|
|
|
|
|
|
self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps)
|
|
|
|
|
|
def _validate(self):
|
|
|
data_source_lst = []
|
|
|
reward_extra_infos_dict: dict[str, list] = defaultdict(list)
|
|
|
|
|
|
|
|
|
sample_inputs = []
|
|
|
sample_outputs = []
|
|
|
sample_scores = []
|
|
|
|
|
|
for test_data in self.val_dataloader:
|
|
|
test_batch = DataProto.from_single_dict(test_data)
|
|
|
|
|
|
|
|
|
test_batch = test_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True)
|
|
|
|
|
|
if self.config.reward_model.enable and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model":
|
|
|
return {}
|
|
|
|
|
|
|
|
|
input_ids = test_batch.batch["input_ids"]
|
|
|
|
|
|
input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
|
|
|
sample_inputs.extend(input_texts)
|
|
|
|
|
|
batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
|
|
|
non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
|
|
|
if "multi_modal_inputs" in test_batch.non_tensor_batch:
|
|
|
non_tensor_batch_keys_to_pop.extend(["multi_modal_data", "multi_modal_inputs"])
|
|
|
if "raw_prompt" in test_batch.non_tensor_batch:
|
|
|
non_tensor_batch_keys_to_pop.append("raw_prompt")
|
|
|
if "tools_kwargs" in test_batch.non_tensor_batch:
|
|
|
non_tensor_batch_keys_to_pop.append("tools_kwargs")
|
|
|
test_gen_batch = test_batch.pop(
|
|
|
batch_keys=batch_keys_to_pop,
|
|
|
non_tensor_batch_keys=non_tensor_batch_keys_to_pop,
|
|
|
)
|
|
|
|
|
|
test_gen_batch.meta_info = {
|
|
|
"eos_token_id": self.tokenizer.eos_token_id,
|
|
|
"pad_token_id": self.tokenizer.pad_token_id,
|
|
|
"recompute_log_prob": False,
|
|
|
"do_sample": self.config.actor_rollout_ref.rollout.val_kwargs.do_sample,
|
|
|
"validate": True,
|
|
|
}
|
|
|
print(f"test_gen_batch meta info: {test_gen_batch.meta_info}")
|
|
|
|
|
|
|
|
|
test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size)
|
|
|
if not self.async_rollout_mode:
|
|
|
test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded)
|
|
|
else:
|
|
|
self.async_rollout_manager.wake_up()
|
|
|
test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded)
|
|
|
self.async_rollout_manager.sleep()
|
|
|
|
|
|
|
|
|
test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size)
|
|
|
print("validation generation end")
|
|
|
|
|
|
|
|
|
output_ids = test_output_gen_batch.batch["responses"]
|
|
|
output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
|
|
|
sample_outputs.extend(output_texts)
|
|
|
|
|
|
test_batch = test_batch.union(test_output_gen_batch)
|
|
|
|
|
|
|
|
|
result = self.val_reward_fn(test_batch, return_dict=True)
|
|
|
reward_tensor = result["reward_tensor"]
|
|
|
scores = reward_tensor.sum(-1).cpu().tolist()
|
|
|
sample_scores.extend(scores)
|
|
|
|
|
|
print("val-scores in batch: ", scores)
|
|
|
|
|
|
reward_extra_infos_dict["reward"].extend(scores)
|
|
|
if "reward_extra_info" in result:
|
|
|
for key, lst in result["reward_extra_info"].items():
|
|
|
reward_extra_infos_dict[key].extend(lst)
|
|
|
|
|
|
data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0]))
|
|
|
|
|
|
self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores)
|
|
|
|
|
|
|
|
|
val_data_dir = self.config.trainer.get("validation_data_dir", None)
|
|
|
if val_data_dir:
|
|
|
self._dump_generations(
|
|
|
inputs=sample_inputs,
|
|
|
outputs=sample_outputs,
|
|
|
scores=sample_scores,
|
|
|
reward_extra_infos_dict=reward_extra_infos_dict,
|
|
|
dump_path=val_data_dir,
|
|
|
)
|
|
|
|
|
|
for key_info, lst in reward_extra_infos_dict.items():
|
|
|
assert len(lst) == 0 or len(lst) == len(sample_scores), f"{key_info}: {len(lst)=}, {len(sample_scores)=}"
|
|
|
|
|
|
data_sources = np.concatenate(data_source_lst, axis=0)
|
|
|
|
|
|
data_src2var2metric2val = process_validation_metrics(data_sources, sample_inputs, reward_extra_infos_dict)
|
|
|
metric_dict = {}
|
|
|
for data_source, var2metric2val in data_src2var2metric2val.items():
|
|
|
core_var = "acc" if "acc" in var2metric2val else "reward"
|
|
|
for var_name, metric2val in var2metric2val.items():
|
|
|
n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()])
|
|
|
for metric_name, metric_val in metric2val.items():
|
|
|
if (var_name == core_var) and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best"]) and (f"@{n_max}" in metric_name):
|
|
|
metric_sec = "val-core"
|
|
|
else:
|
|
|
metric_sec = "val-aux"
|
|
|
pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}"
|
|
|
metric_dict[pfx] = metric_val
|
|
|
|
|
|
return metric_dict
|
|
|
|
|
|
def init_workers(self):
|
|
|
"""Init resource pool and worker group"""
|
|
|
self.resource_pool_manager.create_resource_pool()
|
|
|
|
|
|
self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}
|
|
|
|
|
|
|
|
|
if self.hybrid_engine:
|
|
|
resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout)
|
|
|
actor_rollout_cls = RayClassWithInitArgs(
|
|
|
cls=self.role_worker_mapping[Role.ActorRollout],
|
|
|
config=self.config.actor_rollout_ref,
|
|
|
role="actor_rollout",
|
|
|
)
|
|
|
self.resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls
|
|
|
else:
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
if self.use_critic:
|
|
|
resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)
|
|
|
critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic)
|
|
|
self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls
|
|
|
|
|
|
|
|
|
if self.use_reference_policy:
|
|
|
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)
|
|
|
ref_policy_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RefPolicy], config=self.config.actor_rollout_ref, role="ref")
|
|
|
self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls
|
|
|
|
|
|
|
|
|
if self.use_rm:
|
|
|
|
|
|
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)
|
|
|
rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model)
|
|
|
self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
all_wg = {}
|
|
|
wg_kwargs = {}
|
|
|
if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None:
|
|
|
wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout
|
|
|
|
|
|
for resource_pool, class_dict in self.resource_pool_to_cls.items():
|
|
|
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
|
|
|
wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls, **wg_kwargs)
|
|
|
spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
|
|
|
all_wg.update(spawn_wg)
|
|
|
|
|
|
if self.use_critic:
|
|
|
self.critic_wg = all_wg["critic"]
|
|
|
self.critic_wg.init_model()
|
|
|
|
|
|
if self.use_reference_policy:
|
|
|
self.ref_policy_wg = all_wg["ref"]
|
|
|
self.ref_policy_wg.init_model()
|
|
|
|
|
|
if self.use_rm:
|
|
|
self.rm_wg = all_wg["rm"]
|
|
|
self.rm_wg.init_model()
|
|
|
|
|
|
|
|
|
self.actor_rollout_wg = all_wg["actor_rollout"]
|
|
|
self.actor_rollout_wg.init_model()
|
|
|
|
|
|
|
|
|
self.async_rollout_mode = False
|
|
|
if self.config.actor_rollout_ref.rollout.mode == "async":
|
|
|
self.async_rollout_mode = True
|
|
|
self.async_rollout_manager = AsyncLLMServerManager(
|
|
|
config=self.config.actor_rollout_ref,
|
|
|
worker_group=self.actor_rollout_wg,
|
|
|
)
|
|
|
|
|
|
def _save_checkpoint(self):
|
|
|
|
|
|
local_global_step_folder = os.path.join(self.config.trainer.default_local_dir, f"global_step_{self.global_steps}")
|
|
|
|
|
|
print(f"local_global_step_folder: {local_global_step_folder}")
|
|
|
|
|
|
actor_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor")
|
|
|
|
|
|
remove_previous_ckpt_in_save = self.config.trainer.get("remove_previous_ckpt_in_save", False)
|
|
|
if remove_previous_ckpt_in_save:
|
|
|
print("Warning: remove_previous_ckpt_in_save is deprecated," + " set max_actor_ckpt_to_keep=1 and max_critic_ckpt_to_keep=1 instead")
|
|
|
max_actor_ckpt_to_keep = self.config.trainer.get("max_actor_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1
|
|
|
max_critic_ckpt_to_keep = self.config.trainer.get("max_critic_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1
|
|
|
|
|
|
self.actor_rollout_wg.save_checkpoint(local_global_step_folder, actor_remote_path, self.global_steps, max_ckpt_to_keep=max_actor_ckpt_to_keep)
|
|
|
|
|
|
if self.use_critic:
|
|
|
critic_local_path = os.path.join(local_global_step_folder, "critic")
|
|
|
critic_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "critic")
|
|
|
self.critic_wg.save_checkpoint(critic_local_path, critic_remote_path, self.global_steps, max_ckpt_to_keep=max_critic_ckpt_to_keep)
|
|
|
|
|
|
|
|
|
dataloader_local_path = os.path.join(local_global_step_folder, "data.pt")
|
|
|
dataloader_state_dict = self.train_dataloader.state_dict()
|
|
|
torch.save(dataloader_state_dict, dataloader_local_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_checkpoint(self):
|
|
|
if self.config.trainer.resume_mode == "disable":
|
|
|
return 0
|
|
|
|
|
|
|
|
|
if self.config.trainer.default_hdfs_dir is not None:
|
|
|
raise NotImplementedError("load from hdfs is not implemented yet")
|
|
|
else:
|
|
|
checkpoint_folder = self.config.trainer.default_local_dir
|
|
|
if not os.path.isabs(checkpoint_folder):
|
|
|
working_dir = os.getcwd()
|
|
|
checkpoint_folder = os.path.join(working_dir, checkpoint_folder)
|
|
|
global_step_folder = find_latest_ckpt_path(checkpoint_folder)
|
|
|
|
|
|
|
|
|
if self.config.trainer.resume_mode == "auto":
|
|
|
if global_step_folder is None:
|
|
|
print("Training from scratch")
|
|
|
return 0
|
|
|
else:
|
|
|
if self.config.trainer.resume_mode == "resume_path":
|
|
|
assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type"
|
|
|
assert "global_step_" in self.config.trainer.resume_from_path, "resume ckpt must specify the global_steps"
|
|
|
global_step_folder = self.config.trainer.resume_from_path
|
|
|
if not os.path.isabs(global_step_folder):
|
|
|
working_dir = os.getcwd()
|
|
|
global_step_folder = os.path.join(working_dir, global_step_folder)
|
|
|
print(f"Load from checkpoint folder: {global_step_folder}")
|
|
|
|
|
|
self.global_steps = int(global_step_folder.split("global_step_")[-1])
|
|
|
|
|
|
print(f"Setting global step to {self.global_steps}")
|
|
|
print(f"Resuming from {global_step_folder}")
|
|
|
|
|
|
actor_path = os.path.join(global_step_folder, "actor")
|
|
|
critic_path = os.path.join(global_step_folder, "critic")
|
|
|
|
|
|
self.actor_rollout_wg.load_checkpoint(actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load)
|
|
|
|
|
|
if self.use_critic:
|
|
|
self.critic_wg.load_checkpoint(critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load)
|
|
|
|
|
|
|
|
|
|
|
|
dataloader_local_path = os.path.join(global_step_folder, "data.pt")
|
|
|
if os.path.exists(dataloader_local_path):
|
|
|
dataloader_state_dict = torch.load(dataloader_local_path, weights_only=False)
|
|
|
self.train_dataloader.load_state_dict(dataloader_state_dict)
|
|
|
else:
|
|
|
print(f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch")
|
|
|
|
|
|
def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen"):
|
|
|
"""Reorder the data on single controller such that each dp rank gets similar total tokens"""
|
|
|
attention_mask = batch.batch["attention_mask"]
|
|
|
batch_size = attention_mask.shape[0]
|
|
|
global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist()
|
|
|
world_size = self.actor_rollout_wg.world_size
|
|
|
global_partition_lst = get_seqlen_balanced_partitions(global_seqlen_lst, k_partitions=world_size, equal_size=True)
|
|
|
|
|
|
global_idx = torch.tensor([j for partition in global_partition_lst for j in partition])
|
|
|
batch.reorder(global_idx)
|
|
|
global_balance_stats = log_seqlen_unbalance(seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix)
|
|
|
metrics.update(global_balance_stats)
|
|
|
|
|
|
def fit(self):
|
|
|
"""
|
|
|
The training loop of PPO.
|
|
|
The driver process only need to call the compute functions of the worker group through RPC
|
|
|
to construct the PPO dataflow.
|
|
|
The light-weight advantage computation is done on the driver process.
|
|
|
"""
|
|
|
from omegaconf import OmegaConf
|
|
|
|
|
|
from verl.utils.tracking import Tracking
|
|
|
|
|
|
logger = Tracking(
|
|
|
project_name=self.config.trainer.project_name,
|
|
|
experiment_name=self.config.trainer.experiment_name,
|
|
|
default_backend=self.config.trainer.logger,
|
|
|
config=OmegaConf.to_container(self.config, resolve=True),
|
|
|
)
|
|
|
|
|
|
self.global_steps = 0
|
|
|
|
|
|
|
|
|
|
|
|
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
|
|
|
val_metrics = self._validate()
|
|
|
assert val_metrics, f"{val_metrics=}"
|
|
|
pprint(f"Initial validation metrics: {val_metrics}")
|
|
|
logger.log(data=val_metrics, step=self.global_steps)
|
|
|
|
|
|
if self.config.trainer.get("val_only", False):
|
|
|
return
|
|
|
|
|
|
|
|
|
progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")
|
|
|
|
|
|
|
|
|
self.global_steps += 1
|
|
|
last_val_metrics = None
|
|
|
best_val_rewards = float("-inf")
|
|
|
for epoch in range(self.config.trainer.total_epochs):
|
|
|
for batch_dict in self.train_dataloader:
|
|
|
metrics = {}
|
|
|
timing_raw = {}
|
|
|
batch: DataProto = DataProto.from_single_dict(batch_dict)
|
|
|
|
|
|
|
|
|
batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
|
|
|
non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
|
|
|
if "multi_modal_inputs" in batch.non_tensor_batch:
|
|
|
non_tensor_batch_keys_to_pop.extend(["multi_modal_data", "multi_modal_inputs"])
|
|
|
if "raw_prompt" in batch.non_tensor_batch:
|
|
|
non_tensor_batch_keys_to_pop.append("raw_prompt")
|
|
|
if "tools_kwargs" in batch.non_tensor_batch:
|
|
|
non_tensor_batch_keys_to_pop.append("tools_kwargs")
|
|
|
gen_batch = batch.pop(
|
|
|
batch_keys=batch_keys_to_pop,
|
|
|
non_tensor_batch_keys=non_tensor_batch_keys_to_pop,
|
|
|
)
|
|
|
|
|
|
is_last_step = self.global_steps >= self.total_training_steps
|
|
|
|
|
|
with _timer("step", timing_raw):
|
|
|
|
|
|
with _timer("gen", timing_raw):
|
|
|
if not self.async_rollout_mode:
|
|
|
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
|
|
|
else:
|
|
|
self.async_rollout_manager.wake_up()
|
|
|
gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch)
|
|
|
self.async_rollout_manager.sleep()
|
|
|
|
|
|
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
|
|
|
with _timer("gen_max", timing_raw):
|
|
|
gen_baseline_batch = deepcopy(gen_batch)
|
|
|
gen_baseline_batch.meta_info["do_sample"] = False
|
|
|
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
|
|
|
|
|
|
batch = batch.union(gen_baseline_output)
|
|
|
reward_baseline_tensor = self.reward_fn(batch)
|
|
|
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
|
|
|
|
|
|
batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
|
|
|
|
|
|
batch.batch["reward_baselines"] = reward_baseline_tensor
|
|
|
|
|
|
del gen_baseline_batch, gen_baseline_output
|
|
|
|
|
|
batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object)
|
|
|
|
|
|
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
|
|
|
batch = batch.union(gen_batch_output)
|
|
|
|
|
|
batch.batch["response_mask"] = compute_response_mask(batch)
|
|
|
|
|
|
|
|
|
|
|
|
if self.config.trainer.balance_batch:
|
|
|
self._balance_batch(batch, metrics=metrics)
|
|
|
|
|
|
|
|
|
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
|
|
|
|
|
|
with _timer("reward", timing_raw):
|
|
|
|
|
|
if self.use_rm:
|
|
|
reward_tensor = self.rm_wg.compute_rm_score(batch)
|
|
|
batch = batch.union(reward_tensor)
|
|
|
|
|
|
if self.config.reward_model.launch_reward_fn_async:
|
|
|
future_reward = compute_reward_async.remote(batch, self.config, self.tokenizer)
|
|
|
else:
|
|
|
reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn)
|
|
|
|
|
|
|
|
|
with _timer("old_log_prob", timing_raw):
|
|
|
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
|
|
|
entropys = old_log_prob.batch["entropys"]
|
|
|
response_masks = batch.batch["response_mask"]
|
|
|
loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
|
|
|
entropy_loss = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)
|
|
|
old_log_prob_metrics = {"actor/entropy_loss": entropy_loss.detach().item()}
|
|
|
metrics.update(old_log_prob_metrics)
|
|
|
old_log_prob.batch.pop("entropys")
|
|
|
batch = batch.union(old_log_prob)
|
|
|
|
|
|
if self.use_reference_policy:
|
|
|
|
|
|
with _timer("ref", timing_raw):
|
|
|
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
|
|
|
batch = batch.union(ref_log_prob)
|
|
|
|
|
|
|
|
|
if self.use_critic:
|
|
|
with _timer("values", timing_raw):
|
|
|
values = self.critic_wg.compute_values(batch)
|
|
|
batch = batch.union(values)
|
|
|
|
|
|
with _timer("adv", timing_raw):
|
|
|
|
|
|
reward_extra_infos_dict: dict[str, list]
|
|
|
if self.config.reward_model.launch_reward_fn_async:
|
|
|
reward_tensor, reward_extra_infos_dict = ray.get(future_reward)
|
|
|
batch.batch["token_level_scores"] = reward_tensor
|
|
|
|
|
|
print(f"{list(reward_extra_infos_dict.keys())=}")
|
|
|
if reward_extra_infos_dict:
|
|
|
batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})
|
|
|
|
|
|
|
|
|
if self.config.algorithm.use_kl_in_reward:
|
|
|
batch, kl_metrics = apply_kl_penalty(batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty)
|
|
|
metrics.update(kl_metrics)
|
|
|
else:
|
|
|
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
|
|
|
|
|
|
|
|
|
|
|
|
norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True)
|
|
|
|
|
|
batch = compute_advantage(
|
|
|
batch,
|
|
|
adv_estimator=self.config.algorithm.adv_estimator,
|
|
|
gamma=self.config.algorithm.gamma,
|
|
|
lam=self.config.algorithm.lam,
|
|
|
num_repeat=self.config.actor_rollout_ref.rollout.n,
|
|
|
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
|
|
|
multi_turn=self.config.actor_rollout_ref.rollout.multi_turn.enable,
|
|
|
)
|
|
|
|
|
|
|
|
|
if self.use_critic:
|
|
|
with _timer("update_critic", timing_raw):
|
|
|
critic_output = self.critic_wg.update_critic(batch)
|
|
|
critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
|
|
|
metrics.update(critic_output_metrics)
|
|
|
|
|
|
|
|
|
if self.config.trainer.critic_warmup <= self.global_steps:
|
|
|
|
|
|
with _timer("update_actor", timing_raw):
|
|
|
batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable
|
|
|
actor_output = self.actor_rollout_wg.update_actor(batch)
|
|
|
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
|
|
|
metrics.update(actor_output_metrics)
|
|
|
|
|
|
|
|
|
rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
|
|
|
if rollout_data_dir:
|
|
|
with _timer("dump_rollout_generations", timing_raw):
|
|
|
print(batch.batch.keys())
|
|
|
inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True)
|
|
|
outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True)
|
|
|
scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist()
|
|
|
self._dump_generations(
|
|
|
inputs=inputs,
|
|
|
outputs=outputs,
|
|
|
scores=scores,
|
|
|
reward_extra_infos_dict=reward_extra_infos_dict,
|
|
|
dump_path=rollout_data_dir,
|
|
|
)
|
|
|
|
|
|
|
|
|
if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \
|
|
|
(is_last_step or self.global_steps % self.config.trainer.test_freq == 0):
|
|
|
with _timer("testing", timing_raw):
|
|
|
val_metrics: dict = self._validate()
|
|
|
if is_last_step:
|
|
|
last_val_metrics = val_metrics
|
|
|
metrics.update(val_metrics)
|
|
|
|
|
|
for key in val_metrics:
|
|
|
if "val-aux" in key and "mean@" in key:
|
|
|
if val_metrics[key] > best_val_rewards:
|
|
|
best_val_rewards = metrics[key]
|
|
|
if self.config.trainer.save_freq > 0:
|
|
|
with _timer("save_checkpoint", timing_raw):
|
|
|
self._save_checkpoint()
|
|
|
|
|
|
|
|
|
metrics.update(
|
|
|
{
|
|
|
"training/global_step": self.global_steps,
|
|
|
"training/epoch": epoch,
|
|
|
}
|
|
|
)
|
|
|
|
|
|
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
|
|
|
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
|
|
|
|
|
|
n_gpus = self.resource_pool_manager.get_n_gpus()
|
|
|
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
|
|
|
|
|
|
|
|
|
logger.log(data=metrics, step=self.global_steps)
|
|
|
|
|
|
if is_last_step:
|
|
|
pprint(f"Final validation metrics: {last_val_metrics}")
|
|
|
progress_bar.close()
|
|
|
return
|
|
|
|
|
|
progress_bar.update(1)
|
|
|
self.global_steps += 1
|
|
|
|