Add files using upload-large-folder tool
Browse files- .hydra/config.yaml +240 -0
- .hydra/hydra.yaml +154 -0
- .hydra/overrides.yaml +1 -0
- seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/README.md +207 -0
- seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/agent_adapter/adapter_config.json +42 -0
- seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/critic_adapter/adapter_config.json +42 -0
- src_code_for_reproducibility/__init__.py +4 -0
- src_code_for_reproducibility/chat_utils/__pycache__/apply_template.cpython-312.pyc +0 -0
- src_code_for_reproducibility/chat_utils/__pycache__/chat_turn.cpython-312.pyc +0 -0
- src_code_for_reproducibility/chat_utils/__pycache__/template_specific.cpython-312.pyc +0 -0
- src_code_for_reproducibility/markov_games/__pycache__/markov_game.cpython-312.pyc +0 -0
- src_code_for_reproducibility/markov_games/__pycache__/simulation.cpython-312.pyc +0 -0
- src_code_for_reproducibility/markov_games/ipd/Ipd_hard_coded_agents.py +76 -0
- src_code_for_reproducibility/markov_games/ipd/__init__.py +11 -0
- src_code_for_reproducibility/markov_games/ipd/__pycache__/Ipd_hard_coded_agents.cpython-312.pyc +0 -0
- src_code_for_reproducibility/markov_games/ipd/__pycache__/__init__.cpython-312.pyc +0 -0
- src_code_for_reproducibility/markov_games/ipd/__pycache__/ipd_agent.cpython-312.pyc +0 -0
- src_code_for_reproducibility/markov_games/ipd/__pycache__/ipd_simulation.cpython-312.pyc +0 -0
- src_code_for_reproducibility/markov_games/ipd/__pycache__/ipd_statistics.cpython-312.pyc +0 -0
- src_code_for_reproducibility/markov_games/negotiation/__pycache__/dond_agent.cpython-312.pyc +0 -0
- src_code_for_reproducibility/markov_games/negotiation/__pycache__/dond_simulation.cpython-312.pyc +0 -0
- src_code_for_reproducibility/markov_games/negotiation/__pycache__/nego_agent.cpython-312.pyc +0 -0
- src_code_for_reproducibility/markov_games/negotiation/__pycache__/nego_hard_coded_policies.cpython-312.pyc +0 -0
- src_code_for_reproducibility/markov_games/negotiation/__pycache__/nego_simulation.cpython-312.pyc +0 -0
- src_code_for_reproducibility/markov_games/negotiation/__pycache__/negotiation_statistics.cpython-312.pyc +0 -0
- src_code_for_reproducibility/markov_games/negotiation/__pycache__/no_press_nego_simulation.cpython-312.pyc +0 -0
- src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_rps_agent.cpython-312.pyc +0 -0
- src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_rps_simulation.cpython-312.pyc +0 -0
- src_code_for_reproducibility/models/adapter_training_wrapper.py +104 -0
- src_code_for_reproducibility/models/inference_backend.py +44 -0
- src_code_for_reproducibility/models/inference_backend_dummy.py +59 -0
- src_code_for_reproducibility/models/large_language_model_api.py +174 -0
- src_code_for_reproducibility/models/large_language_model_local.py +361 -0
- src_code_for_reproducibility/models/scalar_critic.py +59 -0
- src_code_for_reproducibility/training/tally_rollout.py +116 -0
- src_code_for_reproducibility/training/tally_tokenwise.py +278 -0
- src_code_for_reproducibility/training/trainer_ad_align.py +505 -0
- src_code_for_reproducibility/training/trainer_common.py +1032 -0
- src_code_for_reproducibility/utils/__pycache__/__init__.cpython-312.pyc +0 -0
- src_code_for_reproducibility/utils/__pycache__/dict_get_path.cpython-312.pyc +0 -0
- src_code_for_reproducibility/utils/__pycache__/gather_training_stats.cpython-312.pyc +0 -0
- src_code_for_reproducibility/utils/__pycache__/get_coagent_id.cpython-312.pyc +0 -0
- src_code_for_reproducibility/utils/__pycache__/resource_context.cpython-312.pyc +0 -0
- src_code_for_reproducibility/utils/__pycache__/rollout_tree_chat_htmls.cpython-312.pyc +0 -0
- src_code_for_reproducibility/utils/__pycache__/rollout_tree_gather_utils.cpython-312.pyc +0 -0
- src_code_for_reproducibility/utils/__pycache__/rollout_tree_stats.cpython-312.pyc +0 -0
- src_code_for_reproducibility/utils/__pycache__/short_id_gen.cpython-312.pyc +0 -0
- src_code_for_reproducibility/utils/__pycache__/stat_pack.cpython-312.pyc +0 -0
- src_code_for_reproducibility/utils/__pycache__/update_start_epoch.cpython-312.pyc +0 -0
- src_code_for_reproducibility/utils/__pycache__/wandb_utils.cpython-312.pyc +0 -0
.hydra/config.yaml
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
experiment:
|
| 2 |
+
wandb_enabled: true
|
| 3 |
+
nb_epochs: 1000
|
| 4 |
+
nb_matches_per_iteration: 128
|
| 5 |
+
reinit_matches_each_it: true
|
| 6 |
+
checkpoint_every_n_iterations: 50
|
| 7 |
+
start_epoch: 0
|
| 8 |
+
resume_experiment: true
|
| 9 |
+
base_seed: 0
|
| 10 |
+
seed_group_size: 1
|
| 11 |
+
train: true
|
| 12 |
+
stat_methods_for_live_wandb: mllm.markov_games.ipd.ipd_statistics
|
| 13 |
+
name: ipd_vanilla_ad_align_no_agent_buffer
|
| 14 |
+
agent_buffer: false
|
| 15 |
+
keep_agent_buffer_count: ${lora_count}
|
| 16 |
+
agent_buffer_recent_k: -1
|
| 17 |
+
logging:
|
| 18 |
+
wandb:
|
| 19 |
+
enabled: false
|
| 20 |
+
project: llm-negotiation
|
| 21 |
+
entity: null
|
| 22 |
+
mode: online
|
| 23 |
+
name: null
|
| 24 |
+
group: null
|
| 25 |
+
tags: []
|
| 26 |
+
notes: null
|
| 27 |
+
temperature: 1.0
|
| 28 |
+
markov_games:
|
| 29 |
+
runner_method_name: LinearRunner
|
| 30 |
+
runner_kwargs: {}
|
| 31 |
+
simulation_class_name: IPD
|
| 32 |
+
simulation_init_args:
|
| 33 |
+
rounds_per_game: 10
|
| 34 |
+
cooperate_actions:
|
| 35 |
+
- <A>
|
| 36 |
+
defect_actions:
|
| 37 |
+
- <B>
|
| 38 |
+
reward: 3
|
| 39 |
+
punishment: 1
|
| 40 |
+
temptation: 5
|
| 41 |
+
sucker: 0
|
| 42 |
+
agents:
|
| 43 |
+
0:
|
| 44 |
+
agent_id: ${agent_0_id}
|
| 45 |
+
agent_name: Alice
|
| 46 |
+
agent_class_name: IPDAgent
|
| 47 |
+
policy_id: base_llm/agent_adapter
|
| 48 |
+
init_kwargs:
|
| 49 |
+
max_errors: 0
|
| 50 |
+
allow_reasoning: false
|
| 51 |
+
max_reasoning_chars: 0
|
| 52 |
+
cooperate_string: <A>
|
| 53 |
+
defect_string: <B>
|
| 54 |
+
strategy_prompt: ''
|
| 55 |
+
intro_prompt: "Welcome to an iterated game. You are playing against another\
|
| 56 |
+
\ agent.\n\nGame Rules:\n1. In each round, both agents simultaneously choose\
|
| 57 |
+
\ to either play A or B.\n2. The payoffs for each round are as follows:\n\
|
| 58 |
+
\ - If both agents play A, both receive 3 points.\n - If both agents\
|
| 59 |
+
\ play B, both receive 1 points.\n - If one agent plays A and the other\
|
| 60 |
+
\ plays B, the player who plays A receives 0 points and the player who plays\
|
| 61 |
+
\ B receives 5 points.\n3. The goal is to maximize your total points over\
|
| 62 |
+
\ all rounds.\n4. You will see the previous play of the other player after\
|
| 63 |
+
\ the completion of a round.\n\nIn order to play A, you must write \"<A>\"\
|
| 64 |
+
\ and nothing else.\nIn order to play B, you must write \"<B>\" and nothing\
|
| 65 |
+
\ else.\n\nDuring the entire game, you must only write \"<A>\" or \"<B>\"\
|
| 66 |
+
. If you write anything else, your response will be considered as an error.\n\
|
| 67 |
+
\nThe game starts now:\n"
|
| 68 |
+
goal_prompt: 'Your goal is to maximize your utility.
|
| 69 |
+
|
| 70 |
+
'
|
| 71 |
+
1:
|
| 72 |
+
agent_id: ${agent_1_id}
|
| 73 |
+
agent_name: Bob
|
| 74 |
+
agent_class_name: IPDAgent
|
| 75 |
+
policy_id: base_llm/agent_adapter
|
| 76 |
+
init_kwargs:
|
| 77 |
+
max_errors: 0
|
| 78 |
+
allow_reasoning: false
|
| 79 |
+
max_reasoning_chars: 0
|
| 80 |
+
cooperate_string: <A>
|
| 81 |
+
defect_string: <B>
|
| 82 |
+
strategy_prompt: ''
|
| 83 |
+
intro_prompt: "Welcome to an iterated game. You are playing against another\
|
| 84 |
+
\ agent.\n\nGame Rules:\n1. In each round, both agents simultaneously choose\
|
| 85 |
+
\ to either play A or B.\n2. The payoffs for each round are as follows:\n\
|
| 86 |
+
\ - If both agents play A, both receive 3 points.\n - If both agents\
|
| 87 |
+
\ play B, both receive 1 points.\n - If one agent plays A and the other\
|
| 88 |
+
\ plays B, the player who plays A receives 0 points and the player who plays\
|
| 89 |
+
\ B receives 5 points.\n3. The goal is to maximize your total points over\
|
| 90 |
+
\ all rounds.\n4. You will see the previous play of the other player after\
|
| 91 |
+
\ the completion of a round.\n\nIn order to play A, you must write \"<A>\"\
|
| 92 |
+
\ and nothing else.\nIn order to play B, you must write \"<B>\" and nothing\
|
| 93 |
+
\ else.\n\nDuring the entire game, you must only write \"<A>\" or \"<B>\"\
|
| 94 |
+
. If you write anything else, your response will be considered as an error.\n\
|
| 95 |
+
\nThe game starts now:\n"
|
| 96 |
+
goal_prompt: 'Your goal is to maximize your utility.
|
| 97 |
+
|
| 98 |
+
'
|
| 99 |
+
models:
|
| 100 |
+
base_llm:
|
| 101 |
+
class: LeanLocalLLM
|
| 102 |
+
init_args:
|
| 103 |
+
llm_id: base_llm
|
| 104 |
+
model_name: Qwen/Qwen2.5-7B-Instruct
|
| 105 |
+
inference_backend: vllm
|
| 106 |
+
hf_kwargs:
|
| 107 |
+
device_map: auto
|
| 108 |
+
torch_dtype: bfloat16
|
| 109 |
+
max_memory:
|
| 110 |
+
0: 20GiB
|
| 111 |
+
attn_implementation: flash_attention_2
|
| 112 |
+
inference_backend_init_kwargs:
|
| 113 |
+
enable_lora: true
|
| 114 |
+
seed: ${experiment.base_seed}
|
| 115 |
+
enable_prefix_caching: true
|
| 116 |
+
max_model_len: 10000.0
|
| 117 |
+
gpu_memory_utilization: 0.5
|
| 118 |
+
dtype: bfloat16
|
| 119 |
+
trust_remote_code: true
|
| 120 |
+
max_lora_rank: 32
|
| 121 |
+
enforce_eager: false
|
| 122 |
+
max_loras: ${lora_count}
|
| 123 |
+
max_cpu_loras: ${lora_count}
|
| 124 |
+
enable_sleep_mode: false
|
| 125 |
+
inference_backend_sampling_params:
|
| 126 |
+
temperature: ${temperature}
|
| 127 |
+
top_p: 1.0
|
| 128 |
+
max_tokens: 400
|
| 129 |
+
top_k: -1
|
| 130 |
+
logprobs: 0
|
| 131 |
+
adapter_configs:
|
| 132 |
+
agent_adapter:
|
| 133 |
+
task_type: CAUSAL_LM
|
| 134 |
+
r: 32
|
| 135 |
+
lora_alpha: 64
|
| 136 |
+
lora_dropout: 0.0
|
| 137 |
+
target_modules: all-linear
|
| 138 |
+
critic_adapter:
|
| 139 |
+
task_type: CAUSAL_LM
|
| 140 |
+
r: 32
|
| 141 |
+
lora_alpha: 64
|
| 142 |
+
lora_dropout: 0.0
|
| 143 |
+
target_modules: all-linear
|
| 144 |
+
enable_thinking: null
|
| 145 |
+
regex_max_attempts: 1
|
| 146 |
+
critics:
|
| 147 |
+
agent_critic:
|
| 148 |
+
module_pointer:
|
| 149 |
+
- base_llm
|
| 150 |
+
- critic_adapter
|
| 151 |
+
optimizers:
|
| 152 |
+
agent_optimizer:
|
| 153 |
+
module_pointer:
|
| 154 |
+
- base_llm
|
| 155 |
+
- agent_adapter
|
| 156 |
+
optimizer_class_name: torch.optim.Adam
|
| 157 |
+
init_args:
|
| 158 |
+
lr: 3.0e-06
|
| 159 |
+
weight_decay: 0.0
|
| 160 |
+
critic_optimizer:
|
| 161 |
+
module_pointer: agent_critic
|
| 162 |
+
optimizer_class_name: torch.optim.Adam
|
| 163 |
+
init_args:
|
| 164 |
+
lr: 3.0e-06
|
| 165 |
+
weight_decay: 0.0
|
| 166 |
+
trainers:
|
| 167 |
+
agent_trainer:
|
| 168 |
+
class: TrainerAdAlign
|
| 169 |
+
module_pointers:
|
| 170 |
+
policy:
|
| 171 |
+
- base_llm
|
| 172 |
+
- agent_adapter
|
| 173 |
+
policy_optimizer: agent_optimizer
|
| 174 |
+
critic: agent_critic
|
| 175 |
+
critic_optimizer: critic_optimizer
|
| 176 |
+
kwargs:
|
| 177 |
+
entropy_coeff: 0.01
|
| 178 |
+
entropy_topk: null
|
| 179 |
+
entropy_mask_regex: null
|
| 180 |
+
kl_coeff: 0.0
|
| 181 |
+
gradient_clipping: 1.0
|
| 182 |
+
restrict_tokens: null
|
| 183 |
+
mini_batch_size: 4
|
| 184 |
+
use_gradient_checkpointing: true
|
| 185 |
+
temperature: ${temperature}
|
| 186 |
+
device: cuda:0
|
| 187 |
+
use_gae: false
|
| 188 |
+
whiten_advantages: false
|
| 189 |
+
whiten_advantages_time_step_wise: false
|
| 190 |
+
skip_discounted_state_visitation: true
|
| 191 |
+
use_gae_lambda_annealing: false
|
| 192 |
+
gae_lambda_annealing_method: None
|
| 193 |
+
gae_lambda_annealing_method_params: None
|
| 194 |
+
gae_lambda_annealing_limit: 0.95
|
| 195 |
+
discount_factor: 0.9
|
| 196 |
+
use_rloo: true
|
| 197 |
+
enable_tokenwise_logging: false
|
| 198 |
+
pg_loss_normalization: nb_tokens
|
| 199 |
+
truncated_importance_sampling_ratio_cap: 2.0
|
| 200 |
+
reward_normalizing_constant: 5.0
|
| 201 |
+
ad_align_force_coop_first_step: false
|
| 202 |
+
ad_align_clipping: null
|
| 203 |
+
ad_align_gamma: 0.9
|
| 204 |
+
ad_align_exclude_k_equals_t: true
|
| 205 |
+
ad_align_use_sign: false
|
| 206 |
+
ad_align_beta: 0.5
|
| 207 |
+
use_old_ad_align: true
|
| 208 |
+
use_time_regularization: false
|
| 209 |
+
rloo_branch: false
|
| 210 |
+
reuse_baseline: false
|
| 211 |
+
train_on_which_data:
|
| 212 |
+
agent_trainer: ${agent_ids}
|
| 213 |
+
lora_count: 30
|
| 214 |
+
common_agent_kwargs:
|
| 215 |
+
max_errors: 0
|
| 216 |
+
allow_reasoning: false
|
| 217 |
+
max_reasoning_chars: 0
|
| 218 |
+
cooperate_string: <A>
|
| 219 |
+
defect_string: <B>
|
| 220 |
+
strategy_prompt: ''
|
| 221 |
+
intro_prompt: "Welcome to an iterated game. You are playing against another agent.\n\
|
| 222 |
+
\nGame Rules:\n1. In each round, both agents simultaneously choose to either play\
|
| 223 |
+
\ A or B.\n2. The payoffs for each round are as follows:\n - If both agents\
|
| 224 |
+
\ play A, both receive 3 points.\n - If both agents play B, both receive 1 points.\n\
|
| 225 |
+
\ - If one agent plays A and the other plays B, the player who plays A receives\
|
| 226 |
+
\ 0 points and the player who plays B receives 5 points.\n3. The goal is to maximize\
|
| 227 |
+
\ your total points over all rounds.\n4. You will see the previous play of the\
|
| 228 |
+
\ other player after the completion of a round.\n\nIn order to play A, you must\
|
| 229 |
+
\ write \"<A>\" and nothing else.\nIn order to play B, you must write \"<B>\"\
|
| 230 |
+
\ and nothing else.\n\nDuring the entire game, you must only write \"<A>\" or\
|
| 231 |
+
\ \"<B>\". If you write anything else, your response will be considered as an\
|
| 232 |
+
\ error.\n\nThe game starts now:\n"
|
| 233 |
+
goal_prompt: 'Your goal is to maximize your utility.
|
| 234 |
+
|
| 235 |
+
'
|
| 236 |
+
agent_0_id: Alice
|
| 237 |
+
agent_1_id: Bob
|
| 238 |
+
agent_ids:
|
| 239 |
+
- Alice
|
| 240 |
+
- Bob
|
.hydra/hydra.yaml
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hydra:
|
| 2 |
+
run:
|
| 3 |
+
dir: ${oc.env:SCRATCH}/llm_negotiation/${now:%Y_%m}/${experiment.name}
|
| 4 |
+
sweep:
|
| 5 |
+
dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
| 6 |
+
subdir: ${hydra.job.num}
|
| 7 |
+
launcher:
|
| 8 |
+
_target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
|
| 9 |
+
sweeper:
|
| 10 |
+
_target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
|
| 11 |
+
max_batch_size: null
|
| 12 |
+
params: null
|
| 13 |
+
help:
|
| 14 |
+
app_name: ${hydra.job.name}
|
| 15 |
+
header: '${hydra.help.app_name} is powered by Hydra.
|
| 16 |
+
|
| 17 |
+
'
|
| 18 |
+
footer: 'Powered by Hydra (https://hydra.cc)
|
| 19 |
+
|
| 20 |
+
Use --hydra-help to view Hydra specific help
|
| 21 |
+
|
| 22 |
+
'
|
| 23 |
+
template: '${hydra.help.header}
|
| 24 |
+
|
| 25 |
+
== Configuration groups ==
|
| 26 |
+
|
| 27 |
+
Compose your configuration from those groups (group=option)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
$APP_CONFIG_GROUPS
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
== Config ==
|
| 34 |
+
|
| 35 |
+
Override anything in the config (foo.bar=value)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
$CONFIG
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
${hydra.help.footer}
|
| 42 |
+
|
| 43 |
+
'
|
| 44 |
+
hydra_help:
|
| 45 |
+
template: 'Hydra (${hydra.runtime.version})
|
| 46 |
+
|
| 47 |
+
See https://hydra.cc for more info.
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
== Flags ==
|
| 51 |
+
|
| 52 |
+
$FLAGS_HELP
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
== Configuration groups ==
|
| 56 |
+
|
| 57 |
+
Compose your configuration from those groups (For example, append hydra/job_logging=disabled
|
| 58 |
+
to command line)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
$HYDRA_CONFIG_GROUPS
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
Use ''--cfg hydra'' to Show the Hydra config.
|
| 65 |
+
|
| 66 |
+
'
|
| 67 |
+
hydra_help: ???
|
| 68 |
+
hydra_logging:
|
| 69 |
+
version: 1
|
| 70 |
+
formatters:
|
| 71 |
+
simple:
|
| 72 |
+
format: '[%(asctime)s][HYDRA] %(message)s'
|
| 73 |
+
handlers:
|
| 74 |
+
console:
|
| 75 |
+
class: logging.StreamHandler
|
| 76 |
+
formatter: simple
|
| 77 |
+
stream: ext://sys.stdout
|
| 78 |
+
root:
|
| 79 |
+
level: INFO
|
| 80 |
+
handlers:
|
| 81 |
+
- console
|
| 82 |
+
loggers:
|
| 83 |
+
logging_example:
|
| 84 |
+
level: DEBUG
|
| 85 |
+
disable_existing_loggers: false
|
| 86 |
+
job_logging:
|
| 87 |
+
version: 1
|
| 88 |
+
formatters:
|
| 89 |
+
simple:
|
| 90 |
+
format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
|
| 91 |
+
handlers:
|
| 92 |
+
console:
|
| 93 |
+
class: logging.StreamHandler
|
| 94 |
+
formatter: simple
|
| 95 |
+
stream: ext://sys.stdout
|
| 96 |
+
file:
|
| 97 |
+
class: logging.FileHandler
|
| 98 |
+
formatter: simple
|
| 99 |
+
filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
|
| 100 |
+
root:
|
| 101 |
+
level: INFO
|
| 102 |
+
handlers:
|
| 103 |
+
- console
|
| 104 |
+
- file
|
| 105 |
+
disable_existing_loggers: false
|
| 106 |
+
env: {}
|
| 107 |
+
mode: RUN
|
| 108 |
+
searchpath: []
|
| 109 |
+
callbacks: {}
|
| 110 |
+
output_subdir: .hydra
|
| 111 |
+
overrides:
|
| 112 |
+
hydra:
|
| 113 |
+
- hydra.mode=RUN
|
| 114 |
+
task: []
|
| 115 |
+
job:
|
| 116 |
+
name: run
|
| 117 |
+
chdir: false
|
| 118 |
+
override_dirname: ''
|
| 119 |
+
id: ???
|
| 120 |
+
num: ???
|
| 121 |
+
config_name: ipd_vanilla_ad_align_no_agent_buffer.yaml
|
| 122 |
+
env_set: {}
|
| 123 |
+
env_copy: []
|
| 124 |
+
config:
|
| 125 |
+
override_dirname:
|
| 126 |
+
kv_sep: '='
|
| 127 |
+
item_sep: ','
|
| 128 |
+
exclude_keys: []
|
| 129 |
+
runtime:
|
| 130 |
+
version: 1.3.2
|
| 131 |
+
version_base: '1.1'
|
| 132 |
+
cwd: /home/mila/m/mohammed.muqeeth/AdAlignLLM
|
| 133 |
+
config_sources:
|
| 134 |
+
- path: hydra.conf
|
| 135 |
+
schema: pkg
|
| 136 |
+
provider: hydra
|
| 137 |
+
- path: /home/mila/m/mohammed.muqeeth/AdAlignLLM/configs
|
| 138 |
+
schema: file
|
| 139 |
+
provider: main
|
| 140 |
+
- path: ''
|
| 141 |
+
schema: structured
|
| 142 |
+
provider: schema
|
| 143 |
+
output_dir: /network/scratch/m/mohammed.muqeeth/llm_negotiation/2026_03/ipd_vanilla_ad_align_no_agent_buffer
|
| 144 |
+
choices:
|
| 145 |
+
hydra/env: default
|
| 146 |
+
hydra/callbacks: null
|
| 147 |
+
hydra/job_logging: default
|
| 148 |
+
hydra/hydra_logging: default
|
| 149 |
+
hydra/hydra_help: default
|
| 150 |
+
hydra/help: default
|
| 151 |
+
hydra/sweeper: basic
|
| 152 |
+
hydra/launcher: basic
|
| 153 |
+
hydra/output: default
|
| 154 |
+
verbose: false
|
.hydra/overrides.yaml
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
[]
|
seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/README.md
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
base_model: Qwen/Qwen2.5-7B-Instruct
|
| 3 |
+
library_name: peft
|
| 4 |
+
pipeline_tag: text-generation
|
| 5 |
+
tags:
|
| 6 |
+
- base_model:adapter:Qwen/Qwen2.5-7B-Instruct
|
| 7 |
+
- lora
|
| 8 |
+
- transformers
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# Model Card for Model ID
|
| 12 |
+
|
| 13 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
## Model Details
|
| 18 |
+
|
| 19 |
+
### Model Description
|
| 20 |
+
|
| 21 |
+
<!-- Provide a longer summary of what this model is. -->
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
- **Developed by:** [More Information Needed]
|
| 26 |
+
- **Funded by [optional]:** [More Information Needed]
|
| 27 |
+
- **Shared by [optional]:** [More Information Needed]
|
| 28 |
+
- **Model type:** [More Information Needed]
|
| 29 |
+
- **Language(s) (NLP):** [More Information Needed]
|
| 30 |
+
- **License:** [More Information Needed]
|
| 31 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
| 32 |
+
|
| 33 |
+
### Model Sources [optional]
|
| 34 |
+
|
| 35 |
+
<!-- Provide the basic links for the model. -->
|
| 36 |
+
|
| 37 |
+
- **Repository:** [More Information Needed]
|
| 38 |
+
- **Paper [optional]:** [More Information Needed]
|
| 39 |
+
- **Demo [optional]:** [More Information Needed]
|
| 40 |
+
|
| 41 |
+
## Uses
|
| 42 |
+
|
| 43 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 44 |
+
|
| 45 |
+
### Direct Use
|
| 46 |
+
|
| 47 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
| 48 |
+
|
| 49 |
+
[More Information Needed]
|
| 50 |
+
|
| 51 |
+
### Downstream Use [optional]
|
| 52 |
+
|
| 53 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
| 54 |
+
|
| 55 |
+
[More Information Needed]
|
| 56 |
+
|
| 57 |
+
### Out-of-Scope Use
|
| 58 |
+
|
| 59 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
| 60 |
+
|
| 61 |
+
[More Information Needed]
|
| 62 |
+
|
| 63 |
+
## Bias, Risks, and Limitations
|
| 64 |
+
|
| 65 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
| 66 |
+
|
| 67 |
+
[More Information Needed]
|
| 68 |
+
|
| 69 |
+
### Recommendations
|
| 70 |
+
|
| 71 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
| 72 |
+
|
| 73 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
| 74 |
+
|
| 75 |
+
## How to Get Started with the Model
|
| 76 |
+
|
| 77 |
+
Use the code below to get started with the model.
|
| 78 |
+
|
| 79 |
+
[More Information Needed]
|
| 80 |
+
|
| 81 |
+
## Training Details
|
| 82 |
+
|
| 83 |
+
### Training Data
|
| 84 |
+
|
| 85 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 86 |
+
|
| 87 |
+
[More Information Needed]
|
| 88 |
+
|
| 89 |
+
### Training Procedure
|
| 90 |
+
|
| 91 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
| 92 |
+
|
| 93 |
+
#### Preprocessing [optional]
|
| 94 |
+
|
| 95 |
+
[More Information Needed]
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
#### Training Hyperparameters
|
| 99 |
+
|
| 100 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 101 |
+
|
| 102 |
+
#### Speeds, Sizes, Times [optional]
|
| 103 |
+
|
| 104 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
| 105 |
+
|
| 106 |
+
[More Information Needed]
|
| 107 |
+
|
| 108 |
+
## Evaluation
|
| 109 |
+
|
| 110 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
| 111 |
+
|
| 112 |
+
### Testing Data, Factors & Metrics
|
| 113 |
+
|
| 114 |
+
#### Testing Data
|
| 115 |
+
|
| 116 |
+
<!-- This should link to a Dataset Card if possible. -->
|
| 117 |
+
|
| 118 |
+
[More Information Needed]
|
| 119 |
+
|
| 120 |
+
#### Factors
|
| 121 |
+
|
| 122 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 123 |
+
|
| 124 |
+
[More Information Needed]
|
| 125 |
+
|
| 126 |
+
#### Metrics
|
| 127 |
+
|
| 128 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 129 |
+
|
| 130 |
+
[More Information Needed]
|
| 131 |
+
|
| 132 |
+
### Results
|
| 133 |
+
|
| 134 |
+
[More Information Needed]
|
| 135 |
+
|
| 136 |
+
#### Summary
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
## Model Examination [optional]
|
| 141 |
+
|
| 142 |
+
<!-- Relevant interpretability work for the model goes here -->
|
| 143 |
+
|
| 144 |
+
[More Information Needed]
|
| 145 |
+
|
| 146 |
+
## Environmental Impact
|
| 147 |
+
|
| 148 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 149 |
+
|
| 150 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 151 |
+
|
| 152 |
+
- **Hardware Type:** [More Information Needed]
|
| 153 |
+
- **Hours used:** [More Information Needed]
|
| 154 |
+
- **Cloud Provider:** [More Information Needed]
|
| 155 |
+
- **Compute Region:** [More Information Needed]
|
| 156 |
+
- **Carbon Emitted:** [More Information Needed]
|
| 157 |
+
|
| 158 |
+
## Technical Specifications [optional]
|
| 159 |
+
|
| 160 |
+
### Model Architecture and Objective
|
| 161 |
+
|
| 162 |
+
[More Information Needed]
|
| 163 |
+
|
| 164 |
+
### Compute Infrastructure
|
| 165 |
+
|
| 166 |
+
[More Information Needed]
|
| 167 |
+
|
| 168 |
+
#### Hardware
|
| 169 |
+
|
| 170 |
+
[More Information Needed]
|
| 171 |
+
|
| 172 |
+
#### Software
|
| 173 |
+
|
| 174 |
+
[More Information Needed]
|
| 175 |
+
|
| 176 |
+
## Citation [optional]
|
| 177 |
+
|
| 178 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 179 |
+
|
| 180 |
+
**BibTeX:**
|
| 181 |
+
|
| 182 |
+
[More Information Needed]
|
| 183 |
+
|
| 184 |
+
**APA:**
|
| 185 |
+
|
| 186 |
+
[More Information Needed]
|
| 187 |
+
|
| 188 |
+
## Glossary [optional]
|
| 189 |
+
|
| 190 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 191 |
+
|
| 192 |
+
[More Information Needed]
|
| 193 |
+
|
| 194 |
+
## More Information [optional]
|
| 195 |
+
|
| 196 |
+
[More Information Needed]
|
| 197 |
+
|
| 198 |
+
## Model Card Authors [optional]
|
| 199 |
+
|
| 200 |
+
[More Information Needed]
|
| 201 |
+
|
| 202 |
+
## Model Card Contact
|
| 203 |
+
|
| 204 |
+
[More Information Needed]
|
| 205 |
+
### Framework versions
|
| 206 |
+
|
| 207 |
+
- PEFT 0.17.1
|
seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/agent_adapter/adapter_config.json
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"alpha_pattern": {},
|
| 3 |
+
"auto_mapping": null,
|
| 4 |
+
"base_model_name_or_path": "Qwen/Qwen2.5-7B-Instruct",
|
| 5 |
+
"bias": "none",
|
| 6 |
+
"corda_config": null,
|
| 7 |
+
"eva_config": null,
|
| 8 |
+
"exclude_modules": null,
|
| 9 |
+
"fan_in_fan_out": false,
|
| 10 |
+
"inference_mode": true,
|
| 11 |
+
"init_lora_weights": true,
|
| 12 |
+
"layer_replication": null,
|
| 13 |
+
"layers_pattern": null,
|
| 14 |
+
"layers_to_transform": null,
|
| 15 |
+
"loftq_config": {},
|
| 16 |
+
"lora_alpha": 64,
|
| 17 |
+
"lora_bias": false,
|
| 18 |
+
"lora_dropout": 0.0,
|
| 19 |
+
"megatron_config": null,
|
| 20 |
+
"megatron_core": "megatron.core",
|
| 21 |
+
"modules_to_save": null,
|
| 22 |
+
"peft_type": "LORA",
|
| 23 |
+
"qalora_group_size": 16,
|
| 24 |
+
"r": 32,
|
| 25 |
+
"rank_pattern": {},
|
| 26 |
+
"revision": null,
|
| 27 |
+
"target_modules": [
|
| 28 |
+
"up_proj",
|
| 29 |
+
"k_proj",
|
| 30 |
+
"q_proj",
|
| 31 |
+
"down_proj",
|
| 32 |
+
"v_proj",
|
| 33 |
+
"o_proj",
|
| 34 |
+
"gate_proj"
|
| 35 |
+
],
|
| 36 |
+
"target_parameters": null,
|
| 37 |
+
"task_type": "CAUSAL_LM",
|
| 38 |
+
"trainable_token_indices": null,
|
| 39 |
+
"use_dora": false,
|
| 40 |
+
"use_qalora": false,
|
| 41 |
+
"use_rslora": false
|
| 42 |
+
}
|
seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/critic_adapter/adapter_config.json
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"alpha_pattern": {},
|
| 3 |
+
"auto_mapping": null,
|
| 4 |
+
"base_model_name_or_path": "Qwen/Qwen2.5-7B-Instruct",
|
| 5 |
+
"bias": "none",
|
| 6 |
+
"corda_config": null,
|
| 7 |
+
"eva_config": null,
|
| 8 |
+
"exclude_modules": null,
|
| 9 |
+
"fan_in_fan_out": false,
|
| 10 |
+
"inference_mode": true,
|
| 11 |
+
"init_lora_weights": true,
|
| 12 |
+
"layer_replication": null,
|
| 13 |
+
"layers_pattern": null,
|
| 14 |
+
"layers_to_transform": null,
|
| 15 |
+
"loftq_config": {},
|
| 16 |
+
"lora_alpha": 64,
|
| 17 |
+
"lora_bias": false,
|
| 18 |
+
"lora_dropout": 0.0,
|
| 19 |
+
"megatron_config": null,
|
| 20 |
+
"megatron_core": "megatron.core",
|
| 21 |
+
"modules_to_save": null,
|
| 22 |
+
"peft_type": "LORA",
|
| 23 |
+
"qalora_group_size": 16,
|
| 24 |
+
"r": 32,
|
| 25 |
+
"rank_pattern": {},
|
| 26 |
+
"revision": null,
|
| 27 |
+
"target_modules": [
|
| 28 |
+
"up_proj",
|
| 29 |
+
"k_proj",
|
| 30 |
+
"q_proj",
|
| 31 |
+
"down_proj",
|
| 32 |
+
"v_proj",
|
| 33 |
+
"o_proj",
|
| 34 |
+
"gate_proj"
|
| 35 |
+
],
|
| 36 |
+
"target_parameters": null,
|
| 37 |
+
"task_type": "CAUSAL_LM",
|
| 38 |
+
"trainable_token_indices": null,
|
| 39 |
+
"use_dora": false,
|
| 40 |
+
"use_qalora": false,
|
| 41 |
+
"use_rslora": false
|
| 42 |
+
}
|
src_code_for_reproducibility/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: mllm/__init__.py
|
| 3 |
+
Summary: Initializes the multi-agent large language model package namespace.
|
| 4 |
+
"""
|
src_code_for_reproducibility/chat_utils/__pycache__/apply_template.cpython-312.pyc
ADDED
|
Binary file (4.13 kB). View file
|
|
|
src_code_for_reproducibility/chat_utils/__pycache__/chat_turn.cpython-312.pyc
ADDED
|
Binary file (1.47 kB). View file
|
|
|
src_code_for_reproducibility/chat_utils/__pycache__/template_specific.cpython-312.pyc
ADDED
|
Binary file (4.4 kB). View file
|
|
|
src_code_for_reproducibility/markov_games/__pycache__/markov_game.cpython-312.pyc
ADDED
|
Binary file (10.2 kB). View file
|
|
|
src_code_for_reproducibility/markov_games/__pycache__/simulation.cpython-312.pyc
ADDED
|
Binary file (4.26 kB). View file
|
|
|
src_code_for_reproducibility/markov_games/ipd/Ipd_hard_coded_agents.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: mllm/markov_games/ipd/Ipd_hard_coded_agents.py
|
| 3 |
+
Summary: Contains hand-crafted IPD policies used as deterministic baselines.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Any, Tuple
|
| 8 |
+
|
| 9 |
+
from mllm.markov_games.ipd.ipd_agent import IPDAgent
|
| 10 |
+
from mllm.markov_games.rollout_tree import AgentActLog, ChatTurn
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class AlwaysCooperateIPDAgent(IPDAgent):
|
| 15 |
+
async def act(self, observation) -> Tuple[Any, AgentActLog]:
|
| 16 |
+
"""
|
| 17 |
+
Always plays the cooperate action, ignoring observation.
|
| 18 |
+
Returns the configured cooperate_string so the simulation parses it as "C".
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
action = self.cooperate_string
|
| 22 |
+
|
| 23 |
+
# Log a minimal, structured chat turn for consistency with other agents
|
| 24 |
+
turn_text = f"Playing cooperate: {action}"
|
| 25 |
+
self.state.chat_history.append(
|
| 26 |
+
ChatTurn(
|
| 27 |
+
agent_id=self.agent_id,
|
| 28 |
+
role="assistant",
|
| 29 |
+
content=turn_text,
|
| 30 |
+
is_state_end=True,
|
| 31 |
+
)
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
act_log = AgentActLog(
|
| 35 |
+
chat_turns=[self.state.chat_history[-1]],
|
| 36 |
+
info=None,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
# Advance internal counters similar to IPDAgent semantics
|
| 40 |
+
self.state.chat_counter = len(self.state.chat_history)
|
| 41 |
+
self.state.round_nb = observation.round_nb
|
| 42 |
+
|
| 43 |
+
return action, act_log
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@dataclass
|
| 47 |
+
class AlwaysDefectIPDAgent(IPDAgent):
|
| 48 |
+
async def act(self, observation) -> Tuple[Any, AgentActLog]:
|
| 49 |
+
"""
|
| 50 |
+
Always plays the defect action, ignoring observation.
|
| 51 |
+
Returns the configured defect_string so the simulation parses it as "D".
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
action = self.defect_string
|
| 55 |
+
|
| 56 |
+
# Log a minimal, structured chat turn for consistency with other agents
|
| 57 |
+
turn_text = f"Playing defect: {action}"
|
| 58 |
+
self.state.chat_history.append(
|
| 59 |
+
ChatTurn(
|
| 60 |
+
agent_id=self.agent_id,
|
| 61 |
+
role="assistant",
|
| 62 |
+
content=turn_text,
|
| 63 |
+
is_state_end=True,
|
| 64 |
+
)
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
act_log = AgentActLog(
|
| 68 |
+
chat_turns=[self.state.chat_history[-1]],
|
| 69 |
+
info=None,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# Advance internal counters similar to IPDAgent semantics
|
| 73 |
+
self.state.chat_counter = len(self.state.chat_history)
|
| 74 |
+
self.state.round_nb = observation.round_nb
|
| 75 |
+
|
| 76 |
+
return action, act_log
|
src_code_for_reproducibility/markov_games/ipd/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: mllm/markov_games/ipd/__init__.py
|
| 3 |
+
Summary: Marks the Iterated Prisoner's Dilemma subpackage.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from .Ipd_hard_coded_agents import AlwaysCooperateIPDAgent, AlwaysDefectIPDAgent
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"AlwaysCooperateIPDAgent",
|
| 10 |
+
"AlwaysDefectIPDAgent",
|
| 11 |
+
]
|
src_code_for_reproducibility/markov_games/ipd/__pycache__/Ipd_hard_coded_agents.cpython-312.pyc
ADDED
|
Binary file (3.06 kB). View file
|
|
|
src_code_for_reproducibility/markov_games/ipd/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (439 Bytes). View file
|
|
|
src_code_for_reproducibility/markov_games/ipd/__pycache__/ipd_agent.cpython-312.pyc
ADDED
|
Binary file (4.98 kB). View file
|
|
|
src_code_for_reproducibility/markov_games/ipd/__pycache__/ipd_simulation.cpython-312.pyc
ADDED
|
Binary file (6.87 kB). View file
|
|
|
src_code_for_reproducibility/markov_games/ipd/__pycache__/ipd_statistics.cpython-312.pyc
ADDED
|
Binary file (1.42 kB). View file
|
|
|
src_code_for_reproducibility/markov_games/negotiation/__pycache__/dond_agent.cpython-312.pyc
ADDED
|
Binary file (4.66 kB). View file
|
|
|
src_code_for_reproducibility/markov_games/negotiation/__pycache__/dond_simulation.cpython-312.pyc
ADDED
|
Binary file (10.7 kB). View file
|
|
|
src_code_for_reproducibility/markov_games/negotiation/__pycache__/nego_agent.cpython-312.pyc
ADDED
|
Binary file (11.8 kB). View file
|
|
|
src_code_for_reproducibility/markov_games/negotiation/__pycache__/nego_hard_coded_policies.cpython-312.pyc
ADDED
|
Binary file (3.39 kB). View file
|
|
|
src_code_for_reproducibility/markov_games/negotiation/__pycache__/nego_simulation.cpython-312.pyc
ADDED
|
Binary file (12.6 kB). View file
|
|
|
src_code_for_reproducibility/markov_games/negotiation/__pycache__/negotiation_statistics.cpython-312.pyc
ADDED
|
Binary file (14.3 kB). View file
|
|
|
src_code_for_reproducibility/markov_games/negotiation/__pycache__/no_press_nego_simulation.cpython-312.pyc
ADDED
|
Binary file (9.73 kB). View file
|
|
|
src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_rps_agent.cpython-312.pyc
ADDED
|
Binary file (6.05 kB). View file
|
|
|
src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_rps_simulation.cpython-312.pyc
ADDED
|
Binary file (11.7 kB). View file
|
|
|
src_code_for_reproducibility/models/adapter_training_wrapper.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: mllm/models/adapter_training_wrapper.py
|
| 3 |
+
Summary: Wraps a shared LLM with adapter-specific PEFT handling for training.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
from typing import Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from peft import LoraConfig, get_peft_model
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class AdapterWrapper(nn.Module):
|
| 17 |
+
"""
|
| 18 |
+
A thin façade that
|
| 19 |
+
• keeps a reference to a *shared* PEFT-wrapped model,
|
| 20 |
+
• ensures `set_adapter(adapter)` is called on every forward,
|
| 21 |
+
• exposes only the parameters that should be trained for that adapter
|
| 22 |
+
(plus whatever extra modules you name).
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
shared_llm: nn.Module,
|
| 28 |
+
adapter_id: str,
|
| 29 |
+
lora_config: dict,
|
| 30 |
+
path: Union[str, None] = None,
|
| 31 |
+
):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.shared_llm = shared_llm
|
| 34 |
+
self.adapter_id = adapter_id
|
| 35 |
+
lora_config = LoraConfig(**lora_config)
|
| 36 |
+
# this modifies the shared llm in place, adding a lora adapter inside
|
| 37 |
+
self.shared_llm = get_peft_model(
|
| 38 |
+
model=shared_llm,
|
| 39 |
+
peft_config=lora_config,
|
| 40 |
+
adapter_name=adapter_id,
|
| 41 |
+
)
|
| 42 |
+
self.shared_llm.train()
|
| 43 |
+
# Load external adapter weights if provided
|
| 44 |
+
loaded_from: str | None = None
|
| 45 |
+
if path:
|
| 46 |
+
try:
|
| 47 |
+
# Supports both local filesystem paths and HF Hub repo IDs
|
| 48 |
+
self.shared_llm.load_adapter(
|
| 49 |
+
is_trainable=True,
|
| 50 |
+
model_id=path,
|
| 51 |
+
adapter_name=adapter_id,
|
| 52 |
+
)
|
| 53 |
+
loaded_from = path
|
| 54 |
+
except (
|
| 55 |
+
Exception
|
| 56 |
+
) as exc: # noqa: BLE001 - want to log any load failure context
|
| 57 |
+
logger.warning(
|
| 58 |
+
f"Adapter '{adapter_id}': failed to load from '{path}': {exc}"
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
if loaded_from:
|
| 62 |
+
logger.info(
|
| 63 |
+
f"Adapter '{adapter_id}': loaded initial weights from '{loaded_from}'."
|
| 64 |
+
)
|
| 65 |
+
else:
|
| 66 |
+
logger.info(
|
| 67 |
+
f"Adapter '{adapter_id}': initialized with fresh weights (no initial weights found)."
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
def parameters(self, recurse: bool = True):
|
| 71 |
+
"""
|
| 72 |
+
"recurse" is just for pytorch compatibility
|
| 73 |
+
"""
|
| 74 |
+
self.shared_llm.set_adapter(self.adapter_id)
|
| 75 |
+
params = [p for p in self.shared_llm.parameters() if p.requires_grad]
|
| 76 |
+
|
| 77 |
+
return params
|
| 78 |
+
|
| 79 |
+
def get_base_model_logits(self, contexts):
|
| 80 |
+
"""
|
| 81 |
+
Run the base model (without adapter) in inference mode, without tracking gradients.
|
| 82 |
+
This is useful to get reference logits for KL-divergence computation.
|
| 83 |
+
"""
|
| 84 |
+
with torch.no_grad():
|
| 85 |
+
with self.shared_llm.disable_adapter():
|
| 86 |
+
return self.shared_llm(input_ids=contexts)[0]
|
| 87 |
+
|
| 88 |
+
def forward(self, *args, **kwargs):
|
| 89 |
+
self.shared_llm.set_adapter(self.adapter_id)
|
| 90 |
+
return self.shared_llm(*args, **kwargs)
|
| 91 |
+
|
| 92 |
+
def save_pretrained(self, save_path):
|
| 93 |
+
self.shared_llm.save_pretrained(save_path)
|
| 94 |
+
|
| 95 |
+
def gradient_checkpointing_enable(self, *args, **kwargs):
|
| 96 |
+
self.shared_llm.gradient_checkpointing_enable(*args, **kwargs)
|
| 97 |
+
|
| 98 |
+
@property
|
| 99 |
+
def dtype(self):
|
| 100 |
+
return self.shared_llm.dtype
|
| 101 |
+
|
| 102 |
+
@property
|
| 103 |
+
def device(self):
|
| 104 |
+
return self.shared_llm.device
|
src_code_for_reproducibility/models/inference_backend.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: mllm/models/inference_backend.py
|
| 3 |
+
Summary: Declares the inference backend interface and shared dataclasses.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from abc import ABC, abstractmethod
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Any, Optional
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class LLMInferenceOutput:
|
| 13 |
+
content: str
|
| 14 |
+
reasoning_content: str | None = None
|
| 15 |
+
log_probs: list[float] | None = None
|
| 16 |
+
out_token_ids: list[int] | None = None
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class LLMInferenceBackend(ABC):
|
| 20 |
+
@abstractmethod
|
| 21 |
+
def __init__(self, **kwargs):
|
| 22 |
+
...
|
| 23 |
+
|
| 24 |
+
@abstractmethod
|
| 25 |
+
def prepare_adapter(
|
| 26 |
+
self, adapter_id: str, weights_got_updated: bool = False
|
| 27 |
+
) -> None:
|
| 28 |
+
"""Ensure adapter is ready/loaded for next generation call."""
|
| 29 |
+
|
| 30 |
+
@abstractmethod
|
| 31 |
+
async def generate(self, prompt: list[dict], regex: Optional[str] = None) -> str:
|
| 32 |
+
...
|
| 33 |
+
|
| 34 |
+
@abstractmethod
|
| 35 |
+
def toggle_training_mode(self) -> None:
|
| 36 |
+
...
|
| 37 |
+
|
| 38 |
+
@abstractmethod
|
| 39 |
+
def toggle_eval_mode(self) -> None:
|
| 40 |
+
...
|
| 41 |
+
|
| 42 |
+
@abstractmethod
|
| 43 |
+
def shutdown(self) -> None:
|
| 44 |
+
...
|
src_code_for_reproducibility/models/inference_backend_dummy.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: mllm/models/inference_backend_dummy.py
|
| 3 |
+
Summary: Stub inference backend that returns synthetic completions for tests.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import asyncio
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
import rstr
|
| 10 |
+
from transformers import AutoTokenizer
|
| 11 |
+
|
| 12 |
+
from mllm.models.inference_backend import LLMInferenceBackend, LLMInferenceOutput
|
| 13 |
+
from mllm.utils.short_id_gen import generate_short_id
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class DummyInferenceBackend(LLMInferenceBackend):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
*args,
|
| 20 |
+
**kwargs,
|
| 21 |
+
):
|
| 22 |
+
pass
|
| 23 |
+
|
| 24 |
+
def prepare_adapter(
|
| 25 |
+
self,
|
| 26 |
+
adapter_id: Optional[str],
|
| 27 |
+
weights_got_updated: bool,
|
| 28 |
+
adapter_path: Optional[str] = None,
|
| 29 |
+
) -> None:
|
| 30 |
+
pass
|
| 31 |
+
|
| 32 |
+
async def toggle_training_mode(self) -> None:
|
| 33 |
+
await asyncio.sleep(0)
|
| 34 |
+
pass
|
| 35 |
+
|
| 36 |
+
async def toggle_eval_mode(self) -> None:
|
| 37 |
+
await asyncio.sleep(0)
|
| 38 |
+
pass
|
| 39 |
+
|
| 40 |
+
def shutdown(self) -> None:
|
| 41 |
+
pass
|
| 42 |
+
|
| 43 |
+
async def generate(
|
| 44 |
+
self,
|
| 45 |
+
prompt_text: str,
|
| 46 |
+
regex: Optional[str] = None,
|
| 47 |
+
extract_thinking: bool = False,
|
| 48 |
+
) -> LLMInferenceOutput:
|
| 49 |
+
if regex:
|
| 50 |
+
# Create random string that respects the regex
|
| 51 |
+
return LLMInferenceOutput(
|
| 52 |
+
content=rstr.xeger(regex),
|
| 53 |
+
reasoning_content="I don't think, I am a dummy backend.",
|
| 54 |
+
)
|
| 55 |
+
else:
|
| 56 |
+
return LLMInferenceOutput(
|
| 57 |
+
content="I am a dummy backend without a regex.",
|
| 58 |
+
reasoning_content="I don't think, I am a dummy backend.",
|
| 59 |
+
)
|
src_code_for_reproducibility/models/large_language_model_api.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: mllm/models/large_language_model_api.py
|
| 3 |
+
Summary: Implements API-based large-language-model inference adapters.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import asyncio
|
| 9 |
+
import copy
|
| 10 |
+
import os
|
| 11 |
+
import random
|
| 12 |
+
import re
|
| 13 |
+
from typing import Any, Callable, Dict, List, Optional, Sequence
|
| 14 |
+
|
| 15 |
+
import backoff
|
| 16 |
+
from openai import AsyncOpenAI, OpenAIError
|
| 17 |
+
|
| 18 |
+
from mllm.markov_games.rollout_tree import ChatTurn
|
| 19 |
+
from mllm.models.inference_backend import LLMInferenceOutput
|
| 20 |
+
|
| 21 |
+
# Static list copied from the public OpenAI docs until a discovery endpoint is exposed.
|
| 22 |
+
reasoning_models = [
|
| 23 |
+
"gpt-5-nano",
|
| 24 |
+
"gpt-5-mini",
|
| 25 |
+
"gpt-5",
|
| 26 |
+
"o1-mini",
|
| 27 |
+
"o1",
|
| 28 |
+
"o1-pro",
|
| 29 |
+
"o3-mini",
|
| 30 |
+
"o3",
|
| 31 |
+
"o3-pro",
|
| 32 |
+
"o4-mini",
|
| 33 |
+
"o4",
|
| 34 |
+
"o4-pro",
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class LargeLanguageModelOpenAI:
|
| 39 |
+
"""Tiny async wrapper for OpenAI Chat Completions."""
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
llm_id: str = "",
|
| 44 |
+
model: str = "gpt-4.1-mini",
|
| 45 |
+
api_key: Optional[str] = None,
|
| 46 |
+
base_url: Optional[str] = None,
|
| 47 |
+
timeout_s: float = 300.0,
|
| 48 |
+
regex_max_attempts: int = 10,
|
| 49 |
+
sampling_params: Optional[Dict[str, Any]] = None,
|
| 50 |
+
init_kwargs: Optional[Dict[str, Any]] = None,
|
| 51 |
+
output_directory: Optional[str] = None,
|
| 52 |
+
) -> None:
|
| 53 |
+
self.llm_id = llm_id
|
| 54 |
+
self.model = model
|
| 55 |
+
key = api_key or os.getenv("OPENAI_API_KEY")
|
| 56 |
+
if not key:
|
| 57 |
+
raise RuntimeError(
|
| 58 |
+
"Set OPENAI_API_KEY as global environment variable or pass api_key."
|
| 59 |
+
)
|
| 60 |
+
client_kwargs: Dict[str, Any] = {"api_key": key, "timeout": timeout_s}
|
| 61 |
+
if base_url:
|
| 62 |
+
client_kwargs["base_url"] = base_url
|
| 63 |
+
self.client = AsyncOpenAI(**client_kwargs)
|
| 64 |
+
|
| 65 |
+
# Sampling/default request params set at init
|
| 66 |
+
self.sampling_params = sampling_params
|
| 67 |
+
self.use_reasoning = model in reasoning_models
|
| 68 |
+
if self.use_reasoning:
|
| 69 |
+
self.sampling_params["reasoning"] = {
|
| 70 |
+
"effort": "low",
|
| 71 |
+
"summary": "detailed",
|
| 72 |
+
}
|
| 73 |
+
self.regex_max_attempts = max(1, int(regex_max_attempts))
|
| 74 |
+
|
| 75 |
+
def get_inference_policies(self) -> Dict[str, Callable]:
|
| 76 |
+
return {
|
| 77 |
+
self.llm_id: self.get_action,
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
async def prepare_adapter_for_inference(self, *args: Any, **kwargs: Any) -> None:
|
| 81 |
+
await asyncio.sleep(0)
|
| 82 |
+
pass
|
| 83 |
+
|
| 84 |
+
async def toggle_eval_mode(self, *args: Any, **kwargs: Any) -> None:
|
| 85 |
+
await asyncio.sleep(0)
|
| 86 |
+
pass
|
| 87 |
+
|
| 88 |
+
async def toggle_training_mode(self, *args: Any, **kwargs: Any) -> None:
|
| 89 |
+
await asyncio.sleep(0)
|
| 90 |
+
pass
|
| 91 |
+
|
| 92 |
+
async def export_adapters(self, *args: Any, **kwargs: Any) -> None:
|
| 93 |
+
await asyncio.sleep(0)
|
| 94 |
+
pass
|
| 95 |
+
|
| 96 |
+
async def checkpoint_all_adapters(self, *args: Any, **kwargs: Any) -> None:
|
| 97 |
+
await asyncio.sleep(0)
|
| 98 |
+
pass
|
| 99 |
+
|
| 100 |
+
def extract_output_from_response(self, resp: Response) -> LLMInferenceOutput:
|
| 101 |
+
if len(resp.output) > 1:
|
| 102 |
+
summary = resp.output[0].summary
|
| 103 |
+
if summary != []:
|
| 104 |
+
reasoning_content = summary[0].text
|
| 105 |
+
reasoning_content = f"OpenAI Reasoning Summary: {reasoning_content}"
|
| 106 |
+
else:
|
| 107 |
+
reasoning_content = None
|
| 108 |
+
content = resp.output[1].content[0].text
|
| 109 |
+
else:
|
| 110 |
+
reasoning_content = None
|
| 111 |
+
content = resp.output[0].content[0].text
|
| 112 |
+
|
| 113 |
+
return LLMInferenceOutput(
|
| 114 |
+
content=content,
|
| 115 |
+
reasoning_content=reasoning_content,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
@backoff.on_exception(
|
| 119 |
+
backoff.expo, Exception, max_time=10**10, max_tries=10**10
|
| 120 |
+
)
|
| 121 |
+
async def get_action(
|
| 122 |
+
self,
|
| 123 |
+
state: list[ChatTurn],
|
| 124 |
+
agent_id: str,
|
| 125 |
+
regex: Optional[str] = None,
|
| 126 |
+
) -> LLMInferenceOutput:
|
| 127 |
+
# Remove any non-role/content keys from the prompt else openai will error.
|
| 128 |
+
prompt = [{"role": p.role, "content": p.content} for p in state]
|
| 129 |
+
|
| 130 |
+
# if self.sleep_between_requests:
|
| 131 |
+
# await self.wait_random_time()
|
| 132 |
+
|
| 133 |
+
# If regex is required, prime the model and validate client-side
|
| 134 |
+
if regex:
|
| 135 |
+
constraint_msg = {
|
| 136 |
+
"role": "user",
|
| 137 |
+
"content": (
|
| 138 |
+
f"Output must match this regex exactly: {regex} \n"
|
| 139 |
+
"Return only the matching string, with no quotes or extra text."
|
| 140 |
+
),
|
| 141 |
+
}
|
| 142 |
+
prompt = [constraint_msg, *prompt]
|
| 143 |
+
pattern = re.compile(regex)
|
| 144 |
+
for _ in range(self.regex_max_attempts):
|
| 145 |
+
resp = await self.client.responses.create(
|
| 146 |
+
model=self.model,
|
| 147 |
+
input=prompt,
|
| 148 |
+
**self.sampling_params,
|
| 149 |
+
)
|
| 150 |
+
policy_output = self.extract_output_from_response(resp)
|
| 151 |
+
if pattern.fullmatch(policy_output.content):
|
| 152 |
+
return policy_output
|
| 153 |
+
prompt = [
|
| 154 |
+
*prompt,
|
| 155 |
+
{
|
| 156 |
+
"role": "user",
|
| 157 |
+
"content": (
|
| 158 |
+
f"Invalid response format. Expected format (regex): {regex}\n Please try again and provide ONLY a response that matches this regex."
|
| 159 |
+
),
|
| 160 |
+
},
|
| 161 |
+
]
|
| 162 |
+
return policy_output
|
| 163 |
+
|
| 164 |
+
# Simple, unconstrained generation
|
| 165 |
+
resp = await self.client.responses.create(
|
| 166 |
+
model=self.model,
|
| 167 |
+
input=prompt,
|
| 168 |
+
**self.sampling_params,
|
| 169 |
+
)
|
| 170 |
+
policy_output = self.extract_output_from_response(resp)
|
| 171 |
+
return policy_output
|
| 172 |
+
|
| 173 |
+
def shutdown(self) -> None:
|
| 174 |
+
self.client = None
|
src_code_for_reproducibility/models/large_language_model_local.py
ADDED
|
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: mllm/models/large_language_model_local.py
|
| 3 |
+
Summary: Provides a local large language model wrapper over inference backends.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
import re
|
| 9 |
+
import sys
|
| 10 |
+
import uuid
|
| 11 |
+
from collections.abc import Callable
|
| 12 |
+
from copy import deepcopy
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
from typing import Literal
|
| 15 |
+
|
| 16 |
+
import httpx
|
| 17 |
+
import requests
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
from torch.optim import SGD, Adam, AdamW, RMSprop
|
| 21 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 22 |
+
|
| 23 |
+
from mllm.chat_utils.apply_template import chat_turns_to_token_ids
|
| 24 |
+
from mllm.markov_games.rollout_tree import ChatTurn
|
| 25 |
+
from mllm.models.adapter_training_wrapper import AdapterWrapper
|
| 26 |
+
from mllm.models.inference_backend import LLMInferenceOutput
|
| 27 |
+
from mllm.models.inference_backend_dummy import DummyInferenceBackend
|
| 28 |
+
from mllm.models.inference_backend_vllm import VLLMAsyncBackend
|
| 29 |
+
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
logger.addHandler(logging.StreamHandler(sys.stdout))
|
| 32 |
+
|
| 33 |
+
AdapterID = str
|
| 34 |
+
PolicyID = str
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class LeanLocalLLM:
|
| 38 |
+
"""
|
| 39 |
+
Wrapper that manages local HuggingFace models, adapters, and inference backends.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
llm_id: str = "base_llm",
|
| 45 |
+
model_name: str = "Qwen/Qwen3-4B-Instruct-2507",
|
| 46 |
+
device: str = "cuda",
|
| 47 |
+
hf_kwargs: dict = {},
|
| 48 |
+
adapter_configs: dict = {},
|
| 49 |
+
output_directory: str = "./models/",
|
| 50 |
+
inference_backend: Literal["vllm", "dummy"] = "vllm",
|
| 51 |
+
inference_backend_sampling_params: dict = {},
|
| 52 |
+
inference_backend_init_kwargs: dict = {},
|
| 53 |
+
initial_adapter_paths: dict[str, str] | None = None,
|
| 54 |
+
initial_buffer_paths: list[str] | None = None,
|
| 55 |
+
enable_thinking: bool = None,
|
| 56 |
+
regex_max_attempts: int = -1,
|
| 57 |
+
max_thinking_characters: int = 0,
|
| 58 |
+
):
|
| 59 |
+
self.inference_backend_name = inference_backend
|
| 60 |
+
self.output_directory = output_directory
|
| 61 |
+
self.llm_id = llm_id
|
| 62 |
+
self.device = torch.device(device) if device else torch.device("cuda")
|
| 63 |
+
self.model_name = model_name
|
| 64 |
+
self.adapter_configs = adapter_configs
|
| 65 |
+
self.adapter_ids = list(adapter_configs.keys())
|
| 66 |
+
self.enable_thinking = enable_thinking
|
| 67 |
+
self.regex_max_attempts = regex_max_attempts
|
| 68 |
+
self.initial_buffer_paths = initial_buffer_paths
|
| 69 |
+
self.max_thinking_characters = max_thinking_characters
|
| 70 |
+
self.regex_retries_count = 0
|
| 71 |
+
|
| 72 |
+
# Optional user-specified initial adapter weight locations (local or HF Hub)
|
| 73 |
+
# Format: {adapter_id: path_or_repo_id}
|
| 74 |
+
self.initial_adapter_paths: dict[str, str] | None = initial_adapter_paths
|
| 75 |
+
|
| 76 |
+
# Path management / imports
|
| 77 |
+
self.save_path = str(os.path.join(output_directory, model_name, "adapters"))
|
| 78 |
+
self.adapter_paths = {
|
| 79 |
+
adapter_id: os.path.join(self.save_path, adapter_id)
|
| 80 |
+
for adapter_id in self.adapter_ids
|
| 81 |
+
}
|
| 82 |
+
checkpoints_dir = os.path.join(self.output_directory, "checkpoints")
|
| 83 |
+
self.past_agent_adapter_paths = {}
|
| 84 |
+
if os.path.isdir(checkpoints_dir):
|
| 85 |
+
for dirname in os.listdir(checkpoints_dir):
|
| 86 |
+
dirpath = os.path.join(checkpoints_dir, dirname)
|
| 87 |
+
if os.path.isdir(dirpath):
|
| 88 |
+
self.past_agent_adapter_paths[f"{dirname}_buffer"] = os.path.join(
|
| 89 |
+
dirpath, "agent_adapter"
|
| 90 |
+
)
|
| 91 |
+
logger.info(
|
| 92 |
+
f"Loaded {len(self.past_agent_adapter_paths)} past agent adapters from checkpoints directory."
|
| 93 |
+
)
|
| 94 |
+
if self.initial_buffer_paths is not None:
|
| 95 |
+
previous_count = len(self.past_agent_adapter_paths)
|
| 96 |
+
for path in self.initial_buffer_paths:
|
| 97 |
+
if os.path.isdir(path):
|
| 98 |
+
for dirname in os.listdir(path):
|
| 99 |
+
dirpath = os.path.join(path, dirname)
|
| 100 |
+
if os.path.isdir(dirpath):
|
| 101 |
+
self.past_agent_adapter_paths[
|
| 102 |
+
f"{dirname}_buffer"
|
| 103 |
+
] = os.path.join(dirpath, "agent_adapter")
|
| 104 |
+
else:
|
| 105 |
+
logger.warning(
|
| 106 |
+
f"Initial buffer path {path} does not exist or is not a directory."
|
| 107 |
+
)
|
| 108 |
+
logger.info(
|
| 109 |
+
f"Loaded {len(self.past_agent_adapter_paths) - previous_count} past agent adapters from user-specified initial buffer paths."
|
| 110 |
+
)
|
| 111 |
+
self.past_agent_adapter_ids = list(self.past_agent_adapter_paths.keys())
|
| 112 |
+
|
| 113 |
+
# ID management for tracking adapter versions
|
| 114 |
+
self.adapter_train_ids = {
|
| 115 |
+
adapter_id: self.short_id_generator() for adapter_id in self.adapter_ids
|
| 116 |
+
}
|
| 117 |
+
# Initialize tokenizer
|
| 118 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
| 119 |
+
# Setup padding token to be same as EOS token
|
| 120 |
+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
| 121 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 122 |
+
|
| 123 |
+
self.weights_got_updated: dict[AdapterID, bool] = {
|
| 124 |
+
adapter_id: False for adapter_id in self.adapter_ids
|
| 125 |
+
}
|
| 126 |
+
self.weights_got_updated.update(
|
| 127 |
+
{adapter_id: False for adapter_id in self.past_agent_adapter_ids}
|
| 128 |
+
)
|
| 129 |
+
self.current_lora_request = None
|
| 130 |
+
self.currently_loaded_adapter_id = None
|
| 131 |
+
|
| 132 |
+
# ---------------------------------------------------------
|
| 133 |
+
# Init HF model, peft adapters
|
| 134 |
+
# ---------------------------------------------------------
|
| 135 |
+
self.shared_hf_llm = AutoModelForCausalLM.from_pretrained(
|
| 136 |
+
pretrained_model_name_or_path=model_name,
|
| 137 |
+
**hf_kwargs,
|
| 138 |
+
)
|
| 139 |
+
self.hf_adapters = {}
|
| 140 |
+
self.optimizers = {}
|
| 141 |
+
for adapter_id in self.adapter_ids:
|
| 142 |
+
# Prefer output-folder path if it exists; else fall back to user-specified initial path if provided
|
| 143 |
+
output_path = os.path.join(self.save_path, adapter_id)
|
| 144 |
+
chosen_path: str | None = None
|
| 145 |
+
if os.path.isdir(output_path) and os.listdir(output_path):
|
| 146 |
+
chosen_path = output_path
|
| 147 |
+
logger.info(
|
| 148 |
+
f"Initializing adapter '{adapter_id}': using existing weights from output folder '{chosen_path}'."
|
| 149 |
+
)
|
| 150 |
+
elif (
|
| 151 |
+
self.initial_adapter_paths and adapter_id in self.initial_adapter_paths
|
| 152 |
+
):
|
| 153 |
+
chosen_path = self.initial_adapter_paths[adapter_id]
|
| 154 |
+
logger.info(
|
| 155 |
+
f"Initializing adapter '{adapter_id}': using provided initial path '{chosen_path}'."
|
| 156 |
+
)
|
| 157 |
+
else:
|
| 158 |
+
logger.info(
|
| 159 |
+
f"Initializing adapter '{adapter_id}': no initial weights provided or found; starting from scratch."
|
| 160 |
+
)
|
| 161 |
+
hf_adapter = AdapterWrapper(
|
| 162 |
+
shared_llm=self.shared_hf_llm,
|
| 163 |
+
adapter_id=adapter_id,
|
| 164 |
+
lora_config=adapter_configs[adapter_id],
|
| 165 |
+
path=chosen_path,
|
| 166 |
+
).to(device)
|
| 167 |
+
self.hf_adapters[adapter_id] = hf_adapter
|
| 168 |
+
# Persist current state of all adapters (ensures remote loads are cached to disk)
|
| 169 |
+
self.export_adapters()
|
| 170 |
+
|
| 171 |
+
# ---------------------------------------------------------
|
| 172 |
+
# Init inference inference_backend
|
| 173 |
+
# ---------------------------------------------------------
|
| 174 |
+
|
| 175 |
+
if inference_backend == "vllm":
|
| 176 |
+
self.inference_backend = VLLMAsyncBackend(
|
| 177 |
+
model_name=self.model_name,
|
| 178 |
+
# adapter_paths=self.adapter_paths,
|
| 179 |
+
tokenizer=self.tokenizer,
|
| 180 |
+
engine_init_kwargs=inference_backend_init_kwargs,
|
| 181 |
+
sampling_params=inference_backend_sampling_params,
|
| 182 |
+
)
|
| 183 |
+
elif inference_backend == "dummy":
|
| 184 |
+
self.inference_backend = DummyInferenceBackend()
|
| 185 |
+
else:
|
| 186 |
+
raise ValueError(f"Unknown inference_backend: {inference_backend}")
|
| 187 |
+
|
| 188 |
+
def reset_regex_retries_count(self) -> None:
|
| 189 |
+
self.regex_retries_count = 0
|
| 190 |
+
|
| 191 |
+
def get_inference_policies(self) -> dict[PolicyID, Callable]:
|
| 192 |
+
"""
|
| 193 |
+
Build async policy callables keyed by adapter id for inference-only usage.
|
| 194 |
+
"""
|
| 195 |
+
policies = {}
|
| 196 |
+
for adapter_id in self.adapter_ids:
|
| 197 |
+
# define policy func
|
| 198 |
+
async def policy(
|
| 199 |
+
state: list[ChatTurn],
|
| 200 |
+
agent_id: str,
|
| 201 |
+
regex: str | None = None,
|
| 202 |
+
_adapter_id=adapter_id,
|
| 203 |
+
):
|
| 204 |
+
self.prepare_adapter_for_inference(adapter_id=_adapter_id)
|
| 205 |
+
response = await self.get_action(state, agent_id, regex)
|
| 206 |
+
return response
|
| 207 |
+
|
| 208 |
+
policies[self.llm_id + "/" + adapter_id] = policy
|
| 209 |
+
|
| 210 |
+
for adapter_id in self.past_agent_adapter_ids:
|
| 211 |
+
# define policy func
|
| 212 |
+
async def policy(
|
| 213 |
+
state: list[ChatTurn],
|
| 214 |
+
agent_id: str,
|
| 215 |
+
regex: str | None = None,
|
| 216 |
+
_adapter_id=adapter_id,
|
| 217 |
+
):
|
| 218 |
+
self.prepare_adapter_for_inference(adapter_id=_adapter_id)
|
| 219 |
+
response = await self.get_action(state, agent_id, regex)
|
| 220 |
+
return response
|
| 221 |
+
|
| 222 |
+
policies[self.llm_id + "/" + adapter_id] = policy
|
| 223 |
+
return policies
|
| 224 |
+
|
| 225 |
+
def get_adapter_modules(self) -> dict[PolicyID, nn.Module]:
|
| 226 |
+
"""
|
| 227 |
+
Returns wrappers over the adapters which allows them be
|
| 228 |
+
interfaced like regular PyTorch models.
|
| 229 |
+
AdapterWrapper lives in adapter_wrapper.py; the huggingface modules already wrap
|
| 230 |
+
parameters here, so we surface them directly until an extra shim is required.
|
| 231 |
+
"""
|
| 232 |
+
trainable_objects = {an: self.hf_adapters[an] for an in self.adapter_ids}
|
| 233 |
+
return trainable_objects
|
| 234 |
+
|
| 235 |
+
async def toggle_training_mode(self) -> None:
|
| 236 |
+
for adn in self.adapter_ids:
|
| 237 |
+
self.adapter_train_ids[adn] = self.short_id_generator()
|
| 238 |
+
await self.inference_backend.toggle_training_mode()
|
| 239 |
+
|
| 240 |
+
async def toggle_eval_mode(self) -> None:
|
| 241 |
+
await self.inference_backend.toggle_eval_mode()
|
| 242 |
+
|
| 243 |
+
def prepare_adapter_for_inference(self, adapter_id: AdapterID) -> None:
|
| 244 |
+
self.inference_backend.prepare_adapter(
|
| 245 |
+
adapter_id,
|
| 246 |
+
adapter_path=self.adapter_paths.get(
|
| 247 |
+
adapter_id, self.past_agent_adapter_paths.get(adapter_id, None)
|
| 248 |
+
),
|
| 249 |
+
weights_got_updated=self.weights_got_updated[adapter_id],
|
| 250 |
+
)
|
| 251 |
+
self.currently_loaded_adapter_id = adapter_id
|
| 252 |
+
self.weights_got_updated[adapter_id] = False
|
| 253 |
+
|
| 254 |
+
# def _make_prompt_text(self, prompt: list[dict]) -> str:
|
| 255 |
+
# if self.enable_thinking is not None:
|
| 256 |
+
# prompt_text = self.tokenizer.apply_chat_template(
|
| 257 |
+
# prompt,
|
| 258 |
+
# tokenize=False,
|
| 259 |
+
# add_generation_prompt=True,
|
| 260 |
+
# enable_thinking=self.enable_thinking,
|
| 261 |
+
# )
|
| 262 |
+
# else:
|
| 263 |
+
# prompt_text = self.tokenizer.apply_chat_template(
|
| 264 |
+
# prompt,
|
| 265 |
+
# tokenize=False,
|
| 266 |
+
# add_generation_prompt=True,
|
| 267 |
+
# )
|
| 268 |
+
|
| 269 |
+
# return prompt_text
|
| 270 |
+
|
| 271 |
+
async def get_action(
|
| 272 |
+
self, state: list[ChatTurn], agent_id: str, regex: str | None = None
|
| 273 |
+
) -> ChatTurn:
|
| 274 |
+
current_regex = regex if self.regex_max_attempts == -1 else None
|
| 275 |
+
pattern = re.compile(regex) if regex else None
|
| 276 |
+
nb_attempts = 0
|
| 277 |
+
state = state[:]
|
| 278 |
+
while True:
|
| 279 |
+
context_token_ids = chat_turns_to_token_ids(
|
| 280 |
+
chats=state,
|
| 281 |
+
tokenizer=self.tokenizer,
|
| 282 |
+
enable_thinking=self.enable_thinking,
|
| 283 |
+
)
|
| 284 |
+
policy_output = await self.inference_backend.generate(
|
| 285 |
+
input_token_ids=context_token_ids.tolist(),
|
| 286 |
+
extract_thinking=(self.max_thinking_characters > 0),
|
| 287 |
+
regex=current_regex,
|
| 288 |
+
)
|
| 289 |
+
if (
|
| 290 |
+
pattern is None
|
| 291 |
+
or (pattern.fullmatch(policy_output.content))
|
| 292 |
+
or (nb_attempts >= self.regex_max_attempts)
|
| 293 |
+
):
|
| 294 |
+
return ChatTurn(
|
| 295 |
+
agent_id=agent_id,
|
| 296 |
+
role="assistant",
|
| 297 |
+
content=policy_output.content,
|
| 298 |
+
reasoning_content=policy_output.reasoning_content,
|
| 299 |
+
out_token_ids=policy_output.out_token_ids,
|
| 300 |
+
log_probs=policy_output.log_probs,
|
| 301 |
+
is_state_end=False,
|
| 302 |
+
)
|
| 303 |
+
else:
|
| 304 |
+
self.regex_retries_count += 1
|
| 305 |
+
nb_attempts += 1
|
| 306 |
+
logger.warning(
|
| 307 |
+
f"Response {policy_output.content} did not match regex: {regex}, retry {nb_attempts}/{self.regex_max_attempts}"
|
| 308 |
+
)
|
| 309 |
+
if nb_attempts == self.regex_max_attempts:
|
| 310 |
+
current_regex = regex
|
| 311 |
+
# regex_prompt = ChatTurn(
|
| 312 |
+
# role="user",
|
| 313 |
+
# content=f"Invalid response format. Expected format (regex): {current_regex}\n Please try again and provide ONLY a response that matches this regex.",
|
| 314 |
+
# reasoning_content=None,
|
| 315 |
+
# log_probs=None,
|
| 316 |
+
# out_token_ids=None,
|
| 317 |
+
# is_state_end=False,
|
| 318 |
+
# )
|
| 319 |
+
# state.append(regex_prompt)
|
| 320 |
+
|
| 321 |
+
def export_adapters(self) -> None:
|
| 322 |
+
"""
|
| 323 |
+
Any peft wrapper, by default, saves all adapters, not just the one currently loaded.
|
| 324 |
+
"""
|
| 325 |
+
|
| 326 |
+
# New version of the adapters available
|
| 327 |
+
for adapter_id in self.adapter_ids:
|
| 328 |
+
self.weights_got_updated[adapter_id] = True
|
| 329 |
+
for adapter_id in self.past_agent_adapter_ids:
|
| 330 |
+
self.weights_got_updated[adapter_id] = True
|
| 331 |
+
|
| 332 |
+
adapter_id = self.adapter_ids[0]
|
| 333 |
+
self.hf_adapters[adapter_id].save_pretrained(self.save_path)
|
| 334 |
+
|
| 335 |
+
def checkpoint_all_adapters(self, checkpoint_indicator: str) -> None:
|
| 336 |
+
"""
|
| 337 |
+
Checkpoints all adapters to the configured output directory.
|
| 338 |
+
"""
|
| 339 |
+
adapter_id = self.adapter_ids[0]
|
| 340 |
+
output_dir = os.path.join(self.output_directory, "checkpoints")
|
| 341 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 342 |
+
date_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
| 343 |
+
agent_adapter_dir = f"{adapter_id}-{checkpoint_indicator}-{date_str}"
|
| 344 |
+
export_path = os.path.join(output_dir, agent_adapter_dir)
|
| 345 |
+
for adapter_id in self.adapter_ids:
|
| 346 |
+
if "agent" in adapter_id:
|
| 347 |
+
self.past_agent_adapter_paths[
|
| 348 |
+
f"{agent_adapter_dir}_buffer"
|
| 349 |
+
] = os.path.join(export_path, adapter_id)
|
| 350 |
+
self.past_agent_adapter_ids.append(f"{agent_adapter_dir}_buffer")
|
| 351 |
+
self.weights_got_updated[f"{agent_adapter_dir}_buffer"] = False
|
| 352 |
+
self.hf_adapters[adapter_id].save_pretrained(export_path)
|
| 353 |
+
|
| 354 |
+
def short_id_generator(self) -> str:
|
| 355 |
+
"""
|
| 356 |
+
Generates a short unique ID for tracking adapter versions.
|
| 357 |
+
|
| 358 |
+
Returns:
|
| 359 |
+
int: An 8-digit integer ID.
|
| 360 |
+
"""
|
| 361 |
+
return str(uuid.uuid4().int)[:8]
|
src_code_for_reproducibility/models/scalar_critic.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: mllm/models/scalar_critic.py
|
| 3 |
+
Summary: Defines a scalar critic network and helper utilities.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.optim as optim
|
| 9 |
+
from peft import LoraConfig, get_peft_model
|
| 10 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 11 |
+
|
| 12 |
+
from mllm.models.adapter_training_wrapper import AdapterWrapper
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ScalarCritic(nn.Module):
|
| 16 |
+
"""
|
| 17 |
+
A causal-LM critic_adapter + a scalar value head:
|
| 18 |
+
V_φ(s) = wᵀ h_last + b
|
| 19 |
+
Only LoRA adapters (inside critic_adapter) and the value head are trainable.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, critic_adapter: AdapterWrapper):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.critic_adapter = critic_adapter
|
| 25 |
+
hidden_size = self.critic_adapter.shared_llm.config.hidden_size
|
| 26 |
+
self.value_head = nn.Linear(hidden_size, 1).to(
|
| 27 |
+
dtype=critic_adapter.dtype, device=critic_adapter.device
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
def forward(self, input_ids, attention_mask=None, **kwargs):
|
| 31 |
+
# AdapterWrapper activates its own adapter internally
|
| 32 |
+
outputs = self.critic_adapter(
|
| 33 |
+
input_ids=input_ids,
|
| 34 |
+
attention_mask=attention_mask,
|
| 35 |
+
output_hidden_states=True,
|
| 36 |
+
**kwargs,
|
| 37 |
+
)
|
| 38 |
+
h_last = outputs.hidden_states[-1] # (B, S, H)
|
| 39 |
+
values = self.value_head(h_last).squeeze(-1) # (B, S)
|
| 40 |
+
return values
|
| 41 |
+
|
| 42 |
+
def parameters(self, recurse: bool = True):
|
| 43 |
+
"""Iterator over *trainable* parameters for this critic."""
|
| 44 |
+
# 1) LoRA params for *this* adapter
|
| 45 |
+
for p in self.critic_adapter.parameters():
|
| 46 |
+
yield p
|
| 47 |
+
# 2) scalar head
|
| 48 |
+
yield from self.value_head.parameters()
|
| 49 |
+
|
| 50 |
+
def gradient_checkpointing_enable(self, *args, **kwargs):
|
| 51 |
+
self.critic_adapter.gradient_checkpointing_enable(*args, **kwargs)
|
| 52 |
+
|
| 53 |
+
@property
|
| 54 |
+
def dtype(self):
|
| 55 |
+
return self.critic_adapter.dtype
|
| 56 |
+
|
| 57 |
+
@property
|
| 58 |
+
def device(self):
|
| 59 |
+
return self.critic_adapter.device
|
src_code_for_reproducibility/training/tally_rollout.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: mllm/training/tally_rollout.py
|
| 3 |
+
Summary: Serializes rollout data into tallies for downstream processing.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
from copy import deepcopy
|
| 9 |
+
from typing import Union
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import torch
|
| 14 |
+
from transformers import AutoTokenizer
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class RolloutTallyItem:
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
crn_ids: list[str],
|
| 21 |
+
rollout_ids: list[str],
|
| 22 |
+
agent_ids: list[str],
|
| 23 |
+
metric_matrix: torch.Tensor,
|
| 24 |
+
):
|
| 25 |
+
"""Lightweight data container that keeps rollout-aligned metric matrices."""
|
| 26 |
+
if isinstance(crn_ids, torch.Tensor):
|
| 27 |
+
crn_ids = crn_ids.detach().cpu().numpy()
|
| 28 |
+
if isinstance(rollout_ids, torch.Tensor):
|
| 29 |
+
rollout_ids = rollout_ids.detach().cpu().numpy()
|
| 30 |
+
if isinstance(agent_ids, torch.Tensor):
|
| 31 |
+
agent_ids = agent_ids.detach().cpu().numpy()
|
| 32 |
+
self.crn_ids = crn_ids
|
| 33 |
+
self.rollout_ids = rollout_ids
|
| 34 |
+
self.agent_ids = agent_ids
|
| 35 |
+
metric_matrix = metric_matrix.detach().cpu()
|
| 36 |
+
assert (
|
| 37 |
+
0 < metric_matrix.ndim <= 2
|
| 38 |
+
), "Metric matrix must have less than or equal to 2 dimensions"
|
| 39 |
+
if metric_matrix.ndim == 1:
|
| 40 |
+
metric_matrix = metric_matrix.reshape(1, -1)
|
| 41 |
+
# Convert to float32 if tensor is in BFloat16 format (not supported by numpy)
|
| 42 |
+
if metric_matrix.dtype == torch.bfloat16:
|
| 43 |
+
metric_matrix = metric_matrix.float()
|
| 44 |
+
self.metric_matrix = metric_matrix.numpy()
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class RolloutTally:
|
| 48 |
+
"""
|
| 49 |
+
Tally is a utility class for collecting and storing training metrics.
|
| 50 |
+
It supports adding metrics at specified paths and saving them to disk.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __init__(self):
|
| 54 |
+
"""
|
| 55 |
+
Initializes the RolloutTally object.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
tokenizer (AutoTokenizer): Tokenizer for converting token IDs to strings.
|
| 59 |
+
max_context_length (int, optional): Maximum context length for contextualized metrics. Defaults to 30.
|
| 60 |
+
"""
|
| 61 |
+
# Array-preserving structure (leaf lists hold numpy arrays / scalars)
|
| 62 |
+
self.metrics = {}
|
| 63 |
+
# Global ordered list of sample identifiers (crn_id, rollout_id) added in the order samples are processed
|
| 64 |
+
|
| 65 |
+
def reset(self):
|
| 66 |
+
"""Reset the tally to an empty dict."""
|
| 67 |
+
self.metrics = {}
|
| 68 |
+
|
| 69 |
+
def get_from_nested_dict(self, dictio: dict, path: str):
|
| 70 |
+
"""Retrieve a nested entry, creating intermediate dicts as needed."""
|
| 71 |
+
assert isinstance(path, list), "Path must be list."
|
| 72 |
+
for sp in path[:-1]:
|
| 73 |
+
dictio = dictio.setdefault(sp, {})
|
| 74 |
+
return dictio.get(path[-1], None)
|
| 75 |
+
|
| 76 |
+
def set_at_path(self, dictio: dict, path: str, value):
|
| 77 |
+
"""Store ``value`` at ``path``; helper used by ``add_metric``."""
|
| 78 |
+
for sp in path[:-1]:
|
| 79 |
+
dictio = dictio.setdefault(sp, {})
|
| 80 |
+
dictio[path[-1]] = value
|
| 81 |
+
|
| 82 |
+
def add_metric(self, path: list[str], rollout_tally_item: RolloutTallyItem):
|
| 83 |
+
"""
|
| 84 |
+
Adds a metric to the base tally at the specified path.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
path (list): List of keys representing the path in the base tally.
|
| 88 |
+
rollout_tally_item (RolloutTallyItem): The rollout tally item to add.
|
| 89 |
+
"""
|
| 90 |
+
rollout_tally_item = deepcopy(rollout_tally_item)
|
| 91 |
+
|
| 92 |
+
# Update array-preserving tally
|
| 93 |
+
array_list = self.get_from_nested_dict(dictio=self.metrics, path=path)
|
| 94 |
+
if array_list is None:
|
| 95 |
+
self.set_at_path(dictio=self.metrics, path=path, value=[rollout_tally_item])
|
| 96 |
+
else:
|
| 97 |
+
array_list.append(rollout_tally_item)
|
| 98 |
+
|
| 99 |
+
def save(self, identifier: str, folder: str):
|
| 100 |
+
"""Persist the tally as a pickle (metrics only) under ``folder``."""
|
| 101 |
+
os.makedirs(name=folder, exist_ok=True)
|
| 102 |
+
|
| 103 |
+
from datetime import datetime
|
| 104 |
+
|
| 105 |
+
now = datetime.now()
|
| 106 |
+
|
| 107 |
+
# Pickle only (fastest, exact structure with numpy/scalars at leaves)
|
| 108 |
+
try:
|
| 109 |
+
import pickle
|
| 110 |
+
|
| 111 |
+
pkl_path = os.path.join(folder, f"{identifier}.rt_tally.pkl")
|
| 112 |
+
payload = {"metrics": self.metrics}
|
| 113 |
+
with open(pkl_path, "wb") as f:
|
| 114 |
+
pickle.dump(payload, f, protocol=pickle.HIGHEST_PROTOCOL)
|
| 115 |
+
except Exception:
|
| 116 |
+
pass
|
src_code_for_reproducibility/training/tally_tokenwise.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: mllm/training/tally_tokenwise.py
|
| 3 |
+
Summary: Converts token-level tallies into per-token statistics.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
from typing import Any, Dict, List, Tuple, Union
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import torch
|
| 13 |
+
from transformers import AutoTokenizer
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ContextualizedTokenwiseTally:
|
| 17 |
+
"""
|
| 18 |
+
Collect, store, and save token-level metrics per rollout.
|
| 19 |
+
|
| 20 |
+
- One DataFrame per rollout_id in `paths`
|
| 21 |
+
- Index = timestep (int)
|
| 22 |
+
- Columns are added incrementally via `add_contexts()` and `add_data()`
|
| 23 |
+
- Cells may contain scalars, strings, or lists (dtype=object)
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
tokenizer: AutoTokenizer,
|
| 29 |
+
paths: List[str],
|
| 30 |
+
max_context_length: int = 30,
|
| 31 |
+
):
|
| 32 |
+
"""
|
| 33 |
+
Args:
|
| 34 |
+
tokenizer: HuggingFace tokenizer used to convert tids -> tokens
|
| 35 |
+
paths: rollout identifiers (parallel to batch dimension)
|
| 36 |
+
max_context_length: truncate context token lists to this length
|
| 37 |
+
"""
|
| 38 |
+
self.tokenizer = tokenizer
|
| 39 |
+
self.paths = paths
|
| 40 |
+
self.max_context_length = max_context_length
|
| 41 |
+
self.tally: Dict[str, pd.DataFrame] = {path: pd.DataFrame() for path in paths}
|
| 42 |
+
|
| 43 |
+
# set later by setters
|
| 44 |
+
self.contexts: torch.Tensor | None = None
|
| 45 |
+
self.action_mask: torch.Tensor | None = None
|
| 46 |
+
self.range: Tuple[int, int] | None = None
|
| 47 |
+
|
| 48 |
+
# --------- Utilities ---------
|
| 49 |
+
|
| 50 |
+
def tids_to_str(self, tids: List[int]) -> List[str]:
|
| 51 |
+
"""Convert a list of token IDs to a list of token strings."""
|
| 52 |
+
return self.tokenizer.convert_ids_to_tokens(tids)
|
| 53 |
+
|
| 54 |
+
def _ensure_ready(self):
|
| 55 |
+
"""Validate that action mask and range are configured prior to writes."""
|
| 56 |
+
assert self.action_mask is not None, "call set_action_mask(mask) first"
|
| 57 |
+
assert self.range is not None, "call set_range((start, end)) first"
|
| 58 |
+
|
| 59 |
+
@staticmethod
|
| 60 |
+
def _sanitize_filename(name: Any) -> str:
|
| 61 |
+
"""Make a safe filename from any rollout_id."""
|
| 62 |
+
s = str(name)
|
| 63 |
+
bad = {os.sep, " ", ":", "|", "<", ">", '"', "'"}
|
| 64 |
+
if os.altsep is not None:
|
| 65 |
+
bad.add(os.altsep)
|
| 66 |
+
for ch in bad:
|
| 67 |
+
s = s.replace(ch, "_")
|
| 68 |
+
return s
|
| 69 |
+
|
| 70 |
+
@staticmethod
|
| 71 |
+
def _pad_left(seq: List[Any], length: int, pad_val: Any = "") -> List[Any]:
|
| 72 |
+
"""Left-pad a sequence to `length` with `pad_val`."""
|
| 73 |
+
if len(seq) >= length:
|
| 74 |
+
return seq[-length:]
|
| 75 |
+
return [pad_val] * (length - len(seq)) + list(seq)
|
| 76 |
+
|
| 77 |
+
# --------- Setters ---------
|
| 78 |
+
|
| 79 |
+
def set_action_mask(self, action_mask: torch.Tensor):
|
| 80 |
+
"""Register the (B, S) mask indicating which tokens correspond to actions."""
|
| 81 |
+
self.action_mask = action_mask
|
| 82 |
+
|
| 83 |
+
def set_range(self, range: Tuple[int, int]):
|
| 84 |
+
"""Record which subset of ``paths`` the current mini-batch corresponds to."""
|
| 85 |
+
self.range = range
|
| 86 |
+
|
| 87 |
+
# --------- Column builders ---------
|
| 88 |
+
|
| 89 |
+
def add_contexts(self, contexts: torch.Tensor):
|
| 90 |
+
"""
|
| 91 |
+
Add a single 'context' column (list[str]) for valid steps.
|
| 92 |
+
|
| 93 |
+
Expects `contexts` with shape (B, S): token id at each timestep.
|
| 94 |
+
For each valid timestep t, we use the last N tokens up to and including t:
|
| 95 |
+
window = contexts[i, max(0, t - N + 1) : t + 1]
|
| 96 |
+
The list is left-padded with "" to always be length N.
|
| 97 |
+
"""
|
| 98 |
+
self._ensure_ready()
|
| 99 |
+
|
| 100 |
+
current_paths = self.paths[self.range[0] : self.range[1]]
|
| 101 |
+
B, S = contexts.shape
|
| 102 |
+
N = self.max_context_length
|
| 103 |
+
|
| 104 |
+
# to CPU ints once
|
| 105 |
+
contexts_cpu = contexts.detach().to("cpu")
|
| 106 |
+
|
| 107 |
+
for i in range(B):
|
| 108 |
+
rollout_id = current_paths[i]
|
| 109 |
+
df = self.tally.get(rollout_id, pd.DataFrame())
|
| 110 |
+
|
| 111 |
+
valid_idx = torch.nonzero(
|
| 112 |
+
self.action_mask[i].bool(), as_tuple=False
|
| 113 |
+
).squeeze(-1)
|
| 114 |
+
if valid_idx.numel() == 0:
|
| 115 |
+
self.tally[rollout_id] = df
|
| 116 |
+
continue
|
| 117 |
+
|
| 118 |
+
idx_list = valid_idx.tolist()
|
| 119 |
+
|
| 120 |
+
# ensure index contains valid steps
|
| 121 |
+
if df.empty:
|
| 122 |
+
df = pd.DataFrame(index=idx_list)
|
| 123 |
+
else:
|
| 124 |
+
new_index = sorted(set(df.index.tolist()) | set(idx_list))
|
| 125 |
+
if list(df.index) != new_index:
|
| 126 |
+
df = df.reindex(new_index)
|
| 127 |
+
|
| 128 |
+
# build context windows
|
| 129 |
+
ctx_token_lists = []
|
| 130 |
+
for t in idx_list:
|
| 131 |
+
start = max(0, t - N + 1)
|
| 132 |
+
window_ids = contexts_cpu[i, start : t + 1].tolist()
|
| 133 |
+
window_toks = self.tids_to_str([int(x) for x in window_ids])
|
| 134 |
+
if len(window_toks) < N:
|
| 135 |
+
window_toks = [""] * (N - len(window_toks)) + window_toks
|
| 136 |
+
else:
|
| 137 |
+
window_toks = window_toks[-N:]
|
| 138 |
+
ctx_token_lists.append(window_toks)
|
| 139 |
+
|
| 140 |
+
# single 'context' column
|
| 141 |
+
if "context" not in df.columns:
|
| 142 |
+
df["context"] = pd.Series(index=df.index, dtype=object)
|
| 143 |
+
df.loc[idx_list, "context"] = pd.Series(
|
| 144 |
+
ctx_token_lists, index=idx_list, dtype=object
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
self.tally[rollout_id] = df
|
| 148 |
+
|
| 149 |
+
def add_data(
|
| 150 |
+
self,
|
| 151 |
+
metric_id: str,
|
| 152 |
+
metrics: torch.Tensor,
|
| 153 |
+
to_tids: bool = False,
|
| 154 |
+
):
|
| 155 |
+
"""
|
| 156 |
+
Add a metric column for valid steps.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
metric_id: column name
|
| 160 |
+
metrics: shape (B, S) for scalars/ids or (B, S, K) for top-k vectors
|
| 161 |
+
to_tids: if True, treat ints/lists of ints as tids and convert to tokens
|
| 162 |
+
"""
|
| 163 |
+
self._ensure_ready()
|
| 164 |
+
current_paths = self.paths[self.range[0] : self.range[1]]
|
| 165 |
+
|
| 166 |
+
if metrics.dim() == 2:
|
| 167 |
+
B, S = metrics.shape
|
| 168 |
+
elif metrics.dim() == 3:
|
| 169 |
+
B, S, _ = metrics.shape
|
| 170 |
+
else:
|
| 171 |
+
raise ValueError("metrics must be (B, S) or (B, S, K)")
|
| 172 |
+
|
| 173 |
+
for i in range(B):
|
| 174 |
+
rollout_id = current_paths[i]
|
| 175 |
+
df = self.tally.get(rollout_id, pd.DataFrame())
|
| 176 |
+
|
| 177 |
+
valid_idx = torch.nonzero(
|
| 178 |
+
self.action_mask[i].bool(), as_tuple=False
|
| 179 |
+
).squeeze(-1)
|
| 180 |
+
if valid_idx.numel() == 0:
|
| 181 |
+
self.tally[rollout_id] = df
|
| 182 |
+
continue
|
| 183 |
+
|
| 184 |
+
idx_list = valid_idx.detach().cpu().tolist()
|
| 185 |
+
|
| 186 |
+
# Ensure index contains valid steps
|
| 187 |
+
if df.empty:
|
| 188 |
+
df = pd.DataFrame(index=idx_list)
|
| 189 |
+
else:
|
| 190 |
+
new_index = sorted(set(df.index.tolist()) | set(idx_list))
|
| 191 |
+
if list(df.index) != new_index:
|
| 192 |
+
df = df.reindex(new_index)
|
| 193 |
+
|
| 194 |
+
# Slice metrics at valid steps
|
| 195 |
+
m_valid = metrics[i][valid_idx]
|
| 196 |
+
|
| 197 |
+
# -> pure python lists (1D list or list-of-lists)
|
| 198 |
+
values = m_valid.detach().cpu().tolist()
|
| 199 |
+
|
| 200 |
+
# optional tids -> tokens
|
| 201 |
+
if to_tids:
|
| 202 |
+
|
| 203 |
+
def _to_tokish(x):
|
| 204 |
+
if isinstance(x, list):
|
| 205 |
+
return self.tids_to_str([int(v) for v in x])
|
| 206 |
+
else:
|
| 207 |
+
return self.tids_to_str([int(x)])[0]
|
| 208 |
+
|
| 209 |
+
values = [_to_tokish(v) for v in values]
|
| 210 |
+
|
| 211 |
+
# Ensure column exists with object dtype, then assign via aligned Series
|
| 212 |
+
if metric_id not in df.columns:
|
| 213 |
+
df[metric_id] = pd.Series(index=df.index, dtype=object)
|
| 214 |
+
|
| 215 |
+
if isinstance(values, np.ndarray):
|
| 216 |
+
values = values.tolist()
|
| 217 |
+
|
| 218 |
+
if len(values) != len(idx_list):
|
| 219 |
+
raise ValueError(
|
| 220 |
+
f"Length mismatch for '{metric_id}': values={len(values)} vs idx_list={len(idx_list)}"
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
df.loc[idx_list, metric_id] = pd.Series(
|
| 224 |
+
values, index=idx_list, dtype=object
|
| 225 |
+
)
|
| 226 |
+
self.tally[rollout_id] = df
|
| 227 |
+
|
| 228 |
+
# --------- Saving ---------
|
| 229 |
+
|
| 230 |
+
def save(self, path: str):
|
| 231 |
+
"""
|
| 232 |
+
Write a manifest JSON and one CSV per rollout.
|
| 233 |
+
|
| 234 |
+
- Manifest includes metadata only (safe to JSON).
|
| 235 |
+
- Each rollout CSV is written with index label 'timestep'.
|
| 236 |
+
- Only a single 'context' column (list[str]).
|
| 237 |
+
"""
|
| 238 |
+
if not self.tally or all(df.empty for df in self.tally.values()):
|
| 239 |
+
return
|
| 240 |
+
|
| 241 |
+
os.makedirs(path, exist_ok=True)
|
| 242 |
+
from datetime import datetime
|
| 243 |
+
|
| 244 |
+
now = datetime.now()
|
| 245 |
+
|
| 246 |
+
manifest = {
|
| 247 |
+
"created_at": f"{now:%Y-%m-%d %H:%M:%S}",
|
| 248 |
+
"max_context_length": self.max_context_length,
|
| 249 |
+
"num_rollouts": len(self.tally),
|
| 250 |
+
"rollouts": [],
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
for rid, df in self.tally.items():
|
| 254 |
+
rid_str = str(rid)
|
| 255 |
+
safe_name = self._sanitize_filename(rid_str)
|
| 256 |
+
csv_path = os.path.join(path, f"{safe_name}_tokenwise.csv")
|
| 257 |
+
|
| 258 |
+
# Put 'context' first, then the rest
|
| 259 |
+
cols = ["context"] + [c for c in df.columns if c != "context"]
|
| 260 |
+
try:
|
| 261 |
+
df[cols].to_csv(csv_path, index=True, index_label="timestep")
|
| 262 |
+
except Exception as e:
|
| 263 |
+
continue
|
| 264 |
+
|
| 265 |
+
manifest["rollouts"].append(
|
| 266 |
+
{
|
| 267 |
+
"rollout_id": rid_str,
|
| 268 |
+
"csv": csv_path,
|
| 269 |
+
"num_rows": int(df.shape[0]),
|
| 270 |
+
"columns": cols,
|
| 271 |
+
}
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
manifest_path = os.path.join(
|
| 275 |
+
path, f"tokenwise_manifest_{now:%Y-%m-%d___%H-%M-%S}.json"
|
| 276 |
+
)
|
| 277 |
+
with open(manifest_path, "w") as fp:
|
| 278 |
+
json.dump(manifest, fp, indent=2)
|
src_code_for_reproducibility/training/trainer_ad_align.py
ADDED
|
@@ -0,0 +1,505 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: mllm/training/trainer_ad_align.py
|
| 3 |
+
Summary: Trainer specialized for the advantage-alignment objective.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import copy
|
| 7 |
+
import logging
|
| 8 |
+
import sys
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from typing import Tuple
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 14 |
+
|
| 15 |
+
from mllm.markov_games.rollout_tree import (
|
| 16 |
+
ChatTurn,
|
| 17 |
+
RolloutTreeBranchNode,
|
| 18 |
+
RolloutTreeRootNode,
|
| 19 |
+
)
|
| 20 |
+
from mllm.training.credit_methods import (
|
| 21 |
+
get_advantage_alignment_credits,
|
| 22 |
+
get_discounted_state_visitation_credits,
|
| 23 |
+
)
|
| 24 |
+
from mllm.training.tally_metrics import Tally
|
| 25 |
+
from mllm.training.tally_rollout import RolloutTally, RolloutTallyItem
|
| 26 |
+
from mllm.training.tally_tokenwise import ContextualizedTokenwiseTally
|
| 27 |
+
from mllm.training.tokenize_chats import process_training_chat
|
| 28 |
+
from mllm.training.trainer_common import BaseTrainer
|
| 29 |
+
from mllm.training.training_data_utils import (
|
| 30 |
+
AdvantagePacket,
|
| 31 |
+
TrainingBatch,
|
| 32 |
+
TrainingChatTurn,
|
| 33 |
+
TrajectoryBatch,
|
| 34 |
+
get_main_chat_list_and_rewards,
|
| 35 |
+
get_tokenwise_credits,
|
| 36 |
+
)
|
| 37 |
+
from mllm.utils.resource_context import resource_logger_context
|
| 38 |
+
|
| 39 |
+
logger = logging.getLogger(__name__)
|
| 40 |
+
logger.addHandler(logging.StreamHandler(sys.stdout))
|
| 41 |
+
|
| 42 |
+
RolloutId = int
|
| 43 |
+
AgentId = str
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@dataclass
|
| 47 |
+
class AdAlignTrainingData:
|
| 48 |
+
"""Holds tensorized rollouts plus precomputed advantages for one agent."""
|
| 49 |
+
|
| 50 |
+
agent_id: str
|
| 51 |
+
main_data: TrajectoryBatch
|
| 52 |
+
# list-of-tensors: per rollout advantages with length jT
|
| 53 |
+
main_advantages: list[torch.FloatTensor] | None = None
|
| 54 |
+
# list-of-tensors: per rollout matrix (jT, A)
|
| 55 |
+
alternative_advantages: list[torch.FloatTensor] | None = None
|
| 56 |
+
advantage_alignment_credits: list[torch.FloatTensor] | None = None
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def get_alternative_chat_histories(
|
| 60 |
+
agent_id: str, root: RolloutTreeRootNode
|
| 61 |
+
) -> list[list[TrainingChatTurn], list[torch.FloatTensor]]:
|
| 62 |
+
"""
|
| 63 |
+
Traverse every unilateral branch under ``root`` and collect chat/reward histories.
|
| 64 |
+
|
| 65 |
+
Returns
|
| 66 |
+
-------
|
| 67 |
+
alternative_chats:
|
| 68 |
+
Flattened list of chat turns for each branch (ordered by branch depth).
|
| 69 |
+
alternative_rewards:
|
| 70 |
+
Matching list of reward tensors aligned with the chat history.
|
| 71 |
+
"""
|
| 72 |
+
current_node = root.child
|
| 73 |
+
branches = current_node.branches
|
| 74 |
+
pre_branch_chat = []
|
| 75 |
+
pre_branch_rewards = []
|
| 76 |
+
alternative_rewards = []
|
| 77 |
+
alternative_chats = []
|
| 78 |
+
while current_node is not None:
|
| 79 |
+
assert isinstance(
|
| 80 |
+
current_node, RolloutTreeBranchNode
|
| 81 |
+
), "Current node should be a branch node."
|
| 82 |
+
main_node = current_node.main_child
|
| 83 |
+
branches = current_node.branches
|
| 84 |
+
current_node = main_node.child
|
| 85 |
+
|
| 86 |
+
# Get the `A` alternative trajectories
|
| 87 |
+
alternative_nodes = branches[agent_id]
|
| 88 |
+
for alt_node in alternative_nodes:
|
| 89 |
+
post_branch_chat, post_branch_rewards = get_main_chat_list_and_rewards(
|
| 90 |
+
agent_id=agent_id, root=alt_node
|
| 91 |
+
)
|
| 92 |
+
branch_chat = pre_branch_chat + post_branch_chat
|
| 93 |
+
alternative_chats.append(branch_chat)
|
| 94 |
+
alternative_rewards.append(
|
| 95 |
+
torch.cat([torch.tensor(pre_branch_rewards), post_branch_rewards])
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
chat_turns: list[ChatTurn] = main_node.step_log.action_logs[agent_id].chat_turns
|
| 99 |
+
chat_turns: list[TrainingChatTurn] = [
|
| 100 |
+
TrainingChatTurn(time_step=main_node.time_step, **turn.model_dump())
|
| 101 |
+
for turn in chat_turns
|
| 102 |
+
]
|
| 103 |
+
|
| 104 |
+
pre_branch_chat.extend(chat_turns)
|
| 105 |
+
pre_branch_rewards.append(
|
| 106 |
+
main_node.step_log.simulation_step_log.rewards[agent_id]
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
return alternative_chats, alternative_rewards
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class TrainerAdAlign(BaseTrainer):
|
| 113 |
+
"""
|
| 114 |
+
Extends the reinforce trainer to support Advantage Alignment.
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
def __init__(
|
| 118 |
+
self,
|
| 119 |
+
ad_align_beta: float,
|
| 120 |
+
ad_align_gamma: float,
|
| 121 |
+
ad_align_exclude_k_equals_t: bool,
|
| 122 |
+
ad_align_use_sign: bool,
|
| 123 |
+
ad_align_clipping: float,
|
| 124 |
+
ad_align_force_coop_first_step: bool,
|
| 125 |
+
use_old_ad_align: bool,
|
| 126 |
+
use_time_regularization: bool,
|
| 127 |
+
rloo_branch: bool,
|
| 128 |
+
reuse_baseline: bool,
|
| 129 |
+
ad_align_beta_anneal_step: int = -1,
|
| 130 |
+
ad_align_beta_anneal_rate: float = 0.5,
|
| 131 |
+
min_ad_align_beta: float = 0.1,
|
| 132 |
+
mean_normalize_ad_align: bool = False,
|
| 133 |
+
whiten_adalign_advantages: bool = False,
|
| 134 |
+
whiten_adalign_advantages_time_step_wise: bool = False,
|
| 135 |
+
ad_align_discount_t: bool = False,
|
| 136 |
+
*args,
|
| 137 |
+
**kwargs,
|
| 138 |
+
):
|
| 139 |
+
"""
|
| 140 |
+
Initialize the advantage alignment trainer.
|
| 141 |
+
Args:
|
| 142 |
+
ad_align_beta: Beta parameter for the advantage alignment.
|
| 143 |
+
ad_align_gamma: Gamma parameter for the advantage alignment.
|
| 144 |
+
ad_align_exclude_k_equals_t: Whether to include k = t in the advantage alignment.
|
| 145 |
+
ad_align_use_sign: Whether to use sign in the advantage alignment.
|
| 146 |
+
ad_align_clipping: Clipping value for the advantage alignment.
|
| 147 |
+
ad_align_force_coop_first_step: Whether to force coop on the first step of the advantage alignment.
|
| 148 |
+
"""
|
| 149 |
+
super().__init__(*args, **kwargs)
|
| 150 |
+
self.ad_align_beta = ad_align_beta
|
| 151 |
+
self.ad_align_gamma = ad_align_gamma
|
| 152 |
+
self.ad_align_exclude_k_equals_t = ad_align_exclude_k_equals_t
|
| 153 |
+
self.ad_align_use_sign = ad_align_use_sign
|
| 154 |
+
self.ad_align_clipping = ad_align_clipping
|
| 155 |
+
self.ad_align_force_coop_first_step = ad_align_force_coop_first_step
|
| 156 |
+
self.use_old_ad_align = use_old_ad_align
|
| 157 |
+
self.use_time_regularization = use_time_regularization
|
| 158 |
+
self.rloo_branch = rloo_branch
|
| 159 |
+
self.reuse_baseline = reuse_baseline
|
| 160 |
+
self.ad_align_beta_anneal_step = ad_align_beta_anneal_step
|
| 161 |
+
self.ad_align_beta_anneal_rate = ad_align_beta_anneal_rate
|
| 162 |
+
self.min_ad_align_beta = min_ad_align_beta
|
| 163 |
+
self.past_ad_align_step = -1
|
| 164 |
+
self.mean_normalize_ad_align = mean_normalize_ad_align
|
| 165 |
+
self.whiten_adalign_advantages = whiten_adalign_advantages
|
| 166 |
+
self.whiten_adalign_advantages_time_step_wise = (
|
| 167 |
+
whiten_adalign_advantages_time_step_wise
|
| 168 |
+
)
|
| 169 |
+
self.ad_align_discount_t = ad_align_discount_t
|
| 170 |
+
self.training_data: dict[AgentId, AdAlignTrainingData] = {}
|
| 171 |
+
self.debug_path_list: list[str] = []
|
| 172 |
+
|
| 173 |
+
def set_agent_trajectory_data(
|
| 174 |
+
self, agent_id: str, roots: list[RolloutTreeRootNode]
|
| 175 |
+
):
|
| 176 |
+
"""
|
| 177 |
+
Materialize main and alternative trajectory tensors used by the advantage-alignment trainer.
|
| 178 |
+
"""
|
| 179 |
+
|
| 180 |
+
B = len(roots) # Number of rollouts
|
| 181 |
+
|
| 182 |
+
# For main rollouts
|
| 183 |
+
batch_rollout_ids = []
|
| 184 |
+
batch_crn_ids = []
|
| 185 |
+
batch_input_ids = []
|
| 186 |
+
batch_action_mask = []
|
| 187 |
+
batch_entropy_mask = []
|
| 188 |
+
batch_timesteps = []
|
| 189 |
+
batch_state_ends_mask = []
|
| 190 |
+
batch_engine_log_probs = []
|
| 191 |
+
batch_rewards = []
|
| 192 |
+
|
| 193 |
+
# For alternative actions rollouts
|
| 194 |
+
batch_branching_time_steps = []
|
| 195 |
+
alternative_batch_input_ids = []
|
| 196 |
+
alternative_batch_action_mask = []
|
| 197 |
+
alternative_batch_entropy_mask = []
|
| 198 |
+
alternative_batch_timesteps = []
|
| 199 |
+
alternative_batch_state_ends_mask = []
|
| 200 |
+
alternative_batch_engine_log_probs = []
|
| 201 |
+
alternative_batch_rewards = []
|
| 202 |
+
jT_list = []
|
| 203 |
+
|
| 204 |
+
try:
|
| 205 |
+
A = len(roots[0].child.branches[agent_id]) # Number of alternative actions
|
| 206 |
+
except:
|
| 207 |
+
A = 0
|
| 208 |
+
|
| 209 |
+
for root in roots:
|
| 210 |
+
rollout_id = root.id
|
| 211 |
+
self.debug_path_list.append(
|
| 212 |
+
"mgid:" + str(rollout_id) + "_agent_id:" + agent_id
|
| 213 |
+
)
|
| 214 |
+
# Get main trajectory
|
| 215 |
+
batch_rollout_ids.append(rollout_id)
|
| 216 |
+
batch_crn_ids.append(root.crn_id)
|
| 217 |
+
main_chat, main_rewards = get_main_chat_list_and_rewards(
|
| 218 |
+
agent_id=agent_id, root=root
|
| 219 |
+
)
|
| 220 |
+
(
|
| 221 |
+
input_ids,
|
| 222 |
+
action_mask,
|
| 223 |
+
entropy_mask,
|
| 224 |
+
timesteps,
|
| 225 |
+
state_ends_mask,
|
| 226 |
+
engine_log_probs,
|
| 227 |
+
) = process_training_chat(
|
| 228 |
+
tokenizer=self.tokenizer,
|
| 229 |
+
chat_history=main_chat,
|
| 230 |
+
entropy_mask_regex=self.entropy_mask_regex,
|
| 231 |
+
exploration_prompts_to_remove=self.exploration_prompts_to_remove,
|
| 232 |
+
)
|
| 233 |
+
batch_input_ids.append(input_ids)
|
| 234 |
+
batch_action_mask.append(action_mask)
|
| 235 |
+
batch_entropy_mask.append(entropy_mask)
|
| 236 |
+
batch_timesteps.append(timesteps)
|
| 237 |
+
batch_state_ends_mask.append(state_ends_mask)
|
| 238 |
+
batch_engine_log_probs.append(engine_log_probs)
|
| 239 |
+
batch_rewards.append(main_rewards)
|
| 240 |
+
jT = (
|
| 241 |
+
main_rewards.numel()
|
| 242 |
+
) # Number of timesteps inferred from reward tensor length.
|
| 243 |
+
jT_list.append(jT)
|
| 244 |
+
if A > 0:
|
| 245 |
+
# We get the branching time steps for each of the `jT` time steps in the main trajectory.
|
| 246 |
+
branching_time_steps = [bt for item in range(jT) for bt in A * [item]]
|
| 247 |
+
batch_branching_time_steps.extend(branching_time_steps)
|
| 248 |
+
|
| 249 |
+
# Get all of the (jT*A) alternative trajectories in the tree
|
| 250 |
+
# (jT is the number of time steps in the main trajectory, A is the number of alternative actions)
|
| 251 |
+
alternative_chats, alternative_rewards = get_alternative_chat_histories(
|
| 252 |
+
agent_id=agent_id, root=root
|
| 253 |
+
)
|
| 254 |
+
assert (
|
| 255 |
+
len(alternative_chats) == A * jT
|
| 256 |
+
), "Incorrect number of alternative trajectories."
|
| 257 |
+
|
| 258 |
+
for chat, rewards in zip(alternative_chats, alternative_rewards):
|
| 259 |
+
(
|
| 260 |
+
input_ids,
|
| 261 |
+
action_mask,
|
| 262 |
+
entropy_mask,
|
| 263 |
+
timesteps,
|
| 264 |
+
state_ends_mask,
|
| 265 |
+
engine_log_probs,
|
| 266 |
+
) = process_training_chat(
|
| 267 |
+
tokenizer=self.tokenizer,
|
| 268 |
+
chat_history=chat,
|
| 269 |
+
entropy_mask_regex=self.entropy_mask_regex,
|
| 270 |
+
exploration_prompts_to_remove=self.exploration_prompts_to_remove,
|
| 271 |
+
)
|
| 272 |
+
alternative_batch_input_ids.append(input_ids)
|
| 273 |
+
alternative_batch_action_mask.append(action_mask)
|
| 274 |
+
alternative_batch_entropy_mask.append(entropy_mask)
|
| 275 |
+
alternative_batch_timesteps.append(timesteps)
|
| 276 |
+
alternative_batch_state_ends_mask.append(state_ends_mask)
|
| 277 |
+
alternative_batch_engine_log_probs.append(engine_log_probs)
|
| 278 |
+
alternative_batch_rewards.append(rewards)
|
| 279 |
+
|
| 280 |
+
jT_list = torch.Tensor(jT_list)
|
| 281 |
+
|
| 282 |
+
# Assert that number of alternative actions is constant
|
| 283 |
+
# assert len(set(nb_alternative_actions)) == 1, "Number of alternative actions must be constant"
|
| 284 |
+
# A = nb_alternative_actions[0]
|
| 285 |
+
|
| 286 |
+
trajectory_batch = TrajectoryBatch(
|
| 287 |
+
rollout_ids=torch.tensor(batch_rollout_ids, dtype=torch.int32), # (B,)
|
| 288 |
+
crn_ids=torch.tensor(batch_crn_ids, dtype=torch.int32),
|
| 289 |
+
agent_ids=[agent_id] * len(batch_rollout_ids),
|
| 290 |
+
batch_input_ids=batch_input_ids,
|
| 291 |
+
batch_action_mask=batch_action_mask,
|
| 292 |
+
batch_entropy_mask=batch_entropy_mask,
|
| 293 |
+
batch_timesteps=batch_timesteps,
|
| 294 |
+
batch_state_ends_mask=batch_state_ends_mask,
|
| 295 |
+
batch_engine_log_probs=batch_engine_log_probs,
|
| 296 |
+
batch_rewards=batch_rewards,
|
| 297 |
+
)
|
| 298 |
+
# Get Advantages & Train Critic
|
| 299 |
+
with resource_logger_context(
|
| 300 |
+
logger, "Get advantages with critic gradient accumulation"
|
| 301 |
+
):
|
| 302 |
+
self.batch_advantages: torch.FloatTensor = (
|
| 303 |
+
self.get_advantages_with_critic_gradient_accumulation(trajectory_batch)
|
| 304 |
+
) # (B, jT)
|
| 305 |
+
|
| 306 |
+
if A > 0:
|
| 307 |
+
# Here, `A` is the number of alternative actions / trajectories taken at each time step.
|
| 308 |
+
# For each of the `B` rollout perspectives, at each of its jT (`j` is for jagged, since each main rollout may be of a different length) steps, we take A alternate trajectories (from different actions).
|
| 309 |
+
# Therefore, we have ∑jT * A trajectories to process. If each of the main trajectories have T steps, we will have `B*T*A` to process.
|
| 310 |
+
with resource_logger_context(logger, "Create alternative trajectory batch"):
|
| 311 |
+
sum_jT = int(torch.sum(jT_list).item())
|
| 312 |
+
jT_list = (
|
| 313 |
+
jT_list.int().tolist()
|
| 314 |
+
) # (jT,) # (we only want the advantages where we branched out)
|
| 315 |
+
alternative_trajectory_batch = TrajectoryBatch(
|
| 316 |
+
rollout_ids=torch.zeros(A * sum_jT, dtype=torch.int32),
|
| 317 |
+
crn_ids=torch.zeros(A * sum_jT, dtype=torch.int32),
|
| 318 |
+
agent_ids=[agent_id] * (A * sum_jT),
|
| 319 |
+
batch_input_ids=alternative_batch_input_ids,
|
| 320 |
+
batch_action_mask=alternative_batch_action_mask,
|
| 321 |
+
batch_entropy_mask=alternative_batch_entropy_mask,
|
| 322 |
+
batch_timesteps=alternative_batch_timesteps,
|
| 323 |
+
batch_state_ends_mask=alternative_batch_state_ends_mask,
|
| 324 |
+
batch_engine_log_probs=alternative_batch_engine_log_probs,
|
| 325 |
+
batch_rewards=alternative_batch_rewards,
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
# Get alternative advantages
|
| 329 |
+
# BAAs stands for batch alternative advantages
|
| 330 |
+
# (torch nested tensors have very little api support, so we have to do some odd manual work here)
|
| 331 |
+
with resource_logger_context(
|
| 332 |
+
logger, "Compute alternative advantage estimates"
|
| 333 |
+
):
|
| 334 |
+
BAAs_list = self.get_advantages_with_critic_gradient_accumulation(
|
| 335 |
+
alternative_trajectory_batch
|
| 336 |
+
) # list length (∑jT * A), each (jT',)
|
| 337 |
+
# Pad alternative advantages to (∑jT*A, P)
|
| 338 |
+
|
| 339 |
+
BAAs_padded = pad_sequence(
|
| 340 |
+
BAAs_list, batch_first=True, padding_value=0.0
|
| 341 |
+
)
|
| 342 |
+
branch_idx = torch.tensor(
|
| 343 |
+
batch_branching_time_steps,
|
| 344 |
+
device=BAAs_padded.device,
|
| 345 |
+
dtype=torch.long,
|
| 346 |
+
)
|
| 347 |
+
gathered = BAAs_padded.gather(
|
| 348 |
+
dim=1, index=branch_idx.unsqueeze(1)
|
| 349 |
+
).squeeze(1)
|
| 350 |
+
# Reshape and split per rollout, then transpose to (jT_i, A)
|
| 351 |
+
gathered = gathered.view(A, sum_jT) # (A, ∑jT)
|
| 352 |
+
blocks = list(
|
| 353 |
+
torch.split(gathered, jT_list, dim=1)
|
| 354 |
+
) # len B, shapes (A, jT_i)
|
| 355 |
+
BAAs = [
|
| 356 |
+
blk.transpose(0, 1).contiguous() for blk in blocks
|
| 357 |
+
] # list of (jT_i, A)
|
| 358 |
+
if self.ad_align_beta_anneal_step > 0:
|
| 359 |
+
max_rollout_id = torch.max(trajectory_batch.rollout_ids) + 1
|
| 360 |
+
if (
|
| 361 |
+
max_rollout_id % self.ad_align_beta_anneal_step == 0
|
| 362 |
+
and self.past_ad_align_step != max_rollout_id
|
| 363 |
+
):
|
| 364 |
+
self.ad_align_beta = max(
|
| 365 |
+
self.ad_align_beta * self.ad_align_beta_anneal_rate,
|
| 366 |
+
self.min_ad_align_beta,
|
| 367 |
+
)
|
| 368 |
+
logger.info(f"Annealing ad_align_beta to {self.ad_align_beta}")
|
| 369 |
+
self.past_ad_align_step = max_rollout_id
|
| 370 |
+
self.training_data[agent_id] = AdAlignTrainingData(
|
| 371 |
+
agent_id=agent_id,
|
| 372 |
+
main_data=trajectory_batch,
|
| 373 |
+
main_advantages=self.batch_advantages,
|
| 374 |
+
alternative_advantages=BAAs if A > 0 else None,
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
def share_advantage_data(self) -> list[AdvantagePacket]:
|
| 378 |
+
"""
|
| 379 |
+
Share the advantage alignment data with other agents.
|
| 380 |
+
Returns:
|
| 381 |
+
AdvantagePacket: The advantage packet containing the agent's advantages.
|
| 382 |
+
"""
|
| 383 |
+
logger.info(f"Sharing advantage alignment data.")
|
| 384 |
+
advantage_packets = []
|
| 385 |
+
for _, agent_data in self.training_data.items():
|
| 386 |
+
advantage_packets.append(
|
| 387 |
+
AdvantagePacket(
|
| 388 |
+
agent_id=agent_data.agent_id,
|
| 389 |
+
rollout_ids=agent_data.main_data.rollout_ids,
|
| 390 |
+
main_advantages=agent_data.main_advantages,
|
| 391 |
+
)
|
| 392 |
+
)
|
| 393 |
+
return advantage_packets
|
| 394 |
+
|
| 395 |
+
def receive_advantage_data(self, advantage_packets: list[AdvantagePacket]):
|
| 396 |
+
"""
|
| 397 |
+
Receive advantage packets from other players.
|
| 398 |
+
These contain the advantages of the other players' rollouts estimated by them.
|
| 399 |
+
"""
|
| 400 |
+
logger.info(f"Receiving advantage packets.")
|
| 401 |
+
|
| 402 |
+
assert (
|
| 403 |
+
len(advantage_packets) > 0
|
| 404 |
+
), "At least one advantage packet must be provided."
|
| 405 |
+
|
| 406 |
+
for agent_id, agent_data in self.training_data.items():
|
| 407 |
+
coagent_advantage_packets = [
|
| 408 |
+
packet for packet in advantage_packets if packet.agent_id != agent_id
|
| 409 |
+
]
|
| 410 |
+
agent_rollout_ids = agent_data.main_data.rollout_ids
|
| 411 |
+
agent_advantages = agent_data.main_advantages
|
| 412 |
+
co_agent_advantages = []
|
| 413 |
+
for rollout_id in agent_rollout_ids:
|
| 414 |
+
for co_agent_packet in coagent_advantage_packets:
|
| 415 |
+
if rollout_id in co_agent_packet.rollout_ids:
|
| 416 |
+
index = torch.where(rollout_id == co_agent_packet.rollout_ids)[
|
| 417 |
+
0
|
| 418 |
+
].item()
|
| 419 |
+
co_agent_advantages.append(
|
| 420 |
+
co_agent_packet.main_advantages[index]
|
| 421 |
+
)
|
| 422 |
+
# assumes that its two player game, with one co-agent
|
| 423 |
+
break
|
| 424 |
+
assert len(co_agent_advantages) == len(agent_advantages)
|
| 425 |
+
B = len(agent_advantages)
|
| 426 |
+
assert all(
|
| 427 |
+
a.shape[0] == b.shape[0]
|
| 428 |
+
for a, b in zip(co_agent_advantages, agent_advantages)
|
| 429 |
+
), "Number of advantages must match for advantage alignment."
|
| 430 |
+
|
| 431 |
+
# Get padded tensors (advantage alignment is invariant to padding)
|
| 432 |
+
lengths = torch.tensor(
|
| 433 |
+
[len(t) for t in agent_advantages],
|
| 434 |
+
device=self.device,
|
| 435 |
+
dtype=torch.long,
|
| 436 |
+
)
|
| 437 |
+
padded_main_advantages = pad_sequence(
|
| 438 |
+
agent_advantages, batch_first=True, padding_value=0.0
|
| 439 |
+
)
|
| 440 |
+
if agent_data.alternative_advantages:
|
| 441 |
+
padded_alternative_advantages = pad_sequence(
|
| 442 |
+
agent_data.alternative_advantages,
|
| 443 |
+
batch_first=True,
|
| 444 |
+
padding_value=0.0,
|
| 445 |
+
) # (B, P, A)
|
| 446 |
+
else:
|
| 447 |
+
padded_alternative_advantages = None
|
| 448 |
+
padded_co_agent_advantages = pad_sequence(
|
| 449 |
+
co_agent_advantages, batch_first=True, padding_value=0.0
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
# Create training batch data
|
| 453 |
+
credits, sub_tensors = get_advantage_alignment_credits(
|
| 454 |
+
a1=padded_main_advantages,
|
| 455 |
+
a1_alternative=padded_alternative_advantages,
|
| 456 |
+
a2=padded_co_agent_advantages,
|
| 457 |
+
beta=self.ad_align_beta,
|
| 458 |
+
gamma=self.ad_align_gamma,
|
| 459 |
+
exclude_k_equals_t=self.ad_align_exclude_k_equals_t,
|
| 460 |
+
use_sign=self.ad_align_use_sign,
|
| 461 |
+
clipping=self.ad_align_clipping,
|
| 462 |
+
force_coop_first_step=self.ad_align_force_coop_first_step,
|
| 463 |
+
use_old_ad_align=self.use_old_ad_align,
|
| 464 |
+
use_time_regularization=self.use_time_regularization,
|
| 465 |
+
rloo_branch=self.rloo_branch,
|
| 466 |
+
reuse_baseline=self.reuse_baseline,
|
| 467 |
+
mean_normalize_ad_align=self.mean_normalize_ad_align,
|
| 468 |
+
whiten_adalign_advantages=self.whiten_adalign_advantages,
|
| 469 |
+
whiten_adalign_advantages_time_step_wise=self.whiten_adalign_advantages_time_step_wise,
|
| 470 |
+
discount_t=self.ad_align_discount_t,
|
| 471 |
+
)
|
| 472 |
+
for key, value in sub_tensors.items():
|
| 473 |
+
self.rollout_tally.add_metric(
|
| 474 |
+
path=[key],
|
| 475 |
+
rollout_tally_item=RolloutTallyItem(
|
| 476 |
+
crn_ids=agent_data.main_data.crn_ids,
|
| 477 |
+
rollout_ids=agent_data.main_data.rollout_ids,
|
| 478 |
+
agent_ids=agent_data.main_data.agent_ids,
|
| 479 |
+
metric_matrix=value,
|
| 480 |
+
),
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
if not self.skip_discounted_state_visitation:
|
| 484 |
+
credits = get_discounted_state_visitation_credits(
|
| 485 |
+
credits,
|
| 486 |
+
self.discount_factor,
|
| 487 |
+
)
|
| 488 |
+
self.rollout_tally.add_metric(
|
| 489 |
+
path=["discounted_state_visitation_credits"],
|
| 490 |
+
rollout_tally_item=RolloutTallyItem(
|
| 491 |
+
crn_ids=agent_data.main_data.crn_ids,
|
| 492 |
+
rollout_ids=agent_data.main_data.rollout_ids,
|
| 493 |
+
agent_ids=agent_data.main_data.agent_ids,
|
| 494 |
+
metric_matrix=sub_tensors[
|
| 495 |
+
"discounted_state_visitation_credits"
|
| 496 |
+
],
|
| 497 |
+
),
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
# Slice back to jagged
|
| 501 |
+
advantage_alignment_credits = [credits[i, : lengths[i]] for i in range(B)]
|
| 502 |
+
# Replace stored training data for this agent by the concrete trajectory batch
|
| 503 |
+
# and attach the computed credits for policy gradient.
|
| 504 |
+
self.training_data[agent_id] = agent_data.main_data
|
| 505 |
+
self.training_data[agent_id].batch_credits = advantage_alignment_credits
|
src_code_for_reproducibility/training/trainer_common.py
ADDED
|
@@ -0,0 +1,1032 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: mllm/training/trainer_common.py
|
| 3 |
+
Summary: Shared trainer utilities, base classes, and gradient helpers.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
import pickle
|
| 9 |
+
import sys
|
| 10 |
+
from abc import ABC, abstractmethod
|
| 11 |
+
from typing import Callable, Literal, Union
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
from accelerate import Accelerator
|
| 17 |
+
from pandas._libs.tslibs.offsets import CBMonthBegin
|
| 18 |
+
from peft import LoraConfig
|
| 19 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 20 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 21 |
+
|
| 22 |
+
from mllm.markov_games.rollout_tree import *
|
| 23 |
+
from mllm.markov_games.rollout_tree import RolloutTreeRootNode
|
| 24 |
+
from mllm.training.annealing_methods import sigmoid_annealing
|
| 25 |
+
from mllm.training.credit_methods import (
|
| 26 |
+
get_discounted_returns,
|
| 27 |
+
get_generalized_advantage_estimates,
|
| 28 |
+
get_rloo_credits,
|
| 29 |
+
whiten_advantages,
|
| 30 |
+
whiten_advantages_time_step_wise,
|
| 31 |
+
)
|
| 32 |
+
from mllm.training.tally_metrics import Tally
|
| 33 |
+
from mllm.training.tally_rollout import RolloutTally, RolloutTallyItem
|
| 34 |
+
from mllm.training.tally_tokenwise import ContextualizedTokenwiseTally
|
| 35 |
+
from mllm.training.tokenize_chats import *
|
| 36 |
+
from mllm.training.tokenize_chats import process_training_chat
|
| 37 |
+
from mllm.training.training_data_utils import *
|
| 38 |
+
from mllm.training.training_data_utils import (
|
| 39 |
+
TrainingBatch,
|
| 40 |
+
TrajectoryBatch,
|
| 41 |
+
get_tokenwise_credits,
|
| 42 |
+
)
|
| 43 |
+
from mllm.utils.resource_context import resource_logger_context
|
| 44 |
+
|
| 45 |
+
logger = logging.getLogger(__name__)
|
| 46 |
+
logger.addHandler(logging.StreamHandler(sys.stdout))
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@dataclass
|
| 50 |
+
class TrainerAnnealingState:
|
| 51 |
+
annealing_step_counter: int = 0
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class BaseTrainer(ABC):
|
| 55 |
+
"""
|
| 56 |
+
Shared scaffolding for policy-gradient trainers (optimizer wiring, logging, etc.).
|
| 57 |
+
|
| 58 |
+
Subclasses implement `set_agent_trajectory_data` / `share_advantage_data`
|
| 59 |
+
to plug in algorithm-specific behavior.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
def __init__(
|
| 63 |
+
self,
|
| 64 |
+
policy: AutoModelForCausalLM,
|
| 65 |
+
policy_optimizer: torch.optim.Optimizer,
|
| 66 |
+
critic: Union[AutoModelForCausalLM, None],
|
| 67 |
+
critic_optimizer: Union[torch.optim.Optimizer, None],
|
| 68 |
+
tokenizer: AutoTokenizer,
|
| 69 |
+
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
|
| 70 |
+
critic_lr_scheduler: Union[torch.optim.lr_scheduler.LRScheduler, None],
|
| 71 |
+
######################################################################
|
| 72 |
+
entropy_coeff: float,
|
| 73 |
+
entropy_topk: int,
|
| 74 |
+
entropy_mask_regex: Union[str, None],
|
| 75 |
+
kl_coeff: float,
|
| 76 |
+
gradient_clipping: Union[float, None],
|
| 77 |
+
restrict_tokens: Union[list[str], None],
|
| 78 |
+
mini_batch_size: int,
|
| 79 |
+
use_gradient_checkpointing: bool,
|
| 80 |
+
temperature: float,
|
| 81 |
+
device: str,
|
| 82 |
+
whiten_advantages: bool,
|
| 83 |
+
whiten_advantages_time_step_wise: bool,
|
| 84 |
+
use_gae: bool,
|
| 85 |
+
use_gae_lambda_annealing: bool,
|
| 86 |
+
gae_lambda_annealing_limit: float,
|
| 87 |
+
gae_lambda_annealing_method: Literal["sigmoid_annealing"],
|
| 88 |
+
gae_lambda_annealing_method_params: dict,
|
| 89 |
+
pg_loss_normalization: Literal["batch", "nb_tokens"],
|
| 90 |
+
use_rloo: bool,
|
| 91 |
+
skip_discounted_state_visitation: bool,
|
| 92 |
+
discount_factor: float,
|
| 93 |
+
enable_tokenwise_logging: bool,
|
| 94 |
+
save_path: str,
|
| 95 |
+
reward_normalizing_constant: float = 1.0,
|
| 96 |
+
critic_loss_type: Literal["mse", "huber"] = "huber",
|
| 97 |
+
exploration_prompts_to_remove: list[str] = [],
|
| 98 |
+
filter_higher_refprob_tokens_kl: bool = False,
|
| 99 |
+
truncated_importance_sampling_ratio_cap: float = 0.0,
|
| 100 |
+
importance_sampling_strategy: Literal[
|
| 101 |
+
"per_token", "per_sequence"
|
| 102 |
+
] = "per_token",
|
| 103 |
+
no_rloo_grouping: bool = False,
|
| 104 |
+
):
|
| 105 |
+
"""
|
| 106 |
+
Initialize the REINFORCE trainer with reward shaping for multi-agent or single-agent training.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
model (AutoModelForCausalLM): The main policy model.
|
| 110 |
+
tokenizer (AutoTokenizer): Tokenizer for the model.
|
| 111 |
+
optimizer (torch.optim.Optimizer): Optimizer for the policy model.
|
| 112 |
+
lr_scheduler (torch.optim.lr_scheduler.LRScheduler): Learning rate scheduler for the policy model.
|
| 113 |
+
critic (AutoModelForCausalLM or None): Critic model for value estimation (optional).
|
| 114 |
+
critic_optimizer (torch.optim.Optimizer or None): Optimizer for the critic model (optional).
|
| 115 |
+
critic_lr_scheduler (torch.optim.lr_scheduler.LRScheduler or None): LR scheduler for the critic (optional).
|
| 116 |
+
config (RtConfig): Configuration object for training.
|
| 117 |
+
"""
|
| 118 |
+
self.tokenizer = tokenizer
|
| 119 |
+
# self.tokenizer.padding_side = "left" # needed for flash attention
|
| 120 |
+
if self.tokenizer.pad_token_id is None:
|
| 121 |
+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
| 122 |
+
self.lr_scheduler = lr_scheduler
|
| 123 |
+
self.accelerator = Accelerator()
|
| 124 |
+
(
|
| 125 |
+
self.policy,
|
| 126 |
+
self.policy_optimizer,
|
| 127 |
+
self.critic,
|
| 128 |
+
self.critic_optimizer,
|
| 129 |
+
) = self.accelerator.prepare(policy, policy_optimizer, critic, critic_optimizer)
|
| 130 |
+
|
| 131 |
+
self.critic_lr_scheduler = critic_lr_scheduler
|
| 132 |
+
self.tally = Tally()
|
| 133 |
+
|
| 134 |
+
if use_gradient_checkpointing == True:
|
| 135 |
+
self.policy.gradient_checkpointing_enable(dict(use_reentrant=False))
|
| 136 |
+
if critic is not None:
|
| 137 |
+
self.critic.gradient_checkpointing_enable(dict(use_reentrant=False))
|
| 138 |
+
|
| 139 |
+
self.save_path = save_path
|
| 140 |
+
|
| 141 |
+
# Load trainer state if it exists
|
| 142 |
+
self.trainer_annealing_state_path = os.path.join(
|
| 143 |
+
self.save_path, "trainer_annealing_state.pkl"
|
| 144 |
+
)
|
| 145 |
+
if os.path.exists(self.trainer_annealing_state_path):
|
| 146 |
+
logger.info(
|
| 147 |
+
f"Loading trainer state from {self.trainer_annealing_state_path}"
|
| 148 |
+
)
|
| 149 |
+
self.trainer_annealing_state = pickle.load(
|
| 150 |
+
open(self.trainer_annealing_state_path, "rb")
|
| 151 |
+
)
|
| 152 |
+
else:
|
| 153 |
+
self.trainer_annealing_state = TrainerAnnealingState()
|
| 154 |
+
|
| 155 |
+
# Load policy optimizer state if it exists
|
| 156 |
+
self.policy_optimizer_path = os.path.join(
|
| 157 |
+
self.save_path, "policy_optimizer_state.pt"
|
| 158 |
+
)
|
| 159 |
+
if os.path.exists(self.policy_optimizer_path):
|
| 160 |
+
logger.info(
|
| 161 |
+
f"Loading policy optimizer state from {self.policy_optimizer_path}"
|
| 162 |
+
)
|
| 163 |
+
self.policy_optimizer.load_state_dict(
|
| 164 |
+
torch.load(self.policy_optimizer_path)
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# Load critic optimizer state if it exists
|
| 168 |
+
self.critic_optimizer_path = os.path.join(
|
| 169 |
+
self.save_path, "critic_optimizer_state.pt"
|
| 170 |
+
)
|
| 171 |
+
if (
|
| 172 |
+
os.path.exists(self.critic_optimizer_path)
|
| 173 |
+
and self.critic_optimizer is not None
|
| 174 |
+
):
|
| 175 |
+
logger.info(
|
| 176 |
+
f"Loading critic optimizer state from {self.critic_optimizer_path}"
|
| 177 |
+
)
|
| 178 |
+
self.critic_optimizer.load_state_dict(
|
| 179 |
+
torch.load(self.critic_optimizer_path)
|
| 180 |
+
)
|
| 181 |
+
self.device = self.accelerator.device
|
| 182 |
+
self.entropy_coeff = entropy_coeff
|
| 183 |
+
self.entropy_topk = entropy_topk
|
| 184 |
+
self.entropy_mask_regex = entropy_mask_regex
|
| 185 |
+
self.kl_coeff = kl_coeff
|
| 186 |
+
self.gradient_clipping = gradient_clipping
|
| 187 |
+
self.restrict_tokens = restrict_tokens
|
| 188 |
+
self.mini_batch_size = mini_batch_size
|
| 189 |
+
self.use_gradient_checkpointing = use_gradient_checkpointing
|
| 190 |
+
self.temperature = temperature
|
| 191 |
+
self.use_gae = use_gae
|
| 192 |
+
self.whiten_advantages = whiten_advantages
|
| 193 |
+
self.whiten_advantages_time_step_wise = whiten_advantages_time_step_wise
|
| 194 |
+
self.use_rloo = use_rloo
|
| 195 |
+
self.skip_discounted_state_visitation = skip_discounted_state_visitation
|
| 196 |
+
self.use_gae_lambda_annealing = use_gae_lambda_annealing
|
| 197 |
+
self.gae_lambda_annealing_limit = gae_lambda_annealing_limit
|
| 198 |
+
if use_gae_lambda_annealing:
|
| 199 |
+
self.gae_lambda_annealing_method: Callable[
|
| 200 |
+
[int], float
|
| 201 |
+
] = lambda step: eval(gae_lambda_annealing_method)(
|
| 202 |
+
step=step, **gae_lambda_annealing_method_params
|
| 203 |
+
)
|
| 204 |
+
self.discount_factor = discount_factor
|
| 205 |
+
self.enable_tokenwise_logging = enable_tokenwise_logging
|
| 206 |
+
self.reward_normalizing_constant = reward_normalizing_constant
|
| 207 |
+
self.pg_loss_normalization = pg_loss_normalization
|
| 208 |
+
self.critic_loss_type = critic_loss_type
|
| 209 |
+
self.exploration_prompts_to_remove = exploration_prompts_to_remove
|
| 210 |
+
# Common containers used by all trainers
|
| 211 |
+
self.training_data: dict = {}
|
| 212 |
+
self.debug_path_list: list[str] = []
|
| 213 |
+
self.policy_gradient_data = None
|
| 214 |
+
self.tally = Tally()
|
| 215 |
+
self.rollout_tally = RolloutTally()
|
| 216 |
+
self.tokenwise_tally: Union[ContextualizedTokenwiseTally, None] = None
|
| 217 |
+
self.filter_higher_refprob_tokens_kl = filter_higher_refprob_tokens_kl
|
| 218 |
+
self.truncated_importance_sampling_ratio_cap = (
|
| 219 |
+
truncated_importance_sampling_ratio_cap
|
| 220 |
+
)
|
| 221 |
+
self.importance_sampling_strategy = importance_sampling_strategy
|
| 222 |
+
self.no_rloo_grouping = no_rloo_grouping
|
| 223 |
+
|
| 224 |
+
def mask_non_restricted_token_logits(self, logits: torch.Tensor) -> torch.Tensor:
|
| 225 |
+
"""
|
| 226 |
+
Masks logits so that only allowed tokens (as specified in config.restrict_tokens)
|
| 227 |
+
and the EOS token are active.
|
| 228 |
+
All other logits are set to -inf, effectively removing them from the softmax.
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
logits (torch.Tensor): The logits tensor of shape (B, S, V).
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
torch.Tensor: The masked logits tensor.
|
| 235 |
+
"""
|
| 236 |
+
# Gradients flow only through the kept logits; masking is recomputed per batch for clarity.
|
| 237 |
+
|
| 238 |
+
if self.restrict_tokens is not None:
|
| 239 |
+
allowed_token_ids = []
|
| 240 |
+
for token in self.restrict_tokens:
|
| 241 |
+
token_ids = self.tokenizer(token, add_special_tokens=False)["input_ids"]
|
| 242 |
+
allowed_token_ids.append(token_ids[0])
|
| 243 |
+
allowed_token_ids.append(
|
| 244 |
+
self.tokenizer.eos_token_id
|
| 245 |
+
) # This token should always be active
|
| 246 |
+
allowed_token_ids = torch.tensor(allowed_token_ids, device=logits.device)
|
| 247 |
+
# Mask log_probs and probs to only allowed tokens
|
| 248 |
+
mask = torch.zeros_like(logits).bool() # (B, S, V)
|
| 249 |
+
mask[..., allowed_token_ids] = True
|
| 250 |
+
logits = torch.where(
|
| 251 |
+
mask,
|
| 252 |
+
logits,
|
| 253 |
+
torch.tensor(-float("inf"), device=logits.device),
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
return logits
|
| 257 |
+
|
| 258 |
+
def apply_reinforce_step(
|
| 259 |
+
self,
|
| 260 |
+
training_batch: TrainingBatch,
|
| 261 |
+
) -> None:
|
| 262 |
+
"""
|
| 263 |
+
Applies a single REINFORCE policy gradient step using the provided batch of rollouts.
|
| 264 |
+
Handles batching, loss computation (including entropy and KL regularization), gradient accumulation, and optimizer step.
|
| 265 |
+
Optionally logs various metrics and statistics.
|
| 266 |
+
|
| 267 |
+
Args:
|
| 268 |
+
paths (list[str]): List of game complete file paths for each rollout.
|
| 269 |
+
contexts (list[torch.Tensor]): List of context tensors for each rollout.
|
| 270 |
+
credits (list[torch.Tensor]): List of credit tensors (rewards/advantages) for each rollout.
|
| 271 |
+
action_masks (list[torch.Tensor]): List of action mask tensors for each rollout.
|
| 272 |
+
"""
|
| 273 |
+
with resource_logger_context(logger, "Apply reinforce step"):
|
| 274 |
+
self.policy.train()
|
| 275 |
+
mb_size = self.mini_batch_size
|
| 276 |
+
nb_rollouts = len(training_batch)
|
| 277 |
+
|
| 278 |
+
# Initialize running mean logs
|
| 279 |
+
running_mean_logs = {
|
| 280 |
+
"rl_objective": 0.0,
|
| 281 |
+
"policy_gradient_loss": 0.0,
|
| 282 |
+
"policy_gradient_norm": 0.0,
|
| 283 |
+
"log_probs": 0.0,
|
| 284 |
+
"credits": 0.0,
|
| 285 |
+
"entropy": 0.0,
|
| 286 |
+
"engine_log_probs_diff_clampfrac": 0.0,
|
| 287 |
+
"tis_imp_ratio": 0.0,
|
| 288 |
+
"ref_log_probs_diff_clampfrac": 0.0,
|
| 289 |
+
"higher_refprob_frac": 0.0,
|
| 290 |
+
"tis_imp_ratio_clampfrac": 0.0,
|
| 291 |
+
}
|
| 292 |
+
if self.entropy_coeff != 0.0:
|
| 293 |
+
running_mean_logs["entropy"] = 0.0
|
| 294 |
+
if self.kl_coeff != 0.0:
|
| 295 |
+
running_mean_logs["kl_divergence"] = 0.0
|
| 296 |
+
|
| 297 |
+
# Get total number of tokens generated
|
| 298 |
+
total_tokens_generated = 0
|
| 299 |
+
for att_mask in training_batch.batch_action_mask:
|
| 300 |
+
total_tokens_generated += att_mask.sum()
|
| 301 |
+
|
| 302 |
+
# Obtain loss normalization
|
| 303 |
+
if self.pg_loss_normalization == "nb_tokens":
|
| 304 |
+
normalization_factor = total_tokens_generated
|
| 305 |
+
elif self.pg_loss_normalization == "batch":
|
| 306 |
+
normalization_factor = np.ceil(nb_rollouts / mb_size).astype(int)
|
| 307 |
+
else:
|
| 308 |
+
raise ValueError(
|
| 309 |
+
f"Invalid pg_loss_normalization: {self.pg_loss_normalization}"
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# Gradient accumulation for each mini-batch
|
| 313 |
+
for mb in range(0, nb_rollouts, mb_size):
|
| 314 |
+
logger.info(f"Processing mini-batch {mb} of {nb_rollouts}")
|
| 315 |
+
loss = 0.0
|
| 316 |
+
training_mb = training_batch[mb : mb + mb_size]
|
| 317 |
+
training_mb = training_mb.get_padded_tensors()
|
| 318 |
+
training_mb.to(self.device)
|
| 319 |
+
(
|
| 320 |
+
tokens_mb,
|
| 321 |
+
action_mask_mb,
|
| 322 |
+
entropy_mask_mb,
|
| 323 |
+
credits_mb,
|
| 324 |
+
engine_log_probs_mb,
|
| 325 |
+
timesteps_mb,
|
| 326 |
+
) = (
|
| 327 |
+
training_mb.batch_input_ids,
|
| 328 |
+
training_mb.batch_action_mask,
|
| 329 |
+
training_mb.batch_entropy_mask,
|
| 330 |
+
training_mb.batch_credits,
|
| 331 |
+
training_mb.batch_engine_log_probs,
|
| 332 |
+
training_mb.batch_timesteps,
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
# Next token prediction
|
| 336 |
+
contexts_mb = tokens_mb[:, :-1]
|
| 337 |
+
shifted_contexts_mb = tokens_mb[:, 1:]
|
| 338 |
+
action_mask_mb = action_mask_mb[:, 1:]
|
| 339 |
+
entropy_mask_mb = entropy_mask_mb[:, 1:]
|
| 340 |
+
credits_mb = credits_mb[:, 1:]
|
| 341 |
+
engine_log_probs_mb = engine_log_probs_mb[:, 1:]
|
| 342 |
+
timesteps_mb = timesteps_mb[:, 1:]
|
| 343 |
+
|
| 344 |
+
if self.enable_tokenwise_logging:
|
| 345 |
+
self.tokenwise_tally.set_action_mask(action_mask=action_mask_mb)
|
| 346 |
+
self.tokenwise_tally.set_range(range=(mb, mb + mb_size))
|
| 347 |
+
self.tokenwise_tally.add_contexts(contexts=contexts_mb)
|
| 348 |
+
self.tokenwise_tally.add_data(
|
| 349 |
+
metric_id="next_token",
|
| 350 |
+
metrics=shifted_contexts_mb,
|
| 351 |
+
to_tids=True,
|
| 352 |
+
)
|
| 353 |
+
self.tokenwise_tally.add_data(
|
| 354 |
+
metric_id="entropy_mask",
|
| 355 |
+
metrics=entropy_mask_mb,
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
if self.enable_tokenwise_logging:
|
| 359 |
+
self.tokenwise_tally.add_data(
|
| 360 |
+
metric_id="next_token_credit", metrics=credits_mb
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
# Forward pass + cast to FP-32 for higher prec. Causal LM attention masks are implicit;
|
| 364 |
+
# wire up a custom mask here only if the policy deviates from standard autoregressive behavior.
|
| 365 |
+
logits = self.policy(input_ids=contexts_mb)[0] # (B, S, V)
|
| 366 |
+
|
| 367 |
+
# Mask non-restricted tokens
|
| 368 |
+
if self.restrict_tokens is not None:
|
| 369 |
+
logits = self.mask_non_restricted_token_logits(logits)
|
| 370 |
+
|
| 371 |
+
logits /= self.temperature # (B, S, V)
|
| 372 |
+
|
| 373 |
+
# Compute new log probabilities
|
| 374 |
+
log_probs = F.log_softmax(logits, dim=-1) # (B, S, V)
|
| 375 |
+
|
| 376 |
+
# Get log probabilities of actions taken during rollouts
|
| 377 |
+
action_log_probs = log_probs.gather(
|
| 378 |
+
dim=-1, index=shifted_contexts_mb.unsqueeze(-1)
|
| 379 |
+
).squeeze(
|
| 380 |
+
-1
|
| 381 |
+
) # (B, S)
|
| 382 |
+
if self.pg_loss_normalization == "batch":
|
| 383 |
+
den_running_mean = action_mask_mb.sum() * normalization_factor
|
| 384 |
+
else:
|
| 385 |
+
den_running_mean = normalization_factor
|
| 386 |
+
running_mean_logs["log_probs"] += (
|
| 387 |
+
action_log_probs * action_mask_mb
|
| 388 |
+
).sum().item() / den_running_mean
|
| 389 |
+
running_mean_logs["credits"] += (
|
| 390 |
+
credits_mb * action_mask_mb
|
| 391 |
+
).sum().item() / den_running_mean
|
| 392 |
+
|
| 393 |
+
if self.enable_tokenwise_logging:
|
| 394 |
+
self.tokenwise_tally.add_data(
|
| 395 |
+
metric_id="next_token_log_prob",
|
| 396 |
+
metrics=action_log_probs,
|
| 397 |
+
)
|
| 398 |
+
self.tokenwise_tally.add_data(
|
| 399 |
+
metric_id="engine_next_token_log_prob",
|
| 400 |
+
metrics=engine_log_probs_mb,
|
| 401 |
+
)
|
| 402 |
+
self.tokenwise_tally.add_data(
|
| 403 |
+
metric_id="next_token_prob",
|
| 404 |
+
metrics=torch.exp(action_log_probs),
|
| 405 |
+
)
|
| 406 |
+
top_k_indices = torch.topk(logits, k=5, dim=-1).indices
|
| 407 |
+
self.tokenwise_tally.add_data(
|
| 408 |
+
metric_id=f"top_{5}_tids",
|
| 409 |
+
metrics=top_k_indices,
|
| 410 |
+
to_tids=True,
|
| 411 |
+
)
|
| 412 |
+
self.tokenwise_tally.add_data(
|
| 413 |
+
metric_id=f"top_{5}_probs",
|
| 414 |
+
metrics=torch.exp(log_probs).gather(
|
| 415 |
+
dim=-1, index=top_k_indices
|
| 416 |
+
),
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
rewarded_action_log_probs = (
|
| 420 |
+
action_mask_mb * credits_mb * action_log_probs
|
| 421 |
+
)
|
| 422 |
+
# (B, S)
|
| 423 |
+
INVALID_LOGPROB = 1.0
|
| 424 |
+
CLAMP_VALUE = 40.0
|
| 425 |
+
masked_action_log_probs = torch.masked_fill(
|
| 426 |
+
action_log_probs, ~action_mask_mb, INVALID_LOGPROB
|
| 427 |
+
)
|
| 428 |
+
masked_engine_log_probs = torch.masked_fill(
|
| 429 |
+
engine_log_probs_mb, ~action_mask_mb, INVALID_LOGPROB
|
| 430 |
+
)
|
| 431 |
+
with torch.no_grad():
|
| 432 |
+
action_engine_log_probs_diff = (
|
| 433 |
+
masked_action_log_probs - masked_engine_log_probs
|
| 434 |
+
).clamp(-CLAMP_VALUE, CLAMP_VALUE)
|
| 435 |
+
running_mean_logs["engine_log_probs_diff_clampfrac"] += (
|
| 436 |
+
action_engine_log_probs_diff.abs()
|
| 437 |
+
.eq(CLAMP_VALUE)
|
| 438 |
+
.float()
|
| 439 |
+
.sum()
|
| 440 |
+
.item()
|
| 441 |
+
/ den_running_mean
|
| 442 |
+
)
|
| 443 |
+
if self.importance_sampling_strategy == "per_sequence":
|
| 444 |
+
tis_imp_ratio = torch.zeros_like(action_engine_log_probs_diff)
|
| 445 |
+
for mb_idx in range(action_engine_log_probs_diff.shape[0]):
|
| 446 |
+
valid_token_mask = action_mask_mb[mb_idx]
|
| 447 |
+
timestep_ids = timesteps_mb[mb_idx][valid_token_mask]
|
| 448 |
+
timestep_logprob_diffs = action_engine_log_probs_diff[mb_idx][
|
| 449 |
+
valid_token_mask
|
| 450 |
+
]
|
| 451 |
+
max_timestep = int(timestep_ids.max().item()) + 1
|
| 452 |
+
timestep_sums = torch.zeros(
|
| 453 |
+
max_timestep,
|
| 454 |
+
device=action_engine_log_probs_diff.device,
|
| 455 |
+
dtype=action_engine_log_probs_diff.dtype,
|
| 456 |
+
)
|
| 457 |
+
timestep_sums.scatter_add_(
|
| 458 |
+
0, timestep_ids, timestep_logprob_diffs
|
| 459 |
+
)
|
| 460 |
+
timestep_ratios = torch.exp(timestep_sums)
|
| 461 |
+
tis_imp_ratio[
|
| 462 |
+
mb_idx, valid_token_mask
|
| 463 |
+
] = timestep_ratios.gather(0, timestep_ids)
|
| 464 |
+
else:
|
| 465 |
+
tis_imp_ratio = torch.exp(action_engine_log_probs_diff)
|
| 466 |
+
running_mean_logs["tis_imp_ratio"] += (
|
| 467 |
+
tis_imp_ratio * action_mask_mb
|
| 468 |
+
).sum().item() / den_running_mean
|
| 469 |
+
if self.truncated_importance_sampling_ratio_cap > 0.0:
|
| 470 |
+
tis_imp_ratio = torch.clamp(
|
| 471 |
+
tis_imp_ratio, max=self.truncated_importance_sampling_ratio_cap
|
| 472 |
+
)
|
| 473 |
+
running_mean_logs["tis_imp_ratio_clampfrac"] += (
|
| 474 |
+
tis_imp_ratio.eq(self.truncated_importance_sampling_ratio_cap)
|
| 475 |
+
.float()
|
| 476 |
+
.sum()
|
| 477 |
+
.item()
|
| 478 |
+
) / den_running_mean
|
| 479 |
+
rewarded_action_log_probs = (
|
| 480 |
+
rewarded_action_log_probs * tis_imp_ratio
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
if self.enable_tokenwise_logging:
|
| 484 |
+
self.tokenwise_tally.add_data(
|
| 485 |
+
metric_id="next_token_clogπ",
|
| 486 |
+
metrics=rewarded_action_log_probs,
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
# Add value term to loss
|
| 490 |
+
if self.pg_loss_normalization == "batch":
|
| 491 |
+
nb_act_tokens = action_mask_mb.sum()
|
| 492 |
+
mb_value = -rewarded_action_log_probs.sum() / nb_act_tokens
|
| 493 |
+
else:
|
| 494 |
+
mb_value = -rewarded_action_log_probs.sum()
|
| 495 |
+
|
| 496 |
+
loss += mb_value
|
| 497 |
+
running_mean_logs["rl_objective"] += mb_value.item() / den_running_mean
|
| 498 |
+
|
| 499 |
+
# -------------------------------------------------
|
| 500 |
+
# Entropy Regularization
|
| 501 |
+
# -------------------------------------------------
|
| 502 |
+
# Only apply entropy on distribution defined over most probable tokens
|
| 503 |
+
if self.entropy_topk is not None:
|
| 504 |
+
top_k_indices = torch.topk(
|
| 505 |
+
logits, k=self.entropy_topk, dim=-1
|
| 506 |
+
).indices
|
| 507 |
+
entropy_logits = logits.gather(dim=-1, index=top_k_indices)
|
| 508 |
+
else:
|
| 509 |
+
entropy_logits = logits
|
| 510 |
+
|
| 511 |
+
token_entropy_terms = -F.softmax(
|
| 512 |
+
entropy_logits, dim=-1
|
| 513 |
+
) * F.log_softmax(
|
| 514 |
+
entropy_logits, dim=-1
|
| 515 |
+
) # (B, S, T)
|
| 516 |
+
token_entropy_terms *= (
|
| 517 |
+
action_mask_mb[:, :, None] * entropy_mask_mb[:, :, None]
|
| 518 |
+
) # only get loss on specific action tokens
|
| 519 |
+
|
| 520 |
+
mb_entropy = token_entropy_terms.sum(dim=-1)
|
| 521 |
+
|
| 522 |
+
if self.enable_tokenwise_logging:
|
| 523 |
+
self.tokenwise_tally.add_data(
|
| 524 |
+
metric_id="entropy",
|
| 525 |
+
metrics=mb_entropy,
|
| 526 |
+
)
|
| 527 |
+
if self.pg_loss_normalization == "batch":
|
| 528 |
+
nb_act_tokens = action_mask_mb.sum()
|
| 529 |
+
mb_entropy = -mb_entropy.sum() / nb_act_tokens
|
| 530 |
+
else:
|
| 531 |
+
mb_entropy = -mb_entropy.sum()
|
| 532 |
+
running_mean_logs["entropy"] += -mb_entropy.item() / den_running_mean
|
| 533 |
+
if self.entropy_coeff != 0.0:
|
| 534 |
+
mb_entropy *= self.entropy_coeff
|
| 535 |
+
loss += mb_entropy
|
| 536 |
+
|
| 537 |
+
# -------------------------------------------------
|
| 538 |
+
# KL-DIVERGENCE
|
| 539 |
+
# -------------------------------------------------
|
| 540 |
+
if self.kl_coeff != 0.0:
|
| 541 |
+
ref_model_logits = self.policy.get_base_model_logits(contexts_mb)
|
| 542 |
+
ref_model_logits = ref_model_logits / self.temperature
|
| 543 |
+
# (B, S, V)
|
| 544 |
+
ref_model_logits = self.mask_non_restricted_token_logits(
|
| 545 |
+
logits=ref_model_logits
|
| 546 |
+
)
|
| 547 |
+
# (B, S, V)
|
| 548 |
+
ref_model_log_probs = F.log_softmax(ref_model_logits, dim=-1)
|
| 549 |
+
# (B, S, V)
|
| 550 |
+
ref_model_action_log_probs = ref_model_log_probs.gather(
|
| 551 |
+
dim=-1, index=shifted_contexts_mb.unsqueeze(-1)
|
| 552 |
+
).squeeze(
|
| 553 |
+
-1
|
| 554 |
+
) # (B,S)
|
| 555 |
+
# Approximating KL Divergence (see refs in docstring)
|
| 556 |
+
# Ref 1: http://joschu.net/blog/kl-approx.html
|
| 557 |
+
# Ref 2: https://github.dev/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L1332
|
| 558 |
+
masked_ref_model_action_log_probs = torch.masked_fill(
|
| 559 |
+
ref_model_action_log_probs, ~action_mask_mb, INVALID_LOGPROB
|
| 560 |
+
)
|
| 561 |
+
action_log_probs_diff = (
|
| 562 |
+
masked_ref_model_action_log_probs - masked_action_log_probs
|
| 563 |
+
).clamp(-CLAMP_VALUE, CLAMP_VALUE)
|
| 564 |
+
running_mean_logs["ref_log_probs_diff_clampfrac"] += (
|
| 565 |
+
action_log_probs_diff.abs().eq(CLAMP_VALUE).float().sum().item()
|
| 566 |
+
/ den_running_mean
|
| 567 |
+
)
|
| 568 |
+
if self.filter_higher_refprob_tokens_kl:
|
| 569 |
+
higher_refprob_tokens_mask = action_log_probs_diff > 0.0
|
| 570 |
+
running_mean_logs["higher_refprob_frac"] += (
|
| 571 |
+
higher_refprob_tokens_mask.sum().item() / den_running_mean
|
| 572 |
+
)
|
| 573 |
+
action_log_probs_diff = action_log_probs_diff * (
|
| 574 |
+
~higher_refprob_tokens_mask
|
| 575 |
+
)
|
| 576 |
+
kl_div = torch.expm1(action_log_probs_diff) - action_log_probs_diff
|
| 577 |
+
kl_div *= action_mask_mb # We only care about KLD of action tokens
|
| 578 |
+
if self.truncated_importance_sampling_ratio_cap > 0.0:
|
| 579 |
+
kl_div = kl_div * tis_imp_ratio
|
| 580 |
+
kl_div *= self.kl_coeff
|
| 581 |
+
if self.enable_tokenwise_logging:
|
| 582 |
+
self.tokenwise_tally.add_data(
|
| 583 |
+
metric_id="ref_model_next_token_log_prob",
|
| 584 |
+
metrics=ref_model_action_log_probs,
|
| 585 |
+
)
|
| 586 |
+
self.tokenwise_tally.add_data(
|
| 587 |
+
metric_id="kl_divergence",
|
| 588 |
+
metrics=kl_div,
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
if self.pg_loss_normalization == "batch":
|
| 592 |
+
nb_act_tokens = action_mask_mb.sum()
|
| 593 |
+
mb_kl = kl_div.sum() / nb_act_tokens
|
| 594 |
+
else:
|
| 595 |
+
mb_kl = kl_div.sum()
|
| 596 |
+
running_mean_logs["kl_divergence"] += (
|
| 597 |
+
mb_kl.item() / den_running_mean
|
| 598 |
+
)
|
| 599 |
+
loss += mb_kl
|
| 600 |
+
|
| 601 |
+
# Accumulate gradient
|
| 602 |
+
running_mean_logs["policy_gradient_loss"] += (
|
| 603 |
+
loss.item() / den_running_mean
|
| 604 |
+
)
|
| 605 |
+
loss /= normalization_factor
|
| 606 |
+
self.accelerator.backward(loss)
|
| 607 |
+
|
| 608 |
+
# ensure gpu memory is freed
|
| 609 |
+
del training_mb
|
| 610 |
+
del log_probs
|
| 611 |
+
del logits
|
| 612 |
+
del loss
|
| 613 |
+
del action_log_probs
|
| 614 |
+
del rewarded_action_log_probs
|
| 615 |
+
|
| 616 |
+
logger.info(
|
| 617 |
+
f"Accumulated the policy gradient loss for {total_tokens_generated} tokens."
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
# Clip gradients and take step
|
| 621 |
+
if self.gradient_clipping is not None:
|
| 622 |
+
grad_norm = self.accelerator.clip_grad_norm_(
|
| 623 |
+
self.policy.parameters(), self.gradient_clipping
|
| 624 |
+
)
|
| 625 |
+
running_mean_logs["policy_gradient_norm"] += grad_norm.item()
|
| 626 |
+
|
| 627 |
+
# Take step
|
| 628 |
+
self.policy_optimizer.step()
|
| 629 |
+
self.policy_optimizer.zero_grad()
|
| 630 |
+
|
| 631 |
+
# Store logs
|
| 632 |
+
for key, value in running_mean_logs.items():
|
| 633 |
+
self.tally.add_metric(path=key, metric=value)
|
| 634 |
+
|
| 635 |
+
# Clear accelerator state so we do not accumulate references between optimizer steps.
|
| 636 |
+
self.accelerator.clear(self.policy, self.policy_optimizer)
|
| 637 |
+
import gc
|
| 638 |
+
|
| 639 |
+
gc.collect()
|
| 640 |
+
torch.cuda.empty_cache()
|
| 641 |
+
return running_mean_logs
|
| 642 |
+
|
| 643 |
+
def get_advantages_with_critic_gradient_accumulation(
|
| 644 |
+
self, trajectories: TrajectoryBatch, critic_loss_scaling_factor: float = 2.0
|
| 645 |
+
) -> torch.FloatTensor:
|
| 646 |
+
"""
|
| 647 |
+
Compute (and optionally whiten) advantages while training the critic in mini-batches.
|
| 648 |
+
Uses GAE if enabled, otherwise uses Monte Carlo returns.
|
| 649 |
+
Optionally trains the critic if GAE is used.
|
| 650 |
+
Returns:
|
| 651 |
+
advantages: NestedFloatTensors
|
| 652 |
+
"""
|
| 653 |
+
|
| 654 |
+
mb_size = self.mini_batch_size
|
| 655 |
+
batch_size = trajectories.rollout_ids.shape[0]
|
| 656 |
+
agent_id = trajectories.agent_ids[0]
|
| 657 |
+
batch_rewards = trajectories.batch_rewards
|
| 658 |
+
|
| 659 |
+
######################################
|
| 660 |
+
# use critic for advantage estimation
|
| 661 |
+
######################################
|
| 662 |
+
if self.use_gae:
|
| 663 |
+
if "buffer" in agent_id:
|
| 664 |
+
self.critic.eval()
|
| 665 |
+
training = False
|
| 666 |
+
else:
|
| 667 |
+
self.critic.train()
|
| 668 |
+
training = True
|
| 669 |
+
advantages = []
|
| 670 |
+
# critic_loss_scaling_factor comes learning single critic for two agents
|
| 671 |
+
normalization_factor = (
|
| 672 |
+
np.ceil(batch_size / mb_size).astype(int) * critic_loss_scaling_factor
|
| 673 |
+
)
|
| 674 |
+
# For each minibatch
|
| 675 |
+
for mb in range(0, batch_size, mb_size):
|
| 676 |
+
trajectory_mb = trajectories[mb : mb + mb_size]
|
| 677 |
+
trajectory_mb.to(self.device)
|
| 678 |
+
rewards_mb = trajectory_mb.batch_rewards
|
| 679 |
+
(
|
| 680 |
+
tokens_mb,
|
| 681 |
+
state_ends_mask_mb,
|
| 682 |
+
timestep_counts,
|
| 683 |
+
) = trajectory_mb.get_padded_tensors_for_critic()
|
| 684 |
+
# critic causal attention up to end flags
|
| 685 |
+
if training:
|
| 686 |
+
vals_estimate_full = self.critic(tokens_mb)
|
| 687 |
+
else:
|
| 688 |
+
with torch.no_grad():
|
| 689 |
+
vals_estimate_full = self.critic(tokens_mb)
|
| 690 |
+
|
| 691 |
+
# if vals_estimate_full.dim() == 3:
|
| 692 |
+
# vals_estimate_full = vals_estimate_full.squeeze(-1)
|
| 693 |
+
|
| 694 |
+
# Select only positions where states end, per sample → list of (jT,)
|
| 695 |
+
B = tokens_mb.shape[0]
|
| 696 |
+
vals_list = [
|
| 697 |
+
vals_estimate_full[b][state_ends_mask_mb[b]] for b in range(B)
|
| 698 |
+
]
|
| 699 |
+
|
| 700 |
+
# Pad to (B, max_jT) = (B, S)
|
| 701 |
+
vals_estimate_mb = pad_sequence(
|
| 702 |
+
vals_list, batch_first=True, padding_value=0.0
|
| 703 |
+
)
|
| 704 |
+
dtype = vals_estimate_mb.dtype
|
| 705 |
+
rewards_mb = pad_sequence(
|
| 706 |
+
rewards_mb, batch_first=True, padding_value=0.0
|
| 707 |
+
).to(
|
| 708 |
+
dtype=dtype
|
| 709 |
+
) # (B, S)
|
| 710 |
+
self.rollout_tally.add_metric(
|
| 711 |
+
path=["batch_rewards"],
|
| 712 |
+
rollout_tally_item=RolloutTallyItem(
|
| 713 |
+
crn_ids=trajectory_mb.crn_ids,
|
| 714 |
+
rollout_ids=trajectory_mb.rollout_ids,
|
| 715 |
+
agent_ids=trajectory_mb.agent_ids,
|
| 716 |
+
metric_matrix=rewards_mb,
|
| 717 |
+
),
|
| 718 |
+
)
|
| 719 |
+
if self.reward_normalizing_constant != 1.0:
|
| 720 |
+
rewards_mb /= self.reward_normalizing_constant
|
| 721 |
+
|
| 722 |
+
det_vals_estimate_mb = vals_estimate_mb.detach() # (B, max_jT)
|
| 723 |
+
self.rollout_tally.add_metric(
|
| 724 |
+
path=["mb_value_estimates_critic"],
|
| 725 |
+
rollout_tally_item=RolloutTallyItem(
|
| 726 |
+
crn_ids=trajectory_mb.crn_ids,
|
| 727 |
+
rollout_ids=trajectory_mb.rollout_ids,
|
| 728 |
+
agent_ids=trajectory_mb.agent_ids,
|
| 729 |
+
metric_matrix=det_vals_estimate_mb,
|
| 730 |
+
),
|
| 731 |
+
)
|
| 732 |
+
|
| 733 |
+
# Append a 0 value to the end of the value estimates
|
| 734 |
+
if det_vals_estimate_mb.shape[1] == rewards_mb.shape[1]:
|
| 735 |
+
Bsize = det_vals_estimate_mb.shape[0]
|
| 736 |
+
device = det_vals_estimate_mb.device
|
| 737 |
+
dtype = det_vals_estimate_mb.dtype
|
| 738 |
+
det_vals_estimate_mb = torch.cat(
|
| 739 |
+
[
|
| 740 |
+
det_vals_estimate_mb,
|
| 741 |
+
torch.zeros((Bsize, 1), device=device, dtype=dtype),
|
| 742 |
+
],
|
| 743 |
+
dim=1,
|
| 744 |
+
) # (B, max_jT+1)
|
| 745 |
+
else:
|
| 746 |
+
raise ValueError(
|
| 747 |
+
"Incompatible shapes for value estimates and rewards."
|
| 748 |
+
)
|
| 749 |
+
|
| 750 |
+
# Get annealed lambda
|
| 751 |
+
if self.use_gae_lambda_annealing:
|
| 752 |
+
annealing_constant = self.gae_lambda_annealing_method(
|
| 753 |
+
step=self.trainer_annealing_state.annealing_step_counter
|
| 754 |
+
)
|
| 755 |
+
annealed_lambda = (
|
| 756 |
+
self.gae_lambda_annealing_limit * annealing_constant
|
| 757 |
+
)
|
| 758 |
+
self.tally.add_metric(
|
| 759 |
+
path="annealed_lambda", metric=annealed_lambda
|
| 760 |
+
)
|
| 761 |
+
else:
|
| 762 |
+
annealed_lambda = self.gae_lambda_annealing_limit
|
| 763 |
+
|
| 764 |
+
# Get GAE advantages
|
| 765 |
+
gae_advantages = get_generalized_advantage_estimates(
|
| 766 |
+
rewards=rewards_mb,
|
| 767 |
+
value_estimates=det_vals_estimate_mb,
|
| 768 |
+
discount_factor=self.discount_factor,
|
| 769 |
+
lambda_coef=annealed_lambda,
|
| 770 |
+
) # (B, max_jT)
|
| 771 |
+
self.rollout_tally.add_metric(
|
| 772 |
+
path=["mb_gae_advantages"],
|
| 773 |
+
rollout_tally_item=RolloutTallyItem(
|
| 774 |
+
crn_ids=trajectory_mb.crn_ids,
|
| 775 |
+
rollout_ids=trajectory_mb.rollout_ids,
|
| 776 |
+
agent_ids=trajectory_mb.agent_ids,
|
| 777 |
+
metric_matrix=gae_advantages,
|
| 778 |
+
),
|
| 779 |
+
)
|
| 780 |
+
if training:
|
| 781 |
+
targets = (
|
| 782 |
+
gae_advantages.to(dtype=dtype) + det_vals_estimate_mb[:, :-1]
|
| 783 |
+
) # (B, max_jT) # A(s, a, b) + V(s) = Q(s, a, b)
|
| 784 |
+
self.rollout_tally.add_metric(
|
| 785 |
+
path=["mb_targets_critic"],
|
| 786 |
+
rollout_tally_item=RolloutTallyItem(
|
| 787 |
+
crn_ids=trajectory_mb.crn_ids,
|
| 788 |
+
rollout_ids=trajectory_mb.rollout_ids,
|
| 789 |
+
agent_ids=trajectory_mb.agent_ids,
|
| 790 |
+
metric_matrix=targets,
|
| 791 |
+
),
|
| 792 |
+
)
|
| 793 |
+
if self.critic_loss_type == "mse":
|
| 794 |
+
loss = F.mse_loss(
|
| 795 |
+
input=vals_estimate_mb,
|
| 796 |
+
target=targets,
|
| 797 |
+
)
|
| 798 |
+
elif self.critic_loss_type == "huber":
|
| 799 |
+
loss = F.huber_loss(
|
| 800 |
+
input=vals_estimate_mb,
|
| 801 |
+
target=targets,
|
| 802 |
+
)
|
| 803 |
+
self.tally.add_metric(path=["mb_critic_loss"], metric=loss.item())
|
| 804 |
+
# Accumulate gradient
|
| 805 |
+
loss /= normalization_factor
|
| 806 |
+
self.accelerator.backward(loss)
|
| 807 |
+
del loss
|
| 808 |
+
del targets
|
| 809 |
+
del vals_estimate_mb
|
| 810 |
+
del trajectory_mb
|
| 811 |
+
del vals_estimate_full
|
| 812 |
+
|
| 813 |
+
# Get jagged back using timestep_counts
|
| 814 |
+
advantages.extend(
|
| 815 |
+
[gae_advantages[i, : timestep_counts[i]] for i in range(B)]
|
| 816 |
+
)
|
| 817 |
+
|
| 818 |
+
######################################
|
| 819 |
+
# use exclusively Monte Carlo returns & rloo for advantage estimation
|
| 820 |
+
######################################
|
| 821 |
+
else:
|
| 822 |
+
lengths = [len(c) for c in batch_rewards]
|
| 823 |
+
padded_rewards = pad_sequence(
|
| 824 |
+
batch_rewards, batch_first=True, padding_value=0.0
|
| 825 |
+
)
|
| 826 |
+
self.rollout_tally.add_metric(
|
| 827 |
+
path=["mb_rewards"],
|
| 828 |
+
rollout_tally_item=RolloutTallyItem(
|
| 829 |
+
crn_ids=trajectories.crn_ids,
|
| 830 |
+
rollout_ids=trajectories.rollout_ids,
|
| 831 |
+
agent_ids=trajectories.agent_ids,
|
| 832 |
+
metric_matrix=padded_rewards,
|
| 833 |
+
),
|
| 834 |
+
)
|
| 835 |
+
if self.reward_normalizing_constant != 1.0:
|
| 836 |
+
padded_rewards /= self.reward_normalizing_constant
|
| 837 |
+
padded_advantages = get_discounted_returns(
|
| 838 |
+
rewards=padded_rewards,
|
| 839 |
+
discount_factor=self.discount_factor,
|
| 840 |
+
) # no baseline for now
|
| 841 |
+
if self.use_rloo:
|
| 842 |
+
is_grouped_by_rng = (
|
| 843 |
+
trajectories.crn_ids.unique().shape[0]
|
| 844 |
+
!= trajectories.crn_ids.shape[0]
|
| 845 |
+
)
|
| 846 |
+
if is_grouped_by_rng and not self.no_rloo_grouping:
|
| 847 |
+
for crn_id in trajectories.crn_ids.unique():
|
| 848 |
+
rng_mask = trajectories.crn_ids == crn_id
|
| 849 |
+
rng_advantages = padded_advantages[rng_mask]
|
| 850 |
+
rng_advantages, _ = get_rloo_credits(credits=rng_advantages)
|
| 851 |
+
padded_advantages[rng_mask] = rng_advantages
|
| 852 |
+
else:
|
| 853 |
+
padded_advantages, _ = get_rloo_credits(credits=padded_advantages)
|
| 854 |
+
self.rollout_tally.add_metric(
|
| 855 |
+
path=["mb_rloo_advantages"],
|
| 856 |
+
rollout_tally_item=RolloutTallyItem(
|
| 857 |
+
crn_ids=trajectories.crn_ids,
|
| 858 |
+
rollout_ids=trajectories.rollout_ids,
|
| 859 |
+
agent_ids=trajectories.agent_ids,
|
| 860 |
+
metric_matrix=padded_advantages,
|
| 861 |
+
),
|
| 862 |
+
)
|
| 863 |
+
advantages = [
|
| 864 |
+
padded_advantages[i, : lengths[i]]
|
| 865 |
+
for i in range(padded_advantages.shape[0])
|
| 866 |
+
]
|
| 867 |
+
|
| 868 |
+
if self.whiten_advantages_time_step_wise or self.whiten_advantages:
|
| 869 |
+
lengths = [len(c) for c in advantages]
|
| 870 |
+
padded_advantages = pad_sequence(
|
| 871 |
+
advantages, batch_first=True, padding_value=0.0
|
| 872 |
+
)
|
| 873 |
+
if self.whiten_advantages_time_step_wise:
|
| 874 |
+
whitened_padded_advantages = whiten_advantages_time_step_wise(
|
| 875 |
+
padded_advantages
|
| 876 |
+
)
|
| 877 |
+
path = ["mb_whitened_advantages_time_step_wise"]
|
| 878 |
+
elif self.whiten_advantages:
|
| 879 |
+
whitened_padded_advantages = whiten_advantages(padded_advantages)
|
| 880 |
+
path = ["mb_whitened_advantages"]
|
| 881 |
+
self.rollout_tally.add_metric(
|
| 882 |
+
path=path,
|
| 883 |
+
rollout_tally_item=RolloutTallyItem(
|
| 884 |
+
crn_ids=trajectories.crn_ids,
|
| 885 |
+
rollout_ids=trajectories.rollout_ids,
|
| 886 |
+
agent_ids=trajectories.agent_ids,
|
| 887 |
+
metric_matrix=whitened_padded_advantages,
|
| 888 |
+
),
|
| 889 |
+
)
|
| 890 |
+
advantages = [
|
| 891 |
+
whitened_padded_advantages[i, : lengths[i]]
|
| 892 |
+
for i in range(whitened_padded_advantages.shape[0])
|
| 893 |
+
]
|
| 894 |
+
|
| 895 |
+
self.trainer_annealing_state.annealing_step_counter += 1
|
| 896 |
+
|
| 897 |
+
return advantages
|
| 898 |
+
|
| 899 |
+
@abstractmethod
|
| 900 |
+
def set_agent_trajectory_data(
|
| 901 |
+
self, agent_id: str, roots: list[RolloutTreeRootNode]
|
| 902 |
+
) -> None:
|
| 903 |
+
"""
|
| 904 |
+
Populate self.training_data for a single agent using the provided rollout trees.
|
| 905 |
+
"""
|
| 906 |
+
pass
|
| 907 |
+
|
| 908 |
+
def set_trajectory_data(
|
| 909 |
+
self, roots: list[RolloutTreeRootNode], agent_ids: list[str]
|
| 910 |
+
) -> None:
|
| 911 |
+
"""
|
| 912 |
+
Convenience wrapper to ingest trajectory data for every training agent.
|
| 913 |
+
"""
|
| 914 |
+
for agent_id in agent_ids:
|
| 915 |
+
self.set_agent_trajectory_data(agent_id, roots)
|
| 916 |
+
|
| 917 |
+
@abstractmethod
|
| 918 |
+
def share_advantage_data(self) -> list[AdvantagePacket]:
|
| 919 |
+
pass
|
| 920 |
+
|
| 921 |
+
@abstractmethod
|
| 922 |
+
def receive_advantage_data(self, advantage_packets: list[AdvantagePacket]) -> None:
|
| 923 |
+
pass
|
| 924 |
+
|
| 925 |
+
def set_policy_gradient_data(self, agent_ids: list[str]) -> None:
|
| 926 |
+
"""
|
| 927 |
+
Reset and rebuild the policy-gradient minibatches before iterating through agents.
|
| 928 |
+
"""
|
| 929 |
+
self.policy_gradient_data = None
|
| 930 |
+
for agent_id in agent_ids:
|
| 931 |
+
assert "buffer" not in agent_id, "Buffer agents do not train policy"
|
| 932 |
+
trajectory_batch = self.training_data[agent_id]
|
| 933 |
+
tokenwise_batch_credits = get_tokenwise_credits(
|
| 934 |
+
batch_timesteps=trajectory_batch.batch_timesteps,
|
| 935 |
+
batch_credits=trajectory_batch.batch_credits,
|
| 936 |
+
)
|
| 937 |
+
policy_gradient_data = TrainingBatch(
|
| 938 |
+
rollout_ids=trajectory_batch.rollout_ids,
|
| 939 |
+
batch_input_ids=trajectory_batch.batch_input_ids,
|
| 940 |
+
batch_action_mask=trajectory_batch.batch_action_mask,
|
| 941 |
+
batch_entropy_mask=trajectory_batch.batch_entropy_mask,
|
| 942 |
+
batch_credits=tokenwise_batch_credits,
|
| 943 |
+
batch_engine_log_probs=trajectory_batch.batch_engine_log_probs,
|
| 944 |
+
batch_timesteps=trajectory_batch.batch_timesteps,
|
| 945 |
+
)
|
| 946 |
+
if self.policy_gradient_data is None:
|
| 947 |
+
self.policy_gradient_data = policy_gradient_data
|
| 948 |
+
else:
|
| 949 |
+
self.policy_gradient_data.append(policy_gradient_data)
|
| 950 |
+
|
| 951 |
+
self.training_data = {}
|
| 952 |
+
self.tokenwise_tally = ContextualizedTokenwiseTally(
|
| 953 |
+
tokenizer=self.tokenizer,
|
| 954 |
+
paths=self.debug_path_list,
|
| 955 |
+
)
|
| 956 |
+
|
| 957 |
+
def train(self) -> None:
|
| 958 |
+
"""
|
| 959 |
+
Entry point for policy updates: prepare batches, compute gradients, and update parameters.
|
| 960 |
+
"""
|
| 961 |
+
assert self.policy_gradient_data is not None, "Policy gradient data is not set"
|
| 962 |
+
if self.critic_optimizer is not None:
|
| 963 |
+
if self.gradient_clipping is not None:
|
| 964 |
+
grad_norm = self.accelerator.clip_grad_norm_(
|
| 965 |
+
self.critic.parameters(), self.gradient_clipping
|
| 966 |
+
)
|
| 967 |
+
self.tally.add_metric(
|
| 968 |
+
path="gradient_norm_critic", metric=grad_norm.item()
|
| 969 |
+
)
|
| 970 |
+
# Take step
|
| 971 |
+
self.critic_optimizer.step()
|
| 972 |
+
self.critic_optimizer.zero_grad()
|
| 973 |
+
self.accelerator.clear(self.critic, self.critic_optimizer)
|
| 974 |
+
import gc
|
| 975 |
+
|
| 976 |
+
gc.collect()
|
| 977 |
+
torch.cuda.empty_cache()
|
| 978 |
+
running_mean_logs = self.apply_reinforce_step(
|
| 979 |
+
training_batch=self.policy_gradient_data
|
| 980 |
+
)
|
| 981 |
+
return running_mean_logs
|
| 982 |
+
|
| 983 |
+
def export_training_tally(self, identifier: str, folder: str) -> None:
|
| 984 |
+
"""
|
| 985 |
+
Saves and resets the collected training metrics using the tally object.
|
| 986 |
+
"""
|
| 987 |
+
os.makedirs(folder, exist_ok=True)
|
| 988 |
+
self.tally.save(identifier=identifier, folder=folder)
|
| 989 |
+
self.tokenwise_tally.save(
|
| 990 |
+
path=os.path.join(folder, f"{identifier}_tokenwise.csv")
|
| 991 |
+
)
|
| 992 |
+
self.rollout_tally.save(identifier=identifier, folder=folder)
|
| 993 |
+
self.tally.reset()
|
| 994 |
+
self.tokenwise_tally = None
|
| 995 |
+
self.rollout_tally.reset()
|
| 996 |
+
self.debug_path_list = []
|
| 997 |
+
|
| 998 |
+
def export_optimizer_states(self) -> None:
|
| 999 |
+
"""
|
| 1000 |
+
Saves the optimizer states for both the main model and critic (if it exists).
|
| 1001 |
+
"""
|
| 1002 |
+
try:
|
| 1003 |
+
os.makedirs(self.save_path, exist_ok=True)
|
| 1004 |
+
|
| 1005 |
+
torch.save(self.policy_optimizer.state_dict(), self.policy_optimizer_path)
|
| 1006 |
+
logger.info(f"Saved main optimizer state to {self.policy_optimizer_path}")
|
| 1007 |
+
|
| 1008 |
+
if self.critic_optimizer is not None:
|
| 1009 |
+
torch.save(
|
| 1010 |
+
self.critic_optimizer.state_dict(), self.critic_optimizer_path
|
| 1011 |
+
)
|
| 1012 |
+
logger.info(
|
| 1013 |
+
f"Saved critic optimizer state to {self.critic_optimizer_path}"
|
| 1014 |
+
)
|
| 1015 |
+
except Exception as e:
|
| 1016 |
+
logger.error(f"Error saving optimizer states: {str(e)}")
|
| 1017 |
+
raise
|
| 1018 |
+
|
| 1019 |
+
def export_trainer_annealing_state(self) -> None:
|
| 1020 |
+
"""
|
| 1021 |
+
Saves the trainer state.
|
| 1022 |
+
"""
|
| 1023 |
+
with open(self.trainer_annealing_state_path, "wb") as f:
|
| 1024 |
+
pickle.dump(self.trainer_annealing_state, f)
|
| 1025 |
+
logger.info(f"Saved trainer state to {self.trainer_annealing_state_path}")
|
| 1026 |
+
|
| 1027 |
+
def export_trainer_states(self) -> None:
|
| 1028 |
+
"""
|
| 1029 |
+
Saves the trainer states.
|
| 1030 |
+
"""
|
| 1031 |
+
self.export_optimizer_states()
|
| 1032 |
+
self.export_trainer_annealing_state()
|
src_code_for_reproducibility/utils/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (257 Bytes). View file
|
|
|
src_code_for_reproducibility/utils/__pycache__/dict_get_path.cpython-312.pyc
ADDED
|
Binary file (669 Bytes). View file
|
|
|
src_code_for_reproducibility/utils/__pycache__/gather_training_stats.cpython-312.pyc
ADDED
|
Binary file (12 kB). View file
|
|
|
src_code_for_reproducibility/utils/__pycache__/get_coagent_id.cpython-312.pyc
ADDED
|
Binary file (560 Bytes). View file
|
|
|
src_code_for_reproducibility/utils/__pycache__/resource_context.cpython-312.pyc
ADDED
|
Binary file (4.68 kB). View file
|
|
|
src_code_for_reproducibility/utils/__pycache__/rollout_tree_chat_htmls.cpython-312.pyc
ADDED
|
Binary file (60.2 kB). View file
|
|
|
src_code_for_reproducibility/utils/__pycache__/rollout_tree_gather_utils.cpython-312.pyc
ADDED
|
Binary file (12.7 kB). View file
|
|
|
src_code_for_reproducibility/utils/__pycache__/rollout_tree_stats.cpython-312.pyc
ADDED
|
Binary file (2.38 kB). View file
|
|
|
src_code_for_reproducibility/utils/__pycache__/short_id_gen.cpython-312.pyc
ADDED
|
Binary file (722 Bytes). View file
|
|
|
src_code_for_reproducibility/utils/__pycache__/stat_pack.cpython-312.pyc
ADDED
|
Binary file (7.76 kB). View file
|
|
|
src_code_for_reproducibility/utils/__pycache__/update_start_epoch.cpython-312.pyc
ADDED
|
Binary file (1.01 kB). View file
|
|
|
src_code_for_reproducibility/utils/__pycache__/wandb_utils.cpython-312.pyc
ADDED
|
Binary file (6.66 kB). View file
|
|
|