Muqeeth commited on
Commit
d476f99
·
verified ·
1 Parent(s): e799c61

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .hydra/config.yaml +183 -0
  2. .hydra/hydra.yaml +154 -0
  3. run.log +0 -0
  4. seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/README.md +207 -0
  5. seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/agent_adapter/adapter_config.json +42 -0
  6. seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/critic_adapter/adapter_config.json +42 -0
  7. seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/fixed_ad_align_adapter/adapter_config.json +42 -0
  8. src_code_for_reproducibility/__init__.py +0 -0
  9. src_code_for_reproducibility/__pycache__/__init__.cpython-312.pyc +0 -0
  10. src_code_for_reproducibility/chat_utils/apply_template.py +84 -0
  11. src_code_for_reproducibility/chat_utils/chat_turn.py +27 -0
  12. src_code_for_reproducibility/chat_utils/template_specific.py +109 -0
  13. src_code_for_reproducibility/docs/Makefile +19 -0
  14. src_code_for_reproducibility/docs/generate_docs.py +249 -0
  15. src_code_for_reproducibility/docs/make.bat +35 -0
  16. src_code_for_reproducibility/docs/source/src.environments.dond.dond_training_data_funcs.rst +7 -0
  17. src_code_for_reproducibility/docs/source/src.experiments.rst +17 -0
  18. src_code_for_reproducibility/docs/source/src.generation.rst +15 -0
  19. src_code_for_reproducibility/docs/source/src.models.rst +20 -0
  20. src_code_for_reproducibility/docs/source/src.training.ppo_train_value_head.rst +7 -0
  21. src_code_for_reproducibility/markov_games/__init__.py +0 -0
  22. src_code_for_reproducibility/markov_games/agent.py +76 -0
  23. src_code_for_reproducibility/markov_games/alternative_actions_runner.py +138 -0
  24. src_code_for_reproducibility/markov_games/group_timesteps.py +150 -0
  25. src_code_for_reproducibility/markov_games/linear_runner.py +30 -0
  26. src_code_for_reproducibility/markov_games/markov_game.py +208 -0
  27. src_code_for_reproducibility/markov_games/mg_utils.py +89 -0
  28. src_code_for_reproducibility/markov_games/rollout_tree.py +86 -0
  29. src_code_for_reproducibility/markov_games/run_markov_games.py +24 -0
  30. src_code_for_reproducibility/markov_games/simulation.py +87 -0
  31. src_code_for_reproducibility/markov_games/statistics_runner.py +405 -0
  32. src_code_for_reproducibility/markov_games/vine_ppo.py +10 -0
  33. src_code_for_reproducibility/models/__init__.py +0 -0
  34. src_code_for_reproducibility/models/__pycache__/__init__.cpython-312.pyc +0 -0
  35. src_code_for_reproducibility/models/__pycache__/inference_backend.cpython-312.pyc +0 -0
  36. src_code_for_reproducibility/models/__pycache__/inference_backend_vllm.cpython-312.pyc +0 -0
  37. src_code_for_reproducibility/models/__pycache__/large_language_model_local.cpython-312.pyc +0 -0
  38. src_code_for_reproducibility/models/__pycache__/scalar_critic.cpython-312.pyc +0 -0
  39. src_code_for_reproducibility/models/adapter_training_wrapper.py +98 -0
  40. src_code_for_reproducibility/models/human_policy.py +255 -0
  41. src_code_for_reproducibility/models/inference_backend.py +39 -0
  42. src_code_for_reproducibility/models/inference_backend_dummy.py +54 -0
  43. src_code_for_reproducibility/models/inference_backend_sglang.py +86 -0
  44. src_code_for_reproducibility/models/inference_backend_sglang_local_server.py +127 -0
  45. src_code_for_reproducibility/models/inference_backend_vllm.py +118 -0
  46. src_code_for_reproducibility/models/inference_backend_vllm_local_server.py +160 -0
  47. src_code_for_reproducibility/models/large_language_model_api.py +171 -0
  48. src_code_for_reproducibility/models/large_language_model_local.py +384 -0
  49. src_code_for_reproducibility/models/scalar_critic.py +54 -0
  50. src_code_for_reproducibility/training/README.md +20 -0
.hydra/config.yaml ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ experiment:
2
+ wandb_enabled: true
3
+ nb_epochs: 3000
4
+ nb_matches_per_iteration: 64
5
+ reinit_matches_each_it: true
6
+ checkpoint_every_n_iterations: 50
7
+ start_epoch: 0
8
+ resume_experiment: true
9
+ base_seed: 0
10
+ seed_group_size: 8
11
+ train: true
12
+ stat_methods_for_live_wandb: mllm.markov_games.negotiation.negotiation_statistics
13
+ name: naive_vs_fixed_ad_align_seed4321
14
+ agent_buffer: false
15
+ keep_agent_buffer_count: ${lora_count}
16
+ agent_buffer_recent_k: -1
17
+ description: Trust-and-Split Rock Paper Scissors negotiation game
18
+ logging:
19
+ wandb:
20
+ enabled: false
21
+ project: llm-negotiation
22
+ entity: null
23
+ mode: online
24
+ name: null
25
+ group: null
26
+ tags: []
27
+ notes: null
28
+ temperature: 1.0
29
+ markov_games:
30
+ runner_method_name: LinearRunner
31
+ runner_kwargs: {}
32
+ group_by_round: true
33
+ simulation_class_name: TrustAndSplitRPSSimulation
34
+ simulation_init_args:
35
+ nb_of_rounds: 10
36
+ quota_messages_per_agent_per_round: 1
37
+ alternating_hands: false
38
+ agents:
39
+ 0:
40
+ agent_id: ${agent_0_id}
41
+ agent_name: Alice
42
+ agent_class_name: TrustAndSplitRPSAgent
43
+ policy_id: base_llm/agent_adapter
44
+ init_kwargs:
45
+ goal: Maximize your total points over the whole game.
46
+ num_message_chars: 500
47
+ message_start_end_format: true
48
+ proposal_start_end_format: true
49
+ 1:
50
+ agent_id: ${agent_1_id}
51
+ agent_name: Bob
52
+ agent_class_name: TrustAndSplitRPSAgent
53
+ policy_id: base_llm/fixed_ad_align_adapter
54
+ init_kwargs:
55
+ goal: Maximize your total points over the whole game.
56
+ num_message_chars: 500
57
+ message_start_end_format: true
58
+ proposal_start_end_format: true
59
+ models:
60
+ base_llm:
61
+ class: LeanLocalLLM
62
+ init_args:
63
+ llm_id: base_llm
64
+ model_name: Qwen/Qwen2.5-7B-Instruct
65
+ inference_backend: vllm
66
+ hf_kwargs:
67
+ device_map: auto
68
+ torch_dtype: bfloat16
69
+ max_memory:
70
+ 0: 20GiB
71
+ attn_implementation: flash_attention_2
72
+ inference_backend_init_kwargs:
73
+ enable_lora: true
74
+ seed: ${experiment.base_seed}
75
+ enable_prefix_caching: true
76
+ max_model_len: 10000.0
77
+ gpu_memory_utilization: 0.5
78
+ dtype: bfloat16
79
+ trust_remote_code: true
80
+ max_lora_rank: 32
81
+ enforce_eager: false
82
+ max_loras: ${lora_count}
83
+ max_cpu_loras: ${lora_count}
84
+ enable_sleep_mode: true
85
+ inference_backend_sampling_params:
86
+ temperature: ${temperature}
87
+ top_p: 1.0
88
+ max_tokens: 400
89
+ top_k: -1
90
+ logprobs: 0
91
+ adapter_configs:
92
+ agent_adapter:
93
+ task_type: CAUSAL_LM
94
+ r: 32
95
+ lora_alpha: 64
96
+ lora_dropout: 0.0
97
+ target_modules: all-linear
98
+ critic_adapter:
99
+ task_type: CAUSAL_LM
100
+ r: 32
101
+ lora_alpha: 64
102
+ lora_dropout: 0.0
103
+ target_modules: all-linear
104
+ fixed_ad_align_adapter:
105
+ task_type: CAUSAL_LM
106
+ r: 32
107
+ lora_alpha: 64
108
+ lora_dropout: 0.0
109
+ target_modules: all-linear
110
+ enable_thinking: null
111
+ regex_max_attempts: 1
112
+ initial_adapter_paths:
113
+ fixed_ad_align_adapter: ${fixed_ad_align_adapter_path}
114
+ critics:
115
+ agent_critic:
116
+ module_pointer:
117
+ - base_llm
118
+ - critic_adapter
119
+ optimizers:
120
+ agent_optimizer:
121
+ module_pointer:
122
+ - base_llm
123
+ - agent_adapter
124
+ optimizer_class_name: torch.optim.Adam
125
+ init_args:
126
+ lr: 3.0e-06
127
+ weight_decay: 0.0
128
+ critic_optimizer:
129
+ module_pointer: agent_critic
130
+ optimizer_class_name: torch.optim.Adam
131
+ init_args:
132
+ lr: 3.0e-06
133
+ weight_decay: 0.0
134
+ trainers:
135
+ agent_trainer:
136
+ class: TrainerNaive
137
+ module_pointers:
138
+ policy:
139
+ - base_llm
140
+ - agent_adapter
141
+ policy_optimizer: agent_optimizer
142
+ critic: agent_critic
143
+ critic_optimizer: critic_optimizer
144
+ kwargs:
145
+ entropy_coeff: 0.0
146
+ entropy_topk: null
147
+ entropy_mask_regex: null
148
+ kl_coeff: 0.001
149
+ gradient_clipping: 1.0
150
+ restrict_tokens: null
151
+ mini_batch_size: 1
152
+ use_gradient_checkpointing: true
153
+ temperature: ${temperature}
154
+ device: cuda:0
155
+ use_gae: false
156
+ whiten_advantages: false
157
+ whiten_advantages_time_step_wise: false
158
+ skip_discounted_state_visitation: true
159
+ use_gae_lambda_annealing: false
160
+ gae_lambda_annealing_method: None
161
+ gae_lambda_annealing_method_params: None
162
+ gae_lambda_annealing_limit: 0.95
163
+ discount_factor: 0.96
164
+ use_rloo: true
165
+ enable_tokenwise_logging: false
166
+ pg_loss_normalization: nb_tokens
167
+ truncated_importance_sampling_ratio_cap: 2.0
168
+ reward_normalizing_constant: 100.0
169
+ train_on_which_data:
170
+ agent_trainer:
171
+ - Alice
172
+ lora_count: 30
173
+ common_agent_kwargs:
174
+ goal: Maximize your total points over the whole game.
175
+ num_message_chars: 500
176
+ message_start_end_format: true
177
+ proposal_start_end_format: true
178
+ agent_0_id: Alice
179
+ agent_1_id: Bob
180
+ agent_ids:
181
+ - Alice
182
+ - Bob
183
+ fixed_ad_align_adapter_path: /home/muqeeth/scratch/llm_negotiation/2025_11/tas_rps_startend_ad_align_nocurrtimestep_seed4321_beta2/seed_4321/Qwen/Qwen2.5-7B-Instruct/adapters/agent_adapter
.hydra/hydra.yaml ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: ${oc.env:SCRATCH}/llm_negotiation/${now:%Y_%m}/${experiment.name}
4
+ sweep:
5
+ dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
6
+ subdir: ${hydra.job.num}
7
+ launcher:
8
+ _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
9
+ sweeper:
10
+ _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
11
+ max_batch_size: null
12
+ params: null
13
+ help:
14
+ app_name: ${hydra.job.name}
15
+ header: '${hydra.help.app_name} is powered by Hydra.
16
+
17
+ '
18
+ footer: 'Powered by Hydra (https://hydra.cc)
19
+
20
+ Use --hydra-help to view Hydra specific help
21
+
22
+ '
23
+ template: '${hydra.help.header}
24
+
25
+ == Configuration groups ==
26
+
27
+ Compose your configuration from those groups (group=option)
28
+
29
+
30
+ $APP_CONFIG_GROUPS
31
+
32
+
33
+ == Config ==
34
+
35
+ Override anything in the config (foo.bar=value)
36
+
37
+
38
+ $CONFIG
39
+
40
+
41
+ ${hydra.help.footer}
42
+
43
+ '
44
+ hydra_help:
45
+ template: 'Hydra (${hydra.runtime.version})
46
+
47
+ See https://hydra.cc for more info.
48
+
49
+
50
+ == Flags ==
51
+
52
+ $FLAGS_HELP
53
+
54
+
55
+ == Configuration groups ==
56
+
57
+ Compose your configuration from those groups (For example, append hydra/job_logging=disabled
58
+ to command line)
59
+
60
+
61
+ $HYDRA_CONFIG_GROUPS
62
+
63
+
64
+ Use ''--cfg hydra'' to Show the Hydra config.
65
+
66
+ '
67
+ hydra_help: ???
68
+ hydra_logging:
69
+ version: 1
70
+ formatters:
71
+ simple:
72
+ format: '[%(asctime)s][HYDRA] %(message)s'
73
+ handlers:
74
+ console:
75
+ class: logging.StreamHandler
76
+ formatter: simple
77
+ stream: ext://sys.stdout
78
+ root:
79
+ level: INFO
80
+ handlers:
81
+ - console
82
+ loggers:
83
+ logging_example:
84
+ level: DEBUG
85
+ disable_existing_loggers: false
86
+ job_logging:
87
+ version: 1
88
+ formatters:
89
+ simple:
90
+ format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
91
+ handlers:
92
+ console:
93
+ class: logging.StreamHandler
94
+ formatter: simple
95
+ stream: ext://sys.stdout
96
+ file:
97
+ class: logging.FileHandler
98
+ formatter: simple
99
+ filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
100
+ root:
101
+ level: INFO
102
+ handlers:
103
+ - console
104
+ - file
105
+ disable_existing_loggers: false
106
+ env: {}
107
+ mode: RUN
108
+ searchpath: []
109
+ callbacks: {}
110
+ output_subdir: .hydra
111
+ overrides:
112
+ hydra:
113
+ - hydra.mode=RUN
114
+ task: []
115
+ job:
116
+ name: run
117
+ chdir: false
118
+ override_dirname: ''
119
+ id: ???
120
+ num: ???
121
+ config_name: naive_vs_fixed_ad_align_seed4321.yaml
122
+ env_set: {}
123
+ env_copy: []
124
+ config:
125
+ override_dirname:
126
+ kv_sep: '='
127
+ item_sep: ','
128
+ exclude_keys: []
129
+ runtime:
130
+ version: 1.3.2
131
+ version_base: '1.1'
132
+ cwd: /scratch/muqeeth/llm_negotiation
133
+ config_sources:
134
+ - path: hydra.conf
135
+ schema: pkg
136
+ provider: hydra
137
+ - path: /scratch/muqeeth/llm_negotiation/configs
138
+ schema: file
139
+ provider: main
140
+ - path: ''
141
+ schema: structured
142
+ provider: schema
143
+ output_dir: /scratch/muqeeth/llm_negotiation/2025_11/naive_vs_fixed_ad_align_seed4321
144
+ choices:
145
+ hydra/env: default
146
+ hydra/callbacks: null
147
+ hydra/job_logging: default
148
+ hydra/hydra_logging: default
149
+ hydra/hydra_help: default
150
+ hydra/help: default
151
+ hydra/sweeper: basic
152
+ hydra/launcher: basic
153
+ hydra/output: default
154
+ verbose: false
run.log ADDED
The diff for this file is too large to render. See raw diff
 
seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/README.md ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: Qwen/Qwen2.5-7B-Instruct
3
+ library_name: peft
4
+ pipeline_tag: text-generation
5
+ tags:
6
+ - base_model:adapter:Qwen/Qwen2.5-7B-Instruct
7
+ - lora
8
+ - transformers
9
+ ---
10
+
11
+ # Model Card for Model ID
12
+
13
+ <!-- Provide a quick summary of what the model is/does. -->
14
+
15
+
16
+
17
+ ## Model Details
18
+
19
+ ### Model Description
20
+
21
+ <!-- Provide a longer summary of what this model is. -->
22
+
23
+
24
+
25
+ - **Developed by:** [More Information Needed]
26
+ - **Funded by [optional]:** [More Information Needed]
27
+ - **Shared by [optional]:** [More Information Needed]
28
+ - **Model type:** [More Information Needed]
29
+ - **Language(s) (NLP):** [More Information Needed]
30
+ - **License:** [More Information Needed]
31
+ - **Finetuned from model [optional]:** [More Information Needed]
32
+
33
+ ### Model Sources [optional]
34
+
35
+ <!-- Provide the basic links for the model. -->
36
+
37
+ - **Repository:** [More Information Needed]
38
+ - **Paper [optional]:** [More Information Needed]
39
+ - **Demo [optional]:** [More Information Needed]
40
+
41
+ ## Uses
42
+
43
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
44
+
45
+ ### Direct Use
46
+
47
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
48
+
49
+ [More Information Needed]
50
+
51
+ ### Downstream Use [optional]
52
+
53
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
54
+
55
+ [More Information Needed]
56
+
57
+ ### Out-of-Scope Use
58
+
59
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
60
+
61
+ [More Information Needed]
62
+
63
+ ## Bias, Risks, and Limitations
64
+
65
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
66
+
67
+ [More Information Needed]
68
+
69
+ ### Recommendations
70
+
71
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
72
+
73
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
74
+
75
+ ## How to Get Started with the Model
76
+
77
+ Use the code below to get started with the model.
78
+
79
+ [More Information Needed]
80
+
81
+ ## Training Details
82
+
83
+ ### Training Data
84
+
85
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
86
+
87
+ [More Information Needed]
88
+
89
+ ### Training Procedure
90
+
91
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
92
+
93
+ #### Preprocessing [optional]
94
+
95
+ [More Information Needed]
96
+
97
+
98
+ #### Training Hyperparameters
99
+
100
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
101
+
102
+ #### Speeds, Sizes, Times [optional]
103
+
104
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
105
+
106
+ [More Information Needed]
107
+
108
+ ## Evaluation
109
+
110
+ <!-- This section describes the evaluation protocols and provides the results. -->
111
+
112
+ ### Testing Data, Factors & Metrics
113
+
114
+ #### Testing Data
115
+
116
+ <!-- This should link to a Dataset Card if possible. -->
117
+
118
+ [More Information Needed]
119
+
120
+ #### Factors
121
+
122
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
123
+
124
+ [More Information Needed]
125
+
126
+ #### Metrics
127
+
128
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
129
+
130
+ [More Information Needed]
131
+
132
+ ### Results
133
+
134
+ [More Information Needed]
135
+
136
+ #### Summary
137
+
138
+
139
+
140
+ ## Model Examination [optional]
141
+
142
+ <!-- Relevant interpretability work for the model goes here -->
143
+
144
+ [More Information Needed]
145
+
146
+ ## Environmental Impact
147
+
148
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
149
+
150
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
151
+
152
+ - **Hardware Type:** [More Information Needed]
153
+ - **Hours used:** [More Information Needed]
154
+ - **Cloud Provider:** [More Information Needed]
155
+ - **Compute Region:** [More Information Needed]
156
+ - **Carbon Emitted:** [More Information Needed]
157
+
158
+ ## Technical Specifications [optional]
159
+
160
+ ### Model Architecture and Objective
161
+
162
+ [More Information Needed]
163
+
164
+ ### Compute Infrastructure
165
+
166
+ [More Information Needed]
167
+
168
+ #### Hardware
169
+
170
+ [More Information Needed]
171
+
172
+ #### Software
173
+
174
+ [More Information Needed]
175
+
176
+ ## Citation [optional]
177
+
178
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
179
+
180
+ **BibTeX:**
181
+
182
+ [More Information Needed]
183
+
184
+ **APA:**
185
+
186
+ [More Information Needed]
187
+
188
+ ## Glossary [optional]
189
+
190
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
191
+
192
+ [More Information Needed]
193
+
194
+ ## More Information [optional]
195
+
196
+ [More Information Needed]
197
+
198
+ ## Model Card Authors [optional]
199
+
200
+ [More Information Needed]
201
+
202
+ ## Model Card Contact
203
+
204
+ [More Information Needed]
205
+ ### Framework versions
206
+
207
+ - PEFT 0.17.1
seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/agent_adapter/adapter_config.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "Qwen/Qwen2.5-7B-Instruct",
5
+ "bias": "none",
6
+ "corda_config": null,
7
+ "eva_config": null,
8
+ "exclude_modules": null,
9
+ "fan_in_fan_out": false,
10
+ "inference_mode": true,
11
+ "init_lora_weights": true,
12
+ "layer_replication": null,
13
+ "layers_pattern": null,
14
+ "layers_to_transform": null,
15
+ "loftq_config": {},
16
+ "lora_alpha": 64,
17
+ "lora_bias": false,
18
+ "lora_dropout": 0.0,
19
+ "megatron_config": null,
20
+ "megatron_core": "megatron.core",
21
+ "modules_to_save": null,
22
+ "peft_type": "LORA",
23
+ "qalora_group_size": 16,
24
+ "r": 32,
25
+ "rank_pattern": {},
26
+ "revision": null,
27
+ "target_modules": [
28
+ "o_proj",
29
+ "gate_proj",
30
+ "up_proj",
31
+ "v_proj",
32
+ "k_proj",
33
+ "q_proj",
34
+ "down_proj"
35
+ ],
36
+ "target_parameters": null,
37
+ "task_type": "CAUSAL_LM",
38
+ "trainable_token_indices": null,
39
+ "use_dora": false,
40
+ "use_qalora": false,
41
+ "use_rslora": false
42
+ }
seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/critic_adapter/adapter_config.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "Qwen/Qwen2.5-7B-Instruct",
5
+ "bias": "none",
6
+ "corda_config": null,
7
+ "eva_config": null,
8
+ "exclude_modules": null,
9
+ "fan_in_fan_out": false,
10
+ "inference_mode": true,
11
+ "init_lora_weights": true,
12
+ "layer_replication": null,
13
+ "layers_pattern": null,
14
+ "layers_to_transform": null,
15
+ "loftq_config": {},
16
+ "lora_alpha": 64,
17
+ "lora_bias": false,
18
+ "lora_dropout": 0.0,
19
+ "megatron_config": null,
20
+ "megatron_core": "megatron.core",
21
+ "modules_to_save": null,
22
+ "peft_type": "LORA",
23
+ "qalora_group_size": 16,
24
+ "r": 32,
25
+ "rank_pattern": {},
26
+ "revision": null,
27
+ "target_modules": [
28
+ "o_proj",
29
+ "gate_proj",
30
+ "up_proj",
31
+ "v_proj",
32
+ "k_proj",
33
+ "q_proj",
34
+ "down_proj"
35
+ ],
36
+ "target_parameters": null,
37
+ "task_type": "CAUSAL_LM",
38
+ "trainable_token_indices": null,
39
+ "use_dora": false,
40
+ "use_qalora": false,
41
+ "use_rslora": false
42
+ }
seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/fixed_ad_align_adapter/adapter_config.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "Qwen/Qwen2.5-7B-Instruct",
5
+ "bias": "none",
6
+ "corda_config": null,
7
+ "eva_config": null,
8
+ "exclude_modules": null,
9
+ "fan_in_fan_out": false,
10
+ "inference_mode": true,
11
+ "init_lora_weights": true,
12
+ "layer_replication": null,
13
+ "layers_pattern": null,
14
+ "layers_to_transform": null,
15
+ "loftq_config": {},
16
+ "lora_alpha": 64,
17
+ "lora_bias": false,
18
+ "lora_dropout": 0.0,
19
+ "megatron_config": null,
20
+ "megatron_core": "megatron.core",
21
+ "modules_to_save": null,
22
+ "peft_type": "LORA",
23
+ "qalora_group_size": 16,
24
+ "r": 32,
25
+ "rank_pattern": {},
26
+ "revision": null,
27
+ "target_modules": [
28
+ "o_proj",
29
+ "gate_proj",
30
+ "up_proj",
31
+ "v_proj",
32
+ "k_proj",
33
+ "q_proj",
34
+ "down_proj"
35
+ ],
36
+ "target_parameters": null,
37
+ "task_type": "CAUSAL_LM",
38
+ "trainable_token_indices": null,
39
+ "use_dora": false,
40
+ "use_qalora": false,
41
+ "use_rslora": false
42
+ }
src_code_for_reproducibility/__init__.py ADDED
File without changes
src_code_for_reproducibility/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (146 Bytes). View file
 
