| """ |
| File: mllm/training/trainer_independent.py |
| Summary: Trainer for independently optimizing each agent. |
| """ |
|
|
| import logging |
| import os |
| import sys |
| from typing import Union |
|
|
| import torch |
| import torch.nn.functional as F |
| from accelerate import Accelerator |
| from pandas._libs.tslibs.offsets import CBMonthBegin |
| from peft import LoraConfig |
| from torch.nn.utils.rnn import pad_sequence |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| from mllm.markov_games.rollout_tree import * |
| from mllm.markov_games.rollout_tree import RolloutTreeRootNode |
| from mllm.training.credit_methods import ( |
| get_discounted_returns, |
| get_discounted_state_visitation_credits, |
| get_generalized_advantage_estimates, |
| get_rloo_credits, |
| ) |
| from mllm.training.tally_metrics import Tally |
| from mllm.training.tally_tokenwise import ContextualizedTokenwiseTally |
| from mllm.training.tokenize_chats import * |
| from mllm.training.tokenize_chats import process_training_chat |
| from mllm.training.trainer_common import BaseTrainer |
| from mllm.training.training_data_utils import * |
| from mllm.training.training_data_utils import ( |
| TrainingBatch, |
| TrajectoryBatch, |
| get_tokenwise_credits, |
| ) |
| from mllm.utils.resource_context import resource_logger_context |
|
|
| logger = logging.getLogger(__name__) |
| logger.addHandler(logging.StreamHandler(sys.stdout)) |
|
|
|
|
| @dataclass |
| class TrainingData: |
| """Caches per-agent trajectory tensors plus their computed advantages.""" |
|
|
| agent_id: str |
| main_data: TrajectoryBatch |
| |
| main_advantages: list[torch.FloatTensor] | None = None |
|
|
|
|
| class TrainerNaive(BaseTrainer): |
| def set_agent_trajectory_data( |
| self, agent_id: str, roots: list[RolloutTreeRootNode] |
| ) -> None: |
| """ |
| Tokenize rollouts for a given agent and cache the tensors used for training. |
| """ |
| |
| self.policy_gradient_data = None |
|
|
| |
| rollout_ids = [] |
| crn_ids = [] |
| batch_input_ids = [] |
| batch_action_mask = [] |
| batch_entropy_mask = [] |
| batch_timesteps = [] |
| batch_state_ends_mask = [] |
| batch_engine_log_probs = [] |
| batch_rewards = [] |
| for root in roots: |
| rollout_id = root.id |
| self.debug_path_list.append( |
| "mgid:" + str(rollout_id) + "_agent_id:" + agent_id |
| ) |
| rollout_ids.append(rollout_id) |
| crn_ids.append(root.crn_id) |
| chat, rewards = get_main_chat_list_and_rewards(agent_id=agent_id, root=root) |
| ( |
| input_ids, |
| action_mask, |
| entropy_mask, |
| timesteps, |
| state_ends_mask, |
| engine_log_probs, |
| ) = process_training_chat( |
| tokenizer=self.tokenizer, |
| chat_history=chat, |
| entropy_mask_regex=self.entropy_mask_regex, |
| exploration_prompts_to_remove=self.exploration_prompts_to_remove, |
| ) |
| batch_input_ids.append(input_ids) |
| batch_action_mask.append(action_mask) |
| batch_entropy_mask.append(entropy_mask) |
| batch_timesteps.append(timesteps) |
| batch_state_ends_mask.append(state_ends_mask) |
| batch_engine_log_probs.append(engine_log_probs) |
| batch_rewards.append(rewards) |
|
|
| trajectory_batch = TrajectoryBatch( |
| rollout_ids=torch.tensor(rollout_ids, dtype=torch.int32), |
| crn_ids=torch.tensor(crn_ids, dtype=torch.int32), |
| agent_ids=[agent_id] * len(rollout_ids), |
| batch_input_ids=batch_input_ids, |
| batch_action_mask=batch_action_mask, |
| batch_entropy_mask=batch_entropy_mask, |
| batch_timesteps=batch_timesteps, |
| batch_state_ends_mask=batch_state_ends_mask, |
| batch_rewards=batch_rewards, |
| batch_engine_log_probs=batch_engine_log_probs, |
| ) |
|
|
| |
| batch_advantages: torch.FloatTensor = ( |
| self.get_advantages_with_critic_gradient_accumulation(trajectory_batch) |
| ) |
|
|
| |
| if not self.skip_discounted_state_visitation: |
| for i in range(len(batch_advantages)): |
| batch_advantages[i] = get_discounted_state_visitation_credits( |
| batch_advantages[i].unsqueeze(0), |
| self.discount_factor, |
| ).squeeze(0) |
|
|
| self.training_data[agent_id] = TrainingData( |
| agent_id=agent_id, |
| main_data=trajectory_batch, |
| main_advantages=batch_advantages, |
| ) |
|
|
| def receive_advantage_data(self, advantage_packets: list[AdvantagePacket]): |
| """ |
| This trainer ignores the advantages of the other trainers. |
| """ |
| for agent_id, agent_data in self.training_data.items(): |
| self.training_data[agent_id] = agent_data.main_data |
| self.training_data[agent_id].batch_credits = agent_data.main_advantages |
|
|
| def share_advantage_data(self) -> list[AdvantagePacket]: |
| """ |
| Share the advantage data with other agents. |
| Returns: |
| AdvantagePacket: The advantage packet containing the agent's advantages. |
| """ |
| logger.info(f"Sharing advantage data.") |
| advantage_packets = [] |
| for agent_id, agent_data in self.training_data.items(): |
| advantage_packets.append( |
| AdvantagePacket( |
| agent_id=agent_id, |
| rollout_ids=agent_data.main_data.rollout_ids, |
| main_advantages=agent_data.main_advantages, |
| ) |
| ) |
| return advantage_packets |
|
|