Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .hydra/hydra.yaml +154 -0
- .hydra/overrides.yaml +1 -0
- run.log +0 -0
- seed_9999/Qwen/Qwen2.5-7B-Instruct/adapters/README.md +207 -0
- seed_9999/Qwen/Qwen2.5-7B-Instruct/adapters/agent_adapter/adapter_config.json +42 -0
- seed_9999/Qwen/Qwen2.5-7B-Instruct/adapters/critic_adapter/adapter_config.json +42 -0
- src_code_for_reproducibility/__init__.py +0 -0
- src_code_for_reproducibility/__pycache__/__init__.cpython-312.pyc +0 -0
- src_code_for_reproducibility/chat_utils/__pycache__/chat_turn.cpython-312.pyc +0 -0
- src_code_for_reproducibility/docs/Makefile +19 -0
- src_code_for_reproducibility/docs/make.bat +35 -0
- src_code_for_reproducibility/docs/source/environments/diplomacy.rst +459 -0
- src_code_for_reproducibility/docs/source/installation.rst +10 -0
- src_code_for_reproducibility/docs/source/media/runbatch.png +0 -0
- src_code_for_reproducibility/docs/source/src.environments.dond.dond_agent.rst +7 -0
- src_code_for_reproducibility/docs/source/src.environments.dond.dond_game.rst +7 -0
- src_code_for_reproducibility/docs/source/src.environments.dond.dond_log_funcs.rst +7 -0
- src_code_for_reproducibility/docs/source/src.environments.dond.dond_return_funcs.rst +7 -0
- src_code_for_reproducibility/docs/source/src.environments.dond.dond_statistics_funcs.rst +7 -0
- src_code_for_reproducibility/docs/source/src.environments.environment_imports.rst +7 -0
- src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_game.rst +7 -0
- src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_log_funcs.rst +7 -0
- src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_statistics_funcs.rst +7 -0
- src_code_for_reproducibility/docs/source/src.environments.rst +25 -0
- src_code_for_reproducibility/docs/source/src.experiments.dond_run_train.rst +7 -0
- src_code_for_reproducibility/docs/source/src.experiments.generate_and_train.rst +7 -0
- src_code_for_reproducibility/docs/source/src.experiments.last_completion.rst +7 -0
- src_code_for_reproducibility/docs/source/src.generation.rst +15 -0
- src_code_for_reproducibility/docs/source/src.models.dummy_hf_agent.rst +7 -0
- src_code_for_reproducibility/docs/source/src.models.rst +20 -0
- src_code_for_reproducibility/docs/source/src.models.updatable_worker.rst +7 -0
- src_code_for_reproducibility/docs/source/src.models.vllm_worker_wrap.rst +7 -0
- src_code_for_reproducibility/docs/source/src.training.ppo_train.rst +7 -0
- src_code_for_reproducibility/docs/source/src.training.reinforce_training.rst +7 -0
- src_code_for_reproducibility/docs/source/src.training.rl_convs_processing.rst +7 -0
- src_code_for_reproducibility/docs/source/src.training.rst +19 -0
- src_code_for_reproducibility/docs/source/src.training.train_main.rst +7 -0
- src_code_for_reproducibility/docs/source/src.utils.common_imports.rst +7 -0
- src_code_for_reproducibility/docs/source/src.utils.log_statistics.rst +7 -0
- src_code_for_reproducibility/docs/source/src.utils.parallel_shuffle.rst +7 -0
- src_code_for_reproducibility/docs/source/src.utils.rst +24 -0
- src_code_for_reproducibility/markov_games/__init__.py +0 -0
- src_code_for_reproducibility/markov_games/agent.py +76 -0
- src_code_for_reproducibility/markov_games/alternative_actions_runner.py +138 -0
- src_code_for_reproducibility/markov_games/group_timesteps.py +150 -0
- src_code_for_reproducibility/markov_games/linear_runner.py +30 -0
- src_code_for_reproducibility/markov_games/markov_game.py +208 -0
- src_code_for_reproducibility/markov_games/mg_utils.py +89 -0
- src_code_for_reproducibility/markov_games/negotiation/__pycache__/nego_hard_coded_policies.cpython-312.pyc +0 -0
- src_code_for_reproducibility/markov_games/rollout_tree.py +86 -0
.hydra/hydra.yaml
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hydra:
|
| 2 |
+
run:
|
| 3 |
+
dir: ${oc.env:SCRATCH}/llm_negotiation/${now:%Y_%m}/${experiment.name}
|
| 4 |
+
sweep:
|
| 5 |
+
dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
| 6 |
+
subdir: ${hydra.job.num}
|
| 7 |
+
launcher:
|
| 8 |
+
_target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
|
| 9 |
+
sweeper:
|
| 10 |
+
_target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
|
| 11 |
+
max_batch_size: null
|
| 12 |
+
params: null
|
| 13 |
+
help:
|
| 14 |
+
app_name: ${hydra.job.name}
|
| 15 |
+
header: '${hydra.help.app_name} is powered by Hydra.
|
| 16 |
+
|
| 17 |
+
'
|
| 18 |
+
footer: 'Powered by Hydra (https://hydra.cc)
|
| 19 |
+
|
| 20 |
+
Use --hydra-help to view Hydra specific help
|
| 21 |
+
|
| 22 |
+
'
|
| 23 |
+
template: '${hydra.help.header}
|
| 24 |
+
|
| 25 |
+
== Configuration groups ==
|
| 26 |
+
|
| 27 |
+
Compose your configuration from those groups (group=option)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
$APP_CONFIG_GROUPS
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
== Config ==
|
| 34 |
+
|
| 35 |
+
Override anything in the config (foo.bar=value)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
$CONFIG
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
${hydra.help.footer}
|
| 42 |
+
|
| 43 |
+
'
|
| 44 |
+
hydra_help:
|
| 45 |
+
template: 'Hydra (${hydra.runtime.version})
|
| 46 |
+
|
| 47 |
+
See https://hydra.cc for more info.
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
== Flags ==
|
| 51 |
+
|
| 52 |
+
$FLAGS_HELP
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
== Configuration groups ==
|
| 56 |
+
|
| 57 |
+
Compose your configuration from those groups (For example, append hydra/job_logging=disabled
|
| 58 |
+
to command line)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
$HYDRA_CONFIG_GROUPS
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
Use ''--cfg hydra'' to Show the Hydra config.
|
| 65 |
+
|
| 66 |
+
'
|
| 67 |
+
hydra_help: ???
|
| 68 |
+
hydra_logging:
|
| 69 |
+
version: 1
|
| 70 |
+
formatters:
|
| 71 |
+
simple:
|
| 72 |
+
format: '[%(asctime)s][HYDRA] %(message)s'
|
| 73 |
+
handlers:
|
| 74 |
+
console:
|
| 75 |
+
class: logging.StreamHandler
|
| 76 |
+
formatter: simple
|
| 77 |
+
stream: ext://sys.stdout
|
| 78 |
+
root:
|
| 79 |
+
level: INFO
|
| 80 |
+
handlers:
|
| 81 |
+
- console
|
| 82 |
+
loggers:
|
| 83 |
+
logging_example:
|
| 84 |
+
level: DEBUG
|
| 85 |
+
disable_existing_loggers: false
|
| 86 |
+
job_logging:
|
| 87 |
+
version: 1
|
| 88 |
+
formatters:
|
| 89 |
+
simple:
|
| 90 |
+
format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
|
| 91 |
+
handlers:
|
| 92 |
+
console:
|
| 93 |
+
class: logging.StreamHandler
|
| 94 |
+
formatter: simple
|
| 95 |
+
stream: ext://sys.stdout
|
| 96 |
+
file:
|
| 97 |
+
class: logging.FileHandler
|
| 98 |
+
formatter: simple
|
| 99 |
+
filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
|
| 100 |
+
root:
|
| 101 |
+
level: INFO
|
| 102 |
+
handlers:
|
| 103 |
+
- console
|
| 104 |
+
- file
|
| 105 |
+
disable_existing_loggers: false
|
| 106 |
+
env: {}
|
| 107 |
+
mode: RUN
|
| 108 |
+
searchpath: []
|
| 109 |
+
callbacks: {}
|
| 110 |
+
output_subdir: .hydra
|
| 111 |
+
overrides:
|
| 112 |
+
hydra:
|
| 113 |
+
- hydra.mode=RUN
|
| 114 |
+
task: []
|
| 115 |
+
job:
|
| 116 |
+
name: run
|
| 117 |
+
chdir: false
|
| 118 |
+
override_dirname: ''
|
| 119 |
+
id: ???
|
| 120 |
+
num: ???
|
| 121 |
+
config_name: no_press_10_1_ties_ad_align_nocurrtimestep_seed9999.yaml
|
| 122 |
+
env_set: {}
|
| 123 |
+
env_copy: []
|
| 124 |
+
config:
|
| 125 |
+
override_dirname:
|
| 126 |
+
kv_sep: '='
|
| 127 |
+
item_sep: ','
|
| 128 |
+
exclude_keys: []
|
| 129 |
+
runtime:
|
| 130 |
+
version: 1.3.2
|
| 131 |
+
version_base: '1.1'
|
| 132 |
+
cwd: /scratch/m/muqeeth/llm_negotiation
|
| 133 |
+
config_sources:
|
| 134 |
+
- path: hydra.conf
|
| 135 |
+
schema: pkg
|
| 136 |
+
provider: hydra
|
| 137 |
+
- path: /scratch/m/muqeeth/llm_negotiation/configs
|
| 138 |
+
schema: file
|
| 139 |
+
provider: main
|
| 140 |
+
- path: ''
|
| 141 |
+
schema: structured
|
| 142 |
+
provider: schema
|
| 143 |
+
output_dir: /scratch/m/muqeeth/llm_negotiation/2025_11/no_press_10_1_ties_ad_align_nocurrtimestep_seed9999
|
| 144 |
+
choices:
|
| 145 |
+
hydra/env: default
|
| 146 |
+
hydra/callbacks: null
|
| 147 |
+
hydra/job_logging: default
|
| 148 |
+
hydra/hydra_logging: default
|
| 149 |
+
hydra/hydra_help: default
|
| 150 |
+
hydra/help: default
|
| 151 |
+
hydra/sweeper: basic
|
| 152 |
+
hydra/launcher: basic
|
| 153 |
+
hydra/output: default
|
| 154 |
+
verbose: false
|
.hydra/overrides.yaml
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
[]
|
run.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
seed_9999/Qwen/Qwen2.5-7B-Instruct/adapters/README.md
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
base_model: Qwen/Qwen2.5-7B-Instruct
|
| 3 |
+
library_name: peft
|
| 4 |
+
pipeline_tag: text-generation
|
| 5 |
+
tags:
|
| 6 |
+
- base_model:adapter:Qwen/Qwen2.5-7B-Instruct
|
| 7 |
+
- lora
|
| 8 |
+
- transformers
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# Model Card for Model ID
|
| 12 |
+
|
| 13 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
## Model Details
|
| 18 |
+
|
| 19 |
+
### Model Description
|
| 20 |
+
|
| 21 |
+
<!-- Provide a longer summary of what this model is. -->
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
- **Developed by:** [More Information Needed]
|
| 26 |
+
- **Funded by [optional]:** [More Information Needed]
|
| 27 |
+
- **Shared by [optional]:** [More Information Needed]
|
| 28 |
+
- **Model type:** [More Information Needed]
|
| 29 |
+
- **Language(s) (NLP):** [More Information Needed]
|
| 30 |
+
- **License:** [More Information Needed]
|
| 31 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
| 32 |
+
|
| 33 |
+
### Model Sources [optional]
|
| 34 |
+
|
| 35 |
+
<!-- Provide the basic links for the model. -->
|
| 36 |
+
|
| 37 |
+
- **Repository:** [More Information Needed]
|
| 38 |
+
- **Paper [optional]:** [More Information Needed]
|
| 39 |
+
- **Demo [optional]:** [More Information Needed]
|
| 40 |
+
|
| 41 |
+
## Uses
|
| 42 |
+
|
| 43 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 44 |
+
|
| 45 |
+
### Direct Use
|
| 46 |
+
|
| 47 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
| 48 |
+
|
| 49 |
+
[More Information Needed]
|
| 50 |
+
|
| 51 |
+
### Downstream Use [optional]
|
| 52 |
+
|
| 53 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
| 54 |
+
|
| 55 |
+
[More Information Needed]
|
| 56 |
+
|
| 57 |
+
### Out-of-Scope Use
|
| 58 |
+
|
| 59 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
| 60 |
+
|
| 61 |
+
[More Information Needed]
|
| 62 |
+
|
| 63 |
+
## Bias, Risks, and Limitations
|
| 64 |
+
|
| 65 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
| 66 |
+
|
| 67 |
+
[More Information Needed]
|
| 68 |
+
|
| 69 |
+
### Recommendations
|
| 70 |
+
|
| 71 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
| 72 |
+
|
| 73 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
| 74 |
+
|
| 75 |
+
## How to Get Started with the Model
|
| 76 |
+
|
| 77 |
+
Use the code below to get started with the model.
|
| 78 |
+
|
| 79 |
+
[More Information Needed]
|
| 80 |
+
|
| 81 |
+
## Training Details
|
| 82 |
+
|
| 83 |
+
### Training Data
|
| 84 |
+
|
| 85 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 86 |
+
|
| 87 |
+
[More Information Needed]
|
| 88 |
+
|
| 89 |
+
### Training Procedure
|
| 90 |
+
|
| 91 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
| 92 |
+
|
| 93 |
+
#### Preprocessing [optional]
|
| 94 |
+
|
| 95 |
+
[More Information Needed]
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
#### Training Hyperparameters
|
| 99 |
+
|
| 100 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 101 |
+
|
| 102 |
+
#### Speeds, Sizes, Times [optional]
|
| 103 |
+
|
| 104 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
| 105 |
+
|
| 106 |
+
[More Information Needed]
|
| 107 |
+
|
| 108 |
+
## Evaluation
|
| 109 |
+
|
| 110 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
| 111 |
+
|
| 112 |
+
### Testing Data, Factors & Metrics
|
| 113 |
+
|
| 114 |
+
#### Testing Data
|
| 115 |
+
|
| 116 |
+
<!-- This should link to a Dataset Card if possible. -->
|
| 117 |
+
|
| 118 |
+
[More Information Needed]
|
| 119 |
+
|
| 120 |
+
#### Factors
|
| 121 |
+
|
| 122 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 123 |
+
|
| 124 |
+
[More Information Needed]
|
| 125 |
+
|
| 126 |
+
#### Metrics
|
| 127 |
+
|
| 128 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 129 |
+
|
| 130 |
+
[More Information Needed]
|
| 131 |
+
|
| 132 |
+
### Results
|
| 133 |
+
|
| 134 |
+
[More Information Needed]
|
| 135 |
+
|
| 136 |
+
#### Summary
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
## Model Examination [optional]
|
| 141 |
+
|
| 142 |
+
<!-- Relevant interpretability work for the model goes here -->
|
| 143 |
+
|
| 144 |
+
[More Information Needed]
|
| 145 |
+
|
| 146 |
+
## Environmental Impact
|
| 147 |
+
|
| 148 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 149 |
+
|
| 150 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 151 |
+
|
| 152 |
+
- **Hardware Type:** [More Information Needed]
|
| 153 |
+
- **Hours used:** [More Information Needed]
|
| 154 |
+
- **Cloud Provider:** [More Information Needed]
|
| 155 |
+
- **Compute Region:** [More Information Needed]
|
| 156 |
+
- **Carbon Emitted:** [More Information Needed]
|
| 157 |
+
|
| 158 |
+
## Technical Specifications [optional]
|
| 159 |
+
|
| 160 |
+
### Model Architecture and Objective
|
| 161 |
+
|
| 162 |
+
[More Information Needed]
|
| 163 |
+
|
| 164 |
+
### Compute Infrastructure
|
| 165 |
+
|
| 166 |
+
[More Information Needed]
|
| 167 |
+
|
| 168 |
+
#### Hardware
|
| 169 |
+
|
| 170 |
+
[More Information Needed]
|
| 171 |
+
|
| 172 |
+
#### Software
|
| 173 |
+
|
| 174 |
+
[More Information Needed]
|
| 175 |
+
|
| 176 |
+
## Citation [optional]
|
| 177 |
+
|
| 178 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 179 |
+
|
| 180 |
+
**BibTeX:**
|
| 181 |
+
|
| 182 |
+
[More Information Needed]
|
| 183 |
+
|
| 184 |
+
**APA:**
|
| 185 |
+
|
| 186 |
+
[More Information Needed]
|
| 187 |
+
|
| 188 |
+
## Glossary [optional]
|
| 189 |
+
|
| 190 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 191 |
+
|
| 192 |
+
[More Information Needed]
|
| 193 |
+
|
| 194 |
+
## More Information [optional]
|
| 195 |
+
|
| 196 |
+
[More Information Needed]
|
| 197 |
+
|
| 198 |
+
## Model Card Authors [optional]
|
| 199 |
+
|
| 200 |
+
[More Information Needed]
|
| 201 |
+
|
| 202 |
+
## Model Card Contact
|
| 203 |
+
|
| 204 |
+
[More Information Needed]
|
| 205 |
+
### Framework versions
|
| 206 |
+
|
| 207 |
+
- PEFT 0.17.1
|
seed_9999/Qwen/Qwen2.5-7B-Instruct/adapters/agent_adapter/adapter_config.json
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"alpha_pattern": {},
|
| 3 |
+
"auto_mapping": null,
|
| 4 |
+
"base_model_name_or_path": "Qwen/Qwen2.5-7B-Instruct",
|
| 5 |
+
"bias": "none",
|
| 6 |
+
"corda_config": null,
|
| 7 |
+
"eva_config": null,
|
| 8 |
+
"exclude_modules": null,
|
| 9 |
+
"fan_in_fan_out": false,
|
| 10 |
+
"inference_mode": true,
|
| 11 |
+
"init_lora_weights": true,
|
| 12 |
+
"layer_replication": null,
|
| 13 |
+
"layers_pattern": null,
|
| 14 |
+
"layers_to_transform": null,
|
| 15 |
+
"loftq_config": {},
|
| 16 |
+
"lora_alpha": 64,
|
| 17 |
+
"lora_bias": false,
|
| 18 |
+
"lora_dropout": 0.0,
|
| 19 |
+
"megatron_config": null,
|
| 20 |
+
"megatron_core": "megatron.core",
|
| 21 |
+
"modules_to_save": null,
|
| 22 |
+
"peft_type": "LORA",
|
| 23 |
+
"qalora_group_size": 16,
|
| 24 |
+
"r": 32,
|
| 25 |
+
"rank_pattern": {},
|
| 26 |
+
"revision": null,
|
| 27 |
+
"target_modules": [
|
| 28 |
+
"v_proj",
|
| 29 |
+
"k_proj",
|
| 30 |
+
"gate_proj",
|
| 31 |
+
"q_proj",
|
| 32 |
+
"up_proj",
|
| 33 |
+
"down_proj",
|
| 34 |
+
"o_proj"
|
| 35 |
+
],
|
| 36 |
+
"target_parameters": null,
|
| 37 |
+
"task_type": "CAUSAL_LM",
|
| 38 |
+
"trainable_token_indices": null,
|
| 39 |
+
"use_dora": false,
|
| 40 |
+
"use_qalora": false,
|
| 41 |
+
"use_rslora": false
|
| 42 |
+
}
|
seed_9999/Qwen/Qwen2.5-7B-Instruct/adapters/critic_adapter/adapter_config.json
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"alpha_pattern": {},
|
| 3 |
+
"auto_mapping": null,
|
| 4 |
+
"base_model_name_or_path": "Qwen/Qwen2.5-7B-Instruct",
|
| 5 |
+
"bias": "none",
|
| 6 |
+
"corda_config": null,
|
| 7 |
+
"eva_config": null,
|
| 8 |
+
"exclude_modules": null,
|
| 9 |
+
"fan_in_fan_out": false,
|
| 10 |
+
"inference_mode": true,
|
| 11 |
+
"init_lora_weights": true,
|
| 12 |
+
"layer_replication": null,
|
| 13 |
+
"layers_pattern": null,
|
| 14 |
+
"layers_to_transform": null,
|
| 15 |
+
"loftq_config": {},
|
| 16 |
+
"lora_alpha": 64,
|
| 17 |
+
"lora_bias": false,
|
| 18 |
+
"lora_dropout": 0.0,
|
| 19 |
+
"megatron_config": null,
|
| 20 |
+
"megatron_core": "megatron.core",
|
| 21 |
+
"modules_to_save": null,
|
| 22 |
+
"peft_type": "LORA",
|
| 23 |
+
"qalora_group_size": 16,
|
| 24 |
+
"r": 32,
|
| 25 |
+
"rank_pattern": {},
|
| 26 |
+
"revision": null,
|
| 27 |
+
"target_modules": [
|
| 28 |
+
"v_proj",
|
| 29 |
+
"k_proj",
|
| 30 |
+
"gate_proj",
|
| 31 |
+
"q_proj",
|
| 32 |
+
"up_proj",
|
| 33 |
+
"down_proj",
|
| 34 |
+
"o_proj"
|
| 35 |
+
],
|
| 36 |
+
"target_parameters": null,
|
| 37 |
+
"task_type": "CAUSAL_LM",
|
| 38 |
+
"trainable_token_indices": null,
|
| 39 |
+
"use_dora": false,
|
| 40 |
+
"use_qalora": false,
|
| 41 |
+
"use_rslora": false
|
| 42 |
+
}
|
src_code_for_reproducibility/__init__.py
ADDED
|
File without changes
|
src_code_for_reproducibility/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (148 Bytes). View file
|
|
|
src_code_for_reproducibility/chat_utils/__pycache__/chat_turn.cpython-312.pyc
ADDED
|
Binary file (1.32 kB). View file
|
|
|
src_code_for_reproducibility/docs/Makefile
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Minimal makefile for Sphinx documentation
|
| 2 |
+
|
| 3 |
+
# You can set these variables from the command line, and also
|
| 4 |
+
# from the environment for the first two.
|
| 5 |
+
SPHINXOPTS ?=
|
| 6 |
+
SPHINXBUILD ?= sphinx-build
|
| 7 |
+
SOURCEDIR = source
|
| 8 |
+
BUILDDIR = build
|
| 9 |
+
|
| 10 |
+
# Put it first so that "make" without argument is like "make help".
|
| 11 |
+
help:
|
| 12 |
+
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(SPHINXFLAGS)
|
| 13 |
+
|
| 14 |
+
.PHONY: help Makefile
|
| 15 |
+
|
| 16 |
+
# Catch-all target: route all unknown targets to Sphinx using the new
|
| 17 |
+
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
| 18 |
+
%: Makefile
|
| 19 |
+
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(SPHINXFLAGS)
|
src_code_for_reproducibility/docs/make.bat
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@ECHO OFF
|
| 2 |
+
|
| 3 |
+
pushd %~dp0
|
| 4 |
+
|
| 5 |
+
REM Command file for Sphinx documentation
|
| 6 |
+
|
| 7 |
+
if "%SPHINXBUILD%" == "" (
|
| 8 |
+
set SPHINXBUILD=sphinx-build
|
| 9 |
+
)
|
| 10 |
+
set SOURCEDIR=source
|
| 11 |
+
set BUILDDIR=build
|
| 12 |
+
|
| 13 |
+
%SPHINXBUILD% >NUL 2>NUL
|
| 14 |
+
if errorlevel 9009 (
|
| 15 |
+
echo.
|
| 16 |
+
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
| 17 |
+
echo.installed, then set the SPHINXBUILD environment variable to point
|
| 18 |
+
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
| 19 |
+
echo.may add the Sphinx directory to PATH.
|
| 20 |
+
echo.
|
| 21 |
+
echo.If you don't have Sphinx installed, grab it from
|
| 22 |
+
echo.https://www.sphinx-doc.org/
|
| 23 |
+
exit /b 1
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
if "%1" == "" goto help
|
| 27 |
+
|
| 28 |
+
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
| 29 |
+
goto end
|
| 30 |
+
|
| 31 |
+
:help
|
| 32 |
+
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
| 33 |
+
|
| 34 |
+
:end
|
| 35 |
+
popd
|
src_code_for_reproducibility/docs/source/environments/diplomacy.rst
ADDED
|
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
=================
|
| 2 |
+
Diplomacy
|
| 3 |
+
=================
|
| 4 |
+
|
| 5 |
+
The Diplomacy environment provides a multi-agent negotiation interface for the classic board game Diplomacy,
|
| 6 |
+
based on DeepMind's implementation. This document describes the API for interacting with the Diplomacy environment
|
| 7 |
+
and its associated agent handler.
|
| 8 |
+
|
| 9 |
+
Overview
|
| 10 |
+
--------
|
| 11 |
+
|
| 12 |
+
Diplomacy is a strategic board game set in Europe before World War I, where players control one of seven European powers
|
| 13 |
+
and negotiate with each other to gain control of supply centers. The game is played in turns, with each turn consisting
|
| 14 |
+
of movement phases, retreat phases, and build phases.
|
| 15 |
+
|
| 16 |
+
Our implementation adapts DeepMind's Diplomacy code to the Multi-Agent Negotiation Environment standard, allowing it
|
| 17 |
+
to be used with LLM agents through a text-based interface.
|
| 18 |
+
|
| 19 |
+
Game Rules
|
| 20 |
+
----------
|
| 21 |
+
|
| 22 |
+
### Game Board and Powers
|
| 23 |
+
|
| 24 |
+
Diplomacy is played on a map of Europe divided into provinces. The game features seven Great Powers that players can control:
|
| 25 |
+
|
| 26 |
+
- England (blue)
|
| 27 |
+
- France (light blue)
|
| 28 |
+
- Germany (black)
|
| 29 |
+
- Italy (green)
|
| 30 |
+
- Austria-Hungary (red)
|
| 31 |
+
- Russia (white)
|
| 32 |
+
- Turkey (yellow)
|
| 33 |
+
|
| 34 |
+
Each power begins with three supply centers (except Russia, which starts with four) and an equal number of units.
|
| 35 |
+
|
| 36 |
+
### Units and Movement
|
| 37 |
+
|
| 38 |
+
There are two types of units in Diplomacy:
|
| 39 |
+
- **Armies (A)**: Can move to adjacent land provinces or be convoyed across water by fleets
|
| 40 |
+
- **Fleets (F)**: Can move to adjacent coastal provinces and sea regions
|
| 41 |
+
|
| 42 |
+
During movement phases, each unit can execute one of these orders:
|
| 43 |
+
- **Hold**: The unit remains in its current province (e.g., "A PAR H")
|
| 44 |
+
- Format: [Unit Type] [Province] H
|
| 45 |
+
- Example: "A PAR H" means "Army in Paris holds its position"
|
| 46 |
+
|
| 47 |
+
- **Move**: The unit attempts to move to an adjacent province (e.g., "A PAR - BUR")
|
| 48 |
+
- Format: [Unit Type] [Current Province] - [Destination Province]
|
| 49 |
+
- Example: "A PAR - BUR" means "Army in Paris moves to Burgundy"
|
| 50 |
+
- Example: "F BRE - ENG" means "Fleet in Brest moves to the English Channel"
|
| 51 |
+
|
| 52 |
+
- **Support**: The unit supports another unit's move or hold (e.g., "A PAR S A MAR - BUR")
|
| 53 |
+
- Format for supporting a move: [Unit Type] [Province] S [Unit Type] [Province] - [Destination]
|
| 54 |
+
- Format for supporting a hold: [Unit Type] [Province] S [Unit Type] [Province]
|
| 55 |
+
- Example: "A PAR S A MAR - BUR" means "Army in Paris supports the Army in Marseille's move to Burgundy"
|
| 56 |
+
- Example: "F LON S F NTH" means "Fleet in London supports the Fleet in North Sea holding its position"
|
| 57 |
+
|
| 58 |
+
- **Convoy**: A fleet can convoy an army across water (e.g., "F ENG C A LON - BRE")
|
| 59 |
+
- Format: [Fleet] [Sea Province] C [Army] [Coastal Province] - [Coastal Province]
|
| 60 |
+
- Example: "F ENG C A LON - BRE" means "Fleet in English Channel convoys the Army in London to Brest"
|
| 61 |
+
|
| 62 |
+
All orders are executed simultaneously, and conflicts are resolved based on strength (number of supporting units).
|
| 63 |
+
|
| 64 |
+
### Common Province Abbreviations
|
| 65 |
+
|
| 66 |
+
Diplomacy uses three-letter abbreviations for provinces. Some common ones include:
|
| 67 |
+
- **PAR**: Paris
|
| 68 |
+
- **LON**: London
|
| 69 |
+
- **BER**: Berlin
|
| 70 |
+
- **MUN**: Munich
|
| 71 |
+
- **BUR**: Burgundy
|
| 72 |
+
- **MAR**: Marseilles
|
| 73 |
+
- **BRE**: Brest
|
| 74 |
+
- **ENG**: English Channel
|
| 75 |
+
- **NTH**: North Sea
|
| 76 |
+
- **VIE**: Vienna
|
| 77 |
+
- **ROM**: Rome
|
| 78 |
+
- **VEN**: Venice
|
| 79 |
+
- **MOW**: Moscow
|
| 80 |
+
- **CON**: Constantinople
|
| 81 |
+
|
| 82 |
+
### Example: Movement and Conflicts
|
| 83 |
+
|
| 84 |
+
For example, if France orders "A PAR - BUR" and Germany orders "A MUN - BUR", neither move succeeds as they have equal strength. However, if France also orders "A MAR S A PAR - BUR", then the French army from Paris would successfully move to Burgundy with strength of 2 against Germany's strength of 1.
|
| 85 |
+
|
| 86 |
+
### Turn Structure
|
| 87 |
+
|
| 88 |
+
A game year consists of five phases:
|
| 89 |
+
1. **Spring Movement**: All powers submit orders for their units
|
| 90 |
+
2. **Spring Retreat**: Units dislodged in the movement phase must retreat or be disbanded
|
| 91 |
+
3. **Fall Movement**: Another round of movement orders
|
| 92 |
+
4. **Fall Retreat**: Retreat orders for dislodged units
|
| 93 |
+
5. **Winter Adjustment**: Powers gain or lose units based on the number of supply centers they control
|
| 94 |
+
|
| 95 |
+
### Supply Centers and Building
|
| 96 |
+
|
| 97 |
+
Supply centers (marked on the map) are key to victory. When a power occupies a supply center during a Fall turn, they gain control of it. During the Winter Adjustment phase:
|
| 98 |
+
- If you control more supply centers than you have units, you can build new units in your home supply centers
|
| 99 |
+
- If you control fewer supply centers than you have units, you must remove excess units
|
| 100 |
+
|
| 101 |
+
### Example: Building and Removing Units
|
| 102 |
+
|
| 103 |
+
If France controls 5 supply centers but only has 4 units, during the Winter phase they can build one new unit in an unoccupied home supply center (Paris, Marseilles, or Brest). Conversely, if France controls only 3 supply centers but has 4 units, they must remove one unit of their choice.
|
| 104 |
+
|
| 105 |
+
### Negotiation
|
| 106 |
+
|
| 107 |
+
A critical component of Diplomacy is the negotiation between players. Before submitting orders, players can communicate freely to form alliances, coordinate attacks, or mislead opponents. These negotiations are not binding, and betrayal is a common strategy.
|
| 108 |
+
|
| 109 |
+
### Example: Alliance and Betrayal
|
| 110 |
+
|
| 111 |
+
England and France might agree to an alliance against Germany, with England promising to support France's move into Belgium. However, England could secretly order their fleet to move into Belgium themselves or support a German move instead.
|
| 112 |
+
|
| 113 |
+
### Victory Conditions
|
| 114 |
+
|
| 115 |
+
The game ends when one power controls 18 or more supply centers (majority of the 34 total centers), or when players agree to a draw. In tournament settings, games may also end after a predetermined number of game years.
|
| 116 |
+
|
| 117 |
+
DiplomacyEnv
|
| 118 |
+
------------
|
| 119 |
+
|
| 120 |
+
The ``DiplomacyEnv`` class provides an interface to the Diplomacy game environment that follows the Multi-Agent
|
| 121 |
+
Negotiation Environment standard.
|
| 122 |
+
|
| 123 |
+
.. code-block:: python
|
| 124 |
+
|
| 125 |
+
class DiplomacyEnv:
|
| 126 |
+
"""
|
| 127 |
+
Multi-Agent Negotiation Environment for Diplomacy, adapting Deepmind's implementation
|
| 128 |
+
to the MarlEnvironment standard.
|
| 129 |
+
"""
|
| 130 |
+
def __init__(self,
|
| 131 |
+
initial_state: Optional[DiplomacyState] = None,
|
| 132 |
+
max_turns: int = 100,
|
| 133 |
+
points_per_supply_centre: bool = True,
|
| 134 |
+
forced_draw_probability: float = 0.0,
|
| 135 |
+
min_years_forced_draw: int = 35):
|
| 136 |
+
"""Initialize the Diplomacy environment.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
initial_state: Initial DiplomacyState (optional)
|
| 140 |
+
max_turns: Maximum number of turns in the game
|
| 141 |
+
points_per_supply_centre: Whether to award points per supply center in case of a draw
|
| 142 |
+
forced_draw_probability: Probability of forcing a draw after min_years_forced_draw
|
| 143 |
+
min_years_forced_draw: Minimum years before considering a forced draw
|
| 144 |
+
"""
|
| 145 |
+
# ...
|
| 146 |
+
|
| 147 |
+
def reset(self):
|
| 148 |
+
"""Reset the environment to an initial state and return the initial observation.
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
observation (dict): A dictionary where keys are agent identifiers and values are observations.
|
| 152 |
+
Each observation contains:
|
| 153 |
+
- board_state: Current state of the board
|
| 154 |
+
- current_season: Current season in the game
|
| 155 |
+
- player_index: Index of the player's power
|
| 156 |
+
- possible_actions: List of possible actions in DeepMind's format
|
| 157 |
+
- human_readable_actions: List of human-readable action descriptions
|
| 158 |
+
- supply_centers: List of supply centers owned by the player
|
| 159 |
+
- units: List of units owned by the player
|
| 160 |
+
- year: Current year in the game
|
| 161 |
+
"""
|
| 162 |
+
# ...
|
| 163 |
+
|
| 164 |
+
def step(self, actions):
|
| 165 |
+
"""Take a step in the environment using the provided actions.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
actions (dict): A dictionary where keys are agent identifiers and values are actions.
|
| 169 |
+
Actions can be:
|
| 170 |
+
- List of integer actions in DeepMind's format
|
| 171 |
+
- List of string actions in text format (e.g., "A MUN - BER")
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
observations (dict): A dictionary where keys are agent identifiers and values are observations.
|
| 175 |
+
Each observation has the same structure as in reset().
|
| 176 |
+
done (bool): Whether the episode has ended.
|
| 177 |
+
info (dict): Additional information about the environment, including:
|
| 178 |
+
- turn: Current turn number
|
| 179 |
+
- returns: Game returns if the game is done, otherwise None
|
| 180 |
+
- waiting_for: List of agents that still need to provide actions (if not all actions are provided)
|
| 181 |
+
"""
|
| 182 |
+
# ...
|
| 183 |
+
|
| 184 |
+
def get_log_info(self):
|
| 185 |
+
"""Get additional information about the environment for logging.
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
log_info (dict): Information about the environment required to log the game, including:
|
| 189 |
+
- power_names: List of power names
|
| 190 |
+
- game_history: History of the game
|
| 191 |
+
- current_turn: Current turn number
|
| 192 |
+
- current_season: Current season name
|
| 193 |
+
- supply_centers: Dictionary mapping power names to supply center counts
|
| 194 |
+
"""
|
| 195 |
+
# ...
|
| 196 |
+
|
| 197 |
+
def render(self):
|
| 198 |
+
"""Render the current state of the environment.
|
| 199 |
+
|
| 200 |
+
Displays a visualization of the current game state.
|
| 201 |
+
"""
|
| 202 |
+
# ...
|
| 203 |
+
|
| 204 |
+
def close(self):
|
| 205 |
+
"""Perform any necessary cleanup."""
|
| 206 |
+
# ...
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
Key Implementation Details
|
| 210 |
+
~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 211 |
+
|
| 212 |
+
The ``DiplomacyEnv`` class implements several key features:
|
| 213 |
+
|
| 214 |
+
1. **Multi-Agent Support**: The environment tracks multiple agents (powers) and manages their interactions.
|
| 215 |
+
|
| 216 |
+
2. **Turn-Based Gameplay**: The environment enforces the turn structure of Diplomacy, including different phases.
|
| 217 |
+
|
| 218 |
+
3. **Action Processing**: The environment can handle actions in both text format and DeepMind's integer format.
|
| 219 |
+
|
| 220 |
+
4. **Observation Generation**: The environment generates detailed observations for each agent, including board state, supply centers, and possible actions.
|
| 221 |
+
|
| 222 |
+
5. **Game Termination**: The environment tracks game termination conditions, including supply center victory and maximum turn limits.
|
| 223 |
+
|
| 224 |
+
Observation Structure
|
| 225 |
+
~~~~~~~~~~~~~~~~~~~~
|
| 226 |
+
|
| 227 |
+
Each agent receives an observation dictionary with the following structure:
|
| 228 |
+
|
| 229 |
+
.. code-block:: python
|
| 230 |
+
|
| 231 |
+
{
|
| 232 |
+
"board_state": np.ndarray, # Board state representation
|
| 233 |
+
"current_season": int, # Season index (0-4)
|
| 234 |
+
"player_index": int, # Index of the player's power (0-6)
|
| 235 |
+
"possible_actions": [int], # List of possible actions in DeepMind's format
|
| 236 |
+
"human_readable_actions": [str], # List of human-readable action descriptions
|
| 237 |
+
"supply_centers": [str], # List of supply centers owned by the player
|
| 238 |
+
"units": [dict], # List of units owned by the player
|
| 239 |
+
"year": int # Current year in the game
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
Action Structure
|
| 243 |
+
~~~~~~~~~~~~~~~
|
| 244 |
+
|
| 245 |
+
Actions can be provided in two formats:
|
| 246 |
+
|
| 247 |
+
1. **Text Format**: String actions like ``"A MUN - BER"`` or ``"F NTH C A LON - BEL"``.
|
| 248 |
+
|
| 249 |
+
2. **Integer Format**: Lists of integers corresponding to DeepMind's action representation.
|
| 250 |
+
|
| 251 |
+
The environment will convert text actions to the internal format as needed.
|
| 252 |
+
|
| 253 |
+
DiplomacyAgent
|
| 254 |
+
--------------
|
| 255 |
+
|
| 256 |
+
The ``DiplomacyAgent`` class implements the agent handler interface for Diplomacy, processing observations from the environment and generating actions through an LLM.
|
| 257 |
+
|
| 258 |
+
.. code-block:: python
|
| 259 |
+
|
| 260 |
+
class DiplomacyAgent:
|
| 261 |
+
"""
|
| 262 |
+
Agent handler for Diplomacy, implementing the AgentState interface
|
| 263 |
+
for the multi-agent negotiation standard.
|
| 264 |
+
"""
|
| 265 |
+
|
| 266 |
+
def __init__(self,
|
| 267 |
+
power_name: str,
|
| 268 |
+
use_text_interface: bool = True,
|
| 269 |
+
system_prompt: Optional[str] = None):
|
| 270 |
+
"""Initialize the Diplomacy agent handler.
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
power_name: Name of the power this agent controls
|
| 274 |
+
use_text_interface: Whether to use text-based interface (vs. structured)
|
| 275 |
+
system_prompt: Optional system prompt to use for the LLM
|
| 276 |
+
"""
|
| 277 |
+
# ...
|
| 278 |
+
|
| 279 |
+
def step(self, observation_from_env, policy_output=None):
|
| 280 |
+
"""Update the agent state based on the observation and action.
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
observation_from_env: The observation from the environment, with structure:
|
| 284 |
+
- board_state: Current state of the board
|
| 285 |
+
- current_season: Current season in the game
|
| 286 |
+
- player_index: Index of the player's power
|
| 287 |
+
- possible_actions: List of possible actions
|
| 288 |
+
- human_readable_actions: List of human-readable action descriptions
|
| 289 |
+
- supply_centers: List of supply centers owned by the player
|
| 290 |
+
- units: List of units owned by the player
|
| 291 |
+
- year: Current year in the game
|
| 292 |
+
|
| 293 |
+
policy_output: The output of the policy (LLM response), or None for initial prompt
|
| 294 |
+
|
| 295 |
+
Returns:
|
| 296 |
+
policy_id (str): The policy identifier ("llm_policy")
|
| 297 |
+
policy_input (dict): The input to the policy, with structure:
|
| 298 |
+
- messages: List of conversation messages in the format:
|
| 299 |
+
[{"role": "system", "content": "..."},
|
| 300 |
+
{"role": "user", "content": "..."}]
|
| 301 |
+
action: The official action to be sent to the environment, or None if not ready
|
| 302 |
+
done (bool): Whether the LLM action is ready to be sent to the environment
|
| 303 |
+
info (dict): Additional information about the agent:
|
| 304 |
+
- valid_action: Whether the extracted action is valid
|
| 305 |
+
"""
|
| 306 |
+
# ...
|
| 307 |
+
|
| 308 |
+
def get_log_info(self):
|
| 309 |
+
"""Get information about the agent required to log a trajectory.
|
| 310 |
+
|
| 311 |
+
Returns:
|
| 312 |
+
log_info (dict): Information about the agent required to log a trajectory:
|
| 313 |
+
- power_name: Name of the power this agent controls
|
| 314 |
+
- conversation_history: List of conversation messages
|
| 315 |
+
- current_action: The current action, if any
|
| 316 |
+
"""
|
| 317 |
+
# ...
|
| 318 |
+
|
| 319 |
+
def render(self):
|
| 320 |
+
"""Render the current state of the agent.
|
| 321 |
+
|
| 322 |
+
Displays the agent's current state, including conversation history.
|
| 323 |
+
"""
|
| 324 |
+
# ...
|
| 325 |
+
|
| 326 |
+
def close(self):
|
| 327 |
+
"""Perform any necessary cleanup."""
|
| 328 |
+
# ...
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
Key Implementation Details
|
| 332 |
+
~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 333 |
+
|
| 334 |
+
The ``DiplomacyAgent`` class implements several key features:
|
| 335 |
+
|
| 336 |
+
1. **LLM Interaction**: The agent generates prompts for an LLM and processes the LLM's responses to extract actions.
|
| 337 |
+
|
| 338 |
+
2. **Conversation Management**: The agent maintains a conversation history for coherent interactions with the LLM.
|
| 339 |
+
|
| 340 |
+
3. **Action Validation**: The agent validates extracted actions against the set of possible actions provided by the environment.
|
| 341 |
+
|
| 342 |
+
4. **Error Handling**: The agent generates clarification prompts when invalid actions are detected.
|
| 343 |
+
|
| 344 |
+
5. **Text-Based Interface**: The agent formats game state information into human-readable text for the LLM.
|
| 345 |
+
|
| 346 |
+
Prompt Structure
|
| 347 |
+
~~~~~~~~~~~~~~~
|
| 348 |
+
|
| 349 |
+
The agent generates prompts that include:
|
| 350 |
+
|
| 351 |
+
1. **System Prompt**: Instructions and context for the LLM, explaining its role as a Diplomacy player.
|
| 352 |
+
|
| 353 |
+
2. **Game State Description**: A text description of the current game state, including:
|
| 354 |
+
- Current year and season
|
| 355 |
+
- Supply centers owned
|
| 356 |
+
- Units controlled
|
| 357 |
+
- Possible actions
|
| 358 |
+
|
| 359 |
+
3. **Action Request**: Instructions on how to format actions.
|
| 360 |
+
|
| 361 |
+
Example system prompt:
|
| 362 |
+
|
| 363 |
+
.. code-block:: text
|
| 364 |
+
|
| 365 |
+
You are playing the role of FRANCE in a game of Diplomacy.
|
| 366 |
+
Your goal is to control as many supply centers as possible.
|
| 367 |
+
You can negotiate with other players and form alliances, but remember that
|
| 368 |
+
these alliances are not binding. When you need to submit orders for your units,
|
| 369 |
+
write them in the correct format, with each order on a new line.
|
| 370 |
+
|
| 371 |
+
Example game state description:
|
| 372 |
+
|
| 373 |
+
.. code-block:: text
|
| 374 |
+
|
| 375 |
+
Year: 1901, Season: SPRING_MOVES
|
| 376 |
+
You are playing as FRANCE.
|
| 377 |
+
You currently control 3 supply centers: PAR, MAR, BRE.
|
| 378 |
+
Your units are: A PAR, A MAR, F BRE.
|
| 379 |
+
|
| 380 |
+
Please provide orders for your units. Here are your possible actions:
|
| 381 |
+
A PAR - BUR
|
| 382 |
+
A PAR - GAS
|
| 383 |
+
A PAR - PIC
|
| 384 |
+
A PAR H
|
| 385 |
+
...
|
| 386 |
+
|
| 387 |
+
Submit your orders, one per line, in the format like: "A MUN - BER" or "F NTH C A LON - BEL"
|
| 388 |
+
|
| 389 |
+
Running Diplomacy Games
|
| 390 |
+
----------------------
|
| 391 |
+
|
| 392 |
+
To run Diplomacy games with LLM agents, you can use the ``run_batched_matches`` function with the ``DiplomacyEnv`` and ``DiplomacyAgent`` classes:
|
| 393 |
+
|
| 394 |
+
.. code-block:: python
|
| 395 |
+
|
| 396 |
+
from mllm.environments.diplomacy.diplomacy_env import DiplomacyEnv
|
| 397 |
+
from mllm.environments.diplomacy.diplomacy_agent import DiplomacyAgent
|
| 398 |
+
from mllm.run_matches import run_batched_matches
|
| 399 |
+
|
| 400 |
+
# Create environment and agent handlers
|
| 401 |
+
env = DiplomacyEnv(max_turns=30)
|
| 402 |
+
|
| 403 |
+
agent_handlers = {
|
| 404 |
+
"AUSTRIA": DiplomacyAgent(power_name="AUSTRIA"),
|
| 405 |
+
"ENGLAND": DiplomacyAgent(power_name="ENGLAND"),
|
| 406 |
+
"FRANCE": DiplomacyAgent(power_name="FRANCE"),
|
| 407 |
+
"GERMANY": DiplomacyAgent(power_name="GERMANY"),
|
| 408 |
+
"ITALY": DiplomacyAgent(power_name="ITALY"),
|
| 409 |
+
"RUSSIA": DiplomacyAgent(power_name="RUSSIA"),
|
| 410 |
+
"TURKEY": DiplomacyAgent(power_name="TURKEY")
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
# Define policy mapping (mapping from policy IDs to actual policy functions)
|
| 414 |
+
policy_mapping = {
|
| 415 |
+
"llm_policy": my_llm_policy_function
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
# Run the game
|
| 419 |
+
game_results = run_batched_matches(
|
| 420 |
+
envs=[env],
|
| 421 |
+
agent_handlers_per_env=[agent_handlers],
|
| 422 |
+
policy_mapping=policy_mapping,
|
| 423 |
+
max_parallel_matches=1
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
# Process results
|
| 427 |
+
for result in game_results:
|
| 428 |
+
print(f"Game finished. Winner: {result['winner']}")
|
| 429 |
+
print(f"Supply centers: {result['supply_centers']}")
|
| 430 |
+
|
| 431 |
+
This setup allows you to run Diplomacy games with LLM agents using the Multi-Agent Negotiation Environment standard.
|
| 432 |
+
|
| 433 |
+
Limitations and Considerations
|
| 434 |
+
-----------------------------
|
| 435 |
+
|
| 436 |
+
1. **Performance**: Processing observations and actions for seven powers using LLMs can be computationally intensive.
|
| 437 |
+
|
| 438 |
+
2. **Action Parsing**: Extracting valid actions from LLM outputs may require sophisticated parsing and error handling.
|
| 439 |
+
|
| 440 |
+
3. **Game Complexity**: Diplomacy is a complex game with many rules and edge cases, which may be challenging for LLMs to fully grasp.
|
| 441 |
+
|
| 442 |
+
4. **Turn Duration**: Real Diplomacy games include negotiation phases of variable duration, which are not fully captured in this implementation.
|
| 443 |
+
|
| 444 |
+
5. **Text Formatting**: The quality of LLM interactions depends heavily on the formatting and clarity of text prompts.
|
| 445 |
+
|
| 446 |
+
Advanced Usage
|
| 447 |
+
------------
|
| 448 |
+
|
| 449 |
+
For advanced usage, you can customize:
|
| 450 |
+
|
| 451 |
+
1. **System Prompts**: Modify agent behavior by providing custom system prompts.
|
| 452 |
+
|
| 453 |
+
2. **Observation Processing**: Extend the observation processing to include additional information.
|
| 454 |
+
|
| 455 |
+
3. **Action Parsing**: Implement more sophisticated action parsing for complex orders.
|
| 456 |
+
|
| 457 |
+
4. **Visualization**: Add custom visualization methods to the environment's render function.
|
| 458 |
+
|
| 459 |
+
5. **Logging**: Extend the logging capabilities to capture additional information about the game state.
|
src_code_for_reproducibility/docs/source/installation.rst
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Installation
|
| 2 |
+
===========
|
| 3 |
+
|
| 4 |
+
To install the package, run:
|
| 5 |
+
|
| 6 |
+
.. code-block:: bash
|
| 7 |
+
|
| 8 |
+
git clone https://github.com/yourusername/llm_negotiation.git
|
| 9 |
+
cd llm_negotiation
|
| 10 |
+
pip install -e .
|
src_code_for_reproducibility/docs/source/media/runbatch.png
ADDED
|
src_code_for_reproducibility/docs/source/src.environments.dond.dond_agent.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.environments.dond.dond\_agent module
|
| 2 |
+
========================================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.environments.dond.dond_agent
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.environments.dond.dond_game.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.environments.dond.dond\_game module
|
| 2 |
+
=======================================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.environments.dond.dond_game
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.environments.dond.dond_log_funcs.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.environments.dond.dond\_log\_funcs module
|
| 2 |
+
=============================================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.environments.dond.dond_log_funcs
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.environments.dond.dond_return_funcs.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.environments.dond.dond\_return\_funcs module
|
| 2 |
+
================================================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.environments.dond.dond_return_funcs
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.environments.dond.dond_statistics_funcs.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.environments.dond.dond\_statistics\_funcs module
|
| 2 |
+
====================================================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.environments.dond.dond_statistics_funcs
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.environments.environment_imports.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.environments.environment\_imports module
|
| 2 |
+
============================================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.environments.environment_imports
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_game.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.environments.ipd.ipd\_game module
|
| 2 |
+
=====================================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.environments.ipd.ipd_game
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_log_funcs.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.environments.ipd.ipd\_log\_funcs module
|
| 2 |
+
===========================================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.environments.ipd.ipd_log_funcs
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_statistics_funcs.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.environments.ipd.ipd\_statistics\_funcs module
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.environments.ipd.ipd_statistics_funcs
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.environments.rst
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.environments package
|
| 2 |
+
========================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.environments
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
| 8 |
+
|
| 9 |
+
Subpackages
|
| 10 |
+
-----------
|
| 11 |
+
|
| 12 |
+
.. toctree::
|
| 13 |
+
:maxdepth: 4
|
| 14 |
+
|
| 15 |
+
src.environments.dond
|
| 16 |
+
src.environments.ipd
|
| 17 |
+
|
| 18 |
+
Submodules
|
| 19 |
+
----------
|
| 20 |
+
|
| 21 |
+
.. toctree::
|
| 22 |
+
:maxdepth: 4
|
| 23 |
+
|
| 24 |
+
src.environments.env_imports
|
| 25 |
+
src.environments.environment_imports
|
src_code_for_reproducibility/docs/source/src.experiments.dond_run_train.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.experiments.dond\_run\_train module
|
| 2 |
+
=======================================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.experiments.dond_run_train
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.experiments.generate_and_train.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.experiments.generate\_and\_train module
|
| 2 |
+
===========================================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.experiments.generate_and_train
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.experiments.last_completion.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.experiments.last\_completion module
|
| 2 |
+
=======================================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.experiments.last_completion
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.generation.rst
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.generation package
|
| 2 |
+
======================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.generation
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
| 8 |
+
|
| 9 |
+
Submodules
|
| 10 |
+
----------
|
| 11 |
+
|
| 12 |
+
.. toctree::
|
| 13 |
+
:maxdepth: 4
|
| 14 |
+
|
| 15 |
+
src.generation.run_games
|
src_code_for_reproducibility/docs/source/src.models.dummy_hf_agent.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.models.dummy\_hf\_agent module
|
| 2 |
+
==================================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.models.dummy_llm_agent
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.models.rst
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.models package
|
| 2 |
+
==================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.models
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
| 8 |
+
|
| 9 |
+
Submodules
|
| 10 |
+
----------
|
| 11 |
+
|
| 12 |
+
.. toctree::
|
| 13 |
+
:maxdepth: 4
|
| 14 |
+
|
| 15 |
+
src.models.dummy_local_llm
|
| 16 |
+
src.models.local_llm
|
| 17 |
+
src.models.new_local_llm
|
| 18 |
+
src.models.server_llm
|
| 19 |
+
src.models.updatable_worker
|
| 20 |
+
src.models.vllm_worker_wrap
|
src_code_for_reproducibility/docs/source/src.models.updatable_worker.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.models.updatable\_worker module
|
| 2 |
+
===================================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.models.updatable_worker
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.models.vllm_worker_wrap.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.models.vllm\_worker\_wrap module
|
| 2 |
+
====================================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.models.vllm_worker_wrap
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.training.ppo_train.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.training.ppo\_train module
|
| 2 |
+
==============================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.training.ppo_train
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.training.reinforce_training.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.training.reinforce\_training module
|
| 2 |
+
=======================================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.training.reinforce_training
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.training.rl_convs_processing.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.training.rl\_convs\_processing module
|
| 2 |
+
=========================================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.training.rl_convs_processing
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.training.rst
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.training package
|
| 2 |
+
====================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.training
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
| 8 |
+
|
| 9 |
+
Submodules
|
| 10 |
+
----------
|
| 11 |
+
|
| 12 |
+
.. toctree::
|
| 13 |
+
:maxdepth: 4
|
| 14 |
+
|
| 15 |
+
src.training.ppo_train
|
| 16 |
+
src.training.ppo_train_value_head
|
| 17 |
+
src.training.reinforce_training
|
| 18 |
+
src.training.rl_convs_processing
|
| 19 |
+
src.training.train_main
|
src_code_for_reproducibility/docs/source/src.training.train_main.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.training.train\_main module
|
| 2 |
+
===============================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.training.train_main
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.utils.common_imports.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.utils.common\_imports module
|
| 2 |
+
================================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.utils.common_imports
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.utils.log_statistics.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.utils.log\_statistics module
|
| 2 |
+
================================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.utils.log_statistics
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.utils.parallel_shuffle.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.utils.parallel\_shuffle module
|
| 2 |
+
==================================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.utils.parallel_shuffle
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/docs/source/src.utils.rst
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.utils package
|
| 2 |
+
=================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.utils
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
| 8 |
+
|
| 9 |
+
Submodules
|
| 10 |
+
----------
|
| 11 |
+
|
| 12 |
+
.. toctree::
|
| 13 |
+
:maxdepth: 4
|
| 14 |
+
|
| 15 |
+
src.utils.common_imports
|
| 16 |
+
src.utils.export_ppo_training_set
|
| 17 |
+
src.utils.extra_stats
|
| 18 |
+
src.utils.inherit_args
|
| 19 |
+
src.utils.log_gpu_usage
|
| 20 |
+
src.utils.log_statistics
|
| 21 |
+
src.utils.model_to_cpu
|
| 22 |
+
src.utils.parallel_shuffle
|
| 23 |
+
src.utils.quick_stats
|
| 24 |
+
src.utils.update_start_epoch
|
src_code_for_reproducibility/markov_games/__init__.py
ADDED
|
File without changes
|
src_code_for_reproducibility/markov_games/agent.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
In simple RL paradise, where the action dimensions are constant and well defined,
|
| 3 |
+
Agent classes are not necessary. But in MARL, with LLM's, there isn't always
|
| 4 |
+
a direct path from policy to action. For instance, from the observation of the environment,
|
| 5 |
+
a prompt must be created. Then, the outputs of the policy might be incorrect, so a second
|
| 6 |
+
request to the LLM must be sent before the action is well defined. This is why this Agent class exists.
|
| 7 |
+
It acts as a mini environment, bridging the gap between the core simulation and
|
| 8 |
+
the LLM policies.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from abc import ABC, abstractmethod
|
| 12 |
+
from collections.abc import Callable
|
| 13 |
+
from typing import Any, Tuple
|
| 14 |
+
|
| 15 |
+
from numpy.random import default_rng
|
| 16 |
+
|
| 17 |
+
from mllm.markov_games.rollout_tree import AgentActLog
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class Agent(ABC):
|
| 21 |
+
@abstractmethod
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
seed: int,
|
| 25 |
+
agent_id: str,
|
| 26 |
+
agent_name: str,
|
| 27 |
+
agent_policy: Callable[[list[dict]], str],
|
| 28 |
+
*args,
|
| 29 |
+
**kwargs,
|
| 30 |
+
):
|
| 31 |
+
"""
|
| 32 |
+
Initialize the agent state.
|
| 33 |
+
"""
|
| 34 |
+
self.seed = seed
|
| 35 |
+
self.agent_id = agent_id
|
| 36 |
+
self.agent_name = agent_name
|
| 37 |
+
self.policy = policy
|
| 38 |
+
self.rng = default_rng(self.seed)
|
| 39 |
+
raise NotImplementedError
|
| 40 |
+
|
| 41 |
+
async def act(self, observation) -> Tuple[Any, AgentActLog]:
|
| 42 |
+
"""
|
| 43 |
+
Query (possibly multiple times) a policy (or possibly a pool of policies) to
|
| 44 |
+
obtain the action of the agent.
|
| 45 |
+
|
| 46 |
+
Example:
|
| 47 |
+
action = None
|
| 48 |
+
prompt = self.observation_to_prompt(observation)
|
| 49 |
+
while not self.valid(action):
|
| 50 |
+
output = await self.policy.generate(prompt)
|
| 51 |
+
action = self.policy_output_to_action(output)
|
| 52 |
+
return action
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
action
|
| 56 |
+
step_info
|
| 57 |
+
"""
|
| 58 |
+
raise NotImplementedError
|
| 59 |
+
|
| 60 |
+
def get_safe_copy(self):
|
| 61 |
+
"""
|
| 62 |
+
Return copy of the agent object that is decorrelated from the original object.
|
| 63 |
+
"""
|
| 64 |
+
raise NotImplementedError
|
| 65 |
+
|
| 66 |
+
def reset(self):
|
| 67 |
+
raise NotImplementedError
|
| 68 |
+
|
| 69 |
+
def render(self):
|
| 70 |
+
raise NotImplementedError
|
| 71 |
+
|
| 72 |
+
def close(self):
|
| 73 |
+
raise NotImplementedError
|
| 74 |
+
|
| 75 |
+
def get_agent_info(self):
|
| 76 |
+
raise NotImplementedError
|
src_code_for_reproducibility/markov_games/alternative_actions_runner.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import copy
|
| 3 |
+
import json
|
| 4 |
+
import os.path
|
| 5 |
+
from typing import Any, Tuple
|
| 6 |
+
|
| 7 |
+
from mllm.markov_games.markov_game import AgentAndActionSafeCopy, MarkovGame
|
| 8 |
+
from mllm.markov_games.rollout_tree import (
|
| 9 |
+
AgentActLog,
|
| 10 |
+
RolloutTreeBranchNode,
|
| 11 |
+
RolloutTreeNode,
|
| 12 |
+
RolloutTreeRootNode,
|
| 13 |
+
StepLog,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
AgentId = str
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
async def run_with_unilateral_alt_action(
|
| 21 |
+
markov_game: MarkovGame,
|
| 22 |
+
agent_id: AgentId,
|
| 23 |
+
time_step: int,
|
| 24 |
+
branch_node: RolloutTreeBranchNode,
|
| 25 |
+
max_depth: int,
|
| 26 |
+
):
|
| 27 |
+
"""
|
| 28 |
+
This function is used to generate a new branch for a given agent.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
# Generate alternative action and take a step
|
| 32 |
+
await markov_game.set_action_of_agent(agent_id)
|
| 33 |
+
terminated: bool = markov_game.take_simulation_step()
|
| 34 |
+
step_log = markov_game.get_step_log()
|
| 35 |
+
first_alternative_node = RolloutTreeNode(
|
| 36 |
+
step_log=step_log,
|
| 37 |
+
time_step=time_step,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# Generate rest of trajectory up to max depth
|
| 41 |
+
time_step += 1
|
| 42 |
+
counter = 1
|
| 43 |
+
previous_node = first_alternative_node
|
| 44 |
+
while not terminated and counter <= max_depth:
|
| 45 |
+
terminated, step_log = await markov_game.step()
|
| 46 |
+
current_node = RolloutTreeNode(step_log=step_log, time_step=time_step)
|
| 47 |
+
previous_node.child = current_node
|
| 48 |
+
previous_node = current_node
|
| 49 |
+
counter += 1
|
| 50 |
+
time_step += 1
|
| 51 |
+
|
| 52 |
+
if branch_node.branches == None:
|
| 53 |
+
branch_node.branches = {agent_id: [first_alternative_node]}
|
| 54 |
+
else:
|
| 55 |
+
agent_branches = branch_node.branches.get(agent_id, [])
|
| 56 |
+
agent_branches.append(first_alternative_node)
|
| 57 |
+
branch_node.branches[agent_id] = agent_branches
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
async def AlternativeActionsRunner(
|
| 61 |
+
markov_game: MarkovGame,
|
| 62 |
+
output_folder: str,
|
| 63 |
+
nb_alternative_actions: int,
|
| 64 |
+
max_depth: int,
|
| 65 |
+
branch_only_on_new_round: bool = False,
|
| 66 |
+
):
|
| 67 |
+
"""
|
| 68 |
+
This method generates a trajectory with partially completed branches,
|
| 69 |
+
where the branching comes from taking unilateraly different actions.
|
| 70 |
+
The resulting data is used to estimate the updated advantage alignment policy gradient terms.
|
| 71 |
+
Let k := nb_sub_steps. Then the number of steps generated is O(Tk), where T is
|
| 72 |
+
the maximum trajectory length.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
tasks = []
|
| 76 |
+
time_step = 0
|
| 77 |
+
terminated = False
|
| 78 |
+
root = RolloutTreeRootNode(
|
| 79 |
+
id=markov_game.get_id(),
|
| 80 |
+
crn_id=markov_game.get_crn_id()
|
| 81 |
+
)
|
| 82 |
+
previous_node = root
|
| 83 |
+
|
| 84 |
+
while not terminated:
|
| 85 |
+
mg_before_action = markov_game.get_safe_copy()
|
| 86 |
+
|
| 87 |
+
# Get safe copies for main branch
|
| 88 |
+
agent_action_safe_copies: dict[
|
| 89 |
+
AgentId, AgentAndActionSafeCopy
|
| 90 |
+
] = await markov_game.get_actions_of_agents_without_side_effects()
|
| 91 |
+
|
| 92 |
+
markov_game.set_actions_of_agents_manually(agent_action_safe_copies)
|
| 93 |
+
terminated = markov_game.take_simulation_step()
|
| 94 |
+
main_node = RolloutTreeNode(
|
| 95 |
+
step_log=markov_game.get_step_log(), time_step=time_step
|
| 96 |
+
)
|
| 97 |
+
branch_node = RolloutTreeBranchNode(main_child=main_node)
|
| 98 |
+
previous_node.child = branch_node
|
| 99 |
+
previous_node = main_node
|
| 100 |
+
|
| 101 |
+
# Get alternative branches by generating new unilateral actions
|
| 102 |
+
for agent_id in markov_game.agent_ids:
|
| 103 |
+
for _ in range(nb_alternative_actions):
|
| 104 |
+
# Get safe copies for branches
|
| 105 |
+
branch_agent_action_safe_copies: dict[
|
| 106 |
+
AgentId, AgentAndActionSafeCopy
|
| 107 |
+
] = {
|
| 108 |
+
agent_id: AgentAndActionSafeCopy(
|
| 109 |
+
action=copy.deepcopy(agent_action_safe_copy.action),
|
| 110 |
+
action_info=copy.deepcopy(agent_action_safe_copy.action_info),
|
| 111 |
+
agent_after_action=agent_action_safe_copy.agent_after_action.get_safe_copy(),
|
| 112 |
+
)
|
| 113 |
+
for agent_id, agent_action_safe_copy in agent_action_safe_copies.items()
|
| 114 |
+
}
|
| 115 |
+
mg_branch: MarkovGame = mg_before_action.get_safe_copy()
|
| 116 |
+
other_agent_id = [id for id in mg_branch.agent_ids if id != agent_id][0]
|
| 117 |
+
mg_branch.set_action_and_agent_after_action_manually(
|
| 118 |
+
agent_id=other_agent_id,
|
| 119 |
+
agent_action_safe_copy=branch_agent_action_safe_copies[
|
| 120 |
+
other_agent_id
|
| 121 |
+
],
|
| 122 |
+
)
|
| 123 |
+
task = asyncio.create_task(
|
| 124 |
+
run_with_unilateral_alt_action(
|
| 125 |
+
markov_game=mg_branch,
|
| 126 |
+
time_step=time_step,
|
| 127 |
+
agent_id=agent_id,
|
| 128 |
+
branch_node=branch_node,
|
| 129 |
+
max_depth=max_depth,
|
| 130 |
+
)
|
| 131 |
+
)
|
| 132 |
+
tasks.append(task)
|
| 133 |
+
time_step += 1
|
| 134 |
+
|
| 135 |
+
# wait for all branches to complete
|
| 136 |
+
await asyncio.gather(*tasks)
|
| 137 |
+
|
| 138 |
+
return root
|
src_code_for_reproducibility/markov_games/group_timesteps.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This module contains the logic for grouping time steps.
|
| 3 |
+
"""
|
| 4 |
+
import copy
|
| 5 |
+
from typing import Callable
|
| 6 |
+
|
| 7 |
+
from mllm.markov_games.markov_game import MarkovGame
|
| 8 |
+
from mllm.markov_games.rollout_tree import (
|
| 9 |
+
AgentActLog,
|
| 10 |
+
RolloutTreeBranchNode,
|
| 11 |
+
RolloutTreeNode,
|
| 12 |
+
RolloutTreeRootNode,
|
| 13 |
+
StepLog,
|
| 14 |
+
)
|
| 15 |
+
from mllm.markov_games.simulation import SimulationStepLog
|
| 16 |
+
|
| 17 |
+
AgentId = str
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def group_time_steps(
|
| 21 |
+
rollout_tree: RolloutTreeRootNode,
|
| 22 |
+
accumulation_stop_condition: Callable[[StepLog], bool],
|
| 23 |
+
) -> RolloutTreeRootNode:
|
| 24 |
+
"""
|
| 25 |
+
During generation, we create rollout trees according to the real time steps.
|
| 26 |
+
However, during training, we might want to treat groups of time steps as a single time step.
|
| 27 |
+
As a concrete example, take Trust-and-Split. At each round, say we have X time steps of communication and then one time step for the split.
|
| 28 |
+
Then the communication actions will not get any reward, and the split action will get the reward. During REINFORCE training, with discounting, this
|
| 29 |
+
can cause training instability. We could instead treat every action in the round as being part of a single action, and give it the reward of the split action.
|
| 30 |
+
This method helps to do this sort of grouping.
|
| 31 |
+
It accumulates actions until the accumulation_stop_condition is met, and then creates a new node with the accumulated actions.
|
| 32 |
+
It then recursively calls itself on the child node.
|
| 33 |
+
Details:
|
| 34 |
+
- The reward for the group is the reward of the last time step in the group.
|
| 35 |
+
- The simulation log for the group is the simulation log of the last time step in the group.
|
| 36 |
+
- The state end for the group becomes the first state end in the group.
|
| 37 |
+
- The agent info for the group is the agent info of the last time step in the group.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def group_step_logs(step_logs: list[StepLog]) -> StepLog:
|
| 41 |
+
"""
|
| 42 |
+
Concatenate per-agent chat turns across steps; keep only the first is_state_end.
|
| 43 |
+
"""
|
| 44 |
+
last_sim_log = step_logs[-1].simulation_step_log
|
| 45 |
+
agent_ids = {aid for s in step_logs for aid in s.action_logs.keys()}
|
| 46 |
+
grouped_logs: dict[AgentId, AgentActLog] = {}
|
| 47 |
+
for aid in agent_ids:
|
| 48 |
+
turns = []
|
| 49 |
+
for s in step_logs:
|
| 50 |
+
act = s.action_logs.get(aid)
|
| 51 |
+
if act and act.chat_turns:
|
| 52 |
+
turns.extend(copy.deepcopy(act.chat_turns))
|
| 53 |
+
disable_is_state_end = False
|
| 54 |
+
# Only the first state_end should be True, the rest should be False
|
| 55 |
+
for t in turns:
|
| 56 |
+
if t.is_state_end:
|
| 57 |
+
if disable_is_state_end:
|
| 58 |
+
t.is_state_end = False
|
| 59 |
+
else:
|
| 60 |
+
disable_is_state_end = True
|
| 61 |
+
continue
|
| 62 |
+
grouped_logs[aid] = AgentActLog(
|
| 63 |
+
chat_turns=turns, info=step_logs[-1].action_logs[aid].info
|
| 64 |
+
)
|
| 65 |
+
return StepLog(action_logs=grouped_logs, simulation_step_log=last_sim_log)
|
| 66 |
+
|
| 67 |
+
def group_time_steps_rec(
|
| 68 |
+
current_node: RolloutTreeNode | RolloutTreeBranchNode,
|
| 69 |
+
group_time_step: int,
|
| 70 |
+
accumulation_step_logs: list[StepLog],
|
| 71 |
+
) -> RolloutTreeNode | RolloutTreeBranchNode:
|
| 72 |
+
"""
|
| 73 |
+
Groups time steps. Recursion is used to handle branches.
|
| 74 |
+
"""
|
| 75 |
+
assert isinstance(current_node, RolloutTreeNode) or isinstance(
|
| 76 |
+
current_node, RolloutTreeBranchNode
|
| 77 |
+
), "Current node must be a tree node or a branch node. Is of type: " + str(
|
| 78 |
+
type(current_node)
|
| 79 |
+
)
|
| 80 |
+
first_group_node = None
|
| 81 |
+
current_group_node = None
|
| 82 |
+
while current_node is not None:
|
| 83 |
+
if isinstance(current_node, RolloutTreeBranchNode):
|
| 84 |
+
raise Exception(
|
| 85 |
+
"Grouping timesteps by round is not supported for branching trajectories yet."
|
| 86 |
+
)
|
| 87 |
+
# Special recursive case for branches
|
| 88 |
+
# if isinstance(current_node, RolloutTreeBranchNode):
|
| 89 |
+
# branches = {}
|
| 90 |
+
# for agent_id, branch_nodes in current_node.branches.items():
|
| 91 |
+
# branch_group_nodes = []
|
| 92 |
+
# for branch_node in branch_nodes:
|
| 93 |
+
# branch_group_node = group_time_steps_rec(
|
| 94 |
+
# current_node=branch_node,
|
| 95 |
+
# group_time_step=group_time_step,
|
| 96 |
+
# accumulation_step_logs=copy.deepcopy(accumulation_step_logs))
|
| 97 |
+
# branch_group_nodes.append(branch_group_node)
|
| 98 |
+
# branches[agent_id] = branch_group_nodes
|
| 99 |
+
|
| 100 |
+
# main_child_group_node = group_time_steps_rec(
|
| 101 |
+
# current_node=current_node.main_child,
|
| 102 |
+
# group_time_step=group_time_step,
|
| 103 |
+
# accumulation_step_logs=copy.deepcopy(accumulation_step_logs))
|
| 104 |
+
|
| 105 |
+
# return RolloutTreeBranchNode(main_child=main_child_group_node, branches=branches)
|
| 106 |
+
|
| 107 |
+
# Accumulate
|
| 108 |
+
accumulation_step_logs.append(current_node.step_log)
|
| 109 |
+
if accumulation_stop_condition(current_node.step_log):
|
| 110 |
+
grouped_step_logs = group_step_logs(accumulation_step_logs)
|
| 111 |
+
accumulation_step_logs = []
|
| 112 |
+
new_group_node = RolloutTreeNode(
|
| 113 |
+
step_log=grouped_step_logs, time_step=group_time_step, child=None
|
| 114 |
+
)
|
| 115 |
+
if first_group_node == None:
|
| 116 |
+
first_group_node = new_group_node
|
| 117 |
+
group_time_step += 1
|
| 118 |
+
if current_group_node is not None:
|
| 119 |
+
current_group_node.child = new_group_node
|
| 120 |
+
current_group_node = new_group_node
|
| 121 |
+
current_node = current_node.child
|
| 122 |
+
return first_group_node
|
| 123 |
+
|
| 124 |
+
node = group_time_steps_rec(
|
| 125 |
+
current_node=rollout_tree.child, group_time_step=0, accumulation_step_logs=[]
|
| 126 |
+
)
|
| 127 |
+
return RolloutTreeRootNode(
|
| 128 |
+
id=rollout_tree.id,
|
| 129 |
+
crn_id=rollout_tree.crn_id,
|
| 130 |
+
child=node,
|
| 131 |
+
agent_ids=rollout_tree.agent_ids,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def stop_when_round_ends(step_log: StepLog) -> bool:
|
| 136 |
+
"""
|
| 137 |
+
Simplest stop condition. Will return True if step log is the last time step of a round.
|
| 138 |
+
This will throw an error if this information is not available in the simulation info.
|
| 139 |
+
"""
|
| 140 |
+
assert (
|
| 141 |
+
"is_last_timestep_in_round" in step_log.simulation_step_log.info.keys()
|
| 142 |
+
), "To group by round, is_last_timestep_in_round must be set in the info of your simulation step log at each time step."
|
| 143 |
+
return step_log.simulation_step_log.info["is_last_timestep_in_round"]
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def group_by_round(rollout_tree: RolloutTreeRootNode) -> RolloutTreeRootNode:
|
| 147 |
+
"""
|
| 148 |
+
Groups time steps by round.
|
| 149 |
+
"""
|
| 150 |
+
return group_time_steps(rollout_tree, stop_when_round_ends)
|
src_code_for_reproducibility/markov_games/linear_runner.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import json
|
| 3 |
+
import os.path
|
| 4 |
+
|
| 5 |
+
from mllm.markov_games.markov_game import MarkovGame
|
| 6 |
+
from mllm.markov_games.rollout_tree import RolloutTreeNode, RolloutTreeRootNode
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
async def LinearRunner(
|
| 10 |
+
markov_game: MarkovGame, output_folder: str
|
| 11 |
+
) -> RolloutTreeRootNode:
|
| 12 |
+
"""
|
| 13 |
+
This method generates a trajectory without branching.
|
| 14 |
+
"""
|
| 15 |
+
time_step = 0
|
| 16 |
+
terminated = False
|
| 17 |
+
root = RolloutTreeRootNode(
|
| 18 |
+
id=markov_game.get_id(),
|
| 19 |
+
crn_id=markov_game.get_crn_id(),
|
| 20 |
+
agent_ids=markov_game.get_agent_ids(),
|
| 21 |
+
)
|
| 22 |
+
previous_node = root
|
| 23 |
+
while not terminated:
|
| 24 |
+
terminated, step_log = await markov_game.step()
|
| 25 |
+
current_node = RolloutTreeNode(step_log=step_log, time_step=time_step)
|
| 26 |
+
previous_node.child = current_node
|
| 27 |
+
previous_node = current_node
|
| 28 |
+
time_step += 1
|
| 29 |
+
|
| 30 |
+
return root
|
src_code_for_reproducibility/markov_games/markov_game.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This class unifies a simulation, and the agents acting in it (see `simulation.py` & `agent.py`).
|
| 3 |
+
In a MarkovGame step,
|
| 4 |
+
1) each agent takes an action,
|
| 5 |
+
2) the state transitions with respect to these actions,
|
| 6 |
+
3) all relevant data of the step is appended to the historical data list
|
| 7 |
+
|
| 8 |
+
In order to perform 3), the agents and the simulation are expected, at each time step,
|
| 9 |
+
to return a log of the state transition (from their perspective).
|
| 10 |
+
For instance, the Simulation might send rewards and the agents might send prompting contexts to be used later to generate the training data.
|
| 11 |
+
A different approach would be to simply have the agents keep their data private and log it upon completion of a trajectory.
|
| 12 |
+
The approach we use here centralizes the data gathering aspect,
|
| 13 |
+
making it easy to create sub-trajectories (in the `runners` defined in `runners.py`) descriptions that
|
| 14 |
+
only log information for step transitions occuring after the branching out.
|
| 15 |
+
"""
|
| 16 |
+
import asyncio
|
| 17 |
+
import copy
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
from typing import Any, List, Literal, Optional, Tuple
|
| 22 |
+
|
| 23 |
+
from transformers.models.idefics2 import Idefics2Config
|
| 24 |
+
|
| 25 |
+
from mllm.markov_games.agent import Agent
|
| 26 |
+
from mllm.markov_games.rollout_tree import AgentActLog, StepLog
|
| 27 |
+
from mllm.markov_games.simulation import Simulation
|
| 28 |
+
|
| 29 |
+
AgentId = str
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@dataclass
|
| 33 |
+
class AgentAndActionSafeCopy:
|
| 34 |
+
action: Any
|
| 35 |
+
action_info: AgentActLog
|
| 36 |
+
agent_after_action: type[Agent]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class MarkovGame(object):
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
id: int,
|
| 43 |
+
agents: dict[AgentId, type[Agent]],
|
| 44 |
+
simulation: type[Simulation],
|
| 45 |
+
crn_id: int,
|
| 46 |
+
):
|
| 47 |
+
"""
|
| 48 |
+
Args:
|
| 49 |
+
agents:
|
| 50 |
+
output_path:
|
| 51 |
+
Path where the step infos are saved.
|
| 52 |
+
simulation:
|
| 53 |
+
Simulation object. Example: IPDSimulation
|
| 54 |
+
"""
|
| 55 |
+
self.agents = agents
|
| 56 |
+
self.agent_ids = self.agents.keys()
|
| 57 |
+
self.simulation = simulation
|
| 58 |
+
self.simulation_step_log = None
|
| 59 |
+
self.agent_step_logs = {agent_id: None for agent_id in self.agent_ids}
|
| 60 |
+
self.actions = {}
|
| 61 |
+
self.id = id
|
| 62 |
+
self.crn_id = crn_id
|
| 63 |
+
|
| 64 |
+
def get_id(self) -> str:
|
| 65 |
+
return self.id
|
| 66 |
+
|
| 67 |
+
def get_crn_id(self) -> int:
|
| 68 |
+
return self.crn_id
|
| 69 |
+
|
| 70 |
+
def get_agent_ids(self) -> List[AgentId]:
|
| 71 |
+
return list(self.agent_ids)
|
| 72 |
+
|
| 73 |
+
async def get_action_of_agent_without_side_effects(
|
| 74 |
+
self, agent_id: AgentId
|
| 75 |
+
) -> Tuple[Any, AgentActLog]:
|
| 76 |
+
"""
|
| 77 |
+
Safe function to get an action of an agent without modifying the agent or the simulation.
|
| 78 |
+
"""
|
| 79 |
+
agent = self.agents[agent_id]
|
| 80 |
+
agent_before_action = agent.get_safe_copy()
|
| 81 |
+
obs = self.simulation.get_obs_agent(agent_id)
|
| 82 |
+
action, action_info = await agent.act(observation=obs)
|
| 83 |
+
self.agents[agent_id] = agent_before_action
|
| 84 |
+
agent_after_action = agent.get_safe_copy()
|
| 85 |
+
return AgentAndActionSafeCopy(action, action_info, agent_after_action)
|
| 86 |
+
|
| 87 |
+
async def get_actions_of_agents_without_side_effects(
|
| 88 |
+
self,
|
| 89 |
+
) -> dict[AgentId, AgentAndActionSafeCopy]:
|
| 90 |
+
"""
|
| 91 |
+
Safe function to get an action of an agent without modifying the agent or the simulation.
|
| 92 |
+
"""
|
| 93 |
+
tasks = []
|
| 94 |
+
for agent_id in self.agent_ids:
|
| 95 |
+
task = asyncio.create_task(
|
| 96 |
+
self.get_action_of_agent_without_side_effects(agent_id)
|
| 97 |
+
)
|
| 98 |
+
tasks.append(task)
|
| 99 |
+
agent_and_action_safe_copies: list[
|
| 100 |
+
AgentAndActionSafeCopy
|
| 101 |
+
] = await asyncio.gather(*tasks)
|
| 102 |
+
return {
|
| 103 |
+
agent_id: agent_and_action_safe_copy
|
| 104 |
+
for agent_id, agent_and_action_safe_copy in zip(
|
| 105 |
+
self.agent_ids, agent_and_action_safe_copies
|
| 106 |
+
)
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
def set_action_and_agent_after_action_manually(
|
| 110 |
+
self,
|
| 111 |
+
agent_id: AgentId,
|
| 112 |
+
agent_action_safe_copy: AgentAndActionSafeCopy,
|
| 113 |
+
):
|
| 114 |
+
"""
|
| 115 |
+
Set the action and the agent after action manually.
|
| 116 |
+
"""
|
| 117 |
+
self.actions[agent_id] = agent_action_safe_copy.action
|
| 118 |
+
self.agent_step_logs[agent_id] = agent_action_safe_copy.action_info
|
| 119 |
+
self.agents[agent_id] = agent_action_safe_copy.agent_after_action
|
| 120 |
+
|
| 121 |
+
def set_actions_of_agents_manually(
|
| 122 |
+
self, actions: dict[AgentId, AgentAndActionSafeCopy]
|
| 123 |
+
):
|
| 124 |
+
"""
|
| 125 |
+
Set the actions of agents manually.
|
| 126 |
+
"""
|
| 127 |
+
for agent_id, agent_action_safe_copy in actions.items():
|
| 128 |
+
self.set_action_and_agent_after_action_manually(
|
| 129 |
+
agent_id, agent_action_safe_copy
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
async def set_action_of_agent(self, agent_id: AgentId):
|
| 133 |
+
"""
|
| 134 |
+
TOWRITE
|
| 135 |
+
"""
|
| 136 |
+
agent = self.agents[agent_id]
|
| 137 |
+
obs = self.simulation.get_obs_agent(agent_id)
|
| 138 |
+
action, action_info = await agent.act(observation=obs)
|
| 139 |
+
self.actions[agent_id] = action
|
| 140 |
+
self.agent_step_logs[agent_id] = action_info
|
| 141 |
+
|
| 142 |
+
async def set_actions(self):
|
| 143 |
+
"""
|
| 144 |
+
TOWRITE
|
| 145 |
+
"""
|
| 146 |
+
# background_tasks = set()
|
| 147 |
+
tasks = []
|
| 148 |
+
for agent_id in self.agent_ids:
|
| 149 |
+
task = asyncio.create_task(self.set_action_of_agent(agent_id))
|
| 150 |
+
tasks.append(task)
|
| 151 |
+
await asyncio.gather(*tasks)
|
| 152 |
+
|
| 153 |
+
def take_simulation_step(self):
|
| 154 |
+
"""
|
| 155 |
+
TOWRITE
|
| 156 |
+
"""
|
| 157 |
+
terminated, self.simulation_step_log = self.simulation.step(self.actions)
|
| 158 |
+
return terminated
|
| 159 |
+
|
| 160 |
+
def get_step_log(self) -> StepLog:
|
| 161 |
+
"""
|
| 162 |
+
TOWRITE
|
| 163 |
+
TODO: assert actions and simulation have taken step
|
| 164 |
+
"""
|
| 165 |
+
step_log = StepLog(
|
| 166 |
+
simulation_step_log=self.simulation_step_log,
|
| 167 |
+
action_logs=self.agent_step_logs,
|
| 168 |
+
)
|
| 169 |
+
return step_log
|
| 170 |
+
|
| 171 |
+
async def step(self) -> Tuple[bool, StepLog]:
|
| 172 |
+
"""
|
| 173 |
+
TOWRITE
|
| 174 |
+
"""
|
| 175 |
+
await self.set_actions()
|
| 176 |
+
terminated = self.take_simulation_step()
|
| 177 |
+
step_log = self.get_step_log()
|
| 178 |
+
return terminated, step_log
|
| 179 |
+
|
| 180 |
+
def get_safe_copy(self):
|
| 181 |
+
"""
|
| 182 |
+
TOWRITE
|
| 183 |
+
"""
|
| 184 |
+
|
| 185 |
+
new_markov_game = copy.copy(self)
|
| 186 |
+
new_simulation = self.simulation.get_safe_copy()
|
| 187 |
+
new_agents = {
|
| 188 |
+
agent_id: agent.get_safe_copy() for agent_id, agent in self.agents.items()
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
# Reassign copied components
|
| 192 |
+
new_markov_game.simulation = new_simulation
|
| 193 |
+
new_markov_game.agents = new_agents
|
| 194 |
+
|
| 195 |
+
# IMPORTANT: ensure agent_ids references the new agents dict, not the original
|
| 196 |
+
new_markov_game.agent_ids = new_markov_game.agents.keys()
|
| 197 |
+
|
| 198 |
+
# Deep-copy step data to avoid correlation
|
| 199 |
+
new_markov_game.simulation_step_log = copy.deepcopy(self.simulation_step_log)
|
| 200 |
+
new_markov_game.actions = copy.deepcopy(self.actions)
|
| 201 |
+
# Rebuild logs to align exactly with new agent ids
|
| 202 |
+
old_agent_step_logs = copy.deepcopy(self.agent_step_logs)
|
| 203 |
+
new_markov_game.agent_step_logs = {
|
| 204 |
+
agent_id: old_agent_step_logs.get(agent_id)
|
| 205 |
+
for agent_id in new_markov_game.agent_ids
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
return new_markov_game
|
src_code_for_reproducibility/markov_games/mg_utils.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import copy
|
| 3 |
+
from collections.abc import Callable
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
|
| 6 |
+
from mllm.markov_games.ipd.ipd_agent import IPDAgent
|
| 7 |
+
from mllm.markov_games.ipd.ipd_simulation import IPD
|
| 8 |
+
from mllm.markov_games.markov_game import MarkovGame
|
| 9 |
+
from mllm.markov_games.negotiation.dond_agent import DealNoDealAgent
|
| 10 |
+
from mllm.markov_games.negotiation.dond_simulation import DealNoDealSimulation
|
| 11 |
+
from mllm.markov_games.negotiation.nego_hard_coded_policies import (
|
| 12 |
+
HardCodedNegoGreedyPolicy,
|
| 13 |
+
HardCodedNegoWelfareMaximizingPolicy,
|
| 14 |
+
)
|
| 15 |
+
from mllm.markov_games.ipd.Ipd_hard_coded_agents import AlwaysCooperateIPDAgent, AlwaysDefectIPDAgent
|
| 16 |
+
from mllm.markov_games.negotiation.no_press_nego_agent import NoPressAgent
|
| 17 |
+
from mllm.markov_games.negotiation.no_press_nego_simulation import NoPressSimulation
|
| 18 |
+
from mllm.markov_games.negotiation.tas_agent import TrustAndSplitAgent
|
| 19 |
+
from mllm.markov_games.negotiation.tas_rps_agent import TrustAndSplitRPSAgent
|
| 20 |
+
from mllm.markov_games.negotiation.tas_rps_simulation import TrustAndSplitRPSSimulation
|
| 21 |
+
from mllm.markov_games.negotiation.tas_simple_agent import TrustAndSplitSimpleAgent
|
| 22 |
+
from mllm.markov_games.negotiation.tas_simple_simulation import (
|
| 23 |
+
TrustAndSplitSimpleSimulation,
|
| 24 |
+
)
|
| 25 |
+
from mllm.markov_games.negotiation.tas_simulation import TrustAndSplitSimulation
|
| 26 |
+
from mllm.markov_games.rollout_tree import (
|
| 27 |
+
AgentActLog,
|
| 28 |
+
RolloutTreeBranchNode,
|
| 29 |
+
RolloutTreeNode,
|
| 30 |
+
RolloutTreeRootNode,
|
| 31 |
+
StepLog,
|
| 32 |
+
)
|
| 33 |
+
from mllm.markov_games.simulation import SimulationStepLog
|
| 34 |
+
|
| 35 |
+
AgentId = str
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class AgentConfig:
|
| 40 |
+
agent_id: str
|
| 41 |
+
agent_name: str
|
| 42 |
+
agent_class_name: str
|
| 43 |
+
policy_id: str
|
| 44 |
+
init_kwargs: dict
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@dataclass
|
| 48 |
+
class MarkovGameConfig:
|
| 49 |
+
id: int
|
| 50 |
+
seed: int
|
| 51 |
+
simulation_class_name: str
|
| 52 |
+
simulation_init_args: dict
|
| 53 |
+
agent_configs: list[AgentConfig]
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def init_markov_game_components(
|
| 57 |
+
config: MarkovGameConfig, policies: dict[str, Callable[[list[dict]], str]]
|
| 58 |
+
):
|
| 59 |
+
"""
|
| 60 |
+
TOWRITE
|
| 61 |
+
"""
|
| 62 |
+
agents = {}
|
| 63 |
+
agent_names = []
|
| 64 |
+
for agent_config in config.agent_configs:
|
| 65 |
+
agent_id = agent_config.agent_id
|
| 66 |
+
agent_name = agent_config.agent_name
|
| 67 |
+
agent_class = eval(agent_config.agent_class_name)
|
| 68 |
+
agent = agent_class(
|
| 69 |
+
seed=config.seed,
|
| 70 |
+
agent_id=agent_id,
|
| 71 |
+
agent_name=agent_name,
|
| 72 |
+
policy=policies[agent_config.policy_id],
|
| 73 |
+
**agent_config.init_kwargs,
|
| 74 |
+
)
|
| 75 |
+
agents[agent_id] = agent
|
| 76 |
+
agent_names.append(agent_name)
|
| 77 |
+
simulation = eval(config.simulation_class_name)(
|
| 78 |
+
seed=config.seed,
|
| 79 |
+
agent_ids=list(agents.keys()),
|
| 80 |
+
agent_names=agent_names,
|
| 81 |
+
**config.simulation_init_args,
|
| 82 |
+
)
|
| 83 |
+
markov_game = MarkovGame(
|
| 84 |
+
id=config.id,
|
| 85 |
+
crn_id=config.seed,
|
| 86 |
+
agents=agents,
|
| 87 |
+
simulation=simulation,
|
| 88 |
+
)
|
| 89 |
+
return markov_game
|
src_code_for_reproducibility/markov_games/negotiation/__pycache__/nego_hard_coded_policies.cpython-312.pyc
ADDED
|
Binary file (3.23 kB). View file
|
|
|
src_code_for_reproducibility/markov_games/rollout_tree.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TODO: add parent to nodes so that some verification can be done. For instance, to ensure that node reward keys match the parent node.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Any, List, Literal, Optional, Tuple
|
| 11 |
+
|
| 12 |
+
import jsonschema
|
| 13 |
+
from pydantic import BaseModel, Field, model_validator
|
| 14 |
+
|
| 15 |
+
from mllm.chat_utils.chat_turn import ChatTurn
|
| 16 |
+
|
| 17 |
+
AgentId = str
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class SimulationStepLog(BaseModel):
|
| 21 |
+
rewards: dict[AgentId, float]
|
| 22 |
+
info: Any = None
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class AgentActLog(BaseModel):
|
| 26 |
+
chat_turns: list[ChatTurn] | None
|
| 27 |
+
info: Any = None
|
| 28 |
+
|
| 29 |
+
@model_validator(mode="after")
|
| 30 |
+
def _exactly_one_state_end(self):
|
| 31 |
+
"""
|
| 32 |
+
This method is used to enforce that for each AgentActLog, there is exactly one ChatTurn which is a state end.
|
| 33 |
+
"""
|
| 34 |
+
if self.chat_turns != []:
|
| 35 |
+
n = sum(1 for t in self.chat_turns if t.is_state_end)
|
| 36 |
+
if n != 1:
|
| 37 |
+
raise ValueError(
|
| 38 |
+
f"AgentActLog must have exactly one ChatTurn with is_state_end=True; got {self.chat_turns}."
|
| 39 |
+
)
|
| 40 |
+
return self
|
| 41 |
+
else:
|
| 42 |
+
return self
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class StepLog(BaseModel):
|
| 46 |
+
action_logs: dict[AgentId, AgentActLog]
|
| 47 |
+
simulation_step_log: SimulationStepLog
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# BranchType = Literal["unilateral_deviation", "common_deviation"] # might not be necessary
|
| 51 |
+
# class BranchNodeInfo(BaseModel):
|
| 52 |
+
# branch_id: str
|
| 53 |
+
# branch_for: AgentId
|
| 54 |
+
# branch_type: BranchType
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class RolloutTreeNode(BaseModel):
|
| 58 |
+
step_log: StepLog
|
| 59 |
+
time_step: int
|
| 60 |
+
child: RolloutTreeNode | RolloutTreeBranchNode | None = None
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class RolloutTreeBranchNode(BaseModel):
|
| 64 |
+
"""
|
| 65 |
+
First item of the tuple indicates which agent "called" for an alternative branch.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
main_child: RolloutTreeNode
|
| 69 |
+
branches: dict[AgentId, list[RolloutTreeNode]] | None = None
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class RolloutTreeRootNode(BaseModel):
|
| 73 |
+
id: int
|
| 74 |
+
crn_id: int # ID of the rng used to generate this rollout tree
|
| 75 |
+
child: RolloutTreeNode | RolloutTreeBranchNode | None = None
|
| 76 |
+
agent_ids: List[AgentId] = Field(min_length=1)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# class RolloutTreeLeafNode(BaseModel):
|
| 80 |
+
# step_log: StepLog
|
| 81 |
+
# time_step: int
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# Necessary for self-referential stuff in pydantic
|
| 85 |
+
RolloutTreeBranchNode.model_rebuild()
|
| 86 |
+
RolloutTreeNode.model_rebuild()
|