src_code_for_reproducibility/chat_utils/apply_template.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from mllm.chat_utils.chat_turn import ChatTurn
4
+ from mllm.chat_utils.template_specific import (
5
+ custom_gemma3_template,
6
+ custom_llama3_template,
7
+ custom_qwen2_template,
8
+ custom_qwen3_template,
9
+ gemma3_assistant_postfix,
10
+ qwen2_assistant_postfix,
11
+ qwen3_assistant_postfix,
12
+ )
13
+
14
+
15
+ def get_custom_chat_template(tokenizer) -> str:
16
+ """
17
+ Get the chat template for the tokenizer.
18
+ """
19
+ if "qwen2" in tokenizer.name_or_path.lower():
20
+ return custom_qwen2_template
21
+ elif "llama" in tokenizer.name_or_path.lower():
22
+ return custom_llama3_template
23
+ elif "qwen3" in tokenizer.name_or_path.lower():
24
+ return custom_qwen3_template
25
+ elif "gemma" in tokenizer.name_or_path.lower():
26
+ return custom_gemma3_template
27
+ else:
28
+ raise ValueError(f"Tokenizer {tokenizer.name_or_path} not supported")
29
+
30
+
31
+ def get_custom_assistant_postfix(tokenizer) -> torch.Tensor:
32
+ """
33
+ Get the custom assistant postfix for the tokenizer.
34
+ """
35
+ if "qwen2" in tokenizer.name_or_path.lower():
36
+ return qwen2_assistant_postfix
37
+ elif "qwen3" in tokenizer.name_or_path.lower():
38
+ return qwen3_assistant_postfix
39
+ elif "gemma" in tokenizer.name_or_path.lower():
40
+ return gemma3_assistant_postfix
41
+ return torch.tensor([], dtype=torch.long)
42
+
43
+
44
+ def tokenize_chats(chats: list[ChatTurn], tokenizer, enable_thinking) -> None:
45
+ """
46
+ Set the chat_template_token_ids for each chat turn.
47
+ # TODO: use engine tokens if available
48
+ """
49
+ custom_template = get_custom_chat_template(tokenizer)
50
+ custom_assistant_postfix: torch.Tensor = get_custom_assistant_postfix(tokenizer)
51
+ for i, chat in enumerate(chats):
52
+ if chat.chat_template_token_ids is None:
53
+ if chat.role == "user":
54
+ next_chat = chats[i + 1] if i + 1 < len(chats) else None
55
+ add_generation_prompt = True
56
+ if next_chat and next_chat.role == "user":
57
+ add_generation_prompt = False
58
+ encoded_chat = tokenizer.apply_chat_template(
59
+ [chat],
60
+ return_tensors="pt",
61
+ chat_template=custom_template,
62
+ add_generation_prompt=add_generation_prompt,
63
+ add_system_prompt=True if i == 0 else False,
64
+ enable_thinking=enable_thinking,
65
+ ).flatten()
66
+ previous_chat = chats[i - 1] if i > 0 else None
67
+ if previous_chat and previous_chat.role == "assistant":
68
+ encoded_chat = torch.cat([custom_assistant_postfix, encoded_chat])
69
+ elif chat.role == "assistant":
70
+ encoded_chat = chat.out_token_ids
71
+ chat.chat_template_token_ids = encoded_chat
72
+
73
+
74
+ def chat_turns_to_token_ids(
75
+ chats: list[ChatTurn], tokenizer, enable_thinking
76
+ ) -> list[int]:
77
+ """
78
+ Tokenize the chat turns and set the chat_template_token_ids for each chat turn.
79
+ """
80
+ tokenize_chats(chats=chats, tokenizer=tokenizer, enable_thinking=enable_thinking)
81
+ token_ids = []
82
+ for chat in chats:
83
+ token_ids.append(chat.chat_template_token_ids)
84
+ return torch.cat(token_ids)
src_code_for_reproducibility/chat_utils/chat_turn.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Any, List, Literal, Optional, Tuple
7
+
8
+ import jsonschema
9
+ import torch
10
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
11
+
12
+ AgentId = str
13
+
14
+
15
+ class ChatTurn(BaseModel):
16
+ model_config = ConfigDict(arbitrary_types_allowed=True) # needed for torch tensors
17
+
18
+ role: str = Field(pattern="^(user|assistant)$")
19
+ agent_id: AgentId # ID of the agent with which the chat occured
20
+ content: str
21
+ reasoning_content: str | None = None
22
+ chat_template_token_ids: torch.LongTensor | None = None # Token ids of chat template format. For example, token ids of "<assistant>{content}</assistant>""
23
+ out_token_ids: torch.LongTensor | None = (
24
+ None # tokens generated from inference engine
25
+ )
26
+ log_probs: torch.FloatTensor | None = None
27
+ is_state_end: bool = False # indicates whether this chat turn marks the end of a state in the trajectory
src_code_for_reproducibility/chat_utils/template_specific.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import huggingface_hub
2
+ import torch
3
+ from transformers import AutoTokenizer
4
+
5
+ custom_llama3_template = """
6
+ {%- if add_system_prompt %}
7
+ {{- '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n<|eot_id|>' }}
8
+ {%- endif %}
9
+ {%- for message in messages %}
10
+ {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }}
11
+ {%- endfor %}
12
+
13
+ {%- if add_generation_prompt %}
14
+ {{- '<|start_header_id|>' + 'assistant' + '<|end_header_id|>\n\n' }}
15
+ {%- endif %}
16
+ """
17
+
18
+ qwen2_assistant_postfix = (
19
+ AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
20
+ .encode("\n", return_tensors="pt")
21
+ .flatten()
22
+ )
23
+ qwen3_assistant_postfix = (
24
+ AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
25
+ .encode("\n", return_tensors="pt")
26
+ .flatten()
27
+ )
28
+ gemma3_assistant_postfix = (
29
+ AutoTokenizer.from_pretrained("google/gemma-3-4b-it")
30
+ .encode("\n", return_tensors="pt")
31
+ .flatten()
32
+ )
33
+ custom_qwen2_template = """
34
+ {%- if add_system_prompt %}
35
+ {{- '<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n' }}
36
+ {%- endif %}
37
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
38
+ {%- for message in messages %}
39
+ {%- if message.content is string %}
40
+ {%- set content = message.content %}
41
+ {%- else %}
42
+ {%- set content = '' %}
43
+ {%- endif %}
44
+ {%- if (message.role == "user") %}
45
+ {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
46
+ {%- elif message.role == "assistant" %}
47
+ {%- set reasoning_content = '' %}
48
+ {%- if message.reasoning_content is string %}
49
+ {%- set reasoning_content = message.reasoning_content %}
50
+ {%- else %}
51
+ {%- if '</think>' in content %}
52
+ {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
53
+ {%- set content = content.split('</think>')[-1].lstrip('\n') %}
54
+ {%- endif %}
55
+ {%- endif %}
56
+ {%- if loop.index0 > ns.last_query_index %}
57
+ {%- if reasoning_content %}
58
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
59
+ {%- else %}
60
+ {{- '<|im_start|>' + message.role + '\n' + content }}
61
+ {%- endif %}
62
+ {%- else %}
63
+ {{- '<|im_start|>' + message.role + '\n' + content }}
64
+ {%- endif %}
65
+ {{- '<|im_end|>\n' }}
66
+ {%- endif %}
67
+ {%- endfor %}
68
+ {%- if add_generation_prompt %}
69
+ {{- '<|im_start|>assistant\n' }}
70
+ {%- endif %}
71
+ """
72
+
73
+ custom_qwen3_template = """
74
+ {%- for message in messages %}
75
+ {%- if message.content is string %}
76
+ {%- set content = message.content %}
77
+ {%- else %}
78
+ {%- set content = '' %}
79
+ {%- endif %}
80
+ {%- if (message.role == "user") %}
81
+ {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
82
+ {%- elif message.role == "assistant" %}
83
+ {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
84
+ {%- endif %}
85
+ {%- endfor %}
86
+ {%- if add_generation_prompt %}
87
+ {{- '<|im_start|>assistant\n' }}
88
+ {%- if enable_thinking is defined and enable_thinking is false %}
89
+ {{- '<think>\n\n</think>\n\n' }}
90
+ {%- endif %}
91
+ {%- endif %}
92
+ """
93
+
94
+ custom_gemma3_template = """
95
+ {%- if add_system_prompt %}
96
+ {{- bos_token -}}
97
+ {%- endif %}
98
+ {%- for message in messages -%}
99
+ {%- if message['role'] == 'assistant' -%}
100
+ {%- set role = 'model' -%}
101
+ {%- else -%}
102
+ {%- set role = message['role'] -%}
103
+ {%- endif -%}
104
+ {{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}
105
+ {%- endfor -%}
106
+ {%- if add_generation_prompt -%}
107
+ {{ '<start_of_turn>model\n' }}
108
+ {%- endif -%}
109
+ """
src_code_for_reproducibility/docs/Makefile ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Minimal makefile for Sphinx documentation
2
+
3
+ # You can set these variables from the command line, and also
4
+ # from the environment for the first two.
5
+ SPHINXOPTS ?=
6
+ SPHINXBUILD ?= sphinx-build
7
+ SOURCEDIR = source
8
+ BUILDDIR = build
9
+
10
+ # Put it first so that "make" without argument is like "make help".
11
+ help:
12
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(SPHINXFLAGS)
13
+
14
+ .PHONY: help Makefile
15
+
16
+ # Catch-all target: route all unknown targets to Sphinx using the new
17
+ # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
18
+ %: Makefile
19
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(SPHINXFLAGS)
src_code_for_reproducibility/docs/generate_docs.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Script to automatically generate Sphinx documentation for all modules and build the HTML website.
4
+ """
5
+ import importlib.util
6
+ import os
7
+ import subprocess
8
+ import sys
9
+
10
+
11
+ def check_and_install_dependencies():
12
+ """Check for required dependencies and install them if missing."""
13
+ required_packages = [
14
+ "sphinx",
15
+ "sphinx-rtd-theme",
16
+ "sphinxcontrib-napoleon",
17
+ "sphinxcontrib-mermaid",
18
+ "sphinx-autodoc-typehints",
19
+ ]
20
+
21
+ missing_packages = []
22
+
23
+ for package in required_packages:
24
+ # Convert package name to module name (replace - with _)
25
+ module_name = package.replace("-", "_")
26
+
27
+ # Check if the package is installed
28
+ if importlib.util.find_spec(module_name) is None:
29
+ missing_packages.append(package)
30
+
31
+ # Install missing packages
32
+ if missing_packages:
33
+ print(f"Installing missing dependencies: {', '.join(missing_packages)}")
34
+ subprocess.check_call(
35
+ [sys.executable, "-m", "pip", "install"] + missing_packages
36
+ )
37
+ print("Dependencies installed successfully")
38
+ else:
39
+ print("All required dependencies are already installed")
40
+
41
+
42
+ def create_makefile(docs_dir):
43
+ """Create a Makefile for Sphinx documentation if it doesn't exist."""
44
+ makefile_path = os.path.join(docs_dir, "Makefile")
45
+
46
+ if os.path.exists(makefile_path):
47
+ print(f"Makefile already exists at {makefile_path}")
48
+ return
49
+
50
+ print(f"Creating Makefile at {makefile_path}")
51
+
52
+ makefile_content = """# Minimal makefile for Sphinx documentation
53
+
54
+ # You can set these variables from the command line, and also
55
+ # from the environment for the first two.
56
+ SPHINXOPTS ?=
57
+ SPHINXBUILD ?= sphinx-build
58
+ SOURCEDIR = source
59
+ BUILDDIR = build
60
+
61
+ # Put it first so that "make" without argument is like "make help".
62
+ help:
63
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(SPHINXFLAGS)
64
+
65
+ .PHONY: help Makefile
66
+
67
+ # Catch-all target: route all unknown targets to Sphinx using the new
68
+ # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
69
+ %: Makefile
70
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(SPHINXFLAGS)
71
+ """
72
+
73
+ with open(makefile_path, "w") as f:
74
+ f.write(makefile_content)
75
+
76
+ print("Makefile created successfully")
77
+
78
+
79
+ def create_make_bat(docs_dir):
80
+ """Create a make.bat file for Windows if it doesn't exist."""
81
+ make_bat_path = os.path.join(docs_dir, "make.bat")
82
+
83
+ if os.path.exists(make_bat_path):
84
+ print(f"make.bat already exists at {make_bat_path}")
85
+ return
86
+
87
+ print(f"Creating make.bat at {make_bat_path}")
88
+
89
+ make_bat_content = """@ECHO OFF
90
+
91
+ pushd %~dp0
92
+
93
+ REM Command file for Sphinx documentation
94
+
95
+ if "%SPHINXBUILD%" == "" (
96
+ set SPHINXBUILD=sphinx-build
97
+ )
98
+ set SOURCEDIR=source
99
+ set BUILDDIR=build
100
+
101
+ %SPHINXBUILD% >NUL 2>NUL
102
+ if errorlevel 9009 (
103
+ echo.
104
+ echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
105
+ echo.installed, then set the SPHINXBUILD environment variable to point
106
+ echo.to the full path of the 'sphinx-build' executable. Alternatively you
107
+ echo.may add the Sphinx directory to PATH.
108
+ echo.
109
+ echo.If you don't have Sphinx installed, grab it from
110
+ echo.https://www.sphinx-doc.org/
111
+ exit /b 1
112
+ )
113
+
114
+ if "%1" == "" goto help
115
+
116
+ %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
117
+ goto end
118
+
119
+ :help
120
+ %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
121
+
122
+ :end
123
+ popd
124
+ """
125
+
126
+ with open(make_bat_path, "w") as f:
127
+ f.write(make_bat_content)
128
+
129
+ print("make.bat created successfully")
130
+
131
+
132
+ def main():
133
+ # Check and install required dependencies
134
+ print("=== Checking dependencies ===")
135
+ check_and_install_dependencies()
136
+
137
+ # Get the directory of this script
138
+ script_dir = os.path.dirname(os.path.abspath(__file__))
139
+
140
+ # Path to the project root
141
+ project_root = os.path.dirname(script_dir)
142
+
143
+ # Path to the source directory
144
+ source_dir = os.path.join(project_root, "src")
145
+
146
+ # Path to the docs source directory
147
+ docs_source_dir = os.path.join(script_dir, "source")
148
+
149
+ # Print paths for debugging
150
+ print(f"Script directory: {script_dir}")
151
+ print(f"Project root: {project_root}")
152
+ print(f"Source directory: {source_dir}")
153
+ print(f"Docs source directory: {docs_source_dir}")
154
+
155
+ # Make sure the source directory exists
156
+ if not os.path.exists(source_dir):
157
+ print(f"Error: Source directory {source_dir} does not exist!")
158
+ sys.exit(1)
159
+
160
+ # Make sure the docs source directory exists
161
+ if not os.path.exists(docs_source_dir):
162
+ print(f"Creating docs source directory: {docs_source_dir}")
163
+ os.makedirs(docs_source_dir)
164
+
165
+ # Step 1: Run sphinx-apidoc to generate .rst files for all modules
166
+ print("\n=== Generating API documentation ===")
167
+ cmd = [
168
+ "sphinx-apidoc",
169
+ "-f", # Force overwriting of existing files
170
+ "-e", # Put module documentation before submodule documentation
171
+ "-M", # Put module documentation before subpackage documentation
172
+ "-o",
173
+ docs_source_dir, # Output directory
174
+ source_dir, # Source code directory
175
+ ]
176
+
177
+ print(f"Running command: {' '.join(cmd)}")
178
+ result = subprocess.run(cmd, capture_output=True, text=True)
179
+
180
+ # Print the output of the command
181
+ print("STDOUT:")
182
+ print(result.stdout)
183
+
184
+ print("STDERR:")
185
+ print(result.stderr)
186
+
187
+ if result.returncode != 0:
188
+ print(f"Error: sphinx-apidoc failed with return code {result.returncode}")
189
+ sys.exit(1)
190
+
191
+ # List the files in the docs source directory
192
+ print("\nFiles in docs/source directory:")
193
+ for file in sorted(os.listdir(docs_source_dir)):
194
+ print(f" {file}")
195
+
196
+ print("\nDocumentation source files generated successfully!")
197
+
198
+ # Step 2: Create Makefile and make.bat if they don't exist
199
+ create_makefile(script_dir)
200
+ create_make_bat(script_dir)
201
+
202
+ # Step 3: Build the HTML documentation
203
+ print("\n=== Building HTML documentation ===")
204
+
205
+ # Determine the build command based on the platform
206
+ if os.name == "nt": # Windows
207
+ build_cmd = ["make.bat", "html"]
208
+ else: # Unix/Linux/Mac
209
+ build_cmd = ["make", "html"]
210
+
211
+ # Change to the docs directory to run the build command
212
+ os.chdir(script_dir)
213
+
214
+ print(f"Running command: {' '.join(build_cmd)}")
215
+ build_result = subprocess.run(build_cmd, capture_output=True, text=True)
216
+
217
+ # Print the output of the build command
218
+ print("STDOUT:")
219
+ print(build_result.stdout)
220
+
221
+ print("STDERR:")
222
+ print(build_result.stderr)
223
+
224
+ if build_result.returncode != 0:
225
+ print(f"Error: HTML build failed with return code {build_result.returncode}")
226
+ sys.exit(1)
227
+
228
+ # Get the path to the built HTML documentation
229
+ html_dir = os.path.join(script_dir, "build", "html")
230
+ index_path = os.path.join(html_dir, "index.html")
231
+
232
+ if os.path.exists(index_path):
233
+ print(f"\nHTML documentation built successfully!")
234
+ print(f"You can view it by opening: {index_path}")
235
+
236
+ # Try to open the documentation in a browser
237
+ try:
238
+ import webbrowser
239
+
240
+ print("\nAttempting to open documentation in your default browser...")
241
+ webbrowser.open(f"file://{index_path}")
242
+ except Exception as e:
243
+ print(f"Could not open browser automatically: {e}")
244
+ else:
245
+ print(f"\nWarning: HTML index file not found at {index_path}")
246
+
247
+
248
+ if __name__ == "__main__":
249
+ main()
src_code_for_reproducibility/docs/make.bat ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @ECHO OFF
2
+
3
+ pushd %~dp0
4
+
5
+ REM Command file for Sphinx documentation
6
+
7
+ if "%SPHINXBUILD%" == "" (
8
+ set SPHINXBUILD=sphinx-build
9
+ )
10
+ set SOURCEDIR=source
11
+ set BUILDDIR=build
12
+
13
+ %SPHINXBUILD% >NUL 2>NUL
14
+ if errorlevel 9009 (
15
+ echo.
16
+ echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17
+ echo.installed, then set the SPHINXBUILD environment variable to point
18
+ echo.to the full path of the 'sphinx-build' executable. Alternatively you
19
+ echo.may add the Sphinx directory to PATH.
20
+ echo.
21
+ echo.If you don't have Sphinx installed, grab it from
22
+ echo.https://www.sphinx-doc.org/
23
+ exit /b 1
24
+ )
25
+
26
+ if "%1" == "" goto help
27
+
28
+ %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29
+ goto end
30
+
31
+ :help
32
+ %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33
+
34
+ :end
35
+ popd
src_code_for_reproducibility/docs/source/src.environments.dond.dond_training_data_funcs.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.environments.dond.dond\_training\_data\_funcs module
2
+ ========================================================
3
+
4
+ .. automodule:: src.environments.dond.dond_training_data_funcs
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/docs/source/src.experiments.rst ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ src.experiments package
2
+ =======================
3
+
4
+ .. automodule:: src.experiments
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
8
+
9
+ Submodules
10
+ ----------
11
+
12
+ .. toctree::
13
+ :maxdepth: 4
14
+
15
+ src.experiments.arithmetic_test
16
+ src.experiments.generate_and_train
17
+ src.experiments.last_completion
src_code_for_reproducibility/docs/source/src.generation.rst ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ src.generation package
2
+ ======================
3
+
4
+ .. automodule:: src.generation
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
8
+
9
+ Submodules
10
+ ----------
11
+
12
+ .. toctree::
13
+ :maxdepth: 4
14
+
15
+ src.generation.run_games
src_code_for_reproducibility/docs/source/src.models.rst ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ src.models package
2
+ ==================
3
+
4
+ .. automodule:: src.models
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
8
+
9
+ Submodules
10
+ ----------
11
+
12
+ .. toctree::
13
+ :maxdepth: 4
14
+
15
+ src.models.dummy_local_llm
16
+ src.models.local_llm
17
+ src.models.new_local_llm
18
+ src.models.server_llm
19
+ src.models.updatable_worker
20
+ src.models.vllm_worker_wrap
src_code_for_reproducibility/docs/source/src.training.ppo_train_value_head.rst ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ src.training.ppo\_train\_value\_head module
2
+ ===========================================
3
+
4
+ .. automodule:: src.training.ppo_train_value_head
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
src_code_for_reproducibility/markov_games/__init__.py ADDED
File without changes
src_code_for_reproducibility/markov_games/agent.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ In simple RL paradise, where the action dimensions are constant and well defined,
3
+ Agent classes are not necessary. But in MARL, with LLM's, there isn't always
4
+ a direct path from policy to action. For instance, from the observation of the environment,
5
+ a prompt must be created. Then, the outputs of the policy might be incorrect, so a second
6
+ request to the LLM must be sent before the action is well defined. This is why this Agent class exists.
7
+ It acts as a mini environment, bridging the gap between the core simulation and
8
+ the LLM policies.
9
+ """
10
+
11
+ from abc import ABC, abstractmethod
12
+ from collections.abc import Callable
13
+ from typing import Any, Tuple
14
+
15
+ from numpy.random import default_rng
16
+
17
+ from mllm.markov_games.rollout_tree import AgentActLog
18
+
19
+
20
+ class Agent(ABC):
21
+ @abstractmethod
22
+ def __init__(
23
+ self,
24
+ seed: int,
25
+ agent_id: str,
26
+ agent_name: str,
27
+ agent_policy: Callable[[list[dict]], str],
28
+ *args,
29
+ **kwargs,
30
+ ):
31
+ """
32
+ Initialize the agent state.
33
+ """
34
+ self.seed = seed
35
+ self.agent_id = agent_id
36
+ self.agent_name = agent_name
37
+ self.policy = policy
38
+ self.rng = default_rng(self.seed)
39
+ raise NotImplementedError
40
+
41
+ async def act(self, observation) -> Tuple[Any, AgentActLog]:
42
+ """
43
+ Query (possibly multiple times) a policy (or possibly a pool of policies) to
44
+ obtain the action of the agent.
45
+
46
+ Example:
47
+ action = None
48
+ prompt = self.observation_to_prompt(observation)
49
+ while not self.valid(action):
50
+ output = await self.policy.generate(prompt)
51
+ action = self.policy_output_to_action(output)
52
+ return action
53
+
54
+ Returns:
55
+ action
56
+ step_info
57
+ """
58
+ raise NotImplementedError
59
+
60
+ def get_safe_copy(self):
61
+ """
62
+ Return copy of the agent object that is decorrelated from the original object.
63
+ """
64
+ raise NotImplementedError
65
+
66
+ def reset(self):
67
+ raise NotImplementedError
68
+
69
+ def render(self):
70
+ raise NotImplementedError
71
+
72
+ def close(self):
73
+ raise NotImplementedError
74
+
75
+ def get_agent_info(self):
76
+ raise NotImplementedError
src_code_for_reproducibility/markov_games/alternative_actions_runner.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import copy
3
+ import json
4
+ import os.path
5
+ from typing import Any, Tuple
6
+
7
+ from mllm.markov_games.markov_game import AgentAndActionSafeCopy, MarkovGame
8
+ from mllm.markov_games.rollout_tree import (
9
+ AgentActLog,
10
+ RolloutTreeBranchNode,
11
+ RolloutTreeNode,
12
+ RolloutTreeRootNode,
13
+ StepLog,
14
+ )
15
+
16
+ AgentId = str
17
+
18
+
19
+
20
+ async def run_with_unilateral_alt_action(
21
+ markov_game: MarkovGame,
22
+ agent_id: AgentId,
23
+ time_step: int,
24
+ branch_node: RolloutTreeBranchNode,
25
+ max_depth: int,
26
+ ):
27
+ """
28
+ This function is used to generate a new branch for a given agent.
29
+ """
30
+
31
+ # Generate alternative action and take a step
32
+ await markov_game.set_action_of_agent(agent_id)
33
+ terminated: bool = markov_game.take_simulation_step()
34
+ step_log = markov_game.get_step_log()
35
+ first_alternative_node = RolloutTreeNode(
36
+ step_log=step_log,
37
+ time_step=time_step,
38
+ )
39
+
40
+ # Generate rest of trajectory up to max depth
41
+ time_step += 1
42
+ counter = 1
43
+ previous_node = first_alternative_node
44
+ while not terminated and counter <= max_depth:
45
+ terminated, step_log = await markov_game.step()
46
+ current_node = RolloutTreeNode(step_log=step_log, time_step=time_step)
47
+ previous_node.child = current_node
48
+ previous_node = current_node
49
+ counter += 1
50
+ time_step += 1
51
+
52
+ if branch_node.branches == None:
53
+ branch_node.branches = {agent_id: [first_alternative_node]}
54
+ else:
55
+ agent_branches = branch_node.branches.get(agent_id, [])
56
+ agent_branches.append(first_alternative_node)
57
+ branch_node.branches[agent_id] = agent_branches
58
+
59
+
60
+ async def AlternativeActionsRunner(
61
+ markov_game: MarkovGame,
62
+ output_folder: str,
63
+ nb_alternative_actions: int,
64
+ max_depth: int,
65
+ branch_only_on_new_round: bool = False,
66
+ ):
67
+ """
68
+ This method generates a trajectory with partially completed branches,
69
+ where the branching comes from taking unilateraly different actions.
70
+ The resulting data is used to estimate the updated advantage alignment policy gradient terms.
71
+ Let k := nb_sub_steps. Then the number of steps generated is O(Tk), where T is
72
+ the maximum trajectory length.
73
+ """
74
+
75
+ tasks = []
76
+ time_step = 0
77
+ terminated = False
78
+ root = RolloutTreeRootNode(
79
+ id=markov_game.get_id(),
80
+ crn_id=markov_game.get_crn_id()
81
+ )
82
+ previous_node = root
83
+
84
+ while not terminated:
85
+ mg_before_action = markov_game.get_safe_copy()
86
+
87
+ # Get safe copies for main branch
88
+ agent_action_safe_copies: dict[
89
+ AgentId, AgentAndActionSafeCopy
90
+ ] = await markov_game.get_actions_of_agents_without_side_effects()
91
+
92
+ markov_game.set_actions_of_agents_manually(agent_action_safe_copies)
93
+ terminated = markov_game.take_simulation_step()
94
+ main_node = RolloutTreeNode(
95
+ step_log=markov_game.get_step_log(), time_step=time_step
96
+ )
97
+ branch_node = RolloutTreeBranchNode(main_child=main_node)
98
+ previous_node.child = branch_node
99
+ previous_node = main_node
100
+
101
+ # Get alternative branches by generating new unilateral actions
102
+ for agent_id in markov_game.agent_ids:
103
+ for _ in range(nb_alternative_actions):
104
+ # Get safe copies for branches
105
+ branch_agent_action_safe_copies: dict[
106
+ AgentId, AgentAndActionSafeCopy
107
+ ] = {
108
+ agent_id: AgentAndActionSafeCopy(
109
+ action=copy.deepcopy(agent_action_safe_copy.action),
110
+ action_info=copy.deepcopy(agent_action_safe_copy.action_info),
111
+ agent_after_action=agent_action_safe_copy.agent_after_action.get_safe_copy(),
112
+ )
113
+ for agent_id, agent_action_safe_copy in agent_action_safe_copies.items()
114
+ }
115
+ mg_branch: MarkovGame = mg_before_action.get_safe_copy()
116
+ other_agent_id = [id for id in mg_branch.agent_ids if id != agent_id][0]
117
+ mg_branch.set_action_and_agent_after_action_manually(
118
+ agent_id=other_agent_id,
119
+ agent_action_safe_copy=branch_agent_action_safe_copies[
120
+ other_agent_id
121
+ ],
122
+ )
123
+ task = asyncio.create_task(
124
+ run_with_unilateral_alt_action(
125
+ markov_game=mg_branch,
126
+ time_step=time_step,
127
+ agent_id=agent_id,
128
+ branch_node=branch_node,
129
+ max_depth=max_depth,
130
+ )
131
+ )
132
+ tasks.append(task)
133
+ time_step += 1
134
+
135
+ # wait for all branches to complete
136
+ await asyncio.gather(*tasks)
137
+
138
+ return root
src_code_for_reproducibility/markov_games/group_timesteps.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module contains the logic for grouping time steps.
3
+ """
4
+ import copy
5
+ from typing import Callable
6
+
7
+ from mllm.markov_games.markov_game import MarkovGame
8
+ from mllm.markov_games.rollout_tree import (
9
+ AgentActLog,
10
+ RolloutTreeBranchNode,
11
+ RolloutTreeNode,
12
+ RolloutTreeRootNode,
13
+ StepLog,
14
+ )
15
+ from mllm.markov_games.simulation import SimulationStepLog
16
+
17
+ AgentId = str
18
+
19
+
20
+ def group_time_steps(
21
+ rollout_tree: RolloutTreeRootNode,
22
+ accumulation_stop_condition: Callable[[StepLog], bool],
23
+ ) -> RolloutTreeRootNode:
24
+ """
25
+ During generation, we create rollout trees according to the real time steps.
26
+ However, during training, we might want to treat groups of time steps as a single time step.
27
+ As a concrete example, take Trust-and-Split. At each round, say we have X time steps of communication and then one time step for the split.
28
+ Then the communication actions will not get any reward, and the split action will get the reward. During REINFORCE training, with discounting, this
29
+ can cause training instability. We could instead treat every action in the round as being part of a single action, and give it the reward of the split action.
30
+ This method helps to do this sort of grouping.
31
+ It accumulates actions until the accumulation_stop_condition is met, and then creates a new node with the accumulated actions.
32
+ It then recursively calls itself on the child node.
33
+ Details:
34
+ - The reward for the group is the reward of the last time step in the group.
35
+ - The simulation log for the group is the simulation log of the last time step in the group.
36
+ - The state end for the group becomes the first state end in the group.
37
+ - The agent info for the group is the agent info of the last time step in the group.
38
+ """
39
+
40
+ def group_step_logs(step_logs: list[StepLog]) -> StepLog:
41
+ """
42
+ Concatenate per-agent chat turns across steps; keep only the first is_state_end.
43
+ """
44
+ last_sim_log = step_logs[-1].simulation_step_log
45
+ agent_ids = {aid for s in step_logs for aid in s.action_logs.keys()}
46
+ grouped_logs: dict[AgentId, AgentActLog] = {}
47
+ for aid in agent_ids:
48
+ turns = []
49
+ for s in step_logs:
50
+ act = s.action_logs.get(aid)
51
+ if act and act.chat_turns:
52
+ turns.extend(copy.deepcopy(act.chat_turns))
53
+ disable_is_state_end = False
54
+ # Only the first state_end should be True, the rest should be False
55
+ for t in turns:
56
+ if t.is_state_end:
57
+ if disable_is_state_end:
58
+ t.is_state_end = False
59
+ else:
60
+ disable_is_state_end = True
61
+ continue
62
+ grouped_logs[aid] = AgentActLog(
63
+ chat_turns=turns, info=step_logs[-1].action_logs[aid].info
64
+ )
65
+ return StepLog(action_logs=grouped_logs, simulation_step_log=last_sim_log)
66
+
67
+ def group_time_steps_rec(
68
+ current_node: RolloutTreeNode | RolloutTreeBranchNode,
69
+ group_time_step: int,
70
+ accumulation_step_logs: list[StepLog],
71
+ ) -> RolloutTreeNode | RolloutTreeBranchNode:
72
+ """
73
+ Groups time steps. Recursion is used to handle branches.
74
+ """
75
+ assert isinstance(current_node, RolloutTreeNode) or isinstance(
76
+ current_node, RolloutTreeBranchNode
77
+ ), "Current node must be a tree node or a branch node. Is of type: " + str(
78
+ type(current_node)
79
+ )
80
+ first_group_node = None
81
+ current_group_node = None
82
+ while current_node is not None:
83
+ if isinstance(current_node, RolloutTreeBranchNode):
84
+ raise Exception(
85
+ "Grouping timesteps by round is not supported for branching trajectories yet."
86
+ )
87
+ # Special recursive case for branches
88
+ # if isinstance(current_node, RolloutTreeBranchNode):
89
+ # branches = {}
90
+ # for agent_id, branch_nodes in current_node.branches.items():
91
+ # branch_group_nodes = []
92
+ # for branch_node in branch_nodes:
93
+ # branch_group_node = group_time_steps_rec(
94
+ # current_node=branch_node,
95
+ # group_time_step=group_time_step,
96
+ # accumulation_step_logs=copy.deepcopy(accumulation_step_logs))
97
+ # branch_group_nodes.append(branch_group_node)
98
+ # branches[agent_id] = branch_group_nodes
99
+
100
+ # main_child_group_node = group_time_steps_rec(
101
+ # current_node=current_node.main_child,
102
+ # group_time_step=group_time_step,
103
+ # accumulation_step_logs=copy.deepcopy(accumulation_step_logs))
104
+
105
+ # return RolloutTreeBranchNode(main_child=main_child_group_node, branches=branches)
106
+
107
+ # Accumulate
108
+ accumulation_step_logs.append(current_node.step_log)
109
+ if accumulation_stop_condition(current_node.step_log):
110
+ grouped_step_logs = group_step_logs(accumulation_step_logs)
111
+ accumulation_step_logs = []
112
+ new_group_node = RolloutTreeNode(
113
+ step_log=grouped_step_logs, time_step=group_time_step, child=None
114
+ )
115
+ if first_group_node == None:
116
+ first_group_node = new_group_node
117
+ group_time_step += 1
118
+ if current_group_node is not None:
119
+ current_group_node.child = new_group_node
120
+ current_group_node = new_group_node
121
+ current_node = current_node.child
122
+ return first_group_node
123
+
124
+ node = group_time_steps_rec(
125
+ current_node=rollout_tree.child, group_time_step=0, accumulation_step_logs=[]
126
+ )
127
+ return RolloutTreeRootNode(
128
+ id=rollout_tree.id,
129
+ crn_id=rollout_tree.crn_id,
130
+ child=node,
131
+ agent_ids=rollout_tree.agent_ids,
132
+ )
133
+
134
+
135
+ def stop_when_round_ends(step_log: StepLog) -> bool:
136
+ """
137
+ Simplest stop condition. Will return True if step log is the last time step of a round.
138
+ This will throw an error if this information is not available in the simulation info.
139
+ """
140
+ assert (
141
+ "is_last_timestep_in_round" in step_log.simulation_step_log.info.keys()
142
+ ), "To group by round, is_last_timestep_in_round must be set in the info of your simulation step log at each time step."
143
+ return step_log.simulation_step_log.info["is_last_timestep_in_round"]
144
+
145
+
146
+ def group_by_round(rollout_tree: RolloutTreeRootNode) -> RolloutTreeRootNode:
147
+ """
148
+ Groups time steps by round.
149
+ """
150
+ return group_time_steps(rollout_tree, stop_when_round_ends)
src_code_for_reproducibility/markov_games/linear_runner.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import os.path
4
+
5
+ from mllm.markov_games.markov_game import MarkovGame
6
+ from mllm.markov_games.rollout_tree import RolloutTreeNode, RolloutTreeRootNode
7
+
8
+
9
+ async def LinearRunner(
10
+ markov_game: MarkovGame, output_folder: str
11
+ ) -> RolloutTreeRootNode:
12
+ """
13
+ This method generates a trajectory without branching.
14
+ """
15
+ time_step = 0
16
+ terminated = False
17
+ root = RolloutTreeRootNode(
18
+ id=markov_game.get_id(),
19
+ crn_id=markov_game.get_crn_id(),
20
+ agent_ids=markov_game.get_agent_ids(),
21
+ )
22
+ previous_node = root
23
+ while not terminated:
24
+ terminated, step_log = await markov_game.step()
25
+ current_node = RolloutTreeNode(step_log=step_log, time_step=time_step)
26
+ previous_node.child = current_node
27
+ previous_node = current_node
28
+ time_step += 1
29
+
30
+ return root
src_code_for_reproducibility/markov_games/markov_game.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This class unifies a simulation, and the agents acting in it (see `simulation.py` & `agent.py`).
3
+ In a MarkovGame step,
4
+ 1) each agent takes an action,
5
+ 2) the state transitions with respect to these actions,
6
+ 3) all relevant data of the step is appended to the historical data list
7
+
8
+ In order to perform 3), the agents and the simulation are expected, at each time step,
9
+ to return a log of the state transition (from their perspective).
10
+ For instance, the Simulation might send rewards and the agents might send prompting contexts to be used later to generate the training data.
11
+ A different approach would be to simply have the agents keep their data private and log it upon completion of a trajectory.
12
+ The approach we use here centralizes the data gathering aspect,
13
+ making it easy to create sub-trajectories (in the `runners` defined in `runners.py`) descriptions that
14
+ only log information for step transitions occuring after the branching out.
15
+ """
16
+ import asyncio
17
+ import copy
18
+ import json
19
+ import os
20
+ from dataclasses import dataclass
21
+ from typing import Any, List, Literal, Optional, Tuple
22
+
23
+ from transformers.models.idefics2 import Idefics2Config
24
+
25
+ from mllm.markov_games.agent import Agent
26
+ from mllm.markov_games.rollout_tree import AgentActLog, StepLog
27
+ from mllm.markov_games.simulation import Simulation
28
+
29
+ AgentId = str
30
+
31
+
32
+ @dataclass
33
+ class AgentAndActionSafeCopy:
34
+ action: Any
35
+ action_info: AgentActLog
36
+ agent_after_action: type[Agent]
37
+
38
+
39
+ class MarkovGame(object):
40
+ def __init__(
41
+ self,
42
+ id: int,
43
+ agents: dict[AgentId, type[Agent]],
44
+ simulation: type[Simulation],
45
+ crn_id: int,
46
+ ):
47
+ """
48
+ Args:
49
+ agents:
50
+ output_path:
51
+ Path where the step infos are saved.
52
+ simulation:
53
+ Simulation object. Example: IPDSimulation
54
+ """
55
+ self.agents = agents
56
+ self.agent_ids = self.agents.keys()
57
+ self.simulation = simulation
58
+ self.simulation_step_log = None
59
+ self.agent_step_logs = {agent_id: None for agent_id in self.agent_ids}
60
+ self.actions = {}
61
+ self.id = id
62
+ self.crn_id = crn_id
63
+
64
+ def get_id(self) -> str:
65
+ return self.id
66
+
67
+ def get_crn_id(self) -> int:
68
+ return self.crn_id
69
+
70
+ def get_agent_ids(self) -> List[AgentId]:
71
+ return list(self.agent_ids)
72
+
73
+ async def get_action_of_agent_without_side_effects(
74
+ self, agent_id: AgentId
75
+ ) -> Tuple[Any, AgentActLog]:
76
+ """
77
+ Safe function to get an action of an agent without modifying the agent or the simulation.
78
+ """
79
+ agent = self.agents[agent_id]
80
+ agent_before_action = agent.get_safe_copy()
81
+ obs = self.simulation.get_obs_agent(agent_id)
82
+ action, action_info = await agent.act(observation=obs)
83
+ self.agents[agent_id] = agent_before_action
84
+ agent_after_action = agent.get_safe_copy()
85
+ return AgentAndActionSafeCopy(action, action_info, agent_after_action)
86
+
87
+ async def get_actions_of_agents_without_side_effects(
88
+ self,
89
+ ) -> dict[AgentId, AgentAndActionSafeCopy]:
90
+ """
91
+ Safe function to get an action of an agent without modifying the agent or the simulation.
92
+ """
93
+ tasks = []
94
+ for agent_id in self.agent_ids:
95
+ task = asyncio.create_task(
96
+ self.get_action_of_agent_without_side_effects(agent_id)
97
+ )
98
+ tasks.append(task)
99
+ agent_and_action_safe_copies: list[
100
+ AgentAndActionSafeCopy
101
+ ] = await asyncio.gather(*tasks)
102
+ return {
103
+ agent_id: agent_and_action_safe_copy
104
+ for agent_id, agent_and_action_safe_copy in zip(
105
+ self.agent_ids, agent_and_action_safe_copies
106
+ )
107
+ }
108
+
109
+ def set_action_and_agent_after_action_manually(
110
+ self,
111
+ agent_id: AgentId,
112
+ agent_action_safe_copy: AgentAndActionSafeCopy,
113
+ ):
114
+ """
115
+ Set the action and the agent after action manually.
116
+ """
117
+ self.actions[agent_id] = agent_action_safe_copy.action
118
+ self.agent_step_logs[agent_id] = agent_action_safe_copy.action_info
119
+ self.agents[agent_id] = agent_action_safe_copy.agent_after_action
120
+
121
+ def set_actions_of_agents_manually(
122
+ self, actions: dict[AgentId, AgentAndActionSafeCopy]
123
+ ):
124
+ """
125
+ Set the actions of agents manually.
126
+ """
127
+ for agent_id, agent_action_safe_copy in actions.items():
128
+ self.set_action_and_agent_after_action_manually(
129
+ agent_id, agent_action_safe_copy
130
+ )
131
+
132
+ async def set_action_of_agent(self, agent_id: AgentId):
133
+ """
134
+ TOWRITE
135
+ """
136
+ agent = self.agents[agent_id]
137
+ obs = self.simulation.get_obs_agent(agent_id)
138
+ action, action_info = await agent.act(observation=obs)
139
+ self.actions[agent_id] = action
140
+ self.agent_step_logs[agent_id] = action_info
141
+
142
+ async def set_actions(self):
143
+ """
144
+ TOWRITE
145
+ """
146
+ # background_tasks = set()
147
+ tasks = []
148
+ for agent_id in self.agent_ids:
149
+ task = asyncio.create_task(self.set_action_of_agent(agent_id))
150
+ tasks.append(task)
151
+ await asyncio.gather(*tasks)
152
+
153
+ def take_simulation_step(self):
154
+ """
155
+ TOWRITE
156
+ """
157
+ terminated, self.simulation_step_log = self.simulation.step(self.actions)
158
+ return terminated
159
+
160
+ def get_step_log(self) -> StepLog:
161
+ """
162
+ TOWRITE
163
+ TODO: assert actions and simulation have taken step
164
+ """
165
+ step_log = StepLog(
166
+ simulation_step_log=self.simulation_step_log,
167
+ action_logs=self.agent_step_logs,
168
+ )
169
+ return step_log
170
+
171
+ async def step(self) -> Tuple[bool, StepLog]:
172
+ """
173
+ TOWRITE
174
+ """
175
+ await self.set_actions()
176
+ terminated = self.take_simulation_step()
177
+ step_log = self.get_step_log()
178
+ return terminated, step_log
179
+
180
+ def get_safe_copy(self):
181
+ """
182
+ TOWRITE
183
+ """
184
+
185
+ new_markov_game = copy.copy(self)
186
+ new_simulation = self.simulation.get_safe_copy()
187
+ new_agents = {
188
+ agent_id: agent.get_safe_copy() for agent_id, agent in self.agents.items()
189
+ }
190
+
191
+ # Reassign copied components
192
+ new_markov_game.simulation = new_simulation
193
+ new_markov_game.agents = new_agents
194
+
195
+ # IMPORTANT: ensure agent_ids references the new agents dict, not the original
196
+ new_markov_game.agent_ids = new_markov_game.agents.keys()
197
+
198
+ # Deep-copy step data to avoid correlation
199
+ new_markov_game.simulation_step_log = copy.deepcopy(self.simulation_step_log)
200
+ new_markov_game.actions = copy.deepcopy(self.actions)
201
+ # Rebuild logs to align exactly with new agent ids
202
+ old_agent_step_logs = copy.deepcopy(self.agent_step_logs)
203
+ new_markov_game.agent_step_logs = {
204
+ agent_id: old_agent_step_logs.get(agent_id)
205
+ for agent_id in new_markov_game.agent_ids
206
+ }
207
+
208
+ return new_markov_game
src_code_for_reproducibility/markov_games/mg_utils.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import copy
3
+ from collections.abc import Callable
4
+ from dataclasses import dataclass
5
+
6
+ from mllm.markov_games.ipd.ipd_agent import IPDAgent
7
+ from mllm.markov_games.ipd.ipd_simulation import IPD
8
+ from mllm.markov_games.markov_game import MarkovGame
9
+ from mllm.markov_games.negotiation.dond_agent import DealNoDealAgent
10
+ from mllm.markov_games.negotiation.dond_simulation import DealNoDealSimulation
11
+ from mllm.markov_games.negotiation.nego_hard_coded_policies import (
12
+ HardCodedNegoGreedyPolicy,
13
+ HardCodedNegoWelfareMaximizingPolicy,
14
+ )
15
+ from mllm.markov_games.ipd.Ipd_hard_coded_agents import AlwaysCooperateIPDAgent, AlwaysDefectIPDAgent
16
+ from mllm.markov_games.negotiation.no_press_nego_agent import NoPressAgent
17
+ from mllm.markov_games.negotiation.no_press_nego_simulation import NoPressSimulation
18
+ from mllm.markov_games.negotiation.tas_agent import TrustAndSplitAgent
19
+ from mllm.markov_games.negotiation.tas_rps_agent import TrustAndSplitRPSAgent
20
+ from mllm.markov_games.negotiation.tas_rps_simulation import TrustAndSplitRPSSimulation
21
+ from mllm.markov_games.negotiation.tas_simple_agent import TrustAndSplitSimpleAgent
22
+ from mllm.markov_games.negotiation.tas_simple_simulation import (
23
+ TrustAndSplitSimpleSimulation,
24
+ )
25
+ from mllm.markov_games.negotiation.tas_simulation import TrustAndSplitSimulation
26
+ from mllm.markov_games.rollout_tree import (
27
+ AgentActLog,
28
+ RolloutTreeBranchNode,
29
+ RolloutTreeNode,
30
+ RolloutTreeRootNode,
31
+ StepLog,
32
+ )
33
+ from mllm.markov_games.simulation import SimulationStepLog
34
+
35
+ AgentId = str
36
+
37
+
38
+ @dataclass
39
+ class AgentConfig:
40
+ agent_id: str
41
+ agent_name: str
42
+ agent_class_name: str
43
+ policy_id: str
44
+ init_kwargs: dict
45
+
46
+
47
+ @dataclass
48
+ class MarkovGameConfig:
49
+ id: int
50
+ seed: int
51
+ simulation_class_name: str
52
+ simulation_init_args: dict
53
+ agent_configs: list[AgentConfig]
54
+
55
+
56
+ def init_markov_game_components(
57
+ config: MarkovGameConfig, policies: dict[str, Callable[[list[dict]], str]]
58
+ ):
59
+ """
60
+ TOWRITE
61
+ """
62
+ agents = {}
63
+ agent_names = []
64
+ for agent_config in config.agent_configs:
65
+ agent_id = agent_config.agent_id
66
+ agent_name = agent_config.agent_name
67
+ agent_class = eval(agent_config.agent_class_name)
68
+ agent = agent_class(
69
+ seed=config.seed,
70
+ agent_id=agent_id,
71
+ agent_name=agent_name,
72
+ policy=policies[agent_config.policy_id],
73
+ **agent_config.init_kwargs,
74
+ )
75
+ agents[agent_id] = agent
76
+ agent_names.append(agent_name)
77
+ simulation = eval(config.simulation_class_name)(
78
+ seed=config.seed,
79
+ agent_ids=list(agents.keys()),
80
+ agent_names=agent_names,
81
+ **config.simulation_init_args,
82
+ )
83
+ markov_game = MarkovGame(
84
+ id=config.id,
85
+ crn_id=config.seed,
86
+ agents=agents,
87
+ simulation=simulation,
88
+ )
89
+ return markov_game
src_code_for_reproducibility/markov_games/rollout_tree.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TODO: add parent to nodes so that some verification can be done. For instance, to ensure that node reward keys match the parent node.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import json
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+ from typing import Any, List, Literal, Optional, Tuple
11
+
12
+ import jsonschema
13
+ from pydantic import BaseModel, Field, model_validator
14
+
15
+ from mllm.chat_utils.chat_turn import ChatTurn
16
+
17
+ AgentId = str
18
+
19
+
20
+ class SimulationStepLog(BaseModel):
21
+ rewards: dict[AgentId, float]
22
+ info: Any = None
23
+
24
+
25
+ class AgentActLog(BaseModel):
26
+ chat_turns: list[ChatTurn] | None
27
+ info: Any = None
28
+
29
+ @model_validator(mode="after")
30
+ def _exactly_one_state_end(self):
31
+ """
32
+ This method is used to enforce that for each AgentActLog, there is exactly one ChatTurn which is a state end.
33
+ """
34
+ if self.chat_turns != []:
35
+ n = sum(1 for t in self.chat_turns if t.is_state_end)
36
+ if n != 1:
37
+ raise ValueError(
38
+ f"AgentActLog must have exactly one ChatTurn with is_state_end=True; got {self.chat_turns}."
39
+ )
40
+ return self
41
+ else:
42
+ return self
43
+
44
+
45
+ class StepLog(BaseModel):
46
+ action_logs: dict[AgentId, AgentActLog]
47
+ simulation_step_log: SimulationStepLog
48
+
49
+
50
+ # BranchType = Literal["unilateral_deviation", "common_deviation"] # might not be necessary
51
+ # class BranchNodeInfo(BaseModel):
52
+ # branch_id: str
53
+ # branch_for: AgentId
54
+ # branch_type: BranchType
55
+
56
+
57
+ class RolloutTreeNode(BaseModel):
58
+ step_log: StepLog
59
+ time_step: int
60
+ child: RolloutTreeNode | RolloutTreeBranchNode | None = None
61
+
62
+
63
+ class RolloutTreeBranchNode(BaseModel):
64
+ """
65
+ First item of the tuple indicates which agent "called" for an alternative branch.
66
+ """
67
+
68
+ main_child: RolloutTreeNode
69
+ branches: dict[AgentId, list[RolloutTreeNode]] | None = None
70
+
71
+
72
+ class RolloutTreeRootNode(BaseModel):
73
+ id: int
74
+ crn_id: int # ID of the rng used to generate this rollout tree
75
+ child: RolloutTreeNode | RolloutTreeBranchNode | None = None
76
+ agent_ids: List[AgentId] = Field(min_length=1)
77
+
78
+
79
+ # class RolloutTreeLeafNode(BaseModel):
80
+ # step_log: StepLog
81
+ # time_step: int
82
+
83
+
84
+ # Necessary for self-referential stuff in pydantic
85
+ RolloutTreeBranchNode.model_rebuild()
86
+ RolloutTreeNode.model_rebuild()
src_code_for_reproducibility/markov_games/run_markov_games.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from collections.abc import Callable
3
+ from dataclasses import dataclass
4
+
5
+ from torch._C import ClassType
6
+
7
+ from mllm.markov_games.markov_game import MarkovGame
8
+ from mllm.markov_games.rollout_tree import RolloutTreeRootNode
9
+
10
+
11
+ async def run_markov_games(
12
+ runner: Callable[[MarkovGame], RolloutTreeRootNode],
13
+ runner_kwargs: dict,
14
+ output_folder: str,
15
+ markov_games: list[MarkovGame],
16
+ ) -> list[RolloutTreeRootNode]:
17
+ tasks = []
18
+ for mg in markov_games:
19
+ tasks.append(
20
+ asyncio.create_task(
21
+ runner(markov_game=mg, output_folder=output_folder, **runner_kwargs)
22
+ )
23
+ )
24
+ return await asyncio.gather(*tasks)
src_code_for_reproducibility/markov_games/simulation.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A Simulation is the environment of a Markov Game.
3
+ The Simulation is not responsible for properly checking / formatting the responses of LLM's.
4
+ This is the job of the `Agent` class.
5
+ Simulations expect clean actions, and are defined similarly to `gymnasium` environments, except that they are adapted for the Multi-agent setting.
6
+ """
7
+
8
+ from abc import ABC, abstractmethod
9
+ from typing import Any, Tuple
10
+
11
+ from numpy.random import default_rng
12
+
13
+ from mllm.markov_games.rollout_tree import SimulationStepLog
14
+
15
+
16
+ class Simulation(ABC):
17
+ @abstractmethod
18
+ def __init__(self, seed: int, *args, **kwargs):
19
+ self.seed = seed
20
+ self.rng = default_rng(self.seed)
21
+
22
+ @abstractmethod
23
+ def step(self, actions: Any) -> Tuple[bool, SimulationStepLog]:
24
+ """
25
+ Returns terminated, info
26
+ """
27
+ raise NotImplementedError
28
+
29
+ def get_obs(self):
30
+ """Returns all agent observations in dict
31
+
32
+ Returns:
33
+ observations
34
+ """
35
+ raise NotImplementedError
36
+
37
+ def get_obs_agent(self, agent_id):
38
+ """Returns observation for agent_id"""
39
+ raise NotImplementedError
40
+
41
+ def get_obs_size(self):
42
+ """Returns the shape of the observation"""
43
+ raise NotImplementedError
44
+
45
+ def get_state(self):
46
+ raise NotImplementedError
47
+
48
+ def get_state_size(self):
49
+ """Returns the shape of the state"""
50
+ raise NotImplementedError
51
+
52
+ def get_avail_actions(self):
53
+ raise NotImplementedError
54
+
55
+ def get_avail_agent_actions(self, agent_id):
56
+ """Returns the available actions for agent_id"""
57
+ raise NotImplementedError
58
+
59
+ def get_total_actions(self):
60
+ """Returns the total number of actions an agent could ever take"""
61
+ # TODO: This is only suitable for a discrete 1 dimensional action space for each agent
62
+ raise NotImplementedError
63
+
64
+ def get_safe_copy(self):
65
+ """
66
+ Return copy of the agent object that is decorrelated from the original object.
67
+ """
68
+ raise NotImplementedError
69
+
70
+ def reset(self):
71
+ """Returns initial observations and states"""
72
+ raise NotImplementedError
73
+
74
+ def render(self):
75
+ raise NotImplementedError
76
+
77
+ def close(self):
78
+ raise NotImplementedError
79
+
80
+ # def seed(self):
81
+ # raise NotImplementedError
82
+
83
+ def save_replay(self):
84
+ raise NotImplementedError
85
+
86
+ def get_simulation_info(self):
87
+ raise NotImplementedError
src_code_for_reproducibility/markov_games/statistics_runner.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import gc
4
+ import json
5
+ import pickle
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+ from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional
9
+
10
+ from basic_render import find_iteration_folders
11
+
12
+ from mllm.markov_games.rollout_tree import (
13
+ RolloutTreeBranchNode,
14
+ RolloutTreeNode,
15
+ RolloutTreeRootNode,
16
+ SimulationStepLog,
17
+ )
18
+
19
+
20
+ def _iterate_main_nodes(root: RolloutTreeRootNode) -> Iterator[RolloutTreeNode]:
21
+ """
22
+ Iterate the main path nodes without materializing full path lists.
23
+ """
24
+ current = root.child
25
+ while current is not None:
26
+ if isinstance(current, RolloutTreeNode):
27
+ yield current
28
+ current = current.child
29
+ elif isinstance(current, RolloutTreeBranchNode):
30
+ # Follow only the main child on the main trajectory
31
+ current = current.main_child
32
+ else:
33
+ break
34
+
35
+
36
+ def iterate_main_simulation_logs(
37
+ root: RolloutTreeRootNode,
38
+ ) -> Iterator[SimulationStepLog]:
39
+ for node in _iterate_main_nodes(root):
40
+ yield node.step_log.simulation_step_log
41
+
42
+
43
+ def stream_rollout_files(iteration_folder: Path) -> Iterator[Path]:
44
+ for p in iteration_folder.rglob("*.rt.pkl"):
45
+ if p.is_file():
46
+ yield p
47
+
48
+
49
+ def load_root(path: Path) -> RolloutTreeRootNode:
50
+ with open(path, "rb") as f:
51
+ data = pickle.load(f)
52
+ return RolloutTreeRootNode.model_validate(data)
53
+
54
+
55
+ @dataclass
56
+ class StatRecord:
57
+ mgid: int
58
+ crn_id: Optional[int]
59
+ iteration: str
60
+ values: Dict[str, Any]
61
+
62
+
63
+ class StatComputer:
64
+ """
65
+ Stateful stat computer that consumes SimulationStepLog instances
66
+ and produces final aggregated values for one rollout (mgid).
67
+ """
68
+
69
+ def update(self, sl: SimulationStepLog) -> None: # pragma: no cover - interface
70
+ raise NotImplementedError
71
+
72
+ def finalize(self) -> Dict[str, Any]: # pragma: no cover - interface
73
+ raise NotImplementedError
74
+
75
+
76
+ def run_stats(
77
+ data_root: Path,
78
+ game_name: str,
79
+ make_computers: Callable[[], List[StatComputer]],
80
+ output_filename: Optional[str] = None,
81
+ output_format: str = "json", # "json" (dict of lists) or "jsonl"
82
+ ) -> Path:
83
+ """
84
+ Compute stats across all iteration_* folders under data_root.
85
+ Writes JSONL to data_root/statistics/<output_filename or f"{game_name}.stats.jsonl">.
86
+ """
87
+ data_root = Path(data_root)
88
+ outdir = data_root / "statistics"
89
+ outdir.mkdir(parents=True, exist_ok=True)
90
+ # Choose extension by format
91
+ default_name = (
92
+ f"{game_name}.stats.json"
93
+ if output_format == "json"
94
+ else f"{game_name}.stats.jsonl"
95
+ )
96
+ outfile = outdir / (
97
+ output_filename if output_filename is not None else default_name
98
+ )
99
+
100
+ # Rewrite file each run to keep it clean and small
101
+ if outfile.exists():
102
+ outfile.unlink()
103
+
104
+ iteration_folders = find_iteration_folders(str(data_root))
105
+
106
+ # If writing JSONL, stream directly; otherwise accumulate minimal records
107
+ if output_format == "jsonl":
108
+ with open(outfile, "w", encoding="utf-8") as w:
109
+ for iteration_folder in iteration_folders:
110
+ iteration_name = Path(iteration_folder).name
111
+ for pkl_path in stream_rollout_files(Path(iteration_folder)):
112
+ root = load_root(pkl_path)
113
+
114
+ computers = make_computers()
115
+ for sl in iterate_main_simulation_logs(root):
116
+ for comp in computers:
117
+ try:
118
+ comp.update(sl)
119
+ except Exception:
120
+ continue
121
+
122
+ values: Dict[str, Any] = {}
123
+ for comp in computers:
124
+ try:
125
+ values.update(comp.finalize())
126
+ except Exception:
127
+ continue
128
+
129
+ rec = {
130
+ "mgid": getattr(root, "id", None),
131
+ "crn_id": getattr(root, "crn_id", None),
132
+ "iteration": iteration_name,
133
+ "stats": values,
134
+ }
135
+ w.write(json.dumps(rec, ensure_ascii=False) + "\n")
136
+
137
+ del root
138
+ del computers
139
+ gc.collect()
140
+ else:
141
+ # Aggregate to dict-of-lists for easier plotting
142
+ records: List[Dict[str, Any]] = []
143
+ # Process in deterministic order
144
+ for iteration_folder in iteration_folders:
145
+ iteration_name = Path(iteration_folder).name
146
+ for pkl_path in stream_rollout_files(Path(iteration_folder)):
147
+ root = load_root(pkl_path)
148
+
149
+ computers = make_computers()
150
+ for sl in iterate_main_simulation_logs(root):
151
+ for comp in computers:
152
+ try:
153
+ comp.update(sl)
154
+ except Exception:
155
+ continue
156
+
157
+ values: Dict[str, Any] = {}
158
+ for comp in computers:
159
+ try:
160
+ values.update(comp.finalize())
161
+ except Exception:
162
+ continue
163
+
164
+ records.append(
165
+ {
166
+ "mgid": getattr(root, "id", None),
167
+ "crn_id": getattr(root, "crn_id", None),
168
+ "iteration": iteration_name,
169
+ "stats": values,
170
+ }
171
+ )
172
+
173
+ del root
174
+ del computers
175
+ gc.collect()
176
+
177
+ # Build dict-of-lists with nested stats preserved
178
+ # Collect all stat keys and nested agent keys where needed
179
+ mgids: List[Any] = []
180
+ crn_ids: List[Any] = []
181
+ iterations_out: List[str] = []
182
+ # stats_out is a nested structure mirroring keys but with lists
183
+ stats_out: Dict[str, Any] = {}
184
+
185
+ # First pass to collect union of keys
186
+ stat_keys: set[str] = set()
187
+ nested_agent_keys: Dict[str, set[str]] = {}
188
+ for r in records:
189
+ stats = r.get("stats", {}) or {}
190
+ for k, v in stats.items():
191
+ stat_keys.add(k)
192
+ if isinstance(v, dict):
193
+ nested = nested_agent_keys.setdefault(k, set())
194
+ for ak in v.keys():
195
+ nested.add(str(ak))
196
+
197
+ # Initialize structure
198
+ for k in stat_keys:
199
+ if k in nested_agent_keys:
200
+ stats_out[k] = {ak: [] for ak in sorted(nested_agent_keys[k])}
201
+ else:
202
+ stats_out[k] = []
203
+
204
+ # Fill lists
205
+ for r in records:
206
+ mgids.append(r.get("mgid"))
207
+ crn_ids.append(r.get("crn_id"))
208
+ iterations_out.append(r.get("iteration"))
209
+ stats = r.get("stats", {}) or {}
210
+ for k in stat_keys:
211
+ val = stats.get(k)
212
+ if isinstance(stats_out[k], dict):
213
+ # per-agent dict
214
+ agent_dict = val if isinstance(val, dict) else {}
215
+ for ak in stats_out[k].keys():
216
+ stats_out[k][ak].append(agent_dict.get(ak))
217
+ else:
218
+ stats_out[k].append(val)
219
+
220
+ with open(outfile, "w", encoding="utf-8") as w:
221
+ json.dump(
222
+ {
223
+ "mgid": mgids,
224
+ "crn_id": crn_ids,
225
+ "iteration": iterations_out,
226
+ "stats": stats_out,
227
+ },
228
+ w,
229
+ ensure_ascii=False,
230
+ )
231
+
232
+ return outfile
233
+
234
+
235
+ def run_stats_functional(
236
+ data_root: Path,
237
+ game_name: str,
238
+ metrics: Dict[str, Callable[[SimulationStepLog], Optional[Dict[str, float]]]],
239
+ output_filename: Optional[str] = None,
240
+ output_format: str = "json",
241
+ ) -> Path:
242
+ """
243
+ Functional variant where metrics is a dict of name -> f(SimulationStepLog) -> {agent_id: value}.
244
+ Aggregates per rollout by averaging over steps where a metric produced a value.
245
+ Writes a single consolidated file in data_root/statistics/.
246
+ """
247
+ data_root = Path(data_root)
248
+ outdir = data_root / "statistics"
249
+ outdir.mkdir(parents=True, exist_ok=True)
250
+ default_name = (
251
+ f"{game_name}.stats.json"
252
+ if output_format == "json"
253
+ else f"{game_name}.stats.jsonl"
254
+ )
255
+ outfile = outdir / (
256
+ output_filename if output_filename is not None else default_name
257
+ )
258
+
259
+ if outfile.exists():
260
+ outfile.unlink()
261
+
262
+ iteration_folders = find_iteration_folders(str(data_root))
263
+
264
+ def finalize_rollout(
265
+ agg: Dict[str, Dict[str, List[float]]]
266
+ ) -> Dict[str, Dict[str, float]]:
267
+ # avg per metric per agent
268
+ result: Dict[str, Dict[str, float]] = {}
269
+ for mname, agent_values in agg.items():
270
+ result[mname] = {}
271
+ for aid, vals in agent_values.items():
272
+ if not vals:
273
+ result[mname][aid] = None # keep alignment; could be None
274
+ else:
275
+ result[mname][aid] = sum(vals) / len(vals)
276
+ return result
277
+
278
+ if output_format == "jsonl":
279
+ with open(outfile, "w", encoding="utf-8") as w:
280
+ for iteration_folder in iteration_folders:
281
+ iteration_name = Path(iteration_folder).name
282
+ for pkl_path in stream_rollout_files(Path(iteration_folder)):
283
+ root = load_root(pkl_path)
284
+
285
+ # aggregator structure: metric -> agent_id -> list of values
286
+ agg: Dict[str, Dict[str, List[float]]] = {
287
+ m: {} for m in metrics.keys()
288
+ }
289
+
290
+ for sl in iterate_main_simulation_logs(root):
291
+ for mname, fn in metrics.items():
292
+ try:
293
+ vals = fn(sl)
294
+ except Exception:
295
+ vals = None
296
+ if not vals:
297
+ continue
298
+ for aid, v in vals.items():
299
+ if v is None:
300
+ continue
301
+ lst = agg[mname].setdefault(str(aid), [])
302
+ try:
303
+ lst.append(float(v))
304
+ except Exception:
305
+ continue
306
+
307
+ values = finalize_rollout(agg)
308
+ rec = {
309
+ "mgid": getattr(root, "id", None),
310
+ "crn_id": getattr(root, "crn_id", None),
311
+ "iteration": iteration_name,
312
+ "stats": values,
313
+ }
314
+ w.write(json.dumps(rec, ensure_ascii=False) + "\n")
315
+
316
+ del root
317
+ gc.collect()
318
+ else:
319
+ records: List[Dict[str, Any]] = []
320
+ for iteration_folder in iteration_folders:
321
+ iteration_name = Path(iteration_folder).name
322
+ for pkl_path in stream_rollout_files(Path(iteration_folder)):
323
+ root = load_root(pkl_path)
324
+
325
+ agg: Dict[str, Dict[str, List[float]]] = {m: {} for m in metrics.keys()}
326
+ for sl in iterate_main_simulation_logs(root):
327
+ for mname, fn in metrics.items():
328
+ try:
329
+ vals = fn(sl)
330
+ except Exception:
331
+ vals = None
332
+ if not vals:
333
+ continue
334
+ for aid, v in vals.items():
335
+ if v is None:
336
+ continue
337
+ lst = agg[mname].setdefault(str(aid), [])
338
+ try:
339
+ lst.append(float(v))
340
+ except Exception:
341
+ continue
342
+
343
+ values = finalize_rollout(agg)
344
+ records.append(
345
+ {
346
+ "mgid": getattr(root, "id", None),
347
+ "crn_id": getattr(root, "crn_id", None),
348
+ "iteration": iteration_name,
349
+ "stats": values,
350
+ }
351
+ )
352
+
353
+ del root
354
+ gc.collect()
355
+
356
+ # Build dict-of-lists output
357
+ mgids: List[Any] = []
358
+ crn_ids: List[Any] = []
359
+ iterations_out: List[str] = []
360
+ stats_out: Dict[str, Any] = {}
361
+
362
+ stat_keys: set[str] = set()
363
+ nested_agent_keys: Dict[str, set[str]] = {}
364
+ for r in records:
365
+ stats = r.get("stats", {}) or {}
366
+ for k, v in stats.items():
367
+ stat_keys.add(k)
368
+ if isinstance(v, dict):
369
+ nested = nested_agent_keys.setdefault(k, set())
370
+ for ak in v.keys():
371
+ nested.add(str(ak))
372
+
373
+ for k in stat_keys:
374
+ if k in nested_agent_keys:
375
+ stats_out[k] = {ak: [] for ak in sorted(nested_agent_keys[k])}
376
+ else:
377
+ stats_out[k] = []
378
+
379
+ for r in records:
380
+ mgids.append(r.get("mgid"))
381
+ crn_ids.append(r.get("crn_id"))
382
+ iterations_out.append(r.get("iteration"))
383
+ stats = r.get("stats", {}) or {}
384
+ for k in stat_keys:
385
+ val = stats.get(k)
386
+ if isinstance(stats_out[k], dict):
387
+ agent_dict = val if isinstance(val, dict) else {}
388
+ for ak in stats_out[k].keys():
389
+ stats_out[k][ak].append(agent_dict.get(ak))
390
+ else:
391
+ stats_out[k].append(val)
392
+
393
+ with open(outfile, "w", encoding="utf-8") as w:
394
+ json.dump(
395
+ {
396
+ "mgid": mgids,
397
+ "crn_id": crn_ids,
398
+ "iteration": iterations_out,
399
+ "stats": stats_out,
400
+ },
401
+ w,
402
+ ensure_ascii=False,
403
+ )
404
+
405
+ return outfile
src_code_for_reproducibility/markov_games/vine_ppo.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from anytree import Node, RenderTree
2
+ from anytree.exporter import DotExporter
3
+ import os.path
4
+ import asyncio
5
+ from mllm.markov_games.markov_game import MarkovGame
6
+
7
+ async def VinePPORunner(
8
+ markov_game: MarkovGame,
9
+ **kwargs):
10
+ pass
src_code_for_reproducibility/models/__init__.py ADDED
File without changes
src_code_for_reproducibility/models/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (153 Bytes). View file
 
