diff --git a/src_code_for_reproducibility/docs/generate_docs.py b/src_code_for_reproducibility/docs/generate_docs.py new file mode 100644 index 0000000000000000000000000000000000000000..e644cbbf091a97420500fb47346c07be5ed141ac --- /dev/null +++ b/src_code_for_reproducibility/docs/generate_docs.py @@ -0,0 +1,249 @@ +#!/usr/bin/env python3 +""" +Script to automatically generate Sphinx documentation for all modules and build the HTML website. +""" +import importlib.util +import os +import subprocess +import sys + + +def check_and_install_dependencies(): + """Check for required dependencies and install them if missing.""" + required_packages = [ + "sphinx", + "sphinx-rtd-theme", + "sphinxcontrib-napoleon", + "sphinxcontrib-mermaid", + "sphinx-autodoc-typehints", + ] + + missing_packages = [] + + for package in required_packages: + # Convert package name to module name (replace - with _) + module_name = package.replace("-", "_") + + # Check if the package is installed + if importlib.util.find_spec(module_name) is None: + missing_packages.append(package) + + # Install missing packages + if missing_packages: + print(f"Installing missing dependencies: {', '.join(missing_packages)}") + subprocess.check_call( + [sys.executable, "-m", "pip", "install"] + missing_packages + ) + print("Dependencies installed successfully") + else: + print("All required dependencies are already installed") + + +def create_makefile(docs_dir): + """Create a Makefile for Sphinx documentation if it doesn't exist.""" + makefile_path = os.path.join(docs_dir, "Makefile") + + if os.path.exists(makefile_path): + print(f"Makefile already exists at {makefile_path}") + return + + print(f"Creating Makefile at {makefile_path}") + + makefile_content = """# Minimal makefile for Sphinx documentation + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(SPHINXFLAGS) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(SPHINXFLAGS) +""" + + with open(makefile_path, "w") as f: + f.write(makefile_content) + + print("Makefile created successfully") + + +def create_make_bat(docs_dir): + """Create a make.bat file for Windows if it doesn't exist.""" + make_bat_path = os.path.join(docs_dir, "make.bat") + + if os.path.exists(make_bat_path): + print(f"make.bat already exists at {make_bat_path}") + return + + print(f"Creating make.bat at {make_bat_path}") + + make_bat_content = """@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd +""" + + with open(make_bat_path, "w") as f: + f.write(make_bat_content) + + print("make.bat created successfully") + + +def main(): + # Check and install required dependencies + print("=== Checking dependencies ===") + check_and_install_dependencies() + + # Get the directory of this script + script_dir = os.path.dirname(os.path.abspath(__file__)) + + # Path to the project root + project_root = os.path.dirname(script_dir) + + # Path to the source directory + source_dir = os.path.join(project_root, "src") + + # Path to the docs source directory + docs_source_dir = os.path.join(script_dir, "source") + + # Print paths for debugging + print(f"Script directory: {script_dir}") + print(f"Project root: {project_root}") + print(f"Source directory: {source_dir}") + print(f"Docs source directory: {docs_source_dir}") + + # Make sure the source directory exists + if not os.path.exists(source_dir): + print(f"Error: Source directory {source_dir} does not exist!") + sys.exit(1) + + # Make sure the docs source directory exists + if not os.path.exists(docs_source_dir): + print(f"Creating docs source directory: {docs_source_dir}") + os.makedirs(docs_source_dir) + + # Step 1: Run sphinx-apidoc to generate .rst files for all modules + print("\n=== Generating API documentation ===") + cmd = [ + "sphinx-apidoc", + "-f", # Force overwriting of existing files + "-e", # Put module documentation before submodule documentation + "-M", # Put module documentation before subpackage documentation + "-o", + docs_source_dir, # Output directory + source_dir, # Source code directory + ] + + print(f"Running command: {' '.join(cmd)}") + result = subprocess.run(cmd, capture_output=True, text=True) + + # Print the output of the command + print("STDOUT:") + print(result.stdout) + + print("STDERR:") + print(result.stderr) + + if result.returncode != 0: + print(f"Error: sphinx-apidoc failed with return code {result.returncode}") + sys.exit(1) + + # List the files in the docs source directory + print("\nFiles in docs/source directory:") + for file in sorted(os.listdir(docs_source_dir)): + print(f" {file}") + + print("\nDocumentation source files generated successfully!") + + # Step 2: Create Makefile and make.bat if they don't exist + create_makefile(script_dir) + create_make_bat(script_dir) + + # Step 3: Build the HTML documentation + print("\n=== Building HTML documentation ===") + + # Determine the build command based on the platform + if os.name == "nt": # Windows + build_cmd = ["make.bat", "html"] + else: # Unix/Linux/Mac + build_cmd = ["make", "html"] + + # Change to the docs directory to run the build command + os.chdir(script_dir) + + print(f"Running command: {' '.join(build_cmd)}") + build_result = subprocess.run(build_cmd, capture_output=True, text=True) + + # Print the output of the build command + print("STDOUT:") + print(build_result.stdout) + + print("STDERR:") + print(build_result.stderr) + + if build_result.returncode != 0: + print(f"Error: HTML build failed with return code {build_result.returncode}") + sys.exit(1) + + # Get the path to the built HTML documentation + html_dir = os.path.join(script_dir, "build", "html") + index_path = os.path.join(html_dir, "index.html") + + if os.path.exists(index_path): + print(f"\nHTML documentation built successfully!") + print(f"You can view it by opening: {index_path}") + + # Try to open the documentation in a browser + try: + import webbrowser + + print("\nAttempting to open documentation in your default browser...") + webbrowser.open(f"file://{index_path}") + except Exception as e: + print(f"Could not open browser automatically: {e}") + else: + print(f"\nWarning: HTML index file not found at {index_path}") + + +if __name__ == "__main__": + main() diff --git a/src_code_for_reproducibility/docs/source/conf.py b/src_code_for_reproducibility/docs/source/conf.py new file mode 100644 index 0000000000000000000000000000000000000000..5c7512678928b6b7580c812cd62d1c22df9945ba --- /dev/null +++ b/src_code_for_reproducibility/docs/source/conf.py @@ -0,0 +1,48 @@ +# Configuration file for the Sphinx documentation builder. +import os +import sys +sys.path.insert(0, os.path.abspath('../..')) + +# -- Project information ----------------------------------------------------- +project = 'llm_negotiation' +copyright = '2023, Your Name' +author = 'Your Name' + +# -- General configuration --------------------------------------------------- +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.viewcode', + 'sphinx.ext.napoleon', + 'sphinx.ext.autosummary', + 'sphinx.ext.intersphinx', + 'sphinx.ext.mathjax', + 'sphinxcontrib.mermaid', + 'sphinx_rtd_theme', +] + +templates_path = ['_templates'] +exclude_patterns = [] + +# -- Options for HTML output ------------------------------------------------- +html_theme = 'sphinx_rtd_theme' +html_static_path = ['_static'] + +# -- Napoleon settings ------------------------------------------------------- +napoleon_google_docstring = True +napoleon_numpy_docstring = False +napoleon_include_init_with_doc = True +napoleon_include_private_with_doc = False +napoleon_include_special_with_doc = True +napoleon_use_admonition_for_examples = False +napoleon_use_admonition_for_notes = False +napoleon_use_admonition_for_references = False +napoleon_use_ivar = False +napoleon_use_param = True +napoleon_use_rtype = True +napoleon_preprocess_types = False +napoleon_type_aliases = None +napoleon_attr_annotations = True + +# -- Path setup -------------------------------------------------------------- +# Make sure the project's modules can be found by Sphinx +sys.path.insert(0, os.path.abspath('../../src')) \ No newline at end of file diff --git a/src_code_for_reproducibility/docs/source/environments.rst b/src_code_for_reproducibility/docs/source/environments.rst new file mode 100644 index 0000000000000000000000000000000000000000..fa2fc4fbe9c68edfb5b1726a19a0bdf133c3f879 --- /dev/null +++ b/src_code_for_reproducibility/docs/source/environments.rst @@ -0,0 +1,35 @@ +================= +MARL Environments +================= + +This section provides detailed documentation for the multi-agent negotiation environments included in the library. + +Each environment follows the standard interface described in :doc:`../environments` but has its own unique game rules, +dynamics, and implementation details. + +.. toctree:: + :maxdepth: 2 + :caption: Available Environments: + + environments/ipd + environments/diplomacy + environments/dond + +Overview +-------- + +The library currently includes the following environments: + +1. **Iterated Prisoner's Dilemma (IPD)**: A classic game theory problem where two agents repeatedly decide whether to cooperate or defect, with different payoffs based on their joint actions. + +2. **Diplomacy**: An adaptation of the board game Diplomacy, where seven European powers compete for control of supply centers through strategic moves and alliances. + +3. **Deal or No Deal (DOND)**: A negotiation environment based on `the paper Deal or No Deal? End-to-End Learning for Negotiation Dialogues `_ in which agents negotiate over the distribution of a set of prizes. + +Each environment documentation includes: + +- Game rules and background +- Implementation details +- API reference +- Example usage +- Advanced features and customization options \ No newline at end of file diff --git a/src_code_for_reproducibility/docs/source/installation.rst b/src_code_for_reproducibility/docs/source/installation.rst new file mode 100644 index 0000000000000000000000000000000000000000..b148f25d92fd8308e9695f7c17c2b91fb0c9a2c6 --- /dev/null +++ b/src_code_for_reproducibility/docs/source/installation.rst @@ -0,0 +1,10 @@ +Installation +=========== + +To install the package, run: + +.. code-block:: bash + + git clone https://github.com/yourusername/llm_negotiation.git + cd llm_negotiation + pip install -e . \ No newline at end of file diff --git a/src_code_for_reproducibility/docs/source/marl_standard.rst b/src_code_for_reproducibility/docs/source/marl_standard.rst new file mode 100644 index 0000000000000000000000000000000000000000..b5ea5529892c611b34255645ec68537a236754cf --- /dev/null +++ b/src_code_for_reproducibility/docs/source/marl_standard.rst @@ -0,0 +1,141 @@ +================= +Abstract Standard for Multi-Agent Negotiation Environments +================= + +Multi-Agent Negotiation Environments require more features than gymnasium environments in order to be used as interfaces in general game running code. +The two fundamental differences between gymnasium environments and Multi-Agent Negotiation Environments are: + +1. Response from the LLM is a text action, not a discrete action. Therefore, appropriate parsing of the text is required. The model may need to be run multiple times to get the full action. + This is why we introduce the `AgentHandler` class, which is responsible for parsing the LLM's response. +2. The environment needs to be able to handle multi-agent interactions. + This is why we introduce the `NegotiationEnvironment` class, which is responsible for handling the multi-agent interactions. +3. MARL environments are complex to describe. In different contexts, the same environment may be described differently. Therefore, both the environement and the agent handlers are + responsible for describing a particular trajectory. This information is given by the `get_log_info` method. +4. There might be a lot of overlap between the neural networks used by each agent. For instance, the same model may be used for all agents. This motivates a requirement for a + policy identifier for each agent. + +Taking inspiration from the `gymnasium `_ library, we introduce a new standard for Multi-Agent Negotiation Environments. + +Our standard is based on the following features: + +Environments are of the form: + +.. code-block:: python + + class MarlEnvironment(): + + def __init__(self): + """Initialize the environment.""" + pass + + def reset(self): + """Reset the environment to an initial state and return the initial observation. + Returns: + observation (dict): A dictionary where keys are agent identifiers and values are observations. + """ + # (...) + return observation + + def step(self, actions): + """Take a step in the environment using the provided actions. + + Args: + actions (dict): A dictionary where keys are agent identifiers and values are actions. + + Returns: + observations (dict): A dictionary where keys are agent identifiers and values are observations. + reward (dict): A dictionary where keys are agent identifiers and values are rewards. + done (bool): Whether the episode has ended. + info (dict): Additional information about the environment. + """ + # (...) + return observations, done, info + + def get_log_info(self): + """Get additional information about the environment. This information is used to log the game. + Returns: + log_info (dict): Information about the environment required to log the game. + """ + # (...) + return log_info + + def render(self): + """Render the current state of the environment.""" + pass + + def close(self): + """Perform any necessary cleanup.""" + pass + + + class AgentState(): + + def __init__(self): + """Initialize the agent state.""" + pass + + def step(self, observation_from_env, policy_output=None): + """Update the agent state based on the observation and action. + The action is the output of the LLM. + """ + + Args: + observation_from_env (dict): The observation of the environment. + policy_output : The output of the policy. + + Returns: + policy_id (str): The policy identifier. + policy_input (dict): The input to the policy. + action : The official action to be sent to the environment. + done (bool): Whether the LLM action is ready to be sent to the environment. + info (dict): Additional information about the agent. + """ + # (...) + return policy_id, policy_input, action, done, info + + def get_log_info(self): + """Get information about the agent required to log a trajectory. + Returns: + log_info (dict): Information about the agent required to log a trajectory. + """ + # (...) + return log_info + + def render(self): + """Render the current state of the environment.""" + pass + + def close(self): + """Perform any necessary cleanup.""" + pass + + +Implicitely, the keys of the `observations` in the `step` method of the `MarlEnvironment` interface represent the set of agents from which an action is expected at the current step. The next step should only expect actions from the agents in the `observations` dictionary. + +As you can see, both classes have a `get_log_info` method. This method is used to log the game. It returns a dictionary with keys being the agent identifiers and values being the information to log. The reason we need this is because the environment and the agent handler may need to log different information. It makes it easier to log from the perspective of each agent. The core environment class should not need to know about the details of the agent handler. + + + +Running Environments in Parallel +-------------------------------- +This standard allows the use of the `run_batched_matches` function (TODO: link) to run environments in an efficient way. The core idea is to batch the policy calls for all agents in the environment. + +.. note:: + The ``run_batched_matches`` function allows you to run multiple negotiation games, or "matches," in parallel. + After each environment is initialized, the function continuously loops over all active matches and checks which agents + are still pending actions. Each agent's logic can require multiple calls to the policy (e.g., an LLM) before an action + becomes "ready" to be sent to the environment. (For instance, an agent might need multiple policy calls before having a string which can be parsed into a valid action.) While an agent is waiting for a policy output, these calls for all agents across all matches are grouped together by unique policy identifier and processed in batch for efficiency. This is the core functionality of the ``run_batched_matches`` function. + + Only once all actions from the required agents at a given step for an environment are ready does the function make a single ``env.step(...)`` call; this ensures + every match moves forward in lockstep for all its active agents. As soon as an environment signals it is done, the function + retrieves logged information from both the environment and the agent states before removing this match from the active set. + + If there are more matches waiting to be processed, they are then started one by one to maintain the specified degree of parallelism. + This batching approach provides an efficient mechanism to handle multi-agent or multi-policy environments, ensuring minimal + overhead and a clear, unified flow for stepping through matches. + +Here is a diagram that shows how the `run_batched_matches` function works at a high level: + +.. image:: media/runbatch.png + :alt: Alternate text for the image + :width: 1000px diff --git a/src_code_for_reproducibility/docs/source/modules.rst b/src_code_for_reproducibility/docs/source/modules.rst new file mode 100644 index 0000000000000000000000000000000000000000..e9ff8ac1a89c7bd18e69d633121f5a4022ac6fdf --- /dev/null +++ b/src_code_for_reproducibility/docs/source/modules.rst @@ -0,0 +1,7 @@ +src +=== + +.. toctree:: + :maxdepth: 4 + + src diff --git a/src_code_for_reproducibility/docs/source/src.environments.dond.dond_game.rst b/src_code_for_reproducibility/docs/source/src.environments.dond.dond_game.rst new file mode 100644 index 0000000000000000000000000000000000000000..d0e595aad169a5a8456f83afe5029e7475d7c9e7 --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.environments.dond.dond_game.rst @@ -0,0 +1,7 @@ +src.environments.dond.dond\_game module +======================================= + +.. automodule:: src.environments.dond.dond_game + :members: + :undoc-members: + :show-inheritance: diff --git a/src_code_for_reproducibility/docs/source/src.environments.dond.dond_log_funcs.rst b/src_code_for_reproducibility/docs/source/src.environments.dond.dond_log_funcs.rst new file mode 100644 index 0000000000000000000000000000000000000000..cf96327d1bcbc7f0f8785804a49a6975eef889c2 --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.environments.dond.dond_log_funcs.rst @@ -0,0 +1,7 @@ +src.environments.dond.dond\_log\_funcs module +============================================= + +.. automodule:: src.environments.dond.dond_log_funcs + :members: + :undoc-members: + :show-inheritance: diff --git a/src_code_for_reproducibility/docs/source/src.environments.dond.dond_player.rst b/src_code_for_reproducibility/docs/source/src.environments.dond.dond_player.rst new file mode 100644 index 0000000000000000000000000000000000000000..bab97f1009eb2d5c4e387ac6a83982a51e33c9e3 --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.environments.dond.dond_player.rst @@ -0,0 +1,7 @@ +src.environments.dond.dond\_agent module +========================================= + +.. automodule:: src.environments.dond.dond_agent + :members: + :undoc-members: + :show-inheritance: diff --git a/src_code_for_reproducibility/docs/source/src.environments.dond.dond_training_data_funcs.rst b/src_code_for_reproducibility/docs/source/src.environments.dond.dond_training_data_funcs.rst new file mode 100644 index 0000000000000000000000000000000000000000..cf31d696a3ed580e24f3c5dffd6f7a2851d16320 --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.environments.dond.dond_training_data_funcs.rst @@ -0,0 +1,7 @@ +src.environments.dond.dond\_training\_data\_funcs module +======================================================== + +.. automodule:: src.environments.dond.dond_training_data_funcs + :members: + :undoc-members: + :show-inheritance: diff --git a/src_code_for_reproducibility/docs/source/src.environments.dond.rst b/src_code_for_reproducibility/docs/source/src.environments.dond.rst new file mode 100644 index 0000000000000000000000000000000000000000..8462de2bdb96d31b5628cfd2942131e13e4e9dc3 --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.environments.dond.rst @@ -0,0 +1,19 @@ +src.environments.dond package +============================= + +.. automodule:: src.environments.dond + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +.. toctree:: + :maxdepth: 4 + + src.environments.dond.dond_agent + src.environments.dond.dond_game + src.environments.dond.dond_log_funcs + src.environments.dond.dond_statistics_funcs + src.environments.dond.dond_training_data_funcs diff --git a/src_code_for_reproducibility/docs/source/src.environments.environment_imports.rst b/src_code_for_reproducibility/docs/source/src.environments.environment_imports.rst new file mode 100644 index 0000000000000000000000000000000000000000..d22c53e31cd1c7c064955900c19f34ac51c7006f --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.environments.environment_imports.rst @@ -0,0 +1,7 @@ +src.environments.environment\_imports module +============================================ + +.. automodule:: src.environments.environment_imports + :members: + :undoc-members: + :show-inheritance: diff --git a/src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_agent.rst b/src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_agent.rst new file mode 100644 index 0000000000000000000000000000000000000000..4845b371089c529493f70de77ceaee0b7500571b --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_agent.rst @@ -0,0 +1,7 @@ +src.environments.ipd.ipd\_agent module +====================================== + +.. automodule:: src.environments.ipd.ipd_agent + :members: + :undoc-members: + :show-inheritance: diff --git a/src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_game.rst b/src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_game.rst new file mode 100644 index 0000000000000000000000000000000000000000..ede471ef9675c780410189fcf63df0c1a05496d0 --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_game.rst @@ -0,0 +1,7 @@ +src.environments.ipd.ipd\_game module +===================================== + +.. automodule:: src.environments.ipd.ipd_game + :members: + :undoc-members: + :show-inheritance: diff --git a/src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_log_funcs.rst b/src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_log_funcs.rst new file mode 100644 index 0000000000000000000000000000000000000000..edec187f4876cdf653ae4f91035f43bc877a7d40 --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_log_funcs.rst @@ -0,0 +1,7 @@ +src.environments.ipd.ipd\_log\_funcs module +=========================================== + +.. automodule:: src.environments.ipd.ipd_log_funcs + :members: + :undoc-members: + :show-inheritance: diff --git a/src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_statistics_funcs.rst b/src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_statistics_funcs.rst new file mode 100644 index 0000000000000000000000000000000000000000..5f54afac07c4d477067ef4c2bf5d883b236cf5fc --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_statistics_funcs.rst @@ -0,0 +1,7 @@ +src.environments.ipd.ipd\_statistics\_funcs module +================================================== + +.. automodule:: src.environments.ipd.ipd_statistics_funcs + :members: + :undoc-members: + :show-inheritance: diff --git a/src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_training_data_funcs.rst b/src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_training_data_funcs.rst new file mode 100644 index 0000000000000000000000000000000000000000..8e4cecea10e25644ff677416823069cd65b500c5 --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.environments.ipd.ipd_training_data_funcs.rst @@ -0,0 +1,7 @@ +src.environments.ipd.ipd\_training\_data\_funcs module +====================================================== + +.. automodule:: src.environments.ipd.ipd_training_data_funcs + :members: + :undoc-members: + :show-inheritance: diff --git a/src_code_for_reproducibility/docs/source/src.environments.ipd.rst b/src_code_for_reproducibility/docs/source/src.environments.ipd.rst new file mode 100644 index 0000000000000000000000000000000000000000..af26091b3a87dee4d6993f0ae09bdb1c380a130e --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.environments.ipd.rst @@ -0,0 +1,19 @@ +src.environments.ipd package +============================ + +.. automodule:: src.environments.ipd + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +.. toctree:: + :maxdepth: 4 + + src.environments.ipd.ipd_agent + src.environments.ipd.ipd_game + src.environments.ipd.ipd_log_funcs + src.environments.ipd.ipd_statistics_funcs + src.environments.ipd.ipd_training_data_funcs diff --git a/src_code_for_reproducibility/docs/source/src.environments.rst b/src_code_for_reproducibility/docs/source/src.environments.rst new file mode 100644 index 0000000000000000000000000000000000000000..221ed1c07ebea145cd23bc06c6474d34b1d8a33e --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.environments.rst @@ -0,0 +1,25 @@ +src.environments package +======================== + +.. automodule:: src.environments + :members: + :undoc-members: + :show-inheritance: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + src.environments.dond + src.environments.ipd + +Submodules +---------- + +.. toctree:: + :maxdepth: 4 + + src.environments.env_imports + src.environments.environment_imports diff --git a/src_code_for_reproducibility/docs/source/src.experiments.arithmetic_test.rst b/src_code_for_reproducibility/docs/source/src.experiments.arithmetic_test.rst new file mode 100644 index 0000000000000000000000000000000000000000..68e0f5da020aee80cc8895ca650a6067317f4bcd --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.experiments.arithmetic_test.rst @@ -0,0 +1,7 @@ +src.experiments.arithmetic\_test module +======================================= + +.. automodule:: src.experiments.arithmetic_test + :members: + :undoc-members: + :show-inheritance: diff --git a/src_code_for_reproducibility/docs/source/src.experiments.dond_run_train.rst b/src_code_for_reproducibility/docs/source/src.experiments.dond_run_train.rst new file mode 100644 index 0000000000000000000000000000000000000000..6c94e4bc508836338d5d6393858d403e746b5d2d --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.experiments.dond_run_train.rst @@ -0,0 +1,7 @@ +src.experiments.dond\_run\_train module +======================================= + +.. automodule:: src.experiments.dond_run_train + :members: + :undoc-members: + :show-inheritance: diff --git a/src_code_for_reproducibility/docs/source/src.experiments.generate_and_train.rst b/src_code_for_reproducibility/docs/source/src.experiments.generate_and_train.rst new file mode 100644 index 0000000000000000000000000000000000000000..d0d0a0ccdf7839f6c192107d1ad91af4aaadda7b --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.experiments.generate_and_train.rst @@ -0,0 +1,7 @@ +src.experiments.generate\_and\_train module +=========================================== + +.. automodule:: src.experiments.generate_and_train + :members: + :undoc-members: + :show-inheritance: diff --git a/src_code_for_reproducibility/docs/source/src.experiments.last_completion.rst b/src_code_for_reproducibility/docs/source/src.experiments.last_completion.rst new file mode 100644 index 0000000000000000000000000000000000000000..1b868ee566283d662a51387046bc070a131f5222 --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.experiments.last_completion.rst @@ -0,0 +1,7 @@ +src.experiments.last\_completion module +======================================= + +.. automodule:: src.experiments.last_completion + :members: + :undoc-members: + :show-inheritance: diff --git a/src_code_for_reproducibility/docs/source/src.experiments.rst b/src_code_for_reproducibility/docs/source/src.experiments.rst new file mode 100644 index 0000000000000000000000000000000000000000..90f61ff53afa0281b9e6b82188eb2df30b81eb07 --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.experiments.rst @@ -0,0 +1,17 @@ +src.experiments package +======================= + +.. automodule:: src.experiments + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +.. toctree:: + :maxdepth: 4 + + src.experiments.arithmetic_test + src.experiments.generate_and_train + src.experiments.last_completion diff --git a/src_code_for_reproducibility/docs/source/src.generation.rst b/src_code_for_reproducibility/docs/source/src.generation.rst new file mode 100644 index 0000000000000000000000000000000000000000..14bb2b1364da7067aed5c37e3c77d091d20f011b --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.generation.rst @@ -0,0 +1,15 @@ +src.generation package +====================== + +.. automodule:: src.generation + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +.. toctree:: + :maxdepth: 4 + + src.generation.run_games diff --git a/src_code_for_reproducibility/docs/source/src.generation.run_games.rst b/src_code_for_reproducibility/docs/source/src.generation.run_games.rst new file mode 100644 index 0000000000000000000000000000000000000000..dbf42d3f821df187cdd4a8bb9d093839ce6b608a --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.generation.run_games.rst @@ -0,0 +1,7 @@ +src.generation.run\_games module +================================ + +.. automodule:: src.generation.run_games + :members: + :undoc-members: + :show-inheritance: diff --git a/src_code_for_reproducibility/docs/source/src.models.dummy_hf_agent.rst b/src_code_for_reproducibility/docs/source/src.models.dummy_hf_agent.rst new file mode 100644 index 0000000000000000000000000000000000000000..937900b392b98e7f01968d496ae0c350a836d632 --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.models.dummy_hf_agent.rst @@ -0,0 +1,7 @@ +src.models.dummy\_hf\_agent module +================================== + +.. automodule:: src.models.dummy_llm_agent + :members: + :undoc-members: + :show-inheritance: diff --git a/src_code_for_reproducibility/docs/source/src.models.dummy_local_llm.rst b/src_code_for_reproducibility/docs/source/src.models.dummy_local_llm.rst new file mode 100644 index 0000000000000000000000000000000000000000..13b40bd388e445fa60a3c3fc2e089ad89c452dbd --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.models.dummy_local_llm.rst @@ -0,0 +1,7 @@ +src.models.dummy\_local\_llm module +=================================== + +.. automodule:: src.models.dummy_local_llm + :members: + :undoc-members: + :show-inheritance: diff --git a/src_code_for_reproducibility/docs/source/src.models.local_llm.rst b/src_code_for_reproducibility/docs/source/src.models.local_llm.rst new file mode 100644 index 0000000000000000000000000000000000000000..5c2eebb05e64919d1915eeb63dc18f5e9a36eb2c --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.models.local_llm.rst @@ -0,0 +1,7 @@ +src.models.local\_llm module +============================ + +.. automodule:: src.models.local_llm + :members: + :undoc-members: + :show-inheritance: diff --git a/src_code_for_reproducibility/docs/source/src.models.rst b/src_code_for_reproducibility/docs/source/src.models.rst new file mode 100644 index 0000000000000000000000000000000000000000..d03983340a5b0317354d1895df709277d5a4baed --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.models.rst @@ -0,0 +1,20 @@ +src.models package +================== + +.. automodule:: src.models + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +.. toctree:: + :maxdepth: 4 + + src.models.dummy_local_llm + src.models.local_llm + src.models.new_local_llm + src.models.server_llm + src.models.updatable_worker + src.models.vllm_worker_wrap diff --git a/src_code_for_reproducibility/docs/source/src.models.updatable_worker.rst b/src_code_for_reproducibility/docs/source/src.models.updatable_worker.rst new file mode 100644 index 0000000000000000000000000000000000000000..ee05dfbe7dd407eed4e275f525c04ca6f2c857ae --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.models.updatable_worker.rst @@ -0,0 +1,7 @@ +src.models.updatable\_worker module +=================================== + +.. automodule:: src.models.updatable_worker + :members: + :undoc-members: + :show-inheritance: diff --git a/src_code_for_reproducibility/docs/source/src.rst b/src_code_for_reproducibility/docs/source/src.rst new file mode 100644 index 0000000000000000000000000000000000000000..d2dcff7e18f979e933a27467f4893f3ad5372a88 --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.rst @@ -0,0 +1,28 @@ +src package +=========== + +.. automodule:: src + :members: + :undoc-members: + :show-inheritance: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + src.environments + src.experiments + src.generation + src.models + src.training + src.utils + +Submodules +---------- + +.. toctree:: + :maxdepth: 4 + + src.run diff --git a/src_code_for_reproducibility/docs/source/src.training.ppo_train_value_head.rst b/src_code_for_reproducibility/docs/source/src.training.ppo_train_value_head.rst new file mode 100644 index 0000000000000000000000000000000000000000..a8d6e526eacfce47fac83a7e4617e04db56ecee2 --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.training.ppo_train_value_head.rst @@ -0,0 +1,7 @@ +src.training.ppo\_train\_value\_head module +=========================================== + +.. automodule:: src.training.ppo_train_value_head + :members: + :undoc-members: + :show-inheritance: diff --git a/src_code_for_reproducibility/docs/source/src.training.reinforce_training.rst b/src_code_for_reproducibility/docs/source/src.training.reinforce_training.rst new file mode 100644 index 0000000000000000000000000000000000000000..5daf4b7250022f523242d6239d0921f362df6d24 --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.training.reinforce_training.rst @@ -0,0 +1,7 @@ +src.training.reinforce\_training module +======================================= + +.. automodule:: src.training.reinforce_training + :members: + :undoc-members: + :show-inheritance: diff --git a/src_code_for_reproducibility/docs/source/src.training.rl_convs_processing.rst b/src_code_for_reproducibility/docs/source/src.training.rl_convs_processing.rst new file mode 100644 index 0000000000000000000000000000000000000000..cf5db1aa0cb6d010fc70f86c341467ba5e9b485e --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.training.rl_convs_processing.rst @@ -0,0 +1,7 @@ +src.training.rl\_convs\_processing module +========================================= + +.. automodule:: src.training.rl_convs_processing + :members: + :undoc-members: + :show-inheritance: diff --git a/src_code_for_reproducibility/docs/source/src.training.rst b/src_code_for_reproducibility/docs/source/src.training.rst new file mode 100644 index 0000000000000000000000000000000000000000..50539fcda2bffa46a72eb48874a7532bf296ff27 --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.training.rst @@ -0,0 +1,19 @@ +src.training package +==================== + +.. automodule:: src.training + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +.. toctree:: + :maxdepth: 4 + + src.training.ppo_train + src.training.ppo_train_value_head + src.training.reinforce_training + src.training.rl_convs_processing + src.training.train_main diff --git a/src_code_for_reproducibility/docs/source/src.training.train_main.rst b/src_code_for_reproducibility/docs/source/src.training.train_main.rst new file mode 100644 index 0000000000000000000000000000000000000000..838f9ac1cfb2cecd48b2e3d241e7a80c7b86fe25 --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.training.train_main.rst @@ -0,0 +1,7 @@ +src.training.train\_main module +=============================== + +.. automodule:: src.training.train_main + :members: + :undoc-members: + :show-inheritance: diff --git a/src_code_for_reproducibility/docs/source/src.utils.export_ppo_training_set.rst b/src_code_for_reproducibility/docs/source/src.utils.export_ppo_training_set.rst new file mode 100644 index 0000000000000000000000000000000000000000..d05da7df25cb21d8abece3f920c1fb4b88f34d0f --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.utils.export_ppo_training_set.rst @@ -0,0 +1,7 @@ +src.utils.export\_ppo\_training\_set module +=========================================== + +.. automodule:: src.utils.export_ppo_training_set + :members: + :undoc-members: + :show-inheritance: diff --git a/src_code_for_reproducibility/docs/source/src.utils.log_gpu_usage.rst b/src_code_for_reproducibility/docs/source/src.utils.log_gpu_usage.rst new file mode 100644 index 0000000000000000000000000000000000000000..44b83082b6eb027ef402603e034c712ccc2cbfcc --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.utils.log_gpu_usage.rst @@ -0,0 +1,7 @@ +src.utils.log\_gpu\_usage module +================================ + +.. automodule:: src.utils.log_gpu_usage + :members: + :undoc-members: + :show-inheritance: diff --git a/src_code_for_reproducibility/docs/source/src.utils.rst b/src_code_for_reproducibility/docs/source/src.utils.rst new file mode 100644 index 0000000000000000000000000000000000000000..4f5cb352cc9ec645c968d0ae99798d47c018c750 --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.utils.rst @@ -0,0 +1,24 @@ +src.utils package +================= + +.. automodule:: src.utils + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +.. toctree:: + :maxdepth: 4 + + src.utils.common_imports + src.utils.export_ppo_training_set + src.utils.extra_stats + src.utils.inherit_args + src.utils.log_gpu_usage + src.utils.log_statistics + src.utils.model_to_cpu + src.utils.parallel_shuffle + src.utils.quick_stats + src.utils.update_start_epoch diff --git a/src_code_for_reproducibility/docs/source/src.utils.update_start_epoch.rst b/src_code_for_reproducibility/docs/source/src.utils.update_start_epoch.rst new file mode 100644 index 0000000000000000000000000000000000000000..72cbad9bd09e056213a2e4cd00a6ba624be333cb --- /dev/null +++ b/src_code_for_reproducibility/docs/source/src.utils.update_start_epoch.rst @@ -0,0 +1,7 @@ +src.utils.update\_start\_epoch module +===================================== + +.. automodule:: src.utils.update_start_epoch + :members: + :undoc-members: + :show-inheritance: diff --git a/src_code_for_reproducibility/markov_games/diplomacy/diplomacy_agent.py b/src_code_for_reproducibility/markov_games/diplomacy/diplomacy_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..5ff402e23224fc7961d9e6796c40daaf2ab4bbaa --- /dev/null +++ b/src_code_for_reproducibility/markov_games/diplomacy/diplomacy_agent.py @@ -0,0 +1,259 @@ +from typing import Dict, List, Tuple, Optional, Any +import copy + +class DiplomacyAgent: + """Agent handler for Diplomacy game that follows the MARL standard. + + This class is responsible for parsing LLM output into valid Diplomacy orders, + managing the agent state, and providing information for logging. + """ + + def __init__(self, policy_id: str, power_name: str, random_valid_move=False): + """Initialize the agent handler for a power in the Diplomacy game. + + Args: + power_name: The name of the power this agent controls (e.g., 'FRANCE', 'ENGLAND') + policy_id: The identifier for the policy this agent uses + random_valid_move: If True, will select random valid moves instead of using LLM (default: False) + """ + self.policy_id = policy_id + self.power_name = power_name + self.orders = [] + self.wait = True + self.processing_state = "WAITING_FOR_ORDERS" + self.parsed_orders = [] + self.order_status = {} + self.message_history = [] + self.random_valid_move = random_valid_move + + def step(self, observation_from_env, policy_output=None): + """Update the agent state based on the observation and LLM output. + + Args: + observation_from_env: The observation from the environment + policy_output: The output from the LLM + + Returns: + policy_id: The policy identifier + policy_input: The input to the policy + action: The official action to be sent to the environment + done: Whether the LLM action is ready to be sent to the environment + info: Additional information about the agent + """ + info = {} + + # If random_valid_move is enabled, select random valid moves + if self.random_valid_move: + valid_orders = self._select_random_valid_moves(observation_from_env) + self.orders = valid_orders + self.wait = False + action = { + "orders": valid_orders, + "wait": False + } + return self.policy_id, {}, action, True, info + + # If no policy output, this is the initial step - prepare prompt + if policy_output is None: + # Create initial prompt for the LLM + phase = observation_from_env.get('phase', '') + units = observation_from_env.get('units', {}).get(self.power_name, []) + centers = observation_from_env.get('centers', {}).get(self.power_name, []) + orderable_locations = observation_from_env.get('orderable_locations', {}) + + prompt = self._create_prompt(phase, units, centers, orderable_locations) + + return self.policy_id, {"prompt": prompt}, None, False, info + + # Process the LLM output to extract orders + success, parsed_orders = self._parse_llm_output(policy_output) + self.parsed_orders = parsed_orders + + if not success: + # Need more information from LLM + clarification_prompt = self._create_clarification_prompt(policy_output, parsed_orders) + return self.policy_id, {"prompt": clarification_prompt}, None, False, info + + # Validate if the orders are valid for the current phase + valid_orders = self._validate_orders(parsed_orders, observation_from_env) + + if valid_orders: + # Orders are valid, prepare action for environment + self.orders = valid_orders + self.wait = False + action = { + "orders": valid_orders, + "wait": False + } + return self.policy_id, {}, action, True, info + else: + # Orders are invalid, ask for new ones + error_prompt = self._create_error_prompt(parsed_orders, observation_from_env) + return self.policy_id, {"prompt": error_prompt}, None, False, info + + def _create_prompt(self, phase, units, centers, orderable_locations): + """Create the initial prompt for the LLM. + + Args: + phase: The current game phase + units: List of units controlled by this power + centers: List of supply centers controlled by this power + orderable_locations: List of locations where orders can be issued + + Returns: + A prompt string for the LLM + """ + prompt = f"You are playing as {self.power_name} in Diplomacy. The current phase is {phase}.\n\n" + prompt += f"Your units: {', '.join(units)}\n" + prompt += f"Your supply centers: {', '.join(centers)}\n" + prompt += f"Locations you can order: {', '.join(orderable_locations)}\n\n" + + if phase.endswith('M'): # Movement phase + prompt += "Please provide orders for your units in the form:\n" + prompt += "- A LON H (hold)\n" + prompt += "- F NTH - NWY (move)\n" + prompt += "- A WAL S F LON (support)\n" + prompt += "- F NWG C A NWY - EDI (convoy)\n" + elif phase.endswith('R'): # Retreat phase + prompt += "Please provide retreat orders for your dislodged units:\n" + prompt += "- A PAR R MAR (retreat to MAR)\n" + prompt += "- A PAR D (disband)\n" + elif phase.endswith('A'): # Adjustment phase + if len(units) < len(centers): + prompt += "You can build units. Please provide build orders:\n" + prompt += "- A PAR B (build army in PAR)\n" + prompt += "- F BRE B (build fleet in BRE)\n" + prompt += "- WAIVE (waive a build)\n" + elif len(units) > len(centers): + prompt += "You must remove units. Please provide disbandment orders:\n" + prompt += "- A PAR D (disband army in PAR)\n" + prompt += "- F BRE D (disband fleet in BRE)\n" + + prompt += "\nProvide your orders as a list, one per line." + return prompt + + def _parse_llm_output(self, llm_output): + """Parse the LLM output to extract orders. + + Args: + llm_output: The raw output from the LLM + + Returns: + success: Whether parsing was successful + parsed_orders: List of parsed orders + """ + # Simple parsing for now - extract lines that look like orders + lines = llm_output.strip().split('\n') + orders = [] + + for line in lines: + # Remove list markers, hyphens, etc. + line = line.strip('- *•').strip() + + # Skip empty lines and lines that don't look like orders + if not line or line.startswith('I ') or line.startswith('Let\'s'): + continue + + # Check if it looks like a Diplomacy order + if (' H' in line or ' -' in line or ' S ' in line or ' C ' in line or + ' R ' in line or ' D' in line or ' B' in line or line == 'WAIVE'): + orders.append(line) + + return len(orders) > 0, orders + + def _validate_orders(self, orders, observation): + """Validate if the orders are valid for the current phase. + + Args: + orders: List of orders to validate + observation: Current observation from the environment + + Returns: + List of valid orders or None if invalid + """ + # For simplicity, we'll assume all parsed orders are valid + # In a real implementation, we would use the game's validation logic + return orders + + def _create_clarification_prompt(self, previous_output, parsed_orders): + """Create a prompt asking for clarification when orders couldn't be parsed. + + Args: + previous_output: The previous LLM output + parsed_orders: Any orders that were successfully parsed + + Returns: + A prompt string for the LLM + """ + prompt = f"I couldn't fully understand your orders for {self.power_name}. " + + if parsed_orders: + prompt += f"I understood these orders:\n" + for order in parsed_orders: + prompt += f"- {order}\n" + + prompt += "\nPlease provide clear, valid Diplomacy orders in the format:\n" + prompt += "- A LON H\n- F NTH - NWY\n- etc.\n" + return prompt + + def _create_error_prompt(self, invalid_orders, observation): + """Create a prompt when orders are invalid. + + Args: + invalid_orders: The invalid orders + observation: Current observation from the environment + + Returns: + A prompt string for the LLM + """ + prompt = f"The following orders for {self.power_name} are invalid:\n" + for order in invalid_orders: + prompt += f"- {order}\n" + + prompt += "\nPlease provide valid orders for your units." + return prompt + + def get_log_info(self): + """Get information about the agent required to log a trajectory. + + Returns: + log_info: Information about the agent required to log a trajectory. + """ + return { + "power_name": self.power_name, + "orders": self.orders, + "wait": self.wait, + "parsing_state": self.processing_state, + "message_history": self.message_history + } + + def render(self): + """Render the current state of the agent.""" + print(f"Power: {self.power_name}") + print(f"Orders: {self.orders}") + print(f"Wait: {self.wait}") + + def close(self): + """Perform any necessary cleanup.""" + pass + + def _select_random_valid_moves(self, observation): + """Select random valid moves for all units. + + Args: + observation: Current observation from the environment + + Returns: + List of valid orders + """ + import random + + possible_orders = observation.get('possible_orders', {}) + valid_orders = [] + + # For each location with possible orders, select one randomly + for location, orders in possible_orders.items(): + if orders: # If there are any possible orders for this location + valid_orders.append(random.choice(orders)) + + return valid_orders \ No newline at end of file diff --git a/src_code_for_reproducibility/markov_games/diplomacy/diplomacy_logging.py b/src_code_for_reproducibility/markov_games/diplomacy/diplomacy_logging.py new file mode 100644 index 0000000000000000000000000000000000000000..3f60e82d9738c116c0b8b8d3f7818eddebb18fa2 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/diplomacy/diplomacy_logging.py @@ -0,0 +1,360 @@ +import os +import json +from utils.common_imports import * + + + +def diplomacy_log_match( + path, + agents_log_info, + env_log_info, + metrics_func=None, + metrics_func_args=None + ): + """ + Logs the Diplomacy game data and generates HTML visualizations using the get_log_info methods. + + Args: + path (str): Base path to save the data. + agents_log_info (list): List of agent information dictionaries containing the get_log_info results. + env_log_info (dict): Environment information from its get_log_info method. + metrics_func (str, optional): Name of the function to calculate metrics. + metrics_func_args (dict, optional): Arguments for the metrics function. + """ + # Create directory structure + os.makedirs(path, exist_ok=True) + + # Save the environment log info + env_log_path = os.path.join(path, "env_log.json") + with open(env_log_path, "w") as f: + json.dump(env_log_info, f, indent=4, default=_json_serialize) + + # Process each agent's log info + for agent_log in agents_log_info: + power_name = agent_log["power_name"] + + # Define paths for raw data and statistics subfolders + power_path = os.path.join(path, power_name) + raw_data_path = os.path.join(power_path, "raw_data") + statistics_path = os.path.join(power_path, "statistics") + + # Ensure directories exist + os.makedirs(raw_data_path, exist_ok=True) + os.makedirs(statistics_path, exist_ok=True) + + # Determine the next available file number for raw data + raw_files = os.listdir(raw_data_path) + raw_numbers = [int(f.split('_')[-1].split('.')[0]) for f in raw_files if f.startswith("log_")] + next_raw_number = max(raw_numbers, default=0) + 1 + raw_file = os.path.join(raw_data_path, f"log_{next_raw_number}.json") + + # Save agent log info + with open(raw_file, "w") as f: + json.dump(agent_log, f, indent=4, default=_json_serialize) + + # Log metrics if a metrics function is provided + if metrics_func: + metrics_files = os.listdir(statistics_path) + metrics_numbers = [int(f.split('_')[-1].split('.')[0]) for f in metrics_files if f.startswith("metrics_")] + next_metrics_number = max(metrics_numbers, default=0) + 1 + metrics_file = os.path.join(statistics_path, f"metrics_{next_metrics_number}.json") + + metrics = globals()[metrics_func](agent_log, info, **metrics_func_args) + with open(metrics_file, "w") as f: + json.dump(metrics, f, indent=4) + + # Generate the HTML visualization + html_content = generate_diplomacy_html(agents_log_info, env_log_info) + + # Ensure the html directory exists + html_path = os.path.join(path, "html") + os.makedirs(html_path, exist_ok=True) + + # Determine the next available file number for HTML + html_files = os.listdir(html_path) + html_numbers = [int(f.split('_')[-1].split('.')[0]) for f in html_files if f.startswith("game_summary_")] + next_html_number = max(html_numbers, default=0) + 1 + html_file = os.path.join(html_path, f"game_summary_{next_html_number}.html") + + # Save the HTML content to a file + with open(html_file, "w") as f: + f.write(html_content) + +def generate_diplomacy_html(agent_infos, env_info): + """ + Generate HTML visualization for a Diplomacy game. + + Args: + agent_infos (list): List of agent information dictionaries from get_log_info. + env_info (dict): Environment information from get_log_info. + + Returns: + str: HTML content for the game visualization. + """ + # Extract game information + game_id = env_info.get("game_id", "Unknown") + phase = env_info.get("phase", "Unknown") + map_name = env_info.get("map_name", "standard") + is_game_done = env_info.get("is_game_done", False) + outcome = env_info.get("outcome", []) + + centers = env_info.get("centers", {}) + units = env_info.get("units", {}) + + # HTML head and style + html_content = """ + + + + + + Diplomacy Game {game_id} + + + +
+

