diff --git a/.hydra/hydra.yaml b/.hydra/hydra.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e1b9b0d2decfea47ad79cba78803ec9f9b5f0cec --- /dev/null +++ b/.hydra/hydra.yaml @@ -0,0 +1,154 @@ +hydra: + run: + dir: ${oc.env:SCRATCH}/llm_negotiation/${now:%Y_%m}/${experiment.name} + sweep: + dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S} + subdir: ${hydra.job.num} + launcher: + _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher + sweeper: + _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper + max_batch_size: null + params: null + help: + app_name: ${hydra.job.name} + header: '${hydra.help.app_name} is powered by Hydra. + + ' + footer: 'Powered by Hydra (https://hydra.cc) + + Use --hydra-help to view Hydra specific help + + ' + template: '${hydra.help.header} + + == Configuration groups == + + Compose your configuration from those groups (group=option) + + + $APP_CONFIG_GROUPS + + + == Config == + + Override anything in the config (foo.bar=value) + + + $CONFIG + + + ${hydra.help.footer} + + ' + hydra_help: + template: 'Hydra (${hydra.runtime.version}) + + See https://hydra.cc for more info. + + + == Flags == + + $FLAGS_HELP + + + == Configuration groups == + + Compose your configuration from those groups (For example, append hydra/job_logging=disabled + to command line) + + + $HYDRA_CONFIG_GROUPS + + + Use ''--cfg hydra'' to Show the Hydra config. + + ' + hydra_help: ??? + hydra_logging: + version: 1 + formatters: + simple: + format: '[%(asctime)s][HYDRA] %(message)s' + handlers: + console: + class: logging.StreamHandler + formatter: simple + stream: ext://sys.stdout + root: + level: INFO + handlers: + - console + loggers: + logging_example: + level: DEBUG + disable_existing_loggers: false + job_logging: + version: 1 + formatters: + simple: + format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' + handlers: + console: + class: logging.StreamHandler + formatter: simple + stream: ext://sys.stdout + file: + class: logging.FileHandler + formatter: simple + filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log + root: + level: INFO + handlers: + - console + - file + disable_existing_loggers: false + env: {} + mode: RUN + searchpath: [] + callbacks: {} + output_subdir: .hydra + overrides: + hydra: + - hydra.mode=RUN + task: [] + job: + name: run + chdir: false + override_dirname: '' + id: ??? + num: ??? + config_name: naive_vs_fixed_ad_align_seed0.yaml + env_set: {} + env_copy: [] + config: + override_dirname: + kv_sep: '=' + item_sep: ',' + exclude_keys: [] + runtime: + version: 1.3.2 + version_base: '1.1' + cwd: /scratch/muqeeth/llm_negotiation + config_sources: + - path: hydra.conf + schema: pkg + provider: hydra + - path: /scratch/muqeeth/llm_negotiation/configs + schema: file + provider: main + - path: '' + schema: structured + provider: schema + output_dir: /scratch/muqeeth/llm_negotiation/2025_11/naive_vs_fixed_ad_align_seed0 + choices: + hydra/env: default + hydra/callbacks: null + hydra/job_logging: default + hydra/hydra_logging: default + hydra/hydra_help: default + hydra/help: default + hydra/sweeper: basic + hydra/launcher: basic + hydra/output: default + verbose: false diff --git a/.hydra/overrides.yaml b/.hydra/overrides.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fe51488c7066f6687ef680d6bfaa4f7768ef205c --- /dev/null +++ b/.hydra/overrides.yaml @@ -0,0 +1 @@ +[] diff --git a/seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/README.md b/seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/README.md new file mode 100644 index 0000000000000000000000000000000000000000..952935e8a936512044016a9bc1f922b109c88143 --- /dev/null +++ b/seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/README.md @@ -0,0 +1,207 @@ +--- +base_model: Qwen/Qwen2.5-7B-Instruct +library_name: peft +pipeline_tag: text-generation +tags: +- base_model:adapter:Qwen/Qwen2.5-7B-Instruct +- lora +- transformers +--- + +# Model Card for Model ID + + + + + +## Model Details + +### Model Description + + + + + +- **Developed by:** [More Information Needed] +- **Funded by [optional]:** [More Information Needed] +- **Shared by [optional]:** [More Information Needed] +- **Model type:** [More Information Needed] +- **Language(s) (NLP):** [More Information Needed] +- **License:** [More Information Needed] +- **Finetuned from model [optional]:** [More Information Needed] + +### Model Sources [optional] + + + +- **Repository:** [More Information Needed] +- **Paper [optional]:** [More Information Needed] +- **Demo [optional]:** [More Information Needed] + +## Uses + + + +### Direct Use + + + +[More Information Needed] + +### Downstream Use [optional] + + + +[More Information Needed] + +### Out-of-Scope Use + + + +[More Information Needed] + +## Bias, Risks, and Limitations + + + +[More Information Needed] + +### Recommendations + + + +Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations. + +## How to Get Started with the Model + +Use the code below to get started with the model. + +[More Information Needed] + +## Training Details + +### Training Data + + + +[More Information Needed] + +### Training Procedure + + + +#### Preprocessing [optional] + +[More Information Needed] + + +#### Training Hyperparameters + +- **Training regime:** [More Information Needed] + +#### Speeds, Sizes, Times [optional] + + + +[More Information Needed] + +## Evaluation + + + +### Testing Data, Factors & Metrics + +#### Testing Data + + + +[More Information Needed] + +#### Factors + + + +[More Information Needed] + +#### Metrics + + + +[More Information Needed] + +### Results + +[More Information Needed] + +#### Summary + + + +## Model Examination [optional] + + + +[More Information Needed] + +## Environmental Impact + + + +Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). + +- **Hardware Type:** [More Information Needed] +- **Hours used:** [More Information Needed] +- **Cloud Provider:** [More Information Needed] +- **Compute Region:** [More Information Needed] +- **Carbon Emitted:** [More Information Needed] + +## Technical Specifications [optional] + +### Model Architecture and Objective + +[More Information Needed] + +### Compute Infrastructure + +[More Information Needed] + +#### Hardware + +[More Information Needed] + +#### Software + +[More Information Needed] + +## Citation [optional] + + + +**BibTeX:** + +[More Information Needed] + +**APA:** + +[More Information Needed] + +## Glossary [optional] + + + +[More Information Needed] + +## More Information [optional] + +[More Information Needed] + +## Model Card Authors [optional] + +[More Information Needed] + +## Model Card Contact + +[More Information Needed] +### Framework versions + +- PEFT 0.17.1 \ No newline at end of file diff --git a/seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/agent_adapter/adapter_config.json b/seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/agent_adapter/adapter_config.json new file mode 100644 index 0000000000000000000000000000000000000000..c7cbf5a09004d8058a89e172c77ed02ebe72d98f --- /dev/null +++ b/seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/agent_adapter/adapter_config.json @@ -0,0 +1,42 @@ +{ + "alpha_pattern": {}, + "auto_mapping": null, + "base_model_name_or_path": "Qwen/Qwen2.5-7B-Instruct", + "bias": "none", + "corda_config": null, + "eva_config": null, + "exclude_modules": null, + "fan_in_fan_out": false, + "inference_mode": true, + "init_lora_weights": true, + "layer_replication": null, + "layers_pattern": null, + "layers_to_transform": null, + "loftq_config": {}, + "lora_alpha": 64, + "lora_bias": false, + "lora_dropout": 0.0, + "megatron_config": null, + "megatron_core": "megatron.core", + "modules_to_save": null, + "peft_type": "LORA", + "qalora_group_size": 16, + "r": 32, + "rank_pattern": {}, + "revision": null, + "target_modules": [ + "down_proj", + "k_proj", + "up_proj", + "v_proj", + "gate_proj", + "q_proj", + "o_proj" + ], + "target_parameters": null, + "task_type": "CAUSAL_LM", + "trainable_token_indices": null, + "use_dora": false, + "use_qalora": false, + "use_rslora": false +} \ No newline at end of file diff --git a/seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/critic_adapter/adapter_config.json b/seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/critic_adapter/adapter_config.json new file mode 100644 index 0000000000000000000000000000000000000000..c7cbf5a09004d8058a89e172c77ed02ebe72d98f --- /dev/null +++ b/seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/critic_adapter/adapter_config.json @@ -0,0 +1,42 @@ +{ + "alpha_pattern": {}, + "auto_mapping": null, + "base_model_name_or_path": "Qwen/Qwen2.5-7B-Instruct", + "bias": "none", + "corda_config": null, + "eva_config": null, + "exclude_modules": null, + "fan_in_fan_out": false, + "inference_mode": true, + "init_lora_weights": true, + "layer_replication": null, + "layers_pattern": null, + "layers_to_transform": null, + "loftq_config": {}, + "lora_alpha": 64, + "lora_bias": false, + "lora_dropout": 0.0, + "megatron_config": null, + "megatron_core": "megatron.core", + "modules_to_save": null, + "peft_type": "LORA", + "qalora_group_size": 16, + "r": 32, + "rank_pattern": {}, + "revision": null, + "target_modules": [ + "down_proj", + "k_proj", + "up_proj", + "v_proj", + "gate_proj", + "q_proj", + "o_proj" + ], + "target_parameters": null, + "task_type": "CAUSAL_LM", + "trainable_token_indices": null, + "use_dora": false, + "use_qalora": false, + "use_rslora": false +} \ No newline at end of file diff --git a/seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/fixed_ad_align_adapter/adapter_config.json b/seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/fixed_ad_align_adapter/adapter_config.json new file mode 100644 index 0000000000000000000000000000000000000000..c7cbf5a09004d8058a89e172c77ed02ebe72d98f --- /dev/null +++ b/seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/fixed_ad_align_adapter/adapter_config.json @@ -0,0 +1,42 @@ +{ + "alpha_pattern": {}, + "auto_mapping": null, + "base_model_name_or_path": "Qwen/Qwen2.5-7B-Instruct", + "bias": "none", + "corda_config": null, + "eva_config": null, + "exclude_modules": null, + "fan_in_fan_out": false, + "inference_mode": true, + "init_lora_weights": true, + "layer_replication": null, + "layers_pattern": null, + "layers_to_transform": null, + "loftq_config": {}, + "lora_alpha": 64, + "lora_bias": false, + "lora_dropout": 0.0, + "megatron_config": null, + "megatron_core": "megatron.core", + "modules_to_save": null, + "peft_type": "LORA", + "qalora_group_size": 16, + "r": 32, + "rank_pattern": {}, + "revision": null, + "target_modules": [ + "down_proj", + "k_proj", + "up_proj", + "v_proj", + "gate_proj", + "q_proj", + "o_proj" + ], + "target_parameters": null, + "task_type": "CAUSAL_LM", + "trainable_token_indices": null, + "use_dora": false, + "use_qalora": false, + "use_rslora": false +} \ No newline at end of file diff --git a/src_code_for_reproducibility/__init__.py b/src_code_for_reproducibility/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src_code_for_reproducibility/docs/source/conf.py b/src_code_for_reproducibility/docs/source/conf.py new file mode 100644 index 0000000000000000000000000000000000000000..5c7512678928b6b7580c812cd62d1c22df9945ba --- /dev/null +++ b/src_code_for_reproducibility/docs/source/conf.py @@ -0,0 +1,48 @@ +# Configuration file for the Sphinx documentation builder. +import os +import sys +sys.path.insert(0, os.path.abspath('../..')) + +# -- Project information ----------------------------------------------------- +project = 'llm_negotiation' +copyright = '2023, Your Name' +author = 'Your Name' + +# -- General configuration --------------------------------------------------- +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.viewcode', + 'sphinx.ext.napoleon', + 'sphinx.ext.autosummary', + 'sphinx.ext.intersphinx', + 'sphinx.ext.mathjax', + 'sphinxcontrib.mermaid', + 'sphinx_rtd_theme', +] + +templates_path = ['_templates'] +exclude_patterns = [] + +# -- Options for HTML output ------------------------------------------------- +html_theme = 'sphinx_rtd_theme' +html_static_path = ['_static'] + +# -- Napoleon settings ------------------------------------------------------- +napoleon_google_docstring = True +napoleon_numpy_docstring = False +napoleon_include_init_with_doc = True +napoleon_include_private_with_doc = False +napoleon_include_special_with_doc = True +napoleon_use_admonition_for_examples = False +napoleon_use_admonition_for_notes = False +napoleon_use_admonition_for_references = False +napoleon_use_ivar = False +napoleon_use_param = True +napoleon_use_rtype = True +napoleon_preprocess_types = False +napoleon_type_aliases = None +napoleon_attr_annotations = True + +# -- Path setup -------------------------------------------------------------- +# Make sure the project's modules can be found by Sphinx +sys.path.insert(0, os.path.abspath('../../src')) \ No newline at end of file diff --git a/src_code_for_reproducibility/docs/source/index.rst b/src_code_for_reproducibility/docs/source/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..cdc1b79539342a9c95ca0cdd9219bce74a7b2c8a --- /dev/null +++ b/src_code_for_reproducibility/docs/source/index.rst @@ -0,0 +1,22 @@ +Welcome to LLM Negotiation's documentation! +=========================================== +This library is a collection of tools for training and evaluating LLM-based agents in multi-agent environments. It is designed to be easy to use and extend. + +.. toctree:: + :maxdepth: 3 + :caption: Contents: + + installation + marl_standard + environments + launch + usage + modules + contributing + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` \ No newline at end of file diff --git a/src_code_for_reproducibility/docs/source/installation.rst b/src_code_for_reproducibility/docs/source/installation.rst new file mode 100644 index 0000000000000000000000000000000000000000..b148f25d92fd8308e9695f7c17c2b91fb0c9a2c6 --- /dev/null +++ b/src_code_for_reproducibility/docs/source/installation.rst @@ -0,0 +1,10 @@ +Installation +=========== + +To install the package, run: + +.. code-block:: bash + + git clone https://github.com/yourusername/llm_negotiation.git + cd llm_negotiation + pip install -e . \ No newline at end of file diff --git a/src_code_for_reproducibility/docs/source/marl_standard.rst b/src_code_for_reproducibility/docs/source/marl_standard.rst new file mode 100644 index 0000000000000000000000000000000000000000..b5ea5529892c611b34255645ec68537a236754cf --- /dev/null +++ b/src_code_for_reproducibility/docs/source/marl_standard.rst @@ -0,0 +1,141 @@ +================= +Abstract Standard for Multi-Agent Negotiation Environments +================= + +Multi-Agent Negotiation Environments require more features than gymnasium environments in order to be used as interfaces in general game running code. +The two fundamental differences between gymnasium environments and Multi-Agent Negotiation Environments are: + +1. Response from the LLM is a text action, not a discrete action. Therefore, appropriate parsing of the text is required. The model may need to be run multiple times to get the full action. + This is why we introduce the `AgentHandler` class, which is responsible for parsing the LLM's response. +2. The environment needs to be able to handle multi-agent interactions. + This is why we introduce the `NegotiationEnvironment` class, which is responsible for handling the multi-agent interactions. +3. MARL environments are complex to describe. In different contexts, the same environment may be described differently. Therefore, both the environement and the agent handlers are + responsible for describing a particular trajectory. This information is given by the `get_log_info` method. +4. There might be a lot of overlap between the neural networks used by each agent. For instance, the same model may be used for all agents. This motivates a requirement for a + policy identifier for each agent. + +Taking inspiration from the `gymnasium `_ library, we introduce a new standard for Multi-Agent Negotiation Environments. + +Our standard is based on the following features: + +Environments are of the form: + +.. code-block:: python + + class MarlEnvironment(): + + def __init__(self): + """Initialize the environment.""" + pass + + def reset(self): + """Reset the environment to an initial state and return the initial observation. + Returns: + observation (dict): A dictionary where keys are agent identifiers and values are observations. + """ + # (...) + return observation + + def step(self, actions): + """Take a step in the environment using the provided actions. + + Args: + actions (dict): A dictionary where keys are agent identifiers and values are actions. + + Returns: + observations (dict): A dictionary where keys are agent identifiers and values are observations. + reward (dict): A dictionary where keys are agent identifiers and values are rewards. + done (bool): Whether the episode has ended. + info (dict): Additional information about the environment. + """ + # (...) + return observations, done, info + + def get_log_info(self): + """Get additional information about the environment. This information is used to log the game. + Returns: + log_info (dict): Information about the environment required to log the game. + """ + # (...) + return log_info + + def render(self): + """Render the current state of the environment.""" + pass + + def close(self): + """Perform any necessary cleanup.""" + pass + + + class AgentState(): + + def __init__(self): + """Initialize the agent state.""" + pass + + def step(self, observation_from_env, policy_output=None): + """Update the agent state based on the observation and action. + The action is the output of the LLM. + """ + + Args: + observation_from_env (dict): The observation of the environment. + policy_output : The output of the policy. + + Returns: + policy_id (str): The policy identifier. + policy_input (dict): The input to the policy. + action : The official action to be sent to the environment. + done (bool): Whether the LLM action is ready to be sent to the environment. + info (dict): Additional information about the agent. + """ + # (...) + return policy_id, policy_input, action, done, info + + def get_log_info(self): + """Get information about the agent required to log a trajectory. + Returns: + log_info (dict): Information about the agent required to log a trajectory. + """ + # (...) + return log_info + + def render(self): + """Render the current state of the environment.""" + pass + + def close(self): + """Perform any necessary cleanup.""" + pass + + +Implicitely, the keys of the `observations` in the `step` method of the `MarlEnvironment` interface represent the set of agents from which an action is expected at the current step. The next step should only expect actions from the agents in the `observations` dictionary. + +As you can see, both classes have a `get_log_info` method. This method is used to log the game. It returns a dictionary with keys being the agent identifiers and values being the information to log. The reason we need this is because the environment and the agent handler may need to log different information. It makes it easier to log from the perspective of each agent. The core environment class should not need to know about the details of the agent handler. + + + +Running Environments in Parallel +-------------------------------- +This standard allows the use of the `run_batched_matches` function (TODO: link) to run environments in an efficient way. The core idea is to batch the policy calls for all agents in the environment. + +.. note:: + The ``run_batched_matches`` function allows you to run multiple negotiation games, or "matches," in parallel. + After each environment is initialized, the function continuously loops over all active matches and checks which agents + are still pending actions. Each agent's logic can require multiple calls to the policy (e.g., an LLM) before an action + becomes "ready" to be sent to the environment. (For instance, an agent might need multiple policy calls before having a string which can be parsed into a valid action.) While an agent is waiting for a policy output, these calls for all agents across all matches are grouped together by unique policy identifier and processed in batch for efficiency. This is the core functionality of the ``run_batched_matches`` function. + + Only once all actions from the required agents at a given step for an environment are ready does the function make a single ``env.step(...)`` call; this ensures + every match moves forward in lockstep for all its active agents. As soon as an environment signals it is done, the function + retrieves logged information from both the environment and the agent states before removing this match from the active set. + + If there are more matches waiting to be processed, they are then started one by one to maintain the specified degree of parallelism. + This batching approach provides an efficient mechanism to handle multi-agent or multi-policy environments, ensuring minimal + overhead and a clear, unified flow for stepping through matches. + +Here is a diagram that shows how the `run_batched_matches` function works at a high level: + +.. image:: media/runbatch.png + :alt: Alternate text for the image + :width: 1000px diff --git a/src_code_for_reproducibility/docs/source/modules.rst b/src_code_for_reproducibility/docs/source/modules.rst new file mode 100644 index 0000000000000000000000000000000000000000..e9ff8ac1a89c7bd18e69d633121f5a4022ac6fdf --- /dev/null +++ b/src_code_for_reproducibility/docs/source/modules.rst @@ -0,0 +1,7 @@ +src +=== + +.. toctree:: + :maxdepth: 4 + + src diff --git a/src_code_for_reproducibility/docs/source/src.environments.dond.dond_agent.rst b/src_code_for_reproducibility/docs/source/src.environments.dond.dond_agent.rst new file mode 100644 index 0000000000000000000000000000000000000000..8fab765a9c7e749bd446533fdddb5fa5b55e6635 --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.environments.dond.dond_agent.rst @@ -0,0 +1,7 @@ +src.environments.dond.dond\_agent module +======================================== + +.. automodule:: src.environments.dond.dond_agent + :members: + :undoc-members: + :show-inheritance: diff --git a/src_code_for_reproducibility/docs/source/src.environments.dond.dond_game.rst b/src_code_for_reproducibility/docs/source/src.environments.dond.dond_game.rst new file mode 100644 index 0000000000000000000000000000000000000000..d0e595aad169a5a8456f83afe5029e7475d7c9e7 --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.environments.dond.dond_game.rst @@ -0,0 +1,7 @@ +src.environments.dond.dond\_game module +======================================= + +.. automodule:: src.environments.dond.dond_game + :members: + :undoc-members: + :show-inheritance: diff --git a/src_code_for_reproducibility/docs/source/src.environments.dond.dond_player.rst b/src_code_for_reproducibility/docs/source/src.environments.dond.dond_player.rst new file mode 100644 index 0000000000000000000000000000000000000000..bab97f1009eb2d5c4e387ac6a83982a51e33c9e3 --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.environments.dond.dond_player.rst @@ -0,0 +1,7 @@ +src.environments.dond.dond\_agent module +========================================= + +.. automodule:: src.environments.dond.dond_agent + :members: + :undoc-members: + :show-inheritance: diff --git a/src_code_for_reproducibility/docs/source/src.environments.dond.dond_statistics_funcs.rst b/src_code_for_reproducibility/docs/source/src.environments.dond.dond_statistics_funcs.rst new file mode 100644 index 0000000000000000000000000000000000000000..4c4d3764f9333b8a2975069160f40952f324a1a8 --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.environments.dond.dond_statistics_funcs.rst @@ -0,0 +1,7 @@ +src.environments.dond.dond\_statistics\_funcs module +==================================================== + +.. automodule:: src.environments.dond.dond_statistics_funcs + :members: + :undoc-members: + :show-inheritance: diff --git a/src_code_for_reproducibility/docs/source/src.environments.dond.dond_training_data_funcs.rst b/src_code_for_reproducibility/docs/source/src.environments.dond.dond_training_data_funcs.rst new file mode 100644 index 0000000000000000000000000000000000000000..cf31d696a3ed580e24f3c5dffd6f7a2851d16320 --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.environments.dond.dond_training_data_funcs.rst @@ -0,0 +1,7 @@ +src.environments.dond.dond\_training\_data\_funcs module +======================================================== + +.. automodule:: src.environments.dond.dond_training_data_funcs + :members: + :undoc-members: + :show-inheritance: diff --git a/src_code_for_reproducibility/docs/source/src.environments.env_imports.rst b/src_code_for_reproducibility/docs/source/src.environments.env_imports.rst new file mode 100644 index 0000000000000000000000000000000000000000..4354ba27eee9f0e0fa3f4f0e5d9131c256a4be57 --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.environments.env_imports.rst @@ -0,0 +1,7 @@ +src.environments.env\_imports module +==================================== + +.. automodule:: src.environments.env_imports + :members: + :undoc-members: + :show-inheritance: diff --git a/src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_game.rst b/src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_game.rst new file mode 100644 index 0000000000000000000000000000000000000000..ede471ef9675c780410189fcf63df0c1a05496d0 --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_game.rst @@ -0,0 +1,7 @@ +src.environments.ipd.ipd\_game module +===================================== + +.. automodule:: src.environments.ipd.ipd_game + :members: + :undoc-members: + :show-inheritance: diff --git a/src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_log_funcs.rst b/src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_log_funcs.rst new file mode 100644 index 0000000000000000000000000000000000000000..edec187f4876cdf653ae4f91035f43bc877a7d40 --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_log_funcs.rst @@ -0,0 +1,7 @@ +src.environments.ipd.ipd\_log\_funcs module +=========================================== + +.. automodule:: src.environments.ipd.ipd_log_funcs + :members: + :undoc-members: + :show-inheritance: diff --git a/src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_statistics_funcs.rst b/src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_statistics_funcs.rst new file mode 100644 index 0000000000000000000000000000000000000000..5f54afac07c4d477067ef4c2bf5d883b236cf5fc --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_statistics_funcs.rst @@ -0,0 +1,7 @@ +src.environments.ipd.ipd\_statistics\_funcs module +================================================== + +.. automodule:: src.environments.ipd.ipd_statistics_funcs + :members: + :undoc-members: + :show-inheritance: diff --git a/src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_training_data_funcs.rst b/src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_training_data_funcs.rst new file mode 100644 index 0000000000000000000000000000000000000000..8e4cecea10e25644ff677416823069cd65b500c5 --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_training_data_funcs.rst @@ -0,0 +1,7 @@ +src.environments.ipd.ipd\_training\_data\_funcs module +====================================================== + +.. automodule:: src.environments.ipd.ipd_training_data_funcs + :members: + :undoc-members: + :show-inheritance: diff --git a/src_code_for_reproducibility/docs/source/src.environments.rst b/src_code_for_reproducibility/docs/source/src.environments.rst new file mode 100644 index 0000000000000000000000000000000000000000..221ed1c07ebea145cd23bc06c6474d34b1d8a33e --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.environments.rst @@ -0,0 +1,25 @@ +src.environments package +======================== + +.. automodule:: src.environments + :members: + :undoc-members: + :show-inheritance: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + src.environments.dond + src.environments.ipd + +Submodules +---------- + +.. toctree:: + :maxdepth: 4 + + src.environments.env_imports + src.environments.environment_imports diff --git a/src_code_for_reproducibility/docs/source/src.experiments.arithmetic_test.rst b/src_code_for_reproducibility/docs/source/src.experiments.arithmetic_test.rst new file mode 100644 index 0000000000000000000000000000000000000000..68e0f5da020aee80cc8895ca650a6067317f4bcd --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.experiments.arithmetic_test.rst @@ -0,0 +1,7 @@ +src.experiments.arithmetic\_test module +======================================= + +.. automodule:: src.experiments.arithmetic_test + :members: + :undoc-members: + :show-inheritance: diff --git a/src_code_for_reproducibility/docs/source/src.rst b/src_code_for_reproducibility/docs/source/src.rst new file mode 100644 index 0000000000000000000000000000000000000000..d2dcff7e18f979e933a27467f4893f3ad5372a88 --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.rst @@ -0,0 +1,28 @@ +src package +=========== + +.. automodule:: src + :members: + :undoc-members: + :show-inheritance: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + src.environments + src.experiments + src.generation + src.models + src.training + src.utils + +Submodules +---------- + +.. toctree:: + :maxdepth: 4 + + src.run diff --git a/src_code_for_reproducibility/docs/source/src.utils.rst b/src_code_for_reproducibility/docs/source/src.utils.rst new file mode 100644 index 0000000000000000000000000000000000000000..4f5cb352cc9ec645c968d0ae99798d47c018c750 --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.utils.rst @@ -0,0 +1,24 @@ +src.utils package +================= + +.. automodule:: src.utils + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +.. toctree:: + :maxdepth: 4 + + src.utils.common_imports + src.utils.export_ppo_training_set + src.utils.extra_stats + src.utils.inherit_args + src.utils.log_gpu_usage + src.utils.log_statistics + src.utils.model_to_cpu + src.utils.parallel_shuffle + src.utils.quick_stats + src.utils.update_start_epoch diff --git a/src_code_for_reproducibility/markov_games/__pycache__/agent.cpython-312.pyc b/src_code_for_reproducibility/markov_games/__pycache__/agent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c7a88e6f07a6e33667afa5f45af17ff3e1101f1 Binary files /dev/null and b/src_code_for_reproducibility/markov_games/__pycache__/agent.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/__pycache__/run_markov_games.cpython-312.pyc b/src_code_for_reproducibility/markov_games/__pycache__/run_markov_games.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..755b1bbc0ffb2dcd02d0227e0172e1a4be1c2411 Binary files /dev/null and b/src_code_for_reproducibility/markov_games/__pycache__/run_markov_games.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/__pycache__/simulation.cpython-312.pyc b/src_code_for_reproducibility/markov_games/__pycache__/simulation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be84edf7bf77be552ce53cfe5c1b32014b7d7a03 Binary files /dev/null and b/src_code_for_reproducibility/markov_games/__pycache__/simulation.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/diplomacy/diplomacy_env.py b/src_code_for_reproducibility/markov_games/diplomacy/diplomacy_env.py new file mode 100644 index 0000000000000000000000000000000000000000..9b72612c43f2535d353b0157ce72a9b79c23cbb3 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/diplomacy/diplomacy_env.py @@ -0,0 +1,230 @@ +from typing import Dict, List, Tuple, Optional, Any +from diplomacy import Game +import random + +class DiplomacyEnv: + """Multi-Agent Reinforcement Learning environment for Diplomacy. + + This class wraps the Diplomacy game engine to provide an interface + compliant with the MARL standard. + """ + + def __init__(self, random_seed=None, map_name="standard", game_id=None, rules=None, max_steps=50): + """Initialize the Diplomacy environment. + + Args: + map_name: The name of the map to use (default: "standard") + game_id: Optional game ID + rules: Optional rules to apply to the game + max_steps: Maximum number of steps before forcing game end (default: 10) + """ + self.random_seed = random_seed + self.map_name = map_name + self.game_id = game_id + self.rules = rules or [] + self.game = None + self.active_powers = [] + self.render_mode = None + self.max_steps = max_steps + self.current_steps = 0 + + def reset(self): + """Reset the environment to an initial state and return the initial observation. + + Returns: + observation: A dictionary where keys are agent identifiers and values are observations. + """ + # Initialize a new game + self.game = Game(game_id=self.game_id, map_name=self.map_name) + + # Apply rules + for rule in self.rules: + self.game.add_rule(rule) + + # Determine active powers (not eliminated) + self.active_powers = [name for name, power in self.game.powers.items() + if not power.is_eliminated()] + + # Reset step counter + self.current_steps = 0 + + # Create initial observations for all powers + observations = {} + for power_name in self.active_powers: + observations[power_name] = self._create_observation(power_name) + + return observations + + def step(self, actions): + """Take a step in the environment using the provided actions. + + Args: + actions: A dictionary where keys are agent identifiers and values are actions. + + Returns: + observations: A dictionary where keys are agent identifiers and values are observations. + done: Whether the episode has ended. + info: Additional information about the environment. + """ + print(f"stepping {self.current_steps}") + self.current_steps += 1 + # Apply actions (orders) for each power + for power_name, action in actions.items(): + if power_name in self.active_powers: + orders = action.get("orders", []) + wait = action.get("wait", True) + + # Set orders for the power + if orders: + self.game.set_orders(power_name, orders) + + # Set wait flag + self.game.set_wait(power_name, wait) + + # Check if all active powers are ready to proceed + if self.game.does_not_wait(): + # Process the current phase + self.game.process() + + + # Update active powers list after processing + self.active_powers = [name for name, power in self.game.powers.items() + if not power.is_eliminated()] + + # Create observations for all active powers + observations = {} + for power_name in self.active_powers: + observations[power_name] = self._create_observation(power_name) + + # Check if the game is done (either naturally or due to max steps) + done = self.game.is_game_done or self.current_steps >= self.max_steps + + # Create info dict + info = { + "phase": self.game.get_current_phase(), + "active_powers": self.active_powers, + "centers": self.game.get_centers(), + "units": self.game.get_units(), + "current_steps": self.current_steps, + "max_steps_reached": self.current_steps >= self.max_steps + } + + return observations, done, info + + def _create_observation(self, power_name): + """Create observation for a specific power. + + Args: + power_name: The name of the power + + Returns: + An observation dictionary + """ + observation = { + "phase": self.game.get_current_phase(), + "units": self.game.get_units(), + "centers": self.game.get_centers(), + "orderable_locations": self.game.get_orderable_locations(power_name), + "order_status": self.game.get_order_status(power_name), + "possible_orders": self._get_possible_orders_for_power(power_name) + } + return observation + + def _get_possible_orders_for_power(self, power_name): + """Get all possible orders for a power's units. + + Args: + power_name: The name of the power + + Returns: + A dictionary mapping units to their possible orders + """ + all_possible_orders = self.game.get_all_possible_orders() + + # Filter for only the locations where this power has units + power_units = self.game.get_units(power_name) + power_unit_locations = [unit[2:] for unit in power_units] + + # For retreat phases, include retreating units + if self.game.phase_type == 'R': + power = self.game.get_power(power_name) + power_unit_locations.extend([unit[2:] for unit in power.retreats]) + + # For adjustment phases, include buildable locations + elif self.game.phase_type == 'A': + power = self.game.get_power(power_name) + # If we have more centers than units, we can build + if len(power.centers) > len(power.units): + buildable_sites = self.game._build_sites(power) + power_unit_locations.extend(buildable_sites) + # If we have more units than centers, we need to remove + elif len(power.units) > len(power.centers): + # All units are candidates for removal + pass + + # Filter the possible orders to only those for this power's units/locations + power_possible_orders = {} + for loc, orders in all_possible_orders.items(): + if loc[:3] in power_unit_locations: + power_possible_orders[loc] = orders + + return power_possible_orders + + def get_log_info(self): + """Get additional information about the environment for logging. + + Returns: + log_info: Information about the environment required to log the game. + """ + if not self.game: + return {} + + return { + "game_id": self.game.game_id, + "phase": self.game.get_current_phase(), + "map_name": self.game.map_name, + "centers": self.game.get_centers(), + "units": self.game.get_units(), + "powers": {name: { + "units": power.units, + "centers": power.centers, + "is_eliminated": power.is_eliminated(), + "order_status": self.game.get_order_status(name) + } for name, power in self.game.powers.items()}, + "orders": self.game.get_orders(), + "active_powers": self.active_powers, + "is_game_done": self.game.is_game_done, + "outcome": self.game.outcome if self.game.is_game_done else None + } + + def render(self, mode='human'): + """Render the current state of the environment. + + Args: + mode: The rendering mode ('human', 'svg', etc.) + + Returns: + The rendered image if applicable + """ + self.render_mode = mode + if self.game: + if mode == 'human': + # Just print basic game state + print(f"Game: {self.game.game_id}") + print(f"Phase: {self.game.get_current_phase()}") + print(f"Active Powers: {self.active_powers}") + print("Supply Centers:") + for power_name, centers in self.game.get_centers().items(): + print(f" {power_name}: {centers}") + print("Units:") + for power_name, units in self.game.get_units().items(): + print(f" {power_name}: {units}") + return None + elif mode == 'svg': + # Return SVG representation + return self.game.render(output_format='svg') + return None + + def close(self): + """Perform any necessary cleanup.""" + self.game = None \ No newline at end of file diff --git a/src_code_for_reproducibility/markov_games/diplomacy/diplomacy_logging.py b/src_code_for_reproducibility/markov_games/diplomacy/diplomacy_logging.py new file mode 100644 index 0000000000000000000000000000000000000000..3f60e82d9738c116c0b8b8d3f7818eddebb18fa2 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/diplomacy/diplomacy_logging.py @@ -0,0 +1,360 @@ +import os +import json +from utils.common_imports import * + + + +def diplomacy_log_match( + path, + agents_log_info, + env_log_info, + metrics_func=None, + metrics_func_args=None + ): + """ + Logs the Diplomacy game data and generates HTML visualizations using the get_log_info methods. + + Args: + path (str): Base path to save the data. + agents_log_info (list): List of agent information dictionaries containing the get_log_info results. + env_log_info (dict): Environment information from its get_log_info method. + metrics_func (str, optional): Name of the function to calculate metrics. + metrics_func_args (dict, optional): Arguments for the metrics function. + """ + # Create directory structure + os.makedirs(path, exist_ok=True) + + # Save the environment log info + env_log_path = os.path.join(path, "env_log.json") + with open(env_log_path, "w") as f: + json.dump(env_log_info, f, indent=4, default=_json_serialize) + + # Process each agent's log info + for agent_log in agents_log_info: + power_name = agent_log["power_name"] + + # Define paths for raw data and statistics subfolders + power_path = os.path.join(path, power_name) + raw_data_path = os.path.join(power_path, "raw_data") + statistics_path = os.path.join(power_path, "statistics") + + # Ensure directories exist + os.makedirs(raw_data_path, exist_ok=True) + os.makedirs(statistics_path, exist_ok=True) + + # Determine the next available file number for raw data + raw_files = os.listdir(raw_data_path) + raw_numbers = [int(f.split('_')[-1].split('.')[0]) for f in raw_files if f.startswith("log_")] + next_raw_number = max(raw_numbers, default=0) + 1 + raw_file = os.path.join(raw_data_path, f"log_{next_raw_number}.json") + + # Save agent log info + with open(raw_file, "w") as f: + json.dump(agent_log, f, indent=4, default=_json_serialize) + + # Log metrics if a metrics function is provided + if metrics_func: + metrics_files = os.listdir(statistics_path) + metrics_numbers = [int(f.split('_')[-1].split('.')[0]) for f in metrics_files if f.startswith("metrics_")] + next_metrics_number = max(metrics_numbers, default=0) + 1 + metrics_file = os.path.join(statistics_path, f"metrics_{next_metrics_number}.json") + + metrics = globals()[metrics_func](agent_log, info, **metrics_func_args) + with open(metrics_file, "w") as f: + json.dump(metrics, f, indent=4) + + # Generate the HTML visualization + html_content = generate_diplomacy_html(agents_log_info, env_log_info) + + # Ensure the html directory exists + html_path = os.path.join(path, "html") + os.makedirs(html_path, exist_ok=True) + + # Determine the next available file number for HTML + html_files = os.listdir(html_path) + html_numbers = [int(f.split('_')[-1].split('.')[0]) for f in html_files if f.startswith("game_summary_")] + next_html_number = max(html_numbers, default=0) + 1 + html_file = os.path.join(html_path, f"game_summary_{next_html_number}.html") + + # Save the HTML content to a file + with open(html_file, "w") as f: + f.write(html_content) + +def generate_diplomacy_html(agent_infos, env_info): + """ + Generate HTML visualization for a Diplomacy game. + + Args: + agent_infos (list): List of agent information dictionaries from get_log_info. + env_info (dict): Environment information from get_log_info. + + Returns: + str: HTML content for the game visualization. + """ + # Extract game information + game_id = env_info.get("game_id", "Unknown") + phase = env_info.get("phase", "Unknown") + map_name = env_info.get("map_name", "standard") + is_game_done = env_info.get("is_game_done", False) + outcome = env_info.get("outcome", []) + + centers = env_info.get("centers", {}) + units = env_info.get("units", {}) + + # HTML head and style + html_content = """ + + + + + + Diplomacy Game {game_id} + + + +
+

