Muqeeth commited on
Commit
a8721a2
·
verified ·
1 Parent(s): b613242

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .hydra/config.yaml +183 -0
  2. .hydra/hydra.yaml +154 -0
  3. .hydra/overrides.yaml +1 -0
  4. run.log +0 -0
  5. seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/README.md +207 -0
  6. seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/agent_adapter/adapter_config.json +42 -0
  7. seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/critic_adapter/adapter_config.json +42 -0
  8. seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/fixed_ad_align_adapter/adapter_config.json +42 -0
  9. src_code_for_reproducibility/__init__.py +0 -0
  10. src_code_for_reproducibility/__pycache__/__init__.cpython-312.pyc +0 -0
  11. src_code_for_reproducibility/chat_utils/__pycache__/apply_template.cpython-312.pyc +0 -0
  12. src_code_for_reproducibility/chat_utils/chat_turn.py +27 -0
  13. src_code_for_reproducibility/chat_utils/template_specific.py +109 -0
  14. src_code_for_reproducibility/docs/Makefile +19 -0
  15. src_code_for_reproducibility/docs/generate_docs.py +249 -0
  16. src_code_for_reproducibility/docs/make.bat +35 -0
  17. src_code_for_reproducibility/docs/source/environments/diplomacy.rst +459 -0
  18. src_code_for_reproducibility/docs/source/src.experiments.dond_run_train.rst +7 -0
  19. src_code_for_reproducibility/docs/source/src.models.dummy_hf_agent.rst +7 -0
  20. src_code_for_reproducibility/markov_games/__init__.py +0 -0
  21. src_code_for_reproducibility/markov_games/agent.py +76 -0
  22. src_code_for_reproducibility/markov_games/alternative_actions_runner.py +138 -0
  23. src_code_for_reproducibility/markov_games/group_timesteps.py +150 -0
  24. src_code_for_reproducibility/markov_games/linear_runner.py +30 -0
  25. src_code_for_reproducibility/markov_games/markov_game.py +208 -0
  26. src_code_for_reproducibility/markov_games/mg_utils.py +89 -0
  27. src_code_for_reproducibility/markov_games/negotiation/__pycache__/negotiation_statistics.cpython-312.pyc +0 -0
  28. src_code_for_reproducibility/markov_games/rollout_tree.py +86 -0
  29. src_code_for_reproducibility/markov_games/run_markov_games.py +24 -0
  30. src_code_for_reproducibility/markov_games/simulation.py +87 -0
  31. src_code_for_reproducibility/markov_games/statistics_runner.py +405 -0
  32. src_code_for_reproducibility/markov_games/vine_ppo.py +10 -0
  33. src_code_for_reproducibility/models/__init__.py +0 -0
  34. src_code_for_reproducibility/models/adapter_training_wrapper.py +98 -0
  35. src_code_for_reproducibility/models/human_policy.py +255 -0
  36. src_code_for_reproducibility/models/inference_backend.py +39 -0
  37. src_code_for_reproducibility/models/inference_backend_dummy.py +54 -0
  38. src_code_for_reproducibility/models/inference_backend_sglang.py +86 -0
  39. src_code_for_reproducibility/models/inference_backend_sglang_local_server.py +127 -0
  40. src_code_for_reproducibility/models/inference_backend_vllm.py +118 -0
  41. src_code_for_reproducibility/models/inference_backend_vllm_local_server.py +160 -0
  42. src_code_for_reproducibility/models/large_language_model_api.py +171 -0
  43. src_code_for_reproducibility/models/large_language_model_local.py +384 -0
  44. src_code_for_reproducibility/models/scalar_critic.py +54 -0
  45. src_code_for_reproducibility/training/README.md +20 -0
  46. src_code_for_reproducibility/training/__init__.py +0 -0
  47. src_code_for_reproducibility/training/annealing_methods.py +6 -0
  48. src_code_for_reproducibility/training/credit_methods.py +304 -0
  49. src_code_for_reproducibility/training/tally_metrics.py +55 -0
  50. src_code_for_reproducibility/training/tally_rollout.py +137 -0
.hydra/config.yaml ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ experiment:
2
+ wandb_enabled: true
3
+ nb_epochs: 3000
4
+ nb_matches_per_iteration: 64
5
+ reinit_matches_each_it: true
6
+ checkpoint_every_n_iterations: 50
7
+ start_epoch: 0
8
+ resume_experiment: true
9
+ base_seed: 0
10
+ seed_group_size: 8
11
+ train: true
12
+ stat_methods_for_live_wandb: mllm.markov_games.negotiation.negotiation_statistics
13
+ name: naive_vs_fixed_ad_align_seed9999
14
+ agent_buffer: false
15
+ keep_agent_buffer_count: ${lora_count}
16
+ agent_buffer_recent_k: -1
17
+ description: Trust-and-Split Rock Paper Scissors negotiation game
18
+ logging:
19
+ wandb:
20
+ enabled: false
21
+ project: llm-negotiation
22
+ entity: null
23
+ mode: online
24
+ name: null
25
+ group: null
26
+ tags: []
27
+ notes: null
28
+ temperature: 1.0
29
+ markov_games:
30
+ runner_method_name: LinearRunner
31
+ runner_kwargs: {}
32
+ group_by_round: true
33
+ simulation_class_name: TrustAndSplitRPSSimulation
34
+ simulation_init_args:
35
+ nb_of_rounds: 10
36
+ quota_messages_per_agent_per_round: 1
37
+ alternating_hands: false
38
+ agents:
39
+ 0:
40
+ agent_id: ${agent_0_id}
41
+ agent_name: Alice
42
+ agent_class_name: TrustAndSplitRPSAgent
43
+ policy_id: base_llm/agent_adapter
44
+ init_kwargs:
45
+ goal: Maximize your total points over the whole game.
46
+ num_message_chars: 500
47
+ message_start_end_format: true
48
+ proposal_start_end_format: true
49
+ 1:
50
+ agent_id: ${agent_1_id}
51
+ agent_name: Bob
52
+ agent_class_name: TrustAndSplitRPSAgent
53
+ policy_id: base_llm/fixed_ad_align_adapter
54
+ init_kwargs:
55
+ goal: Maximize your total points over the whole game.
56
+ num_message_chars: 500
57
+ message_start_end_format: true
58
+ proposal_start_end_format: true
59
+ models:
60
+ base_llm:
61
+ class: LeanLocalLLM
62
+ init_args:
63
+ llm_id: base_llm
64
+ model_name: Qwen/Qwen2.5-7B-Instruct
65
+ inference_backend: vllm
66
+ hf_kwargs:
67
+ device_map: auto
68
+ torch_dtype: bfloat16
69
+ max_memory:
70
+ 0: 20GiB
71
+ attn_implementation: flash_attention_2
72
+ inference_backend_init_kwargs:
73
+ enable_lora: true
74
+ seed: ${experiment.base_seed}
75
+ enable_prefix_caching: true
76
+ max_model_len: 10000.0
77
+ gpu_memory_utilization: 0.5
78
+ dtype: bfloat16
79
+ trust_remote_code: true
80
+ max_lora_rank: 32
81
+ enforce_eager: false
82
+ max_loras: ${lora_count}
83
+ max_cpu_loras: ${lora_count}
84
+ enable_sleep_mode: true
85
+ inference_backend_sampling_params:
86
+ temperature: ${temperature}
87
+ top_p: 1.0
88
+ max_tokens: 400
89
+ top_k: -1
90
+ logprobs: 0
91
+ adapter_configs:
92
+ agent_adapter:
93
+ task_type: CAUSAL_LM
94
+ r: 32
95
+ lora_alpha: 64
96
+ lora_dropout: 0.0
97
+ target_modules: all-linear
98
+ critic_adapter:
99
+ task_type: CAUSAL_LM
100
+ r: 32
101
+ lora_alpha: 64
102
+ lora_dropout: 0.0
103
+ target_modules: all-linear
104
+ fixed_ad_align_adapter:
105
+ task_type: CAUSAL_LM
106
+ r: 32
107
+ lora_alpha: 64
108
+ lora_dropout: 0.0
109
+ target_modules: all-linear
110
+ enable_thinking: null
111
+ regex_max_attempts: 1
112
+ initial_adapter_paths:
113
+ fixed_ad_align_adapter: ${fixed_ad_align_adapter_path}
114
+ critics:
115
+ agent_critic:
116
+ module_pointer:
117
+ - base_llm
118
+ - critic_adapter
119
+ optimizers:
120
+ agent_optimizer:
121
+ module_pointer:
122
+ - base_llm
123
+ - agent_adapter
124
+ optimizer_class_name: torch.optim.Adam
125
+ init_args:
126
+ lr: 3.0e-06
127
+ weight_decay: 0.0
128
+ critic_optimizer:
129
+ module_pointer: agent_critic
130
+ optimizer_class_name: torch.optim.Adam
131
+ init_args:
132
+ lr: 3.0e-06
133
+ weight_decay: 0.0
134
+ trainers:
135
+ agent_trainer:
136
+ class: TrainerNaive
137
+ module_pointers:
138
+ policy:
139
+ - base_llm
140
+ - agent_adapter
141
+ policy_optimizer: agent_optimizer
142
+ critic: agent_critic
143
+ critic_optimizer: critic_optimizer
144
+ kwargs:
145
+ entropy_coeff: 0.0
146
+ entropy_topk: null
147
+ entropy_mask_regex: null
148
+ kl_coeff: 0.001
149
+ gradient_clipping: 1.0
150
+ restrict_tokens: null
151
+ mini_batch_size: 1
152
+ use_gradient_checkpointing: true
153
+ temperature: ${temperature}
154
+ device: cuda:0
155
+ use_gae: false
156
+ whiten_advantages: false
157
+ whiten_advantages_time_step_wise: false
158
+ skip_discounted_state_visitation: true
159
+ use_gae_lambda_annealing: false
160
+ gae_lambda_annealing_method: None
161
+ gae_lambda_annealing_method_params: None
162
+ gae_lambda_annealing_limit: 0.95
163
+ discount_factor: 0.96
164
+ use_rloo: true
165
+ enable_tokenwise_logging: false
166
+ pg_loss_normalization: nb_tokens
167
+ truncated_importance_sampling_ratio_cap: 2.0
168
+ reward_normalizing_constant: 100.0
169
+ train_on_which_data:
170
+ agent_trainer:
171
+ - Alice
172
+ lora_count: 30
173
+ common_agent_kwargs:
174
+ goal: Maximize your total points over the whole game.
175
+ num_message_chars: 500
176
+ message_start_end_format: true
177
+ proposal_start_end_format: true
178
+ agent_0_id: Alice
179
+ agent_1_id: Bob
180
+ agent_ids:
181
+ - Alice
182
+ - Bob
183
+ fixed_ad_align_adapter_path: /home/muqeeth/scratch/llm_negotiation/2025_11/tas_rps_startend_ad_align_nocurrtimestep_seed9999_beta2/seed_9999/Qwen/Qwen2.5-7B-Instruct/adapters/agent_adapter
.hydra/hydra.yaml ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: ${oc.env:SCRATCH}/llm_negotiation/${now:%Y_%m}/${experiment.name}
4
+ sweep:
5
+ dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
6
+ subdir: ${hydra.job.num}
7
+ launcher:
8
+ _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
9
+ sweeper:
10
+ _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
11
+ max_batch_size: null
12
+ params: null
13
+ help:
14
+ app_name: ${hydra.job.name}
15
+ header: '${hydra.help.app_name} is powered by Hydra.
16
+
17
+ '
18
+ footer: 'Powered by Hydra (https://hydra.cc)
19
+
20
+ Use --hydra-help to view Hydra specific help
21
+
22
+ '
23
+ template: '${hydra.help.header}
24
+
25
+ == Configuration groups ==
26
+
27
+ Compose your configuration from those groups (group=option)
28
+
29
+
30
+ $APP_CONFIG_GROUPS
31
+
32
+
33
+ == Config ==
34
+
35
+ Override anything in the config (foo.bar=value)
36
+
37
+
38
+ $CONFIG
39
+
40
+
41
+ ${hydra.help.footer}
42
+
43
+ '
44
+ hydra_help:
45
+ template: 'Hydra (${hydra.runtime.version})
46
+
47
+ See https://hydra.cc for more info.
48
+
49
+
50
+ == Flags ==
51
+
52
+ $FLAGS_HELP
53
+
54
+
55
+ == Configuration groups ==
56
+
57
+ Compose your configuration from those groups (For example, append hydra/job_logging=disabled
58
+ to command line)
59
+
60
+
61
+ $HYDRA_CONFIG_GROUPS
62
+
63
+
64
+ Use ''--cfg hydra'' to Show the Hydra config.
65
+
66
+ '
67
+ hydra_help: ???
68
+ hydra_logging:
69
+ version: 1
70
+ formatters:
71
+ simple:
72
+ format: '[%(asctime)s][HYDRA] %(message)s'
73
+ handlers:
74
+ console:
75
+ class: logging.StreamHandler
76
+ formatter: simple
77
+ stream: ext://sys.stdout
78
+ root:
79
+ level: INFO
80
+ handlers:
81
+ - console
82
+ loggers:
83
+ logging_example:
84
+ level: DEBUG
85
+ disable_existing_loggers: false
86
+ job_logging:
87
+ version: 1
88
+ formatters:
89
+ simple:
90
+ format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
91
+ handlers:
92
+ console:
93
+ class: logging.StreamHandler
94
+ formatter: simple
95
+ stream: ext://sys.stdout
96
+ file:
97
+ class: logging.FileHandler
98
+ formatter: simple
99
+ filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
100
+ root:
101
+ level: INFO
102
+ handlers:
103
+ - console
104
+ - file
105
+ disable_existing_loggers: false
106
+ env: {}
107
+ mode: RUN
108
+ searchpath: []
109
+ callbacks: {}
110
+ output_subdir: .hydra
111
+ overrides:
112
+ hydra:
113
+ - hydra.mode=RUN
114
+ task: []
115
+ job:
116
+ name: run
117
+ chdir: false
118
+ override_dirname: ''
119
+ id: ???
120
+ num: ???
121
+ config_name: naive_vs_fixed_ad_align_seed9999.yaml
122
+ env_set: {}
123
+ env_copy: []
124
+ config:
125
+ override_dirname:
126
+ kv_sep: '='
127
+ item_sep: ','
128
+ exclude_keys: []
129
+ runtime:
130
+ version: 1.3.2
131
+ version_base: '1.1'
132
+ cwd: /scratch/muqeeth/llm_negotiation
133
+ config_sources:
134
+ - path: hydra.conf
135
+ schema: pkg
136
+ provider: hydra
137
+ - path: /scratch/muqeeth/llm_negotiation/configs
138
+ schema: file
139
+ provider: main
140
+ - path: ''
141
+ schema: structured
142
+ provider: schema
143
+ output_dir: /scratch/muqeeth/llm_negotiation/2025_11/naive_vs_fixed_ad_align_seed9999
144
+ choices:
145
+ hydra/env: default
146
+ hydra/callbacks: null
147
+ hydra/job_logging: default
148
+ hydra/hydra_logging: default
149
+ hydra/hydra_help: default
150
+ hydra/help: default
151
+ hydra/sweeper: basic
152
+ hydra/launcher: basic
153
+ hydra/output: default
154
+ verbose: false
.hydra/overrides.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ []
run.log ADDED
The diff for this file is too large to render. See raw diff
 
seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/README.md ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: Qwen/Qwen2.5-7B-Instruct
3
+ library_name: peft
4
+ pipeline_tag: text-generation
5
+ tags:
6
+ - base_model:adapter:Qwen/Qwen2.5-7B-Instruct
7
+ - lora
8
+ - transformers
9
+ ---
10
+
11
+ # Model Card for Model ID
12
+
13
+ <!-- Provide a quick summary of what the model is/does. -->
14
+
15
+
16
+
17
+ ## Model Details
18
+
19
+ ### Model Description
20
+
21
+ <!-- Provide a longer summary of what this model is. -->
22
+
23
+
24
+
25
+ - **Developed by:** [More Information Needed]
26
+ - **Funded by [optional]:** [More Information Needed]
27
+ - **Shared by [optional]:** [More Information Needed]
28
+ - **Model type:** [More Information Needed]
29
+ - **Language(s) (NLP):** [More Information Needed]
30
+ - **License:** [More Information Needed]
31
+ - **Finetuned from model [optional]:** [More Information Needed]
32
+
33
+ ### Model Sources [optional]
34
+
35
+ <!-- Provide the basic links for the model. -->
36
+
37
+ - **Repository:** [More Information Needed]
38
+ - **Paper [optional]:** [More Information Needed]
39
+ - **Demo [optional]:** [More Information Needed]
40
+
41
+ ## Uses
42
+
43
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
44
+
45
+ ### Direct Use
46
+
47
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
48
+
49
+ [More Information Needed]
50
+
51
+ ### Downstream Use [optional]
52
+
53
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
54
+
55
+ [More Information Needed]
56
+
57
+ ### Out-of-Scope Use
58
+
59
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
60
+
61
+ [More Information Needed]
62
+
63
+ ## Bias, Risks, and Limitations
64
+
65
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
66
+
67
+ [More Information Needed]
68
+
69
+ ### Recommendations
70
+
71
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
72
+
73
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
74
+
75
+ ## How to Get Started with the Model
76
+
77
+ Use the code below to get started with the model.
78
+
79
+ [More Information Needed]
80
+
81
+ ## Training Details
82
+
83
+ ### Training Data
84
+
85
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
86
+
87
+ [More Information Needed]
88
+
89
+ ### Training Procedure
90
+
91
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
92
+
93
+ #### Preprocessing [optional]
94
+
95
+ [More Information Needed]
96
+
97
+
98
+ #### Training Hyperparameters
99
+
100
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
101
+
102
+ #### Speeds, Sizes, Times [optional]
103
+
104
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
105
+
106
+ [More Information Needed]
107
+
108
+ ## Evaluation
109
+
110
+ <!-- This section describes the evaluation protocols and provides the results. -->
111
+
112
+ ### Testing Data, Factors & Metrics
113
+
114
+ #### Testing Data
115
+
116
+ <!-- This should link to a Dataset Card if possible. -->
117
+
118
+ [More Information Needed]
119
+
120
+ #### Factors
121
+
122
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
123
+
124
+ [More Information Needed]
125
+
126
+ #### Metrics
127
+
128
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
129
+
130
+ [More Information Needed]
131
+
132
+ ### Results
133
+
134
+ [More Information Needed]
135
+
136
+ #### Summary
137
+
138
+
139
+
140
+ ## Model Examination [optional]
141
+
142
+ <!-- Relevant interpretability work for the model goes here -->
143
+
144
+ [More Information Needed]
145
+
146
+ ## Environmental Impact
147
+
148
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
149
+
150
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
151
+
152
+ - **Hardware Type:** [More Information Needed]
153
+ - **Hours used:** [More Information Needed]
154
+ - **Cloud Provider:** [More Information Needed]
155
+ - **Compute Region:** [More Information Needed]
156
+ - **Carbon Emitted:** [More Information Needed]
157
+
158
+ ## Technical Specifications [optional]
159
+
160
+ ### Model Architecture and Objective
161
+
162
+ [More Information Needed]
163
+
164
+ ### Compute Infrastructure
165
+
166
+ [More Information Needed]
167
+
168
+ #### Hardware
169
+
170
+ [More Information Needed]
171
+
172
+ #### Software
173
+
174
+ [More Information Needed]
175
+
176
+ ## Citation [optional]
177
+
178
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
179
+
180
+ **BibTeX:**
181
+
182
+ [More Information Needed]
183
+
184
+ **APA:**
185
+
186
+ [More Information Needed]
187
+
188
+ ## Glossary [optional]
189
+
190
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
191
+
192
+ [More Information Needed]
193
+
194
+ ## More Information [optional]
195
+
196
+ [More Information Needed]
197
+
198
+ ## Model Card Authors [optional]
199
+
200
+ [More Information Needed]
201
+
202
+ ## Model Card Contact
203
+
204
+ [More Information Needed]
205
+ ### Framework versions
206
+
207
+ - PEFT 0.17.1
seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/agent_adapter/adapter_config.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "Qwen/Qwen2.5-7B-Instruct",
5
+ "bias": "none",
6
+ "corda_config": null,
7
+ "eva_config": null,
8
+ "exclude_modules": null,
9
+ "fan_in_fan_out": false,
10
+ "inference_mode": true,
11
+ "init_lora_weights": true,
12
+ "layer_replication": null,
13
+ "layers_pattern": null,
14
+ "layers_to_transform": null,
15
+ "loftq_config": {},
16
+ "lora_alpha": 64,
17
+ "lora_bias": false,
18
+ "lora_dropout": 0.0,
19
+ "megatron_config": null,
20
+ "megatron_core": "megatron.core",
21
+ "modules_to_save": null,
22
+ "peft_type": "LORA",
23
+ "qalora_group_size": 16,
24
+ "r": 32,
25
+ "rank_pattern": {},
26
+ "revision": null,
27
+ "target_modules": [
28
+ "k_proj",
29
+ "gate_proj",
30
+ "up_proj",
31
+ "q_proj",
32
+ "o_proj",
33
+ "down_proj",
34
+ "v_proj"
35
+ ],
36
+ "target_parameters": null,
37
+ "task_type": "CAUSAL_LM",
38
+ "trainable_token_indices": null,
39
+ "use_dora": false,
40
+ "use_qalora": false,
41
+ "use_rslora": false
42
+ }
seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/critic_adapter/adapter_config.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "Qwen/Qwen2.5-7B-Instruct",
5
+ "bias": "none",
6
+ "corda_config": null,
7
+ "eva_config": null,
8
+ "exclude_modules": null,
9
+ "fan_in_fan_out": false,
10
+ "inference_mode": true,
11
+ "init_lora_weights": true,
12
+ "layer_replication": null,
13
+ "layers_pattern": null,
14
+ "layers_to_transform": null,
15
+ "loftq_config": {},
16
+ "lora_alpha": 64,
17
+ "lora_bias": false,
18
+ "lora_dropout": 0.0,
19
+ "megatron_config": null,
20
+ "megatron_core": "megatron.core",
21
+ "modules_to_save": null,
22
+ "peft_type": "LORA",
23
+ "qalora_group_size": 16,
24
+ "r": 32,
25
+ "rank_pattern": {},
26
+ "revision": null,
27
+ "target_modules": [
28
+ "k_proj",
29
+ "gate_proj",
30
+ "up_proj",
31
+ "q_proj",
32
+ "o_proj",
33
+ "down_proj",
34
+ "v_proj"
35
+ ],
36
+ "target_parameters": null,
37
+ "task_type": "CAUSAL_LM",
38
+ "trainable_token_indices": null,
39
+ "use_dora": false,
40
+ "use_qalora": false,
41
+ "use_rslora": false
42
+ }
seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/fixed_ad_align_adapter/adapter_config.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "Qwen/Qwen2.5-7B-Instruct",
5
+ "bias": "none",
6
+ "corda_config": null,
7
+ "eva_config": null,
8
+ "exclude_modules": null,
9
+ "fan_in_fan_out": false,
10
+ "inference_mode": true,
11
+ "init_lora_weights": true,
12
+ "layer_replication": null,
13
+ "layers_pattern": null,
14
+ "layers_to_transform": null,
15
+ "loftq_config": {},
16
+ "lora_alpha": 64,
17
+ "lora_bias": false,
18
+ "lora_dropout": 0.0,
19
+ "megatron_config": null,
20
+ "megatron_core": "megatron.core",
21
+ "modules_to_save": null,
22
+ "peft_type": "LORA",
23
+ "qalora_group_size": 16,
24
+ "r": 32,
25
+ "rank_pattern": {},
26
+ "revision": null,
27
+ "target_modules": [
28
+ "k_proj",
29
+ "gate_proj",
30
+ "up_proj",
31
+ "q_proj",
32
+ "o_proj",
33
+ "down_proj",
34
+ "v_proj"
35
+ ],
36
+ "target_parameters": null,
37
+ "task_type": "CAUSAL_LM",
38
+ "trainable_token_indices": null,
39
+ "use_dora": false,
40
+ "use_qalora": false,
41
+ "use_rslora": false
42
+ }
src_code_for_reproducibility/__init__.py ADDED
File without changes
src_code_for_reproducibility/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (146 Bytes). View file
 