Game Information

+
+
+

Game Details

+

Game ID: {game_id}

+

Phase: {phase}

+

Map: {map_name}

+

Status: {status}

+
+
+

Supply Centers

+
+ """.format( + game_id=game_id, + phase=phase, + map_name=map_name, + status="Completed" if is_game_done else "Active" + ) + + # Add supply center information + for power, power_centers in centers.items(): + html_content += f""" +
+ {power}: {len(power_centers)} +
+ """ + + html_content += """ +
+
+
+ """ + + # Add outcome if game is done + if is_game_done and outcome: + winners = outcome[1:] if len(outcome) > 1 else ["Draw"] + html_content += f""" +
+

Game Outcome

+

Winners: {', '.join(winners)}

+
+ """ + + html_content += """ +
+
+ """ + + # Add each power's information + for agent_log in agent_infos: + power_name = agent_log["power_name"] + power_class = power_name.lower() + orders = agent_log.get("orders", []) + message_history = agent_log.get("message_history", []) + + html_content += f""" +
+
{power_name}
+ +
+

Units

+
    + """ + + # Add units information + power_units = units.get(power_name, []) + for unit in power_units: + html_content += f"
  • {unit}
  • " + + html_content += """ +
+
+ +
+
Final Orders
+
    + """ + + # Add orders + for order in orders: + html_content += f"
  • {order}
  • " + + html_content += """ +