Game Information

+
+
+

Game Details

+

Game ID: {game_id}

+

Phase: {phase}

+

Map: {map_name}

+

Status: {status}

+
+
+

Supply Centers

+
+ """.format( + game_id=game_id, + phase=phase, + map_name=map_name, + status="Completed" if is_game_done else "Active" + ) + + # Add supply center information + for power, power_centers in centers.items(): + html_content += f""" +
+ {power}: {len(power_centers)} +
+ """ + + html_content += """ +
+
+
+ """ + + # Add outcome if game is done + if is_game_done and outcome: + winners = outcome[1:] if len(outcome) > 1 else ["Draw"] + html_content += f""" +
+

Game Outcome

+

Winners: {', '.join(winners)}

+
+ """ + + html_content += """ +
+
+ """ + + # Add each power's information + for agent_log in agent_infos: + power_name = agent_log["power_name"] + power_class = power_name.lower() + orders = agent_log.get("orders", []) + message_history = agent_log.get("message_history", []) + + html_content += f""" +
+
{power_name}
+ +
+

Units

+
    + """ + + # Add units information + power_units = units.get(power_name, []) + for unit in power_units: + html_content += f"
  • {unit}
  • " + + html_content += """ +
+
+ +
+
Final Orders
+
    + """ + + # Add orders + for order in orders: + html_content += f"
  • {order}
  • " + + html_content += """ +
