Muqeeth commited on
Commit
0caadff
·
verified ·
1 Parent(s): 1f4f273

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. .hydra/config.yaml +178 -0
  2. .hydra/hydra.yaml +154 -0
  3. .hydra/overrides.yaml +1 -0
  4. run.log +0 -0
  5. seed_123/Qwen/Qwen2.5-7B-Instruct/adapters/README.md +207 -0
  6. seed_123/Qwen/Qwen2.5-7B-Instruct/adapters/agent_adapter/adapter_config.json +42 -0
  7. seed_123/Qwen/Qwen2.5-7B-Instruct/adapters/critic_adapter/adapter_config.json +42 -0
  8. src_code_for_reproducibility/__init__.py +0 -0
  9. src_code_for_reproducibility/__pycache__/__init__.cpython-312.pyc +0 -0
  10. src_code_for_reproducibility/chat_utils/__pycache__/apply_template.cpython-312.pyc +0 -0
  11. src_code_for_reproducibility/chat_utils/__pycache__/chat_turn.cpython-312.pyc +0 -0
  12. src_code_for_reproducibility/chat_utils/__pycache__/template_specific.cpython-312.pyc +0 -0
  13. src_code_for_reproducibility/docs/source/contributing.rst +0 -0
  14. src_code_for_reproducibility/docs/source/environments/dond.rst +410 -0
  15. src_code_for_reproducibility/docs/source/environments/ipd.rst +411 -0
  16. src_code_for_reproducibility/docs/source/marl_standard.rst +141 -0
  17. src_code_for_reproducibility/docs/source/src.environments.dond.dond_game.rst +7 -0
  18. src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_statistics_funcs.rst +7 -0
  19. src_code_for_reproducibility/docs/source/src.experiments.generate_and_train.rst +7 -0
  20. src_code_for_reproducibility/docs/source/src.generation.run_games.rst +7 -0
  21. src_code_for_reproducibility/docs/source/src.models.local_llm.rst +7 -0
  22. src_code_for_reproducibility/docs/source/src.models.new_local_llm.rst +7 -0
  23. src_code_for_reproducibility/docs/source/src.models.rst +20 -0
  24. src_code_for_reproducibility/docs/source/src.training.reinforce_training.rst +7 -0
  25. src_code_for_reproducibility/docs/source/src.training.rl_convs_processing.rst +7 -0
  26. src_code_for_reproducibility/docs/source/src.training.train_main.rst +7 -0
  27. src_code_for_reproducibility/docs/source/src.utils.export_ppo_training_set.rst +7 -0
  28. src_code_for_reproducibility/docs/source/src.utils.extra_stats.rst +7 -0
  29. src_code_for_reproducibility/docs/source/src.utils.log_gpu_usage.rst +7 -0
  30. src_code_for_reproducibility/docs/source/src.utils.parallel_shuffle.rst +7 -0
  31. src_code_for_reproducibility/markov_games/__pycache__/__init__.cpython-311.pyc +0 -0
  32. src_code_for_reproducibility/markov_games/__pycache__/group_timesteps.cpython-312.pyc +0 -0
  33. src_code_for_reproducibility/markov_games/__pycache__/mg_utils.cpython-312.pyc +0 -0
  34. src_code_for_reproducibility/markov_games/__pycache__/rollout_tree.cpython-312.pyc +0 -0
  35. src_code_for_reproducibility/markov_games/diplomacy/diplomacy_agent.py +259 -0
  36. src_code_for_reproducibility/markov_games/diplomacy/diplomacy_logging_for_training.py +0 -0
  37. src_code_for_reproducibility/markov_games/ipd/__pycache__/__init__.cpython-312.pyc +0 -0
  38. src_code_for_reproducibility/markov_games/negotiation/__pycache__/no_press_nego_simulation.cpython-312.pyc +0 -0
  39. src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_simple_agent.cpython-312.pyc +0 -0
  40. src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_simulation.cpython-312.pyc +0 -0
  41. src_code_for_reproducibility/markov_games/negotiation/tas_agent.py +108 -0
  42. src_code_for_reproducibility/models/__pycache__/scalar_critic.cpython-312.pyc +0 -0
  43. src_code_for_reproducibility/training/credit_methods.py +295 -0
  44. src_code_for_reproducibility/training/tally_tokenwise.py +276 -0
  45. src_code_for_reproducibility/training/tokenize_chats.py +128 -0
  46. src_code_for_reproducibility/training/trainer_ad_align.py +492 -0
  47. src_code_for_reproducibility/training/trainer_common.py +1054 -0
  48. src_code_for_reproducibility/training/trainer_independent.py +155 -0
  49. src_code_for_reproducibility/training/trainer_sum_rewards.py +127 -0
  50. src_code_for_reproducibility/training/training_data_utils.py +394 -0
.hydra/config.yaml ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ experiment:
2
+ wandb_enabled: true
3
+ nb_epochs: 3000
4
+ nb_matches_per_iteration: 64
5
+ reinit_matches_each_it: true
6
+ checkpoint_every_n_iterations: 10
7
+ start_epoch: 0
8
+ resume_experiment: true
9
+ base_seed: 123
10
+ seed_group_size: 8
11
+ train: true
12
+ stat_methods_for_live_wandb: mllm.markov_games.negotiation.negotiation_statistics
13
+ name: no_press_10_1_ties_ad_align_nocurrtimestep_seed123
14
+ agent_buffer: true
15
+ keep_agent_buffer_count: ${lora_count}
16
+ agent_buffer_recent_k: -1
17
+ logging:
18
+ wandb:
19
+ enabled: false
20
+ project: llm-negotiation
21
+ entity: null
22
+ mode: online
23
+ name: null
24
+ group: null
25
+ tags: []
26
+ notes: null
27
+ temperature: 1.0
28
+ markov_games:
29
+ runner_method_name: LinearRunner
30
+ runner_kwargs: {}
31
+ group_by_round: true
32
+ simulation_class_name: NoPressSimulation
33
+ simulation_init_args:
34
+ nb_of_rounds: 10
35
+ quota_messages_per_agent_per_round: 0
36
+ game_type: 10-1-ties
37
+ atleast_one_conflict: true
38
+ item_types:
39
+ - hats
40
+ - books
41
+ - balls
42
+ agents:
43
+ 0:
44
+ agent_id: ${agent_0_id}
45
+ agent_name: Alice
46
+ agent_class_name: NoPressAgent
47
+ policy_id: base_llm/agent_adapter
48
+ init_kwargs:
49
+ goal: Maximize your total points over the whole game.
50
+ 1:
51
+ agent_id: ${agent_1_id}
52
+ agent_name: Bob
53
+ agent_class_name: NoPressAgent
54
+ policy_id: base_llm/agent_adapter
55
+ init_kwargs:
56
+ goal: Maximize your total points over the whole game.
57
+ models:
58
+ base_llm:
59
+ class: LeanLocalLLM
60
+ init_args:
61
+ llm_id: base_llm
62
+ model_name: Qwen/Qwen2.5-7B-Instruct
63
+ inference_backend: vllm
64
+ hf_kwargs:
65
+ device_map: auto
66
+ torch_dtype: bfloat16
67
+ max_memory:
68
+ 0: 20GiB
69
+ attn_implementation: flash_attention_2
70
+ inference_backend_init_kwargs:
71
+ enable_lora: true
72
+ seed: ${experiment.base_seed}
73
+ enable_prefix_caching: true
74
+ max_model_len: 10000.0
75
+ gpu_memory_utilization: 0.5
76
+ dtype: bfloat16
77
+ trust_remote_code: true
78
+ max_lora_rank: 32
79
+ enforce_eager: false
80
+ max_loras: ${lora_count}
81
+ max_cpu_loras: ${lora_count}
82
+ enable_sleep_mode: true
83
+ inference_backend_sampling_params:
84
+ temperature: ${temperature}
85
+ top_p: 1.0
86
+ max_tokens: 400
87
+ top_k: -1
88
+ logprobs: 0
89
+ adapter_configs:
90
+ agent_adapter:
91
+ task_type: CAUSAL_LM
92
+ r: 32
93
+ lora_alpha: 64
94
+ lora_dropout: 0.0
95
+ target_modules: all-linear
96
+ critic_adapter:
97
+ task_type: CAUSAL_LM
98
+ r: 32
99
+ lora_alpha: 64
100
+ lora_dropout: 0.0
101
+ target_modules: all-linear
102
+ enable_thinking: null
103
+ regex_max_attempts: 3
104
+ critics:
105
+ agent_critic:
106
+ module_pointer:
107
+ - base_llm
108
+ - critic_adapter
109
+ optimizers:
110
+ agent_optimizer:
111
+ module_pointer:
112
+ - base_llm
113
+ - agent_adapter
114
+ optimizer_class_name: torch.optim.Adam
115
+ init_args:
116
+ lr: 3.0e-06
117
+ weight_decay: 0.0
118
+ critic_optimizer:
119
+ module_pointer: agent_critic
120
+ optimizer_class_name: torch.optim.Adam
121
+ init_args:
122
+ lr: 3.0e-06
123
+ weight_decay: 0.0
124
+ trainers:
125
+ agent_trainer:
126
+ class: TrainerAdAlign
127
+ module_pointers:
128
+ policy:
129
+ - base_llm
130
+ - agent_adapter
131
+ policy_optimizer: agent_optimizer
132
+ critic: agent_critic
133
+ critic_optimizer: critic_optimizer
134
+ kwargs:
135
+ entropy_coeff: 0.0
136
+ entropy_topk: null
137
+ entropy_mask_regex: null
138
+ kl_coeff: 0.001
139
+ gradient_clipping: 1.0
140
+ restrict_tokens: null
141
+ mini_batch_size: 1
142
+ use_gradient_checkpointing: false
143
+ temperature: ${temperature}
144
+ device: cuda:0
145
+ use_gae: false
146
+ whiten_advantages: false
147
+ whiten_advantages_time_step_wise: false
148
+ skip_discounted_state_visitation: true
149
+ use_gae_lambda_annealing: false
150
+ gae_lambda_annealing_method: None
151
+ gae_lambda_annealing_method_params: None
152
+ gae_lambda_annealing_limit: 0.95
153
+ discount_factor: 0.9
154
+ use_rloo: true
155
+ enable_tokenwise_logging: false
156
+ pg_loss_normalization: nb_tokens
157
+ truncated_importance_sampling_ratio_cap: 2.0
158
+ reward_normalizing_constant: 100.0
159
+ ad_align_force_coop_first_step: false
160
+ ad_align_clipping: null
161
+ ad_align_gamma: 0.9
162
+ ad_align_exclude_k_equals_t: true
163
+ ad_align_use_sign: false
164
+ ad_align_beta: 1.0
165
+ use_old_ad_align: true
166
+ use_time_regularization: false
167
+ rloo_branch: false
168
+ reuse_baseline: false
169
+ train_on_which_data:
170
+ agent_trainer: ${agent_ids}
171
+ lora_count: 30
172
+ common_agent_kwargs:
173
+ goal: Maximize your total points over the whole game.
174
+ agent_0_id: Alice
175
+ agent_1_id: Bob
176
+ agent_ids:
177
+ - Alice
178
+ - Bob
.hydra/hydra.yaml ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: ${oc.env:SCRATCH}/llm_negotiation/${now:%Y_%m}/${experiment.name}
4
+ sweep:
5
+ dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
6
+ subdir: ${hydra.job.num}
7
+ launcher:
8
+ _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
9
+ sweeper:
10
+ _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
11
+ max_batch_size: null
12
+ params: null
13
+ help:
14
+ app_name: ${hydra.job.name}
15
+ header: '${hydra.help.app_name} is powered by Hydra.
16
+
17
+ '
18
+ footer: 'Powered by Hydra (https://hydra.cc)
19
+
20
+ Use --hydra-help to view Hydra specific help
21
+
22
+ '
23
+ template: '${hydra.help.header}
24
+
25
+ == Configuration groups ==
26
+
27
+ Compose your configuration from those groups (group=option)
28
+
29
+
30
+ $APP_CONFIG_GROUPS
31
+
32
+
33
+ == Config ==
34
+
35
+ Override anything in the config (foo.bar=value)
36
+
37
+
38
+ $CONFIG
39
+
40
+
41
+ ${hydra.help.footer}
42
+
43
+ '
44
+ hydra_help:
45
+ template: 'Hydra (${hydra.runtime.version})
46
+
47
+ See https://hydra.cc for more info.
48
+
49
+
50
+ == Flags ==
51
+
52
+ $FLAGS_HELP
53
+
54
+
55
+ == Configuration groups ==
56
+
57
+ Compose your configuration from those groups (For example, append hydra/job_logging=disabled
58
+ to command line)
59
+
60
+
61
+ $HYDRA_CONFIG_GROUPS
62
+
63
+
64
+ Use ''--cfg hydra'' to Show the Hydra config.
65
+
66
+ '
67
+ hydra_help: ???
68
+ hydra_logging:
69
+ version: 1
70
+ formatters:
71
+ simple:
72
+ format: '[%(asctime)s][HYDRA] %(message)s'
73
+ handlers:
74
+ console:
75
+ class: logging.StreamHandler
76
+ formatter: simple
77
+ stream: ext://sys.stdout
78
+ root:
79
+ level: INFO
80
+ handlers:
81
+ - console
82
+ loggers:
83
+ logging_example:
84
+ level: DEBUG
85
+ disable_existing_loggers: false
86
+ job_logging:
87
+ version: 1
88
+ formatters:
89
+ simple:
90
+ format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
91
+ handlers:
92
+ console:
93
+ class: logging.StreamHandler
94
+ formatter: simple
95
+ stream: ext://sys.stdout
96
+ file:
97
+ class: logging.FileHandler
98
+ formatter: simple
99
+ filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
100
+ root:
101
+ level: INFO
102
+ handlers:
103
+ - console
104
+ - file
105
+ disable_existing_loggers: false
106
+ env: {}
107
+ mode: RUN
108
+ searchpath: []
109
+ callbacks: {}
110
+ output_subdir: .hydra
111
+ overrides:
112
+ hydra:
113
+ - hydra.mode=RUN
114
+ task: []
115
+ job:
116
+ name: run
117
+ chdir: false
118
+ override_dirname: ''
119
+ id: ???
120
+ num: ???
121
+ config_name: no_press_10_1_ties_ad_align_nocurrtimestep_seed123.yaml
122
+ env_set: {}
123
+ env_copy: []
124
+ config:
125
+ override_dirname:
126
+ kv_sep: '='
127
+ item_sep: ','
128
+ exclude_keys: []
129
+ runtime:
130
+ version: 1.3.2
131
+ version_base: '1.1'
132
+ cwd: /scratch/m/muqeeth/llm_negotiation
133
+ config_sources:
134
+ - path: hydra.conf
135
+ schema: pkg
136
+ provider: hydra
137
+ - path: /scratch/m/muqeeth/llm_negotiation/configs
138
+ schema: file
139
+ provider: main
140
+ - path: ''
141
+ schema: structured
142
+ provider: schema
143
+ output_dir: /scratch/m/muqeeth/llm_negotiation/2025_11/no_press_10_1_ties_ad_align_nocurrtimestep_seed123
144
+ choices:
145
+ hydra/env: default
146
+ hydra/callbacks: null
147
+ hydra/job_logging: default
148
+ hydra/hydra_logging: default
149
+ hydra/hydra_help: default
150
+ hydra/help: default
151
+ hydra/sweeper: basic
152
+ hydra/launcher: basic
153
+ hydra/output: default
154
+ verbose: false
.hydra/overrides.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ []
run.log ADDED
The diff for this file is too large to render. See raw diff
 
seed_123/Qwen/Qwen2.5-7B-Instruct/adapters/README.md ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: Qwen/Qwen2.5-7B-Instruct
3
+ library_name: peft
4
+ pipeline_tag: text-generation
5
+ tags:
6
+ - base_model:adapter:Qwen/Qwen2.5-7B-Instruct
7
+ - lora
8
+ - transformers
9
+ ---
10
+
11
+ # Model Card for Model ID
12
+
13
+ <!-- Provide a quick summary of what the model is/does. -->
14
+
15
+
16
+
17
+ ## Model Details
18
+
19
+ ### Model Description
20
+
21
+ <!-- Provide a longer summary of what this model is. -->
22
+
23
+
24
+
25
+ - **Developed by:** [More Information Needed]
26
+ - **Funded by [optional]:** [More Information Needed]
27
+ - **Shared by [optional]:** [More Information Needed]
28
+ - **Model type:** [More Information Needed]
29
+ - **Language(s) (NLP):** [More Information Needed]
30
+ - **License:** [More Information Needed]
31
+ - **Finetuned from model [optional]:** [More Information Needed]
32
+
33
+ ### Model Sources [optional]
34
+
35
+ <!-- Provide the basic links for the model. -->
36
+
37
+ - **Repository:** [More Information Needed]
38
+ - **Paper [optional]:** [More Information Needed]
39
+ - **Demo [optional]:** [More Information Needed]
40
+
41
+ ## Uses
42
+
43
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
44
+
45
+ ### Direct Use
46
+
47
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
48
+
49
+ [More Information Needed]
50
+
51
+ ### Downstream Use [optional]
52
+
53
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
54
+
55
+ [More Information Needed]
56
+
57
+ ### Out-of-Scope Use
58
+
59
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
60
+
61
+ [More Information Needed]
62
+
63
+ ## Bias, Risks, and Limitations
64
+
65
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
66
+
67
+ [More Information Needed]
68
+
69
+ ### Recommendations
70
+
71
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
72
+
73
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
74
+
75
+ ## How to Get Started with the Model
76
+
77
+ Use the code below to get started with the model.
78
+
79
+ [More Information Needed]
80
+
81
+ ## Training Details
82
+
83
+ ### Training Data
84
+
85
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
86
+
87
+ [More Information Needed]
88
+
89
+ ### Training Procedure
90
+
91
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
92
+
93
+ #### Preprocessing [optional]
94
+
95
+ [More Information Needed]
96
+
97
+
98
+ #### Training Hyperparameters
99
+
100
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
101
+
102
+ #### Speeds, Sizes, Times [optional]
103
+
104
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
105
+
106
+ [More Information Needed]
107
+
108
+ ## Evaluation
109
+
110
+ <!-- This section describes the evaluation protocols and provides the results. -->
111
+
112
+ ### Testing Data, Factors & Metrics
113
+
114
+ #### Testing Data
115
+
116
+ <!-- This should link to a Dataset Card if possible. -->
117
+
118
+ [More Information Needed]
119
+
120
+ #### Factors
121
+
122
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
123
+
124
+ [More Information Needed]
125
+
126
+ #### Metrics
127
+
128
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
129
+
130
+ [More Information Needed]
131
+
132
+ ### Results
133
+
134
+ [More Information Needed]
135
+
136
+ #### Summary
137
+
138
+
139
+
140
+ ## Model Examination [optional]
141
+
142
+ <!-- Relevant interpretability work for the model goes here -->
143
+
144
+ [More Information Needed]
145
+
146
+ ## Environmental Impact
147
+
148
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
149
+
150
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
151
+
152
+ - **Hardware Type:** [More Information Needed]
153
+ - **Hours used:** [More Information Needed]
154
+ - **Cloud Provider:** [More Information Needed]
155
+ - **Compute Region:** [More Information Needed]
156
+ - **Carbon Emitted:** [More Information Needed]
157
+
158
+ ## Technical Specifications [optional]
159
+
160
+ ### Model Architecture and Objective
161
+
162
+ [More Information Needed]
163
+
164
+ ### Compute Infrastructure
165
+
166
+ [More Information Needed]
167
+
168
+ #### Hardware
169
+
170
+ [More Information Needed]
171
+
172
+ #### Software
173
+
174
+ [More Information Needed]
175
+
176
+ ## Citation [optional]
177
+
178
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
179
+
180
+ **BibTeX:**
181
+
182
+ [More Information Needed]
183
+
184
+ **APA:**
185
+
186
+ [More Information Needed]
187
+
188
+ ## Glossary [optional]
189
+
190
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
191
+
192
+ [More Information Needed]
193
+
194
+ ## More Information [optional]
195
+
196
+ [More Information Needed]
197
+
198
+ ## Model Card Authors [optional]
199
+
200
+ [More Information Needed]
201
+
202
+ ## Model Card Contact
203
+
204
+ [More Information Needed]
205
+ ### Framework versions
206
+
207
+ - PEFT 0.17.1
seed_123/Qwen/Qwen2.5-7B-Instruct/adapters/agent_adapter/adapter_config.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "Qwen/Qwen2.5-7B-Instruct",
5
+ "bias": "none",
6
+ "corda_config": null,
7
+ "eva_config": null,
8
+ "exclude_modules": null,
9
+ "fan_in_fan_out": false,
10
+ "inference_mode": true,
11
+ "init_lora_weights": true,
12
+ "layer_replication": null,
13
+ "layers_pattern": null,
14
+ "layers_to_transform": null,
15
+ "loftq_config": {},
16
+ "lora_alpha": 64,
17
+ "lora_bias": false,
18
+ "lora_dropout": 0.0,
19
+ "megatron_config": null,
20
+ "megatron_core": "megatron.core",
21
+ "modules_to_save": null,
22
+ "peft_type": "LORA",
23
+ "qalora_group_size": 16,
24
+ "r": 32,
25
+ "rank_pattern": {},
26
+ "revision": null,
27
+ "target_modules": [
28
+ "o_proj",
29
+ "v_proj",
30
+ "up_proj",
31
+ "q_proj",
32
+ "k_proj",
33
+ "gate_proj",
34
+ "down_proj"
35
+ ],
36
+ "target_parameters": null,
37
+ "task_type": "CAUSAL_LM",
38
+ "trainable_token_indices": null,
39
+ "use_dora": false,
40
+ "use_qalora": false,
41
+ "use_rslora": false
42
+ }
seed_123/Qwen/Qwen2.5-7B-Instruct/adapters/critic_adapter/adapter_config.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "Qwen/Qwen2.5-7B-Instruct",
5
+ "bias": "none",
6
+ "corda_config": null,
7
+ "eva_config": null,
8
+ "exclude_modules": null,
9
+ "fan_in_fan_out": false,
10
+ "inference_mode": true,
11
+ "init_lora_weights": true,
12
+ "layer_replication": null,
13
+ "layers_pattern": null,
14
+ "layers_to_transform": null,
15
+ "loftq_config": {},
16
+ "lora_alpha": 64,
17
+ "lora_bias": false,
18
+ "lora_dropout": 0.0,
19
+ "megatron_config": null,
20
+ "megatron_core": "megatron.core",
21
+ "modules_to_save": null,
22
+ "peft_type": "LORA",
23
+ "qalora_group_size": 16,
24
+ "r": 32,
25
+ "rank_pattern": {},
26
+ "revision": null,
27
+ "target_modules": [
28
+ "o_proj",
29
+ "v_proj",
30
+ "up_proj",
31
+ "q_proj",
32
+ "k_proj",
33
+ "gate_proj",
34
+ "down_proj"
35
+ ],
36
+ "target_parameters": null,
37
+ "task_type": "CAUSAL_LM",
38
+ "trainable_token_indices": null,
39
+ "use_dora": false,
40
+ "use_qalora": false,
41
+ "use_rslora": false
42
+ }
src_code_for_reproducibility/__init__.py ADDED
File without changes
src_code_for_reproducibility/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (148 Bytes). View file
 
src_code_for_reproducibility/chat_utils/__pycache__/apply_template.cpython-312.pyc ADDED
Binary file (3.64 kB). View file
 
src_code_for_reproducibility/chat_utils/__pycache__/chat_turn.cpython-312.pyc ADDED
Binary file (1.32 kB). View file
 
src_code_for_reproducibility/chat_utils/__pycache__/template_specific.cpython-312.pyc ADDED
Binary file (3.61 kB). View file
 
