File size: 5,874 Bytes
9ba32f5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | """
File: mllm/training/trainer_independent.py
Summary: Trainer for independently optimizing each agent.
"""
import logging
import os
import sys
from typing import Union
import torch
import torch.nn.functional as F
from accelerate import Accelerator
from pandas._libs.tslibs.offsets import CBMonthBegin
from peft import LoraConfig
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoModelForCausalLM, AutoTokenizer
from mllm.markov_games.rollout_tree import *
from mllm.markov_games.rollout_tree import RolloutTreeRootNode
from mllm.training.credit_methods import (
get_discounted_returns,
get_discounted_state_visitation_credits,
get_generalized_advantage_estimates,
get_rloo_credits,
)
from mllm.training.tally_metrics import Tally
from mllm.training.tally_tokenwise import ContextualizedTokenwiseTally
from mllm.training.tokenize_chats import *
from mllm.training.tokenize_chats import process_training_chat
from mllm.training.trainer_common import BaseTrainer
from mllm.training.training_data_utils import *
from mllm.training.training_data_utils import (
TrainingBatch,
TrajectoryBatch,
get_tokenwise_credits,
)
from mllm.utils.resource_context import resource_logger_context
logger = logging.getLogger(__name__)
logger.addHandler(logging.StreamHandler(sys.stdout))
@dataclass
class TrainingData:
"""Caches per-agent trajectory tensors plus their computed advantages."""
agent_id: str
main_data: TrajectoryBatch
# list-of-tensors: per rollout advantages with length jT
main_advantages: list[torch.FloatTensor] | None = None
class TrainerNaive(BaseTrainer):
def set_agent_trajectory_data(
self, agent_id: str, roots: list[RolloutTreeRootNode]
) -> None:
"""
Tokenize rollouts for a given agent and cache the tensors used for training.
"""
# Reset per-agent buffers; extend this logic if joint training batches are needed.
self.policy_gradient_data = None
# Tensorize Chats
rollout_ids = []
crn_ids = [] # common random number id
batch_input_ids = []
batch_action_mask = []
batch_entropy_mask = []
batch_timesteps = []
batch_state_ends_mask = []
batch_engine_log_probs = []
batch_rewards = []
for root in roots:
rollout_id = root.id
self.debug_path_list.append(
"mgid:" + str(rollout_id) + "_agent_id:" + agent_id
)
rollout_ids.append(rollout_id)
crn_ids.append(root.crn_id)
chat, rewards = get_main_chat_list_and_rewards(agent_id=agent_id, root=root)
(
input_ids,
action_mask,
entropy_mask,
timesteps,
state_ends_mask,
engine_log_probs,
) = process_training_chat(
tokenizer=self.tokenizer,
chat_history=chat,
entropy_mask_regex=self.entropy_mask_regex,
exploration_prompts_to_remove=self.exploration_prompts_to_remove,
)
batch_input_ids.append(input_ids)
batch_action_mask.append(action_mask)
batch_entropy_mask.append(entropy_mask)
batch_timesteps.append(timesteps)
batch_state_ends_mask.append(state_ends_mask)
batch_engine_log_probs.append(engine_log_probs)
batch_rewards.append(rewards)
trajectory_batch = TrajectoryBatch(
rollout_ids=torch.tensor(rollout_ids, dtype=torch.int32),
crn_ids=torch.tensor(crn_ids, dtype=torch.int32),
agent_ids=[agent_id] * len(rollout_ids),
batch_input_ids=batch_input_ids,
batch_action_mask=batch_action_mask,
batch_entropy_mask=batch_entropy_mask,
batch_timesteps=batch_timesteps,
batch_state_ends_mask=batch_state_ends_mask,
batch_rewards=batch_rewards,
batch_engine_log_probs=batch_engine_log_probs,
)
# Get Advantages
batch_advantages: torch.FloatTensor = (
self.get_advantages_with_critic_gradient_accumulation(trajectory_batch)
)
# Discount state visitation (the mathematically correct way)
if not self.skip_discounted_state_visitation:
for i in range(len(batch_advantages)):
batch_advantages[i] = get_discounted_state_visitation_credits(
batch_advantages[i].unsqueeze(0),
self.discount_factor,
).squeeze(0)
self.training_data[agent_id] = TrainingData(
agent_id=agent_id,
main_data=trajectory_batch,
main_advantages=batch_advantages,
)
def receive_advantage_data(self, advantage_packets: list[AdvantagePacket]):
"""
This trainer ignores the advantages of the other trainers.
"""
for agent_id, agent_data in self.training_data.items():
self.training_data[agent_id] = agent_data.main_data
self.training_data[agent_id].batch_credits = agent_data.main_advantages
def share_advantage_data(self) -> list[AdvantagePacket]:
"""
Share the advantage data with other agents.
Returns:
AdvantagePacket: The advantage packet containing the agent's advantages.
"""
logger.info(f"Sharing advantage data.")
advantage_packets = []
for agent_id, agent_data in self.training_data.items():
advantage_packets.append(
AdvantagePacket(
agent_id=agent_id,
rollout_ids=agent_data.main_data.rollout_ids,
main_advantages=agent_data.main_advantages,
)
)
return advantage_packets
|