+
+ """ + + # Add message history + for message in message_history: + if isinstance(message, dict): + # Skip system messages or handle differently + if message.get("role") == "system": + continue + + role = message.get("role", "unknown") + content = message.get("content", "") + + role_class = "user" if role == "user" else "assistant" + role_display = "Environment" if role == "user" else f"LLM ({power_name})" + + # Escape HTML characters in content + content = content.replace("<", "<").replace(">", ">").replace("\n", "
") + + html_content += f""" +
+
{role_display}
+

{content}

+
+ """ + elif isinstance(message, str): + # Simple string messages (may be used in some implementations) + html_content += f""" +
+

{message}

+
+ """ + + html_content += """ +
+ """ + + html_content += """ +
+ + + """ + + return html_content + +def _json_serialize(obj): + """ + A helper function to convert non-JSON-serializable objects + (like OrderResult) into strings or dicts. + """ + # Check for the specific object types you know are problematic + if obj.__class__.__name__ == "OrderResult": + # Return a string representation or a dict + return str(obj) + + # Fallback: attempt to convert anything else to string + return str(obj) \ No newline at end of file diff --git a/src_code_for_reproducibility/markov_games/diplomacy/diplomacy_logging_for_training.py b/src_code_for_reproducibility/markov_games/diplomacy/diplomacy_logging_for_training.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src_code_for_reproducibility/markov_games/ipd/Ipd_hard_coded_agents.py b/src_code_for_reproducibility/markov_games/ipd/Ipd_hard_coded_agents.py new file mode 100644 index 0000000000000000000000000000000000000000..a974bddc69c1a3002ce5d84aac868f59bb731900 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/ipd/Ipd_hard_coded_agents.py @@ -0,0 +1,72 @@ +from dataclasses import dataclass +from typing import Any, Tuple + +from mllm.markov_games.ipd.ipd_agent import IPDAgent +from mllm.markov_games.rollout_tree import AgentActLog, ChatTurn + + +@dataclass +class AlwaysCooperateIPDAgent(IPDAgent): + async def act(self, observation) -> Tuple[Any, AgentActLog]: + """ + Always plays the cooperate action, ignoring observation. + Returns the configured cooperate_string so the simulation parses it as "C". + """ + + action = self.cooperate_string + + # Log a minimal, structured chat turn for consistency with other agents + turn_text = f"Playing cooperate: {action}" + self.state.chat_history.append( + ChatTurn( + agent_id=self.agent_id, + role="assistant", + content=turn_text, + is_state_end=True, + ) + ) + + act_log = AgentActLog( + chat_turns=[self.state.chat_history[-1]], + info=None, + ) + + # Advance internal counters similar to IPDAgent semantics + self.state.chat_counter = len(self.state.chat_history) + self.state.round_nb = observation.round_nb + + return action, act_log + + +@dataclass +class AlwaysDefectIPDAgent(IPDAgent): + async def act(self, observation) -> Tuple[Any, AgentActLog]: + """ + Always plays the defect action, ignoring observation. + Returns the configured defect_string so the simulation parses it as "D". + """ + + action = self.defect_string + + # Log a minimal, structured chat turn for consistency with other agents + turn_text = f"Playing defect: {action}" + self.state.chat_history.append( + ChatTurn( + agent_id=self.agent_id, + role="assistant", + content=turn_text, + is_state_end=True, + ) + ) + + act_log = AgentActLog( + chat_turns=[self.state.chat_history[-1]], + info=None, + ) + + # Advance internal counters similar to IPDAgent semantics + self.state.chat_counter = len(self.state.chat_history) + self.state.round_nb = observation.round_nb + + return action, act_log + diff --git a/src_code_for_reproducibility/markov_games/ipd/__init__.py b/src_code_for_reproducibility/markov_games/ipd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7f2388f6380fd3a54a2c80d1f1f77ae1d1fd4c8 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/ipd/__init__.py @@ -0,0 +1,7 @@ +from .Ipd_hard_coded_agents import AlwaysCooperateIPDAgent, AlwaysDefectIPDAgent + +__all__ = [ + "AlwaysCooperateIPDAgent", + "AlwaysDefectIPDAgent", +] + diff --git a/src_code_for_reproducibility/markov_games/ipd/__pycache__/Ipd_hard_coded_agents.cpython-312.pyc b/src_code_for_reproducibility/markov_games/ipd/__pycache__/Ipd_hard_coded_agents.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2bd4631ea1c691af05899c36322cfb2ffb4bba9a Binary files /dev/null and b/src_code_for_reproducibility/markov_games/ipd/__pycache__/Ipd_hard_coded_agents.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/ipd/__pycache__/ipd_agent.cpython-312.pyc b/src_code_for_reproducibility/markov_games/ipd/__pycache__/ipd_agent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b76920f80b27aa0c326141cd8fdf1166fbb9258b Binary files /dev/null and b/src_code_for_reproducibility/markov_games/ipd/__pycache__/ipd_agent.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/ipd/__pycache__/ipd_simulation.cpython-312.pyc b/src_code_for_reproducibility/markov_games/ipd/__pycache__/ipd_simulation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..342211af722d2aef07232358e0adfcc674718e16 Binary files /dev/null and b/src_code_for_reproducibility/markov_games/ipd/__pycache__/ipd_simulation.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/ipd/__pycache__/ipd_statistics.cpython-312.pyc b/src_code_for_reproducibility/markov_games/ipd/__pycache__/ipd_statistics.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a41b797fd13e8f06b9b4d308b7ab53832604ffc4 Binary files /dev/null and b/src_code_for_reproducibility/markov_games/ipd/__pycache__/ipd_statistics.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/ipd/ipd_agent.py b/src_code_for_reproducibility/markov_games/ipd/ipd_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..3a6f64b542dcc9ee7e114e617bde9cc1181ea301 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/ipd/ipd_agent.py @@ -0,0 +1,115 @@ +import copy +import json +import random +import re +from collections.abc import Callable +from copy import deepcopy +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Union + +from mllm.markov_games.agent import Agent +from mllm.markov_games.rollout_tree import AgentActLog, ChatTurn + + +@dataclass +class IPDAgentState: + """ + TOWRITE + """ + + nb_retries: int + round_nb: int + chat_counter: int + chat_history: List[ChatTurn] + + +@dataclass +class IPDAgent(Agent): + seed: int + agent_id: str + agent_name: str + policy: Callable[[List[Dict]], str] + intro_prompt: str # Introduction prompt explaining the game rules + goal_prompt: str # Prompt explaining the agent's goal + strategy_prompt: str # Prompt suggesting a strategy to the agent + max_errors: int # Maximum number of errors allowed before default action + allow_reasoning: bool # Whether to allow reasoning in the response + max_reasoning_chars: int # Maximum number of characters for reasoning + cooperate_string: str # string parsed as playing cooperate by simulation + defect_string: str # string parsed as playing defect by simulation + + def __post_init__(self): + self.state = IPDAgentState( + nb_retries=0, round_nb=0, chat_counter=0, chat_history=[] + ) + + async def act(self, observation) -> Tuple[Any, AgentActLog]: + """ + TOWRITE + """ + + action = None + action_is_ready = False + round_nb = observation.round_nb + + # If it's the first round, we need to send the intro prompt + if round_nb == 0 and self.state.chat_counter == 0: + self.state.chat_history.append( + ChatTurn( + agent_id=self.agent_id, + role="user", + content=self.intro_prompt, + is_state_end=True, + ) + ) + + # If new round + if round_nb > self.state.round_nb: + coagent_action = observation.last_coagent_move + user_message = f"Last round, the other agent played {coagent_action}." + self.state.chat_history.append( + ChatTurn( + agent_id=self.agent_id, + role="user", + content=user_message, + is_state_end=True, + ) + ) + + # If not new round, try to get valid action from policy + output_chat_turn: ChatTurn = await self.policy( + state=self.state.chat_history, + agent_id=self.agent_id, + regex=f"({self.cooperate_string}|{self.defect_string})", + ) + self.state.chat_history.append(output_chat_turn) + action = output_chat_turn.content + + agent_step_log = AgentActLog( + chat_turns=self.state.chat_history[self.state.chat_counter :], info=None + ) + self.state.chat_counter = len(self.state.chat_history) + self.state.round_nb = round_nb + + return action, agent_step_log + + def get_safe_copy(self): + """ + Return a safe copy of the agent. + """ + agent_copy = copy.copy(self) + agent_copy.state = copy.deepcopy(self.state) + return agent_copy + + def reset(self): + self.state = IPDAgentState() + raise NotImplementedError + + def render(self): + pass + + def close(self): + pass + + def get_agent_info(self): + pass diff --git a/src_code_for_reproducibility/markov_games/ipd/ipd_simulation.py b/src_code_for_reproducibility/markov_games/ipd/ipd_simulation.py new file mode 100644 index 0000000000000000000000000000000000000000..238c7d319bd6e679284d2636aeffa194662a664b --- /dev/null +++ b/src_code_for_reproducibility/markov_games/ipd/ipd_simulation.py @@ -0,0 +1,162 @@ +import copy +import random +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np + +from mllm.markov_games.markov_game import Simulation +from mllm.markov_games.rollout_tree import SimulationStepLog +from mllm.utils.get_coagent_id import get_coagent_id + + +@dataclass +class IPDState: + """ + State of the Iterated Prisoner's Dilemma game. + """ + + round_nb: int = 0 + done: bool = False + last_moves: Dict[str, str] | None = None + + +@dataclass +class IPDObs: + """ + Observation in Iterated Prisoner's Dilemma game. + """ + + round_nb: int + last_coagent_move: str | None + + +class IPD(Simulation): + """ + Iterated Prisoner's Dilemma simulation following the standard. + + In each round of the game, two agents simultaneously choose to either cooperate (C) or defect (D). + The payoffs are as follows: + - If both cooperate: Both receive the "reward" (usually 3 points) + - If both defect: Both receive the "punishment" (usually 1 point) + - If one cooperates and one defects: The defector receives the "temptation" (usually 5 points) + and the cooperator receives the "sucker" payoff (usually 0 points) + + The game is played for a specified number of rounds. + """ + + def __init__( + self, + agent_ids: List[str], + agent_names: List[str], + seed: int, + rounds_per_game: int, + reward: float, # Both cooperate + punishment: float, # Both defect + temptation: float, # Defector's reward when other cooperates + sucker: float, # Cooperator's reward when other defects + cooperate_actions: List[str], + defect_actions: List[str], + ): + self.agent_ids = agent_ids + self.agent_names = agent_names + self.seed = seed + self.rounds_per_game = rounds_per_game + self.reward = reward + self.punishment = punishment + self.temptation = temptation + self.sucker = sucker + self.cooperate_actions = cooperate_actions + self.defect_actions = defect_actions + self.state = IPDState() + + def step(self, actions: Dict[str, str]) -> Tuple[bool, SimulationStepLog]: + """ + Take a step in the environment using the provided actions. + Here, the observations are just the states of the game. + + Args: + actions (dict): A dictionary where keys are agent identifiers and values are actions ('C' or 'D'). + + Returns: + observations (dict): A dictionary where keys are agent identifiers and values are observations. + done (bool): Whether the episode has ended. + info (dict): Additional information about the environment. + """ + + # Calculate rewards using payoff matrix + agent0_action = actions[self.agent_ids[0]] + agent1_action = actions[self.agent_ids[1]] + + # Normalize actions to standard cooperate/defect/gibberish format + def normalize_action(action): + if action in self.cooperate_actions: + return "C" + elif action in self.defect_actions: + return "D" + else: + return "D" + + norm_action0 = normalize_action(agent0_action) + norm_action1 = normalize_action(agent1_action) + + payoffs = { + ("C", "C"): [self.reward, self.reward], + ("C", "D"): [self.sucker, self.temptation], + ("D", "C"): [self.temptation, self.sucker], + ("D", "D"): [self.punishment, self.punishment], + } + + round_rewards = { + self.agent_ids[0]: payoffs[(norm_action0, norm_action1)][0], + self.agent_ids[1]: payoffs[(norm_action0, norm_action1)][1], + } + + # Update game state + self.state.round_nb += 1 + self.state.last_moves = copy.deepcopy(actions) + done = self.state.round_nb >= self.rounds_per_game + step_log = SimulationStepLog( + rewards=round_rewards, + info={ + "actions": { + self.agent_ids[0]: norm_action0, + self.agent_ids[1]: norm_action1, + } + }, + ) + + return done, step_log + + def get_obs(self): + """Returns all agent observations in dict + Returns: + observations + """ + observations = {} + for agent_id in self.agent_ids: + observations[agent_id] = self.get_obs_agent(agent_id) + return observations + + def get_obs_agent(self, agent_id): + """Returns observation for agent_id""" + if self.state.last_moves != None: + other_id = get_coagent_id(self.agent_ids, agent_id) + last_coagent_move = self.state.last_moves[other_id] + else: + last_coagent_move = None + obs = IPDObs(round_nb=self.state.round_nb, last_coagent_move=last_coagent_move) + return obs + + def reset(self): + """Returns initial observations and states""" + self.state = IPDState() + return self.get_obs() + + def get_safe_copy(self): + """ + Return a safe copy of the simulation. + """ + simulation_copy = copy.copy(self) + simulation_copy.state = copy.deepcopy(self.state) + return simulation_copy diff --git a/src_code_for_reproducibility/markov_games/ipd/ipd_statistics.py b/src_code_for_reproducibility/markov_games/ipd/ipd_statistics.py new file mode 100644 index 0000000000000000000000000000000000000000..8740fda6bc2550c92aef27ed9fbe7bc945be42ca --- /dev/null +++ b/src_code_for_reproducibility/markov_games/ipd/ipd_statistics.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from typing import Dict, Callable, List, Tuple + +from mllm.markov_games.rollout_tree import SimulationStepLog + + +def avg_reward(sl: SimulationStepLog) -> List[Tuple[str, float]]: + for aid in sl.rewards.keys(): + if "buffer" in str(aid) and "live" not in str(aid): + return None + # One value per agent at each step + rewards_dict = {f"reward-{aid}": float(v) for aid, v in (sl.rewards or {}).items()} + return [(key, value) for key, value in rewards_dict.items() if value is not None] + +stat_functs: list[Callable[[SimulationStepLog], List[Tuple[str, float]]]] = [ + avg_reward, +] \ No newline at end of file diff --git a/src_code_for_reproducibility/markov_games/negotiation/README.md b/src_code_for_reproducibility/markov_games/negotiation/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c8ebadee705971c5331924ed1b9d53c7e5f69770 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/negotiation/README.md @@ -0,0 +1,40 @@ +## Negotiation Games: core mechanics and variants + +This family of games feature two agents who, in each round, may briefly communicate and then simultaneously propose how to split a fixed resource (most commonly 10 coins). Rewards are the amount kept multiplied by an agent’s per-unit value. The starting speaker alternates deterministically across rounds. + +Communication is optional and variant-dependent: some settings encourage rich messaging to share private information, while others remove messaging entirely to focus on allocation behavior. + +Proportional splitting is used when the two proposals exceed the available total: allocations are scaled proportionally rather than discarded. This preserves a useful learning signal even when agents over-claim. + +### Variants (in increasing difficulty) + +- No‑Press Split + - Single item type (coins) + - No communication; agents go straight to making split proposals, with the starting player alternating deterministically. + - Motivation: mirrors no‑communication setups (e.g., Advantage Alignment) while keeping the split decision nontrivial. + - Deterministic Mode: values are fixed and public: one agent values coins at 10, the other at 1 (alternates each round). + - Stochastic Mode: values are random and uncorrelated. + +- Trust-and-Split RPS (TAS-RPS) + - Single item type (coins) + - Each round, a rock–paper–scissors hand draw creates a strong asymmetry: the winner’s per-coin value is 10, the loser’s is 1. + - Each agent initially sees only their own hand and must communicate to coordinate an optimal split. + - Motivation: enforce large value disparity so one’s own value reveals little about the other’s (avoiding ceiling effects) and incentivize meaningful communication. + +- Trust-and-Split (TAS) + - Single item type (coins); each round, each agent’s per-coin value is independently sampled in a broad range (e.g., 1–20). + - Each agent observes only their own value; they may use short messages to share and negotiate. + - Motivation: a simple blend that tests whether agents learn to exchange private information and coordinate proportional, value-aware splits. + +- Deal-or-No-Deal (DOND) + - Introduced in [Deal or No Deal? End-to-End Learning for Negotiation Dialogues](https://arxiv.org/pdf/1706.05125) + - Multiple item types (typically "books", "hats" and "balls") with limited stocks; each agent has its own per-type values. + - A deal pays out only if both proposals exactly agree and respect the stock; otherwise no deal (zero reward) that round. + - Motivation: a known benchmark closer to real-world bargaining, where both parties must explicitly agree. + + + + + + + diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/dond_agent.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/dond_agent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce815f643a8b1900a43c4f4aef2d526537700de7 Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/dond_agent.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/nego_agent.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/nego_agent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fe2554c39730af2a459983df421ba98da31ffe9 Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/nego_agent.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/nego_hard_coded_policies.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/nego_hard_coded_policies.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63823004458193b44378e01e1385a8a71d7fb487 Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/nego_hard_coded_policies.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/nego_simulation.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/nego_simulation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..885a652a7b17e8f1bfabd33e4d4ec0c9e435bd14 Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/nego_simulation.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/no_press_nego_agent.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/no_press_nego_agent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7281edd1342dbd96b0d8db1fe0c254ad44e180e Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/no_press_nego_agent.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/no_press_nego_simulation.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/no_press_nego_simulation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbc94fa4abf96d0f708fa7784e52f23f63008eac Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/no_press_nego_simulation.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_agent.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_agent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b13dbfa24f63cd641f0f98306b862327083e3b84 Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_agent.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_rps_agent.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_rps_agent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4809c83abfe2fe1ca41508d6fcc081e9f584aab5 Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_rps_agent.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_rps_simulation.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_rps_simulation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d411e2237230583ef9935d9ffd9e61197c1890e Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_rps_simulation.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_simple_agent.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_simple_agent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1092f7bfe2281d3c0708e5195bcc7305179f298c Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_simple_agent.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_simple_simulation.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_simple_simulation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ce7e0ea40f3cb79879f15a5f8785e43abc9518f Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_simple_simulation.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_simulation.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_simulation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbda526f1d3a02b46ba18f0b1f3eec117378953d Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_simulation.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/negotiation/dond_simulation.py b/src_code_for_reproducibility/markov_games/negotiation/dond_simulation.py new file mode 100644 index 0000000000000000000000000000000000000000..a27d6ce2cf7e31a0cddd341db39ae7898b086115 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/negotiation/dond_simulation.py @@ -0,0 +1,153 @@ +import copy +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple + +from numpy.random import default_rng + +from mllm.markov_games.rollout_tree import SimulationStepLog +from mllm.markov_games.negotiation.nego_simulation import Split, NegotiationState, NegotiationObs, NegotiationSimulation +from mllm.utils.get_coagent_id import get_coagent_id + + +AgentId = str + + +@dataclass +class DealNoDealState(NegotiationState): + item_types: List[str] + values: Dict[AgentId, Dict[str, int]] + +@dataclass +class DealNoDealObs(NegotiationObs): + my_values: Dict[str, int] + item_types: List[str] + previous_values_coagent: Dict[str, int] | None + + +def random_partition_integer(rng, total: int, parts: int) -> List[int]: + if parts <= 0: + return [] + if total <= 0: + return [0 for _ in range(parts)] + cuts = sorted(rng.integers(0, total + 1, size=parts - 1).tolist()) + vals = [] + prev = 0 + for c in cuts + [total]: + vals.append(c - prev) + prev = c + return vals + +class DealNoDealSimulation(NegotiationSimulation): + + def __init__( + self, + item_types: List[str] = ["books", "hats", "balls"], + *args, + **kwargs, + ): + super().__init__(item_types=item_types, *args, **kwargs) + self.reset() + + def _other(self, agent_id: AgentId) -> AgentId: + return get_coagent_id(self.agent_ids, agent_id) + + def _sample_stock(self) -> Dict[str, int]: + # total items between 5 and 7 + total_items = int(self.rng.integers(5, 8)) + # nonnegative per-type counts summing to total_items + parts = random_partition_integer(self.rng, total_items, len(self.item_types)) + # allow zeros per type + return {t: int(c) for t, c in zip(self.item_types, parts)} + + def _sample_values_pair(self) -> Dict[AgentId, Dict[str, int]]: + # Each agent has integer non-negative values that sum to 10 + # Each item type valued by at least one agent + # Some item type valued by both agents + while True: + vals_a = random_partition_integer(self.rng, 10, len(self.item_types)) + vals_b = random_partition_integer(self.rng, 10, len(self.item_types)) + a = {t: int(v) for t, v in zip(self.item_types, vals_a)} + b = {t: int(v) for t, v in zip(self.item_types, vals_b)} + # each item valued by at least one + ok1 = all((a[t] > 0) or (b[t] > 0) for t in self.item_types) + # some item valued by both + ok2 = any((a[t] > 0) and (b[t] > 0) for t in self.item_types) + if ok1 and ok2: + return {self.agent_ids[0]: a, self.agent_ids[1]: b} + + def _is_valid_allocation(self, allocation: Dict[str, int], stock: Dict[str, int]) -> bool: + for t in self.item_types: + v = allocation.get(t) + if v is None: + return False + if not isinstance(v, int): + return False + if v < 0 or v > int(stock.get(t, 0)): + return False + return True + + def set_new_round_of_variant(self): + # Keep same values, resample stock + self.state.quantities = self._sample_stock() + + def get_info_of_variant(self, state: NegotiationState, actions: Dict[AgentId, Any]) -> Dict[str, Any]: + return { + "quantities": copy.deepcopy(state.quantities), + "values": copy.deepcopy(state.values), + 'splits': copy.deepcopy(state.splits), + } + + def get_rewards(self, splits: Dict[AgentId, Split]) -> Dict[AgentId, float]: + """ + Returns the rewards for each agent. + """ + split_a = splits[self.agent_ids[0]].items_given_to_self + split_b = splits[self.agent_ids[1]].items_given_to_self + rewards = {self.agent_ids[0]: 0, self.agent_ids[1]: 0} + for t in self.item_types: + # If not complementary, return 0! + if not split_a[t] + split_b[t] == self.state.quantities[t]: + return {self.agent_ids[0]: 0, self.agent_ids[1]: 0} + rewards[self.agent_ids[0]] += split_a[t] * self.state.values[self.agent_ids[0]][t] + rewards[self.agent_ids[1]] += split_b[t] * self.state.values[self.agent_ids[1]][t] + return rewards + + def get_obs(self): + return {agent_id: self.get_obs_agent(agent_id) for agent_id in self.agent_ids} + + def get_obs_agent(self, agent_id): + other_id = self._other(agent_id) + obs = DealNoDealObs( + round_nb=self.state.round_nb, + last_message=self.state.last_message, + current_agent=self.state.current_agent, + quantities=copy.deepcopy(self.state.quantities), + value=0.0, # unused in DOND + other_agent_split=None, # not meaningful until split + split_phase=self.state.split_phase, + quota_messages_per_agent_per_round=self.quota_messages_per_agent_per_round, + my_values=copy.deepcopy(self.state.values[agent_id]), + item_types=list(self.item_types), + previous_values_coagent=copy.deepcopy(self.state.values.get(other_id, {})), + ) + return obs + + def reset(self): + start_agent = self.agent_ids[self._starting_agent_index] + stock = self._sample_stock() + values = self._sample_values_pair() + self.state = DealNoDealState( + round_nb=0, + last_message="", + current_agent=start_agent, + quantities=stock, + values=values, + previous_values=None, + splits={aid: None for aid in self.agent_ids}, + nb_messages_sent={aid: 0 for aid in self.agent_ids}, + split_phase=False, + item_types=list(self.item_types), + ) + return self.get_obs() + + diff --git a/src_code_for_reproducibility/markov_games/negotiation/nego_hard_coded_policies.py b/src_code_for_reproducibility/markov_games/negotiation/nego_hard_coded_policies.py new file mode 100644 index 0000000000000000000000000000000000000000..2b5c191e15ef6b0abada72b1b6ba3a4c59421fdf --- /dev/null +++ b/src_code_for_reproducibility/markov_games/negotiation/nego_hard_coded_policies.py @@ -0,0 +1,64 @@ +import asyncio +from typing import Optional +from mllm.markov_games.negotiation.nego_agent import NegotiationAgent +from mllm.markov_games.negotiation.no_press_nego_agent import NoPressAgent +from mllm.markov_games.negotiation.no_press_nego_simulation import NoPressObs +from mllm.markov_games.rollout_tree import AgentActLog, ChatTurn +from mllm.markov_games.negotiation.nego_simulation import Split +from typing import Any, Tuple + +class HardCodedNegoWelfareMaximizingPolicy(NoPressAgent): + async def act(self, observation: NoPressObs) -> Tuple[Any, AgentActLog]: + """ + Policy that gives all of the items to the agent who values them more. + If the items are equally valued, give them to the agent who values them more. + """ + quantities = observation.quantities + my_values = observation.value + other_values = observation.other_value + + items_given_to_self = {} + for item, qty in quantities.items(): + my_v = float(my_values.get(item, 0)) + other_v = float(other_values.get(item, 0)) + if my_v == other_v: + items_given_to_self[item] = int(qty) / 2 + else: + items_given_to_self[item] = int(qty if my_v > other_v else 0) + + action = Split(items_given_to_self=items_given_to_self) + act_log = AgentActLog( + chat_turns=[ + ChatTurn( + agent_id=self.agent_id, + role="assistant", + content="Using welfare-maximizing split (all to higher-value agent).", + is_state_end=True, + ) + ], + info=None, + ) + return action, act_log + +class HardCodedNegoGreedyPolicy(NoPressAgent): + async def act(self, observation: NoPressObs) -> Tuple[Any, AgentActLog]: + """ + Always gives itself all of the items. + """ + quantities = observation.quantities + items_given_to_self = {item: int(qty) for item, qty in quantities.items()} + + action = Split(items_given_to_self=items_given_to_self) + act_log = AgentActLog( + chat_turns=[ + ChatTurn( + agent_id=self.agent_id, + role="assistant", + content="Using greedy split (keep all items).", + is_state_end=True, + ) + ], + info=None, + ) + return action, act_log + diff --git a/src_code_for_reproducibility/markov_games/negotiation/no_press_nego_agent.py b/src_code_for_reproducibility/markov_games/negotiation/no_press_nego_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..c0a62f2465ac8b34dc09cbc003dcc663b170ffe7 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/negotiation/no_press_nego_agent.py @@ -0,0 +1,94 @@ +from typing import Any, Dict, List, Tuple + +from mllm.markov_games.negotiation.nego_agent import ( + NegotiationAgent, + NegotiationAgentState, +) +from mllm.markov_games.negotiation.nego_simulation import Split +from mllm.markov_games.negotiation.no_press_nego_simulation import NoPressObs +from mllm.markov_games.rollout_tree import AgentActLog, ChatTurn + + +class NoPressAgent(NegotiationAgent): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # No communication in this variant + self.intro_prompt = ( + "Welcome to an iterated game. You are {agent}. The other agent is {other_agent}.\n" + "Setup:\n" + "1. The game consists of multiple independent rounds.\n" + "2. In each round, there are multiple items to split between the two agents.\n" + "3. Both agents are assigned a per-item value between 1 and 20 (inclusive) in each round.\n" + "4. You can observe per-item values of both agents.\n" + "5. Because assignments are random, both agents are equally likely to have same expected per-item value.\n" + "\n" + "Protocol:\n" + "1. Both agents simultaneously propose the amount of each item they will keep.\n" + "2. If the total sum of proposals is less than or equal to the item quantity, both agents receive their proposed amounts.\n" + "3. If the total sum of proposals exceeds the item quantity, they are allocated proportionally.\n" + "4. Your points for the round = (amount you receive per item) x (your per-item value for that round), added across all items.\n" + "5. Points are accumulated across rounds.\n" + "Your goal: {goal}\n" + ) + self.new_round_prompt = ( + "A New Round Begins\n" + "The items to split are {quantities}.\n" + "Your per-item values are {value} and {other_agent}'s per-item values are {other_value}." + ) + self.last_round_prompt = ( + "Last Round Summary:\n" + " - Items to split: {last_quantities}\n" + " - Your per-item values: {last_value_agent}\n" + " - {other_agent}'s per-item values: {last_value_coagent}\n" + " - You proposed: {last_split_agent}\n" + " - You earned: {last_points_agent} points\n" + " - {other_agent} proposed: {last_split_coagent}\n" + " - {other_agent} earned: {last_points_coagent} points\n" + " - Round Complete.\n" + ) + self.send_split_prompt = "Submit Your Proposal\n" "Respond as {proposal_style}" + + def get_message_regex(self, observation: NoPressObs) -> str: + return r"^$" # No messages allowed + + def get_split_regex(self, observation: NoPressObs) -> str: + items = list(observation.quantities.keys()) + # Accept both singular and plural forms + item_pattern = "|".join( + [f"{item[:-1]}s?" if item.endswith("s") else f"{item}s?" for item in items] + ) + regex = rf"(?i)Proposal:\s*((?:\s*(?P(10|[0-9]))\s*(?P{item_pattern})\s*,?)+)" + return regex + + def get_split_action(self, policy_output: str, observation: NoPressObs) -> Split: + items = list(observation.quantities.keys()) + import re as _re + + split_regex = self.get_split_regex(observation) + items_given_to_self = {item: 0 for item in items} + m = _re.match(split_regex, policy_output.strip()) + if m: + # Find all (number, item) pairs + item_pattern = "|".join( + [ + f"{item[:-1]}s?" if item.endswith("s") else f"{item}s?" + for item in items + ] + ) + inner_regex = rf"(?i)(10|[0-9])\s*({item_pattern})" + + def normalize_item_name(item_str): + for orig in items: + if item_str.lower() == orig.lower(): + return orig + if orig.endswith("s") and item_str.lower() == orig[:-1].lower(): + return orig + if ( + not orig.endswith("s") + and item_str.lower() == orig.lower() + "s" + ): + return orig + + for num, item in _re.findall(inner_regex, m.group(1)): + items_given_to_self[normalize_item_name(item)] = int(num) + return Split(items_given_to_self=items_given_to_self) diff --git a/src_code_for_reproducibility/markov_games/negotiation/tas_agent.py b/src_code_for_reproducibility/markov_games/negotiation/tas_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..002160873969ab7292f0f62a091e12ec376022c6 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/negotiation/tas_agent.py @@ -0,0 +1,108 @@ +from mllm.markov_games.negotiation.nego_agent import NegotiationAgent +from mllm.markov_games.negotiation.nego_simulation import Split +from mllm.markov_games.negotiation.tas_simulation import TrustAndSplitObs + + +class TrustAndSplitAgent(NegotiationAgent): + def __init__(self, num_message_chars, *args, **kwargs): + self.num_message_chars = num_message_chars + super().__init__(*args, **kwargs) + self.intro_prompt = ( + "Welcome to an iterated game. You are {agent}. The other agent is {other_agent}.\n" + "Setup:\n" + "1. The game has multiple independent rounds.\n" + "2. In each round, there are multiple items to split between the two agents.\n" + "3. Both agents are assigned a per-item value between 1 and 20 (inclusive) in each round.\n" + "4. You can only observe your own per-item values.\n" + "5. Because assignments are random, both agents are equally likely to have same expected per-item value.\n" + "\n" + "Protocol:\n" + "1. At the start of the round, one agent begins the conversation. The starting role alternates each round.\n" + "2. Agents exchange a short chat ({quota_messages_per_agent_per_round} messages per round per agent) to negotiate how to split the item.\n" + " - Use this chat to communicate your private per-item value to make informed proposals.\n" + "3. After the chat, both agents simultaneously propose the amount of each item they will keep.\n" + "4. If the total sum of proposals is less than or equal to the item quantity, both agents receive their proposed amounts.\n" + "5. If the total sum of proposals exceeds the item quantity, they are allocated proportionally.\n" + "6. Your points for the round = (amount you receive per item) x (your per-item value for that round), added across all items.\n" + "7. Points are accumulated across rounds.\n" + "Your goal: {goal}\n" + ) + self.new_round_prompt = ( + "A New Round Begins\n" + "The items to split are {quantities}.\n" + "Your per-item values are {value}." + ) + self.last_round_prompt = ( + "Last Round Summary:\n" + " - Items to split: {last_quantities}\n" + " - Your per-item values: {last_value_agent}\n" + " - {other_agent}'s per-item values: {last_value_coagent}\n" + " - You proposed: {last_split_agent}\n" + " - You earned: {last_points_agent} points\n" + " - {other_agent} proposed: {last_split_coagent}\n" + " - {other_agent} earned: {last_points_coagent} points\n" + " - Round Complete.\n" + ) + self.send_split_prompt = ( + "Message quota is finished for this round.\n" + "{other_agent} has finalized their proposal.\n" + "Submit your finalization now\n" + "Respond with {proposal_style2}" + ) + # self.wait_for_message_prompt = "Wait for {other_agent} to send a message..." + self.wait_for_message_prompt = "" + self.last_message_prompt = "{other_agent} said: {last_message}" + # self.send_message_prompt = ( + # f"Send your message now (max {self.num_message_chars} chars)." + # ) + self.send_message_prompt = f"Send your message now in ... (<={self.num_message_chars} chars)." + + def get_message_regex(self, observation: TrustAndSplitObs) -> str: + return rf"[\s\S]{{0,{self.num_message_chars}}}" + + # def get_message_regex(self, observation: TrustAndSplitObs) -> str: + # return rf"(?s).{{0,{self.num_message_chars}}}" + + def get_split_regex(self, observation: TrustAndSplitObs) -> str: + items = list(observation.quantities.keys()) + # Accept both singular and plural forms + item_pattern = "|".join( + [f"{item[:-1]}s?" if item.endswith("s") else f"{item}s?" for item in items] + ) + regex = rf"(?i) ?((?:\s*(?P(10|[0-9]))\s*(?P{item_pattern})\s*,?)+) ?" + return regex + + def get_split_action( + self, policy_output: str, observation: TrustAndSplitObs + ) -> Split: + items = list(observation.quantities.keys()) + import re as _re + + split_regex = self.get_split_regex(observation) + items_given_to_self = {item: 0 for item in items} + m = _re.match(split_regex, policy_output.strip()) + if m: + # Find all (number, item) pairs + item_pattern = "|".join( + [ + f"{item[:-1]}s?" if item.endswith("s") else f"{item}s?" + for item in items + ] + ) + inner_regex = rf"(?i)(10|[0-9])\s*({item_pattern})" + + def normalize_item_name(item_str): + for orig in items: + if item_str.lower() == orig.lower(): + return orig + if orig.endswith("s") and item_str.lower() == orig[:-1].lower(): + return orig + if ( + not orig.endswith("s") + and item_str.lower() == orig.lower() + "s" + ): + return orig + + for num, item in _re.findall(inner_regex, m.group(1)): + items_given_to_self[normalize_item_name(item)] = int(num) + return Split(items_given_to_self=items_given_to_self) diff --git a/src_code_for_reproducibility/markov_games/negotiation/tas_rps_agent.py b/src_code_for_reproducibility/markov_games/negotiation/tas_rps_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..e711c2a65d336e4d9b991c68662069e96b4dfee8 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/negotiation/tas_rps_agent.py @@ -0,0 +1,118 @@ +import copy +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple + +from mllm.markov_games.agent import Agent +from mllm.markov_games.negotiation.nego_agent import ( + Message, + NegotiationAgent, + NegotiationAgentState, + Split, +) +from mllm.markov_games.negotiation.tas_rps_simulation import TrustAndSplitRPSObs +from mllm.markov_games.rollout_tree import AgentActLog, ChatTurn + + +class TrustAndSplitRPSAgent(NegotiationAgent): + def __init__( + self, + num_message_chars: int, + message_start_end_format: bool = False, + proposal_start_end_format: bool = False, + *args, + **kwargs, + ): + self.num_message_chars = num_message_chars + self.message_start_end_format = message_start_end_format + self.proposal_start_end_format = proposal_start_end_format + super().__init__(*args, **kwargs) + self.intro_prompt = ( + "Welcome to an iterated game. You are {agent}. The other agent is {other_agent}.\n" + "\n" + "Setup:\n" + "1. The game has multiple independent rounds.\n" + "2. In each round, there are 10 coins to split between the two agents.\n" + "3. Each agent's per-coin value for that round is determined as follows:\n" + " - Both agents are randomly assigned a rock, paper or scissors hands\n" + " - Rock has the upper hand over scissors, scissors has the upper hand over paper and paper has the upper hand over rock.\n" + " - The agent with the upper hand has a per-coin value of 10.\n" + " - The agent with the lower hand has a per-coin value of 1.\n" + "4. You only see your own hand, but you may communicate it in messages and infer your value based on the other agent's hand.\n" + "5. Over many rounds both agents are equally likely to have the upper and lower hand.\n" + "\n" + "Protocol:\n" + "1. At the start of the round, one agent begins the conversation. The starting role alternates each round.\n" + "2. Agents exchange a short chat ({quota_messages_per_agent_per_round} messages per round per agent) to negotiate how to split the 10 coins.\n" + " - Use this chat to communicate your hand so that both agents can determine their per-coin values.\n" + "3. After the chat, both agents simultaneously propose how many coins they keep.\n" + "4. If the total sum of proposals is less than or equal to 10, both agents receive their proposals.\n" + "5. If the total sum of proposals exceeds 10, the coins are allocated proportionally.\n" + "6. Your points for the round = (coins you receive) x (your per-coin value for that round). \n" + "7. The points are accumulated across rounds.\n" + "Your goal: {goal}\n" + ) + self.new_round_prompt = ( + "A New Round Begins\n" + "Your hand is {hand}. You don't know {other_agent}'s hand yet.\n" + ) + # self.last_round_prompt = ( + # "Last Round Summary:\n" + # " - Your hand: {last_hand_agent}\n" + # " - {other_agent}'s hand: {last_hand_coagent}\n" + # " - Your value per coin: {last_value_agent}\n" + # " - {other_agent}'s value per coin: {last_value_coagent}\n" + # " - You proposed: {last_split_agent} coins\n" + # " - You earned: {last_points_agent} points\n" + # " - {other_agent} proposed: {last_split_coagent} coins\n" + # " - {other_agent} earned: {last_points_coagent} points\n" + # " - Round Complete.\n" + # ) + self.last_round_prompt = "In the previous round, {other_agent} had a {last_hand_value_coagent} hand and proposed {last_split_coagent} coins.\n" + if self.proposal_start_end_format: + self.send_split_prompt = ( + "Submit your proposal\n" + "Respond with <> x <> where x is an integer in [0, 10]." + ) + else: + self.send_split_prompt = ( + "Submit your proposal\n" + "Respond with x where x is an integer in [0, 10]." + ) + self.wait_for_message_prompt = "Wait for {other_agent} to send a message..." + # self.wait_for_message_prompt = "" + self.last_message_prompt = "{other_agent} said: {last_message}" + if self.message_start_end_format: + self.send_message_prompt = f"Send your message now in <>...<> (<={self.num_message_chars} chars)." + else: + self.send_message_prompt = f"Send your message now in ... (<={self.num_message_chars} chars)." + + def get_message_regex(self, observation: TrustAndSplitRPSObs) -> str: + if self.message_start_end_format: + return ( + rf"<>[\s\S]{{0,{self.num_message_chars}}}<>" + ) + else: + return rf"[\s\S]{{0,{self.num_message_chars}}}" + + def get_split_regex(self, observation: TrustAndSplitRPSObs) -> str: + if self.proposal_start_end_format: + return r"<> ?(10|[0-9]) ?<>" + else: + return r" ?(10|[0-9]) ?" + + def get_split_action( + self, policy_output: str, observation: TrustAndSplitRPSObs + ) -> Split: + import re as _re + + if self.proposal_start_end_format: + m = _re.search( + r"<> ?(10|[0-9]) ?<>", policy_output + ) + else: + m = _re.search( + r" ?(10|[0-9]) ?", policy_output + ) + coins_int = int(m.group(1)) if m else int(policy_output) + return Split(items_given_to_self={"coins": coins_int}) diff --git a/src_code_for_reproducibility/markov_games/negotiation/tas_simple_agent.py b/src_code_for_reproducibility/markov_games/negotiation/tas_simple_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..e4439b53a04e8efe4553cb1aa0d85459a6e90c9d --- /dev/null +++ b/src_code_for_reproducibility/markov_games/negotiation/tas_simple_agent.py @@ -0,0 +1,90 @@ +from mllm.markov_games.negotiation.nego_agent import NegotiationAgent +from mllm.markov_games.negotiation.nego_simulation import Split +from mllm.markov_games.negotiation.tas_simple_simulation import TrustAndSplitSimpleObs + + +class TrustAndSplitSimpleAgent(NegotiationAgent): + def __init__( + self, + num_message_chars, + message_start_end_format: bool = False, + proposal_start_end_format: bool = False, + *args, + **kwargs, + ): + self.num_message_chars = num_message_chars + self.message_start_end_format = message_start_end_format + self.proposal_start_end_format = proposal_start_end_format + super().__init__(*args, **kwargs) + self.intro_prompt = ( + "Welcome to an iterated game. You are {agent}. The other agent is {other_agent}.\n" + "Setup:\n" + "1. The game has multiple independent rounds.\n" + "2. In each round, there are 10 coins to split between the two agents.\n" + "3. Both agents are assigned a per-coin value between 1 and 10 (inclusive) in each round.\n" + "4. You can only observe your own per-coin value.\n" + "5. Because assignments are random, both agents are equally likely to have same expected per-coin value.\n" + "\n" + "Protocol:\n" + "1. At the start of the round, one agent begins the conversation. The starting role alternates each round.\n" + "2. Agents exchange a short chat ({quota_messages_per_agent_per_round} messages per round per agent) to negotiate how to split the coins.\n" + " - Use this chat to communicate your private per-coin value to make informed proposals.\n" + "3. After the chat, both agents simultaneously propose how many coins they keep.\n" + "4. If the total sum of proposals is less than or equal to 10, both agents receive their proposals.\n" + "5. If the total sum of proposals exceeds 10, the coins are allocated proportionally.\n" + "6. Your points for the round = (coins you receive) x (your per-coin value for that round). \n" + "7. Points are accumulated across rounds.\n" + "Your goal: {goal}\n" + ) + self.new_round_prompt = ( + "A New Round Begins\n" + "Your per-coin value is {value}. You don't know {other_agent}'s value yet.\n" + ) + self.last_round_prompt = "In the previous round, {other_agent} had a {last_value_str_coagent} value and proposed {last_split_coagent} coins.\n" + if self.proposal_start_end_format: + self.send_split_prompt = ( + "Submit your proposal\n" + "Respond with <> x <> where x is an integer in [0, 10]." + ) + else: + self.send_split_prompt = ( + "Submit your proposal\n" + "Respond with x where x is an integer in [0, 10]." + ) + self.wait_for_message_prompt = "Wait for {other_agent} to send a message..." + # self.wait_for_message_prompt = "" + self.last_message_prompt = "{other_agent} said: {last_message}" + if self.message_start_end_format: + self.send_message_prompt = f"Send your message now in <>...<> (<={self.num_message_chars} chars)." + else: + self.send_message_prompt = f"Send your message now in ... (<={self.num_message_chars} chars)." + + def get_message_regex(self, observation: TrustAndSplitSimpleObs) -> str: + if self.message_start_end_format: + return ( + rf"<>[\s\S]{{0,{self.num_message_chars}}}<>" + ) + else: + return rf"[\s\S]{{0,{self.num_message_chars}}}" + + def get_split_regex(self, observation: TrustAndSplitSimpleObs) -> str: + if self.proposal_start_end_format: + return r"<> ?(10|[0-9]) ?<>" + else: + return r" ?(10|[0-9]) ?" + + def get_split_action( + self, policy_output: str, observation: TrustAndSplitSimpleObs + ) -> Split: + import re as _re + + if self.proposal_start_end_format: + m = _re.search( + r"<> ?(10|[0-9]) ?<>", policy_output + ) + else: + m = _re.search( + r" ?(10|[0-9]) ?", policy_output + ) + coins_int = int(m.group(1)) if m else int(policy_output) + return Split(items_given_to_self={"coins": coins_int}) diff --git a/src_code_for_reproducibility/markov_games/negotiation/tas_simple_simulation.py b/src_code_for_reproducibility/markov_games/negotiation/tas_simple_simulation.py new file mode 100644 index 0000000000000000000000000000000000000000..3dbd0c43d73e3f7b18204b62e71d72b2df1d13e6 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/negotiation/tas_simple_simulation.py @@ -0,0 +1,169 @@ +import copy +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Dict, List, Literal + +from numpy.random import default_rng + +from mllm.markov_games.negotiation.nego_simulation import ( + NegotiationObs, + NegotiationSimulation, + NegotiationState, + Split, + compute_tas_style_rewards, +) + +AgentId = str + + +@dataclass +class TrustAndSplitSimpleState(NegotiationState): + pass + + +@dataclass +class TrustAndSplitSimpleObs(NegotiationObs): + last_value_str_coagent: str | None + + +class TrustAndSplitSimpleSimulation(NegotiationSimulation): + def __init__( + self, + game_type: Literal["10-1-exclusive", "1-to-10"] = "1-to-10", + dist_type: Literal["uniform", "bimodal"] = "uniform", + beta_dist_alpha: float = 0.1, + beta_dist_beta: float = 0.1, + *args, + **kwargs, + ): + self.game_type = game_type + self.dist_type = dist_type + self.beta_dist_alpha = beta_dist_alpha + self.beta_dist_beta = beta_dist_beta + super().__init__(*args, **kwargs) + + def _sample_values(self) -> Dict[AgentId, dict]: + values = {} + while True: + if self.game_type == "10-1-exclusive": + v = int(self.rng.choice([1, 10])) + values[self.agent_ids[0]] = v + values[self.agent_ids[1]] = 10 if v == 1 else 1 + elif self.game_type == "1-to-10": + for aid in self.agent_ids: + if self.dist_type == "uniform": + values[aid] = int(self.rng.integers(1, 11)) + elif self.dist_type == "bimodal": + alpha, beta = self.beta_dist_alpha, self.beta_dist_beta + values[aid] = int(round(self.rng.beta(alpha, beta) * 9) + 1) + if len(set(values.values())) != 1: + break + return values + + def _sample_quantities(self) -> Dict[str, int]: + return {"coins": 10} + + def set_new_round_of_variant(self): + self.state.quantities = self._sample_quantities() + self.state.values = self._sample_values() + self.state.split_phase = False + + def get_info_of_variant( + self, state: NegotiationState, actions: Dict[AgentId, Any] + ) -> Dict[str, Any]: + return { + "quantities": copy.deepcopy(state.quantities), + "values": copy.deepcopy(state.values), + # "previous_values": copy.deepcopy(state.previous_values), + "splits": copy.deepcopy(state.splits), + } + + def get_rewards(self, splits: Dict[AgentId, Split]) -> Dict[AgentId, float]: + return compute_tas_style_rewards( + self.agent_ids, self.state.values, splits, self.state.quantities + ) + + def get_obs(self): + return {agent_id: self.get_obs_agent(agent_id) for agent_id in self.agent_ids} + + def get_obs_agent(self, agent_id): + other_id = self._other(agent_id) + last_value_coagent = ( + None + if self.state.previous_values is None + else self.state.previous_values.get(other_id) + ) + last_points_coagent = ( + None + if self.state.previous_points is None + else round(self.state.previous_points.get(other_id), 1) + ) + last_value_agent = ( + None + if self.state.previous_values is None + else self.state.previous_values.get(agent_id) + ) + last_points_agent = ( + None + if self.state.previous_points is None + else round(self.state.previous_points.get(agent_id), 1) + ) + last_split_coagent = None + last_split_agent = None + if self.state.previous_splits is not None: + last_split_coagent = self.state.previous_splits[ + other_id + ].items_given_to_self["coins"] + last_split_agent = self.state.previous_splits[agent_id].items_given_to_self[ + "coins" + ] + if last_value_agent is None or last_value_coagent is None: + last_value_str_coagent = None + else: + if last_value_coagent > last_value_agent: + last_value_str_coagent = "higher" + elif last_value_coagent < last_value_agent: + last_value_str_coagent = "lower" + else: + raise ValueError("Should not be equal values") + + obs = TrustAndSplitSimpleObs( + round_nb=self.state.round_nb, + last_message=self.state.last_message, + quota_messages_per_agent_per_round=self.quota_messages_per_agent_per_round, + current_agent=self.state.current_agent, + other_agent=self.agent_id_to_name[other_id], + quantities=self.state.quantities, + item_types=self.item_types, + value=self.state.values[agent_id], + split_phase=self.state.split_phase, + last_split_agent=last_split_agent, + last_value_agent=last_value_agent, + last_points_agent=last_points_agent, + last_split_coagent=last_split_coagent, + last_value_coagent=last_value_coagent, + last_points_coagent=last_points_coagent, + last_quantities=self.state.previous_quantities, + last_value_str_coagent=last_value_str_coagent, + ) + return obs + + def reset(self): + start_agent = self.agent_ids[self._starting_agent_index] + quantities = self._sample_quantities() + values = self._sample_values() + self.state = TrustAndSplitSimpleState( + round_nb=0, + last_message="", + current_agent=start_agent, + quantities=quantities, + values=values, + previous_values=None, + splits={aid: None for aid in self.agent_ids}, + nb_messages_sent={aid: 0 for aid in self.agent_ids}, + split_phase=False, + previous_splits=None, + previous_points=None, + previous_quantities=None, + ) + return self.get_obs() diff --git a/src_code_for_reproducibility/markov_games/negotiation/tas_simulation.py b/src_code_for_reproducibility/markov_games/negotiation/tas_simulation.py new file mode 100644 index 0000000000000000000000000000000000000000..5499a146e9da491757a8105965b2d210f8327134 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/negotiation/tas_simulation.py @@ -0,0 +1,172 @@ +import copy +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Dict, List, Literal + +from numpy.random import default_rng + +from mllm.markov_games.negotiation.nego_simulation import ( + NegotiationObs, + NegotiationSimulation, + NegotiationState, + Split, + compute_tas_style_rewards, +) + +AgentId = str + + +@dataclass +class TrustAndSplitState(NegotiationState): + pass + + +@dataclass +class TrustAndSplitObs(NegotiationObs): + pass + + +class TrustAndSplitSimulation(NegotiationSimulation): + def __init__( + self, + game_type: Literal["10-1-exclusive", "10-1-ties", "1-to-20"] = "1-to-20", + same_round_value: bool = True, + atleast_one_conflict: bool = False, + *args, + **kwargs, + ): + self.game_type = game_type + self.same_round_value = same_round_value + self.atleast_one_conflict = atleast_one_conflict + super().__init__(*args, **kwargs) + + def _sample_values(self) -> Dict[AgentId, dict]: + values = defaultdict(dict) + if self.state is None: + item_types = self.item_types + else: + item_types = list(self.state.quantities.keys()) + while True: + for item in item_types: + if self.game_type == "10-1-exclusive": + v = int(self.rng.choice([1, 10])) + values[self.agent_ids[0]][item] = v + values[self.agent_ids[1]][item] = 10 if v == 1 else 1 + elif self.game_type == "10-1-ties": + for aid in self.agent_ids: + values[aid][item] = int(self.rng.choice([1, 10])) + elif self.game_type == "1-to-20": + for aid in self.agent_ids: + values[aid][item] = int(self.rng.integers(1, 21)) + agent_values = [sum(v.values()) for v in values.values()] + if self.atleast_one_conflict: + has_conflict = False + for item in item_types: + agent_values_for_item = [ + values[aid][item] for aid in self.agent_ids + ] + if ( + len(set(agent_values_for_item)) > 1 + ): # Different values for this item + has_conflict = True + break + if not has_conflict: + continue + if len(set(agent_values)) == 1 or not self.same_round_value: + break + return values + + def _sample_quantities(self) -> Dict[str, int]: + return {item.lower(): 10 for item in self.item_types} + + def set_new_round_of_variant(self): + self.state.quantities = self._sample_quantities() + self.state.values = self._sample_values() + self.state.split_phase = False + + def get_info_of_variant( + self, state: NegotiationState, actions: Dict[AgentId, Any] + ) -> Dict[str, Any]: + return { + "quantities": copy.deepcopy(state.quantities), + "values": copy.deepcopy(state.values), + # "previous_values": copy.deepcopy(state.previous_values), + "splits": copy.deepcopy(state.splits), + } + + def get_rewards(self, splits: Dict[AgentId, Split]) -> Dict[AgentId, float]: + return compute_tas_style_rewards( + self.agent_ids, self.state.values, splits, self.state.quantities + ) + + def get_obs(self): + return {agent_id: self.get_obs_agent(agent_id) for agent_id in self.agent_ids} + + def get_obs_agent(self, agent_id): + other_id = self._other(agent_id) + last_value_coagent = ( + None + if self.state.previous_values is None + else self.state.previous_values.get(other_id) + ) + last_points_coagent = ( + None + if self.state.previous_points is None + else round(self.state.previous_points.get(other_id), 1) + ) + last_value_agent = ( + None + if self.state.previous_values is None + else self.state.previous_values.get(agent_id) + ) + last_points_agent = ( + None + if self.state.previous_points is None + else round(self.state.previous_points.get(agent_id), 1) + ) + last_split_coagent = None + last_split_agent = None + if self.state.previous_splits is not None: + last_split_coagent = self.state.previous_splits[ + other_id + ].items_given_to_self + last_split_agent = self.state.previous_splits[agent_id].items_given_to_self + obs = TrustAndSplitObs( + round_nb=self.state.round_nb, + last_message=self.state.last_message, + quota_messages_per_agent_per_round=self.quota_messages_per_agent_per_round, + current_agent=self.state.current_agent, + other_agent=self.agent_id_to_name[other_id], + quantities=self.state.quantities, + item_types=self.item_types, + value=self.state.values[agent_id], + split_phase=self.state.split_phase, + last_split_agent=last_split_agent, + last_value_agent=last_value_agent, + last_points_agent=last_points_agent, + last_split_coagent=last_split_coagent, + last_value_coagent=last_value_coagent, + last_points_coagent=last_points_coagent, + last_quantities=self.state.previous_quantities, + ) + return obs + + def reset(self): + start_agent = self.agent_ids[self._starting_agent_index] + quantities = self._sample_quantities() + values = self._sample_values() + self.state = TrustAndSplitState( + round_nb=0, + last_message="", + current_agent=start_agent, + quantities=quantities, + values=values, + previous_values=None, + splits={aid: None for aid in self.agent_ids}, + nb_messages_sent={aid: 0 for aid in self.agent_ids}, + split_phase=False, + previous_splits=None, + previous_points=None, + previous_quantities=None, + ) + return self.get_obs() diff --git a/src_code_for_reproducibility/models/__pycache__/__init__.cpython-312.pyc b/src_code_for_reproducibility/models/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..632ac87b9f56eb9bafde1439f0fb8e11d82c8e3a Binary files /dev/null and b/src_code_for_reproducibility/models/__pycache__/__init__.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/models/__pycache__/large_language_model_api.cpython-312.pyc b/src_code_for_reproducibility/models/__pycache__/large_language_model_api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8cd01cc114917784aaae48c7a20ab56af62f119d Binary files /dev/null and b/src_code_for_reproducibility/models/__pycache__/large_language_model_api.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/utils/gather_training_stats.py b/src_code_for_reproducibility/utils/gather_training_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..98ae3d9cad04748059b1b471033178a1f7e4f385 --- /dev/null +++ b/src_code_for_reproducibility/utils/gather_training_stats.py @@ -0,0 +1,257 @@ +import copy +import csv +import gc +import json +import logging +import os +import pickle +import random +import re +import subprocess +import sys +import time +from datetime import datetime +from statistics import mean +from typing import Any, Dict + +import hydra +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import torch +from omegaconf import OmegaConf + +from mllm.training.tally_metrics import Tally +from mllm.utils.stat_pack import StatPack + + +def get_from_nested_dict(dictio: dict, path: list[str]): + for sp in path[:-1]: + dictio = dictio[sp] + return dictio.get(path[-1]) + + +def set_at_path(dictio: dict, path: list[str], value): + for sp in path[:-1]: + if sp not in dictio: + dictio[sp] = {} + dictio = dictio[sp] + dictio[path[-1]] = value + + +def produce_tabular_render(inpath: str, outpath: str = None): + """ + TODO: docstring + """ + with open(inpath, "r") as f: + data = json.load(f) + rollout_paths = data.keys() + for rollout_path in rollout_paths: + if outpath is None: + m_path = rollout_path.replace("/", "|") + m_path = m_path.replace(".json", "") + m_path = ( + os.path.split(inpath)[0] + + "/contextualized_tabular_renders/" + + m_path + + "_tabular_render.render.csv" + ) + # import pdb; pdb.set_trace() + os.makedirs(os.path.split(m_path)[0], exist_ok=True) + metrics = data[rollout_path] + d = {k: [] for k in metrics[0].keys()} + for m in metrics: + for k, v in m.items(): + d[k].append(v) + d = pd.DataFrame(d) + d.to_csv(m_path) + + +def get_metric_paths(data: list[dict]): + d = data[0] + paths = [] + + def traverse_dict(d, current_path=[]): + for key, value in d.items(): + new_path = current_path + [key] + if isinstance(value, dict): + traverse_dict(value, new_path) + else: + paths.append(new_path) + + traverse_dict(d) + return paths + + +def print_metric_paths(data: list[dict]): + paths = get_metric_paths(data) + for p in paths: + print(p) + + +def get_metric_iteration_list(data: list[dict], metric_path: list[str]): + if isinstance(metric_path, str): + metric_path = [metric_path] + sgl = [] + for d in data: + sgl.append(get_from_nested_dict(d, metric_path)) + return sgl + + +def to_1d_numeric(x): + """Return a 1-D float array (or None if not numeric). Accepts scalars, numpy arrays, or nested list/tuple of them.""" + if x is None: + return None + if isinstance(x, (int, float, np.number)): + return np.array([float(x)], dtype=float) + if isinstance(x, np.ndarray): + try: + return x.astype(float).ravel() + except Exception: + return None + if isinstance(x, (list, tuple)): + parts = [] + for e in x: + arr = to_1d_numeric(e) + if arr is not None and arr.size > 0: + parts.append(arr) + if parts: + return np.concatenate(parts) + return None + return None + + +def get_single_metric_vector(data, metric_path, iterations=None): + if isinstance(metric_path, str): + metric_path = [metric_path] + if iterations == None: + iterations = len(data) + vecs = [] + for d in data: + ar = get_from_nested_dict(d, metric_path) + arr = to_1d_numeric(ar) + if arr is not None: + vecs.append(arr) + + return np.concatenate(vecs) if vecs else np.empty(0, dtype=float) + + +def _load_metrics_file(file_path: str): + if not (file_path.endswith(".tally.pkl") or file_path.endswith(".pkl")): + raise ValueError("Only *.tally.pkl files are supported.") + import pickle + + with open(file_path, "rb") as f: + tree = pickle.load(f) + return tree + + +def get_leaf_items(array_tally: dict, prefix: list[str] = None): + if prefix is None: + prefix = [] + for key, value in array_tally.items(): + next_prefix = prefix + [str(key)] + if isinstance(value, dict): + yield from get_leaf_items(value, next_prefix) + else: + yield next_prefix, value + + +def _sanitize_filename_part(part: str) -> str: + s = part.replace("/", "|") + s = s.replace(" ", "_") + return s + + +def render_rt_tally_pkl_to_csvs(pkl_path: str, outdir: str): + """ + This method takes care of tokenwise logging. + """ + with open(pkl_path, "rb") as f: + payload = pickle.load(f) + # Backward compatibility: older tallies stored the dict directly + if isinstance(payload, dict) and "array_tally" in payload: + array_tally = payload.get("array_tally", {}) + else: + array_tally = payload + + os.makedirs(outdir, exist_ok=True) + trainer_id = os.path.basename(pkl_path).replace(".rt_tally.pkl", "") + for path_list, rollout_tally_items in get_leaf_items(array_tally): + # Create file and initiate writer + path_part = ".".join(_sanitize_filename_part(p) for p in path_list) + filename = f"{trainer_id}__{path_part}.render.csv" + out_path = os.path.join(outdir, filename) + + # Write metric rows to CSV + with open(out_path, "w", newline="") as f: + writer = csv.writer(f) + + # Write header row - need to determine metric column count from first rollout_tally_item + first_item = rollout_tally_items[0] + metric_cols = ( + first_item.metric_matrix.shape[1] + if first_item.metric_matrix.ndim > 1 + else 1 + ) + header = ["agent_id", "crn_id", "rollout_id"] + [ + f"t_{i}" for i in range(metric_cols) + ] + writer.writerow(header) + + for rollout_tally_item in rollout_tally_items: + crn_ids = rollout_tally_item.crn_ids + rollout_ids = rollout_tally_item.rollout_ids + agent_ids = rollout_tally_item.agent_ids + metric_matrix = rollout_tally_item.metric_matrix + for i in range(metric_matrix.shape[0]): + row_vals = metric_matrix[i].reshape(-1) + # Convert row_vals to a list to avoid numpy concatenation issues + row_vals = ( + row_vals.tolist() + if hasattr(row_vals, "tolist") + else list(row_vals) + ) + row_prefix = [ + agent_ids[i], + crn_ids[i], + rollout_ids[i], + ] + writer.writerow(row_prefix + row_vals) + + +def tally_to_stat_pack(tally: Dict[str, Any]): + stat_pack = StatPack() + if "array_tally" in tally: + tally = tally["array_tally"] + + # backward compatibility: will remove later, flatten keys in tally + def get_from_nested_dict(dictio: dict, path: list[str]): + for sp in path[:-1]: + dictio = dictio[sp] + return dictio.get(path[-1]) + + def get_metric_paths(tally: dict): + paths = [] + + def traverse_dict(tally, current_path=[]): + for key, value in tally.items(): + new_path = current_path + [key] + if isinstance(value, dict): + traverse_dict(value, new_path) + else: + paths.append(new_path) + + traverse_dict(tally) + return paths + + paths = get_metric_paths(tally) + modified_tally = {} + for p in paths: + val = get_from_nested_dict(tally, p) + modified_tally["_".join(p)] = np.mean(val) + del tally + tally = modified_tally + for key, value in tally.items(): + stat_pack.add_stat(key, value) + return stat_pack diff --git a/src_code_for_reproducibility/utils/get_stochastic_game_lengths.py b/src_code_for_reproducibility/utils/get_stochastic_game_lengths.py new file mode 100644 index 0000000000000000000000000000000000000000..a43c386aa1764ae2de1d6e177a0238c633c74bba --- /dev/null +++ b/src_code_for_reproducibility/utils/get_stochastic_game_lengths.py @@ -0,0 +1,30 @@ +import numpy as np + +def get_stochastic_game_lengths( + max_length, + nb_games, + continuation_prob, + same_length_batch=False +): + """ + Generates stochastic game lengths based on a geometric distribution. + + Args: + max_length (int): The maximum length a game can have. + nb_games (int): The number of games to generate lengths for. + continuation_prob (float): The probability of the game continuing after each round. + same_length_batch (bool): If True, all games will have the same length. + + Returns: + Array: An array of game lengths. + """ + if continuation_prob == 1: + return [max_length] * nb_games + if same_length_batch: + length = np.random.geometric(1 - continuation_prob, 1) + game_lengths = np.repeat(length, nb_games) + else: + game_lengths = np.random.geometric(1 - continuation_prob, nb_games) + + game_lengths = np.where(game_lengths > max_length, max_length, game_lengths) + return game_lengths.tolist() diff --git a/src_code_for_reproducibility/utils/kill_sglang.py b/src_code_for_reproducibility/utils/kill_sglang.py new file mode 100644 index 0000000000000000000000000000000000000000..e1e1fe9059e4262e995e0876f7eaec1c7aae4464 --- /dev/null +++ b/src_code_for_reproducibility/utils/kill_sglang.py @@ -0,0 +1,17 @@ +import psutil +import signal + +target_name = "sglang::scheduler" +killed = [] + +def kill_sglang(): + for proc in psutil.process_iter(['pid', 'name', 'cmdline']): + try: + # Some processes may not have a name or cmdline + cmdline = " ".join(proc.info['cmdline']) if proc.info['cmdline'] else "" + if target_name in cmdline: + print(f"Killing PID {proc.pid}: {cmdline}") + proc.send_signal(signal.SIGKILL) + killed.append(proc.pid) + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass diff --git a/src_code_for_reproducibility/utils/output_source_code.py b/src_code_for_reproducibility/utils/output_source_code.py new file mode 100644 index 0000000000000000000000000000000000000000..42b51ecb2b818e5e225af10c58898b8be21dee4d --- /dev/null +++ b/src_code_for_reproducibility/utils/output_source_code.py @@ -0,0 +1,6 @@ +def output_source_code(model, output_path: str) -> None: + """ + Outputs the source code of the model to the given path. + """ + with open(output_path, "w") as f: + f.write(model.source_code) diff --git a/src_code_for_reproducibility/utils/resource_context.py b/src_code_for_reproducibility/utils/resource_context.py new file mode 100644 index 0000000000000000000000000000000000000000..43a3a55d0ca0d4a69eadd0c57650a5afd2ae4831 --- /dev/null +++ b/src_code_for_reproducibility/utils/resource_context.py @@ -0,0 +1,78 @@ +import logging +import time +from contextlib import contextmanager + +import torch + + +def vram_usage(): + output = "" + for i in range(torch.cuda.device_count()): + gpu_memory_allocated = torch.cuda.memory_allocated(i) / ( + 1024**3 + ) # Convert bytes to GB + gpu_memory_reserved = torch.cuda.memory_reserved(i) / ( + 1024**3 + ) # Convert bytes to GB + output += f"GPU {i}: Memory Allocated: {gpu_memory_allocated:.2f} GB, Memory Reserved: {gpu_memory_reserved:.2f} GB" + return output + + +def ram_usage(): + import psutil + + process = psutil.Process() + memory_info = process.memory_info() + ram_used = memory_info.rss / (1024**3) # Convert bytes to GB + return f"RAM Usage: {ram_used:.2f} GB" + + +@contextmanager +def resource_logger_context(logger: logging.Logger, task_description: str): + """ + Context manager to log the resource usage of the current task. + Args: + logger: The logger to use to log the resource usage. + task_description: The description of the task to log. + Returns: + None + """ + try: + initial_time = time.time() + # Assume CUDA is available and use device 0 only + total_mem_bytes = torch.cuda.get_device_properties(0).total_memory + initial_total_bytes = ( + torch.cuda.memory_allocated(0) + torch.cuda.memory_reserved(0) + ) + torch.cuda.reset_peak_memory_stats(0) + yield None + finally: + final_time = time.time() + # Ensure kernels within the block are accounted for + torch.cuda.synchronize() + + # Compute metrics + final_allocated_bytes = torch.cuda.memory_allocated(0) + final_reserved_bytes = torch.cuda.memory_reserved(0) + final_total_bytes = final_allocated_bytes + final_reserved_bytes + + delta_vram_percent_total = ( + 100 * (final_total_bytes - initial_total_bytes) / total_mem_bytes + if total_mem_bytes + else 0.0 + ) + current_percent_vram_taken = ( + 100 * final_total_bytes / total_mem_bytes if total_mem_bytes else 0.0 + ) + block_peak_percent = ( + 100 * torch.cuda.max_memory_allocated(0) / total_mem_bytes + if total_mem_bytes + else 0.0 + ) + delta_time_str = time.strftime( + '%H:%M:%S', time.gmtime(final_time - initial_time) + ) + + logger.info( + f"For task: {task_description}, ΔVRAM % (total): {delta_vram_percent_total:.2f}%, Current % of VRAM taken: {current_percent_vram_taken:.2f}%, Block Peak % of device VRAM: {block_peak_percent:.2f}%, ΔTime: {delta_time_str}" + ) diff --git a/src_code_for_reproducibility/utils/rollout_tree_chat_htmls.py b/src_code_for_reproducibility/utils/rollout_tree_chat_htmls.py new file mode 100644 index 0000000000000000000000000000000000000000..821d8007c0bea0fd00e4acb04e9165f1eb6b9b3f --- /dev/null +++ b/src_code_for_reproducibility/utils/rollout_tree_chat_htmls.py @@ -0,0 +1,1587 @@ +from pathlib import Path +from typing import List + +from mllm.utils.rollout_tree_gather_utils import * + + +def html_from_chat_turns(chat_turns: List[ChatTurnLog]) -> str: + """ + Render chat turns as a single, wrapping sequence of messages in time order. + Keep badge and message bubble styles, include time on every badge and + include rewards on assistant badges. Each message is individually + hide/show by click; when hidden, only the badge remains and "(...)" is + shown inline (not inside a bubble). + """ + import html + import re as _re + + # Prepare ordering: sort by (time_step, original_index) to keep stable order within same step + indexed_turns = list(enumerate(chat_turns)) + indexed_turns.sort(key=lambda t: (t[1].time_step, t[0])) + + # Get unique agent IDs and sort alphabetically for consistent assignment + # Agent with alphabetically lower name gets agent-0 (left, green) + # Agent with alphabetically higher name gets agent-1 (right, orange) + unique_agent_ids = sorted(set(turn.agent_id for turn in chat_turns if turn.role == "assistant")) + agent_id_to_index = {aid: idx for idx, aid in enumerate(unique_agent_ids)} + + # CSS styles (simplified layout; no time-step or agent-column backgrounds) + css = """ + + """ + + # HTML structure + html_parts = [ + "", + "", + "", + "", + "Chat Turns", + css, + "", + "", + "", + '
', + '
', + '
', + '', + '', + '', + '', + '', + '900px', + '', + '', + '', + '', + '', + '', + '', + '', + 'px', + '', + '', + '', + f'', + f'', + '|', + f'', + f'', + '', + '', + "
", + "
", + ] + + # Add Chat View + import html as _html_mod + html_parts.append('
') + + # Helper function to add context annotation areas + def add_context_area(position: str, time_step: int): + context_key = f"round-context-{position}-{time_step}" + placeholder = f"Add context {position} round {time_step}..." + color_buttons = "" + # Add default/reset color button first + color_buttons += ( + f'
' + ) + for color_name, color_value in [ + ('red', '#d32f2f'), + ('orange', '#f57c00'), + ('yellow', '#f9a825'), + ('green', '#388e3c'), + ('blue', '#1976d2'), + ('purple', '#7b1fa2'), + ('gray', '#666666'), + ]: + color_buttons += ( + f'
' + ) + + html_parts.append( + f'
' + f'
' + f'
{color_buttons}
' + f'
' + ) + + # Helper function to add split agent context boxes + def add_split_agent_contexts(position: str, time_step: int): + color_buttons = "" + # Add default/reset color button first + color_buttons += ( + f'
' + ) + for color_name, color_value in [ + ('red', '#d32f2f'), + ('orange', '#f57c00'), + ('yellow', '#f9a825'), + ('green', '#388e3c'), + ('blue', '#1976d2'), + ('purple', '#7b1fa2'), + ('gray', '#666666'), + ]: + color_buttons += ( + f'
' + ) + + html_parts.append('
') + + # Agent 0 box + agent0_key = f"agent-context-0-{position}-{time_step}" + agent0_placeholder = f"..." + html_parts.append( + f'
' + f'
' + f'
{color_buttons}
' + f'
' + ) + + # Agent 1 box + agent1_key = f"agent-context-1-{position}-{time_step}" + agent1_placeholder = f"..." + html_parts.append( + f'
' + f'
' + f'
{color_buttons}
' + f'
' + ) + + html_parts.append('
') # split-agent-context + + last_time_step_chat = None + for original_index, turn in indexed_turns: + # Use agent index for CSS class (agent-0 or agent-1) instead of agent ID + agent_index = agent_id_to_index.get(turn.agent_id, 0) + agent_class = f"agent-{agent_index}" + role_class = f"role-{turn.role}" + + # Add time step divider and beginning context + if last_time_step_chat is None or turn.time_step != last_time_step_chat: + # Add end contexts for previous round (only regular context, not prompt summary) + if last_time_step_chat is not None: + add_context_area("end", last_time_step_chat) + + html_parts.append( + f'
' + f'⏱ Round {turn.time_step + 1}' + f'
' + ) + + # Add beginning contexts for new round (both context and prompt summary) + add_context_area("beginning", turn.time_step) + add_split_agent_contexts("beginning", turn.time_step) + + last_time_step_chat = turn.time_step + + # Build chat message with merge controls + html_parts.append(f'
') + + # Add merge control button + html_parts.append( + f'' + ) + + html_parts.append('
') + + # Header with agent name and reward (always show reward) + if turn.role == "assistant": + name = _html_mod.escape(turn.agent_id) + raw_val = turn.reward + if isinstance(raw_val, (int, float)): + reward_val = f"{raw_val:.4f}".rstrip("0").rstrip(".") + if len(reward_val) > 8: + reward_val = reward_val[:8] + "…" + else: + reward_val = str(raw_val) + header_html = ( + f'
' + f'🤖 {name}' + f'⚑ {reward_val}' + f'
' + ) + else: + name = _html_mod.escape(turn.agent_id) + header_html = f'
Prompt of {name}
' + + html_parts.append(header_html) + + # Reasoning content if present + if turn.reasoning_content: + _raw_reasoning = turn.reasoning_content.replace("\r\n", "\n") + _raw_reasoning = _re.sub(r"^\s*\n+", "", _raw_reasoning) + esc_reasoning = _html_mod.escape(_raw_reasoning) + html_parts.append( + f'' + ) + + # Message bubble + esc_content = _html_mod.escape(turn.content) + html_parts.append(f'
{esc_content}
') + + html_parts.append('
') # chat-message-content + html_parts.append('
') # chat-message + + # Add end contexts for the last round (only regular context, not prompt summary) + if last_time_step_chat is not None: + add_context_area("end", last_time_step_chat) + + html_parts.append("
") # flow-chat + html_parts.extend(["", ""]) + + return "\n".join(html_parts) + + +def export_html_from_rollout_tree(path: Path, outdir: Path, main_only: bool = False): + """Process a rollout tree file and generate HTML files for each path. + Creates separate HTML files for the main path and each branch path. + The main path is saved in the root output directory, while branch paths + are saved in a 'branches' subdirectory. + + Args: + path: Path to the rollout tree JSON file + outdir: Output directory for HTML files + main_only: If True, only export the main trajectory (default: False) + """ + root = load_rollout_tree(path) + mgid = root.id + + main_path, branch_paths = get_rollout_tree_paths(root) + + outdir.mkdir(parents=True, exist_ok=True) + + # Create branches subdirectory if we have branch paths + if not main_only and branch_paths: + branches_dir = outdir / f"mgid:{mgid}_branches_html_renders" + branches_dir.mkdir(parents=True, exist_ok=True) + + # Generate HTML for the main path + chat_turns = gather_all_chat_turns_for_path(main_path) + html_content = html_from_chat_turns(chat_turns) + output_file = outdir / f"mgid:{mgid}_main_html_render.render.html" + with open(output_file, "w", encoding="utf-8") as f: + f.write(html_content) + + # Generate HTML for each branch path + for path_obj in branch_paths: + chat_turns = gather_all_chat_turns_for_path(path_obj) + + html_content = html_from_chat_turns(chat_turns) + + path_id: str = path_obj.id + output_filename = f"{path_id}_html_render.render.html" + + output_file = branches_dir / output_filename + + with open(output_file, "w", encoding="utf-8") as f: + f.write(html_content) diff --git a/src_code_for_reproducibility/utils/rollout_tree_gather_utils.py b/src_code_for_reproducibility/utils/rollout_tree_gather_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d1844aa5a20c01b591865cad108f1fb1577d3ef4 --- /dev/null +++ b/src_code_for_reproducibility/utils/rollout_tree_gather_utils.py @@ -0,0 +1,314 @@ +from __future__ import annotations + +import csv +import os +import pickle +import re +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple + +from mllm.markov_games.rollout_tree import * + + + + + +def load_rollout_tree(path: Path) -> RolloutTreeRootNode: + """Load a rollout tree from a PKL file containing a dict.""" + with open(path, "rb") as f: + data = pickle.load(f) + return RolloutTreeRootNode.model_validate(data) + + +@dataclass +class RolloutNodeList: + id: str + nodes: List[RolloutTreeNode] + + +def get_rollout_tree_paths( + root: RolloutTreeRootNode, mgid: Optional[str] = None +) -> Tuple[RolloutNodeList, List[RolloutNodeList]]: + """ + Returns: + main_path: The main path from the root to the end of the tree. + branch_paths: A list of all branch paths from the root to the end of the tree. + Each branch path contains a list of nodes that are part of the branch, including the nodes from the main path before the branch was taken. + """ + branch_paths = [] + + def collect_path_nodes(current) -> List[RolloutTreeNode]: + """Recursively collect all nodes in a path starting from current node.""" + if current is None: + return [] + + if isinstance(current, RolloutTreeNode): + return [current] + collect_path_nodes(current.child) + + elif isinstance(current, RolloutTreeBranchNode): + # For branch nodes, we only follow the main_child for path collection + if current.main_child: + return [current.main_child] + collect_path_nodes( + current.main_child.child + ) + else: + return [] + + def traverse_for_branches( + current, + main_path_prefix: List[RolloutTreeNode], + path_id: str, + current_time_step: Optional[int] = 0, + ): + """Traverse tree to collect all branch paths.""" + if current is None: + return + + if isinstance(current, RolloutTreeNode): + # Continue traversing with this node added to the main path prefix + new_prefix = main_path_prefix + [current] + traverse_for_branches(current.child, new_prefix, path_id, current.time_step) + + elif isinstance(current, RolloutTreeBranchNode): + # Collect all branch paths + if current.branches: + for agent_id, branch_node_list in current.branches.items(): + if branch_node_list: + # Start with the main path prefix, then recursively collect all nodes in this branch + branch_path_nodes = main_path_prefix.copy() + for branch_node in branch_node_list: + branch_path_nodes.extend(collect_path_nodes(branch_node)) + + # Create proper branch path ID with mgid, agent_id, and time_step + mgid_str = mgid or str(root.id) + branch_path_id = f"mgid:{mgid_str}_type:branch_agent:{agent_id}_time_step:{current_time_step}" + branch_paths.append( + RolloutNodeList(id=branch_path_id, nodes=branch_path_nodes) + ) + + # Process the main child and add to prefix + new_prefix = main_path_prefix + if current.main_child: + new_prefix = main_path_prefix + [current.main_child] + + # Continue traversing the main path + if current.main_child: + traverse_for_branches( + current.main_child.child, + new_prefix, + path_id, + current.main_child.time_step, + ) + + # Collect the main path nodes + main_path_nodes = collect_path_nodes(root.child) + + # Traverse to collect all branch paths + traverse_for_branches(root.child, [], "") + + # Create the main path with proper mgid format + mgid_str = mgid or str(root.id) + main_path = RolloutNodeList(id=f"mgid:{mgid_str}_type:main", nodes=main_path_nodes) + + return main_path, branch_paths + + +class ChatTurnLog(BaseModel): + time_step: int + agent_id: str + role: str + content: str + reasoning_content: Optional[str] = None + is_state_end: bool + reward: float + + +def gather_agent_chat_turns_for_path( + agent_id: str, path: RolloutNodeList +) -> List[ChatTurnLog]: + """Iterate through all chat turns for a specific agent in a path sorted by time step.""" + turns = [] + for node in path.nodes: + action_log = node.step_log.action_logs.get(agent_id, []) + if action_log: + for chat_turn in action_log.chat_turns or []: + turns.append( + ChatTurnLog( + time_step=node.time_step, + agent_id=agent_id, + role=chat_turn.role, + content=chat_turn.content, + reasoning_content=getattr(chat_turn, "reasoning_content", None), + is_state_end=chat_turn.is_state_end, + reward=node.step_log.simulation_step_log.rewards.get( + agent_id, 0 + ), + ) + ) + return turns + + +def gather_all_chat_turns_for_path(path: RolloutNodeList) -> List[ChatTurnLog]: + """Iterate through all chat turns for all agents in a path sorted by time step.""" + turns = [] + + # Collect turns from all agents, but interleave them per timestep by (user, assistant) pairs + for node in path.nodes: + # Build (user[, assistant]) pairs for each agent at this timestep + agent_ids = sorted(list(node.step_log.action_logs.keys())) + per_agent_pairs: Dict[str, List[List[ChatTurnLog]]] = {} + + for agent_id in agent_ids: + action_log = node.step_log.action_logs.get(agent_id) + pairs: List[List[ChatTurnLog]] = [] + current_pair: List[ChatTurnLog] = [] + + if action_log and action_log.chat_turns: + for chat_turn in action_log.chat_turns: + turn_log = ChatTurnLog( + time_step=node.time_step, + agent_id=agent_id, + role=chat_turn.role, + content=chat_turn.content, + reasoning_content=getattr(chat_turn, "reasoning_content", None), + is_state_end=chat_turn.is_state_end, + reward=node.step_log.simulation_step_log.rewards.get( + agent_id, 0 + ), + ) + + if chat_turn.role == "user": + # If a previous pair is open, close it and start a new one + if current_pair: + pairs.append(current_pair) + current_pair = [] + current_pair = [turn_log] + else: + # assistant: attach to an open user message if present; otherwise stand alone + if ( + current_pair + and len(current_pair) == 1 + and current_pair[0].role == "user" + ): + current_pair.append(turn_log) + pairs.append(current_pair) + current_pair = [] + else: + # No preceding user or already paired; treat as its own unit + pairs.append([turn_log]) + + if current_pair: + # Unpaired trailing user message + pairs.append(current_pair) + + per_agent_pairs[agent_id] = pairs + + # Interleave pairs across agents: A1, B1, A2, B2, ... + index = 0 + while True: + added_any = False + for agent_id in agent_ids: + agent_pairs = per_agent_pairs.get(agent_id, []) + if index < len(agent_pairs): + for tl in agent_pairs[index]: + turns.append(tl) + added_any = True + if not added_any: + break + index += 1 + + return turns + + +def chat_turns_to_dict(chat_turns: Iterator[ChatTurnLog]) -> Iterator[Dict[str, Any]]: + """Render all chat turns for a path as structured data for JSON.""" + for chat_turn in chat_turns: + yield chat_turn.model_dump() + + +def get_all_agents(root: RolloutTreeRootNode) -> List[str]: + """list of all agent IDs that appear in the tree.""" + if root.child is None: + return [] + + # Get the first node to extract all agent IDs + first_node = root.child + if isinstance(first_node, RolloutTreeBranchNode): + first_node = first_node.main_child + + if first_node is None: + return [] + + # All agents should be present in the first node + agents = set(first_node.step_log.action_logs.keys()) + agents.update(first_node.step_log.simulation_step_log.rewards.keys()) + + return sorted(list(agents)) + + +def gather_agent_main_rewards(agent_id: str, path: RolloutNodeList) -> List[float]: + """Gather main rewards for a specific agent in a path.""" + rewards = [] + for node in path.nodes: + reward = node.step_log.simulation_step_log.rewards[agent_id] + rewards.append(reward) + return rewards + + +def gather_all_rewards(path: RolloutNodeList) -> List[Dict[AgentId, float]]: + """Gather main rewards from main trajectory in a path.""" + rewards = [] + for node in path.nodes: + rewards.append(node.step_log.simulation_step_log.rewards.copy()) + return rewards + + +def gather_simulation_stats( + path: RolloutNodeList, + filter: Callable[[SimulationStepLog], bool], + stat_func: Callable[[SimulationStepLog], Any], +) -> List[Any]: + """Gather stats from main trajectory in a path.""" + stats = [] + for node in path.nodes: + sl = node.step_log.simulation_step_log + if filter(sl): + stats.append(stat_func(sl)) + return stats + + +def gather_simulation_step_logs(path: RolloutNodeList) -> List[SimulationStepLog]: + """Gather simulation information from main trajectory in a path.""" + infos = [] + for node in path.nodes: + infos.append(node.step_log.simulation_step_log) + return infos + + +def export_chat_logs(path: Path, outdir: Path): + """Process a rollout tree PKL file and generate a JSONL of chat turns as dicts. + Each line contains an object with path_id and chat_turns for a single path. + """ + import json + + root = load_rollout_tree(path) + mgid = root.id + + main_path, branch_paths = get_rollout_tree_paths(root) + all_paths = [main_path] + branch_paths + + outdir.mkdir(parents=True, exist_ok=True) + output_file = outdir / f"mgid:{mgid}_plucked_chats.render.jsonl" + + with open(output_file, "w", encoding="utf-8") as f: + for path_obj in all_paths: + chat_turns = gather_all_chat_turns_for_path(path_obj) + output_obj = { + "path_id": str(path_obj.id), + "chat_turns": list(chat_turns_to_dict(iter(chat_turns))), + } + f.write(json.dumps(output_obj, ensure_ascii=False) + "\n") + + diff --git a/src_code_for_reproducibility/utils/rollout_tree_stats.py b/src_code_for_reproducibility/utils/rollout_tree_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..1ac3cd0e34212e7fdbeba19e501b7d96a5f128e1 --- /dev/null +++ b/src_code_for_reproducibility/utils/rollout_tree_stats.py @@ -0,0 +1,50 @@ +from typing import Any, Callable, List, Tuple + +from mllm.markov_games.rollout_tree import RolloutTreeRootNode +from mllm.markov_games.simulation import SimulationStepLog +from mllm.utils.rollout_tree_gather_utils import ( + gather_simulation_step_logs, + get_rollout_tree_paths, +) +from mllm.utils.stat_pack import StatPack + + +def get_rollout_tree_stat_tally( + rollout_tree: RolloutTreeRootNode, + metrics: List[Callable[[SimulationStepLog], List[Tuple[str, float]]]], +) -> StatPack: + stat_tally = StatPack() + # get simulation step logs + node_list = get_rollout_tree_paths(rollout_tree)[0] + simulation_step_logs = gather_simulation_step_logs(node_list) + for simulation_step_log in simulation_step_logs: + for metric in metrics: + metric_result = metric(simulation_step_log) + if metric_result is not None: + for key, value in metric_result: + stat_tally.add_stat(key, value) + return stat_tally + + +def get_rollout_tree_mean_stats( + rollout_tree: RolloutTreeRootNode, metrics: List[Callable[[SimulationStepLog], Any]] +) -> StatPack: + """Get the mean stats for a rollout tree.""" + stat_tally = get_rollout_tree_stat_tally(rollout_tree, metrics) + return stat_tally.mean() + + +def get_mean_rollout_tree_stats( + rollout_trees: List[RolloutTreeRootNode], + metrics: List[Callable[[SimulationStepLog], Any]], +) -> StatPack: + """Get the mean stats for a list of rollout trees.""" + # TODO complete this + stat_tallies = [ + get_rollout_tree_mean_stats(rollout_tree, metrics) + for rollout_tree in rollout_trees + ] + mean_stat_tally = StatPack() + for stat_tally in stat_tallies: + mean_stat_tally.add_stats(stat_tally) + return mean_stat_tally.mean() diff --git a/src_code_for_reproducibility/utils/stat_pack.py b/src_code_for_reproducibility/utils/stat_pack.py new file mode 100644 index 0000000000000000000000000000000000000000..46b397139a1a8a4149030a9cc33d2b3afb7b4a12 --- /dev/null +++ b/src_code_for_reproducibility/utils/stat_pack.py @@ -0,0 +1,113 @@ +import csv +import json +import os +import pickle +from collections import Counter +from copy import deepcopy +from locale import strcoll +from statistics import mean +from typing import Any, Dict, Iterator, List, Optional, Tuple, TypedDict + +import matplotlib.pyplot as plt +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +plt.style.use( + "https://raw.githubusercontent.com/dereckpiche/DedeStyle/refs/heads/main/dedestyle.mplstyle" +) + +import wandb + +from . import wandb_utils + + +class StatPack: + def __init__(self): + self.data = {} + + def add_stat(self, key: str, value: float | int | None): + assert ( + isinstance(value, float) or isinstance(value, int) or value is None + ), f"Value {value} is not a valid type" + if key not in self.data: + self.data[key] = [] + self.data[key].append(value) + + def add_stats(self, other: "StatPack"): + for key in other.keys(): + self.add_stat(key, other[key]) + + def __getitem__(self, key: str): + return self.data[key] + + def __setitem__(self, key: str, value: Any): + self.data[key] = value + + def __contains__(self, key: str): + return key in self.data + + def __len__(self): + return len(self.data) + + def __iter__(self): + return iter(self.data) + + def keys(self): + return self.data.keys() + + def values(self): + return self.data.values() + + def items(self): + return self.data.items() + + def mean(self): + mean_st = StatPack() + for key in self.keys(): + if isinstance(self[key], list): + # TODO: exclude None values + non_none_values = [v for v in self[key] if v is not None] + if non_none_values: + mean_st[key] = np.mean(np.array(non_none_values)) + else: + mean_st[key] = None + return mean_st + + def store_plots(self, folder: str): + os.makedirs(folder, exist_ok=True) + for key in self.keys(): + plt.figure(figsize=(10, 5)) + plt.plot(self[key]) + plt.title(key) + plt.savefig(os.path.join(folder, f"{key}.pdf")) + plt.close() + + def store_numpy(self, folder: str): + os.makedirs(folder, exist_ok=True) + for key in self.keys(): + # Sanitize filename components (avoid slashes, spaces, etc.) + safe_key = str(key).replace(os.sep, "_").replace("/", "_").replace(" ", "_") + values = self[key] + # Convert None to NaN for numpy compatibility + arr = np.array( + [(np.nan if (v is None) else v) for v in values], dtype=float + ) + np.save(os.path.join(folder, f"{safe_key}.npy"), arr) + + def store_json(self, folder: str, filename: str = "stats.json"): + os.makedirs(folder, exist_ok=True) + with open(os.path.join(folder, filename), "w") as f: + json.dump(self.data, f, indent=4) + + def store_csv(self, folder: str): + os.makedirs(folder, exist_ok=True) + for key in self.keys(): + with open(os.path.join(folder, f"stats.csv"), "w") as f: + writer = csv.writer(f) + writer.writerow([key] + self[key]) + + def store_pickle(self, folder: str): + os.makedirs(folder, exist_ok=True) + for key in self.keys(): + with open(os.path.join(folder, f"stats.pkl"), "wb") as f: + pickle.dump(self[key], f) diff --git a/src_code_for_reproducibility/utils/update_start_epoch.py b/src_code_for_reproducibility/utils/update_start_epoch.py new file mode 100644 index 0000000000000000000000000000000000000000..036ddce31b12e7a6547c5099dd37962a88055643 --- /dev/null +++ b/src_code_for_reproducibility/utils/update_start_epoch.py @@ -0,0 +1,9 @@ +import os + +# During run, set hydra.run.dir=./outputs/{folder} +def update_start_epoch(cfg, output_directory): + if cfg["experiment"]["resume_experiment"]: + folders = [f for f in os.listdir(output_directory) if f.startswith("iteration_")] + iterations = [int(f.split("_")[1]) for f in folders] if folders else [0] + cfg["experiment"]["start_epoch"] = max(iterations) + return None diff --git a/src_code_for_reproducibility/utils/wandb_utils.py b/src_code_for_reproducibility/utils/wandb_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3e5d83ed1a5304f208288f78457582b0acdb58c4 --- /dev/null +++ b/src_code_for_reproducibility/utils/wandb_utils.py @@ -0,0 +1,164 @@ +import os +from typing import Any, Dict, Optional + + +_WANDB_AVAILABLE = False +_WANDB_RUN = None + + +def _try_import_wandb(): + global _WANDB_AVAILABLE + if _WANDB_AVAILABLE: + return True + try: + import wandb # type: ignore + + _WANDB_AVAILABLE = True + return True + except Exception: + _WANDB_AVAILABLE = False + return False + + +def _safe_get(cfg: Dict[str, Any], path: list[str], default: Any = None) -> Any: + cur: Any = cfg + for key in path: + if not isinstance(cur, dict) or key not in cur: + return default + cur = cur[key] + return cur + + +def is_enabled(cfg: Dict[str, Any]) -> bool: + return bool(_safe_get(cfg, ["logging", "wandb", "enabled"], False)) + + +def init(cfg: Dict[str, Any], run_dir: str, run_name: Optional[str] = None) -> None: + """ + Initialize Weights & Biases if enabled in config. No-op if disabled or wandb not installed. + """ + global _WANDB_RUN + if not is_enabled(cfg): + return + if not _try_import_wandb(): + return + + import wandb # type: ignore + + project = _safe_get(cfg, ["logging", "wandb", "project"], "llm-negotiation") + entity = _safe_get(cfg, ["logging", "wandb", "entity"], None) + mode = _safe_get(cfg, ["logging", "wandb", "mode"], "online") + tags = _safe_get(cfg, ["logging", "wandb", "tags"], []) or [] + notes = _safe_get(cfg, ["logging", "wandb", "notes"], None) + group = _safe_get(cfg, ["logging", "wandb", "group"], None) + name = _safe_get(cfg, ["logging", "wandb", "name"], run_name) + + # Ensure files are written into the hydra run directory + os.makedirs(run_dir, exist_ok=True) + os.environ.setdefault("WANDB_DIR", run_dir) + + # Convert cfg to plain types for W&B config; fallback to minimal dictionary + try: + from omegaconf import OmegaConf # type: ignore + + cfg_container = OmegaConf.to_container(cfg, resolve=True) # type: ignore + except Exception: + cfg_container = cfg + + _WANDB_RUN = wandb.init( + project=project, + entity=entity, + mode=mode, + name=name, + group=group, + tags=tags, + notes=notes, + config=cfg_container, + dir=run_dir, + reinit=True, + ) + + +def log(metrics: Dict[str, Any], step: Optional[int] = None) -> None: + """Log a flat dictionary of metrics to W&B if active.""" + if not _WANDB_AVAILABLE or _WANDB_RUN is None: + return + try: + import wandb # type: ignore + + wandb.log(metrics if step is None else dict(metrics, step=step)) + except Exception: + pass + + +def _flatten(prefix: str, data: Dict[str, Any], out: Dict[str, Any]) -> None: + for k, v in data.items(): + key = f"{prefix}.{k}" if prefix else k + if isinstance(v, dict): + _flatten(key, v, out) + else: + out[key] = v + + +def _summarize_value(value: Any) -> Dict[str, Any]: + import numpy as np # local import to avoid hard dependency during disabled mode + + if value is None: + return {"none": 1} + # Scalars + if isinstance(value, (int, float)): + return {"value": float(value)} + # Lists or arrays + try: + arr = np.asarray(value) + if arr.size == 0: + return {"size": 0} + return { + "mean": float(np.nanmean(arr)), + "min": float(np.nanmin(arr)), + "max": float(np.nanmax(arr)), + "last": float(arr.reshape(-1)[-1]), + "size": int(arr.size), + } + except Exception: + # Fallback: string repr + return {"text": str(value)} + + +def log_tally(array_tally: Dict[str, Any], prefix: str = "", step: Optional[int] = None) -> None: + """ + Flatten and summarize Tally.array_tally and log to WandB. + Each leaf list/array is summarized with mean/min/max/last/size. + """ + if not _WANDB_AVAILABLE or _WANDB_RUN is None: + return + summarized: Dict[str, Any] = {} + + def walk(node: Any, path: list[str]): + if isinstance(node, dict): + for k, v in node.items(): + walk(v, path + [k]) + return + # node is a list of values accumulated over time + key = ".".join([p for p in ([prefix] if prefix else []) + path]) + try: + summary = _summarize_value(node) + for sk, sv in summary.items(): + summarized[f"{key}.{sk}"] = sv + except Exception: + summarized[f"{key}.error"] = 1 + + walk(array_tally, []) + if summarized: + log(summarized, step=step) + + +def log_flat_stats(stats: Dict[str, Any], prefix: str = "", step: Optional[int] = None) -> None: + if not _WANDB_AVAILABLE or _WANDB_RUN is None: + return + flat: Dict[str, Any] = {} + _flatten(prefix, stats, flat) + if flat: + log(flat, step=step) + +