Muqeeth's picture
Add files using upload-large-folder tool
1c8c60e verified
"""
TODO: Add coefficients for losses (depend on total number of tokens or batch)
TODO: adapt reinforce step for torch.compile
TODO: add lr schedulers support
"""
import logging
import os
import pickle
import sys
from abc import ABC, abstractmethod
from typing import Callable, Literal, Union
import numpy as np
import torch
import torch.nn.functional as F
from accelerate import Accelerator
from pandas._libs.tslibs.offsets import CBMonthBegin
from peft import LoraConfig
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoModelForCausalLM, AutoTokenizer
from mllm.markov_games.rollout_tree import *
from mllm.markov_games.rollout_tree import RolloutTreeRootNode
from mllm.training.annealing_methods import sigmoid_annealing
from mllm.training.credit_methods import (
get_discounted_returns,
get_generalized_advantage_estimates,
get_rloo_credits,
whiten_advantages,
whiten_advantages_time_step_wise,
)
from mllm.training.tally_metrics import Tally
from mllm.training.tally_rollout import RolloutTally, RolloutTallyItem
from mllm.training.tally_tokenwise import ContextualizedTokenwiseTally
from mllm.training.tokenize_chats import *
from mllm.training.tokenize_chats import process_training_chat
from mllm.training.training_data_utils import *
from mllm.training.training_data_utils import (
TrainingBatch,
TrajectoryBatch,
get_tokenwise_credits,
)
from mllm.utils.resource_context import resource_logger_context
logger = logging.getLogger(__name__)
logger.addHandler(logging.StreamHandler(sys.stdout))
@dataclass
class TrainerAnnealingState:
annealing_step_counter: int = 0
class BaseTrainer(ABC):
"""
Trainer
"""
def __init__(
self,
policy: AutoModelForCausalLM,
policy_optimizer: torch.optim.Optimizer,
critic: Union[AutoModelForCausalLM, None],
critic_optimizer: Union[torch.optim.Optimizer, None],
tokenizer: AutoTokenizer,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
critic_lr_scheduler: Union[torch.optim.lr_scheduler.LRScheduler, None],
######################################################################
entropy_coeff: float,
entropy_topk: int,
entropy_mask_regex: Union[str, None],
kl_coeff: float,
gradient_clipping: Union[float, None],
restrict_tokens: Union[list[str], None],
mini_batch_size: int,
use_gradient_checkpointing: bool,
temperature: float,
device: str,
whiten_advantages: bool,
whiten_advantages_time_step_wise: bool,
use_gae: bool,
use_gae_lambda_annealing: bool,
gae_lambda_annealing_limit: float,
gae_lambda_annealing_method: Literal["sigmoid_annealing"],
gae_lambda_annealing_method_params: dict,
pg_loss_normalization: Literal["batch", "nb_tokens"],
use_rloo: bool,
skip_discounted_state_visitation: bool,
discount_factor: float,
enable_tokenwise_logging: bool,
save_path: str,
reward_normalizing_constant: float = 1.0,
critic_loss_type: Literal["mse", "huber"] = "huber",
exploration_prompts_to_remove: list[str] = [],
filter_higher_refprob_tokens_kl: bool = False,
truncated_importance_sampling_ratio_cap: float = 0.0,
importance_sampling_strategy: Literal[
"per_token", "per_sequence"
] = "per_token",
):
"""
Initialize the REINFORCE trainer with reward shaping for multi-agent or single-agent training.
Args:
model (AutoModelForCausalLM): The main policy model.
tokenizer (AutoTokenizer): Tokenizer for the model.
optimizer (torch.optim.Optimizer): Optimizer for the policy model.
lr_scheduler (torch.optim.lr_scheduler.LRScheduler): Learning rate scheduler for the policy model.
critic (AutoModelForCausalLM or None): Critic model for value estimation (optional).
critic_optimizer (torch.optim.Optimizer or None): Optimizer for the critic model (optional).
critic_lr_scheduler (torch.optim.lr_scheduler.LRScheduler or None): LR scheduler for the critic (optional).
config (RtConfig): Configuration object for training.
"""
self.tokenizer = tokenizer
# self.tokenizer.padding_side = "left" # needed for flash attention
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.lr_scheduler = lr_scheduler
self.accelerator = Accelerator()
(
self.policy,
self.policy_optimizer,
self.critic,
self.critic_optimizer,
) = self.accelerator.prepare(policy, policy_optimizer, critic, critic_optimizer)
self.critic_lr_scheduler = critic_lr_scheduler
self.tally = Tally()
if use_gradient_checkpointing == True:
self.policy.gradient_checkpointing_enable(dict(use_reentrant=False))
if critic is not None:
self.critic.gradient_checkpointing_enable(dict(use_reentrant=False))
self.save_path = save_path
# Load trainer state if it exists
self.trainer_annealing_state_path = os.path.join(
self.save_path, "trainer_annealing_state.pkl"
)
if os.path.exists(self.trainer_annealing_state_path):
logger.info(
f"Loading trainer state from {self.trainer_annealing_state_path}"
)
self.trainer_annealing_state = pickle.load(
open(self.trainer_annealing_state_path, "rb")
)
else:
self.trainer_annealing_state = TrainerAnnealingState()
# Load policy optimizer state if it exists
self.policy_optimizer_path = os.path.join(
self.save_path, "policy_optimizer_state.pt"
)
if os.path.exists(self.policy_optimizer_path):
logger.info(
f"Loading policy optimizer state from {self.policy_optimizer_path}"
)
self.policy_optimizer.load_state_dict(
torch.load(self.policy_optimizer_path)
)
# Load critic optimizer state if it exists
self.critic_optimizer_path = os.path.join(
self.save_path, "critic_optimizer_state.pt"
)
if (
os.path.exists(self.critic_optimizer_path)
and self.critic_optimizer is not None
):
logger.info(
f"Loading critic optimizer state from {self.critic_optimizer_path}"
)
self.critic_optimizer.load_state_dict(
torch.load(self.critic_optimizer_path)
)
self.device = self.accelerator.device
self.entropy_coeff = entropy_coeff
self.entropy_topk = entropy_topk
self.entropy_mask_regex = entropy_mask_regex
self.kl_coeff = kl_coeff
self.gradient_clipping = gradient_clipping
self.restrict_tokens = restrict_tokens
self.mini_batch_size = mini_batch_size
self.use_gradient_checkpointing = use_gradient_checkpointing
self.temperature = temperature
self.use_gae = use_gae
self.whiten_advantages = whiten_advantages
self.whiten_advantages_time_step_wise = whiten_advantages_time_step_wise
self.use_rloo = use_rloo
self.skip_discounted_state_visitation = skip_discounted_state_visitation
self.use_gae_lambda_annealing = use_gae_lambda_annealing
self.gae_lambda_annealing_limit = gae_lambda_annealing_limit
if use_gae_lambda_annealing:
self.gae_lambda_annealing_method: Callable[
[int], float
] = lambda step: eval(gae_lambda_annealing_method)(
step=step, **gae_lambda_annealing_method_params
)
self.discount_factor = discount_factor
self.enable_tokenwise_logging = enable_tokenwise_logging
self.reward_normalizing_constant = reward_normalizing_constant
self.pg_loss_normalization = pg_loss_normalization
self.critic_loss_type = critic_loss_type
self.exploration_prompts_to_remove = exploration_prompts_to_remove
# Common containers used by all trainers
self.training_data: dict = {}
self.debug_path_list: list[str] = []
self.policy_gradient_data = None
self.tally = Tally()
self.rollout_tally = RolloutTally()
self.tokenwise_tally: Union[ContextualizedTokenwiseTally, None] = None
self.filter_higher_refprob_tokens_kl = filter_higher_refprob_tokens_kl
self.truncated_importance_sampling_ratio_cap = (
truncated_importance_sampling_ratio_cap
)
self.importance_sampling_strategy = importance_sampling_strategy
def mask_non_restricted_token_logits(self, logits: torch.Tensor) -> torch.Tensor:
"""
Masks logits so that only allowed tokens (as specified in config.restrict_tokens)
and the EOS token are active.
All other logits are set to -inf, effectively removing them from the softmax.
Args:
logits (torch.Tensor): The logits tensor of shape (B, S, V).
Returns:
torch.Tensor: The masked logits tensor.
"""
# TODO: verify. Not sure what we do here is differentiable
# also, we recompute for nothing
if self.restrict_tokens is not None:
allowed_token_ids = []
for token in self.restrict_tokens:
token_ids = self.tokenizer(token, add_special_tokens=False)["input_ids"]
allowed_token_ids.append(token_ids[0])
allowed_token_ids.append(
self.tokenizer.eos_token_id
) # This token should always be active
allowed_token_ids = torch.tensor(allowed_token_ids, device=logits.device)
# Mask log_probs and probs to only allowed tokens
mask = torch.zeros_like(logits).bool() # (B, S, V)
mask[..., allowed_token_ids] = True
logits = torch.where(
mask,
logits,
torch.tensor(-float("inf"), device=logits.device),
)
return logits
# def get_gradient_magnitude(self, loss_term: torch.Tensor) -> float:
# """
# Computes the L2 norm of the gradients of the given loss term with respect to the model parameters.
# Args:
# loss_term (torch.Tensor): The loss tensor to compute gradients for.
# Returns:
# float: The L2 norm of the gradients, or 0.0 if no gradients are present.
# """
# with torch.no_grad():
# grads = torch.autograd.grad(
# loss_term,
# [p for p in self.policy.parameters() if p.requires_grad],
# retain_graph=True,
# allow_unused=True,
# )
# grads = [g for g in grads if g is not None]
# if not grads:
# return torch.tensor(0.0, device=loss_term.device)
# return torch.norm(torch.stack([g.norm(2) for g in grads])).item()
def apply_reinforce_step(
self,
training_batch: TrainingBatch,
) -> None:
"""
Applies a single REINFORCE policy gradient step using the provided batch of rollouts.
Handles batching, loss computation (including entropy and KL regularization), gradient accumulation, and optimizer step.
Optionally logs various metrics and statistics.
Args:
paths (list[str]): List of game complete file paths for each rollout.
contexts (list[torch.Tensor]): List of context tensors for each rollout.
credits (list[torch.Tensor]): List of credit tensors (rewards/advantages) for each rollout.
action_masks (list[torch.Tensor]): List of action mask tensors for each rollout.
"""
with resource_logger_context(logger, "Apply reinforce step"):
self.policy.train()
mb_size = self.mini_batch_size
nb_rollouts = len(training_batch)
# Initialize running mean logs
running_mean_logs = {
"rl_objective": 0.0,
"policy_gradient_loss": 0.0,
"policy_gradient_norm": 0.0,
"log_probs": 0.0,
"credits": 0.0,
"entropy": 0.0,
"engine_log_probs_diff_clampfrac": 0.0,
"tis_imp_ratio": 0.0,
"ref_log_probs_diff_clampfrac": 0.0,
"higher_refprob_frac": 0.0,
"tis_imp_ratio_clampfrac": 0.0,
}
if self.entropy_coeff != 0.0:
running_mean_logs["entropy"] = 0.0
if self.kl_coeff != 0.0:
running_mean_logs["kl_divergence"] = 0.0
# Get total number of tokens generated
total_tokens_generated = 0
for att_mask in training_batch.batch_action_mask:
total_tokens_generated += att_mask.sum()
# Obtain loss normalization
if self.pg_loss_normalization == "nb_tokens":
normalization_factor = total_tokens_generated
elif self.pg_loss_normalization == "batch":
normalization_factor = np.ceil(nb_rollouts / mb_size).astype(int)
else:
raise ValueError(
f"Invalid pg_loss_normalization: {self.pg_loss_normalization}"
)
# Gradient accumulation for each mini-batch
for mb in range(0, nb_rollouts, mb_size):
logger.info(f"Processing mini-batch {mb} of {nb_rollouts}")
loss = 0.0
training_mb = training_batch[mb : mb + mb_size]
training_mb = training_mb.get_padded_tensors()
training_mb.to(self.device)
(
tokens_mb,
action_mask_mb,
entropy_mask_mb,
credits_mb,
engine_log_probs_mb,
timesteps_mb,
) = (
training_mb.batch_input_ids,
training_mb.batch_action_mask,
training_mb.batch_entropy_mask,
training_mb.batch_credits,
training_mb.batch_engine_log_probs,
training_mb.batch_timesteps,
)
# Next token prediction
contexts_mb = tokens_mb[:, :-1]
shifted_contexts_mb = tokens_mb[:, 1:]
action_mask_mb = action_mask_mb[:, 1:]
entropy_mask_mb = entropy_mask_mb[:, 1:]
credits_mb = credits_mb[:, 1:]
engine_log_probs_mb = engine_log_probs_mb[:, 1:]
timesteps_mb = timesteps_mb[:, 1:]
if self.enable_tokenwise_logging:
self.tokenwise_tally.set_action_mask(action_mask=action_mask_mb)
self.tokenwise_tally.set_range(range=(mb, mb + mb_size))
self.tokenwise_tally.add_contexts(contexts=contexts_mb)
self.tokenwise_tally.add_data(
metric_id="next_token",
metrics=shifted_contexts_mb,
to_tids=True,
)
self.tokenwise_tally.add_data(
metric_id="entropy_mask",
metrics=entropy_mask_mb,
)
if self.enable_tokenwise_logging:
self.tokenwise_tally.add_data(
metric_id="next_token_credit", metrics=credits_mb
)
# Forward pass + cast to FP-32 for higher prec.
# TODO: create attention mask if not relying on default (assume causal llm)
logits = self.policy(input_ids=contexts_mb)[0] # (B, S, V)
# Mask non-restricted tokens
if self.restrict_tokens is not None:
logits = self.mask_non_restricted_token_logits(logits)
logits /= self.temperature # (B, S, V)
# Compute new log probabilities
log_probs = F.log_softmax(logits, dim=-1) # (B, S, V)
# Get log probabilities of actions taken during rollouts
action_log_probs = log_probs.gather(
dim=-1, index=shifted_contexts_mb.unsqueeze(-1)
).squeeze(
-1
) # (B, S)
if self.pg_loss_normalization == "batch":
den_running_mean = action_mask_mb.sum() * normalization_factor
else:
den_running_mean = normalization_factor
running_mean_logs["log_probs"] += (
action_log_probs * action_mask_mb
).sum().item() / den_running_mean
running_mean_logs["credits"] += (
credits_mb * action_mask_mb
).sum().item() / den_running_mean
if self.enable_tokenwise_logging:
self.tokenwise_tally.add_data(
metric_id="next_token_log_prob",
metrics=action_log_probs,
)
self.tokenwise_tally.add_data(
metric_id="engine_next_token_log_prob",
metrics=engine_log_probs_mb,
)
self.tokenwise_tally.add_data(
metric_id="next_token_prob",
metrics=torch.exp(action_log_probs),
)
top_k_indices = torch.topk(logits, k=5, dim=-1).indices
self.tokenwise_tally.add_data(
metric_id=f"top_{5}_tids",
metrics=top_k_indices,
to_tids=True,
)
self.tokenwise_tally.add_data(
metric_id=f"top_{5}_probs",
metrics=torch.exp(log_probs).gather(
dim=-1, index=top_k_indices
),
)
rewarded_action_log_probs = (
action_mask_mb * credits_mb * action_log_probs
)
# (B, S)
INVALID_LOGPROB = 1.0
CLAMP_VALUE = 40.0
masked_action_log_probs = torch.masked_fill(
action_log_probs, ~action_mask_mb, INVALID_LOGPROB
)
masked_engine_log_probs = torch.masked_fill(
engine_log_probs_mb, ~action_mask_mb, INVALID_LOGPROB
)
with torch.no_grad():
action_engine_log_probs_diff = (
masked_action_log_probs - masked_engine_log_probs
).clamp(-CLAMP_VALUE, CLAMP_VALUE)
running_mean_logs["engine_log_probs_diff_clampfrac"] += (
action_engine_log_probs_diff.abs()
.eq(CLAMP_VALUE)
.float()
.sum()
.item()
/ den_running_mean
)
if self.importance_sampling_strategy == "per_sequence":
tis_imp_ratio = torch.zeros_like(action_engine_log_probs_diff)
for mb_idx in range(action_engine_log_probs_diff.shape[0]):
valid_token_mask = action_mask_mb[mb_idx]
timestep_ids = timesteps_mb[mb_idx][valid_token_mask]
timestep_logprob_diffs = action_engine_log_probs_diff[mb_idx][
valid_token_mask
]
max_timestep = int(timestep_ids.max().item()) + 1
timestep_sums = torch.zeros(
max_timestep,
device=action_engine_log_probs_diff.device,
dtype=action_engine_log_probs_diff.dtype,
)
timestep_sums.scatter_add_(
0, timestep_ids, timestep_logprob_diffs
)
timestep_ratios = torch.exp(timestep_sums)
tis_imp_ratio[
mb_idx, valid_token_mask
] = timestep_ratios.gather(0, timestep_ids)
else:
tis_imp_ratio = torch.exp(action_engine_log_probs_diff)
running_mean_logs["tis_imp_ratio"] += (
tis_imp_ratio * action_mask_mb
).sum().item() / den_running_mean
if self.truncated_importance_sampling_ratio_cap > 0.0:
tis_imp_ratio = torch.clamp(
tis_imp_ratio, max=self.truncated_importance_sampling_ratio_cap
)
running_mean_logs["tis_imp_ratio_clampfrac"] += (
tis_imp_ratio.eq(self.truncated_importance_sampling_ratio_cap)
.float()
.sum()
.item()
) / den_running_mean
rewarded_action_log_probs = (
rewarded_action_log_probs * tis_imp_ratio
)
if self.enable_tokenwise_logging:
self.tokenwise_tally.add_data(
metric_id="next_token_clogπ",
metrics=rewarded_action_log_probs,
)
# Add value term to loss
if self.pg_loss_normalization == "batch":
nb_act_tokens = action_mask_mb.sum()
mb_value = -rewarded_action_log_probs.sum() / nb_act_tokens
else:
mb_value = -rewarded_action_log_probs.sum()
loss += mb_value
running_mean_logs["rl_objective"] += mb_value.item() / den_running_mean
# -------------------------------------------------
# Entropy Regularization
# -------------------------------------------------
# Only apply entropy on distribution defined over most probable tokens
if self.entropy_topk is not None:
top_k_indices = torch.topk(
logits, k=self.entropy_topk, dim=-1
).indices
entropy_logits = logits.gather(dim=-1, index=top_k_indices)
else:
entropy_logits = logits
token_entropy_terms = -F.softmax(
entropy_logits, dim=-1
) * F.log_softmax(
entropy_logits, dim=-1
) # (B, S, T)
token_entropy_terms *= (
action_mask_mb[:, :, None] * entropy_mask_mb[:, :, None]
) # only get loss on specific action tokens
mb_entropy = token_entropy_terms.sum(dim=-1)
if self.enable_tokenwise_logging:
self.tokenwise_tally.add_data(
metric_id="entropy",
metrics=mb_entropy,
)
if self.pg_loss_normalization == "batch":
nb_act_tokens = action_mask_mb.sum()
mb_entropy = -mb_entropy.sum() / nb_act_tokens
else:
mb_entropy = -mb_entropy.sum()
running_mean_logs["entropy"] += -mb_entropy.item() / den_running_mean
if self.entropy_coeff != 0.0:
mb_entropy *= self.entropy_coeff
loss += mb_entropy
# -------------------------------------------------
# KL-DIVERGENCE
# -------------------------------------------------
if self.kl_coeff != 0.0:
ref_model_logits = self.policy.get_base_model_logits(contexts_mb)
ref_model_logits = ref_model_logits / self.temperature
# (B, S, V)
ref_model_logits = self.mask_non_restricted_token_logits(
logits=ref_model_logits
)
# (B, S, V)
ref_model_log_probs = F.log_softmax(ref_model_logits, dim=-1)
# (B, S, V)
ref_model_action_log_probs = ref_model_log_probs.gather(
dim=-1, index=shifted_contexts_mb.unsqueeze(-1)
).squeeze(
-1
) # (B,S)
# Approximating KL Divergence (see refs in docstring)
# Ref 1: http://joschu.net/blog/kl-approx.html
# Ref 2: https://github.dev/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L1332
masked_ref_model_action_log_probs = torch.masked_fill(
ref_model_action_log_probs, ~action_mask_mb, INVALID_LOGPROB
)
action_log_probs_diff = (
masked_ref_model_action_log_probs - masked_action_log_probs
).clamp(-CLAMP_VALUE, CLAMP_VALUE)
running_mean_logs["ref_log_probs_diff_clampfrac"] += (
action_log_probs_diff.abs().eq(CLAMP_VALUE).float().sum().item()
/ den_running_mean
)
if self.filter_higher_refprob_tokens_kl:
higher_refprob_tokens_mask = action_log_probs_diff > 0.0
running_mean_logs["higher_refprob_frac"] += (
higher_refprob_tokens_mask.sum().item() / den_running_mean
)
action_log_probs_diff = action_log_probs_diff * (
~higher_refprob_tokens_mask
)
kl_div = torch.expm1(action_log_probs_diff) - action_log_probs_diff
kl_div *= action_mask_mb # We only care about KLD of action tokens
if self.truncated_importance_sampling_ratio_cap > 0.0:
kl_div = kl_div * tis_imp_ratio
kl_div *= self.kl_coeff
if self.enable_tokenwise_logging:
self.tokenwise_tally.add_data(
metric_id="ref_model_next_token_log_prob",
metrics=ref_model_action_log_probs,
)
self.tokenwise_tally.add_data(
metric_id="kl_divergence",
metrics=kl_div,
)
if self.pg_loss_normalization == "batch":
nb_act_tokens = action_mask_mb.sum()
mb_kl = kl_div.sum() / nb_act_tokens
else:
mb_kl = kl_div.sum()
running_mean_logs["kl_divergence"] += (
mb_kl.item() / den_running_mean
)
loss += mb_kl
# Accumulate gradient
running_mean_logs["policy_gradient_loss"] += (
loss.item() / den_running_mean
)
loss /= normalization_factor
self.accelerator.backward(loss)
# ensure gpu memory is freed
del training_mb
del log_probs
del logits
del loss
del action_log_probs
del rewarded_action_log_probs
logger.info(
f"Accumulated the policy gradient loss for {total_tokens_generated} tokens."
)
# Clip gradients and take step
if self.gradient_clipping is not None:
grad_norm = self.accelerator.clip_grad_norm_(
self.policy.parameters(), self.gradient_clipping
)
running_mean_logs["policy_gradient_norm"] += grad_norm.item()
# Take step
self.policy_optimizer.step()
self.policy_optimizer.zero_grad()
# Store logs
for key, value in running_mean_logs.items():
self.tally.add_metric(path=key, metric=value)
# Clear
# TODO: verify
self.accelerator.clear(self.policy, self.policy_optimizer)
import gc
gc.collect()
torch.cuda.empty_cache()
return running_mean_logs
def get_advantages_with_critic_gradient_accumulation(
self, trajectories: TrajectoryBatch, critic_loss_scaling_factor: float = 2.0
) -> torch.FloatTensor:
"""
TOWRITE
Uses GAE if enabled, otherwise uses Monte Carlo returns.
Optionally trains the critic if GAE is used.
Returns:
advantages: NestedFloatTensors
"""
mb_size = self.mini_batch_size
batch_size = trajectories.rollout_ids.shape[0]
agent_id = trajectories.agent_ids[0]
batch_rewards = trajectories.batch_rewards
######################################
# use critic for advantage estimation
######################################
if self.use_gae:
if "buffer" in agent_id:
self.critic.eval()
training = False
else:
self.critic.train()
training = True
advantages = []
# critic_loss_scaling_factor comes learning single critic for two agents
normalization_factor = (
np.ceil(batch_size / mb_size).astype(int) * critic_loss_scaling_factor
)
# For each minibatch
for mb in range(0, batch_size, mb_size):
trajectory_mb = trajectories[mb : mb + mb_size]
trajectory_mb.to(self.device)
rewards_mb = trajectory_mb.batch_rewards
(
tokens_mb,
state_ends_mask_mb,
timestep_counts,
) = trajectory_mb.get_padded_tensors_for_critic()
# critic causal attention up to end flags
if training:
vals_estimate_full = self.critic(tokens_mb)
else:
with torch.no_grad():
vals_estimate_full = self.critic(tokens_mb)
# if vals_estimate_full.dim() == 3:
# vals_estimate_full = vals_estimate_full.squeeze(-1)
# Select only positions where states end, per sample → list of (jT,)
B = tokens_mb.shape[0]
vals_list = [
vals_estimate_full[b][state_ends_mask_mb[b]] for b in range(B)
]
# Pad to (B, max_jT) = (B, S)
vals_estimate_mb = pad_sequence(
vals_list, batch_first=True, padding_value=0.0
)
dtype = vals_estimate_mb.dtype
rewards_mb = pad_sequence(
rewards_mb, batch_first=True, padding_value=0.0
).to(
dtype=dtype
) # (B, S)
self.rollout_tally.add_metric(
path=["batch_rewards"],
rollout_tally_item=RolloutTallyItem(
crn_ids=trajectory_mb.crn_ids,
rollout_ids=trajectory_mb.rollout_ids,
agent_ids=trajectory_mb.agent_ids,
metric_matrix=rewards_mb,
),
)
if self.reward_normalizing_constant != 1.0:
rewards_mb /= self.reward_normalizing_constant
det_vals_estimate_mb = vals_estimate_mb.detach() # (B, max_jT)
self.rollout_tally.add_metric(
path=["mb_value_estimates_critic"],
rollout_tally_item=RolloutTallyItem(
crn_ids=trajectory_mb.crn_ids,
rollout_ids=trajectory_mb.rollout_ids,
agent_ids=trajectory_mb.agent_ids,
metric_matrix=det_vals_estimate_mb,
),
)
# Append a 0 value to the end of the value estimates
if det_vals_estimate_mb.shape[1] == rewards_mb.shape[1]:
Bsize = det_vals_estimate_mb.shape[0]
device = det_vals_estimate_mb.device
dtype = det_vals_estimate_mb.dtype
det_vals_estimate_mb = torch.cat(
[
det_vals_estimate_mb,
torch.zeros((Bsize, 1), device=device, dtype=dtype),
],
dim=1,
) # (B, max_jT+1)
else:
raise ValueError(
"Incompatible shapes for value estimates and rewards."
)
# Get annealed lambda
if self.use_gae_lambda_annealing:
annealing_constant = self.gae_lambda_annealing_method(
step=self.trainer_annealing_state.annealing_step_counter
)
annealed_lambda = (
self.gae_lambda_annealing_limit * annealing_constant
)
self.tally.add_metric(
path="annealed_lambda", metric=annealed_lambda
)
else:
annealed_lambda = self.gae_lambda_annealing_limit
# Get GAE advantages
gae_advantages = get_generalized_advantage_estimates(
rewards=rewards_mb,
value_estimates=det_vals_estimate_mb,
discount_factor=self.discount_factor,
lambda_coef=annealed_lambda,
) # (B, max_jT)
self.rollout_tally.add_metric(
path=["mb_gae_advantages"],
rollout_tally_item=RolloutTallyItem(
crn_ids=trajectory_mb.crn_ids,
rollout_ids=trajectory_mb.rollout_ids,
agent_ids=trajectory_mb.agent_ids,
metric_matrix=gae_advantages,
),
)
if training:
targets = (
gae_advantages.to(dtype=dtype) + det_vals_estimate_mb[:, :-1]
) # (B, max_jT) # A(s, a, b) + V(s) = Q(s, a, b)
self.rollout_tally.add_metric(
path=["mb_targets_critic"],
rollout_tally_item=RolloutTallyItem(
crn_ids=trajectory_mb.crn_ids,
rollout_ids=trajectory_mb.rollout_ids,
agent_ids=trajectory_mb.agent_ids,
metric_matrix=targets,
),
)
if self.critic_loss_type == "mse":
loss = F.mse_loss(
input=vals_estimate_mb,
target=targets,
)
elif self.critic_loss_type == "huber":
loss = F.huber_loss(
input=vals_estimate_mb,
target=targets,
)
self.tally.add_metric(path=["mb_critic_loss"], metric=loss.item())
# Accumulate gradient
loss /= normalization_factor
self.accelerator.backward(loss)
del loss
del targets
del vals_estimate_mb
del trajectory_mb
del vals_estimate_full
# Get jagged back using timestep_counts
advantages.extend(
[gae_advantages[i, : timestep_counts[i]] for i in range(B)]
)
######################################
# use exclusively Monte Carlo returns & rloo for advantage estimation
######################################
else:
lengths = [len(c) for c in batch_rewards]
padded_rewards = pad_sequence(
batch_rewards, batch_first=True, padding_value=0.0
)
self.rollout_tally.add_metric(
path=["mb_rewards"],
rollout_tally_item=RolloutTallyItem(
crn_ids=trajectories.crn_ids,
rollout_ids=trajectories.rollout_ids,
agent_ids=trajectories.agent_ids,
metric_matrix=padded_rewards,
),
)
if self.reward_normalizing_constant != 1.0:
padded_rewards /= self.reward_normalizing_constant
padded_advantages = get_discounted_returns(
rewards=padded_rewards,
discount_factor=self.discount_factor,
) # no baseline for now
if self.use_rloo:
is_grouped_by_rng = (
trajectories.crn_ids.unique().shape[0]
!= trajectories.crn_ids.shape[0]
)
if is_grouped_by_rng:
for crn_id in trajectories.crn_ids.unique():
rng_mask = trajectories.crn_ids == crn_id
rng_advantages = padded_advantages[rng_mask]
rng_advantages, _ = get_rloo_credits(credits=rng_advantages)
padded_advantages[rng_mask] = rng_advantages
else:
padded_advantages, _ = get_rloo_credits(credits=padded_advantages)
self.rollout_tally.add_metric(
path=["mb_rloo_advantages"],
rollout_tally_item=RolloutTallyItem(
crn_ids=trajectories.crn_ids,
rollout_ids=trajectories.rollout_ids,
agent_ids=trajectories.agent_ids,
metric_matrix=padded_advantages,
),
)
advantages = [
padded_advantages[i, : lengths[i]]
for i in range(padded_advantages.shape[0])
]
if self.whiten_advantages_time_step_wise or self.whiten_advantages:
lengths = [len(c) for c in advantages]
padded_advantages = pad_sequence(
advantages, batch_first=True, padding_value=0.0
)
if self.whiten_advantages_time_step_wise:
whitened_padded_advantages = whiten_advantages_time_step_wise(
padded_advantages
)
path = ["mb_whitened_advantages_time_step_wise"]
elif self.whiten_advantages:
whitened_padded_advantages = whiten_advantages(padded_advantages)
path = ["mb_whitened_advantages"]
self.rollout_tally.add_metric(
path=path,
rollout_tally_item=RolloutTallyItem(
crn_ids=trajectories.crn_ids,
rollout_ids=trajectories.rollout_ids,
agent_ids=trajectories.agent_ids,
metric_matrix=whitened_padded_advantages,
),
)
advantages = [
whitened_padded_advantages[i, : lengths[i]]
for i in range(whitened_padded_advantages.shape[0])
]
self.trainer_annealing_state.annealing_step_counter += 1
return advantages
@abstractmethod
def set_agent_trajectory_data(
self, agent_id: str, roots: list[RolloutTreeRootNode]
) -> None:
"""
TOWRITE
"""
pass
def set_trajectory_data(
self, roots: list[RolloutTreeRootNode], agent_ids: list[str]
) -> None:
"""
TOWRITE
"""
for agent_id in agent_ids:
self.set_agent_trajectory_data(agent_id, roots)
@abstractmethod
def share_advantage_data(self) -> list[AdvantagePacket]:
pass
@abstractmethod
def receive_advantage_data(self, advantage_packets: list[AdvantagePacket]) -> None:
pass
def set_policy_gradient_data(self, agent_ids: list[str]) -> None:
"""
Already set earlier # TODO: make it separate and clean
"""
self.policy_gradient_data = None
# for agent_id, trajectory_batch in self.training_data.items():
# if "buffer" in agent_id:
# continue
for agent_id in agent_ids:
assert "buffer" not in agent_id, "Buffer agents do not train policy"
trajectory_batch = self.training_data[agent_id]
tokenwise_batch_credits = get_tokenwise_credits(
batch_timesteps=trajectory_batch.batch_timesteps,
batch_credits=trajectory_batch.batch_credits,
)
policy_gradient_data = TrainingBatch(
rollout_ids=trajectory_batch.rollout_ids,
batch_input_ids=trajectory_batch.batch_input_ids,
batch_action_mask=trajectory_batch.batch_action_mask,
batch_entropy_mask=trajectory_batch.batch_entropy_mask,
batch_credits=tokenwise_batch_credits,
batch_engine_log_probs=trajectory_batch.batch_engine_log_probs,
batch_timesteps=trajectory_batch.batch_timesteps,
)
if self.policy_gradient_data is None:
self.policy_gradient_data = policy_gradient_data
else:
self.policy_gradient_data.append(policy_gradient_data)
self.training_data = {}
self.tokenwise_tally = ContextualizedTokenwiseTally(
tokenizer=self.tokenizer,
paths=self.debug_path_list,
)
def train(self) -> None:
"""
TOWRITE
"""
assert self.policy_gradient_data is not None, "Policy gradient data is not set"
if self.critic_optimizer is not None:
if self.gradient_clipping is not None:
grad_norm = self.accelerator.clip_grad_norm_(
self.critic.parameters(), self.gradient_clipping
)
self.tally.add_metric(
path="gradient_norm_critic", metric=grad_norm.item()
)
# Take step
self.critic_optimizer.step()
self.critic_optimizer.zero_grad()
self.accelerator.clear(self.critic, self.critic_optimizer)
import gc
gc.collect()
torch.cuda.empty_cache()
running_mean_logs = self.apply_reinforce_step(
training_batch=self.policy_gradient_data
)
return running_mean_logs
def export_training_tally(self, identifier: str, folder: str) -> None:
"""
Saves and resets the collected training metrics using the tally object.
"""
os.makedirs(folder, exist_ok=True)
self.tally.save(identifier=identifier, folder=folder)
self.tokenwise_tally.save(
path=os.path.join(folder, f"{identifier}_tokenwise.csv")
)
self.rollout_tally.save(identifier=identifier, folder=folder)
self.tally.reset()
self.tokenwise_tally = None
self.rollout_tally.reset()
self.debug_path_list = []
def export_optimizer_states(self) -> None:
"""
Saves the optimizer states for both the main model and critic (if it exists).
"""
try:
os.makedirs(self.save_path, exist_ok=True)
torch.save(self.policy_optimizer.state_dict(), self.policy_optimizer_path)
logger.info(f"Saved main optimizer state to {self.policy_optimizer_path}")
if self.critic_optimizer is not None:
torch.save(
self.critic_optimizer.state_dict(), self.critic_optimizer_path
)
logger.info(
f"Saved critic optimizer state to {self.critic_optimizer_path}"
)
except Exception as e:
logger.error(f"Error saving optimizer states: {str(e)}")
raise
def export_trainer_annealing_state(self) -> None:
"""
Saves the trainer state.
"""
with open(self.trainer_annealing_state_path, "wb") as f:
pickle.dump(self.trainer_annealing_state, f)
logger.info(f"Saved trainer state to {self.trainer_annealing_state_path}")
def export_trainer_states(self) -> None:
"""
Saves the trainer states.
"""
self.export_optimizer_states()
self.export_trainer_annealing_state()