src_code_for_reproducibility/docs/source/contributing.rst ADDED
File without changes
src_code_for_reproducibility/docs/source/environments/dond.rst ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ =================
2
+ Deal or No Deal
3
+ =================
4
+
5
+ The Deal or No Deal (DoND) environment provides a multi-agent negotiation interface where players trade
6
+ items with different values. This document describes the API for interacting with the DoND environment
7
+ and its associated agent handler.
8
+
9
+ Overview
10
+ --------
11
+
12
+ Deal or No Deal is a negotiation game where two agents must agree on how to divide a set of items,
13
+ each of which has different values to each agent. The agents engage in a back-and-forth dialogue to
14
+ determine an allocation of the items, with each trying to maximize their own total value.
15
+
16
+ Our implementation follows the Multi-Agent Negotiation Environment standard, allowing it to be used
17
+ with LLM agents through a text-based interface.
18
+
19
+ Game Rules
20
+ ----------
21
+
22
+ ### Basic Structure
23
+
24
+ The core mechanics of Deal or No Deal are:
25
+
26
+ 1. Two agents negotiate over a set of items (e.g., books, balls, hats)
27
+ 2. Each item has:
28
+ - A specific quantity (how many of each item is available)
29
+ - A value for each agent (which may differ between agents)
30
+ 3. Agents take turns sending messages to negotiate how to split the items
31
+ 4. Once an agreement is reached, agents finalize the deal
32
+ 5. Points are awarded based on the value of items each agent receives
33
+
34
+ ### Detailed Gameplay
35
+
36
+ #### Setup Phase
37
+
38
+ The game begins with:
39
+ - A set of items (e.g., "book", "hat", "ball")
40
+ - Each item has a quantity (e.g., 6 books, 2 hats, 4 balls)
41
+ - Each agent has private values for each item (e.g., books might be worth 5 points to one agent but only 2 points to the other)
42
+ - Agents are assigned roles (starting negotiator and responding negotiator)
43
+
44
+ #### Negotiation Phase
45
+
46
+ 1. Agents take turns sending free-form text messages to each other
47
+ 2. Messages can include offers, counter-offers, questions, or strategic communication
48
+ 3. There is a maximum number of messages permitted (preventing endless negotiations)
49
+ 4. Either agent can propose to finalize an agreement at any time
50
+
51
+ For example:
52
+ - Agent 1: "I propose I get all the books and you get all the hats and balls."
53
+ - Agent 2: "That doesn't work for me. How about you get 3 books and I get 3 books, all the hats, and all the balls?"
54
+ - Agent 1: "Let me counter-offer: I get 4 books and 2 balls, you get 2 books, all hats, and 2 balls."
55
+
56
+ #### Finalization Phase
57
+
58
+ 1. When an agent wants to finalize a deal, they must specify the exact allocation:
59
+ - How many of each item they receive
60
+ - How many of each item the other agent receives
61
+ 2. The other agent must then either agree (by submitting the same allocation) or reject the finalization
62
+ 3. If both agents submit matching finalizations, the deal is executed
63
+ 4. If finalizations don't match, no agreement is reached, and both agents receive 0 points
64
+
65
+ #### Scoring
66
+
67
+ 1. Each agent's score is calculated based on the value of items they receive
68
+ 2. The formula is: Sum(quantity_of_item_i × value_of_item_i_to_agent)
69
+ 3. If no agreement is reached, both agents receive 0 points
70
+
71
+ ### Example Game
72
+
73
+ Let's walk through a simple example:
74
+
75
+ **Setup:**
76
+ - Items: Books (4), Hats (2), Balls (6)
77
+ - Agent 1 values: Books=5, Hats=1, Balls=2
78
+ - Agent 2 values: Books=3, Hats=6, Balls=1
79
+
80
+ **Negotiation (simplified):**
81
+ 1. Agent 1: "I would like all the books and balls. You can have the hats."
82
+ 2. Agent 2: "That doesn't work for me. Books are valuable. I propose I get all the hats and 2 books, you get 2 books and all the balls."
83
+ 3. Agent 1: "How about I get 3 books and all the balls, and you get 1 book and all the hats?"
84
+ 4. Agent 2: "I accept your proposal."
85
+
86
+ **Finalization:**
87
+ - Agent 1 submits: Agent 1 gets (Books: 3, Hats: 0, Balls: 6), Agent 2 gets (Books: 1, Hats: 2, Balls: 0)
88
+ - Agent 2 submits the same allocation, confirming agreement
89
+
90
+ **Scoring:**
91
+ - Agent 1 score: (3 books × 5) + (0 hats × 1) + (6 balls × 2) = 15 + 0 + 12 = 27 points
92
+ - Agent 2 score: (1 book × 3) + (2 hats × 6) + (0 balls × 1) = 3 + 12 + 0 = 15 points
93
+
94
+ ### Game Variations
95
+
96
+ The DoND environment supports several variations through configuration parameters:
97
+
98
+ #### Different Value Distributions
99
+
100
+ The environment offers multiple ways to assign values to items:
101
+
102
+ 1. **Standard Random Setup (dond_random_setup)**:
103
+ - Items have even-numbered quantities
104
+ - Each agent receives distinct random values for each item
105
+ - Values are drawn from a uniform distribution
106
+
107
+ 2. **Independent Random Values (independent_random_vals)**:
108
+ - Item quantities can be any number in the specified range
109
+ - Values for each agent are drawn independently
110
+ - Creates more varied negotiation scenarios
111
+
112
+ 3. **Bicameral Value Distribution (bicameral_vals_assignator)**:
113
+ - Creates a "high value" and "low value" distribution for each item
114
+ - Each agent values approximately half the items highly and half lowly
115
+ - Values are drawn from normal distributions with different means
116
+ - Creates scenarios with clear trade opportunities
117
+
118
+ #### Visibility Options
119
+
120
+ 1. **Finalization Visibility**:
121
+ - When enabled, both agents can see each other's finalization proposals
122
+ - When disabled, finalization proposals remain private until both are submitted
123
+
124
+ 2. **Other Values Visibility**:
125
+ - When enabled, agents can see each other's value functions
126
+ - When disabled, agents only know their own values
127
+ - Creates information asymmetry and richer negotiation dynamics
128
+
129
+ #### Game Modes
130
+
131
+ 1. **Cooperative Mode ("coop")**:
132
+ - Agents are encouraged to find mutually beneficial solutions
133
+ - Success is measured by the sum of both agents' scores
134
+
135
+ 2. **Competitive Mode ("comp")**:
136
+ - Agents aim to maximize their individual scores
137
+ - Creates more adversarial negotiations
138
+
139
+ #### Round Structure
140
+
141
+ 1. **Single Round**:
142
+ - One negotiation session between the same agents
143
+ - Simple evaluation of negotiation skills
144
+
145
+ 2. **Multiple Rounds**:
146
+ - Agents negotiate multiple times with different item setups
147
+ - Allows for learning and adaptation over time
148
+ - Roles can be swapped between rounds
149
+
150
+ DondEnv
151
+ ------------
152
+
153
+ The ``DondEnv`` class provides an interface to the Deal or No Deal environment that follows the Multi-Agent
154
+ Negotiation Environment standard.
155
+
156
+ .. code-block:: python
157
+
158
+ class DondEnv:
159
+ """
160
+ Multi-Agent Negotiation Environment for Deal or No Deal.
161
+ """
162
+ def __init__(
163
+ self,
164
+ agents,
165
+ mode="coop",
166
+ max_messages=None,
167
+ min_messages=None,
168
+ max_chars_per_message=None,
169
+ rounds_per_game=1,
170
+ random_setup_func=None,
171
+ random_setup_kwargs=None,
172
+ role_assignator_func=None,
173
+ role_assignator_func_kwargs=None,
174
+ finalization_visibility=False,
175
+ other_values_visibility=False,
176
+ random_seed=None
177
+ ):
178
+ """Initialize the Deal or No Deal environment.
179
+
180
+ Args:
181
+ agents: List of agent IDs participating in the game
182
+ mode: Game mode ("coop" or "comp")
183
+ max_messages: Maximum number of messages per agent per round
184
+ min_messages: Minimum number of messages per agent per round
185
+ max_chars_per_message: Maximum characters per message
186
+ rounds_per_game: Number of negotiation rounds to play
187
+ random_setup_func: Function to generate item quantities and values
188
+ random_setup_kwargs: Arguments for the random setup function
189
+ role_assignator_func: Function to assign roles to agents
190
+ role_assignator_func_kwargs: Arguments for the role assignator
191
+ finalization_visibility: Whether agents can see each other's finalizations
192
+ other_values_visibility: Whether agents can see each other's values
193
+ random_seed: Seed for reproducibility
194
+ """
195
+ # ...
196
+
197
+ def reset(self):
198
+ """Reset the environment to an initial state and return the initial observation.
199
+
200
+ Returns:
201
+ observation (dict): A dictionary where keys are agent identifiers and values are observations.
202
+ """
203
+ # ...
204
+
205
+ def step(self, actions):
206
+ """Take a step in the environment using the provided actions.
207
+
208
+ Args:
209
+ actions (dict): A dictionary where keys are agent identifiers and values are actions.
210
+ Actions can be messages or finalization proposals.
211
+
212
+ Returns:
213
+ observations (dict): A dictionary where keys are agent identifiers and values are observations.
214
+ done (bool): Whether the episode has ended.
215
+ info (dict): Additional information about the environment.
216
+ """
217
+ # ...
218
+
219
+ def get_state(self):
220
+ """Retrieve the current state of the game.
221
+
222
+ Returns:
223
+ state (dict): The current state of the game, including items, quantities, values, etc.
224
+ """
225
+ # ...
226
+
227
+ Key Implementation Details
228
+ ~~~~~~~~~~~~~~~~~~~~~~~~~
229
+
230
+ The ``DondEnv`` class implements several key features:
231
+
232
+ 1. **Multi-Agent Support**: The environment tracks two agents and manages their alternating messages.
233
+
234
+ 2. **Turn-Based Dialogue**: The environment enforces turn structure and limits on message count.
235
+
236
+ 3. **Finalization Processing**: The environment validates and processes finalization proposals.
237
+
238
+ 4. **Random Setup**: The environment supports multiple methods of generating negotiation scenarios.
239
+
240
+ 5. **Round Management**: The environment can handle multiple rounds with different setups.
241
+
242
+ Observation Structure
243
+ ~~~~~~~~~~~~~~~~~~~~
244
+
245
+ Each agent receives an observation (state) dictionary with rich information about the game:
246
+
247
+ .. code-block:: python
248
+
249
+ {
250
+ "mode": str, # Game mode ("coop" or "comp")
251
+ "role_values": dict, # Value mappings for each role
252
+ "role_props": dict, # Properties for each role
253
+ "agent_to_role": dict, # Mapping from agent IDs to roles
254
+ "is_new_round": bool, # Whether this is the start of a new round
255
+ "is_new_game": bool, # Whether this is the start of a new game
256
+ "game_over": bool, # Whether the game is over
257
+ "items": list, # List of item names
258
+ "quantities": dict, # Quantities of each item
259
+ "has_finalized": bool, # Whether finalization has been proposed
260
+ "last_message": dict, # The last message sent
261
+ "messages_remaining": dict, # Number of messages each agent can still send
262
+ # And various history tracking fields
263
+ }
264
+
265
+ Action Structure
266
+ ~~~~~~~~~~~~~~~
267
+
268
+ Actions can be:
269
+
270
+ 1. **Text Messages**: Free-form text for negotiation.
271
+ 2. **Finalization Proposals**: Structured data specifying the exact allocation of items.
272
+
273
+ Example finalization format:
274
+
275
+ .. code-block:: python
276
+
277
+ {
278
+ "type": "finalize",
279
+ "allocation": {
280
+ "agent1": {"book": 3, "hat": 0, "ball": 6},
281
+ "agent2": {"book": 1, "hat": 2, "ball": 0}
282
+ }
283
+ }
284
+
285
+ Value Setup Functions
286
+ --------------------
287
+
288
+ The DoND environment provides several functions for setting up item values:
289
+
290
+ .. code-block:: python
291
+
292
+ def dond_random_setup(items, min_quant, max_quant, min_val, max_val, random_seed=None):
293
+ """
294
+ Generates items, even-numbered quantities and distinct random values for each category for both agents.
295
+
296
+ Args:
297
+ items (list): List of items.
298
+ min_quant (int): Minimum quantity per item.
299
+ max_quant (int): Maximum quantity per item.
300
+ min_val (int): Minimum value per item.
301
+ max_val (int): Maximum value per item.
302
+ random_seed (int, optional): Seed for random generation.
303
+
304
+ Returns:
305
+ tuple: (items, quantities, (val_starting_negotiator, val_responding_negotiator))
306
+ """
307
+ # ...
308
+
309
+ def independent_random_vals(items, min_quant, max_quant, min_val, max_val, random_seed=None):
310
+ """
311
+ Generates random quantities and independent random values for both agents.
312
+
313
+ Args:
314
+ Similar to dond_random_setup
315
+
316
+ Returns:
317
+ tuple: (items, quantities, (val_starting_negotiator, val_responding_negotiator))
318
+ """
319
+ # ...
320
+
321
+ def bicameral_vals_assignator(items, min_quant, max_quant, low_val_mean, low_val_std, high_val_mean, high_val_std, random_seed=None):
322
+ """
323
+ Generates values with a bicameral distribution - each agent values half the items highly.
324
+
325
+ Args:
326
+ items (list): List of items.
327
+ min_quant, max_quant: Range for quantities
328
+ low_val_mean, low_val_std: Mean and standard deviation for the "low value" distribution
329
+ high_val_mean, high_val_std: Mean and standard deviation for the "high value" distribution
330
+ random_seed: Seed for reproducibility
331
+
332
+ Returns:
333
+ tuple: (items, quantities, (val_starting_negotiator, val_responding_negotiator))
334
+ """
335
+ # ...
336
+
337
+ Running DoND Games
338
+ ----------------------
339
+
340
+ To run Deal or No Deal games with LLM agents, you can use the following structure:
341
+
342
+ .. code-block:: python
343
+
344
+ from mllm.environments.dond.dond_game import DondEnv
345
+ from mllm.environments.dond.dond_agent import DondAgent
346
+ from src.run_matches import run_batched_matches
347
+
348
+ # Create environment
349
+ env = DondEnv(
350
+ agents=["agent1", "agent2"],
351
+ mode="coop",
352
+ max_messages=10,
353
+ rounds_per_game=1,
354
+ random_setup_func="dond_random_setup",
355
+ random_setup_kwargs={
356
+ "items": ["book", "hat", "ball"],
357
+ "min_quant": 2,
358
+ "max_quant": 8,
359
+ "min_val": 1,
360
+ "max_val": 10
361
+ },
362
+ finalization_visibility=False
363
+ )
364
+
365
+ # Create agent handlers (implementation details would vary)
366
+ agent_handlers = {
367
+ "agent1": DondAgent(agent_id="agent1"),
368
+ "agent2": DondAgent(agent_id="agent2")
369
+ }
370
+
371
+ # Define policy mapping
372
+ policy_mapping = {
373
+ "llm_policy": my_llm_policy_function
374
+ }
375
+
376
+ # Run the game
377
+ game_results = run_batched_matches(
378
+ envs=[env],
379
+ agent_handlers_per_env=[agent_handlers],
380
+ policy_mapping=policy_mapping,
381
+ max_parallel_matches=1
382
+ )
383
+
384
+ Limitations and Considerations
385
+ -----------------------------
386
+
387
+ 1. **Negotiation Complexity**: The open-ended nature of negotiations can be challenging for some LLM agents.
388
+
389
+ 2. **Parsing Challenges**: Extracting structured finalization proposals from free-form text requires robust parsing.
390
+
391
+ 3. **Optimization Opportunities**: Different agents may employ different negotiation strategies to optimize outcomes.
392
+
393
+ 4. **Fairness Evaluation**: The environment allows research into questions of fair division and Pareto optimality.
394
+
395
+ 5. **Strategic Deception**: Agents might strategically misrepresent their true values, adding complexity to negotiations.
396
+
397
+ Advanced Usage
398
+ ------------
399
+
400
+ For advanced usage, you can:
401
+
402
+ 1. **Custom Value Functions**: Create more complex distributions of item values for specific research questions.
403
+
404
+ 2. **Novel Negotiation Scenarios**: Design item sets and values to test specific negotiation skills.
405
+
406
+ 3. **Curriculum Learning**: Create progressively more difficult negotiation scenarios.
407
+
408
+ 4. **Communication Analysis**: Analyze the language and strategies used in successful negotiations.
409
+
410
+ 5. **Multi-Round Dynamics**: Study how agents adapt their strategies over multiple rounds.
src_code_for_reproducibility/docs/source/environments/ipd.rst ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ =================
2
+ Iterated Prisoner's Dilemma
3
+ =================
4
+
5
+ The Iterated Prisoner's Dilemma environment provides a classic game theory setting for studying cooperation
6
+ and competition between agents. This document describes the API for interacting with the IPD environment
7
+ and its associated agent handler.
8
+
9
+ Overview
10
+ --------
11
+
12
+ The Prisoner's Dilemma is a fundamental problem in game theory that demonstrates why two rational individuals might not
13
+ cooperate, even when it appears in their best interest to do so. In the iterated version, the same two players
14
+ repeatedly face the same dilemma, allowing for the development of trust or retaliation based on previous interactions.
15
+
16
+ Our implementation follows the Multi-Agent Negotiation Environment standard, allowing it to be used with
17
+ LLM agents through a text-based interface.
18
+
19
+ Game Rules
20
+ ----------
21
+
22
+ ### Basic Premise
23
+
24
+ The scenario behind the Prisoner's Dilemma is as follows:
25
+
26
+ Two criminals are arrested and imprisoned. Each prisoner is in solitary confinement with no means of communicating with
27
+ the other. The prosecutors lack sufficient evidence to convict the pair on the principal charge, but they have enough
28
+ to convict both on a lesser charge. Simultaneously, the prosecutors offer each prisoner a bargain:
29
+
30
+ - If both prisoners betray each other, each serves 2 years in prison (the "punishment" payoff)
31
+ - If one betrays the other while the other remains silent, the betrayer goes free (the "temptation" payoff) while the
32
+ silent accomplice serves 3 years (the "sucker" payoff)
33
+ - If both remain silent, each serves only 1 year in prison (the "reward" payoff)
34
+
35
+ ### Game Mechanics
36
+
37
+ In our implementation, the choices are simplified to:
38
+ - **C**: Cooperate (remain silent)
39
+ - **D**: Defect (betray the other prisoner)
40
+
41
+ Each round, both players simultaneously choose either C or D, and receive points based on the combination of their choices:
42
+
43
+ - Both choose C: Both receive the "reward" payoff (3 points by default)
44
+ - Both choose D: Both receive the "punishment" payoff (1 point by default)
45
+ - One chooses C, one chooses D: The defector receives the "temptation" payoff (5 points by default), while the cooperator
46
+ receives the "sucker" payoff (0 points by default)
47
+
48
+ ### Example: Single Round
49
+
50
+ Let's see how a single round plays out:
51
+
52
+ 1. Alice and Bob simultaneously make their choices
53
+ 2. If Alice chooses C and Bob chooses C:
54
+ - Alice receives 3 points
55
+ - Bob receives 3 points
56
+ 3. If Alice chooses C and Bob chooses D:
57
+ - Alice receives 0 points
58
+ - Bob receives 5 points
59
+ 4. If Alice chooses D and Bob chooses C:
60
+ - Alice receives 5 points
61
+ - Bob receives 0 points
62
+ 5. If Alice chooses D and Bob chooses D:
63
+ - Alice receives 1 point
64
+ - Bob receives 1 point
65
+
66
+ ### Iterated Game Structure
67
+
68
+ The iterated version repeats this basic game for a fixed number of rounds. The key features are:
69
+
70
+ 1. Players know the total number of rounds in advance
71
+ 2. After each round, players learn what choice the other player made
72
+ 3. Players maintain a cumulative score across all rounds
73
+ 4. Players can adjust their strategy based on the history of previous interactions
74
+
75
+ ### Game Variations
76
+
77
+ The IPD environment supports several variations through configuration parameters:
78
+
79
+ #### Different Payoff Matrices
80
+
81
+ The standard payoff values can be modified to create different incentive structures:
82
+ - **Traditional PD**: reward=3, punishment=1, temptation=5, sucker=0
83
+ - **Weak Temptation**: reward=3, punishment=1, temptation=4, sucker=0 (reduces the incentive to defect)
84
+ - **Harsh Punishment**: reward=3, punishment=0, temptation=5, sucker=0 (increases the cost of mutual defection)
85
+ - **Generous**: reward=4, punishment=2, temptation=5, sucker=1 (cushions the blow of being betrayed)
86
+
87
+ #### Game Length Variations
88
+
89
+ The number of rounds can significantly impact strategy:
90
+ - **Short Games** (5-10 rounds): Incentivizes more defection, especially near the end
91
+ - **Medium Games** (20-50 rounds): Allows for the development of tit-for-tat and forgiveness strategies
92
+ - **Long Games** (100+ rounds): Favors steady cooperation with occasional "probing" defections
93
+
94
+ ### Common Strategies
95
+
96
+ While not enforced by the environment, several well-known strategies can emerge:
97
+ - **Always Cooperate**: Always choose C
98
+ - **Always Defect**: Always choose D
99
+ - **Tit for Tat**: Start with C, then copy what the opponent did in the previous round
100
+ - **Forgiving Tit for Tat**: Like Tit for Tat, but occasionally cooperate even after being defected against
101
+ - **Grudger**: Cooperate until the opponent defects once, then always defect
102
+ - **Random**: Choose randomly between C and D
103
+
104
+ IPDEnv
105
+ ------
106
+
107
+ The ``IPDEnv`` class provides an interface to the Iterated Prisoner's Dilemma environment that follows the
108
+ Multi-Agent Negotiation Environment standard.
109
+
110
+ .. code-block:: python
111
+
112
+ class IPDEnv:
113
+ """
114
+ Iterated Prisoner's Dilemma environment following the MarlEnvironment standard.
115
+
116
+ In each round of the game, two agents simultaneously choose to either cooperate (C) or defect (D).
117
+ The payoffs are as follows:
118
+ - If both cooperate: Both receive the "reward" (usually 3 points)
119
+ - If both defect: Both receive the "punishment" (usually 1 point)
120
+ - If one cooperates and one defects: The defector receives the "temptation" (usually 5 points)
121
+ and the cooperator receives the "sucker" payoff (usually 0 points)
122
+
123
+ The game is played for a specified number of rounds.
124
+ """
125
+
126
+ def __init__(
127
+ self,
128
+ rounds_per_game: int = 10,
129
+ reward: float = 3.0, # Both cooperate
130
+ punishment: float = 1.0, # Both defect
131
+ temptation: float = 5.0, # Defector's reward when other cooperates
132
+ sucker: float = 0.0, # Cooperator's reward when other defects
133
+ random_seed: Optional[int] = None,
134
+ ):
135
+ """
136
+ Initialize the Iterated Prisoner's Dilemma environment.
137
+
138
+ Args:
139
+ rounds_per_game: Number of rounds to play
140
+ reward: Payoff when both agents cooperate
141
+ punishment: Payoff when both agents defect
142
+ temptation: Payoff for defecting when other agent cooperates
143
+ sucker: Payoff for cooperating when other agent defects
144
+ seed: Random seed for reproducibility
145
+ """
146
+ # ...
147
+
148
+ def reset(self) -> Dict[str, Dict[str, Any]]:
149
+ """
150
+ Reset the environment to an initial state and return the initial observation.
151
+
152
+ Returns:
153
+ observation (dict): A dictionary where keys are agent identifiers and values are observations.
154
+ """
155
+ # ...
156
+
157
+ def step(self, actions: Dict[str, str]) -> Tuple[Dict[str, Dict[str, Any]], bool, Dict[str, Any]]:
158
+ """
159
+ Take a step in the environment using the provided actions.
160
+
161
+ Args:
162
+ actions (dict): A dictionary where keys are agent identifiers and values are actions ('C' or 'D').
163
+
164
+ Returns:
165
+ observations (dict): A dictionary where keys are agent identifiers and values are observations.
166
+ done (bool): Whether the episode has ended.
167
+ info (dict): Additional information about the environment.
168
+ """
169
+ # ...
170
+
171
+ Key Implementation Details
172
+ ~~~~~~~~~~~~~~~~~~~~~~~~~
173
+
174
+ The ``IPDEnv`` class implements several key features:
175
+
176
+ 1. **Two-Agent Support**: The environment tracks two agents ("alice" and "bob") and manages their interactions.
177
+
178
+ 2. **Round-Based Play**: The environment enforces turn structure and tracks game history.
179
+
180
+ 3. **Payoff Matrix**: The environment calculates rewards based on the standard prisoner's dilemma payoff matrix.
181
+
182
+ 4. **Observation Generation**: The environment generates detailed observations for each agent, including action history and rewards.
183
+
184
+ 5. **Game Termination**: The environment tracks game termination after the specified number of rounds.
185
+
186
+ Observation Structure
187
+ ~~~~~~~~~~~~~~~~~~~~
188
+
189
+ Each agent receives an observation dictionary with the following structure:
190
+
191
+ .. code-block:: python
192
+
193
+ {
194
+ "current_round": int, # Current round number (0-indexed)
195
+ "rounds_per_game": int, # Total number of rounds in the game
196
+ "history": List[Dict], # Complete game history so far
197
+ "last_round_actions": Dict[str, str], # Actions from the previous round (if any)
198
+ "last_round_reward": float, # Reward received in the previous round (if any)
199
+ "total_reward": float, # Cumulative reward so far
200
+ "payoff_matrix": Dict[str, float], # The game's payoff matrix values
201
+ }
202
+
203
+ Action Structure
204
+ ~~~~~~~~~~~~~~~
205
+
206
+ Actions are simple strings:
207
+
208
+ 1. ``"C"`` for Cooperate
209
+ 2. ``"D"`` for Defect
210
+
211
+ IPDAgent
212
+ --------------
213
+
214
+ The ``IPDAgent`` class implements the agent handler interface for the Iterated Prisoner's Dilemma, processing observations from the environment and generating actions through an LLM.
215
+
216
+ .. code-block:: python
217
+
218
+ class IPDAgent:
219
+ """
220
+ Agent handler for Iterated Prisoner's Dilemma, implementing the AgentState interface
221
+ for the multi-agent negotiation standard.
222
+ """
223
+
224
+ def __init__(
225
+ self,
226
+ agent_id: str,
227
+ policy_id: str = "llm_policy",
228
+ system_prompt: Optional[str] = None,
229
+ max_errors: int = 3,
230
+ opponent_id: Optional[str] = None,
231
+ ):
232
+ """
233
+ Initialize the IPD agent handler.
234
+
235
+ Args:
236
+ agent_id: Identifier for this agent ("alice" or "bob")
237
+ policy_id: Identifier for the policy this agent uses
238
+ system_prompt: Optional custom system prompt for the LLM
239
+ max_errors: Maximum number of parsing errors before defaulting to cooperate
240
+ opponent_id: Optional identifier of the opponent (inferred if not provided)
241
+ """
242
+ # ...
243
+
244
+ def step(self, observation_from_env: Dict[str, Any], policy_output: str = None) -> Tuple[str, Dict[str, Any], str, bool, Dict[str, Any]]:
245
+ """
246
+ Update the agent state based on the observation and process the policy output.
247
+
248
+ Args:
249
+ observation_from_env: The observation from the environment
250
+ policy_output: The output from the policy (LLM response)
251
+
252
+ Returns:
253
+ policy_id: The policy identifier
254
+ policy_input: The input to the policy
255
+ action: The action to be sent to the environment
256
+ done: Whether the action is ready to be sent to the environment
257
+ info: Additional information about the agent
258
+ """
259
+ # ...
260
+
261
+ Key Implementation Details
262
+ ~~~~~~~~~~~~~~~~~~~~~~~~~
263
+
264
+ The ``IPDAgent`` class implements several key features:
265
+
266
+ 1. **LLM Interaction**: The agent generates prompts for an LLM and processes the LLM's responses.
267
+
268
+ 2. **Action Extraction**: The agent parses the LLM's output to extract valid actions (C or D).
269
+
270
+ 3. **Error Handling**: The agent provides helpful error messages when parsing fails and defaults to cooperation after multiple failures.
271
+
272
+ 4. **History Tracking**: The agent maintains and provides the complete game history in its prompts.
273
+
274
+ 5. **Strategy Explanation**: The agent can extract and log the reasoning behind an LLM's decisions.
275
+
276
+ Prompt Structure
277
+ ~~~~~~~~~~~~~~~
278
+
279
+ The agent generates prompts that include:
280
+
281
+ 1. **System Prompt**: Instructions and context for the LLM, explaining its role and the rules of the Prisoner's Dilemma.
282
+
283
+ 2. **Game State Description**: A text description of the current game state, including:
284
+ - Current round number
285
+ - History of previous rounds (if any)
286
+ - Cumulative score
287
+
288
+ 3. **Action Request**: Instructions on how to format the response, requiring an explicit action tag.
289
+
290
+ Example system prompt:
291
+
292
+ .. code-block:: text
293
+
294
+ You are playing as Alice in an Iterated Prisoner's Dilemma game against Bob.
295
+ In each round, you must choose to either Cooperate (C) or Defect (D).
296
+
297
+ The payoffs are:
298
+ - If both players Cooperate: You each get 3 points
299
+ - If both players Defect: You each get 1 point
300
+ - If you Cooperate and Bob Defects: You get 0 points, Bob gets 5 points
301
+ - If you Defect and Bob Cooperates: You get 5 points, Bob gets 0 points
302
+
303
+ Your goal is to maximize your total points across all rounds.
304
+ The game will last for exactly 10 rounds, and both players know this.
305
+
306
+ Example game state prompt:
307
+
308
+ .. code-block:: text
309
+
310
+ Current round: 3/10
311
+
312
+ History:
313
+ Round 1: You chose C, Bob chose C. You earned 3 points.
314
+ Round 2: You chose C, Bob chose D. You earned 0 points.
315
+
316
+ Your total score so far: 3 points
317
+
318
+ What is your choice for round 3?
319
+ Please respond with <action>C</action> to cooperate or <action>D</action> to defect,
320
+ and explain your reasoning.
321
+
322
+ Running IPD Games
323
+ ----------------------
324
+
325
+ To run Iterated Prisoner's Dilemma games with LLM agents, you can use the following code structure:
326
+
327
+ .. code-block:: python
328
+
329
+ from mllm.environments.ipd.ipd_game import IPDEnv
330
+ from mllm.environments.ipd.ipd_agent import IPDAgent
331
+ from mllm.run_matches import run_batched_matches
332
+
333
+ # Create environment
334
+ env = IPDEnv(
335
+ rounds_per_game=10,
336
+ reward=3.0,
337
+ punishment=1.0,
338
+ temptation=5.0,
339
+ sucker=0.0
340
+ )
341
+
342
+ # Create agent handlers
343
+ agent_handlers = {
344
+ "alice": IPDAgent(agent_id="alice"),
345
+ "bob": IPDAgent(agent_id="bob")
346
+ }
347
+
348
+ # Define policy mapping
349
+ policy_mapping = {
350
+ "llm_policy": my_llm_policy_function
351
+ }
352
+
353
+ # Run the game
354
+ game_results = run_batched_matches(
355
+ envs=[env],
356
+ agent_handlers_per_env=[agent_handlers],
357
+ policy_mapping=policy_mapping,
358
+ max_parallel_matches=1
359
+ )
360
+
361
+ # Process results
362
+ for result in game_results:
363
+ print(f"Game finished. Scores: {result['total_rewards']}")
364
+
365
+ Statistics and Analysis
366
+ ----------------------
367
+
368
+ The IPD environment includes utility functions for analyzing game outcomes:
369
+
370
+ 1. **Cooperation Rates**: Percentage of rounds where each agent cooperated.
371
+ 2. **Mutual Cooperation/Defection**: Percentage of rounds where both agents made the same choice.
372
+ 3. **Score Distribution**: Analysis of how points were accumulated over the game.
373
+
374
+ These statistics can be calculated using the ``gather_ipd_statistics`` function:
375
+
376
+ .. code-block:: python
377
+
378
+ from mllm.environments.ipd.ipd_statistics_funcs import gather_ipd_statistics
379
+
380
+ stats = gather_ipd_statistics(match_info, env_info)
381
+ print(f"Cooperation rates: {stats['cooperation_rate']}")
382
+ print(f"Mutual cooperation rate: {stats['mutual_cooperation_rate']}")
383
+ print(f"Mutual defection rate: {stats['mutual_defection_rate']}")
384
+
385
+ Limitations and Considerations
386
+ -----------------------------
387
+
388
+ 1. **Determinism**: The environment is deterministic, with randomness only in initialization if a seed is provided.
389
+
390
+ 2. **Limited Player Count**: The IPD environment only supports exactly two players.
391
+
392
+ 3. **Perfect Information**: Both players have perfect information about the game history.
393
+
394
+ 4. **Simultaneous Actions**: Both players act simultaneously, which requires adaptations for some LLM interfaces.
395
+
396
+ 5. **Fixed Game Length**: The total number of rounds is fixed and known to both players from the start.
397
+
398
+ Advanced Usage
399
+ ------------
400
+
401
+ For advanced usage, you can customize:
402
+
403
+ 1. **Payoff Matrix**: Modify reward values to create different incentive structures.
404
+
405
+ 2. **System Prompts**: Customize the LLM's understanding of the game and potential strategies.
406
+
407
+ 3. **Error Handling**: Adjust how the agent responds to invalid LLM outputs.
408
+
409
+ 4. **Analysis**: Create custom statistics gathering for specific research questions.
410
+
411
+ 5. **Integration**: Connect the IPD environment to other negotiation frameworks or tournament systems.
src_code_for_reproducibility/docs/source/marl_standard.rst ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ =================
2
+ Abstract Standard for Multi-Agent Negotiation Environments
3
+ =================
4
+
5
+ Multi-Agent Negotiation Environments require more features than gymnasium environments in order to be used as interfaces in general game running code.
6
+ The two fundamental differences between gymnasium environments and Multi-Agent Negotiation Environments are:
7
+
8
+ 1. Response from the LLM is a text action, not a discrete action. Therefore, appropriate parsing of the text is required. The model may need to be run multiple times to get the full action.
9
+ This is why we introduce the `AgentHandler` class, which is responsible for parsing the LLM's response.
10
+ 2. The environment needs to be able to handle multi-agent interactions.
11
+ This is why we introduce the `NegotiationEnvironment` class, which is responsible for handling the multi-agent interactions.
12
+ 3. MARL environments are complex to describe. In different contexts, the same environment may be described differently. Therefore, both the environement and the agent handlers are
13
+ responsible for describing a particular trajectory. This information is given by the `get_log_info` method.
14
+ 4. There might be a lot of overlap between the neural networks used by each agent. For instance, the same model may be used for all agents. This motivates a requirement for a
15
+ policy identifier for each agent.
16
+
17
+ Taking inspiration from the `gymnasium <https://gymnasium.farama.org/>`_ library, we introduce a new standard for Multi-Agent Negotiation Environments.
18
+
19
+ Our standard is based on the following features:
20
+
21
+ Environments are of the form:
22
+
23
+ .. code-block:: python
24
+
25
+ class MarlEnvironment():
26
+
27
+ def __init__(self):
28
+ """Initialize the environment."""
29
+ pass
30
+
31
+ def reset(self):
32
+ """Reset the environment to an initial state and return the initial observation.
33
+ Returns:
34
+ observation (dict): A dictionary where keys are agent identifiers and values are observations.
35
+ """
36
+ # (...)
37
+ return observation
38
+
39
+ def step(self, actions):
40
+ """Take a step in the environment using the provided actions.
41
+
42
+ Args:
43
+ actions (dict): A dictionary where keys are agent identifiers and values are actions.
44
+
45
+ Returns:
46
+ observations (dict): A dictionary where keys are agent identifiers and values are observations.
47
+ reward (dict): A dictionary where keys are agent identifiers and values are rewards.
48
+ done (bool): Whether the episode has ended.
49
+ info (dict): Additional information about the environment.
50
+ """
51
+ # (...)
52
+ return observations, done, info
53
+
54
+ def get_log_info(self):
55
+ """Get additional information about the environment. This information is used to log the game.
56
+ Returns:
57
+ log_info (dict): Information about the environment required to log the game.
58
+ """
59
+ # (...)
60
+ return log_info
61
+
62
+ def render(self):
63
+ """Render the current state of the environment."""
64
+ pass
65
+
66
+ def close(self):
67
+ """Perform any necessary cleanup."""
68
+ pass
69
+
70
+
71
+ class AgentState():
72
+
73
+ def __init__(self):
74
+ """Initialize the agent state."""
75
+ pass
76
+
77
+ def step(self, observation_from_env, policy_output=None):
78
+ """Update the agent state based on the observation and action.
79
+ The action is the output of the LLM.
80
+ """
81
+
82
+ Args:
83
+ observation_from_env (dict): The observation of the environment.
84
+ policy_output : The output of the policy.
85
+
86
+ Returns:
87
+ policy_id (str): The policy identifier.
88
+ policy_input (dict): The input to the policy.
89
+ action : The official action to be sent to the environment.
90
+ done (bool): Whether the LLM action is ready to be sent to the environment.
91
+ info (dict): Additional information about the agent.
92
+ """
93
+ # (...)
94
+ return policy_id, policy_input, action, done, info
95
+
96
+ def get_log_info(self):
97
+ """Get information about the agent required to log a trajectory.
98
+ Returns:
99
+ log_info (dict): Information about the agent required to log a trajectory.
100
+ """
101
+ # (...)
102
+ return log_info
103
+
104
+ def render(self):
105
+ """Render the current state of the environment."""
106
+ pass
107
+
108
+ def close(self):
109
+ """Perform any necessary cleanup."""
110
+ pass
111
+
112
+
113
+ Implicitely, the keys of the `observations` in the `step` method of the `MarlEnvironment` interface represent the set of agents from which an action is expected at the current step. The next step should only expect actions from the agents in the `observations` dictionary.
114
+
115
+ As you can see, both classes have a `get_log_info` method. This method is used to log the game. It returns a dictionary with keys being the agent identifiers and values being the information to log. The reason we need this is because the environment and the agent handler may need to log different information. It makes it easier to log from the perspective of each agent. The core environment class should not need to know about the details of the agent handler.
116
+
117
+
118
+
119
+ Running Environments in Parallel
120
+ --------------------------------
121
+ This standard allows the use of the `run_batched_matches` function (TODO: link) to run environments in an efficient way. The core idea is to batch the policy calls for all agents in the environment.
122
+
123
+ .. note::
124
+ The ``run_batched_matches`` function allows you to run multiple negotiation games, or "matches," in parallel.
125
+ After each environment is initialized, the function continuously loops over all active matches and checks which agents
126
+ are still pending actions. Each agent's logic can require multiple calls to the policy (e.g., an LLM) before an action
127
+ becomes "ready" to be sent to the environment. (For instance, an agent might need multiple policy calls before having a string which can be parsed into a valid action.) While an agent is waiting for a policy output, these calls for all agents across all matches are grouped together by unique policy identifier and processed in batch for efficiency. This is the core functionality of the ``run_batched_matches`` function.
128
+
129
+ Only once all actions from the required agents at a given step for an environment are ready does the function make a single ``env.step(...)`` call; this ensures
130
+ every match moves forward in lockstep for all its active agents. As soon as an environment signals it is done, the function
131
+ retrieves logged information from both the environment and the agent states before removing this match from the active set.
132
+
133
+ If there are more matches waiting to be processed, they are then started one by one to maintain the specified degree of parallelism.
134
+ This batching approach provides an efficient mechanism to handle multi-agent or multi-policy environments, ensuring minimal
135
+ overhead and a clear, unified flow for stepping through matches.
136
+
137
+ Here is a diagram that shows how the `run_batched_matches` function works at a high level:
138
+
139
+ .. image:: media/runbatch.png
140
+ :alt: Alternate text for the image
141
+ :width: 1000px
src_code_for_reproducibility/docs/source/src.environments.dond.dond_game.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.environments.dond.dond\_game module
2
+ =======================================
3
+
4
+ .. automodule:: src.environments.dond.dond_game
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_statistics_funcs.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.environments.ipd.ipd\_statistics\_funcs module
2
+ ==================================================
3
+
4
+ .. automodule:: src.environments.ipd.ipd_statistics_funcs
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.experiments.generate_and_train.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.experiments.generate\_and\_train module
2
+ ===========================================
3
+
4
+ .. automodule:: src.experiments.generate_and_train
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.generation.run_games.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.generation.run\_games module
2
+ ================================
3
+
4
+ .. automodule:: src.generation.run_games
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.models.local_llm.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.models.local\_llm module
2
+ ============================
3
+
4
+ .. automodule:: src.models.local_llm
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.models.new_local_llm.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.models.new\_local\_llm module
2
+ =================================
3
+
4
+ .. automodule:: src.models.new_local_llm
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.models.rst ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ src.models package
2
+ ==================
3
+
4
+ .. automodule:: src.models
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
8
+
9
+ Submodules
10
+ ----------
11
+
12
+ .. toctree::
13
+ :maxdepth: 4
14
+
15
+ src.models.dummy_local_llm
16
+ src.models.local_llm
17
+ src.models.new_local_llm
18
+ src.models.server_llm
19
+ src.models.updatable_worker
20
+ src.models.vllm_worker_wrap
src_code_for_reproducibility/docs/source/src.training.reinforce_training.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.training.reinforce\_training module
2
+ =======================================
3
+
4
+ .. automodule:: src.training.reinforce_training
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.training.rl_convs_processing.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.training.rl\_convs\_processing module
2
+ =========================================
3
+
4
+ .. automodule:: src.training.rl_convs_processing
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.training.train_main.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.training.train\_main module
2
+ ===============================
3
+
4
+ .. automodule:: src.training.train_main
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.utils.export_ppo_training_set.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.utils.export\_ppo\_training\_set module
2
+ ===========================================
3
+
4
+ .. automodule:: src.utils.export_ppo_training_set
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.utils.extra_stats.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.utils.extra\_stats module
2
+ =============================
3
+
4
+ .. automodule:: src.utils.extra_stats
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.utils.log_gpu_usage.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.utils.log\_gpu\_usage module
2
+ ================================
3
+
4
+ .. automodule:: src.utils.log_gpu_usage
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.utils.parallel_shuffle.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.utils.parallel\_shuffle module
2
+ ==================================
3
+
4
+ .. automodule:: src.utils.parallel_shuffle
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/markov_games/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (173 Bytes). View file
 
src_code_for_reproducibility/markov_games/__pycache__/group_timesteps.cpython-312.pyc ADDED
Binary file (6.17 kB). View file
 
src_code_for_reproducibility/markov_games/__pycache__/mg_utils.cpython-312.pyc ADDED
Binary file (3.98 kB). View file
 
src_code_for_reproducibility/markov_games/__pycache__/rollout_tree.cpython-312.pyc ADDED
Binary file (3.67 kB). View file
 
src_code_for_reproducibility/markov_games/diplomacy/diplomacy_agent.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Tuple, Optional, Any
2
+ import copy
3
+
4
+ class DiplomacyAgent:
5
+ """Agent handler for Diplomacy game that follows the MARL standard.
6
+
7
+ This class is responsible for parsing LLM output into valid Diplomacy orders,
8
+ managing the agent state, and providing information for logging.
9
+ """
10
+
11
+ def __init__(self, policy_id: str, power_name: str, random_valid_move=False):
12
+ """Initialize the agent handler for a power in the Diplomacy game.
13
+
14
+ Args:
15
+ power_name: The name of the power this agent controls (e.g., 'FRANCE', 'ENGLAND')
16
+ policy_id: The identifier for the policy this agent uses
17
+ random_valid_move: If True, will select random valid moves instead of using LLM (default: False)
18
+ """
19
+ self.policy_id = policy_id
20
+ self.power_name = power_name
21
+ self.orders = []
22
+ self.wait = True
23
+ self.processing_state = "WAITING_FOR_ORDERS"
24
+ self.parsed_orders = []
25
+ self.order_status = {}
26
+ self.message_history = []
27
+ self.random_valid_move = random_valid_move
28
+
29
+ def step(self, observation_from_env, policy_output=None):
30
+ """Update the agent state based on the observation and LLM output.
31
+
32
+ Args:
33
+ observation_from_env: The observation from the environment
34
+ policy_output: The output from the LLM
35
+
36
+ Returns:
37
+ policy_id: The policy identifier
38
+ policy_input: The input to the policy
39
+ action: The official action to be sent to the environment
40
+ done: Whether the LLM action is ready to be sent to the environment
41
+ info: Additional information about the agent
42
+ """
43
+ info = {}
44
+
45
+ # If random_valid_move is enabled, select random valid moves
46
+ if self.random_valid_move:
47
+ valid_orders = self._select_random_valid_moves(observation_from_env)
48
+ self.orders = valid_orders
49
+ self.wait = False
50
+ action = {
51
+ "orders": valid_orders,
52
+ "wait": False
53
+ }
54
+ return self.policy_id, {}, action, True, info
55
+
56
+ # If no policy output, this is the initial step - prepare prompt
57
+ if policy_output is None:
58
+ # Create initial prompt for the LLM
59
+ phase = observation_from_env.get('phase', '')
60
+ units = observation_from_env.get('units', {}).get(self.power_name, [])
61
+ centers = observation_from_env.get('centers', {}).get(self.power_name, [])
62
+ orderable_locations = observation_from_env.get('orderable_locations', {})
63
+
64
+ prompt = self._create_prompt(phase, units, centers, orderable_locations)
65
+
66
+ return self.policy_id, {"prompt": prompt}, None, False, info
67
+
68
+ # Process the LLM output to extract orders
69
+ success, parsed_orders = self._parse_llm_output(policy_output)
70
+ self.parsed_orders = parsed_orders
71
+
72
+ if not success:
73
+ # Need more information from LLM
74
+ clarification_prompt = self._create_clarification_prompt(policy_output, parsed_orders)
75
+ return self.policy_id, {"prompt": clarification_prompt}, None, False, info
76
+
77
+ # Validate if the orders are valid for the current phase
78
+ valid_orders = self._validate_orders(parsed_orders, observation_from_env)
79
+
80
+ if valid_orders:
81
+ # Orders are valid, prepare action for environment
82
+ self.orders = valid_orders
83
+ self.wait = False
84
+ action = {
85
+ "orders": valid_orders,
86
+ "wait": False
87
+ }
88
+ return self.policy_id, {}, action, True, info
89
+ else:
90
+ # Orders are invalid, ask for new ones
91
+ error_prompt = self._create_error_prompt(parsed_orders, observation_from_env)
92
+ return self.policy_id, {"prompt": error_prompt}, None, False, info
93
+
94
+ def _create_prompt(self, phase, units, centers, orderable_locations):
95
+ """Create the initial prompt for the LLM.
96
+
97
+ Args:
98
+ phase: The current game phase
99
+ units: List of units controlled by this power
100
+ centers: List of supply centers controlled by this power
101
+ orderable_locations: List of locations where orders can be issued
102
+
103
+ Returns:
104
+ A prompt string for the LLM
105
+ """
106
+ prompt = f"You are playing as {self.power_name} in Diplomacy. The current phase is {phase}.\n\n"
107
+ prompt += f"Your units: {', '.join(units)}\n"
108
+ prompt += f"Your supply centers: {', '.join(centers)}\n"
109
+ prompt += f"Locations you can order: {', '.join(orderable_locations)}\n\n"
110
+
111
+ if phase.endswith('M'): # Movement phase
112
+ prompt += "Please provide orders for your units in the form:\n"
113
+ prompt += "- A LON H (hold)\n"
114
+ prompt += "- F NTH - NWY (move)\n"
115
+ prompt += "- A WAL S F LON (support)\n"
116
+ prompt += "- F NWG C A NWY - EDI (convoy)\n"
117
+ elif phase.endswith('R'): # Retreat phase
118
+ prompt += "Please provide retreat orders for your dislodged units:\n"
119
+ prompt += "- A PAR R MAR (retreat to MAR)\n"
120
+ prompt += "- A PAR D (disband)\n"
121
+ elif phase.endswith('A'): # Adjustment phase
122
+ if len(units) < len(centers):
123
+ prompt += "You can build units. Please provide build orders:\n"
124
+ prompt += "- A PAR B (build army in PAR)\n"
125
+ prompt += "- F BRE B (build fleet in BRE)\n"
126
+ prompt += "- WAIVE (waive a build)\n"
127
+ elif len(units) > len(centers):
128
+ prompt += "You must remove units. Please provide disbandment orders:\n"
129
+ prompt += "- A PAR D (disband army in PAR)\n"
130
+ prompt += "- F BRE D (disband fleet in BRE)\n"
131
+
132
+ prompt += "\nProvide your orders as a list, one per line."
133
+ return prompt
134
+
135
+ def _parse_llm_output(self, llm_output):
136
+ """Parse the LLM output to extract orders.
137
+
138
+ Args:
139
+ llm_output: The raw output from the LLM
140
+
141
+ Returns:
142
+ success: Whether parsing was successful
143
+ parsed_orders: List of parsed orders
144
+ """
145
+ # Simple parsing for now - extract lines that look like orders
146
+ lines = llm_output.strip().split('\n')
147
+ orders = []
148
+
149
+ for line in lines:
150
+ # Remove list markers, hyphens, etc.
151
+ line = line.strip('- *•').strip()
152
+
153
+ # Skip empty lines and lines that don't look like orders
154
+ if not line or line.startswith('I ') or line.startswith('Let\'s'):
155
+ continue
156
+
157
+ # Check if it looks like a Diplomacy order
158
+ if (' H' in line or ' -' in line or ' S ' in line or ' C ' in line or
159
+ ' R ' in line or ' D' in line or ' B' in line or line == 'WAIVE'):
160
+ orders.append(line)
161
+
162
+ return len(orders) > 0, orders
163
+
164
+ def _validate_orders(self, orders, observation):
165
+ """Validate if the orders are valid for the current phase.
166
+
167
+ Args:
168
+ orders: List of orders to validate
169
+ observation: Current observation from the environment
170
+
171
+ Returns:
172
+ List of valid orders or None if invalid
173
+ """
174
+ # For simplicity, we'll assume all parsed orders are valid
175
+ # In a real implementation, we would use the game's validation logic
176
+ return orders
177
+
178
+ def _create_clarification_prompt(self, previous_output, parsed_orders):
179
+ """Create a prompt asking for clarification when orders couldn't be parsed.
180
+
181
+ Args:
182
+ previous_output: The previous LLM output
183
+ parsed_orders: Any orders that were successfully parsed
184
+
185
+ Returns:
186
+ A prompt string for the LLM
187
+ """
188
+ prompt = f"I couldn't fully understand your orders for {self.power_name}. "
189
+
190
+ if parsed_orders:
191
+ prompt += f"I understood these orders:\n"
192
+ for order in parsed_orders:
193
+ prompt += f"- {order}\n"
194
+
195
+ prompt += "\nPlease provide clear, valid Diplomacy orders in the format:\n"
196
+ prompt += "- A LON H\n- F NTH - NWY\n- etc.\n"
197
+ return prompt
198
+
199
+ def _create_error_prompt(self, invalid_orders, observation):
200
+ """Create a prompt when orders are invalid.
201
+
202
+ Args:
203
+ invalid_orders: The invalid orders
204
+ observation: Current observation from the environment
205
+
206
+ Returns:
207
+ A prompt string for the LLM
208
+ """
209
+ prompt = f"The following orders for {self.power_name} are invalid:\n"
210
+ for order in invalid_orders:
211
+ prompt += f"- {order}\n"
212
+
213
+ prompt += "\nPlease provide valid orders for your units."
214
+ return prompt
215
+
216
+ def get_log_info(self):
217
+ """Get information about the agent required to log a trajectory.
218
+
219
+ Returns:
220
+ log_info: Information about the agent required to log a trajectory.
221
+ """
222
+ return {
223
+ "power_name": self.power_name,
224
+ "orders": self.orders,
225
+ "wait": self.wait,
226
+ "parsing_state": self.processing_state,
227
+ "message_history": self.message_history
228
+ }
229
+
230
+ def render(self):
231
+ """Render the current state of the agent."""
232
+ print(f"Power: {self.power_name}")
233
+ print(f"Orders: {self.orders}")
234
+ print(f"Wait: {self.wait}")
235
+
236
+ def close(self):
237
+ """Perform any necessary cleanup."""
238
+ pass
239
+
240
+ def _select_random_valid_moves(self, observation):
241
+ """Select random valid moves for all units.
242
+
243
+ Args:
244
+ observation: Current observation from the environment
245
+
246
+ Returns:
247
+ List of valid orders
248
+ """
249
+ import random
250
+
251
+ possible_orders = observation.get('possible_orders', {})
252
+ valid_orders = []
253
+
254
+ # For each location with possible orders, select one randomly
255
+ for location, orders in possible_orders.items():
256
+ if orders: # If there are any possible orders for this location
257
+ valid_orders.append(random.choice(orders))
258
+
259
+ return valid_orders
src_code_for_reproducibility/markov_games/diplomacy/diplomacy_logging_for_training.py ADDED
File without changes
src_code_for_reproducibility/markov_games/ipd/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (310 Bytes). View file
 
src_code_for_reproducibility/markov_games/negotiation/__pycache__/no_press_nego_simulation.cpython-312.pyc ADDED
Binary file (9.06 kB). View file
 
src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_simple_agent.cpython-312.pyc ADDED
Binary file (4.89 kB). View file
 
src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_simulation.cpython-312.pyc ADDED
Binary file (9.03 kB). View file
 
src_code_for_reproducibility/markov_games/negotiation/tas_agent.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mllm.markov_games.negotiation.nego_agent import NegotiationAgent
2
+ from mllm.markov_games.negotiation.nego_simulation import Split
3
+ from mllm.markov_games.negotiation.tas_simulation import TrustAndSplitObs
4
+
5
+
6
+ class TrustAndSplitAgent(NegotiationAgent):
7
+ def __init__(self, num_message_chars, *args, **kwargs):
8
+ self.num_message_chars = num_message_chars
9
+ super().__init__(*args, **kwargs)
10
+ self.intro_prompt = (
11
+ "Welcome to an iterated game. You are {agent}. The other agent is {other_agent}.\n"
12
+ "Setup:\n"
13
+ "1. The game has multiple independent rounds.\n"
14
+ "2. In each round, there are multiple items to split between the two agents.\n"
15
+ "3. Both agents are assigned a per-item value between 1 and 20 (inclusive) in each round.\n"
16
+ "4. You can only observe your own per-item values.\n"
17
+ "5. Because assignments are random, both agents are equally likely to have same expected per-item value.\n"
18
+ "\n"
19
+ "Protocol:\n"
20
+ "1. At the start of the round, one agent begins the conversation. The starting role alternates each round.\n"
21
+ "2. Agents exchange a short chat ({quota_messages_per_agent_per_round} messages per round per agent) to negotiate how to split the item.\n"
22
+ " - Use this chat to communicate your private per-item value to make informed proposals.\n"
23
+ "3. After the chat, both agents simultaneously propose the amount of each item they will keep.\n"
24
+ "4. If the total sum of proposals is less than or equal to the item quantity, both agents receive their proposed amounts.\n"
25
+ "5. If the total sum of proposals exceeds the item quantity, they are allocated proportionally.\n"
26
+ "6. Your points for the round = (amount you receive per item) x (your per-item value for that round), added across all items.\n"
27
+ "7. Points are accumulated across rounds.\n"
28
+ "Your goal: {goal}\n"
29
+ )
30
+ self.new_round_prompt = (
31
+ "A New Round Begins\n"
32
+ "The items to split are {quantities}.\n"
33
+ "Your per-item values are {value}."
34
+ )
35
+ self.last_round_prompt = (
36
+ "Last Round Summary:\n"
37
+ " - Items to split: {last_quantities}\n"
38
+ " - Your per-item values: {last_value_agent}\n"
39
+ " - {other_agent}'s per-item values: {last_value_coagent}\n"
40
+ " - You proposed: {last_split_agent}\n"
41
+ " - You earned: {last_points_agent} points\n"
42
+ " - {other_agent} proposed: {last_split_coagent}\n"
43
+ " - {other_agent} earned: {last_points_coagent} points\n"
44
+ " - Round Complete.\n"
45
+ )
46
+ self.send_split_prompt = (
47
+ "Message quota is finished for this round.\n"
48
+ "{other_agent} has finalized their proposal.\n"
49
+ "Submit your finalization now\n"
50
+ "Respond with {proposal_style2}"
51
+ )
52
+ # self.wait_for_message_prompt = "Wait for {other_agent} to send a message..."
53
+ self.wait_for_message_prompt = ""
54
+ self.last_message_prompt = "{other_agent} said: {last_message}"
55
+ # self.send_message_prompt = (
56
+ # f"Send your message now (max {self.num_message_chars} chars)."
57
+ # )
58
+ self.send_message_prompt = f"Send your message now in <message>...</message> (<={self.num_message_chars} chars)."
59
+
60
+ def get_message_regex(self, observation: TrustAndSplitObs) -> str:
61
+ return rf"<message>[\s\S]{{0,{self.num_message_chars}}}</message>"
62
+
63
+ # def get_message_regex(self, observation: TrustAndSplitObs) -> str:
64
+ # return rf"(?s).{{0,{self.num_message_chars}}}"
65
+
66
+ def get_split_regex(self, observation: TrustAndSplitObs) -> str:
67
+ items = list(observation.quantities.keys())
68
+ # Accept both singular and plural forms
69
+ item_pattern = "|".join(
70
+ [f"{item[:-1]}s?" if item.endswith("s") else f"{item}s?" for item in items]
71
+ )
72
+ regex = rf"(?i)<items_to_self> ?((?:\s*(?P<num>(10|[0-9]))\s*(?P<item>{item_pattern})\s*,?)+) ?</items_to_self>"
73
+ return regex
74
+
75
+ def get_split_action(
76
+ self, policy_output: str, observation: TrustAndSplitObs
77
+ ) -> Split:
78
+ items = list(observation.quantities.keys())
79
+ import re as _re
80
+
81
+ split_regex = self.get_split_regex(observation)
82
+ items_given_to_self = {item: 0 for item in items}
83
+ m = _re.match(split_regex, policy_output.strip())
84
+ if m:
85
+ # Find all (number, item) pairs
86
+ item_pattern = "|".join(
87
+ [
88
+ f"{item[:-1]}s?" if item.endswith("s") else f"{item}s?"
89
+ for item in items
90
+ ]
91
+ )
92
+ inner_regex = rf"(?i)(10|[0-9])\s*({item_pattern})"
93
+
94
+ def normalize_item_name(item_str):
95
+ for orig in items:
96
+ if item_str.lower() == orig.lower():
97
+ return orig
98
+ if orig.endswith("s") and item_str.lower() == orig[:-1].lower():
99
+ return orig
100
+ if (
101
+ not orig.endswith("s")
102
+ and item_str.lower() == orig.lower() + "s"
103
+ ):
104
+ return orig
105
+
106
+ for num, item in _re.findall(inner_regex, m.group(1)):
107
+ items_given_to_self[normalize_item_name(item)] = int(num)
108
+ return Split(items_given_to_self=items_given_to_self)
src_code_for_reproducibility/models/__pycache__/scalar_critic.cpython-312.pyc ADDED
Binary file (3.21 kB). View file
 
src_code_for_reproducibility/training/credit_methods.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def whiten_advantages(advantages: torch.Tensor) -> torch.Tensor:
5
+ """
6
+ Whitens the advantages.
7
+ """
8
+ whitened_advantages = (advantages - torch.mean(advantages)) / (
9
+ torch.std(advantages) + 1e-9
10
+ )
11
+ return whitened_advantages
12
+
13
+
14
+ def whiten_advantages_time_step_wise(
15
+ advantages: torch.Tensor, # (B, T)
16
+ ) -> torch.Tensor:
17
+ """
18
+ Whitens the advantages.
19
+ """
20
+ assert advantages.dim() == 2, "Wrong dimensions."
21
+ whitened_advantages_time_step_wise = (
22
+ advantages - advantages.mean(dim=0, keepdim=True)
23
+ ) / (advantages.std(dim=0, keepdim=True) + 1e-9)
24
+ return whitened_advantages_time_step_wise
25
+
26
+
27
+ def get_discounted_state_visitation_credits(
28
+ credits: torch.Tensor, discount_factor: float # (B, T)
29
+ ) -> torch.Tensor:
30
+ """
31
+ Computes discounted state visitation credits for a sequence of credits.
32
+ """
33
+ return credits * (
34
+ discount_factor ** torch.arange(credits.shape[1], device=credits.device)
35
+ )
36
+
37
+
38
+ def get_discounted_returns(
39
+ rewards: torch.Tensor, # (B, T)
40
+ discount_factor: float,
41
+ ) -> torch.Tensor:
42
+ """
43
+ Computes Monte Carlo discounted returns for a sequence of rewards.
44
+
45
+ Args:
46
+ rewards (torch.Tensor): Array of rewards for each timestep.
47
+
48
+ Returns:
49
+ torch.Tensor: Array of discounted returns.
50
+ """
51
+ assert rewards.dim() == 2, "Wrong dimensions."
52
+ B, T = rewards.shape
53
+ discounted_returns = torch.zeros_like(rewards)
54
+ accumulator = torch.zeros(B, device=rewards.device, dtype=rewards.dtype)
55
+ for t in reversed(range(T)):
56
+ accumulator = rewards[:, t] + discount_factor * accumulator
57
+ discounted_returns[:, t] = accumulator
58
+ return discounted_returns
59
+
60
+
61
+ def get_rloo_credits(credits: torch.Tensor): # (B, S)
62
+ assert credits.dim() == 2, "Wrong dimensions."
63
+ rloo_baselines = torch.zeros_like(credits)
64
+ n = credits.shape[0]
65
+ if n == 1:
66
+ return credits, rloo_baselines
67
+ rloo_baselines = (torch.sum(credits, dim=0, keepdim=True) - credits) / (n - 1)
68
+ rloo_credits = credits - rloo_baselines
69
+ return rloo_credits, rloo_baselines
70
+
71
+
72
+ def get_generalized_advantage_estimates(
73
+ rewards: torch.Tensor, # (B, T)
74
+ value_estimates: torch.Tensor, # (B, T+1)
75
+ discount_factor: float,
76
+ lambda_coef: float,
77
+ ) -> torch.Tensor:
78
+ """
79
+ Computes Generalized Advantage Estimates (GAE) for a sequence of rewards and value estimates.
80
+ See https://arxiv.org/pdf/1506.02438 for details.
81
+
82
+
83
+ Returns:
84
+ torch.Tensor: Array of GAE values.
85
+ """
86
+ assert rewards.dim() == value_estimates.dim() == 2, "Wrong dimensions."
87
+
88
+ assert (
89
+ rewards.shape[0] == value_estimates.shape[0]
90
+ ), f"Got shapes {rewards.shape} and {value_estimates.shape} of rewards and value estimates."
91
+ assert (
92
+ rewards.shape[1] == value_estimates.shape[1] - 1
93
+ ), f"Got shapes {rewards.shape} and {value_estimates.shape} of rewards and value estimates."
94
+
95
+ T = rewards.shape[1]
96
+ tds = rewards + discount_factor * value_estimates[:, 1:] - value_estimates[:, :-1]
97
+ gaes = torch.zeros_like(tds)
98
+ acc = 0.0
99
+ for t in reversed(range(T)):
100
+ acc = tds[:, t] + lambda_coef * discount_factor * acc
101
+ gaes[:, t] = acc
102
+ return gaes
103
+
104
+
105
+ def get_advantage_alignment_weights(
106
+ advantages: torch.Tensor, # (B, T)
107
+ exclude_k_equals_t: bool,
108
+ gamma: float,
109
+ ) -> torch.Tensor:
110
+ """
111
+ The advantage alignment credit is calculated as
112
+
113
+ \[
114
+ A^*(s_t, a_t, b_t) = A^1(s_t, a_t, b_t) + \beta \cdot
115
+ \left( \sum_{k < t} \gamma^{t-k} A^1(s_k, a_k, b_k) \right)
116
+ A^2(s_t, a_t, b_t)
117
+ \]
118
+
119
+ Here, the weights are defined as \( \beta \cdot
120
+ \left( \sum_{k < t} \gamma^{t-k} A^1(s_k, a_k, b_k) \)
121
+ """
122
+ T = advantages.shape[1]
123
+ discounted_advantages = advantages * (
124
+ gamma * torch.ones((1, T), device=advantages.device)
125
+ ) ** (-torch.arange(0, T, 1, device=advantages.device))
126
+ if exclude_k_equals_t:
127
+ sub = torch.eye(T, device=advantages.device)
128
+ else:
129
+ sub = torch.zeros((T, T), device=advantages.device)
130
+
131
+ # Identity is for \( k < t \), remove for \( k \leq t \)
132
+ ad_align_weights = discounted_advantages @ (
133
+ torch.triu(torch.ones((T, T), device=advantages.device)) - sub
134
+ )
135
+ t_discounts = (gamma * torch.ones((1, T), device=advantages.device)) ** (
136
+ torch.arange(0, T, 1, device=advantages.device)
137
+ )
138
+ ad_align_weights = t_discounts * ad_align_weights
139
+ return ad_align_weights
140
+
141
+
142
+ def get_advantage_alignment_credits(
143
+ a1: torch.Tensor, # (B, S)
144
+ a1_alternative: torch.Tensor, # (B, S, A)
145
+ a2: torch.Tensor, # (B, S)
146
+ exclude_k_equals_t: bool,
147
+ beta: float,
148
+ gamma: float = 1.0,
149
+ use_old_ad_align: bool = False,
150
+ use_sign: bool = False,
151
+ clipping: float | None = None,
152
+ use_time_regularization: bool = False,
153
+ force_coop_first_step: bool = False,
154
+ use_variance_regularization: bool = False,
155
+ rloo_branch: bool = False,
156
+ reuse_baseline: bool = False,
157
+ mean_normalize_ad_align: bool = False,
158
+ whiten_adalign_advantages: bool = False,
159
+ whiten_adalign_advantages_time_step_wise: bool = False,
160
+ ) -> torch.Tensor:
161
+ """
162
+ Calculate the advantage alignment credits with vectorization, as described in https://arxiv.org/abs/2406.14662.
163
+
164
+ Recall that the advantage opponent shaping term of the AdAlign policy gradient is:
165
+ \[
166
+ \beta \mathbb{E}_{\substack{
167
+ \tau \sim \text{Pr}_{\mu}^{\pi^1, \pi^2} \\
168
+ a_t' \sim \pi^1(\cdot \mid s_t)
169
+ }}
170
+ \left[\sum_{t=0}^\infty \gamma^{t}\left( \sum_{k\leq t} A^1(s_k,a^{\prime}_k,b_k) \right) A^{2}(s_t,a_t, b_t)\nabla_{\theta^1}\text{log } \pi^1(a_t|s_t) \right]
171
+ \]
172
+
173
+ This method computes the following:
174
+ \[
175
+ Credit(s_t, a_t, b_t) = \gamma^t \left[ A^1(s_t, a_t, b_t) + \beta \left( \sum_{k\leq t} A^1(s_k,a^{\prime}_k,b_k) \right) A^{2}(s_t,a_t, b_t) \right]
176
+ \]
177
+
178
+ Args:
179
+ a1: Advantages of the main trajectories for the current agent.
180
+ a1_alternative: Advantages of the alternative trajectories for the current agent.
181
+ a2: Advantages of the main trajectories for the other agent.
182
+ discount_factor: Discount factor for the advantage alignment.
183
+ beta: Beta parameter for the advantage alignment.
184
+ gamma: Gamma parameter for the advantage alignment.
185
+ use_sign_in_ad_align: Whether to use sign in the advantage alignment.
186
+
187
+ Returns:
188
+ torch.Tensor: The advantage alignment credits.
189
+ """
190
+
191
+ assert a1.dim() == a2.dim() == 2, "Advantages must be of shape (B, S)"
192
+ if a1_alternative is not None:
193
+ assert (
194
+ a1_alternative.dim() == 3
195
+ ), "Alternative advantages must be of shape (B, S, A)"
196
+ B, T, A = a1_alternative.shape
197
+ else:
198
+ B, T = a1.shape
199
+ assert a1.shape == a2.shape, "Not the same shape"
200
+
201
+ sub_tensors = {}
202
+
203
+ if use_old_ad_align:
204
+ ad_align_weights = get_advantage_alignment_weights(
205
+ advantages=a1, exclude_k_equals_t=exclude_k_equals_t, gamma=gamma
206
+ )
207
+ sub_tensors["ad_align_weights_prev"] = ad_align_weights
208
+ if exclude_k_equals_t:
209
+ ad_align_weights = gamma * ad_align_weights
210
+ else:
211
+ assert a1_alternative is not None, "Alternative advantages must be provided"
212
+ if rloo_branch:
213
+ a1_alternative = torch.cat([a1.unsqueeze(2), a1_alternative], dim=2)
214
+ a1_alternative = a1_alternative.mean(dim=2)
215
+ # print(f"a1_alternative: {a1_alternative}, a1: {a1}\n")
216
+ a1, baseline = get_rloo_credits(a1)
217
+ if reuse_baseline:
218
+ a1_alternative = a1_alternative - baseline
219
+ else:
220
+ a1_alternative, _ = get_rloo_credits(a1_alternative)
221
+ assert a1.shape == a1_alternative.shape, "Not the same shape"
222
+ ad_align_weights = get_advantage_alignment_weights(
223
+ advantages=a1_alternative,
224
+ exclude_k_equals_t=exclude_k_equals_t,
225
+ gamma=gamma,
226
+ )
227
+ sub_tensors["ad_align_weights"] = ad_align_weights
228
+
229
+ # Use sign
230
+ if use_sign:
231
+ assert beta == 1.0, "beta should be 1.0 when using sign"
232
+ positive_signs = ad_align_weights > 0
233
+ negative_signs = ad_align_weights < 0
234
+ ad_align_weights[positive_signs] = 1
235
+ ad_align_weights[negative_signs] = -1
236
+ sub_tensors["ad_align_weights_sign"] = ad_align_weights
237
+ # (rest are 0)
238
+
239
+ ###################
240
+ # Process weights
241
+ ###################
242
+
243
+ # Use clipping
244
+ if clipping not in [0.0, None]:
245
+ upper_mask = ad_align_weights > 1
246
+ lower_mask = ad_align_weights < -1
247
+
248
+ ad_align_weights = torch.clip(
249
+ ad_align_weights,
250
+ -clipping,
251
+ clipping,
252
+ )
253
+ clipping_ratio = (
254
+ torch.sum(upper_mask) + torch.sum(lower_mask)
255
+ ) / upper_mask.size
256
+ sub_tensors["clipped_ad_align_weights"] = ad_align_weights
257
+
258
+ # 1/1+t Regularization
259
+ if use_time_regularization:
260
+ t_values = torch.arange(1, T + 1).to(ad_align_weights.device)
261
+ ad_align_weights = ad_align_weights / t_values
262
+ sub_tensors["time_regularized_ad_align_weights"] = ad_align_weights
263
+
264
+ # Use coop on t=0
265
+ if force_coop_first_step:
266
+ ad_align_weights[:, 0] = 1
267
+ sub_tensors["coop_first_step_ad_align_weights"] = ad_align_weights
268
+ # # Normalize alignment terms (across same time step)
269
+ # if use_variance_regularization_in_ad_align:
270
+ # # TODO: verify
271
+ # reg_coef = torch.std(a1[:, -1]) / (torch.std(opp_shaping_terms[:, -1]) + 1e-9)
272
+ # opp_shaping_terms *= reg_coef
273
+
274
+ ####################################
275
+ # Compose elements together
276
+ ####################################
277
+
278
+ opp_shaping_terms = beta * ad_align_weights * a2
279
+ sub_tensors["ad_align_opp_shaping_terms"] = opp_shaping_terms
280
+
281
+ credits = a1 + opp_shaping_terms
282
+ if mean_normalize_ad_align:
283
+ credits = credits - credits.mean(dim=0)
284
+ sub_tensors["mean_normalized_ad_align_credits"] = credits
285
+ if whiten_adalign_advantages:
286
+ credits = (credits - credits.mean()) / (credits.std() + 1e-9)
287
+ sub_tensors["whitened_ad_align_credits"] = credits
288
+ if whiten_adalign_advantages_time_step_wise:
289
+ credits = (credits - credits.mean(dim=0, keepdim=True)) / (
290
+ credits.std(dim=0, keepdim=True) + 1e-9
291
+ )
292
+ sub_tensors["whitened_ad_align_credits_time_step_wise"] = credits
293
+ sub_tensors["final_ad_align_credits"] = credits
294
+
295
+ return credits, sub_tensors
src_code_for_reproducibility/training/tally_tokenwise.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from typing import Any, Dict, List, Tuple, Union
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ from transformers import AutoTokenizer
9
+
10
+
11
+ class ContextualizedTokenwiseTally:
12
+ """
13
+ Collect, store, and save token-level metrics per rollout.
14
+
15
+ - One DataFrame per rollout_id in `paths`
16
+ - Index = timestep (int)
17
+ - Columns are added incrementally via `add_contexts()` and `add_data()`
18
+ - Cells may contain scalars, strings, or lists (dtype=object)
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ tokenizer: AutoTokenizer,
24
+ paths: List[str],
25
+ max_context_length: int = 30,
26
+ ):
27
+ """
28
+ Args:
29
+ tokenizer: HuggingFace tokenizer used to convert tids -> tokens
30
+ paths: rollout identifiers (parallel to batch dimension)
31
+ max_context_length: truncate context token lists to this length
32
+ """
33
+ self.tokenizer = tokenizer
34
+ self.paths = paths
35
+ self.max_context_length = max_context_length
36
+ self.tally: Dict[str, pd.DataFrame] = {path: pd.DataFrame() for path in paths}
37
+
38
+ # set later by setters
39
+ self.contexts: torch.Tensor | None = None
40
+ self.action_mask: torch.Tensor | None = None
41
+ self.range: Tuple[int, int] | None = None
42
+
43
+ # --------- Utilities ---------
44
+
45
+ def tids_to_str(self, tids: List[int]) -> List[str]:
46
+ """Convert a list of token IDs to a list of token strings."""
47
+ return self.tokenizer.convert_ids_to_tokens(tids)
48
+
49
+ def _ensure_ready(self):
50
+ assert self.action_mask is not None, "call set_action_mask(mask) first"
51
+ assert self.range is not None, "call set_range((start, end)) first"
52
+
53
+ @staticmethod
54
+ def _sanitize_filename(name: Any) -> str:
55
+ """Make a safe filename from any rollout_id."""
56
+ s = str(name)
57
+ bad = {os.sep, " ", ":", "|", "<", ">", '"', "'"}
58
+ if os.altsep is not None:
59
+ bad.add(os.altsep)
60
+ for ch in bad:
61
+ s = s.replace(ch, "_")
62
+ return s
63
+
64
+ @staticmethod
65
+ def _pad_left(seq: List[Any], length: int, pad_val: Any = "") -> List[Any]:
66
+ """Left-pad a sequence to `length` with `pad_val`."""
67
+ if len(seq) >= length:
68
+ return seq[-length:]
69
+ return [pad_val] * (length - len(seq)) + list(seq)
70
+
71
+ # --------- Setters ---------
72
+
73
+ def set_action_mask(self, action_mask: torch.Tensor):
74
+ """
75
+ action_mask: (B, S) bool or 0/1 indicating valid steps
76
+ """
77
+ self.action_mask = action_mask
78
+
79
+ def set_range(self, range: Tuple[int, int]):
80
+ """
81
+ range: slice (start, end) into self.paths for current batch
82
+ """
83
+ self.range = range
84
+
85
+ # --------- Column builders ---------
86
+
87
+ def add_contexts(self, contexts: torch.Tensor):
88
+ """
89
+ Add a single 'context' column (list[str]) for valid steps.
90
+
91
+ Expects `contexts` with shape (B, S): token id at each timestep.
92
+ For each valid timestep t, we use the last N tokens up to and including t:
93
+ window = contexts[i, max(0, t - N + 1) : t + 1]
94
+ The list is left-padded with "" to always be length N.
95
+ """
96
+ self._ensure_ready()
97
+
98
+ current_paths = self.paths[self.range[0] : self.range[1]]
99
+ B, S = contexts.shape
100
+ N = self.max_context_length
101
+
102
+ # to CPU ints once
103
+ contexts_cpu = contexts.detach().to("cpu")
104
+
105
+ for i in range(B):
106
+ rollout_id = current_paths[i]
107
+ df = self.tally.get(rollout_id, pd.DataFrame())
108
+
109
+ valid_idx = torch.nonzero(
110
+ self.action_mask[i].bool(), as_tuple=False
111
+ ).squeeze(-1)
112
+ if valid_idx.numel() == 0:
113
+ self.tally[rollout_id] = df
114
+ continue
115
+
116
+ idx_list = valid_idx.tolist()
117
+
118
+ # ensure index contains valid steps
119
+ if df.empty:
120
+ df = pd.DataFrame(index=idx_list)
121
+ else:
122
+ new_index = sorted(set(df.index.tolist()) | set(idx_list))
123
+ if list(df.index) != new_index:
124
+ df = df.reindex(new_index)
125
+
126
+ # build context windows
127
+ ctx_token_lists = []
128
+ for t in idx_list:
129
+ start = max(0, t - N + 1)
130
+ window_ids = contexts_cpu[i, start : t + 1].tolist()
131
+ window_toks = self.tids_to_str([int(x) for x in window_ids])
132
+ if len(window_toks) < N:
133
+ window_toks = [""] * (N - len(window_toks)) + window_toks
134
+ else:
135
+ window_toks = window_toks[-N:]
136
+ ctx_token_lists.append(window_toks)
137
+
138
+ # single 'context' column
139
+ if "context" not in df.columns:
140
+ df["context"] = pd.Series(index=df.index, dtype=object)
141
+ df.loc[idx_list, "context"] = pd.Series(
142
+ ctx_token_lists, index=idx_list, dtype=object
143
+ )
144
+
145
+ self.tally[rollout_id] = df
146
+
147
+ def add_data(
148
+ self,
149
+ metric_id: str,
150
+ metrics: torch.Tensor,
151
+ to_tids: bool = False,
152
+ ):
153
+ """
154
+ Add a metric column for valid steps.
155
+
156
+ Args:
157
+ metric_id: column name
158
+ metrics: shape (B, S) for scalars/ids or (B, S, K) for top-k vectors
159
+ to_tids: if True, treat ints/lists of ints as tids and convert to tokens
160
+ """
161
+ self._ensure_ready()
162
+ current_paths = self.paths[self.range[0] : self.range[1]]
163
+
164
+ if metrics.dim() == 2:
165
+ B, S = metrics.shape
166
+ elif metrics.dim() == 3:
167
+ B, S, _ = metrics.shape
168
+ else:
169
+ raise ValueError("metrics must be (B, S) or (B, S, K)")
170
+
171
+ for i in range(B):
172
+ rollout_id = current_paths[i]
173
+ df = self.tally.get(rollout_id, pd.DataFrame())
174
+
175
+ valid_idx = torch.nonzero(
176
+ self.action_mask[i].bool(), as_tuple=False
177
+ ).squeeze(-1)
178
+ if valid_idx.numel() == 0:
179
+ self.tally[rollout_id] = df
180
+ continue
181
+
182
+ idx_list = valid_idx.detach().cpu().tolist()
183
+
184
+ # Ensure index contains valid steps
185
+ if df.empty:
186
+ df = pd.DataFrame(index=idx_list)
187
+ else:
188
+ new_index = sorted(set(df.index.tolist()) | set(idx_list))
189
+ if list(df.index) != new_index:
190
+ df = df.reindex(new_index)
191
+
192
+ # Slice metrics at valid steps
193
+ m_valid = metrics[i][valid_idx]
194
+
195
+ # -> pure python lists (1D list or list-of-lists)
196
+ values = m_valid.detach().cpu().tolist()
197
+
198
+ # optional tids -> tokens
199
+ if to_tids:
200
+
201
+ def _to_tokish(x):
202
+ if isinstance(x, list):
203
+ return self.tids_to_str([int(v) for v in x])
204
+ else:
205
+ return self.tids_to_str([int(x)])[0]
206
+
207
+ values = [_to_tokish(v) for v in values]
208
+
209
+ # Ensure column exists with object dtype, then assign via aligned Series
210
+ if metric_id not in df.columns:
211
+ df[metric_id] = pd.Series(index=df.index, dtype=object)
212
+
213
+ if isinstance(values, np.ndarray):
214
+ values = values.tolist()
215
+
216
+ if len(values) != len(idx_list):
217
+ raise ValueError(
218
+ f"Length mismatch for '{metric_id}': values={len(values)} vs idx_list={len(idx_list)}"
219
+ )
220
+
221
+ df.loc[idx_list, metric_id] = pd.Series(
222
+ values, index=idx_list, dtype=object
223
+ )
224
+ self.tally[rollout_id] = df
225
+
226
+ # --------- Saving ---------
227
+
228
+ def save(self, path: str):
229
+ """
230
+ Write a manifest JSON and one CSV per rollout.
231
+
232
+ - Manifest includes metadata only (safe to JSON).
233
+ - Each rollout CSV is written with index label 'timestep'.
234
+ - Only a single 'context' column (list[str]).
235
+ """
236
+ if not self.tally or all(df.empty for df in self.tally.values()):
237
+ return
238
+
239
+ os.makedirs(path, exist_ok=True)
240
+ from datetime import datetime
241
+
242
+ now = datetime.now()
243
+
244
+ manifest = {
245
+ "created_at": f"{now:%Y-%m-%d %H:%M:%S}",
246
+ "max_context_length": self.max_context_length,
247
+ "num_rollouts": len(self.tally),
248
+ "rollouts": [],
249
+ }
250
+
251
+ for rid, df in self.tally.items():
252
+ rid_str = str(rid)
253
+ safe_name = self._sanitize_filename(rid_str)
254
+ csv_path = os.path.join(path, f"{safe_name}_tokenwise.csv")
255
+
256
+ # Put 'context' first, then the rest
257
+ cols = ["context"] + [c for c in df.columns if c != "context"]
258
+ try:
259
+ df[cols].to_csv(csv_path, index=True, index_label="timestep")
260
+ except Exception as e:
261
+ continue
262
+
263
+ manifest["rollouts"].append(
264
+ {
265
+ "rollout_id": rid_str,
266
+ "csv": csv_path,
267
+ "num_rows": int(df.shape[0]),
268
+ "columns": cols,
269
+ }
270
+ )
271
+
272
+ manifest_path = os.path.join(
273
+ path, f"tokenwise_manifest_{now:%Y-%m-%d___%H-%M-%S}.json"
274
+ )
275
+ with open(manifest_path, "w") as fp:
276
+ json.dump(manifest, fp, indent=2)
src_code_for_reproducibility/training/tokenize_chats.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import sys
3
+
4
+ import regex
5
+ import torch
6
+ from transformers import AutoTokenizer
7
+
8
+ from mllm.training.training_data_utils import TrainingChatTurn, TrajectoryBatch
9
+
10
+ logger = logging.getLogger(__name__)
11
+ logger.addHandler(logging.StreamHandler(sys.stdout))
12
+
13
+
14
+ # def get_chat_dicts(chat: list[TrainingChatTurn]) -> list[dict]:
15
+ # chat_dicts = [chat_turn.dict() for chat_turn in chat]
16
+ # return chat_dicts
17
+
18
+
19
+ def process_training_chat(
20
+ tokenizer: AutoTokenizer,
21
+ chat_history: list[TrainingChatTurn],
22
+ entropy_mask_regex: str | None = None,
23
+ exploration_prompts_to_remove: list[str] = [],
24
+ use_engine_out_token_ids: bool = False,
25
+ ) -> tuple[torch.IntTensor, torch.BoolTensor, torch.IntTensor, torch.BoolTensor]:
26
+ """Tokenize a single training chat and build aligned per-token masks.
27
+
28
+ Given an ordered list of `TrainingChatTurn`, this function tokenizes each
29
+ turn independently using the tokenizer's chat template, then concatenates
30
+ all resulting token sequences. It also constructs three parallel 1D masks
31
+ that align with the concatenated tokens:
32
+
33
+ - input_ids: token ids for the entire chat, turn by turn
34
+ - action_mask: True for tokens that belong to assistant turns (i.e., model
35
+ actions), False for tokens from other roles
36
+ - timesteps: per-token time step copied from the originating turn's
37
+ `time_step`
38
+ - state_ends_mask: True for the last token of any turn where
39
+ `is_state_end` is True, otherwise False
40
+
41
+ Important details:
42
+ - Each turn is passed as a single-message list to
43
+ `tokenizer.apply_chat_template` and flattened; the per-turn outputs are
44
+ then concatenated in the original order.
45
+ - Turn boundaries are not explicitly encoded beyond what the chat template
46
+ inserts; masks provide alignment for learning signals and state endings.
47
+ - No truncation or padding is performed here; downstream code should handle
48
+ batching/padding as needed.
49
+ - Note on dtypes: `input_ids` will be a LongTensor (int64). `action_mask`
50
+ and `state_ends_mask` are BoolTensors. `timesteps` is currently created
51
+ as a float tensor; adjust the implementation if integer dtype is
52
+ required downstream.
53
+
54
+ Args:
55
+ tokenizer: A Hugging Face tokenizer supporting `apply_chat_template`.
56
+ chat_history: Ordered list of `TrainingChatTurn` forming one dialogue.
57
+
58
+ Returns:
59
+ A tuple of four 1D tensors, all of equal length N (the total number of
60
+ tokens across all turns), in the following order:
61
+ - input_ids (LongTensor)
62
+ - action_mask (BoolTensor)
63
+ - timesteps (FloatTensor as implemented; see note above)
64
+ - state_ends_mask (BoolTensor)
65
+ """
66
+ state_ends_mask = []
67
+ input_ids = []
68
+ action_mask = []
69
+ timesteps = []
70
+ entropy_mask = []
71
+ engine_log_probs = []
72
+ for train_chat_turn in chat_history:
73
+ is_state_end = train_chat_turn.is_state_end
74
+ time_step = train_chat_turn.time_step
75
+ is_action = train_chat_turn.role == "assistant"
76
+
77
+ # Remove exploration prompts from training data
78
+ for exploration_prompt in exploration_prompts_to_remove:
79
+ if exploration_prompt in train_chat_turn.content:
80
+ train_chat_turn.content = train_chat_turn.content.replace(
81
+ exploration_prompt, ""
82
+ )
83
+
84
+ chat_turn = {
85
+ "role": train_chat_turn.role,
86
+ "content": train_chat_turn.content,
87
+ }
88
+ if entropy_mask_regex is not None:
89
+ is_entropy_mask_true = (
90
+ regex.search(entropy_mask_regex, train_chat_turn.content) is not None
91
+ )
92
+ else:
93
+ is_entropy_mask_true = True
94
+ if is_action:
95
+ chat_turn_ids = train_chat_turn.out_token_ids
96
+ nb_chat_turns_ids = chat_turn_ids.numel()
97
+ action_mask.append(torch.ones(nb_chat_turns_ids, dtype=torch.bool))
98
+ engine_log_probs.append(train_chat_turn.log_probs)
99
+ else:
100
+ chat_turn_ids = train_chat_turn.chat_template_token_ids
101
+ nb_chat_turns_ids = chat_turn_ids.numel()
102
+ action_mask.append(torch.zeros(nb_chat_turns_ids, dtype=torch.bool))
103
+ engine_log_probs.append(torch.zeros(nb_chat_turns_ids, dtype=torch.float))
104
+ nb_chat_turns_ids = chat_turn_ids.numel()
105
+ state_ends_mask.append(torch.zeros(nb_chat_turns_ids, dtype=torch.bool))
106
+ if is_state_end:
107
+ state_ends_mask[-1][-1] = True # last token is state end
108
+ input_ids.append(chat_turn_ids)
109
+ entropy_mask.append(torch.ones(nb_chat_turns_ids, dtype=torch.bool))
110
+ if not is_entropy_mask_true:
111
+ entropy_mask[-1] = entropy_mask[-1] * False
112
+ timesteps.append(torch.ones(nb_chat_turns_ids) * time_step)
113
+ input_ids = torch.cat(input_ids)
114
+ action_mask = torch.cat(action_mask)
115
+ entropy_mask = torch.cat(entropy_mask)
116
+ timesteps = torch.cat(timesteps)
117
+ timesteps = timesteps.to(torch.long)
118
+ state_ends_mask = torch.cat(state_ends_mask)
119
+ engine_log_probs = torch.cat(engine_log_probs)
120
+
121
+ return (
122
+ input_ids,
123
+ action_mask,
124
+ entropy_mask,
125
+ timesteps,
126
+ state_ends_mask,
127
+ engine_log_probs,
128
+ )
src_code_for_reproducibility/training/trainer_ad_align.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ import sys
4
+ from dataclasses import dataclass
5
+ from typing import Tuple
6
+
7
+ import torch
8
+ from torch.nn.utils.rnn import pad_sequence
9
+
10
+ from mllm.markov_games.rollout_tree import (
11
+ ChatTurn,
12
+ RolloutTreeBranchNode,
13
+ RolloutTreeRootNode,
14
+ )
15
+ from mllm.training.credit_methods import (
16
+ get_advantage_alignment_credits,
17
+ get_discounted_state_visitation_credits,
18
+ )
19
+ from mllm.training.tally_metrics import Tally
20
+ from mllm.training.tally_rollout import RolloutTally, RolloutTallyItem
21
+ from mllm.training.tally_tokenwise import ContextualizedTokenwiseTally
22
+ from mllm.training.tokenize_chats import process_training_chat
23
+ from mllm.training.trainer_common import BaseTrainer
24
+ from mllm.training.training_data_utils import (
25
+ AdvantagePacket,
26
+ TrainingBatch,
27
+ TrainingChatTurn,
28
+ TrajectoryBatch,
29
+ get_main_chat_list_and_rewards,
30
+ get_tokenwise_credits,
31
+ )
32
+ from mllm.utils.resource_context import resource_logger_context
33
+
34
+ logger = logging.getLogger(__name__)
35
+ logger.addHandler(logging.StreamHandler(sys.stdout))
36
+
37
+ RolloutId = int
38
+ AgentId = str
39
+
40
+
41
+ @dataclass
42
+ class AdAlignTrainingData:
43
+ agent_id: str
44
+ main_data: TrajectoryBatch
45
+ # list-of-tensors: per rollout advantages with length jT
46
+ main_advantages: list[torch.FloatTensor] | None = None
47
+ # list-of-tensors: per rollout matrix (jT, A)
48
+ alternative_advantages: list[torch.FloatTensor] | None = None
49
+ advantage_alignment_credits: list[torch.FloatTensor] | None = None
50
+
51
+
52
+ def get_alternative_chat_histories(
53
+ agent_id: str, root: RolloutTreeRootNode
54
+ ) -> list[list[TrainingChatTurn], list[torch.FloatTensor]]:
55
+ """
56
+ args:
57
+ agent_id: The agent we want to get the chat history for.
58
+ root: The root of the rollout tree.
59
+ returns:
60
+ alternative_chats: list[list[TrainingChatTurn]] (jT*A, jS')
61
+ alternative_rewards: list[torch.FloatTensor] (jT*A, jT')
62
+ """
63
+ current_node = root.child
64
+ branches = current_node.branches
65
+ pre_branch_chat = []
66
+ pre_branch_rewards = []
67
+ alternative_rewards = []
68
+ alternative_chats = []
69
+ while current_node is not None:
70
+ assert isinstance(
71
+ current_node, RolloutTreeBranchNode
72
+ ), "Current node should be a branch node."
73
+ main_node = current_node.main_child
74
+ branches = current_node.branches
75
+ current_node = main_node.child
76
+
77
+ # Get the `A` alternative trajectories
78
+ alternative_nodes = branches[agent_id]
79
+ for alt_node in alternative_nodes:
80
+ post_branch_chat, post_branch_rewards = get_main_chat_list_and_rewards(
81
+ agent_id=agent_id, root=alt_node
82
+ )
83
+ branch_chat = pre_branch_chat + post_branch_chat
84
+ alternative_chats.append(branch_chat)
85
+ alternative_rewards.append(
86
+ torch.cat([torch.tensor(pre_branch_rewards), post_branch_rewards])
87
+ )
88
+
89
+ chat_turns: list[ChatTurn] = main_node.step_log.action_logs[agent_id].chat_turns
90
+ chat_turns: list[TrainingChatTurn] = [
91
+ TrainingChatTurn(time_step=main_node.time_step, **turn.model_dump())
92
+ for turn in chat_turns
93
+ ]
94
+
95
+ pre_branch_chat.extend(chat_turns)
96
+ pre_branch_rewards.append(
97
+ main_node.step_log.simulation_step_log.rewards[agent_id]
98
+ )
99
+
100
+ return alternative_chats, alternative_rewards
101
+
102
+
103
+ class TrainerAdAlign(BaseTrainer):
104
+ """
105
+ Extends the reinforce trainer to support Advantage Alignment.
106
+ """
107
+
108
+ def __init__(
109
+ self,
110
+ ad_align_beta: float,
111
+ ad_align_gamma: float,
112
+ ad_align_exclude_k_equals_t: bool,
113
+ ad_align_use_sign: bool,
114
+ ad_align_clipping: float,
115
+ ad_align_force_coop_first_step: bool,
116
+ use_old_ad_align: bool,
117
+ use_time_regularization: bool,
118
+ rloo_branch: bool,
119
+ reuse_baseline: bool,
120
+ ad_align_beta_anneal_step: int = -1,
121
+ ad_align_beta_anneal_rate: float = 0.5,
122
+ min_ad_align_beta: float = 0.1,
123
+ mean_normalize_ad_align: bool = False,
124
+ whiten_adalign_advantages: bool = False,
125
+ whiten_adalign_advantages_time_step_wise: bool = False,
126
+ *args,
127
+ **kwargs,
128
+ ):
129
+ """
130
+ Initialize the advantage alignment trainer.
131
+ Args:
132
+ ad_align_beta: Beta parameter for the advantage alignment.
133
+ ad_align_gamma: Gamma parameter for the advantage alignment.
134
+ ad_align_exclude_k_equals_t: Whether to include k = t in the advantage alignment.
135
+ ad_align_use_sign: Whether to use sign in the advantage alignment.
136
+ ad_align_clipping: Clipping value for the advantage alignment.
137
+ ad_align_force_coop_first_step: Whether to force coop on the first step of the advantage alignment.
138
+ """
139
+ super().__init__(*args, **kwargs)
140
+ self.ad_align_beta = ad_align_beta
141
+ self.ad_align_gamma = ad_align_gamma
142
+ self.ad_align_exclude_k_equals_t = ad_align_exclude_k_equals_t
143
+ self.ad_align_use_sign = ad_align_use_sign
144
+ self.ad_align_clipping = ad_align_clipping
145
+ self.ad_align_force_coop_first_step = ad_align_force_coop_first_step
146
+ self.use_old_ad_align = use_old_ad_align
147
+ self.use_time_regularization = use_time_regularization
148
+ self.rloo_branch = rloo_branch
149
+ self.reuse_baseline = reuse_baseline
150
+ self.ad_align_beta_anneal_step = ad_align_beta_anneal_step
151
+ self.ad_align_beta_anneal_rate = ad_align_beta_anneal_rate
152
+ self.min_ad_align_beta = min_ad_align_beta
153
+ self.past_ad_align_step = -1
154
+ self.mean_normalize_ad_align = mean_normalize_ad_align
155
+ self.whiten_adalign_advantages = whiten_adalign_advantages
156
+ self.whiten_adalign_advantages_time_step_wise = (
157
+ whiten_adalign_advantages_time_step_wise
158
+ )
159
+ self.training_data: dict[AgentId, AdAlignTrainingData] = {}
160
+ self.debug_path_list: list[str] = []
161
+
162
+ def set_agent_trajectory_data(
163
+ self, agent_id: str, roots: list[RolloutTreeRootNode]
164
+ ):
165
+ """
166
+ TOWRITE
167
+ Set the advantage alignment data for the trainer.
168
+ """
169
+
170
+ B = len(roots) # Number of rollouts
171
+
172
+ # For main rollouts
173
+ batch_rollout_ids = []
174
+ batch_crn_ids = []
175
+ batch_input_ids = []
176
+ batch_action_mask = []
177
+ batch_entropy_mask = []
178
+ batch_timesteps = []
179
+ batch_state_ends_mask = []
180
+ batch_engine_log_probs = []
181
+ batch_rewards = []
182
+
183
+ # For alternative actions rollouts
184
+ batch_branching_time_steps = []
185
+ alternative_batch_input_ids = []
186
+ alternative_batch_action_mask = []
187
+ alternative_batch_entropy_mask = []
188
+ alternative_batch_timesteps = []
189
+ alternative_batch_state_ends_mask = []
190
+ alternative_batch_engine_log_probs = []
191
+ alternative_batch_rewards = []
192
+ jT_list = []
193
+
194
+ try:
195
+ A = len(roots[0].child.branches[agent_id]) # Number of alternative actions
196
+ except:
197
+ A = 0
198
+
199
+ for root in roots:
200
+ rollout_id = root.id
201
+ self.debug_path_list.append(
202
+ "mgid:" + str(rollout_id) + "_agent_id:" + agent_id
203
+ )
204
+ # Get main trajectory
205
+ batch_rollout_ids.append(rollout_id)
206
+ batch_crn_ids.append(root.crn_id)
207
+ main_chat, main_rewards = get_main_chat_list_and_rewards(
208
+ agent_id=agent_id, root=root
209
+ )
210
+ (
211
+ input_ids,
212
+ action_mask,
213
+ entropy_mask,
214
+ timesteps,
215
+ state_ends_mask,
216
+ engine_log_probs,
217
+ ) = process_training_chat(
218
+ tokenizer=self.tokenizer,
219
+ chat_history=main_chat,
220
+ entropy_mask_regex=self.entropy_mask_regex,
221
+ exploration_prompts_to_remove=self.exploration_prompts_to_remove,
222
+ )
223
+ batch_input_ids.append(input_ids)
224
+ batch_action_mask.append(action_mask)
225
+ batch_entropy_mask.append(entropy_mask)
226
+ batch_timesteps.append(timesteps)
227
+ batch_state_ends_mask.append(state_ends_mask)
228
+ batch_engine_log_probs.append(engine_log_probs)
229
+ batch_rewards.append(main_rewards)
230
+ jT = main_rewards.numel() # TODO: better than this
231
+ jT_list.append(jT)
232
+ if A > 0:
233
+ # We get the branching time steps for each of the `jT` time steps in the main trajectory.
234
+ branching_time_steps = [bt for item in range(jT) for bt in A * [item]]
235
+ batch_branching_time_steps.extend(branching_time_steps)
236
+
237
+ # Get all of the (jT*A) alternative trajectories in the tree
238
+ # (jT is the number of time steps in the main trajectory, A is the number of alternative actions)
239
+ alternative_chats, alternative_rewards = get_alternative_chat_histories(
240
+ agent_id=agent_id, root=root
241
+ )
242
+ assert (
243
+ len(alternative_chats) == A * jT
244
+ ), "Incorrect number of alternative trajectories."
245
+
246
+ for chat, rewards in zip(alternative_chats, alternative_rewards):
247
+ (
248
+ input_ids,
249
+ action_mask,
250
+ entropy_mask,
251
+ timesteps,
252
+ state_ends_mask,
253
+ engine_log_probs,
254
+ ) = process_training_chat(
255
+ tokenizer=self.tokenizer,
256
+ chat_history=chat,
257
+ entropy_mask_regex=self.entropy_mask_regex,
258
+ exploration_prompts_to_remove=self.exploration_prompts_to_remove,
259
+ )
260
+ alternative_batch_input_ids.append(input_ids)
261
+ alternative_batch_action_mask.append(action_mask)
262
+ alternative_batch_entropy_mask.append(entropy_mask)
263
+ alternative_batch_timesteps.append(timesteps)
264
+ alternative_batch_state_ends_mask.append(state_ends_mask)
265
+ alternative_batch_engine_log_probs.append(engine_log_probs)
266
+ alternative_batch_rewards.append(rewards)
267
+
268
+ jT_list = torch.Tensor(jT_list)
269
+
270
+ # Assert that number of alternative actions is constant
271
+ # assert len(set(nb_alternative_actions)) == 1, "Number of alternative actions must be constant"
272
+ # A = nb_alternative_actions[0]
273
+
274
+ trajectory_batch = TrajectoryBatch(
275
+ rollout_ids=torch.tensor(batch_rollout_ids, dtype=torch.int32), # (B,)
276
+ crn_ids=torch.tensor(batch_crn_ids, dtype=torch.int32),
277
+ agent_ids=[agent_id] * len(batch_rollout_ids),
278
+ batch_input_ids=batch_input_ids,
279
+ batch_action_mask=batch_action_mask,
280
+ batch_entropy_mask=batch_entropy_mask,
281
+ batch_timesteps=batch_timesteps,
282
+ batch_state_ends_mask=batch_state_ends_mask,
283
+ batch_engine_log_probs=batch_engine_log_probs,
284
+ batch_rewards=batch_rewards,
285
+ )
286
+ # Get Advantages & Train Critic
287
+ with resource_logger_context(
288
+ logger, "Get advantages with critic gradient accumulation"
289
+ ):
290
+ self.batch_advantages: torch.FloatTensor = (
291
+ self.get_advantages_with_critic_gradient_accumulation(trajectory_batch)
292
+ ) # (B, jT)
293
+
294
+ if A > 0:
295
+ # Here, `A` is the number of alternative actions / trajectories taken at each time step.
296
+ # For each of the `B` rollout perspectives, at each of its jT (`j` is for jagged, since each main rollout may be of a different length) steps, we take A alternate trajectories (from different actions).
297
+ # Therefore, we have ∑jT * A trajectories to process. If each of the main trajectories have T steps, we will have `B*T*A` to process.
298
+ with resource_logger_context(logger, "Create alternative trajectory batch"):
299
+ sum_jT = int(torch.sum(jT_list).item())
300
+ jT_list = (
301
+ jT_list.int().tolist()
302
+ ) # (jT,) # (we only want the advantages where we branched out)
303
+ alternative_trajectory_batch = TrajectoryBatch(
304
+ rollout_ids=torch.zeros(A * sum_jT, dtype=torch.int32),
305
+ crn_ids=torch.zeros(A * sum_jT, dtype=torch.int32),
306
+ agent_ids=[agent_id] * (A * sum_jT),
307
+ batch_input_ids=alternative_batch_input_ids,
308
+ batch_action_mask=alternative_batch_action_mask,
309
+ batch_entropy_mask=alternative_batch_entropy_mask,
310
+ batch_timesteps=alternative_batch_timesteps,
311
+ batch_state_ends_mask=alternative_batch_state_ends_mask,
312
+ batch_engine_log_probs=alternative_batch_engine_log_probs,
313
+ batch_rewards=alternative_batch_rewards,
314
+ )
315
+
316
+ # Get alternative advantages
317
+ # BAAs stands for batch alternative advantages
318
+ # (torch nested tensors have very little api support, so we have to do some odd manual work here)
319
+ with resource_logger_context(
320
+ logger, "Compute alternative advantage estimates"
321
+ ):
322
+ BAAs_list = self.get_advantages_with_critic_gradient_accumulation(
323
+ alternative_trajectory_batch
324
+ ) # list length (∑jT * A), each (jT',)
325
+ # Pad alternative advantages to (∑jT*A, P)
326
+
327
+ BAAs_padded = pad_sequence(
328
+ BAAs_list, batch_first=True, padding_value=0.0
329
+ )
330
+ branch_idx = torch.tensor(
331
+ batch_branching_time_steps,
332
+ device=BAAs_padded.device,
333
+ dtype=torch.long,
334
+ )
335
+ gathered = BAAs_padded.gather(
336
+ dim=1, index=branch_idx.unsqueeze(1)
337
+ ).squeeze(1)
338
+ # Reshape and split per rollout, then transpose to (jT_i, A)
339
+ gathered = gathered.view(A, sum_jT) # (A, ∑jT)
340
+ blocks = list(
341
+ torch.split(gathered, jT_list, dim=1)
342
+ ) # len B, shapes (A, jT_i)
343
+ BAAs = [
344
+ blk.transpose(0, 1).contiguous() for blk in blocks
345
+ ] # list of (jT_i, A)
346
+ if self.ad_align_beta_anneal_step > 0:
347
+ max_rollout_id = torch.max(trajectory_batch.rollout_ids) + 1
348
+ if (
349
+ max_rollout_id % self.ad_align_beta_anneal_step == 0
350
+ and self.past_ad_align_step != max_rollout_id
351
+ ):
352
+ self.ad_align_beta = max(
353
+ self.ad_align_beta * self.ad_align_beta_anneal_rate,
354
+ self.min_ad_align_beta,
355
+ )
356
+ logger.info(f"Annealing ad_align_beta to {self.ad_align_beta}")
357
+ self.past_ad_align_step = max_rollout_id
358
+ self.training_data[agent_id] = AdAlignTrainingData(
359
+ agent_id=agent_id,
360
+ main_data=trajectory_batch,
361
+ main_advantages=self.batch_advantages,
362
+ alternative_advantages=BAAs if A > 0 else None,
363
+ )
364
+
365
+ def share_advantage_data(self) -> list[AdvantagePacket]:
366
+ """
367
+ Share the advantage alignment data with other agents.
368
+ Returns:
369
+ AdvantagePacket: The advantage packet containing the agent's advantages.
370
+ """
371
+ logger.info(f"Sharing advantage alignment data.")
372
+ advantage_packets = []
373
+ for _, agent_data in self.training_data.items():
374
+ advantage_packets.append(
375
+ AdvantagePacket(
376
+ agent_id=agent_data.agent_id,
377
+ rollout_ids=agent_data.main_data.rollout_ids,
378
+ main_advantages=agent_data.main_advantages,
379
+ )
380
+ )
381
+ return advantage_packets
382
+
383
+ def receive_advantage_data(self, advantage_packets: list[AdvantagePacket]):
384
+ """
385
+ Receive advantage packets from other players.
386
+ These contain the advantages of the other players' rollouts estimated by them.
387
+ """
388
+ logger.info(f"Receiving advantage packets.")
389
+
390
+ assert (
391
+ len(advantage_packets) > 0
392
+ ), "At least one advantage packet must be provided."
393
+
394
+ for agent_id, agent_data in self.training_data.items():
395
+ coagent_advantage_packets = [
396
+ packet for packet in advantage_packets if packet.agent_id != agent_id
397
+ ]
398
+ agent_rollout_ids = agent_data.main_data.rollout_ids
399
+ agent_advantages = agent_data.main_advantages
400
+ co_agent_advantages = []
401
+ for rollout_id in agent_rollout_ids:
402
+ for co_agent_packet in coagent_advantage_packets:
403
+ if rollout_id in co_agent_packet.rollout_ids:
404
+ index = torch.where(rollout_id == co_agent_packet.rollout_ids)[
405
+ 0
406
+ ].item()
407
+ co_agent_advantages.append(
408
+ co_agent_packet.main_advantages[index]
409
+ )
410
+ # assumes that its two player game, with one co-agent
411
+ break
412
+ assert len(co_agent_advantages) == len(agent_advantages)
413
+ B = len(agent_advantages)
414
+ assert all(
415
+ a.shape[0] == b.shape[0]
416
+ for a, b in zip(co_agent_advantages, agent_advantages)
417
+ ), "Number of advantages must match for advantage alignment."
418
+
419
+ # Get padded tensors (advantage alignment is invariant to padding)
420
+ lengths = torch.tensor(
421
+ [len(t) for t in agent_advantages],
422
+ device=self.device,
423
+ dtype=torch.long,
424
+ )
425
+ padded_main_advantages = pad_sequence(
426
+ agent_advantages, batch_first=True, padding_value=0.0
427
+ )
428
+ if agent_data.alternative_advantages:
429
+ padded_alternative_advantages = pad_sequence(
430
+ agent_data.alternative_advantages,
431
+ batch_first=True,
432
+ padding_value=0.0,
433
+ ) # (B, P, A)
434
+ else:
435
+ padded_alternative_advantages = None
436
+ padded_co_agent_advantages = pad_sequence(
437
+ co_agent_advantages, batch_first=True, padding_value=0.0
438
+ )
439
+
440
+ # Create training batch data
441
+ credits, sub_tensors = get_advantage_alignment_credits(
442
+ a1=padded_main_advantages,
443
+ a1_alternative=padded_alternative_advantages,
444
+ a2=padded_co_agent_advantages,
445
+ beta=self.ad_align_beta,
446
+ gamma=self.ad_align_gamma,
447
+ exclude_k_equals_t=self.ad_align_exclude_k_equals_t,
448
+ use_sign=self.ad_align_use_sign,
449
+ clipping=self.ad_align_clipping,
450
+ force_coop_first_step=self.ad_align_force_coop_first_step,
451
+ use_old_ad_align=self.use_old_ad_align,
452
+ use_time_regularization=self.use_time_regularization,
453
+ rloo_branch=self.rloo_branch,
454
+ reuse_baseline=self.reuse_baseline,
455
+ mean_normalize_ad_align=self.mean_normalize_ad_align,
456
+ whiten_adalign_advantages=self.whiten_adalign_advantages,
457
+ whiten_adalign_advantages_time_step_wise=self.whiten_adalign_advantages_time_step_wise,
458
+ )
459
+ for key, value in sub_tensors.items():
460
+ self.rollout_tally.add_metric(
461
+ path=[key],
462
+ rollout_tally_item=RolloutTallyItem(
463
+ crn_ids=agent_data.main_data.crn_ids,
464
+ rollout_ids=agent_data.main_data.rollout_ids,
465
+ agent_ids=agent_data.main_data.agent_ids,
466
+ metric_matrix=value,
467
+ ),
468
+ )
469
+
470
+ if not self.skip_discounted_state_visitation:
471
+ credits = get_discounted_state_visitation_credits(
472
+ credits,
473
+ self.discount_factor,
474
+ )
475
+ self.rollout_tally.add_metric(
476
+ path=["discounted_state_visitation_credits"],
477
+ rollout_tally_item=RolloutTallyItem(
478
+ crn_ids=agent_data.main_data.crn_ids,
479
+ rollout_ids=agent_data.main_data.rollout_ids,
480
+ agent_ids=agent_data.main_data.agent_ids,
481
+ metric_matrix=sub_tensors[
482
+ "discounted_state_visitation_credits"
483
+ ],
484
+ ),
485
+ )
486
+
487
+ # Slice back to jagged
488
+ advantage_alignment_credits = [credits[i, : lengths[i]] for i in range(B)]
489
+ # Replace stored training data for this agent by the concrete trajectory batch
490
+ # and attach the computed credits for policy gradient.
491
+ self.training_data[agent_id] = agent_data.main_data
492
+ self.training_data[agent_id].batch_credits = advantage_alignment_credits
src_code_for_reproducibility/training/trainer_common.py ADDED
@@ -0,0 +1,1054 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TODO: Add coefficients for losses (depend on total number of tokens or batch)
3
+ TODO: adapt reinforce step for torch.compile
4
+ TODO: add lr schedulers support
5
+ """
6
+ import logging
7
+ import os
8
+ import pickle
9
+ import sys
10
+ from abc import ABC, abstractmethod
11
+ from typing import Callable, Literal, Union
12
+
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from accelerate import Accelerator
17
+ from pandas._libs.tslibs.offsets import CBMonthBegin
18
+ from peft import LoraConfig
19
+ from torch.nn.utils.rnn import pad_sequence
20
+ from transformers import AutoModelForCausalLM, AutoTokenizer
21
+
22
+ from mllm.markov_games.rollout_tree import *
23
+ from mllm.markov_games.rollout_tree import RolloutTreeRootNode
24
+ from mllm.training.annealing_methods import sigmoid_annealing
25
+ from mllm.training.credit_methods import (
26
+ get_discounted_returns,
27
+ get_generalized_advantage_estimates,
28
+ get_rloo_credits,
29
+ whiten_advantages,
30
+ whiten_advantages_time_step_wise,
31
+ )
32
+ from mllm.training.tally_metrics import Tally
33
+ from mllm.training.tally_rollout import RolloutTally, RolloutTallyItem
34
+ from mllm.training.tally_tokenwise import ContextualizedTokenwiseTally
35
+ from mllm.training.tokenize_chats import *
36
+ from mllm.training.tokenize_chats import process_training_chat
37
+ from mllm.training.training_data_utils import *
38
+ from mllm.training.training_data_utils import (
39
+ TrainingBatch,
40
+ TrajectoryBatch,
41
+ get_tokenwise_credits,
42
+ )
43
+ from mllm.utils.resource_context import resource_logger_context
44
+
45
+ logger = logging.getLogger(__name__)
46
+ logger.addHandler(logging.StreamHandler(sys.stdout))
47
+
48
+
49
+ @dataclass
50
+ class TrainerAnnealingState:
51
+ annealing_step_counter: int = 0
52
+
53
+
54
+ class BaseTrainer(ABC):
55
+ """
56
+ Trainer
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ policy: AutoModelForCausalLM,
62
+ policy_optimizer: torch.optim.Optimizer,
63
+ critic: Union[AutoModelForCausalLM, None],
64
+ critic_optimizer: Union[torch.optim.Optimizer, None],
65
+ tokenizer: AutoTokenizer,
66
+ lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
67
+ critic_lr_scheduler: Union[torch.optim.lr_scheduler.LRScheduler, None],
68
+ ######################################################################
69
+ entropy_coeff: float,
70
+ entropy_topk: int,
71
+ entropy_mask_regex: Union[str, None],
72
+ kl_coeff: float,
73
+ gradient_clipping: Union[float, None],
74
+ restrict_tokens: Union[list[str], None],
75
+ mini_batch_size: int,
76
+ use_gradient_checkpointing: bool,
77
+ temperature: float,
78
+ device: str,
79
+ whiten_advantages: bool,
80
+ whiten_advantages_time_step_wise: bool,
81
+ use_gae: bool,
82
+ use_gae_lambda_annealing: bool,
83
+ gae_lambda_annealing_limit: float,
84
+ gae_lambda_annealing_method: Literal["sigmoid_annealing"],
85
+ gae_lambda_annealing_method_params: dict,
86
+ pg_loss_normalization: Literal["batch", "nb_tokens"],
87
+ use_rloo: bool,
88
+ skip_discounted_state_visitation: bool,
89
+ discount_factor: float,
90
+ enable_tokenwise_logging: bool,
91
+ save_path: str,
92
+ reward_normalizing_constant: float = 1.0,
93
+ critic_loss_type: Literal["mse", "huber"] = "huber",
94
+ exploration_prompts_to_remove: list[str] = [],
95
+ filter_higher_refprob_tokens_kl: bool = False,
96
+ truncated_importance_sampling_ratio_cap: float = 0.0,
97
+ importance_sampling_strategy: Literal[
98
+ "per_token", "per_sequence"
99
+ ] = "per_token",
100
+ ):
101
+ """
102
+ Initialize the REINFORCE trainer with reward shaping for multi-agent or single-agent training.
103
+
104
+ Args:
105
+ model (AutoModelForCausalLM): The main policy model.
106
+ tokenizer (AutoTokenizer): Tokenizer for the model.
107
+ optimizer (torch.optim.Optimizer): Optimizer for the policy model.
108
+ lr_scheduler (torch.optim.lr_scheduler.LRScheduler): Learning rate scheduler for the policy model.
109
+ critic (AutoModelForCausalLM or None): Critic model for value estimation (optional).
110
+ critic_optimizer (torch.optim.Optimizer or None): Optimizer for the critic model (optional).
111
+ critic_lr_scheduler (torch.optim.lr_scheduler.LRScheduler or None): LR scheduler for the critic (optional).
112
+ config (RtConfig): Configuration object for training.
113
+ """
114
+ self.tokenizer = tokenizer
115
+ # self.tokenizer.padding_side = "left" # needed for flash attention
116
+ if self.tokenizer.pad_token_id is None:
117
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
118
+ self.lr_scheduler = lr_scheduler
119
+ self.accelerator = Accelerator()
120
+ (
121
+ self.policy,
122
+ self.policy_optimizer,
123
+ self.critic,
124
+ self.critic_optimizer,
125
+ ) = self.accelerator.prepare(policy, policy_optimizer, critic, critic_optimizer)
126
+
127
+ self.critic_lr_scheduler = critic_lr_scheduler
128
+ self.tally = Tally()
129
+
130
+ if use_gradient_checkpointing == True:
131
+ self.policy.gradient_checkpointing_enable(dict(use_reentrant=False))
132
+ if critic is not None:
133
+ self.critic.gradient_checkpointing_enable(dict(use_reentrant=False))
134
+
135
+ self.save_path = save_path
136
+
137
+ # Load trainer state if it exists
138
+ self.trainer_annealing_state_path = os.path.join(
139
+ self.save_path, "trainer_annealing_state.pkl"
140
+ )
141
+ if os.path.exists(self.trainer_annealing_state_path):
142
+ logger.info(
143
+ f"Loading trainer state from {self.trainer_annealing_state_path}"
144
+ )
145
+ self.trainer_annealing_state = pickle.load(
146
+ open(self.trainer_annealing_state_path, "rb")
147
+ )
148
+ else:
149
+ self.trainer_annealing_state = TrainerAnnealingState()
150
+
151
+ # Load policy optimizer state if it exists
152
+ self.policy_optimizer_path = os.path.join(
153
+ self.save_path, "policy_optimizer_state.pt"
154
+ )
155
+ if os.path.exists(self.policy_optimizer_path):
156
+ logger.info(
157
+ f"Loading policy optimizer state from {self.policy_optimizer_path}"
158
+ )
159
+ self.policy_optimizer.load_state_dict(
160
+ torch.load(self.policy_optimizer_path)
161
+ )
162
+
163
+ # Load critic optimizer state if it exists
164
+ self.critic_optimizer_path = os.path.join(
165
+ self.save_path, "critic_optimizer_state.pt"
166
+ )
167
+ if (
168
+ os.path.exists(self.critic_optimizer_path)
169
+ and self.critic_optimizer is not None
170
+ ):
171
+ logger.info(
172
+ f"Loading critic optimizer state from {self.critic_optimizer_path}"
173
+ )
174
+ self.critic_optimizer.load_state_dict(
175
+ torch.load(self.critic_optimizer_path)
176
+ )
177
+ self.device = self.accelerator.device
178
+ self.entropy_coeff = entropy_coeff
179
+ self.entropy_topk = entropy_topk
180
+ self.entropy_mask_regex = entropy_mask_regex
181
+ self.kl_coeff = kl_coeff
182
+ self.gradient_clipping = gradient_clipping
183
+ self.restrict_tokens = restrict_tokens
184
+ self.mini_batch_size = mini_batch_size
185
+ self.use_gradient_checkpointing = use_gradient_checkpointing
186
+ self.temperature = temperature
187
+ self.use_gae = use_gae
188
+ self.whiten_advantages = whiten_advantages
189
+ self.whiten_advantages_time_step_wise = whiten_advantages_time_step_wise
190
+ self.use_rloo = use_rloo
191
+ self.skip_discounted_state_visitation = skip_discounted_state_visitation
192
+ self.use_gae_lambda_annealing = use_gae_lambda_annealing
193
+ self.gae_lambda_annealing_limit = gae_lambda_annealing_limit
194
+ if use_gae_lambda_annealing:
195
+ self.gae_lambda_annealing_method: Callable[
196
+ [int], float
197
+ ] = lambda step: eval(gae_lambda_annealing_method)(
198
+ step=step, **gae_lambda_annealing_method_params
199
+ )
200
+ self.discount_factor = discount_factor
201
+ self.enable_tokenwise_logging = enable_tokenwise_logging
202
+ self.reward_normalizing_constant = reward_normalizing_constant
203
+ self.pg_loss_normalization = pg_loss_normalization
204
+ self.critic_loss_type = critic_loss_type
205
+ self.exploration_prompts_to_remove = exploration_prompts_to_remove
206
+ # Common containers used by all trainers
207
+ self.training_data: dict = {}
208
+ self.debug_path_list: list[str] = []
209
+ self.policy_gradient_data = None
210
+ self.tally = Tally()
211
+ self.rollout_tally = RolloutTally()
212
+ self.tokenwise_tally: Union[ContextualizedTokenwiseTally, None] = None
213
+ self.filter_higher_refprob_tokens_kl = filter_higher_refprob_tokens_kl
214
+ self.truncated_importance_sampling_ratio_cap = (
215
+ truncated_importance_sampling_ratio_cap
216
+ )
217
+ self.importance_sampling_strategy = importance_sampling_strategy
218
+
219
+ def mask_non_restricted_token_logits(self, logits: torch.Tensor) -> torch.Tensor:
220
+ """
221
+ Masks logits so that only allowed tokens (as specified in config.restrict_tokens)
222
+ and the EOS token are active.
223
+ All other logits are set to -inf, effectively removing them from the softmax.
224
+
225
+ Args:
226
+ logits (torch.Tensor): The logits tensor of shape (B, S, V).
227
+
228
+ Returns:
229
+ torch.Tensor: The masked logits tensor.
230
+ """
231
+ # TODO: verify. Not sure what we do here is differentiable
232
+ # also, we recompute for nothing
233
+
234
+ if self.restrict_tokens is not None:
235
+ allowed_token_ids = []
236
+ for token in self.restrict_tokens:
237
+ token_ids = self.tokenizer(token, add_special_tokens=False)["input_ids"]
238
+ allowed_token_ids.append(token_ids[0])
239
+ allowed_token_ids.append(
240
+ self.tokenizer.eos_token_id
241
+ ) # This token should always be active
242
+ allowed_token_ids = torch.tensor(allowed_token_ids, device=logits.device)
243
+ # Mask log_probs and probs to only allowed tokens
244
+ mask = torch.zeros_like(logits).bool() # (B, S, V)
245
+ mask[..., allowed_token_ids] = True
246
+ logits = torch.where(
247
+ mask,
248
+ logits,
249
+ torch.tensor(-float("inf"), device=logits.device),
250
+ )
251
+
252
+ return logits
253
+
254
+ # def get_gradient_magnitude(self, loss_term: torch.Tensor) -> float:
255
+ # """
256
+ # Computes the L2 norm of the gradients of the given loss term with respect to the model parameters.
257
+
258
+ # Args:
259
+ # loss_term (torch.Tensor): The loss tensor to compute gradients for.
260
+
261
+ # Returns:
262
+ # float: The L2 norm of the gradients, or 0.0 if no gradients are present.
263
+ # """
264
+ # with torch.no_grad():
265
+ # grads = torch.autograd.grad(
266
+ # loss_term,
267
+ # [p for p in self.policy.parameters() if p.requires_grad],
268
+ # retain_graph=True,
269
+ # allow_unused=True,
270
+ # )
271
+ # grads = [g for g in grads if g is not None]
272
+ # if not grads:
273
+ # return torch.tensor(0.0, device=loss_term.device)
274
+ # return torch.norm(torch.stack([g.norm(2) for g in grads])).item()
275
+
276
+ def apply_reinforce_step(
277
+ self,
278
+ training_batch: TrainingBatch,
279
+ ) -> None:
280
+ """
281
+ Applies a single REINFORCE policy gradient step using the provided batch of rollouts.
282
+ Handles batching, loss computation (including entropy and KL regularization), gradient accumulation, and optimizer step.
283
+ Optionally logs various metrics and statistics.
284
+
285
+ Args:
286
+ paths (list[str]): List of game complete file paths for each rollout.
287
+ contexts (list[torch.Tensor]): List of context tensors for each rollout.
288
+ credits (list[torch.Tensor]): List of credit tensors (rewards/advantages) for each rollout.
289
+ action_masks (list[torch.Tensor]): List of action mask tensors for each rollout.
290
+ """
291
+ with resource_logger_context(logger, "Apply reinforce step"):
292
+ self.policy.train()
293
+ mb_size = self.mini_batch_size
294
+ nb_rollouts = len(training_batch)
295
+
296
+ # Initialize running mean logs
297
+ running_mean_logs = {
298
+ "rl_objective": 0.0,
299
+ "policy_gradient_loss": 0.0,
300
+ "policy_gradient_norm": 0.0,
301
+ "log_probs": 0.0,
302
+ "credits": 0.0,
303
+ "entropy": 0.0,
304
+ "engine_log_probs_diff_clampfrac": 0.0,
305
+ "tis_imp_ratio": 0.0,
306
+ "ref_log_probs_diff_clampfrac": 0.0,
307
+ "higher_refprob_frac": 0.0,
308
+ "tis_imp_ratio_clampfrac": 0.0,
309
+ }
310
+ if self.entropy_coeff != 0.0:
311
+ running_mean_logs["entropy"] = 0.0
312
+ if self.kl_coeff != 0.0:
313
+ running_mean_logs["kl_divergence"] = 0.0
314
+
315
+ # Get total number of tokens generated
316
+ total_tokens_generated = 0
317
+ for att_mask in training_batch.batch_action_mask:
318
+ total_tokens_generated += att_mask.sum()
319
+
320
+ # Obtain loss normalization
321
+ if self.pg_loss_normalization == "nb_tokens":
322
+ normalization_factor = total_tokens_generated
323
+ elif self.pg_loss_normalization == "batch":
324
+ normalization_factor = np.ceil(nb_rollouts / mb_size).astype(int)
325
+ else:
326
+ raise ValueError(
327
+ f"Invalid pg_loss_normalization: {self.pg_loss_normalization}"
328
+ )
329
+
330
+ # Gradient accumulation for each mini-batch
331
+ for mb in range(0, nb_rollouts, mb_size):
332
+ logger.info(f"Processing mini-batch {mb} of {nb_rollouts}")
333
+ loss = 0.0
334
+ training_mb = training_batch[mb : mb + mb_size]
335
+ training_mb = training_mb.get_padded_tensors()
336
+ training_mb.to(self.device)
337
+ (
338
+ tokens_mb,
339
+ action_mask_mb,
340
+ entropy_mask_mb,
341
+ credits_mb,
342
+ engine_log_probs_mb,
343
+ timesteps_mb,
344
+ ) = (
345
+ training_mb.batch_input_ids,
346
+ training_mb.batch_action_mask,
347
+ training_mb.batch_entropy_mask,
348
+ training_mb.batch_credits,
349
+ training_mb.batch_engine_log_probs,
350
+ training_mb.batch_timesteps,
351
+ )
352
+
353
+ # Next token prediction
354
+ contexts_mb = tokens_mb[:, :-1]
355
+ shifted_contexts_mb = tokens_mb[:, 1:]
356
+ action_mask_mb = action_mask_mb[:, 1:]
357
+ entropy_mask_mb = entropy_mask_mb[:, 1:]
358
+ credits_mb = credits_mb[:, 1:]
359
+ engine_log_probs_mb = engine_log_probs_mb[:, 1:]
360
+ timesteps_mb = timesteps_mb[:, 1:]
361
+
362
+ if self.enable_tokenwise_logging:
363
+ self.tokenwise_tally.set_action_mask(action_mask=action_mask_mb)
364
+ self.tokenwise_tally.set_range(range=(mb, mb + mb_size))
365
+ self.tokenwise_tally.add_contexts(contexts=contexts_mb)
366
+ self.tokenwise_tally.add_data(
367
+ metric_id="next_token",
368
+ metrics=shifted_contexts_mb,
369
+ to_tids=True,
370
+ )
371
+ self.tokenwise_tally.add_data(
372
+ metric_id="entropy_mask",
373
+ metrics=entropy_mask_mb,
374
+ )
375
+
376
+ if self.enable_tokenwise_logging:
377
+ self.tokenwise_tally.add_data(
378
+ metric_id="next_token_credit", metrics=credits_mb
379
+ )
380
+
381
+ # Forward pass + cast to FP-32 for higher prec.
382
+ # TODO: create attention mask if not relying on default (assume causal llm)
383
+ logits = self.policy(input_ids=contexts_mb)[0] # (B, S, V)
384
+
385
+ # Mask non-restricted tokens
386
+ if self.restrict_tokens is not None:
387
+ logits = self.mask_non_restricted_token_logits(logits)
388
+
389
+ logits /= self.temperature # (B, S, V)
390
+
391
+ # Compute new log probabilities
392
+ log_probs = F.log_softmax(logits, dim=-1) # (B, S, V)
393
+
394
+ # Get log probabilities of actions taken during rollouts
395
+ action_log_probs = log_probs.gather(
396
+ dim=-1, index=shifted_contexts_mb.unsqueeze(-1)
397
+ ).squeeze(
398
+ -1
399
+ ) # (B, S)
400
+ if self.pg_loss_normalization == "batch":
401
+ den_running_mean = action_mask_mb.sum() * normalization_factor
402
+ else:
403
+ den_running_mean = normalization_factor
404
+ running_mean_logs["log_probs"] += (
405
+ action_log_probs * action_mask_mb
406
+ ).sum().item() / den_running_mean
407
+ running_mean_logs["credits"] += (
408
+ credits_mb * action_mask_mb
409
+ ).sum().item() / den_running_mean
410
+
411
+ if self.enable_tokenwise_logging:
412
+ self.tokenwise_tally.add_data(
413
+ metric_id="next_token_log_prob",
414
+ metrics=action_log_probs,
415
+ )
416
+ self.tokenwise_tally.add_data(
417
+ metric_id="engine_next_token_log_prob",
418
+ metrics=engine_log_probs_mb,
419
+ )
420
+ self.tokenwise_tally.add_data(
421
+ metric_id="next_token_prob",
422
+ metrics=torch.exp(action_log_probs),
423
+ )
424
+ top_k_indices = torch.topk(logits, k=5, dim=-1).indices
425
+ self.tokenwise_tally.add_data(
426
+ metric_id=f"top_{5}_tids",
427
+ metrics=top_k_indices,
428
+ to_tids=True,
429
+ )
430
+ self.tokenwise_tally.add_data(
431
+ metric_id=f"top_{5}_probs",
432
+ metrics=torch.exp(log_probs).gather(
433
+ dim=-1, index=top_k_indices
434
+ ),
435
+ )
436
+
437
+ rewarded_action_log_probs = (
438
+ action_mask_mb * credits_mb * action_log_probs
439
+ )
440
+ # (B, S)
441
+ INVALID_LOGPROB = 1.0
442
+ CLAMP_VALUE = 40.0
443
+ masked_action_log_probs = torch.masked_fill(
444
+ action_log_probs, ~action_mask_mb, INVALID_LOGPROB
445
+ )
446
+ masked_engine_log_probs = torch.masked_fill(
447
+ engine_log_probs_mb, ~action_mask_mb, INVALID_LOGPROB
448
+ )
449
+ with torch.no_grad():
450
+ action_engine_log_probs_diff = (
451
+ masked_action_log_probs - masked_engine_log_probs
452
+ ).clamp(-CLAMP_VALUE, CLAMP_VALUE)
453
+ running_mean_logs["engine_log_probs_diff_clampfrac"] += (
454
+ action_engine_log_probs_diff.abs()
455
+ .eq(CLAMP_VALUE)
456
+ .float()
457
+ .sum()
458
+ .item()
459
+ / den_running_mean
460
+ )
461
+ if self.importance_sampling_strategy == "per_sequence":
462
+ tis_imp_ratio = torch.zeros_like(action_engine_log_probs_diff)
463
+ for mb_idx in range(action_engine_log_probs_diff.shape[0]):
464
+ valid_token_mask = action_mask_mb[mb_idx]
465
+ timestep_ids = timesteps_mb[mb_idx][valid_token_mask]
466
+ timestep_logprob_diffs = action_engine_log_probs_diff[mb_idx][
467
+ valid_token_mask
468
+ ]
469
+ max_timestep = int(timestep_ids.max().item()) + 1
470
+ timestep_sums = torch.zeros(
471
+ max_timestep,
472
+ device=action_engine_log_probs_diff.device,
473
+ dtype=action_engine_log_probs_diff.dtype,
474
+ )
475
+ timestep_sums.scatter_add_(
476
+ 0, timestep_ids, timestep_logprob_diffs
477
+ )
478
+ timestep_ratios = torch.exp(timestep_sums)
479
+ tis_imp_ratio[
480
+ mb_idx, valid_token_mask
481
+ ] = timestep_ratios.gather(0, timestep_ids)
482
+ else:
483
+ tis_imp_ratio = torch.exp(action_engine_log_probs_diff)
484
+ running_mean_logs["tis_imp_ratio"] += (
485
+ tis_imp_ratio * action_mask_mb
486
+ ).sum().item() / den_running_mean
487
+ if self.truncated_importance_sampling_ratio_cap > 0.0:
488
+ tis_imp_ratio = torch.clamp(
489
+ tis_imp_ratio, max=self.truncated_importance_sampling_ratio_cap
490
+ )
491
+ running_mean_logs["tis_imp_ratio_clampfrac"] += (
492
+ tis_imp_ratio.eq(self.truncated_importance_sampling_ratio_cap)
493
+ .float()
494
+ .sum()
495
+ .item()
496
+ ) / den_running_mean
497
+ rewarded_action_log_probs = (
498
+ rewarded_action_log_probs * tis_imp_ratio
499
+ )
500
+
501
+ if self.enable_tokenwise_logging:
502
+ self.tokenwise_tally.add_data(
503
+ metric_id="next_token_clogπ",
504
+ metrics=rewarded_action_log_probs,
505
+ )
506
+
507
+ # Add value term to loss
508
+ if self.pg_loss_normalization == "batch":
509
+ nb_act_tokens = action_mask_mb.sum()
510
+ mb_value = -rewarded_action_log_probs.sum() / nb_act_tokens
511
+ else:
512
+ mb_value = -rewarded_action_log_probs.sum()
513
+
514
+ loss += mb_value
515
+ running_mean_logs["rl_objective"] += mb_value.item() / den_running_mean
516
+
517
+ # -------------------------------------------------
518
+ # Entropy Regularization
519
+ # -------------------------------------------------
520
+ # Only apply entropy on distribution defined over most probable tokens
521
+ if self.entropy_topk is not None:
522
+ top_k_indices = torch.topk(
523
+ logits, k=self.entropy_topk, dim=-1
524
+ ).indices
525
+ entropy_logits = logits.gather(dim=-1, index=top_k_indices)
526
+ else:
527
+ entropy_logits = logits
528
+
529
+ token_entropy_terms = -F.softmax(
530
+ entropy_logits, dim=-1
531
+ ) * F.log_softmax(
532
+ entropy_logits, dim=-1
533
+ ) # (B, S, T)
534
+ token_entropy_terms *= (
535
+ action_mask_mb[:, :, None] * entropy_mask_mb[:, :, None]
536
+ ) # only get loss on specific action tokens
537
+
538
+ mb_entropy = token_entropy_terms.sum(dim=-1)
539
+
540
+ if self.enable_tokenwise_logging:
541
+ self.tokenwise_tally.add_data(
542
+ metric_id="entropy",
543
+ metrics=mb_entropy,
544
+ )
545
+ if self.pg_loss_normalization == "batch":
546
+ nb_act_tokens = action_mask_mb.sum()
547
+ mb_entropy = -mb_entropy.sum() / nb_act_tokens
548
+ else:
549
+ mb_entropy = -mb_entropy.sum()
550
+ running_mean_logs["entropy"] += -mb_entropy.item() / den_running_mean
551
+ if self.entropy_coeff != 0.0:
552
+ mb_entropy *= self.entropy_coeff
553
+ loss += mb_entropy
554
+
555
+ # -------------------------------------------------
556
+ # KL-DIVERGENCE
557
+ # -------------------------------------------------
558
+ if self.kl_coeff != 0.0:
559
+ ref_model_logits = self.policy.get_base_model_logits(contexts_mb)
560
+ ref_model_logits = ref_model_logits / self.temperature
561
+ # (B, S, V)
562
+ ref_model_logits = self.mask_non_restricted_token_logits(
563
+ logits=ref_model_logits
564
+ )
565
+ # (B, S, V)
566
+ ref_model_log_probs = F.log_softmax(ref_model_logits, dim=-1)
567
+ # (B, S, V)
568
+ ref_model_action_log_probs = ref_model_log_probs.gather(
569
+ dim=-1, index=shifted_contexts_mb.unsqueeze(-1)
570
+ ).squeeze(
571
+ -1
572
+ ) # (B,S)
573
+ # Approximating KL Divergence (see refs in docstring)
574
+ # Ref 1: http://joschu.net/blog/kl-approx.html
575
+ # Ref 2: https://github.dev/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L1332
576
+ masked_ref_model_action_log_probs = torch.masked_fill(
577
+ ref_model_action_log_probs, ~action_mask_mb, INVALID_LOGPROB
578
+ )
579
+ action_log_probs_diff = (
580
+ masked_ref_model_action_log_probs - masked_action_log_probs
581
+ ).clamp(-CLAMP_VALUE, CLAMP_VALUE)
582
+ running_mean_logs["ref_log_probs_diff_clampfrac"] += (
583
+ action_log_probs_diff.abs().eq(CLAMP_VALUE).float().sum().item()
584
+ / den_running_mean
585
+ )
586
+ if self.filter_higher_refprob_tokens_kl:
587
+ higher_refprob_tokens_mask = action_log_probs_diff > 0.0
588
+ running_mean_logs["higher_refprob_frac"] += (
589
+ higher_refprob_tokens_mask.sum().item() / den_running_mean
590
+ )
591
+ action_log_probs_diff = action_log_probs_diff * (
592
+ ~higher_refprob_tokens_mask
593
+ )
594
+ kl_div = torch.expm1(action_log_probs_diff) - action_log_probs_diff
595
+ kl_div *= action_mask_mb # We only care about KLD of action tokens
596
+ if self.truncated_importance_sampling_ratio_cap > 0.0:
597
+ kl_div = kl_div * tis_imp_ratio
598
+ kl_div *= self.kl_coeff
599
+ if self.enable_tokenwise_logging:
600
+ self.tokenwise_tally.add_data(
601
+ metric_id="ref_model_next_token_log_prob",
602
+ metrics=ref_model_action_log_probs,
603
+ )
604
+ self.tokenwise_tally.add_data(
605
+ metric_id="kl_divergence",
606
+ metrics=kl_div,
607
+ )
608
+
609
+ if self.pg_loss_normalization == "batch":
610
+ nb_act_tokens = action_mask_mb.sum()
611
+ mb_kl = kl_div.sum() / nb_act_tokens
612
+ else:
613
+ mb_kl = kl_div.sum()
614
+ running_mean_logs["kl_divergence"] += (
615
+ mb_kl.item() / den_running_mean
616
+ )
617
+ loss += mb_kl
618
+
619
+ # Accumulate gradient
620
+ running_mean_logs["policy_gradient_loss"] += (
621
+ loss.item() / den_running_mean
622
+ )
623
+ loss /= normalization_factor
624
+ self.accelerator.backward(loss)
625
+
626
+ # ensure gpu memory is freed
627
+ del training_mb
628
+ del log_probs
629
+ del logits
630
+ del loss
631
+ del action_log_probs
632
+ del rewarded_action_log_probs
633
+
634
+ logger.info(
635
+ f"Accumulated the policy gradient loss for {total_tokens_generated} tokens."
636
+ )
637
+
638
+ # Clip gradients and take step
639
+ if self.gradient_clipping is not None:
640
+ grad_norm = self.accelerator.clip_grad_norm_(
641
+ self.policy.parameters(), self.gradient_clipping
642
+ )
643
+ running_mean_logs["policy_gradient_norm"] += grad_norm.item()
644
+
645
+ # Take step
646
+ self.policy_optimizer.step()
647
+ self.policy_optimizer.zero_grad()
648
+
649
+ # Store logs
650
+ for key, value in running_mean_logs.items():
651
+ self.tally.add_metric(path=key, metric=value)
652
+
653
+ # Clear
654
+ # TODO: verify
655
+ self.accelerator.clear(self.policy, self.policy_optimizer)
656
+ import gc
657
+
658
+ gc.collect()
659
+ torch.cuda.empty_cache()
660
+ return running_mean_logs
661
+
662
+ def get_advantages_with_critic_gradient_accumulation(
663
+ self, trajectories: TrajectoryBatch, critic_loss_scaling_factor: float = 2.0
664
+ ) -> torch.FloatTensor:
665
+ """
666
+ TOWRITE
667
+ Uses GAE if enabled, otherwise uses Monte Carlo returns.
668
+ Optionally trains the critic if GAE is used.
669
+ Returns:
670
+ advantages: NestedFloatTensors
671
+ """
672
+
673
+ mb_size = self.mini_batch_size
674
+ batch_size = trajectories.rollout_ids.shape[0]
675
+ agent_id = trajectories.agent_ids[0]
676
+ batch_rewards = trajectories.batch_rewards
677
+
678
+ ######################################
679
+ # use critic for advantage estimation
680
+ ######################################
681
+ if self.use_gae:
682
+ if "buffer" in agent_id:
683
+ self.critic.eval()
684
+ training = False
685
+ else:
686
+ self.critic.train()
687
+ training = True
688
+ advantages = []
689
+ # critic_loss_scaling_factor comes learning single critic for two agents
690
+ normalization_factor = (
691
+ np.ceil(batch_size / mb_size).astype(int) * critic_loss_scaling_factor
692
+ )
693
+ # For each minibatch
694
+ for mb in range(0, batch_size, mb_size):
695
+ trajectory_mb = trajectories[mb : mb + mb_size]
696
+ trajectory_mb.to(self.device)
697
+ rewards_mb = trajectory_mb.batch_rewards
698
+ (
699
+ tokens_mb,
700
+ state_ends_mask_mb,
701
+ timestep_counts,
702
+ ) = trajectory_mb.get_padded_tensors_for_critic()
703
+ # critic causal attention up to end flags
704
+ if training:
705
+ vals_estimate_full = self.critic(tokens_mb)
706
+ else:
707
+ with torch.no_grad():
708
+ vals_estimate_full = self.critic(tokens_mb)
709
+
710
+ # if vals_estimate_full.dim() == 3:
711
+ # vals_estimate_full = vals_estimate_full.squeeze(-1)
712
+
713
+ # Select only positions where states end, per sample → list of (jT,)
714
+ B = tokens_mb.shape[0]
715
+ vals_list = [
716
+ vals_estimate_full[b][state_ends_mask_mb[b]] for b in range(B)
717
+ ]
718
+
719
+ # Pad to (B, max_jT) = (B, S)
720
+ vals_estimate_mb = pad_sequence(
721
+ vals_list, batch_first=True, padding_value=0.0
722
+ )
723
+ dtype = vals_estimate_mb.dtype
724
+ rewards_mb = pad_sequence(
725
+ rewards_mb, batch_first=True, padding_value=0.0
726
+ ).to(
727
+ dtype=dtype
728
+ ) # (B, S)
729
+ self.rollout_tally.add_metric(
730
+ path=["batch_rewards"],
731
+ rollout_tally_item=RolloutTallyItem(
732
+ crn_ids=trajectory_mb.crn_ids,
733
+ rollout_ids=trajectory_mb.rollout_ids,
734
+ agent_ids=trajectory_mb.agent_ids,
735
+ metric_matrix=rewards_mb,
736
+ ),
737
+ )
738
+ if self.reward_normalizing_constant != 1.0:
739
+ rewards_mb /= self.reward_normalizing_constant
740
+
741
+ det_vals_estimate_mb = vals_estimate_mb.detach() # (B, max_jT)
742
+ self.rollout_tally.add_metric(
743
+ path=["mb_value_estimates_critic"],
744
+ rollout_tally_item=RolloutTallyItem(
745
+ crn_ids=trajectory_mb.crn_ids,
746
+ rollout_ids=trajectory_mb.rollout_ids,
747
+ agent_ids=trajectory_mb.agent_ids,
748
+ metric_matrix=det_vals_estimate_mb,
749
+ ),
750
+ )
751
+
752
+ # Append a 0 value to the end of the value estimates
753
+ if det_vals_estimate_mb.shape[1] == rewards_mb.shape[1]:
754
+ Bsize = det_vals_estimate_mb.shape[0]
755
+ device = det_vals_estimate_mb.device
756
+ dtype = det_vals_estimate_mb.dtype
757
+ det_vals_estimate_mb = torch.cat(
758
+ [
759
+ det_vals_estimate_mb,
760
+ torch.zeros((Bsize, 1), device=device, dtype=dtype),
761
+ ],
762
+ dim=1,
763
+ ) # (B, max_jT+1)
764
+ else:
765
+ raise ValueError(
766
+ "Incompatible shapes for value estimates and rewards."
767
+ )
768
+
769
+ # Get annealed lambda
770
+ if self.use_gae_lambda_annealing:
771
+ annealing_constant = self.gae_lambda_annealing_method(
772
+ step=self.trainer_annealing_state.annealing_step_counter
773
+ )
774
+ annealed_lambda = (
775
+ self.gae_lambda_annealing_limit * annealing_constant
776
+ )
777
+ self.tally.add_metric(
778
+ path="annealed_lambda", metric=annealed_lambda
779
+ )
780
+ else:
781
+ annealed_lambda = self.gae_lambda_annealing_limit
782
+
783
+ # Get GAE advantages
784
+ gae_advantages = get_generalized_advantage_estimates(
785
+ rewards=rewards_mb,
786
+ value_estimates=det_vals_estimate_mb,
787
+ discount_factor=self.discount_factor,
788
+ lambda_coef=annealed_lambda,
789
+ ) # (B, max_jT)
790
+ self.rollout_tally.add_metric(
791
+ path=["mb_gae_advantages"],
792
+ rollout_tally_item=RolloutTallyItem(
793
+ crn_ids=trajectory_mb.crn_ids,
794
+ rollout_ids=trajectory_mb.rollout_ids,
795
+ agent_ids=trajectory_mb.agent_ids,
796
+ metric_matrix=gae_advantages,
797
+ ),
798
+ )
799
+ if training:
800
+ targets = (
801
+ gae_advantages.to(dtype=dtype) + det_vals_estimate_mb[:, :-1]
802
+ ) # (B, max_jT) # A(s, a, b) + V(s) = Q(s, a, b)
803
+ self.rollout_tally.add_metric(
804
+ path=["mb_targets_critic"],
805
+ rollout_tally_item=RolloutTallyItem(
806
+ crn_ids=trajectory_mb.crn_ids,
807
+ rollout_ids=trajectory_mb.rollout_ids,
808
+ agent_ids=trajectory_mb.agent_ids,
809
+ metric_matrix=targets,
810
+ ),
811
+ )
812
+ if self.critic_loss_type == "mse":
813
+ loss = F.mse_loss(
814
+ input=vals_estimate_mb,
815
+ target=targets,
816
+ )
817
+ elif self.critic_loss_type == "huber":
818
+ loss = F.huber_loss(
819
+ input=vals_estimate_mb,
820
+ target=targets,
821
+ )
822
+ self.tally.add_metric(path=["mb_critic_loss"], metric=loss.item())
823
+ # Accumulate gradient
824
+ loss /= normalization_factor
825
+ self.accelerator.backward(loss)
826
+ del loss
827
+ del targets
828
+ del vals_estimate_mb
829
+ del trajectory_mb
830
+ del vals_estimate_full
831
+
832
+ # Get jagged back using timestep_counts
833
+ advantages.extend(
834
+ [gae_advantages[i, : timestep_counts[i]] for i in range(B)]
835
+ )
836
+
837
+ ######################################
838
+ # use exclusively Monte Carlo returns & rloo for advantage estimation
839
+ ######################################
840
+ else:
841
+ lengths = [len(c) for c in batch_rewards]
842
+ padded_rewards = pad_sequence(
843
+ batch_rewards, batch_first=True, padding_value=0.0
844
+ )
845
+ self.rollout_tally.add_metric(
846
+ path=["mb_rewards"],
847
+ rollout_tally_item=RolloutTallyItem(
848
+ crn_ids=trajectories.crn_ids,
849
+ rollout_ids=trajectories.rollout_ids,
850
+ agent_ids=trajectories.agent_ids,
851
+ metric_matrix=padded_rewards,
852
+ ),
853
+ )
854
+ if self.reward_normalizing_constant != 1.0:
855
+ padded_rewards /= self.reward_normalizing_constant
856
+ padded_advantages = get_discounted_returns(
857
+ rewards=padded_rewards,
858
+ discount_factor=self.discount_factor,
859
+ ) # no baseline for now
860
+ if self.use_rloo:
861
+ is_grouped_by_rng = (
862
+ trajectories.crn_ids.unique().shape[0]
863
+ != trajectories.crn_ids.shape[0]
864
+ )
865
+ if is_grouped_by_rng:
866
+ for crn_id in trajectories.crn_ids.unique():
867
+ rng_mask = trajectories.crn_ids == crn_id
868
+ rng_advantages = padded_advantages[rng_mask]
869
+ rng_advantages, _ = get_rloo_credits(credits=rng_advantages)
870
+ padded_advantages[rng_mask] = rng_advantages
871
+ else:
872
+ padded_advantages, _ = get_rloo_credits(credits=padded_advantages)
873
+ self.rollout_tally.add_metric(
874
+ path=["mb_rloo_advantages"],
875
+ rollout_tally_item=RolloutTallyItem(
876
+ crn_ids=trajectories.crn_ids,
877
+ rollout_ids=trajectories.rollout_ids,
878
+ agent_ids=trajectories.agent_ids,
879
+ metric_matrix=padded_advantages,
880
+ ),
881
+ )
882
+ advantages = [
883
+ padded_advantages[i, : lengths[i]]
884
+ for i in range(padded_advantages.shape[0])
885
+ ]
886
+
887
+ if self.whiten_advantages_time_step_wise or self.whiten_advantages:
888
+ lengths = [len(c) for c in advantages]
889
+ padded_advantages = pad_sequence(
890
+ advantages, batch_first=True, padding_value=0.0
891
+ )
892
+ if self.whiten_advantages_time_step_wise:
893
+ whitened_padded_advantages = whiten_advantages_time_step_wise(
894
+ padded_advantages
895
+ )
896
+ path = ["mb_whitened_advantages_time_step_wise"]
897
+ elif self.whiten_advantages:
898
+ whitened_padded_advantages = whiten_advantages(padded_advantages)
899
+ path = ["mb_whitened_advantages"]
900
+ self.rollout_tally.add_metric(
901
+ path=path,
902
+ rollout_tally_item=RolloutTallyItem(
903
+ crn_ids=trajectories.crn_ids,
904
+ rollout_ids=trajectories.rollout_ids,
905
+ agent_ids=trajectories.agent_ids,
906
+ metric_matrix=whitened_padded_advantages,
907
+ ),
908
+ )
909
+ advantages = [
910
+ whitened_padded_advantages[i, : lengths[i]]
911
+ for i in range(whitened_padded_advantages.shape[0])
912
+ ]
913
+
914
+ self.trainer_annealing_state.annealing_step_counter += 1
915
+
916
+ return advantages
917
+
918
+ @abstractmethod
919
+ def set_agent_trajectory_data(
920
+ self, agent_id: str, roots: list[RolloutTreeRootNode]
921
+ ) -> None:
922
+ """
923
+ TOWRITE
924
+ """
925
+ pass
926
+
927
+ def set_trajectory_data(
928
+ self, roots: list[RolloutTreeRootNode], agent_ids: list[str]
929
+ ) -> None:
930
+ """
931
+ TOWRITE
932
+ """
933
+ for agent_id in agent_ids:
934
+ self.set_agent_trajectory_data(agent_id, roots)
935
+
936
+ @abstractmethod
937
+ def share_advantage_data(self) -> list[AdvantagePacket]:
938
+ pass
939
+
940
+ @abstractmethod
941
+ def receive_advantage_data(self, advantage_packets: list[AdvantagePacket]) -> None:
942
+ pass
943
+
944
+ def set_policy_gradient_data(self, agent_ids: list[str]) -> None:
945
+ """
946
+ Already set earlier # TODO: make it separate and clean
947
+ """
948
+ self.policy_gradient_data = None
949
+ # for agent_id, trajectory_batch in self.training_data.items():
950
+ # if "buffer" in agent_id:
951
+ # continue
952
+ for agent_id in agent_ids:
953
+ assert "buffer" not in agent_id, "Buffer agents do not train policy"
954
+ trajectory_batch = self.training_data[agent_id]
955
+ tokenwise_batch_credits = get_tokenwise_credits(
956
+ batch_timesteps=trajectory_batch.batch_timesteps,
957
+ batch_credits=trajectory_batch.batch_credits,
958
+ )
959
+ policy_gradient_data = TrainingBatch(
960
+ rollout_ids=trajectory_batch.rollout_ids,
961
+ batch_input_ids=trajectory_batch.batch_input_ids,
962
+ batch_action_mask=trajectory_batch.batch_action_mask,
963
+ batch_entropy_mask=trajectory_batch.batch_entropy_mask,
964
+ batch_credits=tokenwise_batch_credits,
965
+ batch_engine_log_probs=trajectory_batch.batch_engine_log_probs,
966
+ batch_timesteps=trajectory_batch.batch_timesteps,
967
+ )
968
+ if self.policy_gradient_data is None:
969
+ self.policy_gradient_data = policy_gradient_data
970
+ else:
971
+ self.policy_gradient_data.append(policy_gradient_data)
972
+
973
+ self.training_data = {}
974
+ self.tokenwise_tally = ContextualizedTokenwiseTally(
975
+ tokenizer=self.tokenizer,
976
+ paths=self.debug_path_list,
977
+ )
978
+
979
+ def train(self) -> None:
980
+ """
981
+ TOWRITE
982
+ """
983
+ assert self.policy_gradient_data is not None, "Policy gradient data is not set"
984
+ if self.critic_optimizer is not None:
985
+ if self.gradient_clipping is not None:
986
+ grad_norm = self.accelerator.clip_grad_norm_(
987
+ self.critic.parameters(), self.gradient_clipping
988
+ )
989
+ self.tally.add_metric(
990
+ path="gradient_norm_critic", metric=grad_norm.item()
991
+ )
992
+ # Take step
993
+ self.critic_optimizer.step()
994
+ self.critic_optimizer.zero_grad()
995
+ self.accelerator.clear(self.critic, self.critic_optimizer)
996
+ import gc
997
+
998
+ gc.collect()
999
+ torch.cuda.empty_cache()
1000
+ running_mean_logs = self.apply_reinforce_step(
1001
+ training_batch=self.policy_gradient_data
1002
+ )
1003
+ return running_mean_logs
1004
+
1005
+ def export_training_tally(self, identifier: str, folder: str) -> None:
1006
+ """
1007
+ Saves and resets the collected training metrics using the tally object.
1008
+ """
1009
+ os.makedirs(folder, exist_ok=True)
1010
+ self.tally.save(identifier=identifier, folder=folder)
1011
+ self.tokenwise_tally.save(
1012
+ path=os.path.join(folder, f"{identifier}_tokenwise.csv")
1013
+ )
1014
+ self.rollout_tally.save(identifier=identifier, folder=folder)
1015
+ self.tally.reset()
1016
+ self.tokenwise_tally = None
1017
+ self.rollout_tally.reset()
1018
+ self.debug_path_list = []
1019
+
1020
+ def export_optimizer_states(self) -> None:
1021
+ """
1022
+ Saves the optimizer states for both the main model and critic (if it exists).
1023
+ """
1024
+ try:
1025
+ os.makedirs(self.save_path, exist_ok=True)
1026
+
1027
+ torch.save(self.policy_optimizer.state_dict(), self.policy_optimizer_path)
1028
+ logger.info(f"Saved main optimizer state to {self.policy_optimizer_path}")
1029
+
1030
+ if self.critic_optimizer is not None:
1031
+ torch.save(
1032
+ self.critic_optimizer.state_dict(), self.critic_optimizer_path
1033
+ )
1034
+ logger.info(
1035
+ f"Saved critic optimizer state to {self.critic_optimizer_path}"
1036
+ )
1037
+ except Exception as e:
1038
+ logger.error(f"Error saving optimizer states: {str(e)}")
1039
+ raise
1040
+
1041
+ def export_trainer_annealing_state(self) -> None:
1042
+ """
1043
+ Saves the trainer state.
1044
+ """
1045
+ with open(self.trainer_annealing_state_path, "wb") as f:
1046
+ pickle.dump(self.trainer_annealing_state, f)
1047
+ logger.info(f"Saved trainer state to {self.trainer_annealing_state_path}")
1048
+
1049
+ def export_trainer_states(self) -> None:
1050
+ """
1051
+ Saves the trainer states.
1052
+ """
1053
+ self.export_optimizer_states()
1054
+ self.export_trainer_annealing_state()
src_code_for_reproducibility/training/trainer_independent.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+
3
+ """
4
+ import logging
5
+ import os
6
+ import sys
7
+ from typing import Union
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from accelerate import Accelerator
12
+ from pandas._libs.tslibs.offsets import CBMonthBegin
13
+ from peft import LoraConfig
14
+ from torch.nn.utils.rnn import pad_sequence
15
+ from transformers import AutoModelForCausalLM, AutoTokenizer
16
+
17
+ from mllm.markov_games.rollout_tree import *
18
+ from mllm.markov_games.rollout_tree import RolloutTreeRootNode
19
+ from mllm.training.credit_methods import (
20
+ get_discounted_returns,
21
+ get_discounted_state_visitation_credits,
22
+ get_generalized_advantage_estimates,
23
+ get_rloo_credits,
24
+ )
25
+ from mllm.training.tally_metrics import Tally
26
+ from mllm.training.tally_tokenwise import ContextualizedTokenwiseTally
27
+ from mllm.training.tokenize_chats import *
28
+ from mllm.training.tokenize_chats import process_training_chat
29
+ from mllm.training.trainer_common import BaseTrainer
30
+ from mllm.training.training_data_utils import *
31
+ from mllm.training.training_data_utils import (
32
+ TrainingBatch,
33
+ TrajectoryBatch,
34
+ get_tokenwise_credits,
35
+ )
36
+ from mllm.utils.resource_context import resource_logger_context
37
+
38
+ logger = logging.getLogger(__name__)
39
+ logger.addHandler(logging.StreamHandler(sys.stdout))
40
+
41
+
42
+ @dataclass
43
+ class TrainingData:
44
+ agent_id: str
45
+ main_data: TrajectoryBatch
46
+ # list-of-tensors: per rollout advantages with length jT
47
+ main_advantages: list[torch.FloatTensor] | None = None
48
+
49
+
50
+ class TrainerNaive(BaseTrainer):
51
+ def set_agent_trajectory_data(
52
+ self, agent_id: str, roots: list[RolloutTreeRootNode]
53
+ ) -> None:
54
+ """
55
+ TOWRITE
56
+ """
57
+ # TODO: append to current batch data instead, else we will only train for one agent!
58
+ self.policy_gradient_data = None
59
+
60
+ # Tensorize Chats
61
+ rollout_ids = []
62
+ crn_ids = [] # common random number id
63
+ batch_input_ids = []
64
+ batch_action_mask = []
65
+ batch_entropy_mask = []
66
+ batch_timesteps = []
67
+ batch_state_ends_mask = []
68
+ batch_engine_log_probs = []
69
+ batch_rewards = []
70
+ for root in roots:
71
+ rollout_id = root.id
72
+ self.debug_path_list.append(
73
+ "mgid:" + str(rollout_id) + "_agent_id:" + agent_id
74
+ )
75
+ rollout_ids.append(rollout_id)
76
+ crn_ids.append(root.crn_id)
77
+ chat, rewards = get_main_chat_list_and_rewards(agent_id=agent_id, root=root)
78
+ (
79
+ input_ids,
80
+ action_mask,
81
+ entropy_mask,
82
+ timesteps,
83
+ state_ends_mask,
84
+ engine_log_probs,
85
+ ) = process_training_chat(
86
+ tokenizer=self.tokenizer,
87
+ chat_history=chat,
88
+ entropy_mask_regex=self.entropy_mask_regex,
89
+ exploration_prompts_to_remove=self.exploration_prompts_to_remove,
90
+ )
91
+ batch_input_ids.append(input_ids)
92
+ batch_action_mask.append(action_mask)
93
+ batch_entropy_mask.append(entropy_mask)
94
+ batch_timesteps.append(timesteps)
95
+ batch_state_ends_mask.append(state_ends_mask)
96
+ batch_engine_log_probs.append(engine_log_probs)
97
+ batch_rewards.append(rewards)
98
+
99
+ trajectory_batch = TrajectoryBatch(
100
+ rollout_ids=torch.tensor(rollout_ids, dtype=torch.int32),
101
+ crn_ids=torch.tensor(crn_ids, dtype=torch.int32),
102
+ agent_ids=[agent_id] * len(rollout_ids),
103
+ batch_input_ids=batch_input_ids,
104
+ batch_action_mask=batch_action_mask,
105
+ batch_entropy_mask=batch_entropy_mask,
106
+ batch_timesteps=batch_timesteps,
107
+ batch_state_ends_mask=batch_state_ends_mask,
108
+ batch_rewards=batch_rewards,
109
+ batch_engine_log_probs=batch_engine_log_probs,
110
+ )
111
+
112
+ # Get Advantages
113
+ batch_advantages: torch.FloatTensor = (
114
+ self.get_advantages_with_critic_gradient_accumulation(trajectory_batch)
115
+ )
116
+
117
+ # Discount state visitation (the mathematically correct way)
118
+ if not self.skip_discounted_state_visitation:
119
+ for i in range(len(batch_advantages)):
120
+ batch_advantages[i] = get_discounted_state_visitation_credits(
121
+ batch_advantages[i].unsqueeze(0),
122
+ self.discount_factor,
123
+ ).squeeze(0)
124
+
125
+ self.training_data[agent_id] = TrainingData(
126
+ agent_id=agent_id,
127
+ main_data=trajectory_batch,
128
+ main_advantages=batch_advantages,
129
+ )
130
+
131
+ def receive_advantage_data(self, advantage_packets: list[AdvantagePacket]):
132
+ """
133
+ This trainer ignores the advantages of the other trainers.
134
+ """
135
+ for agent_id, agent_data in self.training_data.items():
136
+ self.training_data[agent_id] = agent_data.main_data
137
+ self.training_data[agent_id].batch_credits = agent_data.main_advantages
138
+
139
+ def share_advantage_data(self) -> list[AdvantagePacket]:
140
+ """
141
+ Share the advantage data with other agents.
142
+ Returns:
143
+ AdvantagePacket: The advantage packet containing the agent's advantages.
144
+ """
145
+ logger.info(f"Sharing advantage data.")
146
+ advantage_packets = []
147
+ for agent_id, agent_data in self.training_data.items():
148
+ advantage_packets.append(
149
+ AdvantagePacket(
150
+ agent_id=agent_id,
151
+ rollout_ids=agent_data.main_data.rollout_ids,
152
+ main_advantages=agent_data.main_advantages,
153
+ )
154
+ )
155
+ return advantage_packets
src_code_for_reproducibility/training/trainer_sum_rewards.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+
3
+ """
4
+ import logging
5
+ import os
6
+ import sys
7
+ from typing import Union
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from accelerate import Accelerator
12
+ from pandas._libs.tslibs.offsets import CBMonthBegin
13
+ from peft import LoraConfig
14
+ from torch.nn.utils.rnn import pad_sequence
15
+ from transformers import AutoModelForCausalLM, AutoTokenizer
16
+
17
+ from mllm.markov_games.rollout_tree import *
18
+ from mllm.markov_games.rollout_tree import RolloutTreeRootNode
19
+ from mllm.training.credit_methods import (
20
+ get_discounted_returns,
21
+ get_discounted_state_visitation_credits,
22
+ get_generalized_advantage_estimates,
23
+ get_rloo_credits,
24
+ )
25
+ from mllm.training.tally_metrics import Tally
26
+ from mllm.training.tally_rollout import RolloutTally, RolloutTallyItem
27
+ from mllm.training.tally_tokenwise import ContextualizedTokenwiseTally
28
+ from mllm.training.tokenize_chats import *
29
+ from mllm.training.tokenize_chats import process_training_chat
30
+ from mllm.training.trainer_common import BaseTrainer
31
+ from mllm.training.trainer_independent import TrainerNaive, TrainingData
32
+ from mllm.training.training_data_utils import *
33
+ from mllm.training.training_data_utils import (
34
+ AdvantagePacket,
35
+ TrainingBatch,
36
+ TrajectoryBatch,
37
+ get_tokenwise_credits,
38
+ )
39
+ from mllm.utils.resource_context import resource_logger_context
40
+
41
+ logger = logging.getLogger(__name__)
42
+ logger.addHandler(logging.StreamHandler(sys.stdout))
43
+
44
+
45
+ class TrainerSumRewards(TrainerNaive):
46
+ def receive_advantage_data(self, advantage_packets: list[AdvantagePacket]):
47
+ """
48
+ Sums the advantages of the other trainers
49
+ """
50
+ logger.info(f"Receiving advantage packets.")
51
+
52
+ assert (
53
+ len(advantage_packets) > 0
54
+ ), "At least one advantage packet must be provided."
55
+
56
+ for agent_id, agent_data in self.training_data.items():
57
+ coagent_advantage_packets = [
58
+ packet for packet in advantage_packets if packet.agent_id != agent_id
59
+ ]
60
+ agent_rollout_ids = agent_data.main_data.rollout_ids
61
+ agent_advantages = agent_data.main_advantages
62
+ co_agent_advantages = []
63
+ for rollout_id in agent_rollout_ids:
64
+ for co_agent_packet in coagent_advantage_packets:
65
+ if rollout_id in co_agent_packet.rollout_ids:
66
+ index = torch.where(rollout_id == co_agent_packet.rollout_ids)[
67
+ 0
68
+ ].item()
69
+ co_agent_advantages.append(
70
+ co_agent_packet.main_advantages[index]
71
+ )
72
+ # assumes that its two player game, with one co-agent
73
+ break
74
+ assert len(co_agent_advantages) == len(agent_advantages)
75
+ B = len(agent_advantages)
76
+ assert all(
77
+ a.shape[0] == b.shape[0]
78
+ for a, b in zip(co_agent_advantages, agent_advantages)
79
+ ), "Number of advantages must match in order to sum them up."
80
+
81
+ # Get padded tensors (advantage alignment is invariant to padding)
82
+ lengths = torch.tensor(
83
+ [len(t) for t in agent_advantages],
84
+ device=self.device,
85
+ dtype=torch.long,
86
+ )
87
+ padded_main_advantages = pad_sequence(
88
+ agent_advantages, batch_first=True, padding_value=0.0
89
+ )
90
+
91
+ padded_co_agent_advantages = pad_sequence(
92
+ co_agent_advantages, batch_first=True, padding_value=0.0
93
+ )
94
+
95
+ # Create training batch data
96
+ sum_of_ad_credits = padded_main_advantages + padded_co_agent_advantages
97
+ self.rollout_tally.add_metric(
98
+ path=["sum_of_ad_credits"],
99
+ rollout_tally_item=RolloutTallyItem(
100
+ crn_ids=agent_data.main_data.crn_ids,
101
+ rollout_ids=agent_data.main_data.rollout_ids,
102
+ agent_ids=agent_data.main_data.agent_ids,
103
+ metric_matrix=sum_of_ad_credits,
104
+ ),
105
+ )
106
+
107
+ if not self.skip_discounted_state_visitation:
108
+ sum_of_ad_credits = get_discounted_state_visitation_credits(
109
+ sum_of_ad_credits,
110
+ self.discount_factor,
111
+ )
112
+ self.rollout_tally.add_metric(
113
+ path=["discounted_state_visitation_credits"],
114
+ rollout_tally_item=RolloutTallyItem(
115
+ crn_ids=agent_data.main_data.crn_ids,
116
+ rollout_ids=agent_data.main_data.rollout_ids,
117
+ agent_ids=agent_data.main_data.agent_ids,
118
+ metric_matrix=sub_tensors[
119
+ "discounted_state_visitation_credits"
120
+ ],
121
+ ),
122
+ )
123
+
124
+ # Slice back to jagged and convert to tokenwise credits
125
+ sum_of_ad_credits = [sum_of_ad_credits[i, : lengths[i]] for i in range(B)]
126
+ self.training_data[agent_id] = agent_data.main_data
127
+ self.training_data[agent_id].batch_credits = sum_of_ad_credits
src_code_for_reproducibility/training/training_data_utils.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Literal, Optional, Tuple
3
+
4
+ import torch
5
+ from torch.nn.utils.rnn import pad_sequence
6
+
7
+ from mllm.markov_games.rollout_tree import (
8
+ ChatTurn,
9
+ RolloutTreeBranchNode,
10
+ RolloutTreeNode,
11
+ RolloutTreeRootNode,
12
+ )
13
+
14
+
15
+ @dataclass
16
+ class AdvantagePacket:
17
+ agent_id: str
18
+ rollout_ids: torch.IntTensor # (B,)
19
+ # list-of-tensors
20
+ main_advantages: list[torch.FloatTensor]
21
+
22
+
23
+ class TrainingChatTurn:
24
+ # TODO: simplify by making this a child of ChatTurn
25
+ """
26
+ This class contains the chat turns for a single agent.
27
+ It is like ChatTurn, but with the time step added.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ time_step: int,
33
+ role: str,
34
+ agent_id: str,
35
+ content: str,
36
+ chat_template_token_ids: list[int],
37
+ reasoning_content: str,
38
+ is_state_end: bool,
39
+ out_token_ids: Optional[list[int]] = None,
40
+ log_probs: Optional[list[float]] = None,
41
+ ) -> None:
42
+ self.time_step = time_step
43
+ self.role = role
44
+ self.agent_id = agent_id
45
+ self.content = content
46
+ self.chat_template_token_ids = chat_template_token_ids
47
+ self.reasoning_content = reasoning_content
48
+ self.is_state_end = is_state_end
49
+ self.out_token_ids = out_token_ids
50
+ self.log_probs = log_probs
51
+
52
+ def dict(self):
53
+ return {
54
+ "time_step": self.time_step,
55
+ "role": self.role,
56
+ "agent_id": self.agent_id,
57
+ "content": self.content,
58
+ "chat_template_token_ids": self.chat_template_token_ids,
59
+ "reasoning_content": self.reasoning_content,
60
+ "is_state_end": self.is_state_end,
61
+ "out_token_ids": self.out_token_ids,
62
+ "log_probs": self.log_probs,
63
+ }
64
+
65
+
66
+ def get_main_chat_list_and_rewards(
67
+ agent_id: str, root: RolloutTreeRootNode | RolloutTreeNode
68
+ ) -> Tuple[list[TrainingChatTurn], torch.FloatTensor]:
69
+ """
70
+ This method traverses a rollout tree and returns a the list of ChatTurn
71
+ for an agent. If it encounters a branch node, it follows the main path.
72
+ """
73
+ # TODO; extend for all trees, not just linear
74
+ if isinstance(root, RolloutTreeRootNode):
75
+ current_node = root.child
76
+ else:
77
+ current_node = root
78
+
79
+ chat = []
80
+ rewards = []
81
+ while current_node is not None:
82
+ if isinstance(current_node, RolloutTreeBranchNode):
83
+ current_node = current_node.main_child
84
+ reward: float = current_node.step_log.simulation_step_log.rewards[agent_id]
85
+ rewards.append(reward)
86
+ chat_turns: list[TrainingChatTurn] = current_node.step_log.action_logs[
87
+ agent_id
88
+ ].chat_turns
89
+ chat_turns = [
90
+ TrainingChatTurn(time_step=current_node.time_step, **turn.model_dump())
91
+ for turn in chat_turns
92
+ ]
93
+ chat.extend(chat_turns)
94
+ current_node = current_node.child
95
+ return chat, torch.FloatTensor(rewards)
96
+
97
+
98
+ def get_tokenwise_credits(
99
+ # B := batch size, S := number of tokens / seq. length, T := number of states. `j` stands for jagged (see pytorch nested tensors.)
100
+ batch_timesteps: torch.IntTensor | torch.Tensor, # (B, jS),
101
+ batch_credits: torch.FloatTensor | torch.Tensor, # (B, jT)
102
+ ) -> torch.FloatTensor | torch.Tensor: # (B, jS)
103
+ """
104
+ TOWRITE
105
+ """
106
+ # TODO vectorize this code
107
+ batch_token_credits = []
108
+ for credits, timesteps in zip(batch_credits, batch_timesteps):
109
+ token_credits = torch.zeros_like(
110
+ timesteps,
111
+ dtype=credits.dtype,
112
+ device=timesteps.device,
113
+ )
114
+ for idx, credit in enumerate(credits):
115
+ token_credits[timesteps == idx] = credit
116
+ batch_token_credits.append(token_credits)
117
+ return batch_token_credits
118
+
119
+
120
+ @dataclass
121
+ class TrajectoryBatch:
122
+ """
123
+ Tensorized batch of trajectories using list-of-tensors for jagged dimensions.
124
+ """
125
+
126
+ # B := batch size, S := number of tokens / seq. length, T := number of states.
127
+ rollout_ids: torch.IntTensor # (B,)
128
+ crn_ids: torch.IntTensor # (B,)
129
+ agent_ids: list[str] # (B,)
130
+ batch_input_ids: list[torch.LongTensor] # List[(jS,)]
131
+ batch_action_mask: list[torch.BoolTensor] # List[(jS,)]
132
+ batch_entropy_mask: list[torch.BoolTensor] # List[(jS,)]
133
+ batch_timesteps: list[torch.IntTensor] # List[(jS,)]
134
+ batch_state_ends_mask: list[torch.BoolTensor] # List[(jS,)]
135
+ batch_engine_log_probs: Optional[list[torch.FloatTensor]] # List[(jS,)]
136
+ batch_rewards: list[torch.FloatTensor] # List[(jT,)]
137
+ batch_credits: Optional[list[torch.FloatTensor]] = None # List[(jS,)]
138
+
139
+ def __post_init__(self):
140
+ """
141
+ Validate per-sample consistency.
142
+ """
143
+ B = self.rollout_ids.shape[0]
144
+ assert (
145
+ self.crn_ids.shape[0] == B
146
+ ), "RNG IDs must have length equal to batch size."
147
+ assert (
148
+ len(self.agent_ids) == B
149
+ ), "agent_ids must have length equal to batch size."
150
+ assert (
151
+ len(self.batch_input_ids)
152
+ == len(self.batch_action_mask)
153
+ == len(self.batch_entropy_mask)
154
+ == len(self.batch_timesteps)
155
+ == len(self.batch_state_ends_mask)
156
+ == len(self.batch_engine_log_probs)
157
+ == len(self.batch_rewards)
158
+ == B
159
+ ), "Jagged lists must all have length equal to batch size."
160
+
161
+ for b in range(B):
162
+ nb_rewards = int(self.batch_rewards[b].shape[0])
163
+ nb_timesteps = int(torch.max(self.batch_timesteps[b]).item()) + 1
164
+ assert (
165
+ nb_rewards == nb_timesteps
166
+ ), "Number of rewards and timesteps mismatch."
167
+ assert (
168
+ self.batch_input_ids[b].shape[0]
169
+ == self.batch_action_mask[b].shape[0]
170
+ == self.batch_entropy_mask[b].shape[0]
171
+ == self.batch_engine_log_probs[b].shape[0]
172
+ == self.batch_timesteps[b].shape[0]
173
+ ), "Tensors must have the same shape along the jagged dimension."
174
+ assert (
175
+ int(self.batch_state_ends_mask[b].sum())
176
+ == self.batch_rewards[b].shape[0]
177
+ ), "Number of rewards must match number of state ends."
178
+
179
+ """
180
+ Entries:
181
+ Here, we ignore the batch dimension.
182
+ input_ids:
183
+ All of the tokens of both the user and the assistant, flattened.
184
+ action_mask:
185
+ Set to true on the tokens of the assistant (tokens generated by the model).
186
+ timesteps:
187
+ Therefore, max(timesteps) = Ns - 1.
188
+ state_ends_idx:
189
+ Indices of the tokens at which state descriptions end.
190
+ rewards:
191
+ rewards[t] := R_t(s_t, a_t)
192
+ Example:
193
+ position: "0 1 2 3 4 5 6 7 8 9 10 11 12 13 14"
194
+ input_ids: "U U U a a a U a U a a a U U U" (U := User, a := Assistant)
195
+ action_mask: "x x x ✓ ✓ ✓ x ✓ x ✓ ✓ ✓ x x x"
196
+ timestep: "0 0 0 0 0 0 1 1 1 1 1 1 2 2 2"
197
+ state_ends_dx: [2, 6, 14]
198
+ rewards: [r0, r1, r2]
199
+ """
200
+
201
+ def __getitem__(self, key) -> "TrajectoryBatch":
202
+ if isinstance(key, slice):
203
+ return TrajectoryBatch(
204
+ rollout_ids=self.rollout_ids.__getitem__(key),
205
+ crn_ids=self.crn_ids.__getitem__(key),
206
+ agent_ids=self.agent_ids[key],
207
+ batch_input_ids=self.batch_input_ids[key],
208
+ batch_action_mask=self.batch_action_mask[key],
209
+ batch_entropy_mask=self.batch_entropy_mask[key],
210
+ batch_timesteps=self.batch_timesteps[key],
211
+ batch_state_ends_mask=self.batch_state_ends_mask[key],
212
+ batch_engine_log_probs=self.batch_engine_log_probs[key],
213
+ batch_rewards=self.batch_rewards[key],
214
+ batch_credits=self.batch_credits[key] if self.batch_credits else None,
215
+ )
216
+
217
+ def __len__(self):
218
+ return len(self.batch_input_ids)
219
+
220
+ def to(self, device):
221
+ self.rollout_ids = self.rollout_ids.to(device)
222
+ self.crn_ids = self.crn_ids.to(device)
223
+ self.batch_input_ids = [t.to(device) for t in self.batch_input_ids]
224
+ self.batch_action_mask = [t.to(device) for t in self.batch_action_mask]
225
+ self.batch_entropy_mask = [t.to(device) for t in self.batch_entropy_mask]
226
+ self.batch_timesteps = [t.to(device) for t in self.batch_timesteps]
227
+ self.batch_state_ends_mask = [t.to(device) for t in self.batch_state_ends_mask]
228
+ self.batch_engine_log_probs = [
229
+ t.to(device) for t in self.batch_engine_log_probs
230
+ ]
231
+ self.batch_rewards = [t.to(device) for t in self.batch_rewards]
232
+ self.batch_credits = (
233
+ [t.to(device) for t in self.batch_credits] if self.batch_credits else None
234
+ )
235
+
236
+ def get_padded_tensors_for_critic(self):
237
+ """
238
+ Returns:
239
+ padded_batch_input_ids: (B, P)
240
+ padded_batch_state_ends_mask: (B, P)
241
+ timestep_counts: (B,) tensor of ints indicating number of states per sample
242
+ """
243
+ padded_batch_input_ids = pad_sequence(
244
+ self.batch_input_ids, batch_first=True, padding_value=0
245
+ )
246
+ padded_batch_state_ends_mask = pad_sequence(
247
+ self.batch_state_ends_mask, batch_first=True, padding_value=0
248
+ ).bool()
249
+ # number of states equals number of True in state_ends_mask
250
+ timestep_counts = torch.tensor(
251
+ [int(mask.sum().item()) for mask in self.batch_state_ends_mask],
252
+ device=padded_batch_input_ids.device,
253
+ dtype=torch.long,
254
+ )
255
+ return padded_batch_input_ids, padded_batch_state_ends_mask, timestep_counts
256
+
257
+
258
+ timestep = int
259
+
260
+
261
+ @dataclass
262
+ class PaddedTensorTrainingBatch:
263
+ batch_input_ids: torch.LongTensor | torch.Tensor
264
+ batch_action_mask: torch.BoolTensor | torch.Tensor
265
+ batch_entropy_mask: Optional[torch.BoolTensor | torch.Tensor]
266
+ batch_credits: torch.FloatTensor | torch.Tensor
267
+ batch_engine_log_probs: torch.FloatTensor | torch.Tensor
268
+ batch_timesteps: torch.IntTensor | torch.Tensor
269
+
270
+ def __len__(self):
271
+ return self.batch_input_ids.shape[0]
272
+
273
+ def to(self, device):
274
+ self.batch_input_ids = self.batch_input_ids.to(device)
275
+ self.batch_action_mask = self.batch_action_mask.to(device)
276
+ self.batch_entropy_mask = self.batch_entropy_mask.to(device)
277
+ self.batch_credits = self.batch_credits.to(device)
278
+ self.batch_engine_log_probs = self.batch_engine_log_probs.to(device)
279
+ self.batch_timesteps = self.batch_timesteps.to(device)
280
+
281
+
282
+ @dataclass
283
+ class TrainingBatch:
284
+ rollout_ids: torch.IntTensor | torch.Tensor # (B,)
285
+ batch_input_ids: list[torch.LongTensor] # List[(jS,)]
286
+ batch_action_mask: list[torch.BoolTensor] # List[(jS,)]
287
+ batch_entropy_mask: Optional[list[torch.BoolTensor]] # List[(jS,)]
288
+ batch_credits: list[torch.FloatTensor] # List[(jS,)]
289
+ batch_engine_log_probs: list[torch.FloatTensor] # List[(jS,)]
290
+ batch_timesteps: list[torch.IntTensor] # List[(jS,)]
291
+
292
+ def __post_init__(self):
293
+ # Put everything in the right device
294
+ # self.rollout_ids = self.rollout_ids.to("cuda" if torch.cuda.is_available() else "cpu")
295
+ # self.batch_input_ids = self.batch_input_ids.to("cuda" if torch.cuda.is_available() else "cpu")
296
+ # self.batch_action_mask = self.batch_action_mask.to("cuda" if torch.cuda.is_available() else "cpu")
297
+ # self.batch_credits = self.batch_credits.to("cuda" if torch.cuda.is_available() else "cpu")
298
+ # Ensure batch dimension is present
299
+ assert (
300
+ len(self.batch_input_ids)
301
+ == len(self.batch_action_mask)
302
+ == len(self.batch_entropy_mask)
303
+ == len(self.batch_credits)
304
+ == len(self.batch_engine_log_probs)
305
+ == len(self.batch_timesteps)
306
+ == self.rollout_ids.shape[0]
307
+ ), "Jagged lists must all have length equal to batch size."
308
+ for inp, mask, cred, engine_log_prob, timestep in zip(
309
+ self.batch_input_ids,
310
+ self.batch_action_mask,
311
+ self.batch_credits,
312
+ self.batch_engine_log_probs,
313
+ self.batch_timesteps,
314
+ ):
315
+ assert (
316
+ inp.shape[0]
317
+ == mask.shape[0]
318
+ == cred.shape[0]
319
+ == engine_log_prob.shape[0]
320
+ == timestep.shape[0]
321
+ ), "Tensors must have the same shapes along the jagged dimension."
322
+
323
+ def __getitem__(self, key) -> "TrainingBatch":
324
+ if isinstance(key, slice):
325
+ return TrainingBatch(
326
+ rollout_ids=self.rollout_ids.__getitem__(key),
327
+ batch_input_ids=self.batch_input_ids[key],
328
+ batch_action_mask=self.batch_action_mask[key],
329
+ batch_entropy_mask=self.batch_entropy_mask[key],
330
+ batch_credits=self.batch_credits[key],
331
+ batch_engine_log_probs=self.batch_engine_log_probs[key],
332
+ batch_timesteps=self.batch_timesteps[key],
333
+ )
334
+
335
+ def __len__(self):
336
+ return len(self.batch_input_ids)
337
+
338
+ def to(self, device):
339
+ self.rollout_ids = self.rollout_ids.to(device)
340
+ self.batch_input_ids = [t.to(device) for t in self.batch_input_ids]
341
+ self.batch_action_mask = [t.to(device) for t in self.batch_action_mask]
342
+ self.batch_entropy_mask = [t.to(device) for t in self.batch_entropy_mask]
343
+ self.batch_credits = [t.to(device) for t in self.batch_credits]
344
+ self.batch_engine_log_probs = [
345
+ t.to(device) for t in self.batch_engine_log_probs
346
+ ]
347
+ self.batch_timesteps = [t.to(device) for t in self.batch_timesteps]
348
+
349
+ def get_padded_tensors(self, padding: float = 0.0):
350
+ """
351
+ TOWRITE
352
+ Always pad to the right.
353
+ """
354
+ padded_batch_input_ids = pad_sequence(
355
+ self.batch_input_ids, batch_first=True, padding_value=int(padding)
356
+ )
357
+ padded_batch_action_mask = pad_sequence(
358
+ [m.to(dtype=torch.bool) for m in self.batch_action_mask],
359
+ batch_first=True,
360
+ padding_value=False,
361
+ )
362
+ padded_batch_entropy_mask = pad_sequence(
363
+ self.batch_entropy_mask, batch_first=True, padding_value=False
364
+ )
365
+ padded_batch_credits = pad_sequence(
366
+ self.batch_credits, batch_first=True, padding_value=float(padding)
367
+ )
368
+ padded_batch_engine_log_probs = pad_sequence(
369
+ self.batch_engine_log_probs, batch_first=True, padding_value=float(padding)
370
+ )
371
+ padded_batch_timesteps = pad_sequence(
372
+ self.batch_timesteps, batch_first=True, padding_value=0
373
+ )
374
+
375
+ return PaddedTensorTrainingBatch(
376
+ padded_batch_input_ids,
377
+ padded_batch_action_mask,
378
+ padded_batch_entropy_mask,
379
+ padded_batch_credits,
380
+ padded_batch_engine_log_probs,
381
+ padded_batch_timesteps,
382
+ )
383
+
384
+ def append(self, other: "TrainingBatch"):
385
+ self.rollout_ids = torch.cat([self.rollout_ids, other.rollout_ids])
386
+ self.batch_input_ids.extend(other.batch_input_ids)
387
+ self.batch_action_mask.extend(other.batch_action_mask)
388
+ self.batch_entropy_mask.extend(other.batch_entropy_mask)
389
+ self.batch_credits.extend(other.batch_credits)
390
+ self.batch_engine_log_probs.extend(other.batch_engine_log_probs)
391
+ self.batch_timesteps.extend(other.batch_timesteps)
392
+
393
+
394
+ timestep = int