diff --git a/seed_1/Qwen/Qwen2.5-7B-Instruct/adapters/fixed_ad_align_adapter/adapter_config.json b/seed_1/Qwen/Qwen2.5-7B-Instruct/adapters/fixed_ad_align_adapter/adapter_config.json new file mode 100644 index 0000000000000000000000000000000000000000..a8500e6eb97e94ff82b30e4b2ea7b6eaba53fc50 --- /dev/null +++ b/seed_1/Qwen/Qwen2.5-7B-Instruct/adapters/fixed_ad_align_adapter/adapter_config.json @@ -0,0 +1,46 @@ +{ + "alora_invocation_tokens": null, + "alpha_pattern": {}, + "arrow_config": null, + "auto_mapping": null, + "base_model_name_or_path": "Qwen/Qwen2.5-7B-Instruct", + "bias": "none", + "corda_config": null, + "ensure_weight_tying": false, + "eva_config": null, + "exclude_modules": null, + "fan_in_fan_out": false, + "inference_mode": true, + "init_lora_weights": true, + "layer_replication": null, + "layers_pattern": null, + "layers_to_transform": null, + "loftq_config": {}, + "lora_alpha": 64, + "lora_bias": false, + "lora_dropout": 0.0, + "megatron_config": null, + "megatron_core": "megatron.core", + "modules_to_save": null, + "peft_type": "LORA", + "peft_version": "0.18.1", + "qalora_group_size": 16, + "r": 32, + "rank_pattern": {}, + "revision": null, + "target_modules": [ + "down_proj", + "up_proj", + "gate_proj", + "o_proj", + "k_proj", + "q_proj", + "v_proj" + ], + "target_parameters": null, + "task_type": "CAUSAL_LM", + "trainable_token_indices": null, + "use_dora": false, + "use_qalora": false, + "use_rslora": false +} \ No newline at end of file diff --git a/src_code_for_reproducibility/__pycache__/__init__.cpython-312.pyc b/src_code_for_reproducibility/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d5ad10de57b76f8584a2f944f44527cbac30497 Binary files /dev/null and b/src_code_for_reproducibility/__pycache__/__init__.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/chat_utils/apply_template.py b/src_code_for_reproducibility/chat_utils/apply_template.py new file mode 100644 index 0000000000000000000000000000000000000000..6bbdc32dbb1df0407ff24ae90395dba0d162bf7d --- /dev/null +++ b/src_code_for_reproducibility/chat_utils/apply_template.py @@ -0,0 +1,89 @@ +""" +File: mllm/chat_utils/apply_template.py +Summary: Applies tokenizer-specific chat templates and stitches chat token IDs. +""" + +import torch + +from mllm.chat_utils.chat_turn import ChatTurn +from mllm.chat_utils.template_specific import ( + custom_gemma3_template, + custom_llama3_template, + custom_qwen2_template, + custom_qwen3_template, + gemma3_assistant_postfix, + qwen2_assistant_postfix, + qwen3_assistant_postfix, +) + + +def get_custom_chat_template(tokenizer) -> str: + """ + Get the chat template for the tokenizer. + """ + if "qwen2" in tokenizer.name_or_path.lower(): + return custom_qwen2_template + elif "llama" in tokenizer.name_or_path.lower(): + return custom_llama3_template + elif "qwen3" in tokenizer.name_or_path.lower(): + return custom_qwen3_template + elif "gemma" in tokenizer.name_or_path.lower(): + return custom_gemma3_template + else: + raise ValueError(f"Tokenizer {tokenizer.name_or_path} not supported") + + +def get_custom_assistant_postfix(tokenizer) -> torch.Tensor: + """ + Get the custom assistant postfix for the tokenizer. + """ + if "qwen2" in tokenizer.name_or_path.lower(): + return qwen2_assistant_postfix + elif "qwen3" in tokenizer.name_or_path.lower(): + return qwen3_assistant_postfix + elif "gemma" in tokenizer.name_or_path.lower(): + return gemma3_assistant_postfix + return torch.tensor([], dtype=torch.long) + + +def tokenize_chats(chats: list[ChatTurn], tokenizer, enable_thinking) -> None: + """ + Set the chat_template_token_ids for each chat turn. + We rely on tokenizer-side templates because engine-provided cached tokens are not exposed yet. + """ + custom_template = get_custom_chat_template(tokenizer) + custom_assistant_postfix: torch.Tensor = get_custom_assistant_postfix(tokenizer) + for i, chat in enumerate(chats): + if chat.chat_template_token_ids is None: + if chat.role == "user": + next_chat = chats[i + 1] if i + 1 < len(chats) else None + add_generation_prompt = True + if next_chat and next_chat.role == "user": + add_generation_prompt = False + encoded_chat = tokenizer.apply_chat_template( + [chat], + return_tensors="pt", + chat_template=custom_template, + add_generation_prompt=add_generation_prompt, + add_system_prompt=True if i == 0 else False, + enable_thinking=enable_thinking, + ).flatten() + previous_chat = chats[i - 1] if i > 0 else None + if previous_chat and previous_chat.role == "assistant": + encoded_chat = torch.cat([custom_assistant_postfix, encoded_chat]) + elif chat.role == "assistant": + encoded_chat = chat.out_token_ids + chat.chat_template_token_ids = encoded_chat + + +def chat_turns_to_token_ids( + chats: list[ChatTurn], tokenizer, enable_thinking +) -> list[int]: + """ + Tokenize the chat turns and set the chat_template_token_ids for each chat turn. + """ + tokenize_chats(chats=chats, tokenizer=tokenizer, enable_thinking=enable_thinking) + token_ids = [] + for chat in chats: + token_ids.append(chat.chat_template_token_ids) + return torch.cat(token_ids) diff --git a/src_code_for_reproducibility/chat_utils/chat_turn.py b/src_code_for_reproducibility/chat_utils/chat_turn.py new file mode 100644 index 0000000000000000000000000000000000000000..cfc0d9422a6070c86b1da8abce17ad28816fb2eb --- /dev/null +++ b/src_code_for_reproducibility/chat_utils/chat_turn.py @@ -0,0 +1,32 @@ +""" +File: mllm/chat_utils/chat_turn.py +Summary: Defines the ChatTurn schema plus helpers for serialization and validation. +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Any, List, Literal, Optional, Tuple + +import jsonschema +import torch +from pydantic import BaseModel, ConfigDict, Field, model_validator + +AgentId = str + + +class ChatTurn(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) # needed for torch tensors + + role: str = Field(pattern="^(user|assistant)$") + agent_id: AgentId # ID of the agent with which the chat occured + content: str + reasoning_content: str | None = None + chat_template_token_ids: torch.LongTensor | None = None # Token ids of chat template format. For example, token ids of "{content}"" + out_token_ids: torch.LongTensor | None = ( + None # tokens generated from inference engine + ) + log_probs: torch.FloatTensor | None = None + is_state_end: bool = False # indicates whether this chat turn marks the end of a state in the trajectory diff --git a/src_code_for_reproducibility/chat_utils/template_specific.py b/src_code_for_reproducibility/chat_utils/template_specific.py new file mode 100644 index 0000000000000000000000000000000000000000..c22328455c55f0b0a02439efdacf6b09234d7981 --- /dev/null +++ b/src_code_for_reproducibility/chat_utils/template_specific.py @@ -0,0 +1,114 @@ +""" +File: mllm/chat_utils/template_specific.py +Summary: Stores chat template variants and assistant postfix tensors per tokenizer. +""" + +import huggingface_hub +import torch +from transformers import AutoTokenizer + +custom_llama3_template = """ +{%- if add_system_prompt %} + {{- '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n<|eot_id|>' }} +{%- endif %} +{%- for message in messages %} + {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }} +{%- endfor %} + +{%- if add_generation_prompt %} + {{- '<|start_header_id|>' + 'assistant' + '<|end_header_id|>\n\n' }} +{%- endif %} +""" + +qwen2_assistant_postfix = ( + AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct") + .encode("\n", return_tensors="pt") + .flatten() +) +qwen3_assistant_postfix = ( + AutoTokenizer.from_pretrained("Qwen/Qwen3-8B") + .encode("\n", return_tensors="pt") + .flatten() +) +gemma3_assistant_postfix = ( + AutoTokenizer.from_pretrained("google/gemma-3-4b-it") + .encode("\n", return_tensors="pt") + .flatten() +) +custom_qwen2_template = """ +{%- if add_system_prompt %} + {{- '<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n' }} +{%- endif %} +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{%- for message in messages %} + {%- if message.content is string %} + {%- set content = message.content %} + {%- else %} + {%- set content = '' %} + {%- endif %} + {%- if (message.role == "user") %} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is string %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '' in content %} + {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- set content = content.split('')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {%- if loop.index0 > ns.last_query_index %} + {%- if reasoning_content %} + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} +""" + +custom_qwen3_template = """ +{%- for message in messages %} + {%- if message.content is string %} + {%- set content = message.content %} + {%- else %} + {%- set content = '' %} + {%- endif %} + {%- if (message.role == "user") %} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is false %} + {{- '\n\n\n\n' }} + {%- endif %} +{%- endif %} +""" + +custom_gemma3_template = """ +{%- if add_system_prompt %} +{{- bos_token -}} +{%- endif %} +{%- for message in messages -%} +{%- if message['role'] == 'assistant' -%} +{%- set role = 'model' -%} +{%- else -%} +{%- set role = message['role'] -%} +{%- endif -%} +{{ '' + role + '\n' + message['content'] | trim + '\n' }} +{%- endfor -%} +{%- if add_generation_prompt -%} +{{ 'model\n' }} +{%- endif -%} +""" diff --git a/src_code_for_reproducibility/markov_games/__init__.py b/src_code_for_reproducibility/markov_games/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2a7015344d8ac9b53f2660d4e837f709908213db --- /dev/null +++ b/src_code_for_reproducibility/markov_games/__init__.py @@ -0,0 +1,4 @@ +""" +File: mllm/markov_games/__init__.py +Summary: Makes Markov-game subpackages importable from the top-level namespace. +""" diff --git a/src_code_for_reproducibility/markov_games/agent.py b/src_code_for_reproducibility/markov_games/agent.py new file mode 100644 index 0000000000000000000000000000000000000000..56406ae2695ce97ad7fa4fc436908904ee11be9f --- /dev/null +++ b/src_code_for_reproducibility/markov_games/agent.py @@ -0,0 +1,72 @@ +""" +File: mllm/markov_games/agent.py +Summary: Declares the base Agent interface connecting simulations to policy calls. +""" + +from abc import ABC, abstractmethod +from collections.abc import Callable +from typing import Any, Tuple + +from numpy.random import default_rng + +from mllm.markov_games.rollout_tree import AgentActLog + + +class Agent(ABC): + """Abstract policy wrapper that bridges simulations with arbitrary backends.""" + + @abstractmethod + def __init__( + self, + seed: int, + agent_id: str, + agent_name: str, + agent_policy: Callable[[list[dict]], str], + *args, + **kwargs, + ): + """ + Initialize the agent state and seed its RNG. + + Subclasses typically store extra handles (tokenizers, inference clients, etc.) + but they should always call ``super().__init__`` so sampling remains reproducible. + """ + self.seed = seed + self.agent_id = agent_id + self.agent_name = agent_name + self.policy = policy + self.rng = default_rng(self.seed) + raise NotImplementedError + + async def act(self, observation) -> Tuple[Any, AgentActLog]: + """ + Produce the next action (and associated chat log) given an environment observation. + + Implementations can iterate with rejection sampling, multi-call deliberation, etc. + Returns both the chosen action and an `AgentActLog` describing how it was produced. + """ + raise NotImplementedError + + def get_safe_copy(self): + """ + Return a deep copy whose future calls do not mutate the original agent. + + Needed for branch exploration/reruns with alternative actions. + """ + raise NotImplementedError + + def reset(self): + """Reset any internal state between rollouts.""" + raise NotImplementedError + + def render(self): + """Optional human-readable visualization of the agent (CLI/UI).""" + raise NotImplementedError + + def close(self): + """Release any external resources (network sockets, subprocesses, etc.).""" + raise NotImplementedError + + def get_agent_info(self): + """Return diagnostic metadata to embed inside rollout logs.""" + raise NotImplementedError diff --git a/src_code_for_reproducibility/markov_games/alternative_actions_runner.py b/src_code_for_reproducibility/markov_games/alternative_actions_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..d5165a2552019aefdf281c2bd41e50d204713921 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/alternative_actions_runner.py @@ -0,0 +1,146 @@ +""" +File: mllm/markov_games/alternative_actions_runner.py +Summary: Generates rollout branches by replaying trajectories with unilateral action changes. +""" + +import asyncio +import copy +import json +import os.path +from typing import Any, Tuple + +from mllm.markov_games.markov_game import AgentAndActionSafeCopy, MarkovGame +from mllm.markov_games.rollout_tree import ( + AgentActLog, + RolloutTreeBranchNode, + RolloutTreeNode, + RolloutTreeRootNode, + StepLog, +) + +AgentId = str + + +async def run_with_unilateral_alt_action( + markov_game: MarkovGame, + agent_id: AgentId, + time_step: int, + branch_node: RolloutTreeBranchNode, + max_depth: int, +): + """ + Roll out a counterfactual branch where ``agent_id`` deviates unilaterally. + + Starting from ``branch_node`` (which already contains the main trajectory), + we replay the simulation with the deviating agent's action while freezing + all other agents/actions, then continue for ``max_depth`` steps. + """ + + # Generate alternative action and take a step + await markov_game.set_action_of_agent(agent_id) + terminated: bool = markov_game.take_simulation_step() + step_log = markov_game.get_step_log() + first_alternative_node = RolloutTreeNode( + step_log=step_log, + time_step=time_step, + ) + + # Generate rest of trajectory up to max depth + time_step += 1 + counter = 1 + previous_node = first_alternative_node + while not terminated and counter <= max_depth: + terminated, step_log = await markov_game.step() + current_node = RolloutTreeNode(step_log=step_log, time_step=time_step) + previous_node.child = current_node + previous_node = current_node + counter += 1 + time_step += 1 + + if branch_node.branches == None: + branch_node.branches = {agent_id: [first_alternative_node]} + else: + agent_branches = branch_node.branches.get(agent_id, []) + agent_branches.append(first_alternative_node) + branch_node.branches[agent_id] = agent_branches + + +async def AlternativeActionsRunner( + markov_game: MarkovGame, + output_folder: str, + nb_alternative_actions: int, + max_depth: int, + branch_only_on_new_round: bool = False, +): + """ + Generate a rollout tree containing the main path plus unilateral deviation branches. + + For each timestep we: + 1. Cache agent actions without side effects. + 2. Advance the main trajectory. + 3. Spawn ``nb_alternative_actions`` asynchronous deviations per agent, + each replaying up to ``max_depth`` steps from the cached pre-action state. + The resulting branches feed advantage-alignment estimators. + """ + + tasks = [] + time_step = 0 + terminated = False + root = RolloutTreeRootNode(id=markov_game.get_id(), crn_id=markov_game.get_crn_id()) + previous_node = root + + while not terminated: + mg_before_action = markov_game.get_safe_copy() + + # Get safe copies for main branch + agent_action_safe_copies: dict[ + AgentId, AgentAndActionSafeCopy + ] = await markov_game.get_actions_of_agents_without_side_effects() + + markov_game.set_actions_of_agents_manually(agent_action_safe_copies) + terminated = markov_game.take_simulation_step() + main_node = RolloutTreeNode( + step_log=markov_game.get_step_log(), time_step=time_step + ) + branch_node = RolloutTreeBranchNode(main_child=main_node) + previous_node.child = branch_node + previous_node = main_node + + # Get alternative branches by generating new unilateral actions + for agent_id in markov_game.agent_ids: + for _ in range(nb_alternative_actions): + # Get safe copies for branches + branch_agent_action_safe_copies: dict[ + AgentId, AgentAndActionSafeCopy + ] = { + agent_id: AgentAndActionSafeCopy( + action=copy.deepcopy(agent_action_safe_copy.action), + action_info=copy.deepcopy(agent_action_safe_copy.action_info), + agent_after_action=agent_action_safe_copy.agent_after_action.get_safe_copy(), + ) + for agent_id, agent_action_safe_copy in agent_action_safe_copies.items() + } + mg_branch: MarkovGame = mg_before_action.get_safe_copy() + other_agent_id = [id for id in mg_branch.agent_ids if id != agent_id][0] + mg_branch.set_action_and_agent_after_action_manually( + agent_id=other_agent_id, + agent_action_safe_copy=branch_agent_action_safe_copies[ + other_agent_id + ], + ) + task = asyncio.create_task( + run_with_unilateral_alt_action( + markov_game=mg_branch, + time_step=time_step, + agent_id=agent_id, + branch_node=branch_node, + max_depth=max_depth, + ) + ) + tasks.append(task) + time_step += 1 + + # wait for all branches to complete + await asyncio.gather(*tasks) + + return root diff --git a/src_code_for_reproducibility/markov_games/group_timesteps.py b/src_code_for_reproducibility/markov_games/group_timesteps.py new file mode 100644 index 0000000000000000000000000000000000000000..48b5882a632ba858787befaac306195af959b376 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/group_timesteps.py @@ -0,0 +1,133 @@ +""" +File: mllm/markov_games/group_timesteps.py +Summary: Provides timestep-grouping utilities for rollout trees and training. +""" + +import copy +from typing import Callable + +from mllm.markov_games.markov_game import MarkovGame +from mllm.markov_games.rollout_tree import ( + AgentActLog, + RolloutTreeBranchNode, + RolloutTreeNode, + RolloutTreeRootNode, + StepLog, +) +from mllm.markov_games.simulation import SimulationStepLog + +AgentId = str + + +def group_time_steps( + rollout_tree: RolloutTreeRootNode, + accumulation_stop_condition: Callable[[StepLog], bool], +) -> RolloutTreeRootNode: + """ + During generation, we create rollout trees according to the real time steps. + However, during training, we might want to treat groups of time steps as a single time step. + As a concrete example, take Trust-and-Split. At each round, say we have X time steps of communication and then one time step for the split. + Then the communication actions will not get any reward, and the split action will get the reward. During REINFORCE training, with discounting, this + can cause training instability. We could instead treat every action in the round as being part of a single action, and give it the reward of the split action. + This method helps to do this sort of grouping. + It accumulates actions until the accumulation_stop_condition is met, and then creates a new node with the accumulated actions. + It then recursively calls itself on the child node. + Details: + - The reward for the group is the reward of the last time step in the group. + - The simulation log for the group is the simulation log of the last time step in the group. + - The state end for the group becomes the first state end in the group. + - The agent info for the group is the agent info of the last time step in the group. + """ + + def group_step_logs(step_logs: list[StepLog]) -> StepLog: + """ + Concatenate per-agent chat turns across steps; keep only the first is_state_end. + """ + last_sim_log = step_logs[-1].simulation_step_log + agent_ids = {aid for s in step_logs for aid in s.action_logs.keys()} + grouped_logs: dict[AgentId, AgentActLog] = {} + for aid in agent_ids: + turns = [] + for s in step_logs: + act = s.action_logs.get(aid) + if act and act.chat_turns: + turns.extend(copy.deepcopy(act.chat_turns)) + disable_is_state_end = False + # Only the first state_end should be True, the rest should be False + for t in turns: + if t.is_state_end: + if disable_is_state_end: + t.is_state_end = False + else: + disable_is_state_end = True + continue + grouped_logs[aid] = AgentActLog( + chat_turns=turns, info=step_logs[-1].action_logs[aid].info + ) + return StepLog(action_logs=grouped_logs, simulation_step_log=last_sim_log) + + def group_time_steps_rec( + current_node: RolloutTreeNode | RolloutTreeBranchNode, + group_time_step: int, + accumulation_step_logs: list[StepLog], + ) -> RolloutTreeNode | RolloutTreeBranchNode: + """ + Groups time steps. Recursion is used to handle branches. + """ + assert isinstance(current_node, RolloutTreeNode) or isinstance( + current_node, RolloutTreeBranchNode + ), "Current node must be a tree node or a branch node. Is of type: " + str( + type(current_node) + ) + first_group_node = None + current_group_node = None + while current_node is not None: + if isinstance(current_node, RolloutTreeBranchNode): + raise Exception( + "Grouping timesteps by round is not supported for branching trajectories yet." + ) + + # Accumulate + accumulation_step_logs.append(current_node.step_log) + if accumulation_stop_condition(current_node.step_log): + grouped_step_logs = group_step_logs(accumulation_step_logs) + accumulation_step_logs = [] + new_group_node = RolloutTreeNode( + step_log=grouped_step_logs, time_step=group_time_step, child=None + ) + if first_group_node == None: + first_group_node = new_group_node + group_time_step += 1 + if current_group_node is not None: + current_group_node.child = new_group_node + current_group_node = new_group_node + current_node = current_node.child + return first_group_node + + node = group_time_steps_rec( + current_node=rollout_tree.child, group_time_step=0, accumulation_step_logs=[] + ) + return RolloutTreeRootNode( + id=rollout_tree.id, + crn_id=rollout_tree.crn_id, + child=node, + agent_ids=rollout_tree.agent_ids, + ) + + +def stop_when_round_ends(step_log: StepLog) -> bool: + """ + Simplest stop condition. Will return True if step log is the last time step of a round. + This will throw an error if this information is not available in the simulation info. + """ + assert ( + "is_last_timestep_in_round" in step_log.simulation_step_log.info.keys() + ), "To group by round, is_last_timestep_in_round must be set in the info of your simulation step log at each time step." + return step_log.simulation_step_log.info["is_last_timestep_in_round"] + + +def group_by_round(rollout_tree: RolloutTreeRootNode) -> RolloutTreeRootNode: + """ + Groups time steps by round. + """ + return group_time_steps(rollout_tree, stop_when_round_ends) diff --git a/src_code_for_reproducibility/markov_games/linear_runner.py b/src_code_for_reproducibility/markov_games/linear_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..e3e14a3eda72cf4e620db5ab8ed0d3f7d552e9fe --- /dev/null +++ b/src_code_for_reproducibility/markov_games/linear_runner.py @@ -0,0 +1,42 @@ +""" +File: mllm/markov_games/linear_runner.py +Summary: Simulates a single unbranched Markov-game rollout and records it. +""" + +import asyncio +import json +import os.path + +from mllm.markov_games.markov_game import MarkovGame +from mllm.markov_games.rollout_tree import RolloutTreeNode, RolloutTreeRootNode + + +async def LinearRunner( + markov_game: MarkovGame, output_folder: str +) -> RolloutTreeRootNode: + """ + Generate a single main-path rollout (no branching) for the provided Markov game. + + Parameters + ---------- + markov_game: + Initialized ``MarkovGame`` with agents + simulation ready to step. + output_folder: + Unused placeholder in the legacy API (kept for compatibility). + """ + time_step = 0 + terminated = False + root = RolloutTreeRootNode( + id=markov_game.get_id(), + crn_id=markov_game.get_crn_id(), + agent_ids=markov_game.get_agent_ids(), + ) + previous_node = root + while not terminated: + terminated, step_log = await markov_game.step() + current_node = RolloutTreeNode(step_log=step_log, time_step=time_step) + previous_node.child = current_node + previous_node = current_node + time_step += 1 + + return root diff --git a/src_code_for_reproducibility/markov_games/markov_game.py b/src_code_for_reproducibility/markov_games/markov_game.py new file mode 100644 index 0000000000000000000000000000000000000000..7964fd69d24f617c76e36f852491b1e6141f6c48 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/markov_game.py @@ -0,0 +1,217 @@ +""" +File: mllm/markov_games/markov_game.py +Summary: Defines the MarkovGame base class plus shared simulation interfaces. +""" + +import asyncio +import copy +import json +import os +from dataclasses import dataclass +from typing import Any, List, Literal, Optional, Tuple + +from transformers.models.idefics2 import Idefics2Config + +from mllm.markov_games.agent import Agent +from mllm.markov_games.rollout_tree import AgentActLog, StepLog +from mllm.markov_games.simulation import Simulation + +AgentId = str + + +@dataclass +class AgentAndActionSafeCopy: + """Snapshot of an agent, its action, and metadata used for branch replay.""" + + action: Any + action_info: AgentActLog + agent_after_action: type[Agent] + + +class MarkovGame(object): + def __init__( + self, + id: int, + agents: dict[AgentId, type[Agent]], + simulation: type[Simulation], + crn_id: int, + ): + """ + Initialize the Markov game wrapper. + + Parameters + ---------- + id: + Unique rollout identifier (logged into rollout trees). + agents: + Mapping of agent_id -> Agent instance. + simulation: + Environment implementing the ``Simulation`` interface (IPD, TAS, etc.). + crn_id: + Identifier for the common random number stream used by this rollout. + """ + self.agents = agents + self.agent_ids = self.agents.keys() + self.simulation = simulation + self.simulation_step_log = None + self.agent_step_logs = {agent_id: None for agent_id in self.agent_ids} + self.actions = {} + self.id = id + self.crn_id = crn_id + + def get_id(self) -> str: + return self.id + + def get_crn_id(self) -> int: + return self.crn_id + + def get_agent_ids(self) -> List[AgentId]: + return list(self.agent_ids) + + async def get_action_of_agent_without_side_effects( + self, agent_id: AgentId + ) -> Tuple[Any, AgentActLog]: + """ + Safe function to get an action of an agent without modifying the agent or the simulation. + """ + agent = self.agents[agent_id] + agent_before_action = agent.get_safe_copy() + obs = self.simulation.get_obs_agent(agent_id) + action, action_info = await agent.act(observation=obs) + self.agents[agent_id] = agent_before_action + agent_after_action = agent.get_safe_copy() + return AgentAndActionSafeCopy(action, action_info, agent_after_action) + + async def get_actions_of_agents_without_side_effects( + self, + ) -> dict[AgentId, AgentAndActionSafeCopy]: + """ + Safe function to get an action of an agent without modifying the agent or the simulation. + """ + tasks = [] + for agent_id in self.agent_ids: + task = asyncio.create_task( + self.get_action_of_agent_without_side_effects(agent_id) + ) + tasks.append(task) + agent_and_action_safe_copies: list[ + AgentAndActionSafeCopy + ] = await asyncio.gather(*tasks) + return { + agent_id: agent_and_action_safe_copy + for agent_id, agent_and_action_safe_copy in zip( + self.agent_ids, agent_and_action_safe_copies + ) + } + + def set_action_and_agent_after_action_manually( + self, + agent_id: AgentId, + agent_action_safe_copy: AgentAndActionSafeCopy, + ): + """ + Set the action and the agent after action manually. + """ + self.actions[agent_id] = agent_action_safe_copy.action + self.agent_step_logs[agent_id] = agent_action_safe_copy.action_info + self.agents[agent_id] = agent_action_safe_copy.agent_after_action + + def set_actions_of_agents_manually( + self, actions: dict[AgentId, AgentAndActionSafeCopy] + ): + """ + Set the actions of agents manually. + """ + for agent_id, agent_action_safe_copy in actions.items(): + self.set_action_and_agent_after_action_manually( + agent_id, agent_action_safe_copy + ) + + async def set_action_of_agent(self, agent_id: AgentId): + """ + Query a single agent for its next action and store the result locally. + """ + agent = self.agents[agent_id] + obs = self.simulation.get_obs_agent(agent_id) + action, action_info = await agent.act(observation=obs) + self.actions[agent_id] = action + self.agent_step_logs[agent_id] = action_info + + async def set_actions(self): + """ + Query every agent concurrently and populate the cached actions/logs. + """ + # background_tasks = set() + tasks = [] + for agent_id in self.agent_ids: + task = asyncio.create_task(self.set_action_of_agent(agent_id)) + tasks.append(task) + await asyncio.gather(*tasks) + + def take_simulation_step(self): + """ + Advance the simulation by one step using the cached actions. + """ + terminated, self.simulation_step_log = self.simulation.step(self.actions) + return terminated + + def get_step_log(self) -> StepLog: + """ + Package the most recent simulation step and agent logs into a StepLog. + """ + if self.simulation_step_log is None: + raise RuntimeError( + "Simulation step log is empty; call take_simulation_step() first." + ) + missing_logs = [ + agent_id for agent_id, log in self.agent_step_logs.items() if log is None + ] + if missing_logs: + raise RuntimeError( + f"Agent action logs missing for: {', '.join(missing_logs)}. " + "Ensure set_actions() ran before requesting the step log." + ) + step_log = StepLog( + simulation_step_log=self.simulation_step_log, + action_logs=self.agent_step_logs, + ) + return step_log + + async def step(self) -> Tuple[bool, StepLog]: + """ + Convenience step that collects actions, advances the simulation, and returns the log. + """ + await self.set_actions() + terminated = self.take_simulation_step() + step_log = self.get_step_log() + return terminated, step_log + + def get_safe_copy(self): + """ + Create a shallow copy of the game with deep-copied agents/simulation for branching. + """ + + new_markov_game = copy.copy(self) + new_simulation = self.simulation.get_safe_copy() + new_agents = { + agent_id: agent.get_safe_copy() for agent_id, agent in self.agents.items() + } + + # Reassign copied components + new_markov_game.simulation = new_simulation + new_markov_game.agents = new_agents + + # IMPORTANT: ensure agent_ids references the new agents dict, not the original + new_markov_game.agent_ids = new_markov_game.agents.keys() + + # Deep-copy step data to avoid correlation + new_markov_game.simulation_step_log = copy.deepcopy(self.simulation_step_log) + new_markov_game.actions = copy.deepcopy(self.actions) + # Rebuild logs to align exactly with new agent ids + old_agent_step_logs = copy.deepcopy(self.agent_step_logs) + new_markov_game.agent_step_logs = { + agent_id: old_agent_step_logs.get(agent_id) + for agent_id in new_markov_game.agent_ids + } + + return new_markov_game diff --git a/src_code_for_reproducibility/markov_games/mg_utils.py b/src_code_for_reproducibility/markov_games/mg_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4fc406cd1f0cba593daad1108de2746b6a1d7678 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/mg_utils.py @@ -0,0 +1,97 @@ +""" +File: mllm/markov_games/mg_utils.py +Summary: Holds miscellaneous helpers shared across Markov-game modules. +""" + +import asyncio +import copy +from collections.abc import Callable +from dataclasses import dataclass + +from mllm.markov_games.ipd.ipd_agent import IPDAgent +from mllm.markov_games.ipd.Ipd_hard_coded_agents import ( + AlwaysCooperateIPDAgent, + AlwaysDefectIPDAgent, +) +from mllm.markov_games.ipd.ipd_simulation import IPD +from mllm.markov_games.markov_game import MarkovGame +from mllm.markov_games.negotiation.dond_agent import DealNoDealAgent +from mllm.markov_games.negotiation.dond_simulation import DealNoDealSimulation +from mllm.markov_games.negotiation.nego_hard_coded_policies import ( + HardCodedNegoGreedyPolicy, + HardCodedNegoWelfareMaximizingPolicy, +) +from mllm.markov_games.negotiation.no_press_nego_agent import NoPressAgent +from mllm.markov_games.negotiation.no_press_nego_simulation import NoPressSimulation +from mllm.markov_games.negotiation.tas_rps_agent import TrustAndSplitRPSAgent +from mllm.markov_games.negotiation.tas_rps_simulation import TrustAndSplitRPSSimulation +from mllm.markov_games.rollout_tree import ( + AgentActLog, + RolloutTreeBranchNode, + RolloutTreeNode, + RolloutTreeRootNode, + StepLog, +) +from mllm.markov_games.simulation import SimulationStepLog + +AgentId = str + + +@dataclass +class AgentConfig: + """Configuration blob describing one agent in a Markov game spec.""" + + agent_id: str + agent_name: str + agent_class_name: str + policy_id: str + init_kwargs: dict + + +@dataclass +class MarkovGameConfig: + """Top-level config that ties together simulation settings and agent configs.""" + + id: int + seed: int + simulation_class_name: str + simulation_init_args: dict + agent_configs: list[AgentConfig] + + +def init_markov_game_components( + config: MarkovGameConfig, policies: dict[str, Callable[[list[dict]], str]] +): + """ + Materialize Agents and the Simulation described by ``config`` and return a MarkovGame. + + `policies` is a mapping of policy_id -> callable retrieved from the hosting trainer. + """ + agents = {} + agent_names = [] + for agent_config in config.agent_configs: + agent_id = agent_config.agent_id + agent_name = agent_config.agent_name + agent_class = eval(agent_config.agent_class_name) + agent = agent_class( + seed=config.seed, + agent_id=agent_id, + agent_name=agent_name, + policy=policies[agent_config.policy_id], + **agent_config.init_kwargs, + ) + agents[agent_id] = agent + agent_names.append(agent_name) + simulation = eval(config.simulation_class_name)( + seed=config.seed, + agent_ids=list(agents.keys()), + agent_names=agent_names, + **config.simulation_init_args, + ) + markov_game = MarkovGame( + id=config.id, + crn_id=config.seed, + agents=agents, + simulation=simulation, + ) + return markov_game diff --git a/src_code_for_reproducibility/markov_games/rollout_tree.py b/src_code_for_reproducibility/markov_games/rollout_tree.py new file mode 100644 index 0000000000000000000000000000000000000000..c9feb0e92f3bcf19255d80c6ff2dcd9a045c6c21 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/rollout_tree.py @@ -0,0 +1,95 @@ +""" +File: mllm/markov_games/rollout_tree.py +Summary: Defines rollout tree data structures and serialization helpers. +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Any, List, Literal, Optional, Tuple + +import jsonschema +from pydantic import BaseModel, Field, model_validator + +from mllm.chat_utils.chat_turn import ChatTurn + +AgentId = str + + +class SimulationStepLog(BaseModel): + """Minimal snapshot of environment-side rewards and auxiliary info.""" + + rewards: dict[AgentId, float] + info: Any = None + + +class AgentActLog(BaseModel): + """LLM-side provenance for an action (chat turns + metadata).""" + + chat_turns: list[ChatTurn] | None + info: Any = None + + @model_validator(mode="after") + def _exactly_one_state_end(self): + """ + This method is used to enforce that for each AgentActLog, there is exactly one ChatTurn which is a state end. + """ + if self.chat_turns != []: + n = sum(1 for t in self.chat_turns if t.is_state_end) + if n != 1: + raise ValueError( + f"AgentActLog must have exactly one ChatTurn with is_state_end=True; got {self.chat_turns}." + ) + return self + else: + return self + + +class StepLog(BaseModel): + action_logs: dict[AgentId, AgentActLog] + simulation_step_log: SimulationStepLog + + +# BranchType = Literal["unilateral_deviation", "common_deviation"] # might not be necessary +# class BranchNodeInfo(BaseModel): +# branch_id: str +# branch_for: AgentId +# branch_type: BranchType + + +class RolloutTreeNode(BaseModel): + """Single timestep of the main trajectory (or a branch) plus linkage.""" + + step_log: StepLog + time_step: int + child: RolloutTreeNode | RolloutTreeBranchNode | None = None + + +class RolloutTreeBranchNode(BaseModel): + """ + First item of the tuple indicates which agent "called" for an alternative branch. + """ + + main_child: RolloutTreeNode + branches: dict[AgentId, list[RolloutTreeNode]] | None = None + + +class RolloutTreeRootNode(BaseModel): + """Entry point for serialized rollouts (main path plus optional branches).""" + + id: int + crn_id: int # ID of the rng used to generate this rollout tree + child: RolloutTreeNode | RolloutTreeBranchNode | None = None + agent_ids: List[AgentId] = Field(min_length=1) + + +# class RolloutTreeLeafNode(BaseModel): +# step_log: StepLog +# time_step: int + + +# Necessary for self-referential stuff in pydantic +RolloutTreeBranchNode.model_rebuild() +RolloutTreeNode.model_rebuild() diff --git a/src_code_for_reproducibility/markov_games/run_markov_games.py b/src_code_for_reproducibility/markov_games/run_markov_games.py new file mode 100644 index 0000000000000000000000000000000000000000..4a686ca98104595bbbd85b0f519c981bac952c70 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/run_markov_games.py @@ -0,0 +1,52 @@ +""" +File: mllm/markov_games/run_markov_games.py +Summary: CLI entry point for running configured Markov-game experiments. +""" + +import asyncio +from collections.abc import Callable +from dataclasses import dataclass + +from torch._C import ClassType + +from mllm.markov_games.markov_game import MarkovGame +from mllm.markov_games.rollout_tree import RolloutTreeRootNode + + +async def run_markov_games( + runner: Callable[[MarkovGame], RolloutTreeRootNode], + runner_kwargs: dict, + output_folder: str, + markov_games: list[MarkovGame], +) -> list[RolloutTreeRootNode]: + """ + Kick off multiple Markov game rollouts concurrently and return their trees. + + Parameters mirror the Hydra configs (runner callable + kwargs) so callers can + choose ``LinearRunner``, ``AlternativeActionsRunner`` or future variants. + """ + runner_kwargs = dict(runner_kwargs) + max_parallel_games = runner_kwargs.pop("max_parallel_games", None) + + async def run_game(markov_game: MarkovGame) -> RolloutTreeRootNode: + return await runner( + markov_game=markov_game, + output_folder=output_folder, + **runner_kwargs, + ) + + if max_parallel_games is not None: + semaphore = asyncio.Semaphore(max(1, int(max_parallel_games))) + + async def run_game(markov_game: MarkovGame) -> RolloutTreeRootNode: + async with semaphore: + return await runner( + markov_game=markov_game, + output_folder=output_folder, + **runner_kwargs, + ) + + tasks = [] + for mg in markov_games: + tasks.append(asyncio.create_task(run_game(mg))) + return await asyncio.gather(*tasks) diff --git a/src_code_for_reproducibility/markov_games/simulation.py b/src_code_for_reproducibility/markov_games/simulation.py new file mode 100644 index 0000000000000000000000000000000000000000..dd0a2e61924f9a79aee3229ed8d7aa20827ae859 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/simulation.py @@ -0,0 +1,94 @@ +""" +File: mllm/markov_games/simulation.py +Summary: Core simulation loop utilities and step logging for Markov games. +""" + +from abc import ABC, abstractmethod +from typing import Any, Tuple + +from numpy.random import default_rng + +from mllm.markov_games.rollout_tree import SimulationStepLog + + +class Simulation(ABC): + @abstractmethod + def __init__(self, seed: int, *args, **kwargs): + self.seed = seed + self.rng = default_rng(self.seed) + + @abstractmethod + def step(self, actions: Any) -> Tuple[bool, SimulationStepLog]: + """ + Advance the environment by one logical tick using ``actions``. + + Returns + ------- + terminated: bool + Whether the episode has finished. + SimulationStepLog + Reward/info bundle describing this transition. + """ + raise NotImplementedError + + def get_obs(self): + """Return a dict mapping agent_id -> observation for *all* agents.""" + raise NotImplementedError + + def get_obs_agent(self, agent_id): + """Return the observation for a single agent.""" + raise NotImplementedError + + def get_obs_size(self): + """Describe the observation tensor shape (useful for critic heads).""" + raise NotImplementedError + + def get_state(self): + """Return the privileged simulator state if available.""" + raise NotImplementedError + + def get_state_size(self): + """Describe the state tensor shape.""" + raise NotImplementedError + + def get_avail_actions(self): + """Return the global action mask/tensor if the space is discrete.""" + raise NotImplementedError + + def get_avail_agent_actions(self, agent_id): + """Return the available action mask for a given agent.""" + raise NotImplementedError + + def get_total_actions(self): + """Returns the total number of actions an agent could ever take. + + Implementations currently assume a discrete, one-dimensional action space per agent. + """ + raise NotImplementedError + + def get_safe_copy(self): + """ + Return copy of the simulator that shares no mutable state with the original. + """ + raise NotImplementedError + + def reset(self): + """Reset to the initial state and return the starting observations.""" + raise NotImplementedError + + def render(self): + """Optional human-facing visualization.""" + raise NotImplementedError + + def close(self): + """Release any owned resources (files, processes, etc.).""" + raise NotImplementedError + + # def seed(self): + # raise NotImplementedError + + def save_replay(self): + raise NotImplementedError + + def get_simulation_info(self): + raise NotImplementedError diff --git a/src_code_for_reproducibility/markov_games/statistics_runner.py b/src_code_for_reproducibility/markov_games/statistics_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..e58131fc505806a758936978a46e4f8faefacad3 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/statistics_runner.py @@ -0,0 +1,415 @@ +""" +File: mllm/markov_games/statistics_runner.py +Summary: Executes multiple rollouts to compute experiment statistics. +""" + +from __future__ import annotations + +import gc +import json +import pickle +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional + +from basic_render import find_iteration_folders + +from mllm.markov_games.rollout_tree import ( + RolloutTreeBranchNode, + RolloutTreeNode, + RolloutTreeRootNode, + SimulationStepLog, +) + + +def _iterate_main_nodes(root: RolloutTreeRootNode) -> Iterator[RolloutTreeNode]: + """ + Iterate the main path nodes without materializing full path lists. + """ + current = root.child + while current is not None: + if isinstance(current, RolloutTreeNode): + yield current + current = current.child + elif isinstance(current, RolloutTreeBranchNode): + # Follow only the main child on the main trajectory + current = current.main_child + else: + break + + +def iterate_main_simulation_logs( + root: RolloutTreeRootNode, +) -> Iterator[SimulationStepLog]: + """Yield ``SimulationStepLog`` objects along the main (non-branch) path.""" + for node in _iterate_main_nodes(root): + yield node.step_log.simulation_step_log + + +def stream_rollout_files(iteration_folder: Path) -> Iterator[Path]: + """Iterate over every ``*.rt.pkl`` file under an iteration directory.""" + for p in iteration_folder.rglob("*.rt.pkl"): + if p.is_file(): + yield p + + +def load_root(path: Path) -> RolloutTreeRootNode: + """Load and validate a rollout tree from disk.""" + with open(path, "rb") as f: + data = pickle.load(f) + return RolloutTreeRootNode.model_validate(data) + + +@dataclass +class StatRecord: + """Convenience container for serialized stat rows.""" + + mgid: int + crn_id: Optional[int] + iteration: str + values: Dict[str, Any] + + +class StatComputer: + """ + Stateful stat computer that consumes SimulationStepLog instances + and produces final aggregated values for one rollout (mgid). + """ + + def update(self, sl: SimulationStepLog) -> None: # pragma: no cover - interface + raise NotImplementedError + + def finalize(self) -> Dict[str, Any]: # pragma: no cover - interface + raise NotImplementedError + + +def run_stats( + data_root: Path, + game_name: str, + make_computers: Callable[[], List[StatComputer]], + output_filename: Optional[str] = None, + output_format: str = "json", # "json" (dict of lists) or "jsonl" +) -> Path: + """ + Compute stats across all iteration_* folders under data_root. + Writes JSONL to data_root/statistics/. + """ + data_root = Path(data_root) + outdir = data_root / "statistics" + outdir.mkdir(parents=True, exist_ok=True) + # Choose extension by format + default_name = ( + f"{game_name}.stats.json" + if output_format == "json" + else f"{game_name}.stats.jsonl" + ) + outfile = outdir / ( + output_filename if output_filename is not None else default_name + ) + + # Rewrite file each run to keep it clean and small + if outfile.exists(): + outfile.unlink() + + iteration_folders = find_iteration_folders(str(data_root)) + + # If writing JSONL, stream directly; otherwise accumulate minimal records + if output_format == "jsonl": + with open(outfile, "w", encoding="utf-8") as w: + for iteration_folder in iteration_folders: + iteration_name = Path(iteration_folder).name + for pkl_path in stream_rollout_files(Path(iteration_folder)): + root = load_root(pkl_path) + + computers = make_computers() + for sl in iterate_main_simulation_logs(root): + for comp in computers: + try: + comp.update(sl) + except Exception: + continue + + values: Dict[str, Any] = {} + for comp in computers: + try: + values.update(comp.finalize()) + except Exception: + continue + + rec = { + "mgid": getattr(root, "id", None), + "crn_id": getattr(root, "crn_id", None), + "iteration": iteration_name, + "stats": values, + } + w.write(json.dumps(rec, ensure_ascii=False) + "\n") + + del root + del computers + gc.collect() + else: + # Aggregate to dict-of-lists for easier plotting + records: List[Dict[str, Any]] = [] + # Process in deterministic order + for iteration_folder in iteration_folders: + iteration_name = Path(iteration_folder).name + for pkl_path in stream_rollout_files(Path(iteration_folder)): + root = load_root(pkl_path) + + computers = make_computers() + for sl in iterate_main_simulation_logs(root): + for comp in computers: + try: + comp.update(sl) + except Exception: + continue + + values: Dict[str, Any] = {} + for comp in computers: + try: + values.update(comp.finalize()) + except Exception: + continue + + records.append( + { + "mgid": getattr(root, "id", None), + "crn_id": getattr(root, "crn_id", None), + "iteration": iteration_name, + "stats": values, + } + ) + + del root + del computers + gc.collect() + + # Build dict-of-lists with nested stats preserved + # Collect all stat keys and nested agent keys where needed + mgids: List[Any] = [] + crn_ids: List[Any] = [] + iterations_out: List[str] = [] + # stats_out is a nested structure mirroring keys but with lists + stats_out: Dict[str, Any] = {} + + # First pass to collect union of keys + stat_keys: set[str] = set() + nested_agent_keys: Dict[str, set[str]] = {} + for r in records: + stats = r.get("stats", {}) or {} + for k, v in stats.items(): + stat_keys.add(k) + if isinstance(v, dict): + nested = nested_agent_keys.setdefault(k, set()) + for ak in v.keys(): + nested.add(str(ak)) + + # Initialize structure + for k in stat_keys: + if k in nested_agent_keys: + stats_out[k] = {ak: [] for ak in sorted(nested_agent_keys[k])} + else: + stats_out[k] = [] + + # Fill lists + for r in records: + mgids.append(r.get("mgid")) + crn_ids.append(r.get("crn_id")) + iterations_out.append(r.get("iteration")) + stats = r.get("stats", {}) or {} + for k in stat_keys: + val = stats.get(k) + if isinstance(stats_out[k], dict): + # per-agent dict + agent_dict = val if isinstance(val, dict) else {} + for ak in stats_out[k].keys(): + stats_out[k][ak].append(agent_dict.get(ak)) + else: + stats_out[k].append(val) + + with open(outfile, "w", encoding="utf-8") as w: + json.dump( + { + "mgid": mgids, + "crn_id": crn_ids, + "iteration": iterations_out, + "stats": stats_out, + }, + w, + ensure_ascii=False, + ) + + return outfile + + +def run_stats_functional( + data_root: Path, + game_name: str, + metrics: Dict[str, Callable[[SimulationStepLog], Optional[Dict[str, float]]]], + output_filename: Optional[str] = None, + output_format: str = "json", +) -> Path: + """ + Functional variant where metrics is a dict of name -> f(SimulationStepLog) -> {agent_id: value}. + Aggregates per rollout by averaging over steps where a metric produced a value. + Writes a single consolidated file in data_root/statistics/. + """ + data_root = Path(data_root) + outdir = data_root / "statistics" + outdir.mkdir(parents=True, exist_ok=True) + default_name = ( + f"{game_name}.stats.json" + if output_format == "json" + else f"{game_name}.stats.jsonl" + ) + outfile = outdir / ( + output_filename if output_filename is not None else default_name + ) + + if outfile.exists(): + outfile.unlink() + + iteration_folders = find_iteration_folders(str(data_root)) + + def finalize_rollout( + agg: Dict[str, Dict[str, List[float]]] + ) -> Dict[str, Dict[str, float]]: + # avg per metric per agent + result: Dict[str, Dict[str, float]] = {} + for mname, agent_values in agg.items(): + result[mname] = {} + for aid, vals in agent_values.items(): + if not vals: + result[mname][aid] = None # keep alignment; could be None + else: + result[mname][aid] = sum(vals) / len(vals) + return result + + if output_format == "jsonl": + with open(outfile, "w", encoding="utf-8") as w: + for iteration_folder in iteration_folders: + iteration_name = Path(iteration_folder).name + for pkl_path in stream_rollout_files(Path(iteration_folder)): + root = load_root(pkl_path) + + # aggregator structure: metric -> agent_id -> list of values + agg: Dict[str, Dict[str, List[float]]] = { + m: {} for m in metrics.keys() + } + + for sl in iterate_main_simulation_logs(root): + for mname, fn in metrics.items(): + try: + vals = fn(sl) + except Exception: + vals = None + if not vals: + continue + for aid, v in vals.items(): + if v is None: + continue + lst = agg[mname].setdefault(str(aid), []) + try: + lst.append(float(v)) + except Exception: + continue + + values = finalize_rollout(agg) + rec = { + "mgid": getattr(root, "id", None), + "crn_id": getattr(root, "crn_id", None), + "iteration": iteration_name, + "stats": values, + } + w.write(json.dumps(rec, ensure_ascii=False) + "\n") + + del root + gc.collect() + else: + records: List[Dict[str, Any]] = [] + for iteration_folder in iteration_folders: + iteration_name = Path(iteration_folder).name + for pkl_path in stream_rollout_files(Path(iteration_folder)): + root = load_root(pkl_path) + + agg: Dict[str, Dict[str, List[float]]] = {m: {} for m in metrics.keys()} + for sl in iterate_main_simulation_logs(root): + for mname, fn in metrics.items(): + try: + vals = fn(sl) + except Exception: + vals = None + if not vals: + continue + for aid, v in vals.items(): + if v is None: + continue + lst = agg[mname].setdefault(str(aid), []) + try: + lst.append(float(v)) + except Exception: + continue + + values = finalize_rollout(agg) + records.append( + { + "mgid": getattr(root, "id", None), + "crn_id": getattr(root, "crn_id", None), + "iteration": iteration_name, + "stats": values, + } + ) + + del root + gc.collect() + + # Build dict-of-lists output + mgids: List[Any] = [] + crn_ids: List[Any] = [] + iterations_out: List[str] = [] + stats_out: Dict[str, Any] = {} + + stat_keys: set[str] = set() + nested_agent_keys: Dict[str, set[str]] = {} + for r in records: + stats = r.get("stats", {}) or {} + for k, v in stats.items(): + stat_keys.add(k) + if isinstance(v, dict): + nested = nested_agent_keys.setdefault(k, set()) + for ak in v.keys(): + nested.add(str(ak)) + + for k in stat_keys: + if k in nested_agent_keys: + stats_out[k] = {ak: [] for ak in sorted(nested_agent_keys[k])} + else: + stats_out[k] = [] + + for r in records: + mgids.append(r.get("mgid")) + crn_ids.append(r.get("crn_id")) + iterations_out.append(r.get("iteration")) + stats = r.get("stats", {}) or {} + for k in stat_keys: + val = stats.get(k) + if isinstance(stats_out[k], dict): + agent_dict = val if isinstance(val, dict) else {} + for ak in stats_out[k].keys(): + stats_out[k][ak].append(agent_dict.get(ak)) + else: + stats_out[k].append(val) + + with open(outfile, "w", encoding="utf-8") as w: + json.dump( + { + "mgid": mgids, + "crn_id": crn_ids, + "iteration": iterations_out, + "stats": stats_out, + }, + w, + ensure_ascii=False, + ) + + return outfile diff --git a/src_code_for_reproducibility/models/__init__.py b/src_code_for_reproducibility/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..46d40ee71acb4c10a596d4107d18fd3e890df610 --- /dev/null +++ b/src_code_for_reproducibility/models/__init__.py @@ -0,0 +1,4 @@ +""" +File: mllm/models/__init__.py +Summary: Exports model-layer utilities from the models package. +""" diff --git a/src_code_for_reproducibility/models/__pycache__/human_policy.cpython-312.pyc b/src_code_for_reproducibility/models/__pycache__/human_policy.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ad33b5680253431b63bb77c966e58c880585a2d Binary files /dev/null and b/src_code_for_reproducibility/models/__pycache__/human_policy.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/models/__pycache__/inference_backend_dummy.cpython-312.pyc b/src_code_for_reproducibility/models/__pycache__/inference_backend_dummy.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be4492709842bfaedaa4a74be83368b4946faa6a Binary files /dev/null and b/src_code_for_reproducibility/models/__pycache__/inference_backend_dummy.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/models/__pycache__/inference_backend_vllm.cpython-312.pyc b/src_code_for_reproducibility/models/__pycache__/inference_backend_vllm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7061a0770c18aa38827d50ecf85807ccc8f749e5 Binary files /dev/null and b/src_code_for_reproducibility/models/__pycache__/inference_backend_vllm.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/models/__pycache__/large_language_model_api.cpython-312.pyc b/src_code_for_reproducibility/models/__pycache__/large_language_model_api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8fea9ea8ac688ef22a8e977eaa027844ff05aa2 Binary files /dev/null and b/src_code_for_reproducibility/models/__pycache__/large_language_model_api.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/models/__pycache__/large_language_model_gemini_api.cpython-312.pyc b/src_code_for_reproducibility/models/__pycache__/large_language_model_gemini_api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..732e5615870eb0adff91c461cc72281ef309b585 Binary files /dev/null and b/src_code_for_reproducibility/models/__pycache__/large_language_model_gemini_api.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/models/__pycache__/large_language_model_local.cpython-312.pyc b/src_code_for_reproducibility/models/__pycache__/large_language_model_local.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..594ac70a032211e50955c9992b1a778a6d74dcf7 Binary files /dev/null and b/src_code_for_reproducibility/models/__pycache__/large_language_model_local.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/models/__pycache__/scalar_critic.cpython-312.pyc b/src_code_for_reproducibility/models/__pycache__/scalar_critic.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..410d3fffa1819c04efe55d12f28f1acae0c7686b Binary files /dev/null and b/src_code_for_reproducibility/models/__pycache__/scalar_critic.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/models/adapter_training_wrapper.py b/src_code_for_reproducibility/models/adapter_training_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..f99c768935cca8203f7a9482d25aeade0dee0d59 --- /dev/null +++ b/src_code_for_reproducibility/models/adapter_training_wrapper.py @@ -0,0 +1,104 @@ +""" +File: mllm/models/adapter_training_wrapper.py +Summary: Wraps a shared LLM with adapter-specific PEFT handling for training. +""" + +import logging +from typing import Union + +import torch +import torch.nn as nn +from peft import LoraConfig, get_peft_model + +logger = logging.getLogger(__name__) + + +class AdapterWrapper(nn.Module): + """ + A thin façade that + • keeps a reference to a *shared* PEFT-wrapped model, + • ensures `set_adapter(adapter)` is called on every forward, + • exposes only the parameters that should be trained for that adapter + (plus whatever extra modules you name). + """ + + def __init__( + self, + shared_llm: nn.Module, + adapter_id: str, + lora_config: dict, + path: Union[str, None] = None, + ): + super().__init__() + self.shared_llm = shared_llm + self.adapter_id = adapter_id + lora_config = LoraConfig(**lora_config) + # this modifies the shared llm in place, adding a lora adapter inside + self.shared_llm = get_peft_model( + model=shared_llm, + peft_config=lora_config, + adapter_name=adapter_id, + ) + self.shared_llm.train() + # Load external adapter weights if provided + loaded_from: str | None = None + if path: + try: + # Supports both local filesystem paths and HF Hub repo IDs + self.shared_llm.load_adapter( + is_trainable=True, + model_id=path, + adapter_name=adapter_id, + ) + loaded_from = path + except ( + Exception + ) as exc: # noqa: BLE001 - want to log any load failure context + logger.warning( + f"Adapter '{adapter_id}': failed to load from '{path}': {exc}" + ) + + if loaded_from: + logger.info( + f"Adapter '{adapter_id}': loaded initial weights from '{loaded_from}'." + ) + else: + logger.info( + f"Adapter '{adapter_id}': initialized with fresh weights (no initial weights found)." + ) + + def parameters(self, recurse: bool = True): + """ + "recurse" is just for pytorch compatibility + """ + self.shared_llm.set_adapter(self.adapter_id) + params = [p for p in self.shared_llm.parameters() if p.requires_grad] + + return params + + def get_base_model_logits(self, contexts): + """ + Run the base model (without adapter) in inference mode, without tracking gradients. + This is useful to get reference logits for KL-divergence computation. + """ + with torch.no_grad(): + with self.shared_llm.disable_adapter(): + return self.shared_llm(input_ids=contexts)[0] + + def forward(self, *args, **kwargs): + self.shared_llm.set_adapter(self.adapter_id) + return self.shared_llm(*args, **kwargs) + + def save_pretrained(self, save_path): + self.shared_llm.save_pretrained(save_path) + + def gradient_checkpointing_enable(self, *args, **kwargs): + self.shared_llm.gradient_checkpointing_enable(*args, **kwargs) + + @property + def dtype(self): + return self.shared_llm.dtype + + @property + def device(self): + return self.shared_llm.device diff --git a/src_code_for_reproducibility/models/human_policy.py b/src_code_for_reproducibility/models/human_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..699c2d817c17abfdd0e144e99e14fb2b3ba06872 --- /dev/null +++ b/src_code_for_reproducibility/models/human_policy.py @@ -0,0 +1,260 @@ +""" +File: mllm/models/human_policy.py +Summary: Implements an interactive human-in-the-loop policy for experiments. +""" + +import asyncio +import os +import re +import shutil +import sys +from typing import Callable, Dict, List, Optional + +from mllm.markov_games.rollout_tree import ChatTurn + +try: + import rstr # For generating example strings from regex +except Exception: # pragma: no cover + rstr = None + + +def _clear_terminal() -> None: + """ + Clear the terminal screen in a cross-platform manner. + """ + if sys.stdout.isatty(): + os.system("cls" if os.name == "nt" else "clear") + + +def _terminal_width(default: int = 100) -> int: + try: + return shutil.get_terminal_size().columns + except Exception: + return default + + +def _horizontal_rule(char: str = "─") -> str: + width = max(20, _terminal_width() - 2) + return char * width + + +class _Style: + # ANSI colors (bright, readable) + RESET = "\033[0m" + BOLD = "\033[1m" + DIM = "\033[2m" + # Foreground colors + FG_BLUE = "\033[94m" # user/system headers + FG_GREEN = "\033[92m" # human response header + FG_YELLOW = "\033[93m" # notices + FG_RED = "\033[91m" # errors + FG_MAGENTA = "\033[95m" # regex + FG_CYAN = "\033[96m" # tips + + +def _render_chat(state) -> str: + """ + Render prior messages in a compact, readable terminal format. + + Expected message dict keys: {"role": str, "content": str, ...} + """ + lines: List[str] = [] + lines.append(_horizontal_rule()) + lines.append(f"{_Style.FG_BLUE}{_Style.BOLD} Conversation so far {_Style.RESET}") + lines.append(_horizontal_rule()) + for chat in state: + role = chat.role + content = str(chat.content).strip() + # Map roles to display names and colors/emojis + if role == "assistant": + header = f"{_Style.FG_GREEN}{_Style.BOLD}HUMAN--🧑‍💻{_Style.RESET}" + elif role == "user": + header = f"{_Style.FG_BLUE}{_Style.BOLD}USER--⚙️{_Style.RESET}" + else: + header = f"[{_Style.DIM}{role.upper()}{_Style.RESET}]" + lines.append(header) + # Indent content for readability + for line in content.splitlines() or [""]: + lines.append(f" {line}") + lines.append("") + lines.append(_horizontal_rule()) + return "\n".join(lines) + + +async def _async_input(prompt_text: str) -> str: + """Non-blocking input using a background thread.""" + return await asyncio.to_thread(input, prompt_text) + + +def _short_regex_example(regex: str, max_len: int = 30) -> Optional[str]: + """ + Try to produce a short example string that matches the regex. + We attempt multiple times and pick the first <= max_len. + """ + if rstr is None: + return None + try: + for _ in range(20): + candidate = rstr.xeger(regex) + if len(candidate) <= max_len: + return candidate + # Fallback to truncation (may break match, so don't return) + return None + except Exception: + return None + + +def _detect_input_type(regex: str | None) -> tuple[str, str, str]: + """ + Detect what type of input is expected based on the regex pattern. + Returns (input_type, start_tag, end_tag) + """ + if regex is None: + return "text", "", "" + + if "message_start" in regex and "message_end" in regex: + return "message", "<>", "<>" + elif "proposal_start" in regex and "proposal_end" in regex: + return "proposal", "<>", "<>" + else: + return "text", "", "" + + +async def human_policy(state, agent_id, regex: str | None = None) -> str: + """ + Async human-in-the-loop policy. + + - Displays prior conversation context in the terminal. + - Prompts the user for a response. + - If a regex is provided, validates and re-prompts until it matches. + - Automatically adds formatting tags based on expected input type. + + Args: + prompt: Chat history as a list of {role, content} dicts. + regex: Optional fullmatch validation pattern. + + Returns: + The user's validated response string. + """ + # Detect input type and formatting + input_type, start_tag, end_tag = _detect_input_type(regex) + + while True: + _clear_terminal() + print(_render_chat(state)) + + if regex: + example = _short_regex_example(regex, max_len=30) + print( + f"{_Style.FG_MAGENTA}{_Style.BOLD}Expected format (regex fullmatch):{_Style.RESET}" + ) + print(f" {_Style.FG_MAGENTA}{regex}{_Style.RESET}") + if example: + print( + f"{_Style.FG_CYAN}Example (random, <=30 chars):{_Style.RESET} {example}" + ) + print(_horizontal_rule(".")) + + # Custom prompt based on input type + if input_type == "message": + print( + f"{_Style.FG_YELLOW}Type your message content (formatting will be added automatically):{_Style.RESET}" + ) + elif input_type == "proposal": + print( + f"{_Style.FG_YELLOW}Type your proposal (number only, formatting will be added automatically):{_Style.RESET}" + ) + else: + print( + f"{_Style.FG_YELLOW}Type your response and press Enter.{_Style.RESET}" + ) + + print( + f"{_Style.DIM}Commands: /help to view commands, /refresh to re-render, /quit to abort{_Style.RESET}" + ) + else: + print( + f"{_Style.FG_YELLOW}Type your response and press Enter.{_Style.RESET} {_Style.DIM}(/help for commands){_Style.RESET}" + ) + + user_in = (await _async_input("> ")).rstrip("\n") + + # Commands + if user_in.strip().lower() in {"/help", "/h"}: + print(f"\n{_Style.FG_CYAN}{_Style.BOLD}Available commands:{_Style.RESET}") + print( + f" {_Style.FG_CYAN}/help{_Style.RESET} or {_Style.FG_CYAN}/h{_Style.RESET} Show this help" + ) + print( + f" {_Style.FG_CYAN}/refresh{_Style.RESET} or {_Style.FG_CYAN}/r{_Style.RESET} Re-render the conversation and prompt" + ) + print( + f" {_Style.FG_CYAN}/quit{_Style.RESET} or {_Style.FG_CYAN}/q{_Style.RESET} Abort the run (raises KeyboardInterrupt)" + ) + await asyncio.sleep(1.0) + continue + if user_in.strip().lower() in {"/refresh", "/r"}: + continue + if user_in.strip().lower() in {"/quit", "/q"}: + raise KeyboardInterrupt("Human aborted run from human_policy") + + # Add formatting tags if needed + if start_tag and end_tag: + formatted_input = f"{start_tag}{user_in}{end_tag}" + else: + formatted_input = user_in + + if regex is None: + return ChatTurn( + role="assistant", agent_id=agent_id, content=formatted_input + ) + + # Validate against regex (fullmatch) + try: + pattern = re.compile(regex) + except re.error as e: + # If regex is invalid, fall back to accepting any input + print( + f"{_Style.FG_RED}Warning:{_Style.RESET} Provided regex is invalid: {e}. Accepting input without validation." + ) + await asyncio.sleep(0.5) + return ChatTurn( + role="assistant", agent_id=agent_id, content=formatted_input + ) + + if pattern.fullmatch(formatted_input): + return ChatTurn( + role="assistant", agent_id=agent_id, content=formatted_input + ) + + # Show validation error and re-prompt + print("") + print( + f"{_Style.FG_RED}{_Style.BOLD}Input did not match the required format.{_Style.RESET} Please try again." + ) + + if input_type == "message": + print( + f"You entered: {_Style.FG_CYAN}{start_tag}{user_in}{end_tag}{_Style.RESET}" + ) + print(f"Just type the message content without tags.") + elif input_type == "proposal": + print( + f"You entered: {_Style.FG_CYAN}{start_tag}{user_in}{end_tag}{_Style.RESET}" + ) + print(f"Just type the number without tags.") + else: + print(f"Expected (regex):") + print(f" {_Style.FG_MAGENTA}{regex}{_Style.RESET}") + + print(_horizontal_rule(".")) + print(f"{_Style.FG_YELLOW}Press Enter to retry...{_Style.RESET}") + await _async_input("") + + +def get_human_policies() -> Dict[str, Callable[[List[Dict]], str]]: + """ + Expose the human policy in the same map shape used elsewhere. + """ + # Type hint says Callable[[List[Dict]], str] but we intentionally return the async callable. + return {"human_policy": human_policy} # type: ignore[return-value] diff --git a/src_code_for_reproducibility/models/inference_backend.py b/src_code_for_reproducibility/models/inference_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..c204482170d5a4418870805b620295cab294fab6 --- /dev/null +++ b/src_code_for_reproducibility/models/inference_backend.py @@ -0,0 +1,44 @@ +""" +File: mllm/models/inference_backend.py +Summary: Declares the inference backend interface and shared dataclasses. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Optional + + +@dataclass +class LLMInferenceOutput: + content: str + reasoning_content: str | None = None + log_probs: list[float] | None = None + out_token_ids: list[int] | None = None + + +class LLMInferenceBackend(ABC): + @abstractmethod + def __init__(self, **kwargs): + ... + + @abstractmethod + def prepare_adapter( + self, adapter_id: str, weights_got_updated: bool = False + ) -> None: + """Ensure adapter is ready/loaded for next generation call.""" + + @abstractmethod + async def generate(self, prompt: list[dict], regex: Optional[str] = None) -> str: + ... + + @abstractmethod + def toggle_training_mode(self) -> None: + ... + + @abstractmethod + def toggle_eval_mode(self) -> None: + ... + + @abstractmethod + def shutdown(self) -> None: + ... diff --git a/src_code_for_reproducibility/models/inference_backend_dummy.py b/src_code_for_reproducibility/models/inference_backend_dummy.py new file mode 100644 index 0000000000000000000000000000000000000000..22dd123f5fbcf9a976282b0657097edc680c6ac3 --- /dev/null +++ b/src_code_for_reproducibility/models/inference_backend_dummy.py @@ -0,0 +1,59 @@ +""" +File: mllm/models/inference_backend_dummy.py +Summary: Stub inference backend that returns synthetic completions for tests. +""" + +import asyncio +from typing import Optional + +import rstr +from transformers import AutoTokenizer + +from mllm.models.inference_backend import LLMInferenceBackend, LLMInferenceOutput +from mllm.utils.short_id_gen import generate_short_id + + +class DummyInferenceBackend(LLMInferenceBackend): + def __init__( + self, + *args, + **kwargs, + ): + pass + + def prepare_adapter( + self, + adapter_id: Optional[str], + weights_got_updated: bool, + adapter_path: Optional[str] = None, + ) -> None: + pass + + async def toggle_training_mode(self) -> None: + await asyncio.sleep(0) + pass + + async def toggle_eval_mode(self) -> None: + await asyncio.sleep(0) + pass + + def shutdown(self) -> None: + pass + + async def generate( + self, + prompt_text: str, + regex: Optional[str] = None, + extract_thinking: bool = False, + ) -> LLMInferenceOutput: + if regex: + # Create random string that respects the regex + return LLMInferenceOutput( + content=rstr.xeger(regex), + reasoning_content="I don't think, I am a dummy backend.", + ) + else: + return LLMInferenceOutput( + content="I am a dummy backend without a regex.", + reasoning_content="I don't think, I am a dummy backend.", + ) diff --git a/src_code_for_reproducibility/models/inference_backend_vllm.py b/src_code_for_reproducibility/models/inference_backend_vllm.py new file mode 100644 index 0000000000000000000000000000000000000000..a4a7fc73287cb676ce56beea5de77cf03fc24555 --- /dev/null +++ b/src_code_for_reproducibility/models/inference_backend_vllm.py @@ -0,0 +1,111 @@ +""" +File: mllm/models/inference_backend_vllm.py +Summary: Connects to in-process vLLM instances for batched generation. +""" + +import asyncio +import re +from typing import Optional + +import torch +from transformers import AutoTokenizer +from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams +from vllm.inputs import TokensPrompt +from vllm.lora.request import LoRARequest +from vllm.sampling_params import GuidedDecodingParams, RequestOutputKind + +from mllm.models.inference_backend import LLMInferenceBackend, LLMInferenceOutput +from mllm.utils.short_id_gen import generate_short_id + + +class VLLMAsyncBackend(LLMInferenceBackend): + def __init__( + self, + model_name: str, + tokenizer: AutoTokenizer, + # adapter_paths: dict[str, str], + engine_init_kwargs: dict = {}, + sampling_params: dict = {}, + ): + self.model_name = model_name + self.vllm_adapter_ids = {} + ea = dict(model=model_name, **engine_init_kwargs) + self.engine = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**ea)) + + self.sampling_params = sampling_params + self.tokenizer = tokenizer + + def prepare_adapter( + self, + adapter_id: Optional[str], + adapter_path: Optional[str], + weights_got_updated: bool, + ) -> None: + if weights_got_updated: + self.vllm_adapter_ids[adapter_id] = generate_short_id() + self.current_lora_request = LoRARequest( + adapter_id, + self.vllm_adapter_ids[adapter_id], + adapter_path, + ) + + async def toggle_training_mode(self) -> None: + await self.engine.sleep(level=1) + + async def toggle_eval_mode(self) -> None: + await self.engine.wake_up() + + def shutdown(self) -> None: + # No explicit close call; engine stops when process exits. + pass + + async def generate( + self, + input_token_ids: list[int], + regex: Optional[str] = None, + extract_thinking: bool = False, + ) -> LLMInferenceOutput: + # Build SamplingParams correctly + guided = GuidedDecodingParams(regex=regex) if regex else None + sp = SamplingParams( + **self.sampling_params, + guided_decoding=guided, + output_kind=RequestOutputKind.FINAL_ONLY, + ) + + prompt = TokensPrompt(prompt_token_ids=input_token_ids) + request_id = f"req-{asyncio.get_running_loop().time()}" + result_generator = self.engine.generate( + prompt, + sp, # SamplingParams(...) + request_id, + lora_request=self.current_lora_request, + ) + + async for out in result_generator: # with FINAL_ONLY this runs once + res = out + + raw_text = res.outputs[0].text + out_token_ids = res.outputs[0].token_ids + log_probs = [ + logprob_dict[token_id].logprob + for token_id, logprob_dict in zip(out_token_ids, res.outputs[0].logprobs) + ] + log_probs = torch.tensor(log_probs) + out_token_ids = torch.tensor(out_token_ids, dtype=torch.long) + content = raw_text + reasoning_content = None + + if extract_thinking: + m = re.match( + r"^\n\n([\s\S]*?)\n\n(.*)$", raw_text, flags=re.DOTALL + ) + if m: + reasoning_content = m.group(1) + content = m.group(2) + return LLMInferenceOutput( + content=content, + reasoning_content=reasoning_content, + log_probs=log_probs, + out_token_ids=out_token_ids, + ) diff --git a/src_code_for_reproducibility/models/large_language_model_api.py b/src_code_for_reproducibility/models/large_language_model_api.py new file mode 100644 index 0000000000000000000000000000000000000000..11858911e06cce5fc7a97eddf24e95a7616dca95 --- /dev/null +++ b/src_code_for_reproducibility/models/large_language_model_api.py @@ -0,0 +1,184 @@ +""" +File: mllm/models/large_language_model_api.py +Summary: Implements API-based large-language-model inference adapters. +""" + +from __future__ import annotations + +import asyncio +import copy +import os +import random +import re +from typing import Any, Callable, Dict, List, Optional, Sequence + +import backoff +from openai import AsyncOpenAI, OpenAIError + +from mllm.markov_games.rollout_tree import ChatTurn +from mllm.models.inference_backend import LLMInferenceOutput + +# Static list copied from the public OpenAI docs until a discovery endpoint is exposed. +reasoning_models = [ + "gpt-5-nano", + "gpt-5-mini", + "gpt-5", + "o1-mini", + "o1", + "o1-pro", + "o3-mini", + "o3", + "o3-pro", + "o4-mini", + "o4", + "o4-pro", + "openai/gpt-oss-20b", + "openai/gpt-oss-120b", +] + + +class LargeLanguageModelOpenAI: + """Tiny async wrapper for OpenAI Chat Completions.""" + + def __init__( + self, + llm_id: str = "", + model: str = "gpt-4.1-mini", + reasoning_effort: str = "low", + add_constraint_msg: bool = True, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + timeout_s: float = 300.0, + regex_max_attempts: int = 10, + sampling_params: Optional[Dict[str, Any]] = None, + init_kwargs: Optional[Dict[str, Any]] = None, + output_directory: Optional[str] = None, + ) -> None: + self.llm_id = llm_id + self.model = model + key = api_key or os.getenv("OPENAI_API_KEY") + if not key: + raise RuntimeError( + "Set OPENAI_API_KEY as global environment variable or pass api_key." + ) + client_kwargs: Dict[str, Any] = {"api_key": key, "timeout": timeout_s} + if base_url: + client_kwargs["base_url"] = base_url + self.client = AsyncOpenAI(**client_kwargs) + + # Sampling/default request params set at init + self.sampling_params = sampling_params + self.use_reasoning = model in reasoning_models + if self.use_reasoning: + self.sampling_params["reasoning"] = { + "effort": reasoning_effort, + "summary": "detailed", + } + self.regex_max_attempts = max(1, int(regex_max_attempts)) + self.add_constraint_msg = add_constraint_msg + + def get_inference_policies(self) -> Dict[str, Callable]: + return { + self.llm_id: self.get_action, + } + + async def prepare_adapter_for_inference(self, *args: Any, **kwargs: Any) -> None: + await asyncio.sleep(0) + pass + + async def toggle_eval_mode(self, *args: Any, **kwargs: Any) -> None: + await asyncio.sleep(0) + pass + + async def toggle_training_mode(self, *args: Any, **kwargs: Any) -> None: + await asyncio.sleep(0) + pass + + async def export_adapters(self, *args: Any, **kwargs: Any) -> None: + await asyncio.sleep(0) + pass + + async def checkpoint_all_adapters(self, *args: Any, **kwargs: Any) -> None: + await asyncio.sleep(0) + pass + + def extract_output_from_response(self, resp: Response) -> LLMInferenceOutput: + if len(resp.output) > 1: + reasoning_content = resp.output[0].content + summary = resp.output[0].summary + if reasoning_content is not None: + reasoning_content = ( + f"OpenAI Reasoning Content: {reasoning_content[0].text}" + ) + elif summary != []: + reasoning_content = f"OpenAI Reasoning Summary: {summary[0].text}" + else: + reasoning_content = None + content = resp.output[1].content[0].text + else: + reasoning_content = None + content = resp.output[0].content[0].text + + return LLMInferenceOutput( + content=content, + reasoning_content=reasoning_content, + ) + + @backoff.on_exception( + backoff.expo, Exception, max_time=10**10, max_tries=10**10 + ) + async def get_action( + self, + state: list[ChatTurn], + agent_id: str, + regex: Optional[str] = None, + ) -> LLMInferenceOutput: + # Remove any non-role/content keys from the prompt else openai will error. + prompt = [{"role": p.role, "content": p.content} for p in state] + + # if self.sleep_between_requests: + # await self.wait_random_time() + + # If regex is required, prime the model and validate client-side + if regex: + if self.add_constraint_msg: + constraint_msg = { + "role": "user", + "content": ( + f"Output must match this regex exactly: {regex} \n" + "Return only the matching string, with no quotes or extra text." + ), + } + prompt = [constraint_msg, *prompt] + pattern = re.compile(regex) + for _ in range(self.regex_max_attempts): + resp = await self.client.responses.create( + model=self.model, + input=prompt, + **self.sampling_params, + ) + policy_output = self.extract_output_from_response(resp) + if pattern.fullmatch(policy_output.content): + return policy_output + prompt = [ + *prompt, + { + "role": "user", + "content": ( + f"Invalid response format. Expected format (regex): {regex}\n Please try again and provide ONLY a response that matches this regex." + ), + }, + ] + return policy_output + + # Simple, unconstrained generation + resp = await self.client.responses.create( + model=self.model, + input=prompt, + **self.sampling_params, + ) + policy_output = self.extract_output_from_response(resp) + return policy_output + + def shutdown(self) -> None: + self.client = None diff --git a/src_code_for_reproducibility/models/large_language_model_gemini_api.py b/src_code_for_reproducibility/models/large_language_model_gemini_api.py new file mode 100644 index 0000000000000000000000000000000000000000..aa28767610e0982d89e0c21a261235b01c60a03c --- /dev/null +++ b/src_code_for_reproducibility/models/large_language_model_gemini_api.py @@ -0,0 +1,197 @@ +""" +File: mllm/models/large_language_model_gemini_api.py +Summary: Implements native Gemini API-based large-language-model inference adapters. +""" + +from __future__ import annotations + +import asyncio +import os +import re +from typing import Any, Callable, Dict, List, Optional + +import backoff +from google import genai +from google.genai import types + +from mllm.markov_games.rollout_tree import ChatTurn +from mllm.models.inference_backend import LLMInferenceOutput + + +class LargeLanguageModelGemini: + """Tiny async wrapper for the native Gemini API.""" + + def __init__( + self, + llm_id: str = "", + model: str = "gemini-3.1-flash-lite-preview", + api_key: Optional[str] = None, + timeout_s: float = 300.0, + regex_max_attempts: int = 10, + sampling_params: Optional[Dict[str, Any]] = None, + thinking_level: str = "low", + include_thoughts: bool = True, + init_kwargs: Optional[Dict[str, Any]] = None, + output_directory: Optional[str] = None, + ) -> None: + self.llm_id = llm_id + self.model = model + self.timeout_s = timeout_s + key = api_key or os.getenv("GEMINI_API_KEY") + if not key: + raise RuntimeError( + "Set GEMINI_API_KEY as global environment variable or pass api_key." + ) + self.client = genai.Client(api_key=key) + self.sampling_params = sampling_params or {} + self.thinking_level = thinking_level + self.include_thoughts = include_thoughts + self.regex_max_attempts = max(1, int(regex_max_attempts)) + + def get_inference_policies(self) -> Dict[str, Callable]: + return { + self.llm_id: self.get_action, + } + + async def prepare_adapter_for_inference(self, *args: Any, **kwargs: Any) -> None: + await asyncio.sleep(0) + pass + + async def toggle_eval_mode(self, *args: Any, **kwargs: Any) -> None: + await asyncio.sleep(0) + pass + + async def toggle_training_mode(self, *args: Any, **kwargs: Any) -> None: + await asyncio.sleep(0) + pass + + async def export_adapters(self, *args: Any, **kwargs: Any) -> None: + await asyncio.sleep(0) + pass + + async def checkpoint_all_adapters(self, *args: Any, **kwargs: Any) -> None: + await asyncio.sleep(0) + pass + + def messages_to_contents(self, messages: List[Dict[str, str]]) -> List[types.Content]: + contents: List[types.Content] = [] + system_chunks: List[str] = [] + + for message in messages: + role = message["role"] + text = message["content"] + + if role == "system": + system_chunks.append(text) + continue + + gemini_role = "model" if role == "assistant" else "user" + contents.append( + types.Content( + role=gemini_role, + parts=[types.Part.from_text(text=text)], + ) + ) + + if system_chunks: + system_text = "\n\n".join(system_chunks) + contents.insert( + 0, + types.Content( + role="user", + parts=[ + types.Part.from_text( + text=( + "System instruction:\n" + f"{system_text}\n\n" + "Follow the system instruction for the rest of this conversation." + ) + ) + ], + ), + ) + + return contents + + def build_generate_config(self) -> types.GenerateContentConfig: + return types.GenerateContentConfig( + thinking_config=types.ThinkingConfig( + thinking_level=self.thinking_level, + include_thoughts=self.include_thoughts, + ), + **self.sampling_params, + ) + + def extract_output_from_response(self, response: Any) -> LLMInferenceOutput: + reasoning_parts: List[str] = [] + content_parts: List[str] = [] + + if response.candidates: + for part in response.candidates[0].content.parts: + text = getattr(part, "text", None) + if not text: + continue + if getattr(part, "thought", False): + reasoning_parts.append(text) + else: + content_parts.append(text) + + content = "\n".join(content_parts) if content_parts else (response.text or "") + reasoning_content = "\n".join(reasoning_parts) if reasoning_parts else None + + return LLMInferenceOutput( + content=content, + reasoning_content=reasoning_content, + ) + + @backoff.on_exception( + backoff.expo, Exception, max_time=10**10, max_tries=10**10 + ) + async def get_action( + self, + state: list[ChatTurn], + agent_id: str, + regex: Optional[str] = None, + ) -> LLMInferenceOutput: + prompt = [{"role": p.role, "content": p.content} for p in state] + + if regex: + constraint_msg = { + "role": "user", + "content": ( + f"Output must match this regex exactly: {regex} \n" + "Return only the matching string, with no quotes or extra text." + ), + } + prompt = [constraint_msg, *prompt] + pattern = re.compile(regex) + for _ in range(self.regex_max_attempts): + response = await self.client.aio.models.generate_content( + model=self.model, + contents=self.messages_to_contents(prompt), + config=self.build_generate_config(), + ) + policy_output = self.extract_output_from_response(response) + if pattern.fullmatch(policy_output.content): + return policy_output + prompt = [ + *prompt, + { + "role": "user", + "content": ( + f"Invalid response format. Expected format (regex): {regex}\n" + "Please try again and provide ONLY a response that matches this regex." + ), + }, + ] + return policy_output + + response = await self.client.aio.models.generate_content( + model=self.model, + contents=self.messages_to_contents(prompt), + config=self.build_generate_config(), + ) + return self.extract_output_from_response(response) + + def shutdown(self) -> None: + self.client = None diff --git a/src_code_for_reproducibility/models/large_language_model_local.py b/src_code_for_reproducibility/models/large_language_model_local.py new file mode 100644 index 0000000000000000000000000000000000000000..1f00590397bd0e6d431e35ea9d4320505b3b641c --- /dev/null +++ b/src_code_for_reproducibility/models/large_language_model_local.py @@ -0,0 +1,361 @@ +""" +File: mllm/models/large_language_model_local.py +Summary: Provides a local large language model wrapper over inference backends. +""" + +import logging +import os +import re +import sys +import uuid +from collections.abc import Callable +from copy import deepcopy +from datetime import datetime +from typing import Literal + +import httpx +import requests +import torch +import torch.nn as nn +from torch.optim import SGD, Adam, AdamW, RMSprop +from transformers import AutoModelForCausalLM, AutoTokenizer + +from mllm.chat_utils.apply_template import chat_turns_to_token_ids +from mllm.markov_games.rollout_tree import ChatTurn +from mllm.models.adapter_training_wrapper import AdapterWrapper +from mllm.models.inference_backend import LLMInferenceOutput +from mllm.models.inference_backend_dummy import DummyInferenceBackend +from mllm.models.inference_backend_vllm import VLLMAsyncBackend + +logger = logging.getLogger(__name__) +logger.addHandler(logging.StreamHandler(sys.stdout)) + +AdapterID = str +PolicyID = str + + +class LeanLocalLLM: + """ + Wrapper that manages local HuggingFace models, adapters, and inference backends. + """ + + def __init__( + self, + llm_id: str = "base_llm", + model_name: str = "Qwen/Qwen3-4B-Instruct-2507", + device: str = "cuda", + hf_kwargs: dict = {}, + adapter_configs: dict = {}, + output_directory: str = "./models/", + inference_backend: Literal["vllm", "dummy"] = "vllm", + inference_backend_sampling_params: dict = {}, + inference_backend_init_kwargs: dict = {}, + initial_adapter_paths: dict[str, str] | None = None, + initial_buffer_paths: list[str] | None = None, + enable_thinking: bool = None, + regex_max_attempts: int = -1, + max_thinking_characters: int = 0, + ): + self.inference_backend_name = inference_backend + self.output_directory = output_directory + self.llm_id = llm_id + self.device = torch.device(device) if device else torch.device("cuda") + self.model_name = model_name + self.adapter_configs = adapter_configs + self.adapter_ids = list(adapter_configs.keys()) + self.enable_thinking = enable_thinking + self.regex_max_attempts = regex_max_attempts + self.initial_buffer_paths = initial_buffer_paths + self.max_thinking_characters = max_thinking_characters + self.regex_retries_count = 0 + + # Optional user-specified initial adapter weight locations (local or HF Hub) + # Format: {adapter_id: path_or_repo_id} + self.initial_adapter_paths: dict[str, str] | None = initial_adapter_paths + + # Path management / imports + self.save_path = str(os.path.join(output_directory, model_name, "adapters")) + self.adapter_paths = { + adapter_id: os.path.join(self.save_path, adapter_id) + for adapter_id in self.adapter_ids + } + checkpoints_dir = os.path.join(self.output_directory, "checkpoints") + self.past_agent_adapter_paths = {} + if os.path.isdir(checkpoints_dir): + for dirname in os.listdir(checkpoints_dir): + dirpath = os.path.join(checkpoints_dir, dirname) + if os.path.isdir(dirpath): + self.past_agent_adapter_paths[f"{dirname}_buffer"] = os.path.join( + dirpath, "agent_adapter" + ) + logger.info( + f"Loaded {len(self.past_agent_adapter_paths)} past agent adapters from checkpoints directory." + ) + if self.initial_buffer_paths is not None: + previous_count = len(self.past_agent_adapter_paths) + for path in self.initial_buffer_paths: + if os.path.isdir(path): + for dirname in os.listdir(path): + dirpath = os.path.join(path, dirname) + if os.path.isdir(dirpath): + self.past_agent_adapter_paths[ + f"{dirname}_buffer" + ] = os.path.join(dirpath, "agent_adapter") + else: + logger.warning( + f"Initial buffer path {path} does not exist or is not a directory." + ) + logger.info( + f"Loaded {len(self.past_agent_adapter_paths) - previous_count} past agent adapters from user-specified initial buffer paths." + ) + self.past_agent_adapter_ids = list(self.past_agent_adapter_paths.keys()) + + # ID management for tracking adapter versions + self.adapter_train_ids = { + adapter_id: self.short_id_generator() for adapter_id in self.adapter_ids + } + # Initialize tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + # Setup padding token to be same as EOS token + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + self.tokenizer.pad_token = self.tokenizer.eos_token + + self.weights_got_updated: dict[AdapterID, bool] = { + adapter_id: False for adapter_id in self.adapter_ids + } + self.weights_got_updated.update( + {adapter_id: False for adapter_id in self.past_agent_adapter_ids} + ) + self.current_lora_request = None + self.currently_loaded_adapter_id = None + + # --------------------------------------------------------- + # Init HF model, peft adapters + # --------------------------------------------------------- + self.shared_hf_llm = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path=model_name, + **hf_kwargs, + ) + self.hf_adapters = {} + self.optimizers = {} + for adapter_id in self.adapter_ids: + # Prefer output-folder path if it exists; else fall back to user-specified initial path if provided + output_path = os.path.join(self.save_path, adapter_id) + chosen_path: str | None = None + if os.path.isdir(output_path) and os.listdir(output_path): + chosen_path = output_path + logger.info( + f"Initializing adapter '{adapter_id}': using existing weights from output folder '{chosen_path}'." + ) + elif ( + self.initial_adapter_paths and adapter_id in self.initial_adapter_paths + ): + chosen_path = self.initial_adapter_paths[adapter_id] + logger.info( + f"Initializing adapter '{adapter_id}': using provided initial path '{chosen_path}'." + ) + else: + logger.info( + f"Initializing adapter '{adapter_id}': no initial weights provided or found; starting from scratch." + ) + hf_adapter = AdapterWrapper( + shared_llm=self.shared_hf_llm, + adapter_id=adapter_id, + lora_config=adapter_configs[adapter_id], + path=chosen_path, + ).to(device) + self.hf_adapters[adapter_id] = hf_adapter + # Persist current state of all adapters (ensures remote loads are cached to disk) + self.export_adapters() + + # --------------------------------------------------------- + # Init inference inference_backend + # --------------------------------------------------------- + + if inference_backend == "vllm": + self.inference_backend = VLLMAsyncBackend( + model_name=self.model_name, + # adapter_paths=self.adapter_paths, + tokenizer=self.tokenizer, + engine_init_kwargs=inference_backend_init_kwargs, + sampling_params=inference_backend_sampling_params, + ) + elif inference_backend == "dummy": + self.inference_backend = DummyInferenceBackend() + else: + raise ValueError(f"Unknown inference_backend: {inference_backend}") + + def reset_regex_retries_count(self) -> None: + self.regex_retries_count = 0 + + def get_inference_policies(self) -> dict[PolicyID, Callable]: + """ + Build async policy callables keyed by adapter id for inference-only usage. + """ + policies = {} + for adapter_id in self.adapter_ids: + # define policy func + async def policy( + state: list[ChatTurn], + agent_id: str, + regex: str | None = None, + _adapter_id=adapter_id, + ): + self.prepare_adapter_for_inference(adapter_id=_adapter_id) + response = await self.get_action(state, agent_id, regex) + return response + + policies[self.llm_id + "/" + adapter_id] = policy + + for adapter_id in self.past_agent_adapter_ids: + # define policy func + async def policy( + state: list[ChatTurn], + agent_id: str, + regex: str | None = None, + _adapter_id=adapter_id, + ): + self.prepare_adapter_for_inference(adapter_id=_adapter_id) + response = await self.get_action(state, agent_id, regex) + return response + + policies[self.llm_id + "/" + adapter_id] = policy + return policies + + def get_adapter_modules(self) -> dict[PolicyID, nn.Module]: + """ + Returns wrappers over the adapters which allows them be + interfaced like regular PyTorch models. + AdapterWrapper lives in adapter_wrapper.py; the huggingface modules already wrap + parameters here, so we surface them directly until an extra shim is required. + """ + trainable_objects = {an: self.hf_adapters[an] for an in self.adapter_ids} + return trainable_objects + + async def toggle_training_mode(self) -> None: + for adn in self.adapter_ids: + self.adapter_train_ids[adn] = self.short_id_generator() + await self.inference_backend.toggle_training_mode() + + async def toggle_eval_mode(self) -> None: + await self.inference_backend.toggle_eval_mode() + + def prepare_adapter_for_inference(self, adapter_id: AdapterID) -> None: + self.inference_backend.prepare_adapter( + adapter_id, + adapter_path=self.adapter_paths.get( + adapter_id, self.past_agent_adapter_paths.get(adapter_id, None) + ), + weights_got_updated=self.weights_got_updated[adapter_id], + ) + self.currently_loaded_adapter_id = adapter_id + self.weights_got_updated[adapter_id] = False + + # def _make_prompt_text(self, prompt: list[dict]) -> str: + # if self.enable_thinking is not None: + # prompt_text = self.tokenizer.apply_chat_template( + # prompt, + # tokenize=False, + # add_generation_prompt=True, + # enable_thinking=self.enable_thinking, + # ) + # else: + # prompt_text = self.tokenizer.apply_chat_template( + # prompt, + # tokenize=False, + # add_generation_prompt=True, + # ) + + # return prompt_text + + async def get_action( + self, state: list[ChatTurn], agent_id: str, regex: str | None = None + ) -> ChatTurn: + current_regex = regex if self.regex_max_attempts == -1 else None + pattern = re.compile(regex) if regex else None + nb_attempts = 0 + state = state[:] + while True: + context_token_ids = chat_turns_to_token_ids( + chats=state, + tokenizer=self.tokenizer, + enable_thinking=self.enable_thinking, + ) + policy_output = await self.inference_backend.generate( + input_token_ids=context_token_ids.tolist(), + extract_thinking=(self.max_thinking_characters > 0), + regex=current_regex, + ) + if ( + pattern is None + or (nb_attempts >= self.regex_max_attempts) + or (pattern.fullmatch(policy_output.content)) + ): + return ChatTurn( + agent_id=agent_id, + role="assistant", + content=policy_output.content, + reasoning_content=policy_output.reasoning_content, + out_token_ids=policy_output.out_token_ids, + log_probs=policy_output.log_probs, + is_state_end=False, + ) + else: + self.regex_retries_count += 1 + nb_attempts += 1 + logger.warning( + f"Response {policy_output.content} did not match regex: {regex}, retry {nb_attempts}/{self.regex_max_attempts}" + ) + if nb_attempts == self.regex_max_attempts: + current_regex = regex + # regex_prompt = ChatTurn( + # role="user", + # content=f"Invalid response format. Expected format (regex): {current_regex}\n Please try again and provide ONLY a response that matches this regex.", + # reasoning_content=None, + # log_probs=None, + # out_token_ids=None, + # is_state_end=False, + # ) + # state.append(regex_prompt) + + def export_adapters(self) -> None: + """ + Any peft wrapper, by default, saves all adapters, not just the one currently loaded. + """ + + # New version of the adapters available + for adapter_id in self.adapter_ids: + self.weights_got_updated[adapter_id] = True + for adapter_id in self.past_agent_adapter_ids: + self.weights_got_updated[adapter_id] = True + + adapter_id = self.adapter_ids[0] + self.hf_adapters[adapter_id].save_pretrained(self.save_path) + + def checkpoint_all_adapters(self, checkpoint_indicator: str) -> None: + """ + Checkpoints all adapters to the configured output directory. + """ + adapter_id = self.adapter_ids[0] + output_dir = os.path.join(self.output_directory, "checkpoints") + os.makedirs(output_dir, exist_ok=True) + date_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + agent_adapter_dir = f"{adapter_id}-{checkpoint_indicator}-{date_str}" + export_path = os.path.join(output_dir, agent_adapter_dir) + for adapter_id in self.adapter_ids: + if "agent" in adapter_id: + self.past_agent_adapter_paths[ + f"{agent_adapter_dir}_buffer" + ] = os.path.join(export_path, adapter_id) + self.past_agent_adapter_ids.append(f"{agent_adapter_dir}_buffer") + self.weights_got_updated[f"{agent_adapter_dir}_buffer"] = False + self.hf_adapters[adapter_id].save_pretrained(export_path) + + def short_id_generator(self) -> str: + """ + Generates a short unique ID for tracking adapter versions. + + Returns: + int: An 8-digit integer ID. + """ + return str(uuid.uuid4().int)[:8] diff --git a/src_code_for_reproducibility/models/scalar_critic.py b/src_code_for_reproducibility/models/scalar_critic.py new file mode 100644 index 0000000000000000000000000000000000000000..0b704dcc78fdfbed1c68b1ac469e9c7b51758211 --- /dev/null +++ b/src_code_for_reproducibility/models/scalar_critic.py @@ -0,0 +1,59 @@ +""" +File: mllm/models/scalar_critic.py +Summary: Defines a scalar critic network and helper utilities. +""" + +import torch +import torch.nn as nn +import torch.optim as optim +from peft import LoraConfig, get_peft_model +from transformers import AutoModelForCausalLM, AutoTokenizer + +from mllm.models.adapter_training_wrapper import AdapterWrapper + + +class ScalarCritic(nn.Module): + """ + A causal-LM critic_adapter + a scalar value head: + V_φ(s) = wᵀ h_last + b + Only LoRA adapters (inside critic_adapter) and the value head are trainable. + """ + + def __init__(self, critic_adapter: AdapterWrapper): + super().__init__() + self.critic_adapter = critic_adapter + hidden_size = self.critic_adapter.shared_llm.config.hidden_size + self.value_head = nn.Linear(hidden_size, 1).to( + dtype=critic_adapter.dtype, device=critic_adapter.device + ) + + def forward(self, input_ids, attention_mask=None, **kwargs): + # AdapterWrapper activates its own adapter internally + outputs = self.critic_adapter( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + **kwargs, + ) + h_last = outputs.hidden_states[-1] # (B, S, H) + values = self.value_head(h_last).squeeze(-1) # (B, S) + return values + + def parameters(self, recurse: bool = True): + """Iterator over *trainable* parameters for this critic.""" + # 1) LoRA params for *this* adapter + for p in self.critic_adapter.parameters(): + yield p + # 2) scalar head + yield from self.value_head.parameters() + + def gradient_checkpointing_enable(self, *args, **kwargs): + self.critic_adapter.gradient_checkpointing_enable(*args, **kwargs) + + @property + def dtype(self): + return self.critic_adapter.dtype + + @property + def device(self): + return self.critic_adapter.device diff --git a/src_code_for_reproducibility/training/__init__.py b/src_code_for_reproducibility/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..188fde562de5d8f658ef12708df9967f45cb2a7a --- /dev/null +++ b/src_code_for_reproducibility/training/__init__.py @@ -0,0 +1,4 @@ +""" +File: mllm/training/__init__.py +Summary: Exposes training submodules through the package namespace. +""" diff --git a/src_code_for_reproducibility/training/__pycache__/annealing_methods.cpython-312.pyc b/src_code_for_reproducibility/training/__pycache__/annealing_methods.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d55921e90d12fdba53e18af9e9ba6f4ade46dbce Binary files /dev/null and b/src_code_for_reproducibility/training/__pycache__/annealing_methods.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/training/__pycache__/credit_methods.cpython-312.pyc b/src_code_for_reproducibility/training/__pycache__/credit_methods.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b300d4684025c45fae389a6c781bb7998797e948 Binary files /dev/null and b/src_code_for_reproducibility/training/__pycache__/credit_methods.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/training/__pycache__/tally_metrics.cpython-312.pyc b/src_code_for_reproducibility/training/__pycache__/tally_metrics.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a74fafa69144a369eb17c344e592471248ecd57 Binary files /dev/null and b/src_code_for_reproducibility/training/__pycache__/tally_metrics.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/training/__pycache__/tally_rollout.cpython-312.pyc b/src_code_for_reproducibility/training/__pycache__/tally_rollout.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb87aa66b77f4d963cd4bdb8724b975478dbc70b Binary files /dev/null and b/src_code_for_reproducibility/training/__pycache__/tally_rollout.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/training/__pycache__/tally_tokenwise.cpython-312.pyc b/src_code_for_reproducibility/training/__pycache__/tally_tokenwise.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85894c34c0923ed9c682e64764f8b4352eb2842f Binary files /dev/null and b/src_code_for_reproducibility/training/__pycache__/tally_tokenwise.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/training/__pycache__/trainer_ad_align.cpython-312.pyc b/src_code_for_reproducibility/training/__pycache__/trainer_ad_align.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63f5b30ad7c037fbf59ded7c36d4b5c8c81f9cfa Binary files /dev/null and b/src_code_for_reproducibility/training/__pycache__/trainer_ad_align.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/training/__pycache__/trainer_common.cpython-312.pyc b/src_code_for_reproducibility/training/__pycache__/trainer_common.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c40d5df2692e96042675871bc6aae7d0ee12749 Binary files /dev/null and b/src_code_for_reproducibility/training/__pycache__/trainer_common.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/training/__pycache__/trainer_independent.cpython-312.pyc b/src_code_for_reproducibility/training/__pycache__/trainer_independent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47bb7e6601714fea22c8ab1084ed460dc231d97c Binary files /dev/null and b/src_code_for_reproducibility/training/__pycache__/trainer_independent.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/training/__pycache__/training_data_utils.cpython-312.pyc b/src_code_for_reproducibility/training/__pycache__/training_data_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d0796467210383f881ababcc0932705feb81e8b Binary files /dev/null and b/src_code_for_reproducibility/training/__pycache__/training_data_utils.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/training/annealing_methods.py b/src_code_for_reproducibility/training/annealing_methods.py new file mode 100644 index 0000000000000000000000000000000000000000..591d91f7720880fc202b116b27b15b996c256dc4 --- /dev/null +++ b/src_code_for_reproducibility/training/annealing_methods.py @@ -0,0 +1,20 @@ +""" +File: mllm/training/annealing_methods.py +Summary: Implements annealing schedules used across training loops. +""" + +import numpy as np + + +def sigmoid_annealing(step: int, temperature: float) -> float: + """ + Smoothly ramp a scalar from 0 → 1 using a temperature-controlled sigmoid. + + Args: + step: Current training step or iteration. + temperature: Controls how sharp the transition is; larger values flatten the curve. + + Returns: + Float in [-1, 1] that can be rescaled for annealing schedules. + """ + return 2 / (1 + np.exp(-step / temperature)) - 1 diff --git a/src_code_for_reproducibility/training/credit_methods.py b/src_code_for_reproducibility/training/credit_methods.py new file mode 100644 index 0000000000000000000000000000000000000000..c29032630d06536e7efe6ceaae04092a616ce53a --- /dev/null +++ b/src_code_for_reproducibility/training/credit_methods.py @@ -0,0 +1,307 @@ +""" +File: mllm/training/credit_methods.py +Summary: Holds credit-assignment routines for reinforcement learning updates. +""" + +import torch + + +def whiten_advantages(advantages: torch.Tensor) -> torch.Tensor: + """ + Normalize a vector of advantages to zero mean / unit variance (global). + + Useful for variance reduction before computing gradients. + """ + whitened_advantages = (advantages - torch.mean(advantages)) / ( + torch.std(advantages) + 1e-9 + ) + return whitened_advantages + + +def whiten_advantages_time_step_wise( + advantages: torch.Tensor, # (B, T) +) -> torch.Tensor: + """ + Whiten advantages independently per timestep (column-wise mean/std). + + Helps when rollout lengths differ or certain positions have very different scales. + """ + assert advantages.dim() == 2, "Wrong dimensions." + whitened_advantages_time_step_wise = ( + advantages - advantages.mean(dim=0, keepdim=True) + ) / (advantages.std(dim=0, keepdim=True) + 1e-9) + return whitened_advantages_time_step_wise + + +def get_discounted_state_visitation_credits( + credits: torch.Tensor, discount_factor: float # (B, T) +) -> torch.Tensor: + """ + Apply geometric discounting to credits so earlier visits count less. + + Equivalent to per-timestep multiplication by ``gamma^t``. + """ + return credits * ( + discount_factor ** torch.arange(credits.shape[1], device=credits.device) + ) + + +def get_discounted_returns( + rewards: torch.Tensor, # (B, T) + discount_factor: float, +) -> torch.Tensor: + """ + Computes Monte Carlo discounted returns for a sequence of rewards. + + Args: + rewards (torch.Tensor): Array of rewards for each timestep. + + Returns: + torch.Tensor: Array of discounted returns. + """ + assert rewards.dim() == 2, "Wrong dimensions." + B, T = rewards.shape + discounted_returns = torch.zeros_like(rewards) + accumulator = torch.zeros(B, device=rewards.device, dtype=rewards.dtype) + for t in reversed(range(T)): + accumulator = rewards[:, t] + discount_factor * accumulator + discounted_returns[:, t] = accumulator + return discounted_returns + + +def get_rloo_credits(credits: torch.Tensor): # (B, S) + """Compute leave-one-out baselines for a batch of credits.""" + assert credits.dim() == 2, "Wrong dimensions." + rloo_baselines = torch.zeros_like(credits) + n = credits.shape[0] + if n == 1: + return credits, rloo_baselines + rloo_baselines = (torch.sum(credits, dim=0, keepdim=True) - credits) / (n - 1) + rloo_credits = credits - rloo_baselines + return rloo_credits, rloo_baselines + + +def get_generalized_advantage_estimates( + rewards: torch.Tensor, # (B, T) + value_estimates: torch.Tensor, # (B, T+1) + discount_factor: float, + lambda_coef: float, +) -> torch.Tensor: + """ + Compute Generalized Advantage Estimates (GAE). + + See https://arxiv.org/pdf/1506.02438 for derivation. + """ + assert rewards.dim() == value_estimates.dim() == 2, "Wrong dimensions." + + assert ( + rewards.shape[0] == value_estimates.shape[0] + ), f"Got shapes {rewards.shape} and {value_estimates.shape} of rewards and value estimates." + assert ( + rewards.shape[1] == value_estimates.shape[1] - 1 + ), f"Got shapes {rewards.shape} and {value_estimates.shape} of rewards and value estimates." + + T = rewards.shape[1] + tds = rewards + discount_factor * value_estimates[:, 1:] - value_estimates[:, :-1] + gaes = torch.zeros_like(tds) + acc = 0.0 + for t in reversed(range(T)): + acc = tds[:, t] + lambda_coef * discount_factor * acc + gaes[:, t] = acc + return gaes + + +def get_advantage_alignment_weights( + advantages: torch.Tensor, # (B, T) + exclude_k_equals_t: bool, + gamma: float, + discount_t: bool, +) -> torch.Tensor: + """ + The advantage alignment credit is calculated as + + \[ + A^*(s_t, a_t, b_t) = A^1(s_t, a_t, b_t) + \beta \cdot + \left( \sum_{k < t} \gamma^{t-k} A^1(s_k, a_k, b_k) \right) + A^2(s_t, a_t, b_t) + \] + + Here, the weights are defined as \( \beta \cdot + \left( \sum_{k < t} \gamma^{t-k} A^1(s_k, a_k, b_k) \) + """ + T = advantages.shape[1] + discounted_advantages = advantages * ( + gamma * torch.ones((1, T), device=advantages.device) + ) ** (-torch.arange(0, T, 1, device=advantages.device)) + if exclude_k_equals_t: + sub = torch.eye(T, device=advantages.device) + else: + sub = torch.zeros((T, T), device=advantages.device) + # Identity is for \( k < t \), remove for \( k \leq t \) + ad_align_weights = discounted_advantages @ ( + torch.triu(torch.ones((T, T), device=advantages.device)) - sub + ) + t_discounts = (gamma * torch.ones((1, T), device=advantages.device)) ** ( + torch.arange(0, T, 1, device=advantages.device) + ) + ad_align_weights = t_discounts * ad_align_weights + if discount_t: + time_discounted_advantages = advantages * ( + gamma * torch.ones((1, T), device=advantages.device) + ) ** (torch.arange(0, T, 1, device=advantages.device)) + ad_align_weights = ad_align_weights - advantages + time_discounted_advantages + return ad_align_weights + + +def get_advantage_alignment_credits( + a1: torch.Tensor, # (B, S) + a1_alternative: torch.Tensor, # (B, S, A) + a2: torch.Tensor, # (B, S) + exclude_k_equals_t: bool, + beta: float, + gamma: float = 1.0, + use_old_ad_align: bool = False, + use_sign: bool = False, + clipping: float | None = None, + use_time_regularization: bool = False, + force_coop_first_step: bool = False, + use_variance_regularization: bool = False, + rloo_branch: bool = False, + reuse_baseline: bool = False, + mean_normalize_ad_align: bool = False, + whiten_adalign_advantages: bool = False, + whiten_adalign_advantages_time_step_wise: bool = False, + discount_t: bool = False, +) -> torch.Tensor: + """ + Calculate the advantage alignment credits with vectorization, as described in https://arxiv.org/abs/2406.14662. + + Recall that the advantage opponent shaping term of the AdAlign policy gradient is: + \[ + \beta \mathbb{E}_{\substack{ + \tau \sim \text{Pr}_{\mu}^{\pi^1, \pi^2} \\ + a_t' \sim \pi^1(\cdot \mid s_t) + }} + \left[\sum_{t=0}^\infty \gamma^{t}\left( \sum_{k\leq t} A^1(s_k,a^{\prime}_k,b_k) \right) A^{2}(s_t,a_t, b_t)\nabla_{\theta^1}\text{log } \pi^1(a_t|s_t) \right] + \] + + This method computes the following: + \[ + Credit(s_t, a_t, b_t) = \gamma^t \left[ A^1(s_t, a_t, b_t) + \beta \left( \sum_{k\leq t} A^1(s_k,a^{\prime}_k,b_k) \right) A^{2}(s_t,a_t, b_t) \right] + \] + + Args: + a1: Advantages of the main trajectories for the current agent. + a1_alternative: Advantages of the alternative trajectories for the current agent. + a2: Advantages of the main trajectories for the other agent. + discount_factor: Discount factor for the advantage alignment. + beta: Beta parameter for the advantage alignment. + gamma: Gamma parameter for the advantage alignment. + use_sign_in_ad_align: Whether to use sign in the advantage alignment. + + Returns: + torch.Tensor: The advantage alignment credits. + """ + + assert a1.dim() == a2.dim() == 2, "Advantages must be of shape (B, S)" + if a1_alternative is not None: + assert ( + a1_alternative.dim() == 3 + ), "Alternative advantages must be of shape (B, S, A)" + B, T, A = a1_alternative.shape + else: + B, T = a1.shape + assert a1.shape == a2.shape, "Not the same shape" + + sub_tensors = {} + + if use_old_ad_align: + ad_align_weights = get_advantage_alignment_weights( + advantages=a1, + exclude_k_equals_t=exclude_k_equals_t, + gamma=gamma, + discount_t=discount_t, + ) + sub_tensors["ad_align_weights_prev"] = ad_align_weights + if exclude_k_equals_t: + ad_align_weights = gamma * ad_align_weights + else: + assert a1_alternative is not None, "Alternative advantages must be provided" + if rloo_branch: + a1_alternative = torch.cat([a1.unsqueeze(2), a1_alternative], dim=2) + a1_alternative = a1_alternative.mean(dim=2) + a1, baseline = get_rloo_credits(a1) + if reuse_baseline: + a1_alternative = a1_alternative - baseline + else: + a1_alternative, _ = get_rloo_credits(a1_alternative) + assert a1.shape == a1_alternative.shape, "Not the same shape" + ad_align_weights = get_advantage_alignment_weights( + advantages=a1_alternative, + exclude_k_equals_t=exclude_k_equals_t, + gamma=gamma, + ) + sub_tensors["ad_align_weights"] = ad_align_weights + + # Use sign + if use_sign: + assert beta == 1.0, "beta should be 1.0 when using sign" + positive_signs = ad_align_weights > 0 + negative_signs = ad_align_weights < 0 + ad_align_weights[positive_signs] = 1 + ad_align_weights[negative_signs] = -1 + sub_tensors["ad_align_weights_sign"] = ad_align_weights + # (rest are 0) + + ################### + # Process weights + ################### + + # Use clipping + if clipping not in [0.0, None]: + upper_mask = ad_align_weights > 1 + lower_mask = ad_align_weights < -1 + + ad_align_weights = torch.clip( + ad_align_weights, + -clipping, + clipping, + ) + clipping_ratio = ( + torch.sum(upper_mask) + torch.sum(lower_mask) + ) / upper_mask.size + sub_tensors["clipped_ad_align_weights"] = ad_align_weights + + # 1/1+t Regularization + if use_time_regularization: + t_values = torch.arange(1, T + 1).to(ad_align_weights.device) + ad_align_weights = ad_align_weights / t_values + sub_tensors["time_regularized_ad_align_weights"] = ad_align_weights + + # Use coop on t=0 + if force_coop_first_step: + ad_align_weights[:, 0] = 1 + sub_tensors["coop_first_step_ad_align_weights"] = ad_align_weights + + #################################### + # Compose elements together + #################################### + + opp_shaping_terms = beta * ad_align_weights * a2 + sub_tensors["ad_align_opp_shaping_terms"] = opp_shaping_terms + + credits = a1 + opp_shaping_terms + if mean_normalize_ad_align: + credits = credits - credits.mean(dim=0) + sub_tensors["mean_normalized_ad_align_credits"] = credits + if whiten_adalign_advantages: + credits = (credits - credits.mean()) / (credits.std() + 1e-9) + sub_tensors["whitened_ad_align_credits"] = credits + if whiten_adalign_advantages_time_step_wise: + credits = (credits - credits.mean(dim=0, keepdim=True)) / ( + credits.std(dim=0, keepdim=True) + 1e-9 + ) + sub_tensors["whitened_ad_align_credits_time_step_wise"] = credits + sub_tensors["final_ad_align_credits"] = credits + + return credits, sub_tensors diff --git a/src_code_for_reproducibility/training/tally_metrics.py b/src_code_for_reproducibility/training/tally_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..a0073ab3be86bbaf448b6a898341a4f5a3087b24 --- /dev/null +++ b/src_code_for_reproducibility/training/tally_metrics.py @@ -0,0 +1,64 @@ +""" +File: mllm/training/tally_metrics.py +Summary: Transforms tally files into aggregated metric summaries. +""" + +import os +from numbers import Number +from typing import Union + +import wandb + + +class Tally: + """ + Minimal scalar-first tally. + - Keys are strings. + - First add stores a scalar; subsequent adds upgrade to a list of scalars. + """ + + def __init__(self): + self.stats = {} + + def reset(self): + """Reset all recorded metrics back to an empty dictionary.""" + self.stats = {} + + def _coerce_scalar(self, value: Union[int, float]) -> Union[int, float]: + """Ensure ``value`` is a plain Python scalar (detach tensors, etc.).""" + if hasattr(value, "item") and callable(getattr(value, "item")): + try: + value = value.item() + except Exception: + pass + if isinstance(value, Number): + return value + raise AssertionError("Metric must be a scalar number") + + def add_metric(self, path: str, metric: Union[int, float]): + """Accumulate a metric under ``path`` (scalar on first add, list thereafter).""" + metric = float(metric) + assert isinstance(path, str), "Path must be a string." + assert isinstance(metric, float), "Metric must be a scalar number." + + scalar = self._coerce_scalar(metric) + existing = self.stats.get(path) + if existing is None: + self.stats[path] = scalar + elif isinstance(existing, list): + existing.append(scalar) + else: + self.stats[path] = [existing, scalar] + + def save(self, identifier: str, folder: str): + """Persist the tally as a pickle file under ``folder``.""" + os.makedirs(name=folder, exist_ok=True) + try: + import pickle + + pkl_path = os.path.join(folder, f"{identifier}.tally.pkl") + payload = self.stats + with open(pkl_path, "wb") as f: + pickle.dump(payload, f, protocol=pickle.HIGHEST_PROTOCOL) + except Exception: + pass diff --git a/src_code_for_reproducibility/training/tally_rollout.py b/src_code_for_reproducibility/training/tally_rollout.py new file mode 100644 index 0000000000000000000000000000000000000000..04bb4f36d7f1c6759c3fb0f0102f29b117ea57c1 --- /dev/null +++ b/src_code_for_reproducibility/training/tally_rollout.py @@ -0,0 +1,116 @@ +""" +File: mllm/training/tally_rollout.py +Summary: Serializes rollout data into tallies for downstream processing. +""" + +import json +import os +from copy import deepcopy +from typing import Union + +import numpy as np +import pandas as pd +import torch +from transformers import AutoTokenizer + + +class RolloutTallyItem: + def __init__( + self, + crn_ids: list[str], + rollout_ids: list[str], + agent_ids: list[str], + metric_matrix: torch.Tensor, + ): + """Lightweight data container that keeps rollout-aligned metric matrices.""" + if isinstance(crn_ids, torch.Tensor): + crn_ids = crn_ids.detach().cpu().numpy() + if isinstance(rollout_ids, torch.Tensor): + rollout_ids = rollout_ids.detach().cpu().numpy() + if isinstance(agent_ids, torch.Tensor): + agent_ids = agent_ids.detach().cpu().numpy() + self.crn_ids = crn_ids + self.rollout_ids = rollout_ids + self.agent_ids = agent_ids + metric_matrix = metric_matrix.detach().cpu() + assert ( + 0 < metric_matrix.ndim <= 2 + ), "Metric matrix must have less than or equal to 2 dimensions" + if metric_matrix.ndim == 1: + metric_matrix = metric_matrix.reshape(1, -1) + # Convert to float32 if tensor is in BFloat16 format (not supported by numpy) + if metric_matrix.dtype == torch.bfloat16: + metric_matrix = metric_matrix.float() + self.metric_matrix = metric_matrix.numpy() + + +class RolloutTally: + """ + Tally is a utility class for collecting and storing training metrics. + It supports adding metrics at specified paths and saving them to disk. + """ + + def __init__(self): + """ + Initializes the RolloutTally object. + + Args: + tokenizer (AutoTokenizer): Tokenizer for converting token IDs to strings. + max_context_length (int, optional): Maximum context length for contextualized metrics. Defaults to 30. + """ + # Array-preserving structure (leaf lists hold numpy arrays / scalars) + self.metrics = {} + # Global ordered list of sample identifiers (crn_id, rollout_id) added in the order samples are processed + + def reset(self): + """Reset the tally to an empty dict.""" + self.metrics = {} + + def get_from_nested_dict(self, dictio: dict, path: str): + """Retrieve a nested entry, creating intermediate dicts as needed.""" + assert isinstance(path, list), "Path must be list." + for sp in path[:-1]: + dictio = dictio.setdefault(sp, {}) + return dictio.get(path[-1], None) + + def set_at_path(self, dictio: dict, path: str, value): + """Store ``value`` at ``path``; helper used by ``add_metric``.""" + for sp in path[:-1]: + dictio = dictio.setdefault(sp, {}) + dictio[path[-1]] = value + + def add_metric(self, path: list[str], rollout_tally_item: RolloutTallyItem): + """ + Adds a metric to the base tally at the specified path. + + Args: + path (list): List of keys representing the path in the base tally. + rollout_tally_item (RolloutTallyItem): The rollout tally item to add. + """ + rollout_tally_item = deepcopy(rollout_tally_item) + + # Update array-preserving tally + array_list = self.get_from_nested_dict(dictio=self.metrics, path=path) + if array_list is None: + self.set_at_path(dictio=self.metrics, path=path, value=[rollout_tally_item]) + else: + array_list.append(rollout_tally_item) + + def save(self, identifier: str, folder: str): + """Persist the tally as a pickle (metrics only) under ``folder``.""" + os.makedirs(name=folder, exist_ok=True) + + from datetime import datetime + + now = datetime.now() + + # Pickle only (fastest, exact structure with numpy/scalars at leaves) + try: + import pickle + + pkl_path = os.path.join(folder, f"{identifier}.rt_tally.pkl") + payload = {"metrics": self.metrics} + with open(pkl_path, "wb") as f: + pickle.dump(payload, f, protocol=pickle.HIGHEST_PROTOCOL) + except Exception: + pass diff --git a/src_code_for_reproducibility/training/tally_tokenwise.py b/src_code_for_reproducibility/training/tally_tokenwise.py new file mode 100644 index 0000000000000000000000000000000000000000..b7770e0cb79d5ed4e56a3f66b6982582c72e0bb7 --- /dev/null +++ b/src_code_for_reproducibility/training/tally_tokenwise.py @@ -0,0 +1,278 @@ +""" +File: mllm/training/tally_tokenwise.py +Summary: Converts token-level tallies into per-token statistics. +""" + +import json +import os +from typing import Any, Dict, List, Tuple, Union + +import numpy as np +import pandas as pd +import torch +from transformers import AutoTokenizer + + +class ContextualizedTokenwiseTally: + """ + Collect, store, and save token-level metrics per rollout. + + - One DataFrame per rollout_id in `paths` + - Index = timestep (int) + - Columns are added incrementally via `add_contexts()` and `add_data()` + - Cells may contain scalars, strings, or lists (dtype=object) + """ + + def __init__( + self, + tokenizer: AutoTokenizer, + paths: List[str], + max_context_length: int = 30, + ): + """ + Args: + tokenizer: HuggingFace tokenizer used to convert tids -> tokens + paths: rollout identifiers (parallel to batch dimension) + max_context_length: truncate context token lists to this length + """ + self.tokenizer = tokenizer + self.paths = paths + self.max_context_length = max_context_length + self.tally: Dict[str, pd.DataFrame] = {path: pd.DataFrame() for path in paths} + + # set later by setters + self.contexts: torch.Tensor | None = None + self.action_mask: torch.Tensor | None = None + self.range: Tuple[int, int] | None = None + + # --------- Utilities --------- + + def tids_to_str(self, tids: List[int]) -> List[str]: + """Convert a list of token IDs to a list of token strings.""" + return self.tokenizer.convert_ids_to_tokens(tids) + + def _ensure_ready(self): + """Validate that action mask and range are configured prior to writes.""" + assert self.action_mask is not None, "call set_action_mask(mask) first" + assert self.range is not None, "call set_range((start, end)) first" + + @staticmethod + def _sanitize_filename(name: Any) -> str: + """Make a safe filename from any rollout_id.""" + s = str(name) + bad = {os.sep, " ", ":", "|", "<", ">", '"', "'"} + if os.altsep is not None: + bad.add(os.altsep) + for ch in bad: + s = s.replace(ch, "_") + return s + + @staticmethod + def _pad_left(seq: List[Any], length: int, pad_val: Any = "") -> List[Any]: + """Left-pad a sequence to `length` with `pad_val`.""" + if len(seq) >= length: + return seq[-length:] + return [pad_val] * (length - len(seq)) + list(seq) + + # --------- Setters --------- + + def set_action_mask(self, action_mask: torch.Tensor): + """Register the (B, S) mask indicating which tokens correspond to actions.""" + self.action_mask = action_mask + + def set_range(self, range: Tuple[int, int]): + """Record which subset of ``paths`` the current mini-batch corresponds to.""" + self.range = range + + # --------- Column builders --------- + + def add_contexts(self, contexts: torch.Tensor): + """ + Add a single 'context' column (list[str]) for valid steps. + + Expects `contexts` with shape (B, S): token id at each timestep. + For each valid timestep t, we use the last N tokens up to and including t: + window = contexts[i, max(0, t - N + 1) : t + 1] + The list is left-padded with "" to always be length N. + """ + self._ensure_ready() + + current_paths = self.paths[self.range[0] : self.range[1]] + B, S = contexts.shape + N = self.max_context_length + + # to CPU ints once + contexts_cpu = contexts.detach().to("cpu") + + for i in range(B): + rollout_id = current_paths[i] + df = self.tally.get(rollout_id, pd.DataFrame()) + + valid_idx = torch.nonzero( + self.action_mask[i].bool(), as_tuple=False + ).squeeze(-1) + if valid_idx.numel() == 0: + self.tally[rollout_id] = df + continue + + idx_list = valid_idx.tolist() + + # ensure index contains valid steps + if df.empty: + df = pd.DataFrame(index=idx_list) + else: + new_index = sorted(set(df.index.tolist()) | set(idx_list)) + if list(df.index) != new_index: + df = df.reindex(new_index) + + # build context windows + ctx_token_lists = [] + for t in idx_list: + start = max(0, t - N + 1) + window_ids = contexts_cpu[i, start : t + 1].tolist() + window_toks = self.tids_to_str([int(x) for x in window_ids]) + if len(window_toks) < N: + window_toks = [""] * (N - len(window_toks)) + window_toks + else: + window_toks = window_toks[-N:] + ctx_token_lists.append(window_toks) + + # single 'context' column + if "context" not in df.columns: + df["context"] = pd.Series(index=df.index, dtype=object) + df.loc[idx_list, "context"] = pd.Series( + ctx_token_lists, index=idx_list, dtype=object + ) + + self.tally[rollout_id] = df + + def add_data( + self, + metric_id: str, + metrics: torch.Tensor, + to_tids: bool = False, + ): + """ + Add a metric column for valid steps. + + Args: + metric_id: column name + metrics: shape (B, S) for scalars/ids or (B, S, K) for top-k vectors + to_tids: if True, treat ints/lists of ints as tids and convert to tokens + """ + self._ensure_ready() + current_paths = self.paths[self.range[0] : self.range[1]] + + if metrics.dim() == 2: + B, S = metrics.shape + elif metrics.dim() == 3: + B, S, _ = metrics.shape + else: + raise ValueError("metrics must be (B, S) or (B, S, K)") + + for i in range(B): + rollout_id = current_paths[i] + df = self.tally.get(rollout_id, pd.DataFrame()) + + valid_idx = torch.nonzero( + self.action_mask[i].bool(), as_tuple=False + ).squeeze(-1) + if valid_idx.numel() == 0: + self.tally[rollout_id] = df + continue + + idx_list = valid_idx.detach().cpu().tolist() + + # Ensure index contains valid steps + if df.empty: + df = pd.DataFrame(index=idx_list) + else: + new_index = sorted(set(df.index.tolist()) | set(idx_list)) + if list(df.index) != new_index: + df = df.reindex(new_index) + + # Slice metrics at valid steps + m_valid = metrics[i][valid_idx] + + # -> pure python lists (1D list or list-of-lists) + values = m_valid.detach().cpu().tolist() + + # optional tids -> tokens + if to_tids: + + def _to_tokish(x): + if isinstance(x, list): + return self.tids_to_str([int(v) for v in x]) + else: + return self.tids_to_str([int(x)])[0] + + values = [_to_tokish(v) for v in values] + + # Ensure column exists with object dtype, then assign via aligned Series + if metric_id not in df.columns: + df[metric_id] = pd.Series(index=df.index, dtype=object) + + if isinstance(values, np.ndarray): + values = values.tolist() + + if len(values) != len(idx_list): + raise ValueError( + f"Length mismatch for '{metric_id}': values={len(values)} vs idx_list={len(idx_list)}" + ) + + df.loc[idx_list, metric_id] = pd.Series( + values, index=idx_list, dtype=object + ) + self.tally[rollout_id] = df + + # --------- Saving --------- + + def save(self, path: str): + """ + Write a manifest JSON and one CSV per rollout. + + - Manifest includes metadata only (safe to JSON). + - Each rollout CSV is written with index label 'timestep'. + - Only a single 'context' column (list[str]). + """ + if not self.tally or all(df.empty for df in self.tally.values()): + return + + os.makedirs(path, exist_ok=True) + from datetime import datetime + + now = datetime.now() + + manifest = { + "created_at": f"{now:%Y-%m-%d %H:%M:%S}", + "max_context_length": self.max_context_length, + "num_rollouts": len(self.tally), + "rollouts": [], + } + + for rid, df in self.tally.items(): + rid_str = str(rid) + safe_name = self._sanitize_filename(rid_str) + csv_path = os.path.join(path, f"{safe_name}_tokenwise.csv") + + # Put 'context' first, then the rest + cols = ["context"] + [c for c in df.columns if c != "context"] + try: + df[cols].to_csv(csv_path, index=True, index_label="timestep") + except Exception as e: + continue + + manifest["rollouts"].append( + { + "rollout_id": rid_str, + "csv": csv_path, + "num_rows": int(df.shape[0]), + "columns": cols, + } + ) + + manifest_path = os.path.join( + path, f"tokenwise_manifest_{now:%Y-%m-%d___%H-%M-%S}.json" + ) + with open(manifest_path, "w") as fp: + json.dump(manifest, fp, indent=2) diff --git a/src_code_for_reproducibility/training/tokenize_chats.py b/src_code_for_reproducibility/training/tokenize_chats.py new file mode 100644 index 0000000000000000000000000000000000000000..94da0030ec2afe19d5e5cd8a9a9e39b595d19975 --- /dev/null +++ b/src_code_for_reproducibility/training/tokenize_chats.py @@ -0,0 +1,128 @@ +""" +File: mllm/training/tokenize_chats.py +Summary: Tokenizes chat datasets and prepares tensors for training. +""" + +import logging +import sys + +import regex +import torch +from transformers import AutoTokenizer + +from mllm.training.training_data_utils import TrainingChatTurn, TrajectoryBatch + +logger = logging.getLogger(__name__) +logger.addHandler(logging.StreamHandler(sys.stdout)) + + +def process_training_chat( + tokenizer: AutoTokenizer, + chat_history: list[TrainingChatTurn], + entropy_mask_regex: str | None = None, + exploration_prompts_to_remove: list[str] = [], + use_engine_out_token_ids: bool = False, +) -> tuple[torch.IntTensor, torch.BoolTensor, torch.IntTensor, torch.BoolTensor]: + """Tokenize a single training chat and build aligned per-token masks. + + Given an ordered list of `TrainingChatTurn`, this function tokenizes each + turn independently using the tokenizer's chat template, then concatenates + all resulting token sequences. It also constructs three parallel 1D masks + that align with the concatenated tokens: + + - input_ids: token ids for the entire chat, turn by turn + - action_mask: True for tokens that belong to assistant turns (i.e., model + actions), False for tokens from other roles + - timesteps: per-token time step copied from the originating turn's + `time_step` + - state_ends_mask: True for the last token of any turn where + `is_state_end` is True, otherwise False + + Important details: + - Each turn is passed as a single-message list to + `tokenizer.apply_chat_template` and flattened; the per-turn outputs are + then concatenated in the original order. + - Turn boundaries are not explicitly encoded beyond what the chat template + inserts; masks provide alignment for learning signals and state endings. + - No truncation or padding is performed here; downstream code should handle + batching/padding as needed. + - Note on dtypes: `input_ids` will be a LongTensor (int64). `action_mask` + and `state_ends_mask` are BoolTensors. `timesteps` is currently created + as a float tensor; adjust the implementation if integer dtype is + required downstream. + + Args: + tokenizer: A Hugging Face tokenizer supporting `apply_chat_template`. + chat_history: Ordered list of `TrainingChatTurn` forming one dialogue. + + Returns: + A tuple of four 1D tensors, all of equal length N (the total number of + tokens across all turns), in the following order: + - input_ids (LongTensor) + - action_mask (BoolTensor) + - timesteps (FloatTensor as implemented; see note above) + - state_ends_mask (BoolTensor) + """ + state_ends_mask = [] + input_ids = [] + action_mask = [] + timesteps = [] + entropy_mask = [] + engine_log_probs = [] + for train_chat_turn in chat_history: + is_state_end = train_chat_turn.is_state_end + time_step = train_chat_turn.time_step + is_action = train_chat_turn.role == "assistant" + + # Remove exploration prompts from training data + for exploration_prompt in exploration_prompts_to_remove: + if exploration_prompt in train_chat_turn.content: + train_chat_turn.content = train_chat_turn.content.replace( + exploration_prompt, "" + ) + + chat_turn = { + "role": train_chat_turn.role, + "content": train_chat_turn.content, + } + if entropy_mask_regex is not None: + is_entropy_mask_true = ( + regex.search(entropy_mask_regex, train_chat_turn.content) is not None + ) + else: + is_entropy_mask_true = True + if is_action: + chat_turn_ids = train_chat_turn.out_token_ids + nb_chat_turns_ids = chat_turn_ids.numel() + action_mask.append(torch.ones(nb_chat_turns_ids, dtype=torch.bool)) + engine_log_probs.append(train_chat_turn.log_probs) + else: + chat_turn_ids = train_chat_turn.chat_template_token_ids + nb_chat_turns_ids = chat_turn_ids.numel() + action_mask.append(torch.zeros(nb_chat_turns_ids, dtype=torch.bool)) + engine_log_probs.append(torch.zeros(nb_chat_turns_ids, dtype=torch.float)) + nb_chat_turns_ids = chat_turn_ids.numel() + state_ends_mask.append(torch.zeros(nb_chat_turns_ids, dtype=torch.bool)) + if is_state_end: + state_ends_mask[-1][-1] = True # last token is state end + input_ids.append(chat_turn_ids) + entropy_mask.append(torch.ones(nb_chat_turns_ids, dtype=torch.bool)) + if not is_entropy_mask_true: + entropy_mask[-1] = entropy_mask[-1] * False + timesteps.append(torch.ones(nb_chat_turns_ids) * time_step) + input_ids = torch.cat(input_ids) + action_mask = torch.cat(action_mask) + entropy_mask = torch.cat(entropy_mask) + timesteps = torch.cat(timesteps) + timesteps = timesteps.to(torch.long) + state_ends_mask = torch.cat(state_ends_mask) + engine_log_probs = torch.cat(engine_log_probs) + + return ( + input_ids, + action_mask, + entropy_mask, + timesteps, + state_ends_mask, + engine_log_probs, + ) diff --git a/src_code_for_reproducibility/training/trainer_ad_align.py b/src_code_for_reproducibility/training/trainer_ad_align.py new file mode 100644 index 0000000000000000000000000000000000000000..14e18e51480e594355b3416555011223ff0e8f36 --- /dev/null +++ b/src_code_for_reproducibility/training/trainer_ad_align.py @@ -0,0 +1,505 @@ +""" +File: mllm/training/trainer_ad_align.py +Summary: Trainer specialized for the advantage-alignment objective. +""" + +import copy +import logging +import sys +from dataclasses import dataclass +from typing import Tuple + +import torch +from torch.nn.utils.rnn import pad_sequence + +from mllm.markov_games.rollout_tree import ( + ChatTurn, + RolloutTreeBranchNode, + RolloutTreeRootNode, +) +from mllm.training.credit_methods import ( + get_advantage_alignment_credits, + get_discounted_state_visitation_credits, +) +from mllm.training.tally_metrics import Tally +from mllm.training.tally_rollout import RolloutTally, RolloutTallyItem +from mllm.training.tally_tokenwise import ContextualizedTokenwiseTally +from mllm.training.tokenize_chats import process_training_chat +from mllm.training.trainer_common import BaseTrainer +from mllm.training.training_data_utils import ( + AdvantagePacket, + TrainingBatch, + TrainingChatTurn, + TrajectoryBatch, + get_main_chat_list_and_rewards, + get_tokenwise_credits, +) +from mllm.utils.resource_context import resource_logger_context + +logger = logging.getLogger(__name__) +logger.addHandler(logging.StreamHandler(sys.stdout)) + +RolloutId = int +AgentId = str + + +@dataclass +class AdAlignTrainingData: + """Holds tensorized rollouts plus precomputed advantages for one agent.""" + + agent_id: str + main_data: TrajectoryBatch + # list-of-tensors: per rollout advantages with length jT + main_advantages: list[torch.FloatTensor] | None = None + # list-of-tensors: per rollout matrix (jT, A) + alternative_advantages: list[torch.FloatTensor] | None = None + advantage_alignment_credits: list[torch.FloatTensor] | None = None + + +def get_alternative_chat_histories( + agent_id: str, root: RolloutTreeRootNode +) -> list[list[TrainingChatTurn], list[torch.FloatTensor]]: + """ + Traverse every unilateral branch under ``root`` and collect chat/reward histories. + + Returns + ------- + alternative_chats: + Flattened list of chat turns for each branch (ordered by branch depth). + alternative_rewards: + Matching list of reward tensors aligned with the chat history. + """ + current_node = root.child + branches = current_node.branches + pre_branch_chat = [] + pre_branch_rewards = [] + alternative_rewards = [] + alternative_chats = [] + while current_node is not None: + assert isinstance( + current_node, RolloutTreeBranchNode + ), "Current node should be a branch node." + main_node = current_node.main_child + branches = current_node.branches + current_node = main_node.child + + # Get the `A` alternative trajectories + alternative_nodes = branches[agent_id] + for alt_node in alternative_nodes: + post_branch_chat, post_branch_rewards = get_main_chat_list_and_rewards( + agent_id=agent_id, root=alt_node + ) + branch_chat = pre_branch_chat + post_branch_chat + alternative_chats.append(branch_chat) + alternative_rewards.append( + torch.cat([torch.tensor(pre_branch_rewards), post_branch_rewards]) + ) + + chat_turns: list[ChatTurn] = main_node.step_log.action_logs[agent_id].chat_turns + chat_turns: list[TrainingChatTurn] = [ + TrainingChatTurn(time_step=main_node.time_step, **turn.model_dump()) + for turn in chat_turns + ] + + pre_branch_chat.extend(chat_turns) + pre_branch_rewards.append( + main_node.step_log.simulation_step_log.rewards[agent_id] + ) + + return alternative_chats, alternative_rewards + + +class TrainerAdAlign(BaseTrainer): + """ + Extends the reinforce trainer to support Advantage Alignment. + """ + + def __init__( + self, + ad_align_beta: float, + ad_align_gamma: float, + ad_align_exclude_k_equals_t: bool, + ad_align_use_sign: bool, + ad_align_clipping: float, + ad_align_force_coop_first_step: bool, + use_old_ad_align: bool, + use_time_regularization: bool, + rloo_branch: bool, + reuse_baseline: bool, + ad_align_beta_anneal_step: int = -1, + ad_align_beta_anneal_rate: float = 0.5, + min_ad_align_beta: float = 0.1, + mean_normalize_ad_align: bool = False, + whiten_adalign_advantages: bool = False, + whiten_adalign_advantages_time_step_wise: bool = False, + ad_align_discount_t: bool = False, + *args, + **kwargs, + ): + """ + Initialize the advantage alignment trainer. + Args: + ad_align_beta: Beta parameter for the advantage alignment. + ad_align_gamma: Gamma parameter for the advantage alignment. + ad_align_exclude_k_equals_t: Whether to include k = t in the advantage alignment. + ad_align_use_sign: Whether to use sign in the advantage alignment. + ad_align_clipping: Clipping value for the advantage alignment. + ad_align_force_coop_first_step: Whether to force coop on the first step of the advantage alignment. + """ + super().__init__(*args, **kwargs) + self.ad_align_beta = ad_align_beta + self.ad_align_gamma = ad_align_gamma + self.ad_align_exclude_k_equals_t = ad_align_exclude_k_equals_t + self.ad_align_use_sign = ad_align_use_sign + self.ad_align_clipping = ad_align_clipping + self.ad_align_force_coop_first_step = ad_align_force_coop_first_step + self.use_old_ad_align = use_old_ad_align + self.use_time_regularization = use_time_regularization + self.rloo_branch = rloo_branch + self.reuse_baseline = reuse_baseline + self.ad_align_beta_anneal_step = ad_align_beta_anneal_step + self.ad_align_beta_anneal_rate = ad_align_beta_anneal_rate + self.min_ad_align_beta = min_ad_align_beta + self.past_ad_align_step = -1 + self.mean_normalize_ad_align = mean_normalize_ad_align + self.whiten_adalign_advantages = whiten_adalign_advantages + self.whiten_adalign_advantages_time_step_wise = ( + whiten_adalign_advantages_time_step_wise + ) + self.ad_align_discount_t = ad_align_discount_t + self.training_data: dict[AgentId, AdAlignTrainingData] = {} + self.debug_path_list: list[str] = [] + + def set_agent_trajectory_data( + self, agent_id: str, roots: list[RolloutTreeRootNode] + ): + """ + Materialize main and alternative trajectory tensors used by the advantage-alignment trainer. + """ + + B = len(roots) # Number of rollouts + + # For main rollouts + batch_rollout_ids = [] + batch_crn_ids = [] + batch_input_ids = [] + batch_action_mask = [] + batch_entropy_mask = [] + batch_timesteps = [] + batch_state_ends_mask = [] + batch_engine_log_probs = [] + batch_rewards = [] + + # For alternative actions rollouts + batch_branching_time_steps = [] + alternative_batch_input_ids = [] + alternative_batch_action_mask = [] + alternative_batch_entropy_mask = [] + alternative_batch_timesteps = [] + alternative_batch_state_ends_mask = [] + alternative_batch_engine_log_probs = [] + alternative_batch_rewards = [] + jT_list = [] + + try: + A = len(roots[0].child.branches[agent_id]) # Number of alternative actions + except: + A = 0 + + for root in roots: + rollout_id = root.id + self.debug_path_list.append( + "mgid:" + str(rollout_id) + "_agent_id:" + agent_id + ) + # Get main trajectory + batch_rollout_ids.append(rollout_id) + batch_crn_ids.append(root.crn_id) + main_chat, main_rewards = get_main_chat_list_and_rewards( + agent_id=agent_id, root=root + ) + ( + input_ids, + action_mask, + entropy_mask, + timesteps, + state_ends_mask, + engine_log_probs, + ) = process_training_chat( + tokenizer=self.tokenizer, + chat_history=main_chat, + entropy_mask_regex=self.entropy_mask_regex, + exploration_prompts_to_remove=self.exploration_prompts_to_remove, + ) + batch_input_ids.append(input_ids) + batch_action_mask.append(action_mask) + batch_entropy_mask.append(entropy_mask) + batch_timesteps.append(timesteps) + batch_state_ends_mask.append(state_ends_mask) + batch_engine_log_probs.append(engine_log_probs) + batch_rewards.append(main_rewards) + jT = ( + main_rewards.numel() + ) # Number of timesteps inferred from reward tensor length. + jT_list.append(jT) + if A > 0: + # We get the branching time steps for each of the `jT` time steps in the main trajectory. + branching_time_steps = [bt for item in range(jT) for bt in A * [item]] + batch_branching_time_steps.extend(branching_time_steps) + + # Get all of the (jT*A) alternative trajectories in the tree + # (jT is the number of time steps in the main trajectory, A is the number of alternative actions) + alternative_chats, alternative_rewards = get_alternative_chat_histories( + agent_id=agent_id, root=root + ) + assert ( + len(alternative_chats) == A * jT + ), "Incorrect number of alternative trajectories." + + for chat, rewards in zip(alternative_chats, alternative_rewards): + ( + input_ids, + action_mask, + entropy_mask, + timesteps, + state_ends_mask, + engine_log_probs, + ) = process_training_chat( + tokenizer=self.tokenizer, + chat_history=chat, + entropy_mask_regex=self.entropy_mask_regex, + exploration_prompts_to_remove=self.exploration_prompts_to_remove, + ) + alternative_batch_input_ids.append(input_ids) + alternative_batch_action_mask.append(action_mask) + alternative_batch_entropy_mask.append(entropy_mask) + alternative_batch_timesteps.append(timesteps) + alternative_batch_state_ends_mask.append(state_ends_mask) + alternative_batch_engine_log_probs.append(engine_log_probs) + alternative_batch_rewards.append(rewards) + + jT_list = torch.Tensor(jT_list) + + # Assert that number of alternative actions is constant + # assert len(set(nb_alternative_actions)) == 1, "Number of alternative actions must be constant" + # A = nb_alternative_actions[0] + + trajectory_batch = TrajectoryBatch( + rollout_ids=torch.tensor(batch_rollout_ids, dtype=torch.int32), # (B,) + crn_ids=torch.tensor(batch_crn_ids, dtype=torch.int32), + agent_ids=[agent_id] * len(batch_rollout_ids), + batch_input_ids=batch_input_ids, + batch_action_mask=batch_action_mask, + batch_entropy_mask=batch_entropy_mask, + batch_timesteps=batch_timesteps, + batch_state_ends_mask=batch_state_ends_mask, + batch_engine_log_probs=batch_engine_log_probs, + batch_rewards=batch_rewards, + ) + # Get Advantages & Train Critic + with resource_logger_context( + logger, "Get advantages with critic gradient accumulation" + ): + self.batch_advantages: torch.FloatTensor = ( + self.get_advantages_with_critic_gradient_accumulation(trajectory_batch) + ) # (B, jT) + + if A > 0: + # Here, `A` is the number of alternative actions / trajectories taken at each time step. + # For each of the `B` rollout perspectives, at each of its jT (`j` is for jagged, since each main rollout may be of a different length) steps, we take A alternate trajectories (from different actions). + # Therefore, we have ∑jT * A trajectories to process. If each of the main trajectories have T steps, we will have `B*T*A` to process. + with resource_logger_context(logger, "Create alternative trajectory batch"): + sum_jT = int(torch.sum(jT_list).item()) + jT_list = ( + jT_list.int().tolist() + ) # (jT,) # (we only want the advantages where we branched out) + alternative_trajectory_batch = TrajectoryBatch( + rollout_ids=torch.zeros(A * sum_jT, dtype=torch.int32), + crn_ids=torch.zeros(A * sum_jT, dtype=torch.int32), + agent_ids=[agent_id] * (A * sum_jT), + batch_input_ids=alternative_batch_input_ids, + batch_action_mask=alternative_batch_action_mask, + batch_entropy_mask=alternative_batch_entropy_mask, + batch_timesteps=alternative_batch_timesteps, + batch_state_ends_mask=alternative_batch_state_ends_mask, + batch_engine_log_probs=alternative_batch_engine_log_probs, + batch_rewards=alternative_batch_rewards, + ) + + # Get alternative advantages + # BAAs stands for batch alternative advantages + # (torch nested tensors have very little api support, so we have to do some odd manual work here) + with resource_logger_context( + logger, "Compute alternative advantage estimates" + ): + BAAs_list = self.get_advantages_with_critic_gradient_accumulation( + alternative_trajectory_batch + ) # list length (∑jT * A), each (jT',) + # Pad alternative advantages to (∑jT*A, P) + + BAAs_padded = pad_sequence( + BAAs_list, batch_first=True, padding_value=0.0 + ) + branch_idx = torch.tensor( + batch_branching_time_steps, + device=BAAs_padded.device, + dtype=torch.long, + ) + gathered = BAAs_padded.gather( + dim=1, index=branch_idx.unsqueeze(1) + ).squeeze(1) + # Reshape and split per rollout, then transpose to (jT_i, A) + gathered = gathered.view(A, sum_jT) # (A, ∑jT) + blocks = list( + torch.split(gathered, jT_list, dim=1) + ) # len B, shapes (A, jT_i) + BAAs = [ + blk.transpose(0, 1).contiguous() for blk in blocks + ] # list of (jT_i, A) + if self.ad_align_beta_anneal_step > 0: + max_rollout_id = torch.max(trajectory_batch.rollout_ids) + 1 + if ( + max_rollout_id % self.ad_align_beta_anneal_step == 0 + and self.past_ad_align_step != max_rollout_id + ): + self.ad_align_beta = max( + self.ad_align_beta * self.ad_align_beta_anneal_rate, + self.min_ad_align_beta, + ) + logger.info(f"Annealing ad_align_beta to {self.ad_align_beta}") + self.past_ad_align_step = max_rollout_id + self.training_data[agent_id] = AdAlignTrainingData( + agent_id=agent_id, + main_data=trajectory_batch, + main_advantages=self.batch_advantages, + alternative_advantages=BAAs if A > 0 else None, + ) + + def share_advantage_data(self) -> list[AdvantagePacket]: + """ + Share the advantage alignment data with other agents. + Returns: + AdvantagePacket: The advantage packet containing the agent's advantages. + """ + logger.info(f"Sharing advantage alignment data.") + advantage_packets = [] + for _, agent_data in self.training_data.items(): + advantage_packets.append( + AdvantagePacket( + agent_id=agent_data.agent_id, + rollout_ids=agent_data.main_data.rollout_ids, + main_advantages=agent_data.main_advantages, + ) + ) + return advantage_packets + + def receive_advantage_data(self, advantage_packets: list[AdvantagePacket]): + """ + Receive advantage packets from other players. + These contain the advantages of the other players' rollouts estimated by them. + """ + logger.info(f"Receiving advantage packets.") + + assert ( + len(advantage_packets) > 0 + ), "At least one advantage packet must be provided." + + for agent_id, agent_data in self.training_data.items(): + coagent_advantage_packets = [ + packet for packet in advantage_packets if packet.agent_id != agent_id + ] + agent_rollout_ids = agent_data.main_data.rollout_ids + agent_advantages = agent_data.main_advantages + co_agent_advantages = [] + for rollout_id in agent_rollout_ids: + for co_agent_packet in coagent_advantage_packets: + if rollout_id in co_agent_packet.rollout_ids: + index = torch.where(rollout_id == co_agent_packet.rollout_ids)[ + 0 + ].item() + co_agent_advantages.append( + co_agent_packet.main_advantages[index] + ) + # assumes that its two player game, with one co-agent + break + assert len(co_agent_advantages) == len(agent_advantages) + B = len(agent_advantages) + assert all( + a.shape[0] == b.shape[0] + for a, b in zip(co_agent_advantages, agent_advantages) + ), "Number of advantages must match for advantage alignment." + + # Get padded tensors (advantage alignment is invariant to padding) + lengths = torch.tensor( + [len(t) for t in agent_advantages], + device=self.device, + dtype=torch.long, + ) + padded_main_advantages = pad_sequence( + agent_advantages, batch_first=True, padding_value=0.0 + ) + if agent_data.alternative_advantages: + padded_alternative_advantages = pad_sequence( + agent_data.alternative_advantages, + batch_first=True, + padding_value=0.0, + ) # (B, P, A) + else: + padded_alternative_advantages = None + padded_co_agent_advantages = pad_sequence( + co_agent_advantages, batch_first=True, padding_value=0.0 + ) + + # Create training batch data + credits, sub_tensors = get_advantage_alignment_credits( + a1=padded_main_advantages, + a1_alternative=padded_alternative_advantages, + a2=padded_co_agent_advantages, + beta=self.ad_align_beta, + gamma=self.ad_align_gamma, + exclude_k_equals_t=self.ad_align_exclude_k_equals_t, + use_sign=self.ad_align_use_sign, + clipping=self.ad_align_clipping, + force_coop_first_step=self.ad_align_force_coop_first_step, + use_old_ad_align=self.use_old_ad_align, + use_time_regularization=self.use_time_regularization, + rloo_branch=self.rloo_branch, + reuse_baseline=self.reuse_baseline, + mean_normalize_ad_align=self.mean_normalize_ad_align, + whiten_adalign_advantages=self.whiten_adalign_advantages, + whiten_adalign_advantages_time_step_wise=self.whiten_adalign_advantages_time_step_wise, + discount_t=self.ad_align_discount_t, + ) + for key, value in sub_tensors.items(): + self.rollout_tally.add_metric( + path=[key], + rollout_tally_item=RolloutTallyItem( + crn_ids=agent_data.main_data.crn_ids, + rollout_ids=agent_data.main_data.rollout_ids, + agent_ids=agent_data.main_data.agent_ids, + metric_matrix=value, + ), + ) + + if not self.skip_discounted_state_visitation: + credits = get_discounted_state_visitation_credits( + credits, + self.discount_factor, + ) + self.rollout_tally.add_metric( + path=["discounted_state_visitation_credits"], + rollout_tally_item=RolloutTallyItem( + crn_ids=agent_data.main_data.crn_ids, + rollout_ids=agent_data.main_data.rollout_ids, + agent_ids=agent_data.main_data.agent_ids, + metric_matrix=sub_tensors[ + "discounted_state_visitation_credits" + ], + ), + ) + + # Slice back to jagged + advantage_alignment_credits = [credits[i, : lengths[i]] for i in range(B)] + # Replace stored training data for this agent by the concrete trajectory batch + # and attach the computed credits for policy gradient. + self.training_data[agent_id] = agent_data.main_data + self.training_data[agent_id].batch_credits = advantage_alignment_credits diff --git a/src_code_for_reproducibility/training/trainer_common.py b/src_code_for_reproducibility/training/trainer_common.py new file mode 100644 index 0000000000000000000000000000000000000000..59d3c97b975aad5cdb2727ff2c4e487de70fa43d --- /dev/null +++ b/src_code_for_reproducibility/training/trainer_common.py @@ -0,0 +1,1040 @@ +""" +File: mllm/training/trainer_common.py +Summary: Shared trainer utilities, base classes, and gradient helpers. +""" + +import logging +import os +import pickle +import sys +from abc import ABC, abstractmethod +from typing import Callable, Literal, Union + +import numpy as np +import torch +import torch.nn.functional as F +from accelerate import Accelerator +from pandas._libs.tslibs.offsets import CBMonthBegin +from peft import LoraConfig +from torch.nn.utils.rnn import pad_sequence +from transformers import AutoModelForCausalLM, AutoTokenizer + +from mllm.markov_games.rollout_tree import * +from mllm.markov_games.rollout_tree import RolloutTreeRootNode +from mllm.training.annealing_methods import sigmoid_annealing +from mllm.training.credit_methods import ( + get_discounted_returns, + get_generalized_advantage_estimates, + get_rloo_credits, + whiten_advantages, + whiten_advantages_time_step_wise, +) +from mllm.training.tally_metrics import Tally +from mllm.training.tally_rollout import RolloutTally, RolloutTallyItem +from mllm.training.tally_tokenwise import ContextualizedTokenwiseTally +from mllm.training.tokenize_chats import * +from mllm.training.tokenize_chats import process_training_chat +from mllm.training.training_data_utils import * +from mllm.training.training_data_utils import ( + TrainingBatch, + TrajectoryBatch, + get_tokenwise_credits, +) +from mllm.utils.resource_context import resource_logger_context + +logger = logging.getLogger(__name__) +logger.addHandler(logging.StreamHandler(sys.stdout)) + + +@dataclass +class TrainerAnnealingState: + annealing_step_counter: int = 0 + + +class BaseTrainer(ABC): + """ + Shared scaffolding for policy-gradient trainers (optimizer wiring, logging, etc.). + + Subclasses implement `set_agent_trajectory_data` / `share_advantage_data` + to plug in algorithm-specific behavior. + """ + + def __init__( + self, + policy: AutoModelForCausalLM, + policy_optimizer: torch.optim.Optimizer, + critic: Union[AutoModelForCausalLM, None], + critic_optimizer: Union[torch.optim.Optimizer, None], + tokenizer: AutoTokenizer, + lr_scheduler: torch.optim.lr_scheduler.LRScheduler, + critic_lr_scheduler: Union[torch.optim.lr_scheduler.LRScheduler, None], + ###################################################################### + entropy_coeff: float, + entropy_topk: int, + entropy_mask_regex: Union[str, None], + kl_coeff: float, + gradient_clipping: Union[float, None], + restrict_tokens: Union[list[str], None], + mini_batch_size: int, + use_gradient_checkpointing: bool, + temperature: float, + device: str, + whiten_advantages: bool, + whiten_advantages_time_step_wise: bool, + use_gae: bool, + use_gae_lambda_annealing: bool, + gae_lambda_annealing_limit: float, + gae_lambda_annealing_method: Literal["sigmoid_annealing"], + gae_lambda_annealing_method_params: dict, + pg_loss_normalization: Literal["batch", "nb_tokens"], + use_rloo: bool, + skip_discounted_state_visitation: bool, + discount_factor: float, + enable_tokenwise_logging: bool, + save_path: str, + reward_agent_id: str | None = None, + reward_scale: float = 1.0, + reward_peer_agent_id: str | None = None, + reward_peer_scale: float = 0.0, + reward_normalizing_constant: float = 1.0, + critic_loss_type: Literal["mse", "huber"] = "huber", + exploration_prompts_to_remove: list[str] = [], + filter_higher_refprob_tokens_kl: bool = False, + truncated_importance_sampling_ratio_cap: float = 0.0, + importance_sampling_strategy: Literal[ + "per_token", "per_sequence" + ] = "per_token", + no_rloo_grouping: bool = False, + ): + """ + Initialize the REINFORCE trainer with reward shaping for multi-agent or single-agent training. + + Args: + model (AutoModelForCausalLM): The main policy model. + tokenizer (AutoTokenizer): Tokenizer for the model. + optimizer (torch.optim.Optimizer): Optimizer for the policy model. + lr_scheduler (torch.optim.lr_scheduler.LRScheduler): Learning rate scheduler for the policy model. + critic (AutoModelForCausalLM or None): Critic model for value estimation (optional). + critic_optimizer (torch.optim.Optimizer or None): Optimizer for the critic model (optional). + critic_lr_scheduler (torch.optim.lr_scheduler.LRScheduler or None): LR scheduler for the critic (optional). + config (RtConfig): Configuration object for training. + """ + self.tokenizer = tokenizer + # self.tokenizer.padding_side = "left" # needed for flash attention + if self.tokenizer.pad_token_id is None: + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + self.lr_scheduler = lr_scheduler + self.accelerator = Accelerator() + ( + self.policy, + self.policy_optimizer, + self.critic, + self.critic_optimizer, + ) = self.accelerator.prepare(policy, policy_optimizer, critic, critic_optimizer) + + self.critic_lr_scheduler = critic_lr_scheduler + self.tally = Tally() + + if use_gradient_checkpointing == True: + self.policy.gradient_checkpointing_enable(dict(use_reentrant=False)) + if critic is not None: + self.critic.gradient_checkpointing_enable(dict(use_reentrant=False)) + + self.save_path = save_path + + # Load trainer state if it exists + self.trainer_annealing_state_path = os.path.join( + self.save_path, "trainer_annealing_state.pkl" + ) + if os.path.exists(self.trainer_annealing_state_path): + logger.info( + f"Loading trainer state from {self.trainer_annealing_state_path}" + ) + self.trainer_annealing_state = pickle.load( + open(self.trainer_annealing_state_path, "rb") + ) + else: + self.trainer_annealing_state = TrainerAnnealingState() + + # Load policy optimizer state if it exists + self.policy_optimizer_path = os.path.join( + self.save_path, "policy_optimizer_state.pt" + ) + if os.path.exists(self.policy_optimizer_path): + logger.info( + f"Loading policy optimizer state from {self.policy_optimizer_path}" + ) + self.policy_optimizer.load_state_dict( + torch.load(self.policy_optimizer_path) + ) + + # Load critic optimizer state if it exists + self.critic_optimizer_path = os.path.join( + self.save_path, "critic_optimizer_state.pt" + ) + if ( + os.path.exists(self.critic_optimizer_path) + and self.critic_optimizer is not None + ): + logger.info( + f"Loading critic optimizer state from {self.critic_optimizer_path}" + ) + self.critic_optimizer.load_state_dict( + torch.load(self.critic_optimizer_path) + ) + self.device = self.accelerator.device + self.entropy_coeff = entropy_coeff + self.entropy_topk = entropy_topk + self.entropy_mask_regex = entropy_mask_regex + self.kl_coeff = kl_coeff + self.gradient_clipping = gradient_clipping + self.restrict_tokens = restrict_tokens + self.mini_batch_size = mini_batch_size + self.use_gradient_checkpointing = use_gradient_checkpointing + self.temperature = temperature + self.use_gae = use_gae + self.whiten_advantages = whiten_advantages + self.whiten_advantages_time_step_wise = whiten_advantages_time_step_wise + self.use_rloo = use_rloo + self.skip_discounted_state_visitation = skip_discounted_state_visitation + self.use_gae_lambda_annealing = use_gae_lambda_annealing + self.gae_lambda_annealing_limit = gae_lambda_annealing_limit + if use_gae_lambda_annealing: + self.gae_lambda_annealing_method: Callable[ + [int], float + ] = lambda step: eval(gae_lambda_annealing_method)( + step=step, **gae_lambda_annealing_method_params + ) + self.discount_factor = discount_factor + self.enable_tokenwise_logging = enable_tokenwise_logging + self.reward_agent_id = reward_agent_id + self.reward_scale = reward_scale + self.reward_peer_agent_id = reward_peer_agent_id + self.reward_peer_scale = reward_peer_scale + self.reward_normalizing_constant = reward_normalizing_constant + self.pg_loss_normalization = pg_loss_normalization + self.critic_loss_type = critic_loss_type + self.exploration_prompts_to_remove = exploration_prompts_to_remove + # Common containers used by all trainers + self.training_data: dict = {} + self.debug_path_list: list[str] = [] + self.policy_gradient_data = None + self.tally = Tally() + self.rollout_tally = RolloutTally() + self.tokenwise_tally: Union[ContextualizedTokenwiseTally, None] = None + self.filter_higher_refprob_tokens_kl = filter_higher_refprob_tokens_kl + self.truncated_importance_sampling_ratio_cap = ( + truncated_importance_sampling_ratio_cap + ) + self.importance_sampling_strategy = importance_sampling_strategy + self.no_rloo_grouping = no_rloo_grouping + + def mask_non_restricted_token_logits(self, logits: torch.Tensor) -> torch.Tensor: + """ + Masks logits so that only allowed tokens (as specified in config.restrict_tokens) + and the EOS token are active. + All other logits are set to -inf, effectively removing them from the softmax. + + Args: + logits (torch.Tensor): The logits tensor of shape (B, S, V). + + Returns: + torch.Tensor: The masked logits tensor. + """ + # Gradients flow only through the kept logits; masking is recomputed per batch for clarity. + + if self.restrict_tokens is not None: + allowed_token_ids = [] + for token in self.restrict_tokens: + token_ids = self.tokenizer(token, add_special_tokens=False)["input_ids"] + allowed_token_ids.append(token_ids[0]) + allowed_token_ids.append( + self.tokenizer.eos_token_id + ) # This token should always be active + allowed_token_ids = torch.tensor(allowed_token_ids, device=logits.device) + # Mask log_probs and probs to only allowed tokens + mask = torch.zeros_like(logits).bool() # (B, S, V) + mask[..., allowed_token_ids] = True + logits = torch.where( + mask, + logits, + torch.tensor(-float("inf"), device=logits.device), + ) + + return logits + + def apply_reinforce_step( + self, + training_batch: TrainingBatch, + ) -> None: + """ + Applies a single REINFORCE policy gradient step using the provided batch of rollouts. + Handles batching, loss computation (including entropy and KL regularization), gradient accumulation, and optimizer step. + Optionally logs various metrics and statistics. + + Args: + paths (list[str]): List of game complete file paths for each rollout. + contexts (list[torch.Tensor]): List of context tensors for each rollout. + credits (list[torch.Tensor]): List of credit tensors (rewards/advantages) for each rollout. + action_masks (list[torch.Tensor]): List of action mask tensors for each rollout. + """ + with resource_logger_context(logger, "Apply reinforce step"): + self.policy.train() + mb_size = self.mini_batch_size + nb_rollouts = len(training_batch) + + # Initialize running mean logs + running_mean_logs = { + "rl_objective": 0.0, + "policy_gradient_loss": 0.0, + "policy_gradient_norm": 0.0, + "log_probs": 0.0, + "credits": 0.0, + "entropy": 0.0, + "engine_log_probs_diff_clampfrac": 0.0, + "tis_imp_ratio": 0.0, + "ref_log_probs_diff_clampfrac": 0.0, + "higher_refprob_frac": 0.0, + "tis_imp_ratio_clampfrac": 0.0, + } + if self.entropy_coeff != 0.0: + running_mean_logs["entropy"] = 0.0 + if self.kl_coeff != 0.0: + running_mean_logs["kl_divergence"] = 0.0 + + # Get total number of tokens generated + total_tokens_generated = 0 + for att_mask in training_batch.batch_action_mask: + total_tokens_generated += att_mask.sum() + + # Obtain loss normalization + if self.pg_loss_normalization == "nb_tokens": + normalization_factor = total_tokens_generated + elif self.pg_loss_normalization == "batch": + normalization_factor = np.ceil(nb_rollouts / mb_size).astype(int) + else: + raise ValueError( + f"Invalid pg_loss_normalization: {self.pg_loss_normalization}" + ) + + # Gradient accumulation for each mini-batch + for mb in range(0, nb_rollouts, mb_size): + logger.info(f"Processing mini-batch {mb} of {nb_rollouts}") + loss = 0.0 + training_mb = training_batch[mb : mb + mb_size] + training_mb = training_mb.get_padded_tensors() + training_mb.to(self.device) + ( + tokens_mb, + action_mask_mb, + entropy_mask_mb, + credits_mb, + engine_log_probs_mb, + timesteps_mb, + ) = ( + training_mb.batch_input_ids, + training_mb.batch_action_mask, + training_mb.batch_entropy_mask, + training_mb.batch_credits, + training_mb.batch_engine_log_probs, + training_mb.batch_timesteps, + ) + + # Next token prediction + contexts_mb = tokens_mb[:, :-1] + shifted_contexts_mb = tokens_mb[:, 1:] + action_mask_mb = action_mask_mb[:, 1:] + entropy_mask_mb = entropy_mask_mb[:, 1:] + credits_mb = credits_mb[:, 1:] + engine_log_probs_mb = engine_log_probs_mb[:, 1:] + timesteps_mb = timesteps_mb[:, 1:] + + if self.enable_tokenwise_logging: + self.tokenwise_tally.set_action_mask(action_mask=action_mask_mb) + self.tokenwise_tally.set_range(range=(mb, mb + mb_size)) + self.tokenwise_tally.add_contexts(contexts=contexts_mb) + self.tokenwise_tally.add_data( + metric_id="next_token", + metrics=shifted_contexts_mb, + to_tids=True, + ) + self.tokenwise_tally.add_data( + metric_id="entropy_mask", + metrics=entropy_mask_mb, + ) + + if self.enable_tokenwise_logging: + self.tokenwise_tally.add_data( + metric_id="next_token_credit", metrics=credits_mb + ) + + # Forward pass + cast to FP-32 for higher prec. Causal LM attention masks are implicit; + # wire up a custom mask here only if the policy deviates from standard autoregressive behavior. + logits = self.policy(input_ids=contexts_mb)[0] # (B, S, V) + + # Mask non-restricted tokens + if self.restrict_tokens is not None: + logits = self.mask_non_restricted_token_logits(logits) + + logits /= self.temperature # (B, S, V) + + # Compute new log probabilities + log_probs = F.log_softmax(logits, dim=-1) # (B, S, V) + + # Get log probabilities of actions taken during rollouts + action_log_probs = log_probs.gather( + dim=-1, index=shifted_contexts_mb.unsqueeze(-1) + ).squeeze( + -1 + ) # (B, S) + if self.pg_loss_normalization == "batch": + den_running_mean = action_mask_mb.sum() * normalization_factor + else: + den_running_mean = normalization_factor + running_mean_logs["log_probs"] += ( + action_log_probs * action_mask_mb + ).sum().item() / den_running_mean + running_mean_logs["credits"] += ( + credits_mb * action_mask_mb + ).sum().item() / den_running_mean + + if self.enable_tokenwise_logging: + self.tokenwise_tally.add_data( + metric_id="next_token_log_prob", + metrics=action_log_probs, + ) + self.tokenwise_tally.add_data( + metric_id="engine_next_token_log_prob", + metrics=engine_log_probs_mb, + ) + self.tokenwise_tally.add_data( + metric_id="next_token_prob", + metrics=torch.exp(action_log_probs), + ) + top_k_indices = torch.topk(logits, k=5, dim=-1).indices + self.tokenwise_tally.add_data( + metric_id=f"top_{5}_tids", + metrics=top_k_indices, + to_tids=True, + ) + self.tokenwise_tally.add_data( + metric_id=f"top_{5}_probs", + metrics=torch.exp(log_probs).gather( + dim=-1, index=top_k_indices + ), + ) + + rewarded_action_log_probs = ( + action_mask_mb * credits_mb * action_log_probs + ) + # (B, S) + INVALID_LOGPROB = 1.0 + CLAMP_VALUE = 40.0 + masked_action_log_probs = torch.masked_fill( + action_log_probs, ~action_mask_mb, INVALID_LOGPROB + ) + masked_engine_log_probs = torch.masked_fill( + engine_log_probs_mb, ~action_mask_mb, INVALID_LOGPROB + ) + with torch.no_grad(): + action_engine_log_probs_diff = ( + masked_action_log_probs - masked_engine_log_probs + ).clamp(-CLAMP_VALUE, CLAMP_VALUE) + running_mean_logs["engine_log_probs_diff_clampfrac"] += ( + action_engine_log_probs_diff.abs() + .eq(CLAMP_VALUE) + .float() + .sum() + .item() + / den_running_mean + ) + if self.importance_sampling_strategy == "per_sequence": + tis_imp_ratio = torch.zeros_like(action_engine_log_probs_diff) + for mb_idx in range(action_engine_log_probs_diff.shape[0]): + valid_token_mask = action_mask_mb[mb_idx] + timestep_ids = timesteps_mb[mb_idx][valid_token_mask] + timestep_logprob_diffs = action_engine_log_probs_diff[mb_idx][ + valid_token_mask + ] + max_timestep = int(timestep_ids.max().item()) + 1 + timestep_sums = torch.zeros( + max_timestep, + device=action_engine_log_probs_diff.device, + dtype=action_engine_log_probs_diff.dtype, + ) + timestep_sums.scatter_add_( + 0, timestep_ids, timestep_logprob_diffs + ) + timestep_ratios = torch.exp(timestep_sums) + tis_imp_ratio[ + mb_idx, valid_token_mask + ] = timestep_ratios.gather(0, timestep_ids) + else: + tis_imp_ratio = torch.exp(action_engine_log_probs_diff) + running_mean_logs["tis_imp_ratio"] += ( + tis_imp_ratio * action_mask_mb + ).sum().item() / den_running_mean + if self.truncated_importance_sampling_ratio_cap > 0.0: + tis_imp_ratio = torch.clamp( + tis_imp_ratio, max=self.truncated_importance_sampling_ratio_cap + ) + running_mean_logs["tis_imp_ratio_clampfrac"] += ( + tis_imp_ratio.eq(self.truncated_importance_sampling_ratio_cap) + .float() + .sum() + .item() + ) / den_running_mean + rewarded_action_log_probs = ( + rewarded_action_log_probs * tis_imp_ratio + ) + + if self.enable_tokenwise_logging: + self.tokenwise_tally.add_data( + metric_id="next_token_clogπ", + metrics=rewarded_action_log_probs, + ) + + # Add value term to loss + if self.pg_loss_normalization == "batch": + nb_act_tokens = action_mask_mb.sum() + mb_value = -rewarded_action_log_probs.sum() / nb_act_tokens + else: + mb_value = -rewarded_action_log_probs.sum() + + loss += mb_value + running_mean_logs["rl_objective"] += mb_value.item() / den_running_mean + + # ------------------------------------------------- + # Entropy Regularization + # ------------------------------------------------- + # Only apply entropy on distribution defined over most probable tokens + if self.entropy_topk is not None: + top_k_indices = torch.topk( + logits, k=self.entropy_topk, dim=-1 + ).indices + entropy_logits = logits.gather(dim=-1, index=top_k_indices) + else: + entropy_logits = logits + + token_entropy_terms = -F.softmax( + entropy_logits, dim=-1 + ) * F.log_softmax( + entropy_logits, dim=-1 + ) # (B, S, T) + token_entropy_terms *= ( + action_mask_mb[:, :, None] * entropy_mask_mb[:, :, None] + ) # only get loss on specific action tokens + + mb_entropy = token_entropy_terms.sum(dim=-1) + + if self.enable_tokenwise_logging: + self.tokenwise_tally.add_data( + metric_id="entropy", + metrics=mb_entropy, + ) + if self.pg_loss_normalization == "batch": + nb_act_tokens = action_mask_mb.sum() + mb_entropy = -mb_entropy.sum() / nb_act_tokens + else: + mb_entropy = -mb_entropy.sum() + running_mean_logs["entropy"] += -mb_entropy.item() / den_running_mean + if self.entropy_coeff != 0.0: + mb_entropy *= self.entropy_coeff + loss += mb_entropy + + # ------------------------------------------------- + # KL-DIVERGENCE + # ------------------------------------------------- + if self.kl_coeff != 0.0: + ref_model_logits = self.policy.get_base_model_logits(contexts_mb) + ref_model_logits = ref_model_logits / self.temperature + # (B, S, V) + ref_model_logits = self.mask_non_restricted_token_logits( + logits=ref_model_logits + ) + # (B, S, V) + ref_model_log_probs = F.log_softmax(ref_model_logits, dim=-1) + # (B, S, V) + ref_model_action_log_probs = ref_model_log_probs.gather( + dim=-1, index=shifted_contexts_mb.unsqueeze(-1) + ).squeeze( + -1 + ) # (B,S) + # Approximating KL Divergence (see refs in docstring) + # Ref 1: http://joschu.net/blog/kl-approx.html + # Ref 2: https://github.dev/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L1332 + masked_ref_model_action_log_probs = torch.masked_fill( + ref_model_action_log_probs, ~action_mask_mb, INVALID_LOGPROB + ) + action_log_probs_diff = ( + masked_ref_model_action_log_probs - masked_action_log_probs + ).clamp(-CLAMP_VALUE, CLAMP_VALUE) + running_mean_logs["ref_log_probs_diff_clampfrac"] += ( + action_log_probs_diff.abs().eq(CLAMP_VALUE).float().sum().item() + / den_running_mean + ) + if self.filter_higher_refprob_tokens_kl: + higher_refprob_tokens_mask = action_log_probs_diff > 0.0 + running_mean_logs["higher_refprob_frac"] += ( + higher_refprob_tokens_mask.sum().item() / den_running_mean + ) + action_log_probs_diff = action_log_probs_diff * ( + ~higher_refprob_tokens_mask + ) + kl_div = torch.expm1(action_log_probs_diff) - action_log_probs_diff + kl_div *= action_mask_mb # We only care about KLD of action tokens + if self.truncated_importance_sampling_ratio_cap > 0.0: + kl_div = kl_div * tis_imp_ratio + kl_div *= self.kl_coeff + if self.enable_tokenwise_logging: + self.tokenwise_tally.add_data( + metric_id="ref_model_next_token_log_prob", + metrics=ref_model_action_log_probs, + ) + self.tokenwise_tally.add_data( + metric_id="kl_divergence", + metrics=kl_div, + ) + + if self.pg_loss_normalization == "batch": + nb_act_tokens = action_mask_mb.sum() + mb_kl = kl_div.sum() / nb_act_tokens + else: + mb_kl = kl_div.sum() + running_mean_logs["kl_divergence"] += ( + mb_kl.item() / den_running_mean + ) + loss += mb_kl + + # Accumulate gradient + running_mean_logs["policy_gradient_loss"] += ( + loss.item() / den_running_mean + ) + loss /= normalization_factor + self.accelerator.backward(loss) + + # ensure gpu memory is freed + del training_mb + del log_probs + del logits + del loss + del action_log_probs + del rewarded_action_log_probs + + logger.info( + f"Accumulated the policy gradient loss for {total_tokens_generated} tokens." + ) + + # Clip gradients and take step + if self.gradient_clipping is not None: + grad_norm = self.accelerator.clip_grad_norm_( + self.policy.parameters(), self.gradient_clipping + ) + running_mean_logs["policy_gradient_norm"] += grad_norm.item() + + # Take step + self.policy_optimizer.step() + self.policy_optimizer.zero_grad() + + # Store logs + for key, value in running_mean_logs.items(): + self.tally.add_metric(path=key, metric=value) + + # Clear accelerator state so we do not accumulate references between optimizer steps. + self.accelerator.clear(self.policy, self.policy_optimizer) + import gc + + gc.collect() + torch.cuda.empty_cache() + return running_mean_logs + + def get_advantages_with_critic_gradient_accumulation( + self, trajectories: TrajectoryBatch, critic_loss_scaling_factor: float = 2.0 + ) -> torch.FloatTensor: + """ + Compute (and optionally whiten) advantages while training the critic in mini-batches. + Uses GAE if enabled, otherwise uses Monte Carlo returns. + Optionally trains the critic if GAE is used. + Returns: + advantages: NestedFloatTensors + """ + + mb_size = self.mini_batch_size + batch_size = trajectories.rollout_ids.shape[0] + agent_id = trajectories.agent_ids[0] + batch_rewards = trajectories.batch_rewards + + ###################################### + # use critic for advantage estimation + ###################################### + if self.use_gae: + if "buffer" in agent_id: + self.critic.eval() + training = False + else: + self.critic.train() + training = True + advantages = [] + # critic_loss_scaling_factor comes learning single critic for two agents + normalization_factor = ( + np.ceil(batch_size / mb_size).astype(int) * critic_loss_scaling_factor + ) + # For each minibatch + for mb in range(0, batch_size, mb_size): + trajectory_mb = trajectories[mb : mb + mb_size] + trajectory_mb.to(self.device) + rewards_mb = trajectory_mb.batch_rewards + ( + tokens_mb, + state_ends_mask_mb, + timestep_counts, + ) = trajectory_mb.get_padded_tensors_for_critic() + # critic causal attention up to end flags + if training: + vals_estimate_full = self.critic(tokens_mb) + else: + with torch.no_grad(): + vals_estimate_full = self.critic(tokens_mb) + + # if vals_estimate_full.dim() == 3: + # vals_estimate_full = vals_estimate_full.squeeze(-1) + + # Select only positions where states end, per sample → list of (jT,) + B = tokens_mb.shape[0] + vals_list = [ + vals_estimate_full[b][state_ends_mask_mb[b]] for b in range(B) + ] + + # Pad to (B, max_jT) = (B, S) + vals_estimate_mb = pad_sequence( + vals_list, batch_first=True, padding_value=0.0 + ) + dtype = vals_estimate_mb.dtype + rewards_mb = pad_sequence( + rewards_mb, batch_first=True, padding_value=0.0 + ).to( + dtype=dtype + ) # (B, S) + self.rollout_tally.add_metric( + path=["batch_rewards"], + rollout_tally_item=RolloutTallyItem( + crn_ids=trajectory_mb.crn_ids, + rollout_ids=trajectory_mb.rollout_ids, + agent_ids=trajectory_mb.agent_ids, + metric_matrix=rewards_mb, + ), + ) + if self.reward_normalizing_constant != 1.0: + rewards_mb /= self.reward_normalizing_constant + + det_vals_estimate_mb = vals_estimate_mb.detach() # (B, max_jT) + self.rollout_tally.add_metric( + path=["mb_value_estimates_critic"], + rollout_tally_item=RolloutTallyItem( + crn_ids=trajectory_mb.crn_ids, + rollout_ids=trajectory_mb.rollout_ids, + agent_ids=trajectory_mb.agent_ids, + metric_matrix=det_vals_estimate_mb, + ), + ) + + # Append a 0 value to the end of the value estimates + if det_vals_estimate_mb.shape[1] == rewards_mb.shape[1]: + Bsize = det_vals_estimate_mb.shape[0] + device = det_vals_estimate_mb.device + dtype = det_vals_estimate_mb.dtype + det_vals_estimate_mb = torch.cat( + [ + det_vals_estimate_mb, + torch.zeros((Bsize, 1), device=device, dtype=dtype), + ], + dim=1, + ) # (B, max_jT+1) + else: + raise ValueError( + "Incompatible shapes for value estimates and rewards." + ) + + # Get annealed lambda + if self.use_gae_lambda_annealing: + annealing_constant = self.gae_lambda_annealing_method( + step=self.trainer_annealing_state.annealing_step_counter + ) + annealed_lambda = ( + self.gae_lambda_annealing_limit * annealing_constant + ) + self.tally.add_metric( + path="annealed_lambda", metric=annealed_lambda + ) + else: + annealed_lambda = self.gae_lambda_annealing_limit + + # Get GAE advantages + gae_advantages = get_generalized_advantage_estimates( + rewards=rewards_mb, + value_estimates=det_vals_estimate_mb, + discount_factor=self.discount_factor, + lambda_coef=annealed_lambda, + ) # (B, max_jT) + self.rollout_tally.add_metric( + path=["mb_gae_advantages"], + rollout_tally_item=RolloutTallyItem( + crn_ids=trajectory_mb.crn_ids, + rollout_ids=trajectory_mb.rollout_ids, + agent_ids=trajectory_mb.agent_ids, + metric_matrix=gae_advantages, + ), + ) + if training: + targets = ( + gae_advantages.to(dtype=dtype) + det_vals_estimate_mb[:, :-1] + ) # (B, max_jT) # A(s, a, b) + V(s) = Q(s, a, b) + self.rollout_tally.add_metric( + path=["mb_targets_critic"], + rollout_tally_item=RolloutTallyItem( + crn_ids=trajectory_mb.crn_ids, + rollout_ids=trajectory_mb.rollout_ids, + agent_ids=trajectory_mb.agent_ids, + metric_matrix=targets, + ), + ) + if self.critic_loss_type == "mse": + loss = F.mse_loss( + input=vals_estimate_mb, + target=targets, + ) + elif self.critic_loss_type == "huber": + loss = F.huber_loss( + input=vals_estimate_mb, + target=targets, + ) + self.tally.add_metric(path=["mb_critic_loss"], metric=loss.item()) + # Accumulate gradient + loss /= normalization_factor + self.accelerator.backward(loss) + del loss + del targets + del vals_estimate_mb + del trajectory_mb + del vals_estimate_full + + # Get jagged back using timestep_counts + advantages.extend( + [gae_advantages[i, : timestep_counts[i]] for i in range(B)] + ) + + ###################################### + # use exclusively Monte Carlo returns & rloo for advantage estimation + ###################################### + else: + lengths = [len(c) for c in batch_rewards] + padded_rewards = pad_sequence( + batch_rewards, batch_first=True, padding_value=0.0 + ) + self.rollout_tally.add_metric( + path=["mb_rewards"], + rollout_tally_item=RolloutTallyItem( + crn_ids=trajectories.crn_ids, + rollout_ids=trajectories.rollout_ids, + agent_ids=trajectories.agent_ids, + metric_matrix=padded_rewards, + ), + ) + if self.reward_normalizing_constant != 1.0: + padded_rewards /= self.reward_normalizing_constant + padded_advantages = get_discounted_returns( + rewards=padded_rewards, + discount_factor=self.discount_factor, + ) # no baseline for now + if self.use_rloo: + is_grouped_by_rng = ( + trajectories.crn_ids.unique().shape[0] + != trajectories.crn_ids.shape[0] + ) + if is_grouped_by_rng and not self.no_rloo_grouping: + for crn_id in trajectories.crn_ids.unique(): + rng_mask = trajectories.crn_ids == crn_id + rng_advantages = padded_advantages[rng_mask] + rng_advantages, _ = get_rloo_credits(credits=rng_advantages) + padded_advantages[rng_mask] = rng_advantages + else: + padded_advantages, _ = get_rloo_credits(credits=padded_advantages) + self.rollout_tally.add_metric( + path=["mb_rloo_advantages"], + rollout_tally_item=RolloutTallyItem( + crn_ids=trajectories.crn_ids, + rollout_ids=trajectories.rollout_ids, + agent_ids=trajectories.agent_ids, + metric_matrix=padded_advantages, + ), + ) + advantages = [ + padded_advantages[i, : lengths[i]] + for i in range(padded_advantages.shape[0]) + ] + + if self.whiten_advantages_time_step_wise or self.whiten_advantages: + lengths = [len(c) for c in advantages] + padded_advantages = pad_sequence( + advantages, batch_first=True, padding_value=0.0 + ) + if self.whiten_advantages_time_step_wise: + whitened_padded_advantages = whiten_advantages_time_step_wise( + padded_advantages + ) + path = ["mb_whitened_advantages_time_step_wise"] + elif self.whiten_advantages: + whitened_padded_advantages = whiten_advantages(padded_advantages) + path = ["mb_whitened_advantages"] + self.rollout_tally.add_metric( + path=path, + rollout_tally_item=RolloutTallyItem( + crn_ids=trajectories.crn_ids, + rollout_ids=trajectories.rollout_ids, + agent_ids=trajectories.agent_ids, + metric_matrix=whitened_padded_advantages, + ), + ) + advantages = [ + whitened_padded_advantages[i, : lengths[i]] + for i in range(whitened_padded_advantages.shape[0]) + ] + + self.trainer_annealing_state.annealing_step_counter += 1 + + return advantages + + @abstractmethod + def set_agent_trajectory_data( + self, agent_id: str, roots: list[RolloutTreeRootNode] + ) -> None: + """ + Populate self.training_data for a single agent using the provided rollout trees. + """ + pass + + def set_trajectory_data( + self, roots: list[RolloutTreeRootNode], agent_ids: list[str] + ) -> None: + """ + Convenience wrapper to ingest trajectory data for every training agent. + """ + for agent_id in agent_ids: + self.set_agent_trajectory_data(agent_id, roots) + + @abstractmethod + def share_advantage_data(self) -> list[AdvantagePacket]: + pass + + @abstractmethod + def receive_advantage_data(self, advantage_packets: list[AdvantagePacket]) -> None: + pass + + def set_policy_gradient_data(self, agent_ids: list[str]) -> None: + """ + Reset and rebuild the policy-gradient minibatches before iterating through agents. + """ + self.policy_gradient_data = None + for agent_id in agent_ids: + assert "buffer" not in agent_id, "Buffer agents do not train policy" + trajectory_batch = self.training_data[agent_id] + tokenwise_batch_credits = get_tokenwise_credits( + batch_timesteps=trajectory_batch.batch_timesteps, + batch_credits=trajectory_batch.batch_credits, + ) + policy_gradient_data = TrainingBatch( + rollout_ids=trajectory_batch.rollout_ids, + batch_input_ids=trajectory_batch.batch_input_ids, + batch_action_mask=trajectory_batch.batch_action_mask, + batch_entropy_mask=trajectory_batch.batch_entropy_mask, + batch_credits=tokenwise_batch_credits, + batch_engine_log_probs=trajectory_batch.batch_engine_log_probs, + batch_timesteps=trajectory_batch.batch_timesteps, + ) + if self.policy_gradient_data is None: + self.policy_gradient_data = policy_gradient_data + else: + self.policy_gradient_data.append(policy_gradient_data) + + self.training_data = {} + self.tokenwise_tally = ContextualizedTokenwiseTally( + tokenizer=self.tokenizer, + paths=self.debug_path_list, + ) + + def train(self) -> None: + """ + Entry point for policy updates: prepare batches, compute gradients, and update parameters. + """ + assert self.policy_gradient_data is not None, "Policy gradient data is not set" + if self.critic_optimizer is not None: + if self.gradient_clipping is not None: + grad_norm = self.accelerator.clip_grad_norm_( + self.critic.parameters(), self.gradient_clipping + ) + self.tally.add_metric( + path="gradient_norm_critic", metric=grad_norm.item() + ) + # Take step + self.critic_optimizer.step() + self.critic_optimizer.zero_grad() + self.accelerator.clear(self.critic, self.critic_optimizer) + import gc + + gc.collect() + torch.cuda.empty_cache() + running_mean_logs = self.apply_reinforce_step( + training_batch=self.policy_gradient_data + ) + return running_mean_logs + + def export_training_tally(self, identifier: str, folder: str) -> None: + """ + Saves and resets the collected training metrics using the tally object. + """ + os.makedirs(folder, exist_ok=True) + self.tally.save(identifier=identifier, folder=folder) + self.tokenwise_tally.save( + path=os.path.join(folder, f"{identifier}_tokenwise.csv") + ) + self.rollout_tally.save(identifier=identifier, folder=folder) + self.tally.reset() + self.tokenwise_tally = None + self.rollout_tally.reset() + self.debug_path_list = [] + + def export_optimizer_states(self) -> None: + """ + Saves the optimizer states for both the main model and critic (if it exists). + """ + try: + os.makedirs(self.save_path, exist_ok=True) + + torch.save(self.policy_optimizer.state_dict(), self.policy_optimizer_path) + logger.info(f"Saved main optimizer state to {self.policy_optimizer_path}") + + if self.critic_optimizer is not None: + torch.save( + self.critic_optimizer.state_dict(), self.critic_optimizer_path + ) + logger.info( + f"Saved critic optimizer state to {self.critic_optimizer_path}" + ) + except Exception as e: + logger.error(f"Error saving optimizer states: {str(e)}") + raise + + def export_trainer_annealing_state(self) -> None: + """ + Saves the trainer state. + """ + with open(self.trainer_annealing_state_path, "wb") as f: + pickle.dump(self.trainer_annealing_state, f) + logger.info(f"Saved trainer state to {self.trainer_annealing_state_path}") + + def export_trainer_states(self) -> None: + """ + Saves the trainer states. + """ + self.export_optimizer_states() + self.export_trainer_annealing_state() diff --git a/src_code_for_reproducibility/training/trainer_independent.py b/src_code_for_reproducibility/training/trainer_independent.py new file mode 100644 index 0000000000000000000000000000000000000000..e03f4b23fb8aa8c3cd9533d57d703636eb890a2c --- /dev/null +++ b/src_code_for_reproducibility/training/trainer_independent.py @@ -0,0 +1,166 @@ +""" +File: mllm/training/trainer_independent.py +Summary: Trainer for independently optimizing each agent. +""" + +import logging +import os +import sys +from typing import Union + +import torch +import torch.nn.functional as F +from accelerate import Accelerator +from pandas._libs.tslibs.offsets import CBMonthBegin +from peft import LoraConfig +from torch.nn.utils.rnn import pad_sequence +from transformers import AutoModelForCausalLM, AutoTokenizer + +from mllm.markov_games.rollout_tree import * +from mllm.markov_games.rollout_tree import RolloutTreeRootNode +from mllm.training.credit_methods import ( + get_discounted_returns, + get_discounted_state_visitation_credits, + get_generalized_advantage_estimates, + get_rloo_credits, +) +from mllm.training.tally_metrics import Tally +from mllm.training.tally_tokenwise import ContextualizedTokenwiseTally +from mllm.training.tokenize_chats import * +from mllm.training.tokenize_chats import process_training_chat +from mllm.training.trainer_common import BaseTrainer +from mllm.training.training_data_utils import * +from mllm.training.training_data_utils import ( + TrainingBatch, + TrajectoryBatch, + get_tokenwise_credits, +) +from mllm.utils.resource_context import resource_logger_context + +logger = logging.getLogger(__name__) +logger.addHandler(logging.StreamHandler(sys.stdout)) + + +@dataclass +class TrainingData: + """Caches per-agent trajectory tensors plus their computed advantages.""" + + agent_id: str + main_data: TrajectoryBatch + # list-of-tensors: per rollout advantages with length jT + main_advantages: list[torch.FloatTensor] | None = None + + +class TrainerNaive(BaseTrainer): + def set_agent_trajectory_data( + self, agent_id: str, roots: list[RolloutTreeRootNode] + ) -> None: + """ + Tokenize rollouts for a given agent and cache the tensors used for training. + """ + # Reset per-agent buffers; extend this logic if joint training batches are needed. + self.policy_gradient_data = None + + # Tensorize Chats + rollout_ids = [] + crn_ids = [] # common random number id + batch_input_ids = [] + batch_action_mask = [] + batch_entropy_mask = [] + batch_timesteps = [] + batch_state_ends_mask = [] + batch_engine_log_probs = [] + batch_rewards = [] + for root in roots: + rollout_id = root.id + self.debug_path_list.append( + "mgid:" + str(rollout_id) + "_agent_id:" + agent_id + ) + rollout_ids.append(rollout_id) + crn_ids.append(root.crn_id) + chat, rewards = get_main_chat_list_and_rewards( + agent_id=agent_id, + root=root, + reward_agent_id=self.reward_agent_id, + reward_scale=self.reward_scale, + reward_peer_agent_id=self.reward_peer_agent_id, + reward_peer_scale=self.reward_peer_scale, + ) + ( + input_ids, + action_mask, + entropy_mask, + timesteps, + state_ends_mask, + engine_log_probs, + ) = process_training_chat( + tokenizer=self.tokenizer, + chat_history=chat, + entropy_mask_regex=self.entropy_mask_regex, + exploration_prompts_to_remove=self.exploration_prompts_to_remove, + ) + batch_input_ids.append(input_ids) + batch_action_mask.append(action_mask) + batch_entropy_mask.append(entropy_mask) + batch_timesteps.append(timesteps) + batch_state_ends_mask.append(state_ends_mask) + batch_engine_log_probs.append(engine_log_probs) + batch_rewards.append(rewards) + + trajectory_batch = TrajectoryBatch( + rollout_ids=torch.tensor(rollout_ids, dtype=torch.int32), + crn_ids=torch.tensor(crn_ids, dtype=torch.int32), + agent_ids=[agent_id] * len(rollout_ids), + batch_input_ids=batch_input_ids, + batch_action_mask=batch_action_mask, + batch_entropy_mask=batch_entropy_mask, + batch_timesteps=batch_timesteps, + batch_state_ends_mask=batch_state_ends_mask, + batch_rewards=batch_rewards, + batch_engine_log_probs=batch_engine_log_probs, + ) + + # Get Advantages + batch_advantages: torch.FloatTensor = ( + self.get_advantages_with_critic_gradient_accumulation(trajectory_batch) + ) + + # Discount state visitation (the mathematically correct way) + if not self.skip_discounted_state_visitation: + for i in range(len(batch_advantages)): + batch_advantages[i] = get_discounted_state_visitation_credits( + batch_advantages[i].unsqueeze(0), + self.discount_factor, + ).squeeze(0) + + self.training_data[agent_id] = TrainingData( + agent_id=agent_id, + main_data=trajectory_batch, + main_advantages=batch_advantages, + ) + + def receive_advantage_data(self, advantage_packets: list[AdvantagePacket]): + """ + This trainer ignores the advantages of the other trainers. + """ + for agent_id, agent_data in self.training_data.items(): + self.training_data[agent_id] = agent_data.main_data + self.training_data[agent_id].batch_credits = agent_data.main_advantages + + def share_advantage_data(self) -> list[AdvantagePacket]: + """ + Share the advantage data with other agents. + Returns: + AdvantagePacket: The advantage packet containing the agent's advantages. + """ + logger.info(f"Sharing advantage data.") + advantage_packets = [] + for agent_id, agent_data in self.training_data.items(): + advantage_packets.append( + AdvantagePacket( + agent_id=agent_id, + rollout_ids=agent_data.main_data.rollout_ids, + main_advantages=agent_data.main_advantages, + ) + ) + return advantage_packets diff --git a/src_code_for_reproducibility/training/trainer_sum_rewards.py b/src_code_for_reproducibility/training/trainer_sum_rewards.py new file mode 100644 index 0000000000000000000000000000000000000000..336a542bbf13691a9041bcf15da063f3183db4fe --- /dev/null +++ b/src_code_for_reproducibility/training/trainer_sum_rewards.py @@ -0,0 +1,127 @@ +""" +File: mllm/training/trainer_sum_rewards.py +Summary: Trainer that optimizes the sum-of-rewards objective. +""" + +import logging +import os +import sys +from typing import Union + +import torch +import torch.nn.functional as F +from accelerate import Accelerator +from pandas._libs.tslibs.offsets import CBMonthBegin +from peft import LoraConfig +from torch.nn.utils.rnn import pad_sequence +from transformers import AutoModelForCausalLM, AutoTokenizer + +from mllm.markov_games.rollout_tree import * +from mllm.markov_games.rollout_tree import RolloutTreeRootNode +from mllm.training.credit_methods import ( + get_discounted_returns, + get_discounted_state_visitation_credits, + get_generalized_advantage_estimates, + get_rloo_credits, +) +from mllm.training.tally_metrics import Tally +from mllm.training.tally_rollout import RolloutTally, RolloutTallyItem +from mllm.training.tally_tokenwise import ContextualizedTokenwiseTally +from mllm.training.tokenize_chats import * +from mllm.training.tokenize_chats import process_training_chat +from mllm.training.trainer_common import BaseTrainer +from mllm.training.trainer_independent import TrainerNaive, TrainingData +from mllm.training.training_data_utils import * +from mllm.training.training_data_utils import ( + AdvantagePacket, + TrainingBatch, + TrajectoryBatch, + get_tokenwise_credits, +) +from mllm.utils.resource_context import resource_logger_context + +logger = logging.getLogger(__name__) +logger.addHandler(logging.StreamHandler(sys.stdout)) + + +class TrainerSumRewards(TrainerNaive): + def receive_advantage_data(self, advantage_packets: list[AdvantagePacket]): + """Sum peer advantages onto this agent's advantages to optimize joint reward.""" + logger.info(f"Receiving advantage packets.") + + assert ( + len(advantage_packets) > 0 + ), "At least one advantage packet must be provided." + + for agent_id, agent_data in self.training_data.items(): + coagent_advantage_packets = [ + packet for packet in advantage_packets if packet.agent_id != agent_id + ] + agent_rollout_ids = agent_data.main_data.rollout_ids + agent_advantages = agent_data.main_advantages + co_agent_advantages = [] + for rollout_id in agent_rollout_ids: + for co_agent_packet in coagent_advantage_packets: + if rollout_id in co_agent_packet.rollout_ids: + index = torch.where(rollout_id == co_agent_packet.rollout_ids)[ + 0 + ].item() + co_agent_advantages.append( + co_agent_packet.main_advantages[index] + ) + # assumes that its two player game, with one co-agent + break + assert len(co_agent_advantages) == len(agent_advantages) + B = len(agent_advantages) + assert all( + a.shape[0] == b.shape[0] + for a, b in zip(co_agent_advantages, agent_advantages) + ), "Number of advantages must match in order to sum them up." + + # Get padded tensors (advantage alignment is invariant to padding) + lengths = torch.tensor( + [len(t) for t in agent_advantages], + device=self.device, + dtype=torch.long, + ) + padded_main_advantages = pad_sequence( + agent_advantages, batch_first=True, padding_value=0.0 + ) + + padded_co_agent_advantages = pad_sequence( + co_agent_advantages, batch_first=True, padding_value=0.0 + ) + + # Create training batch data + sum_of_ad_credits = padded_main_advantages + padded_co_agent_advantages + self.rollout_tally.add_metric( + path=["sum_of_ad_credits"], + rollout_tally_item=RolloutTallyItem( + crn_ids=agent_data.main_data.crn_ids, + rollout_ids=agent_data.main_data.rollout_ids, + agent_ids=agent_data.main_data.agent_ids, + metric_matrix=sum_of_ad_credits, + ), + ) + + if not self.skip_discounted_state_visitation: + sum_of_ad_credits = get_discounted_state_visitation_credits( + sum_of_ad_credits, + self.discount_factor, + ) + self.rollout_tally.add_metric( + path=["discounted_state_visitation_credits"], + rollout_tally_item=RolloutTallyItem( + crn_ids=agent_data.main_data.crn_ids, + rollout_ids=agent_data.main_data.rollout_ids, + agent_ids=agent_data.main_data.agent_ids, + metric_matrix=sub_tensors[ + "discounted_state_visitation_credits" + ], + ), + ) + + # Slice back to jagged and convert to tokenwise credits + sum_of_ad_credits = [sum_of_ad_credits[i, : lengths[i]] for i in range(B)] + self.training_data[agent_id] = agent_data.main_data + self.training_data[agent_id].batch_credits = sum_of_ad_credits diff --git a/src_code_for_reproducibility/training/training_data_utils.py b/src_code_for_reproducibility/training/training_data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..81c0daf6020ed0c2003e958706b86c85e279b109 --- /dev/null +++ b/src_code_for_reproducibility/training/training_data_utils.py @@ -0,0 +1,409 @@ +""" +File: mllm/training/training_data_utils.py +Summary: Utilities for loading, filtering, and batching training data. +""" + +from dataclasses import dataclass +from typing import Literal, Optional, Tuple + +import torch +from torch.nn.utils.rnn import pad_sequence + +from mllm.markov_games.rollout_tree import ( + ChatTurn, + RolloutTreeBranchNode, + RolloutTreeNode, + RolloutTreeRootNode, +) + + +@dataclass +class AdvantagePacket: + """Message used by trainers to share per-rollout advantages.""" + + agent_id: str + rollout_ids: torch.IntTensor # (B,) + # list-of-tensors + main_advantages: list[torch.FloatTensor] + + +class TrainingChatTurn: + """ + Lightweight ChatTurn variant that records the timestep index alongside role/content. + """ + + def __init__( + self, + time_step: int, + role: str, + agent_id: str, + content: str, + chat_template_token_ids: list[int], + reasoning_content: str, + is_state_end: bool, + out_token_ids: Optional[list[int]] = None, + log_probs: Optional[list[float]] = None, + ) -> None: + self.time_step = time_step + self.role = role + self.agent_id = agent_id + self.content = content + self.chat_template_token_ids = chat_template_token_ids + self.reasoning_content = reasoning_content + self.is_state_end = is_state_end + self.out_token_ids = out_token_ids + self.log_probs = log_probs + + def dict(self): + return { + "time_step": self.time_step, + "role": self.role, + "agent_id": self.agent_id, + "content": self.content, + "chat_template_token_ids": self.chat_template_token_ids, + "reasoning_content": self.reasoning_content, + "is_state_end": self.is_state_end, + "out_token_ids": self.out_token_ids, + "log_probs": self.log_probs, + } + + +def get_main_chat_list_and_rewards( + agent_id: str, + root: RolloutTreeRootNode | RolloutTreeNode, + reward_agent_id: Optional[str] = None, + reward_scale: float = 1.0, + reward_peer_agent_id: Optional[str] = None, + reward_peer_scale: float = 0.0, +) -> Tuple[list[TrainingChatTurn], torch.FloatTensor]: + """ + This method traverses a rollout tree and returns a the list of ChatTurn + for an agent. If it encounters a branch node, it follows the main path. + """ + # Currently follows only the main branch; extend if side branches must be included. + if isinstance(root, RolloutTreeRootNode): + current_node = root.child + else: + current_node = root + + chat = [] + rewards = [] + reward_agent_id = reward_agent_id or agent_id + while current_node is not None: + if isinstance(current_node, RolloutTreeBranchNode): + current_node = current_node.main_child + reward: float = current_node.step_log.simulation_step_log.rewards[ + reward_agent_id + ] + total_reward = reward_scale * reward + if reward_peer_agent_id is not None: + peer_reward: float = current_node.step_log.simulation_step_log.rewards[ + reward_peer_agent_id + ] + total_reward += reward_peer_scale * peer_reward + rewards.append(total_reward) + chat_turns: list[TrainingChatTurn] = current_node.step_log.action_logs[ + agent_id + ].chat_turns + chat_turns = [ + TrainingChatTurn(time_step=current_node.time_step, **turn.model_dump()) + for turn in chat_turns + ] + chat.extend(chat_turns) + current_node = current_node.child + return chat, torch.FloatTensor(rewards) + + +def get_tokenwise_credits( + # B := batch size, S := number of tokens / seq. length, T := number of states. `j` stands for jagged (see pytorch nested tensors.) + batch_timesteps: torch.IntTensor | torch.Tensor, # (B, jS), + batch_credits: torch.FloatTensor | torch.Tensor, # (B, jT) +) -> torch.FloatTensor | torch.Tensor: # (B, jS) + """ + Expand per-state credits so every token at that timestep receives the same value. + """ + # The explicit loops keep jagged tensor semantics simple; optimize later if profiling warrants it. + batch_token_credits = [] + for credits, timesteps in zip(batch_credits, batch_timesteps): + token_credits = torch.zeros_like( + timesteps, + dtype=credits.dtype, + device=timesteps.device, + ) + for idx, credit in enumerate(credits): + token_credits[timesteps == idx] = credit + batch_token_credits.append(token_credits) + return batch_token_credits + + +@dataclass +class TrajectoryBatch: + """ + Tensorized batch of trajectories using list-of-tensors for jagged dimensions. + """ + + # B := batch size, S := number of tokens / seq. length, T := number of states. + rollout_ids: torch.IntTensor # (B,) + crn_ids: torch.IntTensor # (B,) + agent_ids: list[str] # (B,) + batch_input_ids: list[torch.LongTensor] # List[(jS,)] + batch_action_mask: list[torch.BoolTensor] # List[(jS,)] + batch_entropy_mask: list[torch.BoolTensor] # List[(jS,)] + batch_timesteps: list[torch.IntTensor] # List[(jS,)] + batch_state_ends_mask: list[torch.BoolTensor] # List[(jS,)] + batch_engine_log_probs: Optional[list[torch.FloatTensor]] # List[(jS,)] + batch_rewards: list[torch.FloatTensor] # List[(jT,)] + batch_credits: Optional[list[torch.FloatTensor]] = None # List[(jS,)] + + def __post_init__(self): + """ + Validate per-sample consistency. + """ + B = self.rollout_ids.shape[0] + assert ( + self.crn_ids.shape[0] == B + ), "RNG IDs must have length equal to batch size." + assert ( + len(self.agent_ids) == B + ), "agent_ids must have length equal to batch size." + assert ( + len(self.batch_input_ids) + == len(self.batch_action_mask) + == len(self.batch_entropy_mask) + == len(self.batch_timesteps) + == len(self.batch_state_ends_mask) + == len(self.batch_engine_log_probs) + == len(self.batch_rewards) + == B + ), "Jagged lists must all have length equal to batch size." + + for b in range(B): + nb_rewards = int(self.batch_rewards[b].shape[0]) + nb_timesteps = int(torch.max(self.batch_timesteps[b]).item()) + 1 + assert ( + nb_rewards == nb_timesteps + ), "Number of rewards and timesteps mismatch." + assert ( + self.batch_input_ids[b].shape[0] + == self.batch_action_mask[b].shape[0] + == self.batch_entropy_mask[b].shape[0] + == self.batch_engine_log_probs[b].shape[0] + == self.batch_timesteps[b].shape[0] + ), "Tensors must have the same shape along the jagged dimension." + assert ( + int(self.batch_state_ends_mask[b].sum()) + == self.batch_rewards[b].shape[0] + ), "Number of rewards must match number of state ends." + + """ + Entries: + Here, we ignore the batch dimension. + input_ids: + All of the tokens of both the user and the assistant, flattened. + action_mask: + Set to true on the tokens of the assistant (tokens generated by the model). + timesteps: + Therefore, max(timesteps) = Ns - 1. + state_ends_idx: + Indices of the tokens at which state descriptions end. + rewards: + rewards[t] := R_t(s_t, a_t) + Example: + position: "0 1 2 3 4 5 6 7 8 9 10 11 12 13 14" + input_ids: "U U U a a a U a U a a a U U U" (U := User, a := Assistant) + action_mask: "x x x ✓ ✓ ✓ x ✓ x ✓ ✓ ✓ x x x" + timestep: "0 0 0 0 0 0 1 1 1 1 1 1 2 2 2" + state_ends_dx: [2, 6, 14] + rewards: [r0, r1, r2] + """ + + def __getitem__(self, key) -> "TrajectoryBatch": + if isinstance(key, slice): + return TrajectoryBatch( + rollout_ids=self.rollout_ids.__getitem__(key), + crn_ids=self.crn_ids.__getitem__(key), + agent_ids=self.agent_ids[key], + batch_input_ids=self.batch_input_ids[key], + batch_action_mask=self.batch_action_mask[key], + batch_entropy_mask=self.batch_entropy_mask[key], + batch_timesteps=self.batch_timesteps[key], + batch_state_ends_mask=self.batch_state_ends_mask[key], + batch_engine_log_probs=self.batch_engine_log_probs[key], + batch_rewards=self.batch_rewards[key], + batch_credits=self.batch_credits[key] if self.batch_credits else None, + ) + + def __len__(self): + return len(self.batch_input_ids) + + def to(self, device): + self.rollout_ids = self.rollout_ids.to(device) + self.crn_ids = self.crn_ids.to(device) + self.batch_input_ids = [t.to(device) for t in self.batch_input_ids] + self.batch_action_mask = [t.to(device) for t in self.batch_action_mask] + self.batch_entropy_mask = [t.to(device) for t in self.batch_entropy_mask] + self.batch_timesteps = [t.to(device) for t in self.batch_timesteps] + self.batch_state_ends_mask = [t.to(device) for t in self.batch_state_ends_mask] + self.batch_engine_log_probs = [ + t.to(device) for t in self.batch_engine_log_probs + ] + self.batch_rewards = [t.to(device) for t in self.batch_rewards] + self.batch_credits = ( + [t.to(device) for t in self.batch_credits] if self.batch_credits else None + ) + + def get_padded_tensors_for_critic(self): + """ + Returns: + padded_batch_input_ids: (B, P) + padded_batch_state_ends_mask: (B, P) + timestep_counts: (B,) tensor of ints indicating number of states per sample + """ + padded_batch_input_ids = pad_sequence( + self.batch_input_ids, batch_first=True, padding_value=0 + ) + padded_batch_state_ends_mask = pad_sequence( + self.batch_state_ends_mask, batch_first=True, padding_value=0 + ).bool() + # number of states equals number of True in state_ends_mask + timestep_counts = torch.tensor( + [int(mask.sum().item()) for mask in self.batch_state_ends_mask], + device=padded_batch_input_ids.device, + dtype=torch.long, + ) + return padded_batch_input_ids, padded_batch_state_ends_mask, timestep_counts + + +timestep = int + + +@dataclass +class PaddedTensorTrainingBatch: + """Helper struct returned by ``TrainingBatch.get_padded_tensors``.""" + + batch_input_ids: torch.LongTensor | torch.Tensor + batch_action_mask: torch.BoolTensor | torch.Tensor + batch_entropy_mask: Optional[torch.BoolTensor | torch.Tensor] + batch_credits: torch.FloatTensor | torch.Tensor + batch_engine_log_probs: torch.FloatTensor | torch.Tensor + batch_timesteps: torch.IntTensor | torch.Tensor + + def __len__(self): + return self.batch_input_ids.shape[0] + + def to(self, device): + self.batch_input_ids = self.batch_input_ids.to(device) + self.batch_action_mask = self.batch_action_mask.to(device) + self.batch_entropy_mask = self.batch_entropy_mask.to(device) + self.batch_credits = self.batch_credits.to(device) + self.batch_engine_log_probs = self.batch_engine_log_probs.to(device) + self.batch_timesteps = self.batch_timesteps.to(device) + + +@dataclass +class TrainingBatch: + rollout_ids: torch.IntTensor | torch.Tensor # (B,) + batch_input_ids: list[torch.LongTensor] # List[(jS,)] + batch_action_mask: list[torch.BoolTensor] # List[(jS,)] + batch_entropy_mask: Optional[list[torch.BoolTensor]] # List[(jS,)] + batch_credits: list[torch.FloatTensor] # List[(jS,)] + batch_engine_log_probs: list[torch.FloatTensor] # List[(jS,)] + batch_timesteps: list[torch.IntTensor] # List[(jS,)] + + def __post_init__(self): + # Ensure batch dimension is present + assert ( + len(self.batch_input_ids) + == len(self.batch_action_mask) + == len(self.batch_entropy_mask) + == len(self.batch_credits) + == len(self.batch_engine_log_probs) + == len(self.batch_timesteps) + == self.rollout_ids.shape[0] + ), "Jagged lists must all have length equal to batch size." + for inp, mask, cred, engine_log_prob, timestep in zip( + self.batch_input_ids, + self.batch_action_mask, + self.batch_credits, + self.batch_engine_log_probs, + self.batch_timesteps, + ): + assert ( + inp.shape[0] + == mask.shape[0] + == cred.shape[0] + == engine_log_prob.shape[0] + == timestep.shape[0] + ), "Tensors must have the same shapes along the jagged dimension." + + def __getitem__(self, key) -> "TrainingBatch": + if isinstance(key, slice): + return TrainingBatch( + rollout_ids=self.rollout_ids.__getitem__(key), + batch_input_ids=self.batch_input_ids[key], + batch_action_mask=self.batch_action_mask[key], + batch_entropy_mask=self.batch_entropy_mask[key], + batch_credits=self.batch_credits[key], + batch_engine_log_probs=self.batch_engine_log_probs[key], + batch_timesteps=self.batch_timesteps[key], + ) + + def __len__(self): + return len(self.batch_input_ids) + + def to(self, device): + self.rollout_ids = self.rollout_ids.to(device) + self.batch_input_ids = [t.to(device) for t in self.batch_input_ids] + self.batch_action_mask = [t.to(device) for t in self.batch_action_mask] + self.batch_entropy_mask = [t.to(device) for t in self.batch_entropy_mask] + self.batch_credits = [t.to(device) for t in self.batch_credits] + self.batch_engine_log_probs = [ + t.to(device) for t in self.batch_engine_log_probs + ] + self.batch_timesteps = [t.to(device) for t in self.batch_timesteps] + + def get_padded_tensors(self, padding: float = 0.0): + """ + Materialize right-padded tensors so PyTorch ops can run on uniform shapes. + """ + padded_batch_input_ids = pad_sequence( + self.batch_input_ids, batch_first=True, padding_value=int(padding) + ) + padded_batch_action_mask = pad_sequence( + [m.to(dtype=torch.bool) for m in self.batch_action_mask], + batch_first=True, + padding_value=False, + ) + padded_batch_entropy_mask = pad_sequence( + self.batch_entropy_mask, batch_first=True, padding_value=False + ) + padded_batch_credits = pad_sequence( + self.batch_credits, batch_first=True, padding_value=float(padding) + ) + padded_batch_engine_log_probs = pad_sequence( + self.batch_engine_log_probs, batch_first=True, padding_value=float(padding) + ) + padded_batch_timesteps = pad_sequence( + self.batch_timesteps, batch_first=True, padding_value=0 + ) + + return PaddedTensorTrainingBatch( + padded_batch_input_ids, + padded_batch_action_mask, + padded_batch_entropy_mask, + padded_batch_credits, + padded_batch_engine_log_probs, + padded_batch_timesteps, + ) + + def append(self, other: "TrainingBatch"): + self.rollout_ids = torch.cat([self.rollout_ids, other.rollout_ids]) + self.batch_input_ids.extend(other.batch_input_ids) + self.batch_action_mask.extend(other.batch_action_mask) + self.batch_entropy_mask.extend(other.batch_entropy_mask) + self.batch_credits.extend(other.batch_credits) + self.batch_engine_log_probs.extend(other.batch_engine_log_probs) + self.batch_timesteps.extend(other.batch_timesteps) + + +timestep = int diff --git a/src_code_for_reproducibility/utils/__init__.py b/src_code_for_reproducibility/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f8f86a6250af7b404a43dc38a7d58aef50dbeb6d --- /dev/null +++ b/src_code_for_reproducibility/utils/__init__.py @@ -0,0 +1,4 @@ +""" +File: mllm/utils/__init__.py +Summary: Utility package exposing helper modules. +""" diff --git a/src_code_for_reproducibility/utils/__pycache__/__init__.cpython-312.pyc b/src_code_for_reproducibility/utils/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d48d74878ee839af2ccf1ec5669b79bab5bef28 Binary files /dev/null and b/src_code_for_reproducibility/utils/__pycache__/__init__.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/utils/__pycache__/dict_get_path.cpython-312.pyc b/src_code_for_reproducibility/utils/__pycache__/dict_get_path.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df43d9f111270ad98a105cba5243f21559c79cee Binary files /dev/null and b/src_code_for_reproducibility/utils/__pycache__/dict_get_path.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/utils/__pycache__/get_coagent_id.cpython-312.pyc b/src_code_for_reproducibility/utils/__pycache__/get_coagent_id.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8440f01edaa7af149a0be633b00f4f82fc749ccb Binary files /dev/null and b/src_code_for_reproducibility/utils/__pycache__/get_coagent_id.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/utils/__pycache__/resource_context.cpython-312.pyc b/src_code_for_reproducibility/utils/__pycache__/resource_context.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64262d72d7f478c011adb21060b21d496ea5f520 Binary files /dev/null and b/src_code_for_reproducibility/utils/__pycache__/resource_context.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/utils/__pycache__/rollout_tree_gather_utils.cpython-312.pyc b/src_code_for_reproducibility/utils/__pycache__/rollout_tree_gather_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24330c84e32739ea0196b0590355a6e30be9d936 Binary files /dev/null and b/src_code_for_reproducibility/utils/__pycache__/rollout_tree_gather_utils.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/utils/__pycache__/rollout_tree_stats.cpython-312.pyc b/src_code_for_reproducibility/utils/__pycache__/rollout_tree_stats.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2f49125d23803bb37ccebe81d8fb9ec6e99484c Binary files /dev/null and b/src_code_for_reproducibility/utils/__pycache__/rollout_tree_stats.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/utils/__pycache__/short_id_gen.cpython-312.pyc b/src_code_for_reproducibility/utils/__pycache__/short_id_gen.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1070575327aefef5786460a81cddca82f3dc98b Binary files /dev/null and b/src_code_for_reproducibility/utils/__pycache__/short_id_gen.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/utils/__pycache__/stat_pack.cpython-312.pyc b/src_code_for_reproducibility/utils/__pycache__/stat_pack.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a92745c5bf316b3ab0ae3758b5cd69ae470fb079 Binary files /dev/null and b/src_code_for_reproducibility/utils/__pycache__/stat_pack.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/utils/__pycache__/update_start_epoch.cpython-312.pyc b/src_code_for_reproducibility/utils/__pycache__/update_start_epoch.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8b32accd0e1864c1fff187b80a6f8e2beee2ac4 Binary files /dev/null and b/src_code_for_reproducibility/utils/__pycache__/update_start_epoch.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/utils/__pycache__/wandb_utils.cpython-312.pyc b/src_code_for_reproducibility/utils/__pycache__/wandb_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22bdb19540ae591468a63fedd0c7f34d7ae66568 Binary files /dev/null and b/src_code_for_reproducibility/utils/__pycache__/wandb_utils.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/utils/dict_get_path.py b/src_code_for_reproducibility/utils/dict_get_path.py new file mode 100644 index 0000000000000000000000000000000000000000..16b91ec7ec8ecf4e5ed96af29945f44d27bd0276 --- /dev/null +++ b/src_code_for_reproducibility/utils/dict_get_path.py @@ -0,0 +1,17 @@ +""" +File: mllm/utils/dict_get_path.py +Summary: Retrieves nested dictionary values using dotted key paths. +""" + + +def get_from_nested_dict(a: dict, path) -> any: + # path is string or list of string + try: + if isinstance(path, str): + return a[path] + else: + for p in path: + a = a[p] + return a + except Exception: + return None diff --git a/src_code_for_reproducibility/utils/gather_training_stats.py b/src_code_for_reproducibility/utils/gather_training_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..067fc238c3899ade78a0a4622d002a2c99e337aa --- /dev/null +++ b/src_code_for_reproducibility/utils/gather_training_stats.py @@ -0,0 +1,262 @@ +""" +File: mllm/utils/gather_training_stats.py +Summary: Aggregates training statistics from rollouts and exports artifacts. +""" + +import copy +import csv +import gc +import json +import logging +import os +import pickle +import random +import re +import subprocess +import sys +import time +from datetime import datetime +from statistics import mean +from typing import Any, Dict + +import hydra +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import torch +from omegaconf import OmegaConf + +from mllm.training.tally_metrics import Tally +from mllm.utils.stat_pack import StatPack + + +def get_from_nested_dict(dictio: dict, path: list[str]): + for sp in path[:-1]: + dictio = dictio[sp] + return dictio.get(path[-1]) + + +def set_at_path(dictio: dict, path: list[str], value): + for sp in path[:-1]: + if sp not in dictio: + dictio[sp] = {} + dictio = dictio[sp] + dictio[path[-1]] = value + + +def produce_tabular_render(inpath: str, outpath: str = None): + """ + Convert a JSON metrics dump into per-rollout CSV tables for easier inspection. + """ + with open(inpath, "r") as f: + data = json.load(f) + rollout_paths = data.keys() + for rollout_path in rollout_paths: + if outpath is None: + m_path = rollout_path.replace("/", "|") + m_path = m_path.replace(".json", "") + m_path = ( + os.path.split(inpath)[0] + + "/contextualized_tabular_renders/" + + m_path + + "_tabular_render.render.csv" + ) + # import pdb; pdb.set_trace() + os.makedirs(os.path.split(m_path)[0], exist_ok=True) + metrics = data[rollout_path] + d = {k: [] for k in metrics[0].keys()} + for m in metrics: + for k, v in m.items(): + d[k].append(v) + d = pd.DataFrame(d) + d.to_csv(m_path) + + +def get_metric_paths(data: list[dict]): + d = data[0] + paths = [] + + def traverse_dict(d, current_path=[]): + for key, value in d.items(): + new_path = current_path + [key] + if isinstance(value, dict): + traverse_dict(value, new_path) + else: + paths.append(new_path) + + traverse_dict(d) + return paths + + +def print_metric_paths(data: list[dict]): + paths = get_metric_paths(data) + for p in paths: + print(p) + + +def get_metric_iteration_list(data: list[dict], metric_path: list[str]): + if isinstance(metric_path, str): + metric_path = [metric_path] + sgl = [] + for d in data: + sgl.append(get_from_nested_dict(d, metric_path)) + return sgl + + +def to_1d_numeric(x): + """Return a 1-D float array (or None if not numeric). Accepts scalars, numpy arrays, or nested list/tuple of them.""" + if x is None: + return None + if isinstance(x, (int, float, np.number)): + return np.array([float(x)], dtype=float) + if isinstance(x, np.ndarray): + try: + return x.astype(float).ravel() + except Exception: + return None + if isinstance(x, (list, tuple)): + parts = [] + for e in x: + arr = to_1d_numeric(e) + if arr is not None and arr.size > 0: + parts.append(arr) + if parts: + return np.concatenate(parts) + return None + return None + + +def get_single_metric_vector(data, metric_path, iterations=None): + if isinstance(metric_path, str): + metric_path = [metric_path] + if iterations == None: + iterations = len(data) + vecs = [] + for d in data: + ar = get_from_nested_dict(d, metric_path) + arr = to_1d_numeric(ar) + if arr is not None: + vecs.append(arr) + + return np.concatenate(vecs) if vecs else np.empty(0, dtype=float) + + +def _load_metrics_file(file_path: str): + if not (file_path.endswith(".tally.pkl") or file_path.endswith(".pkl")): + raise ValueError("Only *.tally.pkl files are supported.") + import pickle + + with open(file_path, "rb") as f: + tree = pickle.load(f) + return tree + + +def get_leaf_items(array_tally: dict, prefix: list[str] = None): + if prefix is None: + prefix = [] + for key, value in array_tally.items(): + next_prefix = prefix + [str(key)] + if isinstance(value, dict): + yield from get_leaf_items(value, next_prefix) + else: + yield next_prefix, value + + +def _sanitize_filename_part(part: str) -> str: + s = part.replace("/", "|") + s = s.replace(" ", "_") + return s + + +def render_rt_tally_pkl_to_csvs(pkl_path: str, outdir: str): + """ + This method takes care of tokenwise logging. + """ + with open(pkl_path, "rb") as f: + payload = pickle.load(f) + # Backward compatibility: older tallies stored the dict directly + if isinstance(payload, dict) and "array_tally" in payload: + array_tally = payload.get("array_tally", {}) + else: + array_tally = payload + + os.makedirs(outdir, exist_ok=True) + trainer_id = os.path.basename(pkl_path).replace(".rt_tally.pkl", "") + for path_list, rollout_tally_items in get_leaf_items(array_tally): + # Create file and initiate writer + path_part = ".".join(_sanitize_filename_part(p) for p in path_list) + filename = f"{trainer_id}__{path_part}.render.csv" + out_path = os.path.join(outdir, filename) + + # Write metric rows to CSV + with open(out_path, "w", newline="") as f: + writer = csv.writer(f) + + # Write header row - need to determine metric column count from first rollout_tally_item + first_item = rollout_tally_items[0] + metric_cols = ( + first_item.metric_matrix.shape[1] + if first_item.metric_matrix.ndim > 1 + else 1 + ) + header = ["agent_id", "crn_id", "rollout_id"] + [ + f"t_{i}" for i in range(metric_cols) + ] + writer.writerow(header) + + for rollout_tally_item in rollout_tally_items: + crn_ids = rollout_tally_item.crn_ids + rollout_ids = rollout_tally_item.rollout_ids + agent_ids = rollout_tally_item.agent_ids + metric_matrix = rollout_tally_item.metric_matrix + for i in range(metric_matrix.shape[0]): + row_vals = metric_matrix[i].reshape(-1) + # Convert row_vals to a list to avoid numpy concatenation issues + row_vals = ( + row_vals.tolist() + if hasattr(row_vals, "tolist") + else list(row_vals) + ) + row_prefix = [ + agent_ids[i], + crn_ids[i], + rollout_ids[i], + ] + writer.writerow(row_prefix + row_vals) + + +def tally_to_stat_pack(tally: Dict[str, Any]): + stat_pack = StatPack() + if "array_tally" in tally: + tally = tally["array_tally"] + + # backward compatibility: will remove later, flatten keys in tally + def get_from_nested_dict(dictio: dict, path: list[str]): + for sp in path[:-1]: + dictio = dictio[sp] + return dictio.get(path[-1]) + + def get_metric_paths(tally: dict): + paths = [] + + def traverse_dict(tally, current_path=[]): + for key, value in tally.items(): + new_path = current_path + [key] + if isinstance(value, dict): + traverse_dict(value, new_path) + else: + paths.append(new_path) + + traverse_dict(tally) + return paths + + paths = get_metric_paths(tally) + modified_tally = {} + for p in paths: + val = get_from_nested_dict(tally, p) + modified_tally["_".join(p)] = np.mean(val) + del tally + tally = modified_tally + for key, value in tally.items(): + stat_pack.add_stat(key, value) + return stat_pack diff --git a/src_code_for_reproducibility/utils/get_coagent_id.py b/src_code_for_reproducibility/utils/get_coagent_id.py new file mode 100644 index 0000000000000000000000000000000000000000..f51674757ebb4ba1b0c18a36dd4ea9257564f890 --- /dev/null +++ b/src_code_for_reproducibility/utils/get_coagent_id.py @@ -0,0 +1,10 @@ +""" +File: mllm/utils/get_coagent_id.py +Summary: Helper for deriving co-agent identifiers from rollout metadata. +""" + + +def get_coagent_id(ids: list[str], agent_id: str) -> str | None: + for id in ids: + if id != agent_id: + return id diff --git a/src_code_for_reproducibility/utils/get_stochastic_game_lengths.py b/src_code_for_reproducibility/utils/get_stochastic_game_lengths.py new file mode 100644 index 0000000000000000000000000000000000000000..98a01013b063e7d2504f5f85b1c4a4f9d145412b --- /dev/null +++ b/src_code_for_reproducibility/utils/get_stochastic_game_lengths.py @@ -0,0 +1,33 @@ +""" +File: mllm/utils/get_stochastic_game_lengths.py +Summary: Computes distributions over stochastic game lengths. +""" + +import numpy as np + + +def get_stochastic_game_lengths( + max_length, nb_games, continuation_prob, same_length_batch=False +): + """ + Generates stochastic game lengths based on a geometric distribution. + + Args: + max_length (int): The maximum length a game can have. + nb_games (int): The number of games to generate lengths for. + continuation_prob (float): The probability of the game continuing after each round. + same_length_batch (bool): If True, all games will have the same length. + + Returns: + Array: An array of game lengths. + """ + if continuation_prob == 1: + return [max_length] * nb_games + if same_length_batch: + length = np.random.geometric(1 - continuation_prob, 1) + game_lengths = np.repeat(length, nb_games) + else: + game_lengths = np.random.geometric(1 - continuation_prob, nb_games) + + game_lengths = np.where(game_lengths > max_length, max_length, game_lengths) + return game_lengths.tolist() diff --git a/src_code_for_reproducibility/utils/resource_context.py b/src_code_for_reproducibility/utils/resource_context.py new file mode 100644 index 0000000000000000000000000000000000000000..e0713364ce54d2d20745162329fea9dec2665efd --- /dev/null +++ b/src_code_for_reproducibility/utils/resource_context.py @@ -0,0 +1,83 @@ +""" +File: mllm/utils/resource_context.py +Summary: Tracks system resource usage via a context manager. +""" + +import logging +import time +from contextlib import contextmanager + +import torch + + +def vram_usage(): + output = "" + for i in range(torch.cuda.device_count()): + gpu_memory_allocated = torch.cuda.memory_allocated(i) / ( + 1024**3 + ) # Convert bytes to GB + gpu_memory_reserved = torch.cuda.memory_reserved(i) / ( + 1024**3 + ) # Convert bytes to GB + output += f"GPU {i}: Memory Allocated: {gpu_memory_allocated:.2f} GB, Memory Reserved: {gpu_memory_reserved:.2f} GB" + return output + + +def ram_usage(): + import psutil + + process = psutil.Process() + memory_info = process.memory_info() + ram_used = memory_info.rss / (1024**3) # Convert bytes to GB + return f"RAM Usage: {ram_used:.2f} GB" + + +@contextmanager +def resource_logger_context(logger: logging.Logger, task_description: str): + """ + Context manager to log the resource usage of the current task. + Args: + logger: The logger to use to log the resource usage. + task_description: The description of the task to log. + Returns: + None + """ + try: + initial_time = time.time() + # Assume CUDA is available and use device 0 only + total_mem_bytes = torch.cuda.get_device_properties(0).total_memory + initial_total_bytes = torch.cuda.memory_allocated( + 0 + ) + torch.cuda.memory_reserved(0) + torch.cuda.reset_peak_memory_stats(0) + yield None + finally: + final_time = time.time() + # Ensure kernels within the block are accounted for + torch.cuda.synchronize() + + # Compute metrics + final_allocated_bytes = torch.cuda.memory_allocated(0) + final_reserved_bytes = torch.cuda.memory_reserved(0) + final_total_bytes = final_allocated_bytes + final_reserved_bytes + + delta_vram_percent_total = ( + 100 * (final_total_bytes - initial_total_bytes) / total_mem_bytes + if total_mem_bytes + else 0.0 + ) + current_percent_vram_taken = ( + 100 * final_total_bytes / total_mem_bytes if total_mem_bytes else 0.0 + ) + block_peak_percent = ( + 100 * torch.cuda.max_memory_allocated(0) / total_mem_bytes + if total_mem_bytes + else 0.0 + ) + delta_time_str = time.strftime( + "%H:%M:%S", time.gmtime(final_time - initial_time) + ) + + logger.info( + f"For task: {task_description}, ΔVRAM % (total): {delta_vram_percent_total:.2f}%, Current % of VRAM taken: {current_percent_vram_taken:.2f}%, Block Peak % of device VRAM: {block_peak_percent:.2f}%, ΔTime: {delta_time_str}" + ) diff --git a/src_code_for_reproducibility/utils/rollout_tree_chat_htmls.py b/src_code_for_reproducibility/utils/rollout_tree_chat_htmls.py new file mode 100644 index 0000000000000000000000000000000000000000..8806c1a9df9412c8bce5ae42c3de031b81db52f5 --- /dev/null +++ b/src_code_for_reproducibility/utils/rollout_tree_chat_htmls.py @@ -0,0 +1,1597 @@ +""" +File: mllm/utils/rollout_tree_chat_htmls.py +Summary: Renders rollout tree chat transcripts into HTML artifacts. +""" + +from pathlib import Path +from typing import List + +from mllm.utils.rollout_tree_gather_utils import * + + +def html_from_chat_turns(chat_turns: List[ChatTurnLog]) -> str: + """ + Render chat turns as a single, wrapping sequence of messages in time order. + Keep badge and message bubble styles, include time on every badge and + include rewards on assistant badges. Each message is individually + hide/show by click; when hidden, only the badge remains and "(...)" is + shown inline (not inside a bubble). + """ + import html + import re as _re + + # Prepare ordering: sort by (time_step, original_index) to keep stable order within same step + indexed_turns = list(enumerate(chat_turns)) + indexed_turns.sort(key=lambda t: (t[1].time_step, t[0])) + + # Get unique agent IDs and sort alphabetically for consistent assignment + # Agent with alphabetically lower name gets agent-0 (left, green) + # Agent with alphabetically higher name gets agent-1 (right, orange) + unique_agent_ids = sorted( + set(turn.agent_id for turn in chat_turns if turn.role == "assistant") + ) + agent_id_to_index = {aid: idx for idx, aid in enumerate(unique_agent_ids)} + + # CSS styles (simplified layout; no time-step or agent-column backgrounds) + css = """ + + """ + + # HTML structure + html_parts = [ + "", + "", + "", + "", + "Chat Turns", + css, + "", + "", + "", + '
', + '
', + '
', + '', + '', + '', + '', + '', + '900px', + "", + '', + '', + '", + "", + '', + '', + '', + "px", + "", + '', + '', + f'", + f'', + '|', + f'", + f'', + '', + "", + "
", + "
", + ] + + # Add Chat View + import html as _html_mod + + html_parts.append('
') + + # Helper function to add context annotation areas + def add_context_area(position: str, time_step: int): + context_key = f"round-context-{position}-{time_step}" + placeholder = f"Add context {position} round {time_step}..." + color_buttons = "" + # Add default/reset color button first + color_buttons += ( + f'
' + ) + for color_name, color_value in [ + ("red", "#d32f2f"), + ("orange", "#f57c00"), + ("yellow", "#f9a825"), + ("green", "#388e3c"), + ("blue", "#1976d2"), + ("purple", "#7b1fa2"), + ("gray", "#666666"), + ]: + color_buttons += ( + f'
' + ) + + html_parts.append( + f'
' + f'
' + f'
{color_buttons}
' + f"
" + ) + + # Helper function to add split agent context boxes + def add_split_agent_contexts(position: str, time_step: int): + color_buttons = "" + # Add default/reset color button first + color_buttons += ( + f'
' + ) + for color_name, color_value in [ + ("red", "#d32f2f"), + ("orange", "#f57c00"), + ("yellow", "#f9a825"), + ("green", "#388e3c"), + ("blue", "#1976d2"), + ("purple", "#7b1fa2"), + ("gray", "#666666"), + ]: + color_buttons += ( + f'
' + ) + + html_parts.append('
') + + # Agent 0 box + agent0_key = f"agent-context-0-{position}-{time_step}" + agent0_placeholder = f"..." + html_parts.append( + f'
' + f'
' + f'
{color_buttons}
' + f"
" + ) + + # Agent 1 box + agent1_key = f"agent-context-1-{position}-{time_step}" + agent1_placeholder = f"..." + html_parts.append( + f'
' + f'
' + f'
{color_buttons}
' + f"
" + ) + + html_parts.append("
") # split-agent-context + + last_time_step_chat = None + for original_index, turn in indexed_turns: + # Use agent index for CSS class (agent-0 or agent-1) instead of agent ID + agent_index = agent_id_to_index.get(turn.agent_id, 0) + agent_class = f"agent-{agent_index}" + role_class = f"role-{turn.role}" + + # Add time step divider and beginning context + if last_time_step_chat is None or turn.time_step != last_time_step_chat: + # Add end contexts for previous round (only regular context, not prompt summary) + if last_time_step_chat is not None: + add_context_area("end", last_time_step_chat) + + html_parts.append( + f'
' + f'⏱ Round {turn.time_step + 1}' + f"
" + ) + + # Add beginning contexts for new round (both context and prompt summary) + add_context_area("beginning", turn.time_step) + add_split_agent_contexts("beginning", turn.time_step) + + last_time_step_chat = turn.time_step + + # Build chat message with merge controls + html_parts.append( + f'
' + ) + + # Add merge control button + html_parts.append( + f'' + ) + + html_parts.append('
') + + # Header with agent name and reward (always show reward) + if turn.role == "assistant": + name = _html_mod.escape(turn.agent_id) + raw_val = turn.reward + if isinstance(raw_val, (int, float)): + reward_val = f"{raw_val:.4f}".rstrip("0").rstrip(".") + if len(reward_val) > 8: + reward_val = reward_val[:8] + "…" + else: + reward_val = str(raw_val) + header_html = ( + f'
' + f'🤖 {name}' + f'⚑ {reward_val}' + f"
" + ) + else: + name = _html_mod.escape(turn.agent_id) + header_html = f'
Prompt of {name}
' + + html_parts.append(header_html) + + # Reasoning content if present + if turn.reasoning_content: + _raw_reasoning = turn.reasoning_content.replace("\r\n", "\n") + _raw_reasoning = _re.sub(r"^\s*\n+", "", _raw_reasoning) + esc_reasoning = _html_mod.escape(_raw_reasoning) + html_parts.append( + f'" + ) + + # Message bubble + esc_content = _html_mod.escape(turn.content) + html_parts.append(f'
{esc_content}
') + + html_parts.append("
") # chat-message-content + html_parts.append("
") # chat-message + + # Add end contexts for the last round (only regular context, not prompt summary) + if last_time_step_chat is not None: + add_context_area("end", last_time_step_chat) + + html_parts.append("
") # flow-chat + html_parts.extend(["", ""]) + + return "\n".join(html_parts) + + +def export_html_from_rollout_tree(path: Path, outdir: Path, main_only: bool = False): + """Process a rollout tree file and generate HTML files for each path. + Creates separate HTML files for the main path and each branch path. + The main path is saved in the root output directory, while branch paths + are saved in a 'branches' subdirectory. + + Args: + path: Path to the rollout tree JSON file + outdir: Output directory for HTML files + main_only: If True, only export the main trajectory (default: False) + """ + root = load_rollout_tree(path) + mgid = root.id + + main_path, branch_paths = get_rollout_tree_paths(root) + + outdir.mkdir(parents=True, exist_ok=True) + + # Create branches subdirectory if we have branch paths + if not main_only and branch_paths: + branches_dir = outdir / f"mgid:{mgid}_branches_html_renders" + branches_dir.mkdir(parents=True, exist_ok=True) + + # Generate HTML for the main path + chat_turns = gather_all_chat_turns_for_path(main_path) + html_content = html_from_chat_turns(chat_turns) + output_file = outdir / f"mgid:{mgid}_main_html_render.render.html" + with open(output_file, "w", encoding="utf-8") as f: + f.write(html_content) + + # Generate HTML for each branch path + for path_obj in branch_paths: + chat_turns = gather_all_chat_turns_for_path(path_obj) + + html_content = html_from_chat_turns(chat_turns) + + path_id: str = path_obj.id + output_filename = f"{path_id}_html_render.render.html" + + output_file = branches_dir / output_filename + + with open(output_file, "w", encoding="utf-8") as f: + f.write(html_content) diff --git a/src_code_for_reproducibility/utils/rollout_tree_stats.py b/src_code_for_reproducibility/utils/rollout_tree_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..4725160156230d7efb89588c765fb5b63a7bbbe1 --- /dev/null +++ b/src_code_for_reproducibility/utils/rollout_tree_stats.py @@ -0,0 +1,55 @@ +""" +File: mllm/utils/rollout_tree_stats.py +Summary: Computes descriptive statistics from rollout tree collections. +""" + +from typing import Any, Callable, List, Tuple + +from mllm.markov_games.rollout_tree import RolloutTreeRootNode +from mllm.markov_games.simulation import SimulationStepLog +from mllm.utils.rollout_tree_gather_utils import ( + gather_simulation_step_logs, + get_rollout_tree_paths, +) +from mllm.utils.stat_pack import StatPack + + +def get_rollout_tree_stat_tally( + rollout_tree: RolloutTreeRootNode, + metrics: List[Callable[[SimulationStepLog], List[Tuple[str, float]]]], +) -> StatPack: + stat_tally = StatPack() + # get simulation step logs + node_list = get_rollout_tree_paths(rollout_tree)[0] + simulation_step_logs = gather_simulation_step_logs(node_list) + for simulation_step_log in simulation_step_logs: + for metric in metrics: + metric_result = metric(simulation_step_log) + if metric_result is not None: + for key, value in metric_result: + stat_tally.add_stat(key, value) + return stat_tally + + +def get_rollout_tree_mean_stats( + rollout_tree: RolloutTreeRootNode, metrics: List[Callable[[SimulationStepLog], Any]] +) -> StatPack: + """Get the mean stats for a rollout tree.""" + stat_tally = get_rollout_tree_stat_tally(rollout_tree, metrics) + return stat_tally.mean() + + +def get_mean_rollout_tree_stats( + rollout_trees: List[RolloutTreeRootNode], + metrics: List[Callable[[SimulationStepLog], Any]], +) -> StatPack: + """Get the mean stats for a list of rollout trees.""" + # Compute per-rollout means first, then aggregate them across the entire batch. + stat_tallies = [ + get_rollout_tree_mean_stats(rollout_tree, metrics) + for rollout_tree in rollout_trees + ] + mean_stat_tally = StatPack() + for stat_tally in stat_tallies: + mean_stat_tally.add_stats(stat_tally) + return mean_stat_tally.mean() diff --git a/src_code_for_reproducibility/utils/short_id_gen.py b/src_code_for_reproducibility/utils/short_id_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..6c08ffdc3362c767ea8916496ea5b0e1c01dbd7e --- /dev/null +++ b/src_code_for_reproducibility/utils/short_id_gen.py @@ -0,0 +1,16 @@ +""" +File: mllm/utils/short_id_gen.py +Summary: Generates short unique identifiers for experiment assets. +""" + +import uuid + + +def generate_short_id() -> int: + """ + Generates a short unique ID for tracking adapter versions. + + Returns: + int: An 8-digit integer ID. + """ + return int(str(uuid.uuid4().int)[:8]) diff --git a/src_code_for_reproducibility/utils/stat_pack.py b/src_code_for_reproducibility/utils/stat_pack.py new file mode 100644 index 0000000000000000000000000000000000000000..d4da475dafa8e3290ba9be10922be5687ac2c862 --- /dev/null +++ b/src_code_for_reproducibility/utils/stat_pack.py @@ -0,0 +1,117 @@ +""" +File: mllm/utils/stat_pack.py +Summary: Implements the StatPack container for incremental statistics. +""" + +import csv +import json +import os +import pickle +from collections import Counter +from copy import deepcopy +from locale import strcoll +from statistics import mean +from typing import Any, Dict, Iterator, List, Optional, Tuple, TypedDict + +import matplotlib.pyplot as plt +import numpy as np + +style_path = os.environ.get("ADALIGN_MPLSTYLE") +if style_path: + plt.style.use(style_path) + +import wandb + +from . import wandb_utils + + +class StatPack: + def __init__(self): + self.data = {} + + def add_stat(self, key: str, value: float | int | None): + assert ( + isinstance(value, float) or isinstance(value, int) or value is None + ), f"Value {value} is not a valid type" + if key not in self.data: + self.data[key] = [] + self.data[key].append(value) + + def add_stats(self, other: "StatPack"): + for key in other.keys(): + self.add_stat(key, other[key]) + + def __getitem__(self, key: str): + return self.data[key] + + def __setitem__(self, key: str, value: Any): + self.data[key] = value + + def __contains__(self, key: str): + return key in self.data + + def __len__(self): + return len(self.data) + + def __iter__(self): + return iter(self.data) + + def keys(self): + return self.data.keys() + + def values(self): + return self.data.values() + + def items(self): + return self.data.items() + + def mean(self): + mean_st = StatPack() + for key in self.keys(): + if isinstance(self[key], list): + # Ignore None entries so missing measurements do not bias the mean. + non_none_values = [v for v in self[key] if v is not None] + if non_none_values: + mean_st[key] = np.mean(np.array(non_none_values)) + else: + mean_st[key] = None + return mean_st + + def store_plots(self, folder: str): + os.makedirs(folder, exist_ok=True) + for key in self.keys(): + plt.figure(figsize=(10, 5)) + plt.plot(self[key]) + plt.title(key) + plt.savefig(os.path.join(folder, f"{key}.pdf")) + plt.close() + + def store_numpy(self, folder: str): + os.makedirs(folder, exist_ok=True) + for key in self.keys(): + # Sanitize filename components (avoid slashes, spaces, etc.) + safe_key = str(key).replace(os.sep, "_").replace("/", "_").replace(" ", "_") + values = self[key] + # Convert None to NaN for numpy compatibility + arr = np.array( + [(np.nan if (v is None) else v) for v in values], dtype=float + ) + np.save(os.path.join(folder, f"{safe_key}.npy"), arr) + + def store_json(self, folder: str, filename: str = "stats.json"): + os.makedirs(folder, exist_ok=True) + with open(os.path.join(folder, filename), "w") as f: + json.dump(self.data, f, indent=4) + + def store_csv(self, folder: str): + os.makedirs(folder, exist_ok=True) + for key in self.keys(): + with open(os.path.join(folder, f"stats.csv"), "w") as f: + writer = csv.writer(f) + writer.writerow([key] + self[key]) + + def store_pickle(self, folder: str): + os.makedirs(folder, exist_ok=True) + for key in self.keys(): + with open(os.path.join(folder, f"stats.pkl"), "wb") as f: + pickle.dump(self[key], f) diff --git a/src_code_for_reproducibility/utils/wandb_utils.py b/src_code_for_reproducibility/utils/wandb_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..46289bfdbb48b72fa3fe3b531d150447cfc1eb01 --- /dev/null +++ b/src_code_for_reproducibility/utils/wandb_utils.py @@ -0,0 +1,170 @@ +""" +File: mllm/utils/wandb_utils.py +Summary: Shared Weights & Biases helper functions. +""" + +import os +from typing import Any, Dict, Optional + +_WANDB_AVAILABLE = False +_WANDB_RUN = None + + +def _try_import_wandb(): + global _WANDB_AVAILABLE + if _WANDB_AVAILABLE: + return True + try: + import wandb # type: ignore + + _WANDB_AVAILABLE = True + return True + except Exception: + _WANDB_AVAILABLE = False + return False + + +def _safe_get(cfg: Dict[str, Any], path: list[str], default: Any = None) -> Any: + cur: Any = cfg + for key in path: + if not isinstance(cur, dict) or key not in cur: + return default + cur = cur[key] + return cur + + +def is_enabled(cfg: Dict[str, Any]) -> bool: + return bool(_safe_get(cfg, ["logging", "wandb", "enabled"], False)) + + +def init(cfg: Dict[str, Any], run_dir: str, run_name: Optional[str] = None) -> None: + """ + Initialize Weights & Biases if enabled in config. No-op if disabled or wandb not installed. + """ + global _WANDB_RUN + if not is_enabled(cfg): + return + if not _try_import_wandb(): + return + + import wandb # type: ignore + + project = _safe_get(cfg, ["logging", "wandb", "project"], "llm-negotiation") + entity = _safe_get(cfg, ["logging", "wandb", "entity"], None) + mode = _safe_get(cfg, ["logging", "wandb", "mode"], "online") + tags = _safe_get(cfg, ["logging", "wandb", "tags"], []) or [] + notes = _safe_get(cfg, ["logging", "wandb", "notes"], None) + group = _safe_get(cfg, ["logging", "wandb", "group"], None) + name = _safe_get(cfg, ["logging", "wandb", "name"], run_name) + + # Ensure files are written into the hydra run directory + os.makedirs(run_dir, exist_ok=True) + os.environ.setdefault("WANDB_DIR", run_dir) + + # Convert cfg to plain types for W&B config; fallback to minimal dictionary + try: + from omegaconf import OmegaConf # type: ignore + + cfg_container = OmegaConf.to_container(cfg, resolve=True) # type: ignore + except Exception: + cfg_container = cfg + + _WANDB_RUN = wandb.init( + project=project, + entity=entity, + mode=mode, + name=name, + group=group, + tags=tags, + notes=notes, + config=cfg_container, + dir=run_dir, + reinit=True, + ) + + +def log(metrics: Dict[str, Any], step: Optional[int] = None) -> None: + """Log a flat dictionary of metrics to W&B if active.""" + if not _WANDB_AVAILABLE or _WANDB_RUN is None: + return + try: + import wandb # type: ignore + + wandb.log(metrics if step is None else dict(metrics, step=step)) + except Exception: + pass + + +def _flatten(prefix: str, data: Dict[str, Any], out: Dict[str, Any]) -> None: + for k, v in data.items(): + key = f"{prefix}.{k}" if prefix else k + if isinstance(v, dict): + _flatten(key, v, out) + else: + out[key] = v + + +def _summarize_value(value: Any) -> Dict[str, Any]: + import numpy as np # local import to avoid hard dependency during disabled mode + + if value is None: + return {"none": 1} + # Scalars + if isinstance(value, (int, float)): + return {"value": float(value)} + # Lists or arrays + try: + arr = np.asarray(value) + if arr.size == 0: + return {"size": 0} + return { + "mean": float(np.nanmean(arr)), + "min": float(np.nanmin(arr)), + "max": float(np.nanmax(arr)), + "last": float(arr.reshape(-1)[-1]), + "size": int(arr.size), + } + except Exception: + # Fallback: string repr + return {"text": str(value)} + + +def log_tally( + array_tally: Dict[str, Any], prefix: str = "", step: Optional[int] = None +) -> None: + """ + Flatten and summarize Tally.array_tally and log to WandB. + Each leaf list/array is summarized with mean/min/max/last/size. + """ + if not _WANDB_AVAILABLE or _WANDB_RUN is None: + return + summarized: Dict[str, Any] = {} + + def walk(node: Any, path: list[str]): + if isinstance(node, dict): + for k, v in node.items(): + walk(v, path + [k]) + return + # node is a list of values accumulated over time + key = ".".join([p for p in ([prefix] if prefix else []) + path]) + try: + summary = _summarize_value(node) + for sk, sv in summary.items(): + summarized[f"{key}.{sk}"] = sv + except Exception: + summarized[f"{key}.error"] = 1 + + walk(array_tally, []) + if summarized: + log(summarized, step=step) + + +def log_flat_stats( + stats: Dict[str, Any], prefix: str = "", step: Optional[int] = None +) -> None: + if not _WANDB_AVAILABLE or _WANDB_RUN is None: + return + flat: Dict[str, Any] = {} + _flatten(prefix, stats, flat) + if flat: + log(flat, step=step)