+
+ """ + + # Add message history + for message in message_history: + if isinstance(message, dict): + # Skip system messages or handle differently + if message.get("role") == "system": + continue + + role = message.get("role", "unknown") + content = message.get("content", "") + + role_class = "user" if role == "user" else "assistant" + role_display = "Environment" if role == "user" else f"LLM ({power_name})" + + # Escape HTML characters in content + content = content.replace("<", "<").replace(">", ">").replace("\n", "
") + + html_content += f""" +
+
{role_display}
+

{content}

+
+ """ + elif isinstance(message, str): + # Simple string messages (may be used in some implementations) + html_content += f""" +
+

{message}

+
+ """ + + html_content += """ +
+ """ + + html_content += """ +
+ + + """ + + return html_content + +def _json_serialize(obj): + """ + A helper function to convert non-JSON-serializable objects + (like OrderResult) into strings or dicts. + """ + # Check for the specific object types you know are problematic + if obj.__class__.__name__ == "OrderResult": + # Return a string representation or a dict + return str(obj) + + # Fallback: attempt to convert anything else to string + return str(obj) \ No newline at end of file diff --git a/src_code_for_reproducibility/markov_games/ipd/__init__.py b/src_code_for_reproducibility/markov_games/ipd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7f2388f6380fd3a54a2c80d1f1f77ae1d1fd4c8 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/ipd/__init__.py @@ -0,0 +1,7 @@ +from .Ipd_hard_coded_agents import AlwaysCooperateIPDAgent, AlwaysDefectIPDAgent + +__all__ = [ + "AlwaysCooperateIPDAgent", + "AlwaysDefectIPDAgent", +] + diff --git a/src_code_for_reproducibility/markov_games/ipd/__pycache__/Ipd_hard_coded_agents.cpython-312.pyc b/src_code_for_reproducibility/markov_games/ipd/__pycache__/Ipd_hard_coded_agents.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2bd4631ea1c691af05899c36322cfb2ffb4bba9a Binary files /dev/null and b/src_code_for_reproducibility/markov_games/ipd/__pycache__/Ipd_hard_coded_agents.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/ipd/__pycache__/__init__.cpython-312.pyc b/src_code_for_reproducibility/markov_games/ipd/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f31fbc475eddd63a85d36d6ba42a5232e0b175a2 Binary files /dev/null and b/src_code_for_reproducibility/markov_games/ipd/__pycache__/__init__.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/ipd/__pycache__/ipd_agent.cpython-312.pyc b/src_code_for_reproducibility/markov_games/ipd/__pycache__/ipd_agent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b76920f80b27aa0c326141cd8fdf1166fbb9258b Binary files /dev/null and b/src_code_for_reproducibility/markov_games/ipd/__pycache__/ipd_agent.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/ipd/__pycache__/ipd_simulation.cpython-312.pyc b/src_code_for_reproducibility/markov_games/ipd/__pycache__/ipd_simulation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..342211af722d2aef07232358e0adfcc674718e16 Binary files /dev/null and b/src_code_for_reproducibility/markov_games/ipd/__pycache__/ipd_simulation.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/ipd/__pycache__/ipd_statistics.cpython-312.pyc b/src_code_for_reproducibility/markov_games/ipd/__pycache__/ipd_statistics.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a41b797fd13e8f06b9b4d308b7ab53832604ffc4 Binary files /dev/null and b/src_code_for_reproducibility/markov_games/ipd/__pycache__/ipd_statistics.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/ipd/ipd_agent.py b/src_code_for_reproducibility/markov_games/ipd/ipd_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..3a6f64b542dcc9ee7e114e617bde9cc1181ea301 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/ipd/ipd_agent.py @@ -0,0 +1,115 @@ +import copy +import json +import random +import re +from collections.abc import Callable +from copy import deepcopy +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Union + +from mllm.markov_games.agent import Agent +from mllm.markov_games.rollout_tree import AgentActLog, ChatTurn + + +@dataclass +class IPDAgentState: + """ + TOWRITE + """ + + nb_retries: int + round_nb: int + chat_counter: int + chat_history: List[ChatTurn] + + +@dataclass +class IPDAgent(Agent): + seed: int + agent_id: str + agent_name: str + policy: Callable[[List[Dict]], str] + intro_prompt: str # Introduction prompt explaining the game rules + goal_prompt: str # Prompt explaining the agent's goal + strategy_prompt: str # Prompt suggesting a strategy to the agent + max_errors: int # Maximum number of errors allowed before default action + allow_reasoning: bool # Whether to allow reasoning in the response + max_reasoning_chars: int # Maximum number of characters for reasoning + cooperate_string: str # string parsed as playing cooperate by simulation + defect_string: str # string parsed as playing defect by simulation + + def __post_init__(self): + self.state = IPDAgentState( + nb_retries=0, round_nb=0, chat_counter=0, chat_history=[] + ) + + async def act(self, observation) -> Tuple[Any, AgentActLog]: + """ + TOWRITE + """ + + action = None + action_is_ready = False + round_nb = observation.round_nb + + # If it's the first round, we need to send the intro prompt + if round_nb == 0 and self.state.chat_counter == 0: + self.state.chat_history.append( + ChatTurn( + agent_id=self.agent_id, + role="user", + content=self.intro_prompt, + is_state_end=True, + ) + ) + + # If new round + if round_nb > self.state.round_nb: + coagent_action = observation.last_coagent_move + user_message = f"Last round, the other agent played {coagent_action}." + self.state.chat_history.append( + ChatTurn( + agent_id=self.agent_id, + role="user", + content=user_message, + is_state_end=True, + ) + ) + + # If not new round, try to get valid action from policy + output_chat_turn: ChatTurn = await self.policy( + state=self.state.chat_history, + agent_id=self.agent_id, + regex=f"({self.cooperate_string}|{self.defect_string})", + ) + self.state.chat_history.append(output_chat_turn) + action = output_chat_turn.content + + agent_step_log = AgentActLog( + chat_turns=self.state.chat_history[self.state.chat_counter :], info=None + ) + self.state.chat_counter = len(self.state.chat_history) + self.state.round_nb = round_nb + + return action, agent_step_log + + def get_safe_copy(self): + """ + Return a safe copy of the agent. + """ + agent_copy = copy.copy(self) + agent_copy.state = copy.deepcopy(self.state) + return agent_copy + + def reset(self): + self.state = IPDAgentState() + raise NotImplementedError + + def render(self): + pass + + def close(self): + pass + + def get_agent_info(self): + pass diff --git a/src_code_for_reproducibility/markov_games/ipd/ipd_simulation.py b/src_code_for_reproducibility/markov_games/ipd/ipd_simulation.py new file mode 100644 index 0000000000000000000000000000000000000000..238c7d319bd6e679284d2636aeffa194662a664b --- /dev/null +++ b/src_code_for_reproducibility/markov_games/ipd/ipd_simulation.py @@ -0,0 +1,162 @@ +import copy +import random +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np + +from mllm.markov_games.markov_game import Simulation +from mllm.markov_games.rollout_tree import SimulationStepLog +from mllm.utils.get_coagent_id import get_coagent_id + + +@dataclass +class IPDState: + """ + State of the Iterated Prisoner's Dilemma game. + """ + + round_nb: int = 0 + done: bool = False + last_moves: Dict[str, str] | None = None + + +@dataclass +class IPDObs: + """ + Observation in Iterated Prisoner's Dilemma game. + """ + + round_nb: int + last_coagent_move: str | None + + +class IPD(Simulation): + """ + Iterated Prisoner's Dilemma simulation following the standard. + + In each round of the game, two agents simultaneously choose to either cooperate (C) or defect (D). + The payoffs are as follows: + - If both cooperate: Both receive the "reward" (usually 3 points) + - If both defect: Both receive the "punishment" (usually 1 point) + - If one cooperates and one defects: The defector receives the "temptation" (usually 5 points) + and the cooperator receives the "sucker" payoff (usually 0 points) + + The game is played for a specified number of rounds. + """ + + def __init__( + self, + agent_ids: List[str], + agent_names: List[str], + seed: int, + rounds_per_game: int, + reward: float, # Both cooperate + punishment: float, # Both defect + temptation: float, # Defector's reward when other cooperates + sucker: float, # Cooperator's reward when other defects + cooperate_actions: List[str], + defect_actions: List[str], + ): + self.agent_ids = agent_ids + self.agent_names = agent_names + self.seed = seed + self.rounds_per_game = rounds_per_game + self.reward = reward + self.punishment = punishment + self.temptation = temptation + self.sucker = sucker + self.cooperate_actions = cooperate_actions + self.defect_actions = defect_actions + self.state = IPDState() + + def step(self, actions: Dict[str, str]) -> Tuple[bool, SimulationStepLog]: + """ + Take a step in the environment using the provided actions. + Here, the observations are just the states of the game. + + Args: + actions (dict): A dictionary where keys are agent identifiers and values are actions ('C' or 'D'). + + Returns: + observations (dict): A dictionary where keys are agent identifiers and values are observations. + done (bool): Whether the episode has ended. + info (dict): Additional information about the environment. + """ + + # Calculate rewards using payoff matrix + agent0_action = actions[self.agent_ids[0]] + agent1_action = actions[self.agent_ids[1]] + + # Normalize actions to standard cooperate/defect/gibberish format + def normalize_action(action): + if action in self.cooperate_actions: + return "C" + elif action in self.defect_actions: + return "D" + else: + return "D" + + norm_action0 = normalize_action(agent0_action) + norm_action1 = normalize_action(agent1_action) + + payoffs = { + ("C", "C"): [self.reward, self.reward], + ("C", "D"): [self.sucker, self.temptation], + ("D", "C"): [self.temptation, self.sucker], + ("D", "D"): [self.punishment, self.punishment], + } + + round_rewards = { + self.agent_ids[0]: payoffs[(norm_action0, norm_action1)][0], + self.agent_ids[1]: payoffs[(norm_action0, norm_action1)][1], + } + + # Update game state + self.state.round_nb += 1 + self.state.last_moves = copy.deepcopy(actions) + done = self.state.round_nb >= self.rounds_per_game + step_log = SimulationStepLog( + rewards=round_rewards, + info={ + "actions": { + self.agent_ids[0]: norm_action0, + self.agent_ids[1]: norm_action1, + } + }, + ) + + return done, step_log + + def get_obs(self): + """Returns all agent observations in dict + Returns: + observations + """ + observations = {} + for agent_id in self.agent_ids: + observations[agent_id] = self.get_obs_agent(agent_id) + return observations + + def get_obs_agent(self, agent_id): + """Returns observation for agent_id""" + if self.state.last_moves != None: + other_id = get_coagent_id(self.agent_ids, agent_id) + last_coagent_move = self.state.last_moves[other_id] + else: + last_coagent_move = None + obs = IPDObs(round_nb=self.state.round_nb, last_coagent_move=last_coagent_move) + return obs + + def reset(self): + """Returns initial observations and states""" + self.state = IPDState() + return self.get_obs() + + def get_safe_copy(self): + """ + Return a safe copy of the simulation. + """ + simulation_copy = copy.copy(self) + simulation_copy.state = copy.deepcopy(self.state) + return simulation_copy diff --git a/src_code_for_reproducibility/markov_games/ipd/ipd_statistics.py b/src_code_for_reproducibility/markov_games/ipd/ipd_statistics.py new file mode 100644 index 0000000000000000000000000000000000000000..8740fda6bc2550c92aef27ed9fbe7bc945be42ca --- /dev/null +++ b/src_code_for_reproducibility/markov_games/ipd/ipd_statistics.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from typing import Dict, Callable, List, Tuple + +from mllm.markov_games.rollout_tree import SimulationStepLog + + +def avg_reward(sl: SimulationStepLog) -> List[Tuple[str, float]]: + for aid in sl.rewards.keys(): + if "buffer" in str(aid) and "live" not in str(aid): + return None + # One value per agent at each step + rewards_dict = {f"reward-{aid}": float(v) for aid, v in (sl.rewards or {}).items()} + return [(key, value) for key, value in rewards_dict.items() if value is not None] + +stat_functs: list[Callable[[SimulationStepLog], List[Tuple[str, float]]]] = [ + avg_reward, +] \ No newline at end of file diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/dond_agent.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/dond_agent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce815f643a8b1900a43c4f4aef2d526537700de7 Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/dond_agent.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/dond_simulation.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/dond_simulation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32a288d376668e20c6920050cc222f6a2873626b Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/dond_simulation.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/nego_agent.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/nego_agent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fe2554c39730af2a459983df421ba98da31ffe9 Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/nego_agent.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/nego_hard_coded_policies.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/nego_hard_coded_policies.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63823004458193b44378e01e1385a8a71d7fb487 Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/nego_hard_coded_policies.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/nego_simulation.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/nego_simulation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..885a652a7b17e8f1bfabd33e4d4ec0c9e435bd14 Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/nego_simulation.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/negotiation_statistics.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/negotiation_statistics.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3ff2547a0f33a2b53ee15fe6f9951f6205a0a02 Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/negotiation_statistics.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/no_press_nego_agent.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/no_press_nego_agent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7281edd1342dbd96b0d8db1fe0c254ad44e180e Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/no_press_nego_agent.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_agent.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_agent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b13dbfa24f63cd641f0f98306b862327083e3b84 Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_agent.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_rps_agent.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_rps_agent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4809c83abfe2fe1ca41508d6fcc081e9f584aab5 Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_rps_agent.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_rps_simulation.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_rps_simulation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d411e2237230583ef9935d9ffd9e61197c1890e Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_rps_simulation.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_simple_agent.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_simple_agent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1092f7bfe2281d3c0708e5195bcc7305179f298c Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_simple_agent.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_simple_simulation.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_simple_simulation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ce7e0ea40f3cb79879f15a5f8785e43abc9518f Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_simple_simulation.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_simulation.cpython-312.pyc b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_simulation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbda526f1d3a02b46ba18f0b1f3eec117378953d Binary files /dev/null and b/src_code_for_reproducibility/markov_games/negotiation/__pycache__/tas_simulation.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/negotiation/nego_agent.py b/src_code_for_reproducibility/markov_games/negotiation/nego_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..9b5bf4e3ca4ee7faa982360674e19d9eff6980dc --- /dev/null +++ b/src_code_for_reproducibility/markov_games/negotiation/nego_agent.py @@ -0,0 +1,242 @@ +import copy +from abc import abstractmethod +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple + +import numpy as np + +from mllm.markov_games.agent import Agent +from mllm.markov_games.negotiation.nego_simulation import Message, NegotiationObs, Split +from mllm.markov_games.rollout_tree import AgentActLog, ChatTurn + + +@dataclass +class NegotiationAgentState: + round_nb: int + nb_messages_sent_this_round: int + chat_counter: int + chat_history: List[ChatTurn] + + +class NegotiationAgent(Agent): + def __init__( + self, + seed: int, + agent_id: str, + agent_name: str, + policy: Callable[[List[Dict]], str], + goal: str, + exploration_prompts: List[str] = [], + exploration_prompt_probs: List[float] = [], + ): + self.seed = seed + self.agent_id = agent_id + self.agent_name = agent_name + self.policy = policy + self.goal = goal + self.exploration_prompts_toggled = len(exploration_prompts) > 0 + if self.exploration_prompts_toggled: + exploration_prompts = copy.deepcopy(exploration_prompts) + exploration_prompts.append(None) + self.exploration_prompts = exploration_prompts + self.exploration_prompt_probs = np.array(exploration_prompt_probs) + assert self.exploration_prompt_probs.sum() <= 1 + assert np.all(self.exploration_prompt_probs >= 0) + self.exploration_prompt_probs = np.append( + self.exploration_prompt_probs, 1 - self.exploration_prompt_probs.sum() + ) + self.state = NegotiationAgentState( + round_nb=0, nb_messages_sent_this_round=0, chat_counter=0, chat_history=[] + ) + + # Implemented in variants + self.intro_prompt = "" + self.new_round_prompt = "" + self.last_round_prompt = "" + self.send_split_prompt = "" + self.wait_for_message_prompt = "" + self.last_message_prompt = "" + self.send_message_prompt = "" + + @abstractmethod + def get_message_regex(self, observation: NegotiationObs) -> str: + pass + + @abstractmethod + def get_split_regex(self, observation: NegotiationObs) -> str: + pass + + @abstractmethod + def get_split_action( + self, policy_output: str, observation: NegotiationObs + ) -> Split: + pass + + async def act(self, observation: NegotiationObs) -> Tuple[Any, AgentActLog]: + def dict_to_str(d: dict) -> str: + return ", ".join(f"{v} {k}" for k, v in d.items()) + + def dict_to_eq_str(d: dict) -> str: + return ", ".join(f"{k}={v}" for k, v in d.items()) + + is_our_turn = observation.current_agent == self.agent_id + action: Any = None + round_nb = observation.round_nb + + prompt_parts: List[str] = [] + obs_ctx = vars(observation) + obs_ctx_formmated = obs_ctx.copy() + for key in obs_ctx_formmated: + if isinstance(obs_ctx_formmated[key], dict) and "value" not in key: + obs_ctx_formmated[key] = dict_to_str(obs_ctx_formmated[key]) + elif isinstance(obs_ctx_formmated[key], dict) and "value" in key: + obs_ctx_formmated[key] = dict_to_eq_str(obs_ctx_formmated[key]) + + ####################################### + # build user prompt + ####################################### + + # First-ever call + is_intro = round_nb == 0 and self.state.chat_counter == 0 + if is_intro: + prompt_parts.append( + self.intro_prompt.format( + goal=self.goal, agent=self.agent_name, **obs_ctx_formmated + ) + ) + + # New round + is_new_round = round_nb > self.state.round_nb + if is_new_round or is_intro: + self.state.nb_messages_sent_this_round = 0 + if not is_intro: + prompt_parts.append(self.last_round_prompt.format(**obs_ctx_formmated)) + prompt_parts.append(self.new_round_prompt.format(**obs_ctx_formmated)) + if self.exploration_prompts_toggled: + exploration_prompt = self.exploration_prompts[ + np.random.choice( + len(self.exploration_prompts), p=self.exploration_prompt_probs + ) + ] + if exploration_prompt is not None: + prompt_parts.append(exploration_prompt) + self.state.round_nb = round_nb + + # Wait for message + if not is_our_turn and not observation.split_phase: + prompt_parts.append( + self.wait_for_message_prompt.format(**obs_ctx_formmated) + ) + + # Get last message + if is_our_turn and not is_new_round and not is_intro: + prompt_parts.append(self.last_message_prompt.format(**obs_ctx_formmated)) + + # Prompt to send message + must_send_message = not observation.split_phase and is_our_turn + if must_send_message: + prompt_parts.append(self.send_message_prompt.format(**obs_ctx_formmated)) + + # Prompt to give split + must_send_split = not must_send_message and observation.split_phase + if must_send_split: + var_names = ["x", "y", "z", "w"] # Extend as needed + items_str = ", ".join( + [ + f"{var_names[i]} {item}" + for i, item in enumerate(obs_ctx["quantities"].keys()) + ] + ) + ranges_str = ", ".join( + [ + f"{var_names[i]}: 0-{obs_ctx['quantities'][item]} (integer)" + for i, item in enumerate(obs_ctx["quantities"].keys()) + ] + ) + proposal_style = f"Proposal: {items_str} where {ranges_str}." + proposal_style2 = ( + f" {items_str} where {ranges_str}." + ) + prompt_parts.append( + self.send_split_prompt.format( + proposal_style=proposal_style, + proposal_style2=proposal_style2, + **obs_ctx_formmated, + ) + ) + + # Append one ChatTurn with is_state_end=True + user_prompt = "\n".join(prompt_parts) + self.state.chat_history.append( + ChatTurn( + agent_id=self.agent_id, + role="user", + content=user_prompt, + is_state_end=True, + ) + ) + + ####################################### + # Get policy action + ####################################### + + # Query policy for the appropriate format + if must_send_message: + return_regex = self.get_message_regex(observation) + policy_output = await self.policy( + state=self.state.chat_history, + agent_id=self.agent_id, + regex=return_regex, + ) + self.state.chat_history.append( + ChatTurn( + agent_id=self.agent_id, + role="assistant", + content=policy_output.content, + reasoning_content=policy_output.reasoning_content, + log_probs=policy_output.log_probs, + out_token_ids=policy_output.out_token_ids, + is_state_end=False, + ) + ) + action = Message(message=policy_output.content) + self.state.nb_messages_sent_this_round += 1 + + elif must_send_split: + return_regex = self.get_split_regex(observation) + policy_output = await self.policy( + state=self.state.chat_history, + agent_id=self.agent_id, + regex=return_regex, + ) + self.state.chat_history.append( + ChatTurn( + agent_id=self.agent_id, + role="assistant", + content=policy_output.content, + reasoning_content=policy_output.reasoning_content, + log_probs=policy_output.log_probs, + out_token_ids=policy_output.out_token_ids, + is_state_end=False, + ) + ) + action = self.get_split_action(policy_output.content, observation) + else: + action = None + + agent_step_log = AgentActLog( + chat_turns=self.state.chat_history[self.state.chat_counter :], info=None + ) + self.state.chat_counter = len(self.state.chat_history) + return action, agent_step_log + + def get_safe_copy(self): + agent_copy = copy.copy(self) + agent_copy.state = copy.deepcopy(self.state) + return agent_copy + + def reset(self): + self.state = NegotiationAgentState( + round_nb=0, nb_messages_sent_this_round=0, chat_counter=0, chat_history=[] + ) diff --git a/src_code_for_reproducibility/markov_games/negotiation/nego_simulation.py b/src_code_for_reproducibility/markov_games/negotiation/nego_simulation.py new file mode 100644 index 0000000000000000000000000000000000000000..6a4d18532ab0472cd8f83414d52cf6df589fe126 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/negotiation/nego_simulation.py @@ -0,0 +1,241 @@ +""" +Negotiation simulation environment +other agent is set at the start of every round. Even though current agent changes over message turns in a round. +""" +import copy +from abc import abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple + +from numpy.random import default_rng + +from mllm.markov_games.rollout_tree import SimulationStepLog +from mllm.markov_games.simulation import Simulation +from mllm.utils.get_coagent_id import get_coagent_id + +AgentId = str + + +@dataclass +class Split: + items_given_to_self: Dict[str, int] + + +@dataclass +class Message: + message: str + + +@dataclass # gets extended by variants +class NegotiationState: + round_nb: int + last_message: str + current_agent: AgentId + quantities: Dict[str, int] + values: Dict[AgentId, Dict[str, float]] + splits: Dict[AgentId, Split | None] + nb_messages_sent: Dict[AgentId, int] + previous_values: Dict[AgentId, Dict[str, float]] | None + previous_splits: Dict[AgentId, Dict[str, int] | None] | None + previous_points: Dict[AgentId, float] | None + previous_quantities: Dict[str, int] | None + split_phase: bool + + +@dataclass # gets extended by variants +class NegotiationObs: + round_nb: int + last_message: str + quota_messages_per_agent_per_round: int + current_agent: AgentId + other_agent: str + quantities: Dict[str, int] + item_types: List[str] + value: Dict[str, int] + split_phase: bool + last_split_agent: Dict[str, int] | None + last_value_agent: Dict[str, int] | None + last_points_agent: float | None + last_split_coagent: Dict[str, int] | None + last_value_coagent: Dict[str, int] | None + last_points_coagent: float | None + last_quantities: Dict[str, int] | None + + +def compute_tas_style_rewards( + agent_ids: List[AgentId], + values: Dict[AgentId, float], + splits: Dict[AgentId, Split], + quantities: Dict[str, int], +) -> Dict[AgentId, float]: + """ + TAS-like reward computation: if sum of proposed coins exceeds max_coins, + allocate proportionally. Otherwise, use proposed amounts directly. + Rewards are quantity_kept * per-coin value for each agent. + """ + a0, a1 = agent_ids[0], agent_ids[1] + r0, r1 = 0.0, 0.0 + + for item in quantities: + max_item = quantities[item] + item_to_self_0 = int( + (splits[a0].items_given_to_self.get(item, 0)) + if splits[a0] is not None + else 0 + ) + item_to_self_1 = int( + (splits[a1].items_given_to_self.get(item, 0)) + if splits[a1] is not None + else 0 + ) + denom = max(int(max_item), item_to_self_0 + item_to_self_1) + q0 = float(max_item) * float(item_to_self_0) / float(denom) + q1 = float(max_item) * float(item_to_self_1) / float(denom) + if type(values[a0]) is not dict: + r0 += q0 * float(values[a0]) + r1 += q1 * float(values[a1]) + else: + r0 += q0 * float(values[a0][item]) + r1 += q1 * float(values[a1][item]) + return {a0: r0, a1: r1} + + +class NegotiationSimulation(Simulation): + def __init__( + self, + agent_ids: List[AgentId], + agent_names: List[str], + seed: int, + nb_of_rounds: int, + quota_messages_per_agent_per_round: int, + item_types: List[str] | None = None, + ): + self.seed = seed + self.rng = default_rng(self.seed) + self.agent_ids = list(agent_ids) + self.agent_names = agent_names + self.agent_id_to_name = { + agent_id: agent_name for agent_id, agent_name in zip(agent_ids, agent_names) + } + self.nb_of_rounds = int(nb_of_rounds) + self.quota_messages_per_agent_per_round = int( + quota_messages_per_agent_per_round + ) + if item_types is not None: + self.item_types = [item.lower() for item in item_types] + else: + self.item_types = ["coins"] + self.state: NegotiationState | None = None + self._starting_agent_index = self.rng.choice([0, 1]) + self.reset() + + def _other(self, agent_id: AgentId) -> AgentId: + return get_coagent_id(self.agent_ids, agent_id) + + @abstractmethod + def set_new_round_of_variant(self): + pass + + @abstractmethod + def get_info_of_variant( + self, state: NegotiationState, actions: Dict[AgentId, Any] + ) -> Dict[str, Any]: + pass + + def step(self, actions: Any) -> Tuple[bool, SimulationStepLog]: + """ + Returns terminated, step_log + """ + assert self.state is not None + current_agent = self.state.current_agent + a0, a1 = self.agent_ids[0], self.agent_ids[1] + action = actions.get(current_agent) + + # Split phase: require both splits in the same timestep + if self.state.split_phase: + action_a0 = actions.get(a0) + action_a1 = actions.get(a1) + have_both_splits = isinstance(action_a0, Split) and isinstance( + action_a1, Split + ) + if not have_both_splits: + rewards = {agent_id: 0.0 for agent_id in self.agent_ids} + return False, SimulationStepLog( + rewards=rewards, info={"type": "waiting_for_splits"} + ) + + # Record splits + self.state.splits[a0] = action_a0 + self.state.splits[a1] = action_a1 + + # Compute rewards and end round + rewards = self.get_rewards(self.state.splits) + + # Info + info = self.get_info_of_variant(self.state, actions) + + # Prepare next round + # Alternate starting agent + self.state.round_nb += 1 + self._starting_agent_index = 1 - self._starting_agent_index + self.state.current_agent = self.agent_ids[self._starting_agent_index] + self.state.previous_values = copy.deepcopy(self.state.values) + self.state.previous_splits = copy.deepcopy(self.state.splits) + self.state.previous_quantities = copy.deepcopy(self.state.quantities) + self.state.previous_points = copy.deepcopy(rewards) + self.state.last_message = "" + self.set_new_round_of_variant() # variant specific + self.state.splits = {agent_id: None for agent_id in self.agent_ids} + self.state.nb_messages_sent = {agent_id: 0 for agent_id in self.agent_ids} + is_last_timestep_in_round = True + done = self.state.round_nb >= self.nb_of_rounds + + # Message phase + elif isinstance(action, Message): + self.state.last_message = action.message + self.state.nb_messages_sent[current_agent] += 1 + + # Move turn to other agent + self.state.current_agent = self._other(current_agent) + + # If both agents have reached their message quota, enter split phase + if all( + self.state.nb_messages_sent[agent_id] + >= self.quota_messages_per_agent_per_round + for agent_id in self.agent_ids + ): + self.state.split_phase = True + is_last_timestep_in_round = False + done = False + rewards = {agent_id: 0.0 for agent_id in self.agent_ids} + info = {"type": "message"} + + info[ + "is_last_timestep_in_round" + ] = is_last_timestep_in_round # Used later to group round timesteps if needed + return done, SimulationStepLog(rewards=rewards, info=info) + + def get_obs(self): + """Returns all agent observations in dict""" + return {agent_id: self.get_obs_agent(agent_id) for agent_id in self.agent_ids} + + @abstractmethod + def get_rewards(self, splits: Dict[AgentId, Split]) -> Dict[AgentId, float]: + pass + + @abstractmethod + def get_obs_agent(self, agent_id): + pass + + def get_state(self): + return self.state + + def get_safe_copy(self): + """Return a safe copy of the simulation.""" + simulation_copy = copy.copy(self) + simulation_copy.state = copy.deepcopy(self.state) + return simulation_copy + + @abstractmethod + def reset(self) -> dict[AgentId, NegotiationObs]: + pass diff --git a/src_code_for_reproducibility/markov_games/negotiation/no_press_nego_agent.py b/src_code_for_reproducibility/markov_games/negotiation/no_press_nego_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..c0a62f2465ac8b34dc09cbc003dcc663b170ffe7 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/negotiation/no_press_nego_agent.py @@ -0,0 +1,94 @@ +from typing import Any, Dict, List, Tuple + +from mllm.markov_games.negotiation.nego_agent import ( + NegotiationAgent, + NegotiationAgentState, +) +from mllm.markov_games.negotiation.nego_simulation import Split +from mllm.markov_games.negotiation.no_press_nego_simulation import NoPressObs +from mllm.markov_games.rollout_tree import AgentActLog, ChatTurn + + +class NoPressAgent(NegotiationAgent): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # No communication in this variant + self.intro_prompt = ( + "Welcome to an iterated game. You are {agent}. The other agent is {other_agent}.\n" + "Setup:\n" + "1. The game consists of multiple independent rounds.\n" + "2. In each round, there are multiple items to split between the two agents.\n" + "3. Both agents are assigned a per-item value between 1 and 20 (inclusive) in each round.\n" + "4. You can observe per-item values of both agents.\n" + "5. Because assignments are random, both agents are equally likely to have same expected per-item value.\n" + "\n" + "Protocol:\n" + "1. Both agents simultaneously propose the amount of each item they will keep.\n" + "2. If the total sum of proposals is less than or equal to the item quantity, both agents receive their proposed amounts.\n" + "3. If the total sum of proposals exceeds the item quantity, they are allocated proportionally.\n" + "4. Your points for the round = (amount you receive per item) x (your per-item value for that round), added across all items.\n" + "5. Points are accumulated across rounds.\n" + "Your goal: {goal}\n" + ) + self.new_round_prompt = ( + "A New Round Begins\n" + "The items to split are {quantities}.\n" + "Your per-item values are {value} and {other_agent}'s per-item values are {other_value}." + ) + self.last_round_prompt = ( + "Last Round Summary:\n" + " - Items to split: {last_quantities}\n" + " - Your per-item values: {last_value_agent}\n" + " - {other_agent}'s per-item values: {last_value_coagent}\n" + " - You proposed: {last_split_agent}\n" + " - You earned: {last_points_agent} points\n" + " - {other_agent} proposed: {last_split_coagent}\n" + " - {other_agent} earned: {last_points_coagent} points\n" + " - Round Complete.\n" + ) + self.send_split_prompt = "Submit Your Proposal\n" "Respond as {proposal_style}" + + def get_message_regex(self, observation: NoPressObs) -> str: + return r"^$" # No messages allowed + + def get_split_regex(self, observation: NoPressObs) -> str: + items = list(observation.quantities.keys()) + # Accept both singular and plural forms + item_pattern = "|".join( + [f"{item[:-1]}s?" if item.endswith("s") else f"{item}s?" for item in items] + ) + regex = rf"(?i)Proposal:\s*((?:\s*(?P(10|[0-9]))\s*(?P{item_pattern})\s*,?)+)" + return regex + + def get_split_action(self, policy_output: str, observation: NoPressObs) -> Split: + items = list(observation.quantities.keys()) + import re as _re + + split_regex = self.get_split_regex(observation) + items_given_to_self = {item: 0 for item in items} + m = _re.match(split_regex, policy_output.strip()) + if m: + # Find all (number, item) pairs + item_pattern = "|".join( + [ + f"{item[:-1]}s?" if item.endswith("s") else f"{item}s?" + for item in items + ] + ) + inner_regex = rf"(?i)(10|[0-9])\s*({item_pattern})" + + def normalize_item_name(item_str): + for orig in items: + if item_str.lower() == orig.lower(): + return orig + if orig.endswith("s") and item_str.lower() == orig[:-1].lower(): + return orig + if ( + not orig.endswith("s") + and item_str.lower() == orig.lower() + "s" + ): + return orig + + for num, item in _re.findall(inner_regex, m.group(1)): + items_given_to_self[normalize_item_name(item)] = int(num) + return Split(items_given_to_self=items_given_to_self) diff --git a/src_code_for_reproducibility/markov_games/negotiation/no_press_nego_simulation.py b/src_code_for_reproducibility/markov_games/negotiation/no_press_nego_simulation.py new file mode 100644 index 0000000000000000000000000000000000000000..d182187cc72c889a76f2d1c5be4b3afb6b923ed8 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/negotiation/no_press_nego_simulation.py @@ -0,0 +1,168 @@ +import copy +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Dict, List, Literal, Tuple + +from mllm.markov_games.negotiation.nego_simulation import ( + NegotiationObs, + NegotiationSimulation, + NegotiationState, + Split, + compute_tas_style_rewards, +) + +AgentId = str + + +@dataclass +class NoPressState(NegotiationState): + pass + + +@dataclass +class NoPressObs(NegotiationObs): + other_value: Dict[str, float] + + +class NoPressSimulation(NegotiationSimulation): + def __init__( + self, + game_type: Literal["10-1-exclusive", "10-1-ties", "1-to-20"] = "1-to-20", + same_round_value: bool = True, + atleast_one_conflict: bool = False, + *args, + **kwargs, + ): + self.game_type = game_type + self.same_round_value = same_round_value + self.atleast_one_conflict = atleast_one_conflict + super().__init__(*args, **kwargs) + + def _sample_values(self) -> Dict[AgentId, dict]: + values = defaultdict(dict) + if self.state is None: + item_types = self.item_types + else: + item_types = list(self.state.quantities.keys()) + while True: + for item in item_types: + if self.game_type == "10-1-exclusive": + v = int(self.rng.choice([1, 10])) + values[self.agent_ids[0]][item] = v + values[self.agent_ids[1]][item] = 10 if v == 1 else 1 + elif self.game_type == "10-1-ties": + for aid in self.agent_ids: + values[aid][item] = int(self.rng.choice([1, 10])) + elif self.game_type == "1-to-20": + for aid in self.agent_ids: + values[aid][item] = int(self.rng.integers(1, 21)) + if self.atleast_one_conflict: + has_conflict = False + for item in item_types: + agent_values_for_item = [ + values[aid][item] for aid in self.agent_ids + ] + if len(set(agent_values_for_item)) > 1: + has_conflict = True + break + if not has_conflict: + continue + agent_values = [sum(v.values()) for v in values.values()] + if len(set(agent_values)) == 1 or not self.same_round_value: + break + return values + + def _sample_quantities(self) -> Dict[str, int]: + return {item.lower(): 10 for item in self.item_types} + + def set_new_round_of_variant(self): + self.state.quantities = self._sample_quantities() + self.state.values = self._sample_values() + self.state.split_phase = True + + def get_info_of_variant( + self, state: NegotiationState, actions: Dict[AgentId, Any] + ) -> Dict[str, Any]: + return { + "quantities": copy.deepcopy(state.quantities), + "values": copy.deepcopy(state.values), + "splits": copy.deepcopy(state.splits), + } + + def get_rewards(self, splits: Dict[AgentId, Split]) -> Dict[AgentId, float]: + return compute_tas_style_rewards( + self.agent_ids, self.state.values, splits, self.state.quantities + ) + + def get_obs(self): + return {agent_id: self.get_obs_agent(agent_id) for agent_id in self.agent_ids} + + def get_obs_agent(self, agent_id): + other_id = self._other(agent_id) + last_value_coagent = ( + None + if self.state.previous_values is None + else self.state.previous_values.get(other_id) + ) + last_points_coagent = ( + None + if self.state.previous_points is None + else round(self.state.previous_points.get(other_id), 1) + ) + last_value_agent = ( + None + if self.state.previous_values is None + else self.state.previous_values.get(agent_id) + ) + last_points_agent = ( + None + if self.state.previous_points is None + else round(self.state.previous_points.get(agent_id), 1) + ) + last_split_coagent = None + last_split_agent = None + if self.state.previous_splits is not None: + last_split_coagent = self.state.previous_splits[ + other_id + ].items_given_to_self + last_split_agent = self.state.previous_splits[agent_id].items_given_to_self + obs = NoPressObs( + round_nb=self.state.round_nb, + last_message="", + quota_messages_per_agent_per_round=self.quota_messages_per_agent_per_round, + current_agent=self.state.current_agent, + other_agent=self.agent_id_to_name[other_id], + quantities=self.state.quantities, + item_types=self.item_types, + value=self.state.values[agent_id], + split_phase=self.state.split_phase, + last_split_agent=last_split_agent, + last_value_agent=last_value_agent, + last_points_agent=last_points_agent, + last_split_coagent=last_split_coagent, + last_value_coagent=last_value_coagent, + last_points_coagent=last_points_coagent, + other_value=self.state.values[other_id], + last_quantities=self.state.previous_quantities, + ) + return obs + + def reset(self): + start_agent = self.agent_ids[self._starting_agent_index] + quantities = self._sample_quantities() + values = self._sample_values() + self.state = NoPressState( + round_nb=0, + last_message="", + current_agent=start_agent, + quantities=quantities, + values=values, + previous_values=None, + splits={aid: None for aid in self.agent_ids}, + nb_messages_sent={aid: 0 for aid in self.agent_ids}, + split_phase=True, + previous_splits=None, + previous_points=None, + previous_quantities=None, + ) + return self.get_obs() diff --git a/src_code_for_reproducibility/markov_games/negotiation/tas_agent.py b/src_code_for_reproducibility/markov_games/negotiation/tas_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..002160873969ab7292f0f62a091e12ec376022c6 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/negotiation/tas_agent.py @@ -0,0 +1,108 @@ +from mllm.markov_games.negotiation.nego_agent import NegotiationAgent +from mllm.markov_games.negotiation.nego_simulation import Split +from mllm.markov_games.negotiation.tas_simulation import TrustAndSplitObs + + +class TrustAndSplitAgent(NegotiationAgent): + def __init__(self, num_message_chars, *args, **kwargs): + self.num_message_chars = num_message_chars + super().__init__(*args, **kwargs) + self.intro_prompt = ( + "Welcome to an iterated game. You are {agent}. The other agent is {other_agent}.\n" + "Setup:\n" + "1. The game has multiple independent rounds.\n" + "2. In each round, there are multiple items to split between the two agents.\n" + "3. Both agents are assigned a per-item value between 1 and 20 (inclusive) in each round.\n" + "4. You can only observe your own per-item values.\n" + "5. Because assignments are random, both agents are equally likely to have same expected per-item value.\n" + "\n" + "Protocol:\n" + "1. At the start of the round, one agent begins the conversation. The starting role alternates each round.\n" + "2. Agents exchange a short chat ({quota_messages_per_agent_per_round} messages per round per agent) to negotiate how to split the item.\n" + " - Use this chat to communicate your private per-item value to make informed proposals.\n" + "3. After the chat, both agents simultaneously propose the amount of each item they will keep.\n" + "4. If the total sum of proposals is less than or equal to the item quantity, both agents receive their proposed amounts.\n" + "5. If the total sum of proposals exceeds the item quantity, they are allocated proportionally.\n" + "6. Your points for the round = (amount you receive per item) x (your per-item value for that round), added across all items.\n" + "7. Points are accumulated across rounds.\n" + "Your goal: {goal}\n" + ) + self.new_round_prompt = ( + "A New Round Begins\n" + "The items to split are {quantities}.\n" + "Your per-item values are {value}." + ) + self.last_round_prompt = ( + "Last Round Summary:\n" + " - Items to split: {last_quantities}\n" + " - Your per-item values: {last_value_agent}\n" + " - {other_agent}'s per-item values: {last_value_coagent}\n" + " - You proposed: {last_split_agent}\n" + " - You earned: {last_points_agent} points\n" + " - {other_agent} proposed: {last_split_coagent}\n" + " - {other_agent} earned: {last_points_coagent} points\n" + " - Round Complete.\n" + ) + self.send_split_prompt = ( + "Message quota is finished for this round.\n" + "{other_agent} has finalized their proposal.\n" + "Submit your finalization now\n" + "Respond with {proposal_style2}" + ) + # self.wait_for_message_prompt = "Wait for {other_agent} to send a message..." + self.wait_for_message_prompt = "" + self.last_message_prompt = "{other_agent} said: {last_message}" + # self.send_message_prompt = ( + # f"Send your message now (max {self.num_message_chars} chars)." + # ) + self.send_message_prompt = f"Send your message now in ... (<={self.num_message_chars} chars)." + + def get_message_regex(self, observation: TrustAndSplitObs) -> str: + return rf"[\s\S]{{0,{self.num_message_chars}}}" + + # def get_message_regex(self, observation: TrustAndSplitObs) -> str: + # return rf"(?s).{{0,{self.num_message_chars}}}" + + def get_split_regex(self, observation: TrustAndSplitObs) -> str: + items = list(observation.quantities.keys()) + # Accept both singular and plural forms + item_pattern = "|".join( + [f"{item[:-1]}s?" if item.endswith("s") else f"{item}s?" for item in items] + ) + regex = rf"(?i) ?((?:\s*(?P(10|[0-9]))\s*(?P{item_pattern})\s*,?)+) ?" + return regex + + def get_split_action( + self, policy_output: str, observation: TrustAndSplitObs + ) -> Split: + items = list(observation.quantities.keys()) + import re as _re + + split_regex = self.get_split_regex(observation) + items_given_to_self = {item: 0 for item in items} + m = _re.match(split_regex, policy_output.strip()) + if m: + # Find all (number, item) pairs + item_pattern = "|".join( + [ + f"{item[:-1]}s?" if item.endswith("s") else f"{item}s?" + for item in items + ] + ) + inner_regex = rf"(?i)(10|[0-9])\s*({item_pattern})" + + def normalize_item_name(item_str): + for orig in items: + if item_str.lower() == orig.lower(): + return orig + if orig.endswith("s") and item_str.lower() == orig[:-1].lower(): + return orig + if ( + not orig.endswith("s") + and item_str.lower() == orig.lower() + "s" + ): + return orig + + for num, item in _re.findall(inner_regex, m.group(1)): + items_given_to_self[normalize_item_name(item)] = int(num) + return Split(items_given_to_self=items_given_to_self) diff --git a/src_code_for_reproducibility/markov_games/negotiation/tas_rps_agent.py b/src_code_for_reproducibility/markov_games/negotiation/tas_rps_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..e711c2a65d336e4d9b991c68662069e96b4dfee8 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/negotiation/tas_rps_agent.py @@ -0,0 +1,118 @@ +import copy +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple + +from mllm.markov_games.agent import Agent +from mllm.markov_games.negotiation.nego_agent import ( + Message, + NegotiationAgent, + NegotiationAgentState, + Split, +) +from mllm.markov_games.negotiation.tas_rps_simulation import TrustAndSplitRPSObs +from mllm.markov_games.rollout_tree import AgentActLog, ChatTurn + + +class TrustAndSplitRPSAgent(NegotiationAgent): + def __init__( + self, + num_message_chars: int, + message_start_end_format: bool = False, + proposal_start_end_format: bool = False, + *args, + **kwargs, + ): + self.num_message_chars = num_message_chars + self.message_start_end_format = message_start_end_format + self.proposal_start_end_format = proposal_start_end_format + super().__init__(*args, **kwargs) + self.intro_prompt = ( + "Welcome to an iterated game. You are {agent}. The other agent is {other_agent}.\n" + "\n" + "Setup:\n" + "1. The game has multiple independent rounds.\n" + "2. In each round, there are 10 coins to split between the two agents.\n" + "3. Each agent's per-coin value for that round is determined as follows:\n" + " - Both agents are randomly assigned a rock, paper or scissors hands\n" + " - Rock has the upper hand over scissors, scissors has the upper hand over paper and paper has the upper hand over rock.\n" + " - The agent with the upper hand has a per-coin value of 10.\n" + " - The agent with the lower hand has a per-coin value of 1.\n" + "4. You only see your own hand, but you may communicate it in messages and infer your value based on the other agent's hand.\n" + "5. Over many rounds both agents are equally likely to have the upper and lower hand.\n" + "\n" + "Protocol:\n" + "1. At the start of the round, one agent begins the conversation. The starting role alternates each round.\n" + "2. Agents exchange a short chat ({quota_messages_per_agent_per_round} messages per round per agent) to negotiate how to split the 10 coins.\n" + " - Use this chat to communicate your hand so that both agents can determine their per-coin values.\n" + "3. After the chat, both agents simultaneously propose how many coins they keep.\n" + "4. If the total sum of proposals is less than or equal to 10, both agents receive their proposals.\n" + "5. If the total sum of proposals exceeds 10, the coins are allocated proportionally.\n" + "6. Your points for the round = (coins you receive) x (your per-coin value for that round). \n" + "7. The points are accumulated across rounds.\n" + "Your goal: {goal}\n" + ) + self.new_round_prompt = ( + "A New Round Begins\n" + "Your hand is {hand}. You don't know {other_agent}'s hand yet.\n" + ) + # self.last_round_prompt = ( + # "Last Round Summary:\n" + # " - Your hand: {last_hand_agent}\n" + # " - {other_agent}'s hand: {last_hand_coagent}\n" + # " - Your value per coin: {last_value_agent}\n" + # " - {other_agent}'s value per coin: {last_value_coagent}\n" + # " - You proposed: {last_split_agent} coins\n" + # " - You earned: {last_points_agent} points\n" + # " - {other_agent} proposed: {last_split_coagent} coins\n" + # " - {other_agent} earned: {last_points_coagent} points\n" + # " - Round Complete.\n" + # ) + self.last_round_prompt = "In the previous round, {other_agent} had a {last_hand_value_coagent} hand and proposed {last_split_coagent} coins.\n" + if self.proposal_start_end_format: + self.send_split_prompt = ( + "Submit your proposal\n" + "Respond with <> x <> where x is an integer in [0, 10]." + ) + else: + self.send_split_prompt = ( + "Submit your proposal\n" + "Respond with x where x is an integer in [0, 10]." + ) + self.wait_for_message_prompt = "Wait for {other_agent} to send a message..." + # self.wait_for_message_prompt = "" + self.last_message_prompt = "{other_agent} said: {last_message}" + if self.message_start_end_format: + self.send_message_prompt = f"Send your message now in <>...<> (<={self.num_message_chars} chars)." + else: + self.send_message_prompt = f"Send your message now in ... (<={self.num_message_chars} chars)." + + def get_message_regex(self, observation: TrustAndSplitRPSObs) -> str: + if self.message_start_end_format: + return ( + rf"<>[\s\S]{{0,{self.num_message_chars}}}<>" + ) + else: + return rf"[\s\S]{{0,{self.num_message_chars}}}" + + def get_split_regex(self, observation: TrustAndSplitRPSObs) -> str: + if self.proposal_start_end_format: + return r"<> ?(10|[0-9]) ?<>" + else: + return r" ?(10|[0-9]) ?" + + def get_split_action( + self, policy_output: str, observation: TrustAndSplitRPSObs + ) -> Split: + import re as _re + + if self.proposal_start_end_format: + m = _re.search( + r"<> ?(10|[0-9]) ?<>", policy_output + ) + else: + m = _re.search( + r" ?(10|[0-9]) ?", policy_output + ) + coins_int = int(m.group(1)) if m else int(policy_output) + return Split(items_given_to_self={"coins": coins_int}) diff --git a/src_code_for_reproducibility/markov_games/negotiation/tas_simple_agent.py b/src_code_for_reproducibility/markov_games/negotiation/tas_simple_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..e4439b53a04e8efe4553cb1aa0d85459a6e90c9d --- /dev/null +++ b/src_code_for_reproducibility/markov_games/negotiation/tas_simple_agent.py @@ -0,0 +1,90 @@ +from mllm.markov_games.negotiation.nego_agent import NegotiationAgent +from mllm.markov_games.negotiation.nego_simulation import Split +from mllm.markov_games.negotiation.tas_simple_simulation import TrustAndSplitSimpleObs + + +class TrustAndSplitSimpleAgent(NegotiationAgent): + def __init__( + self, + num_message_chars, + message_start_end_format: bool = False, + proposal_start_end_format: bool = False, + *args, + **kwargs, + ): + self.num_message_chars = num_message_chars + self.message_start_end_format = message_start_end_format + self.proposal_start_end_format = proposal_start_end_format + super().__init__(*args, **kwargs) + self.intro_prompt = ( + "Welcome to an iterated game. You are {agent}. The other agent is {other_agent}.\n" + "Setup:\n" + "1. The game has multiple independent rounds.\n" + "2. In each round, there are 10 coins to split between the two agents.\n" + "3. Both agents are assigned a per-coin value between 1 and 10 (inclusive) in each round.\n" + "4. You can only observe your own per-coin value.\n" + "5. Because assignments are random, both agents are equally likely to have same expected per-coin value.\n" + "\n" + "Protocol:\n" + "1. At the start of the round, one agent begins the conversation. The starting role alternates each round.\n" + "2. Agents exchange a short chat ({quota_messages_per_agent_per_round} messages per round per agent) to negotiate how to split the coins.\n" + " - Use this chat to communicate your private per-coin value to make informed proposals.\n" + "3. After the chat, both agents simultaneously propose how many coins they keep.\n" + "4. If the total sum of proposals is less than or equal to 10, both agents receive their proposals.\n" + "5. If the total sum of proposals exceeds 10, the coins are allocated proportionally.\n" + "6. Your points for the round = (coins you receive) x (your per-coin value for that round). \n" + "7. Points are accumulated across rounds.\n" + "Your goal: {goal}\n" + ) + self.new_round_prompt = ( + "A New Round Begins\n" + "Your per-coin value is {value}. You don't know {other_agent}'s value yet.\n" + ) + self.last_round_prompt = "In the previous round, {other_agent} had a {last_value_str_coagent} value and proposed {last_split_coagent} coins.\n" + if self.proposal_start_end_format: + self.send_split_prompt = ( + "Submit your proposal\n" + "Respond with <> x <> where x is an integer in [0, 10]." + ) + else: + self.send_split_prompt = ( + "Submit your proposal\n" + "Respond with x where x is an integer in [0, 10]." + ) + self.wait_for_message_prompt = "Wait for {other_agent} to send a message..." + # self.wait_for_message_prompt = "" + self.last_message_prompt = "{other_agent} said: {last_message}" + if self.message_start_end_format: + self.send_message_prompt = f"Send your message now in <>...<> (<={self.num_message_chars} chars)." + else: + self.send_message_prompt = f"Send your message now in ... (<={self.num_message_chars} chars)." + + def get_message_regex(self, observation: TrustAndSplitSimpleObs) -> str: + if self.message_start_end_format: + return ( + rf"<>[\s\S]{{0,{self.num_message_chars}}}<>" + ) + else: + return rf"[\s\S]{{0,{self.num_message_chars}}}" + + def get_split_regex(self, observation: TrustAndSplitSimpleObs) -> str: + if self.proposal_start_end_format: + return r"<> ?(10|[0-9]) ?<>" + else: + return r" ?(10|[0-9]) ?" + + def get_split_action( + self, policy_output: str, observation: TrustAndSplitSimpleObs + ) -> Split: + import re as _re + + if self.proposal_start_end_format: + m = _re.search( + r"<> ?(10|[0-9]) ?<>", policy_output + ) + else: + m = _re.search( + r" ?(10|[0-9]) ?", policy_output + ) + coins_int = int(m.group(1)) if m else int(policy_output) + return Split(items_given_to_self={"coins": coins_int}) diff --git a/src_code_for_reproducibility/markov_games/negotiation/tas_simple_simulation.py b/src_code_for_reproducibility/markov_games/negotiation/tas_simple_simulation.py new file mode 100644 index 0000000000000000000000000000000000000000..3dbd0c43d73e3f7b18204b62e71d72b2df1d13e6 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/negotiation/tas_simple_simulation.py @@ -0,0 +1,169 @@ +import copy +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Dict, List, Literal + +from numpy.random import default_rng + +from mllm.markov_games.negotiation.nego_simulation import ( + NegotiationObs, + NegotiationSimulation, + NegotiationState, + Split, + compute_tas_style_rewards, +) + +AgentId = str + + +@dataclass +class TrustAndSplitSimpleState(NegotiationState): + pass + + +@dataclass +class TrustAndSplitSimpleObs(NegotiationObs): + last_value_str_coagent: str | None + + +class TrustAndSplitSimpleSimulation(NegotiationSimulation): + def __init__( + self, + game_type: Literal["10-1-exclusive", "1-to-10"] = "1-to-10", + dist_type: Literal["uniform", "bimodal"] = "uniform", + beta_dist_alpha: float = 0.1, + beta_dist_beta: float = 0.1, + *args, + **kwargs, + ): + self.game_type = game_type + self.dist_type = dist_type + self.beta_dist_alpha = beta_dist_alpha + self.beta_dist_beta = beta_dist_beta + super().__init__(*args, **kwargs) + + def _sample_values(self) -> Dict[AgentId, dict]: + values = {} + while True: + if self.game_type == "10-1-exclusive": + v = int(self.rng.choice([1, 10])) + values[self.agent_ids[0]] = v + values[self.agent_ids[1]] = 10 if v == 1 else 1 + elif self.game_type == "1-to-10": + for aid in self.agent_ids: + if self.dist_type == "uniform": + values[aid] = int(self.rng.integers(1, 11)) + elif self.dist_type == "bimodal": + alpha, beta = self.beta_dist_alpha, self.beta_dist_beta + values[aid] = int(round(self.rng.beta(alpha, beta) * 9) + 1) + if len(set(values.values())) != 1: + break + return values + + def _sample_quantities(self) -> Dict[str, int]: + return {"coins": 10} + + def set_new_round_of_variant(self): + self.state.quantities = self._sample_quantities() + self.state.values = self._sample_values() + self.state.split_phase = False + + def get_info_of_variant( + self, state: NegotiationState, actions: Dict[AgentId, Any] + ) -> Dict[str, Any]: + return { + "quantities": copy.deepcopy(state.quantities), + "values": copy.deepcopy(state.values), + # "previous_values": copy.deepcopy(state.previous_values), + "splits": copy.deepcopy(state.splits), + } + + def get_rewards(self, splits: Dict[AgentId, Split]) -> Dict[AgentId, float]: + return compute_tas_style_rewards( + self.agent_ids, self.state.values, splits, self.state.quantities + ) + + def get_obs(self): + return {agent_id: self.get_obs_agent(agent_id) for agent_id in self.agent_ids} + + def get_obs_agent(self, agent_id): + other_id = self._other(agent_id) + last_value_coagent = ( + None + if self.state.previous_values is None + else self.state.previous_values.get(other_id) + ) + last_points_coagent = ( + None + if self.state.previous_points is None + else round(self.state.previous_points.get(other_id), 1) + ) + last_value_agent = ( + None + if self.state.previous_values is None + else self.state.previous_values.get(agent_id) + ) + last_points_agent = ( + None + if self.state.previous_points is None + else round(self.state.previous_points.get(agent_id), 1) + ) + last_split_coagent = None + last_split_agent = None + if self.state.previous_splits is not None: + last_split_coagent = self.state.previous_splits[ + other_id + ].items_given_to_self["coins"] + last_split_agent = self.state.previous_splits[agent_id].items_given_to_self[ + "coins" + ] + if last_value_agent is None or last_value_coagent is None: + last_value_str_coagent = None + else: + if last_value_coagent > last_value_agent: + last_value_str_coagent = "higher" + elif last_value_coagent < last_value_agent: + last_value_str_coagent = "lower" + else: + raise ValueError("Should not be equal values") + + obs = TrustAndSplitSimpleObs( + round_nb=self.state.round_nb, + last_message=self.state.last_message, + quota_messages_per_agent_per_round=self.quota_messages_per_agent_per_round, + current_agent=self.state.current_agent, + other_agent=self.agent_id_to_name[other_id], + quantities=self.state.quantities, + item_types=self.item_types, + value=self.state.values[agent_id], + split_phase=self.state.split_phase, + last_split_agent=last_split_agent, + last_value_agent=last_value_agent, + last_points_agent=last_points_agent, + last_split_coagent=last_split_coagent, + last_value_coagent=last_value_coagent, + last_points_coagent=last_points_coagent, + last_quantities=self.state.previous_quantities, + last_value_str_coagent=last_value_str_coagent, + ) + return obs + + def reset(self): + start_agent = self.agent_ids[self._starting_agent_index] + quantities = self._sample_quantities() + values = self._sample_values() + self.state = TrustAndSplitSimpleState( + round_nb=0, + last_message="", + current_agent=start_agent, + quantities=quantities, + values=values, + previous_values=None, + splits={aid: None for aid in self.agent_ids}, + nb_messages_sent={aid: 0 for aid in self.agent_ids}, + split_phase=False, + previous_splits=None, + previous_points=None, + previous_quantities=None, + ) + return self.get_obs() diff --git a/src_code_for_reproducibility/markov_games/negotiation/tas_simulation.py b/src_code_for_reproducibility/markov_games/negotiation/tas_simulation.py new file mode 100644 index 0000000000000000000000000000000000000000..5499a146e9da491757a8105965b2d210f8327134 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/negotiation/tas_simulation.py @@ -0,0 +1,172 @@ +import copy +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Dict, List, Literal + +from numpy.random import default_rng + +from mllm.markov_games.negotiation.nego_simulation import ( + NegotiationObs, + NegotiationSimulation, + NegotiationState, + Split, + compute_tas_style_rewards, +) + +AgentId = str + + +@dataclass +class TrustAndSplitState(NegotiationState): + pass + + +@dataclass +class TrustAndSplitObs(NegotiationObs): + pass + + +class TrustAndSplitSimulation(NegotiationSimulation): + def __init__( + self, + game_type: Literal["10-1-exclusive", "10-1-ties", "1-to-20"] = "1-to-20", + same_round_value: bool = True, + atleast_one_conflict: bool = False, + *args, + **kwargs, + ): + self.game_type = game_type + self.same_round_value = same_round_value + self.atleast_one_conflict = atleast_one_conflict + super().__init__(*args, **kwargs) + + def _sample_values(self) -> Dict[AgentId, dict]: + values = defaultdict(dict) + if self.state is None: + item_types = self.item_types + else: + item_types = list(self.state.quantities.keys()) + while True: + for item in item_types: + if self.game_type == "10-1-exclusive": + v = int(self.rng.choice([1, 10])) + values[self.agent_ids[0]][item] = v + values[self.agent_ids[1]][item] = 10 if v == 1 else 1 + elif self.game_type == "10-1-ties": + for aid in self.agent_ids: + values[aid][item] = int(self.rng.choice([1, 10])) + elif self.game_type == "1-to-20": + for aid in self.agent_ids: + values[aid][item] = int(self.rng.integers(1, 21)) + agent_values = [sum(v.values()) for v in values.values()] + if self.atleast_one_conflict: + has_conflict = False + for item in item_types: + agent_values_for_item = [ + values[aid][item] for aid in self.agent_ids + ] + if ( + len(set(agent_values_for_item)) > 1 + ): # Different values for this item + has_conflict = True + break + if not has_conflict: + continue + if len(set(agent_values)) == 1 or not self.same_round_value: + break + return values + + def _sample_quantities(self) -> Dict[str, int]: + return {item.lower(): 10 for item in self.item_types} + + def set_new_round_of_variant(self): + self.state.quantities = self._sample_quantities() + self.state.values = self._sample_values() + self.state.split_phase = False + + def get_info_of_variant( + self, state: NegotiationState, actions: Dict[AgentId, Any] + ) -> Dict[str, Any]: + return { + "quantities": copy.deepcopy(state.quantities), + "values": copy.deepcopy(state.values), + # "previous_values": copy.deepcopy(state.previous_values), + "splits": copy.deepcopy(state.splits), + } + + def get_rewards(self, splits: Dict[AgentId, Split]) -> Dict[AgentId, float]: + return compute_tas_style_rewards( + self.agent_ids, self.state.values, splits, self.state.quantities + ) + + def get_obs(self): + return {agent_id: self.get_obs_agent(agent_id) for agent_id in self.agent_ids} + + def get_obs_agent(self, agent_id): + other_id = self._other(agent_id) + last_value_coagent = ( + None + if self.state.previous_values is None + else self.state.previous_values.get(other_id) + ) + last_points_coagent = ( + None + if self.state.previous_points is None + else round(self.state.previous_points.get(other_id), 1) + ) + last_value_agent = ( + None + if self.state.previous_values is None + else self.state.previous_values.get(agent_id) + ) + last_points_agent = ( + None + if self.state.previous_points is None + else round(self.state.previous_points.get(agent_id), 1) + ) + last_split_coagent = None + last_split_agent = None + if self.state.previous_splits is not None: + last_split_coagent = self.state.previous_splits[ + other_id + ].items_given_to_self + last_split_agent = self.state.previous_splits[agent_id].items_given_to_self + obs = TrustAndSplitObs( + round_nb=self.state.round_nb, + last_message=self.state.last_message, + quota_messages_per_agent_per_round=self.quota_messages_per_agent_per_round, + current_agent=self.state.current_agent, + other_agent=self.agent_id_to_name[other_id], + quantities=self.state.quantities, + item_types=self.item_types, + value=self.state.values[agent_id], + split_phase=self.state.split_phase, + last_split_agent=last_split_agent, + last_value_agent=last_value_agent, + last_points_agent=last_points_agent, + last_split_coagent=last_split_coagent, + last_value_coagent=last_value_coagent, + last_points_coagent=last_points_coagent, + last_quantities=self.state.previous_quantities, + ) + return obs + + def reset(self): + start_agent = self.agent_ids[self._starting_agent_index] + quantities = self._sample_quantities() + values = self._sample_values() + self.state = TrustAndSplitState( + round_nb=0, + last_message="", + current_agent=start_agent, + quantities=quantities, + values=values, + previous_values=None, + splits={aid: None for aid in self.agent_ids}, + nb_messages_sent={aid: 0 for aid in self.agent_ids}, + split_phase=False, + previous_splits=None, + previous_points=None, + previous_quantities=None, + ) + return self.get_obs() diff --git a/src_code_for_reproducibility/utils/wandb_utils.py b/src_code_for_reproducibility/utils/wandb_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3e5d83ed1a5304f208288f78457582b0acdb58c4 --- /dev/null +++ b/src_code_for_reproducibility/utils/wandb_utils.py @@ -0,0 +1,164 @@ +import os +from typing import Any, Dict, Optional + + +_WANDB_AVAILABLE = False +_WANDB_RUN = None + + +def _try_import_wandb(): + global _WANDB_AVAILABLE + if _WANDB_AVAILABLE: + return True + try: + import wandb # type: ignore + + _WANDB_AVAILABLE = True + return True + except Exception: + _WANDB_AVAILABLE = False + return False + + +def _safe_get(cfg: Dict[str, Any], path: list[str], default: Any = None) -> Any: + cur: Any = cfg + for key in path: + if not isinstance(cur, dict) or key not in cur: + return default + cur = cur[key] + return cur + + +def is_enabled(cfg: Dict[str, Any]) -> bool: + return bool(_safe_get(cfg, ["logging", "wandb", "enabled"], False)) + + +def init(cfg: Dict[str, Any], run_dir: str, run_name: Optional[str] = None) -> None: + """ + Initialize Weights & Biases if enabled in config. No-op if disabled or wandb not installed. + """ + global _WANDB_RUN + if not is_enabled(cfg): + return + if not _try_import_wandb(): + return + + import wandb # type: ignore + + project = _safe_get(cfg, ["logging", "wandb", "project"], "llm-negotiation") + entity = _safe_get(cfg, ["logging", "wandb", "entity"], None) + mode = _safe_get(cfg, ["logging", "wandb", "mode"], "online") + tags = _safe_get(cfg, ["logging", "wandb", "tags"], []) or [] + notes = _safe_get(cfg, ["logging", "wandb", "notes"], None) + group = _safe_get(cfg, ["logging", "wandb", "group"], None) + name = _safe_get(cfg, ["logging", "wandb", "name"], run_name) + + # Ensure files are written into the hydra run directory + os.makedirs(run_dir, exist_ok=True) + os.environ.setdefault("WANDB_DIR", run_dir) + + # Convert cfg to plain types for W&B config; fallback to minimal dictionary + try: + from omegaconf import OmegaConf # type: ignore + + cfg_container = OmegaConf.to_container(cfg, resolve=True) # type: ignore + except Exception: + cfg_container = cfg + + _WANDB_RUN = wandb.init( + project=project, + entity=entity, + mode=mode, + name=name, + group=group, + tags=tags, + notes=notes, + config=cfg_container, + dir=run_dir, + reinit=True, + ) + + +def log(metrics: Dict[str, Any], step: Optional[int] = None) -> None: + """Log a flat dictionary of metrics to W&B if active.""" + if not _WANDB_AVAILABLE or _WANDB_RUN is None: + return + try: + import wandb # type: ignore + + wandb.log(metrics if step is None else dict(metrics, step=step)) + except Exception: + pass + + +def _flatten(prefix: str, data: Dict[str, Any], out: Dict[str, Any]) -> None: + for k, v in data.items(): + key = f"{prefix}.{k}" if prefix else k + if isinstance(v, dict): + _flatten(key, v, out) + else: + out[key] = v + + +def _summarize_value(value: Any) -> Dict[str, Any]: + import numpy as np # local import to avoid hard dependency during disabled mode + + if value is None: + return {"none": 1} + # Scalars + if isinstance(value, (int, float)): + return {"value": float(value)} + # Lists or arrays + try: + arr = np.asarray(value) + if arr.size == 0: + return {"size": 0} + return { + "mean": float(np.nanmean(arr)), + "min": float(np.nanmin(arr)), + "max": float(np.nanmax(arr)), + "last": float(arr.reshape(-1)[-1]), + "size": int(arr.size), + } + except Exception: + # Fallback: string repr + return {"text": str(value)} + + +def log_tally(array_tally: Dict[str, Any], prefix: str = "", step: Optional[int] = None) -> None: + """ + Flatten and summarize Tally.array_tally and log to WandB. + Each leaf list/array is summarized with mean/min/max/last/size. + """ + if not _WANDB_AVAILABLE or _WANDB_RUN is None: + return + summarized: Dict[str, Any] = {} + + def walk(node: Any, path: list[str]): + if isinstance(node, dict): + for k, v in node.items(): + walk(v, path + [k]) + return + # node is a list of values accumulated over time + key = ".".join([p for p in ([prefix] if prefix else []) + path]) + try: + summary = _summarize_value(node) + for sk, sv in summary.items(): + summarized[f"{key}.{sk}"] = sv + except Exception: + summarized[f"{key}.error"] = 1 + + walk(array_tally, []) + if summarized: + log(summarized, step=step) + + +def log_flat_stats(stats: Dict[str, Any], prefix: str = "", step: Optional[int] = None) -> None: + if not _WANDB_AVAILABLE or _WANDB_RUN is None: + return + flat: Dict[str, Any] = {} + _flatten(prefix, stats, flat) + if flat: + log(flat, step=step) + +