| | """ |
| | TODO: Add coefficients for losses (depend on total number of tokens or batch) |
| | TODO: adapt reinforce step for torch.compile |
| | TODO: add lr schedulers support |
| | """ |
| | import logging |
| | import os |
| | import pickle |
| | import sys |
| | from abc import ABC, abstractmethod |
| | from typing import Callable, Literal, Union |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | from accelerate import Accelerator |
| | from pandas._libs.tslibs.offsets import CBMonthBegin |
| | from peft import LoraConfig |
| | from torch.nn.utils.rnn import pad_sequence |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| |
|
| | from mllm.markov_games.rollout_tree import * |
| | from mllm.markov_games.rollout_tree import RolloutTreeRootNode |
| | from mllm.training.annealing_methods import sigmoid_annealing |
| | from mllm.training.credit_methods import ( |
| | get_discounted_returns, |
| | get_generalized_advantage_estimates, |
| | get_rloo_credits, |
| | whiten_advantages, |
| | whiten_advantages_time_step_wise, |
| | ) |
| | from mllm.training.tally_metrics import Tally |
| | from mllm.training.tally_rollout import RolloutTally, RolloutTallyItem |
| | from mllm.training.tally_tokenwise import ContextualizedTokenwiseTally |
| | from mllm.training.tokenize_chats import * |
| | from mllm.training.tokenize_chats import process_training_chat |
| | from mllm.training.training_data_utils import * |
| | from mllm.training.training_data_utils import ( |
| | TrainingBatch, |
| | TrajectoryBatch, |
| | get_tokenwise_credits, |
| | ) |
| | from mllm.utils.resource_context import resource_logger_context |
| |
|
| | logger = logging.getLogger(__name__) |
| | logger.addHandler(logging.StreamHandler(sys.stdout)) |
| |
|
| |
|
| | @dataclass |
| | class TrainerAnnealingState: |
| | annealing_step_counter: int = 0 |
| |
|
| |
|
| | class BaseTrainer(ABC): |
| | """ |
| | Trainer |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | policy: AutoModelForCausalLM, |
| | policy_optimizer: torch.optim.Optimizer, |
| | critic: Union[AutoModelForCausalLM, None], |
| | critic_optimizer: Union[torch.optim.Optimizer, None], |
| | tokenizer: AutoTokenizer, |
| | lr_scheduler: torch.optim.lr_scheduler.LRScheduler, |
| | critic_lr_scheduler: Union[torch.optim.lr_scheduler.LRScheduler, None], |
| | |
| | entropy_coeff: float, |
| | entropy_topk: int, |
| | entropy_mask_regex: Union[str, None], |
| | kl_coeff: float, |
| | gradient_clipping: Union[float, None], |
| | restrict_tokens: Union[list[str], None], |
| | mini_batch_size: int, |
| | use_gradient_checkpointing: bool, |
| | temperature: float, |
| | device: str, |
| | whiten_advantages: bool, |
| | whiten_advantages_time_step_wise: bool, |
| | use_gae: bool, |
| | use_gae_lambda_annealing: bool, |
| | gae_lambda_annealing_limit: float, |
| | gae_lambda_annealing_method: Literal["sigmoid_annealing"], |
| | gae_lambda_annealing_method_params: dict, |
| | pg_loss_normalization: Literal["batch", "nb_tokens"], |
| | use_rloo: bool, |
| | skip_discounted_state_visitation: bool, |
| | discount_factor: float, |
| | enable_tokenwise_logging: bool, |
| | save_path: str, |
| | reward_normalizing_constant: float = 1.0, |
| | critic_loss_type: Literal["mse", "huber"] = "huber", |
| | exploration_prompts_to_remove: list[str] = [], |
| | filter_higher_refprob_tokens_kl: bool = False, |
| | truncated_importance_sampling_ratio_cap: float = 0.0, |
| | importance_sampling_strategy: Literal[ |
| | "per_token", "per_sequence" |
| | ] = "per_token", |
| | ): |
| | """ |
| | Initialize the REINFORCE trainer with reward shaping for multi-agent or single-agent training. |
| | |
| | Args: |
| | model (AutoModelForCausalLM): The main policy model. |
| | tokenizer (AutoTokenizer): Tokenizer for the model. |
| | optimizer (torch.optim.Optimizer): Optimizer for the policy model. |
| | lr_scheduler (torch.optim.lr_scheduler.LRScheduler): Learning rate scheduler for the policy model. |
| | critic (AutoModelForCausalLM or None): Critic model for value estimation (optional). |
| | critic_optimizer (torch.optim.Optimizer or None): Optimizer for the critic model (optional). |
| | critic_lr_scheduler (torch.optim.lr_scheduler.LRScheduler or None): LR scheduler for the critic (optional). |
| | config (RtConfig): Configuration object for training. |
| | """ |
| | self.tokenizer = tokenizer |
| | |
| | if self.tokenizer.pad_token_id is None: |
| | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id |
| | self.lr_scheduler = lr_scheduler |
| | self.accelerator = Accelerator() |
| | ( |
| | self.policy, |
| | self.policy_optimizer, |
| | self.critic, |
| | self.critic_optimizer, |
| | ) = self.accelerator.prepare(policy, policy_optimizer, critic, critic_optimizer) |
| |
|
| | self.critic_lr_scheduler = critic_lr_scheduler |
| | self.tally = Tally() |
| |
|
| | if use_gradient_checkpointing == True: |
| | self.policy.gradient_checkpointing_enable(dict(use_reentrant=False)) |
| | if critic is not None: |
| | self.critic.gradient_checkpointing_enable(dict(use_reentrant=False)) |
| |
|
| | self.save_path = save_path |
| |
|
| | |
| | self.trainer_annealing_state_path = os.path.join( |
| | self.save_path, "trainer_annealing_state.pkl" |
| | ) |
| | if os.path.exists(self.trainer_annealing_state_path): |
| | logger.info( |
| | f"Loading trainer state from {self.trainer_annealing_state_path}" |
| | ) |
| | self.trainer_annealing_state = pickle.load( |
| | open(self.trainer_annealing_state_path, "rb") |
| | ) |
| | else: |
| | self.trainer_annealing_state = TrainerAnnealingState() |
| |
|
| | |
| | self.policy_optimizer_path = os.path.join( |
| | self.save_path, "policy_optimizer_state.pt" |
| | ) |
| | if os.path.exists(self.policy_optimizer_path): |
| | logger.info( |
| | f"Loading policy optimizer state from {self.policy_optimizer_path}" |
| | ) |
| | self.policy_optimizer.load_state_dict( |
| | torch.load(self.policy_optimizer_path) |
| | ) |
| |
|
| | |
| | self.critic_optimizer_path = os.path.join( |
| | self.save_path, "critic_optimizer_state.pt" |
| | ) |
| | if ( |
| | os.path.exists(self.critic_optimizer_path) |
| | and self.critic_optimizer is not None |
| | ): |
| | logger.info( |
| | f"Loading critic optimizer state from {self.critic_optimizer_path}" |
| | ) |
| | self.critic_optimizer.load_state_dict( |
| | torch.load(self.critic_optimizer_path) |
| | ) |
| | self.device = self.accelerator.device |
| | self.entropy_coeff = entropy_coeff |
| | self.entropy_topk = entropy_topk |
| | self.entropy_mask_regex = entropy_mask_regex |
| | self.kl_coeff = kl_coeff |
| | self.gradient_clipping = gradient_clipping |
| | self.restrict_tokens = restrict_tokens |
| | self.mini_batch_size = mini_batch_size |
| | self.use_gradient_checkpointing = use_gradient_checkpointing |
| | self.temperature = temperature |
| | self.use_gae = use_gae |
| | self.whiten_advantages = whiten_advantages |
| | self.whiten_advantages_time_step_wise = whiten_advantages_time_step_wise |
| | self.use_rloo = use_rloo |
| | self.skip_discounted_state_visitation = skip_discounted_state_visitation |
| | self.use_gae_lambda_annealing = use_gae_lambda_annealing |
| | self.gae_lambda_annealing_limit = gae_lambda_annealing_limit |
| | if use_gae_lambda_annealing: |
| | self.gae_lambda_annealing_method: Callable[ |
| | [int], float |
| | ] = lambda step: eval(gae_lambda_annealing_method)( |
| | step=step, **gae_lambda_annealing_method_params |
| | ) |
| | self.discount_factor = discount_factor |
| | self.enable_tokenwise_logging = enable_tokenwise_logging |
| | self.reward_normalizing_constant = reward_normalizing_constant |
| | self.pg_loss_normalization = pg_loss_normalization |
| | self.critic_loss_type = critic_loss_type |
| | self.exploration_prompts_to_remove = exploration_prompts_to_remove |
| | |
| | self.training_data: dict = {} |
| | self.debug_path_list: list[str] = [] |
| | self.policy_gradient_data = None |
| | self.tally = Tally() |
| | self.rollout_tally = RolloutTally() |
| | self.tokenwise_tally: Union[ContextualizedTokenwiseTally, None] = None |
| | self.filter_higher_refprob_tokens_kl = filter_higher_refprob_tokens_kl |
| | self.truncated_importance_sampling_ratio_cap = ( |
| | truncated_importance_sampling_ratio_cap |
| | ) |
| | self.importance_sampling_strategy = importance_sampling_strategy |
| |
|
| | def mask_non_restricted_token_logits(self, logits: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Masks logits so that only allowed tokens (as specified in config.restrict_tokens) |
| | and the EOS token are active. |
| | All other logits are set to -inf, effectively removing them from the softmax. |
| | |
| | Args: |
| | logits (torch.Tensor): The logits tensor of shape (B, S, V). |
| | |
| | Returns: |
| | torch.Tensor: The masked logits tensor. |
| | """ |
| | |
| | |
| |
|
| | if self.restrict_tokens is not None: |
| | allowed_token_ids = [] |
| | for token in self.restrict_tokens: |
| | token_ids = self.tokenizer(token, add_special_tokens=False)["input_ids"] |
| | allowed_token_ids.append(token_ids[0]) |
| | allowed_token_ids.append( |
| | self.tokenizer.eos_token_id |
| | ) |
| | allowed_token_ids = torch.tensor(allowed_token_ids, device=logits.device) |
| | |
| | mask = torch.zeros_like(logits).bool() |
| | mask[..., allowed_token_ids] = True |
| | logits = torch.where( |
| | mask, |
| | logits, |
| | torch.tensor(-float("inf"), device=logits.device), |
| | ) |
| |
|
| | return logits |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | def apply_reinforce_step( |
| | self, |
| | training_batch: TrainingBatch, |
| | ) -> None: |
| | """ |
| | Applies a single REINFORCE policy gradient step using the provided batch of rollouts. |
| | Handles batching, loss computation (including entropy and KL regularization), gradient accumulation, and optimizer step. |
| | Optionally logs various metrics and statistics. |
| | |
| | Args: |
| | paths (list[str]): List of game complete file paths for each rollout. |
| | contexts (list[torch.Tensor]): List of context tensors for each rollout. |
| | credits (list[torch.Tensor]): List of credit tensors (rewards/advantages) for each rollout. |
| | action_masks (list[torch.Tensor]): List of action mask tensors for each rollout. |
| | """ |
| | with resource_logger_context(logger, "Apply reinforce step"): |
| | self.policy.train() |
| | mb_size = self.mini_batch_size |
| | nb_rollouts = len(training_batch) |
| |
|
| | |
| | running_mean_logs = { |
| | "rl_objective": 0.0, |
| | "policy_gradient_loss": 0.0, |
| | "policy_gradient_norm": 0.0, |
| | "log_probs": 0.0, |
| | "credits": 0.0, |
| | "entropy": 0.0, |
| | "engine_log_probs_diff_clampfrac": 0.0, |
| | "tis_imp_ratio": 0.0, |
| | "ref_log_probs_diff_clampfrac": 0.0, |
| | "higher_refprob_frac": 0.0, |
| | "tis_imp_ratio_clampfrac": 0.0, |
| | } |
| | if self.entropy_coeff != 0.0: |
| | running_mean_logs["entropy"] = 0.0 |
| | if self.kl_coeff != 0.0: |
| | running_mean_logs["kl_divergence"] = 0.0 |
| |
|
| | |
| | total_tokens_generated = 0 |
| | for att_mask in training_batch.batch_action_mask: |
| | total_tokens_generated += att_mask.sum() |
| |
|
| | |
| | if self.pg_loss_normalization == "nb_tokens": |
| | normalization_factor = total_tokens_generated |
| | elif self.pg_loss_normalization == "batch": |
| | normalization_factor = np.ceil(nb_rollouts / mb_size).astype(int) |
| | else: |
| | raise ValueError( |
| | f"Invalid pg_loss_normalization: {self.pg_loss_normalization}" |
| | ) |
| |
|
| | |
| | for mb in range(0, nb_rollouts, mb_size): |
| | logger.info(f"Processing mini-batch {mb} of {nb_rollouts}") |
| | loss = 0.0 |
| | training_mb = training_batch[mb : mb + mb_size] |
| | training_mb = training_mb.get_padded_tensors() |
| | training_mb.to(self.device) |
| | ( |
| | tokens_mb, |
| | action_mask_mb, |
| | entropy_mask_mb, |
| | credits_mb, |
| | engine_log_probs_mb, |
| | timesteps_mb, |
| | ) = ( |
| | training_mb.batch_input_ids, |
| | training_mb.batch_action_mask, |
| | training_mb.batch_entropy_mask, |
| | training_mb.batch_credits, |
| | training_mb.batch_engine_log_probs, |
| | training_mb.batch_timesteps, |
| | ) |
| |
|
| | |
| | contexts_mb = tokens_mb[:, :-1] |
| | shifted_contexts_mb = tokens_mb[:, 1:] |
| | action_mask_mb = action_mask_mb[:, 1:] |
| | entropy_mask_mb = entropy_mask_mb[:, 1:] |
| | credits_mb = credits_mb[:, 1:] |
| | engine_log_probs_mb = engine_log_probs_mb[:, 1:] |
| | timesteps_mb = timesteps_mb[:, 1:] |
| |
|
| | if self.enable_tokenwise_logging: |
| | self.tokenwise_tally.set_action_mask(action_mask=action_mask_mb) |
| | self.tokenwise_tally.set_range(range=(mb, mb + mb_size)) |
| | self.tokenwise_tally.add_contexts(contexts=contexts_mb) |
| | self.tokenwise_tally.add_data( |
| | metric_id="next_token", |
| | metrics=shifted_contexts_mb, |
| | to_tids=True, |
| | ) |
| | self.tokenwise_tally.add_data( |
| | metric_id="entropy_mask", |
| | metrics=entropy_mask_mb, |
| | ) |
| |
|
| | if self.enable_tokenwise_logging: |
| | self.tokenwise_tally.add_data( |
| | metric_id="next_token_credit", metrics=credits_mb |
| | ) |
| |
|
| | |
| | |
| | logits = self.policy(input_ids=contexts_mb)[0] |
| |
|
| | |
| | if self.restrict_tokens is not None: |
| | logits = self.mask_non_restricted_token_logits(logits) |
| |
|
| | logits /= self.temperature |
| |
|
| | |
| | log_probs = F.log_softmax(logits, dim=-1) |
| |
|
| | |
| | action_log_probs = log_probs.gather( |
| | dim=-1, index=shifted_contexts_mb.unsqueeze(-1) |
| | ).squeeze( |
| | -1 |
| | ) |
| | if self.pg_loss_normalization == "batch": |
| | den_running_mean = action_mask_mb.sum() * normalization_factor |
| | else: |
| | den_running_mean = normalization_factor |
| | running_mean_logs["log_probs"] += ( |
| | action_log_probs * action_mask_mb |
| | ).sum().item() / den_running_mean |
| | running_mean_logs["credits"] += ( |
| | credits_mb * action_mask_mb |
| | ).sum().item() / den_running_mean |
| |
|
| | if self.enable_tokenwise_logging: |
| | self.tokenwise_tally.add_data( |
| | metric_id="next_token_log_prob", |
| | metrics=action_log_probs, |
| | ) |
| | self.tokenwise_tally.add_data( |
| | metric_id="engine_next_token_log_prob", |
| | metrics=engine_log_probs_mb, |
| | ) |
| | self.tokenwise_tally.add_data( |
| | metric_id="next_token_prob", |
| | metrics=torch.exp(action_log_probs), |
| | ) |
| | top_k_indices = torch.topk(logits, k=5, dim=-1).indices |
| | self.tokenwise_tally.add_data( |
| | metric_id=f"top_{5}_tids", |
| | metrics=top_k_indices, |
| | to_tids=True, |
| | ) |
| | self.tokenwise_tally.add_data( |
| | metric_id=f"top_{5}_probs", |
| | metrics=torch.exp(log_probs).gather( |
| | dim=-1, index=top_k_indices |
| | ), |
| | ) |
| |
|
| | rewarded_action_log_probs = ( |
| | action_mask_mb * credits_mb * action_log_probs |
| | ) |
| | |
| | INVALID_LOGPROB = 1.0 |
| | CLAMP_VALUE = 40.0 |
| | masked_action_log_probs = torch.masked_fill( |
| | action_log_probs, ~action_mask_mb, INVALID_LOGPROB |
| | ) |
| | masked_engine_log_probs = torch.masked_fill( |
| | engine_log_probs_mb, ~action_mask_mb, INVALID_LOGPROB |
| | ) |
| | with torch.no_grad(): |
| | action_engine_log_probs_diff = ( |
| | masked_action_log_probs - masked_engine_log_probs |
| | ).clamp(-CLAMP_VALUE, CLAMP_VALUE) |
| | running_mean_logs["engine_log_probs_diff_clampfrac"] += ( |
| | action_engine_log_probs_diff.abs() |
| | .eq(CLAMP_VALUE) |
| | .float() |
| | .sum() |
| | .item() |
| | / den_running_mean |
| | ) |
| | if self.importance_sampling_strategy == "per_sequence": |
| | tis_imp_ratio = torch.zeros_like(action_engine_log_probs_diff) |
| | for mb_idx in range(action_engine_log_probs_diff.shape[0]): |
| | valid_token_mask = action_mask_mb[mb_idx] |
| | timestep_ids = timesteps_mb[mb_idx][valid_token_mask] |
| | timestep_logprob_diffs = action_engine_log_probs_diff[mb_idx][ |
| | valid_token_mask |
| | ] |
| | max_timestep = int(timestep_ids.max().item()) + 1 |
| | timestep_sums = torch.zeros( |
| | max_timestep, |
| | device=action_engine_log_probs_diff.device, |
| | dtype=action_engine_log_probs_diff.dtype, |
| | ) |
| | timestep_sums.scatter_add_( |
| | 0, timestep_ids, timestep_logprob_diffs |
| | ) |
| | timestep_ratios = torch.exp(timestep_sums) |
| | tis_imp_ratio[ |
| | mb_idx, valid_token_mask |
| | ] = timestep_ratios.gather(0, timestep_ids) |
| | else: |
| | tis_imp_ratio = torch.exp(action_engine_log_probs_diff) |
| | running_mean_logs["tis_imp_ratio"] += ( |
| | tis_imp_ratio * action_mask_mb |
| | ).sum().item() / den_running_mean |
| | if self.truncated_importance_sampling_ratio_cap > 0.0: |
| | tis_imp_ratio = torch.clamp( |
| | tis_imp_ratio, max=self.truncated_importance_sampling_ratio_cap |
| | ) |
| | running_mean_logs["tis_imp_ratio_clampfrac"] += ( |
| | tis_imp_ratio.eq(self.truncated_importance_sampling_ratio_cap) |
| | .float() |
| | .sum() |
| | .item() |
| | ) / den_running_mean |
| | rewarded_action_log_probs = ( |
| | rewarded_action_log_probs * tis_imp_ratio |
| | ) |
| |
|
| | if self.enable_tokenwise_logging: |
| | self.tokenwise_tally.add_data( |
| | metric_id="next_token_clogπ", |
| | metrics=rewarded_action_log_probs, |
| | ) |
| |
|
| | |
| | if self.pg_loss_normalization == "batch": |
| | nb_act_tokens = action_mask_mb.sum() |
| | mb_value = -rewarded_action_log_probs.sum() / nb_act_tokens |
| | else: |
| | mb_value = -rewarded_action_log_probs.sum() |
| |
|
| | loss += mb_value |
| | running_mean_logs["rl_objective"] += mb_value.item() / den_running_mean |
| |
|
| | |
| | |
| | |
| | |
| | if self.entropy_topk is not None: |
| | top_k_indices = torch.topk( |
| | logits, k=self.entropy_topk, dim=-1 |
| | ).indices |
| | entropy_logits = logits.gather(dim=-1, index=top_k_indices) |
| | else: |
| | entropy_logits = logits |
| |
|
| | token_entropy_terms = -F.softmax( |
| | entropy_logits, dim=-1 |
| | ) * F.log_softmax( |
| | entropy_logits, dim=-1 |
| | ) |
| | token_entropy_terms *= ( |
| | action_mask_mb[:, :, None] * entropy_mask_mb[:, :, None] |
| | ) |
| |
|
| | mb_entropy = token_entropy_terms.sum(dim=-1) |
| |
|
| | if self.enable_tokenwise_logging: |
| | self.tokenwise_tally.add_data( |
| | metric_id="entropy", |
| | metrics=mb_entropy, |
| | ) |
| | if self.pg_loss_normalization == "batch": |
| | nb_act_tokens = action_mask_mb.sum() |
| | mb_entropy = -mb_entropy.sum() / nb_act_tokens |
| | else: |
| | mb_entropy = -mb_entropy.sum() |
| | running_mean_logs["entropy"] += -mb_entropy.item() / den_running_mean |
| | if self.entropy_coeff != 0.0: |
| | mb_entropy *= self.entropy_coeff |
| | loss += mb_entropy |
| |
|
| | |
| | |
| | |
| | if self.kl_coeff != 0.0: |
| | ref_model_logits = self.policy.get_base_model_logits(contexts_mb) |
| | ref_model_logits = ref_model_logits / self.temperature |
| | |
| | ref_model_logits = self.mask_non_restricted_token_logits( |
| | logits=ref_model_logits |
| | ) |
| | |
| | ref_model_log_probs = F.log_softmax(ref_model_logits, dim=-1) |
| | |
| | ref_model_action_log_probs = ref_model_log_probs.gather( |
| | dim=-1, index=shifted_contexts_mb.unsqueeze(-1) |
| | ).squeeze( |
| | -1 |
| | ) |
| | |
| | |
| | |
| | masked_ref_model_action_log_probs = torch.masked_fill( |
| | ref_model_action_log_probs, ~action_mask_mb, INVALID_LOGPROB |
| | ) |
| | action_log_probs_diff = ( |
| | masked_ref_model_action_log_probs - masked_action_log_probs |
| | ).clamp(-CLAMP_VALUE, CLAMP_VALUE) |
| | running_mean_logs["ref_log_probs_diff_clampfrac"] += ( |
| | action_log_probs_diff.abs().eq(CLAMP_VALUE).float().sum().item() |
| | / den_running_mean |
| | ) |
| | if self.filter_higher_refprob_tokens_kl: |
| | higher_refprob_tokens_mask = action_log_probs_diff > 0.0 |
| | running_mean_logs["higher_refprob_frac"] += ( |
| | higher_refprob_tokens_mask.sum().item() / den_running_mean |
| | ) |
| | action_log_probs_diff = action_log_probs_diff * ( |
| | ~higher_refprob_tokens_mask |
| | ) |
| | kl_div = torch.expm1(action_log_probs_diff) - action_log_probs_diff |
| | kl_div *= action_mask_mb |
| | if self.truncated_importance_sampling_ratio_cap > 0.0: |
| | kl_div = kl_div * tis_imp_ratio |
| | kl_div *= self.kl_coeff |
| | if self.enable_tokenwise_logging: |
| | self.tokenwise_tally.add_data( |
| | metric_id="ref_model_next_token_log_prob", |
| | metrics=ref_model_action_log_probs, |
| | ) |
| | self.tokenwise_tally.add_data( |
| | metric_id="kl_divergence", |
| | metrics=kl_div, |
| | ) |
| |
|
| | if self.pg_loss_normalization == "batch": |
| | nb_act_tokens = action_mask_mb.sum() |
| | mb_kl = kl_div.sum() / nb_act_tokens |
| | else: |
| | mb_kl = kl_div.sum() |
| | running_mean_logs["kl_divergence"] += ( |
| | mb_kl.item() / den_running_mean |
| | ) |
| | loss += mb_kl |
| |
|
| | |
| | running_mean_logs["policy_gradient_loss"] += ( |
| | loss.item() / den_running_mean |
| | ) |
| | loss /= normalization_factor |
| | self.accelerator.backward(loss) |
| |
|
| | |
| | del training_mb |
| | del log_probs |
| | del logits |
| | del loss |
| | del action_log_probs |
| | del rewarded_action_log_probs |
| |
|
| | logger.info( |
| | f"Accumulated the policy gradient loss for {total_tokens_generated} tokens." |
| | ) |
| |
|
| | |
| | if self.gradient_clipping is not None: |
| | grad_norm = self.accelerator.clip_grad_norm_( |
| | self.policy.parameters(), self.gradient_clipping |
| | ) |
| | running_mean_logs["policy_gradient_norm"] += grad_norm.item() |
| |
|
| | |
| | self.policy_optimizer.step() |
| | self.policy_optimizer.zero_grad() |
| |
|
| | |
| | for key, value in running_mean_logs.items(): |
| | self.tally.add_metric(path=key, metric=value) |
| |
|
| | |
| | |
| | self.accelerator.clear(self.policy, self.policy_optimizer) |
| | import gc |
| |
|
| | gc.collect() |
| | torch.cuda.empty_cache() |
| | return running_mean_logs |
| |
|
| | def get_advantages_with_critic_gradient_accumulation( |
| | self, trajectories: TrajectoryBatch, critic_loss_scaling_factor: float = 2.0 |
| | ) -> torch.FloatTensor: |
| | """ |
| | TOWRITE |
| | Uses GAE if enabled, otherwise uses Monte Carlo returns. |
| | Optionally trains the critic if GAE is used. |
| | Returns: |
| | advantages: NestedFloatTensors |
| | """ |
| |
|
| | mb_size = self.mini_batch_size |
| | batch_size = trajectories.rollout_ids.shape[0] |
| | agent_id = trajectories.agent_ids[0] |
| | batch_rewards = trajectories.batch_rewards |
| |
|
| | |
| | |
| | |
| | if self.use_gae: |
| | if "buffer" in agent_id: |
| | self.critic.eval() |
| | training = False |
| | else: |
| | self.critic.train() |
| | training = True |
| | advantages = [] |
| | |
| | normalization_factor = ( |
| | np.ceil(batch_size / mb_size).astype(int) * critic_loss_scaling_factor |
| | ) |
| | |
| | for mb in range(0, batch_size, mb_size): |
| | trajectory_mb = trajectories[mb : mb + mb_size] |
| | trajectory_mb.to(self.device) |
| | rewards_mb = trajectory_mb.batch_rewards |
| | ( |
| | tokens_mb, |
| | state_ends_mask_mb, |
| | timestep_counts, |
| | ) = trajectory_mb.get_padded_tensors_for_critic() |
| | |
| | if training: |
| | vals_estimate_full = self.critic(tokens_mb) |
| | else: |
| | with torch.no_grad(): |
| | vals_estimate_full = self.critic(tokens_mb) |
| |
|
| | |
| | |
| |
|
| | |
| | B = tokens_mb.shape[0] |
| | vals_list = [ |
| | vals_estimate_full[b][state_ends_mask_mb[b]] for b in range(B) |
| | ] |
| |
|
| | |
| | vals_estimate_mb = pad_sequence( |
| | vals_list, batch_first=True, padding_value=0.0 |
| | ) |
| | dtype = vals_estimate_mb.dtype |
| | rewards_mb = pad_sequence( |
| | rewards_mb, batch_first=True, padding_value=0.0 |
| | ).to( |
| | dtype=dtype |
| | ) |
| | self.rollout_tally.add_metric( |
| | path=["batch_rewards"], |
| | rollout_tally_item=RolloutTallyItem( |
| | crn_ids=trajectory_mb.crn_ids, |
| | rollout_ids=trajectory_mb.rollout_ids, |
| | agent_ids=trajectory_mb.agent_ids, |
| | metric_matrix=rewards_mb, |
| | ), |
| | ) |
| | if self.reward_normalizing_constant != 1.0: |
| | rewards_mb /= self.reward_normalizing_constant |
| |
|
| | det_vals_estimate_mb = vals_estimate_mb.detach() |
| | self.rollout_tally.add_metric( |
| | path=["mb_value_estimates_critic"], |
| | rollout_tally_item=RolloutTallyItem( |
| | crn_ids=trajectory_mb.crn_ids, |
| | rollout_ids=trajectory_mb.rollout_ids, |
| | agent_ids=trajectory_mb.agent_ids, |
| | metric_matrix=det_vals_estimate_mb, |
| | ), |
| | ) |
| |
|
| | |
| | if det_vals_estimate_mb.shape[1] == rewards_mb.shape[1]: |
| | Bsize = det_vals_estimate_mb.shape[0] |
| | device = det_vals_estimate_mb.device |
| | dtype = det_vals_estimate_mb.dtype |
| | det_vals_estimate_mb = torch.cat( |
| | [ |
| | det_vals_estimate_mb, |
| | torch.zeros((Bsize, 1), device=device, dtype=dtype), |
| | ], |
| | dim=1, |
| | ) |
| | else: |
| | raise ValueError( |
| | "Incompatible shapes for value estimates and rewards." |
| | ) |
| |
|
| | |
| | if self.use_gae_lambda_annealing: |
| | annealing_constant = self.gae_lambda_annealing_method( |
| | step=self.trainer_annealing_state.annealing_step_counter |
| | ) |
| | annealed_lambda = ( |
| | self.gae_lambda_annealing_limit * annealing_constant |
| | ) |
| | self.tally.add_metric( |
| | path="annealed_lambda", metric=annealed_lambda |
| | ) |
| | else: |
| | annealed_lambda = self.gae_lambda_annealing_limit |
| |
|
| | |
| | gae_advantages = get_generalized_advantage_estimates( |
| | rewards=rewards_mb, |
| | value_estimates=det_vals_estimate_mb, |
| | discount_factor=self.discount_factor, |
| | lambda_coef=annealed_lambda, |
| | ) |
| | self.rollout_tally.add_metric( |
| | path=["mb_gae_advantages"], |
| | rollout_tally_item=RolloutTallyItem( |
| | crn_ids=trajectory_mb.crn_ids, |
| | rollout_ids=trajectory_mb.rollout_ids, |
| | agent_ids=trajectory_mb.agent_ids, |
| | metric_matrix=gae_advantages, |
| | ), |
| | ) |
| | if training: |
| | targets = ( |
| | gae_advantages.to(dtype=dtype) + det_vals_estimate_mb[:, :-1] |
| | ) |
| | self.rollout_tally.add_metric( |
| | path=["mb_targets_critic"], |
| | rollout_tally_item=RolloutTallyItem( |
| | crn_ids=trajectory_mb.crn_ids, |
| | rollout_ids=trajectory_mb.rollout_ids, |
| | agent_ids=trajectory_mb.agent_ids, |
| | metric_matrix=targets, |
| | ), |
| | ) |
| | if self.critic_loss_type == "mse": |
| | loss = F.mse_loss( |
| | input=vals_estimate_mb, |
| | target=targets, |
| | ) |
| | elif self.critic_loss_type == "huber": |
| | loss = F.huber_loss( |
| | input=vals_estimate_mb, |
| | target=targets, |
| | ) |
| | self.tally.add_metric(path=["mb_critic_loss"], metric=loss.item()) |
| | |
| | loss /= normalization_factor |
| | self.accelerator.backward(loss) |
| | del loss |
| | del targets |
| | del vals_estimate_mb |
| | del trajectory_mb |
| | del vals_estimate_full |
| |
|
| | |
| | advantages.extend( |
| | [gae_advantages[i, : timestep_counts[i]] for i in range(B)] |
| | ) |
| |
|
| | |
| | |
| | |
| | else: |
| | lengths = [len(c) for c in batch_rewards] |
| | padded_rewards = pad_sequence( |
| | batch_rewards, batch_first=True, padding_value=0.0 |
| | ) |
| | self.rollout_tally.add_metric( |
| | path=["mb_rewards"], |
| | rollout_tally_item=RolloutTallyItem( |
| | crn_ids=trajectories.crn_ids, |
| | rollout_ids=trajectories.rollout_ids, |
| | agent_ids=trajectories.agent_ids, |
| | metric_matrix=padded_rewards, |
| | ), |
| | ) |
| | if self.reward_normalizing_constant != 1.0: |
| | padded_rewards /= self.reward_normalizing_constant |
| | padded_advantages = get_discounted_returns( |
| | rewards=padded_rewards, |
| | discount_factor=self.discount_factor, |
| | ) |
| | if self.use_rloo: |
| | is_grouped_by_rng = ( |
| | trajectories.crn_ids.unique().shape[0] |
| | != trajectories.crn_ids.shape[0] |
| | ) |
| | if is_grouped_by_rng: |
| | for crn_id in trajectories.crn_ids.unique(): |
| | rng_mask = trajectories.crn_ids == crn_id |
| | rng_advantages = padded_advantages[rng_mask] |
| | rng_advantages, _ = get_rloo_credits(credits=rng_advantages) |
| | padded_advantages[rng_mask] = rng_advantages |
| | else: |
| | padded_advantages, _ = get_rloo_credits(credits=padded_advantages) |
| | self.rollout_tally.add_metric( |
| | path=["mb_rloo_advantages"], |
| | rollout_tally_item=RolloutTallyItem( |
| | crn_ids=trajectories.crn_ids, |
| | rollout_ids=trajectories.rollout_ids, |
| | agent_ids=trajectories.agent_ids, |
| | metric_matrix=padded_advantages, |
| | ), |
| | ) |
| | advantages = [ |
| | padded_advantages[i, : lengths[i]] |
| | for i in range(padded_advantages.shape[0]) |
| | ] |
| |
|
| | if self.whiten_advantages_time_step_wise or self.whiten_advantages: |
| | lengths = [len(c) for c in advantages] |
| | padded_advantages = pad_sequence( |
| | advantages, batch_first=True, padding_value=0.0 |
| | ) |
| | if self.whiten_advantages_time_step_wise: |
| | whitened_padded_advantages = whiten_advantages_time_step_wise( |
| | padded_advantages |
| | ) |
| | path = ["mb_whitened_advantages_time_step_wise"] |
| | elif self.whiten_advantages: |
| | whitened_padded_advantages = whiten_advantages(padded_advantages) |
| | path = ["mb_whitened_advantages"] |
| | self.rollout_tally.add_metric( |
| | path=path, |
| | rollout_tally_item=RolloutTallyItem( |
| | crn_ids=trajectories.crn_ids, |
| | rollout_ids=trajectories.rollout_ids, |
| | agent_ids=trajectories.agent_ids, |
| | metric_matrix=whitened_padded_advantages, |
| | ), |
| | ) |
| | advantages = [ |
| | whitened_padded_advantages[i, : lengths[i]] |
| | for i in range(whitened_padded_advantages.shape[0]) |
| | ] |
| |
|
| | self.trainer_annealing_state.annealing_step_counter += 1 |
| |
|
| | return advantages |
| |
|
| | @abstractmethod |
| | def set_agent_trajectory_data( |
| | self, agent_id: str, roots: list[RolloutTreeRootNode] |
| | ) -> None: |
| | """ |
| | TOWRITE |
| | """ |
| | pass |
| |
|
| | def set_trajectory_data( |
| | self, roots: list[RolloutTreeRootNode], agent_ids: list[str] |
| | ) -> None: |
| | """ |
| | TOWRITE |
| | """ |
| | for agent_id in agent_ids: |
| | self.set_agent_trajectory_data(agent_id, roots) |
| |
|
| | @abstractmethod |
| | def share_advantage_data(self) -> list[AdvantagePacket]: |
| | pass |
| |
|
| | @abstractmethod |
| | def receive_advantage_data(self, advantage_packets: list[AdvantagePacket]) -> None: |
| | pass |
| |
|
| | def set_policy_gradient_data(self, agent_ids: list[str]) -> None: |
| | """ |
| | Already set earlier # TODO: make it separate and clean |
| | """ |
| | self.policy_gradient_data = None |
| | |
| | |
| | |
| | for agent_id in agent_ids: |
| | assert "buffer" not in agent_id, "Buffer agents do not train policy" |
| | trajectory_batch = self.training_data[agent_id] |
| | tokenwise_batch_credits = get_tokenwise_credits( |
| | batch_timesteps=trajectory_batch.batch_timesteps, |
| | batch_credits=trajectory_batch.batch_credits, |
| | ) |
| | policy_gradient_data = TrainingBatch( |
| | rollout_ids=trajectory_batch.rollout_ids, |
| | batch_input_ids=trajectory_batch.batch_input_ids, |
| | batch_action_mask=trajectory_batch.batch_action_mask, |
| | batch_entropy_mask=trajectory_batch.batch_entropy_mask, |
| | batch_credits=tokenwise_batch_credits, |
| | batch_engine_log_probs=trajectory_batch.batch_engine_log_probs, |
| | batch_timesteps=trajectory_batch.batch_timesteps, |
| | ) |
| | if self.policy_gradient_data is None: |
| | self.policy_gradient_data = policy_gradient_data |
| | else: |
| | self.policy_gradient_data.append(policy_gradient_data) |
| |
|
| | self.training_data = {} |
| | self.tokenwise_tally = ContextualizedTokenwiseTally( |
| | tokenizer=self.tokenizer, |
| | paths=self.debug_path_list, |
| | ) |
| |
|
| | def train(self) -> None: |
| | """ |
| | TOWRITE |
| | """ |
| | assert self.policy_gradient_data is not None, "Policy gradient data is not set" |
| | if self.critic_optimizer is not None: |
| | if self.gradient_clipping is not None: |
| | grad_norm = self.accelerator.clip_grad_norm_( |
| | self.critic.parameters(), self.gradient_clipping |
| | ) |
| | self.tally.add_metric( |
| | path="gradient_norm_critic", metric=grad_norm.item() |
| | ) |
| | |
| | self.critic_optimizer.step() |
| | self.critic_optimizer.zero_grad() |
| | self.accelerator.clear(self.critic, self.critic_optimizer) |
| | import gc |
| |
|
| | gc.collect() |
| | torch.cuda.empty_cache() |
| | running_mean_logs = self.apply_reinforce_step( |
| | training_batch=self.policy_gradient_data |
| | ) |
| | return running_mean_logs |
| |
|
| | def export_training_tally(self, identifier: str, folder: str) -> None: |
| | """ |
| | Saves and resets the collected training metrics using the tally object. |
| | """ |
| | os.makedirs(folder, exist_ok=True) |
| | self.tally.save(identifier=identifier, folder=folder) |
| | self.tokenwise_tally.save( |
| | path=os.path.join(folder, f"{identifier}_tokenwise.csv") |
| | ) |
| | self.rollout_tally.save(identifier=identifier, folder=folder) |
| | self.tally.reset() |
| | self.tokenwise_tally = None |
| | self.rollout_tally.reset() |
| | self.debug_path_list = [] |
| |
|
| | def export_optimizer_states(self) -> None: |
| | """ |
| | Saves the optimizer states for both the main model and critic (if it exists). |
| | """ |
| | try: |
| | os.makedirs(self.save_path, exist_ok=True) |
| |
|
| | torch.save(self.policy_optimizer.state_dict(), self.policy_optimizer_path) |
| | logger.info(f"Saved main optimizer state to {self.policy_optimizer_path}") |
| |
|
| | if self.critic_optimizer is not None: |
| | torch.save( |
| | self.critic_optimizer.state_dict(), self.critic_optimizer_path |
| | ) |
| | logger.info( |
| | f"Saved critic optimizer state to {self.critic_optimizer_path}" |
| | ) |
| | except Exception as e: |
| | logger.error(f"Error saving optimizer states: {str(e)}") |
| | raise |
| |
|
| | def export_trainer_annealing_state(self) -> None: |
| | """ |
| | Saves the trainer state. |
| | """ |
| | with open(self.trainer_annealing_state_path, "wb") as f: |
| | pickle.dump(self.trainer_annealing_state, f) |
| | logger.info(f"Saved trainer state to {self.trainer_annealing_state_path}") |
| |
|
| | def export_trainer_states(self) -> None: |
| | """ |
| | Saves the trainer states. |
| | """ |
| | self.export_optimizer_states() |
| | self.export_trainer_annealing_state() |
| |
|