src_code_for_reproducibility/chat_utils/__pycache__/apply_template.cpython-312.pyc ADDED
Binary file (3.92 kB). View file
 
src_code_for_reproducibility/chat_utils/chat_turn.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Any, List, Literal, Optional, Tuple
7
+
8
+ import jsonschema
9
+ import torch
10
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
11
+
12
+ AgentId = str
13
+
14
+
15
+ class ChatTurn(BaseModel):
16
+ model_config = ConfigDict(arbitrary_types_allowed=True) # needed for torch tensors
17
+
18
+ role: str = Field(pattern="^(user|assistant)$")
19
+ agent_id: AgentId # ID of the agent with which the chat occured
20
+ content: str
21
+ reasoning_content: str | None = None
22
+ chat_template_token_ids: torch.LongTensor | None = None # Token ids of chat template format. For example, token ids of "<assistant>{content}</assistant>""
23
+ out_token_ids: torch.LongTensor | None = (
24
+ None # tokens generated from inference engine
25
+ )
26
+ log_probs: torch.FloatTensor | None = None
27
+ is_state_end: bool = False # indicates whether this chat turn marks the end of a state in the trajectory
src_code_for_reproducibility/chat_utils/template_specific.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import huggingface_hub
2
+ import torch
3
+ from transformers import AutoTokenizer
4
+
5
+ custom_llama3_template = """
6
+ {%- if add_system_prompt %}
7
+ {{- '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n<|eot_id|>' }}
8
+ {%- endif %}
9
+ {%- for message in messages %}
10
+ {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }}
11
+ {%- endfor %}
12
+
13
+ {%- if add_generation_prompt %}
14
+ {{- '<|start_header_id|>' + 'assistant' + '<|end_header_id|>\n\n' }}
15
+ {%- endif %}
16
+ """
17
+
18
+ qwen2_assistant_postfix = (
19
+ AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
20
+ .encode("\n", return_tensors="pt")
21
+ .flatten()
22
+ )
23
+ qwen3_assistant_postfix = (
24
+ AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
25
+ .encode("\n", return_tensors="pt")
26
+ .flatten()
27
+ )
28
+ gemma3_assistant_postfix = (
29
+ AutoTokenizer.from_pretrained("google/gemma-3-4b-it")
30
+ .encode("\n", return_tensors="pt")
31
+ .flatten()
32
+ )
33
+ custom_qwen2_template = """
34
+ {%- if add_system_prompt %}
35
+ {{- '<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n' }}
36
+ {%- endif %}
37
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
38
+ {%- for message in messages %}
39
+ {%- if message.content is string %}
40
+ {%- set content = message.content %}
41
+ {%- else %}
42
+ {%- set content = '' %}
43
+ {%- endif %}
44
+ {%- if (message.role == "user") %}
45
+ {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
46
+ {%- elif message.role == "assistant" %}
47
+ {%- set reasoning_content = '' %}
48
+ {%- if message.reasoning_content is string %}
49
+ {%- set reasoning_content = message.reasoning_content %}
50
+ {%- else %}
51
+ {%- if '</think>' in content %}
52
+ {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
53
+ {%- set content = content.split('</think>')[-1].lstrip('\n') %}
54
+ {%- endif %}
55
+ {%- endif %}
56
+ {%- if loop.index0 > ns.last_query_index %}
57
+ {%- if reasoning_content %}
58
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
59
+ {%- else %}
60
+ {{- '<|im_start|>' + message.role + '\n' + content }}
61
+ {%- endif %}
62
+ {%- else %}
63
+ {{- '<|im_start|>' + message.role + '\n' + content }}
64
+ {%- endif %}
65
+ {{- '<|im_end|>\n' }}
66
+ {%- endif %}
67
+ {%- endfor %}
68
+ {%- if add_generation_prompt %}
69
+ {{- '<|im_start|>assistant\n' }}
70
+ {%- endif %}
71
+ """
72
+
73
+ custom_qwen3_template = """
74
+ {%- for message in messages %}
75
+ {%- if message.content is string %}
76
+ {%- set content = message.content %}
77
+ {%- else %}
78
+ {%- set content = '' %}
79
+ {%- endif %}
80
+ {%- if (message.role == "user") %}
81
+ {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
82
+ {%- elif message.role == "assistant" %}
83
+ {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
84
+ {%- endif %}
85
+ {%- endfor %}
86
+ {%- if add_generation_prompt %}
87
+ {{- '<|im_start|>assistant\n' }}
88
+ {%- if enable_thinking is defined and enable_thinking is false %}
89
+ {{- '<think>\n\n</think>\n\n' }}
90
+ {%- endif %}
91
+ {%- endif %}
92
+ """
93
+
94
+ custom_gemma3_template = """
95
+ {%- if add_system_prompt %}
96
+ {{- bos_token -}}
97
+ {%- endif %}
98
+ {%- for message in messages -%}
99
+ {%- if message['role'] == 'assistant' -%}
100
+ {%- set role = 'model' -%}
101
+ {%- else -%}
102
+ {%- set role = message['role'] -%}
103
+ {%- endif -%}
104
+ {{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}
105
+ {%- endfor -%}
106
+ {%- if add_generation_prompt -%}
107
+ {{ '<start_of_turn>model\n' }}
108
+ {%- endif -%}
109
+ """
src_code_for_reproducibility/docs/Makefile ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Minimal makefile for Sphinx documentation
2
+
3
+ # You can set these variables from the command line, and also
4
+ # from the environment for the first two.
5
+ SPHINXOPTS ?=
6
+ SPHINXBUILD ?= sphinx-build
7
+ SOURCEDIR = source
8
+ BUILDDIR = build
9
+
10
+ # Put it first so that "make" without argument is like "make help".
11
+ help:
12
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(SPHINXFLAGS)
13
+
14
+ .PHONY: help Makefile
15
+
16
+ # Catch-all target: route all unknown targets to Sphinx using the new
17
+ # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
18
+ %: Makefile
19
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(SPHINXFLAGS)
src_code_for_reproducibility/docs/generate_docs.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Script to automatically generate Sphinx documentation for all modules and build the HTML website.
4
+ """
5
+ import importlib.util
6
+ import os
7
+ import subprocess
8
+ import sys
9
+
10
+
11
+ def check_and_install_dependencies():
12
+ """Check for required dependencies and install them if missing."""
13
+ required_packages = [
14
+ "sphinx",
15
+ "sphinx-rtd-theme",
16
+ "sphinxcontrib-napoleon",
17
+ "sphinxcontrib-mermaid",
18
+ "sphinx-autodoc-typehints",
19
+ ]
20
+
21
+ missing_packages = []
22
+
23
+ for package in required_packages:
24
+ # Convert package name to module name (replace - with _)
25
+ module_name = package.replace("-", "_")
26
+
27
+ # Check if the package is installed
28
+ if importlib.util.find_spec(module_name) is None:
29
+ missing_packages.append(package)
30
+
31
+ # Install missing packages
32
+ if missing_packages:
33
+ print(f"Installing missing dependencies: {', '.join(missing_packages)}")
34
+ subprocess.check_call(
35
+ [sys.executable, "-m", "pip", "install"] + missing_packages
36
+ )
37
+ print("Dependencies installed successfully")
38
+ else:
39
+ print("All required dependencies are already installed")
40
+
41
+
42
+ def create_makefile(docs_dir):
43
+ """Create a Makefile for Sphinx documentation if it doesn't exist."""
44
+ makefile_path = os.path.join(docs_dir, "Makefile")
45
+
46
+ if os.path.exists(makefile_path):
47
+ print(f"Makefile already exists at {makefile_path}")
48
+ return
49
+
50
+ print(f"Creating Makefile at {makefile_path}")
51
+
52
+ makefile_content = """# Minimal makefile for Sphinx documentation
53
+
54
+ # You can set these variables from the command line, and also
55
+ # from the environment for the first two.
56
+ SPHINXOPTS ?=
57
+ SPHINXBUILD ?= sphinx-build
58
+ SOURCEDIR = source
59
+ BUILDDIR = build
60
+
61
+ # Put it first so that "make" without argument is like "make help".
62
+ help:
63
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(SPHINXFLAGS)
64
+
65
+ .PHONY: help Makefile
66
+
67
+ # Catch-all target: route all unknown targets to Sphinx using the new
68
+ # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
69
+ %: Makefile
70
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(SPHINXFLAGS)
71
+ """
72
+
73
+ with open(makefile_path, "w") as f:
74
+ f.write(makefile_content)
75
+
76
+ print("Makefile created successfully")
77
+
78
+
79
+ def create_make_bat(docs_dir):
80
+ """Create a make.bat file for Windows if it doesn't exist."""
81
+ make_bat_path = os.path.join(docs_dir, "make.bat")
82
+
83
+ if os.path.exists(make_bat_path):
84
+ print(f"make.bat already exists at {make_bat_path}")
85
+ return
86
+
87
+ print(f"Creating make.bat at {make_bat_path}")
88
+
89
+ make_bat_content = """@ECHO OFF
90
+
91
+ pushd %~dp0
92
+
93
+ REM Command file for Sphinx documentation
94
+
95
+ if "%SPHINXBUILD%" == "" (
96
+ set SPHINXBUILD=sphinx-build
97
+ )
98
+ set SOURCEDIR=source
99
+ set BUILDDIR=build
100
+
101
+ %SPHINXBUILD% >NUL 2>NUL
102
+ if errorlevel 9009 (
103
+ echo.
104
+ echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
105
+ echo.installed, then set the SPHINXBUILD environment variable to point
106
+ echo.to the full path of the 'sphinx-build' executable. Alternatively you
107
+ echo.may add the Sphinx directory to PATH.
108
+ echo.
109
+ echo.If you don't have Sphinx installed, grab it from
110
+ echo.https://www.sphinx-doc.org/
111
+ exit /b 1
112
+ )
113
+
114
+ if "%1" == "" goto help
115
+
116
+ %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
117
+ goto end
118
+
119
+ :help
120
+ %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
121
+
122
+ :end
123
+ popd
124
+ """
125
+
126
+ with open(make_bat_path, "w") as f:
127
+ f.write(make_bat_content)
128
+
129
+ print("make.bat created successfully")
130
+
131
+
132
+ def main():
133
+ # Check and install required dependencies
134
+ print("=== Checking dependencies ===")
135
+ check_and_install_dependencies()
136
+
137
+ # Get the directory of this script
138
+ script_dir = os.path.dirname(os.path.abspath(__file__))
139
+
140
+ # Path to the project root
141
+ project_root = os.path.dirname(script_dir)
142
+
143
+ # Path to the source directory
144
+ source_dir = os.path.join(project_root, "src")
145
+
146
+ # Path to the docs source directory
147
+ docs_source_dir = os.path.join(script_dir, "source")
148
+
149
+ # Print paths for debugging
150
+ print(f"Script directory: {script_dir}")
151
+ print(f"Project root: {project_root}")
152
+ print(f"Source directory: {source_dir}")
153
+ print(f"Docs source directory: {docs_source_dir}")
154
+
155
+ # Make sure the source directory exists
156
+ if not os.path.exists(source_dir):
157
+ print(f"Error: Source directory {source_dir} does not exist!")
158
+ sys.exit(1)
159
+
160
+ # Make sure the docs source directory exists
161
+ if not os.path.exists(docs_source_dir):
162
+ print(f"Creating docs source directory: {docs_source_dir}")
163
+ os.makedirs(docs_source_dir)
164
+
165
+ # Step 1: Run sphinx-apidoc to generate .rst files for all modules
166
+ print("\n=== Generating API documentation ===")
167
+ cmd = [
168
+ "sphinx-apidoc",
169
+ "-f", # Force overwriting of existing files
170
+ "-e", # Put module documentation before submodule documentation
171
+ "-M", # Put module documentation before subpackage documentation
172
+ "-o",
173
+ docs_source_dir, # Output directory
174
+ source_dir, # Source code directory
175
+ ]
176
+
177
+ print(f"Running command: {' '.join(cmd)}")
178
+ result = subprocess.run(cmd, capture_output=True, text=True)
179
+
180
+ # Print the output of the command
181
+ print("STDOUT:")
182
+ print(result.stdout)
183
+
184
+ print("STDERR:")
185
+ print(result.stderr)
186
+
187
+ if result.returncode != 0:
188
+ print(f"Error: sphinx-apidoc failed with return code {result.returncode}")
189
+ sys.exit(1)
190
+
191
+ # List the files in the docs source directory
192
+ print("\nFiles in docs/source directory:")
193
+ for file in sorted(os.listdir(docs_source_dir)):
194
+ print(f" {file}")
195
+
196
+ print("\nDocumentation source files generated successfully!")
197
+
198
+ # Step 2: Create Makefile and make.bat if they don't exist
199
+ create_makefile(script_dir)
200
+ create_make_bat(script_dir)
201
+
202
+ # Step 3: Build the HTML documentation
203
+ print("\n=== Building HTML documentation ===")
204
+
205
+ # Determine the build command based on the platform
206
+ if os.name == "nt": # Windows
207
+ build_cmd = ["make.bat", "html"]
208
+ else: # Unix/Linux/Mac
209
+ build_cmd = ["make", "html"]
210
+
211
+ # Change to the docs directory to run the build command
212
+ os.chdir(script_dir)
213
+
214
+ print(f"Running command: {' '.join(build_cmd)}")
215
+ build_result = subprocess.run(build_cmd, capture_output=True, text=True)
216
+
217
+ # Print the output of the build command
218
+ print("STDOUT:")
219
+ print(build_result.stdout)
220
+
221
+ print("STDERR:")
222
+ print(build_result.stderr)
223
+
224
+ if build_result.returncode != 0:
225
+ print(f"Error: HTML build failed with return code {build_result.returncode}")
226
+ sys.exit(1)
227
+
228
+ # Get the path to the built HTML documentation
229
+ html_dir = os.path.join(script_dir, "build", "html")
230
+ index_path = os.path.join(html_dir, "index.html")
231
+
232
+ if os.path.exists(index_path):
233
+ print(f"\nHTML documentation built successfully!")
234
+ print(f"You can view it by opening: {index_path}")
235
+
236
+ # Try to open the documentation in a browser
237
+ try:
238
+ import webbrowser
239
+
240
+ print("\nAttempting to open documentation in your default browser...")
241
+ webbrowser.open(f"file://{index_path}")
242
+ except Exception as e:
243
+ print(f"Could not open browser automatically: {e}")
244
+ else:
245
+ print(f"\nWarning: HTML index file not found at {index_path}")
246
+
247
+
248
+ if __name__ == "__main__":
249
+ main()
src_code_for_reproducibility/docs/make.bat ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @ECHO OFF
2
+
3
+ pushd %~dp0
4
+
5
+ REM Command file for Sphinx documentation
6
+
7
+ if "%SPHINXBUILD%" == "" (
8
+ set SPHINXBUILD=sphinx-build
9
+ )
10
+ set SOURCEDIR=source
11
+ set BUILDDIR=build
12
+
13
+ %SPHINXBUILD% >NUL 2>NUL
14
+ if errorlevel 9009 (
15
+ echo.
16
+ echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17
+ echo.installed, then set the SPHINXBUILD environment variable to point
18
+ echo.to the full path of the 'sphinx-build' executable. Alternatively you
19
+ echo.may add the Sphinx directory to PATH.
20
+ echo.
21
+ echo.If you don't have Sphinx installed, grab it from
22
+ echo.https://www.sphinx-doc.org/
23
+ exit /b 1
24
+ )
25
+
26
+ if "%1" == "" goto help
27
+
28
+ %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29
+ goto end
30
+
31
+ :help
32
+ %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33
+
34
+ :end
35
+ popd
src_code_for_reproducibility/docs/source/environments/diplomacy.rst ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ =================
2
+ Diplomacy
3
+ =================
4
+
5
+ The Diplomacy environment provides a multi-agent negotiation interface for the classic board game Diplomacy,
6
+ based on DeepMind's implementation. This document describes the API for interacting with the Diplomacy environment
7
+ and its associated agent handler.
8
+
9
+ Overview
10
+ --------
11
+
12
+ Diplomacy is a strategic board game set in Europe before World War I, where players control one of seven European powers
13
+ and negotiate with each other to gain control of supply centers. The game is played in turns, with each turn consisting
14
+ of movement phases, retreat phases, and build phases.
15
+
16
+ Our implementation adapts DeepMind's Diplomacy code to the Multi-Agent Negotiation Environment standard, allowing it
17
+ to be used with LLM agents through a text-based interface.
18
+
19
+ Game Rules
20
+ ----------
21
+
22
+ ### Game Board and Powers
23
+
24
+ Diplomacy is played on a map of Europe divided into provinces. The game features seven Great Powers that players can control:
25
+
26
+ - England (blue)
27
+ - France (light blue)
28
+ - Germany (black)
29
+ - Italy (green)
30
+ - Austria-Hungary (red)
31
+ - Russia (white)
32
+ - Turkey (yellow)
33
+
34
+ Each power begins with three supply centers (except Russia, which starts with four) and an equal number of units.
35
+
36
+ ### Units and Movement
37
+
38
+ There are two types of units in Diplomacy:
39
+ - **Armies (A)**: Can move to adjacent land provinces or be convoyed across water by fleets
40
+ - **Fleets (F)**: Can move to adjacent coastal provinces and sea regions
41
+
42
+ During movement phases, each unit can execute one of these orders:
43
+ - **Hold**: The unit remains in its current province (e.g., "A PAR H")
44
+ - Format: [Unit Type] [Province] H
45
+ - Example: "A PAR H" means "Army in Paris holds its position"
46
+
47
+ - **Move**: The unit attempts to move to an adjacent province (e.g., "A PAR - BUR")
48
+ - Format: [Unit Type] [Current Province] - [Destination Province]
49
+ - Example: "A PAR - BUR" means "Army in Paris moves to Burgundy"
50
+ - Example: "F BRE - ENG" means "Fleet in Brest moves to the English Channel"
51
+
52
+ - **Support**: The unit supports another unit's move or hold (e.g., "A PAR S A MAR - BUR")
53
+ - Format for supporting a move: [Unit Type] [Province] S [Unit Type] [Province] - [Destination]
54
+ - Format for supporting a hold: [Unit Type] [Province] S [Unit Type] [Province]
55
+ - Example: "A PAR S A MAR - BUR" means "Army in Paris supports the Army in Marseille's move to Burgundy"
56
+ - Example: "F LON S F NTH" means "Fleet in London supports the Fleet in North Sea holding its position"
57
+
58
+ - **Convoy**: A fleet can convoy an army across water (e.g., "F ENG C A LON - BRE")
59
+ - Format: [Fleet] [Sea Province] C [Army] [Coastal Province] - [Coastal Province]
60
+ - Example: "F ENG C A LON - BRE" means "Fleet in English Channel convoys the Army in London to Brest"
61
+
62
+ All orders are executed simultaneously, and conflicts are resolved based on strength (number of supporting units).
63
+
64
+ ### Common Province Abbreviations
65
+
66
+ Diplomacy uses three-letter abbreviations for provinces. Some common ones include:
67
+ - **PAR**: Paris
68
+ - **LON**: London
69
+ - **BER**: Berlin
70
+ - **MUN**: Munich
71
+ - **BUR**: Burgundy
72
+ - **MAR**: Marseilles
73
+ - **BRE**: Brest
74
+ - **ENG**: English Channel
75
+ - **NTH**: North Sea
76
+ - **VIE**: Vienna
77
+ - **ROM**: Rome
78
+ - **VEN**: Venice
79
+ - **MOW**: Moscow
80
+ - **CON**: Constantinople
81
+
82
+ ### Example: Movement and Conflicts
83
+
84
+ For example, if France orders "A PAR - BUR" and Germany orders "A MUN - BUR", neither move succeeds as they have equal strength. However, if France also orders "A MAR S A PAR - BUR", then the French army from Paris would successfully move to Burgundy with strength of 2 against Germany's strength of 1.
85
+
86
+ ### Turn Structure
87
+
88
+ A game year consists of five phases:
89
+ 1. **Spring Movement**: All powers submit orders for their units
90
+ 2. **Spring Retreat**: Units dislodged in the movement phase must retreat or be disbanded
91
+ 3. **Fall Movement**: Another round of movement orders
92
+ 4. **Fall Retreat**: Retreat orders for dislodged units
93
+ 5. **Winter Adjustment**: Powers gain or lose units based on the number of supply centers they control
94
+
95
+ ### Supply Centers and Building
96
+
97
+ Supply centers (marked on the map) are key to victory. When a power occupies a supply center during a Fall turn, they gain control of it. During the Winter Adjustment phase:
98
+ - If you control more supply centers than you have units, you can build new units in your home supply centers
99
+ - If you control fewer supply centers than you have units, you must remove excess units
100
+
101
+ ### Example: Building and Removing Units
102
+
103
+ If France controls 5 supply centers but only has 4 units, during the Winter phase they can build one new unit in an unoccupied home supply center (Paris, Marseilles, or Brest). Conversely, if France controls only 3 supply centers but has 4 units, they must remove one unit of their choice.
104
+
105
+ ### Negotiation
106
+
107
+ A critical component of Diplomacy is the negotiation between players. Before submitting orders, players can communicate freely to form alliances, coordinate attacks, or mislead opponents. These negotiations are not binding, and betrayal is a common strategy.
108
+
109
+ ### Example: Alliance and Betrayal
110
+
111
+ England and France might agree to an alliance against Germany, with England promising to support France's move into Belgium. However, England could secretly order their fleet to move into Belgium themselves or support a German move instead.
112
+
113
+ ### Victory Conditions
114
+
115
+ The game ends when one power controls 18 or more supply centers (majority of the 34 total centers), or when players agree to a draw. In tournament settings, games may also end after a predetermined number of game years.
116
+
117
+ DiplomacyEnv
118
+ ------------
119
+
120
+ The ``DiplomacyEnv`` class provides an interface to the Diplomacy game environment that follows the Multi-Agent
121
+ Negotiation Environment standard.
122
+
123
+ .. code-block:: python
124
+
125
+ class DiplomacyEnv:
126
+ """
127
+ Multi-Agent Negotiation Environment for Diplomacy, adapting Deepmind's implementation
128
+ to the MarlEnvironment standard.
129
+ """
130
+ def __init__(self,
131
+ initial_state: Optional[DiplomacyState] = None,
132
+ max_turns: int = 100,
133
+ points_per_supply_centre: bool = True,
134
+ forced_draw_probability: float = 0.0,
135
+ min_years_forced_draw: int = 35):
136
+ """Initialize the Diplomacy environment.
137
+
138
+ Args:
139
+ initial_state: Initial DiplomacyState (optional)
140
+ max_turns: Maximum number of turns in the game
141
+ points_per_supply_centre: Whether to award points per supply center in case of a draw
142
+ forced_draw_probability: Probability of forcing a draw after min_years_forced_draw
143
+ min_years_forced_draw: Minimum years before considering a forced draw
144
+ """
145
+ # ...
146
+
147
+ def reset(self):
148
+ """Reset the environment to an initial state and return the initial observation.
149
+
150
+ Returns:
151
+ observation (dict): A dictionary where keys are agent identifiers and values are observations.
152
+ Each observation contains:
153
+ - board_state: Current state of the board
154
+ - current_season: Current season in the game
155
+ - player_index: Index of the player's power
156
+ - possible_actions: List of possible actions in DeepMind's format
157
+ - human_readable_actions: List of human-readable action descriptions
158
+ - supply_centers: List of supply centers owned by the player
159
+ - units: List of units owned by the player
160
+ - year: Current year in the game
161
+ """
162
+ # ...
163
+
164
+ def step(self, actions):
165
+ """Take a step in the environment using the provided actions.
166
+
167
+ Args:
168
+ actions (dict): A dictionary where keys are agent identifiers and values are actions.
169
+ Actions can be:
170
+ - List of integer actions in DeepMind's format
171
+ - List of string actions in text format (e.g., "A MUN - BER")
172
+
173
+ Returns:
174
+ observations (dict): A dictionary where keys are agent identifiers and values are observations.
175
+ Each observation has the same structure as in reset().
176
+ done (bool): Whether the episode has ended.
177
+ info (dict): Additional information about the environment, including:
178
+ - turn: Current turn number
179
+ - returns: Game returns if the game is done, otherwise None
180
+ - waiting_for: List of agents that still need to provide actions (if not all actions are provided)
181
+ """
182
+ # ...
183
+
184
+ def get_log_info(self):
185
+ """Get additional information about the environment for logging.
186
+
187
+ Returns:
188
+ log_info (dict): Information about the environment required to log the game, including:
189
+ - power_names: List of power names
190
+ - game_history: History of the game
191
+ - current_turn: Current turn number
192
+ - current_season: Current season name
193
+ - supply_centers: Dictionary mapping power names to supply center counts
194
+ """
195
+ # ...
196
+
197
+ def render(self):
198
+ """Render the current state of the environment.
199
+
200
+ Displays a visualization of the current game state.
201
+ """
202
+ # ...
203
+
204
+ def close(self):
205
+ """Perform any necessary cleanup."""
206
+ # ...
207
+
208
+
209
+ Key Implementation Details
210
+ ~~~~~~~~~~~~~~~~~~~~~~~~~
211
+
212
+ The ``DiplomacyEnv`` class implements several key features:
213
+
214
+ 1. **Multi-Agent Support**: The environment tracks multiple agents (powers) and manages their interactions.
215
+
216
+ 2. **Turn-Based Gameplay**: The environment enforces the turn structure of Diplomacy, including different phases.
217
+
218
+ 3. **Action Processing**: The environment can handle actions in both text format and DeepMind's integer format.
219
+
220
+ 4. **Observation Generation**: The environment generates detailed observations for each agent, including board state, supply centers, and possible actions.
221
+
222
+ 5. **Game Termination**: The environment tracks game termination conditions, including supply center victory and maximum turn limits.
223
+
224
+ Observation Structure
225
+ ~~~~~~~~~~~~~~~~~~~~
226
+
227
+ Each agent receives an observation dictionary with the following structure:
228
+
229
+ .. code-block:: python
230
+
231
+ {
232
+ "board_state": np.ndarray, # Board state representation
233
+ "current_season": int, # Season index (0-4)
234
+ "player_index": int, # Index of the player's power (0-6)
235
+ "possible_actions": [int], # List of possible actions in DeepMind's format
236
+ "human_readable_actions": [str], # List of human-readable action descriptions
237
+ "supply_centers": [str], # List of supply centers owned by the player
238
+ "units": [dict], # List of units owned by the player
239
+ "year": int # Current year in the game
240
+ }
241
+
242
+ Action Structure
243
+ ~~~~~~~~~~~~~~~
244
+
245
+ Actions can be provided in two formats:
246
+
247
+ 1. **Text Format**: String actions like ``"A MUN - BER"`` or ``"F NTH C A LON - BEL"``.
248
+
249
+ 2. **Integer Format**: Lists of integers corresponding to DeepMind's action representation.
250
+
251
+ The environment will convert text actions to the internal format as needed.
252
+
253
+ DiplomacyAgent
254
+ --------------
255
+
256
+ The ``DiplomacyAgent`` class implements the agent handler interface for Diplomacy, processing observations from the environment and generating actions through an LLM.
257
+
258
+ .. code-block:: python
259
+
260
+ class DiplomacyAgent:
261
+ """
262
+ Agent handler for Diplomacy, implementing the AgentState interface
263
+ for the multi-agent negotiation standard.
264
+ """
265
+
266
+ def __init__(self,
267
+ power_name: str,
268
+ use_text_interface: bool = True,
269
+ system_prompt: Optional[str] = None):
270
+ """Initialize the Diplomacy agent handler.
271
+
272
+ Args:
273
+ power_name: Name of the power this agent controls
274
+ use_text_interface: Whether to use text-based interface (vs. structured)
275
+ system_prompt: Optional system prompt to use for the LLM
276
+ """
277
+ # ...
278
+
279
+ def step(self, observation_from_env, policy_output=None):
280
+ """Update the agent state based on the observation and action.
281
+
282
+ Args:
283
+ observation_from_env: The observation from the environment, with structure:
284
+ - board_state: Current state of the board
285
+ - current_season: Current season in the game
286
+ - player_index: Index of the player's power
287
+ - possible_actions: List of possible actions
288
+ - human_readable_actions: List of human-readable action descriptions
289
+ - supply_centers: List of supply centers owned by the player
290
+ - units: List of units owned by the player
291
+ - year: Current year in the game
292
+
293
+ policy_output: The output of the policy (LLM response), or None for initial prompt
294
+
295
+ Returns:
296
+ policy_id (str): The policy identifier ("llm_policy")
297
+ policy_input (dict): The input to the policy, with structure:
298
+ - messages: List of conversation messages in the format:
299
+ [{"role": "system", "content": "..."},
300
+ {"role": "user", "content": "..."}]
301
+ action: The official action to be sent to the environment, or None if not ready
302
+ done (bool): Whether the LLM action is ready to be sent to the environment
303
+ info (dict): Additional information about the agent:
304
+ - valid_action: Whether the extracted action is valid
305
+ """
306
+ # ...
307
+
308
+ def get_log_info(self):
309
+ """Get information about the agent required to log a trajectory.
310
+
311
+ Returns:
312
+ log_info (dict): Information about the agent required to log a trajectory:
313
+ - power_name: Name of the power this agent controls
314
+ - conversation_history: List of conversation messages
315
+ - current_action: The current action, if any
316
+ """
317
+ # ...
318
+
319
+ def render(self):
320
+ """Render the current state of the agent.
321
+
322
+ Displays the agent's current state, including conversation history.
323
+ """
324
+ # ...
325
+
326
+ def close(self):
327
+ """Perform any necessary cleanup."""
328
+ # ...
329
+
330
+
331
+ Key Implementation Details
332
+ ~~~~~~~~~~~~~~~~~~~~~~~~~
333
+
334
+ The ``DiplomacyAgent`` class implements several key features:
335
+
336
+ 1. **LLM Interaction**: The agent generates prompts for an LLM and processes the LLM's responses to extract actions.
337
+
338
+ 2. **Conversation Management**: The agent maintains a conversation history for coherent interactions with the LLM.
339
+
340
+ 3. **Action Validation**: The agent validates extracted actions against the set of possible actions provided by the environment.
341
+
342
+ 4. **Error Handling**: The agent generates clarification prompts when invalid actions are detected.
343
+
344
+ 5. **Text-Based Interface**: The agent formats game state information into human-readable text for the LLM.
345
+
346
+ Prompt Structure
347
+ ~~~~~~~~~~~~~~~
348
+
349
+ The agent generates prompts that include:
350
+
351
+ 1. **System Prompt**: Instructions and context for the LLM, explaining its role as a Diplomacy player.
352
+
353
+ 2. **Game State Description**: A text description of the current game state, including:
354
+ - Current year and season
355
+ - Supply centers owned
356
+ - Units controlled
357
+ - Possible actions
358
+
359
+ 3. **Action Request**: Instructions on how to format actions.
360
+
361
+ Example system prompt:
362
+
363
+ .. code-block:: text
364
+
365
+ You are playing the role of FRANCE in a game of Diplomacy.
366
+ Your goal is to control as many supply centers as possible.
367
+ You can negotiate with other players and form alliances, but remember that
368
+ these alliances are not binding. When you need to submit orders for your units,
369
+ write them in the correct format, with each order on a new line.
370
+
371
+ Example game state description:
372
+
373
+ .. code-block:: text
374
+
375
+ Year: 1901, Season: SPRING_MOVES
376
+ You are playing as FRANCE.
377
+ You currently control 3 supply centers: PAR, MAR, BRE.
378
+ Your units are: A PAR, A MAR, F BRE.
379
+
380
+ Please provide orders for your units. Here are your possible actions:
381
+ A PAR - BUR
382
+ A PAR - GAS
383
+ A PAR - PIC
384
+ A PAR H
385
+ ...
386
+
387
+ Submit your orders, one per line, in the format like: "A MUN - BER" or "F NTH C A LON - BEL"
388
+
389
+ Running Diplomacy Games
390
+ ----------------------
391
+
392
+ To run Diplomacy games with LLM agents, you can use the ``run_batched_matches`` function with the ``DiplomacyEnv`` and ``DiplomacyAgent`` classes:
393
+
394
+ .. code-block:: python
395
+
396
+ from mllm.environments.diplomacy.diplomacy_env import DiplomacyEnv
397
+ from mllm.environments.diplomacy.diplomacy_agent import DiplomacyAgent
398
+ from mllm.run_matches import run_batched_matches
399
+
400
+ # Create environment and agent handlers
401
+ env = DiplomacyEnv(max_turns=30)
402
+
403
+ agent_handlers = {
404
+ "AUSTRIA": DiplomacyAgent(power_name="AUSTRIA"),
405
+ "ENGLAND": DiplomacyAgent(power_name="ENGLAND"),
406
+ "FRANCE": DiplomacyAgent(power_name="FRANCE"),
407
+ "GERMANY": DiplomacyAgent(power_name="GERMANY"),
408
+ "ITALY": DiplomacyAgent(power_name="ITALY"),
409
+ "RUSSIA": DiplomacyAgent(power_name="RUSSIA"),
410
+ "TURKEY": DiplomacyAgent(power_name="TURKEY")
411
+ }
412
+
413
+ # Define policy mapping (mapping from policy IDs to actual policy functions)
414
+ policy_mapping = {
415
+ "llm_policy": my_llm_policy_function
416
+ }
417
+
418
+ # Run the game
419
+ game_results = run_batched_matches(
420
+ envs=[env],
421
+ agent_handlers_per_env=[agent_handlers],
422
+ policy_mapping=policy_mapping,
423
+ max_parallel_matches=1
424
+ )
425
+
426
+ # Process results
427
+ for result in game_results:
428
+ print(f"Game finished. Winner: {result['winner']}")
429
+ print(f"Supply centers: {result['supply_centers']}")
430
+
431
+ This setup allows you to run Diplomacy games with LLM agents using the Multi-Agent Negotiation Environment standard.
432
+
433
+ Limitations and Considerations
434
+ -----------------------------
435
+
436
+ 1. **Performance**: Processing observations and actions for seven powers using LLMs can be computationally intensive.
437
+
438
+ 2. **Action Parsing**: Extracting valid actions from LLM outputs may require sophisticated parsing and error handling.
439
+
440
+ 3. **Game Complexity**: Diplomacy is a complex game with many rules and edge cases, which may be challenging for LLMs to fully grasp.
441
+
442
+ 4. **Turn Duration**: Real Diplomacy games include negotiation phases of variable duration, which are not fully captured in this implementation.
443
+
444
+ 5. **Text Formatting**: The quality of LLM interactions depends heavily on the formatting and clarity of text prompts.
445
+
446
+ Advanced Usage
447
+ ------------
448
+
449
+ For advanced usage, you can customize:
450
+
451
+ 1. **System Prompts**: Modify agent behavior by providing custom system prompts.
452
+
453
+ 2. **Observation Processing**: Extend the observation processing to include additional information.
454
+
455
+ 3. **Action Parsing**: Implement more sophisticated action parsing for complex orders.
456
+
457
+ 4. **Visualization**: Add custom visualization methods to the environment's render function.
458
+
459
+ 5. **Logging**: Extend the logging capabilities to capture additional information about the game state.
src_code_for_reproducibility/docs/source/src.experiments.dond_run_train.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.experiments.dond\_run\_train module
2
+ =======================================
3
+
4
+ .. automodule:: src.experiments.dond_run_train
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.models.dummy_hf_agent.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.models.dummy\_hf\_agent module
2
+ ==================================
3
+
4
+ .. automodule:: src.models.dummy_llm_agent
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/markov_games/__init__.py ADDED
File without changes
src_code_for_reproducibility/markov_games/agent.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ In simple RL paradise, where the action dimensions are constant and well defined,
3
+ Agent classes are not necessary. But in MARL, with LLM's, there isn't always
4
+ a direct path from policy to action. For instance, from the observation of the environment,
5
+ a prompt must be created. Then, the outputs of the policy might be incorrect, so a second
6
+ request to the LLM must be sent before the action is well defined. This is why this Agent class exists.
7
+ It acts as a mini environment, bridging the gap between the core simulation and
8
+ the LLM policies.
9
+ """
10
+
11
+ from abc import ABC, abstractmethod
12
+ from collections.abc import Callable
13
+ from typing import Any, Tuple
14
+
15
+ from numpy.random import default_rng
16
+
17
+ from mllm.markov_games.rollout_tree import AgentActLog
18
+
19
+
20
+ class Agent(ABC):
21
+ @abstractmethod
22
+ def __init__(
23
+ self,
24
+ seed: int,
25
+ agent_id: str,
26
+ agent_name: str,
27
+ agent_policy: Callable[[list[dict]], str],
28
+ *args,
29
+ **kwargs,
30
+ ):
31
+ """
32
+ Initialize the agent state.
33
+ """
34
+ self.seed = seed
35
+ self.agent_id = agent_id
36
+ self.agent_name = agent_name
37
+ self.policy = policy
38
+ self.rng = default_rng(self.seed)
39
+ raise NotImplementedError
40
+
41
+ async def act(self, observation) -> Tuple[Any, AgentActLog]:
42
+ """
43
+ Query (possibly multiple times) a policy (or possibly a pool of policies) to
44
+ obtain the action of the agent.
45
+
46
+ Example:
47
+ action = None
48
+ prompt = self.observation_to_prompt(observation)
49
+ while not self.valid(action):
50
+ output = await self.policy.generate(prompt)
51
+ action = self.policy_output_to_action(output)
52
+ return action
53
+
54
+ Returns:
55
+ action
56
+ step_info
57
+ """
58
+ raise NotImplementedError
59
+
60
+ def get_safe_copy(self):
61
+ """
62
+ Return copy of the agent object that is decorrelated from the original object.
63
+ """
64
+ raise NotImplementedError
65
+
66
+ def reset(self):
67
+ raise NotImplementedError
68
+
69
+ def render(self):
70
+ raise NotImplementedError
71
+
72
+ def close(self):
73
+ raise NotImplementedError
74
+
75
+ def get_agent_info(self):
76
+ raise NotImplementedError
src_code_for_reproducibility/markov_games/alternative_actions_runner.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import copy
3
+ import json
4
+ import os.path
5
+ from typing import Any, Tuple
6
+
7
+ from mllm.markov_games.markov_game import AgentAndActionSafeCopy, MarkovGame
8
+ from mllm.markov_games.rollout_tree import (
9
+ AgentActLog,
10
+ RolloutTreeBranchNode,
11
+ RolloutTreeNode,
12
+ RolloutTreeRootNode,
13
+ StepLog,
14
+ )
15
+
16
+ AgentId = str
17
+
18
+
19
+
20
+ async def run_with_unilateral_alt_action(
21
+ markov_game: MarkovGame,
22
+ agent_id: AgentId,
23
+ time_step: int,
24
+ branch_node: RolloutTreeBranchNode,
25
+ max_depth: int,
26
+ ):
27
+ """
28
+ This function is used to generate a new branch for a given agent.
29
+ """
30
+
31
+ # Generate alternative action and take a step
32
+ await markov_game.set_action_of_agent(agent_id)
33
+ terminated: bool = markov_game.take_simulation_step()
34
+ step_log = markov_game.get_step_log()
35
+ first_alternative_node = RolloutTreeNode(
36
+ step_log=step_log,
37
+ time_step=time_step,
38
+ )
39
+
40
+ # Generate rest of trajectory up to max depth
41
+ time_step += 1
42
+ counter = 1
43
+ previous_node = first_alternative_node
44
+ while not terminated and counter <= max_depth:
45
+ terminated, step_log = await markov_game.step()
46
+ current_node = RolloutTreeNode(step_log=step_log, time_step=time_step)
47
+ previous_node.child = current_node
48
+ previous_node = current_node
49
+ counter += 1
50
+ time_step += 1
51
+
52
+ if branch_node.branches == None:
53
+ branch_node.branches = {agent_id: [first_alternative_node]}
54
+ else:
55
+ agent_branches = branch_node.branches.get(agent_id, [])
56
+ agent_branches.append(first_alternative_node)
57
+ branch_node.branches[agent_id] = agent_branches
58
+
59
+
60
+ async def AlternativeActionsRunner(
61
+ markov_game: MarkovGame,
62
+ output_folder: str,
63
+ nb_alternative_actions: int,
64
+ max_depth: int,
65
+ branch_only_on_new_round: bool = False,
66
+ ):
67
+ """
68
+ This method generates a trajectory with partially completed branches,
69
+ where the branching comes from taking unilateraly different actions.
70
+ The resulting data is used to estimate the updated advantage alignment policy gradient terms.
71
+ Let k := nb_sub_steps. Then the number of steps generated is O(Tk), where T is
72
+ the maximum trajectory length.
73
+ """
74
+
75
+ tasks = []
76
+ time_step = 0
77
+ terminated = False
78
+ root = RolloutTreeRootNode(
79
+ id=markov_game.get_id(),
80
+ crn_id=markov_game.get_crn_id()
81
+ )
82
+ previous_node = root
83
+
84
+ while not terminated:
85
+ mg_before_action = markov_game.get_safe_copy()
86
+
87
+ # Get safe copies for main branch
88
+ agent_action_safe_copies: dict[
89
+ AgentId, AgentAndActionSafeCopy
90
+ ] = await markov_game.get_actions_of_agents_without_side_effects()
91
+
92
+ markov_game.set_actions_of_agents_manually(agent_action_safe_copies)
93
+ terminated = markov_game.take_simulation_step()
94
+ main_node = RolloutTreeNode(
95
+ step_log=markov_game.get_step_log(), time_step=time_step
96
+ )
97
+ branch_node = RolloutTreeBranchNode(main_child=main_node)
98
+ previous_node.child = branch_node
99
+ previous_node = main_node
100
+
101
+ # Get alternative branches by generating new unilateral actions
102
+ for agent_id in markov_game.agent_ids:
103
+ for _ in range(nb_alternative_actions):
104
+ # Get safe copies for branches
105
+ branch_agent_action_safe_copies: dict[
106
+ AgentId, AgentAndActionSafeCopy
107
+ ] = {
108
+ agent_id: AgentAndActionSafeCopy(
109
+ action=copy.deepcopy(agent_action_safe_copy.action),
110
+ action_info=copy.deepcopy(agent_action_safe_copy.action_info),
111
+ agent_after_action=agent_action_safe_copy.agent_after_action.get_safe_copy(),
112
+ )
113
+ for agent_id, agent_action_safe_copy in agent_action_safe_copies.items()
114
+ }
115
+ mg_branch: MarkovGame = mg_before_action.get_safe_copy()
116
+ other_agent_id = [id for id in mg_branch.agent_ids if id != agent_id][0]
117
+ mg_branch.set_action_and_agent_after_action_manually(
118
+ agent_id=other_agent_id,
119
+ agent_action_safe_copy=branch_agent_action_safe_copies[
120
+ other_agent_id
121
+ ],
122
+ )
123
+ task = asyncio.create_task(
124
+ run_with_unilateral_alt_action(
125
+ markov_game=mg_branch,
126
+ time_step=time_step,
127
+ agent_id=agent_id,
128
+ branch_node=branch_node,
129
+ max_depth=max_depth,
130
+ )
131
+ )
132
+ tasks.append(task)
133
+ time_step += 1
134
+
135
+ # wait for all branches to complete
136
+ await asyncio.gather(*tasks)
137
+
138
+ return root
src_code_for_reproducibility/markov_games/group_timesteps.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module contains the logic for grouping time steps.
3
+ """
4
+ import copy
5
+ from typing import Callable
6
+
7
+ from mllm.markov_games.markov_game import MarkovGame
8
+ from mllm.markov_games.rollout_tree import (
9
+ AgentActLog,
10
+ RolloutTreeBranchNode,
11
+ RolloutTreeNode,
12
+ RolloutTreeRootNode,
13
+ StepLog,
14
+ )
15
+ from mllm.markov_games.simulation import SimulationStepLog
16
+
17
+ AgentId = str
18
+
19
+
20
+ def group_time_steps(
21
+ rollout_tree: RolloutTreeRootNode,
22
+ accumulation_stop_condition: Callable[[StepLog], bool],
23
+ ) -> RolloutTreeRootNode:
24
+ """
25
+ During generation, we create rollout trees according to the real time steps.
26
+ However, during training, we might want to treat groups of time steps as a single time step.
27
+ As a concrete example, take Trust-and-Split. At each round, say we have X time steps of communication and then one time step for the split.
28
+ Then the communication actions will not get any reward, and the split action will get the reward. During REINFORCE training, with discounting, this
29
+ can cause training instability. We could instead treat every action in the round as being part of a single action, and give it the reward of the split action.
30
+ This method helps to do this sort of grouping.
31
+ It accumulates actions until the accumulation_stop_condition is met, and then creates a new node with the accumulated actions.
32
+ It then recursively calls itself on the child node.
33
+ Details:
34
+ - The reward for the group is the reward of the last time step in the group.
35
+ - The simulation log for the group is the simulation log of the last time step in the group.
36
+ - The state end for the group becomes the first state end in the group.
37
+ - The agent info for the group is the agent info of the last time step in the group.
38
+ """
39
+
40
+ def group_step_logs(step_logs: list[StepLog]) -> StepLog:
41
+ """
42
+ Concatenate per-agent chat turns across steps; keep only the first is_state_end.
43
+ """
44
+ last_sim_log = step_logs[-1].simulation_step_log
45
+ agent_ids = {aid for s in step_logs for aid in s.action_logs.keys()}
46
+ grouped_logs: dict[AgentId, AgentActLog] = {}
47
+ for aid in agent_ids:
48
+ turns = []
49
+ for s in step_logs:
50
+ act = s.action_logs.get(aid)
51
+ if act and act.chat_turns:
52
+ turns.extend(copy.deepcopy(act.chat_turns))
53
+ disable_is_state_end = False
54
+ # Only the first state_end should be True, the rest should be False
55
+ for t in turns:
56
+ if t.is_state_end:
57
+ if disable_is_state_end:
58
+ t.is_state_end = False
59
+ else:
60
+ disable_is_state_end = True
61
+ continue
62
+ grouped_logs[aid] = AgentActLog(
63
+ chat_turns=turns, info=step_logs[-1].action_logs[aid].info
64
+ )
65
+ return StepLog(action_logs=grouped_logs, simulation_step_log=last_sim_log)
66
+
67
+ def group_time_steps_rec(
68
+ current_node: RolloutTreeNode | RolloutTreeBranchNode,
69
+ group_time_step: int,
70
+ accumulation_step_logs: list[StepLog],
71
+ ) -> RolloutTreeNode | RolloutTreeBranchNode:
72
+ """
73
+ Groups time steps. Recursion is used to handle branches.
74
+ """
75
+ assert isinstance(current_node, RolloutTreeNode) or isinstance(
76
+ current_node, RolloutTreeBranchNode
77
+ ), "Current node must be a tree node or a branch node. Is of type: " + str(
78
+ type(current_node)
79
+ )
80
+ first_group_node = None
81
+ current_group_node = None
82
+ while current_node is not None:
83
+ if isinstance(current_node, RolloutTreeBranchNode):
84
+ raise Exception(
85
+ "Grouping timesteps by round is not supported for branching trajectories yet."
86
+ )
87
+ # Special recursive case for branches
88
+ # if isinstance(current_node, RolloutTreeBranchNode):
89
+ # branches = {}
90
+ # for agent_id, branch_nodes in current_node.branches.items():
91
+ # branch_group_nodes = []
92
+ # for branch_node in branch_nodes:
93
+ # branch_group_node = group_time_steps_rec(
94
+ # current_node=branch_node,
95
+ # group_time_step=group_time_step,
96
+ # accumulation_step_logs=copy.deepcopy(accumulation_step_logs))
97
+ # branch_group_nodes.append(branch_group_node)
98
+ # branches[agent_id] = branch_group_nodes
99
+
100
+ # main_child_group_node = group_time_steps_rec(
101
+ # current_node=current_node.main_child,
102
+ # group_time_step=group_time_step,
103
+ # accumulation_step_logs=copy.deepcopy(accumulation_step_logs))
104
+
105
+ # return RolloutTreeBranchNode(main_child=main_child_group_node, branches=branches)
106
+
107
+ # Accumulate
108
+ accumulation_step_logs.append(current_node.step_log)
109
+ if accumulation_stop_condition(current_node.step_log):
110
+ grouped_step_logs = group_step_logs(accumulation_step_logs)
111
+ accumulation_step_logs = []
112
+ new_group_node = RolloutTreeNode(
113
+ step_log=grouped_step_logs, time_step=group_time_step, child=None
114
+ )
115
+ if first_group_node == None:
116
+ first_group_node = new_group_node
117
+ group_time_step += 1
118
+ if current_group_node is not None:
119
+ current_group_node.child = new_group_node
120
+ current_group_node = new_group_node
121
+ current_node = current_node.child
122
+ return first_group_node
123
+
124
+ node = group_time_steps_rec(
125
+ current_node=rollout_tree.child, group_time_step=0, accumulation_step_logs=[]
126
+ )
127
+ return RolloutTreeRootNode(
128
+ id=rollout_tree.id,
129
+ crn_id=rollout_tree.crn_id,
130
+ child=node,
131
+ agent_ids=rollout_tree.agent_ids,
132
+ )
133
+
134
+
135
+ def stop_when_round_ends(step_log: StepLog) -> bool:
136
+ """
137
+ Simplest stop condition. Will return True if step log is the last time step of a round.
138
+ This will throw an error if this information is not available in the simulation info.
139
+ """
140
+ assert (
141
+ "is_last_timestep_in_round" in step_log.simulation_step_log.info.keys()
142
+ ), "To group by round, is_last_timestep_in_round must be set in the info of your simulation step log at each time step."
143
+ return step_log.simulation_step_log.info["is_last_timestep_in_round"]
144
+
145
+
146
+ def group_by_round(rollout_tree: RolloutTreeRootNode) -> RolloutTreeRootNode:
147
+ """
148
+ Groups time steps by round.
149
+ """
150
+ return group_time_steps(rollout_tree, stop_when_round_ends)
src_code_for_reproducibility/markov_games/linear_runner.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import os.path
4
+
5
+ from mllm.markov_games.markov_game import MarkovGame
6
+ from mllm.markov_games.rollout_tree import RolloutTreeNode, RolloutTreeRootNode
7
+
8
+
9
+ async def LinearRunner(
10
+ markov_game: MarkovGame, output_folder: str
11
+ ) -> RolloutTreeRootNode:
12
+ """
13
+ This method generates a trajectory without branching.
14
+ """
15
+ time_step = 0
16
+ terminated = False
17
+ root = RolloutTreeRootNode(
18
+ id=markov_game.get_id(),
19
+ crn_id=markov_game.get_crn_id(),
20
+ agent_ids=markov_game.get_agent_ids(),
21
+ )
22
+ previous_node = root
23
+ while not terminated:
24
+ terminated, step_log = await markov_game.step()
25
+ current_node = RolloutTreeNode(step_log=step_log, time_step=time_step)
26
+ previous_node.child = current_node
27
+ previous_node = current_node
28
+ time_step += 1
29
+
30
+ return root
src_code_for_reproducibility/markov_games/markov_game.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This class unifies a simulation, and the agents acting in it (see `simulation.py` & `agent.py`).
3
+ In a MarkovGame step,
4
+ 1) each agent takes an action,
5
+ 2) the state transitions with respect to these actions,
6
+ 3) all relevant data of the step is appended to the historical data list
7
+
8
+ In order to perform 3), the agents and the simulation are expected, at each time step,
9
+ to return a log of the state transition (from their perspective).
10
+ For instance, the Simulation might send rewards and the agents might send prompting contexts to be used later to generate the training data.
11
+ A different approach would be to simply have the agents keep their data private and log it upon completion of a trajectory.
12
+ The approach we use here centralizes the data gathering aspect,
13
+ making it easy to create sub-trajectories (in the `runners` defined in `runners.py`) descriptions that
14
+ only log information for step transitions occuring after the branching out.
15
+ """
16
+ import asyncio
17
+ import copy
18
+ import json
19
+ import os
20
+ from dataclasses import dataclass
21
+ from typing import Any, List, Literal, Optional, Tuple
22
+
23
+ from transformers.models.idefics2 import Idefics2Config
24
+
25
+ from mllm.markov_games.agent import Agent
26
+ from mllm.markov_games.rollout_tree import AgentActLog, StepLog
27
+ from mllm.markov_games.simulation import Simulation
28
+
29
+ AgentId = str
30
+
31
+
32
+ @dataclass
33
+ class AgentAndActionSafeCopy:
34
+ action: Any
35
+ action_info: AgentActLog
36
+ agent_after_action: type[Agent]
37
+
38
+
39
+ class MarkovGame(object):
40
+ def __init__(
41
+ self,
42
+ id: int,
43
+ agents: dict[AgentId, type[Agent]],
44
+ simulation: type[Simulation],
45
+ crn_id: int,
46
+ ):
47
+ """
48
+ Args:
49
+ agents:
50
+ output_path:
51
+ Path where the step infos are saved.
52
+ simulation:
53
+ Simulation object. Example: IPDSimulation
54
+ """
55
+ self.agents = agents
56
+ self.agent_ids = self.agents.keys()
57
+ self.simulation = simulation
58
+ self.simulation_step_log = None
59
+ self.agent_step_logs = {agent_id: None for agent_id in self.agent_ids}
60
+ self.actions = {}
61
+ self.id = id
62
+ self.crn_id = crn_id
63
+
64
+ def get_id(self) -> str:
65
+ return self.id
66
+
67
+ def get_crn_id(self) -> int:
68
+ return self.crn_id
69
+
70
+ def get_agent_ids(self) -> List[AgentId]:
71
+ return list(self.agent_ids)
72
+
73
+ async def get_action_of_agent_without_side_effects(
74
+ self, agent_id: AgentId
75
+ ) -> Tuple[Any, AgentActLog]:
76
+ """
77
+ Safe function to get an action of an agent without modifying the agent or the simulation.
78
+ """
79
+ agent = self.agents[agent_id]
80
+ agent_before_action = agent.get_safe_copy()
81
+ obs = self.simulation.get_obs_agent(agent_id)
82
+ action, action_info = await agent.act(observation=obs)
83
+ self.agents[agent_id] = agent_before_action
84
+ agent_after_action = agent.get_safe_copy()
85
+ return AgentAndActionSafeCopy(action, action_info, agent_after_action)
86
+
87
+ async def get_actions_of_agents_without_side_effects(
88
+ self,
89
+ ) -> dict[AgentId, AgentAndActionSafeCopy]:
90
+ """
91
+ Safe function to get an action of an agent without modifying the agent or the simulation.
92
+ """
93
+ tasks = []
94
+ for agent_id in self.agent_ids:
95
+ task = asyncio.create_task(
96
+ self.get_action_of_agent_without_side_effects(agent_id)
97
+ )
98
+ tasks.append(task)
99
+ agent_and_action_safe_copies: list[
100
+ AgentAndActionSafeCopy
101
+ ] = await asyncio.gather(*tasks)
102
+ return {
103
+ agent_id: agent_and_action_safe_copy
104
+ for agent_id, agent_and_action_safe_copy in zip(
105
+ self.agent_ids, agent_and_action_safe_copies
106
+ )
107
+ }
108
+
109
+ def set_action_and_agent_after_action_manually(
110
+ self,
111
+ agent_id: AgentId,
112
+ agent_action_safe_copy: AgentAndActionSafeCopy,
113
+ ):
114
+ """
115
+ Set the action and the agent after action manually.
116
+ """
117
+ self.actions[agent_id] = agent_action_safe_copy.action
118
+ self.agent_step_logs[agent_id] = agent_action_safe_copy.action_info
119
+ self.agents[agent_id] = agent_action_safe_copy.agent_after_action
120
+
121
+ def set_actions_of_agents_manually(
122
+ self, actions: dict[AgentId, AgentAndActionSafeCopy]
123
+ ):
124
+ """
125
+ Set the actions of agents manually.
126
+ """
127
+ for agent_id, agent_action_safe_copy in actions.items():
128
+ self.set_action_and_agent_after_action_manually(
129
+ agent_id, agent_action_safe_copy
130
+ )
131
+
132
+ async def set_action_of_agent(self, agent_id: AgentId):
133
+ """
134
+ TOWRITE
135
+ """
136
+ agent = self.agents[agent_id]
137
+ obs = self.simulation.get_obs_agent(agent_id)
138
+ action, action_info = await agent.act(observation=obs)
139
+ self.actions[agent_id] = action
140
+ self.agent_step_logs[agent_id] = action_info
141
+
142
+ async def set_actions(self):
143
+ """
144
+ TOWRITE
145
+ """
146
+ # background_tasks = set()
147
+ tasks = []
148
+ for agent_id in self.agent_ids:
149
+ task = asyncio.create_task(self.set_action_of_agent(agent_id))
150
+ tasks.append(task)
151
+ await asyncio.gather(*tasks)
152
+
153
+ def take_simulation_step(self):
154
+ """
155
+ TOWRITE
156
+ """
157
+ terminated, self.simulation_step_log = self.simulation.step(self.actions)
158
+ return terminated
159
+
160
+ def get_step_log(self) -> StepLog:
161
+ """
162
+ TOWRITE
163
+ TODO: assert actions and simulation have taken step
164
+ """
165
+ step_log = StepLog(
166
+ simulation_step_log=self.simulation_step_log,
167
+ action_logs=self.agent_step_logs,
168
+ )
169
+ return step_log
170
+
171
+ async def step(self) -> Tuple[bool, StepLog]:
172
+ """
173
+ TOWRITE
174
+ """
175
+ await self.set_actions()
176
+ terminated = self.take_simulation_step()
177
+ step_log = self.get_step_log()
178
+ return terminated, step_log
179
+
180
+ def get_safe_copy(self):
181
+ """
182
+ TOWRITE
183
+ """
184
+
185
+ new_markov_game = copy.copy(self)
186
+ new_simulation = self.simulation.get_safe_copy()
187
+ new_agents = {
188
+ agent_id: agent.get_safe_copy() for agent_id, agent in self.agents.items()
189
+ }
190
+
191
+ # Reassign copied components
192
+ new_markov_game.simulation = new_simulation
193
+ new_markov_game.agents = new_agents
194
+
195
+ # IMPORTANT: ensure agent_ids references the new agents dict, not the original
196
+ new_markov_game.agent_ids = new_markov_game.agents.keys()
197
+
198
+ # Deep-copy step data to avoid correlation
199
+ new_markov_game.simulation_step_log = copy.deepcopy(self.simulation_step_log)
200
+ new_markov_game.actions = copy.deepcopy(self.actions)
201
+ # Rebuild logs to align exactly with new agent ids
202
+ old_agent_step_logs = copy.deepcopy(self.agent_step_logs)
203
+ new_markov_game.agent_step_logs = {
204
+ agent_id: old_agent_step_logs.get(agent_id)
205
+ for agent_id in new_markov_game.agent_ids
206
+ }
207
+
208
+ return new_markov_game
src_code_for_reproducibility/markov_games/mg_utils.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import copy
3
+ from collections.abc import Callable
4
+ from dataclasses import dataclass
5
+
6
+ from mllm.markov_games.ipd.ipd_agent import IPDAgent
7
+ from mllm.markov_games.ipd.ipd_simulation import IPD
8
+ from mllm.markov_games.markov_game import MarkovGame
9
+ from mllm.markov_games.negotiation.dond_agent import DealNoDealAgent
10
+ from mllm.markov_games.negotiation.dond_simulation import DealNoDealSimulation
11
+ from mllm.markov_games.negotiation.nego_hard_coded_policies import (
12
+ HardCodedNegoGreedyPolicy,
13
+ HardCodedNegoWelfareMaximizingPolicy,
14
+ )
15
+ from mllm.markov_games.ipd.Ipd_hard_coded_agents import AlwaysCooperateIPDAgent, AlwaysDefectIPDAgent
16
+ from mllm.markov_games.negotiation.no_press_nego_agent import NoPressAgent
17
+ from mllm.markov_games.negotiation.no_press_nego_simulation import NoPressSimulation
18
+ from mllm.markov_games.negotiation.tas_agent import TrustAndSplitAgent
19
+ from mllm.markov_games.negotiation.tas_rps_agent import TrustAndSplitRPSAgent
20
+ from mllm.markov_games.negotiation.tas_rps_simulation import TrustAndSplitRPSSimulation
21
+ from mllm.markov_games.negotiation.tas_simple_agent import TrustAndSplitSimpleAgent
22
+ from mllm.markov_games.negotiation.tas_simple_simulation import (
23
+ TrustAndSplitSimpleSimulation,
24
+ )
25
+ from mllm.markov_games.negotiation.tas_simulation import TrustAndSplitSimulation
26
+ from mllm.markov_games.rollout_tree import (
27
+ AgentActLog,
28
+ RolloutTreeBranchNode,
29
+ RolloutTreeNode,
30
+ RolloutTreeRootNode,
31
+ StepLog,
32
+ )
33
+ from mllm.markov_games.simulation import SimulationStepLog
34
+
35
+ AgentId = str
36
+
37
+
38
+ @dataclass
39
+ class AgentConfig:
40
+ agent_id: str
41
+ agent_name: str
42
+ agent_class_name: str
43
+ policy_id: str
44
+ init_kwargs: dict
45
+
46
+
47
+ @dataclass
48
+ class MarkovGameConfig:
49
+ id: int
50
+ seed: int
51
+ simulation_class_name: str
52
+ simulation_init_args: dict
53
+ agent_configs: list[AgentConfig]
54
+
55
+
56
+ def init_markov_game_components(
57
+ config: MarkovGameConfig, policies: dict[str, Callable[[list[dict]], str]]
58
+ ):
59
+ """
60
+ TOWRITE
61
+ """
62
+ agents = {}
63
+ agent_names = []
64
+ for agent_config in config.agent_configs:
65
+ agent_id = agent_config.agent_id
66
+ agent_name = agent_config.agent_name
67
+ agent_class = eval(agent_config.agent_class_name)
68
+ agent = agent_class(
69
+ seed=config.seed,
70
+ agent_id=agent_id,
71
+ agent_name=agent_name,
72
+ policy=policies[agent_config.policy_id],
73
+ **agent_config.init_kwargs,
74
+ )
75
+ agents[agent_id] = agent
76
+ agent_names.append(agent_name)
77
+ simulation = eval(config.simulation_class_name)(
78
+ seed=config.seed,
79
+ agent_ids=list(agents.keys()),
80
+ agent_names=agent_names,
81
+ **config.simulation_init_args,
82
+ )
83
+ markov_game = MarkovGame(
84
+ id=config.id,
85
+ crn_id=config.seed,
86
+ agents=agents,
87
+ simulation=simulation,
88
+ )
89
+ return markov_game
src_code_for_reproducibility/markov_games/negotiation/__pycache__/negotiation_statistics.cpython-312.pyc ADDED
Binary file (14.1 kB). View file
 
src_code_for_reproducibility/markov_games/rollout_tree.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TODO: add parent to nodes so that some verification can be done. For instance, to ensure that node reward keys match the parent node.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import json
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+ from typing import Any, List, Literal, Optional, Tuple
11
+
12
+ import jsonschema
13
+ from pydantic import BaseModel, Field, model_validator
14
+
15
+ from mllm.chat_utils.chat_turn import ChatTurn
16
+
17
+ AgentId = str
18
+
19
+
20
+ class SimulationStepLog(BaseModel):
21
+ rewards: dict[AgentId, float]
22
+ info: Any = None
23
+
24
+
25
+ class AgentActLog(BaseModel):
26
+ chat_turns: list[ChatTurn] | None
27
+ info: Any = None
28
+
29
+ @model_validator(mode="after")
30
+ def _exactly_one_state_end(self):
31
+ """
32
+ This method is used to enforce that for each AgentActLog, there is exactly one ChatTurn which is a state end.
33
+ """
34
+ if self.chat_turns != []:
35
+ n = sum(1 for t in self.chat_turns if t.is_state_end)
36
+ if n != 1:
37
+ raise ValueError(
38
+ f"AgentActLog must have exactly one ChatTurn with is_state_end=True; got {self.chat_turns}."
39
+ )
40
+ return self
41
+ else:
42
+ return self
43
+
44
+
45
+ class StepLog(BaseModel):
46
+ action_logs: dict[AgentId, AgentActLog]
47
+ simulation_step_log: SimulationStepLog
48
+
49
+
50
+ # BranchType = Literal["unilateral_deviation", "common_deviation"] # might not be necessary
51
+ # class BranchNodeInfo(BaseModel):
52
+ # branch_id: str
53
+ # branch_for: AgentId
54
+ # branch_type: BranchType
55
+
56
+
57
+ class RolloutTreeNode(BaseModel):
58
+ step_log: StepLog
59
+ time_step: int
60
+ child: RolloutTreeNode | RolloutTreeBranchNode | None = None
61
+
62
+
63
+ class RolloutTreeBranchNode(BaseModel):
64
+ """
65
+ First item of the tuple indicates which agent "called" for an alternative branch.
66
+ """
67
+
68
+ main_child: RolloutTreeNode
69
+ branches: dict[AgentId, list[RolloutTreeNode]] | None = None
70
+
71
+
72
+ class RolloutTreeRootNode(BaseModel):
73
+ id: int
74
+ crn_id: int # ID of the rng used to generate this rollout tree
75
+ child: RolloutTreeNode | RolloutTreeBranchNode | None = None
76
+ agent_ids: List[AgentId] = Field(min_length=1)
77
+
78
+
79
+ # class RolloutTreeLeafNode(BaseModel):
80
+ # step_log: StepLog
81
+ # time_step: int
82
+
83
+
84
+ # Necessary for self-referential stuff in pydantic
85
+ RolloutTreeBranchNode.model_rebuild()
86
+ RolloutTreeNode.model_rebuild()
src_code_for_reproducibility/markov_games/run_markov_games.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from collections.abc import Callable
3
+ from dataclasses import dataclass
4
+
5
+ from torch._C import ClassType
6
+
7
+ from mllm.markov_games.markov_game import MarkovGame
8
+ from mllm.markov_games.rollout_tree import RolloutTreeRootNode
9
+
10
+
11
+ async def run_markov_games(
12
+ runner: Callable[[MarkovGame], RolloutTreeRootNode],
13
+ runner_kwargs: dict,
14
+ output_folder: str,
15
+ markov_games: list[MarkovGame],
16
+ ) -> list[RolloutTreeRootNode]:
17
+ tasks = []
18
+ for mg in markov_games:
19
+ tasks.append(
20
+ asyncio.create_task(
21
+ runner(markov_game=mg, output_folder=output_folder, **runner_kwargs)
22
+ )
23
+ )
24
+ return await asyncio.gather(*tasks)
src_code_for_reproducibility/markov_games/simulation.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A Simulation is the environment of a Markov Game.
3
+ The Simulation is not responsible for properly checking / formatting the responses of LLM's.
4
+ This is the job of the `Agent` class.
5
+ Simulations expect clean actions, and are defined similarly to `gymnasium` environments, except that they are adapted for the Multi-agent setting.
6
+ """
7
+
8
+ from abc import ABC, abstractmethod
9
+ from typing import Any, Tuple
10
+
11
+ from numpy.random import default_rng
12
+
13
+ from mllm.markov_games.rollout_tree import SimulationStepLog
14
+
15
+
16
+ class Simulation(ABC):
17
+ @abstractmethod
18
+ def __init__(self, seed: int, *args, **kwargs):
19
+ self.seed = seed
20
+ self.rng = default_rng(self.seed)
21
+
22
+ @abstractmethod
23
+ def step(self, actions: Any) -> Tuple[bool, SimulationStepLog]:
24
+ """
25
+ Returns terminated, info
26
+ """
27
+ raise NotImplementedError
28
+
29
+ def get_obs(self):
30
+ """Returns all agent observations in dict
31
+
32
+ Returns:
33
+ observations
34
+ """
35
+ raise NotImplementedError
36
+
37
+ def get_obs_agent(self, agent_id):
38
+ """Returns observation for agent_id"""
39
+ raise NotImplementedError
40
+
41
+ def get_obs_size(self):
42
+ """Returns the shape of the observation"""
43
+ raise NotImplementedError
44
+
45
+ def get_state(self):
46
+ raise NotImplementedError
47
+
48
+ def get_state_size(self):
49
+ """Returns the shape of the state"""
50
+ raise NotImplementedError
51
+
52
+ def get_avail_actions(self):
53
+ raise NotImplementedError
54
+
55
+ def get_avail_agent_actions(self, agent_id):
56
+ """Returns the available actions for agent_id"""
57
+ raise NotImplementedError
58
+
59
+ def get_total_actions(self):
60
+ """Returns the total number of actions an agent could ever take"""
61
+ # TODO: This is only suitable for a discrete 1 dimensional action space for each agent
62
+ raise NotImplementedError
63
+
64
+ def get_safe_copy(self):
65
+ """
66
+ Return copy of the agent object that is decorrelated from the original object.
67
+ """
68
+ raise NotImplementedError
69
+
70
+ def reset(self):
71
+ """Returns initial observations and states"""
72
+ raise NotImplementedError
73
+
74
+ def render(self):
75
+ raise NotImplementedError
76
+
77
+ def close(self):
78
+ raise NotImplementedError
79
+
80
+ # def seed(self):
81
+ # raise NotImplementedError
82
+
83
+ def save_replay(self):
84
+ raise NotImplementedError
85
+
86
+ def get_simulation_info(self):
87
+ raise NotImplementedError
src_code_for_reproducibility/markov_games/statistics_runner.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import gc
4
+ import json
5
+ import pickle
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+ from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional
9
+
10
+ from basic_render import find_iteration_folders
11
+
12
+ from mllm.markov_games.rollout_tree import (
13
+ RolloutTreeBranchNode,
14
+ RolloutTreeNode,
15
+ RolloutTreeRootNode,
16
+ SimulationStepLog,
17
+ )
18
+
19
+
20
+ def _iterate_main_nodes(root: RolloutTreeRootNode) -> Iterator[RolloutTreeNode]:
21
+ """
22
+ Iterate the main path nodes without materializing full path lists.
23
+ """
24
+ current = root.child
25
+ while current is not None:
26
+ if isinstance(current, RolloutTreeNode):
27
+ yield current
28
+ current = current.child
29
+ elif isinstance(current, RolloutTreeBranchNode):
30
+ # Follow only the main child on the main trajectory
31
+ current = current.main_child
32
+ else:
33
+ break
34
+
35
+
36
+ def iterate_main_simulation_logs(
37
+ root: RolloutTreeRootNode,
38
+ ) -> Iterator[SimulationStepLog]:
39
+ for node in _iterate_main_nodes(root):
40
+ yield node.step_log.simulation_step_log
41
+
42
+
43
+ def stream_rollout_files(iteration_folder: Path) -> Iterator[Path]:
44
+ for p in iteration_folder.rglob("*.rt.pkl"):
45
+ if p.is_file():
46
+ yield p
47
+
48
+
49
+ def load_root(path: Path) -> RolloutTreeRootNode:
50
+ with open(path, "rb") as f:
51
+ data = pickle.load(f)
52
+ return RolloutTreeRootNode.model_validate(data)
53
+
54
+
55
+ @dataclass
56
+ class StatRecord:
57
+ mgid: int
58
+ crn_id: Optional[int]
59
+ iteration: str
60
+ values: Dict[str, Any]
61
+
62
+
63
+ class StatComputer:
64
+ """
65
+ Stateful stat computer that consumes SimulationStepLog instances
66
+ and produces final aggregated values for one rollout (mgid).
67
+ """
68
+
69
+ def update(self, sl: SimulationStepLog) -> None: # pragma: no cover - interface
70
+ raise NotImplementedError
71
+
72
+ def finalize(self) -> Dict[str, Any]: # pragma: no cover - interface
73
+ raise NotImplementedError
74
+
75
+
76
+ def run_stats(
77
+ data_root: Path,
78
+ game_name: str,
79
+ make_computers: Callable[[], List[StatComputer]],
80
+ output_filename: Optional[str] = None,
81
+ output_format: str = "json", # "json" (dict of lists) or "jsonl"
82
+ ) -> Path:
83
+ """
84
+ Compute stats across all iteration_* folders under data_root.
85
+ Writes JSONL to data_root/statistics/<output_filename or f"{game_name}.stats.jsonl">.
86
+ """
87
+ data_root = Path(data_root)
88
+ outdir = data_root / "statistics"
89
+ outdir.mkdir(parents=True, exist_ok=True)
90
+ # Choose extension by format
91
+ default_name = (
92
+ f"{game_name}.stats.json"
93
+ if output_format == "json"
94
+ else f"{game_name}.stats.jsonl"
95
+ )
96
+ outfile = outdir / (
97
+ output_filename if output_filename is not None else default_name
98
+ )
99
+
100
+ # Rewrite file each run to keep it clean and small
101
+ if outfile.exists():
102
+ outfile.unlink()
103
+
104
+ iteration_folders = find_iteration_folders(str(data_root))
105
+
106
+ # If writing JSONL, stream directly; otherwise accumulate minimal records
107
+ if output_format == "jsonl":
108
+ with open(outfile, "w", encoding="utf-8") as w:
109
+ for iteration_folder in iteration_folders:
110
+ iteration_name = Path(iteration_folder).name
111
+ for pkl_path in stream_rollout_files(Path(iteration_folder)):
112
+ root = load_root(pkl_path)
113
+
114
+ computers = make_computers()
115
+ for sl in iterate_main_simulation_logs(root):
116
+ for comp in computers:
117
+ try:
118
+ comp.update(sl)
119
+ except Exception:
120
+ continue
121
+
122
+ values: Dict[str, Any] = {}
123
+ for comp in computers:
124
+ try:
125
+ values.update(comp.finalize())
126
+ except Exception:
127
+ continue
128
+
129
+ rec = {
130
+ "mgid": getattr(root, "id", None),
131
+ "crn_id": getattr(root, "crn_id", None),
132
+ "iteration": iteration_name,
133
+ "stats": values,
134
+ }
135
+ w.write(json.dumps(rec, ensure_ascii=False) + "\n")
136
+
137
+ del root
138
+ del computers
139
+ gc.collect()
140
+ else:
141
+ # Aggregate to dict-of-lists for easier plotting
142
+ records: List[Dict[str, Any]] = []
143
+ # Process in deterministic order
144
+ for iteration_folder in iteration_folders:
145
+ iteration_name = Path(iteration_folder).name
146
+ for pkl_path in stream_rollout_files(Path(iteration_folder)):
147
+ root = load_root(pkl_path)
148
+
149
+ computers = make_computers()
150
+ for sl in iterate_main_simulation_logs(root):
151
+ for comp in computers:
152
+ try:
153
+ comp.update(sl)
154
+ except Exception:
155
+ continue
156
+
157
+ values: Dict[str, Any] = {}
158
+ for comp in computers:
159
+ try:
160
+ values.update(comp.finalize())
161
+ except Exception:
162
+ continue
163
+
164
+ records.append(
165
+ {
166
+ "mgid": getattr(root, "id", None),
167
+ "crn_id": getattr(root, "crn_id", None),
168
+ "iteration": iteration_name,
169
+ "stats": values,
170
+ }
171
+ )
172
+
173
+ del root
174
+ del computers
175
+ gc.collect()
176
+
177
+ # Build dict-of-lists with nested stats preserved
178
+ # Collect all stat keys and nested agent keys where needed
179
+ mgids: List[Any] = []
180
+ crn_ids: List[Any] = []
181
+ iterations_out: List[str] = []
182
+ # stats_out is a nested structure mirroring keys but with lists
183
+ stats_out: Dict[str, Any] = {}
184
+
185
+ # First pass to collect union of keys
186
+ stat_keys: set[str] = set()
187
+ nested_agent_keys: Dict[str, set[str]] = {}
188
+ for r in records:
189
+ stats = r.get("stats", {}) or {}
190
+ for k, v in stats.items():
191
+ stat_keys.add(k)
192
+ if isinstance(v, dict):
193
+ nested = nested_agent_keys.setdefault(k, set())
194
+ for ak in v.keys():
195
+ nested.add(str(ak))
196
+
197
+ # Initialize structure
198
+ for k in stat_keys:
199
+ if k in nested_agent_keys:
200
+ stats_out[k] = {ak: [] for ak in sorted(nested_agent_keys[k])}
201
+ else:
202
+ stats_out[k] = []
203
+
204
+ # Fill lists
205
+ for r in records:
206
+ mgids.append(r.get("mgid"))
207
+ crn_ids.append(r.get("crn_id"))
208
+ iterations_out.append(r.get("iteration"))
209
+ stats = r.get("stats", {}) or {}
210
+ for k in stat_keys:
211
+ val = stats.get(k)
212
+ if isinstance(stats_out[k], dict):
213
+ # per-agent dict
214
+ agent_dict = val if isinstance(val, dict) else {}
215
+ for ak in stats_out[k].keys():
216
+ stats_out[k][ak].append(agent_dict.get(ak))
217
+ else:
218
+ stats_out[k].append(val)
219
+
220
+ with open(outfile, "w", encoding="utf-8") as w:
221
+ json.dump(
222
+ {
223
+ "mgid": mgids,
224
+ "crn_id": crn_ids,
225
+ "iteration": iterations_out,
226
+ "stats": stats_out,
227
+ },
228
+ w,
229
+ ensure_ascii=False,
230
+ )
231
+
232
+ return outfile
233
+
234
+
235
+ def run_stats_functional(
236
+ data_root: Path,
237
+ game_name: str,
238
+ metrics: Dict[str, Callable[[SimulationStepLog], Optional[Dict[str, float]]]],
239
+ output_filename: Optional[str] = None,
240
+ output_format: str = "json",
241
+ ) -> Path:
242
+ """
243
+ Functional variant where metrics is a dict of name -> f(SimulationStepLog) -> {agent_id: value}.
244
+ Aggregates per rollout by averaging over steps where a metric produced a value.
245
+ Writes a single consolidated file in data_root/statistics/.
246
+ """
247
+ data_root = Path(data_root)
248
+ outdir = data_root / "statistics"
249
+ outdir.mkdir(parents=True, exist_ok=True)
250
+ default_name = (
251
+ f"{game_name}.stats.json"
252
+ if output_format == "json"
253
+ else f"{game_name}.stats.jsonl"
254
+ )
255
+ outfile = outdir / (
256
+ output_filename if output_filename is not None else default_name
257
+ )
258
+
259
+ if outfile.exists():
260
+ outfile.unlink()
261
+
262
+ iteration_folders = find_iteration_folders(str(data_root))
263
+
264
+ def finalize_rollout(
265
+ agg: Dict[str, Dict[str, List[float]]]
266
+ ) -> Dict[str, Dict[str, float]]:
267
+ # avg per metric per agent
268
+ result: Dict[str, Dict[str, float]] = {}
269
+ for mname, agent_values in agg.items():
270
+ result[mname] = {}
271
+ for aid, vals in agent_values.items():
272
+ if not vals:
273
+ result[mname][aid] = None # keep alignment; could be None
274
+ else:
275
+ result[mname][aid] = sum(vals) / len(vals)
276
+ return result
277
+
278
+ if output_format == "jsonl":
279
+ with open(outfile, "w", encoding="utf-8") as w:
280
+ for iteration_folder in iteration_folders:
281
+ iteration_name = Path(iteration_folder).name
282
+ for pkl_path in stream_rollout_files(Path(iteration_folder)):
283
+ root = load_root(pkl_path)
284
+
285
+ # aggregator structure: metric -> agent_id -> list of values
286
+ agg: Dict[str, Dict[str, List[float]]] = {
287
+ m: {} for m in metrics.keys()
288
+ }
289
+
290
+ for sl in iterate_main_simulation_logs(root):
291
+ for mname, fn in metrics.items():
292
+ try:
293
+ vals = fn(sl)
294
+ except Exception:
295
+ vals = None
296
+ if not vals:
297
+ continue
298
+ for aid, v in vals.items():
299
+ if v is None:
300
+ continue
301
+ lst = agg[mname].setdefault(str(aid), [])
302
+ try:
303
+ lst.append(float(v))
304
+ except Exception:
305
+ continue
306
+
307
+ values = finalize_rollout(agg)
308
+ rec = {
309
+ "mgid": getattr(root, "id", None),
310
+ "crn_id": getattr(root, "crn_id", None),
311
+ "iteration": iteration_name,
312
+ "stats": values,
313
+ }
314
+ w.write(json.dumps(rec, ensure_ascii=False) + "\n")
315
+
316
+ del root
317
+ gc.collect()
318
+ else:
319
+ records: List[Dict[str, Any]] = []
320
+ for iteration_folder in iteration_folders:
321
+ iteration_name = Path(iteration_folder).name
322
+ for pkl_path in stream_rollout_files(Path(iteration_folder)):
323
+ root = load_root(pkl_path)
324
+
325
+ agg: Dict[str, Dict[str, List[float]]] = {m: {} for m in metrics.keys()}
326
+ for sl in iterate_main_simulation_logs(root):
327
+ for mname, fn in metrics.items():
328
+ try:
329
+ vals = fn(sl)
330
+ except Exception:
331
+ vals = None
332
+ if not vals:
333
+ continue
334
+ for aid, v in vals.items():
335
+ if v is None:
336
+ continue
337
+ lst = agg[mname].setdefault(str(aid), [])
338
+ try:
339
+ lst.append(float(v))
340
+ except Exception:
341
+ continue
342
+
343
+ values = finalize_rollout(agg)
344
+ records.append(
345
+ {
346
+ "mgid": getattr(root, "id", None),
347
+ "crn_id": getattr(root, "crn_id", None),
348
+ "iteration": iteration_name,
349
+ "stats": values,
350
+ }
351
+ )
352
+
353
+ del root
354
+ gc.collect()
355
+
356
+ # Build dict-of-lists output
357
+ mgids: List[Any] = []
358
+ crn_ids: List[Any] = []
359
+ iterations_out: List[str] = []
360
+ stats_out: Dict[str, Any] = {}
361
+
362
+ stat_keys: set[str] = set()
363
+ nested_agent_keys: Dict[str, set[str]] = {}
364
+ for r in records:
365
+ stats = r.get("stats", {}) or {}
366
+ for k, v in stats.items():
367
+ stat_keys.add(k)
368
+ if isinstance(v, dict):
369
+ nested = nested_agent_keys.setdefault(k, set())
370
+ for ak in v.keys():
371
+ nested.add(str(ak))
372
+
373
+ for k in stat_keys:
374
+ if k in nested_agent_keys:
375
+ stats_out[k] = {ak: [] for ak in sorted(nested_agent_keys[k])}
376
+ else:
377
+ stats_out[k] = []
378
+
379
+ for r in records:
380
+ mgids.append(r.get("mgid"))
381
+ crn_ids.append(r.get("crn_id"))
382
+ iterations_out.append(r.get("iteration"))
383
+ stats = r.get("stats", {}) or {}
384
+ for k in stat_keys:
385
+ val = stats.get(k)
386
+ if isinstance(stats_out[k], dict):
387
+ agent_dict = val if isinstance(val, dict) else {}
388
+ for ak in stats_out[k].keys():
389
+ stats_out[k][ak].append(agent_dict.get(ak))
390
+ else:
391
+ stats_out[k].append(val)
392
+
393
+ with open(outfile, "w", encoding="utf-8") as w:
394
+ json.dump(
395
+ {
396
+ "mgid": mgids,
397
+ "crn_id": crn_ids,
398
+ "iteration": iterations_out,
399
+ "stats": stats_out,
400
+ },
401
+ w,
402
+ ensure_ascii=False,
403
+ )
404
+
405
+ return outfile
src_code_for_reproducibility/markov_games/vine_ppo.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from anytree import Node, RenderTree
2
+ from anytree.exporter import DotExporter
3
+ import os.path
4
+ import asyncio
5
+ from mllm.markov_games.markov_game import MarkovGame
6
+
7
+ async def VinePPORunner(
8
+ markov_game: MarkovGame,
9
+ **kwargs):
10
+ pass
src_code_for_reproducibility/models/__init__.py ADDED
File without changes
src_code_for_reproducibility/models/adapter_training_wrapper.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import logging
4
+ from typing import Union
5
+ from peft import (
6
+ LoraConfig,
7
+ get_peft_model,
8
+ )
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class AdapterWrapper(nn.Module):
14
+ """
15
+ A thin façade that
16
+ • keeps a reference to a *shared* PEFT-wrapped model,
17
+ • ensures `set_adapter(adapter)` is called on every forward,
18
+ • exposes only the parameters that should be trained for that adapter
19
+ (plus whatever extra modules you name).
20
+ """
21
+ def __init__(
22
+ self,
23
+ shared_llm: nn.Module,
24
+ adapter_id: str,
25
+ lora_config: dict,
26
+ path: Union[str, None] = None,
27
+ ):
28
+ super().__init__()
29
+ self.shared_llm = shared_llm
30
+ self.adapter_id = adapter_id
31
+ lora_config = LoraConfig(**lora_config)
32
+ # this modifies the shared llm in place, adding a lora adapter inside
33
+ self.shared_llm = get_peft_model(
34
+ model=shared_llm,
35
+ peft_config=lora_config,
36
+ adapter_name=adapter_id,
37
+ )
38
+ self.shared_llm.train()
39
+ # Load external adapter weights if provided
40
+ loaded_from: str | None = None
41
+ if path:
42
+ try:
43
+ # Supports both local filesystem paths and HF Hub repo IDs
44
+ self.shared_llm.load_adapter(
45
+ is_trainable=True,
46
+ model_id=path,
47
+ adapter_name=adapter_id,
48
+ )
49
+ loaded_from = path
50
+ except Exception as exc: # noqa: BLE001 - want to log any load failure context
51
+ logger.warning(
52
+ f"Adapter '{adapter_id}': failed to load from '{path}': {exc}"
53
+ )
54
+
55
+ if loaded_from:
56
+ logger.info(
57
+ f"Adapter '{adapter_id}': loaded initial weights from '{loaded_from}'."
58
+ )
59
+ else:
60
+ logger.info(
61
+ f"Adapter '{adapter_id}': initialized with fresh weights (no initial weights found)."
62
+ )
63
+
64
+ def parameters(self, recurse: bool = True):
65
+ """
66
+ "recurse" is just for pytorch compatibility
67
+ """
68
+ self.shared_llm.set_adapter(self.adapter_id)
69
+ params = [p for p in self.shared_llm.parameters() if p.requires_grad]
70
+
71
+ return params
72
+
73
+ def get_base_model_logits(self, contexts):
74
+ """
75
+ Run the base model (without adapter) in inference mode, without tracking gradients.
76
+ This is useful to get reference logits for KL-divergence computation.
77
+ """
78
+ with torch.no_grad():
79
+ with self.shared_llm.disable_adapter():
80
+ return self.shared_llm(input_ids=contexts)[0]
81
+
82
+ def forward(self, *args, **kwargs):
83
+ self.shared_llm.set_adapter(self.adapter_id)
84
+ return self.shared_llm(*args, **kwargs)
85
+
86
+ def save_pretrained(self, save_path):
87
+ self.shared_llm.save_pretrained(save_path)
88
+
89
+ def gradient_checkpointing_enable(self, *args, **kwargs):
90
+ self.shared_llm.gradient_checkpointing_enable(*args, **kwargs)
91
+
92
+ @property
93
+ def dtype(self):
94
+ return self.shared_llm.dtype
95
+
96
+ @property
97
+ def device(self):
98
+ return self.shared_llm.device
src_code_for_reproducibility/models/human_policy.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ import re
4
+ import shutil
5
+ import sys
6
+ from typing import Callable, Dict, List, Optional
7
+
8
+ from mllm.markov_games.rollout_tree import ChatTurn
9
+
10
+ try:
11
+ import rstr # For generating example strings from regex
12
+ except Exception: # pragma: no cover
13
+ rstr = None
14
+
15
+
16
+ def _clear_terminal() -> None:
17
+ """
18
+ Clear the terminal screen in a cross-platform manner.
19
+ """
20
+ if sys.stdout.isatty():
21
+ os.system("cls" if os.name == "nt" else "clear")
22
+
23
+
24
+ def _terminal_width(default: int = 100) -> int:
25
+ try:
26
+ return shutil.get_terminal_size().columns
27
+ except Exception:
28
+ return default
29
+
30
+
31
+ def _horizontal_rule(char: str = "─") -> str:
32
+ width = max(20, _terminal_width() - 2)
33
+ return char * width
34
+
35
+
36
+ class _Style:
37
+ # ANSI colors (bright, readable)
38
+ RESET = "\033[0m"
39
+ BOLD = "\033[1m"
40
+ DIM = "\033[2m"
41
+ # Foreground colors
42
+ FG_BLUE = "\033[94m" # user/system headers
43
+ FG_GREEN = "\033[92m" # human response header
44
+ FG_YELLOW = "\033[93m" # notices
45
+ FG_RED = "\033[91m" # errors
46
+ FG_MAGENTA = "\033[95m" # regex
47
+ FG_CYAN = "\033[96m" # tips
48
+
49
+
50
+ def _render_chat(state) -> str:
51
+ """
52
+ Render prior messages in a compact, readable terminal format.
53
+
54
+ Expected message dict keys: {"role": str, "content": str, ...}
55
+ """
56
+ lines: List[str] = []
57
+ lines.append(_horizontal_rule())
58
+ lines.append(f"{_Style.FG_BLUE}{_Style.BOLD} Conversation so far {_Style.RESET}")
59
+ lines.append(_horizontal_rule())
60
+ for chat in state:
61
+ role = chat.role
62
+ content = str(chat.content).strip()
63
+ # Map roles to display names and colors/emojis
64
+ if role == "assistant":
65
+ header = f"{_Style.FG_GREEN}{_Style.BOLD}HUMAN--🧑‍💻{_Style.RESET}"
66
+ elif role == "user":
67
+ header = f"{_Style.FG_BLUE}{_Style.BOLD}USER--⚙️{_Style.RESET}"
68
+ else:
69
+ header = f"[{_Style.DIM}{role.upper()}{_Style.RESET}]"
70
+ lines.append(header)
71
+ # Indent content for readability
72
+ for line in content.splitlines() or [""]:
73
+ lines.append(f" {line}")
74
+ lines.append("")
75
+ lines.append(_horizontal_rule())
76
+ return "\n".join(lines)
77
+
78
+
79
+ async def _async_input(prompt_text: str) -> str:
80
+ """Non-blocking input using a background thread."""
81
+ return await asyncio.to_thread(input, prompt_text)
82
+
83
+
84
+ def _short_regex_example(regex: str, max_len: int = 30) -> Optional[str]:
85
+ """
86
+ Try to produce a short example string that matches the regex.
87
+ We attempt multiple times and pick the first <= max_len.
88
+ """
89
+ if rstr is None:
90
+ return None
91
+ try:
92
+ for _ in range(20):
93
+ candidate = rstr.xeger(regex)
94
+ if len(candidate) <= max_len:
95
+ return candidate
96
+ # Fallback to truncation (may break match, so don't return)
97
+ return None
98
+ except Exception:
99
+ return None
100
+
101
+
102
+ def _detect_input_type(regex: str | None) -> tuple[str, str, str]:
103
+ """
104
+ Detect what type of input is expected based on the regex pattern.
105
+ Returns (input_type, start_tag, end_tag)
106
+ """
107
+ if regex is None:
108
+ return "text", "", ""
109
+
110
+ if "message_start" in regex and "message_end" in regex:
111
+ return "message", "<<message_start>>", "<<message_end>>"
112
+ elif "proposal_start" in regex and "proposal_end" in regex:
113
+ return "proposal", "<<proposal_start>>", "<<proposal_end>>"
114
+ else:
115
+ return "text", "", ""
116
+
117
+
118
+ async def human_policy(state, agent_id, regex: str | None = None) -> str:
119
+ """
120
+ Async human-in-the-loop policy.
121
+
122
+ - Displays prior conversation context in the terminal.
123
+ - Prompts the user for a response.
124
+ - If a regex is provided, validates and re-prompts until it matches.
125
+ - Automatically adds formatting tags based on expected input type.
126
+
127
+ Args:
128
+ prompt: Chat history as a list of {role, content} dicts.
129
+ regex: Optional fullmatch validation pattern.
130
+
131
+ Returns:
132
+ The user's validated response string.
133
+ """
134
+ # Detect input type and formatting
135
+ input_type, start_tag, end_tag = _detect_input_type(regex)
136
+
137
+ while True:
138
+ _clear_terminal()
139
+ print(_render_chat(state))
140
+
141
+ if regex:
142
+ example = _short_regex_example(regex, max_len=30)
143
+ print(
144
+ f"{_Style.FG_MAGENTA}{_Style.BOLD}Expected format (regex fullmatch):{_Style.RESET}"
145
+ )
146
+ print(f" {_Style.FG_MAGENTA}{regex}{_Style.RESET}")
147
+ if example:
148
+ print(
149
+ f"{_Style.FG_CYAN}Example (random, <=30 chars):{_Style.RESET} {example}"
150
+ )
151
+ print(_horizontal_rule("."))
152
+
153
+ # Custom prompt based on input type
154
+ if input_type == "message":
155
+ print(
156
+ f"{_Style.FG_YELLOW}Type your message content (formatting will be added automatically):{_Style.RESET}"
157
+ )
158
+ elif input_type == "proposal":
159
+ print(
160
+ f"{_Style.FG_YELLOW}Type your proposal (number only, formatting will be added automatically):{_Style.RESET}"
161
+ )
162
+ else:
163
+ print(
164
+ f"{_Style.FG_YELLOW}Type your response and press Enter.{_Style.RESET}"
165
+ )
166
+
167
+ print(
168
+ f"{_Style.DIM}Commands: /help to view commands, /refresh to re-render, /quit to abort{_Style.RESET}"
169
+ )
170
+ else:
171
+ print(
172
+ f"{_Style.FG_YELLOW}Type your response and press Enter.{_Style.RESET} {_Style.DIM}(/help for commands){_Style.RESET}"
173
+ )
174
+
175
+ user_in = (await _async_input("> ")).rstrip("\n")
176
+
177
+ # Commands
178
+ if user_in.strip().lower() in {"/help", "/h"}:
179
+ print(f"\n{_Style.FG_CYAN}{_Style.BOLD}Available commands:{_Style.RESET}")
180
+ print(
181
+ f" {_Style.FG_CYAN}/help{_Style.RESET} or {_Style.FG_CYAN}/h{_Style.RESET} Show this help"
182
+ )
183
+ print(
184
+ f" {_Style.FG_CYAN}/refresh{_Style.RESET} or {_Style.FG_CYAN}/r{_Style.RESET} Re-render the conversation and prompt"
185
+ )
186
+ print(
187
+ f" {_Style.FG_CYAN}/quit{_Style.RESET} or {_Style.FG_CYAN}/q{_Style.RESET} Abort the run (raises KeyboardInterrupt)"
188
+ )
189
+ await asyncio.sleep(1.0)
190
+ continue
191
+ if user_in.strip().lower() in {"/refresh", "/r"}:
192
+ continue
193
+ if user_in.strip().lower() in {"/quit", "/q"}:
194
+ raise KeyboardInterrupt("Human aborted run from human_policy")
195
+
196
+ # Add formatting tags if needed
197
+ if start_tag and end_tag:
198
+ formatted_input = f"{start_tag}{user_in}{end_tag}"
199
+ else:
200
+ formatted_input = user_in
201
+
202
+ if regex is None:
203
+ return ChatTurn(
204
+ role="assistant", agent_id=agent_id, content=formatted_input
205
+ )
206
+
207
+ # Validate against regex (fullmatch)
208
+ try:
209
+ pattern = re.compile(regex)
210
+ except re.error as e:
211
+ # If regex is invalid, fall back to accepting any input
212
+ print(
213
+ f"{_Style.FG_RED}Warning:{_Style.RESET} Provided regex is invalid: {e}. Accepting input without validation."
214
+ )
215
+ await asyncio.sleep(0.5)
216
+ return ChatTurn(
217
+ role="assistant", agent_id=agent_id, content=formatted_input
218
+ )
219
+
220
+ if pattern.fullmatch(formatted_input):
221
+ return ChatTurn(
222
+ role="assistant", agent_id=agent_id, content=formatted_input
223
+ )
224
+
225
+ # Show validation error and re-prompt
226
+ print("")
227
+ print(
228
+ f"{_Style.FG_RED}{_Style.BOLD}Input did not match the required format.{_Style.RESET} Please try again."
229
+ )
230
+
231
+ if input_type == "message":
232
+ print(
233
+ f"You entered: {_Style.FG_CYAN}{start_tag}{user_in}{end_tag}{_Style.RESET}"
234
+ )
235
+ print(f"Just type the message content without tags.")
236
+ elif input_type == "proposal":
237
+ print(
238
+ f"You entered: {_Style.FG_CYAN}{start_tag}{user_in}{end_tag}{_Style.RESET}"
239
+ )
240
+ print(f"Just type the number without tags.")
241
+ else:
242
+ print(f"Expected (regex):")
243
+ print(f" {_Style.FG_MAGENTA}{regex}{_Style.RESET}")
244
+
245
+ print(_horizontal_rule("."))
246
+ print(f"{_Style.FG_YELLOW}Press Enter to retry...{_Style.RESET}")
247
+ await _async_input("")
248
+
249
+
250
+ def get_human_policies() -> Dict[str, Callable[[List[Dict]], str]]:
251
+ """
252
+ Expose the human policy in the same map shape used elsewhere.
253
+ """
254
+ # Type hint says Callable[[List[Dict]], str] but we intentionally return the async callable.
255
+ return {"human_policy": human_policy} # type: ignore[return-value]
src_code_for_reproducibility/models/inference_backend.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import dataclass
3
+ from typing import Any, Optional
4
+
5
+
6
+ @dataclass
7
+ class LLMInferenceOutput:
8
+ content: str
9
+ reasoning_content: str | None = None
10
+ log_probs: list[float] | None = None
11
+ out_token_ids: list[int] | None = None
12
+
13
+
14
+ class LLMInferenceBackend(ABC):
15
+ @abstractmethod
16
+ def __init__(self, **kwargs):
17
+ ...
18
+
19
+ @abstractmethod
20
+ def prepare_adapter(
21
+ self, adapter_id: str, weights_got_updated: bool = False
22
+ ) -> None:
23
+ """Ensure adapter is ready/loaded for next generation call."""
24
+
25
+ @abstractmethod
26
+ async def generate(self, prompt: list[dict], regex: Optional[str] = None) -> str:
27
+ ...
28
+
29
+ @abstractmethod
30
+ def toggle_training_mode(self) -> None:
31
+ ...
32
+
33
+ @abstractmethod
34
+ def toggle_eval_mode(self) -> None:
35
+ ...
36
+
37
+ @abstractmethod
38
+ def shutdown(self) -> None:
39
+ ...
src_code_for_reproducibility/models/inference_backend_dummy.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from typing import Optional
3
+
4
+ import rstr
5
+ from transformers import AutoTokenizer
6
+
7
+ from mllm.models.inference_backend import LLMInferenceBackend, LLMInferenceOutput
8
+ from mllm.utils.short_id_gen import generate_short_id
9
+
10
+
11
+ class DummyInferenceBackend(LLMInferenceBackend):
12
+ def __init__(
13
+ self,
14
+ *args,
15
+ **kwargs,
16
+ ):
17
+ pass
18
+
19
+ def prepare_adapter(
20
+ self,
21
+ adapter_id: Optional[str],
22
+ weights_got_updated: bool,
23
+ adapter_path: Optional[str] = None,
24
+ ) -> None:
25
+ pass
26
+
27
+ async def toggle_training_mode(self) -> None:
28
+ await asyncio.sleep(0)
29
+ pass
30
+
31
+ async def toggle_eval_mode(self) -> None:
32
+ await asyncio.sleep(0)
33
+ pass
34
+
35
+ def shutdown(self) -> None:
36
+ pass
37
+
38
+ async def generate(
39
+ self,
40
+ prompt_text: str,
41
+ regex: Optional[str] = None,
42
+ extract_thinking: bool = False,
43
+ ) -> LLMInferenceOutput:
44
+ if regex:
45
+ # Create random string that respects the regex
46
+ return LLMInferenceOutput(
47
+ content=rstr.xeger(regex),
48
+ reasoning_content="I don't think, I am a dummy backend.",
49
+ )
50
+ else:
51
+ return LLMInferenceOutput(
52
+ content="I am a dummy backend without a regex.",
53
+ reasoning_content="I don't think, I am a dummy backend.",
54
+ )
src_code_for_reproducibility/models/inference_backend_sglang.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # new_backend_sglang_offline.py
2
+ from __future__ import annotations
3
+
4
+ import asyncio
5
+ from typing import Any, Optional
6
+
7
+ # import sglang as sgl
8
+
9
+ from mllm.models.inference_backend import LLMInferenceBackend
10
+
11
+
12
+ class SGLangOfflineBackend(LLMInferenceBackend):
13
+ def __init__(
14
+ self,
15
+ model_name: str,
16
+ tokenizer, # unused but kept for parity
17
+ adapter_paths: dict[str, str],
18
+ device: str = "cuda",
19
+ max_model_len: Optional[int] = None,
20
+ enable_lora: bool = True,
21
+ lora_target_modules: Optional[list[str] | str] = None,
22
+ max_loras_per_batch: int = 8,
23
+ engine_kwargs: dict[str, Any] = None,
24
+ ):
25
+ self.model_name = model_name
26
+ self.adapter_paths = adapter_paths
27
+ self.current_adapter: Optional[str] = None
28
+ engine_kwargs = dict(engine_kwargs or {})
29
+ # Map server-style LoRA flags to offline engine ctor
30
+ if enable_lora and adapter_paths:
31
+ engine_kwargs.setdefault("enable_lora", True)
32
+ # The offline Engine mirrors server args; pass a mapping name->path
33
+ engine_kwargs.setdefault("lora_paths", adapter_paths)
34
+ if lora_target_modules is not None:
35
+ engine_kwargs.setdefault("lora_target_modules", lora_target_modules)
36
+ engine_kwargs.setdefault("max_loras_per_batch", max_loras_per_batch)
37
+
38
+ if max_model_len is not None:
39
+ engine_kwargs.setdefault("context_length", max_model_len)
40
+
41
+ # Launch in-process engine (no HTTP server)
42
+ self.llm = sgl.Engine(model_path=model_name, **engine_kwargs) # async-ready
43
+ # SGLang supports: generate(), async_generate(), and async streaming helpers. :contentReference[oaicite:2]{index=2}
44
+
45
+ def is_ready(self) -> bool:
46
+ return True
47
+
48
+ def toggle_training_mode(self) -> None:
49
+ # No explicit KV release API offline; typically you pause usage here.
50
+ pass
51
+
52
+ def toggle_eval_mode(self) -> None:
53
+ pass
54
+
55
+ def shutdown(self) -> None:
56
+ # Engine cleans up on GC; explicit close not required.
57
+ pass
58
+
59
+ def prepare_adapter(self, adapter_id: Optional[str]) -> None:
60
+ # With offline Engine, when LoRA is enabled at init,
61
+ # you select adapter per request via the input batch mapping.
62
+ self.current_adapter = adapter_id
63
+
64
+ async def generate(
65
+ self, prompt_text: str, sampling_params: dict, adapter_id: Optional[str]
66
+ ) -> str:
67
+ # Non-streaming async (batch of 1). For batched prompts, pass a list.
68
+ params = {
69
+ "temperature": sampling_params.get("temperature", 1.0),
70
+ "top_p": sampling_params.get("top_p", 1.0),
71
+ "max_new_tokens": sampling_params.get("max_new_tokens", 128),
72
+ }
73
+ if (tk := sampling_params.get("top_k", -1)) and tk > 0:
74
+ params["top_k"] = tk
75
+ if (mn := sampling_params.get("min_new_tokens")) is not None:
76
+ params["min_new_tokens"] = mn
77
+ if (fp := sampling_params.get("frequency_penalty")) is not None:
78
+ params["frequency_penalty"] = fp
79
+
80
+ # If using multi-LoRA, SGLang lets you provide adapter names aligned to each input.
81
+ prompts = [prompt_text]
82
+ adapters = [adapter_id] if adapter_id else None # or omit for base
83
+ outs = await self.llm.async_generate(
84
+ prompts, params, adapters
85
+ ) # :contentReference[oaicite:3]{index=3}
86
+ return outs[0]["text"]
src_code_for_reproducibility/models/inference_backend_sglang_local_server.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import httpx
4
+ import requests
5
+ from sglang.utils import launch_server_cmd, wait_for_server
6
+
7
+ from mllm.models.inference_backend import LLMInferenceBackend
8
+
9
+
10
+ class HttpSGLangBackend(LLMInferenceBackend):
11
+ def __init__(self, **kwargs):
12
+ super().__init__(**kwargs)
13
+ self.port = None
14
+ self.proc = None
15
+ self.urls = {}
16
+ # track sglang adapter ids separately from your logical ids
17
+ self.sglang_names = {aid: aid for aid in self.adapter_paths.keys()}
18
+ self.needs_loading = {aid: True for aid in self.adapter_paths.keys()}
19
+
20
+ # defaults you already used:
21
+ self.mem_fraction = kwargs.get("mem_fraction_static", 0.6)
22
+ self.dtype = kwargs.get("dtype", "bfloat16")
23
+ self.extra_cli = kwargs.get("extra_cli", "")
24
+ self.disable_radix_cache = kwargs.get("disable_radix_cache", True)
25
+
26
+ def launch(self) -> None:
27
+ # find local hf cache path for server
28
+ from transformers.utils import cached_file
29
+
30
+ local_llm_path = os.path.split(cached_file(self.model_name, "config.json"))[0]
31
+
32
+ lora_str = ""
33
+ if self.adapter_paths:
34
+ lora_str = "--lora-paths " + " ".join(
35
+ f"{aid}={path}" for aid, path in self.adapter_paths.items()
36
+ )
37
+
38
+ cmd = f"""
39
+ python3 -m sglang.launch_server --model-path {local_llm_path} \
40
+ --host 0.0.0.0 {lora_str} \
41
+ {'--disable-radix-cache' if self.disable_radix_cache else ''} \
42
+ --mem-fraction-static {self.mem_fraction} --dtype {self.dtype} {self.extra_cli}
43
+ """
44
+ self.proc, self.port = launch_server_cmd(cmd)
45
+ wait_for_server(f"http://localhost:{self.port}")
46
+ base = f"http://localhost:{self.port}"
47
+ self.urls = dict(
48
+ generate=f"{base}/generate",
49
+ release=f"{base}/release_memory_occupation",
50
+ resume=f"{base}/resume_memory_occupation",
51
+ load_lora=f"{base}/load_lora_adapter",
52
+ unload_lora=f"{base}/unload_lora_adapter",
53
+ )
54
+
55
+ def is_ready(self) -> bool:
56
+ try:
57
+ requests.get(self.urls["generate"], timeout=2)
58
+ return True
59
+ except Exception:
60
+ return False
61
+
62
+ def prepare_adapter(self, adapter_id: str) -> None:
63
+ if adapter_id is None:
64
+ return
65
+ if self.needs_loading.get(adapter_id, False):
66
+ # unload old name if present
67
+ try:
68
+ requests.post(
69
+ self.urls["unload_lora"],
70
+ json={"lora_name": self.sglang_names[adapter_id]},
71
+ timeout=10,
72
+ )
73
+ except Exception:
74
+ pass
75
+ new_name = self._short_id()
76
+ self.sglang_names[adapter_id] = new_name
77
+ requests.post(
78
+ self.urls["load_lora"],
79
+ json={
80
+ "lora_name": new_name,
81
+ "lora_path": self.adapter_paths[adapter_id],
82
+ },
83
+ ).raise_for_status()
84
+ self.needs_loading[adapter_id] = False
85
+
86
+ async def generate(
87
+ self, prompt_text: str, sampling_params: dict, adapter_id: str | None
88
+ ) -> str:
89
+ lora_name = self.sglang_names.get(adapter_id) if adapter_id else None
90
+ payload = {
91
+ "text": [prompt_text],
92
+ "sampling_params": sampling_params,
93
+ }
94
+ if lora_name:
95
+ payload["lora_path"] = [lora_name]
96
+
97
+ timeout = httpx.Timeout(3600.0, connect=3600.0)
98
+ async with httpx.AsyncClient(timeout=timeout) as client:
99
+ resp = await client.post(self.urls["generate"], json=payload)
100
+ resp.raise_for_status()
101
+ return resp.json()[0]["text"]
102
+
103
+ def toggle_training_mode(self) -> None:
104
+ # free KV space while training adapters
105
+ requests.post(
106
+ self.urls["release"], json={"tags": ["kv_cache"]}
107
+ ).raise_for_status()
108
+
109
+ def toggle_eval_mode(self) -> None:
110
+ # re-allocate KV space
111
+ try:
112
+ requests.post(
113
+ self.urls["resume"], json={"tags": ["kv_cache"]}
114
+ ).raise_for_status()
115
+ except Exception:
116
+ pass
117
+
118
+ def shutdown(self) -> None:
119
+ from sglang.utils import terminate_process
120
+
121
+ if self.proc:
122
+ terminate_process(self.proc)
123
+
124
+ def _short_id(self) -> str:
125
+ import uuid
126
+
127
+ return str(uuid.uuid4().int)[:8]
src_code_for_reproducibility/models/inference_backend_vllm.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import re
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from transformers import AutoTokenizer
7
+ from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
8
+ from vllm.inputs import TokensPrompt
9
+ from vllm.lora.request import LoRARequest
10
+ from vllm.sampling_params import GuidedDecodingParams, RequestOutputKind
11
+
12
+ from mllm.models.inference_backend import LLMInferenceBackend, LLMInferenceOutput
13
+ from mllm.utils.short_id_gen import generate_short_id
14
+
15
+
16
+ class VLLMAsyncBackend(LLMInferenceBackend):
17
+ def __init__(
18
+ self,
19
+ model_name: str,
20
+ tokenizer: AutoTokenizer,
21
+ # adapter_paths: dict[str, str],
22
+ engine_init_kwargs: dict = {},
23
+ sampling_params: dict = {},
24
+ ):
25
+ self.model_name = model_name
26
+ # self.adapter_paths = adapter_paths or {}
27
+ # self.current_adapter = None
28
+ # self.vllm_adapter_ids = {
29
+ # adapter_id: generate_short_id() for adapter_id in adapter_paths.keys()
30
+ # }
31
+ self.vllm_adapter_ids = {}
32
+ ea = dict(model=model_name, **engine_init_kwargs)
33
+ # ea["enable_lora"] = True
34
+ # ea["max_loras"] = len(self.vllm_adapter_ids)
35
+ # ea["enable_sleep_mode"] = True
36
+ self.engine = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**ea))
37
+
38
+ self.sampling_params = sampling_params
39
+ self.tokenizer = tokenizer
40
+
41
+ def prepare_adapter(
42
+ self,
43
+ adapter_id: Optional[str],
44
+ adapter_path: Optional[str],
45
+ weights_got_updated: bool,
46
+ ) -> None:
47
+ # self.current_adapter = adapter_id
48
+ if weights_got_updated:
49
+ self.vllm_adapter_ids[adapter_id] = generate_short_id()
50
+ self.current_lora_request = LoRARequest(
51
+ adapter_id,
52
+ self.vllm_adapter_ids[adapter_id],
53
+ adapter_path,
54
+ )
55
+
56
+ async def toggle_training_mode(self) -> None:
57
+ await self.engine.sleep(level=1)
58
+
59
+ async def toggle_eval_mode(self) -> None:
60
+ await self.engine.wake_up()
61
+
62
+ def shutdown(self) -> None:
63
+ # No explicit close call; engine stops when process exits.
64
+ pass
65
+
66
+ async def generate(
67
+ self,
68
+ input_token_ids: list[int],
69
+ regex: Optional[str] = None,
70
+ extract_thinking: bool = False,
71
+ ) -> LLMInferenceOutput:
72
+ # Build SamplingParams correctly
73
+ guided = GuidedDecodingParams(regex=regex) if regex else None
74
+ sp = SamplingParams(
75
+ **self.sampling_params,
76
+ guided_decoding=guided,
77
+ output_kind=RequestOutputKind.FINAL_ONLY,
78
+ )
79
+
80
+ prompt = TokensPrompt(prompt_token_ids=input_token_ids)
81
+ request_id = f"req-{asyncio.get_running_loop().time()}"
82
+ result_generator = self.engine.generate(
83
+ prompt,
84
+ sp, # SamplingParams(...)
85
+ request_id,
86
+ lora_request=self.current_lora_request,
87
+ )
88
+
89
+ async for out in result_generator: # with FINAL_ONLY this runs once
90
+ res = out
91
+
92
+ raw_text = res.outputs[0].text
93
+ out_token_ids = res.outputs[0].token_ids
94
+ log_probs = [
95
+ logprob_dict[token_id].logprob
96
+ for token_id, logprob_dict in zip(out_token_ids, res.outputs[0].logprobs)
97
+ ]
98
+ log_probs = torch.tensor(log_probs)
99
+ out_token_ids = torch.tensor(out_token_ids, dtype=torch.long)
100
+ # for out_token_id, logprob_dict in zip(out_token_ids, res.outputs[0].logprobs):
101
+ # if logprob_dict[out_token_id].logprob < -1:
102
+ # print(f"High negative logprob {logprob_dict[out_token_id].logprob} for {logprob_dict}")
103
+ content = raw_text
104
+ reasoning_content = None
105
+
106
+ if extract_thinking:
107
+ m = re.match(
108
+ r"^\n<think>\n([\s\S]*?)</think>\n\n(.*)$", raw_text, flags=re.DOTALL
109
+ )
110
+ if m:
111
+ reasoning_content = m.group(1)
112
+ content = m.group(2)
113
+ return LLMInferenceOutput(
114
+ content=content,
115
+ reasoning_content=reasoning_content,
116
+ log_probs=log_probs,
117
+ out_token_ids=out_token_ids,
118
+ )
src_code_for_reproducibility/models/inference_backend_vllm_local_server.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import subprocess
4
+ import time
5
+
6
+ import httpx
7
+ import requests
8
+
9
+ from mllm.models.inference_backend import LLMInferenceBackend
10
+
11
+
12
+ class HttpVLLMBackend(LLMInferenceBackend):
13
+ def __init__(self, **kwargs):
14
+ super().__init__(**kwargs)
15
+ self.port = kwargs.get("port", 8000)
16
+ self.host = kwargs.get("host", "0.0.0.0")
17
+ self.proc = None
18
+ self.base_url = f"http://{self.host}:{self.port}"
19
+ # vLLM memory safety knobs
20
+ self.gpu_mem_util = kwargs.get("gpu_memory_utilization", 0.9)
21
+ self.max_model_len = kwargs.get("max_model_len", None)
22
+ self.max_num_seqs = kwargs.get("max_num_seqs", None)
23
+ self.max_batched_tokens = kwargs.get("max_num_batched_tokens", None)
24
+ self.dtype = kwargs.get("dtype", "bfloat16")
25
+ self.trust_remote_code = kwargs.get("trust_remote_code", False)
26
+ # LoRA strategy: "preload" (CLI) or "runtime" (endpoints) depending on your vLLM build
27
+ self.lora_mode = kwargs.get(
28
+ "lora_mode", "preload"
29
+ ) # "runtime" supported in newer builds
30
+ self.runtime_lora_enabled = self.lora_mode == "runtime"
31
+
32
+ # If preloading: build CLI args (adapter name -> path)
33
+ self._preload_lora_args = []
34
+ if self.adapter_paths and self.lora_mode == "preload":
35
+ # vLLM supports multiple LoRA modules via CLI in recent versions
36
+ # Example flag shapes can vary; adapt as needed for your version:
37
+ # --lora-modules adapter_id=path
38
+ for aid, pth in self.adapter_paths.items():
39
+ self._preload_lora_args += ["--lora-modules", f"{aid}={pth}"]
40
+
41
+ def launch(self):
42
+ # Build vLLM serve command
43
+ cmd = [
44
+ "python3",
45
+ "-m",
46
+ "vllm.entrypoints.openai.api_server",
47
+ "--model",
48
+ self.model_name,
49
+ "--host",
50
+ self.host,
51
+ "--port",
52
+ str(self.port),
53
+ "--dtype",
54
+ self.dtype,
55
+ "--gpu-memory-utilization",
56
+ str(self.gpu_mem_util),
57
+ ]
58
+ if self.trust_remote_code:
59
+ cmd += ["--trust-remote-code"]
60
+ if self.max_model_len:
61
+ cmd += ["--max-model-len", str(self.max_model_len)]
62
+ if self.max_num_seqs:
63
+ cmd += ["--max-num-seqs", str(self.max_num_seqs)]
64
+ if self.max_batched_tokens:
65
+ cmd += ["--max-num-batched-tokens", str(self.max_batched_tokens)]
66
+ cmd += self._preload_lora_args
67
+
68
+ self.proc = subprocess.Popen(
69
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
70
+ )
71
+ self._wait_ready()
72
+
73
+ def _wait_ready(self, timeout=120):
74
+ url = f"{self.base_url}/v1/models"
75
+ t0 = time.time()
76
+ while time.time() - t0 < timeout:
77
+ try:
78
+ r = requests.get(url, timeout=2)
79
+ if r.status_code == 200:
80
+ return
81
+ except Exception:
82
+ pass
83
+ time.sleep(1)
84
+ raise RuntimeError("vLLM server did not become ready in time")
85
+
86
+ def is_ready(self) -> bool:
87
+ try:
88
+ return (
89
+ requests.get(f"{self.base_url}/v1/models", timeout=2).status_code == 200
90
+ )
91
+ except Exception:
92
+ return False
93
+
94
+ def prepare_adapter(self, adapter_id: str) -> None:
95
+ if not adapter_id or not self.runtime_lora_enabled:
96
+ return
97
+ # Newer vLLM builds expose runtime LoRA endpoints. If yours differs,
98
+ # adjust the path/body here and keep the interface stable.
99
+ try:
100
+ requests.post(
101
+ f"{self.base_url}/v1/load_lora_adapter",
102
+ json={
103
+ "adapter_name": adapter_id,
104
+ "adapter_path": self.adapter_paths[adapter_id],
105
+ },
106
+ timeout=10,
107
+ ).raise_for_status()
108
+ except Exception as e:
109
+ # If already loaded or endpoint not present, swallow or log
110
+ pass
111
+
112
+ async def generate(
113
+ self, prompt_text: str, sampling_params: dict, adapter_id: str | None
114
+ ) -> str:
115
+ # Map your sampling params to OpenAI schema
116
+ body = {
117
+ "model": self.model_name,
118
+ "messages": [{"role": "user", "content": prompt_text}],
119
+ "temperature": sampling_params.get("temperature", 1.0),
120
+ "top_p": sampling_params.get("top_p", 1.0),
121
+ "max_tokens": sampling_params.get("max_new_tokens", 128),
122
+ }
123
+ # Optional knobs:
124
+ if sampling_params.get("top_k", -1) and sampling_params["top_k"] > 0:
125
+ # vLLM accepts top_k via extra params; put under "extra_body"
126
+ body.setdefault("extra_body", {})["top_k"] = sampling_params["top_k"]
127
+ if sampling_params.get("min_new_tokens", None) is not None:
128
+ body.setdefault("extra_body", {})["min_tokens"] = sampling_params[
129
+ "min_new_tokens"
130
+ ]
131
+ if sampling_params.get("frequency_penalty", None) is not None:
132
+ body["frequency_penalty"] = sampling_params["frequency_penalty"]
133
+
134
+ # Select LoRA adapter
135
+ if adapter_id:
136
+ if self.runtime_lora_enabled:
137
+ body.setdefault("extra_body", {})["lora_adapter"] = adapter_id
138
+ else:
139
+ # when preloaded via CLI, most builds select by name via "adapter_name"/"lora_adapter"
140
+ body.setdefault("extra_body", {})["lora_adapter"] = adapter_id
141
+
142
+ url = f"{self.base_url}/v1/chat/completions"
143
+ timeout = httpx.Timeout(3600.0, connect=3600.0)
144
+ async with httpx.AsyncClient(timeout=timeout) as client:
145
+ resp = await client.post(url, json=body)
146
+ resp.raise_for_status()
147
+ data = resp.json()
148
+ return data["choices"][0]["message"]["content"]
149
+
150
+ def toggle_training_mode(self) -> None:
151
+ # vLLM doesn’t expose an explicit KV “release” toggle via API.
152
+ # Strategy: keep inference server idle during training, or run training in a separate process.
153
+ pass
154
+
155
+ def toggle_eval_mode(self) -> None:
156
+ pass
157
+
158
+ def shutdown(self) -> None:
159
+ if self.proc:
160
+ self.proc.terminate()
src_code_for_reproducibility/models/large_language_model_api.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import copy
5
+ import os
6
+ import random
7
+ import re
8
+ from typing import Any, Callable, Dict, List, Optional, Sequence
9
+
10
+ import backoff
11
+ from openai import AsyncOpenAI, OpenAIError
12
+
13
+ from mllm.markov_games.rollout_tree import ChatTurn
14
+ from mllm.models.inference_backend import LLMInferenceOutput
15
+
16
+ # TODO: Get this automatically from OpenAI
17
+ reasoning_models = [
18
+ "gpt-5-nano",
19
+ "gpt-5-mini",
20
+ "gpt-5",
21
+ "o1-mini",
22
+ "o1",
23
+ "o1-pro",
24
+ "o3-mini",
25
+ "o3",
26
+ "o3-pro",
27
+ "o4-mini",
28
+ "o4",
29
+ "o4-pro",
30
+ ]
31
+
32
+
33
+ class LargeLanguageModelOpenAI:
34
+ """Tiny async wrapper for OpenAI Chat Completions."""
35
+
36
+ def __init__(
37
+ self,
38
+ llm_id: str = "",
39
+ model: str = "gpt-4.1-mini",
40
+ api_key: Optional[str] = None,
41
+ base_url: Optional[str] = None,
42
+ timeout_s: float = 300.0,
43
+ regex_max_attempts: int = 10,
44
+ sampling_params: Optional[Dict[str, Any]] = None,
45
+ init_kwargs: Optional[Dict[str, Any]] = None,
46
+ output_directory: Optional[str] = None,
47
+ ) -> None:
48
+ self.llm_id = llm_id
49
+ self.model = model
50
+ key = api_key or os.getenv("OPENAI_API_KEY")
51
+ if not key:
52
+ raise RuntimeError(
53
+ "Set OPENAI_API_KEY as global environment variable or pass api_key."
54
+ )
55
+ client_kwargs: Dict[str, Any] = {"api_key": key, "timeout": timeout_s}
56
+ if base_url:
57
+ client_kwargs["base_url"] = base_url
58
+ self.client = AsyncOpenAI(**client_kwargs)
59
+
60
+ # Sampling/default request params set at init
61
+ self.sampling_params = sampling_params
62
+ self.use_reasoning = model in reasoning_models
63
+ if self.use_reasoning:
64
+ self.sampling_params["reasoning"] = {
65
+ "effort": "low",
66
+ "summary": "detailed",
67
+ }
68
+ self.regex_max_attempts = max(1, int(regex_max_attempts))
69
+
70
+ def get_inference_policies(self) -> Dict[str, Callable]:
71
+ return {
72
+ self.llm_id: self.get_action,
73
+ }
74
+
75
+ async def prepare_adapter_for_inference(self, *args: Any, **kwargs: Any) -> None:
76
+ await asyncio.sleep(0)
77
+ pass
78
+
79
+ async def toggle_eval_mode(self, *args: Any, **kwargs: Any) -> None:
80
+ await asyncio.sleep(0)
81
+ pass
82
+
83
+ async def toggle_training_mode(self, *args: Any, **kwargs: Any) -> None:
84
+ await asyncio.sleep(0)
85
+ pass
86
+
87
+ async def export_adapters(self, *args: Any, **kwargs: Any) -> None:
88
+ await asyncio.sleep(0)
89
+ pass
90
+
91
+ async def checkpoint_all_adapters(self, *args: Any, **kwargs: Any) -> None:
92
+ await asyncio.sleep(0)
93
+ pass
94
+
95
+ def extract_output_from_response(self, resp: Response) -> LLMInferenceOutput:
96
+ if len(resp.output) > 1:
97
+ summary = resp.output[0].summary
98
+ if summary != []:
99
+ reasoning_content = summary[0].text
100
+ reasoning_content = f"OpenAI Reasoning Summary: {reasoning_content}"
101
+ else:
102
+ reasoning_content = None
103
+ content = resp.output[1].content[0].text
104
+ else:
105
+ reasoning_content = None
106
+ content = resp.output[0].content[0].text
107
+
108
+ return LLMInferenceOutput(
109
+ content=content,
110
+ reasoning_content=reasoning_content,
111
+ )
112
+
113
+ @backoff.on_exception(
114
+ backoff.expo, Exception, max_time=10**10, max_tries=10**10
115
+ )
116
+ async def get_action(
117
+ self,
118
+ state: list[ChatTurn],
119
+ agent_id: str,
120
+ regex: Optional[str] = None,
121
+ ) -> LLMInferenceOutput:
122
+ # Remove any non-role/content keys from the prompt else openai will error
123
+
124
+ # TODO:
125
+ prompt = [{"role": p.role, "content": p.content} for p in state]
126
+
127
+ # if self.sleep_between_requests:
128
+ # await self.wait_random_time()
129
+
130
+ # If regex is required, prime the model and validate client-side
131
+ if regex:
132
+ constraint_msg = {
133
+ "role": "user",
134
+ "content": (
135
+ f"Output must match this regex exactly: {regex} \n"
136
+ "Return only the matching string, with no quotes or extra text."
137
+ ),
138
+ }
139
+ prompt = [constraint_msg, *prompt]
140
+ pattern = re.compile(regex)
141
+ for _ in range(self.regex_max_attempts):
142
+ resp = await self.client.responses.create(
143
+ model=self.model,
144
+ input=prompt,
145
+ **self.sampling_params,
146
+ )
147
+ policy_output = self.extract_output_from_response(resp)
148
+ if pattern.fullmatch(policy_output.content):
149
+ return policy_output
150
+ prompt = [
151
+ *prompt,
152
+ {
153
+ "role": "user",
154
+ "content": (
155
+ f"Invalid response format. Expected format (regex): {regex}\n Please try again and provide ONLY a response that matches this regex."
156
+ ),
157
+ },
158
+ ]
159
+ return policy_output
160
+
161
+ # Simple, unconstrained generation
162
+ resp = await self.client.responses.create(
163
+ model=self.model,
164
+ input=prompt,
165
+ **self.sampling_params,
166
+ )
167
+ policy_output = self.extract_output_from_response(resp)
168
+ return policy_output
169
+
170
+ def shutdown(self) -> None:
171
+ self.client = None
src_code_for_reproducibility/models/large_language_model_local.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TODO: Figure out how to tweak SGlang not to go OOM when batch size is 32. See https://github.com/sgl-project/sglang/issues/6309.
3
+ """
4
+
5
+ import logging
6
+ import os
7
+ import re
8
+ import sys
9
+ import uuid
10
+ from collections.abc import Callable
11
+ from copy import deepcopy
12
+ from datetime import datetime
13
+ from typing import Literal
14
+
15
+ import httpx
16
+ import requests
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ # from sglang.utils import (
21
+ # launch_server_cmd,
22
+ # print_highlight,
23
+ # terminate_process,
24
+ # wait_for_server,
25
+ # )
26
+ from torch.optim import SGD, Adam, AdamW, RMSprop
27
+ from transformers import AutoModelForCausalLM, AutoTokenizer
28
+ from trl import AutoModelForCausalLMWithValueHead
29
+
30
+ from mllm.chat_utils.apply_template import chat_turns_to_token_ids
31
+ from mllm.markov_games.rollout_tree import ChatTurn
32
+ from mllm.models.adapter_training_wrapper import AdapterWrapper
33
+ from mllm.models.inference_backend import LLMInferenceOutput
34
+ from mllm.models.inference_backend_dummy import DummyInferenceBackend
35
+ from mllm.models.inference_backend_sglang import SGLangOfflineBackend
36
+ from mllm.models.inference_backend_vllm import VLLMAsyncBackend
37
+
38
+ logger = logging.getLogger(__name__)
39
+ logger.addHandler(logging.StreamHandler(sys.stdout))
40
+
41
+ AdapterID = str
42
+ PolicyID = str
43
+
44
+
45
+ class LeanLocalLLM:
46
+ """
47
+ TOWRITE
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ llm_id: str = "base_llm",
53
+ model_name: str = "Qwen/Qwen3-4B-Instruct-2507",
54
+ device: str = "cuda",
55
+ hf_kwargs: dict = {},
56
+ adapter_configs: dict = {},
57
+ output_directory: str = "./models/",
58
+ inference_backend: Literal["vllm", "sglang", "dummy"] = "vllm",
59
+ inference_backend_sampling_params: dict = {},
60
+ inference_backend_init_kwargs: dict = {},
61
+ initial_adapter_paths: dict[str, str] | None = None,
62
+ initial_buffer_paths: list[str] | None = None,
63
+ enable_thinking: bool = None,
64
+ regex_max_attempts: int = -1,
65
+ max_thinking_characters: int = 0,
66
+ ):
67
+ self.inference_backend_name = inference_backend
68
+ self.output_directory = output_directory
69
+ self.llm_id = llm_id
70
+ self.device = torch.device(device) if device else torch.device("cuda")
71
+ self.model_name = model_name
72
+ self.adapter_configs = adapter_configs
73
+ self.adapter_ids = list(adapter_configs.keys())
74
+ self.enable_thinking = enable_thinking
75
+ self.regex_max_attempts = regex_max_attempts
76
+ self.initial_buffer_paths = initial_buffer_paths
77
+ self.max_thinking_characters = max_thinking_characters
78
+ self.regex_retries_count = 0
79
+
80
+ # Optional user-specified initial adapter weight locations (local or HF Hub)
81
+ # Format: {adapter_id: path_or_repo_id}
82
+ self.initial_adapter_paths: dict[str, str] | None = initial_adapter_paths
83
+
84
+ # Path management / imports
85
+ self.save_path = str(os.path.join(output_directory, model_name, "adapters"))
86
+ self.adapter_paths = {
87
+ adapter_id: os.path.join(self.save_path, adapter_id)
88
+ for adapter_id in self.adapter_ids
89
+ }
90
+ checkpoints_dir = os.path.join(self.output_directory, "checkpoints")
91
+ self.past_agent_adapter_paths = {}
92
+ if os.path.isdir(checkpoints_dir):
93
+ for dirname in os.listdir(checkpoints_dir):
94
+ dirpath = os.path.join(checkpoints_dir, dirname)
95
+ if os.path.isdir(dirpath):
96
+ self.past_agent_adapter_paths[f"{dirname}_buffer"] = os.path.join(
97
+ dirpath, "agent_adapter"
98
+ )
99
+ logger.info(
100
+ f"Loaded {len(self.past_agent_adapter_paths)} past agent adapters from checkpoints directory."
101
+ )
102
+ if self.initial_buffer_paths is not None:
103
+ previous_count = len(self.past_agent_adapter_paths)
104
+ for path in self.initial_buffer_paths:
105
+ if os.path.isdir(path):
106
+ for dirname in os.listdir(path):
107
+ dirpath = os.path.join(path, dirname)
108
+ if os.path.isdir(dirpath):
109
+ self.past_agent_adapter_paths[
110
+ f"{dirname}_buffer"
111
+ ] = os.path.join(dirpath, "agent_adapter")
112
+ else:
113
+ logger.warning(
114
+ f"Initial buffer path {path} does not exist or is not a directory."
115
+ )
116
+ logger.info(
117
+ f"Loaded {len(self.past_agent_adapter_paths) - previous_count} past agent adapters from user-specified initial buffer paths."
118
+ )
119
+ self.past_agent_adapter_ids = list(self.past_agent_adapter_paths.keys())
120
+
121
+ # ID management for tracking adapter versions
122
+ self.adapter_train_ids = {
123
+ adapter_id: self.short_id_generator() for adapter_id in self.adapter_ids
124
+ }
125
+ # Initialize tokenizer
126
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
127
+ # Setup padding token to be same as EOS token
128
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
129
+ self.tokenizer.pad_token = self.tokenizer.eos_token
130
+
131
+ self.weights_got_updated: dict[AdapterID, bool] = {
132
+ adapter_id: False for adapter_id in self.adapter_ids
133
+ }
134
+ self.weights_got_updated.update(
135
+ {adapter_id: False for adapter_id in self.past_agent_adapter_ids}
136
+ )
137
+ self.current_lora_request = None
138
+ self.currently_loaded_adapter_id = None
139
+
140
+ # ---------------------------------------------------------
141
+ # Init HF model, peft adapters
142
+ # ---------------------------------------------------------
143
+ self.shared_hf_llm = AutoModelForCausalLM.from_pretrained(
144
+ pretrained_model_name_or_path=model_name,
145
+ **hf_kwargs,
146
+ )
147
+ self.hf_adapters = {}
148
+ self.optimizers = {}
149
+ for adapter_id in self.adapter_ids:
150
+ # Prefer output-folder path if it exists; else fall back to user-specified initial path if provided
151
+ output_path = os.path.join(self.save_path, adapter_id)
152
+ chosen_path: str | None = None
153
+ if os.path.isdir(output_path) and os.listdir(output_path):
154
+ chosen_path = output_path
155
+ logger.info(
156
+ f"Initializing adapter '{adapter_id}': using existing weights from output folder '{chosen_path}'."
157
+ )
158
+ elif (
159
+ self.initial_adapter_paths and adapter_id in self.initial_adapter_paths
160
+ ):
161
+ chosen_path = self.initial_adapter_paths[adapter_id]
162
+ logger.info(
163
+ f"Initializing adapter '{adapter_id}': using provided initial path '{chosen_path}'."
164
+ )
165
+ else:
166
+ logger.info(
167
+ f"Initializing adapter '{adapter_id}': no initial weights provided or found; starting from scratch."
168
+ )
169
+ hf_adapter = AdapterWrapper(
170
+ shared_llm=self.shared_hf_llm,
171
+ adapter_id=adapter_id,
172
+ lora_config=adapter_configs[adapter_id],
173
+ path=chosen_path,
174
+ ).to(device)
175
+ self.hf_adapters[adapter_id] = hf_adapter
176
+ # Persist current state of all adapters (ensures remote loads are cached to disk)
177
+ self.export_adapters()
178
+
179
+ # ---------------------------------------------------------
180
+ # Init inference inference_backend
181
+ # ---------------------------------------------------------
182
+
183
+ if inference_backend == "sglang":
184
+ self.inference_backend = SGLangOfflineBackend(
185
+ model_name=self.model_name,
186
+ save_path=self.save_path,
187
+ adapter_paths=self.adapter_paths,
188
+ tokenizer=self.tokenizer,
189
+ kwargs=inference_backend_init_kwargs,
190
+ )
191
+ elif inference_backend == "vllm":
192
+ self.inference_backend = VLLMAsyncBackend(
193
+ model_name=self.model_name,
194
+ # adapter_paths=self.adapter_paths,
195
+ tokenizer=self.tokenizer,
196
+ engine_init_kwargs=inference_backend_init_kwargs,
197
+ sampling_params=inference_backend_sampling_params,
198
+ )
199
+ elif inference_backend == "dummy":
200
+ self.inference_backend = DummyInferenceBackend()
201
+ else:
202
+ raise ValueError(f"Unknown inference_backend: {inference_backend}")
203
+
204
+ def reset_regex_retries_count(self) -> None:
205
+ self.regex_retries_count = 0
206
+
207
+ def get_inference_policies(self) -> dict[PolicyID, Callable]:
208
+ """
209
+ TOWRITE
210
+ """
211
+ policies = {}
212
+ for adapter_id in self.adapter_ids:
213
+ # define policy func
214
+ async def policy(
215
+ state: list[ChatTurn],
216
+ agent_id: str,
217
+ regex: str | None = None,
218
+ _adapter_id=adapter_id,
219
+ ):
220
+ self.prepare_adapter_for_inference(adapter_id=_adapter_id)
221
+ response = await self.get_action(state, agent_id, regex)
222
+ return response
223
+
224
+ policies[self.llm_id + "/" + adapter_id] = policy
225
+
226
+ for adapter_id in self.past_agent_adapter_ids:
227
+ # define policy func
228
+ async def policy(
229
+ state: list[ChatTurn],
230
+ agent_id: str,
231
+ regex: str | None = None,
232
+ _adapter_id=adapter_id,
233
+ ):
234
+ self.prepare_adapter_for_inference(adapter_id=_adapter_id)
235
+ response = await self.get_action(state, agent_id, regex)
236
+ return response
237
+
238
+ policies[self.llm_id + "/" + adapter_id] = policy
239
+ return policies
240
+
241
+ def get_adapter_modules(self) -> dict[PolicyID, nn.Module]:
242
+ """
243
+ Returns wrappers over the adapters which allows them be
244
+ interfaced like regular PyTorch models.
245
+ # TODO: create the adapter wrappers here
246
+ See adapter_wrapper.py
247
+ """
248
+ trainable_objects = {an: self.hf_adapters[an] for an in self.adapter_ids}
249
+ return trainable_objects
250
+
251
+ async def toggle_training_mode(self) -> None:
252
+ for adn in self.adapter_ids:
253
+ self.adapter_train_ids[adn] = self.short_id_generator()
254
+ await self.inference_backend.toggle_training_mode()
255
+
256
+ async def toggle_eval_mode(self) -> None:
257
+ await self.inference_backend.toggle_eval_mode()
258
+
259
+ def prepare_adapter_for_inference(self, adapter_id: AdapterID) -> None:
260
+ self.inference_backend.prepare_adapter(
261
+ adapter_id,
262
+ adapter_path=self.adapter_paths.get(
263
+ adapter_id, self.past_agent_adapter_paths.get(adapter_id, None)
264
+ ),
265
+ weights_got_updated=self.weights_got_updated[adapter_id],
266
+ )
267
+ self.currently_loaded_adapter_id = adapter_id
268
+ self.weights_got_updated[adapter_id] = False
269
+
270
+ # def _make_prompt_text(self, prompt: list[dict]) -> str:
271
+ # if self.enable_thinking is not None:
272
+ # prompt_text = self.tokenizer.apply_chat_template(
273
+ # prompt,
274
+ # tokenize=False,
275
+ # add_generation_prompt=True,
276
+ # enable_thinking=self.enable_thinking,
277
+ # )
278
+ # else:
279
+ # prompt_text = self.tokenizer.apply_chat_template(
280
+ # prompt,
281
+ # tokenize=False,
282
+ # add_generation_prompt=True,
283
+ # )
284
+
285
+ # return prompt_text
286
+
287
+ async def get_action(
288
+ self, state: list[ChatTurn], agent_id: str, regex: str | None = None
289
+ ) -> ChatTurn:
290
+ current_regex = regex if self.regex_max_attempts == -1 else None
291
+ pattern = re.compile(regex) if regex else None
292
+ nb_attempts = 0
293
+ state = state[:]
294
+ while True:
295
+ context_token_ids = chat_turns_to_token_ids(
296
+ chats=state,
297
+ tokenizer=self.tokenizer,
298
+ enable_thinking=self.enable_thinking,
299
+ )
300
+ # print(f"context is {self.tokenizer.decode(context_token_ids)}")
301
+ policy_output = await self.inference_backend.generate(
302
+ input_token_ids=context_token_ids.tolist(),
303
+ extract_thinking=(self.max_thinking_characters > 0),
304
+ regex=current_regex,
305
+ )
306
+ # print(f"generated: {self.tokenizer.decode(policy_output.out_token_ids)}")
307
+ if (
308
+ pattern is None
309
+ or (pattern.fullmatch(policy_output.content))
310
+ or (nb_attempts >= self.regex_max_attempts)
311
+ ):
312
+ return ChatTurn(
313
+ agent_id=agent_id,
314
+ role="assistant",
315
+ content=policy_output.content,
316
+ reasoning_content=policy_output.reasoning_content,
317
+ out_token_ids=policy_output.out_token_ids,
318
+ log_probs=policy_output.log_probs,
319
+ is_state_end=False,
320
+ )
321
+ else:
322
+ self.regex_retries_count += 1
323
+ nb_attempts += 1
324
+ logger.warning(
325
+ f"Response {policy_output.content} did not match regex: {regex}, retry {nb_attempts}/{self.regex_max_attempts}"
326
+ )
327
+ if nb_attempts == self.regex_max_attempts:
328
+ current_regex = regex
329
+ # regex_prompt = ChatTurn(
330
+ # role="user",
331
+ # content=f"Invalid response format. Expected format (regex): {current_regex}\n Please try again and provide ONLY a response that matches this regex.",
332
+ # reasoning_content=None,
333
+ # log_probs=None,
334
+ # out_token_ids=None,
335
+ # is_state_end=False,
336
+ # )
337
+ # state.append(regex_prompt)
338
+
339
+ def export_adapters(self) -> None:
340
+ """
341
+ Any peft wrapper, by default, saves all adapters, not just the one currently loaded.
342
+ """
343
+
344
+ # New version of the adapters available
345
+ for adapter_id in self.adapter_ids:
346
+ self.weights_got_updated[adapter_id] = True
347
+ for adapter_id in self.past_agent_adapter_ids:
348
+ self.weights_got_updated[adapter_id] = True
349
+
350
+ # import random
351
+ # self.save_path = self.save_path + str(random.randint(1,500))
352
+ # print(f"Save path: {self.save_path}")
353
+ # self.adapter_paths = {adapter_id:os.path.join(self.save_path, adapter_id) for adapter_id in self.adapter_ids}
354
+
355
+ adapter_id = self.adapter_ids[0]
356
+ self.hf_adapters[adapter_id].save_pretrained(self.save_path)
357
+
358
+ def checkpoint_all_adapters(self, checkpoint_indicator: str) -> None:
359
+ """
360
+ Checkpoints all adapters to the configured output directory.
361
+ """
362
+ adapter_id = self.adapter_ids[0]
363
+ output_dir = os.path.join(self.output_directory, "checkpoints")
364
+ os.makedirs(output_dir, exist_ok=True)
365
+ date_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
366
+ agent_adapter_dir = f"{adapter_id}-{checkpoint_indicator}-{date_str}"
367
+ export_path = os.path.join(output_dir, agent_adapter_dir)
368
+ for adapter_id in self.adapter_ids:
369
+ if "agent" in adapter_id:
370
+ self.past_agent_adapter_paths[
371
+ f"{agent_adapter_dir}_buffer"
372
+ ] = os.path.join(export_path, adapter_id)
373
+ self.past_agent_adapter_ids.append(f"{agent_adapter_dir}_buffer")
374
+ self.weights_got_updated[f"{agent_adapter_dir}_buffer"] = False
375
+ self.hf_adapters[adapter_id].save_pretrained(export_path)
376
+
377
+ def short_id_generator(self) -> str:
378
+ """
379
+ Generates a short unique ID for tracking adapter versions.
380
+
381
+ Returns:
382
+ int: An 8-digit integer ID.
383
+ """
384
+ return str(uuid.uuid4().int)[:8]
src_code_for_reproducibility/models/scalar_critic.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, torch.nn as nn, torch.optim as optim
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from peft import LoraConfig, get_peft_model
4
+
5
+ from mllm.models.adapter_training_wrapper import AdapterWrapper
6
+
7
+
8
+ class ScalarCritic(nn.Module):
9
+ """
10
+ A causal-LM critic_adapter + a scalar value head:
11
+ V_φ(s) = wᵀ h_last + b
12
+ Only LoRA adapters (inside critic_adapter) and the value head are trainable.
13
+ """
14
+ def __init__(self, critic_adapter: AdapterWrapper):
15
+ super().__init__()
16
+ self.critic_adapter = critic_adapter
17
+ hidden_size = self.critic_adapter.shared_llm.config.hidden_size
18
+ self.value_head = nn.Linear(hidden_size, 1).to(
19
+ dtype=critic_adapter.dtype,
20
+ device=critic_adapter.device)
21
+
22
+ def forward(self,
23
+ input_ids,
24
+ attention_mask=None,
25
+ **kwargs):
26
+ # AdapterWrapper activates its own adapter internally
27
+ outputs = self.critic_adapter(
28
+ input_ids=input_ids,
29
+ attention_mask=attention_mask,
30
+ output_hidden_states=True,
31
+ **kwargs,
32
+ )
33
+ h_last = outputs.hidden_states[-1] # (B, S, H)
34
+ values = self.value_head(h_last).squeeze(-1) # (B, S)
35
+ return values
36
+
37
+ def parameters(self, recurse: bool = True):
38
+ """Iterator over *trainable* parameters for this critic."""
39
+ # 1) LoRA params for *this* adapter
40
+ for p in self.critic_adapter.parameters():
41
+ yield p
42
+ # 2) scalar head
43
+ yield from self.value_head.parameters()
44
+
45
+ def gradient_checkpointing_enable(self, *args, **kwargs):
46
+ self.critic_adapter.gradient_checkpointing_enable(*args, **kwargs)
47
+
48
+ @property
49
+ def dtype(self):
50
+ return self.critic_adapter.dtype
51
+
52
+ @property
53
+ def device(self):
54
+ return self.critic_adapter.device
src_code_for_reproducibility/training/README.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Suppose we have a trajectory with 3 timesteps.
2
+ token: "0 1 2 3 4 5 6 7 8 9 . . . . ."
3
+ string: "A B C a b c A a A a b c A B C" (Capitalized = User, Lowercased = Assistant)
4
+ action_mask: "x x x ✓ ✓ ✓ x ✓ x ✓ ✓ ✓ x x x" (F = False, T = True)
5
+ rewards: "r r r r r r R R R R R R r r r"
6
+ timestep: "0 0 0 0 0 0 1 1 1 1 1 1 2 2 2"
7
+ state_ends: "x x ✓ x x x ✓ x x x x x x x ✓"
8
+
9
+ There must be one baseline flag per timestep!
10
+
11
+ Then, we might have
12
+
13
+ A naive way to interpret this is to think of the number of assistant messages as the number of
14
+ steps in the environment. However, this is not the case in practice. Indeed, in a
15
+ single simulation step,
16
+
17
+
18
+
19
+
20
+ A subtlety arises with credit assignment. In the multi-agent case, we might
src_code_for_reproducibility/training/__init__.py ADDED
File without changes
src_code_for_reproducibility/training/annealing_methods.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def sigmoid_annealing(step: int, temperature: float) -> float:
5
+ return 2 / (1 + np.exp(-step / temperature)) - 1
6
+
src_code_for_reproducibility/training/credit_methods.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def whiten_advantages(advantages: torch.Tensor) -> torch.Tensor:
5
+ """
6
+ Whitens the advantages.
7
+ """
8
+ whitened_advantages = (advantages - torch.mean(advantages)) / (
9
+ torch.std(advantages) + 1e-9
10
+ )
11
+ return whitened_advantages
12
+
13
+
14
+ def whiten_advantages_time_step_wise(
15
+ advantages: torch.Tensor, # (B, T)
16
+ ) -> torch.Tensor:
17
+ """
18
+ Whitens the advantages.
19
+ """
20
+ assert advantages.dim() == 2, "Wrong dimensions."
21
+ whitened_advantages_time_step_wise = (
22
+ advantages - advantages.mean(dim=0, keepdim=True)
23
+ ) / (advantages.std(dim=0, keepdim=True) + 1e-9)
24
+ return whitened_advantages_time_step_wise
25
+
26
+
27
+ def get_discounted_state_visitation_credits(
28
+ credits: torch.Tensor, discount_factor: float # (B, T)
29
+ ) -> torch.Tensor:
30
+ """
31
+ Computes discounted state visitation credits for a sequence of credits.
32
+ """
33
+ return credits * (
34
+ discount_factor ** torch.arange(credits.shape[1], device=credits.device)
35
+ )
36
+
37
+
38
+ def get_discounted_returns(
39
+ rewards: torch.Tensor, # (B, T)
40
+ discount_factor: float,
41
+ ) -> torch.Tensor:
42
+ """
43
+ Computes Monte Carlo discounted returns for a sequence of rewards.
44
+
45
+ Args:
46
+ rewards (torch.Tensor): Array of rewards for each timestep.
47
+
48
+ Returns:
49
+ torch.Tensor: Array of discounted returns.
50
+ """
51
+ assert rewards.dim() == 2, "Wrong dimensions."
52
+ B, T = rewards.shape
53
+ discounted_returns = torch.zeros_like(rewards)
54
+ accumulator = torch.zeros(B, device=rewards.device, dtype=rewards.dtype)
55
+ for t in reversed(range(T)):
56
+ accumulator = rewards[:, t] + discount_factor * accumulator
57
+ discounted_returns[:, t] = accumulator
58
+ return discounted_returns
59
+
60
+
61
+ def get_rloo_credits(credits: torch.Tensor): # (B, S)
62
+ assert credits.dim() == 2, "Wrong dimensions."
63
+ rloo_baselines = torch.zeros_like(credits)
64
+ n = credits.shape[0]
65
+ if n == 1:
66
+ return credits, rloo_baselines
67
+ rloo_baselines = (torch.sum(credits, dim=0, keepdim=True) - credits) / (n - 1)
68
+ rloo_credits = credits - rloo_baselines
69
+ return rloo_credits, rloo_baselines
70
+
71
+
72
+ def get_generalized_advantage_estimates(
73
+ rewards: torch.Tensor, # (B, T)
74
+ value_estimates: torch.Tensor, # (B, T+1)
75
+ discount_factor: float,
76
+ lambda_coef: float,
77
+ ) -> torch.Tensor:
78
+ """
79
+ Computes Generalized Advantage Estimates (GAE) for a sequence of rewards and value estimates.
80
+ See https://arxiv.org/pdf/1506.02438 for details.
81
+
82
+
83
+ Returns:
84
+ torch.Tensor: Array of GAE values.
85
+ """
86
+ assert rewards.dim() == value_estimates.dim() == 2, "Wrong dimensions."
87
+
88
+ assert (
89
+ rewards.shape[0] == value_estimates.shape[0]
90
+ ), f"Got shapes {rewards.shape} and {value_estimates.shape} of rewards and value estimates."
91
+ assert (
92
+ rewards.shape[1] == value_estimates.shape[1] - 1
93
+ ), f"Got shapes {rewards.shape} and {value_estimates.shape} of rewards and value estimates."
94
+
95
+ T = rewards.shape[1]
96
+ tds = rewards + discount_factor * value_estimates[:, 1:] - value_estimates[:, :-1]
97
+ gaes = torch.zeros_like(tds)
98
+ acc = 0.0
99
+ for t in reversed(range(T)):
100
+ acc = tds[:, t] + lambda_coef * discount_factor * acc
101
+ gaes[:, t] = acc
102
+ return gaes
103
+
104
+
105
+ def get_advantage_alignment_weights(
106
+ advantages: torch.Tensor, # (B, T)
107
+ exclude_k_equals_t: bool,
108
+ gamma: float,
109
+ discount_t: bool,
110
+ ) -> torch.Tensor:
111
+ """
112
+ The advantage alignment credit is calculated as
113
+
114
+ \[
115
+ A^*(s_t, a_t, b_t) = A^1(s_t, a_t, b_t) + \beta \cdot
116
+ \left( \sum_{k < t} \gamma^{t-k} A^1(s_k, a_k, b_k) \right)
117
+ A^2(s_t, a_t, b_t)
118
+ \]
119
+
120
+ Here, the weights are defined as \( \beta \cdot
121
+ \left( \sum_{k < t} \gamma^{t-k} A^1(s_k, a_k, b_k) \)
122
+ """
123
+ T = advantages.shape[1]
124
+ discounted_advantages = advantages * (
125
+ gamma * torch.ones((1, T), device=advantages.device)
126
+ ) ** (-torch.arange(0, T, 1, device=advantages.device))
127
+ if exclude_k_equals_t:
128
+ sub = torch.eye(T, device=advantages.device)
129
+ else:
130
+ sub = torch.zeros((T, T), device=advantages.device)
131
+ # Identity is for \( k < t \), remove for \( k \leq t \)
132
+ ad_align_weights = discounted_advantages @ (
133
+ torch.triu(torch.ones((T, T), device=advantages.device)) - sub
134
+ )
135
+ t_discounts = (gamma * torch.ones((1, T), device=advantages.device)) ** (
136
+ torch.arange(0, T, 1, device=advantages.device)
137
+ )
138
+ ad_align_weights = t_discounts * ad_align_weights
139
+ if discount_t:
140
+ time_discounted_advantages = advantages * (
141
+ gamma * torch.ones((1, T), device=advantages.device)
142
+ ) ** (torch.arange(0, T, 1, device=advantages.device))
143
+ ad_align_weights = ad_align_weights - advantages + time_discounted_advantages
144
+ return ad_align_weights
145
+
146
+
147
+ def get_advantage_alignment_credits(
148
+ a1: torch.Tensor, # (B, S)
149
+ a1_alternative: torch.Tensor, # (B, S, A)
150
+ a2: torch.Tensor, # (B, S)
151
+ exclude_k_equals_t: bool,
152
+ beta: float,
153
+ gamma: float = 1.0,
154
+ use_old_ad_align: bool = False,
155
+ use_sign: bool = False,
156
+ clipping: float | None = None,
157
+ use_time_regularization: bool = False,
158
+ force_coop_first_step: bool = False,
159
+ use_variance_regularization: bool = False,
160
+ rloo_branch: bool = False,
161
+ reuse_baseline: bool = False,
162
+ mean_normalize_ad_align: bool = False,
163
+ whiten_adalign_advantages: bool = False,
164
+ whiten_adalign_advantages_time_step_wise: bool = False,
165
+ discount_t: bool = False,
166
+ ) -> torch.Tensor:
167
+ """
168
+ Calculate the advantage alignment credits with vectorization, as described in https://arxiv.org/abs/2406.14662.
169
+
170
+ Recall that the advantage opponent shaping term of the AdAlign policy gradient is:
171
+ \[
172
+ \beta \mathbb{E}_{\substack{
173
+ \tau \sim \text{Pr}_{\mu}^{\pi^1, \pi^2} \\
174
+ a_t' \sim \pi^1(\cdot \mid s_t)
175
+ }}
176
+ \left[\sum_{t=0}^\infty \gamma^{t}\left( \sum_{k\leq t} A^1(s_k,a^{\prime}_k,b_k) \right) A^{2}(s_t,a_t, b_t)\nabla_{\theta^1}\text{log } \pi^1(a_t|s_t) \right]
177
+ \]
178
+
179
+ This method computes the following:
180
+ \[
181
+ Credit(s_t, a_t, b_t) = \gamma^t \left[ A^1(s_t, a_t, b_t) + \beta \left( \sum_{k\leq t} A^1(s_k,a^{\prime}_k,b_k) \right) A^{2}(s_t,a_t, b_t) \right]
182
+ \]
183
+
184
+ Args:
185
+ a1: Advantages of the main trajectories for the current agent.
186
+ a1_alternative: Advantages of the alternative trajectories for the current agent.
187
+ a2: Advantages of the main trajectories for the other agent.
188
+ discount_factor: Discount factor for the advantage alignment.
189
+ beta: Beta parameter for the advantage alignment.
190
+ gamma: Gamma parameter for the advantage alignment.
191
+ use_sign_in_ad_align: Whether to use sign in the advantage alignment.
192
+
193
+ Returns:
194
+ torch.Tensor: The advantage alignment credits.
195
+ """
196
+
197
+ assert a1.dim() == a2.dim() == 2, "Advantages must be of shape (B, S)"
198
+ if a1_alternative is not None:
199
+ assert (
200
+ a1_alternative.dim() == 3
201
+ ), "Alternative advantages must be of shape (B, S, A)"
202
+ B, T, A = a1_alternative.shape
203
+ else:
204
+ B, T = a1.shape
205
+ assert a1.shape == a2.shape, "Not the same shape"
206
+
207
+ sub_tensors = {}
208
+
209
+ if use_old_ad_align:
210
+ ad_align_weights = get_advantage_alignment_weights(
211
+ advantages=a1,
212
+ exclude_k_equals_t=exclude_k_equals_t,
213
+ gamma=gamma,
214
+ discount_t=discount_t,
215
+ )
216
+ sub_tensors["ad_align_weights_prev"] = ad_align_weights
217
+ if exclude_k_equals_t:
218
+ ad_align_weights = gamma * ad_align_weights
219
+ else:
220
+ assert a1_alternative is not None, "Alternative advantages must be provided"
221
+ if rloo_branch:
222
+ a1_alternative = torch.cat([a1.unsqueeze(2), a1_alternative], dim=2)
223
+ a1_alternative = a1_alternative.mean(dim=2)
224
+ # print(f"a1_alternative: {a1_alternative}, a1: {a1}\n")
225
+ a1, baseline = get_rloo_credits(a1)
226
+ if reuse_baseline:
227
+ a1_alternative = a1_alternative - baseline
228
+ else:
229
+ a1_alternative, _ = get_rloo_credits(a1_alternative)
230
+ assert a1.shape == a1_alternative.shape, "Not the same shape"
231
+ ad_align_weights = get_advantage_alignment_weights(
232
+ advantages=a1_alternative,
233
+ exclude_k_equals_t=exclude_k_equals_t,
234
+ gamma=gamma,
235
+ )
236
+ sub_tensors["ad_align_weights"] = ad_align_weights
237
+
238
+ # Use sign
239
+ if use_sign:
240
+ assert beta == 1.0, "beta should be 1.0 when using sign"
241
+ positive_signs = ad_align_weights > 0
242
+ negative_signs = ad_align_weights < 0
243
+ ad_align_weights[positive_signs] = 1
244
+ ad_align_weights[negative_signs] = -1
245
+ sub_tensors["ad_align_weights_sign"] = ad_align_weights
246
+ # (rest are 0)
247
+
248
+ ###################
249
+ # Process weights
250
+ ###################
251
+
252
+ # Use clipping
253
+ if clipping not in [0.0, None]:
254
+ upper_mask = ad_align_weights > 1
255
+ lower_mask = ad_align_weights < -1
256
+
257
+ ad_align_weights = torch.clip(
258
+ ad_align_weights,
259
+ -clipping,
260
+ clipping,
261
+ )
262
+ clipping_ratio = (
263
+ torch.sum(upper_mask) + torch.sum(lower_mask)
264
+ ) / upper_mask.size
265
+ sub_tensors["clipped_ad_align_weights"] = ad_align_weights
266
+
267
+ # 1/1+t Regularization
268
+ if use_time_regularization:
269
+ t_values = torch.arange(1, T + 1).to(ad_align_weights.device)
270
+ ad_align_weights = ad_align_weights / t_values
271
+ sub_tensors["time_regularized_ad_align_weights"] = ad_align_weights
272
+
273
+ # Use coop on t=0
274
+ if force_coop_first_step:
275
+ ad_align_weights[:, 0] = 1
276
+ sub_tensors["coop_first_step_ad_align_weights"] = ad_align_weights
277
+ # # Normalize alignment terms (across same time step)
278
+ # if use_variance_regularization_in_ad_align:
279
+ # # TODO: verify
280
+ # reg_coef = torch.std(a1[:, -1]) / (torch.std(opp_shaping_terms[:, -1]) + 1e-9)
281
+ # opp_shaping_terms *= reg_coef
282
+
283
+ ####################################
284
+ # Compose elements together
285
+ ####################################
286
+
287
+ opp_shaping_terms = beta * ad_align_weights * a2
288
+ sub_tensors["ad_align_opp_shaping_terms"] = opp_shaping_terms
289
+
290
+ credits = a1 + opp_shaping_terms
291
+ if mean_normalize_ad_align:
292
+ credits = credits - credits.mean(dim=0)
293
+ sub_tensors["mean_normalized_ad_align_credits"] = credits
294
+ if whiten_adalign_advantages:
295
+ credits = (credits - credits.mean()) / (credits.std() + 1e-9)
296
+ sub_tensors["whitened_ad_align_credits"] = credits
297
+ if whiten_adalign_advantages_time_step_wise:
298
+ credits = (credits - credits.mean(dim=0, keepdim=True)) / (
299
+ credits.std(dim=0, keepdim=True) + 1e-9
300
+ )
301
+ sub_tensors["whitened_ad_align_credits_time_step_wise"] = credits
302
+ sub_tensors["final_ad_align_credits"] = credits
303
+
304
+ return credits, sub_tensors
src_code_for_reproducibility/training/tally_metrics.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from numbers import Number
3
+ from typing import Union
4
+
5
+ import wandb
6
+
7
+
8
+ class Tally:
9
+ """
10
+ Minimal scalar-first tally.
11
+ - Keys are strings.
12
+ - First add stores a scalar; subsequent adds upgrade to a list of scalars.
13
+ """
14
+
15
+ def __init__(self):
16
+ self.stats = {}
17
+
18
+ def reset(self):
19
+ self.stats = {}
20
+
21
+ def _coerce_scalar(self, value: Union[int, float]) -> Union[int, float]:
22
+ if hasattr(value, "item") and callable(getattr(value, "item")):
23
+ try:
24
+ value = value.item()
25
+ except Exception:
26
+ pass
27
+ if isinstance(value, Number):
28
+ return value
29
+ raise AssertionError("Metric must be a scalar number")
30
+
31
+ def add_metric(self, path: str, metric: Union[int, float]):
32
+ metric = float(metric)
33
+ assert isinstance(path, str), "Path must be a string."
34
+ assert isinstance(metric, float), "Metric must be a scalar number."
35
+
36
+ scalar = self._coerce_scalar(metric)
37
+ existing = self.stats.get(path)
38
+ if existing is None:
39
+ self.stats[path] = scalar
40
+ elif isinstance(existing, list):
41
+ existing.append(scalar)
42
+ else:
43
+ self.stats[path] = [existing, scalar]
44
+
45
+ def save(self, identifier: str, folder: str):
46
+ os.makedirs(name=folder, exist_ok=True)
47
+ try:
48
+ import pickle
49
+
50
+ pkl_path = os.path.join(folder, f"{identifier}.tally.pkl")
51
+ payload = self.stats
52
+ with open(pkl_path, "wb") as f:
53
+ pickle.dump(payload, f, protocol=pickle.HIGHEST_PROTOCOL)
54
+ except Exception:
55
+ pass
src_code_for_reproducibility/training/tally_rollout.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from copy import deepcopy
4
+ from typing import Union
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+ from transformers import AutoTokenizer
10
+
11
+
12
+ class RolloutTallyItem:
13
+ def __init__(self, crn_ids: list[str], rollout_ids: list[str], agent_ids: list[str], metric_matrix: torch.Tensor):
14
+ """
15
+ Initializes the RolloutTallyItem object.
16
+
17
+ Args:
18
+ crn_ids (list[str]): List of CRN IDs.
19
+ rollout_ids (list[str]): List of rollout IDs.
20
+ agent_ids (list[str]): List of agent IDs.
21
+ metric_matrix (torch.Tensor): Metric matrix.
22
+ """
23
+ if isinstance(crn_ids, torch.Tensor):
24
+ crn_ids = crn_ids.detach().cpu().numpy()
25
+ if isinstance(rollout_ids, torch.Tensor):
26
+ rollout_ids = rollout_ids.detach().cpu().numpy()
27
+ if isinstance(agent_ids, torch.Tensor):
28
+ agent_ids = agent_ids.detach().cpu().numpy()
29
+ self.crn_ids = crn_ids
30
+ self.rollout_ids = rollout_ids
31
+ self.agent_ids = agent_ids
32
+ metric_matrix = metric_matrix.detach().cpu()
33
+ assert 0 < metric_matrix.ndim <= 2, "Metric matrix must have less than or equal to 2 dimensions"
34
+ if metric_matrix.ndim == 1:
35
+ metric_matrix = metric_matrix.reshape(1, -1)
36
+ # Convert to float32 if tensor is in BFloat16 format (not supported by numpy)
37
+ if metric_matrix.dtype == torch.bfloat16:
38
+ metric_matrix = metric_matrix.float()
39
+ self.metric_matrix = metric_matrix.numpy()
40
+
41
+ class RolloutTally:
42
+ """
43
+ Tally is a utility class for collecting and storing training metrics.
44
+ It supports adding metrics at specified paths and saving them to disk.
45
+ """
46
+
47
+ def __init__(self):
48
+ """
49
+ Initializes the RolloutTally object.
50
+
51
+ Args:
52
+ tokenizer (AutoTokenizer): Tokenizer for converting token IDs to strings.
53
+ max_context_length (int, optional): Maximum context length for contextualized metrics. Defaults to 30.
54
+ """
55
+ # Array-preserving structure (leaf lists hold numpy arrays / scalars)
56
+ self.metrics = {}
57
+ # Global ordered list of sample identifiers (crn_id, rollout_id) added in the order samples are processed
58
+
59
+ def reset(self):
60
+ """
61
+ Resets the base and contextualized tallies to empty dictionaries.
62
+ """
63
+ self.metrics = {}
64
+
65
+ def get_from_nested_dict(self, dictio: dict, path: str):
66
+ """
67
+ Retrieves the value at a nested path in a dictionary.
68
+
69
+ Args:
70
+ dictio (dict): The dictionary to search.
71
+ path (list): List of keys representing the path.
72
+
73
+ Returns:
74
+ Any: The value at the specified path, or None if not found.
75
+ """
76
+ assert isinstance(path, list), "Path must be list."
77
+ for sp in path[:-1]:
78
+ dictio = dictio.setdefault(sp, {})
79
+ return dictio.get(path[-1], None)
80
+
81
+ def set_at_path(self, dictio: dict, path: str, value):
82
+ """
83
+ Sets a value at a nested path in a dictionary, creating intermediate dictionaries as needed.
84
+
85
+ Args:
86
+ dictio (dict): The dictionary to modify.
87
+ path (list): List of keys representing the path.
88
+ value (Any): The value to set at the specified path.
89
+ """
90
+ for sp in path[:-1]:
91
+ dictio = dictio.setdefault(sp, {})
92
+ dictio[path[-1]] = value
93
+
94
+
95
+ def add_metric(
96
+ self, path: list[str], rollout_tally_item: RolloutTallyItem
97
+ ):
98
+ """
99
+ Adds a metric to the base tally at the specified path.
100
+
101
+ Args:
102
+ path (list): List of keys representing the path in the base tally.
103
+ rollout_tally_item (RolloutTallyItem): The rollout tally item to add.
104
+ """
105
+ rollout_tally_item = deepcopy(rollout_tally_item)
106
+
107
+ # Update array-preserving tally
108
+ array_list = self.get_from_nested_dict(dictio=self.metrics, path=path)
109
+ if array_list is None:
110
+ self.set_at_path(dictio=self.metrics, path=path, value=[rollout_tally_item])
111
+ else:
112
+ array_list.append(rollout_tally_item)
113
+
114
+
115
+ def save(self, identifier: str, folder: str):
116
+ """
117
+ Saves the base and contextualized tallies to disk as JSON files, and also saves contextualized tallies as CSV files for each game/rollout.
118
+
119
+ Args:
120
+ path (str): Directory path where the metrics will be saved.
121
+ """
122
+ os.makedirs(name=folder, exist_ok=True)
123
+
124
+ from datetime import datetime
125
+
126
+ now = datetime.now()
127
+
128
+ # Pickle only (fastest, exact structure with numpy/scalars at leaves)
129
+ try:
130
+ import pickle
131
+
132
+ pkl_path = os.path.join(folder, f"{identifier}.rt_tally.pkl")
133
+ payload = {"metrics": self.metrics}
134
+ with open(pkl_path, "wb") as f:
135
+ pickle.dump(payload, f, protocol=pickle.HIGHEST_PROTOCOL)
136
+ except Exception:
137
+ pass