| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| FSDP PPO Trainer with Ray-based single controller. |
| This trainer supports model-agonistic model initialization with huggingface |
| """ |
|
|
| import os |
| import uuid |
| from collections import defaultdict |
| from copy import deepcopy |
| from dataclasses import dataclass, field |
| from enum import Enum, IntEnum, auto |
| from typing import Any, Dict, List, Optional, Type |
|
|
| import numpy as np |
| import ray |
| import torch |
| from ray.experimental.tqdm_ray import tqdm |
| from torchdata.stateful_dataloader import StatefulDataLoader |
| from transformers import PreTrainedTokenizer, ProcessorMixin |
|
|
| from ..protocol import DataProto, pad_dataproto_to_divisor, unpad_dataproto |
| from ..single_controller.base import Worker |
| from ..single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup |
| from ..single_controller.ray.base import create_colocated_worker_cls |
| from ..utils import torch_functional as VF |
| from ..utils.checkpoint import CHECKPOINT_TRACKER, remove_obsolete_ckpt |
| from ..utils.logger import Tracker |
| from ..utils.py_functional import convert_dict_to_str, timer |
| from ..utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance |
| from ..workers.fsdp_workers import FSDPWorker |
| from ..workers.reward import FunctionRewardManager |
| from . import core_algos |
| from .config import PPOConfig |
| from .metrics import compute_data_metrics, compute_throughout_metrics, compute_timing_metrics, reduce_metrics |
|
|
|
|
| class Role(IntEnum): |
| """ |
| To create more roles dynamically, you can subclass Role and add new members |
| """ |
|
|
| Actor = auto() |
| Rollout = auto() |
| ActorRollout = auto() |
| Critic = auto() |
| RefPolicy = auto() |
| RewardModel = auto() |
| ActorRolloutRef = auto() |
|
|
|
|
| class AdvantageEstimator(str, Enum): |
| """ |
| Using an enumeration class to avoid spelling errors in adv_estimator |
| """ |
|
|
| GAE = "gae" |
| GRPO = "grpo" |
| REINFORCE_PLUS_PLUS = "reinforce_plus_plus" |
| REMAX = "remax" |
| RLOO = "rloo" |
|
|
|
|
| @dataclass |
| class ResourcePoolManager: |
| """ |
| Define a resource pool specification. Resource pool will be initialized first. |
| """ |
|
|
| resource_pool_spec: dict[str, list[int]] |
| mapping: dict[Role, str] |
| resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict) |
|
|
| def create_resource_pool(self): |
| for resource_pool_name, process_on_nodes in self.resource_pool_spec.items(): |
| |
| |
| |
| resource_pool = RayResourcePool( |
| process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name |
| ) |
| self.resource_pool_dict[resource_pool_name] = resource_pool |
|
|
| self._check_resource_available() |
|
|
| def get_resource_pool(self, role: Role) -> RayResourcePool: |
| """Get the resource pool of the worker.""" |
| return self.resource_pool_dict[self.mapping[role]] |
|
|
| def get_num_gpus(self) -> int: |
| """Get the number of gpus in this cluster.""" |
| return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]) |
|
|
| def _check_resource_available(self): |
| """Check if the resource pool can be satisfied in this ray cluster.""" |
| gpus_available = ray.available_resources().get("GPU", 0) |
| gpus_required = self.get_num_gpus() |
| if gpus_available < gpus_required: |
| raise ValueError(f"Total available GPUs {gpus_available} is less than total desired GPUs {gpus_required}.") |
|
|
|
|
| def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.KLController, kl_penalty="kl"): |
| token_level_scores = data.batch["token_level_scores"] |
| batch_size = data.batch.batch_size[0] |
| response_mask = data.batch["response_mask"] |
|
|
| |
| kld = core_algos.compute_kl(data.batch["old_log_probs"], data.batch["ref_log_probs"], kl_penalty=kl_penalty) |
| kld = kld * response_mask |
|
|
| data.batch["token_level_rewards"] = token_level_scores - kl_ctrl.kl_coef * kld |
|
|
| current_kl = VF.masked_mean(kld, mask=response_mask, dim=-1) |
| current_kl = torch.mean(current_kl, dim=0).item() |
| metrics = {"critic/kl": current_kl, "critic/kl_coef": kl_ctrl.kl_coef} |
|
|
| |
| kl_ctrl.update(current_kl=current_kl, n_steps=batch_size) |
| return data, metrics |
|
|
|
|
| def compute_advantage(data: DataProto, adv_estimator: AdvantageEstimator, gamma: float = 1.0, lam: float = 1.0): |
| token_level_rewards = data.batch["token_level_rewards"] |
| response_mask = data.batch["response_mask"] |
| index = data.non_tensor_batch["uid"] |
| if adv_estimator == AdvantageEstimator.GAE: |
| values = data.batch["values"] |
| advantages, returns = core_algos.compute_gae_advantage_return( |
| token_level_rewards, values, response_mask, gamma, lam |
| ) |
| elif adv_estimator == AdvantageEstimator.GRPO: |
| advantages, returns = core_algos.compute_grpo_outcome_advantage(token_level_rewards, response_mask, index) |
| elif adv_estimator == AdvantageEstimator.REINFORCE_PLUS_PLUS: |
| advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage( |
| token_level_rewards, response_mask, gamma |
| ) |
| elif adv_estimator == AdvantageEstimator.REMAX: |
| reward_baselines = data.batch["reward_baselines"] |
| advantages, returns = core_algos.compute_remax_outcome_advantage( |
| token_level_rewards, reward_baselines, response_mask |
| ) |
| elif adv_estimator == AdvantageEstimator.RLOO: |
| advantages, returns = core_algos.compute_rloo_outcome_advantage(token_level_rewards, response_mask, index) |
| else: |
| raise NotImplementedError |
|
|
| data.batch["advantages"] = advantages |
| data.batch["returns"] = returns |
| return data |
|
|
|
|
| class RayPPOTrainer: |
| """ |
| Note that this trainer runs on the driver process on a single CPU/GPU node. |
| """ |
|
|
| def __init__( |
| self, |
| config: PPOConfig, |
| tokenizer: PreTrainedTokenizer, |
| processor: Optional[ProcessorMixin], |
| train_dataloader: StatefulDataLoader, |
| val_dataloader: StatefulDataLoader, |
| role_worker_mapping: dict[Role, Type[Worker]], |
| resource_pool_manager: ResourcePoolManager, |
| ray_worker_group_cls: Type[RayWorkerGroup] = RayWorkerGroup, |
| reward_fn: Optional[FunctionRewardManager] = None, |
| val_reward_fn: Optional[FunctionRewardManager] = None, |
| ): |
| self.tokenizer = tokenizer |
| self.processor = processor |
| self.train_dataloader = train_dataloader |
| self.val_dataloader = val_dataloader |
| self.config = config |
| self.reward_fn = reward_fn |
| self.val_reward_fn = val_reward_fn |
|
|
| self.hybrid_engine = config.worker.hybrid_engine |
| if self.hybrid_engine: |
| assert Role.ActorRollout in role_worker_mapping, ( |
| f"ActorRollout should be included in {role_worker_mapping.keys()}." |
| ) |
| else: |
| raise NotImplementedError |
|
|
| self.role_worker_mapping = role_worker_mapping |
| self.resource_pool_manager = resource_pool_manager |
| self.use_reward_model = Role.RewardModel in role_worker_mapping |
| self.ray_worker_group_cls = ray_worker_group_cls |
|
|
| |
| if Role.RefPolicy in role_worker_mapping and not config.algorithm.disable_kl: |
| self.use_reference_policy = True |
| self.kl_ctrl = core_algos.get_kl_controller(config.algorithm) |
| else: |
| self.use_reference_policy = False |
| self.kl_ctrl = core_algos.FixedKLController(init_kl_coef=0.0) |
| print("KL is disabled, no KL metrics will be logged. Please set `kl_coef=0` to log KL metrics.") |
|
|
| if config.algorithm.adv_estimator == AdvantageEstimator.GAE: |
| self.use_critic = True |
| else: |
| self.use_critic = False |
|
|
| if config.algorithm.adv_estimator not in list(AdvantageEstimator): |
| raise NotImplementedError(f"Unknown advantage estimator: {config.algorithm.adv_estimator}.") |
|
|
| if config.data.rollout_batch_size % config.worker.actor.global_batch_size != 0: |
| raise ValueError("Rollout batch size must be divisible by actor global batch size.") |
|
|
| if ( |
| config.data.rollout_batch_size * config.worker.rollout.n |
| ) % config.worker.actor.micro_batch_size_per_device_for_experience != 0: |
| raise ValueError( |
| "Rollout batch size * rollout.n must be divisible by actor micro batch size for experience." |
| ) |
|
|
| if self.use_critic: |
| if config.data.rollout_batch_size % config.worker.critic.global_batch_size != 0: |
| raise ValueError("Rollout batch size must be divisible by critic global batch size.") |
|
|
| if ( |
| config.data.rollout_batch_size * config.worker.rollout.n |
| ) % config.worker.critic.micro_batch_size_per_device_for_experience != 0: |
| raise ValueError( |
| "Rollout batch size * rollout.n must be divisible by critic micro batch size for experience." |
| ) |
|
|
| if ( |
| config.algorithm.adv_estimator in (AdvantageEstimator.GRPO, AdvantageEstimator.RLOO) |
| and config.worker.rollout.n == 1 |
| ): |
| raise ValueError("GRPO and RLOO algorithm need `config.worker.rollout.n > 1`.") |
|
|
| if config.trainer.max_steps is not None: |
| self.training_steps = config.trainer.max_steps |
| else: |
| self.training_steps = len(train_dataloader) * config.trainer.total_epochs |
|
|
| config.worker.actor.optim.training_steps = self.training_steps |
| config.worker.critic.optim.training_steps = self.training_steps |
| print(f"Total training steps: {self.training_steps}") |
|
|
| def _maybe_log_val_generations( |
| self, inputs: List[str], outputs: List[str], labels: List[str], scores: List[float] |
| ) -> None: |
| """Log a table of validation samples""" |
| if self.config.trainer.val_generations_to_log <= 0: |
| return |
|
|
| |
| samples = list(zip(inputs, outputs, labels, scores)) |
| samples.sort(key=lambda x: x[0]) |
|
|
| |
| rng = np.random.RandomState(42) |
| rng.shuffle(samples) |
|
|
| samples = samples[: self.config.trainer.val_generations_to_log] |
| self.logger.log_generation(samples, self.global_step) |
|
|
| def _validate(self) -> Dict[str, Any]: |
| reward_tensor_lst = [] |
| |
| sample_inputs, sample_outputs, sample_labels, sample_scores = [], [], [], [] |
| reward_metrics_lst = defaultdict(list) |
| for batch_dict in self.val_dataloader: |
| test_batch = DataProto.from_single_dict(batch_dict) |
| |
| input_ids = test_batch.batch["input_ids"] |
| input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] |
| sample_inputs.extend(input_texts) |
|
|
| test_gen_batch = test_batch.pop( |
| batch_keys=["input_ids", "attention_mask", "position_ids"], |
| non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data"], |
| ) |
| test_gen_batch.meta_info = self.config.worker.rollout.val_override_config |
| test_gen_batch.meta_info["min_pixels"] = self.config.data.min_pixels |
| test_gen_batch.meta_info["max_pixels"] = self.config.data.max_pixels |
| test_gen_batch, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size) |
| test_output_gen_batch = self.actor_rollout_wg.generate_sequences(test_gen_batch) |
| test_output_gen_batch = unpad_dataproto(test_output_gen_batch, pad_size=pad_size) |
|
|
| |
| output_ids = test_output_gen_batch.batch["responses"] |
| output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] |
| sample_outputs.extend(output_texts) |
| sample_labels.extend(test_batch.non_tensor_batch["ground_truth"].tolist()) |
| test_batch = test_batch.union(test_output_gen_batch) |
|
|
| |
| reward_tensor, reward_metrics = ray.get(self.val_reward_fn.compute_reward.remote(test_batch)) |
|
|
| |
| scores = reward_tensor.sum(-1).cpu().tolist() |
| sample_scores.extend(scores) |
|
|
| reward_tensor_lst.append(reward_tensor) |
| for key, value in reward_metrics.items(): |
| reward_metrics_lst[key].extend(value) |
|
|
| self._maybe_log_val_generations(sample_inputs, sample_outputs, sample_labels, sample_scores) |
| reward_score = torch.cat(reward_tensor_lst, dim=0).sum(-1).mean().item() |
| val_reward_metrics = {f"val/{key}_reward": value for key, value in reduce_metrics(reward_metrics_lst).items()} |
| return {"val/reward_score": reward_score, **val_reward_metrics} |
|
|
| def init_workers(self) -> None: |
| """Init resource pool and worker group""" |
| self.resource_pool_manager.create_resource_pool() |
| self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} |
|
|
| |
| if self.hybrid_engine: |
| resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout) |
| actor_rollout_cls = RayClassWithInitArgs( |
| cls=self.role_worker_mapping[Role.ActorRollout], config=self.config.worker, role="actor_rollout" |
| ) |
| self.resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls |
| else: |
| raise NotImplementedError |
|
|
| |
| if self.use_critic: |
| resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) |
| critic_cls = RayClassWithInitArgs( |
| cls=self.role_worker_mapping[Role.Critic], config=self.config.worker, role="critic" |
| ) |
| self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls |
|
|
| |
| if self.use_reference_policy: |
| resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) |
| ref_policy_cls = RayClassWithInitArgs( |
| self.role_worker_mapping[Role.RefPolicy], config=self.config.worker, role="ref" |
| ) |
| self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls |
|
|
| |
| if self.use_reward_model: |
| |
| resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) |
| rm_cls = RayClassWithInitArgs( |
| cls=self.role_worker_mapping[Role.RewardModel], config=self.config.worker, role="reward" |
| ) |
| self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls |
|
|
| |
| |
| |
| |
| all_wg: Dict[str, FSDPWorker] = {} |
| self.wg_dicts = [] |
| for resource_pool, class_dict in self.resource_pool_to_cls.items(): |
| worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) |
| wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls) |
| spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) |
| all_wg.update(spawn_wg) |
| |
| self.wg_dicts.append(wg_dict) |
|
|
| if self.use_critic: |
| self.critic_wg = all_wg["critic"] |
| self.critic_wg.init_model() |
|
|
| if self.use_reference_policy: |
| self.ref_policy_wg = all_wg["ref"] |
| self.ref_policy_wg.init_model() |
|
|
| if self.use_reward_model: |
| self.rm_wg = all_wg["rm"] |
| self.rm_wg.init_model() |
|
|
| |
| self.actor_rollout_wg = all_wg["actor_rollout"] |
| self.actor_rollout_wg.init_model() |
|
|
| def _save_checkpoint(self) -> None: |
| |
| remove_obsolete_ckpt( |
| self.config.trainer.save_checkpoint_path, self.global_step, self.config.trainer.save_limit |
| ) |
| folder_path = os.path.join(self.config.trainer.save_checkpoint_path, f"global_step_{self.global_step}") |
| actor_path = os.path.join(folder_path, "actor") |
| self.actor_rollout_wg.save_checkpoint(actor_path) |
|
|
| if self.use_critic: |
| critic_path = os.path.join(folder_path, "critic") |
| self.critic_wg.save_checkpoint(critic_path) |
|
|
| dataloader_path = os.path.join(folder_path, "dataloader.pt") |
| dataloader_state_dict = self.train_dataloader.state_dict() |
| torch.save(dataloader_state_dict, dataloader_path) |
|
|
| last_global_step_path = os.path.join(self.config.trainer.save_checkpoint_path, CHECKPOINT_TRACKER) |
| with open(last_global_step_path, "w") as f: |
| f.write(str(self.global_step)) |
|
|
| def _load_checkpoint(self) -> None: |
| if self.config.trainer.load_checkpoint_path is None: |
| return |
|
|
| if "global_step_" not in self.config.trainer.load_checkpoint_path.strip(os.path.sep).split(os.path.sep)[-1]: |
| raise ValueError("`load_checkpoint_path` should end with `global_step_*`.") |
|
|
| print(f"Load from checkpoint: {self.config.trainer.load_checkpoint_path}.") |
| self.global_step = int(self.config.trainer.load_checkpoint_path.strip(os.path.sep).split("global_step_")[-1]) |
| actor_path = os.path.join(self.config.trainer.load_checkpoint_path, "actor") |
| self.actor_rollout_wg.load_checkpoint(actor_path) |
| if self.use_critic: |
| critic_path = os.path.join(self.config.trainer.load_checkpoint_path, "critic") |
| self.critic_wg.load_checkpoint(critic_path) |
|
|
| dataloader_path = os.path.join(self.config.trainer.load_checkpoint_path, "dataloader.pt") |
| if os.path.exists(dataloader_path): |
| dataloader_state_dict = torch.load(dataloader_path, weights_only=False) |
| self.train_dataloader.load_state_dict(dataloader_state_dict) |
| else: |
| print(f"No dataloader state found at {dataloader_path}, will start from scratch.") |
|
|
| def _balance_batch(self, batch: DataProto, metrics: Dict[str, Any], logging_prefix: str = "global_seqlen") -> None: |
| """Reorder the data on single controller such that each dp rank gets similar total tokens""" |
| attention_mask = batch.batch["attention_mask"] |
| batch_size = attention_mask.shape[0] |
| global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist() |
| world_size = self.actor_rollout_wg.world_size |
| global_partition_lst = get_seqlen_balanced_partitions( |
| global_seqlen_lst, k_partitions=world_size, equal_size=True |
| ) |
| |
| global_idx = torch.tensor([j for partition in global_partition_lst for j in partition]) |
| batch.reorder(global_idx) |
| global_balance_stats = log_seqlen_unbalance( |
| seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix |
| ) |
| metrics.update(global_balance_stats) |
|
|
| def fit(self): |
| """ |
| The training loop of PPO. |
| The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow. |
| The light-weight advantage computation is done on the driver process. |
| """ |
| self.logger = Tracker(loggers=self.config.trainer.logger, config=self.config.to_dict()) |
| self.global_step = 0 |
| val_metrics: Optional[Dict[str, Any]] = None |
|
|
| |
| self._load_checkpoint() |
|
|
| |
| |
| if self.val_reward_fn is not None and self.config.trainer.val_before_train: |
| val_metrics = self._validate() |
| self.logger.log(data=val_metrics, step=self.global_step) |
| if self.config.trainer.val_only: |
| return |
|
|
| for _ in tqdm(range(self.config.trainer.total_epochs), desc="Epoch", position=0): |
| for batch_dict in tqdm(self.train_dataloader, desc="Running step", position=1): |
| self.global_step += 1 |
| if self.global_step > self.training_steps: |
| break |
|
|
| metrics, timing_raw = {}, {} |
| meta_info = {"min_pixels": self.config.data.min_pixels, "max_pixels": self.config.data.max_pixels} |
| batch: DataProto = DataProto.from_single_dict(batch_dict, meta_info=meta_info) |
|
|
| |
| gen_batch = batch.pop( |
| batch_keys=["input_ids", "attention_mask", "position_ids"], |
| non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data"], |
| meta_info_keys=["min_pixels", "max_pixels"], |
| ) |
| with timer("step", timing_raw): |
| |
| with timer("gen", timing_raw): |
| gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) |
|
|
| if self.config.algorithm.adv_estimator == "remax": |
| with timer("gen_max", timing_raw): |
| gen_baseline_batch = deepcopy(gen_batch) |
| gen_baseline_batch.meta_info["temperature"] = 0 |
| gen_baseline_batch.meta_info["n"] = 1 |
| gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) |
|
|
| batch = batch.union(gen_baseline_output) |
| reward_baseline_tensor, _ = ray.get(self.reward_fn.compute_reward.remote(batch)) |
| reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) |
|
|
| batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) |
| batch.batch["reward_baselines"] = reward_baseline_tensor |
| del gen_baseline_batch, gen_baseline_output |
|
|
| batch.non_tensor_batch["uid"] = np.array( |
| [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object |
| ) |
| |
| batch = batch.repeat(repeat_times=self.config.worker.rollout.n, interleave=True) |
| batch = batch.union(gen_batch_output) |
|
|
| |
| |
| |
| self._balance_batch(batch, metrics=metrics) |
|
|
| |
| batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() |
|
|
| |
| with timer("reward", timing_raw): |
| reward_ref = self.reward_fn.compute_reward.remote(batch) |
|
|
| |
| with timer("old", timing_raw): |
| old_log_probs = self.actor_rollout_wg.compute_log_probs(batch) |
| batch = batch.union(old_log_probs) |
|
|
| |
| if self.use_reference_policy: |
| with timer("ref", timing_raw): |
| ref_log_probs = self.ref_policy_wg.compute_ref_log_probs(batch) |
| batch = batch.union(ref_log_probs) |
|
|
| |
| if self.use_critic: |
| with timer("values", timing_raw): |
| values = self.critic_wg.compute_values(batch) |
| batch = batch.union(values) |
|
|
| with timer("adv", timing_raw): |
| |
| reward_tensor, reward_metrics = ray.get(reward_ref) |
| batch.batch["token_level_scores"] = reward_tensor |
| reward_metrics = {f"reward/{k}": v for k, v in reduce_metrics(reward_metrics).items()} |
| metrics.update(reward_metrics) |
|
|
| |
| if not self.config.algorithm.use_kl_loss and self.use_reference_policy: |
| |
| batch, kl_metrics = apply_kl_penalty(batch, self.kl_ctrl, self.config.algorithm.kl_penalty) |
| metrics.update(kl_metrics) |
| else: |
| batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] |
|
|
| |
| batch = compute_advantage( |
| batch, |
| adv_estimator=self.config.algorithm.adv_estimator, |
| gamma=self.config.algorithm.gamma, |
| lam=self.config.algorithm.lam, |
| ) |
|
|
| |
| if self.use_critic: |
| with timer("update_critic", timing_raw): |
| critic_output = self.critic_wg.update_critic(batch) |
|
|
| critic_metrics = reduce_metrics(critic_output.non_tensor_batch) |
| metrics.update(critic_metrics) |
|
|
| |
| if self.config.trainer.critic_warmup <= self.global_step: |
| with timer("update_actor", timing_raw): |
| actor_output = self.actor_rollout_wg.update_actor(batch) |
|
|
| actor_metrics = reduce_metrics(actor_output.non_tensor_batch) |
| metrics.update(actor_metrics) |
|
|
| |
| if ( |
| self.val_reward_fn is not None |
| and self.config.trainer.val_freq > 0 |
| and self.global_step % self.config.trainer.val_freq == 0 |
| ): |
| with timer("validation", timing_raw): |
| val_metrics = self._validate() |
|
|
| metrics.update(val_metrics) |
|
|
| if self.config.trainer.save_freq > 0 and self.global_step % self.config.trainer.save_freq == 0: |
| with timer("save_checkpoint", timing_raw): |
| self._save_checkpoint() |
|
|
| |
| num_gpus = self.resource_pool_manager.get_num_gpus() |
| metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) |
| metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) |
| metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, num_gpus=num_gpus)) |
|
|
| self.logger.log(data=metrics, step=self.global_step) |
|
|
| |
| if self.val_reward_fn is not None: |
| if ( |
| val_metrics is None |
| or self.config.trainer.val_freq <= 0 |
| or self.global_step % self.config.trainer.val_freq != 0 |
| ): |
| val_metrics = self._validate() |
| self.logger.log(data=val_metrics, step=self.global_step) |
|
|
| print(f"Final validation metrics: {convert_dict_to_str(val_metrics)}") |
|
|
| if self.config.trainer.save_freq <= 0 or self.global_step % self.config.trainer.save_freq != 0: |
| self._save_checkpoint() |
|
|