src_code_for_reproducibility/models/__pycache__/inference_backend.cpython-312.pyc ADDED
Binary file (2.24 kB). View file
 
src_code_for_reproducibility/models/__pycache__/inference_backend_vllm.cpython-312.pyc ADDED
Binary file (4.98 kB). View file
 
src_code_for_reproducibility/models/__pycache__/large_language_model_local.cpython-312.pyc ADDED
Binary file (16.7 kB). View file
 
src_code_for_reproducibility/models/__pycache__/scalar_critic.cpython-312.pyc ADDED
Binary file (3.21 kB). View file
 
src_code_for_reproducibility/models/adapter_training_wrapper.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import logging
4
+ from typing import Union
5
+ from peft import (
6
+ LoraConfig,
7
+ get_peft_model,
8
+ )
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class AdapterWrapper(nn.Module):
14
+ """
15
+ A thin façade that
16
+ • keeps a reference to a *shared* PEFT-wrapped model,
17
+ • ensures `set_adapter(adapter)` is called on every forward,
18
+ • exposes only the parameters that should be trained for that adapter
19
+ (plus whatever extra modules you name).
20
+ """
21
+ def __init__(
22
+ self,
23
+ shared_llm: nn.Module,
24
+ adapter_id: str,
25
+ lora_config: dict,
26
+ path: Union[str, None] = None,
27
+ ):
28
+ super().__init__()
29
+ self.shared_llm = shared_llm
30
+ self.adapter_id = adapter_id
31
+ lora_config = LoraConfig(**lora_config)
32
+ # this modifies the shared llm in place, adding a lora adapter inside
33
+ self.shared_llm = get_peft_model(
34
+ model=shared_llm,
35
+ peft_config=lora_config,
36
+ adapter_name=adapter_id,
37
+ )
38
+ self.shared_llm.train()
39
+ # Load external adapter weights if provided
40
+ loaded_from: str | None = None
41
+ if path:
42
+ try:
43
+ # Supports both local filesystem paths and HF Hub repo IDs
44
+ self.shared_llm.load_adapter(
45
+ is_trainable=True,
46
+ model_id=path,
47
+ adapter_name=adapter_id,
48
+ )
49
+ loaded_from = path
50
+ except Exception as exc: # noqa: BLE001 - want to log any load failure context
51
+ logger.warning(
52
+ f"Adapter '{adapter_id}': failed to load from '{path}': {exc}"
53
+ )
54
+
55
+ if loaded_from:
56
+ logger.info(
57
+ f"Adapter '{adapter_id}': loaded initial weights from '{loaded_from}'."
58
+ )
59
+ else:
60
+ logger.info(
61
+ f"Adapter '{adapter_id}': initialized with fresh weights (no initial weights found)."
62
+ )
63
+
64
+ def parameters(self, recurse: bool = True):
65
+ """
66
+ "recurse" is just for pytorch compatibility
67
+ """
68
+ self.shared_llm.set_adapter(self.adapter_id)
69
+ params = [p for p in self.shared_llm.parameters() if p.requires_grad]
70
+
71
+ return params
72
+
73
+ def get_base_model_logits(self, contexts):
74
+ """
75
+ Run the base model (without adapter) in inference mode, without tracking gradients.
76
+ This is useful to get reference logits for KL-divergence computation.
77
+ """
78
+ with torch.no_grad():
79
+ with self.shared_llm.disable_adapter():
80
+ return self.shared_llm(input_ids=contexts)[0]
81
+
82
+ def forward(self, *args, **kwargs):
83
+ self.shared_llm.set_adapter(self.adapter_id)
84
+ return self.shared_llm(*args, **kwargs)
85
+
86
+ def save_pretrained(self, save_path):
87
+ self.shared_llm.save_pretrained(save_path)
88
+
89
+ def gradient_checkpointing_enable(self, *args, **kwargs):
90
+ self.shared_llm.gradient_checkpointing_enable(*args, **kwargs)
91
+
92
+ @property
93
+ def dtype(self):
94
+ return self.shared_llm.dtype
95
+
96
+ @property
97
+ def device(self):
98
+ return self.shared_llm.device
src_code_for_reproducibility/models/human_policy.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ import re
4
+ import shutil
5
+ import sys
6
+ from typing import Callable, Dict, List, Optional
7
+
8
+ from mllm.markov_games.rollout_tree import ChatTurn
9
+
10
+ try:
11
+ import rstr # For generating example strings from regex
12
+ except Exception: # pragma: no cover
13
+ rstr = None
14
+
15
+
16
+ def _clear_terminal() -> None:
17
+ """
18
+ Clear the terminal screen in a cross-platform manner.
19
+ """
20
+ if sys.stdout.isatty():
21
+ os.system("cls" if os.name == "nt" else "clear")
22
+
23
+
24
+ def _terminal_width(default: int = 100) -> int:
25
+ try:
26
+ return shutil.get_terminal_size().columns
27
+ except Exception:
28
+ return default
29
+
30
+
31
+ def _horizontal_rule(char: str = "─") -> str:
32
+ width = max(20, _terminal_width() - 2)
33
+ return char * width
34
+
35
+
36
+ class _Style:
37
+ # ANSI colors (bright, readable)
38
+ RESET = "\033[0m"
39
+ BOLD = "\033[1m"
40
+ DIM = "\033[2m"
41
+ # Foreground colors
42
+ FG_BLUE = "\033[94m" # user/system headers
43
+ FG_GREEN = "\033[92m" # human response header
44
+ FG_YELLOW = "\033[93m" # notices
45
+ FG_RED = "\033[91m" # errors
46
+ FG_MAGENTA = "\033[95m" # regex
47
+ FG_CYAN = "\033[96m" # tips
48
+
49
+
50
+ def _render_chat(state) -> str:
51
+ """
52
+ Render prior messages in a compact, readable terminal format.
53
+
54
+ Expected message dict keys: {"role": str, "content": str, ...}
55
+ """
56
+ lines: List[str] = []
57
+ lines.append(_horizontal_rule())
58
+ lines.append(f"{_Style.FG_BLUE}{_Style.BOLD} Conversation so far {_Style.RESET}")
59
+ lines.append(_horizontal_rule())
60
+ for chat in state:
61
+ role = chat.role
62
+ content = str(chat.content).strip()
63
+ # Map roles to display names and colors/emojis
64
+ if role == "assistant":
65
+ header = f"{_Style.FG_GREEN}{_Style.BOLD}HUMAN--🧑‍💻{_Style.RESET}"
66
+ elif role == "user":
67
+ header = f"{_Style.FG_BLUE}{_Style.BOLD}USER--⚙️{_Style.RESET}"
68
+ else:
69
+ header = f"[{_Style.DIM}{role.upper()}{_Style.RESET}]"
70
+ lines.append(header)
71
+ # Indent content for readability
72
+ for line in content.splitlines() or [""]:
73
+ lines.append(f" {line}")
74
+ lines.append("")
75
+ lines.append(_horizontal_rule())
76
+ return "\n".join(lines)
77
+
78
+
79
+ async def _async_input(prompt_text: str) -> str:
80
+ """Non-blocking input using a background thread."""
81
+ return await asyncio.to_thread(input, prompt_text)
82
+
83
+
84
+ def _short_regex_example(regex: str, max_len: int = 30) -> Optional[str]:
85
+ """
86
+ Try to produce a short example string that matches the regex.
87
+ We attempt multiple times and pick the first <= max_len.
88
+ """
89
+ if rstr is None:
90
+ return None
91
+ try:
92
+ for _ in range(20):
93
+ candidate = rstr.xeger(regex)
94
+ if len(candidate) <= max_len:
95
+ return candidate
96
+ # Fallback to truncation (may break match, so don't return)
97
+ return None
98
+ except Exception:
99
+ return None
100
+
101
+
102
+ def _detect_input_type(regex: str | None) -> tuple[str, str, str]:
103
+ """
104
+ Detect what type of input is expected based on the regex pattern.
105
+ Returns (input_type, start_tag, end_tag)
106
+ """
107
+ if regex is None:
108
+ return "text", "", ""
109
+
110
+ if "message_start" in regex and "message_end" in regex:
111
+ return "message", "<<message_start>>", "<<message_end>>"
112
+ elif "proposal_start" in regex and "proposal_end" in regex:
113
+ return "proposal", "<<proposal_start>>", "<<proposal_end>>"
114
+ else:
115
+ return "text", "", ""
116
+
117
+
118
+ async def human_policy(state, agent_id, regex: str | None = None) -> str:
119
+ """
120
+ Async human-in-the-loop policy.
121
+
122
+ - Displays prior conversation context in the terminal.
123
+ - Prompts the user for a response.
124
+ - If a regex is provided, validates and re-prompts until it matches.
125
+ - Automatically adds formatting tags based on expected input type.
126
+
127
+ Args:
128
+ prompt: Chat history as a list of {role, content} dicts.
129
+ regex: Optional fullmatch validation pattern.
130
+
131
+ Returns:
132
+ The user's validated response string.
133
+ """
134
+ # Detect input type and formatting
135
+ input_type, start_tag, end_tag = _detect_input_type(regex)
136
+
137
+ while True:
138
+ _clear_terminal()
139
+ print(_render_chat(state))
140
+
141
+ if regex:
142
+ example = _short_regex_example(regex, max_len=30)
143
+ print(
144
+ f"{_Style.FG_MAGENTA}{_Style.BOLD}Expected format (regex fullmatch):{_Style.RESET}"
145
+ )
146
+ print(f" {_Style.FG_MAGENTA}{regex}{_Style.RESET}")
147
+ if example:
148
+ print(
149
+ f"{_Style.FG_CYAN}Example (random, <=30 chars):{_Style.RESET} {example}"
150
+ )
151
+ print(_horizontal_rule("."))
152
+
153
+ # Custom prompt based on input type
154
+ if input_type == "message":
155
+ print(
156
+ f"{_Style.FG_YELLOW}Type your message content (formatting will be added automatically):{_Style.RESET}"
157
+ )
158
+ elif input_type == "proposal":
159
+ print(
160
+ f"{_Style.FG_YELLOW}Type your proposal (number only, formatting will be added automatically):{_Style.RESET}"
161
+ )
162
+ else:
163
+ print(
164
+ f"{_Style.FG_YELLOW}Type your response and press Enter.{_Style.RESET}"
165
+ )
166
+
167
+ print(
168
+ f"{_Style.DIM}Commands: /help to view commands, /refresh to re-render, /quit to abort{_Style.RESET}"
169
+ )
170
+ else:
171
+ print(
172
+ f"{_Style.FG_YELLOW}Type your response and press Enter.{_Style.RESET} {_Style.DIM}(/help for commands){_Style.RESET}"
173
+ )
174
+
175
+ user_in = (await _async_input("> ")).rstrip("\n")
176
+
177
+ # Commands
178
+ if user_in.strip().lower() in {"/help", "/h"}:
179
+ print(f"\n{_Style.FG_CYAN}{_Style.BOLD}Available commands:{_Style.RESET}")
180
+ print(
181
+ f" {_Style.FG_CYAN}/help{_Style.RESET} or {_Style.FG_CYAN}/h{_Style.RESET} Show this help"
182
+ )
183
+ print(
184
+ f" {_Style.FG_CYAN}/refresh{_Style.RESET} or {_Style.FG_CYAN}/r{_Style.RESET} Re-render the conversation and prompt"
185
+ )
186
+ print(
187
+ f" {_Style.FG_CYAN}/quit{_Style.RESET} or {_Style.FG_CYAN}/q{_Style.RESET} Abort the run (raises KeyboardInterrupt)"
188
+ )
189
+ await asyncio.sleep(1.0)
190
+ continue
191
+ if user_in.strip().lower() in {"/refresh", "/r"}:
192
+ continue
193
+ if user_in.strip().lower() in {"/quit", "/q"}:
194
+ raise KeyboardInterrupt("Human aborted run from human_policy")
195
+
196
+ # Add formatting tags if needed
197
+ if start_tag and end_tag:
198
+ formatted_input = f"{start_tag}{user_in}{end_tag}"
199
+ else:
200
+ formatted_input = user_in
201
+
202
+ if regex is None:
203
+ return ChatTurn(
204
+ role="assistant", agent_id=agent_id, content=formatted_input
205
+ )
206
+
207
+ # Validate against regex (fullmatch)
208
+ try:
209
+ pattern = re.compile(regex)
210
+ except re.error as e:
211
+ # If regex is invalid, fall back to accepting any input
212
+ print(
213
+ f"{_Style.FG_RED}Warning:{_Style.RESET} Provided regex is invalid: {e}. Accepting input without validation."
214
+ )
215
+ await asyncio.sleep(0.5)
216
+ return ChatTurn(
217
+ role="assistant", agent_id=agent_id, content=formatted_input
218
+ )
219
+
220
+ if pattern.fullmatch(formatted_input):
221
+ return ChatTurn(
222
+ role="assistant", agent_id=agent_id, content=formatted_input
223
+ )
224
+
225
+ # Show validation error and re-prompt
226
+ print("")
227
+ print(
228
+ f"{_Style.FG_RED}{_Style.BOLD}Input did not match the required format.{_Style.RESET} Please try again."
229
+ )
230
+
231
+ if input_type == "message":
232
+ print(
233
+ f"You entered: {_Style.FG_CYAN}{start_tag}{user_in}{end_tag}{_Style.RESET}"
234
+ )
235
+ print(f"Just type the message content without tags.")
236
+ elif input_type == "proposal":
237
+ print(
238
+ f"You entered: {_Style.FG_CYAN}{start_tag}{user_in}{end_tag}{_Style.RESET}"
239
+ )
240
+ print(f"Just type the number without tags.")
241
+ else:
242
+ print(f"Expected (regex):")
243
+ print(f" {_Style.FG_MAGENTA}{regex}{_Style.RESET}")
244
+
245
+ print(_horizontal_rule("."))
246
+ print(f"{_Style.FG_YELLOW}Press Enter to retry...{_Style.RESET}")
247
+ await _async_input("")
248
+
249
+
250
+ def get_human_policies() -> Dict[str, Callable[[List[Dict]], str]]:
251
+ """
252
+ Expose the human policy in the same map shape used elsewhere.
253
+ """
254
+ # Type hint says Callable[[List[Dict]], str] but we intentionally return the async callable.
255
+ return {"human_policy": human_policy} # type: ignore[return-value]
src_code_for_reproducibility/models/inference_backend.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import dataclass
3
+ from typing import Any, Optional
4
+
5
+
6
+ @dataclass
7
+ class LLMInferenceOutput:
8
+ content: str
9
+ reasoning_content: str | None = None
10
+ log_probs: list[float] | None = None
11
+ out_token_ids: list[int] | None = None
12
+
13
+
14
+ class LLMInferenceBackend(ABC):
15
+ @abstractmethod
16
+ def __init__(self, **kwargs):
17
+ ...
18
+
19
+ @abstractmethod
20
+ def prepare_adapter(
21
+ self, adapter_id: str, weights_got_updated: bool = False
22
+ ) -> None:
23
+ """Ensure adapter is ready/loaded for next generation call."""
24
+
25
+ @abstractmethod
26
+ async def generate(self, prompt: list[dict], regex: Optional[str] = None) -> str:
27
+ ...
28
+
29
+ @abstractmethod
30
+ def toggle_training_mode(self) -> None:
31
+ ...
32
+
33
+ @abstractmethod
34
+ def toggle_eval_mode(self) -> None:
35
+ ...
36
+
37
+ @abstractmethod
38
+ def shutdown(self) -> None:
39
+ ...
src_code_for_reproducibility/models/inference_backend_dummy.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from typing import Optional
3
+
4
+ import rstr
5
+ from transformers import AutoTokenizer
6
+
7
+ from mllm.models.inference_backend import LLMInferenceBackend, LLMInferenceOutput
8
+ from mllm.utils.short_id_gen import generate_short_id
9
+
10
+
11
+ class DummyInferenceBackend(LLMInferenceBackend):
12
+ def __init__(
13
+ self,
14
+ *args,
15
+ **kwargs,
16
+ ):
17
+ pass
18
+
19
+ def prepare_adapter(
20
+ self,
21
+ adapter_id: Optional[str],
22
+ weights_got_updated: bool,
23
+ adapter_path: Optional[str] = None,
24
+ ) -> None:
25
+ pass
26
+
27
+ async def toggle_training_mode(self) -> None:
28
+ await asyncio.sleep(0)
29
+ pass
30
+
31
+ async def toggle_eval_mode(self) -> None:
32
+ await asyncio.sleep(0)
33
+ pass
34
+
35
+ def shutdown(self) -> None:
36
+ pass
37
+
38
+ async def generate(
39
+ self,
40
+ prompt_text: str,
41
+ regex: Optional[str] = None,
42
+ extract_thinking: bool = False,
43
+ ) -> LLMInferenceOutput:
44
+ if regex:
45
+ # Create random string that respects the regex
46
+ return LLMInferenceOutput(
47
+ content=rstr.xeger(regex),
48
+ reasoning_content="I don't think, I am a dummy backend.",
49
+ )
50
+ else:
51
+ return LLMInferenceOutput(
52
+ content="I am a dummy backend without a regex.",
53
+ reasoning_content="I don't think, I am a dummy backend.",
54
+ )
src_code_for_reproducibility/models/inference_backend_sglang.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # new_backend_sglang_offline.py
2
+ from __future__ import annotations
3
+
4
+ import asyncio
5
+ from typing import Any, Optional
6
+
7
+ # import sglang as sgl
8
+
9
+ from mllm.models.inference_backend import LLMInferenceBackend
10
+
11
+
12
+ class SGLangOfflineBackend(LLMInferenceBackend):
13
+ def __init__(
14
+ self,
15
+ model_name: str,
16
+ tokenizer, # unused but kept for parity
17
+ adapter_paths: dict[str, str],
18
+ device: str = "cuda",
19
+ max_model_len: Optional[int] = None,
20
+ enable_lora: bool = True,
21
+ lora_target_modules: Optional[list[str] | str] = None,
22
+ max_loras_per_batch: int = 8,
23
+ engine_kwargs: dict[str, Any] = None,
24
+ ):
25
+ self.model_name = model_name
26
+ self.adapter_paths = adapter_paths
27
+ self.current_adapter: Optional[str] = None
28
+ engine_kwargs = dict(engine_kwargs or {})
29
+ # Map server-style LoRA flags to offline engine ctor
30
+ if enable_lora and adapter_paths:
31
+ engine_kwargs.setdefault("enable_lora", True)
32
+ # The offline Engine mirrors server args; pass a mapping name->path
33
+ engine_kwargs.setdefault("lora_paths", adapter_paths)
34
+ if lora_target_modules is not None:
35
+ engine_kwargs.setdefault("lora_target_modules", lora_target_modules)
36
+ engine_kwargs.setdefault("max_loras_per_batch", max_loras_per_batch)
37
+
38
+ if max_model_len is not None:
39
+ engine_kwargs.setdefault("context_length", max_model_len)
40
+
41
+ # Launch in-process engine (no HTTP server)
42
+ self.llm = sgl.Engine(model_path=model_name, **engine_kwargs) # async-ready
43
+ # SGLang supports: generate(), async_generate(), and async streaming helpers. :contentReference[oaicite:2]{index=2}
44
+
45
+ def is_ready(self) -> bool:
46
+ return True
47
+
48
+ def toggle_training_mode(self) -> None:
49
+ # No explicit KV release API offline; typically you pause usage here.
50
+ pass
51
+
52
+ def toggle_eval_mode(self) -> None:
53
+ pass
54
+
55
+ def shutdown(self) -> None:
56
+ # Engine cleans up on GC; explicit close not required.
57
+ pass
58
+
59
+ def prepare_adapter(self, adapter_id: Optional[str]) -> None:
60
+ # With offline Engine, when LoRA is enabled at init,
61
+ # you select adapter per request via the input batch mapping.
62
+ self.current_adapter = adapter_id
63
+
64
+ async def generate(
65
+ self, prompt_text: str, sampling_params: dict, adapter_id: Optional[str]
66
+ ) -> str:
67
+ # Non-streaming async (batch of 1). For batched prompts, pass a list.
68
+ params = {
69
+ "temperature": sampling_params.get("temperature", 1.0),
70
+ "top_p": sampling_params.get("top_p", 1.0),
71
+ "max_new_tokens": sampling_params.get("max_new_tokens", 128),
72
+ }
73
+ if (tk := sampling_params.get("top_k", -1)) and tk > 0:
74
+ params["top_k"] = tk
75
+ if (mn := sampling_params.get("min_new_tokens")) is not None:
76
+ params["min_new_tokens"] = mn
77
+ if (fp := sampling_params.get("frequency_penalty")) is not None:
78
+ params["frequency_penalty"] = fp
79
+
80
+ # If using multi-LoRA, SGLang lets you provide adapter names aligned to each input.
81
+ prompts = [prompt_text]
82
+ adapters = [adapter_id] if adapter_id else None # or omit for base
83
+ outs = await self.llm.async_generate(
84
+ prompts, params, adapters
85
+ ) # :contentReference[oaicite:3]{index=3}
86
+ return outs[0]["text"]
src_code_for_reproducibility/models/inference_backend_sglang_local_server.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import httpx
4
+ import requests
5
+ from sglang.utils import launch_server_cmd, wait_for_server
6
+
7
+ from mllm.models.inference_backend import LLMInferenceBackend
8
+
9
+
10
+ class HttpSGLangBackend(LLMInferenceBackend):
11
+ def __init__(self, **kwargs):
12
+ super().__init__(**kwargs)
13
+ self.port = None
14
+ self.proc = None
15
+ self.urls = {}
16
+ # track sglang adapter ids separately from your logical ids
17
+ self.sglang_names = {aid: aid for aid in self.adapter_paths.keys()}
18
+ self.needs_loading = {aid: True for aid in self.adapter_paths.keys()}
19
+
20
+ # defaults you already used:
21
+ self.mem_fraction = kwargs.get("mem_fraction_static", 0.6)
22
+ self.dtype = kwargs.get("dtype", "bfloat16")
23
+ self.extra_cli = kwargs.get("extra_cli", "")
24
+ self.disable_radix_cache = kwargs.get("disable_radix_cache", True)
25
+
26
+ def launch(self) -> None:
27
+ # find local hf cache path for server
28
+ from transformers.utils import cached_file
29
+
30
+ local_llm_path = os.path.split(cached_file(self.model_name, "config.json"))[0]
31
+
32
+ lora_str = ""
33
+ if self.adapter_paths:
34
+ lora_str = "--lora-paths " + " ".join(
35
+ f"{aid}={path}" for aid, path in self.adapter_paths.items()
36
+ )
37
+
38
+ cmd = f"""
39
+ python3 -m sglang.launch_server --model-path {local_llm_path} \
40
+ --host 0.0.0.0 {lora_str} \
41
+ {'--disable-radix-cache' if self.disable_radix_cache else ''} \
42
+ --mem-fraction-static {self.mem_fraction} --dtype {self.dtype} {self.extra_cli}
43
+ """
44
+ self.proc, self.port = launch_server_cmd(cmd)
45
+ wait_for_server(f"http://localhost:{self.port}")
46
+ base = f"http://localhost:{self.port}"
47
+ self.urls = dict(
48
+ generate=f"{base}/generate",
49
+ release=f"{base}/release_memory_occupation",
50
+ resume=f"{base}/resume_memory_occupation",
51
+ load_lora=f"{base}/load_lora_adapter",
52
+ unload_lora=f"{base}/unload_lora_adapter",
53
+ )
54
+
55
+ def is_ready(self) -> bool:
56
+ try:
57
+ requests.get(self.urls["generate"], timeout=2)
58
+ return True
59
+ except Exception:
60
+ return False
61
+
62
+ def prepare_adapter(self, adapter_id: str) -> None:
63
+ if adapter_id is None:
64
+ return
65
+ if self.needs_loading.get(adapter_id, False):
66
+ # unload old name if present
67
+ try:
68
+ requests.post(
69
+ self.urls["unload_lora"],
70
+ json={"lora_name": self.sglang_names[adapter_id]},
71
+ timeout=10,
72
+ )
73
+ except Exception:
74
+ pass
75
+ new_name = self._short_id()
76
+ self.sglang_names[adapter_id] = new_name
77
+ requests.post(
78
+ self.urls["load_lora"],
79
+ json={
80
+ "lora_name": new_name,
81
+ "lora_path": self.adapter_paths[adapter_id],
82
+ },
83
+ ).raise_for_status()
84
+ self.needs_loading[adapter_id] = False
85
+
86
+ async def generate(
87
+ self, prompt_text: str, sampling_params: dict, adapter_id: str | None
88
+ ) -> str:
89
+ lora_name = self.sglang_names.get(adapter_id) if adapter_id else None
90
+ payload = {
91
+ "text": [prompt_text],
92
+ "sampling_params": sampling_params,
93
+ }
94
+ if lora_name:
95
+ payload["lora_path"] = [lora_name]
96
+
97
+ timeout = httpx.Timeout(3600.0, connect=3600.0)
98
+ async with httpx.AsyncClient(timeout=timeout) as client:
99
+ resp = await client.post(self.urls["generate"], json=payload)
100
+ resp.raise_for_status()
101
+ return resp.json()[0]["text"]
102
+
103
+ def toggle_training_mode(self) -> None:
104
+ # free KV space while training adapters
105
+ requests.post(
106
+ self.urls["release"], json={"tags": ["kv_cache"]}
107
+ ).raise_for_status()
108
+
109
+ def toggle_eval_mode(self) -> None:
110
+ # re-allocate KV space
111
+ try:
112
+ requests.post(
113
+ self.urls["resume"], json={"tags": ["kv_cache"]}
114
+ ).raise_for_status()
115
+ except Exception:
116
+ pass
117
+
118
+ def shutdown(self) -> None:
119
+ from sglang.utils import terminate_process
120
+
121
+ if self.proc:
122
+ terminate_process(self.proc)
123
+
124
+ def _short_id(self) -> str:
125
+ import uuid
126
+
127
+ return str(uuid.uuid4().int)[:8]
src_code_for_reproducibility/models/inference_backend_vllm.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import re
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from transformers import AutoTokenizer
7
+ from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
8
+ from vllm.inputs import TokensPrompt
9
+ from vllm.lora.request import LoRARequest
10
+ from vllm.sampling_params import GuidedDecodingParams, RequestOutputKind
11
+
12
+ from mllm.models.inference_backend import LLMInferenceBackend, LLMInferenceOutput
13
+ from mllm.utils.short_id_gen import generate_short_id
14
+
15
+
16
+ class VLLMAsyncBackend(LLMInferenceBackend):
17
+ def __init__(
18
+ self,
19
+ model_name: str,
20
+ tokenizer: AutoTokenizer,
21
+ # adapter_paths: dict[str, str],
22
+ engine_init_kwargs: dict = {},
23
+ sampling_params: dict = {},
24
+ ):
25
+ self.model_name = model_name
26
+ # self.adapter_paths = adapter_paths or {}
27
+ # self.current_adapter = None
28
+ # self.vllm_adapter_ids = {
29
+ # adapter_id: generate_short_id() for adapter_id in adapter_paths.keys()
30
+ # }
31
+ self.vllm_adapter_ids = {}
32
+ ea = dict(model=model_name, **engine_init_kwargs)
33
+ # ea["enable_lora"] = True
34
+ # ea["max_loras"] = len(self.vllm_adapter_ids)
35
+ # ea["enable_sleep_mode"] = True
36
+ self.engine = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**ea))
37
+
38
+ self.sampling_params = sampling_params
39
+ self.tokenizer = tokenizer
40
+
41
+ def prepare_adapter(
42
+ self,
43
+ adapter_id: Optional[str],
44
+ adapter_path: Optional[str],
45
+ weights_got_updated: bool,
46
+ ) -> None:
47
+ # self.current_adapter = adapter_id
48
+ if weights_got_updated:
49
+ self.vllm_adapter_ids[adapter_id] = generate_short_id()
50
+ self.current_lora_request = LoRARequest(
51
+ adapter_id,
52
+ self.vllm_adapter_ids[adapter_id],
53
+ adapter_path,
54
+ )
55
+
56
+ async def toggle_training_mode(self) -> None:
57
+ await self.engine.sleep(level=1)
58
+
59
+ async def toggle_eval_mode(self) -> None:
60
+ await self.engine.wake_up()
61
+
62
+ def shutdown(self) -> None:
63
+ # No explicit close call; engine stops when process exits.
64
+ pass
65
+
66
+ async def generate(
67
+ self,
68
+ input_token_ids: list[int],
69
+ regex: Optional[str] = None,
70
+ extract_thinking: bool = False,
71
+ ) -> LLMInferenceOutput:
72
+ # Build SamplingParams correctly
73
+ guided = GuidedDecodingParams(regex=regex) if regex else None
74
+ sp = SamplingParams(
75
+ **self.sampling_params,
76
+ guided_decoding=guided,
77
+ output_kind=RequestOutputKind.FINAL_ONLY,
78
+ )
79
+
80
+ prompt = TokensPrompt(prompt_token_ids=input_token_ids)
81
+ request_id = f"req-{asyncio.get_running_loop().time()}"
82
+ result_generator = self.engine.generate(
83
+ prompt,
84
+ sp, # SamplingParams(...)
85
+ request_id,
86
+ lora_request=self.current_lora_request,
87
+ )
88
+
89
+ async for out in result_generator: # with FINAL_ONLY this runs once
90
+ res = out
91
+
92
+ raw_text = res.outputs[0].text
93
+ out_token_ids = res.outputs[0].token_ids
94
+ log_probs = [
95
+ logprob_dict[token_id].logprob
96
+ for token_id, logprob_dict in zip(out_token_ids, res.outputs[0].logprobs)
97
+ ]
98
+ log_probs = torch.tensor(log_probs)
99
+ out_token_ids = torch.tensor(out_token_ids, dtype=torch.long)
100
+ # for out_token_id, logprob_dict in zip(out_token_ids, res.outputs[0].logprobs):
101
+ # if logprob_dict[out_token_id].logprob < -1:
102
+ # print(f"High negative logprob {logprob_dict[out_token_id].logprob} for {logprob_dict}")
103
+ content = raw_text
104
+ reasoning_content = None
105
+
106
+ if extract_thinking:
107
+ m = re.match(
108
+ r"^\n<think>\n([\s\S]*?)</think>\n\n(.*)$", raw_text, flags=re.DOTALL
109
+ )
110
+ if m:
111
+ reasoning_content = m.group(1)
112
+ content = m.group(2)
113
+ return LLMInferenceOutput(
114
+ content=content,
115
+ reasoning_content=reasoning_content,
116
+ log_probs=log_probs,
117
+ out_token_ids=out_token_ids,
118
+ )
src_code_for_reproducibility/models/inference_backend_vllm_local_server.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import subprocess
4
+ import time
5
+
6
+ import httpx
7
+ import requests
8
+
9
+ from mllm.models.inference_backend import LLMInferenceBackend
10
+
11
+
12
+ class HttpVLLMBackend(LLMInferenceBackend):
13
+ def __init__(self, **kwargs):
14
+ super().__init__(**kwargs)
15
+ self.port = kwargs.get("port", 8000)
16
+ self.host = kwargs.get("host", "0.0.0.0")
17
+ self.proc = None
18
+ self.base_url = f"http://{self.host}:{self.port}"
19
+ # vLLM memory safety knobs
20
+ self.gpu_mem_util = kwargs.get("gpu_memory_utilization", 0.9)
21
+ self.max_model_len = kwargs.get("max_model_len", None)
22
+ self.max_num_seqs = kwargs.get("max_num_seqs", None)
23
+ self.max_batched_tokens = kwargs.get("max_num_batched_tokens", None)
24
+ self.dtype = kwargs.get("dtype", "bfloat16")
25
+ self.trust_remote_code = kwargs.get("trust_remote_code", False)
26
+ # LoRA strategy: "preload" (CLI) or "runtime" (endpoints) depending on your vLLM build
27
+ self.lora_mode = kwargs.get(
28
+ "lora_mode", "preload"
29
+ ) # "runtime" supported in newer builds
30
+ self.runtime_lora_enabled = self.lora_mode == "runtime"
31
+
32
+ # If preloading: build CLI args (adapter name -> path)
33
+ self._preload_lora_args = []
34
+ if self.adapter_paths and self.lora_mode == "preload":
35
+ # vLLM supports multiple LoRA modules via CLI in recent versions
36
+ # Example flag shapes can vary; adapt as needed for your version:
37
+ # --lora-modules adapter_id=path
38
+ for aid, pth in self.adapter_paths.items():
39
+ self._preload_lora_args += ["--lora-modules", f"{aid}={pth}"]
40
+
41
+ def launch(self):
42
+ # Build vLLM serve command
43
+ cmd = [
44
+ "python3",
45
+ "-m",
46
+ "vllm.entrypoints.openai.api_server",
47
+ "--model",
48
+ self.model_name,
49
+ "--host",
50
+ self.host,
51
+ "--port",
52
+ str(self.port),
53
+ "--dtype",
54
+ self.dtype,
55
+ "--gpu-memory-utilization",
56
+ str(self.gpu_mem_util),
57
+ ]
58
+ if self.trust_remote_code:
59
+ cmd += ["--trust-remote-code"]
60
+ if self.max_model_len:
61
+ cmd += ["--max-model-len", str(self.max_model_len)]
62
+ if self.max_num_seqs:
63
+ cmd += ["--max-num-seqs", str(self.max_num_seqs)]
64
+ if self.max_batched_tokens:
65
+ cmd += ["--max-num-batched-tokens", str(self.max_batched_tokens)]
66
+ cmd += self._preload_lora_args
67
+
68
+ self.proc = subprocess.Popen(
69
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
70
+ )
71
+ self._wait_ready()
72
+
73
+ def _wait_ready(self, timeout=120):
74
+ url = f"{self.base_url}/v1/models"
75
+ t0 = time.time()
76
+ while time.time() - t0 < timeout:
77
+ try:
78
+ r = requests.get(url, timeout=2)
79
+ if r.status_code == 200:
80
+ return
81
+ except Exception:
82
+ pass
83
+ time.sleep(1)
84
+ raise RuntimeError("vLLM server did not become ready in time")
85
+
86
+ def is_ready(self) -> bool:
87
+ try:
88
+ return (
89
+ requests.get(f"{self.base_url}/v1/models", timeout=2).status_code == 200
90
+ )
91
+ except Exception:
92
+ return False
93
+
94
+ def prepare_adapter(self, adapter_id: str) -> None:
95
+ if not adapter_id or not self.runtime_lora_enabled:
96
+ return
97
+ # Newer vLLM builds expose runtime LoRA endpoints. If yours differs,
98
+ # adjust the path/body here and keep the interface stable.
99
+ try:
100
+ requests.post(
101
+ f"{self.base_url}/v1/load_lora_adapter",
102
+ json={
103
+ "adapter_name": adapter_id,
104
+ "adapter_path": self.adapter_paths[adapter_id],
105
+ },
106
+ timeout=10,
107
+ ).raise_for_status()
108
+ except Exception as e:
109
+ # If already loaded or endpoint not present, swallow or log
110
+ pass
111
+
112
+ async def generate(
113
+ self, prompt_text: str, sampling_params: dict, adapter_id: str | None
114
+ ) -> str:
115
+ # Map your sampling params to OpenAI schema
116
+ body = {
117
+ "model": self.model_name,
118
+ "messages": [{"role": "user", "content": prompt_text}],
119
+ "temperature": sampling_params.get("temperature", 1.0),
120
+ "top_p": sampling_params.get("top_p", 1.0),
121
+ "max_tokens": sampling_params.get("max_new_tokens", 128),
122
+ }
123
+ # Optional knobs:
124
+ if sampling_params.get("top_k", -1) and sampling_params["top_k"] > 0:
125
+ # vLLM accepts top_k via extra params; put under "extra_body"
126
+ body.setdefault("extra_body", {})["top_k"] = sampling_params["top_k"]
127
+ if sampling_params.get("min_new_tokens", None) is not None:
128
+ body.setdefault("extra_body", {})["min_tokens"] = sampling_params[
129
+ "min_new_tokens"
130
+ ]
131
+ if sampling_params.get("frequency_penalty", None) is not None:
132
+ body["frequency_penalty"] = sampling_params["frequency_penalty"]
133
+
134
+ # Select LoRA adapter
135
+ if adapter_id:
136
+ if self.runtime_lora_enabled:
137
+ body.setdefault("extra_body", {})["lora_adapter"] = adapter_id
138
+ else:
139
+ # when preloaded via CLI, most builds select by name via "adapter_name"/"lora_adapter"
140
+ body.setdefault("extra_body", {})["lora_adapter"] = adapter_id
141
+
142
+ url = f"{self.base_url}/v1/chat/completions"
143
+ timeout = httpx.Timeout(3600.0, connect=3600.0)
144
+ async with httpx.AsyncClient(timeout=timeout) as client:
145
+ resp = await client.post(url, json=body)
146
+ resp.raise_for_status()
147
+ data = resp.json()
148
+ return data["choices"][0]["message"]["content"]
149
+
150
+ def toggle_training_mode(self) -> None:
151
+ # vLLM doesn’t expose an explicit KV “release” toggle via API.
152
+ # Strategy: keep inference server idle during training, or run training in a separate process.
153
+ pass
154
+
155
+ def toggle_eval_mode(self) -> None:
156
+ pass
157
+
158
+ def shutdown(self) -> None:
159
+ if self.proc:
160
+ self.proc.terminate()
src_code_for_reproducibility/models/large_language_model_api.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import copy
5
+ import os
6
+ import random
7
+ import re
8
+ from typing import Any, Callable, Dict, List, Optional, Sequence
9
+
10
+ import backoff
11
+ from openai import AsyncOpenAI, OpenAIError
12
+
13
+ from mllm.markov_games.rollout_tree import ChatTurn
14
+ from mllm.models.inference_backend import LLMInferenceOutput
15
+
16
+ # TODO: Get this automatically from OpenAI
17
+ reasoning_models = [
18
+ "gpt-5-nano",
19
+ "gpt-5-mini",
20
+ "gpt-5",
21
+ "o1-mini",
22
+ "o1",
23
+ "o1-pro",
24
+ "o3-mini",
25
+ "o3",
26
+ "o3-pro",
27
+ "o4-mini",
28
+ "o4",
29
+ "o4-pro",
30
+ ]
31
+
32
+
33
+ class LargeLanguageModelOpenAI:
34
+ """Tiny async wrapper for OpenAI Chat Completions."""
35
+
36
+ def __init__(
37
+ self,
38
+ llm_id: str = "",
39
+ model: str = "gpt-4.1-mini",
40
+ api_key: Optional[str] = None,
41
+ base_url: Optional[str] = None,
42
+ timeout_s: float = 300.0,
43
+ regex_max_attempts: int = 10,
44
+ sampling_params: Optional[Dict[str, Any]] = None,
45
+ init_kwargs: Optional[Dict[str, Any]] = None,
46
+ output_directory: Optional[str] = None,
47
+ ) -> None:
48
+ self.llm_id = llm_id
49
+ self.model = model
50
+ key = api_key or os.getenv("OPENAI_API_KEY")
51
+ if not key:
52
+ raise RuntimeError(
53
+ "Set OPENAI_API_KEY as global environment variable or pass api_key."
54
+ )
55
+ client_kwargs: Dict[str, Any] = {"api_key": key, "timeout": timeout_s}
56
+ if base_url:
57
+ client_kwargs["base_url"] = base_url
58
+ self.client = AsyncOpenAI(**client_kwargs)
59
+
60
+ # Sampling/default request params set at init
61
+ self.sampling_params = sampling_params
62
+ self.use_reasoning = model in reasoning_models
63
+ if self.use_reasoning:
64
+ self.sampling_params["reasoning"] = {
65
+ "effort": "low",
66
+ "summary": "detailed",
67
+ }
68
+ self.regex_max_attempts = max(1, int(regex_max_attempts))
69
+
70
+ def get_inference_policies(self) -> Dict[str, Callable]:
71
+ return {
72
+ self.llm_id: self.get_action,
73
+ }
74
+
75
+ async def prepare_adapter_for_inference(self, *args: Any, **kwargs: Any) -> None:
76
+ await asyncio.sleep(0)
77
+ pass
78
+
79
+ async def toggle_eval_mode(self, *args: Any, **kwargs: Any) -> None:
80
+ await asyncio.sleep(0)
81
+ pass
82
+
83
+ async def toggle_training_mode(self, *args: Any, **kwargs: Any) -> None:
84
+ await asyncio.sleep(0)
85
+ pass
86
+
87
+ async def export_adapters(self, *args: Any, **kwargs: Any) -> None:
88
+ await asyncio.sleep(0)
89
+ pass
90
+
91
+ async def checkpoint_all_adapters(self, *args: Any, **kwargs: Any) -> None:
92
+ await asyncio.sleep(0)
93
+ pass
94
+
95
+ def extract_output_from_response(self, resp: Response) -> LLMInferenceOutput:
96
+ if len(resp.output) > 1:
97
+ summary = resp.output[0].summary
98
+ if summary != []:
99
+ reasoning_content = summary[0].text
100
+ reasoning_content = f"OpenAI Reasoning Summary: {reasoning_content}"
101
+ else:
102
+ reasoning_content = None
103
+ content = resp.output[1].content[0].text
104
+ else:
105
+ reasoning_content = None
106
+ content = resp.output[0].content[0].text
107
+
108
+ return LLMInferenceOutput(
109
+ content=content,
110
+ reasoning_content=reasoning_content,
111
+ )
112
+
113
+ @backoff.on_exception(
114
+ backoff.expo, Exception, max_time=10**10, max_tries=10**10
115
+ )
116
+ async def get_action(
117
+ self,
118
+ state: list[ChatTurn],
119
+ agent_id: str,
120
+ regex: Optional[str] = None,
121
+ ) -> LLMInferenceOutput:
122
+ # Remove any non-role/content keys from the prompt else openai will error
123
+
124
+ # TODO:
125
+ prompt = [{"role": p.role, "content": p.content} for p in state]
126
+
127
+ # if self.sleep_between_requests:
128
+ # await self.wait_random_time()
129
+
130
+ # If regex is required, prime the model and validate client-side
131
+ if regex:
132
+ constraint_msg = {
133
+ "role": "user",
134
+ "content": (
135
+ f"Output must match this regex exactly: {regex} \n"
136
+ "Return only the matching string, with no quotes or extra text."
137
+ ),
138
+ }
139
+ prompt = [constraint_msg, *prompt]
140
+ pattern = re.compile(regex)
141
+ for _ in range(self.regex_max_attempts):
142
+ resp = await self.client.responses.create(
143
+ model=self.model,
144
+ input=prompt,
145
+ **self.sampling_params,
146
+ )
147
+ policy_output = self.extract_output_from_response(resp)
148
+ if pattern.fullmatch(policy_output.content):
149
+ return policy_output
150
+ prompt = [
151
+ *prompt,
152
+ {
153
+ "role": "user",
154
+ "content": (
155
+ f"Invalid response format. Expected format (regex): {regex}\n Please try again and provide ONLY a response that matches this regex."
156
+ ),
157
+ },
158
+ ]
159
+ return policy_output
160
+
161
+ # Simple, unconstrained generation
162
+ resp = await self.client.responses.create(
163
+ model=self.model,
164
+ input=prompt,
165
+ **self.sampling_params,
166
+ )
167
+ policy_output = self.extract_output_from_response(resp)
168
+ return policy_output
169
+
170
+ def shutdown(self) -> None:
171
+ self.client = None
src_code_for_reproducibility/models/large_language_model_local.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TODO: Figure out how to tweak SGlang not to go OOM when batch size is 32. See https://github.com/sgl-project/sglang/issues/6309.
3
+ """
4
+
5
+ import logging
6
+ import os
7
+ import re
8
+ import sys
9
+ import uuid
10
+ from collections.abc import Callable
11
+ from copy import deepcopy
12
+ from datetime import datetime
13
+ from typing import Literal
14
+
15
+ import httpx
16
+ import requests
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ # from sglang.utils import (
21
+ # launch_server_cmd,
22
+ # print_highlight,
23
+ # terminate_process,
24
+ # wait_for_server,
25
+ # )
26
+ from torch.optim import SGD, Adam, AdamW, RMSprop
27
+ from transformers import AutoModelForCausalLM, AutoTokenizer
28
+ from trl import AutoModelForCausalLMWithValueHead
29
+
30
+ from mllm.chat_utils.apply_template import chat_turns_to_token_ids
31
+ from mllm.markov_games.rollout_tree import ChatTurn
32
+ from mllm.models.adapter_training_wrapper import AdapterWrapper
33
+ from mllm.models.inference_backend import LLMInferenceOutput
34
+ from mllm.models.inference_backend_dummy import DummyInferenceBackend
35
+ from mllm.models.inference_backend_sglang import SGLangOfflineBackend
36
+ from mllm.models.inference_backend_vllm import VLLMAsyncBackend
37
+
38
+ logger = logging.getLogger(__name__)
39
+ logger.addHandler(logging.StreamHandler(sys.stdout))
40
+
41
+ AdapterID = str
42
+ PolicyID = str
43
+
44
+
45
+ class LeanLocalLLM:
46
+ """
47
+ TOWRITE
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ llm_id: str = "base_llm",
53
+ model_name: str = "Qwen/Qwen3-4B-Instruct-2507",
54
+ device: str = "cuda",
55
+ hf_kwargs: dict = {},
56
+ adapter_configs: dict = {},
57
+ output_directory: str = "./models/",
58
+ inference_backend: Literal["vllm", "sglang", "dummy"] = "vllm",
59
+ inference_backend_sampling_params: dict = {},
60
+ inference_backend_init_kwargs: dict = {},
61
+ initial_adapter_paths: dict[str, str] | None = None,
62
+ initial_buffer_paths: list[str] | None = None,
63
+ enable_thinking: bool = None,
64
+ regex_max_attempts: int = -1,
65
+ max_thinking_characters: int = 0,
66
+ ):
67
+ self.inference_backend_name = inference_backend
68
+ self.output_directory = output_directory
69
+ self.llm_id = llm_id
70
+ self.device = torch.device(device) if device else torch.device("cuda")
71
+ self.model_name = model_name
72
+ self.adapter_configs = adapter_configs
73
+ self.adapter_ids = list(adapter_configs.keys())
74
+ self.enable_thinking = enable_thinking
75
+ self.regex_max_attempts = regex_max_attempts
76
+ self.initial_buffer_paths = initial_buffer_paths
77
+ self.max_thinking_characters = max_thinking_characters
78
+ self.regex_retries_count = 0
79
+
80
+ # Optional user-specified initial adapter weight locations (local or HF Hub)
81
+ # Format: {adapter_id: path_or_repo_id}
82
+ self.initial_adapter_paths: dict[str, str] | None = initial_adapter_paths
83
+
84
+ # Path management / imports
85
+ self.save_path = str(os.path.join(output_directory, model_name, "adapters"))
86
+ self.adapter_paths = {
87
+ adapter_id: os.path.join(self.save_path, adapter_id)
88
+ for adapter_id in self.adapter_ids
89
+ }
90
+ checkpoints_dir = os.path.join(self.output_directory, "checkpoints")
91
+ self.past_agent_adapter_paths = {}
92
+ if os.path.isdir(checkpoints_dir):
93
+ for dirname in os.listdir(checkpoints_dir):
94
+ dirpath = os.path.join(checkpoints_dir, dirname)
95
+ if os.path.isdir(dirpath):
96
+ self.past_agent_adapter_paths[f"{dirname}_buffer"] = os.path.join(
97
+ dirpath, "agent_adapter"
98
+ )
99
+ logger.info(
100
+ f"Loaded {len(self.past_agent_adapter_paths)} past agent adapters from checkpoints directory."
101
+ )
102
+ if self.initial_buffer_paths is not None:
103
+ previous_count = len(self.past_agent_adapter_paths)
104
+ for path in self.initial_buffer_paths:
105
+ if os.path.isdir(path):
106
+ for dirname in os.listdir(path):
107
+ dirpath = os.path.join(path, dirname)
108
+ if os.path.isdir(dirpath):
109
+ self.past_agent_adapter_paths[
110
+ f"{dirname}_buffer"
111
+ ] = os.path.join(dirpath, "agent_adapter")
112
+ else:
113
+ logger.warning(
114
+ f"Initial buffer path {path} does not exist or is not a directory."
115
+ )
116
+ logger.info(
117
+ f"Loaded {len(self.past_agent_adapter_paths) - previous_count} past agent adapters from user-specified initial buffer paths."
118
+ )
119
+ self.past_agent_adapter_ids = list(self.past_agent_adapter_paths.keys())
120
+
121
+ # ID management for tracking adapter versions
122
+ self.adapter_train_ids = {
123
+ adapter_id: self.short_id_generator() for adapter_id in self.adapter_ids
124
+ }
125
+ # Initialize tokenizer
126
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
127
+ # Setup padding token to be same as EOS token
128
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
129
+ self.tokenizer.pad_token = self.tokenizer.eos_token
130
+
131
+ self.weights_got_updated: dict[AdapterID, bool] = {
132
+ adapter_id: False for adapter_id in self.adapter_ids
133
+ }
134
+ self.weights_got_updated.update(
135
+ {adapter_id: False for adapter_id in self.past_agent_adapter_ids}
136
+ )
137
+ self.current_lora_request = None
138
+ self.currently_loaded_adapter_id = None
139
+
140
+ # ---------------------------------------------------------
141
+ # Init HF model, peft adapters
142
+ # ---------------------------------------------------------
143
+ self.shared_hf_llm = AutoModelForCausalLM.from_pretrained(
144
+ pretrained_model_name_or_path=model_name,
145
+ **hf_kwargs,
146
+ )
147
+ self.hf_adapters = {}
148
+ self.optimizers = {}
149
+ for adapter_id in self.adapter_ids:
150
+ # Prefer output-folder path if it exists; else fall back to user-specified initial path if provided
151
+ output_path = os.path.join(self.save_path, adapter_id)
152
+ chosen_path: str | None = None
153
+ if os.path.isdir(output_path) and os.listdir(output_path):
154
+ chosen_path = output_path
155
+ logger.info(
156
+ f"Initializing adapter '{adapter_id}': using existing weights from output folder '{chosen_path}'."
157
+ )
158
+ elif (
159
+ self.initial_adapter_paths and adapter_id in self.initial_adapter_paths
160
+ ):
161
+ chosen_path = self.initial_adapter_paths[adapter_id]
162
+ logger.info(
163
+ f"Initializing adapter '{adapter_id}': using provided initial path '{chosen_path}'."
164
+ )
165
+ else:
166
+ logger.info(
167
+ f"Initializing adapter '{adapter_id}': no initial weights provided or found; starting from scratch."
168
+ )
169
+ hf_adapter = AdapterWrapper(
170
+ shared_llm=self.shared_hf_llm,
171
+ adapter_id=adapter_id,
172
+ lora_config=adapter_configs[adapter_id],
173
+ path=chosen_path,
174
+ ).to(device)
175
+ self.hf_adapters[adapter_id] = hf_adapter
176
+ # Persist current state of all adapters (ensures remote loads are cached to disk)
177
+ self.export_adapters()
178
+
179
+ # ---------------------------------------------------------
180
+ # Init inference inference_backend
181
+ # ---------------------------------------------------------
182
+
183
+ if inference_backend == "sglang":
184
+ self.inference_backend = SGLangOfflineBackend(
185
+ model_name=self.model_name,
186
+ save_path=self.save_path,
187
+ adapter_paths=self.adapter_paths,
188
+ tokenizer=self.tokenizer,
189
+ kwargs=inference_backend_init_kwargs,
190
+ )
191
+ elif inference_backend == "vllm":
192
+ self.inference_backend = VLLMAsyncBackend(
193
+ model_name=self.model_name,
194
+ # adapter_paths=self.adapter_paths,
195
+ tokenizer=self.tokenizer,
196
+ engine_init_kwargs=inference_backend_init_kwargs,
197
+ sampling_params=inference_backend_sampling_params,
198
+ )
199
+ elif inference_backend == "dummy":
200
+ self.inference_backend = DummyInferenceBackend()
201
+ else:
202
+ raise ValueError(f"Unknown inference_backend: {inference_backend}")
203
+
204
+ def reset_regex_retries_count(self) -> None:
205
+ self.regex_retries_count = 0
206
+
207
+ def get_inference_policies(self) -> dict[PolicyID, Callable]:
208
+ """
209
+ TOWRITE
210
+ """
211
+ policies = {}
212
+ for adapter_id in self.adapter_ids:
213
+ # define policy func
214
+ async def policy(
215
+ state: list[ChatTurn],
216
+ agent_id: str,
217
+ regex: str | None = None,
218
+ _adapter_id=adapter_id,
219
+ ):
220
+ self.prepare_adapter_for_inference(adapter_id=_adapter_id)
221
+ response = await self.get_action(state, agent_id, regex)
222
+ return response
223
+
224
+ policies[self.llm_id + "/" + adapter_id] = policy
225
+
226
+ for adapter_id in self.past_agent_adapter_ids:
227
+ # define policy func
228
+ async def policy(
229
+ state: list[ChatTurn],
230
+ agent_id: str,
231
+ regex: str | None = None,
232
+ _adapter_id=adapter_id,
233
+ ):
234
+ self.prepare_adapter_for_inference(adapter_id=_adapter_id)
235
+ response = await self.get_action(state, agent_id, regex)
236
+ return response
237
+
238
+ policies[self.llm_id + "/" + adapter_id] = policy
239
+ return policies
240
+
241
+ def get_adapter_modules(self) -> dict[PolicyID, nn.Module]:
242
+ """
243
+ Returns wrappers over the adapters which allows them be
244
+ interfaced like regular PyTorch models.
245
+ # TODO: create the adapter wrappers here
246
+ See adapter_wrapper.py
247
+ """
248
+ trainable_objects = {an: self.hf_adapters[an] for an in self.adapter_ids}
249
+ return trainable_objects
250
+
251
+ async def toggle_training_mode(self) -> None:
252
+ for adn in self.adapter_ids:
253
+ self.adapter_train_ids[adn] = self.short_id_generator()
254
+ await self.inference_backend.toggle_training_mode()
255
+
256
+ async def toggle_eval_mode(self) -> None:
257
+ await self.inference_backend.toggle_eval_mode()
258
+
259
+ def prepare_adapter_for_inference(self, adapter_id: AdapterID) -> None:
260
+ self.inference_backend.prepare_adapter(
261
+ adapter_id,
262
+ adapter_path=self.adapter_paths.get(
263
+ adapter_id, self.past_agent_adapter_paths.get(adapter_id, None)
264
+ ),
265
+ weights_got_updated=self.weights_got_updated[adapter_id],
266
+ )
267
+ self.currently_loaded_adapter_id = adapter_id
268
+ self.weights_got_updated[adapter_id] = False
269
+
270
+ # def _make_prompt_text(self, prompt: list[dict]) -> str:
271
+ # if self.enable_thinking is not None:
272
+ # prompt_text = self.tokenizer.apply_chat_template(
273
+ # prompt,
274
+ # tokenize=False,
275
+ # add_generation_prompt=True,
276
+ # enable_thinking=self.enable_thinking,
277
+ # )
278
+ # else:
279
+ # prompt_text = self.tokenizer.apply_chat_template(
280
+ # prompt,
281
+ # tokenize=False,
282
+ # add_generation_prompt=True,
283
+ # )
284
+
285
+ # return prompt_text
286
+
287
+ async def get_action(
288
+ self, state: list[ChatTurn], agent_id: str, regex: str | None = None
289
+ ) -> ChatTurn:
290
+ current_regex = regex if self.regex_max_attempts == -1 else None
291
+ pattern = re.compile(regex) if regex else None
292
+ nb_attempts = 0
293
+ state = state[:]
294
+ while True:
295
+ context_token_ids = chat_turns_to_token_ids(
296
+ chats=state,
297
+ tokenizer=self.tokenizer,
298
+ enable_thinking=self.enable_thinking,
299
+ )
300
+ # print(f"context is {self.tokenizer.decode(context_token_ids)}")
301
+ policy_output = await self.inference_backend.generate(
302
+ input_token_ids=context_token_ids.tolist(),
303
+ extract_thinking=(self.max_thinking_characters > 0),
304
+ regex=current_regex,
305
+ )
306
+ # print(f"generated: {self.tokenizer.decode(policy_output.out_token_ids)}")
307
+ if (
308
+ pattern is None
309
+ or (pattern.fullmatch(policy_output.content))
310
+ or (nb_attempts >= self.regex_max_attempts)
311
+ ):
312
+ return ChatTurn(
313
+ agent_id=agent_id,
314
+ role="assistant",
315
+ content=policy_output.content,
316
+ reasoning_content=policy_output.reasoning_content,
317
+ out_token_ids=policy_output.out_token_ids,
318
+ log_probs=policy_output.log_probs,
319
+ is_state_end=False,
320
+ )
321
+ else:
322
+ self.regex_retries_count += 1
323
+ nb_attempts += 1
324
+ logger.warning(
325
+ f"Response {policy_output.content} did not match regex: {regex}, retry {nb_attempts}/{self.regex_max_attempts}"
326
+ )
327
+ if nb_attempts == self.regex_max_attempts:
328
+ current_regex = regex
329
+ # regex_prompt = ChatTurn(
330
+ # role="user",
331
+ # content=f"Invalid response format. Expected format (regex): {current_regex}\n Please try again and provide ONLY a response that matches this regex.",
332
+ # reasoning_content=None,
333
+ # log_probs=None,
334
+ # out_token_ids=None,
335
+ # is_state_end=False,
336
+ # )
337
+ # state.append(regex_prompt)
338
+
339
+ def export_adapters(self) -> None:
340
+ """
341
+ Any peft wrapper, by default, saves all adapters, not just the one currently loaded.
342
+ """
343
+
344
+ # New version of the adapters available
345
+ for adapter_id in self.adapter_ids:
346
+ self.weights_got_updated[adapter_id] = True
347
+ for adapter_id in self.past_agent_adapter_ids:
348
+ self.weights_got_updated[adapter_id] = True
349
+
350
+ # import random
351
+ # self.save_path = self.save_path + str(random.randint(1,500))
352
+ # print(f"Save path: {self.save_path}")
353
+ # self.adapter_paths = {adapter_id:os.path.join(self.save_path, adapter_id) for adapter_id in self.adapter_ids}
354
+
355
+ adapter_id = self.adapter_ids[0]
356
+ self.hf_adapters[adapter_id].save_pretrained(self.save_path)
357
+
358
+ def checkpoint_all_adapters(self, checkpoint_indicator: str) -> None:
359
+ """
360
+ Checkpoints all adapters to the configured output directory.
361
+ """
362
+ adapter_id = self.adapter_ids[0]
363
+ output_dir = os.path.join(self.output_directory, "checkpoints")
364
+ os.makedirs(output_dir, exist_ok=True)
365
+ date_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
366
+ agent_adapter_dir = f"{adapter_id}-{checkpoint_indicator}-{date_str}"
367
+ export_path = os.path.join(output_dir, agent_adapter_dir)
368
+ for adapter_id in self.adapter_ids:
369
+ if "agent" in adapter_id:
370
+ self.past_agent_adapter_paths[
371
+ f"{agent_adapter_dir}_buffer"
372
+ ] = os.path.join(export_path, adapter_id)
373
+ self.past_agent_adapter_ids.append(f"{agent_adapter_dir}_buffer")
374
+ self.weights_got_updated[f"{agent_adapter_dir}_buffer"] = False
375
+ self.hf_adapters[adapter_id].save_pretrained(export_path)
376
+
377
+ def short_id_generator(self) -> str:
378
+ """
379
+ Generates a short unique ID for tracking adapter versions.
380
+
381
+ Returns:
382
+ int: An 8-digit integer ID.
383
+ """
384
+ return str(uuid.uuid4().int)[:8]
src_code_for_reproducibility/models/scalar_critic.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, torch.nn as nn, torch.optim as optim
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from peft import LoraConfig, get_peft_model
4
+
5
+ from mllm.models.adapter_training_wrapper import AdapterWrapper
6
+
7
+
8
+ class ScalarCritic(nn.Module):
9
+ """
10
+ A causal-LM critic_adapter + a scalar value head:
11
+ V_φ(s) = wᵀ h_last + b
12
+ Only LoRA adapters (inside critic_adapter) and the value head are trainable.
13
+ """
14
+ def __init__(self, critic_adapter: AdapterWrapper):
15
+ super().__init__()
16
+ self.critic_adapter = critic_adapter
17
+ hidden_size = self.critic_adapter.shared_llm.config.hidden_size
18
+ self.value_head = nn.Linear(hidden_size, 1).to(
19
+ dtype=critic_adapter.dtype,
20
+ device=critic_adapter.device)
21
+
22
+ def forward(self,
23
+ input_ids,
24
+ attention_mask=None,
25
+ **kwargs):
26
+ # AdapterWrapper activates its own adapter internally
27
+ outputs = self.critic_adapter(
28
+ input_ids=input_ids,
29
+ attention_mask=attention_mask,
30
+ output_hidden_states=True,
31
+ **kwargs,
32
+ )
33
+ h_last = outputs.hidden_states[-1] # (B, S, H)
34
+ values = self.value_head(h_last).squeeze(-1) # (B, S)
35
+ return values
36
+
37
+ def parameters(self, recurse: bool = True):
38
+ """Iterator over *trainable* parameters for this critic."""
39
+ # 1) LoRA params for *this* adapter
40
+ for p in self.critic_adapter.parameters():
41
+ yield p
42
+ # 2) scalar head
43
+ yield from self.value_head.parameters()
44
+
45
+ def gradient_checkpointing_enable(self, *args, **kwargs):
46
+ self.critic_adapter.gradient_checkpointing_enable(*args, **kwargs)
47
+
48
+ @property
49
+ def dtype(self):
50
+ return self.critic_adapter.dtype
51
+
52
+ @property
53
+ def device(self):
54
+ return self.critic_adapter.device
src_code_for_reproducibility/training/README.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Suppose we have a trajectory with 3 timesteps.
2
+ token: "0 1 2 3 4 5 6 7 8 9 . . . . ."
3
+ string: "A B C a b c A a A a b c A B C" (Capitalized = User, Lowercased = Assistant)
4
+ action_mask: "x x x ✓ ✓ ✓ x ✓ x ✓ ✓ ✓ x x x" (F = False, T = True)
5
+ rewards: "r r r r r r R R R R R R r r r"
6
+ timestep: "0 0 0 0 0 0 1 1 1 1 1 1 2 2 2"
7
+ state_ends: "x x ✓ x x x ✓ x x x x x x x ✓"
8
+
9
+ There must be one baseline flag per timestep!
10
+
11
+ Then, we might have
12
+
13
+ A naive way to interpret this is to think of the number of assistant messages as the number of
14
+ steps in the environment. However, this is not the case in practice. Indeed, in a
15
+ single simulation step,
16
+
17
+
18
+
19
+
20
+ A subtlety arises with credit assignment. In the multi-agent case, we might