Muqeeth's picture
Add files using upload-large-folder tool
9ba32f5 verified
"""
File: mllm/training/trainer_independent.py
Summary: Trainer for independently optimizing each agent.
"""
import logging
import os
import sys
from typing import Union
import torch
import torch.nn.functional as F
from accelerate import Accelerator
from pandas._libs.tslibs.offsets import CBMonthBegin
from peft import LoraConfig
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoModelForCausalLM, AutoTokenizer
from mllm.markov_games.rollout_tree import *
from mllm.markov_games.rollout_tree import RolloutTreeRootNode
from mllm.training.credit_methods import (
get_discounted_returns,
get_discounted_state_visitation_credits,
get_generalized_advantage_estimates,
get_rloo_credits,
)
from mllm.training.tally_metrics import Tally
from mllm.training.tally_tokenwise import ContextualizedTokenwiseTally
from mllm.training.tokenize_chats import *
from mllm.training.tokenize_chats import process_training_chat
from mllm.training.trainer_common import BaseTrainer
from mllm.training.training_data_utils import *
from mllm.training.training_data_utils import (
TrainingBatch,
TrajectoryBatch,
get_tokenwise_credits,
)
from mllm.utils.resource_context import resource_logger_context
logger = logging.getLogger(__name__)
logger.addHandler(logging.StreamHandler(sys.stdout))
@dataclass
class TrainingData:
"""Caches per-agent trajectory tensors plus their computed advantages."""
agent_id: str
main_data: TrajectoryBatch
# list-of-tensors: per rollout advantages with length jT
main_advantages: list[torch.FloatTensor] | None = None
class TrainerNaive(BaseTrainer):
def set_agent_trajectory_data(
self, agent_id: str, roots: list[RolloutTreeRootNode]
) -> None:
"""
Tokenize rollouts for a given agent and cache the tensors used for training.
"""
# Reset per-agent buffers; extend this logic if joint training batches are needed.
self.policy_gradient_data = None
# Tensorize Chats
rollout_ids = []
crn_ids = [] # common random number id
batch_input_ids = []
batch_action_mask = []
batch_entropy_mask = []
batch_timesteps = []
batch_state_ends_mask = []
batch_engine_log_probs = []
batch_rewards = []
for root in roots:
rollout_id = root.id
self.debug_path_list.append(
"mgid:" + str(rollout_id) + "_agent_id:" + agent_id
)
rollout_ids.append(rollout_id)
crn_ids.append(root.crn_id)
chat, rewards = get_main_chat_list_and_rewards(agent_id=agent_id, root=root)
(
input_ids,
action_mask,
entropy_mask,
timesteps,
state_ends_mask,
engine_log_probs,
) = process_training_chat(
tokenizer=self.tokenizer,
chat_history=chat,
entropy_mask_regex=self.entropy_mask_regex,
exploration_prompts_to_remove=self.exploration_prompts_to_remove,
)
batch_input_ids.append(input_ids)
batch_action_mask.append(action_mask)
batch_entropy_mask.append(entropy_mask)
batch_timesteps.append(timesteps)
batch_state_ends_mask.append(state_ends_mask)
batch_engine_log_probs.append(engine_log_probs)
batch_rewards.append(rewards)
trajectory_batch = TrajectoryBatch(
rollout_ids=torch.tensor(rollout_ids, dtype=torch.int32),
crn_ids=torch.tensor(crn_ids, dtype=torch.int32),
agent_ids=[agent_id] * len(rollout_ids),
batch_input_ids=batch_input_ids,
batch_action_mask=batch_action_mask,
batch_entropy_mask=batch_entropy_mask,
batch_timesteps=batch_timesteps,
batch_state_ends_mask=batch_state_ends_mask,
batch_rewards=batch_rewards,
batch_engine_log_probs=batch_engine_log_probs,
)
# Get Advantages
batch_advantages: torch.FloatTensor = (
self.get_advantages_with_critic_gradient_accumulation(trajectory_batch)
)
# Discount state visitation (the mathematically correct way)
if not self.skip_discounted_state_visitation:
for i in range(len(batch_advantages)):
batch_advantages[i] = get_discounted_state_visitation_credits(
batch_advantages[i].unsqueeze(0),
self.discount_factor,
).squeeze(0)
self.training_data[agent_id] = TrainingData(
agent_id=agent_id,
main_data=trajectory_batch,
main_advantages=batch_advantages,
)
def receive_advantage_data(self, advantage_packets: list[AdvantagePacket]):
"""
This trainer ignores the advantages of the other trainers.
"""
for agent_id, agent_data in self.training_data.items():
self.training_data[agent_id] = agent_data.main_data
self.training_data[agent_id].batch_credits = agent_data.main_advantages
def share_advantage_data(self) -> list[AdvantagePacket]:
"""
Share the advantage data with other agents.
Returns:
AdvantagePacket: The advantage packet containing the agent's advantages.
"""
logger.info(f"Sharing advantage data.")
advantage_packets = []
for agent_id, agent_data in self.training_data.items():
advantage_packets.append(
AdvantagePacket(
agent_id=agent_id,
rollout_ids=agent_data.main_data.rollout_ids,
main_advantages=agent_data.main_advantages,
)
)
return advantage_packets